├── .flake8 ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── dr_spaam ├── bin │ ├── plotting │ │ ├── analyze_pseudo_labels.py │ │ ├── get_eer_thresh.py │ │ ├── get_pseudo_label_videos.py │ │ ├── make_videos.sh │ │ └── plot_clustering.py │ ├── setup_jrdb_dataset.py │ └── train.py ├── cfgs │ ├── base_dr_spaam_drow_cfg.yaml │ ├── base_dr_spaam_jrdb_cfg.yaml │ ├── base_drow_drow_cfg.yaml │ └── base_drow_jrdb_cfg.yaml ├── dr_spaam │ ├── __init__.py │ ├── datahandle │ │ ├── __init__.py │ │ ├── _pypcd.py │ │ ├── drow_handle.py │ │ ├── jrdb_handle.py │ │ └── jrdb_handle_det3d.py │ ├── dataset │ │ ├── __init__.py │ │ ├── builder.py │ │ ├── drow_dataset.py │ │ └── jrdb_dataset.py │ ├── detector.py │ ├── model │ │ ├── __init__.py │ │ ├── _common.py │ │ ├── dr_spaam.py │ │ ├── dr_spaam_fn.py │ │ ├── drow_net.py │ │ ├── get_model.py │ │ └── losses.py │ ├── pipeline │ │ ├── __init__.py │ │ ├── logger.py │ │ ├── optim.py │ │ ├── pipeline.py │ │ └── trainer.py │ ├── pseudo_labels.py │ └── utils │ │ ├── __init__.py │ │ ├── jrdb_transforms.py │ │ ├── jrdb_utils.py │ │ ├── plotting.py │ │ ├── precision_recall.py │ │ ├── pytorch_nms │ │ ├── LICENSE │ │ ├── README.md │ │ ├── setup.py │ │ └── src │ │ │ ├── nms.cpp │ │ │ ├── nms │ │ │ └── __init__.py │ │ │ └── nms_kernel.cu │ │ └── utils.py ├── setup.py └── tests │ ├── test_dataloader.py │ ├── test_detector.py │ ├── test_detr_dataloader.py │ ├── test_drow_handle.py │ ├── test_inference_speed.py │ ├── test_jrdb_handle.py │ └── test_jrdb_handle_mayavi.py ├── dr_spaam_ros ├── CMakeLists.txt ├── config │ ├── dr_spaam_ros.yaml │ └── topics.yaml ├── example.rviz ├── launch │ └── dr_spaam_ros.launch ├── package.xml ├── scripts │ ├── drow_data_converter.py │ └── node.py ├── setup.py └── src │ └── dr_spaam_ros │ ├── __init__.py │ └── dr_spaam_ros.py └── imgs ├── dr_spaam_ros_graph.png ├── dr_spaam_ros_teaser.gif └── teaser_1.gif /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 88 3 | ignore = E203, W503, E741 4 | exclude = */__init__.py 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *~ 2 | *__pycache__* 3 | *_ext* 4 | *.bag 5 | *.csv 6 | *.cu.o 7 | *.DS_Store 8 | *.gif 9 | *.ipynb_checkpoints 10 | *.pkl 11 | *.png 12 | *.pth 13 | *.pyc 14 | *.tfevents* 15 | *.yml 16 | 17 | .idea/ 18 | # .vscode/ 19 | *.egg-info/ 20 | build/ 21 | ckpt/ 22 | ckpts/ 23 | ckpts 24 | data 25 | dist/ 26 | devel/ 27 | exp_*/ 28 | experiments 29 | logs 30 | output/ 31 | results/ 32 | result_*/ 33 | wandb 34 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/ambv/black 3 | rev: stable 4 | hooks: 5 | - id: black 6 | language_version: python3.8 7 | - repo: https://gitlab.com/pycqa/flake8 8 | rev: 3.7.9 9 | hooks: 10 | - id: flake8 11 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Person Detection in 2D Range Data 2 | This repository implements DROW3 ([arXiv](https://arxiv.org/abs/1804.02463)) and DR-SPAAM ([arXiv](https://arxiv.org/abs/2004.14079)), real-time person detectors using 2D LiDARs mounted at ankle or knee height. 3 | Also included are experiments from *Self-Supervised Person Detection in 2D Range Data using a Calibrated Camera* ([arXiv](https://arxiv.org/abs/2012.08890)). 4 | Pre-trained models (using PyTorch 1.6) can be found in this [Google drive](https://drive.google.com/drive/folders/1Wl2nC8lJ6s9NI1xtWwmxeAUnuxDiiM4W?usp=sharing). 5 | 6 | ![](imgs/teaser_1.gif) 7 | 8 | ## News 9 | 10 | [06-03-2021] Our work has been accepted to ICRA'21! Checkout the presentation video [here](https://www.youtube.com/watch?v=f5U1ZfqXtc0). 11 | 12 | ## Quick start 13 | 14 | First clone and install the repository 15 | ``` 16 | git clone https://github.com/VisualComputingInstitute/2D_lidar_person_detection.git 17 | cd dr_spaam 18 | python setup.py install 19 | ``` 20 | 21 | Use the `Detector` class to run inference 22 | ```python 23 | import numpy as np 24 | from dr_spaam.detector import Detector 25 | 26 | ckpt = 'path_to_checkpoint' 27 | detector = Detector( 28 | ckpt, 29 | model="DROW3", # Or DR-SPAAM 30 | gpu=True, # Use GPU 31 | stride=1, # Optionally downsample scan for faster inference 32 | panoramic_scan=True # Set to True if the scan covers 360 degree 33 | ) 34 | 35 | # tell the detector field of view of the LiDAR 36 | laser_fov_deg = 360 37 | detector.set_laser_fov(laser_fov_deg) 38 | 39 | # detection 40 | num_pts = 1091 41 | while True: 42 | # create a random scan 43 | scan = np.random.rand(num_pts) # (N,) 44 | 45 | # detect person 46 | dets_xy, dets_cls, instance_mask = detector(scan) # (M, 2), (M,), (N,) 47 | 48 | # confidence threshold 49 | cls_thresh = 0.5 50 | cls_mask = dets_cls > cls_thresh 51 | dets_xy = dets_xy[cls_mask] 52 | dets_cls = dets_cls[cls_mask] 53 | ``` 54 | 55 | ## ROS node 56 | 57 | ![](imgs/dr_spaam_ros_teaser.gif) 58 | 59 | ![](imgs/dr_spaam_ros_graph.png) 60 | 61 | We provide an example ROS node `dr_spaam_ros`. 62 | First install `dr_spaam` to your python environment. 63 | Then compile the ROS package 64 | ``` 65 | catkin build dr_spaam_ros 66 | ``` 67 | 68 | Modify the topics and the path to the pre-trained checkpoint at 69 | `dr_spaam_ros/config/` and launch the node 70 | ``` 71 | roslaunch dr_spaam_ros dr_spaam_ros.launch 72 | ``` 73 | 74 | For testing, you can play a rosbag sequence from JRDB dataset. 75 | For example, 76 | ``` 77 | rosbag play JRDB/test_dataset/rosbags/tressider-2019-04-26_0.bag 78 | ``` 79 | and use RViz to visualize the inference result. 80 | A simple RViz config is located at `dr_spaam_ros/example.rviz`. 81 | 82 | In addition, if you want to test with DROW dataset, you can convert a DROW sequence to a rosbag 83 | ``` 84 | python scripts/drow_data_converter.py --seq --output drow.bag 85 | ``` 86 | 87 | ## Training and evaluation 88 | 89 | Download the [DROW dataset](https://github.com/VisualComputingInstitute/DROW) and the [JackRabbot dataset](https://jrdb.stanford.edu/), 90 | and put them under `dr_spaam/data` as below. 91 | ``` 92 | dr_spaam 93 | ├── data 94 | │ ├── DROWv2-data 95 | │ │ ├── test 96 | │ │ ├── train 97 | │ │ ├── val 98 | │ ├── JRDB 99 | │ │ ├── test_dataset 100 | │ │ ├── train_dataset 101 | ... 102 | ``` 103 | 104 | First preprocess the JRDB dataset (extract laser measurements from raw rosbag and synchronize with images) 105 | ``` 106 | python bin/setup_jrdb_dataset.py 107 | ``` 108 | 109 | To train a network (or evaluate a pretrained checkpoint), run 110 | ``` 111 | python bin/train.py --cfg net_cfg.yaml [--ckpt ckpt_file.pth --evaluation] 112 | ``` 113 | where `net_cfg.yaml` specifies configuration for the training (see examples under `cfgs`). 114 | 115 | ## Self-supervised training with a calibrated camera 116 | 117 | If your robot has a calibrated camera (i.e. the transformation between the camera to the LiDAR is known), 118 | you can generate pseudo labels automatically during deployment and fine-tune the detector (no manual labeling needed). 119 | We provide a wrapper function `dr_spaam.pseudo_labels.get_regression_target_using_bounding_boxes()` for generating pseudo labels conveniently. 120 | For experiments using pseudo labels, 121 | checkout our paper *Self-Supervised Person Detection in 2D Range Data using a Calibrated Camera* ([arXiv](https://arxiv.org/abs/2012.08890)). 122 | Use checkpoints in this [Google drive](https://drive.google.com/drive/folders/1Wl2nC8lJ6s9NI1xtWwmxeAUnuxDiiM4W?usp=sharing) to reproduce our results. 123 | 124 | ## Inference time 125 | On DROW dataset (450 points, 225 degrees field of view) 126 | | | AP0.3 | AP0.5 | FPS (RTX 2080 laptop) | FPS (Jetson AGX Xavier) | 127 | |--------|------------------|------------------|-----------------------|------------------| 128 | |DROW3 | 0.638 | 0.659 | 115.7 | 24.9 | 129 | |DR-SPAAM| 0.707 | 0.723 | 99.6 | 22.5 | 130 | 131 | On JackRabbot dataset (1091 points, 360 degrees field of view) 132 | | | AP0.3 | AP0.5 | FPS (RTX 2080 laptop) | FPS (Jetson AGX Xavier) | 133 | |--------|------------------|------------------|-----------------------|------------------| 134 | |DROW3 | 0.762 | 0.829 | 35.6 | 10.0 | 135 | |DR-SPAAM| 0.785 | 0.849 | 29.4 | 8.8 | 136 | 137 | Note: Evaluation on DROW and JackRabbot are done using different models (the APs are not comparable cross dataset). 138 | Inference time was measured with PyTorch 1.7 and CUDA 10.2 on RTX 2080 laptop, 139 | and PyTorch 1.6 and L4T 4.4 on Jetson AGX Xavier. 140 | 141 | ## Citation 142 | If you use this repo in your project, please cite: 143 | ```BibTeX 144 | @article{Jia2021Person2DRange, 145 | title = {{Self-Supervised Person Detection in 2D Range Data using a 146 | Calibrated Camera}}, 147 | author = {Dan Jia and Mats Steinweg and Alexander Hermans and Bastian Leibe}, 148 | booktitle = {International Conference on Robotics and Automation (ICRA)}, 149 | year = {2021} 150 | } 151 | 152 | @inproceedings{Jia2020DRSPAAM, 153 | title = {{DR-SPAAM: A Spatial-Attention and Auto-regressive 154 | Model for Person Detection in 2D Range Data}}, 155 | author = {Dan Jia and Alexander Hermans and Bastian Leibe}, 156 | booktitle = {International Conference on Intelligent Robots and Systems (IROS)}, 157 | year = {2020} 158 | } 159 | ``` 160 | -------------------------------------------------------------------------------- /dr_spaam/bin/plotting/get_eer_thresh.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | 4 | p1 = "/home/jia/git/awesome_repos/2D_lidar_person_detection/dr_spaam/logs/20210520_224024_drow_jrdb_EVAL/output/val/e000000/evaluation/all/result_r05.pkl" 5 | p2 = "/home/jia/git/awesome_repos/2D_lidar_person_detection/dr_spaam/logs/20210520_231344_drow_jrdb_EVAL/output/val/e000000/evaluation/all/result_r05.pkl" 6 | 7 | for p in (p1, p2): 8 | with open(p, "rb") as f: 9 | res = pickle.load(f) 10 | 11 | eer = res["eer"] 12 | arg = np.argmin(np.abs(res["precisions"] - eer)) 13 | print(res["thresholds"][arg], " ", res["precisions"][arg], " ", res["recalls"][arg]) 14 | -------------------------------------------------------------------------------- /dr_spaam/bin/plotting/get_pseudo_label_videos.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | from tqdm import tqdm 4 | import yaml 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | 8 | import torch 9 | 10 | from dr_spaam.dataset import get_dataloader 11 | import dr_spaam.utils.jrdb_transforms as jt 12 | import dr_spaam.utils.utils as u 13 | 14 | from dr_spaam.pipeline.logger import Logger 15 | from dr_spaam.model.get_model import get_model 16 | 17 | # for _PLOTTING_INTERVAL = 2 18 | # 000579 = 14s 19 | # ffmpeg -r 20 -pattern_type glob -i 'packard-poster-session-2019-03-20_1/scan_pl_*.png' -c:v libx264 -vf fps=30 -pix_fmt yuv420p out_pl.mp4 20 | 21 | _X_LIM = (-15, 15) 22 | # _Y_LIM = (-10, 4) 23 | _Y_LIM = (-7, 7) 24 | 25 | _PLOTTING_INTERVAL = 2 26 | _MAX_COUNT = 1e9 27 | # _MAX_COUNT = 1e1 28 | _SEQ_MAX_COUNT = 2000 29 | 30 | # _COLOR_CLOSE_HSV = (1.0, 0.59, 0.75) 31 | _COLOR_CLOSE_HSV = (0.0, 1.0, 1.0) 32 | _COLOR_FAR_HSV = (0.0, 0.0, 1.0) 33 | _COLOR_DIST_RANGE = (0.0, 20.0) 34 | 35 | # _SPLIT = "train" 36 | _SPLIT = "val" 37 | 38 | _SAVE_DIR = f"/globalwork/jia/tmp/pseudo_label_videos/{_SPLIT}" 39 | os.makedirs(_SAVE_DIR, exist_ok=True) 40 | 41 | 42 | def _get_bounding_box_plotting_vertices(x0, y0, x1, y1): 43 | return np.array([(x0, y0), (x0, y1), (x1, y1), (x1, y0), (x0, y0)]) 44 | 45 | 46 | def _distance_to_bgr_color(dist): 47 | dist_normalized = ( 48 | np.clip(dist, _COLOR_DIST_RANGE[0], _COLOR_DIST_RANGE[1]) / _COLOR_DIST_RANGE[1] 49 | ).reshape(-1, 1) 50 | 51 | c_hsv = ( 52 | np.array(_COLOR_CLOSE_HSV).reshape(1, -1) * (1.0 - dist_normalized) 53 | + np.array(_COLOR_FAR_HSV).reshape(1, -1) * dist_normalized 54 | ).astype(np.float32) 55 | c_hsv = c_hsv[None, ...] 56 | c_bgr = cv2.cvtColor(c_hsv, cv2.COLOR_HSV2RGB) 57 | 58 | return c_bgr[0] 59 | 60 | 61 | def _plot_frame_im(batch_dict, ib, show_pseudo_labels=False): 62 | frame_id = f"{batch_dict['frame_id'][ib]:06d}" 63 | sequence = batch_dict["sequence"][ib] 64 | 65 | im = batch_dict["im_data"][ib]["stitched_image0"] 66 | crop_min_x = 0 67 | im = im[:, crop_min_x:] 68 | height = im.shape[0] 69 | width = im.shape[1] 70 | dpi = height / 1.0 71 | 72 | fig = plt.figure() 73 | fig.set_size_inches(1.0 * width / height, 1, forward=False) 74 | ax_im = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0]) 75 | fig.add_axes(ax_im) 76 | 77 | # image 78 | ax_im.axis("off") 79 | ax_im.imshow(im) 80 | plt.xlim(0, width) 81 | plt.ylim(height, 0) 82 | 83 | # laser points on image 84 | scan_r = batch_dict["scans"][ib][-1] 85 | scan_phi = batch_dict["scan_phi"][ib] 86 | scan_x, scan_y = u.rphi_to_xy(scan_r, scan_phi) 87 | scan_z = batch_dict["laser_z"][ib] 88 | scan_xyz_laser = np.stack((scan_x, -scan_y, scan_z), axis=0) # in JRDB laser frame 89 | p_xy, ib_mask = jt.transform_pts_laser_to_stitched_im(scan_xyz_laser) 90 | 91 | if show_pseudo_labels: 92 | # detection bounding box 93 | for box_dict in batch_dict["im_dets"][ib]: 94 | x0, y0, w, h = box_dict["box"] 95 | x1 = x0 + w 96 | y1 = y0 + h 97 | verts = _get_bounding_box_plotting_vertices(x0, y0, x1, y1) 98 | ax_im.plot( 99 | verts[:, 0] - crop_min_x, verts[:, 1], c=(0.0, 0.0, 1.0), alpha=0.3 100 | ) 101 | # c = max(float(box_dict["score"]) - 0.5, 0) * 2.0 102 | # ax_im.plot(verts[:, 0] - crop_min_x, verts[:, 1], c=(1.0 - c, 1.0 - c, 1.0)) 103 | # ax_im.plot(verts[:, 0] - crop_min_x, verts[:, 1], 104 | # c=(0.0, 0.0, 1.0), alpha=1.0) 105 | 106 | # x1_large = x1 + 0.05 * w 107 | # x0_large = x0 - 0.05 * w 108 | # y1_large = y1 + 0.05 * w 109 | # y0_large = y0 - 0.05 * w 110 | # in_box_mask = np.logical_and( 111 | # np.logical_and(p_xy[0] > x0_large, p_xy[0] < x1_large), 112 | # np.logical_and(p_xy[1] > y0_large, p_xy[1] < y1_large) 113 | # ) 114 | # neg_mask[in_box_mask] = False 115 | 116 | for box in batch_dict["pseudo_label_boxes"][ib]: 117 | x0, y0, x1, y1 = box 118 | verts = _get_bounding_box_plotting_vertices(x0, y0, x1, y1) 119 | ax_im.plot(verts[:, 0] - crop_min_x, verts[:, 1], c="green") 120 | 121 | # overlay only pseudo-label laser points on image 122 | pl_pos_mask = np.logical_and(batch_dict["target_cls"][ib] == 1, ib_mask) 123 | pl_neg_mask = np.logical_and(batch_dict["target_cls"][ib] == 0, ib_mask) 124 | ax_im.scatter( 125 | p_xy[0, pl_pos_mask] - crop_min_x, p_xy[1, pl_pos_mask], s=1, color="green", 126 | ) 127 | ax_im.scatter( 128 | p_xy[0, pl_neg_mask] - crop_min_x, 129 | p_xy[1, pl_neg_mask], 130 | s=1, 131 | color="orange", 132 | ) 133 | 134 | fig_file = os.path.join(_SAVE_DIR, f"figs/{sequence}/im_pl_{frame_id}.png") 135 | else: 136 | # overlay all laser points on image 137 | c_bgr = _distance_to_bgr_color(scan_r) 138 | ax_im.scatter( 139 | p_xy[0, ib_mask] - crop_min_x, p_xy[1, ib_mask], s=1, color=c_bgr[ib_mask] 140 | ) 141 | fig_file = os.path.join(_SAVE_DIR, f"figs/{sequence}/im_raw_{frame_id}.png") 142 | 143 | # save fig 144 | os.makedirs(os.path.dirname(fig_file), exist_ok=True) 145 | fig.savefig(fig_file, dpi=dpi) 146 | plt.close(fig) 147 | 148 | 149 | def _plot_frame_pts(batch_dict, ib, pred_cls, pred_reg, pred_cls_p, pred_reg_p): 150 | frame_id = f"{batch_dict['frame_id'][ib]:06d}" 151 | sequence = batch_dict["sequence"][ib] 152 | 153 | fig = plt.figure(figsize=(10, 10)) 154 | ax = fig.add_subplot(111) 155 | 156 | ax.set_xlim(_X_LIM[0], _X_LIM[1]) 157 | ax.set_ylim(_Y_LIM[0], _Y_LIM[1]) 158 | ax.set_xlabel("x [m]") 159 | ax.set_ylabel("y [m]") 160 | ax.set_aspect("equal") 161 | # ax.set_title(f"Frame {data_dict['idx'][ib]}. Press any key to exit.") 162 | 163 | # scan and cls label 164 | scan_r = batch_dict["scans"][ib][-1] 165 | scan_phi = batch_dict["scan_phi"][ib] 166 | scan_x, scan_y = u.rphi_to_xy(scan_r, scan_phi) 167 | ax.scatter(scan_x, scan_y, s=0.5, c="blue") 168 | 169 | # annotation 170 | ann = batch_dict["dets_wp"][ib] 171 | ann_valid_mask = batch_dict["anns_valid_mask"][ib] 172 | if len(ann) > 0: 173 | ann = np.array(ann) 174 | det_x, det_y = u.rphi_to_xy(ann[:, 0], ann[:, 1]) 175 | for x, y, valid in zip(det_x, det_y, ann_valid_mask): 176 | if valid: 177 | # c = plt.Circle((x, y), radius=0.1, color="red", fill=True) 178 | c = plt.Circle( 179 | (x, y), radius=0.4, color="red", fill=False, linestyle="--" 180 | ) 181 | ax.add_artist(c) 182 | 183 | # plot detections 184 | if pred_cls is not None and pred_reg is not None: 185 | dets_xy, dets_cls, _ = u.nms_predicted_center( 186 | scan_r, scan_phi, pred_cls[ib].reshape(-1), pred_reg[ib] 187 | ) 188 | dets_xy = dets_xy[dets_cls >= 0.9438938] # at EER 189 | if len(dets_xy) > 0: 190 | for x, y in dets_xy: 191 | c = plt.Circle((x, y), radius=0.4, color="green", fill=False) 192 | ax.add_artist(c) 193 | fig_file = os.path.join(_SAVE_DIR, f"figs/{sequence}/scan_det_{frame_id}.png") 194 | 195 | # plot in addition detections from a pre-trained 196 | if pred_cls_p is not None and pred_reg_p is not None: 197 | dets_xy, dets_cls, _ = u.nms_predicted_center( 198 | scan_r, scan_phi, pred_cls_p[ib].reshape(-1), pred_reg_p[ib] 199 | ) 200 | dets_xy = dets_xy[dets_cls > 0.29919282] # at EER 201 | if len(dets_xy) > 0: 202 | for x, y in dets_xy: 203 | c = plt.Circle((x, y), radius=0.4, color="green", fill=False) 204 | ax.add_artist(c) 205 | # plot pre-trained detections only 206 | elif pred_cls_p is not None and pred_reg_p is not None: 207 | dets_xy, dets_cls, _ = u.nms_predicted_center( 208 | scan_r, scan_phi, pred_cls_p[ib].reshape(-1), pred_reg_p[ib] 209 | ) 210 | dets_xy = dets_xy[dets_cls > 0.29919282] # at EER 211 | if len(dets_xy) > 0: 212 | for x, y in dets_xy: 213 | c = plt.Circle((x, y), radius=0.4, color="green", fill=False) 214 | ax.add_artist(c) 215 | fig_file = os.path.join( 216 | _SAVE_DIR, f"figs/{sequence}/scan_pretrain_{frame_id}.png" 217 | ) 218 | # plot pseudo-labels only 219 | else: 220 | pl_neg_mask = batch_dict["target_cls"][ib] == 0 221 | ax.scatter(scan_x[pl_neg_mask], scan_y[pl_neg_mask], s=0.5, c="orange") 222 | 223 | pl_xy = batch_dict["pseudo_label_loc_xy"][ib] 224 | if len(pl_xy) > 0: 225 | for x, y in pl_xy: 226 | c = plt.Circle((x, y), radius=0.4, color="green", fill=False) 227 | ax.add_artist(c) 228 | fig_file = os.path.join(_SAVE_DIR, f"figs/{sequence}/scan_pl_{frame_id}.png") 229 | 230 | # save fig 231 | os.makedirs(os.path.dirname(fig_file), exist_ok=True) 232 | fig.savefig(fig_file, dpi=200) 233 | plt.close(fig) 234 | 235 | 236 | def plot_pseudo_label_for_all_frames(): 237 | with open("./cfgs/base_drow_jrdb_cfg.yaml", "r") as f: 238 | cfg = yaml.safe_load(f) 239 | cfg["dataset"]["pseudo_label"] = True 240 | cfg["dataset"]["pl_correction_level"] = 0 241 | 242 | test_loader = get_dataloader( 243 | split=_SPLIT, 244 | batch_size=1, 245 | num_workers=1, 246 | shuffle=False, 247 | dataset_cfg=cfg["dataset"], 248 | ) 249 | 250 | model = get_model(cfg["model"]) 251 | model.cuda() 252 | model.eval() 253 | 254 | logger = Logger(cfg["pipeline"]["Logger"]) 255 | logger.load_ckpt("./ckpts/ckpt_jrdb_pl_drow3_phce_e40.pth", model) 256 | 257 | model_pretrain = get_model(cfg["model"]) 258 | model_pretrain.cuda() 259 | model_pretrain.eval() 260 | logger.load_ckpt("./ckpts/ckpt_drow_drow3_e40.pth", model_pretrain) 261 | 262 | # generate pseudo labels for all sample 263 | seq_count = 0 264 | for count, batch_dict in enumerate(tqdm(test_loader)): 265 | if batch_dict["first_frame"][0]: 266 | print(f"new seq, reset count, idx {count}") 267 | seq_count = 0 268 | 269 | if seq_count > _SEQ_MAX_COUNT: 270 | continue 271 | else: 272 | seq_count += 1 273 | 274 | if count >= _MAX_COUNT: 275 | break 276 | 277 | with torch.no_grad(): 278 | net_input = torch.from_numpy(batch_dict["input"]).cuda().float() 279 | pred_cls, pred_reg = model(net_input) 280 | pred_cls = torch.sigmoid(pred_cls).data.cpu().numpy() 281 | pred_reg = pred_reg.data.cpu().numpy() 282 | 283 | pred_cls_p, pred_reg_p = model_pretrain(net_input) 284 | pred_cls_p = torch.sigmoid(pred_cls_p).data.cpu().numpy() 285 | pred_reg_p = pred_reg_p.data.cpu().numpy() 286 | 287 | if count % _PLOTTING_INTERVAL == 0: 288 | for ib in range(len(batch_dict["input"])): 289 | # generate sequence videos 290 | _plot_frame_im(batch_dict, ib, False) # image 291 | _plot_frame_im(batch_dict, ib, True) # image 292 | _plot_frame_pts( 293 | batch_dict, ib, None, None, None, None 294 | ) # pseudo-label 295 | _plot_frame_pts( 296 | batch_dict, ib, pred_cls, pred_reg, None, None 297 | ) # detections 298 | 299 | # # use to generate comparsion between pre-trained and pseudo-label trained # noqa 300 | # _plot_frame_pts(batch_dict, ib, pred_cls, pred_reg, None, None) 301 | # _plot_frame_pts(batch_dict, ib, None, None, pred_cls_p, pred_reg_p) 302 | 303 | 304 | def plot_color_bar(): 305 | dist = np.linspace(0, 20.0, int(1e4)) 306 | c_bgr = _distance_to_bgr_color(dist) 307 | c_bgr = np.repeat(c_bgr[None, ...], 1000, axis=0) 308 | 309 | fig = plt.figure() 310 | fig.set_size_inches(10, 1, forward=False) 311 | ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0]) 312 | ax.set_axis_off() 313 | fig.add_axes(ax) 314 | ax.imshow(c_bgr) 315 | 316 | fig_file = os.path.join(_SAVE_DIR, "color_bar.pdf") 317 | fig.savefig(fig_file) 318 | plt.close(fig) 319 | 320 | 321 | if __name__ == "__main__": 322 | plot_color_bar() 323 | plot_pseudo_label_for_all_frames() 324 | -------------------------------------------------------------------------------- /dr_spaam/bin/plotting/make_videos.sh: -------------------------------------------------------------------------------- 1 | for DIR in figs/*; do 2 | if [[ ! -d $DIR ]]; then 3 | continue 4 | fi 5 | 6 | if [[ ! -d videos ]]; then 7 | mkdir videos 8 | fi 9 | 10 | seq=$(basename -- $DIR) 11 | 12 | ffmpeg -r 20 -pattern_type glob -i "${DIR}/im_raw_*.png" -c:v libx264 -vf fps=30 -pix_fmt yuv420p "videos/${seq}__im_raw.mp4" 13 | ffmpeg -r 20 -pattern_type glob -i "${DIR}/im_pl_*.png" -c:v libx264 -vf fps=30 -pix_fmt yuv420p "videos/${seq}__im_pl.mp4" 14 | ffmpeg -r 20 -pattern_type glob -i "${DIR}/scan_det_*.png" -c:v libx264 -vf fps=30 -pix_fmt yuv420p "videos/${seq}__det.mp4" 15 | ffmpeg -r 20 -pattern_type glob -i "${DIR}/scan_pl_*.png" -c:v libx264 -vf fps=30 -pix_fmt yuv420p "videos/${seq}__pl.mp4" 16 | done 17 | -------------------------------------------------------------------------------- /dr_spaam/bin/setup_jrdb_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import os 4 | import rosbag 5 | import shutil 6 | 7 | # Set root dir to JRDB 8 | _jrdb_dir = "./data/JRDB" 9 | 10 | # Variables defining output location. 11 | # Do not change unless for you know what you are doing. 12 | _output_laser_dir_name = "lasers" 13 | _output_frames_laser_im_fname = "frames_img_laser.json" 14 | _output_frames_laser_pc_fname = "frames_pc_laser.json" 15 | _output_laser_timestamp_fname = "timestamps.txt" 16 | 17 | 18 | def _laser_idx_to_fname(idx): 19 | return str(idx).zfill(6) + ".txt" 20 | 21 | 22 | def extract_laser_from_rosbag(split): 23 | """Extract and save combined laser scan from rosbag. Existing files will be overwritten. 24 | """ 25 | data_dir = os.path.join(_jrdb_dir, split + "_dataset") 26 | 27 | timestamp_dir = os.path.join(data_dir, "timestamps") 28 | bag_dir = os.path.join(data_dir, "rosbags") 29 | sequence_names = os.listdir(timestamp_dir) 30 | 31 | laser_dir = os.path.join(data_dir, _output_laser_dir_name) 32 | if os.path.exists(laser_dir): 33 | shutil.rmtree(laser_dir) 34 | os.mkdir(laser_dir) 35 | 36 | for idx, seq_name in enumerate(sequence_names): 37 | seq_laser_dir = os.path.join(laser_dir, seq_name) 38 | os.mkdir(seq_laser_dir) 39 | 40 | bag_file = os.path.join(bag_dir, seq_name + ".bag") 41 | print( 42 | "({}/{}) Extract laser from {} to {}".format( 43 | idx + 1, len(sequence_names), bag_file, seq_laser_dir 44 | ) 45 | ) 46 | bag = rosbag.Bag(bag_file) 47 | 48 | # extract all laser msgs 49 | timestamp_list = [] 50 | for count, (topic, msg, t) in enumerate( 51 | bag.read_messages(topics=["segway/scan_multi"]) 52 | ): 53 | scan = np.array(msg.ranges) 54 | fname = _laser_idx_to_fname(count) 55 | np.savetxt(os.path.join(seq_laser_dir, fname), scan, newline=" ") 56 | 57 | timestamp_list.append(t.to_sec()) 58 | 59 | np.savetxt( 60 | os.path.join(seq_laser_dir, _output_laser_timestamp_fname), 61 | np.array(timestamp_list), 62 | newline=" ", 63 | ) 64 | 65 | bag.close() 66 | 67 | 68 | def _match_pc_im_laser_one_sequence(split, sequence_name): 69 | """Write in timestamp dir a json file that contains url to matching pointcloud, 70 | laser, and image. Existing files will be overwritten. Pointcloud is used as 71 | the main sensors which other sensors are synchronized to. 72 | 73 | Args: 74 | split (str): "train" or "test" 75 | sequence_name (str): 76 | """ 77 | data_dir = os.path.join(_jrdb_dir, split + "_dataset") 78 | 79 | timestamp_dir = os.path.join(data_dir, "timestamps", sequence_name) 80 | laser_dir = os.path.join(data_dir, "lasers", sequence_name) 81 | 82 | # pc frames 83 | pc_frames_file = os.path.join(timestamp_dir, "frames_pc.json") 84 | with open(pc_frames_file, "r") as f: 85 | pc_frames = json.load(f)["data"] 86 | 87 | # im frames 88 | im_frames_file = os.path.join(timestamp_dir, "frames_img.json") 89 | with open(im_frames_file, "r") as f: 90 | im_frames = json.load(f)["data"] 91 | 92 | # synchronize pc and im frame 93 | pc_timestamp = np.array([float(f["timestamp"]) for f in pc_frames]) 94 | im_timestamp = np.array([float(f["timestamp"]) for f in im_frames]) 95 | 96 | pc_im_ft_diff = np.abs(pc_timestamp.reshape(-1, 1) - im_timestamp.reshape(1, -1)) 97 | pc_im_matching_inds = pc_im_ft_diff.argmin(axis=1) 98 | 99 | # synchronize pc and laser 100 | laser_timestamp = np.loadtxt( 101 | os.path.join(laser_dir, _output_laser_timestamp_fname), dtype=np.float64 102 | ) 103 | pc_laser_ft_diff = np.abs( 104 | pc_timestamp.reshape(-1, 1) - laser_timestamp.reshape(1, -1) 105 | ) 106 | pc_laser_matching_inds = pc_laser_ft_diff.argmin(axis=1) 107 | 108 | # create a merged frame dict 109 | output_frames = [] 110 | for i in range(len(pc_frames)): 111 | frame = { 112 | "pc_frame": pc_frames[i], 113 | "im_frame": im_frames[pc_im_matching_inds[i]], 114 | "laser_frame": { 115 | "url": os.path.join( 116 | _output_laser_dir_name, 117 | sequence_name, 118 | _laser_idx_to_fname(pc_laser_matching_inds[i]), 119 | ), 120 | "name": "laser_combined", 121 | "timestamp": laser_timestamp[pc_laser_matching_inds[i]], 122 | }, 123 | "timestamp": pc_frames[i]["timestamp"], 124 | "frame_id": pc_frames[i]["frame_id"], 125 | } 126 | 127 | # correct file url for pc and im 128 | for pc_dict in frame["pc_frame"]["pointclouds"]: 129 | f_name = os.path.basename(pc_dict["url"]) 130 | pc_dict["url"] = os.path.join( 131 | "pointclouds", pc_dict["name"], sequence_name, f_name 132 | ) 133 | 134 | for im_dict in frame["im_frame"]["cameras"]: 135 | f_name = os.path.basename(im_dict["url"]) 136 | cam_name = ( 137 | "image_stitched" 138 | if im_dict["name"] == "stitched_image0" 139 | else im_dict["name"][:-1] + "_" + im_dict["name"][-1] 140 | ) 141 | im_dict["url"] = os.path.join("images", cam_name, sequence_name, f_name) 142 | 143 | output_frames.append(frame) 144 | 145 | # write to file 146 | output_dict = {"data": output_frames} 147 | f_name = os.path.join(timestamp_dir, "frames_pc_im_laser.json") 148 | with open(f_name, "w") as fp: 149 | json.dump(output_dict, fp) 150 | 151 | 152 | def match_pc_im_laser(split): 153 | sequence_names = os.listdir( 154 | os.path.join(_jrdb_dir, split + "_dataset", "timestamps") 155 | ) 156 | for idx, seq_name in enumerate(sequence_names): 157 | print( 158 | "({}/{}) Match sensor data for sequence {}".format( 159 | idx + 1, len(sequence_names), seq_name 160 | ) 161 | ) 162 | _match_pc_im_laser_one_sequence(split, seq_name) 163 | 164 | 165 | if __name__ == "__main__": 166 | extract_laser_from_rosbag("train") 167 | match_pc_im_laser("train") 168 | extract_laser_from_rosbag("test") 169 | match_pc_im_laser("test") 170 | -------------------------------------------------------------------------------- /dr_spaam/bin/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import yaml 3 | import torch 4 | 5 | from dr_spaam.dataset import get_dataloader 6 | from dr_spaam.pipeline.pipeline import Pipeline 7 | from dr_spaam.model.get_model import get_model 8 | 9 | 10 | def run_training(model, pipeline, cfg): 11 | # main train loop 12 | train_loader = get_dataloader( 13 | split="train", shuffle=True, dataset_cfg=cfg["dataset"], **cfg["dataloader"] 14 | ) 15 | val_loader = get_dataloader( 16 | split="val", shuffle=True, dataset_cfg=cfg["dataset"], **cfg["dataloader"] 17 | ) 18 | status = pipeline.train(model, train_loader, val_loader) 19 | 20 | # test after training 21 | if not status: 22 | test_loader = get_dataloader( 23 | split="test", 24 | batch_size=1, 25 | num_workers=1, 26 | shuffle=False, 27 | dataset_cfg=cfg["dataset"], 28 | ) 29 | pipeline.evaluate(model, test_loader, tb_prefix="TEST") 30 | 31 | 32 | def run_evaluation(model, pipeline, cfg): 33 | val_loader = get_dataloader( 34 | split="val", 35 | batch_size=1, 36 | num_workers=1, 37 | shuffle=False, 38 | dataset_cfg=cfg["dataset"], 39 | ) 40 | pipeline.evaluate(model, val_loader, tb_prefix="VAL") 41 | 42 | test_loader = get_dataloader( 43 | split="test", 44 | batch_size=1, 45 | num_workers=1, 46 | shuffle=False, 47 | dataset_cfg=cfg["dataset"], 48 | ) 49 | pipeline.evaluate(model, test_loader, tb_prefix="TEST") 50 | 51 | 52 | if __name__ == "__main__": 53 | # Run benchmark to select fastest implementation of ops. 54 | torch.backends.cudnn.benchmark = True 55 | 56 | parser = argparse.ArgumentParser(description="arg parser") 57 | parser.add_argument( 58 | "--cfg", type=str, required=True, help="configuration of the experiment" 59 | ) 60 | parser.add_argument("--ckpt", type=str, required=False, default=None) 61 | parser.add_argument("--cont", default=False, action="store_true") 62 | parser.add_argument("--tmp", default=False, action="store_true") 63 | parser.add_argument("--evaluation", default=False, action="store_true") 64 | args = parser.parse_args() 65 | 66 | with open(args.cfg, "r") as f: 67 | cfg = yaml.safe_load(f) 68 | cfg["pipeline"]["Logger"]["backup_list"].append(args.cfg) 69 | if args.tmp: 70 | cfg["pipeline"]["Logger"]["tag"] += "_TMP" 71 | 72 | model = get_model(cfg["model"]) 73 | model.cuda() 74 | 75 | pipeline = Pipeline(model, cfg["pipeline"]) 76 | 77 | if args.ckpt: 78 | pipeline.load_ckpt(model, args.ckpt) 79 | elif args.cont and pipeline.sigterm_ckpt_exists(): 80 | pipeline.load_sigterm_ckpt(model) 81 | 82 | # dirty fix to avoid repeatative entries in cfg file 83 | cfg["dataset"]["mixup_alpha"] = cfg["model"]["mixup_alpha"] 84 | 85 | # training or evaluation 86 | if not args.evaluation: 87 | run_training(model, pipeline, cfg) 88 | else: 89 | run_evaluation(model, pipeline, cfg) 90 | 91 | pipeline.close() 92 | -------------------------------------------------------------------------------- /dr_spaam/cfgs/base_dr_spaam_drow_cfg.yaml: -------------------------------------------------------------------------------- 1 | # This file is intended as a config template for generate_experiments.py 2 | # Modify with caution 3 | 4 | model: 5 | # type: "drow" 6 | # kwargs: 7 | # dropout: 0.5 8 | # # focal_loss_gamma: 0.0 9 | 10 | type: "dr-spaam" 11 | kwargs: 12 | dropout: 0.5 13 | num_pts: 56 14 | embedding_length: 128 15 | alpha: 0.5 16 | window_size: 11 17 | panoramic_scan: False 18 | # focal_loss_gamma: 0.0 19 | 20 | # for coping with noisy pseudo labels 21 | cls_loss: 22 | type: 0 # 0 BCE, 1 SymmetricBCE, 2 PartiallyHuberisedBCE 23 | 24 | # # SymmetricBCE 25 | # kwargs: 26 | # alpha: 0.1 27 | # beta: 1.0 28 | 29 | # # PartiallyHuberisedBCE 30 | # kwargs: 31 | # tau: 3.0 32 | 33 | mixup_alpha: 0.0 34 | mixup_w: 0.0 35 | self_paced: False 36 | 37 | dataset: 38 | augment_data: False 39 | person_only: True 40 | pseudo_label: False # only matters for JRDB 41 | # For ablation study, remove the wrong label from pseudo labels. 42 | # 0 no correction, 1 remove false positives, 2 remove false negatives, 3 both. 43 | pl_correction_level: 0 44 | 45 | # DataHandle: 46 | # data_dir: "./data/JRDB" # ./data/JRDB or ./data/DROWv2-data 47 | # num_scans: 10 48 | # scan_stride: 1 49 | # tracking: False 50 | 51 | DataHandle: 52 | data_dir: "./data/DROWv2-data" # ./data/JRDB or ./data/DROWv2-data 53 | num_scans: 10 54 | scan_stride: 1 55 | 56 | cutout_kwargs: 57 | fixed: True 58 | centered: True 59 | window_width: 1.0 60 | window_depth: 0.5 61 | num_cutout_pts: 56 62 | padding_val: 29.99 63 | area_mode: True 64 | 65 | dataloader: 66 | batch_size: 8 67 | num_workers: 8 68 | 69 | pipeline: 70 | Trainer: 71 | grad_norm_clip: -1.0 72 | ckpt_interval: 5 73 | eval_interval: 5 74 | epoch: 40 75 | 76 | Optim: 77 | scheduler_kwargs: 78 | epoch0: 0 79 | epoch1: 40 80 | lr0: 1.e-3 81 | lr1: 1.e-6 82 | 83 | Logger: 84 | log_dir: "./logs/" 85 | tag: "dr_spaam" 86 | log_fname: "log.txt" 87 | backup_list: [] 88 | 89 | -------------------------------------------------------------------------------- /dr_spaam/cfgs/base_dr_spaam_jrdb_cfg.yaml: -------------------------------------------------------------------------------- 1 | # This file is intended as a config template for generate_experiments.py 2 | # Modify with caution 3 | 4 | model: 5 | # type: "drow" 6 | # kwargs: 7 | # dropout: 0.5 8 | # # focal_loss_gamma: 0.0 9 | 10 | type: "dr-spaam" 11 | kwargs: 12 | dropout: 0.5 13 | num_pts: 56 14 | embedding_length: 128 15 | alpha: 0.5 16 | window_size: 17 17 | panoramic_scan: True 18 | # focal_loss_gamma: 0.0 19 | 20 | # for coping with noisy pseudo labels 21 | cls_loss: 22 | type: 0 # 0 BCE, 1 SymmetricBCE, 2 PartiallyHuberisedBCE 23 | 24 | # # SymmetricBCE 25 | # kwargs: 26 | # alpha: 0.1 27 | # beta: 1.0 28 | 29 | # # PartiallyHuberisedBCE 30 | # kwargs: 31 | # tau: 3.0 32 | 33 | mixup_alpha: 0.0 34 | mixup_w: 0.0 35 | self_paced: False 36 | 37 | dataset: 38 | augment_data: False 39 | person_only: True 40 | pseudo_label: False # only matters for JRDB 41 | # For ablation study, remove the wrong label from pseudo labels. 42 | # 0 no correction, 1 remove false positives, 2 remove false negatives, 3 both. 43 | pl_correction_level: 0 44 | 45 | DataHandle: 46 | data_dir: "./data/JRDB" # ./data/JRDB or ./data/DROWv2-data 47 | num_scans: 10 48 | scan_stride: 1 49 | tracking: False 50 | 51 | # DataHandle: 52 | # data_dir: "./data/DROWv2-data" # ./data/JRDB or ./data/DROWv2-data 53 | # num_scans: 10 54 | # scan_stride: 1 55 | 56 | cutout_kwargs: 57 | fixed: True 58 | centered: True 59 | window_width: 1.0 60 | window_depth: 0.5 61 | num_cutout_pts: 56 62 | padding_val: 29.99 63 | area_mode: True 64 | 65 | dataloader: 66 | batch_size: 4 67 | num_workers: 4 68 | 69 | pipeline: 70 | Trainer: 71 | grad_norm_clip: -1.0 72 | ckpt_interval: 2 73 | eval_interval: 2 74 | epoch: 20 75 | 76 | Optim: 77 | scheduler_kwargs: 78 | epoch0: 5 79 | epoch1: 20 80 | lr0: 1.e-3 81 | lr1: 1.e-6 82 | 83 | Logger: 84 | log_dir: "./logs/" 85 | tag: "dr_spaam_jrdb" 86 | log_fname: "log.txt" 87 | backup_list: [] 88 | 89 | -------------------------------------------------------------------------------- /dr_spaam/cfgs/base_drow_drow_cfg.yaml: -------------------------------------------------------------------------------- 1 | # This file is intended as a config template for generate_experiments.py 2 | # Modify with caution 3 | 4 | model: 5 | type: "drow" 6 | kwargs: 7 | dropout: 0.5 8 | # focal_loss_gamma: 0.0 9 | 10 | # type: "dr-spaam" 11 | # kwargs: 12 | # dropout: 0.5 13 | # num_pts: 56 14 | # embedding_length: 128 15 | # alpha: 0.5 16 | # window_size: 11 17 | # panoramic_scan: False 18 | # # focal_loss_gamma: 0.0 19 | 20 | # for coping with noisy pseudo labels 21 | cls_loss: 22 | type: 0 # 0 BCE, 1 SymmetricBCE, 2 PartiallyHuberisedBCE 23 | 24 | # # SymmetricBCE 25 | # kwargs: 26 | # alpha: 0.1 27 | # beta: 1.0 28 | 29 | # # PartiallyHuberisedBCE 30 | # kwargs: 31 | # tau: 3.0 32 | 33 | mixup_alpha: 0.0 34 | mixup_w: 0.0 35 | self_paced: False 36 | 37 | dataset: 38 | augment_data: False 39 | person_only: True 40 | pseudo_label: False # only matters for JRDB 41 | # For ablation study, remove the wrong label from pseudo labels. 42 | # 0 no correction, 1 remove false positives, 2 remove false negatives, 3 both. 43 | pl_correction_level: 0 44 | 45 | # DataHandle: 46 | # data_dir: "./data/JRDB" # ./data/JRDB or ./data/DROWv2-data 47 | # num_scans: 1 48 | # scan_stride: 1 49 | # tracking: False 50 | 51 | DataHandle: 52 | data_dir: "./data/DROWv2-data" # ./data/JRDB or ./data/DROWv2-data 53 | num_scans: 1 54 | scan_stride: 1 55 | 56 | cutout_kwargs: 57 | fixed: True 58 | centered: True 59 | window_width: 1.0 60 | window_depth: 0.5 61 | num_cutout_pts: 56 62 | padding_val: 29.99 63 | area_mode: True 64 | 65 | dataloader: 66 | batch_size: 8 67 | num_workers: 8 68 | 69 | pipeline: 70 | Trainer: 71 | grad_norm_clip: -1.0 72 | ckpt_interval: 5 73 | eval_interval: 5 74 | epoch: 40 75 | 76 | Optim: 77 | scheduler_kwargs: 78 | epoch0: 0 79 | epoch1: 40 80 | lr0: 1.e-3 81 | lr1: 1.e-6 82 | 83 | Logger: 84 | log_dir: "./logs/" 85 | tag: "drow" 86 | log_fname: "log.txt" 87 | backup_list: [] 88 | 89 | 90 | 91 | 92 | -------------------------------------------------------------------------------- /dr_spaam/cfgs/base_drow_jrdb_cfg.yaml: -------------------------------------------------------------------------------- 1 | # This file is intended as a config template for generate_experiments.py 2 | # Modify with caution 3 | 4 | model: 5 | type: "drow" 6 | kwargs: 7 | dropout: 0.5 8 | # focal_loss_gamma: 0.0 9 | 10 | # type: "dr-spaam" 11 | # kwargs: 12 | # dropout: 0.5 13 | # num_pts: 56 14 | # embedding_length: 128 15 | # alpha: 0.5 16 | # window_size: 11 17 | # panoramic_scan: True 18 | # # focal_loss_gamma: 0.0 19 | 20 | # for coping with noisy pseudo labels 21 | cls_loss: 22 | type: 0 # 0 BCE, 1 SymmetricBCE, 2 PartiallyHuberisedBCE 23 | 24 | # # SymmetricBCE 25 | # kwargs: 26 | # alpha: 0.1 27 | # beta: 1.0 28 | 29 | # # PartiallyHuberisedBCE 30 | # kwargs: 31 | # tau: 3.0 32 | 33 | mixup_alpha: 0.0 34 | mixup_w: 0.0 35 | self_paced: False 36 | 37 | dataset: 38 | augment_data: False 39 | person_only: True 40 | pseudo_label: False # only matters for JRDB 41 | # For ablation study, remove the wrong label from pseudo labels. 42 | # 0 no correction, 1 remove false positives, 2 remove false negatives, 3 both. 43 | pl_correction_level: 0 44 | 45 | DataHandle: 46 | data_dir: "./data/JRDB" # ./data/JRDB or ./data/DROWv2-data 47 | num_scans: 1 48 | scan_stride: 1 49 | tracking: False 50 | 51 | # DataHandle: 52 | # data_dir: "./data/DROWv2-data" # ./data/JRDB or ./data/DROWv2-data 53 | # num_scans: 1 54 | # scan_stride: 1 55 | 56 | cutout_kwargs: 57 | fixed: True 58 | centered: True 59 | window_width: 1.0 60 | window_depth: 0.5 61 | num_cutout_pts: 56 62 | padding_val: 29.99 63 | area_mode: True 64 | 65 | dataloader: 66 | batch_size: 8 67 | num_workers: 8 68 | 69 | pipeline: 70 | Trainer: 71 | grad_norm_clip: -1.0 72 | ckpt_interval: 4 73 | eval_interval: 4 74 | epoch: 40 75 | 76 | Optim: 77 | scheduler_kwargs: 78 | epoch0: 10 79 | epoch1: 40 80 | lr0: 1.e-3 81 | lr1: 1.e-6 82 | 83 | Logger: 84 | log_dir: "./logs/" 85 | tag: "drow_jrdb" 86 | log_fname: "log.txt" 87 | backup_list: [] 88 | 89 | -------------------------------------------------------------------------------- /dr_spaam/dr_spaam/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VisualComputingInstitute/2D_lidar_person_detection/99dd7a2a0d64252905e4f621e2c45be64b653a32/dr_spaam/dr_spaam/__init__.py -------------------------------------------------------------------------------- /dr_spaam/dr_spaam/datahandle/__init__.py: -------------------------------------------------------------------------------- 1 | from .drow_handle import * 2 | from .jrdb_handle import * 3 | from .jrdb_handle_det3d import * 4 | -------------------------------------------------------------------------------- /dr_spaam/dr_spaam/datahandle/drow_handle.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | import json 3 | import numpy as np 4 | import os 5 | 6 | # Force the dataloader to load only one sample, in which case the network should 7 | # fit perfectly. 8 | _DEBUG_ONE_SAMPLE = False 9 | 10 | __all__ = ["DROWHandle"] 11 | 12 | 13 | class DROWHandle: 14 | def __init__(self, split, cfg): 15 | assert split in ["train", "val", "test"], f'Invalid split "{split}"' 16 | 17 | self._num_scans = cfg["num_scans"] 18 | self._scan_stride = cfg["scan_stride"] 19 | data_dir = os.path.abspath(os.path.expanduser(cfg["data_dir"])) 20 | 21 | if _DEBUG_ONE_SAMPLE: 22 | split = "train" 23 | seq_names = [f[:-4] for f in glob(os.path.join(data_dir, split, "*.csv"))] 24 | seq_names = seq_names[:6] 25 | else: 26 | seq_names = [f[:-4] for f in glob(os.path.join(data_dir, split, "*.csv"))] 27 | 28 | self.seq_names = seq_names 29 | 30 | # preload all annotations 31 | self.dets_ns, self.dets_wc, self.dets_wa, self.dets_wp = zip( 32 | *map(lambda f: self._load_det_file(f), seq_names) 33 | ) 34 | 35 | # look-up list to convert an index to sequence index and annotation index 36 | self.__flat_seq_inds, self.__flat_det_inds = [], [] 37 | for seq_idx, det_ns in enumerate(self.dets_ns): 38 | num_samples = len(det_ns) 39 | self.__flat_seq_inds += [seq_idx] * num_samples 40 | self.__flat_det_inds += range(num_samples) 41 | 42 | # placeholder for scans, which will be preload on the fly 43 | self.scans_ns = [None] * len(seq_names) 44 | self.scans_t = [None] * len(seq_names) 45 | self.scans = [None] * len(seq_names) 46 | 47 | # placeholder for mapping from detection index to scan index 48 | self.__id2is = [None] * len(seq_names) 49 | 50 | # load the scan sequence into memory if it has not been loaded 51 | for seq_idx in range(len(self.seq_names)): 52 | if self.scans[seq_idx] is None: 53 | self._load_scan_sequence(seq_idx) 54 | 55 | def __len__(self): 56 | if _DEBUG_ONE_SAMPLE: 57 | return 80 58 | else: 59 | return len(self.__flat_det_inds) 60 | 61 | def __getitem__(self, idx): 62 | if _DEBUG_ONE_SAMPLE: 63 | idx = 511 64 | 65 | # find matching seq_idx, det_idx, and scan_idx 66 | seq_idx = self.__flat_seq_inds[idx] 67 | det_idx = self.__flat_det_inds[idx] 68 | scan_idx = self.__id2is[seq_idx][det_idx] 69 | 70 | # annotation [(r, phi), (r, phi), ...] 71 | rtn_dict = { 72 | "idx": idx, 73 | "dets_wc": self.dets_wc[seq_idx][det_idx], 74 | "dets_wa": self.dets_wa[seq_idx][det_idx], 75 | "dets_wp": self.dets_wp[seq_idx][det_idx], 76 | } 77 | 78 | # load sequential scans up to the current one (array[frame, point]) 79 | delta_inds = (np.arange(self._num_scans) * self._scan_stride)[::-1] 80 | scans_inds = [max(0, scan_idx - i) for i in delta_inds] 81 | rtn_dict["scans"] = np.array([self.scans[seq_idx][i] for i in scans_inds]) 82 | rtn_dict["scans_ind"] = scans_inds 83 | rtn_dict["scan_phi"] = self.get_laser_phi() 84 | 85 | return rtn_dict 86 | 87 | def _load_scan_sequence(self, seq_idx): 88 | data = np.genfromtxt(self.seq_names[seq_idx] + ".csv", delimiter=",") 89 | self.scans_ns[seq_idx] = data[:, 0].astype(np.uint32) 90 | self.scans_t[seq_idx] = data[:, 1].astype(np.float32) 91 | self.scans[seq_idx] = data[:, 2:].astype(np.float32) 92 | 93 | # precompute a mapping from detection index to scan index such that 94 | # scans[seq_idx][scan_idx] matches dets[seq_idx][det_idx] 95 | is_ = 0 96 | id2is = [] 97 | for det_ns in self.dets_ns[seq_idx]: 98 | while self.scans_ns[seq_idx][is_] != det_ns: 99 | is_ += 1 100 | id2is.append(is_) 101 | self.__id2is[seq_idx] = id2is 102 | 103 | def _load_det_file(self, seq_name): 104 | def do_load(f_name): 105 | seqs, dets = [], [] 106 | with open(f_name) as f: 107 | for line in f: 108 | seq, tail = line.split(",", 1) 109 | seqs.append(int(seq)) 110 | dets.append(json.loads(tail)) 111 | return seqs, dets 112 | 113 | s1, wcs = do_load(seq_name + ".wc") 114 | s2, was = do_load(seq_name + ".wa") 115 | s3, wps = do_load(seq_name + ".wp") 116 | assert all(a == b == c for a, b, c in zip(s1, s2, s3)) 117 | 118 | return np.array(s1), wcs, was, wps 119 | 120 | @staticmethod 121 | def get_laser_phi(angle_inc=np.radians(0.5), num_pts=450): 122 | # Default setting of DROW, which use SICK S300 laser, with 225 deg fov 123 | # and 450 pts, mounted at 37cm height. 124 | laser_fov = (num_pts - 1) * angle_inc # 450 points 125 | return np.linspace(-laser_fov * 0.5, laser_fov * 0.5, num_pts) 126 | -------------------------------------------------------------------------------- /dr_spaam/dr_spaam/datahandle/jrdb_handle.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import cv2 3 | import json 4 | import numpy as np 5 | import os 6 | from ._pypcd import point_cloud_from_path 7 | 8 | # NOTE: Don't use open3d to load point cloud since it spams the console. Setting 9 | # verbosity level does not solve the problem 10 | # https://github.com/intel-isl/Open3D/issues/1921 11 | # https://github.com/intel-isl/Open3D/issues/884 12 | 13 | # Force the dataloader to load only one sample, in which case the network should 14 | # fit perfectly. 15 | _DEBUG_ONE_SAMPLE = False 16 | 17 | # Pointcloud and image is only needed for visualization. Turn off for fast dataloading 18 | _LOAD_PC_IM = True 19 | 20 | __all__ = ["JRDBHandle"] 21 | 22 | 23 | class JRDBHandle: 24 | def __init__(self, split, cfg, sequences=None, exclude_sequences=None): 25 | if _DEBUG_ONE_SAMPLE: 26 | split = "train" 27 | sequences = None 28 | exclude_sequences = None 29 | 30 | self.__num_scans = cfg["num_scans"] 31 | self.__scan_stride = cfg["scan_stride"] 32 | 33 | data_dir = os.path.abspath(os.path.expanduser(cfg["data_dir"])) 34 | data_dir = ( 35 | os.path.join(data_dir, "train_dataset") 36 | if split == "train" or split == "val" 37 | else os.path.join(data_dir, "test_dataset") 38 | ) 39 | 40 | self.data_dir = data_dir 41 | self.timestamp_dir = os.path.join(data_dir, "timestamps") 42 | self.pc_label_dir = os.path.join(data_dir, "labels", "labels_3d") 43 | self.im_label_dir = os.path.join(data_dir, "labels", "labels_2d_stitched") 44 | 45 | if sequences is not None: 46 | sequence_names = sequences 47 | else: 48 | sequence_names = os.listdir(self.timestamp_dir) 49 | # NOTE it is important to sort the return of os.listdir, since its order 50 | # changes for different file system. 51 | sequence_names.sort() 52 | 53 | if exclude_sequences is not None: 54 | sequence_names = [s for s in sequence_names if s not in exclude_sequences] 55 | 56 | self.sequence_names = sequence_names 57 | 58 | self.sequence_handle = [] 59 | self._sequence_beginning_inds = [0] 60 | self.__flat_inds_sequence = [] 61 | self.__flat_inds_frame = [] 62 | for seq_idx, seq_name in enumerate(self.sequence_names): 63 | self.sequence_handle.append(_SequenceHandle(self.data_dir, seq_name)) 64 | 65 | # build a flat index for all sequences and frames 66 | sequence_length = len(self.sequence_handle[-1]) 67 | self.__flat_inds_sequence += sequence_length * [seq_idx] 68 | self.__flat_inds_frame += range(sequence_length) 69 | 70 | self._sequence_beginning_inds.append( 71 | self._sequence_beginning_inds[-1] + sequence_length 72 | ) 73 | 74 | def __len__(self): 75 | if _DEBUG_ONE_SAMPLE: 76 | return 80 77 | else: 78 | return len(self.__flat_inds_frame) 79 | 80 | def __getitem__(self, idx): 81 | if _DEBUG_ONE_SAMPLE: 82 | idx = 500 83 | 84 | idx_sq = self.__flat_inds_sequence[idx] 85 | idx_fr = self.__flat_inds_frame[idx] 86 | 87 | frame_dict, pc_anns, im_anns, im_dets = self.sequence_handle[idx_sq][idx_fr] 88 | 89 | pc_data = {} 90 | im_data = {} 91 | if _LOAD_PC_IM: 92 | for pc_dict in frame_dict["pc_frame"]["pointclouds"]: 93 | pc_data[pc_dict["name"]] = self._load_pointcloud(pc_dict["url"]) 94 | 95 | for im_dict in frame_dict["im_frame"]["cameras"]: 96 | im_data[im_dict["name"]] = self._load_image(im_dict["url"]) 97 | 98 | laser_data = self._load_consecutive_lasers(frame_dict["laser_frame"]["url"]) 99 | 100 | frame_dict.update( 101 | { 102 | "frame_id": int(frame_dict["frame_id"]), 103 | "sequence": self.sequence_handle[idx_sq].sequence, 104 | "first_frame": idx_fr == 0, 105 | "idx": idx, 106 | "pc_data": pc_data, 107 | "im_data": im_data, 108 | "laser_data": laser_data, 109 | "pc_anns": pc_anns, 110 | "im_anns": im_anns, 111 | "im_dets": im_dets, 112 | "laser_grid": np.linspace( 113 | -np.pi, np.pi, laser_data.shape[1], dtype=np.float32 114 | ), 115 | "laser_z": -0.5 * np.ones(laser_data.shape[1], dtype=np.float32), 116 | } 117 | ) 118 | 119 | return frame_dict 120 | 121 | @staticmethod 122 | def box_is_on_ground(jrdb_ann_dict): 123 | bottom_h = float(jrdb_ann_dict["box"]["cz"]) - 0.5 * float( 124 | jrdb_ann_dict["box"]["h"] 125 | ) 126 | 127 | return bottom_h < -0.69 # value found by examining dataset 128 | 129 | @property 130 | def sequence_beginning_inds(self): 131 | return copy.deepcopy(self._sequence_beginning_inds) 132 | 133 | def _load_pointcloud(self, url): 134 | """Load a point cloud given file url. 135 | 136 | Returns: 137 | pc (np.ndarray[3, N]): 138 | """ 139 | # pcd_load = 140 | # o3d.io.read_point_cloud(os.path.join(self.data_dir, url), format='pcd') 141 | # return np.asarray(pcd_load.points, dtype=np.float32) 142 | pc = point_cloud_from_path(os.path.join(self.data_dir, url)).pc_data 143 | # NOTE: redundent copy, ok for now 144 | pc = np.array([pc["x"], pc["y"], pc["z"]], dtype=np.float32) 145 | return pc 146 | 147 | def _load_image(self, url): 148 | """Load an image given file url. 149 | 150 | Returns: 151 | im (np.ndarray[H, W, 3]): (H, W) = (480, 3760) for stitched image, 152 | (480, 752) for individual image 153 | """ 154 | im = cv2.imread(os.path.join(self.data_dir, url), cv2.IMREAD_COLOR) 155 | im = cv2.cvtColor(im, cv2.COLOR_RGB2BGR) 156 | return im 157 | 158 | def _load_consecutive_lasers(self, url): 159 | """Load current and previous consecutive laser scans. 160 | 161 | Args: 162 | url (str): file url of the current scan 163 | 164 | Returns: 165 | pc (np.ndarray[self.num_scan, N]): Forward in time with increasing 166 | row index, i.e. the latest scan is pc[-1] 167 | """ 168 | fpath = os.path.dirname(url) 169 | current_frame_idx = int(os.path.basename(url).split(".")[0]) 170 | frames_list = [] 171 | for del_idx in reversed(range(self.__num_scans)): 172 | frame_idx = max(0, current_frame_idx - del_idx * self.__scan_stride) 173 | url = os.path.join(fpath, str(frame_idx).zfill(6) + ".txt") 174 | frames_list.append(self._load_laser(url)) 175 | 176 | return np.stack(frames_list, axis=0) 177 | 178 | def _load_laser(self, url): 179 | """Load a laser given file url. 180 | 181 | Returns: 182 | pc (np.ndarray[N, ]): 183 | """ 184 | return np.loadtxt(os.path.join(self.data_dir, url), dtype=np.float32) 185 | 186 | 187 | class _SequenceHandle: 188 | def __init__(self, data_dir, sequence, use_unlabeled_frames=False): 189 | self.sequence = sequence 190 | self._use_unlabeled_frames = use_unlabeled_frames 191 | 192 | # load frames of the sequence 193 | timestamp_dir = os.path.join(data_dir, "timestamps") 194 | fname = os.path.join(timestamp_dir, self.sequence, "frames_pc_im_laser.json") 195 | with open(fname, "r") as f: 196 | """ 197 | list[dict]. Each dict has following keys: 198 | pc_frame: dict with keys frame_id, pointclouds, laser, timestamp 199 | im_frame: same as above 200 | laser_frame: dict with keys url, name, timestamp 201 | frame_id: same as pc_frame["frame_id"] 202 | timestamp: same as pc_frame["timestamp"] 203 | """ 204 | self.frames = json.load(f)["data"] 205 | 206 | # load 3D annotation 207 | pc_label_dir = os.path.join(data_dir, "labels", "labels_3d") 208 | fname = os.path.join(pc_label_dir, f"{self.sequence}.json") 209 | with open(fname, "r") as f: 210 | """ 211 | dict, key is the pc file name, value is the labels (list[dict]) 212 | Each label is a dict with keys: 213 | attributes 214 | box 215 | file_id 216 | observation_angle 217 | label_id 218 | """ 219 | self.pc_labels = json.load(f)["labels"] 220 | 221 | # load 2D annotation 222 | im_label_dir = os.path.join(data_dir, "labels", "labels_2d_stitched") 223 | fname = os.path.join(im_label_dir, f"{self.sequence}.json") 224 | with open(fname, "r") as f: 225 | """ 226 | dict, key is the im file name, value is the labels (list[dict]) 227 | Each label is a dict with keys: 228 | attributes 229 | truncated (bool) 230 | interpolated (bool) 231 | occlusion (str)Fully_visible 232 | area (float) 233 | no_eval (bool) 234 | box (list) (x0, y0, w, h) 235 | file_id (str) 236 | label_id (str) e.g. "pedestrian:46" 237 | """ 238 | self.im_labels = json.load(f)["labels"] 239 | 240 | # load 2D detection 241 | im_det_dir = os.path.join(data_dir, "detections", "detections_2d_stitched") 242 | fname = os.path.join(im_det_dir, f"{self.sequence}.json") 243 | with open(fname, "r") as f: 244 | """ 245 | dict, key is the im file name, value is the provided detections (list[dict]) 246 | Each detection is a dict with keys: 247 | box (list) (x0, y0, w, h) 248 | file_id (str) 249 | label_id (str) e.g. "person:-1" 250 | score (float) 251 | """ 252 | self.im_dets = json.load(f)["detections"] 253 | 254 | # find out which frames has 3D annotation 255 | self.frames_labeled = [] 256 | for frame in self.frames: 257 | pc_file = os.path.basename(frame["pc_frame"]["pointclouds"][0]["url"]) 258 | if pc_file in self.pc_labels: 259 | self.frames_labeled.append(frame) 260 | 261 | # choose if labeled or all frames are used 262 | self.data_frames = ( 263 | self.frames if self._use_unlabeled_frames else self.frames_labeled 264 | ) 265 | 266 | def __len__(self): 267 | return len(self.data_frames) 268 | 269 | def __getitem__(self, idx): 270 | # NOTE It's important to use a copy as the return dict, otherwise the 271 | # original dict in the data handle will be corrupted 272 | frame = copy.deepcopy(self.data_frames[idx]) 273 | 274 | if self._use_unlabeled_frames: 275 | return frame, [], [], [] 276 | 277 | pc_file = os.path.basename(frame["pc_frame"]["pointclouds"][0]["url"]) 278 | pc_anns = copy.deepcopy(self.pc_labels[pc_file]) 279 | 280 | im_file = os.path.basename(frame["im_frame"]["cameras"][0]["url"]) 281 | im_anns = copy.deepcopy(self.im_labels[im_file]) 282 | im_dets = copy.deepcopy(self.im_dets[im_file]) 283 | 284 | return frame, pc_anns, im_anns, im_dets 285 | -------------------------------------------------------------------------------- /dr_spaam/dr_spaam/datahandle/jrdb_handle_det3d.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | import os 4 | from ._pypcd import point_cloud_from_path 5 | 6 | # NOTE: Don't use open3d to load point cloud since it spams the console. Setting 7 | # verbosity level does not solve the problem 8 | # https://github.com/intel-isl/Open3D/issues/1921 9 | # https://github.com/intel-isl/Open3D/issues/884 10 | 11 | # Force the dataloader to load only one sample, in which case the network should 12 | # fit perfectly. 13 | _DEBUG_ONE_SAMPLE = False 14 | 15 | 16 | __all__ = ["JRDBHandleDet3D"] 17 | 18 | 19 | class JRDBHandleDet3D: 20 | def __init__(self, split, cfg, sequences=None, exclude_sequences=None): 21 | if _DEBUG_ONE_SAMPLE: 22 | split = "train" 23 | sequences = None 24 | exclude_sequences = None 25 | 26 | data_dir = os.path.abspath(os.path.expanduser(cfg["data_dir"])) 27 | data_dir = ( 28 | os.path.join(data_dir, "train_dataset") 29 | if split == "train" or split == "val" 30 | else os.path.join(data_dir, "test_dataset") 31 | ) 32 | 33 | self.data_dir = data_dir 34 | 35 | sequence_names = ( 36 | os.listdir(os.path.join(data_dir, "timestamps")) 37 | if sequences is None 38 | else sequences 39 | ) 40 | 41 | # NOTE it is important to sort the return of os.listdir, since its order 42 | # changes for different file system. 43 | sequence_names.sort() 44 | 45 | if exclude_sequences is not None: 46 | sequence_names = [s for s in sequence_names if s not in exclude_sequences] 47 | 48 | self.sequence_names = sequence_names 49 | 50 | self.sequence_handle = [] 51 | self._sequence_beginning_inds = [0] 52 | self.__flat_inds_sequence = [] 53 | self.__flat_inds_frame = [] 54 | for seq_idx, seq_name in enumerate(self.sequence_names): 55 | self.sequence_handle.append(_SequenceHandle(self.data_dir, seq_name)) 56 | 57 | # build a flat index for all sequences and frames 58 | sequence_length = len(self.sequence_handle[-1]) 59 | self.__flat_inds_sequence += sequence_length * [seq_idx] 60 | self.__flat_inds_frame += range(sequence_length) 61 | 62 | self._sequence_beginning_inds.append( 63 | self._sequence_beginning_inds[-1] + sequence_length 64 | ) 65 | 66 | def __len__(self): 67 | if _DEBUG_ONE_SAMPLE: 68 | return 80 69 | else: 70 | return len(self.__flat_inds_frame) 71 | 72 | def __getitem__(self, idx): 73 | if _DEBUG_ONE_SAMPLE: 74 | idx = 500 75 | 76 | idx_sq = self.__flat_inds_sequence[idx] 77 | idx_fr = self.__flat_inds_frame[idx] 78 | 79 | frame_dict = self.sequence_handle[idx_sq][idx_fr] 80 | urls = frame_dict["url"] 81 | 82 | frame_dict.update( 83 | { 84 | "frame_id": int(frame_dict["frame_id"]), 85 | "sequence": self.sequence_handle[idx_sq].sequence, 86 | "first_frame": idx_fr == 0, 87 | "dataset_idx": idx, 88 | "pc_upper": self.load_pointcloud(urls["pc_upper"]), 89 | "pc_lower": self.load_pointcloud(urls["pc_lower"]), 90 | } 91 | ) 92 | 93 | if urls["label"] is not None: 94 | frame_dict["label_str"] = self.load_label(urls["label"]) 95 | 96 | return frame_dict 97 | 98 | @staticmethod 99 | def box_is_on_ground(jrdb_ann_dict): 100 | bottom_h = float(jrdb_ann_dict["box"]["cz"]) - 0.5 * float( 101 | jrdb_ann_dict["box"]["h"] 102 | ) 103 | 104 | return bottom_h < -0.69 # value found by examining dataset 105 | 106 | @property 107 | def sequence_beginning_inds(self): 108 | return copy.deepcopy(self._sequence_beginning_inds) 109 | 110 | def load_pointcloud(self, url): 111 | """Load a point cloud given file url. 112 | 113 | Returns: 114 | pc (np.ndarray[3, N]): 115 | """ 116 | # pcd_load = 117 | # o3d.io.read_point_cloud(os.path.join(self.data_dir, url), format='pcd') 118 | # return np.asarray(pcd_load.points, dtype=np.float32) 119 | pc = point_cloud_from_path(url).pc_data 120 | # NOTE: redundent copy, ok for now 121 | pc = np.array([pc["x"], pc["y"], pc["z"]], dtype=np.float32) 122 | return pc 123 | 124 | def load_label(self, url): 125 | with open(url, "r") as f: 126 | s = f.read() 127 | return s 128 | 129 | 130 | class _SequenceHandle: 131 | def __init__(self, data_dir, sequence): 132 | self.sequence = sequence 133 | 134 | # pc frames 135 | pc_dir = os.path.join(data_dir, "pointclouds", "upper_velodyne", sequence) 136 | frames = [f.split(".")[0] for f in os.listdir(pc_dir)] 137 | 138 | # labels 139 | label_dir = os.path.join(data_dir, "labels_kitti", sequence) 140 | if os.path.exists(label_dir): 141 | labeled_frames = [f.split(".")[0] for f in os.listdir(label_dir)] 142 | frames = list(set(frames) & set(labeled_frames)) 143 | 144 | self._upper_pc_dir = pc_dir 145 | self._lower_pc_dir = os.path.join( 146 | data_dir, "pointclouds", "lower_velodyne", sequence 147 | ) 148 | self._label_dir = label_dir 149 | self._frames = frames 150 | self._frames.sort() 151 | self._load_labels = os.path.exists(label_dir) 152 | 153 | def __len__(self): 154 | return self._frames.__len__() 155 | 156 | def __getitem__(self, idx): 157 | frame = self._frames[idx] 158 | url_upper_pc = os.path.join(self._upper_pc_dir, frame + ".pcd") 159 | url_lower_pc = os.path.join(self._lower_pc_dir, frame + ".pcd") 160 | url_label = ( 161 | os.path.join(self._label_dir, frame + ".txt") if self._load_labels else None 162 | ) 163 | 164 | return { 165 | "frame_id": frame, 166 | "url": { 167 | "pc_upper": url_upper_pc, 168 | "pc_lower": url_lower_pc, 169 | "label": url_label, 170 | }, 171 | } 172 | -------------------------------------------------------------------------------- /dr_spaam/dr_spaam/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .builder import * 2 | from .drow_dataset import * 3 | from .jrdb_dataset import * -------------------------------------------------------------------------------- /dr_spaam/dr_spaam/dataset/builder.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | 3 | 4 | def get_dataloader(split, batch_size, num_workers, shuffle, dataset_cfg): 5 | if "DROW" in dataset_cfg["DataHandle"]["data_dir"]: 6 | from .drow_dataset import DROWDataset 7 | 8 | ds = DROWDataset(split, dataset_cfg) 9 | elif "JRDB" in dataset_cfg["DataHandle"]["data_dir"]: 10 | if dataset_cfg["DataHandle"]["tracking"]: 11 | from .jrdb_detr_dataset import JRDBDeTrDataset 12 | 13 | assert dataset_cfg["DataHandle"]["num_scans"] == 1 14 | ds = JRDBDeTrDataset(split, dataset_cfg) 15 | else: 16 | from .jrdb_dataset import JRDBDataset 17 | 18 | ds = JRDBDataset(split, dataset_cfg) 19 | else: 20 | raise RuntimeError(f"Unknown dataset {dataset_cfg['name']}.") 21 | 22 | return DataLoader( 23 | ds, 24 | batch_size=batch_size, 25 | pin_memory=True, 26 | num_workers=num_workers, 27 | shuffle=shuffle, 28 | collate_fn=ds.collate_batch, 29 | ) 30 | -------------------------------------------------------------------------------- /dr_spaam/dr_spaam/dataset/drow_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.spatial.distance import cdist 3 | from torch.utils.data import Dataset 4 | 5 | from dr_spaam.datahandle.drow_handle import DROWHandle 6 | import dr_spaam.utils.utils as u 7 | 8 | 9 | class DROWDataset(Dataset): 10 | def __init__(self, split, cfg): 11 | self.__handle = DROWHandle(split, cfg["DataHandle"]) 12 | self.__split = split 13 | 14 | self._augment_data = cfg["augment_data"] 15 | self._person_only = cfg["person_only"] 16 | self._cutout_kwargs = cfg["cutout_kwargs"] 17 | 18 | @property 19 | def split(self): 20 | return self.__split # used by trainer.py 21 | 22 | def __len__(self): 23 | return len(self.__handle) 24 | 25 | def __getitem__(self, idx): 26 | data_dict = self.__handle[idx] 27 | 28 | # regression target 29 | target_cls, target_reg = _get_regression_target( 30 | data_dict["scans"][-1], 31 | data_dict["scan_phi"], 32 | data_dict["dets_wc"], 33 | data_dict["dets_wa"], 34 | data_dict["dets_wp"], 35 | person_only=self._person_only, 36 | ) 37 | 38 | data_dict["target_cls"] = target_cls 39 | data_dict["target_reg"] = target_reg 40 | 41 | if self._augment_data: 42 | data_dict = u.data_augmentation(data_dict) 43 | 44 | data_dict["input"] = u.scans_to_cutout( 45 | data_dict["scans"], data_dict["scan_phi"], stride=1, **self._cutout_kwargs 46 | ) 47 | 48 | # to be consistent with JRDB dataset 49 | data_dict["frame_id"] = data_dict["idx"] 50 | data_dict["sequence"] = "all" 51 | 52 | # this is used by JRDB dataset to mask out annotations, to be consistent 53 | data_dict["anns_valid_mask"] = np.ones(len(data_dict["dets_wp"]), dtype=np.bool) 54 | 55 | return data_dict 56 | 57 | def collate_batch(self, batch): 58 | rtn_dict = {} 59 | for k, _ in batch[0].items(): 60 | if k in ["target_cls", "target_reg", "input"]: 61 | rtn_dict[k] = np.array([sample[k] for sample in batch]) 62 | else: 63 | rtn_dict[k] = [sample[k] for sample in batch] 64 | 65 | return rtn_dict 66 | 67 | 68 | def _get_regression_target( 69 | scan, 70 | scan_phi, 71 | wcs, 72 | was, 73 | wps, 74 | radius_wc=0.6, 75 | radius_wa=0.4, 76 | radius_wp=0.35, 77 | label_wc=1, 78 | label_wa=2, 79 | label_wp=3, 80 | person_only=False, 81 | ): 82 | num_pts = len(scan) 83 | target_cls = np.zeros(num_pts, dtype=np.int64) 84 | target_reg = np.zeros((num_pts, 2), dtype=np.float32) 85 | 86 | if person_only: 87 | all_dets = list(wps) 88 | all_radius = [radius_wp] * len(wps) 89 | labels = [0] + [1] * len(wps) 90 | else: 91 | all_dets = list(wcs) + list(was) + list(wps) 92 | all_radius = ( 93 | [radius_wc] * len(wcs) + [radius_wa] * len(was) + [radius_wp] * len(wps) 94 | ) 95 | labels = ( 96 | [0] + [label_wc] * len(wcs) + [label_wa] * len(was) + [label_wp] * len(wps) 97 | ) 98 | 99 | dets = _closest_detection(scan, scan_phi, all_dets, all_radius) 100 | 101 | for i, (r, phi) in enumerate(zip(scan, scan_phi)): 102 | if 0 < dets[i]: 103 | target_cls[i] = labels[dets[i]] 104 | target_reg[i, :] = u.global_to_canonical(r, phi, *all_dets[dets[i] - 1]) 105 | 106 | return target_cls, target_reg 107 | 108 | 109 | def _closest_detection(scan, scan_phi, dets, radii): 110 | """ 111 | Given a single `scan` (450 floats), a list of r,phi detections `dets` (Nx2), 112 | and a list of N `radii` for those detections, return a mapping from each 113 | point in `scan` to the closest detection for which the point falls inside its 114 | radius. The returned detection-index is a 1-based index, with 0 meaning no 115 | detection is close enough to that point. 116 | """ 117 | if len(dets) == 0: 118 | return np.zeros_like(scan, dtype=int) 119 | 120 | assert len(dets) == len(radii), "Need to give a radius for each detection!" 121 | 122 | # Distance (in x,y space) of each laser-point with each detection. 123 | scan_xy = np.array(u.rphi_to_xy(scan, scan_phi)).T # (N, 2) 124 | dists = cdist(scan_xy, np.array([u.rphi_to_xy(r, phi) for r, phi in dets])) 125 | 126 | # Subtract the radius from the distances, such that they are < 0 if inside, 127 | # > 0 if outside. 128 | dists -= radii 129 | 130 | # Prepend zeros so that argmin is 0 for everything "outside". 131 | dists = np.hstack([np.zeros((len(scan), 1)), dists]) 132 | 133 | # And find out who's closest, including the threshold! 134 | return np.argmin(dists, axis=1) 135 | -------------------------------------------------------------------------------- /dr_spaam/dr_spaam/detector.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from dr_spaam.model.drow_net import DrowNet 5 | from dr_spaam.model.dr_spaam import DrSpaam 6 | from dr_spaam.utils import utils as u 7 | 8 | 9 | class Detector(object): 10 | def __init__( 11 | self, ckpt_file, model="DROW3", gpu=True, stride=1, panoramic_scan=False 12 | ): 13 | """A warpper class around DROW3 or DR-SPAAM network for end-to-end inference. 14 | 15 | Args: 16 | ckpt_file (str): Path to checkpoint 17 | model (str): Model name, "DROW3" or "DR-SPAAM". 18 | gpu (bool): True to use GPU. Defaults to True. 19 | stride (int): Downsample scans for faster inference. 20 | panoramic_scan (bool): True if the scan covers 360 degree. 21 | """ 22 | self._gpu = gpu 23 | self._stride = stride 24 | self._use_dr_spaam = model == "DR-SPAAM" 25 | 26 | self._scan_phi = None 27 | self._laser_fov_deg = None 28 | 29 | if model == "DROW3": 30 | self._model = DrowNet( 31 | dropout=0.5, cls_loss=None, mixup_alpha=0.0, mixup_w=0.0 32 | ) 33 | elif model == "DR-SPAAM": 34 | self._model = DrSpaam( 35 | dropout=0.5, 36 | num_pts=56, 37 | embedding_length=128, 38 | alpha=0.5, 39 | window_size=17, 40 | panoramic_scan=panoramic_scan, 41 | cls_loss=None, 42 | mixup_alpha=0.0, 43 | mixup_w=0.0, 44 | ) 45 | else: 46 | raise NotImplementedError( 47 | "model should be 'DROW3' or 'DR-SPAAM', received {} instead.".format( 48 | model 49 | ) 50 | ) 51 | 52 | ckpt = torch.load(ckpt_file) 53 | self._model.load_state_dict(ckpt["model_state"]) 54 | 55 | self._model.eval() 56 | if gpu: 57 | torch.backends.cudnn.benchmark = True 58 | self._model = self._model.cuda() 59 | 60 | def __call__(self, scan): 61 | if self._scan_phi is None: 62 | assert self.is_ready(), "Call set_laser_fov() first." 63 | half_fov_rad = 0.5 * np.deg2rad(self._laser_fov_deg) 64 | self._scan_phi = np.linspace( 65 | -half_fov_rad, half_fov_rad, len(scan), dtype=np.float32 66 | ) 67 | 68 | # preprocess 69 | ct = u.scans_to_cutout( 70 | scan[None, ...], 71 | self._scan_phi, 72 | stride=self._stride, 73 | centered=True, 74 | fixed=True, 75 | window_width=1.0, 76 | window_depth=0.5, 77 | num_cutout_pts=56, 78 | padding_val=29.99, 79 | area_mode=True, 80 | ) 81 | ct = torch.from_numpy(ct).float() 82 | 83 | if self._gpu: 84 | ct = ct.cuda() 85 | 86 | # inference 87 | with torch.no_grad(): 88 | # one extra dimension for batch 89 | if self._use_dr_spaam: 90 | pred_cls, pred_reg, _ = self._model(ct.unsqueeze(dim=0), inference=True) 91 | else: 92 | pred_cls, pred_reg = self._model(ct.unsqueeze(dim=0)) 93 | 94 | pred_cls = torch.sigmoid(pred_cls[0]).data.cpu().numpy() 95 | pred_reg = pred_reg[0].data.cpu().numpy() 96 | 97 | # postprocess 98 | dets_xy, dets_cls, instance_mask = u.nms_predicted_center( 99 | scan[:: self._stride], 100 | self._scan_phi[:: self._stride], 101 | pred_cls[:, 0], 102 | pred_reg, 103 | ) 104 | 105 | return dets_xy, dets_cls, instance_mask 106 | 107 | def set_laser_fov(self, fov_deg): 108 | self._laser_fov_deg = fov_deg 109 | 110 | def is_ready(self): 111 | return self._laser_fov_deg is not None 112 | -------------------------------------------------------------------------------- /dr_spaam/dr_spaam/model/__init__.py: -------------------------------------------------------------------------------- 1 | # Only import python2 compatible modules here 2 | from .dr_spaam import * 3 | from .drow_net import * 4 | # from .dr_spaam_fn import * 5 | # from .losses import * 6 | # from .get_model import * 7 | -------------------------------------------------------------------------------- /dr_spaam/dr_spaam/model/_common.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def _conv1d(in_channel, out_channel, kernel_size, padding): 5 | return nn.Sequential( 6 | nn.Conv1d(in_channel, out_channel, kernel_size=kernel_size, padding=padding), 7 | nn.BatchNorm1d(out_channel), 8 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 9 | ) 10 | 11 | 12 | def _conv1d_3(in_channel, out_channel): 13 | return _conv1d(in_channel, out_channel, kernel_size=3, padding=1) 14 | 15 | 16 | def _conv1d_1(in_channel, out_channel): 17 | return _conv1d(in_channel, out_channel, kernel_size=1, padding=1) 18 | -------------------------------------------------------------------------------- /dr_spaam/dr_spaam/model/dr_spaam.py: -------------------------------------------------------------------------------- 1 | from math import ceil 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from ._common import _conv1d_3 7 | 8 | 9 | class DrSpaam(nn.Module): 10 | def __init__( 11 | self, 12 | dropout=0.5, 13 | num_pts=48, 14 | alpha=0.5, 15 | embedding_length=128, 16 | window_size=7, 17 | panoramic_scan=False, 18 | cls_loss=None, 19 | mixup_alpha=0.0, 20 | mixup_w=0.0, 21 | use_box=False, 22 | ): 23 | super(DrSpaam, self).__init__() 24 | 25 | self.dropout = dropout 26 | self.mixup_alpha = mixup_alpha 27 | self.mixup_w = mixup_w 28 | if mixup_alpha <= 0.0: 29 | mixup_w = 0.0 30 | else: 31 | assert mixup_w >= 0.0 and mixup_w <= 1.0 32 | 33 | # backbone 34 | self.conv_block_1 = nn.Sequential( 35 | _conv1d_3(1, 64), _conv1d_3(64, 64), _conv1d_3(64, 128) 36 | ) 37 | self.conv_block_2 = nn.Sequential( 38 | _conv1d_3(128, 128), _conv1d_3(128, 128), _conv1d_3(128, 256) 39 | ) 40 | self.conv_block_3 = nn.Sequential( 41 | _conv1d_3(256, 256), _conv1d_3(256, 256), _conv1d_3(256, 512) 42 | ) 43 | self.conv_block_4 = nn.Sequential(_conv1d_3(512, 256), _conv1d_3(256, 128)) 44 | 45 | # detection layer 46 | self.conv_cls = nn.Conv1d(128, 1, kernel_size=1) 47 | self.conv_reg = nn.Conv1d(128, 2, kernel_size=1) 48 | self._use_box = use_box 49 | if use_box: 50 | self.conv_box = nn.Conv1d( 51 | 128, 4, kernel_size=1 52 | ) # length, width, sin_rot, cos_rot 53 | 54 | # spatial attention 55 | self.gate = _SpatialAttentionMemory( 56 | n_pts=int(ceil(num_pts / 4)), 57 | n_channel=256, 58 | embedding_length=embedding_length, 59 | alpha=alpha, 60 | window_size=window_size, 61 | panoramic_scan=panoramic_scan, 62 | ) 63 | 64 | # classification loss 65 | self.cls_loss = ( 66 | cls_loss if cls_loss is not None else F.binary_cross_entropy_with_logits 67 | ) 68 | 69 | # initialize weights 70 | for m in self.modules(): 71 | if isinstance(m, (nn.Conv1d, nn.Conv2d)): 72 | nn.init.kaiming_normal_(m.weight, a=0.1, nonlinearity="leaky_relu") 73 | elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)): 74 | nn.init.constant_(m.weight, 1) 75 | nn.init.constant_(m.bias, 0) 76 | 77 | @property 78 | def use_box(self): 79 | return self._use_box 80 | 81 | def forward(self, x, inference=False): 82 | """ 83 | Args: 84 | x (tensor[B, CT, N, L]): (batch, cutout, scan, points per cutout) 85 | inference (bool, optional): Set to true for sequencial inference 86 | (i.e. in deployment). Defaults to False. 87 | 88 | Returns: 89 | pred_cls (tensor[B, CT, C]): C = number of class 90 | pred_reg (tensor[B, CT, 2]) 91 | """ 92 | B, CT, N, L = x.shape 93 | 94 | if not inference: 95 | self.gate.reset() 96 | 97 | # NOTE: Ablation study, DR-SPA, no auto-regression, only two consecutive scans 98 | # x = x[:, :, -2:, :] 99 | 100 | # process scan sequentially 101 | n_scan = x.shape[2] 102 | for i in range(n_scan): 103 | x_i = x[:, :, i, :] # (B, CT, L) 104 | 105 | # extract feature from current scan 106 | out = x_i.view(B * CT, 1, L) 107 | out = self._conv_and_pool(out, self.conv_block_1) # /2 108 | out = self._conv_and_pool(out, self.conv_block_2) # /4 109 | out = out.view(B, CT, out.shape[-2], out.shape[-1]) # (B, CT, C, L) 110 | 111 | # combine current feature with memory 112 | out, sim = self.gate(out) # (B, CT, C, L) 113 | 114 | # detection using combined feature memory 115 | out = out.view(B * CT, out.shape[-2], out.shape[-1]) 116 | out = self._conv_and_pool(out, self.conv_block_3) # /8 117 | out = self.conv_block_4(out) 118 | out = F.avg_pool1d(out, kernel_size=out.shape[-1]) # (B * CT, C, 1) 119 | 120 | pred_cls = self.conv_cls(out).view(B, CT, -1) # (B, CT, cls) 121 | pred_reg = self.conv_reg(out).view(B, CT, 2) # (B, CT, 2) 122 | 123 | if self._use_box: 124 | pred_box = self.conv_box(out).view(B, CT, 4) 125 | return pred_cls, pred_reg, pred_box, sim 126 | else: 127 | return pred_cls, pred_reg, sim 128 | 129 | def _conv_and_pool(self, x, conv_block): 130 | out = conv_block(x) 131 | out = F.max_pool1d(out, kernel_size=2) 132 | if self.dropout > 0: 133 | out = F.dropout(out, p=self.dropout, training=self.training) 134 | 135 | return out 136 | 137 | 138 | class _SpatialAttentionMemory(nn.Module): 139 | def __init__( 140 | self, n_pts, n_channel, embedding_length, alpha, window_size, panoramic_scan 141 | ): 142 | """A memory network that updates with similarity-based spatial attention and 143 | auto-regressive model. 144 | 145 | Args: 146 | n_pts (int): Length of the input sequence (cutout) 147 | n_channel (int): Channel of the input sequence 148 | embedding_length (int): Each cutout is converted to an embedding vector 149 | alpha (float): Auto-regressive update rate, in range [0, 1] 150 | window_size (int): Full neighborhood window size to compute attention 151 | panoramic_scan (bool): True if the scan span 360 degree, used to warp 152 | window indices accordingly 153 | """ 154 | super(_SpatialAttentionMemory, self).__init__() 155 | self._alpha = alpha 156 | self._window_size = window_size 157 | self._embedding_length = embedding_length 158 | self._panoramic_scan = panoramic_scan 159 | 160 | self.conv = nn.Sequential( 161 | nn.Conv1d(n_channel, self._embedding_length, kernel_size=n_pts, padding=0), 162 | nn.BatchNorm1d(self._embedding_length), 163 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 164 | ) 165 | 166 | self._memory = None 167 | 168 | # place holder, created at runtime 169 | self.neighbor_masks, self.neighbor_inds = None, None 170 | 171 | for m in self.modules(): 172 | if isinstance(m, (nn.Conv1d, nn.Conv2d)): 173 | nn.init.kaiming_normal_(m.weight, a=0.1, nonlinearity="leaky_relu") 174 | elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)): 175 | nn.init.constant_(m.weight, 1) 176 | nn.init.constant_(m.bias, 0) 177 | 178 | def reset(self): 179 | self._memory = None 180 | 181 | def forward(self, x_new): 182 | if self._memory is None: 183 | self._memory = x_new 184 | return self._memory, None 185 | 186 | # ########## 187 | # NOTE: Ablation study, DR-AM, no spatial attention 188 | # self._memory = self._alpha * x_new + (1.0 - self._alpha) * self._memory 189 | # return self._memory, None 190 | # ########## 191 | 192 | n_batch, n_cutout, n_channel, n_pts = x_new.shape 193 | 194 | # only need to generate neighbor mask once 195 | if ( 196 | self.neighbor_masks is None 197 | or self.neighbor_masks.shape[0] != x_new.shape[1] 198 | ): 199 | self.neighbor_masks, self.neighbor_inds = self._generate_neighbor_mask( 200 | x_new 201 | ) 202 | 203 | # embedding for cutout 204 | emb_x = self.conv(x_new.view(n_batch * n_cutout, n_channel, n_pts)) 205 | emb_x = emb_x.view(n_batch, n_cutout, self._embedding_length) 206 | 207 | # embedding for template 208 | emb_temp = self.conv(self._memory.view(n_batch * n_cutout, n_channel, n_pts)) 209 | emb_temp = emb_temp.view(n_batch, n_cutout, self._embedding_length) 210 | 211 | # pair-wise similarity (batch, cutout, cutout) 212 | sim = torch.matmul(emb_x, emb_temp.permute(0, 2, 1)) 213 | 214 | # masked softmax 215 | # TODO replace with gather and scatter 216 | sim = sim - 1e10 * ( 217 | 1.0 - self.neighbor_masks 218 | ) # make sure the out-of-window elements have small values 219 | maxes = sim.max(dim=-1, keepdim=True)[0] 220 | exps = torch.exp(sim - maxes) * self.neighbor_masks 221 | exps_sum = exps.sum(dim=-1, keepdim=True) 222 | sim = exps / exps_sum 223 | 224 | # # NOTE this gather scatter version is only marginally more efficient on memory 225 | # sim_w = torch.gather(sim, 2, self.neighbor_inds.unsqueeze(dim=0)) 226 | # sim_w = sim_w.softmax(dim=2) 227 | # sim = torch.zeros_like(sim) 228 | # sim.scatter_(2, self.neighbor_inds.unsqueeze(dim=0), sim_w) 229 | 230 | # weighted average on the template 231 | atten_memory = self._memory.view(n_batch, n_cutout, n_channel * n_pts) 232 | atten_memory = torch.matmul(sim, atten_memory) 233 | atten_memory = atten_memory.view(n_batch, n_cutout, n_channel, n_pts) 234 | 235 | # update memory using auto-regressive 236 | self._memory = self._alpha * x_new + (1.0 - self._alpha) * atten_memory 237 | 238 | return self._memory, sim 239 | 240 | def _generate_neighbor_mask(self, x): 241 | # indices of neighboring cutout 242 | n_cutout = x.shape[1] 243 | hw = int(self._window_size / 2) 244 | inds_col = torch.arange(n_cutout).unsqueeze(dim=-1).long() 245 | window_inds = torch.arange(-hw, hw + 1).long() 246 | inds_col = inds_col + window_inds.unsqueeze(dim=0) # (cutout, neighbors) 247 | # NOTE On JRDB, DR-SPAAM takes part of the panoramic scan and at test time 248 | # takes the whole panoramic scan 249 | inds_col = ( 250 | inds_col % n_cutout 251 | if self._panoramic_scan and not self.training 252 | else inds_col.clamp(min=0, max=n_cutout - 1) 253 | ) 254 | inds_row = torch.arange(n_cutout).unsqueeze(dim=-1).expand_as(inds_col).long() 255 | inds_full = torch.stack((inds_row, inds_col), dim=2).view(-1, 2) 256 | 257 | masks = torch.zeros(n_cutout, n_cutout).float() 258 | masks[inds_full[:, 0], inds_full[:, 1]] = 1.0 259 | return masks.cuda(x.get_device()) if x.is_cuda else masks, inds_full 260 | -------------------------------------------------------------------------------- /dr_spaam/dr_spaam/model/dr_spaam_fn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | import dr_spaam.utils.utils as u 7 | import dr_spaam.utils.precision_recall as pru 8 | from dr_spaam.utils.plotting import plot_one_frame 9 | 10 | 11 | # TODO when to plot? 12 | _PLOTTING = False 13 | 14 | 15 | def _sample_or_repeat(population, n): 16 | """Select n sample from population, without replacement if population size 17 | greater than n, otherwise with replacement. 18 | 19 | Only work for population of 1D tensor (N,) 20 | """ 21 | N = len(population) 22 | if N == n: 23 | return population 24 | elif N > n: 25 | return population[torch.randperm(N, device=population.device)[:n]] 26 | else: 27 | return population[torch.randint(N, (n,), device=population.device)] 28 | 29 | 30 | def _balanced_sampling_reweighting(target_cls, goal_fg_ratio=0.4): 31 | # target_cls is 1D tensor (N, ) 32 | N = target_cls.shape[0] 33 | goal_fg_num = int(N * goal_fg_ratio) 34 | goal_bg_num = int(N * (1.0 - goal_fg_ratio)) 35 | 36 | inds = torch.arange(N, device=target_cls.device) 37 | fg_inds = inds[target_cls > 0] 38 | bg_inds = inds[target_cls == 0] 39 | 40 | if len(fg_inds) > 0: 41 | fg_inds = _sample_or_repeat(fg_inds, goal_fg_num) 42 | bg_inds = _sample_or_repeat(bg_inds, goal_bg_num) 43 | sample_inds = torch.cat((fg_inds, bg_inds)) 44 | else: 45 | sample_inds = _sample_or_repeat(bg_inds, N) 46 | 47 | weights = torch.zeros(N, device=target_cls.device).float() 48 | weights.index_add_(0, sample_inds, torch.ones_like(sample_inds).float()) 49 | 50 | return weights 51 | 52 | 53 | def _model_fn(model, batch_dict, max_num_pts=1e6, cls_loss_weight=1.0): 54 | tb_dict, rtn_dict = {}, {} 55 | 56 | net_input = batch_dict["input"] 57 | target_cls, target_reg = batch_dict["target_cls"], batch_dict["target_reg"] 58 | 59 | B, N = target_cls.shape 60 | 61 | # train only on part of scan, if the GPU cannot fit the whole scan 62 | num_pts = target_cls.shape[1] 63 | if model.training and num_pts > max_num_pts: 64 | idx0 = np.random.randint(0, num_pts - max_num_pts) 65 | idx1 = idx0 + max_num_pts 66 | target_cls = target_cls[:, idx0:idx1] 67 | target_reg = target_reg[:, idx0:idx1, :] 68 | net_input = net_input[:, idx0:idx1, :, :] 69 | N = max_num_pts 70 | 71 | # to gpu 72 | net_input = torch.from_numpy(net_input).cuda(non_blocking=True).float() 73 | target_cls = torch.from_numpy(target_cls).cuda(non_blocking=True).float() 74 | target_reg = torch.from_numpy(target_reg).cuda(non_blocking=True).float() 75 | 76 | # forward pass 77 | rtn_tuple = model(net_input) 78 | 79 | # so this function can be used for both DROW and DR-SPAAM 80 | if len(rtn_tuple) == 2: 81 | pred_cls, pred_reg = rtn_tuple 82 | elif len(rtn_tuple) == 3: 83 | pred_cls, pred_reg, pred_sim = rtn_tuple 84 | rtn_dict["pred_sim"] = pred_sim 85 | 86 | target_cls = target_cls.view(B * N) 87 | pred_cls = pred_cls.view(B * N) 88 | 89 | # number of valid points 90 | valid_mask = target_cls >= 0 91 | valid_ratio = torch.sum(valid_mask).item() / (B * N) 92 | # assert valid_ratio > 0, "No valid points in this batch." 93 | tb_dict["valid_ratio"] = valid_ratio 94 | 95 | # cls loss 96 | cls_loss = ( 97 | model.cls_loss(pred_cls[valid_mask], target_cls[valid_mask], reduction="mean") 98 | * cls_loss_weight 99 | ) 100 | total_loss = cls_loss 101 | tb_dict["cls_loss"] = cls_loss.item() 102 | 103 | # number fg points 104 | # NOTE supervise regression for both close and far neighbor points 105 | fg_mask = torch.logical_or(target_cls == 1, target_cls == -1) 106 | fg_ratio = torch.sum(fg_mask).item() / (B * N) 107 | tb_dict["fg_ratio"] = fg_ratio 108 | 109 | # reg loss 110 | if fg_ratio > 0.0: 111 | target_reg = target_reg.view(B * N, -1) 112 | pred_reg = pred_reg.view(B * N, -1) 113 | reg_loss = F.mse_loss(pred_reg[fg_mask], target_reg[fg_mask], reduction="none") 114 | reg_loss = torch.sqrt(torch.sum(reg_loss, dim=1)).mean() 115 | total_loss = total_loss + reg_loss 116 | tb_dict["reg_loss"] = reg_loss.item() 117 | 118 | # # regularization loss for spatial attention 119 | # if spatial_drow: 120 | # # shannon entropy 121 | # att_loss = (-torch.log(pred_sim + 1e-5) * pred_sim).sum(dim=2).mean() 122 | # tb_dict['att_loss'] = att_loss.item() 123 | # total_loss = total_loss + att_loss 124 | 125 | rtn_dict["pred_reg"] = pred_reg.view(B, N, 2) 126 | rtn_dict["pred_cls"] = pred_cls.view(B, N) 127 | 128 | return total_loss, tb_dict, rtn_dict 129 | 130 | 131 | def _model_fn_mixup(model, batch_dict, max_num_pts=1e6, cls_loss_weight=1.0): 132 | # mixup regularization for robust training against label noise 133 | # https://arxiv.org/pdf/1710.09412.pdf 134 | 135 | tb_dict, rtn_dict = {}, {} 136 | 137 | net_input = batch_dict["input_mixup"] 138 | target_cls = batch_dict["target_cls_mixup"] 139 | 140 | B, N = target_cls.shape 141 | 142 | # train only on part of scan, if the GPU cannot fit the whole scan 143 | num_pts = target_cls.shape[1] 144 | if model.training and num_pts > max_num_pts: 145 | idx0 = np.random.randint(0, num_pts - max_num_pts) 146 | idx1 = idx0 + max_num_pts 147 | target_cls = target_cls[:, idx0:idx1] 148 | net_input = net_input[:, idx0:idx1, :, :] 149 | N = max_num_pts 150 | 151 | # to gpu 152 | net_input = torch.from_numpy(net_input).cuda(non_blocking=True).float() 153 | target_cls = torch.from_numpy(target_cls).cuda(non_blocking=True).float() 154 | 155 | # forward pass 156 | rtn_tuple = model(net_input) 157 | 158 | # so this function can be used for both DROW and DR-SPAAM 159 | if len(rtn_tuple) == 2: 160 | pred_cls, pred_reg = rtn_tuple 161 | elif len(rtn_tuple) == 3: 162 | pred_cls, pred_reg, pred_sim = rtn_tuple 163 | rtn_dict["pred_sim"] = pred_sim 164 | 165 | target_cls = target_cls.view(B * N) 166 | pred_cls = pred_cls.view(B * N) 167 | 168 | # number of valid points 169 | valid_mask = target_cls >= 0 170 | valid_ratio = torch.sum(valid_mask).item() / (B * N) 171 | # assert valid_ratio > 0, "No valid points in this batch." 172 | tb_dict["valid_ratio_mixup"] = valid_ratio 173 | 174 | # cls loss 175 | cls_loss = ( 176 | model.cls_loss(pred_cls[valid_mask], target_cls[valid_mask], reduction="mean") 177 | * cls_loss_weight 178 | ) 179 | total_loss = cls_loss 180 | tb_dict["cls_loss_mixup"] = cls_loss.item() 181 | 182 | return total_loss, tb_dict, rtn_dict 183 | 184 | 185 | def _model_eval_fn(model, batch_dict): 186 | _, tb_dict, rtn_dict = _model_fn(model, batch_dict) 187 | 188 | pred_cls = torch.sigmoid(rtn_dict["pred_cls"]).data.cpu().numpy() 189 | pred_reg = rtn_dict["pred_reg"].data.cpu().numpy() 190 | 191 | # # DEBUG use perfect predictions 192 | # pred_cls = batch_dict["target_cls"] 193 | # pred_cls[pred_cls < 0] = 1 194 | # pred_reg = batch_dict["target_reg"] 195 | 196 | fig_dict = {} 197 | file_dict = {} 198 | 199 | # postprocess network prediction to get detection 200 | scans = batch_dict["scans"] 201 | scan_phi = batch_dict["scan_phi"] 202 | for ib in range(len(scans)): 203 | # store detection, which will be used by _model_eval_collate_fn to compute AP 204 | dets_xy, dets_cls, _ = u.nms_predicted_center( 205 | scans[ib][-1], scan_phi[ib], pred_cls[ib], pred_reg[ib] 206 | ) 207 | frame_id = f"{batch_dict['frame_id'][ib]:06d}" 208 | sequence = batch_dict["sequence"][ib] 209 | 210 | # save detection results for evaluation 211 | det_str = pru.drow_detection_to_kitti_string(dets_xy, dets_cls, None) 212 | file_dict[f"detections/{sequence}/{frame_id}"] = det_str 213 | 214 | # save corresponding groundtruth for evaluation 215 | anns_rphi = batch_dict["dets_wp"][ib] 216 | if len(anns_rphi) > 0: 217 | anns_rphi = np.array(anns_rphi, dtype=np.float32) 218 | gts_xy = np.stack(u.rphi_to_xy(anns_rphi[:, 0], anns_rphi[:, 1]), axis=1) 219 | gts_occluded = np.logical_not(batch_dict["anns_valid_mask"][ib]).astype( 220 | np.int 221 | ) 222 | gts_str = pru.drow_detection_to_kitti_string(gts_xy, None, gts_occluded) 223 | file_dict[f"groundtruth/{sequence}/{frame_id}"] = gts_str 224 | else: 225 | file_dict[f"groundtruth/{sequence}/{frame_id}"] = "" 226 | 227 | # TODO When to plot 228 | if _PLOTTING: 229 | fig, ax = plot_one_frame( 230 | batch_dict, ib, pred_cls[ib], pred_reg[ib], dets_cls, dets_xy 231 | ) 232 | fig_dict[f"figs/{sequence}/{frame_id}"] = (fig, ax) 233 | 234 | return tb_dict, file_dict, fig_dict 235 | 236 | 237 | def _model_eval_collate_fn(tb_dict_list, result_dir): 238 | # tb_dict should only contain scalar values, collate them into an array 239 | # and take their mean as the value of the epoch 240 | epoch_tb_dict = {} 241 | for batch_tb_dict in tb_dict_list: 242 | for k, v in batch_tb_dict.items(): 243 | epoch_tb_dict.setdefault(k, []).append(v) 244 | for k, v in epoch_tb_dict.items(): 245 | epoch_tb_dict[k] = np.array(v).mean() 246 | 247 | sequences, sequences_results_03, sequences_results_05 = pru.evaluate_drow( 248 | result_dir, remove_raw_files=True 249 | ) 250 | 251 | # save evaluation output to system 252 | epoch_dict = {} 253 | for n, re03, re05 in zip(sequences, sequences_results_03, sequences_results_05): 254 | epoch_dict[f"evaluation/{n}/result_r03"] = re03 255 | epoch_dict[f"evaluation/{n}/result_r05"] = re05 256 | 257 | # log scalar values in tensorboard 258 | for k, v in re03.items(): 259 | if not isinstance(v, (np.ndarray, list, tuple)): 260 | epoch_tb_dict[f"{n}_{k}_r03"] = v 261 | 262 | for k, v in re05.items(): 263 | if not isinstance(v, (np.ndarray, list, tuple)): 264 | epoch_tb_dict[f"{n}_{k}_r05"] = v 265 | 266 | return epoch_tb_dict, epoch_dict 267 | -------------------------------------------------------------------------------- /dr_spaam/dr_spaam/model/drow_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from ._common import _conv1d_3 6 | 7 | 8 | class DrowNet(nn.Module): 9 | def __init__(self, dropout=0.5, cls_loss=None, mixup_alpha=0.0, mixup_w=0.0): 10 | super(DrowNet, self).__init__() 11 | 12 | self.dropout = dropout 13 | self.mixup_alpha = mixup_alpha 14 | self.mixup_w = mixup_w 15 | if mixup_alpha <= 0.0: 16 | mixup_w = 0.0 17 | else: 18 | assert mixup_w >= 0.0 and mixup_w <= 1.0 19 | 20 | self.conv_block_1 = nn.Sequential( 21 | _conv1d_3(1, 64), _conv1d_3(64, 64), _conv1d_3(64, 128) 22 | ) 23 | self.conv_block_2 = nn.Sequential( 24 | _conv1d_3(128, 128), _conv1d_3(128, 128), _conv1d_3(128, 256) 25 | ) 26 | self.conv_block_3 = nn.Sequential( 27 | _conv1d_3(256, 256), _conv1d_3(256, 256), _conv1d_3(256, 512) 28 | ) 29 | self.conv_block_4 = nn.Sequential(_conv1d_3(512, 256), _conv1d_3(256, 128)) 30 | 31 | self.conv_cls = nn.Conv1d(128, 1, kernel_size=1) 32 | self.conv_reg = nn.Conv1d(128, 2, kernel_size=1) 33 | 34 | # classification loss 35 | self.cls_loss = ( 36 | cls_loss if cls_loss is not None else F.binary_cross_entropy_with_logits 37 | ) 38 | 39 | # initialize weights 40 | for m in self.modules(): 41 | if isinstance(m, (nn.Conv1d, nn.Conv2d)): 42 | nn.init.kaiming_normal_(m.weight, a=0.1, nonlinearity="leaky_relu") 43 | elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)): 44 | nn.init.constant_(m.weight, 1) 45 | nn.init.constant_(m.bias, 0) 46 | 47 | def forward(self, x): 48 | """ 49 | Args: 50 | x (tensor[B, CT, N, L]): (batch, cutout, scan, points per cutout) 51 | 52 | Returns: 53 | pred_cls (tensor[B, CT, C]): C = number of class 54 | pred_reg (tensor[B, CT, 2]) 55 | """ 56 | n_batch, n_cutout, n_scan, n_pts = x.shape 57 | 58 | # forward cutout from all scans 59 | out = x.view(n_batch * n_cutout * n_scan, 1, n_pts) 60 | out = self._conv_and_pool(out, self.conv_block_1) # /2 61 | out = self._conv_and_pool(out, self.conv_block_2) # /4 62 | 63 | # (batch, cutout, scan, channel, pts) 64 | out = out.view(n_batch, n_cutout, n_scan, out.shape[-2], out.shape[-1]) 65 | # combine all scans 66 | out = torch.sum(out, dim=2) # (B, CT, C, L) 67 | 68 | # forward fused cutout 69 | out = out.view(n_batch * n_cutout, out.shape[-2], out.shape[-1]) 70 | out = self._conv_and_pool(out, self.conv_block_3) # /8 71 | out = self.conv_block_4(out) 72 | out = F.avg_pool1d(out, kernel_size=out.shape[-1]) # (B * CT, C, 1) 73 | 74 | pred_cls = self.conv_cls(out).view(n_batch, n_cutout, -1) # (B, CT, cls) 75 | pred_reg = self.conv_reg(out).view(n_batch, n_cutout, 2) # (B, CT, 2) 76 | 77 | return pred_cls, pred_reg 78 | 79 | def _conv_and_pool(self, x, conv_block): 80 | out = conv_block(x) 81 | out = F.max_pool1d(out, kernel_size=2) 82 | if self.dropout > 0: 83 | out = F.dropout(out, p=self.dropout, training=self.training) 84 | 85 | return out 86 | -------------------------------------------------------------------------------- /dr_spaam/dr_spaam/model/get_model.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import torch.nn.functional as F 3 | 4 | from .losses import ( 5 | SymmetricBCELoss, 6 | SelfPacedLearningLoss, 7 | PartiallyHuberisedBCELoss, 8 | ) 9 | from .dr_spaam_fn import ( 10 | _model_fn, 11 | _model_eval_fn, 12 | _model_eval_collate_fn, 13 | _model_fn_mixup, 14 | ) 15 | 16 | 17 | def get_model(cfg): 18 | if cfg["cls_loss"]["type"] == 0: 19 | cls_loss = F.binary_cross_entropy_with_logits 20 | 21 | elif cfg["cls_loss"]["type"] == 1: 22 | if "kwargs" in cfg["cls_loss"]: 23 | cls_loss = SymmetricBCELoss(**cfg["cls_loss"]["kwargs"]) 24 | else: 25 | cls_loss = SymmetricBCELoss() 26 | 27 | elif cfg["cls_loss"]["type"] == 2: 28 | if "kwargs" in cfg["cls_loss"]: 29 | cls_loss = PartiallyHuberisedBCELoss(**cfg["cls_loss"]["kwargs"]) 30 | else: 31 | cls_loss = PartiallyHuberisedBCELoss() 32 | 33 | else: 34 | raise NotImplementedError 35 | 36 | if cfg["self_paced"]: 37 | cls_loss = SelfPacedLearningLoss(cls_loss) 38 | 39 | if cfg["type"] == "drow": 40 | from .drow_net import DrowNet 41 | 42 | d = DrowNet( 43 | **cfg["kwargs"], 44 | cls_loss=cls_loss, 45 | mixup_alpha=cfg["mixup_alpha"], 46 | mixup_w=cfg["mixup_w"] 47 | ) 48 | d.model_eval_fn = _model_eval_fn 49 | d.model_eval_collate_fn = _model_eval_collate_fn 50 | d.model_fn = partial( 51 | _model_fn, max_num_pts=1e6, cls_loss_weight=1.0 - d.mixup_w 52 | ) 53 | d.model_fn_mixup = partial( 54 | _model_fn_mixup, max_num_pts=1e6, cls_loss_weight=d.mixup_w 55 | ) 56 | return d 57 | elif cfg["type"] == "dr-spaam": 58 | from .dr_spaam import DrSpaam 59 | 60 | d = DrSpaam( 61 | **cfg["kwargs"], 62 | cls_loss=cls_loss, 63 | mixup_alpha=cfg["mixup_alpha"], 64 | mixup_w=cfg["mixup_w"] 65 | ) 66 | d.model_eval_fn = _model_eval_fn 67 | d.model_eval_collate_fn = _model_eval_collate_fn 68 | d.model_fn = partial( 69 | _model_fn, max_num_pts=1000, cls_loss_weight=1.0 - d.mixup_w 70 | ) 71 | d.model_fn_mixup = partial( 72 | _model_fn_mixup, max_num_pts=1000, cls_loss_weight=d.mixup_w 73 | ) 74 | return d 75 | elif cfg["type"] == "detr": 76 | from .detr_net import DeTrNet 77 | 78 | return DeTrNet(**cfg["kwargs"]) 79 | else: 80 | raise NotImplementedError 81 | -------------------------------------------------------------------------------- /dr_spaam/dr_spaam/model/losses.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class PartiallyHuberisedBCELoss(nn.Module): 9 | """partially Huberised softmax cross-entrop 10 | https://openreview.net/pdf?id=rklB76EKPr 11 | """ 12 | 13 | def __init__(self, tau=5.0): 14 | super(PartiallyHuberisedBCELoss, self).__init__() 15 | self._tau = tau 16 | self._log_tau = math.log(self._tau) 17 | self._inv_tau = 1.0 / self._tau 18 | 19 | def forward(self, pred, target, reduction="mean"): 20 | pred_logits = torch.sigmoid(pred) 21 | neg_pred_logits = 1.0 - pred_logits 22 | 23 | loss_pos = -self._tau * pred_logits + self._log_tau + 1.0 24 | pos_mask = pred_logits > self._inv_tau 25 | loss_pos[pos_mask] = -torch.log(pred_logits[pos_mask]) 26 | 27 | loss_neg = -self._tau * neg_pred_logits + self._log_tau + 1.0 28 | neg_mask = neg_pred_logits > self._inv_tau 29 | loss_neg[neg_mask] = -torch.log(neg_pred_logits[neg_mask]) 30 | 31 | loss = target * loss_pos + (1.0 - target) * loss_neg 32 | 33 | if reduction == "mean": 34 | return loss.mean() 35 | elif reduction == "sum": 36 | return loss.sum() 37 | elif reduction == "none": 38 | return loss 39 | else: 40 | raise RuntimeError 41 | 42 | 43 | class SelfPacedLearningLoss(nn.Module): 44 | """Self-paced learning loss 45 | https://arxiv.org/abs/1712.05055 46 | https://papers.nips.cc/paper/3923-self-paced-learning-for-latent-variable-models 47 | https://arxiv.org/abs/1511.06049 48 | """ 49 | 50 | def __init__(self, base_loss, lam1=0.4, lam2=0.5, alpha=1e-2): 51 | super(SelfPacedLearningLoss, self).__init__() 52 | self._base_loss = base_loss 53 | self._lam1 = lam1 54 | self._lam2 = lam2 55 | 56 | self._l1 = None 57 | self._l2 = None 58 | self._alpha = alpha 59 | 60 | self._lam1_max = 0.60 61 | self._lam2_max = 0.72 62 | 63 | self._step = -1 64 | self._burn_in = False 65 | self._burn_in_step = int(2629 * 0.5) 66 | self._update_step = int(2629) 67 | self._update_rate = 1.02 68 | 69 | def forward(self, pred, target, reduction="mean"): 70 | self._update() 71 | 72 | if self._burn_in: 73 | return self._base_loss(pred, target, reduction=reduction) 74 | 75 | # raw loss 76 | base_loss = self._base_loss(pred, target, reduction="none") 77 | 78 | # exponential moving average of loss percentile 79 | with torch.no_grad(): 80 | l1_now = self._percentile(base_loss, self._lam1) 81 | self._l1 = ( 82 | self._alpha * l1_now + (1.0 - self._alpha) * self._l1 83 | if self._l1 is not None 84 | else l1_now 85 | ) 86 | 87 | l2_now = self._percentile(base_loss, self._lam2) 88 | self._l2 = ( 89 | self._alpha * l2_now + (1.0 - self._alpha) * self._l2 90 | if self._l2 is not None 91 | else l2_now 92 | ) 93 | 94 | # compute v 95 | v = (1.0 - (base_loss - self._l1) / (self._l2 - self._l1)).clamp( 96 | min=0.0, max=1.0 97 | ) 98 | 99 | # weighted loss 100 | loss = base_loss * v 101 | 102 | # NOTE reweight the loss, it may be better to use a fixed weighting factor 103 | # loss = loss / (v.sum() / v.numel()) 104 | loss = loss / (loss.sum() / base_loss.sum()) 105 | 106 | if reduction == "mean": 107 | return loss.mean() 108 | elif reduction == "sum": 109 | return loss.sum() 110 | elif reduction == "none": 111 | return loss 112 | else: 113 | raise RuntimeError 114 | 115 | def _update(self): 116 | self._step += 1 117 | 118 | if self._burn_in: 119 | if self._step >= self._burn_in_step: 120 | self._burn_in = False 121 | self._step = 0 122 | else: 123 | if self._step >= self._update_step: 124 | self._lam1 = min(self._lam1_max, self._lam1 * self._update_rate) 125 | self._lam2 = min(self._lam2_max, self._lam2 * self._update_rate) 126 | self._step = 0 127 | 128 | def _percentile(self, t, q): 129 | """ 130 | From https://gist.github.com/spezold/42a451682422beb42bc43ad0c0967a30 131 | """ 132 | # Note that ``kthvalue()`` works one-based, i.e. the first sorted value 133 | # indeed corresponds to k=1, not k=0! Use float(q) instead of q directly, 134 | # so that ``round()`` returns an integer, even if q is a np.float32. 135 | k = 1 + round(float(q) * (t.numel() - 1)) 136 | result = t.kthvalue(k).values.item() 137 | return result 138 | 139 | 140 | class SymmetricBCELoss(nn.Module): 141 | """Symmetric Cross Entropy loss https://arxiv.org/pdf/1908.06112.pdf 142 | for binary classification 143 | """ 144 | 145 | def __init__(self, alpha=0.1, beta=0.5, A=-6): 146 | assert A < 0.0 147 | super(SymmetricBCELoss, self).__init__() 148 | self._alpha = alpha 149 | self._beta = beta 150 | self._A = A 151 | 152 | def forward(self, pred, target, reduction="mean"): 153 | bce_loss = F.binary_cross_entropy_with_logits(pred, target, reduction="none") 154 | 155 | log_target_pos = torch.log(target + 1e-10).clamp_min(self._A) 156 | log_target_neg = torch.log(1.0 - target + 1e-10).clamp_min(self._A) 157 | pred_logits = torch.sigmoid(pred) 158 | 159 | rbce_loss = -log_target_pos * pred_logits - log_target_neg * (1.0 - pred_logits) 160 | 161 | loss = self._alpha * bce_loss + self._beta * rbce_loss 162 | 163 | if reduction == "mean": 164 | return loss.mean() 165 | elif reduction == "sum": 166 | return loss.sum() 167 | elif reduction == "none": 168 | return loss 169 | else: 170 | raise RuntimeError 171 | 172 | 173 | class FocalLoss(nn.Module): 174 | """From https://github.com/mbsariyildiz/focal-loss.pytorch/blob/master/focalloss.py 175 | """ 176 | 177 | def __init__(self, gamma=0, alpha=None): 178 | super(FocalLoss, self).__init__() 179 | self.gamma = gamma 180 | self.alpha = alpha 181 | if isinstance(alpha, (float, int)): 182 | self.alpha = torch.Tensor([alpha, 1 - alpha]) 183 | if isinstance(alpha, list): 184 | self.alpha = torch.Tensor(alpha) 185 | 186 | def forward(self, input, target, reduction="mean"): 187 | if input.dim() > 2: 188 | input = input.view(input.size(0), input.size(1), -1) # N,C,H,W => N,C,H*W 189 | input = input.transpose(1, 2) # N,C,H*W => N,H*W,C 190 | input = input.contiguous().view(-1, input.size(2)) # N,H*W,C => N*H*W,C 191 | target = target.view(-1, 1) 192 | 193 | logpt = F.log_softmax(input, dim=1) 194 | logpt = logpt.gather(1, target) 195 | logpt = logpt.view(-1) 196 | pt = logpt.exp() 197 | 198 | if self.alpha is not None: 199 | if self.alpha.type() != input.data.type(): 200 | self.alpha = self.alpha.type_as(input.data) 201 | at = self.alpha.gather(0, target.data.view(-1)) 202 | logpt = logpt * at 203 | 204 | loss = -1 * (1 - pt) ** self.gamma * logpt 205 | 206 | if reduction == "mean": 207 | return loss.mean() 208 | elif reduction == "sum": 209 | return loss.sum() 210 | elif reduction == "none": 211 | return loss 212 | else: 213 | raise RuntimeError 214 | 215 | 216 | class BinaryFocalLoss(nn.Module): 217 | def __init__(self, gamma=2.0, alpha=-1): 218 | super(BinaryFocalLoss, self).__init__() 219 | self.gamma, self.alpha = gamma, alpha 220 | 221 | def forward(self, pred, target, reduction="mean"): 222 | return binary_focal_loss(pred, target, self.gamma, self.alpha, reduction) 223 | 224 | 225 | def binary_focal_loss(pred, target, gamma=2.0, alpha=-1, reduction="mean"): 226 | loss_pos = -target * (1.0 - pred) ** gamma * torch.log(pred) 227 | loss_neg = -(1.0 - target) * pred ** gamma * torch.log(1.0 - pred) 228 | 229 | if alpha >= 0.0 and alpha <= 1.0: 230 | loss_pos = loss_pos * alpha 231 | loss_neg = loss_neg * (1.0 - alpha) 232 | 233 | loss = loss_pos + loss_neg 234 | 235 | if reduction == "mean": 236 | return loss.mean() 237 | elif reduction == "sum": 238 | return loss.sum() 239 | elif reduction == "none": 240 | return loss 241 | else: 242 | raise RuntimeError 243 | -------------------------------------------------------------------------------- /dr_spaam/dr_spaam/pipeline/__init__.py: -------------------------------------------------------------------------------- 1 | from .pipeline import * 2 | from .logger import * 3 | from .optim import * 4 | from .trainer import * -------------------------------------------------------------------------------- /dr_spaam/dr_spaam/pipeline/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import pickle 4 | from shutil import copyfile 5 | import time 6 | 7 | import json 8 | import numpy as np 9 | import matplotlib.pyplot as plt 10 | from tensorboardX import SummaryWriter 11 | import torch 12 | 13 | 14 | def _create_logger(root_dir, file_name="log.txt"): 15 | log_file = os.path.join(root_dir, file_name) 16 | log_format = "%(asctime)s %(levelname)5s %(message)s" 17 | logging.basicConfig(level=logging.DEBUG, format=log_format, filename=log_file) 18 | console = logging.StreamHandler() 19 | console.setLevel(logging.DEBUG) 20 | console.setFormatter(logging.Formatter(log_format)) 21 | logging.getLogger(__name__).addHandler(console) 22 | return logging.getLogger(__name__) 23 | 24 | 25 | class Logger: 26 | def __init__(self, cfg): 27 | cfg["log_dir"] = os.path.abspath(os.path.expanduser(cfg["log_dir"])) 28 | 29 | # main log 30 | if "use_timestamp" in cfg.keys() and cfg["use_timestamp"] is False: 31 | dir_name = cfg["tag"] 32 | else: 33 | timestamp = time.strftime("%Y%m%d_%H%M%S", time.localtime()) 34 | dir_name = f"{timestamp}_{cfg['tag']}" 35 | self.__log_dir = os.path.join(cfg["log_dir"], dir_name) 36 | os.makedirs(self.__log_dir, exist_ok=True) 37 | 38 | self.__log = _create_logger(self.__log_dir, cfg["log_fname"]) 39 | self.log_debug(f"Log directory: {self.__log_dir}") 40 | 41 | # backup important files (e.g. config.yaml) 42 | self.__backup_dir = os.path.join(self.__log_dir, "backup") 43 | os.makedirs(self.__backup_dir, exist_ok=True) 44 | for file_ in cfg["backup_list"]: 45 | self.log_debug(f"Backup {file_}") 46 | copyfile( 47 | os.path.abspath(file_), 48 | os.path.join(self.__backup_dir, os.path.basename(file_)), 49 | ) 50 | 51 | # for storing results (network output etc.) 52 | self.__output_dir = os.path.join(self.__log_dir, "output") 53 | os.makedirs(self.__output_dir, exist_ok=True) 54 | 55 | # for storing ckpt 56 | self.__ckpt_dir = os.path.join(self.__log_dir, "ckpt") 57 | os.makedirs(self.__ckpt_dir, exist_ok=True) 58 | 59 | # for tensorboard 60 | tb_dir = os.path.join(self.__log_dir, "tb") 61 | os.makedirs(tb_dir, exist_ok=True) 62 | self.__tb = SummaryWriter(log_dir=tb_dir) 63 | 64 | # the sigterm checkpoint 65 | self.__sigterm_ckpt = os.path.join( 66 | cfg["log_dir"], f"sigterm_ckpt_{cfg['tag']}.pth" 67 | ) 68 | 69 | gpu = ( 70 | os.environ["CUDA_VISIBLE_DEVICES"] 71 | if "CUDA_VISIBLE_DEVICES" in os.environ.keys() 72 | else "ALL" 73 | ) 74 | self.log_info(f"CUDA_VISIBLE_DEVICES={gpu}") 75 | 76 | def flush(self): 77 | self.__tb.flush() 78 | 79 | def close(self): 80 | self.__tb.close() 81 | handlers = self.__log.handlers[:] 82 | for handler in handlers: 83 | handler.close() 84 | self.__log.removeHandler(handler) 85 | 86 | """ 87 | Python log 88 | """ 89 | 90 | def log_warning(self, s): 91 | self.__log.warning(s) 92 | 93 | def log_info(self, s): 94 | self.__log.info(s) 95 | 96 | def log_debug(self, s): 97 | self.__log.debug(s) 98 | 99 | """ 100 | Add to tensorboard 101 | """ 102 | 103 | def add_scalar(self, key, val, step): 104 | self.__tb.add_scalar(key, val, step) 105 | 106 | def add_fig(self, key, fig, step, close_fig=False): 107 | """Convert a python fig to np.ndarry and add it to tensorboard""" 108 | fig.canvas.draw() 109 | im = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") 110 | im = im.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 111 | im = im.transpose(2, 0, 1) # (3, H, W) 112 | im = im.astype(np.float32) / 255.0 113 | self.add_im(key, im, step) 114 | 115 | if close_fig: 116 | plt.close(fig) 117 | 118 | def add_im(self, key, im, step): 119 | """Add an image to tensorboard. The image should be as np.ndarray 120 | https://tensorboardx.readthedocs.io/en/latest/tutorial.html#add-image 121 | """ 122 | self.__tb.add_image(key, im, step) 123 | 124 | """ 125 | Save to system 126 | """ 127 | 128 | def get_save_dir(self, epoch, split): 129 | return os.path.join(self.__output_dir, split, f"e{epoch:06d}") 130 | 131 | def save_dict(self, dict_, file_name, epoch, split): 132 | """Save the dictionary to a pickle file. Single value items in the dictionary 133 | are stored in addition as a json file for easy inspection. 134 | """ 135 | json_dict = {} 136 | for key, val in dict_.items(): 137 | if not isinstance(val, (np.ndarray, tuple, list, dict)): 138 | json_dict[key] = str(val) 139 | 140 | save_dir = self.get_save_dir(epoch, split) 141 | json_fname = os.path.join(save_dir, f"{file_name}.json") 142 | os.makedirs(os.path.dirname(json_fname), exist_ok=True) 143 | with open(json_fname, "w") as fp: 144 | json.dump(json_dict, fp, sort_keys=True, indent=4) 145 | self.log_info(f"Dictonary saved to {json_fname}.") 146 | 147 | pickle_fname = os.path.join(save_dir, f"{file_name}.pkl") 148 | with open(pickle_fname, "wb") as fp: 149 | pickle.dump(dict_, fp, protocol=pickle.HIGHEST_PROTOCOL) 150 | self.log_info(f"Dictonary saved to {pickle_fname}.") 151 | 152 | def save_fig(self, fig, file_name, epoch, split, close_fig=True): 153 | fname = os.path.join(self.get_save_dir(epoch, split), f"{file_name}.png") 154 | os.makedirs(os.path.dirname(fname), exist_ok=True) 155 | fig.savefig(fname) 156 | if close_fig: 157 | plt.close(fig) 158 | 159 | def save_file(self, file_str, file_name, epoch, split): 160 | fname = os.path.join(self.get_save_dir(epoch, split), f"{file_name}.txt") 161 | os.makedirs(os.path.dirname(fname), exist_ok=True) 162 | with open(fname, "w") as f: 163 | f.write(file_str) 164 | 165 | """ 166 | Save and load checkpoints 167 | """ 168 | 169 | def save_ckpt(self, fname, model, optimizer, epoch, step): 170 | if not os.path.dirname(fname): 171 | fname = os.path.join(self.__ckpt_dir, fname) 172 | 173 | if model is not None: 174 | if isinstance(model, torch.nn.DataParallel): 175 | model_state = model.module.state_dict() 176 | else: 177 | model_state = model.state_dict() 178 | else: 179 | model_state = None 180 | optim_state = optimizer.state_dict() if optimizer is not None else None 181 | 182 | ckpt_dict = { 183 | "epoch": epoch, 184 | "step": step, 185 | "model_state": model_state, 186 | "optimizer_state": optim_state, 187 | } 188 | torch.save(ckpt_dict, fname) 189 | self.log_info(f"Checkpoint saved to {fname}.") 190 | 191 | def load_ckpt(self, fname, model, optimizer=None): 192 | ckpt = torch.load(fname) 193 | epoch = ckpt["epoch"] if "epoch" in ckpt.keys() else 0 194 | step = ckpt["step"] if "step" in ckpt.keys() else 0 195 | 196 | model.load_state_dict(ckpt["model_state"]) 197 | 198 | if optimizer is not None: 199 | optimizer.load_state_dict(ckpt["optimizer_state"]) 200 | 201 | self.log_info(f"Load checkpoint {fname}: epoch {epoch}, step {step}.") 202 | 203 | return epoch, step 204 | 205 | def save_sigterm_ckpt(self, model, optimizer, epoch, step): 206 | """Save a checkpoint, which another process can use to continue the training, 207 | if the current process is terminated or preempted. This checkpoint should 208 | be saved in a process-agnoistic directory such that it can be located by 209 | both processes. 210 | """ 211 | self.save_ckpt(self.__sigterm_ckpt, model, optimizer, epoch, step) 212 | 213 | def load_sigterm_ckpt(self, model, optimizer): 214 | return self.load_ckpt(self.__sigterm_ckpt, model, optimizer) 215 | 216 | def sigterm_ckpt_exists(self): 217 | return os.path.isfile(self.__sigterm_ckpt) 218 | -------------------------------------------------------------------------------- /dr_spaam/dr_spaam/pipeline/optim.py: -------------------------------------------------------------------------------- 1 | from torch import optim 2 | 3 | 4 | class Optim: 5 | def __init__(self, model, cfg): 6 | self._optim = optim.Adam(model.parameters(), amsgrad=True) 7 | self._lr_scheduler = _ExpDecayScheduler(**cfg["scheduler_kwargs"]) 8 | 9 | def zero_grad(self): 10 | self._optim.zero_grad() 11 | 12 | def step(self): 13 | self._optim.step() 14 | 15 | def state_dict(self): 16 | return self._optim.state_dict() 17 | 18 | def load_state_dict(self, state_dict): 19 | self._optim.load_state_dict(state_dict) 20 | 21 | def set_lr(self, epoch): 22 | for group in self._optim.param_groups: 23 | group["lr"] = self._lr_scheduler(epoch) 24 | 25 | def get_lr(self): 26 | return self._optim.param_groups[0]["lr"] 27 | 28 | 29 | class _ExpDecayScheduler: 30 | """ 31 | Return `v0` until `e` reaches `e0`, then exponentially decay 32 | to `v1` when `e` reaches `e1` and return `v1` thereafter, until 33 | reaching `eNone`, after which it returns `None`. 34 | """ 35 | 36 | def __init__(self, epoch0, lr0, epoch1, lr1): 37 | self._epoch0 = epoch0 38 | self._epoch1 = epoch1 39 | self._lr0 = lr0 40 | self._lr1 = lr1 41 | 42 | def __call__(self, epoch): 43 | if epoch < self._epoch0: 44 | return self._lr0 45 | elif epoch > self._epoch1: 46 | return self._lr1 47 | else: 48 | return self._lr0 * (self._lr1 / self._lr0) ** ( 49 | (epoch - self._epoch0) / (self._epoch1 - self._epoch0) 50 | ) 51 | -------------------------------------------------------------------------------- /dr_spaam/dr_spaam/pipeline/pipeline.py: -------------------------------------------------------------------------------- 1 | from .optim import Optim 2 | from .trainer import Trainer 3 | from .logger import Logger 4 | 5 | 6 | class Pipeline: 7 | def __init__(self, model, cfg): 8 | self.logger = Logger(cfg["Logger"]) 9 | self.optim = Optim(model, cfg["Optim"]) 10 | self.trainer = Trainer(self.logger, self.optim, cfg["Trainer"]) 11 | self.logger.log_debug("Pipeline starts.") 12 | 13 | def close(self): 14 | self.logger.log_debug("Pipeline closes.") 15 | self.logger.close() 16 | 17 | def train(self, model, train_loader, eval_loader=None): 18 | self.logger.log_debug("Training starts.") 19 | status = self.trainer.train(model, train_loader, eval_loader) 20 | self.logger.log_debug(f"Training ends (status {status}).") 21 | return status 22 | 23 | def evaluate(self, model, eval_loader, tb_prefix): 24 | self.logger.log_debug("Evaluation starts.") 25 | status = self.trainer.evaluate( 26 | model, eval_loader, tb_prefix, plotting=False 27 | ) 28 | self.logger.log_debug(f"Evaluation ends (status {status}).") 29 | return status 30 | 31 | def load_ckpt(self, model, ckpt, use_ckpt_epoch=False): 32 | epoch, step = self.logger.load_ckpt(ckpt, model, self.optim) 33 | # When finetuning a pre-trained checkpoint, we don't care the previous 34 | # training schedule, so not setting epoch and step 35 | if use_ckpt_epoch: 36 | self.trainer.set_epoch_step(epoch, step) 37 | 38 | def load_sigterm_ckpt(self, model): 39 | epoch, step = self.logger.load_sigterm_ckpt(model, self.optim) 40 | self.trainer.set_epoch_step(epoch, step) 41 | 42 | def sigterm_ckpt_exists(self): 43 | return self.logger.sigterm_ckpt_exists() 44 | -------------------------------------------------------------------------------- /dr_spaam/dr_spaam/pipeline/trainer.py: -------------------------------------------------------------------------------- 1 | import signal 2 | import tqdm 3 | 4 | import torch 5 | from torch.nn.utils import clip_grad_norm_ 6 | 7 | 8 | class Trainer(object): 9 | def __init__(self, logger, optimizer, cfg): 10 | self._logger = logger 11 | self._optim = optimizer 12 | self._epoch, self._step = 0, 0 13 | 14 | self._grad_norm_clip = cfg["grad_norm_clip"] 15 | self._ckpt_interval = cfg["ckpt_interval"] 16 | self._eval_interval = cfg["eval_interval"] 17 | self._max_epoch = cfg["epoch"] 18 | 19 | self.__sigterm = False 20 | signal.signal(signal.SIGINT, self._sigterm_cb) 21 | signal.signal(signal.SIGTERM, self._sigterm_cb) 22 | 23 | def set_epoch_step(self, epoch=None, step=None): 24 | if epoch is not None: 25 | self._epoch = epoch 26 | if step is not None: 27 | self._step = step 28 | 29 | def evaluate(self, model, eval_loader, tb_prefix, plotting): 30 | model.eval() 31 | tb_dict_list = [] 32 | split = eval_loader.dataset.split 33 | pbar = tqdm.tqdm(total=len(eval_loader), leave=False, desc=f"eval ({split})") 34 | 35 | for b_idx, batch in enumerate(eval_loader): 36 | if self.__sigterm: 37 | pbar.close() 38 | return 1 39 | 40 | with torch.no_grad(): 41 | tb_dict, file_dict, fig_dict = model.model_eval_fn(model, batch) 42 | 43 | # collate tb_dict for epoch evaluation 44 | tb_dict_list.append(tb_dict) 45 | 46 | # save file 47 | for k, v in file_dict.items(): 48 | self._logger.save_file(v, k, self._epoch, split) 49 | 50 | # save figure 51 | for k, (fig, ax) in fig_dict.items(): 52 | # self._logger.add_fig(k, fig, self._step) 53 | self._logger.save_fig(fig, k, self._epoch, split, close_fig=True) 54 | 55 | pbar.update() 56 | 57 | tb_dict, epoch_dict = model.model_eval_collate_fn( 58 | tb_dict_list, self._logger.get_save_dir(self._epoch, split), 59 | ) 60 | 61 | for k, v in tb_dict.items(): 62 | self._logger.add_scalar(f"{tb_prefix}_{k}", v, self._step) 63 | 64 | for k, v in epoch_dict.items(): 65 | self._logger.save_dict(v, k, self._epoch, split) 66 | 67 | pbar.close() 68 | 69 | return 0 70 | 71 | def train(self, model, train_loader, eval_loader=None): 72 | for self._epoch in tqdm.trange(0, self._max_epoch, desc="epochs"): 73 | if self.__sigterm: 74 | self._logger.save_sigterm_ckpt( 75 | model, self._optim, self._epoch, self._step, 76 | ) 77 | return 1 78 | 79 | self._train_epoch(model, train_loader) 80 | 81 | if not self.__sigterm: 82 | if self._is_ckpt_epoch(): 83 | self._logger.save_ckpt( 84 | f"ckpt_e{self._epoch}.pth", 85 | model, 86 | self._optim, 87 | self._epoch, 88 | self._step, 89 | ) 90 | 91 | if eval_loader is not None and self._is_evaluation_epoch(): 92 | self.evaluate( 93 | model, 94 | eval_loader, 95 | tb_prefix="VAL", 96 | plotting=False, 97 | ) 98 | 99 | self._logger.flush() 100 | 101 | return 0 102 | 103 | def _is_ckpt_epoch(self): 104 | return self._epoch % self._ckpt_interval == 0 or self._epoch == self._max_epoch 105 | 106 | def _is_evaluation_epoch(self): 107 | return self._epoch % self._eval_interval == 0 or self._epoch == self._max_epoch 108 | 109 | def _sigterm_cb(self, signum, frame): 110 | self.__sigterm = True 111 | self._logger.log_info(f"Received signal {signum} at frame {frame}.") 112 | 113 | def _train_batch(self, model, batch, ratio): 114 | """Train one batch. `ratio` in between [0, 1) is the progress of training 115 | current epoch. It is used by the scheduler to update learning rate. 116 | """ 117 | model.train() 118 | self._optim.zero_grad() 119 | self._optim.set_lr(self._epoch + ratio) 120 | 121 | loss, tb_dict, _ = model.model_fn(model, batch) 122 | loss.backward() 123 | 124 | if self._grad_norm_clip > 0: 125 | clip_grad_norm_(model.parameters(), self._grad_norm_clip) 126 | 127 | self._optim.step() 128 | 129 | self._logger.add_scalar("TRAIN_lr", self._optim.get_lr(), self._step) 130 | self._logger.add_scalar("TRAIN_loss", loss, self._step) 131 | self._logger.add_scalar("TRAIN_epoch", self._epoch + ratio, self._step) 132 | for key, val in tb_dict.items(): 133 | self._logger.add_scalar(f"TRAIN_{key}", val, self._step) 134 | 135 | return loss.item() 136 | 137 | # # NOTE Dirty fix to use mixup regularization, to be removed 138 | # if model.mixup_alpha <= 0.0: 139 | # return loss.item() 140 | 141 | # original_loss_value = loss.item() 142 | 143 | # self._optim.zero_grad() 144 | # loss, tb_dict, _ = model.model_fn_mixup(model, batch) 145 | # loss.backward() 146 | 147 | # if self._grad_norm_clip > 0: 148 | # clip_grad_norm_(model.parameters(), self._grad_norm_clip) 149 | 150 | # self._optim.step() 151 | 152 | # for key, val in tb_dict.items(): 153 | # self._logger.add_scalar(f"TRAIN_{key}", val, self._step) 154 | 155 | # return original_loss_value 156 | 157 | def _train_epoch(self, model, train_loader): 158 | pbar = tqdm.tqdm(total=len(train_loader), leave=False, desc="train") 159 | for ib, batch in enumerate(train_loader): 160 | if self.__sigterm: 161 | pbar.close() 162 | return 163 | 164 | loss = self._train_batch(model, batch, ratio=(ib / len(train_loader))) 165 | self._step += 1 166 | pbar.set_postfix({"total_it": self._step, "loss": loss}) 167 | pbar.update() 168 | 169 | self._epoch = self._epoch + 1 170 | pbar.close() 171 | -------------------------------------------------------------------------------- /dr_spaam/dr_spaam/pseudo_labels.py: -------------------------------------------------------------------------------- 1 | from dr_spaam.dataset.jrdb_dataset import _get_regression_target_from_pseudo_labels 2 | from dr_spaam.utils import utils as u 3 | 4 | 5 | def get_regression_target_using_bounding_boxes( 6 | scan_r, scan_phi, scan_uv, boxes, boxes_conf 7 | ): 8 | pl_xy, pl_boxes, pl_neg_mask = u.generate_pseudo_labels( 9 | scan_r, scan_phi, scan_uv, boxes, boxes_conf 10 | ) 11 | 12 | (target_cls_pseudo, target_reg_pseudo,) = _get_regression_target_from_pseudo_labels( 13 | scan_r, 14 | pl_xy, 15 | pl_neg_mask, 16 | person_radius_small=0.4, 17 | person_radius_large=0.8, 18 | min_close_points=5, 19 | pl_correction_level=-1, 20 | target_cls_annotated=None, 21 | target_reg_annotated=None, 22 | ) 23 | 24 | return target_cls_pseudo, target_reg_pseudo 25 | -------------------------------------------------------------------------------- /dr_spaam/dr_spaam/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VisualComputingInstitute/2D_lidar_person_detection/99dd7a2a0d64252905e4f621e2c45be64b653a32/dr_spaam/dr_spaam/utils/__init__.py -------------------------------------------------------------------------------- /dr_spaam/dr_spaam/utils/jrdb_transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | """ 4 | Transformations. Following frames are defined: 5 | 6 | base: main frame where 3D annotations are done in, x-forward, y-left, z-up 7 | upper_lidar: x-forward, y-left, z-up 8 | lower_lidar: x-forward, y-left, z-up 9 | laser: x-forward, y-left, z-up 10 | """ 11 | 12 | 13 | def _get_R_z(rot_z): 14 | cs, ss = np.cos(rot_z), np.sin(rot_z) 15 | return np.array([[cs, -ss, 0], [ss, cs, 0], [0, 0, 1]], dtype=np.float32) 16 | 17 | 18 | # laser to base 19 | _rot_z_laser_to_base = np.pi / 120 20 | _R_laser_to_base = _get_R_z(_rot_z_laser_to_base) 21 | 22 | # upper_lidar to base 23 | _rot_z_upper_lidar_to_base = 0.085 24 | _T_upper_lidar_to_base = np.array([0, 0, 0.33529], dtype=np.float32).reshape(3, 1) 25 | _R_upper_lidar_to_base = _get_R_z(_rot_z_upper_lidar_to_base) 26 | 27 | # lower_lidar to base 28 | _rot_z_lower_lidar_to_base = 0.0 29 | _T_lower_lidar_to_base = np.array([0, 0, -0.13511], dtype=np.float32).reshape(3, 1) 30 | _R_lower_lidar_to_base = np.eye(3, dtype=np.float32) 31 | 32 | 33 | """ 34 | Transformation API 35 | """ 36 | 37 | 38 | def transform_pts_upper_velodyne_to_base(pts): 39 | """Transform points from upper velodyne frame to base frame 40 | 41 | Args: 42 | pts (np.array[3, N]): points (x, y, z) 43 | 44 | Returns: 45 | pts_trans (np.array[3, N]) 46 | """ 47 | return _R_upper_lidar_to_base @ pts + _T_upper_lidar_to_base 48 | 49 | 50 | def transform_pts_lower_velodyne_to_base(pts): 51 | return _R_lower_lidar_to_base @ pts + _T_lower_lidar_to_base 52 | 53 | 54 | def transform_pts_laser_to_base(pts): 55 | return _R_laser_to_base @ pts 56 | 57 | 58 | def transform_pts_base_to_upper_velodyne(pts): 59 | return _R_upper_lidar_to_base.T @ (pts - _T_upper_lidar_to_base) 60 | 61 | 62 | def transform_pts_base_to_lower_velodyne(pts): 63 | return _R_lower_lidar_to_base.T @ (pts - _T_lower_lidar_to_base) 64 | 65 | 66 | def transform_pts_base_to_laser(pts): 67 | return _R_laser_to_base.T @ pts 68 | 69 | 70 | def transform_pts_base_to_stitched_im(pts): 71 | """Project 3D points in base frame to the stitched image 72 | 73 | Args: 74 | pts (np.array[3, N]): points (x, y, z) 75 | 76 | Returns: 77 | pts_im (np.array[2, N]) 78 | inbound_mask (np.array[N]) 79 | """ 80 | im_size = (480, 3760) 81 | 82 | # to image coordinate 83 | pts_rect = pts[[1, 2, 0], :] 84 | pts_rect[:2, :] *= -1 85 | 86 | # to pixel 87 | horizontal_theta = np.arctan2(pts_rect[0], pts_rect[2]) 88 | horizontal_percent = horizontal_theta / (2 * np.pi) + 0.5 89 | x = im_size[1] * horizontal_percent 90 | y = ( 91 | 485.78 * pts_rect[1] / pts_rect[2] * np.cos(horizontal_theta) 92 | + 0.4375 * im_size[0] 93 | ) 94 | # horizontal_theta = np.arctan(pts_rect[0, :] / pts_rect[2, :]) 95 | # horizontal_theta += (pts_rect[2, :] < 0) * np.pi 96 | # horizontal_percent = horizontal_theta / (2 * np.pi) 97 | # x = ((horizontal_percent * im_size[1]) + 1880) % im_size[1] 98 | # y = ( 99 | # 485.78 * (pts_rect[1, :] / ((1 / np.cos(horizontal_theta)) * pts_rect[2, :])) 100 | # ) + (0.4375 * im_size[0]) 101 | 102 | # x is always in bound by cylindrical parametrization 103 | # y is always at the lower half of the image, since laser is lower than the camera 104 | # thus only one boundary needs to be checked 105 | inbound_mask = y < im_size[0] 106 | 107 | return np.stack((x, y), axis=0).astype(np.int32), inbound_mask 108 | 109 | 110 | def transform_pts_laser_to_stitched_im(pts): 111 | pts_base = transform_pts_laser_to_base(pts) 112 | return transform_pts_base_to_stitched_im(pts_base) 113 | -------------------------------------------------------------------------------- /dr_spaam/dr_spaam/utils/jrdb_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def box_to_kitti_string(dets_xy, dets_cls, occluded): 5 | """Obtain a KITTI format string for a detected box 6 | 7 | Args: 8 | dets_xy (np.array[N, 2]) 9 | dets_cls (np.array[N]) 10 | occluded (np.array[N]): Only revelent for annotation. Set to 1 if the 11 | annotation is not visible (less than 5 points in proximity) 12 | 13 | Returns: 14 | s (str) 15 | """ 16 | if dets_cls is None: 17 | dets_cls = np.ones(len(dets_xy), dtype=np.float32) 18 | 19 | if occluded is None: 20 | occluded = np.zeros(len(dets_xy), dtype=np.int) 21 | 22 | s = "" 23 | for cls, xy, occ in zip(dets_cls, dets_xy, occluded): 24 | s += f"Pedestrian 0 {occ} 0 0 0 0 0 0 0 0 0 {xy[0]} {xy[1]} 0 0 {cls}\n" 25 | s = s.strip("\n") 26 | 27 | return s 28 | 29 | 30 | def kitti_string_to_box(s): 31 | dets_xy = [] 32 | dets_cls = [] 33 | occluded = [] 34 | 35 | if s: 36 | lines = s.split("\n") 37 | for line in lines: 38 | vals = line.split(" ") 39 | dets_cls.append(float(vals[-1])) 40 | dets_xy.append((float(vals[-5]), float(vals[-4]))) 41 | occluded.append(int(vals[2])) 42 | 43 | dets_cls = np.array(dets_cls, dtype=np.float32) 44 | dets_xy = np.array(dets_xy, dtype=np.float32) 45 | occluded = np.array(occluded, dtype=np.int) 46 | 47 | return dets_xy, dets_cls, occluded 48 | -------------------------------------------------------------------------------- /dr_spaam/dr_spaam/utils/plotting.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | import dr_spaam.utils.utils as u 5 | 6 | _X_LIM = (-7, 7) 7 | _Y_LIM = (-7, 7) 8 | 9 | 10 | def plot_one_frame( 11 | batch_dict, 12 | frame_idx, 13 | pred_cls=None, 14 | pred_reg=None, 15 | dets_cls=None, 16 | dets_xy=None, 17 | xlim=_X_LIM, 18 | ylim=_Y_LIM, 19 | ): 20 | """Plot one frame from a batch, specified by frame_idx. 21 | 22 | Returns: 23 | fig: figure handle 24 | ax: axis handle 25 | """ 26 | fig, ax = _create_figure("", xlim, ylim) 27 | 28 | # scan and cls label 29 | scan_r = batch_dict["scans"][frame_idx][-1] 30 | scan_phi = batch_dict["scan_phi"][frame_idx] 31 | target_cls = batch_dict["target_cls"][frame_idx] 32 | _plot_scan(ax, scan_r, scan_phi, target_cls, s=1) 33 | 34 | # annotation 35 | ann = batch_dict["dets_wp"][frame_idx] 36 | ann_valid_mask = batch_dict["anns_valid_mask"][frame_idx] 37 | if len(ann) > 0: 38 | ann = np.array(ann) 39 | det_x, det_y = u.rphi_to_xy(ann[:, 0], ann[:, 1]) 40 | for x, y, valid in zip(det_x, det_y, ann_valid_mask): 41 | c = "blue" if valid else "orange" 42 | c = plt.Circle((x, y), radius=0.4, color=c, fill=False) 43 | ax.add_artist(c) 44 | 45 | # regression target 46 | target_reg = batch_dict["target_reg"][frame_idx] 47 | _plot_target(ax, target_reg, target_cls > 0, scan_r, scan_phi, s=10, c="blue") 48 | 49 | # regression result 50 | if dets_xy is not None and dets_cls is not None: 51 | _plot_detection(ax, dets_cls, dets_xy, s=40, color_dim=1) 52 | 53 | if pred_cls is not None and pred_reg is not None: 54 | _plot_prediction(ax, pred_cls, pred_reg, scan_r, scan_phi, s=2, color_dim=1) 55 | 56 | return fig, ax 57 | 58 | 59 | def plot_one_batch(batch_dict, xlim=_X_LIM, ylim=_Y_LIM): 60 | fig_ax_list = [] 61 | for ib in range(len(batch_dict["input"])): 62 | fig_ax_list.append(plot_one_frame(batch_dict, ib)) 63 | 64 | return fig_ax_list 65 | 66 | 67 | def plot_one_batch_detr(batch_dict, xlim=_X_LIM, ylim=_Y_LIM): 68 | fig_ax_list = [] 69 | 70 | for ib in range(len(batch_dict["input"])): 71 | fr_idx = batch_dict["frame_dict_curr"][ib]["idx"] 72 | fig, ax = _create_figure(fr_idx, xlim, ylim) 73 | 74 | # scan and cls label 75 | scan_r = batch_dict["frame_dict_curr"][ib]["laser_data"][-1] 76 | scan_phi = batch_dict["frame_dict_curr"][ib]["laser_grid"] 77 | target_cls = batch_dict["target_cls"][ib] 78 | _plot_scan(ax, scan_r, scan_phi, target_cls, s=1) 79 | 80 | # annotation for current frame 81 | anns = batch_dict["frame_dict_curr"][ib]["dets_rphi"] 82 | anns_valid_mask = batch_dict["anns_valid_mask"][ib] 83 | anns_valid = anns[:, anns_valid_mask] 84 | anns_invalid = anns[:, np.logical_not(anns_valid_mask)] 85 | _plot_annotation_detr(ax, anns_valid, radius=0.4, color="blue") 86 | _plot_annotation_detr(ax, anns_invalid, radius=0.4, color="orange") 87 | 88 | # annotation for previous frame 89 | anns_prev = batch_dict["frame_dict_curr"][ib]["dets_rphi_prev"] 90 | anns_tracking_mask = batch_dict["anns_tracking_mask"][ib] 91 | anns_prev = anns_prev[:, anns_tracking_mask] 92 | _plot_annotation_detr(ax, anns_prev, radius=0.4, color="gray", linestyle="--") 93 | 94 | # regression target for previous frame 95 | target_reg_prev = batch_dict["target_reg_prev"][ib] 96 | target_tracking_flag = batch_dict["target_tracking_flag"][ib] 97 | _plot_target( 98 | ax, target_reg_prev, target_tracking_flag, scan_r, scan_phi, s=25, c="gray" 99 | ) 100 | 101 | # regression target for current frame 102 | target_reg = batch_dict["target_reg"][ib] 103 | _plot_target(ax, target_reg, target_cls > 0, scan_r, scan_phi, s=10, c="red") 104 | 105 | # regression result for previous frame 106 | pred_cls = batch_dict["pred_cls"][ib] 107 | pred_reg_prev = batch_dict["pred_reg_prev"][ib] 108 | _plot_prediction( 109 | ax, pred_cls, pred_reg_prev, scan_r, scan_phi, s=2, color_dim=2 110 | ) 111 | 112 | # regression result for current frame 113 | pred_reg = batch_dict["pred_reg"][ib] 114 | _plot_prediction(ax, pred_cls, pred_reg, scan_r, scan_phi, s=2, color_dim=1) 115 | 116 | fig_ax_list.append((fig, ax)) 117 | 118 | return fig_ax_list 119 | 120 | 121 | def _cls_to_color(cls, color_dim): 122 | color = 1.0 - cls.reshape(-1, 1).repeat(3, axis=1) 123 | color[:, color_dim] = 1 124 | return color 125 | 126 | 127 | def _create_figure(title, xlim, ylim): 128 | fig = plt.figure(figsize=(10, 10)) 129 | ax = fig.add_subplot(111) 130 | 131 | ax.set_xlim(xlim[0], xlim[1]) 132 | ax.set_ylim(ylim[0], ylim[1]) 133 | ax.set_xlabel("x [m]") 134 | ax.set_ylabel("y [m]") 135 | ax.set_aspect("equal") 136 | ax.set_title(f"{title}") 137 | 138 | return fig, ax 139 | 140 | 141 | def _plot_scan(ax, scan_r, scan_phi, target_cls, s): 142 | scan_x, scan_y = u.rphi_to_xy(scan_r, scan_phi) 143 | ax.scatter(scan_x[target_cls < 0], scan_y[target_cls < 0], s=s, c="orange") 144 | ax.scatter(scan_x[target_cls == 0], scan_y[target_cls == 0], s=s, c="black") 145 | ax.scatter(scan_x[target_cls > 0], scan_y[target_cls > 0], s=s, c="green") 146 | 147 | 148 | def _plot_prediction(ax, pred_cls, pred_reg, scan_r, scan_phi, s, color_dim): 149 | pred_r, pred_phi = u.canonical_to_global( 150 | scan_r, scan_phi, pred_reg[:, 0], pred_reg[:, 1] 151 | ) 152 | pred_x, pred_y = u.rphi_to_xy(pred_r, pred_phi) 153 | pred_color = _cls_to_color(pred_cls, color_dim=color_dim) 154 | ax.scatter(pred_x, pred_y, s=s, c=pred_color) 155 | 156 | 157 | def _plot_detection(ax, dets_cls, dets_xy, s, color_dim): 158 | dets_color = _cls_to_color(dets_cls, color_dim=color_dim) 159 | ax.scatter(dets_xy[:, 0], dets_xy[:, 1], marker="x", s=s, c=dets_color) 160 | 161 | 162 | def _plot_target(ax, target_reg, target_flag, scan_r, scan_phi, s, c): 163 | dets_r, dets_phi = u.canonical_to_global( 164 | scan_r, scan_phi, target_reg[:, 0], target_reg[:, 1] 165 | ) 166 | dets_r = dets_r[target_flag] 167 | dets_phi = dets_phi[target_flag] 168 | dets_x, dets_y = u.rphi_to_xy(dets_r, dets_phi) 169 | ax.scatter(dets_x, dets_y, s=s, c=c) 170 | 171 | 172 | def _plot_annotation_detr(ax, anns, radius, color, linestyle="-"): 173 | if len(anns) == 0: 174 | return 175 | 176 | det_x, det_y = u.rphi_to_xy(anns[0], anns[1]) 177 | for x, y in zip(det_x, det_y): 178 | c = plt.Circle( 179 | (x, y), radius=radius, color=color, fill=False, linestyle=linestyle 180 | ) 181 | ax.add_artist(c) 182 | -------------------------------------------------------------------------------- /dr_spaam/dr_spaam/utils/pytorch_nms/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2018, Grégoire Payen de La Garanderie, Durham University 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | * Neither the name of the copyright holder nor the names of its 15 | contributors may be used to endorse or promote products derived from 16 | this software without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | 29 | ************************************************************************ 30 | 31 | THIRD-PARTY SOFTWARE NOTICES AND INFORMATION 32 | 33 | This project incorporates material from the project(s) 34 | listed below (collectively, "Third Party Code"). This Third Party Code is 35 | licensed to you under their original license terms set forth below. 36 | 37 | 1. Faster R-CNN, (https://github.com/rbgirshick/py-faster-rcnn) 38 | 39 | The MIT License (MIT) 40 | 41 | Copyright (c) 2015 Microsoft Corporation 42 | 43 | Permission is hereby granted, free of charge, to any person obtaining a copy 44 | of this software and associated documentation files (the "Software"), to deal 45 | in the Software without restriction, including without limitation the rights 46 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 47 | copies of the Software, and to permit persons to whom the Software is 48 | furnished to do so, subject to the following conditions: 49 | 50 | The above copyright notice and this permission notice shall be included in 51 | all copies or substantial portions of the Software. 52 | 53 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 54 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 55 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 56 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 57 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 58 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 59 | THE SOFTWARE. 60 | 61 | -------------------------------------------------------------------------------- /dr_spaam/dr_spaam/utils/pytorch_nms/README.md: -------------------------------------------------------------------------------- 1 | # Torchvision support for NMS 2 | 3 | Note: Since the publication of this repository, NMS support has been included as part of torchvision. Therefore you might want to use this implementation instead: 4 | https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py. 5 | 6 | This repository might still be of interest if you need the index in the `keep` list of the highest-scoring box overlapping each input box. 7 | 8 | # CUDA implementation of NMS for PyTorch. 9 | 10 | 11 | This repository has a CUDA implementation of NMS for PyTorch 1.4.0. 12 | 13 | The code is released under the BSD license however it also includes parts of the original implementation from [Fast R-CNN](https://github.com/rbgirshick/py-faster-rcnn) which falls under the MIT license (see LICENSE file for details). 14 | 15 | The code is experimental and has not be thoroughly tested yet; use at your own risk. Any issues and pull requests are welcome. 16 | 17 | ## Installation 18 | 19 | ``` 20 | python setup.py install 21 | ``` 22 | 23 | ## Usage 24 | 25 | Example: 26 | ``` 27 | from nms import nms 28 | 29 | keep, num_to_keep, parent_object_index = nms(boxes, scores, overlap=.5, top_k=200) 30 | ``` 31 | 32 | The `nms` function takes a (N,4) tensor of `boxes` and associated (N) tensor of `scores`, sorts the bounding boxes by score and selects boxes using Non-Maximum Suppression according to the given `overlap`. It returns the indices of the `top_k` with the highest score. Bounding boxes are represented using the standard (left,top,right,bottom) coordinates representation. 33 | 34 | `keep` is the list of indices of kept bounding boxes. Note that the tensor size is always (N) however only the first `num_to_keep` entries are valid. 35 | 36 | For each input box, the (N) tensor `parent_object_index` contains the index (1-based) in the `keep` list of the highest-scoring box overlapping this box. This can be useful to group input boxes that are related to the same object. The index 0 represents a background box which has been ignored due to `top_k`. 37 | 38 | Currently there is a hard-limit of 64,000 input boxes. You can change the constant `MAX_COL_BLOCKS` in `nms_kernel.cu` to increase this limit. 39 | 40 | -------------------------------------------------------------------------------- /dr_spaam/dr_spaam/utils/pytorch_nms/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import CUDAExtension, BuildExtension 3 | 4 | setup( 5 | name="nms", 6 | packages=["nms"], 7 | package_dir={"": "src"}, 8 | ext_modules=[ 9 | CUDAExtension( 10 | "nms.details", 11 | ["src/nms.cpp", "src/nms_kernel.cu"], 12 | extra_compile_args={"cxx": ["-g"], "nvcc": ["-O2"]}, 13 | ) 14 | ], 15 | cmdclass={"build_ext": BuildExtension}, 16 | ) 17 | -------------------------------------------------------------------------------- /dr_spaam/dr_spaam/utils/pytorch_nms/src/nms.cpp: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2018, Grégoire Payen de La Garanderie, Durham University 2 | * All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without 5 | * modification, are permitted provided that the following conditions are met: 6 | * 7 | * * Redistributions of source code must retain the above copyright notice, this 8 | * list of conditions and the following disclaimer. 9 | * 10 | * * Redistributions in binary form must reproduce the above copyright notice, 11 | * this list of conditions and the following disclaimer in the documentation 12 | * and/or other materials provided with the distribution. 13 | * 14 | * * Neither the name of the copyright holder nor the names of its 15 | * contributors may be used to endorse or promote products derived from 16 | * this software without specific prior written permission. 17 | * 18 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | */ 29 | 30 | #include 31 | #include 32 | #include 33 | 34 | std::vector nms_cuda_forward( 35 | at::Tensor boxes, 36 | at::Tensor idx, 37 | float nms_overlap_thresh, 38 | unsigned long top_k); 39 | 40 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") 41 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 42 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 43 | 44 | std::vector nms_forward( 45 | at::Tensor boxes, 46 | at::Tensor scores, 47 | float thresh, 48 | unsigned long top_k) { 49 | 50 | 51 | auto idx = std::get<1>(scores.sort(0,true)); 52 | 53 | CHECK_INPUT(boxes); 54 | CHECK_INPUT(idx); 55 | 56 | return nms_cuda_forward(boxes, idx, thresh, top_k); 57 | } 58 | 59 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 60 | m.def("nms_forward", &nms_forward, "NMS"); 61 | } 62 | 63 | -------------------------------------------------------------------------------- /dr_spaam/dr_spaam/utils/pytorch_nms/src/nms/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018, Grégoire Payen de La Garanderie, Durham University 2 | # All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # * Redistributions of source code must retain the above copyright notice, this 8 | # list of conditions and the following disclaimer. 9 | # 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # 14 | # * Neither the name of the copyright holder nor the names of its 15 | # contributors may be used to endorse or promote products derived from 16 | # this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | 29 | from . import details 30 | 31 | 32 | def nms(boxes, scores, overlap, top_k): 33 | return details.nms_forward(boxes, scores, overlap, top_k) 34 | -------------------------------------------------------------------------------- /dr_spaam/dr_spaam/utils/pytorch_nms/src/nms_kernel.cu: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2018, Grégoire Payen de La Garanderie, Durham University 2 | * All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without 5 | * modification, are permitted provided that the following conditions are met: 6 | * 7 | * * Redistributions of source code must retain the above copyright notice, this 8 | * list of conditions and the following disclaimer. 9 | * 10 | * * Redistributions in binary form must reproduce the above copyright notice, 11 | * this list of conditions and the following disclaimer in the documentation 12 | * and/or other materials provided with the distribution. 13 | * 14 | * * Neither the name of the copyright holder nor the names of its 15 | * contributors may be used to endorse or promote products derived from 16 | * this software without specific prior written permission. 17 | * 18 | * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | */ 29 | #include 30 | #include 31 | #include 32 | 33 | #include 34 | #include 35 | #include 36 | #include 37 | 38 | // From https://stackoverflow.com/questions/14038589/what-is-the-canonical-way-to-check-for-errors-using-the-cuda-runtime-api 39 | #define gpuErrchk(ans) { gpuAssert((ans), __FILE__, __LINE__); } 40 | inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort=true) 41 | { 42 | if (code != cudaSuccess) 43 | { 44 | fprintf(stderr,"GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line); 45 | if (abort) exit(code); 46 | } 47 | } 48 | 49 | __global__ void printTensorKernel( 50 | torch::PackedTensorAccessor64 boxes, 51 | torch::PackedTensorAccessor64 inds, 52 | const int n_boxes) 53 | { 54 | for (int i = 0; i < n_boxes; ++i) 55 | { 56 | printf("idx: %d, x: %f, y: %f, sort: %i\n", 57 | i, boxes[i][0], boxes[i][1], inds[i][0]); 58 | } 59 | } 60 | 61 | // Hard-coded maximum. Increase if needed. 62 | #define MAX_COL_BLOCKS 1000 63 | 64 | #define DIVUP(m,n) (((m)+(n)-1) / (n)) 65 | int64_t const threadsPerBlock = sizeof(unsigned long long) * 8; 66 | 67 | // The functions below originates from Fast R-CNN 68 | // See https://github.com/rbgirshick/py-faster-rcnn 69 | // Copyright (c) 2015 Microsoft 70 | // Licensed under The MIT License 71 | // Written by Shaoqing Ren 72 | 73 | template 74 | __device__ inline scalar_t devIoU(scalar_t const * const a, scalar_t const * const b) { 75 | // scalar_t left = max(a[0], b[0]), right = min(a[2], b[2]); 76 | // scalar_t top = max(a[1], b[1]), bottom = min(a[3], b[3]); 77 | // scalar_t width = max(right - left, 0.f), height = max(bottom - top, 0.f); 78 | // scalar_t interS = width * height; 79 | // scalar_t Sa = (a[2] - a[0]) * (a[3] - a[1]); 80 | // scalar_t Sb = (b[2] - b[0]) * (b[3] - b[1]); 81 | // return interS / (Sa + Sb - interS); 82 | scalar_t x_diff = a[0] - b[0]; 83 | scalar_t y_diff = a[1] - b[1]; 84 | return sqrt(x_diff * x_diff + y_diff * y_diff); 85 | } 86 | 87 | template 88 | __global__ void nms_kernel(const int64_t n_boxes, const scalar_t nms_overlap_thresh, 89 | const scalar_t *dev_boxes, const int64_t *idx, int64_t *dev_mask) { 90 | const int64_t row_start = blockIdx.y; 91 | const int64_t col_start = blockIdx.x; 92 | 93 | const int row_size = 94 | min(n_boxes - row_start * threadsPerBlock, threadsPerBlock); 95 | const int col_size = 96 | min(n_boxes - col_start * threadsPerBlock, threadsPerBlock); 97 | 98 | // __shared__ scalar_t block_boxes[threadsPerBlock * 4]; 99 | // if (threadIdx.x < col_size) { 100 | // block_boxes[threadIdx.x * 4 + 0] = 101 | // dev_boxes[idx[(threadsPerBlock * col_start + threadIdx.x)] * 4 + 0]; 102 | // block_boxes[threadIdx.x * 4 + 1] = 103 | // dev_boxes[idx[(threadsPerBlock * col_start + threadIdx.x)] * 4 + 1]; 104 | // block_boxes[threadIdx.x * 4 + 2] = 105 | // dev_boxes[idx[(threadsPerBlock * col_start + threadIdx.x)] * 4 + 2]; 106 | // block_boxes[threadIdx.x * 4 + 3] = 107 | // dev_boxes[idx[(threadsPerBlock * col_start + threadIdx.x)] * 4 + 3]; 108 | // } 109 | __shared__ scalar_t block_boxes[threadsPerBlock * 2]; 110 | if (threadIdx.x < col_size) { 111 | block_boxes[threadIdx.x * 2 + 0] = 112 | dev_boxes[idx[(threadsPerBlock * col_start + threadIdx.x)] * 2 + 0]; 113 | block_boxes[threadIdx.x * 2 + 1] = 114 | dev_boxes[idx[(threadsPerBlock * col_start + threadIdx.x)] * 2 + 1]; 115 | } 116 | __syncthreads(); 117 | 118 | if (threadIdx.x < row_size) { 119 | const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x; 120 | const scalar_t *cur_box = dev_boxes + idx[cur_box_idx] * 2; 121 | // const scalar_t *cur_box = dev_boxes + idx[cur_box_idx] * 4; 122 | int i = 0; 123 | unsigned long long t = 0; 124 | int start = 0; 125 | if (row_start == col_start) { 126 | start = threadIdx.x + 1; 127 | } 128 | for (i = start; i < col_size; i++) { 129 | // if (devIoU(cur_box, block_boxes + i * 4) > nms_overlap_thresh) { 130 | if (devIoU(cur_box, block_boxes + i * 2) < nms_overlap_thresh) { 131 | t |= 1ULL << i; 132 | } 133 | } 134 | const int col_blocks = DIVUP(n_boxes, threadsPerBlock); 135 | dev_mask[cur_box_idx * col_blocks + col_start] = t; 136 | } 137 | } 138 | 139 | 140 | __global__ void nms_collect(const int64_t boxes_num, const int64_t col_blocks, int64_t top_k, const int64_t *idx, const int64_t *mask, int64_t *keep, int64_t *parent_object_index, int64_t *num_to_keep) { 141 | int64_t remv[MAX_COL_BLOCKS]; 142 | int64_t num_to_keep_ = 0; 143 | 144 | for (int i = 0; i < col_blocks; i++) { 145 | remv[i] = 0; 146 | } 147 | 148 | for (int i = 0; i < boxes_num; ++i) { 149 | parent_object_index[i] = 0; 150 | } 151 | 152 | for (int i = 0; i < boxes_num; i++) { 153 | int nblock = i / threadsPerBlock; 154 | int inblock = i % threadsPerBlock; 155 | 156 | 157 | if (!(remv[nblock] & (1ULL << inblock))) { 158 | int64_t idxi = idx[i]; 159 | keep[num_to_keep_] = idxi; 160 | const int64_t *p = &mask[0] + i * col_blocks; 161 | for (int j = nblock; j < col_blocks; j++) { 162 | remv[j] |= p[j]; 163 | } 164 | for (int j = i; j < boxes_num; j++) { 165 | int nblockj = j / threadsPerBlock; 166 | int inblockj = j % threadsPerBlock; 167 | if (p[nblockj] & (1ULL << inblockj)) 168 | parent_object_index[idx[j]] = num_to_keep_+1; 169 | } 170 | parent_object_index[idx[i]] = num_to_keep_+1; 171 | 172 | num_to_keep_++; 173 | 174 | if (num_to_keep_==top_k) 175 | break; 176 | } 177 | } 178 | 179 | // Initialize the rest of the keep array to avoid uninitialized values. 180 | for (int i = num_to_keep_; i < boxes_num; ++i) 181 | keep[i] = 0; 182 | 183 | *num_to_keep = min(top_k,num_to_keep_); 184 | } 185 | 186 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 187 | 188 | std::vector nms_cuda_forward( 189 | at::Tensor boxes, 190 | at::Tensor idx, 191 | float nms_overlap_thresh, 192 | unsigned long top_k) { 193 | 194 | // // check tensor value 195 | // auto boxes_a = boxes.packed_accessor64(); 196 | // auto idx_a = idx.packed_accessor64(); 197 | // printTensorKernel<<<1, 1>>>(boxes_a, idx_a, boxes.size(0)); 198 | 199 | const auto boxes_num = boxes.size(0); 200 | 201 | const int col_blocks = DIVUP(boxes_num, threadsPerBlock); 202 | 203 | AT_ASSERTM (col_blocks < MAX_COL_BLOCKS, "The number of column blocks must be less than MAX_COL_BLOCKS. Increase the MAX_COL_BLOCKS constant if needed."); 204 | 205 | auto longOptions = torch::TensorOptions().device(torch::kCUDA).dtype(torch::kLong); 206 | auto mask = at::empty({boxes_num * col_blocks}, longOptions); 207 | 208 | dim3 blocks(DIVUP(boxes_num, threadsPerBlock), 209 | DIVUP(boxes_num, threadsPerBlock)); 210 | dim3 threads(threadsPerBlock); 211 | 212 | CHECK_CONTIGUOUS(boxes); 213 | CHECK_CONTIGUOUS(idx); 214 | CHECK_CONTIGUOUS(mask); 215 | 216 | AT_DISPATCH_FLOATING_TYPES(boxes.type(), "nms_cuda_forward", ([&] { 217 | nms_kernel<<>>(boxes_num, 218 | (scalar_t)nms_overlap_thresh, 219 | boxes.data(), 220 | idx.data(), 221 | mask.data()); 222 | })); 223 | 224 | gpuErrchk(cudaPeekAtLastError()); 225 | gpuErrchk(cudaDeviceSynchronize()); 226 | 227 | auto keep = at::empty({boxes_num}, longOptions); 228 | auto parent_object_index = at::empty({boxes_num}, longOptions); 229 | auto num_to_keep = at::empty({}, longOptions); 230 | 231 | nms_collect<<<1, 1>>>(boxes_num, col_blocks, top_k, 232 | idx.data(), 233 | mask.data(), 234 | keep.data(), 235 | parent_object_index.data(), 236 | num_to_keep.data()); 237 | 238 | 239 | return {keep,num_to_keep,parent_object_index}; 240 | } 241 | 242 | -------------------------------------------------------------------------------- /dr_spaam/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name="dr_spaam", 5 | version="1.2.0", 6 | author="Dan Jia", 7 | author_email="jia@vision.rwth-aachen.de", 8 | packages=find_packages(include=["dr_spaam", "dr_spaam.*", "dr_spaam.*.*"]), 9 | license="LICENSE.txt", 10 | description="DR-SPAAM, a deep-learning based person detector for 2D range data.", 11 | ) 12 | -------------------------------------------------------------------------------- /dr_spaam/tests/test_dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import yaml 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | 7 | import dr_spaam.utils.utils as u 8 | from dr_spaam.dataset import get_dataloader 9 | 10 | _X_LIM = (-7, 7) 11 | _Y_LIM = (-7, 7) 12 | # _X_LIM = (-15, 15) 13 | # _Y_LIM = (-15, 15) 14 | _MAX_COUNT = 3 15 | _INTERACTIVE = False 16 | _SAVE_DIR = "/home/jia/tmp_imgs/test_dataloader" 17 | 18 | 19 | def _plot_sample_light(fig, ax, ib, count, data_dict): 20 | plt.cla() 21 | ax.set_xlim(_X_LIM[0], _X_LIM[1]) 22 | ax.set_ylim(_Y_LIM[0], _Y_LIM[1]) 23 | ax.set_xlabel("x [m]") 24 | ax.set_ylabel("y [m]") 25 | ax.set_aspect("equal") 26 | # ax.set_title(f"Frame {data_dict['idx'][ib]}. Press any key to exit.") 27 | 28 | # scan and cls label 29 | scan_r = data_dict["scans"][ib][-1] 30 | scan_phi = data_dict["scan_phi"][ib] 31 | scan_x, scan_y = u.rphi_to_xy(scan_r, scan_phi) 32 | ax.scatter(scan_x, scan_y, s=0.5, c="blue") 33 | 34 | # annotation 35 | ann = data_dict["dets_wp"][ib] 36 | ann_valid_mask = data_dict["anns_valid_mask"][ib] 37 | if len(ann) > 0: 38 | ann = np.array(ann) 39 | det_x, det_y = u.rphi_to_xy(ann[:, 0], ann[:, 1]) 40 | for x, y, valid in zip(det_x, det_y, ann_valid_mask): 41 | if valid: 42 | # c = plt.Circle((x, y), radius=0.1, color="red", fill=True) 43 | c = plt.Circle((x, y), radius=0.4, color="red", fill=False) 44 | ax.add_artist(c) 45 | 46 | 47 | def _plot_sample(fig, ax, ib, count, data_dict): 48 | plt.cla() 49 | ax.set_xlim(_X_LIM[0], _X_LIM[1]) 50 | ax.set_ylim(_Y_LIM[0], _Y_LIM[1]) 51 | ax.set_xlabel("x [m]") 52 | ax.set_ylabel("y [m]") 53 | ax.set_aspect("equal") 54 | ax.set_title(f"Frame {data_dict['idx'][ib]}. Press any key to exit.") 55 | 56 | # scan and cls label 57 | scan_r = data_dict["scans"][ib][-1] 58 | scan_phi = data_dict["scan_phi"][ib] 59 | scan_x, scan_y = u.rphi_to_xy(scan_r, scan_phi) 60 | 61 | target_cls = data_dict["target_cls"][ib] 62 | ax.scatter(scan_x[target_cls == -2], scan_y[target_cls == -2], s=1, c="yellow") 63 | ax.scatter(scan_x[target_cls == -1], scan_y[target_cls == -1], s=1, c="orange") 64 | ax.scatter(scan_x[target_cls == 0], scan_y[target_cls == 0], s=1, c="black") 65 | ax.scatter(scan_x[target_cls > 0], scan_y[target_cls > 0], s=1, c="green") 66 | 67 | # annotation 68 | ann = data_dict["dets_wp"][ib] 69 | ann_valid_mask = data_dict["anns_valid_mask"][ib] 70 | if len(ann) > 0: 71 | ann = np.array(ann) 72 | det_x, det_y = u.rphi_to_xy(ann[:, 0], ann[:, 1]) 73 | for x, y, valid in zip(det_x, det_y, ann_valid_mask): 74 | c = "blue" if valid else "orange" 75 | c = plt.Circle((x, y), radius=0.4, color=c, fill=False) 76 | ax.add_artist(c) 77 | 78 | # reg label 79 | target_reg = data_dict["target_reg"][ib] 80 | dets_r, dets_phi = u.canonical_to_global( 81 | scan_r, scan_phi, target_reg[:, 0], target_reg[:, 1] 82 | ) 83 | dets_r = dets_r[target_cls > 0] 84 | dets_phi = dets_phi[target_cls > 0] 85 | dets_x, dets_y = u.rphi_to_xy(dets_r, dets_phi) 86 | ax.scatter(dets_x, dets_y, s=10, c="red") 87 | 88 | 89 | def _test_dataloader(): 90 | with open("./base_dr_spaam_jrdb_cfg.yaml", "r") as f: 91 | cfg = yaml.safe_load(f) 92 | 93 | cfg["dataset"]["pseudo_label"] = False 94 | cfg["dataset"]["pl_correction_level"] = 0 95 | 96 | test_loader = get_dataloader( 97 | split="val", 98 | batch_size=5, 99 | num_workers=1, 100 | shuffle=False, 101 | dataset_cfg=cfg["dataset"], 102 | ) 103 | 104 | fig = plt.figure(figsize=(10, 10)) 105 | ax = fig.add_subplot(111) 106 | 107 | _break = False 108 | 109 | if _INTERACTIVE: 110 | 111 | def p(event): 112 | nonlocal _break 113 | _break = True 114 | 115 | fig.canvas.mpl_connect("key_press_event", p) 116 | else: 117 | if os.path.exists(_SAVE_DIR): 118 | shutil.rmtree(_SAVE_DIR) 119 | os.makedirs(_SAVE_DIR) 120 | 121 | for count, data_dict in enumerate(test_loader): 122 | if count >= _MAX_COUNT: 123 | break 124 | 125 | for ib in range(len(data_dict["input"])): 126 | _plot_sample(fig, ax, ib, count, data_dict) 127 | 128 | if _INTERACTIVE: 129 | plt.pause(0.1) 130 | else: 131 | plt.savefig( 132 | os.path.join( 133 | _SAVE_DIR, f"b{count:03}s{ib:02}f{data_dict['idx'][ib]:04}.pdf" 134 | ) 135 | ) 136 | 137 | if _INTERACTIVE: 138 | plt.show() 139 | 140 | 141 | if __name__ == "__main__": 142 | _test_dataloader() 143 | -------------------------------------------------------------------------------- /dr_spaam/tests/test_detector.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | 6 | from dr_spaam.utils import utils as u 7 | import dr_spaam.utils.jrdb_transforms as jt 8 | from dr_spaam.detector import Detector 9 | from dr_spaam.datahandle.jrdb_handle import JRDBHandle 10 | 11 | _X_LIM = (-15, 15) 12 | _Y_LIM = (-15, 15) 13 | _INTERACTIVE = False 14 | _SAVE_DIR = "/home/jia/tmp_imgs/test_detector" 15 | 16 | 17 | def _plot_annotation(ann, ax, color, radius): 18 | if len(ann) > 0: 19 | ann = np.array(ann) 20 | det_x, det_y = u.rphi_to_xy(ann[:, 0], ann[:, 1]) 21 | for x, y in zip(det_x, det_y): 22 | c = plt.Circle( 23 | (x, y), radius=radius, color=color, fill=False, linestyle="--" 24 | ) 25 | ax.add_artist(c) 26 | 27 | 28 | def test_detector(): 29 | data_handle = JRDBHandle( 30 | split="train", 31 | cfg={"data_dir": "./data/JRDB", "num_scans": 10, "scan_stride": 1}, 32 | ) 33 | 34 | # ckpt_file = "/home/jia/ckpts/ckpt_jrdb_ann_drow3_e40.pth" 35 | # d = Detector( 36 | # ckpt_file, model="DROW3", gpu=True, stride=1, panoramic_scan=True 37 | # ) 38 | 39 | ckpt_file = "/home/jia/ckpts/ckpt_jrdb_ann_dr_spaam_e20.pth" 40 | d = Detector(ckpt_file, model="DR-SPAAM", gpu=True, stride=1, panoramic_scan=True) 41 | 42 | d.set_laser_fov(360) 43 | 44 | fig = plt.figure(figsize=(10, 10)) 45 | ax = fig.add_subplot(111) 46 | 47 | _break = False 48 | 49 | if _INTERACTIVE: 50 | 51 | def p(event): 52 | nonlocal _break 53 | _break = True 54 | 55 | fig.canvas.mpl_connect("key_press_event", p) 56 | else: 57 | if os.path.exists(_SAVE_DIR): 58 | shutil.rmtree(_SAVE_DIR) 59 | os.makedirs(_SAVE_DIR) 60 | 61 | for i, data_dict in enumerate(data_handle): 62 | if _break: 63 | break 64 | 65 | # plot scans 66 | scan_r = data_dict["laser_data"][-1, ::-1] # to DROW frame 67 | scan_x, scan_y = u.rphi_to_xy(scan_r, data_dict["laser_grid"]) 68 | 69 | plt.cla() 70 | ax.set_aspect("equal") 71 | ax.set_xlim(_X_LIM[0], _X_LIM[1]) 72 | ax.set_ylim(_Y_LIM[0], _Y_LIM[1]) 73 | ax.set_xlabel("x [m]") 74 | ax.set_ylabel("y [m]") 75 | ax.set_title(f"Frame {data_dict['idx']}. Press any key to exit.") 76 | # ax.axis("off") 77 | 78 | ax.scatter(scan_x, scan_y, s=1, c="black") 79 | 80 | # plot annotation 81 | ann_xyz = [ 82 | (ann["box"]["cx"], ann["box"]["cy"], ann["box"]["cz"]) 83 | for ann in data_dict["pc_anns"] 84 | ] 85 | if len(ann_xyz) > 0: 86 | ann_xyz = np.array(ann_xyz, dtype=np.float32).T 87 | ann_xyz = jt.transform_pts_base_to_laser(ann_xyz) 88 | ann_xyz[1] = -ann_xyz[1] # to DROW frame 89 | for xyz in ann_xyz.T: 90 | c = plt.Circle( 91 | (xyz[0], xyz[1]), 92 | radius=0.4, 93 | color="red", 94 | fill=False, 95 | linestyle="--", 96 | ) 97 | ax.add_artist(c) 98 | 99 | # plot detection 100 | dets_xy, dets_cls, _ = d(scan_r) 101 | dets_cls_norm = np.clip(dets_cls, 0, 0.3) / 0.3 102 | for xy, cls_norm in zip(dets_xy, dets_cls_norm): 103 | color = (1.0 - cls_norm, 1.0, 1.0 - cls_norm) 104 | c = plt.Circle( 105 | (xy[0], xy[1]), radius=0.4, color=color, fill=False, linestyle="-" 106 | ) 107 | ax.add_artist(c) 108 | 109 | if _INTERACTIVE: 110 | plt.pause(0.1) 111 | else: 112 | plt.savefig(os.path.join(_SAVE_DIR, f"frame_{data_dict['idx']:04}.png")) 113 | 114 | if _INTERACTIVE: 115 | plt.show() 116 | 117 | 118 | if __name__ == "__main__": 119 | test_detector() 120 | -------------------------------------------------------------------------------- /dr_spaam/tests/test_detr_dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import yaml 4 | import matplotlib.pyplot as plt 5 | 6 | import dr_spaam.utils.utils as u 7 | from dr_spaam.dataset.get_dataloader import get_dataloader 8 | 9 | _X_LIM = (-7, 7) 10 | _Y_LIM = (-7, 7) 11 | _INTERACTIVE = False 12 | _SAVE_DIR = "/home/jia/tmp_imgs/test_detr_dataloader" 13 | 14 | 15 | def _test_detr_dataloader(): 16 | with open("./tests/test.yaml", "r") as f: 17 | cfg = yaml.safe_load(f) 18 | cfg["dataset"]["DataHandle"]["tracking"] = True 19 | cfg["dataset"]["DataHandle"]["num_scans"] = 1 20 | 21 | test_loader = get_dataloader( 22 | split="train", 23 | batch_size=8, 24 | num_workers=1, 25 | shuffle=True, 26 | dataset_cfg=cfg["dataset"], 27 | ) 28 | 29 | fig = plt.figure(figsize=(10, 10)) 30 | ax = fig.add_subplot(111) 31 | 32 | _break = False 33 | 34 | if _INTERACTIVE: 35 | 36 | def p(event): 37 | nonlocal _break 38 | _break = True 39 | 40 | fig.canvas.mpl_connect("key_press_event", p) 41 | else: 42 | if os.path.exists(_SAVE_DIR): 43 | shutil.rmtree(_SAVE_DIR) 44 | os.makedirs(_SAVE_DIR) 45 | 46 | for count, data_dict in enumerate(test_loader): 47 | for ib in range(len(data_dict["input"])): 48 | fr_idx = data_dict["frame_dict_curr"][ib]["idx"] 49 | 50 | plt.cla() 51 | ax.set_xlim(_X_LIM[0], _X_LIM[1]) 52 | ax.set_ylim(_Y_LIM[0], _Y_LIM[1]) 53 | ax.set_xlabel("x [m]") 54 | ax.set_ylabel("y [m]") 55 | ax.set_aspect("equal") 56 | ax.set_title(f"Frame {fr_idx}. Press any key to exit.") 57 | 58 | # scan and cls label 59 | scan_r = data_dict["frame_dict_curr"][ib]["laser_data"][-1] 60 | scan_phi = data_dict["frame_dict_curr"][ib]["laser_grid"] 61 | scan_x, scan_y = u.rphi_to_xy(scan_r, scan_phi) 62 | 63 | target_cls = data_dict["target_cls"][ib] 64 | ax.scatter(scan_x[target_cls < 0], scan_y[target_cls < 0], s=1, c="orange") 65 | ax.scatter(scan_x[target_cls == 0], scan_y[target_cls == 0], s=1, c="black") 66 | ax.scatter(scan_x[target_cls > 0], scan_y[target_cls > 0], s=1, c="green") 67 | 68 | # annotation for tracking 69 | anns_tracking = data_dict["frame_dict_curr"][ib]["dets_rphi_prev"] 70 | anns_tracking_mask = data_dict["anns_tracking_mask"][ib] 71 | anns_tracking = anns_tracking[:, anns_tracking_mask] 72 | if len(anns_tracking) > 0: 73 | det_x, det_y = u.rphi_to_xy(anns_tracking[0], anns_tracking[1]) 74 | for x, y in zip(det_x, det_y): 75 | c = plt.Circle( 76 | (x, y), radius=0.5, color="gray", fill=False, linestyle="--" 77 | ) 78 | ax.add_artist(c) 79 | 80 | # annotation 81 | anns = data_dict["frame_dict_curr"][ib]["dets_rphi"] 82 | anns_valid_mask = data_dict["anns_valid_mask"][ib] 83 | if len(anns) > 0: 84 | det_x, det_y = u.rphi_to_xy(anns[0], anns[1]) 85 | for x, y, valid in zip(det_x, det_y, anns_valid_mask): 86 | c = "blue" if valid else "orange" 87 | c = plt.Circle((x, y), radius=0.4, color=c, fill=False) 88 | ax.add_artist(c) 89 | 90 | # reg label for previous frame 91 | target_reg_prev = data_dict["target_reg_prev"][ib] 92 | target_tracking_flag = data_dict["target_tracking_flag"][ib] 93 | dets_r_prev, dets_phi_prev = u.canonical_to_global( 94 | scan_r, scan_phi, target_reg_prev[:, 0], target_reg_prev[:, 1] 95 | ) 96 | dets_r_prev = dets_r_prev[target_tracking_flag] 97 | dets_phi_prev = dets_phi_prev[target_tracking_flag] 98 | dets_x_prev, dets_y_prev = u.rphi_to_xy(dets_r_prev, dets_phi_prev) 99 | ax.scatter(dets_x_prev, dets_y_prev, s=25, c="gray") 100 | 101 | # reg label for current frame 102 | target_reg = data_dict["target_reg"][ib] 103 | dets_r, dets_phi = u.canonical_to_global( 104 | scan_r, scan_phi, target_reg[:, 0], target_reg[:, 1] 105 | ) 106 | dets_r = dets_r[target_cls > 0] 107 | dets_phi = dets_phi[target_cls > 0] 108 | dets_x, dets_y = u.rphi_to_xy(dets_r, dets_phi) 109 | ax.scatter(dets_x, dets_y, s=10, c="red") 110 | 111 | if _INTERACTIVE: 112 | plt.pause(0.1) 113 | else: 114 | plt.savefig( 115 | os.path.join(_SAVE_DIR, f"b{count:03}s{ib:02}f{fr_idx:04}.png",) 116 | ) 117 | 118 | if _INTERACTIVE: 119 | plt.show() 120 | 121 | 122 | if __name__ == "__main__": 123 | _test_detr_dataloader() 124 | -------------------------------------------------------------------------------- /dr_spaam/tests/test_drow_handle.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | 6 | from dr_spaam.datahandle import DROWHandle 7 | from dr_spaam.utils import utils as u 8 | 9 | 10 | _X_LIM = (-7, 7) 11 | _Y_LIM = (-7, 7) 12 | _INTERACTIVE = False 13 | _SAVE_DIR = "/home/jia/tmp_imgs/test_drow_handle" 14 | 15 | 16 | def _plot_annotation(ann, ax, color, radius): 17 | if len(ann) > 0: 18 | ann = np.array(ann) 19 | det_x, det_y = u.rphi_to_xy(ann[:, 0], ann[:, 1]) 20 | for x, y in zip(det_x, det_y): 21 | c = plt.Circle((x, y), radius=radius, color=color, fill=False) 22 | ax.add_artist(c) 23 | 24 | 25 | def _plot_sequence(): 26 | drow_handle = DROWHandle( 27 | split="train", 28 | cfg={"num_scans": 1, "scan_stride": 1, "data_dir": "./data/DROWv2-data"}, 29 | ) 30 | 31 | fig = plt.figure(figsize=(10, 10)) 32 | ax = fig.add_subplot(111) 33 | 34 | _break = False 35 | 36 | if _INTERACTIVE: 37 | 38 | def p(event): 39 | nonlocal _break 40 | _break = True 41 | 42 | fig.canvas.mpl_connect("key_press_event", p) 43 | else: 44 | if os.path.exists(_SAVE_DIR): 45 | shutil.rmtree(_SAVE_DIR) 46 | os.makedirs(_SAVE_DIR) 47 | 48 | for i, data_dict in enumerate(drow_handle): 49 | if _break: 50 | break 51 | 52 | scan_x, scan_y = u.rphi_to_xy(data_dict["scans"][-1], data_dict["scan_phi"]) 53 | 54 | plt.cla() 55 | ax.set_aspect("equal") 56 | ax.set_xlim(_X_LIM[0], _X_LIM[1]) 57 | ax.set_ylim(_Y_LIM[0], _Y_LIM[1]) 58 | ax.set_xlabel("x [m]") 59 | ax.set_ylabel("y [m]") 60 | ax.set_title(f"Frame {data_dict['idx']}. Press any key to exit.") 61 | # ax.axis("off") 62 | 63 | ax.scatter(scan_x, scan_y, s=1, c="black") 64 | 65 | _plot_annotation(data_dict["dets_wc"], ax, "red", 0.6) 66 | _plot_annotation(data_dict["dets_wa"], ax, "green", 0.4) 67 | _plot_annotation(data_dict["dets_wp"], ax, "blue", 0.35) 68 | 69 | if _INTERACTIVE: 70 | plt.pause(0.1) 71 | else: 72 | plt.savefig(os.path.join(_SAVE_DIR, f"frame_{data_dict['idx']:04}.png")) 73 | 74 | if _INTERACTIVE: 75 | plt.show() 76 | 77 | 78 | if __name__ == "__main__": 79 | _plot_sequence() 80 | -------------------------------------------------------------------------------- /dr_spaam/tests/test_inference_speed.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | 4 | from dr_spaam.detector import Detector 5 | from dr_spaam.datahandle.drow_handle import DROWHandle 6 | from dr_spaam.datahandle.jrdb_handle import JRDBHandle 7 | 8 | 9 | _FRAME_NUM = 100 10 | _STRIDE = 1 11 | 12 | 13 | def test_inference_speed_on_drow(): 14 | data_handle = DROWHandle( 15 | split="test", 16 | cfg={"num_scans": 1, "scan_stride": 1, "data_dir": "./data/DROWv2-data"}, 17 | ) 18 | 19 | ckpt_file = "./ckpts/ckpt_jrdb_ann_drow3_e40.pth" 20 | detector_drow3 = Detector( 21 | ckpt_file, model="DROW3", gpu=True, stride=_STRIDE, panoramic_scan=True 22 | ) 23 | detector_drow3.set_laser_fov(225) 24 | 25 | ckpt_file = "./ckpts/ckpt_jrdb_ann_dr_spaam_e20.pth" 26 | detector_dr_spaam = Detector( 27 | ckpt_file, model="DR-SPAAM", gpu=True, stride=_STRIDE, panoramic_scan=True 28 | ) 29 | detector_dr_spaam.set_laser_fov(225) 30 | 31 | # sample random frames, discard beginning frames, where PyTorch is searching 32 | # for optimal algorithm 33 | frame_inds = np.random.randint(0, len(data_handle), size=(_FRAME_NUM + 20,)) 34 | 35 | for n, detector in zip(["DROW3", "DR-SPAAM"], [detector_drow3, detector_dr_spaam]): 36 | t_list = [] 37 | for frame_idx in frame_inds: 38 | data_dict = data_handle[frame_idx] 39 | scan_r = data_dict["scans"][-1] 40 | 41 | t0 = time.time() 42 | dets_xy, dets_cls, _ = detector(scan_r) 43 | t_list.append(time.time() - t0) 44 | 45 | t_ave = np.array(t_list[20:]).mean() 46 | print(f"{n} on DROW: {1.0 / t_ave:.1f} FPS " f"({t_ave:.6f} seconds per frame)") 47 | 48 | 49 | def test_inference_speed_on_jrdb(): 50 | data_handle = JRDBHandle( 51 | split="train", 52 | cfg={"data_dir": "./data/JRDB", "num_scans": 1, "scan_stride": 1}, 53 | ) 54 | 55 | ckpt_file = "./ckpts/ckpt_jrdb_ann_drow3_e40.pth" 56 | detector_drow3 = Detector( 57 | ckpt_file, model="DROW3", gpu=True, stride=_STRIDE, panoramic_scan=True 58 | ) 59 | detector_drow3.set_laser_fov(360) 60 | 61 | ckpt_file = "./ckpts/ckpt_jrdb_ann_dr_spaam_e20.pth" 62 | detector_dr_spaam = Detector( 63 | ckpt_file, model="DR-SPAAM", gpu=True, stride=_STRIDE, panoramic_scan=True 64 | ) 65 | detector_dr_spaam.set_laser_fov(360) 66 | 67 | frame_inds = np.random.randint(0, len(data_handle), size=(_FRAME_NUM,)) 68 | 69 | # sample random frames, discard beginning frames, where PyTorch is searching 70 | # for optimal algorithm 71 | frame_inds = np.random.randint(0, len(data_handle), size=(_FRAME_NUM + 20,)) 72 | 73 | for n, detector in zip(["DROW3", "DR-SPAAM"], [detector_drow3, detector_dr_spaam]): 74 | t_list = [] 75 | for frame_idx in frame_inds: 76 | data_dict = data_handle[frame_idx] 77 | scan_r = data_dict["laser_data"][-1, ::-1] # to DROW frame 78 | 79 | t0 = time.time() 80 | dets_xy, dets_cls, _ = detector(scan_r) 81 | t_list.append(time.time() - t0) 82 | 83 | t_ave = np.array(t_list[20:]).mean() 84 | print(f"{n} on JRDB: {1.0 / t_ave:.1f} FPS " f"({t_ave:.6f} seconds per frame)") 85 | 86 | 87 | if __name__ == "__main__": 88 | test_inference_speed_on_drow() 89 | test_inference_speed_on_jrdb() 90 | -------------------------------------------------------------------------------- /dr_spaam/tests/test_jrdb_handle.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import shutil 4 | import time 5 | 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | from matplotlib.gridspec import GridSpec 9 | 10 | from dr_spaam.datahandle import JRDBHandle 11 | import dr_spaam.utils as u 12 | import dr_spaam.utils.jrdb_transforms as jt 13 | 14 | 15 | _XY_LIM = (-7, 7) 16 | # _XY_LIM = (-30, 30) 17 | _Z_LIM = (-1, 2) 18 | _INTERACTIVE = False 19 | _SAVE_DIR = "/home/jia/tmp_imgs/test_jrdb_handle" 20 | 21 | 22 | def _get_pts_color(pts, dim, r_max=20.0): 23 | d = np.clip(np.hypot(pts[0], pts[1]), 0.0, r_max) / r_max 24 | # d = np.clip(np.abs(pts[0]), 0.0, r_max) / r_max 25 | color = d.reshape(-1, 1).repeat(3, axis=1) 26 | color[:, dim] = 1 27 | return color 28 | 29 | 30 | def _test_loading_speed(): 31 | data_handle = JRDBHandle( 32 | split="train", 33 | cfg={"data_dir": "./data/JRDB", "num_scans": 10, "scan_stride": 1}, 34 | ) 35 | 36 | total_frame = 100 37 | inds = random.sample(range(len(data_handle)), total_frame) 38 | 39 | t0 = time.time() 40 | for idx in inds: 41 | _ = data_handle[idx] 42 | t1 = time.time() 43 | 44 | print(f"Loaded {total_frame} frames in {t1 - t0} seconds.") 45 | 46 | 47 | def _plot_sequence(): 48 | jrdb_handle = JRDBHandle( 49 | split="train", 50 | cfg={"data_dir": "./data/JRDB", "num_scans": 10, "scan_stride": 1}, 51 | ) 52 | 53 | fig = plt.figure(figsize=(20, 10)) 54 | gs = GridSpec(3, 2, figure=fig) 55 | 56 | ax_im = fig.add_subplot(gs[0, :]) 57 | ax_bev = fig.add_subplot(gs[1:, 1]) 58 | ax_fpv_xz = fig.add_subplot(gs[1, 0]) 59 | ax_fpv_yz = fig.add_subplot(gs[2, 0]) 60 | 61 | color_pool = np.random.uniform(size=(100, 3)) 62 | 63 | _break = False 64 | 65 | if _INTERACTIVE: 66 | 67 | def p(event): 68 | nonlocal _break 69 | _break = True 70 | 71 | fig.canvas.mpl_connect("key_press_event", p) 72 | else: 73 | if os.path.exists(_SAVE_DIR): 74 | shutil.rmtree(_SAVE_DIR) 75 | os.makedirs(_SAVE_DIR) 76 | 77 | for i, data_dict in enumerate(jrdb_handle): 78 | if _break: 79 | break 80 | 81 | # lidar 82 | pc_xyz_upper = jt.transform_pts_upper_velodyne_to_base( 83 | data_dict["pc_data"]["upper_velodyne"] 84 | ) 85 | pc_xyz_lower = jt.transform_pts_lower_velodyne_to_base( 86 | data_dict["pc_data"]["lower_velodyne"] 87 | ) 88 | 89 | # labels 90 | boxes = [] 91 | for ann in data_dict["pc_anns"]: 92 | jrdb_handle.box_is_on_ground(ann) 93 | boxes.append(u.box_from_jrdb(ann, fast_mode=False)) 94 | 95 | # laser 96 | laser_r = data_dict["laser_data"][-1] 97 | laser_phi = data_dict["laser_grid"] 98 | laser_z = data_dict["laser_z"] 99 | laser_x, laser_y = u.rphi_to_xy(laser_r, laser_phi) 100 | pc_xyz_laser = jt.transform_pts_laser_to_base( 101 | np.stack((laser_x, laser_y, laser_z), axis=0) 102 | ) 103 | 104 | # BEV 105 | ax_bev.cla() 106 | ax_bev.set_aspect("equal") 107 | ax_bev.set_xlim(_XY_LIM[0], _XY_LIM[1]) 108 | ax_bev.set_ylim(_XY_LIM[0], _XY_LIM[1]) 109 | ax_bev.set_title(f"Frame {data_dict['idx']}. Press any key to exit.") 110 | ax_bev.set_xlabel("x [m]") 111 | ax_bev.set_ylabel("y [m]") 112 | # ax_bev.axis("off") 113 | 114 | for rgb_dim, pc_xyz in zip( 115 | (2, 1, 0), (pc_xyz_upper, pc_xyz_lower, pc_xyz_laser) 116 | ): 117 | ax_bev.scatter(pc_xyz[0], pc_xyz[1], s=1, c=_get_pts_color(pc_xyz, rgb_dim)) 118 | 119 | for box in boxes: 120 | box.draw_bev(ax_bev, c=color_pool[box.get_id()]) 121 | 122 | # side view 123 | for dim, ax_fpv in zip((0, 1), (ax_fpv_xz, ax_fpv_yz)): 124 | ax_fpv.cla() 125 | ax_fpv.set_aspect("equal") 126 | ax_fpv.set_xlim(_XY_LIM[0], _XY_LIM[1]) 127 | ax_fpv.set_ylim(_Z_LIM[0], _Z_LIM[1]) 128 | ax_fpv.set_title(f"Frame {data_dict['idx']}. Press any key to exit.") 129 | ax_fpv.set_xlabel("x [m]" if dim == 0 else "y [m]") 130 | ax_fpv.set_ylabel("z [m]") 131 | # ax_fpv.axis("off") 132 | 133 | for rgb_dim, pc_xyz in zip( 134 | (2, 1, 0), (pc_xyz_upper, pc_xyz_lower, pc_xyz_laser) 135 | ): 136 | ax_fpv.scatter( 137 | pc_xyz[dim], pc_xyz[2], s=1, c=_get_pts_color(pc_xyz, rgb_dim) 138 | ) 139 | 140 | for box in boxes: 141 | box.draw_fpv(ax_fpv, dim=dim, c=color_pool[box.get_id()]) 142 | 143 | # image 144 | ax_im.cla() 145 | ax_im.axis("off") 146 | ax_im.imshow(data_dict["im_data"]["stitched_image0"]) 147 | 148 | # detection bounding box 149 | for box_dict in data_dict["im_dets"]: 150 | x0, y0, w, h = box_dict["box"] 151 | verts = np.array( 152 | [(x0, y0), (x0, y0 + h), (x0 + w, y0 + h), (x0 + w, y0), (x0, y0)] 153 | ) 154 | c = max(float(box_dict["score"]) - 0.5, 0) * 2.0 155 | ax_im.plot(verts[:, 0], verts[:, 1], c=(1.0 - c, 1.0 - c, 1.0)) 156 | 157 | # laser points on image 158 | p_xy, ib_mask = jt.transform_pts_base_to_stitched_im(pc_xyz_laser) 159 | ax_im.scatter( 160 | p_xy[0, ib_mask], 161 | p_xy[1, ib_mask], 162 | s=1, 163 | c=_get_pts_color(pc_xyz_laser[:, ib_mask], dim=0), 164 | ) 165 | 166 | if _INTERACTIVE: 167 | plt.pause(0.1) 168 | else: 169 | plt.savefig(os.path.join(_SAVE_DIR, f"frame_{data_dict['idx']:04}.png")) 170 | 171 | if _INTERACTIVE: 172 | plt.show() 173 | 174 | 175 | if __name__ == "__main__": 176 | # _test_loading_speed() 177 | _plot_sequence() 178 | -------------------------------------------------------------------------------- /dr_spaam/tests/test_jrdb_handle_mayavi.py: -------------------------------------------------------------------------------- 1 | import random 2 | import time 3 | 4 | import numpy as np 5 | from mayavi import mlab 6 | 7 | from dr_spaam.datahandle.jrdb_handle import JRDBHandle 8 | import dr_spaam.utils.utils as u 9 | import dr_spaam.utils.utils_box3d as ub3d 10 | import dr_spaam.utils.jrdb_transforms as jt 11 | 12 | _COLOR_INSTANCE = True 13 | 14 | 15 | def _test_loading_speed(): 16 | data_handle = JRDBHandle( 17 | split="train", 18 | cfg={"data_dir": "./data/JRDB", "num_scans": 10, "scan_stride": 1}, 19 | ) 20 | 21 | total_frame = 100 22 | inds = random.sample(range(len(data_handle)), total_frame) 23 | 24 | t0 = time.time() 25 | for idx in inds: 26 | _ = data_handle[idx] 27 | t1 = time.time() 28 | 29 | print(f"Loaded {total_frame} frames in {t1 - t0} seconds.") 30 | 31 | 32 | def _plot_sequence(): 33 | jrdb_handle = JRDBHandle( 34 | split="train", 35 | cfg={"data_dir": "./data/JRDB", "num_scans": 10, "scan_stride": 1}, 36 | ) 37 | 38 | color_pool = np.random.uniform(size=(100, 3)) 39 | 40 | for i, data_dict in enumerate(jrdb_handle): 41 | # lidar 42 | pc_xyz_upper = jt.transform_pts_upper_velodyne_to_base( 43 | data_dict["pc_data"]["upper_velodyne"] 44 | ) 45 | pc_xyz_lower = jt.transform_pts_lower_velodyne_to_base( 46 | data_dict["pc_data"]["lower_velodyne"] 47 | ) 48 | 49 | # laser 50 | laser_r = data_dict["laser_data"][-1] 51 | laser_phi = data_dict["laser_grid"] 52 | laser_z = data_dict["laser_z"] 53 | laser_x, laser_y = u.rphi_to_xy(laser_r, laser_phi) 54 | pc_xyz_laser = jt.transform_pts_laser_to_base( 55 | np.stack((laser_x, laser_y, laser_z), axis=0) 56 | ) 57 | 58 | if _COLOR_INSTANCE: 59 | # labels 60 | boxes, label_ids = [], [] 61 | for ann in data_dict["pc_anns"]: 62 | # jrdb_handle.box_is_on_ground(ann) 63 | box, b_id = ub3d.box_from_jrdb(ann) 64 | boxes.append(box) 65 | label_ids.append(b_id) 66 | boxes = np.array(boxes) # (B, 7) 67 | pc = np.concatenate([pc_xyz_laser, pc_xyz_upper, pc_xyz_lower], axis=1) 68 | in_box_mask, closest_box_inds = ub3d.associate_points_and_boxes( 69 | pc, boxes, resize_factor=1.0 70 | ) 71 | 72 | # plot bg points 73 | bg_pc = pc[:, np.logical_not(in_box_mask)] 74 | mlab.points3d( 75 | bg_pc[0], 76 | bg_pc[1], 77 | bg_pc[2], 78 | scale_factor=0.05, 79 | color=(1.0, 0.0, 0.0), 80 | ) 81 | 82 | # plot box and fg points 83 | fg_pc = pc[:, in_box_mask] 84 | fg_box_inds = closest_box_inds[in_box_mask] 85 | corners_xyz, connect_inds = ub3d.boxes_to_corners( 86 | boxes, rtn_connect_inds=True 87 | ) 88 | for box_idx, (p_id, corner_xyz) in enumerate(zip(label_ids, corners_xyz)): 89 | color = tuple(color_pool[p_id % 100]) 90 | # box 91 | for inds in connect_inds: 92 | mlab.plot3d( 93 | corner_xyz[0, inds], 94 | corner_xyz[1, inds], 95 | corner_xyz[2, inds], 96 | tube_radius=None, 97 | line_width=5, 98 | color=color, 99 | ) 100 | 101 | # point 102 | in_box_pc = fg_pc[:, fg_box_inds == box_idx] 103 | mlab.points3d( 104 | in_box_pc[0], 105 | in_box_pc[1], 106 | in_box_pc[2], 107 | scale_factor=0.05, 108 | color=color, 109 | ) 110 | 111 | else: 112 | # plot points 113 | mlab.points3d( 114 | pc_xyz_lower[0], 115 | pc_xyz_lower[1], 116 | pc_xyz_lower[2], 117 | scale_factor=0.05, 118 | color=(0.0, 1.0, 0.0), 119 | ) 120 | mlab.points3d( 121 | pc_xyz_upper[0], 122 | pc_xyz_upper[1], 123 | pc_xyz_upper[2], 124 | scale_factor=0.05, 125 | color=(0.0, 0.0, 1.0), 126 | ) 127 | mlab.points3d( 128 | pc_xyz_laser[0], 129 | pc_xyz_laser[1], 130 | pc_xyz_laser[2], 131 | scale_factor=0.05, 132 | color=(1.0, 0.0, 0.0), 133 | ) 134 | 135 | # plot box 136 | boxes = [] 137 | for ann in data_dict["pc_anns"]: 138 | # jrdb_handle.box_is_on_ground(ann) 139 | box = ub3d.box_from_jrdb(ann, fast_mode=False) 140 | corners_xyz, connect_inds = box.to_corners(resize_factor=1.0, rtn_connect_inds=True) 141 | for inds in connect_inds: 142 | mlab.plot3d( 143 | corners_xyz[0, inds], 144 | corners_xyz[1, inds], 145 | corners_xyz[2, inds], 146 | tube_radius=None, 147 | line_width=5, 148 | color=tuple(color_pool[box.get_id() % 100]), 149 | ) 150 | 151 | mlab.show() 152 | 153 | 154 | if __name__ == "__main__": 155 | # _test_loading_speed() 156 | _plot_sequence() 157 | -------------------------------------------------------------------------------- /dr_spaam_ros/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 2.8.3) 2 | project(dr_spaam_ros) 3 | 4 | find_package(catkin REQUIRED 5 | COMPONENTS 6 | ) 7 | 8 | catkin_package() 9 | 10 | catkin_python_setup() 11 | 12 | 13 | -------------------------------------------------------------------------------- /dr_spaam_ros/config/dr_spaam_ros.yaml: -------------------------------------------------------------------------------- 1 | weight_file: "/home/jia/ckpts/ckpt_jrdb_ann_drow3_e40.pth" 2 | detector_model: "DROW3" # DROW3 or DR-SPAAM 3 | # weight_file: "/home/jia/ckpts/ckpt_jrdb_ann_dr_spaam_e20.pth" 4 | # detector_model: "DR-SPAAM" # DROW3 or DR-SPAAM 5 | conf_thresh: 0.5 6 | stride: 1 # use this to skip laser points 7 | panoramic_scan: True # Set to true if the scan covers 360 degree 8 | -------------------------------------------------------------------------------- /dr_spaam_ros/config/topics.yaml: -------------------------------------------------------------------------------- 1 | publisher: 2 | detections: 3 | topic: /dr_spaam_detections 4 | queue_size: 1 5 | latch: false 6 | 7 | rviz: 8 | topic: /dr_spaam_rviz 9 | queue_size: 1 10 | latch: false 11 | 12 | subscriber: 13 | scan: 14 | topic: /segway/scan_multi 15 | queue_size: 1 16 | -------------------------------------------------------------------------------- /dr_spaam_ros/example.rviz: -------------------------------------------------------------------------------- 1 | Panels: 2 | - Class: rviz/Displays 3 | Help Height: 0 4 | Name: Displays 5 | Property Tree Widget: 6 | Expanded: 7 | - /Global Options1 8 | - /LaserScan1 9 | - /Marker1/Status1 10 | Splitter Ratio: 0.6167800426483154 11 | Tree Height: 1681 12 | - Class: rviz/Selection 13 | Name: Selection 14 | - Class: rviz/Tool Properties 15 | Expanded: 16 | - /2D Pose Estimate1 17 | - /2D Nav Goal1 18 | - /Publish Point1 19 | Name: Tool Properties 20 | Splitter Ratio: 0.5886790156364441 21 | - Class: rviz/Views 22 | Expanded: 23 | - /Current View1 24 | Name: Views 25 | Splitter Ratio: 0.5 26 | - Class: rviz/Time 27 | Experimental: false 28 | Name: Time 29 | SyncMode: 0 30 | SyncSource: LaserScan 31 | Preferences: 32 | PromptSaveOnExit: true 33 | Toolbars: 34 | toolButtonStyle: 2 35 | Visualization Manager: 36 | Class: "" 37 | Displays: 38 | - Alpha: 0.30000001192092896 39 | Cell Size: 1 40 | Class: rviz/Grid 41 | Color: 85; 87; 83 42 | Enabled: true 43 | Line Style: 44 | Line Width: 0.029999999329447746 45 | Value: Lines 46 | Name: Grid 47 | Normal Cell Count: 0 48 | Offset: 49 | X: 0 50 | Y: 0 51 | Z: 0 52 | Plane: XY 53 | Plane Cell Count: 2000 54 | Reference Frame: 55 | Value: true 56 | - Class: rviz/TF 57 | Enabled: true 58 | Frame Timeout: 1e+8 59 | Frames: 60 | All Enabled: false 61 | base_chassis_link: 62 | Value: true 63 | base_link: 64 | Value: true 65 | left_wheel_link: 66 | Value: true 67 | m1n6s200_link_1: 68 | Value: true 69 | m1n6s200_link_2: 70 | Value: true 71 | m1n6s200_link_3: 72 | Value: true 73 | m1n6s200_link_4: 74 | Value: true 75 | m1n6s200_link_5: 76 | Value: true 77 | m1n6s200_link_6: 78 | Value: true 79 | m1n6s200_link_base: 80 | Value: true 81 | m1n6s200_link_finger_1: 82 | Value: true 83 | m1n6s200_link_finger_2: 84 | Value: true 85 | odom: 86 | Value: true 87 | pan_link: 88 | Value: true 89 | right_wheel_link: 90 | Value: true 91 | tilt_link: 92 | Value: true 93 | Marker Scale: 1 94 | Name: TF 95 | Show Arrows: true 96 | Show Axes: true 97 | Show Names: true 98 | Tree: 99 | odom: 100 | base_link: 101 | {} 102 | Update Interval: 0 103 | Value: true 104 | - Alpha: 1 105 | Autocompute Intensity Bounds: true 106 | Autocompute Value Bounds: 107 | Max Value: 10 108 | Min Value: -10 109 | Value: true 110 | Axis: Z 111 | Channel Name: intensity 112 | Class: rviz/LaserScan 113 | Color: 114; 159; 207 114 | Color Transformer: FlatColor 115 | Decay Time: 0 116 | Enabled: true 117 | Invert Rainbow: false 118 | Max Color: 255; 255; 255 119 | Max Intensity: 4096 120 | Min Color: 0; 0; 0 121 | Min Intensity: 0 122 | Name: LaserScan 123 | Position Transformer: XYZ 124 | Queue Size: 10 125 | Selectable: true 126 | Size (Pixels): 5 127 | Size (m): 0.10000000149011612 128 | Style: Points 129 | Topic: /segway/scan_multi 130 | Unreliable: false 131 | Use Fixed Frame: true 132 | Use rainbow: true 133 | Value: true 134 | - Alpha: 1 135 | Arrow Length: 1 136 | Axes Length: 0.20000000298023224 137 | Axes Radius: 0.05000000074505806 138 | Class: rviz/PoseArray 139 | Color: 52; 101; 164 140 | Enabled: false 141 | Head Length: 0 142 | Head Radius: 0 143 | Name: PoseArray 144 | Shaft Length: 0.20000000298023224 145 | Shaft Radius: 0.20000000298023224 146 | Shape: Arrow (3D) 147 | Topic: /dr_spaam_detections 148 | Unreliable: false 149 | Value: false 150 | - Class: rviz/Marker 151 | Enabled: true 152 | Marker Topic: /dr_spaam_rviz 153 | Name: Marker 154 | Namespaces: 155 | dr_spaam_ros: true 156 | Queue Size: 100 157 | Value: true 158 | Enabled: true 159 | Global Options: 160 | Background Color: 46; 52; 54 161 | Default Light: true 162 | Fixed Frame: odom 163 | Frame Rate: 30 164 | Name: root 165 | Tools: 166 | - Class: rviz/Interact 167 | Hide Inactive Objects: true 168 | - Class: rviz/MoveCamera 169 | - Class: rviz/Select 170 | - Class: rviz/FocusCamera 171 | - Class: rviz/Measure 172 | - Class: rviz/SetInitialPose 173 | Theta std deviation: 0.2617993950843811 174 | Topic: /initialpose 175 | X std deviation: 0.5 176 | Y std deviation: 0.5 177 | - Class: rviz/SetGoal 178 | Topic: /move_base_simple/goal 179 | - Class: rviz/PublishPoint 180 | Single click: true 181 | Topic: /clicked_point 182 | Value: true 183 | Views: 184 | Current: 185 | Angle: 3.655003786087036 186 | Class: rviz/TopDownOrtho 187 | Enable Stereo Rendering: 188 | Stereo Eye Separation: 0.05999999865889549 189 | Stereo Focal Distance: 1 190 | Swap Stereo Eyes: false 191 | Value: false 192 | Invert Z Axis: false 193 | Name: Current View 194 | Near Clip Distance: 0.009999999776482582 195 | Scale: 52.93107604980469 196 | Target Frame: sick_laser_front 197 | Value: TopDownOrtho (rviz) 198 | X: 0.7433948516845703 199 | Y: 89.31505584716797 200 | Saved: ~ 201 | Window Geometry: 202 | Displays: 203 | collapsed: false 204 | Height: 2049 205 | Hide Left Dock: false 206 | Hide Right Dock: true 207 | QMainWindow State: 000000ff00000000fd000000040000000000000224000006fffc0200000018fb0000001200530065006c0065006300740069006f006e00000001e10000009b000000b000fffffffb0000001e0054006f006f006c002000500072006f007000650072007400690065007302000001ed000001df00000185000000a3fb000000120056006900650077007300200054006f006f02000001df000002110000018500000122fb000000200054006f006f006c002000500072006f0070006500720074006900650073003203000002880000011d000002210000017afb000000100044006900730070006c006100790073010000006e000006ff0000018200fffffffb0000002000730065006c0065006300740069006f006e00200062007500660066006500720200000138000000aa0000023a00000294fb00000014005700690064006500530074006500720065006f02000000e6000000d2000003ee0000030bfb0000000c004b0069006e0065006300740200000000000002160000024500000261fb0000000c004b0069006e00650063007402000001b4000002430000016a000000e2fb0000001400440053004c005200200069006d0061006700650200000c3e00000224000002c0000001a9fb00000028005200470042002000460072006f006e007400200054006f0070002000430061006d00650072006102000007b70000000f000003de00000375fb0000002800460072006f006e0074002000640065007000740068002000700061006e006f00720061006d00610000000000ffffffff0000000000000000fb0000002400460072006f006e00740020005200470042002000700061006e006f00720061006d006102000007c10000001800000262000003dbfb000000260052006500610072002000640065007000740068002000700061006e006f00720061006d0061020000051a0000002e00000266000003cbfb0000002200520065006100720020005200470042002000700061006e006f00720061006d006102000009ee0000001800000265000003cffb0000000c00430061006d00650072006102000000000000003d00000245000001d3fb0000000c00430061006d006500720061000000041d000000160000000000000000fb0000000a0049006d0061006700650200000947000002f3000001e90000014bfb00000008004c00650066007402000009eb000000f2000002450000009ffb0000000a0052006900670068007402000009eb0000016400000245000000c7fb00000008005200650061007202000000000000023100000245000000fefb0000000a00520069006700680074020000031a0000032c0000015600000120fb00000008005200650061007202000004740000032c0000015600000120fb0000000c00430061006d00650072006103000004d30000024b00000175000000dc00000001000001a8000003dbfc0200000003fb0000001e0054006f006f006c002000500072006f00700065007200740069006500730100000041000000780000000000000000fb0000000a00560069006500770073000000003d000003db0000013200fffffffb0000001200530065006c0065006300740069006f006e010000025a000000b200000000000000000000000200000490000000a9fc0100000001fb0000000a00560069006500770073030000004e00000080000002e1000001970000000300000f000000005afc0100000002fb0000000800540069006d0065010000000000000f000000057100fffffffb0000000800540069006d0065010000000000000450000000000000000000000cd0000006ff00000004000000040000000800000008fc0000000100000002000000010000000a0054006f006f006c00730100000000ffffffff0000000000000000 208 | Selection: 209 | collapsed: false 210 | Time: 211 | collapsed: false 212 | Tool Properties: 213 | collapsed: false 214 | Views: 215 | collapsed: true 216 | Width: 3840 217 | X: 0 218 | Y: 55 219 | -------------------------------------------------------------------------------- /dr_spaam_ros/launch/dr_spaam_ros.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /dr_spaam_ros/package.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | dr_spaam_ros 4 | 1.0.0 5 | ROS interface for DR-SPAAM detector 6 | 7 | Dan Jia 8 | 9 | 10 | 11 | 12 | TODO 13 | 14 | catkin 15 | rospy 16 | geometry_msgs 17 | sensor_msgs 18 | 19 | -------------------------------------------------------------------------------- /dr_spaam_ros/scripts/drow_data_converter.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import argparse 3 | from math import sin, cos 4 | import numpy as np 5 | 6 | import rospy 7 | import rosbag 8 | 9 | from geometry_msgs.msg import TransformStamped 10 | from sensor_msgs.msg import LaserScan 11 | from tf2_msgs.msg import TFMessage 12 | 13 | 14 | def load_scans(fname): 15 | data = np.genfromtxt(fname, delimiter=",") 16 | seqs, times, scans = data[:, 0].astype(np.uint32), data[:, 1].astype(np.float32), data[:, 2:].astype(np.float32) 17 | return seqs, times, scans 18 | 19 | 20 | def load_odoms(fname): 21 | data = np.genfromtxt(fname, delimiter=",") 22 | seqs, times = data[:, 0].astype(np.uint32), data[:, 1].astype(np.float32) 23 | odos = data[:, 2:].astype(np.float32) # x, y, phi 24 | return seqs, times, odos 25 | 26 | 27 | def sequence_to_bag(seq_name, bag_name): 28 | scan_msg = LaserScan() 29 | scan_msg.header.frame_id = 'sick_laser_front' 30 | scan_msg.angle_min = np.radians(-225.0 / 2) 31 | scan_msg.angle_max = np.radians(225.0 / 2) 32 | scan_msg.range_min = 0.005 33 | scan_msg.range_max = 100.0 34 | scan_msg.scan_time = 0.066667 35 | scan_msg.time_increment = 0.000062 36 | scan_msg.angle_increment = (scan_msg.angle_max - scan_msg.angle_min) / 450 37 | 38 | tran = TransformStamped() 39 | tran.header.frame_id = 'base_footprint' 40 | tran.child_frame_id = 'sick_laser_front' 41 | 42 | with rosbag.Bag(bag_name, 'w') as bag: 43 | # write scans 44 | seqs, times, scans = load_scans(seq_name) 45 | for seq, time, scan in zip(seqs, times, scans): 46 | time = rospy.Time(time) 47 | scan_msg.header.seq = seq 48 | scan_msg.header.stamp = time 49 | scan_msg.ranges = scan 50 | bag.write('/sick_laser_front/scan', scan_msg, t=time) 51 | 52 | # write odometry 53 | seqs, times, odoms = load_odoms(seq_name[:-3] + 'odom2') 54 | for seq, time, odom in zip(seqs, times, odoms): 55 | time = rospy.Time(time) 56 | tran.header.seq = seq 57 | tran.header.stamp = time 58 | tran.transform.translation.x = odom[0] 59 | tran.transform.translation.y = odom[1] 60 | tran.transform.translation.z = 0.0 61 | tran.transform.rotation.x = 0.0 62 | tran.transform.rotation.y = 0.0 63 | tran.transform.rotation.z = sin(odom[2] * 0.5) 64 | tran.transform.rotation.w = cos(odom[2] * 0.5) 65 | tf_msg = TFMessage([tran]) 66 | bag.write('/tf', tf_msg, t=time) 67 | 68 | 69 | if __name__ == '__main__': 70 | parser = argparse.ArgumentParser(description="arg parser") 71 | parser.add_argument("--seq", type=str, required=True, help="path to sequence") 72 | parser.add_argument("--output", type=str, required=False, default="./out.bag") 73 | args = parser.parse_args() 74 | 75 | sequence_to_bag(args.seq, args.output) 76 | -------------------------------------------------------------------------------- /dr_spaam_ros/scripts/node.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import rospy 4 | from dr_spaam_ros.dr_spaam_ros import DrSpaamROS 5 | 6 | 7 | if __name__ == '__main__': 8 | rospy.init_node('dr_spaam_ros') 9 | try: 10 | DrSpaamROS() 11 | except rospy.ROSInterruptException: 12 | pass 13 | rospy.spin() 14 | -------------------------------------------------------------------------------- /dr_spaam_ros/setup.py: -------------------------------------------------------------------------------- 1 | ## ! DO NOT MANUALLY INVOKE THIS setup.py, USE CATKIN INSTEAD 2 | 3 | from distutils.core import setup 4 | from catkin_pkg.python_setup import generate_distutils_setup 5 | 6 | # fetch values from package.xml 7 | setup_args = generate_distutils_setup( 8 | packages=['dr_spaam_ros'], 9 | package_dir={'': 'src'}) 10 | 11 | setup(**setup_args) -------------------------------------------------------------------------------- /dr_spaam_ros/src/dr_spaam_ros/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VisualComputingInstitute/2D_lidar_person_detection/99dd7a2a0d64252905e4f621e2c45be64b653a32/dr_spaam_ros/src/dr_spaam_ros/__init__.py -------------------------------------------------------------------------------- /dr_spaam_ros/src/dr_spaam_ros/dr_spaam_ros.py: -------------------------------------------------------------------------------- 1 | # import time 2 | import numpy as np 3 | import rospy 4 | 5 | from sensor_msgs.msg import LaserScan 6 | from geometry_msgs.msg import Point, Pose, PoseArray 7 | from visualization_msgs.msg import Marker 8 | 9 | from dr_spaam.detector import Detector 10 | 11 | 12 | class DrSpaamROS: 13 | """ROS node to detect pedestrian using DROW3 or DR-SPAAM.""" 14 | 15 | def __init__(self): 16 | self._read_params() 17 | self._detector = Detector( 18 | self.weight_file, 19 | model=self.detector_model, 20 | gpu=True, 21 | stride=self.stride, 22 | panoramic_scan=self.panoramic_scan, 23 | ) 24 | self._init() 25 | 26 | def _read_params(self): 27 | """ 28 | @brief Reads parameters from ROS server. 29 | """ 30 | self.weight_file = rospy.get_param("~weight_file") 31 | self.conf_thresh = rospy.get_param("~conf_thresh") 32 | self.stride = rospy.get_param("~stride") 33 | self.detector_model = rospy.get_param("~detector_model") 34 | self.panoramic_scan = rospy.get_param("~panoramic_scan") 35 | 36 | def _init(self): 37 | """ 38 | @brief Initialize ROS connection. 39 | """ 40 | # Publisher 41 | topic, queue_size, latch = read_publisher_param("detections") 42 | self._dets_pub = rospy.Publisher( 43 | topic, PoseArray, queue_size=queue_size, latch=latch 44 | ) 45 | 46 | topic, queue_size, latch = read_publisher_param("rviz") 47 | self._rviz_pub = rospy.Publisher( 48 | topic, Marker, queue_size=queue_size, latch=latch 49 | ) 50 | 51 | # Subscriber 52 | topic, queue_size = read_subscriber_param("scan") 53 | self._scan_sub = rospy.Subscriber( 54 | topic, LaserScan, self._scan_callback, queue_size=queue_size 55 | ) 56 | 57 | def _scan_callback(self, msg): 58 | if ( 59 | self._dets_pub.get_num_connections() == 0 60 | and self._rviz_pub.get_num_connections() == 0 61 | ): 62 | return 63 | 64 | # TODO check the computation here 65 | if not self._detector.is_ready(): 66 | self._detector.set_laser_fov( 67 | np.rad2deg(msg.angle_increment * len(msg.ranges)) 68 | ) 69 | 70 | scan = np.array(msg.ranges) 71 | scan[scan == 0.0] = 29.99 72 | scan[np.isinf(scan)] = 29.99 73 | scan[np.isnan(scan)] = 29.99 74 | 75 | # t = time.time() 76 | dets_xy, dets_cls, _ = self._detector(scan) 77 | # print("[DrSpaamROS] End-to-end inference time: %f" % (t - time.time())) 78 | 79 | # confidence threshold 80 | conf_mask = (dets_cls >= self.conf_thresh).reshape(-1) 81 | dets_xy = dets_xy[conf_mask] 82 | dets_cls = dets_cls[conf_mask] 83 | 84 | # convert to ros msg and publish 85 | dets_msg = detections_to_pose_array(dets_xy, dets_cls) 86 | dets_msg.header = msg.header 87 | self._dets_pub.publish(dets_msg) 88 | 89 | rviz_msg = detections_to_rviz_marker(dets_xy, dets_cls) 90 | rviz_msg.header = msg.header 91 | self._rviz_pub.publish(rviz_msg) 92 | 93 | 94 | def detections_to_rviz_marker(dets_xy, dets_cls): 95 | """ 96 | @brief Convert detection to RViz marker msg. Each detection is marked as 97 | a circle approximated by line segments. 98 | """ 99 | msg = Marker() 100 | msg.action = Marker.ADD 101 | msg.ns = "dr_spaam_ros" 102 | msg.id = 0 103 | msg.type = Marker.LINE_LIST 104 | 105 | # set quaternion so that RViz does not give warning 106 | msg.pose.orientation.x = 0.0 107 | msg.pose.orientation.y = 0.0 108 | msg.pose.orientation.z = 0.0 109 | msg.pose.orientation.w = 1.0 110 | 111 | msg.scale.x = 0.03 # line width 112 | # red color 113 | msg.color.r = 1.0 114 | msg.color.a = 1.0 115 | 116 | # circle 117 | r = 0.4 118 | ang = np.linspace(0, 2 * np.pi, 20) 119 | xy_offsets = r * np.stack((np.cos(ang), np.sin(ang)), axis=1) 120 | 121 | # to msg 122 | for d_xy, d_cls in zip(dets_xy, dets_cls): 123 | for i in range(len(xy_offsets) - 1): 124 | # start point of a segment 125 | p0 = Point() 126 | p0.x = d_xy[0] + xy_offsets[i, 0] 127 | p0.y = d_xy[1] + xy_offsets[i, 1] 128 | p0.z = 0.0 129 | msg.points.append(p0) 130 | 131 | # end point 132 | p1 = Point() 133 | p1.x = d_xy[0] + xy_offsets[i + 1, 0] 134 | p1.y = d_xy[1] + xy_offsets[i + 1, 1] 135 | p1.z = 0.0 136 | msg.points.append(p1) 137 | 138 | return msg 139 | 140 | 141 | def detections_to_pose_array(dets_xy, dets_cls): 142 | pose_array = PoseArray() 143 | for d_xy, d_cls in zip(dets_xy, dets_cls): 144 | # Detector uses following frame convention: 145 | # x forward, y rightward, z downward, phi is angle w.r.t. x-axis 146 | p = Pose() 147 | p.position.x = d_xy[0] 148 | p.position.y = d_xy[1] 149 | p.position.z = 0.0 150 | pose_array.poses.append(p) 151 | 152 | return pose_array 153 | 154 | 155 | def read_subscriber_param(name): 156 | """ 157 | @brief Convenience function to read subscriber parameter. 158 | """ 159 | topic = rospy.get_param("~subscriber/" + name + "/topic") 160 | queue_size = rospy.get_param("~subscriber/" + name + "/queue_size") 161 | return topic, queue_size 162 | 163 | 164 | def read_publisher_param(name): 165 | """ 166 | @brief Convenience function to read publisher parameter. 167 | """ 168 | topic = rospy.get_param("~publisher/" + name + "/topic") 169 | queue_size = rospy.get_param("~publisher/" + name + "/queue_size") 170 | latch = rospy.get_param("~publisher/" + name + "/latch") 171 | return topic, queue_size, latch 172 | -------------------------------------------------------------------------------- /imgs/dr_spaam_ros_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VisualComputingInstitute/2D_lidar_person_detection/99dd7a2a0d64252905e4f621e2c45be64b653a32/imgs/dr_spaam_ros_graph.png -------------------------------------------------------------------------------- /imgs/dr_spaam_ros_teaser.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VisualComputingInstitute/2D_lidar_person_detection/99dd7a2a0d64252905e4f621e2c45be64b653a32/imgs/dr_spaam_ros_teaser.gif -------------------------------------------------------------------------------- /imgs/teaser_1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VisualComputingInstitute/2D_lidar_person_detection/99dd7a2a0d64252905e4f621e2c45be64b653a32/imgs/teaser_1.gif --------------------------------------------------------------------------------