├── models ├── __init__.py ├── get_model.py ├── baseline_anticipate_cnn.py ├── baseline_pose.py ├── model_pos.py ├── backbone │ ├── resnet_based.py │ ├── imagenet_pretraining.py │ └── resnet │ │ ├── bottleneck.py │ │ ├── basicblock.py │ │ └── resnet.py └── model_concat.py ├── data ├── __init__.py ├── get_data_loader.py └── jaad_loc.py ├── .gitignore ├── utils ├── downsample.py ├── __init__.py ├── parse_box_from_segm.py ├── check_stip_time.py ├── check_ckpt.py ├── build.py ├── logger.py ├── check_switch.py ├── data_proc_loc.py ├── visual.py ├── stats.py ├── temp_plot.py ├── temp_plot_stip.py ├── colormap.py ├── draw_As.py ├── tracker.py ├── masking.py ├── postproc_smooth.py ├── stip_merge_views.py └── data_proc.py ├── args ├── __init__.py ├── train_args.py └── test_args.py ├── graphs ├── acc_temp.png ├── acc_across_time.png ├── acc_temp_graph.png ├── mix_temp_graph.png ├── prob_temp_graph.png ├── acc_temp_graph_smooth.png ├── mix_temp_graph_smooth.png └── prob_temp_graph_smooth.png ├── scripts_dir ├── train_pos.sh ├── train_baseline_anticipate_cnn.sh ├── train_baseline_pose.sh ├── wandb │ └── offline-run-20210314_201226-2eciwruc │ │ └── logs │ │ └── .nfs000000000001e91000000039 ├── train_concat.sh ├── test_baseline_anticipate_cnn.sh ├── train_loc_concat.sh ├── test_loc_concat.sh ├── train_concat_stip.sh ├── train_graph_fromMasks.sh ├── test_concat_stip.sh ├── train_graph.sh ├── test_concat_stip_side.sh ├── test_concat.sh ├── train_loc_graph.sh ├── train_graph_stip_all.sh ├── train_graph_stip.sh ├── test_graph_stip_all.sh ├── test_graph_stip.sh └── test_graph.sh ├── cache_data.py ├── README.md ├── train.py ├── cache_data_stip.py └── test.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .get_model import * 2 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .get_data_loader import * 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | bin/ 2 | datasets/ 3 | *.pyc 4 | __pycache__/ 5 | -------------------------------------------------------------------------------- /utils/downsample.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | 4 | 5 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build 2 | from .data_proc import * 3 | 4 | -------------------------------------------------------------------------------- /args/__init__.py: -------------------------------------------------------------------------------- 1 | from .train_args import TrainArgs 2 | from .test_args import TestArgs 3 | -------------------------------------------------------------------------------- /graphs/acc_temp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordVL/STR-PIP/HEAD/graphs/acc_temp.png -------------------------------------------------------------------------------- /utils/parse_box_from_segm.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | import numpy as np 4 | 5 | -------------------------------------------------------------------------------- /graphs/acc_across_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordVL/STR-PIP/HEAD/graphs/acc_across_time.png -------------------------------------------------------------------------------- /graphs/acc_temp_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordVL/STR-PIP/HEAD/graphs/acc_temp_graph.png -------------------------------------------------------------------------------- /graphs/mix_temp_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordVL/STR-PIP/HEAD/graphs/mix_temp_graph.png -------------------------------------------------------------------------------- /graphs/prob_temp_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordVL/STR-PIP/HEAD/graphs/prob_temp_graph.png -------------------------------------------------------------------------------- /graphs/acc_temp_graph_smooth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordVL/STR-PIP/HEAD/graphs/acc_temp_graph_smooth.png -------------------------------------------------------------------------------- /graphs/mix_temp_graph_smooth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordVL/STR-PIP/HEAD/graphs/mix_temp_graph_smooth.png -------------------------------------------------------------------------------- /graphs/prob_temp_graph_smooth.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordVL/STR-PIP/HEAD/graphs/prob_temp_graph_smooth.png -------------------------------------------------------------------------------- /utils/check_stip_time.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | 4 | stip_root = '/vision/group/prolix' 5 | ped_root = os.path.join(stip_root, 'processed/pedestrians') 6 | img_root = os.path.join(stip_root, 'images_20fps') 7 | -------------------------------------------------------------------------------- /utils/check_ckpt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | ckpt_root = '/sailhome/bingbin/STR-PIP/ckpts/' 5 | dataset = 'JAAD' 6 | 7 | def check(ckpt, ptype='pred'): 8 | f_best_pred = os.path.join(ckpt_root, dataset, ckpt, 'best_{}.pth'.format(ptype)) 9 | pred = torch.load(f_best_pred) 10 | print('best {} epoch:'.format(ptype), pred['epoch']) 11 | 12 | if __name__ == '__main__': 13 | ckpt = 'graph_gru_seq30_pred30_lr1.0e-04_wd1.0e-05_bt16_posNone_branchboth_collapse0_combinepair_adjTypespatial_nLayers2_v4Feats_pedGRU_3evalEpoch' 14 | check(ckpt, 'pred') 15 | check(ckpt, 'last') 16 | check(ckpt, 'det') 17 | -------------------------------------------------------------------------------- /utils/build.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import random 4 | import torch 5 | 6 | import args 7 | from .logger import Logger 8 | 9 | 10 | def build(is_train): 11 | opt, log = args.TrainArgs().parse() if is_train else args.TestArgs().parse() 12 | 13 | if not is_train: 14 | print('Options:') 15 | opt_dict = vars(opt) 16 | for key in sorted(opt_dict): 17 | print('{}: {}'.format(key, opt_dict[key])) 18 | if is_train: 19 | print('lr_init:', opt.lr_init) 20 | print('wd:', opt.wd) 21 | print('ckpt:', opt.ckpt_path) 22 | print() 23 | 24 | os.makedirs(opt.ckpt_path, exist_ok=True) 25 | 26 | # Set seed 27 | torch.manual_seed(2019) 28 | torch.cuda.manual_seed_all(2019) 29 | np.random.seed(2019) 30 | random.seed(2019) 31 | 32 | logger = Logger(opt.ckpt_path, opt.split) 33 | 34 | return opt, logger 35 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | 5 | def blue(string): 6 | return '\033[94m'+string+'\033[0m' 7 | 8 | class Logger: 9 | def __init__(self, ckpt_path, name='debug'): 10 | self.logger = logging.getLogger() 11 | self.logger.setLevel(logging.INFO) 12 | formatter = logging.Formatter('%(asctime)s %(message)s', datefmt=blue('[%Y-%m-%d,%H:%M:%S]')) 13 | 14 | fh = logging.FileHandler(os.path.join(ckpt_path, '{}.log'.format(name)), 'w') 15 | fh.setLevel(logging.INFO) 16 | fh.setFormatter(formatter) 17 | self.logger.addHandler(fh) 18 | 19 | ch = logging.StreamHandler(sys.stdout) 20 | ch.setLevel(logging.INFO) 21 | ch.setFormatter(formatter) 22 | self.logger.addHandler(ch) 23 | 24 | def print(self, log): 25 | if isinstance(log, list): 26 | self.logger.info('\n - '.join(log)) 27 | else: 28 | self.logger.info(log) 29 | -------------------------------------------------------------------------------- /models/get_model.py: -------------------------------------------------------------------------------- 1 | # ped-centric 2 | from .model_concat import ConcatModel 3 | from .model_graph import GraphModel 4 | from .model_pos import PosModel 5 | # loc-centric 6 | from .model_loc_graph import LocGraphModel 7 | from .model_loc_concat import LocConcatModel 8 | # baselines 9 | from .baseline_anticipate_cnn import BaselineAnticipateCNN 10 | from .baseline_pose import BaselinePose 11 | 12 | 13 | def get_model(opt): 14 | # ped-centric 15 | if opt.model == 'concat': 16 | model = ConcatModel(opt) 17 | elif opt.model == 'graph': 18 | model = GraphModel(opt) 19 | elif opt.model == 'pos': 20 | model = PosModel(opt) 21 | # loc-centric 22 | elif opt.model == 'loc_concat': 23 | model = LocConcatModel(opt) 24 | elif opt.model == 'loc_graph': 25 | model = LocGraphModel(opt) 26 | # baselines 27 | elif opt.model == 'baseline_anticipate_cnn': 28 | model = BaselineAnticipateCNN(opt) 29 | elif opt.model == 'baseline_pose': 30 | model = BaselinePose(opt) 31 | else: 32 | raise NotImplementedError 33 | 34 | model.setup() 35 | return model 36 | -------------------------------------------------------------------------------- /args/train_args.py: -------------------------------------------------------------------------------- 1 | from .base_args import BaseArgs 2 | 3 | 4 | class TrainArgs(BaseArgs): 5 | def __init__(self): 6 | super(TrainArgs, self).__init__() 7 | 8 | self.is_train = True 9 | # self.split = 'train' 10 | 11 | self.parser.add_argument('--batch-size', type=int, default=4, help='batch size per gpu') 12 | self.parser.add_argument('--n-epochs', type=int, default=50, help='total # of epochs') 13 | self.parser.add_argument('--n-iters', type=int, default=0, help='total # of iterations') 14 | self.parser.add_argument('--start-epoch', type=int, default=0, help='starting epoch') 15 | self.parser.add_argument('--lr-init', type=float, default=1e-3, help='initial learning rate') 16 | self.parser.add_argument('--lr-decay', type=int, default=0, choices=[0, 1], help='whether to decay learning rate') 17 | self.parser.add_argument('--decay-every', type=int, default=10) 18 | self.parser.add_argument('--wd', type=float, default=1e-5) 19 | self.parser.add_argument('--load-ckpt-dir', type=str, default='', help='directory of checkpoint') 20 | self.parser.add_argument('--load-ckpt-epoch', type=int, default=0, help='epoch to load checkpoint') 21 | -------------------------------------------------------------------------------- /args/test_args.py: -------------------------------------------------------------------------------- 1 | from .base_args import BaseArgs 2 | 3 | 4 | class TestArgs(BaseArgs): 5 | def __init__(self): 6 | super(TestArgs, self).__init__() 7 | 8 | self.is_train = False 9 | # self.split = 'test' 10 | 11 | self.parser.add_argument('--mode', type=str, choices=['extract', 'evaluate']) 12 | 13 | # hyperparameters 14 | self.parser.add_argument('--batch-size', type=int, default=4, help='batch size') 15 | self.parser.add_argument('--which-epoch', type=int, 16 | help='which epochs to evaluate, -1 to load the best checkpoint') 17 | self.parser.add_argument('--slide', type=int, default=1, 18 | help='Whether to use sliding window when testing.') 19 | self.parser.add_argument('--collect-A', type=int, default=0, 20 | help="Whether to collect weight matrices in the graph model.") 21 | self.parser.add_argument('--save-As-format', type=str, default="", 22 | help="Path to saved weight matrices.") 23 | # self.parser.add_argument('--rand-loader', type=int, default=0, 24 | # help="Whether to randomize the data loader.") 25 | 26 | -------------------------------------------------------------------------------- /utils/check_switch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pickle 4 | 5 | def find_segments(labels): 6 | durs = [] 7 | for row in labels: 8 | s_start = 0 9 | for j,val in enumerate(row): 10 | if j and row[j-1] != val: 11 | durs += j-s_start, 12 | s_start = j 13 | durs += len(row) - s_start, 14 | print('# segments:', len(durs)) 15 | print('# avg frames:', np.mean(durs)) 16 | 17 | def find_segments_wrapper(): 18 | ckpt_dir = '/sailhome/bingbin/STR-PIP/ckpts/JAAD/' 19 | ckpt_name = 'graph_gru_seq30_pred30_lr1.0e-05_wd1.0e-05_bt16_posNone_branchped_collapse0_combinepair_adjTypeembed_nLayers2_v2Feats' 20 | label_name = 'label_epochbest_det_stepall.pkl' 21 | fpkl = os.path.join(ckpt_dir, ckpt_name, label_name) 22 | 23 | with open(fpkl, 'rb') as handle: 24 | data = pickle.load(handle) 25 | out = data['out'] 26 | gt = data['GT'] 27 | 28 | print('Out:') 29 | find_segments(out) 30 | print('GT:') 31 | find_segments(gt) 32 | 33 | print('\nOut det:') 34 | find_segments(out[:, :30]) 35 | print('\nOut pred:') 36 | find_segments(out[:, 30:]) 37 | print('\nGT det:') 38 | find_segments(gt[:, :30]) 39 | print('\nGT pred:') 40 | find_segments(gt[:, 30:]) 41 | 42 | if __name__ == '__main__': 43 | find_segments_wrapper() 44 | -------------------------------------------------------------------------------- /scripts_dir/train_pos.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | gpu_id=0 4 | split='train' 5 | n_acts=9 6 | 7 | lr=1e-4 8 | wd=1e-5 9 | bt=16 10 | lr_decay=1 11 | decay_every=20 12 | n_workers=0 13 | save_every=10 14 | evaluate_every=1 15 | 16 | seq_len=30 17 | predict=1 18 | pred_seq_len=30 19 | 20 | load_cache='pos' 21 | # masks 22 | cache_format='/sailhome/bingbin/STR-PIP/datasets/cache/jaad_collapse_max/{}/ped{}_fid{}.pkl' 23 | 24 | use_gru=1 25 | branch='both' 26 | pos_mode='none' 27 | use_gt_act=1 28 | collapse_cls=0 29 | combine_method='pair' 30 | suffix='_'$pos_mode'_9acts_withGTAct_noPos' 31 | ckpt_name='branch'$branch'_collapse'$collapse_cls'_combine'$combine_method$suffix 32 | 33 | CUDA_VISIBLE_DEVICES=$gpu_id python3 -m train.py \ 34 | --model='pos' \ 35 | --split=$split \ 36 | --n-acts=$n_acts \ 37 | --device=$gpu_id \ 38 | --dset-name='JAAD' \ 39 | --ckpt-name=$ckpt_name \ 40 | --lr-init=$lr \ 41 | --wd=$wd \ 42 | --lr-decay=$lr_decay \ 43 | --decay-every=$decay_every \ 44 | --seq-len=$seq_len \ 45 | --predict=$predict \ 46 | --pred-seq-len=$pred_seq_len \ 47 | --batch-size=$bt \ 48 | --save-every=$save_every \ 49 | --evaluate-every=$evaluate_every \ 50 | --n-workers=$n_workers \ 51 | --load-cache=$load_cache \ 52 | --cache-format=$cache_format \ 53 | --branch=$branch \ 54 | --pos-mode=$pos_mode \ 55 | --use-gt-act=$use_gt_act \ 56 | --collapse-cls=$collapse_cls \ 57 | --combine-method=$combine_method \ 58 | --use-gru=$use_gru \ 59 | -------------------------------------------------------------------------------- /scripts_dir/train_baseline_anticipate_cnn.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | gpu_id=1 4 | split='train' 5 | n_acts=9 6 | 7 | lr=1e-4 8 | wd=1e-5 9 | bt=4 10 | lr_decay=1 11 | decay_every=20 12 | n_workers=0 13 | save_every=10 14 | evaluate_every=1 15 | 16 | seq_len=30 17 | predict=1 18 | pred_seq_len=30 19 | 20 | load_cache='masks' 21 | # masks 22 | cache_format='/sailhome/bingbin/STR-PIP/datasets/cache/jaad_collapse_max/{}/ped{}_fid{}.pkl' 23 | 24 | use_gru=1 25 | use_gt_act=0 26 | branch='ped' 27 | pos_mode='none' 28 | collapse_cls=0 29 | combine_method='pair' 30 | suffix='_noGTAct_addDetLoss' 31 | ckpt_name='branch'$branch'_collapse'$collapse_cls'_combine'$combine_method$suffix 32 | 33 | CUDA_VISIBLE_DEVICES=$gpu_id python3 -m train.py \ 34 | --model='baseline_anticipate_cnn' \ 35 | --split=$split \ 36 | --n-acts=$n_acts \ 37 | --device=$gpu_id \ 38 | --dset-name='JAAD' \ 39 | --ckpt-name=$ckpt_name \ 40 | --lr-init=$lr \ 41 | --wd=$wd \ 42 | --lr-decay=$lr_decay \ 43 | --decay-every=$decay_every \ 44 | --seq-len=$seq_len \ 45 | --predict=$predict \ 46 | --pred-seq-len=$pred_seq_len \ 47 | --batch-size=$bt \ 48 | --save-every=$save_every \ 49 | --evaluate-every=$evaluate_every \ 50 | --n-workers=$n_workers \ 51 | --load-cache=$load_cache \ 52 | --cache-format=$cache_format \ 53 | --branch=$branch \ 54 | --pos-mode=$pos_mode \ 55 | --collapse-cls=$collapse_cls \ 56 | --combine-method=$combine_method \ 57 | --use-gru=$use_gru \ 58 | --use-gt-act=$use_gt_act \ 59 | -------------------------------------------------------------------------------- /utils/data_proc_loc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import numpy as np 4 | 5 | H, W = 1080, 1920 6 | k = 0.25 7 | c = -20 8 | y_max = 200 9 | 10 | def parse_split(fannot_in, fannot_out): 11 | with open(fannot_in, 'rb') as handle: 12 | peds = pickle.load(handle) 13 | 14 | vids = {} 15 | for ped in peds: 16 | vid = ped['vid'] 17 | if vid not in vids: 18 | vids[vid] = { 19 | 'act': np.zeros([len(ped['act']), 1]), 20 | 'ped_pos': [[] for _ in range(len(ped['act']))] 21 | } 22 | for fid, pos in enumerate(ped['pos_GT']): 23 | if len(pos) == 0: 24 | continue 25 | x, y, w, h = pos 26 | cx, cy = x+0.5*w, y+h 27 | # check if a pedestrian is in the trapezoid area 28 | if H-cy <= k*cx+c and H-cy <= y_max and H-cy <= k*(W-cx)+c: 29 | vids[vid]['act'][fid] = 1 30 | vids[vid]['ped_pos'][fid] += pos, 31 | 32 | with open(fannot_out, 'wb') as handle: 33 | pickle.dump(vids, handle) 34 | 35 | 36 | def parse_split_wrapper(): 37 | annot_root = '/sailhome/ajarno/STR-PIP/datasets' 38 | 39 | # ftrain = 'annot_train_ped_withTag_sanityWithPose.pkl' 40 | # annot_train = os.path.join(annot_root, ftrain) 41 | # annot_train_out = os.path.join(annot_root, 'annot_train_loc.pkl') 42 | # parse_split(annot_train, annot_train_out) 43 | 44 | # ftest = 'annot_test_ped_withTag_sanityWithPose.pkl' 45 | ftest = 'annot_test_ped.pkl' 46 | annot_test = os.path.join(annot_root, ftest) 47 | annot_test_out = os.path.join(annot_root, 'annot_test_loc_new.pkl') 48 | parse_split(annot_test, annot_test_out) 49 | 50 | 51 | if __name__ == '__main__': 52 | parse_split_wrapper() 53 | -------------------------------------------------------------------------------- /scripts_dir/train_baseline_pose.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | gpu_id=0 4 | split='train' 5 | n_acts=1 6 | 7 | n_epochs=100 8 | lr=1e-4 9 | wd=1e-7 10 | bt=8 11 | lr_decay=1 12 | decay_every=20 13 | n_workers=0 14 | save_every=10 15 | evaluate_every=1 16 | fc_in_dim=256 17 | 18 | seq_len=14 19 | predict=1 20 | pred_seq_len=1 21 | 22 | load_cache='pos' 23 | # masks 24 | cache_format='/sailhome/bingbin/STR-PIP/datasets/cache/jaad_collapse_max/{}/ped{}_fid{}.pkl' 25 | save_cache_format=$cache_format 26 | 27 | use_gru=1 28 | use_gt_act=0 29 | use_pose=1 30 | branch='ped' 31 | pos_mode='none' 32 | collapse_cls=0 33 | combine_method='pair' 34 | suffix='_gru'$fc_in_dim'_zeroPad' 35 | ckpt_name='branch'$branch'_collapse'$collapse_cls'_combine'$combine_method$suffix 36 | 37 | CUDA_VISIBLE_DEVICES=$gpu_id python3 -m train.py \ 38 | --model='baseline_pose' \ 39 | --split=$split \ 40 | --n-acts=$n_acts \ 41 | --device=$gpu_id \ 42 | --dset-name='JAAD' \ 43 | --ckpt-name=$ckpt_name \ 44 | --n-epochs=$n_epochs \ 45 | --lr-init=$lr \ 46 | --wd=$wd \ 47 | --lr-decay=$lr_decay \ 48 | --decay-every=$decay_every \ 49 | --seq-len=$seq_len \ 50 | --predict=$predict \ 51 | --pred-seq-len=$pred_seq_len \ 52 | --batch-size=$bt \ 53 | --save-every=$save_every \ 54 | --evaluate-every=$evaluate_every \ 55 | --fc-in-dim=$fc_in_dim \ 56 | --n-workers=$n_workers \ 57 | --load-cache=$load_cache \ 58 | --cache-format=$cache_format \ 59 | --save-cache-format=$save_cache_format \ 60 | --branch=$branch \ 61 | --pos-mode=$pos_mode \ 62 | --collapse-cls=$collapse_cls \ 63 | --combine-method=$combine_method \ 64 | --use-gru=$use_gru \ 65 | --use-gt-act=$use_gt_act \ 66 | --use-pose=$use_pose \ 67 | -------------------------------------------------------------------------------- /scripts_dir/wandb/offline-run-20210314_201226-2eciwruc/logs/.nfs000000000001e91000000039: -------------------------------------------------------------------------------- 1 | 2021-03-14 20:12:27,342 INFO MainThread:3357636 [internal.py:wandb_internal():91] W&B internal server running at pid: 3357636, started at: 2021-03-14 20:12:27.341321 2 | 2021-03-14 20:12:27,344 DEBUG HandlerThread:3357636 [handler.py:handle_request():101] handle_request: run_start 3 | 2021-03-14 20:12:27,344 INFO WriterThread:3357636 [datastore.py:open_for_write():77] open: /sailhome/agalczak/crossing/scripts_dir/wandb/offline-run-20210314_201226-2eciwruc/run-2eciwruc.wandb 4 | 2021-03-14 20:12:27,351 DEBUG HandlerThread:3357636 [meta.py:__init__():34] meta init 5 | 2021-03-14 20:12:27,351 DEBUG HandlerThread:3357636 [meta.py:__init__():48] meta init done 6 | 2021-03-14 20:12:27,351 DEBUG HandlerThread:3357636 [meta.py:probe():190] probe 7 | 2021-03-14 20:12:27,369 DEBUG HandlerThread:3357636 [meta.py:_setup_git():180] setup git 8 | 2021-03-14 20:12:27,455 DEBUG HandlerThread:3357636 [meta.py:_setup_git():187] setup git done 9 | 2021-03-14 20:12:27,455 DEBUG HandlerThread:3357636 [meta.py:_save_pip():52] save pip 10 | 2021-03-14 20:12:27,460 DEBUG HandlerThread:3357636 [meta.py:_save_pip():66] save pip done 11 | 2021-03-14 20:12:27,461 DEBUG HandlerThread:3357636 [meta.py:probe():231] probe done 12 | 2021-03-14 20:22:28,964 WARNING MainThread:3357636 [internal.py:is_dead():344] Internal process exiting, parent pid 3357618 disappeared 13 | 2021-03-14 20:22:28,965 ERROR MainThread:3357636 [internal.py:wandb_internal():142] Internal process shutdown. 14 | 2021-03-14 20:22:29,293 INFO HandlerThread:3357636 [handler.py:finish():532] shutting down handler 15 | 2021-03-14 20:22:29,295 INFO WriterThread:3357636 [datastore.py:close():258] close: /sailhome/agalczak/crossing/scripts_dir/wandb/offline-run-20210314_201226-2eciwruc/run-2eciwruc.wandb 16 | 2021-03-14 20:22:29,439 INFO SenderThread:3357636 [sender.py:finish():814] shutting down sender 17 | -------------------------------------------------------------------------------- /scripts_dir/train_concat.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #!/usr/bin/env python3 3 | 4 | 5 | gpu_id=0 6 | split='train' 7 | n_acts=9 8 | 9 | lr=1e-4 10 | wd=1e-5 11 | bt=4 12 | lr_decay=1 13 | decay_every=20 14 | n_workers=0 15 | save_every=10 16 | evaluate_every=1 17 | 18 | seq_len=30 19 | predict=1 20 | pred_seq_len=30 21 | predict_k=0 22 | 23 | annot_ped_format='/sailhome/ajarno/STR-PIP/datasets/annot_{}_ped_withTag_sanityWithPose.pkl' 24 | 25 | load_cache='masks' 26 | cache_format='/sailhome/ajarno/STR-PIP/datasets/cache/jaad_collapse_max/{}/ped{}_fid{}.pkl' 27 | save_cache_format=$cache_format 28 | 29 | use_gru=1 30 | use_trn=0 31 | ped_gru=1 32 | # pos_mode='center' 33 | pos_mode='none' 34 | use_act=0 35 | use_gt_act=0 36 | use_pose=0 37 | branch='both' 38 | collapse_cls=0 39 | combine_method='pair' 40 | suffix='_cacheMasks_fixGRU_eval3_9acts_noAct_sanityWithPose_withReLU_pedGRU' 41 | suffix='_test_tmp' 42 | ckpt_name='branch'$branch'_collapse'$collapse_cls'_combine'$combine_method$suffix 43 | 44 | # CUDA_VISIBLE_DEVICES=$gpu_id python3 -m train.py \ 45 | WANDB_MODE=dryrun CUDA_VISIBLE_DEVICES=$gpu_id python3 -m train.py \ 46 | --model='concat' \ 47 | --split=$split \ 48 | --n-acts=$n_acts \ 49 | --device=$gpu_id \ 50 | --dset-name='JAAD' \ 51 | --ckpt-name=$ckpt_name \ 52 | --lr-init=$lr \ 53 | --wd=$wd \ 54 | --lr-decay=$lr_decay \ 55 | --decay-every=$decay_every \ 56 | --seq-len=$seq_len \ 57 | --predict=$predict \ 58 | --pred-seq-len=$pred_seq_len \ 59 | --predict-k=$predict_k \ 60 | --batch-size=$bt \ 61 | --save-every=$save_every \ 62 | --evaluate-every=$evaluate_every \ 63 | --n-workers=$n_workers \ 64 | --annot-ped-format=$annot_ped_format \ 65 | --load-cache=$load_cache \ 66 | --cache-format=$cache_format \ 67 | --save-cache-format=$save_cache_format \ 68 | --branch=$branch \ 69 | --collapse-cls=$collapse_cls \ 70 | --combine-method=$combine_method \ 71 | --use-gru=$use_gru \ 72 | --use-trn=$use_trn \ 73 | --ped-gru=$ped_gru \ 74 | --pos-mode=$pos_mode \ 75 | --use-act=$use_act \ 76 | --use-gt-act=$use_gt_act \ 77 | --use-pose=$use_pose \ 78 | -------------------------------------------------------------------------------- /scripts_dir/test_baseline_anticipate_cnn.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | mode='evaluate' 4 | 5 | gpu_id=1 6 | split='test' 7 | n_acts=9 8 | n_workers=0 9 | bt=1 10 | 11 | # test only 12 | slide=0 13 | rand_test=1 14 | log_every=10 15 | ckpt_dir='/sailhome/bingbin/STR-PIP/ckpts' 16 | dataset='JAAD' 17 | 18 | ckpt_name="baseline_anticipate_cnn_gru_seq30_pred30_lr1.0e-04_wd1.0e-05_bt4_posNone_branchped_collapse0_combinepair_noGTAct_addDetLoss" 19 | 20 | seq_len=30 21 | predict=1 22 | pred_seq_len=30 23 | 24 | load_cache='masks' 25 | # masks 26 | cache_format='/sailhome/bingbin/STR-PIP/datasets/cache/jaad_collapse_max/{}/ped{}_fid{}.pkl' 27 | 28 | which_epoch=-1 29 | if [ $which_epoch -eq -1 ] 30 | then 31 | epoch_name='best_pred' 32 | else 33 | epoch_name=$which_epoch 34 | fi 35 | save_output=1 36 | save_output_format=$ckpt_dir'/'$dataset'/'$ckpt_name'/output_epoch'$epoch_name'_step{}.pkl' 37 | 38 | if [ "$mode" = "extract" ] 39 | then 40 | extract_feats_dir='/sailhome/bingbin/STR-PIP/datasets/cache/JAAD_conv_feats/concat_gru_lr1.0e-05_bt4_test_epoch5/test/' 41 | else 42 | extract_feats_dir='none_existent' 43 | fi 44 | 45 | 46 | use_gru=1 47 | use_gt_act=0 48 | branch='ped' 49 | pos_mode='none' 50 | collapse_cls=0 51 | combine_method='pair' 52 | 53 | CUDA_VISIBLE_DEVICES=$gpu_id python3 test.py \ 54 | --model='baseline_anticipate_cnn' \ 55 | --mode=$mode \ 56 | --slide=$slide \ 57 | --rand-test=$rand_test \ 58 | --split=$split \ 59 | --n-acts=$n_acts \ 60 | --device=$gpu_id \ 61 | --dset-name='JAAD' \ 62 | --ckpt-dir=$ckpt_dir \ 63 | --ckpt-name=$ckpt_name \ 64 | --which-epoch=$which_epoch \ 65 | --save-output=$save_output \ 66 | --save-output-format=$save_output_format \ 67 | --seq-len=$seq_len \ 68 | --predict=$predict \ 69 | --pred-seq-len=$pred_seq_len \ 70 | --batch-size=$bt \ 71 | --n-workers=$n_workers \ 72 | --load-cache=$load_cache \ 73 | --cache-format=$cache_format \ 74 | --branch=$branch \ 75 | --pos-mode=$pos_mode \ 76 | --collapse-cls=$collapse_cls \ 77 | --combine-method=$combine_method \ 78 | --use-gru=$use_gru \ 79 | --use-gt-act=$use_gt_act \ 80 | -------------------------------------------------------------------------------- /scripts_dir/train_loc_concat.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | gpu_id=0 4 | split='train' 5 | n_acts=1 6 | 7 | n_epochs=80 8 | start_epoch=2 9 | lr=1e-4 10 | wd=1e-5 11 | bt=1 12 | lr_decay=1 13 | decay_every=20 14 | n_workers=0 15 | save_every=10 16 | evaluate_every=1 17 | 18 | seq_len=30 19 | predict=1 20 | pred_seq_len=30 21 | predict_k=0 22 | 23 | annot_loc_format='/sailhome/ajarno/STR-PIP/datasets/annot_{}_loc.pkl' 24 | 25 | load_cache='masks' 26 | cache_format='/sailhome/ajarno/STR-PIP/datasets/cache/jaad_loc/{}/ped{}_fid{}.pkl' 27 | save_cache_format=$cache_format 28 | 29 | pretrained_path='/sailhome/ajarno/STR-PIP/ckpts/JAAD_loc/loc_concat_gru_seq30_pred30_lr1.0e-04_wd1.0e-05_bt1_posNone_branchboth_collapse0_combinepair_tmp/best_pred.pth' 30 | 31 | 32 | use_gru=1 33 | use_trn=0 34 | ped_gru=1 35 | # pos_mode='center' 36 | pos_mode='none' 37 | use_act=0 38 | use_gt_act=0 39 | use_pose=0 40 | branch='both' 41 | collapse_cls=0 42 | combine_method='pair' 43 | suffix='_tmp' 44 | ckpt_name='branch'$branch'_collapse'$collapse_cls'_combine'$combine_method$suffix 45 | 46 | CUDA_VISIBLE_DEVICES=$gpu_id python3 -m train.py \ 47 | --model='loc_concat' \ 48 | --split=$split \ 49 | --n-acts=$n_acts \ 50 | --device=$gpu_id \ 51 | --dset-name='JAAD_loc' \ 52 | --ckpt-name=$ckpt_name \ 53 | --n-epochs=$n_epochs \ 54 | --start-epoch=$start_epoch \ 55 | --lr-init=$lr \ 56 | --wd=$wd \ 57 | --lr-decay=$lr_decay \ 58 | --decay-every=$decay_every \ 59 | --seq-len=$seq_len \ 60 | --predict=$predict \ 61 | --pred-seq-len=$pred_seq_len \ 62 | --predict-k=$predict_k \ 63 | --batch-size=$bt \ 64 | --save-every=$save_every \ 65 | --evaluate-every=$evaluate_every \ 66 | --n-workers=$n_workers \ 67 | --annot-loc-format=$annot_loc_format \ 68 | --load-cache=$load_cache \ 69 | --cache-format=$cache_format \ 70 | --save-cache-format=$save_cache_format \ 71 | --branch=$branch \ 72 | --collapse-cls=$collapse_cls \ 73 | --combine-method=$combine_method \ 74 | --use-gru=$use_gru \ 75 | --use-trn=$use_trn \ 76 | --ped-gru=$ped_gru \ 77 | --pos-mode=$pos_mode \ 78 | --use-act=$use_act \ 79 | --use-gt-act=$use_gt_act \ 80 | --use-pose=$use_pose \ 81 | -------------------------------------------------------------------------------- /scripts_dir/test_loc_concat.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # mode='evaluate' 4 | mode='extract' 5 | 6 | gpu_id=0 7 | n_workers=0 8 | n_acts=1 9 | 10 | seq_len=1 11 | predict=0 12 | pred_seq_len=30 13 | 14 | # test only 15 | slide=0 16 | rand_test=1 17 | log_every=10 18 | 19 | 20 | split='train' 21 | # split='test' 22 | use_gru=1 23 | use_trn=0 24 | pos_mode='none' 25 | use_act=0 26 | use_gt_act=0 27 | use_pose=0 28 | # branch='ped' 29 | branch='both' 30 | collapse_cls=0 31 | combine_method='pair' 32 | 33 | annot_loc_format='/sailhome/bingbin/STR-PIP/datasets/annot_{}_loc.pkl' 34 | 35 | load_cache='masks' 36 | # load_cache='none' 37 | 38 | cache_format='/sailhome/bingbin/STR-PIP/datasets/cache/jaad_loc/{}/ped{}_fid{}.pkl' 39 | save_cache_format=$cache_format 40 | 41 | ckpt_name='loc_concat_gru_seq30_pred30_lr1.0e-04_wd1.0e-05_bt1_posNone_branchboth_collapse0_combinepair_tmp' 42 | 43 | # -1 for the best epoch 44 | which_epoch=-1 45 | 46 | # this is to set a non-existent epoch s.t. the features are extracted from ImageNet backbone 47 | # which_epoch=100 48 | 49 | if [ "$mode" = "extract" ] 50 | then 51 | extract_feats_dir='/sailhome/bingbin/STR-PIP/datasets/cache/jaad_loc/JAAD_conv_feats/'$ckpt_name'/'$split'/' 52 | else 53 | extract_feats_dir='none_existent' 54 | fi 55 | 56 | CUDA_VISIBLE_DEVICES=$gpu_id python3 test.py \ 57 | --model='loc_concat' \ 58 | --split=$split \ 59 | --n-acts=$n_acts \ 60 | --mode=$mode \ 61 | --device=$gpu_id \ 62 | --log-every=$log_every \ 63 | --dset-name='JAAD_loc' \ 64 | --ckpt-name=$ckpt_name \ 65 | --batch-size=1 \ 66 | --n-workers=$n_workers \ 67 | --annot-loc-format=$annot_loc_format \ 68 | --load-cache=$load_cache \ 69 | --save-cache-format=$save_cache_format \ 70 | --cache-format=$cache_format \ 71 | --seq-len=$seq_len \ 72 | --predict=$predict \ 73 | --pred-seq-len=$pred_seq_len \ 74 | --use-gru=$use_gru \ 75 | --use-trn=$use_trn \ 76 | --use-act=$use_act \ 77 | --use-gt-act=$use_gt_act \ 78 | --use-pose=$use_pose \ 79 | --pos-mode=$pos_mode \ 80 | --collapse-cls=$collapse_cls \ 81 | --slide=$slide \ 82 | --rand-test=$rand_test \ 83 | --branch=$branch \ 84 | --which-epoch=$which_epoch \ 85 | --extract-feats-dir=$extract_feats_dir 86 | 87 | -------------------------------------------------------------------------------- /models/baseline_anticipate_cnn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from models.model_base import BaseModel 7 | 8 | import pdb 9 | 10 | class BaselineAnticipateCNN(BaseModel): 11 | def __init__(self, opt): 12 | super().__init__(opt) 13 | # nRows: seq_len 14 | # nCols: n_acts, i.e. acts.shape[1] 15 | self.W1 = nn.Conv2d(1, 8, kernel_size=(5,1), padding=(2,0)) 16 | self.W2 = nn.Conv2d(8, 16, kernel_size=(5,1), padding=(2,0)) 17 | self.fc1 = nn.Linear(16*self.seq_len*(self.n_acts//4), 1024) 18 | self.fc2 = nn.Linear(1024, self.seq_len*self.n_acts) 19 | 20 | # init params 21 | self.W1.weight.data.normal_(std=0.1) 22 | self.W1.bias.data.fill_(0.1) 23 | self.W2.weight.data.normal_(std=0.1) 24 | self.W2.bias.data.fill_(0.1) 25 | 26 | self.relu = nn.ReLU() 27 | self.pool = nn.MaxPool2d(kernel_size=(1,2), stride=(1,2)) 28 | self.l2_norm = F.normalize 29 | 30 | if not self.use_gt_act: 31 | self.gru = nn.GRU(self.ped_dim, self.ped_dim, 2, batch_first=True).to(self.device) 32 | 33 | def forward(self, ped, masks, bbox, act, pose=None): 34 | if self.use_gt_act: 35 | # use GT action observations 36 | x = act[:, :self.seq_len].unsqueeze(1) 37 | else: 38 | # predict action labels from pedestrian crops 39 | 40 | # ped_crops: (bt, 30, 3, 224, 224) 41 | B, T, _, _, _ = ped.shape 42 | 43 | ped_crops = ped.view(-1, 3, 224, 224) 44 | ped_crops = ped_crops.type(self.dtype).to(self.device) 45 | ped_feats = self.ped_encoder(ped_crops) 46 | # ped_feats: (B, T, d) 47 | ped_feats = ped_feats.view(B, T, -1) 48 | 49 | temporal_feats, _ = self.gru(ped_feats) 50 | h = self.classifier(temporal_feats) 51 | logits = self.sigmoid(h) 52 | x = (logits > 0.5).type(self.dtype) 53 | x = x.unsqueeze(1) 54 | 55 | # x = acts[:, :self.seq_len].unsqueeze(1) 56 | x = self.W1(x) 57 | x = self.relu(x) 58 | # max pool over the channel dimension 59 | x = self.pool(x) 60 | 61 | x = self.W2(x) 62 | x = self.relu(x) 63 | # max pool over the channel dimension 64 | x = self.pool(x) 65 | 66 | x = x.view(-1, 16*self.seq_len*(self.n_acts // 4)) 67 | x = self.fc1(x) 68 | x = self.fc2(x) 69 | 70 | x = x.view(-1, self.seq_len, self.n_acts) 71 | x = self.l2_norm(x, dim=2) 72 | 73 | if not self.use_gt_act: 74 | # also supervise on the detection 75 | x = torch.cat([h, x], 1) 76 | 77 | return x 78 | -------------------------------------------------------------------------------- /scripts_dir/train_concat_stip.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | gpu_id=0 4 | split='train' 5 | n_acts=1 6 | 7 | lr=1e-5 8 | wd=1e-5 9 | bt=2 10 | lr_decay=1 11 | decay_every=10 12 | n_workers=0 13 | save_every=10 14 | evaluate_every=1 15 | 16 | seq_len=30 17 | predict=1 18 | pred_seq_len=10 19 | predict_k=0 20 | 21 | annot_ped_format='/vision/group/prolix/processed/pedestrians/{}.pkl' 22 | 23 | cache_obj_bbox_format='/vision/group/prolix/processed/obj_bbox_20fps_merged/{}_seg{}.pkl' 24 | cache_obj_bbox_format_left='/vision/group/prolix/processed/left/obj_bbox_20fps_merged/{}_seg{}.pkl' 25 | cache_obj_bbox_format_right='/vision/group/prolix/processed/right/obj_bbox_20fps_merged/{}_seg{}.pkl' 26 | 27 | # load_cache='none' 28 | load_cache='mask' 29 | # cache_format='/sailhome/bingbin/STR-PIP/datasets/cache/jaad_collapse_max/{}/ped{}_fid{}.pkl' 30 | #cache_format='/vision/group/prolix/processed/cache/{}/ped{}_fid{}.pkl' 31 | cache_format=/dev/null 32 | save_cache_format=$cache_format 33 | 34 | use_gru=1 35 | use_trn=0 36 | ped_gru=1 37 | # pos_mode='center' 38 | pos_mode='none' 39 | use_act=0 40 | use_gt_act=0 41 | use_pose=0 42 | branch='both' 43 | collapse_cls=0 44 | combine_method='pair' 45 | suffix='_decay10_tmp' 46 | ckpt_name='branch'$branch'_collapse'$collapse_cls'_combine'$combine_method$suffix 47 | 48 | # CUDA_VISIBLE_DEVICES=$gpu_id python3 -m train.py \ 49 | WANDB_MODE=dryrun CUDA_VISIBLE_DEVICES=$gpu_id python -m train.py \ 50 | --model='concat' \ 51 | --split=$split \ 52 | --n-acts=$n_acts \ 53 | --device=$gpu_id \ 54 | --dset-name='STIP' \ 55 | --ckpt-name=$ckpt_name \ 56 | --lr-init=$lr \ 57 | --wd=$wd \ 58 | --lr-decay=$lr_decay \ 59 | --decay-every=$decay_every \ 60 | --seq-len=$seq_len \ 61 | --predict=$predict \ 62 | --pred-seq-len=$pred_seq_len \ 63 | --predict-k=$predict_k \ 64 | --batch-size=$bt \ 65 | --save-every=$save_every \ 66 | --evaluate-every=$evaluate_every \ 67 | --n-workers=$n_workers \ 68 | --annot-ped-format=$annot_ped_format \ 69 | --load-cache=$load_cache \ 70 | --cache-obj-bbox-format=$cache_obj_bbox_format \ 71 | --cache-obj-bbox-format-left=$cache_obj_bbox_format_left \ 72 | --cache-obj-bbox-format-right=$cache_obj_bbox_format_right \ 73 | --cache-format=$cache_format \ 74 | --save-cache-format=$save_cache_format \ 75 | --branch=$branch \ 76 | --collapse-cls=$collapse_cls \ 77 | --combine-method=$combine_method \ 78 | --use-gru=$use_gru \ 79 | --use-trn=$use_trn \ 80 | --ped-gru=$ped_gru \ 81 | --pos-mode=$pos_mode \ 82 | --use-act=$use_act \ 83 | --use-gt-act=$use_gt_act \ 84 | --use-pose=$use_pose \ 85 | -------------------------------------------------------------------------------- /scripts_dir/train_graph_fromMasks.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | gpu_id=0 4 | n_workers=0 5 | split='train' 6 | n_acts=9 7 | 8 | # train only 9 | n_epochs=80 10 | lr=1e-4 11 | wd=1e-5 12 | bt=6 13 | lr_decay=1 14 | decay_every=30 15 | save_every=10 16 | evaluate_every=1 17 | log_every=30 18 | 19 | load_cache='masks' 20 | # load_cache='feats' 21 | 22 | if [ $load_cache = 'feats' ] 23 | then 24 | cache_root='/sailhome/bingbin/STR-PIP/datasets/cache/JAAD_conv_feats/' 25 | cache_baseformat='/{}/ped{}_fid{}.pkl' 26 | # v3 feats 27 | cache_dir='concat_gru_seq14_pred1_lr1.0e-04_wd1.0e-05_bt4_posNone_branchped_collapse0_combinepair_cacheMasks_fixGRU_eval3_9acts_noAct_sanityWithPose_withReLU_ordered' 28 | 29 | # v2 feats 30 | # cache_format='concat_gru_lr1.0e-04_wd1.0e-05_bt4_ped_collapse0_combinepair_useBBox0_cacheMasks_fixGRU_singleTime' 31 | 32 | # cache_format='imageNet_pretrained_singleTime' 33 | cache_format=$cache_root$cache_dir$cache_baseformat 34 | else 35 | cache_format='/sailhome/bingbin/STR-PIP/datasets/cache/jaad_collapse/{}/ped{}_fid{}.pkl' 36 | fi 37 | 38 | 39 | # Regularization on temporal smoothness 40 | reg_smooth='none' 41 | reg_lambda=5e-4 42 | 43 | seq_len=30 44 | predict=1 45 | pred_seq_len=30 46 | predict_k=0 47 | 48 | # temporal modeling 49 | use_gru=1 50 | use_trn=0 51 | ped_gru=1 52 | ctxt_gru=0 53 | 54 | # features 55 | use_act=0 56 | use_gt_act=0 57 | pos_mode='none' 58 | branch='ped' 59 | adj_type='spatial' 60 | #adj_type='inner' 61 | n_layers=2 62 | collapse_cls=0 63 | combine_method='pair' 64 | 65 | # saving & loading 66 | suffix='_v3Feats_fromMasks_pedGRU' 67 | # suffix='_v3Feats__pedGRU_ctxtGRU' 68 | ckpt_name='branch'$branch'_collapse'$collapse_cls'_combine'$combine_method'_adjType'$adj_type'_nLayers'$n_layers$suffix 69 | # ckpt_name='graph_seq30_layer2_embed' 70 | 71 | CUDA_VISIBLE_DEVICES=$gpu_id python3 -m train.py \ 72 | --model='graph' \ 73 | --reg-smooth=$reg_smooth \ 74 | --reg-lambda=$reg_lambda \ 75 | --split=$split \ 76 | --n-acts=$n_acts \ 77 | --device=$gpu_id \ 78 | --dset-name='JAAD' \ 79 | --ckpt-name=$ckpt_name \ 80 | --n-epochs=$n_epochs \ 81 | --lr-init=$lr \ 82 | --wd=$wd \ 83 | --lr-decay=$lr_decay \ 84 | --decay-every=$decay_every \ 85 | --seq-len=$seq_len \ 86 | --predict=$predict \ 87 | --pred-seq-len=$pred_seq_len \ 88 | --predict-k=$predict_k \ 89 | --batch-size=$bt \ 90 | --save-every=$save_every \ 91 | --evaluate-every=$evaluate_every \ 92 | --log-every=$log_every \ 93 | --n-workers=$n_workers \ 94 | --load-cache=$load_cache \ 95 | --cache-format=$cache_format \ 96 | --branch=$branch \ 97 | --adj-type=$adj_type \ 98 | --n-layers=$n_layers \ 99 | --collapse-cls=$collapse_cls \ 100 | --combine-method=$combine_method \ 101 | --use-gru=$use_gru \ 102 | --use-trn=$use_trn \ 103 | --ped-gru=$ped_gru \ 104 | --ctxt-gru=$ctxt_gru \ 105 | --use-act=$use_act \ 106 | --use-gt-act=$use_gt_act \ 107 | --pos-mode=$pos_mode 108 | -------------------------------------------------------------------------------- /scripts_dir/test_concat_stip.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # mode='evaluate' 4 | mode='extract' 5 | 6 | gpu_id=0 7 | n_workers=0 8 | n_acts=1 9 | 10 | seq_len=1 11 | predict=0 12 | pred_seq_len=2 13 | 14 | # test only 15 | slide=0 16 | rand_test=1 17 | log_every=10 18 | ckpt_dir='/sailhome/bingbin/STR-PIP/ckpts' 19 | dataset='STIP' 20 | 21 | split='test' 22 | use_gru=1 23 | ped_gru=1 24 | use_trn=0 25 | pos_mode='none' 26 | use_act=0 27 | use_gt_act=0 28 | use_pose=0 29 | # branch='ped' 30 | branch='both' 31 | collapse_cls=0 32 | combine_method='pair' 33 | 34 | 35 | annot_ped_format='/vision/group/prolix/processed/pedestrians/{}.pkl' 36 | cache_obj_bbox_format='/vision/group/prolix/processed/obj_bbox_20fps_merged/{}_seg{}.pkl' 37 | 38 | 39 | load_cache='masks' 40 | 41 | cache_format='/vision/group/prolix/processed/cache/{}/ped{}_fid{}.pkl' 42 | save_cache_format=$cache_format 43 | 44 | # ckpt_name='concat_gru_seq8_pred2_lr1.0e-04_wd1.0e-05_bt2_posNone_branchboth_collapse0_combinepair_testANN_hanh1' 45 | 46 | # ckpt_name='concat_gru_seq8_pred2_lr1.0e-04_wd1.0e-05_bt2_posNone_branchboth_collapse0_combinepair_run1' 47 | 48 | ckpt_name='concat_gru_seq8_pred2_lr1.0e-05_wd1.0e-05_bt2_posNone_branchboth_collapse0_combinepair_decay10' 49 | # -1 for the best epoch 50 | which_epoch=-1 51 | 52 | # this is to set a non-existent epoch s.t. the features are extracted from ImageNet backbone 53 | # which_epoch=100 54 | 55 | if [ $which_epoch -eq -1 ] 56 | then 57 | epoch_name='best_pred' 58 | else 59 | epoch_name=$which_epoch 60 | fi 61 | save_output=10 62 | save_output_format=$ckpt_dir'/'$dataset'/'$ckpt_name'/output_epoch'$epoch_name'_step{}.pkl' 63 | 64 | 65 | if [ "$mode" = "extract" ] 66 | then 67 | extract_feats_dir='/vision/group/prolix/processed/cache/STIP_conv_feats/'$ckpt_name'/'$split'/' 68 | else 69 | extract_feats_dir='none_existent' 70 | fi 71 | 72 | CUDA_VISIBLE_DEVICES=$gpu_id python3 test.py \ 73 | --model='concat' \ 74 | --split=$split \ 75 | --n-acts=$n_acts \ 76 | --mode=$mode \ 77 | --device=$gpu_id \ 78 | --log-every=$log_every \ 79 | --dset-name='STIP' \ 80 | --ckpt-name=$ckpt_name \ 81 | --batch-size=1 \ 82 | --n-workers=$n_workers \ 83 | --annot-ped-format=$annot_ped_format \ 84 | --cache-obj-bbox-format=$cache_obj_bbox_format \ 85 | --load-cache=$load_cache \ 86 | --save-cache-format=$save_cache_format \ 87 | --cache-format=$cache_format \ 88 | --seq-len=$seq_len \ 89 | --predict=$predict \ 90 | --pred-seq-len=$pred_seq_len \ 91 | --use-gru=$use_gru \ 92 | --ped-gru=$ped_gru \ 93 | --use-trn=$use_trn \ 94 | --use-act=$use_act \ 95 | --use-gt-act=$use_gt_act \ 96 | --use-pose=$use_pose \ 97 | --pos-mode=$pos_mode \ 98 | --collapse-cls=$collapse_cls \ 99 | --slide=$slide \ 100 | --rand-test=$rand_test \ 101 | --branch=$branch \ 102 | --which-epoch=$which_epoch \ 103 | --save-output=$save_output \ 104 | --save-output-format=$save_output_format \ 105 | --extract-feats-dir=$extract_feats_dir 106 | 107 | -------------------------------------------------------------------------------- /models/baseline_pose.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from models.model_base import BaseModel 7 | 8 | import pdb 9 | 10 | class BaselinePose(BaseModel): 11 | def __init__(self, opt): 12 | super().__init__(opt) 13 | 14 | if self.use_gru: 15 | # aggregate poses with FC + GRU 16 | self.embed_pose = nn.Sequential( 17 | nn.Linear(9*2, 128), 18 | nn.ReLU(), 19 | nn.Linear(128, self.fc_in_dim), 20 | ) 21 | self.gru = nn.GRU(self.fc_in_dim, self.fc_in_dim, 2, batch_first=True) 22 | if self.predict: 23 | self.gru_pred = nn.GRU(self.fc_in_dim, self.fc_in_dim, 2, batch_first=True) 24 | else: 25 | # aggregating poses using 1D Conv as in BaselineAnticipateCNN 26 | # nRows: seq_len 27 | # nCols: n_acts, i.e. acts.shape[1] 28 | self.W1 = nn.Conv2d(1, 8, kernel_size=(5,1), padding=(2,0)) 29 | self.W2 = nn.Conv2d(8, 16, kernel_size=(5,1), padding=(2,0)) 30 | hidden_size = 256 31 | self.fc1 = nn.Linear(16*self.seq_len*((9*2)//4), hidden_size) # 9 joints x 2 coor (x & y) 32 | self.fc2 = nn.Linear(hidden_size, self.pred_seq_len*self.n_acts) 33 | 34 | # init params 35 | self.W1.weight.data.normal_(std=0.1) 36 | self.W1.bias.data.fill_(0.1) 37 | self.W2.weight.data.normal_(std=0.1) 38 | self.W2.bias.data.fill_(0.1) 39 | 40 | self.relu = nn.ReLU() 41 | self.pool = nn.MaxPool2d(kernel_size=(1,2), stride=(1,2)) 42 | self.l2_norm = F.normalize 43 | 44 | # if not self.use_gt_act: 45 | # self.gru = nn.GRU(self.fc_in_dim, self.fc_in_dim, 2, batch_first=True).to(self.device) 46 | 47 | def forward(self, ped, masks, bbox, act, pose=None): 48 | # use GT action observations 49 | B = ped.shape[0] 50 | 51 | x = self.util_norm_pose(pose) 52 | 53 | x = x.contiguous().view(B, self.seq_len, 18) 54 | 55 | if self.use_gru: 56 | feats = self.embed_pose(x) 57 | temp_feats, h = self.gru(feats) 58 | 59 | if self.predict: 60 | o = temp_feats[:, -1:] 61 | pred_outs = [] 62 | for pred_t in range(self.pred_seq_len): 63 | o, h = self.gru_pred(o, h) 64 | pred_outs += o, 65 | pred_outs = torch.cat(pred_outs, 1) 66 | temp_feats = torch.cat([temp_feats, pred_outs], 1) 67 | 68 | logits = self.classifier(temp_feats) 69 | else: 70 | x = x.unsqueeze(1) # shape: (B, 1, self.seq_len, 18) 71 | x = self.W1(x) 72 | x = self.relu(x) 73 | # max pool over the channel dimension 74 | x = self.pool(x) 75 | 76 | x = self.W2(x) 77 | x = self.relu(x) 78 | # max pool over the channel dimension 79 | x = self.pool(x) 80 | 81 | x = x.view(-1, 16*self.seq_len*((9*2) // 4)) 82 | x = self.fc1(x) 83 | x = self.fc2(x) 84 | 85 | x = x.view(-1, self.pred_seq_len, self.n_acts) 86 | logits = self.l2_norm(x, dim=2) 87 | 88 | # if not self.pred_only: 89 | # x = torch.cat([h, x], 1) 90 | return logits 91 | -------------------------------------------------------------------------------- /models/model_pos.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from models.backbone.resnet_based import resnet_backbone 6 | 7 | from models.model_base import BaseModel 8 | 9 | import pdb 10 | 11 | 12 | class PosModel(BaseModel): 13 | def __init__(self, opt): 14 | super().__init__(opt) 15 | self.seq_len = opt.seq_len 16 | self.pos_mode = opt.pos_mode 17 | if self.pos_mode == 'center': 18 | self.fc_in_dim = 2 19 | elif self.pos_mode == 'height': 20 | self.fc_in_dim = 3 21 | elif self.pos_mode == 'bbox': 22 | self.fc_in_dim = 4 23 | elif self.pos_mode == 'both': 24 | self.fc_in_dim = 6 25 | elif self.pos_mode == 'none': 26 | self.fc_in_dim = 0 27 | 28 | self.use_gt_act = opt.use_gt_act 29 | if self.use_gt_act: 30 | self.fc_in_dim += self.n_acts 31 | 32 | if self.fc_in_dim == 0: 33 | raise ValueError("The model should use at least one of 'pos_mode' or 'use_gt_act'.") 34 | 35 | self.gru = nn.GRU(self.fc_in_dim, self.fc_in_dim, 2, batch_first=True).to(self.device) 36 | self.classifier = nn.Sequential( 37 | nn.Linear(self.fc_in_dim, self.fc_in_dim), 38 | nn.Linear(self.fc_in_dim, self.n_acts) 39 | ) 40 | 41 | if self.predict: 42 | # NOTE: this GRU only takes in seq of len 1, i.e. one time step 43 | # GRU is used over GRUCell for multilayer 44 | self.gru_pred = nn.GRU(self.fc_in_dim, self.fc_in_dim, 2, batch_first=True).to(self.device) 45 | 46 | 47 | def forward(self, ped, ctxt, bbox, act, pose=None): 48 | # bbox: (B, T, 4): (y, x, h, w) 49 | bbox = bbox.to(self.device) 50 | if self.pos_mode in ['center', 'both', 'height']: 51 | y = bbox[:,:, 0] + .5*bbox[:,:,2] 52 | x = bbox[:,:, 1] + .5*bbox[:,:,3] 53 | y = y.unsqueeze(-1) 54 | x = x.unsqueeze(-1) 55 | centers = torch.cat([y,x], -1) 56 | if self.pos_mode == 'both': 57 | gru_in = torch.cat([centers, bbox], -1) 58 | elif self.pos_mode == 'height': 59 | h = bbox[:,:,2:3] 60 | gru_in = torch.cat([centers, h], -1) 61 | else: 62 | gru_in = centers 63 | elif self.pos_mode == 'bbox': 64 | gru_in = bbox 65 | elif self.pos_mode == 'none': 66 | gru_in = None 67 | 68 | if self.use_gt_act: 69 | # shape: (bt, all_seq_len, n_acts+len_for_pos) 70 | if self.n_acts == 1: 71 | act = act[:, :, 1:2] 72 | if gru_in is None: 73 | gru_in = act 74 | else: 75 | gru_in = torch.cat([gru_in, act], -1) 76 | 77 | if self.predict: 78 | gru_in = gru_in[:, :self.seq_len] 79 | 80 | output, h = self.gru(gru_in) 81 | 82 | if self.predict: 83 | # o: (B, 1, fc_in_dim) 84 | o = output[:, -1:] 85 | pred_outs = [] 86 | for _ in range(self.pred_seq_len): 87 | o, h = self.gru_pred(o, h) 88 | pred_outs += o, 89 | 90 | # pred_outs: (B, pred_seq_len, fc_in_dim) 91 | pred_outs = torch.cat(pred_outs, 1) 92 | # frame_feats: (B, seq_len + pred_seq_len, fc_in_dim) 93 | output = torch.cat([output, pred_outs], 1) 94 | 95 | logits = self.classifier(output) 96 | return logits 97 | -------------------------------------------------------------------------------- /scripts_dir/train_graph.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | gpu_id=1 4 | split='train' 5 | n_workers=0 6 | n_acts=9 7 | bt=16 8 | log_every=30 9 | 10 | # Training only 11 | n_epochs=80 12 | start_epoch=0 13 | lr=1e-4 14 | wd=2e-6 15 | lr_decay=1 16 | decay_every=20 17 | save_every=10 18 | evaluate_every=1 19 | 20 | reg_smooth='none' 21 | reg_lambda=0 22 | 23 | seq_len=30 24 | predict=1 25 | pred_seq_len=30 26 | predict_k=0 27 | 28 | load_cache='feats' 29 | cache_format='/sailhome/bingbin/STR-PIP/datasets/cache/JAAD_conv_feats/concat_gru_seq30_pred30_lr1.0e-04_wd1.0e-05_bt4_posNone_branchboth_collapse0_combinepair_cacheMasks_fixGRU_eval3_9acts_noAct_sanityWithPose_withReLU_pedGRU/{}/ped{}_fid{}.pkl' 30 | 31 | use_gru=1 32 | use_trn=0 33 | ped_gru=1 34 | ctxt_gru=0 35 | ctxt_node=0 36 | 37 | # features 38 | use_act=0 39 | use_gt_act=0 40 | use_driver=0 41 | use_pose=0 42 | pos_mode='none' 43 | # branch='ped' 44 | branch='both' 45 | adj_type='spatial' 46 | # adj_type='random' 47 | # adj_type='uniform' 48 | # adj_type='spatialOnly' 49 | # adj_type='all' 50 | # adj_type='inner' 51 | use_obj_cls=0 52 | n_layers=2 53 | diff_layer_weight=0 54 | collapse_cls=0 55 | combine_method='pair' 56 | 57 | # saving & loading 58 | # suffix='_v4Feats_pedGRU_newCtxtGRU_3evalEpoch' 59 | suffix='_v4Feats_pedGRU_3evalEpoch_sigmoidMean' 60 | 61 | if [[ "$ctxt_gru" == 1 ]] 62 | then 63 | suffix=$suffix"_newCtxtGRU" 64 | fi 65 | 66 | if [[ "$use_obj_cls" == 1 ]] 67 | then 68 | suffix=$suffix"_objCls" 69 | fi 70 | 71 | if [[ "$use_driver" == 1 ]] 72 | then 73 | suffix=$suffix"_useDriver" 74 | fi 75 | 76 | ckpt_name='branch'$branch'_collapse'$collapse_cls'_combine'$combine_method'_adjType'$adj_type'_nLayers'$n_layers'_diffW'$diff_layer_weight$suffix 77 | # ckpt_name='graph_seq30_layer2_embed' 78 | 79 | CUDA_VISIBLE_DEVICES=$gpu_id python3 -m train.py \ 80 | --model='graph' \ 81 | --reg-smooth=$reg_smooth \ 82 | --reg-lambda=$reg_lambda \ 83 | --split=$split \ 84 | --n-acts=$n_acts \ 85 | --device=$gpu_id \ 86 | --dset-name='JAAD' \ 87 | --ckpt-name=$ckpt_name \ 88 | --n-epochs=$n_epochs \ 89 | --start-epoch=$start_epoch \ 90 | --lr-init=$lr \ 91 | --wd=$wd \ 92 | --lr-decay=$lr_decay \ 93 | --decay-every=$decay_every \ 94 | --seq-len=$seq_len \ 95 | --predict=$predict \ 96 | --pred-seq-len=$pred_seq_len \ 97 | --predict-k=$predict_k \ 98 | --batch-size=$bt \ 99 | --save-every=$save_every \ 100 | --evaluate-every=$evaluate_every \ 101 | --log-every=$log_every \ 102 | --n-workers=$n_workers \ 103 | --load-cache=$load_cache \ 104 | --cache-format=$cache_format \ 105 | --branch=$branch \ 106 | --adj-type=$adj_type \ 107 | --use-obj-cls=$use_obj_cls \ 108 | --n-layers=$n_layers \ 109 | --diff-layer-weight=$diff_layer_weight \ 110 | --collapse-cls=$collapse_cls \ 111 | --combine-method=$combine_method \ 112 | --use-gru=$use_gru \ 113 | --use-trn=$use_trn \ 114 | --ped-gru=$ped_gru \ 115 | --ctxt-gru=$ctxt_gru \ 116 | --ctxt-node=$ctxt_node \ 117 | --use-act=$use_act \ 118 | --use-gt-act=$use_gt_act \ 119 | --use-driver=$use_driver \ 120 | --use-pose=$use_pose \ 121 | --pos-mode=$pos_mode 122 | -------------------------------------------------------------------------------- /models/backbone/resnet_based.py: -------------------------------------------------------------------------------- 1 | import torch 2 | try: 3 | from models.backbone.imagenet_pretraining import load_pretrained_2D_weights 4 | from models.backbone.resnet.basicblock import BasicBlock2D, BasicBlock3D, BasicBlock2_1D 5 | from models.backbone.resnet.bottleneck import Bottleneck2D, Bottleneck3D, Bottleneck2_1D 6 | from models.backbone.resnet.resnet import ResNetBackBone 7 | except: 8 | from imagenet_pretraining import load_pretrained_2D_weights 9 | from resnet.basicblock import BasicBlock2D, BasicBlock3D, BasicBlock2_1D 10 | from resnet.bottleneck import Bottleneck2D, Bottleneck3D, Bottleneck2_1D 11 | from resnet.resnet import ResNetBackBone 12 | 13 | 14 | # __all__ = [ 15 | # 'resnet_two_heads', 16 | # ] 17 | 18 | 19 | def resnet_backbone(depth=18, blocks='2D_2D_2D_2D', **kwargs): 20 | """Constructs a ResNet-18 model backbone 21 | """ 22 | # Blocks and layers 23 | list_block, list_layers = get_cnn_structure(depth=depth, 24 | str_blocks=blocks) 25 | 26 | # Model with two heads 27 | model = ResNetBackBone(list_block, 28 | list_layers, 29 | **kwargs) 30 | 31 | if False: 32 | print( 33 | "*** Backbone: Resnet{} (blocks: {} - pooling: {} - Two heads - blocks 2nd head: {} and fm size 2nd head: {}) ***".format( 34 | depth, 35 | blocks, 36 | pooling, 37 | object_head, 38 | model.size_fm_2nd_head)) 39 | 40 | # Pretrained from imagenet weights 41 | model = load_pretrained_2D_weights('resnet{}'.format(depth), model, inflation='center') 42 | 43 | return model 44 | 45 | 46 | def get_cnn_structure(str_blocks='2D_2D_2D_2D', depth=18): 47 | # List of blocks 48 | list_block = [] 49 | 50 | # layers 51 | if depth == 18: 52 | list_layers = [2, 2, 2, 2] 53 | nature_of_block = 'basic' 54 | elif depth == 34: 55 | list_layers = [3, 4, 6, 3] 56 | nature_of_block = 'basic' 57 | elif depth == 50: 58 | list_layers = [3, 4, 6, 3] 59 | nature_of_block = 'bottleneck' 60 | else: 61 | raise NameError 62 | 63 | # blocks 64 | if nature_of_block == 'basic': 65 | block_2D, block_3D, block_2_1D = BasicBlock2D, BasicBlock3D, BasicBlock2_1D 66 | elif nature_of_block == 'bottleneck': 67 | block_2D, block_3D, block_2_1D = Bottleneck2D, Bottleneck3D, Bottleneck2_1D 68 | else: 69 | raise NameError 70 | 71 | # From string to blocks 72 | list_block_id = str_blocks.split('_') 73 | 74 | # Catch from the options if exists 75 | for i, str_block in enumerate(list_block_id): 76 | # Block kind 77 | if str_block == '2D': 78 | list_block.append(block_2D) 79 | elif str_block == '2.5D': 80 | list_block.append(block_2_1D) 81 | elif str_block == '3D': 82 | list_block.append(block_3D) 83 | else: 84 | # ipdb.set_trace() 85 | raise NameError 86 | 87 | return list_block, list_layers 88 | 89 | 90 | if __name__ == '__main__': 91 | model = resnet_backbone() 92 | img = torch.ones([1, 1, 224, 224]) 93 | feats = model(img) 94 | print(feats.shape) 95 | -------------------------------------------------------------------------------- /scripts_dir/test_concat_stip_side.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # mode='evaluate' 4 | mode='extract' 5 | 6 | gpu_id=0 7 | n_workers=0 8 | n_acts=1 9 | 10 | seq_len=1 11 | predict=0 12 | pred_seq_len=2 13 | 14 | # test only 15 | slide=0 16 | rand_test=1 17 | log_every=10 18 | ckpt_dir='/sailhome/bingbin/STR-PIP/ckpts' 19 | dataset='STIP' 20 | 21 | split='test' 22 | use_gru=1 23 | ped_gru=1 24 | use_trn=0 25 | pos_mode='none' 26 | use_act=0 27 | use_gt_act=0 28 | use_pose=0 29 | # branch='ped' 30 | branch='both' 31 | collapse_cls=0 32 | combine_method='pair' 33 | 34 | view='right' 35 | # view='all_views' 36 | # view='all' 37 | 38 | annot_ped_format='/vision/group/prolix/processed/pedestrians_all_views/{}.pkl' 39 | # if [ $view = 'center' ] 40 | # then 41 | # cache_obj_bbox_format='/vision/group/prolix/processed/obj_bbox_20fps_merged/{}_seg{}.pkl' 42 | # else 43 | cache_obj_bbox_format='/vision/group/prolix/processed/'$view'/obj_bbox_20fps_merged/{}_seg{}.pkl' 44 | # fi 45 | 46 | 47 | load_cache='none' 48 | 49 | cache_format='/vision/group/prolix/processed/cache/'$view'/{}/ped{}_fid{}.pkl' 50 | save_cache_format=$cache_format 51 | 52 | # ckpt_name='concat_gru_seq8_pred2_lr1.0e-04_wd1.0e-05_bt2_posNone_branchboth_collapse0_combinepair_testANN_hanh1' 53 | 54 | # ckpt_name='concat_gru_seq8_pred2_lr1.0e-04_wd1.0e-05_bt2_posNone_branchboth_collapse0_combinepair_run1' 55 | 56 | ckpt_name='concat_gru_seq8_pred2_lr1.0e-05_wd1.0e-05_bt2_posNone_branchboth_collapse0_combinepair_decay10' 57 | # -1 for the best epoch 58 | which_epoch=-1 59 | 60 | # this is to set a non-existent epoch s.t. the features are extracted from ImageNet backbone 61 | # which_epoch=100 62 | 63 | if [ $which_epoch -eq -1 ] 64 | then 65 | epoch_name='best_pred' 66 | else 67 | epoch_name=$which_epoch 68 | fi 69 | save_output=10 70 | save_output_format=$ckpt_dir'/'$dataset'/'$ckpt_name'/output_epoch'$epoch_name'_step{}.pkl' 71 | 72 | 73 | if [ "$mode" = "extract" ] 74 | then 75 | extract_feats_dir='/vision/group/prolix/processed/cache/'$view'/STIP_conv_feats/'$ckpt_name'/'$split'/' 76 | else 77 | extract_feats_dir='none_existent' 78 | fi 79 | 80 | CUDA_VISIBLE_DEVICES=$gpu_id python3 test.py \ 81 | --model='concat' \ 82 | --view=$view \ 83 | --split=$split \ 84 | --n-acts=$n_acts \ 85 | --mode=$mode \ 86 | --device=$gpu_id \ 87 | --log-every=$log_every \ 88 | --dset-name='STIP' \ 89 | --ckpt-name=$ckpt_name \ 90 | --batch-size=1 \ 91 | --n-workers=$n_workers \ 92 | --annot-ped-format=$annot_ped_format \ 93 | --cache-obj-bbox-format=$cache_obj_bbox_format \ 94 | --load-cache=$load_cache \ 95 | --save-cache-format=$save_cache_format \ 96 | --cache-format=$cache_format \ 97 | --seq-len=$seq_len \ 98 | --predict=$predict \ 99 | --pred-seq-len=$pred_seq_len \ 100 | --use-gru=$use_gru \ 101 | --ped-gru=$ped_gru \ 102 | --use-trn=$use_trn \ 103 | --use-act=$use_act \ 104 | --use-gt-act=$use_gt_act \ 105 | --use-pose=$use_pose \ 106 | --pos-mode=$pos_mode \ 107 | --collapse-cls=$collapse_cls \ 108 | --slide=$slide \ 109 | --rand-test=$rand_test \ 110 | --branch=$branch \ 111 | --which-epoch=$which_epoch \ 112 | --save-output=$save_output \ 113 | --save-output-format=$save_output_format \ 114 | --extract-feats-dir=$extract_feats_dir 115 | 116 | -------------------------------------------------------------------------------- /scripts_dir/test_concat.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | mode='evaluate' 4 | # mode='extract' 5 | 6 | gpu_id=0 7 | n_workers=0 8 | n_acts=9 9 | 10 | seq_len=30 11 | predict=1 12 | pred_seq_len=30 13 | 14 | # test only 15 | slide=0 16 | rand_test=1 17 | log_every=10 18 | ckpt_dir='/sailhome/bingbin/STR-PIP/ckpts' 19 | dataset='JAAD' 20 | 21 | split='test' 22 | use_gru=1 23 | ped_gru=1 24 | use_trn=0 25 | pos_mode='none' 26 | use_act=0 27 | use_gt_act=0 28 | use_pose=0 29 | # branch='ped' 30 | branch='both' 31 | collapse_cls=0 32 | combine_method='pair' 33 | 34 | load_cache='masks' 35 | # load_cache='none' 36 | 37 | cache_format='/sailhome/bingbin/STR-PIP/datasets/cache/jaad_collapse/{}/ped{}_fid{}.pkl' 38 | save_cache_format=$cache_format 39 | 40 | # cache_format='/sailhome/bingbin/STR-PIP/datasets/cache/jaad_collapse_max/{}/ped{}_fid{}.pkl' 41 | # cache_format='/sailhome/bingbin/STR-PIP/datasets/cache/jaad_collapse/{}/ped{}_fid{}.pkl' 42 | 43 | # ckpt_name='concat_gru_lr1.0e-04_wd1.0e-05_bt4_ped_collapse0_combinepair_useBBox1_cacheMasks_fixGRU' 44 | # which_epoch=13 45 | 46 | # ckpt_name='concat_gru_lr1.0e-04_wd1.0e-05_bt4_ped_collapse0_combinepair_useBBox0_cacheMasks_fixGRU' 47 | # which_epoch=44 48 | 49 | # ckpt_name='concat_gru_seq30_pred30_lr1.0e-04_wd1.0e-05_bt4_posNone_branchped_collapse0_combinepair_cacheMasks_fixGRU_eval3_9acts_withGTAct' 50 | 51 | # bkwfairi 52 | # ckpt_name='concat_gru_seq14_pred1_lr1.0e-04_wd1.0e-05_bt4_posNone_branchped_collapse0_combinepair_cacheMasks_fixGRU_eval3_9acts_noAct_sanityWithPose_withReLU' 53 | 54 | # dur1n8v7 55 | # saved: pred ~74.9 56 | # ckpt_name='concat_gru_seq30_pred30_lr1.0e-04_wd1.0e-05_bt4_posNone_branchboth_collapse0_combinepair_cacheMasks_fixGRU_eval3_9acts_noAct_sanityWithPose_withReLU_pedGRU' 57 | 58 | # -1 for the best epoch 59 | which_epoch=-1 60 | 61 | # this is to set a non-existent epoch s.t. the features are extracted from ImageNet backbone 62 | # which_epoch=100 63 | 64 | if [ $which_epoch -eq -1 ] 65 | then 66 | epoch_name='best_pred' 67 | else 68 | epoch_name=$which_epoch 69 | fi 70 | save_output=10 71 | save_output_format=$ckpt_dir'/'$dataset'/'$ckpt_name'/output_epoch'$epoch_name'_step{}.pkl' 72 | 73 | 74 | if [ "$mode" = "extract" ] 75 | then 76 | extract_feats_dir='/sailhome/bingbin/STR-PIP/datasets/cache/JAAD_conv_feats/'$ckpt_name'/'$split'/' 77 | else 78 | extract_feats_dir='none_existent' 79 | fi 80 | 81 | CUDA_VISIBLE_DEVICES=$gpu_id python3 test.py \ 82 | --model='concat' \ 83 | --split=$split \ 84 | --n-acts=$n_acts \ 85 | --mode=$mode \ 86 | --device=$gpu_id \ 87 | --log-every=$log_every \ 88 | --dset-name='JAAD' \ 89 | --ckpt-name=$ckpt_name \ 90 | --batch-size=1 \ 91 | --n-workers=$n_workers \ 92 | --load-cache=$load_cache \ 93 | --save-cache-format=$save_cache_format \ 94 | --cache-format=$cache_format \ 95 | --seq-len=$seq_len \ 96 | --predict=$predict \ 97 | --pred-seq-len=$pred_seq_len \ 98 | --use-gru=$use_gru \ 99 | --ped-gru=$ped_gru \ 100 | --use-trn=$use_trn \ 101 | --use-act=$use_act \ 102 | --use-gt-act=$use_gt_act \ 103 | --use-pose=$use_pose \ 104 | --pos-mode=$pos_mode \ 105 | --collapse-cls=$collapse_cls \ 106 | --slide=$slide \ 107 | --rand-test=$rand_test \ 108 | --branch=$branch \ 109 | --which-epoch=$which_epoch \ 110 | --save-output=$save_output \ 111 | --save-output-format=$save_output_format \ 112 | --extract-feats-dir=$extract_feats_dir 113 | 114 | -------------------------------------------------------------------------------- /scripts_dir/train_loc_graph.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | gpu_id=1 4 | split='train' 5 | n_workers=0 6 | n_acts=1 7 | bt=8 8 | log_every=30 9 | 10 | # Training only 11 | n_epochs=80 12 | start_epoch=0 13 | lr=1e-4 14 | wd=1e-5 15 | lr_decay=1 16 | decay_every=20 17 | save_every=10 18 | evaluate_every=1 19 | 20 | reg_smooth='none' 21 | reg_lambda=0 22 | 23 | seq_len=30 24 | predict=1 25 | pred_seq_len=30 26 | predict_k=0 27 | 28 | annot_loc_format='/sailhome/ajarno/STR-PIP/datasets/annot_{}_loc.pkl' 29 | 30 | load_cache='feats' 31 | # cache_format='/sailhome/bingbin/STR-PIP/datasets/cache/JAAD_conv_feats/concat_gru_seq30_pred30_lr1.0e-04_wd1.0e-05_bt4_posNone_branchboth_collapse0_combinepair_cacheMasks_fixGRU_eval3_9acts_noAct_sanityWithPose_withReLU_pedGRU/{}/ped{}_fid{}.pkl' 32 | cache_format='/sailhome/bingbin/STR-PIP/datasets/cache/jaad_loc/JAAD_conv_feats/loc_concat_gru_seq30_pred30_lr1.0e-04_wd1.0e-05_bt1_posNone_branchboth_collapse0_combinepair_tmp/{}/vid{}_fid{}.pkl' 33 | 34 | use_gru=1 35 | use_trn=0 36 | frame_gru=1 37 | node_gru=0 38 | 39 | # features 40 | use_act=0 41 | use_gt_act=0 42 | use_driver=0 43 | use_pose=0 44 | pos_mode='none' 45 | # branch='ped' 46 | branch='both' 47 | # adj_type='spatial' 48 | # adj_type='random' 49 | # adj_type='uniform' 50 | # adj_type='spatialOnly' 51 | # adj_type='all' 52 | adj_type='inner' 53 | use_obj_cls=0 54 | n_layers=2 55 | diff_layer_weight=0 56 | collapse_cls=0 57 | combine_method='pair' 58 | 59 | # saving & loading 60 | # suffix='_v4Feats_pedGRU_newCtxtGRU_3evalEpoch' 61 | suffix='_v4Feats_pedGRU_3evalEpoch' 62 | 63 | if [[ "$node_gru" == 1 ]] 64 | then 65 | suffix=$suffix"_nodeGRU" 66 | fi 67 | 68 | if [[ "$use_obj_cls" == 1 ]] 69 | then 70 | suffix=$suffix"_objCls" 71 | fi 72 | 73 | if [[ "$use_driver" == 1 ]] 74 | then 75 | suffix=$suffix"_useDriver" 76 | fi 77 | 78 | ckpt_name='branch'$branch'_collapse'$collapse_cls'_combine'$combine_method'_adjType'$adj_type'_nLayers'$n_layers'_diffW'$diff_layer_weight$suffix 79 | # ckpt_name='graph_seq30_layer2_embed' 80 | 81 | CUDA_VISIBLE_DEVICES=$gpu_id python3 -m train.py \ 82 | --model='loc_graph' \ 83 | --reg-smooth=$reg_smooth \ 84 | --reg-lambda=$reg_lambda \ 85 | --split=$split \ 86 | --n-acts=$n_acts \ 87 | --device=$gpu_id \ 88 | --dset-name='JAAD_loc' \ 89 | --ckpt-name=$ckpt_name \ 90 | --n-epochs=$n_epochs \ 91 | --start-epoch=$start_epoch \ 92 | --lr-init=$lr \ 93 | --wd=$wd \ 94 | --lr-decay=$lr_decay \ 95 | --decay-every=$decay_every \ 96 | --seq-len=$seq_len \ 97 | --predict=$predict \ 98 | --pred-seq-len=$pred_seq_len \ 99 | --predict-k=$predict_k \ 100 | --batch-size=$bt \ 101 | --save-every=$save_every \ 102 | --evaluate-every=$evaluate_every \ 103 | --log-every=$log_every \ 104 | --n-workers=$n_workers \ 105 | --annot-loc-format=$annot_loc_format \ 106 | --load-cache=$load_cache \ 107 | --cache-format=$cache_format \ 108 | --branch=$branch \ 109 | --adj-type=$adj_type \ 110 | --use-obj-cls=$use_obj_cls \ 111 | --n-layers=$n_layers \ 112 | --diff-layer-weight=$diff_layer_weight \ 113 | --collapse-cls=$collapse_cls \ 114 | --combine-method=$combine_method \ 115 | --use-gru=$use_gru \ 116 | --use-trn=$use_trn \ 117 | --frame-gru=$frame_gru \ 118 | --node-gru=$node_gru \ 119 | --use-act=$use_act \ 120 | --use-gt-act=$use_gt_act \ 121 | --use-driver=$use_driver \ 122 | --use-pose=$use_pose \ 123 | --pos-mode=$pos_mode 124 | -------------------------------------------------------------------------------- /utils/visual.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | import json 4 | import cv2 5 | import pdb 6 | 7 | MACHINE = 'TRI' 8 | MACHINE = 'CS' 9 | 10 | fps_annot = 2 11 | fps_png = 20 12 | fps_tgt = 2 # target FPS 13 | overRate = 1 # oversample rate for annots 14 | downRate = 10 # downsample rate for png 15 | 16 | if MACHINE == 'TRI': 17 | # on TRI instance 18 | stip_root = '/mnt/paralle/stip' 19 | annot_root = os.path.join(stip_root, 'annotation') 20 | # subdir_format_png: two placeholders: area (e.g. "ANN_conor1"), subdir (e.g. "00:15-00:41"). 21 | subdir_format_png = os.path.join(stip_root, 'stip_instances', '{}/rgb/{}') 22 | else: 23 | # on Stanford CS servers 24 | stip_root = '/vision2/u/bingbin/STR-PIP/STIP' 25 | annot_root = os.path.join(stip_root, 'annotations') 26 | subdir_format_png = os.path.join(stip_root, 'rgb', '{}/{}') 27 | 28 | # drawing setting 29 | FONT = cv2.FONT_HERSHEY_SIMPLEX 30 | FONT_SCALE = 4 31 | TEXT_COLOR = (255, 255, 255) 32 | TEXT_THICK = 2 33 | 34 | def png_lookup(fpngs): 35 | png_map = {} 36 | for fpng in fpngs: 37 | fid = int(os.path.basename(fpng).split('.')[0]) 38 | png_map[fid] = fpng 39 | return png_map 40 | 41 | 42 | def visual_clip(fannot, sdarea, tgt_dir): 43 | # pngs = sorted([fpng for sdir in glob(png_formt.format(darea, '*')) for fpng in os.path.join(sdir, '*.png')]) 44 | fpngs_all = sorted([fpng for fpng in glob(os.path.join(sdarea, '*.png'))]) 45 | png_map = png_lookup(fpngs_all) 46 | 47 | os.makedirs(tgt_dir, exist_ok=True) 48 | 49 | annot = json.load(open(fannot, 'r')) 50 | frames_annot = annot['frames'] 51 | for fid_annot in frames_annot: 52 | fid_tgt = [overRate*(int(fid_annot)-1)+i for i in range(overRate)] 53 | fid_png = [tfid*downRate for tfid in fid_tgt] 54 | fpngs = [png_map[pfid] for pfid in fid_png if pfid in png_map] 55 | if len(fpngs): 56 | print('# fpngs: {} / len(png_map): {}'.format(len(fpngs), len(png_map))) 57 | print('fid_annot:', fid_annot) 58 | print('fid_tgt:', fid_tgt) 59 | print('fid_png:', fid_png) 60 | print('fpngs:', fpngs) 61 | print() 62 | 63 | # draw bbox on png frames 64 | annot = frames_annot[fid_annot] 65 | for fpng in fpngs: 66 | img = cv2.imread(fpng) 67 | print('img size:', img.shape) 68 | if len(annot) == 0: 69 | cv2.putText(img, 'No obj in the current frame.', (100,100), 70 | FONT, FONT_SCALE, TEXT_COLOR, TEXT_THICK) 71 | else: 72 | for obj in annot: 73 | cv2.rectangle(img, (int(obj['x1']), int(obj['y1'])), (int(obj['x2']), int(obj['y2'])), (0,255,0), 3) 74 | cv2.putText(img, '-'.join(obj['tags']), (int(obj['x1']), int(obj['y1'])), 75 | FONT, FONT_SCALE, TEXT_COLOR, TEXT_THICK) 76 | tgt_fpng = fpng.replace(sdarea, tgt_dir) 77 | if tgt_fpng == fpng: 78 | print('Error saving drawn png: tgt_fpng == fpng') 79 | print('fpng:', fpng) 80 | pdb.set_trace() 81 | else: 82 | cv2.imwrite(tgt_fpng, img) 83 | 84 | def visual_clip_wrapper(): 85 | fannot = os.path.join(annot_root, '20170907_prolix_trial_ANN_hanh2-09-07-2017_15-44-07.concat.12fps.mp4.json') 86 | sdarea = subdir_format_png.format('ANN_hanh2', '00:16--00:27') 87 | if MACHINE == 'TRI': 88 | tgt_dir = sdarea.replace('/stip/', '/stip/tmp_vis/') 89 | else: 90 | tgt_dir = sdarea.replace('/rgb/' ,'/tmp_vis/') 91 | visual_clip(fannot, sdarea, tgt_dir) 92 | 93 | 94 | if __name__ == "__main__": 95 | visual_clip_wrapper() 96 | -------------------------------------------------------------------------------- /scripts_dir/train_graph_stip_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | gpu_id=0 4 | split='train' 5 | n_workers=0 6 | n_acts=1 7 | bt=32 8 | log_every=30 9 | 10 | # Training only 11 | n_epochs=200 12 | start_epoch=0 13 | lr=1e-4 14 | wd=1e-5 15 | lr_decay=1 16 | decay_every=20 17 | save_every=10 18 | evaluate_every=1 19 | 20 | reg_smooth='none' 21 | reg_lambda=0 22 | 23 | seq_len=8 24 | predict=1 25 | pred_seq_len=6 26 | predict_k=0 27 | 28 | view='all' 29 | 30 | annot_ped_format='/vision/group/prolix/processed/pedestrians_all_views/{}.pkl' 31 | 32 | cache_obj_bbox_format='/vision/group/prolix/processed/all/obj_bbox_20fps_merged/{}_seg{}.pkl' 33 | 34 | load_cache='feats' 35 | cache_format_base='/vision/group/prolix/processed/cache/all/STIP_conv_feats/' 36 | cache_format_pkl='{}/ped{}_fid{}.pkl' 37 | concat_ckpt='concat_gru_seq8_pred2_lr1.0e-05_wd1.0e-05_bt2_posNone_branchboth_collapse0_combinepair_decay10/' 38 | cache_format=$cache_format_base$concat_ckpt$cache_format_pkl 39 | 40 | use_gru=1 41 | use_trn=0 42 | ped_gru=1 43 | ctxt_gru=0 44 | ctxt_node=0 45 | 46 | # features 47 | use_act=0 48 | use_gt_act=0 49 | use_driver=0 50 | use_pose=0 51 | pos_mode='none' 52 | # branch='ped' 53 | branch='both' 54 | adj_type='spatial' 55 | # adj_type='random' 56 | # adj_type='uniform' 57 | # adj_type='spatialOnly' 58 | # adj_type='all' 59 | # adj_type='inner' 60 | use_obj_cls=0 61 | n_layers=2 62 | diff_layer_weight=0 63 | collapse_cls=0 64 | combine_method='pair' 65 | 66 | # saving & loading 67 | suffix='_decay10Feats_'$view 68 | 69 | if [[ "$ctxt_gru" == 1 ]] 70 | then 71 | suffix=$suffix"_newCtxtGRU" 72 | fi 73 | 74 | if [[ "$use_obj_cls" == 1 ]] 75 | then 76 | suffix=$suffix"_objCls" 77 | fi 78 | 79 | if [[ "$use_driver" == 1 ]] 80 | then 81 | suffix=$suffix"_useDriver" 82 | fi 83 | 84 | ckpt_name='branch'$branch'_collapse'$collapse_cls'_combine'$combine_method'_adjType'$adj_type'_nLayers'$n_layers'_diffW'$diff_layer_weight$suffix 85 | 86 | # WANDB_MODE=dryrun CUDA_VISIBLE_DEVICES=$gpu_id python3 -m train.py \ 87 | CUDA_VISIBLE_DEVICES=$gpu_id python3 -m train.py \ 88 | --model='graph' \ 89 | --view=$view \ 90 | --reg-smooth=$reg_smooth \ 91 | --reg-lambda=$reg_lambda \ 92 | --split=$split \ 93 | --n-acts=$n_acts \ 94 | --device=$gpu_id \ 95 | --dset-name='STIP' \ 96 | --ckpt-name=$ckpt_name \ 97 | --n-epochs=$n_epochs \ 98 | --start-epoch=$start_epoch \ 99 | --lr-init=$lr \ 100 | --wd=$wd \ 101 | --lr-decay=$lr_decay \ 102 | --decay-every=$decay_every \ 103 | --seq-len=$seq_len \ 104 | --predict=$predict \ 105 | --pred-seq-len=$pred_seq_len \ 106 | --predict-k=$predict_k \ 107 | --batch-size=$bt \ 108 | --save-every=$save_every \ 109 | --evaluate-every=$evaluate_every \ 110 | --log-every=$log_every \ 111 | --n-workers=$n_workers \ 112 | --annot-ped-format=$annot_ped_format \ 113 | --cache-obj-bbox-format=$cache_obj_bbox_format \ 114 | --load-cache=$load_cache \ 115 | --cache-format=$cache_format \ 116 | --branch=$branch \ 117 | --adj-type=$adj_type \ 118 | --use-obj-cls=$use_obj_cls \ 119 | --n-layers=$n_layers \ 120 | --diff-layer-weight=$diff_layer_weight \ 121 | --collapse-cls=$collapse_cls \ 122 | --combine-method=$combine_method \ 123 | --use-gru=$use_gru \ 124 | --use-trn=$use_trn \ 125 | --ped-gru=$ped_gru \ 126 | --ctxt-gru=$ctxt_gru \ 127 | --ctxt-node=$ctxt_node \ 128 | --use-act=$use_act \ 129 | --use-gt-act=$use_gt_act \ 130 | --use-driver=$use_driver \ 131 | --use-pose=$use_pose \ 132 | --pos-mode=$pos_mode 133 | -------------------------------------------------------------------------------- /scripts_dir/train_graph_stip.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | gpu_id=0 4 | split='train' 5 | n_workers=0 6 | n_acts=1 7 | bt=32 8 | log_every=30 9 | 10 | # Training only 11 | n_epochs=200 12 | start_epoch=0 13 | lr=1e-4 14 | wd=1e-5 15 | lr_decay=1 16 | decay_every=20 17 | save_every=10 18 | evaluate_every=1 19 | 20 | reg_smooth='none' 21 | reg_lambda=0 22 | 23 | seq_len=30 24 | predict=1 25 | pred_seq_len=60 26 | predict_k=0 27 | 28 | annot_ped_format='/vision/group/prolix/processed/pedestrians/{}.pkl' 29 | 30 | cache_obj_bbox_format='/vision/group/prolix/processed/obj_bbox_20fps_merged/{}_seg{}.pkl' 31 | 32 | load_cache='feats' 33 | cache_format_base='/vision/group/prolix/processed/cache/STIP_conv_feats/' 34 | cache_format_pkl='{}/ped{}_fid{}.pkl' 35 | # cache_format=$cache_format_base'concat_gru_seq8_pred2_lr1.0e-04_wd1.0e-05_bt2_posNone_branchboth_collapse0_combinepair_testANN_hanh1/'$cache_format_pkl 36 | # concat_ckpt='concat_gru_seq8_pred2_lr1.0e-04_wd1.0e-05_bt2_posNone_branchboth_collapse0_combinepair_run1/' 37 | concat_ckpt='concat_gru_seq8_pred2_lr1.0e-05_wd1.0e-05_bt2_posNone_branchboth_collapse0_combinepair_decay10/' 38 | cache_format=$cache_format_base$concat_ckpt$cache_format_pkl 39 | 40 | use_gru=1 41 | use_trn=0 42 | ped_gru=1 43 | ctxt_gru=0 44 | ctxt_node=0 45 | 46 | # features 47 | use_act=0 48 | use_gt_act=0 49 | use_driver=0 50 | use_pose=0 51 | pos_mode='none' 52 | # branch='ped' 53 | branch='both' 54 | adj_type='spatial' 55 | # adj_type='random' 56 | # adj_type='uniform' 57 | # adj_type='spatialOnly' 58 | # adj_type='all' 59 | # adj_type='inner' 60 | use_obj_cls=0 61 | n_layers=2 62 | diff_layer_weight=0 63 | collapse_cls=0 64 | combine_method='pair' 65 | 66 | # saving & loading 67 | suffix='_decay10Feats' 68 | 69 | if [[ "$ctxt_gru" == 1 ]] 70 | then 71 | suffix=$suffix"_newCtxtGRU" 72 | fi 73 | 74 | if [[ "$use_obj_cls" == 1 ]] 75 | then 76 | suffix=$suffix"_objCls" 77 | fi 78 | 79 | if [[ "$use_driver" == 1 ]] 80 | then 81 | suffix=$suffix"_useDriver" 82 | fi 83 | 84 | ckpt_name='branch'$branch'_collapse'$collapse_cls'_combine'$combine_method'_adjType'$adj_type'_nLayers'$n_layers'_diffW'$diff_layer_weight$suffix 85 | 86 | # WANDB_MODE=dryrun CUDA_VISIBLE_DEVICES=$gpu_id python3 -m train.py \ 87 | CUDA_VISIBLE_DEVICES=$gpu_id python3 -m train.py \ 88 | --model='graph' \ 89 | --reg-smooth=$reg_smooth \ 90 | --reg-lambda=$reg_lambda \ 91 | --split=$split \ 92 | --n-acts=$n_acts \ 93 | --device=$gpu_id \ 94 | --dset-name='STIP' \ 95 | --ckpt-name=$ckpt_name \ 96 | --n-epochs=$n_epochs \ 97 | --start-epoch=$start_epoch \ 98 | --lr-init=$lr \ 99 | --wd=$wd \ 100 | --lr-decay=$lr_decay \ 101 | --decay-every=$decay_every \ 102 | --seq-len=$seq_len \ 103 | --predict=$predict \ 104 | --pred-seq-len=$pred_seq_len \ 105 | --predict-k=$predict_k \ 106 | --batch-size=$bt \ 107 | --save-every=$save_every \ 108 | --evaluate-every=$evaluate_every \ 109 | --log-every=$log_every \ 110 | --n-workers=$n_workers \ 111 | --annot-ped-format=$annot_ped_format \ 112 | --cache-obj-bbox-format=$cache_obj_bbox_format \ 113 | --load-cache=$load_cache \ 114 | --cache-format=$cache_format \ 115 | --branch=$branch \ 116 | --adj-type=$adj_type \ 117 | --use-obj-cls=$use_obj_cls \ 118 | --n-layers=$n_layers \ 119 | --diff-layer-weight=$diff_layer_weight \ 120 | --collapse-cls=$collapse_cls \ 121 | --combine-method=$combine_method \ 122 | --use-gru=$use_gru \ 123 | --use-trn=$use_trn \ 124 | --ped-gru=$ped_gru \ 125 | --ctxt-gru=$ctxt_gru \ 126 | --ctxt-node=$ctxt_node \ 127 | --use-act=$use_act \ 128 | --use-gt-act=$use_gt_act \ 129 | --use-driver=$use_driver \ 130 | --use-pose=$use_pose \ 131 | --pos-mode=$pos_mode 132 | -------------------------------------------------------------------------------- /scripts_dir/test_graph_stip_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | mode='evaluate' 4 | split='test' 5 | 6 | gpu_id=1 7 | n_workers=0 8 | n_acts=1 9 | bt=1 10 | 11 | seq_len=4 12 | predict=1 13 | pred_seq_len=4 14 | predict_k=0 15 | 16 | view='all' 17 | 18 | # TODO: get new ped formats 19 | annot_ped_format='/vision/group/prolix/processed/pedestrians_all_views/{}.pkl' 20 | cache_obj_bbox_format='/vision/group/prolix/processed/'$view'/obj_bbox_20fps_merged/{}_seg{}.pkl' 21 | 22 | # test only 23 | slide=0 24 | rand_test=1 25 | log_every=10 26 | ckpt_dir='/sailhome/bingbin/STR-PIP/ckpts' 27 | dataset='STIP' 28 | 29 | if [ $seq_len -eq 8 ] 30 | then 31 | ckpt_name='graph_gru_seq'$seq_len'_pred'$pred_seq_len'_lr1.0e-04_wd1.0e-05_bt32_posNone_branchboth_collapse0_combinepair_adjTypespatial_nLayers2_diffW0_decay10Feats_all' 32 | else 33 | # seq_len = 4 34 | ckpt_name='graph_gru_seq'$seq_len'_pred'$pred_seq_len'_lr1.0e-04_wd1.0e-05_bt32_posNone_branchboth_collapse0_combinepair_adjTypespatial_nLayers2_diffW0_decay10Feats' 35 | fi 36 | which_epoch=-1 37 | if [ $which_epoch -eq -1 ] 38 | then 39 | epoch_name='best_pred' 40 | else 41 | epoch_name=$which_epoch 42 | fi 43 | save_output=1 44 | save_output_format=$ckpt_dir'/'$dataset'/'$ckpt_name'/output_epoch'$epoch_name'_step{}_'$seq_len'+'$pred_seq_len'.pkl' 45 | collect_A=1 46 | save_As_format=$ckpt_dir'/'$dataset'/'$ckpt_name'/test_graph_weights_epoch'$epoch_name'/vid{}_eval{}.pkl' 47 | 48 | 49 | load_cache='feats' 50 | feat_ckpt_name='concat_gru_seq8_pred2_lr1.0e-05_wd1.0e-05_bt2_posNone_branchboth_collapse0_combinepair_decay10' 51 | cache_format='/vision/group/prolix/processed/cache/'$view'/STIP_conv_feats/'$feat_ckpt_name'/{}/ped{}_fid{}.pkl' 52 | 53 | 54 | # if [ "$mode" = "extract" ] 55 | # then 56 | # extract_feats_dir='/sailhome/bingbin/STR-PIP/datasets/cache/JAAD_conv_feats/concat_gru_lr1.0e-05_bt4_test_epoch5/test/' 57 | # else 58 | # extract_feats_dir='none_existent' 59 | # fi 60 | 61 | use_gru=1 62 | use_trn=0 63 | ped_gru=1 64 | ctxt_gru=0 65 | ctxt_node=0 66 | 67 | # features 68 | use_act=0 69 | use_gt_act=0 70 | use_driver=0 71 | use_pose=0 72 | pos_mode='none' 73 | # branch='ped' 74 | branch='both' 75 | adj_type='spatial' 76 | use_obj_cls=0 77 | n_layers=2 78 | diff_layer_weight=0 79 | collapse_cls=0 80 | combine_method='pair' 81 | 82 | CUDA_VISIBLE_DEVICES=$gpu_id python3 -W ignore test.py \ 83 | --model='graph' \ 84 | --split=$split \ 85 | --view=$view \ 86 | --mode=$mode \ 87 | --slide=$slide \ 88 | --rand-test=$rand_test \ 89 | --ckpt-dir=$ckpt_dir \ 90 | --ckpt-name=$ckpt_name \ 91 | --n-acts=$n_acts \ 92 | --device=$gpu_id \ 93 | --dset-name='STIP' \ 94 | --ckpt-name=$ckpt_name \ 95 | --which-epoch=$which_epoch \ 96 | --save-output=$save_output \ 97 | --save-output-format=$save_output_format \ 98 | --collect-A=$collect_A \ 99 | --save-As-format=$save_As_format \ 100 | --n-epochs=$n_epochs \ 101 | --start-epoch=$start_epoch \ 102 | --seq-len=$seq_len \ 103 | --predict=$predict \ 104 | --pred-seq-len=$pred_seq_len \ 105 | --predict-k=$predict_k \ 106 | --batch-size=$bt \ 107 | --log-every=$log_every \ 108 | --n-workers=$n_workers \ 109 | --annot-ped-format=$annot_ped_format \ 110 | --cache-obj-bbox-format=$cache_obj_bbox_format \ 111 | --load-cache=$load_cache \ 112 | --cache-format=$cache_format \ 113 | --branch=$branch \ 114 | --adj-type=$adj_type \ 115 | --use-obj-cls=$use_obj_cls \ 116 | --n-layers=$n_layers \ 117 | --diff-layer-weight=$diff_layer_weight \ 118 | --collapse-cls=$collapse_cls \ 119 | --combine-method=$combine_method \ 120 | --use-gru=$use_gru \ 121 | --use-trn=$use_trn \ 122 | --ped-gru=$ped_gru \ 123 | --ctxt-gru=$ctxt_gru \ 124 | --ctxt-node=$ctxt_node \ 125 | --use-act=$use_act \ 126 | --use-gt-act=$use_gt_act \ 127 | --use-driver=$use_driver \ 128 | --use-pose=$use_pose \ 129 | --pos-mode=$pos_mode 130 | 131 | 132 | -------------------------------------------------------------------------------- /utils/stats.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | import pickle 4 | 5 | import pdb 6 | 7 | data_root = '/sailhome/ajarno/STR-PIP/datasets/' 8 | img_root = os.path.join(data_root, 'JAAD_dataset/JAAD_clip_images') 9 | 10 | def count_frames(): 11 | vids = glob(os.path.join(img_root, 'video_*.mp4')) 12 | cnts = [] 13 | for vid in vids: 14 | cnt = len(glob(os.path.join(vid, '*.jpg'))) 15 | cnts += cnt, 16 | 17 | print('total # frames:', sum(cnts)) 18 | print('cnts: max={} / min={} / mean={}'.format(max(cnts), min(cnts), sum(cnts)/len(cnts))) 19 | # pdb.set_trace() 20 | 21 | def ped_count_crossing(): 22 | def helper(ped): 23 | act = ped['act'] 24 | crossing = [each[1] for each in act] 25 | n_cross = sum(crossing) 26 | n_noncross = len(crossing) - n_cross 27 | return n_cross, n_noncross 28 | 29 | with open(os.path.join(data_root, 'annot_train_ped_withTag_sanityWithPose.pkl'), 'rb') as handle: 30 | train = pickle.load(handle) 31 | with open(os.path.join(data_root, 'annot_test_ped_withTag_sanityWithPose.pkl'), 'rb') as handle: 32 | test = pickle.load(handle) 33 | 34 | n_crosses = 0 35 | n_noncrosses = 0 36 | for ped in train+test: 37 | n_cross, n_noncross = helper(ped) 38 | n_crosses += n_cross 39 | n_noncrosses += n_noncross 40 | 41 | n_total = n_crosses + n_noncrosses 42 | print('n_cross: {} ({:.4f})'.format(n_crosses, n_crosses/n_total)) 43 | print('n_noncross: {} ({:.4f})'.format(n_noncrosses, n_noncrosses/n_total)) 44 | 45 | def loc_count_crossing(): 46 | def helper(split): 47 | split_cross, split_noncross = 0, 0 48 | for vid in split: 49 | n_cross = sum(split[vid]['act'] == 1)[0] 50 | n_noncross = sum(split[vid]['act'] == 0)[0] 51 | split_cross += n_cross 52 | split_noncross += n_noncross 53 | split_total = split_cross + split_noncross 54 | # pdb.set_trace() 55 | print(' cross:{} ({:.4f}) / non-cross:{} ({:.4f}) / total:{}'.format( 56 | split_cross, split_cross / split_total, split_noncross, split_noncross / split_total, split_total)) 57 | return split_cross, split_noncross, split_total 58 | 59 | with open(os.path.join(data_root, 'annot_train_loc.pkl'), 'rb') as handle: 60 | train = pickle.load(handle) 61 | print('train:') 62 | train_cross, train_noncross, train_total = helper(train) 63 | 64 | with open(os.path.join(data_root, 'annot_test_loc.pkl'), 'rb') as handle: 65 | test = pickle.load(handle) 66 | print('\ntest:') 67 | test_cross, test_noncross, test_total = helper(test) 68 | 69 | n_cross, n_noncross, n_total = train_cross+test_cross, train_noncross+test_noncross, train_total+test_total 70 | print('\ntotal:') 71 | print(' cross:{} ({:.4f}) / non-cross:{} ({:.4f}) / total:{}'.format( 72 | n_cross, n_cross / n_total, n_noncross, n_noncross / n_total, n_total)) 73 | 74 | def loc_count_ped(): 75 | def helper(split): 76 | n_peds = [] 77 | for vid in split: 78 | for frame in split[vid]['ped_pos']: 79 | n_peds += len(frame), 80 | print(' max:{} / min:{} / mean:{}'.format(max(n_peds), min(n_peds), sum(n_peds)/len(n_peds))) 81 | return n_peds 82 | 83 | with open(os.path.join(data_root, 'annot_train_loc.pkl'), 'rb') as handle: 84 | train = pickle.load(handle) 85 | print('train:') 86 | train_n_peds = helper(train) 87 | 88 | with open(os.path.join(data_root, 'annot_test_loc_new.pkl'), 'rb') as handle: 89 | test = pickle.load(handle) 90 | print('\ntest:') 91 | test_n_peds = helper(test) 92 | 93 | n_peds = train_n_peds + test_n_peds 94 | print('\ntotal:') 95 | print(' max:{} / min:{} / mean:{}'.format(max(n_peds), min(n_peds), sum(n_peds)/len(n_peds))) 96 | 97 | 98 | 99 | 100 | if __name__ == '__main__': 101 | print('count_frames:') 102 | count_frames() 103 | 104 | print('\nped_count_crossing:') 105 | ped_count_crossing() 106 | 107 | print('\nloc_count_crossing:') 108 | loc_count_crossing() 109 | 110 | print('\nloc_count_ped:') 111 | loc_count_ped() 112 | -------------------------------------------------------------------------------- /scripts_dir/test_graph_stip.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | mode='evaluate' 4 | split='test' 5 | 6 | gpu_id=0 7 | n_workers=0 8 | n_acts=1 9 | bt=1 10 | 11 | seq_len=8 12 | predict=1 13 | pred_seq_len=6 14 | predict_k=0 15 | 16 | annot_ped_format='/vision/group/prolix/processed/pedestrians/{}.pkl' 17 | cache_obj_bbox_format='/vision/group/prolix/processed/obj_bbox_20fps_merged/{}_seg{}.pkl' 18 | 19 | # test only 20 | slide=0 21 | rand_test=1 22 | log_every=10 23 | ckpt_dir='/sailhome/bingbin/STR-PIP/ckpts' 24 | dataset='STIP' 25 | 26 | # pred 10 27 | # ckpt_name='graph_gru_seq30_pred10_lr1.0e-04_wd1.0e-05_bt8_posNone_branchboth_collapse0_combinepair_adjTypespatial_nLayers2_diffW0_decay10Feats' 28 | ckpt_name='graph_gru_seq8_pred2_lr1.0e-04_wd1.0e-05_bt32_posNone_branchboth_collapse0_combinepair_adjTypespatial_nLayers2_diffW0_decay10Feats' 29 | 30 | # pred 30 31 | # ckpt_name='graph_gru_seq30_pred60_lr1.0e-04_wd1.0e-05_bt16_posNone_branchboth_collapse0_combinepair_adjTypespatial_nLayers2_diffW0_v4Feats_pedGRU_3evalEpoch' 32 | 33 | # pred 60 34 | # ckpt_name='graph_gru_seq30_pred90_lr1.0e-04_wd1.0e-05_bt16_posNone_branchboth_collapse0_combinepair_adjTypespatial_nLayers2_diffW0_v4Feats_pedGRU_3evalEpoch' 35 | 36 | which_epoch=-1 37 | if [ $which_epoch -eq -1 ] 38 | then 39 | epoch_name='best_pred' 40 | else 41 | epoch_name=$which_epoch 42 | fi 43 | save_output=1 44 | save_output_format=$ckpt_dir'/'$dataset'/'$ckpt_name'/output_epoch'$epoch_name'_step{}_'$seq_len'+'$pred_seq_len'.pkl' 45 | collect_A=1 46 | save_As_format=$ckpt_dir'/'$dataset'/'$ckpt_name'/test_graph_weights_epoch'$epoch_name'/vid{}_eval{}.pkl' 47 | 48 | 49 | load_cache='feats' 50 | cache_format='/vision/group/prolix/processed/cache/STIP_conv_feats/concat_gru_seq8_pred2_lr1.0e-05_wd1.0e-05_bt2_posNone_branchboth_collapse0_combinepair_decay10/{}/ped{}_fid{}.pkl' 51 | 52 | 53 | # if [ "$mode" = "extract" ] 54 | # then 55 | # extract_feats_dir='/sailhome/bingbin/STR-PIP/datasets/cache/JAAD_conv_feats/concat_gru_lr1.0e-05_bt4_test_epoch5/test/' 56 | # else 57 | # extract_feats_dir='none_existent' 58 | # fi 59 | 60 | use_gru=1 61 | use_trn=0 62 | ped_gru=1 63 | ctxt_gru=0 64 | ctxt_node=0 65 | 66 | # features 67 | use_act=0 68 | use_gt_act=0 69 | use_driver=0 70 | use_pose=0 71 | pos_mode='none' 72 | # branch='ped' 73 | branch='both' 74 | adj_type='spatial' 75 | use_obj_cls=0 76 | n_layers=2 77 | diff_layer_weight=0 78 | collapse_cls=0 79 | combine_method='pair' 80 | 81 | CUDA_VISIBLE_DEVICES=$gpu_id python3 test.py \ 82 | --model='graph' \ 83 | --split=$split \ 84 | --mode=$mode \ 85 | --slide=$slide \ 86 | --rand-test=$rand_test \ 87 | --ckpt-dir=$ckpt_dir \ 88 | --ckpt-name=$ckpt_name \ 89 | --n-acts=$n_acts \ 90 | --device=$gpu_id \ 91 | --dset-name='STIP' \ 92 | --ckpt-name=$ckpt_name \ 93 | --which-epoch=$which_epoch \ 94 | --save-output=$save_output \ 95 | --save-output-format=$save_output_format \ 96 | --collect-A=$collect_A \ 97 | --save-As-format=$save_As_format \ 98 | --n-epochs=$n_epochs \ 99 | --start-epoch=$start_epoch \ 100 | --seq-len=$seq_len \ 101 | --predict=$predict \ 102 | --pred-seq-len=$pred_seq_len \ 103 | --predict-k=$predict_k \ 104 | --batch-size=$bt \ 105 | --log-every=$log_every \ 106 | --n-workers=$n_workers \ 107 | --annot-ped-format=$annot_ped_format \ 108 | --cache-obj-bbox-format=$cache_obj_bbox_format \ 109 | --load-cache=$load_cache \ 110 | --cache-format=$cache_format \ 111 | --branch=$branch \ 112 | --adj-type=$adj_type \ 113 | --use-obj-cls=$use_obj_cls \ 114 | --n-layers=$n_layers \ 115 | --diff-layer-weight=$diff_layer_weight \ 116 | --collapse-cls=$collapse_cls \ 117 | --combine-method=$combine_method \ 118 | --use-gru=$use_gru \ 119 | --use-trn=$use_trn \ 120 | --ped-gru=$ped_gru \ 121 | --ctxt-gru=$ctxt_gru \ 122 | --ctxt-node=$ctxt_node \ 123 | --use-act=$use_act \ 124 | --use-gt-act=$use_gt_act \ 125 | --use-driver=$use_driver \ 126 | --use-pose=$use_pose \ 127 | --pos-mode=$pos_mode 128 | 129 | exit 130 | 131 | 132 | -------------------------------------------------------------------------------- /utils/temp_plot.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | import numpy as np 4 | import pickle 5 | import matplotlib 6 | matplotlib.use('Agg') 7 | import matplotlib.pyplot as plt 8 | 9 | import pdb 10 | 11 | SMOOTH = True 12 | MIX = True 13 | 14 | fpkls = [] 15 | 16 | ckpt_root = '/sailhome/bingbin/STR-PIP/ckpts/JAAD/' 17 | 18 | # best proposed 19 | 20 | # pred 30 21 | 22 | # ckpt_name = 'graph_gru_seq30_pred30_lr3.0e-04_wd1.0e-05_bt16_posNone_branchboth_collapse0_combinepair_adjTypespatial_nLayers2_v4Feats_pedGRU_3evalEpoch' 23 | # fout = 'label_epochbest_pred_stepall_run2.pkl' 24 | # fpkl = os.path.join(ckpt_root, ckpt_name, fout) 25 | # fpkls += fpkl, 26 | 27 | fout = 'output_epochbest_pred_stepall.pkl' 28 | # fpkl = os.path.join(ckpt_root, ckpt_name, fout) 29 | # fpkls += fpkl, 30 | 31 | ckpt_name = 'graph_gru_seq30_pred30_lr1.0e-04_wd1.0e-05_bt16_posNone_branchboth_collapse0_combinepair_adjTypespatial_nLayers2_v4Feats_pedGRU_3evalEpoch' 32 | fpkl = os.path.join(ckpt_root, ckpt_name, fout) 33 | # fpkls += fpkl, 34 | 35 | # fout = 'output_epochbest_pred_stepall_run3Epochs.pkl' 36 | # fout = 'output_epochbest_pred_stepall_run1Epochs.pkl' 37 | # fpkl = os.path.join(ckpt_root, ckpt_name, fout) 38 | # fpkls += fpkl, 39 | 40 | ckpt_name = 'graph_gru_seq30_pred30_lr1.0e-04_wd1.0e-05_bt16_posNone_branchboth_collapse0_combinepair_adjTypespatial_nLayers2_v4Feats_pedGRU_newCtxtGRU_3evalEpoch' 41 | fpkl = os.path.join(ckpt_root, ckpt_name, fout) 42 | fpkls += fpkl, 43 | 44 | # graph - pred 60 45 | ckpt_name = 'graph_gru_seq30_pred60_lr1.0e-04_wd1.0e-05_bt16_posNone_branchboth_collapse0_combinepair_adjTypespatial_nLayers2_diffW0_v4Feats_pedGRU_3evalEpoch' 46 | fpkl = os.path.join(ckpt_root, ckpt_name, fout) 47 | fpkls += fpkl, 48 | 49 | 50 | # graph - pred 90 51 | ckpt_name = 'graph_gru_seq30_pred90_lr1.0e-04_wd1.0e-05_bt16_posNone_branchboth_collapse0_combinepair_adjTypespatial_nLayers2_diffW0_v4Feats_pedGRU_3evalEpoch' 52 | fpkl = os.path.join(ckpt_root, ckpt_name, fout) 53 | fpkls += fpkl, 54 | 55 | 56 | # best concat (pred 30) 57 | ckpt_name = 'concat_gru_seq30_pred30_lr1.0e-04_wd1.0e-05_bt4_posNone_branchboth_collapse0_combinepair_cacheMasks_fixGRU_eval3_9acts_noAct_sanityWithPose_withReLU_pedGRU' 58 | fout = 'output_all.pkl' 59 | fpkl = os.path.join(ckpt_root, ckpt_name, fout) 60 | 61 | # fpkls += fpkl, 62 | 63 | 64 | for i, fpkl in enumerate(fpkls): 65 | with open(fpkl, 'rb') as handle: 66 | data = pickle.load(handle) 67 | out, gt = data['out'], data['GT'] 68 | if out.shape[-1] % 10 == 1: 69 | out = out[:, :-1] 70 | gt = gt[:, :-1] 71 | try: 72 | acc = (out == gt).mean(0) 73 | pred_acc = (out[:, 30:] == gt[:, 30:]).mean() 74 | print('pred_acc:', pred_acc) 75 | except Exception as e: 76 | print(e) 77 | pdb.set_trace() 78 | acc = list(acc) 79 | if SMOOTH: 80 | t1, t3 = acc[1:]+[acc[-1]], [acc[0]]+acc[:-1] 81 | acc = (np.array(t1) + np.array(t3) + np.array(acc)) / 3 82 | # pdb.set_trace() 83 | plt.plot(acc, label='{} frames{}'.format(30*(i+1), ' (acc)' if MIX else '')) 84 | plt.legend() 85 | plt.xlabel('Frames', fontsize=12) 86 | plt.ylabel('Accuracy', fontsize=12) 87 | plt.savefig('acc_temp_graph{}.png'.format('_smooth' if SMOOTH else ''), dpi=1000) 88 | if not MIX: 89 | plt.clf() 90 | 91 | 92 | for i, fpkl in enumerate(fpkls): 93 | with open(fpkl, 'rb') as handle: 94 | data = pickle.load(handle) 95 | prob = data['prob'] 96 | B = data['out'].shape[0] 97 | prob = prob.reshape(B, -1) 98 | if prob.shape[-1] % 10 == 1: 99 | prob = prob[:, :-1] 100 | prob = list(prob.mean(0)) 101 | if SMOOTH: 102 | t1, t3 = prob[1:]+[prob[-1]], [prob[0]]+prob[:-1] 103 | prob = (np.array(t1) + np.array(t3) + np.array(prob)) / 3 104 | # pdb.set_trace() 105 | plt.plot(prob, label='{} frames{}'.format(30*(i+1), ' (prob)' if MIX else '')) 106 | plt.legend() 107 | plt.xlabel('Frames', fontsize=12) 108 | plt.ylabel('Accuracy & Probability' if MIX else 'Probability', fontsize=12) 109 | plt.savefig('{}_temp_graph{}.png'.format('mix' if MIX else 'prob', '_smooth' if SMOOTH else ''), dpi=1000) 110 | plt.clf() 111 | 112 | -------------------------------------------------------------------------------- /utils/temp_plot_stip.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | import numpy as np 4 | import pickle 5 | import matplotlib 6 | matplotlib.use('Agg') 7 | import matplotlib.pyplot as plt 8 | 9 | import pdb 10 | 11 | SMOOTH = True 12 | MIX = True 13 | 14 | fpkls = [] 15 | 16 | ckpt_root = '/sailhome/bingbin/STR-PIP/ckpts/STIP/' 17 | 18 | # best proposed 19 | 20 | # pred 30 21 | 22 | # ckpt_name = 'graph_gru_seq30_pred30_lr3.0e-04_wd1.0e-05_bt16_posNone_branchboth_collapse0_combinepair_adjTypespatial_nLayers2_v4Feats_pedGRU_3evalEpoch' 23 | # fout = 'label_epochbest_pred_stepall_run2.pkl' 24 | # fpkl = os.path.join(ckpt_root, ckpt_name, fout) 25 | # fpkls += fpkl, 26 | 27 | fout = 'output_epochbest_pred_stepall.pkl' 28 | # fpkl = os.path.join(ckpt_root, ckpt_name, fout) 29 | # fpkls += fpkl, 30 | 31 | ckpt_name = 'graph_gru_seq30_pred30_lr1.0e-04_wd1.0e-05_bt16_posNone_branchboth_collapse0_combinepair_adjTypespatial_nLayers2_v4Feats_pedGRU_3evalEpoch' 32 | fpkl = os.path.join(ckpt_root, ckpt_name, fout) 33 | # fpkls += fpkl, 34 | 35 | # fout = 'output_epochbest_pred_stepall_run3Epochs.pkl' 36 | # fout = 'output_epochbest_pred_stepall_run1Epochs.pkl' 37 | # fpkl = os.path.join(ckpt_root, ckpt_name, fout) 38 | # fpkls += fpkl, 39 | 40 | ckpt_name = 'graph_gru_seq30_pred30_lr1.0e-04_wd1.0e-05_bt16_posNone_branchboth_collapse0_combinepair_adjTypespatial_nLayers2_v4Feats_pedGRU_newCtxtGRU_3evalEpoch' 41 | fpkl = os.path.join(ckpt_root, ckpt_name, fout) 42 | fpkls += fpkl, 43 | 44 | # graph - pred 60 45 | ckpt_name = 'graph_gru_seq30_pred60_lr1.0e-04_wd1.0e-05_bt16_posNone_branchboth_collapse0_combinepair_adjTypespatial_nLayers2_diffW0_v4Feats_pedGRU_3evalEpoch' 46 | fpkl = os.path.join(ckpt_root, ckpt_name, fout) 47 | fpkls += fpkl, 48 | 49 | 50 | # graph - pred 90 51 | ckpt_name = 'graph_gru_seq30_pred90_lr1.0e-04_wd1.0e-05_bt16_posNone_branchboth_collapse0_combinepair_adjTypespatial_nLayers2_diffW0_v4Feats_pedGRU_3evalEpoch' 52 | fpkl = os.path.join(ckpt_root, ckpt_name, fout) 53 | fpkls += fpkl, 54 | 55 | 56 | # best concat (pred 30) 57 | ckpt_name = 'concat_gru_seq30_pred30_lr1.0e-04_wd1.0e-05_bt4_posNone_branchboth_collapse0_combinepair_cacheMasks_fixGRU_eval3_9acts_noAct_sanityWithPose_withReLU_pedGRU' 58 | fout = 'output_all.pkl' 59 | fpkl = os.path.join(ckpt_root, ckpt_name, fout) 60 | 61 | # fpkls += fpkl, 62 | 63 | 64 | for i, fpkl in enumerate(fpkls): 65 | with open(fpkl, 'rb') as handle: 66 | data = pickle.load(handle) 67 | out, gt = data['out'], data['GT'] 68 | if out.shape[-1] % 10 == 1: 69 | out = out[:, :-1] 70 | gt = gt[:, :-1] 71 | try: 72 | acc = (out == gt).mean(0) 73 | pred_acc = (out[:, 30:] == gt[:, 30:]).mean() 74 | print('pred_acc:', pred_acc) 75 | except Exception as e: 76 | print(e) 77 | pdb.set_trace() 78 | acc = list(acc) 79 | if SMOOTH: 80 | t1, t3 = acc[1:]+[acc[-1]], [acc[0]]+acc[:-1] 81 | acc = (np.array(t1) + np.array(t3) + np.array(acc)) / 3 82 | # pdb.set_trace() 83 | plt.plot(acc, label='{} frames{}'.format(30*(i+1), ' (acc)' if MIX else '')) 84 | plt.legend() 85 | plt.xlabel('Frames', fontsize=12) 86 | plt.ylabel('Accuracy', fontsize=12) 87 | plt.savefig('acc_temp_graph{}.png'.format('_smooth' if SMOOTH else ''), dpi=1000) 88 | if not MIX: 89 | plt.clf() 90 | 91 | 92 | for i, fpkl in enumerate(fpkls): 93 | with open(fpkl, 'rb') as handle: 94 | data = pickle.load(handle) 95 | prob = data['prob'] 96 | B = data['out'].shape[0] 97 | prob = prob.reshape(B, -1) 98 | if prob.shape[-1] % 10 == 1: 99 | prob = prob[:, :-1] 100 | prob = list(prob.mean(0)) 101 | if SMOOTH: 102 | t1, t3 = prob[1:]+[prob[-1]], [prob[0]]+prob[:-1] 103 | prob = (np.array(t1) + np.array(t3) + np.array(prob)) / 3 104 | # pdb.set_trace() 105 | plt.plot(prob, label='{} frames{}'.format(30*(i+1), ' (prob)' if MIX else '')) 106 | plt.legend() 107 | plt.xlabel('Frames', fontsize=12) 108 | plt.ylabel('Accuracy & Probability' if MIX else 'Probability', fontsize=12) 109 | plt.savefig('{}_temp_graph{}.png'.format('mix' if MIX else 'prob', '_smooth' if SMOOTH else ''), dpi=1000) 110 | plt.clf() 111 | 112 | -------------------------------------------------------------------------------- /utils/colormap.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | ############################################################################## 15 | 16 | """An awesome colormap for really neat visualizations.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | from __future__ import unicode_literals 22 | 23 | import numpy as np 24 | 25 | 26 | def colormap(rgb=False): 27 | color_list = np.array( 28 | [ 29 | 0.000, 0.447, 0.741, 30 | 0.850, 0.325, 0.098, 31 | 0.929, 0.694, 0.125, 32 | 0.494, 0.184, 0.556, 33 | 0.466, 0.674, 0.188, 34 | 0.301, 0.745, 0.933, 35 | 0.635, 0.078, 0.184, 36 | 0.300, 0.300, 0.300, 37 | 0.600, 0.600, 0.600, 38 | 1.000, 0.000, 0.000, 39 | 1.000, 0.500, 0.000, 40 | 0.749, 0.749, 0.000, 41 | 0.000, 1.000, 0.000, 42 | 0.000, 0.000, 1.000, 43 | 0.667, 0.000, 1.000, 44 | 0.333, 0.333, 0.000, 45 | 0.333, 0.667, 0.000, 46 | 0.333, 1.000, 0.000, 47 | 0.667, 0.333, 0.000, 48 | 0.667, 0.667, 0.000, 49 | 0.667, 1.000, 0.000, 50 | 1.000, 0.333, 0.000, 51 | 1.000, 0.667, 0.000, 52 | 1.000, 1.000, 0.000, 53 | 0.000, 0.333, 0.500, 54 | 0.000, 0.667, 0.500, 55 | 0.000, 1.000, 0.500, 56 | 0.333, 0.000, 0.500, 57 | 0.333, 0.333, 0.500, 58 | 0.333, 0.667, 0.500, 59 | 0.333, 1.000, 0.500, 60 | 0.667, 0.000, 0.500, 61 | 0.667, 0.333, 0.500, 62 | 0.667, 0.667, 0.500, 63 | 0.667, 1.000, 0.500, 64 | 1.000, 0.000, 0.500, 65 | 1.000, 0.333, 0.500, 66 | 1.000, 0.667, 0.500, 67 | 1.000, 1.000, 0.500, 68 | 0.000, 0.333, 1.000, 69 | 0.000, 0.667, 1.000, 70 | 0.000, 1.000, 1.000, 71 | 0.333, 0.000, 1.000, 72 | 0.333, 0.333, 1.000, 73 | 0.333, 0.667, 1.000, 74 | 0.333, 1.000, 1.000, 75 | 0.667, 0.000, 1.000, 76 | 0.667, 0.333, 1.000, 77 | 0.667, 0.667, 1.000, 78 | 0.667, 1.000, 1.000, 79 | 1.000, 0.000, 1.000, 80 | 1.000, 0.333, 1.000, 81 | 1.000, 0.667, 1.000, 82 | 0.167, 0.000, 0.000, 83 | 0.333, 0.000, 0.000, 84 | 0.500, 0.000, 0.000, 85 | 0.667, 0.000, 0.000, 86 | 0.833, 0.000, 0.000, 87 | 1.000, 0.000, 0.000, 88 | 0.000, 0.167, 0.000, 89 | 0.000, 0.333, 0.000, 90 | 0.000, 0.500, 0.000, 91 | 0.000, 0.667, 0.000, 92 | 0.000, 0.833, 0.000, 93 | 0.000, 1.000, 0.000, 94 | 0.000, 0.000, 0.167, 95 | 0.000, 0.000, 0.333, 96 | 0.000, 0.000, 0.500, 97 | 0.000, 0.000, 0.667, 98 | 0.000, 0.000, 0.833, 99 | 0.000, 0.000, 1.000, 100 | 0.000, 0.000, 0.000, 101 | 0.143, 0.143, 0.143, 102 | 0.286, 0.286, 0.286, 103 | 0.429, 0.429, 0.429, 104 | 0.571, 0.571, 0.571, 105 | 0.714, 0.714, 0.714, 106 | 0.857, 0.857, 0.857, 107 | 1.000, 1.000, 1.000 108 | ] 109 | ).astype(np.float32) 110 | color_list = color_list.reshape((-1, 3)) * 255 111 | if not rgb: 112 | color_list = color_list[:, ::-1] 113 | return color_list 114 | -------------------------------------------------------------------------------- /utils/draw_As.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import cv2 3 | import numpy as np 4 | import os 5 | from glob import glob 6 | import matplotlib 7 | matplotlib.use('Agg') # Use a non-interactive backend 8 | import matplotlib.pyplot as plt 9 | from matplotlib.patches import Polygon 10 | import pickle 11 | import random 12 | 13 | from colormap import colormap 14 | 15 | import pdb 16 | 17 | FLIP = 1 18 | TRANS = 0 19 | 20 | ckpt_dir = '/sailhome/bingbin/STR-PIP/ckpts/JAAD/graph_gru_seq30_pred30_lr1.0e-04_wd1.0e-05_bt16_posNone_branchboth_collapse0_combinepair_adjTypespatial_nLayers2_v4Feats_pedGRU_newCtxtGRU_3evalEpoch/' 21 | 22 | cache_dir = os.path.join(ckpt_dir, 'test_graph_weights_epochbest_pred') 23 | out_dir = os.path.join(ckpt_dir, 'vis_out') 24 | os.makedirs(out_dir, exist_ok=True) 25 | 26 | def tmp_vis_one_image(): 27 | fpkls = sorted(glob(os.path.join(cache_dir, '*pkl'))) 28 | for pi,fpkl in enumerate(fpkls): 29 | if pi < 70: 30 | continue 31 | if pi and pi%10 == 0: 32 | print("{} / {}".format(pi, len(fpkls))) 33 | 34 | with open(fpkl, 'rb') as handle: 35 | vid = pickle.load(handle) 36 | v_ws = vid['ws'] 37 | v_bbox = vid['obj_bbox'] 38 | img_names = vid['img_paths'] 39 | 40 | for i in range(len(v_ws)): 41 | img_name = img_names[i] 42 | fid = os.path.basename(img_name).split('.')[0] 43 | out_name = os.path.join(out_dir, os.path.basename(fpkl).replace('.pkl', '_i{}_f{}.png'.format(i, fid))) 44 | ws = v_ws[i][-1] # take weights from the last graph layer 45 | vis_one_image(img_names[i], out_name, v_bbox[i], weights=ws) 46 | 47 | def vis_one_image(im_name, fout, bboxes, dpi=200, weights=None): 48 | # masks: (N, 28, 28) ... masks for one frame 49 | if not len(bboxes): 50 | return 51 | 52 | im = cv2.imread(im_name) 53 | H, W, _ = im.shape 54 | color_list = colormap(rgb=True) / 255 55 | 56 | fig = plt.figure(frameon=False) 57 | fig.set_size_inches(im.shape[1] / dpi, im.shape[0] / dpi) 58 | ax = plt.Axes(fig, [0., 0., 1., 1.]) 59 | ax.axis('off') 60 | fig.add_axes(ax) 61 | ax.imshow(im) 62 | 63 | mask_color_id = 0 64 | if weights is None: 65 | n_objs = masks.shape[0] 66 | obj_ids = range(n_objs) 67 | else: 68 | obj_ids = np.argsort(weights) 69 | 70 | ws = [0] 71 | for oid in obj_ids: 72 | x,y,w,h = bboxes[oid] 73 | mask = np.zeros([H, W]) 74 | mask[x:x+w, y:y+h] = 1 75 | mask = mask.astype('uint8') 76 | if mask.sum() == 0: 77 | continue 78 | 79 | if weights is not None: 80 | ws += weights[oid], 81 | color_mask = color_list[mask_color_id % len(color_list), 0:3] 82 | mask_color_id += 1 83 | 84 | w_ratio = .4 85 | for c in range(3): 86 | color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio 87 | 88 | e_down = mask 89 | 90 | e_pil = Image.fromarray(e_down) 91 | e_pil_up = e_pil.resize((H, W) if TRANS else (W, H),Image.ANTIALIAS) 92 | e = np.array(e_pil_up) 93 | 94 | _, contour, hier = cv2.findContours(e.copy(), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE) 95 | 96 | if len(contour) > 1: 97 | print('# contour:', len(contour)) 98 | for c in contour: 99 | if FLIP: 100 | assert(c.shape[1] == 1 and c.shape[2] == 2), print('c.shape:', c.shape) 101 | for pid in range(c.shape[0]): 102 | c[pid][0][0], c[pid][0][1] = c[pid][0][1], c[pid][0][0] 103 | linewidth = 1.2 104 | alpha = 0.5 105 | if oid == obj_ids[-1]: 106 | # most probable obj 107 | edgecolor=(1,0,0,1) # 'r' 108 | else: 109 | edgecolor=(1,1,1,1) # 'w' 110 | if weights is not None: 111 | linewidth *= (4 ** weights[oid]) 112 | alpha /= (4 ** weights[oid]) 113 | 114 | polygon = Polygon( 115 | c.reshape((-1, 2)), 116 | fill=True, facecolor=(color_mask[0], color_mask[1], color_mask[2], alpha), 117 | edgecolor=edgecolor, linewidth=linewidth, 118 | ) 119 | xy = polygon.get_xy() 120 | 121 | ax.add_patch(polygon) 122 | 123 | fig.savefig(fout.replace('.jpg', '_{:.3f}.jpg'.format(max(ws))), dpi=dpi) 124 | plt.close('all') 125 | 126 | 127 | if __name__ == '__main__': 128 | # tmp_wrapper() 129 | tmp_vis_one_image() 130 | 131 | -------------------------------------------------------------------------------- /models/backbone/imagenet_pretraining.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.model_zoo as model_zoo 3 | try: 4 | from models.backbone.resnet.resnet import model_urls 5 | except: 6 | from resnet.resnet import model_urls 7 | 8 | 9 | 10 | def _inflate_weight(w, new_temporal_size, inflation='center'): 11 | w_up = w.unsqueeze(2).repeat(1, 1, new_temporal_size, 1, 1) 12 | if inflation == 'center': 13 | w_up = central_inflate_3D_conv(w_up) # center 14 | elif inflation == 'mean': 15 | w_up /= new_temporal_size # mean 16 | return w_up 17 | 18 | 19 | def central_inflate_3D_conv(w): 20 | new_temporal_size = w.size(2) 21 | middle_timestep = int(new_temporal_size / 2.) 22 | before, after = list(range(middle_timestep)), list(range(middle_timestep + 1, new_temporal_size)) 23 | if len(before) > 0: 24 | w[:, :, before] = torch.zeros_like(w[:, :, before]) 25 | if len(after): 26 | w[:, :, after] = torch.zeros_like(w[:, :, after]) 27 | return w 28 | 29 | 30 | def inflate_temporal_conv(pretrained_W_updated, model_dict, inflation): 31 | for k, v in model_dict.items(): 32 | if '1t' in k: 33 | if 'conv' in k: 34 | if 'bias' in k: 35 | v_up = torch.zeros_like(v) 36 | elif 'weight' in k: 37 | v_up = torch.zeros_like(v) 38 | 39 | h, w, T, *_ = v_up.size() 40 | t_2 = int(T / 2.) 41 | for i in range(h): 42 | for j in range(w): 43 | if i == j: 44 | if inflation == 'center': 45 | v_up[i, j, t_2] = torch.ones_like(v_up[i, j, t_2]) 46 | elif inflation == 'mean': 47 | v_up[i, j] = torch.ones_like(v_up[i, j]) / T 48 | # elif 'bn' in k: 49 | # if 'running_mean' in k: 50 | # v_up = torch.zeros_like(v) 51 | # elif 'running_var' in k: 52 | # v_up = torch.ones_like(v) - 1e-05 53 | # elif 'bias' in k: 54 | # v_up = torch.zeros_like(v) 55 | # elif 'weight' in k: 56 | # v_up = torch.ones_like(v) 57 | 58 | # udpate 59 | pretrained_W_updated.update({k: v_up}) 60 | return pretrained_W_updated 61 | 62 | 63 | def _update_pretrained_weights(model, pretrained_W, inflation='center'): 64 | pretrained_W_updated = pretrained_W.copy() 65 | model_dict = model.state_dict() 66 | for k, v in pretrained_W.items(): 67 | if "conv" in k or ('bn' in k and '1t' not in k) or 'downsample' in k: 68 | if k in model_dict.keys(): 69 | if len(model_dict[k].shape) == 5: 70 | new_temporal_size = model_dict[k].size(2) 71 | v_updated = _inflate_weight(v, new_temporal_size, inflation) 72 | else: 73 | v_updated = v 74 | 75 | if isinstance(v, torch.autograd.Variable): 76 | pretrained_W_updated.update({k: v_updated.data}) 77 | else: 78 | pretrained_W_updated.update({k: v_updated}) 79 | if "fc.weight" in k: 80 | pretrained_W_updated.pop('fc.weight', None) 81 | if "fc.bias" in k: 82 | pretrained_W_updated.pop('fc.bias', None) 83 | 84 | # update the dict for 1D conv for 2.5D conv 85 | pretrained_W_updated = inflate_temporal_conv(pretrained_W_updated, model_dict, inflation) 86 | 87 | # update the state dict 88 | model_dict.update(pretrained_W_updated) 89 | 90 | return model_dict 91 | 92 | 93 | def _keep_only_existing_keys(model, pretrained_weights_inflated): 94 | # Loop over the model_dict and update W 95 | model_dict = model.state_dict() # Take the initial weights 96 | for k, v in model_dict.items(): 97 | if k in pretrained_weights_inflated.keys(): 98 | model_dict[k] = pretrained_weights_inflated[k] 99 | return model_dict 100 | 101 | 102 | def load_pretrained_2D_weights(arch, model, inflation): 103 | pretrained_weights = model_zoo.load_url(model_urls[arch]) 104 | pretrained_weights_inflated = _update_pretrained_weights(model, pretrained_weights, inflation) 105 | model.load_state_dict(pretrained_weights_inflated) 106 | print(" -> Init: Imagenet - 3D from 2D (inflation = {})".format(inflation)) 107 | return model 108 | -------------------------------------------------------------------------------- /utils/tracker.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import json 3 | import numpy as np 4 | import optparse 5 | import os 6 | import time 7 | 8 | 9 | def check_color(crossed): 10 | if crossed: 11 | return (0, 0, 255) 12 | return (0, 255, 0) 13 | 14 | 15 | def create_rect(box): 16 | x1, y1 = int(box['x1']), int(box['y1']) 17 | x2, y2 = int(box['x2']), int(box['y2']) 18 | 19 | return x1, y1, x2, y2 20 | 21 | 22 | def create_writer(capture): 23 | fourcc = cv2.VideoWriter_fourcc(*'X264') 24 | height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)) 25 | width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)) 26 | frame_count = int(capture.get(cv2.CAP_PROP_FPS)) 27 | 28 | writer = cv2.VideoWriter('output.mp4', 29 | fourcc, 30 | 2, 31 | (width, height)) 32 | 33 | return writer 34 | 35 | 36 | def get_params(frame_data): 37 | boxes = [f['box'] for f in frame_data] 38 | ids = [f['matchIds'] for f in frame_data] 39 | crossed = [f['crossed'] for f in frame_data] 40 | 41 | return boxes, ids, crossed 42 | 43 | 44 | def parse_options(): 45 | parser = optparse.OptionParser() 46 | parser.add_option('-v', '--video', 47 | dest='video_path', 48 | default=None) 49 | parser.add_option('-j', '--json', 50 | dest='json_path', 51 | default=None) 52 | parser.add_option('-u', '--until', 53 | type='int', 54 | default=None) 55 | parser.add_option('-w', '--write', 56 | dest='write', 57 | action='store_true', 58 | default=False) 59 | parser.add_option('-m', '--mask', 60 | dest='mask', 61 | action='store_true', 62 | default=False) 63 | options, remainder = parser.parse_args() 64 | 65 | # Check for errors. 66 | if options.video_path is None: 67 | raise Exception('Undefined video') 68 | if options.json_path is None: 69 | raise Exception('Undefined json_file') 70 | 71 | return options 72 | 73 | 74 | def Main(): 75 | options = parse_options() 76 | 77 | # Open VideoCapture. 78 | cap = cv2.VideoCapture(options.video_path) 79 | 80 | # Load json file with annotations. 81 | with open(options.json_path, 'r') as f: 82 | data = json.load(f)['frames'] 83 | 84 | if options.write: 85 | writer = create_writer(cap) 86 | 87 | font = cv2.FONT_HERSHEY_SIMPLEX 88 | frame_no = 1 89 | while True: 90 | 91 | wait_key = 25 92 | flag, img = cap.read() 93 | 94 | # Create black image. 95 | black_img = np.zeros(img.shape, dtype=np.uint8) 96 | 97 | if frame_no % 120 == 0: 98 | print('Processed {0} frames'.format(frame_no)) 99 | 100 | if frame_no % 6 != 0: 101 | frame_no += 1 102 | continue 103 | 104 | key = str(int(frame_no / 6 + 1)) 105 | 106 | boxes = data.get(key) 107 | 108 | if boxes == None: 109 | boxes = [] 110 | 111 | # Create list of trackers each 60 frames. 112 | boxes, ids, crossed = get_params(boxes) 113 | 114 | for i, box in enumerate(boxes): 115 | x1, y1, x2, y2 = create_rect(box) 116 | 117 | if options.mask: 118 | roi = img[y1:y2, x1:x2].copy() 119 | black_img[y1:y2, x1:x2] = roi 120 | 121 | if not options.mask: 122 | crossed_color = check_color(crossed[i]) 123 | cv2.rectangle(img, (x1, y1), (x2, y2), crossed_color, 2, 1) 124 | cv2.putText(img, ids[i], (x1, y1 - 10), font, 0.6, 125 | (0, 0, 0), 5, cv2.LINE_AA) 126 | cv2.putText(img, ids[i], (x1, y1 - 10), font, 0.6, 127 | crossed_color, 1, cv2.LINE_AA) 128 | 129 | if options.write: 130 | if options.mask: 131 | writer.write(black_img) 132 | else: 133 | writer.write(img) 134 | else: 135 | if options.mask: 136 | cv2.imshow('frame', black_img) 137 | else: 138 | cv2.imshow('frame', img) 139 | 140 | if cv2.waitKey(wait_key) & 0xFF == ord('q'): 141 | break 142 | if frame_no == options.until: 143 | break 144 | 145 | if flag is False: 146 | break 147 | 148 | frame_no += 1 149 | 150 | cap.release() 151 | if options.write: 152 | writer.release() 153 | 154 | 155 | if __name__ == '__main__': 156 | Main() 157 | -------------------------------------------------------------------------------- /models/backbone/resnet/bottleneck.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class Bottleneck(nn.Module): 5 | expansion = 4 6 | only_2D = False 7 | 8 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): 9 | super(Bottleneck, self).__init__() 10 | self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False) 11 | self.bn1 = nn.BatchNorm3d(planes) 12 | self.conv2 = None 13 | self.bn2 = nn.BatchNorm3d(planes) 14 | self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=False) 15 | self.bn3 = nn.BatchNorm3d(planes * 4) 16 | self.relu = nn.ReLU(inplace=True) 17 | self.downsample = downsample 18 | self.stride = stride 19 | self.input_dim = 5 20 | self.dilation = dilation 21 | 22 | def forward(self, x): 23 | residual = x 24 | 25 | # print('Bottleneck devices:') 26 | # print('current device:', torch.cuda.current_device()) 27 | # print('x:', x.device) 28 | # print('self.conv1:', torch.cuda.device_of(self.conv1)) 29 | out = self.conv1(x) 30 | out = self.bn1(out) 31 | out = self.relu(out) 32 | 33 | out = self.conv2(out) 34 | out = self.bn2(out) 35 | out = self.relu(out) 36 | 37 | out = self.conv3(out) 38 | out = self.bn3(out) 39 | 40 | if self.downsample is not None: 41 | residual = self.downsample(x) 42 | 43 | out += residual 44 | out = self.relu(out) 45 | 46 | return out 47 | 48 | 49 | class Bottleneck3D(Bottleneck): 50 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1, **kwargs): 51 | super().__init__(inplanes, planes, stride, downsample, dilation) 52 | self.conv2 = nn.Conv3d(planes, planes, kernel_size=3, stride=stride, 53 | padding=1, bias=False, dilation=(1, dilation, dilation)) 54 | 55 | 56 | class Bottleneck2D(Bottleneck): 57 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1, **kwargs): 58 | super().__init__(inplanes, planes, stride, downsample, dilation) 59 | # to speed up the inference process 60 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 61 | self.bn1 = nn.BatchNorm2d(planes) 62 | self.bn2 = nn.BatchNorm2d(planes) 63 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False, dilation=dilation) 64 | self.bn3 = nn.BatchNorm2d(planes * 4) 65 | self.input_dim = 4 66 | 67 | if isinstance(stride, int): 68 | stride_1, stride_2 = stride, stride 69 | else: 70 | stride_1, stride_2 = stride[0], stride[1] 71 | 72 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=(3, 3), stride=(stride_1, stride_2), 73 | padding=(1, 1), bias=False) 74 | 75 | 76 | class Bottleneck2_1D(Bottleneck): 77 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1, nb_temporal_conv=1): 78 | super().__init__(inplanes, planes, stride, downsample, dilation) 79 | 80 | if isinstance(stride, int): 81 | stride_2d, stride_1t = (1, stride, stride), (stride, 1, 1) 82 | else: 83 | stride_2d, stride_1t = (1, stride[1], stride[2]), (stride[0], 1, 1) 84 | 85 | # CONV2 86 | self.conv2 = nn.Conv3d(planes, planes, kernel_size=(1, 3, 3), stride=stride_2d, 87 | padding=(0, dilation, dilation), bias=False, dilation=dilation) 88 | 89 | self.conv2_1t = nn.Sequential() 90 | for i in range(nb_temporal_conv): 91 | temp_conv = nn.Conv3d(planes, planes, kernel_size=(3, 1, 1), stride=stride_1t, 92 | padding=(1, 0, 0), bias=False, dilation=1) 93 | self.conv2_1t.add_module('temp_conv_{}'.format(i), temp_conv) 94 | self.conv2_1t.add_module(('relu_{}').format(i), nn.ReLU(inplace=True)) 95 | 96 | 97 | def forward(self, x): 98 | residual = x 99 | 100 | ## CONV1 - 3D (1,1,1) 101 | out = self.conv1(x) 102 | out = self.bn1(out) 103 | out = self.relu(out) 104 | 105 | ## CONV2 106 | # Spatial - 2D (1,3,3) 107 | out = self.conv2(out) 108 | out = self.bn2(out) 109 | out = self.relu(out) 110 | 111 | # Temporal - 3D (3,1,1) 112 | out = self.conv2_1t(out) 113 | 114 | ## CONV3 - 3D (1,1,1) 115 | out = self.conv3(out) 116 | out = self.bn3(out) 117 | 118 | if self.downsample is not None: 119 | residual = self.downsample(x) 120 | 121 | out += residual 122 | out = self.relu(out) 123 | 124 | return out 125 | -------------------------------------------------------------------------------- /utils/masking.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import json 3 | import optparse 4 | import os 5 | import time 6 | import numpy as np 7 | import IPython 8 | 9 | MACHINE = 'TRI' 10 | # MACHINE = 'Stanford' 11 | 12 | if MACHINE == 'TRI': 13 | # downsampled from:'/mnt/parallel/stip/ANN_hanh2/20170907_prolix_trial_ANN_hanh2-09-07-2017_15-44-07_idx00.mkv' 14 | test_video_in = '/mnt/parallel/stip/ANN_hanh2/ANN_hanh2_downsample_fps12.mkv' 15 | test_video_out = 'ANN_hanh2_fps12_masked.mp4' 16 | test_annot = '/mnt/parallel/stip/annotation/20170907_prolix_trial_ANN_hanh2-09-07-2017_15-44-07.concat.12fps.mp4.json' 17 | else: 18 | test_video_in = '/sailhome/bingbin/STR-PIP/datasets/STIP/ANN_hanh2_downsampled.mkv' 19 | test_video_out = 'ANN_hanh2_fps12_masked.mp4' 20 | test_annot = '/sailhome/bingbin/STR-PIP/datasets/STIP/annotations/20170907_prolix_trial_ANN_hanh2-09-07-2017_15-44-07.concat.12fps.mp4.json' 21 | 22 | 23 | def check_color(crossed): 24 | if crossed: 25 | return (0, 0, 255) 26 | return (0, 255, 0) 27 | 28 | 29 | def create_rect(box): 30 | x1, y1 = int(box['x1']), int(box['y1']) 31 | x2, y2 = int(box['x2']), int(box['y2']) 32 | 33 | return x1, y1, x2, y2 34 | 35 | 36 | def create_writer(capture, options): 37 | fourcc = cv2.VideoWriter_fourcc(*'mp4v') 38 | height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)) 39 | width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)) 40 | frame_count = int(capture.get(cv2.CAP_PROP_FPS)) 41 | 42 | print('create_writer:', options.saved_video_path) 43 | writer = cv2.VideoWriter(options.saved_video_path, 44 | fourcc, 45 | 2, 46 | (width, height)) 47 | 48 | return writer 49 | 50 | 51 | def get_params(frame_data): 52 | boxes = [f['box'] for f in frame_data] 53 | ids = [f['matchIds'] for f in frame_data] 54 | crossed = [f['crossed'] for f in frame_data] 55 | 56 | return boxes, ids, crossed 57 | 58 | 59 | def parse_options(): 60 | parser = optparse.OptionParser() 61 | parser.add_option('-v', '--video', 62 | dest='video_path', 63 | default=test_video_in) 64 | parser.add_option('-j', '--json', 65 | dest='json_path', 66 | default=test_annot) 67 | parser.add_option('-u', '--until', 68 | type='int', 69 | default=None) 70 | parser.add_option('-w', '--write_video', 71 | dest='saved_video_path', 72 | default=test_video_out) 73 | options, remainder = parser.parse_args() 74 | 75 | # Check for errors. 76 | if options.video_path is None: 77 | raise Exception('Undefined video') 78 | if options.json_path is None: 79 | raise Exception('Undefined json_file') 80 | 81 | return options 82 | 83 | 84 | def Main(): 85 | options = parse_options() 86 | print('here') 87 | # Open VideoCapture. 88 | cap = cv2.VideoCapture(options.video_path) 89 | 90 | # Load json file with annotations. 91 | with open(options.json_path, 'r') as f: 92 | data = json.load(f)['frames'] 93 | 94 | lastKeyFrame = int(list(data.keys())[-1]) 95 | 96 | writer = create_writer(cap, options) 97 | 98 | font = cv2.FONT_HERSHEY_SIMPLEX 99 | frame_no = 1 100 | while True: 101 | wait_key = 25 102 | flag, img = cap.read() 103 | if frame_no % 120 == 0: 104 | print('Processed {0} frames'.format(frame_no)) 105 | 106 | if frame_no % 6 != 0: 107 | frame_no += 1 108 | continue 109 | 110 | key = str(int(frame_no / 6 + 1)) 111 | 112 | boxes = data.get(key) 113 | 114 | if boxes == None: 115 | boxes = [] 116 | 117 | # Create list of trackers each 60 frames. 118 | boxes, ids, crossed = get_params(boxes) 119 | mask = np.zeros(img.shape) 120 | for i, box in enumerate(boxes): 121 | x1, y1, x2, y2 = create_rect(box) 122 | if ids[i] == '4' or ids[i] == '16': 123 | print(frame_no, key, box) 124 | print((x1, y1, x2, y2)) 125 | 126 | mask[y1:y2,x1:x2, :] = 1 127 | # crossed_color = check_color(crossed[i]) 128 | # cv2.rectangle(img, (x1, y1), (x2, y2), crossed_color, 2, 1) 129 | if '4' in ids or '16' in ids: 130 | wait_key = 0 131 | 132 | 133 | # print('writing') 134 | writer.write(np.uint8(img*mask)) 135 | 136 | if frame_no == options.until: 137 | break 138 | 139 | if frame_no > lastKeyFrame: 140 | break 141 | 142 | frame_no += 1 143 | 144 | cap.release() 145 | writer.release() 146 | 147 | 148 | if __name__ == '__main__': 149 | Main() 150 | -------------------------------------------------------------------------------- /cache_data.py: -------------------------------------------------------------------------------- 1 | # python imports 2 | import os 3 | import sys 4 | from glob import glob 5 | import pickle 6 | import time 7 | import numpy as np 8 | 9 | 10 | # local imports 11 | import data 12 | import utils 13 | from utils.data_proc import parse_objs 14 | 15 | import pdb 16 | 17 | 18 | def cache_masks(): 19 | opt, logger = utils.build(is_train=False) 20 | opt.combine_method = '' 21 | opt.split = 'train' 22 | cache_dir_name = 'jaad_collapse{}'.format('_'+opt.combine_method if opt.combine_method else '') 23 | data.cache_all_objs(opt, cache_dir_name) 24 | 25 | 26 | def cache_crops(): 27 | fnpy_root = '/sailhome/bingbin/STR-PIP/datasets/JAAD_instance_segm' 28 | fpkl_root = '/sailhome/bingbin/STR-PIP/datasets/cache/JAAD_instance_crops' 29 | utils.get_obj_crops(fnpy_root, fpkl_root) 30 | 31 | def add_obj_bbox(): 32 | fnpy_root = '/sailhome/bingbin/STR-PIP/datasets/JAAD_instance_segm' 33 | # fpkl_root = '/sailhome/bingbin/STR-PIP/datasets/cache/jaad_collapse' 34 | fobj_root = '/sailhome/bingbin/STR-PIP/datasets/cache/obj_bbox' 35 | os.makedirs(fobj_root, exist_ok=True) 36 | dir_vids = sorted(glob(os.path.join(fnpy_root, 'vid*'))) 37 | 38 | def helper(vid_range, split): 39 | for dir_vid in vid_range: 40 | print(dir_vid) 41 | sys.stdout.flush() 42 | vid = int(os.path.basename(dir_vid).split('_')[1]) 43 | t_start = time.time() 44 | fsegms = sorted(glob(os.path.join(dir_vid, '*_segm.npy'))) 45 | for i, fsegm in enumerate(fsegms): 46 | if i and i%100 == 0: 47 | print('Time per frame:', (time.time() - t_start)/i) 48 | sys.stdout.flush() 49 | # Note: 'fid' is 0-based for segm, but 1-based for images and caches. 50 | fid = os.path.basename(fsegm).split('_')[0] 51 | fbbox = os.path.join(fobj_root, 'vid{:08d}_fid{:s}.pkl'.format(vid, fid)) 52 | if os.path.exists(fbbox): 53 | continue 54 | if not os.path.exists(fsegm): 55 | print('File does not exist:', fsegm) 56 | continue 57 | objs = parse_objs(fsegm) 58 | dobjs = {cls:[] for cls in range(1,5)} 59 | for cls, masks in objs.items(): 60 | for mask in masks: 61 | try: 62 | if len(mask.shape) == 3: 63 | h, w, c = mask.shape 64 | if c != 1: 65 | raise ValueError('Each mask should have shape (1080, 1920, 1)') 66 | mask = mask.reshape(h, w) 67 | x_pos = mask.sum(0).nonzero()[0] 68 | if not len(x_pos): 69 | x_pos = [0,0] 70 | x_min, x_max = x_pos[0], x_pos[-1] 71 | y_pos = mask.sum(1).nonzero()[0] 72 | if not len(y_pos): 73 | y_pos = [0,0] 74 | y_min, y_max = y_pos[0], y_pos[-1] 75 | # bbox: [x_min, y_min, w, h]; same as bbox for ped['pos_GT'] 76 | bbox = [x_min, y_min, x_max-x_min, y_max-y_min] 77 | except Exception as e: 78 | print(e) 79 | pdb.set_trace() 80 | 81 | dobjs[cls] += bbox, 82 | with open(fbbox, 'wb') as handle: 83 | pickle.dump(dobjs, handle) 84 | 85 | if False: 86 | vids_train = dir_vids[:250] 87 | helper(vids_train, 'train') 88 | if True: 89 | vids_test = dir_vids[250:] 90 | helper(vids_test, 'test') 91 | 92 | def merge_and_flat(vrange): 93 | """ 94 | Merge fids in a vid and flatten the classes 95 | """ 96 | pkl_in_root = '/sailhome/bingbin/STR-PIP/datasets/cache/obj_bbox' 97 | pkl_out_root = '/sailhome/bingbin/STR-PIP/datasets/cache/obj_bbox_merged' 98 | os.makedirs(pkl_out_root, exist_ok=True) 99 | # for vid in range(1, 347): 100 | for vid in vrange: 101 | fpkls = sorted(glob(os.path.join(pkl_in_root, 'vid{:08d}*pkl'.format(vid)))) 102 | print(vid, len(fpkls)) 103 | sys.stdout.flush() 104 | # merged = [[] for _ in range(len(fpkls))] 105 | merged_bbox = [] 106 | merged_cls = [] 107 | t_start = time.time() 108 | for fpkl in fpkls: 109 | with open(fpkl, 'rb') as handle: 110 | data = pickle.load(handle) 111 | curr_bbox = [] 112 | cls = [] 113 | for c in [1,2,3,4]: 114 | for bbox in data[c]: 115 | cls += c, 116 | curr_bbox += bbox, 117 | merged_bbox += np.array(curr_bbox), 118 | merged_cls += np.array(cls), 119 | 120 | fpkl_out = os.path.join(pkl_out_root, 'vid{:08d}.pkl'.format(vid)) 121 | with open(fpkl_out, 'wb') as handle: 122 | dout = { 123 | 'obj_cls': merged_cls, 124 | 'obj_bbox': merged_bbox, 125 | } 126 | pickle.dump(dout, handle) 127 | print('avg time: ', (time.time()-t_start) / len(fpkls)) 128 | 129 | if __name__ == '__main__': 130 | # cache_masks() 131 | cache_crops() 132 | add_obj_bbox() 133 | 134 | # merge_and_flat(range(1, 200)) 135 | # merge_and_flat(range(100, 200)) 136 | # merge_and_flat(range(200, 300)) 137 | # merge_and_flat(range(100, 347)) 138 | merge_and_flat(range(1, 347)) 139 | 140 | -------------------------------------------------------------------------------- /models/backbone/resnet/basicblock.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | def conv3x3(in_planes, out_planes, stride=1, dilation=1): 4 | "3x3 convolution with padding" 5 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 6 | padding=1, bias=False, dilation=dilation) 7 | 8 | 9 | def conv3x3x3(in_planes, out_planes, stride=1, dilation=1): 10 | "3x3 convolution with padding" 11 | return nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride, 12 | padding=1, bias=False, dilation=(1, dilation, dilation)) 13 | 14 | 15 | def conv1x3x3(in_planes, out_planes, stride=1, dilation=1): 16 | "3x3 convolution with padding" 17 | if isinstance(stride, int): 18 | stride_1, stride_2, stride_3 = 1, stride, stride 19 | else: 20 | stride_1, stride_2, stride_3 = 1, stride[1], stride[2] 21 | 22 | return nn.Conv3d(in_planes, out_planes, kernel_size=(1, 3, 3), 23 | stride=(stride_1, stride_2, stride_3), 24 | padding=(0, 1, 1), bias=False, dilation=(1, dilation, dilation)) 25 | 26 | 27 | def conv1x3x3_conv3x1x1(in_planes, out_planes, stride=1, dilation=1, nb_temporal_conv=3): 28 | "3x3 convolution with padding" 29 | if isinstance(stride, int): 30 | stride_2d, stride_1t = (1, stride, stride), (stride, 1, 1) 31 | else: 32 | stride_2d, stride_1t = (1, stride[1], stride[2]), (stride[0], 1, 1) 33 | 34 | _2d = nn.Conv3d(in_planes, out_planes, kernel_size=(1, 3, 3), stride=stride_2d, 35 | padding=(0, 1, 1), bias=False, dilation=dilation) 36 | 37 | _1t = nn.Sequential() 38 | for i in range(nb_temporal_conv): 39 | temp_conv = nn.Conv3d(out_planes, out_planes, kernel_size=(3, 1, 1), stride=stride_1t, 40 | padding=(1, 0, 0), bias=False, dilation=1) 41 | _1t.add_module('temp_conv_{}'.format(i), temp_conv) 42 | _1t.add_module(('relu_{}').format(i), nn.ReLU(inplace=True)) 43 | 44 | return _2d, _1t 45 | 46 | 47 | class BasicBlock(nn.Module): 48 | expansion = 1 49 | only_2D = False 50 | 51 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): 52 | super(BasicBlock, self).__init__() 53 | self.bn1 = nn.BatchNorm3d(planes) 54 | self.relu = nn.ReLU(inplace=True) 55 | self.bn2 = nn.BatchNorm3d(planes) 56 | self.downsample = downsample 57 | self.stride = stride 58 | self.conv1, self.conv2 = None, None 59 | self.input_dim = 5 60 | self.dilation = dilation 61 | 62 | def forward(self, x): 63 | residual = x 64 | 65 | out = self.conv1(x) 66 | out = self.bn1(out) 67 | out = self.relu(out) 68 | 69 | out = self.conv2(out) 70 | out = self.bn2(out) 71 | 72 | if self.downsample is not None: 73 | residual = self.downsample(x) 74 | 75 | out += residual 76 | out = self.relu(out) 77 | 78 | return out 79 | 80 | 81 | class BasicBlock3D(BasicBlock): 82 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1, **kwargs): 83 | super().__init__(inplanes, planes, stride, downsample, dilation) 84 | self.conv1 = conv3x3x3(inplanes, planes, stride, dilation) 85 | self.conv2 = conv3x3x3(planes, planes, dilation) 86 | 87 | 88 | class BasicBlock2D(BasicBlock): 89 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1, **kwargs): 90 | super().__init__(inplanes, planes, stride, downsample, dilation) 91 | # not the same input size here to speed up training 92 | self.bn1 = nn.BatchNorm2d(planes) 93 | self.bn2 = nn.BatchNorm2d(planes) 94 | self.conv1 = conv3x3(inplanes, planes, stride, dilation) 95 | self.conv2 = conv3x3(planes, planes, dilation) 96 | self.input_dim = 4 97 | 98 | 99 | class BasicBlock2_1D(BasicBlock): 100 | expansion = 1 101 | 102 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1, nb_temporal_conv=1): 103 | super().__init__(inplanes, planes, stride, downsample, dilation) 104 | self.conv1, self.conv1_1t = conv1x3x3_conv3x1x1(inplanes, planes, stride, dilation, 105 | nb_temporal_conv=nb_temporal_conv) 106 | self.conv2, self.conv2_1t = conv1x3x3_conv3x1x1(planes, planes, dilation, 107 | nb_temporal_conv=nb_temporal_conv) 108 | 109 | 110 | def forward(self, x): 111 | residual = x 112 | 113 | out = self.conv1(x) # 2D in space 114 | out = self.bn1(out) 115 | out = self.relu(out) 116 | 117 | # ipdb.set_trace() 118 | out = self.conv1_1t(out) # 1D in time + relu after each conv 119 | 120 | out = self.conv2(out) # 2D in space 121 | out = self.bn2(out) 122 | 123 | if self.downsample is not None: 124 | residual = self.downsample(x) 125 | 126 | out += residual 127 | out = self.relu(out) 128 | 129 | out = self.conv2_1t(out) # 1D in time 130 | 131 | return out 132 | -------------------------------------------------------------------------------- /utils/postproc_smooth.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pickle 4 | 5 | import pdb 6 | 7 | def smooth_vote(fpkl, win_size=5): 8 | with open(fpkl, 'rb') as handle: 9 | data = pickle.load(handle) 10 | out = data['out'] 11 | gt = data['GT'] 12 | acc = (out == gt).mean() 13 | 14 | smooth = np.zeros_like(out) 15 | win_half = win_size // 2 16 | smooth[:, :win_half] = out[:, :win_half] 17 | 18 | z2o, o2z = 0, 0 19 | for i,row in enumerate(out): 20 | for j in range(win_half, len(row)): 21 | # vote within the window 22 | smooth[i, j] = 1 if row[j-win_half:j+win_half].mean() >= 0.5 else 0 23 | if smooth[i,j] != out[i,j]: 24 | if out[i,j] == 0: 25 | z2o += 1 26 | else: 27 | o2z += 1 28 | acc_smooth = (smooth == gt).mean() 29 | 30 | print('Acc:', acc) 31 | print('Acc smooth:', acc_smooth) 32 | 33 | print('o2z:', o2z) 34 | print('z2o:', z2o) 35 | 36 | def smooth_flip(fpkl): 37 | with open(fpkl, 'rb') as handle: 38 | data = pickle.load(handle) 39 | out = data['out'] 40 | gt = data['GT'] 41 | acc = (out == gt).mean() 42 | 43 | smooth = np.zeros_like(out) 44 | smooth[:, :1] = out[:, :1] 45 | 46 | z2o, o2z = 0, 0 47 | for i,row in enumerate(out): 48 | for j in range(1, len(row)-1): 49 | # check with the neighbors 50 | if row[j] != row[j-1] and row[j] != row[j+1]: 51 | smooth[i,j] = row[j-1] 52 | if row[j-1] == 0: 53 | o2z += 1 54 | else: 55 | z2o += 1 56 | else: 57 | smooth[i,j] = row[j] 58 | acc_smooth = (smooth == gt).mean() 59 | 60 | print('Acc:', acc) 61 | print('Acc smooth:', acc_smooth) 62 | 63 | print('o2z:', o2z) 64 | print('z2o:', z2o) 65 | 66 | def smooth_sticky(fpkl, win_size=5): 67 | with open(fpkl, 'rb') as handle: 68 | data = pickle.load(handle) 69 | out = data['out'] 70 | gt = data['GT'] 71 | acc = (out == gt).mean() 72 | 73 | smooth = np.zeros_like(out) 74 | smooth[:, :win_size] = out[:, :win_size] 75 | 76 | for i,row in enumerate(out): 77 | for j in range(win_size, len(row)): 78 | if smooth[i, j-win_size:j].mean() > 0.5: 79 | smooth[i, j] = 1 80 | else: 81 | smooth[i,j] = row[j] 82 | acc_smooth = (smooth == gt).mean() 83 | 84 | print('Acc:', acc) 85 | print('Acc smooth:', acc_smooth) 86 | 87 | o2z = ((smooth!=out) * (out==1)).sum() 88 | z2o = ((smooth!=out) * (out==0)).sum() 89 | print('o2z:', o2z) 90 | print('z2o:', z2o) 91 | 92 | print('GT o:', gt.mean()) 93 | print('GT nelem:', gt.size) 94 | 95 | def smooth_hold(fpkl): 96 | # Stay 1 once gets to 1. 97 | with open(fpkl, 'rb') as handle: 98 | data = pickle.load(handle) 99 | out = data['out'][:, :-1] 100 | gt = data['GT'][:, :-1] 101 | print('Acc det:', (out[:, :30] == gt[:, :30]).mean()) 102 | 103 | out = out[:, 30:] 104 | gt = gt[:, 30:] 105 | acc = (out == gt).mean() 106 | 107 | # pdb.set_trace() 108 | 109 | smooth = out.copy() 110 | for i,row in enumerate(smooth): 111 | # pdb.set_trace() 112 | pos = np.where(row == 1)[0] 113 | if len(pos): 114 | first = pos[0] 115 | row[first:] = 1 116 | acc_smooth = (smooth == gt).mean() 117 | 118 | print('Acc pred:', acc) 119 | print('Acc pred smooth:', acc_smooth) 120 | print('Acc last:', (smooth[:, -1] == gt[:, -1]).mean()) 121 | 122 | def smooth_hold_lastObserve(fpkl): 123 | # Stay 1 once gets to 1. 124 | with open(fpkl, 'rb') as handle: 125 | data = pickle.load(handle) 126 | out = data['out'][:, :-1] 127 | gt = data['GT'][:, :-1] 128 | print('Acc det:', (out[:, :30] == gt[:, :30]).mean()) 129 | 130 | last_observe = out[:, 29] 131 | print('prob of last==1:', last_observe.mean()) 132 | out = out[:, 30:] 133 | gt = gt[:, 30:] 134 | acc = (out == gt).mean() 135 | 136 | # pdb.set_trace() 137 | 138 | smooth = out.copy() 139 | for i,row in enumerate(smooth): 140 | if last_observe[i] == 1: 141 | row[:] = 1 142 | acc_smooth = (smooth == gt).mean() 143 | 144 | print('Acc pred:', acc) 145 | print('Acc pred smooth:', acc_smooth) 146 | print('Acc last:', (smooth[:, -1] == gt[:, -1]).mean()) 147 | 148 | 149 | def smooth_wrapper(): 150 | # ped-centric model 151 | ckpt_dir = '/sailhome/bingbin/STR-PIP/ckpts/JAAD/' 152 | # ckpt_name = 'graph_gru_seq30_pred30_lr1.0e-05_wd1.0e-05_bt16_posNone_branchped_collapse0_combinepair_adjTypeembed_nLayers2_v2Feats' 153 | # label_name = 'label_epochbest_det_stepall.pkl' 154 | ckpt_name = 'graph_gru_seq30_pred30_lr3.0e-04_wd1.0e-05_bt16_posNone_branchboth_collapse0_combinepair_adjTypespatial_nLayers2_v4Feats_pedGRU_3evalEpoch' 155 | label_name = 'label_epochbest_pred_stepall.pkl' 156 | 157 | fpkl = os.path.join(ckpt_dir, ckpt_name, label_name) 158 | 159 | win_size = 3 160 | print('Vote (win_size={}):'.format(win_size)) 161 | smooth_vote(fpkl, win_size=win_size) 162 | 163 | print('\nFlip:') 164 | smooth_flip(fpkl) 165 | 166 | print('\nSticky:') 167 | smooth_sticky(fpkl) 168 | 169 | print('\nHold:') 170 | smooth_hold(fpkl) 171 | 172 | print('\nHold last observed:') 173 | smooth_hold_lastObserve(fpkl) 174 | 175 | if __name__ == '__main__': 176 | smooth_wrapper() 177 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Spatiotemporal Relationship Reasoning for Pedestrian Intent Prediction (STR-PIP) 2 | 3 | **_Paper: https://arxiv.org/pdf/2002.08945.pdf_** 4 | 5 | ``` 6 | @inproceeding{liu2020spatiotemporal, 7 |    title={Spatiotemporal Relationship Reasoning for Pedestrian Intent Prediction}, 8 |    author={Bingbin Liu and Ehsan Adeli and Zhangjie Cao and Kuan-Hui Lee and Abhijeet Shenoi and Adrien Gaidon and Juan Carlos Niebles}, 9 |    year={2020}, 10 |    booktitle={IEEE Robotics and Automation Letters (IEEE RA-L) and International Conference on Robotics and Automation (ICRA)}, 11 |    publisher={IEEE} 12 | } 13 | ``` 14 | 15 | ## Abstract 16 | Screen Shot 2021-03-14 at 3 58 00 AM 17 | 18 | Reasoning over visual data is a desirable capability for robotics and vision-based applications. Such reasoning enables forecasting the next events or actions in videos. In recent years, various models have been developed based on convolution operations for prediction or forecasting, but they lack the ability to reason over spatiotemporal data and infer the relationships of different objects in the scene. In this paper, we present a framework based on graph convolution to uncover the spatiotemporal relationships in the scene for reasoning about pedestrian intent. A scene graph is built on top of segmented object instances within and across video frames. Pedestrian intent, defined as the future action of crossing or not-crossing the street, is very crucial piece of information for autonomous vehicles to navigate safely and more smoothly. We approach the problem of intent prediction from two different perspectives and anticipate the intention-to-cross within both pedestrian-centric and location-centric scenarios. In addition, we introduce a new dataset designed specifically for autonomousdriving scenarios in areas with dense pedestrian populations: the Stanford-TRI Intent Prediction (STIP) dataset. Our experiments on STIP and another benchmark dataset show that our graph modeling framework is able to predict the intention-to-cross of the pedestrians with an accuracy of 79.10% on STIP and 79.28% on Joint Attention for Autonomous Driving (JAAD) dataset up to one second earlier than when the actual crossing happens. These results outperform baseline and previous work. Please refer to [https://stip.stanford.edu](https://stip.stanford.edu) for the dataset and code. 19 | 20 | Screen Shot 2021-03-14 at 3 58 00 AM 21 | 22 | ## Datasets 23 | ### Stanford-TRI Intention Prediction (STIP) Dataset: 24 | STIP includes over 900 hours of driving scene videos of front, right, and left cameras, while the vehicle was driving in dense areas of five cities in the United States. The videos were annotated at 2fps with pedestrian bounding boxes and labels of crossing/not-crossing the street, which are respectively shown with green/red boxes in the above videos. We used the [JRMOT (JackRabbot real-time Multi-Object Tracker) platform](https://sites.google.com/view/jrmot) to track the pedestrian and interpolate the annotations for all 20 frame per second. 25 | 26 | **Dataset Code:** https://github.com/StanfordVL/STIP 27 | 28 | **Dataset Information:** https://stip.stanford.edu/dataset.html 29 | 30 | **Request Access to Dataset:** [here](https://docs.google.com/forms/d/e/1FAIpQLSdG5CLJQs7QWY27uIkZj27O4XDm0-OsZVEmBRiHB8EaCoNZXA/viewform) 31 | 32 | ### Joint Attention in Autonomous Driving (JAAD) Dataset: 33 | JAAD is a dataset for studying joint attention in the context of autonomous driving. The focus is on pedestrian and driver behaviors at the point of crossing and factors that influence them. To this end, JAAD dataset provides a richly annotated collection of 346 short video clips (5-10 sec long) extracted from over 240 hours of driving footage. Bounding boxes with occlusion tags are provided for all pedestrians making this dataset suitable for pedestrian detection. 34 | Behavior annotations specify behaviors for pedestrians that interact with or require attention of the driver. For each video there are several tags (weather, locations, etc.) and timestamped behavior labels from a fixed list (e.g. stopped, walking, looking, etc.). In addition, a list of demographic attributes is provided for each pedestrian (e.g. age, gender, direction of motion, etc.) as well as a list of visible traffic scene elements (e.g. stop sign, traffic signal, etc.) for each frame. 35 | 36 | **Annotation Data:** https://github.com/ykotseruba/JAAD 37 | 38 | **Full Dataset:** http://data.nvision2.eecs.yorku.ca/JAAD_dataset 39 | 40 | ## Installation 41 | 42 | ### Notes Before Running 43 | 44 | This training code is setup to be ran on a 16GB GPU. You may have to make some adjustments if you do not have this hardware available. 45 | 46 | ### Create your Virtual Environment 47 | ``` 48 | conda create --name crossing python=3.7 49 | conda activate crossing 50 | ``` 51 | 52 | ### Install Required Libraries 53 | ``` 54 | pip install torchvision==0.5.0 55 | pip install opencv-python 56 | pip install pycocotools 57 | pip install pickle 58 | pip install numpy 59 | pip install wandb 60 | ``` 61 | 62 | ### Configure Environment Variables 63 | ``` 64 | export PYTHONPATH=$PYTHONPATH:[path-to-repo] 65 | ``` 66 | 67 | ### Clone this repo 68 | ``` 69 | git clone https://github.com/StanfordVL/STR-PIP.git 70 | cd STR-PIP 71 | ``` 72 | 73 | ## Usage 74 | Running "scripts_dir/train_concat.sh" 75 | 76 | In order to generate "annot_test_ped_withTag_sanityNoPose.pkl", you will need to run utils/data_proc.py. Specifically, we will run the function prepare_data() inside data_proc.py. 77 | 78 | 79 | ## Current code status (3-14-21) 80 | 81 | * With the correct pre-processing, STIP training/testing works. 82 | * We are looking for some specific artifacts to train on JAAD. We are very close to generating the correct files. 83 | * This code base is in the process of clean-up, doc, and re-factor. There may be sudden changes. We will publish a new release when it is a more ready state. 84 | 85 | -------------------------------------------------------------------------------- /scripts_dir/test_graph.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | mode='evaluate' 4 | split='test' 5 | 6 | gpu_id=1 7 | n_workers=0 8 | n_acts=9 9 | bt=1 10 | 11 | seq_len=30 12 | predict=1 13 | pred_seq_len=30 14 | predict_k=0 15 | 16 | # test only 17 | slide=0 18 | rand_test=1 19 | log_every=10 20 | ckpt_dir='/sailhome/bingbin/STR-PIP/ckpts' 21 | dataset='JAAD' 22 | 23 | # pred 30 24 | # ckpt_name='graph_gru_seq30_pred30_lr1.0e-05_wd1.0e-05_bt16_posNone_branchped_collapse0_combinepair_adjTypeembed_nLayers2_v2Feats' 25 | # ckpt_name='graph_gru_seq30_pred30_lr3.0e-04_wd1.0e-05_bt16_posNone_branchboth_collapse0_combinepair_adjTypespatial_nLayers2_v4Feats_pedGRU_3evalEpoch' 26 | # ckpt_name='graph_gru_seq30_pred30_lr1.0e-04_wd1.0e-05_bt16_posNone_branchboth_collapse0_combinepair_adjTypespatial_nLayers2_diffW1_v4Feats_pedGRU_3evalEpoch' 27 | # ckpt_name='graph_gru_seq30_pred30_lr1.0e-04_wd1.0e-05_bt16_posNone_branchboth_collapse0_combinepair_adjTypespatial_nLayers2_v4Feats_pedGRU_3evalEpoch' 28 | ckpt_name='graph_gru_seq30_pred30_lr1.0e-04_wd1.0e-05_bt16_posNone_branchboth_collapse0_combinepair_adjTypespatial_nLayers2_v4Feats_pedGRU_newCtxtGRU_3evalEpoch' # best 29 | 30 | # pred 60 31 | # ckpt_name='graph_gru_seq30_pred60_lr1.0e-04_wd1.0e-05_bt16_posNone_branchboth_collapse0_combinepair_adjTypespatial_nLayers2_diffW0_v4Feats_pedGRU_3evalEpoch' 32 | 33 | # pred 90 34 | # ckpt_name='graph_gru_seq30_pred90_lr1.0e-04_wd1.0e-05_bt16_posNone_branchboth_collapse0_combinepair_adjTypespatial_nLayers2_diffW0_v4Feats_pedGRU_3evalEpoch' 35 | 36 | which_epoch=-1 37 | if [ $which_epoch -eq -1 ] 38 | then 39 | epoch_name='best_pred' 40 | else 41 | epoch_name=$which_epoch 42 | fi 43 | save_output=1 44 | save_output_format=$ckpt_dir'/'$dataset'/'$ckpt_name'/output_epoch'$epoch_name'_step{}.pkl' 45 | collect_A=1 46 | save_As_format=$ckpt_dir'/'$dataset'/'$ckpt_name'/test_graph_weights_epoch'$epoch_name'/vid{}_eval{}.pkl' 47 | 48 | 49 | 50 | load_cache='feats' 51 | # cache_format='/sailhome/bingbin/STR-PIP/datasets/cache/JAAD_conv_feats/concat_gru_lr1.0e-04_wd1.0e-05_bt4_ped_collapse0_combinepair_useBBox0_cacheMasks_fixGRU_singleTime/{}/ped{}_fid{}.pkl' 52 | cache_format='/sailhome/bingbin/STR-PIP/datasets/cache/JAAD_conv_feats/concat_gru_seq30_pred30_lr1.0e-04_wd1.0e-05_bt4_posNone_branchboth_collapse0_combinepair_cacheMasks_fixGRU_eval3_9acts_noAct_sanityWithPose_withReLU_pedGRU/{}/ped{}_fid{}.pkl' 53 | 54 | 55 | # cache_format='/sailhome/bingbin/STR-PIP/datasets/cache/JAAD_conv_feats/concat_gru_lr1.0e-05_bt4_test_epoch5/{}/ped{}_fid{}.pkl' 56 | 57 | # ckpt_name='graph_gru_lr1.0e-05_wd1.0e-05_bt16_ped_collapse0_combinepair_adjTypeembed_nLayers0_useBBox0_fixGRU' 58 | # which_epoch=50 59 | # 60 | # ckpt_name='graph_gru_lr1.0e-05_wd1.0e-05_bt16_ped_collapse0_combinepair_adjTypeembed_nLayers2_useBBox0_fixGRU' 61 | # which_epoch=22 62 | # 63 | # ckpt_name='graph_gru_lr1.0e-05_wd1.0e-05_bt16_both_collapse0_combinepair_adjTypeembed_nLayers2_useBBox0_fixGRU' 64 | # which_epoch=38 65 | # 66 | # ckpt_name='graph_gru_lr1.0e-05_wd1.0e-05_bt16_ped_collapse0_combinepair_adjTypeuniform_nLayers0_useBBox0_fixGRU' 67 | # which_epoch=42 68 | 69 | if [ "$mode" = "extract" ] 70 | then 71 | extract_feats_dir='/sailhome/bingbin/STR-PIP/datasets/cache/JAAD_conv_feats/concat_gru_lr1.0e-05_bt4_test_epoch5/test/' 72 | else 73 | extract_feats_dir='none_existent' 74 | fi 75 | 76 | use_gru=1 77 | use_trn=0 78 | ped_gru=1 79 | ctxt_gru=0 80 | ctxt_node=0 81 | 82 | # features 83 | use_act=0 84 | use_gt_act=0 85 | use_driver=0 86 | use_pose=0 87 | pos_mode='none' 88 | # branch='ped' 89 | branch='both' 90 | adj_type='spatial' 91 | use_obj_cls=0 92 | n_layers=2 93 | diff_layer_weight=0 94 | collapse_cls=0 95 | combine_method='pair' 96 | 97 | CUDA_VISIBLE_DEVICES=$gpu_id python3 test.py \ 98 | --model='graph' \ 99 | --split=$split \ 100 | --mode=$mode \ 101 | --slide=$slide \ 102 | --rand-test=$rand_test \ 103 | --ckpt-dir=$ckpt_dir \ 104 | --ckpt-name=$ckpt_name \ 105 | --n-acts=$n_acts \ 106 | --device=$gpu_id \ 107 | --dset-name='JAAD' \ 108 | --ckpt-name=$ckpt_name \ 109 | --which-epoch=$which_epoch \ 110 | --save-output=$save_output \ 111 | --save-output-format=$save_output_format \ 112 | --collect-A=$collect_A \ 113 | --save-As-format=$save_As_format \ 114 | --n-epochs=$n_epochs \ 115 | --start-epoch=$start_epoch \ 116 | --seq-len=$seq_len \ 117 | --predict=$predict \ 118 | --pred-seq-len=$pred_seq_len \ 119 | --predict-k=$predict_k \ 120 | --batch-size=$bt \ 121 | --log-every=$log_every \ 122 | --n-workers=$n_workers \ 123 | --load-cache=$load_cache \ 124 | --cache-format=$cache_format \ 125 | --branch=$branch \ 126 | --adj-type=$adj_type \ 127 | --use-obj-cls=$use_obj_cls \ 128 | --n-layers=$n_layers \ 129 | --diff-layer-weight=$diff_layer_weight \ 130 | --collapse-cls=$collapse_cls \ 131 | --combine-method=$combine_method \ 132 | --use-gru=$use_gru \ 133 | --use-trn=$use_trn \ 134 | --ped-gru=$ped_gru \ 135 | --ctxt-gru=$ctxt_gru \ 136 | --ctxt-node=$ctxt_node \ 137 | --use-act=$use_act \ 138 | --use-gt-act=$use_gt_act \ 139 | --use-driver=$use_driver \ 140 | --use-pose=$use_pose \ 141 | --pos-mode=$pos_mode 142 | 143 | exit 144 | 145 | # bak of prev command 146 | 147 | CUDA_VISIBLE_DEVICES=$gpu_id python3 test.py \ 148 | --model='graph' \ 149 | --split=$split \ 150 | --mode=$mode \ 151 | --device=$gpu_id \ 152 | --log-every=$log_every \ 153 | --dset-name=$dataset \ 154 | --ckpt-dir=$ckpt_dir \ 155 | --ckpt-name=$ckpt_name \ 156 | --batch-size=1 \ 157 | --n-workers=$n_workers \ 158 | --load-cache=$load_cache \ 159 | --cache-format=$cache_format \ 160 | --n-layers=$n_layers \ 161 | --seq-len=$seq_len \ 162 | --predict=$predict \ 163 | --pred-seq-len=$pred_seq_len \ 164 | --use-gru=$use_gru \ 165 | --pos-mode=$pos_mode \ 166 | --adj-type=$adj_type \ 167 | --collapse-cls=$collapse_cls \ 168 | --slide=$slide \ 169 | --rand-test=$rand_test \ 170 | --branch=$branch \ 171 | --which-epoch=$which_epoch \ 172 | --extract-feats-dir=$extract_feats_dir \ 173 | --save-output=$save_output \ 174 | --save-output-format=$save_output_format 175 | 176 | -------------------------------------------------------------------------------- /models/backbone/resnet/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch 4 | import pdb 5 | try: 6 | from models.backbone.resnet.basicblock import BasicBlock2D 7 | from models.backbone.resnet.bottleneck import Bottleneck2D 8 | # from utils.other import transform_input 9 | # from utils.meter import * 10 | except: 11 | from resnet.basicblock import BasicBlock2D 12 | from resnet.bottleneck import Bottleneck2D 13 | # from basemodel.utils.other import transform_input 14 | # from basemodel.utils.meter import * 15 | 16 | model_urls = { 17 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 18 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 19 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 20 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 21 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 22 | } 23 | 24 | K_1st_CONV = 3 25 | 26 | 27 | class ResNetBackBone(nn.Module): 28 | def __init__(self, blocks, layers, 29 | str_first_conv='2D', 30 | nb_temporal_conv=1, 31 | list_stride=[1, 2, 2, 2], 32 | **kwargs): 33 | self.nb_temporal_conv = nb_temporal_conv 34 | self.inplanes = 64 35 | super(ResNetBackBone, self).__init__() 36 | self._first_conv(str_first_conv) 37 | self.relu = nn.ReLU(inplace=True) 38 | self.list_channels = [64, 128, 256, 512] 39 | self.list_inplanes = [] 40 | self.list_inplanes.append(self.inplanes) # store the inplanes after layer1 41 | self.layer1 = self._make_layer(blocks[0], self.list_channels[0], layers[0], stride=list_stride[0]) 42 | self.list_inplanes.append(self.inplanes) # store the inplanes after layer1 43 | self.layer2 = self._make_layer(blocks[1], self.list_channels[1], layers[1], stride=list_stride[1]) 44 | self.list_inplanes.append(self.inplanes) # store the inplanes after layer2 45 | self.layer3 = self._make_layer(blocks[2], self.list_channels[2], layers[2], stride=list_stride[2]) 46 | self.list_inplanes.append(self.inplanes) # store the inplanes after layer3 47 | self.layer4 = self._make_layer(blocks[3], self.list_channels[3], layers[3], stride=list_stride[3]) 48 | self.avgpool, self.avgpool_space, self.avgpool_time = None, None, None 49 | 50 | # Init of the weights 51 | for m in self.modules(): 52 | if isinstance(m, nn.Conv3d) or isinstance(m, nn.Conv2d): 53 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 54 | m.weight.data.normal_(0, math.sqrt(2. / n)) 55 | elif isinstance(m, nn.BatchNorm3d) or isinstance(m, nn.BatchNorm2d): 56 | m.weight.data.fill_(1) 57 | m.bias.data.zero_() 58 | 59 | def _first_conv(self, str): 60 | self.conv1_1t = None 61 | self.bn1_1t = None 62 | if str == '3D_stabilize': 63 | self.conv1 = nn.Conv3d(3, 64, kernel_size=(K_1st_CONV, 7, 7), stride=(1, 2, 2), padding=(1, 3, 3), 64 | bias=False) 65 | self.maxpool = nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)) 66 | self.bn1 = nn.BatchNorm3d(64) 67 | 68 | 69 | elif str == '2.5D_stabilize': 70 | self.conv1 = nn.Conv3d(3, 64, kernel_size=(1, 7, 7), stride=(1, 2, 2), padding=(0, 3, 3), 71 | bias=False) 72 | self.conv1_1t = nn.Conv3d(64, 64, kernel_size=(K_1st_CONV, 1, 1), stride=(1, 1, 1), 73 | padding=(1, 0, 0), 74 | bias=False) 75 | self.bn1_1t = nn.BatchNorm3d(64) 76 | self.maxpool = nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)) 77 | self.bn1 = nn.BatchNorm3d(64) 78 | 79 | elif str == '2D': 80 | self.conv1 = nn.Conv2d(3, 64, 81 | kernel_size=(7, 7), 82 | stride=(2, 2), 83 | padding=(3, 3), 84 | bias=False) 85 | self.maxpool = nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) 86 | self.bn1 = nn.BatchNorm2d(64) 87 | 88 | else: 89 | raise NameError 90 | 91 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1): 92 | downsample = None 93 | 94 | # Upgrade the stride is spatio-temporal kernel 95 | if not (block == BasicBlock2D or block == Bottleneck2D): 96 | stride = (1, stride, stride) 97 | 98 | if stride != 1 or self.inplanes != planes * block.expansion: 99 | if block is BasicBlock2D or block is Bottleneck2D: 100 | conv, batchnorm = nn.Conv2d, nn.BatchNorm2d 101 | else: 102 | conv, batchnorm = nn.Conv3d, nn.BatchNorm3d 103 | 104 | downsample = nn.Sequential( 105 | conv(self.inplanes, planes * block.expansion, 106 | kernel_size=1, stride=stride, bias=False, dilation=dilation), 107 | batchnorm(planes * block.expansion), 108 | ) 109 | 110 | layers = [] 111 | layers.append( 112 | block(self.inplanes, planes, stride, downsample, dilation, nb_temporal_conv=self.nb_temporal_conv)) 113 | self.inplanes = planes * block.expansion 114 | for i in range(1, blocks): 115 | layers.append(block(self.inplanes, planes, nb_temporal_conv=self.nb_temporal_conv)) 116 | 117 | return nn.Sequential(*layers) 118 | 119 | def forward(self, x, num=4) : 120 | # pdb.set_trace() 121 | 122 | x = self.conv1(x) 123 | x = self.bn1(x) 124 | x = self.relu(x) 125 | 126 | if self.conv1_1t is not None: 127 | x = self.conv1_1t(x) 128 | x = self.bn1_1t(x) 129 | x = self.relu(x) 130 | 131 | x = self.maxpool(x) 132 | 133 | x = self.layer1(x) 134 | x = self.layer2(x) 135 | x = self.layer3(x) 136 | x = self.layer4(x) 137 | 138 | # Global average pooling 139 | self.avgpool = nn.AvgPool2d((x.size(-1), x.size(-1))) if self.avgpool is None else self.avgpool 140 | x = self.avgpool(x) 141 | 142 | # Final classifier 143 | # x = x.view(x.size(0), -1) 144 | # x = self.fc_classifier(x) 145 | 146 | return x 147 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pickle 4 | import copy 5 | import sys, traceback, code 6 | import torch 7 | 8 | import data 9 | import models 10 | import utils 11 | from test import evaluate 12 | 13 | import pdb 14 | 15 | import wandb 16 | 17 | N_EVAL_EPOCHS = 3 18 | 19 | opt, logger = utils.build(is_train=True) 20 | with open(os.path.join(opt.ckpt_path, 'opt.pkl'), 'wb') as handle: 21 | pickle.dump(opt, handle) 22 | 23 | tags = [opt.model, opt.branch, 'w/ bbox' if opt.use_bbox else 'no bbox', 'seq{}'.format(opt.seq_len)] 24 | if opt.model == 'graph': 25 | tags += ['{} layer'.format(opt.n_layers), opt.adj_type] 26 | tags += ['diff {}'.format(opt.diff_layer_weight)] 27 | tags += [opt.adj_type] 28 | if opt.model == 'pos' or opt.pos_mode != 'none': 29 | tags += opt.pos_mode, 30 | if opt.predict: 31 | tags += 'pred{}'.format(opt.pred_seq_len), 32 | if opt.predict_k: 33 | tags += 'pred_k{}'.format(opt.predict_k), 34 | if opt.ped_gru: 35 | tags += 'pedGRU', 36 | if opt.ctxt_gru: 37 | tags += 'ctxtGRU', 38 | if opt.ctxt_node: 39 | tags += 'ctxtNode', 40 | if opt.load_cache == 'none': 41 | tags += 'cacheNone', 42 | if opt.load_cache == 'masks': 43 | tags += 'cacheMasks', 44 | if opt.load_cache == 'feats': 45 | tags += 'cacheFeats', 46 | if opt.use_driver: 47 | tags += 'driver', 48 | tags += '{}evalEpochs'.format(N_EVAL_EPOCHS), 49 | wandb.init( 50 | project='crossing', 51 | tags=tags) 52 | 53 | opt_dict = vars(opt) 54 | print('Options:') 55 | for key in sorted(opt_dict): 56 | print('{}: {}'.format(key, opt_dict[key])) 57 | 58 | 59 | # train 60 | print(opt) 61 | train_loader = data.get_data_loader(opt) 62 | print('Train dataset: {}'.format(len(train_loader.dataset))) 63 | 64 | # val 65 | # val_opt = copy.deepcopy(opt) 66 | val_opt, _ = utils.build(is_train=False) 67 | val_opt.split = 'test' 68 | val_opt.slide = 0 69 | val_opt.is_train = False 70 | val_opt.rand_test = True 71 | val_opt.batch_size = 1 72 | val_opt.slide = 0 73 | val_loader = data.get_data_loader(val_opt) 74 | print('Val dataset: {}'.format(len(val_loader.dataset))) 75 | 76 | model = models.get_model(opt) 77 | # model = model.to('cuda:{}'.format(opt.device)) 78 | if opt.pretrained_path and os.path.exist(opt.pretrained_path): 79 | print('Loading model from', opt.pretrained_path) 80 | model.load(opt.pretrained_path) 81 | model = model.to('cuda:0') 82 | wandb.watch(model) 83 | 84 | if opt.load_ckpt_dir != '': 85 | ckpt_dir = os.path.join(opt.ckpt_dir, opt.dset_name, opt.load_ckpt_dir) 86 | assert os.path.exists(ckpt_dir) 87 | logger.print('Loading checkpoint from {}'.format(ckpt_dir)) 88 | model.load(ckpt_dir, opt.load_ckpt_epoch) 89 | 90 | opt.n_epochs = max(opt.n_epochs, opt.n_iters // len(train_loader)) 91 | logger.print('Total epochs: {}'.format(opt.n_epochs)) 92 | 93 | 94 | def train(): 95 | best_eval_acc = 0 96 | best_eval_loss = 10 97 | best_epoch = -1 98 | 99 | if val_opt.predict or val_opt.predict_k: 100 | # pred 101 | best_eval_acc_pred = 0 102 | best_eval_loss_pred = 10 103 | best_epoch_pred = -1 104 | # last 105 | best_eval_acc_last = 0 106 | best_eval_loss_last = 10 107 | best_epoch_last = -1 108 | 109 | for epoch in range(opt.start_epoch, opt.n_epochs): 110 | model.setup() 111 | print('Train epoch', epoch) 112 | model.update_hyperparameters(epoch) 113 | 114 | losses = [] 115 | for step, data in enumerate(train_loader): 116 | # break 117 | if epoch == 0: 118 | torch.cuda.empty_cache() 119 | # break 120 | loss = model.step_train(data) 121 | losses += loss, 122 | 123 | torch.cuda.empty_cache() 124 | 125 | if step % opt.log_every == 0: 126 | print('avg loss:', sum(losses) / len(losses)) 127 | wandb.log({"Train loss:":sum(losses) / len(losses)}) 128 | losses = [] 129 | 130 | # Evaluate on val set 131 | if opt.evaluate_every > 0 and (epoch + 1) % opt.evaluate_every == 0: 132 | result_det, result_pred, result_last = evaluate(model, val_loader, val_opt, n_eval_epochs=N_EVAL_EPOCHS) 133 | eval_acc_frame, eval_acc_clip, eval_acc_cross, eval_acc_non_cross, eval_loss = result_det 134 | if eval_acc_frame > best_eval_acc: 135 | best_eval_acc = eval_acc_frame 136 | best_eval_loss = eval_loss 137 | best_epoch = epoch+1 138 | model.save(opt.ckpt_path, best_epoch, 'best_det') 139 | wandb.log({ 140 | 'eval_acc_frame':eval_acc_frame, 'eval_acc_clip':eval_acc_clip, 141 | 'eval_acc_cross':eval_acc_cross, 'eval_acc_non_cross':eval_acc_non_cross, 142 | 'eval_loss':eval_loss, 143 | 'best_eval_acc': best_eval_acc, 'best_eval_loss':best_eval_loss, 'best_epoch':best_epoch}) 144 | 145 | if val_opt.predict or val_opt.predict_k: 146 | # pred 147 | eval_acc_frame, eval_acc_clip, eval_acc_cross, eval_acc_non_cross, eval_loss = result_pred 148 | if eval_acc_frame > best_eval_acc_pred: 149 | best_eval_acc_pred = eval_acc_frame 150 | best_eval_loss_pred = eval_loss 151 | best_epoch_pred = epoch+1 152 | model.save(opt.ckpt_path, best_epoch_pred, 'best_pred') 153 | wandb.log({ 154 | 'eval_acc_frame_pred':eval_acc_frame, 'eval_acc_clip_pred':eval_acc_clip, 155 | 'eval_acc_cross_pred':eval_acc_cross, 'eval_acc_non_cross_pred':eval_acc_non_cross, 156 | 'eval_loss_pred':eval_loss, 157 | 'best_eval_acc_pred': best_eval_acc_pred, 'best_eval_loss_pred':best_eval_loss_pred, 'best_epoch_pred':best_epoch_pred}) 158 | # last 159 | eval_acc_frame, eval_acc_clip, eval_acc_cross, eval_acc_non_cross, eval_loss = result_last 160 | if eval_acc_frame > best_eval_acc_last: 161 | best_eval_acc_last = eval_acc_frame 162 | best_eval_loss_last = eval_loss 163 | best_epoch_last = epoch+1 164 | model.save(opt.ckpt_path, best_epoch_last, 'best_last') 165 | wandb.log({ 166 | 'eval_acc_frame_last':eval_acc_frame, 'eval_acc_clip_last':eval_acc_clip, 167 | 'eval_acc_cross_last':eval_acc_cross, 'eval_acc_non_cross_last':eval_acc_non_cross, 168 | 'eval_loss_last':eval_loss, 169 | 'best_eval_acc_last': best_eval_acc_last, 'best_eval_loss_last':best_eval_loss_last, 'best_epoch_last':best_epoch_last}) 170 | 171 | 172 | # Save model checkpoints 173 | if (epoch + 1) % opt.save_every == 0 and epoch >= 0 or epoch == opt.n_epochs - 1: 174 | model.save(opt.ckpt_path, epoch+1) 175 | 176 | 177 | try: 178 | train() 179 | except Exception as e: 180 | print(e) 181 | typ, vacl, tb = sys.exc_info() 182 | traceback.print_exc() 183 | last_frame = lambda tb=tb: last_frame(tb.tb_next) if tb.tb_next else tb 184 | frame = last_frame().tb_frame 185 | ns = dict(frame.f_globals) 186 | ns.update(frame.f_locals) 187 | code.interact(local=ns) 188 | 189 | -------------------------------------------------------------------------------- /utils/stip_merge_views.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | import numpy as np 4 | import pickle 5 | import torch 6 | import pdb 7 | 8 | 9 | if False: 10 | # Objs 11 | print('Objects') 12 | root = '/vision/group/prolix/processed' 13 | center_root = os.path.join(root, 'center', 'obj_bbox_20fps') 14 | left_root = os.path.join(root, 'left', 'obj_bbox_20fps') 15 | right_root = os.path.join(root, 'right', 'obj_bbox_20fps') 16 | merged_root = os.path.join(root, 'all', 'obj_bbox_20fps') 17 | os.makedirs(merged_root, exist_ok=True) 18 | 19 | def merged_obj_loader(f): 20 | with open(f, 'rb') as handle: 21 | data = pickle.load(handle) 22 | 23 | obj_cls, obj_bbox = [], [] 24 | for (cs, bs) in zip(data['obj_cls'], data['obj_bbox']): 25 | # bs: [xmin, ymin, w, h] 26 | curr_cls, curr_bbox = [], [] 27 | for i in range(len(bs)): 28 | if bs[i, 3] != 0 and bs[i,2] != 0: 29 | # keep only non-empty bbox 30 | curr_cls += cs[i], 31 | curr_bbox += bs[i], 32 | obj_cls += np.array(curr_cls), 33 | obj_bbox += np.array(curr_bbox), 34 | 35 | return obj_cls, obj_bbox 36 | 37 | segs = sorted(glob(os.path.join(center_root, '*pkl'))) 38 | for seg in segs: 39 | fl = seg.replace(center_root, left_root).replace('pkl', 'pkl.pkl') 40 | fr = seg.replace(center_root, right_root).replace('pkl', 'pkl.pkl') 41 | if not os.path.exists(fl) or not os.path.exists(fr): 42 | continue 43 | fout = seg.replace(center_root, merged_root) 44 | # if os.path.exists(fout): 45 | # continue 46 | 47 | with open(seg, 'rb') as handle: 48 | data = pickle.load(handle) 49 | with open(fl, 'rb') as handle: 50 | data_l = pickle.load(handle) 51 | with open(fr, 'rb') as handle: 52 | data_r = pickle.load(handle) 53 | assert(len(data) == len(data_l)) 54 | assert(len(data) == len(data_r)) 55 | data_all = {} 56 | for cls in data: 57 | data_l[cls] = [[x-1216, y, h, w] for (x,y,h,w) in data_l[cls]] 58 | data_r[cls] = [[x+1216, y, h, w] for (x,y,h,w) in data_l[cls]] 59 | data_all[cls] = data[cls] + data_l[cls] + data_r[cls] 60 | with open(fout, 'wb') as handle: 61 | pickle.dump(data_all, handle) 62 | 63 | # obj_cls_c, obj_bbox_c = obj_loader(seg) 64 | # obj_cls_l, obj_bbox_l = obj_loader(fl) 65 | # obj_cls_r, obj_bbox_r = obj_loader(fr) 66 | 67 | # obj_cls = [np.concatenate([c,l,r], 0) for (c,l,r) in zip(obj_cls_c, obj_cls_l, obj_cls_r)] 68 | # obj_bbox = [] 69 | # for (c,l,r) in zip(obj_bbox_c, obj_bbox_l, obj_bbox_r): 70 | # valid = [each for each in [c,l,r] if len(each.shape) > 1] 71 | # if valid: 72 | # a = np.concatenate(valid, 0) 73 | # else: 74 | # a = np.array([]) 75 | # obj_bbox += a, 76 | # # obj_bbox = [np.concatenate([c,l,r], 0) for (c,l,r) in zip(obj_bbox_c, obj_bbox_l, obj_bbox_r)] 77 | # 78 | # with open(fout, 'wb') as handle: 79 | # pickle.dump({'obj_cls':obj_cls, 'obj_bbox':obj_bbox}, handle) 80 | 81 | # Masks 82 | if False: 83 | print('Masks') 84 | root = '/vision/group/prolix/processed' 85 | center_root = os.path.join(root, 'cache/center') 86 | left_root = os.path.join(root, 'cache/left') 87 | right_root = os.path.join(root, 'cache/right') 88 | merged_root = os.path.join(root, 'cache/all') 89 | os.makedirs(merged_root, exist_ok=True) 90 | 91 | for split in ['train']: 92 | split_root = os.path.join(merged_root, split) 93 | os.makedirs(split_root, exist_ok=True) 94 | 95 | caches = sorted(glob(os.path.join(center_root, split, '*pkl'))) 96 | for fc in caches: 97 | fl = fc.replace(center_root, left_root) 98 | fr = fc.replace(center_root, right_root) 99 | if not os.path.exists(fl) or not os.path.exists(fr): 100 | continue 101 | fout = fc.replace(center_root, merged_root) 102 | # if os.path.exists(fout): 103 | # continue 104 | 105 | with open(fc, 'rb') as handle: 106 | data = pickle.load(handle) 107 | with open(fl, 'rb') as handle: 108 | data_l = pickle.load(handle) 109 | with open(fr, 'rb') as handle: 110 | data_r = pickle.load(handle) 111 | 112 | data_out = {'ped_crops': data['ped_crops']} 113 | merged_masks = {k:[] for k in range(1,5)} 114 | keys = list(merged_masks.keys()) 115 | for k in keys: 116 | if k in data['masks']: 117 | merged_masks[k] += data['masks'][k], 118 | if k in data_l['masks']: 119 | merged_masks[k] += data_l['masks'][k], 120 | if k in data_r['masks']: 121 | merged_masks[k] += data_r['masks'][k], 122 | if len(merged_masks[k]) == 0: 123 | merged_masks.pop(k) 124 | else: 125 | merged_masks[k] = torch.cat(merged_masks[k], 0) 126 | data_out['masks'] = merged_masks 127 | 128 | with open(fout, 'wb') as handle: 129 | pickle.dump(data_out, handle) 130 | 131 | # Feats 132 | if True: 133 | print('Feats') 134 | root = '/vision/group/prolix/processed/cache/' 135 | center_root = os.path.join(root, 'center/STIP_conv_feats') 136 | left_root = os.path.join(root, 'left/STIP_conv_feats') 137 | right_root = os.path.join(root, 'right/STIP_conv_feats') 138 | merged_root = os.path.join(root, 'all/STIP_conv_feats') 139 | os.makedirs(merged_root, exist_ok=True) 140 | 141 | for split in ['train']: 142 | caches = sorted(glob(os.path.join(center_root, "concat*", split, '*pkl'))) 143 | split_root = os.path.dirname(caches[0]).replace(center_root, merged_root) 144 | os.makedirs(split_root, exist_ok=True) 145 | 146 | for fc in caches: 147 | fl = fc.replace(center_root, left_root) 148 | fr = fc.replace(center_root, right_root) 149 | if not os.path.exists(fl) or not os.path.exists(fr): 150 | print('Skipping') 151 | continue 152 | fout = fc.replace(center_root, merged_root) 153 | # pdb.set_trace() 154 | # if os.path.exists(fout): 155 | # continue 156 | 157 | with open(fc, 'rb') as handle: 158 | data = pickle.load(handle) 159 | with open(fl, 'rb') as handle: 160 | data_l = pickle.load(handle) 161 | with open(fr, 'rb') as handle: 162 | data_r = pickle.load(handle) 163 | 164 | data_out = {'ped_feats': data['ped_feats']} 165 | merged_feats = [] 166 | if data['ctxt_feats'].shape[0] != 1 or data['ctxt_feats'].sum() != 0: 167 | merged_feats += data['ctxt_feats'], 168 | if data_l['ctxt_feats'].shape[0] != 1 or data_l['ctxt_feats'].sum() != 0: 169 | merged_feats += data_l['ctxt_feats'], 170 | if data_r['ctxt_feats'].shape[0] != 1 or data_r['ctxt_feats'].sum() != 0: 171 | merged_feats += data_r['ctxt_feats'], 172 | merged_feats = torch.cat(merged_feats, 0) 173 | merged_cls = torch.cat([data['ctxt_cls'], data_l['ctxt_cls'], data_r['ctxt_cls']], 0) 174 | data_out['ctxt_feats'] = merged_feats 175 | data_out['ctxt_cls'] = merged_cls 176 | 177 | with open(fout, 'wb') as handle: 178 | pickle.dump(data_out, handle) 179 | -------------------------------------------------------------------------------- /cache_data_stip.py: -------------------------------------------------------------------------------- 1 | # NOTE: this script uses segmentation files only 2 | # and does not rely on pedestrian annotations. 3 | 4 | # python imports 5 | import os 6 | import sys 7 | from glob import glob 8 | import pickle 9 | import time 10 | import numpy as np 11 | 12 | 13 | # local imports 14 | import data 15 | import utils 16 | from utils.data_proc_stip import parse_objs 17 | 18 | import pdb 19 | 20 | 21 | # def cache_masks(): 22 | # opt, logger = utils.build(is_train=False) 23 | # opt.combine_method = '' 24 | # opt.split = 'train' 25 | # cache_dir_name = 'jaad_collapse{}'.format('_'+opt.combine_method if opt.combine_method else '') 26 | # data.cache_all_objs(opt, cache_dir_name) 27 | # 28 | # 29 | # def cache_crops(): 30 | # fnpy_root = '/sailhome/bingbin/STR-PIP/datasets/JAAD_instance_segm' 31 | # fpkl_root = '/sailhome/bingbin/STR-PIP/datasets/cache/JAAD_instance_crops' 32 | # utils.get_obj_crops(fnpy_root, fpkl_root) 33 | 34 | def add_obj_bbox(view=''): 35 | if not view: 36 | vid_root = '/vision/group/prolix/instances_20fps/stip_instances/' 37 | fobj_root = '/vision/group/prolix/processed/obj_bbox_20fps' 38 | elif view == 'center': 39 | vid_root = '/vision/group/prolix/instances_20fps/stip_instances/' 40 | fobj_root = '/vision/group/prolix/processed/center/obj_bbox_20fps' 41 | else: # view = 'left' or 'right' 42 | vid_root = '/vision/group/prolix/stip_side/{}/instances_20fps/'.format(view) 43 | fobj_root = '/vision/group/prolix/processed/{}/obj_bbox_20fps/'.format(view) 44 | os.makedirs(fobj_root, exist_ok=True) 45 | 46 | def helper(vid_range, split): 47 | for dir_vid in vid_range: 48 | print(dir_vid) 49 | sys.stdout.flush() 50 | t_start = time.time() 51 | 52 | if not view or view == 'center': 53 | segs = sorted(glob(os.path.join(vid_root, dir_vid, 'inference', '*--*'))) 54 | else: 55 | segs = sorted(glob(os.path.join(vid_root, dir_vid, '*--*'))) 56 | 57 | for seg in segs: 58 | if not view or view == 'center': 59 | fsegms = sorted(glob(os.path.join(seg, '*_segm.npy'))) 60 | else: 61 | fsegms = sorted(glob(os.path.join(seg, '*.pkl'))) 62 | for i, fsegm in enumerate(fsegms): 63 | if i and i%100 == 0: 64 | print('Time per frame:', (time.time() - t_start)/i) 65 | sys.stdout.flush() 66 | fid = os.path.basename(fsegm).split('_')[0] 67 | fbbox = os.path.join(fobj_root, '{:s}_seg{:s}_fid{:s}.pkl'.format(dir_vid, os.path.basename(seg), fid)) 68 | # if 'ANN_conor1_seg12:22--12:59_fid0000015483.pkl' not in fbbox: 69 | # continue 70 | if os.path.exists(fbbox): 71 | continue 72 | if not os.path.exists(fsegm): 73 | print('File does not exist:', fsegm) 74 | continue 75 | objs = parse_objs(fsegm) 76 | dobjs = {cls:[] for cls in range(1,5)} 77 | for cls, masks in objs.items(): 78 | for mask in masks: 79 | try: 80 | if len(mask.shape) == 3: 81 | h, w, c = mask.shape 82 | if c != 1: 83 | raise ValueError('Each mask should have shape (1080, 1920, 1)') 84 | mask = mask.reshape(h, w) 85 | x_pos = mask.sum(0).nonzero()[0] 86 | if not len(x_pos): 87 | x_pos = [0,0] 88 | x_min, x_max = x_pos[0], x_pos[-1] 89 | y_pos = mask.sum(1).nonzero()[0] 90 | if not len(y_pos): 91 | y_pos = [0,0] 92 | y_min, y_max = y_pos[0], y_pos[-1] 93 | # bbox: [x_min, y_min, w, h]; same as bbox for ped['pos_GT'] 94 | bbox = [x_min, y_min, x_max-x_min, y_max-y_min] 95 | except Exception as e: 96 | print(e) 97 | pdb.set_trace() 98 | 99 | dobjs[cls] += bbox, 100 | with open(fbbox, 'wb') as handle: 101 | pickle.dump(dobjs, handle) 102 | 103 | vids = sorted(glob(os.path.join(vid_root, '*_*'))) 104 | vids = [os.path.basename(vid) for vid in vids] 105 | vids_test = ['downtown_ann_3-09-28-2017', 'downtown_palo_alto_6', 'dt_san_jose_4', 'mountain_view_4', 106 | 'sf_soma_2'] 107 | vids_train = [vid for vid in vids if vid not in vids_test] 108 | 109 | # tmp 110 | # vids_train = ['downtown_ann_1-09-27-2017', 'downtown_ann_2-09-27-2017', 'downtown_ann_3-09-27-2017', 'downtown_ann_1-09-28-2017'] 111 | # vids_test = [] 112 | 113 | if True: 114 | helper(vids_train, 'train') 115 | if False: 116 | helper(vids_test, 'test') 117 | 118 | def merge_and_flat(vrange, view=''): 119 | """ 120 | Merge fids in a vid and flatten the classes 121 | """ 122 | if not view: 123 | pkl_in_root = '/vision/group/prolix/processed/obj_bbox_20fps' 124 | pkl_out_root = '/vision/group/prolix/processed/obj_bbox_20fps_merged' 125 | else: 126 | pkl_in_root = '/vision/group/prolix/processed/{}/obj_bbox_20fps/'.format(view) 127 | pkl_out_root = '/vision/group/prolix/processed/{}/obj_bbox_20fps_merged'.format(view) 128 | os.makedirs(pkl_out_root, exist_ok=True) 129 | 130 | for vid in vrange: 131 | print(vid) 132 | fpkls = sorted(glob(os.path.join(pkl_in_root, '{:s}*pkl'.format(vid)))) 133 | segs = list(set([fpkl.split('_fid')[0] for fpkl in fpkls])) 134 | print('# segs:', len(segs)) 135 | for seg in segs: 136 | fpkls = sorted(glob(seg+'*pkl')) 137 | print(vid, len(fpkls)) 138 | sys.stdout.flush() 139 | # merged = [[] for _ in range(len(fpkls))] 140 | merged_bbox = [] 141 | merged_cls = [] 142 | t_start = time.time() 143 | for fpkl in fpkls: 144 | try: 145 | with open(fpkl, 'rb') as handle: 146 | data = pickle.load(handle) 147 | except: 148 | pdb.set_trace() 149 | curr_bbox = [] 150 | cls = [] 151 | for c in [1,2,3,4]: 152 | for bbox in data[c]: 153 | cls += c, 154 | curr_bbox += bbox, 155 | merged_bbox += np.array(curr_bbox), 156 | merged_cls += np.array(cls), 157 | 158 | seg = seg.split('seg')[-1] 159 | fpkl_out = os.path.join(pkl_out_root, '{}_seg{}.pkl'.format(vid, seg)) 160 | with open(fpkl_out, 'wb') as handle: 161 | dout = { 162 | 'obj_cls': merged_cls, 163 | 'obj_bbox': merged_bbox, 164 | } 165 | pickle.dump(dout, handle) 166 | print('avg time: ', (time.time()-t_start) / len(fpkls)) 167 | 168 | 169 | def merge_and_flat_wrapper(view=''): 170 | vid_root = '/vision/group/prolix/instances_20fps/stip_instances/' 171 | vids = sorted(glob(os.path.join(vid_root, '*_*'))) 172 | vids = [os.path.basename(vid) for vid in vids] 173 | vids_test = ['downtown_ann_3-09-28-2017', 'downtown_palo_alto_6', 'dt_san_jose_4', 'mountain_view_4', 174 | 'sf_soma_2'] 175 | vids_train = [vid for vid in vids if vid not in vids_test] 176 | 177 | # tmp 178 | # vids_train = ['downtown_ann_1-09-27-2017', 'downtown_ann_2-09-27-2017', 'downtown_ann_3-09-27-2017', 'downtown_ann_1-09-28-2017'] 179 | # vids_test = [] 180 | 181 | if True: 182 | merge_and_flat(vids_train, view=view) 183 | if False: 184 | merge_and_flat(vids_test, view=view) 185 | 186 | 187 | def cache_loc(): 188 | def helper(annots): 189 | for vid in annots: 190 | annot = annots[vid] 191 | n_frames = len(annot['act']) 192 | for fid in range(n_frames): 193 | # fid in ped cache file name: 1 based 194 | loc 195 | # ped: 196 | # ped['ped_crops']: ndarray: (3, 224, 224) 197 | # ped['masks']: tensor: [n_objs, 224, 224] 198 | 199 | 200 | if __name__ == '__main__': 201 | # cache_masks() 202 | # cache_crops() 203 | 204 | # add_obj_bbox(view='left') 205 | # merge_and_flat_wrapper(view='left') 206 | # add_obj_bbox(view='right') 207 | # merge_and_flat_wrapper(view='right') 208 | # add_obj_bbox(view='center') 209 | merge_and_flat_wrapper(view='center') 210 | 211 | # merge_and_flat(range(1, 200)) 212 | # merge_and_flat(range(100, 200)) 213 | # merge_and_flat(range(200, 300)) 214 | # merge_and_flat(range(100, 347)) 215 | # merge_and_flat(range(200, 347)) 216 | 217 | -------------------------------------------------------------------------------- /data/get_data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | from PIL import Image 5 | import torch 6 | torch.manual_seed(2019) 7 | torch.cuda.manual_seed_all(2019) 8 | import torch.utils.data as data 9 | import torchvision.transforms as transforms 10 | import argparse 11 | 12 | import time 13 | import pdb 14 | 15 | try: 16 | from .jaad import JAADDataset 17 | from .jaad_loc import JAADLocDataset 18 | from .stip import STIPDataset 19 | except: 20 | from jaad import JAADDataset 21 | from jaad_loc import JAADLocDataset 22 | from stip import STIPDataset 23 | 24 | def jaad_collate(batch): 25 | # each item in batch: a tuple of 26 | # 1. ped_crops: (30, 224, 224, 3) 27 | # 2. masks: list of len 30: each = dict of ndarrays: (n_obj, 1080, 1920, 1) 28 | # 3. GT_act: binary ndarray: (30, 9) 29 | ped_crops = [] 30 | masks = [] 31 | GT_act, GT_bbox, GT_pose = [], [], [] 32 | obj_bbox, obj_cls = [], [] 33 | fids = [] 34 | img_paths = [] 35 | for each in batch: 36 | ped_crops += each['ped_crops'], 37 | masks += each['all_masks'], 38 | GT_act += each['GT_act'], 39 | GT_bbox += each['GT_bbox'], 40 | GT_pose += each['GT_pose'], 41 | obj_bbox += each['obj_bbox'], 42 | obj_cls += each['obj_cls'], 43 | obj_bbox_l += each['obj_bbox_l'], 44 | obj_cls_l += each['obj_cls_l'], 45 | obj_bbox_r += each['obj_bbox_r'], 46 | obj_cls_r += each['obj_cls_r'], 47 | fids += each['fids'], 48 | img_paths += each['img_paths'], 49 | ped_crops = torch.stack(ped_crops) 50 | GT_act = torch.stack(GT_act) 51 | GT_bbox = torch.stack(GT_bbox) 52 | GT_pose = torch.stack(GT_pose) 53 | fids = torch.stack(fids) 54 | ret = { 55 | 'ped_crops': ped_crops, 56 | 'all_masks': masks, 57 | 'GT_act': GT_act, 58 | 'GT_bbox': GT_bbox, 59 | 'GT_pose': GT_pose, 60 | 'obj_cls': obj_cls, 61 | 'obj_bbox': obj_bbox, 62 | 'fids': fids, 63 | 'img_paths': img_paths, 64 | } 65 | if 'frames' in batch[0]: 66 | ret['frames'] = torch.stack([each['frames'] for each in batch], 0) 67 | ret['GT_driver_act'] = torch.stack([each['GT_driver_act'] for each in batch], 0) 68 | 69 | return ret 70 | 71 | 72 | def jaad_loc_collate(batch): 73 | # each item in batch: a tuple of 74 | # 1. ped_crops: (30, 224, 224, 3) 75 | # 2. masks: list of len 30: each = dict of ndarrays: (n_obj, 1080, 1920, 1) 76 | # 3. GT_act: binary ndarray: (30, 9) 77 | ped_crops = [] 78 | masks = [] 79 | GT_act, GT_ped_bbox = [], [] 80 | obj_bbox, obj_cls = [], [] 81 | fids = [] 82 | for each in batch: 83 | ped_crops += each['ped_crops'], 84 | masks += each['all_masks'], 85 | GT_act += each['GT_act'], 86 | GT_ped_bbox += each['GT_ped_bbox'], 87 | obj_bbox += each['obj_bbox'], 88 | obj_cls += each['obj_cls'], 89 | fids += each['fids'], 90 | GT_act = torch.stack(GT_act) 91 | fids = torch.stack(fids) 92 | ret = { 93 | 'ped_crops': ped_crops, 94 | 'all_masks': masks, 95 | 'GT_act': GT_act, 96 | 'GT_ped_bbox': GT_ped_bbox, 97 | 'obj_cls': obj_cls, 98 | 'obj_bbox': obj_bbox, 99 | 'fids': fids, 100 | } 101 | if 'frames' in batch[0]: 102 | ret['frames'] = torch.stack([each['frames'] for each in batch], 0) 103 | return ret 104 | 105 | 106 | def stip_collate(batch): 107 | # each item in batch: a tuple of 108 | # 1. ped_crops: (30, 224, 224, 3) 109 | # 2. masks: list of len 30: each = dict of ndarrays: (n_obj, 1080, 1920, 1) 110 | # 3. GT_act: binary ndarray: (30, 9) 111 | ped_crops = [] 112 | masks = [] 113 | GT_act, GT_bbox, GT_pose = [], [], [] 114 | obj_bbox, obj_cls = [], [] 115 | fids = [] 116 | img_paths = [] 117 | for each in batch: 118 | ped_crops += each['ped_crops'], 119 | masks += each['all_masks'], 120 | GT_act += each['GT_act'], 121 | GT_bbox += each['GT_bbox'], 122 | # GT_pose += each['GT_pose'], 123 | obj_bbox += each['obj_bbox'], 124 | obj_cls += each['obj_cls'], 125 | fids += each['fids'], 126 | img_paths += each['img_paths'], 127 | ped_crops = torch.stack(ped_crops) 128 | GT_act = torch.stack(GT_act) 129 | GT_bbox = torch.stack(GT_bbox) 130 | if len(GT_pose): 131 | GT_pose = torch.stack(GT_pose) 132 | fids = torch.stack(fids) 133 | ret = { 134 | 'ped_crops': ped_crops, 135 | 'all_masks': masks, 136 | 'GT_act': GT_act, 137 | 'GT_bbox': GT_bbox, 138 | 'obj_cls': obj_cls, 139 | 'obj_bbox': obj_bbox, 140 | 'fids': fids, 141 | 'img_paths': img_paths, 142 | } 143 | if len(GT_pose): 144 | ret['GT_pose'] = GT_pose 145 | if 'frames' in batch[0]: 146 | ret['frames'] = torch.stack([each['frames'] for each in batch], 0) 147 | ret['GT_driver_act'] = torch.stack([each['GT_driver_act'] for each in batch], 0) 148 | 149 | return ret 150 | 151 | 152 | 153 | def get_data_loader(opt): 154 | if opt.dset_name.lower() == 'jaad': 155 | dset = JAADDataset(opt) 156 | print('Built JAADDataset.') 157 | collate_fn = jaad_collate 158 | 159 | elif opt.dset_name.lower() == 'jaad_loc': 160 | dset = JAADLocDataset(opt) 161 | print('Built JAADLocDataset.') 162 | collate_fn = jaad_loc_collate 163 | 164 | elif opt.dset_name.lower() == 'stip': 165 | dset = STIPDataset(opt) 166 | print('Built STIPDataset') 167 | collate_fn = stip_collate 168 | 169 | else: 170 | raise NotImplementedError('Sorry but we currently only support JAAD. ^ ^b') 171 | 172 | dloader = data.DataLoader(dset, 173 | batch_size=opt.batch_size, 174 | shuffle=opt.is_train, 175 | num_workers=opt.n_workers, 176 | pin_memory=True, 177 | collate_fn=collate_fn, 178 | ) 179 | 180 | return dloader 181 | 182 | 183 | def cache_all_objs(opt, cache_dir_name): 184 | opt.is_train = False 185 | 186 | opt.collapse_cls = 1 187 | cache_dir_root = '/sailhome/ajarno/STR-PIP/datasets/cache/' 188 | cache_dir = os.path.join(cache_dir_root, cache_dir_name) 189 | os.makedirs(cache_dir, exist_ok=True) 190 | opt.save_cache_format = os.path.join(cache_dir, opt.split, 'ped{}_fid{}.pkl') 191 | os.makedirs(os.path.dirname(opt.save_cache_format), exist_ok=True) 192 | 193 | dset = JAADDataset(opt) 194 | dloader = data.DataLoader(dset, 195 | batch_size=1, 196 | shuffle=False, 197 | num_workers=0, 198 | collate_fn=jaad_collate) 199 | 200 | t_start = time.time() 201 | for i,each in enumerate(dloader): 202 | if i%50 == 0 and i: 203 | print('{}: avg time: {:.3f}'.format(i, (time.time()-t_start) / 50)) 204 | t_start = time.time() 205 | 206 | 207 | if __name__ == '__main__': 208 | parser = argparse.ArgumentParser() 209 | parser.add_argument('--dset-name', type=str, default='JAAD_loc') 210 | parser.add_argument('--annot-ped-format', type=str, default='/sailhome/ajarno/STR-PIP/datasets/annot_{}_ped.pkl') 211 | parser.add_argument('--annot-loc-format', type=str, default='/sailhome/ajarno/STR-PIP/datasets/annot_{}_loc.pkl') 212 | parser.add_argument('--is-train', type=int, default=1) 213 | parser.add_argument('--split', type=str, default='train') 214 | parser.add_argument('--seq-len', type=int, default=30) 215 | parser.add_argument('--ped-crop-size', type=tuple, default=(224, 224)) 216 | parser.add_argument('--mask-size', type=tuple, default=(224, 224)) 217 | parser.add_argument('--collapse-cls', type=int, default=0, 218 | help='Whether to merge the classes. If 1 then each item in masks is a dict keyed by cls, otherwise a list.') 219 | parser.add_argument('--img-path-format', type=str, 220 | default='/sailhome/ajarno/STR-PIP/datasets/JAAD_dataset/JAAD_clip_images/video_{:04d}.mp4/{:d}.jpg') 221 | parser.add_argument('--fsegm-format', type=str, 222 | default='/sailhome/ajarno/STR-PIP/datasets/JAAD_instance_segm/video_{:04d}/{:08d}_segm.npy') 223 | parser.add_argument('--save-cache-format', type=str, default='') 224 | parser.add_argument('--cache-format', type=str, default='') 225 | parser.add_argument('--batch-size', type=int, default=4) 226 | parser.add_argument('--n-workers', type=int, default=0) 227 | # added to test loader 228 | parser.add_argument('--rand-test', type=int, default=1) 229 | parser.add_argument('--predict', type=int, default=0) 230 | parser.add_argument('--predict-k', type=int, default=0) 231 | parser.add_argument('--combine-method', type=str, default='none') 232 | parser.add_argument('--load-cache', type=str, default='masks') 233 | parser.add_argument('--cache-obj-bbox-format', type=str, 234 | default='/sailhome/ajarno/STR-PIP/datasets/cache/obj_bbox_merged/vid{:08d}.pkl') 235 | 236 | opt = parser.parse_args() 237 | opt.save_cache_format = '/sailhome/ajarno/STR-PIP/datasets/cache/jaad_loc/{}/vid{}_fid{}.pkl' 238 | opt.cache_format = opt.save_cache_format 239 | opt.seq_len = 1 240 | opt.split = 'test' 241 | 242 | if True: 243 | # test dloader 244 | dloader = get_data_loader(opt) 245 | 246 | # for i,eg in enumerate(dloader): 247 | # if i%100 == 0: 248 | # print(i) 249 | # sys.stdout.flush() 250 | 251 | for i,vid in enumerate(dloader.dataset.vids): 252 | print('vid:', vid) 253 | annot = dloader.dataset.annots[vid] 254 | n_frames = len(annot['act']) 255 | for fid in range(n_frames): 256 | fcache = opt.cache_format.format(opt.split, vid, fid+1) 257 | if os.path.exists(fcache): 258 | continue 259 | dloader.dataset.__getitem__(i, fid_start=fid) 260 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pickle 4 | import sys, traceback, code 5 | import torch 6 | 7 | import data 8 | import models 9 | import utils 10 | 11 | import pdb 12 | 13 | 14 | def evaluate(model, dloader, opt, n_eval_epochs=3): 15 | print('Begin to evaluate') 16 | model.eval() 17 | 18 | if opt.collect_A: 19 | os.makedirs(os.path.dirname(opt.save_As_format), exist_ok=True) 20 | 21 | acc_det = { 22 | 'frames': 0, 23 | 'correct_frames': 0, 24 | 'clips': 0, 25 | 'correct_clips': 0, 26 | 'cross': 0, 27 | 'non_cross': 0, 28 | 'correct_cross': 0, 29 | 'correct_non_cross': 0, 30 | 'probs': None, 31 | 'loss': 0, 32 | } 33 | label_out = [] # labels output by the model 34 | label_GT = [] # GT labels (crossing) 35 | label_prob = [] 36 | 37 | if opt.predict or opt.predict_k: 38 | acc_pred = {key:0 for key in acc_det} 39 | acc_last = {key:0 for key in acc_det} 40 | 41 | def helper_update_metrics(ret, acc): 42 | n_frames, n_correct_frames, n_clips, n_correct_clips = ret[:4] 43 | n_cross, n_non_cross, n_correct_cross, n_correct_non_cross = ret[4:8] 44 | probs = ret[8] 45 | loss = ret[9] 46 | preds = ret[10] # (B, T) 47 | crossing = ret[11] # (B, T) 48 | 49 | acc['frames'] += n_frames 50 | acc['correct_frames'] += n_correct_frames 51 | acc['clips'] += n_clips 52 | acc['correct_clips'] += n_correct_clips 53 | acc['cross'] += n_cross 54 | acc['non_cross'] += n_non_cross 55 | acc['correct_cross'] += n_correct_cross 56 | acc['correct_non_cross'] += n_correct_non_cross 57 | if acc['probs'] is None: 58 | acc['probs'] = probs 59 | else: 60 | acc['probs'] += probs 61 | acc['loss'] += loss 62 | 63 | return acc 64 | 65 | def helper_report_metrics(acc): 66 | if acc['probs'] is None: 67 | return 0, 0, 0, 0, 100 68 | 69 | acc_frame = acc['correct_frames'] / max(1, acc['frames']) 70 | acc_clip = acc['correct_clips'] / max(1, acc['clips']) 71 | acc_cross = acc['correct_cross'] / max(1, acc['cross']) 72 | acc_non_cross = acc['correct_non_cross'] / max(1, acc['non_cross']) 73 | avg_probs = acc['probs'] / max(1, acc['clips']) 74 | avg_loss = acc['loss'] / max(1, acc['frames']) 75 | print('Accuracy: frame:{:.5f}\t/ clip:{:.5f}'.format(acc_frame, acc_clip)) 76 | print('Recall: cross:{:.5f}\t/ non-cross:{:.5f}'.format(acc_cross, acc_non_cross)) 77 | print('Probs:', ' / '.join(['{}:{:.1f}'.format(i, each.item()*100) for i,each in enumerate(avg_probs)])) 78 | print('Loss: {:.3f}'.format(avg_loss)) 79 | return acc_frame, acc_clip, acc_cross, acc_non_cross, avg_loss 80 | 81 | with torch.no_grad(): 82 | for eid in range(n_eval_epochs): 83 | for step, data in enumerate(dloader): 84 | ret_det, ret_pred, ret_last, As = model.step_test(data, slide=opt.slide, collect_A=opt.collect_A) 85 | 86 | if ret_det is not None: 87 | acc_det = helper_update_metrics(ret_det, acc_det) 88 | 89 | if opt.predict or opt.predict_k: 90 | acc_pred = helper_update_metrics(ret_pred, acc_pred) 91 | acc_last = helper_update_metrics(ret_last, acc_last) 92 | 93 | if opt.save_output > 0 and ret_det is not None: 94 | curr_out = torch.cat([ret_det[10], ret_pred[10], ret_last[10]], -1) 95 | curr_GT = torch.cat([ret_det[11], ret_pred[11], ret_last[11]], -1) 96 | curr_prob = torch.cat([ret_det[8], ret_pred[8], ret_last[8]]) 97 | label_out += curr_out, 98 | label_GT += curr_GT, 99 | label_prob += curr_prob, 100 | elif opt.save_output > 0 and ret_det is not None: 101 | label_out += ret_det[10], 102 | label_GT += ret_det[11], 103 | label_prob += ret_det[8], 104 | 105 | if As is not None: 106 | data = { 107 | 'As': As, 108 | 'fids': data['fids'], 109 | 'img_paths': data['img_paths'], 110 | 'probs': ret_pred[8], # 1D tensor of size T (avg over B) 111 | } 112 | with open(opt.save_As_format.format(step, eid), 'wb') as handle: 113 | pickle.dump(data, handle) 114 | 115 | if opt.save_output and (step+1)%opt.save_output == 0 and False: 116 | label_out = torch.cat(label_out, 0).numpy() 117 | label_GT = torch.cat(label_GT, 0).numpy() 118 | label_prob = torch.cat(label_prob, 0).numpy() 119 | with open(opt.save_output_format.format(step), 'wb') as handle: 120 | pickle.dump({'out':label_out, 'GT': label_GT, 'prob': label_prob}, handle) 121 | label_out = [] 122 | label_GT = [] 123 | label_prob = [] 124 | 125 | torch.cuda.empty_cache() 126 | 127 | if opt.save_output_format: 128 | label_out = torch.cat(label_out, 0).numpy() 129 | label_GT = torch.cat(label_GT, 0).numpy() 130 | label_prob = torch.cat(label_prob, 0).numpy() 131 | with open(opt.save_output_format.format('all'), 'wb') as handle: 132 | pickle.dump({'out':label_out, 'GT': label_GT, 'prob': label_prob}, handle) 133 | 134 | print('Detection:') 135 | result_det = helper_report_metrics(acc_det) 136 | if opt.predict or opt.predict_k: 137 | print('Prediction:') 138 | result_pred = helper_report_metrics(acc_pred) 139 | result_last = helper_report_metrics(acc_last) 140 | print() 141 | return result_det, result_pred, result_last 142 | 143 | print() 144 | return result_det, None, None 145 | 146 | 147 | def extract_feats(model, dloader, extract_feats_dir, seq_len=30): 148 | print('Begin to extract') 149 | model.eval() 150 | 151 | n_peds = len(dloader) 152 | print('n_peds:', n_peds) 153 | 154 | for pid in range(0, n_peds): 155 | ped = dloader.dataset.peds[pid] 156 | if 'frame_end' in ped: 157 | # JAAD setting 158 | n_frames = ped['frame_end'] - ped['frame_start'] + 1 159 | fid_range = range(ped['frame_start'], ped['frame_end']+1) 160 | fid_display = list(fid_range) 161 | elif 'fids20' in ped: 162 | # STIP setting 163 | n_frames = len(ped['fids20']) 164 | fid_range = range(n_frames) 165 | fid_display = ped['fids20'] 166 | else: 167 | print("extract_feats: missing/unexpected keys... o_o") 168 | pdb.set_trace() 169 | 170 | for fid,fid_dis in zip(fid_range, fid_display): 171 | print('pid:{} / fid:{}'.format(pid, fid)) 172 | item = dloader.dataset.__getitem__(pid, fid_start=fid) 173 | ped_crops, masks, act = item['ped_crops'], item['all_masks'], item['GT_act'] 174 | # print('masks[0][1]:', masks[0][1].shape) 175 | ped_feats, ctxt_feats, ctxt_cls = model.extract_feats(ped_crops, masks, pid) 176 | 177 | feat_path = os.path.join(extract_feats_dir, 'ped{}_fid{}.pkl'.format(pid, fid_dis)) 178 | with open(feat_path, 'wb') as handle: 179 | feats = { 180 | 'ped_feats': ped_feats.cpu(), # shape: 1, 512 181 | 'ctxt_feats': ctxt_feats.cpu(), # shape: n_objs, 512 182 | 'ctxt_cls': torch.tensor(ctxt_cls) 183 | } 184 | pickle.dump(feats, handle) 185 | 186 | del ped_feats 187 | del ctxt_feats 188 | 189 | torch.cuda.empty_cache() 190 | 191 | if pid % opt.log_every == 0: 192 | print('pid', pid) 193 | 194 | def extract_feats_loc(model, dloader, extract_feats_dir, seq_len=1): 195 | print('Begin to extract') 196 | model.eval() 197 | 198 | n_vids = len(dloader) 199 | print('n_vids:', n_vids) 200 | 201 | for vid in range(0, n_vids): 202 | key = dloader.dataset.vids[vid] 203 | annot = dloader.dataset.annots[key] 204 | 205 | for fid in range(len(annot['act'])): 206 | print('vid:{} / fid:{}'.format(vid, fid)) 207 | feat_path = os.path.join(extract_feats_dir, 'vid{}_fid{}.pkl'.format(vid, fid)) 208 | if os.path.exists(feat_path): 209 | continue 210 | 211 | item = dloader.dataset.__getitem__(vid, fid_start=fid) 212 | ped_crops, masks, act = item['ped_crops'], item['all_masks'], item['GT_act'] 213 | # print('masks[0][1]:', masks[0][1].shape) 214 | ped_feats, ctxt_feats, ctxt_cls = model.extract_feats(ped_crops, masks) 215 | 216 | with open(feat_path, 'wb') as handle: 217 | feats = { 218 | 'ped_feats': ped_feats[0].cpu(), # shape: 1, 512 219 | 'ctxt_feats': ctxt_feats.cpu(), # shape: n_objs, 512 220 | 'ctxt_cls': torch.tensor(ctxt_cls) 221 | } 222 | pickle.dump(feats, handle) 223 | 224 | del ped_feats 225 | del ctxt_feats 226 | 227 | torch.cuda.empty_cache() 228 | 229 | if vid % opt.log_every == 0: 230 | print('vid', vid) 231 | 232 | 233 | if __name__ == '__main__': 234 | opt, logger = utils.build(is_train=False) 235 | 236 | dloader = data.get_data_loader(opt) 237 | print('{} dataset: {}'.format(opt.split, len(dloader.dataset))) 238 | 239 | model = models.get_model(opt) 240 | print('Got model') 241 | if opt.which_epoch == -1: 242 | model_path = os.path.join(opt.ckpt_path, 'best_pred.pth') 243 | else: 244 | model_path = os.path.join(opt.ckpt_path, '{}.pth'.format(opt.which_epoch)) 245 | if os.path.exists(model_path): 246 | # NOTE: if path not exists, then using backbone weights from ImageNet-pretrained model 247 | model.load(model_path) 248 | print('Model loaded:', model_path) 249 | else: 250 | print('Model does not exists:', model_path) 251 | model = model.to('cuda:0') 252 | 253 | try: 254 | if opt.mode == 'evaluate': 255 | evaluate(model, dloader, opt) 256 | elif opt.mode == 'extract': 257 | assert(opt.batch_size == 1) 258 | assert(opt.seq_len == 1) 259 | assert(opt.predict == 0) 260 | 261 | print('Saving at', opt.extract_feats_dir) 262 | os.makedirs(opt.extract_feats_dir, exist_ok=True) 263 | if 'loc' in opt.model: 264 | extract_feats_loc(model, dloader, opt.extract_feats_dir, opt.seq_len) 265 | else: 266 | extract_feats(model, dloader, opt.extract_feats_dir, opt.seq_len) 267 | except Exception as e: 268 | print(e) 269 | typ, vacl, tb = sys.exc_info() 270 | traceback.print_exc() 271 | last_frame = lambda tb=tb: last_frame(tb.tb_next) if tb.tb_next else tb 272 | frame = last_frame().tb_frame 273 | ns = dict(frame.f_globals) 274 | ns.update(frame.f_locals) 275 | code.interact(local=ns) 276 | 277 | -------------------------------------------------------------------------------- /data/jaad_loc.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | import numpy as np 4 | import pickle 5 | import cv2 6 | import random 7 | random.seed(2019) 8 | import torch 9 | import torch.utils.data as data 10 | import torchvision.transforms as transforms 11 | 12 | import pdb 13 | import time 14 | 15 | from utils.data_proc import parse_objs 16 | 17 | class JAADLocDataset(data.Dataset): 18 | def __init__(self, opt): 19 | # annot_ped_format, is_train, split, 20 | # seq_len, ped_crop_size, mask_size, collapse_cls, 21 | # img_path_format, fsegm_format): 22 | 23 | self.split = opt.split 24 | annot_loc = opt.annot_loc_format.format(self.split) 25 | with open(annot_loc, 'rb') as handle: 26 | self.annots = pickle.load(handle) 27 | self.vids = sorted(self.annots.keys()) 28 | 29 | self.is_train = opt.is_train 30 | self.rand_test = opt.rand_test 31 | self.seq_len = opt.seq_len 32 | self.predict = opt.predict 33 | if self.predict: 34 | self.all_seq_len = self.seq_len + opt.pred_seq_len 35 | else: 36 | self.all_seq_len = self.seq_len 37 | self.predict_k = opt.predict_k 38 | if self.predict_k: 39 | self.all_seq_len += self.predict_k 40 | self.ped_crop_size = opt.ped_crop_size 41 | self.mask_size = opt.mask_size 42 | self.collapse_cls = opt.collapse_cls 43 | self.combine_method = opt.combine_method 44 | 45 | self.img_path_format = opt.img_path_format 46 | self.fsegm_format = opt.fsegm_format 47 | self.save_cache_format = opt.save_cache_format 48 | self.load_cache = opt.load_cache 49 | self.cache_format = opt.cache_format 50 | self.cache_obj_bbox_format = opt.cache_obj_bbox_format 51 | 52 | def __len__(self): 53 | return len(self.vids) 54 | 55 | def __getitem__(self, idx, fid_start=-1): 56 | t_start = time.time() 57 | 58 | vid = self.vids[idx] 59 | annot = self.annots[vid] 60 | 61 | # print('vid:', vid) 62 | 63 | if fid_start != -1: 64 | f_start = fid_start 65 | elif self.is_train or self.rand_test: 66 | # train: randomly sample all_seq_len number of frames 67 | f_start = random.randint(0, len(annot['act'])-self.all_seq_len+1) 68 | elif fid_start == -1: 69 | f_start = 0 70 | 71 | if self.is_train or self.rand_test or fid_start != -1: 72 | if self.predict_k: 73 | fids = [min(len(annot['act'])-1, f_start+i) for i in range(self.seq_len)] 74 | fids += min(len(annot['act'])-1, f_start+self.seq_len-1+self.predict_k), 75 | else: 76 | fids = [min(len(annot['act'])-1, f_start+i) for i in range(self.all_seq_len)] 77 | else: 78 | # take the entire video 79 | fids = range(0, len(annot['act'])) 80 | 81 | GT_act = [annot['act'][fid] for fid in fids] 82 | GT_act = np.stack(GT_act) 83 | GT_act = torch.tensor(GT_act) 84 | 85 | GT_ped_bbox = [[pos.astype(np.float) for pos in annot['ped_pos'][fid]] for fid in fids] 86 | 87 | with open(self.cache_obj_bbox_format.format(vid), 'rb') as handle: 88 | data = pickle.load(handle) 89 | obj_cls = [data['obj_cls'][fid] for fid in fids] 90 | obj_bbox = [data['obj_bbox'][fid] for fid in fids] 91 | 92 | frames = [] 93 | for fid in fids[:self.seq_len]: 94 | img_path = self.img_path_format.format(vid, fid+1) # +1 since fid is 0-based while the frame path starts at 1. 95 | img = cv2.imread(img_path) 96 | img = cv2.resize(img, (224,224)).transpose((2,0,1)) 97 | frames += img, 98 | 99 | ret = { 100 | 'GT_act': GT_act, 101 | 'GT_ped_bbox': GT_ped_bbox, 102 | 'obj_cls': obj_cls, 103 | 'obj_bbox': obj_bbox, 104 | 'frames': torch.tensor(frames), 105 | 'fids': torch.tensor(np.array(fids)), 106 | } 107 | ped_crops = [] 108 | all_masks = [] 109 | 110 | # only the first seq_len fids are input data 111 | fids = fids[:self.seq_len] 112 | 113 | if self.load_cache == 'masks': 114 | for i,fid in enumerate(fids): 115 | # NOTE: fid for masks is 1-based. 116 | fcache = self.cache_format.format(self.split, vid, fid+1) 117 | # print(fcache) 118 | if os.path.exists(fcache): 119 | with open(fcache, 'rb') as handle: 120 | data = pickle.load(handle) 121 | ped_crops += data['ped_crops'], 122 | all_masks += data['masks'], 123 | 124 | if 'max' not in fcache: 125 | # 1 mask per obj: check for # objs 126 | if type(data['masks']) == dict: 127 | n_objs = len([each for val in data['masks'].values() for each in val]) 128 | else: 129 | n_objs = len(data['masks']) 130 | if n_objs != len(obj_bbox[i]): 131 | print('JAAD: n_objs mismatch') 132 | pdb.set_trace() 133 | else: 134 | try: 135 | ped_crop, cls_masks = self.get_vid_fid(vid, annot['ped_pos'][fid], fid) 136 | except Exception as e: 137 | print(e) 138 | pdb.set_trace() 139 | ped_crops += ped_crop, 140 | all_masks += cls_masks, 141 | if type(cls_masks) == dict: 142 | n_objs = len([each for val in cls_masks.values() for each in val]) 143 | else: 144 | n_objs = len(cls_masks) 145 | if n_objs != len(obj_bbox[i]): 146 | print('JAAD: n_objs mismatch') 147 | pdb.set_trace() 148 | 149 | ret['ped_crops'] = ped_crops 150 | ret['all_masks'] = all_masks 151 | 152 | # pdb.set_trace() 153 | n_peds = sum([len(each) for each in ped_crops]) 154 | n_objs = sum([len(each) for each in all_masks]) 155 | # print('n_peds:{} / n_objs:{}'.format(n_peds, n_objs)) 156 | return ret 157 | 158 | elif self.load_cache == 'feats': 159 | ped_feats = [] 160 | ctxt_feats = [] 161 | for fid in fids: 162 | # NOTE: fid for feats is 0-based. 163 | # with open(self.cache_format.format(self.split, vid, fid), 'rb') as handle: 164 | with open(self.cache_format.format(self.split, idx, fid), 'rb') as handle: 165 | data = pickle.load(handle) 166 | ped = data['ped_feats'] # shape: 1, 512 167 | ctxt = data['ctxt_feats'] # shape: n_objs, 512 168 | 169 | ped_feats += ped, 170 | ctxt_feats += ctxt, 171 | # ped_feats = torch.stack(ped_feats, 0) 172 | 173 | ret['ped_crops'] = ped_feats 174 | ret['all_masks'] = ctxt_feats 175 | return ret 176 | 177 | elif self.load_cache == 'pos': 178 | ret['ped_crops'] = torch.zeros([1,1,512]) 179 | ret['all_masks'] = torch.zeros([1,1,512]) 180 | return ret 181 | 182 | for fid in fids: 183 | ped_crop, cls_masks = self.get_ped_fid(ped, fid, idx) 184 | ped_crops += ped_crop, 185 | all_masks += cls_masks, 186 | 187 | # shape: [n_frames, self.ped_crop_size[0], self.ped_crop_size[1]] 188 | 189 | ped_crops = np.stack(ped_crops) 190 | ped_crops = torch.Tensor(ped_crops) 191 | 192 | # print('time per item:', time.time()-t_start) 193 | ret['ped_crops'] = ped_crops 194 | ret['all_masks'] = all_masks 195 | return ret 196 | 197 | 198 | def get_vid_fid(self, vid, peds, fid): 199 | """ 200 | Prepare ped_crop and obj masks for given ped and fid. 201 | """ 202 | 203 | ped_crops = [] 204 | if len(peds) == 0: 205 | ped_crops = [np.zeros([3, 224, 224])] 206 | # if no bbox, take the entire frame 207 | x,y,w,h = 0,0,-1,-1 208 | else: 209 | img_path = self.img_path_format.format(vid, fid+1) # +1 since fid is 0-based while the frame path starts at 1. 210 | img = cv2.imread(img_path) 211 | 212 | # pedestrian crops 213 | for ped in peds: 214 | x, y, w, h = ped 215 | x, y, w, h = int(x), int(y), int(w), int(h) 216 | 217 | try: 218 | ped_crop = img[y:y+h, x:x+w] 219 | except Exception as e: 220 | print(e) 221 | print('img_path:', img_path) 222 | print('x:{}, y:{}, w:{}, h:{}'.format(x, y, w, h)) 223 | ped_crop = cv2.resize(ped_crop, self.ped_crop_size) 224 | ped_crop = ped_crop.transpose((2,0,1)) 225 | ped_crops += ped_crop, 226 | 227 | # obj masks 228 | fsegm = self.fsegm_format.format(vid, fid) 229 | objs = parse_objs(fsegm) 230 | if self.collapse_cls: 231 | cls_masks = [] 232 | else: 233 | cls_masks = {} 234 | for cls, masks in objs.items(): 235 | if not self.collapse_cls: 236 | cls_masks[cls] = [] 237 | for mask in masks: 238 | mask[y:y+h, x:x+w] = 1 239 | if self.combine_method == 'pair': 240 | # crop out the union bbox of ped + obj 241 | # note that didn't check empty. 242 | x_pos = mask.sum(0).nonzero()[0] 243 | x_min, x_max = x_pos[0], x_pos[-1] 244 | y_pos = mask.sum(1).nonzero()[1] 245 | y_min, y_max = y_pos[0], y_pos[-1] 246 | mask = mask[y_min:y_max+1, x_min:x_max+1] 247 | mask = cv2.resize(mask, self.mask_size) 248 | mask = torch.tensor(mask) 249 | # mask = torch.stack([mask, mask, mask]) 250 | if self.collapse_cls: 251 | cls_masks += mask, 252 | else: 253 | # TODO: transform the mask: e.g. crop & norm over the union 254 | cls_masks[cls] += mask, 255 | if not self.collapse_cls: 256 | cls_masks[cls] = torch.stack(cls_masks[cls]) 257 | if self.combine_method == 'sum': 258 | cls_masks[cls] = cls_masks[cls].sum(0) 259 | elif self.combine_method == 'max': 260 | cls_masks[cls], _ = cls_masks[cls].max(0) 261 | 262 | if self.collapse_cls: 263 | if len(cls_masks) != 0: 264 | cls_masks = torch.stack(cls_masks) 265 | if self.combine_method == 'sum': 266 | cls_masks = cls_masks.sum(0) 267 | elif self.combine_method == 'max': 268 | cls_masks, _ = cls_masks.max(0) 269 | else: 270 | # no objects in the frame 271 | if self.combine_method: 272 | # e.g. 'sum' or 'max' 273 | cls_masks = torch.zeros(self.mask_size) 274 | else: 275 | cls_masks = torch.zeros([0, self.mask_size[0], self.mask_size[1]]) 276 | 277 | if self.cache_format: 278 | with open(self.cache_format.format(self.split, vid, fid+1), 'wb') as handle: 279 | cache = { 280 | 'ped_crops': ped_crops, 281 | 'masks': cls_masks.data if self.collapse_cls else {cls:cls_masks[cls].data for cls in cls_masks}, 282 | } 283 | pickle.dump(cache, handle) 284 | 285 | return ped_crops, cls_masks 286 | 287 | 288 | -------------------------------------------------------------------------------- /utils/data_proc.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | import numpy as np 4 | import pickle 5 | import cv2 6 | import xml.etree.ElementTree as ET 7 | from pycocotools import mask as maskUtils 8 | 9 | import pdb 10 | 11 | # Objects of interest 12 | cls_map = { 13 | # other road users (treated as part of vehicles) 14 | 5: 1, # bicyclist 15 | 6: 1, # motorcyclist 16 | 7: 1, # other riders 17 | # vehicles 18 | 28: 1, # bicycle 19 | 29: 1, # boat (??) 20 | 30: 1, # bus 21 | 31: 1, # car 22 | 32: 1, # caravan 23 | 33: 1, # motorcycle 24 | 34: 1, # other vehicle 25 | 35: 1, # trailer 26 | 36: 1, # truck 27 | 37: 1, # wheeled slow 28 | # environments 29 | 3: 2, # crosswalk - plain 30 | 8: 3, # crosswalk - zebra 31 | 24: 4, # traffic lights 32 | } 33 | 34 | # smaller set of objs becaues of memory 35 | cls_map_small = { 36 | # vehicles 37 | 28: 1, # bicycle 38 | 30: 1, # bus 39 | 31: 1, # car 40 | 33: 1, # motorcycle 41 | 35: 1, # trailer 42 | 36: 1, # truck 43 | # environments 44 | 3: 2, # crosswalk - plain 45 | 8: 3, # crosswalk - zebra 46 | 24: 4, # traffic lights 47 | } 48 | 49 | 50 | def parse_objs(fnpy): 51 | # Parse instance segmentations on one frame to masks. 52 | # In: filename of instance segmentation (i.e. `dataset/JAAD_instance_segm/video_{:04d}/{:08d}_segm.npy`) 53 | # Out: dict with key 1-4 (see `cls_map`), each for a type of graph nodes. 54 | segms = np.load(fnpy, allow_pickle=True) 55 | 56 | selected_segms = {} 57 | for key,val in cls_map.items(): 58 | if segms[key]: 59 | if val not in selected_segms: 60 | selected_segms[val] = [] 61 | for each in segms[key]: 62 | mask = maskUtils.decode([each]) 63 | selected_segms[val] += mask, 64 | del segms 65 | return selected_segms 66 | 67 | 68 | def get_obj_crops(fnpy_root, fpkl_root): 69 | os.makedirs(fpkl_root, exist_ok=True) 70 | 71 | vids = sorted(glob(os.path.join(fnpy_root, 'video_*'))) 72 | for i,vid in enumerate(vids): 73 | if i < 2: 74 | continue 75 | print('vid:', os.path.basename(vid)) 76 | fnpys = sorted(glob(os.path.join(vid, '*_segm.npy'))) 77 | for fnpy in fnpys: 78 | selected_segms = parse_objs(fnpy) 79 | crops = {} 80 | for cls,segms in selected_segms.items(): 81 | if len(segms): 82 | crops[cls] = [] 83 | for mask in segms: 84 | y_pos = mask.sum(1).nonzero()[0] 85 | x_pos = mask.sum(0).nonzero()[0] 86 | if len(y_pos) == 0: 87 | print('empty y_pos:', fnpy) 88 | continue 89 | if len(x_pos) == 0: 90 | print('empty x_pos:', fnpy) 91 | continue 92 | y_min, y_max = y_pos[0], y_pos[-1] 93 | x_min, x_max = x_pos[0], x_pos[-1] 94 | if x_min >= x_max or y_min >= y_max: 95 | print('empty_crop:', fnpy) 96 | print('x_min={} / x_max={} / y_min={} / y_max={}\n'.format(x_min, x_max, y_min, y_max)) 97 | continue 98 | crop = mask[y_min:(y_max+1), x_min:(x_max+1)] 99 | crop = cv2.resize(crop, (224,224)) 100 | crops[cls] += crop, 101 | if len(crops[cls]): 102 | crops[cls] = np.stack(crops[cls]) 103 | else: 104 | crops[cls] = np.zeros([0, 224, 224]) 105 | else: 106 | crops[cls] = np.zeros([0, 224, 224]) 107 | 108 | fpkl = fnpy.replace(fnpy_root, fpkl_root) 109 | fpkl_dir = os.path.dirname(fpkl) 110 | os.makedirs(fpkl_dir, exist_ok=True) 111 | with open(fpkl, 'wb') as handle: 112 | pickle.dump(crops, handle) 113 | 114 | 115 | pedestrian_act_map = { 116 | 'clear path': 0, 117 | 'crossing': 1, 118 | 'handwave': 2, 119 | 'looking': 3, 120 | 'nod': 4, 121 | 'slow down': 5, 122 | 'speed up': 6, 123 | 'standing': 7, 124 | 'walking': 8, 125 | } 126 | 127 | joint_ids = [ 128 | 1, # 0 Neck 129 | 2, # 1 RShoulder 130 | 5, # 2 LShoulder 131 | 9, # 3 RHip 132 | 10, # 4 RKnee 133 | 11, # 5 RAnkle 134 | 12, # 6 LHip 135 | 13, # 7 LKnee 136 | 14, # 8 LAnkle 137 | ] 138 | 139 | def parse_pedestrian(fxml, fpos_GT, fpos_pred='', fpose=''): 140 | """ 141 | parse xml (action label) 142 | """ 143 | e = ET.parse(fxml).getroot() 144 | # peds: dict of id to a list of per-frame acts. 145 | # Action labels at each frame is a one-hot vec of length 9 (i.e. len(pedestrian_act_map)). 146 | peds = {} 147 | nframes = int(e.get('num_frames')) 148 | for child in e.getchildren(): 149 | if child.tag != 'actions': 150 | continue 151 | 152 | for each in child.getchildren(): 153 | if 'pedestrian' not in each.tag: # e.g. Driver 154 | continue 155 | 156 | # NOTE: `pid` starts at 1. 157 | tag = each.tag 158 | if tag == 'pedestrian': 159 | pid = 1 160 | tag = 'pedestrian1' 161 | else: 162 | pid = int(tag[len('pedestrian'):]) 163 | # NOTE: change indexing from 'pid' to 'each.tag' 164 | peds[tag] = {'act': [[0] * len(pedestrian_act_map) for _ in range(nframes)]} 165 | peds[tag]['tag'] = tag 166 | peds[tag]['pid'] = pid 167 | pacts = each.getchildren() 168 | for act in pacts: 169 | act_cls = pedestrian_act_map[act.get('id').lower()] 170 | for t in range(int(act.get('start_frame'))-1, int(act.get('end_frame'))): 171 | peds[tag]['act'][t][act_cls] = 1 172 | 173 | if fpose: 174 | """ 175 | parse pose 176 | """ 177 | pose_data = np.load(fpose, encoding='latin1', allow_pickle=True).item() 178 | for ped_tag, frames in sorted(pose_data.items()): 179 | if ped_tag not in peds: 180 | continue 181 | peds[ped_tag]['pose'] = [[] for _ in range(nframes)] 182 | peds[ped_tag]['pos_Pose'] = [[] for _ in range(nframes)] 183 | 184 | # NOTE fid are 0-indexed in pose_data 185 | for frame in sorted(frames): 186 | fid = frame[0] 187 | peds[ped_tag]['pose'][fid] = np.array(frame[2][0])[joint_ids] 188 | peds[ped_tag]['pos_Pose'][fid] = frame[1]['pos'] 189 | 190 | """ 191 | parse position 192 | """ 193 | poccl = {ped_tag:[[] for _ in range(nframes)] for ped_tag in peds} 194 | 195 | # NOTE: not all pedestrians in the npy files have labels. 196 | # We only use those with action labels in xml. 197 | data = np.load(fpos_GT, allow_pickle=True).item() 198 | assert(data['nFrame'] == nframes), print("nFrame mismatch: xml:{} / npy: {}".format(nframes, data['nFrame'])) 199 | 200 | assert(len(data['objLists']) == nframes), print("nFrame mismatch: xml:{} / npy-->objLists: {}".format(nframes, len(data['objLists']))) 201 | 202 | ppos_GT = {ped_tag:[[] for _ in range(nframes)] for ped_tag in peds} 203 | # pdb.set_trace() 204 | for fid,frame in enumerate(data['objLists']): 205 | # frame: dict w/ keys: 'id', 'pos', 'posv', 'occl', 'lock' 206 | for ped in frame: 207 | pid = ped['id'][0] 208 | for ped_tag in peds: 209 | if peds[ped_tag]['pid'] == pid: 210 | ppos_GT[ped_tag][fid] = ped['pos'] # (x, y, w, h) 211 | poccl[ped_tag][fid] = ped['occl'][0] 212 | break 213 | # if ped_tag in peds: 214 | # ppos_GT[ped_tag][fid] = ped['pos'] # (x, y, w, h) 215 | # poccl[ped_tag][fid] = ped['occl'][0] 216 | 217 | for ped_tag in peds: 218 | pid = peds[ped_tag]['pid'] 219 | peds[ped_tag]['frame_start'] = data['objStr'][pid-1]-1 220 | peds[ped_tag]['frame_end'] = data['objEnd'][pid-1]-1 221 | peds[ped_tag]['pos_GT'] = ppos_GT[ped_tag] 222 | peds[ped_tag]['occl'] = poccl[ped_tag] 223 | 224 | if fpos_pred: 225 | raise NotImplementedError('cannot parse pedestrian segm for now. Sorry!!') 226 | 227 | return peds 228 | 229 | 230 | def prepare_data(): 231 | # object GT directory 232 | obj_root = '/sailhome/ajarno/STR-PIP/datasets/JAAD_instance_segm' 233 | fobj_dir_format = os.path.join(obj_root, 'video_{:04d}') 234 | # pedestrian GT files 235 | #ped_root = '/sailhome/ajarno/STR-PIP/datasets/JAAD_dataset/' 236 | ped_root = '/vision/u/caozj1995/data/JAAD_dataset/' 237 | fxml_format = os.path.join(ped_root, 'behavioral_data_xml', 'video_{:04d}.xml') 238 | fpose_format = os.path.join('/vision2/u/mangalam/JAAD/openpose_track_with_pose/', 'video_{:04d}.npy') 239 | fpos_GT_format = os.path.join(ped_root, 'bounding_box_python', 'vbb_part', 'video_{:04d}.npy') 240 | 241 | def prepare_split(vid_range): 242 | 243 | all_peds = [] 244 | for vid in vid_range: 245 | # print(vid) 246 | 247 | if True: 248 | # objects 249 | fobj_dir = fobj_dir_format.format(vid) 250 | fsegms = sorted(glob(os.path.join(fobj_dir, '*_segm.npy'))) 251 | pdb.set_trace() 252 | 253 | frame_objs = [] 254 | for fid,fsegm in enumerate(fsegms): 255 | print('fid:', fid) 256 | frame_objs += parse_objs(fsegm), 257 | 258 | print('Finished objectS') 259 | 260 | if True: 261 | # pedestrians 262 | fxml = fxml_format.format(vid) 263 | fpose = fpose_format.format(vid) 264 | fpose = '' 265 | fpos_GT = fpos_GT_format.format(vid) 266 | ped_label = parse_pedestrian(fxml, fpos_GT, fpose=fpose) 267 | if len(ped_label) == 0: 268 | print('No pedestrian: vid:', vid) 269 | 270 | for ped in ped_label.values(): 271 | ped['vid'] = vid 272 | all_peds += ped, 273 | 274 | return all_peds 275 | 276 | 277 | annot_root = '/sailhome/ajarno/STR-PIP/datasets/' 278 | # Train 279 | print('Processing training data...') 280 | train_range = range(1, 250+1) 281 | annot_train = prepare_split(train_range) 282 | print('# Train:', len(annot_train)) 283 | with open(os.path.join(annot_root, 'annot_train_ped_withTag_sanityNoPose.pkl'), 'wb') as handle: 284 | pickle.dump(annot_train, handle) 285 | print('Training data ready.\n') 286 | 287 | # Test 288 | print('Processing testing data...') 289 | test_range = range(251, 346+1) 290 | annot_test = prepare_split(test_range) 291 | print('# Test:', len(annot_test)) 292 | with open(os.path.join(annot_root, 'annot_test_ped_withTag_sanityNoPose.pkl'), 'wb') as handle: 293 | pickle.dump(annot_test, handle) 294 | print('Testing data ready.\n') 295 | 296 | 297 | 298 | if __name__ == '__main__': 299 | if False: 300 | # test 301 | fsegm = '/sailhome/bingbin/STR-PIP/datasets/JAAD_instance_segm/video_0131/00000001_segm.npy' 302 | parse_objs(fsegm) 303 | # fxml = '/sailhome/bingbin/STR-PIP/datasets/JAAD_dataset/behavioral_data_xml/video_0001.xml' 304 | # fpos = '/vision2/u/caozj/datasets/JAAD_dataset/bounding_box_python/vbb_part/video_0001.npy' 305 | # parse_pedestrian(fxml, fpos) 306 | 307 | if True: 308 | prepare_data() 309 | 310 | if False: 311 | stip_segm2box_wrapper() 312 | -------------------------------------------------------------------------------- /models/model_concat.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | import random 7 | 8 | from models.backbone.resnet_based import resnet_backbone 9 | 10 | from models.model_base import BaseModel 11 | 12 | import pdb 13 | 14 | 15 | class ConcatModel(BaseModel): 16 | def __init__(self, opt): 17 | super().__init__(opt) 18 | if self.use_gru: 19 | self.gru = nn.GRU(self.cls_dim, self.cls_dim, 2, batch_first=True).to(self.device) 20 | elif self.use_trn: 21 | if self.predict and self.pred_seq_len != 1: 22 | raise ValueError("self.pred_seq_len has to be 1 when using TRN.") 23 | self.k = 3 # number of samples per branch 24 | self.f2 = nn.Sequential( 25 | nn.Linear(2*self.cls_dim, 256), 26 | nn.ReLU(), 27 | nn.Linear(256, 256)) 28 | self.f3 = nn.Sequential( 29 | nn.Linear(3*self.cls_dim, 256), 30 | nn.ReLU(), 31 | nn.Linear(256, 256)) 32 | self.f4 = nn.Sequential( 33 | nn.Linear(4*self.cls_dim, 256), 34 | nn.ReLU(), 35 | nn.Linear(256, 256)) 36 | self.h2 = nn.Linear(256*self.k, self.n_acts) 37 | self.h3 = nn.Linear(256*self.k, self.n_acts) 38 | self.h4 = nn.Linear(256*self.k, self.n_acts) 39 | 40 | if self.use_ped_gru: 41 | self.ped_gru = nn.GRU(self.ped_dim, self.ped_dim, 2, batch_first=True).to(self.device) 42 | if self.use_ctxt_gru: 43 | self.ctxt_gru = nn.GRU(self.conv_dim, self.conv_dim, 2, batch_first=True).to(self.device) 44 | 45 | if self.use_act and not self.use_gt_act: 46 | self.act_gru = nn.GRU(self.cls_dim-self.n_acts, self.cls_dim, 2, batch_first=True).to(self.device) 47 | 48 | 49 | if self.predict: 50 | # NOTE: this GRU only takes in seq of len 1, i.e. one time step 51 | # GRU is used over GRUCell for multilayer 52 | self.gru_pred = nn.GRU(self.cls_dim, self.cls_dim, 2, batch_first=True).to(self.device) 53 | 54 | if self.predict_k: 55 | self.fc_pred = nn.Sequential( 56 | nn.Linear(self.cls_dim, 256), 57 | nn.ReLU(), 58 | nn.Linear(256, self.cls_dim) 59 | ) 60 | 61 | 62 | def forward(self, ped_crops, masks, bbox=None, act=None, pose=None, obj_bbox=None, obj_cls=None): 63 | B, T, _,_,_ = ped_crops.shape 64 | 65 | if self.branch in ['both', 'ped']: 66 | # ped_crops: (bt, 30, 3, 224, 224) 67 | ped_crops = ped_crops.type(self.dtype).view(-1, 3, 224, 224) 68 | ped_crops = ped_crops.to(self.device) 69 | ped_feats = self.ped_encoder(ped_crops) 70 | # ped_feats: (B, T, d) 71 | ped_feats = ped_feats.view(B, T, -1) 72 | if self.use_ped_gru: 73 | ped_feats, _ = self.ped_gru(ped_feats) 74 | 75 | if self.branch in ['both', 'ctxt']: 76 | # masks: list: (bt, 30, {cls: (n, 3, 224, 224)}) 77 | ctxt_feats = [[] for _ in range(B)] 78 | for b in range(B): 79 | for t in range(T): 80 | if len(masks[b][t]) == 0: 81 | ctxt_feat = torch.zeros([self.conv_dim]) 82 | ctxt_feat = ctxt_feat.to(self.device) 83 | else: 84 | if type(masks[b][t]) is dict: 85 | vals = list(masks[b][t].values()) 86 | # ctxt_masks: (n_total, (3,) 224, 224) 87 | ctxt_masks = torch.cat(vals, 0) 88 | else: 89 | ctxt_masks = masks[b][t] 90 | # TODO: this seems to be a bug though it didn't complain before. 91 | # Check whether this will affect the prev model. 92 | # ctxt_masks = ctxt_masks.sum(0, True) 93 | if ctxt_masks.dim() == 3: 94 | n_total, h, w = ctxt_masks.shape 95 | ctxt_masks = ctxt_masks.unsqueeze(1).expand([n_total, 3, h, w]) 96 | elif ctxt_masks.dim() == 2: 97 | h, w = ctxt_masks.shape 98 | ctxt_masks = ctxt_masks.unsqueeze(0).unsqueeze(0).expand([1, 3, h, w]) 99 | ctxt_masks = ctxt_masks.to(self.device) 100 | # ctxt_feats: (n_total, d) 101 | # print('ctxt_masks', ctxt_masks.shape) 102 | ctxt_feat = self.ctxt_encoder(ctxt_masks.type(self.dtype)) 103 | # average pool 104 | ctxt_feat = ctxt_feat.mean(0).squeeze(-1).squeeze(-1) 105 | ctxt_feats[b] += ctxt_feat, 106 | ctxt_feats[b] = torch.stack(ctxt_feats[b]) 107 | ctxt_feats = torch.stack(ctxt_feats) 108 | 109 | 110 | if os.path.exists(self.extract_feats_dir): 111 | # NOTE: set self.rand_test = 0 when extracting features since we want to cover all frames 112 | feat_path = os.path.join(self.extract_feats_dir, 'ped{}.pkl'.format(idx)) 113 | with open(feat_path, 'wb') as handle: 114 | feats = { 115 | 'ped_feats': ped_feats, 116 | 'ctxt_feats': ctxt_feats, 117 | } 118 | pickle.dump(feats, handle) 119 | 120 | if self.pos_mode != 'none': 121 | ped_feats = self.append_pos(ped_feats, bbox[:, :self.seq_len]) 122 | 123 | if self.use_signal: 124 | # act 2: handwave / act 3: looking 125 | signal = act[:, :, 2:4].to(self.device) 126 | ped_feats = torch.cat([ped_feats, signal], -1) 127 | 128 | if self.branch == 'both': 129 | frame_feats = torch.cat([ped_feats, ctxt_feats], -1) 130 | elif self.branch == 'ped': 131 | frame_feats = ped_feats 132 | elif self.branch == 'ctxt': 133 | frame_feats = ctxt_feats 134 | else: 135 | raise ValueError("self.branch should be 'both', 'ped', or 'ctxt'. Got {}".format(self.branch)) 136 | 137 | if self.use_act: 138 | if self.use_gt_act: 139 | if self.n_acts == 1: 140 | act = act[:, :, 1:2] 141 | act = act[:, :self.seq_len] 142 | else: 143 | # use predicted action labels 144 | temporal_feats, _ = self.act_gru(frame_feats) 145 | h = self.classifier(temporal_feats) 146 | act_logits = self.sigmoid(h) 147 | act = (act_logits > 0.5).type(self.dtype) 148 | frame_feats = torch.cat([frame_feats, act], -1) 149 | 150 | if self.use_pose: 151 | normed_pose = self.util_norm_pose(pose) 152 | normed_pose = normed_pose.contiguous().view(B, T, -1) 153 | frame_feats = torch.cat([frame_feats, normed_pose], -1) 154 | 155 | if self.use_gru: 156 | # self.gru keeps the dimension of frame_feats 157 | frame_feats, h = self.gru(frame_feats) 158 | elif self.use_trn: 159 | # Note: for predicting the next frame only (i.e. Lopez's setting) 160 | feats2 = [] 161 | feats3 = [] 162 | feats4 = [] 163 | for _ in range(self.k): 164 | # 2-frame relations 165 | l2 = self.seq_len // 2 166 | id1 = random.randint(0, l2-1) 167 | id2 = random.randint(0, l2-1) + l2 168 | feat2 = torch.cat([frame_feats[:, id1], frame_feats[:, id2]], -1) 169 | feats2 += self.f2(feat2), 170 | # 3-frame relations 171 | l3 = self.seq_len // 3 172 | id1 = random.randint(0, l3-1) 173 | id2 = random.randint(l3, 2*l3-1) 174 | id3 = random.randint(2*l3, self.seq_len-1) 175 | feat3 = torch.cat([frame_feats[:, id1], frame_feats[:, id2], frame_feats[:, id3]], -1) 176 | feats3 += self.f3(feat3), 177 | # 4-frame relations 178 | l4 = self.seq_len // 4 179 | id1 = random.randint(0, l4-1) 180 | id2 = random.randint(l4, 2*l4-1) 181 | id3 = random.randint(2*l4, 3*l4-1) 182 | id4 = random.randint(3*l4, self.seq_len-1) 183 | feat4 = torch.cat([frame_feats[:, id1], frame_feats[:, id2], frame_feats[:, id3], frame_feats[:, id4]], -1) 184 | feats4 += self.f4(feat4), 185 | t2 = self.h2(torch.cat(feats2, -1)) 186 | t3 = self.h2(torch.cat(feats3, -1)) 187 | t4 = self.h2(torch.cat(feats4, -1)) 188 | logits = t2 + t3 + t4 189 | logits = logits.view(B, 1, self.n_acts) 190 | return logits 191 | 192 | if self.predict: 193 | # o: (B, 1, cls_dim) 194 | o = frame_feats[:, -1:] 195 | pred_outs = [] 196 | for pred_t in range(self.pred_seq_len): 197 | o, h = self.gru_pred(o, h) 198 | # if self.pos_mode != 'none': 199 | # o = self.append_pos(o, bbox[:, self.seq_len+pred_t:self.seq_len+pred_t+1]) 200 | pred_outs += o, 201 | 202 | # pred_outs: (B, pred_seq_len, cls_dim) 203 | pred_outs = torch.cat(pred_outs, 1) 204 | # frame_feats: (B, seq_len + pred_seq_len, cls_dim) 205 | frame_feats = torch.cat([frame_feats, pred_outs], 1) 206 | 207 | if self.predict_k: 208 | # h: (n_gru_layers, B, cls_dim) --> (B, 1, cls_dim) 209 | h = h.transpose(0,1)[:, -1:] 210 | # pred_feats: (B, T, cls_dim) 211 | pred_feats = self.fc_pred(h) 212 | frame_feats = torch.cat([frame_feats, pred_feats], 1) 213 | 214 | # shape: (B, T, 2) 215 | logits = self.classifier(frame_feats) 216 | if self.use_act and not self.use_gt_act: 217 | return logits, act_logits 218 | # logits[:, :self.seq_len] = act_logits 219 | 220 | return logits 221 | 222 | 223 | def extract_feats(self, ped_crops, masks, idx): 224 | """ 225 | ped_crops: (1, 3, 224, 224) 226 | masks: list of len 1; each item being a dict w/ key 1-4 227 | """ 228 | assert(ped_crops.dim()==4), print('extract_feats does not support batch mode. Data is obtained directly from __getitem__.') 229 | assert(len(masks) == 1) 230 | T = ped_crops.shape[0] 231 | 232 | # step = 30 233 | 234 | # ped_crops: (bt, 30, 3, 224, 224) 235 | ped_crops = ped_crops.view(-1, 3, 224, 224) 236 | 237 | # ped_feats: shape: (T, D (e.g.512), 1, 1) 238 | # ped_feats = self.ped_encoder(ped_crops.to(self.device)).cpu() 239 | ped_feats = self.ped_encoder(ped_crops.type(self.dtype)).cpu() 240 | ped_feats = ped_feats.view(T, -1) 241 | 242 | ctxt_feats = [] 243 | ctxt_cls = [] 244 | for t,mask in enumerate(masks): 245 | if len(mask) == 0: 246 | ctxt_feat = torch.zeros([1, self.conv_dim]) 247 | ctxt_cls += [0] 248 | else: 249 | if type(mask) is dict: 250 | # grouped by classes 251 | ctxt_masks = [] 252 | for k in sorted(mask): 253 | for each in mask[k]: 254 | ctxt_cls += k, 255 | ctxt_masks += each, 256 | ctxt_masks = torch.stack(ctxt_masks, 0) 257 | else: 258 | # class collapsed 259 | ctxt_masks = mask 260 | ctxt_cls += [0] * ctxt_masks.shape[0] 261 | 262 | if ctxt_masks.dim() == 3: 263 | n_total, h, w = ctxt_masks.shape 264 | ctxt_masks = ctxt_masks.unsqueeze(1).expand([n_total, 3, h, w]) 265 | ctxt_masks = ctxt_masks.to(self.device) 266 | 267 | # ctxt_feats: (n_total, d) 268 | ctxt_feat = self.ctxt_encoder(ctxt_masks.type(self.dtype)).cpu() 269 | ctxt_feat = ctxt_feat.squeeze(-1).squeeze(-1) 270 | ctxt_feats += ctxt_feat, 271 | ctxt_feats = torch.cat(ctxt_feats, 0) 272 | assert(len(ctxt_feats) == len(ctxt_cls)) 273 | 274 | # NOTE: set self.rand_test = 0 when extracting features since we want to cover all frames 275 | return ped_feats, ctxt_feats, ctxt_cls 276 | 277 | --------------------------------------------------------------------------------