--output drow.bag
121 | ```
122 |
123 | Use RViz to visualize the inference result.
124 | A simple RViz config is located at `dr_spaam_ros/example.rviz`.
125 |
126 | ## Inference time
127 | | | AP0.3 | AP0.5 | FPS (RTX 2080 laptop) | FPS (Jetson AGX) |
128 | |--------|------------------|------------------|-----------------------|------------------|
129 | |DROW | 0.638 | 0.659 | 95.8 | 24.8 |
130 | |DR-SPAAM| 0.707 | 0.723 | 87.3 | 22.6 |
131 |
132 | Note: In the original paper, we used a voting scheme for postprocessing.
133 | In the implementation here, we have replaced the voting with a non-maximum suppression,
134 | where two detections that are less than 0.5 m apart are considered as duplicates
135 | and the less confident one is suppressed.
136 | Thus there is a mismatch between the numbers here and those listed in the paper.
137 |
138 | ## Citation
139 | If you use DR-SPAAM in your project, please cite:
140 | ```BibTeX
141 | @inproceedings{Jia2020DRSPAAM,
142 | title = {{DR-SPAAM: A Spatial-Attention and Auto-regressive
143 | Model for Person Detection in 2D Range Data}},
144 | author = {Dan Jia and Alexander Hermans and Bastian Leibe},
145 | booktitle = {International Conference on Intelligent Robots and Systems (IROS)},
146 | year = {2020}
147 | }
148 | ```
149 |
--------------------------------------------------------------------------------
/dr_spaam/bin/demo.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 | import numpy as np
4 |
5 | # import matplotlib
6 | # matplotlib.use('agg')
7 | import matplotlib.pyplot as plt
8 |
9 | from dr_spaam.detector import Detector
10 | import dr_spaam.utils.utils as u
11 |
12 |
13 | def inference_time():
14 | seq_name = './data/DROWv2-data/test/run_t_2015-11-26-11-55-45.bag.csv'
15 | scans = np.genfromtxt(seq_name, delimiter=',')[:, 2:]
16 |
17 | # inference time
18 | use_gpu = True
19 | model_names = ("DR-SPAAM", "DROW", "DROW-T5")
20 | ckpts = (
21 | "./ckpts/dr_spaam_e40.pth",
22 | "./ckpts/drow_e40.pth",
23 | "./ckpts/drow5_e40.pth"
24 | )
25 | for model_name, ckpt in zip(model_names, ckpts):
26 | detector = Detector(model_name=model_name, ckpt_file=ckpt, gpu=use_gpu, stride=1)
27 | detector.set_laser_spec(angle_inc=np.radians(0.5), num_pts=450)
28 |
29 | t_list = []
30 | for i in range(60):
31 | s = scans[i:i+5] if model_name == "DROW-T5" else scans[i]
32 | t0 = time.time()
33 | dets_xy, dets_cls, instance_mask = detector(s)
34 | t_list.append(1e3 * (time.time() - t0))
35 |
36 | t = np.array(t_list[10:]).mean()
37 | print("inference time (model: %s, gpu: %s): %f ms (%.1f FPS)" % (
38 | model_name, use_gpu, t, 1e3 / t))
39 |
40 |
41 | def play_sequence():
42 | # scans
43 | seq_name = './data/DROWv2-data/test/run_t_2015-11-26-11-22-03.bag.csv'
44 | # seq_name = './data/DROWv2-data/val/run_2015-11-26-15-52-55-k.bag.csv'
45 | scans_data = np.genfromtxt(seq_name, delimiter=',')
46 | scans_t = scans_data[:, 1]
47 | scans = scans_data[:, 2:]
48 | scan_phi = u.get_laser_phi()
49 |
50 | # odometry, used only for plotting
51 | odo_name = seq_name[:-3] + 'odom2'
52 | odos = np.genfromtxt(odo_name, delimiter=',')
53 | odos_t = odos[:, 1]
54 | odos_phi = odos[:, 4]
55 |
56 | # detector
57 | ckpt = './ckpts/dr_spaam_e40.pth'
58 | detector = Detector(model_name="DR-SPAAM", ckpt_file=ckpt, gpu=True, stride=1)
59 | detector.set_laser_spec(angle_inc=np.radians(0.5), num_pts=450)
60 |
61 | # scanner location
62 | rad_tmp = 0.5 * np.ones(len(scan_phi), dtype=np.float)
63 | xy_scanner = u.rphi_to_xy(rad_tmp, scan_phi)
64 | xy_scanner = np.stack(xy_scanner, axis=1)
65 |
66 | # plot
67 | fig = plt.figure(figsize=(10, 10))
68 | ax = fig.add_subplot(111)
69 |
70 | _break = False
71 |
72 | def p(event):
73 | nonlocal _break
74 | _break = True
75 | fig.canvas.mpl_connect('key_press_event', p)
76 |
77 | # video sequence
78 | odo_idx = 0
79 | for i in range(len(scans)):
80 | # for i in range(0, len(scans), 20):
81 | plt.cla()
82 |
83 | ax.set_aspect('equal')
84 | ax.set_xlim(-15, 15)
85 | ax.set_ylim(-15, 15)
86 |
87 | # ax.set_title('Frame: %s' % i)
88 | ax.set_title('Press any key to exit.')
89 | ax.axis("off")
90 |
91 | # find matching odometry
92 | while odo_idx < len(odos_t) - 1 and odos_t[odo_idx] < scans_t[i]:
93 | odo_idx += 1
94 | odo_phi = odos_phi[odo_idx]
95 | odo_rot = np.array([[np.cos(odo_phi), np.sin(odo_phi)],
96 | [-np.sin(odo_phi), np.cos(odo_phi)]], dtype=np.float32)
97 |
98 | # plot scanner location
99 | xy_scanner_rot = np.matmul(xy_scanner, odo_rot.T)
100 | ax.plot(xy_scanner_rot[:, 0], xy_scanner_rot[:, 1], c='black')
101 | ax.plot((0, xy_scanner_rot[0, 0] * 1.0), (0, xy_scanner_rot[0, 1] * 1.0), c='black')
102 | ax.plot((0, xy_scanner_rot[-1, 0] * 1.0), (0, xy_scanner_rot[-1, 1] * 1.0), c='black')
103 |
104 | # plot points
105 | scan = scans[i]
106 | scan_x, scan_y = u.rphi_to_xy(scan, scan_phi + odo_phi)
107 | ax.scatter(scan_x, scan_y, s=1, c='blue')
108 |
109 | # inference
110 | dets_xy, dets_cls, instance_mask = detector(scan)
111 |
112 | # plot detection
113 | dets_xy_rot = np.matmul(dets_xy, odo_rot.T)
114 | cls_thresh = 0.5
115 | for j in range(len(dets_xy)):
116 | if dets_cls[j] < cls_thresh:
117 | continue
118 | # c = plt.Circle(dets_xy_rot[j], radius=0.5, color='r', fill=False)
119 | c = plt.Circle(dets_xy_rot[j], radius=0.5, color='r', fill=False, linewidth=2)
120 | ax.add_artist(c)
121 |
122 | # plt.savefig('/home/dan/tmp/det_img/frame_%04d.png' % i)
123 |
124 | plt.pause(0.001)
125 |
126 | if _break:
127 | break
128 |
129 |
130 | def play_sequence_with_tracking():
131 | # scans
132 | seq_name = './data/DROWv2-data/train/lunch_2015-11-26-12-04-23.bag.csv'
133 | seq0, seq1 = 109170, 109360
134 | scans, scans_t = [], []
135 | with open(seq_name) as f:
136 | for line in f:
137 | scan_seq, scan_t, scan = line.split(",", 2)
138 | scan_seq = int(scan_seq)
139 | if scan_seq < seq0:
140 | continue
141 | scans.append(np.fromstring(scan, sep=','))
142 | scans_t.append(float(scan_t))
143 | if scan_seq > seq1:
144 | break
145 | scans = np.stack(scans, axis=0)
146 | scans_t = np.array(scans_t)
147 | scan_phi = u.get_laser_phi()
148 |
149 | # odometry, used only for plotting
150 | odo_name = seq_name[:-3] + 'odom2'
151 | odos = np.genfromtxt(odo_name, delimiter=',')
152 | odos_t = odos[:, 1]
153 | odos_phi = odos[:, 4]
154 |
155 | # detector
156 | ckpt = './ckpts/dr_spaam_e40.pth'
157 | detector = Detector(model_name="DR-SPAAM", ckpt_file=ckpt, gpu=True, stride=1, tracking=True)
158 | detector.set_laser_spec(angle_inc=np.radians(0.5), num_pts=450)
159 |
160 | # scanner location
161 | rad_tmp = 0.5 * np.ones(len(scan_phi), dtype=np.float)
162 | xy_scanner = u.rphi_to_xy(rad_tmp, scan_phi)
163 | xy_scanner = np.stack(xy_scanner, axis=1)
164 |
165 | # plot
166 | fig = plt.figure(figsize=(6, 8))
167 | ax = fig.add_subplot(111)
168 |
169 | _break = False
170 |
171 | def p(event):
172 | nonlocal _break
173 | _break = True
174 | fig.canvas.mpl_connect('key_press_event', p)
175 |
176 | # video sequence
177 | odo_idx = 0
178 | for i in range(len(scans)):
179 | plt.cla()
180 |
181 | ax.set_aspect('equal')
182 | ax.set_xlim(-10, 5)
183 | ax.set_ylim(-5, 15)
184 |
185 | # ax.set_title('Frame: %s' % i)
186 | ax.set_title('Press any key to exit.')
187 | ax.axis("off")
188 |
189 | # find matching odometry
190 | while odo_idx < len(odos_t) - 1 and odos_t[odo_idx] < scans_t[i]:
191 | odo_idx += 1
192 | odo_phi = odos_phi[odo_idx]
193 | odo_rot = np.array([[np.cos(odo_phi), np.sin(odo_phi)],
194 | [-np.sin(odo_phi), np.cos(odo_phi)]], dtype=np.float32)
195 |
196 | # plot scanner location
197 | xy_scanner_rot = np.matmul(xy_scanner, odo_rot.T)
198 | ax.plot(xy_scanner_rot[:, 0], xy_scanner_rot[:, 1], c='black')
199 | ax.plot((0, xy_scanner_rot[0, 0] * 1.0), (0, xy_scanner_rot[0, 1] * 1.0), c='black')
200 | ax.plot((0, xy_scanner_rot[-1, 0] * 1.0), (0, xy_scanner_rot[-1, 1] * 1.0), c='black')
201 |
202 | # plot points
203 | scan = scans[i]
204 | scan_x, scan_y = u.rphi_to_xy(scan, scan_phi + odo_phi)
205 | ax.scatter(scan_x, scan_y, s=1, c='blue')
206 |
207 | # inference
208 | dets_xy, dets_cls, instance_mask = detector(scan)
209 |
210 | # plot detection
211 | dets_xy_rot = np.matmul(dets_xy, odo_rot.T)
212 | cls_thresh = 0.3
213 | for j in range(len(dets_xy)):
214 | if dets_cls[j] < cls_thresh:
215 | continue
216 | c = plt.Circle(dets_xy_rot[j], radius=0.5, color='r', fill=False, linewidth=2)
217 | ax.add_artist(c)
218 |
219 | # plot track
220 | cls_thresh = 0.2
221 | tracks, tracks_cls = detector.get_tracklets()
222 | for t, tc in zip(tracks, tracks_cls):
223 | if tc >= cls_thresh and len(t) > 1:
224 | t_rot = np.matmul(t, odo_rot.T)
225 | ax.plot(t_rot[:, 0], t_rot[:, 1], color='g', linewidth=2)
226 |
227 | # plt.savefig('/home/dan/tmp/track3_img/frame_%04d.png' % i)
228 |
229 | plt.pause(0.001)
230 |
231 | if _break:
232 | break
233 |
234 |
235 | if __name__ == "__main__":
236 | parser = argparse.ArgumentParser(description="arg parser")
237 | parser.add_argument("--time", default=False, action='store_true')
238 | parser.add_argument("--dets", default=False, action='store_true')
239 | parser.add_argument("--tracks", default=False, action='store_true')
240 | args = parser.parse_args()
241 |
242 | if args.time:
243 | inference_time()
244 |
245 | if args.dets:
246 | play_sequence()
247 |
248 | if args.tracks:
249 | play_sequence_with_tracking()
250 |
--------------------------------------------------------------------------------
/dr_spaam/bin/eval.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import glob
3 | import os
4 | import pickle
5 | import yaml
6 |
7 | import dr_spaam.utils.eval_utils as eu
8 | from dr_spaam.utils.dataset import create_test_dataloader
9 | from dr_spaam.utils.train_utils import load_checkpoint
10 |
11 |
12 | def eval(model, cfg, epoch, split, it=0, writing=True, plotting=True,
13 | save_pkl=True, tb_log=None, scan_stride=1, pt_stride=1):
14 | root_result_dir = os.path.join('./output', cfg['name'])
15 |
16 | test_loader = create_test_dataloader(data_path="./data/DROWv2-data",
17 | num_scans=cfg['num_scans'],
18 | network_type=cfg['network'],
19 | cutout_kwargs=cfg['cutout_kwargs'],
20 | polar_grid_kwargs=cfg['polar_grid_kwargs'],
21 | pedestrian_only=cfg['pedestrian_only'],
22 | split=split,
23 | scan_stride=scan_stride,
24 | pt_stride=pt_stride)
25 |
26 | eu.eval_epoch_with_output(model, test_loader, epoch=epoch, it=it,
27 | vote_kwargs=cfg['vote_kwargs'],
28 | root_result_dir=root_result_dir, split=split,
29 | tag='eval_%s' % cfg['name'], writing=writing,
30 | plotting=plotting, save_pkl=save_pkl, tb_log=tb_log,
31 | full_eval=True)
32 |
33 |
34 | def eval_dir(cfgs_dir, split, epoch):
35 | cfgs_list = glob.glob(os.path.join(cfgs_dir, '*.yaml'))
36 |
37 | for cfg_file in cfgs_list:
38 | with open(cfg_file, 'r') as f:
39 | cfg = yaml.safe_load(f)
40 | cfg['name'] = os.path.basename(cfg_file).split(".")[0] + cfg['tag']
41 |
42 | ckpt = os.path.join('./output/', cfg['name'], 'ckpts', 'ckpt_e%s.pth' % epoch)
43 | if not os.path.isfile(ckpt):
44 | print("Could not load ckpt %s from config %s" % (ckpt, cfg['name']))
45 | continue
46 |
47 | print("Eval ckpt %s from config %s" % (ckpt, cfg["name"]))
48 | model = eu.cfg_to_model(cfg)
49 | model.cuda()
50 |
51 | _, epoch = load_checkpoint(model=model, filename=ckpt)
52 | eval(model, cfg, epoch, split, writing=True, plotting=False, save_pkl=False)
53 |
54 |
55 | if __name__ == '__main__':
56 | parser = argparse.ArgumentParser(description="arg parser")
57 | parser.add_argument("--cfg", type=str, required=False, default=None)
58 | parser.add_argument("--ckpt", type=str, required=False, default=None)
59 | parser.add_argument("--pkl", type=str, required=False, default=None)
60 | parser.add_argument("--val", default=False, action='store_true')
61 | parser.add_argument("--dir", type=str, required=False, default=None)
62 | parser.add_argument("--epoch", type=int, required=False, default=40)
63 | parser.add_argument("--pt_stride", type=int, required=False, default=1)
64 | parser.add_argument("--scan_stride", type=int, required=False, default=1)
65 | parser.add_argument("--tag", type=str, required=False, default="")
66 | args = parser.parse_args()
67 |
68 | # load existing results, only plotting
69 | if args.pkl is not None:
70 | with open(args.pkl, 'rb') as f:
71 | _, eval_rpt = pickle.load(f)
72 |
73 | # plot
74 | for k, v in eval_rpt.items():
75 | plot_title = args.pkl.split['.'][0] + ('_t%s' % k)
76 | eu.plot_eval_result(v, plot_title=plot_title,
77 | output_file=plot_title + '.png')
78 |
79 | # eval dir
80 | elif args.dir is not None:
81 | split = 'val' if args.val else 'test'
82 | eval_dir(args.dir, split, args.epoch)
83 |
84 | # eval single config
85 | elif args.cfg is not None:
86 | with open(args.cfg, 'r') as f:
87 | cfg = yaml.safe_load(f)
88 | cfg['name'] = os.path.basename(args.cfg).split(".")[0] + cfg['tag']
89 |
90 | # model
91 | model = eu.cfg_to_model(cfg)
92 | model.cuda()
93 |
94 | if args.ckpt is not None:
95 | # ckpt = os.path.join('./output/', cfg['name'], 'ckpts', args.ckpt)
96 | ckpt = args.ckpt
97 | else:
98 | ckpt = os.path.join('./output/', cfg['name'], 'ckpts', 'ckpt_e%s.pth' % args.epoch)
99 |
100 | _, epoch = load_checkpoint(model=model, filename=ckpt)
101 |
102 | split = 'val' if args.val else 'test'
103 |
104 | if len(args.tag) > 0:
105 | cfg['name'] = cfg['name'] + "_" + args.tag
106 |
107 | eval(model, cfg, epoch, split, scan_stride=args.scan_stride, pt_stride=args.pt_stride)
108 |
--------------------------------------------------------------------------------
/dr_spaam/bin/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from shutil import copyfile
4 |
5 | import yaml
6 |
7 | import torch
8 | from torch import optim
9 |
10 | from dr_spaam.utils.dataset import create_dataloader
11 | from dr_spaam.utils.logger import create_logger, create_tb_logger
12 | from dr_spaam.utils.train_utils import Trainer, LucasScheduler, load_checkpoint
13 | from dr_spaam.utils.eval_utils import model_fn, eval_epoch_with_output, cfg_to_model
14 |
15 | from eval import eval
16 |
17 |
18 | torch.backends.cudnn.benchmark = True # Run benchmark to select fastest implementation of ops.
19 |
20 | parser = argparse.ArgumentParser(description="arg parser")
21 | parser.add_argument("--cfg", type=str, required=True, help="configuration of the experiment")
22 | parser.add_argument("--ckpt", type=str, required=False, default=None)
23 | args = parser.parse_args()
24 |
25 | with open(args.cfg, 'r') as f:
26 | cfg = yaml.safe_load(f)
27 | cfg['name'] = os.path.basename(args.cfg).split(".")[0] + cfg['tag']
28 |
29 |
30 | if __name__ == '__main__':
31 | root_result_dir = os.path.join('./', 'output', cfg['name'])
32 | os.makedirs(root_result_dir, exist_ok=True)
33 | copyfile(args.cfg, os.path.join(root_result_dir, os.path.basename(args.cfg)))
34 |
35 | ckpt_dir = os.path.join(root_result_dir, 'ckpts')
36 | os.makedirs(ckpt_dir, exist_ok=True)
37 |
38 | logger, tb_logger = create_logger(root_result_dir), create_tb_logger(root_result_dir)
39 | logger.info('**********************Start logging**********************')
40 |
41 | # log to file
42 | gpu_list = os.environ['CUDA_VISIBLE_DEVICES'] if 'CUDA_VISIBLE_DEVICES' in os.environ.keys() else 'ALL'
43 | logger.info('CUDA_VISIBLE_DEVICES=%s' % gpu_list)
44 |
45 | # create dataloader & network & optimizer
46 | train_loader, eval_loader = create_dataloader(data_path="./data/DROWv2-data",
47 | num_scans=cfg['num_scans'],
48 | batch_size=cfg['batch_size'],
49 | num_workers=cfg['num_workers'],
50 | network_type=cfg['network'],
51 | train_with_val=cfg['train_with_val'],
52 | use_data_augumentation=cfg['use_data_augumentation'],
53 | cutout_kwargs=cfg['cutout_kwargs'],
54 | polar_grid_kwargs=cfg['polar_grid_kwargs'],
55 | pedestrian_only=cfg['pedestrian_only'])
56 |
57 | model = cfg_to_model(cfg)
58 | model.cuda()
59 |
60 | optimizer = optim.Adam(model.parameters(), amsgrad=True)
61 | if 'lr_kwargs' in cfg:
62 | e0, e1 = cfg['lr_kwargs']['e0'], cfg['lr_kwargs']['e1']
63 | else:
64 | e0, e1 = 0, cfg['epochs']
65 | lr_scheduler = LucasScheduler(optimizer, 0, 1e-3, cfg['epochs'], 1e-6)
66 |
67 | if args.ckpt is not None:
68 | starting_iteration, starting_epoch = load_checkpoint(
69 | model=model, optimizer=optimizer, filename=args.ckpt, logger=logger)
70 | elif os.path.isfile(os.path.join(ckpt_dir, 'sigterm_ckpt.pth')):
71 | starting_iteration, starting_epoch = load_checkpoint(
72 | model=model, optimizer=optimizer,
73 | filename=os.path.join(ckpt_dir, 'sigterm_ckpt.pth'),
74 | logger=logger)
75 | else:
76 | starting_iteration, starting_epoch = 0, 0
77 |
78 | # start training
79 | logger.info('**********************Start training**********************')
80 |
81 | model_fn_eval = lambda m, d, e, i: eval_epoch_with_output(
82 | model=m, test_loader=d, epoch=e, it=i, root_result_dir=root_result_dir,
83 | tag=cfg['name'], split='val', writing=True, plotting=True, save_pkl=True,
84 | tb_log=tb_logger, vote_kwargs=cfg['vote_kwargs'], full_eval=False)
85 |
86 | trainer = Trainer(
87 | model,
88 | model_fn,
89 | optimizer,
90 | ckpt_dir=ckpt_dir,
91 | lr_scheduler=lr_scheduler,
92 | model_fn_eval=model_fn_eval,
93 | tb_log=tb_logger,
94 | grad_norm_clip=cfg['grad_norm_clip'],
95 | logger=logger)
96 |
97 | trainer.train(num_epochs=cfg['epochs'],
98 | train_loader=train_loader,
99 | eval_loader=eval_loader,
100 | eval_frequency=max(int(cfg['epochs'] / 20), 1),
101 | ckpt_save_interval=max(int(cfg['epochs'] / 10), 1),
102 | lr_scheduler_each_iter=True,
103 | starting_iteration=starting_iteration,
104 | starting_epoch=starting_epoch)
105 |
106 | # testing
107 | logger.info('**********************Start testing (val)**********************')
108 | eval(model, cfg, epoch=trainer._epoch+1, split='val', it=trainer._it,
109 | writing=True, plotting=True, save_pkl=True, tb_log=tb_logger)
110 |
111 | logger.info('**********************Start testing (test)**********************')
112 | eval(model, cfg, epoch=trainer._epoch+1, split='test', it=trainer._it,
113 | writing=True, plotting=True, save_pkl=True, tb_log=tb_logger)
114 |
115 | tb_logger.close()
116 | logger.info('**********************End**********************')
117 |
--------------------------------------------------------------------------------
/dr_spaam/cfgs/dr_spaam.yaml:
--------------------------------------------------------------------------------
1 | tag: ""
2 | epochs: 40
3 | batch_size: 8
4 | grad_norm_clip: -1.0
5 | num_workers: 8
6 | num_scans: 10
7 | use_data_augumentation: False
8 | train_with_val: False
9 | use_polar_grid: False
10 | focal_loss_gamma: 0.0
11 | pedestrian_only: True
12 |
13 | # Network type: "cutout" or "cutout_spatial"
14 | network: "cutout_spatial"
15 |
16 | similarity_kwargs:
17 | alpha: 0.5
18 | window_size: 11
19 |
20 | cutout_kwargs:
21 | fixed: True
22 | centered: True
23 | window_width: 1.0
24 | window_depth: 0.5
25 | num_cutout_pts: 56
26 | padding_val: 29.99
27 | area_mode: True
28 |
29 | polar_grid_kwargs:
30 | min_range: 0.0
31 | max_range: 30.0
32 | range_bin_size: 0.1
33 | tsdf_clip: 1.0
34 | normalize: True
35 |
36 | # from hyperopt (no longer used)
37 | vote_kwargs:
38 | bin_size: 0.10048541940486004
39 | blur_sigma: 1.459561417325547
40 | min_thresh: 9.447764939669593e-05
41 | vote_collect_radius: 0.15719563974052672
42 |
--------------------------------------------------------------------------------
/dr_spaam/cfgs/drow.yaml:
--------------------------------------------------------------------------------
1 | tag: ""
2 | epochs: 40
3 | batch_size: 8
4 | grad_norm_clip: -1.0
5 | num_workers: 8
6 | num_scans: 1
7 | use_data_augumentation: False
8 | train_with_val: False
9 | use_polar_grid: False
10 | focal_loss_gamma: 0.0
11 | pedestrian_only: True
12 |
13 | # Network type: "cutout" or "cutout_spatial"
14 | network: "cutout"
15 |
16 | cutout_kwargs:
17 | fixed: False
18 | centered: True
19 | window_width: 1.0
20 | window_depth: 0.5
21 | num_cutout_pts: 56
22 | padding_val: 29.99
23 | area_mode: True
24 |
25 | polar_grid_kwargs:
26 | min_range: 0.0
27 | max_range: 30.0
28 | range_bin_size: 0.1
29 | tsdf_clip: 1.0
30 | normalize: True
31 |
32 | # from hyperopt (no longer used)
33 | vote_kwargs:
34 | bin_size: 0.11691041834028301
35 | blur_sigma: 0.7801193226779289
36 | min_thresh: 0.0013299798109178708
37 | vote_collect_radius: 0.1560556348793659
38 |
--------------------------------------------------------------------------------
/dr_spaam/cfgs/drow5.yaml:
--------------------------------------------------------------------------------
1 | tag: ""
2 | epochs: 40
3 | batch_size: 8
4 | grad_norm_clip: -1.0
5 | num_workers: 8
6 | num_scans: 5
7 | use_data_augumentation: False
8 | train_with_val: False
9 | use_polar_grid: False
10 | focal_loss_gamma: 0.0
11 | pedestrian_only: True
12 |
13 | # Network type: "cutout" or "cutout_spatial"
14 | network: "cutout"
15 |
16 | cutout_kwargs:
17 | fixed: False
18 | centered: True
19 | window_width: 1.0
20 | window_depth: 0.5
21 | num_cutout_pts: 56
22 | padding_val: 29.99
23 | area_mode: True
24 |
25 | polar_grid_kwargs:
26 | min_range: 0.0
27 | max_range: 30.0
28 | range_bin_size: 0.1
29 | tsdf_clip: 1.0
30 | normalize: True
31 |
32 | # from hyperopt (no longer used)
33 | vote_kwargs:
34 | bin_size: 0.10041661299422858
35 | blur_sigma: 1.3105587107688101
36 | min_thresh: 1.0228621127903203e-05
37 | vote_collect_radius: 0.15356209212109417
38 |
--------------------------------------------------------------------------------
/dr_spaam/hyperopt/generate_inference_result.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | from tqdm import tqdm
3 | import yaml
4 | import numpy as np
5 | import torch
6 |
7 | import os, sys
8 | sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '../'))
9 |
10 | import utils.utils as u
11 | import utils.eval_utils as eu
12 | from utils.dataset import create_test_dataloader
13 |
14 |
15 | if __name__=='__main__':
16 | cfg_file = './cfgs/NCT_cfgs/STEP_bl_5.yaml'
17 | with open(cfg_file, 'r') as f:
18 | cfg = yaml.safe_load(f)
19 | cfg['name'] = os.path.basename(cfg_file).split(".")[0] + cfg['tag']
20 |
21 | model = eu.cfg_to_model(cfg)
22 | model.cuda()
23 |
24 | ckpt_file = './output/%s/ckpts/ckpt_e40.pth' % cfg['name']
25 | ckpt = torch.load(ckpt_file)
26 | model.load_state_dict(ckpt['model_state'])
27 |
28 | test_loader = create_test_dataloader(data_path="../data/DROWv2-data",
29 | num_scans=cfg['num_scans'],
30 | network_type=cfg['network'],
31 | cutout_kwargs=cfg['cutout_kwargs'],
32 | polar_grid_kwargs=cfg['polar_grid_kwargs'],
33 | pedestrian_only=cfg['pedestrian_only'],
34 | split='val',
35 | scan_stride=1,
36 | pt_stride=1)
37 |
38 | scan_list, pred_cls_list, pred_reg_list, gts_xy_list, gts_inds_list = [], [], [], [], []
39 | for i, data in enumerate(tqdm(test_loader)):
40 | model.eval()
41 |
42 | input = torch.from_numpy(data['input']).cuda(non_blocking=True).float()
43 | with torch.no_grad():
44 | model_rtn = model(input)
45 |
46 | if len(model_rtn) == 3:
47 | pred_cls, pred_reg, _ = model_rtn
48 | else:
49 | pred_cls, pred_reg = model_rtn
50 |
51 | pred_cls = torch.sigmoid(pred_cls[0]).data.cpu().numpy()
52 | pred_reg = pred_reg[0].data.cpu().numpy()
53 |
54 | pred_cls_list.append(pred_cls)
55 | pred_reg_list.append(pred_reg)
56 |
57 | for gt in data['dets_wp'][0]:
58 | xy = u.rphi_to_xy(gt[0], gt[1])
59 | gts_xy_list.append(np.array(xy))
60 | gts_inds_list.append(i)
61 |
62 | scan_list.append(data['scans'][0][-1])
63 |
64 | scans = np.stack(scan_list, axis=0)
65 | pred_cls = np.stack(pred_cls_list, axis=0)
66 | pred_reg = np.stack(pred_reg_list, axis=0)
67 | gts_xy = np.stack(gts_xy_list, axis=0)
68 | gts_inds = np.array(gts_inds_list)
69 |
70 | pkl_file = './hyperopt/inference_result_%s.pkl' % cfg['name']
71 | with open(pkl_file, 'wb') as f:
72 | pickle.dump([scans, pred_cls, pred_reg, gts_xy, gts_inds], f)
73 |
--------------------------------------------------------------------------------
/dr_spaam/hyperopt/objective_functions.py:
--------------------------------------------------------------------------------
1 | import json
2 | import hyperopt as hp
3 | import numpy as np
4 |
5 | import os, sys
6 | sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '../'))
7 |
8 | import utils.utils as u
9 | import utils.eval_utils as eu
10 |
11 |
12 | def objective(vote_kwargs, scans, pred_cls, pred_reg, gts_xy, gts_inds):
13 | # get detection
14 | scan_phi = u.get_laser_phi()
15 | dets_xy_list, dets_cls_list, dets_inds_list = [], [], []
16 | for i, (scan, p_cls, p_reg) in enumerate(zip(scans, pred_cls, pred_reg)):
17 | dets_xy, dets_cls, _ = u.group_predicted_center(
18 | scan, scan_phi, p_cls, p_reg, **vote_kwargs)
19 |
20 | for xy, c in zip(dets_xy, dets_cls):
21 | dets_xy_list.append(xy)
22 | dets_cls_list.append(c)
23 | dets_inds_list.append(i)
24 |
25 | dets_xy = np.array(dets_xy_list)
26 | dets_cls = np.array(dets_cls_list)
27 | dets_inds = np.array(dets_inds_list)
28 |
29 | # compute precision recall
30 | eval_radius = 0.5
31 | rpt_tuple = eu.compute_prec_rec(dets_xy, dets_cls[:, 0], dets_inds,
32 | gts_xy, gts_inds, eval_radius)
33 | ap, f1, eer = eu.eval_prec_rec(*rpt_tuple[:2])
34 |
35 | # objective, maximize AP_0.5 for pedestrian class
36 | rtn_dict = {'loss': -ap,
37 | 'status': hp.STATUS_OK,
38 | 'real_attachments': {'kw': json.dumps(vote_kwargs).encode('utf-8'),
39 | 'auc': json.dumps(ap).encode('utf-8')}}
40 |
41 | return rtn_dict
42 |
43 |
--------------------------------------------------------------------------------
/dr_spaam/hyperopt/run_hyperopt.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "toc": "true"
7 | },
8 | "source": [
9 | "# Table of Contents\n",
10 | " "
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": 1,
16 | "metadata": {
17 | "ExecuteTime": {
18 | "end_time": "2018-02-22T22:19:27.721877Z",
19 | "start_time": "2018-02-22T22:19:26.989615Z"
20 | }
21 | },
22 | "outputs": [],
23 | "source": [
24 | "%matplotlib inline\n",
25 | "%config InlineBackend.figure_format = 'retina'\n",
26 | "\n",
27 | "# Font which got unicode math stuff.\n",
28 | "import matplotlib as mpl\n",
29 | "mpl.rcParams['font.family'] = 'DejaVu Sans'\n",
30 | "\n",
31 | "# Much more readable plots\n",
32 | "import matplotlib.pyplot as plt\n",
33 | "plt.style.use('ggplot')\n",
34 | "\n",
35 | "# Much better than plt.subplots() \n",
36 | "from mpl_toolkits.axes_grid1 import ImageGrid"
37 | ]
38 | },
39 | {
40 | "cell_type": "code",
41 | "execution_count": 2,
42 | "metadata": {
43 | "ExecuteTime": {
44 | "end_time": "2018-02-22T22:19:28.245175Z",
45 | "start_time": "2018-02-22T22:19:28.242267Z"
46 | }
47 | },
48 | "outputs": [],
49 | "source": [
50 | "import functools\n",
51 | "import os\n",
52 | "import json\n",
53 | "import math\n",
54 | "import numpy as np\n",
55 | "import pickle\n",
56 | "\n",
57 | "import hyperopt as hp\n",
58 | "import hyperopt.mongoexp\n",
59 | "\n",
60 | "import objective_functions as ofn"
61 | ]
62 | },
63 | {
64 | "cell_type": "code",
65 | "execution_count": 3,
66 | "metadata": {
67 | "ExecuteTime": {
68 | "end_time": "2018-02-22T22:19:29.021451Z",
69 | "start_time": "2018-02-22T22:19:29.010855Z"
70 | }
71 | },
72 | "outputs": [],
73 | "source": [
74 | "votes_to_detections_space = {\n",
75 | " 'bin_size': hyperopt.hp.uniform('bin_size', 0.1, 1.0),\n",
76 | " 'vote_collect_radius': hyperopt.hp.uniform('vote_collect_radius', 0.01, 2.0),\n",
77 | " 'min_thresh': hyperopt.hp.loguniform('min_thresh', -7*np.log(10), -2*np.log(10)),\n",
78 | " 'blur_sigma': hyperopt.hp.uniform('blur_sigma', 0.0, 5.0),\n",
79 | "}"
80 | ]
81 | },
82 | {
83 | "cell_type": "code",
84 | "execution_count": 4,
85 | "metadata": {
86 | "ExecuteTime": {
87 | "end_time": "2018-02-22T22:19:30.558461Z",
88 | "start_time": "2018-02-22T22:19:30.436275Z"
89 | }
90 | },
91 | "outputs": [],
92 | "source": [
93 | "run_name = 'DR-SPAAM_11_5_new'\n",
94 | "mongodb_port = 27012\n",
95 | "trials = hp.mongoexp.MongoTrials('mongo://localhost:{}/hyperopt/jobs'.format(mongodb_port), exp_key=run_name)"
96 | ]
97 | },
98 | {
99 | "cell_type": "code",
100 | "execution_count": 5,
101 | "metadata": {
102 | "ExecuteTime": {
103 | "end_time": "2018-02-22T22:19:29.003364Z",
104 | "start_time": "2018-02-22T22:19:28.690003Z"
105 | }
106 | },
107 | "outputs": [],
108 | "source": [
109 | "inference_result = '/home/jia/v3/hyperopt/inference_result_new.pkl'\n",
110 | "with open(inference_result, 'rb') as f:\n",
111 | " scans, pred_cls, pred_reg, gts_xy, gts_inds = pickle.load(f)"
112 | ]
113 | },
114 | {
115 | "cell_type": "code",
116 | "execution_count": 7,
117 | "metadata": {
118 | "ExecuteTime": {
119 | "start_time": "2018-02-23T12:33:52.721Z"
120 | }
121 | },
122 | "outputs": [
123 | {
124 | "name": "stderr",
125 | "output_type": "stream",
126 | "text": [
127 | "over-writing old domain trials attachment\n"
128 | ]
129 | },
130 | {
131 | "name": "stdout",
132 | "output_type": "stream",
133 | "text": [
134 | " 52%|█████▏ | 15491/30000 [17:41:52<16:34:33, 4.11s/trial, best loss: -0.5389506816864014]\n"
135 | ]
136 | },
137 | {
138 | "ename": "KeyboardInterrupt",
139 | "evalue": "",
140 | "output_type": "error",
141 | "traceback": [
142 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
143 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
144 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0mtrials\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtrials\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0malgo\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mhp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtpe\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msuggest\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 13\u001b[0;31m max_evals=30000)\n\u001b[0m",
145 | "\u001b[0;32m~/anaconda3/envs/drow/lib/python3.6/site-packages/hyperopt/fmin.py\u001b[0m in \u001b[0;36mfmin\u001b[0;34m(fn, space, algo, max_evals, timeout, loss_threshold, trials, rstate, allow_trials_fmin, pass_expr_memo_ctrl, catch_eval_exceptions, verbose, return_argmin, points_to_evaluate, max_queue_len, show_progressbar)\u001b[0m\n\u001b[1;32m 480\u001b[0m \u001b[0mcatch_eval_exceptions\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcatch_eval_exceptions\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 481\u001b[0m \u001b[0mreturn_argmin\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mreturn_argmin\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 482\u001b[0;31m \u001b[0mshow_progressbar\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mshow_progressbar\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 483\u001b[0m )\n\u001b[1;32m 484\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
146 | "\u001b[0;32m~/anaconda3/envs/drow/lib/python3.6/site-packages/hyperopt/base.py\u001b[0m in \u001b[0;36mfmin\u001b[0;34m(self, fn, space, algo, max_evals, timeout, loss_threshold, max_queue_len, rstate, verbose, pass_expr_memo_ctrl, catch_eval_exceptions, return_argmin, show_progressbar)\u001b[0m\n\u001b[1;32m 684\u001b[0m \u001b[0mcatch_eval_exceptions\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcatch_eval_exceptions\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 685\u001b[0m \u001b[0mreturn_argmin\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mreturn_argmin\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 686\u001b[0;31m \u001b[0mshow_progressbar\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mshow_progressbar\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 687\u001b[0m )\n\u001b[1;32m 688\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
147 | "\u001b[0;32m~/anaconda3/envs/drow/lib/python3.6/site-packages/hyperopt/fmin.py\u001b[0m in \u001b[0;36mfmin\u001b[0;34m(fn, space, algo, max_evals, timeout, loss_threshold, trials, rstate, allow_trials_fmin, pass_expr_memo_ctrl, catch_eval_exceptions, verbose, return_argmin, points_to_evaluate, max_queue_len, show_progressbar)\u001b[0m\n\u001b[1;32m 507\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 508\u001b[0m \u001b[0;31m# next line is where the fmin is actually executed\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 509\u001b[0;31m \u001b[0mrval\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexhaust\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 510\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 511\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mreturn_argmin\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
148 | "\u001b[0;32m~/anaconda3/envs/drow/lib/python3.6/site-packages/hyperopt/fmin.py\u001b[0m in \u001b[0;36mexhaust\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 328\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mexhaust\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 329\u001b[0m \u001b[0mn_done\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrials\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 330\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax_evals\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mn_done\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mblock_until_done\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0masynchronous\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 331\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrials\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrefresh\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 332\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
149 | "\u001b[0;32m~/anaconda3/envs/drow/lib/python3.6/site-packages/hyperopt/fmin.py\u001b[0m in \u001b[0;36mrun\u001b[0;34m(self, N, block_until_done)\u001b[0m\n\u001b[1;32m 269\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnew_trials\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 270\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrials\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minsert_trial_docs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnew_trials\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 271\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrials\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrefresh\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 272\u001b[0m \u001b[0mn_queued\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnew_trials\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 273\u001b[0m \u001b[0mqlen\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_queue_len\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
150 | "\u001b[0;32m~/anaconda3/envs/drow/lib/python3.6/site-packages/hyperopt/mongoexp.py\u001b[0m in \u001b[0;36mrefresh\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 845\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 846\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mrefresh\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 847\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrefresh_tids\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 848\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 849\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_insert_trial_docs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdocs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
151 | "\u001b[0;32m~/anaconda3/envs/drow/lib/python3.6/site-packages/hyperopt/mongoexp.py\u001b[0m in \u001b[0;36mrefresh_tids\u001b[0;34m(self, tids)\u001b[0m\n\u001b[1;32m 764\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 765\u001b[0m \u001b[0;31m# which records are in db but not in existing, and vice versa\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 766\u001b[0;31m \u001b[0mdb_in_existing\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfast_isin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdb_data\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"_id\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mexisting_data\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"_id\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 767\u001b[0m \u001b[0mexisting_in_db\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfast_isin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mexisting_data\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"_id\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdb_data\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"_id\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 768\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
152 | "\u001b[0;32m~/anaconda3/envs/drow/lib/python3.6/site-packages/hyperopt/utils.py\u001b[0m in \u001b[0;36mfast_isin\u001b[0;34m(X, Y)\u001b[0m\n\u001b[1;32m 155\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mY\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 156\u001b[0m \u001b[0mT\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mY\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcopy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 157\u001b[0;31m \u001b[0mT\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msort\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 158\u001b[0m \u001b[0mD\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mT\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msearchsorted\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 159\u001b[0m \u001b[0mT\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mT\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
153 | "\u001b[0;32m~/anaconda3/envs/drow/lib/python3.6/site-packages/bson/objectid.py\u001b[0m in \u001b[0;36m__lt__\u001b[0;34m(self, other)\u001b[0m\n\u001b[1;32m 277\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__lt__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mother\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 278\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mother\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mObjectId\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 279\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__id\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0mother\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbinary\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 280\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mNotImplemented\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 281\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
154 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
155 | ]
156 | }
157 | ],
158 | "source": [
159 | "votes_to_detections_func = functools.partial(ofn.objective,\n",
160 | " scans=scans, \n",
161 | " pred_cls=pred_cls, \n",
162 | " pred_reg=pred_reg,\n",
163 | " gts_xy=gts_xy, \n",
164 | " gts_inds=gts_inds)\n",
165 | "\n",
166 | " \n",
167 | "best = hp.fmin(votes_to_detections_func, \n",
168 | " space=votes_to_detections_space, \n",
169 | " trials=trials,\n",
170 | " algo=hp.tpe.suggest,\n",
171 | " max_evals=30000)"
172 | ]
173 | },
174 | {
175 | "cell_type": "code",
176 | "execution_count": 9,
177 | "metadata": {},
178 | "outputs": [
179 | {
180 | "data": {
181 | "text/plain": [
182 | "[ObjectId('5e57d416d7c82d8659e556c6'),\n",
183 | " 2,\n",
184 | " 182811,\n",
185 | " None,\n",
186 | " SON([('loss', -0.5389506816864014), ('status', 'ok'), ('real_attachments', SON([('kw', b'{\"bin_size\": 0.10048541940486004, \"blur_sigma\": 1.459561417325547, \"min_thresh\": 9.447764939669593e-05, \"vote_collect_radius\": 0.15719563974052672}'), ('auc', b'0.5389506816864014')]))]),\n",
187 | " SON([('tid', 182811), ('cmd', ['domain_attachment', 'FMinIter_Domain']), ('workdir', None), ('idxs', SON([('bin_size', [182811]), ('blur_sigma', [182811]), ('min_thresh', [182811]), ('vote_collect_radius', [182811])])), ('vals', SON([('bin_size', [0.10048541940486004]), ('blur_sigma', [1.459561417325547]), ('min_thresh', [9.447764939669593e-05]), ('vote_collect_radius', [0.15719563974052672])]))]),\n",
188 | " 'DR-SPAAM_11_5_new',\n",
189 | " ['ncm0239.hpc.itc.rwth-aachen.de:59526'],\n",
190 | " 3,\n",
191 | " datetime.datetime(2020, 2, 27, 14, 37, 11, 293000),\n",
192 | " datetime.datetime(2020, 2, 27, 14, 37, 30, 169000)]"
193 | ]
194 | },
195 | "execution_count": 9,
196 | "metadata": {},
197 | "output_type": "execute_result"
198 | }
199 | ],
200 | "source": [
201 | "trials.best_trial.values()"
202 | ]
203 | },
204 | {
205 | "cell_type": "code",
206 | "execution_count": 23,
207 | "metadata": {},
208 | "outputs": [
209 | {
210 | "data": {
211 | "text/plain": [
212 | "{'loss': -0.529565155506134,\n",
213 | " 'status': 'ok',\n",
214 | " 'real_attachments': {'kw': b'{\"bin_size\": 0.1315193551894875, \"blur_sigma\": 0.8769606532708437, \"min_thresh\": 9.436743980879768e-05, \"vote_collect_radius\": 0.16901292463464487}',\n",
215 | " 'auc': b'0.529565155506134'}}"
216 | ]
217 | },
218 | "execution_count": 23,
219 | "metadata": {},
220 | "output_type": "execute_result"
221 | }
222 | ],
223 | "source": [
224 | "votes_to_detections_func = functools.partial(ofn.objective,\n",
225 | " scans=scans, \n",
226 | " pred_cls=pred_cls, \n",
227 | " pred_reg=pred_reg,\n",
228 | " gts_xy=gts_xy, \n",
229 | " gts_inds=gts_inds)\n",
230 | "\n",
231 | "votes_to_detections_func({'bin_size': 0.1315193551894875,\n",
232 | " 'blur_sigma': 0.8769606532708437,\n",
233 | " 'min_thresh': 9.436743980879768e-5,\n",
234 | " 'vote_collect_radius': 0.16901292463464487})"
235 | ]
236 | },
237 | {
238 | "cell_type": "code",
239 | "execution_count": null,
240 | "metadata": {
241 | "ExecuteTime": {
242 | "end_time": "2018-02-23T08:15:45.959270Z",
243 | "start_time": "2018-02-23T08:15:40.602968Z"
244 | },
245 | "scrolled": false
246 | },
247 | "outputs": [],
248 | "source": [
249 | "hp_values = [(t['result']['loss'], t['misc']['vals']) for t in trials.trials if 'loss' in t['result']]\n",
250 | "scores = np.asarray([-t['result']['loss'] for t in trials.trials if 'loss' in t['result']])\n",
251 | "keys = hp_values[0][1].keys()\n",
252 | "val_count = len(keys)\n",
253 | "\n",
254 | "min_score = np.min(np.exp(scores))\n",
255 | "max_score = np.max(np.exp(scores))\n",
256 | "norm = mpl.colors.Normalize(min_score, max_score)\n",
257 | "\n",
258 | "\n",
259 | "fig, ax = plt.subplots(val_count,1,figsize=(18,5*val_count))\n",
260 | "\n",
261 | "for a, k in zip(ax, sorted(keys)):\n",
262 | " hp_vals = np.asarray([h[1][k][0] for h in hp_values])\n",
263 | " if k == 'min_thresh':\n",
264 | " hp_vals = np.log10(hp_vals)\n",
265 | " N, bins, patches = a.hist(hp_vals, bins=100)\n",
266 | " scores_sorted = scores[np.argsort(hp_vals)]\n",
267 | " \n",
268 | " start_idx = 0\n",
269 | " bin_scores = []\n",
270 | " for n in N:\n",
271 | " bin_scores.append(np.mean(scores_sorted[start_idx:start_idx+int(n)]))\n",
272 | " start_idx +=int(n)\n",
273 | " \n",
274 | " for b, thispatch in zip(bin_scores, patches):\n",
275 | " color = plt.cm.viridis(norm(np.exp(b)))\n",
276 | " thispatch.set_facecolor(color)\n",
277 | " a.set_title(k)"
278 | ]
279 | },
280 | {
281 | "cell_type": "code",
282 | "execution_count": null,
283 | "metadata": {},
284 | "outputs": [],
285 | "source": []
286 | }
287 | ],
288 | "metadata": {
289 | "kernelspec": {
290 | "display_name": "Python 3",
291 | "language": "python",
292 | "name": "python3"
293 | },
294 | "language_info": {
295 | "codemirror_mode": {
296 | "name": "ipython",
297 | "version": 3
298 | },
299 | "file_extension": ".py",
300 | "mimetype": "text/x-python",
301 | "name": "python",
302 | "nbconvert_exporter": "python",
303 | "pygments_lexer": "ipython3",
304 | "version": "3.6.9"
305 | },
306 | "nav_menu": {},
307 | "toc": {
308 | "navigate_menu": true,
309 | "number_sections": true,
310 | "sideBar": true,
311 | "threshold": 6,
312 | "toc_cell": true,
313 | "toc_section_display": "block",
314 | "toc_window_display": false
315 | }
316 | },
317 | "nbformat": 4,
318 | "nbformat_minor": 1
319 | }
320 |
--------------------------------------------------------------------------------
/dr_spaam/hyperopt/scripts/hyperopt_master_tmux.bash:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | tmux -2 new-session -d -s hyperopt_master
3 |
4 | machines=(
5 | "Einhorn:12"
6 | # "Grimbergen:10"
7 | # "Bush:3"
8 | # "Carolus:2"
9 | # "Fix:2"
10 | # "Hund:2"
11 | # "Kriek:2"
12 | # "Schlunz:2"
13 | # "Tsingtao:4"
14 | # "Veltins:2"
15 | # "Zhiguli:2"
16 | # "Astra:1"
17 | # "Faxe:2"
18 | # "Grolsch:2"
19 | # "Hoppiness:6"
20 | # "Kilkenny:4"
21 | # "Lasko:4"
22 | # "Mickey:6"
23 | # "Paulaner:2"
24 | # "Bevog:2"
25 | # "Borsodi:2" <-- dies with hp workers.
26 | # "Duff:2"
27 | # "Duvel:3"
28 | # "Helios:3"
29 | # "Tyskie:2"
30 | # "Reissdorf:8"
31 | # "Becks:2"
32 | # "Corona:4"
33 | # "Kingfisher:4"
34 | # "Stella:2"
35 | "Chimay:10"
36 | # "Rothaus:2"
37 | )
38 |
39 | # Create the windows for each machine.
40 | for m in ${machines[@]}; do
41 | machine=${m%:*}
42 | count=${m#*:}
43 |
44 | echo $machine:$count
45 | tmux rename-window "$machine"
46 | tmux new-window
47 | done
48 |
49 | # Fix the redundant window created by the last loop entry.
50 | # And move to the first window again.
51 | tmux kill-window
52 | sleep 1
53 |
54 | # SSH to the actual machine and run the jobs
55 | for m in ${machines[@]}; do
56 | machine=${m%:*}
57 | count=${m#*:}
58 | tmux send-keys -t hyperopt_master:$machine "ssh $machine" C-m
59 | tmux send-keys -t hyperopt_master:$machine "~/drower9k/hyperopt_scripts/hyperopt_slave.bash $count" C-m
60 | done
61 |
62 | tmux select-window -t hyperopt_master:0
63 |
64 | #Attach to session
65 | tmux -2 attach-session -t hyperopt_master
66 |
--------------------------------------------------------------------------------
/dr_spaam/hyperopt/scripts/hyperopt_mongo_tmux.bash:
--------------------------------------------------------------------------------
1 | d#!/bin/bash
2 | tmux -2 new-session -d -s hyperopt_mongodb
3 |
4 | tmux send-keys "mongod --dbpath /home/jia/tmp/dumps/drow/hyperopt_mongodb --port 27012 --directoryperdb --journal --bind_ip_all" C-m
5 |
6 | #Attach to session
7 | tmux -2 attach-session -t hyperopt_mongodb
8 |
--------------------------------------------------------------------------------
/dr_spaam/hyperopt/scripts/hyperopt_slave.bash:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | tmux -2 new-session -d -s slave_hyperopt
3 |
4 | # tmux send-keys "ssh -L -f -N localhost:12345:chimay:27010" C-m
5 |
6 | for (( c=1; c<$1; c++ ))
7 | do
8 | tmux split-window -v
9 | tmux select-layout tiled
10 | done
11 |
12 |
13 | for (( c=0; c<$1; c++ ))
14 | do
15 | tmux select-pane -t $c
16 | tmux send-keys "source /home/jia/torch_cuda10_venv/bin/activate" C-m
17 | tmux send-keys "cd /home/jia/tmp" C-m
18 | tmux send-keys "PYTHONPATH=$PYTHONPATH:~/drower9k hyperopt-mongo-worker --mongo=chimay:27010/hyperopt --reserve-timeout=inf --poll-interval=15" C-m
19 | done
20 |
21 | #Attach to session
22 | tmux -2 attach-session -t slave_hyperopt
23 |
--------------------------------------------------------------------------------
/dr_spaam/hyperopt/scripts/hyperopt_slave_claix.bash:
--------------------------------------------------------------------------------
1 | #!/usr/local_rwth/bin/zsh
2 |
3 | #SBATCH --job-name=hyperopt
4 |
5 | #SBATCH --output=/home/yx643192/slurm_logs/hyperopt/%J_%x.log
6 |
7 | #SBATCH --cpus-per-task=1
8 |
9 | #SBATCH --mem-per-cpu=3G
10 |
11 | #SBATCH --time=2-00:00:00
12 |
13 | #SBATCH --signal=TERM@120
14 |
15 | #SBATCH --partition=c18m
16 |
17 | #SBATCH --account=rwth0485
18 |
19 | #SBATCH --array=1-50
20 |
21 | source $HOME/.zshrc
22 | conda activate torch10
23 |
24 | cd /work/yx643192/hyperopt_tmp
25 |
26 | ssh -4 -N -f -J jia@recog.vision.rwth-aachen.de -L localhost:12345:chimay:27012 jia@chimay
27 |
28 | PYTHONPATH=$PYTHONPATH:/home/yx643192/v3/hyperopt hyperopt-mongo-worker --mongo=localhost:12345/hyperopt --reserve-timeout=inf --poll-interval=15
29 |
--------------------------------------------------------------------------------
/dr_spaam/hyperopt/scripts/hyperopt_slave_colossus.bash:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd /tmp
4 | for (( c=0; c<24; c++ ))
5 | do
6 | PYTHONPATH=$PYTHONPATH:~/drower9k hyperopt-mongo-worker --mongo=einhorn:27010/hyperopt --reserve-timeout=inf --poll-interval=15 &
7 | done
8 | wait %1
9 |
--------------------------------------------------------------------------------
/dr_spaam/hyperopt/scripts/kill_hyperopt_master_tmux.bash:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Loop over all panes, assuming their name is correct.
4 | for machine in $(tmux list-windows -t hyperopt_master -F '#W'); do
5 | ssh $machine 'tmux kill-session -t slave_hyperopt'
6 | done
7 |
8 | # Finally kill the session
9 | tmux kill-session -t hyperopt_master
--------------------------------------------------------------------------------
/dr_spaam/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name="dr_spaam",
5 | version="1.1",
6 | author='Dan Jia',
7 | author_email='jia@vision.rwth-aachen.de',
8 | package_dir={'': 'src'},
9 | packages=find_packages(where='src'),
10 | license='LICENSE.txt',
11 | description='DR-SPAAM, a deep-learning based person detector for 2D range data.'
12 | )
13 |
--------------------------------------------------------------------------------
/dr_spaam/src/dr_spaam/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VisualComputingInstitute/DR-SPAAM-Detector/e5a5f73f69523b90829be06a2558b597c2934f9f/dr_spaam/src/dr_spaam/__init__.py
--------------------------------------------------------------------------------
/dr_spaam/src/dr_spaam/detector.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from .utils import utils as u
4 | from .model.drow import DROW, SpatialDROW
5 |
6 |
7 | class Detector(object):
8 | def __init__(self, model_name, ckpt_file, gpu=True, stride=1, tracking=False):
9 | """DR-SPAAM detector wrapper
10 |
11 | Args:
12 | model_name (str): "DROW", "DROW-T5", or "DR-SPAAM"
13 | ckpt_file (str): Path to checkpoint
14 | gpu (bool, optional): True to use GPU. Defaults to True.
15 | stride (int, optional): Use stride to skip scan points. Defaults to 1.
16 | tracking (bool, optional): True to do tracking. Defaults to False.
17 | """
18 | self._gpu, self._scan_phi, self._stride = gpu, None, stride
19 | self._model_name = model_name
20 | self._use_dr_spaam = model_name == "DR-SPAAM"
21 |
22 | self._ct_kwargs = {
23 | "fixed": False,
24 | "centered": True,
25 | "window_width": 1.0,
26 | "window_depth": 0.5,
27 | "num_cutout_pts": 56,
28 | "padding_val": 29.99,
29 | "area_mode": True
30 | }
31 |
32 | # NOTE: Voting is replaced by NMS and vote kwargs are no longer needed
33 | if model_name == "DR-SPAAM":
34 | model = SpatialDROW(num_pts=self._ct_kwargs['num_cutout_pts'],
35 | pedestrian_only=True,
36 | alpha=0.5,
37 | window_size=11)
38 | self._vote_kwargs = {
39 | "bin_size": 0.10048541940486004,
40 | "blur_sigma": 1.459561417325547,
41 | "min_thresh": 9.447764939669593e-05,
42 | "vote_collect_radius": 0.15719563974052672
43 | }
44 | elif model_name == "DROW":
45 | model = DROW(num_scans=1,
46 | num_pts=self._ct_kwargs['num_cutout_pts'],
47 | pedestrian_only=True)
48 | self._vote_kwargs = {
49 | "bin_size": 0.11691041834028301,
50 | "blur_sigma": 0.7801193226779289,
51 | "min_thresh": 0.0013299798109178708,
52 | "vote_collect_radius": 0.1560556348793659
53 | }
54 | elif model_name == "DROW-T5":
55 | model = DROW(num_scans=5,
56 | num_pts=self._ct_kwargs['num_cutout_pts'],
57 | pedestrian_only=True)
58 | self._vote_kwargs = {
59 | "bin_size": 0.10041661299422858,
60 | "blur_sigma": 1.3105587107688101,
61 | "min_thresh": 1.0228621127903203e-05,
62 | "vote_collect_radius": 0.15356209212109417
63 | }
64 | else:
65 | raise RuntimeError(
66 | "Unknown model name '%s'. Use 'DR-SPAAM', 'DROW', or 'DROW-T5'." % (model_name))
67 |
68 | ckpt = torch.load(ckpt_file)
69 | model.load_state_dict(ckpt['model_state'])
70 |
71 | model.eval()
72 | self._model = model.cuda() if gpu else model
73 |
74 | self._tracker = _TrackingExtension() if tracking else None
75 | if self._use_dr_spaam:
76 | self._fea = None
77 |
78 | def __call__(self, scan):
79 | assert self.laser_spec_set(), "Need to call set_laser_spec() first."
80 |
81 | if len(scan.shape) == 1:
82 | scan = scan[None, ...]
83 |
84 | # preprocess
85 | ct = u.scans_to_cutout(
86 | scan, self._scan_phi,
87 | stride=self._stride, **self._ct_kwargs)
88 | ct = torch.from_numpy(ct).float()
89 |
90 | if self._gpu:
91 | ct = ct.cuda()
92 |
93 | # inference
94 | with torch.no_grad():
95 | if self._use_dr_spaam:
96 | pred_cls, pred_reg, self._fea, sim_matrix = self._model(
97 | ct.unsqueeze(dim=0), testing=True, fea_template=self._fea)
98 | else:
99 | pred_cls, pred_reg = self._model(ct.unsqueeze(dim=0)) # one dim for batch
100 | pred_cls = torch.sigmoid(pred_cls[0]).data.cpu().numpy()
101 | pred_reg = pred_reg[0].data.cpu().numpy()
102 |
103 | # postprocess
104 | dets_xy, dets_cls, instance_mask = u.nms_predicted_center(
105 | scan[-1, ::self._stride], self._scan_phi[::self._stride], pred_cls, pred_reg, min_dist=0.5)
106 | # dets_xy, dets_cls, instance_mask = u.group_predicted_center(
107 | # scan[-1], self._scan_phi, pred_cls, pred_reg, **self._vote_kwargs)
108 |
109 | if self._tracker:
110 | self._tracker(dets_xy, dets_cls, instance_mask, sim_matrix)
111 |
112 | return dets_xy, dets_cls, instance_mask
113 |
114 | def get_tracklets(self):
115 | assert self._tracker is not None
116 | return self._tracker.get_tracklets()
117 |
118 | def set_laser_spec(self, angle_inc, num_pts):
119 | self._scan_phi = u.get_laser_phi(angle_inc, num_pts)
120 |
121 | def laser_spec_set(self):
122 | return self._scan_phi is not None
123 |
124 |
125 | class _TrackingExtension(object):
126 | def __init__(self):
127 | self._prev_dets_xy = None
128 | self._prev_dets_cls = None
129 | self._prev_instance_mask = None
130 | self._prev_dets_to_tracks = None # a list of track id for each detection
131 |
132 | self._tracks = []
133 | self._tracks_cls = []
134 | self._tracks_age = []
135 |
136 | self._max_track_age = 100
137 | self._max_assoc_dist = 0.7
138 |
139 | def __call__(self, dets_xy, dets_cls, instance_mask, sim_matrix):
140 | # first frame
141 | if self._prev_dets_xy is None:
142 | self._prev_dets_xy = dets_xy
143 | self._prev_dets_cls = dets_cls
144 | self._prev_instance_mask = instance_mask
145 | self._prev_dets_to_tracks = np.arange(len(dets_xy), dtype=np.int32)
146 |
147 | for d_xy, d_cls in zip(dets_xy, dets_cls):
148 | self._tracks.append([d_xy])
149 | self._tracks_cls.append([np.asscalar(d_cls)])
150 | self._tracks_age.append(0)
151 |
152 | return
153 |
154 | # associate detections
155 | prev_dets_inds = self._associate_prev_det(
156 | dets_xy, dets_cls, instance_mask, sim_matrix)
157 |
158 | # mapping from detection indices to tracklets indices
159 | dets_to_tracks = []
160 |
161 | # assign current detections to tracks based on assocation with previous
162 | # detections
163 | for d_idx, (d_xy, d_cls, prev_d_idx) in enumerate(
164 | zip(dets_xy, dets_cls, prev_dets_inds)):
165 | # distance between assocated detections
166 | dxy = self._prev_dets_xy[prev_d_idx] - d_xy
167 | dxy = np.hypot(dxy[0], dxy[1])
168 |
169 | if dxy < self._max_assoc_dist and prev_d_idx >= 0:
170 | # if current detection is close to the associated detection,
171 | # append to the tracklet
172 | ti = self._prev_dets_to_tracks[prev_d_idx]
173 | self._tracks[ti].append(d_xy)
174 | self._tracks_cls[ti].append(np.asscalar(d_cls))
175 | self._tracks_age[ti] = -1
176 | dets_to_tracks.append(ti)
177 | else:
178 | # otherwise start a new tracklet
179 | self._tracks.append([d_xy])
180 | self._tracks_cls.append([np.asscalar(d_cls)])
181 | self._tracks_age.append(-1)
182 | dets_to_tracks.append(len(self._tracks) - 1)
183 |
184 | # tracklet age
185 | for i in range(len(self._tracks_age)):
186 | self._tracks_age[i] += 1
187 |
188 | # # prune inactive tracks
189 | # pop_inds = []
190 | # for i in range(len(self._tracks_age)):
191 | # self._tracks_age[i] = self._tracks_age[i] + 1
192 | # if self._tracks_age[i] > self._track_len:
193 | # pop_inds.append(i)
194 |
195 | # if len(pop_inds) > 0:
196 | # pop_inds.reverse()
197 | # for pi in pop_inds:
198 | # for j in range(len(dets_to_tracks)):
199 | # if dets_to_tracks[j] == pi:
200 | # dets_to_tracks[j] = -1
201 | # elif dets_to_tracks[j] > pi:
202 | # dets_to_tracks[j] = dets_to_tracks[j] - 1
203 | # self._tracks.pop(pi)
204 | # self._tracks_cls.pop(pi)
205 | # self._tracks_age.pop(pi)
206 |
207 | # update
208 | self._prev_dets_xy = dets_xy
209 | self._prev_dets_cls = dets_cls
210 | self._prev_instance_mask = instance_mask
211 | self._prev_dets_to_tracks = dets_to_tracks
212 |
213 | def get_tracklets(self):
214 | tracks, tracks_cls = [], []
215 | for i in range(len(self._tracks)):
216 | if self._tracks_age[i] < self._max_track_age:
217 | tracks.append(np.stack(self._tracks[i], axis=0))
218 | tracks_cls.append(np.array(self._tracks_cls[i]).mean())
219 | return tracks, tracks_cls
220 |
221 | def _associate_prev_det(self, dets_xy, dets_cls, instance_mask, sim_matrix):
222 | prev_dets_inds = []
223 | occupied_flag = np.zeros(len(self._prev_dets_xy), dtype=np.bool)
224 | sim = sim_matrix[0].data.cpu().numpy()
225 | for d_idx, (d_xy, d_cls) in enumerate(zip(dets_xy, dets_cls)):
226 | inst_id = d_idx + 1 # instance is 1-based
227 |
228 | # For all the points that belong to the current instance, find their
229 | # most similar points in the previous scans and take the point with
230 | # highest support as the associated point of this instance in the
231 | # previous scan.
232 | inst_sim = sim[instance_mask == inst_id].argmax(axis=1)
233 | assoc_prev_pt_inds = np.bincount(inst_sim).argmax()
234 |
235 | # associated detection
236 | prev_d_idx = self._prev_instance_mask[assoc_prev_pt_inds] - 1 # instance is 1-based
237 |
238 | # only associate one detection
239 | if occupied_flag[prev_d_idx]:
240 | prev_dets_inds.append(-1)
241 | else:
242 | prev_dets_inds.append(prev_d_idx)
243 | occupied_flag[prev_d_idx] = True
244 |
245 | return prev_dets_inds
246 |
--------------------------------------------------------------------------------
/dr_spaam/src/dr_spaam/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VisualComputingInstitute/DR-SPAAM-Detector/e5a5f73f69523b90829be06a2558b597c2934f9f/dr_spaam/src/dr_spaam/model/__init__.py
--------------------------------------------------------------------------------
/dr_spaam/src/dr_spaam/model/drow.py:
--------------------------------------------------------------------------------
1 | from math import ceil
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from .loss_utils import FocalLoss, BinaryFocalLoss
8 |
9 |
10 | def _conv(in_channel, out_channel, kernel_size, padding):
11 | return nn.Sequential(nn.Conv1d(in_channel, out_channel,
12 | kernel_size=kernel_size, padding=padding),
13 | nn.BatchNorm1d(out_channel),
14 | nn.LeakyReLU(negative_slope=0.1, inplace=True))
15 |
16 |
17 | def _conv3x3(in_channel, out_channel):
18 | return _conv(in_channel, out_channel, kernel_size=3, padding=1)
19 |
20 |
21 | def _conv1x1(in_channel, out_channel):
22 | return _conv(in_channel, out_channel, kernel_size=1, padding=1)
23 |
24 |
25 | class DROW(nn.Module):
26 | def __init__(self, dropout=0.5, num_scans=5, num_pts=48, focal_loss_gamma=0.0,
27 | pedestrian_only=False):
28 | super(DROW, self).__init__()
29 |
30 | self.dropout = dropout
31 |
32 | self.conv_block_1 = nn.Sequential(_conv3x3(1, 64),
33 | _conv3x3(64, 64),
34 | _conv3x3(64, 128))
35 | self.conv_block_2 = nn.Sequential(_conv3x3(128, 128),
36 | _conv3x3(128, 128),
37 | _conv3x3(128, 256))
38 | self.conv_block_3 = nn.Sequential(_conv3x3(256, 256),
39 | _conv3x3(256, 256),
40 | _conv3x3(256, 512))
41 | self.conv_block_4 = nn.Sequential(_conv3x3(512, 256),
42 | _conv3x3(256, 128))
43 |
44 | if pedestrian_only:
45 | self.conv_cls = nn.Conv1d(128, 1, kernel_size=1) # probs
46 | self.cls_loss = BinaryFocalLoss(gamma=focal_loss_gamma) \
47 | if focal_loss_gamma > 0.0 else F.binary_cross_entropy
48 | else:
49 | self.conv_cls = nn.Conv1d(128, 4, kernel_size=1) # probs
50 | self.cls_loss = FocalLoss(gamma=focal_loss_gamma) \
51 | if focal_loss_gamma > 0.0 else F.cross_entropy
52 |
53 | self.conv_reg = nn.Conv1d(128, 2, kernel_size=1) # vote
54 |
55 | for m in self.modules():
56 | if isinstance(m, (nn.Conv1d, nn.Conv2d)):
57 | nn.init.kaiming_normal_(m.weight, a=0.1, nonlinearity='leaky_relu')
58 | elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
59 | nn.init.constant_(m.weight, 1)
60 | nn.init.constant_(m.bias, 0)
61 |
62 | def _forward_conv(self, x, conv_block):
63 | out = conv_block(x)
64 | out = F.max_pool1d(out, kernel_size=2)
65 | if self.dropout > 0:
66 | out = F.dropout(out, p=self.dropout, training=self.training)
67 |
68 | return out
69 |
70 | def _forward_cutout(self, x):
71 | n_batch, n_cutout, n_scan, n_pts = x.shape
72 |
73 | out = x.view(n_batch * n_cutout * n_scan, 1, n_pts)
74 |
75 | # feature for each cutout
76 | out = self._forward_conv(out, self.conv_block_1) # 24
77 | out = self._forward_conv(out, self.conv_block_2) # 12
78 |
79 | # (batch, cutout, scan, channel, pts)
80 | return out.view(n_batch, n_cutout, n_scan, out.shape[-2], out.shape[-1])
81 |
82 | def _fuse_cutout(self, x):
83 | return torch.sum(x, dim=2) # (batch, cutout, channel, pts)
84 |
85 | def _forward_fused_cutout(self, x):
86 | n_batch, n_cutout, n_channel, n_pts = x.shape
87 |
88 | # feature for fused cutout
89 | out = x.view(n_batch*n_cutout, n_channel, n_pts)
90 | out = self._forward_conv(out, self.conv_block_3) # 6
91 | out = self.conv_block_4(out)
92 | out = F.avg_pool1d(out, kernel_size=out.shape[-1]) # (batch*cutout, channel, 1)
93 |
94 | pred_cls = self.conv_cls(out).view(n_batch, n_cutout, -1)
95 | pred_reg = self.conv_reg(out).view(n_batch, n_cutout, 2)
96 |
97 | return pred_cls, pred_reg
98 |
99 | def forward(self, x):
100 | out = self._forward_cutout(x)
101 | out = self._fuse_cutout(out)
102 | pred_cls, pred_reg = self._forward_fused_cutout(out)
103 |
104 | return pred_cls, pred_reg
105 |
106 |
107 | class _TemporalAttention(nn.Module):
108 | def __init__(self, n_scans, n_pts, n_channel):
109 | super(_TemporalAttention, self).__init__()
110 | self.conv1 = nn.Conv1d(n_channel, 128, kernel_size=n_pts, padding=0)
111 | self.bn1 = nn.BatchNorm1d(128)
112 | self.conv2 = nn.Conv1d(128, 64, kernel_size=n_scans, padding=0)
113 | self.bn2 = nn.BatchNorm1d(64)
114 | self.fc = nn.Linear(64, n_scans)
115 | self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
116 |
117 | for m in self.modules():
118 | if isinstance(m, (nn.Conv1d, nn.Conv2d)):
119 | nn.init.kaiming_normal_(m.weight, a=0.1, nonlinearity='leaky_relu')
120 | elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
121 | nn.init.constant_(m.weight, 1)
122 | nn.init.constant_(m.bias, 0)
123 |
124 | def forward(self, x):
125 | n_batch, n_scans, n_channel, n_pts = x.shape
126 |
127 | out = x.view(n_batch * n_scans, n_channel, n_pts)
128 | out = self.conv1(out)
129 | out = self.bn1(out)
130 | out = self.relu(out)
131 |
132 | out = out.view(n_batch, n_scans, 128).permute(0, 2, 1) # (batch, feature, scans)
133 | out = self.conv2(out)
134 | out = self.bn2(out)
135 | out = self.relu(out).view(n_batch, 64) # (batch, feature)
136 |
137 | out = self.fc(out)
138 | out = F.softmax(out, dim=1) # (batch, scans)
139 |
140 | return out
141 |
142 |
143 | class TemporalDROW(DROW):
144 | def __init__(self, dropout=0.5, num_scans=5, num_pts=48, focal_loss_gamma=0.0,
145 | pedestrian_only=False):
146 | super(TemporalDROW, self).__init__(
147 | dropout=dropout, num_scans=num_scans, num_pts=num_pts,
148 | focal_loss_gamma=focal_loss_gamma, pedestrian_only=pedestrian_only)
149 |
150 | if num_scans > 1:
151 | self.gate = _TemporalAttention(num_scans, ceil(num_pts / 4), 256)
152 |
153 | def _fuse_cutout(self, x):
154 | n_batch, n_cutout, n_scans, n_channel, n_pts = x.shape
155 |
156 | if n_scans == 1:
157 | return x.view(n_batch, n_cutout, n_channel, n_pts)
158 |
159 | out = x.view(n_batch * n_cutout, n_scans, n_channel, n_pts)
160 | gate = self.gate(out)
161 | out = out * gate[..., None, None]
162 | out = torch.sum(out, dim=1) # (batch*cutout, channel, pts)
163 |
164 | return out.view(n_batch, n_cutout, n_channel, n_pts)
165 |
166 | def forward(self, x, testing=False, fea_prev=None):
167 | # inference
168 | if testing:
169 | out = self._forward_cutout(x).squeeze(dim=2)
170 | fea_now = out.clone()
171 | if fea_prev is not None and len(fea_prev) > 0:
172 | out = torch.stack(list(fea_prev) + [out], dim=2)
173 | out = self._fuse_cutout(out)
174 | pred_cls, pred_reg = self._forward_fused_cutout(out)
175 |
176 | return pred_cls, pred_reg, fea_now
177 |
178 | out = self._forward_cutout(x)
179 | out = self._fuse_cutout(out)
180 | pred_cls, pred_reg = self._forward_fused_cutout(out)
181 |
182 | return pred_cls, pred_reg
183 |
184 |
185 | class _SpatialAttention(nn.Module):
186 | def __init__(self, n_pts, n_channel, alpha=0.5, window_size=7):
187 | super(_SpatialAttention, self).__init__()
188 | self._alpha = alpha
189 | self._window_size = window_size
190 |
191 | self.conv = nn.Sequential(
192 | nn.Conv1d(n_channel, 128, kernel_size=n_pts, padding=0),
193 | nn.BatchNorm1d(128),
194 | nn.LeakyReLU(negative_slope=0.1, inplace=True))
195 |
196 | # place holder, created at runtime
197 | self.neighbor_masks, self.neighbor_inds = None, None
198 |
199 | for m in self.modules():
200 | if isinstance(m, (nn.Conv1d, nn.Conv2d)):
201 | nn.init.kaiming_normal_(m.weight, a=0.1, nonlinearity='leaky_relu')
202 | elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
203 | nn.init.constant_(m.weight, 1)
204 | nn.init.constant_(m.bias, 0)
205 |
206 | def _generate_neighbor_mask(self, x):
207 | # indices of neighboring cutout
208 | n_cutout = x.shape[1]
209 | hw = int(self._window_size / 2)
210 | inds_col = torch.arange(n_cutout).unsqueeze(dim=-1).long()
211 | window_inds = torch.arange(-hw, hw+1).long()
212 | inds_col = inds_col + window_inds.unsqueeze(dim=0) # (cutout, neighbors)
213 | inds_col = inds_col.clamp(min=0, max=n_cutout-1)
214 | inds_row = torch.arange(n_cutout).unsqueeze(dim=-1).expand_as(inds_col).long()
215 | inds_full = torch.stack((inds_row, inds_col), dim=2).view(-1, 2)
216 | # self.register_buffer('neighbor_inds', inds_full)
217 |
218 | masks = torch.zeros(n_cutout, n_cutout).float()
219 | masks[inds_full[:, 0], inds_full[:, 1]] = 1.0
220 | return masks.cuda(x.get_device()) if x.is_cuda else masks, inds_full
221 |
222 | def forward(self, x, x_template):
223 | n_batch, n_cutout, n_channel, n_pts = x.shape
224 |
225 | # # for ablation study - no spatial attention
226 | # if True:
227 | # out_temp = self._alpha * x + (1.0 - self._alpha) * x_template
228 | # return out_temp, None
229 |
230 | # only need to generate neighbor mask once
231 | if self.neighbor_masks is None:
232 | self.neighbor_masks, self.neighbor_inds = self._generate_neighbor_mask(x)
233 |
234 | # embedding for cutout
235 | emb_x = self.conv(x.view(n_batch * n_cutout, n_channel, n_pts))
236 | emb_x = emb_x.view(n_batch, n_cutout, 128)
237 |
238 | # embedding for template
239 | emb_temp = self.conv(x_template.view(n_batch * n_cutout, n_channel, n_pts))
240 | emb_temp = emb_temp.view(n_batch, n_cutout, 128)
241 |
242 | # pair-wise similarity (batch, cutout, cutout)
243 | sim = torch.matmul(emb_x, emb_temp.permute(0, 2, 1))
244 |
245 | # # masked softmax (original)
246 | # # @note 1e-5 was added to `exps` before, not to `exps_sum`
247 | # maxes = (sim * self.neighbor_masks).max(dim=-1, keepdim=True)[0]
248 | # sim_centered = torch.clamp(sim - maxes, max=0.0)
249 | # exps = torch.exp(sim_centered) * self.neighbor_masks
250 | # exps_sum = exps.sum(dim=-1, keepdim=True)
251 | # sim = exps / exps_sum
252 |
253 | # masked softmax (new)
254 | sim = sim - 1e10 * (1.0 - self.neighbor_masks) # make sure the out-of-window elements have small values
255 | maxes = sim.max(dim=-1, keepdim=True)[0]
256 | exps = torch.exp(sim - maxes) * self.neighbor_masks
257 | exps_sum = exps.sum(dim=-1, keepdim=True)
258 | sim = exps / exps_sum
259 |
260 | # # weighted average on the template (old)
261 | # out_temp = x_template.view(n_batch, n_cutout, n_channel*n_pts).permute(0, 2, 1)
262 | # out_temp = torch.matmul(out_temp, sim.permute(0, 2, 1))
263 | # out_temp = out_temp.permute(0, 2, 1).view(
264 | # n_batch, n_cutout, n_channel, n_pts)
265 |
266 | # weighted average on the template (new, remove redundent transpose)
267 | out_temp = x_template.view(n_batch, n_cutout, n_channel*n_pts)
268 | out_temp = torch.matmul(sim, out_temp)
269 | out_temp = out_temp.view(n_batch, n_cutout, n_channel, n_pts)
270 |
271 | # auto-regressive
272 | out_temp = self._alpha * x + (1.0 - self._alpha) * out_temp
273 |
274 | return out_temp, sim
275 |
276 |
277 | class SpatialDROW(DROW):
278 | def __init__(self, dropout=0.5, num_scans=5, num_pts=48, focal_loss_gamma=0.0,
279 | alpha=0.5, window_size=7, pedestrian_only=False):
280 | super(SpatialDROW, self).__init__(
281 | dropout=dropout, num_scans=num_scans, num_pts=num_pts,
282 | focal_loss_gamma=focal_loss_gamma, pedestrian_only=pedestrian_only)
283 |
284 | self.gate = _SpatialAttention(n_pts=int(ceil(num_pts / 4)),
285 | n_channel=256,
286 | alpha=alpha,
287 | window_size=window_size)
288 |
289 | def forward(self, x, testing=False, fea_template=None):
290 | # inference
291 | if testing:
292 | out = self._forward_cutout(x).squeeze(dim=2)
293 | if fea_template is None:
294 | out_template = out.clone()
295 | sim = None
296 | else:
297 | out_template, sim = self.gate(out, fea_template)
298 |
299 | pred_cls, pred_reg = self._forward_fused_cutout(out_template)
300 |
301 | return pred_cls, pred_reg, out_template, sim
302 |
303 | # # for ablation study - no auto-regression
304 | # if True:
305 | # input = x[:, :, -2, :].unsqueeze(dim=2)
306 | # out_template = self._forward_cutout(input).squeeze(dim=2)
307 | # input = x[:, :, -1, :].unsqueeze(dim=2)
308 | # out = self._forward_cutout(input).squeeze(dim=2)
309 | # out_template, sim = self.gate(out, out_template)
310 | # pred_cls, pred_reg = self._forward_fused_cutout(out_template)
311 | # return pred_cls, pred_reg, sim
312 |
313 | # training or evaluation
314 | n_scan = x.shape[2]
315 | input = x[:, :, 0, :].unsqueeze(dim=2)
316 | out_template = self._forward_cutout(input).squeeze(dim=2)
317 | for i in range(1, n_scan):
318 | input = x[:, :, i, :].unsqueeze(dim=2)
319 | out = self._forward_cutout(input).squeeze(dim=2)
320 | out_template, sim = self.gate(out, out_template)
321 |
322 | pred_cls, pred_reg = self._forward_fused_cutout(out_template)
323 |
324 | return pred_cls, pred_reg, sim
325 |
--------------------------------------------------------------------------------
/dr_spaam/src/dr_spaam/model/loss_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class FocalLoss(nn.Module):
6 | # From https://github.com/mbsariyildiz/focal-loss.pytorch/blob/master/focalloss.py
7 | def __init__(self, gamma=0, alpha=None):
8 | super(FocalLoss, self).__init__()
9 | self.gamma = gamma
10 | self.alpha = alpha
11 | if isinstance(alpha, (float, int)): self.alpha = torch.Tensor([alpha, 1 - alpha])
12 | if isinstance(alpha, list): self.alpha = torch.Tensor(alpha)
13 |
14 | def forward(self, input, target, reduction='mean'):
15 | if input.dim()>2:
16 | input = input.view(input.size(0), input.size(1), -1) # N,C,H,W => N,C,H*W
17 | input = input.transpose(1, 2) # N,C,H*W => N,H*W,C
18 | input = input.contiguous().view(-1, input.size(2)) # N,H*W,C => N*H*W,C
19 | target = target.view(-1, 1)
20 |
21 | logpt = F.log_softmax(input, dim=1)
22 | logpt = logpt.gather(1,target)
23 | logpt = logpt.view(-1)
24 | pt = logpt.exp()
25 |
26 | if self.alpha is not None:
27 | if self.alpha.type() != input.data.type():
28 | self.alpha = self.alpha.type_as(input.data)
29 | at = self.alpha.gather(0, target.data.view(-1))
30 | logpt = logpt * at
31 |
32 | loss = -1 * (1 - pt)**self.gamma * logpt
33 |
34 | if reduction == 'mean':
35 | return loss.mean()
36 | elif reduction == 'sum':
37 | return loss.sum()
38 | elif reduction == 'none':
39 | return loss
40 | else:
41 | raise RuntimeError
42 |
43 |
44 | class BinaryFocalLoss(nn.Module):
45 | def __init__(self, gamma=2.0, alpha=-1):
46 | super(BinaryFocalLoss, self).__init__()
47 | self.gamma, self.alpha = gamma, alpha
48 |
49 | def forward(self, pred, target, reduction='mean'):
50 | return binary_focal_loss(pred, target, self.gamma, self.alpha, reduction)
51 |
52 |
53 | def binary_focal_loss(pred, target, gamma=2.0, alpha=-1, reduction='mean'):
54 | loss_pos = - target * (1.0 - pred)**gamma * torch.log(pred)
55 | loss_neg = - (1.0 - target) * pred**gamma * torch.log(1.0 - pred)
56 |
57 | if alpha >= 0.0 and alpha <= 1.0:
58 | loss_pos = loss_pos * alpha
59 | loss_neg = loss_neg * (1.0 - alpha)
60 |
61 | loss = loss_pos + loss_neg
62 |
63 | if reduction == 'mean':
64 | return loss.mean()
65 | elif reduction == 'sum':
66 | return loss.sum()
67 | elif reduction == 'none':
68 | return loss
69 | else:
70 | raise RuntimeError
--------------------------------------------------------------------------------
/dr_spaam/src/dr_spaam/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VisualComputingInstitute/DR-SPAAM-Detector/e5a5f73f69523b90829be06a2558b597c2934f9f/dr_spaam/src/dr_spaam/utils/__init__.py
--------------------------------------------------------------------------------
/dr_spaam/src/dr_spaam/utils/dataset.py:
--------------------------------------------------------------------------------
1 | from glob import glob
2 | import os
3 |
4 | import json
5 | import numpy as np
6 | from torch.utils.data import Dataset, DataLoader
7 |
8 | from . import utils as u
9 |
10 |
11 | def create_dataloader(data_path, num_scans, batch_size, num_workers, network_type="cutout",
12 | train_with_val=False, use_data_augumentation=False,
13 | cutout_kwargs=None, polar_grid_kwargs=None,
14 | pedestrian_only=False):
15 | train_set = DROWDataset(data_path=data_path,
16 | split='train',
17 | num_scans=num_scans,
18 | network_type=network_type,
19 | train_with_val=train_with_val,
20 | use_data_augumentation=use_data_augumentation,
21 | cutout_kwargs=cutout_kwargs,
22 | polar_grid_kwargs=polar_grid_kwargs,
23 | pedestrian_only=pedestrian_only)
24 | eval_set = DROWDataset(data_path=data_path,
25 | split='val',
26 | num_scans=num_scans,
27 | network_type=network_type,
28 | train_with_val=False,
29 | use_data_augumentation=False,
30 | cutout_kwargs=cutout_kwargs,
31 | polar_grid_kwargs=polar_grid_kwargs,
32 | pedestrian_only=pedestrian_only)
33 | train_loader = DataLoader(train_set, batch_size=batch_size, pin_memory=True,
34 | num_workers=num_workers, shuffle=True,
35 | collate_fn=train_set.collate_batch)
36 | eval_loader = DataLoader(eval_set, batch_size=batch_size, pin_memory=True,
37 | num_workers=num_workers, shuffle=True,
38 | collate_fn=eval_set.collate_batch)
39 | return train_loader, eval_loader
40 |
41 |
42 | def create_test_dataloader(data_path, num_scans, network_type="cutout",
43 | cutout_kwargs=None, polar_grid_kwargs=None,
44 | pedestrian_only=False, split='test',
45 | scan_stride=1, pt_stride=1):
46 | test_set = DROWDataset(data_path=data_path,
47 | split=split,
48 | num_scans=num_scans,
49 | network_type=network_type,
50 | train_with_val=False,
51 | use_data_augumentation=False,
52 | cutout_kwargs=cutout_kwargs,
53 | polar_grid_kwargs=polar_grid_kwargs,
54 | pedestrian_only=pedestrian_only,
55 | scan_stride=scan_stride,
56 | pt_stride=pt_stride)
57 | test_loader = DataLoader(test_set, batch_size=1, pin_memory=True,
58 | num_workers=1, shuffle=False,
59 | collate_fn=test_set.collate_batch)
60 | return test_loader
61 |
62 |
63 | class DROWDataset(Dataset):
64 | def __init__(self, data_path, split='train', num_scans=5, network_type="cutout",
65 | train_with_val=False, cutout_kwargs=None, polar_grid_kwargs=None,
66 | use_data_augumentation=False, pedestrian_only=False,
67 | scan_stride=1, pt_stride=1):
68 | self._num_scans = num_scans
69 | self._use_data_augmentation = use_data_augumentation
70 | self._cutout_kwargs = cutout_kwargs
71 | self._network_type = network_type
72 | self._polar_grid_kwargs = polar_grid_kwargs
73 | self._pedestrian_only = pedestrian_only
74 | self._scan_stride = scan_stride
75 | self._pt_stride = pt_stride # @TODO remove pt_stride
76 |
77 | if train_with_val:
78 | seq_names = [f[:-4] for f in glob(os.path.join(data_path, 'train', '*.csv'))]
79 | seq_names += [f[:-4] for f in glob(os.path.join(data_path, 'val', '*.csv'))]
80 | else:
81 | seq_names = [f[:-4] for f in glob(os.path.join(data_path, split, '*.csv'))]
82 |
83 | # seq_names = seq_names[:1]
84 | self.seq_names = seq_names
85 |
86 | # Pre-load scans and annotations
87 | self.scans_ns, self.scans_t, self.scans = zip(*[self._load_scan_file(f) for f in seq_names])
88 | self.dets_ns, self.dets_wc, self.dets_wa, self.dets_wp = zip(*map(
89 | lambda f: self._load_det_file(f), seq_names))
90 |
91 | # Pre-compute mappings from detection index to scan index
92 | # such that idet2iscan[seq_idx][det_idx] = scan_idx
93 | self.idet2iscan = [{i: np.where(ss == d)[0][0] for i, d in enumerate(ds)}
94 | for ss, ds in zip(self.scans_ns, self.dets_ns)]
95 |
96 | # Look-up list for sequence indices and annotation indices.
97 | self.flat_seq_inds, self.flat_det_inds = [], []
98 | for seq_idx, det_ns in enumerate(self.dets_ns):
99 | num_samples = len(det_ns)
100 | self.flat_seq_inds += [seq_idx] * num_samples
101 | self.flat_det_inds += range(num_samples)
102 |
103 | def __len__(self):
104 | return len(self.flat_det_inds)
105 |
106 | def __getitem__(self, idx):
107 | seq_idx = self.flat_seq_inds[idx]
108 | det_idx = self.flat_det_inds[idx]
109 | dets_ns = self.dets_ns[seq_idx][det_idx]
110 |
111 | rtn_dict = {}
112 | rtn_dict['seq_name'] = self.seq_names[seq_idx]
113 | rtn_dict['dets_ns'] = dets_ns
114 |
115 | # Annotation
116 | rtn_dict['dets_wc'] = self.dets_wc[seq_idx][det_idx]
117 | rtn_dict['dets_wa'] = self.dets_wa[seq_idx][det_idx]
118 | rtn_dict['dets_wp'] = self.dets_wp[seq_idx][det_idx]
119 |
120 | # Scan
121 | scan_idx = self.idet2iscan[seq_idx][det_idx]
122 | inds_tmp = (np.arange(self._num_scans) * self._scan_stride)[::-1]
123 | scan_inds = [max(0, scan_idx - i) for i in inds_tmp]
124 | scans = np.array([self.scans[seq_idx][i] for i in scan_inds])
125 | scans = scans[:, ::self._pt_stride]
126 | scans_ns = [self.scans_ns[seq_idx][i] for i in scan_inds]
127 | rtn_dict['scans'] = scans
128 | rtn_dict['scans_ns'] = scans_ns
129 |
130 | # angle
131 | scan_phi = u.get_laser_phi()[::self._pt_stride]
132 | rtn_dict['phi_grid'] = scan_phi
133 |
134 | # Regression target
135 | target_cls, target_reg = u.get_regression_target(
136 | scans[-1],
137 | scan_phi,
138 | rtn_dict['dets_wc'],
139 | rtn_dict['dets_wa'],
140 | rtn_dict['dets_wp'],
141 | pedestrian_only=self._pedestrian_only)
142 |
143 | rtn_dict['target_cls'] = target_cls
144 | rtn_dict['target_reg'] = target_reg
145 |
146 | if self._use_data_augmentation:
147 | rtn_dict = u.data_augmentation(rtn_dict)
148 |
149 | # polar grid or cutout
150 | if self._network_type == "cutout" \
151 | or self._network_type == "cutout_gating" \
152 | or self._network_type == "cutout_spatial":
153 | if "area_mode" not in self._cutout_kwargs:
154 | cutout = u.scans_to_cutout_original(
155 | scans, scan_phi[1] - scan_phi[0],
156 | **self._cutout_kwargs)
157 | else:
158 | cutout = u.scans_to_cutout(scans, scan_phi, stride=1,
159 | **self._cutout_kwargs)
160 | rtn_dict['input'] = cutout
161 | elif self._network_type == "fc1d":
162 | rtn_dict['input'] = np.expand_dims(scans, axis=1)
163 | elif self._network_type == 'fc1d_fea':
164 | cutout = u.scans_to_cutout(rtn_dict['scans'],
165 | scan_phi[1] - scan_phi[0],
166 | **self._cutout_kwargs)
167 | rtn_dict['input'] = np.transpose(cutout, (1, 2, 0))
168 | elif self._network_type == "fc2d":
169 | polar_grid = u.scans_to_polar_grid(rtn_dict['scans'],
170 | **self._polar_grid_kwargs)
171 | rtn_dict['input'] = np.expand_dims(polar_grid, axis=1)
172 | elif self._network_type == 'fc2d_fea':
173 | raise NotImplementedError
174 |
175 | return rtn_dict
176 |
177 | def collate_batch(self, batch):
178 | rtn_dict = {}
179 | for k, _ in batch[0].items():
180 | if k in ["target_cls", "target_reg", "input"]:
181 | rtn_dict[k] = np.array([sample[k] for sample in batch])
182 | else:
183 | rtn_dict[k] = [sample[k] for sample in batch]
184 |
185 | return rtn_dict
186 |
187 | def _load_scan_file(self, seq_name):
188 | data = np.genfromtxt(seq_name + '.csv', delimiter=",")
189 | seqs = data[:, 0].astype(np.uint32)
190 | times = data[:, 1].astype(np.float32)
191 | scans = data[:, 2:].astype(np.float32)
192 | return seqs, times, scans
193 |
194 | def _load_det_file(self, seq_name):
195 | def do_load(f_name):
196 | seqs, dets = [], []
197 | with open(f_name) as f:
198 | for line in f:
199 | seq, tail = line.split(',', 1)
200 | seqs.append(int(seq))
201 | dets.append(json.loads(tail))
202 | return seqs, dets
203 |
204 | s1, wcs = do_load(seq_name + '.wc')
205 | s2, was = do_load(seq_name + '.wa')
206 | s3, wps = do_load(seq_name + '.wp')
207 | assert all(a == b == c for a, b, c in zip(s1, s2, s3))
208 |
209 | return np.array(s1), wcs, was, wps
210 |
211 |
212 | if __name__ == '__main__':
213 | import matplotlib.pyplot as plt
214 |
215 | dataset = DROWDataset(data_path='../data/DROWv2-data')
216 |
217 | fig = plt.figure()
218 | ax = fig.add_subplot(111)
219 |
220 | for sample in dataset:
221 | target_cls, target_reg = sample['target_cls'], sample['target_reg']
222 | scans = sample['scans']
223 | scan_phi = u.get_laser_phi()
224 |
225 | num_scans = scans.shape[0]
226 | for scan_idx in range(1):
227 | scan_x, scan_y = u.scan_to_xy(scans[-scan_idx])
228 |
229 | plt.cla()
230 | ax.set_xlim(-5, 5)
231 | ax.set_ylim(-5, 5)
232 | ax.scatter(scan_x, scan_y, s=1, c='black')
233 |
234 | colors = ['blue', 'green', 'red']
235 | cls_labels = [1, 2, 3]
236 | for cls_label, c in zip(cls_labels, colors):
237 | canonical_dxy = target_reg[target_cls==cls_label]
238 | dets_r, dets_phi = u.canonical_to_global(
239 | scans[-1][target_cls==cls_label],
240 | scan_phi[target_cls==cls_label],
241 | canonical_dxy[:, 0],
242 | canonical_dxy[:, 1])
243 | dets_x, dets_y = u.rphi_to_xy(dets_r, dets_phi)
244 | ax.scatter(dets_x, dets_y, s=5, c=c)
245 |
246 | plt.pause(0.1)
247 |
248 | plt.show()
249 |
--------------------------------------------------------------------------------
/dr_spaam/src/dr_spaam/utils/eval_utils.py:
--------------------------------------------------------------------------------
1 | import tqdm
2 | import matplotlib.pyplot as plt
3 | import numpy as np
4 | import os
5 | import pickle
6 | import torch
7 | import torch.nn.functional as F
8 |
9 | from . import utils as u
10 | from . import prec_rec_utils as pru
11 |
12 | # For plotting using lab cluster server https://github.com/matplotlib/matplotlib/issues/3466/
13 | plt.switch_backend('agg')
14 |
15 |
16 | def cfg_to_model(cfg):
17 | if cfg['network'] == 'cutout':
18 | from ..model.drow import DROW
19 | model = DROW(num_scans=cfg['num_scans'],
20 | num_pts=cfg['cutout_kwargs']['num_cutout_pts'],
21 | focal_loss_gamma=cfg['focal_loss_gamma'],
22 | pedestrian_only=cfg['pedestrian_only'])
23 |
24 | elif cfg['network'] == 'cutout_gating':
25 | from ..model.drow import TemporalDROW
26 | model = TemporalDROW(num_scans=cfg['num_scans'],
27 | num_pts=cfg['cutout_kwargs']['num_cutout_pts'],
28 | focal_loss_gamma=cfg['focal_loss_gamma'],
29 | pedestrian_only=cfg['pedestrian_only'])
30 |
31 | elif cfg['network'] == 'cutout_spatial':
32 | from ..model.drow import SpatialDROW
33 | model = SpatialDROW(num_scans=cfg['num_scans'],
34 | num_pts=cfg['cutout_kwargs']['num_cutout_pts'],
35 | focal_loss_gamma=cfg['focal_loss_gamma'],
36 | alpha=cfg['similarity_kwargs']['alpha'],
37 | window_size=cfg['similarity_kwargs']['window_size'],
38 | pedestrian_only=cfg['pedestrian_only'])
39 |
40 | elif cfg['network'] == 'fc2d':
41 | from ..model.polar_drow import PolarDROW
42 | model = PolarDROW(in_channel=1)
43 |
44 | elif cfg['network'] == 'fc2d_fea':
45 | raise NotImplementedError
46 | from ..model.polar_drow import PolarDROW
47 | model = PolarDROW(in_channel=cfg['cutout_kwargs']['num_cutout_pts'])
48 |
49 | elif cfg['network'] == 'fc1d':
50 | from ..model.fconv_drow import FConvDROW
51 | model = FConvDROW(in_channel=1)
52 |
53 | elif cfg['network'] == 'fc1d_fea':
54 | from ..model.fconv_drow import FConvDROW
55 | model = FConvDROW(in_channel=cfg['cutout_kwargs']['num_cutout_pts'])
56 |
57 | else:
58 | raise RuntimeError
59 |
60 | return model
61 |
62 |
63 | def model_fn(model, data, rtn_result=False):
64 | tb_dict, rtn_dict = {}, {}
65 |
66 | net_input = data['input']
67 | net_input = torch.from_numpy(net_input).cuda(non_blocking=True).float()
68 |
69 | # Forward pass
70 | model_rtn = model(net_input)
71 | spatial_drow = len(model_rtn) == 3
72 | if spatial_drow:
73 | pred_cls, pred_reg, pred_sim = model_rtn
74 | else:
75 | pred_cls, pred_reg = model_rtn
76 |
77 | target_cls, target_reg = data['target_cls'], data['target_reg']
78 | target_cls = torch.from_numpy(target_cls).cuda(non_blocking=True).long()
79 | target_reg = torch.from_numpy(target_reg).cuda(non_blocking=True).float()
80 |
81 | n_batch, n_pts = target_cls.shape[:2]
82 |
83 | # cls loss
84 | target_cls = target_cls.view(n_batch * n_pts)
85 | pred_cls = pred_cls.view(n_batch * n_pts, -1)
86 | if pred_cls.shape[1] == 1:
87 | cls_loss = model.cls_loss(torch.sigmoid(pred_cls.squeeze(-1)),
88 | target_cls.float(),
89 | reduction='mean')
90 | else:
91 | cls_loss = model.cls_loss(pred_cls, target_cls, reduction='mean')
92 | total_loss = cls_loss
93 | tb_dict['cls_loss'] = cls_loss.item()
94 |
95 | # number fg points
96 | fg_mask = target_cls.ne(0)
97 | fg_ratio = torch.sum(fg_mask).item() / (n_batch * n_pts)
98 | tb_dict['fg_ratio'] = fg_ratio
99 |
100 | # reg loss
101 | if fg_ratio > 0.0:
102 | target_reg = target_reg.view(n_batch * n_pts, -1)
103 | pred_reg = pred_reg.view(n_batch * n_pts, -1)
104 | reg_loss = F.mse_loss(pred_reg[fg_mask], target_reg[fg_mask],
105 | reduction='none')
106 | reg_loss = torch.sqrt(torch.sum(reg_loss, dim=1)).mean()
107 | total_loss = total_loss + reg_loss
108 | tb_dict['reg_loss'] = reg_loss.item()
109 |
110 | # # regularization loss for spatial attention
111 | # if spatial_drow:
112 | # att_loss = (-torch.log(pred_sim + 1e-5) * pred_sim).sum(dim=2).mean() # shannon entropy
113 | # tb_dict['att_loss'] = att_loss.item()
114 | # total_loss = total_loss + att_loss
115 |
116 | if rtn_result:
117 | rtn_dict["pred_reg"] = pred_reg.view(n_batch, n_pts, -1)
118 | rtn_dict["pred_cls"] = pred_cls.view(n_batch, n_pts, -1)
119 |
120 | return total_loss, tb_dict, rtn_dict
121 |
122 |
123 | def eval_batch(model, data, vote_kwargs, full_eval=True):
124 | # forward pass
125 | _, tb_dict, rtn_dict = model_fn(model, data, rtn_result=full_eval)
126 |
127 | # only compute lost, not ap
128 | if not full_eval:
129 | return tb_dict, rtn_dict
130 |
131 | # get inference result to cpu
132 | pred_cls, pred_reg = rtn_dict['pred_cls'], rtn_dict['pred_reg']
133 | if pred_cls.shape[-1] == 1:
134 | pred_cls = torch.sigmoid(pred_cls).data.cpu().numpy()
135 | else:
136 | pred_cls = F.softmax(pred_cls, dim=-1).data.cpu().numpy()
137 | pred_reg = pred_reg.data.cpu().numpy()
138 |
139 | # grouping
140 | scan_grid, phi_grid = data['scans'], data['phi_grid']
141 | dets_xy_list, dets_cls_list, dets_inds_list = [], [], []
142 | for i, (s_g, p_g, p_cls, p_reg) in enumerate(
143 | zip(scan_grid, phi_grid, pred_cls, pred_reg)):
144 | # dets_xy, dets_cls, _ = u.group_predicted_center(s_g[-1], p_g, p_cls, p_reg,
145 | # **vote_kwargs)
146 | dets_xy, dets_cls, _ = u.nms_predicted_center(s_g[-1], p_g, p_cls, p_reg)
147 | if len(dets_xy) > 0:
148 | dets_xy_list.append(dets_xy)
149 | dets_cls_list.append(dets_cls)
150 | dets_inds_list = dets_inds_list + [i] * len(dets_cls)
151 |
152 | if len(dets_xy_list) > 0:
153 | rtn_dict.update({'dets_xy': np.concatenate(dets_xy_list, axis=0),
154 | 'dets_cls': np.concatenate(dets_cls_list, axis=0),
155 | 'dets_inds': np.array(dets_inds_list, dtype=np.int32)})
156 |
157 | return tb_dict, rtn_dict
158 |
159 |
160 | def eval_epoch(model, test_loader, vote_kwargs, full_eval=True):
161 | model.eval()
162 |
163 | # hold all detections
164 | dets_xy_list, dets_cls_list, dets_inds_list = [], [], []
165 |
166 | # hold all ground truth
167 | gts_xy, gts_inds = {}, {}
168 | gts_xy['wc'], gts_xy['wa'], gts_xy['wp'], gts_xy['all'] = [], [], [], []
169 | gts_inds['wc'], gts_inds['wa'], gts_inds['wp'], gts_inds['all'] = [], [], [], []
170 |
171 | # hold all items for tb logging
172 | tb_dict = {}
173 |
174 | # inference over the whole test set, and collect results
175 | for it, data in enumerate(tqdm.tqdm(test_loader, desc='eval')):
176 | n_batch = len(data['scans'])
177 | it_global = it * n_batch
178 |
179 | # inference
180 | batch_tb_dict, batch_rtn_dict = eval_batch(model, data, vote_kwargs, full_eval)
181 |
182 | # store tb log
183 | for k, v in batch_tb_dict.items():
184 | tb_dict.setdefault(k, []).append(v)
185 |
186 | if not full_eval:
187 | continue
188 |
189 | # store detection
190 | if 'dets_xy' in batch_rtn_dict:
191 | dets_xy_list.append(batch_rtn_dict['dets_xy'])
192 | dets_cls_list.append(batch_rtn_dict['dets_cls'])
193 | dets_inds_list.append(batch_rtn_dict['dets_inds'] + it_global)
194 |
195 | # store gt
196 | for k in ['wc', 'wa', 'wp']:
197 | for j, j_gts in enumerate(data['dets_'+k]):
198 | for r, phi in j_gts:
199 | xy = u.rphi_to_xy(r, phi)
200 | gts_xy[k].append(xy)
201 | gts_xy['all'].append(xy)
202 | gts_inds[k].append(j + it_global)
203 | gts_inds['all'].append(j + it_global)
204 |
205 | # compute loss
206 | for k, v in tb_dict.items():
207 | tb_dict[k] = np.array(v).mean()
208 |
209 | # only log training loss
210 | if not full_eval:
211 | return tb_dict, None, None
212 |
213 | # dets for the whole epoch
214 | dets_xy = np.concatenate(dets_xy_list, axis=0) # (N, 2)
215 | dets_cls = np.concatenate(dets_cls_list, axis=0) # (N, cls)
216 | dets_inds = np.concatenate(dets_inds_list) # (N)
217 |
218 | # gts for the whole epoch
219 | for k, v in gts_xy.items():
220 | gts_xy[k] = np.array(v)
221 | gts_inds[k] = np.array(gts_inds[k], dtype=np.int32)
222 |
223 | # evaluation
224 | rpt_dict = {}
225 | dist_thresh = [0.3, 0.5, 0.7]
226 | for dt in dist_thresh:
227 | rpt_dict[dt] = {}
228 |
229 | # pedestrian only
230 | if dets_cls.shape[1] == 1:
231 | for dt in dist_thresh:
232 | rpt_dict[dt]['wp'] = compute_prec_rec(dets_xy, dets_cls[:, 0], dets_inds,
233 | gts_xy['wp'], gts_inds['wp'], dt)
234 | ap, f1, eer = eval_prec_rec(*rpt_dict[dt]['wp'][:2])
235 |
236 | tb_dict["ap_wp_t%s" % dt] = ap
237 | tb_dict["f1_wp_t%s" % dt] = f1
238 | tb_dict["eer_wp_t%s" % dt] = eer
239 |
240 | # multi-class
241 | else:
242 | for dt in dist_thresh:
243 | for k in gts_xy.keys():
244 | if k == 'wc': d_cls = dets_cls[:, 1]
245 | elif k == 'wa': d_cls = dets_cls[:, 2]
246 | elif k == 'wp': d_cls = dets_cls[:, 3]
247 | elif k == 'all': d_cls = np.sum(dets_cls[:, 1:], axis=1)
248 | else: raise RuntimeError
249 |
250 | rpt_dict[dt][k] = compute_prec_rec(dets_xy, d_cls, dets_inds,
251 | gts_xy[k], gts_inds[k], dt)
252 | ap, f1, eer = eval_prec_rec(*rpt_dict[dt][k][:2])
253 |
254 | tb_dict["ap_%s_t%s" % (k, dt)] = ap
255 | tb_dict["f1_%s_t%s" % (k, dt)] = f1
256 | tb_dict["eer_%s_t%s" % (k, dt)] = eer
257 |
258 | # also return network inference results
259 | fwd_dict = {}
260 | fwd_dict['dets'] = dets_xy
261 | fwd_dict['dets_inds'] = dets_inds
262 | fwd_dict['dets_cls'] = dets_cls
263 | fwd_dict['gts'] = gts_xy
264 | fwd_dict['gts_inds'] = gts_inds
265 |
266 | return tb_dict, rpt_dict, fwd_dict
267 |
268 |
269 | def compute_prec_rec(dets, dets_cls, dets_inds, gts, gts_inds, dt):
270 | dt = dt * np.ones(len(gts_inds), dtype=np.float32)
271 | return pru.prec_rec_2d(dets_cls, dets, dets_inds, gts, gts_inds, dt)
272 |
273 |
274 | def eval_prec_rec(rec, prec):
275 | return pru.eval_prec_rec(rec, prec)
276 |
277 |
278 | def plot_prec_rec(rpt_dict, plot_title=None, output_file=None):
279 | pedestrian_only = 'all' not in rpt_dict.keys()
280 | if pedestrian_only:
281 | fig, ax = pru.plot_prec_rec_wps_only(wps=rpt_dict['wp'],
282 | title=plot_title)
283 | else:
284 | fig, ax = pru.plot_prec_rec(wds=rpt_dict['all'],
285 | wcs=rpt_dict['wc'],
286 | was=rpt_dict['wa'],
287 | wps=rpt_dict['wp'],
288 | title=plot_title)
289 |
290 | if output_file is not None:
291 | plt.savefig(output_file, bbox_inches='tight')
292 |
293 | return fig, ax
294 |
295 |
296 | def eval_epoch_with_output(model, test_loader, epoch, it, root_result_dir, split, tag,
297 | vote_kwargs, full_eval=True, writing=False, plotting=False,
298 | save_pkl=False, tb_log=None):
299 | tb_dict, rpt_dict, fwd_dict = eval_epoch(
300 | model, test_loader, vote_kwargs=vote_kwargs,
301 | full_eval=full_eval)
302 |
303 | if writing:
304 | ap_dir = os.path.join(root_result_dir, 'results')
305 | os.makedirs(ap_dir, exist_ok=True)
306 | ap_file = os.path.join(ap_dir, '%s.csv' % split)
307 | for k, v in tb_dict.items():
308 | with open(ap_file, "a") as f:
309 | s = "%s, %s, %s, %s, %s, %s\n" % (tag, it, epoch, split, k, v)
310 | f.write(s)
311 | if tb_log is not None:
312 | stag = ("eval_%s" % split) if tag.startswith('eval_') else split
313 | tb_log.add_scalar("%s_%s" % (stag, k), v, it)
314 |
315 | if not full_eval:
316 | tb_log.flush()
317 | return
318 |
319 | if save_pkl:
320 | pkl_dir = os.path.join(root_result_dir, 'pkl')
321 | os.makedirs(pkl_dir, exist_ok=True)
322 |
323 | s = '%s_e%s_%s.pkl' % (tag, epoch, split)
324 | with open(os.path.join(pkl_dir, 'rpt_'+s), "wb") as f:
325 | pickle.dump(rpt_dict, f)
326 | with open(os.path.join(pkl_dir, 'fwd_'+s), "wb") as f:
327 | pickle.dump(fwd_dict, f)
328 |
329 | if plotting:
330 | for k, v in rpt_dict.items():
331 | fig_dir = os.path.join(root_result_dir, 'figs', split, 't_%s' % k)
332 | os.makedirs(fig_dir, exist_ok=True)
333 | plot_file = '%s_e%s_%s_t%s.png' % (tag, epoch, split, k)
334 | fig, ax = plot_prec_rec(v, output_file=os.path.join(fig_dir, plot_file))
335 |
336 | if tb_log is not None:
337 | fig.canvas.draw()
338 | im = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
339 | im = im.reshape(fig.canvas.get_width_height()[::-1] + (3, ))
340 | im = im.transpose(2, 0, 1) # (3, H, W)
341 | im = im.astype(np.float32) / 255.0
342 | tb_log.add_image("pr_curve_t%s" % k, im, it)
343 | plt.close(fig)
344 | else:
345 | plt.close(fig)
346 |
347 | if tb_log is not None:
348 | tb_log.flush()
349 |
--------------------------------------------------------------------------------
/dr_spaam/src/dr_spaam/utils/logger.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | from tensorboardX import SummaryWriter
4 |
5 |
6 | def create_logger(root_dir, file_name='log.txt'):
7 | log_file = os.path.join(root_dir, file_name)
8 | log_format = '%(asctime)s %(levelname)5s %(message)s'
9 | logging.basicConfig(level=logging.DEBUG, format=log_format, filename=log_file)
10 | console = logging.StreamHandler()
11 | console.setLevel(logging.DEBUG)
12 | console.setFormatter(logging.Formatter(log_format))
13 | logging.getLogger(__name__).addHandler(console)
14 | return logging.getLogger(__name__)
15 |
16 |
17 | def create_tb_logger(root_dir, tb_log_dir='tensorboard'):
18 | return SummaryWriter(log_dir=os.path.join(root_dir, tb_log_dir))
19 |
--------------------------------------------------------------------------------
/dr_spaam/src/dr_spaam/utils/prec_rec_utils.py:
--------------------------------------------------------------------------------
1 | # Most of the code here comes from
2 | # https://github.com/VisualComputingInstitute/DROW/blob/master/v2/utils/__init__.py
3 | from collections import defaultdict
4 | import numpy as np
5 | import matplotlib as mpl
6 | import matplotlib.pyplot as plt
7 | from scipy.optimize import linear_sum_assignment
8 | from scipy.spatial.distance import cdist
9 | from sklearn.metrics import auc
10 |
11 | # For plotting using lab cluster server
12 | # https://github.com/matplotlib/matplotlib/issues/3466/
13 | plt.switch_backend('agg')
14 |
15 |
16 | def prec_rec_2d(det_scores, det_coords, det_frames, gt_coords, gt_frames, gt_radii):
17 | """ Computes full precision-recall curves at all possible thresholds.
18 |
19 | Arguments:
20 | - `det_scores` (D,) array containing the scores of the D detections.
21 | - `det_coords` (D,2) array containing the (x,y) coordinates of the D detections.
22 | - `det_frames` (D,) array containing the frame number of each of the D detections.
23 | - `gt_coords` (L,2) array containing the (x,y) coordinates of the L labels (ground-truth detections).
24 | - `gt_frames` (L,) array containing the frame number of each of the L labels.
25 | - `gt_radii` (L,) array containing the radius at which each of the L labels should consider detection associations.
26 | This will typically just be an np.full_like(gt_frames, 0.5) or similar,
27 | but could vary when mixing classes, for example.
28 |
29 | Returns: (recs, precs, threshs)
30 | - `threshs`: (D,) array of sorted thresholds (scores), from higher to lower.
31 | - `recs`: (D,) array of recall scores corresponding to the thresholds.
32 | - `precs`: (D,) array of precision scores corresponding to the thresholds.
33 | """
34 | # This means that all reported detection frames which are not in ground-truth frames
35 | # will be counted as false-positives.
36 | # TODO: do some sanity-checks in the "linearization" functions before calling `prec_rec_2d`.
37 | frames = np.unique(np.r_[det_frames, gt_frames])
38 |
39 | det_accepted_idxs = defaultdict(list)
40 | tps = np.zeros(len(frames), dtype=np.uint32)
41 | fps = np.zeros(len(frames), dtype=np.uint32)
42 | fns = np.array([np.sum(gt_frames == f) for f in frames], dtype=np.uint32)
43 |
44 | precs = np.full_like(det_scores, np.nan)
45 | recs = np.full_like(det_scores, np.nan)
46 | threshs = np.full_like(det_scores, np.nan)
47 |
48 | indices = np.argsort(det_scores, kind='mergesort') # mergesort for determinism.
49 | for i, idx in enumerate(reversed(indices)):
50 | frame = det_frames[idx]
51 | iframe = np.where(frames == frame)[0][0] # Can only be a single one.
52 |
53 | # Accept this detection
54 | dets_idxs = det_accepted_idxs[frame]
55 | dets_idxs.append(idx)
56 | threshs[i] = det_scores[idx]
57 |
58 | dets = det_coords[dets_idxs]
59 |
60 | gts_mask = gt_frames == frame
61 | gts = gt_coords[gts_mask]
62 | radii = gt_radii[gts_mask]
63 |
64 | if len(gts) == 0: # No GT, but there is a detection.
65 | fps[iframe] += 1
66 | else: # There is GT and detection in this frame.
67 | not_in_radius = radii[:,None] < cdist(gts, dets) # -> ngts x ndets, True (=1) if too far, False (=0) if may match.
68 | igt, idet = linear_sum_assignment(not_in_radius)
69 |
70 | tps[iframe] = np.sum(np.logical_not(not_in_radius[igt, idet])) # Could match within radius
71 | fps[iframe] = len(dets) - tps[iframe] # NB: dets is only the so-far accepted.
72 | fns[iframe] = len(gts) - tps[iframe]
73 |
74 | tp, fp, fn = np.sum(tps), np.sum(fps), np.sum(fns)
75 | precs[i] = tp/(fp+tp) if fp+tp > 0 else np.nan
76 | recs[i] = tp/(fn+tp) if fn+tp > 0 else np.nan
77 |
78 | return recs, precs, threshs
79 |
80 |
81 | def eval_prec_rec(rec, prec):
82 | # make sure the x-input to auc is sorted
83 | assert np.sum(np.diff(rec)>=0) == len(rec) - 1
84 | # compute error matrices
85 | return auc(rec, prec), peakf1(rec, prec), eer(rec, prec)
86 |
87 |
88 | def peakf1(recs, precs):
89 | return np.max(2 * precs * recs / np.clip(precs + recs, 1e-16, 2 + 1e-16))
90 |
91 |
92 | def eer(recs, precs):
93 | # Find the first nonzero or else (0,0) will be the EER :)
94 | def first_nonzero_idx(arr):
95 | return np.where(arr != 0)[0][0]
96 |
97 | p1 = first_nonzero_idx(precs)
98 | r1 = first_nonzero_idx(recs)
99 | idx = np.argmin(np.abs(precs[p1:] - recs[r1:]))
100 | return (precs[p1 + idx] + recs[r1 + idx]) / 2 # They are often the exact same, but if not, use average.
101 |
102 |
103 | def plot_prec_rec(wds, wcs, was, wps, figsize=(15,10), title=None):
104 | fig, ax = plt.subplots(figsize=figsize)
105 |
106 | # make sure the x-input to auc is sorted
107 | assert np.sum(np.diff(wds[0])>=0) == len(wds[0]) - 1
108 | assert np.sum(np.diff(wcs[0])>=0) == len(wcs[0]) - 1
109 | assert np.sum(np.diff(was[0])>=0) == len(was[0]) - 1
110 | assert np.sum(np.diff(wps[0])>=0) == len(wps[0]) - 1
111 |
112 | ax.plot(*wds[:2], label='agn (AUC: {:.1%}, F1: {:.1%}, EER: {:.1%})'.format(auc(*wds[:2]), peakf1(*wds[:2]), eer(*wds[:2])), c='#E24A33')
113 | ax.plot(*wcs[:2], label='wcs (AUC: {:.1%}, F1: {:.1%}, EER: {:.1%})'.format(auc(*wcs[:2]), peakf1(*wcs[:2]), eer(*wcs[:2])), c='#348ABD')
114 | ax.plot(*was[:2], label='was (AUC: {:.1%}, F1: {:.1%}, EER: {:.1%})'.format(auc(*was[:2]), peakf1(*was[:2]), eer(*was[:2])), c='#988ED5')
115 | ax.plot(*wps[:2], label='wps (AUC: {:.1%}, F1: {:.1%}, EER: {:.1%})'.format(auc(*wps[:2]), peakf1(*wps[:2]), eer(*wps[:2])), c='#8EBA42')
116 |
117 | if title is not None:
118 | fig.suptitle(title, fontsize=16, y=0.91)
119 |
120 | _prettify_pr_curve(ax)
121 | _lbplt_fatlegend(ax, loc='upper right')
122 |
123 | return fig, ax
124 |
125 |
126 | def plot_prec_rec_wps_only(wps, figsize=(15,10), title=None):
127 | fig, ax = plt.subplots(figsize=figsize)
128 |
129 | # make sure the x-input to auc is sorted
130 | assert np.sum(np.diff(wps[0])>=0) == len(wps[0]) - 1
131 |
132 | ax.plot(*wps[:2], label='wps (AUC: {:.1%}, F1: {:.1%}, EER: {:.1%})'.format(auc(*wps[:2]), peakf1(*wps[:2]), eer(*wps[:2])), c='#8EBA42')
133 |
134 | if title is not None:
135 | fig.suptitle(title, fontsize=16, y=0.91)
136 |
137 | _prettify_pr_curve(ax)
138 | _lbplt_fatlegend(ax, loc='upper right')
139 | return fig, ax
140 |
141 |
142 | def _prettify_pr_curve(ax):
143 | ax.plot([0,1], [0,1], ls="--", c=".6")
144 | ax.set_xlim(-0.02,1.02)
145 | ax.set_ylim(-0.02,1.02)
146 | ax.set_xlabel("Recall [%]")
147 | ax.set_ylabel("Precision [%]")
148 | ax.axes.xaxis.set_major_formatter(mpl.ticker.FuncFormatter(lambda x, pos: '{:.0f}'.format(x*100)))
149 | ax.axes.yaxis.set_major_formatter(mpl.ticker.FuncFormatter(lambda x, pos: '{:.0f}'.format(x*100)))
150 | return ax
151 |
152 |
153 | def _lbplt_fatlegend(ax=None, *args, **kwargs):
154 | # Copy paste from lbtoolbox.plotting.fatlegend
155 | if ax is not None:
156 | leg = ax.legend(*args, **kwargs)
157 | else:
158 | leg = plt.legend(*args, **kwargs)
159 |
160 | for l in leg.legendHandles:
161 | l.set_linewidth(l.get_linewidth()*2.0)
162 | l.set_alpha(1)
163 | return leg
164 |
--------------------------------------------------------------------------------
/dr_spaam/src/dr_spaam/utils/pytorch_nms/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2018, Grégoire Payen de La Garanderie, Durham University
2 | All rights reserved.
3 |
4 | Redistribution and use in source and binary forms, with or without
5 | modification, are permitted provided that the following conditions are met:
6 |
7 | * Redistributions of source code must retain the above copyright notice, this
8 | list of conditions and the following disclaimer.
9 |
10 | * Redistributions in binary form must reproduce the above copyright notice,
11 | this list of conditions and the following disclaimer in the documentation
12 | and/or other materials provided with the distribution.
13 |
14 | * Neither the name of the copyright holder nor the names of its
15 | contributors may be used to endorse or promote products derived from
16 | this software without specific prior written permission.
17 |
18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 |
29 | ************************************************************************
30 |
31 | THIRD-PARTY SOFTWARE NOTICES AND INFORMATION
32 |
33 | This project incorporates material from the project(s)
34 | listed below (collectively, "Third Party Code"). This Third Party Code is
35 | licensed to you under their original license terms set forth below.
36 |
37 | 1. Faster R-CNN, (https://github.com/rbgirshick/py-faster-rcnn)
38 |
39 | The MIT License (MIT)
40 |
41 | Copyright (c) 2015 Microsoft Corporation
42 |
43 | Permission is hereby granted, free of charge, to any person obtaining a copy
44 | of this software and associated documentation files (the "Software"), to deal
45 | in the Software without restriction, including without limitation the rights
46 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
47 | copies of the Software, and to permit persons to whom the Software is
48 | furnished to do so, subject to the following conditions:
49 |
50 | The above copyright notice and this permission notice shall be included in
51 | all copies or substantial portions of the Software.
52 |
53 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
54 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
55 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
56 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
57 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
58 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
59 | THE SOFTWARE.
60 |
61 |
--------------------------------------------------------------------------------
/dr_spaam/src/dr_spaam/utils/pytorch_nms/README.md:
--------------------------------------------------------------------------------
1 | # Torchvision support for NMS
2 |
3 | Note: Since the publication of this repository, NMS support has been included as part of torchvision. Therefore you might want to use this implementation instead:
4 | https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py.
5 |
6 | This repository might still be of interest if you need the index in the `keep` list of the highest-scoring box overlapping each input box.
7 |
8 | # CUDA implementation of NMS for PyTorch.
9 |
10 |
11 | This repository has a CUDA implementation of NMS for PyTorch 1.4.0.
12 |
13 | The code is released under the BSD license however it also includes parts of the original implementation from [Fast R-CNN](https://github.com/rbgirshick/py-faster-rcnn) which falls under the MIT license (see LICENSE file for details).
14 |
15 | The code is experimental and has not be thoroughly tested yet; use at your own risk. Any issues and pull requests are welcome.
16 |
17 | ## Installation
18 |
19 | ```
20 | python setup.py install
21 | ```
22 |
23 | ## Usage
24 |
25 | Example:
26 | ```
27 | from nms import nms
28 |
29 | keep, num_to_keep, parent_object_index = nms(boxes, scores, overlap=.5, top_k=200)
30 | ```
31 |
32 | The `nms` function takes a (N,4) tensor of `boxes` and associated (N) tensor of `scores`, sorts the bounding boxes by score and selects boxes using Non-Maximum Suppression according to the given `overlap`. It returns the indices of the `top_k` with the highest score. Bounding boxes are represented using the standard (left,top,right,bottom) coordinates representation.
33 |
34 | `keep` is the list of indices of kept bounding boxes. Note that the tensor size is always (N) however only the first `num_to_keep` entries are valid.
35 |
36 | For each input box, the (N) tensor `parent_object_index` contains the index (1-based) in the `keep` list of the highest-scoring box overlapping this box. This can be useful to group input boxes that are related to the same object. The index 0 represents a background box which has been ignored due to `top_k`.
37 |
38 | Currently there is a hard-limit of 64,000 input boxes. You can change the constant `MAX_COL_BLOCKS` in `nms_kernel.cu` to increase this limit.
39 |
40 |
--------------------------------------------------------------------------------
/dr_spaam/src/dr_spaam/utils/pytorch_nms/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension
3 |
4 | setup(name='nms', packages=['nms'],
5 | package_dir={'':'src'},
6 | ext_modules=[
7 | CUDAExtension(
8 | 'nms.details',
9 | ['src/nms.cpp', 'src/nms_kernel.cu'],
10 | extra_compile_args={'cxx': ['-g'], 'nvcc': ['-O2']})
11 | ],
12 | cmdclass={'build_ext': BuildExtension})
13 |
14 |
--------------------------------------------------------------------------------
/dr_spaam/src/dr_spaam/utils/pytorch_nms/src/nms.cpp:
--------------------------------------------------------------------------------
1 | /* Copyright (c) 2018, Grégoire Payen de La Garanderie, Durham University
2 | * All rights reserved.
3 | *
4 | * Redistribution and use in source and binary forms, with or without
5 | * modification, are permitted provided that the following conditions are met:
6 | *
7 | * * Redistributions of source code must retain the above copyright notice, this
8 | * list of conditions and the following disclaimer.
9 | *
10 | * * Redistributions in binary form must reproduce the above copyright notice,
11 | * this list of conditions and the following disclaimer in the documentation
12 | * and/or other materials provided with the distribution.
13 | *
14 | * * Neither the name of the copyright holder nor the names of its
15 | * contributors may be used to endorse or promote products derived from
16 | * this software without specific prior written permission.
17 | *
18 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 | */
29 |
30 | #include
31 | #include
32 | #include
33 |
34 | std::vector nms_cuda_forward(
35 | at::Tensor boxes,
36 | at::Tensor idx,
37 | float nms_overlap_thresh,
38 | unsigned long top_k);
39 |
40 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
41 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
42 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
43 |
44 | std::vector nms_forward(
45 | at::Tensor boxes,
46 | at::Tensor scores,
47 | float thresh,
48 | unsigned long top_k) {
49 |
50 |
51 | auto idx = std::get<1>(scores.sort(0,true));
52 |
53 | CHECK_INPUT(boxes);
54 | CHECK_INPUT(idx);
55 |
56 | return nms_cuda_forward(boxes, idx, thresh, top_k);
57 | }
58 |
59 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
60 | m.def("nms_forward", &nms_forward, "NMS");
61 | }
62 |
63 |
--------------------------------------------------------------------------------
/dr_spaam/src/dr_spaam/utils/pytorch_nms/src/nms/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2018, Grégoire Payen de La Garanderie, Durham University
2 | # All rights reserved.
3 | #
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted provided that the following conditions are met:
6 | #
7 | # * Redistributions of source code must retain the above copyright notice, this
8 | # list of conditions and the following disclaimer.
9 | #
10 | # * Redistributions in binary form must reproduce the above copyright notice,
11 | # this list of conditions and the following disclaimer in the documentation
12 | # and/or other materials provided with the distribution.
13 | #
14 | # * Neither the name of the copyright holder nor the names of its
15 | # contributors may be used to endorse or promote products derived from
16 | # this software without specific prior written permission.
17 | #
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 |
29 | from . import details
30 |
31 | def nms(boxes, scores, overlap, top_k):
32 | return details.nms_forward(boxes, scores, overlap, top_k)
33 |
34 |
--------------------------------------------------------------------------------
/dr_spaam/src/dr_spaam/utils/pytorch_nms/src/nms_kernel.cu:
--------------------------------------------------------------------------------
1 | /* Copyright (c) 2018, Grégoire Payen de La Garanderie, Durham University
2 | * All rights reserved.
3 | *
4 | * Redistribution and use in source and binary forms, with or without
5 | * modification, are permitted provided that the following conditions are met:
6 | *
7 | * * Redistributions of source code must retain the above copyright notice, this
8 | * list of conditions and the following disclaimer.
9 | *
10 | * * Redistributions in binary form must reproduce the above copyright notice,
11 | * this list of conditions and the following disclaimer in the documentation
12 | * and/or other materials provided with the distribution.
13 | *
14 | * * Neither the name of the copyright holder nor the names of its
15 | * contributors may be used to endorse or promote products derived from
16 | * this software without specific prior written permission.
17 | *
18 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 | */
29 | #include
30 | #include
31 | #include
32 |
33 | #include
34 | #include
35 | #include
36 | #include
37 |
38 | // From https://stackoverflow.com/questions/14038589/what-is-the-canonical-way-to-check-for-errors-using-the-cuda-runtime-api
39 | #define gpuErrchk(ans) { gpuAssert((ans), __FILE__, __LINE__); }
40 | inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort=true)
41 | {
42 | if (code != cudaSuccess)
43 | {
44 | fprintf(stderr,"GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line);
45 | if (abort) exit(code);
46 | }
47 | }
48 |
49 | __global__ void printTensorKernel(
50 | torch::PackedTensorAccessor64 boxes,
51 | torch::PackedTensorAccessor64 inds,
52 | const int n_boxes)
53 | {
54 | for (int i = 0; i < n_boxes; ++i)
55 | {
56 | printf("idx: %d, x: %f, y: %f, sort: %i\n",
57 | i, boxes[i][0], boxes[i][1], inds[i][0]);
58 | }
59 | }
60 |
61 | // Hard-coded maximum. Increase if needed.
62 | #define MAX_COL_BLOCKS 1000
63 |
64 | #define DIVUP(m,n) (((m)+(n)-1) / (n))
65 | int64_t const threadsPerBlock = sizeof(unsigned long long) * 8;
66 |
67 | // The functions below originates from Fast R-CNN
68 | // See https://github.com/rbgirshick/py-faster-rcnn
69 | // Copyright (c) 2015 Microsoft
70 | // Licensed under The MIT License
71 | // Written by Shaoqing Ren
72 |
73 | template
74 | __device__ inline scalar_t devIoU(scalar_t const * const a, scalar_t const * const b) {
75 | // scalar_t left = max(a[0], b[0]), right = min(a[2], b[2]);
76 | // scalar_t top = max(a[1], b[1]), bottom = min(a[3], b[3]);
77 | // scalar_t width = max(right - left, 0.f), height = max(bottom - top, 0.f);
78 | // scalar_t interS = width * height;
79 | // scalar_t Sa = (a[2] - a[0]) * (a[3] - a[1]);
80 | // scalar_t Sb = (b[2] - b[0]) * (b[3] - b[1]);
81 | // return interS / (Sa + Sb - interS);
82 | scalar_t x_diff = a[0] - b[0];
83 | scalar_t y_diff = a[1] - b[1];
84 | return sqrt(x_diff * x_diff + y_diff * y_diff);
85 | }
86 |
87 | template
88 | __global__ void nms_kernel(const int64_t n_boxes, const scalar_t nms_overlap_thresh,
89 | const scalar_t *dev_boxes, const int64_t *idx, int64_t *dev_mask) {
90 | const int64_t row_start = blockIdx.y;
91 | const int64_t col_start = blockIdx.x;
92 |
93 | const int row_size =
94 | min(n_boxes - row_start * threadsPerBlock, threadsPerBlock);
95 | const int col_size =
96 | min(n_boxes - col_start * threadsPerBlock, threadsPerBlock);
97 |
98 | // __shared__ scalar_t block_boxes[threadsPerBlock * 4];
99 | // if (threadIdx.x < col_size) {
100 | // block_boxes[threadIdx.x * 4 + 0] =
101 | // dev_boxes[idx[(threadsPerBlock * col_start + threadIdx.x)] * 4 + 0];
102 | // block_boxes[threadIdx.x * 4 + 1] =
103 | // dev_boxes[idx[(threadsPerBlock * col_start + threadIdx.x)] * 4 + 1];
104 | // block_boxes[threadIdx.x * 4 + 2] =
105 | // dev_boxes[idx[(threadsPerBlock * col_start + threadIdx.x)] * 4 + 2];
106 | // block_boxes[threadIdx.x * 4 + 3] =
107 | // dev_boxes[idx[(threadsPerBlock * col_start + threadIdx.x)] * 4 + 3];
108 | // }
109 | __shared__ scalar_t block_boxes[threadsPerBlock * 2];
110 | if (threadIdx.x < col_size) {
111 | block_boxes[threadIdx.x * 2 + 0] =
112 | dev_boxes[idx[(threadsPerBlock * col_start + threadIdx.x)] * 2 + 0];
113 | block_boxes[threadIdx.x * 2 + 1] =
114 | dev_boxes[idx[(threadsPerBlock * col_start + threadIdx.x)] * 2 + 1];
115 | }
116 | __syncthreads();
117 |
118 | if (threadIdx.x < row_size) {
119 | const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x;
120 | const scalar_t *cur_box = dev_boxes + idx[cur_box_idx] * 2;
121 | // const scalar_t *cur_box = dev_boxes + idx[cur_box_idx] * 4;
122 | int i = 0;
123 | unsigned long long t = 0;
124 | int start = 0;
125 | if (row_start == col_start) {
126 | start = threadIdx.x + 1;
127 | }
128 | for (i = start; i < col_size; i++) {
129 | // if (devIoU(cur_box, block_boxes + i * 4) > nms_overlap_thresh) {
130 | if (devIoU(cur_box, block_boxes + i * 2) < nms_overlap_thresh) {
131 | t |= 1ULL << i;
132 | }
133 | }
134 | const int col_blocks = DIVUP(n_boxes, threadsPerBlock);
135 | dev_mask[cur_box_idx * col_blocks + col_start] = t;
136 | }
137 | }
138 |
139 |
140 | __global__ void nms_collect(const int64_t boxes_num, const int64_t col_blocks, int64_t top_k, const int64_t *idx, const int64_t *mask, int64_t *keep, int64_t *parent_object_index, int64_t *num_to_keep) {
141 | int64_t remv[MAX_COL_BLOCKS];
142 | int64_t num_to_keep_ = 0;
143 |
144 | for (int i = 0; i < col_blocks; i++) {
145 | remv[i] = 0;
146 | }
147 |
148 | for (int i = 0; i < boxes_num; ++i) {
149 | parent_object_index[i] = 0;
150 | }
151 |
152 | for (int i = 0; i < boxes_num; i++) {
153 | int nblock = i / threadsPerBlock;
154 | int inblock = i % threadsPerBlock;
155 |
156 |
157 | if (!(remv[nblock] & (1ULL << inblock))) {
158 | int64_t idxi = idx[i];
159 | keep[num_to_keep_] = idxi;
160 | const int64_t *p = &mask[0] + i * col_blocks;
161 | for (int j = nblock; j < col_blocks; j++) {
162 | remv[j] |= p[j];
163 | }
164 | for (int j = i; j < boxes_num; j++) {
165 | int nblockj = j / threadsPerBlock;
166 | int inblockj = j % threadsPerBlock;
167 | if (p[nblockj] & (1ULL << inblockj))
168 | parent_object_index[idx[j]] = num_to_keep_+1;
169 | }
170 | parent_object_index[idx[i]] = num_to_keep_+1;
171 |
172 | num_to_keep_++;
173 |
174 | if (num_to_keep_==top_k)
175 | break;
176 | }
177 | }
178 |
179 | // Initialize the rest of the keep array to avoid uninitialized values.
180 | for (int i = num_to_keep_; i < boxes_num; ++i)
181 | keep[i] = 0;
182 |
183 | *num_to_keep = min(top_k,num_to_keep_);
184 | }
185 |
186 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
187 |
188 | std::vector nms_cuda_forward(
189 | at::Tensor boxes,
190 | at::Tensor idx,
191 | float nms_overlap_thresh,
192 | unsigned long top_k) {
193 |
194 | // // check tensor value
195 | // auto boxes_a = boxes.packed_accessor64();
196 | // auto idx_a = idx.packed_accessor64();
197 | // printTensorKernel<<<1, 1>>>(boxes_a, idx_a, boxes.size(0));
198 |
199 | const auto boxes_num = boxes.size(0);
200 |
201 | const int col_blocks = DIVUP(boxes_num, threadsPerBlock);
202 |
203 | AT_ASSERTM (col_blocks < MAX_COL_BLOCKS, "The number of column blocks must be less than MAX_COL_BLOCKS. Increase the MAX_COL_BLOCKS constant if needed.");
204 |
205 | auto longOptions = torch::TensorOptions().device(torch::kCUDA).dtype(torch::kLong);
206 | auto mask = at::empty({boxes_num * col_blocks}, longOptions);
207 |
208 | dim3 blocks(DIVUP(boxes_num, threadsPerBlock),
209 | DIVUP(boxes_num, threadsPerBlock));
210 | dim3 threads(threadsPerBlock);
211 |
212 | CHECK_CONTIGUOUS(boxes);
213 | CHECK_CONTIGUOUS(idx);
214 | CHECK_CONTIGUOUS(mask);
215 |
216 | AT_DISPATCH_FLOATING_TYPES(boxes.type(), "nms_cuda_forward", ([&] {
217 | nms_kernel<<>>(boxes_num,
218 | (scalar_t)nms_overlap_thresh,
219 | boxes.data(),
220 | idx.data(),
221 | mask.data());
222 | }));
223 |
224 | gpuErrchk(cudaPeekAtLastError());
225 | gpuErrchk(cudaDeviceSynchronize());
226 |
227 | auto keep = at::empty({boxes_num}, longOptions);
228 | auto parent_object_index = at::empty({boxes_num}, longOptions);
229 | auto num_to_keep = at::empty({}, longOptions);
230 |
231 | nms_collect<<<1, 1>>>(boxes_num, col_blocks, top_k,
232 | idx.data(),
233 | mask.data(),
234 | keep.data(),
235 | parent_object_index.data(),
236 | num_to_keep.data());
237 |
238 |
239 | return {keep,num_to_keep,parent_object_index};
240 | }
241 |
242 |
--------------------------------------------------------------------------------
/dr_spaam/src/dr_spaam/utils/train_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from torch.nn.utils import clip_grad_norm_
4 | import tqdm
5 |
6 | def checkpoint_state(model=None, optimizer=None, epoch=None, it=None):
7 | optim_state = optimizer.state_dict() if optimizer is not None else None
8 | if model is not None:
9 | if isinstance(model, torch.nn.DataParallel):
10 | model_state = model.module.state_dict()
11 | else:
12 | model_state = model.state_dict()
13 | else:
14 | model_state = None
15 |
16 | return {'epoch': epoch, 'it': it, 'model_state': model_state, 'optimizer_state': optim_state}
17 |
18 |
19 | def save_checkpoint(state, filename='checkpoint', logger=None):
20 | filename = '{}.pth'.format(filename)
21 | torch.save(state, filename)
22 | if logger is not None:
23 | logger.info('Checkpoint saved to %s' % filename)
24 |
25 |
26 | def load_checkpoint(model=None, optimizer=None, filename='checkpoint', logger=None):
27 | if os.path.isfile(filename):
28 | if logger is not None:
29 | logger.info("Loading from checkpoint '{}'".format(filename))
30 | checkpoint = torch.load(filename)
31 | epoch = checkpoint['epoch'] if 'epoch' in checkpoint.keys() else -1
32 | it = checkpoint.get('it', 0.0)
33 | if model is not None and checkpoint['model_state'] is not None:
34 | # # @TODO Dirty fix, to be removed
35 | # if 'gate.neighbor_masks' in checkpoint['model_state']:
36 | # del checkpoint['model_state']['gate.neighbor_masks']
37 | model.load_state_dict(checkpoint['model_state'])
38 | if optimizer is not None and checkpoint['optimizer_state'] is not None:
39 | optimizer.load_state_dict(checkpoint['optimizer_state'])
40 | else:
41 | print('Could not find %s' % filename)
42 | raise FileNotFoundError
43 |
44 | return it, epoch
45 |
46 |
47 | class LucasScheduler(object):
48 | """
49 | Return `v0` until `e` reaches `e0`, then exponentially decay
50 | to `v1` when `e` reaches `e1` and return `v1` thereafter, until
51 | reaching `eNone`, after which it returns `None`.
52 |
53 | Copyright (C) 2017 Lucas Beyer - http://lucasb.eyer.be =)
54 | """
55 | def __init__(self, optimizer, e0, v0, e1, v1, eNone=float('inf')):
56 | self.e0, self.v0 = e0, v0
57 | self.e1, self.v1 = e1, v1
58 | self.eNone = eNone
59 | self._optim = optimizer
60 |
61 | def step(self, epoch):
62 | if epoch < self.e0:
63 | lr = self.v0
64 | elif epoch < self.e1:
65 | lr = self.v0 * (self.v1/self.v0)**((epoch-self.e0)/(self.e1-self.e0))
66 | elif epoch < self.eNone:
67 | lr = self.v1
68 |
69 | for group in self._optim.param_groups:
70 | group['lr'] = lr
71 |
72 | def get_lr(self):
73 | return self._optim.param_groups[0]['lr']
74 |
75 |
76 | class Trainer(object):
77 | def __init__(self, model, model_fn, optimizer, ckpt_dir, lr_scheduler,
78 | model_fn_eval, logger, tb_log, grad_norm_clip):
79 | self.model, self.optimizer, self.lr_scheduler = model, optimizer, lr_scheduler
80 | self.model_fn, self.model_fn_eval = model_fn, model_fn_eval
81 | self.ckpt_dir, self.logger, self.tb_log = ckpt_dir, logger, tb_log
82 | self.grad_norm_clip = grad_norm_clip
83 |
84 | self._epoch, self._it = 0, 0
85 |
86 | import signal
87 | signal.signal(signal.SIGINT, self._sigterm_cb)
88 | signal.signal(signal.SIGTERM, self._sigterm_cb)
89 |
90 | def _sigterm_cb(self, signum, frame):
91 | self.logger.warning('Received signal %s at frame %s' % (signum, frame))
92 | ckpt_name = os.path.join(self.ckpt_dir, 'sigterm_ckpt')
93 | save_checkpoint(checkpoint_state(self.model, self.optimizer, self._epoch, self._it),
94 | filename=ckpt_name, logger=self.logger)
95 | self.tb_log.flush()
96 | self.tb_log.close()
97 | import sys; sys.exit()
98 |
99 | def _train_it(self, batch):
100 | self.model.train()
101 | self.optimizer.zero_grad()
102 |
103 | loss, tb_dict, _ = self.model_fn(self.model, batch)
104 |
105 | loss.backward()
106 | if self.grad_norm_clip > 0:
107 | clip_grad_norm_(self.model.parameters(), self.grad_norm_clip)
108 | self.optimizer.step()
109 |
110 | return loss.item(), tb_dict
111 |
112 | def train(self, num_epochs, train_loader, eval_loader=None, eval_frequency=1,
113 | ckpt_save_interval=5, lr_scheduler_each_iter=True, starting_epoch=0,
114 | starting_iteration=0):
115 | self._it = starting_iteration
116 | with tqdm.trange(starting_epoch, num_epochs, desc='epochs') as tbar, \
117 | tqdm.tqdm(total=len(train_loader), leave=False, desc='train') as pbar:
118 |
119 | for self._epoch in tbar:
120 | if not lr_scheduler_each_iter:
121 | self.lr_scheduler.step(self._epoch)
122 |
123 | # train one epoch
124 | for cur_it, batch in enumerate(train_loader):
125 | if lr_scheduler_each_iter:
126 | self.lr_scheduler.step(self._epoch + cur_it / len(train_loader))
127 |
128 | cur_lr = self.lr_scheduler.get_lr()
129 | self.tb_log.add_scalar('learning_rate', cur_lr, self._it)
130 |
131 | loss, tb_dict = self._train_it(batch)
132 |
133 | disp_dict = {'loss': loss, 'lr': cur_lr}
134 |
135 | # log to console and tensorboard
136 | pbar.update()
137 | pbar.set_postfix(dict(total_it=self._it))
138 | tbar.set_postfix(disp_dict)
139 | tbar.refresh()
140 |
141 | self.tb_log.add_scalar('train_loss', loss, self._it)
142 | self.tb_log.add_scalar('learning_rate', cur_lr, self._it)
143 | for key, val in tb_dict.items():
144 | self.tb_log.add_scalar('train_' + key, val, self._it)
145 |
146 | self._it += 1
147 |
148 | # save trained model
149 | trained_epoch = self._epoch + 1
150 | if trained_epoch % ckpt_save_interval == 0:
151 | ckpt_name = os.path.join(self.ckpt_dir, 'ckpt_e%d' % trained_epoch)
152 | save_checkpoint(
153 | checkpoint_state(self.model, self.optimizer, trained_epoch, self._it),
154 | filename=ckpt_name, logger=self.logger)
155 |
156 | # eval one epoch
157 | if eval_loader is not None and trained_epoch % eval_frequency == 0:
158 | pbar.close()
159 | with torch.set_grad_enabled(False):
160 | self.model.eval()
161 | self.model_fn_eval(self.model, eval_loader, self._epoch, self._it)
162 |
163 | self.tb_log.flush()
164 |
165 | pbar.close()
166 | pbar = tqdm.tqdm(total=len(train_loader), leave=False, desc='train')
167 | pbar.set_postfix(dict(total_it=self._it))
--------------------------------------------------------------------------------
/dr_spaam/src/dr_spaam/utils/utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | # from numba import jit
3 | import numpy as np
4 | from scipy.ndimage import maximum_filter
5 | from scipy.spatial.distance import cdist
6 | import torch
7 | import cv2
8 |
9 | # from nms import nms
10 |
11 | # In numpy >= 1.17, np.clip is slow, use core.umath.clip instead
12 | # https://github.com/numpy/numpy/issues/14281
13 | if "clip" in dir(np.core.umath):
14 | _clip = np.core.umath.clip
15 | # print("use np.core.umath.clip")
16 | else:
17 | _clip = np.clip
18 | # print("use np.clip")
19 |
20 | def get_laser_phi(angle_inc=np.radians(0.5), num_pts=450):
21 | # Default setting of DROW, which use SICK S300 laser, with 225 deg fov
22 | # and 450 pts, mounted at 37cm height.
23 | laser_fov = (num_pts - 1) * angle_inc # 450 points
24 | return np.linspace(-laser_fov*0.5, laser_fov*0.5, num_pts)
25 |
26 |
27 | def scan_to_xy(scan, phi=None):
28 | if phi is None:
29 | return rphi_to_xy(scan, get_laser_phi())
30 | else:
31 | return rphi_to_xy(scan, phi)
32 |
33 |
34 | def xy_to_rphi(x, y):
35 | # NOTE: Axes rotated by 90 CCW by intent, so that 0 is top.
36 | # y axis aligns with the center of scan, pointing outward/upward, x axis pointing to right
37 | # phi is the angle with y axis, rotating towards x is positive
38 | return np.hypot(x, y), np.arctan2(x, y)
39 |
40 |
41 | # @jit
42 | def rphi_to_xy(r, phi):
43 | return r * np.sin(phi), r * np.cos(phi)
44 |
45 |
46 | def rphi_to_xy_torch(r, phi):
47 | return r * torch.sin(phi), r * torch.cos(phi)
48 |
49 |
50 | def global_to_canonical(scan_r, scan_phi, dets_r, dets_phi):
51 | # Canonical frame: origin at the scan points, y pointing outward/upward along the scan, x pointing rightward
52 | dx = np.sin(dets_phi - scan_phi) * dets_r
53 | dy = np.cos(dets_phi - scan_phi) * dets_r - scan_r
54 | return dx, dy
55 |
56 |
57 | # @jit
58 | def canonical_to_global(scan_r, scan_phi, dx, dy):
59 | tmp_y = scan_r + dy
60 | tmp_phi = np.arctan2(dx, tmp_y) # dx first is correct due to problem geometry dx -> y axis and vice versa.
61 | dets_phi = tmp_phi + scan_phi
62 | dets_r = tmp_y / np.cos(tmp_phi)
63 | return dets_r, dets_phi
64 |
65 |
66 | def canonical_to_global_torch(scan_r, scan_phi, dx, dy):
67 | tmp_y = scan_r + dy
68 | tmp_phi = torch.atan2(dx, tmp_y) # dx first is correct due to problem geometry dx -> y axis and vice versa.
69 | dets_phi = tmp_phi + scan_phi
70 | dets_r = tmp_y / torch.cos(tmp_phi)
71 | return dets_r, dets_phi
72 |
73 |
74 | def data_augmentation(sample_dict):
75 | scans, target_reg = sample_dict['scans'], sample_dict['target_reg']
76 |
77 | # # Random scaling
78 | # s = np.random.uniform(low=0.95, high=1.05)
79 | # scans = s * scans
80 | # target_reg = s * target_reg
81 |
82 | # Random left-right flip. Of whole batch for convenience, but should be the same as individuals.
83 | if np.random.rand() < 0.5:
84 | scans = scans[:, ::-1]
85 | target_reg[:, 0] = -target_reg[:, 0]
86 |
87 | sample_dict.update({'target_reg': target_reg, 'scans': scans})
88 |
89 | return sample_dict
90 |
91 |
92 | def get_regression_target(scan, scan_phi, wcs, was, wps,
93 | radius_wc=0.6, radius_wa=0.4, radius_wp=0.35,
94 | label_wc=1, label_wa=2, label_wp=3,
95 | pedestrian_only=False):
96 | num_pts = len(scan)
97 | target_cls = np.zeros(num_pts, dtype=np.int64)
98 | target_reg = np.zeros((num_pts, 2), dtype=np.float32)
99 |
100 | if pedestrian_only:
101 | all_dets = list(wps)
102 | all_radius = [radius_wp] * len(wps)
103 | labels = [0] + [1] * len(wps)
104 | else:
105 | all_dets = list(wcs) + list(was) + list(wps)
106 | all_radius = [radius_wc]*len(wcs) + [radius_wa]*len(was) + [radius_wp]*len(wps)
107 | labels = [0] + [label_wc] * len(wcs) + [label_wa] * len(was) + [label_wp] * len(wps)
108 |
109 | dets = closest_detection(scan, scan_phi, all_dets, all_radius)
110 |
111 | for i, (r, phi) in enumerate(zip(scan, scan_phi)):
112 | if 0 < dets[i]:
113 | target_cls[i] = labels[dets[i]]
114 | target_reg[i,:] = global_to_canonical(r, phi, *all_dets[dets[i]-1])
115 |
116 | return target_cls, target_reg
117 |
118 |
119 | def closest_detection(scan, scan_phi, dets, radii):
120 | """
121 | Given a single `scan` (450 floats), a list of r,phi detections `dets` (Nx2),
122 | and a list of N `radii` for those detections, return a mapping from each
123 | point in `scan` to the closest detection for which the point falls inside its radius.
124 | The returned detection-index is a 1-based index, with 0 meaning no detection
125 | is close enough to that point.
126 | """
127 | if len(dets) == 0:
128 | return np.zeros_like(scan, dtype=int)
129 |
130 | assert len(dets) == len(radii), "Need to give a radius for each detection!"
131 |
132 | # Distance (in x,y space) of each laser-point with each detection.
133 | scan_xy = np.array(rphi_to_xy(scan, scan_phi)).T # (N, 2)
134 | dists = cdist(scan_xy, np.array([rphi_to_xy(r, phi) for r, phi in dets]))
135 |
136 | # Subtract the radius from the distances, such that they are < 0 if inside, > 0 if outside.
137 | dists -= radii
138 |
139 | # Prepend zeros so that argmin is 0 for everything "outside".
140 | dists = np.hstack([np.zeros((len(scan), 1)), dists])
141 |
142 | # And find out who's closest, including the threshold!
143 | return np.argmin(dists, axis=1)
144 |
145 |
146 | def scans_to_cutout(scans, scan_phi, stride=1, centered=True,
147 | fixed=False, window_width=1.66, window_depth=1.0,
148 | num_cutout_pts=48, padding_val=29.99, area_mode=False):
149 | num_scans, num_pts = scans.shape
150 |
151 | # size (width) of the window
152 | dists = scans[:, ::stride] if fixed else \
153 | np.tile(scans[-1, ::stride], num_scans).reshape(num_scans, -1)
154 | half_alpha = np.arctan(0.5 * window_width / np.maximum(dists, 1e-2))
155 |
156 | # cutout indices
157 | delta_alpha = 2.0 * half_alpha / (num_cutout_pts - 1)
158 | ang_ct = scan_phi[::stride] - half_alpha + np.arange(num_cutout_pts).reshape(num_cutout_pts, 1, 1) * delta_alpha
159 | inds_ct = (ang_ct - scan_phi[0]) / (scan_phi[1] - scan_phi[0])
160 | outbound_mask = np.logical_or(inds_ct < 0, inds_ct > num_pts - 1)
161 |
162 | # cutout (linear interp)
163 | inds_ct_low = _clip(np.floor(inds_ct), 0, num_pts - 1).astype(np.int)
164 | inds_ct_high = _clip(inds_ct_low + 1, 0, num_pts - 1).astype(np.int)
165 | inds_ct_ratio = _clip(inds_ct - inds_ct_low, 0.0, 1.0)
166 | inds_offset = np.arange(num_scans).reshape(1, num_scans, 1) * num_pts # because np.take flattens array
167 | ct_low = np.take(scans, inds_ct_low + inds_offset)
168 | ct_high = np.take(scans, inds_ct_high + inds_offset)
169 | ct = ct_low + inds_ct_ratio * (ct_high - ct_low)
170 |
171 | # use area sampling for down-sampling (close points)
172 | if area_mode:
173 | num_pts_in_window = inds_ct[-1] - inds_ct[0]
174 | area_mask = num_pts_in_window > num_cutout_pts
175 | if np.sum(area_mask) > 0:
176 | # sample the window with more points than the actual number of points
177 | s_area = int(math.ceil(np.max(num_pts_in_window) / num_cutout_pts))
178 | num_ct_pts_area = s_area * num_cutout_pts
179 | delta_alpha_area = 2.0 * half_alpha / (num_ct_pts_area - 1)
180 | ang_ct_area = scan_phi[::stride] - half_alpha + \
181 | np.arange(num_ct_pts_area).reshape(num_ct_pts_area, 1, 1) * delta_alpha_area
182 | inds_ct_area = (ang_ct_area - scan_phi[0]) / (scan_phi[1] - scan_phi[0])
183 | inds_ct_area = np.rint(_clip(inds_ct_area, 0, num_pts - 1)).astype(np.int32)
184 | ct_area = np.take(scans, inds_ct_area + inds_offset)
185 | ct_area = ct_area.reshape(num_cutout_pts, s_area, num_scans, dists.shape[1]).mean(axis=1)
186 | ct[:, area_mask] = ct_area[:, area_mask]
187 |
188 | # normalize cutout
189 | ct[outbound_mask] = padding_val
190 | ct = _clip(ct, dists - window_depth, dists + window_depth)
191 | if centered:
192 | ct = ct - dists
193 | ct = ct / window_depth
194 |
195 | return np.ascontiguousarray(ct.transpose((2, 1, 0)), dtype=np.float32) # (scans, times, cutouts)
196 |
197 |
198 | def scans_to_cutout_torch(scans, scan_phi, stride=1, centered=True,
199 | fixed=False, window_width=1.66, window_depth=1.0,
200 | num_cutout_pts=48, padding_val=29.99, area_mode=False):
201 | num_scans, num_pts = scans.shape
202 |
203 | # size (width) of the window
204 | dists = scans[:, ::stride] if fixed else \
205 | scans[-1, ::stride].repeat(num_scans, 1)
206 | half_alpha = torch.atan(0.5 * window_width / torch.clamp(dists, min=1e-2))
207 |
208 | # cutout indices
209 | delta_alpha = 2.0 * half_alpha / (num_cutout_pts - 1)
210 | ang_step = torch.arange(
211 | num_cutout_pts, device=scans.device).view(
212 | num_cutout_pts, 1, 1) * delta_alpha
213 | ang_ct = scan_phi[::stride] - half_alpha + ang_step
214 | inds_ct = (ang_ct - scan_phi[0]) / (scan_phi[1] - scan_phi[0])
215 | outbound_mask = torch.logical_xor(inds_ct < 0, inds_ct > num_pts - 1)
216 |
217 | # cutout (linear interp)
218 | inds_ct_low = inds_ct.floor().long().clamp(min=0, max=num_pts - 1)
219 | inds_ct_high = inds_ct.ceil().long().clamp(min=0, max=num_pts - 1)
220 | inds_ct_ratio = (inds_ct - inds_ct_low).clamp(min=0.0, max=1.0)
221 | ct_low = torch.gather(
222 | scans.expand_as(inds_ct_low), dim=2, index=inds_ct_low)
223 | ct_high = torch.gather(
224 | scans.expand_as(inds_ct_high), dim=2, index=inds_ct_high)
225 | ct = ct_low + inds_ct_ratio * (ct_high - ct_low)
226 |
227 | # use area sampling for down-sampling (close points)
228 | if area_mode:
229 | num_pts_in_window = inds_ct[-1] - inds_ct[0]
230 | area_mask = num_pts_in_window > num_cutout_pts
231 | if torch.sum(area_mask) > 0:
232 | # sample the window with more points than the actual number of points
233 | s_area = (num_pts_in_window.max() / num_cutout_pts).ceil().long().item()
234 | num_ct_pts_area = s_area * num_cutout_pts
235 | delta_alpha_area = 2.0 * half_alpha / (num_ct_pts_area - 1)
236 | ang_step_area = torch.arange(
237 | num_ct_pts_area, device=scans.device).view(
238 | num_ct_pts_area, 1, 1) * delta_alpha_area
239 | ang_ct_area = scan_phi[::stride] - half_alpha + ang_step_area
240 | inds_ct_area = torch.round(
241 | (ang_ct_area - scan_phi[0]) / (scan_phi[1] - scan_phi[0])) \
242 | .long().clamp(min=0, max=num_pts - 1)
243 | ct_area = torch.gather(
244 | scans.expand_as(inds_ct_area), dim=2, index=inds_ct_area)
245 | ct_area = ct_area.view(
246 | num_cutout_pts, s_area, num_scans, dists.shape[1]).mean(dim=1)
247 | ct[:, area_mask] = ct_area[:, area_mask]
248 |
249 | # normalize cutout
250 | ct[outbound_mask] = padding_val
251 | # torch.clamp does not support tensor min/max
252 | ct = torch.where(ct < (dists - window_depth), dists - window_depth, ct)
253 | ct = torch.where(ct > (dists + window_depth), dists + window_depth, ct)
254 | if centered:
255 | ct = ct - dists
256 | ct = ct / window_depth
257 |
258 | # # compare impl with numpy version
259 | # ct_numpy = scans_to_cutout(
260 | # scans.data.cpu().numpy(), scan_phi.data.cpu().numpy(),
261 | # stride=stride, centered=centered, fixed=fixed, window_width=window_width,
262 | # window_depth=window_depth, num_cutout_pts=num_cutout_pts,
263 | # padding_val=padding_val, area_mode=area_mode)
264 | # print("max(abs(ct_numpy - ct_torch)) = %f" % (np.max(np.abs(
265 | # ct_numpy - ct.permute((2, 1, 0)).float().data.cpu().numpy()))))
266 |
267 | return ct.permute((2, 1, 0)).float().contiguous() # (scans, times, cutouts)
268 |
269 |
270 | def scans_to_cutout_original(scans, angle_incre, fixed=True, centered=True,
271 | pt_inds=None, window_width=1.66, window_depth=1.0,
272 | num_cutout_pts=48, padding_val=29.99):
273 | # assert False, "Deprecated"
274 |
275 | num_scans, num_pts = scans.shape
276 | if pt_inds is None:
277 | pt_inds = range(num_pts)
278 |
279 | scans_padded = np.pad(scans, ((0, 0), (0, 1)), mode='constant', constant_values=padding_val) # pad boarder
280 | scans_cutout = np.empty((num_pts, num_scans, num_cutout_pts), dtype=np.float32)
281 |
282 | for scan_idx in range(num_scans):
283 | for pt_idx in pt_inds:
284 | # Compute the size (width) of the window
285 | pt_r = scans[scan_idx, pt_idx] if fixed else scans[-1, pt_idx]
286 |
287 | half_alpha = float(np.arctan(0.5 * window_width / max(pt_r, 0.01)))
288 |
289 | # Compute the start and end indices of cutout
290 | start_idx = int(round(pt_idx - half_alpha / angle_incre))
291 | end_idx = int(round(pt_idx + half_alpha / angle_incre))
292 | cutout_pts_inds = np.arange(start_idx, end_idx + 1)
293 | cutout_pts_inds = _clip(cutout_pts_inds, -1, num_pts)
294 | # cutout_pts_inds = np.core.umath.clip(cutout_pts_inds, -1, num_pts)
295 | # cutout_pts_inds = cutout_pts_inds.clip(-1, num_pts)
296 |
297 | # cutout points
298 | cutout_pts = scans_padded[scan_idx, cutout_pts_inds]
299 |
300 | # resampling/interpolation
301 | interp = cv2.INTER_AREA if num_cutout_pts < len(cutout_pts_inds) else cv2.INTER_LINEAR
302 | cutout_sampled = cv2.resize(cutout_pts,
303 | (1, num_cutout_pts),
304 | interpolation=interp).squeeze()
305 |
306 | # center cutout and clip depth to avoid strong depth discontinuity
307 | cutout_sampled = _clip(cutout_sampled, pt_r - window_depth, pt_r + window_depth)
308 | # cutout_sampled = np.core.umath.clip(
309 | # cutout_sampled,
310 | # pt_r - window_depth,
311 | # pt_r + window_depth)
312 | # cutout_sampled = cutout_sampled.clip(pt_r - window_depth,
313 | # pt_r + window_depth)
314 |
315 | if centered:
316 | cutout_sampled -= pt_r # center
317 | cutout_sampled = cutout_sampled / window_depth # normalize
318 | scans_cutout[pt_idx, scan_idx, :] = cutout_sampled
319 |
320 | return scans_cutout
321 |
322 |
323 | def scans_to_polar_grid(scans, min_range=0.0, max_range=30.0, range_bin_size=1.0,
324 | tsdf_clip=1.0, normalize=True):
325 | num_scans, num_pts = scans.shape
326 | num_range = int((max_range - min_range) / range_bin_size) + 1
327 | mag_range, mid_range = max_range - min_range, 0.5 * (max_range - min_range)
328 |
329 | polar_grid = np.empty((num_scans, num_range, num_pts), dtype=np.float32)
330 |
331 | scans = np.clip(scans, min_range, max_range)
332 | scans_grid_inds = ((scans - min_range) / range_bin_size).astype(np.int32)
333 |
334 | for i_scan in range(num_scans):
335 | for i_pt in range(num_pts):
336 | range_grid_ind = scans_grid_inds[i_scan, i_pt]
337 | scan_val = scans[i_scan, i_pt]
338 |
339 | if tsdf_clip > 0.0:
340 | min_dist, max_dist = 0 - range_grid_ind, num_range - range_grid_ind
341 | tsdf = np.arange(min_dist, max_dist, step=1).astype(np.float32) * range_bin_size
342 | tsdf = np.clip(tsdf, -tsdf_clip, tsdf_clip)
343 | else:
344 | tsdf = np.zeros(num_range, dtype=np.float32)
345 |
346 | if normalize:
347 | scan_val = (scan_val - mid_range) / mag_range * 2.0
348 | tsdf = tsdf / mag_range * 2.0
349 |
350 | tsdf[range_grid_ind] = scan_val
351 | polar_grid[i_scan, :, i_pt] = tsdf
352 |
353 | return polar_grid
354 |
355 |
356 | def group_predicted_center(scan_grid, phi_grid, pred_cls, pred_reg, min_thresh=1e-5,
357 | class_weights=None, bin_size=0.1, blur_sigma=0.5,
358 | x_min=-15.0, x_max=15.0, y_min=-5.0, y_max=15.0,
359 | vote_collect_radius=0.3, cls_agnostic_vote=False):
360 | '''
361 | Convert a list of votes to a list of detections based on Non-Max suppression.
362 |
363 | ` `vote_combiner` the combination function for the votes per detection.
364 | - `bin_size` the bin size (in meters) used for the grid where votes are cast.
365 | - `blur_win` the window size (in bins) used to blur the voting grid.
366 | - `blur_sigma` the sigma used to compute the Gaussian in the blur window.
367 | - `x_min` the left limit for the voting grid, in meters.
368 | - `x_max` the right limit for the voting grid, in meters.
369 | - `y_min` the bottom limit for the voting grid in meters.
370 | - `y_max` the top limit for the voting grid in meters.
371 | - `vote_collect_radius` the radius use during the collection of votes assigned
372 | to each detection.
373 |
374 | Returns a list of tuples (x,y,probs) where `probs` has the same layout as
375 | `probas`.
376 | '''
377 | pred_r, pred_phi = canonical_to_global(scan_grid, phi_grid,
378 | pred_reg[:,0], pred_reg[:, 1])
379 | pred_xs, pred_ys = rphi_to_xy(pred_r, pred_phi)
380 |
381 | instance_mask = np.zeros(len(scan_grid), dtype=np.int32)
382 | scan_array_inds = np.arange(len(scan_grid))
383 |
384 | single_cls = pred_cls.shape[1] == 1
385 |
386 | if class_weights is not None and not single_cls:
387 | pred_cls = np.copy(pred_cls)
388 | pred_cls[:, 1:] *= class_weights
389 |
390 | # voting grid
391 | x_range = int((x_max-x_min) / bin_size)
392 | y_range = int((y_max-y_min) / bin_size)
393 | grid = np.zeros((x_range, y_range, pred_cls.shape[1]), np.float32)
394 |
395 | # update x/y max to correspond to the end of the last bin.
396 | x_max = x_min + x_range * bin_size
397 | y_max = y_min + y_range * bin_size
398 |
399 | # filter out all the weak votes
400 | pred_cls_agn = pred_cls[:, 0] if single_cls else np.sum(pred_cls[:, 1:], axis=-1)
401 | voters_inds = np.where(pred_cls_agn > min_thresh)[0]
402 |
403 | if len(voters_inds) == 0:
404 | return [], [], instance_mask
405 |
406 | pred_xs, pred_ys = pred_xs[voters_inds], pred_ys[voters_inds]
407 | pred_cls = pred_cls[voters_inds]
408 | scan_array_inds = scan_array_inds[voters_inds]
409 | pred_x_inds = np.int64((pred_xs - x_min) / bin_size)
410 | pred_y_inds = np.int64((pred_ys - y_min) / bin_size)
411 |
412 | # discard out of bound votes
413 | mask = (0 <= pred_x_inds) & (pred_x_inds < x_range) & (0 <= pred_y_inds) & (pred_y_inds < y_range)
414 | pred_x_inds, pred_xs = pred_x_inds[mask], pred_xs[mask]
415 | pred_y_inds, pred_ys = pred_y_inds[mask], pred_ys[mask]
416 | pred_cls = pred_cls[mask]
417 | scan_array_inds = scan_array_inds[mask]
418 |
419 | # vote into the grid, including the agnostic vote as sum of class-votes!
420 | # @TODO Do we need the class grids?
421 | if single_cls:
422 | np.add.at(grid, (pred_x_inds, pred_y_inds), pred_cls)
423 | else:
424 | np.add.at(grid, (pred_x_inds, pred_y_inds),
425 | np.concatenate([np.sum(pred_cls[:, 1:], axis=1, keepdims=True),
426 | pred_cls[:, 1:]],
427 | axis=1))
428 |
429 | # NMS, only in the "common" voting grid
430 | grid_all_cls = grid[:, :, 0]
431 | if blur_sigma > 0:
432 | blur_win = int(2 * ((blur_sigma*5) // 2) + 1)
433 | grid_all_cls = cv2.GaussianBlur(grid_all_cls, (blur_win, blur_win), blur_sigma)
434 | grid_nms_val = maximum_filter(grid_all_cls, size=3)
435 | grid_nms_inds = (grid_all_cls == grid_nms_val) & (grid_all_cls > 0)
436 | nms_xs, nms_ys = np.where(grid_nms_inds)
437 |
438 | if len(nms_xs) == 0:
439 | return [], [], instance_mask
440 |
441 | # Back from grid-bins to real-world locations.
442 | nms_xs = nms_xs * bin_size + x_min + bin_size / 2
443 | nms_ys = nms_ys * bin_size + y_min + bin_size / 2
444 |
445 | # For each vote, get which maximum/detection it contributed to.
446 | # Shape of `distance_to_center` (ndets, voters) and outer is (voters)
447 | distance_to_center = np.hypot(pred_xs - nms_xs[:, None], pred_ys - nms_ys[:, None])
448 | detection_ids = np.argmin(distance_to_center, axis=0)
449 |
450 | # Generate the final detections by average over their voters.
451 | dets_xs, dets_ys, dets_cls = [], [], []
452 | for ipeak in range(len(nms_xs)):
453 | voter_inds = np.where(detection_ids == ipeak)[0]
454 | voter_inds = voter_inds[distance_to_center[ipeak, voter_inds] < vote_collect_radius]
455 |
456 | support_xs, support_ys = pred_xs[voter_inds], pred_ys[voter_inds]
457 | support_cls = pred_cls[voter_inds]
458 |
459 | # mark instance, 0 is the background
460 | instance_mask[scan_array_inds[voter_inds]] = ipeak + 1
461 |
462 | if cls_agnostic_vote and not single_cls:
463 | weights = np.sum(support_cls[:, 1:], axis=1)
464 | norm = 1.0 / np.sum(weights)
465 | dets_xs.append(norm * np.sum(weights * support_xs))
466 | dets_ys.append(norm * np.sum(weights * support_ys))
467 | dets_cls.append(norm * np.sum(weights[:, None] * support_cls, axis=0))
468 | else:
469 | dets_xs.append(np.mean(support_xs))
470 | dets_ys.append(np.mean(support_ys))
471 | dets_cls.append(np.mean(support_cls, axis=0))
472 |
473 | return np.array([dets_xs, dets_ys]).T, np.array(dets_cls), instance_mask
474 |
475 |
476 | # @jit(nopython=True)
477 | def nms_predicted_center(scan_grid, phi_grid, pred_cls, pred_reg, min_dist=0.5):
478 | assert pred_cls.shape[1] == 1
479 |
480 | pred_r, pred_phi = canonical_to_global(
481 | scan_grid, phi_grid, pred_reg[:, 0], pred_reg[:, 1])
482 | pred_xs, pred_ys = rphi_to_xy(pred_r, pred_phi)
483 |
484 | # sort prediction with descending confidence
485 | sort_inds = np.argsort(pred_cls[:, 0])[::-1]
486 | pred_xs, pred_ys = pred_xs[sort_inds], pred_ys[sort_inds]
487 | pred_cls = pred_cls[sort_inds]
488 |
489 | # compute pair-wise distance
490 | num_pts = len(scan_grid)
491 | xdiff = pred_xs.reshape(num_pts, 1) - pred_xs.reshape(1, num_pts)
492 | ydiff = pred_ys.reshape(num_pts, 1) - pred_ys.reshape(1, num_pts)
493 | p_dist = np.sqrt(np.square(xdiff) + np.square(ydiff))
494 |
495 | # nms
496 | keep = np.ones(num_pts, dtype=np.bool_)
497 | instance_mask = np.zeros(num_pts, dtype=np.int32)
498 | instance_id = 1
499 | for i in range(num_pts):
500 | if not keep[i]:
501 | continue
502 |
503 | dup_inds = p_dist[i] < min_dist
504 | keep[dup_inds] = False
505 | keep[i] = True
506 | instance_mask[sort_inds[dup_inds]] = instance_id
507 | instance_id += 1
508 |
509 | det_xys = np.stack((pred_xs, pred_ys), axis=1)[keep]
510 | det_cls = pred_cls[keep]
511 |
512 | return det_xys, det_cls, instance_mask
513 |
514 |
515 | def nms_predicted_center_torch(scan_grid, phi_grid, pred_cls, pred_reg, min_dist=0.5):
516 | assert pred_cls.shape[1] == 1
517 |
518 | # scan_grid = torch.from_numpy(scan_grid).float().cuda(non_blocking=True)
519 | # phi_grid = torch.from_numpy(phi_grid).float().cuda(non_blocking=True)
520 |
521 | with torch.no_grad():
522 | pred_r, pred_phi = canonical_to_global_torch(
523 | scan_grid, phi_grid, pred_reg[:, 0], pred_reg[:, 1])
524 | pred_xs, pred_ys = rphi_to_xy_torch(pred_r, pred_phi)
525 | pred_xys = torch.stack((pred_xs, pred_ys), dim=1).contiguous()
526 |
527 | top_k = 10000 # keep all detections
528 | keep, num_to_keep, parent_object_index = nms(pred_xys, pred_cls, min_dist, top_k)
529 |
530 | dets_xy = pred_xys[keep[:num_to_keep]]
531 | dets_cls = pred_cls[keep[:num_to_keep]]
532 | instance_mask = parent_object_index.long()
533 |
534 | return dets_xy, dets_cls, instance_mask
535 |
--------------------------------------------------------------------------------
/dr_spaam_ros/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | cmake_minimum_required(VERSION 2.8.3)
2 | project(dr_spaam_ros)
3 |
4 | find_package(catkin REQUIRED
5 | COMPONENTS
6 | )
7 |
8 | catkin_package()
9 |
10 | catkin_python_setup()
11 |
12 |
13 |
--------------------------------------------------------------------------------
/dr_spaam_ros/config/dr_spaam_ros.yaml:
--------------------------------------------------------------------------------
1 | weight_file: '/home/dan/git/DR-SPAAM-Detector_private/dr_spaam/ckpts/dr_spaam_e40.pth'
2 | conf_thresh: 0.1
3 | stride: 1 # use this to skip laser points
4 |
--------------------------------------------------------------------------------
/dr_spaam_ros/config/topics.yaml:
--------------------------------------------------------------------------------
1 | publisher:
2 | detections:
3 | topic: /dr_spaam_detections
4 | queue_size: 1
5 | latch: false
6 |
7 | rviz:
8 | topic: /dr_spaam_rviz
9 | queue_size: 1
10 | latch: false
11 |
12 | subscriber:
13 | scan:
14 | topic: /sick_laser_front/scan
15 | queue_size: 1
16 |
--------------------------------------------------------------------------------
/dr_spaam_ros/example.rviz:
--------------------------------------------------------------------------------
1 | Panels:
2 | - Class: rviz/Displays
3 | Help Height: 0
4 | Name: Displays
5 | Property Tree Widget:
6 | Expanded:
7 | - /Global Options1
8 | - /PoseArray1/Status1
9 | Splitter Ratio: 0.6167800426483154
10 | Tree Height: 1886
11 | - Class: rviz/Selection
12 | Name: Selection
13 | - Class: rviz/Tool Properties
14 | Expanded:
15 | - /2D Pose Estimate1
16 | - /2D Nav Goal1
17 | - /Publish Point1
18 | Name: Tool Properties
19 | Splitter Ratio: 0.5886790156364441
20 | - Class: rviz/Views
21 | Expanded:
22 | - /Current View1
23 | Name: Views
24 | Splitter Ratio: 0.5
25 | - Class: rviz/Time
26 | Experimental: false
27 | Name: Time
28 | SyncMode: 0
29 | SyncSource: LaserScan
30 | Preferences:
31 | PromptSaveOnExit: true
32 | Toolbars:
33 | toolButtonStyle: 2
34 | Visualization Manager:
35 | Class: ""
36 | Displays:
37 | - Alpha: 0.10000000149011612
38 | Cell Size: 1
39 | Class: rviz/Grid
40 | Color: 85; 87; 83
41 | Enabled: false
42 | Line Style:
43 | Line Width: 0.029999999329447746
44 | Value: Lines
45 | Name: Grid
46 | Normal Cell Count: 0
47 | Offset:
48 | X: 0
49 | Y: 0
50 | Z: 0
51 | Plane: XY
52 | Plane Cell Count: 200
53 | Reference Frame:
54 | Value: false
55 | - Class: rviz/TF
56 | Enabled: true
57 | Frame Timeout: 1e+8
58 | Frames:
59 | All Enabled: false
60 | base_footprint:
61 | Value: true
62 | sick_laser_front:
63 | Value: true
64 | Marker Scale: 1
65 | Name: TF
66 | Show Arrows: true
67 | Show Axes: true
68 | Show Names: true
69 | Tree:
70 | base_footprint:
71 | sick_laser_front:
72 | {}
73 | Update Interval: 0
74 | Value: true
75 | - Alpha: 1
76 | Autocompute Intensity Bounds: true
77 | Autocompute Value Bounds:
78 | Max Value: 10
79 | Min Value: -10
80 | Value: true
81 | Axis: Z
82 | Channel Name: intensity
83 | Class: rviz/LaserScan
84 | Color: 204; 0; 0
85 | Color Transformer: FlatColor
86 | Decay Time: 0
87 | Enabled: true
88 | Invert Rainbow: false
89 | Max Color: 255; 255; 255
90 | Max Intensity: 4096
91 | Min Color: 0; 0; 0
92 | Min Intensity: 0
93 | Name: LaserScan
94 | Position Transformer: XYZ
95 | Queue Size: 10
96 | Selectable: true
97 | Size (Pixels): 5
98 | Size (m): 0.10000000149011612
99 | Style: Points
100 | Topic: /sick_laser_front/scan
101 | Unreliable: false
102 | Use Fixed Frame: true
103 | Use rainbow: true
104 | Value: true
105 | - Alpha: 1
106 | Arrow Length: 1
107 | Axes Length: 0.20000000298023224
108 | Axes Radius: 0.05000000074505806
109 | Class: rviz/PoseArray
110 | Color: 52; 101; 164
111 | Enabled: true
112 | Head Length: 0
113 | Head Radius: 0
114 | Name: PoseArray
115 | Shaft Length: 0.20000000298023224
116 | Shaft Radius: 0.20000000298023224
117 | Shape: Arrow (3D)
118 | Topic: /dr_spaam_detections
119 | Unreliable: false
120 | Value: true
121 | Enabled: true
122 | Global Options:
123 | Background Color: 211; 215; 207
124 | Default Light: true
125 | Fixed Frame: base_footprint
126 | Frame Rate: 30
127 | Name: root
128 | Tools:
129 | - Class: rviz/Interact
130 | Hide Inactive Objects: true
131 | - Class: rviz/MoveCamera
132 | - Class: rviz/Select
133 | - Class: rviz/FocusCamera
134 | - Class: rviz/Measure
135 | - Class: rviz/SetInitialPose
136 | Theta std deviation: 0.2617993950843811
137 | Topic: /initialpose
138 | X std deviation: 0.5
139 | Y std deviation: 0.5
140 | - Class: rviz/SetGoal
141 | Topic: /move_base_simple/goal
142 | - Class: rviz/PublishPoint
143 | Single click: true
144 | Topic: /clicked_point
145 | Value: true
146 | Views:
147 | Current:
148 | Angle: 3.7600014209747314
149 | Class: rviz/TopDownOrtho
150 | Enable Stereo Rendering:
151 | Stereo Eye Separation: 0.05999999865889549
152 | Stereo Focal Distance: 1
153 | Swap Stereo Eyes: false
154 | Value: false
155 | Invert Z Axis: false
156 | Name: Current View
157 | Near Clip Distance: 0.009999999776482582
158 | Scale: 85.89463806152344
159 | Target Frame: sick_laser_front
160 | Value: TopDownOrtho (rviz)
161 | X: 0
162 | Y: 0
163 | Saved: ~
164 | Window Geometry:
165 | Displays:
166 | collapsed: false
167 | Height: 2105
168 | Hide Left Dock: false
169 | Hide Right Dock: false
170 | QMainWindow State: 000000ff00000000fd0000000400000000000002240000079bfc0200000018fb0000001200530065006c0065006300740069006f006e00000001e10000009b0000005c00fffffffb0000001e0054006f006f006c002000500072006f007000650072007400690065007302000001ed000001df00000185000000a3fb000000120056006900650077007300200054006f006f02000001df000002110000018500000122fb000000200054006f006f006c002000500072006f0070006500720074006900650073003203000002880000011d000002210000017afb000000100044006900730070006c006100790073010000003d0000079b000000c900fffffffb0000002000730065006c0065006300740069006f006e00200062007500660066006500720200000138000000aa0000023a00000294fb00000014005700690064006500530074006500720065006f02000000e6000000d2000003ee0000030bfb0000000c004b0069006e0065006300740200000000000002160000024500000261fb0000000c004b0069006e00650063007402000001b4000002430000016a000000e2fb0000001400440053004c005200200069006d0061006700650200000c3e00000224000002c0000001a9fb00000028005200470042002000460072006f006e007400200054006f0070002000430061006d00650072006102000007b70000000f000003de00000375fb0000002800460072006f006e0074002000640065007000740068002000700061006e006f00720061006d00610000000000ffffffff0000000000000000fb0000002400460072006f006e00740020005200470042002000700061006e006f00720061006d006102000007c10000001800000262000003dbfb000000260052006500610072002000640065007000740068002000700061006e006f00720061006d0061020000051a0000002e00000266000003cbfb0000002200520065006100720020005200470042002000700061006e006f00720061006d006102000009ee0000001800000265000003cffb0000000c00430061006d00650072006102000000000000003d00000245000001d3fb0000000c00430061006d006500720061000000041d000000160000000000000000fb0000000a0049006d0061006700650200000947000002f3000001e90000014bfb00000008004c0065006600740200000b9a000000f2000002450000009ffb0000000a005200690067006800740200000b9a0000016400000245000000c7fb00000008005200650061007202000000000000023100000245000000fefb0000000a00520069006700680074020000031a0000032c0000015600000120fb00000008005200650061007202000004740000032c0000015600000120fb0000000c00430061006d00650072006103000004d30000024b00000175000000dc00000001000001a80000079bfc0200000003fb0000001e0054006f006f006c002000500072006f00700065007200740069006500730100000041000000780000000000000000fb0000000a00560069006500770073010000003d0000079b000000a400fffffffb0000001200530065006c0065006300740069006f006e010000025a000000b200000000000000000000000200000490000000a9fc0100000001fb0000000a00560069006500770073030000004e00000080000002e1000001970000000300000ebd0000003efc0100000002fb0000000800540069006d0065010000000000000ebd000002eb00fffffffb0000000800540069006d0065010000000000000450000000000000000000000ae50000079b00000004000000040000000800000008fc0000000100000002000000010000000a0054006f006f006c00730100000000ffffffff0000000000000000
171 | Selection:
172 | collapsed: false
173 | Time:
174 | collapsed: false
175 | Tool Properties:
176 | collapsed: false
177 | Views:
178 | collapsed: false
179 | Width: 3773
180 | X: 67
181 | Y: 27
182 |
--------------------------------------------------------------------------------
/dr_spaam_ros/launch/dr_spaam_ros.launch:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/dr_spaam_ros/package.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 | dr_spaam_ros
4 | 1.0.0
5 | ROS interface for DR-SPAAM detector
6 |
7 | Dan Jia
8 |
9 |
10 |
11 |
12 | TODO
13 |
14 | catkin
15 | rospy
16 | geometry_msgs
17 | sensor_msgs
18 |
19 |
--------------------------------------------------------------------------------
/dr_spaam_ros/scripts/drow_data_converter.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import argparse
3 | from math import sin, cos
4 | import numpy as np
5 |
6 | import rospy
7 | import rosbag
8 |
9 | from geometry_msgs.msg import TransformStamped
10 | from sensor_msgs.msg import LaserScan
11 | from tf2_msgs.msg import TFMessage
12 |
13 |
14 | def load_scans(fname):
15 | data = np.genfromtxt(fname, delimiter=",")
16 | seqs, times, scans = data[:, 0].astype(np.uint32), data[:, 1].astype(np.float32), data[:, 2:].astype(np.float32)
17 | return seqs, times, scans
18 |
19 |
20 | def load_odoms(fname):
21 | data = np.genfromtxt(fname, delimiter=",")
22 | seqs, times = data[:, 0].astype(np.uint32), data[:, 1].astype(np.float32)
23 | odos = data[:, 2:].astype(np.float32) # x, y, phi
24 | return seqs, times, odos
25 |
26 |
27 | def sequence_to_bag(seq_name, bag_name):
28 | scan_msg = LaserScan()
29 | scan_msg.header.frame_id = 'sick_laser_front'
30 | scan_msg.angle_min = np.radians(-225.0 / 2)
31 | scan_msg.angle_max = np.radians(225.0 / 2)
32 | scan_msg.range_min = 0.005
33 | scan_msg.range_max = 100.0
34 | scan_msg.scan_time = 0.066667
35 | scan_msg.time_increment = 0.000062
36 | scan_msg.angle_increment = (scan_msg.angle_max - scan_msg.angle_min) / 450
37 |
38 | tran = TransformStamped()
39 | tran.header.frame_id = 'base_footprint'
40 | tran.child_frame_id = 'sick_laser_front'
41 |
42 | with rosbag.Bag(bag_name, 'w') as bag:
43 | # write scans
44 | seqs, times, scans = load_scans(seq_name)
45 | for seq, time, scan in zip(seqs, times, scans):
46 | time = rospy.Time(time)
47 | scan_msg.header.seq = seq
48 | scan_msg.header.stamp = time
49 | scan_msg.ranges = scan
50 | bag.write('/sick_laser_front/scan', scan_msg, t=time)
51 |
52 | # write odometry
53 | seqs, times, odoms = load_odoms(seq_name[:-3] + 'odom2')
54 | for seq, time, odom in zip(seqs, times, odoms):
55 | time = rospy.Time(time)
56 | tran.header.seq = seq
57 | tran.header.stamp = time
58 | tran.transform.translation.x = odom[0]
59 | tran.transform.translation.y = odom[1]
60 | tran.transform.translation.z = 0.0
61 | tran.transform.rotation.x = 0.0
62 | tran.transform.rotation.y = 0.0
63 | tran.transform.rotation.z = sin(odom[2] * 0.5)
64 | tran.transform.rotation.w = cos(odom[2] * 0.5)
65 | tf_msg = TFMessage([tran])
66 | bag.write('/tf', tf_msg, t=time)
67 |
68 |
69 | if __name__ == '__main__':
70 | parser = argparse.ArgumentParser(description="arg parser")
71 | parser.add_argument("--seq", type=str, required=True, help="path to sequence")
72 | parser.add_argument("--output", type=str, required=False, default="./out.bag")
73 | args = parser.parse_args()
74 |
75 | sequence_to_bag(args.seq, args.output)
76 |
--------------------------------------------------------------------------------
/dr_spaam_ros/scripts/node.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import rospy
4 | from dr_spaam_ros.dr_spaam_ros import DrSpaamROS
5 |
6 |
7 | if __name__ == '__main__':
8 | rospy.init_node('dr_spaam_ros')
9 | try:
10 | DrSpaamROS()
11 | except rospy.ROSInterruptException:
12 | pass
13 | rospy.spin()
14 |
--------------------------------------------------------------------------------
/dr_spaam_ros/setup.py:
--------------------------------------------------------------------------------
1 | ## ! DO NOT MANUALLY INVOKE THIS setup.py, USE CATKIN INSTEAD
2 |
3 | from distutils.core import setup
4 | from catkin_pkg.python_setup import generate_distutils_setup
5 |
6 | # fetch values from package.xml
7 | setup_args = generate_distutils_setup(
8 | packages=['dr_spaam_ros'],
9 | package_dir={'': 'src'})
10 |
11 | setup(**setup_args)
--------------------------------------------------------------------------------
/dr_spaam_ros/src/dr_spaam_ros/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VisualComputingInstitute/DR-SPAAM-Detector/e5a5f73f69523b90829be06a2558b597c2934f9f/dr_spaam_ros/src/dr_spaam_ros/__init__.py
--------------------------------------------------------------------------------
/dr_spaam_ros/src/dr_spaam_ros/dr_spaam_ros.py:
--------------------------------------------------------------------------------
1 | # import time
2 | import numpy as np
3 | import rospy
4 |
5 | from sensor_msgs.msg import LaserScan
6 | from geometry_msgs.msg import Point, Pose, PoseArray
7 | from visualization_msgs.msg import Marker
8 |
9 | from dr_spaam.detector import Detector
10 |
11 |
12 | class DrSpaamROS():
13 | """ROS node to detect pedestrian using DR-SPAAM."""
14 | def __init__(self):
15 | self._read_params()
16 | self._detector = Detector(
17 | model_name="DR-SPAAM",
18 | ckpt_file = self.weight_file,
19 | gpu=True, stride=self.stride)
20 | self._init()
21 |
22 | def _read_params(self):
23 | """
24 | @brief Reads parameters from ROS server.
25 | """
26 | self.weight_file = rospy.get_param("~weight_file")
27 | self.conf_thresh = rospy.get_param("~conf_thresh")
28 | self.stride = rospy.get_param("~stride")
29 |
30 | def _init(self):
31 | """
32 | @brief Initialize ROS connection.
33 | """
34 | # Publisher
35 | topic, queue_size, latch = read_publisher_param("detections")
36 | self._dets_pub = rospy.Publisher(
37 | topic, PoseArray, queue_size=queue_size, latch=latch)
38 |
39 | topic, queue_size, latch = read_publisher_param("rviz")
40 | self._rviz_pub = rospy.Publisher(
41 | topic, Marker, queue_size=queue_size, latch=latch)
42 |
43 | # Subscriber
44 | topic, queue_size = read_subscriber_param("scan")
45 | self._scan_sub = rospy.Subscriber(
46 | topic, LaserScan, self._scan_callback, queue_size=queue_size)
47 |
48 | def _scan_callback(self, msg):
49 | if self._dets_pub.get_num_connections() == 0:
50 | return
51 |
52 | if not self._detector.laser_spec_set():
53 | self._detector.set_laser_spec(angle_inc=msg.angle_increment,
54 | num_pts=len(msg.ranges))
55 |
56 | scan = np.array(msg.ranges)
57 | scan[scan == 0.0] = 29.99
58 | scan[np.isinf(scan)] = 29.99
59 | scan[np.isnan(scan)] = 29.99
60 |
61 | # t = time.time()
62 | dets_xy, dets_cls, _ = self._detector(scan)
63 | # print("[DrSpaamROS] End-to-end inference time: %f" % (t - time.time()))
64 |
65 | # confidence threshold
66 | conf_mask = (dets_cls >= self.conf_thresh).reshape(-1)
67 | # if not np.sum(conf_mask) > 0:
68 | # return
69 | dets_xy = dets_xy[conf_mask]
70 | dets_cls = dets_cls[conf_mask]
71 |
72 | # convert and publish ros msg
73 | dets_msg = detections_to_pose_array(dets_xy, dets_cls)
74 | dets_msg.header = msg.header
75 | self._dets_pub.publish(dets_msg)
76 |
77 | rviz_msg = detections_to_rviz_marker(dets_xy, dets_cls)
78 | rviz_msg.header = msg.header
79 | self._rviz_pub.publish(rviz_msg)
80 |
81 |
82 | def detections_to_rviz_marker(dets_xy, dets_cls):
83 | """
84 | @brief Convert detection to RViz marker msg. Each detection is marked as
85 | a circle approximated by line segments.
86 | """
87 | msg = Marker()
88 | msg.action = Marker.ADD
89 | msg.ns = "dr_spaam_ros"
90 | msg.id = 0
91 | msg.type = Marker.LINE_LIST
92 |
93 | msg.scale.x = 0.03 # line width
94 | # red color
95 | msg.color.r = 1.0
96 | msg.color.a = 1.0
97 |
98 | # circle
99 | r = 0.2
100 | ang = np.linspace(0, 2 * np.pi, 20)
101 | xy_offsets = r * np.stack((np.cos(ang), np.sin(ang)), axis=1)
102 |
103 | # to msg
104 | for d_xy, d_cls in zip(dets_xy, dets_cls):
105 | # If laser is facing front, DR-SPAAM's y-axis aligns with the laser
106 | # center ray, x-axis points to right, z-axis points upward
107 | for i in range(len(xy_offsets) - 1):
108 | # start point of a segment
109 | p0 = Point()
110 | p0.x = d_xy[1] + xy_offsets[i, 0]
111 | p0.y = d_xy[0] + xy_offsets[i, 1]
112 | p0.z = 0.0
113 | msg.points.append(p0)
114 |
115 | # end point
116 | p1 = Point()
117 | p1.x = d_xy[1] + xy_offsets[i + 1, 0]
118 | p1.y = d_xy[0] + xy_offsets[i + 1, 1]
119 | p1.z = 0.0
120 | msg.points.append(p1)
121 |
122 | return msg
123 |
124 |
125 | def detections_to_pose_array(dets_xy, dets_cls):
126 | pose_array = PoseArray()
127 | for d_xy, d_cls in zip(dets_xy, dets_cls):
128 | # If laser is facing front, DR-SPAAM's y-axis aligns with the laser
129 | # center ray, x-axis points to right, z-axis points upward
130 | p = Pose()
131 | p.position.x = d_xy[1]
132 | p.position.y = d_xy[0]
133 | p.position.z = 0.0
134 | pose_array.poses.append(p)
135 |
136 | return pose_array
137 |
138 |
139 | def read_subscriber_param(name):
140 | """
141 | @brief Convenience function to read subscriber parameter.
142 | """
143 | topic = rospy.get_param("~subscriber/" + name + "/topic")
144 | queue_size = rospy.get_param("~subscriber/" + name + "/queue_size")
145 | return topic, queue_size
146 |
147 |
148 | def read_publisher_param(name):
149 | """
150 | @brief Convenience function to read publisher parameter.
151 | """
152 | topic = rospy.get_param("~publisher/" + name + "/topic")
153 | queue_size = rospy.get_param("~publisher/" + name + "/queue_size")
154 | latch = rospy.get_param("~publisher/" + name + "/latch")
155 | return topic, queue_size, latch
--------------------------------------------------------------------------------
/imgs/dets.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VisualComputingInstitute/DR-SPAAM-Detector/e5a5f73f69523b90829be06a2558b597c2934f9f/imgs/dets.gif
--------------------------------------------------------------------------------
/imgs/dr_spaam_ros.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VisualComputingInstitute/DR-SPAAM-Detector/e5a5f73f69523b90829be06a2558b597c2934f9f/imgs/dr_spaam_ros.gif
--------------------------------------------------------------------------------
/imgs/rosgraph.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VisualComputingInstitute/DR-SPAAM-Detector/e5a5f73f69523b90829be06a2558b597c2934f9f/imgs/rosgraph.png
--------------------------------------------------------------------------------
/imgs/tracks.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/VisualComputingInstitute/DR-SPAAM-Detector/e5a5f73f69523b90829be06a2558b597c2934f9f/imgs/tracks.gif
--------------------------------------------------------------------------------