├── models ├── __init__.py ├── model_utils.py ├── head_pose_estimator.py ├── model_selector.py ├── inferring_shared_attention_estimation_debug.py └── hourglass.py ├── dataset ├── __init__.py └── dataset_selector.py ├── ICCV2023-PJAE.png ├── count_gt_jumping_videocoatt_jumping.png ├── count_gt_jumping_videocoatt_diminishing.png ├── analysis ├── analyze_ja_tem_consistency │ ├── 11_71210 │ │ ├── move_abs.png │ │ └── move_diff.png │ ├── 16_29570 │ │ ├── move_abs.png │ │ └── move_diff.png │ ├── 21_68880 │ │ ├── move_abs.png │ │ └── move_diff.png │ ├── 22_51790 │ │ ├── move_abs.png │ │ └── move_diff.png │ ├── 23_51240 │ │ ├── move_abs.png │ │ └── move_diff.png │ ├── 25_29675 │ │ ├── move_abs.png │ │ └── move_diff.png │ ├── 28_13105 │ │ ├── move_abs.png │ │ └── move_diff.png │ ├── 28_27115 │ │ ├── move_abs.png │ │ └── move_diff.png │ ├── 33_23025 │ │ ├── move_abs.png │ │ └── move_diff.png │ ├── 36_17455 │ │ ├── move_abs.png │ │ └── move_diff.png │ ├── 40_1190 │ │ ├── move_abs.png │ │ └── move_diff.png │ ├── 44_17755 │ │ ├── move_abs.png │ │ └── move_diff.png │ ├── 44_37770 │ │ ├── move_abs.png │ │ └── move_diff.png │ ├── 4_24805 │ │ ├── move_abs.png │ │ └── move_diff.png │ ├── 52_4290 │ │ ├── move_abs.png │ │ └── move_diff.png │ └── 53_15815 │ │ ├── move_abs.png │ │ └── move_diff.png ├── iccv2023 │ ├── videocoatt_ja_size_analysis.py │ ├── videocoatt_ja_existance_analysis.py │ ├── p_s_gen_ablation_on_videocoatt.py │ ├── p_p_gen_ablation_on_videocoatt.py │ ├── fusion_ablation_on_videocoatt.py │ ├── self_attetnion_ablation.py │ ├── gt_pred_ablation_on_volleyball.py │ ├── comparison_ja_trans_on_videocoatt.py │ ├── comparison_ja_trans_on_volleyball.py │ ├── gt_pred_ablation_on_videocoatt.py │ ├── p_p_agg_ablation_on_volleyball.py │ ├── fusion_ablation_on_volleyball.py │ ├── tran_enc_comparision_on_volleyball.py │ ├── p_s_gen_ablation_on_volleyball.py │ ├── multi_heads_comparision_on_volleyball.py │ ├── p_p_gen_ablation_on_volleyball.py │ ├── gaussian_comparison_on_volleyball.py │ ├── comparison_finetune_on_volleyball.py │ ├── p_p_agg_ablation_on_videocoatt.py │ ├── blur_ablation_on_volleyball.py │ ├── input_ablation_on_volleyball.py │ ├── comparison_on_volleyball.py │ ├── comparison_finetune_on_videocoatt.py │ ├── input_ablation_on_videocoatt.py │ └── comparison_on_videocoatt.py ├── utils.py ├── analyze_ja_tem_consistency.py ├── comparison_on_volleyball.py ├── iccv_comparison_on_vollleyball.py └── p_p_mask_comparison_on_vollleyball.py ├── yaml_files ├── gazefollow │ ├── demo.yaml │ ├── debug_ours.yaml │ └── train_ours.yaml ├── toy │ ├── demo.yaml │ ├── debug.yaml │ └── train.yaml ├── videoattentiontarget │ ├── demo.yaml │ ├── debug_hgt.yaml │ ├── train_hgt.yaml │ ├── train_ours.yaml │ └── debug_ours.yaml ├── volleyball │ ├── debug_hgt.yaml │ ├── eval.yaml │ ├── train_hgt.yaml │ ├── debug_isa.yaml │ ├── train_isa.yaml │ ├── demo.yaml │ ├── train_ours_p_p.yaml │ ├── debug_ours_p_p.yaml │ └── train_ours.yaml └── videocoatt │ ├── demo.yaml │ ├── debug_isa.yaml │ ├── train_isa.yaml │ ├── eval.yaml │ ├── debug_ours.yaml │ └── debug_ours_p_p.yaml ├── LICENSE ├── print_weights.py ├── eval_vol.bash ├── .gitignore ├── count_gt_jumping_videocoatt.py ├── make_gif_from_all_images.py ├── make_video_from_images.py ├── requirements.txt └── README.md /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ICCV2023-PJAE.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chihina/PJAE-ICCV2023/HEAD/ICCV2023-PJAE.png -------------------------------------------------------------------------------- /count_gt_jumping_videocoatt_jumping.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chihina/PJAE-ICCV2023/HEAD/count_gt_jumping_videocoatt_jumping.png -------------------------------------------------------------------------------- /count_gt_jumping_videocoatt_diminishing.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chihina/PJAE-ICCV2023/HEAD/count_gt_jumping_videocoatt_diminishing.png -------------------------------------------------------------------------------- /analysis/analyze_ja_tem_consistency/11_71210/move_abs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chihina/PJAE-ICCV2023/HEAD/analysis/analyze_ja_tem_consistency/11_71210/move_abs.png -------------------------------------------------------------------------------- /analysis/analyze_ja_tem_consistency/11_71210/move_diff.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chihina/PJAE-ICCV2023/HEAD/analysis/analyze_ja_tem_consistency/11_71210/move_diff.png -------------------------------------------------------------------------------- /analysis/analyze_ja_tem_consistency/16_29570/move_abs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chihina/PJAE-ICCV2023/HEAD/analysis/analyze_ja_tem_consistency/16_29570/move_abs.png -------------------------------------------------------------------------------- /analysis/analyze_ja_tem_consistency/16_29570/move_diff.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chihina/PJAE-ICCV2023/HEAD/analysis/analyze_ja_tem_consistency/16_29570/move_diff.png -------------------------------------------------------------------------------- /analysis/analyze_ja_tem_consistency/21_68880/move_abs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chihina/PJAE-ICCV2023/HEAD/analysis/analyze_ja_tem_consistency/21_68880/move_abs.png -------------------------------------------------------------------------------- /analysis/analyze_ja_tem_consistency/21_68880/move_diff.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chihina/PJAE-ICCV2023/HEAD/analysis/analyze_ja_tem_consistency/21_68880/move_diff.png -------------------------------------------------------------------------------- /analysis/analyze_ja_tem_consistency/22_51790/move_abs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chihina/PJAE-ICCV2023/HEAD/analysis/analyze_ja_tem_consistency/22_51790/move_abs.png -------------------------------------------------------------------------------- /analysis/analyze_ja_tem_consistency/22_51790/move_diff.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chihina/PJAE-ICCV2023/HEAD/analysis/analyze_ja_tem_consistency/22_51790/move_diff.png -------------------------------------------------------------------------------- /analysis/analyze_ja_tem_consistency/23_51240/move_abs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chihina/PJAE-ICCV2023/HEAD/analysis/analyze_ja_tem_consistency/23_51240/move_abs.png -------------------------------------------------------------------------------- /analysis/analyze_ja_tem_consistency/23_51240/move_diff.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chihina/PJAE-ICCV2023/HEAD/analysis/analyze_ja_tem_consistency/23_51240/move_diff.png -------------------------------------------------------------------------------- /analysis/analyze_ja_tem_consistency/25_29675/move_abs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chihina/PJAE-ICCV2023/HEAD/analysis/analyze_ja_tem_consistency/25_29675/move_abs.png -------------------------------------------------------------------------------- /analysis/analyze_ja_tem_consistency/25_29675/move_diff.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chihina/PJAE-ICCV2023/HEAD/analysis/analyze_ja_tem_consistency/25_29675/move_diff.png -------------------------------------------------------------------------------- /analysis/analyze_ja_tem_consistency/28_13105/move_abs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chihina/PJAE-ICCV2023/HEAD/analysis/analyze_ja_tem_consistency/28_13105/move_abs.png -------------------------------------------------------------------------------- /analysis/analyze_ja_tem_consistency/28_13105/move_diff.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chihina/PJAE-ICCV2023/HEAD/analysis/analyze_ja_tem_consistency/28_13105/move_diff.png -------------------------------------------------------------------------------- /analysis/analyze_ja_tem_consistency/28_27115/move_abs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chihina/PJAE-ICCV2023/HEAD/analysis/analyze_ja_tem_consistency/28_27115/move_abs.png -------------------------------------------------------------------------------- /analysis/analyze_ja_tem_consistency/28_27115/move_diff.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chihina/PJAE-ICCV2023/HEAD/analysis/analyze_ja_tem_consistency/28_27115/move_diff.png -------------------------------------------------------------------------------- /analysis/analyze_ja_tem_consistency/33_23025/move_abs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chihina/PJAE-ICCV2023/HEAD/analysis/analyze_ja_tem_consistency/33_23025/move_abs.png -------------------------------------------------------------------------------- /analysis/analyze_ja_tem_consistency/33_23025/move_diff.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chihina/PJAE-ICCV2023/HEAD/analysis/analyze_ja_tem_consistency/33_23025/move_diff.png -------------------------------------------------------------------------------- /analysis/analyze_ja_tem_consistency/36_17455/move_abs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chihina/PJAE-ICCV2023/HEAD/analysis/analyze_ja_tem_consistency/36_17455/move_abs.png -------------------------------------------------------------------------------- /analysis/analyze_ja_tem_consistency/36_17455/move_diff.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chihina/PJAE-ICCV2023/HEAD/analysis/analyze_ja_tem_consistency/36_17455/move_diff.png -------------------------------------------------------------------------------- /analysis/analyze_ja_tem_consistency/40_1190/move_abs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chihina/PJAE-ICCV2023/HEAD/analysis/analyze_ja_tem_consistency/40_1190/move_abs.png -------------------------------------------------------------------------------- /analysis/analyze_ja_tem_consistency/40_1190/move_diff.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chihina/PJAE-ICCV2023/HEAD/analysis/analyze_ja_tem_consistency/40_1190/move_diff.png -------------------------------------------------------------------------------- /analysis/analyze_ja_tem_consistency/44_17755/move_abs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chihina/PJAE-ICCV2023/HEAD/analysis/analyze_ja_tem_consistency/44_17755/move_abs.png -------------------------------------------------------------------------------- /analysis/analyze_ja_tem_consistency/44_17755/move_diff.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chihina/PJAE-ICCV2023/HEAD/analysis/analyze_ja_tem_consistency/44_17755/move_diff.png -------------------------------------------------------------------------------- /analysis/analyze_ja_tem_consistency/44_37770/move_abs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chihina/PJAE-ICCV2023/HEAD/analysis/analyze_ja_tem_consistency/44_37770/move_abs.png -------------------------------------------------------------------------------- /analysis/analyze_ja_tem_consistency/44_37770/move_diff.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chihina/PJAE-ICCV2023/HEAD/analysis/analyze_ja_tem_consistency/44_37770/move_diff.png -------------------------------------------------------------------------------- /analysis/analyze_ja_tem_consistency/4_24805/move_abs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chihina/PJAE-ICCV2023/HEAD/analysis/analyze_ja_tem_consistency/4_24805/move_abs.png -------------------------------------------------------------------------------- /analysis/analyze_ja_tem_consistency/4_24805/move_diff.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chihina/PJAE-ICCV2023/HEAD/analysis/analyze_ja_tem_consistency/4_24805/move_diff.png -------------------------------------------------------------------------------- /analysis/analyze_ja_tem_consistency/52_4290/move_abs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chihina/PJAE-ICCV2023/HEAD/analysis/analyze_ja_tem_consistency/52_4290/move_abs.png -------------------------------------------------------------------------------- /analysis/analyze_ja_tem_consistency/52_4290/move_diff.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chihina/PJAE-ICCV2023/HEAD/analysis/analyze_ja_tem_consistency/52_4290/move_diff.png -------------------------------------------------------------------------------- /analysis/analyze_ja_tem_consistency/53_15815/move_abs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chihina/PJAE-ICCV2023/HEAD/analysis/analyze_ja_tem_consistency/53_15815/move_abs.png -------------------------------------------------------------------------------- /analysis/analyze_ja_tem_consistency/53_15815/move_diff.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chihina/PJAE-ICCV2023/HEAD/analysis/analyze_ja_tem_consistency/53_15815/move_diff.png -------------------------------------------------------------------------------- /yaml_files/gazefollow/demo.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | name: gazefollow 3 | dataset_dir : data/gazefollow 4 | 5 | exp_set: 6 | save_folder : saved_weights 7 | model_name: gazefollow-dual-cnn-w_pre 8 | 9 | seed_num : 777 10 | gpu_mode : True 11 | gpu_start : 4 12 | gpu_finish : 4 13 | num_workers : 1 14 | batch_size : 1 15 | wandb_name : demo 16 | 17 | mode: test 18 | # mode : validate 19 | # mode : train 20 | 21 | exp_params: 22 | test_gt_gaze : False 23 | # test_gt_gaze : True 24 | 25 | # vis_dist_error: False 26 | vis_dist_error: True -------------------------------------------------------------------------------- /yaml_files/toy/demo.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | name: toy 3 | dataset_dir : data/joint_attention_toy 4 | 5 | exp_set: 6 | save_folder : saved_weights 7 | # model_name: toy-concat_direct 8 | # model_name: toy-concat_independent 9 | # model_name: toy-concat_independent_angle_mask_feat 10 | 11 | model_name: toy-cnn 12 | 13 | seed_num : 777 14 | gpu_mode : True 15 | gpu_start : 7 16 | gpu_finish : 7 17 | num_workers : 1 18 | batch_size : 1 19 | wandb_name : demo 20 | 21 | mode: test 22 | # mode : validate 23 | # mode : train 24 | 25 | exp_params: 26 | # use_gt_gaze : False 27 | use_gt_gaze : True -------------------------------------------------------------------------------- /yaml_files/videoattentiontarget/demo.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | name: videoattentiontarget 3 | dataset_dir : data/videoattentiontarget 4 | 5 | exp_set: 6 | save_folder : saved_weights 7 | # model_name: videoattentiontarget-head_pose_estimator 8 | # model_name: videoattentiontarget-concat_independent 9 | 10 | # model_name: videoattentiontarget-hgt 11 | # model_name: videoattentiontarget-hgt-high 12 | # model_name: videoattentiontarget-hgt-1101 13 | # model_name: videoattentiontarget-hgt_bbox_PRED 14 | 15 | # model_name: videoattentiontarget-only_davt_PRED 16 | 17 | seed_num : 777 18 | gpu_mode : True 19 | gpu_start : 4 20 | gpu_finish : 4 21 | num_workers : 1 22 | batch_size : 1 23 | wandb_name : demo 24 | 25 | # mode: test 26 | # mode : validate 27 | mode : train 28 | 29 | exp_params: 30 | test_gt_gaze : False 31 | # test_gt_gaze : True 32 | 33 | # vis_dist_error: False 34 | vis_dist_error: True -------------------------------------------------------------------------------- /analysis/iccv2023/videocoatt_ja_size_analysis.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import glob 4 | 5 | dataset_dir = 'data/VideoCoAtt_Dataset' 6 | 7 | # joint attetntion detection analysis 8 | ja_cnt_dic = {} 9 | ja_cnt_dic[0] = 349468 10 | ja_cnt_dic[1] = 139348 11 | ja_cnt_dic[2] = 3284 12 | all_ja_cnt = sum(ja_cnt_dic.values()) 13 | no_ja_cnt = ja_cnt_dic[0] 14 | print('Joint attention exsistance ratio') 15 | print(f'{no_ja_cnt/all_ja_cnt}={no_ja_cnt}/{all_ja_cnt}') 16 | 17 | # joint attetion size analysis 18 | size_list = [] 19 | for ann_path in glob.glob(os.path.join(dataset_dir, 'annotations', '*', '*.txt')): 20 | with open(ann_path, 'r') as f: 21 | ann_lines = f.readlines() 22 | for line in ann_lines: 23 | x_min, y_min, x_max, y_max = map(float, line.strip().split()[2:6]) 24 | width = (x_max-x_min) 25 | height = (y_max-y_min) 26 | size = (width+height)//4 27 | size_list.append(size) 28 | 29 | print('Joint attention size') 30 | print(np.mean(np.array(size_list))) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Chihiro Nakatani 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /analysis/utils.py: -------------------------------------------------------------------------------- 1 | import openpyxl 2 | from openpyxl.styles import PatternFill 3 | 4 | def refine_excel(save_excel_file_path): 5 | wb = openpyxl.load_workbook(save_excel_file_path) 6 | ws = wb['all'] 7 | for col_idx, col in enumerate(ws.iter_cols()): 8 | if col_idx == 0: 9 | pass 10 | else: 11 | col_name = col[0].value 12 | if 'dist' in col_name: 13 | row_judge_val = min([cell.value for cell in col if type(cell.value) is not str]) 14 | else: 15 | row_judge_val = max([cell.value for cell in col if type(cell.value) is not str]) 16 | 17 | for row_idx, cell in enumerate(col): 18 | cell.number_format = '0.0' 19 | if cell.value == row_judge_val: 20 | cell.fill = PatternFill(fgColor='FFFF00', bgColor="FFFF00", fill_type = "solid") 21 | cell.value = r'\red' + '{' + f'{cell.value:.1f}' + '}' 22 | # cell.fill = PatternFill(fgColor='FFFF00', bgColor="FFFF00", fill_type = "solid") 23 | # cell.value = r'\red' + '{' + f'{cell.value:.3f}' + '}' 24 | 25 | wb.save(save_excel_file_path) -------------------------------------------------------------------------------- /analysis/iccv2023/videocoatt_ja_existance_analysis.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import glob 4 | import sys 5 | 6 | dataset_dir = 'data/VideoCoAtt_Dataset' 7 | 8 | # joint attetntion detection analysis 9 | ja_cnt_dic = {} 10 | ja_cnt_dic[0] = 349468 11 | ja_cnt_dic[1] = 139348 12 | ja_cnt_dic[2] = 3284 13 | all_ja_cnt = sum(ja_cnt_dic.values()) 14 | no_ja_cnt = ja_cnt_dic[0] 15 | print('Joint attention exsistance ratio') 16 | print(f'{no_ja_cnt/all_ja_cnt}={no_ja_cnt}/{all_ja_cnt}') 17 | 18 | # joint attetion size analysis 19 | img_id_list_ja = [] 20 | img_id_list_all = [] 21 | for ann_path in glob.glob(os.path.join(dataset_dir, 'annotations', 'test', '*.txt')): 22 | with open(ann_path, 'r') as f: 23 | ann_lines = f.readlines() 24 | 25 | vid_id = int(ann_path.split('/')[-1].split('.')[0]) 26 | for line in ann_lines: 27 | co_id, img_id = map(float, line.strip().split()[0:2]) 28 | 29 | data_id = f'{vid_id}_{img_id}' 30 | img_id_list_ja.append(data_id) 31 | 32 | for img_path in glob.glob(os.path.join(dataset_dir, 'images_nk', 'test', '*', '*.jpg')): 33 | img_id_list_all.append(img_path) 34 | 35 | print(len(img_id_list_ja)) 36 | print(len(img_id_list_all)) 37 | print(len(img_id_list_ja)/len(img_id_list_all)) 38 | sys.exit() 39 | -------------------------------------------------------------------------------- /yaml_files/videoattentiontarget/debug_hgt.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | name: videoattentiontarget 3 | dataset_dir : data/videoattentiontarget 4 | 5 | exp_set: 6 | save_folder: saved_weights 7 | wandb_name: debug 8 | 9 | wandb_log : False 10 | 11 | batch_size: 2 12 | num_workers: 1 13 | seed_num: 777 14 | gpu_mode : True 15 | gpu_start : 6 16 | gpu_finish : 6 17 | 18 | # resize_height: 320 19 | resize_height: 224 20 | # resize_width: 480 21 | resize_width: 224 22 | resize_head_height: 64 23 | resize_head_width: 64 24 | 25 | exp_params: 26 | # bbox_types: GT 27 | bbox_types: PRED 28 | bbox_iou_thresh: 0.6 29 | 30 | # learning rate 31 | lr : 0.00001 32 | 33 | # gt gaussian 34 | gaussian_sigma: 10 35 | 36 | # learning schedule 37 | nEpochs : 500 38 | start_iter : 0 39 | snapshots : 100 40 | scheduler_start : 1000 41 | scheduler_iter : 1100000 42 | 43 | # pretrained models 44 | pretrained_models_dir: saved_weights 45 | 46 | use_pretrained_joint_attention_estimator: False 47 | # use_pretrained_head_pose_estimator: True 48 | pretrained_joint_attention_estimator_name: pretrain_head_estimator 49 | freeze_joint_attention_estimator: False 50 | # freeze_joint_attention_estimator: True 51 | 52 | model_params: 53 | model_type: human_gaze_target_transformer -------------------------------------------------------------------------------- /yaml_files/videoattentiontarget/train_hgt.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | name: videoattentiontarget 3 | dataset_dir : data/videoattentiontarget 4 | 5 | exp_set: 6 | save_folder: saved_weights 7 | # wandb_name: videoattentiontarget-hgt 8 | # wandb_name: videoattentiontarget-hgt-high 9 | # wandb_name: videoattentiontarget-hgt-1101 10 | 11 | # wandb_name: videoattentiontarget-hgt_bbox_GT 12 | wandb_name: videoattentiontarget-hgt_bbox_PRED 13 | 14 | wandb_log : True 15 | 16 | batch_size: 16 17 | num_workers: 16 18 | seed_num: 777 19 | gpu_mode : True 20 | gpu_start : 2 21 | gpu_finish : 2 22 | 23 | # resize_height: 320 24 | resize_height: 224 25 | # resize_width: 480 26 | resize_width: 224 27 | resize_head_height: 64 28 | resize_head_width: 64 29 | 30 | exp_params: 31 | # bbox_types: GT 32 | bbox_types: PRED 33 | bbox_iou_thresh: 0.6 34 | 35 | # learning rate 36 | lr : 0.00001 37 | 38 | # gt gaussian 39 | gaussian_sigma: 40 40 | 41 | # learning schedule 42 | nEpochs : 500 43 | start_iter : 0 44 | snapshots : 100 45 | scheduler_start : 1000 46 | scheduler_iter : 1100000 47 | 48 | # pretrained models 49 | pretrained_models_dir: saved_weights 50 | 51 | use_pretrained_joint_attention_estimator: False 52 | # use_pretrained_head_pose_estimator: True 53 | pretrained_joint_attention_estimator_name: pretrain_head_estimator 54 | freeze_joint_attention_estimator: False 55 | # freeze_joint_attention_estimator: True 56 | 57 | model_params: 58 | model_type: human_gaze_target_transformer -------------------------------------------------------------------------------- /yaml_files/volleyball/debug_hgt.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | name: volleyball 3 | sendo_dataset_dir : data/volleyball_tracking_annotation 4 | rgb_dataset_dir : data/videos 5 | annotation_dir : data/vatic_ball_annotation/annotation_data/ 6 | dataset_bbox_gt: data/jae_dataset_bbox_gt 7 | dataset_bbox_pred: data/jae_dataset_bbox_pred 8 | 9 | exp_set: 10 | save_folder: saved_weights 11 | wandb_name: debug 12 | wandb_log : False 13 | 14 | batch_size: 2 15 | num_workers: 1 16 | seed_num: 777 17 | gpu_mode : True 18 | gpu_start : 4 19 | gpu_finish : 4 20 | 21 | # resize_height: 320 22 | resize_height: 224 23 | # resize_width: 480 24 | resize_width: 224 25 | resize_head_height: 64 26 | resize_head_width: 64 27 | 28 | exp_params: 29 | 30 | use_frame_type: mid 31 | # use_frame_type: all 32 | 33 | # bbox_types: GT 34 | bbox_types: PRED 35 | 36 | # gaze_types: GT 37 | gaze_types: PRED 38 | 39 | # action_types: GT 40 | action_types: PRED 41 | 42 | # learning rate 43 | lr : 0.00001 44 | 45 | # gt gaussian 46 | gaussian_sigma: 10 47 | 48 | # learning schedule 49 | nEpochs : 500 50 | start_iter : 0 51 | snapshots : 100 52 | scheduler_start : 1000 53 | scheduler_iter : 1100000 54 | 55 | # pretrained models 56 | pretrained_models_dir: saved_weights 57 | 58 | use_pretrained_joint_attention_estimator: False 59 | # use_pretrained_head_pose_estimator: True 60 | pretrained_joint_attention_estimator_name: pretrain_head_estimator 61 | freeze_joint_attention_estimator: False 62 | # freeze_joint_attention_estimator: True 63 | 64 | model_params: 65 | model_type: human_gaze_target_transformer -------------------------------------------------------------------------------- /yaml_files/videocoatt/demo.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | name: videocoatt 3 | dataset_dir : data/VideoCoAtt_Dataset 4 | saliency_dataset_dir : data/deepgaze_output_loader 5 | 6 | exp_set: 7 | save_folder : saved_weights 8 | # model_name: videocoatt-head_pose_estimator 9 | 10 | # model_name: videoattentiontarget-hgt 11 | # model_name: videoattentiontarget-hgt-high 12 | # model_name: gazefollow-dual-cnn-w_pre 13 | # model_name: videocoatt-isa 14 | # model_name: videocoatt-isa-mse-1103 15 | # model_name: videoattentiontarget-dual-cnn-w_pre-w_att_in 16 | 17 | # model_name: videocoatt-p_p_field_deep_p_s_gaze_follow_freeze 18 | # model_name: videocoatt-p_p_field_deep_p_s_davt_freeze 19 | # model_name: videocoatt-p_p_field_deep_p_s_trans_gaze_follow_freeze 20 | # model_name: videocoatt-p_p_field_deep_p_s_cnn_gaze_follow_w_pre_simple_average 21 | # model_name: videocoatt-p_p_field_deep_p_s_davt_simple_average 22 | # model_name: videocoatt-p_p_field_deep_p_s_davt_scalar_weight_fine 23 | 24 | # model_name: videocoatt-dual-people_field_middle 25 | model_name: videocoatt-p_p_field_deep_p_s_davt_scalar_weight_fix_token_only_GT 26 | 27 | seed_num : 777 28 | gpu_mode : True 29 | gpu_start : 6 30 | gpu_finish : 6 31 | num_workers : 1 32 | batch_size : 1 33 | wandb_name : demo 34 | 35 | mode: test 36 | # mode : validate 37 | # mode : train 38 | 39 | exp_params: 40 | # test_heads_type : det 41 | test_heads_type : gt 42 | det_heads_model : det_heads 43 | test_heads_conf : 0.6 44 | # use_gt_gaze : False 45 | use_gt_gaze : True 46 | 47 | # use_frame_type: mid 48 | use_frame_type: all 49 | 50 | # vis_dist_error: False 51 | vis_dist_error: True -------------------------------------------------------------------------------- /dataset/dataset_selector.py: -------------------------------------------------------------------------------- 1 | from dataset.volleyball import VolleyBallDataset 2 | from dataset.volleyball_wo_att import VolleyBallDatasetWithoutAtt 3 | from dataset.videocoatt import VideoCoAttDataset, VideoCoAttDatasetNoAtt, VideoCoAttDatasetMultAP 4 | from dataset.videoattentiontarget import VideoAttentionTargetDataset 5 | from dataset.toy import ToyDataset 6 | from dataset.gazefollow import GazeFollowDataset 7 | 8 | def dataset_generator(cfg, mode): 9 | print(f'{cfg.data.name} dataset') 10 | if cfg.data.name == 'volleyball': 11 | data_set = VolleyBallDataset(cfg, mode) 12 | elif cfg.data.name == 'volleyball_wo_att': 13 | data_set = VolleyBallDatasetWithoutAtt(cfg, mode) 14 | elif cfg.data.name == 'videocoatt': 15 | if mode == 'valid': 16 | mode = 'validate' 17 | data_set = VideoCoAttDataset(cfg, mode) 18 | elif cfg.data.name == 'videocoatt_no_att': 19 | if mode == 'valid': 20 | mode = 'validate' 21 | data_set = VideoCoAttDatasetNoAtt(cfg, mode) 22 | elif cfg.data.name == 'videocoatt_mult_att': 23 | if mode == 'valid': 24 | mode = 'validate' 25 | data_set = VideoCoAttDatasetMultAP(cfg, mode) 26 | elif cfg.data.name == 'videoattentiontarget': 27 | if mode == 'valid': 28 | mode = 'test' 29 | data_set = VideoAttentionTargetDataset(cfg, mode) 30 | elif cfg.data.name == 'toy': 31 | data_set = ToyDataset(cfg, mode) 32 | elif cfg.data.name == 'gazefollow': 33 | if mode == 'valid': 34 | mode = 'test' 35 | data_set = GazeFollowDataset(cfg, mode) 36 | else: 37 | assert True, 'cfg.data.name is incorrect' 38 | 39 | return data_set -------------------------------------------------------------------------------- /yaml_files/volleyball/eval.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | name: volleyball 3 | sendo_dataset_dir : data/volleyball_tracking_annotation 4 | rgb_dataset_dir : data/videos 5 | annotation_dir : data/vatic_ball_annotation/annotation_data_sub/ 6 | dataset_bbox_gt: data/jae_dataset_bbox_gt 7 | dataset_bbox_pred: data/jae_dataset_bbox_pred 8 | 9 | exp_set: 10 | save_folder : saved_weights 11 | 12 | # [ball detection] 13 | # model_name: 2021_0708_lr_e3_gamma_1_stack_3_mid_frame_ver2 14 | 15 | # [ISA] 16 | # model_name: volleyball-isa_bbox_PRED_gaze_PRED_act_PRED 17 | # model_name: volleyball-isa_bbox_GT_gaze_GT_act_GT 18 | 19 | # [DAVT] 20 | # model_name: volleyball-dual-mid_p_p_field_middle_p_s_davt_bbox_PRED_gaze_PRED_act_PRED_psfix_fusion_wo_p_p 21 | # model_name: volleyball-dual-mid_p_p_field_middle_p_s_davt_bbox_GT_gaze_GT_act_GT_psfix_fusion_wo_p_p 22 | 23 | # [Ours] 24 | # model_name: volleyball-dual-mid_p_p_field_middle_p_s_davt_bbox_PRED_gaze_PRED_act_PRED_weight_fusion_fine_token_only 25 | # model_name: volleyball-dual-mid_p_p_field_middle_p_s_davt_bbox_GT_gaze_GT_act_GT_weight_fusion_fine_token_only 26 | 27 | model_name: volleyball-dual-mid_p_p_field_middle_bbox_GT_gaze_GT_act_GT 28 | 29 | seed_num : 777 30 | gpu_mode : True 31 | gpu_start : 6 32 | gpu_finish : 6 33 | num_workers : 1 34 | batch_size : 1 35 | wandb_name : test 36 | 37 | mode: test 38 | 39 | exp_params: 40 | use_frame_type: mid 41 | # use_frame_type: all 42 | 43 | # bbox_types: GT 44 | bbox_types: PRED 45 | 46 | # action_types: GT 47 | action_types: PRED 48 | 49 | # gaze_types: GT 50 | gaze_types: PRED 51 | 52 | # use_action: True 53 | # use_position: True 54 | # use_gaze: True 55 | 56 | use_blured_img: False 57 | # use_blured_img: True -------------------------------------------------------------------------------- /analysis/iccv2023/p_s_gen_ablation_on_videocoatt.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | import pandas as pd 5 | import numpy as np 6 | 7 | 8 | saved_result_dir = os.path.join('results', 'videocoatt') 9 | 10 | # define analyze model type 11 | analyze_name_list = [] 12 | analyze_name_list.append('videocoatt-p_p_field_deep_p_s_davt') 13 | 14 | # define ablate type 15 | analyze_name_ablation_list = [] 16 | analyze_name_ablation_list.append('simple_average') 17 | analyze_name_ablation_list.append('scalar_weight') 18 | analyze_name_ablation_list.append('freeze') 19 | 20 | # define test data type 21 | test_data_type_list = [] 22 | test_data_type_list.append('test_gt_gaze_False_head_conf_0.6') 23 | for test_data_type in test_data_type_list: 24 | print(f'==={test_data_type}===') 25 | for analyze_name in analyze_name_list: 26 | model_name_list = [] 27 | eval_results_list = [] 28 | for ablation_name in analyze_name_ablation_list: 29 | 30 | model_name = f'{analyze_name}_{ablation_name}' 31 | model_name_list.append(model_name) 32 | 33 | json_file_path = os.path.join(saved_result_dir, model_name, 'eval_results', test_data_type, 'eval_results.json') 34 | with open(json_file_path, 'r') as f: 35 | eval_results_dic = json.load(f) 36 | eval_results_list.append(list(eval_results_dic.values())) 37 | eval_metrics_list = list(eval_results_dic.keys()) 38 | 39 | eval_results_array = np.array(eval_results_list) 40 | df_eval_results = pd.DataFrame(eval_results_array, model_name_list, eval_metrics_list) 41 | save_csv_file_path = os.path.join(saved_result_dir, f'p_s_gen_ablation_videocoatt_{test_data_type}.csv') 42 | df_eval_results.to_csv(save_csv_file_path) -------------------------------------------------------------------------------- /print_weights.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from models.model_selector import model_generator 5 | from models.davt_scene_extractor import ModelSpatial, ModelSpatialDummy, ModelSpatioTemporal 6 | 7 | 8 | gpus_list = [4] 9 | print("===> Load pretrained model (saliecny extractor)") 10 | model_name = 'pretrained_scene_extractor_davt' 11 | weight_name_list = ['model_demo.pt', 'model_gazefollow.pt', 'model_videoatttarget.pt', 'initial_weights_for_temporal_training.pt'] 12 | weight_key_dict = {} 13 | for weight_name in weight_name_list: 14 | model_weight_path = os.path.join('saved_weights', 'volleyball_wo_att', model_name, weight_name) 15 | pretrained_dict = torch.load(model_weight_path, map_location='cuda:'+str(gpus_list[0])) 16 | pretrained_dict = pretrained_dict['model'] 17 | weight_key_len = len(pretrained_dict.keys()) 18 | print(weight_name, weight_key_len) 19 | weight_key_dict[weight_name] = list(pretrained_dict.keys()) 20 | 21 | weight_img = weight_key_dict[weight_name_list[0]] 22 | weight_vid = weight_key_dict[weight_name_list[2]] 23 | weight_diff = sorted(set(weight_vid) - set(weight_img)) 24 | # print(weight_diff) 25 | # for weihgt_diff_key in weight_diff: 26 | # print(weihgt_diff_key) 27 | 28 | model_saliency = ModelSpatioTemporal() 29 | model_saliency_dict = model_saliency.state_dict() 30 | model_saliency_dict_key_len = len(model_saliency_dict.keys()) 31 | print("model_saliency_dict_key_len", model_saliency_dict_key_len) 32 | 33 | weight_vid_saliency = model_saliency_dict.keys() 34 | weight_diff = sorted(set(weight_vid_saliency) - set(weight_img)) 35 | # print(weight_diff) 36 | for weihgt_diff_key in weight_diff: 37 | print(weihgt_diff_key) 38 | 39 | # model_saliency_dict.update(pretrained_dict) 40 | # model_saliency.load_state_dict(model_saliency_dict) -------------------------------------------------------------------------------- /yaml_files/volleyball/train_hgt.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | name: volleyball 3 | sendo_dataset_dir : data/volleyball_tracking_annotation 4 | rgb_dataset_dir : data/videos 5 | annotation_dir : data/vatic_ball_annotation/annotation_data/ 6 | dataset_bbox_gt: data/jae_dataset_bbox_gt 7 | dataset_bbox_pred: data/jae_dataset_bbox_pred 8 | 9 | exp_set: 10 | save_folder: saved_weights 11 | # wandb_name: volleyball-hgtd_bbox_GT_gaze_GT_act_GT 12 | wandb_name: volleyball-hgtd_bbox_PRED_gaze_PRED_act_PRED 13 | wandb_log : True 14 | 15 | batch_size: 2 16 | num_workers: 1 17 | seed_num: 777 18 | gpu_mode : True 19 | gpu_start : 5 20 | gpu_finish : 5 21 | 22 | # resize_height: 320 23 | resize_height: 224 24 | # resize_width: 480 25 | resize_width: 224 26 | resize_head_height: 64 27 | resize_head_width: 64 28 | 29 | exp_params: 30 | 31 | use_frame_type: mid 32 | # use_frame_type: all 33 | 34 | # bbox_types: GT 35 | bbox_types: PRED 36 | 37 | # gaze_types: GT 38 | gaze_types: PRED 39 | 40 | # action_types: GT 41 | action_types: PRED 42 | 43 | # learning rate 44 | lr : 0.00001 45 | 46 | # position augmentation 47 | use_position_aug: False 48 | # use_position_aug: True 49 | position_aug_std: 0.05 50 | 51 | # gt gaussian 52 | gaussian_sigma: 10 53 | 54 | # learning schedule 55 | nEpochs : 500 56 | start_iter : 0 57 | snapshots : 100 58 | scheduler_start : 1000 59 | scheduler_iter : 1100000 60 | 61 | # pretrained models 62 | pretrained_models_dir: saved_weights 63 | 64 | use_pretrained_joint_attention_estimator: False 65 | # use_pretrained_head_pose_estimator: True 66 | pretrained_joint_attention_estimator_name: pretrain_head_estimator 67 | freeze_joint_attention_estimator: False 68 | # freeze_joint_attention_estimator: True 69 | 70 | model_params: 71 | model_type: human_gaze_target_transformer -------------------------------------------------------------------------------- /yaml_files/videocoatt/debug_isa.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | name: videocoatt 3 | dataset_dir : data/VideoCoAtt_Dataset 4 | saliency_dataset_dir : data/deepgaze_output_loader 5 | 6 | exp_set: 7 | save_folder: saved_weights 8 | wandb_name: debug 9 | wandb_log : False 10 | 11 | batch_size: 16 12 | num_workers: 16 13 | seed_num: 777 14 | gpu_mode : True 15 | gpu_start : 3 16 | gpu_finish : 3 17 | 18 | resize_height: 320 19 | resize_width: 480 20 | resize_head_height: 64 21 | resize_head_width: 64 22 | 23 | exp_params: 24 | # use_e_att_loss : False 25 | use_e_att_loss : True 26 | 27 | use_frame_type: mid 28 | # use_frame_type: all 29 | 30 | use_gt_gaze: False 31 | # use_gt_gaze: True 32 | 33 | # position augmentation 34 | use_position_aug: False 35 | # use_position_aug: True 36 | position_aug_std: 0.05 37 | 38 | # loss function 39 | loss : mse 40 | # loss : bce 41 | 42 | # learning rate 43 | lr : 0.0001 44 | 45 | # gt gaussian 46 | gaussian_sigma: 10 47 | 48 | # learning schedule 49 | nEpochs : 500 50 | start_iter : 0 51 | snapshots : 100 52 | scheduler_start : 1000 53 | scheduler_iter : 1100000 54 | 55 | det_heads_model : det_heads 56 | train_det_heads : False 57 | # train_det_heads : True 58 | train_heads_conf : 0.6 59 | test_heads_conf : 0.6 60 | 61 | # pretrained models 62 | pretrained_models_dir: saved_weights 63 | 64 | # use_pretrained_head_pose_estimator: False 65 | use_pretrained_head_pose_estimator: True 66 | pretrained_head_pose_estimator_name: videocoatt-head_pose_estimator 67 | # freeze_head_pose_estimator: False 68 | freeze_head_pose_estimator: True 69 | 70 | use_pretrained_joint_attention_estimator: False 71 | # use_pretrained_head_pose_estimator: True 72 | pretrained_joint_attention_estimator_name: pretrain_head_estimator 73 | freeze_joint_attention_estimator: False 74 | # freeze_joint_attention_estimator: True 75 | 76 | model_params: 77 | model_type: isa -------------------------------------------------------------------------------- /analysis/iccv2023/p_p_gen_ablation_on_videocoatt.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | import pandas as pd 5 | import numpy as np 6 | 7 | 8 | saved_result_dir = os.path.join('results', 'videocoatt') 9 | 10 | # define analyze model type 11 | analyze_name_list = [] 12 | analyze_name_list.append('videocoatt-dual-people') 13 | 14 | # define ablate type 15 | analyze_name_ablation_list = [] 16 | analyze_name_ablation_list.append('fc_shallow') 17 | analyze_name_ablation_list.append('fc_middle') 18 | analyze_name_ablation_list.append('fc_deep') 19 | analyze_name_ablation_list.append('deconv_shallow') 20 | analyze_name_ablation_list.append('deconv_middle') 21 | analyze_name_ablation_list.append('field_middle') 22 | analyze_name_ablation_list.append('field_deep') 23 | 24 | # define test data type 25 | test_data_type_list = [] 26 | test_data_type_list.append('test_gt_gaze_False_head_conf_0.6') 27 | for test_data_type in test_data_type_list: 28 | print(f'==={test_data_type}===') 29 | for analyze_name in analyze_name_list: 30 | model_name_list = [] 31 | eval_results_list = [] 32 | analyze_name_type = analyze_name.split('_')[0] 33 | for ablation_name in analyze_name_ablation_list: 34 | 35 | model_name = f'{analyze_name_type}_{ablation_name}' 36 | model_name_list.append(model_name) 37 | 38 | json_file_path = os.path.join(saved_result_dir, model_name, 'eval_results', test_data_type, 'eval_results.json') 39 | with open(json_file_path, 'r') as f: 40 | eval_results_dic = json.load(f) 41 | eval_results_list.append(list(eval_results_dic.values())) 42 | eval_metrics_list = list(eval_results_dic.keys()) 43 | 44 | eval_results_array = np.array(eval_results_list) 45 | df_eval_results = pd.DataFrame(eval_results_array, model_name_list, eval_metrics_list) 46 | save_csv_file_path = os.path.join(saved_result_dir, f'p_p_gen_ablation_videocoatt_{test_data_type}.csv') 47 | df_eval_results.to_csv(save_csv_file_path) -------------------------------------------------------------------------------- /analysis/iccv2023/fusion_ablation_on_videocoatt.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | import pandas as pd 5 | import numpy as np 6 | 7 | 8 | saved_result_dir = os.path.join('results', 'videocoatt') 9 | 10 | # define analyze model type 11 | analyze_name_list = [] 12 | analyze_name_list.append('videocoatt-p_p_field_deep_p_s') 13 | 14 | # define ablate type 15 | analyze_name_ablation_list = [] 16 | analyze_name_ablation_list = [] 17 | analyze_name_ablation_list.append('_davt_simple_average') 18 | analyze_name_ablation_list.append('_davt_scalar_weight_fix') 19 | analyze_name_ablation_list.append('_davt_scalar_weight_fine') 20 | # analyze_name_ablation_list.append('_davt_freeze') 21 | 22 | # define model names 23 | model_name_list = [] 24 | model_name_list.append('Mean average') 25 | model_name_list.append('Weighted average (fix)') 26 | model_name_list.append('Weighted average (fine)') 27 | # model_name_list.append('CNN fusion') 28 | 29 | # define test data type 30 | test_data_type_list = [] 31 | test_data_type_list.append('test_gt_gaze_False_head_conf_0.6') 32 | for test_data_type in test_data_type_list: 33 | print(f'==={test_data_type}===') 34 | for analyze_name in analyze_name_list: 35 | eval_results_list = [] 36 | for ablation_name in analyze_name_ablation_list: 37 | 38 | model_name = f'{analyze_name}{ablation_name}' 39 | 40 | json_file_path = os.path.join(saved_result_dir, model_name, 'eval_results', test_data_type, 'eval_results.json') 41 | with open(json_file_path, 'r') as f: 42 | eval_results_dic = json.load(f) 43 | eval_results_list.append(list(eval_results_dic.values())) 44 | eval_metrics_list = list(eval_results_dic.keys()) 45 | 46 | eval_results_array = np.array(eval_results_list) 47 | df_eval_results = pd.DataFrame(eval_results_array, model_name_list, eval_metrics_list) 48 | save_csv_file_path = os.path.join(saved_result_dir, f'fusion_ablation_videocoatt_{test_data_type}.csv') 49 | df_eval_results.to_csv(save_csv_file_path) -------------------------------------------------------------------------------- /analysis/iccv2023/self_attetnion_ablation.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | import pandas as pd 5 | import numpy as np 6 | 7 | 8 | saved_result_dir = os.path.join('results', 'volleyball') 9 | 10 | # define analyze model type 11 | analyze_name_list = [] 12 | analyze_name_list.append('volleyball-dual-mid_p_p_field_middle_bbox_PRED_gaze_PRED_act_PRED') 13 | 14 | # define ablate type 15 | analyze_name_ablation_list = [] 16 | analyze_name_ablation_list.append('_ind_only_wo_trans') 17 | analyze_name_ablation_list.append('_token_only_wo_trans') 18 | analyze_name_ablation_list.append('_token_only') 19 | 20 | # define model name 21 | model_name_list = [] 22 | model_name_list.append('ind_wo_trans') 23 | model_name_list.append('token_wo_trans') 24 | model_name_list.append('Ours') 25 | 26 | # define test data type 27 | test_data_type_list = [] 28 | # test_data_type_list.append('bbox_GT_gaze_GT_act_GT_blur_False') 29 | test_data_type_list.append('bbox_PRED_gaze_PRED_act_PRED_blur_False') 30 | for test_data_type in test_data_type_list: 31 | print(f'==={test_data_type}===') 32 | for analyze_name in analyze_name_list: 33 | # model_name_list = [] 34 | eval_results_list = [] 35 | analyze_name_type = analyze_name 36 | for ablation_name in analyze_name_ablation_list: 37 | 38 | model_name = f'{analyze_name_type}{ablation_name}' 39 | # model_name_list.append(model_name) 40 | 41 | json_file_path = os.path.join(saved_result_dir, model_name, 'eval_results', test_data_type, 'eval_results.json') 42 | 43 | with open(json_file_path, 'r') as f: 44 | eval_results_dic = json.load(f) 45 | eval_results_list.append(list(eval_results_dic.values())) 46 | eval_metrics_list = list(eval_results_dic.keys()) 47 | 48 | eval_results_array = np.array(eval_results_list) 49 | df_eval_results = pd.DataFrame(eval_results_array, model_name_list, eval_metrics_list) 50 | save_csv_file_path = os.path.join(saved_result_dir, f'self_attention_ablation_{analyze_name}_{test_data_type}.csv') 51 | df_eval_results.to_csv(save_csv_file_path) -------------------------------------------------------------------------------- /analysis/analyze_ja_tem_consistency.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | from tqdm import tqdm 6 | 7 | def read_csv_ann(ann_path): 8 | if not os.path.exists(ann_path): 9 | return pd.DataFrame() 10 | df_ann = pd.read_csv(ann_path, header=None, sep=' ') 11 | df_ann = df_ann.iloc[:, 1:5] 12 | 13 | return df_ann 14 | 15 | dataset_dir = os.path.join('data', 'vatic_ball_annotation', 'annotation_data_sub') 16 | program_name = os.path.basename(__file__).split('.')[0] 17 | save_dir = os.path.join('analysis', program_name) 18 | if not os.path.exists(save_dir): 19 | os.makedirs(save_dir) 20 | 21 | center_idx, pad_idx = 20, 5 22 | img_height, img_width = 720, 1280 23 | stop_idx = 20 24 | 25 | for ann_idx, ann_file_name in enumerate(tqdm(os.listdir(dataset_dir))): 26 | vid_num, seq_num = ann_file_name.split('_')[1:3] 27 | df_ann = read_csv_ann(os.path.join(dataset_dir, ann_file_name)) 28 | if df_ann.shape[0] == 0: 29 | continue 30 | 31 | ja_xmid = (df_ann.iloc[:, 0] + df_ann.iloc[:, 2]) / 2 32 | ja_ymid = (df_ann.iloc[:, 1] + df_ann.iloc[:, 3]) / 2 33 | ja_mid = pd.concat([ja_xmid, ja_ymid], axis=1) 34 | ja_mid = ja_mid.iloc[center_idx-pad_idx:center_idx+pad_idx, :] 35 | ja_mid_diff = ja_mid.diff().fillna(0) 36 | 37 | save_dir_child = os.path.join(save_dir, f'{vid_num}_{seq_num}') 38 | if not os.path.exists(save_dir_child): 39 | os.makedirs(save_dir_child) 40 | 41 | plt.figure() 42 | for t in range(ja_mid.shape[0]): 43 | t_norm = str(t / (ja_mid.shape[0])) 44 | plt.plot(ja_mid.iloc[t, 0], ja_mid.iloc[t, 1], 'o', color=t_norm) 45 | plt.savefig(os.path.join(save_dir_child, f'move_abs.png')) 46 | plt.close() 47 | 48 | plt.figure() 49 | for t in range(ja_mid.shape[0]): 50 | t_norm = str(t / (ja_mid.shape[0])) 51 | plt.plot(ja_mid_diff.iloc[t, 0], ja_mid_diff.iloc[t, 1], 'o', color=t_norm) 52 | plt.xlim(-50, 50) 53 | plt.ylim(-50, 50) 54 | plt.savefig(os.path.join(save_dir_child, f'move_diff.png')) 55 | plt.close() 56 | 57 | if ann_idx > stop_idx: 58 | break -------------------------------------------------------------------------------- /models/model_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | # https://github.com/wzlxjtu/PositionalEncoding2D/blob/master/positionalembedding2d.py 5 | def positionalencoding1d(d_model, length): 6 | """ 7 | :param d_model: dimension of the model 8 | :param length: length of positions 9 | :return: length*d_model position matrix 10 | """ 11 | if d_model % 2 != 0: 12 | raise ValueError("Cannot use sin/cos positional encoding with " 13 | "odd dim (got dim={:d})".format(d_model)) 14 | pe = torch.zeros(length, d_model) 15 | position = torch.arange(0, length).unsqueeze(1) 16 | div_term = torch.exp((torch.arange(0, d_model, 2, dtype=torch.float) * 17 | -(math.log(10000.0) / d_model))) 18 | pe[:, 0::2] = torch.sin(position.float() * div_term) 19 | pe[:, 1::2] = torch.cos(position.float() * div_term) 20 | 21 | return pe 22 | 23 | def positionalencoding2d(d_model, height, width): 24 | """ 25 | :param d_model: dimension of the model 26 | :param height: height of the positions 27 | :param width: width of the positions 28 | :return: d_model*height*width position matrix 29 | """ 30 | if d_model % 4 != 0: 31 | raise ValueError("Cannot use sin/cos positional encoding with " 32 | "odd dimension (got dim={:d})".format(d_model)) 33 | pe = torch.zeros(d_model, height, width) 34 | # Each dimension use half of d_model 35 | d_model = int(d_model / 2) 36 | div_term = torch.exp(torch.arange(0., d_model, 2) * 37 | -(math.log(10000.0) / d_model)) 38 | pos_w = torch.arange(0., width).unsqueeze(1) 39 | pos_h = torch.arange(0., height).unsqueeze(1) 40 | pe[0:d_model:2, :, :] = torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1) 41 | pe[1:d_model:2, :, :] = torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1) 42 | pe[d_model::2, :, :] = torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width) 43 | pe[d_model + 1::2, :, :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width) 44 | 45 | return pe -------------------------------------------------------------------------------- /analysis/iccv2023/gt_pred_ablation_on_volleyball.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | import pandas as pd 5 | import numpy as np 6 | 7 | 8 | saved_result_dir = os.path.join('results', 'volleyball') 9 | 10 | # define analyze model type 11 | # analyze_name = 'volleyball-dual-mid_p_p_field_middle_p_s_davt_bbox_GT_gaze_GT_act_GT_psfix_fusion' 12 | analyze_name = 'volleyball-dual-mid_p_p_field_middle_p_s_davt_bbox_GT_gaze_GT_act_GT_psfix_fusion_scalar_weight_fine' 13 | # analyze_name = 'volleyball-dual-mid_p_p_field_middle_p_s_davt_bbox_PRED_gaze_PRED_act_PRED_psfix_fusion' 14 | # analyze_name = 'volleyball-dual-mid_p_p_field_middle_p_s_davt_bbox_PRED_gaze_PRED_act_PRED_psfix_fusion_scalar_weight_fine' 15 | 16 | # define test data type 17 | test_data_type_list = [] 18 | test_data_type_list.append('bbox_GT_gaze_GT_act_GT_blur_False') 19 | test_data_type_list.append('bbox_GT_gaze_PRED_act_GT_blur_False') 20 | test_data_type_list.append('bbox_GT_gaze_GT_act_PRED_blur_False') 21 | test_data_type_list.append('bbox_GT_gaze_PRED_act_PRED_blur_False') 22 | test_data_type_list.append('bbox_PRED_gaze_PRED_act_PRED_blur_False') 23 | 24 | # define model name 25 | model_name_list = [] 26 | model_name_list.append('Ours (p=GT, g=GT, a=GT)') 27 | model_name_list.append('Ours (p=GT, g=Pr, a=GT)') 28 | model_name_list.append('Ours (p=GT, g=GT, a=Pr)') 29 | model_name_list.append('Ours (p=GT, g=Pr, a=Pr)') 30 | model_name_list.append('Ours (p=Pr, g=Pr, a=Pr)') 31 | 32 | eval_results_list = [] 33 | for test_data_type in test_data_type_list: 34 | print(f'==={test_data_type}===') 35 | json_file_path = os.path.join(saved_result_dir, analyze_name, 'eval_results', test_data_type, 'eval_results.json') 36 | with open(json_file_path, 'r') as f: 37 | eval_results_dic = json.load(f) 38 | eval_results_list.append(list(eval_results_dic.values())) 39 | eval_metrics_list = list(eval_results_dic.keys()) 40 | 41 | eval_results_array = np.array(eval_results_list) 42 | df_eval_results = pd.DataFrame(eval_results_array, model_name_list, eval_metrics_list) 43 | save_csv_file_path = os.path.join(saved_result_dir, f'gt_pred_ablation_{analyze_name}.csv') 44 | df_eval_results.to_csv(save_csv_file_path) -------------------------------------------------------------------------------- /analysis/iccv2023/comparison_ja_trans_on_videocoatt.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | import pandas as pd 5 | import numpy as np 6 | 7 | 8 | saved_result_dir = os.path.join('results', 'videocoatt') 9 | 10 | # define analyze model type 11 | analyze_name_list = [] 12 | # analyze_name_list.append('videocoatt-dual-people_field_middle') 13 | # analyze_name_list.append('videocoatt-p_p_field_deep_p_s_davt_scalar_weight_fine') 14 | analyze_name_list.append('videocoatt-p_p_field_deep_p_s_davt_scalar_weight_fix') 15 | 16 | # define ablate type 17 | analyze_name_ablation_list = [] 18 | analyze_name_ablation_list.append('_ind_only') 19 | # analyze_name_ablation_list.append('_token_only') 20 | analyze_name_ablation_list.append('') 21 | analyze_name_ablation_list.append('_ind_and_token_ind_based') 22 | analyze_name_ablation_list.append('_ind_and_token_token_based') 23 | 24 | # define model names 25 | model_name_list = [] 26 | model_name_list.append('Ind only') 27 | model_name_list.append('Token only') 28 | model_name_list.append('Ind and Token (ind-based)') 29 | model_name_list.append('Ind and Token (token-based)') 30 | 31 | # define test data type 32 | test_data_type_list = [] 33 | test_data_type_list.append('test_gt_gaze_False_head_conf_0.6') 34 | for test_data_type in test_data_type_list: 35 | print(f'==={test_data_type}===') 36 | for analyze_name in analyze_name_list: 37 | eval_results_list = [] 38 | for ablation_name in analyze_name_ablation_list: 39 | 40 | model_name = f'{analyze_name}{ablation_name}' 41 | json_file_path = os.path.join(saved_result_dir, model_name, 'eval_results', test_data_type, 'eval_results.json') 42 | with open(json_file_path, 'r') as f: 43 | eval_results_dic = json.load(f) 44 | eval_results_list.append(list(eval_results_dic.values())) 45 | eval_metrics_list = list(eval_results_dic.keys()) 46 | 47 | eval_results_array = np.array(eval_results_list) 48 | df_eval_results = pd.DataFrame(eval_results_array, model_name_list, eval_metrics_list) 49 | save_csv_file_path = os.path.join(saved_result_dir, f'comparison_ja_trans_videocoatt_{test_data_type}.csv') 50 | df_eval_results.to_csv(save_csv_file_path) -------------------------------------------------------------------------------- /analysis/comparison_on_volleyball.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | import pandas as pd 5 | import numpy as np 6 | 7 | sys.path.append('./') 8 | from analysis.utils import refine_excel 9 | 10 | saved_result_dir = os.path.join('results', 'volleyball_wo_att') 11 | 12 | # define ablate type 13 | analyze_name_list = [] 14 | analyze_name_list.append('volleyball-isa_bbox_PRED_gaze_PRED_action_PRED_vid') 15 | analyze_name_list.append('volleyball_PRED_DAVT_only_lr_e3') 16 | analyze_name_list.append('volleyball_PRED_DAVT_only_lr_e3_demo') 17 | analyze_name_list.append('volleyball_PRED_DAVT_only_lr_e3_gazefollow') 18 | analyze_name_list.append('volleyball_PRED_DAVT_only_lr_e3_videoatttarget') 19 | analyze_name_list.append('volleyball_PRED_ori_att_vid_token_mask_random25_t_enc_DAVT_scalar_fusion_mod') 20 | 21 | # define model name 22 | model_name_list = [] 23 | model_name_list.append('ISA') 24 | model_name_list.append('DAVT (init)') 25 | model_name_list.append('DAVT (demo)') 26 | model_name_list.append('DAVT (gaze)') 27 | model_name_list.append('DAVT (videoatt)') 28 | model_name_list.append('Ours') 29 | 30 | # define test data type 31 | test_data_type_list = [] 32 | # test_data_type_list.append('bbox_GT_gaze_GT_act_GT_blur_False') 33 | test_data_type_list.append('bbox_PRED_gaze_PRED_act_PRED_blur_False') 34 | for test_data_type in test_data_type_list: 35 | print(f'==={test_data_type}===') 36 | eval_results_list = [] 37 | for analyze_name in analyze_name_list: 38 | model_name = f'{analyze_name}' 39 | 40 | json_file_path = os.path.join(saved_result_dir, model_name, 'eval_results', test_data_type, 'eval_results.json') 41 | 42 | with open(json_file_path, 'r') as f: 43 | eval_results_dic = json.load(f) 44 | eval_results_list.append(list(eval_results_dic.values())) 45 | eval_metrics_list = list(eval_results_dic.keys()) 46 | 47 | eval_results_array = np.array(eval_results_list) 48 | df_eval_results = pd.DataFrame(eval_results_array, model_name_list, eval_metrics_list) 49 | save_excel_file_path = os.path.join(saved_result_dir, f'comparison_on_volleyball_{test_data_type}.xlsx') 50 | df_eval_results.to_excel(save_excel_file_path, sheet_name='all') 51 | refine_excel(save_excel_file_path) -------------------------------------------------------------------------------- /analysis/iccv2023/comparison_ja_trans_on_volleyball.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | import pandas as pd 5 | import numpy as np 6 | 7 | 8 | saved_result_dir = os.path.join('results', 'volleyball') 9 | 10 | # define analyze model type 11 | analyze_name_list = [] 12 | analyze_name_list.append('volleyball-dual-mid_p_p_field_middle_bbox_PRED_gaze_PRED_act_PRED') 13 | # analyze_name_list.append('volleyball-dual-mid_p_p_field_middle_p_s_davt_bbox_PRED_gaze_PRED_act_PRED_weight_fusion_fine') 14 | 15 | # define ablate type 16 | analyze_name_ablation_list = [] 17 | analyze_name_ablation_list.append('_ind_only') 18 | analyze_name_ablation_list.append('_token_only') 19 | analyze_name_ablation_list.append('_ind_and_token_ind_based') 20 | analyze_name_ablation_list.append('_ind_and_token_token_based') 21 | 22 | # define model names 23 | model_name_list = [] 24 | model_name_list.append('Ind only') 25 | model_name_list.append('Token only') 26 | model_name_list.append('Ind and Token (ind-based)') 27 | model_name_list.append('Ind and Token (token-based)') 28 | 29 | # define test data type 30 | test_data_type_list = [] 31 | # test_data_type_list.append('bbox_GT_gaze_GT_act_GT_blur_False') 32 | test_data_type_list.append('bbox_PRED_gaze_PRED_act_PRED_blur_False') 33 | for test_data_type in test_data_type_list: 34 | print(f'==={test_data_type}===') 35 | for analyze_name in analyze_name_list: 36 | eval_results_list = [] 37 | for ablation_name in analyze_name_ablation_list: 38 | 39 | model_name = f'{analyze_name}{ablation_name}' 40 | json_file_path = os.path.join(saved_result_dir, model_name, 'eval_results', test_data_type, 'eval_results.json') 41 | with open(json_file_path, 'r') as f: 42 | eval_results_dic = json.load(f) 43 | eval_results_list.append(list(eval_results_dic.values())) 44 | eval_metrics_list = list(eval_results_dic.keys()) 45 | 46 | eval_results_array = np.array(eval_results_list) 47 | df_eval_results = pd.DataFrame(eval_results_array, model_name_list, eval_metrics_list) 48 | save_csv_file_path = os.path.join(saved_result_dir, f'comparison_ja_trans_volleyball_{test_data_type}.csv') 49 | df_eval_results.to_csv(save_csv_file_path) -------------------------------------------------------------------------------- /analysis/iccv2023/gt_pred_ablation_on_videocoatt.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | import pandas as pd 5 | import numpy as np 6 | 7 | 8 | saved_result_dir = os.path.join('results', 'videocoatt') 9 | 10 | # define analyze model type 11 | analyze_name = 'videocoatt-p_p_field_deep_p_s_davt_scalar_weight_fix_token_only_GT' 12 | 13 | # define test data type 14 | test_data_type_list = [] 15 | test_data_type_list.append('bbox_gt_gaze_True_thresh_f_score') 16 | test_data_type_list.append('bbox_gt_gaze_False_thresh_f_score') 17 | test_data_type_list.append('bbox_det_gaze_False_thresh_f_score') 18 | 19 | # define model name 20 | model_name_list = [] 21 | model_name_list.append('Ours (p=GT, g=GT)') 22 | model_name_list.append('Ours (p=GT, g=Pr)') 23 | model_name_list.append('Ours (p=Pr, g=Pr)') 24 | 25 | eval_results_list = [] 26 | for test_data_type in test_data_type_list: 27 | print(f'==={test_data_type}===') 28 | json_file_path = os.path.join(saved_result_dir, analyze_name, 'eval_results', test_data_type, 'eval_results.json') 29 | with open(json_file_path, 'r') as f: 30 | eval_results_dic = json.load(f) 31 | 32 | eval_results_dic_update = {} 33 | eval_results_dic_update['Dist final (x)'] = eval_results_dic['l2_dist_x_final'] 34 | eval_results_dic_update['Dist final (y)'] = eval_results_dic['l2_dist_y_final'] 35 | eval_results_dic_update['Dist final (euc)'] = eval_results_dic['l2_dist_euc_final'] 36 | for i in range(20): 37 | thr = i*10 38 | eval_results_dic_update[f'Det final (Thr={thr})'] = eval_results_dic[f'Det final (Thr={thr})'] 39 | eval_results_dic_update['Accuracy final'] = eval_results_dic['accuracy final'] 40 | eval_results_dic_update['F-score final'] = eval_results_dic['f1 final'] 41 | eval_results_dic_update['AUC final'] = eval_results_dic['auc final'] 42 | 43 | eval_results_list.append(list(eval_results_dic_update.values())) 44 | eval_metrics_list = list(eval_results_dic_update.keys()) 45 | 46 | eval_results_array = np.array(eval_results_list) 47 | df_eval_results = pd.DataFrame(eval_results_array, model_name_list, eval_metrics_list) 48 | save_csv_file_path = os.path.join(saved_result_dir, f'gt_pred_ablation_{analyze_name}_videocoatt.csv') 49 | df_eval_results.to_csv(save_csv_file_path) -------------------------------------------------------------------------------- /analysis/iccv2023/p_p_agg_ablation_on_volleyball.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | import pandas as pd 5 | import numpy as np 6 | 7 | 8 | saved_result_dir = os.path.join('results', 'volleyball') 9 | 10 | # define analyze model type 11 | analyze_name_list = [] 12 | analyze_name_list.append('volleyball-dual-mid_p_p_field_middle_bbox_GT_gaze_GT_act_GT') 13 | # analyze_name_list.append('volleyball-dual-mid_p_p_field_middle_bbox_PRED_gaze_PRED_act_PRED') 14 | 15 | # define ablate type 16 | analyze_name_ablation_list = [] 17 | analyze_name_ablation_list.append('_ind_only') 18 | # analyze_name_ablation_list.append('_token_only') 19 | analyze_name_ablation_list.append('') 20 | analyze_name_ablation_list.append('_ind_and_token_ind_based') 21 | # analyze_name_ablation_list.append('_ind_and_token_token_based') 22 | 23 | # define model names 24 | model_name_list = [] 25 | model_name_list.append('Ind only') 26 | model_name_list.append('Token only') 27 | model_name_list.append('Ind and Token (ind-based)') 28 | # model_name_list.append('Ind and Token (token-based)') 29 | 30 | # define test data type 31 | test_data_type_list = [] 32 | test_data_type_list.append('bbox_GT_gaze_GT_act_GT_blur_False') 33 | # test_data_type_list.append('bbox_PRED_gaze_PRED_act_PRED_blur_False') 34 | for test_data_type in test_data_type_list: 35 | print(f'==={test_data_type}===') 36 | for analyze_name in analyze_name_list: 37 | eval_results_list = [] 38 | for ablation_name in analyze_name_ablation_list: 39 | 40 | model_name = f'{analyze_name}{ablation_name}' 41 | json_file_path = os.path.join(saved_result_dir, model_name, 'eval_results', test_data_type, 'eval_results.json') 42 | with open(json_file_path, 'r') as f: 43 | eval_results_dic = json.load(f) 44 | eval_results_list.append(list(eval_results_dic.values())) 45 | eval_metrics_list = list(eval_results_dic.keys()) 46 | 47 | eval_results_array = np.array(eval_results_list) 48 | df_eval_results = pd.DataFrame(eval_results_array, model_name_list, eval_metrics_list) 49 | save_csv_file_path = os.path.join(saved_result_dir, f'p_p_agg_ablation_volleyball_{test_data_type}.csv') 50 | df_eval_results.to_csv(save_csv_file_path) -------------------------------------------------------------------------------- /analysis/iccv2023/fusion_ablation_on_volleyball.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | import pandas as pd 5 | import numpy as np 6 | 7 | 8 | saved_result_dir = os.path.join('results', 'volleyball') 9 | 10 | # define analyze model type 11 | analyze_name_list = [] 12 | analyze_name_list.append('volleyball-dual-mid_p_p_field_middle_p_s_davt_bbox_GT_gaze_GT_act_GT') 13 | # analyze_name_list.append('volleyball-dual-mid_p_p_field_middle_p_s_davt_bbox_PRED_gaze_PRED_act_PRED') 14 | 15 | # define ablate type 16 | analyze_name_ablation_list = [] 17 | analyze_name_ablation_list.append('average') 18 | analyze_name_ablation_list.append('weight') 19 | analyze_name_ablation_list.append('cnn') 20 | 21 | # define model names 22 | model_name_list = [] 23 | model_name_list.append('Mean average') 24 | model_name_list.append('Weighted average') 25 | model_name_list.append('CNN fusion') 26 | 27 | # define test data type 28 | test_data_type_list = [] 29 | test_data_type_list.append('bbox_GT_gaze_GT_act_GT_blur_False') 30 | # test_data_type_list.append('bbox_PRED_gaze_PRED_act_PRED_blur_False') 31 | for test_data_type in test_data_type_list: 32 | print(f'==={test_data_type}===') 33 | for analyze_name in analyze_name_list: 34 | eval_results_list = [] 35 | for ablation_name in analyze_name_ablation_list: 36 | 37 | if ablation_name == 'average': 38 | model_name = f'{analyze_name}_{ablation_name}_fusion_fix_token_only' 39 | else: 40 | model_name = f'{analyze_name}_{ablation_name}_fusion_fine_token_only' 41 | json_file_path = os.path.join(saved_result_dir, model_name, 'eval_results', test_data_type, 'eval_results.json') 42 | with open(json_file_path, 'r') as f: 43 | eval_results_dic = json.load(f) 44 | eval_results_list.append(list(eval_results_dic.values())) 45 | eval_metrics_list = list(eval_results_dic.keys()) 46 | 47 | eval_results_array = np.array(eval_results_list) 48 | df_eval_results = pd.DataFrame(eval_results_array, model_name_list, eval_metrics_list) 49 | save_csv_file_path = os.path.join(saved_result_dir, f'fusion_ablation_volleyball_{test_data_type}.csv') 50 | df_eval_results.to_csv(save_csv_file_path) -------------------------------------------------------------------------------- /analysis/iccv2023/tran_enc_comparision_on_volleyball.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | import pandas as pd 5 | import numpy as np 6 | 7 | 8 | saved_result_dir = os.path.join('results', 'volleyball') 9 | 10 | # define analyze model type 11 | analyze_name_list = [] 12 | # analyze_name_list.append('volleyball-dual-mid_p_p_field_middle_p_s_davt_bbox_GT_gaze_GT_act_GT_psfix_fusion') 13 | analyze_name_list.append('volleyball-dual-mid_p_p_field_middle_p_s_davt_bbox_PRED_gaze_PRED_act_PRED_psfix_fusion') 14 | 15 | # define ablate type 16 | analyze_name_ablation_list = [] 17 | analyze_name_ablation_list.append('_enc_1') 18 | analyze_name_ablation_list.append('') 19 | analyze_name_ablation_list.append('_enc_3') 20 | analyze_name_ablation_list.append('_enc_4') 21 | 22 | # define model name 23 | model_name_list = [] 24 | model_name_list.append('1') 25 | model_name_list.append('2') 26 | model_name_list.append('3') 27 | model_name_list.append('4') 28 | 29 | # define test data type 30 | test_data_type_list = [] 31 | # test_data_type_list.append('bbox_GT_gaze_GT_act_GT_blur_False') 32 | test_data_type_list.append('bbox_PRED_gaze_PRED_act_PRED_blur_False') 33 | for test_data_type in test_data_type_list: 34 | print(f'==={test_data_type}===') 35 | for analyze_name in analyze_name_list: 36 | # model_name_list = [] 37 | eval_results_list = [] 38 | analyze_name_type = analyze_name 39 | for ablation_name in analyze_name_ablation_list: 40 | 41 | model_name = f'{analyze_name_type}{ablation_name}' 42 | # model_name_list.append(model_name) 43 | 44 | json_file_path = os.path.join(saved_result_dir, model_name, 'eval_results', test_data_type, 'eval_results.json') 45 | 46 | with open(json_file_path, 'r') as f: 47 | eval_results_dic = json.load(f) 48 | eval_results_list.append(list(eval_results_dic.values())) 49 | eval_metrics_list = list(eval_results_dic.keys()) 50 | 51 | eval_results_array = np.array(eval_results_list) 52 | df_eval_results = pd.DataFrame(eval_results_array, model_name_list, eval_metrics_list) 53 | save_csv_file_path = os.path.join(saved_result_dir, f'trans_enc_{analyze_name}_{test_data_type}.csv') 54 | df_eval_results.to_csv(save_csv_file_path) -------------------------------------------------------------------------------- /eval_vol.bash: -------------------------------------------------------------------------------- 1 | IFS_BACKUP=$IFS 2 | IFS=$'\n' 3 | 4 | model_array=( 5 | # Please fill in the model path 6 | # 'volleyball_GT_ori_att_vid_token_t_enc' 7 | # 'volleyball_GT_ori_att_vid_token_mask_t_enc' 8 | # 'volleyball_PRED_ori_att_vid_token_t_enc' 9 | # 'volleyball_PRED_ori_att_vid_token_mask_t_enc' 10 | # 'volleyball_PRED_ori_att_vid_wo_token' 11 | # 'volleyball_PRED_ori_att_vid_token_mask_every2_t_enc' 12 | # 'volleyball_PRED_ori_att_vid_token_mask_mid1_t_enc' 13 | # 'volleyball_PRED_ori_att_vid_token_mask_random10_t_enc' 14 | # 'volleyball_PRED_ori_att_vid_token_mask_random25_t_enc' 15 | # 'volleyball_PRED_ori_att_vid_token_mask_random75_t_enc' 16 | # 'volleyball_PRED_ori_att_vid_token_mask_random90_t_enc' 17 | # 'volleyball_PRED_ori_att_vid_token_mask_random25_t_enc_DAVT_scalar_fusion' 18 | # 'volleyball_PRED_ori_att_vid_token_mask_random25_t_enc_DAVT_scalar_fusion_freeze' 19 | # 'volleyball_PRED_DAVT_only' 20 | # 'volleyball_PRED_DAVT_only_lr_e2_modified' 21 | # 'volleyball_PRED_DAVT_only_lr_e3_modified' 22 | # 'volleyball_PRED_DAVT_only_lr_e3_modified_init' 23 | # 'volleyball_PRED_DAVT_only_lr_e3_modified_init_2layer' 24 | # 'volleyball_PRED_DAVT_only_lr_e3_modified_videoatttarget' 25 | # 'volleyball_PRED_ori_att_vid_token_mask_random25_t_enc_DAVT_scalar_fusion_mod' 26 | # 'volleyball_PRED_DAVT_only_lr_e2' 27 | # 'volleyball_PRED_DAVT_only_lr_e3_demo' 28 | # 'volleyball_PRED_DAVT_only_lr_e3_gazefollow' 29 | # 'volleyball_PRED_DAVT_only_lr_e3_videoatttarget' 30 | # 'volleyball_PRED_DAVT_only_lr_e3' 31 | # 'volleyball_PRED_DAVT_only_lr_e4' 32 | # 'volleyball_PRED_ori_att_vid_token_mask_random25_t_enc_DAVT_scalar_fusion_ver2' 33 | # 'volleyball_PRED_ori_att_vid_token_mask_random25_t_enc_DAVT_scalar_fusion_ver2_freeze' 34 | # 'volleyball_PRED_ori_att_vid_token_mask_every2_start_random_t_enc' 35 | # 'volleyball_PRED_ori_att_vid_token_mask_mid3_start_random_t_enc' 36 | # 'volleyball_PRED_ori_att_vid_token_mask_mid4_start_random_t_enc' 37 | 'volleyball_PRED_ori_att_vid_token_mask_random25_t_enc_DAVT_scalar_fusion_cont' 38 | ) 39 | 40 | for model in ${model_array[@]}; do 41 | echo $model 42 | python eval_on_volleyball_ours.py yaml_files/volleyball_wo_att/eval.yaml -m $model 43 | done -------------------------------------------------------------------------------- /analysis/iccv2023/p_s_gen_ablation_on_volleyball.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | import pandas as pd 5 | import numpy as np 6 | 7 | 8 | saved_result_dir = os.path.join('results', 'volleyball') 9 | 10 | # define analyze model type 11 | analyze_name_list = [] 12 | # analyze_name_list.append('volleyball-dual-mid_p_p_field_middle_p_s_davt_bbox_GT_gaze_GT_act_GT_psfix_fusion') 13 | analyze_name_list.append('volleyball-dual-mid_p_p_field_middle_p_s_davt_bbox_PRED_gaze_PRED_act_PRED_psfix_fusion') 14 | 15 | # define ablate type 16 | analyze_name_ablation_list = [] 17 | analyze_name_ablation_list.append('') 18 | analyze_name_ablation_list.append('_scalar_weight') 19 | analyze_name_ablation_list.append('_scalar_weight_endtoend') 20 | analyze_name_ablation_list.append('_scalar_weight_endtoend_fix') 21 | analyze_name_ablation_list.append('_mid') 22 | 23 | # define model name 24 | model_name_list = [] 25 | model_name_list.append('simple_average') 26 | model_name_list.append('weight') 27 | model_name_list.append('weight_endtoend') 28 | model_name_list.append('weight_endtoend_fix') 29 | model_name_list.append('mid') 30 | 31 | # define test data type 32 | test_data_type_list = [] 33 | # test_data_type_list.append('bbox_GT_gaze_GT_act_GT_blur_False') 34 | test_data_type_list.append('bbox_PRED_gaze_PRED_act_PRED_blur_False') 35 | for test_data_type in test_data_type_list: 36 | print(f'==={test_data_type}===') 37 | for analyze_name in analyze_name_list: 38 | eval_results_list = [] 39 | for ablation_name in analyze_name_ablation_list: 40 | 41 | model_name = f'{analyze_name}{ablation_name}' 42 | json_file_path = os.path.join(saved_result_dir, model_name, 'eval_results', test_data_type, 'eval_results.json') 43 | with open(json_file_path, 'r') as f: 44 | eval_results_dic = json.load(f) 45 | eval_results_list.append(list(eval_results_dic.values())) 46 | eval_metrics_list = list(eval_results_dic.keys()) 47 | 48 | eval_results_array = np.array(eval_results_list) 49 | df_eval_results = pd.DataFrame(eval_results_array, model_name_list, eval_metrics_list) 50 | save_csv_file_path = os.path.join(saved_result_dir, f'p_s_gen_ablation_volleyball_{test_data_type}.csv') 51 | df_eval_results.to_csv(save_csv_file_path) -------------------------------------------------------------------------------- /yaml_files/volleyball/debug_isa.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | name: volleyball 3 | sendo_dataset_dir : data/volleyball_tracking_annotation 4 | rgb_dataset_dir : data/videos 5 | annotation_dir : data/vatic_ball_annotation/annotation_data/ 6 | dataset_bbox_gt: data/jae_dataset_bbox_gt 7 | dataset_bbox_pred: data/jae_dataset_bbox_pred 8 | 9 | exp_set: 10 | save_folder: saved_weights 11 | wandb_name: debug 12 | wandb_log : False 13 | 14 | batch_size: 16 15 | num_workers: 16 16 | seed_num: 777 17 | gpu_mode : True 18 | gpu_start : 2 19 | gpu_finish : 2 20 | 21 | resize_height: 320 22 | resize_width: 480 23 | resize_head_height: 64 24 | resize_head_width: 64 25 | 26 | exp_params: 27 | # use_e_att_loss : False 28 | use_e_att_loss : True 29 | 30 | use_frame_type: mid 31 | # use_frame_type: all 32 | 33 | # bbox_types: GT 34 | bbox_types: PRED 35 | 36 | # gaze_types: GT 37 | gaze_types: PRED 38 | 39 | # action_types: GT 40 | action_types: PRED 41 | 42 | # position augmentation 43 | use_position_aug: False 44 | # use_position_aug: True 45 | position_aug_std: 0.05 46 | 47 | # loss function 48 | loss : mse 49 | # loss : bce 50 | 51 | # learning rate 52 | lr : 0.0001 53 | 54 | # gt gaussian 55 | gaussian_sigma: 10 56 | 57 | # learning schedule 58 | nEpochs : 500 59 | start_iter : 0 60 | snapshots : 100 61 | scheduler_start : 1000 62 | scheduler_iter : 1100000 63 | # pretrained models 64 | pretrained_models_dir: saved_weights 65 | 66 | # use_pretrained_head_pose_estimator: False 67 | use_pretrained_head_pose_estimator: True 68 | pretrained_head_pose_estimator_name: volleyball-head_pose_estimator 69 | # freeze_head_pose_estimator: False 70 | freeze_head_pose_estimator: True 71 | 72 | # use_pretrained_saliency_extractor: False 73 | use_pretrained_saliency_extractor: True 74 | pretrained_saliency_extractor_name: 2021_0708_lr_e3_gamma_1_stack_3_mid_frame_ver2 75 | # freeze_saliency_extractor: False 76 | freeze_saliency_extractor: True 77 | 78 | use_pretrained_joint_attention_estimator: False 79 | # use_pretrained_joint_attention_estimator: True 80 | pretrained_joint_attention_estimator_name: pretrain_head_estimator 81 | freeze_joint_attention_estimator: False 82 | # freeze_joint_attention_estimator: True 83 | 84 | model_params: 85 | model_type: isa -------------------------------------------------------------------------------- /yaml_files/videocoatt/train_isa.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | name: videocoatt 3 | dataset_dir : data/VideoCoAtt_Dataset 4 | saliency_dataset_dir : data/deepgaze_output_loader 5 | 6 | exp_set: 7 | save_folder: saved_weights 8 | # wandb_name: videocoatt-isa 9 | # wandb_name: videocoatt-isa-mse-1103 10 | # wandb_name: videocoatt-isa-mse-1103_w_dets_head 11 | # wandb_name: videocoatt-isa-mse-1103_wo_dets_head 12 | 13 | # wandb_name: videocoatt-isa_bbox_GT_gaze_GT 14 | wandb_name: videocoatt-isa_bbox_GT_gaze_GT_ver2 15 | # wandb_name: videocoatt-isa_bbox_PRED_gaze_PRED 16 | 17 | wandb_log : True 18 | 19 | batch_size: 16 20 | num_workers: 16 21 | # seed_num: 777 22 | seed_num: 888 23 | gpu_mode : True 24 | gpu_start : 6 25 | gpu_finish : 6 26 | 27 | # resize_height: 320 28 | resize_height: 28 29 | # resize_width: 480 30 | resize_width: 28 31 | resize_head_height: 64 32 | resize_head_width: 64 33 | 34 | exp_params: 35 | # use_e_att_loss : False 36 | use_e_att_loss : True 37 | 38 | use_frame_type: mid 39 | # use_frame_type: all 40 | 41 | # use_gt_gaze: False 42 | use_gt_gaze: True 43 | 44 | # position augmentation 45 | use_position_aug: False 46 | # use_position_aug: True 47 | position_aug_std: 0.05 48 | 49 | # loss function 50 | loss : mse 51 | # loss : bce 52 | 53 | # learning rate 54 | lr : 0.0001 55 | 56 | # gt gaussian 57 | gaussian_sigma: 10 58 | 59 | # learning schedule 60 | nEpochs : 500 61 | start_iter : 0 62 | snapshots : 100 63 | scheduler_start : 1000 64 | scheduler_iter : 1100000 65 | 66 | det_heads_model : det_heads 67 | train_det_heads : False 68 | # train_det_heads : True 69 | train_heads_conf : 0.6 70 | test_heads_conf : 0.6 71 | 72 | # pretrained models 73 | pretrained_models_dir: saved_weights 74 | 75 | # use_pretrained_head_pose_estimator: False 76 | use_pretrained_head_pose_estimator: True 77 | pretrained_head_pose_estimator_name: videocoatt-head_pose_estimator 78 | # freeze_head_pose_estimator: False 79 | freeze_head_pose_estimator: True 80 | 81 | use_pretrained_joint_attention_estimator: False 82 | # use_pretrained_head_pose_estimator: True 83 | pretrained_joint_attention_estimator_name: pretrain_head_estimator 84 | freeze_joint_attention_estimator: False 85 | # freeze_joint_attention_estimator: True 86 | 87 | model_params: 88 | model_type: isa -------------------------------------------------------------------------------- /analysis/iccv2023/multi_heads_comparision_on_volleyball.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | import pandas as pd 5 | import numpy as np 6 | 7 | 8 | saved_result_dir = os.path.join('results', 'volleyball') 9 | 10 | # define analyze model type 11 | analyze_name_list = [] 12 | # analyze_name_list.append('volleyball-dual-mid_p_p_field_middle_p_s_davt_bbox_GT_gaze_GT_act_GT_psfix_fusion') 13 | analyze_name_list.append('volleyball-dual-mid_p_p_field_middle_p_s_davt_bbox_PRED_gaze_PRED_act_PRED_psfix_fusion') 14 | 15 | # define ablate type 16 | analyze_name_ablation_list = [] 17 | analyze_name_ablation_list.append('_multi_1') 18 | analyze_name_ablation_list.append('') 19 | analyze_name_ablation_list.append('_multi_4') 20 | analyze_name_ablation_list.append('_multi_4') 21 | analyze_name_ablation_list.append('_multi_16') 22 | 23 | # define model name 24 | model_name_list = [] 25 | model_name_list.append('1') 26 | model_name_list.append('2') 27 | model_name_list.append('4') 28 | model_name_list.append('8') 29 | model_name_list.append('16') 30 | 31 | # define test data type 32 | test_data_type_list = [] 33 | # test_data_type_list.append('bbox_GT_gaze_GT_act_GT_blur_False') 34 | test_data_type_list.append('bbox_PRED_gaze_PRED_act_PRED_blur_False') 35 | for test_data_type in test_data_type_list: 36 | print(f'==={test_data_type}===') 37 | for analyze_name in analyze_name_list: 38 | # model_name_list = [] 39 | eval_results_list = [] 40 | analyze_name_type = analyze_name 41 | for ablation_name in analyze_name_ablation_list: 42 | 43 | model_name = f'{analyze_name_type}{ablation_name}' 44 | # model_name_list.append(model_name) 45 | 46 | json_file_path = os.path.join(saved_result_dir, model_name, 'eval_results', test_data_type, 'eval_results.json') 47 | 48 | with open(json_file_path, 'r') as f: 49 | eval_results_dic = json.load(f) 50 | eval_results_list.append(list(eval_results_dic.values())) 51 | eval_metrics_list = list(eval_results_dic.keys()) 52 | 53 | eval_results_array = np.array(eval_results_list) 54 | df_eval_results = pd.DataFrame(eval_results_array, model_name_list, eval_metrics_list) 55 | save_csv_file_path = os.path.join(saved_result_dir, f'multi_heads_{analyze_name}_{test_data_type}.csv') 56 | df_eval_results.to_csv(save_csv_file_path) -------------------------------------------------------------------------------- /analysis/iccv_comparison_on_vollleyball.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | import pandas as pd 5 | import numpy as np 6 | 7 | saved_result_dir_analyze = os.path.join('results', 'volleyball_wo_att') 8 | 9 | saved_result_dir_list = [] 10 | saved_result_dir_list.append(os.path.join('results', 'volleyball_all')) 11 | saved_result_dir_list.append(os.path.join('results', 'volleyball_all')) 12 | saved_result_dir_list.append(os.path.join('results', 'volleyball_wo_att')) 13 | 14 | analyze_name_list = [] 15 | # analyze_name_list.append('volleyball-dual-mid_p_p_field_middle_p_s_davt_bbox_GT_gaze_GT_act_GT_weight_fusion_fine_token_only') 16 | # analyze_name_list.append('volleyball-dual-mid_p_p_field_middle_p_s_davt_bbox_PRED_gaze_PRED_act_PRED_weight_fusion_fine_token_only') 17 | analyze_name_list.append('volleyball-dual-mid_p_p_field_middle_bbox_GT_gaze_GT_act_GT') 18 | analyze_name_list.append('volleyball-dual-mid_p_p_field_middle_bbox_PRED_gaze_PRED_act_PRED_token_only') 19 | analyze_name_list.append('volleyball_p_p_field_middle_bbox_PRED_ind_128_token_only_w_gaze_loss_img_att_cross') 20 | 21 | model_name_list = [] 22 | model_name_list.append('Ours (ICCV2023:GT)') 23 | model_name_list.append('Ours (ICCV2023:PRED)') 24 | model_name_list.append('Ours (PRED)') 25 | 26 | # define test data type 27 | test_data_type_list = [] 28 | # test_data_type_list.append('bbox_GT_gaze_GT_act_GT_blur_False') 29 | test_data_type_list.append('bbox_PRED_gaze_PRED_act_PRED_blur_False') 30 | for test_data_type in test_data_type_list: 31 | print(f'==={test_data_type}===') 32 | eval_results_list = [] 33 | for analyze_idx, analyze_name in enumerate(analyze_name_list): 34 | model_name = f'{analyze_name}' 35 | 36 | json_file_path = os.path.join(saved_result_dir_list[analyze_idx], model_name, 'eval_results', test_data_type, 'eval_results.json') 37 | 38 | with open(json_file_path, 'r') as f: 39 | eval_results_dic = json.load(f) 40 | eval_results_list.append(list(eval_results_dic.values())) 41 | eval_metrics_list = list(eval_results_dic.keys()) 42 | 43 | eval_results_array = np.array(eval_results_list) 44 | df_eval_results = pd.DataFrame(eval_results_array, model_name_list, eval_metrics_list) 45 | save_excel_file_path = os.path.join(saved_result_dir_analyze, f'iccv_comparison_{test_data_type}.xlsx') 46 | df_eval_results.to_excel(save_excel_file_path) -------------------------------------------------------------------------------- /analysis/iccv2023/p_p_gen_ablation_on_volleyball.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | import pandas as pd 5 | import numpy as np 6 | 7 | 8 | saved_result_dir = os.path.join('results', 'volleyball') 9 | 10 | # define analyze model type 11 | analyze_name_list = [] 12 | analyze_name_list.append('volleyball-dual-mid_p_p') 13 | 14 | # define ablate type 15 | analyze_name_ablation_list = [] 16 | # analyze_name_ablation_list.append('fc_shallow') 17 | analyze_name_ablation_list.append('fc_middle') 18 | analyze_name_ablation_list.append('fc_deep') 19 | analyze_name_ablation_list.append('deconv_shallow') 20 | analyze_name_ablation_list.append('deconv_middle') 21 | analyze_name_ablation_list.append('field_middle') 22 | analyze_name_ablation_list.append('field_deep') 23 | 24 | # define model name 25 | model_name_list = [] 26 | # model_name_list.append('fc_shallow') 27 | model_name_list.append('fc_middle') 28 | model_name_list.append('fc_deep') 29 | model_name_list.append('deconv_shallow') 30 | model_name_list.append('deconv_middle') 31 | model_name_list.append('field_middle') 32 | model_name_list.append('field_deep') 33 | 34 | # define test data type 35 | test_data_type_list = [] 36 | test_data_type_list.append('bbox_GT_gaze_GT_act_GT_blur_False') 37 | # test_data_type_list.append('bbox_PRED_gaze_PRED_act_PRED_blur_False') 38 | for test_data_type in test_data_type_list: 39 | print(f'==={test_data_type}===') 40 | for analyze_name in analyze_name_list: 41 | eval_results_list = [] 42 | for ablation_name in analyze_name_ablation_list: 43 | 44 | model_name = f'{analyze_name}_{ablation_name}_bbox_GT_gaze_GT_act_GT' 45 | # model_name = f'{analyze_name}_{ablation_name}_bbox_PRED_gaze_PRED_act_PRED' 46 | json_file_path = os.path.join(saved_result_dir, model_name, 'eval_results', test_data_type, 'eval_results.json') 47 | with open(json_file_path, 'r') as f: 48 | eval_results_dic = json.load(f) 49 | eval_results_list.append(list(eval_results_dic.values())) 50 | eval_metrics_list = list(eval_results_dic.keys()) 51 | 52 | eval_results_array = np.array(eval_results_list) 53 | df_eval_results = pd.DataFrame(eval_results_array, model_name_list, eval_metrics_list) 54 | save_csv_file_path = os.path.join(saved_result_dir, f'p_p_gen_ablation_volleyball_{test_data_type}.csv') 55 | df_eval_results.to_csv(save_csv_file_path) -------------------------------------------------------------------------------- /yaml_files/volleyball/train_isa.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | name: volleyball 3 | sendo_dataset_dir : data/volleyball_tracking_annotation 4 | rgb_dataset_dir : data/videos 5 | annotation_dir : data/vatic_ball_annotation/annotation_data/ 6 | dataset_bbox_gt: data/jae_dataset_bbox_gt 7 | dataset_bbox_pred: data/jae_dataset_bbox_pred 8 | 9 | exp_set: 10 | save_folder: saved_weights 11 | wandb_name: volleyball-isa_bbox_GT_gaze_GT_act_GT 12 | # wandb_name: volleyball-isa_bbox_PRED_gaze_PRED_act_PRED 13 | wandb_log : True 14 | 15 | batch_size: 4 16 | num_workers: 16 17 | seed_num: 777 18 | gpu_mode : True 19 | gpu_start : 2 20 | gpu_finish : 2 21 | 22 | resize_height: 320 23 | resize_width: 480 24 | resize_head_height: 64 25 | resize_head_width: 64 26 | 27 | exp_params: 28 | # use_e_att_loss : False 29 | use_e_att_loss : True 30 | 31 | use_frame_type: mid 32 | # use_frame_type: all 33 | 34 | bbox_types: GT 35 | # bbox_types: PRED 36 | 37 | gaze_types: GT 38 | # gaze_types: PRED 39 | 40 | action_types: GT 41 | # action_types: PRED 42 | 43 | # position augmentation 44 | use_position_aug: False 45 | # use_position_aug: True 46 | position_aug_std: 0.05 47 | 48 | # loss function 49 | loss : mse 50 | # loss : bce 51 | 52 | # learning rate 53 | # lr : 0.0001 54 | lr : 0.001 55 | 56 | # gt gaussian 57 | gaussian_sigma: 40 58 | 59 | # learning schedule 60 | nEpochs : 500 61 | start_iter : 0 62 | snapshots : 100 63 | scheduler_start : 1000 64 | scheduler_iter : 1100000 65 | # pretrained models 66 | pretrained_models_dir: saved_weights 67 | 68 | # use_pretrained_head_pose_estimator: False 69 | use_pretrained_head_pose_estimator: True 70 | pretrained_head_pose_estimator_name: volleyball-head_pose_estimator 71 | # freeze_head_pose_estimator: False 72 | freeze_head_pose_estimator: True 73 | 74 | # use_pretrained_saliency_extractor: False 75 | use_pretrained_saliency_extractor: True 76 | pretrained_saliency_extractor_name: 2021_0708_lr_e3_gamma_1_stack_3_mid_frame_ver2 77 | # freeze_saliency_extractor: False 78 | freeze_saliency_extractor: True 79 | 80 | use_pretrained_joint_attention_estimator: False 81 | # use_pretrained_joint_attention_estimator: True 82 | pretrained_joint_attention_estimator_name: pretrain_head_estimator 83 | freeze_joint_attention_estimator: False 84 | # freeze_joint_attention_estimator: True 85 | 86 | model_params: 87 | model_type: isa -------------------------------------------------------------------------------- /analysis/iccv2023/gaussian_comparison_on_volleyball.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | import pandas as pd 5 | import numpy as np 6 | 7 | 8 | saved_result_dir = os.path.join('results', 'volleyball') 9 | 10 | # define analyze model type 11 | analyze_name_list = [] 12 | analyze_name_list.append('volleyball-dual-mid_p_p_field_middle_bbox_PRED_gaze_PRED_act_PRED_token_only') 13 | 14 | # define ablate type 15 | analyze_name_ablation_list = [] 16 | analyze_name_ablation_list.append('_var_100') 17 | analyze_name_ablation_list.append('_var_80') 18 | analyze_name_ablation_list.append('_var_60') 19 | analyze_name_ablation_list.append('_var_40') 20 | analyze_name_ablation_list.append('_var_20') 21 | analyze_name_ablation_list.append('') 22 | analyze_name_ablation_list.append('_var_5') 23 | analyze_name_ablation_list.append('_var_1') 24 | 25 | # define model name 26 | model_name_list = [] 27 | model_name_list.append('100') 28 | model_name_list.append('80') 29 | model_name_list.append('60') 30 | model_name_list.append('40') 31 | model_name_list.append('20') 32 | model_name_list.append('10') 33 | model_name_list.append('5') 34 | model_name_list.append('1') 35 | 36 | # define test data type 37 | test_data_type_list = [] 38 | # test_data_type_list.append('bbox_GT_gaze_GT_act_GT_blur_False') 39 | test_data_type_list.append('bbox_PRED_gaze_PRED_act_PRED_blur_False') 40 | for test_data_type in test_data_type_list: 41 | print(f'==={test_data_type}===') 42 | for analyze_name in analyze_name_list: 43 | # model_name_list = [] 44 | eval_results_list = [] 45 | analyze_name_type = analyze_name 46 | for ablation_name in analyze_name_ablation_list: 47 | 48 | model_name = f'{analyze_name_type}{ablation_name}' 49 | # model_name_list.append(model_name) 50 | 51 | json_file_path = os.path.join(saved_result_dir, model_name, 'eval_results', test_data_type, 'eval_results.json') 52 | 53 | with open(json_file_path, 'r') as f: 54 | eval_results_dic = json.load(f) 55 | eval_results_list.append(list(eval_results_dic.values())) 56 | eval_metrics_list = list(eval_results_dic.keys()) 57 | 58 | eval_results_array = np.array(eval_results_list) 59 | df_eval_results = pd.DataFrame(eval_results_array, model_name_list, eval_metrics_list) 60 | save_csv_file_path = os.path.join(saved_result_dir, f'gaussian_comparison_{analyze_name}_{test_data_type}.csv') 61 | df_eval_results.to_csv(save_csv_file_path) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # additional requirements 132 | wandb 133 | data 134 | saved_weights 135 | results 136 | .vscode 137 | models/detr -------------------------------------------------------------------------------- /yaml_files/volleyball/demo.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | name: volleyball 3 | sendo_dataset_dir : data/volleyball_tracking_annotation 4 | rgb_dataset_dir : data/videos 5 | # annotation_dir : data/vatic_ball_annotation/annotation_data/ 6 | annotation_dir : data/vatic_ball_annotation/annotation_data_sub/ 7 | dataset_bbox_gt: data/jae_dataset_bbox_gt 8 | dataset_bbox_pred: data/jae_dataset_bbox_pred 9 | 10 | exp_set: 11 | save_folder : saved_weights 12 | 13 | model_name: volleyball-selection 14 | 15 | # model_name: 2021_0708_lr_e3_gamma_1_stack_3_mid_frame_ver2 16 | # model_name: volleyball-isa_bbox_GT_gaze_GT_act_GT 17 | # model_name: volleyball-isa_bbox_PRED_gaze_PRED_act_PRED 18 | # model_name: volleyball-hgtd_bbox_GT_gaze_GT_act_GT 19 | # model_name: volleyball-dual-mid_p_p_field_middle_p_s_davt_bbox_GT_gaze_GT_act_GT_p_s_only 20 | # model_name: volleyball-dual-mid_p_p_field_middle_p_s_davt_bbox_PRED_gaze_PRED_act_PRED_p_s_only 21 | 22 | # model_name: volleyball-dual-mid_p_p_field_middle_p_s_davt_bbox_GT_gaze_GT_act_GT_psfix_fusion 23 | # model_name: volleyball-dual-mid_p_p_field_middle_p_s_davt_bbox_PRED_gaze_PRED_act_PRED_psfix_fusion 24 | 25 | # model_name: volleyball-dual-mid_p_p_field_middle_bbox_PRED_gaze_PRED_act_PRED 26 | # model_name: volleyball-dual-mid_p_p_field_middle_bbox_GT_gaze_GT_act_GT 27 | # model_name: volleyball-dual-mid_p_p_field_middle_bbox_PRED_gaze_PRED_act_PRED_mse 28 | # model_name: volleyball-dual-mid_p_p_field_middle_bbox_PRED_gaze_PRED_act_PRED_bce 29 | 30 | # model_name: volleyball-dual-mid_p_p_field_middle_p_s_davt_bbox_PRED_gaze_PRED_act_PRED_psfix_fusion_scalar_weight_fine 31 | # model_name: volleyball-dual-mid_p_p_field_middle_p_s_davt_bbox_PRED_gaze_PRED_act_PRED_cnn_fusion_fine_token_only 32 | 33 | # model_name: volleyball-dual-mid_p_p_field_middle_bbox_GT_gaze_GT_act_GT 34 | # model_name: volleyball-dual-mid_p_p_field_middle_bbox_GT_gaze_GT_act_GT_token_only_w_gaze_noise 35 | model_name: volleyball-dual-mid_p_p_field_middle_p_s_davt_bbox_GT_gaze_GT_act_GT_weight_fusion_fine_token_only 36 | 37 | seed_num : 777 38 | gpu_mode : True 39 | gpu_start : 0 40 | gpu_finish : 0 41 | num_workers : 1 42 | batch_size : 1 43 | wandb_name : demo 44 | 45 | # mode: train 46 | mode: test 47 | 48 | exp_params: 49 | # use_frame_type: mid 50 | use_frame_type: all 51 | 52 | bbox_types: GT 53 | # bbox_types: PRED 54 | action_types: GT 55 | # action_types: PRED 56 | gaze_types: GT 57 | # gaze_types: PRED 58 | # gaze_types: NONE 59 | 60 | # use_action: True 61 | # use_position: True 62 | # use_gaze: True 63 | 64 | # vis_dist_error: False 65 | vis_dist_error: True -------------------------------------------------------------------------------- /models/head_pose_estimator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torchvision.models as models 5 | from torch.nn import functional as F 6 | 7 | import sys 8 | 9 | class HeadPoseEstimatorResnet(nn.Module): 10 | def __init__(self, cfg): 11 | super(HeadPoseEstimatorResnet, self).__init__() 12 | 13 | # load resnet 14 | resnet = models.resnet18(pretrained=True) 15 | resnet = nn.Sequential(*list(resnet.children())[:-1]) 16 | # for p in resnet.parameters(): 17 | # p.requires_grad = False 18 | self.feature_extractor = resnet 19 | 20 | self.head_pose_estimator = nn.Sequential( 21 | nn.Linear(512, 64), 22 | nn.ReLU(), 23 | nn.Linear(64, 16), 24 | nn.ReLU(), 25 | nn.Linear(16, 2), 26 | ) 27 | 28 | # define loss function 29 | self.use_gaze_loss = cfg.exp_params.use_gaze_loss 30 | self.loss_func_head_pose = nn.MSELoss(reduction='sum') 31 | 32 | def forward(self, inp): 33 | # unpack input data 34 | head_img = inp['head_img'] 35 | 36 | # head feature extraction 37 | batch_size, frame_num, people_num, channel_num, img_height, img_width = head_img.shape 38 | head_img = head_img.view(batch_size*frame_num*people_num, channel_num, img_height, img_width) 39 | 40 | head_feature = self.feature_extractor(head_img) 41 | head_feature = head_feature.mean(dim=(-2, -1)) 42 | 43 | # head pose estimation 44 | head_vector = self.head_pose_estimator(head_feature) 45 | head_vector = head_vector.view(batch_size, people_num, -1) 46 | head_feature = head_feature.view(batch_size, people_num, -1) 47 | 48 | # normarize head pose 49 | head_vector = F.normalize(head_vector, dim=-1) 50 | 51 | # pack output data 52 | out = {} 53 | out['head_vector'] = head_vector 54 | out['head_img_extract'] = head_feature 55 | 56 | return out 57 | 58 | def calc_loss(self, inp, out): 59 | # unpack data 60 | head_vector_gt = inp['head_vector_gt'] 61 | att_inside_flag = inp['att_inside_flag'] 62 | head_vector = out['head_vector'] 63 | 64 | # define coeficient 65 | if self.use_gaze_loss: 66 | loss_head_coef = 0.01 67 | else: 68 | loss_head_coef = 0 69 | 70 | # calculate loss 71 | head_vector_no_pad = head_vector[:, :, 0:2]*att_inside_flag[:, :, None] 72 | head_vector_gt_no_pad = head_vector_gt[:, :, 0:2]*att_inside_flag[:, :, None] 73 | head_num_sum_no_pad = torch.sum(att_inside_flag) 74 | loss_head = self.loss_func_head_pose(head_vector_no_pad, head_vector_gt_no_pad) 75 | loss_head = loss_head/head_num_sum_no_pad 76 | loss_head = loss_head*loss_head_coef 77 | 78 | # pack data 79 | loss_set = {} 80 | loss_set['loss_head'] = loss_head 81 | 82 | return loss_set -------------------------------------------------------------------------------- /analysis/iccv2023/comparison_finetune_on_volleyball.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | import pandas as pd 5 | import numpy as np 6 | 7 | saved_result_dir = os.path.join('results', 'volleyball') 8 | 9 | # define training modality type 10 | train_mode_list = [] 11 | train_mode_list.append('GT') 12 | # train_mode_list.append('Pr') 13 | 14 | # define test data type 15 | test_data_type_list = [] 16 | test_data_type_list.append('bbox_GT_gaze_GT_act_GT_blur_False') 17 | # test_data_type_list.append('bbox_PRED_gaze_PRED_act_PRED_blur_False') 18 | 19 | # define analyze model type 20 | analyze_name_list_dic = {} 21 | 22 | # (Train:Test, GT:GT) 23 | analyze_name_list = [] 24 | analyze_name_list.append('volleyball-dual-mid_p_p_field_middle_bbox_GT_gaze_GT_act_GT_wo_action_wo_video_tuned') 25 | analyze_name_list.append('volleyball-dual-mid_p_p_field_middle_bbox_GT_gaze_GT_act_GT_wo_action_w_video_tuned_lr_0001') 26 | analyze_name_list_dic[0] = analyze_name_list 27 | 28 | # (Train:Test, Pr:Pr) 29 | # analyze_name_list = [] 30 | # analyze_name_list.append('volleyball-dual-mid_p_p_field_middle_bbox_PRED_gaze_PRED_act_PRED_wo_action_video_tuned') 31 | # analyze_name_list.append('volleyball-dual-mid_p_p_field_middle_bbox_PRED_gaze_PRED_act_PRED_wo_action') 32 | # analyze_name_list_dic[1] = analyze_name_list 33 | 34 | # define model names 35 | model_name_list = [] 36 | model_name_list.append('Ours w/o finetune') 37 | model_name_list.append('Ours w/ finetune') 38 | 39 | epoch_sum = 105 40 | epoch_div = 5 41 | for data_idx, analyze_name_list in analyze_name_list_dic.items(): 42 | train_mode = train_mode_list[data_idx] 43 | test_data_type = test_data_type_list[data_idx] 44 | print(f'==={train_mode}:{test_data_type}===') 45 | 46 | model_wo_finetune = analyze_name_list[0] 47 | model_w_finetune = analyze_name_list[1] 48 | l2_dist_array = np.zeros((2, (epoch_sum//epoch_div)-1)) 49 | 50 | epoch_num_list = [epoch_num for epoch_num in range(epoch_div, epoch_sum, epoch_div)] 51 | for epoch_idx, epoch_num in enumerate(epoch_num_list): 52 | json_file_path_wo_finetune = os.path.join(saved_result_dir, model_wo_finetune, 'eval_results', test_data_type, f'epoch_{epoch_num}', 'eval_results.json') 53 | json_file_path_w_finetune = os.path.join(saved_result_dir, model_w_finetune, 'eval_results', test_data_type, f'epoch_{epoch_num}', 'eval_results.json') 54 | with open(json_file_path_wo_finetune, 'r') as f: 55 | eval_results_dic_wo_finetune = json.load(f) 56 | with open(json_file_path_w_finetune, 'r') as f: 57 | eval_results_dic_w_finetune = json.load(f) 58 | 59 | l2_dist_array[0, epoch_idx] = eval_results_dic_wo_finetune['l2_dist_euc_p_p'] 60 | l2_dist_array[1, epoch_idx] = eval_results_dic_w_finetune['l2_dist_euc_p_p'] 61 | 62 | df_eval_results = pd.DataFrame(l2_dist_array, model_name_list, epoch_num_list) 63 | save_csv_file_path = os.path.join(saved_result_dir, f'comparision_finetune_on_volleyball_train_{train_mode}_{test_data_type}.csv') 64 | df_eval_results.to_csv(save_csv_file_path) -------------------------------------------------------------------------------- /analysis/iccv2023/p_p_agg_ablation_on_videocoatt.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | import pandas as pd 5 | import numpy as np 6 | 7 | 8 | saved_result_dir = os.path.join('results', 'videocoatt') 9 | 10 | # define analyze model type 11 | analyze_name_list = [] 12 | analyze_name_list.append('videocoatt-dual-people_field_middle') 13 | 14 | # define ablate type 15 | analyze_name_ablation_list = [] 16 | analyze_name_ablation_list.append('_ind_only') 17 | analyze_name_ablation_list.append('_token_only') 18 | # analyze_name_ablation_list.append('') 19 | analyze_name_ablation_list.append('_ind_and_token_ind_based') 20 | # analyze_name_ablation_list.append('_ind_and_token_token_based') 21 | 22 | # define model names 23 | model_name_list = [] 24 | model_name_list.append('Ind only') 25 | model_name_list.append('Token only') 26 | model_name_list.append('Ind and Token (ind-based)') 27 | # model_name_list.append('Ind and Token (token-based)') 28 | 29 | # define test data type 30 | test_data_type_list = [] 31 | test_data_type_list.append('bbox_gt_gaze_True') 32 | # test_data_type_list.append('bbox_det_gaze_False') 33 | for test_data_type in test_data_type_list: 34 | print(f'==={test_data_type}===') 35 | for analyze_name in analyze_name_list: 36 | eval_results_list = [] 37 | for ablation_name in analyze_name_ablation_list: 38 | 39 | if test_data_type == 'bbox_gt_gaze_True': 40 | model_name = f'{analyze_name}{ablation_name}_bbox_GT_gaze_GT' 41 | else: 42 | model_name = f'{analyze_name}{ablation_name}' 43 | 44 | json_file_path = os.path.join(saved_result_dir, model_name, 'eval_results', test_data_type, 'eval_results.json') 45 | with open(json_file_path, 'r') as f: 46 | eval_results_dic = json.load(f) 47 | 48 | eval_results_dic_update = {} 49 | eval_results_dic_update['Dist final (x)'] = eval_results_dic['l2_dist_x_p_p'] 50 | eval_results_dic_update['Dist final (y)'] = eval_results_dic['l2_dist_y_p_p'] 51 | eval_results_dic_update['Dist final (euc)'] = eval_results_dic['l2_dist_euc_p_p'] 52 | for i in range(20): 53 | thr = i*10 54 | eval_results_dic_update[f'Det final (Thr={thr})'] = eval_results_dic[f'Det p-p (Thr={thr})'] 55 | eval_results_dic_update['Accuracy final'] = eval_results_dic['accuracy final'] 56 | eval_results_dic_update['F-score final'] = eval_results_dic['f1 final'] 57 | eval_results_dic_update['AUC final'] = eval_results_dic['auc final'] 58 | 59 | eval_results_list.append(list(eval_results_dic_update.values())) 60 | eval_metrics_list = list(eval_results_dic_update.keys()) 61 | 62 | eval_results_array = np.array(eval_results_list) 63 | df_eval_results = pd.DataFrame(eval_results_array, model_name_list, eval_metrics_list) 64 | save_csv_file_path = os.path.join(saved_result_dir, f'p_p_agg_ablation_videocoatt_{test_data_type}.csv') 65 | df_eval_results.to_csv(save_csv_file_path) -------------------------------------------------------------------------------- /analysis/iccv2023/blur_ablation_on_volleyball.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import pandas as pd 4 | import numpy as np 5 | import sys 6 | 7 | saved_result_dir = os.path.join('results', 'volleyball') 8 | 9 | # define analyze model type 10 | # analyze_name = 'volleyball-dual-mid_p_p_field_middle_p_s_davt_bbox_GT_gaze_GT_act_GT_weight_fusion_fine_token_only' 11 | # analyze_name = 'volleyball-dual-mid_p_p_field_middle_p_s_davt_bbox_PRED_gaze_PRED_act_PRED_weight_fusion_fine_token_only' 12 | 13 | # define test data type 14 | # test_data_type_list = [] 15 | # test_data_type_list.append('bbox_GT_gaze_GT_act_GT_blur_False') 16 | # test_data_type_list.append('bbox_GT_gaze_GT_act_GT_blur_True') 17 | # test_data_type_list.append('bbox_PRED_gaze_PRED_act_PRED_blur_False') 18 | # test_data_type_list.append('bbox_PRED_gaze_PRED_act_PRED_blur_True') 19 | 20 | # eval_results_list = [] 21 | # for test_data_type in test_data_type_list: 22 | # print(f'==={test_data_type}===') 23 | # json_file_path = os.path.join(saved_result_dir, analyze_name, 'eval_results', test_data_type, 'eval_results.json') 24 | # with open(json_file_path, 'r') as f: 25 | # eval_results_dic = json.load(f) 26 | # eval_results_list.append(list(eval_results_dic.values())) 27 | # eval_metrics_list = list(eval_results_dic.keys()) 28 | 29 | # eval_results_array = np.array(eval_results_list) 30 | # df_eval_results = pd.DataFrame(eval_results_array, test_data_type_list, eval_metrics_list) 31 | # save_csv_file_path = os.path.join(saved_result_dir, f'blur_ablation_{analyze_name}.csv') 32 | # df_eval_results.to_csv(save_csv_file_path) 33 | 34 | analyze_dic = {} 35 | analyze_name = 'volleyball-dual-mid_p_p_field_middle_p_s_davt_bbox_GT_gaze_GT_act_GT_weight_fusion_fine_token_only' 36 | analyze_dic[analyze_name] = [] 37 | analyze_dic[analyze_name].append('bbox_GT_gaze_GT_act_GT_blur_False') 38 | analyze_dic[analyze_name].append('bbox_GT_gaze_GT_act_GT_blur_True') 39 | 40 | analyze_name = 'volleyball-dual-mid_p_p_field_middle_p_s_davt_bbox_PRED_gaze_PRED_act_PRED_weight_fusion_fine_token_only' 41 | analyze_dic[analyze_name] = [] 42 | analyze_dic[analyze_name].append('bbox_PRED_gaze_PRED_act_PRED_blur_False') 43 | analyze_dic[analyze_name].append('bbox_PRED_gaze_PRED_act_PRED_blur_True') 44 | 45 | for analyze_name, test_data_type_list in analyze_dic.items(): 46 | eval_results_list = [] 47 | for test_data_type in test_data_type_list: 48 | print(f'==={test_data_type}===') 49 | json_file_path = os.path.join(saved_result_dir, analyze_name, 'eval_results', test_data_type, 'eval_results.json') 50 | with open(json_file_path, 'r') as f: 51 | eval_results_dic = json.load(f) 52 | eval_results_list.append(list(eval_results_dic.values())) 53 | eval_metrics_list = list(eval_results_dic.keys()) 54 | 55 | eval_results_array = np.array(eval_results_list) 56 | df_eval_results = pd.DataFrame(eval_results_array, test_data_type_list, eval_metrics_list) 57 | save_csv_file_path = os.path.join(saved_result_dir, f'blur_ablation_{analyze_name}.csv') 58 | df_eval_results.to_csv(save_csv_file_path) -------------------------------------------------------------------------------- /analysis/p_p_mask_comparison_on_vollleyball.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | import pandas as pd 5 | import numpy as np 6 | 7 | sys.path.append('./') 8 | from analysis.utils import refine_excel 9 | 10 | saved_result_dir = os.path.join('results', 'volleyball_wo_att') 11 | 12 | # define ablate type 13 | analyze_name_list = [] 14 | analyze_name_list.append('volleyball_PRED_ori_att_vid_token_t_enc') 15 | analyze_name_list.append('volleyball_PRED_ori_att_vid_token_mask_random10_t_enc') 16 | analyze_name_list.append('volleyball_PRED_ori_att_vid_token_mask_random25_t_enc') 17 | analyze_name_list.append('volleyball_PRED_ori_att_vid_token_mask_t_enc') 18 | analyze_name_list.append('volleyball_PRED_ori_att_vid_token_mask_random75_t_enc') 19 | analyze_name_list.append('volleyball_PRED_ori_att_vid_token_mask_random90_t_enc') 20 | analyze_name_list.append('volleyball_PRED_ori_att_vid_token_mask_every2_t_enc') 21 | analyze_name_list.append('volleyball_PRED_ori_att_vid_token_mask_every2_start_random_t_enc') 22 | analyze_name_list.append('volleyball_PRED_ori_att_vid_token_mask_mid1_t_enc') 23 | analyze_name_list.append('volleyball_PRED_ori_att_vid_token_mask_mid3_start_random_t_enc') 24 | analyze_name_list.append('volleyball_PRED_ori_att_vid_token_mask_mid4_start_random_t_enc') 25 | 26 | # define model name 27 | model_name_list = [] 28 | model_name_list.append('Ours (w/o mask)') 29 | model_name_list.append('Ours (10% mask)') 30 | model_name_list.append('Ours (25% mask)') 31 | model_name_list.append('Ours (50% mask)') 32 | model_name_list.append('Ours (75% mask)') 33 | model_name_list.append('Ours (90% mask)') 34 | model_name_list.append('Ours (every 2 frames mask, start first)') 35 | model_name_list.append('Ours (every 2 frames mask, start random)') 36 | model_name_list.append('Ours (continuous 3 frames mask, start center)') 37 | model_name_list.append('Ours (continuous 3 frames mask, start random)') 38 | model_name_list.append('Ours (continuous 4 frames mask, start random)') 39 | 40 | # define test data type 41 | test_data_type_list = [] 42 | # test_data_type_list.append('bbox_GT_gaze_GT_act_GT_blur_False') 43 | test_data_type_list.append('bbox_PRED_gaze_PRED_act_PRED_blur_False') 44 | for test_data_type in test_data_type_list: 45 | print(f'==={test_data_type}===') 46 | eval_results_list = [] 47 | for analyze_name in analyze_name_list: 48 | model_name = f'{analyze_name}' 49 | 50 | json_file_path = os.path.join(saved_result_dir, model_name, 'eval_results', test_data_type, 'eval_results.json') 51 | 52 | with open(json_file_path, 'r') as f: 53 | eval_results_dic = json.load(f) 54 | eval_results_list.append(list(eval_results_dic.values())) 55 | eval_metrics_list = list(eval_results_dic.keys()) 56 | 57 | eval_results_array = np.array(eval_results_list) 58 | df_eval_results = pd.DataFrame(eval_results_array, model_name_list, eval_metrics_list) 59 | save_excel_file_path = os.path.join(saved_result_dir, f'p_p_mask_comparison_on_volleyball_{test_data_type}.xlsx') 60 | df_eval_results.to_excel(save_excel_file_path, sheet_name='all') 61 | refine_excel(save_excel_file_path) -------------------------------------------------------------------------------- /yaml_files/videocoatt/eval.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | name: videocoatt 3 | dataset_dir : data/VideoCoAtt_Dataset 4 | saliency_dataset_dir : data/deepgaze_output_loader 5 | 6 | exp_set: 7 | save_folder : saved_weights 8 | # model_name: videocoatt-head_pose_estimator 9 | 10 | # model_name: videoattentiontarget-hgt 11 | # model_name: videoattentiontarget-hgt-high 12 | # model_name: videoattentiontarget-hgt-1101 13 | # model_name: videocoatt-isa_bbox_GT_gaze_GT 14 | # model_name: videocoatt-isa_bbox_GT_gaze_GT_ver2 15 | # model_name: videocoatt-isa_bbox_PRED_gaze_PRED 16 | 17 | # model_name: videocoatt-p_p_field_deep_p_s_davt_freeze 18 | 19 | # model_name: videocoatt-dual-people_fc_shallow 20 | # model_name: videocoatt-dual-people_fc_middle 21 | # model_name: videocoatt-dual-people_fc_deep 22 | # model_name: videocoatt-dual-people_deconv_shallow 23 | # model_name: videocoatt-dual-people_deconv_middle 24 | # model_name: videocoatt-dual-people_field_middle 25 | # model_name: videocoatt-dual-people_field_deep 26 | 27 | # model_name: videocoatt-p_p_field_deep_p_s_gaze_follow_freeze 28 | # model_name: videocoatt-p_p_field_deep_p_s_cnn_gaze_follow_w_pre_simple_average 29 | # model_name: videocoatt-p_p_field_deep_p_s_davt_simple_average 30 | # model_name: videocoatt-p_p_field_deep_p_s_davt_scalar_weight 31 | # model_name: videocoatt-p_p_field_deep_p_s_davt_scalar_weight_fine 32 | # model_name: videocoatt-p_p_field_deep_p_s_davt_scalar_weight_fix 33 | 34 | # model_name: videocoatt-dual-people_field_middle_ind_only 35 | # model_name: videocoatt-dual-people_field_middle_token_only 36 | # model_name: videocoatt-dual-people_field_middle_ind_and_token_ind_based 37 | # model_name: videocoatt-dual-people_field_middle_ind_and_token_token_based 38 | 39 | # model_name: videocoatt-p_p_field_deep_p_s_davt_scalar_weight_fine_ind_only 40 | # model_name: videocoatt-p_p_field_deep_p_s_davt_scalar_weight_fine_token_only 41 | # model_name: videocoatt-p_p_field_deep_p_s_davt_scalar_weight_fine_ind_and_token_ind_based 42 | # model_name: videocoatt-p_p_field_deep_p_s_davt_scalar_weight_fine_ind_and_token_token_based 43 | 44 | # model_name: videocoatt-p_p_field_deep_p_s_davt_scalar_weight_fix_ind_only 45 | model_name: videocoatt-p_p_field_deep_p_s_davt_scalar_weight_fix_token_only 46 | # model_name: videocoatt-p_p_field_deep_p_s_davt_scalar_weight_fix_ind_and_token_ind_based 47 | # model_name: videocoatt-p_p_field_deep_p_s_davt_scalar_weight_fix_ind_and_token_token_based 48 | 49 | # model_name: videocoatt-p_p_field_deep_p_s_davt_scalar_weight_fix_token_only_GT 50 | # model_name: videocoatt-p_p_field_deep_p_s_davt_scalar_weight_fix_token_only_GT_ver2 51 | # model_name: videocoatt-p_p_field_deep_p_s_davt_scalar_weight_fix_token_only_GT_ver3 52 | 53 | seed_num : 777 54 | gpu_mode : True 55 | gpu_start : 4 56 | gpu_finish : 4 57 | num_workers : 1 58 | batch_size : 1 59 | wandb_name : test 60 | 61 | mode: test 62 | # mode : validate 63 | # mode : train 64 | 65 | exp_params: 66 | # test_heads_type : det 67 | test_heads_type : gt 68 | det_heads_model : det_heads 69 | test_heads_conf : 0.6 70 | # test_heads_conf : 0.8 71 | 72 | # test_gt_gaze : False 73 | test_gt_gaze : True 74 | 75 | use_frame_type: mid 76 | # use_frame_type: all 77 | -------------------------------------------------------------------------------- /analysis/iccv2023/input_ablation_on_volleyball.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | import pandas as pd 5 | import numpy as np 6 | 7 | 8 | saved_result_dir = os.path.join('results', 'volleyball') 9 | 10 | # define analyze model type 11 | analyze_name_list = [] 12 | # analyze_name_list.append('volleyball-dual-mid_p_p_field_middle_p_s_davt_bbox_GT_gaze_GT_act_GT_independ_fusion') 13 | # analyze_name_list.append('volleyball-dual-mid_p_p_field_middle_p_s_davt_bbox_PRED_gaze_PRED_act_PRED_independ_fusion') 14 | 15 | # analyze_name_list.append('volleyball-dual-mid_p_p_field_middle_p_s_davt_bbox_GT_gaze_GT_act_GT_psfix_fusion') 16 | # analyze_name_list.append('volleyball-dual-mid_p_p_field_middle_p_s_davt_bbox_PRED_gaze_PRED_act_PRED_psfix_fusion') 17 | 18 | # analyze_name_list.append('volleyball-dual-mid_p_p_field_middle_p_s_davt_bbox_GT_gaze_GT_act_GT_weight_fusion_fine_token_only') 19 | analyze_name_list.append('volleyball-dual-mid_p_p_field_middle_p_s_davt_bbox_PRED_gaze_PRED_act_PRED_weight_fusion_fine_token_only') 20 | 21 | # define ablate type 22 | analyze_name_ablation_list = [] 23 | analyze_name_ablation_list.append('_wo_position') 24 | analyze_name_ablation_list.append('_wo_gaze') 25 | analyze_name_ablation_list.append('_wo_action') 26 | # analyze_name_ablation_list.append('_wo_gaze_wo_position') 27 | # analyze_name_ablation_list.append('_wo_action_wo_position') 28 | # analyze_name_ablation_list.append('_wo_action_wo_gaze') 29 | analyze_name_ablation_list.append('_wo_p_p') 30 | analyze_name_ablation_list.append('_wo_p_s') 31 | analyze_name_ablation_list.append('') 32 | 33 | # define model name 34 | model_name_list = [] 35 | model_name_list.append('Ours w/o p') 36 | model_name_list.append('Ours w/o g') 37 | model_name_list.append('Ours w/o a') 38 | # model_name_list.append('Ours w/o g and p') 39 | # model_name_list.append('Ours w/o a and p') 40 | # model_name_list.append('Ours w/o a and g') 41 | model_name_list.append('Ours w/o branch (a)') 42 | model_name_list.append('Ours w/o branch (b)') 43 | model_name_list.append('Ours') 44 | 45 | # define test data type 46 | test_data_type_list = [] 47 | # test_data_type_list.append('bbox_GT_gaze_GT_act_GT_blur_False') 48 | test_data_type_list.append('bbox_PRED_gaze_PRED_act_PRED_blur_False') 49 | for test_data_type in test_data_type_list: 50 | print(f'==={test_data_type}===') 51 | for analyze_name in analyze_name_list: 52 | # model_name_list = [] 53 | eval_results_list = [] 54 | analyze_name_type = analyze_name 55 | for ablation_name in analyze_name_ablation_list: 56 | 57 | model_name = f'{analyze_name_type}{ablation_name}' 58 | # model_name_list.append(model_name) 59 | 60 | json_file_path = os.path.join(saved_result_dir, model_name, 'eval_results', test_data_type, 'eval_results.json') 61 | 62 | with open(json_file_path, 'r') as f: 63 | eval_results_dic = json.load(f) 64 | eval_results_list.append(list(eval_results_dic.values())) 65 | eval_metrics_list = list(eval_results_dic.keys()) 66 | 67 | eval_results_array = np.array(eval_results_list) 68 | df_eval_results = pd.DataFrame(eval_results_array, model_name_list, eval_metrics_list) 69 | save_csv_file_path = os.path.join(saved_result_dir, f'ablation_{analyze_name}_{test_data_type}.csv') 70 | df_eval_results.to_csv(save_csv_file_path) -------------------------------------------------------------------------------- /count_gt_jumping_videocoatt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | 6 | def count_gt_jumping(ja_ann, seq_num): 7 | ja_ann_frames = ja_ann[seq_num].keys() 8 | # print(f'=====Seq:{seq_num} Frames:{len(ja_ann_frames)}=====') 9 | 10 | jumping_count = 0 11 | diminishing_count = 0 12 | for iter, frame_id in enumerate(sorted(ja_ann_frames)): 13 | ja_ids = list(ja_ann[seq_num][frame_id].keys()) 14 | if len(ja_ids) != 1: 15 | continue 16 | 17 | # obtain ja points 18 | ja_bbox = np.array(ja_ann[seq_num][frame_id][ja_ids[0]]) 19 | ja_bbox_x_center = (ja_bbox[0] + ja_bbox[2]) / 2 20 | ja_bbox_y_center = (ja_bbox[1] + ja_bbox[3]) / 2 21 | ja_bbox_center = np.array([ja_bbox_x_center, ja_bbox_y_center]) 22 | 23 | # set initial values 24 | if iter == 0: 25 | ja_bbox_center_prev = ja_bbox_center 26 | frame_id_prev = frame_id 27 | 28 | ja_dist = np.linalg.norm(ja_bbox_center - ja_bbox_center_prev) 29 | ja_dist_thre = ja_dist > 50 30 | frame_continous = (frame_id - frame_id_prev) == 1 31 | if ja_dist_thre and frame_continous: 32 | # print(f'Seq:{seq_num} Frame:{frame_id} {ja_dist}') 33 | jumping_count += 1 34 | elif ja_dist_thre and not frame_continous: 35 | # print(f'Seq:{seq_num} Frame:{frame_id} {ja_dist}') 36 | diminishing_count += 1 37 | 38 | # update previous values 39 | ja_bbox_center_prev = ja_bbox_center 40 | frame_id_prev = frame_id 41 | 42 | return jumping_count, diminishing_count 43 | 44 | # load annotations 45 | dataset_dir = os.path.join('data', 'VideoCoAtt_Dataset', 'annotations') 46 | ja_ann = {} 47 | for data_type in os.listdir(dataset_dir): 48 | data_type_dir = os.path.join(dataset_dir, data_type) 49 | if not data_type in ja_ann.keys(): 50 | ja_ann[data_type] = {} 51 | for file in os.listdir(data_type_dir): 52 | file_path = os.path.join(data_type_dir, file) 53 | seq_num = int(file.split('.')[0]) 54 | with open(file_path, 'r') as f: 55 | data = [x.strip().split() for x in f.readlines()] 56 | ja_ann[data_type][seq_num] = {} 57 | for i in range(len(data)): 58 | ja_id, frame_id = int(data[i][0]), int(data[i][1]) 59 | ja_ann[data_type][seq_num][frame_id] = {} 60 | ja_ann[data_type][seq_num][frame_id][ja_id] = [int(x) for x in data[i][2:6]] 61 | 62 | # count jumping 63 | jumping_count_list = [] 64 | diminishing_count_list = [] 65 | data_type_list = ja_ann.keys() 66 | for data_type in data_type_list: 67 | ja_ann_data_type = ja_ann[data_type] 68 | seq_num_list = ja_ann_data_type.keys() 69 | for seq_num in seq_num_list: 70 | jumping_count, diminishing_count = count_gt_jumping(ja_ann_data_type, seq_num) 71 | print(f'Seq:{seq_num} Jumping Count:{jumping_count} Diminishing Count:{diminishing_count}') 72 | jumping_count_list.append(jumping_count) 73 | diminishing_count_list.append(diminishing_count) 74 | 75 | plt.figure() 76 | plt.hist(jumping_count_list) 77 | plt.xticks(np.arange(0, 20, 1)) 78 | plt.savefig('count_gt_jumping_videocoatt_jumping.png') 79 | 80 | plt.figure() 81 | plt.hist(diminishing_count_list) 82 | plt.xticks(np.arange(0, 20, 1)) 83 | plt.savefig('count_gt_jumping_videocoatt_diminishing.png') -------------------------------------------------------------------------------- /analysis/iccv2023/comparison_on_volleyball.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | import pandas as pd 5 | import numpy as np 6 | 7 | saved_result_dir = os.path.join('results', 'volleyball') 8 | 9 | # define training modality type 10 | train_mode_list = [] 11 | train_mode_list.append('GT') 12 | train_mode_list.append('GT') 13 | train_mode_list.append('Pr') 14 | 15 | # define test data type 16 | test_data_type_list = [] 17 | test_data_type_list.append('bbox_GT_gaze_GT_act_GT_blur_False') 18 | test_data_type_list.append('bbox_PRED_gaze_PRED_act_PRED_blur_False') 19 | test_data_type_list.append('bbox_PRED_gaze_PRED_act_PRED_blur_False') 20 | 21 | # define analyze model type 22 | analyze_name_list_dic = {} 23 | 24 | # (Train:Test, GT:GT) 25 | analyze_name_list = [] 26 | analyze_name_list.append('2021_0708_lr_e3_gamma_1_stack_3_mid_frame_ver2') 27 | analyze_name_list.append('volleyball-isa_bbox_GT_gaze_GT_act_GT') 28 | analyze_name_list.append('volleyball-dual-mid_p_p_field_middle_p_s_davt_bbox_GT_gaze_GT_act_GT_psfix_fusion_wo_p_p') 29 | analyze_name_list.append('volleyball-dual-mid_p_p_field_middle_p_s_davt_bbox_GT_gaze_GT_act_GT_weight_fusion_fine_token_only') 30 | analyze_name_list_dic[0] = analyze_name_list 31 | 32 | # (Train:Test, GT:Pr) 33 | analyze_name_list = [] 34 | analyze_name_list.append('2021_0708_lr_e3_gamma_1_stack_3_mid_frame_ver2') 35 | analyze_name_list.append('volleyball-isa_bbox_GT_gaze_GT_act_GT') 36 | analyze_name_list.append('volleyball-dual-mid_p_p_field_middle_p_s_davt_bbox_GT_gaze_GT_act_GT_psfix_fusion_wo_p_p') 37 | analyze_name_list.append('volleyball-dual-mid_p_p_field_middle_p_s_davt_bbox_GT_gaze_GT_act_GT_weight_fusion_fine_token_only') 38 | analyze_name_list_dic[1] = analyze_name_list 39 | 40 | # (Train:Test, Pr:Pr) 41 | analyze_name_list = [] 42 | analyze_name_list.append('2021_0708_lr_e3_gamma_1_stack_3_mid_frame_ver2') 43 | analyze_name_list.append('volleyball-isa_bbox_PRED_gaze_PRED_act_PRED') 44 | analyze_name_list.append('volleyball-dual-mid_p_p_field_middle_p_s_davt_bbox_PRED_gaze_PRED_act_PRED_psfix_fusion_wo_p_p') 45 | analyze_name_list.append('volleyball-dual-mid_p_p_field_middle_p_s_davt_bbox_PRED_gaze_PRED_act_PRED_weight_fusion_fine_token_only') 46 | analyze_name_list_dic[2] = analyze_name_list 47 | 48 | # define model names 49 | model_name_list = [] 50 | model_name_list.append('Ball detection') 51 | model_name_list.append('ISA') 52 | model_name_list.append('DAVT') 53 | model_name_list.append('Ours') 54 | 55 | for data_idx, analyze_name_list in analyze_name_list_dic.items(): 56 | train_mode = train_mode_list[data_idx] 57 | test_data_type = test_data_type_list[data_idx] 58 | print(f'==={train_mode}:{test_data_type}===') 59 | eval_results_list = [] 60 | for analyze_name in analyze_name_list: 61 | analyze_name_type = analyze_name 62 | model_name = f'{analyze_name_type}' 63 | json_file_path = os.path.join(saved_result_dir, model_name, 'eval_results', test_data_type, 'eval_results.json') 64 | with open(json_file_path, 'r') as f: 65 | eval_results_dic = json.load(f) 66 | eval_results_list.append(list(eval_results_dic.values())) 67 | eval_metrics_list = list(eval_results_dic.keys()) 68 | 69 | eval_results_array = np.array(eval_results_list) 70 | df_eval_results = pd.DataFrame(eval_results_array, model_name_list, eval_metrics_list) 71 | save_csv_file_path = os.path.join(saved_result_dir, f'comparision_volleyball_train_{train_mode}_{test_data_type}.csv') 72 | df_eval_results.to_csv(save_csv_file_path) -------------------------------------------------------------------------------- /make_gif_from_all_images.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import glob 4 | import sys 5 | import argparse 6 | from PIL import Image 7 | from tqdm import tqdm 8 | 9 | gif_img_vol = [] 10 | gif_img_vid = [] 11 | dataset_list = ['volleyball', 'videocoatt'] 12 | for dataset in dataset_list: 13 | id_list = [] 14 | if dataset == 'volleyball': 15 | model_name = 'volleyball-dual-mid_p_p_field_middle_p_s_davt_bbox_GT_gaze_GT_act_GT_weight_fusion_fine_token_only' 16 | id_list.append('4_105655') 17 | id_list.append('5_30480') 18 | # id_list.append('9_19275') 19 | # id_list.append('11_22120') 20 | # id_list.append('14_28045') 21 | # id_list.append('20_25385') 22 | # id_list.append('25_29630') 23 | # id_list.append('29_17050') 24 | id_list.append('34_12470') 25 | elif dataset == 'videocoatt': 26 | model_name = 'videocoatt-p_p_field_deep_p_s_davt_scalar_weight_fix_token_only_GT' 27 | id_list.append('10') 28 | id_list.append('15') 29 | id_list.append('19') 30 | id_list.append('23') 31 | 32 | for id_txt in id_list: 33 | print(f'{dataset}:{id_txt}') 34 | if dataset == 'volleyball': 35 | vid_num, seq_num = id_txt.split('_') 36 | elif dataset == 'videocoatt': 37 | vid_num = id_txt 38 | 39 | image_type_list = ['final_jo_att_superimposed'] 40 | for image_type in image_type_list: 41 | video_folder = os.path.join('results', dataset, model_name, 'videos') 42 | if os.path.exists(video_folder) is False: 43 | os.makedirs(video_folder) 44 | 45 | image_folder = os.path.join('results', dataset, model_name, image_type) 46 | if dataset == 'volleyball': 47 | fps = 7 48 | images_file_base = os.path.join(image_folder, f'test_{vid_num}_{seq_num}_*_{image_type}.png') 49 | elif dataset == 'videocoatt': 50 | fps = 7 51 | images_file_base = os.path.join(image_folder, f'test_*_{vid_num}_*_{vid_num}_{image_type}.png') 52 | 53 | images = sorted(glob.glob(images_file_base)) 54 | if dataset == 'volleyball': 55 | gif_img_vol += images 56 | elif dataset == 'videocoatt': 57 | gif_img_vid += images 58 | 59 | # pil_img_vol = [Image.open(gif_img).resize((320, 180)) for gif_img in gif_img_vol] 60 | # pil_img_vid = [Image.open(gif_img).resize((320, 180)) for gif_img in gif_img_vid] 61 | # pil_img_all = pil_img_vol + pil_img_vid 62 | # pil_img_vol[0].save('results/results_vol.gif', save_all=True, append_images=pil_img_vol) 63 | # pil_img_vid[0].save('results/results_vid.gif', save_all=True, append_images=pil_img_vid) 64 | # pil_img_all[0].save('results/results_all.gif', save_all=True, append_images=pil_img_all) 65 | 66 | pil_img_vol = [Image.open(gif_img).resize((320*2, 180*2)) for gif_img in gif_img_vol] 67 | pil_img_vid = [Image.open(gif_img).resize((320*2, 180*2)) for gif_img in gif_img_vid] 68 | pil_img_all = pil_img_vol + pil_img_vid 69 | pil_img_vol[0].save('results/results_vol_large.gif', save_all=True, append_images=pil_img_vol, loop=0) 70 | pil_img_vid[0].save('results/results_vid_large.gif', save_all=True, append_images=pil_img_vid, loop=0) 71 | pil_img_all[0].save('results/results_all_large.gif', save_all=True, append_images=pil_img_all, loop=0) 72 | 73 | # gif_img_all = gif_img_vol + gif_img_vid 74 | # resize_height, resize_width = 720//4, 1280//4 75 | # video_name = os.path.join('results', f'ICCV2023-PJAE-demo.mp4') 76 | # fmt = cv2.VideoWriter_fourcc('m', 'p', '4', 'v') 77 | # video = cv2.VideoWriter(video_name, fmt, fps, (resize_width,resize_height)) 78 | # fps = 7 79 | # for gif_img in tqdm(gif_img_all): 80 | # frame = cv2.resize(cv2.imread(gif_img), (resize_width,resize_height)) 81 | # video.write(frame) 82 | # cv2.destroyAllWindows() 83 | # video.release() -------------------------------------------------------------------------------- /make_video_from_images.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import glob 4 | import sys 5 | import argparse 6 | 7 | parser = argparse.ArgumentParser(description="parameters for training") 8 | parser.add_argument("dataset", type=str, help="dataset name") 9 | args = parser.parse_args() 10 | dataset = args.dataset 11 | dataset_list = ['volleyball', 'videocoatt'] 12 | for dataset in dataset_list: 13 | id_list = [] 14 | if dataset == 'volleyball': 15 | model_name = 'volleyball-dual-mid_p_p_field_middle_p_s_davt_bbox_GT_gaze_GT_act_GT_weight_fusion_fine_token_only' 16 | id_list.append('4_105655') 17 | id_list.append('5_30480') 18 | # id_list.append('9_19275') 19 | # id_list.append('11_22120') 20 | # id_list.append('14_28045') 21 | # id_list.append('20_25385') 22 | # id_list.append('25_29630') 23 | # id_list.append('29_17050') 24 | id_list.append('34_12470') 25 | elif dataset == 'videocoatt': 26 | model_name = 'videocoatt-p_p_field_deep_p_s_davt_scalar_weight_fix_token_only_GT' 27 | id_list.append('10') 28 | id_list.append('15') 29 | id_list.append('19') 30 | id_list.append('23') 31 | 32 | for id_txt in id_list: 33 | print(f'{dataset}:{id_txt}') 34 | if dataset == 'volleyball': 35 | vid_num, seq_num = id_txt.split('_') 36 | elif dataset == 'videocoatt': 37 | vid_num = id_txt 38 | 39 | image_type_list = ['final_jo_att_superimposed'] 40 | for image_type in image_type_list: 41 | video_folder = os.path.join('results', dataset, model_name, 'videos') 42 | if os.path.exists(video_folder) is False: 43 | os.makedirs(video_folder) 44 | 45 | image_folder = os.path.join('results', dataset, model_name, image_type) 46 | if dataset == 'volleyball': 47 | fps = 7 48 | images_file_base = os.path.join(image_folder, f'test_{vid_num}_{seq_num}_*_{image_type}.png') 49 | elif dataset == 'videocoatt': 50 | fps = 7 51 | images_file_base = os.path.join(image_folder, f'test_*_{vid_num}_*_{vid_num}_{image_type}.png') 52 | 53 | images = sorted(glob.glob(images_file_base)) 54 | if dataset == 'volleyball': 55 | images_all = [images] 56 | elif dataset == 'videocoatt': 57 | images_all = [] 58 | images_mini = [] 59 | for img_idx, image in enumerate(images): 60 | img_num = int(image.split('/')[-1].split('_')[3]) 61 | 62 | if img_idx != 0: 63 | continue_flag = (prev_img_num+1) == img_num 64 | else: 65 | continue_flag = True 66 | 67 | if continue_flag: 68 | images_mini.append(image) 69 | else: 70 | images_all.append(images_mini) 71 | images_mini = [] 72 | 73 | prev_img_num = img_num 74 | images_all.append(images_mini) 75 | 76 | for vid_idx, images in enumerate(images_all): 77 | frame = cv2.imread(images[0]) 78 | height, width, layers = frame.shape 79 | 80 | if dataset == 'volleyball': 81 | video_name = os.path.join(video_folder, f'test_{vid_num}_{seq_num}_{vid_idx}_{image_type}.avi') 82 | elif dataset == 'videocoatt': 83 | video_name = os.path.join(video_folder, f'test_{vid_num}_{vid_idx}_{image_type}.avi') 84 | 85 | video = cv2.VideoWriter(video_name, 0, fps, (width,height)) 86 | for image in images: 87 | video.write(cv2.imread(image)) 88 | cv2.destroyAllWindows() 89 | video.release() -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | addict==2.4.0 2 | appdirs==1.4.3 3 | apturl==0.5.2 4 | asttokens==2.2.1 5 | async-timeout==4.0.2 6 | attrs==19.3.0 7 | Automat==0.8.0 8 | bcrypt==3.1.7 9 | blinker==1.4 10 | Brlapi==0.7.0 11 | cached-property==1.5.1 12 | chardet==3.0.4 13 | Click==7.0 14 | cloudpickle==2.2.1 15 | colorama==0.4.3 16 | command-not-found==0.3 17 | configobj==5.0.6 18 | constantly==15.1.0 19 | contourpy==1.0.7 20 | cryptography==2.8 21 | cupshelpers==1.0 22 | cycler==0.11.0 23 | dbus-python==1.2.16 24 | defer==1.0.6 25 | distlib==0.3.0 26 | distro==1.4.0 27 | distro-info===0.23ubuntu1 28 | dm-tree==0.1.8 29 | docker==4.1.0 30 | docker-compose==1.25.0 31 | docker-pycreds==0.4.0 32 | dockerpty==0.4.1 33 | docopt==0.6.2 34 | duplicity==0.8.12.0 35 | einops==0.4.1 36 | entrypoints==0.3 37 | et-xmlfile==1.1.0 38 | executing==1.2.0 39 | fasteners==0.14.1 40 | filelock==3.0.12 41 | fonttools==4.33.3 42 | future==0.18.2 43 | fvcore==0.1.5.post20221221 44 | galternatives==1.0.6 45 | gitdb==4.0.9 46 | grpcio==1.54.2 47 | gunicorn==20.0.4 48 | gym==0.22.0 49 | gym-notices==0.0.8 50 | httplib2==0.14.0 51 | hyperlink==19.0.0 52 | icecream==2.1.3 53 | idna==2.8 54 | imageio==2.26.0 55 | importlib-metadata==4.13.0 56 | importlib-resources==5.12.0 57 | incremental==16.10.1 58 | iopath==0.1.10 59 | iotop==0.6 60 | joblib==1.2.0 61 | jsonschema==3.2.0 62 | keyring==18.0.1 63 | kiwisolver==1.4.3 64 | language-selector==0.1 65 | launchpadlib==1.10.13 66 | lazr.restfulclient==0.14.2 67 | lazr.uri==1.0.3 68 | lazy_loader==0.1 69 | lightning-utilities==0.9.0 70 | llvmlite==0.39.1 71 | lockfile==0.12.2 72 | louis==3.12.0 73 | lz4==4.3.2 74 | macaroonbakery==1.3.1 75 | Mako==1.1.0 76 | MarkupSafe==1.1.0 77 | matplotlib==3.5.2 78 | monotonic==1.5 79 | more-itertools==4.2.0 80 | msgpack==1.0.5 81 | netifaces==0.10.4 82 | networkx==3.0 83 | numba==0.56.4 84 | numpy==1.22.4 85 | oauthlib==3.1.0 86 | olefile==0.46 87 | opencv-contrib-python==4.6.0.66 88 | opencv-python==3.4.18.65 89 | openpyxl==3.1.1 90 | packaging==21.3 91 | pandas==1.4.2 92 | paramiko==2.6.0 93 | pathtools==0.1.2 94 | PettingZoo==1.12.0 95 | pexpect==4.6.0 96 | Pillow==9.4.0 97 | pkg_resources==0.0.0 98 | portalocker==2.7.0 99 | promise==2.3 100 | protobuf==3.20.1 101 | psutil==5.9.1 102 | pyasn1==0.4.2 103 | pyasn1-modules==0.2.1 104 | pycairo==1.16.2 105 | pycrypto==2.6.1 106 | pycups==1.9.73 107 | pyglet==2.0.7 108 | Pygments==2.3.1 109 | PyGObject==3.36.0 110 | PyHamcrest==1.9.0 111 | PyJWT==1.7.1 112 | pymacaroons==0.13.0 113 | PyNaCl==1.3.0 114 | pynndescent==0.5.8 115 | pyOpenSSL==19.0.0 116 | pyparsing==3.0.9 117 | pyRFC3339==1.1 118 | pyrsistent==0.15.5 119 | python-apt==2.0.1+ubuntu0.20.4.1 120 | python-dateutil==2.8.2 121 | python-debian==0.1.36+ubuntu1.1 122 | pytorch-pretrained-vit==0.0.7 123 | pytz==2022.1 124 | PyWavelets==1.4.1 125 | pyxdg==0.26 126 | PyYAML==5.3.1 127 | ray==1.8.0 128 | redis==4.5.5 129 | reportlab==3.5.34 130 | requests==2.22.0 131 | requests-unixsocket==0.2.0 132 | roi-align==0.0.2 133 | scikit-image==0.20.0 134 | scikit-learn==1.2.1 135 | scipy==1.9.1 136 | screen-resolution-extra==0.0.0 137 | seaborn==0.11.2 138 | SecretStorage==2.3.1 139 | sentry-sdk==1.5.12 140 | service-identity==18.1.0 141 | setproctitle==1.2.3 142 | shortuuid==1.0.9 143 | simplejson==3.16.0 144 | six==1.14.0 145 | smmap==5.0.0 146 | sos==4.4 147 | ssh-import-id==5.10 148 | SuperSuit==3.2.0 149 | systemd-python==234 150 | tabulate==0.9.0 151 | tensorboardX==2.6 152 | termcolor==2.3.0 153 | texttable==1.6.2 154 | thop==0.1.1.post2209072238 155 | threadpoolctl==3.1.0 156 | tifffile==2023.2.28 157 | timm==0.6.5 158 | torch==1.11.0+cu115 159 | torchaudio==0.11.0+cu115 160 | torchfile==0.1.0 161 | torchmetrics==1.0.0 162 | torchvision==0.12.0+cu115 163 | tqdm==4.64.0 164 | Twisted==18.9.0 165 | typing_extensions==4.2.0 166 | tzdata==2023.3 167 | ubuntu-advantage-tools==8001 168 | ubuntu-drivers-common==0.0.0 169 | ufw==0.36 170 | umap-learn==0.5.3 171 | unattended-upgrades==0.1 172 | urllib3==1.25.8 173 | usb-creator==0.3.7 174 | virtualenv==20.0.17 175 | vit-pytorch==0.35.4 176 | wadllib==1.3.3 177 | wandb==0.12.18 178 | websocket-client==0.53.0 179 | xkit==0.0.0 180 | yacs==0.1.8 181 | zipp==3.15.0 182 | zope.interface==4.7.1 183 | -------------------------------------------------------------------------------- /yaml_files/toy/debug.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | name: toy 3 | dataset_dir : data/joint_attention_toy 4 | 5 | exp_set: 6 | save_folder: saved_weights 7 | wandb_name: debug 8 | wandb_log : False 9 | 10 | batch_size: 2 11 | num_workers: 1 12 | seed_num: 777 13 | gpu_mode : True 14 | gpu_start : 3 15 | gpu_finish : 3 16 | 17 | resize_height: 320 18 | resize_width: 480 19 | resize_head_height: 64 20 | resize_head_width: 64 21 | 22 | exp_params: 23 | use_gaze_loss : False 24 | # use_gaze_loss : True 25 | 26 | use_e_map_loss : False 27 | # use_e_map_loss : True 28 | 29 | use_e_att_loss : False 30 | # use_e_att_loss : True 31 | 32 | use_each_e_map_loss : False 33 | # use_each_e_map_loss : True 34 | 35 | # use_regression_loss : False 36 | use_regression_loss : True 37 | 38 | use_attraction_loss : False 39 | # use_attraction_loss : True 40 | 41 | use_repulsion_loss : False 42 | # use_repulsion_loss : True 43 | 44 | use_frame_type: mid 45 | # use_frame_type: all 46 | 47 | # use_gt_gaze: False 48 | use_gt_gaze: True 49 | 50 | # loss function 51 | loss : mse 52 | # loss : bce 53 | 54 | # learning rate 55 | lr : 0.001 56 | 57 | # gt gaussian 58 | gaussian_sigma: 10 59 | 60 | # learning schedule 61 | nEpochs : 500 62 | start_iter : 0 63 | snapshots : 100 64 | scheduler_start : 1000 65 | scheduler_iter : 1100000 66 | 67 | # pretrained models 68 | pretrained_models_dir: saved_weights 69 | 70 | use_pretrained_head_pose_estimator: False 71 | # use_pretrained_head_pose_estimator: True 72 | pretrained_head_pose_estimator_name: none 73 | freeze_head_pose_estimator: False 74 | # freeze_head_pose_estimator: True 75 | 76 | use_pretrained_joint_attention_estimator: False 77 | # use_pretrained_head_pose_estimator: True 78 | pretrained_joint_attention_estimator_name: pretrain_head_estimator 79 | freeze_joint_attention_estimator: False 80 | # freeze_joint_attention_estimator: True 81 | 82 | model_params: 83 | model_type: ja_transformer 84 | 85 | # Position 86 | # use_position : False 87 | use_position : True 88 | # use_position_enc_person : False 89 | use_position_enc_person : True 90 | use_position_enc_type : sine 91 | # use_position_enc_type : learnable 92 | 93 | # Gaze 94 | # use_gaze : False 95 | use_gaze : True 96 | gaze_type: vector 97 | # gaze_type: feature 98 | 99 | # Action 100 | use_action : False 101 | # use_action : True 102 | 103 | # Whole image 104 | # use_img : False 105 | use_img : True 106 | 107 | # Gaze map 108 | # use_angle_dist_rgb_type : none 109 | # use_angle_dist_rgb_type : raw 110 | use_angle_dist_rgb_type : feat 111 | 112 | use_dynamic_angle : False 113 | # use_dynamic_angle : True 114 | 115 | # use_dynamic_distance : False 116 | use_dynamic_distance : True 117 | 118 | dynamic_distance_type: gaussian 119 | dynamic_gaussian_num : 1 120 | # dynamic_distance_type: generator 121 | use_gauss_limit : False 122 | # # use_gauss_limit : True 123 | 124 | gaze_map_estimator_type : identity 125 | # gaze_map_estimator_type : deep 126 | # gaze_map_estimator_type : normal 127 | 128 | # transformer 129 | use_people_people_trans: False 130 | # use_people_people_trans: True 131 | 132 | # rgb_people_trans_type : concat_direct 133 | # rgb_people_trans_type : concat_paralell 134 | rgb_people_trans_type : concat_independent 135 | 136 | people_people_trans_enc_num : 2 137 | mha_num_heads_people_people : 2 138 | 139 | rgb_people_trans_enc_num : 4 140 | mha_num_heads_rgb_people : 4 141 | rgb_embeding_dim : 64 142 | people_feat_dim : 16 143 | 144 | # rgb_cnn_extractor_type : normal 145 | # rgb_cnn_extractor_type : patch 146 | # rgb_cnn_extractor_type : no_use 147 | rgb_cnn_extractor_type : resnet18 148 | # rgb_cnn_extractor_type : resnet50 149 | # rgb_cnn_extractor_stage_idx : 1 150 | # rgb_cnn_extractor_stage_idx : 2 151 | rgb_cnn_extractor_stage_idx : 3 152 | # rgb_cnn_extractor_stage_idx : 4 153 | # rgb_cnn_extractor_type : hrnet_w18_small 154 | # rgb_cnn_extractor_type : hrnet_w32 155 | # rgb_cnn_extractor_stage_idx : 3 156 | # rgb_cnn_extractor_type : convnext 157 | # rgb_cnn_extractor_stage_idx : 2 158 | 159 | # angle_distance_fusion: max 160 | # angle_distance_fusion: mean 161 | angle_distance_fusion: mult -------------------------------------------------------------------------------- /yaml_files/toy/train.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | name: toy 3 | dataset_dir : data/joint_attention_toy 4 | 5 | exp_set: 6 | save_folder: saved_weights 7 | wandb_log : True 8 | # wandb_name: toy-concat_independent 9 | wandb_name: toy-concat_independent_angle_mask_feat 10 | # wandb_name: toy-concat_direct 11 | 12 | batch_size: 64 13 | num_workers: 16 14 | seed_num: 777 15 | gpu_mode : True 16 | gpu_start : 4 17 | gpu_finish : 4 18 | 19 | resize_height: 320 20 | resize_width: 480 21 | resize_head_height: 64 22 | resize_head_width: 64 23 | 24 | exp_params: 25 | use_gaze_loss : False 26 | # use_gaze_loss : True 27 | 28 | use_e_map_loss : False 29 | # use_e_map_loss : True 30 | 31 | use_e_att_loss : False 32 | # use_e_att_loss : True 33 | 34 | use_each_e_map_loss : False 35 | # use_each_e_map_loss : True 36 | 37 | # use_regression_loss : False 38 | use_regression_loss : True 39 | 40 | use_attraction_loss : False 41 | # use_attraction_loss : True 42 | 43 | use_repulsion_loss : False 44 | # use_repulsion_loss : True 45 | 46 | # use_frame_type: mid 47 | use_frame_type: all 48 | 49 | # use_gt_gaze: False 50 | use_gt_gaze: True 51 | 52 | # loss function 53 | loss : mse 54 | # loss : bce 55 | 56 | # learning rate 57 | lr : 0.001 58 | 59 | # gt gaussian 60 | gaussian_sigma: 10 61 | 62 | # learning schedule 63 | nEpochs : 500 64 | start_iter : 0 65 | snapshots : 100 66 | scheduler_start : 1000 67 | scheduler_iter : 1100000 68 | 69 | # pretrained models 70 | pretrained_models_dir: saved_weights 71 | 72 | use_pretrained_head_pose_estimator: False 73 | # use_pretrained_head_pose_estimator: True 74 | pretrained_head_pose_estimator_name: none 75 | freeze_head_pose_estimator: False 76 | # freeze_head_pose_estimator: True 77 | 78 | use_pretrained_joint_attention_estimator: False 79 | # use_pretrained_head_pose_estimator: True 80 | pretrained_joint_attention_estimator_name: pretrain_head_estimator 81 | freeze_joint_attention_estimator: False 82 | # freeze_joint_attention_estimator: True 83 | 84 | model_params: 85 | model_type: ja_transformer 86 | 87 | # Position 88 | # use_position : False 89 | use_position : True 90 | # use_position_enc_person : False 91 | use_position_enc_person : True 92 | use_position_enc_type : sine 93 | # use_position_enc_type : learnable 94 | 95 | # Gaze 96 | # use_gaze : False 97 | use_gaze : True 98 | gaze_type: vector 99 | # gaze_type: feature 100 | 101 | # Action 102 | use_action : False 103 | # use_action : True 104 | 105 | # Whole image 106 | # use_img : False 107 | use_img : True 108 | 109 | # Gaze map 110 | # use_angle_dist_rgb_type : none 111 | # use_angle_dist_rgb_type : raw 112 | use_angle_dist_rgb_type : feat 113 | 114 | use_dynamic_angle : False 115 | # use_dynamic_angle : True 116 | 117 | # use_dynamic_distance : False 118 | use_dynamic_distance : True 119 | 120 | dynamic_distance_type: gaussian 121 | dynamic_gaussian_num : 1 122 | # dynamic_distance_type: generator 123 | use_gauss_limit : False 124 | # # use_gauss_limit : True 125 | 126 | gaze_map_estimator_type : identity 127 | # gaze_map_estimator_type : deep 128 | # gaze_map_estimator_type : normal 129 | 130 | # transformer 131 | use_people_people_trans: False 132 | # use_people_people_trans: True 133 | 134 | # rgb_people_trans_type : concat_direct 135 | # rgb_people_trans_type : concat_paralell 136 | rgb_people_trans_type : concat_independent 137 | 138 | people_people_trans_enc_num : 2 139 | mha_num_heads_people_people : 2 140 | 141 | rgb_people_trans_enc_num : 4 142 | mha_num_heads_rgb_people : 4 143 | rgb_embeding_dim : 64 144 | people_feat_dim : 16 145 | 146 | # rgb_cnn_extractor_type : normal 147 | # rgb_cnn_extractor_type : patch 148 | # rgb_cnn_extractor_type : no_use 149 | rgb_cnn_extractor_type : resnet18 150 | # rgb_cnn_extractor_type : resnet50 151 | # rgb_cnn_extractor_stage_idx : 1 152 | # rgb_cnn_extractor_stage_idx : 2 153 | rgb_cnn_extractor_stage_idx : 3 154 | # rgb_cnn_extractor_stage_idx : 4 155 | # rgb_cnn_extractor_type : hrnet_w18_small 156 | # rgb_cnn_extractor_type : hrnet_w32 157 | # rgb_cnn_extractor_stage_idx : 3 158 | # rgb_cnn_extractor_type : convnext 159 | # rgb_cnn_extractor_stage_idx : 2 160 | 161 | # angle_distance_fusion: max 162 | # angle_distance_fusion: mean 163 | angle_distance_fusion: mult -------------------------------------------------------------------------------- /models/model_selector.py: -------------------------------------------------------------------------------- 1 | from models.head_pose_estimator import HeadPoseEstimatorResnet 2 | from models.joint_attention_estimator_transformer import JointAttentionEstimatorTransformer 3 | from models.joint_attention_estimator_transformer_dual import JointAttentionEstimatorTransformerDual 4 | from models.joint_attention_estimator_transformer_dual_only_people import JointAttentionEstimatorTransformerDualOnlyPeople 5 | from models.joint_attention_estimator_transformer_dual_img_feat import JointAttentionEstimatorTransformerDualImgFeat 6 | from models.joint_attention_estimator_transformer_dual_img_feat_only_people import JointAttentionEstimatorTransformerDualOnlyPeopleImgFeat 7 | from models.joint_attention_fusion import JointAttentionFusion, JointAttentionFusionDummy 8 | from models.inferring_shared_attention_estimation import InferringSharedAttentionEstimator 9 | from models.end_to_end_human_gaze_target import EndToEndHumanGazeTargetTransformer 10 | from models.davt_scene_extractor import ModelSpatial, ModelSpatialDummy, ModelSpatioTemporal 11 | from models.transformer_scene_extractor import SceneFeatureTransformer 12 | from models.cnn_scene_extractor import SceneFeatureCNN 13 | from models.hourglass import HourglassNet 14 | import sys 15 | 16 | def model_generator(cfg): 17 | if cfg.model_params.model_type == 'ja_transformer': 18 | model_head = HeadPoseEstimatorResnet(cfg) 19 | model_gaussian = JointAttentionEstimatorTransformer(cfg) 20 | model_saliency = ModelSpatial() 21 | elif cfg.model_params.model_type == 'ja_transformer_dual': 22 | model_head = HeadPoseEstimatorResnet(cfg) 23 | model_gaussian = JointAttentionEstimatorTransformerDual(cfg) 24 | if cfg.model_params.p_s_estimator_type == 'davt': 25 | model_saliency = ModelSpatial() 26 | elif cfg.model_params.p_s_estimator_type == 'cnn': 27 | model_saliency = SceneFeatureCNN(cfg) 28 | elif cfg.model_params.p_s_estimator_type == 'transformer': 29 | model_saliency = SceneFeatureTransformer(cfg) 30 | model_fusion = JointAttentionFusion(cfg) 31 | elif cfg.model_params.model_type == 'ja_transformer_dual_only_people': 32 | model_head = HeadPoseEstimatorResnet(cfg) 33 | model_gaussian = JointAttentionEstimatorTransformerDualOnlyPeople(cfg) 34 | model_saliency = ModelSpatialDummy() 35 | model_fusion = JointAttentionFusionDummy() 36 | elif cfg.model_params.model_type == 'ja_transformer_dual_img_feat': 37 | model_head = HeadPoseEstimatorResnet(cfg) 38 | model_gaussian = JointAttentionEstimatorTransformerDualImgFeat(cfg) 39 | if cfg.model_params.p_s_estimator_type == 'davt': 40 | if cfg.exp_params.use_frame_type == 'mid': 41 | model_saliency = ModelSpatial() 42 | else: 43 | model_saliency = ModelSpatioTemporal(num_lstm_layers = 2) 44 | elif cfg.model_params.p_s_estimator_type == 'cnn': 45 | model_saliency = SceneFeatureCNN(cfg) 46 | elif cfg.model_params.p_s_estimator_type == 'transformer': 47 | model_saliency = SceneFeatureTransformer(cfg) 48 | model_fusion = JointAttentionFusion(cfg) 49 | elif cfg.model_params.model_type == 'ja_transformer_dual_only_people_img_feat': 50 | model_head = HeadPoseEstimatorResnet(cfg) 51 | model_gaussian = JointAttentionEstimatorTransformerDualOnlyPeopleImgFeat(cfg) 52 | model_saliency = ModelSpatialDummy() 53 | model_fusion = JointAttentionFusionDummy() 54 | elif cfg.model_params.model_type == 'isa': 55 | model_head = HeadPoseEstimatorResnet(cfg) 56 | model_gaussian = InferringSharedAttentionEstimator(cfg) 57 | if 'volleyball' in cfg.data.name: 58 | model_saliency = HourglassNet(3, 3, 5) 59 | else: 60 | model_saliency = ModelSpatialDummy(cfg) 61 | model_fusion = JointAttentionFusionDummy() 62 | elif cfg.model_params.model_type == 'human_gaze_target_transformer': 63 | model_head = HeadPoseEstimatorResnet(cfg) 64 | model_gaussian = EndToEndHumanGazeTargetTransformer(cfg) 65 | model_saliency = ModelSpatialDummy(cfg) 66 | model_fusion = JointAttentionFusionDummy() 67 | elif cfg.model_params.model_type == 'ball_detection': 68 | model_head = None 69 | model_gaussian = None 70 | model_saliency = HourglassNet(3, 3, 5) 71 | model_fusion = JointAttentionFusionDummy() 72 | else: 73 | assert True, 'cfg.exp_parames.model_type is incorrect' 74 | 75 | return model_head, model_gaussian, model_saliency, model_fusion, cfg -------------------------------------------------------------------------------- /analysis/iccv2023/comparison_finetune_on_videocoatt.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | import pandas as pd 5 | import numpy as np 6 | import glob 7 | 8 | saved_result_dir = os.path.join('results', 'videocoatt') 9 | 10 | # define training modality type 11 | train_mode_list = [] 12 | train_mode_list.append('GT') 13 | # train_mode_list.append('Pr') 14 | 15 | # define test data type 16 | test_data_type_list = [] 17 | test_data_type_list.append('bbox_gt_gaze_True_thresh_f_score') 18 | # test_data_type_list.append('bbox_det_gaze_False_thresh_f_score') 19 | 20 | # define analize model names 21 | analyze_name_list_dic = {} 22 | 23 | # (Train:Test = GT:GT) 24 | analyze_name_list = [] 25 | # analyze_name_list.append('videocoatt-dual-people_field_middle_token_only_bbox_GT_gaze_GT_wo_action_wo_volley_tuned') 26 | # analyze_name_list.append('videocoatt-dual-people_field_middle_token_only_bbox_GT_gaze_GT_wo_action_w_volley_tuned_lr_0001') 27 | # ================================================================================================================================ 28 | analyze_name_list.append('videocoatt-dual-people_field_middle_token_only_bbox_GT_gaze_GT_wo_action_wo_volley_tuned_lr_000001') 29 | analyze_name_list.append('videocoatt-dual-people_field_middle_token_only_bbox_GT_gaze_GT_wo_action_w_volley_tuned') 30 | analyze_name_list_dic[0] = analyze_name_list 31 | 32 | # (Train:Test = Pr:Pr) 33 | # analyze_name_list = [] 34 | # analyze_name_list.append('videocoatt-dual-people_field_middle_token_only') 35 | # analyze_name_list.append('videocoatt-dual-people_field_middle_token_only_bbox_PRED_gaze_PRED_wo_action_volley_tuned') 36 | # analyze_name_list_dic[1] = analyze_name_list 37 | 38 | # define model names 39 | model_name_list = [] 40 | model_name_list.append('Ours w/o finetune') 41 | model_name_list.append('Ours w/ finetune') 42 | 43 | # epoch_sum = 105 44 | epoch_sum = 55 45 | epoch_div = 5 46 | for data_type_idx, analyze_name_list in analyze_name_list_dic.items(): 47 | test_data_type = test_data_type_list[data_type_idx] 48 | print(f'==={test_data_type}===') 49 | for analyze_idx, analyze_name in enumerate(analyze_name_list): 50 | model_wo_finetune = analyze_name_list[0] 51 | model_w_finetune = analyze_name_list[1] 52 | 53 | l2_dist_array = np.zeros((2, (epoch_sum//epoch_div)-1)) 54 | 55 | epoch_num_list = [epoch_num for epoch_num in range(epoch_div, epoch_sum, epoch_div)] 56 | for epoch_idx, epoch_num in enumerate(epoch_num_list): 57 | json_file_path_wo_finetune = os.path.join(saved_result_dir, model_wo_finetune, 'eval_results', test_data_type, f'epoch_{epoch_num}', 'eval_results.json') 58 | json_file_path_w_finetune = os.path.join(saved_result_dir, model_w_finetune, 'eval_results', test_data_type, f'epoch_{epoch_num}', 'eval_results.json') 59 | with open(json_file_path_wo_finetune, 'r') as f: 60 | eval_results_dic_wo_finetune = json.load(f) 61 | with open(json_file_path_w_finetune, 'r') as f: 62 | eval_results_dic_w_finetune = json.load(f) 63 | 64 | l2_dist_array[0, epoch_idx] = eval_results_dic_wo_finetune['l2_dist_euc_p_p'] 65 | l2_dist_array[1, epoch_idx] = eval_results_dic_w_finetune['l2_dist_euc_p_p'] 66 | 67 | # eval_results_dic_update = {} 68 | # eval_results_dic_update['Dist(x)'] = eval_results_dic['l2_dist_x_final'] 69 | # eval_results_dic_update['Dist(y)'] = eval_results_dic['l2_dist_y_final'] 70 | # eval_results_dic_update['Dist(euc)'] = eval_results_dic['l2_dist_euc_final'] 71 | # for i in range(20): 72 | # thr = i*10 73 | # eval_results_dic_update[f'Det(Thr={thr})'] = eval_results_dic[f'Det final (Thr={thr})'] 74 | # eval_results_dic_update['Accuracy'] = eval_results_dic['accuracy final'] 75 | # eval_results_dic_update['Precision'] = eval_results_dic['precision final'] 76 | # eval_results_dic_update['Recall'] = eval_results_dic['recall final'] 77 | # eval_results_dic_update['F-score'] = eval_results_dic['f1 final'] 78 | # eval_results_dic_update['AUC'] = eval_results_dic['auc final'] 79 | 80 | # eval_results_list.append(list(eval_results_dic_update.values())) 81 | # eval_metrics_list = list(eval_results_dic_update.keys()) 82 | 83 | df_eval_results = pd.DataFrame(l2_dist_array, model_name_list, epoch_num_list) 84 | save_csv_file_path = os.path.join(saved_result_dir, f'comparision_finetune_on_videocoatt_{train_mode_list[data_type_idx]}_{test_data_type}.csv') 85 | df_eval_results.to_csv(save_csv_file_path) -------------------------------------------------------------------------------- /yaml_files/gazefollow/debug_ours.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | name: gazefollow 3 | dataset_dir : data/gazefollow 4 | 5 | exp_set: 6 | save_folder: saved_weights 7 | wandb_log : False 8 | 9 | wandb_name: debug 10 | 11 | batch_size: 32 12 | # batch_size: 8 13 | # batch_size: 2 14 | num_workers: 16 15 | seed_num: 777 16 | gpu_mode : True 17 | gpu_start : 3 18 | gpu_finish : 3 19 | 20 | resize_height: 320 21 | resize_width: 480 22 | resize_head_height: 224 23 | resize_head_width: 224 24 | 25 | exp_params: 26 | # use_gaze_loss : False 27 | use_gaze_loss : True 28 | 29 | use_person_person_att_loss : False 30 | # use_person_person_att_loss : True 31 | person_person_att_loss_weight : 1 32 | 33 | use_person_person_jo_att_loss : False 34 | # use_person_person_jo_att_loss : True 35 | person_person_jo_att_loss_weight : 1 36 | 37 | # use_person_scene_att_loss : False 38 | use_person_scene_att_loss : True 39 | person_scene_att_loss_weight : 1 40 | 41 | use_person_scene_jo_att_loss : False 42 | # use_person_scene_jo_att_loss : True 43 | person_scene_jo_att_loss_weight : 1 44 | 45 | use_final_jo_att_loss : False 46 | # use_final_jo_att_loss : True 47 | final_jo_att_loss_weight : 1 48 | 49 | use_frame_type: mid 50 | # use_frame_type: all 51 | 52 | use_gt_gaze: False 53 | # use_gt_gaze: True 54 | 55 | # position augmentation 56 | use_position_aug: False 57 | # use_position_aug: True 58 | position_aug_std: 0.05 59 | 60 | # loss function 61 | loss : mse 62 | # loss : bce 63 | 64 | # learning rate 65 | lr : 0.0001 66 | 67 | # gt gaussian 68 | gaussian_sigma: 20 69 | 70 | # learning schedule 71 | nEpochs : 500 72 | start_iter : 0 73 | snapshots : 100 74 | scheduler_start : 1000 75 | scheduler_iter : 1100000 76 | 77 | # pretrained models 78 | pretrained_models_dir: saved_weights 79 | 80 | # use_pretrained_head_pose_estimator: False 81 | use_pretrained_head_pose_estimator: True 82 | pretrained_head_pose_estimator_name: videocoatt-head_pose_estimator 83 | pretrained_head_pose_estimator_name: gazefollow-dual-cnn-w_pre 84 | # freeze_head_pose_estimator: False 85 | freeze_head_pose_estimator: True 86 | 87 | use_pretrained_saliency_extractor: False 88 | # use_pretrained_saliency_extractor: True 89 | pretrained_saliency_extractor_name: pretrained_scene_extractor_davt 90 | freeze_saliency_extractor: False 91 | # freeze_saliency_extractor: True 92 | 93 | use_pretrained_joint_attention_estimator: False 94 | # use_pretrained_joint_attention_estimator: True 95 | pretrained_joint_attention_estimator_name: pretrain_head_estimator 96 | freeze_joint_attention_estimator: False 97 | # freeze_joint_attention_estimator: True 98 | 99 | model_params: 100 | model_type: ja_transformer_dual 101 | 102 | # Position 103 | # use_position : False 104 | use_position : True 105 | 106 | # Gaze 107 | # use_gaze : False 108 | use_gaze : True 109 | 110 | # Action 111 | use_action : False 112 | # use_action : True 113 | 114 | # Person embedding 115 | # head_embedding_type : liner 116 | head_embedding_type : mlp 117 | 118 | # Whole image 119 | # use_img : False 120 | use_img : True 121 | 122 | # person-person transformer 123 | people_feat_dim : 16 124 | # use_people_people_trans: False 125 | use_people_people_trans: True 126 | people_people_trans_enc_num : 2 127 | mha_num_heads_people_people : 2 128 | 129 | # rgb-person transformer 130 | rgb_feat_dim : 256 131 | rgb_people_trans_enc_num : 1 132 | mha_num_heads_rgb_people : 1 133 | # p_p_estimator_type : fc_shallow 134 | # p_p_estimator_type : fc_middle 135 | # p_p_estimator_type : fc_deep 136 | # p_p_estimator_type : deconv_shallow 137 | # p_p_estimator_type : deconv_middle 138 | # p_p_estimator_type : deconv_deep 139 | # p_p_estimator_type : field_shallow 140 | p_p_estimator_type : field_middle 141 | # p_p_estimator_type : field_deep 142 | 143 | p_p_aggregation_type : ind_only 144 | # p_p_aggregation_type : token_only 145 | # p_p_aggregation_type : ind_and_token_ind_based 146 | # p_p_aggregation_type : ind_and_token_token_based 147 | 148 | # rgb_cnn_extractor_type : rgb_patch 149 | # rgb_patch_size : 8 150 | # rgb_cnn_extractor_type : resnet18 151 | rgb_cnn_extractor_type : resnet50 152 | # rgb_cnn_extractor_stage_idx : 1 153 | # rgb_cnn_extractor_stage_idx : 2 154 | # rgb_cnn_extractor_stage_idx : 3 155 | rgb_cnn_extractor_stage_idx : 4 156 | 157 | p_s_estimator_type : davt 158 | # p_s_estimator_type : transformer 159 | # p_s_estimator_type : cnn 160 | # p_s_estimator_cnn_pretrain : False 161 | p_s_estimator_cnn_pretrain : True 162 | use_p_s_estimator_att_inside : False 163 | # use_p_s_estimator_att_inside : True 164 | 165 | # fusion_net_type : early 166 | # fusion_net_type : mid 167 | # fusion_net_type : late 168 | # fusion_net_type : scalar_weight 169 | fusion_net_type : simple_average -------------------------------------------------------------------------------- /yaml_files/gazefollow/train_ours.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | name: gazefollow 3 | dataset_dir : data/gazefollow 4 | 5 | exp_set: 6 | save_folder: saved_weights 7 | wandb_log : True 8 | 9 | wandb_name: gazefollow-only_davt 10 | 11 | batch_size: 32 12 | # batch_size: 8 13 | # batch_size: 2 14 | num_workers: 16 15 | seed_num: 777 16 | gpu_mode : True 17 | gpu_start : 3 18 | gpu_finish : 3 19 | 20 | resize_height: 320 21 | resize_width: 480 22 | resize_head_height: 224 23 | resize_head_width: 224 24 | 25 | exp_params: 26 | use_gaze_loss : False 27 | # use_gaze_loss : True 28 | 29 | use_person_person_att_loss : False 30 | # use_person_person_att_loss : True 31 | person_person_att_loss_weight : 1 32 | 33 | use_person_person_jo_att_loss : False 34 | # use_person_person_jo_att_loss : True 35 | person_person_jo_att_loss_weight : 1 36 | 37 | # use_person_scene_att_loss : False 38 | use_person_scene_att_loss : True 39 | person_scene_att_loss_weight : 1 40 | 41 | use_person_scene_jo_att_loss : False 42 | # use_person_scene_jo_att_loss : True 43 | person_scene_jo_att_loss_weight : 1 44 | 45 | use_final_jo_att_loss : False 46 | # use_final_jo_att_loss : True 47 | final_jo_att_loss_weight : 1 48 | 49 | use_frame_type: mid 50 | # use_frame_type: all 51 | 52 | use_gt_gaze: False 53 | # use_gt_gaze: True 54 | 55 | # position augmentation 56 | use_position_aug: False 57 | # use_position_aug: True 58 | position_aug_std: 0.05 59 | 60 | # loss function 61 | loss : mse 62 | # loss : bce 63 | 64 | # learning rate 65 | lr : 0.00001 66 | 67 | # gt gaussian 68 | gaussian_sigma: 20 69 | 70 | # learning schedule 71 | nEpochs : 500 72 | start_iter : 0 73 | snapshots : 100 74 | scheduler_start : 1000 75 | scheduler_iter : 1100000 76 | 77 | # pretrained models 78 | pretrained_models_dir: saved_weights 79 | 80 | # use_pretrained_head_pose_estimator: False 81 | use_pretrained_head_pose_estimator: True 82 | # pretrained_head_pose_estimator_name: videocoatt-head_pose_estimator 83 | pretrained_head_pose_estimator_name: gazefollow-dual-cnn-w_pre 84 | # freeze_head_pose_estimator: False 85 | freeze_head_pose_estimator: True 86 | 87 | use_pretrained_saliency_extractor: False 88 | # use_pretrained_saliency_extractor: True 89 | pretrained_saliency_extractor_name: pretrained_scene_extractor_davt 90 | freeze_saliency_extractor: False 91 | # freeze_saliency_extractor: True 92 | 93 | use_pretrained_joint_attention_estimator: False 94 | # use_pretrained_joint_attention_estimator: True 95 | pretrained_joint_attention_estimator_name: pretrain_head_estimator 96 | freeze_joint_attention_estimator: False 97 | # freeze_joint_attention_estimator: True 98 | 99 | model_params: 100 | model_type: ja_transformer_dual 101 | 102 | # Position 103 | # use_position : False 104 | use_position : True 105 | 106 | # Gaze 107 | # use_gaze : False 108 | use_gaze : True 109 | 110 | # Action 111 | use_action : False 112 | # use_action : True 113 | 114 | # Person embedding 115 | # head_embedding_type : liner 116 | head_embedding_type : mlp 117 | 118 | # Whole image 119 | # use_img : False 120 | use_img : True 121 | 122 | # person-person transformer 123 | people_feat_dim : 16 124 | # use_people_people_trans: False 125 | use_people_people_trans: True 126 | people_people_trans_enc_num : 2 127 | mha_num_heads_people_people : 2 128 | 129 | # rgb-person transformer 130 | rgb_feat_dim : 256 131 | rgb_people_trans_enc_num : 1 132 | mha_num_heads_rgb_people : 1 133 | # p_p_estimator_type : fc_shallow 134 | # p_p_estimator_type : fc_middle 135 | # p_p_estimator_type : fc_deep 136 | # p_p_estimator_type : deconv_shallow 137 | # p_p_estimator_type : deconv_middle 138 | # p_p_estimator_type : deconv_deep 139 | # p_p_estimator_type : field_shallow 140 | p_p_estimator_type : field_middle 141 | # p_p_estimator_type : field_deep 142 | 143 | p_p_aggregation_type : ind_only 144 | # p_p_aggregation_type : token_only 145 | # p_p_aggregation_type : ind_and_token_ind_based 146 | # p_p_aggregation_type : ind_and_token_token_based 147 | 148 | # rgb_cnn_extractor_type : rgb_patch 149 | # rgb_patch_size : 8 150 | # rgb_cnn_extractor_type : resnet18 151 | rgb_cnn_extractor_type : resnet50 152 | # rgb_cnn_extractor_stage_idx : 1 153 | # rgb_cnn_extractor_stage_idx : 2 154 | # rgb_cnn_extractor_stage_idx : 3 155 | rgb_cnn_extractor_stage_idx : 4 156 | 157 | p_s_estimator_type : davt 158 | # p_s_estimator_type : transformer 159 | # p_s_estimator_type : cnn 160 | # p_s_estimator_cnn_pretrain : False 161 | p_s_estimator_cnn_pretrain : True 162 | use_p_s_estimator_att_inside : False 163 | # use_p_s_estimator_att_inside : True 164 | 165 | # fusion_net_type : early 166 | # fusion_net_type : mid 167 | # fusion_net_type : late 168 | # fusion_net_type : scalar_weight 169 | fusion_net_type : simple_average -------------------------------------------------------------------------------- /yaml_files/videoattentiontarget/train_ours.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | name: videoattentiontarget 3 | dataset_dir : data/videoattentiontarget 4 | 5 | exp_set: 6 | save_folder: saved_weights 7 | wandb_log : True 8 | 9 | wandb_name: videoattentiontarget-only_davt_PRED 10 | 11 | batch_size: 16 12 | # batch_size: 8 13 | # batch_size: 2 14 | num_workers: 16 15 | seed_num: 777 16 | gpu_mode : True 17 | gpu_start : 1 18 | gpu_finish : 1 19 | 20 | resize_height: 320 21 | resize_width: 480 22 | resize_head_height: 224 23 | resize_head_width: 224 24 | 25 | exp_params: 26 | use_gaze_loss : False 27 | # use_gaze_loss : True 28 | 29 | use_person_person_att_loss : False 30 | # use_person_person_att_loss : True 31 | person_person_att_loss_weight : 1 32 | 33 | use_person_person_jo_att_loss : False 34 | # use_person_person_jo_att_loss : True 35 | person_person_jo_att_loss_weight : 1 36 | 37 | # use_person_scene_att_loss : False 38 | use_person_scene_att_loss : True 39 | person_scene_att_loss_weight : 1 40 | 41 | use_person_scene_jo_att_loss : False 42 | # use_person_scene_jo_att_loss : True 43 | person_scene_jo_att_loss_weight : 1 44 | 45 | use_final_jo_att_loss : False 46 | # use_final_jo_att_loss : True 47 | final_jo_att_loss_weight : 1 48 | 49 | use_frame_type: mid 50 | # use_frame_type: all 51 | 52 | use_gt_gaze: False 53 | # use_gt_gaze: True 54 | 55 | # bbox_types: GT 56 | bbox_types: PRED 57 | bbox_iou_thresh: 0.6 58 | 59 | # position augmentation 60 | use_position_aug: False 61 | # use_position_aug: True 62 | position_aug_std: 0.05 63 | 64 | # loss function 65 | loss : mse 66 | # loss : bce 67 | 68 | # learning rate 69 | # lr : 0.00001 70 | lr : 0.001 71 | 72 | # gt gaussian 73 | gaussian_sigma: 20 74 | 75 | # learning schedule 76 | nEpochs : 500 77 | start_iter : 0 78 | snapshots : 100 79 | scheduler_start : 1000 80 | scheduler_iter : 1100000 81 | 82 | # pretrained models 83 | pretrained_models_dir: saved_weights 84 | 85 | # use_pretrained_head_pose_estimator: False 86 | use_pretrained_head_pose_estimator: True 87 | pretrained_head_pose_estimator_name: gazefollow-dual-cnn-w_pre 88 | # pretrained_head_pose_estimator_name: gazefollow-dual-trans-w_pre 89 | # freeze_head_pose_estimator: False 90 | freeze_head_pose_estimator: True 91 | 92 | # use_pretrained_saliency_extractor: False 93 | use_pretrained_saliency_extractor: True 94 | pretrained_saliency_extractor_name: pretrained_scene_extractor_davt 95 | freeze_saliency_extractor: False 96 | # freeze_saliency_extractor: True 97 | 98 | use_pretrained_joint_attention_estimator: False 99 | # use_pretrained_joint_attention_estimator: True 100 | pretrained_joint_attention_estimator_name: pretrain_head_estimator 101 | freeze_joint_attention_estimator: False 102 | # freeze_joint_attention_estimator: True 103 | 104 | model_params: 105 | model_type: ja_transformer_dual 106 | 107 | # Position 108 | # use_position : False 109 | use_position : True 110 | 111 | # Gaze 112 | # use_gaze : False 113 | use_gaze : True 114 | 115 | # Action 116 | use_action : False 117 | # use_action : True 118 | 119 | # Person embedding 120 | # head_embedding_type : liner 121 | head_embedding_type : mlp 122 | 123 | # Whole image 124 | # use_img : False 125 | use_img : True 126 | 127 | # person-person transformer 128 | people_feat_dim : 16 129 | # use_people_people_trans: False 130 | use_people_people_trans: True 131 | people_people_trans_enc_num : 2 132 | mha_num_heads_people_people : 2 133 | 134 | # rgb-person transformer 135 | rgb_feat_dim : 256 136 | rgb_people_trans_enc_num : 1 137 | mha_num_heads_rgb_people : 1 138 | # p_p_estimator_type : fc_shallow 139 | # p_p_estimator_type : fc_middle 140 | # p_p_estimator_type : fc_deep 141 | # p_p_estimator_type : deconv_shallow 142 | # p_p_estimator_type : deconv_middle 143 | # p_p_estimator_type : deconv_deep 144 | # p_p_estimator_type : field_shallow 145 | p_p_estimator_type : field_middle 146 | # p_p_estimator_type : field_deep 147 | 148 | p_p_aggregation_type : ind_only 149 | # p_p_aggregation_type : token_only 150 | # p_p_aggregation_type : ind_and_token_ind_based 151 | # p_p_aggregation_type : ind_and_token_token_based 152 | 153 | # rgb_cnn_extractor_type : rgb_patch 154 | # rgb_patch_size : 8 155 | # rgb_cnn_extractor_type : resnet18 156 | rgb_cnn_extractor_type : resnet50 157 | # rgb_cnn_extractor_stage_idx : 1 158 | # rgb_cnn_extractor_stage_idx : 2 159 | # rgb_cnn_extractor_stage_idx : 3 160 | rgb_cnn_extractor_stage_idx : 4 161 | 162 | p_s_estimator_type : davt 163 | # p_s_estimator_type : transformer 164 | # p_s_estimator_type : cnn 165 | # p_s_estimator_cnn_pretrain : False 166 | p_s_estimator_cnn_pretrain : True 167 | use_p_s_estimator_att_inside : False 168 | # use_p_s_estimator_att_inside : True 169 | 170 | # fusion_net_type : early 171 | # fusion_net_type : mid 172 | # fusion_net_type : late 173 | # fusion_net_type : scalar_weight 174 | fusion_net_type : simple_average -------------------------------------------------------------------------------- /yaml_files/videoattentiontarget/debug_ours.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | name: videoattentiontarget 3 | dataset_dir : data/videoattentiontarget 4 | 5 | exp_set: 6 | save_folder: saved_weights 7 | wandb_log : False 8 | 9 | wandb_name: debug 10 | 11 | batch_size: 16 12 | # batch_size: 8 13 | # batch_size: 2 14 | num_workers: 16 15 | seed_num: 777 16 | gpu_mode : True 17 | gpu_start : 0 18 | gpu_finish : 0 19 | 20 | resize_height: 320 21 | resize_width: 480 22 | resize_head_height: 224 23 | resize_head_width: 224 24 | 25 | exp_params: 26 | use_gaze_loss : False 27 | # use_gaze_loss : True 28 | 29 | use_person_person_att_loss : False 30 | # use_person_person_att_loss : True 31 | person_person_att_loss_weight : 1 32 | 33 | use_person_person_jo_att_loss : False 34 | # use_person_person_jo_att_loss : True 35 | person_person_jo_att_loss_weight : 1 36 | 37 | # use_person_scene_att_loss : False 38 | use_person_scene_att_loss : True 39 | person_scene_att_loss_weight : 1 40 | 41 | use_person_scene_jo_att_loss : False 42 | # use_person_scene_jo_att_loss : True 43 | person_scene_jo_att_loss_weight : 1 44 | 45 | use_final_jo_att_loss : False 46 | # use_final_jo_att_loss : True 47 | final_jo_att_loss_weight : 1 48 | 49 | use_frame_type: mid 50 | # use_frame_type: all 51 | 52 | use_gt_gaze: False 53 | # use_gt_gaze: True 54 | 55 | # bbox_types: GT 56 | bbox_types: PRED 57 | bbox_iou_thresh: 0.6 58 | 59 | # position augmentation 60 | use_position_aug: False 61 | # use_position_aug: True 62 | position_aug_std: 0.05 63 | 64 | # loss function 65 | loss : mse 66 | # loss : bce 67 | 68 | # learning rate 69 | lr : 0.00001 70 | 71 | # gt gaussian 72 | gaussian_sigma: 20 73 | 74 | # learning schedule 75 | nEpochs : 500 76 | start_iter : 0 77 | snapshots : 100 78 | scheduler_start : 1000 79 | scheduler_iter : 1100000 80 | 81 | # pretrained models 82 | pretrained_models_dir: saved_weights 83 | 84 | # use_pretrained_head_pose_estimator: False 85 | use_pretrained_head_pose_estimator: True 86 | pretrained_head_pose_estimator_name: gazefollow-dual-cnn-w_pre 87 | # pretrained_head_pose_estimator_name: gazefollow-dual-trans-w_pre 88 | # freeze_head_pose_estimator: False 89 | freeze_head_pose_estimator: True 90 | 91 | use_pretrained_saliency_extractor: False 92 | # use_pretrained_saliency_extractor: True 93 | pretrained_saliency_extractor_name: gazefollow-dual-cnn-w_pre 94 | # pretrained_saliency_extractor_name: gazefollow-dual-trans-w_pre 95 | freeze_saliency_extractor: False 96 | # freeze_saliency_extractor: True 97 | 98 | use_pretrained_joint_attention_estimator: False 99 | # use_pretrained_joint_attention_estimator: True 100 | pretrained_joint_attention_estimator_name: pretrain_head_estimator 101 | freeze_joint_attention_estimator: False 102 | # freeze_joint_attention_estimator: True 103 | 104 | model_params: 105 | model_type: ja_transformer_dual 106 | 107 | # Position 108 | # use_position : False 109 | use_position : True 110 | 111 | # Gaze 112 | # use_gaze : False 113 | use_gaze : True 114 | 115 | # Action 116 | use_action : False 117 | # use_action : True 118 | 119 | # Person embedding 120 | # head_embedding_type : liner 121 | head_embedding_type : mlp 122 | 123 | # Whole image 124 | # use_img : False 125 | use_img : True 126 | 127 | # person-person transformer 128 | people_feat_dim : 16 129 | # use_people_people_trans: False 130 | use_people_people_trans: True 131 | people_people_trans_enc_num : 2 132 | mha_num_heads_people_people : 2 133 | 134 | # rgb-person transformer 135 | rgb_feat_dim : 256 136 | rgb_people_trans_enc_num : 1 137 | mha_num_heads_rgb_people : 1 138 | # p_p_estimator_type : fc_shallow 139 | # p_p_estimator_type : fc_middle 140 | # p_p_estimator_type : fc_deep 141 | # p_p_estimator_type : deconv_shallow 142 | # p_p_estimator_type : deconv_middle 143 | # p_p_estimator_type : deconv_deep 144 | # p_p_estimator_type : field_shallow 145 | p_p_estimator_type : field_middle 146 | # p_p_estimator_type : field_deep 147 | 148 | p_p_aggregation_type : ind_only 149 | # p_p_aggregation_type : token_only 150 | # p_p_aggregation_type : ind_and_token_ind_based 151 | # p_p_aggregation_type : ind_and_token_token_based 152 | 153 | # rgb_cnn_extractor_type : rgb_patch 154 | # rgb_patch_size : 8 155 | # rgb_cnn_extractor_type : resnet18 156 | rgb_cnn_extractor_type : resnet50 157 | # rgb_cnn_extractor_stage_idx : 1 158 | # rgb_cnn_extractor_stage_idx : 2 159 | # rgb_cnn_extractor_stage_idx : 3 160 | rgb_cnn_extractor_stage_idx : 4 161 | 162 | p_s_estimator_type : davt 163 | # p_s_estimator_type : transformer 164 | # p_s_estimator_type : cnn 165 | # p_s_estimator_cnn_pretrain : False 166 | p_s_estimator_cnn_pretrain : True 167 | use_p_s_estimator_att_inside : False 168 | # use_p_s_estimator_att_inside : True 169 | 170 | # fusion_net_type : early 171 | # fusion_net_type : mid 172 | # fusion_net_type : late 173 | # fusion_net_type : scalar_weight 174 | fusion_net_type : simple_average -------------------------------------------------------------------------------- /yaml_files/videocoatt/debug_ours.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | name: videocoatt 3 | dataset_dir : data/VideoCoAtt_Dataset 4 | saliency_dataset_dir : data/deepgaze_output_loader 5 | 6 | exp_set: 7 | save_folder: saved_weights 8 | wandb_log : False 9 | wandb_name: debug 10 | 11 | batch_size: 8 12 | # batch_size: 2 13 | num_workers: 16 14 | seed_num: 777 15 | gpu_mode : True 16 | gpu_start : 4 17 | gpu_finish : 4 18 | 19 | resize_height: 320 20 | resize_width: 480 21 | resize_head_height: 224 22 | resize_head_width: 224 23 | 24 | exp_params: 25 | 26 | use_person_person_att_loss : False 27 | # use_person_person_att_loss : True 28 | person_person_att_loss_weight : 1 29 | 30 | # use_person_person_jo_att_loss : False 31 | use_person_person_jo_att_loss : True 32 | person_person_jo_att_loss_weight : 1 33 | 34 | use_person_scene_att_loss : False 35 | # use_person_scene_att_loss : True 36 | person_scene_att_loss_weight : 1 37 | 38 | use_person_scene_jo_att_loss : False 39 | # use_person_scene_jo_att_loss : True 40 | person_scene_jo_att_loss_weight : 1 41 | 42 | # use_final_jo_att_loss : False 43 | use_final_jo_att_loss : True 44 | final_jo_att_loss_weight : 1 45 | 46 | use_frame_type: mid 47 | # use_frame_type: all 48 | 49 | use_gt_gaze: False 50 | # use_gt_gaze: True 51 | 52 | # position augmentation 53 | use_position_aug: False 54 | # use_position_aug: True 55 | position_aug_std: 0.05 56 | 57 | # loss function 58 | loss : mse 59 | # loss : bce 60 | 61 | # learning rate 62 | lr : 0.0001 63 | 64 | # gt gaussian 65 | gaussian_sigma: 10 66 | 67 | # learning schedule 68 | nEpochs : 500 69 | start_iter : 0 70 | snapshots : 100 71 | scheduler_start : 1000 72 | scheduler_iter : 1100000 73 | 74 | det_heads_model : det_heads 75 | # train_det_heads : False 76 | train_det_heads : True 77 | train_heads_conf : 0.6 78 | test_heads_conf : 0.6 79 | 80 | # pretrained models 81 | pretrained_models_dir: saved_weights 82 | 83 | # use_pretrained_head_pose_estimator: False 84 | use_pretrained_head_pose_estimator: True 85 | # pretrained_head_pose_estimator_name: videocoatt-head_pose_estimator 86 | pretrained_head_pose_estimator_name: gazefollow-dual-cnn-w_pre 87 | # pretrained_head_pose_estimator_name: videoattentiontarget-dual-cnn_wo_pre_w_att_ins 88 | # freeze_head_pose_estimator: False 89 | freeze_head_pose_estimator: True 90 | 91 | # use_pretrained_saliency_extractor: False 92 | use_pretrained_saliency_extractor: True 93 | # pretrained_saliency_extractor_name: pretrained_scene_extractor_davt 94 | # pretrained_saliency_extractor_name: gazefollow-dual-cnn 95 | pretrained_saliency_extractor_name: videoattentiontarget-dual-cnn_wo_pre_w_att_ins 96 | # freeze_saliency_extractor: False 97 | freeze_saliency_extractor: True 98 | 99 | use_pretrained_joint_attention_estimator: False 100 | # use_pretrained_head_pose_estimator: True 101 | pretrained_joint_attention_estimator_name: pretrain_head_estimator 102 | freeze_joint_attention_estimator: False 103 | # freeze_joint_attention_estimator: True 104 | 105 | model_params: 106 | model_type: ja_transformer_dual 107 | 108 | # Position 109 | # use_position : False 110 | use_position : True 111 | 112 | # Gaze 113 | # use_gaze : False 114 | use_gaze : True 115 | 116 | # Action 117 | use_action : False 118 | # use_action : True 119 | 120 | # Person embedding 121 | # head_embedding_type : liner 122 | head_embedding_type : mlp 123 | 124 | # Whole image 125 | # use_img : False 126 | use_img : True 127 | 128 | # person-person transformer 129 | people_feat_dim : 16 130 | # use_people_people_trans: False 131 | use_people_people_trans: True 132 | people_people_trans_enc_num : 2 133 | mha_num_heads_people_people : 2 134 | # p_p_estimator_type : fc_shallow 135 | # p_p_estimator_type : fc_middle 136 | # p_p_estimator_type : fc_deep 137 | # p_p_estimator_type : deconv_shallow 138 | # p_p_estimator_type : deconv_middle 139 | # p_p_estimator_type : deconv_deep 140 | # p_p_estimator_type : field_shallow 141 | # p_p_estimator_type : field_middle 142 | p_p_estimator_type : field_deep 143 | 144 | # rgb-person transformer 145 | rgb_feat_dim : 256 146 | rgb_people_trans_enc_num : 1 147 | mha_num_heads_rgb_people : 1 148 | 149 | # rgb_cnn_extractor_type : rgb_patch 150 | # rgb_patch_size : 8 151 | # rgb_cnn_extractor_type : resnet18 152 | rgb_cnn_extractor_type : resnet50 153 | # rgb_cnn_extractor_stage_idx : 1 154 | # rgb_cnn_extractor_stage_idx : 2 155 | # rgb_cnn_extractor_stage_idx : 3 156 | rgb_cnn_extractor_stage_idx : 4 157 | 158 | # p_s_estimator_type : davt 159 | # p_s_estimator_type : transformer 160 | p_s_estimator_type : cnn 161 | p_s_estimator_cnn_pretrain : False 162 | # p_s_estimator_cnn_pretrain : True 163 | # use_p_s_estimator_att_inside : False 164 | use_p_s_estimator_att_inside : True 165 | 166 | # fusion_net_type : early 167 | # fusion_net_type : mid 168 | # fusion_net_type : late 169 | fusion_net_type : simple_average 170 | # fusion_net_type : scalar_weight -------------------------------------------------------------------------------- /yaml_files/volleyball/train_ours_p_p.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | name: volleyball 3 | sendo_dataset_dir : data/volleyball_tracking_annotation 4 | rgb_dataset_dir : data/videos 5 | annotation_dir : data/vatic_ball_annotation/annotation_data/ 6 | dataset_bbox_gt: data/jae_dataset_bbox_gt 7 | dataset_bbox_pred: data/jae_dataset_bbox_pred 8 | 9 | exp_set: 10 | save_folder: saved_weights 11 | wandb_log : True 12 | 13 | wandb_name: volleyball-dual-mid_p_p_field_middle_bbox_PRED_gaze_PRED_act_PRED_token_only 14 | 15 | batch_size: 8 16 | # batch_size: 4 17 | num_workers: 16 18 | seed_num: 777 19 | # seed_num: 888 20 | gpu_mode : True 21 | gpu_start : 1 22 | gpu_finish : 1 23 | 24 | resize_height : 320 25 | resize_width : 640 26 | resize_head_width: 64 27 | resize_head_height: 64 28 | 29 | exp_params: 30 | 31 | use_person_person_att_loss : False 32 | # use_person_person_att_loss : True 33 | person_person_att_loss_weight : 1 34 | 35 | # use_person_person_jo_att_loss : False 36 | use_person_person_jo_att_loss : True 37 | person_person_jo_att_loss_weight : 1 38 | 39 | use_person_scene_att_loss : False 40 | # use_person_scene_att_loss : True 41 | person_scene_att_loss_weight : 1 42 | 43 | use_person_scene_jo_att_loss : False 44 | # use_person_scene_jo_att_loss : True 45 | person_scene_jo_att_loss_weight : 1 46 | 47 | use_final_jo_att_loss : False 48 | # use_final_jo_att_loss : True 49 | final_jo_att_loss_weight : 1 50 | 51 | use_frame_type: mid 52 | # use_frame_type: all 53 | 54 | bbox_types: GT 55 | # bbox_types: PRED 56 | 57 | gaze_types: GT 58 | # gaze_types: PRED 59 | 60 | action_types: GT 61 | # action_types: PRED 62 | 63 | # position augmentation 64 | use_position_aug: False 65 | # use_position_aug: True 66 | position_aug_std: 0.05 67 | 68 | # loss function 69 | loss : mse 70 | # loss : bce 71 | 72 | # learning rate 73 | # lr : 0.0001 74 | lr : 0.001 75 | # lr : 0.01 76 | # lr : 0.1 77 | 78 | # gt gaussian 79 | # gaussian_sigma: 40 80 | # gaussian_sigma: 20 81 | gaussian_sigma: 10 82 | 83 | # learning schedule 84 | nEpochs : 500 85 | start_iter : 0 86 | snapshots : 100 87 | scheduler_start : 1000 88 | scheduler_iter : 1100000 89 | 90 | # pretrained models 91 | pretrained_models_dir: saved_weights 92 | 93 | # use_pretrained_head_pose_estimator: False 94 | use_pretrained_head_pose_estimator: True 95 | pretrained_head_pose_estimator_name: volleyball-head_pose_estimator 96 | # freeze_head_pose_estimator: False 97 | freeze_head_pose_estimator: True 98 | 99 | use_pretrained_saliency_extractor: False 100 | # use_pretrained_saliency_extractor: True 101 | pretrained_saliency_extractor_name: pretrained_scene_extractor_davt 102 | freeze_saliency_extractor: False 103 | # freeze_saliency_extractor: True 104 | 105 | use_pretrained_joint_attention_estimator: False 106 | # use_pretrained_head_pose_estimator: True 107 | pretrained_joint_attention_estimator_name: pretrain_head_estimator 108 | freeze_joint_attention_estimator: False 109 | # freeze_joint_attention_estimator: True 110 | 111 | model_params: 112 | model_type: ja_transformer_dual_only_people 113 | 114 | # Position 115 | # use_position : False 116 | use_position : True 117 | 118 | # Gaze 119 | # use_gaze : False 120 | use_gaze : True 121 | 122 | # Action 123 | # use_action : False 124 | use_action : True 125 | 126 | # Person embedding 127 | # head_embedding_type : liner 128 | head_embedding_type : mlp 129 | 130 | # Whole image 131 | # use_img : False 132 | use_img : True 133 | 134 | # person-person transformer 135 | people_feat_dim : 16 136 | # people_feat_dim : 32 137 | # use_people_people_trans: False 138 | use_people_people_trans: True 139 | 140 | # people_people_trans_enc_num : 1 141 | # people_people_trans_enc_num : 2 142 | people_people_trans_enc_num : 3 143 | # people_people_trans_enc_num : 4 144 | 145 | # mha_num_heads_people_people : 1 146 | mha_num_heads_people_people : 2 147 | # mha_num_heads_people_people : 4 148 | # mha_num_heads_people_people : 8 149 | # mha_num_heads_people_people : 16 150 | 151 | # p_p_estimator_type : fc_shallow 152 | # p_p_estimator_type : fc_middle 153 | # p_p_estimator_type : fc_deep 154 | # p_p_estimator_type : deconv_shallow 155 | # p_p_estimator_type : deconv_middle 156 | p_p_estimator_type : field_middle 157 | # p_p_estimator_type : field_deep 158 | 159 | # p_p_aggregation_type : ind_only 160 | p_p_aggregation_type : token_only 161 | # p_p_aggregation_type : ind_and_token_ind_based 162 | # p_p_aggregation_type : ind_and_token_token_based 163 | 164 | # rgb-person transformer 165 | rgb_feat_dim : 256 166 | rgb_people_trans_enc_num : 1 167 | mha_num_heads_rgb_people : 1 168 | 169 | # rgb_cnn_extractor_type : rgb_patch 170 | # rgb_patch_size : 8 171 | # rgb_cnn_extractor_type : resnet18 172 | rgb_cnn_extractor_type : resnet50 173 | # rgb_cnn_extractor_stage_idx : 1 174 | # rgb_cnn_extractor_stage_idx : 2 175 | # rgb_cnn_extractor_stage_idx : 3 176 | rgb_cnn_extractor_stage_idx : 4 177 | 178 | p_s_estimator_type : davt 179 | # p_s_estimator_type : transformer 180 | # p_s_estimator_type : cnn 181 | p_s_estimator_cnn_pretrain : False 182 | # p_s_estimator_cnn_pretrain : True 183 | use_p_s_estimator_att_inside : False 184 | # use_p_s_estimator_att_inside : True -------------------------------------------------------------------------------- /yaml_files/videocoatt/debug_ours_p_p.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | name: videocoatt 3 | dataset_dir : data/VideoCoAtt_Dataset 4 | saliency_dataset_dir : data/deepgaze_output_loader 5 | 6 | exp_set: 7 | save_folder: saved_weights 8 | wandb_log : False 9 | 10 | wandb_name: debug 11 | 12 | # batch_size: 16 13 | batch_size: 8 14 | # batch_size: 2 15 | num_workers: 16 16 | seed_num: 777 17 | gpu_mode : True 18 | gpu_start : 3 19 | gpu_finish : 3 20 | 21 | resize_height: 320 22 | resize_width: 480 23 | # resize_width: 640 24 | resize_head_height: 224 25 | resize_head_width: 224 26 | 27 | exp_params: 28 | use_person_person_att_loss : False 29 | # use_person_person_att_loss : True 30 | person_person_att_loss_weight : 1 31 | 32 | # use_person_person_jo_att_loss : False 33 | use_person_person_jo_att_loss : True 34 | person_person_jo_att_loss_weight : 1 35 | 36 | use_person_scene_att_loss : False 37 | # use_person_scene_att_loss : True 38 | person_scene_att_loss_weight : 1 39 | 40 | use_person_scene_jo_att_loss : False 41 | # use_person_scene_jo_att_loss : True 42 | person_scene_jo_att_loss_weight : 1 43 | 44 | use_final_jo_att_loss : False 45 | # use_final_jo_att_loss : True 46 | final_jo_att_loss_weight : 1 47 | 48 | use_frame_type: mid 49 | # use_frame_type: all 50 | 51 | use_gt_gaze: False 52 | # use_gt_gaze: True 53 | 54 | # position augmentation 55 | use_position_aug: False 56 | # use_position_aug: True 57 | position_aug_std: 0.05 58 | 59 | # loss function 60 | loss : mse 61 | # loss : bce 62 | 63 | # learning rate 64 | lr : 0.0001 65 | 66 | # gt gaussian 67 | gaussian_sigma: 10 68 | 69 | # learning schedule 70 | nEpochs : 500 71 | start_iter : 0 72 | snapshots : 100 73 | scheduler_start : 1000 74 | scheduler_iter : 1100000 75 | 76 | det_heads_model : det_heads 77 | # train_det_heads : False 78 | train_det_heads : True 79 | train_heads_conf : 0.6 80 | test_heads_conf : 0.6 81 | 82 | # pretrained models 83 | pretrained_models_dir: saved_weights 84 | 85 | # use_pretrained_head_pose_estimator: False 86 | use_pretrained_head_pose_estimator: True 87 | # pretrained_head_pose_estimator_name: videocoatt-head_pose_estimator 88 | pretrained_head_pose_estimator_name: gazefollow-dual-cnn-w_pre 89 | # pretrained_head_pose_estimator_name: videoattentiontarget-dual-cnn_wo_pre_w_att_ins 90 | # freeze_head_pose_estimator: False 91 | freeze_head_pose_estimator: True 92 | 93 | use_pretrained_saliency_extractor: False 94 | # use_pretrained_saliency_extractor: True 95 | pretrained_saliency_extractor_name: pretrained_scene_extractor_davt 96 | # pretrained_saliency_extractor_name: gazefollow-dual-cnn 97 | # pretrained_saliency_extractor_name: videoattentiontarget-dual-cnn_wo_pre_w_att_ins 98 | freeze_saliency_extractor: False 99 | # freeze_saliency_extractor: True 100 | 101 | # use_pretrained_joint_attention_estimator: False 102 | use_pretrained_joint_attention_estimator: True 103 | # pretrained_joint_attention_estimator_name: pretrain_head_estimator 104 | # pretrained_joint_attention_estimator_name: volleyball-dual-mid_p_p_field_middle_bbox_GT_gaze_GT_act_GT 105 | pretrained_joint_attention_estimator_name: volleyball-dual-mid_p_p_field_middle_bbox_GT_gaze_GT_act_GT_wo_action 106 | freeze_joint_attention_estimator: False 107 | # freeze_joint_attention_estimator: True 108 | 109 | model_params: 110 | model_type: ja_transformer_dual_only_people 111 | 112 | # Position 113 | # use_position : False 114 | use_position : True 115 | 116 | # Gaze 117 | # use_gaze : False 118 | use_gaze : True 119 | 120 | # Action 121 | use_action : False 122 | # use_action : True 123 | 124 | # Person embedding 125 | # head_embedding_type : liner 126 | head_embedding_type : mlp 127 | 128 | # Whole image 129 | # use_img : False 130 | use_img : True 131 | 132 | # person-person transformer 133 | people_feat_dim : 16 134 | # use_people_people_trans: False 135 | use_people_people_trans: True 136 | people_people_trans_enc_num : 2 137 | mha_num_heads_people_people : 2 138 | # p_p_estimator_type : fc_shallow 139 | # p_p_estimator_type : fc_middle 140 | # p_p_estimator_type : fc_deep 141 | # p_p_estimator_type : deconv_shallow 142 | # p_p_estimator_type : deconv_middle 143 | # p_p_estimator_type : deconv_deep 144 | # p_p_estimator_type : field_shallow 145 | p_p_estimator_type : field_middle 146 | # p_p_estimator_type : field_deep 147 | 148 | # p_p_aggregation_type : ind_only 149 | p_p_aggregation_type : token_only 150 | # p_p_aggregation_type : ind_and_token_ind_based 151 | # p_p_aggregation_type : ind_and_token_token_based 152 | 153 | # rgb-person transformer 154 | rgb_feat_dim : 256 155 | rgb_people_trans_enc_num : 1 156 | mha_num_heads_rgb_people : 1 157 | 158 | # rgb_cnn_extractor_type : rgb_patch 159 | # rgb_patch_size : 8 160 | # rgb_cnn_extractor_type : resnet18 161 | rgb_cnn_extractor_type : resnet50 162 | # rgb_cnn_extractor_stage_idx : 1 163 | # rgb_cnn_extractor_stage_idx : 2 164 | # rgb_cnn_extractor_stage_idx : 3 165 | rgb_cnn_extractor_stage_idx : 4 166 | 167 | p_s_estimator_type : davt 168 | # p_s_estimator_type : transformer 169 | # p_s_estimator_type : cnn 170 | p_s_estimator_cnn_pretrain : False 171 | # p_s_estimator_cnn_pretrain : True 172 | use_p_s_estimator_att_inside : False 173 | # use_p_s_estimator_att_inside : True -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Intro 2 | 3 | This is the official repository for the following paper: 4 | 5 | Chihiro Nakatani, Hiroaki Kawashima, Norimichi Ukita 6 | Interaction-aware Joint Attention Estimation Using People Attributes, ICCV2023 7 | Project page: [https://toyota-ti.ac.jp/Lab/Denshi/iim/ukita/selection/ICCV2023-PJAE.html](https://www.toyota-ti.ac.jp/Lab/Denshi/iim/ukita/selection/ICCV2023-PJAE.html) 8 | 9 | ![Top page](https://github.com/chihina/Interaction-aware-Joint-Attention-Estimation-Using-People-Attributes/blob/master/ICCV2023-PJAE.png) 10 | 11 | # Citation 12 | 13 | ``` 14 | @inproceedings{DBLP:conf/iccv/NakataniKU23, 15 | author = {Chihiro Nakatani and 16 | Hiroaki Kawashima and 17 | Norimichi Ukita}, 18 | title = {Interaction-aware Joint Attention Estimation Using People Attributes}, 19 | booktitle = {ICCV}, 20 | year = {2023}, 21 | } 22 | ``` 23 | 24 | ## Environment 25 | python 3.6.9 26 | 27 | And you can use requirements.txt 28 | ``` 29 | pip install -r requirements.txt 30 | ``` 31 | 32 | # Data preparation 33 | ## 1. Download dataset 34 | You can download daatset from the following url. 35 | These dataset are required to place in data/ in the repository. 36 | 37 | * Volleyball dataset (data/videos) 38 | https://github.com/mostafa-saad/deep-activity-rec 39 | 40 | * Volleyball dataset (data/jae_dataset_bbox_gt, data/jae_dataset_bbox_pred) 41 | https://drive.google.com/drive/folders/1O55_wri92uv87g-2aDh8ll6dFVupmFaB?usp=share_link 42 | 43 | * Volleyball dataset (data/vatic_ball_annotation/annotation_data) 44 | https://drive.google.com/drive/folders/1O55_wri92uv87g-2aDh8ll6dFVupmFaB?usp=share_link 45 | 46 | * VideoCoAtt dataset (data/VideoCoAtt_Dataset) 47 | http://www.stat.ucla.edu/~lifengfan/shared_attention 48 | 49 | * VideoCoAtt dataset (data/VideoCoAtt_Dataset/dets_heads) 50 | https://drive.google.com/drive/folders/1O55_wri92uv87g-2aDh8ll6dFVupmFaB?usp=share_link 51 | 52 | ## 2. Training 53 | * You can change parameters of the model (e.g., multi-head numbers, transformer encoder numbers, ...) by editing the yaml files. 54 | * Trained model are also published in here (https://drive.google.com/drive/folders/1O55_wri92uv87g-2aDh8ll6dFVupmFaB?usp=share_link 55 | ) 56 | * trained models required to place in saved_weights/volleyball or saved_weights/videocoatt in the repository. 57 | 58 | 59 | ### 2.1 Volleyball dataset 60 | 61 | * Ours 62 | ``` 63 | python train.py yaml/volleyball/train_ours_p_p.yaml 64 | python train.py yaml/volleyball/train_ours.yaml 65 | ``` 66 | The following folder contains the trained models. 67 | 1. volleyball-dual-mid_p_p_field_middle_p_s_davt_bbox_PRED_gaze_PRED_act_PRED_weight_fusion_fine_token_only (Ex.1) 68 | 2. volleyball-dual-mid_p_p_field_middle_p_s_davt_bbox_GT_gaze_GT_act_GT_weight_fusion_fine_token_only (Ex.2) 69 | 70 | * DAVT 71 | ``` 72 | python train.py yaml/volleyball/train_ours_p_p.yaml 73 | python train.py yaml/volleyball/train_ours.yaml 74 | ``` 75 | The following folder contains the trained models. 76 | 1. volleyball-dual-mid_p_p_field_middle_p_s_davt_bbox_PRED_gaze_PRED_act_PRED_p_s_only (Ex.1) 77 | 2. volleyball-dual-volleyball-dual-mid_p_p_field_middle_p_s_davt_bbox_GT_gaze_GT_act_GT_p_s_only (Ex.2) 78 | 79 | * ISA 80 | ``` 81 | python train.py yaml/volleyball/train_ours_isa.yaml 82 | ``` 83 | The following folder contains the trained models. 84 | 1. volleyball-dual-isa_bbox_PRED_gaze_PRED_act_PRED (Ex.1) 85 | 2. volleyball-dual-isa_bbox_GT_gaze_GT_act_GT (Ex.2) 86 | 87 | 88 | ### 2.2 VideoCoAtt dataset 89 | 90 | * Ours 91 | ``` 92 | python train.py yaml/videocoatt/train_ours_p_p.yaml 93 | python train.py yaml/videocoatt/train_ours.yaml 94 | ``` 95 | The following folder contains the trained models. 96 | 1. videocoatt-dual-p_p_field_deep_p_s_davt_scalar_weight_fix (Ex.1) 97 | 2. videocoatt-dual-p_p_field_deep_p_s_davt_scalar_weight_fix_token_only_GT (Ex.2) 98 | 99 | * DAVT 100 | Trained model is published in here (https://github.com/ejcgt/attention-target-detection) 101 | 102 | * ISA 103 | ``` 104 | python train.py yaml/videocoatt/train_ours_isa.yaml 105 | ``` 106 | The following folder contains the trained models. 107 | 1. videocoatt-isa_bbox_PRED_gaze_PRED (Ex.1) 108 | 2. videocoatt-isa_bbox_GT_gaze_GT (Ex.2) 109 | 110 | * HGTD 111 | ``` 112 | python train.py yaml/videoattentiontarget/train_hgt.yaml 113 | ``` 114 | The following folder contains the trained models. 115 | 1. videocoatt-videoattentiontarget-hgt-high (Ex.1 and Ex.2) 116 | 117 | ## 3. Evaluation 118 | ### 3.1 Volleyball dataset 119 | You can choose the model which you would like to evaluate in the yaml files. 120 | 121 | * Ours and DAVT 122 | ``` 123 | python eval_on_volleyball_ours.py yaml/volleyball/eval.yaml 124 | ``` 125 | 126 | * ISA 127 | ``` 128 | python eval_on_videocoatt_isa.py yaml/volleyball/eval.yaml 129 | ``` 130 | 131 | ### 3.2 VideoCoAtt dataset 132 | 133 | * Ours and DAVT 134 | ``` 135 | python eval_on_videocoatt_ours.py yaml/videocoatt/eval.yaml 136 | ``` 137 | 138 | * ISA 139 | ``` 140 | python eval_on_videocoatt_isa.py yaml/videocoatt/eval.yaml 141 | ``` 142 | 143 | * HGTD 144 | ``` 145 | python eval_on_videocoatt_hgt.py yaml/videocoatt/eval.yaml 146 | ``` 147 | 148 | ## 4. Demo 149 | You can choose the model which you would like to evaluate in the yaml files. 150 | 151 | ### 4.1 Volleyball dataset 152 | ``` 153 | python demo_ours.py yaml/volleyball/demo.yaml 154 | ``` 155 | 156 | ### 4.2 VideoCoAtt dataset 157 | ``` 158 | python demo_ours.py yaml/videocoatt/demo.yaml 159 | ``` 160 | -------------------------------------------------------------------------------- /yaml_files/volleyball/debug_ours_p_p.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | name: volleyball 3 | sendo_dataset_dir : data/volleyball_tracking_annotation 4 | rgb_dataset_dir : data/videos 5 | annotation_dir : data/vatic_ball_annotation/annotation_data_sub/ 6 | dataset_bbox_gt: data/jae_dataset_bbox_gt 7 | dataset_bbox_pred: data/jae_dataset_bbox_pred 8 | 9 | exp_set: 10 | save_folder: saved_weights 11 | wandb_name: debug 12 | wandb_log : False 13 | 14 | batch_size: 1 15 | num_workers: 1 16 | seed_num: 777 17 | gpu_mode : True 18 | gpu_start : 4 19 | gpu_finish : 4 20 | 21 | resize_height : 320 22 | resize_width : 640 23 | resize_head_width: 64 24 | resize_head_height: 64 25 | 26 | exp_params: 27 | 28 | use_person_person_att_loss : False 29 | # use_person_person_att_loss : True 30 | person_person_att_loss_weight : 1 31 | 32 | # use_person_person_jo_att_loss : False 33 | use_person_person_jo_att_loss : True 34 | person_person_jo_att_loss_weight : 1 35 | 36 | use_person_scene_att_loss : False 37 | # use_person_scene_att_loss : True 38 | person_scene_att_loss_weight : 1 39 | 40 | use_person_scene_jo_att_loss : False 41 | # use_person_scene_jo_att_loss : True 42 | person_scene_jo_att_loss_weight : 1 43 | 44 | use_final_jo_att_loss : False 45 | # use_final_jo_att_loss : True 46 | final_jo_att_loss_weight : 1 47 | 48 | use_frame_type: mid 49 | # use_frame_type: all 50 | 51 | bbox_types: GT 52 | # bbox_types: PRED 53 | 54 | gaze_types: GT 55 | # gaze_types: PRED 56 | 57 | action_types: GT 58 | # action_types: PRED 59 | 60 | # position augmentation 61 | use_position_aug: False 62 | # use_position_aug: True 63 | position_aug_std: 0.05 64 | 65 | # loss function 66 | loss : mse 67 | # loss : bce 68 | 69 | # learning rate 70 | # lr : 0.0001 71 | lr : 0.001 72 | # lr : 0.01 73 | # lr : 0.1 74 | 75 | # gt gaussian 76 | # gaussian_sigma: 40 77 | # gaussian_sigma: 20 78 | gaussian_sigma: 10 79 | # gaussian_sigma: 5 80 | # gaussian_sigma: 1 81 | 82 | # learning schedule 83 | nEpochs : 500 84 | start_iter : 0 85 | snapshots : 100 86 | scheduler_start : 1000 87 | scheduler_iter : 1100000 88 | 89 | # pretrained models 90 | pretrained_models_dir: saved_weights 91 | 92 | # use_pretrained_head_pose_estimator: False 93 | use_pretrained_head_pose_estimator: True 94 | pretrained_head_pose_estimator_name: volleyball-head_pose_estimator 95 | # freeze_head_pose_estimator: False 96 | freeze_head_pose_estimator: True 97 | 98 | use_pretrained_saliency_extractor: False 99 | # use_pretrained_saliency_extractor: True 100 | pretrained_saliency_extractor_name: pretrained_scene_extractor_davt 101 | freeze_saliency_extractor: False 102 | # freeze_saliency_extractor: True 103 | 104 | use_pretrained_joint_attention_estimator: False 105 | # use_pretrained_joint_attention_estimator: True 106 | # pretrained_joint_attention_estimator_name: pretrain_head_estimator 107 | pretrained_joint_attention_estimator_name: videocoatt-dual-people_field_middle_token_only_bbox_GT_gaze_GT 108 | # pretrained_joint_attention_estimator_name: videocoatt-dual-people_field_middle_token_only_bbox 109 | freeze_joint_attention_estimator: False 110 | # freeze_joint_attention_estimator: True 111 | 112 | model_params: 113 | model_type: ja_transformer_dual_only_people 114 | 115 | # Position 116 | # use_position : False 117 | use_position : True 118 | 119 | # Gaze 120 | # use_gaze : False 121 | use_gaze : True 122 | 123 | # Action 124 | # use_action : False 125 | use_action : True 126 | 127 | # Person embedding 128 | # head_embedding_type : liner 129 | head_embedding_type : mlp 130 | 131 | # Whole image 132 | # use_img : False 133 | use_img : True 134 | 135 | # person-person transformer 136 | people_feat_dim : 16 137 | # people_feat_dim : 32 138 | # use_people_people_trans: False 139 | use_people_people_trans: True 140 | 141 | # people_people_trans_enc_num : 1 142 | people_people_trans_enc_num : 2 143 | # people_people_trans_enc_num : 3 144 | # people_people_trans_enc_num : 4 145 | 146 | # mha_num_heads_people_people : 1 147 | mha_num_heads_people_people : 2 148 | # mha_num_heads_people_people : 4 149 | # mha_num_heads_people_people : 8 150 | # mha_num_heads_people_people : 16 151 | 152 | # p_p_estimator_type : fc_shallow 153 | # p_p_estimator_type : fc_middle 154 | # p_p_estimator_type : fc_deep 155 | # p_p_estimator_type : deconv_shallow 156 | # p_p_estimator_type : deconv_middle 157 | p_p_estimator_type : field_middle 158 | # p_p_estimator_type : field_deep 159 | 160 | # p_p_aggregation_type : ind_only 161 | p_p_aggregation_type : token_only 162 | # p_p_aggregation_type : token_only_concat 163 | # p_p_aggregation_type : ind_and_token_ind_based 164 | # p_p_aggregation_type : ind_and_token_token_based 165 | 166 | # rgb-person transformer 167 | rgb_feat_dim : 256 168 | rgb_people_trans_enc_num : 1 169 | mha_num_heads_rgb_people : 1 170 | 171 | # rgb_cnn_extractor_type : rgb_patch 172 | # rgb_patch_size : 8 173 | # rgb_cnn_extractor_type : resnet18 174 | rgb_cnn_extractor_type : resnet50 175 | # rgb_cnn_extractor_stage_idx : 1 176 | # rgb_cnn_extractor_stage_idx : 2 177 | # rgb_cnn_extractor_stage_idx : 3 178 | rgb_cnn_extractor_stage_idx : 4 179 | 180 | p_s_estimator_type : davt 181 | # p_s_estimator_type : transformer 182 | # p_s_estimator_type : cnn 183 | p_s_estimator_cnn_pretrain : False 184 | # p_s_estimator_cnn_pretrain : True 185 | use_p_s_estimator_att_inside : False 186 | # use_p_s_estimator_att_inside : True -------------------------------------------------------------------------------- /models/inferring_shared_attention_estimation_debug.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models as models 4 | import sys 5 | 6 | class InferringSharedAttentionEstimator(nn.Module): 7 | def __init__(self, cfg): 8 | super(InferringSharedAttentionEstimator, self).__init__() 9 | 10 | ## set useful variables 11 | self.epsilon = 1e-7 12 | self.pi = 3.1415 13 | 14 | ## set data 15 | self.dataset_name = cfg.data.name 16 | 17 | ## exp settings 18 | self.resize_width = cfg.exp_set.resize_width 19 | self.resize_height = cfg.exp_set.resize_height 20 | 21 | self.gpu_list = range(cfg.exp_set.gpu_start, cfg.exp_set.gpu_finish+1) 22 | self.device = torch.device(f"cuda:{self.gpu_list[0]}") 23 | self.wandb_name = cfg.exp_set.wandb_name 24 | self.batch_size = cfg.exp_set.batch_size 25 | 26 | # define loss function 27 | self.loss = cfg.exp_params.loss 28 | if self.loss == 'mse': 29 | print('Use MSE loss function') 30 | self.loss_func_joint_attention = nn.MSELoss() 31 | elif self.loss == 'bce': 32 | print('Use BCE loss function') 33 | self.loss_func_joint_attention = nn.BCELoss() 34 | elif self.loss == 'l1': 35 | print('Use l1 loss function') 36 | self.loss_func_joint_attention = nn.L1Loss() 37 | self.use_e_att_loss = cfg.exp_params.use_e_att_loss 38 | 39 | self.conv_in_channels = 1+1 40 | self.spatial_detection_module = nn.Sequential( 41 | nn.Conv2d(in_channels=self.conv_in_channels, out_channels=16, kernel_size=3, padding=1), 42 | nn.ReLU(inplace=False), 43 | nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding=1), 44 | nn.ReLU(inplace=False), 45 | nn.Conv2d(in_channels=16, out_channels=8, kernel_size=3, padding=1), 46 | nn.ReLU(inplace=False), 47 | nn.Conv2d(in_channels=8, out_channels=1, kernel_size=1), 48 | # nn.Sigmoid(), 49 | ) 50 | 51 | def forward(self, inp): 52 | head_vector = inp['head_vector'] 53 | head_feature = inp['head_feature'] 54 | xy_axis_map = inp['xy_axis_map'] 55 | head_xy_map = inp['head_xy_map'] 56 | gaze_xy_map = inp['gaze_xy_map'] 57 | saliency_img = inp['saliency_img'] 58 | 59 | people_exist_mask = (torch.sum(head_feature, dim=-1) != 0).bool() 60 | people_exist_num = torch.sum(people_exist_mask, dim=-1) 61 | 62 | # generate head xy map 63 | head_xy_map = head_xy_map * head_feature[:, :, :2, None, None] 64 | 65 | # generate gaze xy map 66 | gaze_xy_map = gaze_xy_map * head_vector[:, :, :, None, None] 67 | 68 | # generate gaze cone map 69 | xy_axis_map_dif_head = xy_axis_map - head_xy_map 70 | x_axis_map_dif_head_mul_gaze = xy_axis_map_dif_head * gaze_xy_map 71 | xy_dot_product = torch.sum(x_axis_map_dif_head_mul_gaze, dim=2) 72 | xy_dot_product = xy_dot_product / (torch.norm(xy_axis_map_dif_head, dim=2) + self.epsilon) 73 | xy_dot_product = xy_dot_product / (torch.norm(gaze_xy_map, dim=2) + self.epsilon) 74 | 75 | # calculate theta and distance map 76 | theta_x_y = torch.acos(torch.clamp(xy_dot_product, -1+self.epsilon, 1-self.epsilon)) 77 | 78 | # generate sigma of gaussian 79 | # multiply zero to padding maps 80 | self.gaussian_sigma = 0.5 81 | angle_dist = torch.exp(-torch.pow(theta_x_y, 2)/(2*self.gaussian_sigma**2)) / self.gaussian_sigma 82 | angle_dist = angle_dist * (torch.sum(head_feature, dim=2) != 0)[:, :, None, None] 83 | 84 | # sum all gaze maps (divide people num excluding padding people) 85 | angle_dist_sum_pooling = torch.sum(angle_dist, dim=1)[:, None, :, :] 86 | angle_dist_sum_pooling = angle_dist_sum_pooling/people_exist_num[:, None, None, None] 87 | 88 | # cat angle img and saliency img 89 | # spatial detection module 90 | angle_saliency_img = torch.cat([angle_dist_sum_pooling, saliency_img], dim=1) 91 | estimated_joint_attention = self.spatial_detection_module(angle_saliency_img) 92 | 93 | # return final img 94 | estimated_joint_attention = estimated_joint_attention[:, 0, :, :] 95 | 96 | # pack return values 97 | data = {} 98 | data['head_tensor'] = head_vector 99 | data['img_pred'] = estimated_joint_attention 100 | data['angle_dist'] = angle_dist 101 | data['angle_dist_pool'] = angle_dist_sum_pooling 102 | data['saliency_map'] = saliency_img 103 | 104 | return data 105 | 106 | def calc_loss(self, inp, out, cfg): 107 | # unpack data 108 | img_gt = inp['img_gt'] 109 | att_inside_flag = inp['att_inside_flag'] 110 | img_pred = out['img_pred'] 111 | 112 | # switch loss coeficient 113 | if self.use_e_att_loss: 114 | loss_map_coef = 1 115 | else: 116 | loss_map_coef = 0 117 | 118 | # calculate final map loss 119 | img_gt_att = torch.sum(img_gt, dim=1) 120 | img_gt_att_thresh = torch.ones(1, dtype=img_gt_att.dtype, device=img_gt_att.device) 121 | img_gt_att = torch.where(img_gt_att>img_gt_att_thresh, img_gt_att_thresh, img_gt_att) 122 | loss_map = self.loss_func_joint_attention(img_pred.float(), img_gt_att.float()) 123 | loss_map = loss_map_coef * loss_map 124 | 125 | loss_set = {} 126 | loss_set['loss_map'] = loss_map 127 | 128 | return loss_set -------------------------------------------------------------------------------- /analysis/iccv2023/input_ablation_on_videocoatt.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | import pandas as pd 5 | import numpy as np 6 | 7 | 8 | saved_result_dir = os.path.join('results', 'videocoatt') 9 | 10 | # define analyze model type 11 | analyze_name_list = [] 12 | analyze_name_list.append('videocoatt-p_p_field_deep_p_s_davt_scalar_weight_fix_token_only_GT') 13 | # analyze_name_list.append('videocoatt-p_p_field_deep_p_s_davt_scalar_weight_fix_token_only') 14 | 15 | # define ablate type 16 | analyze_name_ablation_list = [] 17 | analyze_name_ablation_list.append('_wo_position') 18 | analyze_name_ablation_list.append('_wo_gaze') 19 | analyze_name_ablation_list.append('_wo_p_p') 20 | analyze_name_ablation_list.append('_wo_p_s') 21 | analyze_name_ablation_list.append('') 22 | 23 | # define model name 24 | model_name_list = [] 25 | model_name_list.append('Ours w/o p') 26 | model_name_list.append('Ours w/o g') 27 | model_name_list.append('Ours w/o branch (a)') 28 | model_name_list.append('Ours w/o branch (b)') 29 | model_name_list.append('Ours') 30 | 31 | # define test data type 32 | test_data_type_list = [] 33 | test_data_type_list.append('bbox_gt_gaze_True') 34 | # test_data_type_list.append('bbox_det_gaze_False') 35 | for test_data_type in test_data_type_list: 36 | print(f'==={test_data_type}===') 37 | for analyze_name in analyze_name_list: 38 | eval_results_list = [] 39 | analyze_name_type = analyze_name 40 | for ablation_name in analyze_name_ablation_list: 41 | 42 | model_name = f'{analyze_name_type}{ablation_name}' 43 | json_file_path = os.path.join(saved_result_dir, model_name, 'eval_results', test_data_type, 'eval_results.json') 44 | 45 | with open(json_file_path, 'r') as f: 46 | eval_results_dic = json.load(f) 47 | 48 | eval_results_dic_update = {} 49 | if ablation_name == '_wo_p_p': 50 | eval_results_dic_update['Dist p-p (euc)'] = eval_results_dic['l2_dist_euc_p_p'] 51 | eval_results_dic_update['Dist p-s (euc)'] = eval_results_dic['l2_dist_euc_p_s'] 52 | eval_results_dic_update['Dist final (euc)'] = eval_results_dic['l2_dist_euc_p_s'] 53 | for i in range(20): 54 | thr = i*10 55 | eval_results_dic_update[f'Det p-p (Thr={thr})'] = eval_results_dic[f'Det p-p (Thr={thr})'] 56 | for i in range(20): 57 | thr = i*10 58 | eval_results_dic_update[f'Det p-s (Thr={thr})'] = eval_results_dic[f'Det p-s (Thr={thr})'] 59 | for i in range(20): 60 | thr = i*10 61 | eval_results_dic_update[f'Det final (Thr={thr})'] = eval_results_dic[f'Det p-s (Thr={thr})'] 62 | eval_results_dic_update['Accuracy p-p'] = eval_results_dic['accuracy p-p'] 63 | eval_results_dic_update['Accuracy p-s'] = eval_results_dic['accuracy p-s'] 64 | eval_results_dic_update['Accuracy final'] = eval_results_dic['accuracy p-s'] 65 | eval_results_dic_update['F-score p-p'] = eval_results_dic['f1 p-p'] 66 | eval_results_dic_update['F-score p-s'] = eval_results_dic['f1 p-s'] 67 | eval_results_dic_update['F-score final'] = eval_results_dic['f1 p-s'] 68 | eval_results_dic_update['AUC p-p'] = eval_results_dic['auc p-p'] 69 | eval_results_dic_update['AUC p-s'] = eval_results_dic['auc p-s'] 70 | eval_results_dic_update['AUC final'] = eval_results_dic['auc p-s'] 71 | else: 72 | eval_results_dic_update['Dist p-p (euc)'] = eval_results_dic['l2_dist_euc_p_p'] 73 | eval_results_dic_update['Dist p-s (euc)'] = eval_results_dic['l2_dist_euc_p_s'] 74 | eval_results_dic_update['Dist final (euc)'] = eval_results_dic['l2_dist_euc_final'] 75 | for i in range(20): 76 | thr = i*10 77 | eval_results_dic_update[f'Det p-p (Thr={thr})'] = eval_results_dic[f'Det p-p (Thr={thr})'] 78 | for i in range(20): 79 | thr = i*10 80 | eval_results_dic_update[f'Det p-s (Thr={thr})'] = eval_results_dic[f'Det p-s (Thr={thr})'] 81 | for i in range(20): 82 | thr = i*10 83 | eval_results_dic_update[f'Det final (Thr={thr})'] = eval_results_dic[f'Det final (Thr={thr})'] 84 | eval_results_dic_update['Accuracy p-p'] = eval_results_dic['accuracy p-p'] 85 | eval_results_dic_update['Accuracy p-s'] = eval_results_dic['accuracy p-s'] 86 | eval_results_dic_update['Accuracy final'] = eval_results_dic['accuracy final'] 87 | eval_results_dic_update['F-score p-p'] = eval_results_dic['f1 p-p'] 88 | eval_results_dic_update['F-score p-s'] = eval_results_dic['f1 p-s'] 89 | eval_results_dic_update['F-score final'] = eval_results_dic['f1 final'] 90 | eval_results_dic_update['AUC p-p'] = eval_results_dic['auc p-p'] 91 | eval_results_dic_update['AUC p-s'] = eval_results_dic['auc p-s'] 92 | eval_results_dic_update['AUC final'] = eval_results_dic['auc final'] 93 | 94 | eval_results_list.append(list(eval_results_dic_update.values())) 95 | eval_metrics_list = list(eval_results_dic_update.keys()) 96 | 97 | eval_results_array = np.array(eval_results_list) 98 | df_eval_results = pd.DataFrame(eval_results_array, model_name_list, eval_metrics_list) 99 | save_csv_file_path = os.path.join(saved_result_dir, f'ablation_{analyze_name}_{test_data_type}_videocoatt.csv') 100 | df_eval_results.to_csv(save_csv_file_path) -------------------------------------------------------------------------------- /analysis/iccv2023/comparison_on_videocoatt.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | import pandas as pd 5 | import numpy as np 6 | import glob 7 | 8 | saved_result_dir = os.path.join('results', 'videocoatt') 9 | 10 | # define model names 11 | model_name_list = [] 12 | model_name_list.append('ISA') 13 | model_name_list.append('DAVT') 14 | model_name_list.append('HGTD') 15 | model_name_list.append('Ours') 16 | 17 | # define test data type 18 | test_data_type_list = [] 19 | test_data_type_list.append('bbox_det_gaze_False') 20 | test_data_type_list.append('bbox_det_gaze_False') 21 | test_data_type_list.append('bbox_gt_gaze_True') 22 | 23 | # define training modality type 24 | train_mode_list = [] 25 | train_mode_list.append('Pr') 26 | train_mode_list.append('GT') 27 | train_mode_list.append('GT') 28 | 29 | # define analize model names 30 | analyze_name_list_dic = {} 31 | 32 | # (Train:Test = Pr:Pr) 33 | analyze_name_list = [] 34 | analyze_name_list.append('videocoatt-isa_bbox_PRED_gaze_PRED') 35 | # analyze_name_list.append('videocoatt-p_p_field_deep_p_s_davt_freeze') 36 | analyze_name_list.append('videoattentiontarget-only_davt_PRED') 37 | # analyze_name_list.append('videoattentiontarget-hgt-high') 38 | analyze_name_list.append('videoattentiontarget-hgt-hgt_bbox_PRED') 39 | analyze_name_list.append('videocoatt-p_p_field_deep_p_s_davt_scalar_weight_fix') 40 | analyze_name_list_dic[0] = analyze_name_list 41 | 42 | # (Train:Test = GT:Pr) 43 | analyze_name_list = [] 44 | analyze_name_list.append('videocoatt-isa_bbox_GT_gaze_GT') 45 | analyze_name_list.append('videocoatt-p_p_field_deep_p_s_davt_freeze') 46 | analyze_name_list.append('videoattentiontarget-hgt-high') 47 | analyze_name_list.append('videocoatt-p_p_field_deep_p_s_davt_scalar_weight_fix_token_only_GT') 48 | analyze_name_list_dic[1] = analyze_name_list 49 | 50 | # (Train:Test = GT:GT) 51 | analyze_name_list = [] 52 | analyze_name_list.append('videocoatt-isa_bbox_GT_gaze_GT') 53 | analyze_name_list.append('videocoatt-p_p_field_deep_p_s_davt_freeze') 54 | analyze_name_list.append('videoattentiontarget-hgt-high') 55 | analyze_name_list.append('videocoatt-p_p_field_deep_p_s_davt_scalar_weight_fix_token_only_GT') 56 | analyze_name_list_dic[2] = analyze_name_list 57 | 58 | for data_type_idx, analyze_name_list in analyze_name_list_dic.items(): 59 | eval_results_list = [] 60 | test_data_type = test_data_type_list[data_type_idx] 61 | print(f'==={test_data_type}===') 62 | for analyze_idx, analyze_name in enumerate(analyze_name_list): 63 | model_name = model_name_list[analyze_idx] 64 | print(model_name, analyze_name) 65 | json_file_path = os.path.join(saved_result_dir, analyze_name, 'eval_results', test_data_type, 'eval_results.json') 66 | with open(json_file_path, 'r') as f: 67 | eval_results_dic = json.load(f) 68 | 69 | eval_results_dic_update = {} 70 | if model_name in ['ISA', 'HGTD']: 71 | eval_results_dic_update['Dist(x)'] = eval_results_dic['l2_dist_x'] 72 | eval_results_dic_update['Dist(y)'] = eval_results_dic['l2_dist_y'] 73 | eval_results_dic_update['Dist(euc)'] = eval_results_dic['l2_dist_euc'] 74 | for i in range(20): 75 | thr = i*10 76 | eval_results_dic_update[f'Det(Thr={thr})'] = eval_results_dic[f'Det (Thr={thr})'] 77 | eval_results_dic_update['Accuracy'] = eval_results_dic['accuracy'] 78 | eval_results_dic_update['Precision'] = eval_results_dic['precision'] 79 | eval_results_dic_update['Recall'] = eval_results_dic['recall'] 80 | eval_results_dic_update['F-score'] = eval_results_dic['f1'] 81 | eval_results_dic_update['AUC'] = eval_results_dic['auc'] 82 | elif model_name in ['DAVT']: 83 | eval_results_dic_update['Dist(x)'] = eval_results_dic['l2_dist_x_p_s'] 84 | eval_results_dic_update['Dist(y)'] = eval_results_dic['l2_dist_y_p_s'] 85 | eval_results_dic_update['Dist(euc)'] = eval_results_dic['l2_dist_euc_p_s'] 86 | for i in range(20): 87 | thr = i*10 88 | eval_results_dic_update[f'Det(Thr={thr})'] = eval_results_dic[f'Det p-s (Thr={thr})'] 89 | eval_results_dic_update['Accuracy'] = eval_results_dic['accuracy p-s'] 90 | eval_results_dic_update['Precision'] = eval_results_dic['precision p-s'] 91 | eval_results_dic_update['Recall'] = eval_results_dic['recall p-s'] 92 | eval_results_dic_update['F-score'] = eval_results_dic['f1 p-s'] 93 | eval_results_dic_update['AUC'] = eval_results_dic['auc p-s'] 94 | elif model_name in ['Ours']: 95 | eval_results_dic_update['Dist(x)'] = eval_results_dic['l2_dist_x_final'] 96 | eval_results_dic_update['Dist(y)'] = eval_results_dic['l2_dist_y_final'] 97 | eval_results_dic_update['Dist(euc)'] = eval_results_dic['l2_dist_euc_final'] 98 | for i in range(20): 99 | thr = i*10 100 | eval_results_dic_update[f'Det(Thr={thr})'] = eval_results_dic[f'Det final (Thr={thr})'] 101 | eval_results_dic_update['Accuracy'] = eval_results_dic['accuracy final'] 102 | eval_results_dic_update['Precision'] = eval_results_dic['precision final'] 103 | eval_results_dic_update['Recall'] = eval_results_dic['recall final'] 104 | eval_results_dic_update['F-score'] = eval_results_dic['f1 final'] 105 | eval_results_dic_update['AUC'] = eval_results_dic['auc final'] 106 | 107 | eval_results_list.append(list(eval_results_dic_update.values())) 108 | eval_metrics_list = list(eval_results_dic_update.keys()) 109 | 110 | eval_results_array = np.array(eval_results_list) 111 | df_eval_results = pd.DataFrame(eval_results_array, model_name_list, eval_metrics_list) 112 | save_csv_file_path = os.path.join(saved_result_dir, f'comparision_on_videocoatt_{train_mode_list[data_type_idx]}_{test_data_type}.csv') 113 | df_eval_results.to_csv(save_csv_file_path) -------------------------------------------------------------------------------- /yaml_files/volleyball/train_ours.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | name: volleyball 3 | sendo_dataset_dir : data/volleyball_tracking_annotation 4 | rgb_dataset_dir : data/videos 5 | annotation_dir : data/vatic_ball_annotation/annotation_data/ 6 | dataset_bbox_gt: data/jae_dataset_bbox_gt 7 | dataset_bbox_pred: data/jae_dataset_bbox_pred 8 | 9 | exp_set: 10 | save_folder: saved_weights 11 | wandb_log : True 12 | 13 | # [Ours] 14 | # model_name: volleyball-dual-mid_p_p_field_middle_p_s_davt_bbox_PRED_gaze_PRED_act_PRED_weight_fusion_fine_token_only 15 | # model_name: volleyball-dual-mid_p_p_field_middle_p_s_davt_bbox_GT_gaze_GT_act_GT_weight_fusion_fine_token_only 16 | 17 | # batch_size: 8 18 | batch_size: 4 19 | num_workers: 16 20 | seed_num: 777 21 | gpu_mode : True 22 | gpu_start : 6 23 | gpu_finish : 6 24 | 25 | resize_height : 320 26 | resize_width : 640 27 | resize_head_width: 64 28 | resize_head_height: 64 29 | 30 | exp_params: 31 | 32 | use_person_person_att_loss : False 33 | # use_person_person_att_loss : True 34 | person_person_att_loss_weight : 1 35 | 36 | # use_person_person_jo_att_loss : False 37 | use_person_person_jo_att_loss : True 38 | person_person_jo_att_loss_weight : 1 39 | 40 | # use_person_scene_att_loss : False 41 | use_person_scene_att_loss : True 42 | person_scene_att_loss_weight : 1 43 | 44 | use_person_scene_jo_att_loss : False 45 | # use_person_scene_jo_att_loss : True 46 | person_scene_jo_att_loss_weight : 1 47 | 48 | # use_final_jo_att_loss : False 49 | use_final_jo_att_loss : True 50 | final_jo_att_loss_weight : 1 51 | 52 | use_frame_type: mid 53 | # use_frame_type: all 54 | 55 | bbox_types: GT 56 | # bbox_types: PRED 57 | 58 | gaze_types: GT 59 | # gaze_types: PRED 60 | 61 | action_types: GT 62 | # action_types: PRED 63 | 64 | # position augmentation 65 | use_position_aug: False 66 | # use_position_aug: True 67 | position_aug_std: 0.05 68 | 69 | # loss function 70 | loss : mse 71 | # loss : bce 72 | 73 | # learning rate 74 | lr : 0.00001 75 | # lr : 0.0001 76 | # lr : 0.001 77 | # lr : 0.01 78 | # lr : 0.1 79 | 80 | # gt gaussian 81 | # gaussian_sigma: 40 82 | gaussian_sigma: 10 83 | # gaussian_sigma: 5 84 | 85 | # learning schedule 86 | # nEpochs : 500 87 | nEpochs : 15 88 | # nEpochs : 50 89 | start_iter : 0 90 | snapshots : 100 91 | scheduler_start : 1000 92 | scheduler_iter : 1100000 93 | 94 | # pretrained models 95 | pretrained_models_dir: saved_weights 96 | 97 | # use_pretrained_head_pose_estimator: False 98 | use_pretrained_head_pose_estimator: True 99 | pretrained_head_pose_estimator_name: volleyball-head_pose_estimator 100 | # freeze_head_pose_estimator: False 101 | freeze_head_pose_estimator: True 102 | 103 | # use_pretrained_saliency_extractor: False 104 | use_pretrained_saliency_extractor: True 105 | # pretrained_saliency_extractor_name: pretrained_scene_extractor_davt 106 | pretrained_saliency_extractor_name: volleyball-dual-mid_p_p_field_middle_p_s_davt_bbox_GT_gaze_GT_act_GT_p_s_only 107 | # pretrained_saliency_extractor_name: volleyball-dual-mid_p_p_field_middle_p_s_davt_bbox_PRED_gaze_PRED_act_PRED_p_s_only 108 | freeze_saliency_extractor: False 109 | # freeze_saliency_extractor: True 110 | 111 | # use_pretrained_joint_attention_estimator: False 112 | use_pretrained_joint_attention_estimator: True 113 | pretrained_joint_attention_estimator_name: volleyball-dual-mid_p_p_field_middle_bbox_GT_gaze_GT_act_GT 114 | # pretrained_joint_attention_estimator_name: volleyball-dual-mid_p_p_field_middle_bbox_PRED_gaze_PRED_act_PRED 115 | freeze_joint_attention_estimator: False 116 | # freeze_joint_attention_estimator: True 117 | 118 | model_params: 119 | model_type: ja_transformer_dual 120 | 121 | # Position 122 | # use_position : False 123 | use_position : True 124 | 125 | # Gaze 126 | # use_gaze : False 127 | use_gaze : True 128 | 129 | # Action 130 | # use_action : False 131 | use_action : True 132 | 133 | # Person embedding 134 | # head_embedding_type : liner 135 | head_embedding_type : mlp 136 | 137 | # Whole image 138 | # use_img : False 139 | use_img : True 140 | 141 | # person-person transformer 142 | people_feat_dim : 16 143 | # people_feat_dim : 32 144 | # use_people_people_trans: False 145 | use_people_people_trans: True 146 | # people_people_trans_enc_num : 1 147 | people_people_trans_enc_num : 2 148 | # people_people_trans_enc_num : 3 149 | # people_people_trans_enc_num : 4 150 | # mha_num_heads_people_people : 1 151 | mha_num_heads_people_people : 2 152 | # mha_num_heads_people_people : 4 153 | # mha_num_heads_people_people : 8 154 | # mha_num_heads_people_people : 16 155 | 156 | # p_p_estimator_type : fc_shallow 157 | # p_p_estimator_type : fc_middle 158 | # p_p_estimator_type : fc_deep 159 | # p_p_estimator_type : deconv_shallow 160 | # p_p_estimator_type : deconv_middle 161 | # p_p_estimator_type : deconv_deep 162 | # p_p_estimator_type : field_shallow 163 | p_p_estimator_type : field_middle 164 | # p_p_estimator_type : field_deep 165 | 166 | # p_p_aggregation_type : ind_only 167 | p_p_aggregation_type : token_only 168 | # p_p_aggregation_type : ind_and_token_ind_based 169 | # p_p_aggregation_type : ind_and_token_token_based 170 | 171 | # rgb-person transformer 172 | rgb_feat_dim : 256 173 | rgb_people_trans_enc_num : 1 174 | mha_num_heads_rgb_people : 1 175 | 176 | # rgb_cnn_extractor_type : rgb_patch 177 | # rgb_patch_size : 8 178 | # rgb_cnn_extractor_type : resnet18 179 | rgb_cnn_extractor_type : resnet50 180 | # rgb_cnn_extractor_stage_idx : 1 181 | # rgb_cnn_extractor_stage_idx : 2 182 | # rgb_cnn_extractor_stage_idx : 3 183 | rgb_cnn_extractor_stage_idx : 4 184 | 185 | p_s_estimator_type : davt 186 | # p_s_estimator_type : transformer 187 | # p_s_estimator_type : cnn 188 | p_s_estimator_cnn_pretrain : False 189 | # p_s_estimator_cnn_pretrain : True 190 | use_p_s_estimator_att_inside : False 191 | # use_p_s_estimator_att_inside : True 192 | 193 | # fusion_net_type : early 194 | # fusion_net_type : mid 195 | # fusion_net_type : late 196 | fusion_net_type : scalar_weight 197 | # fusion_net_type : simple_average -------------------------------------------------------------------------------- /models/hourglass.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | import torchvision.models as models 5 | from torch.autograd import Variable 6 | 7 | Pool = nn.MaxPool2d 8 | 9 | def batchnorm(x): 10 | return nn.BatchNorm2d(x.size()[1])(x) 11 | 12 | class Conv(nn.Module): 13 | def __init__(self, inp_dim, out_dim, kernel_size=3, stride = 1, bn = False, relu = True): 14 | super(Conv, self).__init__() 15 | self.inp_dim = inp_dim 16 | self.conv = nn.Conv2d(inp_dim, out_dim, kernel_size, stride, padding=(kernel_size-1)//2, bias=True) 17 | self.relu = None 18 | self.bn = None 19 | if relu: 20 | self.relu = nn.ReLU() 21 | if bn: 22 | self.bn = nn.BatchNorm2d(out_dim) 23 | 24 | def forward(self, x): 25 | assert x.size()[1] == self.inp_dim, "{} {}".format(x.size()[1], self.inp_dim) 26 | x = self.conv(x) 27 | if self.bn is not None: 28 | x = self.bn(x) 29 | if self.relu is not None: 30 | x = self.relu(x) 31 | return x 32 | 33 | class Residual(nn.Module): 34 | def __init__(self, inp_dim, out_dim): 35 | super(Residual, self).__init__() 36 | self.relu = nn.ReLU() 37 | self.bn1 = nn.BatchNorm2d(inp_dim) 38 | self.conv1 = Conv(inp_dim, int(out_dim/2), 1, relu=False) 39 | self.bn2 = nn.BatchNorm2d(int(out_dim/2)) 40 | self.conv2 = Conv(int(out_dim/2), int(out_dim/2), 3, relu=False) 41 | self.bn3 = nn.BatchNorm2d(int(out_dim/2)) 42 | self.conv3 = Conv(int(out_dim/2), out_dim, 1, relu=False) 43 | self.skip_layer = Conv(inp_dim, out_dim, 1, relu=False) 44 | if inp_dim == out_dim: 45 | self.need_skip = False 46 | else: 47 | self.need_skip = True 48 | 49 | def forward(self, x): 50 | if self.need_skip: 51 | residual = self.skip_layer(x) 52 | else: 53 | residual = x 54 | out = self.bn1(x) 55 | out = self.relu(out) 56 | out = self.conv1(out) 57 | out = self.bn2(out) 58 | out = self.relu(out) 59 | out = self.conv2(out) 60 | out = self.bn3(out) 61 | out = self.relu(out) 62 | out = self.conv3(out) 63 | out += residual 64 | return out 65 | 66 | class Hourglass(nn.Module): 67 | def __init__(self, n, f, bn=None, increase=0): 68 | super(Hourglass, self).__init__() 69 | nf = f + increase 70 | self.up1 = Residual(f, f) 71 | # Lower branch 72 | self.pool1 = Pool(2, 2) 73 | self.low1 = Residual(f, nf) 74 | self.n = n 75 | # Recursive hourglass 76 | if self.n > 1: 77 | self.low2 = Hourglass(n-1, nf, bn=bn) 78 | else: 79 | self.low2 = Residual(nf, nf) 80 | self.low3 = Residual(nf, f) 81 | self.up2 = nn.Upsample(scale_factor=2, mode='nearest') 82 | 83 | def forward(self, x): 84 | up1 = self.up1(x) 85 | pool1 = self.pool1(x) 86 | low1 = self.low1(pool1) 87 | low2 = self.low2(low1) 88 | low3 = self.low3(low2) 89 | up2 = self.up2(low3) 90 | 91 | return up1 + up2 92 | 93 | class Merge(nn.Module): 94 | def __init__(self, x_dim, y_dim): 95 | super(Merge, self).__init__() 96 | self.conv = Conv(x_dim, y_dim, 1, relu=False, bn=False) 97 | 98 | def forward(self, x): 99 | return self.conv(x) 100 | 101 | # Hourglass networks 102 | class HourglassNet(nn.Module): 103 | def __init__(self, nstack, inp_dim, oup_dim, bn=False, increase=0, **kwargs): 104 | super(HourglassNet, self).__init__() 105 | 106 | self.nstack = nstack 107 | self.pre = nn.Sequential( 108 | Conv(3, 64, 7, 2, bn=True, relu=True), 109 | Residual(64, 128), 110 | Pool(2, 2), 111 | Residual(128, 128), 112 | Residual(128, inp_dim) 113 | ) 114 | 115 | self.hgs = nn.ModuleList( [ 116 | nn.Sequential( 117 | Hourglass(1, inp_dim, bn, increase), 118 | ) for i in range(nstack)] ) 119 | 120 | self.features = nn.ModuleList( [ 121 | nn.Sequential( 122 | Residual(inp_dim, inp_dim), 123 | Conv(inp_dim, inp_dim, 1, bn=True, relu=True) 124 | ) for i in range(nstack)] ) 125 | 126 | self.outs = nn.ModuleList( [Conv(inp_dim, oup_dim, 1, relu=False, bn=False) for i in range(nstack)] ) 127 | self.merge_features = nn.ModuleList( [Merge(inp_dim, inp_dim) for i in range(nstack-1)] ) 128 | self.merge_preds = nn.ModuleList( [Merge(oup_dim, inp_dim) for i in range(nstack-1)] ) 129 | self.nstack = nstack 130 | self.sigmoid = nn.Sigmoid() 131 | 132 | def forward(self, inp): 133 | rgb_img = inp['rgb_img'] 134 | rgb_img_wo_norm = inp['rgb_img_wo_norm'] 135 | batch_size, frame_num, channel, resize_height, resize_width = rgb_img_wo_norm.shape 136 | rgb_img_wo_norm = rgb_img_wo_norm.reshape(batch_size*frame_num, channel, resize_height, resize_width) 137 | 138 | x = self.pre(rgb_img_wo_norm) 139 | combined_hm_preds = [] 140 | for i in range(self.nstack): 141 | hg = self.hgs[i](x) 142 | feature = self.features[i](hg) 143 | preds = self.outs[i](feature) 144 | preds[:, 0, :, :] = self.sigmoid(preds[:, 0, :, :]) 145 | combined_hm_preds.append(preds) 146 | if i < self.nstack - 1: 147 | x = x + self.merge_preds[i](preds) + self.merge_features[i](feature) 148 | 149 | # only return plob map (excluding size and offset) 150 | saliency_img = torch.stack(combined_hm_preds, 1)[:, -1, 0, :, :][:, None, :, :] 151 | ori_height, ori_width = rgb_img.shape[-2:] 152 | saliency_img = F.interpolate(saliency_img, (ori_height, ori_width), mode='bilinear') 153 | saliency_img = saliency_img.reshape(batch_size, frame_num, 1, ori_height, ori_width) 154 | 155 | # pack return values 156 | data = {} 157 | data['saliency_img'] = saliency_img 158 | 159 | return data 160 | 161 | def calc_loss(self, inp, out, cfg): 162 | loss_set = {} 163 | return loss_set --------------------------------------------------------------------------------