├── .gitignore ├── LICENSE ├── README.md ├── docs ├── config_tutorial.md ├── data_annotation_tutorial.md └── data_intro.md ├── images └── v2xvit.png ├── requirements.txt ├── setup.py └── v2xvit ├── __init__.py ├── data_utils ├── __init__.py ├── augmentor │ ├── __init__.py │ ├── augment_utils.py │ └── data_augmentor.py ├── datasets │ ├── __init__.py │ ├── basedataset.py │ ├── early_fusion_dataset.py │ ├── early_fusion_vis_dataset.py │ ├── intermediate_fusion_dataset.py │ └── late_fusion_dataset.py ├── post_processor │ ├── __init__.py │ ├── base_postprocessor.py │ ├── bev_postprocessor.py │ └── voxel_postprocessor.py └── pre_processor │ ├── __init__.py │ ├── base_preprocessor.py │ ├── bev_preprocessor.py │ ├── sp_voxel_preprocessor.py │ └── voxel_preprocessor.py ├── hypes_yaml ├── __init__.py ├── point_pillar_early_fusion.yaml ├── point_pillar_fcooper.yaml ├── point_pillar_late_fusion.yaml ├── point_pillar_opv2v.yaml ├── point_pillar_v2vnet.yaml ├── point_pillar_v2xvit.yaml ├── visualization.yaml └── yaml_utils.py ├── loss ├── __init__.py ├── pixor_loss.py ├── point_pillar_loss.py └── voxel_net_loss.py ├── models ├── __init__.py ├── point_pillar.py ├── point_pillar_fcooper.py ├── point_pillar_opv2v.py ├── point_pillar_transformer.py ├── point_pillar_v2vnet.py └── sub_modules │ ├── __init__.py │ ├── base_bev_backbone.py │ ├── base_transformer.py │ ├── convgru.py │ ├── downsample_conv.py │ ├── f_cooper_fuse.py │ ├── fuse_utils.py │ ├── hmsa.py │ ├── mswin.py │ ├── naive_compress.py │ ├── pillar_vfe.py │ ├── point_pillar_scatter.py │ ├── self_attn.py │ ├── split_attn.py │ ├── torch_transformation_utils.py │ ├── v2v_fuse.py │ └── v2xvit_basic.py ├── tools ├── __init__.py ├── debug_utils.py ├── inference.py ├── infrence_utils.py ├── train.py └── train_utils.py ├── utils ├── __init__.py ├── box_overlaps.pyx ├── box_utils.py ├── common_utils.py ├── eval_utils.py ├── pcd_utils.py ├── setup.py └── transformation_utils.py ├── version.py └── visualization ├── __init__.py ├── pinhole_param.json ├── vis_data_sequence.py └── vis_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | logs/ 131 | *.c 132 | *.so 133 | .idea 134 | opv2x 135 | .DS_Store 136 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Runsheng Xu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/v2x-vit-vehicle-to-everything-cooperative/3d-object-detection-on-v2xset)](https://paperswithcode.com/sota/3d-object-detection-on-v2xset?p=v2x-vit-vehicle-to-everything-cooperative) 2 | 3 | # [V2X-ViT](https://arxiv.org/abs/2203.10638): Vehicle-to-Everything Cooperative Perception with Vision Transformer (ECCV 2022) 4 | 5 | [![paper](https://img.shields.io/badge/arXiv-Paper-.svg)](https://arxiv.org/abs/2203.10638) 6 | [![supplement](https://img.shields.io/badge/Supplementary-Material-red)]() 7 | [![video](https://img.shields.io/badge/Video-Presentation-F9D371)]() 8 | 9 | 10 | This is the official implementation of ECCV2022 paper "V2X-ViT: Vehicle-to-Everything Cooperative Perception with Vision Transformer". 11 | [Runsheng Xu](https://derrickxunu.github.io/), [Hao Xiang](https://xhwind.github.io/), [Zhengzhong Tu](https://github.com/vztu), [Xin Xia](https://scholar.google.com/citations?user=vCYqMTIAAAAJ&hl=en), [Ming-Hsuan Yang](https://scholar.google.com/citations?user=p9-ohHsAAAAJ&hl=en), [Jiaqi Ma](https://mobility-lab.seas.ucla.edu/) 12 | 13 | UCLA, UT-Austin, Google Research, UC-Merced 14 | 15 | **Important Notice**: [OpenCOOD](https://github.com/DerrickXuNu/OpenCOOD) supports V2X-ViT and V2XSet now! We will **no longer** update this repo, and all the new features (e.g. multi gpu implementation) will only be updated in OpenCOOD. 16 | 17 | ![teaser](images/v2xvit.png) 18 | 19 | ## Installation 20 | ```bash 21 | # Clone repo 22 | git clone https://github.com/DerrickXuNu/v2x-vit 23 | 24 | cd v2x-vit 25 | 26 | # Setup conda environment 27 | conda create -y --name v2xvit python=3.7 28 | 29 | conda activate v2xvit 30 | # pytorch >= 1.8.1, newest version can work well 31 | conda install -y pytorch torchvision cudatoolkit=11.3 -c pytorch 32 | # spconv 2.0 install, choose the correct cuda version for you 33 | pip install spconv-cu113 34 | 35 | # Install dependencies 36 | pip install -r requirements.txt 37 | # Install bbx nms calculation cuda version 38 | python v2xvit/utils/setup.py build_ext --inplace 39 | 40 | # install v2xvit into the environment 41 | python setup.py develop 42 | ``` 43 | 44 | ## Data 45 | ### Download 46 | The data can be found from [this url](https://ucla.app.box.com/v/UCLA-MobilityLab-V2XVIT). Since the data for train/validate/test 47 | is very large, we split each data set into small chunks, which can be found in the directory ending with `_chunks`, such as `train_chunks`. After downloading, please run the following command to each set to merge those chunks together: 48 | 49 | ``` 50 | cat train.zip.part* > train.zip 51 | unzip train.zip 52 | ``` 53 | If you have good internet, you can also directly download the whole zip file, e.g. train.zip 54 | ### Structure 55 | After downloading is finished, please make the file structured as following: 56 | 57 | ```sh 58 | v2x-vit # root of your v2xvit 59 | ├── v2xset # the downloaded v2xset data 60 | │ ├── train 61 | │ ├── validate 62 | │ ├── test 63 | ├── v2xvit # the core codebase 64 | 65 | ``` 66 | ### Details 67 | Our data label format is very similar with the one in [OPV2V](https://github.com/DerrickXuNu/OpenCOOD). For more details, please refer to the [data tutorial](docs/data_intro.md). 68 | 69 | ### Noise Simulation 70 | One important feature of V2XSet is the capability of adding different communication noises. This is done in a post-processing approach through our flexible coding framework. To set different noise, please 71 | refer to [config yaml tutorial](docs/config_tutorial.md). 72 | 73 | ## Getting Started 74 | ### Data sequence visualization 75 | To quickly visualize the LiDAR stream in the V2XSet dataset, first modify the `validate_dir` 76 | in your `v2xvit/hypes_yaml/visualization.yaml` to the V2XSet data path on your local machine, e.g. `v2xset/validate`, 77 | and the run the following commond: 78 | ```python 79 | cd ~/v2x-vit 80 | python v2xvit/visualization/vis_data_sequence.py [--color_mode ${COLOR_RENDERING_MODE}] 81 | ``` 82 | Arguments Explanation: 83 | - `color_mode` : str type, indicating the lidar color rendering mode. You can choose from 'constant', 'intensity' or 'z-value'. 84 | 85 | ### Test with pretrained model 86 | To test the pretrained model of V2X-ViT, first download the model file from [google url](https://drive.google.com/drive/folders/1h2UOPP2tNRkV_s6cbKcSfMvTgb8_ZFj9?usp=sharing) and 87 | then put it under v2x-vit/logs/v2x-vit. Change the `validate_path` in `v2x-vit/logs/v2x-vit/config.yaml` as `'v2xset/test'. 88 | 89 | To test under perfect setting, change both `async` and `loc_error`to false in the v2x-vit/logs/v2x-vit/config.yaml. 90 | 91 | To test under noisy setting in our paper, change the `wild_setting` as followings: 92 | ``` 93 | wild_setting: 94 | async: true 95 | async_mode: 'sim' 96 | async_overhead: 100 97 | backbone_delay: 10 98 | data_size: 1.06 99 | loc_err: true 100 | ryp_std: 0.2 101 | seed: 25 102 | transmission_speed: 27 103 | xyz_std: 0.2 104 | ``` 105 | Eventually, run the following command to perform test: 106 | ```python 107 | python v2xvit/tools/inference.py --model_dir ${CHECKPOINT_FOLDER} --fusion_method ${FUSION_STRATEGY} [--show_vis] [--show_sequence] 108 | ``` 109 | Arguments Explanation: 110 | - `model_dir`: the path to your saved model. 111 | - `fusion_method`: indicate the fusion strategy, currently support 'early', 'late', and 'intermediate'. 112 | - `show_vis`: whether to visualize the detection overlay with point cloud. 113 | - `show_sequence` : the detection results will visualized in a video stream. It can NOT be set with `show_vis` at the same time. 114 | 115 | 116 | 117 | 118 | ### Train your model 119 | V2X-ViT uses yaml file to configure all the parameters for training. To train your own model 120 | from scratch or a continued checkpoint, run the following commonds: 121 | ```python 122 | python v2xvit/tools/train.py --hypes_yaml ${CONFIG_FILE} [--model_dir ${CHECKPOINT_FOLDER} --half] 123 | ``` 124 | Arguments Explanation: 125 | - `hypes_yaml`: the path of the training configuration file, e.g. `v2xvit/hypes_yaml/point_pillar_v2xvit.yaml`, meaning you want to train 126 | - `model_dir` (optional) : the path of the checkpoints. This is used to fine-tune the trained models. When the `model_dir` is 127 | given, the trainer will discard the `hypes_yaml` and load the `config.yaml` in the checkpoint folder. 128 | - `half`(optional): if specified, hybrid-precision training will be used to save memory occupation. 129 | 130 | Important Notes for Training: 131 | 1. When you train from scratch, please first set `async` and `loc_err` to false to train on perfect setting. Also, set `compression` to 0 at beginning. 132 | 2. After the model on perfect setting converged, set `compression` to 32 (please change the config yaml in your trained model directory) and continue training on the perfect setting for another 1-2 epoches. 133 | 3. Next, set `async` to true, `async_mode` to 'real', `async_overhead` to 200 or 300, `loc_err` to true, `xyz_std` to 0.2, `rpy_std` to 0.2, and then continue training your model on this noisy setting. Please note that you are free to change these noise setting during training to obtain better performance. 134 | 4. Eventually, use the model fine-tuned on noisy setting as the test model for both perfect and noisy setting. 135 | 136 | ## Citation 137 | If you are using our V2X-ViT model or V2XSet dataset for your research, please cite the following paper: 138 | ```bibtex 139 | @inproceedings{xu2022v2xvit, 140 | author = {Runsheng Xu, Hao Xiang, Zhengzhong Tu, Xin Xia, Ming-Hsuan Yang, Jiaqi Ma}, 141 | title = {V2X-ViT: Vehicle-to-Everything Cooperative Perception with Vision Transformer}, 142 | booktitle={Proceedings of the European Conference on Computer Vision (ECCV)}, 143 | year = {2022}} 144 | ``` 145 | 146 | ## Acknowledgement 147 | V2X-ViT is build upon [OpenCOOD](https://github.com/DerrickXuNu/OpenCOOD), which is the first Open Cooperative Detection framework for autonomous driving. 148 | 149 | V2XSet is collected using [OpenCDA](https://github.com/ucla-mobility/OpenCDA), which is the first open co-simulation-based research/engineering framework integrated with prototype cooperative driving automation pipelines as well as regular automated driving components (e.g., perception, localization, planning, control). 150 | -------------------------------------------------------------------------------- /docs/config_tutorial.md: -------------------------------------------------------------------------------- 1 | ## Tutorial: Config System in V2X-ViT 2 | 3 | --- 4 | We incorporate modular and inheritance design into the config system to enable users conveniently 5 | modify the model/training/inference parameters. Specifically, we use **yaml** files to configure all the 6 | important parameters. 7 | 8 | ### Config File Location 9 | To train a model from scratch, all the yaml files should be saved in `v2xvit/hypes_yaml`, and users should use the `load_yaml()` function in [`v2xvit/hypes_yaml/yaml_utils.py`](https://github.com/DerrickXuNu/OpenCOOD/blob/main/opencood/hypes_yaml/yaml_utils.py#L8) to load the parameters into a dictionary. 10 | 11 | To train from a saved checkpoint or test, the yaml files should be located with the checkpoint together in the `v2xvit/logs/model_name`. 12 | 13 | ### Config Name Style 14 | We follow the below style to name config yaml files. 15 | ```python 16 | {backbone}_{fusion_strategy}.yaml 17 | ``` 18 | 19 | ### Noise Simulation 20 | Communication noise simulation is defined in the `wild_setting` group in the yaml file. 21 | ``` 22 | wild_setting: # setting related to noise 23 | async: true 24 | async_mode: 'sim' 25 | async_overhead: 100 26 | backbone_delay: 10 27 | data_size: 1.06 28 | loc_err: true 29 | ryp_std: 0.2 30 | seed: 25 31 | transmission_speed: 27 32 | xyz_std: 0.2 33 | ``` 34 | `async`: whether add communication delay.
35 | `aysnc_mode`: sim or real mode. In sim mode, the delay is a constant while in real mode, the delay has a uniform distribution. 36 | The major experiment in the paper used sim mode whereas the 'Effects of transmission size' study used real 37 | mode.
38 | `async_overhead`: the communication delay in ms. In sim mode, it represents a constant number. In real mode, 39 | the systematic async will be a random number from 0 to `aysnc_overhead`.
40 | `backbone_delay`: an estimate of backbone computation time. Only useful in real mode.
41 | `data_size`: transmission data size in Mb. Only used in real mode.
42 | `transmission_speed`: data transmitting speed during communication. By default 27 Mb/s. Only used in real mode.
43 | `loc_err`: whether to add localization error.
44 | `xyz_std`: the standard deviation of positional GPS error.
45 | `ryp_std`: the standard deviation of angular GPS error.
46 | `seed`: random seed for noise simulation. please keep it as 25 during testing . 47 | 48 | 49 | 50 | 51 | ### A concrete example 52 | Now let's go through the `point_pillar_opv2v_fusion.yaml` as an example. 53 | 54 | ```yaml 55 | name: point_pillar_intermediate_fusion # this parameter together with the current timestamp will define the name of the saved folder for the model. 56 | root_dir: "v2xset/train" # this is where the training data locate 57 | validate_dir: "v2xset/validate" # during training, it defines the validation folder. during testing, it defines the testing folder path. 58 | 59 | yaml_parser: "load_point_pillar_params" # we need specific loading functions for different backbones. 60 | train_params: # the common training parameters 61 | batch_size: &batch_size 2 62 | epoches: 60 63 | eval_freq: 1 64 | save_freq: 1 65 | 66 | wild_setting: # setting related to noise 67 | async: true 68 | async_mode: 'sim' 69 | async_overhead: 100 70 | backbone_delay: 10 71 | data_size: 1.06 72 | loc_err: true 73 | ryp_std: 0.2 74 | seed: 25 75 | transmission_speed: 27 76 | xyz_std: 0.2 77 | 78 | fusion: 79 | core_method: 'IntermediateFusionDataset' # LateFusionDataset, EarlyFusionDataset, and IntermediateFusionDataset are supported 80 | args: 81 | cur_ego_pose_flag: True 82 | # when the cur_ego_pose_flag is set to True, there is no time gap 83 | # between the time when the LiDAR data is captured by connected 84 | # agents and when the extracted features are received by 85 | # the ego vehicle, which is equal to implement STCM. When set to False, 86 | # STCM has to be used. To validate STCM, V2X-ViT will set this as False. 87 | 88 | # preprocess-related 89 | preprocess: 90 | # options: BasePreprocessor, SpVoxelPreprocessor, BevPreprocessor 91 | core_method: 'SpVoxelPreprocessor' 92 | args: 93 | voxel_size: &voxel_size [0.4, 0.4, 4] # the voxel resolution for PointPillar 94 | max_points_per_voxel: 32 # maximum points allowed in each voxel 95 | max_voxel_train: 32000 # the maximum voxel number during training 96 | max_voxel_test: 70000 # the maximum voxel number during testing 97 | # LiDAR point cloud cropping range 98 | cav_lidar_range: &cav_lidar [-140.8, -40, -3, 140.8, 40, 1] 99 | 100 | # data augmentation options. 101 | data_augment: 102 | - NAME: random_world_flip 103 | ALONG_AXIS_LIST: [ 'x' ] 104 | 105 | - NAME: random_world_rotation 106 | WORLD_ROT_ANGLE: [ -0.78539816, 0.78539816 ] 107 | 108 | - NAME: random_world_scaling 109 | WORLD_SCALE_RANGE: [ 0.95, 1.05 ] 110 | 111 | # post processing related. 112 | postprocess: 113 | core_method: 'VoxelPostprocessor' # VoxelPostprocessor and BevPostprocessor are supported 114 | anchor_args: # anchor generator parameters 115 | cav_lidar_range: *cav_lidar # the range is consistent with the lidar cropping range to generate the correct ancrhors 116 | l: 3.9 # the default length of the anchor 117 | w: 1.6 # the default width 118 | h: 1.56 # the default height 119 | r: [0, 90] # the yaw angles. 0, 90 meaning for each voxel, two anchors will be generated with 0 and 90 degree yaw angle 120 | feature_stride: 2 # the feature map is shrank twice compared the input voxel tensor 121 | num: &achor_num 2 # for each location in the feature map, 2 anchors will be generated 122 | target_args: # used to generate positive and negative samples for object detection 123 | pos_threshold: 0.6 124 | neg_threshold: 0.45 125 | score_threshold: 0.20 126 | order: 'hwl' # hwl or lwh 127 | max_num: 100 # maximum number of objects in a single frame. use this number to make sure different frames have the same dimension in the same batch 128 | nms_thresh: 0.15 129 | 130 | # model related 131 | model: 132 | core_method: point_pillar_opv2v # trainer will load the corresponding model python file with the same name 133 | args: # detailed parameters of the point pillar model 134 | voxel_size: *voxel_size 135 | lidar_range: *cav_lidar 136 | anchor_number: *achor_num 137 | 138 | pillar_vfe: 139 | use_norm: true 140 | with_distance: false 141 | use_absolute_xyz: true 142 | num_filters: [64] 143 | point_pillar_scatter: 144 | num_features: 64 145 | 146 | base_bev_backbone: 147 | layer_nums: [3, 5, 8] 148 | layer_strides: [2, 2, 2] 149 | num_filters: [64, 128, 256] 150 | upsample_strides: [1, 2, 4] 151 | num_upsample_filter: [128, 128, 128] 152 | compression: 0 # whether to compress the features before fusion to reduce the bandwidth 153 | backbone_fix: false # whether fix the pointpillar backbone weights during training. 154 | anchor_num: *achor_num 155 | 156 | loss: # loss function 157 | core_method: point_pillar_loss # trainer will load the loss function with the same name 158 | args: 159 | cls_weight: 1.0 # classification weights 160 | reg: 2.0 # regression weights 161 | 162 | optimizer: # optimzer setup 163 | core_method: Adam # the name has to exist in Pytorch optimizer library 164 | lr: 0.002 165 | args: 166 | eps: 1e-10 167 | weight_decay: 1e-4 168 | 169 | lr_scheduler: # learning rate schedular 170 | core_method: multistep #step, multistep and Exponential are supported 171 | gamma: 0.1 172 | step_size: [15, 30] 173 | 174 | ``` -------------------------------------------------------------------------------- /docs/data_annotation_tutorial.md: -------------------------------------------------------------------------------- 1 | ## Data Annotation Introduction 2 | 3 | --- 4 | We save all groundtruth annotations per agent per timestamp in the yaml files. For instance, 5 | `2021_08_24_21_29_28/4805/000069.yaml` refers to the data annotations with the perspective of te 6 | agent 4805 at timestamp 69 in the scenario database `2021_08_24_21_29_28`. Here we go through an example: 7 | 8 | ```yaml 9 | camera0: # parameters for frontal camera 10 | cords: # the x,y,z,roll,yaw,pitch under CARLA map coordinate 11 | - 141.35067749023438 12 | - -388.642578125 13 | - 1.0410505533218384 14 | - 0.07589337974786758 15 | - 174.18048095703125 16 | - 0.20690691471099854 17 | extrinsic: # extrinsic matrix from camera to LiDAR 18 | - - 0.9999999999999999 19 | - -5.1230071481984265e-18 20 | - 9.322129061605055e-20 21 | - -2.999993025731527 22 | - - -2.5011383190939924e-18 23 | - 1.0 24 | - 1.1458579204685086e-19 25 | - -3.934422863949294e-06 26 | - - 2.7713237218713775e-20 27 | - 3.7310309839064755e-20 28 | - 1.0 29 | - 0.8999999040861146 30 | - - 0.0 31 | - 0.0 32 | - 0.0 33 | - 1.0 34 | intrinsic: # camera intrinsic matrix 35 | - - 335.639852470912 36 | - 0.0 37 | - 400.0 38 | - - 0.0 39 | - 335.639852470912 40 | - 300.0 41 | - - 0.0 42 | - 0.0 43 | - 1.0 44 | camera1: ... # params of right rear camera 45 | camera2: ... # params of left rear camera 46 | canera3: ... # params of back camera 47 | ego_speed: 18.13 # agent's current speed, km/h 48 | lidar_pose: # LiDAR pose under CARLA map coordinate system 49 | - 144.33 50 | - -388.94 51 | - 1.93 52 | - 0.078 53 | - 174.18 54 | - 0.21 55 | plan_trajectory: # agent's planning trajectory 56 | - - 140. 57 | - -388 58 | - 87 59 | predicted_ego_pos: # agent's localization (x,y,z,roll,yaw,pitch) gained from GPS 60 | - 143.78 61 | - -388.94 62 | - 0.036 63 | - 0.080 64 | - -185.95 65 | - 0.18 66 | true_ego_pos: # agent's true localization 67 | - 143.83 68 | - -388.89 69 | - 0.032 70 | - 0.075 71 | - 174.18 72 | - 0.21 73 | vehicles: # the surrounding vehicles that have at least one LiDAR point hit from the agent 74 | 4796: # the id of the vehicle (i.e. object) 75 | angle: # roll, yaw, pitch under CARLA map coordinate system 76 | - 0.096 77 | - -177.86 78 | - 0.197 79 | center: # the relative position from bounding box center to the frontal axis of this vehicle 80 | - 0.0004 81 | - 0.0005 82 | - 0.71 83 | extent: # half length, width and height of the vehicle in meter 84 | - 2.45 85 | - 1.06 86 | - 0.75 87 | location: # x, y ,z position of the center in the frontal axis of the vehicle under CARLA map coordinate system 88 | - 158.55 89 | - -385.75 90 | - 0.032 91 | speed: 19.47 # vehicle's speed 92 | 4880: ... 93 | ``` 94 | 95 | -------------------------------------------------------------------------------- /docs/data_intro.md: -------------------------------------------------------------------------------- 1 | ## Data Introduction 2 | 3 | --- 4 | 5 | V2XSet data is structured as following: 6 | 7 | ```sh 8 | V2XSet 9 | ├── train # data for training 10 | │ ├── 2021_08_22_21_41_24 # scenario folder 11 | │ ├── data_protocol.yaml # the simulation parameters used to collect the data in Carla 12 | │ └── -1 # The infra's id 13 | │ └── 00000.pcd - 00700.pcd # the point clouds data from timestamp 0 to 700 14 | │ ├── 00000.yaml - 00700.yaml # corresponding metadata for each timestamp 15 | │ ├── 00000_camera0.png - 00700_camera0.png # frontal camera images 16 | │ ├── 00000_camera1.png - 00700_camera1.png # right rear camera images 17 | │ ├── 00000_camera2.png - 00700_camera2.png # left rear camera images 18 | │ └── 00000_camera3.png - 00700_camera3.png # back camera images 19 | | └── 112 # The connected vehicle id 20 | ├── validate 21 | ├── test 22 | ``` 23 | 24 | ### 1. Data Split 25 | OPV2V dataset can be divided into 4 different folders: `train`, `validation`, `test` 26 | - `train`: contains all training data 27 | - `validate`: used for validation during training 28 | - `test`: test set 29 | 30 | ### 2. Scenario Database 31 | V2XSet has 58 scenarios in total, where each of them contains data stream from different agents across different timestamps. 32 | Each scenario is named by the time it was gathered, e.g., `2021_08_22_21_41_24`. 33 | 34 | ### 3. Agent Contents 35 | Under each scenario folder, the data of every intelligent agent~(i.e. infrastructure or connected automated vehicle) appearing in the current scenario is saved in different folders. Each folder is named by the agent's unique id, e.g., 1732. A negative id means infrastructure. 36 | 37 | In each agent folder, data across different timestamps will be saved. Those timestamps are represented by five digits integers 38 | as the prefix of the filenames (e.g., 00700.pcd). There are three types of files inside the agent folders: LiDAR point clouds (`.pcd` files), camera images (`.png` files), and metadata (`.yaml` files). 39 | 40 | #### 3.1 Lidar point cloud 41 | The LiDAR data is saved with Open3d package and has a postfix ".pcd" in the name. 42 | 43 | #### 3.2 Camera images 44 | Each CAV and Infra is equipped with 4 RGB cameras to capture the 360 degree of view of the surrounding scene.`camera0`, `camera1`, `camera2`, and `camera3` represent the front, right rear, left rear, and back cameras respectively. 45 | 46 | #### 3.3 Data Annotation 47 | All the metadata is saved in yaml files. It records the following important information at the current timestamp: 48 | - **ego information**: Current ego pose with and without GPS noise under Carla world coordinates, ego speed in km/h, the LiDAR pose, and future planning trajectories. 49 | - **calibration**: The intrinsic matrix and extrinsic matrix from each camera to the LiDAR sensor. 50 | - **objects annotation**: The pose and velocity of each surrounding human driving vehicle that has at least one point hit by the agent's LiDAR sensor. See [data annotation section](data_annotation_tutorial.md) for more details. 51 | 52 | ### 4. Data Collection Protocol 53 | Besides agent contents, every scenario database also has a yaml file named `data_protocol.yaml`. 54 | This yaml file records the simulation configuration to collect the current scenario. 55 | 56 | -------------------------------------------------------------------------------- /images/v2xvit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DerrickXuNu/v2x-vit/f0e6c13f41e916548b2d8aba61e42a18ce980416/images/v2xvit.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib 2 | numpy 3 | open3d 4 | opencv-python 5 | cython 6 | tensorboardX 7 | shapely 8 | einops 9 | 10 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Author: Runsheng Xu 3 | # License: TDG-Attribution-NonCommercial-NoDistrib 4 | 5 | 6 | from os.path import dirname, realpath 7 | from setuptools import setup, find_packages, Distribution 8 | from v2xvit.version import __version__ 9 | 10 | 11 | def _read_requirements_file(): 12 | """Return the elements in requirements.txt.""" 13 | req_file_path = '%s/requirements.txt' % dirname(realpath(__file__)) 14 | with open(req_file_path) as f: 15 | return [line.strip() for line in f] 16 | 17 | 18 | setup( 19 | name='V2XViT', 20 | version=__version__, 21 | packages=find_packages(), 22 | url='https://github.com/ucla-mobility/OpenCDA.git', 23 | license='MIT', 24 | author='Runsheng Xu, Hao Xiang, Zhengzhong Tu', 25 | author_email='rxx3386@ucla.edu', 26 | description='An opensource pytorch framework for autonomous driving ' 27 | 'cooperative detection', 28 | long_description=open("README.md").read(), 29 | install_requires=_read_requirements_file(), 30 | ) 31 | -------------------------------------------------------------------------------- /v2xvit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DerrickXuNu/v2x-vit/f0e6c13f41e916548b2d8aba61e42a18ce980416/v2xvit/__init__.py -------------------------------------------------------------------------------- /v2xvit/data_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DerrickXuNu/v2x-vit/f0e6c13f41e916548b2d8aba61e42a18ce980416/v2xvit/data_utils/__init__.py -------------------------------------------------------------------------------- /v2xvit/data_utils/augmentor/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DerrickXuNu/v2x-vit/f0e6c13f41e916548b2d8aba61e42a18ce980416/v2xvit/data_utils/augmentor/__init__.py -------------------------------------------------------------------------------- /v2xvit/data_utils/augmentor/augment_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from v2xvit.utils import common_utils 4 | 5 | 6 | def random_flip_along_x(gt_boxes, points): 7 | """ 8 | Args: 9 | gt_boxes: (N, 7 + C), [x, y, z, dx, dy, dz, heading, [vx], [vy]] 10 | points: (M, 3 + C) 11 | Returns: 12 | """ 13 | enable = np.random.choice([False, True], replace=False, p=[0.5, 0.5]) 14 | if enable: 15 | gt_boxes[:, 1] = -gt_boxes[:, 1] 16 | gt_boxes[:, 6] = -gt_boxes[:, 6] 17 | points[:, 1] = -points[:, 1] 18 | 19 | if gt_boxes.shape[1] > 7: 20 | gt_boxes[:, 8] = -gt_boxes[:, 8] 21 | 22 | return gt_boxes, points 23 | 24 | 25 | def random_flip_along_y(gt_boxes, points): 26 | """ 27 | Args: 28 | gt_boxes: (N, 7 + C), [x, y, z, dx, dy, dz, heading, [vx], [vy]] 29 | points: (M, 3 + C) 30 | Returns: 31 | """ 32 | enable = np.random.choice([False, True], replace=False, p=[0.5, 0.5]) 33 | if enable: 34 | gt_boxes[:, 0] = -gt_boxes[:, 0] 35 | gt_boxes[:, 6] = -(gt_boxes[:, 6] + np.pi) 36 | points[:, 0] = -points[:, 0] 37 | 38 | if gt_boxes.shape[1] > 7: 39 | gt_boxes[:, 7] = -gt_boxes[:, 7] 40 | 41 | return gt_boxes, points 42 | 43 | 44 | def global_rotation(gt_boxes, points, rot_range): 45 | """ 46 | Args: 47 | gt_boxes: (N, 7 + C), [x, y, z, dx, dy, dz, heading, [vx], [vy]] 48 | points: (M, 3 + C), 49 | rot_range: [min, max] 50 | Returns: 51 | """ 52 | noise_rotation = np.random.uniform(rot_range[0], 53 | rot_range[1]) 54 | points = common_utils.rotate_points_along_z(points[np.newaxis, :, :], 55 | np.array([noise_rotation]))[0] 56 | 57 | gt_boxes[:, 0:3] = \ 58 | common_utils.rotate_points_along_z(gt_boxes[np.newaxis, :, 0:3], 59 | np.array([noise_rotation]))[0] 60 | gt_boxes[:, 6] += noise_rotation 61 | 62 | if gt_boxes.shape[1] > 7: 63 | gt_boxes[:, 7:9] = common_utils.rotate_points_along_z( 64 | np.hstack((gt_boxes[:, 7:9], np.zeros((gt_boxes.shape[0], 1))))[ 65 | np.newaxis, :, :], 66 | np.array([noise_rotation]))[0][:, 0:2] 67 | 68 | return gt_boxes, points 69 | 70 | 71 | def global_scaling(gt_boxes, points, scale_range): 72 | """ 73 | Args: 74 | gt_boxes: (N, 7), [x, y, z, dx, dy, dz, heading] 75 | points: (M, 3 + C), 76 | scale_range: [min, max] 77 | Returns: 78 | """ 79 | if scale_range[1] - scale_range[0] < 1e-3: 80 | return gt_boxes, points 81 | noise_scale = np.random.uniform(scale_range[0], scale_range[1]) 82 | points[:, :3] *= noise_scale 83 | gt_boxes[:, :6] *= noise_scale 84 | 85 | return gt_boxes, points 86 | -------------------------------------------------------------------------------- /v2xvit/data_utils/augmentor/data_augmentor.py: -------------------------------------------------------------------------------- 1 | """ 2 | Class for data augmentation 3 | """ 4 | from functools import partial 5 | 6 | from v2xvit.data_utils.augmentor import augment_utils 7 | 8 | 9 | class DataAugmentor(object): 10 | """ 11 | Data Augmentor. 12 | 13 | Parameters 14 | ---------- 15 | augment_config : list 16 | A list of augmentation configuration. 17 | 18 | Attributes 19 | ---------- 20 | data_augmentor_queue : list 21 | The list of data augmented functions. 22 | """ 23 | 24 | def __init__(self, augment_config, train=True): 25 | self.data_augmentor_queue = [] 26 | self.train = train 27 | 28 | for cur_cfg in augment_config: 29 | cur_augmentor = getattr(self, cur_cfg['NAME'])(config=cur_cfg) 30 | self.data_augmentor_queue.append(cur_augmentor) 31 | 32 | def random_world_flip(self, data_dict=None, config=None): 33 | if data_dict is None: 34 | return partial(self.random_world_flip, config=config) 35 | 36 | gt_boxes, gt_mask, points = data_dict['object_bbx_center'], \ 37 | data_dict['object_bbx_mask'], \ 38 | data_dict['lidar_np'] 39 | gt_boxes_valid = gt_boxes[gt_mask == 1] 40 | 41 | for cur_axis in config['ALONG_AXIS_LIST']: 42 | assert cur_axis in ['x', 'y'] 43 | gt_boxes_valid, points = getattr(augment_utils, 44 | 'random_flip_along_%s' % cur_axis)( 45 | gt_boxes_valid, points, 46 | ) 47 | 48 | gt_boxes[:gt_boxes_valid.shape[0], :] = gt_boxes_valid 49 | 50 | data_dict['object_bbx_center'] = gt_boxes 51 | data_dict['object_bbx_mask'] = gt_mask 52 | data_dict['lidar_np'] = points 53 | 54 | return data_dict 55 | 56 | def random_world_rotation(self, data_dict=None, config=None): 57 | if data_dict is None: 58 | return partial(self.random_world_rotation, config=config) 59 | 60 | rot_range = config['WORLD_ROT_ANGLE'] 61 | if not isinstance(rot_range, list): 62 | rot_range = [-rot_range, rot_range] 63 | 64 | gt_boxes, gt_mask, points = data_dict['object_bbx_center'], \ 65 | data_dict['object_bbx_mask'], \ 66 | data_dict['lidar_np'] 67 | gt_boxes_valid = gt_boxes[gt_mask == 1] 68 | gt_boxes_valid, points = augment_utils.global_rotation( 69 | gt_boxes_valid, points, rot_range=rot_range 70 | ) 71 | gt_boxes[:gt_boxes_valid.shape[0], :] = gt_boxes_valid 72 | 73 | data_dict['object_bbx_center'] = gt_boxes 74 | data_dict['object_bbx_mask'] = gt_mask 75 | data_dict['lidar_np'] = points 76 | 77 | return data_dict 78 | 79 | def random_world_scaling(self, data_dict=None, config=None): 80 | if data_dict is None: 81 | return partial(self.random_world_scaling, config=config) 82 | 83 | gt_boxes, gt_mask, points = data_dict['object_bbx_center'], \ 84 | data_dict['object_bbx_mask'], \ 85 | data_dict['lidar_np'] 86 | gt_boxes_valid = gt_boxes[gt_mask == 1] 87 | 88 | gt_boxes_valid, points = augment_utils.global_scaling( 89 | gt_boxes_valid, points, config['WORLD_SCALE_RANGE'] 90 | ) 91 | gt_boxes[:gt_boxes_valid.shape[0], :] = gt_boxes_valid 92 | 93 | data_dict['object_bbx_center'] = gt_boxes 94 | data_dict['object_bbx_mask'] = gt_mask 95 | data_dict['lidar_np'] = points 96 | 97 | return data_dict 98 | 99 | def forward(self, data_dict): 100 | """ 101 | Args: 102 | data_dict: 103 | points: (N, 3 + C_in) 104 | gt_boxes: optional, (N, 7) [x, y, z, dx, dy, dz, heading] 105 | gt_names: optional, (N), string 106 | ... 107 | 108 | Returns: 109 | """ 110 | if self.train: 111 | for cur_augmentor in self.data_augmentor_queue: 112 | data_dict = cur_augmentor(data_dict=data_dict) 113 | 114 | return data_dict 115 | -------------------------------------------------------------------------------- /v2xvit/data_utils/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from v2xvit.data_utils.datasets.late_fusion_dataset import LateFusionDataset 2 | from v2xvit.data_utils.datasets.early_fusion_dataset import EarlyFusionDataset 3 | from v2xvit.data_utils.datasets.intermediate_fusion_dataset import IntermediateFusionDataset 4 | 5 | __all__ = { 6 | 'LateFusionDataset': LateFusionDataset, 7 | 'EarlyFusionDataset': EarlyFusionDataset, 8 | 'IntermediateFusionDataset': IntermediateFusionDataset 9 | } 10 | 11 | # the final range for evaluation 12 | GT_RANGE = [-140, -40, -3, 140, 40, 1] 13 | # The communication range for cavs 14 | COM_RANGE = 70 15 | 16 | 17 | def build_dataset(dataset_cfg, visualize=False, train=True): 18 | dataset_name = dataset_cfg['fusion']['core_method'] 19 | error_message = f"{dataset_name} is not found. " \ 20 | f"Please add your processor file's name in opencood/" \ 21 | f"data_utils/datasets/init.py" 22 | assert dataset_name in ['LateFusionDataset', 'EarlyFusionDataset', 23 | 'IntermediateFusionDataset'], error_message 24 | 25 | dataset = __all__[dataset_name]( 26 | params=dataset_cfg, 27 | visualize=visualize, 28 | train=train 29 | ) 30 | 31 | return dataset 32 | -------------------------------------------------------------------------------- /v2xvit/data_utils/datasets/early_fusion_vis_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is a dataset for early fusion visualization only. 3 | """ 4 | from collections import OrderedDict 5 | 6 | import numpy as np 7 | import torch 8 | 9 | from v2xvit.utils import box_utils 10 | from v2xvit.data_utils.post_processor import build_postprocessor 11 | from v2xvit.data_utils.datasets import basedataset 12 | from v2xvit.data_utils.pre_processor import build_preprocessor 13 | from v2xvit.utils.pcd_utils import \ 14 | mask_points_by_range, mask_ego_points, shuffle_points, \ 15 | downsample_lidar_minimum 16 | 17 | 18 | class EarlyFusionVisDataset(basedataset.BaseDataset): 19 | def __init__(self, params, visualize, train=True): 20 | super(EarlyFusionVisDataset, self).__init__(params, visualize, train) 21 | self.pre_processor = build_preprocessor(params['preprocess'], 22 | train) 23 | self.post_processor = build_postprocessor(params['postprocess'], train) 24 | 25 | def __getitem__(self, idx): 26 | base_data_dict = self.retrieve_base_data(idx) 27 | 28 | processed_data_dict = OrderedDict() 29 | processed_data_dict['ego'] = {} 30 | 31 | ego_id = -1 32 | ego_lidar_pose = [] 33 | 34 | # first find the ego vehicle's lidar pose 35 | for cav_id, cav_content in base_data_dict.items(): 36 | if cav_content['ego']: 37 | ego_id = cav_id 38 | ego_lidar_pose = cav_content['params']['lidar_pose'] 39 | break 40 | 41 | assert ego_id != -1 42 | assert len(ego_lidar_pose) > 0 43 | 44 | projected_lidar_stack = [] 45 | object_stack = [] 46 | object_id_stack = [] 47 | 48 | # loop over all CAVs to process information 49 | for cav_id, selected_cav_base in base_data_dict.items(): 50 | selected_cav_processed = self.get_item_single_car( 51 | selected_cav_base, 52 | ego_lidar_pose) 53 | # all these lidar and object coordinates are projected to ego 54 | # already. 55 | projected_lidar_stack.append( 56 | selected_cav_processed['projected_lidar']) 57 | object_stack.append(selected_cav_processed['object_bbx_center']) 58 | object_id_stack += selected_cav_processed['object_ids'] 59 | 60 | # exclude all repetitive objects 61 | unique_indices = \ 62 | [object_id_stack.index(x) for x in set(object_id_stack)] 63 | object_stack = np.vstack(object_stack) 64 | object_stack = object_stack[unique_indices] 65 | 66 | # make sure bounding boxes across all frames have the same number 67 | object_bbx_center = \ 68 | np.zeros((self.params['postprocess']['max_num'], 7)) 69 | mask = np.zeros(self.params['postprocess']['max_num']) 70 | object_bbx_center[:object_stack.shape[0], :] = object_stack 71 | mask[:object_stack.shape[0]] = 1 72 | 73 | # convert list to numpy array, (N, 4) 74 | projected_lidar_stack = np.vstack(projected_lidar_stack) 75 | 76 | # data augmentation 77 | projected_lidar_stack, object_bbx_center, mask = \ 78 | self.augment(projected_lidar_stack, object_bbx_center, mask) 79 | 80 | # we do lidar filtering in the stacked lidar 81 | projected_lidar_stack = mask_points_by_range(projected_lidar_stack, 82 | self.params['preprocess'][ 83 | 'cav_lidar_range']) 84 | # augmentation may remove some of the bbx out of range 85 | object_bbx_center_valid = object_bbx_center[mask == 1] 86 | object_bbx_center_valid = \ 87 | box_utils.mask_boxes_outside_range_numpy(object_bbx_center_valid, 88 | self.params['preprocess'][ 89 | 'cav_lidar_range'], 90 | self.params['postprocess'][ 91 | 'order'] 92 | ) 93 | mask[object_bbx_center_valid.shape[0]:] = 0 94 | object_bbx_center[:object_bbx_center_valid.shape[0]] = \ 95 | object_bbx_center_valid 96 | object_bbx_center[object_bbx_center_valid.shape[0]:] = 0 97 | 98 | processed_data_dict['ego'].update( 99 | {'object_bbx_center': object_bbx_center, 100 | 'object_bbx_mask': mask, 101 | 'object_ids': [object_id_stack[i] for i in unique_indices], 102 | 'origin_lidar': projected_lidar_stack 103 | }) 104 | 105 | return processed_data_dict 106 | 107 | def get_item_single_car(self, selected_cav_base, ego_pose): 108 | """ 109 | Project the lidar and bbx to ego space first, and then do clipping. 110 | 111 | Parameters 112 | ---------- 113 | selected_cav_base : dict 114 | The dictionary contains a single CAV's raw information. 115 | ego_pose : list 116 | The ego vehicle lidar pose under world coordinate. 117 | 118 | Returns 119 | ------- 120 | selected_cav_processed : dict 121 | The dictionary contains the cav's processed information. 122 | """ 123 | selected_cav_processed = {} 124 | 125 | # calculate the transformation matrix 126 | transformation_matrix = \ 127 | selected_cav_base['params']['transformation_matrix'] 128 | 129 | # retrieve objects under ego coordinates 130 | object_bbx_center, object_bbx_mask, object_ids = \ 131 | self.post_processor.generate_object_center([selected_cav_base], 132 | ego_pose) 133 | 134 | # filter lidar 135 | lidar_np = selected_cav_base['lidar_np'] 136 | lidar_np = shuffle_points(lidar_np) 137 | # remove points that hit itself 138 | lidar_np = mask_ego_points(lidar_np) 139 | # project the lidar to ego space 140 | lidar_np[:, :3] = \ 141 | box_utils.project_points_by_matrix_torch(lidar_np[:, :3], 142 | transformation_matrix) 143 | 144 | selected_cav_processed.update( 145 | {'object_bbx_center': object_bbx_center[object_bbx_mask == 1], 146 | 'object_ids': object_ids, 147 | 'projected_lidar': lidar_np}) 148 | 149 | return selected_cav_processed 150 | 151 | def collate_batch_train(self, batch): 152 | """ 153 | Customized collate function for pytorch dataloader during training 154 | for late fusion dataset. 155 | 156 | Parameters 157 | ---------- 158 | batch : dict 159 | 160 | Returns 161 | ------- 162 | batch : dict 163 | Reformatted batch. 164 | """ 165 | # during training, we only care about ego. 166 | output_dict = {'ego': {}} 167 | 168 | object_bbx_center = [] 169 | object_bbx_mask = [] 170 | origin_lidar = [] 171 | 172 | for i in range(len(batch)): 173 | ego_dict = batch[i]['ego'] 174 | object_bbx_center.append(ego_dict['object_bbx_center']) 175 | object_bbx_mask.append(ego_dict['object_bbx_mask']) 176 | origin_lidar.append(ego_dict['origin_lidar']) 177 | 178 | # convert to numpy, (B, max_num, 7) 179 | object_bbx_center = torch.from_numpy(np.array(object_bbx_center)) 180 | object_bbx_mask = torch.from_numpy(np.array(object_bbx_mask)) 181 | output_dict['ego'].update({'object_bbx_center': object_bbx_center, 182 | 'object_bbx_mask': object_bbx_mask}) 183 | 184 | origin_lidar = \ 185 | np.array(downsample_lidar_minimum(pcd_np_list=origin_lidar)) 186 | origin_lidar = torch.from_numpy(origin_lidar) 187 | output_dict['ego'].update({'origin_lidar': origin_lidar}) 188 | 189 | return output_dict 190 | -------------------------------------------------------------------------------- /v2xvit/data_utils/post_processor/__init__.py: -------------------------------------------------------------------------------- 1 | from v2xvit.data_utils.post_processor.voxel_postprocessor import VoxelPostprocessor 2 | from v2xvit.data_utils.post_processor.bev_postprocessor import BevPostprocessor 3 | 4 | __all__ = { 5 | 'VoxelPostprocessor': VoxelPostprocessor, 6 | 'BevPostprocessor': BevPostprocessor, 7 | } 8 | 9 | 10 | def build_postprocessor(anchor_cfg, train): 11 | process_method_name = anchor_cfg['core_method'] 12 | assert process_method_name in ['VoxelPostprocessor', 'BevPostprocessor'] 13 | anchor_generator = __all__[process_method_name]( 14 | anchor_params=anchor_cfg, 15 | train=train 16 | ) 17 | 18 | return anchor_generator 19 | -------------------------------------------------------------------------------- /v2xvit/data_utils/post_processor/base_postprocessor.py: -------------------------------------------------------------------------------- 1 | """ 2 | Template for AnchorGenerator 3 | """ 4 | 5 | import numpy as np 6 | import torch 7 | 8 | from v2xvit.utils import box_utils 9 | 10 | 11 | class BasePostprocessor(object): 12 | """ 13 | Template for Anchor generator. 14 | 15 | Parameters 16 | ---------- 17 | anchor_params : dict 18 | The dictionary containing all anchor-related parameters. 19 | train : bool 20 | Indicate train or test mode. 21 | 22 | Attributes 23 | ---------- 24 | bbx_dict : dictionary 25 | Contain all objects information across the cav, key: id, value: bbx 26 | coordinates (1, 7) 27 | """ 28 | 29 | def __init__(self, anchor_params, train=True): 30 | self.params = anchor_params 31 | self.bbx_dict = {} 32 | self.train = train 33 | 34 | def generate_anchor_box(self): 35 | # needs to be overloaded 36 | return None 37 | 38 | def generate_label(self, *argv): 39 | return None 40 | 41 | def generate_gt_bbx(self, data_dict): 42 | """ 43 | The base postprocessor will generate 3d groundtruth bounding box. 44 | 45 | Parameters 46 | ---------- 47 | data_dict : dict 48 | The dictionary containing the origin input data of model. 49 | 50 | Returns 51 | ------- 52 | gt_box3d_tensor : torch.Tensor 53 | The groundtruth bounding box tensor, shape (N, 8, 3). 54 | """ 55 | gt_box3d_list = [] 56 | # used to avoid repetitive bounding box 57 | object_id_list = [] 58 | 59 | for cav_id, cav_content in data_dict.items(): 60 | # used to project gt bounding box to ego space. 61 | # the transformation matrix for gt should always be based on 62 | # current timestamp (object transformation matrix is for 63 | # late fusion only since other fusion method already did 64 | # the transformation in the preprocess) 65 | transformation_matrix = cav_content['transformation_matrix'] \ 66 | if 'gt_transformation_matrix' not in cav_content \ 67 | else cav_content['gt_transformation_matrix'] 68 | 69 | object_bbx_center = cav_content['object_bbx_center'] 70 | object_bbx_mask = cav_content['object_bbx_mask'] 71 | object_ids = cav_content['object_ids'] 72 | object_bbx_center = object_bbx_center[object_bbx_mask == 1] 73 | 74 | # convert center to corner 75 | object_bbx_corner = \ 76 | box_utils.boxes_to_corners_3d(object_bbx_center, 77 | self.params['order']) 78 | projected_object_bbx_corner = \ 79 | box_utils.project_box3d(object_bbx_corner.float(), 80 | transformation_matrix) 81 | gt_box3d_list.append(projected_object_bbx_corner) 82 | 83 | # append the corresponding ids 84 | object_id_list += object_ids 85 | 86 | # gt bbx 3d 87 | gt_box3d_list = torch.vstack(gt_box3d_list) 88 | # some of the bbx may be repetitive, use the id list to filter 89 | gt_box3d_selected_indices = \ 90 | [object_id_list.index(x) for x in set(object_id_list)] 91 | gt_box3d_tensor = gt_box3d_list[gt_box3d_selected_indices] 92 | 93 | # filter the gt_box to make sure all bbx are in the range 94 | mask = \ 95 | box_utils.get_mask_for_boxes_within_range_torch(gt_box3d_tensor) 96 | gt_box3d_tensor = gt_box3d_tensor[mask, :, :] 97 | 98 | return gt_box3d_tensor 99 | 100 | def generate_object_center(self, 101 | cav_contents, 102 | reference_lidar_pose): 103 | """ 104 | Retrieve all objects in a format of (n, 7), where 7 represents 105 | x, y, z, l, w, h, yaw or x, y, z, h, w, l, yaw. 106 | 107 | Parameters 108 | ---------- 109 | cav_contents : list 110 | List of dictionary, save all cavs' information. 111 | 112 | reference_lidar_pose : list 113 | The final target lidar pose with length 6. 114 | 115 | Returns 116 | ------- 117 | object_np : np.ndarray 118 | Shape is (max_num, 7). 119 | mask : np.ndarray 120 | Shape is (max_num,). 121 | object_ids : list 122 | Length is number of bbx in current sample. 123 | """ 124 | from v2xvit.data_utils.datasets import GT_RANGE 125 | 126 | tmp_object_dict = {} 127 | for cav_content in cav_contents: 128 | tmp_object_dict.update(cav_content['params']['vehicles']) 129 | 130 | output_dict = {} 131 | filter_range = self.params['anchor_args']['cav_lidar_range'] \ 132 | if self.train else GT_RANGE 133 | 134 | box_utils.project_world_objects(tmp_object_dict, 135 | output_dict, 136 | reference_lidar_pose, 137 | filter_range, 138 | self.params['order']) 139 | 140 | object_np = np.zeros((self.params['max_num'], 7)) 141 | mask = np.zeros(self.params['max_num']) 142 | object_ids = [] 143 | 144 | for i, (object_id, object_bbx) in enumerate(output_dict.items()): 145 | object_np[i] = object_bbx[0, :] 146 | mask[i] = 1 147 | object_ids.append(object_id) 148 | 149 | return object_np, mask, object_ids 150 | -------------------------------------------------------------------------------- /v2xvit/data_utils/pre_processor/__init__.py: -------------------------------------------------------------------------------- 1 | from v2xvit.data_utils.pre_processor.base_preprocessor import BasePreprocessor 2 | from v2xvit.data_utils.pre_processor.voxel_preprocessor import VoxelPreprocessor 3 | from v2xvit.data_utils.pre_processor.bev_preprocessor import BevPreprocessor 4 | from v2xvit.data_utils.pre_processor.sp_voxel_preprocessor import SpVoxelPreprocessor 5 | 6 | __all__ = { 7 | 'BasePreprocessor': BasePreprocessor, 8 | 'VoxelPreprocessor': VoxelPreprocessor, 9 | 'BevPreprocessor': BevPreprocessor, 10 | 'SpVoxelPreprocessor': SpVoxelPreprocessor 11 | } 12 | 13 | 14 | def build_preprocessor(preprocess_cfg, train): 15 | process_method_name = preprocess_cfg['core_method'] 16 | error_message = f"{process_method_name} is not found. " \ 17 | f"Please add your processor file's name in opencood/" \ 18 | f"data_utils/processor/init.py" 19 | assert process_method_name in ['BasePreprocessor', 'VoxelPreprocessor', 20 | 'BevPreprocessor', 'SpVoxelPreprocessor'], \ 21 | error_message 22 | 23 | processor = __all__[process_method_name]( 24 | preprocess_params=preprocess_cfg, 25 | train=train 26 | ) 27 | 28 | return processor 29 | -------------------------------------------------------------------------------- /v2xvit/data_utils/pre_processor/base_preprocessor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from v2xvit.utils import pcd_utils 4 | 5 | 6 | class BasePreprocessor(object): 7 | """ 8 | Basic Lidar pre-processor. 9 | 10 | Parameters 11 | ---------- 12 | preprocess_params : dict 13 | The dictionary containing all parameters of the preprocessing. 14 | 15 | train : bool 16 | Train or test mode. 17 | """ 18 | 19 | def __init__(self, preprocess_params, train): 20 | self.params = preprocess_params 21 | self.train = train 22 | 23 | def preprocess(self, pcd_np): 24 | """ 25 | Preprocess the lidar points by simple sampling. 26 | 27 | Parameters 28 | ---------- 29 | pcd_np : np.ndarray 30 | The raw lidar. 31 | 32 | Returns 33 | ------- 34 | data_dict : the output dictionary. 35 | """ 36 | data_dict = {} 37 | sample_num = self.params['args']['sample_num'] 38 | 39 | pcd_np = pcd_utils.downsample_lidar(pcd_np, sample_num) 40 | data_dict['downsample_lidar'] = pcd_np 41 | 42 | return data_dict 43 | 44 | def project_points_to_bev_map(self, points, ratio=0.1): 45 | """ 46 | Project points to BEV occupancy map with default ratio=0.1. 47 | 48 | Parameters 49 | ---------- 50 | points : np.ndarray 51 | (N, 3) / (N, 4) 52 | 53 | ratio : float 54 | Discretization parameters. Default is 0.1. 55 | 56 | Returns 57 | ------- 58 | bev_map : np.ndarray 59 | BEV occupancy map including projected points with shape 60 | (img_row, img_col). 61 | 62 | """ 63 | L1, W1, H1, L2, W2, H2 = self.params["cav_lidar_range"] 64 | img_row = int((L2 - L1) / ratio) 65 | img_col = int((W2 - W1) / ratio) 66 | bev_map = np.zeros((img_row, img_col)) 67 | bev_origin = np.array([L1, W1, H1]).reshape(1, -1) 68 | # (N, 3) 69 | indices = ((points[:, :3] - bev_origin) / ratio).astype(int) 70 | mask = np.logical_and(indices[:, 0] > 0, indices[:, 0] < img_row) 71 | mask = np.logical_and(mask, np.logical_and(indices[:, 1] > 0, 72 | indices[:, 1] < img_col)) 73 | indices = indices[mask, :] 74 | bev_map[indices[:, 0], indices[:, 1]] = 1 75 | return bev_map 76 | -------------------------------------------------------------------------------- /v2xvit/data_utils/pre_processor/bev_preprocessor.py: -------------------------------------------------------------------------------- 1 | """ 2 | Convert lidar to bev 3 | """ 4 | 5 | import numpy as np 6 | import torch 7 | from v2xvit.data_utils.pre_processor.base_preprocessor import \ 8 | BasePreprocessor 9 | 10 | class BevPreprocessor(BasePreprocessor): 11 | def __init__(self, preprocess_params, train): 12 | super(BevPreprocessor, self).__init__(preprocess_params, train) 13 | self.lidar_range = self.params['cav_lidar_range'] 14 | self.geometry_param = preprocess_params["geometry_param"] 15 | 16 | def preprocess(self, pcd_raw): 17 | """ 18 | Preprocess the lidar points to BEV representations. 19 | 20 | Parameters 21 | ---------- 22 | pcd_raw : np.ndarray 23 | The raw lidar. 24 | 25 | Returns 26 | ------- 27 | data_dict : the structured output dictionary. 28 | """ 29 | bev = np.zeros(self.geometry_param['input_shape'], dtype=np.float32) 30 | intensity_map_count = np.zeros((bev.shape[0], bev.shape[1]), dtype=np.int) 31 | bev_origin = np.array( 32 | [self.geometry_param["L1"], self.geometry_param["W1"], 33 | self.geometry_param["H1"]]).reshape(1, -1) 34 | 35 | indices = ((pcd_raw[:, :3] - bev_origin) / self.geometry_param[ 36 | "res"]).astype(int) 37 | ## bev[indices[:, 0], indices[:, 1], indices[:, 2]] = 1 38 | # np.add.at(bev, (indices[:, 0], indices[:, 1], indices[:, 2]), 1) 39 | # bev[indices[:, 0], indices[:, 1], -1] += pcd_raw[:, 3] 40 | # intensity_map_count[indices[:, 0], indices[:, 1]] += 1 41 | 42 | for i in range(indices.shape[0]): 43 | bev[indices[i, 0], indices[i, 1], indices[i, 2]] = 1 44 | bev[indices[i, 0], indices[i, 1], -1] += pcd_raw[i, 3] 45 | intensity_map_count[indices[i, 0], indices[i, 1]] += 1 46 | divide_mask = intensity_map_count!=0 47 | bev[divide_mask, -1] = np.divide(bev[divide_mask, -1], intensity_map_count[divide_mask]) 48 | 49 | data_dict = { 50 | "bev_input": np.transpose(bev, (2, 0, 1)) 51 | } 52 | return data_dict 53 | 54 | @staticmethod 55 | def collate_batch_list(batch): 56 | """ 57 | Customized pytorch data loader collate function. 58 | 59 | Parameters 60 | ---------- 61 | batch : list 62 | List of dictionary. Each dictionary represent a single frame. 63 | 64 | Returns 65 | ------- 66 | processed_batch : dict 67 | Updated lidar batch. 68 | """ 69 | bev_input_list = [ 70 | x["bev_input"][np.newaxis, ...] for x in batch 71 | ] 72 | processed_batch = { 73 | "bev_input": torch.from_numpy( 74 | np.concatenate(bev_input_list, axis=0)) 75 | } 76 | return processed_batch 77 | @staticmethod 78 | def collate_batch_dict(batch): 79 | """ 80 | Customized pytorch data loader collate function. 81 | 82 | Parameters 83 | ---------- 84 | batch : dict 85 | Dict of list. Each element represents a CAV. 86 | 87 | Returns 88 | ------- 89 | processed_batch : dict 90 | Updated lidar batch. 91 | """ 92 | bev_input_list = [ 93 | x[np.newaxis, ...] for x in batch["bev_input"] 94 | ] 95 | processed_batch = { 96 | "bev_input": torch.from_numpy( 97 | np.concatenate(bev_input_list, axis=0)) 98 | } 99 | return processed_batch 100 | 101 | def collate_batch(self, batch): 102 | """ 103 | Customized pytorch data loader collate function. 104 | 105 | Parameters 106 | ---------- 107 | batch : list / dict 108 | Batched data. 109 | Returns 110 | ------- 111 | processed_batch : dict 112 | Updated lidar batch. 113 | """ 114 | if isinstance(batch, list): 115 | return self.collate_batch_list(batch) 116 | elif isinstance(batch, dict): 117 | return self.collate_batch_dict(batch) 118 | else: 119 | raise NotImplemented 120 | 121 | -------------------------------------------------------------------------------- /v2xvit/data_utils/pre_processor/sp_voxel_preprocessor.py: -------------------------------------------------------------------------------- 1 | """ 2 | Transform points to voxels using sparse conv library 3 | """ 4 | import sys 5 | 6 | import numpy as np 7 | import torch 8 | from cumm import tensorview as tv 9 | from spconv.utils import Point2VoxelCPU3d 10 | 11 | from v2xvit.data_utils.pre_processor.base_preprocessor import \ 12 | BasePreprocessor 13 | 14 | 15 | class SpVoxelPreprocessor(BasePreprocessor): 16 | def __init__(self, preprocess_params, train): 17 | super(SpVoxelPreprocessor, self).__init__(preprocess_params, 18 | train) 19 | 20 | self.lidar_range = self.params['cav_lidar_range'] 21 | self.voxel_size = self.params['args']['voxel_size'] 22 | self.max_points_per_voxel = self.params['args']['max_points_per_voxel'] 23 | 24 | if train: 25 | self.max_voxels = self.params['args']['max_voxel_train'] 26 | else: 27 | self.max_voxels = self.params['args']['max_voxel_test'] 28 | 29 | grid_size = (np.array(self.lidar_range[3:6]) - 30 | np.array(self.lidar_range[0:3])) / np.array(self.voxel_size) 31 | self.grid_size = np.round(grid_size).astype(np.int64) 32 | 33 | # use sparse conv library to generate voxel 34 | self.voxel_generator = Point2VoxelCPU3d( 35 | vsize_xyz=self.voxel_size, 36 | coors_range_xyz=self.lidar_range, 37 | max_num_points_per_voxel=self.max_points_per_voxel, 38 | num_point_features=4, 39 | max_num_voxels=self.max_voxels 40 | ) 41 | 42 | def preprocess(self, pcd_np): 43 | data_dict = {} 44 | pcd_tv = tv.from_numpy(pcd_np) 45 | voxel_output = self.voxel_generator.point_to_voxel(pcd_tv) 46 | if isinstance(voxel_output, dict): 47 | voxels, coordinates, num_points = \ 48 | voxel_output['voxels'], voxel_output['coordinates'], \ 49 | voxel_output['num_points_per_voxel'] 50 | else: 51 | voxels, coordinates, num_points = voxel_output 52 | 53 | data_dict['voxel_features'] = voxels.numpy() 54 | data_dict['voxel_coords'] = coordinates.numpy() 55 | data_dict['voxel_num_points'] = num_points.numpy() 56 | 57 | return data_dict 58 | 59 | def collate_batch(self, batch): 60 | """ 61 | Customized pytorch data loader collate function. 62 | 63 | Parameters 64 | ---------- 65 | batch : list or dict 66 | List or dictionary. 67 | 68 | Returns 69 | ------- 70 | processed_batch : dict 71 | Updated lidar batch. 72 | """ 73 | 74 | if isinstance(batch, list): 75 | return self.collate_batch_list(batch) 76 | elif isinstance(batch, dict): 77 | return self.collate_batch_dict(batch) 78 | else: 79 | sys.exit('Batch has too be a list or a dictionarn') 80 | 81 | @staticmethod 82 | def collate_batch_list(batch): 83 | """ 84 | Customized pytorch data loader collate function. 85 | 86 | Parameters 87 | ---------- 88 | batch : list 89 | List of dictionary. Each dictionary represent a single frame. 90 | 91 | Returns 92 | ------- 93 | processed_batch : dict 94 | Updated lidar batch. 95 | """ 96 | voxel_features = [] 97 | voxel_num_points = [] 98 | voxel_coords = [] 99 | 100 | for i in range(len(batch)): 101 | voxel_features.append(batch[i]['voxel_features']) 102 | voxel_num_points.append(batch[i]['voxel_num_points']) 103 | coords = batch[i]['voxel_coords'] 104 | voxel_coords.append( 105 | np.pad(coords, ((0, 0), (1, 0)), 106 | mode='constant', constant_values=i)) 107 | 108 | voxel_num_points = torch.from_numpy(np.concatenate(voxel_num_points)) 109 | voxel_features = torch.from_numpy(np.concatenate(voxel_features)) 110 | voxel_coords = torch.from_numpy(np.concatenate(voxel_coords)) 111 | 112 | return {'voxel_features': voxel_features, 113 | 'voxel_coords': voxel_coords, 114 | 'voxel_num_points': voxel_num_points} 115 | 116 | @staticmethod 117 | def collate_batch_dict(batch: dict): 118 | """ 119 | Collate batch if the batch is a dictionary, 120 | eg: {'voxel_features': [feature1, feature2...., feature n]} 121 | 122 | Parameters 123 | ---------- 124 | batch : dict 125 | 126 | Returns 127 | ------- 128 | processed_batch : dict 129 | Updated lidar batch. 130 | """ 131 | voxel_features = \ 132 | torch.from_numpy(np.concatenate(batch['voxel_features'])) 133 | voxel_num_points = \ 134 | torch.from_numpy(np.concatenate(batch['voxel_num_points'])) 135 | coords = batch['voxel_coords'] 136 | voxel_coords = [] 137 | 138 | for i in range(len(coords)): 139 | voxel_coords.append( 140 | np.pad(coords[i], ((0, 0), (1, 0)), 141 | mode='constant', constant_values=i)) 142 | voxel_coords = torch.from_numpy(np.concatenate(voxel_coords)) 143 | 144 | return {'voxel_features': voxel_features, 145 | 'voxel_coords': voxel_coords, 146 | 'voxel_num_points': voxel_num_points} 147 | -------------------------------------------------------------------------------- /v2xvit/data_utils/pre_processor/voxel_preprocessor.py: -------------------------------------------------------------------------------- 1 | """ 2 | Convert lidar to voxel 3 | """ 4 | import sys 5 | 6 | import numpy as np 7 | import torch 8 | 9 | from v2xvit.data_utils.pre_processor.base_preprocessor import \ 10 | BasePreprocessor 11 | 12 | 13 | class VoxelPreprocessor(BasePreprocessor): 14 | def __init__(self, preprocess_params, train): 15 | super(VoxelPreprocessor, self).__init__(preprocess_params, train) 16 | self.lidar_range = self.params['cav_lidar_range'] 17 | 18 | self.vw = self.params['args']['vw'] 19 | self.vh = self.params['args']['vh'] 20 | self.vd = self.params['args']['vd'] 21 | self.T = self.params['args']['T'] 22 | 23 | def preprocess(self, pcd_np): 24 | """ 25 | Preprocess the lidar points by voxelization. 26 | 27 | Parameters 28 | ---------- 29 | pcd_np : np.ndarray 30 | The raw lidar. 31 | 32 | Returns 33 | ------- 34 | data_dict : the structured output dictionary. 35 | """ 36 | data_dict = {} 37 | 38 | # calculate the voxel coordinates 39 | voxel_coords = ((pcd_np[:, :3] - 40 | np.floor(np.array([self.lidar_range[0], 41 | self.lidar_range[1], 42 | self.lidar_range[2]])) / ( 43 | self.vw, self.vh, self.vd))).astype(np.int32) 44 | 45 | # convert to (D, H, W) as the paper 46 | voxel_coords = voxel_coords[:, [2, 1, 0]] 47 | voxel_coords, inv_ind, voxel_counts = np.unique(voxel_coords, axis=0, 48 | return_inverse=True, 49 | return_counts=True) 50 | 51 | voxel_features = [] 52 | 53 | for i in range(len(voxel_coords)): 54 | voxel = np.zeros((self.T, 7), dtype=np.float32) 55 | pts = pcd_np[inv_ind == i] 56 | if voxel_counts[i] > self.T: 57 | pts = pts[:self.T, :] 58 | voxel_counts[i] = self.T 59 | 60 | # augment the points 61 | voxel[:pts.shape[0], :] = np.concatenate((pts, pts[:, :3] - 62 | np.mean(pts[:, :3], 0)), 63 | axis=1) 64 | voxel_features.append(voxel) 65 | 66 | data_dict['voxel_features'] = np.array(voxel_features) 67 | data_dict['voxel_coords'] = voxel_coords 68 | 69 | return data_dict 70 | 71 | def collate_batch(self, batch): 72 | """ 73 | Customized pytorch data loader collate function. 74 | 75 | Parameters 76 | ---------- 77 | batch : list or dict 78 | List or dictionary. 79 | 80 | Returns 81 | ------- 82 | processed_batch : dict 83 | Updated lidar batch. 84 | """ 85 | 86 | if isinstance(batch, list): 87 | return self.collate_batch_list(batch) 88 | elif isinstance(batch, dict): 89 | return self.collate_batch_dict(batch) 90 | else: 91 | sys.exit('Batch has too be a list or a dictionarn') 92 | 93 | @staticmethod 94 | def collate_batch_list(batch): 95 | """ 96 | Customized pytorch data loader collate function. 97 | 98 | Parameters 99 | ---------- 100 | batch : list 101 | List of dictionary. Each dictionary represent a single frame. 102 | 103 | Returns 104 | ------- 105 | processed_batch : dict 106 | Updated lidar batch. 107 | """ 108 | voxel_features = [] 109 | voxel_coords = [] 110 | 111 | for i in range(len(batch)): 112 | voxel_features.append(batch[i]['voxel_features']) 113 | coords = batch[i]['voxel_coords'] 114 | voxel_coords.append( 115 | np.pad(coords, ((0, 0), (1, 0)), 116 | mode='constant', constant_values=i)) 117 | 118 | voxel_features = torch.from_numpy(np.concatenate(voxel_features)) 119 | voxel_coords = torch.from_numpy(np.concatenate(voxel_coords)) 120 | 121 | return {'voxel_features': voxel_features, 122 | 'voxel_coords': voxel_coords} 123 | 124 | @staticmethod 125 | def collate_batch_dict(batch: dict): 126 | """ 127 | Collate batch if the batch is a dictionary, 128 | eg: {'voxel_features': [feature1, feature2...., feature n]} 129 | 130 | Parameters 131 | ---------- 132 | batch : dict 133 | 134 | Returns 135 | ------- 136 | processed_batch : dict 137 | Updated lidar batch. 138 | """ 139 | voxel_features = \ 140 | torch.from_numpy(np.concatenate(batch['voxel_features'])) 141 | coords = batch['voxel_coords'] 142 | voxel_coords = [] 143 | 144 | for i in range(len(coords)): 145 | voxel_coords.append( 146 | np.pad(coords[i], ((0, 0), (1, 0)), 147 | mode='constant', constant_values=i)) 148 | voxel_coords = torch.from_numpy(np.concatenate(voxel_coords)) 149 | 150 | return {'voxel_features': voxel_features, 151 | 'voxel_coords': voxel_coords} 152 | -------------------------------------------------------------------------------- /v2xvit/hypes_yaml/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DerrickXuNu/v2x-vit/f0e6c13f41e916548b2d8aba61e42a18ce980416/v2xvit/hypes_yaml/__init__.py -------------------------------------------------------------------------------- /v2xvit/hypes_yaml/point_pillar_early_fusion.yaml: -------------------------------------------------------------------------------- 1 | name: point_pillar_early_fusion 2 | root_dir: 'v2xset/train' 3 | validate_dir: 'v2xset/validate' 4 | yaml_parser: "load_point_pillar_params" 5 | 6 | wild_setting: 7 | async: false 8 | async_overhead: 60 9 | seed: 20 10 | loc_err: false 11 | xyz_std: 0.2 12 | ryp_std: 0.2 13 | 14 | train_params: 15 | batch_size: &batch_size 4 16 | epoches: 22 17 | eval_freq: 1 18 | save_freq: 1 19 | 20 | fusion: 21 | core_method: 'EarlyFusionDataset' # LateFusionDataset, EarlyFusionDataset, IntermediateFusionDataset supported 22 | args: [] 23 | 24 | # preprocess-related 25 | preprocess: 26 | # options: BasePreprocessor, VoxelPreprocessor, BevPreprocessor 27 | core_method: 'SpVoxelPreprocessor' 28 | args: 29 | voxel_size: &voxel_size [0.4, 0.4, 4] 30 | max_points_per_voxel: 32 31 | max_voxel_train: 32000 32 | max_voxel_test: 70000 33 | # lidar range for each individual cav. 34 | cav_lidar_range: &cav_lidar [-140.8, -40, -3, 140.8, 40, 1] 35 | 36 | data_augment: 37 | - NAME: random_world_flip 38 | ALONG_AXIS_LIST: [ 'x' ] 39 | 40 | - NAME: random_world_rotation 41 | WORLD_ROT_ANGLE: [ -0.78539816, 0.78539816 ] 42 | 43 | - NAME: random_world_scaling 44 | WORLD_SCALE_RANGE: [ 0.95, 1.05 ] 45 | 46 | # anchor box related 47 | postprocess: 48 | core_method: 'VoxelPostprocessor' # VoxelPostprocessor, BevPostprocessor supported 49 | anchor_args: 50 | cav_lidar_range: *cav_lidar 51 | l: 3.9 52 | w: 1.6 53 | h: 1.56 54 | r: [0, 90] 55 | num: &achor_num 2 56 | feature_stride: 4 57 | target_args: 58 | pos_threshold: 0.6 59 | neg_threshold: 0.45 60 | score_threshold: 0.20 61 | order: 'hwl' # hwl or lwh 62 | max_num: 100 # maximum number of objects in a single frame. use this number to make sure different frames has the same dimension in the same batch 63 | nms_thresh: 0.15 64 | 65 | # model related 66 | model: 67 | core_method: point_pillar 68 | args: 69 | voxel_size: *voxel_size 70 | lidar_range: *cav_lidar 71 | anchor_number: *achor_num 72 | pillar_vfe: 73 | use_norm: true 74 | with_distance: false 75 | use_absolute_xyz: true 76 | num_filters: [64] 77 | point_pillar_scatter: 78 | num_features: 64 79 | 80 | base_bev_backbone: 81 | layer_nums: [3, 5, 8] 82 | layer_strides: [2, 2, 2] 83 | num_filters: [64, 128, 256] 84 | upsample_strides: [1, 2, 4] 85 | num_upsample_filter: [128, 128, 128] 86 | 87 | shrink_header: 88 | kernal_size: [ 3 ] 89 | stride: [ 2 ] 90 | padding: [ 1 ] 91 | dim: [ 256 ] 92 | input_dim: 384 # 128 * 3 93 | 94 | cls_head_dim: 256 95 | 96 | anchor_num: *achor_num 97 | 98 | loss: 99 | core_method: point_pillar_loss 100 | args: 101 | cls_weight: 1.0 102 | reg: 2.0 103 | 104 | optimizer: 105 | core_method: Adam 106 | lr: 0.002 107 | args: 108 | eps: 1e-10 109 | weight_decay: 1e-4 110 | 111 | lr_scheduler: 112 | core_method: multistep #step, multistep and Exponential support 113 | gamma: 0.1 114 | step_size: [20, 30] 115 | 116 | -------------------------------------------------------------------------------- /v2xvit/hypes_yaml/point_pillar_fcooper.yaml: -------------------------------------------------------------------------------- 1 | name: point_pillar_fcooper 2 | root_dir: 'v2xset/train' 3 | validate_dir: 'v2xset/validate' 4 | wild_setting: 5 | async: false 6 | async_overhead: 100 7 | seed: 20 8 | loc_err: false 9 | xyz_std: 0.2 10 | ryp_std: 0.2 11 | data_size: 1.06 # Mb!! 12 | transmission_speed: 27 # Mbps!! 13 | backbone_delay: 10 # ms 14 | 15 | yaml_parser: "load_point_pillar_params" 16 | train_params: 17 | batch_size: &batch_size 4 18 | epoches: 60 19 | eval_freq: 1 20 | save_freq: 1 21 | max_cav: &max_cav 5 22 | 23 | fusion: 24 | core_method: 'IntermediateFusionDataset' # LateFusionDataset, EarlyFusionDataset, IntermediateFusionDataset supported 25 | args: 26 | cur_ego_pose_flag: True 27 | # when the cur_ego_pose_flag is set to True, there is no time gap 28 | # between the time when the LiDAR data is captured by connected 29 | # agents and when the extracted features are received by 30 | # the ego vehicle, which is equal to implement STCM. When set to False, 31 | # STCM has to be used. 32 | 33 | 34 | # preprocess-related 35 | preprocess: 36 | # options: BasePreprocessor, VoxelPreprocessor, BevPreprocessor 37 | core_method: 'SpVoxelPreprocessor' 38 | args: 39 | voxel_size: &voxel_size [0.4, 0.4, 4] 40 | max_points_per_voxel: 32 41 | max_voxel_train: 32000 42 | max_voxel_test: 70000 43 | # lidar range for each individual cav. 44 | cav_lidar_range: &cav_lidar [-140.8, -38.4, -3, 140.8, 38.4, 1] 45 | 46 | data_augment: 47 | - NAME: random_world_flip 48 | ALONG_AXIS_LIST: [ 'x' ] 49 | 50 | - NAME: random_world_rotation 51 | WORLD_ROT_ANGLE: [ -0.78539816, 0.78539816 ] 52 | 53 | - NAME: random_world_scaling 54 | WORLD_SCALE_RANGE: [ 0.95, 1.05 ] 55 | 56 | # anchor box related 57 | postprocess: 58 | core_method: 'VoxelPostprocessor' # VoxelPostprocessor, BevPostprocessor supported 59 | anchor_args: 60 | cav_lidar_range: *cav_lidar 61 | l: 3.9 62 | w: 1.6 63 | h: 1.56 64 | r: [0, 90] 65 | feature_stride: 4 66 | num: &achor_num 2 67 | target_args: 68 | pos_threshold: 0.6 69 | neg_threshold: 0.45 70 | score_threshold: 0.20 71 | order: 'hwl' # hwl or lwh 72 | max_num: 100 # maximum number of objects in a single frame. use this number to make sure different frames has the same dimension in the same batch 73 | nms_thresh: 0.15 74 | 75 | # model related 76 | model: 77 | core_method: point_pillar_fcooper 78 | args: 79 | voxel_size: *voxel_size 80 | lidar_range: *cav_lidar 81 | anchor_number: *achor_num 82 | max_cav: *max_cav 83 | compression: 0 # compression rate 84 | backbone_fix: false 85 | 86 | pillar_vfe: 87 | use_norm: true 88 | with_distance: false 89 | use_absolute_xyz: true 90 | num_filters: [64] 91 | point_pillar_scatter: 92 | num_features: 64 93 | 94 | base_bev_backbone: 95 | layer_nums: [3, 5, 8] 96 | layer_strides: [2, 2, 2] 97 | num_filters: [64, 128, 256] 98 | upsample_strides: [1, 2, 4] 99 | num_upsample_filter: [128, 128, 128] 100 | shrink_header: 101 | kernal_size: [3] 102 | stride: [2] 103 | padding: [1] 104 | dim: [256] 105 | input_dim: 384 # 128 * 3 106 | 107 | # add decoder later 108 | 109 | loss: 110 | core_method: point_pillar_loss 111 | args: 112 | cls_weight: 1.0 113 | reg: 2.0 114 | 115 | optimizer: 116 | core_method: Adam 117 | lr: 0.001 118 | args: 119 | eps: 1e-10 120 | weight_decay: 1e-4 121 | 122 | lr_scheduler: 123 | core_method: multistep #step, multistep and Exponential support 124 | gamma: 0.1 125 | step_size: [15, 50] 126 | 127 | -------------------------------------------------------------------------------- /v2xvit/hypes_yaml/point_pillar_late_fusion.yaml: -------------------------------------------------------------------------------- 1 | name: point_pillar_late_fusion 2 | root_dir: 'v2xset/train' 3 | validate_dir: 'v2xset/validate' 4 | yaml_parser: "load_point_pillar_params" 5 | 6 | wild_setting: 7 | async: false 8 | async_overhead: 100 9 | seed: 20 10 | loc_err: false 11 | xyz_std: 0.2 12 | ryp_std: 0.2 13 | 14 | 15 | train_params: 16 | batch_size: &batch_size 8 17 | epoches: 25 18 | eval_freq: 1 19 | save_freq: 1 20 | 21 | fusion: 22 | core_method: 'LateFusionDataset' # LateFusionDataset, EarlyFusionDataset, IntermediateFusionDataset supported 23 | args: [] 24 | 25 | # preprocess-related 26 | preprocess: 27 | # options: BasePreprocessor, VoxelPreprocessor, BevPreprocessor 28 | core_method: 'SpVoxelPreprocessor' 29 | args: 30 | voxel_size: &voxel_size [0.4, 0.4, 4] 31 | max_points_per_voxel: 32 32 | max_voxel_train: 16000 33 | max_voxel_test: 40000 34 | # lidar range for each individual cav. 35 | cav_lidar_range: &cav_lidar [-70.4, -40, -3, 70.4, 40, 1] 36 | 37 | data_augment: 38 | - NAME: random_world_flip 39 | ALONG_AXIS_LIST: [ 'x' ] 40 | 41 | - NAME: random_world_rotation 42 | WORLD_ROT_ANGLE: [ -0.78539816, 0.78539816 ] 43 | 44 | - NAME: random_world_scaling 45 | WORLD_SCALE_RANGE: [ 0.95, 1.05 ] 46 | 47 | # anchor box related 48 | postprocess: 49 | core_method: 'VoxelPostprocessor' # VoxelPostprocessor, BevPostprocessor supported 50 | anchor_args: 51 | cav_lidar_range: *cav_lidar 52 | l: 3.9 53 | w: 1.6 54 | h: 1.56 55 | r: [0, 90] 56 | feature_stride: 4 57 | num: &achor_num 2 58 | target_args: 59 | pos_threshold: 0.6 60 | neg_threshold: 0.45 61 | score_threshold: 0.20 62 | order: 'hwl' # hwl or lwh 63 | max_num: 100 # maximum number of objects in a single frame. use this number to make sure different frames has the same dimension in the same batch 64 | nms_thresh: 0.15 65 | 66 | # model related 67 | model: 68 | core_method: point_pillar 69 | args: 70 | voxel_size: *voxel_size 71 | lidar_range: *cav_lidar 72 | anchor_number: *achor_num 73 | pillar_vfe: 74 | use_norm: true 75 | with_distance: false 76 | use_absolute_xyz: true 77 | num_filters: [64] 78 | point_pillar_scatter: 79 | num_features: 64 80 | 81 | base_bev_backbone: 82 | layer_nums: [3, 5, 8] 83 | layer_strides: [2, 2, 2] 84 | num_filters: [64, 128, 256] 85 | upsample_strides: [1, 2, 4] 86 | num_upsample_filter: [128, 128, 128] 87 | 88 | shrink_header: 89 | kernal_size: [ 3 ] 90 | stride: [ 2 ] 91 | padding: [ 1 ] 92 | dim: [ 256 ] 93 | input_dim: 384 # 128 * 3 94 | 95 | cls_head_dim: 256 96 | 97 | anchor_num: *achor_num 98 | 99 | loss: 100 | core_method: point_pillar_loss 101 | args: 102 | cls_weight: 1.0 103 | reg: 2.0 104 | 105 | optimizer: 106 | core_method: Adam 107 | lr: 0.002 108 | args: 109 | eps: 1e-10 110 | weight_decay: 1e-4 111 | 112 | lr_scheduler: 113 | core_method: multistep #step, multistep and Exponential support 114 | gamma: 0.1 115 | step_size: [20, 30] 116 | 117 | -------------------------------------------------------------------------------- /v2xvit/hypes_yaml/point_pillar_opv2v.yaml: -------------------------------------------------------------------------------- 1 | name: point_pillar_opv2v 2 | #root_dir: '/home/runshengxu/project/Cooperative_perception/opencood/tmp' 3 | root_dir: 'v2xset/train' 4 | validate_dir: 'v2xset/validate' 5 | 6 | wild_setting: 7 | async: false 8 | async_overhead: 100 9 | seed: 20 10 | loc_err: false 11 | xyz_std: 0.2 12 | ryp_std: 0.2 13 | data_size: 1.06 # Mb!! 14 | transmission_speed: 27 # Mbps!! 15 | backbone_delay: 10 # ms 16 | 17 | yaml_parser: "load_point_pillar_params" 18 | train_params: 19 | batch_size: &batch_size 2 20 | epoches: 60 21 | eval_freq: 1 22 | save_freq: 1 23 | max_cav: &max_cav 5 24 | 25 | fusion: 26 | core_method: 'IntermediateFusionDataset' # LateFusionDataset, EarlyFusionDataset, IntermediateFusionDataset supported 27 | args: 28 | cur_ego_pose_flag: True 29 | # when the cur_ego_pose_flag is set to True, there is no time gap 30 | # between the time when the LiDAR data is captured by connected 31 | # agents and when the extracted features are received by 32 | # the ego vehicle, which is equal to implement STCM. When set to False, 33 | # STCM has to be used. 34 | 35 | 36 | # preprocess-related 37 | preprocess: 38 | # options: BasePreprocessor, VoxelPreprocessor, BevPreprocessor 39 | core_method: 'SpVoxelPreprocessor' 40 | args: 41 | voxel_size: &voxel_size [0.4, 0.4, 4] 42 | max_points_per_voxel: 32 43 | max_voxel_train: 32000 44 | max_voxel_test: 70000 45 | # lidar range for each individual cav. 46 | cav_lidar_range: &cav_lidar [-140.8, -38.4, -3, 140.8, 38.4, 1] 47 | 48 | data_augment: 49 | - NAME: random_world_flip 50 | ALONG_AXIS_LIST: [ 'x' ] 51 | 52 | - NAME: random_world_rotation 53 | WORLD_ROT_ANGLE: [ -0.78539816, 0.78539816 ] 54 | 55 | - NAME: random_world_scaling 56 | WORLD_SCALE_RANGE: [ 0.95, 1.05 ] 57 | 58 | # anchor box related 59 | postprocess: 60 | core_method: 'VoxelPostprocessor' # VoxelPostprocessor, BevPostprocessor supported 61 | anchor_args: 62 | cav_lidar_range: *cav_lidar 63 | l: 3.9 64 | w: 1.6 65 | h: 1.56 66 | r: [0, 90] 67 | feature_stride: 4 68 | num: &achor_num 2 69 | target_args: 70 | pos_threshold: 0.6 71 | neg_threshold: 0.45 72 | score_threshold: 0.20 73 | order: 'hwl' # hwl or lwh 74 | max_num: 100 # maximum number of objects in a single frame. use this number to make sure different frames has the same dimension in the same batch 75 | nms_thresh: 0.15 76 | 77 | # model related 78 | model: 79 | core_method: point_pillar_opv2v 80 | args: 81 | voxel_size: *voxel_size 82 | lidar_range: *cav_lidar 83 | anchor_number: *achor_num 84 | max_cav: *max_cav 85 | compression: 32 # compression rate 86 | backbone_fix: false 87 | 88 | pillar_vfe: 89 | use_norm: true 90 | with_distance: false 91 | use_absolute_xyz: true 92 | num_filters: [64] 93 | point_pillar_scatter: 94 | num_features: 64 95 | 96 | base_bev_backbone: 97 | layer_nums: [3, 5, 8] 98 | layer_strides: [2, 2, 2] 99 | num_filters: [64, 128, 256] 100 | upsample_strides: [1, 2, 4] 101 | num_upsample_filter: [128, 128, 128] 102 | shrink_header: 103 | kernal_size: [3] 104 | stride: [2] 105 | padding: [1] 106 | dim: [256] 107 | input_dim: 384 # 128 * 3 108 | 109 | # add decoder later 110 | 111 | loss: 112 | core_method: point_pillar_loss 113 | args: 114 | cls_weight: 1.0 115 | reg: 2.0 116 | 117 | optimizer: 118 | core_method: Adam 119 | lr: 0.001 120 | args: 121 | eps: 1e-10 122 | weight_decay: 1e-4 123 | 124 | lr_scheduler: 125 | core_method: multistep #step, multistep and Exponential support 126 | gamma: 0.1 127 | step_size: [15, 50] 128 | 129 | -------------------------------------------------------------------------------- /v2xvit/hypes_yaml/point_pillar_v2vnet.yaml: -------------------------------------------------------------------------------- 1 | name: point_pillar_v2vnet 2 | root_dir: 'v2xset/train' 3 | validate_dir: 'v2xset/validate' 4 | 5 | wild_setting: 6 | async: false 7 | async_overhead: 100 8 | seed: 20 9 | loc_err: false 10 | xyz_std: 0.2 11 | ryp_std: 0.2 12 | data_size: 1.06 # Mb!! 13 | transmission_speed: 27 # Mbps!! 14 | backbone_delay: 10 # ms 15 | 16 | yaml_parser: "load_point_pillar_params" 17 | train_params: 18 | batch_size: &batch_size 4 19 | epoches: 60 20 | eval_freq: 1 21 | save_freq: 1 22 | max_cav: &max_cav 5 23 | 24 | fusion: 25 | core_method: 'IntermediateFusionDataset' # LateFusionDataset, EarlyFusionDataset, IntermediateFusionDataset supported 26 | args: 27 | cur_ego_pose_flag: True 28 | 29 | # preprocess-related 30 | preprocess: 31 | # options: BasePreprocessor, VoxelPreprocessor, BevPreprocessor 32 | core_method: 'SpVoxelPreprocessor' 33 | args: 34 | voxel_size: &voxel_size [0.4, 0.4, 4] 35 | max_points_per_voxel: 32 36 | max_voxel_train: 32000 37 | max_voxel_test: 70000 38 | # lidar range for each individual cav. 39 | cav_lidar_range: &cav_lidar [-140.8, -38.4, -3, 140.8, 38.4, 1] 40 | 41 | data_augment: 42 | - NAME: random_world_flip 43 | ALONG_AXIS_LIST: [ 'x' ] 44 | 45 | - NAME: random_world_rotation 46 | WORLD_ROT_ANGLE: [ -0.78539816, 0.78539816 ] 47 | 48 | - NAME: random_world_scaling 49 | WORLD_SCALE_RANGE: [ 0.95, 1.05 ] 50 | 51 | # anchor box related 52 | postprocess: 53 | core_method: 'VoxelPostprocessor' # VoxelPostprocessor, BevPostprocessor supported 54 | anchor_args: 55 | cav_lidar_range: *cav_lidar 56 | l: 3.9 57 | w: 1.6 58 | h: 1.56 59 | r: [0, 90] 60 | feature_stride: 4 61 | num: &achor_num 2 62 | target_args: 63 | pos_threshold: 0.6 64 | neg_threshold: 0.45 65 | score_threshold: 0.20 66 | order: 'hwl' # hwl or lwh 67 | max_num: 100 # maximum number of objects in a single frame. use this number to make sure different frames has the same dimension in the same batch 68 | nms_thresh: 0.15 69 | 70 | # model related 71 | model: 72 | core_method: point_pillar_v2vnet 73 | args: 74 | voxel_size: *voxel_size 75 | lidar_range: *cav_lidar 76 | anchor_number: *achor_num 77 | max_cav: *max_cav 78 | compression: 0 # compression rate 79 | backbone_fix: false 80 | 81 | pillar_vfe: 82 | use_norm: true 83 | with_distance: false 84 | use_absolute_xyz: true 85 | num_filters: [64] 86 | point_pillar_scatter: 87 | num_features: 64 88 | 89 | base_bev_backbone: 90 | layer_nums: [3, 5, 8] 91 | layer_strides: [2, 2, 2] 92 | num_filters: [64, 128, 256] 93 | upsample_strides: [1, 2, 4] 94 | num_upsample_filter: [128, 128, 128] 95 | shrink_header: 96 | kernal_size: [3] 97 | stride: [2] 98 | padding: [1] 99 | dim: [256] 100 | input_dim: 384 # 128 * 3 101 | 102 | v2vfusion: 103 | use_temporal_encoding: true 104 | voxel_size: *voxel_size 105 | downsample_rate: 4 106 | num_iteration: 3 107 | in_channels: 256 108 | gru_flag: false 109 | agg_operator: "avg" # max or avg 110 | conv_gru: 111 | H: 48 112 | W: 176 113 | num_layers: 1 114 | kernel_size: [[3,3]] 115 | 116 | 117 | # add decoder later 118 | 119 | loss: 120 | core_method: point_pillar_loss 121 | args: 122 | cls_weight: 1.0 123 | reg: 2.0 124 | 125 | optimizer: 126 | core_method: Adam 127 | lr: 0.001 128 | args: 129 | eps: 1e-10 130 | weight_decay: 1e-4 131 | 132 | lr_scheduler: 133 | core_method: multistep #step, multistep and Exponential support 134 | gamma: 0.1 135 | step_size: [15, 50] 136 | 137 | -------------------------------------------------------------------------------- /v2xvit/hypes_yaml/point_pillar_v2xvit.yaml: -------------------------------------------------------------------------------- 1 | name: point_pillar_v2xvit 2 | root_dir: 'v2xset/train' 3 | validate_dir: 'v2xset/validate' 4 | 5 | wild_setting: 6 | async: false 7 | async_mode: 'sim' 8 | async_overhead: 100 9 | seed: 25 10 | loc_err: false 11 | xyz_std: 0.2 12 | ryp_std: 0.2 13 | data_size: 1.06 # Mb!! 14 | transmission_speed: 27 # Mbps!! 15 | backbone_delay: 10 # ms 16 | 17 | yaml_parser: "load_point_pillar_params" 18 | train_params: 19 | batch_size: &batch_size 2 20 | epoches: 60 21 | eval_freq: 1 22 | save_freq: 1 23 | max_cav: &max_cav 5 24 | 25 | fusion: 26 | core_method: 'IntermediateFusionDataset' # LateFusionDataset, EarlyFusionDataset, IntermediateFusionDataset supported 27 | args: 28 | cur_ego_pose_flag: False 29 | # when the cur_ego_pose_flag is set to True, there is no time gap 30 | # between the time when the LiDAR data is captured by connected 31 | # agents and when the extracted features are received by 32 | # the ego vehicle, which is equal to implement STCM. When set to False, 33 | # STCM has to be used. 34 | 35 | # preprocess-related 36 | preprocess: 37 | # options: BasePreprocessor, VoxelPreprocessor, BevPreprocessor 38 | core_method: 'SpVoxelPreprocessor' 39 | args: 40 | voxel_size: &voxel_size [0.4, 0.4, 4] 41 | max_points_per_voxel: 32 42 | max_voxel_train: 32000 43 | max_voxel_test: 70000 44 | # lidar range for each individual cav. 45 | cav_lidar_range: &cav_lidar [-140.8, -38.4, -3, 140.8, 38.4, 1] 46 | 47 | data_augment: 48 | - NAME: random_world_flip 49 | ALONG_AXIS_LIST: [ 'x' ] 50 | 51 | - NAME: random_world_rotation 52 | WORLD_ROT_ANGLE: [ -0.78539816, 0.78539816 ] 53 | 54 | - NAME: random_world_scaling 55 | WORLD_SCALE_RANGE: [ 0.95, 1.05 ] 56 | 57 | # anchor box related 58 | postprocess: 59 | core_method: 'VoxelPostprocessor' # VoxelPostprocessor, BevPostprocessor supported 60 | anchor_args: 61 | cav_lidar_range: *cav_lidar 62 | l: 3.9 63 | w: 1.6 64 | h: 1.56 65 | r: [0, 90] 66 | feature_stride: 4 67 | num: &achor_num 2 68 | target_args: 69 | pos_threshold: 0.6 70 | neg_threshold: 0.45 71 | score_threshold: 0.27 72 | order: 'hwl' # hwl or lwh 73 | max_num: 100 # maximum number of objects in a single frame. use this number to make sure different frames has the same dimension in the same batch 74 | nms_thresh: 0.15 75 | 76 | # model related 77 | model: 78 | core_method: point_pillar_transformer 79 | args: 80 | voxel_size: *voxel_size 81 | lidar_range: *cav_lidar 82 | anchor_number: *achor_num 83 | max_cav: *max_cav 84 | compression: 0 # compression rate 85 | backbone_fix: false 86 | 87 | pillar_vfe: 88 | use_norm: true 89 | with_distance: false 90 | use_absolute_xyz: true 91 | num_filters: [64] 92 | point_pillar_scatter: 93 | num_features: 64 94 | 95 | base_bev_backbone: 96 | layer_nums: [3, 5, 8] 97 | layer_strides: [2, 2, 2] 98 | num_filters: [64, 128, 256] 99 | upsample_strides: [1, 2, 4] 100 | num_upsample_filter: [128, 128, 128] 101 | shrink_header: 102 | kernal_size: [3] 103 | stride: [2] 104 | padding: [1] 105 | dim: [256] 106 | input_dim: 384 # 128 * 3 107 | 108 | transformer: 109 | encoder: &encoder 110 | # number of fusion blocks per encoder layer 111 | num_blocks: 1 112 | # number of encoder layers 113 | depth: 3 114 | use_roi_mask: true 115 | use_RTE: &use_RTE true 116 | RTE_ratio: &RTE_ratio 2 # 2 means the dt has 100ms interval while 1 means 50 ms interval 117 | # agent-wise attention 118 | cav_att_config: &cav_att_config 119 | dim: 256 120 | use_hetero: true 121 | use_RTE: *use_RTE 122 | RTE_ratio: *RTE_ratio 123 | heads: 8 124 | dim_head: 32 125 | dropout: 0.3 126 | # spatial-wise attention 127 | pwindow_att_config: &pwindow_att_config 128 | dim: 256 129 | heads: [16, 8, 4] 130 | dim_head: [16, 32, 64] 131 | dropout: 0.3 132 | window_size: [4, 8, 16] 133 | relative_pos_embedding: true 134 | fusion_method: 'split_attn' 135 | # feedforward condition 136 | feed_forward: &feed_forward 137 | mlp_dim: 256 138 | dropout: 0.3 139 | sttf: &sttf 140 | voxel_size: *voxel_size 141 | downsample_rate: 4 142 | 143 | # add decoder later 144 | 145 | loss: 146 | core_method: point_pillar_loss 147 | args: 148 | cls_weight: 1.0 149 | reg: 2.0 150 | 151 | optimizer: 152 | core_method: Adam 153 | lr: 0.001 154 | args: 155 | eps: 1e-10 156 | weight_decay: 1e-4 157 | 158 | lr_scheduler: 159 | core_method: multistep #step, multistep and Exponential support 160 | gamma: 0.1 161 | step_size: [15, 50] 162 | 163 | -------------------------------------------------------------------------------- /v2xvit/hypes_yaml/visualization.yaml: -------------------------------------------------------------------------------- 1 | # this yaml is only for visualization 2 | name: visualization 3 | 4 | yaml_parser: "load_voxel_params" 5 | root_dir: 'v2xset/train' 6 | validate_dir: 'v2xset/validate' 7 | 8 | train_params: 9 | batch_size: &batch_size 4 10 | epoches: 100 11 | eval_freq: 1 12 | save_freq: 1 13 | 14 | fusion: 15 | core_method: 'EarlyFusionDataset' # LateFusionDataset, EarlyFusionDataset, IntermediateFusionDataset supported 16 | args: [] 17 | 18 | # preprocess-related 19 | preprocess: 20 | # options: BasePreprocessor, VoxelPreprocessor, BevPreprocessor 21 | core_method: 'SpVoxelPreprocessor' 22 | args: 23 | voxel_size: &voxel_size [0.4, 0.4, 0.4] 24 | max_points_per_voxel: &T 32 25 | max_voxel_train: 36000 26 | max_voxel_test: 70000 27 | # lidar range for each individual cav. 28 | cav_lidar_range: &cav_lidar [-140.8, -40, -3, 140.8, 40, 1] 29 | 30 | data_augment: 31 | - NAME: random_world_flip 32 | ALONG_AXIS_LIST: [ 'x' ] 33 | 34 | - NAME: random_world_rotation 35 | WORLD_ROT_ANGLE: [ -0.78539816, 0.78539816 ] 36 | 37 | - NAME: random_world_scaling 38 | WORLD_SCALE_RANGE: [ 0.95, 1.05 ] 39 | 40 | # anchor box related 41 | postprocess: 42 | core_method: 'VoxelPostprocessor' # VoxelPostprocessor, BevPostprocessor supported 43 | anchor_args: 44 | cav_lidar_range: *cav_lidar 45 | l: 3.9 46 | w: 1.6 47 | h: 1.56 48 | r: [0, 90] 49 | num: &achor_num 2 50 | target_args: 51 | pos_threshold: 0.6 52 | neg_threshold: 0.45 53 | score_threshold: 0.96 54 | order: 'hwl' # hwl or lwh 55 | max_num: 100 # maximum number of objects in a single frame. use this number to make sure different frames has the same dimension in the same batch 56 | nms_thresh: 0.15 -------------------------------------------------------------------------------- /v2xvit/hypes_yaml/yaml_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import yaml 3 | import os 4 | import math 5 | 6 | import numpy as np 7 | 8 | 9 | def load_yaml(file, opt=None): 10 | """ 11 | Load yaml file and return a dictionary. 12 | 13 | Parameters 14 | ---------- 15 | file : string 16 | yaml file path. 17 | 18 | opt : argparser 19 | Argparser. 20 | Returns 21 | ------- 22 | param : dict 23 | A dictionary that contains defined parameters. 24 | """ 25 | if opt and opt.model_dir: 26 | file = os.path.join(opt.model_dir, 'config.yaml') 27 | 28 | stream = open(file, 'r') 29 | loader = yaml.Loader 30 | loader.add_implicit_resolver( 31 | u'tag:yaml.org,2002:float', 32 | re.compile(u'''^(?: 33 | [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)? 34 | |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+) 35 | |\\.[0-9_]+(?:[eE][-+][0-9]+)? 36 | |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]* 37 | |[-+]?\\.(?:inf|Inf|INF) 38 | |\\.(?:nan|NaN|NAN))$''', re.X), 39 | list(u'-+0123456789.')) 40 | param = yaml.load(stream, Loader=loader) 41 | if "yaml_parser" in param: 42 | param = eval(param["yaml_parser"])(param) 43 | 44 | return param 45 | 46 | 47 | def load_voxel_params(param): 48 | """ 49 | Based on the lidar range and resolution of voxel, calcuate the anchor box 50 | and target resolution. 51 | 52 | Parameters 53 | ---------- 54 | param : dict 55 | Original loaded parameter dictionary. 56 | 57 | Returns 58 | ------- 59 | param : dict 60 | Modified parameter dictionary with new attribute `anchor_args[W][H][L]` 61 | """ 62 | anchor_args = param['postprocess']['anchor_args'] 63 | cav_lidar_range = anchor_args['cav_lidar_range'] 64 | voxel_size = param['preprocess']['args']['voxel_size'] 65 | 66 | vw = voxel_size[0] 67 | vh = voxel_size[1] 68 | vd = voxel_size[2] 69 | 70 | anchor_args['vw'] = vw 71 | anchor_args['vh'] = vh 72 | anchor_args['vd'] = vd 73 | 74 | anchor_args['W'] = int((cav_lidar_range[3] - cav_lidar_range[0]) / vw) 75 | anchor_args['H'] = int((cav_lidar_range[4] - cav_lidar_range[1]) / vh) 76 | anchor_args['D'] = int((cav_lidar_range[5] - cav_lidar_range[2]) / vd) 77 | 78 | param['postprocess'].update({'anchor_args': anchor_args}) 79 | # sometimes we just want to visualize the data without implementing model 80 | if 'model' in param: 81 | param['model']['args']['W'] = anchor_args['W'] 82 | param['model']['args']['H'] = anchor_args['H'] 83 | param['model']['args']['D'] = anchor_args['D'] 84 | return param 85 | 86 | 87 | def load_point_pillar_params(param): 88 | """ 89 | Based on the lidar range and resolution of voxel, calcuate the anchor box 90 | and target resolution. 91 | 92 | Parameters 93 | ---------- 94 | param : dict 95 | Original loaded parameter dictionary. 96 | 97 | Returns 98 | ------- 99 | param : dict 100 | Modified parameter dictionary with new attribute. 101 | """ 102 | cav_lidar_range = param['preprocess']['cav_lidar_range'] 103 | voxel_size = param['preprocess']['args']['voxel_size'] 104 | 105 | grid_size = (np.array(cav_lidar_range[3:6]) - np.array( 106 | cav_lidar_range[0:3])) / \ 107 | np.array(voxel_size) 108 | grid_size = np.round(grid_size).astype(np.int64) 109 | param['model']['args']['point_pillar_scatter']['grid_size'] = grid_size 110 | 111 | anchor_args = param['postprocess']['anchor_args'] 112 | 113 | vw = voxel_size[0] 114 | vh = voxel_size[1] 115 | vd = voxel_size[2] 116 | 117 | anchor_args['vw'] = vw 118 | anchor_args['vh'] = vh 119 | anchor_args['vd'] = vd 120 | 121 | anchor_args['W'] = math.ceil((cav_lidar_range[3] - cav_lidar_range[0]) / vw) 122 | anchor_args['H'] = math.ceil((cav_lidar_range[4] - cav_lidar_range[1]) / vh) 123 | anchor_args['D'] = math.ceil((cav_lidar_range[5] - cav_lidar_range[2]) / vd) 124 | 125 | param['postprocess'].update({'anchor_args': anchor_args}) 126 | 127 | return param 128 | 129 | def load_second_params(param): 130 | """ 131 | Based on the lidar range and resolution of voxel, calcuate the anchor box 132 | and target resolution. 133 | 134 | Parameters 135 | ---------- 136 | param : dict 137 | Original loaded parameter dictionary. 138 | 139 | Returns 140 | ------- 141 | param : dict 142 | Modified parameter dictionary with new attribute. 143 | """ 144 | cav_lidar_range = param['preprocess']['cav_lidar_range'] 145 | voxel_size = param['preprocess']['args']['voxel_size'] 146 | 147 | grid_size = (np.array(cav_lidar_range[3:6]) - np.array( 148 | cav_lidar_range[0:3])) / \ 149 | np.array(voxel_size) 150 | grid_size = np.round(grid_size).astype(np.int64) 151 | param['model']['args']['grid_size'] = grid_size 152 | 153 | anchor_args = param['postprocess']['anchor_args'] 154 | 155 | vw = voxel_size[0] 156 | vh = voxel_size[1] 157 | vd = voxel_size[2] 158 | 159 | anchor_args['vw'] = vw 160 | anchor_args['vh'] = vh 161 | anchor_args['vd'] = vd 162 | 163 | anchor_args['W'] = math.ceil((cav_lidar_range[3] - cav_lidar_range[0]) / vw) 164 | anchor_args['H'] = math.ceil((cav_lidar_range[4] - cav_lidar_range[1]) / vh) 165 | anchor_args['D'] = math.ceil((cav_lidar_range[5] - cav_lidar_range[2]) / vd) 166 | 167 | param['postprocess'].update({'anchor_args': anchor_args}) 168 | 169 | return param 170 | 171 | def load_bev_params(param): 172 | """ 173 | Load bev related geometry parameters s.t. boundary, resolutions, input 174 | shape, target shape etc. 175 | 176 | Parameters 177 | ---------- 178 | param : dict 179 | Original loaded parameter dictionary. 180 | 181 | Returns 182 | ------- 183 | param : dict 184 | Modified parameter dictionary with new attribute `geometry_param`. 185 | 186 | """ 187 | res = param["preprocess"]["args"]["res"] 188 | L1, W1, H1, L2, W2, H2 = param["preprocess"]["cav_lidar_range"] 189 | downsample_rate = param["preprocess"]["args"]["downsample_rate"] 190 | 191 | def f(low, high, r): 192 | return int((high - low) / r) 193 | 194 | input_shape = ( 195 | int((f(L1, L2, res))), 196 | int((f(W1, W2, res))), 197 | int((f(H1, H2, res)) + 1) 198 | ) 199 | label_shape = ( 200 | int(input_shape[0] / downsample_rate), 201 | int(input_shape[1] / downsample_rate), 202 | 7 203 | ) 204 | geometry_param = { 205 | 'L1': L1, 206 | 'L2': L2, 207 | 'W1': W1, 208 | 'W2': W2, 209 | 'H1': H1, 210 | 'H2': H2, 211 | "downsample_rate": downsample_rate, 212 | "input_shape": input_shape, 213 | "label_shape": label_shape, 214 | "res": res 215 | } 216 | param["preprocess"]["geometry_param"] = geometry_param 217 | param["postprocess"]["geometry_param"] = geometry_param 218 | param["model"]["args"]["geometry_param"] = geometry_param 219 | return param 220 | 221 | 222 | def save_yaml(data, save_name): 223 | """ 224 | Save the dictionary into a yaml file. 225 | 226 | Parameters 227 | ---------- 228 | data : dict 229 | The dictionary contains all data. 230 | 231 | save_name : string 232 | Full path of the output yaml file. 233 | """ 234 | 235 | with open(save_name, 'w') as outfile: 236 | yaml.dump(data, outfile, default_flow_style=False) 237 | -------------------------------------------------------------------------------- /v2xvit/loss/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DerrickXuNu/v2x-vit/f0e6c13f41e916548b2d8aba61e42a18ce980416/v2xvit/loss/__init__.py -------------------------------------------------------------------------------- /v2xvit/loss/pixor_loss.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class PixorLoss(nn.Module): 9 | def __init__(self, args): 10 | super(PixorLoss, self).__init__() 11 | self.alpha = args["alpha"] 12 | self.beta = args["beta"] 13 | self.loss_dict = {} 14 | 15 | def forward(self, output_dict, target_dict): 16 | """ 17 | Compute loss for pixor network 18 | Parameters 19 | ---------- 20 | output_dict : dict 21 | The dictionary that contains the output. 22 | 23 | target_dict : dict 24 | The dictionary that contains the target. 25 | 26 | Returns 27 | ------- 28 | total_loss : torch.Tensor 29 | Total loss. 30 | 31 | """ 32 | targets = target_dict["label_map"] 33 | cls_preds, loc_preds = output_dict["cls"], output_dict["reg"] 34 | 35 | cls_targets, loc_targets = targets.split([1, 6], dim=1) 36 | pos_count = cls_targets.sum() 37 | neg_count = (cls_targets == 0).sum() 38 | w1, w2 = neg_count / (pos_count + neg_count), pos_count / ( 39 | pos_count + neg_count) 40 | weights = torch.ones_like(cls_preds.reshape(-1)) 41 | weights[cls_targets.reshape(-1) == 1] = w1 42 | weights[cls_targets.reshape(-1) == 0] = w2 43 | # cls_targets = cls_targets.float() 44 | # cls_loss = F.binary_cross_entropy_with_logits(input=cls_preds.reshape(-1), target=cls_targets.reshape(-1), weight=weights, 45 | # reduction='mean') 46 | cls_loss = F.binary_cross_entropy_with_logits( 47 | input=cls_preds, target=cls_targets, 48 | reduction='mean') 49 | pos_pixels = cls_targets.sum() 50 | 51 | loc_loss = F.smooth_l1_loss(cls_targets * loc_preds, 52 | cls_targets * loc_targets, 53 | reduction='sum') 54 | loc_loss = loc_loss / pos_pixels if pos_pixels > 0 else loc_loss 55 | 56 | total_loss = self.alpha * cls_loss + self.beta * loc_loss 57 | 58 | self.loss_dict.update({'total_loss': total_loss, 59 | 'reg_loss': loc_loss, 60 | 'cls_loss': cls_loss}) 61 | 62 | return total_loss 63 | 64 | def logging(self, epoch, batch_id, batch_len, writer): 65 | """ 66 | Print out the loss function for current iteration. 67 | 68 | Parameters 69 | ---------- 70 | epoch : int 71 | Current epoch for training. 72 | batch_id : int 73 | The current batch. 74 | batch_len : int 75 | Total batch length in one iteration of training, 76 | writer : SummaryWriter 77 | Used to visualize on tensorboard 78 | """ 79 | total_loss = self.loss_dict['total_loss'] 80 | reg_loss = self.loss_dict['reg_loss'] 81 | cls_loss = self.loss_dict['cls_loss'] 82 | 83 | print("[epoch %d][%d/%d], || Loss: %.4f || cls Loss: %.4f" 84 | " || reg Loss: %.4f" % ( 85 | epoch, batch_id + 1, batch_len, 86 | total_loss.item(), cls_loss.item(), reg_loss.item())) 87 | 88 | writer.add_scalar('Regression_loss', reg_loss.item(), 89 | epoch * batch_len + batch_id) 90 | writer.add_scalar('Confidence_loss', cls_loss.item(), 91 | epoch * batch_len + batch_id) 92 | 93 | 94 | def test(): 95 | torch.manual_seed(0) 96 | loss = PixorLoss(None) 97 | pred = torch.sigmoid(torch.randn(1, 7, 2, 3)) 98 | label = torch.zeros(1, 7, 2, 3) 99 | loss = loss(pred, label) 100 | print(loss) 101 | 102 | 103 | if __name__ == "__main__": 104 | test() 105 | -------------------------------------------------------------------------------- /v2xvit/loss/voxel_net_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class VoxelNetLoss(nn.Module): 7 | def __init__(self, args): 8 | super(VoxelNetLoss, self).__init__() 9 | self.smoothl1loss = nn.SmoothL1Loss(size_average=False) 10 | self.alpha = args['alpha'] 11 | self.beta = args['beta'] 12 | self.reg_coe = args['reg'] 13 | self.loss_dict = {} 14 | 15 | def forward(self, output_dict, target_dict): 16 | """ 17 | Parameters 18 | ---------- 19 | output_dict : dict 20 | target_dict : dict 21 | """ 22 | rm = output_dict['rm'] 23 | psm = output_dict['psm'] 24 | 25 | pos_equal_one = target_dict['pos_equal_one'] 26 | neg_equal_one = target_dict['neg_equal_one'] 27 | targets = target_dict['targets'] 28 | 29 | p_pos = F.sigmoid(psm.permute(0, 2, 3, 1)) 30 | rm = rm.permute(0, 2, 3, 1).contiguous() 31 | rm = rm.view(rm.size(0), rm.size(1), rm.size(2), -1, 7) 32 | targets = targets.view(targets.size(0), targets.size(1), 33 | targets.size(2), -1, 7) 34 | pos_equal_one_for_reg = pos_equal_one.unsqueeze( 35 | pos_equal_one.dim()).expand(-1, -1, -1, -1, 7) 36 | 37 | rm_pos = rm * pos_equal_one_for_reg 38 | targets_pos = targets * pos_equal_one_for_reg 39 | 40 | cls_pos_loss = -pos_equal_one * torch.log(p_pos + 1e-6) 41 | cls_pos_loss = cls_pos_loss.sum() / (pos_equal_one.sum() + 1e-6) 42 | 43 | cls_neg_loss = -neg_equal_one * torch.log(1 - p_pos + 1e-6) 44 | cls_neg_loss = cls_neg_loss.sum() / (neg_equal_one.sum() + 1e-6) 45 | 46 | reg_loss = self.smoothl1loss(rm_pos, targets_pos) 47 | reg_loss = reg_loss / (pos_equal_one.sum() + 1e-6) 48 | conf_loss = self.alpha * cls_pos_loss + self.beta * cls_neg_loss 49 | 50 | total_loss = self.reg_coe * reg_loss + conf_loss 51 | 52 | self.loss_dict.update({'total_loss': total_loss, 53 | 'reg_loss': reg_loss, 54 | 'conf_loss': conf_loss}) 55 | 56 | return total_loss 57 | 58 | def logging(self, epoch, batch_id, batch_len, writer): 59 | """ 60 | Print out the loss function for current iteration. 61 | 62 | Parameters 63 | ---------- 64 | epoch : int 65 | Current epoch for training. 66 | batch_id : int 67 | The current batch. 68 | batch_len : int 69 | Total batch length in one iteration of training, 70 | writer : SummaryWriter 71 | Used to visualize on tensorboard 72 | """ 73 | total_loss = self.loss_dict['total_loss'] 74 | reg_loss = self.loss_dict['reg_loss'] 75 | conf_loss = self.loss_dict['conf_loss'] 76 | 77 | print("[epoch %d][%d/%d], || Loss: %.4f || Conf Loss: %.4f" 78 | " || Loc Loss: %.4f" % ( 79 | epoch, batch_id + 1, batch_len, 80 | total_loss.item(), conf_loss.item(), reg_loss.item())) 81 | 82 | writer.add_scalar('Regression_loss', reg_loss.item(), 83 | epoch*batch_len + batch_id) 84 | writer.add_scalar('Confidence_loss', conf_loss.item(), 85 | epoch*batch_len + batch_id) 86 | -------------------------------------------------------------------------------- /v2xvit/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DerrickXuNu/v2x-vit/f0e6c13f41e916548b2d8aba61e42a18ce980416/v2xvit/models/__init__.py -------------------------------------------------------------------------------- /v2xvit/models/point_pillar.py: -------------------------------------------------------------------------------- 1 | """ 2 | Vanilla pointpillar for early and late fusion. 3 | """ 4 | import torch.nn as nn 5 | 6 | from v2xvit.models.sub_modules.pillar_vfe import PillarVFE 7 | from v2xvit.models.sub_modules.point_pillar_scatter import PointPillarScatter 8 | from v2xvit.models.sub_modules.base_bev_backbone import BaseBEVBackbone 9 | from v2xvit.models.sub_modules.downsample_conv import DownsampleConv 10 | 11 | 12 | class PointPillar(nn.Module): 13 | def __init__(self, args): 14 | super(PointPillar, self).__init__() 15 | 16 | # PIllar VFE 17 | self.pillar_vfe = PillarVFE(args['pillar_vfe'], 18 | num_point_features=4, 19 | voxel_size=args['voxel_size'], 20 | point_cloud_range=args['lidar_range']) 21 | self.scatter = PointPillarScatter(args['point_pillar_scatter']) 22 | self.backbone = BaseBEVBackbone(args['base_bev_backbone'], 64) 23 | # used to downsample the feature map for efficient computation 24 | self.shrink_flag = False 25 | if 'shrink_header' in args: 26 | self.shrink_flag = True 27 | self.shrink_conv = DownsampleConv(args['shrink_header']) 28 | 29 | self.cls_head = nn.Conv2d(args['cls_head_dim'], args['anchor_number'], 30 | kernel_size=1) 31 | self.reg_head = nn.Conv2d(args['cls_head_dim'], 32 | 7 * args['anchor_number'], 33 | kernel_size=1) 34 | 35 | def forward(self, data_dict): 36 | 37 | voxel_features = data_dict['processed_lidar']['voxel_features'] 38 | voxel_coords = data_dict['processed_lidar']['voxel_coords'] 39 | voxel_num_points = data_dict['processed_lidar']['voxel_num_points'] 40 | 41 | batch_dict = {'voxel_features': voxel_features, 42 | 'voxel_coords': voxel_coords, 43 | 'voxel_num_points': voxel_num_points} 44 | 45 | batch_dict = self.pillar_vfe(batch_dict) 46 | batch_dict = self.scatter(batch_dict) 47 | batch_dict = self.backbone(batch_dict) 48 | 49 | spatial_features_2d = batch_dict['spatial_features_2d'] 50 | if self.shrink_flag: 51 | spatial_features_2d = self.shrink_conv(spatial_features_2d) 52 | 53 | psm = self.cls_head(spatial_features_2d) 54 | rm = self.reg_head(spatial_features_2d) 55 | 56 | output_dict = {'psm': psm, 57 | 'rm': rm} 58 | 59 | return output_dict -------------------------------------------------------------------------------- /v2xvit/models/point_pillar_fcooper.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from v2xvit.models.sub_modules.pillar_vfe import PillarVFE 4 | from v2xvit.models.sub_modules.point_pillar_scatter import PointPillarScatter 5 | from v2xvit.models.sub_modules.base_bev_backbone import BaseBEVBackbone 6 | from v2xvit.models.sub_modules.downsample_conv import DownsampleConv 7 | from v2xvit.models.sub_modules.naive_compress import NaiveCompressor 8 | from v2xvit.models.sub_modules.f_cooper_fuse import SpatialFusion 9 | 10 | 11 | class PointPillarFCooper(nn.Module): 12 | def __init__(self, args): 13 | super(PointPillarFCooper, self).__init__() 14 | 15 | self.max_cav = args['max_cav'] 16 | # PIllar VFE 17 | self.pillar_vfe = PillarVFE(args['pillar_vfe'], 18 | num_point_features=4, 19 | voxel_size=args['voxel_size'], 20 | point_cloud_range=args['lidar_range']) 21 | self.scatter = PointPillarScatter(args['point_pillar_scatter']) 22 | self.backbone = BaseBEVBackbone(args['base_bev_backbone'], 64) 23 | # used to downsample the feature map for efficient computation 24 | self.shrink_flag = False 25 | if 'shrink_header' in args: 26 | self.shrink_flag = True 27 | self.shrink_conv = DownsampleConv(args['shrink_header']) 28 | self.compression = False 29 | 30 | if args['compression'] > 0: 31 | self.compression = True 32 | self.naive_compressor = NaiveCompressor(256, args['compression']) 33 | 34 | self.fusion_net = SpatialFusion() 35 | 36 | self.cls_head = nn.Conv2d(128 * 2, args['anchor_number'], 37 | kernel_size=1) 38 | self.reg_head = nn.Conv2d(128 * 2, 7 * args['anchor_number'], 39 | kernel_size=1) 40 | 41 | if args['backbone_fix']: 42 | self.backbone_fix() 43 | 44 | def backbone_fix(self): 45 | """ 46 | Fix the parameters of backbone during finetune on timedelay。 47 | """ 48 | for p in self.pillar_vfe.parameters(): 49 | p.requires_grad = False 50 | 51 | for p in self.scatter.parameters(): 52 | p.requires_grad = False 53 | 54 | for p in self.backbone.parameters(): 55 | p.requires_grad = False 56 | 57 | if self.compression: 58 | for p in self.naive_compressor.parameters(): 59 | p.requires_grad = False 60 | if self.shrink_flag: 61 | for p in self.shrink_conv.parameters(): 62 | p.requires_grad = False 63 | 64 | for p in self.cls_head.parameters(): 65 | p.requires_grad = False 66 | for p in self.reg_head.parameters(): 67 | p.requires_grad = False 68 | 69 | def forward(self, data_dict): 70 | voxel_features = data_dict['processed_lidar']['voxel_features'] 71 | voxel_coords = data_dict['processed_lidar']['voxel_coords'] 72 | voxel_num_points = data_dict['processed_lidar']['voxel_num_points'] 73 | record_len = data_dict['record_len'] 74 | spatial_correction_matrix = data_dict['spatial_correction_matrix'] 75 | 76 | batch_dict = {'voxel_features': voxel_features, 77 | 'voxel_coords': voxel_coords, 78 | 'voxel_num_points': voxel_num_points, 79 | 'record_len': record_len} 80 | # n, 4 -> n, c 81 | batch_dict = self.pillar_vfe(batch_dict) 82 | # n, c -> N, C, H, W 83 | batch_dict = self.scatter(batch_dict) 84 | batch_dict = self.backbone(batch_dict) 85 | 86 | spatial_features_2d = batch_dict['spatial_features_2d'] 87 | # downsample feature to reduce memory 88 | if self.shrink_flag: 89 | spatial_features_2d = self.shrink_conv(spatial_features_2d) 90 | # compressor 91 | if self.compression: 92 | spatial_features_2d = self.naive_compressor(spatial_features_2d) 93 | 94 | fused_feature = self.fusion_net(spatial_features_2d, record_len) 95 | 96 | psm = self.cls_head(fused_feature) 97 | rm = self.reg_head(fused_feature) 98 | 99 | output_dict = {'psm': psm, 100 | 'rm': rm} 101 | 102 | return output_dict 103 | -------------------------------------------------------------------------------- /v2xvit/models/point_pillar_opv2v.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from v2xvit.models.sub_modules.pillar_vfe import PillarVFE 4 | from v2xvit.models.sub_modules.point_pillar_scatter import PointPillarScatter 5 | from v2xvit.models.sub_modules.base_bev_backbone import BaseBEVBackbone 6 | from v2xvit.models.sub_modules.downsample_conv import DownsampleConv 7 | from v2xvit.models.sub_modules.naive_compress import NaiveCompressor 8 | from v2xvit.models.sub_modules.self_attn import AttFusion 9 | 10 | 11 | class PointPillarOPV2V(nn.Module): 12 | def __init__(self, args): 13 | super(PointPillarOPV2V, self).__init__() 14 | 15 | self.max_cav = args['max_cav'] 16 | # PIllar VFE 17 | self.pillar_vfe = PillarVFE(args['pillar_vfe'], 18 | num_point_features=4, 19 | voxel_size=args['voxel_size'], 20 | point_cloud_range=args['lidar_range']) 21 | self.scatter = PointPillarScatter(args['point_pillar_scatter']) 22 | self.backbone = BaseBEVBackbone(args['base_bev_backbone'], 64) 23 | # used to downsample the feature map for efficient computation 24 | self.shrink_flag = False 25 | if 'shrink_header' in args: 26 | self.shrink_flag = True 27 | self.shrink_conv = DownsampleConv(args['shrink_header']) 28 | self.compression = False 29 | 30 | if args['compression'] > 0: 31 | self.compression = True 32 | self.naive_compressor = NaiveCompressor(256, args['compression']) 33 | 34 | self.fusion_net = AttFusion(256) 35 | 36 | self.cls_head = nn.Conv2d(128 * 2, args['anchor_number'], 37 | kernel_size=1) 38 | self.reg_head = nn.Conv2d(128 * 2, 7 * args['anchor_number'], 39 | kernel_size=1) 40 | 41 | if args['backbone_fix']: 42 | self.backbone_fix() 43 | 44 | def backbone_fix(self): 45 | """ 46 | Fix the parameters of backbone during finetune on timedelay。 47 | """ 48 | for p in self.pillar_vfe.parameters(): 49 | p.requires_grad = False 50 | 51 | for p in self.scatter.parameters(): 52 | p.requires_grad = False 53 | 54 | for p in self.backbone.parameters(): 55 | p.requires_grad = False 56 | 57 | if self.compression: 58 | for p in self.naive_compressor.parameters(): 59 | p.requires_grad = False 60 | if self.shrink_flag: 61 | for p in self.shrink_conv.parameters(): 62 | p.requires_grad = False 63 | 64 | for p in self.cls_head.parameters(): 65 | p.requires_grad = False 66 | for p in self.reg_head.parameters(): 67 | p.requires_grad = False 68 | 69 | def forward(self, data_dict): 70 | voxel_features = data_dict['processed_lidar']['voxel_features'] 71 | voxel_coords = data_dict['processed_lidar']['voxel_coords'] 72 | voxel_num_points = data_dict['processed_lidar']['voxel_num_points'] 73 | record_len = data_dict['record_len'] 74 | spatial_correction_matrix = data_dict['spatial_correction_matrix'] 75 | 76 | # B, max_cav, 3(dt dv infra), 1, 1 77 | prior_encoding =\ 78 | data_dict['prior_encoding'].unsqueeze(-1).unsqueeze(-1) 79 | 80 | batch_dict = {'voxel_features': voxel_features, 81 | 'voxel_coords': voxel_coords, 82 | 'voxel_num_points': voxel_num_points, 83 | 'record_len': record_len} 84 | # n, 4 -> n, c 85 | batch_dict = self.pillar_vfe(batch_dict) 86 | # n, c -> N, C, H, W 87 | batch_dict = self.scatter(batch_dict) 88 | batch_dict = self.backbone(batch_dict) 89 | 90 | spatial_features_2d = batch_dict['spatial_features_2d'] 91 | # downsample feature to reduce memory 92 | if self.shrink_flag: 93 | spatial_features_2d = self.shrink_conv(spatial_features_2d) 94 | # compressor 95 | if self.compression: 96 | spatial_features_2d = self.naive_compressor(spatial_features_2d) 97 | 98 | fused_feature = self.fusion_net(spatial_features_2d, record_len) 99 | 100 | psm = self.cls_head(fused_feature) 101 | rm = self.reg_head(fused_feature) 102 | 103 | output_dict = {'psm': psm, 104 | 'rm': rm} 105 | 106 | return output_dict 107 | -------------------------------------------------------------------------------- /v2xvit/models/point_pillar_transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from v2xvit.models.sub_modules.pillar_vfe import PillarVFE 5 | from v2xvit.models.sub_modules.point_pillar_scatter import PointPillarScatter 6 | from v2xvit.models.sub_modules.base_bev_backbone import BaseBEVBackbone 7 | from v2xvit.models.sub_modules.fuse_utils import regroup 8 | from v2xvit.models.sub_modules.downsample_conv import DownsampleConv 9 | from v2xvit.models.sub_modules.naive_compress import NaiveCompressor 10 | from v2xvit.models.sub_modules.v2xvit_basic import V2XTransformer 11 | 12 | 13 | class PointPillarTransformer(nn.Module): 14 | def __init__(self, args): 15 | super(PointPillarTransformer, self).__init__() 16 | 17 | self.max_cav = args['max_cav'] 18 | # PIllar VFE 19 | self.pillar_vfe = PillarVFE(args['pillar_vfe'], 20 | num_point_features=4, 21 | voxel_size=args['voxel_size'], 22 | point_cloud_range=args['lidar_range']) 23 | self.scatter = PointPillarScatter(args['point_pillar_scatter']) 24 | self.backbone = BaseBEVBackbone(args['base_bev_backbone'], 64) 25 | # used to downsample the feature map for efficient computation 26 | self.shrink_flag = False 27 | if 'shrink_header' in args: 28 | self.shrink_flag = True 29 | self.shrink_conv = DownsampleConv(args['shrink_header']) 30 | self.compression = False 31 | 32 | if args['compression'] > 0: 33 | self.compression = True 34 | self.naive_compressor = NaiveCompressor(256, args['compression']) 35 | 36 | self.fusion_net = V2XTransformer(args['transformer']) 37 | 38 | self.cls_head = nn.Conv2d(128 * 2, args['anchor_number'], 39 | kernel_size=1) 40 | self.reg_head = nn.Conv2d(128 * 2, 7 * args['anchor_number'], 41 | kernel_size=1) 42 | 43 | if args['backbone_fix']: 44 | self.backbone_fix() 45 | 46 | def backbone_fix(self): 47 | """ 48 | Fix the parameters of backbone during finetune on timedelay。 49 | """ 50 | for p in self.pillar_vfe.parameters(): 51 | p.requires_grad = False 52 | 53 | for p in self.scatter.parameters(): 54 | p.requires_grad = False 55 | 56 | for p in self.backbone.parameters(): 57 | p.requires_grad = False 58 | 59 | if self.compression: 60 | for p in self.naive_compressor.parameters(): 61 | p.requires_grad = False 62 | if self.shrink_flag: 63 | for p in self.shrink_conv.parameters(): 64 | p.requires_grad = False 65 | 66 | for p in self.cls_head.parameters(): 67 | p.requires_grad = False 68 | for p in self.reg_head.parameters(): 69 | p.requires_grad = False 70 | 71 | def forward(self, data_dict): 72 | voxel_features = data_dict['processed_lidar']['voxel_features'] 73 | voxel_coords = data_dict['processed_lidar']['voxel_coords'] 74 | voxel_num_points = data_dict['processed_lidar']['voxel_num_points'] 75 | record_len = data_dict['record_len'] 76 | spatial_correction_matrix = data_dict['spatial_correction_matrix'] 77 | 78 | # B, max_cav, 3(dt dv infra), 1, 1 79 | prior_encoding =\ 80 | data_dict['prior_encoding'].unsqueeze(-1).unsqueeze(-1) 81 | 82 | batch_dict = {'voxel_features': voxel_features, 83 | 'voxel_coords': voxel_coords, 84 | 'voxel_num_points': voxel_num_points, 85 | 'record_len': record_len} 86 | # n, 4 -> n, c 87 | batch_dict = self.pillar_vfe(batch_dict) 88 | # n, c -> N, C, H, W 89 | batch_dict = self.scatter(batch_dict) 90 | batch_dict = self.backbone(batch_dict) 91 | 92 | spatial_features_2d = batch_dict['spatial_features_2d'] 93 | # downsample feature to reduce memory 94 | if self.shrink_flag: 95 | spatial_features_2d = self.shrink_conv(spatial_features_2d) 96 | # compressor 97 | if self.compression: 98 | spatial_features_2d = self.naive_compressor(spatial_features_2d) 99 | # N, C, H, W -> B, L, C, H, W 100 | regroup_feature, mask = regroup(spatial_features_2d, 101 | record_len, 102 | self.max_cav) 103 | # prior encoding added 104 | prior_encoding = prior_encoding.repeat(1, 1, 1, 105 | regroup_feature.shape[3], 106 | regroup_feature.shape[4]) 107 | regroup_feature = torch.cat([regroup_feature, prior_encoding], dim=2) 108 | 109 | # b l c h w -> b l h w c 110 | regroup_feature = regroup_feature.permute(0, 1, 3, 4, 2) 111 | # transformer fusion 112 | fused_feature = self.fusion_net(regroup_feature, mask, spatial_correction_matrix) 113 | # b h w c -> b c h w 114 | fused_feature = fused_feature.permute(0, 3, 1, 2) 115 | 116 | psm = self.cls_head(fused_feature) 117 | rm = self.reg_head(fused_feature) 118 | 119 | output_dict = {'psm': psm, 120 | 'rm': rm} 121 | 122 | return output_dict 123 | -------------------------------------------------------------------------------- /v2xvit/models/point_pillar_v2vnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from v2xvit.models.sub_modules.pillar_vfe import PillarVFE 5 | from v2xvit.models.sub_modules.point_pillar_scatter import PointPillarScatter 6 | from v2xvit.models.sub_modules.base_bev_backbone import BaseBEVBackbone 7 | from v2xvit.models.sub_modules.downsample_conv import DownsampleConv 8 | from v2xvit.models.sub_modules.naive_compress import NaiveCompressor 9 | from v2xvit.models.sub_modules.v2v_fuse import V2VNetFusion 10 | 11 | 12 | class PointPillarV2VNet(nn.Module): 13 | def __init__(self, args): 14 | super(PointPillarV2VNet, self).__init__() 15 | 16 | self.max_cav = args['max_cav'] 17 | # PIllar VFE 18 | self.pillar_vfe = PillarVFE(args['pillar_vfe'], 19 | num_point_features=4, 20 | voxel_size=args['voxel_size'], 21 | point_cloud_range=args['lidar_range']) 22 | self.scatter = PointPillarScatter(args['point_pillar_scatter']) 23 | self.backbone = BaseBEVBackbone(args['base_bev_backbone'], 64) 24 | # used to downsample the feature map for efficient computation 25 | self.shrink_flag = False 26 | if 'shrink_header' in args: 27 | self.shrink_flag = True 28 | self.shrink_conv = DownsampleConv(args['shrink_header']) 29 | self.compression = False 30 | 31 | if args['compression'] > 0: 32 | self.compression = True 33 | self.naive_compressor = NaiveCompressor(256, args['compression']) 34 | 35 | self.fusion_net = V2VNetFusion(args['v2vfusion']) 36 | 37 | self.cls_head = nn.Conv2d(128 * 2, args['anchor_number'], 38 | kernel_size=1) 39 | self.reg_head = nn.Conv2d(128 * 2, 7 * args['anchor_number'], 40 | kernel_size=1) 41 | 42 | if args['backbone_fix']: 43 | self.backbone_fix() 44 | 45 | def backbone_fix(self): 46 | """ 47 | Fix the parameters of backbone during finetune on timedelay。 48 | """ 49 | for p in self.pillar_vfe.parameters(): 50 | p.requires_grad = False 51 | 52 | for p in self.scatter.parameters(): 53 | p.requires_grad = False 54 | 55 | for p in self.backbone.parameters(): 56 | p.requires_grad = False 57 | 58 | if self.compression: 59 | for p in self.naive_compressor.parameters(): 60 | p.requires_grad = False 61 | if self.shrink_flag: 62 | for p in self.shrink_conv.parameters(): 63 | p.requires_grad = False 64 | 65 | for p in self.cls_head.parameters(): 66 | p.requires_grad = False 67 | for p in self.reg_head.parameters(): 68 | p.requires_grad = False 69 | 70 | def unpad_prior_encoding(self, x, record_len): 71 | # remove padded zeros to form tensor with shape (N, 3) 72 | # x: (B, L, 3); record_len: (B) 73 | B = x.shape[0] 74 | out = [] 75 | for i in range(B): 76 | # (valid_len, 3) 77 | out.append(x[i, :record_len[i], :]) 78 | out = torch.cat(out, dim=0) 79 | # (N, 3) 80 | return out 81 | 82 | def forward(self, data_dict): 83 | voxel_features = data_dict['processed_lidar']['voxel_features'] 84 | voxel_coords = data_dict['processed_lidar']['voxel_coords'] 85 | voxel_num_points = data_dict['processed_lidar']['voxel_num_points'] 86 | record_len = data_dict['record_len'] 87 | spatial_correction_matrix = data_dict['spatial_correction_matrix'] 88 | pairwise_t_matrix = data_dict['pairwise_t_matrix'] 89 | prior_encoding = data_dict['prior_encoding'] 90 | prior_encoding = self.unpad_prior_encoding(prior_encoding, record_len) 91 | 92 | batch_dict = {'voxel_features': voxel_features, 93 | 'voxel_coords': voxel_coords, 94 | 'voxel_num_points': voxel_num_points, 95 | 'record_len': record_len} 96 | # n, 4 -> n, c 97 | batch_dict = self.pillar_vfe(batch_dict) 98 | # n, c -> N, C, H, W 99 | batch_dict = self.scatter(batch_dict) 100 | batch_dict = self.backbone(batch_dict) 101 | 102 | spatial_features_2d = batch_dict['spatial_features_2d'] 103 | # downsample feature to reduce memory 104 | if self.shrink_flag: 105 | spatial_features_2d = self.shrink_conv(spatial_features_2d) 106 | # compressor 107 | if self.compression: 108 | spatial_features_2d = self.naive_compressor(spatial_features_2d) 109 | fused_feature = self.fusion_net(spatial_features_2d, 110 | record_len, 111 | pairwise_t_matrix, 112 | prior_encoding) 113 | 114 | psm = self.cls_head(fused_feature) 115 | rm = self.reg_head(fused_feature) 116 | 117 | output_dict = {'psm': psm, 118 | 'rm': rm} 119 | 120 | return output_dict 121 | -------------------------------------------------------------------------------- /v2xvit/models/sub_modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DerrickXuNu/v2x-vit/f0e6c13f41e916548b2d8aba61e42a18ce980416/v2xvit/models/sub_modules/__init__.py -------------------------------------------------------------------------------- /v2xvit/models/sub_modules/base_bev_backbone.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class BaseBEVBackbone(nn.Module): 7 | def __init__(self, model_cfg, input_channels): 8 | super().__init__() 9 | self.model_cfg = model_cfg 10 | 11 | if 'layer_nums' in self.model_cfg: 12 | 13 | assert len(self.model_cfg['layer_nums']) == \ 14 | len(self.model_cfg['layer_strides']) == \ 15 | len(self.model_cfg['num_filters']) 16 | 17 | layer_nums = self.model_cfg['layer_nums'] 18 | layer_strides = self.model_cfg['layer_strides'] 19 | num_filters = self.model_cfg['num_filters'] 20 | else: 21 | layer_nums = layer_strides = num_filters = [] 22 | 23 | if 'upsample_strides' in self.model_cfg: 24 | assert len(self.model_cfg['upsample_strides']) \ 25 | == len(self.model_cfg['num_upsample_filter']) 26 | 27 | num_upsample_filters = self.model_cfg['num_upsample_filter'] 28 | upsample_strides = self.model_cfg['upsample_strides'] 29 | 30 | else: 31 | upsample_strides = num_upsample_filters = [] 32 | 33 | num_levels = len(layer_nums) 34 | c_in_list = [input_channels, *num_filters[:-1]] 35 | 36 | self.blocks = nn.ModuleList() 37 | self.deblocks = nn.ModuleList() 38 | 39 | for idx in range(num_levels): 40 | cur_layers = [ 41 | nn.ZeroPad2d(1), 42 | nn.Conv2d( 43 | c_in_list[idx], num_filters[idx], kernel_size=3, 44 | stride=layer_strides[idx], padding=0, bias=False 45 | ), 46 | nn.BatchNorm2d(num_filters[idx], eps=1e-3, momentum=0.01), 47 | nn.ReLU() 48 | ] 49 | for k in range(layer_nums[idx]): 50 | cur_layers.extend([ 51 | nn.Conv2d(num_filters[idx], num_filters[idx], 52 | kernel_size=3, padding=1, bias=False), 53 | nn.BatchNorm2d(num_filters[idx], eps=1e-3, momentum=0.01), 54 | nn.ReLU() 55 | ]) 56 | 57 | self.blocks.append(nn.Sequential(*cur_layers)) 58 | if len(upsample_strides) > 0: 59 | stride = upsample_strides[idx] 60 | if stride >= 1: 61 | self.deblocks.append(nn.Sequential( 62 | nn.ConvTranspose2d( 63 | num_filters[idx], num_upsample_filters[idx], 64 | upsample_strides[idx], 65 | stride=upsample_strides[idx], bias=False 66 | ), 67 | nn.BatchNorm2d(num_upsample_filters[idx], 68 | eps=1e-3, momentum=0.01), 69 | nn.ReLU() 70 | )) 71 | else: 72 | stride = np.round(1 / stride).astype(np.int) 73 | self.deblocks.append(nn.Sequential( 74 | nn.Conv2d( 75 | num_filters[idx], num_upsample_filters[idx], 76 | stride, 77 | stride=stride, bias=False 78 | ), 79 | nn.BatchNorm2d(num_upsample_filters[idx], eps=1e-3, 80 | momentum=0.01), 81 | nn.ReLU() 82 | )) 83 | 84 | c_in = sum(num_upsample_filters) 85 | if len(upsample_strides) > num_levels: 86 | self.deblocks.append(nn.Sequential( 87 | nn.ConvTranspose2d(c_in, c_in, upsample_strides[-1], 88 | stride=upsample_strides[-1], bias=False), 89 | nn.BatchNorm2d(c_in, eps=1e-3, momentum=0.01), 90 | nn.ReLU(), 91 | )) 92 | 93 | self.num_bev_features = c_in 94 | 95 | def forward(self, data_dict): 96 | spatial_features = data_dict['spatial_features'] 97 | 98 | ups = [] 99 | ret_dict = {} 100 | x = spatial_features 101 | 102 | for i in range(len(self.blocks)): 103 | x = self.blocks[i](x) 104 | 105 | stride = int(spatial_features.shape[2] / x.shape[2]) 106 | ret_dict['spatial_features_%dx' % stride] = x 107 | 108 | if len(self.deblocks) > 0: 109 | ups.append(self.deblocks[i](x)) 110 | else: 111 | ups.append(x) 112 | 113 | if len(ups) > 1: 114 | x = torch.cat(ups, dim=1) 115 | elif len(ups) == 1: 116 | x = ups[0] 117 | 118 | if len(self.deblocks) > len(self.blocks): 119 | x = self.deblocks[-1](x) 120 | 121 | data_dict['spatial_features_2d'] = x 122 | return data_dict 123 | -------------------------------------------------------------------------------- /v2xvit/models/sub_modules/base_transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from einops import rearrange 5 | 6 | 7 | class PreNorm(nn.Module): 8 | def __init__(self, dim, fn): 9 | super().__init__() 10 | self.norm = nn.LayerNorm(dim) 11 | self.fn = fn 12 | 13 | def forward(self, x, **kwargs): 14 | return self.fn(self.norm(x), **kwargs) 15 | 16 | 17 | class FeedForward(nn.Module): 18 | def __init__(self, dim, hidden_dim, dropout=0.): 19 | super().__init__() 20 | self.net = nn.Sequential( 21 | nn.Linear(dim, hidden_dim), 22 | nn.GELU(), 23 | nn.Dropout(dropout), 24 | nn.Linear(hidden_dim, dim), 25 | nn.Dropout(dropout) 26 | ) 27 | 28 | def forward(self, x): 29 | return self.net(x) 30 | 31 | 32 | class CavAttention(nn.Module): 33 | """ 34 | Vanilla CAV attention. 35 | """ 36 | def __init__(self, dim, heads, dim_head=64, dropout=0.1): 37 | super().__init__() 38 | inner_dim = heads * dim_head 39 | 40 | self.heads = heads 41 | self.scale = dim_head ** -0.5 42 | 43 | self.attend = nn.Softmax(dim=-1) 44 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) 45 | 46 | self.to_out = nn.Sequential( 47 | nn.Linear(inner_dim, dim), 48 | nn.Dropout(dropout) 49 | ) 50 | 51 | def forward(self, x, mask, prior_encoding): 52 | # x: (B, L, H, W, C) -> (B, H, W, L, C) 53 | # mask: (B, L) 54 | x = x.permute(0, 2, 3, 1, 4) 55 | # mask: (B, 1, H, W, L, 1) 56 | mask = mask.unsqueeze(1) 57 | 58 | # qkv: [(B, H, W, L, C_inner) *3] 59 | qkv = self.to_qkv(x).chunk(3, dim=-1) 60 | # q: (B, M, H, W, L, C) 61 | q, k, v = map(lambda t: rearrange(t, 'b h w l (m c) -> b m h w l c', 62 | m=self.heads), qkv) 63 | 64 | # attention, (B, M, H, W, L, L) 65 | att_map = torch.einsum('b m h w i c, b m h w j c -> b m h w i j', 66 | q, k) * self.scale 67 | # add mask 68 | att_map = att_map.masked_fill(mask == 0, -float('inf')) 69 | # softmax 70 | att_map = self.attend(att_map) 71 | 72 | # out:(B, M, H, W, L, C_head) 73 | out = torch.einsum('b m h w i j, b m h w j c -> b m h w i c', att_map, 74 | v) 75 | out = rearrange(out, 'b m h w l c -> b h w l (m c)', 76 | m=self.heads) 77 | out = self.to_out(out) 78 | # (B L H W C) 79 | out = out.permute(0, 3, 1, 2, 4) 80 | return out 81 | 82 | 83 | class BaseEncoder(nn.Module): 84 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.): 85 | super().__init__() 86 | self.layers = nn.ModuleList([]) 87 | for _ in range(depth): 88 | self.layers.append(nn.ModuleList([ 89 | PreNorm(dim, CavAttention(dim, 90 | heads=heads, 91 | dim_head=dim_head, 92 | dropout=dropout)), 93 | PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout)) 94 | ])) 95 | 96 | def forward(self, x, mask): 97 | for attn, ff in self.layers: 98 | x = attn(x, mask=mask) + x 99 | x = ff(x) + x 100 | return x 101 | 102 | 103 | class BaseTransformer(nn.Module): 104 | def __init__(self, args): 105 | super().__init__() 106 | 107 | dim = args['dim'] 108 | depth = args['depth'] 109 | heads = args['heads'] 110 | dim_head = args['dim_head'] 111 | mlp_dim = args['mlp_dim'] 112 | dropout = args['dropout'] 113 | max_cav = args['max_cav'] 114 | 115 | self.encoder = BaseEncoder(dim, depth, heads, dim_head, mlp_dim, 116 | dropout) 117 | 118 | def forward(self, x, mask): 119 | # B, L, H, W, C 120 | output = self.encoder(x, mask) 121 | # B, H, W, C 122 | output = output[:, 0] 123 | 124 | return output -------------------------------------------------------------------------------- /v2xvit/models/sub_modules/downsample_conv.py: -------------------------------------------------------------------------------- 1 | """ 2 | Class used to downsample features by 3*3 conv 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | class DoubleConv(nn.Module): 10 | """ 11 | Double convoltuion 12 | Args: 13 | in_channels: input channel num 14 | out_channels: output channel num 15 | """ 16 | 17 | def __init__(self, in_channels, out_channels, kernel_size, 18 | stride, padding): 19 | super().__init__() 20 | self.double_conv = nn.Sequential( 21 | nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, 22 | stride=stride, padding=padding), 23 | nn.ReLU(inplace=True), 24 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), 25 | nn.ReLU(inplace=True) 26 | ) 27 | 28 | def forward(self, x): 29 | return self.double_conv(x) 30 | 31 | 32 | class DownsampleConv(nn.Module): 33 | def __init__(self, config): 34 | super(DownsampleConv, self).__init__() 35 | self.layers = nn.ModuleList([]) 36 | input_dim = config['input_dim'] 37 | 38 | for (ksize, dim, stride, padding) in zip(config['kernal_size'], 39 | config['dim'], 40 | config['stride'], 41 | config['padding']): 42 | self.layers.append(DoubleConv(input_dim, 43 | dim, 44 | kernel_size=ksize, 45 | stride=stride, 46 | padding=padding)) 47 | input_dim = dim 48 | 49 | def forward(self, x): 50 | for i in range(len(self.layers)): 51 | x = self.layers[i](x) 52 | return x -------------------------------------------------------------------------------- /v2xvit/models/sub_modules/f_cooper_fuse.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of F-cooper maxout fusing. 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class SpatialFusion(nn.Module): 9 | def __init__(self): 10 | super(SpatialFusion, self).__init__() 11 | 12 | def regroup(self, x, record_len): 13 | cum_sum_len = torch.cumsum(record_len, dim=0) 14 | split_x = torch.tensor_split(x, cum_sum_len[:-1].cpu()) 15 | return split_x 16 | 17 | def forward(self, x, record_len): 18 | # x: B, C, H, W, split x:[(B1, C, W, H), (B2, C, W, H)] 19 | split_x = self.regroup(x, record_len) 20 | out = [] 21 | 22 | for xx in split_x: 23 | xx = torch.max(xx, dim=0, keepdim=True)[0] 24 | out.append(xx) 25 | return torch.cat(out, dim=0) -------------------------------------------------------------------------------- /v2xvit/models/sub_modules/fuse_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from einops import rearrange 5 | from v2xvit.utils.common_utils import torch_tensor_to_numpy 6 | 7 | 8 | def regroup(dense_feature, record_len, max_len): 9 | """ 10 | Regroup the data based on the record_len. 11 | 12 | Parameters 13 | ---------- 14 | dense_feature : torch.Tensor 15 | N, C, H, W 16 | record_len : list 17 | [sample1_len, sample2_len, ...] 18 | max_len : int 19 | Maximum cav number 20 | 21 | Returns 22 | ------- 23 | regroup_feature : torch.Tensor 24 | B, L, C, H, W 25 | """ 26 | cum_sum_len = list(np.cumsum(torch_tensor_to_numpy(record_len))) 27 | split_features = torch.tensor_split(dense_feature, 28 | cum_sum_len[:-1]) 29 | regroup_features = [] 30 | mask = [] 31 | 32 | for split_feature in split_features: 33 | # M, C, H, W 34 | feature_shape = split_feature.shape 35 | 36 | # the maximum M is 5 as most 5 cavs 37 | padding_len = max_len - feature_shape[0] 38 | mask.append([1] * feature_shape[0] + [0] * padding_len) 39 | 40 | padding_tensor = torch.zeros(padding_len, feature_shape[1], 41 | feature_shape[2], feature_shape[3]) 42 | padding_tensor = padding_tensor.to(split_feature.device) 43 | 44 | split_feature = torch.cat([split_feature, padding_tensor], 45 | dim=0) 46 | 47 | # 1, 5C, H, W 48 | split_feature = split_feature.view(-1, 49 | feature_shape[2], 50 | feature_shape[3]).unsqueeze(0) 51 | regroup_features.append(split_feature) 52 | 53 | # B, 5C, H, W 54 | regroup_features = torch.cat(regroup_features, dim=0) 55 | # B, L, C, H, W 56 | regroup_features = rearrange(regroup_features, 57 | 'b (l c) h w -> b l c h w', 58 | l=max_len) 59 | mask = torch.from_numpy(np.array(mask)).to(regroup_features.device) 60 | 61 | return regroup_features, mask 62 | -------------------------------------------------------------------------------- /v2xvit/models/sub_modules/hmsa.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from einops import rearrange 5 | 6 | 7 | class HGTCavAttention(nn.Module): 8 | def __init__(self, dim, heads, num_types=2, 9 | num_relations=4, dim_head=64, dropout=0.1): 10 | super().__init__() 11 | inner_dim = heads * dim_head 12 | 13 | self.heads = heads 14 | self.scale = dim_head ** -0.5 15 | self.num_types = num_types 16 | 17 | self.attend = nn.Softmax(dim=-1) 18 | self.drop_out = nn.Dropout(dropout) 19 | self.k_linears = nn.ModuleList() 20 | self.q_linears = nn.ModuleList() 21 | self.v_linears = nn.ModuleList() 22 | self.a_linears = nn.ModuleList() 23 | self.norms = nn.ModuleList() 24 | for t in range(num_types): 25 | self.k_linears.append(nn.Linear(dim, inner_dim)) 26 | self.q_linears.append(nn.Linear(dim, inner_dim)) 27 | self.v_linears.append(nn.Linear(dim, inner_dim)) 28 | self.a_linears.append(nn.Linear(inner_dim, dim)) 29 | 30 | self.relation_att = nn.Parameter( 31 | torch.Tensor(num_relations, heads, dim_head, dim_head)) 32 | self.relation_msg = nn.Parameter( 33 | torch.Tensor(num_relations, heads, dim_head, dim_head)) 34 | 35 | torch.nn.init.xavier_uniform(self.relation_att) 36 | torch.nn.init.xavier_uniform(self.relation_msg) 37 | 38 | def to_qkv(self, x, types): 39 | # x: (B,H,W,L,C) 40 | # types: (B,L) 41 | q_batch = [] 42 | k_batch = [] 43 | v_batch = [] 44 | 45 | for b in range(x.shape[0]): 46 | q_list = [] 47 | k_list = [] 48 | v_list = [] 49 | 50 | for i in range(x.shape[-2]): 51 | # (H,W,1,C) 52 | q_list.append( 53 | self.q_linears[types[b, i]](x[b, :, :, i, :].unsqueeze(2))) 54 | k_list.append( 55 | self.k_linears[types[b, i]](x[b, :, :, i, :].unsqueeze(2))) 56 | v_list.append( 57 | self.v_linears[types[b, i]](x[b, :, :, i, :].unsqueeze(2))) 58 | # (1,H,W,L,C) 59 | q_batch.append(torch.cat(q_list, dim=2).unsqueeze(0)) 60 | k_batch.append(torch.cat(k_list, dim=2).unsqueeze(0)) 61 | v_batch.append(torch.cat(v_list, dim=2).unsqueeze(0)) 62 | # (B,H,W,L,C) 63 | q = torch.cat(q_batch, dim=0) 64 | k = torch.cat(k_batch, dim=0) 65 | v = torch.cat(v_batch, dim=0) 66 | return q, k, v 67 | 68 | def get_relation_type_index(self, type1, type2): 69 | return type1 * self.num_types + type2 70 | 71 | def get_hetero_edge_weights(self, x, types): 72 | w_att_batch = [] 73 | w_msg_batch = [] 74 | 75 | for b in range(x.shape[0]): 76 | w_att_list = [] 77 | w_msg_list = [] 78 | 79 | for i in range(x.shape[-2]): 80 | w_att_i_list = [] 81 | w_msg_i_list = [] 82 | 83 | for j in range(x.shape[-2]): 84 | e_type = self.get_relation_type_index(types[b, i], 85 | types[b, j]) 86 | w_att_i_list.append(self.relation_att[e_type].unsqueeze(0)) 87 | w_msg_i_list.append(self.relation_msg[e_type].unsqueeze(0)) 88 | w_att_list.append(torch.cat(w_att_i_list, dim=0).unsqueeze(0)) 89 | w_msg_list.append(torch.cat(w_msg_i_list, dim=0).unsqueeze(0)) 90 | 91 | w_att_batch.append(torch.cat(w_att_list, dim=0).unsqueeze(0)) 92 | w_msg_batch.append(torch.cat(w_msg_list, dim=0).unsqueeze(0)) 93 | 94 | # (B,M,L,L,C_head,C_head) 95 | w_att = torch.cat(w_att_batch, dim=0).permute(0, 3, 1, 2, 4, 5) 96 | w_msg = torch.cat(w_msg_batch, dim=0).permute(0, 3, 1, 2, 4, 5) 97 | return w_att, w_msg 98 | 99 | def to_out(self, x, types): 100 | out_batch = [] 101 | for b in range(x.shape[0]): 102 | out_list = [] 103 | for i in range(x.shape[-2]): 104 | out_list.append( 105 | self.a_linears[types[b, i]](x[b, :, :, i, :].unsqueeze(2))) 106 | out_batch.append(torch.cat(out_list, dim=2).unsqueeze(0)) 107 | out = torch.cat(out_batch, dim=0) 108 | return out 109 | 110 | def forward(self, x, mask, prior_encoding): 111 | # x: (B, L, H, W, C) -> (B, H, W, L, C) 112 | # mask: (B, H, W, L, 1) 113 | # prior_encoding: (B,L,H,W,3) 114 | x = x.permute(0, 2, 3, 1, 4) 115 | # mask: (B, 1, H, W, L, 1) 116 | mask = mask.unsqueeze(1) 117 | # (B,L) 118 | velocities, dts, types = [itm.squeeze(-1) for itm in 119 | prior_encoding[:, :, 0, 0, :].split( 120 | [1, 1, 1], dim=-1)] 121 | types = types.to(torch.int) 122 | dts = dts.to(torch.int) 123 | qkv = self.to_qkv(x, types) 124 | # (B,M,L,L,C_head,C_head) 125 | w_att, w_msg = self.get_hetero_edge_weights(x, types) 126 | 127 | # q: (B, M, H, W, L, C) 128 | q, k, v = map(lambda t: rearrange(t, 'b h w l (m c) -> b m h w l c', 129 | m=self.heads), (qkv)) 130 | # attention, (B, M, H, W, L, L) 131 | att_map = torch.einsum( 132 | 'b m h w i p, b m i j p q, bm h w j q -> b m h w i j', 133 | [q, w_att, k]) * self.scale 134 | # add mask 135 | att_map = att_map.masked_fill(mask == 0, -float('inf')) 136 | # softmax 137 | att_map = self.attend(att_map) 138 | 139 | # out:(B, M, H, W, L, C_head) 140 | v_msg = torch.einsum('b m i j p c, b m h w j p -> b m h w i j c', 141 | w_msg, v) 142 | out = torch.einsum('b m h w i j, b m h w i j c -> b m h w i c', 143 | att_map, v_msg) 144 | 145 | out = rearrange(out, 'b m h w l c -> b h w l (m c)', 146 | m=self.heads) 147 | out = self.to_out(out, types) 148 | out = self.drop_out(out) 149 | # (B L H W C) 150 | out = out.permute(0, 3, 1, 2, 4) 151 | return out -------------------------------------------------------------------------------- /v2xvit/models/sub_modules/mswin.py: -------------------------------------------------------------------------------- 1 | """ 2 | Multi-scale window transformer 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | import numpy as np 7 | 8 | from einops import rearrange 9 | from v2xvit.models.sub_modules.split_attn import SplitAttn 10 | 11 | 12 | def get_relative_distances(window_size): 13 | indices = torch.tensor(np.array( 14 | [[x, y] for x in range(window_size) for y in range(window_size)])) 15 | distances = indices[None, :, :] - indices[:, None, :] 16 | return distances 17 | 18 | 19 | class BaseWindowAttention(nn.Module): 20 | def __init__(self, dim, heads, dim_head, drop_out, window_size, 21 | relative_pos_embedding): 22 | super().__init__() 23 | inner_dim = dim_head * heads 24 | 25 | self.heads = heads 26 | self.scale = dim_head ** -0.5 27 | self.window_size = window_size 28 | self.relative_pos_embedding = relative_pos_embedding 29 | 30 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) 31 | 32 | if self.relative_pos_embedding: 33 | self.relative_indices = get_relative_distances(window_size) + \ 34 | window_size - 1 35 | self.pos_embedding = nn.Parameter(torch.randn(2 * window_size - 1, 36 | 2 * window_size - 1)) 37 | else: 38 | self.pos_embedding = nn.Parameter(torch.randn(window_size ** 2, 39 | window_size ** 2)) 40 | 41 | self.to_out = nn.Sequential( 42 | nn.Linear(inner_dim, dim), 43 | nn.Dropout(drop_out) 44 | ) 45 | 46 | def forward(self, x): 47 | b, l, h, w, c, m = *x.shape, self.heads 48 | 49 | qkv = self.to_qkv(x).chunk(3, dim=-1) 50 | new_h = h // self.window_size 51 | new_w = w // self.window_size 52 | 53 | # q : (b, l, m, new_h*new_w, window_size^2, c_head) 54 | q, k, v = map( 55 | lambda t: rearrange(t, 56 | 'b l (new_h w_h) (new_w w_w) (m c) -> b l m (new_h new_w) (w_h w_w) c', 57 | m=m, w_h=self.window_size, 58 | w_w=self.window_size), qkv) 59 | # b l m h window_size window_size 60 | dots = torch.einsum('b l m h i c, b l m h j c -> b l m h i j', 61 | q, k, ) * self.scale 62 | # consider prior knowledge of the local window 63 | if self.relative_pos_embedding: 64 | dots += self.pos_embedding[self.relative_indices[:, :, 0], 65 | self.relative_indices[:, :, 1]] 66 | else: 67 | dots += self.pos_embedding 68 | 69 | attn = dots.softmax(dim=-1) 70 | 71 | out = torch.einsum('b l m h i j, b l m h j c -> b l m h i c', attn, v) 72 | # b l h w c 73 | out = rearrange(out, 74 | 'b l m (new_h new_w) (w_h w_w) c -> b l (new_h w_h) (new_w w_w) (m c)', 75 | m=self.heads, w_h=self.window_size, 76 | w_w=self.window_size, 77 | new_w=new_w, new_h=new_h) 78 | out = self.to_out(out) 79 | 80 | return out 81 | 82 | 83 | class PyramidWindowAttention(nn.Module): 84 | def __init__(self, dim, heads, dim_heads, drop_out, window_size, 85 | relative_pos_embedding, fuse_method='naive'): 86 | super().__init__() 87 | 88 | assert isinstance(window_size, list) 89 | assert isinstance(heads, list) 90 | assert isinstance(dim_heads, list) 91 | assert len(dim_heads) == len(heads) 92 | 93 | self.pwmsa = nn.ModuleList([]) 94 | 95 | for (head, dim_head, ws) in zip(heads, dim_heads, window_size): 96 | self.pwmsa.append(BaseWindowAttention(dim, 97 | head, 98 | dim_head, 99 | drop_out, 100 | ws, 101 | relative_pos_embedding)) 102 | self.fuse_mehod = fuse_method 103 | if fuse_method == 'split_attn': 104 | self.split_attn = SplitAttn(256) 105 | 106 | def forward(self, x): 107 | output = None 108 | # naive fusion will just sum up all window attention output and do a 109 | # mean 110 | if self.fuse_mehod == 'naive': 111 | for wmsa in self.pwmsa: 112 | output = wmsa(x) if output is None else output + wmsa(x) 113 | return output / len(self.pwmsa) 114 | 115 | elif self.fuse_mehod == 'split_attn': 116 | window_list = [] 117 | for wmsa in self.pwmsa: 118 | window_list.append(wmsa(x)) 119 | return self.split_attn(window_list) -------------------------------------------------------------------------------- /v2xvit/models/sub_modules/naive_compress.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class NaiveCompressor(nn.Module): 6 | def __init__(self, input_dim, compress_raito): 7 | super().__init__() 8 | self.encoder = nn.Sequential( 9 | nn.Conv2d(input_dim, input_dim//compress_raito, kernel_size=3, 10 | stride=1, padding=1), 11 | nn.BatchNorm2d(input_dim//compress_raito, eps=1e-3, momentum=0.01), 12 | nn.ReLU() 13 | ) 14 | self.decoder = nn.Sequential( 15 | nn.Conv2d(input_dim//compress_raito, input_dim, kernel_size=3, 16 | stride=1, padding=1), 17 | nn.BatchNorm2d(input_dim, eps=1e-3, momentum=0.01), 18 | nn.ReLU(), 19 | nn.Conv2d(input_dim, input_dim, kernel_size=3, stride=1, padding=1), 20 | nn.BatchNorm2d(input_dim, eps=1e-3, 21 | momentum=0.01), 22 | nn.ReLU() 23 | ) 24 | 25 | def forward(self, x): 26 | x = self.encoder(x) 27 | x = self.decoder(x) 28 | 29 | return x -------------------------------------------------------------------------------- /v2xvit/models/sub_modules/pillar_vfe.py: -------------------------------------------------------------------------------- 1 | """ 2 | Pillar VFE, credits to OpenPCDet. 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class PFNLayer(nn.Module): 11 | def __init__(self, 12 | in_channels, 13 | out_channels, 14 | use_norm=True, 15 | last_layer=False): 16 | super().__init__() 17 | 18 | self.last_vfe = last_layer 19 | self.use_norm = use_norm 20 | if not self.last_vfe: 21 | out_channels = out_channels // 2 22 | 23 | if self.use_norm: 24 | self.linear = nn.Linear(in_channels, out_channels, bias=False) 25 | self.norm = nn.BatchNorm1d(out_channels, eps=1e-3, momentum=0.01) 26 | else: 27 | self.linear = nn.Linear(in_channels, out_channels, bias=True) 28 | 29 | self.part = 50000 30 | 31 | def forward(self, inputs): 32 | if inputs.shape[0] > self.part: 33 | # nn.Linear performs randomly when batch size is too large 34 | num_parts = inputs.shape[0] // self.part 35 | part_linear_out = [self.linear( 36 | inputs[num_part * self.part:(num_part + 1) * self.part]) 37 | for num_part in range(num_parts + 1)] 38 | x = torch.cat(part_linear_out, dim=0) 39 | else: 40 | x = self.linear(inputs) 41 | torch.backends.cudnn.enabled = False 42 | x = self.norm(x.permute(0, 2, 1)).permute(0, 2, 43 | 1) if self.use_norm else x 44 | torch.backends.cudnn.enabled = True 45 | x = F.relu(x) 46 | x_max = torch.max(x, dim=1, keepdim=True)[0] 47 | 48 | if self.last_vfe: 49 | return x_max 50 | else: 51 | x_repeat = x_max.repeat(1, inputs.shape[1], 1) 52 | x_concatenated = torch.cat([x, x_repeat], dim=2) 53 | return x_concatenated 54 | 55 | 56 | class PillarVFE(nn.Module): 57 | def __init__(self, model_cfg, num_point_features, voxel_size, 58 | point_cloud_range): 59 | super().__init__() 60 | self.model_cfg = model_cfg 61 | 62 | self.use_norm = self.model_cfg['use_norm'] 63 | self.with_distance = self.model_cfg['with_distance'] 64 | 65 | self.use_absolute_xyz = self.model_cfg['use_absolute_xyz'] 66 | num_point_features += 6 if self.use_absolute_xyz else 3 67 | if self.with_distance: 68 | num_point_features += 1 69 | 70 | self.num_filters = self.model_cfg['num_filters'] 71 | assert len(self.num_filters) > 0 72 | num_filters = [num_point_features] + list(self.num_filters) 73 | 74 | pfn_layers = [] 75 | for i in range(len(num_filters) - 1): 76 | in_filters = num_filters[i] 77 | out_filters = num_filters[i + 1] 78 | pfn_layers.append( 79 | PFNLayer(in_filters, out_filters, self.use_norm, 80 | last_layer=(i >= len(num_filters) - 2)) 81 | ) 82 | self.pfn_layers = nn.ModuleList(pfn_layers) 83 | 84 | self.voxel_x = voxel_size[0] 85 | self.voxel_y = voxel_size[1] 86 | self.voxel_z = voxel_size[2] 87 | self.x_offset = self.voxel_x / 2 + point_cloud_range[0] 88 | self.y_offset = self.voxel_y / 2 + point_cloud_range[1] 89 | self.z_offset = self.voxel_z / 2 + point_cloud_range[2] 90 | 91 | def get_output_feature_dim(self): 92 | return self.num_filters[-1] 93 | 94 | @staticmethod 95 | def get_paddings_indicator(actual_num, max_num, axis=0): 96 | actual_num = torch.unsqueeze(actual_num, axis + 1) 97 | max_num_shape = [1] * len(actual_num.shape) 98 | max_num_shape[axis + 1] = -1 99 | max_num = torch.arange(max_num, 100 | dtype=torch.int, 101 | device=actual_num.device).view(max_num_shape) 102 | paddings_indicator = actual_num.int() > max_num 103 | return paddings_indicator 104 | 105 | def forward(self, batch_dict): 106 | 107 | voxel_features, voxel_num_points, coords = \ 108 | batch_dict['voxel_features'], batch_dict['voxel_num_points'], \ 109 | batch_dict['voxel_coords'] 110 | points_mean = \ 111 | voxel_features[:, :, :3].sum(dim=1, keepdim=True) / \ 112 | voxel_num_points.type_as(voxel_features).view(-1, 1, 1) 113 | f_cluster = voxel_features[:, :, :3] - points_mean 114 | 115 | f_center = torch.zeros_like(voxel_features[:, :, :3]) 116 | f_center[:, :, 0] = voxel_features[:, :, 0] - ( 117 | coords[:, 3].to(voxel_features.dtype).unsqueeze( 118 | 1) * self.voxel_x + self.x_offset) 119 | f_center[:, :, 1] = voxel_features[:, :, 1] - ( 120 | coords[:, 2].to(voxel_features.dtype).unsqueeze( 121 | 1) * self.voxel_y + self.y_offset) 122 | f_center[:, :, 2] = voxel_features[:, :, 2] - ( 123 | coords[:, 1].to(voxel_features.dtype).unsqueeze( 124 | 1) * self.voxel_z + self.z_offset) 125 | 126 | if self.use_absolute_xyz: 127 | features = [voxel_features, f_cluster, f_center] 128 | else: 129 | features = [voxel_features[..., 3:], f_cluster, f_center] 130 | 131 | if self.with_distance: 132 | points_dist = torch.norm(voxel_features[:, :, :3], 2, 2, 133 | keepdim=True) 134 | features.append(points_dist) 135 | features = torch.cat(features, dim=-1) 136 | 137 | voxel_count = features.shape[1] 138 | mask = self.get_paddings_indicator(voxel_num_points, voxel_count, 139 | axis=0) 140 | mask = torch.unsqueeze(mask, -1).type_as(voxel_features) 141 | features *= mask 142 | for pfn in self.pfn_layers: 143 | features = pfn(features) 144 | features = features.squeeze() 145 | batch_dict['pillar_features'] = features 146 | return batch_dict 147 | -------------------------------------------------------------------------------- /v2xvit/models/sub_modules/point_pillar_scatter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class PointPillarScatter(nn.Module): 6 | def __init__(self, model_cfg): 7 | super().__init__() 8 | 9 | self.model_cfg = model_cfg 10 | self.num_bev_features = self.model_cfg['num_features'] 11 | self.nx, self.ny, self.nz = model_cfg['grid_size'] 12 | assert self.nz == 1 13 | 14 | def forward(self, batch_dict): 15 | pillar_features, coords = batch_dict['pillar_features'], batch_dict[ 16 | 'voxel_coords'] 17 | batch_spatial_features = [] 18 | batch_size = coords[:, 0].max().int().item() + 1 19 | 20 | for batch_idx in range(batch_size): 21 | spatial_feature = torch.zeros( 22 | self.num_bev_features, 23 | self.nz * self.nx * self.ny, 24 | dtype=pillar_features.dtype, 25 | device=pillar_features.device) 26 | 27 | batch_mask = coords[:, 0] == batch_idx 28 | this_coords = coords[batch_mask, :] 29 | 30 | indices = this_coords[:, 1] + \ 31 | this_coords[:, 2] * self.nx + \ 32 | this_coords[:, 3] 33 | indices = indices.type(torch.long) 34 | 35 | pillars = pillar_features[batch_mask, :] 36 | pillars = pillars.t() 37 | spatial_feature[:, indices] = pillars 38 | batch_spatial_features.append(spatial_feature) 39 | 40 | batch_spatial_features = \ 41 | torch.stack(batch_spatial_features, 0) 42 | batch_spatial_features = \ 43 | batch_spatial_features.view(batch_size, self.num_bev_features * 44 | self.nz, self.ny, self.nx) 45 | batch_dict['spatial_features'] = batch_spatial_features 46 | 47 | return batch_dict 48 | 49 | -------------------------------------------------------------------------------- /v2xvit/models/sub_modules/self_attn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class ScaledDotProductAttention(nn.Module): 8 | """ 9 | Scaled Dot-Product Attention proposed in "Attention Is All You Need" 10 | Compute the dot products of the query with all keys, divide each by sqrt(dim), 11 | and apply a softmax function to obtain the weights on the values 12 | Args: dim, mask 13 | dim (int): dimention of attention 14 | mask (torch.Tensor): tensor containing indices to be masked 15 | Inputs: query, key, value, mask 16 | - **query** (batch, q_len, d_model): tensor containing projection vector for decoder. 17 | - **key** (batch, k_len, d_model): tensor containing projection vector for encoder. 18 | - **value** (batch, v_len, d_model): tensor containing features of the encoded input sequence. 19 | - **mask** (-): tensor containing indices to be masked 20 | Returns: context, attn 21 | - **context**: tensor containing the context vector from attention mechanism. 22 | - **attn**: tensor containing the attention (alignment) from the encoder outputs. 23 | """ 24 | 25 | def __init__(self, dim): 26 | super(ScaledDotProductAttention, self).__init__() 27 | self.sqrt_dim = np.sqrt(dim) 28 | 29 | def forward(self, query, key, value): 30 | score = torch.bmm(query, key.transpose(1, 2)) / self.sqrt_dim 31 | attn = F.softmax(score, -1) 32 | context = torch.bmm(attn, value) 33 | return context 34 | 35 | 36 | class AttFusion(nn.Module): 37 | def __init__(self, feature_dim): 38 | super(AttFusion, self).__init__() 39 | self.att = ScaledDotProductAttention(feature_dim) 40 | 41 | def forward(self, x, record_len): 42 | split_x = self.regroup(x, record_len) 43 | batch_size = len(record_len) 44 | C, W, H = split_x[0].shape[1:] 45 | out = [] 46 | for xx in split_x: 47 | cav_num = xx.shape[0] 48 | xx = xx.view(cav_num, C, -1).permute(2, 0, 1) 49 | h = self.att(xx, xx, xx) 50 | h = h.permute(1, 2, 0).view(cav_num, C, W, H)[0, ...].unsqueeze(0) 51 | out.append(h) 52 | return torch.cat(out, dim=0) 53 | 54 | def regroup(self, x, record_len): 55 | cum_sum_len = torch.cumsum(record_len, dim=0) 56 | split_x = torch.tensor_split(x, cum_sum_len[:-1].cpu()) 57 | return split_x 58 | -------------------------------------------------------------------------------- /v2xvit/models/sub_modules/split_attn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class RadixSoftmax(nn.Module): 7 | def __init__(self, radix, cardinality): 8 | super(RadixSoftmax, self).__init__() 9 | self.radix = radix 10 | self.cardinality = cardinality 11 | 12 | def forward(self, x): 13 | # x: (B, L, 1, 1, 3C) 14 | batch = x.size(0) 15 | cav_num = x.size(1) 16 | 17 | if self.radix > 1: 18 | # x: (B, L, 1, 3, C) 19 | x = x.view(batch, 20 | cav_num, 21 | self.cardinality, self.radix, -1) 22 | x = F.softmax(x, dim=3) 23 | # B, 3LC 24 | x = x.reshape(batch, -1) 25 | else: 26 | x = torch.sigmoid(x) 27 | return x 28 | 29 | 30 | class SplitAttn(nn.Module): 31 | def __init__(self, input_dim): 32 | super(SplitAttn, self).__init__() 33 | self.input_dim = input_dim 34 | 35 | self.fc1 = nn.Linear(input_dim, input_dim, bias=False) 36 | self.bn1 = nn.LayerNorm(input_dim) 37 | self.act1 = nn.ReLU() 38 | self.fc2 = nn.Linear(input_dim, input_dim * 3, bias=False) 39 | 40 | self.rsoftmax = RadixSoftmax(3, 1) 41 | 42 | def forward(self, window_list): 43 | # window list: [(B, L, H, W, C) * 3] 44 | assert len(window_list) == 3, 'only 3 windows are supported' 45 | 46 | sw, mw, bw = window_list[0], window_list[1], window_list[2] 47 | B, L = sw.shape[0], sw.shape[1] 48 | 49 | # global average pooling, B, L, H, W, C 50 | x_gap = sw + mw + bw 51 | # B, L, 1, 1, C 52 | x_gap = x_gap.mean((2, 3), keepdim=True) 53 | x_gap = self.act1(self.bn1(self.fc1(x_gap))) 54 | # B, L, 1, 1, 3C 55 | x_attn = self.fc2(x_gap) 56 | # B L 1 1 3C 57 | x_attn = self.rsoftmax(x_attn).view(B, L, 1, 1, -1) 58 | 59 | out = sw * x_attn[:, :, :, :, 0:self.input_dim] + \ 60 | mw * x_attn[:, :, :, :, self.input_dim:2*self.input_dim] +\ 61 | bw * x_attn[:, :, :, :, self.input_dim*2:] 62 | 63 | return out 64 | -------------------------------------------------------------------------------- /v2xvit/models/sub_modules/v2v_fuse.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of V2VNet Fusion 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from v2xvit.models.sub_modules.torch_transformation_utils import \ 9 | get_discretized_transformation_matrix, get_transformation_matrix, \ 10 | warp_affine, get_rotated_roi 11 | from v2xvit.models.sub_modules.convgru import ConvGRU 12 | 13 | 14 | class V2VNetFusion(nn.Module): 15 | def __init__(self, args): 16 | super(V2VNetFusion, self).__init__() 17 | in_channels = args['in_channels'] 18 | H, W = args['conv_gru']['H'], args['conv_gru']['W'] 19 | kernel_size = args['conv_gru']['kernel_size'] 20 | num_gru_layers = args['conv_gru']['num_layers'] 21 | 22 | self.use_temporal_encoding = args['use_temporal_encoding'] 23 | self.discrete_ratio = args['voxel_size'][0] 24 | self.downsample_rate = args['downsample_rate'] 25 | self.num_iteration = args['num_iteration'] 26 | self.gru_flag = args['gru_flag'] 27 | self.agg_operator = args['agg_operator'] 28 | 29 | self.cnn = nn.Conv2d(in_channels + 1, in_channels, kernel_size=3, 30 | stride=1, padding=1) 31 | self.msg_cnn = nn.Conv2d(in_channels * 2, in_channels, kernel_size=3, 32 | stride=1, padding=1) 33 | self.conv_gru = ConvGRU(input_size=(H, W), 34 | input_dim=in_channels * 2, 35 | hidden_dim=[in_channels], 36 | kernel_size=kernel_size, 37 | num_layers=num_gru_layers, 38 | batch_first=True, 39 | bias=True, 40 | return_all_layers=False) 41 | self.mlp = nn.Linear(in_channels, in_channels) 42 | 43 | def regroup(self, x, record_len): 44 | cum_sum_len = torch.cumsum(record_len, dim=0) 45 | split_x = torch.tensor_split(x, cum_sum_len[:-1].cpu()) 46 | return split_x 47 | 48 | def forward(self, x, record_len, pairwise_t_matrix, prior_encoding): 49 | # x: (B,C,H,W) 50 | # record_len: (B) 51 | # pairwise_t_matrix: (B,L,L,4,4) 52 | # prior_encoding: (B,3) 53 | _, C, H, W = x.shape 54 | B, L = pairwise_t_matrix.shape[:2] 55 | 56 | if self.use_temporal_encoding: 57 | # (B,1,1,1) 58 | dt = prior_encoding[:, 1].to(torch.int).unsqueeze(1).unsqueeze( 59 | 2).unsqueeze(3) 60 | x = torch.cat([x, dt.repeat(1, 1, H, W)], dim=1) 61 | x = self.cnn(x) 62 | 63 | # split x:[(L1, C, H, W), (L2, C, H, W)] 64 | split_x = self.regroup(x, record_len) 65 | # (B,L,L,2,3) 66 | pairwise_t_matrix = get_discretized_transformation_matrix( 67 | pairwise_t_matrix.reshape(-1, L, 4, 4), self.discrete_ratio, 68 | self.downsample_rate).reshape(B, L, L, 2, 3) 69 | # (B*L,L,1,H,W) 70 | roi_mask = get_rotated_roi((B * L, L, 1, H, W), 71 | pairwise_t_matrix.reshape(B * L * L, 2, 3)) 72 | roi_mask = roi_mask.reshape(B, L, L, 1, H, W) 73 | 74 | batch_node_features = split_x 75 | # iteratively update the features for num_iteration times 76 | for l in range(self.num_iteration): 77 | 78 | batch_updated_node_features = [] 79 | # iterate each batch 80 | for b in range(B): 81 | 82 | # number of valid agent 83 | N = record_len[b] 84 | # (N,N,4,4) 85 | # t_matrix[i, j]-> from i to j 86 | t_matrix = pairwise_t_matrix[b][:N, :N, :, :] 87 | updated_node_features = [] 88 | # update each node i 89 | for i in range(N): 90 | # (N,1,H,W) 91 | mask = roi_mask[b, :N, i, ...] 92 | 93 | current_t_matrix = t_matrix[:, i, :, :] 94 | current_t_matrix = get_transformation_matrix( 95 | current_t_matrix, (H, W)) 96 | 97 | # (N,C,H,W) 98 | neighbor_feature = warp_affine(batch_node_features[b], 99 | current_t_matrix, 100 | (H, W)) 101 | # (N,C,H,W) 102 | ego_agent_feature = batch_node_features[b][i].unsqueeze( 103 | 0).repeat(N, 1, 1, 1) 104 | #(N,2C,H,W) 105 | neighbor_feature = torch.cat( 106 | [neighbor_feature, ego_agent_feature], dim=1) 107 | # (N,C,H,W) 108 | message = self.msg_cnn(neighbor_feature) * mask 109 | 110 | # (C,H,W) 111 | if self.agg_operator=="avg": 112 | agg_feature = torch.mean(message, dim=0) 113 | elif self.agg_operator=="max": 114 | agg_feature = torch.max(message, dim=0)[0] 115 | else: 116 | raise ValueError("agg_operator has wrong value") 117 | # (2C, H, W) 118 | cat_feature = torch.cat( 119 | [batch_node_features[b][i, ...], agg_feature], dim=0) 120 | # (C,H,W) 121 | if self.gru_flag: 122 | gru_out = \ 123 | self.conv_gru(cat_feature.unsqueeze(0).unsqueeze(0))[ 124 | 0][ 125 | 0].squeeze(0).squeeze(0) 126 | else: 127 | gru_out = batch_node_features[b][i, ...] + agg_feature 128 | updated_node_features.append(gru_out.unsqueeze(0)) 129 | # (N,C,H,W) 130 | batch_updated_node_features.append( 131 | torch.cat(updated_node_features, dim=0)) 132 | batch_node_features = batch_updated_node_features 133 | # (B,C,H,W) 134 | out = torch.cat( 135 | [itm[0, ...].unsqueeze(0) for itm in batch_node_features], dim=0) 136 | # (B,C,H,W) 137 | out = self.mlp(out.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) 138 | 139 | return out 140 | -------------------------------------------------------------------------------- /v2xvit/tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DerrickXuNu/v2x-vit/f0e6c13f41e916548b2d8aba61e42a18ce980416/v2xvit/tools/__init__.py -------------------------------------------------------------------------------- /v2xvit/tools/debug_utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | from torch.utils.data import DataLoader 5 | 6 | import v2xvit.hypes_yaml.yaml_utils as yaml_utils 7 | from v2xvit.tools import train_utils 8 | from v2xvit.data_utils.datasets import build_dataset 9 | from v2xvit.visualization import vis_utils 10 | 11 | 12 | def test_parser(): 13 | parser = argparse.ArgumentParser(description="synthetic data generation") 14 | parser.add_argument('--model_dir', type=str, required=True, 15 | help='Continued training path') 16 | parser.add_argument('--fusion_method', type=str, default='late', 17 | help='late, early or intermediate') 18 | opt = parser.parse_args() 19 | return opt 20 | 21 | 22 | def test_bev_post_processing(): 23 | opt = test_parser() 24 | assert opt.fusion_method in ['late', 'early', 'intermediate'] 25 | 26 | hypes = yaml_utils.load_yaml(None, opt) 27 | 28 | print('Dataset Building') 29 | opencood_dataset = build_dataset(hypes, visualize=True, train=False) 30 | data_loader = DataLoader(opencood_dataset, 31 | batch_size=1, 32 | num_workers=0, 33 | collate_fn=opencood_dataset.collate_batch_test, 34 | shuffle=False, 35 | pin_memory=False, 36 | drop_last=False) 37 | 38 | print('Creating Model') 39 | model = train_utils.create_model(hypes) 40 | # we assume gpu is necessary 41 | if torch.cuda.is_available(): 42 | model.cuda() 43 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 44 | 45 | print('Loading Model from checkpoint') 46 | saved_path = opt.model_dir 47 | _, model = train_utils.load_saved_model(saved_path, model) 48 | model.eval() 49 | for i, batch_data in enumerate(data_loader): 50 | batch_data = train_utils.to_device(batch_data, device) 51 | label_map = batch_data["ego"]["label_dict"]["label_map"] 52 | output_dict = { 53 | "cls": label_map[:, 0, :, :], 54 | "reg": label_map[:, 1:, :, :] 55 | } 56 | gt_box_tensor, _ = opencood_dataset.post_processor.post_process_debug( 57 | batch_data["ego"], output_dict) 58 | vis_utils.visualize_single_sample_output_bev(gt_box_tensor, 59 | batch_data['ego'][ 60 | 'origin_lidar'].squeeze( 61 | 0), 62 | opencood_dataset) 63 | 64 | 65 | if __name__ == '__main__': 66 | test_bev_post_processing() 67 | -------------------------------------------------------------------------------- /v2xvit/tools/infrence_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import OrderedDict 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from v2xvit.utils.common_utils import torch_tensor_to_numpy 8 | 9 | 10 | def inference_late_fusion(batch_data, model, dataset): 11 | """ 12 | Model inference for late fusion. 13 | 14 | Parameters 15 | ---------- 16 | batch_data : dict 17 | model : opencood.object 18 | dataset : opencood.LateFusionDataset 19 | 20 | Returns 21 | ------- 22 | pred_box_tensor : torch.Tensor 23 | The tensor of prediction bounding box after NMS. 24 | gt_box_tensor : torch.Tensor 25 | The tensor of gt bounding box. 26 | """ 27 | output_dict = OrderedDict() 28 | 29 | for cav_id, cav_content in batch_data.items(): 30 | output_dict[cav_id] = model(cav_content) 31 | 32 | pred_box_tensor, pred_score, gt_box_tensor = \ 33 | dataset.post_process(batch_data, 34 | output_dict) 35 | 36 | return pred_box_tensor, pred_score, gt_box_tensor 37 | 38 | 39 | def inference_early_fusion(batch_data, model, dataset): 40 | """ 41 | Model inference for early fusion. 42 | 43 | Parameters 44 | ---------- 45 | batch_data : dict 46 | model : opencood.object 47 | dataset : opencood.EarlyFusionDataset 48 | 49 | Returns 50 | ------- 51 | pred_box_tensor : torch.Tensor 52 | The tensor of prediction bounding box after NMS. 53 | gt_box_tensor : torch.Tensor 54 | The tensor of gt bounding box. 55 | """ 56 | output_dict = OrderedDict() 57 | cav_content = batch_data['ego'] 58 | 59 | output_dict['ego'] = model(cav_content) 60 | 61 | pred_box_tensor, pred_score, gt_box_tensor = \ 62 | dataset.post_process(batch_data, 63 | output_dict) 64 | 65 | return pred_box_tensor, pred_score, gt_box_tensor 66 | 67 | 68 | def inference_intermediate_fusion(batch_data, model, dataset): 69 | """ 70 | Model inference for early fusion. 71 | 72 | Parameters 73 | ---------- 74 | batch_data : dict 75 | model : opencood.object 76 | dataset : opencood.EarlyFusionDataset 77 | 78 | Returns 79 | ------- 80 | pred_box_tensor : torch.Tensor 81 | The tensor of prediction bounding box after NMS. 82 | gt_box_tensor : torch.Tensor 83 | The tensor of gt bounding box. 84 | """ 85 | return inference_early_fusion(batch_data, model, dataset) 86 | 87 | 88 | def save_prediction_gt(pred_tensor, gt_tensor, pcd, timestamp, save_path): 89 | """ 90 | Save prediction and gt tensor to txt file. 91 | """ 92 | pred_np = torch_tensor_to_numpy(pred_tensor) 93 | gt_np = torch_tensor_to_numpy(gt_tensor) 94 | pcd_np = torch_tensor_to_numpy(pcd) 95 | 96 | np.save(os.path.join(save_path, '%04d_pcd.npy' % timestamp), pcd_np) 97 | np.save(os.path.join(save_path, '%04d_pred.npy' % timestamp), pred_np) 98 | np.save(os.path.join(save_path, '%04d_gt.npy' % timestamp), gt_np) 99 | -------------------------------------------------------------------------------- /v2xvit/tools/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import statistics 4 | 5 | import torch 6 | import tqdm 7 | from torch.utils.data import DataLoader 8 | from tensorboardX import SummaryWriter 9 | 10 | import v2xvit.hypes_yaml.yaml_utils as yaml_utils 11 | from v2xvit.tools import train_utils 12 | from v2xvit.data_utils.datasets import build_dataset 13 | 14 | 15 | def train_parser(): 16 | parser = argparse.ArgumentParser(description="synthetic data generation") 17 | parser.add_argument("--hypes_yaml", type=str, required=True, 18 | help='data generation yaml file needed ') 19 | parser.add_argument('--model_dir', default='', 20 | help='Continued training path') 21 | parser.add_argument("--half", action='store_true', help="whether train with half precision") 22 | opt = parser.parse_args() 23 | return opt 24 | 25 | 26 | def main(): 27 | opt = train_parser() 28 | hypes = yaml_utils.load_yaml(opt.hypes_yaml, opt) 29 | 30 | print('Dataset Building') 31 | opencood_train_dataset = build_dataset(hypes, visualize=False, train=True) 32 | opencood_validate_dataset = build_dataset(hypes, 33 | visualize=False, 34 | train=False) 35 | 36 | train_loader = DataLoader(opencood_train_dataset, 37 | batch_size=hypes['train_params']['batch_size'], 38 | num_workers=8, 39 | collate_fn=opencood_train_dataset.collate_batch_train, 40 | shuffle=True, 41 | pin_memory=False, 42 | drop_last=True) 43 | val_loader = DataLoader(opencood_validate_dataset, 44 | batch_size=hypes['train_params']['batch_size'], 45 | num_workers=8, 46 | collate_fn=opencood_train_dataset.collate_batch_train, 47 | shuffle=False, 48 | pin_memory=False, 49 | drop_last=True) 50 | 51 | print('Creating Model') 52 | model = train_utils.create_model(hypes) 53 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 54 | 55 | # we assume gpu is necessary 56 | if torch.cuda.is_available(): 57 | model.to(device) 58 | 59 | # define the loss 60 | criterion = train_utils.create_loss(hypes) 61 | 62 | # optimizer setup 63 | optimizer = train_utils.setup_optimizer(hypes, model) 64 | # lr scheduler setup 65 | scheduler = train_utils.setup_lr_schedular(hypes, optimizer) 66 | 67 | # if we want to train from last checkpoint. 68 | if opt.model_dir: 69 | saved_path = opt.model_dir 70 | init_epoch, model = train_utils.load_saved_model(saved_path, model) 71 | 72 | else: 73 | init_epoch = 0 74 | # if we train the model from scratch, we need to create a folder 75 | # to save the model, 76 | saved_path = train_utils.setup_train(hypes) 77 | 78 | # record training 79 | writer = SummaryWriter(saved_path) 80 | 81 | # half precision training 82 | if opt.half: 83 | scaler = torch.cuda.amp.GradScaler() 84 | 85 | print('Training start') 86 | epoches = hypes['train_params']['epoches'] 87 | # used to help schedule learning rate 88 | for epoch in range(init_epoch, max(epoches, init_epoch)): 89 | scheduler.step(epoch) 90 | for param_group in optimizer.param_groups: 91 | print('learning rate %f' % param_group["lr"]) 92 | pbar2 = tqdm.tqdm(total=len(train_loader), leave=True) 93 | for i, batch_data in enumerate(train_loader): 94 | # the model will be evaluation mode during validation 95 | model.train() 96 | model.zero_grad() 97 | optimizer.zero_grad() 98 | 99 | batch_data = train_utils.to_device(batch_data, device) 100 | 101 | # case1 : late fusion train --> only ego needed 102 | # case2 : early fusion train --> all data projected to ego 103 | # case3 : intermediate fusion --> ['ego']['processed_lidar'] 104 | # becomes a list, which containing all data from other cavs 105 | # as well 106 | if not opt.half: 107 | ouput_dict = model(batch_data['ego']) 108 | # first argument is always your output dictionary, 109 | # second argument is always your label dictionary. 110 | final_loss = criterion(ouput_dict, batch_data['ego']['label_dict']) 111 | else: 112 | with torch.cuda.amp.autocast(): 113 | ouput_dict = model(batch_data['ego']) 114 | final_loss = criterion(ouput_dict, batch_data['ego']['label_dict']) 115 | 116 | criterion.logging(epoch, i, len(train_loader), writer, pbar=pbar2) 117 | pbar2.update(1) 118 | # back-propagation 119 | if not opt.half: 120 | final_loss.backward() 121 | optimizer.step() 122 | else: 123 | scaler.scale(final_loss).backward() 124 | scaler.step(optimizer) 125 | scaler.update() 126 | if epoch % hypes['train_params']['eval_freq'] == 0: 127 | valid_ave_loss = [] 128 | 129 | with torch.no_grad(): 130 | for i, batch_data in enumerate(val_loader): 131 | model.eval() 132 | 133 | batch_data = train_utils.to_device(batch_data, device) 134 | ouput_dict = model(batch_data['ego']) 135 | 136 | final_loss = criterion(ouput_dict, 137 | batch_data['ego']['label_dict']) 138 | valid_ave_loss.append(final_loss.item()) 139 | valid_ave_loss = statistics.mean(valid_ave_loss) 140 | print('At epoch %d, the validation loss is %f' % (epoch, 141 | valid_ave_loss)) 142 | 143 | writer.add_scalar('Validate_Loss', valid_ave_loss, epoch) 144 | 145 | if epoch % hypes['train_params']['save_freq'] == 0: 146 | torch.save(model.state_dict(), 147 | os.path.join(saved_path, 148 | 'net_epoch%d.pth' % (epoch + 1))) 149 | 150 | print('Training Finished, checkpoints saved to %s' % saved_path) 151 | 152 | 153 | if __name__ == '__main__': 154 | main() 155 | -------------------------------------------------------------------------------- /v2xvit/tools/train_utils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import importlib 3 | import yaml 4 | import os 5 | import re 6 | from datetime import datetime 7 | 8 | import torch 9 | import torch.optim as optim 10 | 11 | 12 | def load_saved_model(saved_path, model): 13 | """ 14 | Load saved model if exiseted 15 | 16 | Parameters 17 | __________ 18 | saved_path : str 19 | model saved path 20 | model : opencood object 21 | The model instance. 22 | 23 | Returns 24 | ------- 25 | model : opencood object 26 | The model instance loaded pretrained params. 27 | """ 28 | assert os.path.exists(saved_path), '{} not found'.format(saved_path) 29 | 30 | def findLastCheckpoint(save_dir): 31 | file_list = glob.glob(os.path.join(save_dir, '*epoch*.pth')) 32 | if file_list: 33 | epochs_exist = [] 34 | for file_ in file_list: 35 | result = re.findall(".*epoch(.*).pth.*", file_) 36 | epochs_exist.append(int(result[0])) 37 | initial_epoch_ = max(epochs_exist) 38 | else: 39 | initial_epoch_ = 0 40 | return initial_epoch_ 41 | 42 | initial_epoch = findLastCheckpoint(saved_path) 43 | if initial_epoch > 0: 44 | print('resuming by loading epoch %d' % initial_epoch) 45 | model.load_state_dict(torch.load( 46 | os.path.join(saved_path, 47 | 'net_epoch%d.pth' % initial_epoch)), strict=False) 48 | 49 | return initial_epoch, model 50 | 51 | 52 | def setup_train(hypes): 53 | """ 54 | Create folder for saved model based on current timestep and model name 55 | 56 | Parameters 57 | ---------- 58 | hypes: dict 59 | Config yaml dictionary for training: 60 | """ 61 | model_name = hypes['name'] 62 | current_time = datetime.now() 63 | 64 | folder_name = current_time.strftime("_%Y_%m_%d_%H_%M_%S") 65 | folder_name = model_name + folder_name 66 | 67 | current_path = os.path.dirname(__file__) 68 | current_path = os.path.join(current_path, '../logs') 69 | 70 | full_path = os.path.join(current_path, folder_name) 71 | 72 | if not os.path.exists(full_path): 73 | os.makedirs(full_path) 74 | # save the yaml file 75 | save_name = os.path.join(full_path, 'config.yaml') 76 | with open(save_name, 'w') as outfile: 77 | yaml.dump(hypes, outfile) 78 | 79 | return full_path 80 | 81 | 82 | def create_model(hypes): 83 | """ 84 | Import the module "models/[model_name].py 85 | 86 | Parameters 87 | __________ 88 | hypes : dict 89 | Dictionary containing parameters. 90 | 91 | Returns 92 | ------- 93 | model : opencood,object 94 | Model object. 95 | """ 96 | backbone_name = hypes['model']['core_method'] 97 | backbone_config = hypes['model']['args'] 98 | 99 | model_filename = "v2xvit.models." + backbone_name 100 | model_lib = importlib.import_module(model_filename) 101 | model = None 102 | target_model_name = backbone_name.replace('_', '') 103 | 104 | for name, cls in model_lib.__dict__.items(): 105 | if name.lower() == target_model_name.lower(): 106 | model = cls 107 | 108 | if model is None: 109 | print('backbone not found in models folder. Please make sure you ' 110 | 'have a python file named %s and has a class ' 111 | 'called %s ignoring upper/lower case' % (model_filename, 112 | target_model_name)) 113 | exit(0) 114 | instance = model(backbone_config) 115 | return instance 116 | 117 | 118 | def create_loss(hypes): 119 | """ 120 | Create the loss function based on the given loss name. 121 | 122 | Parameters 123 | ---------- 124 | hypes : dict 125 | Configuration params for training. 126 | Returns 127 | ------- 128 | criterion : opencood.object 129 | The loss function. 130 | """ 131 | loss_func_name = hypes['loss']['core_method'] 132 | loss_func_config = hypes['loss']['args'] 133 | 134 | loss_filename = "v2xvit.loss." + loss_func_name 135 | loss_lib = importlib.import_module(loss_filename) 136 | loss_func = None 137 | target_loss_name = loss_func_name.replace('_', '') 138 | 139 | for name, lfunc in loss_lib.__dict__.items(): 140 | if name.lower() == target_loss_name.lower(): 141 | loss_func = lfunc 142 | 143 | if loss_func is None: 144 | print('loss function not found in loss folder. Please make sure you ' 145 | 'have a python file named %s and has a class ' 146 | 'called %s ignoring upper/lower case' % (loss_filename, 147 | target_loss_name)) 148 | exit(0) 149 | 150 | criterion = loss_func(loss_func_config) 151 | return criterion 152 | 153 | 154 | def setup_optimizer(hypes, model): 155 | """ 156 | Create optimizer corresponding to the yaml file 157 | 158 | Parameters 159 | ---------- 160 | hypes : dict 161 | The training configurations. 162 | model : opencood model 163 | The pytorch model 164 | """ 165 | method_dict = hypes['optimizer'] 166 | optimizer_method = getattr(optim, method_dict['core_method'], None) 167 | if not optimizer_method: 168 | raise ValueError('{} is not supported'.format(method_dict['name'])) 169 | if 'args' in method_dict: 170 | return optimizer_method(filter(lambda p: p.requires_grad, 171 | model.parameters()), 172 | lr=method_dict['lr'], 173 | **method_dict['args']) 174 | else: 175 | return optimizer_method(filter(lambda p: p.requires_grad, 176 | model.parameters()), 177 | lr=method_dict['lr']) 178 | 179 | 180 | def setup_lr_schedular(hypes, optimizer): 181 | """ 182 | Set up the learning rate schedular. 183 | 184 | Parameters 185 | ---------- 186 | hypes : dict 187 | The training configurations. 188 | 189 | optimizer : torch.optimizer 190 | """ 191 | lr_schedule_config = hypes['lr_scheduler'] 192 | 193 | if lr_schedule_config['core_method'] == 'step': 194 | from torch.optim.lr_scheduler import StepLR 195 | step_size = lr_schedule_config['step_size'] 196 | gamma = lr_schedule_config['gamma'] 197 | scheduler = StepLR(optimizer, step_size=step_size, gamma=gamma) 198 | 199 | elif lr_schedule_config['core_method'] == 'multistep': 200 | from torch.optim.lr_scheduler import MultiStepLR 201 | milestones = lr_schedule_config['step_size'] 202 | gamma = lr_schedule_config['gamma'] 203 | scheduler = MultiStepLR(optimizer, 204 | milestones=milestones, 205 | gamma=gamma) 206 | 207 | else: 208 | from torch.optim.lr_scheduler import ExponentialLR 209 | gamma = lr_schedule_config['gamma'] 210 | scheduler = ExponentialLR(optimizer, gamma) 211 | 212 | return scheduler 213 | 214 | 215 | def to_device(inputs, device): 216 | if isinstance(inputs, list): 217 | return [to_device(x, device) for x in inputs] 218 | elif isinstance(inputs, dict): 219 | return {k: to_device(v, device) for k, v in inputs.items()} 220 | else: 221 | if isinstance(inputs, int) or isinstance(inputs, float) \ 222 | or isinstance(inputs, str): 223 | return inputs 224 | return inputs.to(device) 225 | -------------------------------------------------------------------------------- /v2xvit/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DerrickXuNu/v2x-vit/f0e6c13f41e916548b2d8aba61e42a18ce980416/v2xvit/utils/__init__.py -------------------------------------------------------------------------------- /v2xvit/utils/box_overlaps.pyx: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Fast R-CNN 3 | # Copyright (c) 2015 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Sergey Karayev 6 | # -------------------------------------------------------- 7 | 8 | import numpy as np 9 | cimport numpy as np 10 | from cython.parallel import prange, parallel 11 | 12 | 13 | DTYPE = np.float32 14 | ctypedef float DTYPE_t 15 | 16 | 17 | def bbox_overlaps( 18 | np.ndarray[DTYPE_t, ndim=2] boxes, 19 | np.ndarray[DTYPE_t, ndim=2] query_boxes): 20 | """ 21 | Parameters 22 | ---------- 23 | boxes: (N, 4) ndarray of float 24 | query_boxes: (K, 4) ndarray of float 25 | Returns 26 | ------- 27 | overlaps: (N, K) ndarray of overlap between boxes and query_boxes 28 | """ 29 | cdef unsigned int N = boxes.shape[0] 30 | cdef unsigned int K = query_boxes.shape[0] 31 | cdef np.ndarray[DTYPE_t, ndim=2] overlaps = np.zeros((N, K), dtype=DTYPE) 32 | cdef DTYPE_t iw, ih, box_area 33 | cdef DTYPE_t ua 34 | cdef unsigned int k, n 35 | for k in range(K): 36 | box_area = ( 37 | (query_boxes[k, 2] - query_boxes[k, 0] + 1) * 38 | (query_boxes[k, 3] - query_boxes[k, 1] + 1) 39 | ) 40 | for n in range(N): 41 | iw = ( 42 | min(boxes[n, 2], query_boxes[k, 2]) - 43 | max(boxes[n, 0], query_boxes[k, 0]) + 1 44 | ) 45 | if iw > 0: 46 | ih = ( 47 | min(boxes[n, 3], query_boxes[k, 3]) - 48 | max(boxes[n, 1], query_boxes[k, 1]) + 1 49 | ) 50 | if ih > 0: 51 | ua = float( 52 | (boxes[n, 2] - boxes[n, 0] + 1) * 53 | (boxes[n, 3] - boxes[n, 1] + 1) + 54 | box_area - iw * ih 55 | ) 56 | overlaps[n, k] = iw * ih / ua 57 | return overlaps 58 | 59 | def bbox_intersections( 60 | np.ndarray[DTYPE_t, ndim=2] boxes, 61 | np.ndarray[DTYPE_t, ndim=2] query_boxes): 62 | """ 63 | For each query box compute the intersection ratio covered by boxes 64 | ---------- 65 | Parameters 66 | ---------- 67 | boxes: (N, 4) ndarray of float 68 | query_boxes: (K, 4) ndarray of float 69 | Returns 70 | ------- 71 | overlaps: (N, K) ndarray of intersec between boxes and query_boxes 72 | """ 73 | cdef unsigned int N = boxes.shape[0] 74 | cdef unsigned int K = query_boxes.shape[0] 75 | cdef np.ndarray[DTYPE_t, ndim=2] intersec = np.zeros((N, K), dtype=DTYPE) 76 | cdef DTYPE_t iw, ih, box_area 77 | cdef DTYPE_t ua 78 | cdef unsigned int k, n 79 | for k in range(K): 80 | box_area = ( 81 | (query_boxes[k, 2] - query_boxes[k, 0] + 1) * 82 | (query_boxes[k, 3] - query_boxes[k, 1] + 1) 83 | ) 84 | for n in range(N): 85 | iw = ( 86 | min(boxes[n, 2], query_boxes[k, 2]) - 87 | max(boxes[n, 0], query_boxes[k, 0]) + 1 88 | ) 89 | if iw > 0: 90 | ih = ( 91 | min(boxes[n, 3], query_boxes[k, 3]) - 92 | max(boxes[n, 1], query_boxes[k, 1]) + 1 93 | ) 94 | if ih > 0: 95 | intersec[n, k] = iw * ih / box_area 96 | return intersec 97 | 98 | # Compute bounding box voting 99 | def box_vote( 100 | np.ndarray[float, ndim=2] dets_NMS, 101 | np.ndarray[float, ndim=2] dets_all): 102 | cdef np.ndarray[float, ndim=2] dets_voted = np.zeros((dets_NMS.shape[0], dets_NMS.shape[1]), dtype=np.float32) 103 | cdef unsigned int N = dets_NMS.shape[0] 104 | cdef unsigned int M = dets_all.shape[0] 105 | 106 | cdef np.ndarray[float, ndim=1] det 107 | cdef np.ndarray[float, ndim=1] acc_box 108 | cdef float acc_score 109 | 110 | cdef np.ndarray[float, ndim=1] det2 111 | cdef float bi0, bi1, bit2, bi3 112 | cdef float iw, ih, ua 113 | 114 | cdef float thresh=0.5 115 | 116 | for i in range(N): 117 | det = dets_NMS[i, :] 118 | acc_box = np.zeros((4), dtype=np.float32) 119 | acc_score = 0.0 120 | 121 | for m in range(M): 122 | det2 = dets_all[m, :] 123 | 124 | bi0 = max(det[0], det2[0]) 125 | bi1 = max(det[1], det2[1]) 126 | bi2 = min(det[2], det2[2]) 127 | bi3 = min(det[3], det2[3]) 128 | 129 | iw = bi2 - bi0 + 1 130 | ih = bi3 - bi1 + 1 131 | 132 | if not (iw > 0 and ih > 0): 133 | continue 134 | 135 | ua = (det[2] - det[0] + 1) * (det[3] - det[1] + 1) + (det2[2] - det2[0] + 1) * (det2[3] - det2[1] + 1) - iw * ih 136 | ov = iw * ih / ua 137 | 138 | if (ov < thresh): 139 | continue 140 | 141 | acc_box += det2[4] * det2[0:4] 142 | acc_score += det2[4] 143 | 144 | dets_voted[i][0:4] = acc_box / acc_score 145 | dets_voted[i][4] = det[4] # Keep the original score 146 | 147 | return dets_voted 148 | -------------------------------------------------------------------------------- /v2xvit/utils/common_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Common utilities 3 | """ 4 | 5 | import numpy as np 6 | import torch 7 | from shapely.geometry import Polygon 8 | 9 | 10 | def check_numpy_to_torch(x): 11 | if isinstance(x, np.ndarray): 12 | return torch.from_numpy(x).float(), True 13 | return x, False 14 | 15 | 16 | def check_contain_nan(x): 17 | if isinstance(x, dict): 18 | return any(check_contain_nan(v) for k, v in x.items()) 19 | if isinstance(x, list): 20 | return any(check_contain_nan(itm) for itm in x) 21 | if isinstance(x, int) or isinstance(x, float): 22 | return False 23 | if isinstance(x, np.ndarray): 24 | return np.any(np.isnan(x)) 25 | return torch.any(x.isnan()).detach().cpu().item() 26 | 27 | 28 | def rotate_points_along_z(points, angle): 29 | """ 30 | Args: 31 | points: (B, N, 3 + C) 32 | angle: (B), radians, angle along z-axis, angle increases x ==> y 33 | Returns: 34 | 35 | """ 36 | points, is_numpy = check_numpy_to_torch(points) 37 | angle, _ = check_numpy_to_torch(angle) 38 | 39 | cosa = torch.cos(angle) 40 | sina = torch.sin(angle) 41 | zeros = angle.new_zeros(points.shape[0]) 42 | ones = angle.new_ones(points.shape[0]) 43 | rot_matrix = torch.stack(( 44 | cosa, sina, zeros, 45 | -sina, cosa, zeros, 46 | zeros, zeros, ones 47 | ), dim=1).view(-1, 3, 3).float() 48 | points_rot = torch.matmul(points[:, :, 0:3].float(), rot_matrix) 49 | points_rot = torch.cat((points_rot, points[:, :, 3:]), dim=-1) 50 | return points_rot.numpy() if is_numpy else points_rot 51 | 52 | 53 | def rotate_points_along_z_2d(points, angle): 54 | """ 55 | Rorate the points along z-axis. 56 | Parameters 57 | ---------- 58 | points : torch.Tensor / np.ndarray 59 | (N, 2). 60 | angle : torch.Tensor / np.ndarray 61 | (N,) 62 | 63 | Returns 64 | ------- 65 | points_rot : torch.Tensor / np.ndarray 66 | Rorated points with shape (N, 2) 67 | 68 | """ 69 | points, is_numpy = check_numpy_to_torch(points) 70 | angle, _ = check_numpy_to_torch(angle) 71 | cosa = torch.cos(angle) 72 | sina = torch.sin(angle) 73 | # (N, 2, 2) 74 | rot_matrix = torch.stack((cosa, sina, -sina, cosa), dim=1).view(-1, 2, 75 | 2).float() 76 | points_rot = torch.einsum("ik, ikj->ij", points.float(), rot_matrix) 77 | return points_rot.numpy() if is_numpy else points_rot 78 | 79 | 80 | def remove_ego_from_objects(objects, ego_id): 81 | """ 82 | Avoid adding ego vehicle to the object dictionary. 83 | 84 | Parameters 85 | ---------- 86 | objects : dict 87 | The dictionary contained all objects. 88 | 89 | ego_id : int 90 | Ego id. 91 | """ 92 | if ego_id in objects: 93 | del objects[ego_id] 94 | 95 | 96 | def retrieve_ego_id(base_data_dict): 97 | """ 98 | Retrieve the ego vehicle id from sample(origin format). 99 | 100 | Parameters 101 | ---------- 102 | base_data_dict : dict 103 | Data sample in origin format. 104 | 105 | Returns 106 | ------- 107 | ego_id : str 108 | The id of ego vehicle. 109 | """ 110 | ego_id = None 111 | 112 | for cav_id, cav_content in base_data_dict.items(): 113 | if cav_content['ego']: 114 | ego_id = cav_id 115 | break 116 | return ego_id 117 | 118 | 119 | def compute_iou(box, boxes): 120 | """ 121 | Compute iou between box and boxes list 122 | Parameters 123 | ---------- 124 | box : shapely.geometry.Polygon 125 | Bounding box Polygon. 126 | 127 | boxes : list 128 | List of shapely.geometry.Polygon. 129 | 130 | Returns 131 | ------- 132 | iou : np.ndarray 133 | Array of iou between box and boxes. 134 | 135 | """ 136 | # Calculate intersection areas 137 | iou = [box.intersection(b).area / box.union(b).area for b in boxes] 138 | 139 | return np.array(iou, dtype=np.float32) 140 | 141 | 142 | def convert_format(boxes_array): 143 | """ 144 | Convert boxes array to shapely.geometry.Polygon format. 145 | Parameters 146 | ---------- 147 | boxes_array : np.ndarray 148 | (N, 4, 2) or (N, 8, 3). 149 | 150 | Returns 151 | ------- 152 | list of converted shapely.geometry.Polygon object. 153 | 154 | """ 155 | polygons = [Polygon([(box[i, 0], box[i, 1]) for i in range(4)]) for box in 156 | boxes_array] 157 | return np.array(polygons) 158 | 159 | 160 | def torch_tensor_to_numpy(torch_tensor): 161 | """ 162 | Convert a torch tensor to numpy. 163 | 164 | Parameters 165 | ---------- 166 | torch_tensor : torch.Tensor 167 | 168 | Returns 169 | ------- 170 | A numpy array. 171 | """ 172 | return torch_tensor.numpy() if not torch_tensor.is_cuda else \ 173 | torch_tensor.cpu().detach().numpy() 174 | -------------------------------------------------------------------------------- /v2xvit/utils/eval_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from v2xvit.utils import common_utils 7 | from v2xvit.hypes_yaml import yaml_utils 8 | 9 | 10 | def voc_ap(rec, prec): 11 | """ 12 | VOC 2010 Average Precision. 13 | """ 14 | rec.insert(0, 0.0) 15 | rec.append(1.0) 16 | mrec = rec[:] 17 | 18 | prec.insert(0, 0.0) 19 | prec.append(0.0) 20 | mpre = prec[:] 21 | 22 | for i in range(len(mpre) - 2, -1, -1): 23 | mpre[i] = max(mpre[i], mpre[i + 1]) 24 | 25 | i_list = [] 26 | for i in range(1, len(mrec)): 27 | if mrec[i] != mrec[i - 1]: 28 | i_list.append(i) 29 | 30 | ap = 0.0 31 | for i in i_list: 32 | ap += ((mrec[i] - mrec[i - 1]) * mpre[i]) 33 | return ap, mrec, mpre 34 | 35 | 36 | def caluclate_tp_fp(det_boxes, det_score, gt_boxes, result_stat, iou_thresh): 37 | """ 38 | Calculate the true positive and false positive numbers of the current 39 | frames. 40 | 41 | Parameters 42 | ---------- 43 | det_boxes : torch.Tensor 44 | The detection bounding box, shape (N, 8, 3) or (N, 4, 2). 45 | det_score :torch.Tensor 46 | The confidence score for each preditect bounding box. 47 | gt_boxes : torch.Tensor 48 | The groundtruth bounding box. 49 | result_stat: dict 50 | A dictionary contains fp, tp and gt number. 51 | iou_thresh : float 52 | The iou thresh. 53 | """ 54 | # fp, tp and gt in the current frame 55 | fp = [] 56 | tp = [] 57 | gt = gt_boxes.shape[0] 58 | if det_boxes is not None: 59 | # convert bounding boxes to numpy array 60 | det_boxes = common_utils.torch_tensor_to_numpy(det_boxes) 61 | det_score = common_utils.torch_tensor_to_numpy(det_score) 62 | gt_boxes = common_utils.torch_tensor_to_numpy(gt_boxes) 63 | 64 | # sort the prediction bounding box by score 65 | score_order_descend = np.argsort(-det_score) 66 | det_polygon_list = list(common_utils.convert_format(det_boxes)) 67 | gt_polygon_list = list(common_utils.convert_format(gt_boxes)) 68 | 69 | # match prediction and gt bounding box 70 | for i in range(score_order_descend.shape[0]): 71 | det_polygon = det_polygon_list[score_order_descend[i]] 72 | ious = common_utils.compute_iou(det_polygon, gt_polygon_list) 73 | 74 | if len(gt_polygon_list) == 0 or np.max(ious) < iou_thresh: 75 | fp.append(1) 76 | tp.append(0) 77 | continue 78 | 79 | fp.append(0) 80 | tp.append(1) 81 | 82 | gt_index = np.argmax(ious) 83 | gt_polygon_list.pop(gt_index) 84 | 85 | result_stat[iou_thresh]['fp'] += fp 86 | result_stat[iou_thresh]['tp'] += tp 87 | result_stat[iou_thresh]['gt'] += gt 88 | 89 | 90 | def calculate_ap(result_stat, iou): 91 | """ 92 | Calculate the average precision and recall, and save them into a txt. 93 | 94 | Parameters 95 | ---------- 96 | result_stat : dict 97 | A dictionary contains fp, tp and gt number. 98 | iou : float 99 | """ 100 | iou_5 = result_stat[iou] 101 | 102 | fp = iou_5['fp'] 103 | tp = iou_5['tp'] 104 | assert len(fp) == len(tp) 105 | 106 | gt_total = iou_5['gt'] 107 | 108 | cumsum = 0 109 | for idx, val in enumerate(fp): 110 | fp[idx] += cumsum 111 | cumsum += val 112 | 113 | cumsum = 0 114 | for idx, val in enumerate(tp): 115 | tp[idx] += cumsum 116 | cumsum += val 117 | 118 | rec = tp[:] 119 | for idx, val in enumerate(tp): 120 | rec[idx] = float(tp[idx]) / gt_total 121 | 122 | prec = tp[:] 123 | for idx, val in enumerate(tp): 124 | prec[idx] = float(tp[idx]) / (fp[idx] + tp[idx]) 125 | 126 | ap, mrec, mprec = voc_ap(rec[:], prec[:]) 127 | 128 | return ap, mrec, mprec 129 | 130 | 131 | def eval_final_results(result_stat, save_path): 132 | dump_dict = {} 133 | 134 | ap_30, mrec_30, mpre_30 = calculate_ap(result_stat, 0.30) 135 | ap_50, mrec_50, mpre_50 = calculate_ap(result_stat, 0.50) 136 | ap_70, mrec_70, mpre_70 = calculate_ap(result_stat, 0.70) 137 | 138 | dump_dict.update({'ap30': ap_30, 139 | 'ap_50': ap_50, 140 | 'ap_70': ap_70, 141 | 'mpre_50': mpre_50, 142 | 'mrec_50': mrec_50, 143 | 'mpre_70': mpre_70, 144 | 'mrec_70': mrec_70, 145 | }) 146 | yaml_utils.save_yaml(dump_dict, os.path.join(save_path, 'eval.yaml')) 147 | 148 | print('The Average Precision at IOU 0.3 is %.2f, ' 149 | 'The Average Precision at IOU 0.5 is %.2f, ' 150 | 'The Average Precision at IOU 0.7 is %.2f' % (ap_30, ap_50, ap_70)) 151 | -------------------------------------------------------------------------------- /v2xvit/utils/pcd_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility functions related to point cloud 3 | """ 4 | 5 | import open3d as o3d 6 | import numpy as np 7 | 8 | 9 | def pcd_to_np(pcd_file): 10 | """ 11 | Read pcd and return numpy array. 12 | 13 | Parameters 14 | ---------- 15 | pcd_file : str 16 | The pcd file that contains the point cloud. 17 | 18 | Returns 19 | ------- 20 | pcd : o3d.PointCloud 21 | PointCloud object, used for visualization 22 | pcd_np : np.ndarray 23 | The lidar data in numpy format, shape:(n, 4) 24 | 25 | """ 26 | pcd = o3d.io.read_point_cloud(pcd_file) 27 | 28 | xyz = np.asarray(pcd.points) 29 | # we save the intensity in the first channel 30 | intensity = np.expand_dims(np.asarray(pcd.colors)[:, 0], -1) 31 | pcd_np = np.hstack((xyz, intensity)) 32 | 33 | return np.asarray(pcd_np, dtype=np.float32) 34 | 35 | 36 | def mask_points_by_range(points, limit_range): 37 | """ 38 | Remove the lidar points out of the boundary. 39 | 40 | Parameters 41 | ---------- 42 | points : np.ndarray 43 | Lidar points under lidar sensor coordinate system. 44 | 45 | limit_range : list 46 | [x_min, y_min, z_min, x_max, y_max, z_max] 47 | 48 | Returns 49 | ------- 50 | points : np.ndarray 51 | Filtered lidar points. 52 | """ 53 | 54 | mask = (points[:, 0] > limit_range[0]) & (points[:, 0] < limit_range[3])\ 55 | & (points[:, 1] > limit_range[1]) & ( 56 | points[:, 1] < limit_range[4]) \ 57 | & (points[:, 2] > limit_range[2]) & ( 58 | points[:, 2] < limit_range[5]) 59 | 60 | points = points[mask] 61 | 62 | return points 63 | 64 | 65 | def mask_ego_points(points): 66 | """ 67 | Remove the lidar points of the ego vehicle itself. 68 | 69 | Parameters 70 | ---------- 71 | points : np.ndarray 72 | Lidar points under lidar sensor coordinate system. 73 | 74 | Returns 75 | ------- 76 | points : np.ndarray 77 | Filtered lidar points. 78 | """ 79 | mask = (points[:, 0] >= -1.95) & (points[:, 0] <= 2.95) \ 80 | & (points[:, 1] >= -1.1) & (points[:, 1] <= 1.1) 81 | points = points[np.logical_not(mask)] 82 | 83 | return points 84 | 85 | 86 | def shuffle_points(points): 87 | shuffle_idx = np.random.permutation(points.shape[0]) 88 | points = points[shuffle_idx] 89 | 90 | return points 91 | 92 | 93 | def lidar_project(lidar_data, extrinsic): 94 | """ 95 | Given the extrinsic matrix, project lidar data to another space. 96 | 97 | Parameters 98 | ---------- 99 | lidar_data : np.ndarray 100 | Lidar data, shape: (n, 4) 101 | 102 | extrinsic : np.ndarray 103 | Extrinsic matrix, shape: (4, 4) 104 | 105 | Returns 106 | ------- 107 | projected_lidar : np.ndarray 108 | Projected lida data, shape: (n, 4) 109 | """ 110 | 111 | lidar_xyz = lidar_data[:, :3].T 112 | # (3, n) -> (4, n), homogeneous transformation 113 | lidar_xyz = np.r_[lidar_xyz, [np.ones(lidar_xyz.shape[1])]] 114 | lidar_int = lidar_data[:, 3] 115 | 116 | # transform to ego vehicle space, (3, n) 117 | project_lidar_xyz = np.dot(extrinsic, lidar_xyz)[:3, :] 118 | # (n, 3) 119 | project_lidar_xyz = project_lidar_xyz.T 120 | # concatenate the intensity with xyz, (n, 4) 121 | projected_lidar = np.hstack((project_lidar_xyz, 122 | np.expand_dims(lidar_int, -1))) 123 | 124 | return projected_lidar 125 | 126 | 127 | def projected_lidar_stack(projected_lidar_list): 128 | """ 129 | Stack all projected lidar together. 130 | 131 | Parameters 132 | ---------- 133 | projected_lidar_list : list 134 | The list containing all projected lidar. 135 | 136 | Returns 137 | ------- 138 | stack_lidar : np.ndarray 139 | Stack all projected lidar data together. 140 | """ 141 | stack_lidar = [] 142 | for lidar_data in projected_lidar_list: 143 | stack_lidar.append(lidar_data) 144 | 145 | return np.vstack(stack_lidar) 146 | 147 | 148 | def downsample_lidar(pcd_np, num): 149 | """ 150 | Downsample the lidar points to a certain number. 151 | 152 | Parameters 153 | ---------- 154 | pcd_np : np.ndarray 155 | The lidar points, (n, 4). 156 | 157 | num : int 158 | The downsample target number. 159 | 160 | Returns 161 | ------- 162 | pcd_np : np.ndarray 163 | The downsampled lidar points. 164 | """ 165 | assert pcd_np.shape[0] >= num 166 | 167 | selected_index = np.random.choice((pcd_np.shape[0]), 168 | num, 169 | replace=False) 170 | pcd_np = pcd_np[selected_index] 171 | 172 | return pcd_np 173 | 174 | 175 | def downsample_lidar_minimum(pcd_np_list): 176 | """ 177 | Given a list of pcd, find the minimum number and downsample all 178 | point clouds to the minimum number. 179 | 180 | Parameters 181 | ---------- 182 | pcd_np_list : list 183 | A list of pcd numpy array(n, 4). 184 | Returns 185 | ------- 186 | pcd_np_list : list 187 | Downsampled point clouds. 188 | """ 189 | minimum = np.Inf 190 | 191 | for i in range(len(pcd_np_list)): 192 | num = pcd_np_list[i].shape[0] 193 | minimum = num if minimum > num else minimum 194 | 195 | for (i, pcd_np) in enumerate(pcd_np_list): 196 | pcd_np_list[i] = downsample_lidar(pcd_np, minimum) 197 | 198 | return pcd_np_list 199 | -------------------------------------------------------------------------------- /v2xvit/utils/setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | from Cython.Build import cythonize 3 | import numpy 4 | setup( 5 | name='box overlaps', 6 | ext_modules=cythonize('v2xvit/utils/box_overlaps.pyx'), 7 | include_dirs=[numpy.get_include()] 8 | ) -------------------------------------------------------------------------------- /v2xvit/utils/transformation_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Transformation utils 3 | """ 4 | 5 | import numpy as np 6 | 7 | 8 | def x_to_world(pose): 9 | """ 10 | The transformation matrix from x-coordinate system to carla world system 11 | 12 | Parameters 13 | ---------- 14 | pose : list 15 | [x, y, z, roll, yaw, pitch] 16 | 17 | Returns 18 | ------- 19 | matrix : np.ndarray 20 | The transformation matrix. 21 | """ 22 | x, y, z, roll, yaw, pitch = pose[:] 23 | 24 | # used for rotation matrix 25 | c_y = np.cos(np.radians(yaw)) 26 | s_y = np.sin(np.radians(yaw)) 27 | c_r = np.cos(np.radians(roll)) 28 | s_r = np.sin(np.radians(roll)) 29 | c_p = np.cos(np.radians(pitch)) 30 | s_p = np.sin(np.radians(pitch)) 31 | 32 | matrix = np.identity(4) 33 | # translation matrix 34 | matrix[0, 3] = x 35 | matrix[1, 3] = y 36 | matrix[2, 3] = z 37 | 38 | # rotation matrix 39 | matrix[0, 0] = c_p * c_y 40 | matrix[0, 1] = c_y * s_p * s_r - s_y * c_r 41 | matrix[0, 2] = -c_y * s_p * c_r - s_y * s_r 42 | matrix[1, 0] = s_y * c_p 43 | matrix[1, 1] = s_y * s_p * s_r + c_y * c_r 44 | matrix[1, 2] = -s_y * s_p * c_r + c_y * s_r 45 | matrix[2, 0] = s_p 46 | matrix[2, 1] = -c_p * s_r 47 | matrix[2, 2] = c_p * c_r 48 | 49 | return matrix 50 | 51 | 52 | def x1_to_x2(x1, x2): 53 | """ 54 | Transformation matrix from x1 to x2. 55 | 56 | Parameters 57 | ---------- 58 | x1 : list 59 | The pose of x1 under world coordinates. 60 | x2 : list 61 | The pose of x2 under world coordinates. 62 | 63 | Returns 64 | ------- 65 | transformation_matrix : np.ndarray 66 | The transformation matrix. 67 | 68 | """ 69 | x1_to_world = x_to_world(x1) 70 | x2_to_world = x_to_world(x2) 71 | world_to_x2 = np.linalg.inv(x2_to_world) 72 | 73 | transformation_matrix = np.dot(world_to_x2, x1_to_world) 74 | return transformation_matrix 75 | 76 | 77 | def dist_to_continuous(p_dist, displacement_dist, res, downsample_rate): 78 | """ 79 | Convert points discretized format to continuous space for BEV representation. 80 | Parameters 81 | ---------- 82 | p_dist : numpy.array 83 | Points in discretized coorindates. 84 | 85 | displacement_dist : numpy.array 86 | Discretized coordinates of bottom left origin. 87 | 88 | res : float 89 | Discretization resolution. 90 | 91 | downsample_rate : int 92 | Dowmsamping rate. 93 | 94 | Returns 95 | ------- 96 | p_continuous : numpy.array 97 | Points in continuous coorindates. 98 | 99 | """ 100 | p_dist = np.copy(p_dist) 101 | p_dist = p_dist + displacement_dist 102 | p_continuous = p_dist * res * downsample_rate 103 | return p_continuous 104 | -------------------------------------------------------------------------------- /v2xvit/version.py: -------------------------------------------------------------------------------- 1 | """Specifies the current version number of v2xvit.""" 2 | 3 | __version__ = "0.1.0" 4 | -------------------------------------------------------------------------------- /v2xvit/visualization/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DerrickXuNu/v2x-vit/f0e6c13f41e916548b2d8aba61e42a18ce980416/v2xvit/visualization/__init__.py -------------------------------------------------------------------------------- /v2xvit/visualization/pinhole_param.json: -------------------------------------------------------------------------------- 1 | { 2 | "class_name" : "PinholeCameraParameters", 3 | "extrinsic" : 4 | [ 5 | 1.0, 6 | -0.0, 7 | -0.0, 8 | 0.0, 9 | 0.0, 10 | -1.0, 11 | -0.0, 12 | 0.0, 13 | 0.0, 14 | -0.0, 15 | -1.0, 16 | 0.0, 17 | 14.870189666748047, 18 | 0.0001621246337890625, 19 | 141.0903074604017, 20 | 1.0 21 | ], 22 | "intrinsic" : 23 | { 24 | "height" : 1025, 25 | "intrinsic_matrix" : 26 | [ 27 | 887.67603887904966, 28 | 0.0, 29 | 0.0, 30 | 0.0, 31 | 887.67603887904966, 32 | 0.0, 33 | 926.0, 34 | 512.0, 35 | 1.0 36 | ], 37 | "width" : 1853 38 | }, 39 | "version_major" : 1, 40 | "version_minor" : 0 41 | } -------------------------------------------------------------------------------- /v2xvit/visualization/vis_data_sequence.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from torch.utils.data import DataLoader 4 | 5 | from v2xvit.hypes_yaml.yaml_utils import load_yaml 6 | from v2xvit.visualization import vis_utils 7 | from v2xvit.data_utils.datasets.early_fusion_vis_dataset import \ 8 | EarlyFusionVisDataset 9 | 10 | 11 | def vis_parser(): 12 | parser = argparse.ArgumentParser(description="data visualization") 13 | parser.add_argument('--color_mode', type=str, default="intensity", 14 | help='lidar color rendering mode, e.g. intensity,' 15 | 'z-value or constant.') 16 | opt = parser.parse_args() 17 | return opt 18 | 19 | 20 | if __name__ == '__main__': 21 | current_path = os.path.dirname(os.path.realpath(__file__)) 22 | params = load_yaml(os.path.join(current_path, 23 | '../hypes_yaml/visualization.yaml')) 24 | 25 | opencda_dataset = EarlyFusionVisDataset(params, visualize=True, 26 | train=False) 27 | data_loader = DataLoader(opencda_dataset, batch_size=1, num_workers=8, 28 | collate_fn=opencda_dataset.collate_batch_train, 29 | shuffle=False, 30 | pin_memory=False) 31 | 32 | opt = vis_parser() 33 | vis_utils.visualize_sequence_dataloader(data_loader, 34 | params['postprocess']['order'], 35 | color_mode=opt.color_mode) 36 | --------------------------------------------------------------------------------