├── .github └── FUNDING.yml ├── .gitignore ├── LICENSE ├── README.md ├── data.py ├── demo.py ├── imgs ├── arXiv_fig1.png ├── dataset_cmp.png ├── demo01.png └── demo_exp.png ├── main.py ├── model.py ├── requirements.txt ├── tools ├── extract_comma2k19.py └── extract_nuscenes.py ├── utils.py ├── utils_comma2k19 ├── LICENSE ├── __init__.py ├── benchmarks.py ├── camera.py ├── coordinates.py ├── dataset.py ├── orientation.py └── unzip_msft_fs.py └── view_transform.py /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: [OpenDriveLab] # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] 4 | patreon: # Replace with a single Patreon username 5 | open_collective: # Replace with a single Open Collective username 6 | ko_fi: # Replace with a single Ko-fi username 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # Replace with a single IssueHunt username 11 | otechie: # Replace with a single Otechie username 12 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry 13 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] 14 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data/* 2 | lightning_logs/ 3 | runs/ 4 | __pycache__ 5 | vis*/ 6 | *.ckpt 7 | .vscode -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2021 T.T. Tang tttang@sjtu.edu.cn 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | #

Openpilot-Deepdive

2 | **

Level 2 Autonomous Driving on a Single Device: Diving into the Devils of Openpilot

** 3 | ![image](./imgs/arXiv_fig1.png) 4 | [**Webpage**](https://sites.google.com/view/openpilot-deepdive/home) | [**Paper**](https://arxiv.org/abs/2206.08176) | [**Zhihu**](https://www.zhihu.com/people/PerceptionX) 5 | *** 6 | # Introduction 7 | 8 | This repository is the PyTorch implementation for our Openpilot-Deepdive. 9 | In contrast to most traditional autonomous driving solutions where the perception, prediction, and planning module are apart, [Openpilot](https://github.com/commaai/openpilot) uses an end-to-end neural network to predict the trajectory directly from the camera images, which is called Supercombo. We try to reimplement the training details and test the pipeline on public benchmarks. Experimental results of OP-Deepdive on nuScenes, Comma2k19, CARLA, and in-house realistic scenarios (collected in Shanghai) verify that a low-cost device can indeed achieve most L2 functionalities and be on par with the original Supercombo model. We also test on CommaTwo device with a dual-model deploeyment framework, which is in this repo: [Openpilot-Deployment](https://github.com/OpenPerceptionX/Openpilot-Deepdive/tree/deploy/Openpilot-Deployment). 10 | 11 | * [Directory Structure](#directory-structure) 12 | * [Changelog](#changelog) 13 | * [Quick Start](#quick-start-examples) 14 | * [Installation](#installation) 15 | * [Dataset](#dataset) 16 | * [Training and Testing](#training-and-testing) 17 | * [Demo](#demo) 18 | * [Baselines](#baselines) 19 | * [Citation](#citation) 20 | * [License](#license) 21 | 22 | *** 23 | # Directory Structure 24 | 25 | ``` 26 | Openpilot-Deepdive 27 | ├── tools - Tools to generate split on Comma2k19 and nuScenes datasets. 28 | ├── utils_comma2k19 - The utils provided by comma, copied from `commaai/comma2k19.git/utils` 29 | ├── data 30 | ├── nuscenes -> soft link to the nusSenes-all dataset 31 | ├── comma2k19 -> soft link to the Comma2k19 dataset 32 | ``` 33 | *** 34 | # Changelog 35 | 36 | 2022-6-17: We released the v1.0 code for Openpilot-Deepdive. 37 | 38 | 2022-6-26: We fix some problems and update the readme for using the code on bare-metal machines. Thanks @EliomEssaim and @MicroHest! 39 | 40 | 2022-7-13: We released the v1.0 code for [Openpilot-Deployment](https://github.com/OpenPerceptionX/Openpilot-Deepdive/tree/deploy/Openpilot-Deployment) for dual-model deployment in the Openpilot framework. 41 | 42 | *** 43 | # Quick Start Examples 44 | Before starting, we refer you to read the [arXiv](https://arxiv.org/abs/2206.08176) to understand the details of our work. 45 | ## Installation 46 | Clone repo and install requirements.txt in a [Python>=3.7.0](https://www.python.org/) environment, including [PyTorch>=1.7](https://pytorch.org/get-started/locally/). 47 | 48 | ``` 49 | git clone https://github.com/OpenPerceptionX/Openpilot-Deepdive.git # clone 50 | cd Openpilot-Deepdive 51 | pip install -r requirements.txt # install 52 | ``` 53 | ## Dataset 54 | We train and evaluate our model on two datasets, [nuScenes](https://www.nuscenes.org/nuscenes) and [Comma2k19](https://github.com/commaai/comma2k19). 55 | The table shows some key features of them. 56 | 57 | | Dataset | Raw
FPS (Hz) | Aligned&
FPS (Hz) | Length Per
Sequence
(Frames/Second) | Altogether
Length
(Minutes) | Scenario | Locations | 58 | | :----: |:----:|:----:|:----:|:----:|:----:|:----:| 59 | | nuScenes | 12 | 2 | 40 / 20 | 330 | Street | America
Singapore | 60 | | Comma2k19 | 20 | 20 | 1000 / 60 | 2000 | Highway | America | 61 | 62 | Please create a `data` folder and create soft links to the datasets. 63 | 64 | For dataset splits, you may create your own by running the scripts in the `tools` folder, or download it in https://github.com/OpenPerceptionX/Openpilot-Deepdive/issues/4. 65 | 66 | ## Training and Testing 67 | By default, the batch size is set to be 6 per GPU, which consumes 27 GB GPU memory. When using 8 V100 GPUs, it takes approximate 120 hours to train 100 epochs on Comma2k19 dataset. 68 | 69 | **Note**: Our lab use `slurm` to run and manage the tasks. Then, the PyTorch distributed training processes are initialized manually by `slurm`, since the automatic `mp.spawn` may cause unknown problems on slurm clusters. For most people who do not use a cluster, it's okay to launch the training process on bare-metal machines, but you will have to open multiple terminals and set some environmental variables manually if you want to use multiple GPUs. We will explain it below. 70 | 71 | **Warning**: Since we have to extract all the frames from the video before sending them into the network, the program is hungry for memory. The actual memory usage is related to `batch_size` and `n_workers`. By default, each process with `n_workers=4` and `batch_size=6` consumes around 40 to 50 GB memory. You'd better open an `htop` to monitor the memory usage, before the machine hangs. 72 | 73 | ```[bash] 74 | # Training on a slurm cluster 75 | export DIST_PORT = 23333 # You may use whatever you want 76 | export NUM_GPUS = 8 77 | PORT=$DIST_PORT$ srun -p $PARTITION$ --job-name=openpilot -n $NUM_GPUS$ --gres=gpu:$NUM_GPUS$ --ntasks-per-node=$NUM_GPUS$ python main.py 78 | ``` 79 | 80 | ```[bash] 81 | # Training on a bare-metal machine with a single GPU 82 | PORT=23333 SLURM_PROCID=0 SLURM_NTASKS=1 python main.py 83 | ``` 84 | 85 | ```[bash] 86 | # Training on a bare-metal machine with multiple GPUs 87 | # You need to open multiple terminals 88 | 89 | # Let's use 4 GPUs for example 90 | # Terminal 1 91 | PORT=23333 SLURM_PROCID=0 SLURM_NTASKS=4 python main.py 92 | # Terminal 2 93 | PORT=23333 SLURM_PROCID=1 SLURM_NTASKS=4 python main.py 94 | # Terminal 3 95 | PORT=23333 SLURM_PROCID=2 SLURM_NTASKS=4 python main.py 96 | # Terminal 4 97 | PORT=23333 SLURM_PROCID=3 SLURM_NTASKS=4 python main.py 98 | # Then, the training process will start after all 4 processes are launched. 99 | ``` 100 | 101 | By default, the program will not output anything once the training process starts, for the widely-used `tqdm` might be buggy on slurm clusters. So, you may see some debugging info like the one below and the program seems to be stuck. 102 | 103 | ``` 104 | [1656218909.68] starting job... 0 of 1 105 | [1656218911.53] DDP Initialized at localhost:23333 0 of 1 106 | Comma2k19SequenceDataset: DEMO mode is on. 107 | Loaded pretrained weights for efficientnet-b2 108 | ``` 109 | 110 | Don't worry, you can open a tensorboard to see the loss and validation curves. 111 | ``` 112 | tensorboard --logdir runs --bind_all 113 | ``` 114 | 115 | Otherwise, you may want to parse `--tqdm=True` to show the progress bar in `Terminal 1`. 116 | 117 | By default, the test process will be executed once every epoch. So we did not implement the independent test script. 118 | 119 | ## Demo 120 | See more demo and test cases on our [webpage](https://sites.google.com/view/openpilot-deepdive/home)! 121 | 122 | You can generate your own demo video using `demo.py`. It will generate some frames in the `./vis` folder. (You may have to create it first.) Then, you can generate a video using `ffmpeg`. 123 | 124 | https://user-images.githubusercontent.com/20351686/174319920-35b3ad34-a15e-43d7-be23-b135c24712e2.mp4 125 | 126 | 127 | *** 128 | # Baselines 129 | Here we list several baselines to perform trajectory prediction task on different datasets. You are welcome to pull request and add your work here! 130 | 131 | **nuScenes** 132 | | Method | AP@0.5(0-10) | AP@1(10-20) | AP@1(20-30) | AP@1(30-50) | 133 | | :----: |:----: |:----: |:----: |:----: | 134 | | Supercombo | 0.237 | 0.064 | 0.038 | 0.053 | 135 | | Supercombo-finetuned | 0.305 | 0.162 | 0.088 | 0.050 | 136 | | OP-Deepdive (Ours) | 0.28 | 0.14 | 0.067 | 0.038 | 137 | 138 | **Comma2k19** 139 | | Method | AP@0.5(0-10) | AP@1(10-20) | AP@1(20-30) | AP@1(30-50) | AP@2(50+) | Average Jerk* | 140 | | :----: |:----: |:----: |:----: |:----: |:----: |:----: | 141 | | Supercombo | 0.7966 | 0.6170 | 0.2661 | 0.0889 | 0.0062 | 2.2243 | 142 | | OP-Deepdive (Ours) | 0.909 | 0.808 | 0.651 | 0.465 | 0.239 | 4.7959 | 143 | 144 | \*: The lower, the better. To comparison, the average jerk of human driver's trajectories is 0.3232 m/s^2. 145 | 146 | *** 147 | # Citation 148 | Please use the following citation when referencing our repo or [arXiv](https://arxiv.org/abs/2206.08176). 149 | ``` 150 | @article{chen2022op, 151 | title={Level 2 Autonomous Driving on a Single Device: Diving into the Devils of Openpilot}, 152 | author={Li Chen and Tutian Tang and Zhitian Cai and Yang Li and Penghao Wu and Hongyang Li and Jianping Shi and Junchi Yan and Yu Qiao}, 153 | journal={arXiv preprint arXiv:2206.08176}, 154 | year={2022} 155 | } 156 | ``` 157 | *** 158 | # License 159 | All code within this repository is under [Apache License 2.0](https://www.apache.org/licenses/LICENSE-2.0). 160 | *** 161 | 162 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import io 3 | import json 4 | import torch 5 | from math import pi 6 | import numpy as np 7 | from scipy.interpolate import interp1d 8 | import cv2 9 | cv2.setNumThreads(0) 10 | cv2.ocl.setUseOpenCL(False) 11 | 12 | from PIL import Image 13 | from torch.utils.data import Dataset 14 | from torchvision import transforms 15 | from utils import warp, generate_random_params_for_warp 16 | from view_transform import calibration 17 | 18 | import utils_comma2k19.orientation as orient 19 | import utils_comma2k19.coordinates as coord 20 | 21 | 22 | class PlanningDataset(Dataset): 23 | def __init__(self, root='data', json_path_pattern='p3_%s.json', split='train'): 24 | self.samples = json.load(open(os.path.join(root, json_path_pattern % split))) 25 | print('PlanningDataset: %d samples loaded from %s' % 26 | (len(self.samples), os.path.join(root, json_path_pattern % split))) 27 | self.split = split 28 | 29 | self.img_root = os.path.join(root, 'nuscenes') 30 | self.transforms = transforms.Compose( 31 | [ 32 | # transforms.Resize((900 // 2, 1600 // 2)), 33 | # transforms.Resize((9 * 32, 16 * 32)), 34 | transforms.Resize((128, 256)), 35 | transforms.ToTensor(), 36 | transforms.Normalize([0.3890, 0.3937, 0.3851], 37 | [0.2172, 0.2141, 0.2209]), 38 | ] 39 | ) 40 | 41 | self.enable_aug = False 42 | self.view_transform = False 43 | 44 | self.use_memcache = False 45 | if self.use_memcache: 46 | self._init_mc_() 47 | 48 | def _init_mc_(self): 49 | from petrel_client.client import Client 50 | self.client = Client('~/petreloss.conf') 51 | print('======== Initializing Memcache: Success =======') 52 | 53 | def _get_cv2_image(self, path): 54 | if self.use_memcache: 55 | img_bytes = self.client.get(str(path)) 56 | assert(img_bytes is not None) 57 | img_mem_view = memoryview(img_bytes) 58 | img_array = np.frombuffer(img_mem_view, np.uint8) 59 | return cv2.imdecode(img_array, cv2.IMREAD_COLOR) 60 | 61 | else: 62 | return cv2.imread(path) 63 | 64 | def __len__(self): 65 | return len(self.samples) 66 | 67 | def __getitem__(self, idx): 68 | sample = self.samples[idx] 69 | imgs, future_poses = sample['imgs'], sample['future_poses'] 70 | 71 | # process future_poses 72 | future_poses = torch.tensor(future_poses) 73 | future_poses[:, 0] = future_poses[:, 0].clamp(1e-2, ) # the car will never go backward 74 | 75 | imgs = list(self._get_cv2_image(os.path.join(self.img_root, p)) for p in imgs) 76 | imgs = list(cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in imgs) # RGB 77 | 78 | # process images 79 | if self.enable_aug and self.split == 'train': 80 | # data augumentation when training 81 | # random distort (warp) 82 | w_offsets, h_offsets = generate_random_params_for_warp(imgs[0], random_rate=0.1) 83 | imgs = list(warp(img, w_offsets, h_offsets) for img in imgs) 84 | 85 | # random flip 86 | if np.random.rand() > 0.5: 87 | imgs = list(img[:, ::-1, :] for img in imgs) 88 | future_poses[:, 1] *= -1 89 | 90 | 91 | if self.view_transform: 92 | camera_rotation_matrix = np.linalg.inv(np.array(sample["camera_rotation_matrix_inv"])) 93 | camera_translation = -np.array(sample["camera_translation_inv"]) 94 | camera_extrinsic = np.vstack((np.hstack((camera_rotation_matrix, camera_translation.reshape((3, 1)))), np.array([0, 0, 0, 1]))) 95 | camera_extrinsic = np.linalg.inv(camera_extrinsic) 96 | warp_matrix = calibration(camera_extrinsic, np.array(sample["camera_intrinsic"])) 97 | imgs = list(cv2.warpPerspective(src = img, M = warp_matrix, dsize= (256,128), flags= cv2.WARP_INVERSE_MAP) for img in imgs) 98 | 99 | # cvt back to PIL images 100 | # cv2.imshow('0', imgs[0]) 101 | # cv2.imshow('1', imgs[1]) 102 | # cv2.waitKey(0) 103 | imgs = list(Image.fromarray(img) for img in imgs) 104 | imgs = list(self.transforms(img) for img in imgs) 105 | input_img = torch.cat(imgs, dim=0) 106 | 107 | return dict( 108 | input_img=input_img, 109 | future_poses=future_poses, 110 | camera_intrinsic=torch.tensor(sample['camera_intrinsic']), 111 | camera_extrinsic=torch.tensor(sample['camera_extrinsic']), 112 | camera_translation_inv=torch.tensor(sample['camera_translation_inv']), 113 | camera_rotation_matrix_inv=torch.tensor(sample['camera_rotation_matrix_inv']), 114 | ) 115 | 116 | 117 | class SequencePlanningDataset(PlanningDataset): 118 | def __init__(self, root='data', json_path_pattern='p3_%s.json', split='train'): 119 | print('Sequence', end='') 120 | self.fix_seq_length = 18 121 | super().__init__(root=root, json_path_pattern=json_path_pattern, split=split) 122 | 123 | def __getitem__(self, idx): 124 | seq_samples = self.samples[idx] 125 | seq_length = len(seq_samples) 126 | if seq_length < self.fix_seq_length: 127 | # Only 1 sample < 28 (==21) 128 | return self.__getitem__(np.random.randint(0, len(self.samples))) 129 | if seq_length > self.fix_seq_length: 130 | seq_length_delta = seq_length - self.fix_seq_length 131 | seq_length_delta = np.random.randint(0, seq_length_delta+1) 132 | seq_samples = seq_samples[seq_length_delta:self.fix_seq_length+seq_length_delta] 133 | 134 | seq_future_poses = list(smp['future_poses'] for smp in seq_samples) 135 | seq_imgs = list(smp['imgs'] for smp in seq_samples) 136 | 137 | seq_input_img = [] 138 | for imgs in seq_imgs: 139 | imgs = list(self._get_cv2_image(os.path.join(self.img_root, p)) for p in imgs) 140 | imgs = list(cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in imgs) # RGB 141 | imgs = list(Image.fromarray(img) for img in imgs) 142 | imgs = list(self.transforms(img) for img in imgs) 143 | input_img = torch.cat(imgs, dim=0) 144 | seq_input_img.append(input_img[None]) 145 | seq_input_img = torch.cat(seq_input_img) 146 | 147 | return dict( 148 | seq_input_img=seq_input_img, # torch.Size([28, 10, 3]) 149 | seq_future_poses=torch.tensor(seq_future_poses), # torch.Size([28, 6, 128, 256]) 150 | camera_intrinsic=torch.tensor(seq_samples[0]['camera_intrinsic']), 151 | camera_extrinsic=torch.tensor(seq_samples[0]['camera_extrinsic']), 152 | camera_translation_inv=torch.tensor(seq_samples[0]['camera_translation_inv']), 153 | camera_rotation_matrix_inv=torch.tensor(seq_samples[0]['camera_rotation_matrix_inv']), 154 | ) 155 | 156 | 157 | class Comma2k19SequenceDataset(PlanningDataset): 158 | def __init__(self, split_txt_path, prefix, mode, use_memcache=True, return_origin=False): 159 | self.split_txt_path = split_txt_path 160 | self.prefix = prefix 161 | 162 | self.samples = open(split_txt_path).readlines() 163 | self.samples = [i.strip() for i in self.samples] 164 | 165 | assert mode in ('train', 'val', 'demo') 166 | self.mode = mode 167 | if self.mode == 'demo': 168 | print('Comma2k19SequenceDataset: DEMO mode is on.') 169 | 170 | self.fix_seq_length = 800 if mode == 'train' else 800 171 | 172 | self.transforms = transforms.Compose( 173 | [ 174 | # transforms.Resize((900 // 2, 1600 // 2)), 175 | # transforms.Resize((9 * 32, 16 * 32)), 176 | transforms.Resize((128, 256)), 177 | transforms.ToTensor(), 178 | transforms.Normalize([0.3890, 0.3937, 0.3851], 179 | [0.2172, 0.2141, 0.2209]), 180 | ] 181 | ) 182 | 183 | self.warp_matrix = calibration(extrinsic_matrix=np.array([[ 0, -1, 0, 0], 184 | [ 0, 0, -1, 1.22], 185 | [ 1, 0, 0, 0], 186 | [ 0, 0, 0, 1]]), 187 | cam_intrinsics=np.array([[910, 0, 582], 188 | [0, 910, 437], 189 | [0, 0, 1]]), 190 | device_frame_from_road_frame=np.hstack((np.diag([1, -1, -1]), [[0], [0], [1.22]]))) 191 | 192 | self.use_memcache = use_memcache 193 | if self.use_memcache: 194 | self._init_mc_() 195 | 196 | self.return_origin = return_origin 197 | 198 | # from OpenPilot 199 | self.num_pts = 10 * 20 # 10 s * 20 Hz = 200 frames 200 | self.t_anchors = np.array( 201 | (0. , 0.00976562, 0.0390625 , 0.08789062, 0.15625 , 202 | 0.24414062, 0.3515625 , 0.47851562, 0.625 , 0.79101562, 203 | 0.9765625 , 1.18164062, 1.40625 , 1.65039062, 1.9140625 , 204 | 2.19726562, 2.5 , 2.82226562, 3.1640625 , 3.52539062, 205 | 3.90625 , 4.30664062, 4.7265625 , 5.16601562, 5.625 , 206 | 6.10351562, 6.6015625 , 7.11914062, 7.65625 , 8.21289062, 207 | 8.7890625 , 9.38476562, 10.) 208 | ) 209 | self.t_idx = np.linspace(0, 10, num=self.num_pts) 210 | 211 | 212 | def _get_cv2_vid(self, path): 213 | if self.use_memcache: 214 | path = self.client.generate_presigned_url(str(path), client_method='get_object', expires_in=3600) 215 | return cv2.VideoCapture(path) 216 | 217 | def _get_numpy(self, path): 218 | if self.use_memcache: 219 | bytes = io.BytesIO(memoryview(self.client.get(str(path)))) 220 | return np.lib.format.read_array(bytes) 221 | else: 222 | return np.load(path) 223 | 224 | def __getitem__(self, idx): 225 | seq_sample_path = self.prefix + self.samples[idx] 226 | cap = self._get_cv2_vid(seq_sample_path + '/video.hevc') 227 | if (cap.isOpened() == False): 228 | raise RuntimeError 229 | imgs = [] # <--- all frames here 230 | origin_imgs = [] 231 | while (cap.isOpened()): 232 | ret, frame = cap.read() 233 | if ret == True: 234 | imgs.append(frame) 235 | # cv2.imshow('frame', frame) 236 | # cv2.waitKey(0) 237 | if self.return_origin: 238 | origin_imgs.append(frame) 239 | else: 240 | break 241 | cap.release() 242 | 243 | seq_length = len(imgs) 244 | 245 | if self.mode == 'demo': 246 | self.fix_seq_length = seq_length - self.num_pts - 1 247 | 248 | if seq_length < self.fix_seq_length + self.num_pts: 249 | print('The length of sequence', seq_sample_path, 'is too short', 250 | '(%d < %d)' % (seq_length, self.fix_seq_length + self.num_pts)) 251 | return self.__getitem__(idx+1) 252 | 253 | seq_length_delta = seq_length - (self.fix_seq_length + self.num_pts) 254 | seq_length_delta = np.random.randint(1, seq_length_delta+1) 255 | 256 | seq_start_idx = seq_length_delta 257 | seq_end_idx = seq_length_delta + self.fix_seq_length 258 | 259 | # seq_input_img 260 | imgs = imgs[seq_start_idx-1: seq_end_idx] # contains one more img 261 | imgs = [cv2.warpPerspective(src=img, M=self.warp_matrix, dsize=(512,256), flags=cv2.WARP_INVERSE_MAP) for img in imgs] 262 | imgs = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in imgs] 263 | imgs = list(Image.fromarray(img) for img in imgs) 264 | imgs = list(self.transforms(img)[None] for img in imgs) 265 | input_img = torch.cat(imgs, dim=0) # [N+1, 3, H, W] 266 | del imgs 267 | input_img = torch.cat((input_img[:-1, ...], input_img[1:, ...]), dim=1) 268 | 269 | # poses 270 | frame_positions = self._get_numpy(self.prefix + self.samples[idx] + '/global_pose/frame_positions')[seq_start_idx: seq_end_idx+self.num_pts] 271 | frame_orientations = self._get_numpy(self.prefix + self.samples[idx] + '/global_pose/frame_orientations')[seq_start_idx: seq_end_idx+self.num_pts] 272 | 273 | future_poses = [] 274 | for i in range(self.fix_seq_length): 275 | ecef_from_local = orient.rot_from_quat(frame_orientations[i]) 276 | local_from_ecef = ecef_from_local.T 277 | frame_positions_local = np.einsum('ij,kj->ki', local_from_ecef, frame_positions - frame_positions[i]).astype(np.float32) 278 | 279 | # Time-Anchor like OpenPilot 280 | fs = [interp1d(self.t_idx, frame_positions_local[i: i+self.num_pts, j]) for j in range(3)] 281 | interp_positions = [fs[j](self.t_anchors)[:, None] for j in range(3)] 282 | interp_positions = np.concatenate(interp_positions, axis=1) 283 | 284 | future_poses.append(interp_positions) 285 | future_poses = torch.tensor(np.array(future_poses), dtype=torch.float32) 286 | 287 | rtn_dict = dict( 288 | seq_input_img=input_img, # torch.Size([N, 6, 128, 256]) 289 | seq_future_poses=future_poses, # torch.Size([N, num_pts, 3]) 290 | # camera_intrinsic=torch.tensor(seq_samples[0]['camera_intrinsic']), 291 | # camera_extrinsic=torch.tensor(seq_samples[0]['camera_extrinsic']), 292 | # camera_translation_inv=torch.tensor(seq_samples[0]['camera_translation_inv']), 293 | # camera_rotation_matrix_inv=torch.tensor(seq_samples[0]['camera_rotation_matrix_inv']), 294 | ) 295 | 296 | # For DEMO 297 | if self.return_origin: 298 | origin_imgs = origin_imgs[seq_start_idx: seq_end_idx] 299 | origin_imgs = [torch.tensor(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))[None] for img in origin_imgs] 300 | origin_imgs = torch.cat(origin_imgs, dim=0) # N, H_ori, W_ori, 3 301 | rtn_dict['origin_imgs'] = origin_imgs 302 | 303 | return rtn_dict 304 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | from utils import draw_path, draw_trajectory_on_ax 2 | import torch 3 | import os 4 | import cv2 5 | import numpy as np 6 | from tqdm import tqdm 7 | from torch.nn.functional import softmax 8 | 9 | from data import PlanningDataset, Comma2k19SequenceDataset 10 | from main import SequenceBaselineV1 11 | from torch.utils.data import DataLoader 12 | 13 | import matplotlib.pyplot as plt 14 | 15 | 16 | CKPT_PATH = 'vis/M5_epoch_94.pth' # Path to your checkpoint 17 | 18 | # You can generate your own comma2k19_demo.txt to make some fancy demos 19 | # val = Comma2k19SequenceDataset('data/comma2k19_demo.txt', 'data/comma2k19/','demo', use_memcache=False, return_origin=True) 20 | val = Comma2k19SequenceDataset('data/comma2k19_val_non_overlap.txt', 'data/comma2k19/','demo', use_memcache=False, return_origin=True) 21 | val_loader = DataLoader(val, 1, num_workers=0, shuffle=False) 22 | 23 | planning_v0 = SequenceBaselineV1(5, 33, 1.0, 0.0, 'adamw') 24 | planning_v0.load_state_dict(torch.load(CKPT_PATH)) 25 | planning_v0.eval().cuda() 26 | 27 | seq_idx = 0 28 | for b_idx, batch in enumerate(val_loader): 29 | os.mkdir('vis/M5_DEMO_%04d' % seq_idx) 30 | seq_inputs, seq_labels = batch['seq_input_img'].cuda(), batch['seq_future_poses'].cuda() 31 | origin_imgs = batch['origin_imgs'] 32 | # camera_rotation_matrix_inv=batch['camera_rotation_matrix_inv'].numpy()[0] 33 | # camera_translation_inv=batch['camera_translation_inv'].numpy()[0] 34 | # camera_intrinsic=batch['camera_intrinsic'].numpy()[0] 35 | bs = seq_labels.size(0) 36 | seq_length = seq_labels.size(1) 37 | 38 | hidden = torch.zeros((2, bs, 512), device=seq_inputs.device) 39 | 40 | img_idx = 0 41 | for t in tqdm(range(seq_length)): 42 | 43 | with torch.no_grad(): 44 | inputs, labels = seq_inputs[:, t, :, :, :], seq_labels[:, t, :, :] 45 | pred_cls, pred_trajectory, hidden = planning_v0(inputs, hidden) 46 | 47 | pred_conf = softmax(pred_cls, dim=-1).cpu().numpy()[0] 48 | pred_trajectory = pred_trajectory.reshape(planning_v0.M, planning_v0.num_pts, 3).cpu().numpy() 49 | 50 | inputs, labels = inputs.cpu(), labels.cpu() 51 | vis_img = (inputs.permute(0, 2, 3, 1)[0] * torch.tensor((0.2172, 0.2141, 0.2209, 0.2172, 0.2141, 0.2209)) + torch.tensor((0.3890, 0.3937, 0.3851, 0.3890, 0.3937, 0.3851)) ) * 255 52 | # print(vis_img.max(), vis_img.min(), vis_img.mean()) 53 | vis_img = vis_img.clamp(0, 255) 54 | img_0, img_1 = vis_img[..., :3].numpy().astype(np.uint8), vis_img[..., 3:].numpy().astype(np.uint8) 55 | 56 | # fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(nrows=2, ncols=2, figsize=(16, 9)) 57 | # fig = plt.figure(figsize=(16, 9), constrained_layout=True) 58 | # fig = plt.figure(figsize=(12, 9.9)) # W, H 59 | fig = plt.figure(figsize=(12, 9)) # W, H 60 | spec = fig.add_gridspec(3, 3) # H, W 61 | ax1 = fig.add_subplot(spec[ 2, 0]) # H, W 62 | ax2 = fig.add_subplot(spec[ 2, 1]) 63 | ax3 = fig.add_subplot(spec[ :, 2]) 64 | ax4 = fig.add_subplot(spec[0:2, 0:2]) 65 | 66 | ax1.imshow(img_0) 67 | ax1.set_title('network input [previous]') 68 | ax1.axis('off') 69 | 70 | ax2.imshow(img_1) 71 | ax2.set_title('network input [current]') 72 | ax2.axis('off') 73 | 74 | current_metric = (((pred_trajectory[pred_conf.argmax()] - labels.numpy()) ** 2).sum(-1) ** 0.5).mean().item() 75 | 76 | trajectories = list(pred_trajectory) + list(labels) 77 | confs = list(pred_conf) + [1, ] 78 | ax3 = draw_trajectory_on_ax(ax3, trajectories, confs, ylim=(0, 200)) 79 | ax3.set_title('Mean L2: %.2f' % current_metric) 80 | ax3.grid() 81 | 82 | origin_img = origin_imgs[0, t, :, :, :].numpy() 83 | overlay = origin_img.copy() 84 | draw_path(pred_trajectory[pred_conf.argmax(), :], overlay, width=1, height=1.2, fill_color=(255,255,255), line_color=(0,255,0)) 85 | origin_img = 0.5 * origin_img + 0.5 * overlay 86 | draw_path(pred_trajectory[pred_conf.argmax(), :], origin_img, width=1, height=1.2, fill_color=None, line_color=(0,255,0)) 87 | 88 | ax4.imshow(origin_img.astype(np.uint8)) 89 | ax4.set_title('project on current frame') 90 | ax4.axis('off') 91 | 92 | # ax4.imshow(img_1) 93 | # pred_mask = np.argmax(pred_conf) 94 | # pred_trajectory = [pred_trajectory[pred_mask, ...], ] + [batch['future_poses'].numpy()[0], ] 95 | # pred_conf = [pred_conf[pred_mask], ] + [1, ] 96 | # for pred_trajectory_single, pred_conf_single in zip(pred_trajectory, pred_conf): 97 | # location = list((p + camera_translation_inv for p in pred_trajectory_single)) 98 | # proj_trajectory = np.array(list((camera_intrinsic @ (camera_rotation_matrix_inv @ l) for l in location))) 99 | # proj_trajectory /= proj_trajectory[..., 2:3].repeat(3, -1) 100 | # proj_trajectory /= 2 101 | # proj_trajectory = proj_trajectory[(proj_trajectory[..., 0] > 0) & (proj_trajectory[..., 0] < 800)] 102 | # proj_trajectory = proj_trajectory[(proj_trajectory[..., 1] > 0) & (proj_trajectory[..., 1] < 450)] 103 | # ax4.plot(proj_trajectory[:, 0], proj_trajectory[:, 1], 'o-', label='gt' if pred_conf_single == 1.0 else 'pred - conf %.3f' % pred_conf_single, alpha=np.clip(pred_conf_single, 0.1, np.Inf)) 104 | 105 | # ax4.legend() 106 | plt.tight_layout() 107 | plt.savefig('vis/M5_DEMO_%04d/%08d.jpg' % (seq_idx, img_idx), pad_inches=0.2, bbox_inches='tight') 108 | img_idx += 1 109 | # plt.show() 110 | plt.close(fig) 111 | 112 | seq_idx += 1 113 | -------------------------------------------------------------------------------- /imgs/arXiv_fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenDriveLab/Openpilot-Deepdive/fae05055c071fe8b6ed0dd578bb047f29f2b4dd4/imgs/arXiv_fig1.png -------------------------------------------------------------------------------- /imgs/dataset_cmp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenDriveLab/Openpilot-Deepdive/fae05055c071fe8b6ed0dd578bb047f29f2b4dd4/imgs/dataset_cmp.png -------------------------------------------------------------------------------- /imgs/demo01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenDriveLab/Openpilot-Deepdive/fae05055c071fe8b6ed0dd578bb047f29f2b4dd4/imgs/demo01.png -------------------------------------------------------------------------------- /imgs/demo_exp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenDriveLab/Openpilot-Deepdive/fae05055c071fe8b6ed0dd578bb047f29f2b4dd4/imgs/demo_exp.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import random 5 | from tqdm import tqdm 6 | from argparse import ArgumentParser 7 | 8 | import numpy as np 9 | import torch 10 | from torch import nn 11 | import torch.nn.functional as F 12 | from torch import optim 13 | from torch.utils.data import DataLoader 14 | 15 | import torch.distributed as dist 16 | from torch.utils.data.distributed import DistributedSampler 17 | 18 | if torch.__version__ == 'parrots': 19 | from pavi import SummaryWriter 20 | else: 21 | from torch.utils.tensorboard import SummaryWriter 22 | 23 | from data import PlanningDataset, SequencePlanningDataset, Comma2k19SequenceDataset 24 | from model import PlaningNetwork, MultipleTrajectoryPredictionLoss, SequencePlanningNetwork 25 | from utils import draw_trajectory_on_ax, get_val_metric, get_val_metric_keys 26 | 27 | 28 | def get_hyperparameters(parser: ArgumentParser): 29 | parser.add_argument('--batch_size', type=int, default=6) 30 | parser.add_argument('--lr', type=float, default=1e-4) 31 | parser.add_argument('--n_workers', type=int, default=4) 32 | parser.add_argument('--epochs', type=int, default=100) 33 | parser.add_argument('--log_per_n_step', type=int, default=20) 34 | parser.add_argument('--val_per_n_epoch', type=int, default=1) 35 | 36 | parser.add_argument('--resume', type=str, default='') 37 | 38 | parser.add_argument('--M', type=int, default=5) 39 | parser.add_argument('--num_pts', type=int, default=33) 40 | parser.add_argument('--mtp_alpha', type=float, default=1.0) 41 | parser.add_argument('--optimizer', type=str, default='sgd') 42 | parser.add_argument('--sync_bn', type=bool, default=True) 43 | parser.add_argument('--tqdm', type=bool, default=False) 44 | parser.add_argument('--optimize_per_n_step', type=int, default=40) 45 | 46 | try: 47 | exp_name = os.environ["SLURM_JOB_ID"] 48 | except KeyError: 49 | exp_name = str(time.time()) 50 | parser.add_argument('--exp_name', type=str, default=exp_name) 51 | 52 | return parser 53 | 54 | 55 | def setup(rank, world_size): 56 | torch.cuda.set_device(rank) 57 | dist.init_process_group('nccl', init_method='tcp://localhost:%s' % os.environ['PORT'], rank=rank, world_size=world_size) 58 | print('[%.2f]' % time.time(), 'DDP Initialized at %s:%s' % ('localhost', os.environ['PORT']), rank, 'of', world_size, flush=True) 59 | 60 | 61 | def get_dataloader(rank, world_size, batch_size, pin_memory=False, num_workers=0): 62 | train = Comma2k19SequenceDataset('data/comma2k19_train_non_overlap.txt', 'data/comma2k19/','train', use_memcache=False) 63 | val = Comma2k19SequenceDataset('data/comma2k19_val_non_overlap.txt', 'data/comma2k19/','demo', use_memcache=False) 64 | 65 | if torch.__version__ == 'parrots': 66 | dist_sampler_params = dict(num_replicas=world_size, rank=rank, shuffle=True) 67 | else: 68 | dist_sampler_params = dict(num_replicas=world_size, rank=rank, shuffle=True, drop_last=True) 69 | train_sampler = DistributedSampler(train, **dist_sampler_params) 70 | val_sampler = DistributedSampler(val, **dist_sampler_params) 71 | 72 | loader_args = dict(num_workers=num_workers, persistent_workers=True if num_workers > 0 else False, prefetch_factor=2, pin_memory=pin_memory) 73 | train_loader = DataLoader(train, batch_size, sampler=train_sampler, **loader_args) 74 | val_loader = DataLoader(val, batch_size=1, sampler=val_sampler, **loader_args) 75 | 76 | return train_loader, val_loader 77 | 78 | 79 | def cleanup(): 80 | dist.destroy_process_group() 81 | 82 | class SequenceBaselineV1(nn.Module): 83 | def __init__(self, M, num_pts, mtp_alpha, lr, optimizer, optimize_per_n_step=40) -> None: 84 | super().__init__() 85 | self.M = M 86 | self.num_pts = num_pts 87 | self.mtp_alpha = mtp_alpha 88 | self.lr = lr 89 | self.optimizer = optimizer 90 | 91 | self.net = SequencePlanningNetwork(M, num_pts) 92 | 93 | self.optimize_per_n_step = optimize_per_n_step # for the gru module 94 | 95 | @staticmethod 96 | def configure_optimizers(args, model): 97 | if args.optimizer == 'sgd': 98 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=0.01) 99 | elif args.optimizer == 'adam': 100 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=0.01) 101 | elif args.optimizer == 'adamw': 102 | optimizer = optim.AdamW(model.parameters(), lr=args.lr, ) 103 | else: 104 | raise NotImplementedError 105 | lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 20, 0.9) 106 | 107 | return optimizer, lr_scheduler 108 | 109 | def forward(self, x, hidden=None): 110 | if hidden is None: 111 | hidden = torch.zeros((2, x.size(0), 512)).to(self.device) 112 | return self.net(x, hidden) 113 | 114 | 115 | def main(rank, world_size, args): 116 | if rank == 0: 117 | writer = SummaryWriter() 118 | 119 | train_dataloader, val_dataloader = get_dataloader(rank, world_size, args.batch_size, False, args.n_workers) 120 | model = SequenceBaselineV1(args.M, args.num_pts, args.mtp_alpha, args.lr, args.optimizer, args.optimize_per_n_step) 121 | use_sync_bn = args.sync_bn 122 | if use_sync_bn: 123 | model = nn.SyncBatchNorm.convert_sync_batchnorm(model) 124 | model = model.cuda() 125 | optimizer, lr_scheduler = model.configure_optimizers(args, model) 126 | model: SequenceBaselineV1 127 | if args.resume and rank == 0: 128 | print('Loading weights from', args.resume) 129 | model.load_state_dict(torch.load(args.resume), strict=True) 130 | dist.barrier() 131 | model = nn.parallel.DistributedDataParallel(model, device_ids=[rank], find_unused_parameters=True, broadcast_buffers=False) 132 | loss = MultipleTrajectoryPredictionLoss(args.mtp_alpha, args.M, args.num_pts, distance_type='angle') 133 | 134 | num_steps = 0 135 | disable_tqdm = (not args.tqdm) or (rank != 0) 136 | 137 | for epoch in tqdm(range(args.epochs), disable=disable_tqdm, position=0): 138 | train_dataloader.sampler.set_epoch(epoch) 139 | 140 | for batch_idx, data in enumerate(tqdm(train_dataloader, leave=False, disable=disable_tqdm, position=1)): 141 | seq_inputs, seq_labels = data['seq_input_img'].cuda(), data['seq_future_poses'].cuda() 142 | bs = seq_labels.size(0) 143 | seq_length = seq_labels.size(1) 144 | 145 | hidden = torch.zeros((2, bs, 512)).cuda() 146 | total_loss = 0 147 | for t in tqdm(range(seq_length), leave=False, disable=disable_tqdm, position=2): 148 | num_steps += 1 149 | inputs, labels = seq_inputs[:, t, :, :, :], seq_labels[:, t, :, :] 150 | pred_cls, pred_trajectory, hidden = model(inputs, hidden) 151 | 152 | cls_loss, reg_loss = loss(pred_cls, pred_trajectory, labels) 153 | total_loss += (cls_loss + args.mtp_alpha * reg_loss.mean()) / model.module.optimize_per_n_step 154 | 155 | if rank == 0 and (num_steps + 1) % args.log_per_n_step == 0: 156 | # TODO: add a customized log function 157 | writer.add_scalar('train/epoch', epoch, num_steps) 158 | writer.add_scalar('loss/cls', cls_loss, num_steps) 159 | writer.add_scalar('loss/reg', reg_loss.mean(), num_steps) 160 | writer.add_scalar('loss/reg_x', reg_loss[0], num_steps) 161 | writer.add_scalar('loss/reg_y', reg_loss[1], num_steps) 162 | writer.add_scalar('loss/reg_z', reg_loss[2], num_steps) 163 | writer.add_scalar('param/lr', optimizer.param_groups[0]['lr'], num_steps) 164 | 165 | if (t + 1) % model.module.optimize_per_n_step == 0: 166 | hidden = hidden.clone().detach() 167 | optimizer.zero_grad() 168 | total_loss.backward() 169 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # TODO: move to args 170 | optimizer.step() 171 | if rank == 0: 172 | writer.add_scalar('loss/total', total_loss, num_steps) 173 | total_loss = 0 174 | 175 | if not isinstance(total_loss, int): 176 | optimizer.zero_grad() 177 | total_loss.backward() 178 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # TODO: move to args 179 | optimizer.step() 180 | if rank == 0: 181 | writer.add_scalar('loss/total', total_loss, num_steps) 182 | 183 | lr_scheduler.step() 184 | if (epoch + 1) % args.val_per_n_epoch == 0: 185 | if rank == 0: 186 | # save model 187 | ckpt_path = os.path.join(writer.log_dir, 'epoch_%d.pth' % epoch) 188 | torch.save(model.module.state_dict(), ckpt_path) 189 | print('[Epoch %d] checkpoint saved at %s' % (epoch, ckpt_path)) 190 | 191 | model.eval() 192 | with torch.no_grad(): 193 | saved_metric_epoch = get_val_metric_keys() 194 | for batch_idx, data in enumerate(tqdm(val_dataloader, leave=False, disable=disable_tqdm, position=1)): 195 | seq_inputs, seq_labels = data['seq_input_img'].cuda(), data['seq_future_poses'].cuda() 196 | 197 | bs = seq_labels.size(0) 198 | seq_length = seq_labels.size(1) 199 | 200 | hidden = torch.zeros((2, bs, 512), device=seq_inputs.device) 201 | for t in tqdm(range(seq_length), leave=False, disable=True, position=2): 202 | inputs, labels = seq_inputs[:, t, :, :, :], seq_labels[:, t, :, :] 203 | pred_cls, pred_trajectory, hidden = model(inputs, hidden) 204 | 205 | metrics = get_val_metric(pred_cls, pred_trajectory.view(-1, args.M, args.num_pts, 3), labels) 206 | 207 | for k, v in metrics.items(): 208 | saved_metric_epoch[k].append(v.float().mean().item()) 209 | 210 | dist.barrier() # Wait for all processes 211 | # sync 212 | metric_single = torch.zeros((len(saved_metric_epoch), ), dtype=torch.float32, device='cuda') 213 | counter_single = torch.zeros((len(saved_metric_epoch), ), dtype=torch.int32, device='cuda') 214 | # From Python 3.6 onwards, the standard dict type maintains insertion order by default. 215 | # But, programmers should not rely on it. 216 | for i, k in enumerate(sorted(saved_metric_epoch.keys())): 217 | metric_single[i] = np.mean(saved_metric_epoch[k]) 218 | counter_single[i] = len(saved_metric_epoch[k]) 219 | 220 | metric_gather = [torch.zeros((len(saved_metric_epoch), ), dtype=torch.float32, device='cuda')[None] for _ in range(world_size)] 221 | counter_gather = [torch.zeros((len(saved_metric_epoch), ), dtype=torch.int32, device='cuda')[None] for _ in range(world_size)] 222 | dist.all_gather(metric_gather, metric_single[None]) 223 | dist.all_gather(counter_gather, counter_single[None]) 224 | 225 | if rank == 0: 226 | metric_gather = torch.cat(metric_gather, dim=0) # [world_size, num_metric_keys] 227 | counter_gather = torch.cat(counter_gather, dim=0) # [world_size, num_metric_keys] 228 | metric_gather_weighted_mean = (metric_gather * counter_gather).sum(0) / counter_gather.sum(0) 229 | for i, k in enumerate(sorted(saved_metric_epoch.keys())): 230 | writer.add_scalar(k, metric_gather_weighted_mean[i], num_steps) 231 | dist.barrier() 232 | 233 | model.train() 234 | 235 | cleanup() 236 | 237 | 238 | if __name__ == "__main__": 239 | print('[%.2f]' % time.time(), 'starting job...', os.environ['SLURM_PROCID'], 'of', os.environ['SLURM_NTASKS'], flush=True) 240 | 241 | parser = ArgumentParser() 242 | parser = get_hyperparameters(parser) 243 | args = parser.parse_args() 244 | 245 | setup(rank=int(os.environ['SLURM_PROCID']), world_size=int(os.environ['SLURM_NTASKS'])) 246 | main(rank=int(os.environ['SLURM_PROCID']), world_size=int(os.environ['SLURM_NTASKS']), args=args) 247 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from efficientnet_pytorch import EfficientNet 6 | 7 | 8 | class PlaningNetwork(nn.Module): 9 | def __init__(self, M, num_pts): 10 | super().__init__() 11 | self.M = M 12 | self.num_pts = num_pts 13 | self.backbone = EfficientNet.from_pretrained('efficientnet-b2', in_channels=6) 14 | 15 | use_avg_pooling = False # TODO 16 | if use_avg_pooling: 17 | self.plan_head = nn.Sequential( 18 | nn.AdaptiveAvgPool2d(1), 19 | nn.Flatten(), 20 | nn.BatchNorm1d(1408), 21 | nn.ReLU(), 22 | nn.Linear(1408, 4096), 23 | nn.BatchNorm1d(4096), 24 | nn.ReLU(), 25 | # nn.Dropout(0.3), 26 | nn.Linear(4096, M * (num_pts * 3 + 1)) # +1 for cls 27 | ) 28 | else: # more like the structure of OpenPilot 29 | self.plan_head = nn.Sequential( 30 | # 6, 450, 800 -> 1408, 14, 25 31 | # nn.AdaptiveMaxPool2d((4, 8)), # 1408, 4, 8 32 | nn.BatchNorm2d(1408), 33 | nn.Conv2d(1408, 32, 1), # 32, 4, 8 34 | nn.BatchNorm2d(32), 35 | nn.Flatten(), 36 | nn.ELU(), 37 | nn.Linear(1024, 4096), 38 | nn.BatchNorm1d(4096), 39 | nn.ReLU(), 40 | # nn.Dropout(0.3), 41 | nn.Linear(4096, M * (num_pts * 3 + 1)) # +1 for cls 42 | ) 43 | 44 | 45 | def forward(self, x): 46 | features = self.backbone.extract_features(x) 47 | raw_preds = self.plan_head(features) 48 | pred_cls = raw_preds[:, :self.M] 49 | pred_trajectory = raw_preds[:, self.M:].reshape(-1, self.M, self.num_pts, 3) 50 | 51 | pred_xs = pred_trajectory[:, :, :, 0:1].exp() 52 | pred_ys = pred_trajectory[:, :, :, 1:2].sinh() 53 | pred_zs = pred_trajectory[:, :, :, 2:3] 54 | return pred_cls, torch.cat((pred_xs, pred_ys, pred_zs), dim=3) 55 | 56 | 57 | class SequencePlanningNetwork(nn.Module): 58 | def __init__(self, M, num_pts): 59 | super().__init__() 60 | self.M = M 61 | self.num_pts = num_pts 62 | self.backbone = EfficientNet.from_pretrained('efficientnet-b2', in_channels=6) 63 | 64 | self.plan_head = nn.Sequential( 65 | # 6, 450, 800 -> 1408, 14, 25 66 | # nn.AdaptiveMaxPool2d((4, 8)), # 1408, 4, 8 67 | nn.BatchNorm2d(1408), 68 | nn.Conv2d(1408, 32, 1), # 32, 4, 8 69 | nn.BatchNorm2d(32), 70 | nn.Flatten(), 71 | nn.ELU(), 72 | ) 73 | self.gru = nn.GRU(input_size=1024, hidden_size=512, bidirectional=True, batch_first=True) # 1024 out 74 | self.plan_head_tip = nn.Sequential( 75 | nn.Flatten(), 76 | # nn.BatchNorm1d(1024), 77 | nn.ELU(), 78 | nn.Linear(1024, 4096), 79 | # nn.BatchNorm1d(4096), 80 | nn.ReLU(), 81 | # nn.Dropout(0.3), 82 | nn.Linear(4096, M * (num_pts * 3 + 1)) # +1 for cls 83 | ) 84 | 85 | def forward(self, x, hidden): 86 | features = self.backbone.extract_features(x) 87 | 88 | raw_preds = self.plan_head(features) 89 | raw_preds, hidden = self.gru(raw_preds[:, None, :], hidden) # N, L, H_in for batch_first=True 90 | raw_preds = self.plan_head_tip(raw_preds) 91 | 92 | pred_cls = raw_preds[:, :self.M] 93 | pred_trajectory = raw_preds[:, self.M:].reshape(-1, self.M, self.num_pts, 3) 94 | 95 | pred_xs = pred_trajectory[:, :, :, 0:1].exp() 96 | pred_ys = pred_trajectory[:, :, :, 1:2].sinh() 97 | pred_zs = pred_trajectory[:, :, :, 2:3] 98 | return pred_cls, torch.cat((pred_xs, pred_ys, pred_zs), dim=3), hidden 99 | 100 | 101 | class AbsoluteRelativeErrorLoss(nn.Module): 102 | def __init__(self, epsilon=1e-4): 103 | super().__init__() 104 | self.epsilon = epsilon 105 | 106 | def forward(self, pred, target): 107 | error = (pred - target) / (target + self.epsilon) 108 | return torch.abs(error) 109 | 110 | 111 | class SigmoidAbsoluteRelativeErrorLoss(nn.Module): 112 | def __init__(self, epsilon=1e-4): 113 | super().__init__() 114 | self.epsilon = epsilon 115 | 116 | def forward(self, pred, target): 117 | error = (pred - target) / (target + self.epsilon) 118 | return torch.sigmoid(torch.abs(error)) 119 | 120 | 121 | class MultipleTrajectoryPredictionLoss(nn.Module): 122 | def __init__(self, alpha, M, num_pts, distance_type='angle'): 123 | super().__init__() 124 | self.alpha = alpha # TODO: currently no use 125 | self.M = M 126 | self.num_pts = num_pts 127 | 128 | self.distance_type = distance_type 129 | if self.distance_type == 'angle': 130 | self.distance_func = nn.CosineSimilarity(dim=2) 131 | else: 132 | raise NotImplementedError 133 | self.cls_loss = nn.CrossEntropyLoss() 134 | self.reg_loss = nn.SmoothL1Loss(reduction='none') 135 | # self.reg_loss = SigmoidAbsoluteRelativeErrorLoss() 136 | # self.reg_loss = AbsoluteRelativeErrorLoss() 137 | 138 | def forward(self, pred_cls, pred_trajectory, gt): 139 | """ 140 | pred_cls: [B, M] 141 | pred_trajectory: [B, M * num_pts * 3] 142 | gt: [B, num_pts, 3] 143 | """ 144 | assert len(pred_cls) == len(pred_trajectory) == len(gt) 145 | pred_trajectory = pred_trajectory.reshape(-1, self.M, self.num_pts, 3) 146 | with torch.no_grad(): 147 | # step 1: calculate distance between gt and each prediction 148 | pred_end_positions = pred_trajectory[:, :, self.num_pts-1, :] # B, M, 3 149 | gt_end_positions = gt[:, self.num_pts-1:, :].expand(-1, self.M, -1) # B, 1, 3 -> B, M, 3 150 | 151 | distances = 1 - self.distance_func(pred_end_positions, gt_end_positions) # B, M 152 | index = distances.argmin(dim=1) # B 153 | 154 | gt_cls = index 155 | pred_trajectory = pred_trajectory[torch.tensor(range(len(gt_cls)), device=gt_cls.device), index, ...] # B, num_pts, 3 156 | 157 | cls_loss = self.cls_loss(pred_cls, gt_cls) 158 | 159 | reg_loss = self.reg_loss(pred_trajectory, gt).mean(dim=(0, 1)) 160 | 161 | return cls_loss, reg_loss 162 | 163 | 164 | if __name__ == '__main__': 165 | # model = EfficientNet.from_pretrained('efficientnet-b2', in_channels=6) 166 | model = PlaningNetwork(M=3, num_pts=20) 167 | 168 | dummy_input = torch.zeros((1, 6, 256, 512)) 169 | 170 | # features = model.extract_features(dummy_input) 171 | features = model(dummy_input) 172 | 173 | pred_cls = torch.rand(16, 5) 174 | pred_trajectory = torch.rand(16, 5*20*3) 175 | gt = torch.rand(16, 20, 3) 176 | 177 | loss = MultipleTrajectoryPredictionLoss(1.0, 5, 20) 178 | 179 | loss(pred_cls, pred_trajectory, gt) 180 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | nuscenes-devkit 2 | efficientnet_pytorch 3 | numpyencoder 4 | tensorboard 5 | # pytorch-lightning -------------------------------------------------------------------------------- /tools/extract_comma2k19.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | from tqdm import tqdm 4 | import os 5 | import cv2 6 | import glob 7 | cv2.setNumThreads(0) 8 | cv2.ocl.setUseOpenCL(False) 9 | import numpy as np 10 | import random 11 | random.seed(0) 12 | 13 | 14 | def main(): 15 | sequences = glob.glob('data/comma2k19/*/*/*/video.hevc') 16 | random.shuffle(sequences) 17 | 18 | num_seqs = len(sequences) 19 | print(num_seqs, 'sequences') 20 | 21 | num_train = int(0.8 * num_seqs) 22 | 23 | with open('data/comma2k19_train.txt', 'w') as f: 24 | f.writelines(seq.replace('data/comma2k19/', '').replace('/video.hevc', '\n') for seq in sequences[:num_train]) 25 | with open('data/comma2k19_val.txt', 'w') as f: 26 | f.writelines(seq.replace('data/comma2k19/', '').replace('/video.hevc', '\n') for seq in sequences[num_train:]) 27 | example_segment = 'data/comma2k19/Chunk_1/b0c9d2329ad1606b|2018-07-27--06-03-57/3/' 28 | frame_times = np.load(example_segment + 'global_pose/frame_times') 29 | print(frame_times.shape) 30 | 31 | # === Generating non-overlaping seqs === 32 | sequences = glob.glob('data/comma2k19/*/*/*/video.hevc') 33 | sequences = [seq.replace('data/comma2k19/', '').replace('/video.hevc', '') for seq in sequences] 34 | seq_names = list(set([seq.split('/')[1] for seq in sequences])) 35 | num_seqs = len(seq_names) 36 | num_train = int(0.8 * num_seqs) 37 | train_seq_names = seq_names[:num_train] 38 | with open('data/comma2k19_train_non_overlap.txt', 'w') as f: 39 | f.writelines(seq + '\n' for seq in sequences if seq.split('/')[1] in train_seq_names) 40 | with open('data/comma2k19_val_non_overlap.txt', 'w') as f: 41 | f.writelines(seq + '\n' for seq in sequences if seq.split('/')[1] not in train_seq_names) 42 | 43 | 44 | if __name__ == '__main__': 45 | main() 46 | -------------------------------------------------------------------------------- /tools/extract_nuscenes.py: -------------------------------------------------------------------------------- 1 | import math 2 | import json 3 | import random 4 | random.seed(0) 5 | import numpy as np 6 | from tqdm import tqdm 7 | from numpyencoder import NumpyEncoder 8 | 9 | from nuscenes.nuscenes import NuScenes 10 | from nuscenes.can_bus.can_bus_api import NuScenesCanBus 11 | from scipy.spatial.transform import Rotation 12 | 13 | 14 | # Hyper-Params 15 | DATA_ROOT = 'data/nuscenes' 16 | SPLIT = 'v1.0-trainval' 17 | NUM_RGB_IMGS = 2 18 | NUM_FUTURE_TRAJECTORY_PTS = 10 19 | OUTPUT_JSON_NAME = 'data/p3_10pts_can_bus_%s_temporal.json' 20 | GET_CAN_BUS = True 21 | 22 | TEMPORAL = True 23 | 24 | 25 | sensors_tree = { 26 | 'ms_imu': 27 | [ 28 | 'linear_accel', 29 | 'q', 30 | 'rotation_rate', 31 | ], 32 | 33 | 'pose': 34 | [ 35 | 'accel', 36 | 'orientation', 37 | 'pos', 38 | 'rotation_rate', 39 | 'vel', 40 | ], 41 | 42 | 'steeranglefeedback': 43 | [ 44 | 'value', 45 | ], 46 | 47 | 'vehicle_monitor': 48 | [ 49 | 'available_distance', 50 | 'battery_level', 51 | 'brake', 52 | 'brake_switch', 53 | 'gear_position', 54 | 'left_signal', 55 | 'rear_left_rpm', 56 | 'rear_right_rpm', 57 | 'right_signal', 58 | 'steering', 59 | 'steering_speed', 60 | 'throttle', 61 | 'vehicle_speed', 62 | 'yaw_rate', 63 | ], 64 | 65 | 'zoe_veh_info': 66 | [ 67 | 'FL_wheel_speed', 68 | 'FR_wheel_speed', 69 | 'RL_wheel_speed', 70 | 'RR_wheel_speed', 71 | 'left_solar', 72 | 'longitudinal_accel', 73 | 'meanEffTorque', 74 | 'odom', 75 | 'odom_speed', 76 | 'pedal_cc', 77 | 'regen', 78 | 'requestedTorqueAfterProc', 79 | 'right_solar', 80 | 'steer_corrected', 81 | 'steer_offset_can', 82 | 'steer_raw', 83 | 'transversal_accel', 84 | ], 85 | 86 | 'zoesensors': 87 | [ 88 | 'brake_sensor', 89 | 'steering_sensor', 90 | 'throttle_sensor', 91 | ], 92 | } 93 | 94 | 95 | def find_nearest_index(array, value): 96 | idx = np.searchsorted(array, value, side="left") 97 | if idx > 0 and (idx == len(array) or math.fabs(value - array[idx-1]) < math.fabs(value - array[idx])): 98 | return idx-1 99 | else: 100 | return idx 101 | 102 | 103 | def get_samples(nusc, scenes, nusc_can=None): 104 | samples = [] 105 | # list of dicts, where 106 | # 'img': LIST of filenames 0, 1, ..., NUM_RGB_IMGS - 1. 107 | # NUM_RGB_IMGS - 1 is the frame of 'current' timestamp 108 | # 'pt_%d': LIST of future points offset by current img 0, 1, ..., NUM_FUTURE_TRAJECTORY_PTS - 1 109 | # 0 is the point of the 'very next' timestamp 110 | for scene in tqdm(scenes, ncols=0): 111 | assert len(scene) >= NUM_RGB_IMGS + NUM_FUTURE_TRAJECTORY_PTS 112 | valid_start_tokens = scene[NUM_RGB_IMGS-1 : -NUM_FUTURE_TRAJECTORY_PTS] 113 | if TEMPORAL: 114 | cur_scene_samples = [] 115 | # CAN BUS 116 | if nusc_can is not None: 117 | can_bus_cache = dict() 118 | scene_token = nusc.get('sample', valid_start_tokens[0])['scene_token'] 119 | scene_name = nusc.get('scene', scene_token)['name'] 120 | 121 | has_can_bus_data = True 122 | for message_name, keys in sensors_tree.items(): 123 | try: 124 | can_data = nusc_can.get_messages(scene_name, message_name) 125 | except Exception: 126 | has_can_bus_data = False 127 | continue 128 | can_bus_cache['%s.utime' % (message_name)] = np.array([m['utime'] for m in can_data]) 129 | if len(can_bus_cache['%s.utime' % message_name]) == 0: 130 | has_can_bus_data = False 131 | continue 132 | for key_name in keys: 133 | can_bus_cache['%s.%s' % (message_name, key_name)] = np.array([m[key_name] for m in can_data]) 134 | 135 | if not has_can_bus_data: 136 | print('Error: %s does not have any CAN bus data!' % scene_name) 137 | continue 138 | 139 | for idx, cur_token in enumerate(valid_start_tokens): 140 | img_tokens = scene[idx:idx+NUM_RGB_IMGS] 141 | point_tokens = scene[idx+NUM_RGB_IMGS:idx+NUM_RGB_IMGS+NUM_FUTURE_TRAJECTORY_PTS] 142 | 143 | cam_front_data = nusc.get('sample_data', nusc.get('sample', cur_token)['data']['CAM_FRONT']) 144 | # Images 145 | imgs = list(nusc.get('sample_data', nusc.get('sample', token)['data']['CAM_FRONT'])['filename'] for token in img_tokens) 146 | 147 | # Ego poses 148 | cur_ego_pose = nusc.get('ego_pose', cam_front_data['ego_pose_token']) 149 | ego_rotation_matrix = Rotation.from_quat(np.array(cur_ego_pose['rotation'])[[1,2,3,0]]).as_matrix() 150 | ego_tranlation = np.array(cur_ego_pose['translation']) 151 | ego_rotation_matrix_inv = np.linalg.inv(ego_rotation_matrix) 152 | ego_tranlation_inv = -ego_tranlation 153 | 154 | future_poses = list(nusc.get('ego_pose', nusc.get('sample_data', nusc.get('sample', token)['data']['CAM_FRONT'])['ego_pose_token'])['translation'] for token in point_tokens) 155 | future_poses = list(ego_rotation_matrix_inv @ (np.array(future_pose)+ego_tranlation_inv) for future_pose in future_poses) 156 | future_poses = list(list(p) for p in future_poses) # for json 157 | 158 | # Camera Matrices 159 | calibration_para = nusc.get('calibrated_sensor', cam_front_data['calibrated_sensor_token']) 160 | camera_intrinsic = np.array(calibration_para['camera_intrinsic']) 161 | camera_rotation_matrix = Rotation.from_quat(np.array(calibration_para['rotation'])[[1,2,3,0]]).as_matrix() 162 | camera_translation = np.array(calibration_para['translation']) 163 | camera_rotation_matrix_inv = np.linalg.inv(camera_rotation_matrix) 164 | camera_translation_inv = -camera_translation 165 | camera_extrinsic = np.vstack((np.hstack((camera_rotation_matrix_inv, camera_translation_inv.reshape((3, 1)))), np.array([0, 0, 0, 1]))) 166 | 167 | cur_sample_to_append = dict( 168 | imgs=imgs, 169 | future_poses=future_poses, 170 | camera_intrinsic=camera_intrinsic.tolist(), 171 | camera_extrinsic=camera_extrinsic.tolist(), 172 | camera_translation_inv=camera_translation_inv.tolist(), 173 | camera_rotation_matrix_inv=camera_rotation_matrix_inv.tolist(), 174 | ) 175 | 176 | # CAN BUS 177 | if nusc_can is not None: 178 | img_timestamp = nusc.get('sample_data', nusc.get('sample', img_tokens[-1])['data']['CAM_FRONT'])['timestamp'] 179 | cur_sample_to_append['img_utime'] = img_timestamp 180 | for message_name, keys in sensors_tree.items(): 181 | message_utimes = can_bus_cache['%s.utime' % message_name] 182 | nearest_index = find_nearest_index(message_utimes, img_timestamp) 183 | can_bus_time_delta = abs(message_utimes[nearest_index] - img_timestamp) # ideally should be less than half the sample rate (2Hz * 2 = 4Hz) 184 | if can_bus_time_delta >= 0.25 * 1e6: 185 | print('Warning', scene_name, message_utimes[nearest_index], img_timestamp, can_bus_time_delta) 186 | cur_sample_to_append['can_bus.%s.utime' % message_name] = message_utimes[nearest_index] 187 | cur_sample_to_append['can_bus.%s.can_bus_delta' % message_name] = can_bus_time_delta 188 | for key_name in keys: 189 | can_bus_value = can_bus_cache['%s.%s' % (message_name, key_name)][nearest_index] 190 | if isinstance(can_bus_value, np.ndarray): 191 | can_bus_value = can_bus_value.tolist() 192 | cur_sample_to_append['can_bus.%s.%s' % (message_name, key_name)] = can_bus_value 193 | if TEMPORAL: 194 | cur_scene_samples.append(cur_sample_to_append) 195 | else: 196 | samples.append(cur_sample_to_append) 197 | 198 | if TEMPORAL: 199 | samples.append(cur_scene_samples) 200 | 201 | return samples 202 | 203 | 204 | # Load NuScenes dataset 205 | nusc = NuScenes(version=SPLIT, dataroot=DATA_ROOT, verbose=True) 206 | nusc_can = NuScenesCanBus(dataroot=DATA_ROOT) if GET_CAN_BUS else None 207 | 208 | # get all scenes into time structure 209 | all_scenes = [] 210 | for scene in nusc.scene: 211 | cur_token = scene['first_sample_token'] 212 | cur_scene_tokens = [] # saves tokens of samples in this scene 213 | while cur_token != '': 214 | cur_scene_tokens.append(cur_token) 215 | cur_sample = nusc.get('sample', cur_token) 216 | cur_token = cur_sample['next'] 217 | 218 | all_scenes.append(cur_scene_tokens) 219 | 220 | random.shuffle(all_scenes) 221 | 222 | length_all_scenes = len(all_scenes) 223 | print('Altogether', length_all_scenes, 'scenes') 224 | 225 | train_samples = get_samples(nusc, all_scenes[:int(length_all_scenes * 0.8)], nusc_can) 226 | val_samples = get_samples(nusc, all_scenes[int(length_all_scenes * 0.8):], nusc_can) 227 | 228 | json.dump(train_samples, open(OUTPUT_JSON_NAME % 'train', 'w'), indent='\t', cls=NumpyEncoder) 229 | json.dump(val_samples, open(OUTPUT_JSON_NAME % 'val', 'w'), indent='\t', cls=NumpyEncoder) 230 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import matplotlib 4 | import matplotlib.pyplot as plt 5 | from matplotlib.axes import Axes 6 | import torch 7 | from torch import nn 8 | import torch.nn.functional as F 9 | 10 | from utils_comma2k19.camera import img_from_device, denormalize, view_frame_from_device_frame 11 | from cycler import cycler 12 | matplotlib.rcParams['axes.prop_cycle'] = cycler('color', 13 | ['#1f77b4', '#ff7f0e', '#2ca02c', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf']) 14 | 15 | 16 | def draw_trajectory_on_ax(ax: Axes, trajectories, confs, line_type='o-', transparent=True, xlim=(-30, 30), ylim=(0, 100)): 17 | ''' 18 | ax: matplotlib.axes.Axes, the axis to draw trajectories on 19 | trajectories: List of numpy arrays of shape (num_points, 2 or 3) 20 | confs: List of numbers, 1 means gt 21 | ''' 22 | 23 | # get the max conf 24 | max_conf = max([conf for conf in confs if conf != 1]) 25 | 26 | for idx, (trajectory, conf) in enumerate(zip(trajectories, confs)): 27 | label = 'gt' if conf == 1 else 'pred%d (%.3f)' % (idx, conf) 28 | alpha = 1.0 29 | if transparent: 30 | alpha = 1.0 if conf == max_conf else np.clip(conf, 0.1, None) 31 | plot_args = dict(label=label, alpha=alpha, linewidth=2 if alpha == 1.0 else 1) 32 | if label == 'gt': 33 | plot_args['color'] = '#d62728' 34 | ax.plot(trajectory[:, 1], # - for nuscenes and + for comma 2k19 35 | trajectory[:, 0], 36 | line_type, **plot_args) 37 | if xlim is not None: 38 | ax.set_xlim(*xlim) 39 | if ylim is not None: 40 | ax.set_ylim(*ylim) 41 | ax.legend() 42 | 43 | return ax 44 | 45 | 46 | def get_val_metric(pred_cls, pred_trajectory, labels, namespace='val'): 47 | rtn_dict = dict() 48 | bs, M, num_pts, _ = pred_trajectory.shape 49 | 50 | # Lagecy metric: Prediction L2 loss 51 | pred_label = torch.argmax(pred_cls, -1) # B, 52 | pred_trajectory_single = pred_trajectory[torch.tensor(range(bs), device=pred_cls.device), pred_label, ...] 53 | l2_dists = F.mse_loss(pred_trajectory_single, labels, reduction='none') # B, num_pts, 2 or 3 54 | 55 | # Lagecy metric: cls Acc 56 | gt_trajectory_M = labels[:, None, ...].expand(-1, M, -1, -1) 57 | l2_distances = F.mse_loss(pred_trajectory, gt_trajectory_M, reduction='none').sum(dim=(2, 3)) # B, M 58 | best_match = torch.argmin(l2_distances, -1) # B, 59 | rtn_dict.update({'l2_dist': l2_dists.mean(dim=(1, 2)), 'cls_acc': best_match == pred_label}) 60 | 61 | # New Metric 62 | distance_splits = ((0, 10), (10, 20), (20, 30), (30, 50), (50, 1000)) 63 | AP_thresholds = (0.5, 1, 2) 64 | euclidean_distances = l2_dists.sum(-1).sqrt() # euclidean distances over the points: [B, num_pts] 65 | x_distances = labels[..., 0] # B, num_pts 66 | 67 | for min_dst, max_dst in distance_splits: 68 | points_mask = (x_distances >= min_dst) & (x_distances < max_dst) # B, num_pts, 69 | if points_mask.sum() == 0: 70 | continue # No gt points in this range 71 | rtn_dict.update({'eucliden_%d_%d' % (min_dst, max_dst): euclidean_distances[points_mask]}) # [sum(mask), ] 72 | rtn_dict.update({'eucliden_x_%d_%d' % (min_dst, max_dst): l2_dists[..., 0][points_mask].sqrt()}) # [sum(mask), ] 73 | rtn_dict.update({'eucliden_y_%d_%d' % (min_dst, max_dst): l2_dists[..., 1][points_mask].sqrt()}) # [sum(mask), ] 74 | 75 | for AP_threshold in AP_thresholds: 76 | hit_mask = (euclidean_distances < AP_threshold) & points_mask 77 | rtn_dict.update({'AP_%d_%d_%s' % (min_dst, max_dst, AP_threshold): hit_mask[points_mask]}) 78 | 79 | # add namespace 80 | if namespace is not None: 81 | for k in list(rtn_dict.keys()): 82 | rtn_dict['%s/%s' % (namespace, k)] = rtn_dict.pop(k) 83 | return rtn_dict 84 | 85 | 86 | def get_val_metric_keys(namespace='val'): 87 | rtn_dict = dict() 88 | rtn_dict.update({'l2_dist': [], 'cls_acc': []}) 89 | 90 | # New Metric 91 | distance_splits = ((0, 10), (10, 20), (20, 30), (30, 50), (50, 1000)) 92 | AP_thresholds = (0.5, 1, 2) 93 | 94 | for min_dst, max_dst in distance_splits: 95 | rtn_dict.update({'eucliden_%d_%d' % (min_dst, max_dst): []}) # [sum(mask), ] 96 | rtn_dict.update({'eucliden_x_%d_%d' % (min_dst, max_dst): []}) # [sum(mask), ] 97 | rtn_dict.update({'eucliden_y_%d_%d' % (min_dst, max_dst): []}) # [sum(mask), ] 98 | for AP_threshold in AP_thresholds: 99 | rtn_dict.update({'AP_%d_%d_%s' % (min_dst, max_dst, AP_threshold): []}) 100 | 101 | # add namespace 102 | if namespace is not None: 103 | for k in list(rtn_dict.keys()): 104 | rtn_dict['%s/%s' % (namespace, k)] = rtn_dict.pop(k) 105 | return rtn_dict 106 | 107 | 108 | def generate_random_params_for_warp(img, random_rate=0.1): 109 | h, w = img.shape[:2] 110 | 111 | width_max = random_rate * w 112 | height_max = random_rate * h 113 | 114 | # 8 offsets 115 | w_offsets = list(np.random.uniform(0, width_max) for _ in range(4)) 116 | h_offsets = list(np.random.uniform(0, height_max) for _ in range(4)) 117 | 118 | return w_offsets, h_offsets 119 | 120 | 121 | def warp(img, w_offsets, h_offsets): 122 | h, w = img.shape[:2] 123 | 124 | original_corner_pts = np.array( 125 | ( 126 | (w_offsets[0], h_offsets[0]), 127 | (w - w_offsets[1], h_offsets[1]), 128 | (w_offsets[2], h - h_offsets[2]), 129 | (w - w_offsets[3], h - h_offsets[3]), 130 | ), dtype=np.float32 131 | ) 132 | 133 | target_corner_pts = np.array( 134 | ( 135 | (0, 0), # Top-left 136 | (w, 0), # Top-right 137 | (0, h), # Bottom-left 138 | (w, h), # Bottom-right 139 | ), dtype=np.float32 140 | ) 141 | 142 | transform_matrix = cv2.getPerspectiveTransform(original_corner_pts, target_corner_pts) 143 | 144 | transformed_image = cv2.warpPerspective(img, transform_matrix, (w, h)) 145 | 146 | return transformed_image 147 | 148 | 149 | def draw_path(device_path, img, width=1, height=1.2, fill_color=(128,0,255), line_color=(0,255,0)): 150 | # device_path: N, 3 151 | device_path_l = device_path + np.array([0, 0, height]) 152 | device_path_r = device_path + np.array([0, 0, height]) 153 | device_path_l[:,1] -= width 154 | device_path_r[:,1] += width 155 | 156 | img_points_norm_l = img_from_device(device_path_l) 157 | img_points_norm_r = img_from_device(device_path_r) 158 | 159 | img_pts_l = denormalize(img_points_norm_l) 160 | img_pts_r = denormalize(img_points_norm_r) 161 | # filter out things rejected along the way 162 | valid = np.logical_and(np.isfinite(img_pts_l).all(axis=1), np.isfinite(img_pts_r).all(axis=1)) 163 | img_pts_l = img_pts_l[valid].astype(int) 164 | img_pts_r = img_pts_r[valid].astype(int) 165 | 166 | for i in range(1, len(img_pts_l)): 167 | u1,v1,u2,v2 = np.append(img_pts_l[i-1], img_pts_r[i-1]) 168 | u3,v3,u4,v4 = np.append(img_pts_l[i], img_pts_r[i]) 169 | pts = np.array([[u1,v1],[u2,v2],[u4,v4],[u3,v3]], np.int32).reshape((-1,1,2)) 170 | if fill_color: 171 | cv2.fillPoly(img,[pts],fill_color) 172 | if line_color: 173 | cv2.polylines(img,[pts],True,line_color) 174 | -------------------------------------------------------------------------------- /utils_comma2k19/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 comma.ai 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /utils_comma2k19/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenDriveLab/Openpilot-Deepdive/fae05055c071fe8b6ed0dd578bb047f29f2b4dd4/utils_comma2k19/__init__.py -------------------------------------------------------------------------------- /utils_comma2k19/benchmarks.py: -------------------------------------------------------------------------------- 1 | import coordinates as coord 2 | import numpy as np 3 | from tqdm import tqdm 4 | 5 | 6 | def get_altitude_errors(frame_poss): 7 | ''' 8 | Takes in a list of 2D arrays. Where every element 9 | of the list represents all the frame positions 10 | of a segment in ECEF. 11 | ''' 12 | 13 | # set size of squares (bins) of road which we 14 | # assume to the road to have constant height 15 | # and area within which to check 16 | local = coord.LocalCoord.from_ecef([-2712470.27794758, -4262442.18438959, 3879912.32221487]) 17 | north_bounds = [-16000, 2500] 18 | east_bounds = [-1000, 7000] 19 | binsize = 5 20 | north_bins, east_bins = [], [] 21 | for i in range(north_bounds[0], north_bounds[1], binsize): 22 | north_bins.append([i,i+binsize]) 23 | for i in range(east_bounds[0], east_bounds[1], binsize): 24 | east_bins.append([i,i+binsize]) 25 | 26 | # convert positions to NED 27 | frame_poss_ned = [] 28 | for pos in frame_poss: 29 | if pos is None: 30 | continue 31 | frame_poss_ned.append(local.ecef2ned(pos)) 32 | 33 | # find bin idxs for all frame positions 34 | bins = [[[] for j in range(len(east_bins))] for i in range(len(north_bins))] 35 | for pos_ned in frame_poss_ned: 36 | north_idxs = np.clip(((pos_ned[:,0] - north_bounds[0])/binsize).astype(int), 37 | 0, 38 | len(bins)-1) 39 | east_idxs = np.clip(((pos_ned[:,1] - east_bounds[0])/binsize).astype(int), 40 | 0, 41 | len(bins[0])-1) 42 | idxs = np.column_stack((north_idxs, east_idxs)) 43 | _, uniq = np.unique(idxs, return_index=True, axis=0) 44 | for p, idx in zip(pos_ned[uniq], idxs[uniq]): 45 | bins[idx[0]][idx[1]].append(p) 46 | 47 | 48 | # Now find the errors by looking at the deviation 49 | # from the mean in each bin 50 | alt_diffs = [] 51 | k = 0 52 | for pos_ned in tqdm(frame_poss_ned): 53 | k +=1 54 | north_idxs = np.clip(((pos_ned[:,0] - north_bounds[0])/binsize).astype(int), 55 | 0, 56 | len(bins)-1) 57 | east_idxs = np.clip(((pos_ned[:,1] - east_bounds[0])/binsize).astype(int), 58 | 0, 59 | len(bins[0])-1) 60 | idxs = np.column_stack((north_idxs, east_idxs)) 61 | alt_diffs.append([]) 62 | for p, idx in zip(pos_ned, idxs): 63 | # we want at least 5 observations per bin 64 | if len(bins[idx[0]][idx[1]]) > 5: 65 | alt_diffs[-1].append(p[2] - np.mean(np.array(bins[idx[0]][idx[1]])[:,2])) 66 | return alt_diffs 67 | -------------------------------------------------------------------------------- /utils_comma2k19/camera.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import utils_comma2k19.orientation as orient 3 | 4 | FULL_FRAME_SIZE = (1164, 874) 5 | W, H = FULL_FRAME_SIZE[0], FULL_FRAME_SIZE[1] 6 | eon_focal_length = FOCAL = 910.0 7 | 8 | # aka 'K' aka camera_frame_from_view_frame 9 | eon_intrinsics = np.array([ 10 | [FOCAL, 0., W/2.], 11 | [ 0., FOCAL, H/2.], 12 | [ 0., 0., 1.]]) 13 | 14 | # aka 'K_inv' aka view_frame_from_camera_frame 15 | eon_intrinsics_inv = np.linalg.inv(eon_intrinsics) 16 | 17 | # device/mesh : x->forward, y-> right, z->down 18 | # view : x->right, y->down, z->forward 19 | device_frame_from_view_frame = np.array([ 20 | [ 0., 0., 1.], 21 | [ 1., 0., 0.], 22 | [ 0., 1., 0.] 23 | ]) 24 | view_frame_from_device_frame = device_frame_from_view_frame.T 25 | 26 | 27 | def get_calib_from_vp(vp): 28 | vp_norm = normalize(vp) 29 | yaw_calib = np.arctan(vp_norm[0]) 30 | pitch_calib = np.arctan(vp_norm[1]*np.cos(yaw_calib)) 31 | # TODO should be, this but written 32 | # to be compatible with meshcalib and 33 | # get_view_frame_from_road_fram 34 | #pitch_calib = -np.arctan(vp_norm[1]*np.cos(yaw_calib)) 35 | roll_calib = 0 36 | return roll_calib, pitch_calib, yaw_calib 37 | 38 | # aka 'extrinsic_matrix' 39 | # road : x->forward, y -> left, z->up 40 | def get_view_frame_from_road_frame(roll, pitch, yaw, height): 41 | # TODO 42 | # calibration pitch is currently defined 43 | # opposite to pitch in device frame 44 | pitch = -pitch 45 | device_from_road = orient.rot_from_euler([roll, pitch, yaw]).dot(np.diag([1, -1, -1])) 46 | view_from_road = view_frame_from_device_frame.dot(device_from_road) 47 | return np.hstack((view_from_road, [[0], [height], [0]])) 48 | 49 | 50 | def vp_from_ke(m): 51 | """ 52 | Computes the vanishing point from the product of the intrinsic and extrinsic 53 | matrices C = KE. 54 | 55 | The vanishing point is defined as lim x->infinity C (x, 0, 0, 1).T 56 | """ 57 | return (m[0, 0]/m[2,0], m[1,0]/m[2,0]) 58 | 59 | def roll_from_ke(m): 60 | # note: different from calibration.h/RollAnglefromKE: i think that one's just wrong 61 | return np.arctan2(-(m[1, 0] - m[1, 1] * m[2, 0] / m[2, 1]), 62 | -(m[0, 0] - m[0, 1] * m[2, 0] / m[2, 1])) 63 | 64 | def normalize(img_pts): 65 | # normalizes image coordinates 66 | # accepts single pt or array of pts 67 | img_pts = np.array(img_pts) 68 | input_shape = img_pts.shape 69 | img_pts = np.atleast_2d(img_pts) 70 | img_pts = np.hstack((img_pts, np.ones((img_pts.shape[0],1)))) 71 | img_pts_normalized = eon_intrinsics_inv.dot(img_pts.T).T 72 | img_pts_normalized[(img_pts < 0).any(axis=1)] = np.nan 73 | return img_pts_normalized[:,:2].reshape(input_shape) 74 | 75 | def denormalize(img_pts): 76 | # denormalizes image coordinates 77 | # accepts single pt or array of pts 78 | img_pts = np.array(img_pts) 79 | input_shape = img_pts.shape 80 | img_pts = np.atleast_2d(img_pts) 81 | img_pts = np.hstack((img_pts, np.ones((img_pts.shape[0],1)))) 82 | img_pts_denormalized = eon_intrinsics.dot(img_pts.T).T 83 | img_pts_denormalized[img_pts_denormalized[:,0] > W] = np.nan 84 | img_pts_denormalized[img_pts_denormalized[:,0] < 0] = np.nan 85 | img_pts_denormalized[img_pts_denormalized[:,1] > H] = np.nan 86 | img_pts_denormalized[img_pts_denormalized[:,1] < 0] = np.nan 87 | return img_pts_denormalized[:,:2].reshape(input_shape) 88 | 89 | def device_from_ecef(pos_ecef, orientation_ecef, pt_ecef): 90 | # device from ecef frame 91 | # device frame is x -> forward, y-> right, z -> down 92 | # accepts single pt or array of pts 93 | input_shape = pt_ecef.shape 94 | pt_ecef = np.atleast_2d(pt_ecef) 95 | ecef_from_device_rot = orient.rotations_from_quats(orientation_ecef) 96 | device_from_ecef_rot = ecef_from_device_rot.T 97 | pt_ecef_rel = pt_ecef - pos_ecef 98 | pt_device = np.einsum('jk,ik->ij', device_from_ecef_rot, pt_ecef_rel) 99 | return pt_device.reshape(input_shape) 100 | 101 | def img_from_device(pt_device): 102 | # img coordinates from pts in device frame 103 | # first transforms to view frame, then to img coords 104 | # accepts single pt or array of pts 105 | input_shape = pt_device.shape 106 | pt_device = np.atleast_2d(pt_device) 107 | pt_view = np.einsum('jk,ik->ij', view_frame_from_device_frame, pt_device) 108 | 109 | # This function should never return negative depths 110 | pt_view[pt_view[:,2] < 0] = np.nan 111 | 112 | pt_img = pt_view/pt_view[:,2:3] 113 | return pt_img.reshape(input_shape)[:,:2] 114 | 115 | -------------------------------------------------------------------------------- /utils_comma2k19/coordinates.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Coordinate transformation module. All methods accept arrays as input 3 | with each row as a position. 4 | ''' 5 | 6 | import numpy as np 7 | 8 | a = 6378137 9 | b = 6356752.3142 10 | esq = 6.69437999014 * 0.001 11 | e1sq = 6.73949674228 * 0.001 12 | 13 | 14 | def geodetic2ecef(geodetic, radians=False): 15 | geodetic = np.array(geodetic) 16 | input_shape = geodetic.shape 17 | geodetic = np.atleast_2d(geodetic) 18 | 19 | ratio = 1.0 if radians else (np.pi / 180.0) 20 | lat = ratio*geodetic[:,0] 21 | lon = ratio*geodetic[:,1] 22 | alt = geodetic[:,2] 23 | 24 | xi = np.sqrt(1 - esq * np.sin(lat)**2) 25 | x = (a / xi + alt) * np.cos(lat) * np.cos(lon) 26 | y = (a / xi + alt) * np.cos(lat) * np.sin(lon) 27 | z = (a / xi * (1 - esq) + alt) * np.sin(lat) 28 | ecef = np.array([x, y, z]).T 29 | return ecef.reshape(input_shape) 30 | 31 | 32 | def ecef2geodetic(ecef, radians=False): 33 | """ 34 | Convert ECEF coordinates to geodetic using ferrari's method 35 | """ 36 | # Save shape and export column 37 | ecef = np.atleast_1d(ecef) 38 | input_shape = ecef.shape 39 | ecef = np.atleast_2d(ecef) 40 | x, y, z = ecef[:, 0], ecef[:, 1], ecef[:, 2] 41 | 42 | ratio = 1.0 if radians else (180.0 / np.pi) 43 | 44 | # Conver from ECEF to geodetic using Ferrari's methods 45 | # https://en.wikipedia.org/wiki/Geographic_coordinate_conversion#Ferrari.27s_solution 46 | r = np.sqrt(x * x + y * y) 47 | Esq = a * a - b * b 48 | F = 54 * b * b * z * z 49 | G = r * r + (1 - esq) * z * z - esq * Esq 50 | C = (esq * esq * F * r * r) / (pow(G, 3)) 51 | S = np.cbrt(1 + C + np.sqrt(C * C + 2 * C)) 52 | P = F / (3 * pow((S + 1 / S + 1), 2) * G * G) 53 | Q = np.sqrt(1 + 2 * esq * esq * P) 54 | r_0 = -(P * esq * r) / (1 + Q) + np.sqrt(0.5 * a * a*(1 + 1.0 / Q) - \ 55 | P * (1 - esq) * z * z / (Q * (1 + Q)) - 0.5 * P * r * r) 56 | U = np.sqrt(pow((r - esq * r_0), 2) + z * z) 57 | V = np.sqrt(pow((r - esq * r_0), 2) + (1 - esq) * z * z) 58 | Z_0 = b * b * z / (a * V) 59 | h = U * (1 - b * b / (a * V)) 60 | lat = ratio*np.arctan((z + e1sq * Z_0) / r) 61 | lon = ratio*np.arctan2(y, x) 62 | 63 | # stack the new columns and return to the original shape 64 | geodetic = np.column_stack((lat, lon, h)) 65 | return geodetic.reshape(input_shape) 66 | 67 | class LocalCoord(object): 68 | """ 69 | Allows conversions to local frames. In this case NED. 70 | That is: North East Down from the start position in 71 | meters. 72 | """ 73 | def __init__(self, init_geodetic, init_ecef): 74 | self.init_ecef = init_ecef 75 | lat, lon, _ = (np.pi/180)*np.array(init_geodetic) 76 | self.ned2ecef_matrix = np.array([[-np.sin(lat)*np.cos(lon), -np.sin(lon), -np.cos(lat)*np.cos(lon)], 77 | [-np.sin(lat)*np.sin(lon), np.cos(lon), -np.cos(lat)*np.sin(lon)], 78 | [np.cos(lat), 0, -np.sin(lat)]]) 79 | self.ecef2ned_matrix = self.ned2ecef_matrix.T 80 | 81 | @classmethod 82 | def from_geodetic(cls, init_geodetic): 83 | init_ecef = geodetic2ecef(init_geodetic) 84 | return LocalCoord(init_geodetic, init_ecef) 85 | 86 | @classmethod 87 | def from_ecef(cls, init_ecef): 88 | init_geodetic = ecef2geodetic(init_ecef) 89 | return LocalCoord(init_geodetic, init_ecef) 90 | 91 | 92 | def ecef2ned(self, ecef): 93 | ecef = np.array(ecef) 94 | return np.dot(self.ecef2ned_matrix, (ecef - self.init_ecef).T).T 95 | 96 | def ned2ecef(self, ned): 97 | ned = np.array(ned) 98 | # Transpose so that init_ecef will broadcast correctly for 1d or 2d ned. 99 | return (np.dot(self.ned2ecef_matrix, ned.T).T + self.init_ecef) 100 | 101 | def geodetic2ned(self, geodetic): 102 | ecef = geodetic2ecef(geodetic) 103 | return self.ecef2ned(ecef) 104 | 105 | def ned2geodetic(self, ned): 106 | ecef = self.ned2ecef(ned) 107 | return ecef2geodetic(ecef) 108 | -------------------------------------------------------------------------------- /utils_comma2k19/dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import numpy as np 3 | import torch 4 | from torch.utils.data import Dataset, DataLoader 5 | from torchvision import transforms 6 | from tools.lib.framereader import FrameReader, BaseFrameReader 7 | 8 | 9 | class ToTensor(object): 10 | 11 | def __call__(self, sample): 12 | return { 13 | key: torch.from_numpy(value) for key, value in sample.items() 14 | } 15 | 16 | 17 | class CommaDataset(Dataset): 18 | 19 | def __init__(self, main_dir, transform=None): 20 | self.main_dir = main_dir 21 | self.frame_reader = FrameReader(main_dir + 'video.hevc') 22 | 23 | self.gps_times = np.load(main_dir + 'global_pose/frame_gps_times') 24 | self.orientations = np.load(main_dir + 'global_pose/frame_orientations') 25 | self.positions = np.load(main_dir + 'global_pose/frame_positions') 26 | self.times = np.load(main_dir + 'global_pose/frame_times') 27 | self.velocities = np.load(main_dir + 'global_pose/frame_velocities') 28 | 29 | self.transform = transform 30 | 31 | def __len__(self): 32 | return len(self.velocities) 33 | 34 | def __getitem__(self, idx): 35 | image = np.array(self.frame_reader.get(idx, pix_fmt='rgb24')[0], dtype=np.float64) 36 | 37 | sample = { 38 | 'image': image, 39 | 'gps_times': self.gps_times, 40 | 'orientations': self.orientations, 41 | 'positions': self.positions, 42 | 'times': self.times, 43 | 'velocities': self.velocities[idx] 44 | } 45 | 46 | if self.transform: 47 | sample = self.transform(sample) 48 | 49 | return sample 50 | 51 | 52 | if __name__ == "__main__": 53 | 54 | example_segment = 'Example_1/b0c9d2329ad1606b|2018-08-02--08-34-47/40/' 55 | frame_idx = 200 56 | 57 | comma_dataset = CommaDataset(main_dir=example_segment, transform=transforms.Compose([ 58 | ToTensor() 59 | ])) 60 | 61 | comma_dataloader = DataLoader(comma_dataset, batch_size=4, shuffle=True, num_workers=0) 62 | 63 | # sample = comma_dataset[frame_idx] 64 | sample = next(iter(comma_dataloader)) 65 | image = sample['image'][0].numpy() 66 | velocity = sample['velocities'][0].numpy() 67 | -------------------------------------------------------------------------------- /utils_comma2k19/orientation.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Vectorized functions that transform between 3 | rotation matrices, euler angles and quaternions. 4 | All support lists, array or array of arrays as inputs. 5 | Supports both x2y and y_from_x format (y_from_x preferred!). 6 | ''' 7 | 8 | import numpy as np 9 | from numpy import dot, inner, array, linalg 10 | from utils_comma2k19.coordinates import LocalCoord 11 | 12 | 13 | def euler2quat(eulers): 14 | eulers = array(eulers) 15 | if len(eulers.shape) > 1: 16 | output_shape = (-1,4) 17 | else: 18 | output_shape = (4,) 19 | eulers = np.atleast_2d(eulers) 20 | gamma, theta, psi = eulers[:,0], eulers[:,1], eulers[:,2] 21 | 22 | q0 = np.cos(gamma / 2) * np.cos(theta / 2) * np.cos(psi / 2) + \ 23 | np.sin(gamma / 2) * np.sin(theta / 2) * np.sin(psi / 2) 24 | q1 = np.sin(gamma / 2) * np.cos(theta / 2) * np.cos(psi / 2) - \ 25 | np.cos(gamma / 2) * np.sin(theta / 2) * np.sin(psi / 2) 26 | q2 = np.cos(gamma / 2) * np.sin(theta / 2) * np.cos(psi / 2) + \ 27 | np.sin(gamma / 2) * np.cos(theta / 2) * np.sin(psi / 2) 28 | q3 = np.cos(gamma / 2) * np.cos(theta / 2) * np.sin(psi / 2) - \ 29 | np.sin(gamma / 2) * np.sin(theta / 2) * np.cos(psi / 2) 30 | 31 | quats = array([q0, q1, q2, q3]).T 32 | for i in xrange(len(quats)): 33 | if quats[i,0] < 0: 34 | quats[i] = -quats[i] 35 | return quats.reshape(output_shape) 36 | 37 | 38 | def quat2euler(quats): 39 | quats = array(quats) 40 | if len(quats.shape) > 1: 41 | output_shape = (-1,3) 42 | else: 43 | output_shape = (3,) 44 | quats = np.atleast_2d(quats) 45 | q0, q1, q2, q3 = quats[:,0], quats[:,1], quats[:,2], quats[:,3] 46 | 47 | gamma = np.arctan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1**2 + q2**2)) 48 | theta = np.arcsin(2 * (q0 * q2 - q3 * q1)) 49 | psi = np.arctan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2**2 + q3**2)) 50 | 51 | eulers = array([gamma, theta, psi]).T 52 | return eulers.reshape(output_shape) 53 | 54 | 55 | def quat2rot(quats): 56 | quats = array(quats) 57 | input_shape = quats.shape 58 | quats = np.atleast_2d(quats) 59 | Rs = np.zeros((quats.shape[0], 3, 3)) 60 | q0 = quats[:, 0] 61 | q1 = quats[:, 1] 62 | q2 = quats[:, 2] 63 | q3 = quats[:, 3] 64 | Rs[:, 0, 0] = q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3 65 | Rs[:, 0, 1] = 2 * (q1 * q2 - q0 * q3) 66 | Rs[:, 0, 2] = 2 * (q0 * q2 + q1 * q3) 67 | Rs[:, 1, 0] = 2 * (q1 * q2 + q0 * q3) 68 | Rs[:, 1, 1] = q0 * q0 - q1 * q1 + q2 * q2 - q3 * q3 69 | Rs[:, 1, 2] = 2 * (q2 * q3 - q0 * q1) 70 | Rs[:, 2, 0] = 2 * (q1 * q3 - q0 * q2) 71 | Rs[:, 2, 1] = 2 * (q0 * q1 + q2 * q3) 72 | Rs[:, 2, 2] = q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3 73 | 74 | if len(input_shape) < 2: 75 | return Rs[0] 76 | else: 77 | return Rs 78 | 79 | 80 | def rot2quat(rots): 81 | input_shape = rots.shape 82 | if len(input_shape) < 3: 83 | rots = array([rots]) 84 | K3 = np.empty((len(rots), 4, 4)) 85 | K3[:, 0, 0] = (rots[:, 0, 0] - rots[:, 1, 1] - rots[:, 2, 2]) / 3.0 86 | K3[:, 0, 1] = (rots[:, 1, 0] + rots[:, 0, 1]) / 3.0 87 | K3[:, 0, 2] = (rots[:, 2, 0] + rots[:, 0, 2]) / 3.0 88 | K3[:, 0, 3] = (rots[:, 1, 2] - rots[:, 2, 1]) / 3.0 89 | K3[:, 1, 0] = K3[:, 0, 1] 90 | K3[:, 1, 1] = (rots[:, 1, 1] - rots[:, 0, 0] - rots[:, 2, 2]) / 3.0 91 | K3[:, 1, 2] = (rots[:, 2, 1] + rots[:, 1, 2]) / 3.0 92 | K3[:, 1, 3] = (rots[:, 2, 0] - rots[:, 0, 2]) / 3.0 93 | K3[:, 2, 0] = K3[:, 0, 2] 94 | K3[:, 2, 1] = K3[:, 1, 2] 95 | K3[:, 2, 2] = (rots[:, 2, 2] - rots[:, 0, 0] - rots[:, 1, 1]) / 3.0 96 | K3[:, 2, 3] = (rots[:, 0, 1] - rots[:, 1, 0]) / 3.0 97 | K3[:, 3, 0] = K3[:, 0, 3] 98 | K3[:, 3, 1] = K3[:, 1, 3] 99 | K3[:, 3, 2] = K3[:, 2, 3] 100 | K3[:, 3, 3] = (rots[:, 0, 0] + rots[:, 1, 1] + rots[:, 2, 2]) / 3.0 101 | q = np.empty((len(rots), 4)) 102 | for i in xrange(len(rots)): 103 | _, eigvecs = linalg.eigh(K3[i].T) 104 | eigvecs = eigvecs[:,3:] 105 | q[i, 0] = eigvecs[-1] 106 | q[i, 1:] = -eigvecs[:-1].flatten() 107 | if q[i, 0] < 0: 108 | q[i] = -q[i] 109 | 110 | if len(input_shape) < 3: 111 | return q[0] 112 | else: 113 | return q 114 | 115 | 116 | def euler2rot(eulers): 117 | return rotations_from_quats(euler2quat(eulers)) 118 | 119 | 120 | def rot2euler(rots): 121 | return quat2euler(quats_from_rotations(rots)) 122 | 123 | 124 | quats_from_rotations = rot2quat 125 | quat_from_rot = rot2quat 126 | rotations_from_quats = quat2rot 127 | rot_from_quat= quat2rot 128 | rot_from_quat= quat2rot 129 | euler_from_rot = rot2euler 130 | euler_from_quat = quat2euler 131 | rot_from_euler = euler2rot 132 | quat_from_euler = euler2quat 133 | 134 | 135 | 136 | 137 | 138 | 139 | ''' 140 | Random helpers below 141 | ''' 142 | 143 | 144 | def quat_product(q, r): 145 | t = np.zeros(4) 146 | t[0] = r[0] * q[0] - r[1] * q[1] - r[2] * q[2] - r[3] * q[3] 147 | t[1] = r[0] * q[1] + r[1] * q[0] - r[2] * q[3] + r[3] * q[2] 148 | t[2] = r[0] * q[2] + r[1] * q[3] + r[2] * q[0] - r[3] * q[1] 149 | t[3] = r[0] * q[3] - r[1] * q[2] + r[2] * q[1] + r[3] * q[0] 150 | return t 151 | 152 | 153 | def rot_matrix(roll, pitch, yaw): 154 | cr, sr = np.cos(roll), np.sin(roll) 155 | cp, sp = np.cos(pitch), np.sin(pitch) 156 | cy, sy = np.cos(yaw), np.sin(yaw) 157 | rr = array([[1,0,0],[0, cr,-sr],[0, sr, cr]]) 158 | rp = array([[cp,0,sp],[0, 1,0],[-sp, 0, cp]]) 159 | ry = array([[cy,-sy,0],[sy, cy,0],[0, 0, 1]]) 160 | return ry.dot(rp.dot(rr)) 161 | 162 | 163 | def rot(axis, angle): 164 | # Rotates around an arbitrary axis 165 | ret_1 = (1 - np.cos(angle)) * array([[axis[0]**2, axis[0] * axis[1], axis[0] * axis[2]], [ 166 | axis[1] * axis[0], axis[1]**2, axis[1] * axis[2] 167 | ], [axis[2] * axis[0], axis[2] * axis[1], axis[2]**2]]) 168 | ret_2 = np.cos(angle) * np.eye(3) 169 | ret_3 = np.sin(angle) * array([[0, -axis[2], axis[1]], [axis[2], 0, -axis[0]], 170 | [-axis[1], axis[0], 0]]) 171 | return ret_1 + ret_2 + ret_3 172 | 173 | 174 | def ecef_euler_from_ned(ned_ecef_init, ned_pose): 175 | ''' 176 | Got it from here: 177 | Using Rotations to Build Aerospace Coordinate Systems 178 | -Don Koks 179 | ''' 180 | converter = LocalCoord.from_ecef(ned_ecef_init) 181 | x0 = converter.ned2ecef([1, 0, 0]) - converter.ned2ecef([0, 0, 0]) 182 | y0 = converter.ned2ecef([0, 1, 0]) - converter.ned2ecef([0, 0, 0]) 183 | z0 = converter.ned2ecef([0, 0, 1]) - converter.ned2ecef([0, 0, 0]) 184 | 185 | x1 = rot(z0, ned_pose[2]).dot(x0) 186 | y1 = rot(z0, ned_pose[2]).dot(y0) 187 | z1 = rot(z0, ned_pose[2]).dot(z0) 188 | 189 | x2 = rot(y1, ned_pose[1]).dot(x1) 190 | y2 = rot(y1, ned_pose[1]).dot(y1) 191 | z2 = rot(y1, ned_pose[1]).dot(z1) 192 | 193 | x3 = rot(x2, ned_pose[0]).dot(x2) 194 | y3 = rot(x2, ned_pose[0]).dot(y2) 195 | #z3 = rot(x2, ned_pose[0]).dot(z2) 196 | 197 | x0 = array([1, 0, 0]) 198 | y0 = array([0, 1, 0]) 199 | z0 = array([0, 0, 1]) 200 | 201 | psi = np.arctan2(inner(x3, y0), inner(x3, x0)) 202 | theta = np.arctan2(-inner(x3, z0), np.sqrt(inner(x3, x0)**2 + inner(x3, y0)**2)) 203 | y2 = rot(z0, psi).dot(y0) 204 | z2 = rot(y2, theta).dot(z0) 205 | phi = np.arctan2(inner(y3, z2), inner(y3, y2)) 206 | 207 | ret = array([phi, theta, psi]) 208 | return ret 209 | 210 | 211 | def ned_euler_from_ecef(ned_ecef_init, ecef_poses): 212 | ''' 213 | Got the math from here: 214 | Using Rotations to Build Aerospace Coordinate Systems 215 | -Don Koks 216 | 217 | Also accepts array of ecef_poses and array of ned_ecef_inits. 218 | Where each row is a pose and an ecef_init. 219 | ''' 220 | ned_ecef_init = array(ned_ecef_init) 221 | ecef_poses = array(ecef_poses) 222 | output_shape = ecef_poses.shape 223 | ned_ecef_init = np.atleast_2d(ned_ecef_init) 224 | if ned_ecef_init.shape[0] == 1: 225 | ned_ecef_init = np.tile(ned_ecef_init[0], (output_shape[0], 1)) 226 | ecef_poses = np.atleast_2d(ecef_poses) 227 | 228 | ned_poses = np.zeros(ecef_poses.shape) 229 | for i, ecef_pose in enumerate(ecef_poses): 230 | converter = LocalCoord.from_ecef(ned_ecef_init[i]) 231 | x0 = array([1, 0, 0]) 232 | y0 = array([0, 1, 0]) 233 | z0 = array([0, 0, 1]) 234 | 235 | x1 = rot(z0, ecef_pose[2]).dot(x0) 236 | y1 = rot(z0, ecef_pose[2]).dot(y0) 237 | z1 = rot(z0, ecef_pose[2]).dot(z0) 238 | 239 | x2 = rot(y1, ecef_pose[1]).dot(x1) 240 | y2 = rot(y1, ecef_pose[1]).dot(y1) 241 | z2 = rot(y1, ecef_pose[1]).dot(z1) 242 | 243 | x3 = rot(x2, ecef_pose[0]).dot(x2) 244 | y3 = rot(x2, ecef_pose[0]).dot(y2) 245 | #z3 = rot(x2, ecef_pose[0]).dot(z2) 246 | 247 | x0 = converter.ned2ecef([1, 0, 0]) - converter.ned2ecef([0, 0, 0]) 248 | y0 = converter.ned2ecef([0, 1, 0]) - converter.ned2ecef([0, 0, 0]) 249 | z0 = converter.ned2ecef([0, 0, 1]) - converter.ned2ecef([0, 0, 0]) 250 | 251 | psi = np.arctan2(inner(x3, y0), inner(x3, x0)) 252 | theta = np.arctan2(-inner(x3, z0), np.sqrt(inner(x3, x0)**2 + inner(x3, y0)**2)) 253 | y2 = rot(z0, psi).dot(y0) 254 | z2 = rot(y2, theta).dot(z0) 255 | phi = np.arctan2(inner(y3, z2), inner(y3, y2)) 256 | ned_poses[i] = array([phi, theta, psi]) 257 | 258 | return ned_poses.reshape(output_shape) 259 | 260 | 261 | def ecef2car(car_ecef, psi, theta, points_ecef, ned_converter): 262 | """ 263 | TODO: add roll rotation 264 | Converts an array of points in ecef coordinates into 265 | x-forward, y-left, z-up coordinates 266 | Parameters 267 | ---------- 268 | psi: yaw, radian 269 | theta: pitch, radian 270 | Returns 271 | ------- 272 | [x, y, z] coordinates in car frame 273 | """ 274 | 275 | # input is an array of points in ecef cocrdinates 276 | # output is an array of points in car's coordinate (x-front, y-left, z-up) 277 | 278 | # convert points to NED 279 | points_ned = [] 280 | for p in points_ecef: 281 | points_ned.append(ned_converter.ecef2ned_matrix.dot(array(p) - car_ecef)) 282 | 283 | points_ned = np.vstack(points_ned).T 284 | 285 | # n, e, d -> x, y, z 286 | # Calculate relative postions and rotate wrt to heading and pitch of car 287 | invert_R = array([[1., 0., 0.], [0., -1., 0.], [0., 0., -1.]]) 288 | 289 | c, s = np.cos(psi), np.sin(psi) 290 | yaw_R = array([[c, s, 0.], [-s, c, 0.], [0., 0., 1.]]) 291 | 292 | c, s = np.cos(theta), np.sin(theta) 293 | pitch_R = array([[c, 0., -s], [0., 1., 0.], [s, 0., c]]) 294 | 295 | return dot(pitch_R, dot(yaw_R, dot(invert_R, points_ned))) 296 | -------------------------------------------------------------------------------- /utils_comma2k19/unzip_msft_fs.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This script is a workaround for Microsoft-based filesystems (exFat, NTFS etc). 3 | These filesystems don't allow the vertical pipe ('|') in paths. 4 | So, instead of unzipping files manually, run this script which will replace 5 | all pipes in path names in the zip files by underscores. 6 | 7 | Usage: 8 | python3 unzip_msft_fs.py 9 | ''' 10 | 11 | import multiprocessing 12 | import os 13 | import shutil 14 | import sys 15 | import zipfile 16 | 17 | NUMBER_OF_CHUNKS = 10 18 | 19 | 20 | def unzip_replace(zip_dir, zip_name, extract_dir, 21 | filter_predicate, replace_me, replace_by): 22 | zip_path = os.path.join(zip_dir, zip_name) 23 | z = zipfile.ZipFile(zip_path) 24 | for f in z.infolist(): 25 | if filter_predicate(f): 26 | old = f.filename 27 | f.filename = f.filename.replace(replace_me, replace_by) 28 | z.extract(f, extract_dir) 29 | 30 | 31 | def fix_pipe(base): 32 | """ 33 | Given unzipped directory "base", 34 | creates new directories with | replaced with _, 35 | moves all contents into their respective new directories 36 | and then deletes the old (now empty) directories that contain |. 37 | """ 38 | for d in filter(lambda s: '|' in s, os.listdir(base)): 39 | old = os.path.join(base, d) 40 | new = os.path.join(base, d.replace('|', '_')) 41 | try: 42 | os.makedirs(new, exist_ok=False) 43 | contents = map(lambda s: os.path.join(old, s), os.listdir(old)) 44 | for f in contents: 45 | shutil.move(f, os.path.join(new)) 46 | os.rmdir(old) 47 | except Exception as e: 48 | print('New directory already exists \ 49 | -- did you execute this script already?') 50 | raise e 51 | 52 | 53 | def map_fn(args): 54 | unzip_replace(args['dir'], args['.zip'], args['extract'], 55 | lambda f: '|' in f.filename, '|', '_') 56 | print('Finished unzipping {}'.format( 57 | os.path.join(args['dir'], args['.zip']))) 58 | 59 | 60 | if __name__ == "__main__": 61 | if len(sys.argv) != 3: 62 | print('python3 unzip_msft_fs.py ') 63 | sys.exit(1) 64 | 65 | dataset_dir = sys.argv[1] 66 | goal_dir = sys.argv[2] 67 | 68 | if not os.path.isdir(goal_dir): 69 | print('Creating directories for you...') 70 | os.makedirs(goal_dir) 71 | 72 | bases = ['Chunk_%d.zip' % i for i in range(1, NUMBER_OF_CHUNKS + 1)] 73 | """ 74 | Assuming that your hard drive is slow, 75 | so it is ok too have more processes than cores 76 | i.e. overhead of managing more processes << time spent unzipping 77 | """ 78 | p = multiprocessing.Pool(len(bases)) 79 | 80 | # bbases: bases wih base :) 81 | bbases = map(lambda c: os.path.join(dataset_dir, c), bases) 82 | bbases = list(filter(os.path.isfile, bbases)) 83 | assert len(bbases) == NUMBER_OF_CHUNKS, \ 84 | "Could only find {} out of {} chunks in directory {}".format( 85 | len(bbases), NUMBER_OF_CHUNKS, dataset_dir) 86 | 87 | for i, b in enumerate(bases): 88 | bases[i] = { 89 | 'dir': dataset_dir, 90 | 'extract': goal_dir, 91 | '.zip': b 92 | } 93 | p.map(map_fn, bases) 94 | 95 | bbases = list(map(lambda b: b.replace('.zip', ''), bbases)) 96 | bbases = list(filter(os.path.isdir, bbases)) 97 | assert len(bbases) == NUMBER_OF_CHUNKS, \ 98 | "Could only find {} out of {} chunks in directory {}".format( 99 | len(bbases), NUMBER_OF_CHUNKS, dataset_dir) 100 | -------------------------------------------------------------------------------- /view_transform.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import numpy as np 5 | import cv2 6 | from PIL import Image 7 | from torch.utils.data import Dataset 8 | # from torchvision import transforms 9 | import matplotlib.pyplot as plt 10 | from tqdm import tqdm 11 | 12 | 13 | device_frame_from_view_frame = np.array([ 14 | [ 0., 0., 1.], 15 | [ 1., 0., 0.], 16 | [ 0., 1., 0.] 17 | ]) 18 | view_frame_from_device_frame = device_frame_from_view_frame.T 19 | 20 | # MED model 21 | MEDMODEL_INPUT_SIZE = (512, 256) 22 | MEDMODEL_YUV_SIZE = (MEDMODEL_INPUT_SIZE[0], MEDMODEL_INPUT_SIZE[1] * 3 // 2) 23 | MEDMODEL_CY = 47.6 24 | 25 | medmodel_fl = 910.0 26 | medmodel_intrinsics = np.array([ 27 | [medmodel_fl, 0.0, 0.5 * MEDMODEL_INPUT_SIZE[0]], 28 | [0.0, medmodel_fl, MEDMODEL_CY], 29 | [0.0, 0.0, 1.0]]) 30 | 31 | 32 | def calibration(extrinsic_matrix, cam_intrinsics, device_frame_from_road_frame=None): 33 | if device_frame_from_road_frame is None: 34 | device_frame_from_road_frame = np.hstack((np.diag([1, -1, -1]), [[0], [0], [1.51]])) 35 | med_frame_from_ground = medmodel_intrinsics@view_frame_from_device_frame@device_frame_from_road_frame[:,(0,1,3)] 36 | ground_from_med_frame = np.linalg.inv(med_frame_from_ground) 37 | 38 | 39 | extrinsic_matrix_eigen = extrinsic_matrix[:3] 40 | camera_frame_from_road_frame = np.dot(cam_intrinsics, extrinsic_matrix_eigen) 41 | camera_frame_from_ground = np.zeros((3,3)) 42 | camera_frame_from_ground[:,0] = camera_frame_from_road_frame[:,0] 43 | camera_frame_from_ground[:,1] = camera_frame_from_road_frame[:,1] 44 | camera_frame_from_ground[:,2] = camera_frame_from_road_frame[:,3] 45 | warp_matrix = np.dot(camera_frame_from_ground, ground_from_med_frame) 46 | 47 | return warp_matrix 48 | 49 | 50 | if __name__ == '__main__': 51 | from data import PlanningDataset 52 | dataset = PlanningDataset(split='val') 53 | for idx, data in tqdm(enumerate(dataset)): 54 | imgs = data["input_img"] 55 | img0 = imgs[0] 56 | camera_rotation_matrix = np.linalg.inv(data["camera_rotation_matrix_inv"].numpy()) 57 | camera_translation = -data["camera_translation_inv"].numpy() 58 | camera_extrinsic = np.vstack((np.hstack((camera_rotation_matrix, camera_translation.reshape((3, 1)))), np.array([0, 0, 0, 1]))) 59 | camera_extrinsic = np.linalg.inv(camera_extrinsic) 60 | cv2.imshow("origin_img",img0) 61 | cv2.waitKey(0) 62 | warp_matrix = calibration(camera_extrinsic, data["camera_intrinsic"].numpy()) 63 | transformed_img = cv2.warpPerspective(src = img0, M = warp_matrix, dsize= (512,256), flags= cv2.WARP_INVERSE_MAP) 64 | cv2.imshow("warped_img",transformed_img) 65 | cv2.waitKey(0) 66 | --------------------------------------------------------------------------------