├── .gitignore ├── images └── diagram.PNG ├── requirements.txt ├── scripts ├── _init_paths.py ├── gen_h5_file.py ├── prepare_flow_im.py ├── cal_im_flow2uv.py ├── cal_flow.py ├── prd_full_v.py ├── test_association.py ├── train_association.py ├── pt_wise_error.py └── split_data.py ├── lib ├── nb_utils.py ├── pyramidNet.py ├── data_loader_v.py └── gt_velocity.py └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | data 2 | data/* 3 | external/* 4 | lib/__pycache__ 5 | scripts/__pycache__ 6 | push.sh 7 | -------------------------------------------------------------------------------- /images/diagram.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/longyunf/radar-full-velocity/HEAD/images/diagram.PNG -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | nuscenes-devkit 4 | numpy 5 | glob2 6 | matplotlib 7 | h5py 8 | pyquaternion 9 | shapely 10 | argparse 11 | tqdm 12 | scikit-image 13 | Pillow 14 | -------------------------------------------------------------------------------- /scripts/_init_paths.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | def add_path(path): 5 | if path not in sys.path: 6 | sys.path.insert(0, path) 7 | 8 | this_dir = os.path.dirname(__file__) 9 | lib_dir = os.path.join(this_dir, '..', 'lib') 10 | 11 | # Add library path to PYTHONPATH 12 | add_path(lib_dir) 13 | -------------------------------------------------------------------------------- /lib/nb_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | class sparse_neighbor_connection: 5 | def __init__(self, left, right, top, bottom, skip=1): 6 | self.xy, self.hn = self.getXYoffset(left, right, top, bottom, skip) 7 | self.num_nbs = len(self.xy) 8 | 9 | def getXYoffset(self, left, right, top, bottom, skip): 10 | xy = [] 11 | step = skip + 1 12 | x_pos = np.concatenate( [ np.arange(0, -left-1, -step)[::-1], np.arange(step, right+1, step) ] ) 13 | y_pos = np.concatenate( [ np.arange(0, -top-1, -step)[::-1], np.arange(step, bottom+1, step) ] ) 14 | 15 | for x in x_pos: 16 | for y in y_pos: 17 | xy.append([x,y]) 18 | 19 | hn = max([left, right, top, bottom]) 20 | 21 | return xy, hn 22 | 23 | def plot_neighbor(self): 24 | xy = self.xy 25 | hn = self.hn 26 | 27 | M = np.zeros((2*hn + 1, 2*hn + 1), dtype=np.uint8) 28 | for x, y in xy: 29 | x += hn 30 | y += hn 31 | M[y,x]=255 32 | M[hn, hn] = 128 33 | 34 | plt.imshow(M, cmap='gray') 35 | plt.title('%d neighbors' % len(xy)) 36 | plt.show() 37 | -------------------------------------------------------------------------------- /scripts/gen_h5_file.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from os.path import join 4 | import numpy as np 5 | import h5py 6 | from tqdm import tqdm 7 | import skimage.io as io 8 | import torch 9 | 10 | def create_data_group(hf, mode, sample_indices, dir_label): 11 | group = hf.create_group('%s' % mode) 12 | 13 | im_list = [] 14 | for idx in tqdm(sample_indices, '%s:im' % mode): 15 | im1 = io.imread(join(dir_label, '%05d_im.jpg' % idx)) 16 | im_list.append(im1) 17 | group.create_dataset('im',data=np.array(im_list)) 18 | del im_list 19 | 20 | uv2_im_list = [] 21 | for idx in tqdm(sample_indices, '%s:uv2_im' % mode): 22 | uv2_im = np.load(join(dir_label, '%05d_im_uv.npy' % idx)) 23 | uv2_im_list.append(uv2_im) 24 | group.create_dataset('im_uv',data=np.array(uv2_im_list)) 25 | del uv2_im_list 26 | 27 | group.create_dataset('indices',data=np.array(sample_indices)) 28 | 29 | 30 | if __name__ == '__main__': 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument('--dir_data', type=str) 33 | args = parser.parse_args() 34 | 35 | if args.dir_data == None: 36 | this_dir = os.path.dirname(__file__) 37 | args.dir_data = join(this_dir, '..', 'data') 38 | 39 | dir_label = join(args.dir_data, 'prepared_data') 40 | path_h5_file = join(args.dir_data, 'prepared_data.h5') 41 | 42 | train_sample_indices = torch.load(join(args.dir_data,'sample_split.tar'))['train_sample_indices'] 43 | val_sample_indices = torch.load(join(args.dir_data,'sample_split.tar'))['val_sample_indices'] 44 | test_sample_indices = torch.load(join(args.dir_data,'sample_split.tar'))['test_sample_indices'] 45 | 46 | hf = h5py.File(path_h5_file, 'w') 47 | 48 | create_data_group(hf, 'train', train_sample_indices, dir_label) 49 | create_data_group(hf, 'val', val_sample_indices, dir_label) 50 | create_data_group(hf, 'test', test_sample_indices, dir_label) 51 | 52 | hf.close() 53 | 54 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Full-Velocity Radar Returns by Radar-Camera Fusion 2 | 3 | ![example figure](images/diagram.PNG) 4 | **The radar full velocity is estimated by using Doppler velocity and optical flow, which can be computed with (a) a previous image or (b) the next image.** 5 | 6 | 7 | ## Directories 8 | ```plain 9 | data/ 10 | nuscenes/ 11 | annotations/ 12 | maps/ 13 | samples/ 14 | sweeps/ 15 | v1.0-trainval/ 16 | lib/ 17 | scripts/ 18 | external/ 19 | RAFT/ 20 | ``` 21 | 22 | 23 | ## Setup 24 | - Create a conda environment called pda 25 | ```bash 26 | conda create -n pda python=3.6 27 | ``` 28 | - Install required packages 29 | ```bash 30 | pip install -r requirements.txt 31 | ``` 32 | - Download [nuScenes dataset](https://www.nuscenes.org/) (Full dataset (v1.0) Trainval) into data/nuscenes/ 33 | - Clone external repos [RAFT](https://github.com/princeton-vl/RAFT) into external/ 34 | 35 | ## Code 36 | **1. Data preparation** 37 | 38 | ```bash 39 | cd scripts 40 | 41 | # 1) split data 42 | python split_data.py 43 | 44 | # 2) extract images for flow computation 45 | python prepare_flow_im.py 46 | 47 | # 3) compute image flow 48 | python cal_flow.py 49 | 50 | # 4) transform image flow to normalized expression (u2,v2) 51 | python cal_im_flow2uv.py 52 | 53 | # 5) create .h5 dataset file 54 | python gen_h5_file3.py 55 | ``` 56 | 57 | **2. Estimate radar-camera association** 58 | ```bash 59 | python train_association.py # train 60 | python test_association.py # demo 61 | ``` 62 | Download [pre-trained weights](https://drive.google.com/drive/folders/1Yz9_mtq5QqLlyAAhVoeJTaq0BiQdJMpA?usp=sharing) 63 | 64 | **3. Predict radar full velocity** 65 | 66 | ```bash 67 | 68 | # 1) generate offsets of radar projections based on associations 69 | python test_association.py --gen_offset 70 | 71 | # 2) demo of full velocity prediction 72 | python prd_full_v.py 73 | 74 | # 3) evaluation of point-wise velocity 75 | python pt_wise_error.py 76 | ``` 77 | 78 | 79 | ## Citation 80 | ```plain 81 | @InProceedings{Long_2021_ICCV, 82 | author = {Long, Yunfei and Morris, Daniel and Liu, Xiaoming and Castro, Marcos and Chakravarty, Punarjay and Narayanan, Praveen}, 83 | title = {Full-Velocity Radar Returns by Radar-Camera Fusion}, 84 | booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision}, 85 | month = {October}, 86 | year = {2021} 87 | } 88 | ``` 89 | 90 | -------------------------------------------------------------------------------- /scripts/prepare_flow_im.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Prepare image pairs (key frame and the next one) for optical flow 3 | ''' 4 | 5 | import skimage.io as io 6 | import os 7 | from os.path import join 8 | import glob 9 | import argparse 10 | from skimage.transform import resize 11 | 12 | import torch 13 | from nuscenes.nuscenes import NuScenes 14 | 15 | 16 | def downsample_im(im, downsample_scale, y_cutoff): 17 | h_im, w_im = im.shape[0:2] 18 | h_im = int( h_im / downsample_scale ) 19 | w_im = int( w_im / downsample_scale ) 20 | 21 | im = resize(im, (h_im,w_im,3), order=1, preserve_range=True, anti_aliasing=False) 22 | im = im.astype('uint8') 23 | im = im[y_cutoff:,...] 24 | return im 25 | 26 | 27 | if __name__ == '__main__': 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument('--dir_data', type=str, help='data directory') 30 | parser.add_argument('--version', type=str, default='v1.0-trainval', help='dataset split') 31 | 32 | args = parser.parse_args() 33 | 34 | if args.dir_data == None: 35 | this_dir = os.path.dirname(__file__) 36 | args.dir_data = os.path.join(this_dir, '..', 'data') 37 | 38 | dir_nuscenes = join(args.dir_data, 'nuscenes') 39 | 40 | downsample_scale = 4 41 | y_cutoff = 33 42 | 43 | nusc = NuScenes(args.version, dataroot = dir_nuscenes, verbose=False) 44 | 45 | dir_data_out = join(args.dir_data, 'prepared_data') 46 | if not os.path.exists(dir_data_out): 47 | os.makedirs(dir_data_out) 48 | 49 | 'remove all files in the output folder' 50 | f_list=glob.glob(join(dir_data_out,'*')) 51 | for f in f_list: 52 | os.remove(f) 53 | print('removed %d old files in output folder' % len(f_list)) 54 | 55 | sample_indices = torch.load(join(args.dir_data,'sample_split.tar'))['all_indices'] 56 | 57 | ct = 0 58 | for sample_idx in sample_indices: 59 | 60 | cam_token = nusc.sample[sample_idx]['data']['CAM_FRONT'] 61 | cam_data = nusc.get('sample_data', cam_token) 62 | 63 | if cam_data['next']: 64 | cam_path = join(nusc.dataroot, cam_data['filename']) 65 | im1 = io.imread(cam_path) 66 | 67 | cam_token2 = cam_data['next'] 68 | cam_data2 = nusc.get('sample_data', cam_token2) 69 | cam_path2 = join(nusc.dataroot, cam_data2['filename']) 70 | im2 = io.imread(cam_path2) 71 | 72 | im = downsample_im(im1, downsample_scale, y_cutoff) 73 | im_next = downsample_im(im2, downsample_scale, y_cutoff) 74 | 75 | io.imsave(join(dir_data_out, '%05d_im_full.jpg' % sample_idx), im1) 76 | io.imsave(join(dir_data_out, '%05d_im_full_next.jpg' % sample_idx), im2) 77 | 78 | io.imsave(join(dir_data_out, '%05d_im.jpg' % sample_idx), im) 79 | io.imsave(join(dir_data_out, '%05d_im_next.jpg' % sample_idx), im_next) 80 | 81 | ct += 1 82 | print('Save image %d/%d' % ( ct, len(sample_indices) ) ) 83 | -------------------------------------------------------------------------------- /scripts/cal_im_flow2uv.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from os.path import join 4 | import numpy as np 5 | from tqdm import tqdm 6 | import torch 7 | from nuscenes.nuscenes import NuScenes 8 | 9 | 10 | def get_intrinsic_matrix(nusc, cam_token): 11 | cam_data = nusc.get('sample_data', cam_token) 12 | cs_rec = nusc.get('calibrated_sensor', cam_data['calibrated_sensor_token']) 13 | 14 | return np.array( cs_rec['camera_intrinsic'] ) 15 | 16 | 17 | def flow2uv_full(flow, K): 18 | ''' 19 | uv_map: h x w x 2 20 | ''' 21 | f = K[0,0] 22 | cx = K[0,2] 23 | cy = K[1,2] 24 | 25 | h,w = flow.shape[:2] 26 | x_map, y_map = np.meshgrid(np.arange(w), np.arange(h)) 27 | x_map, y_map = x_map.astype('float32'), y_map.astype('float32') 28 | x_map += flow[..., 0] 29 | y_map += flow[..., 1] 30 | 31 | u_map = (x_map - cx) / f 32 | v_map = (y_map - cy) / f 33 | 34 | uv_map = np.stack([u_map,v_map], axis=2) 35 | 36 | return uv_map 37 | 38 | 39 | def downsample_flow(flow_full, downsample_scale, y_cutoff): 40 | H, W, nc = flow_full.shape 41 | h = int( H / downsample_scale ) 42 | w = int( W / downsample_scale ) 43 | 44 | x_map, y_map = np.meshgrid(np.arange(w), np.arange(h)) 45 | 46 | x_map_old = np.round( np.clip( x_map * downsample_scale, 0, W-1) ).astype(int).ravel() 47 | y_map_old = np.round( np.clip( y_map * downsample_scale, 0, H-1) ).astype(int).ravel() 48 | 49 | flow_list = [] 50 | for i in range(nc): 51 | flow_list.append(flow_full[y_map_old, x_map_old, i]) 52 | 53 | flow = np.stack(flow_list, axis=1) 54 | flow = np.reshape(flow, (h,w,-1)) 55 | flow = flow[y_cutoff:,...] 56 | 57 | return flow 58 | 59 | 60 | if __name__ == '__main__': 61 | parser = argparse.ArgumentParser() 62 | parser.add_argument('--dir_data', type=str, help='data directory') 63 | parser.add_argument('--version', type=str, default='v1.0-trainval', help='dataset split') 64 | 65 | args = parser.parse_args() 66 | 67 | if args.dir_data == None: 68 | this_dir = os.path.dirname(__file__) 69 | args.dir_data = os.path.join(this_dir, '..', 'data') 70 | 71 | dir_nuscenes = join(args.dir_data, 'nuscenes') 72 | out_dir = join(args.dir_data, 'prepared_data') 73 | 74 | nusc = NuScenes(args.version, dataroot = dir_nuscenes, verbose=False) 75 | 76 | downsample_scale = 4 77 | y_cutoff = 33 78 | 79 | sample_indices = torch.load(join(args.dir_data,'sample_split.tar'))['all_indices'] 80 | 81 | ct = 0 82 | for sample_idx in tqdm(sample_indices): 83 | 84 | f_flow = join(out_dir, '%05d_full_flow.npy' % sample_idx) 85 | flow = np.load(f_flow) 86 | 87 | cam_token = nusc.sample[sample_idx]['data']['CAM_FRONT'] 88 | 89 | K = get_intrinsic_matrix(nusc, cam_token) 90 | 91 | flow_downsampled = downsample_flow(flow, downsample_scale, y_cutoff) 92 | flow_downsampled /= downsample_scale 93 | 94 | uv_map = flow2uv_full(flow, K) 95 | uv_map = downsample_flow(uv_map, downsample_scale, y_cutoff) 96 | 97 | np.save(f_flow[:-13] + 'im_uv.npy', uv_map) 98 | np.save(f_flow[:-13] + 'flow.npy', flow_downsampled) 99 | -------------------------------------------------------------------------------- /scripts/cal_flow.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Compute flow 3 | Based on RAFT (https://github.com/princeton-vl/RAFT) 4 | ''' 5 | import sys 6 | import argparse 7 | import os 8 | import glob 9 | import numpy as np 10 | import torch 11 | from PIL import Image 12 | from tqdm import tqdm 13 | from os.path import join 14 | 15 | raft_path = join(os.path.dirname(__file__), '..', 'external', 'RAFT', 'core') 16 | if raft_path not in sys.path: 17 | sys.path.insert(0, raft_path) 18 | from raft import RAFT 19 | 20 | DEVICE = 'cuda' 21 | 22 | if __name__ == '__main__': 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('--model', help="restore checkpoint") 25 | parser.add_argument('--small', action='store_true', help='use small model') 26 | parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision') 27 | parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation') 28 | parser.add_argument('--dir_data', type=str, help='dataset directory') 29 | parser.add_argument('--start_idx', type=int) 30 | parser.add_argument('--end_idx', type=int) 31 | 32 | args = parser.parse_args() 33 | 34 | if args.dir_data == None: 35 | this_dir = os.path.dirname(__file__) 36 | args.dir_data = join(this_dir, '..', 'data') 37 | 38 | if args.model == None: 39 | this_dir = os.path.dirname(__file__) 40 | args.model = join(this_dir, '..', 'external', 'RAFT', 'models', 'raft-kitti.pth') 41 | 42 | start_idx = args.start_idx 43 | end_idx = args.end_idx 44 | out_dir = join(args.dir_data, 'prepared_data') 45 | 46 | model = torch.nn.DataParallel(RAFT(args)) 47 | model.load_state_dict(torch.load(args.model)) 48 | 49 | model = model.module 50 | model.to(DEVICE) 51 | model.eval() 52 | 53 | im_list = np.array(np.sort(glob.glob(join(out_dir, '*im_full.jpg')))) 54 | 55 | N = len(im_list) 56 | 57 | print('Total sample number:', N) 58 | 59 | if start_idx == None: 60 | start_idx = 0 61 | 62 | if end_idx == None or end_idx > N - 1 : 63 | end_idx = N - 1 64 | 65 | for sample_idx in tqdm(range(start_idx, end_idx + 1)): 66 | 67 | f_im1 = im_list[sample_idx] 68 | 69 | im1 = np.array(Image.open(f_im1)).astype(np.uint8) 70 | 71 | f_im_next = f_im1[:-4] + '_next.jpg' 72 | f_im_prev = f_im1[:-4] + '_prev.jpg' 73 | 74 | if os.path.exists(f_im_next): 75 | im2 = np.array(Image.open(f_im_next)).astype(np.uint8) 76 | else: 77 | im2 = np.array(Image.open(f_im_prev)).astype(np.uint8) 78 | 79 | im1 = np.pad(im1, ((2,2),(0,0),(0,0)), 'constant') 80 | im2 = np.pad(im2, ((2,2),(0,0),(0,0)), 'constant') 81 | 82 | im1 = torch.from_numpy(im1).permute(2, 0, 1).float() 83 | im1 = im1[None,].to(DEVICE) 84 | 85 | im2 = torch.from_numpy(im2).permute(2, 0, 1).float() 86 | im2 = im2[None,].to(DEVICE) 87 | 88 | with torch.no_grad(): 89 | flow_low, flow_up = model(im1, im2, iters=20, test_mode=True) 90 | flow = flow_up[0].permute(1,2,0).cpu().numpy()[2:-2,...] 91 | 92 | path_flow = f_im1[:-11] + 'full_flow.npy' 93 | np.save(path_flow, flow) 94 | 95 | -------------------------------------------------------------------------------- /lib/pyramidNet.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Pyramid network 3 | ''' 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | __all__ = ['PyramidCNN', 'pnet'] 9 | 10 | class Block(nn.Module): 11 | 12 | def __init__(self, nChannels, doRes=False, doBN=False, doELU=False): 13 | super(Block, self).__init__() 14 | 15 | if doBN: 16 | self.bn1 = nn.BatchNorm2d(nChannels) 17 | self.bn2 = nn.BatchNorm2d(nChannels) 18 | else: 19 | self.bn1 = [] 20 | self.bn2 = [] 21 | self.conv1 = nn.Conv2d(nChannels, nChannels, kernel_size=3, padding=1, bias=True) 22 | self.conv2 = nn.Conv2d(nChannels, nChannels, kernel_size=3, padding=1, bias=True) 23 | if doELU: 24 | self.relu = nn.ELU() 25 | else: 26 | self.relu = nn.ReLU() 27 | self.doRes = doRes 28 | 29 | def forward(self, x): 30 | 31 | out = x 32 | 33 | if self.bn1: 34 | out = self.bn1(out) 35 | out = self.relu(out) 36 | out = self.conv1(out) 37 | 38 | if self.bn2: 39 | out = self.bn2(out) 40 | out = self.relu(out) 41 | out = self.conv2(out) 42 | 43 | if self.doRes: 44 | out += x 45 | 46 | return out 47 | 48 | def NetBlock(nPerBlock, nChannels, doRes, doBN, doELU): 49 | layers = [] 50 | for _ in range(nPerBlock): 51 | layers.append(Block(nChannels, doRes, doBN, doELU)) 52 | return nn.Sequential(*layers) 53 | 54 | 55 | def PredictBlock(nChannels, outChannels, doBN, doELU): 56 | layers = [] 57 | 58 | if doBN: 59 | layers.append( nn.BatchNorm2d(nChannels) ) 60 | if doELU: 61 | layers.append( nn.ELU() ) 62 | else: 63 | layers.append( nn.ReLU() ) 64 | layers.append( nn.Conv2d(nChannels, outChannels, kernel_size=3, padding=1, bias=True) ) 65 | return nn.Sequential(*layers) 66 | 67 | def PredictBox( nChannels, doBN, doELU): 68 | layers = [] 69 | 70 | if doBN: 71 | layers.append( nn.BatchNorm2d(nChannels) ) 72 | if doELU: 73 | layers.append( nn.ELU() ) 74 | else: 75 | layers.append( nn.ReLU() ) 76 | layers.append( nn.Conv2d(nChannels, 4, kernel_size=3, padding=1, bias=True) ) 77 | layers.append( nn.ReLU() ) 78 | return nn.Sequential(*layers) 79 | 80 | def PredictPixels( nChannels, doBN, doELU): 81 | layers = [] 82 | 83 | if doBN: 84 | layers.append( nn.BatchNorm2d(nChannels) ) 85 | if doELU: 86 | layers.append( nn.ELU() ) 87 | else: 88 | layers.append( nn.ReLU() ) 89 | layers.append( nn.Conv2d(nChannels, 2, kernel_size=3, padding=1, bias=True) ) 90 | return nn.Sequential(*layers) 91 | 92 | class PyramidCNN(nn.Module): 93 | def __init__(self, nLevels, nPred, nPerBlock, nChannels, inChannels, outChannels, doRes, doBN, doELU, predPix, predBoxes): 94 | super(PyramidCNN, self).__init__() 95 | assert( nLevels > 1 ) 96 | #assert( nPred > 0 ) 97 | self.nLevels = nLevels 98 | self.nPred = nPred 99 | self.predBoxes = predBoxes 100 | self.predPix = predPix 101 | if inChannels!=nChannels: 102 | self.inConv = nn.Conv2d(inChannels, nChannels, kernel_size=3, padding=1, bias=True) 103 | else: 104 | self.inConv = [] 105 | self.blocksUp = nn.ModuleList() 106 | self.blocksDown = nn.ModuleList() 107 | self.blocksPred = nn.ModuleList() 108 | if self.predBoxes: 109 | self.predictBox = PredictBox( nChannels, doBN, doELU ) 110 | if self.predPix: 111 | self.predictPix = PredictPixels( nChannels, doBN, doELU) 112 | 113 | for _ in range(nLevels-1): 114 | self.blocksUp.append( NetBlock( nPerBlock, nChannels, doRes, doBN, doELU) ) 115 | for _ in range(nLevels): 116 | self.blocksDown.append( NetBlock( nPerBlock, nChannels, False, doBN, doELU) ) #no res-blocks going down 117 | for _ in range(nPred): 118 | self.blocksPred.append( PredictBlock( nChannels, outChannels, doBN, doELU) ) 119 | 120 | def _addLevel(self, x, level): 121 | x = F.max_pool2d(x, 2, stride=2) 122 | 123 | if level < self.nLevels-1: 124 | x = self.blocksUp[level](x) 125 | y, out = self._addLevel( x, level + 1) 126 | x = x + y 127 | x = self.blocksDown[level](x) 128 | w = F.interpolate( x, scale_factor=2, mode='nearest') 129 | 130 | if level < self.nPred: 131 | z = self.blocksPred[level](x) 132 | out = [z] + out 133 | else: 134 | out = [] 135 | 136 | return w, out 137 | 138 | 139 | def forward(self, x): 140 | 141 | if self.inConv: 142 | x = self.inConv( x ) 143 | 144 | level = 0 145 | x = self.blocksUp[level](x) 146 | y, out = self._addLevel( x, level + 1) 147 | z = x + y 148 | w = self.blocksDown[level](z) 149 | if level < self.nPred: 150 | z = self.blocksPred[level](w) 151 | out = [z] + out 152 | 153 | if self.predBoxes: 154 | out.append( self.predictBox(w) ) 155 | 156 | if self.predPix: 157 | out.append( self.predictPix(w) ) 158 | 159 | return out 160 | 161 | def pnet(**kwargs): 162 | model = PyramidCNN( nLevels = kwargs['nLevels'], 163 | nPred = kwargs['nPred'], 164 | nPerBlock = kwargs['nPerBlock'], 165 | nChannels = kwargs['nChannels'], 166 | inChannels = kwargs['inChannels'], 167 | outChannels = kwargs['outChannels'], 168 | doRes = kwargs['doRes'], 169 | doBN = kwargs['doBN'], 170 | doELU = kwargs['doELU'], 171 | predPix = kwargs['predPix'], 172 | predBoxes = kwargs['predBoxes'] ) 173 | return model 174 | -------------------------------------------------------------------------------- /scripts/prd_full_v.py: -------------------------------------------------------------------------------- 1 | """ 2 | Predict and visualize full velocity 3 | """ 4 | import os 5 | import matplotlib.pyplot as plt 6 | from os.path import join 7 | import numpy as np 8 | import argparse 9 | import torch 10 | from nuscenes.nuscenes import NuScenes 11 | from nuscenes.utils.data_classes import LidarPointCloud 12 | 13 | import _init_paths 14 | from gt_velocity import gt_box_key, cal_trans_matrix, proj2im, flow2uv_map 15 | from pt_wise_error import get_im_pair, cal_full_v_in_radar, correct_coord, upsample_coord, downsample_coord 16 | 17 | 18 | def plot_flow(im, flow, title='flow', color='cyan', step=20): 19 | h, w = im.shape[:2] 20 | 21 | x1, y1 = np.meshgrid(np.arange(0,w), np.arange(0,h)) 22 | 23 | dx = flow[...,0] 24 | dy = flow[...,1] 25 | 26 | # plt.figure() 27 | plt.imshow(im) 28 | for i in range(0, h, step): 29 | for j in range(0, w, step): 30 | plt.arrow(x1[i,j], y1[i,j], dx[i,j], dy[i,j], length_includes_head=True, width=0.2, head_width=2, color=color) 31 | 32 | plt.title(title) 33 | plt.show() 34 | 35 | 36 | def pltRadarWithV(x,y,vx,vy, color='red', zorder=1): 37 | plt.scatter(x,y,s=5) 38 | for i in range(len(x)): 39 | plt.arrow(x[i], y[i], vx[i], vy[i], length_includes_head=True, width=0.05, head_width=0.3, color=color, zorder=zorder) 40 | 41 | plt.xlabel('x') 42 | plt.ylabel('y') 43 | plt.axis('equal') 44 | # plt.show() 45 | 46 | 47 | if __name__ == '__main__': 48 | parser = argparse.ArgumentParser() 49 | parser.add_argument('--dir_data', type=str) 50 | parser.add_argument('--version', type=str, default='v1.0-trainval', help='dataset split') 51 | parser.add_argument('--sample_idx', type=int, default=17659) 52 | 53 | args = parser.parse_args() 54 | 55 | if args.dir_data == None: 56 | this_dir = os.path.dirname(__file__) 57 | args.dir_data = join(this_dir, '..', 'data') 58 | 59 | args.dir_nuscenes = join(args.dir_data, 'nuscenes') 60 | dir_files = join(args.dir_data, 'prepared_data') 61 | 62 | sample_indices = torch.load(join(args.dir_data,'sample_split.tar'))['test_sample_indices'] 63 | nusc = NuScenes(args.version, dataroot = args.dir_nuscenes, verbose=False) 64 | 65 | cam_token1 = nusc.sample[args.sample_idx]['data']['CAM_FRONT'] 66 | im1, im2, f, cx, cy, dt, cam_token2 = get_im_pair(nusc, cam_token1) 67 | 68 | gt = gt_box_key(nusc, args.sample_idx) 69 | rd_token = gt.radar_token 70 | 71 | gt_rd = gt.radar 72 | x_list, y_list, vx_list, vy_list, msk_gt = gt_rd['x'], gt_rd['y'], gt_rd['vx'], gt_rd['vy'], gt_rd['having_truth'] 73 | 74 | pc = LidarPointCloud( np.stack([x_list, y_list, np.zeros_like(x_list), np.ones_like(x_list)]) ) 75 | 76 | T_r2c = cal_trans_matrix(nusc, rd_token, cam_token1) 77 | T_c2r = cal_trans_matrix(nusc, cam_token1, rd_token) 78 | T_c2c = cal_trans_matrix(nusc, cam_token1, cam_token2) 79 | 80 | pc.transform(T_r2c) 81 | 82 | xi_list, yi_list, d_list, msk_in_im = proj2im(nusc, pc, cam_token1) 83 | xi_list0, yi_list0 = xi_list, yi_list 84 | xi_list, yi_list = downsample_coord(xi_list, yi_list, downsample_scale=4, y_cutoff=33) 85 | 86 | rd_offset = np.load(join(dir_files, '%05d_offset.npy' % args.sample_idx)) 87 | flow = np.load(join(dir_files, '%05d_full_flow.npy' % args.sample_idx)) 88 | u1_map, v1_map, u2_map, v2_map = flow2uv_map(flow, cx, cy, f) 89 | 90 | xi_list, yi_list = correct_coord(xi_list, yi_list, rd_offset) 91 | xi_list, yi_list = upsample_coord(xi_list, yi_list, downsample_scale=4, y_cutoff=33) 92 | 93 | prd_vxf = [] 94 | prd_vyf = [] 95 | prd_msk = [] 96 | for xi, yi, vx, vy, d, is_in_im in zip(xi_list, yi_list, vx_list, vy_list, d_list, msk_in_im): 97 | 98 | if [xi,yi] != [-1,-1]: 99 | xp, yp = int(round(xi)), int(round(yi)) 100 | u1,v1,u2,v2 = u1_map[yp,xp], v1_map[yp,xp], u2_map[yp,xp], v2_map[yp,xp] 101 | vx_f, vy_f, vz_f = cal_full_v_in_radar(vx, vy, d, u1, v1, u2, v2, T_c2r, T_c2c, dt) 102 | 103 | prd_vxf.append(vx_f) 104 | prd_vyf.append(vy_f) 105 | prd_msk.append(True) 106 | else: 107 | prd_vxf.append(0) 108 | prd_vyf.append(0) 109 | prd_msk.append(False) 110 | 111 | vxf_list, vyf_list, gt_msk = gt_rd['vxf'], gt_rd['vyf'], gt_rd['having_truth'] 112 | 113 | plt.close('all') 114 | ## plot raw radar depth and radar-pixel association 115 | plt.figure() 116 | plt.imshow(im1) 117 | for x0,y0,x1,y1 in zip(xi_list0, yi_list0, xi_list, yi_list): 118 | if x1 != -1: 119 | plt.arrow(x=x0, y=y0, dx=(x1-x0), dy=(y1-y0), length_includes_head=True, width=2, head_width=5, color='yellow') 120 | plt.scatter(xi_list0, yi_list0, c=d_list, s=10, cmap='jet') 121 | 122 | plt.figure() 123 | plot_flow(im1, flow) 124 | 125 | # plot predicted velocity and Doppler velocity 126 | plt.figure() 127 | for x, y, vx, vy, prd_vx, prd_vy, gt_vx, gt_vy, gt_valid, prd_valid in zip(x_list, y_list, vx_list, vy_list, prd_vxf, prd_vyf, vxf_list, vyf_list, gt_msk, prd_msk): 128 | if gt_valid and prd_valid: 129 | pltRadarWithV([x],[y],[prd_vx],[prd_vy], 'black', zorder=5) 130 | pltRadarWithV([x],[y],[vx],[vy], 'red') 131 | 132 | # plot box with GT velocity 133 | for obj_token in gt.boxes_radar: 134 | box_rd = gt.boxes_radar[obj_token] 135 | 136 | poly = box_rd['polygon'] 137 | xy_list = list(poly.exterior.coords) 138 | x_temp = [xy[0] for xy in xy_list] 139 | y_temp = [xy[1] for xy in xy_list] 140 | 141 | plt.plot(x_temp, y_temp, color='purple', linewidth=2) 142 | x_ct, y_ct = box_rd['center'][:2] 143 | plt.arrow(x_ct, y_ct, dx=box_rd['v'][0], dy=box_rd['v'][1], length_includes_head=True, width=0.08, head_width=0.3, color='green') 144 | plt.axis('equal') 145 | plt.show() 146 | 147 | -------------------------------------------------------------------------------- /lib/data_loader_v.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from os.path import join 4 | import h5py 5 | import torch 6 | from torch.utils import data 7 | from nuscenes.nuscenes import NuScenes 8 | 9 | from gt_velocity import gt_box_key, cal_depthMap, cal_trans_matrix, cal_matrix_refSensor_to_global 10 | from nb_utils import sparse_neighbor_connection 11 | 12 | 13 | def get_intrinsic_matrix(nusc, cam_token): 14 | cam_data = nusc.get('sample_data', cam_token) 15 | cs_rec = nusc.get('calibrated_sensor', cam_data['calibrated_sensor_token']) 16 | 17 | return np.array( cs_rec['camera_intrinsic'] ) 18 | 19 | 20 | def cal_uv1(h, w, K, downsample_scale=4, y_cutoff=33): 21 | ''' 22 | uv_map: h x w x 2 23 | ''' 24 | f = K[0,0] 25 | cx = K[0,2] 26 | cy = K[1,2] 27 | 28 | x_map, y_map = np.meshgrid(np.arange(w), np.arange(h)) 29 | x_map, y_map = x_map.astype('float32'), y_map.astype('float32') 30 | 31 | cx = cx / downsample_scale 32 | cy = cy / downsample_scale - y_cutoff 33 | f = f / downsample_scale 34 | 35 | u_map = (x_map - cx) / f 36 | v_map = (y_map - cy) / f 37 | 38 | uv_map = np.stack([u_map,v_map], axis=2) 39 | 40 | return uv_map 41 | 42 | 43 | def cal_uv_translation(uv2, R, msk_uv2=None): 44 | ''' 45 | inputs: 46 | uv2: (2 x h x w); full flow 47 | R: rotaion matrix (from u1,v1 -> u2,v2) 48 | msk_uv2: h x w 49 | output: 50 | uvt2: flow from translation inv(R)*t 51 | ''' 52 | u2, v2 = uv2[0], uv2[1] 53 | 54 | R_inv = np.linalg.inv(R) 55 | r11, r12, r13, r21, r22, r23, r31, r32, r33 = R_inv.flatten() 56 | ut = (u2*r11 + v2*r12 + r13) / (u2*r31 + v2*r32 +r33) 57 | vt = (u2*r21 + v2*r22 + r23) / (u2*r31 + v2*r32 +r33) 58 | 59 | if msk_uv2 is not None: 60 | ut = ut * msk_uv2 61 | vt = vt * msk_uv2 62 | 63 | uvt2 = np.stack([ut,vt]) 64 | 65 | return uvt2 66 | 67 | 68 | def init_data_loader(args, mode): 69 | 70 | if mode == 'train': 71 | batch_size = args.batch_size 72 | if args.no_data_shuffle: 73 | shuffle = False 74 | else: 75 | shuffle = True 76 | else: 77 | batch_size = args.test_batch_size 78 | shuffle = False 79 | 80 | nusc = NuScenes(version = args.version, dataroot = args.dir_nuscenes, verbose=False) 81 | 82 | args_dataset = {'path_data_file': args.path_data_file, 83 | 'mode': mode, 84 | 'nusc':nusc, 85 | 'nb':args.nb} 86 | args_data_loader = {'batch_size': batch_size, 87 | 'shuffle': shuffle, 88 | 'num_workers': args.num_workers} 89 | dataset = Dataset(**args_dataset) 90 | data_loader = torch.utils.data.DataLoader(dataset, **args_data_loader) 91 | 92 | return data_loader 93 | 94 | 95 | class Dataset(data.Dataset): 96 | def __init__(self, path_data_file, mode, nusc, nb): 97 | data = h5py.File(path_data_file, 'r')[mode] 98 | self.nusc = nusc 99 | self.im_list = data['im'][...] 100 | self.uv2_im_list = data['im_uv'][...].astype('f4') 101 | self.indices = data['indices'] 102 | self.nb = nb 103 | 104 | def __len__(self): 105 | return len(self.indices) 106 | 107 | def __getitem__(self, idx): 108 | 109 | sample_idx = self.indices[idx] 110 | gt = gt_box_key(self.nusc, sample_idx) 111 | rd = gt.radar 112 | uv2_im = self.uv2_im_list[idx].astype('float32') 113 | h, w = uv2_im.shape[:2] 114 | 115 | cam_data = self.nusc.get('sample_data', gt.cam_token) 116 | cam_token2 = cam_data['next'] 117 | dt = (self.nusc.get('sample_data', cam_token2)['timestamp'] - cam_data['timestamp']) * 1e-6 118 | 119 | cam_token1 = gt.cam_token 120 | rd_token = gt.radar_token 121 | 122 | K = get_intrinsic_matrix(self.nusc, cam_token1) 123 | uv1_im = cal_uv1(h, w, K, downsample_scale=4, y_cutoff=33) 124 | uv = np.concatenate([uv1_im, uv2_im], axis=2) 125 | 126 | T_c2r = cal_trans_matrix(self.nusc, cam_token1, rd_token) 127 | T_c2c = cal_trans_matrix(self.nusc, cam_token1, cam_token2) 128 | T_c2w = cal_matrix_refSensor_to_global(self.nusc, cam_token1) 129 | 130 | depth_map, _, msk_nb_map, vx_nb_map, vy_nb_map, _, vx_gt_map, vy_gt_map = \ 131 | cal_depthMap(rd, uv, T_c2r, T_c2c, T_c2w, dt, self.nb, downsample_scale=4, y_cutoff=33) 132 | 133 | error_map = ( (vx_nb_map - vx_gt_map[...,None])**2 + (vy_nb_map - vy_gt_map[...,None])**2 )**0.5 134 | 135 | error_map = error_map.transpose((2,0,1)) # (n,h,w) 136 | msk_nb_map = msk_nb_map.transpose((2,0,1)) 137 | 138 | im1 = self.im_list[idx].astype('float32').transpose((2,0,1))/255 # (3,h,w) 139 | R = T_c2c[:3,:3] 140 | 141 | uv1_im = uv1_im.transpose((2,0,1)) # (2,h,w) 142 | uv2_im = uv2_im.transpose((2,0,1)) # (2,h,w) 143 | 144 | d_radar = depth_map[None,...].astype('float32') # (1,h,w) 145 | 146 | scale_factor = 30 147 | 148 | uvt2_im = cal_uv_translation(uv2_im, R) 149 | duv_im = (uvt2_im - uv1_im) * scale_factor 150 | 151 | data_in = np.concatenate((im1, uv1_im, duv_im, d_radar), axis=0) # (8,h,w) 152 | 153 | sample = {'data_in': data_in, 'sample_idx': self.indices[idx], 'error': error_map, 'msk': msk_nb_map} 154 | 155 | return sample 156 | 157 | 158 | if __name__=='__main__': 159 | this_dir = os.path.dirname(__file__) 160 | dir_data = join(this_dir, '..', 'data') 161 | dir_nuscenes = join(dir_data, 'nuscenes') 162 | path_data_file = join(dir_data, 'prepared_data.h5') 163 | nb = sparse_neighbor_connection(*(4, 4, 10, 4)) 164 | nusc = NuScenes(version = 'v1.0-trainval', dataroot = dir_nuscenes, verbose=False) 165 | 166 | args_dataset = {'path_data_file': path_data_file, 167 | 'mode': 'train', 168 | 'nusc': nusc, 169 | 'nb':nb} 170 | args_train_loader = {'batch_size': 6, 171 | 'shuffle': True, 172 | 'num_workers': 0} 173 | 174 | train_set = Dataset(**args_dataset) 175 | train_loader = torch.utils.data.DataLoader(train_set, **args_train_loader) 176 | 177 | data_iterator = enumerate(train_loader) 178 | 179 | batch_idx, sample = next(data_iterator) 180 | 181 | print('batch_idx', batch_idx) 182 | print('data_in', sample['data_in'].shape, type(sample['data_in']),sample['data_in'].dtype) 183 | print('error', sample['error'].shape, type(sample['error']), sample['error'].dtype) 184 | print('msk', sample['msk'].shape, type(sample['msk']), sample['msk'].dtype) 185 | -------------------------------------------------------------------------------- /scripts/test_association.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import argparse 4 | import os 5 | from os.path import join 6 | import sys 7 | from tqdm import tqdm 8 | import torch 9 | import torch.backends.cudnn as cudnn 10 | 11 | import _init_paths 12 | from pyramidNet import PyramidCNN 13 | from data_loader_v import init_data_loader 14 | from nb_utils import sparse_neighbor_connection 15 | 16 | 17 | def load_weights(args, model): 18 | f_checkpoint = join(args.dir_result, 'checkpoint.tar') 19 | if os.path.isfile(f_checkpoint): 20 | print('load best model') 21 | model.load_state_dict(torch.load(f_checkpoint)['state_dict_best']) 22 | else: 23 | sys.exit('No model found') 24 | 25 | 26 | def init_env(): 27 | use_cuda = torch.cuda.is_available() 28 | device = torch.device("cuda" if use_cuda else "cpu") 29 | cudnn.benchmark = True if use_cuda else False 30 | return device 31 | 32 | 33 | def plt_depth_on_im(depth_map, im, title = '', ptsSize = 1): 34 | 35 | h,w = im.shape[0:2] 36 | x_map, y_map = np.meshgrid(np.arange(w), np.arange(h)) 37 | msk = depth_map > 0 38 | 39 | plt.figure() 40 | plt.imshow(im) 41 | plt.scatter(x_map[msk], y_map[msk], c=depth_map[msk], s=ptsSize, cmap='jet') 42 | plt.title(title, fontsize=20) 43 | plt.colorbar() 44 | plt.axis('off') 45 | 46 | 47 | def cal_radar_position_offset(prd_aff, d_radar, nb, thres_aff=0.5): 48 | ''' 49 | inputs: 50 | prd_aff: n_nb x h x w 51 | d_radar: h x w 52 | outputs: 53 | offset: numpy (h,w,2); at each pixel (dx,dy); no association (dx,dy)=(-1000,-1000) 54 | ''' 55 | h,w = d_radar.shape 56 | offset = -1000 * np.ones((h,w,2)) 57 | xy_list = nb.xy 58 | 59 | max_aff = np.max(prd_aff, axis=0) 60 | idx_max = np.argmax(prd_aff, axis=0) 61 | 62 | msk_radar = d_radar > 0 63 | max_aff = max_aff * msk_radar 64 | 65 | for i in range(h): 66 | for j in range(w): 67 | if d_radar[i,j] > 0 and max_aff[i,j] > thres_aff: 68 | offset[i,j,:] = xy_list[idx_max[i,j]] 69 | 70 | return offset, max_aff 71 | 72 | 73 | def prd_one_sample(model, nb, test_loader, device, sample_idx = 1, thres_aff = 0.3): 74 | 75 | def plt_association_on_im(depth_map, im, pos_offset, title = '', ptsSize = 1): 76 | h,w = im.shape[0:2] 77 | x_map, y_map = np.meshgrid(np.arange(w), np.arange(h)) 78 | msk = depth_map > 0 79 | 80 | plt.figure() 81 | plt.imshow(im) 82 | 83 | for i in range(h): 84 | for j in range(w): 85 | dx, dy = pos_offset[i,j,:] 86 | if [dx,dy] != [-1000,-1000]: 87 | plt.arrow(j, i, dx, dy, length_includes_head=True, width=0.1, head_width=0.2, color='yellow') 88 | 89 | plt.scatter(x_map[msk], y_map[msk], c=depth_map[msk], s=ptsSize, cmap='jet') 90 | plt.title(title, fontsize=20) 91 | plt.colorbar() 92 | plt.axis('off') 93 | plt.show() 94 | 95 | with torch.no_grad(): 96 | for ct, sample in enumerate(test_loader): 97 | s_idx = sample['sample_idx'] 98 | if s_idx == sample_idx: 99 | data_in = sample['data_in'].to(device) 100 | prd = torch.sigmoid( model(data_in)[0] ) 101 | d_radar_tensor = data_in[:,[7],...] 102 | im = data_in[0][0:3].permute(1,2,0).to('cpu').numpy() 103 | d_radar = d_radar_tensor[0][0].to('cpu').numpy() 104 | prd = prd[0].cpu().numpy() 105 | break 106 | 107 | pos_offset, max_aff = cal_radar_position_offset(prd, d_radar, nb, thres_aff) 108 | plt.close('all') 109 | plt_association_on_im(d_radar, im, pos_offset, title = 'Association', ptsSize = 30) 110 | 111 | 112 | def gen_offset_map(model, test_loader, device, args): 113 | test_indices = [] 114 | with torch.no_grad(): 115 | for ct, sample in enumerate(tqdm(test_loader)): 116 | data_in, sample_idx = sample['data_in'].to(device), sample['sample_idx'][0].item() 117 | test_indices.append(sample_idx) 118 | prd = torch.sigmoid( model(data_in)[0] ) 119 | d_radar = data_in[0][7].to('cpu').numpy() 120 | prd = prd[0].cpu().numpy() 121 | pos_offset, _ = cal_radar_position_offset(prd, d_radar, args.nb, thres_aff=0.3) 122 | np.save(join(args.output_folder, '%05d_offset.npy' % sample_idx), pos_offset) 123 | 124 | 125 | def main(args): 126 | if args.dir_data == None: 127 | this_dir = os.path.dirname(__file__) 128 | args.dir_data = join(this_dir, '..', 'data') 129 | 130 | if not args.dir_result: 131 | args.dir_result = join(args.dir_data, 'train_result', '%d_%d_%d' % (args.left_right, args.top, args.bottom)) 132 | args.path_data_file = join(args.dir_data, 'prepared_data.h5') 133 | args.output_folder = join(args.dir_data, 'prepared_data') 134 | args.dir_nuscenes = join(args.dir_data, 'nuscenes') 135 | 136 | args.nb = sparse_neighbor_connection(*(args.left_right, args.left_right, args.top, args.bottom, args.skip)) 137 | args.outChannels = len(args.nb.xy) 138 | 139 | device = init_env() 140 | 141 | test_loader = init_data_loader(args, 'test') 142 | 143 | model = PyramidCNN(args.nLevels, args.nPred, args.nPerBlock, 144 | args.nChannels, args.inChannels, args.outChannels, 145 | args.doRes, args.doBN, doELU=False, 146 | predPix=False, predBoxes=False).to(device) 147 | 148 | load_weights(args, model) 149 | model.eval() 150 | 151 | if args.gen_offset: 152 | gen_offset_map(model, test_loader, device, args) 153 | else: 154 | sample_idx = 17659 155 | thres_aff = 0.3 156 | prd_one_sample(model, args.nb, test_loader, device, sample_idx, thres_aff) 157 | 158 | 159 | if __name__ == '__main__': 160 | parser = argparse.ArgumentParser() 161 | parser.add_argument('--dir_data', type=str) 162 | parser.add_argument('--dir_result', type=str) 163 | parser.add_argument('--version', type=str, default='v1.0-trainval', help='dataset split') 164 | 165 | parser.add_argument('--test_batch_size', type=int, default=1) 166 | parser.add_argument('--batch_size', type=int, default=1) 167 | parser.add_argument('--no_data_shuffle', type=bool, default=True) 168 | parser.add_argument('--num_workers', type=int, default=0) 169 | 170 | parser.add_argument('--nLevels', type=int, default=5) 171 | parser.add_argument('--nPred', type=int, default=1) 172 | parser.add_argument('--nPerBlock', type=int, default=1) 173 | parser.add_argument('--nChannels', type=int, default=64) 174 | parser.add_argument('--inChannels', type=int, default=8) 175 | parser.add_argument('--doRes', type=bool, default=True) 176 | parser.add_argument('--doBN', type=bool, default=True) 177 | 178 | parser.add_argument('--left_right', type=int, default=4) 179 | parser.add_argument('--top', type=int, default=10) 180 | parser.add_argument('--bottom', type=int, default=4) 181 | parser.add_argument('--skip', type=int, default=1) 182 | 183 | parser.add_argument('--gen_offset', action='store_true', default=False, help='generate predicted offsets') 184 | 185 | args = parser.parse_args() 186 | 187 | main(args) 188 | 189 | -------------------------------------------------------------------------------- /scripts/train_association.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import argparse 4 | import os 5 | from os.path import join 6 | from timeit import default_timer as timer 7 | import copy 8 | from tqdm import tqdm 9 | import torch 10 | import torch.backends.cudnn as cudnn 11 | 12 | import _init_paths 13 | from data_loader_v import init_data_loader 14 | from pyramidNet import PyramidCNN 15 | from nb_utils import sparse_neighbor_connection 16 | 17 | 18 | def CE_loss(aff_prd, error, msk, thres=0.5): 19 | c = thres**2/np.log(2) 20 | msk = msk.float() 21 | lb_aff = torch.exp(-error**2/c) 22 | n_pixel = torch.sum(msk>0).float() 23 | loss = torch.sum( (-lb_aff * aff_prd + torch.log(1+torch.exp(aff_prd)) ) * msk ) / n_pixel 24 | return loss 25 | 26 | 27 | def train(log_interval, model, device, train_loader, optimizer, epoch, nb): 28 | model.train() 29 | ave_loss=0 30 | 31 | for batch_idx, sample in enumerate(train_loader): 32 | data_in, error, msk = sample['data_in'].to(device), sample['error'].to(device), sample['msk'].to(device) 33 | 34 | optimizer.zero_grad() 35 | 36 | prd = model(data_in)[0] 37 | 38 | loss = CE_loss(prd, error, msk) 39 | ave_loss += loss.item() 40 | 41 | loss.backward() 42 | optimizer.step() 43 | if batch_idx % log_interval == 0: 44 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.10f}'.format( 45 | epoch, batch_idx * len(data_in), len(train_loader.dataset), 46 | 100. * batch_idx / len(train_loader), loss.item())) 47 | 48 | ave_loss/=len(train_loader) 49 | print('\nTraining set: Average loss: {:.7f}\n'.format(ave_loss)) 50 | return ave_loss 51 | 52 | 53 | def test(model, device, test_loader, nb): 54 | model.eval() 55 | test_loss = 0 56 | with torch.no_grad(): 57 | for sample in tqdm(test_loader, 'Validation'): 58 | data_in, error, msk = sample['data_in'].to(device), sample['error'].to(device), sample['msk'].to(device) 59 | prd = model(data_in)[0] 60 | loss = CE_loss(prd, error, msk) 61 | 62 | test_loss += loss.item() 63 | 64 | test_loss/= len(test_loader) 65 | 66 | print('\nTest set: Average loss: {:.7f}\n'.format(test_loss)) 67 | 68 | return test_loss 69 | 70 | 71 | def save_arguments(args): 72 | f = open(join(args.dir_result,'args.txt'),'w') 73 | f.write(repr(args)+'\n') 74 | f.close() 75 | 76 | 77 | def mkdir(dir1): 78 | if not os.path.exists(dir1): 79 | os.makedirs(dir1) 80 | print('make directory %s' % dir1) 81 | 82 | 83 | def init_params(args, model, optimizer): 84 | loss_train=[] 85 | loss_val=[] 86 | 87 | start_epoch = 1 88 | state_dict_best = None 89 | loss_val_min = None 90 | if args.resume == True: 91 | f_checkpoint = join(args.dir_result, 'checkpoint.tar') 92 | if os.path.isfile(f_checkpoint): 93 | print('Resume training') 94 | checkpoint = torch.load(f_checkpoint) 95 | model.load_state_dict(checkpoint['state_dict']) 96 | optimizer.load_state_dict(checkpoint['optimizer']) 97 | start_epoch = checkpoint['epoch'] + 1 98 | loss_train, loss_val = checkpoint['loss'] 99 | loss_val_min = checkpoint['loss_val_min'] 100 | state_dict_best = checkpoint['state_dict_best'] 101 | else: 102 | print('No checkpoint file is found.') 103 | 104 | return loss_train, loss_val, start_epoch, state_dict_best, loss_val_min 105 | 106 | 107 | def save_checkpoint(epoch, model, optimizer, loss_train, loss_val, loss_val_min, state_dict_best, args): 108 | if epoch == 1: 109 | loss_val_min = loss_val[-1] 110 | state_dict_best = copy.deepcopy( model.state_dict() ) 111 | elif loss_val[-1] < loss_val_min: 112 | loss_val_min = loss_val[-1] 113 | state_dict_best = copy.deepcopy( model.state_dict() ) 114 | 115 | state = {'epoch': epoch, 116 | 'state_dict': model.state_dict(), 117 | 'optimizer': optimizer.state_dict(), 118 | 'loss': [loss_train, loss_val], 119 | 'loss_val_min': loss_val_min, 120 | 'state_dict_best': state_dict_best} 121 | 122 | torch.save(state, join(args.dir_result, 'checkpoint.tar')) 123 | if epoch % 5 == 0: 124 | torch.save(state, join(args.dir_result, 'checkpoint_%d.tar' % epoch)) 125 | 126 | return loss_val_min, state_dict_best 127 | 128 | 129 | def plot_and_save_loss_curve(epoch, loss_train, loss_val): 130 | plt.close('all') 131 | plt.figure() 132 | t=np.arange(1,epoch+1) 133 | plt.plot(t,loss_train,'b.-') 134 | plt.plot(t,loss_val,'r.-') 135 | plt.grid() 136 | plt.legend(['training loss','testing loss'],loc='best') 137 | plt.xlabel('Epoch') 138 | plt.ylabel('Loss') 139 | plt.yscale('log') 140 | plt.title('loss in logscale') 141 | plt.savefig(join(args.dir_result, 'loss.png')) 142 | 143 | 144 | def init_env(): 145 | torch.manual_seed(args.seed) 146 | use_cuda = torch.cuda.is_available() 147 | device = torch.device("cuda" if use_cuda else "cpu") 148 | cudnn.benchmark = True if use_cuda else False 149 | return device 150 | 151 | 152 | def main(args): 153 | 154 | if args.dir_data == None: 155 | this_dir = os.path.dirname(__file__) 156 | args.dir_data = join(this_dir, '..', 'data') 157 | 158 | args.dir_nuscenes = join(args.dir_data, 'nuscenes') 159 | 160 | if not args.dir_result: 161 | args.dir_result = join(args.dir_data, 'train_result', '%d_%d_%d' % (args.left_right, args.top, args.bottom)) 162 | mkdir(args.dir_result) 163 | 164 | args.nb = sparse_neighbor_connection(*(args.left_right, args.left_right, args.top, args.bottom, args.skip)) 165 | args.outChannels = len(args.nb.xy) 166 | print('output channels: ', args.outChannels) 167 | 168 | args.path_data_file = join(args.dir_data, 'prepared_data.h5') 169 | save_arguments(args) 170 | device = init_env() 171 | 172 | model = PyramidCNN(args.nLevels, args.nPred, args.nPerBlock, 173 | args.nChannels, args.inChannels, args.outChannels, 174 | args.doRes, args.doBN, doELU=False, 175 | predPix=False, predBoxes=False).to(device) 176 | 177 | optimizer = torch.optim.RMSprop(model.parameters(), 178 | lr = args.lr, 179 | weight_decay = 0, 180 | momentum = args.momentum) 181 | 182 | loss_train, loss_val, start_epoch, state_dict_best, loss_val_min = \ 183 | init_params(args, model, optimizer) 184 | 185 | train_loader = init_data_loader(args, 'train') 186 | val_loader = init_data_loader(args, 'val') 187 | 188 | for epoch in range(start_epoch, args.epochs + 1): 189 | start = timer() 190 | 191 | loss_train.append(train(args.log_interval, model, device, train_loader, optimizer, epoch, args.nb)) 192 | loss_val_epoch = test(model, device, val_loader, args.nb) 193 | 194 | loss_val.append(loss_val_epoch) 195 | 196 | loss_val_min, state_dict_best = save_checkpoint(epoch, model, optimizer, loss_train, loss_val, loss_val_min, state_dict_best, args) 197 | plot_and_save_loss_curve(epoch, loss_train, loss_val) 198 | 199 | end = timer(); t = (end - start) / 60; print('Time used: %.1f minutes\n' % t) 200 | 201 | if args.do_test: 202 | test_loader = init_data_loader(args, 'test') 203 | f_checkpoint = join(args.dir_result, 'checkpoint.tar') 204 | if os.path.isfile(f_checkpoint): 205 | print('load best model') 206 | checkpoint = torch.load(f_checkpoint) 207 | model.load_state_dict(checkpoint['state_dict_best']) 208 | loss_test = test(model, device, test_loader, args.nb) 209 | print('testing loss:', loss_test) 210 | 211 | if __name__ == '__main__': 212 | parser = argparse.ArgumentParser(description='training parameters') 213 | parser.add_argument('--dir_data', type=str) 214 | parser.add_argument('--dir_result', type=str) 215 | parser.add_argument('--version', type=str, default='v1.0-trainval', help='dataset split') 216 | 217 | parser.add_argument('--seed', type=int, default=1) 218 | parser.add_argument('--epochs', type=int, default=10) 219 | parser.add_argument('--resume', action='store_true', default=False, help='resume training from checkpoint') 220 | parser.add_argument('--batch_size', type=int, default=8) 221 | parser.add_argument('--test_batch_size', type=int, default=4) 222 | parser.add_argument('--log_interval', type=int, default=5) 223 | parser.add_argument('--lr', type=float, default=5e-5, help='Learning rate') 224 | parser.add_argument('--momentum', type=float, default=0.9) 225 | parser.add_argument('--num_workers', type=int, default=0) 226 | parser.add_argument('--no_data_shuffle', type=bool, default=False) 227 | 228 | parser.add_argument('--nLevels', type=int, default=5) 229 | parser.add_argument('--nPred', type=int, default=1) 230 | parser.add_argument('--nPerBlock', type=int, default=1) 231 | parser.add_argument('--nChannels', type=int, default=64) 232 | parser.add_argument('--inChannels', type=int, default=8) 233 | parser.add_argument('--doRes', type=bool, default=True) 234 | parser.add_argument('--doBN', type=bool, default=True) 235 | parser.add_argument('--do_test', type=bool, default=True, help='compute loss for testing set') 236 | 237 | parser.add_argument('--left_right', type=int, default=4) 238 | parser.add_argument('--top', type=int, default=10) 239 | parser.add_argument('--bottom', type=int, default=4) 240 | parser.add_argument('--skip', type=int, default=1) 241 | 242 | args = parser.parse_args() 243 | main(args) 244 | 245 | -------------------------------------------------------------------------------- /scripts/pt_wise_error.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join 3 | import numpy as np 4 | import argparse 5 | import skimage.io as io 6 | from tqdm import tqdm 7 | from nuscenes.nuscenes import NuScenes 8 | import torch 9 | from nuscenes.utils.data_classes import LidarPointCloud 10 | 11 | import _init_paths 12 | from gt_velocity import gt_box_key, cal_trans_matrix, proj2im, flow2uv_map 13 | 14 | 15 | def downsample_coord(xi_list, yi_list, downsample_scale=4, y_cutoff=33): 16 | 17 | h_im, w_im = 900, 1600 18 | h_new = int( h_im / downsample_scale ) 19 | w_new = int( w_im / downsample_scale ) 20 | 21 | xi_list_new = (xi_list + 0.5) / downsample_scale - 0.5 22 | yi_list_new = (yi_list + 0.5) / downsample_scale - 0.5 23 | 24 | xi_list_new = np.clip(xi_list_new, 0, w_new - 1) 25 | yi_list_new = np.clip(yi_list_new, 0, h_new - 1) - y_cutoff 26 | 27 | return xi_list_new, yi_list_new 28 | 29 | 30 | def upsample_coord(xi_list, yi_list, downsample_scale=4, y_cutoff=33): 31 | 32 | msk = xi_list == -1 33 | 34 | h_im, w_im = 900, 1600 35 | 36 | xi_list_new = xi_list 37 | yi_list_new = yi_list + y_cutoff 38 | 39 | xi_list_new *= downsample_scale 40 | yi_list_new *= downsample_scale 41 | 42 | xi_list_new = np.clip(xi_list_new, 0, w_im-1) 43 | yi_list_new = np.clip(yi_list_new, 0, h_im-1) 44 | 45 | xi_list_new[msk] = -1 46 | yi_list_new[msk] = -1 47 | 48 | return xi_list_new, yi_list_new 49 | 50 | 51 | def correct_coord(xi_list, yi_list, rd_offset): 52 | 53 | xi_list_new = [] 54 | yi_list_new = [] 55 | 56 | for xi, yi in zip(xi_list, yi_list): 57 | x_one, y_one = int(round( xi )), int(round( yi )) 58 | dx, dy = rd_offset[y_one, x_one] 59 | if [dx, dy] != [-1000,-1000]: 60 | x_new, y_new = x_one + dx, y_one + dy 61 | else: 62 | x_new, y_new = -1, -1 63 | xi_list_new.append(x_new) 64 | yi_list_new.append(y_new) 65 | 66 | xi_list_new = np.array(xi_list_new) 67 | yi_list_new = np.array(yi_list_new) 68 | 69 | return xi_list_new, yi_list_new 70 | 71 | 72 | def get_im_pair(nusc, cam_token): 73 | cam_data = nusc.get('sample_data', cam_token) 74 | K = np.array( nusc.get('calibrated_sensor', cam_data['calibrated_sensor_token'])['camera_intrinsic'] ) 75 | f = K[0,0] 76 | cx = K[0,2] 77 | cy = K[1,2] 78 | 79 | cam_path = join(nusc.dataroot, cam_data['filename']) 80 | im1 = io.imread(cam_path) 81 | 82 | cam_token2 = cam_data['next'] 83 | cam_data2 = nusc.get('sample_data', cam_token2) 84 | cam_path2 = join(nusc.dataroot, cam_data2['filename']) 85 | im2 = io.imread(cam_path2) 86 | 87 | dt = (cam_data2['timestamp'] - cam_data['timestamp']) * 1e-6 88 | 89 | return im1, im2, f, cx, cy, dt, cam_token2 90 | 91 | 92 | def cal_full_v_in_radar(vx, vy, d, u1, v1, u2, v2, T_c2r, T_c2c, dt): 93 | # output in radar coordinates 94 | r11, r12, r13 = T_c2r[0,:3] 95 | r21, r22, r23 = T_c2r[1,:3] 96 | 97 | ra11, ra12, ra13, btx = T_c2c[0,:] 98 | ra21, ra22, ra23, bty = T_c2c[1,:] 99 | ra31, ra32, ra33, btz = T_c2c[2,:] 100 | 101 | A = np.array([[ra11-u2*ra31, ra12-u2*ra32, ra13-u2*ra33], \ 102 | [ra21-v2*ra31, ra22-v2*ra32, ra23-v2*ra33], \ 103 | [r11*vx+r21*vy, r12*vx+r22*vy, r13*vx+r23*vy]] ) 104 | 105 | b = np.array([[((ra31*u1+ra32*v1+ra33)*u2-(ra11*u1+ra12*v1+ra13))*d+u2*btz-btx],\ 106 | [((ra31*u1+ra32*v1+ra33)*v2-(ra21*u1+ra22*v1+ra23))*d+v2*btz-bty],\ 107 | [(vx**2 + vy**2)*dt]]) 108 | 109 | x = np.squeeze( np.dot( np.linalg.inv(A), b ) ) 110 | 111 | vx_c, vy_c, vz_c = x[0]/dt, x[1]/dt, x[2]/dt 112 | 113 | vr = np.squeeze( np.dot(T_c2r[:3,:3], np.array([[vx_c], [vy_c], [vz_c]])) ) 114 | 115 | vx_f, vy_f, vz_f = vr[0], vr[1], vr[2] 116 | 117 | return vx_f, vy_f, vz_f 118 | 119 | 120 | def decompose_v(vx, vy, x, y): 121 | ''' 122 | Decompose full velocity into radial and tengential components 123 | inputs: 124 | x,y: radar coordinates 125 | vx, vy: full velocity 126 | outputs: 127 | radial_v: np.array([vx_radial, vy_radial]) 128 | tangent_v: np.array([vx_tangent, vy_tangent]) 129 | ''' 130 | radial_v = (vx * x + vy * y) * np.array([x,y]) / (x**2 + y**2) 131 | tangent_v = np.array([vx,vy]) - radial_v 132 | 133 | return radial_v, tangent_v 134 | 135 | 136 | if __name__ == '__main__': 137 | parser = argparse.ArgumentParser() 138 | parser.add_argument('--dir_data', type=str) 139 | parser.add_argument('--version', type=str, default='v1.0-trainval', help='dataset split') 140 | 141 | args = parser.parse_args() 142 | 143 | if args.dir_data == None: 144 | this_dir = os.path.dirname(__file__) 145 | args.dir_data = join(this_dir, '..', 'data') 146 | 147 | args.dir_nuscenes = join(args.dir_data, 'nuscenes') 148 | dir_files = join(args.dir_data, 'prepared_data') 149 | 150 | sample_indices = torch.load(join(args.dir_data,'sample_split.tar'))['test_sample_indices'] 151 | nusc = NuScenes(args.version, dataroot = args.dir_nuscenes, verbose=False) 152 | 153 | error_list = [] 154 | error_radial_list = [] 155 | error_tan_list = [] 156 | 157 | error_list2 = [] 158 | error_radial_list2 = [] 159 | error_tan_list2 = [] 160 | 161 | N = len(sample_indices) 162 | d_min = 0; d_max = float('inf') 163 | 164 | for idx in tqdm(range(N)): 165 | sample_idx = sample_indices[idx] 166 | cam_token1 = nusc.sample[sample_idx]['data']['CAM_FRONT'] 167 | im1, im2, f, cx, cy, dt, cam_token2 = get_im_pair(nusc, cam_token1) 168 | 169 | gt = gt_box_key(nusc, sample_idx) 170 | rd_token = gt.radar_token 171 | 172 | gt_rd = gt.radar 173 | x_list, y_list, vx_list, vy_list, msk_gt = gt_rd['x'], gt_rd['y'], gt_rd['vx'], gt_rd['vy'], gt_rd['having_truth'] 174 | 175 | pc = LidarPointCloud( np.stack([x_list, y_list, np.zeros_like(x_list), np.ones_like(x_list)]) ) 176 | 177 | T_r2c = cal_trans_matrix(nusc, rd_token, cam_token1) 178 | T_c2r = cal_trans_matrix(nusc, cam_token1, rd_token) 179 | T_c2c = cal_trans_matrix(nusc, cam_token1, cam_token2) 180 | 181 | pc.transform(T_r2c) 182 | 183 | xi_list, yi_list, d_list, msk_in_im = proj2im(nusc, pc, cam_token1) 184 | xi_list, yi_list = downsample_coord(xi_list, yi_list, downsample_scale=4, y_cutoff=33) 185 | 186 | rd_offset = np.load(join(dir_files, '%05d_offset.npy' % sample_idx)) 187 | flow = np.load(join(dir_files, '%05d_full_flow.npy' % sample_idx)) 188 | u1_map, v1_map, u2_map, v2_map = flow2uv_map(flow, cx, cy, f) 189 | 190 | xi_list, yi_list = correct_coord(xi_list, yi_list, rd_offset) 191 | xi_list, yi_list = upsample_coord(xi_list, yi_list, downsample_scale=4, y_cutoff=33) 192 | 193 | prd_vxf = [] 194 | prd_vyf = [] 195 | prd_msk = [] 196 | for xi, yi, vx, vy, d, is_in_im in zip(xi_list, yi_list, vx_list, vy_list, d_list, msk_in_im): 197 | 198 | if [xi,yi] != [-1,-1] and d_min <= d < d_max and is_in_im: 199 | xp, yp = int(round(xi)), int(round(yi)) 200 | u1,v1,u2,v2 = u1_map[yp,xp], v1_map[yp,xp], u2_map[yp,xp], v2_map[yp,xp] 201 | vx_f, vy_f, vz_f = cal_full_v_in_radar(vx, vy, d, u1, v1, u2, v2, T_c2r, T_c2c, dt) 202 | 203 | prd_vxf.append(vx_f) 204 | prd_vyf.append(vy_f) 205 | prd_msk.append(True) 206 | else: 207 | prd_vxf.append(0) 208 | prd_vyf.append(0) 209 | prd_msk.append(False) 210 | 211 | vxf_list, vyf_list, gt_msk = gt_rd['vxf'], gt_rd['vyf'], gt_rd['having_truth'] 212 | 213 | # error for the proposed 214 | for x, y, vx, vy, gt_vx, gt_vy, gt_valid, prd_valid in zip(x_list, y_list, prd_vxf, prd_vyf, vxf_list, vyf_list, gt_msk, prd_msk): 215 | if gt_valid and prd_valid: 216 | error = ( (vx - gt_vx)**2 + (vy - gt_vy)**2 ) ** 0.5 217 | 218 | rad_v, tan_v = decompose_v(vx, vy, x, y) 219 | gt_rad_v, gt_tan_v = decompose_v(gt_vx, gt_vy, x, y) 220 | 221 | error_radial = np.linalg.norm(rad_v - gt_rad_v) 222 | error_tan = np.linalg.norm(tan_v - gt_tan_v) 223 | 224 | error_list.append(error) 225 | error_radial_list.append(error_radial) 226 | error_tan_list.append(error_tan) 227 | 228 | # erorr for baseline 229 | for x, y, vx, vy, gt_vx, gt_vy, gt_valid, prd_valid in zip(x_list, y_list, vx_list, vy_list, vxf_list, vyf_list, gt_msk, prd_msk): 230 | if gt_valid and prd_valid: 231 | error = ( (vx - gt_vx)**2 + (vy - gt_vy)**2 ) ** 0.5 232 | 233 | rad_v, tan_v = decompose_v(vx, vy, x, y) 234 | gt_rad_v, gt_tan_v = decompose_v(gt_vx, gt_vy, x, y) 235 | 236 | error_radial = np.linalg.norm(rad_v - gt_rad_v) 237 | error_tan = np.linalg.norm(tan_v - gt_tan_v) 238 | 239 | error_list2.append(error) 240 | error_radial_list2.append(error_radial) 241 | error_tan_list2.append(error_tan) 242 | 243 | ave_error = np.mean(error_list) 244 | std_error = np.std(error_list) 245 | 246 | ave_radial_error = np.mean(error_radial_list) 247 | std_radial_error = np.std(error_radial_list) 248 | 249 | ave_tan_error = np.mean(error_tan_list) 250 | std_tan_error = np.std(error_tan_list) 251 | 252 | print('Ours') 253 | print('ave_error:', ave_error) 254 | print('std_error:', std_error) 255 | print('ave_radial_error:', ave_radial_error) 256 | print('std_radial_error:', std_radial_error) 257 | print('ave_tan_error:', ave_tan_error) 258 | print('std_tan_error:', std_tan_error) 259 | 260 | ave_error2 = np.mean(error_list2) 261 | std_error2 = np.std(error_list2) 262 | 263 | ave_radial_error2 = np.mean(error_radial_list2) 264 | std_radial_error2 = np.std(error_radial_list2) 265 | 266 | ave_tan_error2 = np.mean(error_tan_list2) 267 | std_tan_error2 = np.std(error_tan_list2) 268 | 269 | print('Baseline') 270 | print('ave_error2:', ave_error2) 271 | print('std_error2:', std_error2) 272 | print('ave_radial_error2:', ave_radial_error2) 273 | print('std_radial_error2:', std_radial_error2) 274 | print('ave_tan_error2:', ave_tan_error2) 275 | print('std_tan_error2:', std_tan_error2) 276 | 277 | -------------------------------------------------------------------------------- /lib/gt_velocity.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from pyquaternion import Quaternion 4 | from functools import reduce 5 | from shapely.geometry import Point, MultiPoint 6 | from nuscenes.utils.data_classes import RadarPointCloud, LidarPointCloud 7 | from nuscenes.utils.geometry_utils import view_points, transform_matrix 8 | 9 | 10 | def cal_matrix_refSensor_from_global(nusc, sensor_token): 11 | sensor_data = nusc.get('sample_data', sensor_token) 12 | ref_pose_rec = nusc.get('ego_pose', sensor_data['ego_pose_token']) 13 | ref_cs_rec = nusc.get('calibrated_sensor', sensor_data['calibrated_sensor_token']) 14 | ref_from_car = transform_matrix(ref_cs_rec['translation'], Quaternion(ref_cs_rec['rotation']), inverse=True) 15 | car_from_global = transform_matrix(ref_pose_rec['translation'], Quaternion(ref_pose_rec['rotation']), inverse=True) 16 | M_ref_from_global = reduce(np.dot, [ref_from_car, car_from_global]) 17 | return M_ref_from_global 18 | 19 | 20 | def cal_matrix_refSensor_to_global(nusc, sensor_token): 21 | sensor_data = nusc.get('sample_data', sensor_token) 22 | current_pose_rec = nusc.get('ego_pose', sensor_data['ego_pose_token']) 23 | global_from_car = transform_matrix(current_pose_rec['translation'], 24 | Quaternion(current_pose_rec['rotation']), inverse=False) 25 | current_cs_rec = nusc.get('calibrated_sensor', sensor_data['calibrated_sensor_token']) 26 | car_from_current = transform_matrix(current_cs_rec['translation'], Quaternion(current_cs_rec['rotation']), inverse=False) 27 | M_ref_to_global = reduce(np.dot, [global_from_car, car_from_current]) 28 | return M_ref_to_global 29 | 30 | 31 | def cal_trans_matrix(nusc, sensor1_token, sensor2_token): 32 | ''' 33 | calculate transformation matrix from sensor1 to sensor2 (4 x 4) 34 | ''' 35 | M_ref_to_global = cal_matrix_refSensor_to_global(nusc, sensor1_token) 36 | M_ref_from_global = cal_matrix_refSensor_from_global(nusc, sensor2_token) 37 | trans_matrix = reduce(np.dot, [M_ref_from_global, M_ref_to_global]) 38 | return trans_matrix 39 | 40 | 41 | def flow2uv_map(flow, cx, cy, f): 42 | # flow: h x w x 2 43 | h, w = flow.shape[:2] 44 | x1_map, y1_map = np.meshgrid(np.arange(0, w), np.arange(0, h)) 45 | dx_map, dy_map = flow[...,0], flow[...,1] 46 | x2_map, y2_map = x1_map + dx_map, y1_map + dy_map 47 | u1_map, v1_map, u2_map, v2_map = (x1_map - cx)/f, (y1_map - cy)/f, (x2_map - cx)/f, (y2_map - cy)/f 48 | 49 | return u1_map, v1_map, u2_map, v2_map 50 | 51 | 52 | def proj2im(nusc, pc_cam, cam_token, min_z = 2): 53 | cam_data = nusc.get('sample_data', cam_token) 54 | cs_rec = nusc.get('calibrated_sensor', cam_data['calibrated_sensor_token']) 55 | depth = pc_cam.points[2] 56 | msk = pc_cam.points[2] >= min_z 57 | points = view_points(pc_cam.points[:3, :], np.array(cs_rec['camera_intrinsic']), normalize=True) 58 | x, y = points[0], points[1] 59 | msk = reduce(np.logical_and, [x>0, x<1600, y>0, y<900, msk]) 60 | return x, y, depth, msk 61 | 62 | 63 | def loadMovingRadar(nusc, radar_token, disable_filters = True): 64 | radar_sample = nusc.get('sample_data', radar_token) 65 | pcl_path = os.path.join(nusc.dataroot, radar_sample['filename']) 66 | if disable_filters: 67 | RadarPointCloud.disable_filters() 68 | pc = RadarPointCloud.from_file(pcl_path) 69 | 70 | pts = pc.points 71 | dynamic_prop = pts[3] 72 | msk_mv = reduce(np.logical_or, [dynamic_prop==0, dynamic_prop==2, dynamic_prop==6]) # moving mask 73 | 74 | x, y = pts[0], pts[1] 75 | vx, vy = pts[8], pts[9] 76 | 77 | x, y = x[msk_mv], y[msk_mv] 78 | vx, vy = vx[msk_mv], vy[msk_mv] 79 | 80 | return x,y,vx,vy 81 | 82 | 83 | class gt_box_key: 84 | def __init__(self, nusc, sample_idx, thres_v = 0.4, disable_radar_filters = True): 85 | 86 | def judge_moving(v, thres_v): 87 | if np.isnan( v[0] ): 88 | return False 89 | v_L2 = (v[0] ** 2 + v[1] ** 2) ** 0.5 90 | if v_L2 > thres_v: 91 | return True 92 | else: 93 | return False 94 | 95 | sample = nusc.sample[sample_idx] 96 | 97 | self.cam_token = nusc.sample[sample_idx]['data']['CAM_FRONT'] 98 | self.radar_token = nusc.sample[sample_idx]['data']['RADAR_FRONT'] 99 | 100 | self.boxes_world = {} # boxes in world coordinates 101 | for ann_token in sample['anns']: 102 | ann = nusc.get('sample_annotation', ann_token) 103 | v = nusc.box_velocity(ann['token']) 104 | is_moving = judge_moving(v, thres_v) 105 | if is_moving and ( 'vehicle' in ann['category_name'] ): 106 | obj_token = ann['instance_token'] 107 | self.boxes_world[obj_token] = { k : ann[k] for k in ['translation', 'size', 'rotation', 'instance_token', 'category_name'] } 108 | self.boxes_world[obj_token]['v'] = v 109 | 110 | self.boxes_radar = {} # boxes in radar coordinates 111 | M_radar_from_global = cal_matrix_refSensor_from_global(nusc, self.radar_token) 112 | for obj_token in self.boxes_world: 113 | box = self.boxes_world[obj_token] 114 | M_object_to_global = transform_matrix(box['translation'], Quaternion(box['rotation']), inverse=False) 115 | w,l,h = box['size'] 116 | vx, vy, vz = box['v'] # global coordinates 117 | corners_l = np.array([[-l/2,l/2,l/2,-l/2],[-w/2,-w/2,w/2,w/2], [-h/2,-h/2,-h/2,-h/2], [0,0,0,0]]) 118 | corners_h = np.array([[-l/2,l/2,l/2,-l/2],[-w/2,-w/2,w/2,w/2], [h/2,h/2,h/2,h/2], [0,0,0,0]]) 119 | center = np.array([[0],[0],[0],[0]]) 120 | v = np.array([[0,vx],[0,vy], [0,vz], [0,0]]) 121 | keyPts = LidarPointCloud(np.concatenate([corners_l, corners_h, center], axis=1)) 122 | keyPts.transform(M_object_to_global) 123 | pts = np.concatenate([keyPts.points,v],axis=1) 124 | keyPts.points = pts 125 | keyPts.transform(M_radar_from_global) 126 | 127 | x_list = keyPts.points[0, 0:-3] 128 | y_list = keyPts.points[1, 0:-3] 129 | box_center = keyPts.points[0:2, -3] 130 | 131 | if box_center[0] > 1: 132 | vx = keyPts.points[0,-1] - keyPts.points[0,-2] 133 | vy = keyPts.points[1,-1] - keyPts.points[1,-2] 134 | vz = keyPts.points[2,-1] - keyPts.points[2,-2] 135 | polygon = MultiPoint( [(x,y) for (x,y) in zip(x_list, y_list)] ).convex_hull 136 | v =(vx, vy, vz) 137 | self.boxes_radar[obj_token] = {'v': v, 'center': box_center, 'polygon': polygon} 138 | 139 | self.boxes_im= {} # boxes on image 140 | M_cam_from_global = cal_matrix_refSensor_from_global(nusc, self.cam_token) 141 | for obj_token in self.boxes_world: 142 | box = self.boxes_world[obj_token] 143 | M_object_to_global = transform_matrix(box['translation'], Quaternion(box['rotation']), inverse=False) 144 | w,l,h = box['size'] 145 | vx, vy, vz = box['v'] # global coordinates 146 | corners_l = np.array([[-l/2,l/2,l/2,-l/2],[-w/2,-w/2,w/2,w/2], [-h/2,-h/2,-h/2,-h/2], [0,0,0,0]]) 147 | corners_h = np.array([[-l/2,l/2,l/2,-l/2],[-w/2,-w/2,w/2,w/2], [h/2,h/2,h/2,h/2], [0,0,0,0]]) 148 | center = np.array([[0],[0],[0],[0]]) 149 | v = np.array([[0,vx],[0,vy], [0,vz], [0,0]]) 150 | keyPts = LidarPointCloud(np.concatenate([corners_l, corners_h, center], axis=1)) 151 | keyPts.transform(M_object_to_global) 152 | pts = np.concatenate([keyPts.points,v],axis=1) 153 | keyPts.points = pts 154 | keyPts.transform(M_cam_from_global) 155 | d_box = keyPts.points[2,-3] 156 | 157 | if d_box > 1: 158 | keyPts.points = keyPts.points[:, 0:-3] 159 | xs, ys, _, msks = proj2im(nusc, keyPts, self.cam_token, min_z = 2) 160 | polygon = MultiPoint( [(x,y) for (x,y) in zip(xs,ys)] ).convex_hull 161 | vx = keyPts.points[0,-1] - keyPts.points[0,-2] 162 | vy = keyPts.points[1,-1] - keyPts.points[1,-2] 163 | vz = keyPts.points[2,-1] - keyPts.points[2,-2] 164 | v =(vx, vy, vz) 165 | self.boxes_im[obj_token] = {'v': v, 'd_box_center': d_box, 'polygon': polygon } 166 | 167 | obj_list = [] 168 | vxf_list, vyf_list = [], [] 169 | vxf_w_list, vyf_w_list = [], [] # in global coordinates 170 | having_truth = [] 171 | x_list, y_list, vx_list, vy_list = loadMovingRadar(nusc, self.radar_token, disable_filters = disable_radar_filters) 172 | 173 | # filter out points outside image 174 | pc = LidarPointCloud( np.stack([x_list, y_list, np.zeros_like(x_list), np.ones_like(x_list)]) ) 175 | T_r2c = cal_trans_matrix(nusc, self.radar_token, self.cam_token) 176 | pc.transform(T_r2c) 177 | x_i, y_i, depth, msk = proj2im(nusc, pc, self.cam_token) 178 | x_i, y_i, depth = x_i[msk], y_i[msk], depth[msk] 179 | x_list, y_list, vx_list, vy_list = x_list[msk], y_list[msk], vx_list[msk], vy_list[msk] 180 | 181 | for x,y,vx,vy in zip(x_list, y_list, vx_list, vy_list): 182 | p1 = Point(x,y) 183 | box_found = False 184 | for obj_token in self.boxes_radar: 185 | box = self.boxes_radar[obj_token] 186 | poly = box['polygon'] 187 | if p1.distance(poly) < 0.5: 188 | vxf, vyf = box['v'][:2] 189 | vxf_w, vyf_w = self.boxes_world[obj_token]['v'][:2] 190 | error = abs( (vxf*vx + vyf*vy)/(vx**2 + vy**2) - 1 ) 191 | if error < 0.2: 192 | obj_list.append(obj_token) 193 | vxf_list.append(vxf) 194 | vyf_list.append(vyf) 195 | vxf_w_list.append(vxf_w) 196 | vyf_w_list.append(vyf_w) 197 | having_truth.append(True) 198 | box_found = True 199 | break 200 | if box_found == False: 201 | obj_list.append('') 202 | vxf_list.append(0) 203 | vyf_list.append(0) 204 | vxf_w_list.append(0) 205 | vyf_w_list.append(0) 206 | having_truth.append(False) 207 | 208 | # (x,y) radar coordinates; (x_i, y_i, depth) image depth 209 | # (vx,vy): radial velocity; (vxf,vyf): GT full velocity 210 | self.radar = {'x': x_list, 'y': y_list, 'x_i': x_i, 'y_i': y_i, 'depth': depth, 'vx': vx_list, 'vy':vy_list, 'vxf': vxf_list, 'vyf': vyf_list, 'vxf_w': vxf_w_list, 'vyf_w': vyf_w_list, 'having_truth': having_truth, 'obj_list': obj_list} 211 | 212 | def v_label_exist(self, thres_n_gt_pts=2): 213 | having_truth = self.radar['having_truth'] 214 | if len(having_truth)>0 and np.sum(having_truth) >= thres_n_gt_pts: 215 | return True 216 | else: 217 | return False 218 | 219 | 220 | def cal_full_v(vx, vy, d, u1, v1, u2, v2, T_c2r, T_c2c, dt): 221 | ''' 222 | inputs: 223 | (vx,vy): radial velocity in radar coordinates 224 | d: depth in camera coordinates 225 | u1,v1,u2,v2: image flow 226 | outputs: 227 | full velocity in camera coordinates 228 | ''' 229 | r11, r12, r13 = T_c2r[0,:3] 230 | r21, r22, r23 = T_c2r[1,:3] 231 | 232 | ra11, ra12, ra13, btx = T_c2c[0,:] 233 | ra21, ra22, ra23, bty = T_c2c[1,:] 234 | ra31, ra32, ra33, btz = T_c2c[2,:] 235 | 236 | A = np.array([[ra11-u2*ra31, ra12-u2*ra32, ra13-u2*ra33], \ 237 | [ra21-v2*ra31, ra22-v2*ra32, ra23-v2*ra33], \ 238 | [r11*vx+r21*vy, r12*vx+r22*vy, r13*vx+r23*vy]] ) 239 | 240 | b = np.array([[((ra31*u1+ra32*v1+ra33)*u2-(ra11*u1+ra12*v1+ra13))*d+u2*btz-btx],\ 241 | [((ra31*u1+ra32*v1+ra33)*v2-(ra21*u1+ra22*v1+ra23))*d+v2*btz-bty],\ 242 | [(vx**2 + vy**2)*dt]]) 243 | 244 | x = np.squeeze( np.dot( np.linalg.inv(A), b ) ) 245 | 246 | vx_c, vy_c, vz_c = x[0]/dt, x[1]/dt, x[2]/dt 247 | 248 | return (vx_c, vy_c, vz_c) 249 | 250 | 251 | def transform_velocity(v, T): 252 | vx, vy, vz = v[0], v[1], v[2] 253 | v2 = np.squeeze( np.dot(T[:3,:3], np.array([[vx], [vy], [vz]])) ) 254 | return v2 255 | 256 | 257 | def cal_depthMap(rd, uv, Tc2r, Tc2c, Tc2w, dt, nb, downsample_scale, y_cutoff): 258 | 259 | x_i, y_i, depth, vx, vy, msk = rd['x_i'], rd['y_i'], rd['depth'], rd['vx'], rd['vy'], rd['having_truth'] 260 | vx_gt, vy_gt = rd['vxf_w'], rd['vyf_w'] 261 | h_im, w_im = 900, 1600 262 | n_nb = len(nb.xy) 263 | h_new = int( h_im / downsample_scale ) - y_cutoff 264 | w_new = int( w_im / downsample_scale ) 265 | 266 | depth_map = np.zeros( (h_new, w_new) , dtype=float) 267 | msk_map = np.zeros( (h_new, w_new) , dtype=bool) 268 | 269 | vx_gt_map = np.zeros( (h_new, w_new) , dtype=float) 270 | vy_gt_map = np.zeros( (h_new, w_new) , dtype=float) 271 | 272 | vx_nb_map = np.zeros( (h_new, w_new, n_nb) , dtype=float) 273 | vy_nb_map = np.zeros( (h_new, w_new, n_nb) , dtype=float) 274 | vz_nb_map = np.zeros( (h_new, w_new, n_nb) , dtype=float) 275 | msk_nb_map = np.zeros( (h_new, w_new, n_nb) , dtype=bool) 276 | 277 | x_i = (x_i + 0.5) / downsample_scale - 0.5 278 | y_i = (y_i + 0.5) / downsample_scale - 0.5 - y_cutoff 279 | x_i = np.clip(x_i, 0, w_new - 1) 280 | y_i = np.clip(y_i, 0, h_new - 1) 281 | 282 | for i in range(len(x_i)): 283 | x_one, y_one = int(round( x_i[i] )), int(round( y_i[i] )) 284 | if depth_map[y_one,x_one] == 0 or depth_map[y_one,x_one] > depth[i]: 285 | depth_map[y_one,x_one] = depth[i] 286 | msk_map[y_one,x_one] = msk[i] 287 | for nb_idx, (ox,oy) in enumerate(nb.xy): 288 | x_new = x_one + ox 289 | y_new = y_one + oy 290 | if (0 <= x_new < w_new) and (0 <= y_new < h_new): 291 | msk_nb_map[y_one, x_one, nb_idx] = True 292 | u1, v1, u2, v2 = uv[y_new,x_new] 293 | v_cam = cal_full_v(vx[i], vy[i], depth[i], u1, v1, u2, v2, Tc2r, Tc2c, dt) 294 | v_world = transform_velocity(v_cam, Tc2w) 295 | vx_nb_map[y_one, x_one, nb_idx] = v_world[0] 296 | vy_nb_map[y_one, x_one, nb_idx] = v_world[1] 297 | vz_nb_map[y_one, x_one, nb_idx] = v_world[2] 298 | vx_gt_map[y_one, x_one] = vx_gt[i] 299 | vy_gt_map[y_one, x_one] = vy_gt[i] 300 | 301 | return depth_map, msk_map, msk_nb_map, vx_nb_map, vy_nb_map, vz_nb_map, vx_gt_map, vy_gt_map 302 | 303 | -------------------------------------------------------------------------------- /scripts/split_data.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Generate train/val/test split 3 | ''' 4 | import os 5 | from collections import defaultdict 6 | from os.path import join 7 | import numpy as np 8 | from tqdm import tqdm 9 | import argparse 10 | import torch 11 | from nuscenes.nuscenes import NuScenes 12 | 13 | import _init_paths 14 | from gt_velocity import gt_box_key 15 | 16 | np.random.seed(1) 17 | 18 | SCENE_SPLITS = { 19 | 'train': 20 | ['scene-0001', 'scene-0002', 'scene-0004', 'scene-0005', 'scene-0006', 'scene-0007', 'scene-0008', 'scene-0009', 21 | 'scene-0010', 'scene-0011', 'scene-0019', 'scene-0020', 'scene-0021', 'scene-0022', 'scene-0023', 'scene-0024', 22 | 'scene-0025', 'scene-0026', 'scene-0027', 'scene-0028', 'scene-0029', 'scene-0030', 'scene-0031', 'scene-0032', 23 | 'scene-0033', 'scene-0034', 'scene-0041', 'scene-0042', 'scene-0043', 'scene-0044', 'scene-0045', 'scene-0046', 24 | 'scene-0047', 'scene-0048', 'scene-0049', 'scene-0050', 'scene-0051', 'scene-0052', 'scene-0053', 'scene-0054', 25 | 'scene-0055', 'scene-0056', 'scene-0057', 'scene-0058', 'scene-0059', 'scene-0060', 'scene-0061', 'scene-0062', 26 | 'scene-0063', 'scene-0064', 'scene-0065', 'scene-0066', 'scene-0067', 'scene-0068', 'scene-0069', 'scene-0070', 27 | 'scene-0071', 'scene-0072', 'scene-0073', 'scene-0074', 'scene-0075', 'scene-0076', 'scene-0120', 'scene-0121', 28 | 'scene-0122', 'scene-0123', 'scene-0124', 'scene-0125', 'scene-0126', 'scene-0127', 'scene-0128', 'scene-0129', 29 | 'scene-0130', 'scene-0131', 'scene-0132', 'scene-0133', 'scene-0134', 'scene-0135', 'scene-0138', 'scene-0139', 30 | 'scene-0149', 'scene-0150', 'scene-0151', 'scene-0152', 'scene-0154', 'scene-0155', 'scene-0157', 'scene-0158', 31 | 'scene-0159', 'scene-0160', 'scene-0161', 'scene-0162', 'scene-0163', 'scene-0164', 'scene-0165', 'scene-0166', 32 | 'scene-0167', 'scene-0168', 'scene-0170', 'scene-0171', 'scene-0172', 'scene-0173', 'scene-0174', 'scene-0175', 33 | 'scene-0176', 'scene-0177', 'scene-0178', 'scene-0179', 'scene-0180', 'scene-0181', 'scene-0182', 'scene-0183', 34 | 'scene-0184', 'scene-0185', 'scene-0187', 'scene-0188', 'scene-0190', 'scene-0191', 'scene-0192', 'scene-0193', 35 | 'scene-0194', 'scene-0195', 'scene-0196', 'scene-0199', 'scene-0200', 'scene-0202', 'scene-0203', 'scene-0204', 36 | 'scene-0206', 'scene-0207', 'scene-0208', 'scene-0209', 'scene-0210', 'scene-0211', 'scene-0212', 'scene-0213', 37 | 'scene-0214', 'scene-0218', 'scene-0219', 'scene-0220', 'scene-0222', 'scene-0224', 'scene-0225', 'scene-0226', 38 | 'scene-0227', 'scene-0228', 'scene-0229', 'scene-0230', 'scene-0231', 'scene-0232', 'scene-0233', 'scene-0234', 39 | 'scene-0235', 'scene-0236', 'scene-0237', 'scene-0238', 'scene-0239', 'scene-0240', 'scene-0241', 'scene-0242', 40 | 'scene-0243', 'scene-0244', 'scene-0245', 'scene-0246', 'scene-0247', 'scene-0248', 'scene-0249', 'scene-0250', 41 | 'scene-0251', 'scene-0252', 'scene-0253', 'scene-0254', 'scene-0255', 'scene-0256', 'scene-0257', 'scene-0258', 42 | 'scene-0259', 'scene-0260', 'scene-0261', 'scene-0262', 'scene-0263', 'scene-0264', 'scene-0283', 'scene-0284', 43 | 'scene-0285', 'scene-0286', 'scene-0287', 'scene-0288', 'scene-0289', 'scene-0290', 'scene-0291', 'scene-0292', 44 | 'scene-0293', 'scene-0294', 'scene-0295', 'scene-0296', 'scene-0297', 'scene-0298', 'scene-0299', 'scene-0300', 45 | 'scene-0301', 'scene-0302', 'scene-0303', 'scene-0304', 'scene-0305', 'scene-0306', 'scene-0315', 'scene-0316', 46 | 'scene-0317', 'scene-0318', 'scene-0321', 'scene-0323', 'scene-0324', 'scene-0328', 'scene-0347', 'scene-0348', 47 | 'scene-0349', 'scene-0350', 'scene-0351', 'scene-0352', 'scene-0353', 'scene-0354', 'scene-0355', 'scene-0356', 48 | 'scene-0357', 'scene-0358', 'scene-0359', 'scene-0360', 'scene-0361', 'scene-0362', 'scene-0363', 'scene-0364', 49 | 'scene-0365', 'scene-0366', 'scene-0367', 'scene-0368', 'scene-0369', 'scene-0370', 'scene-0371', 'scene-0372', 50 | 'scene-0373', 'scene-0374', 'scene-0375', 'scene-0376', 'scene-0377', 'scene-0378', 'scene-0379', 'scene-0380', 51 | 'scene-0381', 'scene-0382', 'scene-0383', 'scene-0384', 'scene-0385', 'scene-0386', 'scene-0388', 'scene-0389', 52 | 'scene-0390', 'scene-0391', 'scene-0392', 'scene-0393', 'scene-0394', 'scene-0395', 'scene-0396', 'scene-0397', 53 | 'scene-0398', 'scene-0399', 'scene-0400', 'scene-0401', 'scene-0402', 'scene-0403', 'scene-0405', 'scene-0406', 54 | 'scene-0407', 'scene-0408', 'scene-0410', 'scene-0411', 'scene-0412', 'scene-0413', 'scene-0414', 'scene-0415', 55 | 'scene-0416', 'scene-0417', 'scene-0418', 'scene-0419', 'scene-0420', 'scene-0421', 'scene-0422', 'scene-0423', 56 | 'scene-0424', 'scene-0425', 'scene-0426', 'scene-0427', 'scene-0428', 'scene-0429', 'scene-0430', 'scene-0431', 57 | 'scene-0432', 'scene-0433', 'scene-0434', 'scene-0435', 'scene-0436', 'scene-0437', 'scene-0438', 'scene-0439', 58 | 'scene-0440', 'scene-0441', 'scene-0442', 'scene-0443', 'scene-0444', 'scene-0445', 'scene-0446', 'scene-0447', 59 | 'scene-0448', 'scene-0449', 'scene-0450', 'scene-0451', 'scene-0452', 'scene-0453', 'scene-0454', 'scene-0455', 60 | 'scene-0456', 'scene-0457', 'scene-0458', 'scene-0459', 'scene-0461', 'scene-0462', 'scene-0463', 'scene-0464', 61 | 'scene-0465', 'scene-0467', 'scene-0468', 'scene-0469', 'scene-0471', 'scene-0472', 'scene-0474', 'scene-0475', 62 | 'scene-0476', 'scene-0477', 'scene-0478', 'scene-0479', 'scene-0480', 'scene-0499', 'scene-0500', 'scene-0501', 63 | 'scene-0502', 'scene-0504', 'scene-0505', 'scene-0506', 'scene-0507', 'scene-0508', 'scene-0509', 'scene-0510', 64 | 'scene-0511', 'scene-0512', 'scene-0513', 'scene-0514', 'scene-0515', 'scene-0517', 'scene-0518', 'scene-0525', 65 | 'scene-0526', 'scene-0527', 'scene-0528', 'scene-0529', 'scene-0530', 'scene-0531', 'scene-0532', 'scene-0533', 66 | 'scene-0534', 'scene-0535', 'scene-0536', 'scene-0537', 'scene-0538', 'scene-0539', 'scene-0541', 'scene-0542', 67 | 'scene-0543', 'scene-0544', 'scene-0545', 'scene-0546', 'scene-0566', 'scene-0568', 'scene-0570', 'scene-0571', 68 | 'scene-0572', 'scene-0573', 'scene-0574', 'scene-0575', 'scene-0576', 'scene-0577', 'scene-0578', 'scene-0580', 69 | 'scene-0582', 'scene-0583', 'scene-0584', 'scene-0585', 'scene-0586', 'scene-0587', 'scene-0588', 'scene-0589', 70 | 'scene-0590', 'scene-0591', 'scene-0592', 'scene-0593', 'scene-0594', 'scene-0595', 'scene-0596', 'scene-0597', 71 | 'scene-0598', 'scene-0599', 'scene-0600', 'scene-0639', 'scene-0640', 'scene-0641', 'scene-0642', 'scene-0643', 72 | 'scene-0644', 'scene-0645', 'scene-0646', 'scene-0647', 'scene-0648', 'scene-0649', 'scene-0650', 'scene-0651', 73 | 'scene-0652', 'scene-0653', 'scene-0654', 'scene-0655', 'scene-0656', 'scene-0657', 'scene-0658', 'scene-0659', 74 | 'scene-0660', 'scene-0661', 'scene-0662', 'scene-0663', 'scene-0664', 'scene-0665', 'scene-0666', 'scene-0667', 75 | 'scene-0668', 'scene-0669', 'scene-0670', 'scene-0671', 'scene-0672', 'scene-0673', 'scene-0674', 'scene-0675', 76 | 'scene-0676', 'scene-0677', 'scene-0678', 'scene-0679', 'scene-0681', 'scene-0683', 'scene-0684', 'scene-0685', 77 | 'scene-0686', 'scene-0687', 'scene-0688', 'scene-0689', 'scene-0695', 'scene-0696', 'scene-0697', 'scene-0698', 78 | 'scene-0700', 'scene-0701', 'scene-0703', 'scene-0704', 'scene-0705', 'scene-0706', 'scene-0707', 'scene-0708', 79 | 'scene-0709', 'scene-0710', 'scene-0711', 'scene-0712', 'scene-0713', 'scene-0714', 'scene-0715', 'scene-0716', 80 | 'scene-0717', 'scene-0718', 'scene-0719', 'scene-0726', 'scene-0727', 'scene-0728', 'scene-0730', 'scene-0731', 81 | 'scene-0733', 'scene-0734', 'scene-0735', 'scene-0736', 'scene-0737', 'scene-0738', 'scene-0739', 'scene-0740', 82 | 'scene-0741', 'scene-0744', 'scene-0746', 'scene-0747', 'scene-0749', 'scene-0750', 'scene-0751', 'scene-0752', 83 | 'scene-0757', 'scene-0758', 'scene-0759', 'scene-0760', 'scene-0761', 'scene-0762', 'scene-0763', 'scene-0764', 84 | 'scene-0765', 'scene-0767', 'scene-0768', 'scene-0769', 'scene-0786', 'scene-0787', 'scene-0789', 'scene-0790', 85 | 'scene-0791', 'scene-0792', 'scene-0803', 'scene-0804', 'scene-0805', 'scene-0806', 'scene-0808', 'scene-0809', 86 | 'scene-0810', 'scene-0811', 'scene-0812', 'scene-0813', 'scene-0815', 'scene-0816', 'scene-0817', 'scene-0819', 87 | 'scene-0820', 'scene-0821', 'scene-0822', 'scene-0847', 'scene-0848', 'scene-0849', 'scene-0850', 'scene-0851', 88 | 'scene-0852', 'scene-0853', 'scene-0854', 'scene-0855', 'scene-0856', 'scene-0858', 'scene-0860', 'scene-0861', 89 | 'scene-0862', 'scene-0863', 'scene-0864', 'scene-0865', 'scene-0866', 'scene-0868', 'scene-0869', 'scene-0870', 90 | 'scene-0871', 'scene-0872', 'scene-0873', 'scene-0875', 'scene-0876', 'scene-0877', 'scene-0878', 'scene-0880', 91 | 'scene-0882', 'scene-0883', 'scene-0884', 'scene-0885', 'scene-0886', 'scene-0887', 'scene-0888', 'scene-0889', 92 | 'scene-0890', 'scene-0891', 'scene-0892', 'scene-0893', 'scene-0894', 'scene-0895', 'scene-0896', 'scene-0897', 93 | 'scene-0898', 'scene-0899', 'scene-0900', 'scene-0901', 'scene-0902', 'scene-0903', 'scene-0945', 'scene-0947', 94 | 'scene-0949', 'scene-0952', 'scene-0953', 'scene-0955', 'scene-0956', 'scene-0957', 'scene-0958', 'scene-0959', 95 | 'scene-0960', 'scene-0961', 'scene-0975', 'scene-0976', 'scene-0977', 'scene-0978', 'scene-0979', 'scene-0980', 96 | 'scene-0981', 'scene-0982', 'scene-0983', 'scene-0984', 'scene-0988', 'scene-0989', 'scene-0990', 'scene-0991', 97 | 'scene-0992', 'scene-0994', 'scene-0995', 'scene-0996', 'scene-0997', 'scene-0998', 'scene-0999', 'scene-1000', 98 | 'scene-1001', 'scene-1002', 'scene-1003', 'scene-1004', 'scene-1005', 'scene-1006', 'scene-1007', 'scene-1008', 99 | 'scene-1009', 'scene-1010', 'scene-1011', 'scene-1012', 'scene-1013', 'scene-1014', 'scene-1015', 'scene-1016', 100 | 'scene-1017', 'scene-1018', 'scene-1019', 'scene-1020', 'scene-1021', 'scene-1022', 'scene-1023', 'scene-1024', 101 | 'scene-1025', 'scene-1044', 'scene-1045', 'scene-1046', 'scene-1047', 'scene-1048', 'scene-1049', 'scene-1050', 102 | 'scene-1051', 'scene-1052', 'scene-1053', 'scene-1054', 'scene-1055', 'scene-1056', 'scene-1057', 'scene-1058', 103 | 'scene-1074', 'scene-1075', 'scene-1076', 'scene-1077', 'scene-1078', 'scene-1079', 'scene-1080', 'scene-1081', 104 | 'scene-1082', 'scene-1083', 'scene-1084', 'scene-1085', 'scene-1086', 'scene-1087', 'scene-1088', 'scene-1089', 105 | 'scene-1090', 'scene-1091', 'scene-1092', 'scene-1093', 'scene-1094', 'scene-1095', 'scene-1096', 'scene-1097', 106 | 'scene-1098', 'scene-1099', 'scene-1100', 'scene-1101', 'scene-1102', 'scene-1104', 'scene-1105', 'scene-1106', 107 | 'scene-1107', 'scene-1108', 'scene-1109', 'scene-1110'], 108 | 'val': 109 | ['scene-0003', 'scene-0012', 'scene-0013', 'scene-0014', 'scene-0015', 'scene-0016', 'scene-0017', 'scene-0018', 110 | 'scene-0035', 'scene-0036', 'scene-0038', 'scene-0039', 'scene-0092', 'scene-0093', 'scene-0094', 'scene-0095', 111 | 'scene-0096', 'scene-0097', 'scene-0098', 'scene-0099', 'scene-0100', 'scene-0101', 'scene-0102', 'scene-0103', 112 | 'scene-0104', 'scene-0105', 'scene-0106', 'scene-0107', 'scene-0108', 'scene-0109', 'scene-0110', 'scene-0221', 113 | 'scene-0268', 'scene-0269', 'scene-0270', 'scene-0271', 'scene-0272', 'scene-0273', 'scene-0274', 'scene-0275', 114 | 'scene-0276', 'scene-0277', 'scene-0278', 'scene-0329', 'scene-0330', 'scene-0331', 'scene-0332', 'scene-0344', 115 | 'scene-0345', 'scene-0346', 'scene-0519', 'scene-0520', 'scene-0521', 'scene-0522', 'scene-0523', 'scene-0524', 116 | 'scene-0552', 'scene-0553', 'scene-0554', 'scene-0555', 'scene-0556', 'scene-0557', 'scene-0558', 'scene-0559', 117 | 'scene-0560', 'scene-0561', 'scene-0562', 'scene-0563', 'scene-0564', 'scene-0565', 'scene-0625', 'scene-0626', 118 | 'scene-0627', 'scene-0629', 'scene-0630', 'scene-0632', 'scene-0633', 'scene-0634', 'scene-0635', 'scene-0636', 119 | 'scene-0637', 'scene-0638', 'scene-0770', 'scene-0771', 'scene-0775', 'scene-0777', 'scene-0778', 'scene-0780', 120 | 'scene-0781', 'scene-0782', 'scene-0783', 'scene-0784', 'scene-0794', 'scene-0795', 'scene-0796', 'scene-0797', 121 | 'scene-0798', 'scene-0799', 'scene-0800', 'scene-0802', 'scene-0904', 'scene-0905', 'scene-0906', 'scene-0907', 122 | 'scene-0908', 'scene-0909', 'scene-0910', 'scene-0911', 'scene-0912', 'scene-0913', 'scene-0914', 'scene-0915', 123 | 'scene-0916', 'scene-0917', 'scene-0919', 'scene-0920', 'scene-0921', 'scene-0922', 'scene-0923', 'scene-0924', 124 | 'scene-0925', 'scene-0926', 'scene-0927', 'scene-0928', 'scene-0929', 'scene-0930', 'scene-0931', 'scene-0962', 125 | 'scene-0963', 'scene-0966', 'scene-0967', 'scene-0968', 'scene-0969', 'scene-0971', 'scene-0972', 'scene-1059', 126 | 'scene-1060', 'scene-1061', 'scene-1062', 'scene-1063', 'scene-1064', 'scene-1065', 'scene-1066', 'scene-1067', 127 | 'scene-1068', 'scene-1069', 'scene-1070', 'scene-1071', 'scene-1072', 'scene-1073'], 128 | 'mini_train': 129 | ['scene-0061', 'scene-0553', 'scene-0655', 'scene-0757', 'scene-0796', 'scene-1077', 'scene-1094', 'scene-1100'], 130 | 'mini_val': 131 | ['scene-0103', 'scene-0916'], 132 | } 133 | 134 | 135 | def is_first_2_sample_in_scene(idx): 136 | if not nusc.sample[idx]['prev']: 137 | return True 138 | elif not nusc.sample[idx-1]['prev']: 139 | return True 140 | else: 141 | return False 142 | 143 | def is_last_2_sample_in_scene(idx): 144 | if not nusc.sample[idx]['next']: 145 | return True 146 | elif not nusc.sample[idx+1]['next']: 147 | return True 148 | else: 149 | return False 150 | 151 | 152 | if __name__ == '__main__': 153 | parser = argparse.ArgumentParser() 154 | parser.add_argument('--dir_data', type=str, help='data directory') 155 | parser.add_argument('--version', type=str, default='v1.0-trainval', help='dataset split') 156 | 157 | args = parser.parse_args() 158 | 159 | if args.dir_data == None: 160 | this_dir = os.path.dirname(__file__) 161 | args.dir_data = os.path.join(this_dir, '..', 'data') 162 | 163 | 164 | dir_nuscenes = join(args.dir_data, 'nuscenes') 165 | 166 | train_ratio = 0.9 167 | val_ratio = 0.1 168 | 169 | nusc = NuScenes(args.version, dataroot = dir_nuscenes, verbose=False) 170 | 171 | trainval_scenes = defaultdict(list) 172 | test_scenes = defaultdict(list) 173 | for scene in tqdm(nusc.scene): 174 | if 'rain'.lower() in scene['description'].lower() or 'Night'.lower() in scene['description'].lower(): 175 | continue 176 | name = scene['name'] 177 | if name in SCENE_SPLITS['train']: 178 | trainval_scenes['names'].append(name) 179 | trainval_scenes['tokens'].append(scene['token']) 180 | elif name in SCENE_SPLITS['val']: 181 | test_scenes['names'].append(name) 182 | test_scenes['tokens'].append(scene['token']) 183 | 184 | 185 | trainval_scenes['names'], trainval_scenes['tokens'] = np.array(trainval_scenes['names']), np.array(trainval_scenes['tokens']) 186 | N_trainval = len(trainval_scenes['names']) 187 | indices = np.arange(N_trainval) 188 | np.random.shuffle(indices) 189 | trainval_scenes['names'] = trainval_scenes['names'][indices] 190 | trainval_scenes['tokens'] = trainval_scenes['tokens'][indices] 191 | 192 | 193 | n_train_scenes = int(round(N_trainval * train_ratio)) 194 | n_val_scenes = N_trainval - n_train_scenes 195 | 196 | train_scenes = defaultdict(list) 197 | val_scenes = defaultdict(list) 198 | 199 | train_scenes['names'] = trainval_scenes['names'][:n_train_scenes] 200 | train_scenes['tokens'] = trainval_scenes['tokens'][:n_train_scenes] 201 | val_scenes['names'] = trainval_scenes['names'][n_train_scenes:] 202 | val_scenes['tokens'] = trainval_scenes['tokens'][n_train_scenes:] 203 | 204 | train_sample_idx = [] 205 | val_sample_idx = [] 206 | test_sample_idx = [] 207 | 208 | for idx, sample in tqdm(enumerate(nusc.sample)): 209 | if is_first_2_sample_in_scene(idx) or is_last_2_sample_in_scene(idx): 210 | continue 211 | if sample['scene_token'] in train_scenes['tokens']: 212 | train_sample_idx.append(idx) 213 | elif sample['scene_token'] in val_scenes['tokens']: 214 | val_sample_idx.append(idx) 215 | elif sample['scene_token'] in test_scenes['tokens']: 216 | test_sample_idx.append(idx) 217 | 218 | print(len(train_sample_idx), len(val_sample_idx), len(test_sample_idx)) 219 | 220 | all_idx = train_sample_idx + val_sample_idx + test_sample_idx 221 | 222 | sample_split = {'all_indices': all_idx, 223 | 'train_sample_indices': train_sample_idx, 224 | 'val_sample_indices': val_sample_idx, 225 | 'test_sample_indices': test_sample_idx } 226 | 227 | 228 | train_sample_idx = [] 229 | val_sample_idx = [] 230 | test_sample_idx = [] 231 | 232 | for sample_idx in tqdm(sample_split['all_indices']): 233 | gt = gt_box_key(nusc, sample_idx) 234 | is_valid_frame = gt.v_label_exist(thres_n_gt_pts=2) # ignore frames having fewer than 2 radar points with GT velocity 235 | if is_valid_frame: 236 | if sample_idx in sample_split['train_sample_indices']: 237 | train_sample_idx.append(sample_idx) 238 | elif sample_idx in sample_split['val_sample_indices']: 239 | val_sample_idx.append(sample_idx) 240 | elif sample_idx in sample_split['test_sample_indices']: 241 | test_sample_idx.append(sample_idx) 242 | 243 | 244 | all_idx = train_sample_idx + val_sample_idx + test_sample_idx 245 | 246 | sample_split_new = {'all_indices': all_idx, 247 | 'train_sample_indices': train_sample_idx, 248 | 'val_sample_indices': val_sample_idx, 249 | 'test_sample_indices': test_sample_idx } 250 | 251 | 252 | print('train: %d, val: %d, test: %d' % ( len(train_sample_idx), len(val_sample_idx), len(test_sample_idx) ) ) 253 | torch.save(sample_split_new, join(args.dir_data, 'sample_split.tar')) 254 | 255 | --------------------------------------------------------------------------------