├── common ├── __init__.py ├── transforms.py ├── scheduler.py ├── metric_tracking.py ├── mixup.py ├── utils.py └── runner.py ├── models ├── __init__.py ├── feature_mapping.py ├── base_model.py └── transformerblock.py ├── datasets ├── __init__.py ├── data.py ├── reader_fns.py └── epic_kitchens.py ├── conf ├── opt │ ├── optimizer │ │ ├── adam.yaml │ │ ├── adamW.yaml │ │ └── sgd.yaml │ └── scheduler │ │ ├── multi_step.yaml │ │ └── cosine.yaml ├── model │ ├── backbone │ │ └── identity.yaml │ ├── CMFP │ │ ├── cmfp_early.yaml │ │ ├── scorefusion.yaml │ │ └── individual.yaml │ ├── mapping │ │ ├── gatedlinear.yaml │ │ ├── linear.yaml │ │ └── nonlinear.yaml │ ├── fuser │ │ ├── MATT.yaml │ │ ├── CA-Fuser.yaml │ │ ├── SA-Fuser_wo_token.yaml │ │ ├── T-SA-Fuser.yaml │ │ └── SA-Fuser.yaml │ ├── future_predictor │ │ └── base_future_predictor.yaml │ └── common.yaml ├── .DS_Store ├── dataset │ ├── egtea │ │ ├── common.yaml │ │ ├── val.yaml │ │ └── train.yaml │ └── epic_kitchens100 │ │ ├── common.yaml │ │ ├── train.yaml │ │ ├── val.yaml │ │ └── test.yaml ├── data │ └── default.yaml └── config.yaml ├── fuser.png ├── expts ├── .DS_Store ├── 00_RGB_TSN_ek100_train.txt ├── 00_RGB_Swin_ek100_train.txt ├── 01_SA-Fuser_ek100_val_TSN_wo_audio.txt ├── 01_SA-Fuser_ek100_test_TSN_wo_audio.txt ├── 06_SA-Fuser_egtea_val.txt ├── 01_SA-Fuser_ek100_val_TSN.txt ├── 01_SA-Fuser_ek100_val_Swin.txt ├── 06_SA-Fuser_egtea_train.txt ├── 05_MATT_ek100_train.txt ├── 04_CA-Fuser_ek100_train.txt ├── 02_SA-Fuser_wo_token_ek100_train.txt ├── 01_SA-Fuser_ek100_train.txt └── 03_T-SA-Fuser_ek100_train.txt ├── annotations ├── .DS_Store ├── ek55_ori │ ├── .DS_Store │ ├── EPIC_test_s1_timestamps.pkl │ ├── EPIC_test_s2_timestamps.pkl │ ├── EPIC_train_action_labels.pkl │ ├── EPIC_many_shot_verbs.csv │ ├── EPIC_many_shot_nouns.csv │ └── EPIC_verb_classes.csv ├── ek100_ori │ ├── .DS_Store │ ├── EPIC_100_train.pkl │ ├── EPIC_100_validation.pkl │ ├── EPIC_100_test_timestamps.pkl │ └── EPIC_100_verb_classes.csv ├── ek55_rulstm │ ├── .DS_Store │ ├── EPIC_many_shot_verbs.csv │ ├── validation_videos.csv │ ├── EPIC_many_shot_nouns.csv │ └── training_videos.csv ├── ek100_rulstm │ ├── .DS_Store │ ├── validation_videos.csv │ └── training_videos.csv └── egtea │ └── actions.csv ├── checkpoints ├── fusion_egtea_tsn │ └── README.md ├── fusion_ek100_swin_4h_16s │ └── README.md ├── fusion_ek100_tsn_4h_18s │ └── README.md └── fusion_ek100_tsn_wo_audio_4h_18s │ └── README.md ├── logits └── README.md ├── run.py ├── .gitignore ├── environment.yml ├── test.py ├── README.md ├── LICENSE └── tmp.py /common/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /conf/opt/optimizer/adam.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.optim.Adam -------------------------------------------------------------------------------- /conf/opt/optimizer/adamW.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.optim.AdamW -------------------------------------------------------------------------------- /conf/model/backbone/identity.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.nn.Identity -------------------------------------------------------------------------------- /fuser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zeyun-zhong/AFFT/HEAD/fuser.png -------------------------------------------------------------------------------- /conf/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zeyun-zhong/AFFT/HEAD/conf/.DS_Store -------------------------------------------------------------------------------- /expts/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zeyun-zhong/AFFT/HEAD/expts/.DS_Store -------------------------------------------------------------------------------- /conf/opt/scheduler/multi_step.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.optim.lr_scheduler.MultiStepLR 2 | -------------------------------------------------------------------------------- /conf/opt/optimizer/sgd.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.optim.SGD 2 | momentum: 0.9 3 | nesterov: false -------------------------------------------------------------------------------- /annotations/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zeyun-zhong/AFFT/HEAD/annotations/.DS_Store -------------------------------------------------------------------------------- /conf/model/CMFP/cmfp_early.yaml: -------------------------------------------------------------------------------- 1 | _target_: models.future_prediction.CMFPEarly 2 | model_cfg: null -------------------------------------------------------------------------------- /conf/model/CMFP/scorefusion.yaml: -------------------------------------------------------------------------------- 1 | _target_: models.future_prediction.CMFPScoreFusion 2 | model_cfg: null -------------------------------------------------------------------------------- /conf/model/mapping/gatedlinear.yaml: -------------------------------------------------------------------------------- 1 | _target_: models.feature_mapping.GatedLinear 2 | use_layernorm: true -------------------------------------------------------------------------------- /annotations/ek55_ori/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zeyun-zhong/AFFT/HEAD/annotations/ek55_ori/.DS_Store -------------------------------------------------------------------------------- /conf/model/CMFP/individual.yaml: -------------------------------------------------------------------------------- 1 | _target_: models.future_prediction.IndividualFuturePrediction 2 | model_cfg: null -------------------------------------------------------------------------------- /annotations/ek100_ori/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zeyun-zhong/AFFT/HEAD/annotations/ek100_ori/.DS_Store -------------------------------------------------------------------------------- /annotations/ek55_rulstm/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zeyun-zhong/AFFT/HEAD/annotations/ek55_rulstm/.DS_Store -------------------------------------------------------------------------------- /annotations/ek100_rulstm/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zeyun-zhong/AFFT/HEAD/annotations/ek100_rulstm/.DS_Store -------------------------------------------------------------------------------- /conf/model/mapping/linear.yaml: -------------------------------------------------------------------------------- 1 | _target_: models.feature_mapping.Linear 2 | use_layernorm: false 3 | sparse_mapping: true -------------------------------------------------------------------------------- /conf/model/mapping/nonlinear.yaml: -------------------------------------------------------------------------------- 1 | _target_: models.feature_mapping.NonLinear 2 | use_layernorm: true 3 | activation: relu -------------------------------------------------------------------------------- /annotations/ek100_ori/EPIC_100_train.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zeyun-zhong/AFFT/HEAD/annotations/ek100_ori/EPIC_100_train.pkl -------------------------------------------------------------------------------- /annotations/ek100_ori/EPIC_100_validation.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zeyun-zhong/AFFT/HEAD/annotations/ek100_ori/EPIC_100_validation.pkl -------------------------------------------------------------------------------- /annotations/ek55_ori/EPIC_test_s1_timestamps.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zeyun-zhong/AFFT/HEAD/annotations/ek55_ori/EPIC_test_s1_timestamps.pkl -------------------------------------------------------------------------------- /annotations/ek55_ori/EPIC_test_s2_timestamps.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zeyun-zhong/AFFT/HEAD/annotations/ek55_ori/EPIC_test_s2_timestamps.pkl -------------------------------------------------------------------------------- /checkpoints/fusion_egtea_tsn/README.md: -------------------------------------------------------------------------------- 1 | Please download the corresponding checkpoint from [Model Zoo](../../README.md#Model Zoo) 2 | and put it here. -------------------------------------------------------------------------------- /conf/model/fuser/MATT.yaml: -------------------------------------------------------------------------------- 1 | _target_: models.fusion.MATT 2 | modal_dims: ${model.modal_dims} 3 | dim: ${model.common.in_features} 4 | drop_rate: 0.8 -------------------------------------------------------------------------------- /conf/opt/scheduler/cosine.yaml: -------------------------------------------------------------------------------- 1 | _target_: common.scheduler.CosineLR 2 | num_epochs: ${train.num_epochs} 3 | eta_min: 1e-6 # Min LR (default) 4 | -------------------------------------------------------------------------------- /annotations/ek100_ori/EPIC_100_test_timestamps.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zeyun-zhong/AFFT/HEAD/annotations/ek100_ori/EPIC_100_test_timestamps.pkl -------------------------------------------------------------------------------- /annotations/ek55_ori/EPIC_train_action_labels.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zeyun-zhong/AFFT/HEAD/annotations/ek55_ori/EPIC_train_action_labels.pkl -------------------------------------------------------------------------------- /checkpoints/fusion_ek100_swin_4h_16s/README.md: -------------------------------------------------------------------------------- 1 | Please download the corresponding checkpoint from [Model Zoo](../../README.md#Model Zoo) 2 | and put it here. -------------------------------------------------------------------------------- /checkpoints/fusion_ek100_tsn_4h_18s/README.md: -------------------------------------------------------------------------------- 1 | Please download the corresponding checkpoint from [Model Zoo](../../README.md#Model Zoo) 2 | and put it here. -------------------------------------------------------------------------------- /checkpoints/fusion_ek100_tsn_wo_audio_4h_18s/README.md: -------------------------------------------------------------------------------- 1 | Please download the corresponding checkpoint from [Model Zoo](../../README.md#Model Zoo) 2 | and put it here. -------------------------------------------------------------------------------- /conf/model/fuser/CA-Fuser.yaml: -------------------------------------------------------------------------------- 1 | _target_: models.fusion.TemporalCrossAttentFuser 2 | dim: ${model.common.in_features} 3 | modalities: ${model.modal_dims} 4 | num_heads: 4 5 | embd_drop_rate: 0.1 6 | drop_rate: 0.1 7 | attn_drop_rate: 0.1 8 | drop_path_rate: 0.1 9 | -------------------------------------------------------------------------------- /conf/model/fuser/SA-Fuser_wo_token.yaml: -------------------------------------------------------------------------------- 1 | _target_: models.fusion.CMFuser 2 | dim: ${model.common.in_features} 3 | depth: 6 4 | num_heads: 4 5 | embd_drop_rate: 0.1 6 | drop_rate: 0.1 7 | attn_drop_rate: 0.1 8 | drop_path_rate: 0.1 9 | cross_attn: false 10 | -------------------------------------------------------------------------------- /logits/README.md: -------------------------------------------------------------------------------- 1 | Logits that generated by running [test.py](../test.py) and 2 | the submission file for ek100 generated by [challenge.py](../challenge.py) will be saved here. 3 | 4 | You can simply change the variable `LOGITS_DIR` in [challenge.py](../challenge.py) 5 | to assign a new path. -------------------------------------------------------------------------------- /conf/model/fuser/T-SA-Fuser.yaml: -------------------------------------------------------------------------------- 1 | _target_: models.fusion.TemporalCMFuser 2 | dim: ${model.common.in_features} 3 | depth: 6 4 | num_heads: 4 5 | embd_drop_rate: 0.1 6 | drop_rate: 0.1 7 | attn_drop_rate: 0.1 8 | drop_path_rate: 0.1 9 | modalities: ${model.modal_dims} 10 | modal_encoding: true 11 | frame_level_token: false 12 | temporal_sequence_length: null -------------------------------------------------------------------------------- /annotations/ek55_ori/EPIC_many_shot_verbs.csv: -------------------------------------------------------------------------------- 1 | verb_class,verb 2 | 1,put 3 | 0,take 4 | 4,wash 5 | 2,open 6 | 3,close 7 | 5,cut 8 | 6,mix 9 | 7,pour 10 | 9,move 11 | 12,turn-on 12 | 10,remove 13 | 15,turn-off 14 | 8,throw 15 | 11,dry 16 | 16,peel 17 | 22,insert 18 | 13,turn 19 | 14,shake 20 | 21,squeeze 21 | 23,press 22 | 20,check 23 | 19,scoop 24 | 18,empty 25 | 17,adjust 26 | 24,fill 27 | 32,flip 28 | -------------------------------------------------------------------------------- /annotations/ek55_rulstm/EPIC_many_shot_verbs.csv: -------------------------------------------------------------------------------- 1 | verb_class,verb 2 | 1,put 3 | 0,take 4 | 4,wash 5 | 2,open 6 | 3,close 7 | 5,cut 8 | 6,mix 9 | 7,pour 10 | 9,move 11 | 12,turn-on 12 | 10,remove 13 | 15,turn-off 14 | 8,throw 15 | 11,dry 16 | 16,peel 17 | 22,insert 18 | 13,turn 19 | 14,shake 20 | 21,squeeze 21 | 23,press 22 | 20,check 23 | 19,scoop 24 | 18,empty 25 | 17,adjust 26 | 24,fill 27 | 32,flip 28 | -------------------------------------------------------------------------------- /conf/model/fuser/SA-Fuser.yaml: -------------------------------------------------------------------------------- 1 | _target_: models.fusion.ModalTokenCMFuser 2 | dim: ${model.common.in_features} 3 | depth: 6 4 | num_heads: 4 5 | embd_drop_rate: 0.1 6 | drop_rate: 0.1 7 | attn_drop_rate: 0.1 8 | drop_path_rate: 0.1 9 | cross_attn: false 10 | norm_elementwise: true 11 | modalities: ${model.modal_dims} 12 | modal_encoding: false 13 | frame_level_token: false 14 | temporal_sequence_length: null -------------------------------------------------------------------------------- /conf/dataset/egtea/common.yaml: -------------------------------------------------------------------------------- 1 | # @package dataset.egtea.common 2 | 3 | version: -1 4 | # RULSTM feats dirs 5 | rulstm_feats_dir: ${dataset_root_dir}/egtea/features 6 | annot_dir: ${cwd}/annotations/egtea/ 7 | rulstm_annot_dir: ${cwd}/annotations/egtea/ 8 | label_type: action 9 | sample_strategy: "last_clip" 10 | tau_a: 0.5 11 | tau_o: 10 12 | split: 1 13 | compute_dataset_stats: false 14 | reader_fn: null 15 | max_els: null -------------------------------------------------------------------------------- /conf/model/future_predictor/base_future_predictor.yaml: -------------------------------------------------------------------------------- 1 | _target_: models.future_prediction.BaseFuturePredictor 2 | in_features: ${model.common.in_features} 3 | inter_dim: ${model.common.fp_inter_dim} 4 | n_layer: ${model.common.fp_layers} 5 | n_head: ${model.common.fp_heads} 6 | output_attentions: ${model.common.fp_output_attentions} 7 | embd_pdrop: ${model.common.embd_pdrop} 8 | resid_pdrop: ${model.common.resid_pdrop} 9 | attn_pdrop: ${model.common.attn_pdrop} -------------------------------------------------------------------------------- /conf/dataset/epic_kitchens100/common.yaml: -------------------------------------------------------------------------------- 1 | # @package dataset.epic_kitchens100.common 2 | 3 | version: 0.2 4 | # RULSTM feats dirs 5 | rulstm_feats_dir: ${dataset_root_dir}/epickitchens100/features 6 | annot_dir: ${cwd}/annotations/ek100_ori/ 7 | rulstm_annot_dir: ${cwd}/annotations/ek100_rulstm/ 8 | label_type: action 9 | sample_strategy: "last_clip" 10 | tau_a: 1 11 | tau_o: 10 12 | compute_dataset_stats: false 13 | reader_fn: null 14 | max_els: null 15 | -------------------------------------------------------------------------------- /annotations/ek55_rulstm/validation_videos.csv: -------------------------------------------------------------------------------- 1 | P01_01 2 | P01_10 3 | P02_03 4 | P02_05 5 | P03_06 6 | P03_11 7 | P04_09 8 | P06_05 9 | P07_02 10 | P07_08 11 | P07_10 12 | P08_01 13 | P08_05 14 | P08_12 15 | P10_01 16 | P13_04 17 | P13_06 18 | P13_09 19 | P14_01 20 | P14_02 21 | P20_03 22 | P20_04 23 | P22_08 24 | P22_10 25 | P22_11 26 | P22_13 27 | P23_03 28 | P24_08 29 | P25_11 30 | P26_02 31 | P26_11 32 | P26_16 33 | P27_03 34 | P28_05 35 | P28_12 36 | P28_13 37 | P30_01 38 | P30_03 39 | P31_01 40 | P31_08 41 | -------------------------------------------------------------------------------- /conf/data/default.yaml: -------------------------------------------------------------------------------- 1 | 2 | # The top few go into the dataset object to load as per these 3 | num_frames: 10 4 | frame_rate: 1 5 | frame_subclips: 6 | num_frames: 1 7 | stride: 1 8 | sec_subclips: # allows to have different sequence of labels than the sequence of frames 9 | num_frames: 1 10 | stride: 1 11 | 12 | # Load segmentation labels only if a classifier on the past is being applied 13 | load_seg_labels: true 14 | 15 | # Augmentation for RULSTM Feats, only for training 16 | zero_mask_rate: 0.0 17 | -------------------------------------------------------------------------------- /conf/model/common.yaml: -------------------------------------------------------------------------------- 1 | # @package model.common 2 | 3 | in_features: ${model.common_dim} 4 | 5 | # boolean options controlling future predictor and classifier 6 | share_classifiers: true # whether a common classifier should be used 7 | share_predictors: false # whether a common future predictor should be used 8 | modality_cls: false # whether modality-wise classification 9 | fusion_cls: true # whether the fused features should be classified 10 | 11 | # backbones (identity layer for feature vectors) 12 | backbones: null 13 | 14 | # for base future predictor 15 | fp_output_len: 1 16 | fp_inter_dim: 2048 17 | fp_layers: 6 18 | fp_heads: 4 19 | fp_output_attentions: false 20 | embd_pdrop: 0.1 21 | resid_pdrop: 0.1 22 | attn_pdrop: 0.1 23 | -------------------------------------------------------------------------------- /common/transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | 4 | 5 | class PermuteRULSTMFeats: 6 | def __init__(self): 7 | pass 8 | 9 | def __call__(self, vid): 10 | return vid.permute(3, 0, 1, 2) 11 | 12 | 13 | class ZeroMaskRULSTMFeats: 14 | """Mask random frames with zeros""" 15 | def __init__(self, mask_rate=0.2): 16 | self.mask_rate = mask_rate 17 | 18 | def __call__(self, vid): 19 | if self.mask_rate == 0: 20 | return vid 21 | num_frames = vid.size(0) 22 | num_masked_frames = round(num_frames * self.mask_rate) 23 | random_choices = random.sample(range(num_frames), num_masked_frames) 24 | vid[random_choices, :, :, :] = torch.zeros((num_masked_frames, vid.size(1), vid.size(2), vid.size(-1))) 25 | return vid 26 | -------------------------------------------------------------------------------- /datasets/data.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | 3 | 4 | def get_dataset(dataset_cfg, data_cfg, transforms, logger): 5 | kwargs = {} 6 | kwargs['transforms'] = transforms 7 | kwargs['frame_rate'] = data_cfg.frame_rate 8 | kwargs['frames_per_clip'] = data_cfg.num_frames 9 | # Have to call dict() here since relative interpolation somehow doesn't work once I get the subclips object 10 | kwargs['frame_subclips_options'] = dict(data_cfg.frame_subclips) 11 | kwargs['sec_subclips_options'] = dict(data_cfg.sec_subclips) 12 | kwargs['load_seg_labels'] = data_cfg.load_seg_labels 13 | logger.info('Creating the dataset object...') 14 | # Not recursive since many of the sub-instantiations would need positional arguments 15 | _dataset = hydra.utils.instantiate(dataset_cfg, _recursive_=False, **kwargs) 16 | logger.info(f'Created dataset with {len(_dataset)} elts') 17 | return _dataset 18 | -------------------------------------------------------------------------------- /conf/dataset/egtea/val.yaml: -------------------------------------------------------------------------------- 1 | # @package dataset.egtea.val 2 | 3 | _target_: datasets.epic_kitchens.EPICKitchens 4 | version: ${dataset.egtea.common.version} 5 | annotation_path: 6 | - ${dataset.egtea.common.annot_dir}/validation${dataset.egtea.common.split}.csv 7 | annotation_dir: ${dataset.egtea.common.annot_dir} 8 | rulstm_annotation_dir: ${dataset.egtea.common.rulstm_annot_dir} # Needed during computing final outputs to get tail classes etc. 9 | label_type: ${dataset.egtea.common.label_type} 10 | sample_strategy: ${dataset.egtea.common.sample_strategy} 11 | action_labels_fpath: ${dataset.egtea.common.rulstm_annot_dir}/actions.csv 12 | compute_dataset_stats: ${dataset.egtea.common.compute_dataset_stats} 13 | conv_to_anticipate_fn: 14 | _target_: datasets.base_video_dataset.convert_to_anticipation 15 | tau_a: ${dataset.egtea.common.tau_a} 16 | tau_o: ${dataset.egtea.common.tau_o} 17 | drop_style: correct 18 | reader_fn: ${dataset.egtea.common.reader_fn} 19 | max_els: ${dataset.egtea.common.max_els} -------------------------------------------------------------------------------- /conf/dataset/egtea/train.yaml: -------------------------------------------------------------------------------- 1 | # @package dataset.egtea.train 2 | 3 | _target_: datasets.epic_kitchens.EPICKitchens 4 | version: ${dataset.egtea.common.version} 5 | annotation_path: 6 | - ${dataset.egtea.common.annot_dir}/training${dataset.egtea.common.split}.csv 7 | annotation_dir: ${dataset.egtea.common.annot_dir} 8 | rulstm_annotation_dir: ${dataset.egtea.common.rulstm_annot_dir} # Needed during computing final outputs to get tail classes etc. 9 | label_type: ${dataset.egtea.common.label_type} 10 | sample_strategy: ${dataset.egtea.common.sample_strategy} 11 | action_labels_fpath: ${dataset.egtea.common.rulstm_annot_dir}/actions.csv 12 | compute_dataset_stats: ${dataset.egtea.common.compute_dataset_stats} 13 | conv_to_anticipate_fn: 14 | _target_: datasets.base_video_dataset.convert_to_anticipation 15 | tau_a: ${dataset.egtea.common.tau_a} 16 | tau_o: ${dataset.egtea.common.tau_o} 17 | drop_style: correct 18 | reader_fn: ${dataset.egtea.common.reader_fn} 19 | max_els: ${dataset.egtea.common.max_els} -------------------------------------------------------------------------------- /annotations/ek55_ori/EPIC_many_shot_nouns.csv: -------------------------------------------------------------------------------- 1 | noun_class,noun 2 | 3,tap 3 | 4,plate 4 | 8,cupboard 5 | 1,pan 6 | 7,spoon 7 | 5,knife 8 | 9,drawer 9 | 10,fridge 10 | 6,bowl 11 | 12,hand 12 | 11,lid 13 | 13,onion 14 | 16,glass 15 | 23,cup 16 | 17,water 17 | 19,board:chopping 18 | 21,sponge 19 | 18,fork 20 | 32,cloth 21 | 20,bag 22 | 28,bottle 23 | 15,pot 24 | 22,spatula 25 | 39,box 26 | 26,meat 27 | 24,oil 28 | 30,tomato 29 | 31,salt 30 | 29,container 31 | 27,potato 32 | 77,package 33 | 37,food 34 | 47,hob 35 | 35,pasta 36 | 78,top 37 | 40,carrot 38 | 45,garlic 39 | 68,skin 40 | 44,rice 41 | 25,bin 42 | 38,kettle 43 | 46,pepper 44 | 33,sink 45 | 51,cheese 46 | 56,oven 47 | 70,liquid:washing 48 | 58,coffee 49 | 52,bread 50 | 108,rubbish 51 | 67,peach 52 | 42,colander 53 | 41,sauce 54 | 54,salad 55 | 126,maker:coffee 56 | 60,jar 57 | 84,sausage 58 | 75,cutlery 59 | 43,milk 60 | 62,chicken 61 | 50,egg 62 | 59,filter 63 | 55,microwave 64 | 49,dishwasher 65 | 87,can 66 | 48,dough 67 | 63,tray 68 | 72,leaf 69 | 105,jug 70 | 106,heat 71 | 79,spice 72 | 111,stock 73 | -------------------------------------------------------------------------------- /annotations/ek55_rulstm/EPIC_many_shot_nouns.csv: -------------------------------------------------------------------------------- 1 | noun_class,noun 2 | 3,tap 3 | 4,plate 4 | 8,cupboard 5 | 1,pan 6 | 7,spoon 7 | 5,knife 8 | 9,drawer 9 | 10,fridge 10 | 6,bowl 11 | 12,hand 12 | 11,lid 13 | 13,onion 14 | 16,glass 15 | 23,cup 16 | 17,water 17 | 19,board:chopping 18 | 21,sponge 19 | 18,fork 20 | 32,cloth 21 | 20,bag 22 | 28,bottle 23 | 15,pot 24 | 22,spatula 25 | 39,box 26 | 26,meat 27 | 24,oil 28 | 30,tomato 29 | 31,salt 30 | 29,container 31 | 27,potato 32 | 77,package 33 | 37,food 34 | 47,hob 35 | 35,pasta 36 | 78,top 37 | 40,carrot 38 | 45,garlic 39 | 68,skin 40 | 44,rice 41 | 25,bin 42 | 38,kettle 43 | 46,pepper 44 | 33,sink 45 | 51,cheese 46 | 56,oven 47 | 70,liquid:washing 48 | 58,coffee 49 | 52,bread 50 | 108,rubbish 51 | 67,peach 52 | 42,colander 53 | 41,sauce 54 | 54,salad 55 | 126,maker:coffee 56 | 60,jar 57 | 84,sausage 58 | 75,cutlery 59 | 43,milk 60 | 62,chicken 61 | 50,egg 62 | 59,filter 63 | 55,microwave 64 | 49,dishwasher 65 | 87,can 66 | 48,dough 67 | 63,tray 68 | 72,leaf 69 | 105,jug 70 | 106,heat 71 | 79,spice 72 | 111,stock 73 | -------------------------------------------------------------------------------- /conf/dataset/epic_kitchens100/train.yaml: -------------------------------------------------------------------------------- 1 | # @package dataset.epic_kitchens100.train 2 | 3 | _target_: datasets.epic_kitchens.EPICKitchens 4 | version: ${dataset.epic_kitchens100.common.version} 5 | annotation_path: 6 | - ${dataset.epic_kitchens100.common.annot_dir}/EPIC_100_train.pkl 7 | annotation_dir: ${dataset.epic_kitchens100.common.annot_dir} 8 | rulstm_annotation_dir: ${dataset.epic_kitchens100.common.rulstm_annot_dir} # Needed during computing final outputs to get tail classes etc. 9 | label_type: ${dataset.epic_kitchens100.common.label_type} 10 | sample_strategy: ${dataset.epic_kitchens100.common.sample_strategy} 11 | action_labels_fpath: ${dataset.epic_kitchens100.common.rulstm_annot_dir}/actions.csv 12 | compute_dataset_stats: ${dataset.epic_kitchens100.common.compute_dataset_stats} 13 | conv_to_anticipate_fn: 14 | _target_: datasets.base_video_dataset.convert_to_anticipation 15 | tau_a: ${dataset.epic_kitchens100.common.tau_a} 16 | tau_o: ${dataset.epic_kitchens100.common.tau_o} 17 | drop_style: correct 18 | reader_fn: ${dataset.epic_kitchens100.common.reader_fn} 19 | max_els: ${dataset.epic_kitchens100.common.max_els} -------------------------------------------------------------------------------- /conf/dataset/epic_kitchens100/val.yaml: -------------------------------------------------------------------------------- 1 | # @package dataset.epic_kitchens100.val 2 | 3 | _target_: datasets.epic_kitchens.EPICKitchens 4 | version: ${dataset.epic_kitchens100.common.version} 5 | annotation_path: 6 | - ${dataset.epic_kitchens100.common.annot_dir}/EPIC_100_validation.pkl 7 | annotation_dir: ${dataset.epic_kitchens100.common.annot_dir} 8 | rulstm_annotation_dir: ${dataset.epic_kitchens100.common.rulstm_annot_dir} # Needed during computing final outputs to get tail classes etc. 9 | label_type: ${dataset.epic_kitchens100.common.label_type} 10 | sample_strategy: ${dataset.epic_kitchens100.common.sample_strategy} 11 | action_labels_fpath: ${dataset.epic_kitchens100.common.rulstm_annot_dir}/actions.csv 12 | compute_dataset_stats: ${dataset.epic_kitchens100.common.compute_dataset_stats} 13 | conv_to_anticipate_fn: 14 | _target_: datasets.base_video_dataset.convert_to_anticipation 15 | tau_a: ${dataset.epic_kitchens100.common.tau_a} 16 | tau_o: ${dataset.epic_kitchens100.common.tau_o} 17 | drop_style: correct 18 | reader_fn: ${dataset.epic_kitchens100.common.reader_fn} 19 | max_els: ${dataset.epic_kitchens100.common.max_els} -------------------------------------------------------------------------------- /conf/dataset/epic_kitchens100/test.yaml: -------------------------------------------------------------------------------- 1 | # @package dataset.epic_kitchens100.train 2 | 3 | _target_: datasets.epic_kitchens.EPICKitchens 4 | version: ${dataset.epic_kitchens100.common.version} 5 | annotation_path: 6 | - ${dataset.epic_kitchens100.common.annot_dir}/EPIC_100_test_timestamps.pkl 7 | annotation_dir: ${dataset.epic_kitchens100.common.annot_dir} 8 | rulstm_annotation_dir: ${dataset.epic_kitchens100.common.rulstm_annot_dir} # Needed during computing final outputs to get tail classes etc. 9 | label_type: ${dataset.epic_kitchens100.common.label_type} 10 | sample_strategy: ${dataset.epic_kitchens100.common.sample_strategy} 11 | action_labels_fpath: ${dataset.epic_kitchens100.common.rulstm_annot_dir}/actions.csv 12 | compute_dataset_stats: ${dataset.epic_kitchens100.common.compute_dataset_stats} 13 | conv_to_anticipate_fn: 14 | _target_: datasets.base_video_dataset.convert_to_anticipation 15 | tau_a: ${dataset.epic_kitchens100.common.tau_a} 16 | tau_o: ${dataset.epic_kitchens100.common.tau_o} 17 | drop_style: correct 18 | reader_fn: ${dataset.epic_kitchens100.common.reader_fn} 19 | max_els: ${dataset.epic_kitchens100.common.max_els} -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import argparse 3 | 4 | 5 | def parse_args(): 6 | """Parse arguments""" 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('-c', '--cfg', type=str, required=True, 9 | help='Overrides config file') 10 | parser.add_argument('-m', '--mode', type=str, required=True, choices=['train', 'test', 'visualize_attention'], 11 | help='Choose which file to run') 12 | parser.add_argument('-n', '--nproc_per_node', type=int, default=4, required=True, 13 | help='number of gpus per node') 14 | args = parser.parse_args() 15 | return args 16 | 17 | 18 | def read_file_into_cli(fpath): 19 | """Read cli from file into a string.""" 20 | res = [] 21 | with open(fpath, 'r') as fin: 22 | for line in fin: 23 | args = line.split('#')[0].strip() 24 | if len(args) == 0: 25 | continue 26 | res.append(args) 27 | return res 28 | 29 | 30 | def escape_str(input_str): 31 | return f"'{input_str}'" 32 | 33 | 34 | def construct_cmd(args): 35 | if args.cfg: 36 | assert args.cfg.startswith("expts"), "Must be wrt this directory" 37 | 38 | cli_stuff = read_file_into_cli(args.cfg) 39 | cli_stuff = [escape_str(el) for el in cli_stuff] 40 | cli_stuff = ' '.join(cli_stuff) 41 | 42 | cli = (f'HYDRA_FULL_ERROR=1 torchrun --nproc_per_node={args.nproc_per_node} {args.mode}.py ') 43 | cli += cli_stuff 44 | return cli 45 | 46 | 47 | def main(): 48 | args = parse_args() 49 | cmd = construct_cmd(args) 50 | print('>> Running "{}"'.format(cmd)) 51 | subprocess.call(cmd, shell=True) 52 | 53 | 54 | if __name__ == "__main__": 55 | main() -------------------------------------------------------------------------------- /annotations/ek100_rulstm/validation_videos.csv: -------------------------------------------------------------------------------- 1 | P01_11 2 | P01_12 3 | P01_13 4 | P01_14 5 | P01_15 6 | P02_12 7 | P02_13 8 | P02_14 9 | P02_15 10 | P03_21 11 | P03_22 12 | P03_23 13 | P03_24 14 | P03_25 15 | P03_26 16 | P04_24 17 | P04_25 18 | P04_26 19 | P04_27 20 | P04_28 21 | P04_29 22 | P04_30 23 | P04_31 24 | P04_32 25 | P04_33 26 | P05_07 27 | P05_09 28 | P06_10 29 | P06_11 30 | P06_12 31 | P06_13 32 | P06_14 33 | P07_12 34 | P07_13 35 | P07_14 36 | P07_15 37 | P07_16 38 | P07_17 39 | P07_18 40 | P08_09 41 | P08_10 42 | P08_14 43 | P08_15 44 | P08_16 45 | P08_17 46 | P09_07 47 | P09_08 48 | P10_03 49 | P11_17 50 | P11_18 51 | P11_19 52 | P11_20 53 | P11_21 54 | P11_22 55 | P11_23 56 | P11_24 57 | P12_03 58 | P12_08 59 | P13_01 60 | P13_02 61 | P13_03 62 | P14_06 63 | P14_08 64 | P15_04 65 | P15_05 66 | P15_06 67 | P16_04 68 | P17_02 69 | P18_01 70 | P18_02 71 | P18_03 72 | P18_04 73 | P18_05 74 | P18_06 75 | P18_07 76 | P18_08 77 | P18_09 78 | P18_10 79 | P18_11 80 | P18_12 81 | P19_05 82 | P19_06 83 | P20_05 84 | P20_06 85 | P20_07 86 | P21_02 87 | P22_01 88 | P22_02 89 | P22_03 90 | P22_04 91 | P23_05 92 | P24_09 93 | P25_06 94 | P25_07 95 | P25_08 96 | P26_30 97 | P26_31 98 | P26_32 99 | P26_33 100 | P26_34 101 | P26_35 102 | P26_36 103 | P26_37 104 | P26_38 105 | P26_39 106 | P26_40 107 | P26_41 108 | P27_05 109 | P28_15 110 | P28_16 111 | P28_17 112 | P28_18 113 | P28_19 114 | P28_20 115 | P28_21 116 | P28_22 117 | P28_23 118 | P28_24 119 | P28_25 120 | P28_26 121 | P29_05 122 | P29_06 123 | P30_07 124 | P30_08 125 | P30_09 126 | P31_10 127 | P31_11 128 | P31_12 129 | P32_01 130 | P32_02 131 | P32_03 132 | P32_04 133 | P32_05 134 | P32_06 135 | P32_07 136 | P32_08 137 | P32_09 138 | P32_10 139 | -------------------------------------------------------------------------------- /expts/00_RGB_TSN_ek100_train.txt: -------------------------------------------------------------------------------- 1 | workers=32 2 | num_gpus=2 3 | experiment_name=TSN_fp6l4h2048_bs32_lr0.001_mixupbackbone-0.1 4 | init_from_model=null 5 | primary_metric=val_mt5r_action_rgb 6 | 7 | train.batch_size=16 8 | eval.batch_size=16 9 | train.num_epochs=50 10 | train.use_mixup=true 11 | train.mixup_backbone=true 12 | train.mixup_alpha=0.1 13 | 14 | model.modal_dims={rgb:1024} 15 | model.common_dim=1024 16 | model.dropout=0.2 17 | model.common.backbones={rgb: {_target_: torch.nn.Identity}} 18 | model/future_predictor=base_future_predictor 19 | model/CMFP=individual 20 | 21 | model.common.share_classifiers=false 22 | model.common.share_predictors=false 23 | model.common.modality_cls=true 24 | model.common.fusion_cls=false 25 | 26 | model.common.fp_output_len=1 27 | model.common.fp_inter_dim=2048 28 | model.common.fp_layers=6 29 | model.common.fp_heads=4 30 | model.common.fp_output_attentions=false 31 | model.common.embd_pdrop=0.1 32 | model.common.resid_pdrop=0.1 33 | model.common.attn_pdrop=0.1 34 | 35 | opt.lr=0.001 36 | opt.wd=0.000001 37 | opt/optimizer=sgd 38 | opt/scheduler=cosine 39 | opt.optimizer.nesterov=true 40 | opt.warmup.num_epochs=20 41 | opt.scheduler.num_epochs=30 42 | opt.scheduler.eta_min=1e-6 43 | 44 | data_train.zero_mask_rate=0. 45 | 46 | dataset@dataset_train=epic_kitchens100/train 47 | dataset@dataset_eval=epic_kitchens100/val 48 | dataset.epic_kitchens100.common.label_type=action 49 | dataset.epic_kitchens100.common.sample_strategy=last_clip 50 | dataset.epic_kitchens100.common.tau_a=1 51 | dataset.epic_kitchens100.common.tau_o=10 52 | dataset.epic_kitchens100.common.compute_dataset_stats=true 53 | dataset.epic_kitchens100.common.max_els=null 54 | 55 | dataset.epic_kitchens100.common.reader_fn={rgb: {_target_: datasets.reader_fns.EpicRULSTMFeatsReader, lmdb_path: ${dataset.epic_kitchens100.common.rulstm_feats_dir}/rgb/}} -------------------------------------------------------------------------------- /expts/00_RGB_Swin_ek100_train.txt: -------------------------------------------------------------------------------- 1 | workers=32 2 | num_gpus=2 3 | experiment_name=Swin_fp6l4h2048_bs32_lr0.001_mixupbackbone-0.1 4 | init_from_model=null 5 | primary_metric=val_mt5r_action_rgb 6 | 7 | train.batch_size=16 8 | eval.batch_size=16 9 | train.num_epochs=50 10 | train.use_mixup=true 11 | train.mixup_backbone=true 12 | train.mixup_alpha=0.1 13 | 14 | model.modal_dims={rgb:1024} 15 | model.common_dim=1024 16 | model.dropout=0.2 17 | model.common.backbones={rgb: {_target_: torch.nn.Identity}} 18 | model/future_predictor=base_future_predictor 19 | model/CMFP=individual 20 | 21 | model.common.share_classifiers=false 22 | model.common.share_predictors=false 23 | model.common.modality_cls=true 24 | model.common.fusion_cls=false 25 | 26 | model.common.fp_output_len=1 27 | model.common.fp_inter_dim=2048 28 | model.common.fp_layers=6 29 | model.common.fp_heads=4 30 | model.common.fp_output_attentions=false 31 | model.common.embd_pdrop=0.1 32 | model.common.resid_pdrop=0.1 33 | model.common.attn_pdrop=0.1 34 | 35 | opt.lr=0.001 36 | opt.wd=0.000001 37 | opt/optimizer=sgd 38 | opt/scheduler=cosine 39 | opt.optimizer.nesterov=true 40 | opt.warmup.num_epochs=20 41 | opt.scheduler.num_epochs=30 42 | opt.scheduler.eta_min=1e-6 43 | 44 | data_train.zero_mask_rate=0. 45 | 46 | dataset@dataset_train=epic_kitchens100/train 47 | dataset@dataset_eval=epic_kitchens100/val 48 | dataset.epic_kitchens100.common.label_type=action 49 | dataset.epic_kitchens100.common.sample_strategy=last_clip 50 | dataset.epic_kitchens100.common.tau_a=1 51 | dataset.epic_kitchens100.common.tau_o=10 52 | dataset.epic_kitchens100.common.compute_dataset_stats=true 53 | dataset.epic_kitchens100.common.max_els=null 54 | 55 | dataset.epic_kitchens100.common.reader_fn={rgb: {_target_: datasets.reader_fns.EpicRULSTMFeatsReader, lmdb_path: ${dataset.epic_kitchens100.common.rulstm_feats_dir}/rgb_omnivore/}} -------------------------------------------------------------------------------- /expts/01_SA-Fuser_ek100_val_TSN_wo_audio.txt: -------------------------------------------------------------------------------- 1 | workers=32 2 | num_gpus=1 3 | init_from_model=fusion_ek100_tsn_wo_audio_4h_18s/checkpoint_best.pth 4 | 5 | train.batch_size=32 6 | eval.batch_size=32 7 | 8 | model.modal_dims={rgb:1024,objects:352,flow:1024} 9 | model.common_dim=1024 10 | model.dropout=0.2 11 | model.common.backbones={rgb:{_target_:torch.nn.Identity},objects:{_target_:torch.nn.Identity},flow:{_target_:torch.nn.Identity}} 12 | model/future_predictor=base_future_predictor 13 | model/fuser=SA-Fuser 14 | model/CMFP=cmfp_early 15 | model/mapping=linear 16 | 17 | model.common.share_classifiers=true 18 | model.common.share_predictors=true 19 | model.common.modality_cls=false 20 | model.common.fusion_cls=true 21 | 22 | model.mapping.use_layernorm=false 23 | model.mapping.sparse_mapping=true 24 | 25 | model.fuser.depth=6 26 | model.fuser.num_heads=4 27 | model.fuser.embd_drop_rate=0.1 28 | model.fuser.drop_rate=0.1 29 | model.fuser.attn_drop_rate=0.1 30 | model.fuser.drop_path_rate=0.1 31 | model.fuser.cross_attn=false 32 | 33 | data_train.num_frames=18 34 | data_eval.num_frames=18 35 | 36 | dataset@dataset_train=epic_kitchens100/train 37 | dataset@dataset_eval=epic_kitchens100/val 38 | dataset.epic_kitchens100.common.label_type=action 39 | dataset.epic_kitchens100.common.sample_strategy=last_clip 40 | dataset.epic_kitchens100.common.tau_a=1 41 | dataset.epic_kitchens100.common.tau_o=18 42 | dataset.epic_kitchens100.common.compute_dataset_stats=false 43 | dataset.epic_kitchens100.common.max_els=null 44 | 45 | dataset.epic_kitchens100.common.reader_fn={rgb:{_target_:datasets.reader_fns.EpicRULSTMFeatsReader,lmdb_path:${dataset.epic_kitchens100.common.rulstm_feats_dir}/rgb/},objects:{_target_:datasets.reader_fns.EpicRULSTMFeatsReader,lmdb_path:${dataset.epic_kitchens100.common.rulstm_feats_dir}/obj/},flow:{_target_:datasets.reader_fns.EpicRULSTMFeatsReader,lmdb_path:${dataset.epic_kitchens100.common.rulstm_feats_dir}/flow/}} -------------------------------------------------------------------------------- /expts/01_SA-Fuser_ek100_test_TSN_wo_audio.txt: -------------------------------------------------------------------------------- 1 | workers=32 2 | num_gpus=1 3 | init_from_model=fusion_ek100_tsn_wo_audio_4h_18s/checkpoint_best.pth 4 | +save_name=test.h5 5 | 6 | train.batch_size=32 7 | eval.batch_size=32 8 | 9 | model.modal_dims={rgb:1024,objects:352,flow:1024} 10 | model.common_dim=1024 11 | model.dropout=0.2 12 | model.common.backbones={rgb:{_target_:torch.nn.Identity},objects:{_target_:torch.nn.Identity},flow:{_target_:torch.nn.Identity}} 13 | model/future_predictor=base_future_predictor 14 | model/fuser=mtcmfuser 15 | model/CMFP=cmfp_early 16 | model/mapping=linear 17 | 18 | model.common.share_classifiers=true 19 | model.common.share_predictors=true 20 | model.common.modality_cls=false 21 | model.common.fusion_cls=true 22 | 23 | model.mapping.use_layernorm=false 24 | model.mapping.sparse_mapping=true 25 | 26 | model.fuser.depth=6 27 | model.fuser.num_heads=4 28 | model.fuser.embd_drop_rate=0.1 29 | model.fuser.drop_rate=0.1 30 | model.fuser.attn_drop_rate=0.1 31 | model.fuser.drop_path_rate=0.1 32 | model.fuser.cross_attn=false 33 | 34 | data_train.num_frames=18 35 | data_eval.num_frames=18 36 | 37 | dataset@dataset_train=epic_kitchens100/train 38 | dataset@dataset_eval=epic_kitchens100/test 39 | dataset.epic_kitchens100.common.label_type=action 40 | dataset.epic_kitchens100.common.sample_strategy=last_clip 41 | dataset.epic_kitchens100.common.tau_a=1 42 | dataset.epic_kitchens100.common.tau_o=18 43 | dataset.epic_kitchens100.common.compute_dataset_stats=false 44 | dataset.epic_kitchens100.common.max_els=null 45 | 46 | dataset.epic_kitchens100.common.reader_fn={rgb:{_target_:datasets.reader_fns.EpicRULSTMFeatsReader,lmdb_path:${dataset.epic_kitchens100.common.rulstm_feats_dir}/rgb/},objects:{_target_:datasets.reader_fns.EpicRULSTMFeatsReader,lmdb_path:${dataset.epic_kitchens100.common.rulstm_feats_dir}/obj/},flow:{_target_:datasets.reader_fns.EpicRULSTMFeatsReader,lmdb_path:${dataset.epic_kitchens100.common.rulstm_feats_dir}/flow/}} -------------------------------------------------------------------------------- /expts/06_SA-Fuser_egtea_val.txt: -------------------------------------------------------------------------------- 1 | workers=32 2 | num_gpus=1 3 | init_from_model=fusion_egtea_tsn/checkpoint_best.pth 4 | 5 | train.batch_size=32 6 | eval.batch_size=32 7 | 8 | model.modal_dims={rgb:1024, flow:1024} 9 | model.common_dim=1024 10 | model.dropout=0.2 11 | model.common.backbones={rgb: {_target_: torch.nn.Identity}, flow: {_target_: torch.nn.Identity}} 12 | model/future_predictor=base_future_predictor 13 | model/fuser=SA-Fuser 14 | model/CMFP=cmfp_early 15 | model/mapping=linear 16 | 17 | model.common.share_classifiers=true 18 | model.common.share_predictors=true 19 | model.common.modality_cls=false 20 | model.common.fusion_cls=true 21 | 22 | model.mapping.use_layernorm=false 23 | model.mapping.sparse_mapping=true 24 | 25 | model.fuser.depth=2 26 | model.fuser.num_heads=4 27 | model.fuser.embd_drop_rate=0.1 28 | model.fuser.drop_rate=0.1 29 | model.fuser.attn_drop_rate=0.1 30 | model.fuser.drop_path_rate=0.1 31 | model.fuser.cross_attn=false 32 | 33 | model.common.fp_output_len=1 34 | model.common.fp_inter_dim=2048 35 | model.common.fp_layers=2 36 | model.common.fp_heads=4 37 | model.common.fp_output_attentions=false 38 | model.common.embd_pdrop=0.1 39 | model.common.resid_pdrop=0.1 40 | model.common.attn_pdrop=0.1 41 | 42 | data_train.zero_mask_rate=0.0 43 | 44 | dataset@dataset_train=egtea/train 45 | dataset@dataset_eval=egtea/val 46 | dataset.egtea.common.label_type=action 47 | dataset.egtea.common.sample_strategy=last_clip 48 | dataset.egtea.common.tau_a=0.5 49 | dataset.egtea.common.tau_o=10 50 | dataset.egtea.common.compute_dataset_stats=false 51 | dataset.egtea.common.max_els=null 52 | 53 | dataset.egtea.common.reader_fn={rgb: {_target_: datasets.reader_fns.EpicRULSTMFeatsReader, lmdb_path: ${dataset.egtea.common.rulstm_feats_dir}/TSN-C_3_egtea_action_CE_s${dataset.egtea.common.split}_rgb_model_best_fcfull_hd/}, flow: {_target_: datasets.reader_fns.EpicRULSTMFeatsReader, lmdb_path: ${dataset.egtea.common.rulstm_feats_dir}/TSN-C_3_egtea_action_CE_s${dataset.egtea.common.split}_flow_model_best_fcfull_hd/}} -------------------------------------------------------------------------------- /expts/01_SA-Fuser_ek100_val_TSN.txt: -------------------------------------------------------------------------------- 1 | workers=32 2 | num_gpus=1 3 | init_from_model=fusion_ek100_tsn_4h_18s/checkpoint_best.pth 4 | 5 | train.batch_size=32 6 | eval.batch_size=32 7 | 8 | model.modal_dims={rgb:1024,objects:352,audio:1024,flow:1024} 9 | model.common_dim=1024 10 | model.dropout=0.2 11 | model.common.backbones={rgb:{_target_:torch.nn.Identity},objects:{_target_:torch.nn.Identity},flow:{_target_:torch.nn.Identity},audio:{_target_:torch.nn.Identity}} 12 | model/future_predictor=base_future_predictor 13 | model/fuser=SA-Fuser 14 | model/CMFP=cmfp_early 15 | model/mapping=linear 16 | 17 | model.common.share_classifiers=true 18 | model.common.share_predictors=true 19 | model.common.modality_cls=false 20 | model.common.fusion_cls=true 21 | 22 | model.mapping.use_layernorm=false 23 | model.mapping.sparse_mapping=true 24 | 25 | model.fuser.depth=6 26 | model.fuser.num_heads=4 27 | model.fuser.embd_drop_rate=0.1 28 | model.fuser.drop_rate=0.1 29 | model.fuser.attn_drop_rate=0.1 30 | model.fuser.drop_path_rate=0.1 31 | model.fuser.cross_attn=false 32 | 33 | data_train.num_frames=18 34 | data_eval.num_frames=18 35 | 36 | dataset@dataset_train=epic_kitchens100/train 37 | dataset@dataset_eval=epic_kitchens100/val 38 | dataset.epic_kitchens100.common.label_type=action 39 | dataset.epic_kitchens100.common.sample_strategy=last_clip 40 | dataset.epic_kitchens100.common.tau_a=1 41 | dataset.epic_kitchens100.common.tau_o=18 42 | dataset.epic_kitchens100.common.compute_dataset_stats=false 43 | dataset.epic_kitchens100.common.max_els=null 44 | 45 | dataset.epic_kitchens100.common.reader_fn={rgb:{_target_:datasets.reader_fns.EpicRULSTMFeatsReader,lmdb_path:${dataset.epic_kitchens100.common.rulstm_feats_dir}/rgb/},objects:{_target_:datasets.reader_fns.EpicRULSTMFeatsReader,lmdb_path:${dataset.epic_kitchens100.common.rulstm_feats_dir}/obj/},flow:{_target_:datasets.reader_fns.EpicRULSTMFeatsReader,lmdb_path:${dataset.epic_kitchens100.common.rulstm_feats_dir}/flow/},audio:{_target_:datasets.reader_fns.EpicRULSTMFeatsReader,lmdb_path:${dataset.epic_kitchens100.common.rulstm_feats_dir}/audio/,warn_if_using_closeby_frame:false}} -------------------------------------------------------------------------------- /expts/01_SA-Fuser_ek100_val_Swin.txt: -------------------------------------------------------------------------------- 1 | workers=32 2 | num_gpus=1 3 | init_from_model=fusion_ek100_swin_4h_16s/checkpoint_best.pth 4 | 5 | train.batch_size=32 6 | eval.batch_size=32 7 | 8 | model.modal_dims={rgb:1024,objects:352,audio:1024,flow:1024} 9 | model.common_dim=1024 10 | model.dropout=0.2 11 | model.common.backbones={rgb:{_target_:torch.nn.Identity},objects:{_target_:torch.nn.Identity},flow:{_target_:torch.nn.Identity},audio:{_target_:torch.nn.Identity}} 12 | model/future_predictor=base_future_predictor 13 | model/fuser=SA-Fuser 14 | model/CMFP=cmfp_early 15 | model/mapping=linear 16 | 17 | model.common.share_classifiers=true 18 | model.common.share_predictors=true 19 | model.common.modality_cls=false 20 | model.common.fusion_cls=true 21 | 22 | model.mapping.use_layernorm=false 23 | model.mapping.sparse_mapping=true 24 | 25 | model.fuser.depth=6 26 | model.fuser.num_heads=4 27 | model.fuser.embd_drop_rate=0.1 28 | model.fuser.drop_rate=0.1 29 | model.fuser.attn_drop_rate=0.1 30 | model.fuser.drop_path_rate=0.1 31 | model.fuser.cross_attn=false 32 | 33 | data_train.num_frames=16 34 | data_eval.num_frames=16 35 | 36 | dataset@dataset_train=epic_kitchens100/train 37 | dataset@dataset_eval=epic_kitchens100/val 38 | dataset.epic_kitchens100.common.label_type=action 39 | dataset.epic_kitchens100.common.sample_strategy=last_clip 40 | dataset.epic_kitchens100.common.tau_a=1 41 | dataset.epic_kitchens100.common.tau_o=16 42 | dataset.epic_kitchens100.common.compute_dataset_stats=false 43 | dataset.epic_kitchens100.common.max_els=null 44 | 45 | dataset.epic_kitchens100.common.reader_fn={rgb:{_target_:datasets.reader_fns.EpicRULSTMFeatsReader,lmdb_path:${dataset.epic_kitchens100.common.rulstm_feats_dir}/rgb_omnivore/},objects:{_target_:datasets.reader_fns.EpicRULSTMFeatsReader,lmdb_path:${dataset.epic_kitchens100.common.rulstm_feats_dir}/obj/},flow:{_target_:datasets.reader_fns.EpicRULSTMFeatsReader,lmdb_path:${dataset.epic_kitchens100.common.rulstm_feats_dir}/flow/},audio:{_target_:datasets.reader_fns.EpicRULSTMFeatsReader,lmdb_path:${dataset.epic_kitchens100.common.rulstm_feats_dir}/audio/,warn_if_using_closeby_frame:false}} -------------------------------------------------------------------------------- /conf/config.yaml: -------------------------------------------------------------------------------- 1 | cwd: ${hydra:runtime.cwd} 2 | workers: 4 3 | num_gpus: 2 4 | seed: 42 5 | project_name: Anticipation 6 | experiment_name: CMFuser 7 | init_from_model: null 8 | dataset_root_dir: /home/zhong/Documents/datasets 9 | primary_metric: val_mt5r_action_all-fused 10 | dist_backend: nccl 11 | temporal_context: 10 12 | 13 | train: 14 | batch_size: 3 15 | num_epochs: 50 16 | use_mixup: true 17 | mixup_backbone: true # whether to mixup inputs or the backbone outputs 18 | mixup_alpha: 0.1 # this value is from vivit: https://github.com/google-research/scenic/blob/main/scenic/projects/vivit/configs/epic_kitchens/vivit_large_factorised_encoder.py 19 | label_smoothing: 20 | action: 0.4 21 | verb: 0.01 22 | noun: 0.03 23 | modules_to_keep: null 24 | loss_wts: 25 | # classification for future action 26 | cls_action: 1.0 27 | cls_verb: 1.0 28 | cls_noun: 1.0 29 | # classification for updated past action 30 | past_cls_action: 1.0 31 | past_cls_verb: 1.0 32 | past_cls_noun: 1.0 33 | # regression for updated past feature 34 | past_reg: 1.0 35 | 36 | eval: 37 | batch_size: 3 38 | 39 | model: 40 | modal_dims: null #{"rgb": 1024, "objects": 352} # length of this dict corresponds to the number of modalities 41 | modal_feature_order: ["rgb", "objects", "audio", "poses", "flow"] 42 | common_dim: 1024 43 | dropout: 0.2 44 | 45 | opt: 46 | lr: 0.001 # learning rate 47 | wd: 0.000001 # weight decay 48 | lr_wd: null # [[backbone, 0.0001, 0.000001]] # modules with specific lr and wd 49 | grad_clip: null # by default, no clipping 50 | warmup: 51 | _target_: common.scheduler.Warmup 52 | init_lr_ratio: 0.01 # Warmup from this ratio of the orig LRs 53 | num_epochs: 0 # Warmup for this many epochs (will take out of total epochs) 54 | 55 | defaults: 56 | - dataset@dataset_train: epic_kitchens100/train 57 | - dataset@dataset_eval: epic_kitchens100/val 58 | - data@data_train: default 59 | - data@data_eval: default 60 | - dataset/epic_kitchens100/common 61 | - dataset/egtea/common 62 | - model/common 63 | - opt/optimizer: sgd 64 | - opt/scheduler: cosine 65 | - model/backbone: identity 66 | - model/future_predictor: base_future_predictor 67 | - model/fuser: SA-Fuser 68 | - model/CMFP: cmfp_early 69 | - model/mapping: linear 70 | - _self_ 71 | -------------------------------------------------------------------------------- /expts/06_SA-Fuser_egtea_train.txt: -------------------------------------------------------------------------------- 1 | workers=32 2 | num_gpus=2 3 | experiment_name=egtea 4 | init_from_model=null 5 | primary_metric=val_acc1_action_all-fused 6 | 7 | train.batch_size=16 8 | eval.batch_size=16 9 | train.num_epochs=50 10 | train.use_mixup=true 11 | train.mixup_backbone=true 12 | train.mixup_alpha=0.1 13 | train.loss_wts.past_cls_action=0.1 # following AVT 14 | 15 | model.modal_dims={rgb:1024, flow:1024} 16 | model.common_dim=1024 17 | model.dropout=0.2 18 | model.common.backbones={rgb: {_target_: torch.nn.Identity}, flow: {_target_: torch.nn.Identity}} 19 | model/future_predictor=base_future_predictor 20 | model/fuser=SA-Fuser 21 | model/CMFP=cmfp_early 22 | model/mapping=linear 23 | 24 | model.common.share_classifiers=true 25 | model.common.share_predictors=true 26 | model.common.modality_cls=false 27 | model.common.fusion_cls=true 28 | 29 | model.mapping.use_layernorm=false 30 | model.mapping.sparse_mapping=true 31 | 32 | model.fuser.depth=2 33 | model.fuser.num_heads=4 34 | model.fuser.embd_drop_rate=0.1 35 | model.fuser.drop_rate=0.1 36 | model.fuser.attn_drop_rate=0.1 37 | model.fuser.drop_path_rate=0.1 38 | model.fuser.cross_attn=false 39 | 40 | model.common.fp_output_len=1 41 | model.common.fp_inter_dim=2048 42 | model.common.fp_layers=2 43 | model.common.fp_heads=4 44 | model.common.fp_output_attentions=false 45 | model.common.embd_pdrop=0.1 46 | model.common.resid_pdrop=0.1 47 | model.common.attn_pdrop=0.1 48 | 49 | opt.lr=0.001 50 | opt.wd=0.000001 51 | opt/optimizer=sgd 52 | opt/scheduler=cosine 53 | opt.optimizer.nesterov=true 54 | opt.warmup.num_epochs=20 55 | opt.scheduler.num_epochs=30 56 | opt.scheduler.eta_min=1e-6 57 | 58 | data_train.zero_mask_rate=0.0 59 | 60 | dataset@dataset_train=egtea/train 61 | dataset@dataset_eval=egtea/val 62 | dataset.egtea.common.label_type=action 63 | dataset.egtea.common.sample_strategy=last_clip 64 | dataset.egtea.common.tau_a=0.5 65 | dataset.egtea.common.tau_o=10 66 | dataset.egtea.common.compute_dataset_stats=false 67 | dataset.egtea.common.max_els=null 68 | 69 | dataset.egtea.common.reader_fn={rgb: {_target_: datasets.reader_fns.EpicRULSTMFeatsReader, lmdb_path: ${dataset.egtea.common.rulstm_feats_dir}/TSN-C_3_egtea_action_CE_s${dataset.egtea.common.split}_rgb_model_best_fcfull_hd/}, flow: {_target_: datasets.reader_fns.EpicRULSTMFeatsReader, lmdb_path: ${dataset.egtea.common.rulstm_feats_dir}/TSN-C_3_egtea_action_CE_s${dataset.egtea.common.split}_flow_model_best_fcfull_hd/}} -------------------------------------------------------------------------------- /expts/05_MATT_ek100_train.txt: -------------------------------------------------------------------------------- 1 | workers=32 2 | num_gpus=2 3 | experiment_name=MATT 4 | init_from_model=null 5 | primary_metric=val_mt5r_action_all-fused 6 | 7 | train.loss_wts.past_cls_action=0 8 | 9 | train.batch_size=16 10 | eval.batch_size=16 11 | train.num_epochs=50 12 | train.use_mixup=true 13 | train.mixup_backbone=true 14 | train.mixup_alpha=0.1 15 | 16 | model.modal_dims={rgb:1024, objects:352, audio:1024, flow:1024} 17 | model.common_dim=1024 18 | model.dropout=0.2 19 | model.common.backbones={rgb: {_target_: torch.nn.Identity}, objects: {_target_: torch.nn.Identity}, flow: {_target_: torch.nn.Identity}, audio: {_target_: torch.nn.Identity}} 20 | model/future_predictor=base_future_predictor 21 | model/fuser=MATT 22 | model/CMFP=scorefusion 23 | model/mapping=linear 24 | 25 | model.common.share_classifiers=false 26 | model.common.share_predictors=false 27 | model.common.modality_cls=true 28 | model.common.fusion_cls=false 29 | 30 | model.mapping.use_layernorm=false 31 | model.mapping.sparse_mapping=true 32 | 33 | model.fuser.drop_rate=0.8 34 | 35 | model.common.fp_output_len=1 36 | model.common.fp_inter_dim=2048 37 | model.common.fp_layers=2 38 | model.common.fp_heads=4 39 | model.common.fp_output_attentions=false 40 | model.common.embd_pdrop=0.1 41 | model.common.resid_pdrop=0.1 42 | model.common.attn_pdrop=0.1 43 | 44 | opt.lr=0.001 45 | opt.wd=0.000001 46 | opt/optimizer=sgd 47 | opt/scheduler=cosine 48 | opt.optimizer.nesterov=true 49 | opt.warmup.num_epochs=20 50 | opt.scheduler.num_epochs=30 51 | opt.scheduler.eta_min=1e-6 52 | 53 | data_train.zero_mask_rate=0.0 54 | 55 | dataset@dataset_train=epic_kitchens100/train 56 | dataset@dataset_eval=epic_kitchens100/val 57 | dataset.epic_kitchens100.common.label_type=action 58 | dataset.epic_kitchens100.common.sample_strategy=last_clip 59 | dataset.epic_kitchens100.common.tau_a=1 60 | dataset.epic_kitchens100.common.tau_o=10 61 | dataset.epic_kitchens100.common.compute_dataset_stats=false 62 | dataset.epic_kitchens100.common.max_els=null 63 | 64 | dataset.epic_kitchens100.common.reader_fn={rgb: {_target_: datasets.reader_fns.EpicRULSTMFeatsReader, lmdb_path: ${dataset.epic_kitchens100.common.rulstm_feats_dir}/rgb_omnivore/}, objects: {_target_: datasets.reader_fns.EpicRULSTMFeatsReader, lmdb_path: ${dataset.epic_kitchens100.common.rulstm_feats_dir}/obj/}, flow: {_target_: datasets.reader_fns.EpicRULSTMFeatsReader, lmdb_path: ${dataset.epic_kitchens100.common.rulstm_feats_dir}/flow/}, audio: {_target_: datasets.reader_fns.EpicRULSTMFeatsReader, lmdb_path: ${dataset.epic_kitchens100.common.rulstm_feats_dir}/audio/, warn_if_using_closeby_frame: false}} -------------------------------------------------------------------------------- /expts/04_CA-Fuser_ek100_train.txt: -------------------------------------------------------------------------------- 1 | workers=32 2 | num_gpus=2 3 | experiment_name=CA-Fuser 4 | init_from_model=null 5 | primary_metric=val_mt5r_action_all-fused 6 | 7 | train.batch_size=16 8 | eval.batch_size=16 9 | train.num_epochs=50 10 | train.use_mixup=true 11 | train.mixup_backbone=true 12 | train.mixup_alpha=0.1 13 | 14 | model.modal_dims={rgb:1024, objects:352, audio:1024, flow:1024} 15 | model.common_dim=1024 16 | model.dropout=0.2 17 | model.common.backbones={rgb: {_target_: torch.nn.Identity}, objects: {_target_: torch.nn.Identity}, flow: {_target_: torch.nn.Identity}, audio: {_target_: torch.nn.Identity}} 18 | model/future_predictor=base_future_predictor 19 | model/fuser=CA-Fuser 20 | model/CMFP=cmfp_early 21 | model/mapping=linear 22 | 23 | model.common.share_classifiers=true 24 | model.common.share_predictors=true 25 | model.common.modality_cls=false 26 | model.common.fusion_cls=true 27 | 28 | model.mapping.use_layernorm=false 29 | model.mapping.sparse_mapping=true 30 | 31 | model.fuser.num_heads=4 32 | model.fuser.embd_drop_rate=0.1 33 | model.fuser.drop_rate=0.1 34 | model.fuser.attn_drop_rate=0.1 35 | model.fuser.drop_path_rate=0.1 36 | 37 | model.common.fp_output_len=1 38 | model.common.fp_inter_dim=2048 39 | model.common.fp_layers=6 40 | model.common.fp_heads=4 41 | model.common.fp_output_attentions=false 42 | model.common.embd_pdrop=0.1 43 | model.common.resid_pdrop=0.1 44 | model.common.attn_pdrop=0.1 45 | 46 | opt.lr=0.001 47 | opt.wd=0.000001 48 | opt/optimizer=sgd 49 | opt/scheduler=cosine 50 | opt.optimizer.nesterov=true 51 | opt.warmup.num_epochs=20 52 | opt.scheduler.num_epochs=30 53 | opt.scheduler.eta_min=1e-6 54 | 55 | data_train.zero_mask_rate=0.0 56 | 57 | dataset@dataset_train=epic_kitchens100/train 58 | dataset@dataset_eval=epic_kitchens100/val 59 | dataset.epic_kitchens100.common.label_type=action 60 | dataset.epic_kitchens100.common.sample_strategy=last_clip 61 | dataset.epic_kitchens100.common.tau_a=1 62 | dataset.epic_kitchens100.common.tau_o=10 63 | dataset.epic_kitchens100.common.compute_dataset_stats=false 64 | dataset.epic_kitchens100.common.max_els=null 65 | 66 | dataset.epic_kitchens100.common.reader_fn={rgb: {_target_: datasets.reader_fns.EpicRULSTMFeatsReader, lmdb_path: ${dataset.epic_kitchens100.common.rulstm_feats_dir}/rgb_omnivore/}, objects: {_target_: datasets.reader_fns.EpicRULSTMFeatsReader, lmdb_path: ${dataset.epic_kitchens100.common.rulstm_feats_dir}/obj/}, flow: {_target_: datasets.reader_fns.EpicRULSTMFeatsReader, lmdb_path: ${dataset.epic_kitchens100.common.rulstm_feats_dir}/flow/}, audio: {_target_: datasets.reader_fns.EpicRULSTMFeatsReader, lmdb_path: ${dataset.epic_kitchens100.common.rulstm_feats_dir}/audio/, warn_if_using_closeby_frame: false}} -------------------------------------------------------------------------------- /expts/02_SA-Fuser_wo_token_ek100_train.txt: -------------------------------------------------------------------------------- 1 | workers=32 2 | num_gpus=2 3 | experiment_name=SA-Fuser_wo_token 4 | init_from_model=null 5 | primary_metric=val_mt5r_action_all-fused 6 | 7 | train.batch_size=16 8 | eval.batch_size=16 9 | train.num_epochs=50 10 | train.use_mixup=true 11 | train.mixup_backbone=true 12 | train.mixup_alpha=0.1 13 | 14 | model.modal_dims={rgb:1024, objects:352, audio:1024, flow:1024} 15 | model.common_dim=1024 16 | model.dropout=0.2 17 | model.common.backbones={rgb: {_target_: torch.nn.Identity}, objects: {_target_: torch.nn.Identity}, flow: {_target_: torch.nn.Identity}, audio: {_target_: torch.nn.Identity}} 18 | model/future_predictor=base_future_predictor 19 | model/fuser=SA-Fuser_wo_token 20 | model/CMFP=cmfp_early 21 | model/mapping=linear 22 | 23 | model.common.share_classifiers=true 24 | model.common.share_predictors=true 25 | model.common.modality_cls=false 26 | model.common.fusion_cls=true 27 | 28 | model.mapping.use_layernorm=false 29 | model.mapping.sparse_mapping=true 30 | 31 | model.fuser.depth=6 32 | model.fuser.num_heads=4 33 | model.fuser.embd_drop_rate=0.1 34 | model.fuser.drop_rate=0.1 35 | model.fuser.attn_drop_rate=0.1 36 | model.fuser.drop_path_rate=0.1 37 | model.fuser.cross_attn=false 38 | 39 | model.common.fp_output_len=1 40 | model.common.fp_inter_dim=2048 41 | model.common.fp_layers=6 42 | model.common.fp_heads=4 43 | model.common.fp_output_attentions=false 44 | model.common.embd_pdrop=0.1 45 | model.common.resid_pdrop=0.1 46 | model.common.attn_pdrop=0.1 47 | 48 | opt.lr=0.001 49 | opt.wd=0.000001 50 | opt/optimizer=sgd 51 | opt/scheduler=cosine 52 | opt.optimizer.nesterov=true 53 | opt.warmup.num_epochs=20 54 | opt.scheduler.num_epochs=30 55 | opt.scheduler.eta_min=1e-6 56 | 57 | data_train.zero_mask_rate=0.0 58 | 59 | dataset@dataset_train=epic_kitchens100/train 60 | dataset@dataset_eval=epic_kitchens100/val 61 | dataset.epic_kitchens100.common.label_type=action 62 | dataset.epic_kitchens100.common.sample_strategy=last_clip 63 | dataset.epic_kitchens100.common.tau_a=1 64 | dataset.epic_kitchens100.common.tau_o=10 65 | dataset.epic_kitchens100.common.compute_dataset_stats=false 66 | dataset.epic_kitchens100.common.max_els=null 67 | 68 | dataset.epic_kitchens100.common.reader_fn={rgb: {_target_: datasets.reader_fns.EpicRULSTMFeatsReader, lmdb_path: ${dataset.epic_kitchens100.common.rulstm_feats_dir}/rgb_omnivore/}, objects: {_target_: datasets.reader_fns.EpicRULSTMFeatsReader, lmdb_path: ${dataset.epic_kitchens100.common.rulstm_feats_dir}/obj/}, flow: {_target_: datasets.reader_fns.EpicRULSTMFeatsReader, lmdb_path: ${dataset.epic_kitchens100.common.rulstm_feats_dir}/flow/}, audio: {_target_: datasets.reader_fns.EpicRULSTMFeatsReader, lmdb_path: ${dataset.epic_kitchens100.common.rulstm_feats_dir}/audio/, warn_if_using_closeby_frame: false}} -------------------------------------------------------------------------------- /expts/01_SA-Fuser_ek100_train.txt: -------------------------------------------------------------------------------- 1 | workers=32 2 | num_gpus=2 3 | experiment_name=SA-Fuser 4 | init_from_model=null 5 | primary_metric=val_mt5r_action_all-fused 6 | 7 | train.batch_size=16 8 | eval.batch_size=16 9 | train.num_epochs=50 10 | train.use_mixup=true 11 | train.mixup_backbone=true 12 | train.mixup_alpha=0.1 13 | 14 | model.modal_dims={rgb:1024, objects:352, audio:1024, flow:1024} 15 | model.common_dim=1024 16 | model.dropout=0.2 17 | model.common.backbones={rgb: {_target_: torch.nn.Identity}, objects: {_target_: torch.nn.Identity}, flow: {_target_: torch.nn.Identity}, audio: {_target_: torch.nn.Identity}} 18 | model/future_predictor=base_future_predictor 19 | model/fuser=SA-Fuser 20 | model/CMFP=cmfp_early 21 | model/mapping=linear 22 | 23 | model.common.share_classifiers=true 24 | model.common.share_predictors=true 25 | model.common.modality_cls=false 26 | model.common.fusion_cls=true 27 | 28 | model.mapping.use_layernorm=false 29 | model.mapping.sparse_mapping=true 30 | 31 | model.fuser.depth=6 32 | model.fuser.num_heads=4 33 | model.fuser.embd_drop_rate=0.1 34 | model.fuser.drop_rate=0.1 35 | model.fuser.attn_drop_rate=0.1 36 | model.fuser.drop_path_rate=0.1 37 | model.fuser.cross_attn=false 38 | 39 | model.common.fp_output_len=1 40 | model.common.fp_inter_dim=2048 41 | model.common.fp_layers=6 42 | model.common.fp_heads=4 43 | model.common.fp_output_attentions=false 44 | model.common.embd_pdrop=0.1 45 | model.common.resid_pdrop=0.1 46 | model.common.attn_pdrop=0.1 47 | 48 | opt.lr=0.001 49 | opt.wd=0.000001 50 | opt/optimizer=sgd 51 | opt/scheduler=cosine 52 | opt.optimizer.nesterov=true 53 | opt.warmup.num_epochs=20 54 | opt.scheduler.num_epochs=30 55 | opt.scheduler.eta_min=1e-6 56 | 57 | data_train.zero_mask_rate=0.0 58 | data_train.num_frames=16 59 | data_eval.num_frames=16 60 | 61 | dataset@dataset_train=epic_kitchens100/train 62 | dataset@dataset_eval=epic_kitchens100/val 63 | dataset.epic_kitchens100.common.label_type=action 64 | dataset.epic_kitchens100.common.sample_strategy=last_clip 65 | dataset.epic_kitchens100.common.tau_a=1 66 | dataset.epic_kitchens100.common.tau_o=16 67 | dataset.epic_kitchens100.common.compute_dataset_stats=false 68 | dataset.epic_kitchens100.common.max_els=null 69 | 70 | dataset.epic_kitchens100.common.reader_fn={rgb: {_target_: datasets.reader_fns.EpicRULSTMFeatsReader, lmdb_path: ${dataset.epic_kitchens100.common.rulstm_feats_dir}/rgb_omnivore/}, objects: {_target_: datasets.reader_fns.EpicRULSTMFeatsReader, lmdb_path: ${dataset.epic_kitchens100.common.rulstm_feats_dir}/obj/}, flow: {_target_: datasets.reader_fns.EpicRULSTMFeatsReader, lmdb_path: ${dataset.epic_kitchens100.common.rulstm_feats_dir}/flow/}, audio: {_target_: datasets.reader_fns.EpicRULSTMFeatsReader, lmdb_path: ${dataset.epic_kitchens100.common.rulstm_feats_dir}/audio/, warn_if_using_closeby_frame: false}} -------------------------------------------------------------------------------- /expts/03_T-SA-Fuser_ek100_train.txt: -------------------------------------------------------------------------------- 1 | workers=32 2 | num_gpus=2 3 | experiment_name=T-SA-Fuser 4 | init_from_model=null 5 | primary_metric=val_mt5r_action_all-fused 6 | 7 | train.batch_size=16 8 | eval.batch_size=16 9 | train.num_epochs=50 10 | train.use_mixup=true 11 | train.mixup_backbone=true 12 | train.mixup_alpha=0.1 13 | 14 | model.modal_dims={rgb:1024, objects:352, audio:1024, flow:1024} 15 | model.common_dim=1024 16 | model.dropout=0.2 17 | model.common.backbones={rgb: {_target_: torch.nn.Identity}, objects: {_target_: torch.nn.Identity}, flow: {_target_: torch.nn.Identity}, audio: {_target_: torch.nn.Identity}} 18 | model/future_predictor=base_future_predictor 19 | model/fuser=T-SA-Fuser 20 | model/CMFP=cmfp_early 21 | model/mapping=linear 22 | 23 | model.common.share_classifiers=true 24 | model.common.share_predictors=true 25 | model.common.modality_cls=false 26 | model.common.fusion_cls=true 27 | 28 | model.mapping.use_layernorm=false 29 | model.mapping.sparse_mapping=true 30 | 31 | model.fuser.depth=6 32 | model.fuser.num_heads=4 33 | model.fuser.embd_drop_rate=0.1 34 | model.fuser.drop_rate=0.1 35 | model.fuser.attn_drop_rate=0.1 36 | model.fuser.drop_path_rate=0.1 37 | model.fuser.modal_encoding=true 38 | model.fuser.frame_level_token=true 39 | model.fuser.temporal_sequence_length=10 40 | 41 | model.common.fp_output_len=1 42 | model.common.fp_inter_dim=2048 43 | model.common.fp_layers=6 44 | model.common.fp_heads=4 45 | model.common.fp_output_attentions=false 46 | model.common.embd_pdrop=0.1 47 | model.common.resid_pdrop=0.1 48 | model.common.attn_pdrop=0.1 49 | 50 | opt.lr=0.001 51 | opt.wd=0.000001 52 | opt/optimizer=sgd 53 | opt/scheduler=cosine 54 | opt.optimizer.nesterov=true 55 | opt.warmup.num_epochs=20 56 | opt.scheduler.num_epochs=30 57 | opt.scheduler.eta_min=1e-6 58 | 59 | data_train.zero_mask_rate=0.0 60 | 61 | dataset@dataset_train=epic_kitchens100/train 62 | dataset@dataset_eval=epic_kitchens100/val 63 | dataset.epic_kitchens100.common.label_type=action 64 | dataset.epic_kitchens100.common.sample_strategy=last_clip 65 | dataset.epic_kitchens100.common.tau_a=1 66 | dataset.epic_kitchens100.common.tau_o=10 67 | dataset.epic_kitchens100.common.compute_dataset_stats=false 68 | dataset.epic_kitchens100.common.max_els=null 69 | 70 | dataset.epic_kitchens100.common.reader_fn={rgb: {_target_: datasets.reader_fns.EpicRULSTMFeatsReader, lmdb_path: ${dataset.epic_kitchens100.common.rulstm_feats_dir}/rgb_omnivore/}, objects: {_target_: datasets.reader_fns.EpicRULSTMFeatsReader, lmdb_path: ${dataset.epic_kitchens100.common.rulstm_feats_dir}/obj/}, flow: {_target_: datasets.reader_fns.EpicRULSTMFeatsReader, lmdb_path: ${dataset.epic_kitchens100.common.rulstm_feats_dir}/flow/}, audio: {_target_: datasets.reader_fns.EpicRULSTMFeatsReader, lmdb_path: ${dataset.epic_kitchens100.common.rulstm_feats_dir}/audio/, warn_if_using_closeby_frame: false}} -------------------------------------------------------------------------------- /annotations/ek55_rulstm/training_videos.csv: -------------------------------------------------------------------------------- 1 | P01_02 2 | P01_03 3 | P01_04 4 | P01_05 5 | P01_06 6 | P01_07 7 | P01_08 8 | P01_09 9 | P01_16 10 | P01_17 11 | P01_18 12 | P01_19 13 | P02_01 14 | P02_02 15 | P02_04 16 | P02_06 17 | P02_07 18 | P02_08 19 | P02_09 20 | P02_10 21 | P02_11 22 | P03_02 23 | P03_03 24 | P03_04 25 | P03_05 26 | P03_07 27 | P03_08 28 | P03_09 29 | P03_10 30 | P03_12 31 | P03_13 32 | P03_14 33 | P03_15 34 | P03_16 35 | P03_17 36 | P03_18 37 | P03_19 38 | P03_20 39 | P03_27 40 | P03_28 41 | P04_01 42 | P04_02 43 | P04_03 44 | P04_04 45 | P04_05 46 | P04_06 47 | P04_07 48 | P04_08 49 | P04_10 50 | P04_11 51 | P04_12 52 | P04_13 53 | P04_14 54 | P04_15 55 | P04_16 56 | P04_17 57 | P04_18 58 | P04_19 59 | P04_20 60 | P04_21 61 | P04_22 62 | P04_23 63 | P05_01 64 | P05_02 65 | P05_03 66 | P05_04 67 | P05_05 68 | P05_06 69 | P05_08 70 | P06_01 71 | P06_02 72 | P06_03 73 | P06_07 74 | P06_08 75 | P06_09 76 | P07_01 77 | P07_03 78 | P07_04 79 | P07_05 80 | P07_06 81 | P07_07 82 | P07_09 83 | P07_11 84 | P08_02 85 | P08_03 86 | P08_04 87 | P08_06 88 | P08_07 89 | P08_08 90 | P08_11 91 | P08_13 92 | P08_18 93 | P08_19 94 | P08_20 95 | P08_21 96 | P08_22 97 | P08_23 98 | P08_24 99 | P08_25 100 | P08_26 101 | P08_27 102 | P08_28 103 | P10_02 104 | P10_04 105 | P12_01 106 | P12_02 107 | P12_04 108 | P12_05 109 | P12_06 110 | P12_07 111 | P13_05 112 | P13_07 113 | P13_08 114 | P13_10 115 | P14_03 116 | P14_04 117 | P14_05 118 | P14_07 119 | P14_09 120 | P15_01 121 | P15_02 122 | P15_03 123 | P15_07 124 | P15_08 125 | P15_09 126 | P15_10 127 | P15_11 128 | P15_12 129 | P15_13 130 | P16_01 131 | P16_02 132 | P16_03 133 | P17_01 134 | P17_03 135 | P17_04 136 | P19_01 137 | P19_02 138 | P19_03 139 | P19_04 140 | P20_01 141 | P20_02 142 | P21_01 143 | P21_03 144 | P21_04 145 | P22_05 146 | P22_06 147 | P22_07 148 | P22_09 149 | P22_12 150 | P22_14 151 | P22_15 152 | P22_16 153 | P22_17 154 | P23_01 155 | P23_02 156 | P23_04 157 | P24_01 158 | P24_02 159 | P24_03 160 | P24_04 161 | P24_05 162 | P24_06 163 | P24_07 164 | P25_01 165 | P25_02 166 | P25_03 167 | P25_04 168 | P25_05 169 | P25_09 170 | P25_10 171 | P25_12 172 | P26_01 173 | P26_03 174 | P26_04 175 | P26_05 176 | P26_06 177 | P26_07 178 | P26_08 179 | P26_09 180 | P26_10 181 | P26_12 182 | P26_13 183 | P26_14 184 | P26_15 185 | P26_17 186 | P26_18 187 | P26_19 188 | P26_20 189 | P26_21 190 | P26_22 191 | P26_23 192 | P26_24 193 | P26_25 194 | P26_26 195 | P26_27 196 | P26_28 197 | P26_29 198 | P27_01 199 | P27_02 200 | P27_04 201 | P27_06 202 | P27_07 203 | P28_01 204 | P28_02 205 | P28_03 206 | P28_04 207 | P28_06 208 | P28_07 209 | P28_08 210 | P28_09 211 | P28_10 212 | P28_11 213 | P28_14 214 | P29_01 215 | P29_02 216 | P29_03 217 | P29_04 218 | P30_02 219 | P30_04 220 | P30_05 221 | P30_06 222 | P30_10 223 | P30_11 224 | P31_02 225 | P31_03 226 | P31_04 227 | P31_05 228 | P31_06 229 | P31_07 230 | P31_09 231 | P31_13 232 | P31_14 233 | -------------------------------------------------------------------------------- /annotations/egtea/actions.csv: -------------------------------------------------------------------------------- 1 | 0, 0_0, Inspect/Read_recipe 2 | 1, 1_1, Open_fridge 3 | 2, 2_2, Take_eating:utensil 4 | 3, 3_3, Cut_tomato 5 | 4, 4_4, Turn on_faucet 6 | 5, 5_2, Put_eating:utensil 7 | 6, 1_5, Open_cabinet 8 | 7, 2_6, Take_condiment:container 9 | 8, 3_7, Cut_cucumber 10 | 9, 6_8, Operate_stove 11 | 10, 7_1, Close_fridge 12 | 11, 3_9, Cut_carrot 13 | 12, 5_6, Put_condiment:container 14 | 13, 3_10, Cut_onion 15 | 14, 1_11, Open_drawer 16 | 15, 2_12, Take_plate 17 | 16, 2_13, Take_bowl 18 | 17, 5_13, Put_bowl 19 | 18, 5_14, Put_trash 20 | 19, 5_12, Put_plate 21 | 20, 3_15, Cut_bell:pepper 22 | 21, 5_16, Put_cooking:utensil 23 | 22, 2_17, Take_paper:towel 24 | 23, 8_18, Move Around_bacon 25 | 24, 1_6, Open_condiment:container 26 | 25, 9_2, Wash_eating:utensil 27 | 26, 10_19, Spread_condiment 28 | 27, 11_4, Turn off_faucet 29 | 28, 5_20, Put_pan 30 | 29, 2_16, Take_cooking:utensil 31 | 30, 5_21, Put_lettuce 32 | 31, 8_22, Move Around_patty 33 | 32, 5_23, Put_pot 34 | 33, 7_5, Close_cabinet 35 | 34, 5_24, Put_bread 36 | 35, 2_24, Take_bread 37 | 36, 7_6, Close_condiment:container 38 | 37, 1_25, Open_fridge:drawer 39 | 38, 9_26, Wash_hand 40 | 39, 5_3, Put_tomato 41 | 40, 2_27, Take_seasoning:container 42 | 41, 2_28, Take_cup 43 | 42, 12_21, Divide/Pull Apart_lettuce 44 | 43, 5_28, Put_cup 45 | 44, 2_23, Take_pot 46 | 45, 13_29, Clean/Wipe_counter 47 | 46, 2_30, Take_bread:container 48 | 47, 2_3, Take_tomato 49 | 48, 2_20, Take_pan 50 | 49, 8_20, Move Around_pan 51 | 50, 9_31, Wash_cutting:board 52 | 51, 5_30, Put_bread:container 53 | 52, 2_32, Take_sponge 54 | 53, 2_21, Take_lettuce 55 | 54, 2_10, Take_onion 56 | 55, 5_32, Put_sponge 57 | 56, 12_17, Divide/Pull Apart_paper:towel 58 | 57, 1_33, Open_dishwasher 59 | 58, 2_34, Take_cheese:container 60 | 59, 2_35, Take_oil:container 61 | 60, 5_27, Put_seasoning:container 62 | 61, 2_7, Take_cucumber 63 | 62, 9_20, Wash_pan 64 | 63, 2_15, Take_bell:pepper 65 | 64, 12_10, Divide/Pull Apart_onion 66 | 65, 5_31, Put_cutting:board 67 | 66, 14_36, Mix_mixture 68 | 67, 2_37, Take_tomato:container 69 | 68, 5_38, Put_cheese 70 | 69, 8_2, Move Around_eating:utensil 71 | 70, 5_15, Put_bell:pepper 72 | 71, 15_39, Pour_oil 73 | 72, 2_40, Take_pasta:container 74 | 73, 3_21, Cut_lettuce 75 | 74, 5_37, Put_tomato:container 76 | 75, 9_13, Wash_bowl 77 | 76, 3_41, Cut_olive 78 | 77, 7_11, Close_drawer 79 | 78, 15_19, Pour_condiment 80 | 79, 9_23, Wash_pot 81 | 80, 14_42, Mix_pasta 82 | 81, 1_30, Open_bread:container 83 | 82, 2_43, Take_grocery:bag 84 | 83, 2_38, Take_cheese 85 | 84, 15_44, Pour_seasoning 86 | 85, 14_45, Mix_egg 87 | 86, 15_46, Pour_water 88 | 87, 5_17, Put_paper:towel 89 | 88, 5_7, Put_cucumber 90 | 89, 16_47, Compress_sandwich 91 | 90, 5_34, Put_cheese:container 92 | 91, 5_10, Put_onion 93 | 92, 17_45, Crack_egg 94 | 93, 2_31, Take_cutting:board 95 | 94, 1_35, Open_oil:container 96 | 95, 18_48, Squeeze_washing:liquid 97 | 96, 6_49, Operate_microwave 98 | 97, 7_25, Close_fridge:drawer 99 | 98, 9_50, Wash_strainer 100 | 99, 8_13, Move Around_bowl 101 | 100, 8_23, Move Around_pot 102 | 101, 5_43, Put_grocery:bag 103 | 102, 2_45, Take_egg 104 | 103, 1_34, Open_cheese:container 105 | 104, 7_35, Close_oil:container 106 | 105, 5_35, Put_oil:container 107 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | .idea/ 161 | 162 | outputs/ 163 | -------------------------------------------------------------------------------- /models/feature_mapping.py: -------------------------------------------------------------------------------- 1 | """Implementation of different projection functions that map feature vectors with different sizes to a common size""" 2 | 3 | import torch 4 | from torch import nn as nn 5 | from functools import partial 6 | from torch.nn import functional as F 7 | 8 | 9 | class GatedEmbeddingUnit(nn.Module): 10 | def __init__(self, input_dimension, output_dimension): 11 | super().__init__() 12 | self.fc = nn.Linear(input_dimension, output_dimension) 13 | self.cg = ContextGating(output_dimension) 14 | 15 | def forward(self, x): 16 | x = self.fc(x) 17 | x = self.cg(x) 18 | return x 19 | 20 | 21 | class ContextGating(nn.Module): 22 | def __init__(self, dimension): 23 | super(ContextGating, self).__init__() 24 | self.fc = nn.Linear(dimension, dimension) 25 | 26 | def forward(self, x): 27 | x1 = self.fc(x) 28 | x = torch.cat((x, x1), 1) 29 | return F.glu(x, 1) 30 | 31 | 32 | class GatedLinear(nn.Module): 33 | def __init__(self, in_features, out_features, use_layernorm: bool = True): 34 | super().__init__() 35 | 36 | tmp = [nn.Linear(in_features, out_features), ContextGating(out_features)] 37 | 38 | if use_layernorm: 39 | norm_layer = partial(nn.LayerNorm, eps=1e-6) 40 | tmp.append(norm_layer(out_features)) 41 | 42 | layers = tmp # deprecated: if in_features != out_features else [nn.Identity()] 43 | 44 | self.mapping = nn.Sequential(*layers) 45 | self.use_layernorm = use_layernorm 46 | 47 | def forward(self, x): 48 | return self.mapping(x) 49 | 50 | def __str__(self): 51 | return f'Gated linear mapping layer with use_layernorm: {self.use_layernorm}' 52 | 53 | 54 | class Linear(nn.Module): 55 | """Implements the linear feature mapping layer""" 56 | def __init__(self, in_features, out_features, use_layernorm: bool = False, sparse_mapping=True): 57 | super().__init__() 58 | 59 | if sparse_mapping: 60 | layers = [nn.Linear(in_features, out_features, bias=False) 61 | if in_features != out_features else nn.Identity()] 62 | else: 63 | layers = [nn.Linear(in_features, out_features, bias=False)] 64 | 65 | if use_layernorm: 66 | norm_layer = partial(nn.LayerNorm, eps=1e-6) 67 | layers.append(norm_layer(out_features)) 68 | 69 | self.mapping = nn.Sequential(*layers) 70 | self.use_layernorm = use_layernorm 71 | self.sparse_mapping = sparse_mapping 72 | 73 | def forward(self, x): 74 | return self.mapping(x) 75 | 76 | def __str__(self): 77 | return f'Linear mapping layer with use_layernorm: {self.use_layernorm}, ' \ 78 | f'and sparse_mapping: {self.sparse_mapping}' 79 | 80 | 81 | def get_activation_layer(name): 82 | act_layers = { 83 | 'relu': nn.ReLU(), 84 | 'gelu': nn.GELU(), 85 | 'none': nn.Identity(), 86 | } 87 | assert name in act_layers.keys(), f'{name} is not supported in {list(act_layers.keys())}.' 88 | return act_layers[name] 89 | 90 | 91 | class NonLinear(nn.Module): 92 | """Implements the non-linear feature mapping layer""" 93 | def __init__(self, in_features, out_features, use_layernorm: bool = False, activation='relu'): 94 | super().__init__() 95 | 96 | layers = [nn.Linear(in_features, out_features), get_activation_layer(activation)] 97 | 98 | if use_layernorm: 99 | norm_layer = partial(nn.LayerNorm, eps=1e-6) 100 | layers.append(norm_layer(out_features)) 101 | 102 | self.mapping = nn.Sequential(*layers) 103 | self.use_layernorm = use_layernorm 104 | self.activation = activation 105 | 106 | def forward(self, x): 107 | return self.mapping(x) 108 | 109 | def __str__(self): 110 | return f'Nonlinear mapping layer with use_layernorm: {self.use_layernorm}, ' \ 111 | f'and activation: {self.activation}' 112 | -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | """Implementation of the base model framework, instantiating different backbones, fusion methods and future predictor 2 | methods using hydra.utils.instantiate""" 3 | from itertools import repeat 4 | from typing import Dict, Tuple 5 | import torch 6 | import torch.nn as nn 7 | import hydra 8 | from omegaconf import OmegaConf 9 | from common import utils 10 | 11 | CLS_MAP_PREFIX = 'cls_map_' 12 | PAST_LOGITS_PREFIX = 'past_' 13 | 14 | 15 | class BaseModel(nn.Module): 16 | def __init__(self, model_cfg: OmegaConf, num_classes: Dict[str, int], 17 | class_mappings: Dict[Tuple[str, str], torch.FloatTensor]): 18 | super().__init__() 19 | self.backbone = nn.ModuleDict() 20 | 21 | for mod, backbone_conf in model_cfg.common.backbones.items(): 22 | self.backbone[mod] = hydra.utils.instantiate(backbone_conf) 23 | 24 | self.future_predictor = hydra.utils.instantiate(model_cfg.CMFP, model_cfg=model_cfg, 25 | num_classes=num_classes, _recursive_=False) 26 | 27 | # Store the class mapping as buffers 28 | for (src, dst), mapping in class_mappings.items(): 29 | self.register_buffer(f'{CLS_MAP_PREFIX}{src}_{dst}', mapping) 30 | 31 | def forward_singlecrop(self, data_dict, **kwargs): 32 | """ 33 | Args: 34 | video (torch.Tensor, Bx#clipsxCxTxHxW) 35 | target_shape: The shape of the target. Some of these layers might 36 | be able to use this information. 37 | """ 38 | feats_past = {} 39 | for mod, data in data_dict.items(): 40 | feats = self.backbone[mod](data) 41 | # spatial mean B*clipsxCxT 42 | feats = torch.mean(feats, [-1, -2]) 43 | feats = feats.permute((0, 1, 3, 2)) 44 | if feats.ndim == 4: 45 | feats = torch.flatten(feats, 1, 2) # BxTxF, T=10 46 | feats_past[mod] = feats 47 | 48 | target = kwargs['target'] 49 | target_subclips = kwargs['target_subclips'] 50 | target_subclips_ignore_index = kwargs['target_subclips_ignore_index'] 51 | 52 | # Mixup the backbone outputs if required 53 | if kwargs['mixup_fn'] is not None: 54 | mixup_fn = kwargs['mixup_fn'] 55 | feats_past, target, target_subclips, target_subclips_ignore_index = \ 56 | mixup_fn(feats_past, target, target_subclips) 57 | 58 | # Future prediction 59 | outputs = self.future_predictor(feats_past) 60 | outputs_target = { 61 | 'target': target, 62 | 'target_subclips': target_subclips, 63 | 'target_subclips_ignore_index': target_subclips_ignore_index 64 | } 65 | 66 | return outputs, outputs_target 67 | 68 | def forward(self, video_data, *args, **kwargs): 69 | """ 70 | Args: video (torch.Tensor) 71 | Could be (B, #clips, C, T, H, W) or 72 | (B, #clips, #crops, C, T, H, W) 73 | Returns: 74 | Final features 75 | """ 76 | for mod, data in video_data.items(): 77 | if data.ndim == 6: 78 | video_data[mod] = [data] 79 | elif data.ndim == 7 and data.size(2) == 1: 80 | video_data[mod] = [data.squeeze(2)] 81 | elif data.ndim == 7: 82 | video_data[mod] = torch.unbind(data, dim=2) 83 | else: 84 | raise NotImplementedError('Unsupported size %s' % data.shape) 85 | 86 | all_mods = sorted(list(video_data.keys())) 87 | all_data = [video_data[mod] for mod in all_mods] 88 | num_crops = max([len(sl) for sl in all_data]) 89 | all_data = [sl * (num_crops // len(sl)) for sl in all_data] 90 | all_crops = list(zip(*all_data)) 91 | 92 | video_data = [{m: c for m, c in zip(mods, crops)} for mods, crops in zip(repeat(all_mods), all_crops)] 93 | 94 | feats = [self.forward_singlecrop(el, *args, **kwargs) for el in video_data] 95 | 96 | # Since we only apply mixup in training and in training we only have one single crop, 97 | # it's fine to just use the index 0 here 98 | output_targets = feats[0][1] 99 | 100 | # convert to dicts of lists 101 | feats_merged = {} 102 | for out_dict, _ in feats: 103 | for key in out_dict: 104 | if key not in feats_merged: 105 | feats_merged[key] = {k: [v] for k, v in out_dict[key].items()} 106 | else: 107 | for k, v in feats_merged[key].items(): 108 | v.append(out_dict[key][k]) 109 | 110 | # Average over the crops 111 | for out_key in feats_merged: 112 | if out_key == 'attentions': 113 | # we select the attentions from the first element, as for attention analysis we only have one crop 114 | feats_merged[out_key] = {k: el[0] for k, el in feats_merged[out_key].items()} 115 | continue 116 | feats_merged[out_key] = {k: torch.mean(torch.stack(el, dim=0), dim=0) for k, el in 117 | feats_merged[out_key].items()} 118 | 119 | return feats_merged, output_targets 120 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: afft 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - _openmp_mutex=4.5=1_gnu 9 | - blas=1.0=mkl 10 | - bzip2=1.0.8=h7b6447c_0 11 | - ca-certificates=2021.10.8=ha878542_0 12 | - certifi=2021.10.8=py37h89c1867_2 13 | - colorama=0.4.4=pyh9f0ad1d_0 14 | - cudatoolkit=11.3.1=h2bc3f7f_2 15 | - ffmpeg=4.3=hf484d3e_0 16 | - freetype=2.11.0=h70c0345_0 17 | - giflib=5.2.1=h7b6447c_0 18 | - gmp=6.2.1=h2531618_2 19 | - gnutls=3.6.15=he1e5248_0 20 | - intel-openmp=2021.4.0=h06a4308_3561 21 | - jpeg=9d=h7f8727e_0 22 | - lame=3.100=h7b6447c_0 23 | - lcms2=2.12=h3be6417_0 24 | - ld_impl_linux-64=2.35.1=h7274673_9 25 | - libffi=3.3=he6710b0_2 26 | - libgcc-ng=9.3.0=h5101ec6_17 27 | - libgomp=9.3.0=h5101ec6_17 28 | - libiconv=1.15=h63c8f33_5 29 | - libidn2=2.3.2=h7f8727e_0 30 | - libpng=1.6.37=hbc83047_0 31 | - libstdcxx-ng=9.3.0=hd4cf53a_17 32 | - libtasn1=4.16.0=h27cfd23_0 33 | - libtiff=4.2.0=h85742a9_0 34 | - libunistring=0.9.10=h27cfd23_0 35 | - libuv=1.40.0=h7b6447c_0 36 | - libwebp=1.2.0=h89dd481_0 37 | - libwebp-base=1.2.0=h27cfd23_0 38 | - lz4-c=1.9.3=h295c915_1 39 | - mkl=2021.4.0=h06a4308_640 40 | - mkl-service=2.4.0=py37h7f8727e_0 41 | - mkl_fft=1.3.1=py37hd3c417c_0 42 | - mkl_random=1.2.2=py37h51133e4_0 43 | - ncurses=6.3=h7f8727e_2 44 | - nettle=3.7.3=hbbd107a_1 45 | - numpy=1.21.2=py37h20f2e39_0 46 | - numpy-base=1.21.2=py37h79a1101_0 47 | - olefile=0.46=py37_0 48 | - openh264=2.1.1=h4ff587b_0 49 | - openssl=1.1.1n=h7f8727e_0 50 | - pillow=8.4.0=py37h5aabda8_0 51 | - pip=21.2.2=py37h06a4308_0 52 | - python=3.7.11=h12debd9_0 53 | - python_abi=3.7=2_cp37m 54 | - pytorch=1.10.1=py3.7_cuda11.3_cudnn8.2.0_0 55 | - pytorch-mutex=1.0=cuda 56 | - readline=8.1.2=h7f8727e_1 57 | - setuptools=58.0.4=py37h06a4308_0 58 | - six=1.16.0=pyhd3eb1b0_0 59 | - sqlite=3.37.0=hc218d9a_0 60 | - tk=8.6.11=h1ccaba5_0 61 | - torchaudio=0.10.1=py37_cu113 62 | - torchvision=0.11.2=py37_cu113 63 | - tqdm=4.64.0=pyhd8ed1ab_0 64 | - typing_extensions=3.10.0.2=pyh06a4308_0 65 | - wheel=0.37.1=pyhd3eb1b0_0 66 | - xz=5.2.5=h7b6447c_0 67 | - zlib=1.2.11=h7f8727e_4 68 | - zstd=1.4.9=haebb681_0 69 | - pip: 70 | - absl-py==1.0.0 71 | - aiohttp==3.8.1 72 | - aiosignal==1.2.0 73 | - antlr4-python3-runtime==4.8 74 | - async-timeout==4.0.2 75 | - asynctest==0.13.0 76 | - attrs==21.4.0 77 | - av==9.2.0 78 | - blessed==1.19.1 79 | - blessings==1.7 80 | - cached-property==1.5.2 81 | - cachetools==5.1.0 82 | - charset-normalizer==2.0.12 83 | - click==8.0.4 84 | - cloudpickle==2.0.0 85 | - cycler==0.11.0 86 | - datasets==2.3.2 87 | - decorator==4.4.2 88 | - dill==0.3.5.1 89 | - docker-pycreds==0.4.0 90 | - einops==0.4.1 91 | - filelock==3.6.0 92 | - fonttools==4.33.3 93 | - frozenlist==1.3.0 94 | - fsspec==2022.5.0 95 | - fvcore==0.1.5.post20220512 96 | - gitdb==4.0.9 97 | - gitpython==3.1.27 98 | - google-auth==2.6.6 99 | - google-auth-oauthlib==0.4.6 100 | - gpustat==0.6.0 101 | - grpcio==1.46.3 102 | - h5py==3.6.0 103 | - huggingface-hub==0.5.1 104 | - hydra-core==1.1.1 105 | - idna==3.3 106 | - imageio==2.19.2 107 | - imageio-ffmpeg==0.4.7 108 | - importlib-metadata==4.11.1 109 | - importlib-resources==5.4.0 110 | - inquirer==2.9.2 111 | - iopath==0.1.9 112 | - joblib==1.1.0 113 | - kiwisolver==1.4.2 114 | - lmdb==1.3.0 115 | - markdown==3.3.7 116 | - matplotlib==3.5.2 117 | - moviepy==1.0.3 118 | - multidict==6.0.2 119 | - multiprocess==0.70.13 120 | - munch==2.5.0 121 | - networkx==2.6.3 122 | - numpyencoder==0.3.0 123 | - nvidia-ml-py3==7.352.0 124 | - oauthlib==3.2.0 125 | - omegaconf==2.1.1 126 | - opencv-python==4.5.5.64 127 | - packaging==21.3 128 | - pandas==1.3.5 129 | - parameterized==0.8.1 130 | - pathtools==0.1.2 131 | - portalocker==2.4.0 132 | - pretrainedmodels==0.7.4 133 | - proglog==0.1.10 134 | - promise==2.3 135 | - protobuf==3.19.4 136 | - psutil==5.9.0 137 | - pyarrow==8.0.0 138 | - pyasn1==0.4.8 139 | - pyasn1-modules==0.2.8 140 | - pyparsing==3.0.8 141 | - python-dateutil==2.8.2 142 | - python-editor==1.0.4 143 | - pytorchvideo==0.1.5 144 | - pytz==2021.3 145 | - pyyaml==6.0 146 | - readchar==3.0.5 147 | - regex==2022.4.24 148 | - requests==2.27.1 149 | - requests-oauthlib==1.3.1 150 | - responses==0.18.0 151 | - rsa==4.8 152 | - sacremoses==0.0.49 153 | - scipy==1.7.3 154 | - seaborn==0.11.2 155 | - sentry-sdk==1.5.5 156 | - shortuuid==1.0.8 157 | - smmap==5.0.0 158 | - submitit==1.4.2 159 | - tabulate==0.8.9 160 | - tensorboard==2.9.0 161 | - tensorboard-data-server==0.6.1 162 | - tensorboard-plugin-wit==1.8.1 163 | - termcolor==1.1.0 164 | - timm==0.5.4 165 | - tokenizers==0.12.1 166 | - transformers==4.18.0 167 | - urllib3==1.26.8 168 | - wandb==0.12.10 169 | - wcwidth==0.2.5 170 | - werkzeug==2.1.2 171 | - wget==3.2 172 | - xxhash==3.0.0 173 | - yacs==0.1.8 174 | - yarl==1.7.2 175 | - yaspin==2.1.0 176 | - zipp==3.7.0 177 | prefix: /home/haicore-project-kit/on3546/anaconda3/envs/action 178 | -------------------------------------------------------------------------------- /annotations/ek100_rulstm/training_videos.csv: -------------------------------------------------------------------------------- 1 | P01_01 2 | P01_02 3 | P01_03 4 | P01_04 5 | P01_05 6 | P01_06 7 | P01_07 8 | P01_08 9 | P01_09 10 | P01_102 11 | P01_103 12 | P01_104 13 | P01_105 14 | P01_106 15 | P01_107 16 | P01_108 17 | P01_109 18 | P01_10 19 | P01_16 20 | P01_17 21 | P01_18 22 | P01_19 23 | P02_01 24 | P02_02 25 | P02_03 26 | P02_04 27 | P02_05 28 | P02_06 29 | P02_07 30 | P02_08 31 | P02_09 32 | P02_101 33 | P02_102 34 | P02_103 35 | P02_104 36 | P02_105 37 | P02_107 38 | P02_108 39 | P02_109 40 | P02_10 41 | P02_110 42 | P02_111 43 | P02_112 44 | P02_113 45 | P02_114 46 | P02_115 47 | P02_116 48 | P02_118 49 | P02_119 50 | P02_11 51 | P02_120 52 | P02_121 53 | P02_122 54 | P02_123 55 | P02_124 56 | P02_126 57 | P02_127 58 | P02_128 59 | P02_129 60 | P02_130 61 | P02_131 62 | P02_132 63 | P02_133 64 | P02_134 65 | P02_135 66 | P03_02 67 | P03_03 68 | P03_04 69 | P03_05 70 | P03_06 71 | P03_07 72 | P03_08 73 | P03_09 74 | P03_101 75 | P03_102 76 | P03_106 77 | P03_107 78 | P03_108 79 | P03_109 80 | P03_10 81 | P03_110 82 | P03_111 83 | P03_112 84 | P03_113 85 | P03_114 86 | P03_115 87 | P03_116 88 | P03_117 89 | P03_118 90 | P03_119 91 | P03_11 92 | P03_120 93 | P03_121 94 | P03_122 95 | P03_123 96 | P03_12 97 | P03_13 98 | P03_14 99 | P03_15 100 | P03_16 101 | P03_17 102 | P03_18 103 | P03_19 104 | P03_20 105 | P03_27 106 | P03_28 107 | P04_01 108 | P04_02 109 | P04_03 110 | P04_04 111 | P04_05 112 | P04_06 113 | P04_07 114 | P04_08 115 | P04_09 116 | P04_101 117 | P04_102 118 | P04_103 119 | P04_104 120 | P04_106 121 | P04_107 122 | P04_108 123 | P04_109 124 | P04_10 125 | P04_110 126 | P04_111 127 | P04_112 128 | P04_113 129 | P04_114 130 | P04_115 131 | P04_116 132 | P04_117 133 | P04_118 134 | P04_119 135 | P04_11 136 | P04_120 137 | P04_121 138 | P04_12 139 | P04_13 140 | P04_14 141 | P04_15 142 | P04_16 143 | P04_17 144 | P04_18 145 | P04_19 146 | P04_20 147 | P04_21 148 | P04_22 149 | P04_23 150 | P05_01 151 | P05_02 152 | P05_03 153 | P05_04 154 | P05_05 155 | P05_06 156 | P05_08 157 | P06_01 158 | P06_02 159 | P06_03 160 | P06_05 161 | P06_07 162 | P06_08 163 | P06_09 164 | P06_101 165 | P06_102 166 | P06_103 167 | P06_104 168 | P06_105 169 | P06_106 170 | P06_107 171 | P06_108 172 | P06_109 173 | P06_110 174 | P06_113 175 | P07_01 176 | P07_02 177 | P07_03 178 | P07_04 179 | P07_05 180 | P07_06 181 | P07_07 182 | P07_08 183 | P07_09 184 | P07_101 185 | P07_102 186 | P07_103 187 | P07_106 188 | P07_107 189 | P07_10 190 | P07_110 191 | P07_111 192 | P07_112 193 | P07_113 194 | P07_114 195 | P07_115 196 | P07_116 197 | P07_117 198 | P07_11 199 | P08_01 200 | P08_02 201 | P08_03 202 | P08_04 203 | P08_05 204 | P08_06 205 | P08_07 206 | P08_08 207 | P08_11 208 | P08_12 209 | P08_13 210 | P08_18 211 | P08_19 212 | P08_20 213 | P08_21 214 | P08_22 215 | P08_23 216 | P08_24 217 | P08_25 218 | P08_26 219 | P08_27 220 | P08_28 221 | P09_01 222 | P09_02 223 | P09_03 224 | P09_04 225 | P09_05 226 | P09_06 227 | P09_103 228 | P09_104 229 | P09_105 230 | P09_106 231 | P10_01 232 | P10_02 233 | P10_04 234 | P11_01 235 | P11_02 236 | P11_03 237 | P11_04 238 | P11_05 239 | P11_06 240 | P11_07 241 | P11_08 242 | P11_09 243 | P11_101 244 | P11_102 245 | P11_103 246 | P11_104 247 | P11_105 248 | P11_107 249 | P11_109 250 | P11_10 251 | P11_11 252 | P11_12 253 | P11_13 254 | P11_14 255 | P11_15 256 | P11_16 257 | P12_01 258 | P12_02 259 | P12_04 260 | P12_05 261 | P12_06 262 | P12_07 263 | P12_101 264 | P12_103 265 | P12_104 266 | P12_105 267 | P13_04 268 | P13_05 269 | P13_06 270 | P13_07 271 | P13_08 272 | P13_09 273 | P13_10 274 | P14_01 275 | P14_02 276 | P14_03 277 | P14_04 278 | P14_05 279 | P14_07 280 | P14_09 281 | P15_01 282 | P15_02 283 | P15_03 284 | P15_07 285 | P15_08 286 | P15_09 287 | P15_10 288 | P15_11 289 | P15_12 290 | P15_13 291 | P16_01 292 | P16_02 293 | P16_03 294 | P17_01 295 | P17_03 296 | P17_04 297 | P19_01 298 | P19_02 299 | P19_03 300 | P19_04 301 | P20_01 302 | P20_02 303 | P20_03 304 | P20_04 305 | P21_01 306 | P21_03 307 | P21_04 308 | P22_05 309 | P22_06 310 | P22_07 311 | P22_08 312 | P22_09 313 | P22_101 314 | P22_102 315 | P22_103 316 | P22_104 317 | P22_105 318 | P22_106 319 | P22_107 320 | P22_108 321 | P22_109 322 | P22_10 323 | P22_110 324 | P22_111 325 | P22_112 326 | P22_113 327 | P22_115 328 | P22_116 329 | P22_117 330 | P22_11 331 | P22_12 332 | P22_13 333 | P22_14 334 | P22_15 335 | P22_16 336 | P22_17 337 | P23_01 338 | P23_02 339 | P23_03 340 | P23_04 341 | P23_101 342 | P23_102 343 | P24_01 344 | P24_02 345 | P24_03 346 | P24_04 347 | P24_05 348 | P24_06 349 | P24_07 350 | P24_08 351 | P25_01 352 | P25_02 353 | P25_03 354 | P25_04 355 | P25_05 356 | P25_09 357 | P25_101 358 | P25_102 359 | P25_103 360 | P25_104 361 | P25_106 362 | P25_107 363 | P25_10 364 | P25_11 365 | P25_12 366 | P26_01 367 | P26_02 368 | P26_03 369 | P26_04 370 | P26_05 371 | P26_06 372 | P26_07 373 | P26_08 374 | P26_09 375 | P26_101 376 | P26_102 377 | P26_103 378 | P26_104 379 | P26_105 380 | P26_106 381 | P26_107 382 | P26_108 383 | P26_109 384 | P26_10 385 | P26_110 386 | P26_111 387 | P26_112 388 | P26_113 389 | P26_114 390 | P26_115 391 | P26_116 392 | P26_117 393 | P26_118 394 | P26_119 395 | P26_11 396 | P26_124 397 | P26_12 398 | P26_13 399 | P26_14 400 | P26_15 401 | P26_16 402 | P26_17 403 | P26_18 404 | P26_19 405 | P26_20 406 | P26_21 407 | P26_22 408 | P26_23 409 | P26_24 410 | P26_25 411 | P26_26 412 | P26_27 413 | P26_28 414 | P26_29 415 | P27_01 416 | P27_02 417 | P27_03 418 | P27_04 419 | P27_06 420 | P27_07 421 | P27_101 422 | P27_103 423 | P27_104 424 | P27_105 425 | P28_01 426 | P28_02 427 | P28_03 428 | P28_04 429 | P28_05 430 | P28_06 431 | P28_07 432 | P28_08 433 | P28_09 434 | P28_101 435 | P28_102 436 | P28_103 437 | P28_104 438 | P28_105 439 | P28_106 440 | P28_107 441 | P28_108 442 | P28_109 443 | P28_10 444 | P28_110 445 | P28_111 446 | P28_112 447 | P28_113 448 | P28_11 449 | P28_12 450 | P28_13 451 | P28_14 452 | P29_01 453 | P29_02 454 | P29_03 455 | P29_04 456 | P30_01 457 | P30_02 458 | P30_03 459 | P30_04 460 | P30_05 461 | P30_06 462 | P30_101 463 | P30_103 464 | P30_104 465 | P30_107 466 | P30_108 467 | P30_109 468 | P30_10 469 | P30_110 470 | P30_111 471 | P30_112 472 | P30_113 473 | P30_114 474 | P30_11 475 | P31_01 476 | P31_02 477 | P31_03 478 | P31_04 479 | P31_05 480 | P31_06 481 | P31_07 482 | P31_08 483 | P31_09 484 | P31_13 485 | P31_14 486 | P35_101 487 | P35_103 488 | P35_104 489 | P35_105 490 | P35_107 491 | P35_108 492 | P35_109 493 | P37_101 494 | P37_102 495 | P37_103 496 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | from torch import nn 4 | from torch.utils.data import DataLoader 5 | from tqdm import tqdm 6 | import hydra 7 | from omegaconf import OmegaConf, DictConfig, ListConfig 8 | import numpy as np 9 | import os 10 | import h5py 11 | from collections import defaultdict 12 | 13 | from models.base_model import BaseModel 14 | from datasets.data import get_dataset 15 | from train import get_transform_val, init_model 16 | from challenge import marginalize_verb_noun, print_accuracies_epic, LOGITS_DIR 17 | from train import DATASET_EVAL_CFG_KEY 18 | 19 | 20 | def store_append_h5(endpoints, output_dir, save_file_name): 21 | output_fpath = os.path.join(output_dir, save_file_name) 22 | os.makedirs(output_dir, exist_ok=True) 23 | with h5py.File(output_fpath, 'a') as fout: 24 | for key, val in endpoints.items(): 25 | if key not in fout: 26 | fout.create_dataset(key, data=val, compression='gzip', compression_opts=9, 27 | chunks=True, maxshape=(None, ) + val.shape[1:]) 28 | else: 29 | fout[key].resize((fout[key].shape[0] + val.shape[0], ) + val.shape[1:]) 30 | fout[key][-val.shape[0]:, ...] = val 31 | 32 | 33 | def save_logits(model, data_loader: DataLoader, device, logger, save_dir=None, save_file_name=None): 34 | """Saves logits to given path, so that the logits can be used for ensemble or any other analysis""" 35 | # construct kwargs for forwarding 36 | kwargs = {} 37 | kwargs['mixup_fn'] = None 38 | kwargs['target'] = None 39 | kwargs['target_subclips'] = None 40 | kwargs['target_subclips_ignore_index'] = None 41 | 42 | for idx, data in enumerate(tqdm(data_loader)): 43 | data, _ = data 44 | feature_dict = {mod: tens.to(device, non_blocking=True) for mod, tens in data["data_dict"].items()} 45 | outputs, outputs_target = model(feature_dict, **kwargs) 46 | 47 | logits_key = 'logits/action' 48 | 49 | logits = {} 50 | if len(outputs[logits_key]) == 1: # single modality or early fusion model 51 | modk = next(iter(outputs[logits_key].keys())) 52 | logits[f'{logits_key}_{modk}'] = outputs[f'{logits_key}'][modk][:, 0, :].detach().cpu().numpy() 53 | else: 54 | fusion_key = 'all-fused' 55 | logging.info(f'This model consists of multiple branches. ' 56 | f'Saving fusion branch "{fusion_key}" only ...') 57 | logits[f'{logits_key}_{fusion_key}'] = \ 58 | outputs[f'{logits_key}'][fusion_key][:, 0, :].detach().cpu().numpy() 59 | 60 | store_append_h5(logits, save_dir, save_file_name) 61 | logger.info(f'Saved logits {logits.keys()} as {save_file_name} to {save_dir}.') 62 | 63 | 64 | def evaluate(model, dataset, data_loader: DataLoader, device): 65 | """ 66 | Computes the verb, noun and action performance of overall, unseen and tail 67 | """ 68 | logits_key = 'logits/action' 69 | logits = defaultdict(list) 70 | 71 | # construct kwargs for forwarding 72 | kwargs = {} 73 | kwargs['mixup_fn'] = None 74 | kwargs['target'] = None 75 | kwargs['target_subclips'] = None 76 | kwargs['target_subclips_ignore_index'] = None 77 | 78 | # forwarding 79 | for idx, data in enumerate(tqdm(data_loader)): 80 | data, _ = data 81 | feature_dict = {mod: tens.to(device, non_blocking=True) for mod, tens in data["data_dict"].items()} 82 | outputs, outputs_target = model(feature_dict, **kwargs) 83 | 84 | if len(outputs[logits_key]) == 1: # single modality or early fusion model 85 | modk = next(iter(outputs[logits_key].keys())) 86 | logits[f'{logits_key}_{modk}'].append(outputs[f'{logits_key}'][modk][:, 0, :].detach().cpu().numpy()) 87 | else: 88 | fusion_key = 'all-fused' 89 | logging.info(f'This model consists of multiple branches. ' 90 | f'Saving fusion branch "{fusion_key}" only ...') 91 | logits[f'{logits_key}_{fusion_key}'].append( 92 | outputs[f'{logits_key}'][fusion_key][:, 0, :].detach().cpu().numpy()) 93 | 94 | # since we only save one entry 95 | logits_array = np.concatenate(next(iter(logits.values())), axis=0) 96 | 97 | accs, scores = marginalize_verb_noun(logits_array, dataset, to_prob=True, compute_manyshot_unseen_tail=True) 98 | print_accuracies_epic(accs) 99 | 100 | 101 | @hydra.main(config_path="conf", config_name="config") 102 | def main(cfg: DictConfig): 103 | print(OmegaConf.to_yaml(cfg)) 104 | logger = logging.getLogger(__name__) 105 | 106 | device = torch.device('cuda') 107 | transform_val = get_transform_val(cfg) 108 | dataset_test = get_dataset(getattr(cfg, DATASET_EVAL_CFG_KEY), cfg.data_eval, transform_val, logger) 109 | logger.info('Creating data loaders...') 110 | data_loader_test = torch.utils.data.DataLoader( 111 | dataset_test, 112 | batch_size=cfg.eval.batch_size or cfg.train.batch_size * 4, 113 | num_workers=cfg.workers, 114 | pin_memory=True, 115 | shuffle=False 116 | ) 117 | 118 | num_classes = {key: len(val) for key, val in dataset_test.classes.items()} 119 | model = BaseModel(cfg.model, num_classes=num_classes, class_mappings=dataset_test.class_mappings) 120 | 121 | # load pretrained weights 122 | assert cfg.init_from_model is not None, 'Checkpoint is required for test.' 123 | ckpt_paths = cfg.init_from_model 124 | if not isinstance(ckpt_paths, ListConfig): 125 | ckpt_paths = [ckpt_paths] 126 | ckpt_paths = [os.path.join(cfg.cwd, 'checkpoints', path) for path in ckpt_paths] 127 | modules_to_keep = None 128 | _ = init_model(model, ckpt_paths, modules_to_keep, logger) 129 | 130 | model = nn.DataParallel(model, device_ids=range(cfg.num_gpus)) 131 | model = model.to(device) # Sends model to device 0, other gpus are used automatically. 132 | 133 | # test 134 | model.eval() 135 | with torch.no_grad(): 136 | if 'save_name' in cfg: 137 | save_dir = os.path.join(cfg.cwd, LOGITS_DIR, cfg.init_from_model.split('/')[0]) 138 | save_logits(model, data_loader_test, device, logger, save_dir, cfg.save_name) 139 | else: 140 | evaluate(model, dataset_test, data_loader_test, device) 141 | 142 | 143 | if __name__ == '__main__': 144 | main() 145 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Anticipative Feature Fusion Transformer for Multi-Modal Action Anticipation (WACV 2023) 2 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/anticipative-feature-fusion-transformer-for/action-anticipation-on-epic-kitchens-100)](https://paperswithcode.com/sota/action-anticipation-on-epic-kitchens-100?p=anticipative-feature-fusion-transformer-for)
3 | 4 | This repository contains the official source code and data for 5 | our [AFFT](https://arxiv.org/abs/2210.12649) paper. 6 | If you find our code or paper useful, please consider citing: 7 | 8 | Z. Zhong, D. Schneider, M. Voit, R. Stiefelhagen and J. Beyerer. 9 | Anticipative Feature Fusion Transformer for Multi-Modal Action Anticipation. 10 | In *WACV*, 2023. 11 | 12 | ```bibtex 13 | @InProceedings{Zhong_2023_WACV, 14 | author = {Zhong, Zeyun and Schneider, David and Voit, Michael and Stiefelhagen, Rainer and Beyerer, J\"urgen}, 15 | title = {Anticipative Feature Fusion Transformer for Multi-Modal Action Anticipation}, 16 | booktitle = {Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision (WACV)}, 17 | month = {January}, 18 | year = {2023}, 19 | pages = {6068-6077} 20 | } 21 | ``` 22 | 23 |
24 | 25 |
26 | 27 | ## Installation 28 | First clone the repo and set up the required packages in a conda environment. 29 | 30 | ```bash 31 | $ git clone https://github.com/zeyun-zhong/AFFT.git 32 | $ conda env create -f environment.yaml python=3.7 33 | $ conda activate afft 34 | ``` 35 | 36 | ## Download Data 37 | ### Dataset features 38 | 39 | AFFT works on pre-extracted features, so you will need to download the features first. You can 40 | download the TSN-features from RULSTM for [EK100](https://github.com/fpv-iplab/rulstm/blob/master/RULSTM/scripts/download_data_ek100_full.sh) 41 | and for [EGTEA Gaze+](https://iplab.dmi.unict.it/sharing/rulstm/features/egtea.zip). 42 | The RGB-Swin features are available [here](https://cvhci.anthropomatik.kit.edu/~dschneider/epic-kitchens/features/rgb_omnivore.zip) and audio features are available [here](https://cvhci.anthropomatik.kit.edu/~dschneider/epic-kitchens/features/audio.zip). 43 | 44 | Please make sure that your data structure follows the structure shown below. Note that 45 | `dataset_root_dir` in [config.yaml](conf/config.yaml) should be changed to your specific data path. 46 | 47 | ``` 48 | Dataset root path (e.g., /home/user/datasets) 49 | ├── epickitchens100 50 | │ └── features 51 | │ │── rgb 52 | │ │ └── data.mdb 53 | │ │── rgb_omnivore 54 | │ │ └── data.mdb 55 | │ │── obj 56 | │ │ └── data.mdb 57 | │ │── audio 58 | │ │ └── data.mdb 59 | │ └── flow 60 | │ └── data.mdb 61 | └── egtea 62 | └── features 63 | │── TSN-C_3_egtea_action_CE_s1_rgb_model_best_fcfull_hd 64 | │ └── data.mdb 65 | │── TSN-C_3_egtea_action_CE_s1_flow_model_best_fcfull_hd 66 | │ └── data.mdb 67 | │── TSN-C_3_egtea_action_CE_s2_rgb_model_best_fcfull_hd 68 | │ └── data.mdb 69 | │── TSN-C_3_egtea_action_CE_s2_flow_model_best_fcfull_hd 70 | │ └── data.mdb 71 | │── TSN-C_3_egtea_action_CE_s3_rgb_model_best_fcfull_hd 72 | │ └── data.mdb 73 | └── TSN-C_3_egtea_action_CE_s3_flow_model_best_fcfull_hd 74 | └── data.mdb 75 | ``` 76 | 77 | If you use a different organization, you would need to edit `rulstm_feats_dir` in [EK100-common](conf/dataset/epic_kitchens100/common.yaml) 78 | and [EGTEA-common](conf/dataset/egtea/common.yaml). 79 | 80 | ### Model Zoo 81 | 82 | | Dataset | Modalities | Performance
(Actions) | Config | Model | 83 | |---------|:---------------------------------------------------------|-------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------|-----------| 84 | | EK100 | R-Swin, O, AU, F
R-TSN, O, AU, F
R-TSN, O, F | 18.5 (MT5R)
17.0 (MT5R)
16.4 (MT5R) | `expts/01_SA-Fuser_ek100_val_Swin.txt`
`expts/01_SA-Fuser_ek100_val_TSN.txt`
`expts/01_SA-Fuser_ek100_val_TSN_wo_audio.txt` | [link](https://bwsyncandshare.kit.edu/s/C2xLnk3H8WXgMbW)
[link](https://bwsyncandshare.kit.edu/s/M97peiEnjsW2Qji)
[link](https://bwsyncandshare.kit.edu/s/gZgH8gFJd4MFegR) | 85 | | EGTEA | RGB-TSN, Flow | 42.5 (Top-1) | `expts/02_ek100_avt_tsn.txt` | [link](https://bwsyncandshare.kit.edu/s/iXySQLxqZTdH4qB) | 86 | 87 | 88 | ## Training 89 | Recall that `dataset_root_dir` in [config.yaml](conf/config.yaml) should be changed to your specific path. 90 | >### EpicKitchens-100 91 | > ```bash 92 | > python run.py -c expts/01_SA-Fuser_ek100_train.txt --mode train --nproc_per_node 2 93 | >``` 94 | 95 | >### EGTEA Gaze+ 96 | > ```bash 97 | > python run.py -c expts/06_SA-Fuser_egtea_train.txt --mode train --nproc_per_node 2 98 | >``` 99 | 100 | ## Validation 101 | >### EpicKitchens-100 102 | > ```bash 103 | > python run.py -c expts/01_SA-Fuser_ek100_val_TSN_wo_audio.txt --mode test --nproc_per_node 1 104 | >``` 105 | 106 | >### EGTEA Gaze+ 107 | > ```bash 108 | > python run.py -c expts/06_SA-Fuser_egtea_val.txt --mode test --nproc_per_node 1 109 | >``` 110 | 111 | ## Test / Challenge (EK100) 112 | > ```bash 113 | > # save logits 114 | > python run.py -c expts/01_SA-Fuser_ek100_test_TSN_wo_audio.txt --mode test --nproc_per_node 1 115 | > 116 | > # generate test / challenge file 117 | > python challenge.py --prefix_h5 test --models fusion_ek100_tsn_wo_audio_4h_18s --weights 1. 118 | >``` 119 | 120 | 121 | ## License 122 | 123 | This codebase is released under the license terms specified in the [LICENSE](LICENSE) file. Any imported libraries, datasets or other code follows the license terms set by respective authors. 124 | 125 | 126 | ## Acknowledgements 127 | 128 | Many thanks to [Rohit Girdhar](https://github.com/facebookresearch/AVT) and [Antonino Furnari](https://github.com/fpv-iplab/rulstm) for providing their code and data. 129 | -------------------------------------------------------------------------------- /annotations/ek55_ori/EPIC_verb_classes.csv: -------------------------------------------------------------------------------- 1 | verb_id,class_key,verbs 2 | 0,take,"['take', 'grab', 'pick', 'draw', 'get', 'grab-up', 'collect-from', 'take-up', 'grab-down', 'pull-down', 'fetch', 'pick-up']" 3 | 1,put,"['put', 'pose', 'put-away', 'place-that', 'place-down', 'put', 'set-down', 'put-over', 'place-on', 'reture', 'put-on', 'place-away', 'swap', 'layer', 'position', 'put-down', 'put-back', 'place-back', 'return', 'pace', 'dose', 'put-up', 'leave', 'place', 'put-aside', 'lay', 'reposition', 'replace', 'put-onto']" 4 | 2,open,"['open', 'unzip', 'open-up']" 5 | 3,close,"['close', 'close-off', 'shut']" 6 | 4,wash,"['wash', 'sponge', 'lather', 'wash-with', 'rinse', 'rinse-off', 'soap-up', 'wash-off', 'soap', 'rise', 'wash-up', 'clean-around', 'clean-off', 'clean-with', 'clean', 'clean-up', 'wipe-down', 'wipe-of', 'wipe', 'wipe-off']" 7 | 5,cut,"['cut', 'chop', 'chop-off', 'cut-off', 'slice-into', 'slice-along', 'slice-off', 'stem', 'slice-up', 'cut-into', 'dice', 'half', 'halve', 'chop-up', 'snip', 'trim', 'slice']" 8 | 6,mix,"['mix', 'beat', 'mix-around', 'stir-with', 'whisk', 'stir', 'blend', 'mix-in', 'stir-in']" 9 | 7,pour,"['pour', 'pour-in', 'tip-in', 'pour-out', 'pour-on', 'pour-into', 'sieve']" 10 | 8,throw,"['throw', 'throw-out', 'recycle', 'dispose-of', 'throw-over', 'throw-away', 'throw-in', 'bin', 'trow', 'toss']" 11 | 9,move,"['move', 'transfer', 'move-around']" 12 | 10,remove,"['remove', 'extract', 'take-off', 'remove-out', 'take-out', 'get-out', 'remove-inside', 'remove-from', 'pick-out']" 13 | 11,dry,"['dry', 'dry-off', 'towel']" 14 | 12,turn-on,"['turn-on', 'start', 'begin', 'ignite', 'switch-on', 'activate', 'water-on', 'play', 'start-to', 'restart', 'light']" 15 | 13,turn,"['turn', 'rotate']" 16 | 14,shake,"['shake', 'shake-off', 'shake-out']" 17 | 15,turn-off,"['turn-off', 'switch-of', 'water-off', 'switch-off', 'switch-out', 'shut-off', 'turn-of']" 18 | 16,peel,"['peel', 'skin-from', 'skin', 'peel-off', 'peel-back']" 19 | 17,adjust,"['adjust', 'change', 'regulate']" 20 | 18,empty,['empty'] 21 | 19,scoop,"['scoop', 'spoon-in', 'spoon', 'scoop-out', 'scoop-up']" 22 | 20,check,"['check', 'ensure', 'test', 'look-in', 'watch', 'inspect', 'check-on']" 23 | 21,squeeze,"['squeeze', 'squidge', 'squidge-into', 'squash', 'squish-into', 'wring-out', 'wring', 'squish', 'squeeze-into']" 24 | 22,insert,"['insert', 'put-in', 'put-into', 'fit', 'place-in', 'put-inside']" 25 | 23,press,"['press', 'push-down', 'collapse', 'compress', 'push', 'press-on']" 26 | 24,fill,"['fill', 'fill-with', 'fill-up', 'stuff']" 27 | 25,add,"['add', 'combine', 'add-to']" 28 | 26,scrape,"['scrape', 'scrape-out', 'loose', 'scour', 'scrap', 'scrape-off']" 29 | 27,sharpen,"['sharpen', 'thin']" 30 | 28,wrap,"['wrap', 'top', 'lid', 'cover', 'seal', 'clamp', 'fasten', 'reseal', 'clip', 'wrap-up', 'tie']" 31 | 29,roll,"['roll', 'roll-up']" 32 | 30,sprinkle,"['sprinkle', 'sprinkle-on', 'drizzle', 'scatter', 'sprincle', 'crumble']" 33 | 31,break,"['break', 'snap', 'crack', 'break-up']" 34 | 32,flip,"['flip', 'overturn', 'turn-over']" 35 | 33,hang,"['hang', 'drape', 'hang-up']" 36 | 34,hold,['hold'] 37 | 35,sort,"['sort', 'rearrange', 'arrange', 'clear', 'tidy', 'line-up']" 38 | 36,apply,"['apply', 'spread']" 39 | 37,crush,"['crush', 'hit', 'grind', 'grate', 'mash', 'crush']" 40 | 38,search,"['search', 'search-for', 'search-in', 'look-for', 'find', 'locate']" 41 | 39,sample,"['sample', 'taste', 'smell', 'lick']" 42 | 40,knead,['knead'] 43 | 41,set,"['set', 'set-up', 'set-out']" 44 | 42,walk,"['walk', 'enter', 'enter-into', 'walk-into', 'walk-around', 'walk-to', 'walk-with']" 45 | 43,divide,"['divide', 'split', 'detach', 'separate', 'distribute']" 46 | 44,spray,"['spray', 'spay']" 47 | 45,use,"['use', 'used-to', 'used']" 48 | 46,fold,['fold'] 49 | 47,cook,"['cook', 'toast', 'fry-in', 'heat', 'fry', 'reduce', 'boil']" 50 | 48,filter,"['filter', 'strain', 'drain', 'dump-out']" 51 | 49,scrub,"['scrub', 'scrub-inside']" 52 | 50,look,"['look', 'stare-at', 'see', 'read-on', 'read', 'look-at']" 53 | 51,finish,"['finish', 'stop', 'end', 'do']" 54 | 52,soak,"['soak', 'submerge', 'immerse']" 55 | 53,brush,['brush'] 56 | 54,pull,"['pull', 'pull-out']" 57 | 55,pat,"['pat', 'pat-into', 'dab', 'tap', 'tap-on', 'poke', 'pat-down']" 58 | 56,form,"['form', 'shape', 'forge', 'make', 'shape-into']" 59 | 57,measure,['measure'] 60 | 58,drink,"['drink', 'drink-from']" 61 | 59,choose,"['choose', 'select']" 62 | 60,serve,"['serve', 'plate-on', 'plate', 'dish', 'plate-up']" 63 | 61,drop,"['drop', 'drop-on']" 64 | 62,wear,['wear'] 65 | 63,rip,"['rip', 'tear', 'tear-off', 'tear-down']" 66 | 64,tip,"['tip', 'tip-over', 'tip-out']" 67 | 65,turn-down,['turn-down'] 68 | 66,gather,"['gather', 'collect']" 69 | 67,eat,"['eat', 'chew-on', 'bite']" 70 | 68,stack,"['stack', 'stack-up']" 71 | 69,store,['store'] 72 | 70,switch,['switch'] 73 | 71,increase,"['increase', 'switch-up', 'turn-up']" 74 | 72,carry,"['carry', 'have', 'bring']" 75 | 73,lift,"['lift', 'raise', 'tilt', 'lift-up']" 76 | 74,twist,"['twist', 'screw', 'screw-in', 'screw-on', 'tighten']" 77 | 75,sweep,['sweep'] 78 | 76,rub,"['rub', 'rub-off']" 79 | 77,unwrap,"['unwrap', 'unpack', 'unseal']" 80 | 78,stab,['stab'] 81 | 79,attach,"['attach', 'connect', 'plug', 'assemble', 'plug-into', 'plug-in']" 82 | 80,stretch,"['stretch', 'unfold']" 83 | 81,lower,['lower'] 84 | 82,prepare,['prepare'] 85 | 83,unscrew,"['unscrew', 'untwist']" 86 | 84,season,"['season', 'sweeten', 'pepper', 'salt']" 87 | 85,video,['video'] 88 | 86,tap-off,['tap-off'] 89 | 87,set-off,['set-off'] 90 | 88,squirt,['squirt'] 91 | 89,load,['load'] 92 | 90,unroll,"['unroll', 'roll-out']" 93 | 91,water,"['water', 'wet']" 94 | 92,do,['do'] 95 | 93,flatten,"['flatten', 'flatten-with']" 96 | 94,uncover,"['uncover', 'lid-off']" 97 | 95,slide,"['slide', 'slide-out']" 98 | 96,unplug,['unplug'] 99 | 97,level,['level'] 100 | 98,tear-out,['tear-out'] 101 | 99,feel,"['feel', 'touch']" 102 | 100,fix,['fix'] 103 | 101,spill,['spill'] 104 | 102,pack,['pack'] 105 | 103,bake,['bake'] 106 | 104,blow,"['blow', 'blow-out']" 107 | 105,sit-on,['sit-on'] 108 | 106,count,['count'] 109 | 107,dip,['dip'] 110 | 108,cool,['cool'] 111 | 109,flush,['flush'] 112 | 110,knife,['knife'] 113 | 111,fork,['fork'] 114 | 112,swirl,['swirl'] 115 | 113,stick,['stick'] 116 | 114,pet-down,['pet-down'] 117 | 115,realize,['realize'] 118 | 116,weigh,['weigh'] 119 | 117,defoliate,['defoliate'] 120 | 118,deseed,['deseed'] 121 | 119,tessellate,['tessellate'] 122 | 120,unfreeze,['unfreeze'] 123 | 121,decide-if,['decide-if'] 124 | 122,let-out,['let-out'] 125 | 123,save,['save'] 126 | 124,reverse,['reverse'] 127 | -------------------------------------------------------------------------------- /common/scheduler.py: -------------------------------------------------------------------------------- 1 | """copied from AVT""" 2 | 3 | from typing import Sequence 4 | 5 | import torch 6 | from bisect import bisect_right 7 | 8 | 9 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 10 | def __init__( 11 | self, 12 | optimizer: torch.optim.Optimizer, 13 | milestone_epochs: Sequence[int], 14 | gamma: float = 0.1, 15 | warmup_factor: float = 1.0 / 3, 16 | warmup_epochs: int = 5, 17 | warmup_method: str = 'linear', 18 | last_epoch: int = -1, 19 | iters_per_epoch: int = None, # Must be set by calling code 20 | world_size: int = None, 21 | ): 22 | del world_size 23 | if not milestone_epochs == sorted(milestone_epochs): 24 | raise ValueError( 25 | "Milestones should be a list of" 26 | " increasing integers. Got {}", 27 | milestone_epochs, 28 | ) 29 | 30 | if warmup_method not in ("constant", "linear"): 31 | raise ValueError( 32 | "Only 'constant' or 'linear' warmup_method accepted" 33 | "got {}".format(warmup_method)) 34 | self.milestones = [iters_per_epoch * m for m in milestone_epochs] 35 | self.gamma = gamma 36 | self.warmup_factor = warmup_factor 37 | self.warmup_iters = max(warmup_epochs * iters_per_epoch, 1) 38 | 39 | self.warmup_method = warmup_method 40 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 41 | 42 | def get_lr(self): 43 | warmup_factor = 1 44 | if self.last_epoch < self.warmup_iters: 45 | if self.warmup_method == "constant": 46 | warmup_factor = self.warmup_factor 47 | elif self.warmup_method == "linear": 48 | alpha = float(self.last_epoch) / self.warmup_iters 49 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 50 | return [ 51 | base_lr * warmup_factor * 52 | self.gamma**bisect_right(self.milestones, self.last_epoch) 53 | for base_lr in self.base_lrs 54 | ] 55 | 56 | 57 | class CosineLR(torch.optim.lr_scheduler.CosineAnnealingLR): 58 | def __init__(self, 59 | optimizer, 60 | num_epochs, 61 | iters_per_epoch=None, 62 | world_size=None, 63 | **kwargs): 64 | kwargs['eta_min'] *= world_size 65 | super().__init__(optimizer, 66 | T_max=num_epochs * iters_per_epoch, 67 | **kwargs) 68 | 69 | def get_lr(self, *args, **kwargs): 70 | if self.last_epoch < self.T_max: 71 | return super().get_lr(*args, **kwargs) 72 | else: 73 | # Adding this if I train the model longer than the T_max set in 74 | # this. Happens when I sweep over different amounts of warmup. 75 | return [0.0 for _ in self.optimizer.param_groups] 76 | 77 | 78 | class ReduceLROnPlateau(torch.optim.lr_scheduler.ReduceLROnPlateau): 79 | def __init__(self, 80 | optimizer, 81 | iters_per_epoch=None, 82 | world_size=None, 83 | **kwargs): 84 | del iters_per_epoch, world_size 85 | super().__init__(optimizer, **kwargs) 86 | 87 | 88 | class Warmup(torch.optim.lr_scheduler._LRScheduler): 89 | """Wrap the scheduler for warmup before it kicks in.""" 90 | def __init__( 91 | self, 92 | optimizer: torch.optim.Optimizer, 93 | scheduler: torch.optim.lr_scheduler._LRScheduler, 94 | init_lr_ratio: float = 0.0, 95 | num_epochs: int = 5, 96 | last_epoch: int = -1, 97 | iters_per_epoch: int = None, # Must be set by calling code 98 | world_size: int = None, 99 | ): 100 | """ 101 | Args: 102 | init_lr_ratio (float in [0, 1]): Ratio of the original LR to start 103 | from. If 0.1, it will start from 0.1 of the original LRs and go 104 | upto 1.0 of the original LRs in the epochs. By def start from 105 | 0 up. 106 | num_epochs (int): Num of epochs to take to warmup. 107 | last_epoch (int): Which was the last epoch to init from (not really 108 | used anymore since we store the state_dict when loading 109 | scheduler from disk.) 110 | """ 111 | del world_size 112 | self.base_scheduler = scheduler 113 | self.warmup_iters = max(num_epochs * iters_per_epoch, 1) 114 | if self.warmup_iters > 1: 115 | self.init_lr_ratio = init_lr_ratio 116 | else: 117 | self.init_lr_ratio = 1.0 # Don't go from 0 to 1 in 1 iteration 118 | super().__init__(optimizer, last_epoch) 119 | 120 | def get_lr(self): 121 | # Epoch is iters for me, since I step after each iteration 122 | # (not after each epoch) 123 | # Based on logic in step, this should only be called for the warmup 124 | # iters. After that it should go to the base scheduler 125 | assert self.last_epoch < self.warmup_iters # since it increments 126 | return [ 127 | el * (self.init_lr_ratio + (1 - self.init_lr_ratio) * 128 | (float(self.last_epoch) / self.warmup_iters)) 129 | for el in self.base_lrs 130 | ] 131 | 132 | def step(self, *args, **kwargs): 133 | if self.last_epoch < (self.warmup_iters - 1): 134 | super().step(*args, **kwargs) 135 | else: 136 | self.base_scheduler.step(*args, **kwargs) 137 | 138 | def state_dict(self): 139 | """Returns the state of the scheduler as a :class:`dict`. 140 | 141 | It contains an entry for every variable in self.__dict__ which 142 | is not the optimizer. 143 | """ 144 | base_sched_dict = self.base_scheduler.state_dict() 145 | other_stuff = { 146 | key: value 147 | for key, value in self.__dict__.items() if key not in [ 148 | 'base_scheduler', 'optimizer'] 149 | } 150 | return {'base_sched_dict': base_sched_dict, 'other_stuff': other_stuff} 151 | 152 | def load_state_dict(self, state_dict): 153 | """Loads the schedulers state. 154 | 155 | Arguments: 156 | state_dict (dict): scheduler state. Should be an object returned 157 | from a call to :meth:`state_dict`. 158 | """ 159 | self.base_scheduler.__dict__.update(state_dict['base_sched_dict']) 160 | self.__dict__.update(state_dict['other_stuff']) 161 | -------------------------------------------------------------------------------- /common/metric_tracking.py: -------------------------------------------------------------------------------- 1 | """Implementation of metrictracker in training process""" 2 | 3 | import numpy as np 4 | from typing import Dict 5 | from common.utils import is_dist_avail_and_initialized 6 | import torch 7 | import torch.distributed as dist 8 | 9 | 10 | class MeanTopKRecallMeter(object): 11 | """adapted from RULSTM""" 12 | def __init__(self, name, num_classes: int, k=5, string_format='{:.3f}'): 13 | self.name = name 14 | self.num_classes = num_classes 15 | self.k = k 16 | self.string_format = string_format 17 | 18 | def reset(self): 19 | self.tps = np.zeros(self.num_classes) 20 | self.nums = np.zeros(self.num_classes) 21 | 22 | def update(self, logits_labels_dict, n=1): 23 | del n # not used here 24 | scores = logits_labels_dict['logits'] 25 | labels = logits_labels_dict['labels'] 26 | tp = (np.argsort(scores, axis=1)[:, -self.k:] == labels.reshape(-1, 1)).max(1) 27 | for l in np.unique(labels): 28 | self.tps[l] += tp[labels == l].sum() 29 | self.nums[l] += (labels == l).sum() 30 | 31 | def synchronize_between_processes(self): 32 | if not is_dist_avail_and_initialized(): 33 | return 34 | tps_all = torch.tensor(self.tps, device='cuda') 35 | nums_all = torch.tensor(self.nums, device='cuda') 36 | dist.barrier() 37 | dist.all_reduce(tps_all) 38 | dist.all_reduce(nums_all) 39 | self.tps = tps_all 40 | self.nums = nums_all 41 | 42 | @property 43 | def value(self): 44 | tps = self.tps[self.nums > 0] 45 | nums = self.nums[self.nums > 0] 46 | recalls = tps / nums 47 | if len(recalls) > 0: 48 | return recalls.mean() * 100 49 | else: 50 | return None 51 | 52 | def to_string(self): 53 | return self.string_format.format(self.value) 54 | 55 | 56 | class AverageMeter: 57 | """Computes and stores the average and current value""" 58 | def __init__(self, name, string_format='{:.3f}'): 59 | self.name = name 60 | self.string_format = string_format 61 | 62 | def reset(self): 63 | self.val, self.avg, self.sum, self.count = 0, 0, 0, 0 64 | 65 | def update(self, val, n=1): 66 | self.val = val 67 | self.sum += val * n 68 | self.count += n 69 | 70 | def synchronize_between_processes(self): 71 | if not is_dist_avail_and_initialized(): 72 | return 73 | count_all = torch.tensor(self.count, device='cuda') 74 | sum_all = torch.tensor(self.sum, device='cuda') 75 | dist.barrier() 76 | dist.all_reduce(count_all) 77 | dist.all_reduce(sum_all) 78 | self.count = count_all 79 | self.sum = sum_all 80 | 81 | @property 82 | def value(self): 83 | """returns the current floating average""" 84 | self.avg = self.sum / self.count 85 | return self.avg 86 | 87 | def to_string(self): 88 | return self.string_format.format(self.value) 89 | 90 | 91 | class MetricTracker: 92 | """Interface of all metrics, tracks multiple metrics""" 93 | def __init__(self, num_classes: Dict[str, int]): 94 | self.training_metrics = {} 95 | self.validation_metrics = {} 96 | self.num_classes = num_classes 97 | self.training_prefix = 'train_' 98 | self.validation_prefix = 'val_' 99 | 100 | def _get_num_classes(self, name): 101 | num_classes = None 102 | for key, value in self.num_classes.items(): 103 | if key in name: 104 | num_classes = value 105 | if num_classes is None: 106 | raise ValueError('Name of the mt5r metric muss contain action, verb or noun.') 107 | return num_classes 108 | 109 | def add_metric(self, name, is_training=None): 110 | meter = AverageMeter(name) 111 | if 'mt5r' in name: 112 | num_classes = self._get_num_classes(name) 113 | meter = MeanTopKRecallMeter(name, num_classes) 114 | 115 | # reset the meter 116 | meter.reset() 117 | 118 | if is_training is None: 119 | self.training_metrics[name] = meter 120 | self.validation_metrics[name] = meter 121 | elif is_training: 122 | self.training_metrics[name] = meter 123 | else: 124 | self.validation_metrics[name] = meter 125 | 126 | def update(self, metric_dict: Dict, batch_size: int, is_training: bool): 127 | if is_training: 128 | metrics = self.training_metrics 129 | prefix = self.training_prefix 130 | else: 131 | metrics = self.validation_metrics 132 | prefix = self.validation_prefix 133 | 134 | for key, value in metric_dict.items(): 135 | key = prefix + key 136 | if key not in metrics: 137 | self.add_metric(key, is_training) 138 | metrics[key].update(value, batch_size) 139 | 140 | def synchronize_between_processes(self, is_training): 141 | if is_training: 142 | metrics = self.training_metrics 143 | else: 144 | metrics = self.validation_metrics 145 | 146 | for key in metrics: 147 | metrics[key].synchronize_between_processes() 148 | 149 | def reset(self): 150 | """reset all metrics at the beginning of each training epoch""" 151 | for name in self.training_metrics: 152 | self.training_metrics[name].reset() 153 | for name in self.validation_metrics: 154 | self.validation_metrics[name].reset() 155 | 156 | def get_all_data(self, is_training): 157 | """returns the current values of all tracked metrics""" 158 | if is_training: 159 | metrics = self.training_metrics 160 | else: 161 | metrics = self.validation_metrics 162 | data = {} 163 | for key in metrics: 164 | data[key] = metrics[key].value 165 | return data 166 | 167 | def get_data(self, metric_name, is_training): 168 | """returns the current value of the metric""" 169 | if is_training: 170 | return self.training_metrics[metric_name].value 171 | else: 172 | return self.validation_metrics[metric_name].value 173 | 174 | def to_string(self, is_training): 175 | """returns the string of all values""" 176 | if is_training: 177 | result = '\33[0;36;40m' + 'Training: ' 178 | metrics = self.training_metrics 179 | else: 180 | result = '\33[0;32;40m' + 'Validation: ' 181 | metrics = self.validation_metrics 182 | 183 | for key in metrics: 184 | result += metrics[key].name + ': ' + metrics[key].to_string() + ' ' 185 | return result + '\033[0m' 186 | -------------------------------------------------------------------------------- /models/transformerblock.py: -------------------------------------------------------------------------------- 1 | """Implementation of basic transformer architecture, code is based on Timm""" 2 | 3 | import torch.nn as nn 4 | import torch 5 | 6 | 7 | class Attention(nn.Module): 8 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 9 | super().__init__() 10 | self.num_heads = num_heads 11 | head_dim = dim // num_heads 12 | self.scale = qk_scale or head_dim ** -0.5 13 | 14 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 15 | self.attn_drop = nn.Dropout(attn_drop) 16 | self.proj = nn.Linear(dim, dim) 17 | self.proj_drop = nn.Dropout(proj_drop) 18 | 19 | def forward(self, x, attn_mask=None): 20 | B, N, C = x.shape 21 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 22 | q, k, v = qkv[0], qkv[1], qkv[2] 23 | 24 | attn = (q @ k.transpose(-2, -1)) * self.scale 25 | 26 | # add the mask to enable cross attention 27 | if attn_mask is not None: 28 | attn = attn + attn_mask 29 | 30 | attn = attn.softmax(dim=-1) 31 | attn = self.attn_drop(attn) 32 | 33 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 34 | x = self.proj(x) 35 | x = self.proj_drop(x) 36 | return x, attn 37 | 38 | 39 | class CrossAttention(nn.Module): 40 | """Cross attention used in transformer decoder""" 41 | def __init__(self, dim, mem_dim=None, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 42 | super().__init__() 43 | self.num_heads = num_heads 44 | head_dim = dim // num_heads 45 | self.scale = qk_scale or head_dim ** -0.5 46 | 47 | mem_dim = mem_dim or dim 48 | self.w_q = nn.Linear(dim, dim, bias=qkv_bias) 49 | self.w_k = nn.Linear(mem_dim, dim, bias=qkv_bias) 50 | self.w_v = nn.Linear(mem_dim, dim, bias=qkv_bias) 51 | 52 | self.attn_drop = nn.Dropout(attn_drop) 53 | self.proj = nn.Linear(dim, dim) 54 | self.proj_drop = nn.Dropout(proj_drop) 55 | 56 | def forward(self, x, mem, attn_mask=None): 57 | B, N, C = x.shape 58 | 59 | # calculate query, key, value for all heads 60 | q = self.w_q(x).view(B, N, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, num_heads, N, c) 61 | k = self.w_k(mem).view(B, N, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, num_heads, N, c) 62 | v = self.w_v(mem).view(B, N, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, num_heads, N, c) 63 | 64 | attn = (q @ k.transpose(-2, -1)) * self.scale 65 | 66 | # add the mask to enable "causality" 67 | if attn_mask is not None: 68 | attn = attn + attn_mask 69 | 70 | attn = attn.softmax(dim=-1) 71 | attn = self.attn_drop(attn) 72 | 73 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 74 | x = self.proj(x) 75 | x = self.proj_drop(x) 76 | return x 77 | 78 | 79 | class MLP(nn.Module): 80 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 81 | super().__init__() 82 | out_features = out_features or in_features 83 | hidden_features = hidden_features or in_features 84 | self.mlp = nn.Sequential( 85 | nn.Linear(in_features, hidden_features), 86 | act_layer(), 87 | nn.Linear(hidden_features, out_features), 88 | nn.Dropout(drop) 89 | ) 90 | 91 | def forward(self, x): 92 | x = self.mlp(x) 93 | return x 94 | 95 | 96 | def drop_path(x, drop_prob: float = 0., training: bool = False): 97 | if drop_prob == 0. or not training: 98 | return x 99 | keep_prob = 1 - drop_prob 100 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 101 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 102 | random_tensor.floor_() # binarize 103 | output = x.div(keep_prob) * random_tensor 104 | return output 105 | 106 | 107 | class DropPath(nn.Module): 108 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 109 | """ 110 | def __init__(self, drop_prob=None): 111 | super(DropPath, self).__init__() 112 | self.drop_prob = drop_prob 113 | 114 | def forward(self, x): 115 | return drop_path(x, self.drop_prob, self.training) 116 | 117 | 118 | class Block(nn.Module): 119 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 120 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 121 | super().__init__() 122 | self.norm1 = norm_layer(dim) 123 | self.attn = Attention( 124 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop 125 | ) 126 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 127 | self.norm2 = norm_layer(dim) 128 | mlp_hidden_dim = int(dim * mlp_ratio) 129 | self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 130 | 131 | def forward(self, x, attn_mask=None): 132 | attn_output, attn_weights = self.attn(self.norm1(x), attn_mask) 133 | x = x + self.drop_path(attn_output) 134 | x = x + self.drop_path(self.mlp(self.norm2(x))) 135 | return x, attn_weights 136 | 137 | 138 | class DecoderBlock(nn.Module): 139 | """Transformer decoder block with pre-layernorm""" 140 | def __init__(self, dim, mem_dim=None, num_heads=4, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 141 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 142 | super().__init__() 143 | self.norm_self = norm_layer(dim) 144 | self.attn = Attention( 145 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop 146 | ) 147 | self.cross_attn = CrossAttention( 148 | dim, mem_dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop 149 | ) 150 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 151 | self.norm_q = norm_layer(dim) 152 | self.norm_kv = norm_layer(mem_dim or dim) 153 | self.norm_mlp = norm_layer(dim) 154 | mlp_hidden_dim = int(dim * mlp_ratio) 155 | self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 156 | 157 | def forward(self, x, mem, attn_mask=None): 158 | attn_output, _ = self.attn(self.norm_self(x), attn_mask) # attention weights will be ignored 159 | x = x + self.drop_path(attn_output) # self attention + short-cut 160 | x = x + self.drop_path(self.cross_attn(self.norm_q(x), self.norm_kv(mem), attn_mask)) # cross attention + short-cut 161 | x = x + self.drop_path(self.mlp(self.norm_mlp(x))) # mlp 162 | return x 163 | -------------------------------------------------------------------------------- /common/mixup.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of mixup with ignore class, since some sequences donot have gt labels 3 | Mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412) 4 | """ 5 | 6 | from typing import Dict, Sequence, Union 7 | import torch 8 | 9 | 10 | def batch_wo_ignore_cls(target_subclips: torch.Tensor, ignore_cls=-1): 11 | target_subclips = target_subclips.squeeze(-1) # avoid dim like (B, 1) 12 | assert target_subclips.ndim == 2, "Target subclips should have dimension of 2." 13 | batch_index = (target_subclips != ignore_cls).all(-1) 14 | return batch_index 15 | 16 | 17 | def convert_to_one_hot( 18 | targets: torch.Tensor, 19 | num_class: int, 20 | label_smooth: float = 0.0, 21 | ) -> torch.Tensor: 22 | """ 23 | This function converts target class indices to one-hot vectors, 24 | given the number of classes. 25 | Args: 26 | targets (torch.Tensor): Index labels to be converted. 27 | num_class (int): Total number of classes. 28 | label_smooth (float): Label smooth value for non-target classes. Label smooth 29 | is disabled by default (0). 30 | """ 31 | assert ( 32 | torch.max(targets).item() < num_class 33 | ), "Class Index must be less than number of classes" 34 | assert 0 <= label_smooth < 1.0, "Label smooth value needs to be between 0 and 1." 35 | 36 | targets = targets.squeeze(-1) # avoids dim like (B, 1) 37 | 38 | non_target_value = label_smooth / num_class 39 | target_value = 1.0 - label_smooth + non_target_value 40 | one_hot_targets = torch.full( 41 | (*targets.shape, num_class), 42 | non_target_value, 43 | dtype=None, 44 | device=targets.device, 45 | ) 46 | one_hot_targets.scatter_(-1, targets.unsqueeze(-1), target_value) 47 | return one_hot_targets 48 | 49 | 50 | def _mix_labels( 51 | labels: torch.Tensor, 52 | num_classes: int, 53 | lam: float = 1.0, 54 | label_smoothing: float = 0.0, 55 | one_hot: bool = False, 56 | ): 57 | """ 58 | This function converts class indices to one-hot vectors and mix labels, given the 59 | number of classes. 60 | Args: 61 | labels (torch.Tensor): Class labels. 62 | num_classes (int): Total number of classes. 63 | lam (float): lamba value for mixing labels. 64 | label_smoothing (float): Label smoothing value. 65 | """ 66 | if one_hot: 67 | labels1 = labels 68 | labels2 = labels.flip(0) 69 | else: 70 | labels1 = convert_to_one_hot(labels, num_classes, label_smoothing) 71 | labels2 = convert_to_one_hot(labels.flip(0), num_classes, label_smoothing) 72 | return labels1 * lam + labels2 * (1.0 - lam) 73 | 74 | 75 | def _mix(inputs: torch.Tensor, batch_wo_ignore_index: torch.Tensor, lam: float = 1.0) -> torch.Tensor: 76 | """ 77 | mix inputs of specific indexes 78 | :param inputs: input tensor 79 | :param batch_wo_ignore_index: index of batches where ignore class does occur 80 | :param lam: mixup lambda 81 | :return: mixed inputs 82 | """ 83 | inputs_selected = inputs[batch_wo_ignore_index] 84 | inputs_flipped = inputs_selected.flip(0).mul_(1.0 - lam) 85 | inputs_selected.mul_(lam).add_(inputs_flipped) 86 | inputs[batch_wo_ignore_index] = inputs_selected 87 | return inputs 88 | 89 | 90 | class MixUp(torch.nn.Module): 91 | """ 92 | Mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412) 93 | """ 94 | 95 | def __init__( 96 | self, 97 | alpha: float = 1.0, 98 | label_smoothing: Dict = 0.0, 99 | num_classes: Dict = None, 100 | one_hot: bool = False, 101 | ignore_cls=-1, 102 | ) -> None: 103 | """ 104 | This implements MixUp for videos. 105 | Args: 106 | alpha (float): Mixup alpha value. 107 | label_smoothing (dict): Label smoothing value. 108 | num_classes (dict, int): Number of total classes. 109 | one_hot (bool): whether labels are already in one-hot form 110 | ignore_cls (int): class that will not contribute for backpropagation 111 | """ 112 | super().__init__() 113 | self.mixup_beta_sampler = torch.distributions.beta.Beta(alpha, alpha) 114 | self.label_smoothing = label_smoothing 115 | self.num_classes = num_classes 116 | self.one_hot = one_hot 117 | self.ignore_cls = ignore_cls 118 | 119 | def forward( 120 | self, 121 | x_video: Dict, 122 | labels: Dict, 123 | labels_subclips: Union[Dict, None], 124 | ) -> Sequence[Union[Dict, None]]: 125 | """ 126 | :param x_video: Dict of inputs from different modalities 127 | :param labels: Dict of action / (verb, noun) labels 128 | :param labels_subclips: Dict of action / (verb, noun) labels for past frames 129 | :return: mixed inputs and labels 130 | """ 131 | assert next(iter(x_video.values())).size(0) > 1, "MixUp cannot be applied to a single instance." 132 | batch_wo_ignore_index = [...] 133 | 134 | # convert labels to one-hot format 135 | labels_out = {key: convert_to_one_hot(val, self.num_classes[key], self.label_smoothing[key]) 136 | for key, val in labels.items()} 137 | 138 | if labels_subclips is not None: 139 | labels_subclips_curr = next(iter(labels_subclips.values())) 140 | batch_wo_ignore_index = batch_wo_ignore_cls(labels_subclips_curr, self.ignore_cls) 141 | 142 | # convert labels_subclips to one-hot format 143 | labels_subclips_out = {} 144 | labels_subclips_ignore_index = {} 145 | for key, val in labels_subclips.items(): 146 | val_tmp = val.clone() 147 | # we first assign those ignore classes 0, so that the code works 148 | # the runner will filter out these ignore classes later 149 | subclips_ignore_index = val == self.ignore_cls 150 | val_tmp[subclips_ignore_index] = 0 151 | labels_subclips_ignore_index[key] = subclips_ignore_index 152 | val_one_hot = convert_to_one_hot(val_tmp, self.num_classes[key], self.label_smoothing[key]) 153 | labels_subclips_out[key] = val_one_hot 154 | 155 | if batch_wo_ignore_index.sum() <= 1: 156 | # we don't do mixup here, since there is only one single batch wo ignore index 157 | return x_video, labels_out, labels_subclips_out, labels_subclips_ignore_index 158 | 159 | mixup_lambda = self.mixup_beta_sampler.sample() 160 | 161 | # mix inputs 162 | x_out = { 163 | modk: _mix(x.clone(), batch_wo_ignore_index, mixup_lambda) 164 | for modk, x in x_video.items() 165 | } 166 | 167 | # mix labels 168 | labels_out = { 169 | key: _mix(val, batch_wo_ignore_index, mixup_lambda) 170 | for key, val in labels_out.items() 171 | } 172 | 173 | if labels_subclips is None: 174 | return x_video, labels_out, None, None 175 | 176 | # mix labels of past frames 177 | labels_subclips_out = { 178 | key: _mix(val, batch_wo_ignore_index, mixup_lambda) 179 | for key, val in labels_subclips_out.items() 180 | } 181 | 182 | return x_out, labels_out, labels_subclips_out, labels_subclips_ignore_index 183 | -------------------------------------------------------------------------------- /datasets/reader_fns.py: -------------------------------------------------------------------------------- 1 | """Implementation of reader functions, modified from AVT""" 2 | 3 | import logging 4 | from pathlib import Path 5 | from typing import Union, List 6 | import torch 7 | import torch.nn as nn 8 | import torchvision 9 | from omegaconf import OmegaConf 10 | import lmdb 11 | import numpy as np 12 | import pandas as pd 13 | import time 14 | 15 | from common.utils import get_video_info 16 | 17 | 18 | # An abstract class to keep track of all reader type classes 19 | class Reader(nn.Module): 20 | pass 21 | 22 | 23 | class DefaultReader(Reader): 24 | def forward(self, video_path, start, end, fps, df_row, **kwargs): 25 | del df_row, fps # Not needed here 26 | torchvision.set_video_backend('pyav') 27 | st = time.perf_counter() 28 | video_info = torchvision.io.read_video(video_path, start, end, **kwargs) 29 | timings = {"T GetItem.GetVideo.I/O.reader.pyav": time.perf_counter() - st} 30 | 31 | # DEBUG see what is breaking 32 | logging.debug('Read %s from %s', video_info[0].shape, video_path.split('/')[-1]) 33 | return (*video_info, timings) 34 | 35 | @staticmethod 36 | def get_frame_rate(video_path: Path) -> float: 37 | return get_video_info(video_path, ['fps'])['fps'] 38 | 39 | 40 | class EpicRULSTMFeatsReader(Reader): 41 | def __init__(self, 42 | lmdb_path: Union[Path, List[Path]] = None, 43 | warn_if_using_closeby_frame: bool = True): 44 | """ 45 | Args: 46 | feats_lmdb_path: LMDB path for RULSTM features. Must be 47 | specified if using rulstm_tsn_feat input_type. Could be a 48 | list, in which case it will concat all those features together. 49 | """ 50 | super().__init__() 51 | if OmegaConf.get_type(lmdb_path) != list: 52 | lmdb_path = [lmdb_path] 53 | self.lmdb_path = lmdb_path 54 | self.lmdb_envs = [lmdb.open(el, readonly=True, lock=False) for el in lmdb_path] 55 | self.warn_if_using_closeby_frame = warn_if_using_closeby_frame 56 | 57 | def forward(self, *args, **kwargs): 58 | return self._read_rulstm_features(*args, **kwargs) 59 | 60 | @staticmethod 61 | def get_frame_rate(video_path: Path) -> float: 62 | del video_path 63 | return 30.0 64 | 65 | def read_representations(self, frames, env, frame_format): 66 | """Reads a set of representations, given their frame names and an LMDB environment. 67 | From https://github.com/fpv-iplab/rulstm/blob/96e38666fad7feafebbeeae94952dba24771e512/RULSTM/dataset.py#L10 68 | """ 69 | features = [] 70 | # for each frame 71 | for frame_id in frames: 72 | # read the current frame 73 | with env.begin() as e: 74 | # Need to search for a frame that has features stored, the exact frame may not have. 75 | # To avoid looking at the future when training/testing, (important for anticipation), 76 | # look only for previous to current position. 77 | dd = None 78 | search_radius = 0 79 | for search_radius in range(10): 80 | dd = e.get(frame_format.format(frame_id - search_radius).strip().encode('utf-8')) 81 | if dd is not None: 82 | break 83 | if dd is not None and search_radius > 0: 84 | if self.warn_if_using_closeby_frame: 85 | logging.warning('Missing %s, but used %d instead', frame_format.format(frame_id), 86 | frame_id - search_radius) 87 | if dd is None: 88 | logging.error('Missing %s, Only specific frames are stored in lmdb :(', frame_format.format(frame_id)) 89 | features.append(None) 90 | else: 91 | # convert to numpy array 92 | data = np.frombuffer(dd, 'float32') 93 | # append to list 94 | features.append(data) 95 | # For any frames we didn't find a feature, use a series of 0s 96 | features_not_none = [el for el in features if el is not None] 97 | assert len(features_not_none) > 0, (f'No features found in {frame_format} - {frames}') 98 | feature_not_none = features_not_none[0] # any 99 | features = [np.zeros_like(feature_not_none) if el is None else el for el in features] 100 | # convert list to numpy array 101 | features = np.array(features) 102 | # Add singleton dimensions to make it look like a video, so 103 | # rest of the code just works 104 | features = features[:, np.newaxis, np.newaxis, :] 105 | # Make it torch Tensor to be consistent 106 | features = torch.as_tensor(features) 107 | return features 108 | 109 | def _read_rulstm_features(self, video_path: Path, start_sec: float, end_sec: float, fps: float, 110 | df_row: pd.DataFrame, pts_unit='sec'): 111 | del pts_unit # Not supported here 112 | # Read every single frame between the start and end, the base_video_dataset code will deal with how to sample into 4fps 113 | # (i.e. 0.25s steps), Rather than first computing the timestamps, just compute the 114 | # frame ID of the start and end, and do a arange .. that avoids any repeated frames due to quantization/floor 115 | time_stamps = None 116 | timings = {} 117 | start_frame = np.floor(start_sec * fps) 118 | end_frame = np.floor(end_sec * fps) 119 | frames = np.arange(end_frame, start_frame, -1).astype(int)[::-1] 120 | # If the frames go below 1, replace them with the lowest time pt 121 | assert frames.max() >= 1, (f'The dataset shouldnt have cases otherwise. {video_path} {start_sec} {end_sec} ' 122 | f'{df_row} {frames} {time_stamps}') 123 | frames[frames < 1] = frames[frames >= 1].min() 124 | # Get the features 125 | all_feats = [] 126 | for i, lmdb_env in enumerate(self.lmdb_envs): 127 | video_name = Path(video_path).stem 128 | lmdb_path = self.lmdb_path[i] 129 | st = time.perf_counter() 130 | if 'audio' in lmdb_path or 'poses' in lmdb_path: 131 | frames_new = self._convert_to_orig_video_fps(video_name, fps, frames) 132 | all_feats.append(self.read_representations(frames_new, lmdb_env, video_name + '_frame_{:010d}.jpg')) 133 | else: 134 | all_feats.append(self.read_representations(frames, lmdb_env, video_name + '_frame_{:010d}.jpg')) 135 | timings = {"T GetItem.GetVideo.I/O.reader.lmdb_read": time.perf_counter() - st} 136 | final_feat = torch.cat(all_feats, dim=-1) 137 | # Must return rgb, audio, info; so padding with empty dicts for those 138 | return final_feat, {}, {}, timings 139 | 140 | def _convert_to_orig_video_fps(self, video_name, fps, frames): 141 | """Convert the frames in rulstm fps to the frames in orig video fps. 142 | This is used for features (e.g. audio) which were extracted based on videos.""" 143 | orig_fps = self._get_orig_video_fps(video_name) 144 | frames_new = frames / fps * orig_fps 145 | frames_new = np.rint(frames_new).astype(int) 146 | return frames_new 147 | 148 | @staticmethod 149 | def _get_orig_video_fps(video_name): 150 | length = len(video_name.split('_')[-1]) 151 | if length == 3: # epic 100 152 | return 50.0 153 | elif length == 2: # epic 55 154 | return 59.94005994005994 155 | else: 156 | raise ValueError(f'Unkown video name format: {video_name}') 157 | -------------------------------------------------------------------------------- /common/utils.py: -------------------------------------------------------------------------------- 1 | """Implementation of helper functions, some of them are copied/modified from AVT/RULSTM""" 2 | 3 | from __future__ import print_function 4 | from typing import List, Dict 5 | 6 | import errno 7 | import os 8 | from pathlib import Path 9 | import logging 10 | import submitit 11 | import cv2 12 | import numpy as np 13 | 14 | import torch 15 | import torch.distributed as dist 16 | from torch import nn 17 | 18 | 19 | def topk_accuracy(scores, labels, ks, selected_class=None): 20 | """Computes TOP-K accuracies for different values of k 21 | Args: 22 | rankings: numpy ndarray, shape = (instance_count, label_count) 23 | labels: numpy ndarray, shape = (instance_count,) 24 | ks: tuple of integers 25 | 26 | Returns: 27 | list of float: TOP-K accuracy for each k in ks 28 | """ 29 | # from RULSTM 30 | if selected_class is not None: 31 | idx = labels == selected_class 32 | scores = scores[idx] 33 | labels = labels[idx] 34 | rankings = scores.argsort()[:, ::-1] 35 | # trim to max k to avoid extra computation 36 | maxk = np.max(ks) 37 | 38 | # compute true positives in the top-maxk predictions 39 | tp = rankings[:, :maxk] == labels.reshape(-1, 1) 40 | 41 | # trim to selected ks and compute accuracies 42 | return [tp[:, :k].max(1).mean() for k in ks] 43 | 44 | 45 | def topk_recall(scores, labels, k=5, classes=None): 46 | # From RULSTM 47 | unique = np.unique(labels) 48 | if classes is None: 49 | classes = unique 50 | else: 51 | classes = np.intersect1d(classes, unique) 52 | recalls = 0 53 | #np.zeros((scores.shape[0], scores.shape[1])) 54 | for c in classes: 55 | recalls += topk_accuracy(scores, labels, ks=(k,), selected_class=c)[0] 56 | return recalls/len(classes) 57 | 58 | 59 | def accuracy(output, target, topk=(1, )): 60 | """Computes the accuracy over the k top predictions 61 | for the specified values of k 62 | Args: 63 | output (*, K) predictions 64 | target (*, ) targets 65 | """ 66 | if torch.all(target < 0): 67 | return [ 68 | torch.zeros([], device=output.device) for _ in range(len(topk)) 69 | ] 70 | with torch.no_grad(): 71 | # flatten the initial dimensions, to deal with 3D+ input 72 | output = output.flatten(0, -2) 73 | target = target.flatten() 74 | # Now compute the accuracy 75 | maxk = max(topk) 76 | batch_size = target.size(0) 77 | 78 | _, pred = output.topk(maxk, 1, True, True) 79 | pred = pred.t() 80 | correct = pred.eq(target[None]) 81 | 82 | res = [] 83 | for k in topk: 84 | correct_k = correct[:k].flatten().sum(dtype=torch.float32) 85 | res.append(correct_k * (100.0 / batch_size)) 86 | return res 87 | 88 | 89 | def mkdir(path): 90 | try: 91 | os.makedirs(path) 92 | except OSError as e: 93 | if e.errno != errno.EEXIST: 94 | raise 95 | 96 | 97 | def setup_for_distributed(is_master, logger): 98 | """ 99 | This function disables printing when not in master process 100 | """ 101 | import builtins as __builtin__ 102 | builtin_print = __builtin__.print 103 | 104 | def print(*args, **kwargs): 105 | force = kwargs.pop('force', False) 106 | if is_master or force: 107 | builtin_print(*args, **kwargs) 108 | 109 | __builtin__.print = print 110 | if not is_master: 111 | # Don't print anything except FATAL 112 | logger.setLevel(logging.ERROR) 113 | logging.basicConfig(level=logging.ERROR) 114 | else: 115 | logger.setLevel(logging.INFO) 116 | logging.basicConfig(level=logging.INFO) 117 | 118 | 119 | def is_dist_avail_and_initialized(): 120 | if not dist.is_available(): 121 | return False 122 | if not dist.is_initialized(): 123 | return False 124 | return True 125 | 126 | 127 | def get_world_size(): 128 | if not is_dist_avail_and_initialized(): 129 | return 1 130 | return dist.get_world_size() 131 | 132 | 133 | def get_rank(): 134 | if not is_dist_avail_and_initialized(): 135 | return 0 136 | return dist.get_rank() 137 | 138 | 139 | def is_main_process(): 140 | return get_rank() == 0 141 | 142 | 143 | def save_on_master(*args, **kwargs): 144 | if is_main_process(): 145 | torch.save(*args, **kwargs) 146 | 147 | 148 | def init_distributed_mode(logger, dist_backend='nccl'): 149 | dist_info = dict( 150 | distributed=False, 151 | rank=0, 152 | world_size=1, 153 | gpu=0, 154 | dist_backend=dist_backend, 155 | dist_url=get_init_file(None).as_uri(), 156 | ) 157 | # If launched using submitit, get the job_env and set using those 158 | try: 159 | job_env = submitit.JobEnvironment() 160 | except RuntimeError: 161 | job_env = None 162 | if job_env is not None: 163 | dist_info['rank'] = job_env.global_rank 164 | dist_info['world_size'] = job_env.num_tasks 165 | dist_info['gpu'] = job_env.local_rank 166 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 167 | dist_info['rank'] = int(os.environ["RANK"]) 168 | dist_info['world_size'] = int(os.environ['WORLD_SIZE']) 169 | dist_info['gpu'] = int(os.environ['LOCAL_RANK']) 170 | elif 'SLURM_PROCID' in os.environ: 171 | dist_info['rank'] = int(os.environ['SLURM_PROCID']) 172 | dist_info['gpu'] = dist_info['rank'] % torch.cuda.device_count() 173 | elif 'rank' in dist_info: 174 | pass 175 | else: 176 | print('Not using distributed mode') 177 | dist_info['distributed'] = False 178 | return dist_info 179 | 180 | dist_info['distributed'] = True 181 | 182 | torch.cuda.set_device(dist_info['gpu']) 183 | dist_info['dist_backend'] = dist_backend 184 | print('| distributed init (rank {}): {}'.format(dist_info['rank'], 185 | dist_info['dist_url']), 186 | flush=True) 187 | torch.distributed.init_process_group(backend=dist_info['dist_backend'], 188 | init_method=dist_info['dist_url'], 189 | world_size=dist_info['world_size'], 190 | rank=dist_info['rank']) 191 | setup_for_distributed(dist_info['rank'] == 0, logger) 192 | return dist_info 193 | 194 | 195 | def get_shared_folder(name) -> Path: 196 | # Since using hydra, which figures the out folder 197 | return Path('./').absolute() 198 | 199 | 200 | def get_init_file(name): 201 | # Init file must not exist, but it's parent dir must exist. 202 | os.makedirs(str(get_shared_folder(name)), exist_ok=True) 203 | init_file = get_shared_folder(name) / 'sync_file_init' 204 | return init_file 205 | 206 | 207 | def gather_tensors_from_all(tensor: torch.Tensor) -> List[torch.Tensor]: 208 | """ 209 | Wrapper over torch.distributed.all_gather for performing 210 | 'gather' of 'tensor' over all processes in both distributed / 211 | non-distributed scenarios. 212 | """ 213 | if tensor.ndim == 0: 214 | # 0 dim tensors cannot be gathered. so unsqueeze 215 | tensor = tensor.unsqueeze(0) 216 | 217 | if is_dist_avail_and_initialized(): 218 | gathered_tensors = [ 219 | torch.zeros_like(tensor) 220 | for _ in range(torch.distributed.get_world_size()) 221 | ] 222 | torch.distributed.all_gather(gathered_tensors, tensor) 223 | else: 224 | gathered_tensors = [tensor] 225 | 226 | return gathered_tensors 227 | 228 | 229 | def gather_from_all(tensor: torch.Tensor) -> torch.Tensor: 230 | gathered_tensors = gather_tensors_from_all(tensor) 231 | gathered_tensor = torch.cat(gathered_tensors, 0) 232 | return gathered_tensor 233 | 234 | 235 | def get_video_info(video_path: Path, props: List[str]) -> Dict[str, float]: 236 | """ 237 | Given the video, return the properties asked for 238 | """ 239 | output = {} 240 | cam = cv2.VideoCapture(str(video_path)) 241 | if 'fps' in props: 242 | output['fps'] = cam.get(cv2.CAP_PROP_FPS) 243 | if 'len' in props: 244 | fps = cam.get(cv2.CAP_PROP_FPS) 245 | if fps <= 0: 246 | output['len'] = 0 247 | else: 248 | output['len'] = (cam.get(cv2.CAP_PROP_FRAME_COUNT) / fps) 249 | cam.release() 250 | return output 251 | 252 | 253 | def human_format(num): 254 | num = float('{:.3g}'.format(num)) 255 | magnitude = 0 256 | while abs(num) >= 1000: 257 | magnitude += 1 258 | num /= 1000.0 259 | return '{}{}'.format('{:f}'.format(num).rstrip('0').rstrip('.'), ['', 'K', 'M', 'B', 'T'][magnitude]) 260 | 261 | 262 | def question(question, dflt=True): 263 | while True: 264 | answer = input(f"{question} ({'[Y]/N' if dflt else 'Y/[N]'})") 265 | if any(answer.lower() == f for f in ["yes", 'y', '1', 'ye']): 266 | return True 267 | elif any(answer.lower() == f for f in ['no', 'n', '0']): 268 | return False 269 | elif answer == '': 270 | return dflt 271 | 272 | 273 | def has_batchnorms(model): 274 | bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm) 275 | for name, module in model.named_modules(): 276 | if isinstance(module, bn_types): 277 | return True 278 | return False 279 | 280 | 281 | def has_key(obj, key): 282 | # recursively find key inside a dict 283 | if any([key in k for k in obj.keys()]): return True 284 | for k, v in obj.items(): 285 | if isinstance(v, dict): 286 | if has_key(v, key): return True 287 | return False 288 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /tmp.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from models.fusion import * 4 | from models.future_prediction import * 5 | from models.base_model import BaseModel 6 | from datasets.data import get_dataset 7 | from train import get_transform_val 8 | from challenge import * 9 | from models.action_recognition import * 10 | from models.backbone import * 11 | 12 | 13 | def ordered_feature_list(x_d: Dict[str, Tensor], feats_order: List) -> List[Tensor]: 14 | """Converts multimodal feature dictionary to a list according to the given order 15 | used for cmfuser""" 16 | tensor_list = [] 17 | for i, modk in enumerate(feats_order): 18 | tensor_list.append(x_d[modk]) 19 | return tensor_list 20 | 21 | 22 | @hydra.main(config_path='conf', config_name='config') 23 | def debug_model(cfg: DictConfig): 24 | ckpt_root_dir = '/home/zhong/Documents/projects/AVAA/checkpoints/' 25 | ckpt_path = [ckpt_root_dir + 'IndividualFuturePrediction_CMFuser_rgb/checkpoint_best.pth', 26 | ckpt_root_dir + 'IndividualFuturePrediction_CMFuser_objects/checkpoint_best.pth'] 27 | logger = logging.getLogger(__name__) 28 | 29 | # model configs 30 | model_cfg = cfg.model 31 | model_cfg.modal_dims = {"rgb": 1024, "objects": 352} 32 | model_cfg.common.share_classifiers = False 33 | model_cfg.common.share_predictors = False 34 | model_cfg.common.modality_cls = True 35 | model_cfg.common.fusion_cls = True 36 | model_cfg.CMFP = {'_target_': 'models.future_prediction.CMFPLate', 'model_cfg': 'null'} 37 | 38 | num_classes = {'action': 3806} 39 | model = BaseModel(cfg.model, num_classes=num_classes, class_mappings={}) 40 | 41 | named_params = list(model.named_parameters()) 42 | named_buffs = list(model.named_buffers()) 43 | model_state = list(model.state_dict()) 44 | 45 | modules_to_keep = ['future_predictor.future_predictor', 'future_predictor.dim_encoder', 46 | 'future_predictor.dim_decoder'] 47 | 48 | params_require_grad = [p for p in model.parameters() if p.requires_grad] 49 | print(1) 50 | 51 | 52 | @hydra.main(config_path="conf", config_name="config") 53 | def debug_cmfp(cfg: DictConfig): 54 | cmfp_name = 'early' 55 | input_len = 10 # 10 seconds if fps = 1 56 | bs = 64 57 | 58 | feats = { 59 | 'rgb': torch.randn((bs, input_len, 768)).cuda(), 60 | # 'objects': torch.randn((bs, input_len, 352)).cuda() 61 | } 62 | 63 | model_cfg = cfg.model 64 | model_cfg.modal_dims = {"rgb": 768, "objects": 352} 65 | model_cfg.common.fp_inter_dim = 768 66 | model_cfg.common_dim = 768 67 | model_cfg.common.fp_layers = 4 68 | 69 | model_cfg.common.share_classifiers = False # may to be changed 70 | model_cfg.common.share_predictors = False # may to be changed 71 | model_cfg.common.map_features = False # may to be changed 72 | model_cfg.common.modality_cls = True # may to be changed 73 | model_cfg.common.fusion_cls = False 74 | 75 | num_classes = {'action': 3806} 76 | 77 | # from models.action_recognition import CMRecognitionEarly 78 | # model = CMRecognitionEarly(model_cfg, num_classes) 79 | # model = CMFPEarly(model_cfg, num_classes, extra_cls_rgb=True) 80 | model = IndividualRecognition(model_cfg, num_classes) 81 | 82 | model.to('cuda') 83 | out = model(feats) 84 | 85 | print(1) 86 | 87 | 88 | def debug_video_reading(model, dataset, device, logger): 89 | model.eval() 90 | dur_data = 0 91 | dur_infer = 0 92 | length = 100 93 | for idx in range(length): 94 | start_time = time.time() 95 | with torch.no_grad(): 96 | data = dataset[idx] 97 | time1 = time.time() 98 | dur1 = time1 - start_time 99 | logger.info(f'fetch data takes {dur1}s') 100 | video = data['video'].to(device) 101 | outputs = model(video) 102 | dur2 = time.time() - time1 103 | logger.info(f'inference takes {dur2}s') 104 | dur_data += dur1 105 | dur_infer += dur2 106 | logger.info(f'averaged fetch data duration pro sample: {dur_data / length}s') 107 | logger.info(f'averaged inference duration pro sample: {dur_infer / length}') 108 | 109 | 110 | def debug_fuser(): 111 | feats_order = ["rgb", "objects"] 112 | order_feature_func = partial(ordered_feature_list, feats_order=feats_order) 113 | modal_dims = {'rgb': 1024, 'objects': 1024} 114 | 115 | # fuser = ModalTokenCMFuser(dim=1024, frame_level_token=True, temporal_sequence_length=10, modalities=modal_dims) 116 | # fuser = TemporalCrossAttentFuser(dim=1024, num_modals=2).cuda() 117 | fuser = TemporalCMFuser(dim=1024, modalities=modal_dims, frame_level_token=False, temporal_sequence_length=10).cuda() 118 | feats = {'rgb': torch.randn((64, 10, 1024)).cuda(), 119 | 'objects': torch.randn((64, 10, 1024)).cuda()} 120 | weights = fuser(feats, order_feature_func) 121 | print(1) 122 | 123 | 124 | @hydra.main(config_path="conf", config_name="config") 125 | def debug_recognition(cfg: DictConfig): 126 | model_cfg = cfg.model 127 | model_cfg.modal_dims = {"rgb": 768} 128 | num_classes = {'action': 3806} 129 | 130 | feats = {'rgb': torch.randn((64, 3, 768))} 131 | 132 | model = IndividualRecognition(model_cfg, num_classes) 133 | y = model(feats) 134 | print(1) 135 | 136 | @hydra.main(config_path="conf", config_name="config") 137 | def debug_dataset(cfg: DictConfig): 138 | logger = logging.getLogger(__name__) 139 | cfg.dataset.epic_kitchens100.common.sample_strategy = 'random_clip' 140 | cfg.dataset.epic_kitchens100.common.reader_fn = {'_target_': 'datasets.reader_fns.EpicRULSTMFeatsReader', 141 | 'lmdb_path': ['${dataset.epic_kitchens100.common.rulstm_feats_dir}/rgb/']} 142 | transform_val = get_transform_val(cfg) 143 | dataset_test = get_dataset(getattr(cfg, 'dataset_eval'), cfg.data_eval, transform_val, logger) 144 | for i in tqdm(range(9638)): 145 | data = dataset_test[i] 146 | 147 | print(1) 148 | 149 | 150 | def contains_list(test_list): 151 | for element in test_list: 152 | if isinstance(element, list): 153 | return True 154 | return False 155 | 156 | 157 | def debug(): 158 | w_tsn_10s = np.arange(0, 1, 0.25) 159 | w_tsn_14s = np.arange(0, 1, 0.25) 160 | w_tsn_18s = np.arange(0, 1, 0.25) 161 | w_swin_4h_8s = np.arange(0.25, 1, 0.25) 162 | w_swin_4h_14s = np.arange(0.25, 1, 0.25) 163 | w_swin_4h_16s = np.arange(0.75, 1.25, 0.25) # important 164 | w_swin_4h_18s = np.arange(0.25, 1, 0.25) 165 | w_swin_8h_10s = np.arange(0.5, 1, 0.25) # important 166 | w_swin_8h_14s = np.arange(0.5, 1, 0.25) # important 167 | w_swin_8h_16s = np.arange(0.25, 1, 0.25) 168 | w_swin_8h_18s = np.arange(0.25, 1, 0.25) 169 | 170 | # weights = [w for i in range(len(ex))] 171 | weights = [w_tsn_10s, tsn_14s, tsn_18s, w_swin_4h_8s, w_swin_4h_14s, w_swin_4h_16s, 172 | w_swin_4h_18s, w_swin_8h_10s, w_swin_8h_14s, w_swin_8h_16s, w_swin_8h_18s] 173 | # weights = [w_tsn_10s, w_tsn_14s, w_tsn_18s] 174 | weights_combinations = list(itertools.product(*weights)) 175 | 176 | 177 | def debug_crossentropy_with_ignore_index(): 178 | func = nn.CrossEntropyLoss(ignore_index=-1) 179 | num_class = 5 180 | target = torch.tensor([-1, 1, 3]) 181 | #target1 = convert_to_one_hot(target, num_class, label_smooth=0.0) 182 | logits = torch.randn(3, num_class, requires_grad=True) 183 | target1 = torch.tensor([1, 3]) 184 | logits1 = logits[1:] 185 | 186 | loss = func(logits, target) 187 | loss1 = func(logits1, target1) 188 | print(loss) 189 | print(loss1) 190 | 191 | 192 | def debug_mixup_simple(): 193 | B, num_classes = 5, {'action': 6} 194 | feature_dict = { 195 | 'rgb': torch.randn((B, 5, 3)), 196 | 'objects': torch.randn((B, 5, 3)) 197 | } 198 | target = {'action': torch.tensor([0, 0, 0, 1, 1])} 199 | target_subclips = {'action': torch.tensor([[1, 1, 0, -1, 1], 200 | [0, 1, 1, 0, 0], 201 | [1, -1, 0, 1, 1], 202 | [1, 0, 0, 1, 1], 203 | [1, -1, 1, 1, 0]])} 204 | from common.mixup import MixUp 205 | op = MixUp(label_smoothing=0.1, num_classes=num_classes) 206 | x_out, labels_out, labels_subclips_out, labels_subclips_ignore_index = op(feature_dict, target, target_subclips) 207 | 208 | past_logits = torch.randn((B, 5, 6)) 209 | logits = torch.randn((B, 6)) 210 | 211 | from common.runner import MultiDimCrossEntropy 212 | loss_func = MultiDimCrossEntropy() 213 | 214 | loss = loss_func(logits, target['action']) 215 | past_loss = loss_func(past_logits, labels_subclips_out['action'], one_hot=True, ignore_index=labels_subclips_ignore_index['action']) 216 | 217 | labels = labels_out['action'] 218 | _top_max_k_vals, top_max_k_inds = torch.topk( 219 | labels, 2, dim=1, largest=True, sorted=True 220 | ) 221 | idx_top1 = torch.arange(labels.shape[0]), top_max_k_inds[:, 0] 222 | idx_top2 = torch.arange(labels.shape[0]), top_max_k_inds[:, 1] 223 | preds = logits.detach() 224 | preds[idx_top1] += preds[idx_top2] 225 | preds[idx_top2] = 0.0 226 | labels = top_max_k_inds[:, 0] 227 | 228 | print(1) 229 | 230 | 231 | def debug_mixup(): 232 | B = 8 233 | num_classes = 5 234 | feature_dict = { 235 | 'rgb': torch.randn((B, 10, 3, 1, 224, 224)), 236 | 'objects': torch.randn((B, 10, 352, 1, 1, 1)) 237 | } 238 | target = {'action': torch.randint(0, num_classes, (B,))} 239 | target_subclip = {'action': torch.randint(0, num_classes, (B, 10, 1))} 240 | 241 | from common.mixup import MixUp 242 | 243 | op = MixUp(label_smoothing=0.1, num_classes=num_classes) 244 | a, b, c = op(feature_dict, target, target_subclip) 245 | print(1) 246 | 247 | 248 | def debug_backbone(): 249 | # model = MViTModel() 250 | # ckpt = torch.load('checkpoints/TIMM/MViTv2_S_in1k.pyth') 251 | # missing_keys, unexp_keys = model.model.load_state_dict(ckpt['model_state'], strict=False) 252 | model = TIMMModel(model_type='beit_base_patch16_224_in22k') 253 | ckpt = torch.load('checkpoints/TIMM/beit_base_patch16_224_pt22k_ft22k.pth') 254 | missing_keys, unexp_keys = model.model.load_state_dict(ckpt['model'], strict=False) 255 | print(1) 256 | 257 | 258 | @hydra.main(config_path="conf", config_name="config") 259 | def debug_future_embed_prediction(cfg): 260 | model_cfg = cfg.model 261 | model_cfg.modal_dims = {"rgb": 768} 262 | model_cfg.common.fp_inter_dim = 768 263 | model_cfg.common_dim = 768 264 | 265 | num_classes = {'action': 3806} 266 | 267 | from models.future_embed_prediction import FutureEmbedPrediction 268 | model = FutureEmbedPrediction(model_cfg, num_classes, dim=2048) 269 | model.to('cuda') 270 | 271 | input_len = 10 # 10 seconds if fps = 1 272 | bs = 64 273 | 274 | feats = { 275 | 'rgb': torch.randn((bs, input_len, 768)).cuda(), 276 | # 'objects': torch.randn((bs, input_len, 352)).cuda() 277 | } 278 | 279 | out = model(feats) 280 | print(1) 281 | 282 | 283 | def tmp(): 284 | input_channel = 1 285 | output_channel = 128 286 | 287 | conv_layer = nn.Conv2d(input_channel, output_channel, kernel_size=(1, 1)) 288 | linear_layer = nn.Linear(input_channel, output_channel) 289 | 290 | input_linear = torch.randn(1, 7, 7, input_channel) 291 | input_conv = torch.randn(1, input_channel, 7, 7) 292 | 293 | output_conv = conv_layer(input_conv) 294 | output_linear = linear_layer(input_linear) 295 | 296 | print(1) 297 | 298 | 299 | if __name__ == '__main__': 300 | tmp() 301 | # debug() 302 | # debug_mixup_simple() 303 | # debug_crossentropy_with_ignore_index() 304 | # debug_mixup() 305 | # debug_recognition() 306 | # debug_dataset() 307 | # debug_cmfp() 308 | # debug_model() 309 | # debug_fuser() 310 | # debug_backbone() 311 | # debug_future_embed_prediction() 312 | # causal_mask = generate_square_subsequent_mask(5) 313 | # causal_modality_mask = causal_mask.repeat(2, 2) 314 | # print(causal_modality_mask) 315 | -------------------------------------------------------------------------------- /common/runner.py: -------------------------------------------------------------------------------- 1 | """Implementation of a training iteration""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | from typing import Dict, Tuple, Union, Callable, Optional 6 | 7 | from common import utils 8 | 9 | CLS_MAP_PREFIX = 'cls_map_' 10 | PAST_LOGITS_PREFIX = 'past_' 11 | 12 | 13 | class MultiDimCrossEntropy(nn.CrossEntropyLoss): 14 | """Will reshape the flatten initial dimensions and then incur loss""" 15 | 16 | def forward(self, inp, tgt, 17 | one_hot: bool = False, 18 | ignore_index: Union[torch.Tensor, None] = None): 19 | """ 20 | Args: 21 | inp: (*, C) 22 | tgt: (*, ) 23 | one_hot: whether the labels are already one-hotted 24 | ignore_index: index of inputs to be ignored 25 | """ 26 | inp = inp.reshape(-1, inp.size(-1)) 27 | tgt = tgt.reshape(-1,) if not one_hot else tgt.reshape(-1, tgt.size(-1)) 28 | 29 | if ignore_index is not None: 30 | assert one_hot, "Target should be one-hotted." 31 | ignore_index = ignore_index.reshape(-1,) 32 | keep_index = ~ignore_index 33 | inp = inp[keep_index] 34 | tgt = tgt[keep_index] 35 | 36 | res = super().forward(inp, tgt) 37 | return res 38 | 39 | 40 | class BasicLossAccuracy(nn.Module): 41 | """Computes acc1, acc5 and the three loss values 42 | 1. loss for future action prediction, 43 | 2. loss for past action prediction and 44 | 3. loss for past feature regression. 45 | """ 46 | 47 | def __init__(self): 48 | super().__init__() 49 | kwargs = {'ignore_index': -1} 50 | kwargs['reduction'] = 'none' # to get batch level output 51 | self.cls_criterion = MultiDimCrossEntropy(**kwargs) 52 | self.reg_criterion = torch.nn.MSELoss() 53 | 54 | def forward_future_action(self, logits, tgt_val, mixup_enable, losses, metrics, 55 | acc1_key, acc5_key, mt5r_key, loss_key, key_suffix=''): 56 | """Computes accuracy, mt5r and loss value of future action""" 57 | loss_future_action = self.cls_criterion(logits, tgt_val, one_hot=mixup_enable) 58 | 59 | sequence_index = 0 60 | if mixup_enable: 61 | # we add up the top1 and top2 predictions and labels 62 | _top_max_k_vals, top_max_k_inds = torch.topk( 63 | tgt_val, 2, dim=1, largest=True, sorted=True 64 | ) 65 | idx_top1 = torch.arange(tgt_val.shape[0]), \ 66 | torch.tensor(sequence_index).repeat(tgt_val.shape[0]), \ 67 | top_max_k_inds[:, 0] 68 | idx_top2 = torch.arange(tgt_val.shape[0]), \ 69 | torch.tensor(sequence_index).repeat(tgt_val.shape[0]), \ 70 | top_max_k_inds[:, 1] 71 | preds = logits.detach().clone() 72 | preds[idx_top1] += preds[idx_top2] 73 | preds[idx_top2] = 0.0 74 | labels = top_max_k_inds[:, 0] 75 | else: 76 | preds = logits.detach() 77 | labels = tgt_val.clone() 78 | 79 | if len(labels.shape) == 1: # single frame prediction 80 | labels = labels.unsqueeze(dim=-1) 81 | 82 | metrics[mt5r_key + key_suffix] = { 83 | 'logits': preds[:, sequence_index, :].cpu().numpy(), 84 | 'labels': labels[:, sequence_index].cpu().numpy() 85 | } 86 | 87 | dataset_max_classes = preds.size(-1) 88 | acc1, acc5 = utils.accuracy(preds, labels, topk=(1, min(5, dataset_max_classes))) 89 | 90 | losses[loss_key + key_suffix] = loss_future_action 91 | metrics[acc1_key + key_suffix] = acc1 92 | metrics[acc5_key + key_suffix] = acc5 93 | 94 | def forward_past_action(self, past_logits, past_target, mixup_enable, 95 | losses, loss_key, past_target_ignore_index=None, key_suffix=''): 96 | frames_to_keep = None 97 | if mixup_enable: 98 | assert past_logits.shape == past_target.shape 99 | assert past_target_ignore_index is not None 100 | if frames_to_keep is not None: 101 | past_target_ignore_index = past_target_ignore_index[:, frames_to_keep] 102 | loss_past_action = self.cls_criterion( 103 | past_logits, past_target, one_hot=mixup_enable, ignore_index=past_target_ignore_index 104 | ) 105 | else: 106 | past_target = past_target.squeeze(-1) # this assumes the last dimension is 1 107 | assert past_logits.shape[:-1] == past_target.shape 108 | loss_past_action = self.cls_criterion(past_logits, past_target) 109 | 110 | losses[loss_key + key_suffix] = loss_past_action 111 | 112 | def forward(self, outputs, target, target_subclips, 113 | mixup_enable: bool = False, target_subclips_ignore_index: Union[Dict, None] = None): 114 | """ 115 | Args: 116 | outputs['logits'] torch.Tensor (B, num_classes) or 117 | (B, T, num_classes) 118 | Latter in case of dense prediction 119 | target: {type: (B) or (B, T')}; latter in case of dense prediction 120 | target_subclips: {type: (B, #clips, T)}: The target for each input frame 121 | mixup_enable (bool): whether the targets are already one-hotted 122 | target_subclips_ignore_index: index of inputs to be ignored 123 | """ 124 | losses = {} 125 | metrics = {} 126 | for tgt_type, tgt_val in target.items(): 127 | # --------FUTURE ACTION PREDICTION------- 128 | for modk in outputs[f'logits/{tgt_type}']: 129 | logits = outputs[f'logits/{tgt_type}'][modk] 130 | 131 | # metric keys 132 | acc1_key, acc5_key = f'acc1_{tgt_type}_{modk}', f'acc5_{tgt_type}_{modk}' 133 | mt5r_key = f'mt5r_{tgt_type}_{modk}' 134 | loss_key = f'cls_{tgt_type}_{modk}' 135 | 136 | assert len(logits.shape) == 3 # Includes temporal dimension (B, T, C), even if T == 1 137 | self.forward_future_action( 138 | logits, tgt_val, mixup_enable, losses, metrics, 139 | acc1_key, acc5_key, mt5r_key, loss_key 140 | ) 141 | 142 | # --------PAST ACTION PREDICTION------- 143 | past_logits_key = f'{PAST_LOGITS_PREFIX}logits/{tgt_type}' 144 | if past_logits_key in outputs and target_subclips is not None: 145 | for modk in outputs[past_logits_key]: 146 | past_logits = outputs[past_logits_key][modk] 147 | loss_key = f'past_cls_{tgt_type}_{modk}' 148 | past_target_ignore_index = None if target_subclips_ignore_index is None \ 149 | else target_subclips_ignore_index[tgt_type] 150 | 151 | self.forward_past_action( 152 | past_logits, target_subclips[tgt_type], mixup_enable, 153 | losses, loss_key, past_target_ignore_index 154 | ) 155 | 156 | # --------PAST FEATURE REGRESSION--------- 157 | if 'orig_past' in outputs and 'past_futures' in outputs: 158 | orig_past_features = outputs['orig_past'] 159 | updated_past_features = outputs['past_futures'] 160 | 161 | for modk, updated_past_feature in updated_past_features.items(): 162 | if modk not in orig_past_features: continue 163 | loss_key = f'past_reg_{modk}' 164 | losses[loss_key] = self.reg_criterion( 165 | updated_past_features[modk][:, 1:], orig_past_features[modk][:, 1:] 166 | ) 167 | 168 | return losses, metrics 169 | 170 | 171 | def get_loss_wts(loss_wts: Dict, key: str) -> float: 172 | for k, v in loss_wts.items(): 173 | if key.startswith(k): 174 | return v 175 | raise ValueError(f'{key} not contained in predefined loss_wts: {loss_wts}') 176 | 177 | 178 | class Runner: 179 | """wrapper class of BasicLossAccuracy, runs on each batch, returns all metrics""" 180 | 181 | def __init__(self, model, device, loss_wts): 182 | super().__init__() 183 | self.model = model 184 | self.device = device 185 | self.loss_acc_fn = BasicLossAccuracy() 186 | self.loss_wts = loss_wts 187 | 188 | def _basic_preproc(self, data): 189 | if not isinstance(data, dict): 190 | video, target = data 191 | # Make a dict so that later code can use it 192 | data = {} 193 | data['video'] = video 194 | data['target'] = target 195 | data['idx'] = -torch.ones_like(target) 196 | return data 197 | 198 | @staticmethod 199 | def _reduce_loss(losses, loss_wts): 200 | # reduce the losses 201 | losses = {key: torch.mean(val) for key, val in losses.items()} 202 | # weight the losses 203 | losses_wtd = [] 204 | for key, val in losses.items(): 205 | this_loss_wt = get_loss_wts(loss_wts, key) 206 | if this_loss_wt > 0: 207 | losses_wtd.append(this_loss_wt * val) 208 | loss = torch.sum(torch.stack(losses_wtd)) 209 | if torch.isnan(loss): 210 | raise ValueError('The loss is NaN!') 211 | losses_metric = {k: v.item() for k, v in losses.items()} # prevents increasing gpu memory 212 | losses_metric['total_loss'] = loss.item() 213 | return loss, losses_metric 214 | 215 | def __call__(self, 216 | data: Union[Dict[str, torch.Tensor], # If dict 217 | Tuple[torch.Tensor, torch.Tensor]], 218 | mixup_fn: Optional[Callable] = None, 219 | mixup_backbone: Optional[bool] = True): 220 | """ 221 | Args: 222 | data (dict): Dictionary of all the data from the data loader 223 | mixup_fn: Mixup function 224 | mixup_backbone: Whether to mixup the inputs or the backbone outputs 225 | """ 226 | data, timings = data # Getting timings from dataloader 227 | data = self._basic_preproc(data) 228 | feature_dict = {mod: tens.to(self.device, non_blocking=True) for mod, tens in data["data_dict"].items()} 229 | target = {} 230 | target_subclips = {} 231 | for key in data['target'].keys(): 232 | target[key] = data['target'][key].to(self.device, non_blocking=True) 233 | if 'target_subclips' in data: 234 | for key in data['target_subclips'].keys(): 235 | target_subclips[key] = data['target_subclips'][key].to(self.device, non_blocking=True) 236 | else: 237 | target_subclips = None 238 | 239 | kwargs = {} 240 | kwargs['mixup_fn'] = None 241 | kwargs['target'] = target 242 | kwargs['target_subclips'] = target_subclips 243 | kwargs['target_subclips_ignore_index'] = None 244 | 245 | # the target will be one-hotted after mixup 246 | if mixup_fn is not None: 247 | # mixup the inputs 248 | if not mixup_backbone: 249 | feature_dict, target, target_subclips, target_subclips_ignore_index = \ 250 | mixup_fn(feature_dict, target, target_subclips) 251 | kwargs['target'] = target 252 | kwargs['target_subclips'] = target_subclips 253 | kwargs['target_subclips_ignore_index'] = target_subclips_ignore_index 254 | # mixup the backbone outputs 255 | else: 256 | kwargs['mixup_fn'] = mixup_fn 257 | 258 | outputs, outputs_targets = self.model(feature_dict, **kwargs) 259 | 260 | losses, metrics = self.loss_acc_fn( 261 | outputs, 262 | outputs_targets['target'], 263 | outputs_targets['target_subclips'], 264 | mixup_enable=(mixup_fn is not None), 265 | target_subclips_ignore_index=outputs_targets['target_subclips_ignore_index'], 266 | ) 267 | loss, losses_metric = self._reduce_loss(losses, self.loss_wts) 268 | metrics.update(losses_metric) 269 | metrics.update(timings) 270 | return loss, metrics 271 | -------------------------------------------------------------------------------- /annotations/ek100_ori/EPIC_100_verb_classes.csv: -------------------------------------------------------------------------------- 1 | id,key,instances,category 2 | 0,take,"['collect-from', 'collect-into', 'draw', 'fetch', 'get', 'get-from', 'get-in', 'get-off', 'get-with', 'grab', 'grab-down', 'grab-from', 'grab-up', 'pick', 'pick-from', 'pick-off', 'pick-up', 'pull-off', 'retrieve', 'take', 'take-from', 'take-in', 'take-of', 'take-to', 'take-up', 'unload', 'catch', 'take-into', 'remove-on', 'get-on', 'take-with', 'take-inside', 'take-on', 'take-down']",retrieve 3 | 1,put,"['create', 'dose', 'lay', 'lay-down', 'lay-on', 'lay-onto', 'lay-out', 'layer', 'leave', 'leave-in', 'leave-with', 'lie', 'pace', 'place', 'place-away', 'place-back', 'place-down', 'place-on', 'place-onto', 'place-that', 'pose', 'position', 'pout', 'put', 'put-around', 'put-aside', 'put-away', 'put-back', 'put-down', 'put-from', 'put-of', 'put-on', 'put-onto', 'put-over', 'put-to', 'put-under', 'put-up', 'put-with', 'replace', 'reposition', 'rest', 'rest-against', 'rest-on', 'rest-onto', 'reture', 'return', 'return-from', 'return-into', 'set-down', 'tuck-under', 'pet-down', 'save', 'store', 'tip', 'put-unto', 'mushroom-onto', 'put-for', 'place-over', 'place-up', 'place-under', 'get-into', 'combine-from', 'put-outside', 'store-in', 'put-between', 'pick-into', 'layer-on', 'place-with']",leave 4 | 2,wash,"['clean', 'clean-around', 'clean-from', 'clean-off', 'clean-up', 'clean-with', 'hoover', 'lather', 'mop-up', 'rinse', 'rinse-from', 'rinse-in', 'rinse-off', 'rinse-out', 'rinse-under', 'rise', 'soap', 'soap-up', 'sponge', 'wash', 'wash-around', 'wash-by', 'wash-off', 'wash-out', 'wash-under', 'wash-up', 'wash-with', 'wipe', 'wipe-around', 'wipe-by', 'wipe-down', 'wipe-from', 'wipe-in', 'wipe-into', 'wipe-of', 'wipe-off', 'wipe-on', 'wipe-onto', 'wipe-over', 'wipe-up', 'wipe-with', 'cool', 'wash-in', 'wash-on', 'rinse-into', 'rinse-on']",clean 5 | 3,open,"['lever-open', 'open', 'open-in', 'open-on', 'open-up', 'open-with', 'uncork', 'unzip', 'open-to']",access 6 | 4,close,"['close', 'close-off', 'close-with', 'screw-on', 'screw-onto', 'shut', 'twist-on', 'close-in', 'close-on']",block 7 | 5,insert,"['drop-into', 'fit', 'fit-inside', 'insert', 'insert-into', 'move-in', 'place-in', 'place-into', 'pop-into', 'push-in', 'push-into', 'put-in', 'put-inside', 'put-into', 'slide-into', 'slide-onto', 'load', 'pack', 'place-inside', 'sheath', 'sheathe', 'insert-in']",leave 8 | 6,turn-on,"['activate', 'begin', 'ignite', 'light', 'play', 'restart', 'start', 'switch-on', 'turn-on', 'water-on', 'tun-on']",access 9 | 7,cut,"['chop', 'chop-in', 'chop-off', 'chop-up', 'chop-with', 'cut', 'cut-in', 'cut-into', 'cut-off', 'cut-on', 'cut-out', 'cut-with', 'dice', 'half', 'halve', 'julienne', 'score', 'slice', 'slice-along', 'slice-at', 'slice-in', 'slice-into','slice-of', 'slice-off', 'slice-onto', 'slice-out', 'slice-through', 'slice-up', 'slice-with', 'snip', 'trim', 'make', 'stick', 'start-to', 'top', 'slice-to', 'chop-into', 'cut-from', 'slice-from', 'cut-by', 'make-into', 'trim-into']",split 10 | 8,turn-off,"['shut-off', 'switch-of', 'switch-off', 'switch-out', 'tap-off', 'turn-of', 'turn-off', 'water-off', 'flush']",block 11 | 9,pour,"['drizzle', 'drizzle-into', 'drizzle-on', 'pour', 'pour-down', 'pour-from', 'pour-in', 'pour-into', 'pour-on', 'pour-onto', 'pour-out', 'pour-over', 'shake-into', 'sieve', 'tip-from', 'tip-in', 'tip-into', 'pour-inside', 'sieve-with']",merge 12 | 10,mix,"['beat', 'blend', 'fold-in', 'mix', 'mix-around', 'mix-in', 'mix-so', 'mix-together', 'mix-with', 'scramble', 'stir', 'stir-in', 'stir-into', 'stir-with', 'whisk', 'whisk-in', 'mix-into', 'stir-on']",merge 13 | 11,move,"['back', 'drag-around', 'move', 'move-along', 'move-around', 'move-from', 'move-into', 'move-off', 'move-on', 'move-onto', 'move-through', 'move-to', 'move-with', 'move-within', 'pass-to', 'return-to', 'transfer', 'transfer-from', 'transfer-into', 'transfer-to', 'flick', 'flick-up', 'swap', 'distribute-in', 'make-in', 'change-for']",transition 14 | 12,remove,"['dispense', 'extract', 'get-out', 'pick-out', 'pop-out', 'remove', 'remove-from', 'remove-inside', 'remove-into', 'remove-out', 'remove-with', 'take-off', 'take-out', 'unclip', 'unsheathe', 'clear', 'unplug', 'desee', 'remove-off', 'extract-from', 'fish-out']",retrieve 15 | 13,throw,"['bin', 'chuck-into', 'discard', 'dispose-of', 'recycle', 'rubbish-in', 'throw', 'throw-away', 'throw-down', 'throw-for', 'throw-from', 'throw-in', 'throw-inside', 'throw-into', 'throw-off', 'throw-on', 'throw-onto', 'throw-out', 'throw-over', 'toss', 'trash', 'trow', 'swipe-into', 'throw-to', 'remain-in']",leave 16 | 14,dry,"['dry', 'dry-off', 'dry-on', 'dry-with', 'towel', 'dry-in']",clean 17 | 15,shake,"['jiggle', 'shake', 'shake-from', 'shake-in', 'shake-off', 'shake-on', 'shake-onto', 'shake-out', 'shake-over', 'shake-with', 'waft', 'swirl', 'swirl-in']",manipulate 18 | 16,scoop,"['ladle', 'scoop', 'scoop-at', 'scoop-from', 'scoop-in', 'scoop-into', 'scoop-out', 'scoop-up', 'scoop-with', 'shovel', 'shovel-up', 'skim-off', 'spoon', 'spoon-in', 'spoon-into', 'spoon-onto', 'spoon-out']",retrieve 19 | 17,adjust,"['adjust', 'adjust-in', 'adjust-on', 'calibrate', 'change', 'regulate']",monitor 20 | 18,squeeze,"['crumple', 'crumple-up', 'juice', 'scrunch-up', 'squash', 'squeeze', 'squeeze-around', 'squeeze-from', 'squeeze-in', 'squeeze-into', 'squeeze-onto', 'squeeze-out', 'squeeze-over', 'squeeze-with', 'squidge', 'squidge-into', 'squidge-with', 'squish', 'squish-into', 'wring', 'wring-out', 'let-out']",manipulate 21 | 19,peel,"['peel', 'peel-back', 'peel-from', 'peel-off', 'peel-with', 'skin', 'skin-from', 'defoliate']",split 22 | 20,empty,"['empty', 'empty-from', 'empty-into', 'empty-out', 'tip-out', 'tip-to']",split 23 | 21,press,"['collapse', 'compress', 'press', 'press-down', 'press-into', 'press-on', 'press-onto', 'push', 'push-down', 'push-off', 'push-onto', 'push-out', 'push-up', 'press-in', 'push-on', 'compress-in', 'compress-on']",manipulate 24 | 22,flip,"['flip', 'flip-in', 'flip-over', 'flip-through', 'overturn', 'turn-over', 'upturn', 'reverse', 'rotate-to', 'turn-with', 'flip-with', 'flip-onto']",manipulate 25 | 23,turn,"['rotate', 'rotate-around', 'spin', 'turn', 'turn-around', 'turn-away', 'turn-in', 'turn-to']",manipulate 26 | 24,check,"['check', 'check-in', 'check-on', 'check-out', 'ensure', 'examine', 'inspect', 'look-in', 'test', 'watch', 'count', 'realize']",monitor 27 | 25,scrape,"['drag-through', 'loose', 'scrap', 'scrape', 'scrape-against', 'scrape-around', 'scrape-down', 'scrape-from', 'scrape-in', 'scrape-into', 'scrape-of', 'scrape-off', 'scrape-on', 'scrape-onto', 'scrape-out', 'scrape-to', 'scrape-up', 'scrape-with', 'scratch-off', 'level', 'scrap-on', 'detach-from']",clean 28 | 26,fill,"['fill', 'fill-from', 'fill-into', 'fill-up', 'fill-with', 'stuff', 'stuff-with']",merge 29 | 27,apply,"['apply', 'apply-to', 'distribute', 'smear', 'spread', 'spread-around', 'spread-in', 'spread-into', 'spread-on', 'spread-onto', 'spread-out', 'spread-over', 'pump-onto', 'squirt', 'add-on', 'apply-in', 'spread-with']",distribute 30 | 28,fold,"['fold', 'fold-into', 'fold-over', 'fold-up']",order 31 | 29,scrub,"['scrub', 'scrub-down', 'scrub-inside', 'scrub-with', 'scour']",clean 32 | 30,break,"['break', 'break-into', 'break-off', 'break-up', 'break-with', 'crack', 'crack-in', 'crack-into', 'crack-on', 'snap', 'shred', 'stem', 'debone', 'break-apart', 'break-in', 'open-into']",split 33 | 31,pull,"['pull', 'pull-at', 'pull-from', 'pull-in', 'pull-on', 'pull-out', 'pull-through', 'pull-up', 'pull-down']",manipulate 34 | 32,pat,"['dab', 'dab-in', 'pat', 'pat-down', 'pat-into', 'poke', 'prod', 'tap', 'tap-against', 'tap-on', 'tap-with', 'hit', 'hit-on', 'hit-with', 'dab-with']",sense 35 | 33,lift,"['lift', 'lift-from', 'lift-off', 'lift-out', 'lift-up', 'raise', 'tilt']",retrieve 36 | 34,hold,"['hold', 'hold-along', 'hold-down', 'hold-in', 'hold-onto', 'hold-out', 'hold-over', 'hold-with']",manipulate 37 | 35,eat,"['bite', 'chew-on', 'eat', 'eat-from', 'swallow', 'lick', 'sample', 'taste', 'eat-on']",sense 38 | 36,wrap,"['clamp', 'clip', 'cover', 'cover-in', 'cover-with', 'fasten', 'lay-over', 'reseal', 'rewrap', 'seal', 'tie', 'tie-in', 'wrap', 'wrap-around', 'wrap-in', 'wrap-onto', 'wrap-up', 'wrap-with', 'seal-with', 'wrap-over', 'roll-up']",block 39 | 37,filter,"['drain', 'drain-from', 'drain-into', 'drain-on', 'drain-out', 'dump-out', 'filter', 'strain', 'strain-from', 'strain-in', 'drain-in']",split 40 | 38,look,"['look', 'look-at', 'look-behind', 'look-through', 'look-under', 'read', 'read-on', 'see', 'see-at', 'see-off', 'stare-at']",monitor 41 | 39,unroll,"['roll-out', 'unroll', 'unroll-from', 'unfold', 'unfurl']",access 42 | 40,sort,"['arrange', 'arrange-on', 'coil', 'line-up', 'pile', 'rearrange', 'sort', 'sort-out', 'straighten', 'straighten-out', 'tessellate', 'tidy', 'stack', 'stack-in', 'stack-up', 'align', 'rearrange-in', 'arrange-in', 'sort-in']",order 43 | 41,hang,"['drape', 'hang', 'hang-up', 'hand', 'hang-on']",leave 44 | 42,sprinkle,"['crumble', 'crumble-into', 'distribute-on', 'distribute-onto', 'scatter', 'sprincle', 'sprinkle', 'sprinkle-in', 'sprinkle-into', 'sprinkle-on', 'sprinkle-onto', 'sprinkle-over', 'scatter-onto']",distribute 45 | 43,rip,"['rip', 'rip-off', 'tear', 'tear-down', 'tear-from', 'tear-in', 'tear-into', 'tear-off', 'tear-on', 'tear-out', 'tear-up', 'rip-with']",split 46 | 44,spray,"['spay', 'spray', 'spray-down', 'spray-on']",distribute 47 | 45,cook,"['boil', 'cook', 'fry', 'fry-in', 'heat', 'steam', 'stir-fry', 'toast', 'cook-into', 'boil-with']",manipulate 48 | 46,add,"['add', 'add-from', 'add-into', 'add-to', 'combine', 'combine-into', 'add-in']",merge 49 | 47,roll,"['roll', 'roll-around', 'roll-down', 'roll-in', 'roll-into', 'roll-on', 'roll-over']",block 50 | 48,search,"['browse', 'find', 'fumble-for', 'locate', 'look-for', 'look-inside', 'rifle-through', 'rummage-in', 'search', 'search-for', 'search-in']",monitor 51 | 49,crush,"['bash-against', 'crush', 'crush-into', 'crush-with', 'grind', 'grind-into', 'grind-on', 'mash', 'mash-with', 'soften', 'smash', 'grind-onto', 'crush-in']",split 52 | 50,stretch,"['stretch', 'stretch-around', 'stretch-onto', 'stretch-out']",manipulate 53 | 51,knead,['knead'],manipulate 54 | 52,divide,"['detach', 'disassemble', 'divide', 'separate', 'separate-from', 'separate-out', 'split', 'split-into', 'unattach', 'divide-into']",split 55 | 53,set,"['set', 'set-in', 'set-on', 'set-out', 'set-off', 'zero', 'set-to']",manipulate 56 | 54,feel,"['feel', 'touch']",sense 57 | 55,rub,"['rub', 'rub-in', 'rub-into', 'rub-off', 'rub-on', 'rub-onto', 'rub-over', 'rub-with', 'scratch', 'deburr', 'rub-around']",clean 58 | 56,soak,"['immerge', 'immerse', 'soak', 'soak-in', 'submerge', 'submerge-in']",clean 59 | 57,brush,"['brush', 'brush-into', 'brush-off', 'brush-onto', 'sweep', 'sweep-from', 'sweep-into', 'sweep-off', 'sweep-up', 'brush-on', 'paste-onto']",clean 60 | 58,sharpen,"['hone', 'sharpen', 'thin', 'forge']",manipulate 61 | 59,drop,"['drop', 'drop-down', 'drop-in', 'drop-on', 'drop-onto', 'fall-on', 'release-from', 'spill', 'tip-over']",leave 62 | 60,drink,"['drink', 'drink-from']",sense 63 | 61,slide,"['slide', 'slide-acros', 'slide-along', 'slide-down', 'slide-from', 'slide-inside', 'slide-off', 'slide-on', 'slide-out']",manipulate 64 | 62,water,"['damp', 'water', 'wet']",manipulate 65 | 63,gather,"['collect', 'gather', 'gather-in', 'gather-into', 'gather-with', 'gather-from']",retrieve 66 | 64,attach,"['assemble', 'attach', 'attach-onto', 'attach-to', 'clip-on', 'connect', 'plug', 'plug-in', 'plug-into', 'reassemble', 'reattach', 'snap-in', 'fix', 'expand']",merge 67 | 65,turn-down,['turn-down'],monitor 68 | 66,coat,"['batter', 'coat', 'brush-with', 'dip', 'marinate', 'butter-in', 'oil', 'smooth-into']",merge 69 | 67,transition,"['transition', 'enter', 'enter-into', 'walk', 'walk-around', 'walk-into', 'walk-to', 'walk-with', 'sit-on', 'step', 'step-off', 'step-on', 'stand-up']",transition 70 | 68,wear,['wear'],manipulate 71 | 69,measure,"['measure', 'measure-into', 'measure-out', 'weigh', 'measure-in']",monitor 72 | 70,increase,"['increase', 'set-up', 'switch-up', 'turn-up', 'increase-on', 'rise-up']",monitor 73 | 71,unscrew,"['screw-off', 'twist-off', 'unscrew', 'unscrew-from', 'untwist']",access 74 | 72,wait,"['wait', 'wait-for', 'decide-if', 'decide']",monitor 75 | 73,lower,"['lower', 'untilt']",monitor 76 | 74,form,"['ball-up', 'form', 'form-into', 'reconstruct', 'shape', 'shape-into', 'form-from']",manipulate 77 | 75,smell,"['smell', 'smell-in', 'sniff', 'blow', 'blow-out']",sense 78 | 76,use,['use'],manipulate 79 | 77,grate,"['grate', 'grate-into', 'grate-on', 'grate-onto', 'zest-into']",split 80 | 78,screw,"['screw', 'screw-in', 'tighten', 'twist']",manipulate 81 | 79,let-go,"['let', 'let-go']",leave 82 | 80,finish,"['end', 'finish', 'stop']",manipulate 83 | 81,stab,"['pierce', 'pierce-in', 'stab', 'fork']",split 84 | 82,serve,"['dish', 'dish-up', 'plate', 'plate-on', 'plate-up', 'serve', 'serve-in', 'serve-on']",leave 85 | 83,uncover,"['lid-off', 'uncover']",access 86 | 84,unwrap,"['unpack', 'unravel', 'unseal', 'unwrap']",access 87 | 85,choose,"['choose', 'select']",retrieve 88 | 86,lock,"['lock', 'lock-in', 'lock-into', 'lock-on']",block 89 | 87,flatten,"['flatten', 'flatten-out', 'flatten-with']",manipulate 90 | 88,switch,"['switch', 'switch-to']",access 91 | 89,carry,"['bring', 'carry', 'have', 'bring-to']",transition 92 | 90,season,"['pepper', 'salt', 'season', 'sweeten', 'salt-in', 'pepper-in', 'season-with']",distribute 93 | 91,unlock,"['unlock', 'unlock-from']",access 94 | 92,prepare,"['prepare', 'prepare-for']",manipulate 95 | 93,bake,['bake'],manipulate 96 | 94,mark,"['mark', 'mark-on']",manipulate 97 | 95,bend,['bend'],manipulate 98 | 96,unfreeze,['unfreeze'],monitor -------------------------------------------------------------------------------- /datasets/epic_kitchens.py: -------------------------------------------------------------------------------- 1 | """The Epic Kitchens dataset loaders, this class also supports EGTEA Gaze+ dataset""" 2 | 3 | from typing import List, Dict, Sequence, Tuple, Union 4 | from datetime import datetime, date 5 | from collections import OrderedDict 6 | import pickle as pkl 7 | import csv 8 | import logging 9 | from pathlib import Path 10 | import pandas as pd 11 | import torch 12 | 13 | from .base_video_dataset import BaseVideoDataset, RULSTM_TSN_FPS 14 | 15 | EGTEA_VERSION = -1 # This class also supports EGTEA Gaze+ 16 | EPIC55_VERSION = 0.1 17 | EPIC100_VERSION = 0.2 18 | 19 | 20 | class EPICKitchens(BaseVideoDataset): 21 | """EPICKitchens and EGTEA dataloader.""" 22 | 23 | def __init__(self, 24 | annotation_path: Sequence[Path], 25 | action_labels_fpath: Path = None, 26 | annotation_dir: Path = None, 27 | rulstm_annotation_dir: Path = None, 28 | version: float = EPIC55_VERSION, 29 | **other_kwargs, 30 | ): 31 | """ 32 | Args: 33 | label_type (str): The type of label to return 34 | action_labels_fpath (Path): Path to map the verb and noun labels to 35 | actions. It was used in the anticipation paper, that defines 36 | a set of actions and train for action prediction, as opposed 37 | to verb and noun prediction. 38 | annotation_dir: Where all the other annotations are typically stored 39 | """ 40 | self.version = version 41 | df = pd.concat([self._load_df(el) for el in annotation_path]) 42 | df.reset_index(inplace=True, drop=True) # to combine all of them 43 | self.annotation_dir = Path(annotation_dir) 44 | self.rulstm_annotation_dir = rulstm_annotation_dir 45 | 46 | # Load verb and noun classes 47 | epic_postfix = '' 48 | if self.version == EPIC100_VERSION: 49 | epic_postfix = '_100' 50 | if self.version != EGTEA_VERSION: 51 | verb_classes = self._load_class_names(self.annotation_dir / f'EPIC{epic_postfix}_verb_classes.csv') 52 | noun_classes = self._load_class_names(self.annotation_dir / f'EPIC{epic_postfix}_noun_classes.csv') 53 | else: 54 | verb_classes, noun_classes = [], [] 55 | 56 | # Create action classes 57 | if action_labels_fpath is not None: 58 | load_action_fn = self._load_action_classes 59 | if self.version == EGTEA_VERSION: 60 | load_action_fn = self._load_action_classes_egtea 61 | action_classes, verb_noun_to_action = load_action_fn(action_labels_fpath) 62 | else: 63 | logging.warning('Action labels were not provided. Generating actions ...') 64 | action_classes, verb_noun_to_action = self._gen_all_actions(verb_classes, noun_classes) 65 | 66 | # Add the action classes to the data frame 67 | if 'action_class' not in df.columns and {'noun_class', 'verb_class'}.issubset(df.columns): 68 | df.loc[:, 'action_class'] = df.loc[:, ('verb_class', 'noun_class')].apply( 69 | lambda row: (verb_noun_to_action[(row.at['verb_class'], row.at['noun_class'])] 70 | if (row.at['verb_class'], row.at['noun_class']) in verb_noun_to_action else -1), axis=1) 71 | elif 'action_class' not in df.columns: 72 | df.loc[:, 'action_class'] = -1 73 | df.loc[:, 'verb_class'] = -1 74 | df.loc[:, 'noun_class'] = -1 75 | num_undefined_actions = len(df[df['action_class'] == -1].index) 76 | if num_undefined_actions > 0: 77 | logging.error(f'Did not found valid action label for {num_undefined_actions}/{len(df)} samples!') 78 | 79 | other_kwargs['verb_classes'] = verb_classes 80 | other_kwargs['noun_classes'] = noun_classes 81 | other_kwargs['action_classes'] = action_classes 82 | 83 | super().__init__(df, **other_kwargs) 84 | self.verb_noun_to_action = verb_noun_to_action 85 | logging.info(f'Created EPIC {self.version} dataset with {len(self)} samples') 86 | 87 | @property 88 | def class_mappings(self) -> Dict[Tuple[str, str], torch.FloatTensor]: 89 | num_verbs = len(self.verb_classes) 90 | if num_verbs == 0: 91 | num_verbs = len(set([el[0] for el, _ in self.verb_noun_to_action.items()])) 92 | num_nouns = len(self.noun_classes) 93 | if num_nouns == 0: 94 | num_nouns = len(set([el[1] for el, _ in self.verb_noun_to_action.items()])) 95 | num_actions = len(self.action_classes) 96 | if num_actions == 0: 97 | num_actions = len(set([el for _, el in self.verb_noun_to_action.items()])) 98 | verb_in_action = torch.zeros((num_actions, num_verbs), dtype=torch.float) 99 | noun_in_action = torch.zeros((num_actions, num_nouns), dtype=torch.float) 100 | for (verb, noun), action in self.verb_noun_to_action.items(): 101 | verb_in_action[action, verb] = 1.0 102 | noun_in_action[action, noun] = 1.0 103 | return { 104 | ('verb', 'action'): verb_in_action, 105 | ('noun', 'action'): noun_in_action 106 | } 107 | 108 | @property 109 | def classes_manyshot(self) -> OrderedDict: 110 | """ 111 | In EPIC-55, the recall computation was done for "many shot" classes, 112 | and not for all classes. So, for that version read the class names as 113 | provided by RULSTM.""" 114 | if self.version != EPIC55_VERSION: 115 | return super().classes_manyshot 116 | # read the list of many shot verbs 117 | many_shot_verbs = { 118 | el['verb']: el['verb_class'] 119 | for el in pd.read_csv(self.annotation_dir / 'EPIC_many_shot_verbs.csv').to_dict('records') 120 | } 121 | # read the list of many shot nouns 122 | many_shot_nouns = { 123 | el['noun']: el['noun_class'] 124 | for el in pd.read_csv(self.annotation_dir / 'EPIC_many_shot_nouns.csv').to_dict('records') 125 | } 126 | # create the list of many shot actions 127 | # an action is "many shot" if at least one between the related verb and noun are many shot 128 | many_shot_actions = {} 129 | action_names = {val: key for key, val in self.action_classes.items()} 130 | for (verb_id, noun_id), action_id in self.verb_noun_to_action.items(): 131 | if (verb_id in many_shot_verbs.values()) or (noun_id in many_shot_nouns.values()): 132 | many_shot_actions[action_names[action_id]] = action_id 133 | return { 134 | 'verb': many_shot_verbs, 135 | 'noun': many_shot_nouns, 136 | 'action': many_shot_actions, 137 | } 138 | 139 | def _load_class_names(self, annot_path: Path): 140 | res = {} 141 | with open(annot_path, 'r') as fin: 142 | reader = csv.DictReader(fin, delimiter=',') 143 | for lno, line in enumerate(reader): 144 | res[line['class_key' if self.version == 145 | EPIC55_VERSION else 'key']] = lno 146 | return res 147 | 148 | @staticmethod 149 | def _load_action_classes(action_labels_fpath: Path) -> Tuple[Dict[str, int], Dict[Tuple[int, int], int]]: 150 | """ 151 | Given a CSV file with the actions (as from RULSTM paper), construct the set of actions and mapping from verb/noun to action 152 | Args: 153 | action_labels_fpath: path to the file 154 | Returns: 155 | class_names: Dict of action class names 156 | verb_noun_to_action: Mapping from verb/noun to action IDs 157 | """ 158 | class_names = {} 159 | verb_noun_to_action = {} 160 | with open(action_labels_fpath, 'r') as fin: 161 | reader = csv.DictReader(fin, delimiter=',') 162 | for lno, line in enumerate(reader): 163 | class_names[line['action']] = lno 164 | verb_noun_to_action[(int(line['verb']), int(line['noun']))] = int(line['id']) 165 | return class_names, verb_noun_to_action 166 | 167 | @staticmethod 168 | def _load_action_classes_egtea(action_labels_fpath: Path) -> Tuple[Dict[str, int], Dict[Tuple[int, int], int]]: 169 | """ 170 | Given a CSV file with the actions (as from RULSTM paper), construct the set of actions and mapping from verb/noun to action 171 | Args: 172 | action_labels_fpath: path to the file 173 | Returns: 174 | class_names: Dict of action class names 175 | verb_noun_to_action: Mapping from verb/noun to action IDs 176 | """ 177 | class_names = {} 178 | verb_noun_to_action = {} 179 | with open(action_labels_fpath, 'r') as fin: 180 | reader = csv.DictReader(fin, delimiter=',', 181 | # Assuming the order is verb/noun 182 | # TODO check if that is correct 183 | fieldnames=['id', 'verb_noun', 'action']) 184 | for lno, line in enumerate(reader): 185 | class_names[line['action']] = lno 186 | verb, noun = [int(el) for el in line['verb_noun'].split('_')] 187 | verb_noun_to_action[(verb, noun)] = int(line['id']) 188 | return class_names, verb_noun_to_action 189 | 190 | @staticmethod 191 | def _gen_all_actions(verb_classes: List[str], noun_classes: List[str]) -> Tuple[ 192 | Dict[str, int], Dict[Tuple[int, int], int]]: 193 | """ 194 | Given all possible verbs and nouns, construct all possible actions 195 | Args: 196 | verb_classes: All verbs 197 | noun_classes: All nouns 198 | Returns: 199 | class_names: list of action class names 200 | verb_noun_to_action: Mapping from verb/noun to action IDs 201 | """ 202 | class_names = {} 203 | verb_noun_to_action = {} 204 | action_id = 0 205 | for verb_id, verb_cls in enumerate(verb_classes): 206 | for noun_id, noun_cls in enumerate(noun_classes): 207 | class_names[f'{verb_cls}:{noun_cls}'] = action_id 208 | verb_noun_to_action[(verb_id, noun_id)] = action_id 209 | action_id += 1 210 | return class_names, verb_noun_to_action 211 | 212 | def _init_df_orig(self, annotation_path): 213 | """Loading the original EPIC Kitchens annotations""" 214 | 215 | def timestr_to_sec(s, fmt='%H:%M:%S.%f'): 216 | # Convert timestr to seconds 217 | timeobj = datetime.strptime(s, fmt).time() 218 | td = datetime.combine(date.min, timeobj) - datetime.min 219 | return td.total_seconds() 220 | 221 | # Load the DF from annot path 222 | logging.info(f'Loading original EPIC pkl annotations {annotation_path}') 223 | with open(annotation_path, 'rb') as fin: 224 | df = pkl.load(fin) 225 | # Make a copy of the UID column, since that will be needed to gen output files 226 | df.reset_index(drop=False, inplace=True) 227 | 228 | # parse timestamps from the video 229 | df.loc[:, 'start'] = df.start_timestamp.apply(timestr_to_sec) 230 | df.loc[:, 'end'] = df.stop_timestamp.apply(timestr_to_sec) 231 | 232 | # original annotations have text in weird format - fix that 233 | if 'noun' in df.columns: 234 | df.loc[:, 'noun'] = df.loc[:, 'noun'].apply(lambda s: ' '.join(s.replace(':', ' ').split(sep=' ')[::-1])) 235 | if 'verb' in df.columns: 236 | df.loc[:, 'verb'] = df.loc[:, 'verb'].apply(lambda s: ' '.join(s.replace('-', ' ').split(sep=' '))) 237 | df = self._init_df_gen_vidpath(df) 238 | df.reset_index(inplace=True, drop=True) 239 | return df 240 | 241 | def _init_df_gen_vidpath(self, df): 242 | # generate video_path 243 | if self.version == EGTEA_VERSION: 244 | df.loc[:, 'video_path'] = df.apply(lambda x: Path(x.video_id + '.mp4'), axis=1) 245 | else: # For the EPIC datasets 246 | df.loc[:, 'video_path'] = df.apply(lambda x: (Path(x.participant_id) / Path(x.video_id + '.MP4')), axis=1) 247 | return df 248 | 249 | def _init_df_rulstm(self, annotation_path): 250 | logging.info('Loading RULSTM EPIC csv annotations %s', annotation_path) 251 | df = pd.read_csv(annotation_path, 252 | names=['uid', 'video_id', 'start_frame_30fps', 'end_frame_30fps', 'verb_class', 'noun_class', 253 | 'action_class'], 254 | index_col=0, 255 | skipinitialspace=True, 256 | dtype={'uid': str, 'video_id': str, 'start_frame_30fps': int, 'end_frame_30fps': int, 257 | 'verb_class': int, 'noun_class': int, 'action_class': int}) 258 | # Make a copy of the UID column, since that will be needed to gen output files 259 | df.reset_index(drop=False, inplace=True) 260 | # Convert the frame number to start and end 261 | df.loc[:, 'start'] = df.loc[:, 'start_frame_30fps'].apply(lambda x: x / RULSTM_TSN_FPS) 262 | df.loc[:, 'end'] = df.loc[:, 'end_frame_30fps'].apply(lambda x: x / RULSTM_TSN_FPS) 263 | # Participant ID from video_id 264 | df.loc[:, 'participant_id'] = df.loc[:, 'video_id'].apply(lambda x: x.split('_')[0]) 265 | df = self._init_df_gen_vidpath(df) 266 | df.reset_index(inplace=True, drop=True) 267 | return df 268 | 269 | def _load_df(self, annotation_path): 270 | if annotation_path.endswith('.pkl'): 271 | return self._init_df_orig(annotation_path) 272 | elif annotation_path.endswith('.csv'): 273 | # Else, it must be the RULSTM annotations (fps 30) 274 | return self._init_df_rulstm(annotation_path) 275 | else: 276 | raise NotImplementedError(annotation_path) 277 | --------------------------------------------------------------------------------