├── models
├── __init__.py
├── get_model.py
├── baseline_anticipate_cnn.py
├── baseline_pose.py
├── model_pos.py
├── backbone
│ ├── resnet_based.py
│ ├── imagenet_pretraining.py
│ └── resnet
│ │ ├── bottleneck.py
│ │ ├── basicblock.py
│ │ └── resnet.py
└── model_concat.py
├── data
├── __init__.py
├── get_data_loader.py
└── jaad_loc.py
├── .gitignore
├── utils
├── downsample.py
├── __init__.py
├── parse_box_from_segm.py
├── check_stip_time.py
├── check_ckpt.py
├── build.py
├── logger.py
├── check_switch.py
├── data_proc_loc.py
├── visual.py
├── stats.py
├── temp_plot.py
├── temp_plot_stip.py
├── colormap.py
├── draw_As.py
├── tracker.py
├── masking.py
├── postproc_smooth.py
├── stip_merge_views.py
└── data_proc.py
├── args
├── __init__.py
├── train_args.py
└── test_args.py
├── graphs
├── acc_temp.png
├── acc_across_time.png
├── acc_temp_graph.png
├── mix_temp_graph.png
├── prob_temp_graph.png
├── acc_temp_graph_smooth.png
├── mix_temp_graph_smooth.png
└── prob_temp_graph_smooth.png
├── scripts_dir
├── train_pos.sh
├── train_baseline_anticipate_cnn.sh
├── train_baseline_pose.sh
├── wandb
│ └── offline-run-20210314_201226-2eciwruc
│ │ └── logs
│ │ └── .nfs000000000001e91000000039
├── train_concat.sh
├── test_baseline_anticipate_cnn.sh
├── train_loc_concat.sh
├── test_loc_concat.sh
├── train_concat_stip.sh
├── train_graph_fromMasks.sh
├── test_concat_stip.sh
├── train_graph.sh
├── test_concat_stip_side.sh
├── test_concat.sh
├── train_loc_graph.sh
├── train_graph_stip_all.sh
├── train_graph_stip.sh
├── test_graph_stip_all.sh
├── test_graph_stip.sh
└── test_graph.sh
├── cache_data.py
├── README.md
├── train.py
├── cache_data_stip.py
└── test.py
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .get_model import *
2 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .get_data_loader import *
2 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | bin/
2 | datasets/
3 | *.pyc
4 | __pycache__/
5 |
--------------------------------------------------------------------------------
/utils/downsample.py:
--------------------------------------------------------------------------------
1 | import os
2 | from glob import glob
3 |
4 |
5 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .build import build
2 | from .data_proc import *
3 |
4 |
--------------------------------------------------------------------------------
/args/__init__.py:
--------------------------------------------------------------------------------
1 | from .train_args import TrainArgs
2 | from .test_args import TestArgs
3 |
--------------------------------------------------------------------------------
/graphs/acc_temp.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StanfordVL/STR-PIP/HEAD/graphs/acc_temp.png
--------------------------------------------------------------------------------
/utils/parse_box_from_segm.py:
--------------------------------------------------------------------------------
1 | import os
2 | from glob import glob
3 | import numpy as np
4 |
5 |
--------------------------------------------------------------------------------
/graphs/acc_across_time.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StanfordVL/STR-PIP/HEAD/graphs/acc_across_time.png
--------------------------------------------------------------------------------
/graphs/acc_temp_graph.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StanfordVL/STR-PIP/HEAD/graphs/acc_temp_graph.png
--------------------------------------------------------------------------------
/graphs/mix_temp_graph.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StanfordVL/STR-PIP/HEAD/graphs/mix_temp_graph.png
--------------------------------------------------------------------------------
/graphs/prob_temp_graph.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StanfordVL/STR-PIP/HEAD/graphs/prob_temp_graph.png
--------------------------------------------------------------------------------
/graphs/acc_temp_graph_smooth.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StanfordVL/STR-PIP/HEAD/graphs/acc_temp_graph_smooth.png
--------------------------------------------------------------------------------
/graphs/mix_temp_graph_smooth.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StanfordVL/STR-PIP/HEAD/graphs/mix_temp_graph_smooth.png
--------------------------------------------------------------------------------
/graphs/prob_temp_graph_smooth.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StanfordVL/STR-PIP/HEAD/graphs/prob_temp_graph_smooth.png
--------------------------------------------------------------------------------
/utils/check_stip_time.py:
--------------------------------------------------------------------------------
1 | import os
2 | from glob import glob
3 |
4 | stip_root = '/vision/group/prolix'
5 | ped_root = os.path.join(stip_root, 'processed/pedestrians')
6 | img_root = os.path.join(stip_root, 'images_20fps')
7 |
--------------------------------------------------------------------------------
/utils/check_ckpt.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 |
4 | ckpt_root = '/sailhome/bingbin/STR-PIP/ckpts/'
5 | dataset = 'JAAD'
6 |
7 | def check(ckpt, ptype='pred'):
8 | f_best_pred = os.path.join(ckpt_root, dataset, ckpt, 'best_{}.pth'.format(ptype))
9 | pred = torch.load(f_best_pred)
10 | print('best {} epoch:'.format(ptype), pred['epoch'])
11 |
12 | if __name__ == '__main__':
13 | ckpt = 'graph_gru_seq30_pred30_lr1.0e-04_wd1.0e-05_bt16_posNone_branchboth_collapse0_combinepair_adjTypespatial_nLayers2_v4Feats_pedGRU_3evalEpoch'
14 | check(ckpt, 'pred')
15 | check(ckpt, 'last')
16 | check(ckpt, 'det')
17 |
--------------------------------------------------------------------------------
/utils/build.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import random
4 | import torch
5 |
6 | import args
7 | from .logger import Logger
8 |
9 |
10 | def build(is_train):
11 | opt, log = args.TrainArgs().parse() if is_train else args.TestArgs().parse()
12 |
13 | if not is_train:
14 | print('Options:')
15 | opt_dict = vars(opt)
16 | for key in sorted(opt_dict):
17 | print('{}: {}'.format(key, opt_dict[key]))
18 | if is_train:
19 | print('lr_init:', opt.lr_init)
20 | print('wd:', opt.wd)
21 | print('ckpt:', opt.ckpt_path)
22 | print()
23 |
24 | os.makedirs(opt.ckpt_path, exist_ok=True)
25 |
26 | # Set seed
27 | torch.manual_seed(2019)
28 | torch.cuda.manual_seed_all(2019)
29 | np.random.seed(2019)
30 | random.seed(2019)
31 |
32 | logger = Logger(opt.ckpt_path, opt.split)
33 |
34 | return opt, logger
35 |
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import sys
4 |
5 | def blue(string):
6 | return '\033[94m'+string+'\033[0m'
7 |
8 | class Logger:
9 | def __init__(self, ckpt_path, name='debug'):
10 | self.logger = logging.getLogger()
11 | self.logger.setLevel(logging.INFO)
12 | formatter = logging.Formatter('%(asctime)s %(message)s', datefmt=blue('[%Y-%m-%d,%H:%M:%S]'))
13 |
14 | fh = logging.FileHandler(os.path.join(ckpt_path, '{}.log'.format(name)), 'w')
15 | fh.setLevel(logging.INFO)
16 | fh.setFormatter(formatter)
17 | self.logger.addHandler(fh)
18 |
19 | ch = logging.StreamHandler(sys.stdout)
20 | ch.setLevel(logging.INFO)
21 | ch.setFormatter(formatter)
22 | self.logger.addHandler(ch)
23 |
24 | def print(self, log):
25 | if isinstance(log, list):
26 | self.logger.info('\n - '.join(log))
27 | else:
28 | self.logger.info(log)
29 |
--------------------------------------------------------------------------------
/models/get_model.py:
--------------------------------------------------------------------------------
1 | # ped-centric
2 | from .model_concat import ConcatModel
3 | from .model_graph import GraphModel
4 | from .model_pos import PosModel
5 | # loc-centric
6 | from .model_loc_graph import LocGraphModel
7 | from .model_loc_concat import LocConcatModel
8 | # baselines
9 | from .baseline_anticipate_cnn import BaselineAnticipateCNN
10 | from .baseline_pose import BaselinePose
11 |
12 |
13 | def get_model(opt):
14 | # ped-centric
15 | if opt.model == 'concat':
16 | model = ConcatModel(opt)
17 | elif opt.model == 'graph':
18 | model = GraphModel(opt)
19 | elif opt.model == 'pos':
20 | model = PosModel(opt)
21 | # loc-centric
22 | elif opt.model == 'loc_concat':
23 | model = LocConcatModel(opt)
24 | elif opt.model == 'loc_graph':
25 | model = LocGraphModel(opt)
26 | # baselines
27 | elif opt.model == 'baseline_anticipate_cnn':
28 | model = BaselineAnticipateCNN(opt)
29 | elif opt.model == 'baseline_pose':
30 | model = BaselinePose(opt)
31 | else:
32 | raise NotImplementedError
33 |
34 | model.setup()
35 | return model
36 |
--------------------------------------------------------------------------------
/args/train_args.py:
--------------------------------------------------------------------------------
1 | from .base_args import BaseArgs
2 |
3 |
4 | class TrainArgs(BaseArgs):
5 | def __init__(self):
6 | super(TrainArgs, self).__init__()
7 |
8 | self.is_train = True
9 | # self.split = 'train'
10 |
11 | self.parser.add_argument('--batch-size', type=int, default=4, help='batch size per gpu')
12 | self.parser.add_argument('--n-epochs', type=int, default=50, help='total # of epochs')
13 | self.parser.add_argument('--n-iters', type=int, default=0, help='total # of iterations')
14 | self.parser.add_argument('--start-epoch', type=int, default=0, help='starting epoch')
15 | self.parser.add_argument('--lr-init', type=float, default=1e-3, help='initial learning rate')
16 | self.parser.add_argument('--lr-decay', type=int, default=0, choices=[0, 1], help='whether to decay learning rate')
17 | self.parser.add_argument('--decay-every', type=int, default=10)
18 | self.parser.add_argument('--wd', type=float, default=1e-5)
19 | self.parser.add_argument('--load-ckpt-dir', type=str, default='', help='directory of checkpoint')
20 | self.parser.add_argument('--load-ckpt-epoch', type=int, default=0, help='epoch to load checkpoint')
21 |
--------------------------------------------------------------------------------
/args/test_args.py:
--------------------------------------------------------------------------------
1 | from .base_args import BaseArgs
2 |
3 |
4 | class TestArgs(BaseArgs):
5 | def __init__(self):
6 | super(TestArgs, self).__init__()
7 |
8 | self.is_train = False
9 | # self.split = 'test'
10 |
11 | self.parser.add_argument('--mode', type=str, choices=['extract', 'evaluate'])
12 |
13 | # hyperparameters
14 | self.parser.add_argument('--batch-size', type=int, default=4, help='batch size')
15 | self.parser.add_argument('--which-epoch', type=int,
16 | help='which epochs to evaluate, -1 to load the best checkpoint')
17 | self.parser.add_argument('--slide', type=int, default=1,
18 | help='Whether to use sliding window when testing.')
19 | self.parser.add_argument('--collect-A', type=int, default=0,
20 | help="Whether to collect weight matrices in the graph model.")
21 | self.parser.add_argument('--save-As-format', type=str, default="",
22 | help="Path to saved weight matrices.")
23 | # self.parser.add_argument('--rand-loader', type=int, default=0,
24 | # help="Whether to randomize the data loader.")
25 |
26 |
--------------------------------------------------------------------------------
/utils/check_switch.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import pickle
4 |
5 | def find_segments(labels):
6 | durs = []
7 | for row in labels:
8 | s_start = 0
9 | for j,val in enumerate(row):
10 | if j and row[j-1] != val:
11 | durs += j-s_start,
12 | s_start = j
13 | durs += len(row) - s_start,
14 | print('# segments:', len(durs))
15 | print('# avg frames:', np.mean(durs))
16 |
17 | def find_segments_wrapper():
18 | ckpt_dir = '/sailhome/bingbin/STR-PIP/ckpts/JAAD/'
19 | ckpt_name = 'graph_gru_seq30_pred30_lr1.0e-05_wd1.0e-05_bt16_posNone_branchped_collapse0_combinepair_adjTypeembed_nLayers2_v2Feats'
20 | label_name = 'label_epochbest_det_stepall.pkl'
21 | fpkl = os.path.join(ckpt_dir, ckpt_name, label_name)
22 |
23 | with open(fpkl, 'rb') as handle:
24 | data = pickle.load(handle)
25 | out = data['out']
26 | gt = data['GT']
27 |
28 | print('Out:')
29 | find_segments(out)
30 | print('GT:')
31 | find_segments(gt)
32 |
33 | print('\nOut det:')
34 | find_segments(out[:, :30])
35 | print('\nOut pred:')
36 | find_segments(out[:, 30:])
37 | print('\nGT det:')
38 | find_segments(gt[:, :30])
39 | print('\nGT pred:')
40 | find_segments(gt[:, 30:])
41 |
42 | if __name__ == '__main__':
43 | find_segments_wrapper()
44 |
--------------------------------------------------------------------------------
/scripts_dir/train_pos.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | gpu_id=0
4 | split='train'
5 | n_acts=9
6 |
7 | lr=1e-4
8 | wd=1e-5
9 | bt=16
10 | lr_decay=1
11 | decay_every=20
12 | n_workers=0
13 | save_every=10
14 | evaluate_every=1
15 |
16 | seq_len=30
17 | predict=1
18 | pred_seq_len=30
19 |
20 | load_cache='pos'
21 | # masks
22 | cache_format='/sailhome/bingbin/STR-PIP/datasets/cache/jaad_collapse_max/{}/ped{}_fid{}.pkl'
23 |
24 | use_gru=1
25 | branch='both'
26 | pos_mode='none'
27 | use_gt_act=1
28 | collapse_cls=0
29 | combine_method='pair'
30 | suffix='_'$pos_mode'_9acts_withGTAct_noPos'
31 | ckpt_name='branch'$branch'_collapse'$collapse_cls'_combine'$combine_method$suffix
32 |
33 | CUDA_VISIBLE_DEVICES=$gpu_id python3 -m train.py \
34 | --model='pos' \
35 | --split=$split \
36 | --n-acts=$n_acts \
37 | --device=$gpu_id \
38 | --dset-name='JAAD' \
39 | --ckpt-name=$ckpt_name \
40 | --lr-init=$lr \
41 | --wd=$wd \
42 | --lr-decay=$lr_decay \
43 | --decay-every=$decay_every \
44 | --seq-len=$seq_len \
45 | --predict=$predict \
46 | --pred-seq-len=$pred_seq_len \
47 | --batch-size=$bt \
48 | --save-every=$save_every \
49 | --evaluate-every=$evaluate_every \
50 | --n-workers=$n_workers \
51 | --load-cache=$load_cache \
52 | --cache-format=$cache_format \
53 | --branch=$branch \
54 | --pos-mode=$pos_mode \
55 | --use-gt-act=$use_gt_act \
56 | --collapse-cls=$collapse_cls \
57 | --combine-method=$combine_method \
58 | --use-gru=$use_gru \
59 |
--------------------------------------------------------------------------------
/scripts_dir/train_baseline_anticipate_cnn.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | gpu_id=1
4 | split='train'
5 | n_acts=9
6 |
7 | lr=1e-4
8 | wd=1e-5
9 | bt=4
10 | lr_decay=1
11 | decay_every=20
12 | n_workers=0
13 | save_every=10
14 | evaluate_every=1
15 |
16 | seq_len=30
17 | predict=1
18 | pred_seq_len=30
19 |
20 | load_cache='masks'
21 | # masks
22 | cache_format='/sailhome/bingbin/STR-PIP/datasets/cache/jaad_collapse_max/{}/ped{}_fid{}.pkl'
23 |
24 | use_gru=1
25 | use_gt_act=0
26 | branch='ped'
27 | pos_mode='none'
28 | collapse_cls=0
29 | combine_method='pair'
30 | suffix='_noGTAct_addDetLoss'
31 | ckpt_name='branch'$branch'_collapse'$collapse_cls'_combine'$combine_method$suffix
32 |
33 | CUDA_VISIBLE_DEVICES=$gpu_id python3 -m train.py \
34 | --model='baseline_anticipate_cnn' \
35 | --split=$split \
36 | --n-acts=$n_acts \
37 | --device=$gpu_id \
38 | --dset-name='JAAD' \
39 | --ckpt-name=$ckpt_name \
40 | --lr-init=$lr \
41 | --wd=$wd \
42 | --lr-decay=$lr_decay \
43 | --decay-every=$decay_every \
44 | --seq-len=$seq_len \
45 | --predict=$predict \
46 | --pred-seq-len=$pred_seq_len \
47 | --batch-size=$bt \
48 | --save-every=$save_every \
49 | --evaluate-every=$evaluate_every \
50 | --n-workers=$n_workers \
51 | --load-cache=$load_cache \
52 | --cache-format=$cache_format \
53 | --branch=$branch \
54 | --pos-mode=$pos_mode \
55 | --collapse-cls=$collapse_cls \
56 | --combine-method=$combine_method \
57 | --use-gru=$use_gru \
58 | --use-gt-act=$use_gt_act \
59 |
--------------------------------------------------------------------------------
/utils/data_proc_loc.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | import numpy as np
4 |
5 | H, W = 1080, 1920
6 | k = 0.25
7 | c = -20
8 | y_max = 200
9 |
10 | def parse_split(fannot_in, fannot_out):
11 | with open(fannot_in, 'rb') as handle:
12 | peds = pickle.load(handle)
13 |
14 | vids = {}
15 | for ped in peds:
16 | vid = ped['vid']
17 | if vid not in vids:
18 | vids[vid] = {
19 | 'act': np.zeros([len(ped['act']), 1]),
20 | 'ped_pos': [[] for _ in range(len(ped['act']))]
21 | }
22 | for fid, pos in enumerate(ped['pos_GT']):
23 | if len(pos) == 0:
24 | continue
25 | x, y, w, h = pos
26 | cx, cy = x+0.5*w, y+h
27 | # check if a pedestrian is in the trapezoid area
28 | if H-cy <= k*cx+c and H-cy <= y_max and H-cy <= k*(W-cx)+c:
29 | vids[vid]['act'][fid] = 1
30 | vids[vid]['ped_pos'][fid] += pos,
31 |
32 | with open(fannot_out, 'wb') as handle:
33 | pickle.dump(vids, handle)
34 |
35 |
36 | def parse_split_wrapper():
37 | annot_root = '/sailhome/ajarno/STR-PIP/datasets'
38 |
39 | # ftrain = 'annot_train_ped_withTag_sanityWithPose.pkl'
40 | # annot_train = os.path.join(annot_root, ftrain)
41 | # annot_train_out = os.path.join(annot_root, 'annot_train_loc.pkl')
42 | # parse_split(annot_train, annot_train_out)
43 |
44 | # ftest = 'annot_test_ped_withTag_sanityWithPose.pkl'
45 | ftest = 'annot_test_ped.pkl'
46 | annot_test = os.path.join(annot_root, ftest)
47 | annot_test_out = os.path.join(annot_root, 'annot_test_loc_new.pkl')
48 | parse_split(annot_test, annot_test_out)
49 |
50 |
51 | if __name__ == '__main__':
52 | parse_split_wrapper()
53 |
--------------------------------------------------------------------------------
/scripts_dir/train_baseline_pose.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | gpu_id=0
4 | split='train'
5 | n_acts=1
6 |
7 | n_epochs=100
8 | lr=1e-4
9 | wd=1e-7
10 | bt=8
11 | lr_decay=1
12 | decay_every=20
13 | n_workers=0
14 | save_every=10
15 | evaluate_every=1
16 | fc_in_dim=256
17 |
18 | seq_len=14
19 | predict=1
20 | pred_seq_len=1
21 |
22 | load_cache='pos'
23 | # masks
24 | cache_format='/sailhome/bingbin/STR-PIP/datasets/cache/jaad_collapse_max/{}/ped{}_fid{}.pkl'
25 | save_cache_format=$cache_format
26 |
27 | use_gru=1
28 | use_gt_act=0
29 | use_pose=1
30 | branch='ped'
31 | pos_mode='none'
32 | collapse_cls=0
33 | combine_method='pair'
34 | suffix='_gru'$fc_in_dim'_zeroPad'
35 | ckpt_name='branch'$branch'_collapse'$collapse_cls'_combine'$combine_method$suffix
36 |
37 | CUDA_VISIBLE_DEVICES=$gpu_id python3 -m train.py \
38 | --model='baseline_pose' \
39 | --split=$split \
40 | --n-acts=$n_acts \
41 | --device=$gpu_id \
42 | --dset-name='JAAD' \
43 | --ckpt-name=$ckpt_name \
44 | --n-epochs=$n_epochs \
45 | --lr-init=$lr \
46 | --wd=$wd \
47 | --lr-decay=$lr_decay \
48 | --decay-every=$decay_every \
49 | --seq-len=$seq_len \
50 | --predict=$predict \
51 | --pred-seq-len=$pred_seq_len \
52 | --batch-size=$bt \
53 | --save-every=$save_every \
54 | --evaluate-every=$evaluate_every \
55 | --fc-in-dim=$fc_in_dim \
56 | --n-workers=$n_workers \
57 | --load-cache=$load_cache \
58 | --cache-format=$cache_format \
59 | --save-cache-format=$save_cache_format \
60 | --branch=$branch \
61 | --pos-mode=$pos_mode \
62 | --collapse-cls=$collapse_cls \
63 | --combine-method=$combine_method \
64 | --use-gru=$use_gru \
65 | --use-gt-act=$use_gt_act \
66 | --use-pose=$use_pose \
67 |
--------------------------------------------------------------------------------
/scripts_dir/wandb/offline-run-20210314_201226-2eciwruc/logs/.nfs000000000001e91000000039:
--------------------------------------------------------------------------------
1 | 2021-03-14 20:12:27,342 INFO MainThread:3357636 [internal.py:wandb_internal():91] W&B internal server running at pid: 3357636, started at: 2021-03-14 20:12:27.341321
2 | 2021-03-14 20:12:27,344 DEBUG HandlerThread:3357636 [handler.py:handle_request():101] handle_request: run_start
3 | 2021-03-14 20:12:27,344 INFO WriterThread:3357636 [datastore.py:open_for_write():77] open: /sailhome/agalczak/crossing/scripts_dir/wandb/offline-run-20210314_201226-2eciwruc/run-2eciwruc.wandb
4 | 2021-03-14 20:12:27,351 DEBUG HandlerThread:3357636 [meta.py:__init__():34] meta init
5 | 2021-03-14 20:12:27,351 DEBUG HandlerThread:3357636 [meta.py:__init__():48] meta init done
6 | 2021-03-14 20:12:27,351 DEBUG HandlerThread:3357636 [meta.py:probe():190] probe
7 | 2021-03-14 20:12:27,369 DEBUG HandlerThread:3357636 [meta.py:_setup_git():180] setup git
8 | 2021-03-14 20:12:27,455 DEBUG HandlerThread:3357636 [meta.py:_setup_git():187] setup git done
9 | 2021-03-14 20:12:27,455 DEBUG HandlerThread:3357636 [meta.py:_save_pip():52] save pip
10 | 2021-03-14 20:12:27,460 DEBUG HandlerThread:3357636 [meta.py:_save_pip():66] save pip done
11 | 2021-03-14 20:12:27,461 DEBUG HandlerThread:3357636 [meta.py:probe():231] probe done
12 | 2021-03-14 20:22:28,964 WARNING MainThread:3357636 [internal.py:is_dead():344] Internal process exiting, parent pid 3357618 disappeared
13 | 2021-03-14 20:22:28,965 ERROR MainThread:3357636 [internal.py:wandb_internal():142] Internal process shutdown.
14 | 2021-03-14 20:22:29,293 INFO HandlerThread:3357636 [handler.py:finish():532] shutting down handler
15 | 2021-03-14 20:22:29,295 INFO WriterThread:3357636 [datastore.py:close():258] close: /sailhome/agalczak/crossing/scripts_dir/wandb/offline-run-20210314_201226-2eciwruc/run-2eciwruc.wandb
16 | 2021-03-14 20:22:29,439 INFO SenderThread:3357636 [sender.py:finish():814] shutting down sender
17 |
--------------------------------------------------------------------------------
/scripts_dir/train_concat.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #!/usr/bin/env python3
3 |
4 |
5 | gpu_id=0
6 | split='train'
7 | n_acts=9
8 |
9 | lr=1e-4
10 | wd=1e-5
11 | bt=4
12 | lr_decay=1
13 | decay_every=20
14 | n_workers=0
15 | save_every=10
16 | evaluate_every=1
17 |
18 | seq_len=30
19 | predict=1
20 | pred_seq_len=30
21 | predict_k=0
22 |
23 | annot_ped_format='/sailhome/ajarno/STR-PIP/datasets/annot_{}_ped_withTag_sanityWithPose.pkl'
24 |
25 | load_cache='masks'
26 | cache_format='/sailhome/ajarno/STR-PIP/datasets/cache/jaad_collapse_max/{}/ped{}_fid{}.pkl'
27 | save_cache_format=$cache_format
28 |
29 | use_gru=1
30 | use_trn=0
31 | ped_gru=1
32 | # pos_mode='center'
33 | pos_mode='none'
34 | use_act=0
35 | use_gt_act=0
36 | use_pose=0
37 | branch='both'
38 | collapse_cls=0
39 | combine_method='pair'
40 | suffix='_cacheMasks_fixGRU_eval3_9acts_noAct_sanityWithPose_withReLU_pedGRU'
41 | suffix='_test_tmp'
42 | ckpt_name='branch'$branch'_collapse'$collapse_cls'_combine'$combine_method$suffix
43 |
44 | # CUDA_VISIBLE_DEVICES=$gpu_id python3 -m train.py \
45 | WANDB_MODE=dryrun CUDA_VISIBLE_DEVICES=$gpu_id python3 -m train.py \
46 | --model='concat' \
47 | --split=$split \
48 | --n-acts=$n_acts \
49 | --device=$gpu_id \
50 | --dset-name='JAAD' \
51 | --ckpt-name=$ckpt_name \
52 | --lr-init=$lr \
53 | --wd=$wd \
54 | --lr-decay=$lr_decay \
55 | --decay-every=$decay_every \
56 | --seq-len=$seq_len \
57 | --predict=$predict \
58 | --pred-seq-len=$pred_seq_len \
59 | --predict-k=$predict_k \
60 | --batch-size=$bt \
61 | --save-every=$save_every \
62 | --evaluate-every=$evaluate_every \
63 | --n-workers=$n_workers \
64 | --annot-ped-format=$annot_ped_format \
65 | --load-cache=$load_cache \
66 | --cache-format=$cache_format \
67 | --save-cache-format=$save_cache_format \
68 | --branch=$branch \
69 | --collapse-cls=$collapse_cls \
70 | --combine-method=$combine_method \
71 | --use-gru=$use_gru \
72 | --use-trn=$use_trn \
73 | --ped-gru=$ped_gru \
74 | --pos-mode=$pos_mode \
75 | --use-act=$use_act \
76 | --use-gt-act=$use_gt_act \
77 | --use-pose=$use_pose \
78 |
--------------------------------------------------------------------------------
/scripts_dir/test_baseline_anticipate_cnn.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | mode='evaluate'
4 |
5 | gpu_id=1
6 | split='test'
7 | n_acts=9
8 | n_workers=0
9 | bt=1
10 |
11 | # test only
12 | slide=0
13 | rand_test=1
14 | log_every=10
15 | ckpt_dir='/sailhome/bingbin/STR-PIP/ckpts'
16 | dataset='JAAD'
17 |
18 | ckpt_name="baseline_anticipate_cnn_gru_seq30_pred30_lr1.0e-04_wd1.0e-05_bt4_posNone_branchped_collapse0_combinepair_noGTAct_addDetLoss"
19 |
20 | seq_len=30
21 | predict=1
22 | pred_seq_len=30
23 |
24 | load_cache='masks'
25 | # masks
26 | cache_format='/sailhome/bingbin/STR-PIP/datasets/cache/jaad_collapse_max/{}/ped{}_fid{}.pkl'
27 |
28 | which_epoch=-1
29 | if [ $which_epoch -eq -1 ]
30 | then
31 | epoch_name='best_pred'
32 | else
33 | epoch_name=$which_epoch
34 | fi
35 | save_output=1
36 | save_output_format=$ckpt_dir'/'$dataset'/'$ckpt_name'/output_epoch'$epoch_name'_step{}.pkl'
37 |
38 | if [ "$mode" = "extract" ]
39 | then
40 | extract_feats_dir='/sailhome/bingbin/STR-PIP/datasets/cache/JAAD_conv_feats/concat_gru_lr1.0e-05_bt4_test_epoch5/test/'
41 | else
42 | extract_feats_dir='none_existent'
43 | fi
44 |
45 |
46 | use_gru=1
47 | use_gt_act=0
48 | branch='ped'
49 | pos_mode='none'
50 | collapse_cls=0
51 | combine_method='pair'
52 |
53 | CUDA_VISIBLE_DEVICES=$gpu_id python3 test.py \
54 | --model='baseline_anticipate_cnn' \
55 | --mode=$mode \
56 | --slide=$slide \
57 | --rand-test=$rand_test \
58 | --split=$split \
59 | --n-acts=$n_acts \
60 | --device=$gpu_id \
61 | --dset-name='JAAD' \
62 | --ckpt-dir=$ckpt_dir \
63 | --ckpt-name=$ckpt_name \
64 | --which-epoch=$which_epoch \
65 | --save-output=$save_output \
66 | --save-output-format=$save_output_format \
67 | --seq-len=$seq_len \
68 | --predict=$predict \
69 | --pred-seq-len=$pred_seq_len \
70 | --batch-size=$bt \
71 | --n-workers=$n_workers \
72 | --load-cache=$load_cache \
73 | --cache-format=$cache_format \
74 | --branch=$branch \
75 | --pos-mode=$pos_mode \
76 | --collapse-cls=$collapse_cls \
77 | --combine-method=$combine_method \
78 | --use-gru=$use_gru \
79 | --use-gt-act=$use_gt_act \
80 |
--------------------------------------------------------------------------------
/scripts_dir/train_loc_concat.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | gpu_id=0
4 | split='train'
5 | n_acts=1
6 |
7 | n_epochs=80
8 | start_epoch=2
9 | lr=1e-4
10 | wd=1e-5
11 | bt=1
12 | lr_decay=1
13 | decay_every=20
14 | n_workers=0
15 | save_every=10
16 | evaluate_every=1
17 |
18 | seq_len=30
19 | predict=1
20 | pred_seq_len=30
21 | predict_k=0
22 |
23 | annot_loc_format='/sailhome/ajarno/STR-PIP/datasets/annot_{}_loc.pkl'
24 |
25 | load_cache='masks'
26 | cache_format='/sailhome/ajarno/STR-PIP/datasets/cache/jaad_loc/{}/ped{}_fid{}.pkl'
27 | save_cache_format=$cache_format
28 |
29 | pretrained_path='/sailhome/ajarno/STR-PIP/ckpts/JAAD_loc/loc_concat_gru_seq30_pred30_lr1.0e-04_wd1.0e-05_bt1_posNone_branchboth_collapse0_combinepair_tmp/best_pred.pth'
30 |
31 |
32 | use_gru=1
33 | use_trn=0
34 | ped_gru=1
35 | # pos_mode='center'
36 | pos_mode='none'
37 | use_act=0
38 | use_gt_act=0
39 | use_pose=0
40 | branch='both'
41 | collapse_cls=0
42 | combine_method='pair'
43 | suffix='_tmp'
44 | ckpt_name='branch'$branch'_collapse'$collapse_cls'_combine'$combine_method$suffix
45 |
46 | CUDA_VISIBLE_DEVICES=$gpu_id python3 -m train.py \
47 | --model='loc_concat' \
48 | --split=$split \
49 | --n-acts=$n_acts \
50 | --device=$gpu_id \
51 | --dset-name='JAAD_loc' \
52 | --ckpt-name=$ckpt_name \
53 | --n-epochs=$n_epochs \
54 | --start-epoch=$start_epoch \
55 | --lr-init=$lr \
56 | --wd=$wd \
57 | --lr-decay=$lr_decay \
58 | --decay-every=$decay_every \
59 | --seq-len=$seq_len \
60 | --predict=$predict \
61 | --pred-seq-len=$pred_seq_len \
62 | --predict-k=$predict_k \
63 | --batch-size=$bt \
64 | --save-every=$save_every \
65 | --evaluate-every=$evaluate_every \
66 | --n-workers=$n_workers \
67 | --annot-loc-format=$annot_loc_format \
68 | --load-cache=$load_cache \
69 | --cache-format=$cache_format \
70 | --save-cache-format=$save_cache_format \
71 | --branch=$branch \
72 | --collapse-cls=$collapse_cls \
73 | --combine-method=$combine_method \
74 | --use-gru=$use_gru \
75 | --use-trn=$use_trn \
76 | --ped-gru=$ped_gru \
77 | --pos-mode=$pos_mode \
78 | --use-act=$use_act \
79 | --use-gt-act=$use_gt_act \
80 | --use-pose=$use_pose \
81 |
--------------------------------------------------------------------------------
/scripts_dir/test_loc_concat.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # mode='evaluate'
4 | mode='extract'
5 |
6 | gpu_id=0
7 | n_workers=0
8 | n_acts=1
9 |
10 | seq_len=1
11 | predict=0
12 | pred_seq_len=30
13 |
14 | # test only
15 | slide=0
16 | rand_test=1
17 | log_every=10
18 |
19 |
20 | split='train'
21 | # split='test'
22 | use_gru=1
23 | use_trn=0
24 | pos_mode='none'
25 | use_act=0
26 | use_gt_act=0
27 | use_pose=0
28 | # branch='ped'
29 | branch='both'
30 | collapse_cls=0
31 | combine_method='pair'
32 |
33 | annot_loc_format='/sailhome/bingbin/STR-PIP/datasets/annot_{}_loc.pkl'
34 |
35 | load_cache='masks'
36 | # load_cache='none'
37 |
38 | cache_format='/sailhome/bingbin/STR-PIP/datasets/cache/jaad_loc/{}/ped{}_fid{}.pkl'
39 | save_cache_format=$cache_format
40 |
41 | ckpt_name='loc_concat_gru_seq30_pred30_lr1.0e-04_wd1.0e-05_bt1_posNone_branchboth_collapse0_combinepair_tmp'
42 |
43 | # -1 for the best epoch
44 | which_epoch=-1
45 |
46 | # this is to set a non-existent epoch s.t. the features are extracted from ImageNet backbone
47 | # which_epoch=100
48 |
49 | if [ "$mode" = "extract" ]
50 | then
51 | extract_feats_dir='/sailhome/bingbin/STR-PIP/datasets/cache/jaad_loc/JAAD_conv_feats/'$ckpt_name'/'$split'/'
52 | else
53 | extract_feats_dir='none_existent'
54 | fi
55 |
56 | CUDA_VISIBLE_DEVICES=$gpu_id python3 test.py \
57 | --model='loc_concat' \
58 | --split=$split \
59 | --n-acts=$n_acts \
60 | --mode=$mode \
61 | --device=$gpu_id \
62 | --log-every=$log_every \
63 | --dset-name='JAAD_loc' \
64 | --ckpt-name=$ckpt_name \
65 | --batch-size=1 \
66 | --n-workers=$n_workers \
67 | --annot-loc-format=$annot_loc_format \
68 | --load-cache=$load_cache \
69 | --save-cache-format=$save_cache_format \
70 | --cache-format=$cache_format \
71 | --seq-len=$seq_len \
72 | --predict=$predict \
73 | --pred-seq-len=$pred_seq_len \
74 | --use-gru=$use_gru \
75 | --use-trn=$use_trn \
76 | --use-act=$use_act \
77 | --use-gt-act=$use_gt_act \
78 | --use-pose=$use_pose \
79 | --pos-mode=$pos_mode \
80 | --collapse-cls=$collapse_cls \
81 | --slide=$slide \
82 | --rand-test=$rand_test \
83 | --branch=$branch \
84 | --which-epoch=$which_epoch \
85 | --extract-feats-dir=$extract_feats_dir
86 |
87 |
--------------------------------------------------------------------------------
/models/baseline_anticipate_cnn.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from models.model_base import BaseModel
7 |
8 | import pdb
9 |
10 | class BaselineAnticipateCNN(BaseModel):
11 | def __init__(self, opt):
12 | super().__init__(opt)
13 | # nRows: seq_len
14 | # nCols: n_acts, i.e. acts.shape[1]
15 | self.W1 = nn.Conv2d(1, 8, kernel_size=(5,1), padding=(2,0))
16 | self.W2 = nn.Conv2d(8, 16, kernel_size=(5,1), padding=(2,0))
17 | self.fc1 = nn.Linear(16*self.seq_len*(self.n_acts//4), 1024)
18 | self.fc2 = nn.Linear(1024, self.seq_len*self.n_acts)
19 |
20 | # init params
21 | self.W1.weight.data.normal_(std=0.1)
22 | self.W1.bias.data.fill_(0.1)
23 | self.W2.weight.data.normal_(std=0.1)
24 | self.W2.bias.data.fill_(0.1)
25 |
26 | self.relu = nn.ReLU()
27 | self.pool = nn.MaxPool2d(kernel_size=(1,2), stride=(1,2))
28 | self.l2_norm = F.normalize
29 |
30 | if not self.use_gt_act:
31 | self.gru = nn.GRU(self.ped_dim, self.ped_dim, 2, batch_first=True).to(self.device)
32 |
33 | def forward(self, ped, masks, bbox, act, pose=None):
34 | if self.use_gt_act:
35 | # use GT action observations
36 | x = act[:, :self.seq_len].unsqueeze(1)
37 | else:
38 | # predict action labels from pedestrian crops
39 |
40 | # ped_crops: (bt, 30, 3, 224, 224)
41 | B, T, _, _, _ = ped.shape
42 |
43 | ped_crops = ped.view(-1, 3, 224, 224)
44 | ped_crops = ped_crops.type(self.dtype).to(self.device)
45 | ped_feats = self.ped_encoder(ped_crops)
46 | # ped_feats: (B, T, d)
47 | ped_feats = ped_feats.view(B, T, -1)
48 |
49 | temporal_feats, _ = self.gru(ped_feats)
50 | h = self.classifier(temporal_feats)
51 | logits = self.sigmoid(h)
52 | x = (logits > 0.5).type(self.dtype)
53 | x = x.unsqueeze(1)
54 |
55 | # x = acts[:, :self.seq_len].unsqueeze(1)
56 | x = self.W1(x)
57 | x = self.relu(x)
58 | # max pool over the channel dimension
59 | x = self.pool(x)
60 |
61 | x = self.W2(x)
62 | x = self.relu(x)
63 | # max pool over the channel dimension
64 | x = self.pool(x)
65 |
66 | x = x.view(-1, 16*self.seq_len*(self.n_acts // 4))
67 | x = self.fc1(x)
68 | x = self.fc2(x)
69 |
70 | x = x.view(-1, self.seq_len, self.n_acts)
71 | x = self.l2_norm(x, dim=2)
72 |
73 | if not self.use_gt_act:
74 | # also supervise on the detection
75 | x = torch.cat([h, x], 1)
76 |
77 | return x
78 |
--------------------------------------------------------------------------------
/scripts_dir/train_concat_stip.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | gpu_id=0
4 | split='train'
5 | n_acts=1
6 |
7 | lr=1e-5
8 | wd=1e-5
9 | bt=2
10 | lr_decay=1
11 | decay_every=10
12 | n_workers=0
13 | save_every=10
14 | evaluate_every=1
15 |
16 | seq_len=30
17 | predict=1
18 | pred_seq_len=10
19 | predict_k=0
20 |
21 | annot_ped_format='/vision/group/prolix/processed/pedestrians/{}.pkl'
22 |
23 | cache_obj_bbox_format='/vision/group/prolix/processed/obj_bbox_20fps_merged/{}_seg{}.pkl'
24 | cache_obj_bbox_format_left='/vision/group/prolix/processed/left/obj_bbox_20fps_merged/{}_seg{}.pkl'
25 | cache_obj_bbox_format_right='/vision/group/prolix/processed/right/obj_bbox_20fps_merged/{}_seg{}.pkl'
26 |
27 | # load_cache='none'
28 | load_cache='mask'
29 | # cache_format='/sailhome/bingbin/STR-PIP/datasets/cache/jaad_collapse_max/{}/ped{}_fid{}.pkl'
30 | #cache_format='/vision/group/prolix/processed/cache/{}/ped{}_fid{}.pkl'
31 | cache_format=/dev/null
32 | save_cache_format=$cache_format
33 |
34 | use_gru=1
35 | use_trn=0
36 | ped_gru=1
37 | # pos_mode='center'
38 | pos_mode='none'
39 | use_act=0
40 | use_gt_act=0
41 | use_pose=0
42 | branch='both'
43 | collapse_cls=0
44 | combine_method='pair'
45 | suffix='_decay10_tmp'
46 | ckpt_name='branch'$branch'_collapse'$collapse_cls'_combine'$combine_method$suffix
47 |
48 | # CUDA_VISIBLE_DEVICES=$gpu_id python3 -m train.py \
49 | WANDB_MODE=dryrun CUDA_VISIBLE_DEVICES=$gpu_id python -m train.py \
50 | --model='concat' \
51 | --split=$split \
52 | --n-acts=$n_acts \
53 | --device=$gpu_id \
54 | --dset-name='STIP' \
55 | --ckpt-name=$ckpt_name \
56 | --lr-init=$lr \
57 | --wd=$wd \
58 | --lr-decay=$lr_decay \
59 | --decay-every=$decay_every \
60 | --seq-len=$seq_len \
61 | --predict=$predict \
62 | --pred-seq-len=$pred_seq_len \
63 | --predict-k=$predict_k \
64 | --batch-size=$bt \
65 | --save-every=$save_every \
66 | --evaluate-every=$evaluate_every \
67 | --n-workers=$n_workers \
68 | --annot-ped-format=$annot_ped_format \
69 | --load-cache=$load_cache \
70 | --cache-obj-bbox-format=$cache_obj_bbox_format \
71 | --cache-obj-bbox-format-left=$cache_obj_bbox_format_left \
72 | --cache-obj-bbox-format-right=$cache_obj_bbox_format_right \
73 | --cache-format=$cache_format \
74 | --save-cache-format=$save_cache_format \
75 | --branch=$branch \
76 | --collapse-cls=$collapse_cls \
77 | --combine-method=$combine_method \
78 | --use-gru=$use_gru \
79 | --use-trn=$use_trn \
80 | --ped-gru=$ped_gru \
81 | --pos-mode=$pos_mode \
82 | --use-act=$use_act \
83 | --use-gt-act=$use_gt_act \
84 | --use-pose=$use_pose \
85 |
--------------------------------------------------------------------------------
/scripts_dir/train_graph_fromMasks.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | gpu_id=0
4 | n_workers=0
5 | split='train'
6 | n_acts=9
7 |
8 | # train only
9 | n_epochs=80
10 | lr=1e-4
11 | wd=1e-5
12 | bt=6
13 | lr_decay=1
14 | decay_every=30
15 | save_every=10
16 | evaluate_every=1
17 | log_every=30
18 |
19 | load_cache='masks'
20 | # load_cache='feats'
21 |
22 | if [ $load_cache = 'feats' ]
23 | then
24 | cache_root='/sailhome/bingbin/STR-PIP/datasets/cache/JAAD_conv_feats/'
25 | cache_baseformat='/{}/ped{}_fid{}.pkl'
26 | # v3 feats
27 | cache_dir='concat_gru_seq14_pred1_lr1.0e-04_wd1.0e-05_bt4_posNone_branchped_collapse0_combinepair_cacheMasks_fixGRU_eval3_9acts_noAct_sanityWithPose_withReLU_ordered'
28 |
29 | # v2 feats
30 | # cache_format='concat_gru_lr1.0e-04_wd1.0e-05_bt4_ped_collapse0_combinepair_useBBox0_cacheMasks_fixGRU_singleTime'
31 |
32 | # cache_format='imageNet_pretrained_singleTime'
33 | cache_format=$cache_root$cache_dir$cache_baseformat
34 | else
35 | cache_format='/sailhome/bingbin/STR-PIP/datasets/cache/jaad_collapse/{}/ped{}_fid{}.pkl'
36 | fi
37 |
38 |
39 | # Regularization on temporal smoothness
40 | reg_smooth='none'
41 | reg_lambda=5e-4
42 |
43 | seq_len=30
44 | predict=1
45 | pred_seq_len=30
46 | predict_k=0
47 |
48 | # temporal modeling
49 | use_gru=1
50 | use_trn=0
51 | ped_gru=1
52 | ctxt_gru=0
53 |
54 | # features
55 | use_act=0
56 | use_gt_act=0
57 | pos_mode='none'
58 | branch='ped'
59 | adj_type='spatial'
60 | #adj_type='inner'
61 | n_layers=2
62 | collapse_cls=0
63 | combine_method='pair'
64 |
65 | # saving & loading
66 | suffix='_v3Feats_fromMasks_pedGRU'
67 | # suffix='_v3Feats__pedGRU_ctxtGRU'
68 | ckpt_name='branch'$branch'_collapse'$collapse_cls'_combine'$combine_method'_adjType'$adj_type'_nLayers'$n_layers$suffix
69 | # ckpt_name='graph_seq30_layer2_embed'
70 |
71 | CUDA_VISIBLE_DEVICES=$gpu_id python3 -m train.py \
72 | --model='graph' \
73 | --reg-smooth=$reg_smooth \
74 | --reg-lambda=$reg_lambda \
75 | --split=$split \
76 | --n-acts=$n_acts \
77 | --device=$gpu_id \
78 | --dset-name='JAAD' \
79 | --ckpt-name=$ckpt_name \
80 | --n-epochs=$n_epochs \
81 | --lr-init=$lr \
82 | --wd=$wd \
83 | --lr-decay=$lr_decay \
84 | --decay-every=$decay_every \
85 | --seq-len=$seq_len \
86 | --predict=$predict \
87 | --pred-seq-len=$pred_seq_len \
88 | --predict-k=$predict_k \
89 | --batch-size=$bt \
90 | --save-every=$save_every \
91 | --evaluate-every=$evaluate_every \
92 | --log-every=$log_every \
93 | --n-workers=$n_workers \
94 | --load-cache=$load_cache \
95 | --cache-format=$cache_format \
96 | --branch=$branch \
97 | --adj-type=$adj_type \
98 | --n-layers=$n_layers \
99 | --collapse-cls=$collapse_cls \
100 | --combine-method=$combine_method \
101 | --use-gru=$use_gru \
102 | --use-trn=$use_trn \
103 | --ped-gru=$ped_gru \
104 | --ctxt-gru=$ctxt_gru \
105 | --use-act=$use_act \
106 | --use-gt-act=$use_gt_act \
107 | --pos-mode=$pos_mode
108 |
--------------------------------------------------------------------------------
/scripts_dir/test_concat_stip.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # mode='evaluate'
4 | mode='extract'
5 |
6 | gpu_id=0
7 | n_workers=0
8 | n_acts=1
9 |
10 | seq_len=1
11 | predict=0
12 | pred_seq_len=2
13 |
14 | # test only
15 | slide=0
16 | rand_test=1
17 | log_every=10
18 | ckpt_dir='/sailhome/bingbin/STR-PIP/ckpts'
19 | dataset='STIP'
20 |
21 | split='test'
22 | use_gru=1
23 | ped_gru=1
24 | use_trn=0
25 | pos_mode='none'
26 | use_act=0
27 | use_gt_act=0
28 | use_pose=0
29 | # branch='ped'
30 | branch='both'
31 | collapse_cls=0
32 | combine_method='pair'
33 |
34 |
35 | annot_ped_format='/vision/group/prolix/processed/pedestrians/{}.pkl'
36 | cache_obj_bbox_format='/vision/group/prolix/processed/obj_bbox_20fps_merged/{}_seg{}.pkl'
37 |
38 |
39 | load_cache='masks'
40 |
41 | cache_format='/vision/group/prolix/processed/cache/{}/ped{}_fid{}.pkl'
42 | save_cache_format=$cache_format
43 |
44 | # ckpt_name='concat_gru_seq8_pred2_lr1.0e-04_wd1.0e-05_bt2_posNone_branchboth_collapse0_combinepair_testANN_hanh1'
45 |
46 | # ckpt_name='concat_gru_seq8_pred2_lr1.0e-04_wd1.0e-05_bt2_posNone_branchboth_collapse0_combinepair_run1'
47 |
48 | ckpt_name='concat_gru_seq8_pred2_lr1.0e-05_wd1.0e-05_bt2_posNone_branchboth_collapse0_combinepair_decay10'
49 | # -1 for the best epoch
50 | which_epoch=-1
51 |
52 | # this is to set a non-existent epoch s.t. the features are extracted from ImageNet backbone
53 | # which_epoch=100
54 |
55 | if [ $which_epoch -eq -1 ]
56 | then
57 | epoch_name='best_pred'
58 | else
59 | epoch_name=$which_epoch
60 | fi
61 | save_output=10
62 | save_output_format=$ckpt_dir'/'$dataset'/'$ckpt_name'/output_epoch'$epoch_name'_step{}.pkl'
63 |
64 |
65 | if [ "$mode" = "extract" ]
66 | then
67 | extract_feats_dir='/vision/group/prolix/processed/cache/STIP_conv_feats/'$ckpt_name'/'$split'/'
68 | else
69 | extract_feats_dir='none_existent'
70 | fi
71 |
72 | CUDA_VISIBLE_DEVICES=$gpu_id python3 test.py \
73 | --model='concat' \
74 | --split=$split \
75 | --n-acts=$n_acts \
76 | --mode=$mode \
77 | --device=$gpu_id \
78 | --log-every=$log_every \
79 | --dset-name='STIP' \
80 | --ckpt-name=$ckpt_name \
81 | --batch-size=1 \
82 | --n-workers=$n_workers \
83 | --annot-ped-format=$annot_ped_format \
84 | --cache-obj-bbox-format=$cache_obj_bbox_format \
85 | --load-cache=$load_cache \
86 | --save-cache-format=$save_cache_format \
87 | --cache-format=$cache_format \
88 | --seq-len=$seq_len \
89 | --predict=$predict \
90 | --pred-seq-len=$pred_seq_len \
91 | --use-gru=$use_gru \
92 | --ped-gru=$ped_gru \
93 | --use-trn=$use_trn \
94 | --use-act=$use_act \
95 | --use-gt-act=$use_gt_act \
96 | --use-pose=$use_pose \
97 | --pos-mode=$pos_mode \
98 | --collapse-cls=$collapse_cls \
99 | --slide=$slide \
100 | --rand-test=$rand_test \
101 | --branch=$branch \
102 | --which-epoch=$which_epoch \
103 | --save-output=$save_output \
104 | --save-output-format=$save_output_format \
105 | --extract-feats-dir=$extract_feats_dir
106 |
107 |
--------------------------------------------------------------------------------
/models/baseline_pose.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from models.model_base import BaseModel
7 |
8 | import pdb
9 |
10 | class BaselinePose(BaseModel):
11 | def __init__(self, opt):
12 | super().__init__(opt)
13 |
14 | if self.use_gru:
15 | # aggregate poses with FC + GRU
16 | self.embed_pose = nn.Sequential(
17 | nn.Linear(9*2, 128),
18 | nn.ReLU(),
19 | nn.Linear(128, self.fc_in_dim),
20 | )
21 | self.gru = nn.GRU(self.fc_in_dim, self.fc_in_dim, 2, batch_first=True)
22 | if self.predict:
23 | self.gru_pred = nn.GRU(self.fc_in_dim, self.fc_in_dim, 2, batch_first=True)
24 | else:
25 | # aggregating poses using 1D Conv as in BaselineAnticipateCNN
26 | # nRows: seq_len
27 | # nCols: n_acts, i.e. acts.shape[1]
28 | self.W1 = nn.Conv2d(1, 8, kernel_size=(5,1), padding=(2,0))
29 | self.W2 = nn.Conv2d(8, 16, kernel_size=(5,1), padding=(2,0))
30 | hidden_size = 256
31 | self.fc1 = nn.Linear(16*self.seq_len*((9*2)//4), hidden_size) # 9 joints x 2 coor (x & y)
32 | self.fc2 = nn.Linear(hidden_size, self.pred_seq_len*self.n_acts)
33 |
34 | # init params
35 | self.W1.weight.data.normal_(std=0.1)
36 | self.W1.bias.data.fill_(0.1)
37 | self.W2.weight.data.normal_(std=0.1)
38 | self.W2.bias.data.fill_(0.1)
39 |
40 | self.relu = nn.ReLU()
41 | self.pool = nn.MaxPool2d(kernel_size=(1,2), stride=(1,2))
42 | self.l2_norm = F.normalize
43 |
44 | # if not self.use_gt_act:
45 | # self.gru = nn.GRU(self.fc_in_dim, self.fc_in_dim, 2, batch_first=True).to(self.device)
46 |
47 | def forward(self, ped, masks, bbox, act, pose=None):
48 | # use GT action observations
49 | B = ped.shape[0]
50 |
51 | x = self.util_norm_pose(pose)
52 |
53 | x = x.contiguous().view(B, self.seq_len, 18)
54 |
55 | if self.use_gru:
56 | feats = self.embed_pose(x)
57 | temp_feats, h = self.gru(feats)
58 |
59 | if self.predict:
60 | o = temp_feats[:, -1:]
61 | pred_outs = []
62 | for pred_t in range(self.pred_seq_len):
63 | o, h = self.gru_pred(o, h)
64 | pred_outs += o,
65 | pred_outs = torch.cat(pred_outs, 1)
66 | temp_feats = torch.cat([temp_feats, pred_outs], 1)
67 |
68 | logits = self.classifier(temp_feats)
69 | else:
70 | x = x.unsqueeze(1) # shape: (B, 1, self.seq_len, 18)
71 | x = self.W1(x)
72 | x = self.relu(x)
73 | # max pool over the channel dimension
74 | x = self.pool(x)
75 |
76 | x = self.W2(x)
77 | x = self.relu(x)
78 | # max pool over the channel dimension
79 | x = self.pool(x)
80 |
81 | x = x.view(-1, 16*self.seq_len*((9*2) // 4))
82 | x = self.fc1(x)
83 | x = self.fc2(x)
84 |
85 | x = x.view(-1, self.pred_seq_len, self.n_acts)
86 | logits = self.l2_norm(x, dim=2)
87 |
88 | # if not self.pred_only:
89 | # x = torch.cat([h, x], 1)
90 | return logits
91 |
--------------------------------------------------------------------------------
/models/model_pos.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from models.backbone.resnet_based import resnet_backbone
6 |
7 | from models.model_base import BaseModel
8 |
9 | import pdb
10 |
11 |
12 | class PosModel(BaseModel):
13 | def __init__(self, opt):
14 | super().__init__(opt)
15 | self.seq_len = opt.seq_len
16 | self.pos_mode = opt.pos_mode
17 | if self.pos_mode == 'center':
18 | self.fc_in_dim = 2
19 | elif self.pos_mode == 'height':
20 | self.fc_in_dim = 3
21 | elif self.pos_mode == 'bbox':
22 | self.fc_in_dim = 4
23 | elif self.pos_mode == 'both':
24 | self.fc_in_dim = 6
25 | elif self.pos_mode == 'none':
26 | self.fc_in_dim = 0
27 |
28 | self.use_gt_act = opt.use_gt_act
29 | if self.use_gt_act:
30 | self.fc_in_dim += self.n_acts
31 |
32 | if self.fc_in_dim == 0:
33 | raise ValueError("The model should use at least one of 'pos_mode' or 'use_gt_act'.")
34 |
35 | self.gru = nn.GRU(self.fc_in_dim, self.fc_in_dim, 2, batch_first=True).to(self.device)
36 | self.classifier = nn.Sequential(
37 | nn.Linear(self.fc_in_dim, self.fc_in_dim),
38 | nn.Linear(self.fc_in_dim, self.n_acts)
39 | )
40 |
41 | if self.predict:
42 | # NOTE: this GRU only takes in seq of len 1, i.e. one time step
43 | # GRU is used over GRUCell for multilayer
44 | self.gru_pred = nn.GRU(self.fc_in_dim, self.fc_in_dim, 2, batch_first=True).to(self.device)
45 |
46 |
47 | def forward(self, ped, ctxt, bbox, act, pose=None):
48 | # bbox: (B, T, 4): (y, x, h, w)
49 | bbox = bbox.to(self.device)
50 | if self.pos_mode in ['center', 'both', 'height']:
51 | y = bbox[:,:, 0] + .5*bbox[:,:,2]
52 | x = bbox[:,:, 1] + .5*bbox[:,:,3]
53 | y = y.unsqueeze(-1)
54 | x = x.unsqueeze(-1)
55 | centers = torch.cat([y,x], -1)
56 | if self.pos_mode == 'both':
57 | gru_in = torch.cat([centers, bbox], -1)
58 | elif self.pos_mode == 'height':
59 | h = bbox[:,:,2:3]
60 | gru_in = torch.cat([centers, h], -1)
61 | else:
62 | gru_in = centers
63 | elif self.pos_mode == 'bbox':
64 | gru_in = bbox
65 | elif self.pos_mode == 'none':
66 | gru_in = None
67 |
68 | if self.use_gt_act:
69 | # shape: (bt, all_seq_len, n_acts+len_for_pos)
70 | if self.n_acts == 1:
71 | act = act[:, :, 1:2]
72 | if gru_in is None:
73 | gru_in = act
74 | else:
75 | gru_in = torch.cat([gru_in, act], -1)
76 |
77 | if self.predict:
78 | gru_in = gru_in[:, :self.seq_len]
79 |
80 | output, h = self.gru(gru_in)
81 |
82 | if self.predict:
83 | # o: (B, 1, fc_in_dim)
84 | o = output[:, -1:]
85 | pred_outs = []
86 | for _ in range(self.pred_seq_len):
87 | o, h = self.gru_pred(o, h)
88 | pred_outs += o,
89 |
90 | # pred_outs: (B, pred_seq_len, fc_in_dim)
91 | pred_outs = torch.cat(pred_outs, 1)
92 | # frame_feats: (B, seq_len + pred_seq_len, fc_in_dim)
93 | output = torch.cat([output, pred_outs], 1)
94 |
95 | logits = self.classifier(output)
96 | return logits
97 |
--------------------------------------------------------------------------------
/scripts_dir/train_graph.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | gpu_id=1
4 | split='train'
5 | n_workers=0
6 | n_acts=9
7 | bt=16
8 | log_every=30
9 |
10 | # Training only
11 | n_epochs=80
12 | start_epoch=0
13 | lr=1e-4
14 | wd=2e-6
15 | lr_decay=1
16 | decay_every=20
17 | save_every=10
18 | evaluate_every=1
19 |
20 | reg_smooth='none'
21 | reg_lambda=0
22 |
23 | seq_len=30
24 | predict=1
25 | pred_seq_len=30
26 | predict_k=0
27 |
28 | load_cache='feats'
29 | cache_format='/sailhome/bingbin/STR-PIP/datasets/cache/JAAD_conv_feats/concat_gru_seq30_pred30_lr1.0e-04_wd1.0e-05_bt4_posNone_branchboth_collapse0_combinepair_cacheMasks_fixGRU_eval3_9acts_noAct_sanityWithPose_withReLU_pedGRU/{}/ped{}_fid{}.pkl'
30 |
31 | use_gru=1
32 | use_trn=0
33 | ped_gru=1
34 | ctxt_gru=0
35 | ctxt_node=0
36 |
37 | # features
38 | use_act=0
39 | use_gt_act=0
40 | use_driver=0
41 | use_pose=0
42 | pos_mode='none'
43 | # branch='ped'
44 | branch='both'
45 | adj_type='spatial'
46 | # adj_type='random'
47 | # adj_type='uniform'
48 | # adj_type='spatialOnly'
49 | # adj_type='all'
50 | # adj_type='inner'
51 | use_obj_cls=0
52 | n_layers=2
53 | diff_layer_weight=0
54 | collapse_cls=0
55 | combine_method='pair'
56 |
57 | # saving & loading
58 | # suffix='_v4Feats_pedGRU_newCtxtGRU_3evalEpoch'
59 | suffix='_v4Feats_pedGRU_3evalEpoch_sigmoidMean'
60 |
61 | if [[ "$ctxt_gru" == 1 ]]
62 | then
63 | suffix=$suffix"_newCtxtGRU"
64 | fi
65 |
66 | if [[ "$use_obj_cls" == 1 ]]
67 | then
68 | suffix=$suffix"_objCls"
69 | fi
70 |
71 | if [[ "$use_driver" == 1 ]]
72 | then
73 | suffix=$suffix"_useDriver"
74 | fi
75 |
76 | ckpt_name='branch'$branch'_collapse'$collapse_cls'_combine'$combine_method'_adjType'$adj_type'_nLayers'$n_layers'_diffW'$diff_layer_weight$suffix
77 | # ckpt_name='graph_seq30_layer2_embed'
78 |
79 | CUDA_VISIBLE_DEVICES=$gpu_id python3 -m train.py \
80 | --model='graph' \
81 | --reg-smooth=$reg_smooth \
82 | --reg-lambda=$reg_lambda \
83 | --split=$split \
84 | --n-acts=$n_acts \
85 | --device=$gpu_id \
86 | --dset-name='JAAD' \
87 | --ckpt-name=$ckpt_name \
88 | --n-epochs=$n_epochs \
89 | --start-epoch=$start_epoch \
90 | --lr-init=$lr \
91 | --wd=$wd \
92 | --lr-decay=$lr_decay \
93 | --decay-every=$decay_every \
94 | --seq-len=$seq_len \
95 | --predict=$predict \
96 | --pred-seq-len=$pred_seq_len \
97 | --predict-k=$predict_k \
98 | --batch-size=$bt \
99 | --save-every=$save_every \
100 | --evaluate-every=$evaluate_every \
101 | --log-every=$log_every \
102 | --n-workers=$n_workers \
103 | --load-cache=$load_cache \
104 | --cache-format=$cache_format \
105 | --branch=$branch \
106 | --adj-type=$adj_type \
107 | --use-obj-cls=$use_obj_cls \
108 | --n-layers=$n_layers \
109 | --diff-layer-weight=$diff_layer_weight \
110 | --collapse-cls=$collapse_cls \
111 | --combine-method=$combine_method \
112 | --use-gru=$use_gru \
113 | --use-trn=$use_trn \
114 | --ped-gru=$ped_gru \
115 | --ctxt-gru=$ctxt_gru \
116 | --ctxt-node=$ctxt_node \
117 | --use-act=$use_act \
118 | --use-gt-act=$use_gt_act \
119 | --use-driver=$use_driver \
120 | --use-pose=$use_pose \
121 | --pos-mode=$pos_mode
122 |
--------------------------------------------------------------------------------
/models/backbone/resnet_based.py:
--------------------------------------------------------------------------------
1 | import torch
2 | try:
3 | from models.backbone.imagenet_pretraining import load_pretrained_2D_weights
4 | from models.backbone.resnet.basicblock import BasicBlock2D, BasicBlock3D, BasicBlock2_1D
5 | from models.backbone.resnet.bottleneck import Bottleneck2D, Bottleneck3D, Bottleneck2_1D
6 | from models.backbone.resnet.resnet import ResNetBackBone
7 | except:
8 | from imagenet_pretraining import load_pretrained_2D_weights
9 | from resnet.basicblock import BasicBlock2D, BasicBlock3D, BasicBlock2_1D
10 | from resnet.bottleneck import Bottleneck2D, Bottleneck3D, Bottleneck2_1D
11 | from resnet.resnet import ResNetBackBone
12 |
13 |
14 | # __all__ = [
15 | # 'resnet_two_heads',
16 | # ]
17 |
18 |
19 | def resnet_backbone(depth=18, blocks='2D_2D_2D_2D', **kwargs):
20 | """Constructs a ResNet-18 model backbone
21 | """
22 | # Blocks and layers
23 | list_block, list_layers = get_cnn_structure(depth=depth,
24 | str_blocks=blocks)
25 |
26 | # Model with two heads
27 | model = ResNetBackBone(list_block,
28 | list_layers,
29 | **kwargs)
30 |
31 | if False:
32 | print(
33 | "*** Backbone: Resnet{} (blocks: {} - pooling: {} - Two heads - blocks 2nd head: {} and fm size 2nd head: {}) ***".format(
34 | depth,
35 | blocks,
36 | pooling,
37 | object_head,
38 | model.size_fm_2nd_head))
39 |
40 | # Pretrained from imagenet weights
41 | model = load_pretrained_2D_weights('resnet{}'.format(depth), model, inflation='center')
42 |
43 | return model
44 |
45 |
46 | def get_cnn_structure(str_blocks='2D_2D_2D_2D', depth=18):
47 | # List of blocks
48 | list_block = []
49 |
50 | # layers
51 | if depth == 18:
52 | list_layers = [2, 2, 2, 2]
53 | nature_of_block = 'basic'
54 | elif depth == 34:
55 | list_layers = [3, 4, 6, 3]
56 | nature_of_block = 'basic'
57 | elif depth == 50:
58 | list_layers = [3, 4, 6, 3]
59 | nature_of_block = 'bottleneck'
60 | else:
61 | raise NameError
62 |
63 | # blocks
64 | if nature_of_block == 'basic':
65 | block_2D, block_3D, block_2_1D = BasicBlock2D, BasicBlock3D, BasicBlock2_1D
66 | elif nature_of_block == 'bottleneck':
67 | block_2D, block_3D, block_2_1D = Bottleneck2D, Bottleneck3D, Bottleneck2_1D
68 | else:
69 | raise NameError
70 |
71 | # From string to blocks
72 | list_block_id = str_blocks.split('_')
73 |
74 | # Catch from the options if exists
75 | for i, str_block in enumerate(list_block_id):
76 | # Block kind
77 | if str_block == '2D':
78 | list_block.append(block_2D)
79 | elif str_block == '2.5D':
80 | list_block.append(block_2_1D)
81 | elif str_block == '3D':
82 | list_block.append(block_3D)
83 | else:
84 | # ipdb.set_trace()
85 | raise NameError
86 |
87 | return list_block, list_layers
88 |
89 |
90 | if __name__ == '__main__':
91 | model = resnet_backbone()
92 | img = torch.ones([1, 1, 224, 224])
93 | feats = model(img)
94 | print(feats.shape)
95 |
--------------------------------------------------------------------------------
/scripts_dir/test_concat_stip_side.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # mode='evaluate'
4 | mode='extract'
5 |
6 | gpu_id=0
7 | n_workers=0
8 | n_acts=1
9 |
10 | seq_len=1
11 | predict=0
12 | pred_seq_len=2
13 |
14 | # test only
15 | slide=0
16 | rand_test=1
17 | log_every=10
18 | ckpt_dir='/sailhome/bingbin/STR-PIP/ckpts'
19 | dataset='STIP'
20 |
21 | split='test'
22 | use_gru=1
23 | ped_gru=1
24 | use_trn=0
25 | pos_mode='none'
26 | use_act=0
27 | use_gt_act=0
28 | use_pose=0
29 | # branch='ped'
30 | branch='both'
31 | collapse_cls=0
32 | combine_method='pair'
33 |
34 | view='right'
35 | # view='all_views'
36 | # view='all'
37 |
38 | annot_ped_format='/vision/group/prolix/processed/pedestrians_all_views/{}.pkl'
39 | # if [ $view = 'center' ]
40 | # then
41 | # cache_obj_bbox_format='/vision/group/prolix/processed/obj_bbox_20fps_merged/{}_seg{}.pkl'
42 | # else
43 | cache_obj_bbox_format='/vision/group/prolix/processed/'$view'/obj_bbox_20fps_merged/{}_seg{}.pkl'
44 | # fi
45 |
46 |
47 | load_cache='none'
48 |
49 | cache_format='/vision/group/prolix/processed/cache/'$view'/{}/ped{}_fid{}.pkl'
50 | save_cache_format=$cache_format
51 |
52 | # ckpt_name='concat_gru_seq8_pred2_lr1.0e-04_wd1.0e-05_bt2_posNone_branchboth_collapse0_combinepair_testANN_hanh1'
53 |
54 | # ckpt_name='concat_gru_seq8_pred2_lr1.0e-04_wd1.0e-05_bt2_posNone_branchboth_collapse0_combinepair_run1'
55 |
56 | ckpt_name='concat_gru_seq8_pred2_lr1.0e-05_wd1.0e-05_bt2_posNone_branchboth_collapse0_combinepair_decay10'
57 | # -1 for the best epoch
58 | which_epoch=-1
59 |
60 | # this is to set a non-existent epoch s.t. the features are extracted from ImageNet backbone
61 | # which_epoch=100
62 |
63 | if [ $which_epoch -eq -1 ]
64 | then
65 | epoch_name='best_pred'
66 | else
67 | epoch_name=$which_epoch
68 | fi
69 | save_output=10
70 | save_output_format=$ckpt_dir'/'$dataset'/'$ckpt_name'/output_epoch'$epoch_name'_step{}.pkl'
71 |
72 |
73 | if [ "$mode" = "extract" ]
74 | then
75 | extract_feats_dir='/vision/group/prolix/processed/cache/'$view'/STIP_conv_feats/'$ckpt_name'/'$split'/'
76 | else
77 | extract_feats_dir='none_existent'
78 | fi
79 |
80 | CUDA_VISIBLE_DEVICES=$gpu_id python3 test.py \
81 | --model='concat' \
82 | --view=$view \
83 | --split=$split \
84 | --n-acts=$n_acts \
85 | --mode=$mode \
86 | --device=$gpu_id \
87 | --log-every=$log_every \
88 | --dset-name='STIP' \
89 | --ckpt-name=$ckpt_name \
90 | --batch-size=1 \
91 | --n-workers=$n_workers \
92 | --annot-ped-format=$annot_ped_format \
93 | --cache-obj-bbox-format=$cache_obj_bbox_format \
94 | --load-cache=$load_cache \
95 | --save-cache-format=$save_cache_format \
96 | --cache-format=$cache_format \
97 | --seq-len=$seq_len \
98 | --predict=$predict \
99 | --pred-seq-len=$pred_seq_len \
100 | --use-gru=$use_gru \
101 | --ped-gru=$ped_gru \
102 | --use-trn=$use_trn \
103 | --use-act=$use_act \
104 | --use-gt-act=$use_gt_act \
105 | --use-pose=$use_pose \
106 | --pos-mode=$pos_mode \
107 | --collapse-cls=$collapse_cls \
108 | --slide=$slide \
109 | --rand-test=$rand_test \
110 | --branch=$branch \
111 | --which-epoch=$which_epoch \
112 | --save-output=$save_output \
113 | --save-output-format=$save_output_format \
114 | --extract-feats-dir=$extract_feats_dir
115 |
116 |
--------------------------------------------------------------------------------
/scripts_dir/test_concat.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | mode='evaluate'
4 | # mode='extract'
5 |
6 | gpu_id=0
7 | n_workers=0
8 | n_acts=9
9 |
10 | seq_len=30
11 | predict=1
12 | pred_seq_len=30
13 |
14 | # test only
15 | slide=0
16 | rand_test=1
17 | log_every=10
18 | ckpt_dir='/sailhome/bingbin/STR-PIP/ckpts'
19 | dataset='JAAD'
20 |
21 | split='test'
22 | use_gru=1
23 | ped_gru=1
24 | use_trn=0
25 | pos_mode='none'
26 | use_act=0
27 | use_gt_act=0
28 | use_pose=0
29 | # branch='ped'
30 | branch='both'
31 | collapse_cls=0
32 | combine_method='pair'
33 |
34 | load_cache='masks'
35 | # load_cache='none'
36 |
37 | cache_format='/sailhome/bingbin/STR-PIP/datasets/cache/jaad_collapse/{}/ped{}_fid{}.pkl'
38 | save_cache_format=$cache_format
39 |
40 | # cache_format='/sailhome/bingbin/STR-PIP/datasets/cache/jaad_collapse_max/{}/ped{}_fid{}.pkl'
41 | # cache_format='/sailhome/bingbin/STR-PIP/datasets/cache/jaad_collapse/{}/ped{}_fid{}.pkl'
42 |
43 | # ckpt_name='concat_gru_lr1.0e-04_wd1.0e-05_bt4_ped_collapse0_combinepair_useBBox1_cacheMasks_fixGRU'
44 | # which_epoch=13
45 |
46 | # ckpt_name='concat_gru_lr1.0e-04_wd1.0e-05_bt4_ped_collapse0_combinepair_useBBox0_cacheMasks_fixGRU'
47 | # which_epoch=44
48 |
49 | # ckpt_name='concat_gru_seq30_pred30_lr1.0e-04_wd1.0e-05_bt4_posNone_branchped_collapse0_combinepair_cacheMasks_fixGRU_eval3_9acts_withGTAct'
50 |
51 | # bkwfairi
52 | # ckpt_name='concat_gru_seq14_pred1_lr1.0e-04_wd1.0e-05_bt4_posNone_branchped_collapse0_combinepair_cacheMasks_fixGRU_eval3_9acts_noAct_sanityWithPose_withReLU'
53 |
54 | # dur1n8v7
55 | # saved: pred ~74.9
56 | # ckpt_name='concat_gru_seq30_pred30_lr1.0e-04_wd1.0e-05_bt4_posNone_branchboth_collapse0_combinepair_cacheMasks_fixGRU_eval3_9acts_noAct_sanityWithPose_withReLU_pedGRU'
57 |
58 | # -1 for the best epoch
59 | which_epoch=-1
60 |
61 | # this is to set a non-existent epoch s.t. the features are extracted from ImageNet backbone
62 | # which_epoch=100
63 |
64 | if [ $which_epoch -eq -1 ]
65 | then
66 | epoch_name='best_pred'
67 | else
68 | epoch_name=$which_epoch
69 | fi
70 | save_output=10
71 | save_output_format=$ckpt_dir'/'$dataset'/'$ckpt_name'/output_epoch'$epoch_name'_step{}.pkl'
72 |
73 |
74 | if [ "$mode" = "extract" ]
75 | then
76 | extract_feats_dir='/sailhome/bingbin/STR-PIP/datasets/cache/JAAD_conv_feats/'$ckpt_name'/'$split'/'
77 | else
78 | extract_feats_dir='none_existent'
79 | fi
80 |
81 | CUDA_VISIBLE_DEVICES=$gpu_id python3 test.py \
82 | --model='concat' \
83 | --split=$split \
84 | --n-acts=$n_acts \
85 | --mode=$mode \
86 | --device=$gpu_id \
87 | --log-every=$log_every \
88 | --dset-name='JAAD' \
89 | --ckpt-name=$ckpt_name \
90 | --batch-size=1 \
91 | --n-workers=$n_workers \
92 | --load-cache=$load_cache \
93 | --save-cache-format=$save_cache_format \
94 | --cache-format=$cache_format \
95 | --seq-len=$seq_len \
96 | --predict=$predict \
97 | --pred-seq-len=$pred_seq_len \
98 | --use-gru=$use_gru \
99 | --ped-gru=$ped_gru \
100 | --use-trn=$use_trn \
101 | --use-act=$use_act \
102 | --use-gt-act=$use_gt_act \
103 | --use-pose=$use_pose \
104 | --pos-mode=$pos_mode \
105 | --collapse-cls=$collapse_cls \
106 | --slide=$slide \
107 | --rand-test=$rand_test \
108 | --branch=$branch \
109 | --which-epoch=$which_epoch \
110 | --save-output=$save_output \
111 | --save-output-format=$save_output_format \
112 | --extract-feats-dir=$extract_feats_dir
113 |
114 |
--------------------------------------------------------------------------------
/scripts_dir/train_loc_graph.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | gpu_id=1
4 | split='train'
5 | n_workers=0
6 | n_acts=1
7 | bt=8
8 | log_every=30
9 |
10 | # Training only
11 | n_epochs=80
12 | start_epoch=0
13 | lr=1e-4
14 | wd=1e-5
15 | lr_decay=1
16 | decay_every=20
17 | save_every=10
18 | evaluate_every=1
19 |
20 | reg_smooth='none'
21 | reg_lambda=0
22 |
23 | seq_len=30
24 | predict=1
25 | pred_seq_len=30
26 | predict_k=0
27 |
28 | annot_loc_format='/sailhome/ajarno/STR-PIP/datasets/annot_{}_loc.pkl'
29 |
30 | load_cache='feats'
31 | # cache_format='/sailhome/bingbin/STR-PIP/datasets/cache/JAAD_conv_feats/concat_gru_seq30_pred30_lr1.0e-04_wd1.0e-05_bt4_posNone_branchboth_collapse0_combinepair_cacheMasks_fixGRU_eval3_9acts_noAct_sanityWithPose_withReLU_pedGRU/{}/ped{}_fid{}.pkl'
32 | cache_format='/sailhome/bingbin/STR-PIP/datasets/cache/jaad_loc/JAAD_conv_feats/loc_concat_gru_seq30_pred30_lr1.0e-04_wd1.0e-05_bt1_posNone_branchboth_collapse0_combinepair_tmp/{}/vid{}_fid{}.pkl'
33 |
34 | use_gru=1
35 | use_trn=0
36 | frame_gru=1
37 | node_gru=0
38 |
39 | # features
40 | use_act=0
41 | use_gt_act=0
42 | use_driver=0
43 | use_pose=0
44 | pos_mode='none'
45 | # branch='ped'
46 | branch='both'
47 | # adj_type='spatial'
48 | # adj_type='random'
49 | # adj_type='uniform'
50 | # adj_type='spatialOnly'
51 | # adj_type='all'
52 | adj_type='inner'
53 | use_obj_cls=0
54 | n_layers=2
55 | diff_layer_weight=0
56 | collapse_cls=0
57 | combine_method='pair'
58 |
59 | # saving & loading
60 | # suffix='_v4Feats_pedGRU_newCtxtGRU_3evalEpoch'
61 | suffix='_v4Feats_pedGRU_3evalEpoch'
62 |
63 | if [[ "$node_gru" == 1 ]]
64 | then
65 | suffix=$suffix"_nodeGRU"
66 | fi
67 |
68 | if [[ "$use_obj_cls" == 1 ]]
69 | then
70 | suffix=$suffix"_objCls"
71 | fi
72 |
73 | if [[ "$use_driver" == 1 ]]
74 | then
75 | suffix=$suffix"_useDriver"
76 | fi
77 |
78 | ckpt_name='branch'$branch'_collapse'$collapse_cls'_combine'$combine_method'_adjType'$adj_type'_nLayers'$n_layers'_diffW'$diff_layer_weight$suffix
79 | # ckpt_name='graph_seq30_layer2_embed'
80 |
81 | CUDA_VISIBLE_DEVICES=$gpu_id python3 -m train.py \
82 | --model='loc_graph' \
83 | --reg-smooth=$reg_smooth \
84 | --reg-lambda=$reg_lambda \
85 | --split=$split \
86 | --n-acts=$n_acts \
87 | --device=$gpu_id \
88 | --dset-name='JAAD_loc' \
89 | --ckpt-name=$ckpt_name \
90 | --n-epochs=$n_epochs \
91 | --start-epoch=$start_epoch \
92 | --lr-init=$lr \
93 | --wd=$wd \
94 | --lr-decay=$lr_decay \
95 | --decay-every=$decay_every \
96 | --seq-len=$seq_len \
97 | --predict=$predict \
98 | --pred-seq-len=$pred_seq_len \
99 | --predict-k=$predict_k \
100 | --batch-size=$bt \
101 | --save-every=$save_every \
102 | --evaluate-every=$evaluate_every \
103 | --log-every=$log_every \
104 | --n-workers=$n_workers \
105 | --annot-loc-format=$annot_loc_format \
106 | --load-cache=$load_cache \
107 | --cache-format=$cache_format \
108 | --branch=$branch \
109 | --adj-type=$adj_type \
110 | --use-obj-cls=$use_obj_cls \
111 | --n-layers=$n_layers \
112 | --diff-layer-weight=$diff_layer_weight \
113 | --collapse-cls=$collapse_cls \
114 | --combine-method=$combine_method \
115 | --use-gru=$use_gru \
116 | --use-trn=$use_trn \
117 | --frame-gru=$frame_gru \
118 | --node-gru=$node_gru \
119 | --use-act=$use_act \
120 | --use-gt-act=$use_gt_act \
121 | --use-driver=$use_driver \
122 | --use-pose=$use_pose \
123 | --pos-mode=$pos_mode
124 |
--------------------------------------------------------------------------------
/utils/visual.py:
--------------------------------------------------------------------------------
1 | import os
2 | from glob import glob
3 | import json
4 | import cv2
5 | import pdb
6 |
7 | MACHINE = 'TRI'
8 | MACHINE = 'CS'
9 |
10 | fps_annot = 2
11 | fps_png = 20
12 | fps_tgt = 2 # target FPS
13 | overRate = 1 # oversample rate for annots
14 | downRate = 10 # downsample rate for png
15 |
16 | if MACHINE == 'TRI':
17 | # on TRI instance
18 | stip_root = '/mnt/paralle/stip'
19 | annot_root = os.path.join(stip_root, 'annotation')
20 | # subdir_format_png: two placeholders: area (e.g. "ANN_conor1"), subdir (e.g. "00:15-00:41").
21 | subdir_format_png = os.path.join(stip_root, 'stip_instances', '{}/rgb/{}')
22 | else:
23 | # on Stanford CS servers
24 | stip_root = '/vision2/u/bingbin/STR-PIP/STIP'
25 | annot_root = os.path.join(stip_root, 'annotations')
26 | subdir_format_png = os.path.join(stip_root, 'rgb', '{}/{}')
27 |
28 | # drawing setting
29 | FONT = cv2.FONT_HERSHEY_SIMPLEX
30 | FONT_SCALE = 4
31 | TEXT_COLOR = (255, 255, 255)
32 | TEXT_THICK = 2
33 |
34 | def png_lookup(fpngs):
35 | png_map = {}
36 | for fpng in fpngs:
37 | fid = int(os.path.basename(fpng).split('.')[0])
38 | png_map[fid] = fpng
39 | return png_map
40 |
41 |
42 | def visual_clip(fannot, sdarea, tgt_dir):
43 | # pngs = sorted([fpng for sdir in glob(png_formt.format(darea, '*')) for fpng in os.path.join(sdir, '*.png')])
44 | fpngs_all = sorted([fpng for fpng in glob(os.path.join(sdarea, '*.png'))])
45 | png_map = png_lookup(fpngs_all)
46 |
47 | os.makedirs(tgt_dir, exist_ok=True)
48 |
49 | annot = json.load(open(fannot, 'r'))
50 | frames_annot = annot['frames']
51 | for fid_annot in frames_annot:
52 | fid_tgt = [overRate*(int(fid_annot)-1)+i for i in range(overRate)]
53 | fid_png = [tfid*downRate for tfid in fid_tgt]
54 | fpngs = [png_map[pfid] for pfid in fid_png if pfid in png_map]
55 | if len(fpngs):
56 | print('# fpngs: {} / len(png_map): {}'.format(len(fpngs), len(png_map)))
57 | print('fid_annot:', fid_annot)
58 | print('fid_tgt:', fid_tgt)
59 | print('fid_png:', fid_png)
60 | print('fpngs:', fpngs)
61 | print()
62 |
63 | # draw bbox on png frames
64 | annot = frames_annot[fid_annot]
65 | for fpng in fpngs:
66 | img = cv2.imread(fpng)
67 | print('img size:', img.shape)
68 | if len(annot) == 0:
69 | cv2.putText(img, 'No obj in the current frame.', (100,100),
70 | FONT, FONT_SCALE, TEXT_COLOR, TEXT_THICK)
71 | else:
72 | for obj in annot:
73 | cv2.rectangle(img, (int(obj['x1']), int(obj['y1'])), (int(obj['x2']), int(obj['y2'])), (0,255,0), 3)
74 | cv2.putText(img, '-'.join(obj['tags']), (int(obj['x1']), int(obj['y1'])),
75 | FONT, FONT_SCALE, TEXT_COLOR, TEXT_THICK)
76 | tgt_fpng = fpng.replace(sdarea, tgt_dir)
77 | if tgt_fpng == fpng:
78 | print('Error saving drawn png: tgt_fpng == fpng')
79 | print('fpng:', fpng)
80 | pdb.set_trace()
81 | else:
82 | cv2.imwrite(tgt_fpng, img)
83 |
84 | def visual_clip_wrapper():
85 | fannot = os.path.join(annot_root, '20170907_prolix_trial_ANN_hanh2-09-07-2017_15-44-07.concat.12fps.mp4.json')
86 | sdarea = subdir_format_png.format('ANN_hanh2', '00:16--00:27')
87 | if MACHINE == 'TRI':
88 | tgt_dir = sdarea.replace('/stip/', '/stip/tmp_vis/')
89 | else:
90 | tgt_dir = sdarea.replace('/rgb/' ,'/tmp_vis/')
91 | visual_clip(fannot, sdarea, tgt_dir)
92 |
93 |
94 | if __name__ == "__main__":
95 | visual_clip_wrapper()
96 |
--------------------------------------------------------------------------------
/scripts_dir/train_graph_stip_all.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | gpu_id=0
4 | split='train'
5 | n_workers=0
6 | n_acts=1
7 | bt=32
8 | log_every=30
9 |
10 | # Training only
11 | n_epochs=200
12 | start_epoch=0
13 | lr=1e-4
14 | wd=1e-5
15 | lr_decay=1
16 | decay_every=20
17 | save_every=10
18 | evaluate_every=1
19 |
20 | reg_smooth='none'
21 | reg_lambda=0
22 |
23 | seq_len=8
24 | predict=1
25 | pred_seq_len=6
26 | predict_k=0
27 |
28 | view='all'
29 |
30 | annot_ped_format='/vision/group/prolix/processed/pedestrians_all_views/{}.pkl'
31 |
32 | cache_obj_bbox_format='/vision/group/prolix/processed/all/obj_bbox_20fps_merged/{}_seg{}.pkl'
33 |
34 | load_cache='feats'
35 | cache_format_base='/vision/group/prolix/processed/cache/all/STIP_conv_feats/'
36 | cache_format_pkl='{}/ped{}_fid{}.pkl'
37 | concat_ckpt='concat_gru_seq8_pred2_lr1.0e-05_wd1.0e-05_bt2_posNone_branchboth_collapse0_combinepair_decay10/'
38 | cache_format=$cache_format_base$concat_ckpt$cache_format_pkl
39 |
40 | use_gru=1
41 | use_trn=0
42 | ped_gru=1
43 | ctxt_gru=0
44 | ctxt_node=0
45 |
46 | # features
47 | use_act=0
48 | use_gt_act=0
49 | use_driver=0
50 | use_pose=0
51 | pos_mode='none'
52 | # branch='ped'
53 | branch='both'
54 | adj_type='spatial'
55 | # adj_type='random'
56 | # adj_type='uniform'
57 | # adj_type='spatialOnly'
58 | # adj_type='all'
59 | # adj_type='inner'
60 | use_obj_cls=0
61 | n_layers=2
62 | diff_layer_weight=0
63 | collapse_cls=0
64 | combine_method='pair'
65 |
66 | # saving & loading
67 | suffix='_decay10Feats_'$view
68 |
69 | if [[ "$ctxt_gru" == 1 ]]
70 | then
71 | suffix=$suffix"_newCtxtGRU"
72 | fi
73 |
74 | if [[ "$use_obj_cls" == 1 ]]
75 | then
76 | suffix=$suffix"_objCls"
77 | fi
78 |
79 | if [[ "$use_driver" == 1 ]]
80 | then
81 | suffix=$suffix"_useDriver"
82 | fi
83 |
84 | ckpt_name='branch'$branch'_collapse'$collapse_cls'_combine'$combine_method'_adjType'$adj_type'_nLayers'$n_layers'_diffW'$diff_layer_weight$suffix
85 |
86 | # WANDB_MODE=dryrun CUDA_VISIBLE_DEVICES=$gpu_id python3 -m train.py \
87 | CUDA_VISIBLE_DEVICES=$gpu_id python3 -m train.py \
88 | --model='graph' \
89 | --view=$view \
90 | --reg-smooth=$reg_smooth \
91 | --reg-lambda=$reg_lambda \
92 | --split=$split \
93 | --n-acts=$n_acts \
94 | --device=$gpu_id \
95 | --dset-name='STIP' \
96 | --ckpt-name=$ckpt_name \
97 | --n-epochs=$n_epochs \
98 | --start-epoch=$start_epoch \
99 | --lr-init=$lr \
100 | --wd=$wd \
101 | --lr-decay=$lr_decay \
102 | --decay-every=$decay_every \
103 | --seq-len=$seq_len \
104 | --predict=$predict \
105 | --pred-seq-len=$pred_seq_len \
106 | --predict-k=$predict_k \
107 | --batch-size=$bt \
108 | --save-every=$save_every \
109 | --evaluate-every=$evaluate_every \
110 | --log-every=$log_every \
111 | --n-workers=$n_workers \
112 | --annot-ped-format=$annot_ped_format \
113 | --cache-obj-bbox-format=$cache_obj_bbox_format \
114 | --load-cache=$load_cache \
115 | --cache-format=$cache_format \
116 | --branch=$branch \
117 | --adj-type=$adj_type \
118 | --use-obj-cls=$use_obj_cls \
119 | --n-layers=$n_layers \
120 | --diff-layer-weight=$diff_layer_weight \
121 | --collapse-cls=$collapse_cls \
122 | --combine-method=$combine_method \
123 | --use-gru=$use_gru \
124 | --use-trn=$use_trn \
125 | --ped-gru=$ped_gru \
126 | --ctxt-gru=$ctxt_gru \
127 | --ctxt-node=$ctxt_node \
128 | --use-act=$use_act \
129 | --use-gt-act=$use_gt_act \
130 | --use-driver=$use_driver \
131 | --use-pose=$use_pose \
132 | --pos-mode=$pos_mode
133 |
--------------------------------------------------------------------------------
/scripts_dir/train_graph_stip.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | gpu_id=0
4 | split='train'
5 | n_workers=0
6 | n_acts=1
7 | bt=32
8 | log_every=30
9 |
10 | # Training only
11 | n_epochs=200
12 | start_epoch=0
13 | lr=1e-4
14 | wd=1e-5
15 | lr_decay=1
16 | decay_every=20
17 | save_every=10
18 | evaluate_every=1
19 |
20 | reg_smooth='none'
21 | reg_lambda=0
22 |
23 | seq_len=30
24 | predict=1
25 | pred_seq_len=60
26 | predict_k=0
27 |
28 | annot_ped_format='/vision/group/prolix/processed/pedestrians/{}.pkl'
29 |
30 | cache_obj_bbox_format='/vision/group/prolix/processed/obj_bbox_20fps_merged/{}_seg{}.pkl'
31 |
32 | load_cache='feats'
33 | cache_format_base='/vision/group/prolix/processed/cache/STIP_conv_feats/'
34 | cache_format_pkl='{}/ped{}_fid{}.pkl'
35 | # cache_format=$cache_format_base'concat_gru_seq8_pred2_lr1.0e-04_wd1.0e-05_bt2_posNone_branchboth_collapse0_combinepair_testANN_hanh1/'$cache_format_pkl
36 | # concat_ckpt='concat_gru_seq8_pred2_lr1.0e-04_wd1.0e-05_bt2_posNone_branchboth_collapse0_combinepair_run1/'
37 | concat_ckpt='concat_gru_seq8_pred2_lr1.0e-05_wd1.0e-05_bt2_posNone_branchboth_collapse0_combinepair_decay10/'
38 | cache_format=$cache_format_base$concat_ckpt$cache_format_pkl
39 |
40 | use_gru=1
41 | use_trn=0
42 | ped_gru=1
43 | ctxt_gru=0
44 | ctxt_node=0
45 |
46 | # features
47 | use_act=0
48 | use_gt_act=0
49 | use_driver=0
50 | use_pose=0
51 | pos_mode='none'
52 | # branch='ped'
53 | branch='both'
54 | adj_type='spatial'
55 | # adj_type='random'
56 | # adj_type='uniform'
57 | # adj_type='spatialOnly'
58 | # adj_type='all'
59 | # adj_type='inner'
60 | use_obj_cls=0
61 | n_layers=2
62 | diff_layer_weight=0
63 | collapse_cls=0
64 | combine_method='pair'
65 |
66 | # saving & loading
67 | suffix='_decay10Feats'
68 |
69 | if [[ "$ctxt_gru" == 1 ]]
70 | then
71 | suffix=$suffix"_newCtxtGRU"
72 | fi
73 |
74 | if [[ "$use_obj_cls" == 1 ]]
75 | then
76 | suffix=$suffix"_objCls"
77 | fi
78 |
79 | if [[ "$use_driver" == 1 ]]
80 | then
81 | suffix=$suffix"_useDriver"
82 | fi
83 |
84 | ckpt_name='branch'$branch'_collapse'$collapse_cls'_combine'$combine_method'_adjType'$adj_type'_nLayers'$n_layers'_diffW'$diff_layer_weight$suffix
85 |
86 | # WANDB_MODE=dryrun CUDA_VISIBLE_DEVICES=$gpu_id python3 -m train.py \
87 | CUDA_VISIBLE_DEVICES=$gpu_id python3 -m train.py \
88 | --model='graph' \
89 | --reg-smooth=$reg_smooth \
90 | --reg-lambda=$reg_lambda \
91 | --split=$split \
92 | --n-acts=$n_acts \
93 | --device=$gpu_id \
94 | --dset-name='STIP' \
95 | --ckpt-name=$ckpt_name \
96 | --n-epochs=$n_epochs \
97 | --start-epoch=$start_epoch \
98 | --lr-init=$lr \
99 | --wd=$wd \
100 | --lr-decay=$lr_decay \
101 | --decay-every=$decay_every \
102 | --seq-len=$seq_len \
103 | --predict=$predict \
104 | --pred-seq-len=$pred_seq_len \
105 | --predict-k=$predict_k \
106 | --batch-size=$bt \
107 | --save-every=$save_every \
108 | --evaluate-every=$evaluate_every \
109 | --log-every=$log_every \
110 | --n-workers=$n_workers \
111 | --annot-ped-format=$annot_ped_format \
112 | --cache-obj-bbox-format=$cache_obj_bbox_format \
113 | --load-cache=$load_cache \
114 | --cache-format=$cache_format \
115 | --branch=$branch \
116 | --adj-type=$adj_type \
117 | --use-obj-cls=$use_obj_cls \
118 | --n-layers=$n_layers \
119 | --diff-layer-weight=$diff_layer_weight \
120 | --collapse-cls=$collapse_cls \
121 | --combine-method=$combine_method \
122 | --use-gru=$use_gru \
123 | --use-trn=$use_trn \
124 | --ped-gru=$ped_gru \
125 | --ctxt-gru=$ctxt_gru \
126 | --ctxt-node=$ctxt_node \
127 | --use-act=$use_act \
128 | --use-gt-act=$use_gt_act \
129 | --use-driver=$use_driver \
130 | --use-pose=$use_pose \
131 | --pos-mode=$pos_mode
132 |
--------------------------------------------------------------------------------
/scripts_dir/test_graph_stip_all.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | mode='evaluate'
4 | split='test'
5 |
6 | gpu_id=1
7 | n_workers=0
8 | n_acts=1
9 | bt=1
10 |
11 | seq_len=4
12 | predict=1
13 | pred_seq_len=4
14 | predict_k=0
15 |
16 | view='all'
17 |
18 | # TODO: get new ped formats
19 | annot_ped_format='/vision/group/prolix/processed/pedestrians_all_views/{}.pkl'
20 | cache_obj_bbox_format='/vision/group/prolix/processed/'$view'/obj_bbox_20fps_merged/{}_seg{}.pkl'
21 |
22 | # test only
23 | slide=0
24 | rand_test=1
25 | log_every=10
26 | ckpt_dir='/sailhome/bingbin/STR-PIP/ckpts'
27 | dataset='STIP'
28 |
29 | if [ $seq_len -eq 8 ]
30 | then
31 | ckpt_name='graph_gru_seq'$seq_len'_pred'$pred_seq_len'_lr1.0e-04_wd1.0e-05_bt32_posNone_branchboth_collapse0_combinepair_adjTypespatial_nLayers2_diffW0_decay10Feats_all'
32 | else
33 | # seq_len = 4
34 | ckpt_name='graph_gru_seq'$seq_len'_pred'$pred_seq_len'_lr1.0e-04_wd1.0e-05_bt32_posNone_branchboth_collapse0_combinepair_adjTypespatial_nLayers2_diffW0_decay10Feats'
35 | fi
36 | which_epoch=-1
37 | if [ $which_epoch -eq -1 ]
38 | then
39 | epoch_name='best_pred'
40 | else
41 | epoch_name=$which_epoch
42 | fi
43 | save_output=1
44 | save_output_format=$ckpt_dir'/'$dataset'/'$ckpt_name'/output_epoch'$epoch_name'_step{}_'$seq_len'+'$pred_seq_len'.pkl'
45 | collect_A=1
46 | save_As_format=$ckpt_dir'/'$dataset'/'$ckpt_name'/test_graph_weights_epoch'$epoch_name'/vid{}_eval{}.pkl'
47 |
48 |
49 | load_cache='feats'
50 | feat_ckpt_name='concat_gru_seq8_pred2_lr1.0e-05_wd1.0e-05_bt2_posNone_branchboth_collapse0_combinepair_decay10'
51 | cache_format='/vision/group/prolix/processed/cache/'$view'/STIP_conv_feats/'$feat_ckpt_name'/{}/ped{}_fid{}.pkl'
52 |
53 |
54 | # if [ "$mode" = "extract" ]
55 | # then
56 | # extract_feats_dir='/sailhome/bingbin/STR-PIP/datasets/cache/JAAD_conv_feats/concat_gru_lr1.0e-05_bt4_test_epoch5/test/'
57 | # else
58 | # extract_feats_dir='none_existent'
59 | # fi
60 |
61 | use_gru=1
62 | use_trn=0
63 | ped_gru=1
64 | ctxt_gru=0
65 | ctxt_node=0
66 |
67 | # features
68 | use_act=0
69 | use_gt_act=0
70 | use_driver=0
71 | use_pose=0
72 | pos_mode='none'
73 | # branch='ped'
74 | branch='both'
75 | adj_type='spatial'
76 | use_obj_cls=0
77 | n_layers=2
78 | diff_layer_weight=0
79 | collapse_cls=0
80 | combine_method='pair'
81 |
82 | CUDA_VISIBLE_DEVICES=$gpu_id python3 -W ignore test.py \
83 | --model='graph' \
84 | --split=$split \
85 | --view=$view \
86 | --mode=$mode \
87 | --slide=$slide \
88 | --rand-test=$rand_test \
89 | --ckpt-dir=$ckpt_dir \
90 | --ckpt-name=$ckpt_name \
91 | --n-acts=$n_acts \
92 | --device=$gpu_id \
93 | --dset-name='STIP' \
94 | --ckpt-name=$ckpt_name \
95 | --which-epoch=$which_epoch \
96 | --save-output=$save_output \
97 | --save-output-format=$save_output_format \
98 | --collect-A=$collect_A \
99 | --save-As-format=$save_As_format \
100 | --n-epochs=$n_epochs \
101 | --start-epoch=$start_epoch \
102 | --seq-len=$seq_len \
103 | --predict=$predict \
104 | --pred-seq-len=$pred_seq_len \
105 | --predict-k=$predict_k \
106 | --batch-size=$bt \
107 | --log-every=$log_every \
108 | --n-workers=$n_workers \
109 | --annot-ped-format=$annot_ped_format \
110 | --cache-obj-bbox-format=$cache_obj_bbox_format \
111 | --load-cache=$load_cache \
112 | --cache-format=$cache_format \
113 | --branch=$branch \
114 | --adj-type=$adj_type \
115 | --use-obj-cls=$use_obj_cls \
116 | --n-layers=$n_layers \
117 | --diff-layer-weight=$diff_layer_weight \
118 | --collapse-cls=$collapse_cls \
119 | --combine-method=$combine_method \
120 | --use-gru=$use_gru \
121 | --use-trn=$use_trn \
122 | --ped-gru=$ped_gru \
123 | --ctxt-gru=$ctxt_gru \
124 | --ctxt-node=$ctxt_node \
125 | --use-act=$use_act \
126 | --use-gt-act=$use_gt_act \
127 | --use-driver=$use_driver \
128 | --use-pose=$use_pose \
129 | --pos-mode=$pos_mode
130 |
131 |
132 |
--------------------------------------------------------------------------------
/utils/stats.py:
--------------------------------------------------------------------------------
1 | import os
2 | from glob import glob
3 | import pickle
4 |
5 | import pdb
6 |
7 | data_root = '/sailhome/ajarno/STR-PIP/datasets/'
8 | img_root = os.path.join(data_root, 'JAAD_dataset/JAAD_clip_images')
9 |
10 | def count_frames():
11 | vids = glob(os.path.join(img_root, 'video_*.mp4'))
12 | cnts = []
13 | for vid in vids:
14 | cnt = len(glob(os.path.join(vid, '*.jpg')))
15 | cnts += cnt,
16 |
17 | print('total # frames:', sum(cnts))
18 | print('cnts: max={} / min={} / mean={}'.format(max(cnts), min(cnts), sum(cnts)/len(cnts)))
19 | # pdb.set_trace()
20 |
21 | def ped_count_crossing():
22 | def helper(ped):
23 | act = ped['act']
24 | crossing = [each[1] for each in act]
25 | n_cross = sum(crossing)
26 | n_noncross = len(crossing) - n_cross
27 | return n_cross, n_noncross
28 |
29 | with open(os.path.join(data_root, 'annot_train_ped_withTag_sanityWithPose.pkl'), 'rb') as handle:
30 | train = pickle.load(handle)
31 | with open(os.path.join(data_root, 'annot_test_ped_withTag_sanityWithPose.pkl'), 'rb') as handle:
32 | test = pickle.load(handle)
33 |
34 | n_crosses = 0
35 | n_noncrosses = 0
36 | for ped in train+test:
37 | n_cross, n_noncross = helper(ped)
38 | n_crosses += n_cross
39 | n_noncrosses += n_noncross
40 |
41 | n_total = n_crosses + n_noncrosses
42 | print('n_cross: {} ({:.4f})'.format(n_crosses, n_crosses/n_total))
43 | print('n_noncross: {} ({:.4f})'.format(n_noncrosses, n_noncrosses/n_total))
44 |
45 | def loc_count_crossing():
46 | def helper(split):
47 | split_cross, split_noncross = 0, 0
48 | for vid in split:
49 | n_cross = sum(split[vid]['act'] == 1)[0]
50 | n_noncross = sum(split[vid]['act'] == 0)[0]
51 | split_cross += n_cross
52 | split_noncross += n_noncross
53 | split_total = split_cross + split_noncross
54 | # pdb.set_trace()
55 | print(' cross:{} ({:.4f}) / non-cross:{} ({:.4f}) / total:{}'.format(
56 | split_cross, split_cross / split_total, split_noncross, split_noncross / split_total, split_total))
57 | return split_cross, split_noncross, split_total
58 |
59 | with open(os.path.join(data_root, 'annot_train_loc.pkl'), 'rb') as handle:
60 | train = pickle.load(handle)
61 | print('train:')
62 | train_cross, train_noncross, train_total = helper(train)
63 |
64 | with open(os.path.join(data_root, 'annot_test_loc.pkl'), 'rb') as handle:
65 | test = pickle.load(handle)
66 | print('\ntest:')
67 | test_cross, test_noncross, test_total = helper(test)
68 |
69 | n_cross, n_noncross, n_total = train_cross+test_cross, train_noncross+test_noncross, train_total+test_total
70 | print('\ntotal:')
71 | print(' cross:{} ({:.4f}) / non-cross:{} ({:.4f}) / total:{}'.format(
72 | n_cross, n_cross / n_total, n_noncross, n_noncross / n_total, n_total))
73 |
74 | def loc_count_ped():
75 | def helper(split):
76 | n_peds = []
77 | for vid in split:
78 | for frame in split[vid]['ped_pos']:
79 | n_peds += len(frame),
80 | print(' max:{} / min:{} / mean:{}'.format(max(n_peds), min(n_peds), sum(n_peds)/len(n_peds)))
81 | return n_peds
82 |
83 | with open(os.path.join(data_root, 'annot_train_loc.pkl'), 'rb') as handle:
84 | train = pickle.load(handle)
85 | print('train:')
86 | train_n_peds = helper(train)
87 |
88 | with open(os.path.join(data_root, 'annot_test_loc_new.pkl'), 'rb') as handle:
89 | test = pickle.load(handle)
90 | print('\ntest:')
91 | test_n_peds = helper(test)
92 |
93 | n_peds = train_n_peds + test_n_peds
94 | print('\ntotal:')
95 | print(' max:{} / min:{} / mean:{}'.format(max(n_peds), min(n_peds), sum(n_peds)/len(n_peds)))
96 |
97 |
98 |
99 |
100 | if __name__ == '__main__':
101 | print('count_frames:')
102 | count_frames()
103 |
104 | print('\nped_count_crossing:')
105 | ped_count_crossing()
106 |
107 | print('\nloc_count_crossing:')
108 | loc_count_crossing()
109 |
110 | print('\nloc_count_ped:')
111 | loc_count_ped()
112 |
--------------------------------------------------------------------------------
/scripts_dir/test_graph_stip.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | mode='evaluate'
4 | split='test'
5 |
6 | gpu_id=0
7 | n_workers=0
8 | n_acts=1
9 | bt=1
10 |
11 | seq_len=8
12 | predict=1
13 | pred_seq_len=6
14 | predict_k=0
15 |
16 | annot_ped_format='/vision/group/prolix/processed/pedestrians/{}.pkl'
17 | cache_obj_bbox_format='/vision/group/prolix/processed/obj_bbox_20fps_merged/{}_seg{}.pkl'
18 |
19 | # test only
20 | slide=0
21 | rand_test=1
22 | log_every=10
23 | ckpt_dir='/sailhome/bingbin/STR-PIP/ckpts'
24 | dataset='STIP'
25 |
26 | # pred 10
27 | # ckpt_name='graph_gru_seq30_pred10_lr1.0e-04_wd1.0e-05_bt8_posNone_branchboth_collapse0_combinepair_adjTypespatial_nLayers2_diffW0_decay10Feats'
28 | ckpt_name='graph_gru_seq8_pred2_lr1.0e-04_wd1.0e-05_bt32_posNone_branchboth_collapse0_combinepair_adjTypespatial_nLayers2_diffW0_decay10Feats'
29 |
30 | # pred 30
31 | # ckpt_name='graph_gru_seq30_pred60_lr1.0e-04_wd1.0e-05_bt16_posNone_branchboth_collapse0_combinepair_adjTypespatial_nLayers2_diffW0_v4Feats_pedGRU_3evalEpoch'
32 |
33 | # pred 60
34 | # ckpt_name='graph_gru_seq30_pred90_lr1.0e-04_wd1.0e-05_bt16_posNone_branchboth_collapse0_combinepair_adjTypespatial_nLayers2_diffW0_v4Feats_pedGRU_3evalEpoch'
35 |
36 | which_epoch=-1
37 | if [ $which_epoch -eq -1 ]
38 | then
39 | epoch_name='best_pred'
40 | else
41 | epoch_name=$which_epoch
42 | fi
43 | save_output=1
44 | save_output_format=$ckpt_dir'/'$dataset'/'$ckpt_name'/output_epoch'$epoch_name'_step{}_'$seq_len'+'$pred_seq_len'.pkl'
45 | collect_A=1
46 | save_As_format=$ckpt_dir'/'$dataset'/'$ckpt_name'/test_graph_weights_epoch'$epoch_name'/vid{}_eval{}.pkl'
47 |
48 |
49 | load_cache='feats'
50 | cache_format='/vision/group/prolix/processed/cache/STIP_conv_feats/concat_gru_seq8_pred2_lr1.0e-05_wd1.0e-05_bt2_posNone_branchboth_collapse0_combinepair_decay10/{}/ped{}_fid{}.pkl'
51 |
52 |
53 | # if [ "$mode" = "extract" ]
54 | # then
55 | # extract_feats_dir='/sailhome/bingbin/STR-PIP/datasets/cache/JAAD_conv_feats/concat_gru_lr1.0e-05_bt4_test_epoch5/test/'
56 | # else
57 | # extract_feats_dir='none_existent'
58 | # fi
59 |
60 | use_gru=1
61 | use_trn=0
62 | ped_gru=1
63 | ctxt_gru=0
64 | ctxt_node=0
65 |
66 | # features
67 | use_act=0
68 | use_gt_act=0
69 | use_driver=0
70 | use_pose=0
71 | pos_mode='none'
72 | # branch='ped'
73 | branch='both'
74 | adj_type='spatial'
75 | use_obj_cls=0
76 | n_layers=2
77 | diff_layer_weight=0
78 | collapse_cls=0
79 | combine_method='pair'
80 |
81 | CUDA_VISIBLE_DEVICES=$gpu_id python3 test.py \
82 | --model='graph' \
83 | --split=$split \
84 | --mode=$mode \
85 | --slide=$slide \
86 | --rand-test=$rand_test \
87 | --ckpt-dir=$ckpt_dir \
88 | --ckpt-name=$ckpt_name \
89 | --n-acts=$n_acts \
90 | --device=$gpu_id \
91 | --dset-name='STIP' \
92 | --ckpt-name=$ckpt_name \
93 | --which-epoch=$which_epoch \
94 | --save-output=$save_output \
95 | --save-output-format=$save_output_format \
96 | --collect-A=$collect_A \
97 | --save-As-format=$save_As_format \
98 | --n-epochs=$n_epochs \
99 | --start-epoch=$start_epoch \
100 | --seq-len=$seq_len \
101 | --predict=$predict \
102 | --pred-seq-len=$pred_seq_len \
103 | --predict-k=$predict_k \
104 | --batch-size=$bt \
105 | --log-every=$log_every \
106 | --n-workers=$n_workers \
107 | --annot-ped-format=$annot_ped_format \
108 | --cache-obj-bbox-format=$cache_obj_bbox_format \
109 | --load-cache=$load_cache \
110 | --cache-format=$cache_format \
111 | --branch=$branch \
112 | --adj-type=$adj_type \
113 | --use-obj-cls=$use_obj_cls \
114 | --n-layers=$n_layers \
115 | --diff-layer-weight=$diff_layer_weight \
116 | --collapse-cls=$collapse_cls \
117 | --combine-method=$combine_method \
118 | --use-gru=$use_gru \
119 | --use-trn=$use_trn \
120 | --ped-gru=$ped_gru \
121 | --ctxt-gru=$ctxt_gru \
122 | --ctxt-node=$ctxt_node \
123 | --use-act=$use_act \
124 | --use-gt-act=$use_gt_act \
125 | --use-driver=$use_driver \
126 | --use-pose=$use_pose \
127 | --pos-mode=$pos_mode
128 |
129 | exit
130 |
131 |
132 |
--------------------------------------------------------------------------------
/utils/temp_plot.py:
--------------------------------------------------------------------------------
1 | import os
2 | from glob import glob
3 | import numpy as np
4 | import pickle
5 | import matplotlib
6 | matplotlib.use('Agg')
7 | import matplotlib.pyplot as plt
8 |
9 | import pdb
10 |
11 | SMOOTH = True
12 | MIX = True
13 |
14 | fpkls = []
15 |
16 | ckpt_root = '/sailhome/bingbin/STR-PIP/ckpts/JAAD/'
17 |
18 | # best proposed
19 |
20 | # pred 30
21 |
22 | # ckpt_name = 'graph_gru_seq30_pred30_lr3.0e-04_wd1.0e-05_bt16_posNone_branchboth_collapse0_combinepair_adjTypespatial_nLayers2_v4Feats_pedGRU_3evalEpoch'
23 | # fout = 'label_epochbest_pred_stepall_run2.pkl'
24 | # fpkl = os.path.join(ckpt_root, ckpt_name, fout)
25 | # fpkls += fpkl,
26 |
27 | fout = 'output_epochbest_pred_stepall.pkl'
28 | # fpkl = os.path.join(ckpt_root, ckpt_name, fout)
29 | # fpkls += fpkl,
30 |
31 | ckpt_name = 'graph_gru_seq30_pred30_lr1.0e-04_wd1.0e-05_bt16_posNone_branchboth_collapse0_combinepair_adjTypespatial_nLayers2_v4Feats_pedGRU_3evalEpoch'
32 | fpkl = os.path.join(ckpt_root, ckpt_name, fout)
33 | # fpkls += fpkl,
34 |
35 | # fout = 'output_epochbest_pred_stepall_run3Epochs.pkl'
36 | # fout = 'output_epochbest_pred_stepall_run1Epochs.pkl'
37 | # fpkl = os.path.join(ckpt_root, ckpt_name, fout)
38 | # fpkls += fpkl,
39 |
40 | ckpt_name = 'graph_gru_seq30_pred30_lr1.0e-04_wd1.0e-05_bt16_posNone_branchboth_collapse0_combinepair_adjTypespatial_nLayers2_v4Feats_pedGRU_newCtxtGRU_3evalEpoch'
41 | fpkl = os.path.join(ckpt_root, ckpt_name, fout)
42 | fpkls += fpkl,
43 |
44 | # graph - pred 60
45 | ckpt_name = 'graph_gru_seq30_pred60_lr1.0e-04_wd1.0e-05_bt16_posNone_branchboth_collapse0_combinepair_adjTypespatial_nLayers2_diffW0_v4Feats_pedGRU_3evalEpoch'
46 | fpkl = os.path.join(ckpt_root, ckpt_name, fout)
47 | fpkls += fpkl,
48 |
49 |
50 | # graph - pred 90
51 | ckpt_name = 'graph_gru_seq30_pred90_lr1.0e-04_wd1.0e-05_bt16_posNone_branchboth_collapse0_combinepair_adjTypespatial_nLayers2_diffW0_v4Feats_pedGRU_3evalEpoch'
52 | fpkl = os.path.join(ckpt_root, ckpt_name, fout)
53 | fpkls += fpkl,
54 |
55 |
56 | # best concat (pred 30)
57 | ckpt_name = 'concat_gru_seq30_pred30_lr1.0e-04_wd1.0e-05_bt4_posNone_branchboth_collapse0_combinepair_cacheMasks_fixGRU_eval3_9acts_noAct_sanityWithPose_withReLU_pedGRU'
58 | fout = 'output_all.pkl'
59 | fpkl = os.path.join(ckpt_root, ckpt_name, fout)
60 |
61 | # fpkls += fpkl,
62 |
63 |
64 | for i, fpkl in enumerate(fpkls):
65 | with open(fpkl, 'rb') as handle:
66 | data = pickle.load(handle)
67 | out, gt = data['out'], data['GT']
68 | if out.shape[-1] % 10 == 1:
69 | out = out[:, :-1]
70 | gt = gt[:, :-1]
71 | try:
72 | acc = (out == gt).mean(0)
73 | pred_acc = (out[:, 30:] == gt[:, 30:]).mean()
74 | print('pred_acc:', pred_acc)
75 | except Exception as e:
76 | print(e)
77 | pdb.set_trace()
78 | acc = list(acc)
79 | if SMOOTH:
80 | t1, t3 = acc[1:]+[acc[-1]], [acc[0]]+acc[:-1]
81 | acc = (np.array(t1) + np.array(t3) + np.array(acc)) / 3
82 | # pdb.set_trace()
83 | plt.plot(acc, label='{} frames{}'.format(30*(i+1), ' (acc)' if MIX else ''))
84 | plt.legend()
85 | plt.xlabel('Frames', fontsize=12)
86 | plt.ylabel('Accuracy', fontsize=12)
87 | plt.savefig('acc_temp_graph{}.png'.format('_smooth' if SMOOTH else ''), dpi=1000)
88 | if not MIX:
89 | plt.clf()
90 |
91 |
92 | for i, fpkl in enumerate(fpkls):
93 | with open(fpkl, 'rb') as handle:
94 | data = pickle.load(handle)
95 | prob = data['prob']
96 | B = data['out'].shape[0]
97 | prob = prob.reshape(B, -1)
98 | if prob.shape[-1] % 10 == 1:
99 | prob = prob[:, :-1]
100 | prob = list(prob.mean(0))
101 | if SMOOTH:
102 | t1, t3 = prob[1:]+[prob[-1]], [prob[0]]+prob[:-1]
103 | prob = (np.array(t1) + np.array(t3) + np.array(prob)) / 3
104 | # pdb.set_trace()
105 | plt.plot(prob, label='{} frames{}'.format(30*(i+1), ' (prob)' if MIX else ''))
106 | plt.legend()
107 | plt.xlabel('Frames', fontsize=12)
108 | plt.ylabel('Accuracy & Probability' if MIX else 'Probability', fontsize=12)
109 | plt.savefig('{}_temp_graph{}.png'.format('mix' if MIX else 'prob', '_smooth' if SMOOTH else ''), dpi=1000)
110 | plt.clf()
111 |
112 |
--------------------------------------------------------------------------------
/utils/temp_plot_stip.py:
--------------------------------------------------------------------------------
1 | import os
2 | from glob import glob
3 | import numpy as np
4 | import pickle
5 | import matplotlib
6 | matplotlib.use('Agg')
7 | import matplotlib.pyplot as plt
8 |
9 | import pdb
10 |
11 | SMOOTH = True
12 | MIX = True
13 |
14 | fpkls = []
15 |
16 | ckpt_root = '/sailhome/bingbin/STR-PIP/ckpts/STIP/'
17 |
18 | # best proposed
19 |
20 | # pred 30
21 |
22 | # ckpt_name = 'graph_gru_seq30_pred30_lr3.0e-04_wd1.0e-05_bt16_posNone_branchboth_collapse0_combinepair_adjTypespatial_nLayers2_v4Feats_pedGRU_3evalEpoch'
23 | # fout = 'label_epochbest_pred_stepall_run2.pkl'
24 | # fpkl = os.path.join(ckpt_root, ckpt_name, fout)
25 | # fpkls += fpkl,
26 |
27 | fout = 'output_epochbest_pred_stepall.pkl'
28 | # fpkl = os.path.join(ckpt_root, ckpt_name, fout)
29 | # fpkls += fpkl,
30 |
31 | ckpt_name = 'graph_gru_seq30_pred30_lr1.0e-04_wd1.0e-05_bt16_posNone_branchboth_collapse0_combinepair_adjTypespatial_nLayers2_v4Feats_pedGRU_3evalEpoch'
32 | fpkl = os.path.join(ckpt_root, ckpt_name, fout)
33 | # fpkls += fpkl,
34 |
35 | # fout = 'output_epochbest_pred_stepall_run3Epochs.pkl'
36 | # fout = 'output_epochbest_pred_stepall_run1Epochs.pkl'
37 | # fpkl = os.path.join(ckpt_root, ckpt_name, fout)
38 | # fpkls += fpkl,
39 |
40 | ckpt_name = 'graph_gru_seq30_pred30_lr1.0e-04_wd1.0e-05_bt16_posNone_branchboth_collapse0_combinepair_adjTypespatial_nLayers2_v4Feats_pedGRU_newCtxtGRU_3evalEpoch'
41 | fpkl = os.path.join(ckpt_root, ckpt_name, fout)
42 | fpkls += fpkl,
43 |
44 | # graph - pred 60
45 | ckpt_name = 'graph_gru_seq30_pred60_lr1.0e-04_wd1.0e-05_bt16_posNone_branchboth_collapse0_combinepair_adjTypespatial_nLayers2_diffW0_v4Feats_pedGRU_3evalEpoch'
46 | fpkl = os.path.join(ckpt_root, ckpt_name, fout)
47 | fpkls += fpkl,
48 |
49 |
50 | # graph - pred 90
51 | ckpt_name = 'graph_gru_seq30_pred90_lr1.0e-04_wd1.0e-05_bt16_posNone_branchboth_collapse0_combinepair_adjTypespatial_nLayers2_diffW0_v4Feats_pedGRU_3evalEpoch'
52 | fpkl = os.path.join(ckpt_root, ckpt_name, fout)
53 | fpkls += fpkl,
54 |
55 |
56 | # best concat (pred 30)
57 | ckpt_name = 'concat_gru_seq30_pred30_lr1.0e-04_wd1.0e-05_bt4_posNone_branchboth_collapse0_combinepair_cacheMasks_fixGRU_eval3_9acts_noAct_sanityWithPose_withReLU_pedGRU'
58 | fout = 'output_all.pkl'
59 | fpkl = os.path.join(ckpt_root, ckpt_name, fout)
60 |
61 | # fpkls += fpkl,
62 |
63 |
64 | for i, fpkl in enumerate(fpkls):
65 | with open(fpkl, 'rb') as handle:
66 | data = pickle.load(handle)
67 | out, gt = data['out'], data['GT']
68 | if out.shape[-1] % 10 == 1:
69 | out = out[:, :-1]
70 | gt = gt[:, :-1]
71 | try:
72 | acc = (out == gt).mean(0)
73 | pred_acc = (out[:, 30:] == gt[:, 30:]).mean()
74 | print('pred_acc:', pred_acc)
75 | except Exception as e:
76 | print(e)
77 | pdb.set_trace()
78 | acc = list(acc)
79 | if SMOOTH:
80 | t1, t3 = acc[1:]+[acc[-1]], [acc[0]]+acc[:-1]
81 | acc = (np.array(t1) + np.array(t3) + np.array(acc)) / 3
82 | # pdb.set_trace()
83 | plt.plot(acc, label='{} frames{}'.format(30*(i+1), ' (acc)' if MIX else ''))
84 | plt.legend()
85 | plt.xlabel('Frames', fontsize=12)
86 | plt.ylabel('Accuracy', fontsize=12)
87 | plt.savefig('acc_temp_graph{}.png'.format('_smooth' if SMOOTH else ''), dpi=1000)
88 | if not MIX:
89 | plt.clf()
90 |
91 |
92 | for i, fpkl in enumerate(fpkls):
93 | with open(fpkl, 'rb') as handle:
94 | data = pickle.load(handle)
95 | prob = data['prob']
96 | B = data['out'].shape[0]
97 | prob = prob.reshape(B, -1)
98 | if prob.shape[-1] % 10 == 1:
99 | prob = prob[:, :-1]
100 | prob = list(prob.mean(0))
101 | if SMOOTH:
102 | t1, t3 = prob[1:]+[prob[-1]], [prob[0]]+prob[:-1]
103 | prob = (np.array(t1) + np.array(t3) + np.array(prob)) / 3
104 | # pdb.set_trace()
105 | plt.plot(prob, label='{} frames{}'.format(30*(i+1), ' (prob)' if MIX else ''))
106 | plt.legend()
107 | plt.xlabel('Frames', fontsize=12)
108 | plt.ylabel('Accuracy & Probability' if MIX else 'Probability', fontsize=12)
109 | plt.savefig('{}_temp_graph{}.png'.format('mix' if MIX else 'prob', '_smooth' if SMOOTH else ''), dpi=1000)
110 | plt.clf()
111 |
112 |
--------------------------------------------------------------------------------
/utils/colormap.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017-present, Facebook, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | ##############################################################################
15 |
16 | """An awesome colormap for really neat visualizations."""
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 | from __future__ import unicode_literals
22 |
23 | import numpy as np
24 |
25 |
26 | def colormap(rgb=False):
27 | color_list = np.array(
28 | [
29 | 0.000, 0.447, 0.741,
30 | 0.850, 0.325, 0.098,
31 | 0.929, 0.694, 0.125,
32 | 0.494, 0.184, 0.556,
33 | 0.466, 0.674, 0.188,
34 | 0.301, 0.745, 0.933,
35 | 0.635, 0.078, 0.184,
36 | 0.300, 0.300, 0.300,
37 | 0.600, 0.600, 0.600,
38 | 1.000, 0.000, 0.000,
39 | 1.000, 0.500, 0.000,
40 | 0.749, 0.749, 0.000,
41 | 0.000, 1.000, 0.000,
42 | 0.000, 0.000, 1.000,
43 | 0.667, 0.000, 1.000,
44 | 0.333, 0.333, 0.000,
45 | 0.333, 0.667, 0.000,
46 | 0.333, 1.000, 0.000,
47 | 0.667, 0.333, 0.000,
48 | 0.667, 0.667, 0.000,
49 | 0.667, 1.000, 0.000,
50 | 1.000, 0.333, 0.000,
51 | 1.000, 0.667, 0.000,
52 | 1.000, 1.000, 0.000,
53 | 0.000, 0.333, 0.500,
54 | 0.000, 0.667, 0.500,
55 | 0.000, 1.000, 0.500,
56 | 0.333, 0.000, 0.500,
57 | 0.333, 0.333, 0.500,
58 | 0.333, 0.667, 0.500,
59 | 0.333, 1.000, 0.500,
60 | 0.667, 0.000, 0.500,
61 | 0.667, 0.333, 0.500,
62 | 0.667, 0.667, 0.500,
63 | 0.667, 1.000, 0.500,
64 | 1.000, 0.000, 0.500,
65 | 1.000, 0.333, 0.500,
66 | 1.000, 0.667, 0.500,
67 | 1.000, 1.000, 0.500,
68 | 0.000, 0.333, 1.000,
69 | 0.000, 0.667, 1.000,
70 | 0.000, 1.000, 1.000,
71 | 0.333, 0.000, 1.000,
72 | 0.333, 0.333, 1.000,
73 | 0.333, 0.667, 1.000,
74 | 0.333, 1.000, 1.000,
75 | 0.667, 0.000, 1.000,
76 | 0.667, 0.333, 1.000,
77 | 0.667, 0.667, 1.000,
78 | 0.667, 1.000, 1.000,
79 | 1.000, 0.000, 1.000,
80 | 1.000, 0.333, 1.000,
81 | 1.000, 0.667, 1.000,
82 | 0.167, 0.000, 0.000,
83 | 0.333, 0.000, 0.000,
84 | 0.500, 0.000, 0.000,
85 | 0.667, 0.000, 0.000,
86 | 0.833, 0.000, 0.000,
87 | 1.000, 0.000, 0.000,
88 | 0.000, 0.167, 0.000,
89 | 0.000, 0.333, 0.000,
90 | 0.000, 0.500, 0.000,
91 | 0.000, 0.667, 0.000,
92 | 0.000, 0.833, 0.000,
93 | 0.000, 1.000, 0.000,
94 | 0.000, 0.000, 0.167,
95 | 0.000, 0.000, 0.333,
96 | 0.000, 0.000, 0.500,
97 | 0.000, 0.000, 0.667,
98 | 0.000, 0.000, 0.833,
99 | 0.000, 0.000, 1.000,
100 | 0.000, 0.000, 0.000,
101 | 0.143, 0.143, 0.143,
102 | 0.286, 0.286, 0.286,
103 | 0.429, 0.429, 0.429,
104 | 0.571, 0.571, 0.571,
105 | 0.714, 0.714, 0.714,
106 | 0.857, 0.857, 0.857,
107 | 1.000, 1.000, 1.000
108 | ]
109 | ).astype(np.float32)
110 | color_list = color_list.reshape((-1, 3)) * 255
111 | if not rgb:
112 | color_list = color_list[:, ::-1]
113 | return color_list
114 |
--------------------------------------------------------------------------------
/utils/draw_As.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import cv2
3 | import numpy as np
4 | import os
5 | from glob import glob
6 | import matplotlib
7 | matplotlib.use('Agg') # Use a non-interactive backend
8 | import matplotlib.pyplot as plt
9 | from matplotlib.patches import Polygon
10 | import pickle
11 | import random
12 |
13 | from colormap import colormap
14 |
15 | import pdb
16 |
17 | FLIP = 1
18 | TRANS = 0
19 |
20 | ckpt_dir = '/sailhome/bingbin/STR-PIP/ckpts/JAAD/graph_gru_seq30_pred30_lr1.0e-04_wd1.0e-05_bt16_posNone_branchboth_collapse0_combinepair_adjTypespatial_nLayers2_v4Feats_pedGRU_newCtxtGRU_3evalEpoch/'
21 |
22 | cache_dir = os.path.join(ckpt_dir, 'test_graph_weights_epochbest_pred')
23 | out_dir = os.path.join(ckpt_dir, 'vis_out')
24 | os.makedirs(out_dir, exist_ok=True)
25 |
26 | def tmp_vis_one_image():
27 | fpkls = sorted(glob(os.path.join(cache_dir, '*pkl')))
28 | for pi,fpkl in enumerate(fpkls):
29 | if pi < 70:
30 | continue
31 | if pi and pi%10 == 0:
32 | print("{} / {}".format(pi, len(fpkls)))
33 |
34 | with open(fpkl, 'rb') as handle:
35 | vid = pickle.load(handle)
36 | v_ws = vid['ws']
37 | v_bbox = vid['obj_bbox']
38 | img_names = vid['img_paths']
39 |
40 | for i in range(len(v_ws)):
41 | img_name = img_names[i]
42 | fid = os.path.basename(img_name).split('.')[0]
43 | out_name = os.path.join(out_dir, os.path.basename(fpkl).replace('.pkl', '_i{}_f{}.png'.format(i, fid)))
44 | ws = v_ws[i][-1] # take weights from the last graph layer
45 | vis_one_image(img_names[i], out_name, v_bbox[i], weights=ws)
46 |
47 | def vis_one_image(im_name, fout, bboxes, dpi=200, weights=None):
48 | # masks: (N, 28, 28) ... masks for one frame
49 | if not len(bboxes):
50 | return
51 |
52 | im = cv2.imread(im_name)
53 | H, W, _ = im.shape
54 | color_list = colormap(rgb=True) / 255
55 |
56 | fig = plt.figure(frameon=False)
57 | fig.set_size_inches(im.shape[1] / dpi, im.shape[0] / dpi)
58 | ax = plt.Axes(fig, [0., 0., 1., 1.])
59 | ax.axis('off')
60 | fig.add_axes(ax)
61 | ax.imshow(im)
62 |
63 | mask_color_id = 0
64 | if weights is None:
65 | n_objs = masks.shape[0]
66 | obj_ids = range(n_objs)
67 | else:
68 | obj_ids = np.argsort(weights)
69 |
70 | ws = [0]
71 | for oid in obj_ids:
72 | x,y,w,h = bboxes[oid]
73 | mask = np.zeros([H, W])
74 | mask[x:x+w, y:y+h] = 1
75 | mask = mask.astype('uint8')
76 | if mask.sum() == 0:
77 | continue
78 |
79 | if weights is not None:
80 | ws += weights[oid],
81 | color_mask = color_list[mask_color_id % len(color_list), 0:3]
82 | mask_color_id += 1
83 |
84 | w_ratio = .4
85 | for c in range(3):
86 | color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio
87 |
88 | e_down = mask
89 |
90 | e_pil = Image.fromarray(e_down)
91 | e_pil_up = e_pil.resize((H, W) if TRANS else (W, H),Image.ANTIALIAS)
92 | e = np.array(e_pil_up)
93 |
94 | _, contour, hier = cv2.findContours(e.copy(), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
95 |
96 | if len(contour) > 1:
97 | print('# contour:', len(contour))
98 | for c in contour:
99 | if FLIP:
100 | assert(c.shape[1] == 1 and c.shape[2] == 2), print('c.shape:', c.shape)
101 | for pid in range(c.shape[0]):
102 | c[pid][0][0], c[pid][0][1] = c[pid][0][1], c[pid][0][0]
103 | linewidth = 1.2
104 | alpha = 0.5
105 | if oid == obj_ids[-1]:
106 | # most probable obj
107 | edgecolor=(1,0,0,1) # 'r'
108 | else:
109 | edgecolor=(1,1,1,1) # 'w'
110 | if weights is not None:
111 | linewidth *= (4 ** weights[oid])
112 | alpha /= (4 ** weights[oid])
113 |
114 | polygon = Polygon(
115 | c.reshape((-1, 2)),
116 | fill=True, facecolor=(color_mask[0], color_mask[1], color_mask[2], alpha),
117 | edgecolor=edgecolor, linewidth=linewidth,
118 | )
119 | xy = polygon.get_xy()
120 |
121 | ax.add_patch(polygon)
122 |
123 | fig.savefig(fout.replace('.jpg', '_{:.3f}.jpg'.format(max(ws))), dpi=dpi)
124 | plt.close('all')
125 |
126 |
127 | if __name__ == '__main__':
128 | # tmp_wrapper()
129 | tmp_vis_one_image()
130 |
131 |
--------------------------------------------------------------------------------
/models/backbone/imagenet_pretraining.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.utils.model_zoo as model_zoo
3 | try:
4 | from models.backbone.resnet.resnet import model_urls
5 | except:
6 | from resnet.resnet import model_urls
7 |
8 |
9 |
10 | def _inflate_weight(w, new_temporal_size, inflation='center'):
11 | w_up = w.unsqueeze(2).repeat(1, 1, new_temporal_size, 1, 1)
12 | if inflation == 'center':
13 | w_up = central_inflate_3D_conv(w_up) # center
14 | elif inflation == 'mean':
15 | w_up /= new_temporal_size # mean
16 | return w_up
17 |
18 |
19 | def central_inflate_3D_conv(w):
20 | new_temporal_size = w.size(2)
21 | middle_timestep = int(new_temporal_size / 2.)
22 | before, after = list(range(middle_timestep)), list(range(middle_timestep + 1, new_temporal_size))
23 | if len(before) > 0:
24 | w[:, :, before] = torch.zeros_like(w[:, :, before])
25 | if len(after):
26 | w[:, :, after] = torch.zeros_like(w[:, :, after])
27 | return w
28 |
29 |
30 | def inflate_temporal_conv(pretrained_W_updated, model_dict, inflation):
31 | for k, v in model_dict.items():
32 | if '1t' in k:
33 | if 'conv' in k:
34 | if 'bias' in k:
35 | v_up = torch.zeros_like(v)
36 | elif 'weight' in k:
37 | v_up = torch.zeros_like(v)
38 |
39 | h, w, T, *_ = v_up.size()
40 | t_2 = int(T / 2.)
41 | for i in range(h):
42 | for j in range(w):
43 | if i == j:
44 | if inflation == 'center':
45 | v_up[i, j, t_2] = torch.ones_like(v_up[i, j, t_2])
46 | elif inflation == 'mean':
47 | v_up[i, j] = torch.ones_like(v_up[i, j]) / T
48 | # elif 'bn' in k:
49 | # if 'running_mean' in k:
50 | # v_up = torch.zeros_like(v)
51 | # elif 'running_var' in k:
52 | # v_up = torch.ones_like(v) - 1e-05
53 | # elif 'bias' in k:
54 | # v_up = torch.zeros_like(v)
55 | # elif 'weight' in k:
56 | # v_up = torch.ones_like(v)
57 |
58 | # udpate
59 | pretrained_W_updated.update({k: v_up})
60 | return pretrained_W_updated
61 |
62 |
63 | def _update_pretrained_weights(model, pretrained_W, inflation='center'):
64 | pretrained_W_updated = pretrained_W.copy()
65 | model_dict = model.state_dict()
66 | for k, v in pretrained_W.items():
67 | if "conv" in k or ('bn' in k and '1t' not in k) or 'downsample' in k:
68 | if k in model_dict.keys():
69 | if len(model_dict[k].shape) == 5:
70 | new_temporal_size = model_dict[k].size(2)
71 | v_updated = _inflate_weight(v, new_temporal_size, inflation)
72 | else:
73 | v_updated = v
74 |
75 | if isinstance(v, torch.autograd.Variable):
76 | pretrained_W_updated.update({k: v_updated.data})
77 | else:
78 | pretrained_W_updated.update({k: v_updated})
79 | if "fc.weight" in k:
80 | pretrained_W_updated.pop('fc.weight', None)
81 | if "fc.bias" in k:
82 | pretrained_W_updated.pop('fc.bias', None)
83 |
84 | # update the dict for 1D conv for 2.5D conv
85 | pretrained_W_updated = inflate_temporal_conv(pretrained_W_updated, model_dict, inflation)
86 |
87 | # update the state dict
88 | model_dict.update(pretrained_W_updated)
89 |
90 | return model_dict
91 |
92 |
93 | def _keep_only_existing_keys(model, pretrained_weights_inflated):
94 | # Loop over the model_dict and update W
95 | model_dict = model.state_dict() # Take the initial weights
96 | for k, v in model_dict.items():
97 | if k in pretrained_weights_inflated.keys():
98 | model_dict[k] = pretrained_weights_inflated[k]
99 | return model_dict
100 |
101 |
102 | def load_pretrained_2D_weights(arch, model, inflation):
103 | pretrained_weights = model_zoo.load_url(model_urls[arch])
104 | pretrained_weights_inflated = _update_pretrained_weights(model, pretrained_weights, inflation)
105 | model.load_state_dict(pretrained_weights_inflated)
106 | print(" -> Init: Imagenet - 3D from 2D (inflation = {})".format(inflation))
107 | return model
108 |
--------------------------------------------------------------------------------
/utils/tracker.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import json
3 | import numpy as np
4 | import optparse
5 | import os
6 | import time
7 |
8 |
9 | def check_color(crossed):
10 | if crossed:
11 | return (0, 0, 255)
12 | return (0, 255, 0)
13 |
14 |
15 | def create_rect(box):
16 | x1, y1 = int(box['x1']), int(box['y1'])
17 | x2, y2 = int(box['x2']), int(box['y2'])
18 |
19 | return x1, y1, x2, y2
20 |
21 |
22 | def create_writer(capture):
23 | fourcc = cv2.VideoWriter_fourcc(*'X264')
24 | height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
25 | width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
26 | frame_count = int(capture.get(cv2.CAP_PROP_FPS))
27 |
28 | writer = cv2.VideoWriter('output.mp4',
29 | fourcc,
30 | 2,
31 | (width, height))
32 |
33 | return writer
34 |
35 |
36 | def get_params(frame_data):
37 | boxes = [f['box'] for f in frame_data]
38 | ids = [f['matchIds'] for f in frame_data]
39 | crossed = [f['crossed'] for f in frame_data]
40 |
41 | return boxes, ids, crossed
42 |
43 |
44 | def parse_options():
45 | parser = optparse.OptionParser()
46 | parser.add_option('-v', '--video',
47 | dest='video_path',
48 | default=None)
49 | parser.add_option('-j', '--json',
50 | dest='json_path',
51 | default=None)
52 | parser.add_option('-u', '--until',
53 | type='int',
54 | default=None)
55 | parser.add_option('-w', '--write',
56 | dest='write',
57 | action='store_true',
58 | default=False)
59 | parser.add_option('-m', '--mask',
60 | dest='mask',
61 | action='store_true',
62 | default=False)
63 | options, remainder = parser.parse_args()
64 |
65 | # Check for errors.
66 | if options.video_path is None:
67 | raise Exception('Undefined video')
68 | if options.json_path is None:
69 | raise Exception('Undefined json_file')
70 |
71 | return options
72 |
73 |
74 | def Main():
75 | options = parse_options()
76 |
77 | # Open VideoCapture.
78 | cap = cv2.VideoCapture(options.video_path)
79 |
80 | # Load json file with annotations.
81 | with open(options.json_path, 'r') as f:
82 | data = json.load(f)['frames']
83 |
84 | if options.write:
85 | writer = create_writer(cap)
86 |
87 | font = cv2.FONT_HERSHEY_SIMPLEX
88 | frame_no = 1
89 | while True:
90 |
91 | wait_key = 25
92 | flag, img = cap.read()
93 |
94 | # Create black image.
95 | black_img = np.zeros(img.shape, dtype=np.uint8)
96 |
97 | if frame_no % 120 == 0:
98 | print('Processed {0} frames'.format(frame_no))
99 |
100 | if frame_no % 6 != 0:
101 | frame_no += 1
102 | continue
103 |
104 | key = str(int(frame_no / 6 + 1))
105 |
106 | boxes = data.get(key)
107 |
108 | if boxes == None:
109 | boxes = []
110 |
111 | # Create list of trackers each 60 frames.
112 | boxes, ids, crossed = get_params(boxes)
113 |
114 | for i, box in enumerate(boxes):
115 | x1, y1, x2, y2 = create_rect(box)
116 |
117 | if options.mask:
118 | roi = img[y1:y2, x1:x2].copy()
119 | black_img[y1:y2, x1:x2] = roi
120 |
121 | if not options.mask:
122 | crossed_color = check_color(crossed[i])
123 | cv2.rectangle(img, (x1, y1), (x2, y2), crossed_color, 2, 1)
124 | cv2.putText(img, ids[i], (x1, y1 - 10), font, 0.6,
125 | (0, 0, 0), 5, cv2.LINE_AA)
126 | cv2.putText(img, ids[i], (x1, y1 - 10), font, 0.6,
127 | crossed_color, 1, cv2.LINE_AA)
128 |
129 | if options.write:
130 | if options.mask:
131 | writer.write(black_img)
132 | else:
133 | writer.write(img)
134 | else:
135 | if options.mask:
136 | cv2.imshow('frame', black_img)
137 | else:
138 | cv2.imshow('frame', img)
139 |
140 | if cv2.waitKey(wait_key) & 0xFF == ord('q'):
141 | break
142 | if frame_no == options.until:
143 | break
144 |
145 | if flag is False:
146 | break
147 |
148 | frame_no += 1
149 |
150 | cap.release()
151 | if options.write:
152 | writer.release()
153 |
154 |
155 | if __name__ == '__main__':
156 | Main()
157 |
--------------------------------------------------------------------------------
/models/backbone/resnet/bottleneck.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class Bottleneck(nn.Module):
5 | expansion = 4
6 | only_2D = False
7 |
8 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
9 | super(Bottleneck, self).__init__()
10 | self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False)
11 | self.bn1 = nn.BatchNorm3d(planes)
12 | self.conv2 = None
13 | self.bn2 = nn.BatchNorm3d(planes)
14 | self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=False)
15 | self.bn3 = nn.BatchNorm3d(planes * 4)
16 | self.relu = nn.ReLU(inplace=True)
17 | self.downsample = downsample
18 | self.stride = stride
19 | self.input_dim = 5
20 | self.dilation = dilation
21 |
22 | def forward(self, x):
23 | residual = x
24 |
25 | # print('Bottleneck devices:')
26 | # print('current device:', torch.cuda.current_device())
27 | # print('x:', x.device)
28 | # print('self.conv1:', torch.cuda.device_of(self.conv1))
29 | out = self.conv1(x)
30 | out = self.bn1(out)
31 | out = self.relu(out)
32 |
33 | out = self.conv2(out)
34 | out = self.bn2(out)
35 | out = self.relu(out)
36 |
37 | out = self.conv3(out)
38 | out = self.bn3(out)
39 |
40 | if self.downsample is not None:
41 | residual = self.downsample(x)
42 |
43 | out += residual
44 | out = self.relu(out)
45 |
46 | return out
47 |
48 |
49 | class Bottleneck3D(Bottleneck):
50 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1, **kwargs):
51 | super().__init__(inplanes, planes, stride, downsample, dilation)
52 | self.conv2 = nn.Conv3d(planes, planes, kernel_size=3, stride=stride,
53 | padding=1, bias=False, dilation=(1, dilation, dilation))
54 |
55 |
56 | class Bottleneck2D(Bottleneck):
57 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1, **kwargs):
58 | super().__init__(inplanes, planes, stride, downsample, dilation)
59 | # to speed up the inference process
60 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
61 | self.bn1 = nn.BatchNorm2d(planes)
62 | self.bn2 = nn.BatchNorm2d(planes)
63 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False, dilation=dilation)
64 | self.bn3 = nn.BatchNorm2d(planes * 4)
65 | self.input_dim = 4
66 |
67 | if isinstance(stride, int):
68 | stride_1, stride_2 = stride, stride
69 | else:
70 | stride_1, stride_2 = stride[0], stride[1]
71 |
72 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=(3, 3), stride=(stride_1, stride_2),
73 | padding=(1, 1), bias=False)
74 |
75 |
76 | class Bottleneck2_1D(Bottleneck):
77 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1, nb_temporal_conv=1):
78 | super().__init__(inplanes, planes, stride, downsample, dilation)
79 |
80 | if isinstance(stride, int):
81 | stride_2d, stride_1t = (1, stride, stride), (stride, 1, 1)
82 | else:
83 | stride_2d, stride_1t = (1, stride[1], stride[2]), (stride[0], 1, 1)
84 |
85 | # CONV2
86 | self.conv2 = nn.Conv3d(planes, planes, kernel_size=(1, 3, 3), stride=stride_2d,
87 | padding=(0, dilation, dilation), bias=False, dilation=dilation)
88 |
89 | self.conv2_1t = nn.Sequential()
90 | for i in range(nb_temporal_conv):
91 | temp_conv = nn.Conv3d(planes, planes, kernel_size=(3, 1, 1), stride=stride_1t,
92 | padding=(1, 0, 0), bias=False, dilation=1)
93 | self.conv2_1t.add_module('temp_conv_{}'.format(i), temp_conv)
94 | self.conv2_1t.add_module(('relu_{}').format(i), nn.ReLU(inplace=True))
95 |
96 |
97 | def forward(self, x):
98 | residual = x
99 |
100 | ## CONV1 - 3D (1,1,1)
101 | out = self.conv1(x)
102 | out = self.bn1(out)
103 | out = self.relu(out)
104 |
105 | ## CONV2
106 | # Spatial - 2D (1,3,3)
107 | out = self.conv2(out)
108 | out = self.bn2(out)
109 | out = self.relu(out)
110 |
111 | # Temporal - 3D (3,1,1)
112 | out = self.conv2_1t(out)
113 |
114 | ## CONV3 - 3D (1,1,1)
115 | out = self.conv3(out)
116 | out = self.bn3(out)
117 |
118 | if self.downsample is not None:
119 | residual = self.downsample(x)
120 |
121 | out += residual
122 | out = self.relu(out)
123 |
124 | return out
125 |
--------------------------------------------------------------------------------
/utils/masking.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import json
3 | import optparse
4 | import os
5 | import time
6 | import numpy as np
7 | import IPython
8 |
9 | MACHINE = 'TRI'
10 | # MACHINE = 'Stanford'
11 |
12 | if MACHINE == 'TRI':
13 | # downsampled from:'/mnt/parallel/stip/ANN_hanh2/20170907_prolix_trial_ANN_hanh2-09-07-2017_15-44-07_idx00.mkv'
14 | test_video_in = '/mnt/parallel/stip/ANN_hanh2/ANN_hanh2_downsample_fps12.mkv'
15 | test_video_out = 'ANN_hanh2_fps12_masked.mp4'
16 | test_annot = '/mnt/parallel/stip/annotation/20170907_prolix_trial_ANN_hanh2-09-07-2017_15-44-07.concat.12fps.mp4.json'
17 | else:
18 | test_video_in = '/sailhome/bingbin/STR-PIP/datasets/STIP/ANN_hanh2_downsampled.mkv'
19 | test_video_out = 'ANN_hanh2_fps12_masked.mp4'
20 | test_annot = '/sailhome/bingbin/STR-PIP/datasets/STIP/annotations/20170907_prolix_trial_ANN_hanh2-09-07-2017_15-44-07.concat.12fps.mp4.json'
21 |
22 |
23 | def check_color(crossed):
24 | if crossed:
25 | return (0, 0, 255)
26 | return (0, 255, 0)
27 |
28 |
29 | def create_rect(box):
30 | x1, y1 = int(box['x1']), int(box['y1'])
31 | x2, y2 = int(box['x2']), int(box['y2'])
32 |
33 | return x1, y1, x2, y2
34 |
35 |
36 | def create_writer(capture, options):
37 | fourcc = cv2.VideoWriter_fourcc(*'mp4v')
38 | height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
39 | width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
40 | frame_count = int(capture.get(cv2.CAP_PROP_FPS))
41 |
42 | print('create_writer:', options.saved_video_path)
43 | writer = cv2.VideoWriter(options.saved_video_path,
44 | fourcc,
45 | 2,
46 | (width, height))
47 |
48 | return writer
49 |
50 |
51 | def get_params(frame_data):
52 | boxes = [f['box'] for f in frame_data]
53 | ids = [f['matchIds'] for f in frame_data]
54 | crossed = [f['crossed'] for f in frame_data]
55 |
56 | return boxes, ids, crossed
57 |
58 |
59 | def parse_options():
60 | parser = optparse.OptionParser()
61 | parser.add_option('-v', '--video',
62 | dest='video_path',
63 | default=test_video_in)
64 | parser.add_option('-j', '--json',
65 | dest='json_path',
66 | default=test_annot)
67 | parser.add_option('-u', '--until',
68 | type='int',
69 | default=None)
70 | parser.add_option('-w', '--write_video',
71 | dest='saved_video_path',
72 | default=test_video_out)
73 | options, remainder = parser.parse_args()
74 |
75 | # Check for errors.
76 | if options.video_path is None:
77 | raise Exception('Undefined video')
78 | if options.json_path is None:
79 | raise Exception('Undefined json_file')
80 |
81 | return options
82 |
83 |
84 | def Main():
85 | options = parse_options()
86 | print('here')
87 | # Open VideoCapture.
88 | cap = cv2.VideoCapture(options.video_path)
89 |
90 | # Load json file with annotations.
91 | with open(options.json_path, 'r') as f:
92 | data = json.load(f)['frames']
93 |
94 | lastKeyFrame = int(list(data.keys())[-1])
95 |
96 | writer = create_writer(cap, options)
97 |
98 | font = cv2.FONT_HERSHEY_SIMPLEX
99 | frame_no = 1
100 | while True:
101 | wait_key = 25
102 | flag, img = cap.read()
103 | if frame_no % 120 == 0:
104 | print('Processed {0} frames'.format(frame_no))
105 |
106 | if frame_no % 6 != 0:
107 | frame_no += 1
108 | continue
109 |
110 | key = str(int(frame_no / 6 + 1))
111 |
112 | boxes = data.get(key)
113 |
114 | if boxes == None:
115 | boxes = []
116 |
117 | # Create list of trackers each 60 frames.
118 | boxes, ids, crossed = get_params(boxes)
119 | mask = np.zeros(img.shape)
120 | for i, box in enumerate(boxes):
121 | x1, y1, x2, y2 = create_rect(box)
122 | if ids[i] == '4' or ids[i] == '16':
123 | print(frame_no, key, box)
124 | print((x1, y1, x2, y2))
125 |
126 | mask[y1:y2,x1:x2, :] = 1
127 | # crossed_color = check_color(crossed[i])
128 | # cv2.rectangle(img, (x1, y1), (x2, y2), crossed_color, 2, 1)
129 | if '4' in ids or '16' in ids:
130 | wait_key = 0
131 |
132 |
133 | # print('writing')
134 | writer.write(np.uint8(img*mask))
135 |
136 | if frame_no == options.until:
137 | break
138 |
139 | if frame_no > lastKeyFrame:
140 | break
141 |
142 | frame_no += 1
143 |
144 | cap.release()
145 | writer.release()
146 |
147 |
148 | if __name__ == '__main__':
149 | Main()
150 |
--------------------------------------------------------------------------------
/cache_data.py:
--------------------------------------------------------------------------------
1 | # python imports
2 | import os
3 | import sys
4 | from glob import glob
5 | import pickle
6 | import time
7 | import numpy as np
8 |
9 |
10 | # local imports
11 | import data
12 | import utils
13 | from utils.data_proc import parse_objs
14 |
15 | import pdb
16 |
17 |
18 | def cache_masks():
19 | opt, logger = utils.build(is_train=False)
20 | opt.combine_method = ''
21 | opt.split = 'train'
22 | cache_dir_name = 'jaad_collapse{}'.format('_'+opt.combine_method if opt.combine_method else '')
23 | data.cache_all_objs(opt, cache_dir_name)
24 |
25 |
26 | def cache_crops():
27 | fnpy_root = '/sailhome/bingbin/STR-PIP/datasets/JAAD_instance_segm'
28 | fpkl_root = '/sailhome/bingbin/STR-PIP/datasets/cache/JAAD_instance_crops'
29 | utils.get_obj_crops(fnpy_root, fpkl_root)
30 |
31 | def add_obj_bbox():
32 | fnpy_root = '/sailhome/bingbin/STR-PIP/datasets/JAAD_instance_segm'
33 | # fpkl_root = '/sailhome/bingbin/STR-PIP/datasets/cache/jaad_collapse'
34 | fobj_root = '/sailhome/bingbin/STR-PIP/datasets/cache/obj_bbox'
35 | os.makedirs(fobj_root, exist_ok=True)
36 | dir_vids = sorted(glob(os.path.join(fnpy_root, 'vid*')))
37 |
38 | def helper(vid_range, split):
39 | for dir_vid in vid_range:
40 | print(dir_vid)
41 | sys.stdout.flush()
42 | vid = int(os.path.basename(dir_vid).split('_')[1])
43 | t_start = time.time()
44 | fsegms = sorted(glob(os.path.join(dir_vid, '*_segm.npy')))
45 | for i, fsegm in enumerate(fsegms):
46 | if i and i%100 == 0:
47 | print('Time per frame:', (time.time() - t_start)/i)
48 | sys.stdout.flush()
49 | # Note: 'fid' is 0-based for segm, but 1-based for images and caches.
50 | fid = os.path.basename(fsegm).split('_')[0]
51 | fbbox = os.path.join(fobj_root, 'vid{:08d}_fid{:s}.pkl'.format(vid, fid))
52 | if os.path.exists(fbbox):
53 | continue
54 | if not os.path.exists(fsegm):
55 | print('File does not exist:', fsegm)
56 | continue
57 | objs = parse_objs(fsegm)
58 | dobjs = {cls:[] for cls in range(1,5)}
59 | for cls, masks in objs.items():
60 | for mask in masks:
61 | try:
62 | if len(mask.shape) == 3:
63 | h, w, c = mask.shape
64 | if c != 1:
65 | raise ValueError('Each mask should have shape (1080, 1920, 1)')
66 | mask = mask.reshape(h, w)
67 | x_pos = mask.sum(0).nonzero()[0]
68 | if not len(x_pos):
69 | x_pos = [0,0]
70 | x_min, x_max = x_pos[0], x_pos[-1]
71 | y_pos = mask.sum(1).nonzero()[0]
72 | if not len(y_pos):
73 | y_pos = [0,0]
74 | y_min, y_max = y_pos[0], y_pos[-1]
75 | # bbox: [x_min, y_min, w, h]; same as bbox for ped['pos_GT']
76 | bbox = [x_min, y_min, x_max-x_min, y_max-y_min]
77 | except Exception as e:
78 | print(e)
79 | pdb.set_trace()
80 |
81 | dobjs[cls] += bbox,
82 | with open(fbbox, 'wb') as handle:
83 | pickle.dump(dobjs, handle)
84 |
85 | if False:
86 | vids_train = dir_vids[:250]
87 | helper(vids_train, 'train')
88 | if True:
89 | vids_test = dir_vids[250:]
90 | helper(vids_test, 'test')
91 |
92 | def merge_and_flat(vrange):
93 | """
94 | Merge fids in a vid and flatten the classes
95 | """
96 | pkl_in_root = '/sailhome/bingbin/STR-PIP/datasets/cache/obj_bbox'
97 | pkl_out_root = '/sailhome/bingbin/STR-PIP/datasets/cache/obj_bbox_merged'
98 | os.makedirs(pkl_out_root, exist_ok=True)
99 | # for vid in range(1, 347):
100 | for vid in vrange:
101 | fpkls = sorted(glob(os.path.join(pkl_in_root, 'vid{:08d}*pkl'.format(vid))))
102 | print(vid, len(fpkls))
103 | sys.stdout.flush()
104 | # merged = [[] for _ in range(len(fpkls))]
105 | merged_bbox = []
106 | merged_cls = []
107 | t_start = time.time()
108 | for fpkl in fpkls:
109 | with open(fpkl, 'rb') as handle:
110 | data = pickle.load(handle)
111 | curr_bbox = []
112 | cls = []
113 | for c in [1,2,3,4]:
114 | for bbox in data[c]:
115 | cls += c,
116 | curr_bbox += bbox,
117 | merged_bbox += np.array(curr_bbox),
118 | merged_cls += np.array(cls),
119 |
120 | fpkl_out = os.path.join(pkl_out_root, 'vid{:08d}.pkl'.format(vid))
121 | with open(fpkl_out, 'wb') as handle:
122 | dout = {
123 | 'obj_cls': merged_cls,
124 | 'obj_bbox': merged_bbox,
125 | }
126 | pickle.dump(dout, handle)
127 | print('avg time: ', (time.time()-t_start) / len(fpkls))
128 |
129 | if __name__ == '__main__':
130 | # cache_masks()
131 | cache_crops()
132 | add_obj_bbox()
133 |
134 | # merge_and_flat(range(1, 200))
135 | # merge_and_flat(range(100, 200))
136 | # merge_and_flat(range(200, 300))
137 | # merge_and_flat(range(100, 347))
138 | merge_and_flat(range(1, 347))
139 |
140 |
--------------------------------------------------------------------------------
/models/backbone/resnet/basicblock.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | def conv3x3(in_planes, out_planes, stride=1, dilation=1):
4 | "3x3 convolution with padding"
5 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
6 | padding=1, bias=False, dilation=dilation)
7 |
8 |
9 | def conv3x3x3(in_planes, out_planes, stride=1, dilation=1):
10 | "3x3 convolution with padding"
11 | return nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride,
12 | padding=1, bias=False, dilation=(1, dilation, dilation))
13 |
14 |
15 | def conv1x3x3(in_planes, out_planes, stride=1, dilation=1):
16 | "3x3 convolution with padding"
17 | if isinstance(stride, int):
18 | stride_1, stride_2, stride_3 = 1, stride, stride
19 | else:
20 | stride_1, stride_2, stride_3 = 1, stride[1], stride[2]
21 |
22 | return nn.Conv3d(in_planes, out_planes, kernel_size=(1, 3, 3),
23 | stride=(stride_1, stride_2, stride_3),
24 | padding=(0, 1, 1), bias=False, dilation=(1, dilation, dilation))
25 |
26 |
27 | def conv1x3x3_conv3x1x1(in_planes, out_planes, stride=1, dilation=1, nb_temporal_conv=3):
28 | "3x3 convolution with padding"
29 | if isinstance(stride, int):
30 | stride_2d, stride_1t = (1, stride, stride), (stride, 1, 1)
31 | else:
32 | stride_2d, stride_1t = (1, stride[1], stride[2]), (stride[0], 1, 1)
33 |
34 | _2d = nn.Conv3d(in_planes, out_planes, kernel_size=(1, 3, 3), stride=stride_2d,
35 | padding=(0, 1, 1), bias=False, dilation=dilation)
36 |
37 | _1t = nn.Sequential()
38 | for i in range(nb_temporal_conv):
39 | temp_conv = nn.Conv3d(out_planes, out_planes, kernel_size=(3, 1, 1), stride=stride_1t,
40 | padding=(1, 0, 0), bias=False, dilation=1)
41 | _1t.add_module('temp_conv_{}'.format(i), temp_conv)
42 | _1t.add_module(('relu_{}').format(i), nn.ReLU(inplace=True))
43 |
44 | return _2d, _1t
45 |
46 |
47 | class BasicBlock(nn.Module):
48 | expansion = 1
49 | only_2D = False
50 |
51 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
52 | super(BasicBlock, self).__init__()
53 | self.bn1 = nn.BatchNorm3d(planes)
54 | self.relu = nn.ReLU(inplace=True)
55 | self.bn2 = nn.BatchNorm3d(planes)
56 | self.downsample = downsample
57 | self.stride = stride
58 | self.conv1, self.conv2 = None, None
59 | self.input_dim = 5
60 | self.dilation = dilation
61 |
62 | def forward(self, x):
63 | residual = x
64 |
65 | out = self.conv1(x)
66 | out = self.bn1(out)
67 | out = self.relu(out)
68 |
69 | out = self.conv2(out)
70 | out = self.bn2(out)
71 |
72 | if self.downsample is not None:
73 | residual = self.downsample(x)
74 |
75 | out += residual
76 | out = self.relu(out)
77 |
78 | return out
79 |
80 |
81 | class BasicBlock3D(BasicBlock):
82 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1, **kwargs):
83 | super().__init__(inplanes, planes, stride, downsample, dilation)
84 | self.conv1 = conv3x3x3(inplanes, planes, stride, dilation)
85 | self.conv2 = conv3x3x3(planes, planes, dilation)
86 |
87 |
88 | class BasicBlock2D(BasicBlock):
89 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1, **kwargs):
90 | super().__init__(inplanes, planes, stride, downsample, dilation)
91 | # not the same input size here to speed up training
92 | self.bn1 = nn.BatchNorm2d(planes)
93 | self.bn2 = nn.BatchNorm2d(planes)
94 | self.conv1 = conv3x3(inplanes, planes, stride, dilation)
95 | self.conv2 = conv3x3(planes, planes, dilation)
96 | self.input_dim = 4
97 |
98 |
99 | class BasicBlock2_1D(BasicBlock):
100 | expansion = 1
101 |
102 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1, nb_temporal_conv=1):
103 | super().__init__(inplanes, planes, stride, downsample, dilation)
104 | self.conv1, self.conv1_1t = conv1x3x3_conv3x1x1(inplanes, planes, stride, dilation,
105 | nb_temporal_conv=nb_temporal_conv)
106 | self.conv2, self.conv2_1t = conv1x3x3_conv3x1x1(planes, planes, dilation,
107 | nb_temporal_conv=nb_temporal_conv)
108 |
109 |
110 | def forward(self, x):
111 | residual = x
112 |
113 | out = self.conv1(x) # 2D in space
114 | out = self.bn1(out)
115 | out = self.relu(out)
116 |
117 | # ipdb.set_trace()
118 | out = self.conv1_1t(out) # 1D in time + relu after each conv
119 |
120 | out = self.conv2(out) # 2D in space
121 | out = self.bn2(out)
122 |
123 | if self.downsample is not None:
124 | residual = self.downsample(x)
125 |
126 | out += residual
127 | out = self.relu(out)
128 |
129 | out = self.conv2_1t(out) # 1D in time
130 |
131 | return out
132 |
--------------------------------------------------------------------------------
/utils/postproc_smooth.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import pickle
4 |
5 | import pdb
6 |
7 | def smooth_vote(fpkl, win_size=5):
8 | with open(fpkl, 'rb') as handle:
9 | data = pickle.load(handle)
10 | out = data['out']
11 | gt = data['GT']
12 | acc = (out == gt).mean()
13 |
14 | smooth = np.zeros_like(out)
15 | win_half = win_size // 2
16 | smooth[:, :win_half] = out[:, :win_half]
17 |
18 | z2o, o2z = 0, 0
19 | for i,row in enumerate(out):
20 | for j in range(win_half, len(row)):
21 | # vote within the window
22 | smooth[i, j] = 1 if row[j-win_half:j+win_half].mean() >= 0.5 else 0
23 | if smooth[i,j] != out[i,j]:
24 | if out[i,j] == 0:
25 | z2o += 1
26 | else:
27 | o2z += 1
28 | acc_smooth = (smooth == gt).mean()
29 |
30 | print('Acc:', acc)
31 | print('Acc smooth:', acc_smooth)
32 |
33 | print('o2z:', o2z)
34 | print('z2o:', z2o)
35 |
36 | def smooth_flip(fpkl):
37 | with open(fpkl, 'rb') as handle:
38 | data = pickle.load(handle)
39 | out = data['out']
40 | gt = data['GT']
41 | acc = (out == gt).mean()
42 |
43 | smooth = np.zeros_like(out)
44 | smooth[:, :1] = out[:, :1]
45 |
46 | z2o, o2z = 0, 0
47 | for i,row in enumerate(out):
48 | for j in range(1, len(row)-1):
49 | # check with the neighbors
50 | if row[j] != row[j-1] and row[j] != row[j+1]:
51 | smooth[i,j] = row[j-1]
52 | if row[j-1] == 0:
53 | o2z += 1
54 | else:
55 | z2o += 1
56 | else:
57 | smooth[i,j] = row[j]
58 | acc_smooth = (smooth == gt).mean()
59 |
60 | print('Acc:', acc)
61 | print('Acc smooth:', acc_smooth)
62 |
63 | print('o2z:', o2z)
64 | print('z2o:', z2o)
65 |
66 | def smooth_sticky(fpkl, win_size=5):
67 | with open(fpkl, 'rb') as handle:
68 | data = pickle.load(handle)
69 | out = data['out']
70 | gt = data['GT']
71 | acc = (out == gt).mean()
72 |
73 | smooth = np.zeros_like(out)
74 | smooth[:, :win_size] = out[:, :win_size]
75 |
76 | for i,row in enumerate(out):
77 | for j in range(win_size, len(row)):
78 | if smooth[i, j-win_size:j].mean() > 0.5:
79 | smooth[i, j] = 1
80 | else:
81 | smooth[i,j] = row[j]
82 | acc_smooth = (smooth == gt).mean()
83 |
84 | print('Acc:', acc)
85 | print('Acc smooth:', acc_smooth)
86 |
87 | o2z = ((smooth!=out) * (out==1)).sum()
88 | z2o = ((smooth!=out) * (out==0)).sum()
89 | print('o2z:', o2z)
90 | print('z2o:', z2o)
91 |
92 | print('GT o:', gt.mean())
93 | print('GT nelem:', gt.size)
94 |
95 | def smooth_hold(fpkl):
96 | # Stay 1 once gets to 1.
97 | with open(fpkl, 'rb') as handle:
98 | data = pickle.load(handle)
99 | out = data['out'][:, :-1]
100 | gt = data['GT'][:, :-1]
101 | print('Acc det:', (out[:, :30] == gt[:, :30]).mean())
102 |
103 | out = out[:, 30:]
104 | gt = gt[:, 30:]
105 | acc = (out == gt).mean()
106 |
107 | # pdb.set_trace()
108 |
109 | smooth = out.copy()
110 | for i,row in enumerate(smooth):
111 | # pdb.set_trace()
112 | pos = np.where(row == 1)[0]
113 | if len(pos):
114 | first = pos[0]
115 | row[first:] = 1
116 | acc_smooth = (smooth == gt).mean()
117 |
118 | print('Acc pred:', acc)
119 | print('Acc pred smooth:', acc_smooth)
120 | print('Acc last:', (smooth[:, -1] == gt[:, -1]).mean())
121 |
122 | def smooth_hold_lastObserve(fpkl):
123 | # Stay 1 once gets to 1.
124 | with open(fpkl, 'rb') as handle:
125 | data = pickle.load(handle)
126 | out = data['out'][:, :-1]
127 | gt = data['GT'][:, :-1]
128 | print('Acc det:', (out[:, :30] == gt[:, :30]).mean())
129 |
130 | last_observe = out[:, 29]
131 | print('prob of last==1:', last_observe.mean())
132 | out = out[:, 30:]
133 | gt = gt[:, 30:]
134 | acc = (out == gt).mean()
135 |
136 | # pdb.set_trace()
137 |
138 | smooth = out.copy()
139 | for i,row in enumerate(smooth):
140 | if last_observe[i] == 1:
141 | row[:] = 1
142 | acc_smooth = (smooth == gt).mean()
143 |
144 | print('Acc pred:', acc)
145 | print('Acc pred smooth:', acc_smooth)
146 | print('Acc last:', (smooth[:, -1] == gt[:, -1]).mean())
147 |
148 |
149 | def smooth_wrapper():
150 | # ped-centric model
151 | ckpt_dir = '/sailhome/bingbin/STR-PIP/ckpts/JAAD/'
152 | # ckpt_name = 'graph_gru_seq30_pred30_lr1.0e-05_wd1.0e-05_bt16_posNone_branchped_collapse0_combinepair_adjTypeembed_nLayers2_v2Feats'
153 | # label_name = 'label_epochbest_det_stepall.pkl'
154 | ckpt_name = 'graph_gru_seq30_pred30_lr3.0e-04_wd1.0e-05_bt16_posNone_branchboth_collapse0_combinepair_adjTypespatial_nLayers2_v4Feats_pedGRU_3evalEpoch'
155 | label_name = 'label_epochbest_pred_stepall.pkl'
156 |
157 | fpkl = os.path.join(ckpt_dir, ckpt_name, label_name)
158 |
159 | win_size = 3
160 | print('Vote (win_size={}):'.format(win_size))
161 | smooth_vote(fpkl, win_size=win_size)
162 |
163 | print('\nFlip:')
164 | smooth_flip(fpkl)
165 |
166 | print('\nSticky:')
167 | smooth_sticky(fpkl)
168 |
169 | print('\nHold:')
170 | smooth_hold(fpkl)
171 |
172 | print('\nHold last observed:')
173 | smooth_hold_lastObserve(fpkl)
174 |
175 | if __name__ == '__main__':
176 | smooth_wrapper()
177 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Spatiotemporal Relationship Reasoning for Pedestrian Intent Prediction (STR-PIP)
2 |
3 | **_Paper: https://arxiv.org/pdf/2002.08945.pdf_**
4 |
5 | ```
6 | @inproceeding{liu2020spatiotemporal,
7 | title={Spatiotemporal Relationship Reasoning for Pedestrian Intent Prediction},
8 | author={Bingbin Liu and Ehsan Adeli and Zhangjie Cao and Kuan-Hui Lee and Abhijeet Shenoi and Adrien Gaidon and Juan Carlos Niebles},
9 | year={2020},
10 | booktitle={IEEE Robotics and Automation Letters (IEEE RA-L) and International Conference on Robotics and Automation (ICRA)},
11 | publisher={IEEE}
12 | }
13 | ```
14 |
15 | ## Abstract
16 |
17 |
18 | Reasoning over visual data is a desirable capability for robotics and vision-based applications. Such reasoning enables forecasting the next events or actions in videos. In recent years, various models have been developed based on convolution operations for prediction or forecasting, but they lack the ability to reason over spatiotemporal data and infer the relationships of different objects in the scene. In this paper, we present a framework based on graph convolution to uncover the spatiotemporal relationships in the scene for reasoning about pedestrian intent. A scene graph is built on top of segmented object instances within and across video frames. Pedestrian intent, defined as the future action of crossing or not-crossing the street, is very crucial piece of information for autonomous vehicles to navigate safely and more smoothly. We approach the problem of intent prediction from two different perspectives and anticipate the intention-to-cross within both pedestrian-centric and location-centric scenarios. In addition, we introduce a new dataset designed specifically for autonomousdriving scenarios in areas with dense pedestrian populations: the Stanford-TRI Intent Prediction (STIP) dataset. Our experiments on STIP and another benchmark dataset show that our graph modeling framework is able to predict the intention-to-cross of the pedestrians with an accuracy of 79.10% on STIP and 79.28% on Joint Attention for Autonomous Driving (JAAD) dataset up to one second earlier than when the actual crossing happens. These results outperform baseline and previous work. Please refer to [https://stip.stanford.edu](https://stip.stanford.edu) for the dataset and code.
19 |
20 |
21 |
22 | ## Datasets
23 | ### Stanford-TRI Intention Prediction (STIP) Dataset:
24 | STIP includes over 900 hours of driving scene videos of front, right, and left cameras, while the vehicle was driving in dense areas of five cities in the United States. The videos were annotated at 2fps with pedestrian bounding boxes and labels of crossing/not-crossing the street, which are respectively shown with green/red boxes in the above videos. We used the [JRMOT (JackRabbot real-time Multi-Object Tracker) platform](https://sites.google.com/view/jrmot) to track the pedestrian and interpolate the annotations for all 20 frame per second.
25 |
26 | **Dataset Code:** https://github.com/StanfordVL/STIP
27 |
28 | **Dataset Information:** https://stip.stanford.edu/dataset.html
29 |
30 | **Request Access to Dataset:** [here](https://docs.google.com/forms/d/e/1FAIpQLSdG5CLJQs7QWY27uIkZj27O4XDm0-OsZVEmBRiHB8EaCoNZXA/viewform)
31 |
32 | ### Joint Attention in Autonomous Driving (JAAD) Dataset:
33 | JAAD is a dataset for studying joint attention in the context of autonomous driving. The focus is on pedestrian and driver behaviors at the point of crossing and factors that influence them. To this end, JAAD dataset provides a richly annotated collection of 346 short video clips (5-10 sec long) extracted from over 240 hours of driving footage. Bounding boxes with occlusion tags are provided for all pedestrians making this dataset suitable for pedestrian detection.
34 | Behavior annotations specify behaviors for pedestrians that interact with or require attention of the driver. For each video there are several tags (weather, locations, etc.) and timestamped behavior labels from a fixed list (e.g. stopped, walking, looking, etc.). In addition, a list of demographic attributes is provided for each pedestrian (e.g. age, gender, direction of motion, etc.) as well as a list of visible traffic scene elements (e.g. stop sign, traffic signal, etc.) for each frame.
35 |
36 | **Annotation Data:** https://github.com/ykotseruba/JAAD
37 |
38 | **Full Dataset:** http://data.nvision2.eecs.yorku.ca/JAAD_dataset
39 |
40 | ## Installation
41 |
42 | ### Notes Before Running
43 |
44 | This training code is setup to be ran on a 16GB GPU. You may have to make some adjustments if you do not have this hardware available.
45 |
46 | ### Create your Virtual Environment
47 | ```
48 | conda create --name crossing python=3.7
49 | conda activate crossing
50 | ```
51 |
52 | ### Install Required Libraries
53 | ```
54 | pip install torchvision==0.5.0
55 | pip install opencv-python
56 | pip install pycocotools
57 | pip install pickle
58 | pip install numpy
59 | pip install wandb
60 | ```
61 |
62 | ### Configure Environment Variables
63 | ```
64 | export PYTHONPATH=$PYTHONPATH:[path-to-repo]
65 | ```
66 |
67 | ### Clone this repo
68 | ```
69 | git clone https://github.com/StanfordVL/STR-PIP.git
70 | cd STR-PIP
71 | ```
72 |
73 | ## Usage
74 | Running "scripts_dir/train_concat.sh"
75 |
76 | In order to generate "annot_test_ped_withTag_sanityNoPose.pkl", you will need to run utils/data_proc.py. Specifically, we will run the function prepare_data() inside data_proc.py.
77 |
78 |
79 | ## Current code status (3-14-21)
80 |
81 | * With the correct pre-processing, STIP training/testing works.
82 | * We are looking for some specific artifacts to train on JAAD. We are very close to generating the correct files.
83 | * This code base is in the process of clean-up, doc, and re-factor. There may be sudden changes. We will publish a new release when it is a more ready state.
84 |
85 |
--------------------------------------------------------------------------------
/scripts_dir/test_graph.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | mode='evaluate'
4 | split='test'
5 |
6 | gpu_id=1
7 | n_workers=0
8 | n_acts=9
9 | bt=1
10 |
11 | seq_len=30
12 | predict=1
13 | pred_seq_len=30
14 | predict_k=0
15 |
16 | # test only
17 | slide=0
18 | rand_test=1
19 | log_every=10
20 | ckpt_dir='/sailhome/bingbin/STR-PIP/ckpts'
21 | dataset='JAAD'
22 |
23 | # pred 30
24 | # ckpt_name='graph_gru_seq30_pred30_lr1.0e-05_wd1.0e-05_bt16_posNone_branchped_collapse0_combinepair_adjTypeembed_nLayers2_v2Feats'
25 | # ckpt_name='graph_gru_seq30_pred30_lr3.0e-04_wd1.0e-05_bt16_posNone_branchboth_collapse0_combinepair_adjTypespatial_nLayers2_v4Feats_pedGRU_3evalEpoch'
26 | # ckpt_name='graph_gru_seq30_pred30_lr1.0e-04_wd1.0e-05_bt16_posNone_branchboth_collapse0_combinepair_adjTypespatial_nLayers2_diffW1_v4Feats_pedGRU_3evalEpoch'
27 | # ckpt_name='graph_gru_seq30_pred30_lr1.0e-04_wd1.0e-05_bt16_posNone_branchboth_collapse0_combinepair_adjTypespatial_nLayers2_v4Feats_pedGRU_3evalEpoch'
28 | ckpt_name='graph_gru_seq30_pred30_lr1.0e-04_wd1.0e-05_bt16_posNone_branchboth_collapse0_combinepair_adjTypespatial_nLayers2_v4Feats_pedGRU_newCtxtGRU_3evalEpoch' # best
29 |
30 | # pred 60
31 | # ckpt_name='graph_gru_seq30_pred60_lr1.0e-04_wd1.0e-05_bt16_posNone_branchboth_collapse0_combinepair_adjTypespatial_nLayers2_diffW0_v4Feats_pedGRU_3evalEpoch'
32 |
33 | # pred 90
34 | # ckpt_name='graph_gru_seq30_pred90_lr1.0e-04_wd1.0e-05_bt16_posNone_branchboth_collapse0_combinepair_adjTypespatial_nLayers2_diffW0_v4Feats_pedGRU_3evalEpoch'
35 |
36 | which_epoch=-1
37 | if [ $which_epoch -eq -1 ]
38 | then
39 | epoch_name='best_pred'
40 | else
41 | epoch_name=$which_epoch
42 | fi
43 | save_output=1
44 | save_output_format=$ckpt_dir'/'$dataset'/'$ckpt_name'/output_epoch'$epoch_name'_step{}.pkl'
45 | collect_A=1
46 | save_As_format=$ckpt_dir'/'$dataset'/'$ckpt_name'/test_graph_weights_epoch'$epoch_name'/vid{}_eval{}.pkl'
47 |
48 |
49 |
50 | load_cache='feats'
51 | # cache_format='/sailhome/bingbin/STR-PIP/datasets/cache/JAAD_conv_feats/concat_gru_lr1.0e-04_wd1.0e-05_bt4_ped_collapse0_combinepair_useBBox0_cacheMasks_fixGRU_singleTime/{}/ped{}_fid{}.pkl'
52 | cache_format='/sailhome/bingbin/STR-PIP/datasets/cache/JAAD_conv_feats/concat_gru_seq30_pred30_lr1.0e-04_wd1.0e-05_bt4_posNone_branchboth_collapse0_combinepair_cacheMasks_fixGRU_eval3_9acts_noAct_sanityWithPose_withReLU_pedGRU/{}/ped{}_fid{}.pkl'
53 |
54 |
55 | # cache_format='/sailhome/bingbin/STR-PIP/datasets/cache/JAAD_conv_feats/concat_gru_lr1.0e-05_bt4_test_epoch5/{}/ped{}_fid{}.pkl'
56 |
57 | # ckpt_name='graph_gru_lr1.0e-05_wd1.0e-05_bt16_ped_collapse0_combinepair_adjTypeembed_nLayers0_useBBox0_fixGRU'
58 | # which_epoch=50
59 | #
60 | # ckpt_name='graph_gru_lr1.0e-05_wd1.0e-05_bt16_ped_collapse0_combinepair_adjTypeembed_nLayers2_useBBox0_fixGRU'
61 | # which_epoch=22
62 | #
63 | # ckpt_name='graph_gru_lr1.0e-05_wd1.0e-05_bt16_both_collapse0_combinepair_adjTypeembed_nLayers2_useBBox0_fixGRU'
64 | # which_epoch=38
65 | #
66 | # ckpt_name='graph_gru_lr1.0e-05_wd1.0e-05_bt16_ped_collapse0_combinepair_adjTypeuniform_nLayers0_useBBox0_fixGRU'
67 | # which_epoch=42
68 |
69 | if [ "$mode" = "extract" ]
70 | then
71 | extract_feats_dir='/sailhome/bingbin/STR-PIP/datasets/cache/JAAD_conv_feats/concat_gru_lr1.0e-05_bt4_test_epoch5/test/'
72 | else
73 | extract_feats_dir='none_existent'
74 | fi
75 |
76 | use_gru=1
77 | use_trn=0
78 | ped_gru=1
79 | ctxt_gru=0
80 | ctxt_node=0
81 |
82 | # features
83 | use_act=0
84 | use_gt_act=0
85 | use_driver=0
86 | use_pose=0
87 | pos_mode='none'
88 | # branch='ped'
89 | branch='both'
90 | adj_type='spatial'
91 | use_obj_cls=0
92 | n_layers=2
93 | diff_layer_weight=0
94 | collapse_cls=0
95 | combine_method='pair'
96 |
97 | CUDA_VISIBLE_DEVICES=$gpu_id python3 test.py \
98 | --model='graph' \
99 | --split=$split \
100 | --mode=$mode \
101 | --slide=$slide \
102 | --rand-test=$rand_test \
103 | --ckpt-dir=$ckpt_dir \
104 | --ckpt-name=$ckpt_name \
105 | --n-acts=$n_acts \
106 | --device=$gpu_id \
107 | --dset-name='JAAD' \
108 | --ckpt-name=$ckpt_name \
109 | --which-epoch=$which_epoch \
110 | --save-output=$save_output \
111 | --save-output-format=$save_output_format \
112 | --collect-A=$collect_A \
113 | --save-As-format=$save_As_format \
114 | --n-epochs=$n_epochs \
115 | --start-epoch=$start_epoch \
116 | --seq-len=$seq_len \
117 | --predict=$predict \
118 | --pred-seq-len=$pred_seq_len \
119 | --predict-k=$predict_k \
120 | --batch-size=$bt \
121 | --log-every=$log_every \
122 | --n-workers=$n_workers \
123 | --load-cache=$load_cache \
124 | --cache-format=$cache_format \
125 | --branch=$branch \
126 | --adj-type=$adj_type \
127 | --use-obj-cls=$use_obj_cls \
128 | --n-layers=$n_layers \
129 | --diff-layer-weight=$diff_layer_weight \
130 | --collapse-cls=$collapse_cls \
131 | --combine-method=$combine_method \
132 | --use-gru=$use_gru \
133 | --use-trn=$use_trn \
134 | --ped-gru=$ped_gru \
135 | --ctxt-gru=$ctxt_gru \
136 | --ctxt-node=$ctxt_node \
137 | --use-act=$use_act \
138 | --use-gt-act=$use_gt_act \
139 | --use-driver=$use_driver \
140 | --use-pose=$use_pose \
141 | --pos-mode=$pos_mode
142 |
143 | exit
144 |
145 | # bak of prev command
146 |
147 | CUDA_VISIBLE_DEVICES=$gpu_id python3 test.py \
148 | --model='graph' \
149 | --split=$split \
150 | --mode=$mode \
151 | --device=$gpu_id \
152 | --log-every=$log_every \
153 | --dset-name=$dataset \
154 | --ckpt-dir=$ckpt_dir \
155 | --ckpt-name=$ckpt_name \
156 | --batch-size=1 \
157 | --n-workers=$n_workers \
158 | --load-cache=$load_cache \
159 | --cache-format=$cache_format \
160 | --n-layers=$n_layers \
161 | --seq-len=$seq_len \
162 | --predict=$predict \
163 | --pred-seq-len=$pred_seq_len \
164 | --use-gru=$use_gru \
165 | --pos-mode=$pos_mode \
166 | --adj-type=$adj_type \
167 | --collapse-cls=$collapse_cls \
168 | --slide=$slide \
169 | --rand-test=$rand_test \
170 | --branch=$branch \
171 | --which-epoch=$which_epoch \
172 | --extract-feats-dir=$extract_feats_dir \
173 | --save-output=$save_output \
174 | --save-output-format=$save_output_format
175 |
176 |
--------------------------------------------------------------------------------
/models/backbone/resnet/resnet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import math
3 | import torch
4 | import pdb
5 | try:
6 | from models.backbone.resnet.basicblock import BasicBlock2D
7 | from models.backbone.resnet.bottleneck import Bottleneck2D
8 | # from utils.other import transform_input
9 | # from utils.meter import *
10 | except:
11 | from resnet.basicblock import BasicBlock2D
12 | from resnet.bottleneck import Bottleneck2D
13 | # from basemodel.utils.other import transform_input
14 | # from basemodel.utils.meter import *
15 |
16 | model_urls = {
17 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
18 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
19 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
20 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
21 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
22 | }
23 |
24 | K_1st_CONV = 3
25 |
26 |
27 | class ResNetBackBone(nn.Module):
28 | def __init__(self, blocks, layers,
29 | str_first_conv='2D',
30 | nb_temporal_conv=1,
31 | list_stride=[1, 2, 2, 2],
32 | **kwargs):
33 | self.nb_temporal_conv = nb_temporal_conv
34 | self.inplanes = 64
35 | super(ResNetBackBone, self).__init__()
36 | self._first_conv(str_first_conv)
37 | self.relu = nn.ReLU(inplace=True)
38 | self.list_channels = [64, 128, 256, 512]
39 | self.list_inplanes = []
40 | self.list_inplanes.append(self.inplanes) # store the inplanes after layer1
41 | self.layer1 = self._make_layer(blocks[0], self.list_channels[0], layers[0], stride=list_stride[0])
42 | self.list_inplanes.append(self.inplanes) # store the inplanes after layer1
43 | self.layer2 = self._make_layer(blocks[1], self.list_channels[1], layers[1], stride=list_stride[1])
44 | self.list_inplanes.append(self.inplanes) # store the inplanes after layer2
45 | self.layer3 = self._make_layer(blocks[2], self.list_channels[2], layers[2], stride=list_stride[2])
46 | self.list_inplanes.append(self.inplanes) # store the inplanes after layer3
47 | self.layer4 = self._make_layer(blocks[3], self.list_channels[3], layers[3], stride=list_stride[3])
48 | self.avgpool, self.avgpool_space, self.avgpool_time = None, None, None
49 |
50 | # Init of the weights
51 | for m in self.modules():
52 | if isinstance(m, nn.Conv3d) or isinstance(m, nn.Conv2d):
53 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
54 | m.weight.data.normal_(0, math.sqrt(2. / n))
55 | elif isinstance(m, nn.BatchNorm3d) or isinstance(m, nn.BatchNorm2d):
56 | m.weight.data.fill_(1)
57 | m.bias.data.zero_()
58 |
59 | def _first_conv(self, str):
60 | self.conv1_1t = None
61 | self.bn1_1t = None
62 | if str == '3D_stabilize':
63 | self.conv1 = nn.Conv3d(3, 64, kernel_size=(K_1st_CONV, 7, 7), stride=(1, 2, 2), padding=(1, 3, 3),
64 | bias=False)
65 | self.maxpool = nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1))
66 | self.bn1 = nn.BatchNorm3d(64)
67 |
68 |
69 | elif str == '2.5D_stabilize':
70 | self.conv1 = nn.Conv3d(3, 64, kernel_size=(1, 7, 7), stride=(1, 2, 2), padding=(0, 3, 3),
71 | bias=False)
72 | self.conv1_1t = nn.Conv3d(64, 64, kernel_size=(K_1st_CONV, 1, 1), stride=(1, 1, 1),
73 | padding=(1, 0, 0),
74 | bias=False)
75 | self.bn1_1t = nn.BatchNorm3d(64)
76 | self.maxpool = nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1))
77 | self.bn1 = nn.BatchNorm3d(64)
78 |
79 | elif str == '2D':
80 | self.conv1 = nn.Conv2d(3, 64,
81 | kernel_size=(7, 7),
82 | stride=(2, 2),
83 | padding=(3, 3),
84 | bias=False)
85 | self.maxpool = nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
86 | self.bn1 = nn.BatchNorm2d(64)
87 |
88 | else:
89 | raise NameError
90 |
91 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
92 | downsample = None
93 |
94 | # Upgrade the stride is spatio-temporal kernel
95 | if not (block == BasicBlock2D or block == Bottleneck2D):
96 | stride = (1, stride, stride)
97 |
98 | if stride != 1 or self.inplanes != planes * block.expansion:
99 | if block is BasicBlock2D or block is Bottleneck2D:
100 | conv, batchnorm = nn.Conv2d, nn.BatchNorm2d
101 | else:
102 | conv, batchnorm = nn.Conv3d, nn.BatchNorm3d
103 |
104 | downsample = nn.Sequential(
105 | conv(self.inplanes, planes * block.expansion,
106 | kernel_size=1, stride=stride, bias=False, dilation=dilation),
107 | batchnorm(planes * block.expansion),
108 | )
109 |
110 | layers = []
111 | layers.append(
112 | block(self.inplanes, planes, stride, downsample, dilation, nb_temporal_conv=self.nb_temporal_conv))
113 | self.inplanes = planes * block.expansion
114 | for i in range(1, blocks):
115 | layers.append(block(self.inplanes, planes, nb_temporal_conv=self.nb_temporal_conv))
116 |
117 | return nn.Sequential(*layers)
118 |
119 | def forward(self, x, num=4) :
120 | # pdb.set_trace()
121 |
122 | x = self.conv1(x)
123 | x = self.bn1(x)
124 | x = self.relu(x)
125 |
126 | if self.conv1_1t is not None:
127 | x = self.conv1_1t(x)
128 | x = self.bn1_1t(x)
129 | x = self.relu(x)
130 |
131 | x = self.maxpool(x)
132 |
133 | x = self.layer1(x)
134 | x = self.layer2(x)
135 | x = self.layer3(x)
136 | x = self.layer4(x)
137 |
138 | # Global average pooling
139 | self.avgpool = nn.AvgPool2d((x.size(-1), x.size(-1))) if self.avgpool is None else self.avgpool
140 | x = self.avgpool(x)
141 |
142 | # Final classifier
143 | # x = x.view(x.size(0), -1)
144 | # x = self.fc_classifier(x)
145 |
146 | return x
147 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import pickle
4 | import copy
5 | import sys, traceback, code
6 | import torch
7 |
8 | import data
9 | import models
10 | import utils
11 | from test import evaluate
12 |
13 | import pdb
14 |
15 | import wandb
16 |
17 | N_EVAL_EPOCHS = 3
18 |
19 | opt, logger = utils.build(is_train=True)
20 | with open(os.path.join(opt.ckpt_path, 'opt.pkl'), 'wb') as handle:
21 | pickle.dump(opt, handle)
22 |
23 | tags = [opt.model, opt.branch, 'w/ bbox' if opt.use_bbox else 'no bbox', 'seq{}'.format(opt.seq_len)]
24 | if opt.model == 'graph':
25 | tags += ['{} layer'.format(opt.n_layers), opt.adj_type]
26 | tags += ['diff {}'.format(opt.diff_layer_weight)]
27 | tags += [opt.adj_type]
28 | if opt.model == 'pos' or opt.pos_mode != 'none':
29 | tags += opt.pos_mode,
30 | if opt.predict:
31 | tags += 'pred{}'.format(opt.pred_seq_len),
32 | if opt.predict_k:
33 | tags += 'pred_k{}'.format(opt.predict_k),
34 | if opt.ped_gru:
35 | tags += 'pedGRU',
36 | if opt.ctxt_gru:
37 | tags += 'ctxtGRU',
38 | if opt.ctxt_node:
39 | tags += 'ctxtNode',
40 | if opt.load_cache == 'none':
41 | tags += 'cacheNone',
42 | if opt.load_cache == 'masks':
43 | tags += 'cacheMasks',
44 | if opt.load_cache == 'feats':
45 | tags += 'cacheFeats',
46 | if opt.use_driver:
47 | tags += 'driver',
48 | tags += '{}evalEpochs'.format(N_EVAL_EPOCHS),
49 | wandb.init(
50 | project='crossing',
51 | tags=tags)
52 |
53 | opt_dict = vars(opt)
54 | print('Options:')
55 | for key in sorted(opt_dict):
56 | print('{}: {}'.format(key, opt_dict[key]))
57 |
58 |
59 | # train
60 | print(opt)
61 | train_loader = data.get_data_loader(opt)
62 | print('Train dataset: {}'.format(len(train_loader.dataset)))
63 |
64 | # val
65 | # val_opt = copy.deepcopy(opt)
66 | val_opt, _ = utils.build(is_train=False)
67 | val_opt.split = 'test'
68 | val_opt.slide = 0
69 | val_opt.is_train = False
70 | val_opt.rand_test = True
71 | val_opt.batch_size = 1
72 | val_opt.slide = 0
73 | val_loader = data.get_data_loader(val_opt)
74 | print('Val dataset: {}'.format(len(val_loader.dataset)))
75 |
76 | model = models.get_model(opt)
77 | # model = model.to('cuda:{}'.format(opt.device))
78 | if opt.pretrained_path and os.path.exist(opt.pretrained_path):
79 | print('Loading model from', opt.pretrained_path)
80 | model.load(opt.pretrained_path)
81 | model = model.to('cuda:0')
82 | wandb.watch(model)
83 |
84 | if opt.load_ckpt_dir != '':
85 | ckpt_dir = os.path.join(opt.ckpt_dir, opt.dset_name, opt.load_ckpt_dir)
86 | assert os.path.exists(ckpt_dir)
87 | logger.print('Loading checkpoint from {}'.format(ckpt_dir))
88 | model.load(ckpt_dir, opt.load_ckpt_epoch)
89 |
90 | opt.n_epochs = max(opt.n_epochs, opt.n_iters // len(train_loader))
91 | logger.print('Total epochs: {}'.format(opt.n_epochs))
92 |
93 |
94 | def train():
95 | best_eval_acc = 0
96 | best_eval_loss = 10
97 | best_epoch = -1
98 |
99 | if val_opt.predict or val_opt.predict_k:
100 | # pred
101 | best_eval_acc_pred = 0
102 | best_eval_loss_pred = 10
103 | best_epoch_pred = -1
104 | # last
105 | best_eval_acc_last = 0
106 | best_eval_loss_last = 10
107 | best_epoch_last = -1
108 |
109 | for epoch in range(opt.start_epoch, opt.n_epochs):
110 | model.setup()
111 | print('Train epoch', epoch)
112 | model.update_hyperparameters(epoch)
113 |
114 | losses = []
115 | for step, data in enumerate(train_loader):
116 | # break
117 | if epoch == 0:
118 | torch.cuda.empty_cache()
119 | # break
120 | loss = model.step_train(data)
121 | losses += loss,
122 |
123 | torch.cuda.empty_cache()
124 |
125 | if step % opt.log_every == 0:
126 | print('avg loss:', sum(losses) / len(losses))
127 | wandb.log({"Train loss:":sum(losses) / len(losses)})
128 | losses = []
129 |
130 | # Evaluate on val set
131 | if opt.evaluate_every > 0 and (epoch + 1) % opt.evaluate_every == 0:
132 | result_det, result_pred, result_last = evaluate(model, val_loader, val_opt, n_eval_epochs=N_EVAL_EPOCHS)
133 | eval_acc_frame, eval_acc_clip, eval_acc_cross, eval_acc_non_cross, eval_loss = result_det
134 | if eval_acc_frame > best_eval_acc:
135 | best_eval_acc = eval_acc_frame
136 | best_eval_loss = eval_loss
137 | best_epoch = epoch+1
138 | model.save(opt.ckpt_path, best_epoch, 'best_det')
139 | wandb.log({
140 | 'eval_acc_frame':eval_acc_frame, 'eval_acc_clip':eval_acc_clip,
141 | 'eval_acc_cross':eval_acc_cross, 'eval_acc_non_cross':eval_acc_non_cross,
142 | 'eval_loss':eval_loss,
143 | 'best_eval_acc': best_eval_acc, 'best_eval_loss':best_eval_loss, 'best_epoch':best_epoch})
144 |
145 | if val_opt.predict or val_opt.predict_k:
146 | # pred
147 | eval_acc_frame, eval_acc_clip, eval_acc_cross, eval_acc_non_cross, eval_loss = result_pred
148 | if eval_acc_frame > best_eval_acc_pred:
149 | best_eval_acc_pred = eval_acc_frame
150 | best_eval_loss_pred = eval_loss
151 | best_epoch_pred = epoch+1
152 | model.save(opt.ckpt_path, best_epoch_pred, 'best_pred')
153 | wandb.log({
154 | 'eval_acc_frame_pred':eval_acc_frame, 'eval_acc_clip_pred':eval_acc_clip,
155 | 'eval_acc_cross_pred':eval_acc_cross, 'eval_acc_non_cross_pred':eval_acc_non_cross,
156 | 'eval_loss_pred':eval_loss,
157 | 'best_eval_acc_pred': best_eval_acc_pred, 'best_eval_loss_pred':best_eval_loss_pred, 'best_epoch_pred':best_epoch_pred})
158 | # last
159 | eval_acc_frame, eval_acc_clip, eval_acc_cross, eval_acc_non_cross, eval_loss = result_last
160 | if eval_acc_frame > best_eval_acc_last:
161 | best_eval_acc_last = eval_acc_frame
162 | best_eval_loss_last = eval_loss
163 | best_epoch_last = epoch+1
164 | model.save(opt.ckpt_path, best_epoch_last, 'best_last')
165 | wandb.log({
166 | 'eval_acc_frame_last':eval_acc_frame, 'eval_acc_clip_last':eval_acc_clip,
167 | 'eval_acc_cross_last':eval_acc_cross, 'eval_acc_non_cross_last':eval_acc_non_cross,
168 | 'eval_loss_last':eval_loss,
169 | 'best_eval_acc_last': best_eval_acc_last, 'best_eval_loss_last':best_eval_loss_last, 'best_epoch_last':best_epoch_last})
170 |
171 |
172 | # Save model checkpoints
173 | if (epoch + 1) % opt.save_every == 0 and epoch >= 0 or epoch == opt.n_epochs - 1:
174 | model.save(opt.ckpt_path, epoch+1)
175 |
176 |
177 | try:
178 | train()
179 | except Exception as e:
180 | print(e)
181 | typ, vacl, tb = sys.exc_info()
182 | traceback.print_exc()
183 | last_frame = lambda tb=tb: last_frame(tb.tb_next) if tb.tb_next else tb
184 | frame = last_frame().tb_frame
185 | ns = dict(frame.f_globals)
186 | ns.update(frame.f_locals)
187 | code.interact(local=ns)
188 |
189 |
--------------------------------------------------------------------------------
/utils/stip_merge_views.py:
--------------------------------------------------------------------------------
1 | import os
2 | from glob import glob
3 | import numpy as np
4 | import pickle
5 | import torch
6 | import pdb
7 |
8 |
9 | if False:
10 | # Objs
11 | print('Objects')
12 | root = '/vision/group/prolix/processed'
13 | center_root = os.path.join(root, 'center', 'obj_bbox_20fps')
14 | left_root = os.path.join(root, 'left', 'obj_bbox_20fps')
15 | right_root = os.path.join(root, 'right', 'obj_bbox_20fps')
16 | merged_root = os.path.join(root, 'all', 'obj_bbox_20fps')
17 | os.makedirs(merged_root, exist_ok=True)
18 |
19 | def merged_obj_loader(f):
20 | with open(f, 'rb') as handle:
21 | data = pickle.load(handle)
22 |
23 | obj_cls, obj_bbox = [], []
24 | for (cs, bs) in zip(data['obj_cls'], data['obj_bbox']):
25 | # bs: [xmin, ymin, w, h]
26 | curr_cls, curr_bbox = [], []
27 | for i in range(len(bs)):
28 | if bs[i, 3] != 0 and bs[i,2] != 0:
29 | # keep only non-empty bbox
30 | curr_cls += cs[i],
31 | curr_bbox += bs[i],
32 | obj_cls += np.array(curr_cls),
33 | obj_bbox += np.array(curr_bbox),
34 |
35 | return obj_cls, obj_bbox
36 |
37 | segs = sorted(glob(os.path.join(center_root, '*pkl')))
38 | for seg in segs:
39 | fl = seg.replace(center_root, left_root).replace('pkl', 'pkl.pkl')
40 | fr = seg.replace(center_root, right_root).replace('pkl', 'pkl.pkl')
41 | if not os.path.exists(fl) or not os.path.exists(fr):
42 | continue
43 | fout = seg.replace(center_root, merged_root)
44 | # if os.path.exists(fout):
45 | # continue
46 |
47 | with open(seg, 'rb') as handle:
48 | data = pickle.load(handle)
49 | with open(fl, 'rb') as handle:
50 | data_l = pickle.load(handle)
51 | with open(fr, 'rb') as handle:
52 | data_r = pickle.load(handle)
53 | assert(len(data) == len(data_l))
54 | assert(len(data) == len(data_r))
55 | data_all = {}
56 | for cls in data:
57 | data_l[cls] = [[x-1216, y, h, w] for (x,y,h,w) in data_l[cls]]
58 | data_r[cls] = [[x+1216, y, h, w] for (x,y,h,w) in data_l[cls]]
59 | data_all[cls] = data[cls] + data_l[cls] + data_r[cls]
60 | with open(fout, 'wb') as handle:
61 | pickle.dump(data_all, handle)
62 |
63 | # obj_cls_c, obj_bbox_c = obj_loader(seg)
64 | # obj_cls_l, obj_bbox_l = obj_loader(fl)
65 | # obj_cls_r, obj_bbox_r = obj_loader(fr)
66 |
67 | # obj_cls = [np.concatenate([c,l,r], 0) for (c,l,r) in zip(obj_cls_c, obj_cls_l, obj_cls_r)]
68 | # obj_bbox = []
69 | # for (c,l,r) in zip(obj_bbox_c, obj_bbox_l, obj_bbox_r):
70 | # valid = [each for each in [c,l,r] if len(each.shape) > 1]
71 | # if valid:
72 | # a = np.concatenate(valid, 0)
73 | # else:
74 | # a = np.array([])
75 | # obj_bbox += a,
76 | # # obj_bbox = [np.concatenate([c,l,r], 0) for (c,l,r) in zip(obj_bbox_c, obj_bbox_l, obj_bbox_r)]
77 | #
78 | # with open(fout, 'wb') as handle:
79 | # pickle.dump({'obj_cls':obj_cls, 'obj_bbox':obj_bbox}, handle)
80 |
81 | # Masks
82 | if False:
83 | print('Masks')
84 | root = '/vision/group/prolix/processed'
85 | center_root = os.path.join(root, 'cache/center')
86 | left_root = os.path.join(root, 'cache/left')
87 | right_root = os.path.join(root, 'cache/right')
88 | merged_root = os.path.join(root, 'cache/all')
89 | os.makedirs(merged_root, exist_ok=True)
90 |
91 | for split in ['train']:
92 | split_root = os.path.join(merged_root, split)
93 | os.makedirs(split_root, exist_ok=True)
94 |
95 | caches = sorted(glob(os.path.join(center_root, split, '*pkl')))
96 | for fc in caches:
97 | fl = fc.replace(center_root, left_root)
98 | fr = fc.replace(center_root, right_root)
99 | if not os.path.exists(fl) or not os.path.exists(fr):
100 | continue
101 | fout = fc.replace(center_root, merged_root)
102 | # if os.path.exists(fout):
103 | # continue
104 |
105 | with open(fc, 'rb') as handle:
106 | data = pickle.load(handle)
107 | with open(fl, 'rb') as handle:
108 | data_l = pickle.load(handle)
109 | with open(fr, 'rb') as handle:
110 | data_r = pickle.load(handle)
111 |
112 | data_out = {'ped_crops': data['ped_crops']}
113 | merged_masks = {k:[] for k in range(1,5)}
114 | keys = list(merged_masks.keys())
115 | for k in keys:
116 | if k in data['masks']:
117 | merged_masks[k] += data['masks'][k],
118 | if k in data_l['masks']:
119 | merged_masks[k] += data_l['masks'][k],
120 | if k in data_r['masks']:
121 | merged_masks[k] += data_r['masks'][k],
122 | if len(merged_masks[k]) == 0:
123 | merged_masks.pop(k)
124 | else:
125 | merged_masks[k] = torch.cat(merged_masks[k], 0)
126 | data_out['masks'] = merged_masks
127 |
128 | with open(fout, 'wb') as handle:
129 | pickle.dump(data_out, handle)
130 |
131 | # Feats
132 | if True:
133 | print('Feats')
134 | root = '/vision/group/prolix/processed/cache/'
135 | center_root = os.path.join(root, 'center/STIP_conv_feats')
136 | left_root = os.path.join(root, 'left/STIP_conv_feats')
137 | right_root = os.path.join(root, 'right/STIP_conv_feats')
138 | merged_root = os.path.join(root, 'all/STIP_conv_feats')
139 | os.makedirs(merged_root, exist_ok=True)
140 |
141 | for split in ['train']:
142 | caches = sorted(glob(os.path.join(center_root, "concat*", split, '*pkl')))
143 | split_root = os.path.dirname(caches[0]).replace(center_root, merged_root)
144 | os.makedirs(split_root, exist_ok=True)
145 |
146 | for fc in caches:
147 | fl = fc.replace(center_root, left_root)
148 | fr = fc.replace(center_root, right_root)
149 | if not os.path.exists(fl) or not os.path.exists(fr):
150 | print('Skipping')
151 | continue
152 | fout = fc.replace(center_root, merged_root)
153 | # pdb.set_trace()
154 | # if os.path.exists(fout):
155 | # continue
156 |
157 | with open(fc, 'rb') as handle:
158 | data = pickle.load(handle)
159 | with open(fl, 'rb') as handle:
160 | data_l = pickle.load(handle)
161 | with open(fr, 'rb') as handle:
162 | data_r = pickle.load(handle)
163 |
164 | data_out = {'ped_feats': data['ped_feats']}
165 | merged_feats = []
166 | if data['ctxt_feats'].shape[0] != 1 or data['ctxt_feats'].sum() != 0:
167 | merged_feats += data['ctxt_feats'],
168 | if data_l['ctxt_feats'].shape[0] != 1 or data_l['ctxt_feats'].sum() != 0:
169 | merged_feats += data_l['ctxt_feats'],
170 | if data_r['ctxt_feats'].shape[0] != 1 or data_r['ctxt_feats'].sum() != 0:
171 | merged_feats += data_r['ctxt_feats'],
172 | merged_feats = torch.cat(merged_feats, 0)
173 | merged_cls = torch.cat([data['ctxt_cls'], data_l['ctxt_cls'], data_r['ctxt_cls']], 0)
174 | data_out['ctxt_feats'] = merged_feats
175 | data_out['ctxt_cls'] = merged_cls
176 |
177 | with open(fout, 'wb') as handle:
178 | pickle.dump(data_out, handle)
179 |
--------------------------------------------------------------------------------
/cache_data_stip.py:
--------------------------------------------------------------------------------
1 | # NOTE: this script uses segmentation files only
2 | # and does not rely on pedestrian annotations.
3 |
4 | # python imports
5 | import os
6 | import sys
7 | from glob import glob
8 | import pickle
9 | import time
10 | import numpy as np
11 |
12 |
13 | # local imports
14 | import data
15 | import utils
16 | from utils.data_proc_stip import parse_objs
17 |
18 | import pdb
19 |
20 |
21 | # def cache_masks():
22 | # opt, logger = utils.build(is_train=False)
23 | # opt.combine_method = ''
24 | # opt.split = 'train'
25 | # cache_dir_name = 'jaad_collapse{}'.format('_'+opt.combine_method if opt.combine_method else '')
26 | # data.cache_all_objs(opt, cache_dir_name)
27 | #
28 | #
29 | # def cache_crops():
30 | # fnpy_root = '/sailhome/bingbin/STR-PIP/datasets/JAAD_instance_segm'
31 | # fpkl_root = '/sailhome/bingbin/STR-PIP/datasets/cache/JAAD_instance_crops'
32 | # utils.get_obj_crops(fnpy_root, fpkl_root)
33 |
34 | def add_obj_bbox(view=''):
35 | if not view:
36 | vid_root = '/vision/group/prolix/instances_20fps/stip_instances/'
37 | fobj_root = '/vision/group/prolix/processed/obj_bbox_20fps'
38 | elif view == 'center':
39 | vid_root = '/vision/group/prolix/instances_20fps/stip_instances/'
40 | fobj_root = '/vision/group/prolix/processed/center/obj_bbox_20fps'
41 | else: # view = 'left' or 'right'
42 | vid_root = '/vision/group/prolix/stip_side/{}/instances_20fps/'.format(view)
43 | fobj_root = '/vision/group/prolix/processed/{}/obj_bbox_20fps/'.format(view)
44 | os.makedirs(fobj_root, exist_ok=True)
45 |
46 | def helper(vid_range, split):
47 | for dir_vid in vid_range:
48 | print(dir_vid)
49 | sys.stdout.flush()
50 | t_start = time.time()
51 |
52 | if not view or view == 'center':
53 | segs = sorted(glob(os.path.join(vid_root, dir_vid, 'inference', '*--*')))
54 | else:
55 | segs = sorted(glob(os.path.join(vid_root, dir_vid, '*--*')))
56 |
57 | for seg in segs:
58 | if not view or view == 'center':
59 | fsegms = sorted(glob(os.path.join(seg, '*_segm.npy')))
60 | else:
61 | fsegms = sorted(glob(os.path.join(seg, '*.pkl')))
62 | for i, fsegm in enumerate(fsegms):
63 | if i and i%100 == 0:
64 | print('Time per frame:', (time.time() - t_start)/i)
65 | sys.stdout.flush()
66 | fid = os.path.basename(fsegm).split('_')[0]
67 | fbbox = os.path.join(fobj_root, '{:s}_seg{:s}_fid{:s}.pkl'.format(dir_vid, os.path.basename(seg), fid))
68 | # if 'ANN_conor1_seg12:22--12:59_fid0000015483.pkl' not in fbbox:
69 | # continue
70 | if os.path.exists(fbbox):
71 | continue
72 | if not os.path.exists(fsegm):
73 | print('File does not exist:', fsegm)
74 | continue
75 | objs = parse_objs(fsegm)
76 | dobjs = {cls:[] for cls in range(1,5)}
77 | for cls, masks in objs.items():
78 | for mask in masks:
79 | try:
80 | if len(mask.shape) == 3:
81 | h, w, c = mask.shape
82 | if c != 1:
83 | raise ValueError('Each mask should have shape (1080, 1920, 1)')
84 | mask = mask.reshape(h, w)
85 | x_pos = mask.sum(0).nonzero()[0]
86 | if not len(x_pos):
87 | x_pos = [0,0]
88 | x_min, x_max = x_pos[0], x_pos[-1]
89 | y_pos = mask.sum(1).nonzero()[0]
90 | if not len(y_pos):
91 | y_pos = [0,0]
92 | y_min, y_max = y_pos[0], y_pos[-1]
93 | # bbox: [x_min, y_min, w, h]; same as bbox for ped['pos_GT']
94 | bbox = [x_min, y_min, x_max-x_min, y_max-y_min]
95 | except Exception as e:
96 | print(e)
97 | pdb.set_trace()
98 |
99 | dobjs[cls] += bbox,
100 | with open(fbbox, 'wb') as handle:
101 | pickle.dump(dobjs, handle)
102 |
103 | vids = sorted(glob(os.path.join(vid_root, '*_*')))
104 | vids = [os.path.basename(vid) for vid in vids]
105 | vids_test = ['downtown_ann_3-09-28-2017', 'downtown_palo_alto_6', 'dt_san_jose_4', 'mountain_view_4',
106 | 'sf_soma_2']
107 | vids_train = [vid for vid in vids if vid not in vids_test]
108 |
109 | # tmp
110 | # vids_train = ['downtown_ann_1-09-27-2017', 'downtown_ann_2-09-27-2017', 'downtown_ann_3-09-27-2017', 'downtown_ann_1-09-28-2017']
111 | # vids_test = []
112 |
113 | if True:
114 | helper(vids_train, 'train')
115 | if False:
116 | helper(vids_test, 'test')
117 |
118 | def merge_and_flat(vrange, view=''):
119 | """
120 | Merge fids in a vid and flatten the classes
121 | """
122 | if not view:
123 | pkl_in_root = '/vision/group/prolix/processed/obj_bbox_20fps'
124 | pkl_out_root = '/vision/group/prolix/processed/obj_bbox_20fps_merged'
125 | else:
126 | pkl_in_root = '/vision/group/prolix/processed/{}/obj_bbox_20fps/'.format(view)
127 | pkl_out_root = '/vision/group/prolix/processed/{}/obj_bbox_20fps_merged'.format(view)
128 | os.makedirs(pkl_out_root, exist_ok=True)
129 |
130 | for vid in vrange:
131 | print(vid)
132 | fpkls = sorted(glob(os.path.join(pkl_in_root, '{:s}*pkl'.format(vid))))
133 | segs = list(set([fpkl.split('_fid')[0] for fpkl in fpkls]))
134 | print('# segs:', len(segs))
135 | for seg in segs:
136 | fpkls = sorted(glob(seg+'*pkl'))
137 | print(vid, len(fpkls))
138 | sys.stdout.flush()
139 | # merged = [[] for _ in range(len(fpkls))]
140 | merged_bbox = []
141 | merged_cls = []
142 | t_start = time.time()
143 | for fpkl in fpkls:
144 | try:
145 | with open(fpkl, 'rb') as handle:
146 | data = pickle.load(handle)
147 | except:
148 | pdb.set_trace()
149 | curr_bbox = []
150 | cls = []
151 | for c in [1,2,3,4]:
152 | for bbox in data[c]:
153 | cls += c,
154 | curr_bbox += bbox,
155 | merged_bbox += np.array(curr_bbox),
156 | merged_cls += np.array(cls),
157 |
158 | seg = seg.split('seg')[-1]
159 | fpkl_out = os.path.join(pkl_out_root, '{}_seg{}.pkl'.format(vid, seg))
160 | with open(fpkl_out, 'wb') as handle:
161 | dout = {
162 | 'obj_cls': merged_cls,
163 | 'obj_bbox': merged_bbox,
164 | }
165 | pickle.dump(dout, handle)
166 | print('avg time: ', (time.time()-t_start) / len(fpkls))
167 |
168 |
169 | def merge_and_flat_wrapper(view=''):
170 | vid_root = '/vision/group/prolix/instances_20fps/stip_instances/'
171 | vids = sorted(glob(os.path.join(vid_root, '*_*')))
172 | vids = [os.path.basename(vid) for vid in vids]
173 | vids_test = ['downtown_ann_3-09-28-2017', 'downtown_palo_alto_6', 'dt_san_jose_4', 'mountain_view_4',
174 | 'sf_soma_2']
175 | vids_train = [vid for vid in vids if vid not in vids_test]
176 |
177 | # tmp
178 | # vids_train = ['downtown_ann_1-09-27-2017', 'downtown_ann_2-09-27-2017', 'downtown_ann_3-09-27-2017', 'downtown_ann_1-09-28-2017']
179 | # vids_test = []
180 |
181 | if True:
182 | merge_and_flat(vids_train, view=view)
183 | if False:
184 | merge_and_flat(vids_test, view=view)
185 |
186 |
187 | def cache_loc():
188 | def helper(annots):
189 | for vid in annots:
190 | annot = annots[vid]
191 | n_frames = len(annot['act'])
192 | for fid in range(n_frames):
193 | # fid in ped cache file name: 1 based
194 | loc
195 | # ped:
196 | # ped['ped_crops']: ndarray: (3, 224, 224)
197 | # ped['masks']: tensor: [n_objs, 224, 224]
198 |
199 |
200 | if __name__ == '__main__':
201 | # cache_masks()
202 | # cache_crops()
203 |
204 | # add_obj_bbox(view='left')
205 | # merge_and_flat_wrapper(view='left')
206 | # add_obj_bbox(view='right')
207 | # merge_and_flat_wrapper(view='right')
208 | # add_obj_bbox(view='center')
209 | merge_and_flat_wrapper(view='center')
210 |
211 | # merge_and_flat(range(1, 200))
212 | # merge_and_flat(range(100, 200))
213 | # merge_and_flat(range(200, 300))
214 | # merge_and_flat(range(100, 347))
215 | # merge_and_flat(range(200, 347))
216 |
217 |
--------------------------------------------------------------------------------
/data/get_data_loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import numpy as np
4 | from PIL import Image
5 | import torch
6 | torch.manual_seed(2019)
7 | torch.cuda.manual_seed_all(2019)
8 | import torch.utils.data as data
9 | import torchvision.transforms as transforms
10 | import argparse
11 |
12 | import time
13 | import pdb
14 |
15 | try:
16 | from .jaad import JAADDataset
17 | from .jaad_loc import JAADLocDataset
18 | from .stip import STIPDataset
19 | except:
20 | from jaad import JAADDataset
21 | from jaad_loc import JAADLocDataset
22 | from stip import STIPDataset
23 |
24 | def jaad_collate(batch):
25 | # each item in batch: a tuple of
26 | # 1. ped_crops: (30, 224, 224, 3)
27 | # 2. masks: list of len 30: each = dict of ndarrays: (n_obj, 1080, 1920, 1)
28 | # 3. GT_act: binary ndarray: (30, 9)
29 | ped_crops = []
30 | masks = []
31 | GT_act, GT_bbox, GT_pose = [], [], []
32 | obj_bbox, obj_cls = [], []
33 | fids = []
34 | img_paths = []
35 | for each in batch:
36 | ped_crops += each['ped_crops'],
37 | masks += each['all_masks'],
38 | GT_act += each['GT_act'],
39 | GT_bbox += each['GT_bbox'],
40 | GT_pose += each['GT_pose'],
41 | obj_bbox += each['obj_bbox'],
42 | obj_cls += each['obj_cls'],
43 | obj_bbox_l += each['obj_bbox_l'],
44 | obj_cls_l += each['obj_cls_l'],
45 | obj_bbox_r += each['obj_bbox_r'],
46 | obj_cls_r += each['obj_cls_r'],
47 | fids += each['fids'],
48 | img_paths += each['img_paths'],
49 | ped_crops = torch.stack(ped_crops)
50 | GT_act = torch.stack(GT_act)
51 | GT_bbox = torch.stack(GT_bbox)
52 | GT_pose = torch.stack(GT_pose)
53 | fids = torch.stack(fids)
54 | ret = {
55 | 'ped_crops': ped_crops,
56 | 'all_masks': masks,
57 | 'GT_act': GT_act,
58 | 'GT_bbox': GT_bbox,
59 | 'GT_pose': GT_pose,
60 | 'obj_cls': obj_cls,
61 | 'obj_bbox': obj_bbox,
62 | 'fids': fids,
63 | 'img_paths': img_paths,
64 | }
65 | if 'frames' in batch[0]:
66 | ret['frames'] = torch.stack([each['frames'] for each in batch], 0)
67 | ret['GT_driver_act'] = torch.stack([each['GT_driver_act'] for each in batch], 0)
68 |
69 | return ret
70 |
71 |
72 | def jaad_loc_collate(batch):
73 | # each item in batch: a tuple of
74 | # 1. ped_crops: (30, 224, 224, 3)
75 | # 2. masks: list of len 30: each = dict of ndarrays: (n_obj, 1080, 1920, 1)
76 | # 3. GT_act: binary ndarray: (30, 9)
77 | ped_crops = []
78 | masks = []
79 | GT_act, GT_ped_bbox = [], []
80 | obj_bbox, obj_cls = [], []
81 | fids = []
82 | for each in batch:
83 | ped_crops += each['ped_crops'],
84 | masks += each['all_masks'],
85 | GT_act += each['GT_act'],
86 | GT_ped_bbox += each['GT_ped_bbox'],
87 | obj_bbox += each['obj_bbox'],
88 | obj_cls += each['obj_cls'],
89 | fids += each['fids'],
90 | GT_act = torch.stack(GT_act)
91 | fids = torch.stack(fids)
92 | ret = {
93 | 'ped_crops': ped_crops,
94 | 'all_masks': masks,
95 | 'GT_act': GT_act,
96 | 'GT_ped_bbox': GT_ped_bbox,
97 | 'obj_cls': obj_cls,
98 | 'obj_bbox': obj_bbox,
99 | 'fids': fids,
100 | }
101 | if 'frames' in batch[0]:
102 | ret['frames'] = torch.stack([each['frames'] for each in batch], 0)
103 | return ret
104 |
105 |
106 | def stip_collate(batch):
107 | # each item in batch: a tuple of
108 | # 1. ped_crops: (30, 224, 224, 3)
109 | # 2. masks: list of len 30: each = dict of ndarrays: (n_obj, 1080, 1920, 1)
110 | # 3. GT_act: binary ndarray: (30, 9)
111 | ped_crops = []
112 | masks = []
113 | GT_act, GT_bbox, GT_pose = [], [], []
114 | obj_bbox, obj_cls = [], []
115 | fids = []
116 | img_paths = []
117 | for each in batch:
118 | ped_crops += each['ped_crops'],
119 | masks += each['all_masks'],
120 | GT_act += each['GT_act'],
121 | GT_bbox += each['GT_bbox'],
122 | # GT_pose += each['GT_pose'],
123 | obj_bbox += each['obj_bbox'],
124 | obj_cls += each['obj_cls'],
125 | fids += each['fids'],
126 | img_paths += each['img_paths'],
127 | ped_crops = torch.stack(ped_crops)
128 | GT_act = torch.stack(GT_act)
129 | GT_bbox = torch.stack(GT_bbox)
130 | if len(GT_pose):
131 | GT_pose = torch.stack(GT_pose)
132 | fids = torch.stack(fids)
133 | ret = {
134 | 'ped_crops': ped_crops,
135 | 'all_masks': masks,
136 | 'GT_act': GT_act,
137 | 'GT_bbox': GT_bbox,
138 | 'obj_cls': obj_cls,
139 | 'obj_bbox': obj_bbox,
140 | 'fids': fids,
141 | 'img_paths': img_paths,
142 | }
143 | if len(GT_pose):
144 | ret['GT_pose'] = GT_pose
145 | if 'frames' in batch[0]:
146 | ret['frames'] = torch.stack([each['frames'] for each in batch], 0)
147 | ret['GT_driver_act'] = torch.stack([each['GT_driver_act'] for each in batch], 0)
148 |
149 | return ret
150 |
151 |
152 |
153 | def get_data_loader(opt):
154 | if opt.dset_name.lower() == 'jaad':
155 | dset = JAADDataset(opt)
156 | print('Built JAADDataset.')
157 | collate_fn = jaad_collate
158 |
159 | elif opt.dset_name.lower() == 'jaad_loc':
160 | dset = JAADLocDataset(opt)
161 | print('Built JAADLocDataset.')
162 | collate_fn = jaad_loc_collate
163 |
164 | elif opt.dset_name.lower() == 'stip':
165 | dset = STIPDataset(opt)
166 | print('Built STIPDataset')
167 | collate_fn = stip_collate
168 |
169 | else:
170 | raise NotImplementedError('Sorry but we currently only support JAAD. ^ ^b')
171 |
172 | dloader = data.DataLoader(dset,
173 | batch_size=opt.batch_size,
174 | shuffle=opt.is_train,
175 | num_workers=opt.n_workers,
176 | pin_memory=True,
177 | collate_fn=collate_fn,
178 | )
179 |
180 | return dloader
181 |
182 |
183 | def cache_all_objs(opt, cache_dir_name):
184 | opt.is_train = False
185 |
186 | opt.collapse_cls = 1
187 | cache_dir_root = '/sailhome/ajarno/STR-PIP/datasets/cache/'
188 | cache_dir = os.path.join(cache_dir_root, cache_dir_name)
189 | os.makedirs(cache_dir, exist_ok=True)
190 | opt.save_cache_format = os.path.join(cache_dir, opt.split, 'ped{}_fid{}.pkl')
191 | os.makedirs(os.path.dirname(opt.save_cache_format), exist_ok=True)
192 |
193 | dset = JAADDataset(opt)
194 | dloader = data.DataLoader(dset,
195 | batch_size=1,
196 | shuffle=False,
197 | num_workers=0,
198 | collate_fn=jaad_collate)
199 |
200 | t_start = time.time()
201 | for i,each in enumerate(dloader):
202 | if i%50 == 0 and i:
203 | print('{}: avg time: {:.3f}'.format(i, (time.time()-t_start) / 50))
204 | t_start = time.time()
205 |
206 |
207 | if __name__ == '__main__':
208 | parser = argparse.ArgumentParser()
209 | parser.add_argument('--dset-name', type=str, default='JAAD_loc')
210 | parser.add_argument('--annot-ped-format', type=str, default='/sailhome/ajarno/STR-PIP/datasets/annot_{}_ped.pkl')
211 | parser.add_argument('--annot-loc-format', type=str, default='/sailhome/ajarno/STR-PIP/datasets/annot_{}_loc.pkl')
212 | parser.add_argument('--is-train', type=int, default=1)
213 | parser.add_argument('--split', type=str, default='train')
214 | parser.add_argument('--seq-len', type=int, default=30)
215 | parser.add_argument('--ped-crop-size', type=tuple, default=(224, 224))
216 | parser.add_argument('--mask-size', type=tuple, default=(224, 224))
217 | parser.add_argument('--collapse-cls', type=int, default=0,
218 | help='Whether to merge the classes. If 1 then each item in masks is a dict keyed by cls, otherwise a list.')
219 | parser.add_argument('--img-path-format', type=str,
220 | default='/sailhome/ajarno/STR-PIP/datasets/JAAD_dataset/JAAD_clip_images/video_{:04d}.mp4/{:d}.jpg')
221 | parser.add_argument('--fsegm-format', type=str,
222 | default='/sailhome/ajarno/STR-PIP/datasets/JAAD_instance_segm/video_{:04d}/{:08d}_segm.npy')
223 | parser.add_argument('--save-cache-format', type=str, default='')
224 | parser.add_argument('--cache-format', type=str, default='')
225 | parser.add_argument('--batch-size', type=int, default=4)
226 | parser.add_argument('--n-workers', type=int, default=0)
227 | # added to test loader
228 | parser.add_argument('--rand-test', type=int, default=1)
229 | parser.add_argument('--predict', type=int, default=0)
230 | parser.add_argument('--predict-k', type=int, default=0)
231 | parser.add_argument('--combine-method', type=str, default='none')
232 | parser.add_argument('--load-cache', type=str, default='masks')
233 | parser.add_argument('--cache-obj-bbox-format', type=str,
234 | default='/sailhome/ajarno/STR-PIP/datasets/cache/obj_bbox_merged/vid{:08d}.pkl')
235 |
236 | opt = parser.parse_args()
237 | opt.save_cache_format = '/sailhome/ajarno/STR-PIP/datasets/cache/jaad_loc/{}/vid{}_fid{}.pkl'
238 | opt.cache_format = opt.save_cache_format
239 | opt.seq_len = 1
240 | opt.split = 'test'
241 |
242 | if True:
243 | # test dloader
244 | dloader = get_data_loader(opt)
245 |
246 | # for i,eg in enumerate(dloader):
247 | # if i%100 == 0:
248 | # print(i)
249 | # sys.stdout.flush()
250 |
251 | for i,vid in enumerate(dloader.dataset.vids):
252 | print('vid:', vid)
253 | annot = dloader.dataset.annots[vid]
254 | n_frames = len(annot['act'])
255 | for fid in range(n_frames):
256 | fcache = opt.cache_format.format(opt.split, vid, fid+1)
257 | if os.path.exists(fcache):
258 | continue
259 | dloader.dataset.__getitem__(i, fid_start=fid)
260 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import pickle
4 | import sys, traceback, code
5 | import torch
6 |
7 | import data
8 | import models
9 | import utils
10 |
11 | import pdb
12 |
13 |
14 | def evaluate(model, dloader, opt, n_eval_epochs=3):
15 | print('Begin to evaluate')
16 | model.eval()
17 |
18 | if opt.collect_A:
19 | os.makedirs(os.path.dirname(opt.save_As_format), exist_ok=True)
20 |
21 | acc_det = {
22 | 'frames': 0,
23 | 'correct_frames': 0,
24 | 'clips': 0,
25 | 'correct_clips': 0,
26 | 'cross': 0,
27 | 'non_cross': 0,
28 | 'correct_cross': 0,
29 | 'correct_non_cross': 0,
30 | 'probs': None,
31 | 'loss': 0,
32 | }
33 | label_out = [] # labels output by the model
34 | label_GT = [] # GT labels (crossing)
35 | label_prob = []
36 |
37 | if opt.predict or opt.predict_k:
38 | acc_pred = {key:0 for key in acc_det}
39 | acc_last = {key:0 for key in acc_det}
40 |
41 | def helper_update_metrics(ret, acc):
42 | n_frames, n_correct_frames, n_clips, n_correct_clips = ret[:4]
43 | n_cross, n_non_cross, n_correct_cross, n_correct_non_cross = ret[4:8]
44 | probs = ret[8]
45 | loss = ret[9]
46 | preds = ret[10] # (B, T)
47 | crossing = ret[11] # (B, T)
48 |
49 | acc['frames'] += n_frames
50 | acc['correct_frames'] += n_correct_frames
51 | acc['clips'] += n_clips
52 | acc['correct_clips'] += n_correct_clips
53 | acc['cross'] += n_cross
54 | acc['non_cross'] += n_non_cross
55 | acc['correct_cross'] += n_correct_cross
56 | acc['correct_non_cross'] += n_correct_non_cross
57 | if acc['probs'] is None:
58 | acc['probs'] = probs
59 | else:
60 | acc['probs'] += probs
61 | acc['loss'] += loss
62 |
63 | return acc
64 |
65 | def helper_report_metrics(acc):
66 | if acc['probs'] is None:
67 | return 0, 0, 0, 0, 100
68 |
69 | acc_frame = acc['correct_frames'] / max(1, acc['frames'])
70 | acc_clip = acc['correct_clips'] / max(1, acc['clips'])
71 | acc_cross = acc['correct_cross'] / max(1, acc['cross'])
72 | acc_non_cross = acc['correct_non_cross'] / max(1, acc['non_cross'])
73 | avg_probs = acc['probs'] / max(1, acc['clips'])
74 | avg_loss = acc['loss'] / max(1, acc['frames'])
75 | print('Accuracy: frame:{:.5f}\t/ clip:{:.5f}'.format(acc_frame, acc_clip))
76 | print('Recall: cross:{:.5f}\t/ non-cross:{:.5f}'.format(acc_cross, acc_non_cross))
77 | print('Probs:', ' / '.join(['{}:{:.1f}'.format(i, each.item()*100) for i,each in enumerate(avg_probs)]))
78 | print('Loss: {:.3f}'.format(avg_loss))
79 | return acc_frame, acc_clip, acc_cross, acc_non_cross, avg_loss
80 |
81 | with torch.no_grad():
82 | for eid in range(n_eval_epochs):
83 | for step, data in enumerate(dloader):
84 | ret_det, ret_pred, ret_last, As = model.step_test(data, slide=opt.slide, collect_A=opt.collect_A)
85 |
86 | if ret_det is not None:
87 | acc_det = helper_update_metrics(ret_det, acc_det)
88 |
89 | if opt.predict or opt.predict_k:
90 | acc_pred = helper_update_metrics(ret_pred, acc_pred)
91 | acc_last = helper_update_metrics(ret_last, acc_last)
92 |
93 | if opt.save_output > 0 and ret_det is not None:
94 | curr_out = torch.cat([ret_det[10], ret_pred[10], ret_last[10]], -1)
95 | curr_GT = torch.cat([ret_det[11], ret_pred[11], ret_last[11]], -1)
96 | curr_prob = torch.cat([ret_det[8], ret_pred[8], ret_last[8]])
97 | label_out += curr_out,
98 | label_GT += curr_GT,
99 | label_prob += curr_prob,
100 | elif opt.save_output > 0 and ret_det is not None:
101 | label_out += ret_det[10],
102 | label_GT += ret_det[11],
103 | label_prob += ret_det[8],
104 |
105 | if As is not None:
106 | data = {
107 | 'As': As,
108 | 'fids': data['fids'],
109 | 'img_paths': data['img_paths'],
110 | 'probs': ret_pred[8], # 1D tensor of size T (avg over B)
111 | }
112 | with open(opt.save_As_format.format(step, eid), 'wb') as handle:
113 | pickle.dump(data, handle)
114 |
115 | if opt.save_output and (step+1)%opt.save_output == 0 and False:
116 | label_out = torch.cat(label_out, 0).numpy()
117 | label_GT = torch.cat(label_GT, 0).numpy()
118 | label_prob = torch.cat(label_prob, 0).numpy()
119 | with open(opt.save_output_format.format(step), 'wb') as handle:
120 | pickle.dump({'out':label_out, 'GT': label_GT, 'prob': label_prob}, handle)
121 | label_out = []
122 | label_GT = []
123 | label_prob = []
124 |
125 | torch.cuda.empty_cache()
126 |
127 | if opt.save_output_format:
128 | label_out = torch.cat(label_out, 0).numpy()
129 | label_GT = torch.cat(label_GT, 0).numpy()
130 | label_prob = torch.cat(label_prob, 0).numpy()
131 | with open(opt.save_output_format.format('all'), 'wb') as handle:
132 | pickle.dump({'out':label_out, 'GT': label_GT, 'prob': label_prob}, handle)
133 |
134 | print('Detection:')
135 | result_det = helper_report_metrics(acc_det)
136 | if opt.predict or opt.predict_k:
137 | print('Prediction:')
138 | result_pred = helper_report_metrics(acc_pred)
139 | result_last = helper_report_metrics(acc_last)
140 | print()
141 | return result_det, result_pred, result_last
142 |
143 | print()
144 | return result_det, None, None
145 |
146 |
147 | def extract_feats(model, dloader, extract_feats_dir, seq_len=30):
148 | print('Begin to extract')
149 | model.eval()
150 |
151 | n_peds = len(dloader)
152 | print('n_peds:', n_peds)
153 |
154 | for pid in range(0, n_peds):
155 | ped = dloader.dataset.peds[pid]
156 | if 'frame_end' in ped:
157 | # JAAD setting
158 | n_frames = ped['frame_end'] - ped['frame_start'] + 1
159 | fid_range = range(ped['frame_start'], ped['frame_end']+1)
160 | fid_display = list(fid_range)
161 | elif 'fids20' in ped:
162 | # STIP setting
163 | n_frames = len(ped['fids20'])
164 | fid_range = range(n_frames)
165 | fid_display = ped['fids20']
166 | else:
167 | print("extract_feats: missing/unexpected keys... o_o")
168 | pdb.set_trace()
169 |
170 | for fid,fid_dis in zip(fid_range, fid_display):
171 | print('pid:{} / fid:{}'.format(pid, fid))
172 | item = dloader.dataset.__getitem__(pid, fid_start=fid)
173 | ped_crops, masks, act = item['ped_crops'], item['all_masks'], item['GT_act']
174 | # print('masks[0][1]:', masks[0][1].shape)
175 | ped_feats, ctxt_feats, ctxt_cls = model.extract_feats(ped_crops, masks, pid)
176 |
177 | feat_path = os.path.join(extract_feats_dir, 'ped{}_fid{}.pkl'.format(pid, fid_dis))
178 | with open(feat_path, 'wb') as handle:
179 | feats = {
180 | 'ped_feats': ped_feats.cpu(), # shape: 1, 512
181 | 'ctxt_feats': ctxt_feats.cpu(), # shape: n_objs, 512
182 | 'ctxt_cls': torch.tensor(ctxt_cls)
183 | }
184 | pickle.dump(feats, handle)
185 |
186 | del ped_feats
187 | del ctxt_feats
188 |
189 | torch.cuda.empty_cache()
190 |
191 | if pid % opt.log_every == 0:
192 | print('pid', pid)
193 |
194 | def extract_feats_loc(model, dloader, extract_feats_dir, seq_len=1):
195 | print('Begin to extract')
196 | model.eval()
197 |
198 | n_vids = len(dloader)
199 | print('n_vids:', n_vids)
200 |
201 | for vid in range(0, n_vids):
202 | key = dloader.dataset.vids[vid]
203 | annot = dloader.dataset.annots[key]
204 |
205 | for fid in range(len(annot['act'])):
206 | print('vid:{} / fid:{}'.format(vid, fid))
207 | feat_path = os.path.join(extract_feats_dir, 'vid{}_fid{}.pkl'.format(vid, fid))
208 | if os.path.exists(feat_path):
209 | continue
210 |
211 | item = dloader.dataset.__getitem__(vid, fid_start=fid)
212 | ped_crops, masks, act = item['ped_crops'], item['all_masks'], item['GT_act']
213 | # print('masks[0][1]:', masks[0][1].shape)
214 | ped_feats, ctxt_feats, ctxt_cls = model.extract_feats(ped_crops, masks)
215 |
216 | with open(feat_path, 'wb') as handle:
217 | feats = {
218 | 'ped_feats': ped_feats[0].cpu(), # shape: 1, 512
219 | 'ctxt_feats': ctxt_feats.cpu(), # shape: n_objs, 512
220 | 'ctxt_cls': torch.tensor(ctxt_cls)
221 | }
222 | pickle.dump(feats, handle)
223 |
224 | del ped_feats
225 | del ctxt_feats
226 |
227 | torch.cuda.empty_cache()
228 |
229 | if vid % opt.log_every == 0:
230 | print('vid', vid)
231 |
232 |
233 | if __name__ == '__main__':
234 | opt, logger = utils.build(is_train=False)
235 |
236 | dloader = data.get_data_loader(opt)
237 | print('{} dataset: {}'.format(opt.split, len(dloader.dataset)))
238 |
239 | model = models.get_model(opt)
240 | print('Got model')
241 | if opt.which_epoch == -1:
242 | model_path = os.path.join(opt.ckpt_path, 'best_pred.pth')
243 | else:
244 | model_path = os.path.join(opt.ckpt_path, '{}.pth'.format(opt.which_epoch))
245 | if os.path.exists(model_path):
246 | # NOTE: if path not exists, then using backbone weights from ImageNet-pretrained model
247 | model.load(model_path)
248 | print('Model loaded:', model_path)
249 | else:
250 | print('Model does not exists:', model_path)
251 | model = model.to('cuda:0')
252 |
253 | try:
254 | if opt.mode == 'evaluate':
255 | evaluate(model, dloader, opt)
256 | elif opt.mode == 'extract':
257 | assert(opt.batch_size == 1)
258 | assert(opt.seq_len == 1)
259 | assert(opt.predict == 0)
260 |
261 | print('Saving at', opt.extract_feats_dir)
262 | os.makedirs(opt.extract_feats_dir, exist_ok=True)
263 | if 'loc' in opt.model:
264 | extract_feats_loc(model, dloader, opt.extract_feats_dir, opt.seq_len)
265 | else:
266 | extract_feats(model, dloader, opt.extract_feats_dir, opt.seq_len)
267 | except Exception as e:
268 | print(e)
269 | typ, vacl, tb = sys.exc_info()
270 | traceback.print_exc()
271 | last_frame = lambda tb=tb: last_frame(tb.tb_next) if tb.tb_next else tb
272 | frame = last_frame().tb_frame
273 | ns = dict(frame.f_globals)
274 | ns.update(frame.f_locals)
275 | code.interact(local=ns)
276 |
277 |
--------------------------------------------------------------------------------
/data/jaad_loc.py:
--------------------------------------------------------------------------------
1 | import os
2 | from glob import glob
3 | import numpy as np
4 | import pickle
5 | import cv2
6 | import random
7 | random.seed(2019)
8 | import torch
9 | import torch.utils.data as data
10 | import torchvision.transforms as transforms
11 |
12 | import pdb
13 | import time
14 |
15 | from utils.data_proc import parse_objs
16 |
17 | class JAADLocDataset(data.Dataset):
18 | def __init__(self, opt):
19 | # annot_ped_format, is_train, split,
20 | # seq_len, ped_crop_size, mask_size, collapse_cls,
21 | # img_path_format, fsegm_format):
22 |
23 | self.split = opt.split
24 | annot_loc = opt.annot_loc_format.format(self.split)
25 | with open(annot_loc, 'rb') as handle:
26 | self.annots = pickle.load(handle)
27 | self.vids = sorted(self.annots.keys())
28 |
29 | self.is_train = opt.is_train
30 | self.rand_test = opt.rand_test
31 | self.seq_len = opt.seq_len
32 | self.predict = opt.predict
33 | if self.predict:
34 | self.all_seq_len = self.seq_len + opt.pred_seq_len
35 | else:
36 | self.all_seq_len = self.seq_len
37 | self.predict_k = opt.predict_k
38 | if self.predict_k:
39 | self.all_seq_len += self.predict_k
40 | self.ped_crop_size = opt.ped_crop_size
41 | self.mask_size = opt.mask_size
42 | self.collapse_cls = opt.collapse_cls
43 | self.combine_method = opt.combine_method
44 |
45 | self.img_path_format = opt.img_path_format
46 | self.fsegm_format = opt.fsegm_format
47 | self.save_cache_format = opt.save_cache_format
48 | self.load_cache = opt.load_cache
49 | self.cache_format = opt.cache_format
50 | self.cache_obj_bbox_format = opt.cache_obj_bbox_format
51 |
52 | def __len__(self):
53 | return len(self.vids)
54 |
55 | def __getitem__(self, idx, fid_start=-1):
56 | t_start = time.time()
57 |
58 | vid = self.vids[idx]
59 | annot = self.annots[vid]
60 |
61 | # print('vid:', vid)
62 |
63 | if fid_start != -1:
64 | f_start = fid_start
65 | elif self.is_train or self.rand_test:
66 | # train: randomly sample all_seq_len number of frames
67 | f_start = random.randint(0, len(annot['act'])-self.all_seq_len+1)
68 | elif fid_start == -1:
69 | f_start = 0
70 |
71 | if self.is_train or self.rand_test or fid_start != -1:
72 | if self.predict_k:
73 | fids = [min(len(annot['act'])-1, f_start+i) for i in range(self.seq_len)]
74 | fids += min(len(annot['act'])-1, f_start+self.seq_len-1+self.predict_k),
75 | else:
76 | fids = [min(len(annot['act'])-1, f_start+i) for i in range(self.all_seq_len)]
77 | else:
78 | # take the entire video
79 | fids = range(0, len(annot['act']))
80 |
81 | GT_act = [annot['act'][fid] for fid in fids]
82 | GT_act = np.stack(GT_act)
83 | GT_act = torch.tensor(GT_act)
84 |
85 | GT_ped_bbox = [[pos.astype(np.float) for pos in annot['ped_pos'][fid]] for fid in fids]
86 |
87 | with open(self.cache_obj_bbox_format.format(vid), 'rb') as handle:
88 | data = pickle.load(handle)
89 | obj_cls = [data['obj_cls'][fid] for fid in fids]
90 | obj_bbox = [data['obj_bbox'][fid] for fid in fids]
91 |
92 | frames = []
93 | for fid in fids[:self.seq_len]:
94 | img_path = self.img_path_format.format(vid, fid+1) # +1 since fid is 0-based while the frame path starts at 1.
95 | img = cv2.imread(img_path)
96 | img = cv2.resize(img, (224,224)).transpose((2,0,1))
97 | frames += img,
98 |
99 | ret = {
100 | 'GT_act': GT_act,
101 | 'GT_ped_bbox': GT_ped_bbox,
102 | 'obj_cls': obj_cls,
103 | 'obj_bbox': obj_bbox,
104 | 'frames': torch.tensor(frames),
105 | 'fids': torch.tensor(np.array(fids)),
106 | }
107 | ped_crops = []
108 | all_masks = []
109 |
110 | # only the first seq_len fids are input data
111 | fids = fids[:self.seq_len]
112 |
113 | if self.load_cache == 'masks':
114 | for i,fid in enumerate(fids):
115 | # NOTE: fid for masks is 1-based.
116 | fcache = self.cache_format.format(self.split, vid, fid+1)
117 | # print(fcache)
118 | if os.path.exists(fcache):
119 | with open(fcache, 'rb') as handle:
120 | data = pickle.load(handle)
121 | ped_crops += data['ped_crops'],
122 | all_masks += data['masks'],
123 |
124 | if 'max' not in fcache:
125 | # 1 mask per obj: check for # objs
126 | if type(data['masks']) == dict:
127 | n_objs = len([each for val in data['masks'].values() for each in val])
128 | else:
129 | n_objs = len(data['masks'])
130 | if n_objs != len(obj_bbox[i]):
131 | print('JAAD: n_objs mismatch')
132 | pdb.set_trace()
133 | else:
134 | try:
135 | ped_crop, cls_masks = self.get_vid_fid(vid, annot['ped_pos'][fid], fid)
136 | except Exception as e:
137 | print(e)
138 | pdb.set_trace()
139 | ped_crops += ped_crop,
140 | all_masks += cls_masks,
141 | if type(cls_masks) == dict:
142 | n_objs = len([each for val in cls_masks.values() for each in val])
143 | else:
144 | n_objs = len(cls_masks)
145 | if n_objs != len(obj_bbox[i]):
146 | print('JAAD: n_objs mismatch')
147 | pdb.set_trace()
148 |
149 | ret['ped_crops'] = ped_crops
150 | ret['all_masks'] = all_masks
151 |
152 | # pdb.set_trace()
153 | n_peds = sum([len(each) for each in ped_crops])
154 | n_objs = sum([len(each) for each in all_masks])
155 | # print('n_peds:{} / n_objs:{}'.format(n_peds, n_objs))
156 | return ret
157 |
158 | elif self.load_cache == 'feats':
159 | ped_feats = []
160 | ctxt_feats = []
161 | for fid in fids:
162 | # NOTE: fid for feats is 0-based.
163 | # with open(self.cache_format.format(self.split, vid, fid), 'rb') as handle:
164 | with open(self.cache_format.format(self.split, idx, fid), 'rb') as handle:
165 | data = pickle.load(handle)
166 | ped = data['ped_feats'] # shape: 1, 512
167 | ctxt = data['ctxt_feats'] # shape: n_objs, 512
168 |
169 | ped_feats += ped,
170 | ctxt_feats += ctxt,
171 | # ped_feats = torch.stack(ped_feats, 0)
172 |
173 | ret['ped_crops'] = ped_feats
174 | ret['all_masks'] = ctxt_feats
175 | return ret
176 |
177 | elif self.load_cache == 'pos':
178 | ret['ped_crops'] = torch.zeros([1,1,512])
179 | ret['all_masks'] = torch.zeros([1,1,512])
180 | return ret
181 |
182 | for fid in fids:
183 | ped_crop, cls_masks = self.get_ped_fid(ped, fid, idx)
184 | ped_crops += ped_crop,
185 | all_masks += cls_masks,
186 |
187 | # shape: [n_frames, self.ped_crop_size[0], self.ped_crop_size[1]]
188 |
189 | ped_crops = np.stack(ped_crops)
190 | ped_crops = torch.Tensor(ped_crops)
191 |
192 | # print('time per item:', time.time()-t_start)
193 | ret['ped_crops'] = ped_crops
194 | ret['all_masks'] = all_masks
195 | return ret
196 |
197 |
198 | def get_vid_fid(self, vid, peds, fid):
199 | """
200 | Prepare ped_crop and obj masks for given ped and fid.
201 | """
202 |
203 | ped_crops = []
204 | if len(peds) == 0:
205 | ped_crops = [np.zeros([3, 224, 224])]
206 | # if no bbox, take the entire frame
207 | x,y,w,h = 0,0,-1,-1
208 | else:
209 | img_path = self.img_path_format.format(vid, fid+1) # +1 since fid is 0-based while the frame path starts at 1.
210 | img = cv2.imread(img_path)
211 |
212 | # pedestrian crops
213 | for ped in peds:
214 | x, y, w, h = ped
215 | x, y, w, h = int(x), int(y), int(w), int(h)
216 |
217 | try:
218 | ped_crop = img[y:y+h, x:x+w]
219 | except Exception as e:
220 | print(e)
221 | print('img_path:', img_path)
222 | print('x:{}, y:{}, w:{}, h:{}'.format(x, y, w, h))
223 | ped_crop = cv2.resize(ped_crop, self.ped_crop_size)
224 | ped_crop = ped_crop.transpose((2,0,1))
225 | ped_crops += ped_crop,
226 |
227 | # obj masks
228 | fsegm = self.fsegm_format.format(vid, fid)
229 | objs = parse_objs(fsegm)
230 | if self.collapse_cls:
231 | cls_masks = []
232 | else:
233 | cls_masks = {}
234 | for cls, masks in objs.items():
235 | if not self.collapse_cls:
236 | cls_masks[cls] = []
237 | for mask in masks:
238 | mask[y:y+h, x:x+w] = 1
239 | if self.combine_method == 'pair':
240 | # crop out the union bbox of ped + obj
241 | # note that didn't check empty.
242 | x_pos = mask.sum(0).nonzero()[0]
243 | x_min, x_max = x_pos[0], x_pos[-1]
244 | y_pos = mask.sum(1).nonzero()[1]
245 | y_min, y_max = y_pos[0], y_pos[-1]
246 | mask = mask[y_min:y_max+1, x_min:x_max+1]
247 | mask = cv2.resize(mask, self.mask_size)
248 | mask = torch.tensor(mask)
249 | # mask = torch.stack([mask, mask, mask])
250 | if self.collapse_cls:
251 | cls_masks += mask,
252 | else:
253 | # TODO: transform the mask: e.g. crop & norm over the union
254 | cls_masks[cls] += mask,
255 | if not self.collapse_cls:
256 | cls_masks[cls] = torch.stack(cls_masks[cls])
257 | if self.combine_method == 'sum':
258 | cls_masks[cls] = cls_masks[cls].sum(0)
259 | elif self.combine_method == 'max':
260 | cls_masks[cls], _ = cls_masks[cls].max(0)
261 |
262 | if self.collapse_cls:
263 | if len(cls_masks) != 0:
264 | cls_masks = torch.stack(cls_masks)
265 | if self.combine_method == 'sum':
266 | cls_masks = cls_masks.sum(0)
267 | elif self.combine_method == 'max':
268 | cls_masks, _ = cls_masks.max(0)
269 | else:
270 | # no objects in the frame
271 | if self.combine_method:
272 | # e.g. 'sum' or 'max'
273 | cls_masks = torch.zeros(self.mask_size)
274 | else:
275 | cls_masks = torch.zeros([0, self.mask_size[0], self.mask_size[1]])
276 |
277 | if self.cache_format:
278 | with open(self.cache_format.format(self.split, vid, fid+1), 'wb') as handle:
279 | cache = {
280 | 'ped_crops': ped_crops,
281 | 'masks': cls_masks.data if self.collapse_cls else {cls:cls_masks[cls].data for cls in cls_masks},
282 | }
283 | pickle.dump(cache, handle)
284 |
285 | return ped_crops, cls_masks
286 |
287 |
288 |
--------------------------------------------------------------------------------
/utils/data_proc.py:
--------------------------------------------------------------------------------
1 | import os
2 | from glob import glob
3 | import numpy as np
4 | import pickle
5 | import cv2
6 | import xml.etree.ElementTree as ET
7 | from pycocotools import mask as maskUtils
8 |
9 | import pdb
10 |
11 | # Objects of interest
12 | cls_map = {
13 | # other road users (treated as part of vehicles)
14 | 5: 1, # bicyclist
15 | 6: 1, # motorcyclist
16 | 7: 1, # other riders
17 | # vehicles
18 | 28: 1, # bicycle
19 | 29: 1, # boat (??)
20 | 30: 1, # bus
21 | 31: 1, # car
22 | 32: 1, # caravan
23 | 33: 1, # motorcycle
24 | 34: 1, # other vehicle
25 | 35: 1, # trailer
26 | 36: 1, # truck
27 | 37: 1, # wheeled slow
28 | # environments
29 | 3: 2, # crosswalk - plain
30 | 8: 3, # crosswalk - zebra
31 | 24: 4, # traffic lights
32 | }
33 |
34 | # smaller set of objs becaues of memory
35 | cls_map_small = {
36 | # vehicles
37 | 28: 1, # bicycle
38 | 30: 1, # bus
39 | 31: 1, # car
40 | 33: 1, # motorcycle
41 | 35: 1, # trailer
42 | 36: 1, # truck
43 | # environments
44 | 3: 2, # crosswalk - plain
45 | 8: 3, # crosswalk - zebra
46 | 24: 4, # traffic lights
47 | }
48 |
49 |
50 | def parse_objs(fnpy):
51 | # Parse instance segmentations on one frame to masks.
52 | # In: filename of instance segmentation (i.e. `dataset/JAAD_instance_segm/video_{:04d}/{:08d}_segm.npy`)
53 | # Out: dict with key 1-4 (see `cls_map`), each for a type of graph nodes.
54 | segms = np.load(fnpy, allow_pickle=True)
55 |
56 | selected_segms = {}
57 | for key,val in cls_map.items():
58 | if segms[key]:
59 | if val not in selected_segms:
60 | selected_segms[val] = []
61 | for each in segms[key]:
62 | mask = maskUtils.decode([each])
63 | selected_segms[val] += mask,
64 | del segms
65 | return selected_segms
66 |
67 |
68 | def get_obj_crops(fnpy_root, fpkl_root):
69 | os.makedirs(fpkl_root, exist_ok=True)
70 |
71 | vids = sorted(glob(os.path.join(fnpy_root, 'video_*')))
72 | for i,vid in enumerate(vids):
73 | if i < 2:
74 | continue
75 | print('vid:', os.path.basename(vid))
76 | fnpys = sorted(glob(os.path.join(vid, '*_segm.npy')))
77 | for fnpy in fnpys:
78 | selected_segms = parse_objs(fnpy)
79 | crops = {}
80 | for cls,segms in selected_segms.items():
81 | if len(segms):
82 | crops[cls] = []
83 | for mask in segms:
84 | y_pos = mask.sum(1).nonzero()[0]
85 | x_pos = mask.sum(0).nonzero()[0]
86 | if len(y_pos) == 0:
87 | print('empty y_pos:', fnpy)
88 | continue
89 | if len(x_pos) == 0:
90 | print('empty x_pos:', fnpy)
91 | continue
92 | y_min, y_max = y_pos[0], y_pos[-1]
93 | x_min, x_max = x_pos[0], x_pos[-1]
94 | if x_min >= x_max or y_min >= y_max:
95 | print('empty_crop:', fnpy)
96 | print('x_min={} / x_max={} / y_min={} / y_max={}\n'.format(x_min, x_max, y_min, y_max))
97 | continue
98 | crop = mask[y_min:(y_max+1), x_min:(x_max+1)]
99 | crop = cv2.resize(crop, (224,224))
100 | crops[cls] += crop,
101 | if len(crops[cls]):
102 | crops[cls] = np.stack(crops[cls])
103 | else:
104 | crops[cls] = np.zeros([0, 224, 224])
105 | else:
106 | crops[cls] = np.zeros([0, 224, 224])
107 |
108 | fpkl = fnpy.replace(fnpy_root, fpkl_root)
109 | fpkl_dir = os.path.dirname(fpkl)
110 | os.makedirs(fpkl_dir, exist_ok=True)
111 | with open(fpkl, 'wb') as handle:
112 | pickle.dump(crops, handle)
113 |
114 |
115 | pedestrian_act_map = {
116 | 'clear path': 0,
117 | 'crossing': 1,
118 | 'handwave': 2,
119 | 'looking': 3,
120 | 'nod': 4,
121 | 'slow down': 5,
122 | 'speed up': 6,
123 | 'standing': 7,
124 | 'walking': 8,
125 | }
126 |
127 | joint_ids = [
128 | 1, # 0 Neck
129 | 2, # 1 RShoulder
130 | 5, # 2 LShoulder
131 | 9, # 3 RHip
132 | 10, # 4 RKnee
133 | 11, # 5 RAnkle
134 | 12, # 6 LHip
135 | 13, # 7 LKnee
136 | 14, # 8 LAnkle
137 | ]
138 |
139 | def parse_pedestrian(fxml, fpos_GT, fpos_pred='', fpose=''):
140 | """
141 | parse xml (action label)
142 | """
143 | e = ET.parse(fxml).getroot()
144 | # peds: dict of id to a list of per-frame acts.
145 | # Action labels at each frame is a one-hot vec of length 9 (i.e. len(pedestrian_act_map)).
146 | peds = {}
147 | nframes = int(e.get('num_frames'))
148 | for child in e.getchildren():
149 | if child.tag != 'actions':
150 | continue
151 |
152 | for each in child.getchildren():
153 | if 'pedestrian' not in each.tag: # e.g. Driver
154 | continue
155 |
156 | # NOTE: `pid` starts at 1.
157 | tag = each.tag
158 | if tag == 'pedestrian':
159 | pid = 1
160 | tag = 'pedestrian1'
161 | else:
162 | pid = int(tag[len('pedestrian'):])
163 | # NOTE: change indexing from 'pid' to 'each.tag'
164 | peds[tag] = {'act': [[0] * len(pedestrian_act_map) for _ in range(nframes)]}
165 | peds[tag]['tag'] = tag
166 | peds[tag]['pid'] = pid
167 | pacts = each.getchildren()
168 | for act in pacts:
169 | act_cls = pedestrian_act_map[act.get('id').lower()]
170 | for t in range(int(act.get('start_frame'))-1, int(act.get('end_frame'))):
171 | peds[tag]['act'][t][act_cls] = 1
172 |
173 | if fpose:
174 | """
175 | parse pose
176 | """
177 | pose_data = np.load(fpose, encoding='latin1', allow_pickle=True).item()
178 | for ped_tag, frames in sorted(pose_data.items()):
179 | if ped_tag not in peds:
180 | continue
181 | peds[ped_tag]['pose'] = [[] for _ in range(nframes)]
182 | peds[ped_tag]['pos_Pose'] = [[] for _ in range(nframes)]
183 |
184 | # NOTE fid are 0-indexed in pose_data
185 | for frame in sorted(frames):
186 | fid = frame[0]
187 | peds[ped_tag]['pose'][fid] = np.array(frame[2][0])[joint_ids]
188 | peds[ped_tag]['pos_Pose'][fid] = frame[1]['pos']
189 |
190 | """
191 | parse position
192 | """
193 | poccl = {ped_tag:[[] for _ in range(nframes)] for ped_tag in peds}
194 |
195 | # NOTE: not all pedestrians in the npy files have labels.
196 | # We only use those with action labels in xml.
197 | data = np.load(fpos_GT, allow_pickle=True).item()
198 | assert(data['nFrame'] == nframes), print("nFrame mismatch: xml:{} / npy: {}".format(nframes, data['nFrame']))
199 |
200 | assert(len(data['objLists']) == nframes), print("nFrame mismatch: xml:{} / npy-->objLists: {}".format(nframes, len(data['objLists'])))
201 |
202 | ppos_GT = {ped_tag:[[] for _ in range(nframes)] for ped_tag in peds}
203 | # pdb.set_trace()
204 | for fid,frame in enumerate(data['objLists']):
205 | # frame: dict w/ keys: 'id', 'pos', 'posv', 'occl', 'lock'
206 | for ped in frame:
207 | pid = ped['id'][0]
208 | for ped_tag in peds:
209 | if peds[ped_tag]['pid'] == pid:
210 | ppos_GT[ped_tag][fid] = ped['pos'] # (x, y, w, h)
211 | poccl[ped_tag][fid] = ped['occl'][0]
212 | break
213 | # if ped_tag in peds:
214 | # ppos_GT[ped_tag][fid] = ped['pos'] # (x, y, w, h)
215 | # poccl[ped_tag][fid] = ped['occl'][0]
216 |
217 | for ped_tag in peds:
218 | pid = peds[ped_tag]['pid']
219 | peds[ped_tag]['frame_start'] = data['objStr'][pid-1]-1
220 | peds[ped_tag]['frame_end'] = data['objEnd'][pid-1]-1
221 | peds[ped_tag]['pos_GT'] = ppos_GT[ped_tag]
222 | peds[ped_tag]['occl'] = poccl[ped_tag]
223 |
224 | if fpos_pred:
225 | raise NotImplementedError('cannot parse pedestrian segm for now. Sorry!!')
226 |
227 | return peds
228 |
229 |
230 | def prepare_data():
231 | # object GT directory
232 | obj_root = '/sailhome/ajarno/STR-PIP/datasets/JAAD_instance_segm'
233 | fobj_dir_format = os.path.join(obj_root, 'video_{:04d}')
234 | # pedestrian GT files
235 | #ped_root = '/sailhome/ajarno/STR-PIP/datasets/JAAD_dataset/'
236 | ped_root = '/vision/u/caozj1995/data/JAAD_dataset/'
237 | fxml_format = os.path.join(ped_root, 'behavioral_data_xml', 'video_{:04d}.xml')
238 | fpose_format = os.path.join('/vision2/u/mangalam/JAAD/openpose_track_with_pose/', 'video_{:04d}.npy')
239 | fpos_GT_format = os.path.join(ped_root, 'bounding_box_python', 'vbb_part', 'video_{:04d}.npy')
240 |
241 | def prepare_split(vid_range):
242 |
243 | all_peds = []
244 | for vid in vid_range:
245 | # print(vid)
246 |
247 | if True:
248 | # objects
249 | fobj_dir = fobj_dir_format.format(vid)
250 | fsegms = sorted(glob(os.path.join(fobj_dir, '*_segm.npy')))
251 | pdb.set_trace()
252 |
253 | frame_objs = []
254 | for fid,fsegm in enumerate(fsegms):
255 | print('fid:', fid)
256 | frame_objs += parse_objs(fsegm),
257 |
258 | print('Finished objectS')
259 |
260 | if True:
261 | # pedestrians
262 | fxml = fxml_format.format(vid)
263 | fpose = fpose_format.format(vid)
264 | fpose = ''
265 | fpos_GT = fpos_GT_format.format(vid)
266 | ped_label = parse_pedestrian(fxml, fpos_GT, fpose=fpose)
267 | if len(ped_label) == 0:
268 | print('No pedestrian: vid:', vid)
269 |
270 | for ped in ped_label.values():
271 | ped['vid'] = vid
272 | all_peds += ped,
273 |
274 | return all_peds
275 |
276 |
277 | annot_root = '/sailhome/ajarno/STR-PIP/datasets/'
278 | # Train
279 | print('Processing training data...')
280 | train_range = range(1, 250+1)
281 | annot_train = prepare_split(train_range)
282 | print('# Train:', len(annot_train))
283 | with open(os.path.join(annot_root, 'annot_train_ped_withTag_sanityNoPose.pkl'), 'wb') as handle:
284 | pickle.dump(annot_train, handle)
285 | print('Training data ready.\n')
286 |
287 | # Test
288 | print('Processing testing data...')
289 | test_range = range(251, 346+1)
290 | annot_test = prepare_split(test_range)
291 | print('# Test:', len(annot_test))
292 | with open(os.path.join(annot_root, 'annot_test_ped_withTag_sanityNoPose.pkl'), 'wb') as handle:
293 | pickle.dump(annot_test, handle)
294 | print('Testing data ready.\n')
295 |
296 |
297 |
298 | if __name__ == '__main__':
299 | if False:
300 | # test
301 | fsegm = '/sailhome/bingbin/STR-PIP/datasets/JAAD_instance_segm/video_0131/00000001_segm.npy'
302 | parse_objs(fsegm)
303 | # fxml = '/sailhome/bingbin/STR-PIP/datasets/JAAD_dataset/behavioral_data_xml/video_0001.xml'
304 | # fpos = '/vision2/u/caozj/datasets/JAAD_dataset/bounding_box_python/vbb_part/video_0001.npy'
305 | # parse_pedestrian(fxml, fpos)
306 |
307 | if True:
308 | prepare_data()
309 |
310 | if False:
311 | stip_segm2box_wrapper()
312 |
--------------------------------------------------------------------------------
/models/model_concat.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | import random
7 |
8 | from models.backbone.resnet_based import resnet_backbone
9 |
10 | from models.model_base import BaseModel
11 |
12 | import pdb
13 |
14 |
15 | class ConcatModel(BaseModel):
16 | def __init__(self, opt):
17 | super().__init__(opt)
18 | if self.use_gru:
19 | self.gru = nn.GRU(self.cls_dim, self.cls_dim, 2, batch_first=True).to(self.device)
20 | elif self.use_trn:
21 | if self.predict and self.pred_seq_len != 1:
22 | raise ValueError("self.pred_seq_len has to be 1 when using TRN.")
23 | self.k = 3 # number of samples per branch
24 | self.f2 = nn.Sequential(
25 | nn.Linear(2*self.cls_dim, 256),
26 | nn.ReLU(),
27 | nn.Linear(256, 256))
28 | self.f3 = nn.Sequential(
29 | nn.Linear(3*self.cls_dim, 256),
30 | nn.ReLU(),
31 | nn.Linear(256, 256))
32 | self.f4 = nn.Sequential(
33 | nn.Linear(4*self.cls_dim, 256),
34 | nn.ReLU(),
35 | nn.Linear(256, 256))
36 | self.h2 = nn.Linear(256*self.k, self.n_acts)
37 | self.h3 = nn.Linear(256*self.k, self.n_acts)
38 | self.h4 = nn.Linear(256*self.k, self.n_acts)
39 |
40 | if self.use_ped_gru:
41 | self.ped_gru = nn.GRU(self.ped_dim, self.ped_dim, 2, batch_first=True).to(self.device)
42 | if self.use_ctxt_gru:
43 | self.ctxt_gru = nn.GRU(self.conv_dim, self.conv_dim, 2, batch_first=True).to(self.device)
44 |
45 | if self.use_act and not self.use_gt_act:
46 | self.act_gru = nn.GRU(self.cls_dim-self.n_acts, self.cls_dim, 2, batch_first=True).to(self.device)
47 |
48 |
49 | if self.predict:
50 | # NOTE: this GRU only takes in seq of len 1, i.e. one time step
51 | # GRU is used over GRUCell for multilayer
52 | self.gru_pred = nn.GRU(self.cls_dim, self.cls_dim, 2, batch_first=True).to(self.device)
53 |
54 | if self.predict_k:
55 | self.fc_pred = nn.Sequential(
56 | nn.Linear(self.cls_dim, 256),
57 | nn.ReLU(),
58 | nn.Linear(256, self.cls_dim)
59 | )
60 |
61 |
62 | def forward(self, ped_crops, masks, bbox=None, act=None, pose=None, obj_bbox=None, obj_cls=None):
63 | B, T, _,_,_ = ped_crops.shape
64 |
65 | if self.branch in ['both', 'ped']:
66 | # ped_crops: (bt, 30, 3, 224, 224)
67 | ped_crops = ped_crops.type(self.dtype).view(-1, 3, 224, 224)
68 | ped_crops = ped_crops.to(self.device)
69 | ped_feats = self.ped_encoder(ped_crops)
70 | # ped_feats: (B, T, d)
71 | ped_feats = ped_feats.view(B, T, -1)
72 | if self.use_ped_gru:
73 | ped_feats, _ = self.ped_gru(ped_feats)
74 |
75 | if self.branch in ['both', 'ctxt']:
76 | # masks: list: (bt, 30, {cls: (n, 3, 224, 224)})
77 | ctxt_feats = [[] for _ in range(B)]
78 | for b in range(B):
79 | for t in range(T):
80 | if len(masks[b][t]) == 0:
81 | ctxt_feat = torch.zeros([self.conv_dim])
82 | ctxt_feat = ctxt_feat.to(self.device)
83 | else:
84 | if type(masks[b][t]) is dict:
85 | vals = list(masks[b][t].values())
86 | # ctxt_masks: (n_total, (3,) 224, 224)
87 | ctxt_masks = torch.cat(vals, 0)
88 | else:
89 | ctxt_masks = masks[b][t]
90 | # TODO: this seems to be a bug though it didn't complain before.
91 | # Check whether this will affect the prev model.
92 | # ctxt_masks = ctxt_masks.sum(0, True)
93 | if ctxt_masks.dim() == 3:
94 | n_total, h, w = ctxt_masks.shape
95 | ctxt_masks = ctxt_masks.unsqueeze(1).expand([n_total, 3, h, w])
96 | elif ctxt_masks.dim() == 2:
97 | h, w = ctxt_masks.shape
98 | ctxt_masks = ctxt_masks.unsqueeze(0).unsqueeze(0).expand([1, 3, h, w])
99 | ctxt_masks = ctxt_masks.to(self.device)
100 | # ctxt_feats: (n_total, d)
101 | # print('ctxt_masks', ctxt_masks.shape)
102 | ctxt_feat = self.ctxt_encoder(ctxt_masks.type(self.dtype))
103 | # average pool
104 | ctxt_feat = ctxt_feat.mean(0).squeeze(-1).squeeze(-1)
105 | ctxt_feats[b] += ctxt_feat,
106 | ctxt_feats[b] = torch.stack(ctxt_feats[b])
107 | ctxt_feats = torch.stack(ctxt_feats)
108 |
109 |
110 | if os.path.exists(self.extract_feats_dir):
111 | # NOTE: set self.rand_test = 0 when extracting features since we want to cover all frames
112 | feat_path = os.path.join(self.extract_feats_dir, 'ped{}.pkl'.format(idx))
113 | with open(feat_path, 'wb') as handle:
114 | feats = {
115 | 'ped_feats': ped_feats,
116 | 'ctxt_feats': ctxt_feats,
117 | }
118 | pickle.dump(feats, handle)
119 |
120 | if self.pos_mode != 'none':
121 | ped_feats = self.append_pos(ped_feats, bbox[:, :self.seq_len])
122 |
123 | if self.use_signal:
124 | # act 2: handwave / act 3: looking
125 | signal = act[:, :, 2:4].to(self.device)
126 | ped_feats = torch.cat([ped_feats, signal], -1)
127 |
128 | if self.branch == 'both':
129 | frame_feats = torch.cat([ped_feats, ctxt_feats], -1)
130 | elif self.branch == 'ped':
131 | frame_feats = ped_feats
132 | elif self.branch == 'ctxt':
133 | frame_feats = ctxt_feats
134 | else:
135 | raise ValueError("self.branch should be 'both', 'ped', or 'ctxt'. Got {}".format(self.branch))
136 |
137 | if self.use_act:
138 | if self.use_gt_act:
139 | if self.n_acts == 1:
140 | act = act[:, :, 1:2]
141 | act = act[:, :self.seq_len]
142 | else:
143 | # use predicted action labels
144 | temporal_feats, _ = self.act_gru(frame_feats)
145 | h = self.classifier(temporal_feats)
146 | act_logits = self.sigmoid(h)
147 | act = (act_logits > 0.5).type(self.dtype)
148 | frame_feats = torch.cat([frame_feats, act], -1)
149 |
150 | if self.use_pose:
151 | normed_pose = self.util_norm_pose(pose)
152 | normed_pose = normed_pose.contiguous().view(B, T, -1)
153 | frame_feats = torch.cat([frame_feats, normed_pose], -1)
154 |
155 | if self.use_gru:
156 | # self.gru keeps the dimension of frame_feats
157 | frame_feats, h = self.gru(frame_feats)
158 | elif self.use_trn:
159 | # Note: for predicting the next frame only (i.e. Lopez's setting)
160 | feats2 = []
161 | feats3 = []
162 | feats4 = []
163 | for _ in range(self.k):
164 | # 2-frame relations
165 | l2 = self.seq_len // 2
166 | id1 = random.randint(0, l2-1)
167 | id2 = random.randint(0, l2-1) + l2
168 | feat2 = torch.cat([frame_feats[:, id1], frame_feats[:, id2]], -1)
169 | feats2 += self.f2(feat2),
170 | # 3-frame relations
171 | l3 = self.seq_len // 3
172 | id1 = random.randint(0, l3-1)
173 | id2 = random.randint(l3, 2*l3-1)
174 | id3 = random.randint(2*l3, self.seq_len-1)
175 | feat3 = torch.cat([frame_feats[:, id1], frame_feats[:, id2], frame_feats[:, id3]], -1)
176 | feats3 += self.f3(feat3),
177 | # 4-frame relations
178 | l4 = self.seq_len // 4
179 | id1 = random.randint(0, l4-1)
180 | id2 = random.randint(l4, 2*l4-1)
181 | id3 = random.randint(2*l4, 3*l4-1)
182 | id4 = random.randint(3*l4, self.seq_len-1)
183 | feat4 = torch.cat([frame_feats[:, id1], frame_feats[:, id2], frame_feats[:, id3], frame_feats[:, id4]], -1)
184 | feats4 += self.f4(feat4),
185 | t2 = self.h2(torch.cat(feats2, -1))
186 | t3 = self.h2(torch.cat(feats3, -1))
187 | t4 = self.h2(torch.cat(feats4, -1))
188 | logits = t2 + t3 + t4
189 | logits = logits.view(B, 1, self.n_acts)
190 | return logits
191 |
192 | if self.predict:
193 | # o: (B, 1, cls_dim)
194 | o = frame_feats[:, -1:]
195 | pred_outs = []
196 | for pred_t in range(self.pred_seq_len):
197 | o, h = self.gru_pred(o, h)
198 | # if self.pos_mode != 'none':
199 | # o = self.append_pos(o, bbox[:, self.seq_len+pred_t:self.seq_len+pred_t+1])
200 | pred_outs += o,
201 |
202 | # pred_outs: (B, pred_seq_len, cls_dim)
203 | pred_outs = torch.cat(pred_outs, 1)
204 | # frame_feats: (B, seq_len + pred_seq_len, cls_dim)
205 | frame_feats = torch.cat([frame_feats, pred_outs], 1)
206 |
207 | if self.predict_k:
208 | # h: (n_gru_layers, B, cls_dim) --> (B, 1, cls_dim)
209 | h = h.transpose(0,1)[:, -1:]
210 | # pred_feats: (B, T, cls_dim)
211 | pred_feats = self.fc_pred(h)
212 | frame_feats = torch.cat([frame_feats, pred_feats], 1)
213 |
214 | # shape: (B, T, 2)
215 | logits = self.classifier(frame_feats)
216 | if self.use_act and not self.use_gt_act:
217 | return logits, act_logits
218 | # logits[:, :self.seq_len] = act_logits
219 |
220 | return logits
221 |
222 |
223 | def extract_feats(self, ped_crops, masks, idx):
224 | """
225 | ped_crops: (1, 3, 224, 224)
226 | masks: list of len 1; each item being a dict w/ key 1-4
227 | """
228 | assert(ped_crops.dim()==4), print('extract_feats does not support batch mode. Data is obtained directly from __getitem__.')
229 | assert(len(masks) == 1)
230 | T = ped_crops.shape[0]
231 |
232 | # step = 30
233 |
234 | # ped_crops: (bt, 30, 3, 224, 224)
235 | ped_crops = ped_crops.view(-1, 3, 224, 224)
236 |
237 | # ped_feats: shape: (T, D (e.g.512), 1, 1)
238 | # ped_feats = self.ped_encoder(ped_crops.to(self.device)).cpu()
239 | ped_feats = self.ped_encoder(ped_crops.type(self.dtype)).cpu()
240 | ped_feats = ped_feats.view(T, -1)
241 |
242 | ctxt_feats = []
243 | ctxt_cls = []
244 | for t,mask in enumerate(masks):
245 | if len(mask) == 0:
246 | ctxt_feat = torch.zeros([1, self.conv_dim])
247 | ctxt_cls += [0]
248 | else:
249 | if type(mask) is dict:
250 | # grouped by classes
251 | ctxt_masks = []
252 | for k in sorted(mask):
253 | for each in mask[k]:
254 | ctxt_cls += k,
255 | ctxt_masks += each,
256 | ctxt_masks = torch.stack(ctxt_masks, 0)
257 | else:
258 | # class collapsed
259 | ctxt_masks = mask
260 | ctxt_cls += [0] * ctxt_masks.shape[0]
261 |
262 | if ctxt_masks.dim() == 3:
263 | n_total, h, w = ctxt_masks.shape
264 | ctxt_masks = ctxt_masks.unsqueeze(1).expand([n_total, 3, h, w])
265 | ctxt_masks = ctxt_masks.to(self.device)
266 |
267 | # ctxt_feats: (n_total, d)
268 | ctxt_feat = self.ctxt_encoder(ctxt_masks.type(self.dtype)).cpu()
269 | ctxt_feat = ctxt_feat.squeeze(-1).squeeze(-1)
270 | ctxt_feats += ctxt_feat,
271 | ctxt_feats = torch.cat(ctxt_feats, 0)
272 | assert(len(ctxt_feats) == len(ctxt_cls))
273 |
274 | # NOTE: set self.rand_test = 0 when extracting features since we want to cover all frames
275 | return ped_feats, ctxt_feats, ctxt_cls
276 |
277 |
--------------------------------------------------------------------------------