├── LICENSE ├── README.md ├── cfgs ├── nclt_bev.yaml └── oxford_bev.yaml ├── datasets ├── augmentor.py ├── composition_bev.py ├── nclt_bev.py ├── oxford_bev.py ├── projection.py └── robotcar_sdk │ ├── extrinsics │ ├── ins.txt │ ├── ldmrs.txt │ ├── lms_front.txt │ ├── lms_rear.txt │ ├── mono_left.txt │ ├── mono_rear.txt │ ├── mono_right.txt │ ├── nlct_velodyne.txt │ ├── radar.txt │ ├── stereo.txt │ ├── velodyne_left.txt │ └── velodyne_right.txt │ └── python │ ├── README.md │ ├── __init__.py │ ├── build_pointcloud.py │ ├── camera_model.py │ ├── image.py │ ├── interpolate_poses.py │ ├── play_images.py │ ├── play_radar.py │ ├── play_velodyne.py │ ├── project_laser_into_camera.py │ ├── radar.py │ ├── requirements.txt │ ├── transform.py │ └── velodyne.py ├── img ├── nclt.gif └── oxford.gif ├── install.sh ├── log └── count_SR.py ├── merge_nclt.py ├── merge_oxford.py ├── models ├── __init__.py ├── decoders.py ├── denoiser_bev.py ├── difussion_loc_model_bev.py ├── gaussian_diffuser.py ├── image_feature_extractor_bev.py ├── layers │ ├── __init__.py │ ├── attention.py │ ├── block.py │ ├── dino_head.py │ ├── drop_path.py │ ├── layer_scale.py │ ├── mlp.py │ ├── patch_embed.py │ └── swiglu_ffn.py ├── model_utils.py └── stems.py ├── test_bev.py ├── train_bev.py └── utils ├── embedding.py ├── pose_util.py ├── train_util.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 NuBot 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BEVDiffLoc 2 | BEVDiffLoc: End-to-End LiDAR Global Localization in BEV View based on Diffusion Model 3 | 4 | # Visualization 5 | ![image](img/oxford.gif) ![image](img/nclt.gif) 6 | 7 | # Environment 8 | 9 | - python 3.10 10 | 11 | - pytorch 2.1.2 12 | 13 | - cuda 12.1 14 | 15 | ``` 16 | source install.sh 17 | ``` 18 | 19 | ## Dataset 20 | 21 | We support the [Oxford Radar RobotCar](https://oxford-robotics-institute.github.io/radar-robotcar-dataset/datasets) and [NCLT](https://robots.engin.umich.edu/nclt/) datasets right now. 22 | 23 | The data of the Oxford and NCLT dataset should be organized as follows: 24 | 25 | ``` 26 | data_root 27 | ├── 2019-01-11-14-02-26-radar-oxford-10k 28 | │ ├── velodyne_left 29 | │ │ ├── xxx.bin 30 | │ │ ├── xxx.bin 31 | │ ├── gps 32 | │ │ ├── gps.csv 33 | │ │ ├── ins.csv 34 | │ ├── velodyne_left.timestamps 35 | │ ├── merge_bev (Data prepare) 36 | │ │ ├── xxx.png 37 | │ │ ├── xxx.png 38 | │ ├── merge_bev.txt (Data prepare) 39 | ├── Oxford_pose_stats.txt 40 | ├── train_split.txt 41 | ├── valid_split.txt 42 | ``` 43 | 44 | ## Data prepare 45 | 46 | - Oxford&NCLT: We use [merge_nclt.py](merge_nclt.py) and [merge_oxford.py](merge_oxford.py) to generate local scenes for data augmentation. 47 | 48 | 49 | ## Run 50 | 51 | ### Download the pretrained ViT model 52 | We initialize BEVDiffLoc's feature learner with [DINOv2](https://github.com/facebookresearch/dinov2?tab=readme-ov-file). 53 | 54 | ### Train 55 | 56 | ``` 57 | accelerate launch --num_processes 1 --mixed_precision fp16 train_bev.py 58 | ``` 59 | 60 | ### Test 61 | ``` 62 | python test_bev.py 63 | ``` 64 | 65 | ## Citation 66 | 67 | If you find this work helpful, please consider citing: 68 | ```bibtex 69 | @article{wang2025bevdiffloc, 70 | title={BEVDiffLoc: End-to-End LiDAR Global Localization in BEV View based on Diffusion Model}, 71 | author={Wang, Ziyue and Shi, Chenghao and Wang, Neng and Yu, Qinghua and Chen, Xieyuanli and Lu, Huimin}, 72 | journal={arXiv preprint arXiv:2503.11372}, 73 | year={2025} 74 | } 75 | ``` 76 | 77 | ## Acknowledgement 78 | 79 | We appreciate the code of [DiffLoc](https://github.com/liw95/DiffLoc) and [BEVPlace++](https://github.com/zjuluolun/BEVPlace2) they shared. 80 | 81 | -------------------------------------------------------------------------------- /cfgs/nclt_bev.yaml: -------------------------------------------------------------------------------- 1 | ckpt: 'log/nclt.pth' 2 | seed: 7 3 | exp_name: 'NCLT_bev' 4 | exp_dir: 'log' 5 | sampling_timesteps: 15 6 | 7 | train: 8 | batch_size: 16 9 | val_batch_size: 16 10 | image_size: [32, 32] 11 | original_image_size: [32, 32] 12 | steps: 3 13 | skip: 2 14 | use_merge: True 15 | merge_num: 1 16 | restart_num: 5 17 | lr: 5e-4 18 | weight_decay: 0.01 19 | epochs: 151 20 | ckpt_interval: 1 21 | num_workers: 16 22 | eval_interval: 5 23 | print_interval: 200 24 | persistent_workers: True 25 | 26 | pin_memory: True 27 | clip_grad: 1.0 28 | 29 | cudnnbenchmark: True 30 | 31 | warmup_sche: True 32 | dataset: 'NCLT' 33 | dataroot: '/media/wzy/data' 34 | 35 | profile: False 36 | 37 | 38 | MODEL: 39 | _target_: models.DiffusionLocModel_bev 40 | 41 | IMAGE_FEATURE_EXTRACTOR: 42 | _target_: models.image_feature_extractor_bev.ImageFeatureExtractor_bev 43 | backbone: "vit_small_patch16_384" 44 | freeze: False 45 | in_channels: 128 46 | new_patch_size: [4, 4] 47 | new_patch_stride: [4, 4] 48 | conv_stem: 'ConvStem' # 'none' or 'ConvStem' 49 | stem_base_channels: 32 50 | D_h: 256 # hidden dimension of the stem 51 | image_size: [32, 32] 52 | decoder: 'linear' 53 | pretrained_path: "dinov2_vits14_pretrain.pth" 54 | reuse_pos_emb: true 55 | reuse_patch_emb: false # no patch embedding as a convolutional stem (ConvStem) is used 56 | n_cls: 1 57 | 58 | DENOISER: 59 | _target_: models.Denoiser_bev 60 | TRANSFORMER: 61 | _target_: models.denoiser_bev.TransformerEncoderWrapper_bev 62 | d_model: 512 63 | nhead: 4 64 | num_encoder_layers: 8 65 | dim_feedforward: 1024 66 | dropout: 0.1 67 | batch_first: True 68 | norm_first: True 69 | 70 | 71 | DIFFUSER: 72 | _target_: models.GaussianDiffusion 73 | beta_schedule: custom 74 | 75 | # Data augmentation config 76 | augmentation: 77 | 78 | # translation 79 | p_transx: 0.5 80 | trans_xmin: -5 81 | trans_xmax: 5 82 | p_transy: 0.5 83 | trans_ymin: -3 84 | trans_ymax: 3 85 | p_transz: 0.5 86 | trans_zmin: -1 87 | trans_zmax: 0. 88 | 89 | # rotation 90 | p_rot_roll: 0.5 91 | rot_rollmin: -5 92 | rot_rollmax: 5 93 | p_rot_pitch: 0.5 94 | rot_pitchmin: -5 95 | rot_pitchmax: 5 96 | p_rot_yaw: 0.5 97 | rot_yawmin: -45 98 | rot_yawmax: 45 99 | -------------------------------------------------------------------------------- /cfgs/oxford_bev.yaml: -------------------------------------------------------------------------------- 1 | ckpt: 'log/oxford.pth' 2 | seed: 7 3 | exp_name: 'Oxford_bev' 4 | exp_dir: 'log' 5 | sampling_timesteps: 10 6 | 7 | train: 8 | batch_size: 16 9 | val_batch_size: 16 10 | image_size: [32, 32] 11 | original_image_size: [32, 32] 12 | steps: 3 13 | skip: 2 14 | use_merge: True 15 | merge_num: 1 16 | restart_num: 5 17 | lr: 5e-4 18 | weight_decay: 0.01 19 | epochs: 151 20 | ckpt_interval: 1 21 | num_workers: 16 22 | eval_interval: 5 23 | print_interval: 200 24 | persistent_workers: True 25 | 26 | pin_memory: True 27 | clip_grad: 1.0 28 | 29 | cudnnbenchmark: True 30 | 31 | warmup_sche: True 32 | dataset: 'Oxford' 33 | dataroot: '/media/wzy/data' 34 | 35 | profile: False 36 | 37 | 38 | MODEL: 39 | _target_: models.DiffusionLocModel_bev 40 | 41 | IMAGE_FEATURE_EXTRACTOR: 42 | _target_: models.image_feature_extractor_bev.ImageFeatureExtractor_bev 43 | backbone: "vit_small_patch16_384" 44 | freeze: False 45 | in_channels: 128 46 | new_patch_size: [4, 4] 47 | new_patch_stride: [4, 4] 48 | conv_stem: 'ConvStem' # 'none' or 'ConvStem' 49 | stem_base_channels: 32 50 | D_h: 256 # hidden dimension of the stem 51 | image_size: [32, 32] 52 | decoder: 'linear' 53 | pretrained_path: "dinov2_vits14_pretrain.pth" 54 | reuse_pos_emb: true 55 | reuse_patch_emb: false # no patch embedding as a convolutional stem (ConvStem) is used 56 | n_cls: 1 57 | 58 | DENOISER: 59 | _target_: models.Denoiser_bev 60 | TRANSFORMER: 61 | _target_: models.denoiser_bev.TransformerEncoderWrapper_bev 62 | d_model: 512 63 | nhead: 4 64 | num_encoder_layers: 8 65 | dim_feedforward: 1024 66 | dropout: 0.1 67 | batch_first: True 68 | norm_first: True 69 | 70 | 71 | DIFFUSER: 72 | _target_: models.GaussianDiffusion 73 | beta_schedule: custom 74 | 75 | # Data augmentation config 76 | augmentation: 77 | # flip 78 | p_flipx: 0. 79 | p_flipy: 0.5 80 | 81 | # translation 82 | p_transx: 0.5 83 | trans_xmin: -5 84 | trans_xmax: 5 85 | p_transy: 0.5 86 | trans_ymin: -3 87 | trans_ymax: 3 88 | p_transz: 0.5 89 | trans_zmin: -1 90 | trans_zmax: 0. 91 | 92 | # rotation 93 | p_rot_roll: 0.5 94 | rot_rollmin: -5 95 | rot_rollmax: 5 96 | p_rot_pitch: 0.5 97 | rot_pitchmin: -5 98 | rot_pitchmax: 5 99 | p_rot_yaw: 0.5 100 | rot_yawmin: -45 101 | rot_yawmax: 45 102 | -------------------------------------------------------------------------------- /datasets/augmentor.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import numpy as np 4 | from scipy.spatial.transform import Rotation as R 5 | from datasets.robotcar_sdk.python.transform import euler_to_so3 6 | 7 | 8 | class AugmentParams(object): 9 | ''' 10 | Adapted from Z. Zhuang et al. https://github.com/ICEORY/PMF 11 | ''' 12 | 13 | def __init__(self, p_flipx=0., p_flipy=0., 14 | p_transx=0., trans_xmin=0., trans_xmax=0., 15 | p_transy=0., trans_ymin=0., trans_ymax=0., 16 | p_transz=0., trans_zmin=0., trans_zmax=0., 17 | p_rot_roll=0., rot_rollmin=0., rot_rollmax=0., 18 | p_rot_pitch=0., rot_pitchmin=0, rot_pitchmax=0., 19 | p_rot_yaw=0., rot_yawmin=0., rot_yawmax=0., 20 | p_scale=0., scale_min=1.0, scale_max=1.0): 21 | self.p_flipx = p_flipx 22 | self.p_flipy = p_flipy 23 | 24 | self.p_transx = p_transx 25 | self.trans_xmin = trans_xmin 26 | self.trans_xmax = trans_xmax 27 | 28 | self.p_transy = p_transy 29 | self.trans_ymin = trans_ymin 30 | self.trans_ymax = trans_ymax 31 | 32 | self.p_transz = p_transz 33 | self.trans_zmin = trans_zmin 34 | self.trans_zmax = trans_zmax 35 | 36 | self.p_rot_roll = p_rot_roll 37 | self.rot_rollmin = rot_rollmin 38 | self.rot_rollmax = rot_rollmax 39 | 40 | self.p_rot_pitch = p_rot_pitch 41 | self.rot_pitchmin = rot_pitchmin 42 | self.rot_pitchmax = rot_pitchmax 43 | 44 | self.p_rot_yaw = p_rot_yaw 45 | self.rot_yawmin = rot_yawmin 46 | self.rot_yawmax = rot_yawmax 47 | 48 | self.p_scale = p_scale 49 | self.scale_min = scale_min 50 | self.scale_max = scale_max 51 | 52 | def sefScaleParams(self, p_scale, scale_min, scale_max): 53 | self.p_scale = p_scale 54 | self.scale_min = scale_min 55 | self.scale_max = scale_max 56 | 57 | def setFlipProb(self, p_flipx, p_flipy): 58 | self.p_flipx = p_flipx 59 | self.p_flipy = p_flipy 60 | 61 | def setTranslationParams(self, 62 | p_transx=0., trans_xmin=0., trans_xmax=0., 63 | p_transy=0., trans_ymin=0., trans_ymax=0., 64 | p_transz=0., trans_zmin=0., trans_zmax=0.): 65 | self.p_transx = p_transx 66 | self.trans_xmin = trans_xmin 67 | self.trans_xmax = trans_xmax 68 | 69 | self.p_transy = p_transy 70 | self.trans_ymin = trans_ymin 71 | self.trans_ymax = trans_ymax 72 | 73 | self.p_transz = p_transz 74 | self.trans_zmin = trans_zmin 75 | self.trans_zmax = trans_zmax 76 | 77 | def setRotationParams(self, 78 | p_rot_roll=0., rot_rollmin=0., rot_rollmax=0., 79 | p_rot_pitch=0., rot_pitchmin=0, rot_pitchmax=0., 80 | p_rot_yaw=0., rot_yawmin=0., rot_yawmax=0.): 81 | 82 | self.p_rot_roll = p_rot_roll 83 | self.rot_rollmin = rot_rollmin 84 | self.rot_rollmax = rot_rollmax 85 | 86 | self.p_rot_pitch = p_rot_pitch 87 | self.rot_pitchmin = rot_pitchmin 88 | self.rot_pitchmax = rot_pitchmax 89 | 90 | self.p_rot_yaw = p_rot_yaw 91 | self.rot_yawmin = rot_yawmin 92 | self.rot_yawmax = rot_yawmax 93 | 94 | def __str__(self): 95 | print('=== Augmentor parameters ===') 96 | # print('p_flipx: {}, p_flipy: {}'.format(self.p_flipx, self.p_flipy)) 97 | print('p_transx: {}, p_transxmin: {}, p_transxmax: {}'.format( 98 | self.p_transx, self.trans_xmin, self.trans_xmax)) 99 | print('p_transy: {}, p_transymin: {}, p_transymax: {}'.format( 100 | self.p_transy, self.trans_ymin, self.trans_ymax)) 101 | print('p_transz: {}, p_transzmin: {}, p_transzmax: {}'.format( 102 | self.p_transz, self.trans_zmin, self.trans_zmax)) 103 | print('p_rotroll: {}, rot_rollmin: {}, rot_rollmax: {}'.format( 104 | self.p_rot_roll, self.rot_rollmin, self.rot_rollmax)) 105 | print('p_rotpitch: {}, rot_pitchmin: {}, rot_pitchmax: {}'.format( 106 | self.p_rot_pitch, self.rot_pitchmin, self.rot_pitchmax)) 107 | print('p_rotyaw: {}, rot_yawmin: {}, rot_yawmax: {}'.format( 108 | self.p_rot_yaw, self.rot_yawmin, self.rot_yawmax)) 109 | print('p_scale: {}, scale_min: {}, scale_max: {}'.format( 110 | self.p_scale, self.scale_min, self.scale_max)) 111 | 112 | class Augmentor(object): 113 | def __init__(self, params: AugmentParams): 114 | self.parmas = params 115 | 116 | @staticmethod 117 | def flipX(pointcloud: np.ndarray): 118 | pointcloud[:, 0] = -pointcloud[:, 0] 119 | return pointcloud 120 | 121 | @staticmethod 122 | def flipY(pointcloud: np.ndarray): 123 | pointcloud[:, 1] = -pointcloud[:, 1] 124 | return pointcloud 125 | 126 | @staticmethod 127 | def translation(pointcloud: np.ndarray, x: float, y: float, z: float): 128 | pointcloud[:, 0] += x 129 | pointcloud[:, 1] += y 130 | pointcloud[:, 2] += z 131 | return pointcloud 132 | 133 | @staticmethod 134 | def rotation(pointcloud: np.ndarray, roll: float, pitch: float, yaw: float, degrees=True): 135 | # rot_matrix = R.from_euler( 136 | # 'zyx', [yaw, pitch, roll], degrees=degrees).as_matrix() 137 | # 需要先转换为弧度制 138 | roll = math.radians(roll) 139 | pitch = math.radians(pitch) 140 | yaw = math.radians(yaw) 141 | rot_matrix = euler_to_so3([roll, pitch, yaw])[:3, :3] # [3, 3] 142 | # pointcloud[:, :3] = np.matmul(pointcloud[:, :3], rot_matrix.T) 143 | pointcloud[:, :3] = (rot_matrix @ pointcloud[:, :3].transpose()).transpose() 144 | return pointcloud, rot_matrix 145 | 146 | @staticmethod 147 | def randomRotation(pointcloud: np.ndarray): 148 | rot_matrix = R.random(random_state=1234).as_matrix() 149 | pointcloud[:, :3] = np.matmul(pointcloud[:, :3], rot_matrix.T) 150 | return pointcloud 151 | 152 | @staticmethod 153 | def scale_cloud(pointcloud: np.ndarray, scale_min: float, scale_max: float): 154 | pointcloud = pointcloud * np.random.uniform(scale_min, scale_max) 155 | return pointcloud 156 | 157 | def doAugmentation_bev(self, pointcloud): 158 | # # flip augment 159 | # rand = random.uniform(0, 1) 160 | # if rand < self.parmas.p_flipx: 161 | # pointcloud = self.flipX(pointcloud) 162 | # 163 | # rand = random.uniform(0, 1) 164 | # if rand < self.parmas.p_flipy: 165 | # pointcloud = self.flipY(pointcloud) 166 | 167 | # scale augment 168 | # rand = random.uniform(0, 1) 169 | # if rand < self.parmas.p_scale: 170 | # pointcloud = self.scale_cloud(pointcloud, self.parmas.scale_min, self.parmas.scale_max) 171 | 172 | # translation 对每个点单独变换,可不改变pose值 173 | rand = random.uniform(0, 1) 174 | if rand < self.parmas.p_transx: 175 | trans_x = random.uniform( 176 | self.parmas.trans_xmin, self.parmas.trans_xmax) 177 | else: 178 | trans_x = 0 179 | 180 | rand = random.uniform(0, 1) 181 | if rand < self.parmas.p_transy: 182 | trans_y = random.uniform( 183 | self.parmas.trans_ymin, self.parmas.trans_ymax) 184 | else: 185 | trans_y = 0 186 | 187 | # rand = random.uniform(0, 1) 188 | # if rand < self.parmas.p_transz: 189 | # trans_z = random.uniform( 190 | # self.parmas.trans_zmin, self.parmas.trans_zmax) 191 | # else: 192 | # trans_z = 0 193 | # pointcloud = self.translation(pointcloud, trans_x, trans_y, trans_z) 194 | 195 | # rotation 对点云整体变换,需要改变pose值 196 | # rand = random.uniform(0, 1) 197 | # if rand < self.parmas.p_rot_roll: 198 | # rot_roll = random.uniform( 199 | # self.parmas.rot_rollmin, self.parmas.rot_rollmax) 200 | # else: 201 | # rot_roll = 0 202 | 203 | # rand = random.uniform(0, 1) 204 | # if rand < self.parmas.p_rot_pitch: 205 | # rot_pitch = random.uniform( 206 | # self.parmas.rot_pitchmin, self.parmas.rot_pitchmax) 207 | # else: 208 | # rot_pitch = 0 209 | 210 | rand = random.uniform(0, 1) 211 | if rand < self.parmas.p_rot_yaw: 212 | rot_yaw = random.uniform( 213 | self.parmas.rot_yawmin, self.parmas.rot_yawmax) 214 | else: 215 | rot_yaw = 0 216 | # pointcloud, rotation = self.rotation(pointcloud, rot_roll, rot_pitch, rot_yaw) 217 | pointcloud, rotation = self.rotation(pointcloud, 0, 0, rot_yaw) 218 | 219 | return pointcloud, rotation -------------------------------------------------------------------------------- /datasets/composition_bev.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Ziyue Wang and Wen Li 3 | @file: composition_bev.py 4 | @time: 2025/3/12 14:20 5 | """ 6 | 7 | import torch 8 | import numpy as np 9 | import sys 10 | sys.path.insert(0, '../') 11 | 12 | from .oxford_bev import Oxford_BEV 13 | from .nclt_bev import NCLT_BEV 14 | from torch.utils import data 15 | from utils.pose_util import calc_vos_safe_fc 16 | 17 | 18 | class MF_bev(data.Dataset): 19 | def __init__(self, dataset, config, split='train', include_vos=False): 20 | 21 | self.steps = config.train.steps 22 | self.skip = config.train.skip 23 | self.use_merge = config.train.use_merge 24 | self.train = split 25 | 26 | if dataset == 'Oxford': 27 | self.dset = Oxford_BEV(config, split) 28 | elif dataset == 'NCLT': 29 | self.dset = NCLT_BEV(config, split) 30 | else: 31 | raise NotImplementedError('{:s} dataset is not implemented!') 32 | 33 | self.L = self.steps * self.skip 34 | # GCS 35 | self.include_vos = include_vos 36 | self.vo_func = calc_vos_safe_fc 37 | 38 | 39 | def get_indices(self, index): 40 | skips = self.skip * np.ones(self.steps-1) 41 | offsets = np.insert(skips, 0, 0).cumsum() # (self.steps,) 42 | offsets -= offsets[len(offsets) // 2] 43 | offsets = offsets.astype(np.int_) 44 | idx = index + offsets 45 | idx = np.minimum(np.maximum(idx, 0), len(self.dset)-1) 46 | assert np.all(idx >= 0), '{:d}'.format(index) 47 | assert np.all(idx < len(self.dset)) 48 | return idx 49 | 50 | def get_merge_indices(self, index): 51 | skips = self.merge_skip * np.ones(self.merge_steps-1) 52 | offsets = np.insert(skips, 0, 0).cumsum() # (self.steps,) 53 | offsets -= offsets[len(offsets) // 2] 54 | offsets = offsets.astype(np.int_) 55 | idx = index + offsets 56 | idx = np.minimum(np.maximum(idx, 0), len(self.dset)-1) 57 | assert np.all(idx >= 0), '{:d}'.format(index) 58 | assert np.all(idx < len(self.dset)) 59 | return idx 60 | 61 | def __getitem__(self, index): 62 | idx = self.get_indices(index) 63 | 64 | clip = [self.dset[i] for i in idx] 65 | pcs = torch.stack([c[0] for c in clip], dim=0) # (self.steps, 1, 251, 251) 66 | poses = torch.stack([c[1] for c in clip], dim=0) # (self.steps, 3) 67 | 68 | if self.train == 'train' and self.use_merge: 69 | merge_pcs = torch.cat([c[2] for c in clip], dim=0) # (self.steps, 1, 251, 251) 70 | merge_poses = torch.cat([c[3] for c in clip], dim=0) # (self.steps, 3) 71 | 72 | pcs = torch.cat([pcs, merge_pcs], dim=0) 73 | poses = torch.cat([poses, merge_poses], dim=0) 74 | 75 | if self.include_vos: 76 | vos = self.vo_func(poses.unsqueeze(0))[0] 77 | poses = torch.cat((poses, vos), dim=0) 78 | 79 | batch = { 80 | "image": pcs, 81 | "pose": poses, 82 | } 83 | return batch 84 | 85 | def __len__(self): 86 | L = len(self.dset) 87 | return L -------------------------------------------------------------------------------- /datasets/nclt_bev.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Ziyue Wang and Wen Li 3 | @file: nclt_bev.py 4 | @time: 2025/3/12 14:20 5 | """ 6 | 7 | import os 8 | import cv2 9 | import h5py 10 | import torch 11 | import random 12 | import numpy as np 13 | import os.path as osp 14 | from torch.utils import data 15 | from datasets.projection import getBEV 16 | from datasets.augmentor import Augmentor, AugmentParams 17 | from utils.pose_util import process_poses, filter_overflow_nclt, interpolate_pose_nclt, so3_to_euler_nclt, poses_foraugmentaion 18 | 19 | BASE_DIR = osp.dirname(osp.abspath(__file__)) 20 | 21 | velodatatype = np.dtype({ 22 | 'x': ('= x_num or y_ind >= y_num or x_ind < 0 or y_ind < 0): 61 | continue 62 | if mat_global_image[ y_ind,x_ind]<10: 63 | mat_global_image[ y_ind,x_ind] += 1 64 | 65 | max_pixel = np.max(np.max(mat_global_image)) 66 | 67 | mat_global_image[mat_global_image<=1] = 0 68 | mat_global_image = mat_global_image*10 69 | 70 | mat_global_image[np.where(mat_global_image>255)]=255 71 | mat_global_image = mat_global_image/np.max(mat_global_image)*255 72 | 73 | return mat_global_image -------------------------------------------------------------------------------- /datasets/robotcar_sdk/extrinsics/ins.txt: -------------------------------------------------------------------------------- 1 | -1.7132 0.1181 1.1948 -0.0125 0.0400 0.0050 2 | -------------------------------------------------------------------------------- /datasets/robotcar_sdk/extrinsics/ldmrs.txt: -------------------------------------------------------------------------------- 1 | 1.5349 0.0090 1.3777 0.0205 0.0767 -0.0299 -------------------------------------------------------------------------------- /datasets/robotcar_sdk/extrinsics/lms_front.txt: -------------------------------------------------------------------------------- 1 | 1.7589 0.2268 1.0411 -0.0437 -1.4572 0.0356 2 | -------------------------------------------------------------------------------- /datasets/robotcar_sdk/extrinsics/lms_rear.txt: -------------------------------------------------------------------------------- 1 | -2.5850 0.2852 1.0885 -2.8405 -1.5090 -0.3614 -------------------------------------------------------------------------------- /datasets/robotcar_sdk/extrinsics/mono_left.txt: -------------------------------------------------------------------------------- 1 | -0.0905 1.6375 0.2803 0.2079 -0.2339 1.2321 2 | -------------------------------------------------------------------------------- /datasets/robotcar_sdk/extrinsics/mono_rear.txt: -------------------------------------------------------------------------------- 1 | -2.0582 0.0894 0.3675 -0.0119 -0.2498 3.1283 2 | -------------------------------------------------------------------------------- /datasets/robotcar_sdk/extrinsics/mono_right.txt: -------------------------------------------------------------------------------- 1 | -0.2587 -1.6810 0.3226 -0.1961 -0.2469 -1.2675 2 | -------------------------------------------------------------------------------- /datasets/robotcar_sdk/extrinsics/nlct_velodyne.txt: -------------------------------------------------------------------------------- 1 | 0.002 -0.004 -0.957 0.807 0.166 -90.703 -------------------------------------------------------------------------------- /datasets/robotcar_sdk/extrinsics/radar.txt: -------------------------------------------------------------------------------- 1 | -0.71813 0.12 -0.54479 0 0.05 0 2 | -------------------------------------------------------------------------------- /datasets/robotcar_sdk/extrinsics/stereo.txt: -------------------------------------------------------------------------------- 1 | 0 0 0 0 0 0 2 | -------------------------------------------------------------------------------- /datasets/robotcar_sdk/extrinsics/velodyne_left.txt: -------------------------------------------------------------------------------- 1 | -0.60072 -0.34077 -0.26837 -0.0053948 -0.041998 -3.1337 2 | -------------------------------------------------------------------------------- /datasets/robotcar_sdk/extrinsics/velodyne_right.txt: -------------------------------------------------------------------------------- 1 | -0.61153 0.55676 -0.27023 0.0027052 -0.041999 -3.1357 2 | -------------------------------------------------------------------------------- /datasets/robotcar_sdk/python/README.md: -------------------------------------------------------------------------------- 1 | RobotCar Dataset Python Tools 2 | ============================= 3 | 4 | This directory contains sample python code for viewing and manipulating data from the [Oxford Robotcar Dataset](http://robotcar-dataset.robots.ox.ac.uk) and [Oxford Radar Robotcar Dataset](http://ori.ox.ac.uk/datasets/radar-robotcar-dataset). 5 | 6 | Requirements 7 | ------------ 8 | The python tools have been tested on Python 2.7. 9 | Python 3.* compatibility has not been verified. 10 | 11 | The following packages are required: 12 | * numpy 13 | * matplotlib 14 | * colour_demosaicing 15 | * pillow 16 | * opencv-python 17 | * open3d-python 18 | 19 | These can be installed with pip: 20 | 21 | ``` 22 | pip install numpy matplotlib colour_demosaicing pillow opencv-python open3d-python 23 | ``` 24 | 25 | Command Line Tools 26 | ------------------ 27 | 28 | ### Viewing Images 29 | The `play_images.py` script can be used to view images from the dataset. 30 | 31 | ```bash 32 | python play_images.py /path/to/data/yyyy-mm-dd-hh-mm-ss/stereo/centre 33 | ``` 34 | 35 | If you wish to undistort the images before viewing them, pass the camera model directory as a second argument: 36 | 37 | ```bash 38 | python play_images.py /path/to/data/yyyy-mm-dd-hh-mm-ss/stereo/centre --models_dir /path/to/camera/models 39 | ``` 40 | 41 | ### Viewing Radar Scans 42 | The `play_radar.py` script can be used to view radar scans. 43 | 44 | ```bash 45 | python play_radar.py /path/to/data/yyyy-mm-dd-hh-mm-ss/radar 46 | ``` 47 | 48 | ### Viewing Velodyne Scans 49 | The `play_velodyne.py` script can be used to view velodyne scans from raw or binary form. 50 | 51 | ```bash 52 | python play_velodyne.py /path/to/data/yyyy-mm-dd-hh-mm-ss/velodyne_left 53 | ``` 54 | 55 | ### Building Pointclouds 56 | The `build_pointcloud.py` script builds and displays a 3D pointcloud by combining multiple LIDAR scans with a pose source. 57 | The pose source can be either INS data or the supplied visual odometry data. For example: 58 | 59 | ```bash 60 | python build_pointcloud.py --laser_dir /path/to/data/yyyy-mm-dd-hh-mm-ss/lms_front --extrinsics_dir ../extrinsics --poses_file /path/to/data/yyyy-mm-dd-hh-mm-ss/vo/vo.csv 61 | ``` 62 | 63 | ### Projecting pointclouds into images 64 | The `project_laser_into_camera.py` script first builds a pointcloud, then projects it into a camera image using a pinhole camera model. 65 | For example: 66 | 67 | ```bash 68 | python project_laser_into_camera.py --image_dir /path/to/data/yyyy-mm-dd-hh-mm-ss/stereo/centre --laser_dir /path/to/data/yyyy-mm-dd-hh-mm-ss/ldmrs --poses_file /path/to/data/yyyy-mm-dd-hh-mm-ss/vo/vo.csv --models_dir /path/to/models --extrinsics_dir ../extrinsics --image_idx 200 69 | ``` 70 | 71 | Usage from Python 72 | ----------------- 73 | The scripts here are also designed to be used in your own scripts. 74 | 75 | * `build_pointcloud.py`: function for building a pointcloud from LIDAR and odometry data 76 | * `camera_model.py`: loads camera models from disk, and provides undistortion of images and projection of pointclouds 77 | * `interpolate_poses.py`: functions for interpolating VO or INS data to obtain pose estimates at arbitrary timestamps 78 | * `transform.py`: functions for converting between various transform representations 79 | * `image.py`: function for loading, Bayer demosaicing and undistorting images 80 | * `velodyne.py`: functions for loading Velodyne scan data and converting a raw scan representation to pointcloud 81 | * `radar.py`: functions for loading radar scan data and converting a polar scan representation to Cartesian 82 | 83 | For examples of how to use these functions, see the command line tools above. 84 | -------------------------------------------------------------------------------- /datasets/robotcar_sdk/python/__init__.py: -------------------------------------------------------------------------------- 1 | from . import transform -------------------------------------------------------------------------------- /datasets/robotcar_sdk/python/build_pointcloud.py: -------------------------------------------------------------------------------- 1 | ################################################################################ 2 | # 3 | # Copyright (c) 2017 University of Oxford 4 | # Authors: 5 | # Geoff Pascoe (gmp@robots.ox.ac.uk) 6 | # 7 | # This work is licensed under the Creative Commons 8 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 9 | # To view a copy of this license, visit 10 | # http://creativecommons.org/licenses/by-nc-sa/4.0/ or send a letter to 11 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 12 | # 13 | ################################################################################ 14 | 15 | import os 16 | import re 17 | import numpy as np 18 | 19 | from transform import build_se3_transform 20 | from interpolate_poses import interpolate_vo_poses, interpolate_ins_poses 21 | from velodyne import load_velodyne_raw, load_velodyne_binary, velodyne_raw_to_pointcloud 22 | 23 | 24 | def build_pointcloud(lidar_dir, poses_file, extrinsics_dir, start_time, end_time, origin_time=-1): 25 | """Builds a pointcloud by combining multiple LIDAR scans with odometry information. 26 | 27 | Args: 28 | lidar_dir (str): Directory containing LIDAR scans. 29 | poses_file (str): Path to a file containing pose information. Can be VO or INS data. 30 | extrinsics_dir (str): Directory containing extrinsic calibrations. 31 | start_time (int): UNIX timestamp of the start of the window over which to build the pointcloud. 32 | end_time (int): UNIX timestamp of the end of the window over which to build the pointcloud. 33 | origin_time (int): UNIX timestamp of origin frame. Pointcloud coordinates are relative to this frame. 34 | 35 | Returns: 36 | numpy.ndarray: 3xn array of (x, y, z) coordinates of pointcloud 37 | numpy.array: array of n reflectance values or None if no reflectance values are recorded (LDMRS) 38 | 39 | Raises: 40 | ValueError: if specified window doesn't contain any laser scans. 41 | IOError: if scan files are not found. 42 | 43 | """ 44 | if origin_time < 0: 45 | origin_time = start_time 46 | 47 | lidar = re.search('(lms_front|lms_rear|ldmrs|velodyne_left|velodyne_right)', lidar_dir).group(0) 48 | timestamps_path = os.path.join(lidar_dir, os.pardir, lidar + '.timestamps') 49 | 50 | timestamps = [] 51 | with open(timestamps_path) as timestamps_file: 52 | for line in timestamps_file: 53 | timestamp = int(line.split(' ')[0]) 54 | if start_time <= timestamp <= end_time: 55 | timestamps.append(timestamp) 56 | 57 | if len(timestamps) == 0: 58 | raise ValueError("No LIDAR data in the given time bracket.") 59 | 60 | with open(os.path.join(extrinsics_dir, lidar + '.txt')) as extrinsics_file: 61 | extrinsics = next(extrinsics_file) 62 | G_posesource_laser = build_se3_transform([float(x) for x in extrinsics.split(' ')]) 63 | 64 | poses_type = re.search('(vo|ins|rtk)\.csv', poses_file).group(1) 65 | 66 | if poses_type in ['ins', 'rtk']: 67 | with open(os.path.join(extrinsics_dir, 'ins.txt')) as extrinsics_file: 68 | extrinsics = next(extrinsics_file) 69 | G_posesource_laser = np.linalg.solve(build_se3_transform([float(x) for x in extrinsics.split(' ')]), 70 | G_posesource_laser) 71 | 72 | poses = interpolate_ins_poses(poses_file, timestamps, origin_time, use_rtk=(poses_type == 'rtk')) 73 | else: 74 | # sensor is VO, which is located at the main vehicle frame 75 | poses = interpolate_vo_poses(poses_file, timestamps, origin_time) 76 | 77 | pointcloud = np.array([[0], [0], [0], [0]]) 78 | if lidar == 'ldmrs': 79 | reflectance = None 80 | else: 81 | reflectance = np.empty((0)) 82 | 83 | for i in range(0, len(poses)): 84 | scan_path = os.path.join(lidar_dir, str(timestamps[i]) + '.bin') 85 | if "velodyne" not in lidar: 86 | if not os.path.isfile(scan_path): 87 | continue 88 | 89 | scan_file = open(scan_path) 90 | scan = np.fromfile(scan_file, np.double) 91 | scan_file.close() 92 | 93 | scan = scan.reshape((len(scan) // 3, 3)).transpose() 94 | 95 | if lidar != 'ldmrs': 96 | # LMS scans are tuples of (x, y, reflectance) 97 | reflectance = np.concatenate((reflectance, np.ravel(scan[2, :]))) 98 | scan[2, :] = np.zeros((1, scan.shape[1])) 99 | else: 100 | if os.path.isfile(scan_path): 101 | ptcld = load_velodyne_binary(scan_path) 102 | else: 103 | scan_path = os.path.join(lidar_dir, str(timestamps[i]) + '.png') 104 | if not os.path.isfile(scan_path): 105 | continue 106 | ranges, intensities, angles, approximate_timestamps = load_velodyne_raw(scan_path) 107 | ptcld = velodyne_raw_to_pointcloud(ranges, intensities, angles) 108 | 109 | reflectance = np.concatenate((reflectance, ptcld[3])) 110 | scan = ptcld[:3] 111 | 112 | scan = np.dot(np.dot(poses[i], G_posesource_laser), np.vstack([scan, np.ones((1, scan.shape[1]))])) 113 | pointcloud = np.hstack([pointcloud, scan]) 114 | 115 | pointcloud = pointcloud[:, 1:] 116 | if pointcloud.shape[1] == 0: 117 | raise IOError("Could not find scan files for given time range in directory " + lidar_dir) 118 | 119 | return pointcloud, reflectance 120 | 121 | 122 | if __name__ == "__main__": 123 | import argparse 124 | import open3d 125 | 126 | parser = argparse.ArgumentParser(description='Build and display a pointcloud') 127 | parser.add_argument('--poses_file', type=str, default=None, help='File containing relative or absolute poses') 128 | parser.add_argument('--extrinsics_dir', type=str, default=None, 129 | help='Directory containing extrinsic calibrations') 130 | parser.add_argument('--laser_dir', type=str, default=None, help='Directory containing LIDAR data') 131 | 132 | args = parser.parse_args() 133 | 134 | lidar = re.search('(lms_front|lms_rear|ldmrs|velodyne_left|velodyne_right)', args.laser_dir).group(0) 135 | timestamps_path = os.path.join(args.laser_dir, os.pardir, lidar + '.timestamps') 136 | with open(timestamps_path) as timestamps_file: 137 | start_time = int(next(timestamps_file).split(' ')[0]) 138 | 139 | end_time = start_time + 2e7 140 | 141 | pointcloud, reflectance = build_pointcloud(args.laser_dir, args.poses_file, 142 | args.extrinsics_dir, start_time, end_time) 143 | 144 | if reflectance is not None: 145 | colours = (reflectance - reflectance.min()) / (reflectance.max() - reflectance.min()) 146 | colours = 1 / (1 + np.exp(-10 * (colours - colours.mean()))) 147 | else: 148 | colours = 'gray' 149 | 150 | # Pointcloud Visualisation using Open3D 151 | vis = open3d.Visualizer() 152 | vis.create_window(window_name=os.path.basename(__file__)) 153 | render_option = vis.get_render_option() 154 | render_option.background_color = np.array([0.1529, 0.1569, 0.1333], np.float32) 155 | render_option.point_color_option = open3d.PointColorOption.ZCoordinate 156 | coordinate_frame = open3d.geometry.create_mesh_coordinate_frame() 157 | vis.add_geometry(coordinate_frame) 158 | pcd = open3d.geometry.PointCloud() 159 | pcd.points = open3d.utility.Vector3dVector( 160 | -np.ascontiguousarray(pointcloud[[1, 0, 2]].transpose().astype(np.float64))) 161 | pcd.colors = open3d.utility.Vector3dVector(np.tile(colours[:, np.newaxis], (1, 3)).astype(np.float64)) 162 | # Rotate pointcloud to align displayed coordinate frame colouring 163 | pcd.transform(build_se3_transform([0, 0, 0, np.pi, 0, -np.pi / 2])) 164 | vis.add_geometry(pcd) 165 | view_control = vis.get_view_control() 166 | params = view_control.convert_to_pinhole_camera_parameters() 167 | params.extrinsic = build_se3_transform([0, 3, 10, 0, -np.pi * 0.42, -np.pi / 2]) 168 | view_control.convert_from_pinhole_camera_parameters(params) 169 | vis.run() 170 | -------------------------------------------------------------------------------- /datasets/robotcar_sdk/python/camera_model.py: -------------------------------------------------------------------------------- 1 | ################################################################################ 2 | # 3 | # Copyright (c) 2017 University of Oxford 4 | # Authors: 5 | # Geoff Pascoe (gmp@robots.ox.ac.uk) 6 | # 7 | # This work is licensed under the Creative Commons 8 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 9 | # To view a copy of this license, visit 10 | # http://creativecommons.org/licenses/by-nc-sa/4.0/ or send a letter to 11 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 12 | # 13 | ################################################################################ 14 | 15 | import re 16 | import os 17 | import numpy as np 18 | import scipy.interpolate as interp 19 | from scipy.ndimage import map_coordinates 20 | 21 | 22 | class CameraModel: 23 | """Provides intrinsic parameters and undistortion LUT for a camera. 24 | 25 | Attributes: 26 | camera (str): Name of the camera. 27 | camera sensor (str): Name of the sensor on the camera for multi-sensor cameras. 28 | focal_length (tuple[float]): Focal length of the camera in horizontal and vertical axis, in pixels. 29 | principal_point (tuple[float]): Principal point of camera for pinhole projection model, in pixels. 30 | G_camera_image (:obj: `numpy.matrixlib.defmatrix.matrix`): Transform from image frame to camera frame. 31 | bilinear_lut (:obj: `numpy.ndarray`): Look-up table for undistortion of images, mapping pixels in an undistorted 32 | image to pixels in the distorted image 33 | 34 | """ 35 | 36 | def __init__(self, models_dir, images_dir): 37 | """Loads a camera model from disk. 38 | 39 | Args: 40 | models_dir (str): directory containing camera model files. 41 | images_dir (str): directory containing images for which to read camera model. 42 | 43 | """ 44 | self.camera = None 45 | self.camera_sensor = None 46 | self.focal_length = None 47 | self.principal_point = None 48 | self.G_camera_image = None 49 | self.bilinear_lut = None 50 | 51 | self.__load_intrinsics(models_dir, images_dir) 52 | self.__load_lut(models_dir, images_dir) 53 | 54 | def project(self, xyz, image_size): 55 | """Projects a pointcloud into the camera using a pinhole camera model. 56 | 57 | Args: 58 | xyz (:obj: `numpy.ndarray`): 3xn array, where each column is (x, y, z) point relative to camera frame. 59 | image_size (tuple[int]): dimensions of image in pixels 60 | 61 | Returns: 62 | numpy.ndarray: 2xm array of points, where each column is the (u, v) pixel coordinates of a point in pixels. 63 | numpy.array: array of depth values for points in image. 64 | 65 | Note: 66 | Number of output points m will be less than or equal to number of input points n, as points that do not 67 | project into the image are discarded. 68 | 69 | """ 70 | if xyz.shape[0] == 3: 71 | xyz = np.stack((xyz, np.ones((1, xyz.shape[1])))) 72 | xyzw = np.linalg.solve(self.G_camera_image, xyz) 73 | 74 | # Find which points lie in front of the camera 75 | in_front = [i for i in range(0, xyzw.shape[1]) if xyzw[2, i] >= 0] 76 | xyzw = xyzw[:, in_front] 77 | 78 | uv = np.vstack((self.focal_length[0] * xyzw[0, :] / xyzw[2, :] + self.principal_point[0], 79 | self.focal_length[1] * xyzw[1, :] / xyzw[2, :] + self.principal_point[1])) 80 | 81 | in_img = [i for i in range(0, uv.shape[1]) 82 | if 0.5 <= uv[0, i] <= image_size[1] and 0.5 <= uv[1, i] <= image_size[0]] 83 | 84 | return uv[:, in_img], np.ravel(xyzw[2, in_img]) 85 | 86 | def undistort(self, image): 87 | """Undistorts an image. 88 | 89 | Args: 90 | image (:obj: `numpy.ndarray`): A distorted image. Must be demosaiced - ie. must be a 3-channel RGB image. 91 | 92 | Returns: 93 | numpy.ndarray: Undistorted version of image. 94 | 95 | Raises: 96 | ValueError: if image size does not match camera model. 97 | ValueError: if image only has a single channel. 98 | 99 | """ 100 | if image.shape[0] * image.shape[1] != self.bilinear_lut.shape[0]: 101 | raise ValueError('Incorrect image size for camera model') 102 | 103 | lut = self.bilinear_lut[:, 1::-1].T.reshape((2, image.shape[0], image.shape[1])) 104 | 105 | if len(image.shape) == 1: 106 | raise ValueError('Undistortion function only works with multi-channel images') 107 | 108 | undistorted = np.rollaxis(np.array([map_coordinates(image[:, :, channel], lut, order=1) 109 | for channel in range(0, image.shape[2])]), 0, 3) 110 | 111 | return undistorted.astype(image.dtype) 112 | 113 | def __get_model_name(self, images_dir): 114 | self.camera = re.search('(stereo|mono_(left|right|rear))', images_dir).group(0) 115 | if self.camera == 'stereo': 116 | self.camera_sensor = re.search('(left|centre|right)', images_dir).group(0) 117 | if self.camera_sensor == 'left': 118 | return 'stereo_wide_left' 119 | elif self.camera_sensor == 'right': 120 | return 'stereo_wide_right' 121 | elif self.camera_sensor == 'centre': 122 | return 'stereo_narrow_left' 123 | else: 124 | raise RuntimeError('Unknown camera model for given directory: ' + images_dir) 125 | else: 126 | return self.camera 127 | 128 | def __load_intrinsics(self, models_dir, images_dir): 129 | model_name = self.__get_model_name(images_dir) 130 | intrinsics_path = os.path.join(models_dir, model_name + '.txt') 131 | 132 | with open(intrinsics_path) as intrinsics_file: 133 | vals = [float(x) for x in next(intrinsics_file).split()] 134 | self.focal_length = (vals[0], vals[1]) 135 | self.principal_point = (vals[2], vals[3]) 136 | 137 | G_camera_image = [] 138 | for line in intrinsics_file: 139 | G_camera_image.append([float(x) for x in line.split()]) 140 | self.G_camera_image = np.array(G_camera_image) 141 | 142 | def __load_lut(self, models_dir, images_dir): 143 | model_name = self.__get_model_name(images_dir) 144 | lut_path = os.path.join(models_dir, model_name + '_distortion_lut.bin') 145 | 146 | lut = np.fromfile(lut_path, np.double) 147 | lut = lut.reshape([2, lut.size // 2]) 148 | self.bilinear_lut = lut.transpose() 149 | 150 | -------------------------------------------------------------------------------- /datasets/robotcar_sdk/python/image.py: -------------------------------------------------------------------------------- 1 | ################################################################################ 2 | # 3 | # Copyright (c) 2017 University of Oxford 4 | # Authors: 5 | # Geoff Pascoe (gmp@robots.ox.ac.uk) 6 | # 7 | # This work is licensed under the Creative Commons 8 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 9 | # To view a copy of this license, visit 10 | # http://creativecommons.org/licenses/by-nc-sa/4.0/ or send a letter to 11 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 12 | # 13 | ############################################################################### 14 | 15 | import re 16 | from PIL import Image 17 | from colour_demosaicing import demosaicing_CFA_Bayer_bilinear as demosaic 18 | import numpy as np 19 | 20 | BAYER_STEREO = 'gbrg' 21 | BAYER_MONO = 'rggb' 22 | 23 | 24 | def load_image(image_path, model=None): 25 | """Loads and rectifies an image from file. 26 | 27 | Args: 28 | image_path (str): path to an image from the dataset. 29 | model (camera_model.CameraModel): if supplied, model will be used to undistort image. 30 | 31 | Returns: 32 | numpy.ndarray: demosaiced and optionally undistorted image 33 | 34 | """ 35 | if model: 36 | camera = model.camera 37 | else: 38 | camera = re.search('(stereo|mono_(left|right|rear))', image_path).group(0) 39 | if camera == 'stereo': 40 | pattern = BAYER_STEREO 41 | else: 42 | pattern = BAYER_MONO 43 | 44 | img = Image.open(image_path) 45 | img = demosaic(img, pattern) 46 | if model: 47 | img = model.undistort(img) 48 | 49 | return np.array(img).astype(np.uint8) 50 | 51 | -------------------------------------------------------------------------------- /datasets/robotcar_sdk/python/interpolate_poses.py: -------------------------------------------------------------------------------- 1 | ################################################################################ 2 | # 3 | # Copyright (c) 2017 University of Oxford 4 | # Authors: 5 | # Geoff Pascoe (gmp@robots.ox.ac.uk) 6 | # 7 | # This work is licensed under the Creative Commons 8 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 9 | # To view a copy of this license, visit 10 | # http://creativecommons.org/licenses/by-nc-sa/4.0/ or send a letter to 11 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 12 | # 13 | ################################################################################ 14 | 15 | import bisect 16 | import csv 17 | import numpy as np 18 | import numpy.matlib as ml 19 | import sys 20 | sys.path.append('.') 21 | from .transform import * 22 | 23 | def NCLT_interpolate_vo_poses(vo_path, pose_timestamps, origin_timestamp): 24 | """Interpolate poses from visual odometry. 25 | 26 | Args: 27 | vo_path (str): path to file containing relative poses from visual odometry. 28 | pose_timestamps (list[int]): UNIX timestamps at which interpolated poses are required. 29 | origin_timestamp (int): UNIX timestamp of origin frame. Poses will be reported relative to this frame. 30 | 31 | Returns: 32 | list[numpy.matrixlib.defmatrix.matrix]: SE3 matrix representing interpolated pose for each requested timestamp. 33 | 34 | """ 35 | with open(vo_path) as vo_file: 36 | vo_reader = csv.reader(vo_file) 37 | headers = next(vo_file) 38 | 39 | vo_timestamps = [0] 40 | abs_poses = [ml.identity(4)] 41 | 42 | lower_timestamp = min(min(pose_timestamps), origin_timestamp) 43 | upper_timestamp = max(max(pose_timestamps), origin_timestamp) 44 | 45 | for row in vo_reader: 46 | timestamp = int(row[0]) 47 | if timestamp < lower_timestamp: 48 | vo_timestamps[0] = timestamp 49 | continue 50 | 51 | vo_timestamps.append(timestamp) 52 | 53 | xyzrpy = [float(v) for v in row[1:7]] 54 | rel_pose = build_se3_transform(xyzrpy) 55 | abs_pose = abs_poses[-1] * rel_pose 56 | abs_poses.append(abs_pose) 57 | 58 | if timestamp >= upper_timestamp: 59 | break 60 | 61 | return interpolate_poses(vo_timestamps, abs_poses, pose_timestamps, origin_timestamp) 62 | 63 | def interpolate_vo_poses(vo_path, pose_timestamps, origin_timestamp): 64 | """Interpolate poses from visual odometry. 65 | 66 | Args: 67 | vo_path (str): path to file containing relative poses from visual odometry. 68 | pose_timestamps (list[int]): UNIX timestamps at which interpolated poses are required. 69 | origin_timestamp (int): UNIX timestamp of origin frame. Poses will be reported relative to this frame. 70 | 71 | Returns: 72 | list[numpy.matrixlib.defmatrix.matrix]: SE3 matrix representing interpolated pose for each requested timestamp. 73 | 74 | """ 75 | with open(vo_path) as vo_file: 76 | vo_reader = csv.reader(vo_file) 77 | headers = next(vo_file) 78 | 79 | vo_timestamps = [0] 80 | abs_poses = [ml.identity(4)] 81 | 82 | lower_timestamp = min(min(pose_timestamps), origin_timestamp) 83 | upper_timestamp = max(max(pose_timestamps), origin_timestamp) 84 | 85 | for row in vo_reader: 86 | timestamp = int(row[0]) 87 | if timestamp < lower_timestamp: 88 | vo_timestamps[0] = timestamp 89 | continue 90 | 91 | vo_timestamps.append(timestamp) 92 | 93 | xyzrpy = [float(v) for v in row[2:8]] 94 | rel_pose = build_se3_transform(xyzrpy) 95 | abs_pose = abs_poses[-1] * rel_pose 96 | abs_poses.append(abs_pose) 97 | 98 | if timestamp >= upper_timestamp: 99 | break 100 | 101 | return interpolate_poses(vo_timestamps, abs_poses, pose_timestamps, origin_timestamp) 102 | 103 | 104 | def interpolate_ins_poses(ins_path, pose_timestamps, origin_timestamp, use_rtk=False): 105 | """Interpolate poses from INS. 106 | 107 | Args: 108 | ins_path (str): path to file containing poses from INS. 109 | pose_timestamps (list[int]): UNIX timestamps at which interpolated poses are required. 110 | origin_timestamp (int): UNIX timestamp of origin frame. Poses will be reported relative to this frame. 111 | 112 | Returns: 113 | list[numpy.matrixlib.defmatrix.matrix]: SE3 matrix representing interpolated pose for each requested timestamp. 114 | 115 | """ 116 | with open(ins_path) as ins_file: 117 | ins_reader = csv.reader(ins_file) 118 | headers = next(ins_file) 119 | 120 | ins_timestamps = [0] 121 | abs_poses = [ml.identity(4)] 122 | 123 | upper_timestamp = max(max(pose_timestamps), origin_timestamp) 124 | 125 | for row in ins_reader: 126 | timestamp = int(row[0]) 127 | ins_timestamps.append(timestamp) 128 | 129 | utm = row[5:8] if not use_rtk else row[4:7] 130 | rpy = row[-3:] if not use_rtk else row[11:14] 131 | xyzrpy = [float(v) for v in utm] + [float(v) for v in rpy] 132 | abs_pose = build_se3_transform(xyzrpy) 133 | abs_poses.append(abs_pose) 134 | 135 | if timestamp >= upper_timestamp: 136 | break 137 | 138 | ins_timestamps = ins_timestamps[1:] 139 | abs_poses = abs_poses[1:] 140 | 141 | return interpolate_poses(ins_timestamps, abs_poses, pose_timestamps, origin_timestamp) 142 | 143 | 144 | def interpolate_ins_poses_xmu(ins_path: object, pose_timestamps, origin_timestamp, use_rtk=False): 145 | """Interpolate poses from INS. 146 | 147 | Args: 148 | ins_path (str): path to file containing poses from INS. 149 | pose_timestamps (list[int]): UNIX timestamps at which interpolated poses are required. 150 | origin_timestamp (int): UNIX timestamp of origin frame. Poses will be reported relative to this frame. 151 | 152 | Returns: 153 | list[numpy.matrixlib.defmatrix.matrix]: SE3 matrix representing interpolated pose for each requested timestamp. 154 | 155 | """ 156 | with open(ins_path) as ins_file: 157 | # ins_reader = csv.reader(ins_file) 158 | # headers = next(ins_file) 159 | ins_reader = np.loadtxt(ins_file) 160 | # mask = ins_reader[:, -1] == 2 161 | # ins_reader = ins_reader[mask] 162 | ins_timestamps = [0] 163 | abs_poses = [ml.identity(4)] 164 | 165 | upper_timestamp = max(max(pose_timestamps), origin_timestamp) 166 | 167 | for row in ins_reader: 168 | timestamp = int(row[0]) 169 | ins_timestamps.append(timestamp) 170 | 171 | utm = row[1:4] if not use_rtk else row[4:7] 172 | rpy = row[-3:] if not use_rtk else row[11:14] 173 | xyzrpy = [float(v) for v in utm] + [float(v) for v in rpy] 174 | abs_pose = build_se3_transform(xyzrpy) 175 | abs_poses.append(abs_pose) 176 | 177 | if timestamp >= upper_timestamp: 178 | break 179 | 180 | ins_timestamps = ins_timestamps[1:] 181 | abs_poses = abs_poses[1:] 182 | 183 | return interpolate_poses(ins_timestamps, abs_poses, pose_timestamps, origin_timestamp) 184 | 185 | 186 | 187 | def interpolate_poses(pose_timestamps, abs_poses, requested_timestamps, origin_timestamp): 188 | """Interpolate between absolute poses. 189 | 190 | Args: 191 | pose_timestamps (list[int]): Timestamps of supplied poses. Must be in ascending order. 192 | abs_poses (list[numpy.matrixlib.defmatrix.matrix]): SE3 matrices representing poses at the timestamps specified. 193 | requested_timestamps (list[int]): Timestamps for which interpolated timestamps are required. 194 | origin_timestamp (int): UNIX timestamp of origin frame. Poses will be reported relative to this frame. 195 | 196 | Returns: 197 | list[numpy.matrixlib.defmatrix.matrix]: SE3 matrix representing interpolated pose for each requested timestamp. 198 | 199 | Raises: 200 | ValueError: if pose_timestamps and abs_poses are not the same length 201 | ValueError: if pose_timestamps is not in ascending order 202 | 203 | """ 204 | requested_timestamps.insert(0, origin_timestamp) 205 | requested_timestamps = np.array(requested_timestamps) 206 | pose_timestamps = np.array(pose_timestamps) 207 | 208 | if len(pose_timestamps) != len(abs_poses): 209 | raise ValueError('Must supply same number of timestamps as poses') 210 | 211 | abs_quaternions = np.zeros((4, len(abs_poses))) 212 | abs_positions = np.zeros((3, len(abs_poses))) 213 | for i, pose in enumerate(abs_poses): 214 | if i > 0 and pose_timestamps[i-1] >= pose_timestamps[i]: 215 | raise ValueError('Pose timestamps must be in ascending order') 216 | 217 | abs_quaternions[:, i] = so3_to_quaternion(pose[0:3, 0:3]) 218 | abs_positions[:, i] = np.ravel(pose[0:3, 3]) 219 | 220 | upper_indices = [bisect.bisect(pose_timestamps, pt) for pt in requested_timestamps] 221 | lower_indices = [u - 1 for u in upper_indices] 222 | 223 | if max(upper_indices) >= len(pose_timestamps): 224 | upper_indices = [min(i, len(pose_timestamps) - 1) for i in upper_indices] 225 | 226 | fractions = (requested_timestamps - pose_timestamps[lower_indices]) / \ 227 | (pose_timestamps[upper_indices] - pose_timestamps[lower_indices]) 228 | # import warnings 229 | # warnings.filterwarnings('error') 230 | # try: 231 | # fractions = (requested_timestamps - pose_timestamps[lower_indices]) / \ 232 | # (pose_timestamps[upper_indices] - pose_timestamps[lower_indices]) 233 | # except Warning: 234 | # print('Warning was raised as an exception!') 235 | # print(pose_timestamps[upper_indices]) 236 | # print(pose_timestamps[lower_indices]) 237 | 238 | quaternions_lower = abs_quaternions[:, lower_indices] 239 | quaternions_upper = abs_quaternions[:, upper_indices] 240 | 241 | d_array = (quaternions_lower * quaternions_upper).sum(0) 242 | 243 | linear_interp_indices = np.nonzero(d_array >= 1) 244 | sin_interp_indices = np.nonzero(d_array < 1) 245 | 246 | scale0_array = np.zeros(d_array.shape) 247 | scale1_array = np.zeros(d_array.shape) 248 | 249 | scale0_array[linear_interp_indices] = 1 - fractions[linear_interp_indices] 250 | scale1_array[linear_interp_indices] = fractions[linear_interp_indices] 251 | 252 | theta_array = np.arccos(np.abs(d_array[sin_interp_indices])) 253 | 254 | scale0_array[sin_interp_indices] = \ 255 | np.sin((1 - fractions[sin_interp_indices]) * theta_array) / np.sin(theta_array) 256 | scale1_array[sin_interp_indices] = \ 257 | np.sin(fractions[sin_interp_indices] * theta_array) / np.sin(theta_array) 258 | 259 | negative_d_indices = np.nonzero(d_array < 0) 260 | scale1_array[negative_d_indices] = -scale1_array[negative_d_indices] 261 | 262 | quaternions_interp = np.tile(scale0_array, (4, 1)) * quaternions_lower \ 263 | + np.tile(scale1_array, (4, 1)) * quaternions_upper 264 | 265 | positions_lower = abs_positions[:, lower_indices] 266 | positions_upper = abs_positions[:, upper_indices] 267 | 268 | positions_interp = np.multiply(np.tile((1 - fractions), (3, 1)), positions_lower) \ 269 | + np.multiply(np.tile(fractions, (3, 1)), positions_upper) 270 | 271 | poses_mat = ml.zeros((4, 4 * len(requested_timestamps))) 272 | 273 | poses_mat[0, 0::4] = 1 - 2 * np.square(quaternions_interp[2, :]) - \ 274 | 2 * np.square(quaternions_interp[3, :]) 275 | poses_mat[0, 1::4] = 2 * np.multiply(quaternions_interp[1, :], quaternions_interp[2, :]) - \ 276 | 2 * np.multiply(quaternions_interp[3, :], quaternions_interp[0, :]) 277 | poses_mat[0, 2::4] = 2 * np.multiply(quaternions_interp[1, :], quaternions_interp[3, :]) + \ 278 | 2 * np.multiply(quaternions_interp[2, :], quaternions_interp[0, :]) 279 | 280 | poses_mat[1, 0::4] = 2 * np.multiply(quaternions_interp[1, :], quaternions_interp[2, :]) \ 281 | + 2 * np.multiply(quaternions_interp[3, :], quaternions_interp[0, :]) 282 | poses_mat[1, 1::4] = 1 - 2 * np.square(quaternions_interp[1, :]) \ 283 | - 2 * np.square(quaternions_interp[3, :]) 284 | poses_mat[1, 2::4] = 2 * np.multiply(quaternions_interp[2, :], quaternions_interp[3, :]) - \ 285 | 2 * np.multiply(quaternions_interp[1, :], quaternions_interp[0, :]) 286 | 287 | poses_mat[2, 0::4] = 2 * np.multiply(quaternions_interp[1, :], quaternions_interp[3, :]) - \ 288 | 2 * np.multiply(quaternions_interp[2, :], quaternions_interp[0, :]) 289 | poses_mat[2, 1::4] = 2 * np.multiply(quaternions_interp[2, :], quaternions_interp[3, :]) + \ 290 | 2 * np.multiply(quaternions_interp[1, :], quaternions_interp[0, :]) 291 | poses_mat[2, 2::4] = 1 - 2 * np.square(quaternions_interp[1, :]) - \ 292 | 2 * np.square(quaternions_interp[2, :]) 293 | 294 | poses_mat[0:3, 3::4] = positions_interp 295 | poses_mat[3, 3::4] = 1 296 | 297 | # don't use relative pose 298 | # poses_mat = np.linalg.solve(poses_mat[0:4, 0:4], poses_mat) 299 | 300 | poses_out = [0] * (len(requested_timestamps) - 1) 301 | for i in range(1, len(requested_timestamps)): 302 | poses_out[i - 1] = poses_mat[0:4, i * 4:(i + 1) * 4] 303 | 304 | return poses_out 305 | -------------------------------------------------------------------------------- /datasets/robotcar_sdk/python/play_images.py: -------------------------------------------------------------------------------- 1 | ################################################################################ 2 | # 3 | # Copyright (c) 2017 University of Oxford 4 | # Authors: 5 | # Geoff Pascoe (gmp@robots.ox.ac.uk) 6 | # 7 | # This work is licensed under the Creative Commons 8 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 9 | # To view a copy of this license, visit 10 | # http://creativecommons.org/licenses/by-nc-sa/4.0/ or send a letter to 11 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 12 | # 13 | ################################################################################ 14 | 15 | import argparse 16 | import os 17 | import re 18 | import matplotlib.pyplot as plt 19 | from datetime import datetime as dt 20 | from image import load_image 21 | from camera_model import CameraModel 22 | 23 | parser = argparse.ArgumentParser(description='Play back images from a given directory') 24 | 25 | parser.add_argument('dir', type=str, help='Directory containing images.') 26 | parser.add_argument('--models_dir', type=str, default=None, help='(optional) Directory containing camera model. If supplied, images will be undistorted before display') 27 | parser.add_argument('--scale', type=float, default=1.0, help='(optional) factor by which to scale images before display') 28 | 29 | args = parser.parse_args() 30 | 31 | camera = re.search('(stereo|mono_(left|right|rear))', args.dir).group(0) 32 | 33 | timestamps_path = os.path.join(os.path.join(args.dir, os.pardir, camera + '.timestamps')) 34 | if not os.path.isfile(timestamps_path): 35 | timestamps_path = os.path.join(args.dir, os.pardir, os.pardir, camera + '.timestamps') 36 | if not os.path.isfile(timestamps_path): 37 | raise IOError("Could not find timestamps file") 38 | 39 | model = None 40 | if args.models_dir: 41 | model = CameraModel(args.models_dir, args.dir) 42 | 43 | current_chunk = 0 44 | timestamps_file = open(timestamps_path) 45 | for line in timestamps_file: 46 | tokens = line.split() 47 | datetime = dt.utcfromtimestamp(int(tokens[0])/1000000) 48 | chunk = int(tokens[1]) 49 | 50 | filename = os.path.join(args.dir, tokens[0] + '.png') 51 | if not os.path.isfile(filename): 52 | if chunk != current_chunk: 53 | print("Chunk " + str(chunk) + " not found") 54 | current_chunk = chunk 55 | continue 56 | 57 | current_chunk = chunk 58 | 59 | img = load_image(filename, model) 60 | plt.imshow(img) 61 | plt.xlabel(datetime) 62 | plt.xticks([]) 63 | plt.yticks([]) 64 | plt.pause(0.01) 65 | -------------------------------------------------------------------------------- /datasets/robotcar_sdk/python/play_radar.py: -------------------------------------------------------------------------------- 1 | ################################################################################ 2 | # 3 | # Copyright (c) 2017 University of Oxford 4 | # Authors: 5 | # Dan Barnes (dbarnes@robots.ox.ac.uk) 6 | # 7 | # This work is licensed under the Creative Commons 8 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 9 | # To view a copy of this license, visit 10 | # http://creativecommons.org/licenses/by-nc-sa/4.0/ or send a letter to 11 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 12 | # 13 | ################################################################################ 14 | 15 | import argparse 16 | import os 17 | from radar import load_radar, radar_polar_to_cartesian 18 | import numpy as np 19 | import cv2 20 | 21 | parser = argparse.ArgumentParser(description='Play back radar data from a given directory') 22 | 23 | parser.add_argument('dir', type=str, help='Directory containing radar data.') 24 | 25 | args = parser.parse_args() 26 | 27 | timestamps_path = os.path.join(os.path.join(args.dir, os.pardir, 'radar.timestamps')) 28 | if not os.path.isfile(timestamps_path): 29 | raise IOError("Could not find timestamps file") 30 | 31 | # Cartesian Visualsation Setup 32 | # Resolution of the cartesian form of the radar scan in metres per pixel 33 | cart_resolution = .25 34 | # Cartesian visualisation size (used for both height and width) 35 | cart_pixel_width = 501 # pixels 36 | interpolate_crossover = True 37 | 38 | title = "Radar Visualisation Example" 39 | 40 | radar_timestamps = np.loadtxt(timestamps_path, delimiter=' ', usecols=[0], dtype=np.int64) 41 | for radar_timestamp in radar_timestamps: 42 | filename = os.path.join(args.dir, str(radar_timestamp) + '.png') 43 | 44 | if not os.path.isfile(filename): 45 | raise FileNotFoundError("Could not find radar example: {}".format(filename)) 46 | 47 | timestamps, azimuths, valid, fft_data, radar_resolution = load_radar(filename) 48 | cart_img = radar_polar_to_cartesian(azimuths, fft_data, radar_resolution, cart_resolution, cart_pixel_width, 49 | interpolate_crossover) 50 | 51 | # Combine polar and cartesian for visualisation 52 | # The raw polar data is resized to the height of the cartesian representation 53 | downsample_rate = 4 54 | fft_data_vis = fft_data[:, ::downsample_rate] 55 | resize_factor = float(cart_img.shape[0]) / float(fft_data_vis.shape[0]) 56 | fft_data_vis = cv2.resize(fft_data_vis, (0, 0), None, resize_factor, resize_factor) 57 | vis = cv2.hconcat((fft_data_vis, fft_data_vis[:, :10] * 0 + 1, cart_img)) 58 | 59 | cv2.imshow(title, vis * 2.) # The data is doubled to improve visualisation 60 | cv2.waitKey(1) 61 | -------------------------------------------------------------------------------- /datasets/robotcar_sdk/python/play_velodyne.py: -------------------------------------------------------------------------------- 1 | ################################################################################ 2 | # 3 | # Copyright (c) 2017 University of Oxford 4 | # Authors: 5 | # Dan Barnes (dbarnes@robots.ox.ac.uk) 6 | # 7 | # This work is licensed under the Creative Commons 8 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 9 | # To view a copy of this license, visit 10 | # http://creativecommons.org/licenses/by-nc-sa/4.0/ or send a letter to 11 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 12 | # 13 | ################################################################################ 14 | 15 | 16 | import argparse 17 | from argparse import RawTextHelpFormatter 18 | import os 19 | from velodyne import load_velodyne_raw, load_velodyne_binary, velodyne_raw_to_pointcloud 20 | import numpy as np 21 | import cv2 22 | from matplotlib.cm import get_cmap 23 | from scipy import interpolate 24 | import open3d 25 | from transform import build_se3_transform 26 | 27 | mode_flag_help = """Mode to run in, one of: (raw|raw_interp|raw_ptcld|bin_ptcld) 28 | - 'raw' - visualise the raw velodyne data (intensities and ranges) 29 | - 'raw_interp' - visualise the raw velodyne data (intensities and ranges) interpolated to consistent azimuth angles 30 | between scans. 31 | - 'raw_ptcld' - visualise the raw velodyne data converted to a pointcloud (converts files of the form 32 | .png to pointcloud) 33 | - 'bin_ptcld' - visualise the precomputed velodyne pointclouds (files of the form .bin). This is 34 | approximately 2x faster than running the conversion from raw data `raw_ptcld` at the cost of 35 | approximately 8x the storage space. 36 | """ 37 | 38 | parser = argparse.ArgumentParser(description='Play back velodyne data from a given directory', 39 | formatter_class=RawTextHelpFormatter) 40 | parser.add_argument('--mode', default="raw_interp", type=str, help=mode_flag_help) 41 | parser.add_argument('--scale', default=2., type=float, help="Scale visualisations by this amount") 42 | parser.add_argument('dir', type=str, help='Directory containing velodyne data.') 43 | 44 | args = parser.parse_args() 45 | 46 | 47 | def main(): 48 | velodyne_dir = args.dir 49 | if velodyne_dir[-1] == '/': 50 | velodyne_dir = velodyne_dir[:-1] 51 | 52 | velodyne_sensor = os.path.basename(velodyne_dir) 53 | if velodyne_sensor not in ["velodyne_left", "velodyne_right"]: 54 | raise ValueError("Velodyne directory not valid: {}".format(velodyne_dir)) 55 | 56 | timestamps_path = velodyne_dir + '.timestamps' 57 | if not os.path.isfile(timestamps_path): 58 | raise IOError("Could not find timestamps file: {}".format(timestamps_path)) 59 | 60 | title = "Velodyne Visualisation Example" 61 | extension = ".bin" if args.mode == "bin_ptcld" else ".png" 62 | velodyne_timestamps = np.loadtxt(timestamps_path, delimiter=' ', usecols=[0], dtype=np.int64) 63 | colourmap = (get_cmap("viridis")(np.linspace(0, 1, 255))[:, :3] * 255).astype(np.uint8)[:, ::-1] 64 | interp_angles = np.mod(np.linspace(np.pi, 3 * np.pi, 720), 2 * np.pi) 65 | vis = None 66 | 67 | for velodyne_timestamp in velodyne_timestamps: 68 | 69 | filename = os.path.join(args.dir, str(velodyne_timestamp) + extension) 70 | 71 | if args.mode == "bin_ptcld": 72 | ptcld = load_velodyne_binary(filename) 73 | else: 74 | ranges, intensities, angles, approximate_timestamps = load_velodyne_raw(filename) 75 | 76 | if args.mode == "raw_ptcld": 77 | ptcld = velodyne_raw_to_pointcloud(ranges, intensities, angles) 78 | elif args.mode == "raw_interp": 79 | intensities = interpolate.interp1d(angles[0], intensities, bounds_error=False)(interp_angles) 80 | ranges = interpolate.interp1d(angles[0], ranges, bounds_error=False)(interp_angles) 81 | intensities[np.isnan(intensities)] = 0 82 | ranges[np.isnan(ranges)] = 0 83 | 84 | if '_ptcld' in args.mode: 85 | # Pointcloud Visualisation using Open3D 86 | if vis is None: 87 | vis = open3d.Visualizer() 88 | vis.create_window(window_name=title) 89 | pcd = open3d.geometry.PointCloud() 90 | # initialise the geometry pre loop 91 | pcd.points = open3d.utility.Vector3dVector(ptcld[:3].transpose().astype(np.float64)) 92 | pcd.colors = open3d.utility.Vector3dVector(np.tile(ptcld[3:].transpose(), (1, 3)).astype(np.float64)) 93 | # Rotate pointcloud to align displayed coordinate frame colouring 94 | pcd.transform(build_se3_transform([0, 0, 0, np.pi, 0, -np.pi / 2])) 95 | vis.add_geometry(pcd) 96 | render_option = vis.get_render_option() 97 | render_option.background_color = np.array([0.1529, 0.1569, 0.1333], np.float32) 98 | render_option.point_color_option = open3d.PointColorOption.ZCoordinate 99 | coordinate_frame = open3d.geometry.create_mesh_coordinate_frame() 100 | vis.add_geometry(coordinate_frame) 101 | view_control = vis.get_view_control() 102 | params = view_control.convert_to_pinhole_camera_parameters() 103 | params.extrinsic = build_se3_transform([0, 3, 10, 0, -np.pi * 0.42, -np.pi / 2]) 104 | view_control.convert_from_pinhole_camera_parameters(params) 105 | 106 | pcd.points = open3d.utility.Vector3dVector(ptcld[:3].transpose().astype(np.float64)) 107 | pcd.colors = open3d.utility.Vector3dVector( 108 | np.tile(ptcld[3:].transpose(), (1, 3)).astype(np.float64) / 40) 109 | vis.update_geometry() 110 | vis.poll_events() 111 | vis.update_renderer() 112 | 113 | else: 114 | # Ranges and Intensities visualisation using OpenCV 115 | intensities_vis = colourmap[np.clip((intensities * 4).astype(np.int), 0, colourmap.shape[0] - 1)] 116 | ranges_vis = colourmap[np.clip((ranges * 4).astype(np.int), 0, colourmap.shape[0] - 1)] 117 | visualisation = np.concatenate((intensities_vis, ranges_vis), 0) 118 | visualisation = cv2.resize(visualisation, None, fy=6 * args.scale, fx=args.scale, 119 | interpolation=cv2.INTER_NEAREST) 120 | cv2.imshow(title, visualisation) 121 | cv2.waitKey(1) 122 | 123 | 124 | if __name__ == "__main__": 125 | main() 126 | -------------------------------------------------------------------------------- /datasets/robotcar_sdk/python/project_laser_into_camera.py: -------------------------------------------------------------------------------- 1 | ################################################################################ 2 | # 3 | # Copyright (c) 2017 University of Oxford 4 | # Authors: 5 | # Geoff Pascoe (gmp@robots.ox.ac.uk) 6 | # 7 | # This work is licensed under the Creative Commons 8 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 9 | # To view a copy of this license, visit 10 | # http://creativecommons.org/licenses/by-nc-sa/4.0/ or send a letter to 11 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 12 | # 13 | ################################################################################ 14 | 15 | import os 16 | import re 17 | import numpy as np 18 | import matplotlib.pyplot as plt 19 | import argparse 20 | 21 | from build_pointcloud import build_pointcloud 22 | from transform import build_se3_transform 23 | from image import load_image 24 | from camera_model import CameraModel 25 | 26 | parser = argparse.ArgumentParser(description='Project LIDAR data into camera image') 27 | parser.add_argument('--image_dir', type=str, help='Directory containing images') 28 | parser.add_argument('--laser_dir', type=str, help='Directory containing LIDAR scans') 29 | parser.add_argument('--poses_file', type=str, help='File containing either INS or VO poses') 30 | parser.add_argument('--models_dir', type=str, help='Directory containing camera models') 31 | parser.add_argument('--extrinsics_dir', type=str, help='Directory containing sensor extrinsics') 32 | parser.add_argument('--image_idx', type=int, help='Index of image to display') 33 | 34 | args = parser.parse_args() 35 | 36 | model = CameraModel(args.models_dir, args.image_dir) 37 | 38 | extrinsics_path = os.path.join(args.extrinsics_dir, model.camera + '.txt') 39 | with open(extrinsics_path) as extrinsics_file: 40 | extrinsics = [float(x) for x in next(extrinsics_file).split(' ')] 41 | 42 | G_camera_vehicle = build_se3_transform(extrinsics) 43 | G_camera_posesource = None 44 | 45 | poses_type = re.search('(vo|ins|rtk)\.csv', args.poses_file).group(1) 46 | if poses_type in ['ins', 'rtk']: 47 | with open(os.path.join(args.extrinsics_dir, 'ins.txt')) as extrinsics_file: 48 | extrinsics = next(extrinsics_file) 49 | G_camera_posesource = G_camera_vehicle * build_se3_transform([float(x) for x in extrinsics.split(' ')]) 50 | else: 51 | # VO frame and vehicle frame are the same 52 | G_camera_posesource = G_camera_vehicle 53 | 54 | 55 | timestamps_path = os.path.join(args.image_dir, os.pardir, model.camera + '.timestamps') 56 | if not os.path.isfile(timestamps_path): 57 | timestamps_path = os.path.join(args.image_dir, os.pardir, os.pardir, model.camera + '.timestamps') 58 | 59 | timestamp = 0 60 | with open(timestamps_path) as timestamps_file: 61 | for i, line in enumerate(timestamps_file): 62 | if i == args.image_idx: 63 | timestamp = int(line.split(' ')[0]) 64 | 65 | pointcloud, reflectance = build_pointcloud(args.laser_dir, args.poses_file, args.extrinsics_dir, 66 | timestamp - 1e7, timestamp + 1e7, timestamp) 67 | 68 | pointcloud = np.dot(G_camera_posesource, pointcloud) 69 | 70 | image_path = os.path.join(args.image_dir, str(timestamp) + '.png') 71 | image = load_image(image_path, model) 72 | 73 | uv, depth = model.project(pointcloud, image.shape) 74 | 75 | plt.imshow(image) 76 | plt.scatter(np.ravel(uv[0, :]), np.ravel(uv[1, :]), s=2, c=depth, edgecolors='none', cmap='jet') 77 | plt.xlim(0, image.shape[1]) 78 | plt.ylim(image.shape[0], 0) 79 | plt.xticks([]) 80 | plt.yticks([]) 81 | plt.show() 82 | -------------------------------------------------------------------------------- /datasets/robotcar_sdk/python/radar.py: -------------------------------------------------------------------------------- 1 | ################################################################################ 2 | # 3 | # Copyright (c) 2017 University of Oxford 4 | # Authors: 5 | # Dan Barnes (dbarnes@robots.ox.ac.uk) 6 | # 7 | # This work is licensed under the Creative Commons 8 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 9 | # To view a copy of this license, visit 10 | # http://creativecommons.org/licenses/by-nc-sa/4.0/ or send a letter to 11 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 12 | # 13 | ############################################################################### 14 | 15 | from typing import AnyStr, Tuple 16 | import numpy as np 17 | import cv2 18 | 19 | 20 | def load_radar(example_path: AnyStr) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, float]: 21 | """Decode a single Oxford Radar RobotCar Dataset radar example 22 | Args: 23 | example_path (AnyStr): Oxford Radar RobotCar Dataset Example png 24 | Returns: 25 | timestamps (np.ndarray): Timestamp for each azimuth in int64 (UNIX time) 26 | azimuths (np.ndarray): Rotation for each polar radar azimuth (radians) 27 | valid (np.ndarray) Mask of whether azimuth data is an original sensor reading or interpolated from adjacent 28 | azimuths 29 | fft_data (np.ndarray): Radar power readings along each azimuth 30 | radar_resolution (float): Resolution of the polar radar data (metres per pixel) 31 | """ 32 | # Hard coded configuration to simplify parsing code 33 | radar_resolution = np.array([0.0432], np.float32) 34 | encoder_size = 5600 35 | 36 | raw_example_data = cv2.imread(example_path, cv2.IMREAD_GRAYSCALE) 37 | timestamps = raw_example_data[:, :8].copy().view(np.int64) 38 | azimuths = (raw_example_data[:, 8:10].copy().view(np.uint16) / float(encoder_size) * 2 * np.pi).astype(np.float32) 39 | valid = raw_example_data[:, 10:11] == 255 40 | fft_data = raw_example_data[:, 11:].astype(np.float32)[:, :, np.newaxis] / 255. 41 | 42 | return timestamps, azimuths, valid, fft_data, radar_resolution 43 | 44 | 45 | def radar_polar_to_cartesian(azimuths: np.ndarray, fft_data: np.ndarray, radar_resolution: float, 46 | cart_resolution: float, cart_pixel_width: int, interpolate_crossover=True) -> np.ndarray: 47 | """Convert a polar radar scan to cartesian. 48 | Args: 49 | azimuths (np.ndarray): Rotation for each polar radar azimuth (radians) 50 | fft_data (np.ndarray): Polar radar power readings 51 | radar_resolution (float): Resolution of the polar radar data (metres per pixel) 52 | cart_resolution (float): Cartesian resolution (metres per pixel) 53 | cart_pixel_size (int): Width and height of the returned square cartesian output (pixels). Please see the Notes 54 | below for a full explanation of how this is used. 55 | interpolate_crossover (bool, optional): If true interpolates between the end and start azimuth of the scan. In 56 | practice a scan before / after should be used but this prevents nan regions in the return cartesian form. 57 | 58 | Returns: 59 | np.ndarray: Cartesian radar power readings 60 | Notes: 61 | After using the warping grid the output radar cartesian is defined as as follows where 62 | X and Y are the `real` world locations of the pixels in metres: 63 | If 'cart_pixel_width' is odd: 64 | +------ Y = -1 * cart_resolution (m) 65 | |+----- Y = 0 (m) at centre pixel 66 | ||+---- Y = 1 * cart_resolution (m) 67 | |||+--- Y = 2 * cart_resolution (m) 68 | |||| +- Y = cart_pixel_width // 2 * cart_resolution (m) (at last pixel) 69 | |||| +-----------+ 70 | vvvv v 71 | +---------------+---------------+ 72 | | | | 73 | | | | 74 | | | | 75 | | | | 76 | | | | 77 | | | | 78 | | | | 79 | +---------------+---------------+ <-- X = 0 (m) at centre pixel 80 | | | | 81 | | | | 82 | | | | 83 | | | | 84 | | | | 85 | | | | 86 | | | | 87 | +---------------+---------------+ 88 | <-------------------------------> 89 | cart_pixel_width (pixels) 90 | If 'cart_pixel_width' is even: 91 | +------ Y = -0.5 * cart_resolution (m) 92 | |+----- Y = 0.5 * cart_resolution (m) 93 | ||+---- Y = 1.5 * cart_resolution (m) 94 | |||+--- Y = 2.5 * cart_resolution (m) 95 | |||| +- Y = (cart_pixel_width / 2 - 0.5) * cart_resolution (m) (at last pixel) 96 | |||| +----------+ 97 | vvvv v 98 | +------------------------------+ 99 | | | 100 | | | 101 | | | 102 | | | 103 | | | 104 | | | 105 | | | 106 | | | 107 | | | 108 | | | 109 | | | 110 | | | 111 | | | 112 | | | 113 | | | 114 | +------------------------------+ 115 | <------------------------------> 116 | cart_pixel_width (pixels) 117 | """ 118 | if (cart_pixel_width % 2) == 0: 119 | cart_min_range = (cart_pixel_width / 2 - 0.5) * cart_resolution 120 | else: 121 | cart_min_range = cart_pixel_width // 2 * cart_resolution 122 | coords = np.linspace(-cart_min_range, cart_min_range, cart_pixel_width, dtype=np.float32) 123 | Y, X = np.meshgrid(coords, -coords) 124 | sample_range = np.sqrt(Y * Y + X * X) 125 | sample_angle = np.arctan2(Y, X) 126 | sample_angle += (sample_angle < 0).astype(np.float32) * 2. * np.pi 127 | 128 | # Interpolate Radar Data Coordinates 129 | azimuth_step = azimuths[1] - azimuths[0] 130 | sample_u = (sample_range - radar_resolution / 2) / radar_resolution 131 | sample_v = (sample_angle - azimuths[0]) / azimuth_step 132 | 133 | # We clip the sample points to the minimum sensor reading range so that we 134 | # do not have undefined results in the centre of the image. In practice 135 | # this region is simply undefined. 136 | sample_u[sample_u < 0] = 0 137 | 138 | if interpolate_crossover: 139 | fft_data = np.concatenate((fft_data[-1:], fft_data, fft_data[:1]), 0) 140 | sample_v = sample_v + 1 141 | 142 | polar_to_cart_warp = np.stack((sample_u, sample_v), -1) 143 | cart_img = np.expand_dims(cv2.remap(fft_data, polar_to_cart_warp, None, cv2.INTER_LINEAR), -1) 144 | return cart_img 145 | -------------------------------------------------------------------------------- /datasets/robotcar_sdk/python/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | matplotlib 3 | colour_demosaicing 4 | pillow 5 | opencv-python 6 | open3d-python 7 | -------------------------------------------------------------------------------- /datasets/robotcar_sdk/python/transform.py: -------------------------------------------------------------------------------- 1 | ################################################################################ 2 | # 3 | # Copyright (c) 2017 University of Oxford 4 | # Authors: 5 | # Geoff Pascoe (gmp@robots.ox.ac.uk) 6 | # 7 | # This work is licensed under the Creative Commons 8 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 9 | # To view a copy of this license, visit 10 | # http://creativecommons.org/licenses/by-nc-sa/4.0/ or send a letter to 11 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 12 | # 13 | ################################################################################ 14 | 15 | import numpy as np 16 | import numpy.matlib as matlib 17 | from math import sin, cos, atan2, sqrt 18 | 19 | MATRIX_MATCH_TOLERANCE = 1e-4 20 | 21 | 22 | def build_se3_transform(xyzrpy): 23 | """Creates an SE3 transform from translation and Euler angles. 24 | 25 | Args: 26 | xyzrpy (list[float]): translation and Euler angles for transform. Must have six components. 27 | 28 | Returns: 29 | numpy.matrixlib.defmatrix.matrix: SE3 homogeneous transformation matrix 30 | 31 | Raises: 32 | ValueError: if `len(xyzrpy) != 6` 33 | 34 | """ 35 | if len(xyzrpy) != 6: 36 | raise ValueError("Must supply 6 values to build transform") 37 | 38 | se3 = matlib.identity(4) 39 | se3[0:3, 0:3] = euler_to_so3(xyzrpy[3:6]) 40 | se3[0:3, 3] = np.matrix(xyzrpy[0:3]).transpose() 41 | return se3 42 | 43 | 44 | def euler_to_so3(rpy): 45 | """Converts Euler angles to an SO3 rotation matrix. 46 | 47 | Args: 48 | rpy (list[float]): Euler angles (in radians). Must have three components. 49 | 50 | Returns: 51 | numpy.matrixlib.defmatrix.matrix: 3x3 SO3 rotation matrix 52 | 53 | Raises: 54 | ValueError: if `len(rpy) != 3`. 55 | 56 | """ 57 | if len(rpy) != 3: 58 | raise ValueError("Euler angles must have three components") 59 | 60 | R_x = np.matrix([[1, 0, 0], 61 | [0, cos(rpy[0]), -sin(rpy[0])], 62 | [0, sin(rpy[0]), cos(rpy[0])]]) 63 | R_y = np.matrix([[cos(rpy[1]), 0, sin(rpy[1])], 64 | [0, 1, 0], 65 | [-sin(rpy[1]), 0, cos(rpy[1])]]) 66 | R_z = np.matrix([[cos(rpy[2]), -sin(rpy[2]), 0], 67 | [sin(rpy[2]), cos(rpy[2]), 0], 68 | [0, 0, 1]]) 69 | R_zyx = R_z * R_y * R_x 70 | return R_zyx 71 | 72 | 73 | def so3_to_euler(so3): 74 | """Converts an SO3 rotation matrix to Euler angles 75 | 76 | Args: 77 | so3: 3x3 rotation matrix 78 | 79 | Returns: 80 | numpy.matrixlib.defmatrix.matrix: list of Euler angles (size 3) 81 | 82 | Raises: 83 | ValueError: if so3 is not 3x3 84 | ValueError: if a valid Euler parametrisation cannot be found 85 | 86 | """ 87 | if so3.shape != (3, 3): 88 | raise ValueError("SO3 matrix must be 3x3") 89 | roll = atan2(so3[2, 1], so3[2, 2]) 90 | yaw = atan2(so3[1, 0], so3[0, 0]) 91 | denom = sqrt(so3[0, 0] ** 2 + so3[1, 0] ** 2) 92 | pitch_poss = [atan2(-so3[2, 0], denom), atan2(-so3[2, 0], -denom)] 93 | 94 | R = euler_to_so3((roll, pitch_poss[0], yaw)) 95 | 96 | if (so3 - R).sum() < MATRIX_MATCH_TOLERANCE: 97 | return np.matrix([roll, pitch_poss[0], yaw]) 98 | else: 99 | R = euler_to_so3((roll, pitch_poss[1], yaw)) 100 | if (so3 - R).sum() > MATRIX_MATCH_TOLERANCE: 101 | raise ValueError("Could not find valid pitch angle") 102 | return np.matrix([roll, pitch_poss[1], yaw]) 103 | 104 | 105 | def so3_to_quaternion(so3): 106 | """Converts an SO3 rotation matrix to a quaternion 107 | 108 | Args: 109 | so3: 3x3 rotation matrix 110 | 111 | Returns: 112 | numpy.ndarray: quaternion [w, x, y, z] 113 | 114 | Raises: 115 | ValueError: if so3 is not 3x3 116 | """ 117 | if so3.shape != (3, 3): 118 | raise ValueError("SO3 matrix must be 3x3") 119 | 120 | R_xx = so3[0, 0] 121 | R_xy = so3[0, 1] 122 | R_xz = so3[0, 2] 123 | R_yx = so3[1, 0] 124 | R_yy = so3[1, 1] 125 | R_yz = so3[1, 2] 126 | R_zx = so3[2, 0] 127 | R_zy = so3[2, 1] 128 | R_zz = so3[2, 2] 129 | 130 | try: 131 | w = sqrt(so3.trace() + 1) / 2 132 | except(ValueError): 133 | # w is non-real 134 | w = 0 135 | 136 | # Due to numerical precision the value passed to `sqrt` may be a negative of the order 1e-15. 137 | # To avoid this error we clip these values to a minimum value of 0. 138 | x = sqrt(max(1 + R_xx - R_yy - R_zz, 0)) / 2 139 | y = sqrt(max(1 + R_yy - R_xx - R_zz, 0)) / 2 140 | z = sqrt(max(1 + R_zz - R_yy - R_xx, 0)) / 2 141 | 142 | max_index = max(range(4), key=[w, x, y, z].__getitem__) 143 | 144 | if max_index == 0: 145 | x = (R_zy - R_yz) / (4 * w) 146 | y = (R_xz - R_zx) / (4 * w) 147 | z = (R_yx - R_xy) / (4 * w) 148 | elif max_index == 1: 149 | w = (R_zy - R_yz) / (4 * x) 150 | y = (R_xy + R_yx) / (4 * x) 151 | z = (R_zx + R_xz) / (4 * x) 152 | elif max_index == 2: 153 | w = (R_xz - R_zx) / (4 * y) 154 | x = (R_xy + R_yx) / (4 * y) 155 | z = (R_yz + R_zy) / (4 * y) 156 | elif max_index == 3: 157 | w = (R_yx - R_xy) / (4 * z) 158 | x = (R_zx + R_xz) / (4 * z) 159 | y = (R_yz + R_zy) / (4 * z) 160 | 161 | return np.array([w, x, y, z]) 162 | 163 | 164 | def se3_to_components(se3): 165 | """Converts an SE3 rotation matrix to linear translation and Euler angles 166 | 167 | Args: 168 | se3: 4x4 transformation matrix 169 | 170 | Returns: 171 | numpy.matrixlib.defmatrix.matrix: list of [x, y, z, roll, pitch, yaw] 172 | 173 | Raises: 174 | ValueError: if se3 is not 4x4 175 | ValueError: if a valid Euler parametrisation cannot be found 176 | 177 | """ 178 | if se3.shape != (4, 4): 179 | raise ValueError("SE3 transform must be a 4x4 matrix") 180 | xyzrpy = np.empty(6) 181 | xyzrpy[0:3] = se3[0:3, 3].transpose() 182 | xyzrpy[3:6] = so3_to_euler(se3[0:3, 0:3]) 183 | return xyzrpy 184 | -------------------------------------------------------------------------------- /datasets/robotcar_sdk/python/velodyne.py: -------------------------------------------------------------------------------- 1 | ################################################################################ 2 | # 3 | # Copyright (c) 2017 University of Oxford 4 | # Authors: 5 | # Dan Barnes (dbarnes@robots.ox.ac.uk) 6 | # 7 | # This work is licensed under the Creative Commons 8 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 9 | # To view a copy of this license, visit 10 | # http://creativecommons.org/licenses/by-nc-sa/4.0/ or send a letter to 11 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 12 | # 13 | ############################################################################### 14 | 15 | from typing import AnyStr 16 | import numpy as np 17 | import os 18 | 19 | # Hard coded configuration to simplify parsing code 20 | hdl32e_range_resolution = 0.002 # m / pixel 21 | hdl32e_minimum_range = 1.0 22 | hdl32e_elevations = np.array([-0.1862, -0.1628, -0.1396, -0.1164, -0.0930, 23 | -0.0698, -0.0466, -0.0232, 0., 0.0232, 0.0466, 0.0698, 24 | 0.0930, 0.1164, 0.1396, 0.1628, 0.1862, 0.2094, 0.2327, 25 | 0.2560, 0.2793, 0.3025, 0.3259, 0.3491, 0.3723, 0.3957, 26 | 0.4189, 0.4421, 0.4655, 0.4887, 0.5119, 0.5353])[:, np.newaxis] 27 | hdl32e_base_to_fire_height = 0.090805 28 | hdl32e_cos_elevations = np.cos(hdl32e_elevations) 29 | hdl32e_sin_elevations = np.sin(hdl32e_elevations) 30 | 31 | 32 | def load_velodyne_binary(velodyne_bin_path: AnyStr): 33 | """Decode a binary Velodyne example (of the form '.bin') 34 | Args: 35 | example_path (AnyStr): Oxford Radar RobotCar Dataset binary Velodyne pointcloud example path 36 | Returns: 37 | ptcld (np.ndarray): XYZI pointcloud from the binary Velodyne data Nx4 38 | Notes: 39 | - The pre computed points are *NOT* motion compensated. 40 | - Converting a raw velodyne scan to pointcloud can be done using the 41 | `velodyne_ranges_intensities_angles_to_pointcloud` function. 42 | """ 43 | ext = os.path.splitext(velodyne_bin_path)[1] 44 | if ext != ".bin": 45 | raise RuntimeError("Velodyne binary pointcloud file should have `.bin` extension but had: {}".format(ext)) 46 | if not os.path.isfile(velodyne_bin_path): 47 | raise FileNotFoundError("Could not find velodyne bin example: {}".format(velodyne_bin_path)) 48 | data = np.fromfile(velodyne_bin_path, dtype=np.float32) 49 | ptcld = data.reshape((4, -1)) 50 | return ptcld 51 | 52 | def load_velodyne_binary_seg(velodyne_bin_path: AnyStr): 53 | """Decode a binary Velodyne example (of the form '.bin') 54 | Args: 55 | example_path (AnyStr): Oxford Radar RobotCar Dataset binary Velodyne pointcloud example path 56 | Returns: 57 | ptcld (np.ndarray): XYZI pointcloud from the binary Velodyne data Nx4 58 | Notes: 59 | - The pre computed points are *NOT* motion compensated. 60 | - Converting a raw velodyne scan to pointcloud can be done using the 61 | `velodyne_ranges_intensities_angles_to_pointcloud` function. 62 | """ 63 | ext = os.path.splitext(velodyne_bin_path)[1] 64 | if ext != ".bin": 65 | raise RuntimeError("Velodyne binary pointcloud file should have `.bin` extension but had: {}".format(ext)) 66 | if not os.path.isfile(velodyne_bin_path): 67 | raise FileNotFoundError("Could not find velodyne bin example: {}".format(velodyne_bin_path)) 68 | data = np.fromfile(velodyne_bin_path, dtype=np.float32) 69 | ptcld = data.reshape((-1, 4)) 70 | return ptcld 71 | 72 | 73 | def velodyne_raw_to_pointcloud(ranges: np.ndarray, intensities: np.ndarray, angles: np.ndarray): 74 | """ Convert raw Velodyne data (from load_velodyne_raw) into a pointcloud 75 | Args: 76 | ranges (np.ndarray): Raw Velodyne range readings 77 | intensities (np.ndarray): Raw Velodyne intensity readings 78 | angles (np.ndarray): Raw Velodyne angles 79 | Returns: 80 | pointcloud (np.ndarray): XYZI pointcloud generated from the raw Velodyne data Nx4 81 | 82 | Notes: 83 | - This implementation does *NOT* perform motion compensation on the generated pointcloud. 84 | - Accessing the pointclouds in binary form via `load_velodyne_pointcloud` is approximately 2x faster at the cost 85 | of 8x the storage space 86 | """ 87 | valid = ranges > hdl32e_minimum_range 88 | z = hdl32e_sin_elevations * ranges - hdl32e_base_to_fire_height 89 | xy = hdl32e_cos_elevations * ranges 90 | x = np.sin(angles) * xy 91 | y = -np.cos(angles) * xy 92 | 93 | xf = x[valid].reshape(-1) 94 | yf = y[valid].reshape(-1) 95 | zf = z[valid].reshape(-1) 96 | intensityf = intensities[valid].reshape(-1).astype(np.float32) 97 | ptcld = np.stack((xf, yf, zf, intensityf), 0) 98 | return ptcld 99 | -------------------------------------------------------------------------------- /img/nclt.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nubot-nudt/BEVDiffLoc/a25b1196c7ca6df092c7af36a875934deacf74b3/img/nclt.gif -------------------------------------------------------------------------------- /img/oxford.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nubot-nudt/BEVDiffLoc/a25b1196c7ca6df092c7af36a875934deacf74b3/img/oxford.gif -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # This Script Assumes Python 3.10, CUDA 12.1 8 | 9 | conda deactivate 10 | 11 | # Set environment variables 12 | export ENV_NAME=BEVDiffLoc 13 | export PYTHON_VERSION=3.10 14 | export PYTORCH_VERSION=2.1.2 15 | export CUDA_VERSION=12.1 16 | 17 | # Create a new conda environment and activate it 18 | conda create -n $ENV_NAME python=$PYTHON_VERSION 19 | conda activate $ENV_NAME 20 | 21 | # Install PyTorch, torchvision, and PyTorch3D using conda 22 | conda install pytorch=$PYTORCH_VERSION torchvision pytorch-cuda=$CUDA_VERSION -c pytorch -c nvidia 23 | conda install -c fvcore -c iopath -c conda-forge fvcore iopath 24 | conda install pytorch3d -c pytorch3d 25 | 26 | # Install pip packages 27 | pip install hydra-core --upgrade 28 | pip install omegaconf opencv-python einops visdom 29 | pip install accelerate==0.24.0 30 | pip install matplotlib==3.8.2 31 | pip install pandas==2.2.0 32 | pip install transforms3d==0.4.1 33 | pip install open3d==0.18.0 34 | pip install h5py==3.10.0 35 | pip install tensorboardX==2.6.2.2 36 | pip install timm==0.9.12 37 | pip install faiss-gpu 38 | pip install numpy 39 | pip install opencv-python 40 | pip install scikit-image 41 | pip install scikit-learn 42 | pip install tqdm 43 | pip install argparse 44 | pip install imgaug 45 | -------------------------------------------------------------------------------- /log/count_SR.py: -------------------------------------------------------------------------------- 1 | def calculate_count(file_t, file_q): 2 | try: 3 | with open(file_t, 'r') as file_t, open(file_q, 'r') as file_q: 4 | lines_t = file_t.readlines() 5 | lines_q = file_q.readlines() 6 | 7 | 8 | count = 0 9 | min_lines = min(len(lines_t), len(lines_q)) 10 | 11 | 12 | for i in range(min_lines): 13 | try: 14 | value_t = float(lines_t[i].strip()) 15 | value_q = float(lines_q[i].strip()) 16 | 17 | 18 | if value_t < 2 and value_q < 5: 19 | count += 1 20 | except ValueError: 21 | 22 | continue 23 | 24 | return count / min_lines 25 | 26 | except FileNotFoundError: 27 | print("Can't find File") 28 | return None 29 | 30 | file_t = 'error_t.txt' 31 | file_q = 'error_q.txt' 32 | 33 | ratio = calculate_count(file_t, file_q) 34 | if ratio is not None: 35 | print(f"SR ratio: {ratio}") 36 | 37 | -------------------------------------------------------------------------------- /merge_nclt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import h5py 3 | import torch 4 | import cv2 5 | import numpy as np 6 | import os.path as osp 7 | import open3d as o3d 8 | from torch.utils import data 9 | from datasets.projection import getBEV 10 | from datasets.augmentor import Augmentor, AugmentParams 11 | from utils.pose_util import process_poses, filter_overflow_nclt, interpolate_pose_nclt, so3_to_euler_nclt, poses_foraugmentaion 12 | import time 13 | import math 14 | 15 | BASE_DIR = osp.dirname(osp.abspath(__file__)) 16 | 17 | velodatatype = np.dtype({ 18 | 'x': (' 100: 168 | 169 | merged_pointcloud.clear() 170 | pcd = o3d.geometry.PointCloud() 171 | 172 | for j in range(0, len(all_pointcloud), 20): 173 | # 将pointcloud从numpy数组转为Open3D点云对象 174 | pcd.points = o3d.utility.Vector3dVector(all_pointcloud[j]) 175 | pcd.transform(all_poses[j]) 176 | merged_pointcloud += pcd 177 | 178 | x_mean = np.mean(merged_x) 179 | y_mean = np.mean(merged_y) 180 | x_std = np.std(merged_x) 181 | y_std = np.std(merged_y) 182 | bev_pointcloud = merged_pointcloud.voxel_down_sample(voxel_size) 183 | 184 | # 创建绕Z轴的旋转矩阵 185 | yaw_random = np.random.uniform(-3.14, 3.14) 186 | 187 | x_new = np.random.normal(loc=x_mean, scale=x_std) 188 | y_new = np.random.normal(loc=y_mean, scale=y_std) 189 | 190 | bev_img = getBEV(bev_pointcloud.points, x_new, y_new, yaw_random) 191 | bev_img = np.tile(bev_img, (3, 1, 1)) 192 | bev_img = bev_img.transpose(1, 2, 0) 193 | 194 | cv2.imwrite(f"{image_path}{i-100}.png", bev_img) 195 | with open(pose_path, 'a') as file: 196 | file.write(f"{x_new} {y_new} {yaw_random}\n") 197 | 198 | all_pointcloud.pop(0) 199 | all_poses.pop(0) 200 | merged_x.pop(0) 201 | merged_y.pop(0) 202 | 203 | T2 = time.time() 204 | print("Time used:", T2-T1) 205 | 206 | print("Done") -------------------------------------------------------------------------------- /merge_oxford.py: -------------------------------------------------------------------------------- 1 | import os 2 | import h5py 3 | import torch 4 | import cv2 5 | import numpy as np 6 | import os.path as osp 7 | from copy import deepcopy 8 | import open3d as o3d 9 | from torch.utils import data 10 | from datasets.projection import getBEV 11 | from datasets.augmentor import Augmentor, AugmentParams 12 | from utils.pose_util import process_poses, filter_overflow_ts, poses_foraugmentaion 13 | from datasets.robotcar_sdk.python.interpolate_poses import interpolate_ins_poses 14 | from datasets.robotcar_sdk.python.transform import build_se3_transform, euler_to_so3 15 | import time 16 | import math 17 | 18 | BASE_DIR = osp.dirname(osp.abspath(__file__)) 19 | 20 | class Oxford_merge(data.Dataset): 21 | def __init__(self, split='train'): 22 | # directories 23 | if split == 'train': 24 | self.is_train = True 25 | else: 26 | self.is_train = False 27 | 28 | lidar = 'velodyne_left' 29 | data_path = '/media/wzy/data' 30 | 31 | data_dir = osp.join(data_path, 'Oxford') 32 | extrinsics_dir = osp.join(BASE_DIR, 'datasets' ,'robotcar_sdk', 'extrinsics') 33 | 34 | seqs = ['2019-01-14-12-05-52'] 35 | 36 | ps = {} 37 | ts = {} 38 | vo_stats = {} 39 | self.pcs = [] 40 | 41 | # extrinsic reading 42 | with open(os.path.join(extrinsics_dir, lidar + '.txt')) as extrinsics_file: 43 | extrinsics = next(extrinsics_file) 44 | G_posesource_laser = build_se3_transform([float(x) for x in extrinsics.split(' ')]) 45 | with open(os.path.join(extrinsics_dir, 'ins.txt')) as extrinsics_file: 46 | extrinsics = next(extrinsics_file) 47 | G_posesource_laser = np.linalg.solve(build_se3_transform([float(x) for x in extrinsics.split(' ')]), G_posesource_laser) # (4, 4) 48 | 49 | for seq in seqs: 50 | seq_dir = osp.join(data_dir, seq + '-radar-oxford-10k') 51 | # read the image timestamps 52 | h5_path = osp.join(seq_dir, lidar + '_' + 'False.h5') 53 | 54 | if not os.path.isfile(h5_path): 55 | print('interpolate ' + seq) 56 | ts_filename = osp.join(seq_dir, lidar + '.timestamps') 57 | with open(ts_filename, 'r') as f: 58 | ts_raw = [int(l.rstrip().split(' ')[0]) for l in f] 59 | # GT poses 60 | ins_filename = osp.join(seq_dir, 'gps', 'ins.csv') 61 | ts[seq] = filter_overflow_ts(ins_filename, ts_raw) 62 | p = np.asarray(interpolate_ins_poses(ins_filename, deepcopy(ts[seq]), ts[seq][0])) # (n, 4, 4) 63 | p = np.asarray([np.dot(pose, G_posesource_laser) for pose in p]) # (n, 4, 4) 64 | ps[seq] = np.reshape(p[:, :3, :], (len(p), -1)) # (n, 12) 65 | # write to h5 file 66 | print('write interpolate pose to ' + h5_path) 67 | h5_file = h5py.File(h5_path, 'w') 68 | h5_file.create_dataset('valid_timestamps', data=np.asarray(ts[seq], dtype=np.int64)) 69 | h5_file.create_dataset('poses', data=ps[seq]) 70 | else: 71 | # load h5 file, save pose interpolating time 72 | print("load " + seq + ' pose from ' + h5_path) 73 | h5_file = h5py.File(h5_path, 'r') 74 | ts[seq] = h5_file['valid_timestamps'][...] 75 | ps[seq] = h5_file['poses'][...] 76 | 77 | vo_stats[seq] = {'R': np.eye(3), 't': np.zeros(3), 's': 1} 78 | 79 | self.pcs.extend([osp.join(seq_dir, 'velodyne_left', '{:d}.bin'.format(t)) for t in ts[seq]]) 80 | 81 | # read / save pose normalization information 82 | poses = np.empty((0, 12)) 83 | for p in ps.values(): 84 | poses = np.vstack((poses, p)) 85 | pose_stats_filename = osp.join(data_dir, 'Oxford_pose_stats.txt') 86 | print("pose_stats_filename:",pose_stats_filename) 87 | if split == 'train': 88 | mean_t = np.mean(poses[:, [3, 7, 11]], axis=0) # (3,) 89 | std_t = np.std(poses[:, [3, 7, 11]], axis=0) # (3,) 90 | np.savetxt(pose_stats_filename, np.vstack((mean_t, std_t)), fmt='%8.7f') 91 | else: 92 | mean_t, std_t = np.loadtxt(pose_stats_filename) 93 | 94 | self.poses_3_4 = poses 95 | self.poses = np.empty((0, 6)) 96 | self.rots = np.empty((0, 3, 3)) 97 | for seq in seqs: 98 | pss, rotation, pss_max, pss_min = process_poses(poses_in=ps[seq], mean_t=mean_t, std_t=std_t, 99 | align_R=vo_stats[seq]['R'], align_t=vo_stats[seq]['t'], 100 | align_s=vo_stats[seq]['s']) 101 | self.poses = np.vstack((self.poses, pss)) 102 | self.rots = np.vstack((self.rots, rotation)) 103 | 104 | def __len__(self): 105 | return len(self.poses) 106 | 107 | def __getitem__(self, idx_N): 108 | scan_path = self.pcs[idx_N] 109 | 110 | pointcloud = np.fromfile(scan_path, dtype=np.float32).reshape(4, -1).transpose() 111 | pointcloud[:, 2] = -1 * pointcloud[:, 2] 112 | 113 | poses_3_4 = self.poses_3_4[idx_N] 114 | 115 | # Generate BEV_Image 116 | pointcloud = pointcloud[:, :3] 117 | pointcloud = pointcloud[np.where(np.abs(pointcloud[:,0])<50)[0],:] 118 | pointcloud = pointcloud[np.where(np.abs(pointcloud[:,1])<50)[0],:] 119 | pointcloud = pointcloud[np.where(np.abs(pointcloud[:,2])<50)[0],:] 120 | pointcloud = pointcloud.astype(np.float32) 121 | 122 | return pointcloud, poses_3_4 123 | 124 | if __name__ == '__main__': 125 | 126 | dataset = Oxford_merge(split='train') 127 | merged_pointcloud = o3d.geometry.PointCloud() 128 | merged_x = [] 129 | merged_y = [] 130 | all_pointcloud = [] 131 | all_poses = [] 132 | voxel_size = 0.4 133 | image_path = '/home/wzy/merge_bev/' 134 | pose_path = '/home/wzy/merge_bev.txt' 135 | with open(pose_path, 'w', encoding='utf-8'): 136 | pass 137 | 138 | if not os.path.exists(image_path): 139 | # 如果目录不存在,创建该目录 140 | os.makedirs(image_path) 141 | print(f"目录 '{image_path}' 已创建") 142 | else: 143 | print(f"目录 '{image_path}' 已存在") 144 | 145 | T1 = time.time() 146 | for i in range(len(dataset)): 147 | 148 | pointcloud, poses = dataset[i] 149 | 150 | merged_x.append(poses[3]) 151 | merged_y.append(poses[7]) 152 | all_pointcloud.append(pointcloud) 153 | 154 | poses = poses.reshape(3, 4) 155 | 156 | # 添加最后一行 [0, 0, 0, 1] 157 | last_row = np.array([0, 0, 0, 1]).reshape(1, 4) 158 | poses = np.vstack((poses, last_row)) 159 | all_poses.append(poses) 160 | 161 | if i > 100: 162 | 163 | merged_pointcloud.clear() 164 | pcd = o3d.geometry.PointCloud() 165 | 166 | for j in range(0, len(all_pointcloud), 20): 167 | # 将pointcloud从numpy数组转为Open3D点云对象 168 | pcd.points = o3d.utility.Vector3dVector(all_pointcloud[j]) 169 | pcd.transform(all_poses[j]) 170 | merged_pointcloud += pcd 171 | 172 | x_mean = np.mean(merged_x) 173 | y_mean = np.mean(merged_y) 174 | x_std = np.std(merged_x) 175 | y_std = np.std(merged_y) 176 | bev_pointcloud = merged_pointcloud.voxel_down_sample(voxel_size) 177 | 178 | # 创建绕Z轴的旋转矩阵 179 | yaw_random = np.random.uniform(-3.14, 3.14) 180 | 181 | x_new = np.random.normal(loc=x_mean, scale=x_std) 182 | y_new = np.random.normal(loc=y_mean, scale=y_std) 183 | 184 | bev_img = getBEV(bev_pointcloud.points, x_new, y_new, yaw_random) 185 | bev_img = np.tile(bev_img, (3, 1, 1)) 186 | bev_img = bev_img.transpose(1, 2, 0) 187 | 188 | cv2.imwrite(f"{image_path}{i-100}.png", bev_img) 189 | with open(pose_path, 'a') as file: 190 | file.write(f"{x_new} {y_new} {yaw_random}\n") 191 | 192 | all_pointcloud.pop(0) 193 | all_poses.pop(0) 194 | merged_x.pop(0) 195 | merged_y.pop(0) 196 | 197 | T2 = time.time() 198 | print("Time used:", T2-T1) 199 | 200 | print("Done") -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .difussion_loc_model_bev import DiffusionLocModel_bev 8 | 9 | from .denoiser_bev import Denoiser_bev, TransformerEncoderWrapper_bev 10 | from .gaussian_diffuser import GaussianDiffusion 11 | from .image_feature_extractor_bev import ImageFeatureExtractor_bev 12 | -------------------------------------------------------------------------------- /models/decoders.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 - Valeo Comfort and Driving Assistance 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | from einops import rearrange 19 | from einops.layers.torch import Rearrange 20 | 21 | from models.model_utils import get_grid_size_1d, get_grid_size_2d, init_weights 22 | 23 | 24 | class DecoderLinear(nn.Module): 25 | # From R. Strudel et al. 26 | # https://github.com/rstrudel/segmenter 27 | def __init__(self, n_cls, patch_size, d_encoder, patch_stride=None): 28 | super().__init__() 29 | 30 | self.d_encoder = d_encoder 31 | self.patch_size = patch_size 32 | self.patch_stride = patch_stride 33 | self.n_cls = n_cls 34 | 35 | self.head = nn.Linear(self.d_encoder, n_cls) 36 | self.apply(init_weights) 37 | 38 | @torch.jit.ignore 39 | def no_weight_decay(self): 40 | return set() 41 | 42 | def forward(self, x, im_size, skip=None): 43 | H, W = im_size 44 | GS_H, GS_W = get_grid_size_2d(H, W, self.patch_size, self.patch_stride) 45 | # print(x.shape) 46 | x1 = self.head(x) 47 | # print(x1.shape) 48 | x2 = rearrange(x1, 'b (h w) c -> b c h w', h=GS_H) 49 | return x1, x2 50 | # return x1 51 | 52 | class DecoderLinear_bev(nn.Module): 53 | # From R. Strudel et al. 54 | # https://github.com/rstrudel/segmenter 55 | def __init__(self, n_cls, patch_size, d_encoder, patch_stride=None): 56 | super().__init__() 57 | 58 | self.d_encoder = d_encoder 59 | self.patch_size = patch_size 60 | self.patch_stride = patch_stride 61 | self.n_cls = n_cls 62 | 63 | self.head = nn.Linear(self.d_encoder, n_cls) 64 | self.apply(init_weights) 65 | 66 | @torch.jit.ignore 67 | def no_weight_decay(self): 68 | return set() 69 | 70 | def forward(self, x, im_size, skip=None): 71 | H, W = im_size 72 | x1 = self.head(x) 73 | return x1 74 | -------------------------------------------------------------------------------- /models/denoiser_bev.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Ziyue Wang and Wen Li 3 | @file: denoiser_bev.py 4 | @time: 2025/3/12 14:20 5 | """ 6 | 7 | import logging 8 | from typing import Dict, List, Optional, Callable 9 | from utils.embedding import TimeStepEmbedding, PoseEmbedding 10 | 11 | import torch 12 | import torch.nn as nn 13 | 14 | from hydra.utils import instantiate 15 | 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | class Denoiser_bev(nn.Module): 21 | def __init__( 22 | self, 23 | TRANSFORMER: Dict, 24 | target_dim: int = 3, # pose shape 25 | pivot_cam_onehot: bool = True, 26 | z_dim: int = 384, 27 | mlp_hidden_dim: bool = 128, 28 | ): 29 | super().__init__() 30 | 31 | self.pivot_cam_onehot = pivot_cam_onehot 32 | self.target_dim = target_dim 33 | self.time_embed = TimeStepEmbedding() 34 | self.pose_embed = PoseEmbedding(target_dim=self.target_dim) 35 | 36 | first_dim = ( 37 | self.time_embed.out_dim 38 | + self.pose_embed.out_dim 39 | + z_dim 40 | + int(self.pivot_cam_onehot) 41 | ) 42 | 43 | d_model = TRANSFORMER.d_model 44 | 45 | self._first = nn.Linear(first_dim, d_model) 46 | 47 | # call TransformerEncoderWrapper() to build a encoder-only transformer 48 | self._trunk = instantiate(TRANSFORMER, _recursive_=False) 49 | 50 | self._last = MLP( 51 | d_model, 52 | [mlp_hidden_dim, self.target_dim], 53 | norm_layer=nn.LayerNorm, 54 | ) 55 | 56 | def forward( 57 | self, 58 | x: torch.Tensor, # B x N x dim 59 | t: torch.Tensor, # B 60 | z: torch.Tensor, # B x N x dim_z 61 | ): 62 | B, N, _ = x.shape 63 | 64 | t_emb = self.time_embed(t) 65 | # expand t from B x C to B x N x C 66 | t_emb = t_emb.view(B, 1, t_emb.shape[-1]).expand(-1, N, -1) 67 | 68 | x_emb = self.pose_embed(x) 69 | if self.pivot_cam_onehot: 70 | # add the one hot vector identifying the first camera as pivot 71 | cam_pivot_id = torch.zeros_like(z[..., :1]) 72 | cam_pivot_id[:, 0, ...] = 1.0 73 | z = torch.cat([z, cam_pivot_id], dim=-1) 74 | 75 | feed_feats = torch.cat([x_emb, t_emb, z], dim=-1) 76 | 77 | input_ = self._first(feed_feats) 78 | 79 | feats_ = self._trunk(input_) 80 | 81 | output = self._last(feats_) 82 | 83 | return output 84 | 85 | 86 | def TransformerEncoderWrapper_bev( 87 | d_model: int, 88 | nhead: int, 89 | num_encoder_layers: int, 90 | dim_feedforward: int = 2048, 91 | dropout: float = 0.1, 92 | norm_first: bool = True, 93 | batch_first: bool = True, 94 | ): 95 | encoder_layer = torch.nn.TransformerEncoderLayer( 96 | d_model=d_model, 97 | nhead=nhead, 98 | dim_feedforward=dim_feedforward, 99 | dropout=dropout, 100 | batch_first=batch_first, 101 | norm_first=norm_first, 102 | ) 103 | 104 | _trunk = torch.nn.TransformerEncoder(encoder_layer, num_encoder_layers) 105 | return _trunk 106 | 107 | 108 | class MLP(torch.nn.Sequential): 109 | """This block implements the multi-layer perceptron (MLP) module. 110 | 111 | Args: 112 | in_channels (int): Number of channels of the input 113 | hidden_channels (List[int]): List of the hidden channel dimensions 114 | norm_layer (Callable[..., torch.nn.Module], optional): 115 | Norm layer that will be stacked on top of the convolution layer. 116 | If ``None`` this layer wont be used. Default: ``None`` 117 | activation_layer (Callable[..., torch.nn.Module], optional): 118 | Activation function which will be stacked on top of the 119 | normalization layer (if not None), otherwise on top of the 120 | conv layer. If ``None`` this layer wont be used. 121 | Default: ``torch.nn.ReLU`` 122 | inplace (bool): Parameter for the activation layer, which can 123 | optionally do the operation in-place. Default ``True`` 124 | bias (bool): Whether to use bias in the linear layer. Default ``True`` 125 | dropout (float): The probability for the dropout layer. Default: 0.0 126 | """ 127 | 128 | def __init__( 129 | self, 130 | in_channels: int, 131 | hidden_channels: List[int], 132 | norm_layer: Optional[Callable[..., torch.nn.Module]] = None, 133 | activation_layer: Optional[ 134 | Callable[..., torch.nn.Module] 135 | ] = torch.nn.ReLU, 136 | # ] = nn.LeakyReLU, 137 | inplace: Optional[bool] = True, 138 | bias: bool = True, 139 | norm_first: bool = False, 140 | dropout: float = 0.0, 141 | ): 142 | # The addition of `norm_layer` is inspired from 143 | # the implementation of TorchMultimodal: 144 | # https://github.com/facebookresearch/multimodal/blob/5dec8a/torchmultimodal/modules/layers/mlp.py 145 | params = {} if inplace is None else {"inplace": inplace} 146 | 147 | layers = [] 148 | in_dim = in_channels 149 | 150 | for hidden_dim in hidden_channels[:-1]: 151 | if norm_first and norm_layer is not None: 152 | layers.append(norm_layer(in_dim)) 153 | 154 | layers.append(torch.nn.Linear(in_dim, hidden_dim, bias=bias)) 155 | 156 | if not norm_first and norm_layer is not None: 157 | layers.append(norm_layer(hidden_dim)) 158 | 159 | layers.append(activation_layer(**params)) 160 | 161 | if dropout > 0: 162 | layers.append(torch.nn.Dropout(dropout, **params)) 163 | 164 | in_dim = hidden_dim 165 | 166 | if norm_first and norm_layer is not None: 167 | layers.append(norm_layer(in_dim)) 168 | 169 | layers.append(torch.nn.Linear(in_dim, hidden_channels[-1], bias=bias)) 170 | if dropout > 0: 171 | layers.append(torch.nn.Dropout(dropout, **params)) 172 | 173 | super().__init__(*layers) -------------------------------------------------------------------------------- /models/difussion_loc_model_bev.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Ziyue Wang and Wen Li 3 | @file: difussion_loc_model_bev.py 4 | @time: 2025/3/12 14:20 5 | """ 6 | 7 | import torch 8 | import torch.nn as nn 9 | import logging 10 | from typing import Dict, Optional 11 | from hydra.utils import instantiate 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | class DiffusionLocModel_bev(nn.Module): 17 | def __init__( 18 | self, 19 | IMAGE_FEATURE_EXTRACTOR: Dict, 20 | DIFFUSER: Dict, 21 | DENOISER: Dict, 22 | ): 23 | """ Initializes a DiffusionLoc model 24 | Args: 25 | image_feature_extractor_cfg (Dict): 26 | Configuration for the image feature extractor. 27 | diffuser_cfg (Dict): 28 | Configuration for the diffuser. 29 | denoiser_cfg (Dict): 30 | Configuration for the denoiser. 31 | """ 32 | 33 | super().__init__() 34 | 35 | self.image_feature_extractor = instantiate(IMAGE_FEATURE_EXTRACTOR, _recursive_=False) 36 | 37 | self.diffuser = instantiate(DIFFUSER, _recursive_=False) 38 | 39 | denoiser = instantiate(DENOISER, _recursive_=False) 40 | self.diffuser.model = denoiser 41 | 42 | self.target_dim = denoiser.target_dim 43 | 44 | def forward( 45 | self, 46 | image: torch.Tensor, 47 | pose: Optional[torch.Tensor] = None, 48 | sampling_timesteps = 10, 49 | training=True, 50 | ): 51 | """ 52 | Forward pass of the PoseDiffusionModel. 53 | 54 | Args: 55 | image (torch.Tensor): 56 | Input image tensor, BxNx5xHxW. 57 | pose (Optional[CamerasBase], optional): 58 | Camera object. Defaults to None. 59 | training train or eval 60 | 61 | Return: 62 | Prected poses: BxNx6 63 | 64 | """ 65 | shapelist = list(image.shape) 66 | batch_size = len(image) 67 | 68 | if training: 69 | reshaped_image = image.reshape(shapelist[0] * shapelist[1], *shapelist[2:]) 70 | 71 | z = self.image_feature_extractor(reshaped_image) # [B, N, 384] [B, N, 256, 1] 72 | z = z.reshape(batch_size, shapelist[1], -1) # [B*N, C] [B, N, C] 73 | 74 | diffusion_results = self.diffuser(pose, z=z) 75 | diffusion_results['pred_pose'] = diffusion_results["x_0_pred"] 76 | 77 | return diffusion_results 78 | 79 | else: 80 | 81 | reshaped_image = image.reshape(shapelist[0] * shapelist[1], *shapelist[2:]) 82 | 83 | z = self.image_feature_extractor(reshaped_image) 84 | z = z.reshape(batch_size, shapelist[1], -1) 85 | B, N, _ = z.shape 86 | 87 | target_shape = [B, N, self.target_dim] 88 | 89 | # sampling 90 | # ddpm 91 | # pred_pose, pred_pose_diffusion_samples = self.diffuser.sample(shape=target_shape, z=z) 92 | # ddim 93 | pred_pose, _ = self.diffuser.ddim_sample(shape=target_shape, z=z, sampling_timesteps=sampling_timesteps) 94 | diffusion_results = { 95 | "pred_pose": pred_pose, # [B, N, 3] 96 | "z": z 97 | } 98 | 99 | return diffusion_results 100 | -------------------------------------------------------------------------------- /models/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | from .dino_head import DINOHead 7 | from .mlp import Mlp 8 | from .patch_embed import PatchEmbed 9 | from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused 10 | from .block import NestedTensorBlock 11 | from .attention import MemEffAttention 12 | -------------------------------------------------------------------------------- /models/layers/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py 9 | 10 | import logging 11 | import os 12 | import warnings 13 | 14 | from torch import Tensor 15 | from torch import nn 16 | 17 | 18 | logger = logging.getLogger("dinov2") 19 | 20 | 21 | XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None 22 | try: 23 | if XFORMERS_ENABLED: 24 | from xformers.ops import memory_efficient_attention, unbind 25 | 26 | XFORMERS_AVAILABLE = True 27 | warnings.warn("xFormers is available (Attention)") 28 | else: 29 | warnings.warn("xFormers is disabled (Attention)") 30 | raise ImportError 31 | except ImportError: 32 | XFORMERS_AVAILABLE = False 33 | warnings.warn("xFormers is not available (Attention)") 34 | 35 | 36 | class Attention(nn.Module): 37 | def __init__( 38 | self, 39 | dim: int, 40 | num_heads: int = 8, 41 | qkv_bias: bool = False, 42 | proj_bias: bool = True, 43 | attn_drop: float = 0.0, 44 | proj_drop: float = 0.0, 45 | ) -> None: 46 | super().__init__() 47 | self.num_heads = num_heads 48 | head_dim = dim // num_heads 49 | self.scale = head_dim**-0.5 50 | 51 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 52 | self.attn_drop = nn.Dropout(attn_drop) 53 | self.proj = nn.Linear(dim, dim, bias=proj_bias) 54 | self.proj_drop = nn.Dropout(proj_drop) 55 | 56 | def forward(self, x: Tensor) -> Tensor: 57 | B, N, C = x.shape 58 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 59 | 60 | q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] 61 | attn = q @ k.transpose(-2, -1) 62 | 63 | attn = attn.softmax(dim=-1) 64 | attn = self.attn_drop(attn) 65 | 66 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 67 | x = self.proj(x) 68 | x = self.proj_drop(x) 69 | return x 70 | 71 | 72 | class MemEffAttention(Attention): 73 | def forward(self, x: Tensor, attn_bias=None) -> Tensor: 74 | if not XFORMERS_AVAILABLE: 75 | if attn_bias is not None: 76 | raise AssertionError("xFormers is required for using nested tensors") 77 | return super().forward(x) 78 | 79 | B, N, C = x.shape 80 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) 81 | 82 | q, k, v = unbind(qkv, 2) 83 | 84 | x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) 85 | x = x.reshape([B, N, C]) 86 | 87 | x = self.proj(x) 88 | x = self.proj_drop(x) 89 | return x 90 | -------------------------------------------------------------------------------- /models/layers/block.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py 9 | 10 | import logging 11 | import os 12 | from typing import Callable, List, Any, Tuple, Dict 13 | import warnings 14 | 15 | import torch 16 | from torch import nn, Tensor 17 | 18 | from .attention import Attention, MemEffAttention 19 | from .drop_path import DropPath 20 | from .layer_scale import LayerScale 21 | from .mlp import Mlp 22 | 23 | 24 | logger = logging.getLogger("dinov2") 25 | 26 | 27 | XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None 28 | try: 29 | if XFORMERS_ENABLED: 30 | from xformers.ops import fmha, scaled_index_add, index_select_cat 31 | 32 | XFORMERS_AVAILABLE = True 33 | warnings.warn("xFormers is available (Block)") 34 | else: 35 | warnings.warn("xFormers is disabled (Block)") 36 | raise ImportError 37 | except ImportError: 38 | XFORMERS_AVAILABLE = False 39 | 40 | warnings.warn("xFormers is not available (Block)") 41 | 42 | 43 | class Block(nn.Module): 44 | def __init__( 45 | self, 46 | dim: int, 47 | num_heads: int, 48 | mlp_ratio: float = 4.0, 49 | qkv_bias: bool = False, 50 | proj_bias: bool = True, 51 | ffn_bias: bool = True, 52 | drop: float = 0.0, 53 | attn_drop: float = 0.0, 54 | init_values=None, 55 | drop_path: float = 0.0, 56 | act_layer: Callable[..., nn.Module] = nn.GELU, 57 | norm_layer: Callable[..., nn.Module] = nn.LayerNorm, 58 | attn_class: Callable[..., nn.Module] = Attention, 59 | ffn_layer: Callable[..., nn.Module] = Mlp, 60 | ) -> None: 61 | super().__init__() 62 | # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") 63 | self.norm1 = norm_layer(dim) 64 | self.attn = attn_class( 65 | dim, 66 | num_heads=num_heads, 67 | qkv_bias=qkv_bias, 68 | proj_bias=proj_bias, 69 | attn_drop=attn_drop, 70 | proj_drop=drop, 71 | ) 72 | self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 73 | self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 74 | 75 | self.norm2 = norm_layer(dim) 76 | mlp_hidden_dim = int(dim * mlp_ratio) 77 | self.mlp = ffn_layer( 78 | in_features=dim, 79 | hidden_features=mlp_hidden_dim, 80 | act_layer=act_layer, 81 | drop=drop, 82 | bias=ffn_bias, 83 | ) 84 | self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 85 | self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 86 | 87 | self.sample_drop_ratio = drop_path 88 | 89 | def forward(self, x: Tensor) -> Tensor: 90 | def attn_residual_func(x: Tensor) -> Tensor: 91 | return self.ls1(self.attn(self.norm1(x))) 92 | 93 | def ffn_residual_func(x: Tensor) -> Tensor: 94 | return self.ls2(self.mlp(self.norm2(x))) 95 | 96 | if self.training and self.sample_drop_ratio > 0.1: 97 | # the overhead is compensated only for a drop path rate larger than 0.1 98 | x = drop_add_residual_stochastic_depth( 99 | x, 100 | residual_func=attn_residual_func, 101 | sample_drop_ratio=self.sample_drop_ratio, 102 | ) 103 | x = drop_add_residual_stochastic_depth( 104 | x, 105 | residual_func=ffn_residual_func, 106 | sample_drop_ratio=self.sample_drop_ratio, 107 | ) 108 | elif self.training and self.sample_drop_ratio > 0.0: 109 | x = x + self.drop_path1(attn_residual_func(x)) 110 | x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 111 | else: 112 | x = x + attn_residual_func(x) 113 | x = x + ffn_residual_func(x) 114 | return x 115 | 116 | 117 | def drop_add_residual_stochastic_depth( 118 | x: Tensor, 119 | residual_func: Callable[[Tensor], Tensor], 120 | sample_drop_ratio: float = 0.0, 121 | ) -> Tensor: 122 | # 1) extract subset using permutation 123 | b, n, d = x.shape 124 | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) 125 | brange = (torch.randperm(b, device=x.device))[:sample_subset_size] 126 | x_subset = x[brange] 127 | 128 | # 2) apply residual_func to get residual 129 | residual = residual_func(x_subset) 130 | 131 | x_flat = x.flatten(1) 132 | residual = residual.flatten(1) 133 | 134 | residual_scale_factor = b / sample_subset_size 135 | 136 | # 3) add the residual 137 | x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) 138 | return x_plus_residual.view_as(x) 139 | 140 | 141 | def get_branges_scales(x, sample_drop_ratio=0.0): 142 | b, n, d = x.shape 143 | sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) 144 | brange = (torch.randperm(b, device=x.device))[:sample_subset_size] 145 | residual_scale_factor = b / sample_subset_size 146 | return brange, residual_scale_factor 147 | 148 | 149 | def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): 150 | if scaling_vector is None: 151 | x_flat = x.flatten(1) 152 | residual = residual.flatten(1) 153 | x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) 154 | else: 155 | x_plus_residual = scaled_index_add( 156 | x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor 157 | ) 158 | return x_plus_residual 159 | 160 | 161 | attn_bias_cache: Dict[Tuple, Any] = {} 162 | 163 | 164 | def get_attn_bias_and_cat(x_list, branges=None): 165 | """ 166 | this will perform the index select, cat the tensors, and provide the attn_bias from cache 167 | """ 168 | batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] 169 | all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) 170 | if all_shapes not in attn_bias_cache.keys(): 171 | seqlens = [] 172 | for b, x in zip(batch_sizes, x_list): 173 | for _ in range(b): 174 | seqlens.append(x.shape[1]) 175 | attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) 176 | attn_bias._batch_sizes = batch_sizes 177 | attn_bias_cache[all_shapes] = attn_bias 178 | 179 | if branges is not None: 180 | cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) 181 | else: 182 | tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) 183 | cat_tensors = torch.cat(tensors_bs1, dim=1) 184 | 185 | return attn_bias_cache[all_shapes], cat_tensors 186 | 187 | 188 | def drop_add_residual_stochastic_depth_list( 189 | x_list: List[Tensor], 190 | residual_func: Callable[[Tensor, Any], Tensor], 191 | sample_drop_ratio: float = 0.0, 192 | scaling_vector=None, 193 | ) -> Tensor: 194 | # 1) generate random set of indices for dropping samples in the batch 195 | branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] 196 | branges = [s[0] for s in branges_scales] 197 | residual_scale_factors = [s[1] for s in branges_scales] 198 | 199 | # 2) get attention bias and index+concat the tensors 200 | attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) 201 | 202 | # 3) apply residual_func to get residual, and split the result 203 | residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore 204 | 205 | outputs = [] 206 | for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): 207 | outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) 208 | return outputs 209 | 210 | 211 | class NestedTensorBlock(Block): 212 | def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: 213 | """ 214 | x_list contains a list of tensors to nest together and run 215 | """ 216 | assert isinstance(self.attn, MemEffAttention) 217 | 218 | if self.training and self.sample_drop_ratio > 0.0: 219 | 220 | def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 221 | return self.attn(self.norm1(x), attn_bias=attn_bias) 222 | 223 | def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 224 | return self.mlp(self.norm2(x)) 225 | 226 | x_list = drop_add_residual_stochastic_depth_list( 227 | x_list, 228 | residual_func=attn_residual_func, 229 | sample_drop_ratio=self.sample_drop_ratio, 230 | scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, 231 | ) 232 | x_list = drop_add_residual_stochastic_depth_list( 233 | x_list, 234 | residual_func=ffn_residual_func, 235 | sample_drop_ratio=self.sample_drop_ratio, 236 | scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, 237 | ) 238 | return x_list 239 | else: 240 | 241 | def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 242 | return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) 243 | 244 | def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: 245 | return self.ls2(self.mlp(self.norm2(x))) 246 | 247 | attn_bias, x = get_attn_bias_and_cat(x_list) 248 | x = x + attn_residual_func(x, attn_bias=attn_bias) 249 | x = x + ffn_residual_func(x) 250 | return attn_bias.split(x) 251 | 252 | def forward(self, x_or_x_list): 253 | if isinstance(x_or_x_list, Tensor): 254 | return super().forward(x_or_x_list) 255 | elif isinstance(x_or_x_list, list): 256 | if not XFORMERS_AVAILABLE: 257 | raise AssertionError("xFormers is required for using nested tensors") 258 | return self.forward_nested(x_or_x_list) 259 | else: 260 | raise AssertionError 261 | -------------------------------------------------------------------------------- /models/layers/dino_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.nn.init import trunc_normal_ 9 | from torch.nn.utils import weight_norm 10 | 11 | 12 | class DINOHead(nn.Module): 13 | def __init__( 14 | self, 15 | in_dim, 16 | out_dim, 17 | use_bn=False, 18 | nlayers=3, 19 | hidden_dim=2048, 20 | bottleneck_dim=256, 21 | mlp_bias=True, 22 | ): 23 | super().__init__() 24 | nlayers = max(nlayers, 1) 25 | self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias) 26 | self.apply(self._init_weights) 27 | self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) 28 | self.last_layer.weight_g.data.fill_(1) 29 | 30 | def _init_weights(self, m): 31 | if isinstance(m, nn.Linear): 32 | trunc_normal_(m.weight, std=0.02) 33 | if isinstance(m, nn.Linear) and m.bias is not None: 34 | nn.init.constant_(m.bias, 0) 35 | 36 | def forward(self, x): 37 | x = self.mlp(x) 38 | eps = 1e-6 if x.dtype == torch.float16 else 1e-12 39 | x = nn.functional.normalize(x, dim=-1, p=2, eps=eps) 40 | x = self.last_layer(x) 41 | return x 42 | 43 | 44 | def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True): 45 | if nlayers == 1: 46 | return nn.Linear(in_dim, bottleneck_dim, bias=bias) 47 | else: 48 | layers = [nn.Linear(in_dim, hidden_dim, bias=bias)] 49 | if use_bn: 50 | layers.append(nn.BatchNorm1d(hidden_dim)) 51 | layers.append(nn.GELU()) 52 | for _ in range(nlayers - 2): 53 | layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias)) 54 | if use_bn: 55 | layers.append(nn.BatchNorm1d(hidden_dim)) 56 | layers.append(nn.GELU()) 57 | layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias)) 58 | return nn.Sequential(*layers) 59 | -------------------------------------------------------------------------------- /models/layers/drop_path.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py 9 | 10 | 11 | from torch import nn 12 | 13 | 14 | def drop_path(x, drop_prob: float = 0.0, training: bool = False): 15 | if drop_prob == 0.0 or not training: 16 | return x 17 | keep_prob = 1 - drop_prob 18 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 19 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob) 20 | if keep_prob > 0.0: 21 | random_tensor.div_(keep_prob) 22 | output = x * random_tensor 23 | return output 24 | 25 | 26 | class DropPath(nn.Module): 27 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" 28 | 29 | def __init__(self, drop_prob=None): 30 | super(DropPath, self).__init__() 31 | self.drop_prob = drop_prob 32 | 33 | def forward(self, x): 34 | return drop_path(x, self.drop_prob, self.training) 35 | -------------------------------------------------------------------------------- /models/layers/layer_scale.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 7 | 8 | from typing import Union 9 | 10 | import torch 11 | from torch import Tensor 12 | from torch import nn 13 | 14 | 15 | class LayerScale(nn.Module): 16 | def __init__( 17 | self, 18 | dim: int, 19 | init_values: Union[float, Tensor] = 1e-5, 20 | inplace: bool = False, 21 | ) -> None: 22 | super().__init__() 23 | self.inplace = inplace 24 | self.gamma = nn.Parameter(init_values * torch.ones(dim)) 25 | 26 | def forward(self, x: Tensor) -> Tensor: 27 | return x.mul_(self.gamma) if self.inplace else x * self.gamma 28 | -------------------------------------------------------------------------------- /models/layers/mlp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py 9 | 10 | 11 | from typing import Callable, Optional 12 | 13 | from torch import Tensor, nn 14 | 15 | 16 | class Mlp(nn.Module): 17 | def __init__( 18 | self, 19 | in_features: int, 20 | hidden_features: Optional[int] = None, 21 | out_features: Optional[int] = None, 22 | act_layer: Callable[..., nn.Module] = nn.GELU, 23 | drop: float = 0.0, 24 | bias: bool = True, 25 | ) -> None: 26 | super().__init__() 27 | out_features = out_features or in_features 28 | hidden_features = hidden_features or in_features 29 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) 30 | self.act = act_layer() 31 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) 32 | self.drop = nn.Dropout(drop) 33 | 34 | def forward(self, x: Tensor) -> Tensor: 35 | x = self.fc1(x) 36 | x = self.act(x) 37 | x = self.drop(x) 38 | x = self.fc2(x) 39 | x = self.drop(x) 40 | return x 41 | -------------------------------------------------------------------------------- /models/layers/patch_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | # References: 7 | # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py 8 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py 9 | 10 | from typing import Callable, Optional, Tuple, Union 11 | 12 | from torch import Tensor 13 | import torch.nn as nn 14 | 15 | 16 | def make_2tuple(x): 17 | if isinstance(x, tuple): 18 | assert len(x) == 2 19 | return x 20 | 21 | assert isinstance(x, int) 22 | return (x, x) 23 | 24 | 25 | class PatchEmbed(nn.Module): 26 | """ 27 | 2D image to patch embedding: (B,C,H,W) -> (B,N,D) 28 | 29 | Args: 30 | img_size: Image size. 31 | patch_size: Patch token size. 32 | in_chans: Number of input image channels. 33 | embed_dim: Number of linear projection output channels. 34 | norm_layer: Normalization layer. 35 | """ 36 | 37 | def __init__( 38 | self, 39 | img_size: Union[int, Tuple[int, int]] = 224, 40 | patch_size: Union[int, Tuple[int, int]] = 16, 41 | in_chans: int = 3, 42 | embed_dim: int = 768, 43 | norm_layer: Optional[Callable] = None, 44 | flatten_embedding: bool = True, 45 | ) -> None: 46 | super().__init__() 47 | 48 | image_HW = make_2tuple(img_size) 49 | patch_HW = make_2tuple(patch_size) 50 | patch_grid_size = ( 51 | image_HW[0] // patch_HW[0], 52 | image_HW[1] // patch_HW[1], 53 | ) 54 | 55 | self.img_size = image_HW 56 | self.patch_size = patch_HW 57 | self.patches_resolution = patch_grid_size 58 | self.num_patches = patch_grid_size[0] * patch_grid_size[1] 59 | 60 | self.in_chans = in_chans 61 | self.embed_dim = embed_dim 62 | 63 | self.flatten_embedding = flatten_embedding 64 | 65 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) 66 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 67 | 68 | def forward(self, x: Tensor) -> Tensor: 69 | _, _, H, W = x.shape 70 | patch_H, patch_W = self.patch_size 71 | 72 | assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" 73 | assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" 74 | 75 | x = self.proj(x) # B C H W 76 | H, W = x.size(2), x.size(3) 77 | x = x.flatten(2).transpose(1, 2) # B HW C 78 | x = self.norm(x) 79 | if not self.flatten_embedding: 80 | x = x.reshape(-1, H, W, self.embed_dim) # B H W C 81 | return x 82 | 83 | def flops(self) -> float: 84 | Ho, Wo = self.patches_resolution 85 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) 86 | if self.norm is not None: 87 | flops += Ho * Wo * self.embed_dim 88 | return flops 89 | -------------------------------------------------------------------------------- /models/layers/swiglu_ffn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the Apache License, Version 2.0 4 | # found in the LICENSE file in the root directory of this source tree. 5 | 6 | import os 7 | from typing import Callable, Optional 8 | import warnings 9 | 10 | from torch import Tensor, nn 11 | import torch.nn.functional as F 12 | 13 | 14 | class SwiGLUFFN(nn.Module): 15 | def __init__( 16 | self, 17 | in_features: int, 18 | hidden_features: Optional[int] = None, 19 | out_features: Optional[int] = None, 20 | act_layer: Callable[..., nn.Module] = None, 21 | drop: float = 0.0, 22 | bias: bool = True, 23 | ) -> None: 24 | super().__init__() 25 | out_features = out_features or in_features 26 | hidden_features = hidden_features or in_features 27 | self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) 28 | self.w3 = nn.Linear(hidden_features, out_features, bias=bias) 29 | 30 | def forward(self, x: Tensor) -> Tensor: 31 | x12 = self.w12(x) 32 | x1, x2 = x12.chunk(2, dim=-1) 33 | hidden = F.silu(x1) * x2 34 | return self.w3(hidden) 35 | 36 | 37 | XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None 38 | try: 39 | if XFORMERS_ENABLED: 40 | from xformers.ops import SwiGLU 41 | 42 | XFORMERS_AVAILABLE = True 43 | warnings.warn("xFormers is available (SwiGLU)") 44 | else: 45 | warnings.warn("xFormers is disabled (SwiGLU)") 46 | raise ImportError 47 | except ImportError: 48 | SwiGLU = SwiGLUFFN 49 | XFORMERS_AVAILABLE = False 50 | 51 | warnings.warn("xFormers is not available (SwiGLU)") 52 | 53 | 54 | class SwiGLUFFNFused(SwiGLU): 55 | def __init__( 56 | self, 57 | in_features: int, 58 | hidden_features: Optional[int] = None, 59 | out_features: Optional[int] = None, 60 | act_layer: Callable[..., nn.Module] = None, 61 | drop: float = 0.0, 62 | bias: bool = True, 63 | ) -> None: 64 | out_features = out_features or in_features 65 | hidden_features = hidden_features or in_features 66 | hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 67 | super().__init__( 68 | in_features=in_features, 69 | hidden_features=hidden_features, 70 | out_features=out_features, 71 | bias=bias, 72 | ) 73 | -------------------------------------------------------------------------------- /models/model_utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Adapted from R. Strudel et al. 3 | https://github.com/rstrudel/segmenter 4 | 5 | MIT License 6 | Copyright (c) 2021 Robin Strudel 7 | Copyright (c) INRIA 8 | ''' 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import math 14 | from timm.models.layers import trunc_normal_ 15 | 16 | 17 | def resize_pos_embed(posemb, grid_old_shape, grid_new_shape, num_extra_tokens): 18 | # Rescale the grid of position embeddings when loading from state_dict. Adapted from 19 | # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 20 | posemb_tok, posemb_grid = ( 21 | posemb[:, :num_extra_tokens], 22 | posemb[0, num_extra_tokens:], 23 | ) 24 | if grid_old_shape is None: 25 | gs_old_h = int(math.sqrt(len(posemb_grid))) 26 | gs_old_w = gs_old_h 27 | else: 28 | gs_old_h, gs_old_w = grid_old_shape 29 | 30 | gs_h, gs_w = grid_new_shape 31 | posemb_grid = posemb_grid.reshape(1, gs_old_h, gs_old_w, -1).permute(0, 3, 1, 2) 32 | posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode='bilinear') 33 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) 34 | posemb = torch.cat([posemb_tok, posemb_grid], dim=1) 35 | return posemb 36 | 37 | 38 | def init_weights(m): 39 | if isinstance(m, nn.Linear): 40 | trunc_normal_(m.weight, std=0.02) 41 | if isinstance(m, nn.Linear) and m.bias is not None: 42 | nn.init.constant_(m.bias, 0) 43 | elif isinstance(m, nn.LayerNorm): 44 | nn.init.constant_(m.bias, 0) 45 | nn.init.constant_(m.weight, 1.0) 46 | 47 | 48 | def get_grid_size_1d(length, patch_size, stride): 49 | assert patch_size % stride == 0 50 | assert length % patch_size == 0 51 | return (length - patch_size) // stride + 1 52 | 53 | 54 | def get_grid_size_2d(H, W, patch_size, patch_stride): 55 | if isinstance(patch_size, int): 56 | PS_H = PS_W = patch_size 57 | else: 58 | PS_H, PS_W = patch_size 59 | 60 | if patch_stride is not None: 61 | if isinstance(patch_stride, int): 62 | patch_stride = (patch_stride, patch_stride) 63 | H_stride, W_stride = patch_stride 64 | else: 65 | H_stride = PS_H 66 | W_stride = PS_W 67 | 68 | grid_H = get_grid_size_1d(H, PS_H, H_stride) 69 | grid_W = get_grid_size_1d(W, PS_W, W_stride) 70 | return grid_H, grid_W 71 | 72 | 73 | def adapt_input_conv(in_chans, conv_weight): 74 | conv_type = conv_weight.dtype 75 | conv_weight = conv_weight.float() # Some weights are in torch.half, ensure it's float for sum on CPU 76 | O, I, J, K = conv_weight.shape 77 | if in_chans == 1: 78 | if I > 3: 79 | assert conv_weight.shape[1] % 3 == 0 80 | # For models with space2depth stems 81 | conv_weight = conv_weight.reshape(O, I // 3, 3, J, K) 82 | conv_weight = conv_weight.sum(dim=2, keepdim=False) 83 | else: 84 | conv_weight = conv_weight.sum(dim=1, keepdim=True) 85 | elif in_chans != 3: 86 | if I != 3: 87 | raise NotImplementedError('Weight format not supported by conversion.') 88 | else: 89 | # NOTE this strategy should be better than random init, but there could be other combinations of 90 | # the original RGB input layer weights that'd work better for specific cases. 91 | repeat = int(math.ceil(in_chans / 3)) 92 | conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :] 93 | conv_weight *= (3 / float(in_chans)) 94 | conv_weight = conv_weight.to(conv_type) 95 | return conv_weight 96 | 97 | 98 | def padding(im, patch_size, fill_value=0): 99 | # make the image sizes divisible by patch_size 100 | H, W = im.size(2), im.size(3) 101 | pad_h, pad_w = 0, 0 102 | if isinstance(patch_size, int): 103 | patch_size_H = patch_size_W = patch_size 104 | else: 105 | patch_size_H, patch_size_W = patch_size 106 | if H % patch_size_H > 0: 107 | pad_h = patch_size_H - (H % patch_size_H) 108 | if W % patch_size_W > 0: 109 | pad_w = patch_size_W - (W % patch_size_W) 110 | im_padded = im 111 | if pad_h > 0 or pad_w > 0: 112 | im_padded = F.pad(im, (0, pad_w, 0, pad_h), value=fill_value) 113 | return im_padded 114 | 115 | 116 | def unpadding(y, target_size): 117 | H, W = target_size 118 | H_pad, W_pad = y.size(2), y.size(3) 119 | # crop predictions on extra pixels coming from padding 120 | extra_h = H_pad - H 121 | extra_w = W_pad - W 122 | if extra_h > 0: 123 | y = y[:, :, :-extra_h] 124 | if extra_w > 0: 125 | y = y[:, :, :, :-extra_w] 126 | return y 127 | 128 | 129 | class GeMPooling(nn.Module): 130 | def __init__(self, p=3, eps=1e-6): 131 | super(GeMPooling, self).__init__() 132 | self.p = nn.Parameter(torch.ones(1) * p) 133 | self.eps = eps 134 | 135 | def forward(self, x): 136 | # This implicitly applies ReLU on x (clamps negative values) 137 | temp = x.clamp(min=self.eps).pow(self.p) 138 | # 全局平均池化 139 | temp = temp.mean(1) 140 | 141 | return temp.pow(1. / self.p) 142 | -------------------------------------------------------------------------------- /models/stems.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 - Valeo Comfort and Driving Assistance 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | 19 | from .model_utils import get_grid_size_1d, get_grid_size_2d 20 | 21 | 22 | class PatchEmbedding(nn.Module): 23 | def __init__(self, image_size, patch_size, patch_stride, embed_dim, channels, resize_emb=True): 24 | super().__init__() 25 | if isinstance(patch_size, int): 26 | patch_size = (patch_size, patch_size) 27 | if patch_stride is None: 28 | patch_stride = patch_size 29 | else: 30 | if isinstance(patch_stride, int): 31 | patch_stride = (patch_stride, patch_stride) 32 | assert isinstance(patch_size, (list, tuple)) 33 | assert isinstance(patch_stride, (list, tuple)) 34 | assert len(patch_stride) == 2 35 | assert len(patch_size) == 2 36 | patch_size = tuple(patch_size) 37 | patch_stride = tuple(patch_stride) 38 | 39 | self.image_size = image_size 40 | if image_size[0] % patch_size[0] != 0 or image_size[1] % patch_size[1] != 0: 41 | raise ValueError('image dimensions must be divisible by the patch size') 42 | self.grid_size = ( 43 | get_grid_size_1d(image_size[0], patch_size[0], patch_stride[0]), 44 | get_grid_size_1d(image_size[1], patch_size[1], patch_stride[1])) 45 | 46 | self.num_patches = self.grid_size[0] * self.grid_size[1] 47 | self.patch_size = patch_size 48 | self.patch_stride = patch_stride 49 | self.proj = nn.Conv2d( 50 | channels, embed_dim, kernel_size=patch_size, stride=patch_stride 51 | ) 52 | self.resize_emb = resize_emb 53 | 54 | def get_grid_size(self, H, W): 55 | return get_grid_size_2d(H, W, self.patch_size, self.patch_stride) 56 | 57 | def forward(self, im): 58 | B, C, H, W = im.shape 59 | if self.resize_emb: 60 | x = self.proj(im).flatten(2).transpose(1, 2) # shape: B, N, D 61 | else: 62 | x = self.proj(im) # shape: B, D, new_H, new_W 63 | return x, None 64 | 65 | 66 | class ConvStem(nn.Module): 67 | def __init__(self, 68 | in_channels=3, 69 | base_channels=32, 70 | img_size=(32, 384), 71 | patch_stride=(2, 8), 72 | embed_dim=384, 73 | flatten=True, 74 | hidden_dim=None): 75 | super().__init__() 76 | 77 | if hidden_dim is None: 78 | hidden_dim = 2 * base_channels 79 | 80 | self.base_channels = base_channels 81 | self.dropout_ratio = 0.2 82 | 83 | # Build stem, similar to the design in https://github.com/TiagoCortinhal/SalsaNext 84 | self.conv_block = nn.Sequential( 85 | ResContextBlock(in_channels, base_channels), 86 | ResContextBlock(base_channels, base_channels), 87 | ResContextBlock(base_channels, base_channels), 88 | ResBlock(base_channels, hidden_dim, self.dropout_ratio, pooling=False, drop_out=False)) 89 | 90 | assert patch_stride[0] % 2 == 0 91 | assert patch_stride[1] % 2 == 0 92 | 93 | kernel_size = (patch_stride[0] + 1, patch_stride[1] + 1) 94 | padding = (patch_stride[0] // 2, patch_stride[1] // 2) 95 | patch_stride = (patch_stride[0], patch_stride[1]) 96 | self.proj_block = nn.Sequential( 97 | nn.AvgPool2d(kernel_size=kernel_size, stride=patch_stride, padding=padding), 98 | nn.Conv2d(hidden_dim, embed_dim, kernel_size=1)) 99 | 100 | self.patch_stride = patch_stride 101 | self.patch_size = patch_stride 102 | self.grid_size = (img_size[0] // patch_stride[0], img_size[1] // patch_stride[1]) 103 | self.num_patches = self.grid_size[0] * self.grid_size[1] 104 | # print("num_", self.num_patches) 105 | self.flatten = flatten 106 | 107 | def get_grid_size(self, H, W): 108 | return get_grid_size_2d(H, W, self.patch_size, self.patch_stride) 109 | 110 | def forward(self, x): 111 | B, C, H, W = x.shape # B, in_channels, image_size[0], image_size[1] 112 | x_base = self.conv_block(x) # B, hidden_dim, image_size[0], image_size[1] 113 | x = self.proj_block(x_base) 114 | if self.flatten: 115 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC 116 | 117 | return x, x_base 118 | 119 | 120 | class ResContextBlock(nn.Module): 121 | # From T. Cortinhal et al. 122 | # https://github.com/TiagoCortinhal/SalsaNext 123 | def __init__(self, in_filters, out_filters): 124 | super(ResContextBlock, self).__init__() 125 | self.conv1 = nn.Conv2d(in_filters, out_filters, kernel_size=(1, 1), stride=1) 126 | self.act1 = nn.LeakyReLU() 127 | 128 | self.conv2 = nn.Conv2d(out_filters, out_filters, (3, 3), padding=1) 129 | self.act2 = nn.LeakyReLU() 130 | self.bn1 = nn.BatchNorm2d(out_filters) 131 | 132 | self.conv3 = nn.Conv2d(out_filters, out_filters, (3, 3), dilation=2, padding=2) 133 | self.act3 = nn.LeakyReLU() 134 | self.bn2 = nn.BatchNorm2d(out_filters) 135 | 136 | def forward(self, x): 137 | shortcut = self.conv1(x) 138 | shortcut = self.act1(shortcut) 139 | 140 | resA = self.conv2(shortcut) 141 | resA = self.act2(resA) 142 | resA1 = self.bn1(resA) 143 | 144 | resA = self.conv3(resA1) 145 | resA = self.act3(resA) 146 | resA2 = self.bn2(resA) 147 | 148 | output = shortcut + resA2 149 | return output 150 | 151 | 152 | class ResBlock(nn.Module): 153 | # From T. Cortinhal et al. 154 | # https://github.com/TiagoCortinhal/SalsaNext 155 | def __init__(self, in_filters, out_filters, dropout_rate, kernel_size=(3, 3), stride=1, 156 | pooling=True, drop_out=True): 157 | super(ResBlock, self).__init__() 158 | self.pooling = pooling 159 | self.drop_out = drop_out 160 | self.conv1 = nn.Conv2d(in_filters, out_filters, kernel_size=(1, 1), stride=stride) 161 | self.act1 = nn.LeakyReLU() 162 | 163 | self.conv2 = nn.Conv2d(in_filters, out_filters, kernel_size=(3, 3), padding=1) 164 | self.act2 = nn.LeakyReLU() 165 | self.bn1 = nn.BatchNorm2d(out_filters) 166 | 167 | self.conv3 = nn.Conv2d(out_filters, out_filters, kernel_size=(3, 3), dilation=2, padding=2) 168 | self.act3 = nn.LeakyReLU() 169 | self.bn2 = nn.BatchNorm2d(out_filters) 170 | 171 | self.conv4 = nn.Conv2d(out_filters, out_filters, kernel_size=(2, 2), dilation=2, padding=1) 172 | self.act4 = nn.LeakyReLU() 173 | self.bn3 = nn.BatchNorm2d(out_filters) 174 | 175 | self.conv5 = nn.Conv2d(out_filters * 3, out_filters, kernel_size=(1, 1)) 176 | self.act5 = nn.LeakyReLU() 177 | self.bn4 = nn.BatchNorm2d(out_filters) 178 | 179 | if pooling: 180 | self.dropout = nn.Dropout2d(p=dropout_rate) 181 | self.pool = nn.AvgPool2d(kernel_size=kernel_size, stride=2, padding=1) 182 | else: 183 | self.dropout = nn.Dropout2d(p=dropout_rate) 184 | 185 | def forward(self, x): 186 | shortcut = self.conv1(x) 187 | shortcut = self.act1(shortcut) 188 | 189 | resA = self.conv2(x) 190 | resA = self.act2(resA) 191 | resA1 = self.bn1(resA) 192 | 193 | resA = self.conv3(resA1) 194 | resA = self.act3(resA) 195 | resA2 = self.bn2(resA) 196 | 197 | resA = self.conv4(resA2) 198 | resA = self.act4(resA) 199 | resA3 = self.bn3(resA) 200 | 201 | concat = torch.cat((resA1, resA2, resA3), dim=1) 202 | resA = self.conv5(concat) 203 | resA = self.act5(resA) 204 | resA = self.bn4(resA) 205 | resA = shortcut + resA 206 | 207 | if self.pooling: 208 | if self.drop_out: 209 | resB = self.dropout(resA) 210 | else: 211 | resB = resA 212 | resB = self.pool(resB) 213 | 214 | return resB, resA 215 | else: 216 | if self.drop_out: 217 | resB = self.dropout(resA) 218 | else: 219 | resB = resA 220 | return resB 221 | -------------------------------------------------------------------------------- /test_bev.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Ziyue Wang and Wen Li 3 | @file: test_bev.py 4 | @time: 2025/3/12 14:20 5 | """ 6 | 7 | import time 8 | import matplotlib 9 | import os.path as osp 10 | matplotlib.use('Agg') 11 | 12 | from hydra.utils import instantiate 13 | from omegaconf import OmegaConf, DictConfig 14 | from utils.train_util import * 15 | from utils.utils import seed_all_random_engines 16 | from utils.pose_util import qexp, val_translation, val_rotation, r_to_d 17 | from datasets.composition_bev import MF_bev 18 | from tensorboardX import SummaryWriter 19 | 20 | 21 | TOTAL_ITERATIONS = 0 22 | 23 | def log_string(out_str): 24 | LOG_FOUT.write(out_str + '\n') 25 | LOG_FOUT.flush() 26 | print(out_str) 27 | 28 | 29 | def test(cfg: DictConfig): 30 | # NOTE carefully double check the instruction from huggingface! 31 | global TOTAL_ITERATIONS 32 | OmegaConf.set_struct(cfg, False) 33 | 34 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 35 | # Instantiate the model 36 | model = instantiate(cfg.MODEL, _recursive_=False) 37 | 38 | eval_dataset = MF_bev(cfg.train.dataset, cfg, split='eval') 39 | 40 | ckpt_path = os.path.join(cfg.ckpt) 41 | if os.path.isfile(ckpt_path): 42 | checkpoint = torch.load(ckpt_path, map_location=device) 43 | model.load_state_dict(checkpoint, strict=True) 44 | print(f"Loaded checkpoint from: {ckpt_path}") 45 | else: 46 | raise ValueError(f"No checkpoint found at: {ckpt_path}") 47 | 48 | if cfg.train.num_workers > 0: 49 | persistent_workers = cfg.train.persistent_workers 50 | else: 51 | persistent_workers = False 52 | 53 | eval_dataloader = torch.utils.data.DataLoader(eval_dataset, batch_size=cfg.train.val_batch_size, 54 | num_workers=cfg.train.num_workers, 55 | pin_memory=cfg.train.pin_memory, 56 | persistent_workers=persistent_workers, 57 | shuffle=False) # collate 58 | 59 | # Move model and images to the GPU 60 | model = model.to(device) 61 | 62 | # Evaluation Mode 63 | model.eval() 64 | 65 | # Seed random engines 66 | seed_all_random_engines(cfg.seed) 67 | 68 | # pose mean and std 69 | pose_stats = os.path.join(cfg.train.dataroot, cfg.train.dataset, cfg.train.dataset + '_pose_stats.txt') 70 | pose_m, pose_s = np.loadtxt(pose_stats) 71 | pose_m = pose_m[:2] 72 | pose_s = pose_s[:2] 73 | # results 74 | gt_translation = np.zeros((len(eval_dataset), 2)) 75 | pred_translation = np.zeros((len(eval_dataset), 2)) 76 | gt_rotation = np.zeros((len(eval_dataset), 1)) 77 | pred_rotation = np.zeros((len(eval_dataset), 1)) 78 | error_t = np.zeros(len(eval_dataset)) 79 | error_q = np.zeros(len(eval_dataset)) 80 | 81 | T1 = time.time() 82 | 83 | for step, batch in enumerate(eval_dataloader): 84 | val_pose = batch["pose"][:, -1, :] 85 | start_idx = step * cfg.train.val_batch_size 86 | end_idx = min((step + 1) * cfg.train.val_batch_size, len(eval_dataset)) 87 | gt_translation[start_idx:end_idx, :] = val_pose[:, :2].numpy() * pose_s + pose_m 88 | gt_rotation[start_idx:end_idx, :] = np.asarray([r_to_d(q).flatten() for q in val_pose[:, 2].numpy()]) 89 | images = batch["image"].to(device) 90 | with torch.no_grad(): 91 | predictions = model(images, sampling_timesteps=cfg.sampling_timesteps, training=False) 92 | # predicted pose 93 | pred = predictions['pred_pose'] 94 | s = pred.size() # out.shape = [B, N, 6] 95 | pred_t = pred[..., :2] 96 | pred_q = pred[..., 2] 97 | # last frame 98 | pred_t = pred_t.view(s[0], s[1], 2) 99 | pred_q = pred_q.view(s[0], s[1], 1) 100 | pred_t = pred_t[:, -1, :] 101 | pred_q = pred_q[:, -1, :] 102 | 103 | # RTE / RRE 104 | pred_translation[start_idx:end_idx, :] = pred_t.cpu().numpy() * pose_s + pose_m 105 | pred_rotation[start_idx:end_idx, :] = np.asarray([r_to_d(q) for q in pred_q.cpu().numpy()]) 106 | error_t[start_idx:end_idx] = np.asarray([val_translation(p, q) for p, q in zip(pred_translation[start_idx:end_idx, :], gt_translation[start_idx:end_idx, :])]) 107 | error_q[start_idx:end_idx] = np.asarray([abs(p - q).squeeze() for p, q in zip(pred_rotation[start_idx:end_idx, :], gt_rotation[start_idx:end_idx, :])]) 108 | error_q[start_idx:end_idx] = np.where(error_q[start_idx:end_idx] > 180, abs(360 - error_q[start_idx:end_idx]), error_q[start_idx:end_idx]) 109 | 110 | log_string('MeanTE(m): %f' % np.mean(error_t[start_idx:end_idx], axis=0)) 111 | log_string('MeanRE(degrees): %f' % np.mean(error_q[start_idx:end_idx], axis=0)) 112 | log_string('MedianTE(m): %f' % np.median(error_t[start_idx:end_idx], axis=0)) 113 | log_string('MedianRE(degrees): %f' % np.median(error_q[start_idx:end_idx], axis=0)) 114 | 115 | T2 = time.time() 116 | print("time:", T2-T1) 117 | 118 | mean_ATE = np.mean(error_t) 119 | mean_ARE = np.mean(error_q) 120 | median_ATE = np.median(error_t) 121 | median_ARE = np.median(error_q) 122 | 123 | log_string('Mean Position Error(m): %f' % mean_ATE) 124 | log_string('Mean Orientation Error(degrees): %f' % mean_ARE) 125 | log_string('Median Position Error(m): %f' % median_ATE) 126 | log_string('Median Orientation Error(degrees): %f' % median_ARE) 127 | 128 | val_writer.add_scalar('MeanATE', mean_ATE, TOTAL_ITERATIONS) 129 | val_writer.add_scalar('MeanARE', mean_ARE, TOTAL_ITERATIONS) 130 | val_writer.add_scalar('MedianATE', median_ATE, TOTAL_ITERATIONS) 131 | val_writer.add_scalar('MedianARE', median_ARE, TOTAL_ITERATIONS) 132 | 133 | # save error and trajectory 134 | real_pose = pred_translation - pose_m 135 | gt_pose = gt_translation - pose_m 136 | error_t_filename = osp.join(cfg.exp_dir, 'error_t.txt') 137 | error_q_filename = osp.join(cfg.exp_dir, 'error_q.txt') 138 | pred_t_filename = osp.join(cfg.exp_dir, 'pred_t.txt') 139 | gt_t_filename = osp.join(cfg.exp_dir, 'gt_t.txt') 140 | pred_q_filename = osp.join(cfg.exp_dir, 'pred_q.txt') 141 | gt_q_filename = osp.join(cfg.exp_dir, 'gt_q.txt') 142 | np.savetxt(error_t_filename, error_t, fmt='%8.7f') 143 | np.savetxt(error_q_filename, error_q, fmt='%8.7f') 144 | np.savetxt(pred_t_filename, real_pose, fmt='%8.7f') 145 | np.savetxt(gt_t_filename, gt_pose, fmt='%8.7f') 146 | np.savetxt(pred_q_filename, pred_rotation, fmt='%8.7f') 147 | np.savetxt(gt_q_filename, gt_rotation, fmt='%8.7f') 148 | 149 | 150 | if __name__ == '__main__': 151 | # oxford_bev.yaml / nclt_bev.yaml / 152 | conf = OmegaConf.load('cfgs/oxford_bev.yaml') 153 | LOG_FOUT = open(os.path.join(conf.exp_dir, 'log.txt'), 'w') 154 | LOG_FOUT.write(str(conf) + '\n') 155 | val_writer = SummaryWriter(os.path.join(conf.exp_dir, 'valid')) 156 | # 5 cpu core 157 | torch.set_num_threads(5) 158 | test(conf) 159 | -------------------------------------------------------------------------------- /train_bev.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Ziyue Wang and Wen Li 3 | @file: train_bev.py 4 | @time: 2025/3/12 14:20 5 | """ 6 | 7 | import io 8 | import time 9 | import pstats 10 | import cProfile 11 | import torch.nn as nn 12 | 13 | from hydra.utils import instantiate 14 | from collections import OrderedDict 15 | from omegaconf import OmegaConf, DictConfig 16 | from pytorch3d.implicitron.tools import vis_utils 17 | from accelerate import Accelerator, DistributedDataParallelKwargs 18 | from utils.train_util import * 19 | from datasets.composition_bev import MF_bev 20 | from tqdm import tqdm 21 | from tensorboardX import SummaryWriter 22 | 23 | writer = SummaryWriter(log_dir='./runs/03_10') 24 | 25 | def prefix_with_module(checkpoint): 26 | prefixed_checkpoint = OrderedDict() 27 | for key, value in checkpoint.items(): 28 | prefixed_key = "module." + key 29 | prefixed_checkpoint[prefixed_key] = value 30 | return prefixed_checkpoint 31 | 32 | 33 | # Wrapper for cProfile.Profile for easily make optional, turn on/off and printing 34 | class Profiler: 35 | def __init__(self, active: bool): 36 | self.c_profiler = cProfile.Profile() 37 | self.active = active 38 | 39 | def enable(self): 40 | if self.active: 41 | self.c_profiler.enable() 42 | 43 | def disable(self): 44 | if self.active: 45 | self.c_profiler.disable() 46 | 47 | def print(self): 48 | if self.active: 49 | s = io.StringIO() 50 | sortby = pstats.SortKey.CUMULATIVE 51 | ps = pstats.Stats(self.c_profiler, stream=s).sort_stats(sortby) 52 | ps.print_stats() 53 | print(s.getvalue()) 54 | 55 | 56 | def get_thread_count(var_name): 57 | return os.environ.get(var_name) 58 | 59 | 60 | def train_fn(cfg: DictConfig): 61 | # NOTE carefully double check the instruction from huggingface! 62 | 63 | OmegaConf.set_struct(cfg, False) 64 | 65 | # Initialize the accelerator 66 | accelerator = Accelerator(even_batches=False, device_placement=False) 67 | 68 | accelerator.print("Model Config:") 69 | accelerator.print(OmegaConf.to_yaml(cfg)) 70 | 71 | accelerator.print("Accelerator State:") 72 | accelerator.print(accelerator.state) 73 | 74 | torch.backends.cudnn.benchmark = cfg.train.cudnnbenchmark 75 | 76 | set_seed_and_print(cfg.seed) 77 | 78 | if accelerator.is_main_process: 79 | viz = vis_utils.get_visdom_connection( 80 | server="http://127.0.0.1", 81 | port=int(os.environ.get("VISDOM_PORT", 8097)), 82 | ) 83 | 84 | viz = vis_utils.get_visdom_connection(server="http://127.0.0.1",port=int(os.environ.get("VISDOM_PORT", 8097))) 85 | 86 | accelerator.print(f"!!!!!!!!!!!!!!!!!!!!!!!!!! OMP_NUM_THREADS: {get_thread_count('OMP_NUM_THREADS')}") 87 | accelerator.print(f"!!!!!!!!!!!!!!!!!!!!!!!!!! MKL_NUM_THREADS: {get_thread_count('MKL_NUM_THREADS')}") 88 | 89 | accelerator.print(f"!!!!!!!!!!!!!!!!!!!!!!!!!! SLURM_CPU_BIND: {get_thread_count('SLURM_CPU_BIND')}") 90 | accelerator.print( 91 | f"!!!!!!!!!!!!!!!!!!!!!!!!!! SLURM_JOB_CPUS_PER_NODE: {get_thread_count('SLURM_JOB_CPUS_PER_NODE')}") 92 | 93 | train_dataset = MF_bev(cfg.train.dataset, cfg, split='train') 94 | eval_dataset = MF_bev(cfg.train.dataset, cfg, split='eval') 95 | 96 | if cfg.train.num_workers > 0: 97 | persistent_workers = cfg.train.persistent_workers 98 | else: 99 | persistent_workers = False 100 | 101 | dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=cfg.train.batch_size, 102 | num_workers=cfg.train.num_workers, 103 | pin_memory=cfg.train.pin_memory, 104 | shuffle=True, drop_last=True, 105 | persistent_workers=persistent_workers 106 | ) # collate_fn 107 | eval_dataloader = torch.utils.data.DataLoader(eval_dataset, batch_size=cfg.train.batch_size, 108 | num_workers=cfg.train.num_workers, pin_memory=cfg.train.pin_memory, 109 | shuffle=False, persistent_workers=persistent_workers) # collate_fn 110 | 111 | accelerator.print("length of train dataloader is: ", len(dataloader)) 112 | accelerator.print("length of eval dataloader is: ", len(eval_dataloader)) 113 | 114 | # Instantiate the model 115 | model = instantiate(cfg.MODEL, _recursive_=False) 116 | 117 | model = model.to(accelerator.device) 118 | criterion = nn.BCEWithLogitsLoss() 119 | 120 | # Define the numer of epoch 121 | num_epochs = cfg.train.epochs 122 | 123 | # log 124 | if os.path.exists(cfg.exp_dir) == 0: 125 | os.mkdir(cfg.exp_dir) 126 | # Define the optimizer 127 | if cfg.train.warmup_sche: 128 | optimizer = torch.optim.AdamW(params=model.parameters(), lr=cfg.train.lr) 129 | lr_scheduler = WarmupCosineLR(optimizer=optimizer, lr=cfg.train.lr, 130 | warmup_steps=cfg.train.restart_num * len(dataloader), momentum=0.9, 131 | max_steps=len(dataloader) * (cfg.train.epochs - cfg.train.restart_num)) 132 | else: 133 | optimizer = torch.optim.AdamW(params=model.parameters(), lr=cfg.train.lr, weight_decay=cfg.train.weight_decay) 134 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=80, gamma=0.5) 135 | 136 | 137 | model, dataloader, optimizer, lr_scheduler = accelerator.prepare(model, dataloader, optimizer, lr_scheduler) 138 | 139 | accelerator.print(f"xxxxxxxxxxxxxxxxxx dataloader has {dataloader.num_workers} num_workers") 140 | 141 | start_epoch = 0 142 | 143 | to_plot = ("loss", "lr", "diffloss", "error_t", "error_q") 144 | 145 | stats = VizStats(to_plot) 146 | 147 | for epoch in range(start_epoch, num_epochs): 148 | stats.new_epoch() 149 | 150 | set_seed_and_print(cfg.seed + epoch) 151 | 152 | # Evaluation 153 | if (epoch != 0) and (epoch % cfg.train.eval_interval == 0): 154 | # if (epoch % cfg.train.eval_interval == 0): 155 | accelerator.print(f"----------Start to eval at epoch {epoch}----------") 156 | _train_or_eval_fn(model, criterion, eval_dataloader, cfg, optimizer, stats, accelerator, lr_scheduler, epoch, training=False) 157 | accelerator.print(f"----------Finish the eval at epoch {epoch}----------") 158 | else: 159 | accelerator.print(f"----------Skip the eval at epoch {epoch}----------") 160 | 161 | # Training 162 | accelerator.print(f"----------Start to train at epoch {epoch}----------") 163 | _train_or_eval_fn(model, criterion, dataloader, cfg, optimizer, stats, accelerator, lr_scheduler, epoch, training=True) 164 | accelerator.print(f"----------Finish the train at epoch {epoch}----------") 165 | 166 | if accelerator.is_main_process: 167 | for g in optimizer.param_groups: 168 | lr = g['lr'] 169 | break 170 | accelerator.print(f"----------LR is {lr}----------") 171 | accelerator.print(f"----------Saving stats to {cfg.exp_name}----------") 172 | stats.update({"lr": lr}, stat_set="train") 173 | stats.plot_stats(viz=viz, visdom_env=cfg.exp_name) 174 | accelerator.print(f"----------Done----------") 175 | 176 | if epoch >= 40: 177 | accelerator.wait_for_everyone() 178 | ckpt_path = os.path.join(cfg.exp_dir, f"ckpt_{epoch:06}.pth") 179 | accelerator.print(f"----------Saving the ckpt at epoch {epoch} to {ckpt_path}----------") 180 | unwrapped_model = accelerator.unwrap_model(model) 181 | if epoch % 5 == 0: 182 | accelerator.save(unwrapped_model.state_dict(), ckpt_path) 183 | 184 | if accelerator.is_main_process: 185 | stats.save(cfg.exp_dir + "stats") 186 | 187 | return True 188 | 189 | 190 | def _train_or_eval_fn(model, criterion, dataloader, cfg, optimizer, stats, accelerator, lr_scheduler, epoch, training=True): 191 | if training: 192 | model.train() 193 | else: 194 | model.eval() 195 | 196 | # print(f"Start the loop for process {accelerator.process_index}") 197 | 198 | time_start = time.time() 199 | max_it = len(dataloader) 200 | 201 | pose_stats = os.path.join(cfg.train.dataroot, cfg.train.dataset, cfg.train.dataset + '_pose_stats.txt') 202 | pose_m, pose_s = np.loadtxt(pose_stats) 203 | pose_s = torch.from_numpy(pose_s).to(accelerator.device) 204 | pose_m = torch.from_numpy(pose_m).to(accelerator.device) 205 | 206 | tqdm_loader = tqdm(dataloader, total=len(dataloader)) 207 | for step, batch in enumerate(tqdm_loader): 208 | images = batch["image"].to(accelerator.device) # [B, N, 3, 251, 251] 209 | batch_size, frame_size = images.size(0), images.size(1) 210 | poses = batch["pose"].to(accelerator.device) # [B, N, 3] 211 | H, W = images.size(-2), images.size(-1) 212 | 213 | if training: 214 | predictions = model(images, poses, training=True) 215 | predictions["diffloss"] = predictions["diffloss"] 216 | loss = predictions["diffloss"] 217 | writer.add_scalar('train/diffloss', loss.item(), step + epoch * max_it) 218 | else: 219 | with torch.no_grad(): 220 | predictions = model(images, training=False) 221 | 222 | # calculate metric 223 | frame_num = frame_size * batch_size 224 | pred_poses = predictions['pred_pose'].reshape(frame_num, 3) # [B*N, 3] 225 | gt_poses = poses.reshape(frame_num, 3) # [B*N, 3] 226 | 227 | iou = 0. 228 | for i in range(frame_num): 229 | if i == 0: 230 | error_t = t_error(pred_poses[i, :2], gt_poses[i, :2], pose_s[:2], pose_m[:2]) 231 | error_q = q_error(pred_poses[i, 2], gt_poses[i, 2]) 232 | else: 233 | error_t += (t_error(pred_poses[i, :2], gt_poses[i, :2], pose_s[:2], pose_m[:2])) 234 | error_q += (q_error(pred_poses[i, 2], gt_poses[i, 2])) 235 | 236 | predictions['error_t'] = error_t / frame_num 237 | predictions['error_q'] = error_q / frame_num 238 | 239 | if training: 240 | writer.add_scalar('train/error_t', predictions['error_t'].item(), step + epoch * max_it) 241 | writer.add_scalar('train/error_q', predictions['error_q'].item(), step + epoch * max_it) 242 | 243 | if training: 244 | stats.update(predictions, time_start=time_start, stat_set="train") 245 | if step % cfg.train.print_interval == 0: 246 | accelerator.print(stats.print(stat_set="train", max_it=max_it)) 247 | else: 248 | stats.update(predictions, time_start=time_start, stat_set="eval") 249 | if step % cfg.train.print_interval == 0: 250 | accelerator.print(stats.print(stat_set="eval", max_it=max_it)) 251 | 252 | if training: 253 | optimizer.zero_grad() 254 | accelerator.backward(loss) 255 | if cfg.train.clip_grad > 0 and accelerator.sync_gradients: 256 | accelerator.clip_grad_norm_(model.parameters(), cfg.train.clip_grad) 257 | optimizer.step() 258 | lr_scheduler.step() 259 | 260 | return True 261 | 262 | def t_error(pred_poses, gt_poses, pose_s, pose_mean): 263 | with torch.no_grad(): 264 | error_t = val_translation(pred_poses, gt_poses, pose_s, pose_mean) 265 | 266 | return error_t 267 | 268 | def q_error(pred_poses, gt_poses): 269 | with torch.no_grad(): 270 | p = r_to_d(pred_poses) 271 | q = r_to_d(gt_poses) 272 | error_q = abs(p - q) 273 | 274 | return error_q 275 | 276 | def r_to_d(r): 277 | 278 | d = r * 180 / np.pi 279 | 280 | return d 281 | 282 | def val_translation(pred_p, gt_p, pose_s, pose_mean): 283 | """ 284 | test model, compute error (numpy) 285 | input: 286 | pred_p: [3,] 287 | gt_p: [3,] 288 | returns: 289 | translation error (m): 290 | """ 291 | pred_p = (pred_p * pose_s) + pose_mean 292 | gt_p = (gt_p * pose_s) + pose_mean 293 | error = torch.linalg.norm(gt_p - pred_p) 294 | 295 | return error 296 | 297 | 298 | if __name__ == '__main__': 299 | # oxford_bev.yaml / nclt_bev.yaml 300 | conf = OmegaConf.load('cfgs/oxford_bev.yaml') 301 | # conf = OmegaConf.load('cfgs/nclt_bev.yaml') 302 | train_fn(conf) 303 | -------------------------------------------------------------------------------- /utils/embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | from pytorch3d.renderer import HarmonicEmbedding 5 | 6 | 7 | class TimeStepEmbedding(nn.Module): 8 | # learned from https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/nn.py 9 | def __init__(self, dim=256, max_period=10000): 10 | super().__init__() 11 | self.dim = dim 12 | self.max_period = max_period 13 | 14 | self.linear = nn.Sequential( 15 | nn.Linear(dim, dim // 2), 16 | nn.SiLU(), 17 | nn.Linear(dim // 2, dim // 2), 18 | ) 19 | 20 | self.out_dim = dim // 2 21 | 22 | def _compute_freqs(self, half): 23 | freqs = torch.exp( 24 | -math.log(self.max_period) 25 | * torch.arange(start=0, end=half, dtype=torch.float32) 26 | / half 27 | ) 28 | return freqs 29 | 30 | def forward(self, timesteps): 31 | half = self.dim // 2 32 | freqs = self._compute_freqs(half).to(device=timesteps.device) 33 | args = timesteps[:, None].float() * freqs[None] 34 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 35 | if self.dim % 2: 36 | embedding = torch.cat( 37 | [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 38 | ) 39 | 40 | output = self.linear(embedding) 41 | return output 42 | 43 | 44 | class PoseEmbedding(nn.Module): 45 | def __init__(self, target_dim, n_harmonic_functions=10, append_input=True): 46 | super().__init__() 47 | 48 | self._emb_pose = HarmonicEmbedding( 49 | n_harmonic_functions=n_harmonic_functions, append_input=append_input 50 | ) 51 | # print("target_dim", target_dim) 52 | 53 | self.out_dim = self._emb_pose.get_output_dim(target_dim) 54 | 55 | def forward(self, pose_encoding): 56 | e_pose_encoding = self._emb_pose(pose_encoding) 57 | return e_pose_encoding 58 | -------------------------------------------------------------------------------- /utils/train_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import math 4 | import torch 5 | import logging 6 | import warnings 7 | import torch.optim 8 | import numpy as np 9 | import matplotlib 10 | import matplotlib.pyplot as plt 11 | from itertools import cycle 12 | from matplotlib import colors as mcolors 13 | from torch.utils.data import BatchSampler 14 | from pytorch3d.transforms import so3_relative_angle 15 | from pytorch3d.implicitron.tools.stats import Stats 16 | from pytorch3d.implicitron.tools.vis_utils import get_visdom_connection 17 | from accelerate.utils import set_seed as accelerate_set_seed 18 | import torch.optim.lr_scheduler as toptim 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | def set_seed_and_print(seed): 24 | accelerate_set_seed(seed, device_specific=True) 25 | print(f"----------Seed is set to {np.random.get_state()[1][0]} now----------") 26 | 27 | 28 | class VizStats(Stats): 29 | def plot_stats( 30 | self, viz=None, visdom_env=None, plot_file=None, visdom_server=None, visdom_port=None 31 | ): 32 | # use the cached visdom env if none supplied 33 | if visdom_env is None: 34 | visdom_env = self.visdom_env 35 | if visdom_server is None: 36 | visdom_server = self.visdom_server 37 | if visdom_port is None: 38 | visdom_port = self.visdom_port 39 | if plot_file is None: 40 | plot_file = self.plot_file 41 | 42 | stat_sets = list(self.stats.keys()) 43 | 44 | logger.debug( 45 | f"printing charts to visdom env '{visdom_env}' ({visdom_server}:{visdom_port})" 46 | ) 47 | 48 | novisdom = False 49 | 50 | if viz is None: 51 | viz = get_visdom_connection(server=visdom_server, port=visdom_port) 52 | 53 | if viz is None or not viz.check_connection(): 54 | logger.info("no visdom server! -> skipping visdom plots") 55 | novisdom = True 56 | 57 | lines = [] 58 | 59 | # plot metrics 60 | if not novisdom: 61 | viz.close(env=visdom_env, win=None) 62 | 63 | for stat in self.log_vars: 64 | vals = [] 65 | stat_sets_now = [] 66 | for stat_set in stat_sets: 67 | val = self.stats[stat_set][stat].get_epoch_averages() 68 | if val is None: 69 | continue 70 | else: 71 | val = np.array(val).reshape(-1) 72 | stat_sets_now.append(stat_set) 73 | vals.append(val) 74 | 75 | if len(vals) == 0: 76 | continue 77 | 78 | lines.append((stat_sets_now, stat, vals)) 79 | 80 | if not novisdom: 81 | for tmodes, stat, vals in lines: 82 | title = "%s" % stat 83 | opts = {"title": title, "legend": list(tmodes)} 84 | for i, (tmode, val) in enumerate(zip(tmodes, vals)): 85 | update = "append" if i > 0 else None 86 | valid = np.where(np.isfinite(val))[0] 87 | if len(valid) == 0: 88 | continue 89 | x = np.arange(len(val)) 90 | viz.line( 91 | Y=val[valid], 92 | X=x[valid], 93 | env=visdom_env, 94 | opts=opts, 95 | win=f"stat_plot_{title}", 96 | name=tmode, 97 | update=update, 98 | ) 99 | 100 | if plot_file: 101 | logger.info(f"plotting stats to {plot_file}") 102 | ncol = 3 103 | nrow = int(np.ceil(float(len(lines)) / ncol)) 104 | matplotlib.rcParams.update({"font.size": 5}) 105 | color = cycle(plt.cm.tab10(np.linspace(0, 1, 10))) 106 | fig = plt.figure(1) 107 | plt.clf() 108 | for idx, (tmodes, stat, vals) in enumerate(lines): 109 | c = next(color) 110 | plt.subplot(nrow, ncol, idx + 1) 111 | plt.gca() 112 | for vali, vals_ in enumerate(vals): 113 | c_ = c * (1.0 - float(vali) * 0.3) 114 | valid = np.where(np.isfinite(vals_))[0] 115 | if len(valid) == 0: 116 | continue 117 | x = np.arange(len(vals_)) 118 | plt.plot(x[valid], vals_[valid], c=c_, linewidth=1) 119 | plt.ylabel(stat) 120 | plt.xlabel("epoch") 121 | plt.gca().yaxis.label.set_color(c[0:3] * 0.75) 122 | plt.legend(tmodes) 123 | gcolor = np.array(mcolors.to_rgba("lightgray")) 124 | grid_params = {"visible": True, "color": gcolor} 125 | plt.grid(**grid_params, which="major", linestyle="-", linewidth=0.4) 126 | plt.grid(**grid_params, which="minor", linestyle="--", linewidth=0.2) 127 | plt.minorticks_on() 128 | 129 | plt.tight_layout() 130 | plt.show() 131 | try: 132 | fig.savefig(plot_file) 133 | except PermissionError: 134 | warnings.warn("Cant dump stats due to insufficient permissions!") 135 | 136 | 137 | def rotation_angle(rot_gt, rot_pred, batch_size=None): 138 | # rot_gt, rot_pred (B, 3, 3) 139 | # masks_flat: B, 1 140 | rel_angle_cos = so3_relative_angle(rot_gt, rot_pred, eps=1e-4) 141 | rel_rangle_deg = rel_angle_cos * 180 / np.pi 142 | 143 | if batch_size is not None: 144 | rel_rangle_deg = rel_rangle_deg.reshape(batch_size, -1) 145 | 146 | return rel_rangle_deg 147 | 148 | 149 | def translation_angle(tvec_gt, tvec_pred, batch_size=None): 150 | rel_tangle_deg = evaluate_translation_batch(tvec_gt, tvec_pred) 151 | rel_tangle_deg = rel_tangle_deg * 180.0 / np.pi 152 | 153 | if batch_size is not None: 154 | rel_tangle_deg = rel_tangle_deg.reshape(batch_size, -1) 155 | 156 | return rel_tangle_deg 157 | 158 | 159 | def evaluate_translation_batch(t_gt, t, eps=1e-15, default_err=1e6): 160 | """Normalize the translation vectors and compute the angle between them.""" 161 | t_norm = torch.norm(t, dim=1, keepdim=True) 162 | t = t / (t_norm + eps) 163 | 164 | t_gt_norm = torch.norm(t_gt, dim=1, keepdim=True) 165 | t_gt = t_gt / (t_gt_norm + eps) 166 | 167 | loss_t = torch.clamp_min(1.0 - torch.sum(t * t_gt, dim=1) ** 2, eps) 168 | err_t = torch.acos(torch.sqrt(1 - loss_t)) 169 | 170 | err_t[torch.isnan(err_t) | torch.isinf(err_t)] = default_err 171 | return err_t 172 | 173 | 174 | def batched_all_pairs(B, N): 175 | # B, N = se3.shape[:2] 176 | i1_, i2_ = torch.combinations( 177 | torch.arange(N), 2, with_replacement=False 178 | ).unbind(-1) 179 | i1, i2 = [ 180 | (i[None] + torch.arange(B)[:, None] * N).reshape(-1) 181 | for i in [i1_, i2_] 182 | ] 183 | 184 | return i1, i2 185 | 186 | 187 | class WarmupCosineRestarts(torch.optim.lr_scheduler._LRScheduler): 188 | def __init__(self, optimizer, T_0, iters_per_epoch, T_mult=1, eta_min=0, warmup_ratio=0.1, warmup_lr_init=1e-7, 189 | last_epoch=-1): 190 | self.T_0 = T_0 * iters_per_epoch # 50 * 156988 191 | self.T_mult = T_mult # 1 192 | self.eta_min = eta_min # 0 193 | self.warmup_iters = int(T_0 * warmup_ratio * iters_per_epoch) # int(50 * 156988 * 0.1 * 156988) 194 | self.warmup_lr_init = warmup_lr_init # 1e-7 195 | super(WarmupCosineRestarts, self).__init__(optimizer, last_epoch) 196 | 197 | def get_lr(self): 198 | if self.T_mult == 1: 199 | i_restart = self.last_epoch // self.T_0 # 算出是否需要restart 200 | T_cur = self.last_epoch - i_restart * self.T_0 # 201 | else: 202 | n = int(math.log((self.last_epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult)) 203 | T_cur = self.last_epoch - self.T_0 * (self.T_mult ** n - 1) // (self.T_mult - 1) 204 | 205 | if T_cur < self.warmup_iters: 206 | warmup_ratio = T_cur / self.warmup_iters 207 | return [self.warmup_lr_init + (base_lr - self.warmup_lr_init) * warmup_ratio for base_lr in self.base_lrs] 208 | else: 209 | T_cur_adjusted = T_cur - self.warmup_iters 210 | T_i = self.T_0 - self.warmup_iters 211 | return [self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * T_cur_adjusted / T_i)) / 2 212 | for base_lr in self.base_lrs] 213 | 214 | 215 | class WarmupCosineLR(toptim._LRScheduler): 216 | ''' Warmup learning rate scheduler. 217 | Initially, increases the learning rate from 0 to the final value, in a 218 | certain number of steps. After this number of steps, each step decreases 219 | LR exponentially. 220 | ''' 221 | 222 | def __init__(self, optimizer, lr, warmup_steps, momentum, max_steps): 223 | # cyclic params 224 | self.optimizer = optimizer 225 | self.lr = lr 226 | self.warmup_steps = warmup_steps 227 | self.momentum = momentum 228 | 229 | # cap to one 230 | if self.warmup_steps < 1: 231 | self.warmup_steps = 1 232 | 233 | # cyclic lr 234 | self.cosine_scheduler = toptim.CosineAnnealingLR( 235 | self.optimizer, T_max=max_steps) 236 | 237 | self.initial_scheduler = toptim.CyclicLR(self.optimizer, 238 | base_lr=0, 239 | max_lr=self.lr, 240 | step_size_up=self.warmup_steps, 241 | step_size_down=self.warmup_steps, 242 | cycle_momentum=False, 243 | base_momentum=self.momentum, 244 | max_momentum=self.momentum) 245 | 246 | self.last_epoch = -1 247 | self.finished = False 248 | super().__init__(optimizer) 249 | 250 | def step(self, epoch=None): 251 | if self.finished or self.initial_scheduler.last_epoch >= self.warmup_steps: 252 | if not self.finished: 253 | self.base_lrs = [self.lr for lr in self.base_lrs] 254 | self.finished = True 255 | return self.cosine_scheduler.step(epoch) 256 | else: 257 | return self.initial_scheduler.step(epoch) 258 | 259 | # 训练策略 260 | POWER = 0.9 261 | 262 | 263 | def lr_poly(base_lr, iter, max_iter, power): 264 | return base_lr * ((1 - float(iter) / max_iter) ** power) 265 | 266 | 267 | def lr_warmup(base_lr, iter, max_iter, warmup_iter): 268 | return base_lr * (float(iter) / warmup_iter) 269 | 270 | 271 | def adjust_learning_rate(lr, i_iter, max_iter, PREHEAT_STEPS): 272 | if i_iter < PREHEAT_STEPS: 273 | lr = lr_warmup(lr, i_iter, max_iter, PREHEAT_STEPS) 274 | else: 275 | lr = lr_poly(lr, i_iter, max_iter, POWER) 276 | 277 | return lr 278 | 279 | 280 | class DynamicBatchSampler(BatchSampler): 281 | def __init__(self, num_sequences, dataset_len=1024, max_images=128, images_per_seq=(3, 20)): 282 | # len(dataset): 32; cfg.train.len_train: 16384; max_image: 512; images_per_seq: [3, 51] 283 | # self.dataset = dataset 284 | self.max_images = max_images 285 | self.images_per_seq = list(range(images_per_seq[0], images_per_seq[1])) 286 | self.num_sequences = num_sequences 287 | self.dataset_len = dataset_len 288 | 289 | def _capped_random_choice(self, x, size, replace: bool = True): 290 | len_x = x if isinstance(x, int) else len(x) 291 | if replace: 292 | return np.random.choice(x, size=size, replace=len_x < size) 293 | else: 294 | return np.random.choice(x, size=min(size, len_x), replace=False) 295 | 296 | def __iter__(self): 297 | for batch_idx in range(self.dataset_len): 298 | # NOTE batch_idx is never used later 299 | # print(f"process {batch_idx}") 300 | n_per_seq = np.random.choice(self.images_per_seq) 301 | n_seqs = (self.max_images // n_per_seq) 302 | 303 | chosen_seq = self._capped_random_choice(self.num_sequences, n_seqs) 304 | # print(f"get the chosen_seq for {batch_idx}") 305 | 306 | batches = [(bidx, n_per_seq) for bidx in chosen_seq] 307 | # print(f"yield the batches for {batch_idx}") 308 | yield batches 309 | 310 | def __len__(self): 311 | return self.dataset_len 312 | 313 | 314 | class FixBatchSampler(torch.utils.data.Sampler): 315 | def __init__(self, dataset, dataset_len=1024, batch_size=64, max_images=128, images_per_seq=(3, 20)): 316 | # dataset; len_train: 16384; max_images: 512; images_per_seq: [3, 51] 317 | self.dataset = dataset 318 | self.max_images = max_images # 32 319 | self.images_per_seq = list(range(images_per_seq[0], images_per_seq[1])) # [3, ..., 21] 320 | self.num_sequences = len(self.dataset) # 4 321 | self.dataset_len = dataset_len # 156988 322 | self.batch_size = 48 # bath_size 323 | self.fix_images_per_seq = True 324 | 325 | def _capped_random_choice(self, x, size, replace: bool = True): 326 | len_x = x if isinstance(x, int) else len(x) 327 | if replace: 328 | return np.random.choice(x, size=size, replace=len_x < size) 329 | else: 330 | return np.random.choice(x, size=min(size, len_x), replace=False) 331 | 332 | def __iter__(self): 333 | for batch_idx in range(self.dataset_len): 334 | # NOTE batch_idx is never used later 335 | # print(f"process {batch_idx}") 336 | if self.fix_images_per_seq: 337 | # n_per_seq = 12 338 | n_per_seq = 1 339 | else: 340 | n_per_seq = np.random.choice(self.images_per_seq) 341 | 342 | n_seqs = self.batch_size 343 | 344 | chosen_seq = self._capped_random_choice(self.num_sequences, n_seqs) 345 | # print(f"get the chosen_seq for {batch_idx}") 346 | 347 | batches = [(bidx, n_per_seq) for bidx in chosen_seq] 348 | # print(f"yield the batches for {batch_idx}") 349 | yield batches 350 | 351 | def __len__(self): 352 | return self.dataset_len 353 | 354 | 355 | def find_last_checkpoint( 356 | exp_dir, any_path: bool = False, all_checkpoints: bool = False 357 | ): 358 | if any_path: 359 | exts = [".pth", "_stats.jgz", "_opt.pth"] 360 | else: 361 | exts = [".pth"] 362 | 363 | for ext in exts: 364 | fls = sorted( 365 | glob.glob( 366 | os.path.join(glob.escape(exp_dir), "model_epoch_" + "[0-9]" * 8 + ext) 367 | ) 368 | ) 369 | if len(fls) > 0: 370 | break 371 | # pyre-fixme[61]: `fls` is undefined, or not always defined. 372 | if len(fls) == 0: 373 | fl = None 374 | else: 375 | if all_checkpoints: 376 | # pyre-fixme[61]: `fls` is undefined, or not always defined. 377 | fl = [f[0: -len(ext)] + ".pth" for f in fls] 378 | else: 379 | # pyre-fixme[61]: `ext` is undefined, or not always defined. 380 | fl = fls[-1][0: -len(ext)] + ".pth" 381 | 382 | return fl 383 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | import tempfile 6 | 7 | 8 | def seed_all_random_engines(seed: int) -> None: 9 | np.random.seed(seed) 10 | torch.manual_seed(seed) 11 | random.seed(seed) 12 | 13 | 14 | def calc_vos_simple(poses): 15 | vos = [] 16 | for p in poses: 17 | pvos = [p[i+1].unsqueeze(0) - p[i].unsqueeze(0) for i in range(len(p)-1)] 18 | vos.append(torch.cat(pvos, dim=0)) 19 | vos = torch.stack(vos, dim=0) 20 | return vos --------------------------------------------------------------------------------