├── .gitignore ├── LICENSE ├── README.md ├── config ├── config_eval_json.py ├── config_vis.py └── nusc │ ├── baseline │ ├── baseline_120m.py │ ├── baseline_120m_cam.py │ ├── baseline_240m.py │ ├── baseline_240m_cam.py │ ├── baseline_60m.py │ └── baseline_60m_cam.py │ ├── hd_prior │ ├── hd_120m.py │ ├── hd_120m_cam.py │ ├── hd_240m.py │ ├── hd_240m_cam.py │ ├── hd_60m.py │ └── hd_60m_cam.py │ ├── hd_prior_pretrain │ ├── hd_pretrain_120m.py │ ├── hd_pretrain_240m.py │ └── hd_pretrain_60m.py │ └── sd_prior │ ├── sd_120m.py │ ├── sd_120m_cam.py │ ├── sd_240m.py │ ├── sd_240m_cam.py │ ├── sd_60m.py │ └── sd_60m_cam.py ├── data_osm ├── __init__.py ├── av2_dataset.py ├── av2map_extractor.py ├── const.py ├── dataset.py ├── geo_opensfm.py ├── image.py ├── lidar.py ├── osm │ ├── boston-seaport.cpg │ ├── boston-seaport.dbf │ ├── boston-seaport.prj │ ├── boston-seaport.shp │ ├── boston-seaport.shx │ ├── sd_map_data_ATX.pkl │ ├── sd_map_data_DTW.pkl │ ├── sd_map_data_MIA.pkl │ ├── sd_map_data_PAO.pkl │ ├── sd_map_data_PIT.pkl │ ├── sd_map_data_WDC.pkl │ ├── singapore-hollandvillage.cpg │ ├── singapore-hollandvillage.dbf │ ├── singapore-hollandvillage.prj │ ├── singapore-hollandvillage.shp │ ├── singapore-hollandvillage.shx │ ├── singapore-onenorth.cpg │ ├── singapore-onenorth.dbf │ ├── singapore-onenorth.prj │ ├── singapore-onenorth.shp │ ├── singapore-onenorth.shx │ ├── singapore-queenstown.cpg │ ├── singapore-queenstown.dbf │ ├── singapore-queenstown.prj │ ├── singapore-queenstown.shp │ └── singapore-queenstown.shx ├── pipelines │ ├── __init__.py │ ├── formating.py │ ├── loading.py │ ├── transform.py │ └── vectorize.py ├── rasterize.py ├── utils.py └── vector_map.py ├── docs ├── getting_started.md ├── installation.md └── visualization.md ├── environment.yml ├── figs └── teaser.jpg ├── icon ├── car.png └── car_gray.png ├── model ├── __init__.py ├── hdmapnet.py ├── lift_splat.py ├── pmapnet_hd.py ├── pmapnet_sd.py └── utils │ ├── VPN.py │ ├── __init__.py │ ├── base.py │ ├── homography.py │ ├── map_mae_head.py │ ├── misc.py │ ├── pointpillar.py │ ├── position_encoding.py │ ├── sdmap_cross_attn.py │ ├── utils.py │ └── voxel.py ├── requirements.txt ├── tools ├── config.py ├── eval.py ├── evaluate_json.py ├── evaluation │ ├── AP.py │ ├── __init__.py │ ├── angle_diff.py │ ├── chamfer_distance.py │ ├── dataset.py │ ├── iou.py │ └── modules │ │ ├── lpips.py │ │ ├── networks.py │ │ └── utils.py ├── export_json.py ├── loss.py ├── postprocess │ ├── __init__.py │ ├── cluster.py │ ├── connect.py │ └── vectorize.py ├── vis_map.py ├── vis_video_av2.py └── vis_video_nus.py ├── train.py └── train_HDPrior_pretrain.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.ipynb 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | .hypothesis/ 49 | .pytest_cache/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | db.sqlite3 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # pyenv 77 | .python-version 78 | 79 | # celery beat schedule file 80 | celerybeat-schedule 81 | 82 | # SageMath parsed files 83 | *.sage.py 84 | 85 | # Environments 86 | .env 87 | .venv 88 | env/ 89 | venv/ 90 | ENV/ 91 | env.bak/ 92 | venv.bak/ 93 | 94 | # Spyder project settings 95 | .spyderproject 96 | .spyproject 97 | 98 | # Rope project settings 99 | .ropeproject 100 | 101 | # mkdocs documentation 102 | /site 103 | 104 | # mypy 105 | .mypy_cache/ 106 | 107 | # cython generated cpp 108 | data 109 | .vscode 110 | .idea 111 | 112 | # custom 113 | # *.pkl 114 | *.gif 115 | *.pkl.json 116 | *.log.json 117 | work_dirs/ 118 | debug_img/ 119 | model_file/ 120 | exps/ 121 | *~ 122 | mmdet3d/.mim 123 | mmdetection3d 124 | # Pytorch 125 | *.pth 126 | 127 | # demo 128 | demo/ 129 | *.obj 130 | *.ply 131 | *.zip 132 | *.tar 133 | *.tar.gz 134 | *.json 135 | 136 | # datasets 137 | /datasets 138 | /data_ann 139 | 140 | # softlinks 141 | av2 142 | nuScenes 143 | dataset 144 | 145 | Work_dir -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

P-MapNet: Far-seeing Map Generator Enhanced by both SDMap and HDMap Priors

3 | 4 | [[RAL](https://ieeexplore.ieee.org/document/10643284)] [[Paper](https://arxiv.org/pdf/2403.10521.pdf)] [[Project Page](https://jike5.github.io/P-MapNet/)] 5 | 6 |
7 | 8 | ![visualization](figs/teaser.jpg) 9 | **Abstract:** 10 | Autonomous vehicles are gradually entering city roads today, with the help of high-definition maps (HDMaps). However, the reliance on HDMaps prevents autonomous vehicles from stepping into regions without this expensive digital infrastructure. This fact drives many researchers to study online HDMap construction algorithms, but the performance of these algorithms at far regions is still unsatisfying. We present P-MapNet, in which the letter P highlights the fact that we focus on incorporating map priors to improve model performance. Specifically, we exploit priors in both SDMap and HDMap. On one hand, we extract weakly aligned SDMap from OpenStreetMap, and encode it as an additional conditioning branch. Despite the misalignment challenge, our attention-based architecture adaptively attends to relevant SDMap skeletons and significantly improves performance. On the other hand, we exploit a masked autoencoder to capture the prior distribution of HDMap, which can serve as a refinement module to mitigate occlusions and artifacts. We benchmark on the nuScenes and Argoverse2 datasets. 11 | Through comprehensive experiments, we show that: (1) our SDMap prior can improve online map construction performance, using both rasterized (by up to +18.73 mIoU) and vectorized (by up to +8.50 mAP) output representations. (2) our HDMap prior can improve map perceptual metrics by up to 6.34%. (3) 12 | P-MapNet can be switched into different inference modes that covers different regions of the accuracy-efficiency trade-off landscape. (4) P-MapNet is a far-seeing solution that brings larger improvements on longer ranges. 13 | 14 | ## Model 15 | 16 | ### Results on nuScenes-val set 17 | We provide results on nuScenes-val set. 18 | 19 | | Range | Method | M | Div. | Ped. | Bound. | mIoU | Model | Config | 20 | |:-----------:|:--------:|:---:|:---:|:---:|:-----:|:--------:|:--------:|:--------:| 21 | | 60 × 30 | HDMapNet | L+C | 45.9 | 30.5 | 56.8 | 44.40 | [ckpt](https://drive.google.com/file/d/1yYCRk_as7Vhvi_rL5BxqVrmEf_u7mB3b/view?usp=drive_link) | [cfg](config/nusc/baseline/baseline_60m.py) | 22 | | 60 × 30 | P-MapNet(SD+HD Prio.) | L+C | **54.2** | **41.3** | **63.7** | **53.07** | [ckpt](https://drive.google.com/file/d/1hr9QNRDOWmiqZcW2L5WY_o_0aIZFIo0W/view?usp=drive_link) | [cfg](config/nusc/hd_prior/hd_60m.py) | 23 | | 120 × 60 | HDMapNet | L+C | 53.6 | 37.8 | 57.1 | 49.50 | [ckpt](https://drive.google.com/file/d/1L_3whc53FmEdGh8Fn1EVS7xquX0_xHZJ/view?usp=drive_link) | [cfg](config/nusc/baseline/baseline_120m.py) | 24 | | 120 × 60 | P-MapNet(SD+HD Prio.) | L+C | **65.3** | **52.0** | **68.0** | **61.77** | [ckpt](https://drive.google.com/file/d/1MG10vfqFDnf4sYiDqdO2274LlQB670ne/view?usp=drive_link) | [cfg](config/nusc/hd_prior/hd_120m.py) | 25 | | 240 × 60 | HDMapNet | L+C | 40.0 | 26.8 | 42.6 | 36.47 | [ckpt](https://drive.google.com/file/d/1oKjYPXVxu0MwDzrOJ97r-0b2GBnKxK12/view?usp=drive_link) | [cfg](config/nusc/baseline/baseline_240m.py) | 26 | | 240 × 60 | P-MapNet(SD+HD Prio.) | L+C | **53.0** | **42.6** | **54.2** | **49.93** | [ckpt](https://drive.google.com/file/d/1lcA9U9oWKYM9X20gblBaG16I2DBLt2yU/view?usp=drive_link) | [cfg](config/nusc/hd_prior/hd_240m.py) | 27 | 28 | > The model weights under **other settings** can be downloaded at [GoogleDrive](https://drive.google.com/drive/folders/1P6LuhsHy3yy4sGwlDCGT9tjVzYpcaqEb?usp=drive_link) or [百度云](https://pan.baidu.com/s/1OVI3aWgOGGg6_iGCs_gxDg?pwd=65aa). 29 | 30 | ## Getting Started 31 | - [Installation](docs/installation.md) 32 | - [Train and Eval](docs/getting_started.md) 33 | - [visualization](docs/visualization.md) 34 | 35 | 36 | 37 | ### TODO 38 | - [ ] Add Argoverse2 dataset model 39 | 40 | ### Citation 41 | If you found this paper or codebase useful, please cite our paper: 42 | ``` 43 | @ARTICLE{10643284, 44 | author={Jiang, Zhou and Zhu, Zhenxin and Li, Pengfei and Gao, Huan-ang and Yuan, Tianyuan and Shi, Yongliang and Zhao, Hang and Zhao, Hao}, 45 | journal={IEEE Robotics and Automation Letters}, 46 | title={P-MapNet: Far-Seeing Map Generator Enhanced by Both SDMap and HDMap Priors}, 47 | year={2024}, 48 | volume={9}, 49 | number={10}, 50 | pages={8539-8546}, 51 | keywords={Feature extraction;Skeleton;Laser radar;Generators;Encoding;Point cloud compression;Autonomous vehicles;Computer vision for transportation;semantic scene understanding;intelligent transportation systems}, 52 | doi={10.1109/LRA.2024.3447450}} 53 | 54 | ``` 55 | -------------------------------------------------------------------------------- /config/config_eval_json.py: -------------------------------------------------------------------------------- 1 | result_path = './120_sd.json' 2 | dataroot = './dataset/nuScenes' 3 | version= 'v1.0-trainval' #'v1.0-mini' 4 | 5 | CD_threshold = 5 6 | threshold_iou = 0.1 7 | xbound = [-60.0, 60.0, 0.3] 8 | ybound = [-30.0, 30.0, 0.3] 9 | batch_size = 4 10 | eval_set = 'val' #'train', 'val', 'test', 'mini_train', 'mini_val' 11 | thickness = 5 12 | max_channel = 3 13 | bidirectional = False 14 | -------------------------------------------------------------------------------- /config/config_vis.py: -------------------------------------------------------------------------------- 1 | # DATA 2 | dataset='nuScenes' 3 | dataroot = './dataset/nuScenes' 4 | version= 'v1.0-trainval' #'v1.0-mini' 5 | 6 | xbound = [-60.0, 60.0, 0.3] #120m*60m, bev_size:400*200 7 | ybound = [-30.0, 30.0, 0.3] 8 | 9 | zbound = [-10.0, 10.0, 20.0] 10 | dbound = [4.0, 45.0, 1.0] 11 | image_size = [128, 352] 12 | thickness = 5 13 | # Path 14 | vis_path = './vis_map' 15 | sd_map_path='./data_osm/osm' 16 | 17 | # CHECK_POINTS 18 | modelf = 'ckpt/fusion_120_sd_model23.pt' 19 | 20 | # Model 21 | model = 'pmapnet_sd' 22 | 23 | # Morphological_process mode in the vectorized post-process 24 | morpho_mode='MORPH_CLOSE' # 'MORPH_OPEN', 'None' 25 | 26 | batch_size = 1 27 | nworkers = 20 28 | gpus = [0] 29 | 30 | 31 | direction_pred = True 32 | instance_seg = True 33 | embedding_dim = 16 34 | delta_v = 0.5 35 | delta_d = 3.0 36 | angle_class = 36 37 | 38 | # Mask config 39 | mask_flag = False 40 | mask_ratio = -1 # random ratio 41 | patch_h = 20 42 | patch_w = 20 43 | 44 | 45 | 46 | 47 | -------------------------------------------------------------------------------- /config/nusc/baseline/baseline_120m.py: -------------------------------------------------------------------------------- 1 | # DATA 2 | dataset='nuScenes' 3 | dataroot = './dataset/nuScenes' 4 | version= 'v1.0-trainval' 5 | 6 | xbound = [-60.0, 60.0, 0.3] #120m*60m, bev_size:400*200 7 | ybound = [-30.0, 30.0, 0.3] 8 | 9 | zbound = [-10.0, 10.0, 20.0] 10 | dbound = [4.0, 45.0, 1.0] 11 | image_size = [128, 352] 12 | thickness = 5 13 | # EXP 14 | logdir = './Work_dir/baseline_120' 15 | sd_map_path='./data_osm/osm' 16 | # TRAIN 17 | model = 'HDMapNet_fusion' 18 | nepochs = 30 19 | batch_size = 8 20 | nworkers = 20 21 | gpus = [0, 1, 2, 3] 22 | 23 | # OPT 24 | lr = 5e-4 25 | weight_decay = 1e-7 26 | max_grad_norm = 5.0 27 | pos_weight = 2.13 28 | steplr = 10 29 | 30 | # CHECK_POINTS 31 | modelf = None 32 | 33 | # LOSS 34 | scale_seg = 1.0 35 | scale_var = 0.1 36 | scale_dist = 0.1 37 | scale_direction = 0.1 38 | 39 | direction_pred = True 40 | instance_seg = True 41 | embedding_dim = 16 42 | delta_v = 0.5 43 | delta_d = 3.0 44 | angle_class = 36 45 | 46 | # Mask config 47 | mask_flag = False 48 | mask_ratio = -1 # random ratio 49 | patch_h = 20 50 | patch_w = 20 51 | 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /config/nusc/baseline/baseline_120m_cam.py: -------------------------------------------------------------------------------- 1 | # DATA 2 | dataset='nuScenes' 3 | dataroot = './dataset/nuScenes' 4 | version= 'v1.0-trainval' 5 | 6 | xbound = [-60.0, 60.0, 0.3] #120m*60m, bev_size:400*200 7 | ybound = [-30.0, 30.0, 0.3] 8 | 9 | zbound = [-10.0, 10.0, 20.0] 10 | dbound = [4.0, 45.0, 1.0] 11 | image_size = [128, 352] 12 | thickness = 5 13 | # EXP 14 | logdir = './Work_dir/baseline_cam_120' 15 | sd_map_path='./data_osm/osm' 16 | # TRAIN 17 | model = 'HDMapNet_cam' 18 | nepochs = 30 19 | batch_size = 8 20 | nworkers = 20 21 | gpus = [0, 1, 2, 3] 22 | 23 | # OPT 24 | lr = 5e-4 25 | weight_decay = 1e-7 26 | max_grad_norm = 5.0 27 | pos_weight = 2.13 28 | steplr = 10 29 | 30 | # CHECK_POINTS 31 | modelf = None 32 | 33 | # LOSS 34 | scale_seg = 1.0 35 | scale_var = 0.1 36 | scale_dist = 0.1 37 | scale_direction = 0.1 38 | 39 | direction_pred = True 40 | instance_seg = True 41 | embedding_dim = 16 42 | delta_v = 0.5 43 | delta_d = 3.0 44 | angle_class = 36 45 | 46 | # Mask config 47 | mask_flag = False 48 | mask_ratio = -1 # random ratio 49 | patch_h = 20 50 | patch_w = 20 51 | 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /config/nusc/baseline/baseline_240m.py: -------------------------------------------------------------------------------- 1 | # DATA 2 | dataset='nuScenes' 3 | dataroot = './dataset/nuScenes' 4 | version= 'v1.0-trainval' 5 | 6 | xbound = [-120.0, 120.0, 0.3] #240m*60m, bev_size:800*200 7 | ybound = [-30.0, 30.0, 0.3] 8 | 9 | zbound = [-10.0, 10.0, 20.0] 10 | dbound = [4.0, 45.0, 1.0] 11 | image_size = [128, 352] 12 | thickness = 5 13 | # EXP 14 | logdir = './Work_dir/baseline_240' 15 | sd_map_path='./data_osm/osm' 16 | # TRAIN 17 | model = 'HDMapNet_fusion' 18 | nepochs = 30 19 | batch_size = 8 20 | nworkers = 20 21 | gpus = [0, 1, 2, 3] 22 | 23 | # OPT 24 | lr = 5e-4 25 | weight_decay = 1e-7 26 | max_grad_norm = 5.0 27 | pos_weight = 2.13 28 | steplr = 10 29 | 30 | # CHECK_POINTS 31 | modelf = None 32 | 33 | # LOSS 34 | scale_seg = 1.0 35 | scale_var = 0.1 36 | scale_dist = 0.1 37 | scale_direction = 0.1 38 | 39 | direction_pred = True 40 | instance_seg = True 41 | embedding_dim = 16 42 | delta_v = 0.5 43 | delta_d = 3.0 44 | angle_class = 36 45 | 46 | # Mask config 47 | mask_flag = False 48 | mask_ratio = -1 # random ratio 49 | patch_h = 20 50 | patch_w = 20 51 | 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /config/nusc/baseline/baseline_240m_cam.py: -------------------------------------------------------------------------------- 1 | # DATA 2 | dataset='nuScenes' 3 | dataroot = './dataset/nuScenes' 4 | version= 'v1.0-trainval' 5 | 6 | xbound = [-120.0, 120.0, 0.3] #240m*60m, bev_size:800*200 7 | ybound = [-30.0, 30.0, 0.3] 8 | 9 | zbound = [-10.0, 10.0, 20.0] 10 | dbound = [4.0, 45.0, 1.0] 11 | image_size = [128, 352] 12 | thickness = 5 13 | # EXP 14 | logdir = './Work_dir/baseline_cam_240' 15 | sd_map_path='./data_osm/osm' 16 | # TRAIN 17 | model = 'HDMapNet_cam' 18 | nepochs = 30 19 | batch_size = 8 20 | nworkers = 20 21 | gpus = [0, 1, 2, 3] 22 | 23 | # OPT 24 | lr = 5e-4 25 | weight_decay = 1e-7 26 | max_grad_norm = 5.0 27 | pos_weight = 2.13 28 | steplr = 10 29 | 30 | # CHECK_POINTS 31 | modelf = None 32 | 33 | # LOSS 34 | scale_seg = 1.0 35 | scale_var = 0.1 36 | scale_dist = 0.1 37 | scale_direction = 0.1 38 | 39 | direction_pred = True 40 | instance_seg = True 41 | embedding_dim = 16 42 | delta_v = 0.5 43 | delta_d = 3.0 44 | angle_class = 36 45 | 46 | # Mask config 47 | mask_flag = False 48 | mask_ratio = -1 # random ratio 49 | patch_h = 20 50 | patch_w = 20 51 | 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /config/nusc/baseline/baseline_60m.py: -------------------------------------------------------------------------------- 1 | # DATA 2 | dataset='nuScenes' 3 | dataroot = './dataset/' 4 | version= 'v1.0-trainval' 5 | 6 | xbound = [-30, 30.0, 0.15] # 60m*30m, bev_size:400*200 7 | ybound = [-15.0, 15.0, 0.15] 8 | 9 | zbound = [-10.0, 10.0, 20.0] 10 | dbound = [4.0, 45.0, 1.0] 11 | image_size = [128, 352] 12 | thickness = 5 13 | # EXP 14 | logdir = './Work_dir/baseline_60' 15 | sd_map_path='./data_osm/osm' 16 | # TRAIN 17 | model = 'HDMapNet_fusion' 18 | nepochs = 30 19 | batch_size = 16 20 | nworkers = 10 21 | gpus = [0, 1, 2, 3] 22 | 23 | # OPT 24 | lr = 5e-4 25 | weight_decay = 1e-7 26 | max_grad_norm = 5.0 27 | pos_weight = 2.13 28 | steplr = 10 29 | 30 | # CHECK_POINTS 31 | modelf = None 32 | 33 | # LOSS 34 | scale_seg = 1.0 35 | scale_var = 0.1 36 | scale_dist = 0.1 37 | scale_direction = 0.1 38 | 39 | direction_pred = True 40 | instance_seg = True 41 | embedding_dim = 16 42 | delta_v = 0.5 43 | delta_d = 3.0 44 | angle_class = 36 45 | 46 | # Mask config 47 | mask_flag = False 48 | mask_ratio = -1 # random ratio 49 | patch_h = 20 50 | patch_w = 20 51 | 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /config/nusc/baseline/baseline_60m_cam.py: -------------------------------------------------------------------------------- 1 | # DATA 2 | dataset='nuScenes' 3 | dataroot = './dataset/nuScenes' 4 | version= 'v1.0-trainval' 5 | 6 | xbound = [-30, 30.0, 0.15] # 60m*30m, bev_size:400*200 7 | ybound = [-15.0, 15.0, 0.15] 8 | 9 | zbound = [-10.0, 10.0, 20.0] 10 | dbound = [4.0, 45.0, 1.0] 11 | image_size = [128, 352] 12 | thickness = 5 13 | # EXP 14 | logdir = './Work_dir/baseline_cam_60' 15 | sd_map_path='./data_osm/osm' 16 | # TRAIN 17 | model = 'HDMapNet_cam' 18 | nepochs = 30 19 | batch_size = 8 20 | nworkers = 20 21 | gpus = [0, 1, 2, 3] 22 | 23 | # OPT 24 | lr = 5e-4 25 | weight_decay = 1e-7 26 | max_grad_norm = 5.0 27 | pos_weight = 2.13 28 | steplr = 10 29 | 30 | # CHECK_POINTS 31 | modelf = None 32 | 33 | # NETWORK 34 | use_aux = True 35 | griding_num = 200 36 | backbone = '18' 37 | 38 | # LOSS 39 | scale_seg = 1.0 40 | scale_var = 0.1 41 | scale_dist = 0.1 42 | scale_direction = 0.1 43 | 44 | direction_pred = True 45 | instance_seg = True 46 | embedding_dim = 16 47 | delta_v = 0.5 48 | delta_d = 3.0 49 | angle_class = 36 50 | 51 | # Mask config 52 | mask_flag = False 53 | mask_ratio = -1 # random ratio 54 | patch_h = 20 55 | patch_w = 20 56 | 57 | 58 | 59 | 60 | -------------------------------------------------------------------------------- /config/nusc/hd_prior/hd_120m.py: -------------------------------------------------------------------------------- 1 | # DATA 2 | dataset='nuScenes' 3 | dataroot = './dataset/nuScenes' 4 | version= 'v1.0-trainval' 5 | 6 | xbound = [-60.0, 60.0, 0.3] #120m*60m, bev_size:400*200 7 | ybound = [-30.0, 30.0, 0.3] 8 | 9 | zbound = [-10.0, 10.0, 20.0] 10 | dbound = [4.0, 45.0, 1.0] 11 | image_size = [128, 352] 12 | thickness = 5 13 | # EXP 14 | logdir = './Work_dir/hd_120' 15 | sd_map_path='./data_osm/osm' 16 | # TRAIN 17 | model = 'pmapnet_hd' 18 | nepochs = 30 19 | batch_size = 8 20 | nworkers = 20 21 | gpus = [0, 1, 2, 3] 22 | 23 | # OPT 24 | lr = 5e-4 25 | weight_decay = 1e-7 26 | max_grad_norm = 5.0 27 | pos_weight = 2.13 28 | steplr = 10 29 | 30 | # CHECK_POINTS 31 | modelf = None 32 | 33 | # LOSS 34 | scale_seg = 1.0 35 | scale_var = 0.1 36 | scale_dist = 0.1 37 | scale_direction = 0.1 38 | 39 | direction_pred = True 40 | instance_seg = True 41 | embedding_dim = 16 42 | delta_v = 0.5 43 | delta_d = 3.0 44 | angle_class = 36 45 | 46 | # Mask config 47 | mask_flag = False 48 | mask_ratio = -1 # random ratio 49 | patch_h = 20 50 | patch_w = 20 51 | 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /config/nusc/hd_prior/hd_120m_cam.py: -------------------------------------------------------------------------------- 1 | # DATA 2 | dataset='nuScenes' 3 | dataroot = './dataset/nuScenes' 4 | version= 'v1.0-trainval' 5 | 6 | xbound = [-60.0, 60.0, 0.3] #120m*60m, bev_size:400*200 7 | ybound = [-30.0, 30.0, 0.3] 8 | 9 | zbound = [-10.0, 10.0, 20.0] 10 | dbound = [4.0, 45.0, 1.0] 11 | image_size = [128, 352] 12 | thickness = 5 13 | # EXP 14 | logdir = './Work_dir/hd_cam_120' 15 | sd_map_path='./data_osm/osm' 16 | # TRAIN 17 | model = 'pmapnet_hd_cam' 18 | nepochs = 30 19 | batch_size = 8 20 | nworkers = 20 21 | gpus = [0, 1, 2, 3] 22 | 23 | # OPT 24 | lr = 5e-4 25 | weight_decay = 1e-7 26 | max_grad_norm = 5.0 27 | pos_weight = 2.13 28 | steplr = 10 29 | 30 | # CHECK_POINTS 31 | modelf = None 32 | 33 | # LOSS 34 | scale_seg = 1.0 35 | scale_var = 0.1 36 | scale_dist = 0.1 37 | scale_direction = 0.1 38 | 39 | direction_pred = True 40 | instance_seg = True 41 | embedding_dim = 16 42 | delta_v = 0.5 43 | delta_d = 3.0 44 | angle_class = 36 45 | 46 | # Mask config 47 | mask_flag = False 48 | mask_ratio = -1 # random ratio 49 | patch_h = 20 50 | patch_w = 20 51 | 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /config/nusc/hd_prior/hd_240m.py: -------------------------------------------------------------------------------- 1 | # DATA 2 | dataset='nuScenes' 3 | dataroot = './dataset/nuScenes' 4 | version= 'v1.0-trainval' 5 | 6 | xbound = [-120.0, 120.0, 0.3] #240m*60m, bev_size:800*200 7 | ybound = [-30.0, 30.0, 0.3] 8 | 9 | zbound = [-10.0, 10.0, 20.0] 10 | dbound = [4.0, 45.0, 1.0] 11 | image_size = [128, 352] 12 | thickness = 5 13 | # EXP 14 | logdir = './Work_dir/hd_240' 15 | sd_map_path='./data_osm/osm' 16 | # TRAIN 17 | model = 'pmapnet_hd' 18 | nepochs = 30 19 | batch_size = 8 20 | nworkers = 20 21 | gpus = [0, 1, 2, 3] 22 | 23 | # OPT 24 | lr = 5e-4 25 | weight_decay = 1e-7 26 | max_grad_norm = 5.0 27 | pos_weight = 2.13 28 | steplr = 10 29 | 30 | # CHECK_POINTS 31 | modelf = None 32 | 33 | # LOSS 34 | scale_seg = 1.0 35 | scale_var = 0.1 36 | scale_dist = 0.1 37 | scale_direction = 0.1 38 | 39 | direction_pred = True 40 | instance_seg = True 41 | embedding_dim = 16 42 | delta_v = 0.5 43 | delta_d = 3.0 44 | angle_class = 36 45 | 46 | # Mask config 47 | mask_flag = False 48 | mask_ratio = -1 # random ratio 49 | patch_h = 20 50 | patch_w = 20 51 | 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /config/nusc/hd_prior/hd_240m_cam.py: -------------------------------------------------------------------------------- 1 | # DATA 2 | dataset='nuScenes' 3 | dataroot = './dataset/nuScenes' 4 | version= 'v1.0-trainval' 5 | 6 | xbound = [-120.0, 120.0, 0.3] #240m*60m, bev_size:800*200 7 | ybound = [-30.0, 30.0, 0.3] 8 | 9 | zbound = [-10.0, 10.0, 20.0] 10 | dbound = [4.0, 45.0, 1.0] 11 | image_size = [128, 352] 12 | thickness = 5 13 | # EXP 14 | logdir = './Work_dir/hd_cam_240' 15 | sd_map_path='./data_osm/osm' 16 | # TRAIN 17 | model = 'pmapnet_hd_cam' 18 | nepochs = 30 19 | batch_size = 8 20 | nworkers = 20 21 | gpus = [0, 1, 2, 3] 22 | 23 | # OPT 24 | lr = 5e-4 25 | weight_decay = 1e-7 26 | max_grad_norm = 5.0 27 | pos_weight = 2.13 28 | steplr = 10 29 | 30 | # CHECK_POINTS 31 | modelf = None 32 | 33 | # LOSS 34 | scale_seg = 1.0 35 | scale_var = 0.1 36 | scale_dist = 0.1 37 | scale_direction = 0.1 38 | 39 | direction_pred = True 40 | instance_seg = True 41 | embedding_dim = 16 42 | delta_v = 0.5 43 | delta_d = 3.0 44 | angle_class = 36 45 | 46 | # Mask config 47 | mask_flag = False 48 | mask_ratio = -1 # random ratio 49 | patch_h = 20 50 | patch_w = 20 51 | 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /config/nusc/hd_prior/hd_60m.py: -------------------------------------------------------------------------------- 1 | # DATA 2 | dataset='nuScenes' 3 | dataroot = './dataset/nuScenes' 4 | version= 'v1.0-trainval' 5 | 6 | xbound = [-30.0, 30.0, 0.15] # 60m*30m, bev_size:400*200 7 | ybound = [-15.0, 15.0, 0.15] 8 | 9 | zbound = [-10.0, 10.0, 20.0] 10 | dbound = [4.0, 45.0, 1.0] 11 | image_size = [128, 352] 12 | thickness = 5 13 | # EXP 14 | logdir = './Work_dir/hd_60' 15 | sd_map_path='./data_osm/osm' 16 | # TRAIN 17 | model = 'pmapnet_hd' 18 | nepochs = 30 19 | batch_size = 8 20 | nworkers = 20 21 | gpus = [0, 1, 2, 3] 22 | 23 | # OPT 24 | lr = 5e-4 25 | weight_decay = 1e-7 26 | max_grad_norm = 5.0 27 | pos_weight = 2.13 28 | steplr = 10 29 | 30 | # CHECK_POINTS 31 | modelf = None 32 | 33 | # LOSS 34 | scale_seg = 1.0 35 | scale_var = 0.1 36 | scale_dist = 0.1 37 | scale_direction = 0.1 38 | 39 | direction_pred = True 40 | instance_seg = True 41 | embedding_dim = 16 42 | delta_v = 0.5 43 | delta_d = 3.0 44 | angle_class = 36 45 | 46 | # Mask config 47 | mask_flag = False 48 | mask_ratio = -1 # random ratio 49 | patch_h = 20 50 | patch_w = 20 51 | 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /config/nusc/hd_prior/hd_60m_cam.py: -------------------------------------------------------------------------------- 1 | # DATA 2 | dataset='nuScenes' 3 | dataroot = './dataset/nuScenes' 4 | version= 'v1.0-trainval' 5 | 6 | xbound = [-30.0, 30.0, 0.15] #60m*30m, bev_size:400*200 7 | ybound = [-15.0, 15.0, 0.15] 8 | 9 | zbound = [-10.0, 10.0, 20.0] 10 | dbound = [4.0, 45.0, 1.0] 11 | image_size = [128, 352] 12 | thickness = 5 13 | # EXP 14 | logdir = './Work_dir/hd_cam_60' 15 | sd_map_path='./data_osm/osm' 16 | # TRAIN 17 | model = 'pmapnet_hd_cam' 18 | nepochs = 30 19 | batch_size = 8 20 | nworkers = 20 21 | gpus = [0, 1, 2, 3] 22 | 23 | # OPT 24 | lr = 5e-4 25 | weight_decay = 1e-7 26 | max_grad_norm = 5.0 27 | pos_weight = 2.13 28 | steplr = 10 29 | 30 | # CHECK_POINTS 31 | modelf = None 32 | 33 | # NETWORK 34 | use_aux = True 35 | griding_num = 200 36 | backbone = '18' 37 | 38 | # LOSS 39 | scale_seg = 1.0 40 | scale_var = 0.1 41 | scale_dist = 0.1 42 | scale_direction = 0.1 43 | 44 | direction_pred = True 45 | instance_seg = True 46 | embedding_dim = 16 47 | delta_v = 0.5 48 | delta_d = 3.0 49 | angle_class = 36 50 | 51 | # Mask config 52 | mask_flag = False 53 | mask_ratio = -1 # random ratio 54 | patch_h = 20 55 | patch_w = 20 56 | 57 | 58 | 59 | 60 | -------------------------------------------------------------------------------- /config/nusc/hd_prior_pretrain/hd_pretrain_120m.py: -------------------------------------------------------------------------------- 1 | # DATA 2 | dataset='nuScenes' 3 | dataroot = './dataset/nuScenes' 4 | version= 'v1.0-trainval' 5 | 6 | xbound = [-60.0, 60.0, 0.3] #120m*60m, bev_size:400*200 7 | ybound = [-30.0, 30.0, 0.3] 8 | 9 | zbound = [-10.0, 10.0, 20.0] 10 | dbound = [4.0, 45.0, 1.0] 11 | image_size = [128, 352] 12 | thickness = 5 13 | # EXP 14 | logdir = './Work_dir/pretrain_120' 15 | sd_map_path='./data_osm/osm' 16 | # TRAIN 17 | model = 'hdmapnet_pretrain' 18 | nepochs = 30 19 | batch_size = 8 20 | nworkers = 20 21 | gpus = [0, 1, 2, 3] 22 | 23 | # OPT 24 | lr = 5e-4 25 | weight_decay = 1e-7 26 | max_grad_norm = 5.0 27 | pos_weight = 2.13 28 | steplr = 10 29 | 30 | # CHECK_POINTS 31 | modelf = None 32 | 33 | # LOSS 34 | scale_seg = 1.0 35 | scale_var = 0.1 36 | scale_dist = 0.1 37 | scale_direction = 0.1 38 | 39 | direction_pred = True 40 | instance_seg = True 41 | embedding_dim = 16 42 | delta_v = 0.5 43 | delta_d = 3.0 44 | angle_class = 36 45 | 46 | # Mask config 47 | mask_flag = True 48 | mask_ratio = -1 # random ratio 49 | patch_h = 20 50 | patch_w = 20 51 | 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /config/nusc/hd_prior_pretrain/hd_pretrain_240m.py: -------------------------------------------------------------------------------- 1 | # DATA 2 | dataset='nuScenes' 3 | dataroot = './dataset/nuScenes' 4 | version= 'v1.0-trainval' 5 | 6 | xbound = [-120.0, 120.0, 0.3] #240m*60m, bev_size:800*200 7 | ybound = [-30.0, 30.0, 0.3] 8 | 9 | zbound = [-10.0, 10.0, 20.0] 10 | dbound = [4.0, 45.0, 1.0] 11 | image_size = [128, 352] 12 | thickness = 5 13 | # EXP 14 | logdir = './Work_dir/pretrain_240' 15 | sd_map_path='./data_osm/osm' 16 | # TRAIN 17 | model = 'hdmapnet_pretrain' 18 | nepochs = 30 19 | batch_size = 8 20 | nworkers = 20 21 | gpus = [0, 1, 2, 3] 22 | 23 | # OPT 24 | lr = 5e-4 25 | weight_decay = 1e-7 26 | max_grad_norm = 5.0 27 | pos_weight = 2.13 28 | steplr = 10 29 | 30 | # CHECK_POINTS 31 | modelf = None 32 | 33 | # LOSS 34 | scale_seg = 1.0 35 | scale_var = 0.1 36 | scale_dist = 0.1 37 | scale_direction = 0.1 38 | 39 | direction_pred = True 40 | instance_seg = True 41 | embedding_dim = 16 42 | delta_v = 0.5 43 | delta_d = 3.0 44 | angle_class = 36 45 | 46 | # Mask config 47 | mask_flag = True 48 | mask_ratio = 0.5 # random ratio 49 | patch_h = 20 50 | patch_w = 20 51 | 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /config/nusc/hd_prior_pretrain/hd_pretrain_60m.py: -------------------------------------------------------------------------------- 1 | # DATA 2 | dataset='nuScenes' 3 | dataroot = './dataset/nuScenes' 4 | version= 'v1.0-trainval' 5 | 6 | xbound = [-30, 30.0, 0.15] # 60m*30m, bev_size:400*200 7 | ybound = [-15.0, 15.0, 0.15] 8 | 9 | zbound = [-10.0, 10.0, 20.0] 10 | dbound = [4.0, 45.0, 1.0] 11 | image_size = [128, 352] 12 | thickness = 5 13 | # EXP 14 | logdir = './Work_dir/pretrain_60' 15 | sd_map_path='./data_osm/osm' 16 | # TRAIN 17 | model = 'hdmapnet_pretrain' 18 | nepochs = 30 19 | batch_size = 8 20 | nworkers = 10 21 | gpus = [0, 1, 2, 3] 22 | 23 | # OPT 24 | lr = 5e-4 25 | weight_decay = 1e-7 26 | max_grad_norm = 5.0 27 | pos_weight = 2.13 28 | steplr = 10 29 | 30 | # CHECK_POINTS 31 | modelf = ".Work_dir/pretrain_60/model0.pt" 32 | vit_base = 'ckpt/mae_finetuned_vit_base.pth' # download link: https://dl.fbaipublicfiles.com/mae/finetune/mae_finetuned_vit_base.pth 33 | # LOSS 34 | scale_seg = 1.0 35 | scale_var = 0.1 36 | scale_dist = 0.1 37 | scale_direction = 0.1 38 | 39 | direction_pred = True 40 | instance_seg = True 41 | embedding_dim = 16 42 | delta_v = 0.5 43 | delta_d = 3.0 44 | angle_class = 36 45 | 46 | # Mask config 47 | mask_flag = True 48 | mask_ratio = 0.5 # random ratio 49 | patch_h = 20 50 | patch_w = 20 51 | 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /config/nusc/sd_prior/sd_120m.py: -------------------------------------------------------------------------------- 1 | # DATA 2 | dataset='nuScenes' 3 | dataroot = './dataset/nuScenes' 4 | version= 'v1.0-trainval' #'v1.0-mini' 5 | 6 | xbound = [-60.0, 60.0, 0.3] #120m*60m, bev_size:400*200 7 | ybound = [-30.0, 30.0, 0.3] 8 | 9 | zbound = [-10.0, 10.0, 20.0] 10 | dbound = [4.0, 45.0, 1.0] 11 | image_size = [128, 352] 12 | thickness = 5 13 | # EXP 14 | logdir = './Work_dir/sd_120' 15 | sd_map_path='./data_osm/osm' 16 | # TRAIN 17 | model = 'pmapnet_sd' 18 | nepochs = 30 19 | batch_size = 4 20 | nworkers = 20 21 | gpus = [0] 22 | 23 | # OPT 24 | lr = 5e-4 25 | weight_decay = 1e-7 26 | max_grad_norm = 5.0 27 | pos_weight = 2.13 28 | steplr = 10 29 | 30 | # CHECK_POINTS 31 | modelf = None 32 | 33 | # LOSS 34 | scale_seg = 1.0 35 | scale_var = 0.1 36 | scale_dist = 0.1 37 | scale_direction = 0.1 38 | 39 | direction_pred = True 40 | instance_seg = True 41 | embedding_dim = 16 42 | delta_v = 0.5 43 | delta_d = 3.0 44 | angle_class = 36 45 | 46 | # Mask config 47 | mask_flag = False 48 | mask_ratio = -1 # random ratio 49 | patch_h = 20 50 | patch_w = 20 51 | 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /config/nusc/sd_prior/sd_120m_cam.py: -------------------------------------------------------------------------------- 1 | # DATA 2 | dataset='nuScenes' 3 | dataroot = './dataset/nuScenes' 4 | version= 'v1.0-trainval' 5 | 6 | xbound = [-60.0, 60.0, 0.3] #120m*60m, bev_size:400*200 7 | ybound = [-30.0, 30.0, 0.3] 8 | 9 | zbound = [-10.0, 10.0, 20.0] 10 | dbound = [4.0, 45.0, 1.0] 11 | image_size = [128, 352] 12 | thickness = 5 13 | # EXP 14 | logdir = './Work_dir/sd_cam_120' 15 | sd_map_path='./data_osm/osm' 16 | # TRAIN 17 | model = 'pmapnet_sd_cam' 18 | nepochs = 30 19 | batch_size = 8 20 | nworkers = 20 21 | gpus = [0, 1, 2, 3] 22 | 23 | # OPT 24 | lr = 5e-4 25 | weight_decay = 1e-7 26 | max_grad_norm = 5.0 27 | pos_weight = 2.13 28 | steplr = 10 29 | 30 | # CHECK_POINTS 31 | modelf = None 32 | 33 | # LOSS 34 | scale_seg = 1.0 35 | scale_var = 0.1 36 | scale_dist = 0.1 37 | scale_direction = 0.1 38 | 39 | direction_pred = True 40 | instance_seg = True 41 | embedding_dim = 16 42 | delta_v = 0.5 43 | delta_d = 3.0 44 | angle_class = 36 45 | 46 | # Mask config 47 | mask_flag = False 48 | mask_ratio = -1 # random ratio 49 | patch_h = 20 50 | patch_w = 20 51 | 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /config/nusc/sd_prior/sd_240m.py: -------------------------------------------------------------------------------- 1 | # DATA 2 | dataset='nuScenes' 3 | dataroot = './dataset/nuScenes' 4 | version= 'v1.0-trainval' 5 | 6 | xbound = [-120.0, 120.0, 0.3] #240m*60m, bev_size:800*200 7 | ybound = [-30.0, 30.0, 0.3] 8 | 9 | zbound = [-10.0, 10.0, 20.0] 10 | dbound = [4.0, 45.0, 1.0] 11 | image_size = [128, 352] 12 | thickness = 5 13 | # EXP 14 | logdir = './Work_dir/sd_240' 15 | sd_map_path='./data_osm/osm' 16 | # TRAIN 17 | model = 'pmapnet_sd' 18 | nepochs = 30 19 | batch_size = 8 20 | nworkers = 20 21 | gpus = [0, 1, 2, 3] 22 | 23 | # OPT 24 | lr = 5e-4 25 | weight_decay = 1e-7 26 | max_grad_norm = 5.0 27 | pos_weight = 2.13 28 | steplr = 10 29 | 30 | # CHECK_POINTS 31 | modelf = None 32 | 33 | # LOSS 34 | scale_seg = 1.0 35 | scale_var = 0.1 36 | scale_dist = 0.1 37 | scale_direction = 0.1 38 | 39 | direction_pred = True 40 | instance_seg = True 41 | embedding_dim = 16 42 | delta_v = 0.5 43 | delta_d = 3.0 44 | angle_class = 36 45 | 46 | # Mask config 47 | mask_flag = False 48 | mask_ratio = -1 # random ratio 49 | patch_h = 20 50 | patch_w = 20 51 | 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /config/nusc/sd_prior/sd_240m_cam.py: -------------------------------------------------------------------------------- 1 | # DATA 2 | dataset='nuScenes' 3 | dataroot = './dataset/nuScenes' 4 | version= 'v1.0-trainval' 5 | 6 | xbound = [-120.0, 120.0, 0.3] #240m*60m, bev_size:800*200 7 | ybound = [-30.0, 30.0, 0.3] 8 | 9 | zbound = [-10.0, 10.0, 20.0] 10 | dbound = [4.0, 45.0, 1.0] 11 | image_size = [128, 352] 12 | thickness = 5 13 | # EXP 14 | logdir = './Work_dir/sd_cam_240' 15 | sd_map_path='./data_osm/osm' 16 | # TRAIN 17 | model = 'pmapnet_sd_cam' 18 | nepochs = 30 19 | batch_size = 8 20 | nworkers = 20 21 | gpus = [0, 1, 2, 3] 22 | 23 | # OPT 24 | lr = 5e-4 25 | weight_decay = 1e-7 26 | max_grad_norm = 5.0 27 | pos_weight = 2.13 28 | steplr = 10 29 | 30 | # CHECK_POINTS 31 | modelf = None 32 | 33 | # LOSS 34 | scale_seg = 1.0 35 | scale_var = 0.1 36 | scale_dist = 0.1 37 | scale_direction = 0.1 38 | 39 | direction_pred = True 40 | instance_seg = True 41 | embedding_dim = 16 42 | delta_v = 0.5 43 | delta_d = 3.0 44 | angle_class = 36 45 | 46 | # Mask config 47 | mask_flag = False 48 | mask_ratio = -1 # random ratio 49 | patch_h = 20 50 | patch_w = 20 51 | 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /config/nusc/sd_prior/sd_60m.py: -------------------------------------------------------------------------------- 1 | # DATA 2 | dataset='nuScenes' 3 | dataroot = './dataset/nuScenes' 4 | version= 'v1.0-trainval' 5 | 6 | xbound = [-30, 30.0, 0.15] # 60m*30m, bev_size:400*200 7 | ybound = [-15.0, 15.0, 0.15] 8 | 9 | zbound = [-10.0, 10.0, 20.0] 10 | dbound = [4.0, 45.0, 1.0] 11 | image_size = [128, 352] 12 | thickness = 5 13 | # EXP 14 | logdir = './Work_dir/sd_60' 15 | sd_map_path='./data_osm/osm' 16 | # TRAIN 17 | model = 'pmapnet_sd' 18 | nepochs = 30 19 | batch_size = 2 20 | nworkers = 20 21 | gpus = [0] 22 | 23 | # OPT 24 | lr = 5e-4 25 | weight_decay = 1e-7 26 | max_grad_norm = 5.0 27 | pos_weight = 2.13 28 | steplr = 10 29 | 30 | # CHECK_POINTS 31 | modelf = "./Work_dir/sd_60m/model11--.pt" 32 | 33 | # LOSS 34 | scale_seg = 1.0 35 | scale_var = 0.1 36 | scale_dist = 0.1 37 | scale_direction = 0.1 38 | 39 | direction_pred = True 40 | instance_seg = True 41 | embedding_dim = 16 42 | delta_v = 0.5 43 | delta_d = 3.0 44 | angle_class = 36 45 | 46 | # Mask config 47 | mask_flag = False 48 | mask_ratio = -1 # random ratio 49 | patch_h = 20 50 | patch_w = 20 51 | 52 | # JSON 53 | result_path = './Work_dir/sd_60m/submission.json' 54 | max_channel = 3 55 | bidirectional = False 56 | CD_threshold = 5 57 | threshold_iou = 0.1 -------------------------------------------------------------------------------- /config/nusc/sd_prior/sd_60m_cam.py: -------------------------------------------------------------------------------- 1 | # DATA 2 | dataset='nuScenes' 3 | dataroot = './dataset/nuScenes' 4 | version= 'v1.0-trainval' 5 | 6 | xbound = [-60.0, 60.0, 0.3] 7 | ybound = [-30.0, 30.0, 0.3] 8 | 9 | zbound = [-10.0, 10.0, 20.0] 10 | dbound = [4.0, 45.0, 1.0] 11 | image_size = [128, 352] 12 | thickness = 5 13 | # EXP 14 | logdir = './Work_dir/sd_cam_60' 15 | sd_map_path='./data_osm/osm' 16 | # TRAIN 17 | model = 'pmapnet_sd_cam' 18 | nepochs = 30 19 | batch_size = 8 20 | nworkers = 20 21 | gpus = [0, 1, 2, 3] 22 | 23 | # OPT 24 | lr = 5e-4 25 | weight_decay = 1e-7 26 | max_grad_norm = 5.0 27 | pos_weight = 2.13 28 | steplr = 10 29 | 30 | # CHECK_POINTS 31 | modelf = "Work_dir/nus/sd_60m_cam/cam_60_sd_model15.pt" 32 | 33 | # NETWORK 34 | use_aux = True 35 | griding_num = 200 36 | backbone = '18' 37 | 38 | # LOSS 39 | scale_seg = 1.0 40 | scale_var = 0.1 41 | scale_dist = 0.1 42 | scale_direction = 0.1 43 | 44 | direction_pred = True 45 | instance_seg = True 46 | embedding_dim = 16 47 | delta_v = 0.5 48 | delta_d = 3.0 49 | angle_class = 36 50 | 51 | # Mask config 52 | mask_flag = False 53 | mask_ratio = -1 # random ratio 54 | patch_h = 20 55 | patch_w = 20 56 | 57 | 58 | 59 | 60 | -------------------------------------------------------------------------------- /data_osm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jike5/P-MapNet/b8b4cf2295ee75826046eef9cfa12b107fb43619/data_osm/__init__.py -------------------------------------------------------------------------------- /data_osm/const.py: -------------------------------------------------------------------------------- 1 | MAP = ['boston-seaport', 'singapore-hollandvillage', 'singapore-onenorth', 'singapore-queenstown'] 2 | CAMS = ['CAM_FRONT_LEFT', 'CAM_FRONT', 'CAM_FRONT_RIGHT', 'CAM_BACK_LEFT', 'CAM_BACK', 'CAM_BACK_RIGHT'] 3 | CLASS2LABEL = { 4 | 'road_divider': 0, # 道路分隔线 5 | 'lane_divider': 0, # 车道分隔线 6 | 'ped_crossing': 1, # 人行道 7 | 'contours': 2, # 轮廓线 8 | 'others': -1 9 | } 10 | 11 | NUM_CLASSES = 3 12 | IMG_ORIGIN_H = 900 13 | IMG_ORIGIN_W = 1600 14 | -------------------------------------------------------------------------------- /data_osm/geo_opensfm.py: -------------------------------------------------------------------------------- 1 | """Copied from opensfm.geo to minimize hard dependencies.""" 2 | 3 | from typing import Tuple 4 | 5 | import numpy as np 6 | from numpy import ndarray 7 | 8 | WGS84_a = 6378137.0 9 | WGS84_b = 6356752.314245 10 | 11 | 12 | def ecef_from_lla(lat, lon, alt: float) -> Tuple[float, ...]: 13 | """ 14 | Compute ECEF XYZ from latitude, longitude and altitude. 15 | 16 | All using the WGS84 model. 17 | Altitude is the distance to the WGS84 ellipsoid. 18 | Check results here http://www.oc.nps.edu/oc2902w/coord/llhxyz.htm 19 | 20 | >>> lat, lon, alt = 10, 20, 30 21 | >>> x, y, z = ecef_from_lla(lat, lon, alt) 22 | >>> np.allclose(lla_from_ecef(x,y,z), [lat, lon, alt]) 23 | True 24 | """ 25 | a2 = WGS84_a**2 26 | b2 = WGS84_b**2 27 | lat = np.radians(lat) 28 | lon = np.radians(lon) 29 | L = 1.0 / np.sqrt(a2 * np.cos(lat) ** 2 + b2 * np.sin(lat) ** 2) 30 | x = (a2 * L + alt) * np.cos(lat) * np.cos(lon) 31 | y = (a2 * L + alt) * np.cos(lat) * np.sin(lon) 32 | z = (b2 * L + alt) * np.sin(lat) 33 | return x, y, z 34 | 35 | 36 | def lla_from_ecef(x, y, z): 37 | """ 38 | Compute latitude, longitude and altitude from ECEF XYZ. 39 | 40 | All using the WGS84 model. 41 | Altitude is the distance to the WGS84 ellipsoid. 42 | """ 43 | a = WGS84_a 44 | b = WGS84_b 45 | ea = np.sqrt((a**2 - b**2) / a**2) 46 | eb = np.sqrt((a**2 - b**2) / b**2) 47 | p = np.sqrt(x**2 + y**2) 48 | theta = np.arctan2(z * a, p * b) 49 | lon = np.arctan2(y, x) 50 | lat = np.arctan2( 51 | z + eb**2 * b * np.sin(theta) ** 3, p - ea**2 * a * np.cos(theta) ** 3 52 | ) 53 | N = a / np.sqrt(1 - ea**2 * np.sin(lat) ** 2) 54 | alt = p / np.cos(lat) - N 55 | return np.degrees(lat), np.degrees(lon), alt 56 | 57 | 58 | def ecef_from_topocentric_transform(lat, lon, alt: float) -> ndarray: 59 | """ 60 | Transformation from a topocentric frame at reference position to ECEF. 61 | 62 | The topocentric reference frame is a metric one with the origin 63 | at the given (lat, lon, alt) position, with the X axis heading east, 64 | the Y axis heading north and the Z axis vertical to the ellipsoid. 65 | >>> a = ecef_from_topocentric_transform(30, 20, 10) 66 | >>> b = ecef_from_topocentric_transform_finite_diff(30, 20, 10) 67 | >>> np.allclose(a, b) 68 | True 69 | """ 70 | x, y, z = ecef_from_lla(lat, lon, alt) 71 | sa = np.sin(np.radians(lat)) 72 | ca = np.cos(np.radians(lat)) 73 | so = np.sin(np.radians(lon)) 74 | co = np.cos(np.radians(lon)) 75 | return np.array( 76 | [ 77 | [-so, -sa * co, ca * co, x], 78 | [co, -sa * so, ca * so, y], 79 | [0, ca, sa, z], 80 | [0, 0, 0, 1], 81 | ] 82 | ) 83 | 84 | 85 | def ecef_from_topocentric_transform_finite_diff(lat, lon, alt: float) -> ndarray: 86 | """ 87 | Transformation from a topocentric frame at reference position to ECEF. 88 | 89 | The topocentric reference frame is a metric one with the origin 90 | at the given (lat, lon, alt) position, with the X axis heading east, 91 | the Y axis heading north and the Z axis vertical to the ellipsoid. 92 | """ 93 | eps = 1e-2 94 | x, y, z = ecef_from_lla(lat, lon, alt) 95 | v1 = ( 96 | ( 97 | np.array(ecef_from_lla(lat, lon + eps, alt)) 98 | - np.array(ecef_from_lla(lat, lon - eps, alt)) 99 | ) 100 | / 2 101 | / eps 102 | ) 103 | v2 = ( 104 | ( 105 | np.array(ecef_from_lla(lat + eps, lon, alt)) 106 | - np.array(ecef_from_lla(lat - eps, lon, alt)) 107 | ) 108 | / 2 109 | / eps 110 | ) 111 | v3 = ( 112 | ( 113 | np.array(ecef_from_lla(lat, lon, alt + eps)) 114 | - np.array(ecef_from_lla(lat, lon, alt - eps)) 115 | ) 116 | / 2 117 | / eps 118 | ) 119 | v1 /= np.linalg.norm(v1) 120 | v2 /= np.linalg.norm(v2) 121 | v3 /= np.linalg.norm(v3) 122 | return np.array( 123 | [ 124 | [v1[0], v2[0], v3[0], x], 125 | [v1[1], v2[1], v3[1], y], 126 | [v1[2], v2[2], v3[2], z], 127 | [0, 0, 0, 1], 128 | ] 129 | ) 130 | 131 | 132 | def topocentric_from_lla(lat, lon, alt: float, reflat, reflon, refalt: float): 133 | """ 134 | Transform from lat, lon, alt to topocentric XYZ. 135 | 136 | >>> lat, lon, alt = -10, 20, 100 137 | >>> np.allclose(topocentric_from_lla(lat, lon, alt, lat, lon, alt), 138 | ... [0,0,0]) 139 | True 140 | >>> x, y, z = topocentric_from_lla(lat, lon, alt, 0, 0, 0) 141 | >>> np.allclose(lla_from_topocentric(x, y, z, 0, 0, 0), 142 | ... [lat, lon, alt]) 143 | True 144 | """ 145 | T = np.linalg.inv(ecef_from_topocentric_transform(reflat, reflon, refalt)) 146 | x, y, z = ecef_from_lla(lat, lon, alt) 147 | tx = T[0, 0] * x + T[0, 1] * y + T[0, 2] * z + T[0, 3] 148 | ty = T[1, 0] * x + T[1, 1] * y + T[1, 2] * z + T[1, 3] 149 | tz = T[2, 0] * x + T[2, 1] * y + T[2, 2] * z + T[2, 3] 150 | return tx, ty, tz 151 | 152 | 153 | def lla_from_topocentric(x, y, z, reflat, reflon, refalt: float): 154 | """ 155 | Transform from topocentric XYZ to lat, lon, alt. 156 | """ 157 | T = ecef_from_topocentric_transform(reflat, reflon, refalt) 158 | ex = T[0, 0] * x + T[0, 1] * y + T[0, 2] * z + T[0, 3] 159 | ey = T[1, 0] * x + T[1, 1] * y + T[1, 2] * z + T[1, 3] 160 | ez = T[2, 0] * x + T[2, 1] * y + T[2, 2] * z + T[2, 3] 161 | return lla_from_ecef(ex, ey, ez) 162 | 163 | 164 | class TopocentricConverter(object): 165 | """Convert to and from a topocentric reference frame.""" 166 | 167 | def __init__(self, reflat, reflon, refalt): 168 | """Init the converter given the reference origin.""" 169 | self.lat = reflat 170 | self.lon = reflon 171 | self.alt = refalt 172 | 173 | def to_topocentric(self, lat, lon, alt): 174 | """Convert lat, lon, alt to topocentric x, y, z.""" 175 | return topocentric_from_lla(lat, lon, alt, self.lat, self.lon, self.alt) 176 | 177 | def to_lla(self, x, y, z): 178 | """Convert topocentric x, y, z to lat, lon, alt.""" 179 | return lla_from_topocentric(x, y, z, self.lat, self.lon, self.alt) 180 | 181 | def __eq__(self, o): 182 | return np.allclose([self.lat, self.lon, self.alt], (o.lat, o.lon, o.alt)) -------------------------------------------------------------------------------- /data_osm/image.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | 4 | import torch 5 | import torchvision 6 | 7 | class NormalizeInverse(torchvision.transforms.Normalize): 8 | # https://discuss.pytorch.org/t/simple-way-to-inverse-transform-normalization/4821/8 9 | def __init__(self, mean, std): 10 | mean = torch.as_tensor(mean) 11 | std = torch.as_tensor(std) 12 | std_inv = 1 / (std + 1e-7) 13 | mean_inv = -mean * std_inv 14 | super().__init__(mean=mean_inv, std=std_inv) 15 | 16 | def __call__(self, tensor): 17 | return super().__call__(tensor.clone()) 18 | 19 | 20 | normalize_img = torchvision.transforms.Compose(( 21 | torchvision.transforms.ToTensor(), 22 | torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], 23 | std=[0.229, 0.224, 0.225]), 24 | )) 25 | 26 | normalize_tensor_img = torchvision.transforms.Compose(( 27 | # torchvision.transforms.ToTensor(), 28 | torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], 29 | std=[0.229, 0.224, 0.225]), 30 | )) 31 | 32 | denormalize_img = torchvision.transforms.Compose(( 33 | NormalizeInverse(mean=[0.485, 0.456, 0.406], 34 | std=[0.229, 0.224, 0.225]), 35 | torchvision.transforms.ToPILImage(), 36 | )) 37 | 38 | 39 | def img_transform(img, resize, resize_dims): 40 | post_rot2 = torch.eye(2) 41 | post_tran2 = torch.zeros(2) 42 | 43 | img = img.resize(resize_dims) # resize到352*128 44 | 45 | rot_resize = torch.Tensor([[resize[0], 0], 46 | [0, resize[1]]]) 47 | post_rot2 = rot_resize @ post_rot2 48 | post_tran2 = rot_resize @ post_tran2 49 | 50 | post_tran = torch.zeros(3) 51 | post_rot = torch.eye(3) 52 | post_tran[:2] = post_tran2 53 | post_rot[:2, :2] = post_rot2 54 | return img, post_rot, post_tran 55 | 56 | 57 | def get_rot(h): 58 | return torch.Tensor([ 59 | [np.cos(h), np.sin(h)], 60 | [-np.sin(h), np.cos(h)], 61 | ]) 62 | 63 | # def img_transform(img, resize, resize_dims, crop, flip, rotate): 64 | # post_rot2 = torch.eye(2) 65 | # post_tran2 = torch.zeros(2) 66 | 67 | # # adjust image 68 | # img = img.resize(resize_dims) 69 | # img = img.crop(crop) 70 | # if flip: 71 | # img = img.transpose(method=Image.FLIP_LEFT_RIGHT) 72 | # img = img.rotate(rotate) 73 | 74 | # # post-homography transformation 75 | # post_rot2 *= resize 76 | # post_tran2 -= torch.Tensor(crop[:2]) 77 | # if flip: 78 | # A = torch.Tensor([[-1, 0], [0, 1]]) 79 | # b = torch.Tensor([crop[2] - crop[0], 0]) 80 | # post_rot2 = A.matmul(post_rot2) 81 | # post_tran2 = A.matmul(post_tran2) + b 82 | # A = get_rot(rotate/180*np.pi) 83 | # b = torch.Tensor([crop[2] - crop[0], crop[3] - crop[1]]) / 2 84 | # b = A.matmul(-b) + b 85 | # post_rot2 = A.matmul(post_rot2) 86 | # post_tran2 = A.matmul(post_tran2) + b 87 | 88 | # post_tran = torch.zeros(3) 89 | # post_rot = torch.eye(3) 90 | # post_tran[:2] = post_tran2 91 | # post_rot[:2, :2] = post_rot2 92 | # return img, post_rot, post_tran 93 | 94 | -------------------------------------------------------------------------------- /data_osm/lidar.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from functools import reduce 4 | 5 | from pyquaternion import Quaternion 6 | 7 | from nuscenes.utils.data_classes import LidarPointCloud 8 | from nuscenes.utils.geometry_utils import transform_matrix 9 | 10 | 11 | def get_lidar_data(nusc, sample_rec, nsweeps, min_distance): 12 | """ 13 | Returns at most nsweeps of lidar in the ego frame. 14 | Returned tensor is 5(x, y, z, reflectance, dt) x N 15 | Adapted from https://github.com/nutonomy/nuscenes-devkit/blob/master/python-sdk/nuscenes/utils/data_classes.py#L56 16 | """ 17 | points = np.zeros((5, 0)) 18 | 19 | # Get reference pose and timestamp. 20 | ref_sd_token = sample_rec['data']['LIDAR_TOP'] 21 | ref_sd_rec = nusc.get('sample_data', ref_sd_token) 22 | ref_pose_rec = nusc.get('ego_pose', ref_sd_rec['ego_pose_token']) 23 | ref_cs_rec = nusc.get('calibrated_sensor', ref_sd_rec['calibrated_sensor_token']) 24 | ref_time = 1e-6 * ref_sd_rec['timestamp'] 25 | 26 | # Homogeneous transformation matrix from global to _current_ ego car frame. 27 | car_from_global = transform_matrix(ref_pose_rec['translation'], Quaternion(ref_pose_rec['rotation']), 28 | inverse=True) 29 | 30 | # Aggregate current and previous sweeps. 31 | sample_data_token = sample_rec['data']['LIDAR_TOP'] 32 | current_sd_rec = nusc.get('sample_data', sample_data_token) 33 | for _ in range(nsweeps): 34 | # Load up the pointcloud and remove points close to the sensor. 35 | current_pc = LidarPointCloud.from_file(os.path.join(nusc.dataroot, current_sd_rec['filename'])) 36 | current_pc.remove_close(min_distance) 37 | 38 | # Get past pose. 39 | current_pose_rec = nusc.get('ego_pose', current_sd_rec['ego_pose_token']) 40 | global_from_car = transform_matrix(current_pose_rec['translation'], 41 | Quaternion(current_pose_rec['rotation']), inverse=False) 42 | 43 | # Homogeneous transformation matrix from sensor coordinate frame to ego car frame. 44 | current_cs_rec = nusc.get('calibrated_sensor', current_sd_rec['calibrated_sensor_token']) 45 | car_from_current = transform_matrix(current_cs_rec['translation'], Quaternion(current_cs_rec['rotation']), 46 | inverse=False) 47 | 48 | # Fuse four transformation matrices into one and perform transform. 49 | trans_matrix = reduce(np.dot, [car_from_global, global_from_car, car_from_current]) 50 | current_pc.transform(trans_matrix) 51 | 52 | # Add time vector which can be used as a temporal feature. 53 | time_lag = ref_time - 1e-6 * current_sd_rec['timestamp'] 54 | times = time_lag * np.ones((1, current_pc.nbr_points())) 55 | 56 | new_points = np.concatenate((current_pc.points, times), 0) 57 | points = np.concatenate((points, new_points), 1) 58 | 59 | # Abort if there are no previous sweeps. 60 | if current_sd_rec['prev'] == '': 61 | break 62 | else: 63 | current_sd_rec = nusc.get('sample_data', current_sd_rec['prev']) 64 | 65 | return points -------------------------------------------------------------------------------- /data_osm/osm/boston-seaport.cpg: -------------------------------------------------------------------------------- 1 | UTF-8 2 | -------------------------------------------------------------------------------- /data_osm/osm/boston-seaport.dbf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jike5/P-MapNet/b8b4cf2295ee75826046eef9cfa12b107fb43619/data_osm/osm/boston-seaport.dbf -------------------------------------------------------------------------------- /data_osm/osm/boston-seaport.prj: -------------------------------------------------------------------------------- 1 | GEOGCS["GCS_WGS_1984",DATUM["D_WGS_1984",SPHEROID["WGS_1984",6378137,298.257223563]],PRIMEM["Greenwich",0],UNIT["Degree",0.017453292519943295]] 2 | -------------------------------------------------------------------------------- /data_osm/osm/boston-seaport.shp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jike5/P-MapNet/b8b4cf2295ee75826046eef9cfa12b107fb43619/data_osm/osm/boston-seaport.shp -------------------------------------------------------------------------------- /data_osm/osm/boston-seaport.shx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jike5/P-MapNet/b8b4cf2295ee75826046eef9cfa12b107fb43619/data_osm/osm/boston-seaport.shx -------------------------------------------------------------------------------- /data_osm/osm/sd_map_data_ATX.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jike5/P-MapNet/b8b4cf2295ee75826046eef9cfa12b107fb43619/data_osm/osm/sd_map_data_ATX.pkl -------------------------------------------------------------------------------- /data_osm/osm/sd_map_data_DTW.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jike5/P-MapNet/b8b4cf2295ee75826046eef9cfa12b107fb43619/data_osm/osm/sd_map_data_DTW.pkl -------------------------------------------------------------------------------- /data_osm/osm/sd_map_data_MIA.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jike5/P-MapNet/b8b4cf2295ee75826046eef9cfa12b107fb43619/data_osm/osm/sd_map_data_MIA.pkl -------------------------------------------------------------------------------- /data_osm/osm/sd_map_data_PAO.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jike5/P-MapNet/b8b4cf2295ee75826046eef9cfa12b107fb43619/data_osm/osm/sd_map_data_PAO.pkl -------------------------------------------------------------------------------- /data_osm/osm/sd_map_data_PIT.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jike5/P-MapNet/b8b4cf2295ee75826046eef9cfa12b107fb43619/data_osm/osm/sd_map_data_PIT.pkl -------------------------------------------------------------------------------- /data_osm/osm/sd_map_data_WDC.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jike5/P-MapNet/b8b4cf2295ee75826046eef9cfa12b107fb43619/data_osm/osm/sd_map_data_WDC.pkl -------------------------------------------------------------------------------- /data_osm/osm/singapore-hollandvillage.cpg: -------------------------------------------------------------------------------- 1 | UTF-8 2 | -------------------------------------------------------------------------------- /data_osm/osm/singapore-hollandvillage.prj: -------------------------------------------------------------------------------- 1 | GEOGCS["GCS_WGS_1984",DATUM["D_WGS_1984",SPHEROID["WGS_1984",6378137,298.257223563]],PRIMEM["Greenwich",0],UNIT["Degree",0.017453292519943295]] 2 | -------------------------------------------------------------------------------- /data_osm/osm/singapore-hollandvillage.shp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jike5/P-MapNet/b8b4cf2295ee75826046eef9cfa12b107fb43619/data_osm/osm/singapore-hollandvillage.shp -------------------------------------------------------------------------------- /data_osm/osm/singapore-hollandvillage.shx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jike5/P-MapNet/b8b4cf2295ee75826046eef9cfa12b107fb43619/data_osm/osm/singapore-hollandvillage.shx -------------------------------------------------------------------------------- /data_osm/osm/singapore-onenorth.cpg: -------------------------------------------------------------------------------- 1 | UTF-8 2 | -------------------------------------------------------------------------------- /data_osm/osm/singapore-onenorth.prj: -------------------------------------------------------------------------------- 1 | GEOGCS["GCS_WGS_1984",DATUM["D_WGS_1984",SPHEROID["WGS_1984",6378137,298.257223563]],PRIMEM["Greenwich",0],UNIT["Degree",0.017453292519943295]] 2 | -------------------------------------------------------------------------------- /data_osm/osm/singapore-onenorth.shp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jike5/P-MapNet/b8b4cf2295ee75826046eef9cfa12b107fb43619/data_osm/osm/singapore-onenorth.shp -------------------------------------------------------------------------------- /data_osm/osm/singapore-onenorth.shx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jike5/P-MapNet/b8b4cf2295ee75826046eef9cfa12b107fb43619/data_osm/osm/singapore-onenorth.shx -------------------------------------------------------------------------------- /data_osm/osm/singapore-queenstown.cpg: -------------------------------------------------------------------------------- 1 | UTF-8 2 | -------------------------------------------------------------------------------- /data_osm/osm/singapore-queenstown.prj: -------------------------------------------------------------------------------- 1 | GEOGCS["GCS_WGS_1984",DATUM["D_WGS_1984",SPHEROID["WGS_1984",6378137,298.257223563]],PRIMEM["Greenwich",0],UNIT["Degree",0.017453292519943295]] 2 | -------------------------------------------------------------------------------- /data_osm/osm/singapore-queenstown.shp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jike5/P-MapNet/b8b4cf2295ee75826046eef9cfa12b107fb43619/data_osm/osm/singapore-queenstown.shp -------------------------------------------------------------------------------- /data_osm/osm/singapore-queenstown.shx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jike5/P-MapNet/b8b4cf2295ee75826046eef9cfa12b107fb43619/data_osm/osm/singapore-queenstown.shx -------------------------------------------------------------------------------- /data_osm/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | from .loading import LoadMultiViewImagesFromFiles 2 | from .formating import FormatBundleMap 3 | from .transform import ResizeMultiViewImages, PadMultiViewImages, Normalize3D, PhotoMetricDistortionMultiViewImage 4 | from .vectorize import VectorizeMap 5 | -------------------------------------------------------------------------------- /data_osm/pipelines/formating.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from mmcv.parallel import DataContainer as DC 3 | import torch 4 | import mmcv 5 | from collections.abc import Sequence 6 | # from mmdet3d.core.points import BasePoints 7 | # from mmdet.datasets.pipelines import to_tensor 8 | 9 | # copy from mmdet:https://mmdetection.readthedocs.io/en/v2.0.0/_modules/mmdet/datasets/pipelines/formating.html 10 | def to_tensor(data): 11 | """Convert objects of various python types to :obj:`torch.Tensor`. 12 | 13 | Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, 14 | :class:`Sequence`, :class:`int` and :class:`float`. 15 | """ 16 | if isinstance(data, torch.Tensor): 17 | return data 18 | elif isinstance(data, np.ndarray): 19 | return torch.from_numpy(data) 20 | elif isinstance(data, Sequence) and not mmcv.is_str(data): 21 | return torch.tensor(data) 22 | elif isinstance(data, int): 23 | return torch.LongTensor([data]) 24 | elif isinstance(data, float): 25 | return torch.FloatTensor([data]) 26 | else: 27 | raise TypeError(f'type {type(data)} cannot be converted to tensor.') 28 | 29 | class FormatBundleMap(object): 30 | """Format data for map tasks and then collect data for model input. 31 | 32 | These fields are formatted as follows. 33 | 34 | - img: (1) transpose, (2) to tensor, (3) to DataContainer (stack=True) 35 | - semantic_mask (if exists): (1) to tensor, (2) to DataContainer (stack=True) 36 | - vectors (if exists): (1) to DataContainer (cpu_only=True) 37 | - img_metas: (1) to DataContainer (cpu_only=True) 38 | """ 39 | 40 | def __init__(self, process_img=True, 41 | keys=['img', 'semantic_mask', 'vectors'], 42 | meta_keys=['intrinsics', 'extrinsics']): 43 | 44 | self.process_img = process_img 45 | self.keys = keys 46 | self.meta_keys = meta_keys 47 | 48 | def __call__(self, results): 49 | """Call function to transform and format common fields in results. 50 | 51 | Args: 52 | results (dict): Result dict contains the data to convert. 53 | 54 | Returns: 55 | dict: The result dict contains the data that is formatted with 56 | default bundle. 57 | """ 58 | # Format 3D data 59 | if 'points' in results: 60 | assert isinstance(results['points'], BasePoints) 61 | results['points'] = DC(results['points'].tensor) 62 | 63 | for key in ['voxels', 'coors', 'voxel_centers', 'num_points']: 64 | if key not in results: 65 | continue 66 | results[key] = DC(to_tensor(results[key]), stack=False) 67 | 68 | if 'img' in results and self.process_img: 69 | if isinstance(results['img'], list): 70 | # process multiple imgs in single frame 71 | imgs = [img.transpose(2, 0, 1) for img in results['img']] 72 | imgs = np.ascontiguousarray(np.stack(imgs, axis=0)) 73 | results['img'] = DC(to_tensor(imgs), stack=True) 74 | else: 75 | img = np.ascontiguousarray(results['img'].transpose(2, 0, 1)) 76 | results['img'] = DC(to_tensor(img), stack=True) 77 | 78 | if 'semantic_mask' in results: 79 | results['semantic_mask'] = DC(to_tensor(results['semantic_mask']), stack=True) 80 | 81 | if 'vectors' in results: 82 | # vectors may have different sizes 83 | vectors = results['vectors'] 84 | results['vectors'] = DC(vectors, stack=False, cpu_only=True) 85 | 86 | if 'polys' in results: 87 | results['polys'] = DC(results['polys'], stack=False, cpu_only=True) 88 | 89 | return results 90 | 91 | def __repr__(self): 92 | """str: Return a string that describes the module.""" 93 | repr_str = self.__class__.__name__ 94 | repr_str += f'(process_img={self.process_img}, ' 95 | return repr_str 96 | -------------------------------------------------------------------------------- /data_osm/pipelines/loading.py: -------------------------------------------------------------------------------- 1 | import mmcv 2 | import numpy as np 3 | 4 | class LoadMultiViewImagesFromFiles(object): 5 | """Load multi channel images from a list of separate channel files. 6 | 7 | Expects results['img_filename'] to be a list of filenames. 8 | 9 | Args: 10 | to_float32 (bool): Whether to convert the img to float32. 11 | Defaults to False. 12 | color_type (str): Color type of the file. Defaults to 'unchanged'. 13 | """ 14 | 15 | def __init__(self, to_float32=False, color_type='unchanged'): 16 | self.to_float32 = to_float32 17 | self.color_type = color_type 18 | 19 | def __call__(self, results): 20 | """Call function to load multi-view image from files. 21 | 22 | Args: 23 | results (dict): Result dict containing multi-view image filenames. 24 | 25 | Returns: 26 | dict: The result dict containing the multi-view image data. \ 27 | Added keys and values are described below. 28 | 29 | - filename (str): Multi-view image filenames. 30 | - img (np.ndarray): Multi-view image arrays. 31 | - img_shape (tuple[int]): Shape of multi-view image arrays. 32 | - ori_shape (tuple[int]): Shape of original image arrays. 33 | - pad_shape (tuple[int]): Shape of padded image arrays. 34 | - scale_factor (float): Scale factor. 35 | - img_norm_cfg (dict): Normalization configuration of images. 36 | """ 37 | filename = results['img_filenames'] 38 | img = [mmcv.imread(name, self.color_type) for name in filename] 39 | if self.to_float32: 40 | img = [i.astype(np.float32) for i in img] 41 | results['img'] = img 42 | results['img_shape'] = [i.shape for i in img] 43 | results['ori_shape'] = [i.shape for i in img] 44 | # Set initial values for default meta_keys 45 | results['pad_shape'] = [i.shape for i in img] 46 | # results['scale_factor'] = 1.0 47 | num_channels = 1 if len(img[0].shape) < 3 else img[0].shape[2] 48 | results['img_norm_cfg'] = dict( 49 | mean=np.zeros(num_channels, dtype=np.float32), 50 | std=np.ones(num_channels, dtype=np.float32), 51 | to_rgb=False) 52 | results['img_fields'] = ['img'] 53 | return results 54 | 55 | def __repr__(self): 56 | """str: Return a string that describes the module.""" 57 | return f'{self.__class__.__name__} (to_float32={self.to_float32}, '\ 58 | f"color_type='{self.color_type}')" 59 | -------------------------------------------------------------------------------- /data_osm/pipelines/vectorize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from shapely.geometry import LineString 3 | from numpy.typing import NDArray 4 | from typing import List, Tuple, Union, Dict 5 | from IPython import embed 6 | 7 | class VectorizeMap(object): 8 | """Generate vectoized map and put into `semantic_mask` key. 9 | Concretely, shapely geometry objects are converted into sample points (ndarray). 10 | We use args `sample_num`, `sample_dist`, `simplify` to specify sampling method. 11 | 12 | Args: 13 | roi_size (tuple or list): bev range . 14 | normalize (bool): whether to normalize points to range (0, 1). 15 | coords_dim (int): dimension of point coordinates. 16 | simplify (bool): whether to use simpily function. If true, `sample_num` \ 17 | and `sample_dist` will be ignored. 18 | sample_num (int): number of points to interpolate from a polyline. Set to -1 to ignore. 19 | sample_dist (float): interpolate distance. Set to -1 to ignore. 20 | """ 21 | 22 | def __init__(self, 23 | roi_size: Union[Tuple, List], 24 | normalize: bool, 25 | coords_dim: int, 26 | simplify: bool=False, 27 | sample_num: int=-1, 28 | sample_dist: float=-1, 29 | permute: bool=False 30 | ): 31 | self.coords_dim = coords_dim 32 | self.sample_num = sample_num 33 | self.sample_dist = sample_dist 34 | self.roi_size = np.array(roi_size) 35 | self.normalize = normalize 36 | self.simplify = simplify 37 | self.permute = permute 38 | 39 | if sample_dist > 0: 40 | assert sample_num < 0 and not simplify 41 | self.sample_fn = self.interp_fixed_dist 42 | elif sample_num > 0: 43 | assert sample_dist < 0 and not simplify 44 | self.sample_fn = self.interp_fixed_num 45 | else: 46 | assert simplify 47 | 48 | def interp_fixed_num(self, line: LineString) -> NDArray: 49 | ''' Interpolate a line to fixed number of points. 50 | 51 | Args: 52 | line (LineString): line 53 | 54 | Returns: 55 | points (array): interpolated points, shape (N, 2) 56 | ''' 57 | 58 | distances = np.linspace(0, line.length, self.sample_num) 59 | sampled_points = np.array([list(line.interpolate(distance).coords) 60 | for distance in distances]).squeeze() 61 | 62 | return sampled_points 63 | 64 | def interp_fixed_dist(self, line: LineString) -> NDArray: 65 | ''' Interpolate a line at fixed interval. 66 | 67 | Args: 68 | line (LineString): line 69 | 70 | Returns: 71 | points (array): interpolated points, shape (N, 2) 72 | ''' 73 | 74 | distances = list(np.arange(self.sample_dist, line.length, self.sample_dist)) 75 | # make sure to sample at least two points when sample_dist > line.length 76 | distances = [0,] + distances + [line.length,] 77 | 78 | sampled_points = np.array([list(line.interpolate(distance).coords) 79 | for distance in distances]).squeeze() 80 | 81 | return sampled_points 82 | 83 | def get_vectorized_lines(self, map_geoms: Dict) -> Dict: 84 | ''' Vectorize map elements. Iterate over the input dict and apply the 85 | specified sample funcion. 86 | 87 | Args: 88 | line (LineString): line 89 | 90 | Returns: 91 | vectors (array): dict of vectorized map elements. 92 | ''' 93 | 94 | vectors = {} 95 | for label, geom_list in map_geoms.items(): 96 | vectors[label] = [] 97 | for geom in geom_list: 98 | if geom.geom_type == 'LineString': 99 | if self.simplify: 100 | line = geom.simplify(0.2, preserve_topology=True) 101 | line = np.array(line.coords) 102 | else: 103 | line = self.sample_fn(geom) 104 | line = line[:, :self.coords_dim] 105 | 106 | if self.normalize: 107 | line = self.normalize_line(line) 108 | if self.permute: 109 | line = self.permute_line(line) 110 | vectors[label].append(line) 111 | 112 | elif geom.geom_type == 'Polygon': 113 | # polygon objects will not be vectorized 114 | continue 115 | 116 | else: 117 | raise ValueError('map geoms must be either LineString or Polygon!') 118 | return vectors 119 | 120 | def normalize_line(self, line: NDArray) -> NDArray: 121 | ''' Convert points to range (0, 1). 122 | 123 | Args: 124 | line (LineString): line 125 | 126 | Returns: 127 | normalized (array): normalized points. 128 | ''' 129 | 130 | origin = -np.array([self.roi_size[0]/2, self.roi_size[1]/2]) 131 | 132 | line[:, :2] = line[:, :2] - origin 133 | 134 | # transform from range [0, 1] to (0, 1) 135 | eps = 1e-5 136 | line[:, :2] = line[:, :2] / (self.roi_size + eps) 137 | 138 | return line 139 | 140 | def permute_line(self, line: np.ndarray, padding=1e5): 141 | ''' 142 | (num_pts, 2) -> (num_permute, num_pts, 2) 143 | where num_permute = 2 * (num_pts - 1) 144 | ''' 145 | is_closed = np.allclose(line[0], line[-1], atol=1e-3) 146 | num_points = len(line) 147 | permute_num = num_points - 1 148 | permute_lines_list = [] 149 | if is_closed: 150 | pts_to_permute = line[:-1, :] # throw away replicate start end pts 151 | for shift_i in range(permute_num): 152 | permute_lines_list.append(np.roll(pts_to_permute, shift_i, axis=0)) 153 | flip_pts_to_permute = np.flip(pts_to_permute, axis=0) 154 | for shift_i in range(permute_num): 155 | permute_lines_list.append(np.roll(flip_pts_to_permute, shift_i, axis=0)) 156 | else: 157 | permute_lines_list.append(line) 158 | permute_lines_list.append(np.flip(line, axis=0)) 159 | 160 | permute_lines_array = np.stack(permute_lines_list, axis=0) 161 | 162 | if is_closed: 163 | tmp = np.zeros((permute_num * 2, num_points, self.coords_dim)) 164 | tmp[:, :-1, :] = permute_lines_array 165 | tmp[:, -1, :] = permute_lines_array[:, 0, :] # add replicate start end pts 166 | permute_lines_array = tmp 167 | 168 | else: 169 | # padding 170 | padding = np.full([permute_num * 2 - 2, num_points, self.coords_dim], padding) 171 | permute_lines_array = np.concatenate((permute_lines_array, padding), axis=0) 172 | 173 | return permute_lines_array 174 | 175 | def __call__(self, input_dict): 176 | map_geoms = input_dict['map_geoms'] 177 | sd_map_data = input_dict.get('sd_vectors', None) 178 | input_dict['vectors'] = self.get_vectorized_lines(map_geoms) 179 | input_dict['sd_vectors'] = sd_map_data 180 | return input_dict 181 | 182 | def __repr__(self): 183 | repr_str = self.__class__.__name__ 184 | repr_str += f'(simplify={self.simplify}, ' 185 | repr_str += f'sample_num={self.sample_num}), ' 186 | repr_str += f'sample_dist={self.sample_dist}), ' 187 | repr_str += f'roi_size={self.roi_size})' 188 | repr_str += f'normalize={self.normalize})' 189 | repr_str += f'coords_dim={self.coords_dim})' 190 | 191 | return repr_str -------------------------------------------------------------------------------- /docs/getting_started.md: -------------------------------------------------------------------------------- 1 | ## Getting Started 2 | 3 | ### Training 4 | 5 | Run `python train.py [config-file]`, for example: 6 | 7 | ``` 8 | # Baseline model 9 | python train.py config/nusc/baseline/baseline_60m.py 10 | # SDMap Prior model 11 | python train.py config/nusc/sd_prior/sd_60m.py 12 | ``` 13 | 14 | Explanation of some parameters in `[config-file]`: 15 | * `dataroot`: the path of your nuScenes data 16 | * `logdir`: the path where log files, checkpoints, etc., are saved 17 | * `model`: model name. Currently, the following models are supported: `HDMapNet_cam`, `HDMapNet_fusion`, `pmapnet_sd[_cam]`, `pmapnet_hd`, and `hdmapnet_pretrain`. You can find them in the [file](../model/__init__.py). 18 | * `batch_size`: this should be the sum of samples across all GPUs, where `sample_per_gpu` = `batch_size` / `gpu_nums`. 19 | * `gpus`: the number of GPUs you are using. 20 | 21 | ### Evaluation 22 | 23 | #### mIoU Metric 24 | To evaluate your model using the mIoU metric, you should first set the `modelf` in `[config-file]` to the path of your checkpoint, and then use the following command: 25 | ``` 26 | python tools/eval.py [config-file] 27 | ``` 28 | 29 | #### mAP Metric 30 | 31 | Before running the evaluation code, you should first obtain the `submission.json` file, which can be generated using the following command: 32 | ``` 33 | python tools/export_json.py 34 | ``` 35 | > Note: remember to set the value of `result_path` in `[config-file]`. 36 | 37 | Run `python tools/evaluate_json.py` for evaluation. 38 | ``` 39 | python tools/evaluate_json.py 40 | ``` 41 | -------------------------------------------------------------------------------- /docs/installation.md: -------------------------------------------------------------------------------- 1 | ### Environment 2 | 3 | 1. Create conda environment 4 | ``` 5 | conda env create -f environment.yml 6 | conda activate pmapnet 7 | ``` 8 | 2. Install pytorch 9 | ``` 10 | pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html 11 | ``` 12 | 3. Install dependencies 13 | ``` 14 | pip install -r requirements.txt 15 | ``` 16 | 17 | ### Datasets preparing 18 | Download [nuScenes dataset](https://www.nuscenes.org/) and put it to `dataset/` folder. -------------------------------------------------------------------------------- /docs/visualization.md: -------------------------------------------------------------------------------- 1 | # Visualization 2 | 3 | We provide all the visualization scripts under `tools/vis_*.py` 4 | 5 | ## Visualize prediction 6 | 7 | - Set `modelf = /path/to/experiment/ckpt` in config file. 8 | 9 | ```shell 10 | python tools/vis_map.py /path/to/experiment/config 11 | ``` 12 | **Notes**: 13 | 14 | - All the visualization samples will be saved in `P_MAPNET/Work_dir/experiment/vis` automatically. If you want to customize the saving path, you can add `vis_path = /customized_path` in config file. 15 | 16 | ## Merge them into video 17 | 18 | We also provide the script to merge the input, output and GT into video to benchmark the performance qualitatively. 19 | 20 | ```shell 21 | # visualize nuscenes dataset 22 | python tools/vis_video_nus.py /path/to/experiment/config path/to/experiment/vis 23 | #visualize argoverse2 dataset 24 | python tools/vis_video_av2.py /path/to/experiment/config path/to/experiment/vis 25 | ``` 26 | **Notes**: 27 | - The video will be saved in `P-MAPNET/Work_dir/experiment/demo.mp4` -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: pmapnet 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - _openmp_mutex=5.1=1_gnu 7 | - asttokens=2.0.5=pyhd3eb1b0_0 8 | - backcall=0.2.0=pyhd3eb1b0_0 9 | - ca-certificates=2023.08.22=h06a4308_0 10 | - decorator=5.1.1=pyhd3eb1b0_0 11 | - executing=0.8.3=pyhd3eb1b0_0 12 | - ld_impl_linux-64=2.38=h1181459_1 13 | - libffi=3.4.4=h6a678d5_0 14 | - libgcc-ng=11.2.0=h1234567_1 15 | - libgomp=11.2.0=h1234567_1 16 | - libstdcxx-ng=11.2.0=h1234567_1 17 | - matplotlib-inline=0.1.6=py38h06a4308_0 18 | - ncurses=6.4=h6a678d5_0 19 | - openssl=3.0.11=h7f8727e_2 20 | - parso=0.8.3=pyhd3eb1b0_0 21 | - pexpect=4.8.0=pyhd3eb1b0_3 22 | - pickleshare=0.7.5=pyhd3eb1b0_1003 23 | - pip=23.2.1=py38h06a4308_0 24 | - ptyprocess=0.7.0=pyhd3eb1b0_2 25 | - pure_eval=0.2.2=pyhd3eb1b0_0 26 | - pygments=2.15.1=py38h06a4308_1 27 | - python=3.8.18=h955ad1f_0 28 | - readline=8.2=h5eee18b_0 29 | - setuptools=68.0.0=py38h06a4308_0 30 | - six=1.16.0=pyhd3eb1b0_1 31 | - sqlite=3.41.2=h5eee18b_0 32 | - stack_data=0.2.0=pyhd3eb1b0_0 33 | - tk=8.6.12=h1ccaba5_0 34 | - traitlets=5.7.1=py38h06a4308_0 35 | - typing_extensions=4.7.1=py38h06a4308_0 36 | - wheel=0.41.2=py38h06a4308_0 37 | - xz=5.4.2=h5eee18b_0 38 | - zlib=1.2.13=h5eee18b_0 -------------------------------------------------------------------------------- /figs/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jike5/P-MapNet/b8b4cf2295ee75826046eef9cfa12b107fb43619/figs/teaser.jpg -------------------------------------------------------------------------------- /icon/car.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jike5/P-MapNet/b8b4cf2295ee75826046eef9cfa12b107fb43619/icon/car.png -------------------------------------------------------------------------------- /icon/car_gray.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jike5/P-MapNet/b8b4cf2295ee75826046eef9cfa12b107fb43619/icon/car_gray.png -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .hdmapnet import HDMapNet 2 | from .lift_splat import LiftSplat 3 | from .pmapnet_sd import PMapNet_SD 4 | from .pmapnet_hd import PMapNet_HD, PMapNet_HD16, PMapNet_HD32 5 | from .utils.map_mae_head import vit_base_patch8, vit_base_patch16, vit_base_patch32 6 | 7 | def get_model(cfg, data_conf, instance_seg=True, embedded_dim=16, direction_pred=True, angle_class=36): 8 | patch_h = data_conf['ybound'][1] - data_conf['ybound'][0] 9 | patch_w = data_conf['xbound'][1] - data_conf['xbound'][0] 10 | canvas_h = int(patch_h / data_conf['ybound'][2]) 11 | canvas_w = int(patch_w / data_conf['xbound'][2]) 12 | 13 | method = cfg.model 14 | if "dataset" in cfg: 15 | if cfg.dataset == 'av2': 16 | data_conf.update({"num_cams":7}) 17 | 18 | if method == 'lift_splat': 19 | model = LiftSplat(data_conf, instance_seg=instance_seg, embedded_dim=embedded_dim) 20 | 21 | # HDMapNet model 22 | elif method == 'HDMapNet_cam': 23 | model = HDMapNet(data_conf, instance_seg=instance_seg, embedded_dim=embedded_dim, direction_pred=direction_pred, direction_dim=angle_class, lidar=False) 24 | elif method == 'HDMapNet_fusion': 25 | model = HDMapNet(data_conf, instance_seg=instance_seg, embedded_dim=embedded_dim, direction_pred=direction_pred, direction_dim=angle_class, lidar=True) 26 | 27 | # P-MapNet sd prior model 28 | elif method == 'pmapnet_sd': 29 | model = PMapNet_SD(data_conf, instance_seg=instance_seg, embedded_dim=embedded_dim, direction_pred=direction_pred, direction_dim=angle_class, lidar=True) 30 | elif method == 'pmapnet_sd_cam': 31 | model = PMapNet_SD(data_conf, instance_seg=instance_seg, embedded_dim=embedded_dim, direction_pred=direction_pred, direction_dim=angle_class, lidar=False) 32 | 33 | # P-MapNet hd prior model 34 | elif method == 'pmapnet_hd': 35 | model = PMapNet_HD(data_conf, instance_seg=instance_seg, embedded_dim=embedded_dim, direction_pred=direction_pred, direction_dim=angle_class, lidar=True) 36 | elif method == 'pmapnet_hd16': 37 | model = PMapNet_HD16(data_conf, instance_seg=instance_seg, embedded_dim=embedded_dim, direction_pred=direction_pred, direction_dim=angle_class, lidar=True) 38 | elif method == 'pmapnet_hd32': 39 | model = PMapNet_HD32(data_conf, instance_seg=instance_seg, embedded_dim=embedded_dim, direction_pred=direction_pred, direction_dim=angle_class, lidar=True) 40 | elif method == 'pmapnet_hd_cam': 41 | model = PMapNet_HD(data_conf, instance_seg=instance_seg, embedded_dim=embedded_dim, direction_pred=direction_pred, direction_dim=angle_class, lidar=False) 42 | elif method == 'pmapnet_hd_cam16': 43 | model = PMapNet_HD16(data_conf, instance_seg=instance_seg, embedded_dim=embedded_dim, direction_pred=direction_pred, direction_dim=angle_class, lidar=False) 44 | 45 | # P-MapNet hd pretrain model 46 | elif method == "hdmapnet_pretrain": 47 | model = vit_base_patch8(data_conf=data_conf, instance_seg=instance_seg, embedded_dim=embedded_dim, direction_pred=direction_pred, direction_dim=angle_class, lidar=True, img_size=(canvas_h, canvas_w)) 48 | elif method == "hdmapnet_pretrain16": 49 | model = vit_base_patch16(data_conf=data_conf, instance_seg=instance_seg, embedded_dim=embedded_dim, direction_pred=direction_pred, direction_dim=angle_class, lidar=True, img_size=(canvas_h, canvas_w)) 50 | elif method == "hdmapnet_pretrain32": 51 | model = vit_base_patch32(data_conf=data_conf, instance_seg=instance_seg, embedded_dim=embedded_dim, direction_pred=direction_pred, direction_dim=angle_class, lidar=True, img_size=(canvas_h, canvas_w)) 52 | else: 53 | raise NotImplementedError 54 | 55 | return model 56 | -------------------------------------------------------------------------------- /model/hdmapnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from .utils.homography import bilinear_sampler, IPM 5 | from .utils.utils import plane_grid_2d, get_rot_2d, cam_to_pixel 6 | from .utils.pointpillar import PointPillarEncoder 7 | from .utils.base import CamEncode, BevEncode 8 | from data_osm.utils import gen_dx_bx 9 | 10 | 11 | class ViewTransformation(nn.Module): 12 | def __init__(self, fv_size, bv_size, n_views=6): 13 | super(ViewTransformation, self).__init__() 14 | self.n_views = n_views 15 | self.hw_mat = [] 16 | self.bv_size = bv_size 17 | fv_dim = fv_size[0] * fv_size[1] 18 | bv_dim = bv_size[0] * bv_size[1] 19 | for i in range(self.n_views): 20 | fc_transform = nn.Sequential( 21 | nn.Linear(fv_dim, bv_dim), 22 | nn.ReLU(), 23 | nn.Linear(bv_dim, bv_dim), 24 | nn.ReLU() 25 | ) 26 | self.hw_mat.append(fc_transform) 27 | self.hw_mat = nn.ModuleList(self.hw_mat) 28 | 29 | def forward(self, feat): 30 | B, N, C, H, W = feat.shape 31 | feat = feat.view(B, N, C, H*W) 32 | outputs = [] 33 | for i in range(N): 34 | output = self.hw_mat[i](feat[:, i]).view(B, C, self.bv_size[0], self.bv_size[1]) 35 | outputs.append(output) 36 | outputs = torch.stack(outputs, 1) 37 | return outputs 38 | 39 | 40 | class HDMapNet(nn.Module): 41 | def __init__(self, data_conf, instance_seg=True, embedded_dim=16, direction_pred=True, direction_dim=36, lidar=False): 42 | super(HDMapNet, self).__init__() 43 | self.camC = 64 44 | self.downsample = 16 45 | 46 | dx, bx, nx = gen_dx_bx(data_conf['xbound'], data_conf['ybound'], data_conf['zbound']) 47 | final_H, final_W = nx[1].item(), nx[0].item() 48 | 49 | self.camencode = CamEncode(self.camC) 50 | fv_size = (data_conf['image_size'][0]//self.downsample, data_conf['image_size'][1]//self.downsample) 51 | bv_size = (final_H//5, final_W//5) 52 | num_cams = data_conf.get('num_cams', 6) 53 | print("num_cams: ", num_cams) 54 | self.view_fusion = ViewTransformation(fv_size=fv_size, bv_size=bv_size, n_views=num_cams) 55 | 56 | res_x = bv_size[1] * 3 // 4 57 | ipm_xbound = [-res_x, res_x, 4*res_x/final_W] 58 | ipm_ybound = [-res_x/2, res_x/2, 2*res_x/final_H] 59 | self.ipm = IPM(ipm_xbound, ipm_ybound, N=num_cams, C=self.camC, extrinsic=True) 60 | self.up_sampler = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 61 | # self.up_sampler = nn.Upsample(scale_factor=5, mode='bilinear', align_corners=True) 62 | 63 | self.lidar = lidar 64 | lidar_dim = 128 65 | if lidar: 66 | self.pp = PointPillarEncoder(lidar_dim, data_conf['xbound'], data_conf['ybound'], data_conf['zbound']) 67 | self.bevencode = BevEncode(inC=self.camC+lidar_dim, outC=data_conf['num_channels'], instance_seg=instance_seg, embedded_dim=embedded_dim, direction_pred=direction_pred, direction_dim=direction_dim+1) 68 | else: 69 | self.bevencode = BevEncode(inC=self.camC, outC=data_conf['num_channels'], instance_seg=instance_seg, embedded_dim=embedded_dim, direction_pred=direction_pred, direction_dim=direction_dim+1) 70 | 71 | def get_Ks_RTs_and_post_RTs(self, intrins, rots, trans, post_rots, post_trans): 72 | B, N, _, _ = intrins.shape 73 | Ks = torch.eye(4, device=intrins.device).view(1, 1, 4, 4).repeat(B, N, 1, 1) 74 | 75 | Rs = torch.eye(4, device=rots.device).view(1, 1, 4, 4).repeat(B, N, 1, 1) 76 | Rs[:, :, :3, :3] = rots.transpose(-1, -2).contiguous() 77 | Ts = torch.eye(4, device=trans.device).view(1, 1, 4, 4).repeat(B, N, 1, 1) 78 | Ts[:, :, :3, 3] = -trans 79 | RTs = Rs @ Ts 80 | 81 | post_RTs = None 82 | 83 | return Ks, RTs, post_RTs 84 | 85 | def get_cam_feats(self, x): 86 | B, N, C, imH, imW = x.shape 87 | x = x.view(B*N, C, imH, imW) 88 | x = self.camencode(x) 89 | x = x.view(B, N, self.camC, imH//self.downsample, imW//self.downsample) 90 | return x 91 | 92 | def forward(self, img, trans, rots, intrins, post_trans, post_rots, lidar_data, lidar_mask, car_trans, yaw_pitch_roll, osm): 93 | x = self.get_cam_feats(img) 94 | # import pdb; pdb.set_trace() 95 | x = self.view_fusion(x) 96 | Ks, RTs, post_RTs = self.get_Ks_RTs_and_post_RTs(intrins, rots, trans, post_rots, post_trans) 97 | topdown = self.ipm(x, Ks, RTs, car_trans, yaw_pitch_roll, post_RTs) 98 | topdown = self.up_sampler(topdown) 99 | if self.lidar: 100 | lidar_feature = self.pp(lidar_data, lidar_mask) 101 | topdown = torch.cat([topdown, lidar_feature], dim=1) 102 | return self.bevencode(topdown) 103 | -------------------------------------------------------------------------------- /model/lift_splat.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2020 NVIDIA Corporation. All rights reserved. 3 | Licensed under the NVIDIA Source Code License. See LICENSE at https://github.com/nv-tlabs/lift-splat-shoot. 4 | Authors: Jonah Philion and Sanja Fidler 5 | """ 6 | 7 | import torch 8 | from torch import nn 9 | 10 | from data_osm.utils import gen_dx_bx 11 | from .utils.base import CamEncode, BevEncode 12 | 13 | 14 | def cumsum_trick(x, geom_feats, ranks): 15 | x = x.cumsum(0) 16 | kept = torch.ones(x.shape[0], device=x.device, dtype=torch.bool) 17 | kept[:-1] = (ranks[1:] != ranks[:-1]) 18 | 19 | x, geom_feats = x[kept], geom_feats[kept] 20 | x = torch.cat((x[:1], x[1:] - x[:-1])) 21 | 22 | return x, geom_feats 23 | 24 | 25 | class QuickCumsum(torch.autograd.Function): 26 | @staticmethod 27 | def forward(ctx, x, geom_feats, ranks): 28 | x = x.cumsum(0) 29 | kept = torch.ones(x.shape[0], device=x.device, dtype=torch.bool) 30 | kept[:-1] = (ranks[1:] != ranks[:-1]) 31 | 32 | x, geom_feats = x[kept], geom_feats[kept] 33 | x = torch.cat((x[:1], x[1:] - x[:-1])) 34 | 35 | # save kept for backward 36 | ctx.save_for_backward(kept) 37 | 38 | # no gradient for geom_feats 39 | ctx.mark_non_differentiable(geom_feats) 40 | 41 | return x, geom_feats 42 | 43 | @staticmethod 44 | def backward(ctx, gradx, gradgeom): 45 | kept, = ctx.saved_tensors 46 | back = torch.cumsum(kept, 0) 47 | back[kept] -= 1 48 | 49 | val = gradx[back] 50 | 51 | return val, None, None 52 | 53 | 54 | class LiftSplat(nn.Module): 55 | def __init__(self, grid_conf, data_aug_conf, outC, instance_seg, embedded_dim): 56 | super(LiftSplat, self).__init__() 57 | self.grid_conf = grid_conf 58 | self.data_aug_conf = data_aug_conf 59 | 60 | dx, bx, nx = gen_dx_bx(self.grid_conf['xbound'], 61 | self.grid_conf['ybound'], 62 | self.grid_conf['zbound'], 63 | ) 64 | self.dx = nn.Parameter(dx, requires_grad=False) 65 | self.bx = nn.Parameter(bx, requires_grad=False) 66 | self.nx = nn.Parameter(nx, requires_grad=False) 67 | 68 | self.downsample = 16 69 | self.camC = 64 70 | self.frustum = self.create_frustum() 71 | # D x H/downsample x D/downsample x 3 72 | self.D, _, _, _ = self.frustum.shape 73 | self.camencode = CamEncode(self.D, self.camC, self.downsample) 74 | self.bevencode = BevEncode(inC=self.camC, outC=outC, instance_seg=instance_seg, embedded_dim=embedded_dim) 75 | 76 | # toggle using QuickCumsum vs. autograd 77 | self.use_quickcumsum = True 78 | 79 | def create_frustum(self): 80 | # make grid in image plane 81 | ogfH, ogfW = self.data_aug_conf['final_dim'] 82 | fH, fW = ogfH // self.downsample, ogfW // self.downsample 83 | ds = torch.arange(*self.grid_conf['dbound'], dtype=torch.float).view(-1, 1, 1).expand(-1, fH, fW) 84 | D, _, _ = ds.shape 85 | xs = torch.linspace(0, ogfW - 1, fW, dtype=torch.float).view(1, 1, fW).expand(D, fH, fW) 86 | ys = torch.linspace(0, ogfH - 1, fH, dtype=torch.float).view(1, fH, 1).expand(D, fH, fW) 87 | 88 | # D x H x W x 3 89 | frustum = torch.stack((xs, ys, ds), -1) 90 | return nn.Parameter(frustum, requires_grad=False) 91 | 92 | def get_geometry(self, rots, trans, intrins, post_rots, post_trans): 93 | """Determine the (x,y,z) locations (in the ego frame) 94 | of the points in the point cloud. 95 | Returns B x N x D x H/downsample x W/downsample x 3 96 | """ 97 | B, N, _ = trans.shape 98 | 99 | # *undo* post-transformation 100 | # B x N x D x H x W x 3 101 | points = self.frustum - post_trans.view(B, N, 1, 1, 1, 3) 102 | points = torch.inverse(post_rots).view(B, N, 1, 1, 1, 3, 3).matmul(points.unsqueeze(-1)) 103 | 104 | # cam_to_ego 105 | points = torch.cat((points[:, :, :, :, :, :2] * points[:, :, :, :, :, 2:3], 106 | points[:, :, :, :, :, 2:3] 107 | ), 5) 108 | combine = rots.matmul(torch.inverse(intrins)) 109 | points = combine.view(B, N, 1, 1, 1, 3, 3).matmul(points).squeeze(-1) 110 | points += trans.view(B, N, 1, 1, 1, 3) 111 | 112 | return points 113 | 114 | def get_cam_feats(self, x): 115 | """Return B x N x D x H/downsample x W/downsample x C 116 | """ 117 | B, N, C, imH, imW = x.shape 118 | 119 | x = x.view(B*N, C, imH, imW) 120 | x = self.camencode(x) 121 | x = x.view(B, N, self.camC, self.D, imH//self.downsample, imW//self.downsample) 122 | x = x.permute(0, 1, 3, 4, 5, 2) 123 | 124 | return x 125 | 126 | def voxel_pooling(self, geom_feats, x): 127 | B, N, D, H, W, C = x.shape 128 | Nprime = B*N*D*H*W 129 | 130 | # flatten x 131 | x = x.reshape(Nprime, C) 132 | 133 | # flatten indices 134 | # B x N x D x H/downsample x W/downsample x 3 135 | geom_feats = ((geom_feats - (self.bx - self.dx/2.)) / self.dx).long() 136 | geom_feats = geom_feats.view(Nprime, 3) 137 | batch_ix = torch.cat([torch.full([Nprime//B, 1], ix, device=x.device, dtype=torch.long) for ix in range(B)]) 138 | geom_feats = torch.cat((geom_feats, batch_ix), 1) # x, y, z, b 139 | 140 | # filter out points that are outside box 141 | kept = (geom_feats[:, 0] >= 0) & (geom_feats[:, 0] < self.nx[0])\ 142 | & (geom_feats[:, 1] >= 0) & (geom_feats[:, 1] < self.nx[1])\ 143 | & (geom_feats[:, 2] >= 0) & (geom_feats[:, 2] < self.nx[2]) 144 | x = x[kept] 145 | geom_feats = geom_feats[kept] 146 | 147 | # get tensors from the same voxel next to each other 148 | ranks = geom_feats[:, 0] * (self.nx[1] * self.nx[2] * B)\ 149 | + geom_feats[:, 1] * (self.nx[2] * B)\ 150 | + geom_feats[:, 2] * B\ 151 | + geom_feats[:, 3] 152 | sorts = ranks.argsort() 153 | x, geom_feats, ranks = x[sorts], geom_feats[sorts], ranks[sorts] 154 | 155 | # cumsum trick 156 | if not self.use_quickcumsum: 157 | x, geom_feats = cumsum_trick(x, geom_feats, ranks) 158 | else: 159 | x, geom_feats = QuickCumsum.apply(x, geom_feats, ranks) 160 | 161 | # griddify (B x C x Z x X x Y) 162 | final = torch.zeros((B, C, self.nx[2], self.nx[1], self.nx[0]), device=x.device) 163 | final[geom_feats[:, 3], :, geom_feats[:, 2], geom_feats[:, 1], geom_feats[:, 0]] = x 164 | 165 | # collapse Z 166 | final = torch.cat(final.unbind(dim=2), 1) 167 | 168 | return final 169 | 170 | def get_voxels(self, x, rots, trans, intrins, post_rots, post_trans): 171 | # B x N x D x H/downsample x W/downsample x 3: (x,y,z) locations (in the ego frame) 172 | geom = self.get_geometry(rots, trans, intrins, post_rots, post_trans) 173 | # B x N x D x H/downsample x W/downsample x C: cam feats 174 | x = self.get_cam_feats(x) 175 | 176 | x = self.voxel_pooling(geom, x) 177 | 178 | return x 179 | 180 | def forward(self, points, points_mask, x, rots, trans, intrins, post_rots, post_trans, translation, yaw_pitch_roll): 181 | x = self.get_voxels(x, rots, trans, intrins, post_rots, post_trans) 182 | x = self.bevencode(x) 183 | return x 184 | -------------------------------------------------------------------------------- /model/pmapnet_sd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from .utils.homography import IPM 4 | from .utils.pointpillar import PointPillarEncoder 5 | from .utils.base import CamEncode, BevEncode 6 | from data_osm.utils import gen_dx_bx 7 | from .utils.sdmap_cross_attn import SDMapCrossAttn 8 | from .utils.position_encoding import PositionEmbeddingSine 9 | 10 | class ViewTransformation(nn.Module): 11 | def __init__(self, fv_size, bv_size, n_views=6): 12 | super(ViewTransformation, self).__init__() 13 | self.n_views = n_views 14 | self.hw_mat = [] 15 | self.bv_size = bv_size 16 | fv_dim = fv_size[0] * fv_size[1] 17 | bv_dim = bv_size[0] * bv_size[1] 18 | for i in range(self.n_views): 19 | fc_transform = nn.Sequential( 20 | nn.Linear(fv_dim, bv_dim), 21 | nn.ReLU(), 22 | nn.Linear(bv_dim, bv_dim), 23 | nn.ReLU() 24 | ) 25 | self.hw_mat.append(fc_transform) 26 | self.hw_mat = nn.ModuleList(self.hw_mat) 27 | 28 | def forward(self, feat): 29 | B, N, C, H, W = feat.shape 30 | feat = feat.view(B, N, C, H*W) 31 | outputs = [] 32 | for i in range(N): 33 | output = self.hw_mat[i](feat[:, i]).view(B, C, self.bv_size[0], self.bv_size[1]) 34 | outputs.append(output) 35 | outputs = torch.stack(outputs, 1) 36 | return outputs 37 | 38 | 39 | class PMapNet_SD(nn.Module): 40 | def __init__(self, data_conf, instance_seg=True, embedded_dim=16, direction_pred=True, direction_dim=36, lidar=False): 41 | super(PMapNet_SD, self).__init__() 42 | 43 | self.lidar = lidar 44 | self.camC = 64 45 | self.LiDARC = 128 46 | self.downsample = 16 47 | 48 | #cross attn params 49 | hidden_dim = 64 50 | self.position_embedding = PositionEmbeddingSine(hidden_dim//2, normalize=True) 51 | 52 | if lidar: 53 | feat_numchannels = self.camC+self.LiDARC 54 | self.pp = PointPillarEncoder(self.LiDARC, data_conf['xbound'], data_conf['ybound'], data_conf['zbound']) 55 | else: 56 | feat_numchannels = self.camC 57 | 58 | self.input_proj = nn.Conv2d(feat_numchannels, hidden_dim, kernel_size=1) 59 | 60 | # sdmap_cross_attn 61 | self.sdmap_crossattn = SDMapCrossAttn(d_model=hidden_dim, num_decoder_layers=2, dropout=0.1) 62 | 63 | dx, bx, nx = gen_dx_bx(data_conf['xbound'], data_conf['ybound'], data_conf['zbound']) 64 | final_H, final_W = nx[1].item(), nx[0].item() 65 | 66 | self.camencode = CamEncode(self.camC) 67 | fv_size = (data_conf['image_size'][0]//self.downsample, data_conf['image_size'][1]//self.downsample) 68 | bv_size = (final_H//5, final_W//5) 69 | num_cams = data_conf.get('num_cams', 6) 70 | # import pdb; pdb.set_trace() 71 | self.view_fusion = ViewTransformation(fv_size=fv_size, bv_size=bv_size, n_views=num_cams) 72 | 73 | res_x = bv_size[1] * 3 // 4 74 | ipm_xbound = [-res_x, res_x, 4*res_x/final_W] 75 | ipm_ybound = [-res_x/2, res_x/2, 2*res_x/final_H] 76 | self.ipm = IPM(ipm_xbound, ipm_ybound, N=num_cams, C=self.camC, extrinsic=True) 77 | self.up_sampler = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 78 | 79 | self.pool = nn.AvgPool2d(kernel_size=10, stride=10) 80 | self.conv_osm = self.nn_Sequential(1, hidden_dim // 2, hidden_dim) 81 | self.conv_bev = self.nn_Sequential(feat_numchannels, feat_numchannels, feat_numchannels) 82 | self.conv_up = self.nn_Sequential_Transpose(hidden_dim, feat_numchannels, feat_numchannels) 83 | 84 | self.bevencode = BevEncode(inC=feat_numchannels, outC=data_conf['num_channels'], instance_seg=instance_seg, embedded_dim=embedded_dim, direction_pred=direction_pred, direction_dim=direction_dim+1) 85 | 86 | def get_Ks_RTs_and_post_RTs(self, intrins, rots, trans, post_rots, post_trans): 87 | B, N, _, _ = intrins.shape 88 | Ks = torch.eye(4, device=intrins.device).view(1, 1, 4, 4).repeat(B, N, 1, 1) 89 | 90 | Rs = torch.eye(4, device=rots.device).view(1, 1, 4, 4).repeat(B, N, 1, 1) 91 | Rs[:, :, :3, :3] = rots.transpose(-1, -2).contiguous() 92 | Ts = torch.eye(4, device=trans.device).view(1, 1, 4, 4).repeat(B, N, 1, 1) 93 | Ts[:, :, :3, 3] = -trans 94 | RTs = Rs @ Ts 95 | 96 | post_RTs = None 97 | 98 | return Ks, RTs, post_RTs 99 | 100 | def get_cam_feats(self, x): 101 | B, N, C, imH, imW = x.shape 102 | x = x.view(B*N, C, imH, imW) 103 | x = self.camencode(x) 104 | x = x.view(B, N, self.camC, imH//self.downsample, imW//self.downsample) 105 | return x 106 | 107 | def nn_Sequential(self, in_dim=192, mid_dim=192, out_dim=192): 108 | return nn.Sequential( 109 | nn.Conv2d(in_dim, out_channels=mid_dim, kernel_size=4, stride=2, padding=1), 110 | nn.ReLU(), 111 | nn.Conv2d(mid_dim, out_channels=out_dim, kernel_size=4, stride=2, padding=1), 112 | nn.ReLU(), 113 | nn.Conv2d(out_dim, out_dim, kernel_size=4, stride=2, padding=1, bias=False), 114 | nn.BatchNorm2d(out_dim), 115 | nn.ReLU(inplace=True), 116 | ) 117 | 118 | def nn_Sequential_Transpose(self, in_dim=192, mid_dim=192, out_dim=192): 119 | return nn.Sequential( 120 | nn.ConvTranspose2d(in_dim, out_channels=mid_dim, kernel_size=4, stride=2, padding=1), 121 | nn.ReLU(), 122 | nn.ConvTranspose2d(mid_dim, out_channels=out_dim, kernel_size=4, stride=2, padding=1), 123 | nn.ReLU(), 124 | nn.ConvTranspose2d(out_dim, out_dim, kernel_size=4, stride=2, padding=1, bias=False), 125 | nn.BatchNorm2d(out_dim), 126 | nn.ReLU(inplace=True), 127 | ) 128 | def forward(self, img, trans, rots, intrins, post_trans, post_rots, lidar_data, lidar_mask, car_trans, yaw_pitch_roll, osm): 129 | x = self.get_cam_feats(img) 130 | # import pdb; pdb.set_trace() 131 | x = self.view_fusion(x) 132 | Ks, RTs, post_RTs = self.get_Ks_RTs_and_post_RTs(intrins, rots, trans, post_rots, post_trans) 133 | topdown = self.ipm(x, Ks, RTs, car_trans, yaw_pitch_roll, post_RTs) 134 | topdown = self.up_sampler(topdown) 135 | if self.lidar: 136 | lidar_feature = self.pp(lidar_data, lidar_mask) 137 | topdown = torch.cat([topdown, lidar_feature], dim=1) 138 | 139 | bev_small = self.conv_bev(topdown) 140 | 141 | conv_osm = self.conv_osm(osm) 142 | 143 | bs,c,h,w = bev_small.shape 144 | self.mask = torch.zeros([1,h,w],dtype=torch.bool) 145 | 146 | pos = self.position_embedding(bev_small[-1], self.mask.to(bev_small.device)).to(bev_small.dtype) 147 | bs = bev_small.shape[0] 148 | pos = pos.repeat(bs, 1, 1, 1) 149 | bev_out = self.sdmap_crossattn(self.input_proj(bev_small), conv_osm, pos = pos)[0] 150 | bev_final = self.conv_up(bev_out) 151 | return self.bevencode(bev_final) 152 | -------------------------------------------------------------------------------- /model/utils/VPN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .base import CamEncode, BevEncode 5 | from .pointpillar import PointPillarEncoder 6 | 7 | 8 | class TransformModule(nn.Module): 9 | def __init__(self, dim, num_view=6): 10 | super(TransformModule, self).__init__() 11 | self.num_view = num_view 12 | self.dim = dim 13 | self.mat_list = nn.ModuleList() 14 | for i in range(self.num_view): 15 | fc_transform = nn.Sequential( 16 | nn.Linear(dim * dim, dim * dim), 17 | nn.ReLU(), 18 | nn.Linear(dim * dim, dim * dim), 19 | nn.ReLU() 20 | ) 21 | self.mat_list += [fc_transform] 22 | 23 | def forward(self, x): 24 | # shape x: B, V, C, H, W 25 | x = x.view(list(x.size()[:3]) + [self.dim * self.dim,]) 26 | view_comb = self.mat_list[0](x[:, 0]) 27 | for index in range(x.size(1))[1:]: 28 | view_comb += self.mat_list[index](x[:, index]) 29 | view_comb = view_comb.view(list(view_comb.size()[:2]) + [self.dim, self.dim]) 30 | return view_comb 31 | 32 | 33 | class VPNModel(nn.Module): 34 | def __init__(self, outC, camC=64, instance_seg=True, embedded_dim=16, extrinsic=False, lidar=False, xbound=None, ybound=None, zbound=None): 35 | super(VPNModel, self).__init__() 36 | self.camC = camC 37 | self.extrinsic = extrinsic 38 | self.downsample = 16 39 | 40 | self.camencode = CamEncode(camC) 41 | self.view_fusion = TransformModule(dim=(8, 22)) 42 | self.up_sampler = nn.Upsample(size=(200, 400), mode='bilinear', align_corners=True) 43 | self.lidar = lidar 44 | if lidar: 45 | self.pp = PointPillarEncoder(128, xbound, ybound, zbound) 46 | self.bevencode = BevEncode(inC=camC+128, outC=outC, instance_seg=instance_seg, embedded_dim=embedded_dim) 47 | else: 48 | self.bevencode = BevEncode(inC=camC, outC=outC, instance_seg=instance_seg, embedded_dim=embedded_dim) 49 | 50 | 51 | def get_Ks_RTs_and_post_RTs(self, intrins, rots, trans, post_rots, post_trans): 52 | B, N, _, _ = intrins.shape 53 | Ks = torch.eye(4, device=intrins.device).view(1, 1, 4, 4).repeat(B, N, 1, 1) 54 | # Ks[:, :, :3, :3] = intrins 55 | 56 | Rs = torch.eye(4, device=rots.device).view(1, 1, 4, 4).repeat(B, N, 1, 1) 57 | Rs[:, :, :3, :3] = rots.transpose(-1, -2).contiguous() 58 | Ts = torch.eye(4, device=trans.device).view(1, 1, 4, 4).repeat(B, N, 1, 1) 59 | Ts[:, :, :3, 3] = -trans 60 | RTs = Rs @ Ts 61 | 62 | post_RTs = None 63 | 64 | return Ks, RTs, post_RTs 65 | 66 | def get_cam_feats(self, x): 67 | """Return B x N x D x H/downsample x W/downsample x C 68 | """ 69 | B, N, C, imH, imW = x.shape 70 | 71 | x = x.view(B*N, C, imH, imW) 72 | x = self.camencode(x) 73 | x = x.view(B, N, self.camC, imH//self.downsample, imW//self.downsample) 74 | return x 75 | 76 | def forward(self, points, points_mask, x, rots, trans, intrins, post_rots, post_trans, translation, yaw_pitch_roll): 77 | x = self.get_cam_feats(x) 78 | x = self.view_fusion(x) 79 | topdown = x.mean(1) 80 | topdown = self.up_sampler(topdown) 81 | if self.lidar: 82 | lidar_feature = self.pp(points, points_mask) 83 | topdown = torch.cat([topdown, lidar_feature], dim=1) 84 | return self.bevencode(topdown) 85 | -------------------------------------------------------------------------------- /model/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jike5/P-MapNet/b8b4cf2295ee75826046eef9cfa12b107fb43619/model/utils/__init__.py -------------------------------------------------------------------------------- /model/utils/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from efficientnet_pytorch import EfficientNet 5 | from torchvision.models.resnet import resnet18,resnet50 6 | 7 | class Up(nn.Module): 8 | def __init__(self, in_channels, out_channels, scale_factor=2): 9 | super().__init__() 10 | 11 | self.up = nn.Upsample(scale_factor=scale_factor, mode='bilinear', 12 | align_corners=True) 13 | 14 | self.conv = nn.Sequential( 15 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False), 16 | nn.BatchNorm2d(out_channels), 17 | nn.ReLU(inplace=True), 18 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False), 19 | nn.BatchNorm2d(out_channels), 20 | nn.ReLU(inplace=True) 21 | ) 22 | 23 | def forward(self, x1, x2): 24 | x1 = self.up(x1) 25 | x1 = torch.cat([x2, x1], dim=1) 26 | return self.conv(x1) 27 | 28 | class CamEncode(nn.Module): 29 | def __init__(self, C): 30 | super(CamEncode, self).__init__() 31 | self.C = C 32 | 33 | self.trunk = EfficientNet.from_pretrained("efficientnet-b0") 34 | self.up1 = Up(320+112, self.C) 35 | 36 | def get_eff_depth(self, x): 37 | # adapted from https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/efficientnet_pytorch/model.py#L231 38 | endpoints = dict() 39 | 40 | # Stem 41 | x = self.trunk._swish(self.trunk._bn0(self.trunk._conv_stem(x))) 42 | prev_x = x 43 | 44 | # Blocks 45 | for idx, block in enumerate(self.trunk._blocks): 46 | drop_connect_rate = self.trunk._global_params.drop_connect_rate 47 | if drop_connect_rate: 48 | drop_connect_rate *= float(idx) / len(self.trunk._blocks) # scale drop connect_rate 49 | x = block(x, drop_connect_rate=drop_connect_rate) 50 | if prev_x.size(2) > x.size(2): 51 | endpoints['reduction_{}'.format(len(endpoints)+1)] = prev_x 52 | prev_x = x 53 | 54 | # Head 55 | endpoints['reduction_{}'.format(len(endpoints)+1)] = x 56 | x = self.up1(endpoints['reduction_5'], endpoints['reduction_4']) 57 | return x 58 | 59 | def forward(self, x): 60 | return self.get_eff_depth(x) 61 | 62 | class maeDecode(nn.Module): 63 | def __init__(self, inC, outC, instance_seg=True, embedded_dim=16, direction_pred=True, direction_dim=37): 64 | super(maeDecode, self).__init__() 65 | trunk = resnet50(pretrained=False, zero_init_residual=True) 66 | self.conv1 = nn.Conv2d(inC, 64, kernel_size=7, stride=2, padding=3, bias=False) 67 | 68 | self.bn1 = trunk.bn1 69 | self.relu = trunk.relu 70 | 71 | self.layer1 = trunk.layer1 72 | self.layer2 = trunk.layer2 73 | self.layer3 = trunk.layer3 74 | # self.res50.conv1(), 75 | # self.res50.bn1(), 76 | # self.res50.relu(), 77 | # self.res50.maxpool(), 78 | # self.res50.layer1(), 79 | # self.res50.layer2(), 80 | # self.res50.layer3(), 81 | # self.res50.layer4(), 82 | # self.res50.avgpool() 83 | self.up1 = Up(1024 + 256, 256, scale_factor=4) 84 | self.up2 = nn.Sequential( 85 | nn.Upsample(scale_factor=2, mode='bilinear', 86 | align_corners=True), 87 | nn.Conv2d(256, 128, kernel_size=3, padding=1, bias=False), 88 | nn.BatchNorm2d(128), 89 | nn.ReLU(inplace=True), 90 | nn.Conv2d(128, outC, kernel_size=1, padding=0), 91 | ) 92 | 93 | self.instance_seg = instance_seg 94 | if instance_seg: 95 | self.up1_embedded = Up(1024 + 256, 256, scale_factor=4) 96 | self.up2_embedded = nn.Sequential( 97 | nn.Upsample(scale_factor=2, mode='bilinear', 98 | align_corners=True), 99 | nn.Conv2d(256, 128, kernel_size=3, padding=1, bias=False), 100 | nn.BatchNorm2d(128), 101 | nn.ReLU(inplace=True), 102 | nn.Conv2d(128, embedded_dim, kernel_size=1, padding=0), 103 | ) 104 | 105 | self.direction_pred = direction_pred 106 | if direction_pred: 107 | # self.up1_direction = Up(64 + 256, 256, scale_factor=4) 108 | self.up1_direction = Up(1024 + 256, 256, scale_factor=4) 109 | self.up2_direction = nn.Sequential( 110 | nn.Upsample(scale_factor=2, mode='bilinear', 111 | align_corners=True), 112 | nn.Conv2d(256, 128, kernel_size=3, padding=1, bias=False), 113 | nn.BatchNorm2d(128), 114 | nn.ReLU(inplace=True), 115 | nn.Conv2d(128, direction_dim, kernel_size=1, padding=0), 116 | ) 117 | 118 | def forward(self, x): # x: torch.Size([bs, 128, 200, 400]) 119 | x = self.conv1(x) # x: torch.Size([bs, 64, 100, 200]) 120 | x = self.bn1(x) 121 | x = self.relu(x) 122 | 123 | x1 = self.layer1(x) # x1: torch.Size([bs, 256, 100, 200]) 124 | x = self.layer2(x1) # x: torch.Size([bs, 512, 100, 200]) 125 | x2 = self.layer3(x) # x2: torch.Size([bs, 1024, 25, 50]) 126 | 127 | x = self.up1(x2, x1) # x: torch.Size([bs, 256, 100, 200]) 128 | x = self.up2(x) # x: torch.Size([bs, 4, 200, 400]) 129 | 130 | if self.instance_seg: 131 | x_embedded = self.up1_embedded(x2, x1) 132 | x_embedded = self.up2_embedded(x_embedded) 133 | else: 134 | x_embedded = None 135 | 136 | if self.direction_pred: 137 | x_direction = self.up1_embedded(x2, x1) 138 | x_direction = self.up2_direction(x_direction) 139 | else: 140 | x_direction = None 141 | 142 | return x, x_embedded, x_direction 143 | 144 | class BevEncode(nn.Module): 145 | def __init__(self, inC, outC, instance_seg=True, embedded_dim=16, direction_pred=True, direction_dim=37): 146 | super(BevEncode, self).__init__() 147 | trunk = resnet18(pretrained=False, zero_init_residual=True) 148 | self.conv1 = nn.Conv2d(inC, 64, kernel_size=7, stride=2, padding=3, bias=False) 149 | self.bn1 = trunk.bn1 150 | self.relu = trunk.relu 151 | 152 | self.layer1 = trunk.layer1 153 | self.layer2 = trunk.layer2 154 | self.layer3 = trunk.layer3 155 | 156 | self.up1 = Up(64 + 256, 256, scale_factor=4) 157 | self.up2 = nn.Sequential( 158 | nn.Upsample(scale_factor=2, mode='bilinear', 159 | align_corners=True), 160 | nn.Conv2d(256, 128, kernel_size=3, padding=1, bias=False), 161 | nn.BatchNorm2d(128), 162 | nn.ReLU(inplace=True), 163 | nn.Conv2d(128, outC, kernel_size=1, padding=0), 164 | ) 165 | 166 | self.instance_seg = instance_seg 167 | if instance_seg: 168 | self.up1_embedded = Up(64 + 256, 256, scale_factor=4) 169 | self.up2_embedded = nn.Sequential( 170 | nn.Upsample(scale_factor=2, mode='bilinear', 171 | align_corners=True), 172 | nn.Conv2d(256, 128, kernel_size=3, padding=1, bias=False), 173 | nn.BatchNorm2d(128), 174 | nn.ReLU(inplace=True), 175 | nn.Conv2d(128, embedded_dim, kernel_size=1, padding=0), 176 | ) 177 | 178 | self.direction_pred = direction_pred 179 | if direction_pred: 180 | self.up1_direction = Up(64 + 256, 256, scale_factor=4) 181 | self.up2_direction = nn.Sequential( 182 | nn.Upsample(scale_factor=2, mode='bilinear', 183 | align_corners=True), 184 | nn.Conv2d(256, 128, kernel_size=3, padding=1, bias=False), 185 | nn.BatchNorm2d(128), 186 | nn.ReLU(inplace=True), 187 | nn.Conv2d(128, direction_dim, kernel_size=1, padding=0), 188 | ) 189 | 190 | def forward(self, x): 191 | x = self.conv1(x) 192 | x = self.bn1(x) 193 | x = self.relu(x) 194 | 195 | x1 = self.layer1(x) 196 | x = self.layer2(x1) 197 | x2 = self.layer3(x) 198 | 199 | x = self.up1(x2, x1) 200 | x = self.up2(x) 201 | 202 | if self.instance_seg: 203 | x_embedded = self.up1_embedded(x2, x1) 204 | x_embedded = self.up2_embedded(x_embedded) 205 | else: 206 | x_embedded = None 207 | 208 | if self.direction_pred: 209 | x_direction = self.up1_direction(x2, x1) 210 | x_direction = self.up2_direction(x_direction) 211 | else: 212 | x_direction = None 213 | 214 | return x, x_embedded, x_direction 215 | 216 | class BevEncode_bd(nn.Module): 217 | def __init__(self, inC, outC, instance_seg=True, embedded_dim=16, direction_pred=True, direction_dim=37): 218 | super(BevEncode_bd, self).__init__() 219 | trunk = resnet18(pretrained=False, zero_init_residual=True) 220 | self.conv1 = nn.Conv2d(inC, 64, kernel_size=7, stride=2, padding=3, bias=False) 221 | self.bn1 = trunk.bn1 222 | self.relu = trunk.relu 223 | 224 | self.layer1 = trunk.layer1 225 | self.layer2 = trunk.layer2 226 | self.layer3 = trunk.layer3 227 | 228 | self.up1 = Up(64 + 256, 256, scale_factor=4) 229 | self.up2 = nn.Sequential( 230 | nn.Upsample(scale_factor=2, mode='bilinear', 231 | align_corners=True), 232 | nn.Conv2d(256, 128, kernel_size=3, padding=1, bias=False), 233 | nn.BatchNorm2d(128), 234 | nn.ReLU(inplace=True), 235 | nn.Conv2d(128, outC, kernel_size=1, padding=0), 236 | ) 237 | 238 | 239 | def forward(self, x): 240 | x = self.conv1(x) 241 | x = self.bn1(x) 242 | x = self.relu(x) 243 | 244 | x1 = self.layer1(x) 245 | x = self.layer2(x1) 246 | x2 = self.layer3(x) 247 | 248 | x = self.up1(x2, x1) 249 | x = self.up2(x) 250 | 251 | return x 252 | 253 | -------------------------------------------------------------------------------- /model/utils/map_mae_head.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from functools import partial 4 | import timm.models.vision_transformer 5 | from .base import CamEncode, BevEncode 6 | 7 | class ConvBNReLU(nn.Module): 8 | def __init__(self, in_channel, out_channel, kernel_size=3, stride=1, padding=1, 9 | dilation=1, groups=1, bias=False, has_relu=True): 10 | super().__init__() 11 | self.conv = nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=kernel_size, 12 | stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) 13 | self.bn = nn.BatchNorm2d(out_channel) 14 | self.relu = nn.ReLU(inplace=True) 15 | self.has_relu = has_relu 16 | 17 | def forward(self, x): 18 | feat = self.conv(x) 19 | feat = self.bn(feat) 20 | if self.has_relu: 21 | return self.relu(feat) 22 | return feat 23 | 24 | class PatchEmbed(nn.Module): 25 | """ Image to Patch Embedding 26 | """ 27 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 28 | super().__init__() 29 | if isinstance(img_size, int): 30 | img_size = (img_size, img_size) 31 | if isinstance(patch_size, int): 32 | patch_size = (patch_size, patch_size) 33 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) 34 | self.img_size = img_size 35 | self.patch_size = patch_size 36 | self.num_patches = num_patches 37 | 38 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 39 | 40 | def forward(self, x): 41 | B, C, H, W = x.shape 42 | # FIXME look at relaxing size constraints 43 | assert H == self.img_size[0] and W == self.img_size[1], \ 44 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 45 | x = self.proj(x) 46 | return x 47 | 48 | 49 | class MapVisionTransformer(timm.models.vision_transformer.VisionTransformer): 50 | """ Vision Transformer with support for global average pooling 51 | """ 52 | def __init__(self, 53 | data_conf=None, 54 | instance_seg=True, 55 | embedded_dim=16, 56 | direction_pred=True, 57 | direction_dim=36, 58 | lidar=None, 59 | **kwargs): 60 | super(MapVisionTransformer, self).__init__(**kwargs) 61 | self.bev_head = BevEncode(inC=kwargs['embed_dim'], 62 | outC=data_conf['num_channels'], 63 | instance_seg=instance_seg, 64 | embedded_dim=embedded_dim, 65 | direction_pred=direction_pred, 66 | direction_dim=direction_dim+1) 67 | patch_h = data_conf['ybound'][1] - data_conf['ybound'][0] # 30.0 68 | patch_w = data_conf['xbound'][1] - data_conf['xbound'][0] # 60.0 69 | self.canvas_h = int(patch_h / data_conf['ybound'][2]) # 200 70 | self.canvas_w = int(patch_w / data_conf['xbound'][2]) # 400 71 | self.conv_up = nn.Sequential( 72 | nn.ConvTranspose2d(kwargs['embed_dim'], kwargs['embed_dim'], kernel_size=4, stride=2, padding=1), 73 | nn.ConvTranspose2d(kwargs['embed_dim'], kwargs['embed_dim'], kernel_size=4, stride=2, padding=1), 74 | nn.Upsample(size=(self.canvas_h, self.canvas_w), mode='bilinear', align_corners=False), 75 | ConvBNReLU(kwargs['embed_dim'], kwargs['embed_dim'], 1, stride=1, padding=0, has_relu=False), 76 | ) 77 | self.map_patch_embed = PatchEmbed(kwargs['img_size'], kwargs['patch_size'], kwargs['in_chans'], kwargs['embed_dim']) 78 | 79 | def forward_features(self, x): 80 | B = x.shape[0] # (b,c,h,w) 81 | # import pdb; pdb.set_trace() 82 | x = self.map_patch_embed(x) # (b,dim,12,25) 83 | 84 | _, dim, h, w = x.shape 85 | x = x.flatten(2).transpose(1, 2) # (b,n,dim) 86 | x = x + self.pos_embed[:, :-1] 87 | x = self.pos_drop(x) 88 | 89 | for blk in self.blocks: 90 | x = blk(x) 91 | 92 | x = self.norm(x) 93 | outcome = x.permute(0,2,1).reshape(B, dim, h, w) 94 | outcome = self.conv_up(outcome) 95 | return outcome 96 | 97 | def forward(self, x): 98 | x = self.forward_features(x) 99 | x = self.bev_head(x) 100 | return x 101 | 102 | 103 | def vit_base_patch8(**kwargs): 104 | model = MapVisionTransformer( 105 | patch_size=8, 106 | embed_dim=768, 107 | depth=12, 108 | num_heads=12, 109 | mlp_ratio=4, 110 | qkv_bias=True, 111 | in_chans=4, 112 | **kwargs) 113 | return model 114 | 115 | def vit_base_patch16(**kwargs): 116 | model = MapVisionTransformer( 117 | patch_size=16, 118 | embed_dim=768, 119 | depth=12, 120 | num_heads=12, 121 | mlp_ratio=4, 122 | qkv_bias=True, 123 | in_chans=4, 124 | **kwargs) 125 | return model 126 | 127 | def vit_base_patch32(**kwargs): 128 | model = MapVisionTransformer( 129 | patch_size=32, 130 | embed_dim=768, 131 | depth=12, 132 | num_heads=12, 133 | mlp_ratio=4, 134 | qkv_bias=True, 135 | in_chans=4, 136 | **kwargs) 137 | return model -------------------------------------------------------------------------------- /model/utils/pointpillar.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch_scatter 4 | 5 | from .voxel import points_to_voxels 6 | 7 | 8 | class PillarBlock(nn.Module): 9 | def __init__(self, idims=64, dims=64, num_layers=1, 10 | stride=1): 11 | super(PillarBlock, self).__init__() 12 | layers = [] 13 | self.idims = idims 14 | self.stride = stride 15 | for i in range(num_layers): 16 | layers.append(nn.Conv2d(self.idims, dims, 3, stride=self.stride, 17 | padding=1, bias=False)) 18 | layers.append(nn.BatchNorm2d(dims)) 19 | layers.append(nn.ReLU(inplace=True)) 20 | self.idims = dims 21 | self.stride = 1 22 | self.layers = nn.Sequential(*layers) 23 | 24 | def forward(self, x): 25 | return self.layers(x) 26 | 27 | 28 | class PointNet(nn.Module): 29 | def __init__(self, idims=64, odims=64): 30 | super(PointNet, self).__init__() 31 | self.pointnet = nn.Sequential( 32 | nn.Conv1d(idims, odims, kernel_size=1, bias=False), 33 | nn.BatchNorm1d(odims), 34 | nn.ReLU(inplace=True) 35 | ) 36 | 37 | def forward(self, points_feature, points_mask): 38 | batch_size, num_points, num_dims = points_feature.shape 39 | points_feature = points_feature.permute(0, 2, 1) 40 | mask = points_mask.view(batch_size, 1, num_points) 41 | return self.pointnet(points_feature) * mask 42 | 43 | 44 | class PointPillar(nn.Module): 45 | def __init__(self, C, xbound, ybound, zbound, embedded_dim=16, direction_dim=37): 46 | super(PointPillar, self).__init__() 47 | self.xbound = xbound 48 | self.ybound = ybound 49 | self.zbound = zbound 50 | self.embedded_dim = embedded_dim 51 | self.pn = PointNet(15, 64) 52 | self.block1 = PillarBlock(64, dims=64, num_layers=2, stride=1) 53 | self.block2 = PillarBlock(64, dims=128, num_layers=3, stride=2) 54 | self.block3 = PillarBlock(128, 256, num_layers=3, stride=2) 55 | self.up1 = nn.Sequential( 56 | nn.Conv2d(64, 64, 3, padding=1, bias=False), 57 | nn.BatchNorm2d(64), 58 | nn.ReLU(inplace=True) 59 | ) 60 | self.up2 = nn.Sequential( 61 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), 62 | nn.Conv2d(128, 128, 3, stride=1, padding=1, bias=False), 63 | nn.BatchNorm2d(128), 64 | nn.ReLU(inplace=True) 65 | ) 66 | self.up3 = nn.Sequential( 67 | nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True), 68 | nn.Conv2d(256, 256, 3, stride=1, padding=1, bias=False), 69 | nn.BatchNorm2d(256), 70 | nn.ReLU(inplace=True) 71 | ) 72 | self.conv_out = nn.Sequential( 73 | nn.Conv2d(448, 256, 3, padding=1, bias=False), 74 | nn.BatchNorm2d(256), 75 | nn.ReLU(inplace=True), 76 | nn.Conv2d(256, 128, 3, padding=1, bias=False), 77 | nn.BatchNorm2d(128), 78 | nn.ReLU(inplace=True), 79 | nn.Conv2d(128, C, 1), 80 | ) 81 | self.instance_conv_out = nn.Sequential( 82 | nn.Conv2d(448, 256, 3, padding=1, bias=False), 83 | nn.BatchNorm2d(256), 84 | nn.ReLU(inplace=True), 85 | nn.Conv2d(256, 128, 3, padding=1, bias=False), 86 | nn.BatchNorm2d(128), 87 | nn.ReLU(inplace=True), 88 | nn.Conv2d(128, embedded_dim, 1), 89 | ) 90 | self.direction_conv_out = nn.Sequential( 91 | nn.Conv2d(448, 256, 3, padding=1, bias=False), 92 | nn.BatchNorm2d(256), 93 | nn.ReLU(inplace=True), 94 | nn.Conv2d(256, 128, 3, padding=1, bias=False), 95 | nn.BatchNorm2d(128), 96 | nn.ReLU(inplace=True), 97 | nn.Conv2d(128, direction_dim, 1), 98 | ) 99 | 100 | def forward(self, points, points_mask, 101 | x, rots, trans, intrins, post_rots, post_trans, translation, yaw_pitch_roll): 102 | points_xyz = points[:, :, :3] 103 | points_feature = points[:, :, 3:] 104 | voxels = points_to_voxels( 105 | points_xyz, points_mask, self.xbound, self.ybound, self.zbound 106 | ) 107 | points_feature = torch.cat( 108 | [points, # 5 109 | torch.unsqueeze(voxels['voxel_point_count'], dim=-1), # 1 110 | voxels['local_points_xyz'], # 3 111 | voxels['point_centroids'], # 3 112 | points_xyz - voxels['voxel_centers'], # 3 113 | ], dim=-1 114 | ) 115 | points_feature = self.pn(points_feature, voxels['points_mask']) 116 | voxel_feature = torch_scatter.scatter_mean( 117 | points_feature, 118 | torch.unsqueeze(voxels['voxel_indices'], dim=1), 119 | dim=2, 120 | dim_size=voxels['num_voxels']) 121 | batch_size = points.size(0) 122 | voxel_feature = voxel_feature.view(batch_size, -1, voxels['grid_size'][0], voxels['grid_size'][1]) 123 | voxel_feature1 = self.block1(voxel_feature) 124 | voxel_feature2 = self.block2(voxel_feature1) 125 | voxel_feature3 = self.block3(voxel_feature2) 126 | voxel_feature1 = self.up1(voxel_feature1) 127 | voxel_feature2 = self.up2(voxel_feature2) 128 | voxel_feature3 = self.up3(voxel_feature3) 129 | voxel_feature = torch.cat([voxel_feature1, voxel_feature2, voxel_feature3], dim=1) 130 | return self.conv_out(voxel_feature).transpose(3, 2), self.instance_conv_out(voxel_feature).transpose(3, 2), self.direction_conv_out(voxel_feature).transpose(3, 2) 131 | 132 | 133 | class PointPillarEncoder(nn.Module): 134 | def __init__(self, C, xbound, ybound, zbound): 135 | super(PointPillarEncoder, self).__init__() 136 | self.xbound = xbound 137 | self.ybound = ybound 138 | self.zbound = zbound 139 | self.pn = PointNet(15, 64) 140 | self.block1 = PillarBlock(64, dims=64, num_layers=2, stride=1) 141 | self.block2 = PillarBlock(64, dims=128, num_layers=3, stride=2) 142 | self.block3 = PillarBlock(128, 256, num_layers=3, stride=2) 143 | self.up1 = nn.Sequential( 144 | nn.Conv2d(64, 64, 3, padding=1, bias=False), 145 | nn.BatchNorm2d(64), 146 | nn.ReLU(inplace=True) 147 | ) 148 | self.up2 = nn.Sequential( 149 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), 150 | nn.Conv2d(128, 128, 3, stride=1, padding=1, bias=False), 151 | nn.BatchNorm2d(128), 152 | nn.ReLU(inplace=True) 153 | ) 154 | self.up3 = nn.Sequential( 155 | nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True), 156 | nn.Conv2d(256, 256, 3, stride=1, padding=1, bias=False), 157 | nn.BatchNorm2d(256), 158 | nn.ReLU(inplace=True) 159 | ) 160 | self.conv_out = nn.Sequential( 161 | nn.Conv2d(448, 256, 3, padding=1, bias=False), 162 | nn.BatchNorm2d(256), 163 | nn.ReLU(inplace=True), 164 | nn.Conv2d(256, 128, 3, padding=1, bias=False), 165 | nn.BatchNorm2d(128), 166 | nn.ReLU(inplace=True), 167 | nn.Conv2d(128, C, 1), 168 | ) 169 | 170 | def forward(self, points, points_mask): 171 | points_xyz = points[:, :, :3] 172 | points_feature = points[:, :, 3:] 173 | voxels = points_to_voxels( 174 | points_xyz, points_mask, self.xbound, self.ybound, self.zbound 175 | ) 176 | points_feature = torch.cat( 177 | [points, # 5 178 | torch.unsqueeze(voxels['voxel_point_count'], dim=-1), # 1 179 | voxels['local_points_xyz'], # 3 180 | voxels['point_centroids'], # 3 181 | points_xyz - voxels['voxel_centers'], # 3 182 | ], dim=-1 183 | ) 184 | points_feature = self.pn(points_feature, voxels['points_mask']) 185 | voxel_feature = torch_scatter.scatter_mean( 186 | points_feature, 187 | torch.unsqueeze(voxels['voxel_indices'], dim=1), 188 | dim=2, 189 | dim_size=voxels['num_voxels']) 190 | batch_size = points.size(0) 191 | voxel_feature = voxel_feature.view(batch_size, -1, voxels['grid_size'][0], voxels['grid_size'][1]) 192 | voxel_feature1 = self.block1(voxel_feature) 193 | voxel_feature2 = self.block2(voxel_feature1) 194 | voxel_feature3 = self.block3(voxel_feature2) 195 | voxel_feature1 = self.up1(voxel_feature1) 196 | voxel_feature2 = self.up2(voxel_feature2) 197 | voxel_feature3 = self.up3(voxel_feature3) 198 | voxel_feature = torch.cat([voxel_feature1, voxel_feature2, voxel_feature3], dim=1) 199 | return self.conv_out(voxel_feature).transpose(3, 2) 200 | -------------------------------------------------------------------------------- /model/utils/position_encoding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Various positional encodings for the transformer. 4 | """ 5 | import math 6 | import torch 7 | from torch import nn 8 | 9 | from .misc import NestedTensor 10 | 11 | 12 | class PositionEmbeddingSine(nn.Module): 13 | """ 14 | This is a more standard version of the position embedding, very similar to the one 15 | used by the Attention is all you need paper, generalized to work on images. 16 | """ 17 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 18 | super().__init__() 19 | self.num_pos_feats = num_pos_feats 20 | self.temperature = temperature 21 | self.normalize = normalize 22 | if scale is not None and normalize is False: 23 | raise ValueError("normalize should be True if scale is passed") 24 | if scale is None: 25 | scale = 2 * math.pi 26 | self.scale = scale 27 | 28 | def forward(self, x, mask): 29 | 30 | assert mask is not None 31 | not_mask = ~mask 32 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 33 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 34 | if self.normalize: 35 | eps = 1e-6 36 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 37 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 38 | 39 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 40 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 41 | 42 | pos_x = x_embed[:, :, :, None] / dim_t 43 | pos_y = y_embed[:, :, :, None] / dim_t 44 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 45 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 46 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 47 | return pos 48 | 49 | 50 | class PositionEmbeddingLearned(nn.Module): 51 | """ 52 | Absolute pos embedding, learned. 53 | """ 54 | def __init__(self, num_pos_feats=256): 55 | super().__init__() 56 | self.row_embed = nn.Embedding(50, num_pos_feats) 57 | self.col_embed = nn.Embedding(50, num_pos_feats) 58 | self.reset_parameters() 59 | 60 | def reset_parameters(self): 61 | nn.init.uniform_(self.row_embed.weight) 62 | nn.init.uniform_(self.col_embed.weight) 63 | 64 | def forward(self, x): 65 | # x = tensor_list.tensors 66 | h, w = x.shape[-2:] 67 | i = torch.arange(w, device=x.device) 68 | j = torch.arange(h, device=x.device) 69 | x_emb = self.col_embed(i) 70 | y_emb = self.row_embed(j) 71 | pos = torch.cat([ 72 | x_emb.unsqueeze(0).repeat(h, 1, 1), 73 | y_emb.unsqueeze(1).repeat(1, w, 1), 74 | ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) 75 | return pos 76 | 77 | 78 | def build_position_encoding(args): 79 | N_steps = args.hidden_dim // 2 80 | if args.position_embedding in ('v2', 'sine'): 81 | # TODO find a better way of exposing other arguments 82 | position_embedding = PositionEmbeddingSine(N_steps, normalize=True) 83 | elif args.position_embedding in ('v3', 'learned'): 84 | position_embedding = PositionEmbeddingLearned(N_steps) 85 | else: 86 | raise ValueError(f"not supported {args.position_embedding}") 87 | 88 | return position_embedding 89 | -------------------------------------------------------------------------------- /model/utils/sdmap_cross_attn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | DETR Transformer class. 4 | 5 | Copy-paste from torch.nn.Transformer with modifications: 6 | * positional encodings are passed in MHattention 7 | * extra LN at the end of encoder is removed 8 | * decoder returns a stack of activations from all decoding layers 9 | """ 10 | import copy 11 | from typing import Optional, List 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | from torch import nn, Tensor 16 | 17 | 18 | class SDMapCrossAttn(nn.Module): 19 | 20 | def __init__(self, d_model=256, nhead=8, num_encoder_layers=2, 21 | num_decoder_layers=2, dim_feedforward=192, dropout=0.1, 22 | activation="relu", normalize_before=False, 23 | return_intermediate_dec=False): 24 | super().__init__() 25 | 26 | self.return_intermediate = return_intermediate_dec 27 | self.norm = nn.LayerNorm(d_model) 28 | 29 | 30 | 31 | decoder_layer = SDMapCrossAttnLayer(d_model, nhead, dim_feedforward, 32 | dropout, activation) 33 | 34 | self.layers = _get_clones(decoder_layer, num_decoder_layers) 35 | self.num_layers = num_decoder_layers 36 | 37 | self._reset_parameters() 38 | 39 | self.d_model = d_model 40 | self.nhead = nhead 41 | 42 | def _reset_parameters(self): 43 | for p in self.parameters(): 44 | if p.dim() > 1: 45 | nn.init.xavier_uniform_(p) 46 | 47 | 48 | def forward(self, bev, sdmap, 49 | tgt_mask: Optional[Tensor] = None, 50 | memory_mask: Optional[Tensor] = None, 51 | tgt_key_padding_mask: Optional[Tensor] = None, 52 | memory_key_padding_mask: Optional[Tensor] = None, 53 | pos: Optional[Tensor] = None, 54 | query_pos: Optional[Tensor] = None): 55 | 56 | assert bev.shape == sdmap.shape 57 | bs, c, h, w = bev.shape 58 | bev = bev.flatten(2).permute(2, 0, 1) 59 | sdmap = sdmap.flatten(2).permute(2, 0, 1) 60 | pos = pos.flatten(2).permute(2, 0, 1) 61 | 62 | output = bev 63 | 64 | intermediate = [] 65 | 66 | for layer in self.layers: 67 | output = layer(output, sdmap, tgt_mask=tgt_mask, 68 | memory_mask=memory_mask, 69 | tgt_key_padding_mask=tgt_key_padding_mask, 70 | memory_key_padding_mask=memory_key_padding_mask, 71 | pos=pos, query_pos=query_pos) 72 | if self.return_intermediate: 73 | intermediate.append(self.norm(output)) 74 | 75 | if self.norm is not None: 76 | output = self.norm(output) 77 | if self.return_intermediate: 78 | intermediate.pop() 79 | intermediate.append(output) 80 | 81 | if self.return_intermediate: 82 | return torch.stack(intermediate) 83 | 84 | bew_feat = output.view(h,w,bs,c).permute(2,3,0,1) 85 | 86 | return bew_feat.unsqueeze(0) 87 | 88 | 89 | class SDMapCrossAttnLayer(nn.Module): 90 | def __init__(self, d_model, nhead, dim_feedforward=192, dropout=0.1, 91 | activation="relu"): 92 | super().__init__() 93 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 94 | self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 95 | # Implementation of Feedforward model 96 | self.linear1 = nn.Linear(d_model, dim_feedforward) 97 | self.dropout = nn.Dropout(dropout) 98 | self.linear2 = nn.Linear(dim_feedforward, d_model) 99 | 100 | self.norm1 = nn.LayerNorm(d_model) 101 | self.norm2 = nn.LayerNorm(d_model) 102 | self.norm3 = nn.LayerNorm(d_model) 103 | self.dropout1 = nn.Dropout(dropout) 104 | self.dropout2 = nn.Dropout(dropout) 105 | self.dropout3 = nn.Dropout(dropout) 106 | 107 | self.activation = _get_activation_fn(activation) 108 | 109 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 110 | return tensor if pos is None else tensor + pos 111 | 112 | def forward_post(self, bev, sdmap, 113 | bev_mask: Optional[Tensor] = None, 114 | sdmap_mask: Optional[Tensor] = None, 115 | tgt_key_padding_mask: Optional[Tensor] = None, 116 | sdmap_key_padding_mask: Optional[Tensor] = None, 117 | pos: Optional[Tensor] = None, 118 | query_pos: Optional[Tensor] = None): 119 | 120 | q = k = self.with_pos_embed(bev, pos) 121 | bev2 = self.self_attn(q, k, value=bev, attn_mask=sdmap_mask, 122 | key_padding_mask=tgt_key_padding_mask)[0] 123 | bev = bev + self.dropout1(bev2) 124 | bev = self.norm1(bev) 125 | 126 | bev2 = self.multihead_attn(query=self.with_pos_embed(bev, pos), 127 | key=self.with_pos_embed(sdmap, pos), 128 | # key=sdmap, 129 | value=sdmap, attn_mask=sdmap_mask, 130 | key_padding_mask=sdmap_key_padding_mask)[0] 131 | bev = bev + self.dropout2(bev2) 132 | bev = self.norm2(bev) 133 | 134 | bev2 = self.linear2(self.dropout(self.activation(self.linear1(bev)))) 135 | bev = bev + self.dropout3(bev2) 136 | bev = self.norm3(bev) 137 | 138 | return bev 139 | 140 | 141 | def forward(self, tgt, memory, 142 | tgt_mask: Optional[Tensor] = None, 143 | memory_mask: Optional[Tensor] = None, 144 | tgt_key_padding_mask: Optional[Tensor] = None, 145 | memory_key_padding_mask: Optional[Tensor] = None, 146 | pos: Optional[Tensor] = None, 147 | query_pos: Optional[Tensor] = None): 148 | 149 | return self.forward_post(tgt, memory, tgt_mask, memory_mask, 150 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) 151 | 152 | 153 | def _get_clones(module, N): 154 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 155 | 156 | 157 | def build_transformer(args): 158 | return SDMapCrossAttn( 159 | d_model=args.hidden_dim, 160 | dropout=args.dropout, 161 | nhead=args.nheads, 162 | dim_feedforward=args.dim_feedforward, 163 | num_encoder_layers=args.enc_layers, 164 | num_decoder_layers=args.dec_layers, 165 | normalize_before=args.pre_norm, 166 | return_intermediate_dec=True, 167 | ) 168 | 169 | 170 | def _get_activation_fn(activation): 171 | """Return an activation function given a string""" 172 | if activation == "relu": 173 | return F.relu 174 | if activation == "gelu": 175 | return F.gelu 176 | if activation == "glu": 177 | return F.glu 178 | raise RuntimeError(F"activation should be relu/gelu, not {activation}.") -------------------------------------------------------------------------------- /model/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def plane_grid_2d(xbound, ybound): 4 | xmin, xmax = xbound[0], xbound[1] 5 | num_x = int((xbound[1] - xbound[0]) / xbound[2]) 6 | ymin, ymax = ybound[0], ybound[1] 7 | num_y = int((ybound[1] - ybound[0]) / ybound[2]) 8 | 9 | y = torch.linspace(xmin, xmax, num_x).cuda() 10 | x = torch.linspace(ymin, ymax, num_y).cuda() 11 | y, x = torch.meshgrid(x, y) 12 | x = x.flatten() 13 | y = y.flatten() 14 | 15 | coords = torch.stack([x, y], axis=0) 16 | return coords 17 | 18 | 19 | def cam_to_pixel(points, xbound, ybound): 20 | new_points = torch.zeros_like(points) 21 | new_points[..., 0] = (points[..., 0] - xbound[0]) / xbound[2] 22 | new_points[..., 1] = (points[..., 1] - ybound[0]) / ybound[2] 23 | return new_points 24 | 25 | 26 | def get_rot_2d(yaw): 27 | sin_yaw = torch.sin(yaw) 28 | cos_yaw = torch.cos(yaw) 29 | rot = torch.zeros(list(yaw.shape) + [2, 2]).cuda() 30 | rot[..., 0, 0] = cos_yaw 31 | rot[..., 0, 1] = sin_yaw 32 | rot[..., 1, 0] = -sin_yaw 33 | rot[..., 1, 1] = cos_yaw 34 | return rot 35 | 36 | 37 | -------------------------------------------------------------------------------- /model/utils/voxel.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch_scatter 4 | 5 | 6 | def pad_or_trim_to_np(x, shape, pad_val=0): 7 | shape = np.asarray(shape) 8 | pad = shape - np.minimum(np.shape(x), shape) 9 | zeros = np.zeros_like(pad) 10 | x = np.pad(x, np.stack([zeros, pad], axis=1), constant_values=pad_val) 11 | return x[:shape[0], :shape[1]] 12 | 13 | 14 | def raval_index(coords, dims): 15 | dims = torch.cat((dims, torch.ones(1, device=dims.device)), dim=0)[1:] 16 | dims = torch.flip(dims, dims=[0]) 17 | dims = torch.cumprod(dims, dim=0) / dims[0] 18 | multiplier = torch.flip(dims, dims=[0]) 19 | indices = torch.sum(coords * multiplier, dim=1) 20 | return indices 21 | 22 | 23 | def points_to_voxels( 24 | points_xyz, 25 | points_mask, 26 | grid_range_x, 27 | grid_range_y, 28 | grid_range_z 29 | ): 30 | batch_size, num_points, _ = points_xyz.shape 31 | voxel_size_x = grid_range_x[2] 32 | voxel_size_y = grid_range_y[2] 33 | voxel_size_z = grid_range_z[2] 34 | grid_size = np.asarray([ 35 | (grid_range_x[1]-grid_range_x[0]) / voxel_size_x, 36 | (grid_range_y[1]-grid_range_y[0]) / voxel_size_y, 37 | (grid_range_z[1]-grid_range_z[0]) / voxel_size_z 38 | ]).astype('int32') 39 | voxel_size = np.asarray([voxel_size_x, voxel_size_y, voxel_size_z]) 40 | voxel_size = torch.Tensor(voxel_size).to(points_xyz.device) 41 | num_voxels = grid_size[0] * grid_size[1] * grid_size[2] 42 | grid_offset = torch.Tensor([grid_range_x[0], grid_range_y[0], grid_range_z[0]]).to(points_xyz.device) 43 | shifted_points_xyz = points_xyz - grid_offset 44 | voxel_xyz = shifted_points_xyz / voxel_size 45 | voxel_coords = voxel_xyz.int() 46 | grid_size = torch.from_numpy(grid_size).to(points_xyz.device) 47 | grid_size = grid_size.int() 48 | zeros = torch.zeros_like(grid_size) 49 | voxel_paddings = ((points_mask < 1.0) | 50 | torch.any((voxel_coords >= grid_size) | 51 | (voxel_coords < zeros), dim=-1)) 52 | voxel_indices = raval_index( 53 | torch.reshape(voxel_coords, [batch_size * num_points, 3]), grid_size) 54 | voxel_indices = torch.reshape(voxel_indices, [batch_size, num_points]) 55 | voxel_indices = torch.where(voxel_paddings, 56 | torch.zeros_like(voxel_indices), 57 | voxel_indices) 58 | voxel_centers = ((0.5 + voxel_coords.float()) * voxel_size + grid_offset) 59 | voxel_coords = torch.where(torch.unsqueeze(voxel_paddings, dim=-1), 60 | torch.zeros_like(voxel_coords), 61 | voxel_coords) 62 | voxel_xyz = torch.where(torch.unsqueeze(voxel_paddings, dim=-1), 63 | torch.zeros_like(voxel_xyz), 64 | voxel_xyz) 65 | voxel_paddings = voxel_paddings.float() 66 | 67 | voxel_indices = voxel_indices.long() 68 | points_per_voxel = torch_scatter.scatter_sum( 69 | torch.ones((batch_size, num_points), dtype=voxel_coords.dtype, device=voxel_coords.device) * (1-voxel_paddings), 70 | voxel_indices, 71 | dim=1, 72 | dim_size=num_voxels 73 | ) 74 | 75 | voxel_point_count = torch.gather(points_per_voxel, 76 | dim=1, 77 | index=voxel_indices) 78 | 79 | 80 | voxel_centroids = torch_scatter.scatter_mean( 81 | points_xyz, 82 | voxel_indices, 83 | dim=1, 84 | dim_size=num_voxels) 85 | point_centroids = torch.gather(voxel_centroids, dim=1, index=torch.unsqueeze(voxel_indices, dim=-1).repeat(1, 1, 3)) 86 | local_points_xyz = points_xyz - point_centroids 87 | 88 | result = { 89 | 'local_points_xyz': local_points_xyz, 90 | 'shifted_points_xyz': shifted_points_xyz, 91 | 'point_centroids': point_centroids, 92 | 'points_xyz': points_xyz, 93 | 'grid_offset': grid_offset, 94 | 'voxel_coords': voxel_coords, 95 | 'voxel_centers': voxel_centers, 96 | 'voxel_indices': voxel_indices, 97 | 'voxel_paddings': voxel_paddings, 98 | 'points_mask': 1 - voxel_paddings, 99 | 'num_voxels': num_voxels, 100 | 'grid_size': grid_size, 101 | 'voxel_xyz': voxel_xyz, 102 | 'voxel_size': voxel_size, 103 | 'voxel_point_count': voxel_point_count, 104 | 'points_per_voxel': points_per_voxel 105 | } 106 | 107 | 108 | return result 109 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.0.0 2 | addict==2.4.0 3 | attrs==23.1.0 4 | black==23.11.0 5 | cachetools==5.3.2 6 | certifi==2023.11.17 7 | charset-normalizer==3.3.2 8 | click==8.1.7 9 | click-plugins==1.1.1 10 | cligj==0.7.2 11 | contourpy==1.1.1 12 | cycler==0.12.1 13 | descartes==1.1.0 14 | efficientnet-pytorch==0.7.1 15 | exceptiongroup==1.2.0 16 | filelock==3.13.1 17 | fiona==1.9.5 18 | fire==0.5.0 19 | flake8==6.1.0 20 | fonttools==4.45.1 21 | fsspec==2023.12.1 22 | geopandas==0.13.2 23 | google-auth==2.23.4 24 | google-auth-oauthlib==1.0.0 25 | grpcio==1.59.3 26 | huggingface-hub==0.19.4 27 | idna==3.6 28 | imageio==2.33.0 29 | importlib-metadata==6.8.0 30 | importlib-resources==6.1.1 31 | iniconfig==2.0.0 32 | ipython==8.12.2 33 | jedi==0.19.1 34 | joblib==1.3.2 35 | kiwisolver==1.4.5 36 | llvmlite==0.36.0 37 | lyft-dataset-sdk==0.0.8 38 | markdown==3.5.1 39 | markupsafe==2.1.3 40 | matplotlib==3.5.3 41 | mccabe==0.7.0 42 | mmcls==0.25.0 43 | mmcv-full==1.6.0 44 | mmdet==2.28.2 45 | mmsegmentation==0.30.0 46 | mypy-extensions==1.0.0 47 | networkx==2.2 48 | numba==0.53.0 49 | numpy==1.23.5 50 | nuscenes-devkit==1.1.11 51 | nvidia-ml-py==12.535.133 52 | nvitop==1.3.1 53 | oauthlib==3.2.2 54 | opencv-python==4.8.1.78 55 | packaging==23.2 56 | pandas==2.0.3 57 | pathspec==0.11.2 58 | pillow==10.1.0 59 | platformdirs==4.0.0 60 | plotly==5.18.0 61 | pluggy==1.3.0 62 | plyfile==1.0.2 63 | prettytable==3.9.0 64 | prompt-toolkit==3.0.41 65 | protobuf==4.25.1 66 | psutil==5.9.6 67 | pyasn1==0.5.1 68 | pyasn1-modules==0.3.0 69 | pycocotools==2.0.7 70 | pycodestyle==2.11.1 71 | pyflakes==3.1.0 72 | pyparsing==3.1.1 73 | pyproj==3.5.0 74 | pyquaternion==0.9.9 75 | pytest==7.4.3 76 | python-dateutil==2.8.2 77 | pytz==2023.3.post1 78 | pywavelets==1.4.1 79 | pyyaml==6.0.1 80 | requests==2.31.0 81 | requests-oauthlib==1.3.1 82 | rsa==4.9 83 | safetensors==0.4.1 84 | scikit-image==0.19.3 85 | scikit-learn==1.3.2 86 | scipy==1.10.1 87 | shapely==1.8.5.post1 88 | some-package==0.1 89 | tenacity==8.2.3 90 | tensorboard==2.13.0 91 | tensorboard-data-server==0.7.2 92 | tensorboardx==2.6.2.2 93 | termcolor==2.3.0 94 | terminaltables==3.1.10 95 | threadpoolctl==3.2.0 96 | tifffile==2023.7.10 97 | timm==0.9.12 98 | tomli==2.0.1 99 | torch-scatter==2.0.9 100 | tqdm==4.66.1 101 | trimesh==2.35.39 102 | tzdata==2023.3 103 | urllib3==2.1.0 104 | wcwidth==0.2.12 105 | werkzeug==3.0.1 106 | yapf==0.40.2 107 | zipp==3.17.0 -------------------------------------------------------------------------------- /tools/eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import tqdm 3 | import os 4 | import sys 5 | currentPath = os.path.split(os.path.realpath(__file__))[0] 6 | sys.path.append(currentPath + '/..') 7 | import torch 8 | from tools.config import Config 9 | from tools.evaluation.iou import get_batch_iou 10 | from tools.evaluation import lpips 11 | from data_osm.dataset import semantic_dataset 12 | from data_osm.const import NUM_CLASSES 13 | from model import get_model 14 | from tools.postprocess.vectorize import vectorize 15 | from collections import OrderedDict 16 | import matplotlib.pyplot as plt 17 | import numpy as np 18 | import warnings 19 | warnings.filterwarnings("ignore") 20 | 21 | def onehot_encoding(logits, dim=1): 22 | max_idx = torch.argmax(logits, dim, keepdim=True) 23 | one_hot = logits.new_full(logits.shape, 0) 24 | one_hot.scatter_(dim, max_idx, 1) 25 | return one_hot 26 | 27 | # eval only pre-train mae 28 | def eval_pretrain(bevencode_bd, val_loader): 29 | bevencode_bd.eval() 30 | total_intersects = 0 31 | total_union = 0 32 | 33 | with torch.no_grad(): 34 | total_epe = 0 35 | index = 0 36 | for (imgs, trans, rots, intrins, post_trans, post_rots, lidar_data, lidar_mask, car_trans, 37 | yaw_pitch_roll, semantic_gt, instance_gt, direction_gt,osm_masks, 38 | osm_vectors, masked_map, timestamp, scene_id) in tqdm.tqdm(val_loader): 39 | 40 | semantic, embedding, direction = bevencode_bd(masked_map.cuda().float()) 41 | semantic_gt = semantic_gt.cuda().float() 42 | intersects, union = get_batch_iou(onehot_encoding(semantic.cuda()), semantic_gt) 43 | total_intersects += intersects 44 | total_union += union 45 | index = index + 1 46 | return total_intersects / (total_union + 1e-7) 47 | 48 | 49 | def eval_iou(model, val_loader): 50 | model.eval() 51 | total_intersects = 0 52 | total_union = 0 53 | with torch.no_grad(): 54 | for (imgs, trans, rots, intrins, post_trans, post_rots, lidar_data, lidar_mask, car_trans, 55 | yaw_pitch_roll, semantic_gt, instance_gt, direction_gt,osm_masks, 56 | osm_vectors, masked_map, timestamp, scene_id) in tqdm.tqdm(val_loader): 57 | 58 | semantic, embedding, direction = model(imgs.cuda(), trans.cuda(), rots.cuda(), intrins.cuda(), 59 | post_trans.cuda(), post_rots.cuda(), lidar_data.cuda(), 60 | lidar_mask.cuda(), car_trans.cuda(), yaw_pitch_roll.cuda(), osm_masks.float().cuda()) 61 | 62 | semantic_gt = semantic_gt.cuda().float() 63 | device = semantic_gt.device 64 | if semantic.device != device: 65 | semantic = semantic.to(device) 66 | embedding = embedding.to(device) 67 | direction = direction.to(device) 68 | 69 | intersects, union = get_batch_iou(onehot_encoding(semantic), semantic_gt) 70 | total_intersects += intersects 71 | total_union += union 72 | return total_intersects / (total_union + 1e-7) 73 | 74 | 75 | def eval_all(model, val_loader): 76 | model.eval() 77 | total_intersects = 0 78 | total_union = 0 79 | i=0 80 | lpipss1 = [] 81 | with torch.no_grad(): 82 | for imgs, trans, rots, intrins, post_trans, post_rots, lidar_data, lidar_mask, car_trans, yaw_pitch_roll, semantic_gt, instance_gt, direction_gt,osm_masks, osm_vectors, masks_bd_osm, mask_bd, timestamp, scene_ids in tqdm.tqdm(val_loader): 83 | 84 | 85 | semantic, embedding, direction = model(imgs.cuda(), trans.cuda(), rots.cuda(), intrins.cuda(), 86 | post_trans.cuda(), post_rots.cuda(), lidar_data.cuda(), 87 | lidar_mask.cuda(), car_trans.cuda(), yaw_pitch_roll.cuda(), osm_masks.float().cuda()) 88 | 89 | gt = semantic_gt[:,1:4,:,:].clone().cuda() 90 | pred = semantic[:,1:4,:,:].clone().cuda() 91 | 92 | lpipss1.append(lpips(pred.float(), gt.float(), net_type='alex') / pred.shape[0]) 93 | 94 | semantic_gt = semantic_gt.cuda().float() 95 | 96 | device = semantic_gt.device 97 | if semantic.device != device: 98 | semantic = semantic.to(device) 99 | intersects, union = get_batch_iou(onehot_encoding(semantic), semantic_gt) 100 | total_intersects += intersects 101 | total_union += union 102 | i+=1 103 | print(" LPIPS1: {:>12.7f}".format(torch.tensor(lpipss1).mean(), ".5")) 104 | print(" IOU: {:>12.7f}".format((total_intersects / (total_union + 1e-7))), ".5") 105 | # return (total_intersects / (total_union + 1e-7)) 106 | 107 | 108 | def main(args): 109 | data_conf = { 110 | 'num_channels': NUM_CLASSES + 1, 111 | 'image_size': cfg.image_size, 112 | 'xbound': cfg.xbound, 113 | 'ybound': cfg.ybound, 114 | 'zbound': cfg.zbound, 115 | 'dbound': cfg.dbound, 116 | 'thickness': cfg.thickness, 117 | 'angle_class': cfg.angle_class, 118 | 'patch_w': cfg.patch_w, 119 | 'patch_h': cfg.patch_h, 120 | 'mask_ratio': cfg.mask_ratio, 121 | 'mask_flag': cfg.mask_flag, 122 | 'sd_map_path': cfg.sd_map_path, 123 | } 124 | 125 | train_loader, val_loader = semantic_dataset(args, args.version, args.dataroot, data_conf, 126 | args.batch_size, args.nworkers, cfg.dataset) 127 | model = get_model(args, data_conf, args.instance_seg, args.embedding_dim, args.direction_pred, args.angle_class) 128 | 129 | state_dict_model = torch.load(args.modelf) 130 | new_state_dict = OrderedDict() 131 | for k, v in state_dict_model.items(): 132 | name = k[7:] 133 | new_state_dict[name] = v 134 | model.load_state_dict(new_state_dict) 135 | model.cuda() 136 | 137 | if "pretrain" in str(args.config): 138 | print(eval_pretrain(model, val_loader)) 139 | else: 140 | print(eval_iou(model, val_loader)) 141 | 142 | 143 | if __name__ == '__main__': 144 | parser = argparse.ArgumentParser(description='Evaluate HDMap Construction Results..') 145 | parser.add_argument("config", help = 'path to config file', type=str, default=None) 146 | args = parser.parse_args() 147 | cfg = Config.fromfile(args.config) 148 | 149 | main(cfg) 150 | -------------------------------------------------------------------------------- /tools/evaluate_json.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import tqdm 3 | import argparse 4 | from config import Config 5 | import sys 6 | import os 7 | currentPath = os.path.split(os.path.realpath(__file__))[0] 8 | sys.path.append(currentPath + '/..') 9 | from tools.evaluation.dataset import PMapNetEvalDataset 10 | from tools.evaluation.chamfer_distance import semantic_mask_chamfer_dist_cum 11 | from tools.evaluation.AP import instance_mask_AP 12 | from tools.evaluation.iou import get_batch_iou 13 | 14 | SAMPLED_RECALLS = torch.linspace(0.1, 1, 10) 15 | # THRESHOLDS = [0.2, 0.5, 1.0] 16 | THRESHOLDS = [0.5, 1.0, 1.5] 17 | 18 | def get_val_info(args): 19 | data_conf = { 20 | 'xbound': args.xbound, 21 | 'ybound': args.ybound, 22 | 'thickness': args.thickness, 23 | 'sd_map_path': args.sd_map_path 24 | } 25 | 26 | dataset = PMapNetEvalDataset( 27 | args.version, args.dataroot, 'val', args.result_path, data_conf) 28 | 29 | data_loader = torch.utils.data.DataLoader( 30 | dataset, batch_size=args.batch_size, shuffle=False, drop_last=False) 31 | 32 | total_CD1 = torch.zeros(args.max_channel).cuda() 33 | total_CD2 = torch.zeros(args.max_channel).cuda() 34 | total_CD_num1 = torch.zeros(args.max_channel).cuda() 35 | total_CD_num2 = torch.zeros(args.max_channel).cuda() 36 | total_intersect = torch.zeros(args.max_channel).cuda() 37 | total_union = torch.zeros(args.max_channel).cuda() 38 | AP_matrix = torch.zeros((args.max_channel, len(THRESHOLDS))).cuda() 39 | AP_count_matrix = torch.zeros((args.max_channel, len(THRESHOLDS))).cuda() 40 | 41 | 42 | print('running eval...') 43 | for pred_map, confidence_level, gt_map in tqdm.tqdm(data_loader): 44 | 45 | pred_map = pred_map.cuda() 46 | confidence_level = confidence_level.cuda() 47 | gt_map = gt_map.cuda() 48 | 49 | 50 | intersect, union = get_batch_iou(pred_map, gt_map) 51 | CD1, CD2, num1, num2 = semantic_mask_chamfer_dist_cum( 52 | pred_map, gt_map, args.xbound[2], args.ybound[2], threshold=args.CD_threshold) 53 | 54 | instance_mask_AP(AP_matrix, AP_count_matrix, pred_map, gt_map, args.xbound[2], args.ybound[2], 55 | confidence_level, THRESHOLDS, sampled_recalls=SAMPLED_RECALLS, bidirectional=args.bidirectional, threshold_iou=args.threshold_iou) 56 | 57 | total_intersect += intersect.cuda() 58 | total_union += union.cuda() 59 | total_CD1 += CD1 60 | total_CD2 += CD2 61 | total_CD_num1 += num1 62 | total_CD_num2 += num2 63 | 64 | 65 | CD_pred = total_CD1 / total_CD_num1 66 | CD_label = total_CD2 / total_CD_num2 67 | CD = (total_CD1 + total_CD2) / (total_CD_num1 +total_CD_num2) 68 | AP = AP_matrix / AP_count_matrix 69 | 70 | return { 71 | 'iou': total_intersect / total_union, 72 | 'CD_pred': CD_pred, 73 | 'CD_label': CD_label, 74 | 'CD': CD, 75 | 'AP': AP, 76 | } 77 | 78 | 79 | if __name__ == '__main__': 80 | parser = argparse.ArgumentParser(description='Evaluate Vectorized HDMap Construction Results.') 81 | parser.add_argument("config", help = 'path to config file', type=str, default=None) 82 | 83 | args = parser.parse_args() 84 | cfg = Config.fromfile(args.config) 85 | 86 | print(get_val_info(cfg)) 87 | 88 | -------------------------------------------------------------------------------- /tools/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .modules.lpips import LPIPS 4 | 5 | def lpips(x: torch.Tensor, 6 | y: torch.Tensor, 7 | net_type: str = 'alex', 8 | version: str = '0.1'): 9 | r"""Function that measures 10 | Learned Perceptual Image Patch Similarity (LPIPS). 11 | 12 | Arguments: 13 | x, y (torch.Tensor): the input tensors to compare. 14 | net_type (str): the network type to compare the features: 15 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 16 | version (str): the version of LPIPS. Default: 0.1. 17 | """ 18 | device = x.device 19 | criterion = LPIPS(net_type, version).to(device) 20 | return criterion(x, y) -------------------------------------------------------------------------------- /tools/evaluation/angle_diff.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def onehot_encoding_spread(logits, dim=1): 5 | max_idx = torch.argmax(logits, dim, keepdim=True) 6 | one_hot = logits.new_full(logits.shape, 0) 7 | one_hot.scatter_(dim, max_idx, 1) 8 | one_hot.scatter_(dim, torch.clamp(max_idx-1, min=0), 1) 9 | one_hot.scatter_(dim, torch.clamp(max_idx-2, min=0), 1) 10 | one_hot.scatter_(dim, torch.clamp(max_idx+1, max=logits.shape[dim]-1), 1) 11 | one_hot.scatter_(dim, torch.clamp(max_idx+2, max=logits.shape[dim]-1), 1) 12 | 13 | return one_hot 14 | 15 | 16 | def get_pred_top2_direction(direction, dim=1): 17 | direction = torch.softmax(direction, dim) 18 | idx1 = torch.argmax(direction, dim) 19 | idx1_onehot_spread = onehot_encoding_spread(direction, dim) 20 | idx1_onehot_spread = idx1_onehot_spread.bool() 21 | direction[idx1_onehot_spread] = 0 22 | idx2 = torch.argmax(direction, dim) 23 | direction = torch.stack([idx1, idx2], dim) - 1 24 | return direction 25 | 26 | 27 | def calc_angle_diff(pred_mask, gt_mask, angle_class): 28 | per_angle = float(360. / angle_class) 29 | eval_mask = 1 - gt_mask[:, 0] 30 | pred_direction = get_pred_top2_direction(pred_mask, dim=1).float() 31 | gt_direction = (torch.topk(gt_mask, 2, dim=1)[1] - 1).float() 32 | 33 | pred_direction *= per_angle 34 | gt_direction *= per_angle 35 | pred_direction = pred_direction[:, :, None, :, :].repeat(1, 1, 2, 1, 1) 36 | gt_direction = gt_direction[:, None, :, :, :].repeat(1, 2, 1, 1, 1) 37 | diff_mask = torch.abs(pred_direction - gt_direction) 38 | diff_mask = torch.min(diff_mask, 360 - diff_mask) 39 | diff_mask = torch.min(diff_mask[:, 0, 0] + diff_mask[:, 1, 1], diff_mask[:, 1, 0] + diff_mask[:, 0, 1]) / 2 40 | return ((eval_mask * diff_mask).sum() / (eval_mask.sum() + 1e-6)).item() 41 | -------------------------------------------------------------------------------- /tools/evaluation/chamfer_distance.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def chamfer_distance(source_pc, target_pc, threshold, cum=False, bidirectional=True): 5 | dist = torch.cdist(source_pc.float(), target_pc.float()) 6 | dist1, _ = torch.min(dist, 2) 7 | dist2, _ = torch.min(dist, 1) 8 | if cum: 9 | len1 = dist1.shape[-1] 10 | len2 = dist2.shape[-1] 11 | dist1 = dist1.sum(-1) 12 | dist2 = dist2.sum(-1) 13 | return dist1, dist2, len1, len2 14 | dist1 = dist1.mean(-1) 15 | dist2 = dist2.mean(-1) 16 | if bidirectional: 17 | return min((dist1 + dist2) / 2, threshold) 18 | else: 19 | #return min(dist1, threshold), min(dist2, threshold) 20 | return min(dist1, threshold) 21 | 22 | 23 | def semantic_mask_chamfer_dist_cum(seg_pred, seg_label, scale_x, scale_y, threshold): 24 | # seg_label: N, C, H, W 25 | # seg_pred: N, C, H, W 26 | N, C, H, W = seg_label.shape 27 | 28 | cum_CD1 = torch.zeros(C, device=seg_label.device) 29 | cum_CD2 = torch.zeros(C, device=seg_label.device) 30 | cum_num1 = torch.zeros(C, device=seg_label.device) 31 | cum_num2 = torch.zeros(C, device=seg_label.device) 32 | for n in range(N): 33 | for c in range(C): 34 | pred_pc_x, pred_pc_y = torch.where(seg_pred[n, c] != 0) 35 | label_pc_x, label_pc_y = torch.where(seg_label[n, c] != 0) 36 | pred_pc_x = pred_pc_x.float() * scale_x 37 | pred_pc_y = pred_pc_y.float() * scale_y 38 | label_pc_x = label_pc_x.float() * scale_x 39 | label_pc_y = label_pc_y.float() * scale_y 40 | if len(pred_pc_x) == 0 and len(label_pc_x) == 0: 41 | continue 42 | 43 | if len(label_pc_x) == 0: 44 | cum_CD1[c] += len(pred_pc_x) * threshold 45 | cum_num1[c] += len(pred_pc_x) 46 | continue 47 | 48 | if len(pred_pc_x) == 0: 49 | cum_CD2[c] += len(label_pc_x) * threshold 50 | cum_num2[c] += len(label_pc_x) 51 | continue 52 | 53 | pred_pc_coords = torch.stack([pred_pc_x, pred_pc_y], -1).float() 54 | label_pc_coords = torch.stack([label_pc_x, label_pc_y], -1).float() 55 | CD1, CD2, len1, len2 = chamfer_distance(pred_pc_coords[None], label_pc_coords[None], threshold=threshold, cum=True) 56 | cum_CD1[c] += CD1.item() 57 | cum_CD2[c] += CD2.item() 58 | cum_num1[c] += len1 59 | cum_num2[c] += len2 60 | return cum_CD1, cum_CD2, cum_num1, cum_num2 61 | 62 | def semantic_mask_turn_cal(seg_pred, seg_label, scale_x, scale_y, threshold): 63 | # seg_label: N, C, H, W 64 | # seg_pred: N, C, H, W 65 | N, C, H, W = seg_label.shape 66 | print("N: ", N) 67 | print("C: ", C) 68 | 69 | for n in range(N): 70 | for c in range(C): 71 | label_pc_x, label_pc_y = torch.where(seg_label[n, c] != 0) 72 | label_pc_x = label_pc_x.float() * scale_x 73 | label_pc_y = label_pc_y.float() * scale_y 74 | if len(label_pc_x) == 0: 75 | continue 76 | 77 | label_pc_coords = torch.stack([label_pc_x, label_pc_y], -1).float() 78 | print("label_pc_coords.shape: ", label_pc_coords.shape) 79 | print("label_pc_coords[0]: ", label_pc_coords[0]) 80 | print("label_pc_coords[-1]: ", label_pc_coords[-1]) 81 | # CD1, CD2, len1, len2 = chamfer_distance(pred_pc_coords[None], label_pc_coords[None], threshold=threshold, cum=True) 82 | # cum_CD1[c] += CD1.item() 83 | # cum_CD2[c] += CD2.item() 84 | # cum_num1[c] += len1 85 | # cum_num2[c] += len2 86 | # return cum_CD1, cum_CD2, cum_num1, cum_num2 87 | return None -------------------------------------------------------------------------------- /tools/evaluation/dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | 4 | import torch 5 | 6 | from data_osm.dataset import PMapNetDataset 7 | from data_osm.rasterize import rasterize_map 8 | from data_osm.const import NUM_CLASSES 9 | from nuscenes.utils.splits import create_splits_scenes 10 | 11 | 12 | class PMapNetEvalDataset(PMapNetDataset): 13 | def __init__(self, version, dataroot, eval_set, result_path, data_conf, max_line_count=300): 14 | self.eval_set = eval_set 15 | super(PMapNetEvalDataset, self).__init__(version, dataroot, data_conf, is_train=False) 16 | with open(result_path, 'r') as f: 17 | self.prediction = json.load(f) 18 | self.max_line_count = max_line_count 19 | self.thickness = data_conf['thickness'] 20 | 21 | def get_scenes(self, version, is_train): 22 | return create_splits_scenes()[self.eval_set] 23 | 24 | def __len__(self): 25 | return len(self.samples) 26 | 27 | def __getitem__(self, idx): 28 | rec = self.samples[idx] 29 | location = self.nusc.get('log', self.nusc.get('scene', rec['scene_token'])['log_token'])['location'] 30 | ego_pose = self.nusc.get('ego_pose', self.nusc.get('sample_data', rec['data']['LIDAR_TOP'])['ego_pose_token']) 31 | gt_vectors, polygon_geom, osm_vectors = self.vector_map.gen_vectorized_samples(location, ego_pose['translation'], ego_pose['rotation']) 32 | # import pdb; pdb.set_trace() 33 | gt_map, _ = rasterize_map(gt_vectors, self.patch_size, self.canvas_size, NUM_CLASSES, self.thickness) 34 | if self.prediction['meta']['vector']: 35 | pred_vectors = self.prediction['results'][rec['token']] 36 | pred_map, confidence_level = rasterize_map(pred_vectors, self.patch_size, self.canvas_size, NUM_CLASSES, self.thickness) 37 | else: 38 | pred_map = np.array(self.prediction['results'][rec['token']]['map']) 39 | confidence_level = self.prediction['results'][rec['token']]['confidence_level'] 40 | 41 | confidence_level = torch.tensor(confidence_level + [-1] * (self.max_line_count - len(confidence_level))) 42 | 43 | return pred_map, confidence_level, gt_map 44 | -------------------------------------------------------------------------------- /tools/evaluation/iou.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def get_batch_iou(pred_map, gt_map): 5 | intersects = [] 6 | unions = [] 7 | with torch.no_grad(): 8 | pred_map = pred_map.bool() 9 | gt_map = gt_map.bool() 10 | 11 | for i in range(pred_map.shape[1]): 12 | pred = pred_map[:, i] 13 | tgt = gt_map[:, i] 14 | # import pdb; pdb.set_trace() 15 | intersect = (pred & tgt).sum().float() 16 | union = (pred | tgt).sum().float() 17 | intersects.append(intersect) 18 | unions.append(union) 19 | return torch.tensor(intersects), torch.tensor(unions) 20 | -------------------------------------------------------------------------------- /tools/evaluation/modules/lpips.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .networks import get_network, LinLayers 5 | from .utils import get_state_dict 6 | 7 | 8 | class LPIPS(nn.Module): 9 | r"""Creates a criterion that measures 10 | Learned Perceptual Image Patch Similarity (LPIPS). 11 | 12 | Arguments: 13 | net_type (str): the network type to compare the features: 14 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 15 | version (str): the version of LPIPS. Default: 0.1. 16 | """ 17 | def __init__(self, net_type: str = 'alex', version: str = '0.1'): 18 | 19 | assert version in ['0.1'], 'v0.1 is only supported now' 20 | 21 | super(LPIPS, self).__init__() 22 | 23 | # pretrained network 24 | self.net = get_network(net_type) 25 | 26 | # linear layers 27 | self.lin = LinLayers(self.net.n_channels_list) 28 | self.lin.load_state_dict(get_state_dict(net_type, version)) 29 | 30 | def forward(self, x: torch.Tensor, y: torch.Tensor): 31 | feat_x, feat_y = self.net(x), self.net(y) 32 | 33 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)] 34 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)] 35 | 36 | return torch.sum(torch.cat(res, 0), 0, True) 37 | -------------------------------------------------------------------------------- /tools/evaluation/modules/networks.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | from itertools import chain 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import models 8 | 9 | from .utils import normalize_activation 10 | 11 | 12 | def get_network(net_type: str): 13 | if net_type == 'alex': 14 | return AlexNet() 15 | elif net_type == 'squeeze': 16 | return SqueezeNet() 17 | elif net_type == 'vgg': 18 | return VGG16() 19 | else: 20 | raise NotImplementedError('choose net_type from [alex, squeeze, vgg].') 21 | 22 | 23 | class LinLayers(nn.ModuleList): 24 | def __init__(self, n_channels_list: Sequence[int]): 25 | super(LinLayers, self).__init__([ 26 | nn.Sequential( 27 | nn.Identity(), 28 | nn.Conv2d(nc, 1, 1, 1, 0, bias=False) 29 | ) for nc in n_channels_list 30 | ]) 31 | 32 | for param in self.parameters(): 33 | param.requires_grad = False 34 | 35 | 36 | class BaseNet(nn.Module): 37 | def __init__(self): 38 | super(BaseNet, self).__init__() 39 | 40 | # register buffer 41 | self.register_buffer( 42 | 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 43 | self.register_buffer( 44 | 'std', torch.Tensor([.458, .448, .450])[None, :, None, None]) 45 | 46 | def set_requires_grad(self, state: bool): 47 | for param in chain(self.parameters(), self.buffers()): 48 | param.requires_grad = state 49 | 50 | def z_score(self, x: torch.Tensor): 51 | return (x - self.mean) / self.std 52 | 53 | def forward(self, x: torch.Tensor): 54 | x = self.z_score(x) 55 | 56 | output = [] 57 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1): 58 | x = layer(x) 59 | if i in self.target_layers: 60 | output.append(normalize_activation(x)) 61 | if len(output) == len(self.target_layers): 62 | break 63 | return output 64 | 65 | 66 | class SqueezeNet(BaseNet): 67 | def __init__(self): 68 | super(SqueezeNet, self).__init__() 69 | 70 | self.layers = models.squeezenet1_1(True).features 71 | self.target_layers = [2, 5, 8, 10, 11, 12, 13] 72 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] 73 | 74 | self.set_requires_grad(False) 75 | 76 | 77 | class AlexNet(BaseNet): 78 | def __init__(self): 79 | super(AlexNet, self).__init__() 80 | 81 | self.layers = models.alexnet(True).features 82 | self.target_layers = [2, 5, 8, 10, 12] 83 | self.n_channels_list = [64, 192, 384, 256, 256] 84 | 85 | self.set_requires_grad(False) 86 | 87 | 88 | class VGG16(BaseNet): 89 | def __init__(self): 90 | super(VGG16, self).__init__() 91 | 92 | self.layers = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features 93 | self.target_layers = [4, 9, 16, 23, 30] 94 | self.n_channels_list = [64, 128, 256, 512, 512] 95 | 96 | self.set_requires_grad(False) 97 | -------------------------------------------------------------------------------- /tools/evaluation/modules/utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | 5 | 6 | def normalize_activation(x, eps=1e-10): 7 | norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)) 8 | return x / (norm_factor + eps) 9 | 10 | 11 | def get_state_dict(net_type: str = 'alex', version: str = '0.1'): 12 | # build url 13 | url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \ 14 | + f'master/lpips/weights/v{version}/{net_type}.pth' 15 | 16 | # download 17 | old_state_dict = torch.hub.load_state_dict_from_url( 18 | url, progress=True, 19 | map_location=None if torch.cuda.is_available() else torch.device('cpu') 20 | ) 21 | 22 | # rename keys 23 | new_state_dict = OrderedDict() 24 | for key, val in old_state_dict.items(): 25 | new_key = key 26 | new_key = new_key.replace('lin', '') 27 | new_key = new_key.replace('model.', '') 28 | new_state_dict[new_key] = val 29 | 30 | return new_state_dict 31 | -------------------------------------------------------------------------------- /tools/export_json.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import tqdm 3 | import torch 4 | import mmcv 5 | from config import Config 6 | import sys 7 | import os 8 | currentPath = os.path.split(os.path.realpath(__file__))[0] 9 | sys.path.append(currentPath + '/..') 10 | from data_osm.dataset import semantic_dataset 11 | from data_osm.const import NUM_CLASSES 12 | from model import get_model 13 | from postprocess.vectorize import vectorize 14 | from collections import OrderedDict 15 | from tools.evaluation.iou import get_batch_iou 16 | import warnings 17 | warnings.filterwarnings("ignore") 18 | import matplotlib.pyplot as plt 19 | import os 20 | from PIL import Image 21 | 22 | 23 | 24 | def gen_dx_bx(xbound, ybound): 25 | dx = [row[2] for row in [xbound, ybound]] 26 | bx = [row[0] + row[2] / 2.0 for row in [xbound, ybound]] 27 | nx = [(row[1] - row[0]) / row[2] for row in [xbound, ybound]] 28 | return dx, bx, nx 29 | def onehot_encoding(logits, dim=1): 30 | max_idx = torch.argmax(logits, dim, keepdim=True) 31 | one_hot = logits.new_full(logits.shape, 0) 32 | one_hot.scatter_(dim, max_idx, 1) 33 | return one_hot 34 | 35 | def export_to_json(model, val_loader, angle_class, args): 36 | submission = { 37 | "meta": { 38 | "use_camera": True, 39 | "use_lidar": False, 40 | "use_radar": False, 41 | "use_external": False, 42 | "vector": True, 43 | }, 44 | "results": {} 45 | } # todo: add mode 46 | 47 | dx, bx, nx = gen_dx_bx(args.xbound, args.ybound) 48 | count = 0 49 | model.eval() 50 | with torch.no_grad(): 51 | for batchi, (imgs, trans, rots, intrins, post_trans, post_rots, lidar_data, lidar_mask, car_trans, 52 | yaw_pitch_roll, semantic_gt, instance_gt, direction_gt, osm_masks, osm_vectors, masked_map, timestamp,scene_id) in enumerate(tqdm.tqdm(val_loader)): 53 | 54 | segmentation, embedding, direction = model(imgs.cuda(), trans.cuda(), rots.cuda(), intrins.cuda(), 55 | post_trans.cuda(), post_rots.cuda(), lidar_data.cuda(), 56 | lidar_mask.cuda(), car_trans.cuda(), yaw_pitch_roll.cuda(), osm_masks.float().cuda()) 57 | 58 | for si in range(segmentation.shape[0]): 59 | coords, confidences, line_types = vectorize(segmentation[si], embedding[si], direction[si], angle_class) 60 | count += 1 61 | vectors = [] 62 | for coord, confidence, line_type in zip(coords, confidences, line_types): 63 | vector = {'pts': coord * dx + bx, 'pts_num': len(coord), "type": line_type, "confidence_level": confidence} 64 | vectors.append(vector) 65 | rec = val_loader.dataset.samples[batchi * val_loader.batch_size + si] 66 | submission['results'][rec['token']] = vectors 67 | mmcv.dump(submission, args.result_path) 68 | 69 | 70 | def main(args): 71 | data_conf = { 72 | 'num_channels': NUM_CLASSES + 1, 73 | 'image_size': cfg.image_size, 74 | 'xbound': cfg.xbound, 75 | 'ybound': cfg.ybound, 76 | 'zbound': cfg.zbound, 77 | 'dbound': cfg.dbound, 78 | 'thickness': cfg.thickness, 79 | 'angle_class': cfg.angle_class, 80 | 'patch_w': cfg.patch_w, 81 | 'patch_h': cfg.patch_h, 82 | 'mask_ratio': cfg.mask_ratio, 83 | 'mask_flag': cfg.mask_flag, 84 | 'sd_map_path': cfg.sd_map_path, 85 | } 86 | 87 | train_loader, val_loader = semantic_dataset(args, args.version, args.dataroot, data_conf, 88 | args.batch_size, args.nworkers, cfg.dataset) 89 | model = get_model(args, data_conf, args.instance_seg, args.embedding_dim, args.direction_pred, args.angle_class) 90 | # import pdb; pdb.set_trace() 91 | state_dict_model_120 = torch.load(args.modelf) 92 | new_state_dict_120 = OrderedDict() 93 | for k, v in state_dict_model_120.items(): 94 | name = k[7:] 95 | new_state_dict_120[name] = v 96 | model.load_state_dict(new_state_dict_120, strict=True) 97 | model.cuda() 98 | 99 | export_to_json(model, val_loader, args.angle_class, args) 100 | 101 | 102 | if __name__ == '__main__': 103 | parser = argparse.ArgumentParser(description='Export vector results to json.') 104 | parser.add_argument("config", help = 'path to config file', type=str, default=None) 105 | args = parser.parse_args() 106 | cfg = Config.fromfile(args.config) 107 | print("cfg: ", cfg) 108 | main(cfg) 109 | -------------------------------------------------------------------------------- /tools/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class FocalLoss(nn.Module): 7 | def __init__(self, alpha=1, gamma=2, reduce='mean'): 8 | super(FocalLoss, self).__init__() 9 | self.alpha = alpha 10 | self.gamma = gamma 11 | self.reduce = reduce 12 | 13 | def forward(self, inputs, targets): 14 | BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduce=False) 15 | pt = torch.exp(-BCE_loss) 16 | F_loss = self.alpha * (1 - pt) ** self.gamma * BCE_loss 17 | 18 | if self.reduce == 'mean': 19 | return torch.mean(F_loss) 20 | elif self.reduce == 'sum': 21 | return torch.sum(F_loss) 22 | else: 23 | raise NotImplementedError 24 | 25 | 26 | class SimpleLoss(torch.nn.Module): 27 | def __init__(self, pos_weight): 28 | super(SimpleLoss, self).__init__() 29 | self.loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([pos_weight])) 30 | 31 | def forward(self, ypred, ytgt): 32 | loss = self.loss_fn(ypred, ytgt) 33 | return loss 34 | 35 | 36 | class DiscriminativeLoss(nn.Module): 37 | def __init__(self, embed_dim, delta_v, delta_d): 38 | super(DiscriminativeLoss, self).__init__() 39 | self.embed_dim = embed_dim 40 | self.delta_v = delta_v 41 | self.delta_d = delta_d 42 | 43 | def forward(self, embedding, seg_gt): 44 | if embedding is None: 45 | return 0, 0, 0 46 | bs = embedding.shape[0] 47 | 48 | var_loss = torch.tensor(0, dtype=embedding.dtype, device=embedding.device) 49 | dist_loss = torch.tensor(0, dtype=embedding.dtype, device=embedding.device) 50 | reg_loss = torch.tensor(0, dtype=embedding.dtype, device=embedding.device) 51 | 52 | for b in range(bs): 53 | embedding_b = embedding[b] # (embed_dim, H, W) 54 | seg_gt_b = seg_gt[b] 55 | 56 | labels = torch.unique(seg_gt_b) 57 | labels = labels[labels != 0] 58 | num_lanes = len(labels) 59 | if num_lanes == 0: 60 | # please refer to issue here: https://github.com/harryhan618/LaneNet/issues/12 61 | _nonsense = embedding.sum() 62 | _zero = torch.zeros_like(_nonsense) 63 | var_loss = var_loss + _nonsense * _zero 64 | dist_loss = dist_loss + _nonsense * _zero 65 | reg_loss = reg_loss + _nonsense * _zero 66 | continue 67 | 68 | centroid_mean = [] 69 | for lane_idx in labels: 70 | seg_mask_i = (seg_gt_b == lane_idx) 71 | if not seg_mask_i.any(): 72 | continue 73 | embedding_i = embedding_b[:, seg_mask_i] 74 | 75 | mean_i = torch.mean(embedding_i, dim=1) 76 | centroid_mean.append(mean_i) 77 | 78 | # ---------- var_loss ------------- 79 | var_loss = var_loss + torch.mean(F.relu(torch.norm(embedding_i-mean_i.reshape(self.embed_dim, 1), dim=0) - self.delta_v) ** 2) / num_lanes 80 | centroid_mean = torch.stack(centroid_mean) # (n_lane, embed_dim) 81 | 82 | if num_lanes > 1: 83 | centroid_mean1 = centroid_mean.reshape(-1, 1, self.embed_dim) 84 | centroid_mean2 = centroid_mean.reshape(1, -1, self.embed_dim) 85 | dist = torch.norm(centroid_mean1-centroid_mean2, dim=2) # shape (num_lanes, num_lanes) 86 | dist = dist + torch.eye(num_lanes, dtype=dist.dtype, device=dist.device) * self.delta_d # diagonal elements are 0, now mask above delta_d 87 | 88 | # divided by two for double calculated loss above, for implementation convenience 89 | dist_loss = dist_loss + torch.sum(F.relu(-dist + self.delta_d)**2) / (num_lanes * (num_lanes-1)) / 2 90 | 91 | # reg_loss is not used in original paper 92 | # reg_loss = reg_loss + torch.mean(torch.norm(centroid_mean, dim=1)) 93 | 94 | var_loss = var_loss / bs 95 | dist_loss = dist_loss / bs 96 | reg_loss = reg_loss / bs 97 | return var_loss, dist_loss, reg_loss 98 | 99 | 100 | def calc_loss(): 101 | pass 102 | -------------------------------------------------------------------------------- /tools/postprocess/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jike5/P-MapNet/b8b4cf2295ee75826046eef9cfa12b107fb43619/tools/postprocess/__init__.py -------------------------------------------------------------------------------- /tools/postprocess/cluster.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 18-5-30 上午10:04 4 | # @Author : MaybeShewill-CV 5 | # @Site : https://github.com/MaybeShewill-CV/lanenet-lane-detection 6 | # @File : lanenet_postprocess.py 7 | # @IDE: PyCharm Community Edition 8 | """ 9 | LaneNet model post process 10 | """ 11 | import matplotlib.pyplot as plt 12 | import cv2 13 | import numpy as np 14 | 15 | from sklearn.cluster import DBSCAN 16 | from sklearn.preprocessing import StandardScaler 17 | 18 | 19 | def _morphological_process(image, mode='MORPH_CLOSE', kernel_size=5): 20 | """ 21 | morphological process to fill the hole in the binary segmentation result 22 | :param image: 23 | :param kernel_size: 24 | :return: 25 | """ 26 | if len(image.shape) == 3: 27 | raise ValueError('Binary segmentation result image should be a single channel image') 28 | 29 | if image.dtype is not np.uint8: 30 | image = np.array(image, np.uint8) 31 | 32 | # close operation fille hole 33 | kernel = cv2.getStructuringElement(shape=cv2.MORPH_ELLIPSE, ksize=(kernel_size, kernel_size)) 34 | if mode == 'MORPH_CLOSE': 35 | closing = cv2.morphologyEx(image, cv2.MORPH_CLOSE, kernel, iterations=1) 36 | elif mode == 'MORPH_OPEN': 37 | closing = cv2.morphologyEx(image, cv2.MORPH_OPEN, kernel, iterations=1) 38 | else: 39 | closing = image 40 | return closing 41 | 42 | 43 | def _connect_components_analysis(image): 44 | """ 45 | connect components analysis to remove the small components 46 | :param image: 47 | :return: 48 | """ 49 | if len(image.shape) == 3: 50 | gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) 51 | else: 52 | gray_image = image 53 | 54 | return cv2.connectedComponentsWithStats(gray_image, connectivity=8, ltype=cv2.CV_32S) 55 | 56 | 57 | class _LaneFeat(object): 58 | """ 59 | 60 | """ 61 | def __init__(self, feat, coord, class_id=-1): 62 | """ 63 | lane feat object 64 | :param feat: lane embeddng feats [feature_1, feature_2, ...] 65 | :param coord: lane coordinates [x, y] 66 | :param class_id: lane class id 67 | """ 68 | self._feat = feat 69 | self._coord = coord 70 | self._class_id = class_id 71 | 72 | @property 73 | def feat(self): 74 | """ 75 | 76 | :return: 77 | """ 78 | return self._feat 79 | 80 | @feat.setter 81 | def feat(self, value): 82 | """ 83 | 84 | :param value: 85 | :return: 86 | """ 87 | if not isinstance(value, np.ndarray): 88 | value = np.array(value, dtype=np.float64) 89 | 90 | if value.dtype != np.float32: 91 | value = np.array(value, dtype=np.float64) 92 | 93 | self._feat = value 94 | 95 | @property 96 | def coord(self): 97 | """ 98 | 99 | :return: 100 | """ 101 | return self._coord 102 | 103 | @coord.setter 104 | def coord(self, value): 105 | """ 106 | 107 | :param value: 108 | :return: 109 | """ 110 | if not isinstance(value, np.ndarray): 111 | value = np.array(value) 112 | 113 | if value.dtype != np.int32: 114 | value = np.array(value, dtype=np.int32) 115 | 116 | self._coord = value 117 | 118 | @property 119 | def class_id(self): 120 | """ 121 | 122 | :return: 123 | """ 124 | return self._class_id 125 | 126 | @class_id.setter 127 | def class_id(self, value): 128 | """ 129 | 130 | :param value: 131 | :return: 132 | """ 133 | if not isinstance(value, np.int64): 134 | raise ValueError('Class id must be integer') 135 | 136 | self._class_id = value 137 | 138 | 139 | class _LaneNetCluster(object): 140 | """ 141 | Instance segmentation result cluster 142 | """ 143 | 144 | def __init__(self, dbscan_eps=0.35, postprocess_min_samples=200): 145 | """ 146 | 147 | """ 148 | self.dbscan_eps = dbscan_eps 149 | self.postprocess_min_samples = postprocess_min_samples 150 | 151 | def _embedding_feats_dbscan_cluster(self, embedding_image_feats): 152 | """ 153 | dbscan cluster 154 | :param embedding_image_feats: 155 | :return: 156 | """ 157 | from sklearn.cluster import MeanShift 158 | 159 | db = DBSCAN(eps=self.dbscan_eps, min_samples=self.postprocess_min_samples) 160 | # db = MeanShift() 161 | try: 162 | features = StandardScaler().fit_transform(embedding_image_feats) 163 | db.fit(features) 164 | except Exception as err: 165 | # print(err) 166 | ret = { 167 | 'origin_features': None, 168 | 'cluster_nums': 0, 169 | 'db_labels': None, 170 | 'unique_labels': None, 171 | 'cluster_center': None 172 | } 173 | return ret 174 | db_labels = db.labels_ 175 | unique_labels = np.unique(db_labels) 176 | 177 | num_clusters = len(unique_labels) 178 | # cluster_centers = db.components_ 179 | 180 | ret = { 181 | 'origin_features': features, 182 | 'cluster_nums': num_clusters, 183 | 'db_labels': db_labels, 184 | 'unique_labels': unique_labels, 185 | # 'cluster_center': cluster_centers 186 | } 187 | 188 | return ret 189 | 190 | @staticmethod 191 | def _get_lane_embedding_feats(binary_seg_ret, instance_seg_ret): 192 | """ 193 | get lane embedding features according the binary seg result 194 | :param binary_seg_ret: 195 | :param instance_seg_ret: 196 | :return: 197 | """ 198 | idx = np.where(binary_seg_ret == 255) 199 | lane_embedding_feats = instance_seg_ret[idx] 200 | # idx_scale = np.vstack((idx[0] / 256.0, idx[1] / 512.0)).transpose() 201 | # lane_embedding_feats = np.hstack((lane_embedding_feats, idx_scale)) 202 | lane_coordinate = np.vstack((idx[1], idx[0])).transpose() 203 | 204 | assert lane_embedding_feats.shape[0] == lane_coordinate.shape[0] 205 | 206 | ret = { 207 | 'lane_embedding_feats': lane_embedding_feats, 208 | 'lane_coordinates': lane_coordinate 209 | } 210 | 211 | return ret 212 | 213 | def apply_lane_feats_cluster(self, binary_seg_result, instance_seg_result): 214 | """ 215 | 216 | :param binary_seg_result: 217 | :param instance_seg_result: 218 | :return: 219 | """ 220 | # get embedding feats and coords 221 | get_lane_embedding_feats_result = self._get_lane_embedding_feats( 222 | binary_seg_ret=binary_seg_result, 223 | instance_seg_ret=instance_seg_result 224 | ) 225 | 226 | # dbscan cluster 227 | dbscan_cluster_result = self._embedding_feats_dbscan_cluster( 228 | embedding_image_feats=get_lane_embedding_feats_result['lane_embedding_feats'] 229 | ) 230 | 231 | mask = np.zeros(shape=[binary_seg_result.shape[0], binary_seg_result.shape[1]], dtype=np.int) 232 | db_labels = dbscan_cluster_result['db_labels'] 233 | unique_labels = dbscan_cluster_result['unique_labels'] 234 | coord = get_lane_embedding_feats_result['lane_coordinates'] 235 | 236 | if db_labels is None: 237 | return None, None 238 | 239 | lane_coords = [] 240 | 241 | for index, label in enumerate(unique_labels.tolist()): 242 | if label == -1: 243 | continue 244 | idx = np.where(db_labels == label) 245 | pix_coord_idx = tuple((coord[idx][:, 1], coord[idx][:, 0])) 246 | mask[pix_coord_idx] = label + 1 247 | lane_coords.append(coord[idx]) 248 | 249 | return mask, lane_coords 250 | 251 | 252 | class LaneNetPostProcessor(object): 253 | """ 254 | lanenet post process for lane generation 255 | """ 256 | def __init__(self, dbscan_eps=0.35, postprocess_min_samples=200): 257 | """ 258 | 259 | :param ipm_remap_file_path: ipm generate file path 260 | """ 261 | 262 | self._cluster = _LaneNetCluster(dbscan_eps, postprocess_min_samples) 263 | 264 | def postprocess(self, binary_seg_result, mode, instance_seg_result=None, min_area_threshold=100): 265 | """ 266 | 267 | :param binary_seg_result: 268 | :param instance_seg_result: 269 | :param min_area_threshold: 270 | :param source_image: 271 | :param data_source: 272 | :return: 273 | """ 274 | # convert binary_seg_result 275 | binary_seg_result = np.array(binary_seg_result * 255, dtype=np.uint8) 276 | 277 | # apply image morphology operation to fill in the hold and reduce the small area 278 | morphological_ret = _morphological_process(binary_seg_result, mode, kernel_size=5) 279 | 280 | connect_components_analysis_ret = _connect_components_analysis(image=morphological_ret) 281 | 282 | labels = connect_components_analysis_ret[1] 283 | stats = connect_components_analysis_ret[2] 284 | for index, stat in enumerate(stats): 285 | if stat[4] <= min_area_threshold: 286 | idx = np.where(labels == index) 287 | morphological_ret[idx] = 0 288 | 289 | # apply embedding features cluster 290 | mask_image, lane_coords = self._cluster.apply_lane_feats_cluster( 291 | binary_seg_result=morphological_ret, 292 | instance_seg_result=instance_seg_result 293 | ) 294 | 295 | return mask_image, lane_coords 296 | -------------------------------------------------------------------------------- /tools/postprocess/connect.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import numpy as np 4 | from copy import deepcopy 5 | 6 | import torch 7 | 8 | 9 | def sort_points_by_dist(coords): 10 | coords = coords.astype('float') 11 | num_points = coords.shape[0] 12 | diff_matrix = np.repeat(coords[:, None], num_points, 1) - coords 13 | # x_range = np.max(np.abs(diff_matrix[..., 0])) 14 | # y_range = np.max(np.abs(diff_matrix[..., 1])) 15 | # diff_matrix[..., 1] *= x_range / y_range 16 | dist_matrix = np.sqrt(((diff_matrix) ** 2).sum(-1)) 17 | dist_matrix_full = deepcopy(dist_matrix) 18 | direction_matrix = diff_matrix / (dist_matrix.reshape(num_points, num_points, 1) + 1e-6) 19 | 20 | sorted_points = [coords[0]] 21 | sorted_indices = [0] 22 | dist_matrix[:, 0] = np.inf 23 | 24 | last_direction = np.array([0, 0]) 25 | for i in range(num_points - 1): 26 | last_idx = sorted_indices[-1] 27 | dist_metric = dist_matrix[last_idx] - 0 * (last_direction * direction_matrix[last_idx]).sum(-1) 28 | idx = np.argmin(dist_metric) % num_points 29 | new_direction = direction_matrix[last_idx, idx] 30 | if dist_metric[idx] > 3 and min(dist_matrix_full[idx][sorted_indices]) < 5: 31 | dist_matrix[:, idx] = np.inf 32 | continue 33 | if dist_metric[idx] > 10 and i > num_points * 0.9: 34 | break 35 | sorted_points.append(coords[idx]) 36 | sorted_indices.append(idx) 37 | dist_matrix[:, idx] = np.inf 38 | last_direction = new_direction 39 | 40 | return np.stack(sorted_points, 0) 41 | 42 | 43 | def connect_by_step(coords, direction_mask, sorted_points, taken_direction, step=5, per_deg=10): 44 | while True: 45 | last_point = tuple(np.flip(sorted_points[-1])) 46 | if not taken_direction[last_point][0]: 47 | direction = direction_mask[last_point][0] 48 | taken_direction[last_point][0] = True 49 | elif not taken_direction[last_point][1]: 50 | direction = direction_mask[last_point][1] 51 | taken_direction[last_point][1] = True 52 | else: 53 | break 54 | 55 | if direction == -1: 56 | continue 57 | 58 | deg = per_deg * direction 59 | vector_to_target = step * np.array([np.cos(np.deg2rad(deg)), np.sin(np.deg2rad(deg))]) 60 | last_point = deepcopy(sorted_points[-1]) 61 | 62 | # NMS 63 | coords = coords[np.linalg.norm(coords - last_point, axis=-1) > step-1] 64 | 65 | if len(coords) == 0: 66 | break 67 | 68 | target_point = np.array([last_point[0] + vector_to_target[0], last_point[1] + vector_to_target[1]]) 69 | dist_metric = np.linalg.norm(coords - target_point, axis=-1) 70 | idx = np.argmin(dist_metric) 71 | 72 | if dist_metric[idx] > 50: 73 | continue 74 | 75 | sorted_points.append(deepcopy(coords[idx])) 76 | 77 | vector_to_next = coords[idx] - last_point 78 | deg = np.rad2deg(math.atan2(vector_to_next[1], vector_to_next[0])) 79 | inverse_deg = (180 + deg) % 360 80 | target_direction = per_deg * direction_mask[tuple(np.flip(sorted_points[-1]))] 81 | tmp = np.abs(target_direction - inverse_deg) 82 | tmp = torch.min(tmp, 360 - tmp) 83 | taken = np.argmin(tmp) 84 | taken_direction[tuple(np.flip(sorted_points[-1]))][taken] = True 85 | 86 | 87 | def connect_by_direction(coords, direction_mask, step=5, per_deg=10): 88 | sorted_points = [deepcopy(coords[random.randint(0, coords.shape[0]-1)])] 89 | taken_direction = np.zeros_like(direction_mask, dtype=np.bool) 90 | 91 | connect_by_step(coords, direction_mask, sorted_points, taken_direction, step, per_deg) 92 | sorted_points.reverse() 93 | connect_by_step(coords, direction_mask, sorted_points, taken_direction, step, per_deg) 94 | return np.stack(sorted_points, 0) 95 | -------------------------------------------------------------------------------- /tools/postprocess/vectorize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | from .cluster import LaneNetPostProcessor 6 | from .connect import sort_points_by_dist, connect_by_direction 7 | 8 | 9 | def onehot_encoding(logits, dim=0): 10 | max_idx = torch.argmax(logits, dim, keepdim=True) 11 | one_hot = logits.new_full(logits.shape, 0) 12 | one_hot.scatter_(dim, max_idx, 1) 13 | return one_hot 14 | 15 | 16 | def onehot_encoding_spread(logits, dim=1): 17 | max_idx = torch.argmax(logits, dim, keepdim=True) 18 | one_hot = logits.new_full(logits.shape, 0) 19 | one_hot.scatter_(dim, max_idx, 1) 20 | one_hot.scatter_(dim, torch.clamp(max_idx-1, min=0), 1) 21 | one_hot.scatter_(dim, torch.clamp(max_idx-2, min=0), 1) 22 | one_hot.scatter_(dim, torch.clamp(max_idx+1, max=logits.shape[dim]-1), 1) 23 | one_hot.scatter_(dim, torch.clamp(max_idx+2, max=logits.shape[dim]-1), 1) 24 | 25 | return one_hot 26 | 27 | 28 | def get_pred_top2_direction(direction, dim=1): 29 | direction = torch.softmax(direction, dim) 30 | idx1 = torch.argmax(direction, dim) 31 | idx1_onehot_spread = onehot_encoding_spread(direction, dim) 32 | idx1_onehot_spread = idx1_onehot_spread.bool() 33 | direction[idx1_onehot_spread] = 0 34 | idx2 = torch.argmax(direction, dim) 35 | direction = torch.stack([idx1, idx2], dim) - 1 #torch.Size([200, 400, 2]) 36 | return direction 37 | 38 | 39 | def vectorize(segmentation, embedding, direction, angle_class, morpho_mode='MORPH_CLOSE'): 40 | segmentation = segmentation.softmax(0) 41 | embedding = embedding.cpu() 42 | direction = direction.permute(1, 2, 0).cpu() 43 | direction = get_pred_top2_direction(direction, dim=-1) 44 | 45 | max_pool_1 = nn.MaxPool2d((1, 5), padding=(0, 2), stride=1) 46 | avg_pool_1 = nn.AvgPool2d((9, 5), padding=(4, 2), stride=1) 47 | max_pool_2 = nn.MaxPool2d((5, 1), padding=(2, 0), stride=1) 48 | avg_pool_2 = nn.AvgPool2d((5, 9), padding=(2, 4), stride=1) 49 | post_processor = LaneNetPostProcessor(dbscan_eps=1.5, postprocess_min_samples=50) 50 | 51 | oh_pred = onehot_encoding(segmentation).cpu().numpy() 52 | confidences = [] 53 | line_types = [] 54 | simplified_coords = [] 55 | for i in range(1, oh_pred.shape[0]): 56 | single_mask = oh_pred[i].astype('uint8') 57 | single_embedding = embedding.permute(1, 2, 0) 58 | 59 | single_class_inst_mask, single_class_inst_coords = post_processor.postprocess(single_mask, morpho_mode, single_embedding) 60 | if single_class_inst_mask is None: 61 | continue 62 | 63 | num_inst = len(single_class_inst_coords) 64 | 65 | prob = segmentation[i] 66 | prob[single_class_inst_mask == 0] = 0 67 | nms_mask_1 = ((max_pool_1(prob.unsqueeze(0))[0] - prob) < 0.0001).cpu().numpy() 68 | avg_mask_1 = avg_pool_1(prob.unsqueeze(0))[0].cpu().numpy() 69 | nms_mask_2 = ((max_pool_2(prob.unsqueeze(0))[0] - prob) < 0.0001).cpu().numpy() 70 | avg_mask_2 = avg_pool_2(prob.unsqueeze(0))[0].cpu().numpy() 71 | vertical_mask = avg_mask_1 > avg_mask_2 72 | horizontal_mask = ~vertical_mask 73 | nms_mask = (vertical_mask & nms_mask_1) | (horizontal_mask & nms_mask_2) 74 | 75 | for j in range(1, num_inst + 1): 76 | full_idx = np.where((single_class_inst_mask == j)) 77 | full_lane_coord = np.vstack((full_idx[1], full_idx[0])).transpose() 78 | confidence = prob[single_class_inst_mask == j].mean().item() 79 | 80 | idx = np.where(nms_mask & (single_class_inst_mask == j)) 81 | if len(idx[0]) == 0: 82 | continue 83 | lane_coordinate = np.vstack((idx[1], idx[0])).transpose() 84 | 85 | range_0 = np.max(full_lane_coord[:, 0]) - np.min(full_lane_coord[:, 0]) 86 | range_1 = np.max(full_lane_coord[:, 1]) - np.min(full_lane_coord[:, 1]) 87 | if range_0 > range_1: 88 | lane_coordinate = sorted(lane_coordinate, key=lambda x: x[0]) 89 | else: 90 | lane_coordinate = sorted(lane_coordinate, key=lambda x: x[1]) 91 | 92 | lane_coordinate = np.stack(lane_coordinate) 93 | lane_coordinate = sort_points_by_dist(lane_coordinate) 94 | lane_coordinate = lane_coordinate.astype('int32') 95 | lane_coordinate = connect_by_direction(lane_coordinate, direction, step=7, per_deg=360 / angle_class) 96 | 97 | simplified_coords.append(lane_coordinate) 98 | confidences.append(confidence) 99 | line_types.append(i-1) 100 | 101 | return simplified_coords, confidences, line_types 102 | -------------------------------------------------------------------------------- /tools/vis_map.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('.') 3 | import argparse 4 | import tqdm 5 | import os 6 | import cv2 7 | import torch 8 | from tools.evaluation.iou import get_batch_iou 9 | from tools.config import Config 10 | from data_osm.dataset import semantic_dataset 11 | from data_osm.const import NUM_CLASSES 12 | from model import get_model 13 | from postprocess.vectorize import vectorize 14 | from collections import OrderedDict 15 | import matplotlib.pyplot as plt 16 | import numpy as np 17 | from PIL import Image 18 | from tools.evaluation import lpips 19 | from data_osm.image import denormalize_img 20 | import warnings 21 | warnings.filterwarnings("ignore") 22 | 23 | 24 | Nu_SCENE_CANDIDATE = [ 25 | 'scene-0555', 'scene-0556', 'scene-0557', 'scene-0558', 26 | 'scene-1065', 'scene-1066', 'scene-1067', 'scene-1068' 27 | 'scene-0275', 'scene-0276', 'scene-0277', 'scene-0278', 28 | 'scene-0519', 'scene-0520', 'scene-0521', 'scene-0522', 29 | 'scene-0911', 'scene-0912', 'scene-0913', 'scene-0914', 30 | ] 31 | 32 | AV2_SCENE_CANDIDATE = [ 33 | 'f1275002-842e-3571-8f7d-05816bc7cf56', 34 | 'ba67827f-6b99-3d2a-96ab-7c829eb999bb', 35 | 'bf360aeb-1bbd-3c1e-b143-09cf83e4f2e4', 36 | 'ded5ef6e-46ea-3a66-9180-18a6fa0a2db4', 37 | 'e8c9fd64-fdd2-422d-a2a2-6f47500d1d12', 38 | '1f434d15-8745-3fba-9c3e-ccb026688397', 39 | '6f128f23-ee40-3ea9-8c50-c9cdb9d3e8b6', 40 | ] 41 | 42 | SCENE_CANDIDATE = None 43 | 44 | def onehot_encoding(logits, dim=1): 45 | max_idx = torch.argmax(logits, dim, keepdim=True) 46 | one_hot = logits.new_full(logits.shape, 0) 47 | one_hot.scatter_(dim, max_idx, 1) 48 | return one_hot 49 | 50 | 51 | def vis(semantic, semantic_gt, sd_map, time, scene_id, save_path, with_gt=False): 52 | car_img = Image.open('icon/car_gray.png') 53 | semantic = onehot_encoding(semantic) 54 | semantic = semantic.clone().cpu().numpy() 55 | semantic[semantic < 0.1] = np.nan 56 | semantic_gt_mask = semantic_gt.clone().cpu().numpy() 57 | semantic_gt_mask[semantic_gt < 0.1] = np.nan 58 | sd_map = sd_map.cpu().numpy() 59 | sd_map[sd_map < 0.1] = np.nan 60 | 61 | b, c, h, w = semantic.shape 62 | alpha = 0.8 63 | dpi = 600 64 | divier = 'Blues' 65 | ped_crossing = 'Greens' 66 | boundary = 'Purples' 67 | vmax = 1 68 | for i in range(semantic.shape[0]): 69 | if scene_id[i] not in SCENE_CANDIDATE: 70 | continue 71 | save_path_seg = os.path.join(save_path, f'{scene_id[i]}', f'{time[i]}') 72 | if not os.path.exists(save_path_seg): 73 | os.makedirs(save_path_seg) 74 | # vis hdmap gt with sd map 75 | imname = os.path.join(save_path_seg, 'gt_sd_map.png') 76 | if not os.path.exists(imname): 77 | plt.figure(figsize=(w*2/100, 4)) 78 | plt.imshow(semantic_gt_mask[i][1]*0.5, vmin=0, cmap= divier, vmax=vmax, alpha=alpha) 79 | plt.imshow(semantic_gt_mask[i][2]*0.5, vmin=0, cmap= ped_crossing, vmax=vmax, alpha=alpha) 80 | plt.imshow(semantic_gt_mask[i][3]*0.5, vmin=0, cmap=boundary, vmax=vmax, alpha=alpha) 81 | plt.imshow(sd_map[i][0]*0.8, vmin=0, cmap='Greys', vmax=1, alpha=0.9) 82 | plt.xlim(0, w) 83 | plt.ylim(0, h) 84 | plt.axis('off') 85 | plt.tight_layout() 86 | print('saving', imname) 87 | plt.savefig(imname, bbox_inches='tight', format='png', dpi=dpi) 88 | plt.close() 89 | 90 | imname = os.path.join(save_path_seg, 'sd_map.png') 91 | if not os.path.exists(imname): 92 | plt.figure(figsize=(w*2/100, 4)) 93 | plt.imshow(sd_map[i][0]*0.8, vmin=0, cmap='Greys', vmax=1, alpha=0.9) 94 | plt.xlim(0, w) 95 | plt.ylim(0, h) 96 | plt.axis('off') 97 | plt.tight_layout() 98 | print('saving', imname) 99 | plt.savefig(imname, bbox_inches='tight', format='png', dpi=dpi) 100 | plt.close() 101 | 102 | # vis pred hdmap 103 | imname = os.path.join(save_path_seg, 'pred_map.png') 104 | if not os.path.exists(imname): 105 | plt.figure(figsize=(w*2/100, 4)) 106 | plt.imshow(semantic[i][1]*0.5, vmin=0, cmap= divier, vmax=vmax, alpha=alpha) 107 | plt.imshow(semantic[i][2]*0.5, vmin=0, cmap= ped_crossing, vmax=vmax, alpha=alpha) 108 | plt.imshow(semantic[i][3]*0.5, vmin=0, cmap=boundary, vmax=vmax, alpha=alpha) 109 | plt.xlim(0, w) 110 | plt.ylim(0, h) 111 | plt.imshow(car_img, extent=[w//2-15, w//2+15, h//2-12, h//2+12]) 112 | plt.axis('off') 113 | plt.tight_layout() 114 | print('saving', imname) 115 | plt.savefig(imname, bbox_inches='tight', format='png', dpi=dpi) 116 | plt.close() 117 | 118 | if with_gt: 119 | # vis hdmap gt 120 | imname = os.path.join(save_path_seg, 'gt_map.png') 121 | if not os.path.exists(imname): 122 | plt.figure(figsize=(w*2/100, 4)) 123 | plt.imshow(semantic_gt_mask[i][1]*0.5, vmin=0, cmap=divier, vmax=vmax, alpha=alpha) 124 | plt.imshow(semantic_gt_mask[i][2]*0.5, vmin=0, cmap=ped_crossing, vmax=vmax, alpha=alpha) 125 | plt.imshow(semantic_gt_mask[i][3]*0.5, vmin=0, cmap=boundary, vmax=vmax, alpha=alpha) 126 | plt.xlim(0, w) 127 | plt.ylim(0, h) 128 | plt.imshow(car_img, extent=[w//2-15, w//2+15, h//2-12, h//2+12]) 129 | plt.axis('off') 130 | plt.tight_layout() 131 | print('saving ', imname) 132 | plt.savefig(imname, bbox_inches='tight', format='png', dpi=dpi) 133 | plt.close() 134 | 135 | 136 | def vis_vec(coords, timestamp, scene_id, save_path, h, w): 137 | save_path_vec = os.path.join(save_path, 'vec', f'{scene_id}') 138 | if not os.path.exists(save_path_vec): 139 | os.makedirs(save_path_vec) 140 | 141 | car_img = Image.open('icon/car_gray.png') 142 | 143 | plt.figure(figsize=(w*2/100, 2)) 144 | for coord in coords: 145 | plt.plot(coord[:, 0], coord[:, 1], linewidth=2) 146 | 147 | plt.xlim((0, w)) 148 | plt.ylim((0, h)) 149 | plt.axis('off') 150 | plt.grid(False) 151 | plt.imshow(car_img, extent=[w//2-15, w//2+15, h//2-12, h//2+12]) 152 | 153 | img_name = os.path.join(save_path_vec, f'{timestamp}_vecz_.jpg') 154 | print('saving', img_name) 155 | plt.savefig(img_name) 156 | plt.close() 157 | 158 | 159 | def eval_vis_all(model, save_path, val_loader): 160 | model.eval() 161 | total_intersects = 0 162 | total_union = 0 163 | i=0 164 | with torch.no_grad(): 165 | for (imgs, trans, rots, intrins, post_trans, post_rots, lidar_data, lidar_mask, car_trans, 166 | yaw_pitch_roll, semantic_gt, instance_gt, direction_gt, osm_masks, osm_vectors, masked_map, timestamps, scene_ids) in tqdm.tqdm(val_loader): 167 | # import pdb; pdb.set_trace() 168 | semantic, embedding, direction = model(imgs.cuda(), trans.cuda(), rots.cuda(), intrins.cuda(), 169 | post_trans.cuda(), post_rots.cuda(), lidar_data.cuda(), 170 | lidar_mask.cuda(), car_trans.cuda(), yaw_pitch_roll.cuda(), osm_masks.float().cuda()) 171 | 172 | semantic_gt = semantic_gt.cuda().float() 173 | device = semantic_gt.device 174 | if semantic.device != device: 175 | semantic = semantic.to(device) 176 | intersects, union = get_batch_iou(onehot_encoding(semantic), semantic_gt) 177 | total_intersects += intersects 178 | total_union += union 179 | vis(semantic.cpu().float(), semantic_gt.cpu().float(), osm_masks.float(), timestamps, scene_ids, save_path, with_gt=True) 180 | i+=1 181 | return (total_intersects / (total_union + 1e-7)) 182 | 183 | def main(cfg): 184 | # import pdb; pdb.set_trace() 185 | global SCENE_CANDIDATE 186 | SCENE_CANDIDATE = Nu_SCENE_CANDIDATE 187 | if 'dataset' in cfg: 188 | if cfg.dataset == 'av2': 189 | SCENE_CANDIDATE = AV2_SCENE_CANDIDATE 190 | 191 | data_conf = { 192 | 'num_channels': NUM_CLASSES + 1, 193 | 'image_size': cfg.image_size, 194 | 'xbound': cfg.xbound, 195 | 'ybound': cfg.ybound, 196 | 'zbound': cfg.zbound, 197 | 'dbound': cfg.dbound, 198 | 'thickness': cfg.thickness, 199 | 'angle_class': cfg.angle_class, 200 | 'patch_w': cfg.patch_w, 201 | 'patch_h': cfg.patch_h, 202 | 'mask_ratio': cfg.mask_ratio, 203 | 'mask_flag': cfg.mask_flag, 204 | 'sd_map_path': cfg.sd_map_path, 205 | } 206 | 207 | train_loader, val_loader = semantic_dataset(cfg, cfg.version, cfg.dataroot, data_conf, 208 | cfg.batch_size, cfg.nworkers, cfg.dataset) 209 | model = get_model(cfg, data_conf, cfg.instance_seg, cfg.embedding_dim, cfg.direction_pred, cfg.angle_class) 210 | 211 | state_dict_model = torch.load(cfg.modelf) 212 | new_state_dict = OrderedDict() 213 | for k, v in state_dict_model.items(): 214 | name = k[7:] 215 | new_state_dict[name] = v 216 | # import pdb; pdb.set_trace() 217 | model.load_state_dict(new_state_dict, strict=True) 218 | model.cuda() 219 | if "vis_path" not in cfg: 220 | cfg.vis_path = os.path.join(cfg.logdir, "vis") 221 | eval_vis_all(model, cfg.vis_path, val_loader) 222 | 223 | if __name__ == '__main__': 224 | parser = argparse.ArgumentParser(description='P-MapNet pre-train HD Prior.') 225 | parser.add_argument("config", help = 'path to config file', type=str, default=None) 226 | args = parser.parse_args() 227 | cfg = Config.fromfile(args.config) 228 | main(cfg) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import sys 4 | import logging 5 | import time 6 | from tensorboardX import SummaryWriter 7 | import argparse 8 | import matplotlib.pyplot as plt 9 | import torch 10 | import torch.nn as nn 11 | from tools.config import Config 12 | from torch.optim.lr_scheduler import StepLR 13 | from tools.loss import SimpleLoss, DiscriminativeLoss 14 | from data_osm.dataset import semantic_dataset 15 | from data_osm.const import NUM_CLASSES 16 | from tools.evaluation.iou import get_batch_iou 17 | from tools.evaluation.angle_diff import calc_angle_diff 18 | from tools.eval import onehot_encoding, eval_iou 19 | from model.utils.map_mae_head import vit_base_patch8 20 | import warnings 21 | warnings.filterwarnings("ignore") 22 | 23 | import tqdm 24 | import pdb 25 | from PIL import Image 26 | from model import get_model 27 | 28 | from collections import OrderedDict 29 | import torch.nn.functional as F 30 | from sklearn import metrics 31 | 32 | 33 | def write_log(writer, ious, title, counter): 34 | writer.add_scalar(f'{title}/iou', torch.mean(ious[1:]), counter) 35 | 36 | for i, iou in enumerate(ious): 37 | writer.add_scalar(f'{title}/class_{i}/iou', iou, counter) 38 | 39 | def train(cfg): 40 | if not os.path.exists(cfg.logdir): 41 | os.makedirs(cfg.logdir) 42 | logname = time.strftime("%Y-%m-%d-%H_%M_%S", time.localtime(time.time())) 43 | logging.basicConfig(filename=os.path.join(cfg.logdir, logname+'.log'), 44 | filemode='w', 45 | format='%(asctime)s: %(message)s', 46 | datefmt='%Y-%m-%d %H:%M:%S', 47 | level=logging.INFO) 48 | logging.getLogger('shapely.geos').setLevel(logging.CRITICAL) 49 | 50 | logger = logging.getLogger() 51 | logger.addHandler(logging.StreamHandler(sys.stdout)) 52 | 53 | data_conf = { 54 | 'num_channels': NUM_CLASSES + 1, 55 | 'image_size': cfg.image_size, 56 | 'xbound': cfg.xbound, 57 | 'ybound': cfg.ybound, 58 | 'zbound': cfg.zbound, 59 | 'dbound': cfg.dbound, 60 | 'thickness': cfg.thickness, 61 | 'angle_class': cfg.angle_class, 62 | 'patch_w': cfg.patch_w, 63 | 'patch_h': cfg.patch_h, 64 | 'mask_ratio': cfg.mask_ratio, 65 | 'mask_flag': cfg.mask_flag, 66 | 'sd_map_path': cfg.sd_map_path, 67 | } 68 | 69 | model = get_model(cfg, data_conf, cfg.instance_seg, cfg.embedding_dim, cfg.direction_pred, cfg.angle_class) 70 | # import pdb; pdb.set_trace() 71 | if "hd" in cfg.model: 72 | cfg.modelf_map = cfg.modelf_map if "modelf_map" in cfg else None 73 | cfg.modelf_mae = cfg.modelf_mae if "modelf_mae" in cfg else None 74 | if cfg.modelf_map: 75 | state_dict_model = torch.load(cfg.modelf_map) 76 | new_state_dict = OrderedDict() 77 | for k, v in state_dict_model.items(): 78 | name = k[7:] 79 | new_state_dict[name] = v 80 | model.load_state_dict(new_state_dict, strict=False) 81 | 82 | if cfg.modelf_mae: 83 | state_dict_model = torch.load(cfg.modelf_mae) 84 | new_state_dict = OrderedDict() 85 | for k, v in state_dict_model.items(): 86 | name = k.replace('module', 'mae_head') 87 | new_state_dict[name] = v 88 | model.load_state_dict(new_state_dict, strict=False) 89 | 90 | cfg.freeze_backbone = cfg.freeze_backbone if "freeze_backbone" in cfg else None 91 | if cfg.freeze_backbone: 92 | for name, param in model.named_parameters(): 93 | if 'mae_head' not in name: 94 | param.requires_grad = False 95 | 96 | if 'resume' in cfg and cfg.resume is not None: 97 | print("Loading checkpoint from cfg.resume: ", cfg.resume) 98 | state_dict_model = torch.load(cfg.resume) 99 | new_state_dict = OrderedDict() 100 | for k, v in state_dict_model.items(): 101 | name = k[7:] 102 | new_state_dict[name] = v 103 | model.load_state_dict(new_state_dict, strict=False) 104 | 105 | model = nn.DataParallel(model, device_ids=cfg.gpus) 106 | model.cuda(device=cfg.gpus[0]) 107 | # import pdb; pdb.set_trace() 108 | train_loader, val_loader = semantic_dataset(cfg, cfg.version, cfg.dataroot, data_conf, 109 | cfg.batch_size, cfg.nworkers, cfg.dataset) 110 | 111 | opt = torch.optim.Adam(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay) 112 | sched = StepLR(opt, 3, 0.1) 113 | writer = SummaryWriter(logdir=cfg.logdir) 114 | 115 | loss_fn = SimpleLoss(cfg.pos_weight).cuda() 116 | embedded_loss_fn = DiscriminativeLoss(cfg.embedding_dim, cfg.delta_v, cfg.delta_d).cuda() 117 | direction_loss_fn = torch.nn.BCELoss(reduction='none') 118 | 119 | counter = 0 120 | last_idx = len(train_loader) - 1 121 | for epoch in range(cfg.nepochs): 122 | for batchi, (imgs, trans, rots, intrins, post_trans, post_rots, lidar_data, lidar_mask, car_trans, 123 | yaw_pitch_roll, semantic_gt, instance_gt, direction_gt, osm_masks, osm_vectors, masked_map, timestamps, scene_ids) in enumerate(train_loader): 124 | # import pdb; pdb.set_trace() 125 | t0 = time.time() 126 | opt.zero_grad() 127 | semantic, embedding, direction = model(imgs.cuda(), trans.cuda(), rots.cuda(), intrins.cuda(), 128 | post_trans.cuda(), post_rots.cuda(), lidar_data.cuda(), 129 | lidar_mask.cuda(), car_trans.cuda(), yaw_pitch_roll.cuda(), osm_masks.float().cuda()) 130 | 131 | semantic_gt = semantic_gt.cuda().float() 132 | instance_gt = instance_gt.cuda() 133 | 134 | device = semantic_gt.device 135 | if semantic.device != device: 136 | semantic = semantic.to(device) 137 | embedding = embedding.to(device) 138 | direction = direction.to(device) 139 | 140 | seg_loss = loss_fn(semantic, semantic_gt) 141 | if cfg.instance_seg: 142 | var_loss, dist_loss, reg_loss = embedded_loss_fn(embedding, instance_gt) 143 | else: 144 | var_loss = 0 145 | dist_loss = 0 146 | reg_loss = 0 147 | 148 | if cfg.direction_pred: 149 | direction_gt = direction_gt.cuda() 150 | lane_mask = (1 - direction_gt[:, 0]).unsqueeze(1) 151 | direction_loss = direction_loss_fn(torch.softmax(direction, 1), direction_gt) 152 | direction_loss = (direction_loss * lane_mask).sum() / (lane_mask.sum() * direction_loss.shape[1] + 1e-6) 153 | angle_diff = calc_angle_diff(direction, direction_gt, cfg.angle_class) 154 | else: 155 | direction_loss = 0 156 | angle_diff = 0 157 | 158 | final_loss = seg_loss * cfg.scale_seg + var_loss * cfg.scale_var + dist_loss * cfg.scale_dist + direction_loss * cfg.scale_direction 159 | final_loss.backward() 160 | torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.max_grad_norm) 161 | opt.step() 162 | counter += 1 163 | t1 = time.time() 164 | 165 | if counter % 100 == 0: 166 | intersects, union = get_batch_iou(onehot_encoding(semantic), semantic_gt) 167 | iou = intersects / (union + 1e-7) 168 | logger.info(f"TRAIN[{epoch:>3d}]: [{batchi:>4d}/{last_idx}] " 169 | f"Time: {t1-t0:>7.4f} " 170 | f"Loss: {final_loss.item():>7.4f} " 171 | f"IOU: {np.array2string(iou[1:].numpy(), precision=3, floatmode='fixed')}") 172 | 173 | write_log(writer, iou, 'train', counter) 174 | writer.add_scalar('train/step_time', t1 - t0, counter) 175 | writer.add_scalar('train/seg_loss', seg_loss, counter) 176 | writer.add_scalar('train/var_loss', var_loss, counter) 177 | writer.add_scalar('train/dist_loss', dist_loss, counter) 178 | writer.add_scalar('train/reg_loss', reg_loss, counter) 179 | writer.add_scalar('train/direction_loss', direction_loss, counter) 180 | writer.add_scalar('train/final_loss', final_loss, counter) 181 | writer.add_scalar('train/angle_diff', angle_diff, counter) 182 | 183 | model_name = os.path.join(cfg.logdir, f"model{epoch}.pt") 184 | torch.save(model.state_dict(), model_name) 185 | logger.info(f"{model_name} saved") 186 | 187 | iou = eval_iou(model, val_loader) 188 | logger.info(f"EVAL[{epoch:>2d}]: " 189 | f"IOU: {np.array2string(iou[1:].numpy(), precision=3, floatmode='fixed')}") 190 | write_log(writer, iou, 'eval', counter) 191 | model.train() 192 | sched.step() 193 | 194 | if __name__ == '__main__': 195 | parser = argparse.ArgumentParser(description='P-MapNet training with HD Prior.') 196 | parser.add_argument("config", help = 'path to config file', type=str, default=None) 197 | args = parser.parse_args() 198 | cfg = Config.fromfile(args.config) 199 | 200 | if not os.path.exists(cfg.logdir): 201 | os.makedirs(cfg.logdir) 202 | with open(os.path.join(cfg.logdir, 'config.txt'), 'w') as f: 203 | argsDict = cfg.__dict__ 204 | for eachArg, value in argsDict.items(): 205 | f.writelines(eachArg + " : " + str(value) + "\n") 206 | train(cfg) 207 | 208 | -------------------------------------------------------------------------------- /train_HDPrior_pretrain.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import sys 4 | import logging 5 | import time 6 | from tensorboardX import SummaryWriter 7 | import argparse 8 | import matplotlib.pyplot as plt 9 | import torch 10 | import torch.nn as nn 11 | from tools.config import Config 12 | from torch.optim.lr_scheduler import StepLR 13 | from tools.loss import SimpleLoss, DiscriminativeLoss 14 | from data_osm.dataset import semantic_dataset 15 | from data_osm.const import NUM_CLASSES 16 | from tools.evaluation.iou import get_batch_iou 17 | from tools.evaluation.angle_diff import calc_angle_diff 18 | from tools.eval import onehot_encoding, eval_pretrain 19 | from model.utils.map_mae_head import vit_base_patch8 20 | from model import get_model 21 | 22 | import warnings 23 | warnings.filterwarnings("ignore") 24 | from collections import OrderedDict 25 | 26 | def write_log(writer, ious, title, counter): 27 | writer.add_scalar(f'{title}/iou', torch.mean(ious[1:]), counter) 28 | for i, iou in enumerate(ious): 29 | writer.add_scalar(f'{title}/class_{i}/iou', iou, counter) 30 | 31 | def train(cfg): 32 | if not os.path.exists(cfg.logdir): 33 | os.makedirs(cfg.logdir) 34 | logging.basicConfig(filename=os.path.join(cfg.logdir, "results.log"), 35 | filemode='w', 36 | format='%(asctime)s: %(message)s', 37 | datefmt='%Y-%m-%d %H:%M:%S', 38 | level=logging.INFO) 39 | logging.getLogger('shapely.geos').setLevel(logging.CRITICAL) 40 | 41 | logger = logging.getLogger() 42 | logger.addHandler(logging.StreamHandler(sys.stdout)) 43 | 44 | data_conf = { 45 | 'num_channels': NUM_CLASSES + 1, 46 | 'image_size': cfg.image_size, 47 | 'xbound': cfg.xbound, 48 | 'ybound': cfg.ybound, 49 | 'zbound': cfg.zbound, 50 | 'dbound': cfg.dbound, 51 | 'thickness': cfg.thickness, 52 | 'angle_class': cfg.angle_class, 53 | 'patch_w': cfg.patch_w, 54 | 'patch_h': cfg.patch_h, 55 | 'mask_ratio': cfg.mask_ratio, 56 | 'mask_flag': cfg.mask_flag, 57 | 'sd_map_path': cfg.sd_map_path, 58 | } 59 | 60 | train_loader, val_loader = semantic_dataset(cfg, cfg.version, cfg.dataroot, data_conf, 61 | cfg.batch_size, cfg.nworkers, cfg.dataset) 62 | patch_h = data_conf['ybound'][1] - data_conf['ybound'][0] 63 | patch_w = data_conf['xbound'][1] - data_conf['xbound'][0] 64 | canvas_h = int(patch_h / data_conf['ybound'][2]) 65 | canvas_w = int(patch_w / data_conf['xbound'][2]) 66 | 67 | # # TODO: add to cfg and add support for patch32 68 | # model = vit_base_patch8(data_conf=data_conf, 69 | # instance_seg=cfg.instance_seg, 70 | # embedded_dim=cfg.embedding_dim, 71 | # direction_pred=cfg.direction_pred, 72 | # direction_dim=cfg.angle_class, 73 | # lidar=True, 74 | # img_size=(canvas_h, canvas_w)) 75 | model = get_model(cfg, data_conf, cfg.instance_seg, cfg.embedding_dim, cfg.direction_pred, cfg.angle_class) 76 | 77 | if 'vit_base' in cfg and cfg.vit_base is not None: 78 | state_dict_model = torch.load(cfg.vit_base) 79 | model.load_state_dict(state_dict_model, strict=False) 80 | model = nn.DataParallel(model, device_ids=cfg.gpus) 81 | opt = torch.optim.Adam(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay) 82 | sched = StepLR(opt, 3, 0.1) 83 | writer = SummaryWriter(logdir=cfg.logdir) 84 | 85 | loss_fn = SimpleLoss(cfg.pos_weight).cuda() 86 | embedded_loss_fn = DiscriminativeLoss(cfg.embedding_dim, cfg.delta_v, cfg.delta_d).cuda() 87 | direction_loss_fn = torch.nn.BCELoss(reduction='none') 88 | 89 | model.cuda(device=cfg.gpus[0]) 90 | model.train() 91 | 92 | counter = 0 93 | last_idx = len(train_loader) - 1 94 | 95 | for epoch in range(cfg.nepochs): 96 | for batchi, (imgs, trans, rots, intrins, post_trans, post_rots, 97 | lidar_data, lidar_mask, car_trans, yaw_pitch_roll, 98 | semantic_gt, instance_gt, direction_gt, osm_masks, 99 | osm_vectors, masked_map, timestamps, scene_ids) in enumerate(train_loader): 100 | t0 = time.time() 101 | opt.zero_grad() 102 | semantic, embedding, direction = model(masked_map.float()) 103 | semantic_gt = semantic_gt.cuda().float() 104 | instance_gt = instance_gt.cuda() 105 | 106 | device = semantic_gt.device 107 | if semantic.device != device: 108 | semantic = semantic.to(device) 109 | embedding = embedding.to(device) 110 | direction = direction.to(device) 111 | 112 | seg_loss = loss_fn(semantic, semantic_gt) 113 | if cfg.instance_seg: 114 | var_loss, dist_loss, reg_loss = embedded_loss_fn(embedding, instance_gt) 115 | else: 116 | var_loss = 0 117 | dist_loss = 0 118 | reg_loss = 0 119 | 120 | if cfg.direction_pred: 121 | direction_gt = direction_gt.cuda() 122 | lane_mask = (1 - direction_gt[:, 0]).unsqueeze(1) 123 | direction_loss = direction_loss_fn(torch.softmax(direction, 1), direction_gt) 124 | direction_loss = (direction_loss * lane_mask).sum() / (lane_mask.sum() * direction_loss.shape[1] + 1e-6) 125 | angle_diff = calc_angle_diff(direction, direction_gt, cfg.angle_class) 126 | else: 127 | direction_loss = 0 128 | angle_diff = 0 129 | 130 | final_loss = seg_loss * cfg.scale_seg + var_loss * cfg.scale_var + dist_loss * cfg.scale_dist + direction_loss * cfg.scale_direction 131 | final_loss.backward() 132 | torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.max_grad_norm) 133 | opt.step() 134 | counter += 1 135 | t1 = time.time() 136 | if counter % 100 == 0: 137 | intersects, union = get_batch_iou(onehot_encoding(semantic), semantic_gt) 138 | iou = intersects / (union + 1e-7) 139 | logger.info(f"TRAIN[{epoch:>3d}]: [{batchi:>4d}/{last_idx}] " 140 | f"Time: {t1-t0:>7.4f} " 141 | f"Loss: {final_loss.item():>7.4f} " 142 | f"IOU: {np.array2string(iou[1:].numpy(), precision=3, floatmode='fixed')}") 143 | 144 | write_log(writer, iou, 'train', counter) 145 | writer.add_scalar('train/step_time', t1 - t0, counter) 146 | writer.add_scalar('train/seg_loss', seg_loss, counter) 147 | writer.add_scalar('train/var_loss', var_loss, counter) 148 | writer.add_scalar('train/dist_loss', dist_loss, counter) 149 | writer.add_scalar('train/reg_loss', reg_loss, counter) 150 | writer.add_scalar('train/direction_loss', direction_loss, counter) 151 | writer.add_scalar('train/final_loss', final_loss, counter) 152 | writer.add_scalar('train/angle_diff', angle_diff, counter) 153 | cur_lr = opt.state_dict()['param_groups'][0]['lr'] 154 | writer.add_scalar('train/lr', cur_lr, counter) 155 | 156 | model_name = os.path.join(cfg.logdir, f"model{epoch}.pt") 157 | torch.save(model.state_dict(), model_name) 158 | 159 | logger.info(f"{model_name} saved") 160 | 161 | iou = eval_pretrain(model, val_loader) 162 | 163 | logger.info(f"EVAL[{epoch:>2d}]: " 164 | f"IOU: {np.array2string(iou[1:].numpy(), precision=3, floatmode='fixed')}") 165 | 166 | write_log(writer, iou, 'eval', counter) 167 | 168 | model.train() 169 | 170 | sched.step() 171 | 172 | 173 | if __name__ == '__main__': 174 | parser = argparse.ArgumentParser(description='P-MapNet pre-train HD Prior.') 175 | parser.add_argument("config", help = 'path to config file', type=str, default=None) 176 | args = parser.parse_args() 177 | cfg = Config.fromfile(args.config) 178 | 179 | if not os.path.exists(cfg.logdir): 180 | os.makedirs(cfg.logdir) 181 | with open(os.path.join(cfg.logdir, 'config.txt'), 'w') as f: 182 | argsDict = cfg.__dict__ 183 | for eachArg, value in argsDict.items(): 184 | f.writelines(eachArg + " : " + str(value) + "\n") 185 | train(cfg) 186 | --------------------------------------------------------------------------------