├── .gitignore ├── LICENSE ├── README.md ├── config ├── label_mapping │ └── nuscenes-occ.yaml ├── train_dome.py ├── train_dome_resample.py └── train_occvae.py ├── dataset ├── __init__.py ├── dataset.py ├── dataset_wrapper.py └── sampler.py ├── diffusion ├── __init__.py ├── diffusion_utils.py ├── gaussian_diffusion.py ├── respace.py └── timestep_sampler.py ├── environment.yml ├── loss ├── __init__.py ├── base_loss.py ├── ce_loss.py ├── emb_loss.py ├── multi_loss.py └── recon_loss.py ├── model ├── VAE │ ├── quantizer.py │ └── vae_2d_resnet.py ├── __init__.py ├── dome.py └── pose_encoder.py ├── resample ├── astar.py ├── launch.py ├── main.py ├── requirements.txt └── utils.py ├── static └── images │ ├── dome_pipe.png │ ├── favicon.ico │ ├── occvae.png │ ├── overall_pipeline4.png │ ├── teaser12.png │ └── vis_demo_cmp_2.png ├── tools ├── eval.sh ├── eval_metric.py ├── eval_vae.py ├── eval_vae.sh ├── train_diffusion.py ├── train_diffusion.sh ├── train_vae.py ├── train_vae.sh ├── vis_diffusion.sh ├── vis_gif.py ├── vis_utils.py ├── vis_vae.sh ├── visualize_demo.py └── visualize_demo_vae.py └── utils ├── __init__.py ├── ema.py ├── freeze_model.py ├── load_save_util.py └── metric_util.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | data 3 | out 4 | out/ 5 | results/ 6 | __pycache__/ 7 | visualization/results/ 8 | analyze/ 9 | *.png 10 | *.jpg 11 | *.log 12 | ckpts/ 13 | results/ 14 | .DS_Store 15 | dev/ 16 | kernels/ 17 | work_dir/ 18 | *.pth 19 | *.zip -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # DOME: Taming Diffusion Model into High-Fidelity Controllable Occupancy World Model 4 |
5 | 6 | ### [Project Page](https://gusongen.github.io/DOME/) | [Paper](https://arxiv.org/abs/2410.10429v1) 7 | 8 |
9 | 10 | 11 | 12 | https://github.com/user-attachments/assets/7da724d5-acbc-40f7-b5f8-dac38cfbe24a 13 | 14 | 15 | https://github.com/user-attachments/assets/9ab4237c-67c0-4718-8576-32068a74bccf 16 | 17 | 18 | https://github.com/user-attachments/assets/f29149f3-a749-4c55-9777-3d05912bebe1 19 | 20 | 21 | 22 | 23 | 24 | teaser 25 | Our Occupancy World Model can generate long-duration occupancy forecasts and can be effectively controlled by trajectory conditions. 26 | 27 | 28 | # 📖 Overview 29 | overview 30 | Our method consists of two components: (a) Occ-VAE Pipeline encodes occupancy frames into a continuous latent space, enabling efficient data compression. (b)DOME Pipeline learns to predict 4D occupancy based on historical occupancy observations. 31 | 32 | 33 | 34 | ## 🗓️ News 35 | - [2025.1.1] We release the code and checkpoints. 36 | - [2024.11.18] Project page is online! 37 | 38 | ## 🗓️ TODO 39 | - [x] Code release. 40 | - [x] Checkpoint release. 41 | 42 | 43 | ## 🚀 Setup 44 | ### clone the repo 45 | ``` 46 | git clone https://github.com/gusongen/DOME.git 47 | cd DOME 48 | ``` 49 | 50 | ### environment setup 51 | ``` 52 | conda env create --file environment.yml 53 | ``` 54 | 55 | ### data preparation 56 | 1. Create soft link from `data/nuscenes` to your_nuscenes_path 57 | 58 | 2. Prepare the gts semantic occupancy introduced in [Occ3d](https://github.com/Tsinghua-MARS-Lab/Occ3D) 59 | 60 | 3. Download our generated train/val pickle files and put them in `data/` 61 | 62 | [nuscenes_infos_train_temporal_v3_scene.pkl](https://cloud.tsinghua.edu.cn/d/9e231ed16e4a4caca3bd/) 63 | 64 | [nuscenes_infos_val_temporal_v3_scene.pkl](https://cloud.tsinghua.edu.cn/d/9e231ed16e4a4caca3bd/) 65 | 66 | The dataset should be organized as follows: 67 | 68 | ``` 69 | . 70 | └── data/ 71 | ├── nuscenes # downloaded from www.nuscenes.org/ 72 | │ ├── lidarseg 73 | │ ├── maps 74 | │ ├── samples 75 | │ ├── sweeps 76 | │ ├── v1.0-trainval 77 | │ └── gts # download from Occ3d 78 | ├── nuscenes_infos_train_temporal_v3_scene.pkl 79 | └── nuscenes_infos_val_temporal_v3_scene.pkl 80 | ``` 81 | ### ckpt preparation 82 | Download the pretrained weights from [here](https://drive.google.com/drive/folders/1D1HugOG7JurEqmnQo4XbW_-Ji0chEq-e?usp=sharing) and put them in `ckpts` folder. 83 | 84 | ## 🏃 Run the code 85 | ### (optional) Preprocess resampled data 86 | ``` 87 | cd resample 88 | 89 | python launch.py \ 90 | --dst ../data/resampled_occ \ 91 | --imageset ../data/nuscenes_infos_train_temporal_v3_scene.pkl \ 92 | --data_path ../data/nuscenes 93 | ``` 94 | 95 | ### OCC-VAE 96 | ```shell 97 | # train 98 | sh tools/train_vae.sh 99 | 100 | # eval 101 | sh tools/eval_vae.sh 102 | 103 | # visualize 104 | sh tools/vis_vae.sh 105 | ``` 106 | 107 | ### DOME 108 | ```shell 109 | # train 110 | sh tools/train_diffusion.sh 111 | 112 | # eval 113 | sh tools/eval.sh 114 | 115 | # visualize 116 | sh tools/vis_diffusion.sh 117 | ``` 118 | 119 | ## 🎫 Acknowledgment 120 | This code draws inspiration from their work. We sincerely appreciate their excellent contribution. 121 | - [OccWorld](https://github.com/wzzheng/OccWorld) 122 | - [Latte](https://github.com/Vchitect/Latte) 123 | - [Vista](https://github.com/OpenDriveLab/Vista.git) 124 | - [PyTorch-VAE](https://github.com/AntixK/PyTorch-VAE) 125 | - [A* serach](https://www.redblobgames.com/pathfinding/a-star/) 126 | 127 | ## 🖊️ Citation 128 | ``` 129 | @article{gu2024dome, 130 | title={Dome: Taming diffusion model into high-fidelity controllable occupancy world model}, 131 | author={Gu, Songen and Yin, Wei and Jin, Bu and Guo, Xiaoyang and Wang, Junming and Li, Haodong and Zhang, Qian and Long, Xiaoxiao}, 132 | journal={arXiv preprint arXiv:2410.10429}, 133 | year={2024} 134 | } 135 | ``` 136 | 137 | 138 | -------------------------------------------------------------------------------- /config/label_mapping/nuscenes-occ.yaml: -------------------------------------------------------------------------------- 1 | labels: 2 | 0: 'noise' 3 | 1: 'animal' 4 | 2: 'human.pedestrian.adult' 5 | 3: 'human.pedestrian.child' 6 | 4: 'human.pedestrian.construction_worker' 7 | 5: 'human.pedestrian.personal_mobility' 8 | 6: 'human.pedestrian.police_officer' 9 | 7: 'human.pedestrian.stroller' 10 | 8: 'human.pedestrian.wheelchair' 11 | 9: 'movable_object.barrier' 12 | 10: 'movable_object.debris' 13 | 11: 'movable_object.pushable_pullable' 14 | 12: 'movable_object.trafficcone' 15 | 13: 'static_object.bicycle_rack' 16 | 14: 'vehicle.bicycle' 17 | 15: 'vehicle.bus.bendy' 18 | 16: 'vehicle.bus.rigid' 19 | 17: 'vehicle.car' 20 | 18: 'vehicle.construction' 21 | 19: 'vehicle.emergency.ambulance' 22 | 20: 'vehicle.emergency.police' 23 | 21: 'vehicle.motorcycle' 24 | 22: 'vehicle.trailer' 25 | 23: 'vehicle.truck' 26 | 24: 'flat.driveable_surface' 27 | 25: 'flat.other' 28 | 26: 'flat.sidewalk' 29 | 27: 'flat.terrain' 30 | 28: 'static.manmade' 31 | 29: 'static.other' 32 | 30: 'static.vegetation' 33 | 31: 'vehicle.ego' 34 | 32: 'empty' 35 | labels_16: 36 | 0: 'others' 37 | 1: 'barrier' 38 | 2: 'bicycle' 39 | 3: 'bus' 40 | 4: 'car' 41 | 5: 'construction_vehicle' 42 | 6: 'motorcycle' 43 | 7: 'pedestrian' 44 | 8: 'traffic_cone' 45 | 9: 'trailer' 46 | 10: 'truck' 47 | 11: 'driveable_surface' 48 | 12: 'other_flat' 49 | 13: 'sidewalk' 50 | 14: 'terrain' 51 | 15: 'manmade' 52 | 16: 'vegetation' 53 | 17: 'empty' 54 | learning_map: 55 | 0: 0 56 | 1: 1 57 | 2: 2 58 | 3: 3 59 | 4: 4 60 | 5: 5 61 | 6: 6 62 | 7: 7 63 | 8: 8 64 | 9: 9 65 | 10: 10 66 | 11: 11 67 | 12: 12 68 | 13: 13 69 | 14: 14 70 | 15: 15 71 | 16: 16 72 | 17: 17 -------------------------------------------------------------------------------- /config/train_dome.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | work_dir = './work_dir/dome' 4 | 5 | start_frame = 0 6 | mid_frame = 4 7 | end_frame = 10 8 | eval_length = end_frame-mid_frame 9 | 10 | return_len_train = 11 11 | return_len_ = 10 12 | grad_max_norm = 1 13 | print_freq = 1 14 | max_epochs = 2000 15 | warmup_iters = 50 16 | ema = True 17 | 18 | load_from = '' 19 | vae_load_from = 'ckpts/occvae_latest.pth' 20 | port = 25098 21 | revise_ckpt = 3 22 | eval_every_epochs = 10 23 | save_every_epochs = 10 24 | 25 | multisteplr = True 26 | multisteplr_config = dict( 27 | decay_rate=1, 28 | decay_t=[ 29 | 0, 30 | ], 31 | t_in_epochs=False, 32 | warmup_lr_init=1e-06, 33 | warmup_t=0) 34 | optimizer = dict(optimizer=dict(lr=0.0001, type='AdamW', weight_decay=0.0001)) 35 | 36 | schedule = dict( 37 | beta_end=0.02, 38 | beta_schedule='linear', 39 | beta_start=0.0001, 40 | variance_type='learned_range') 41 | 42 | sample = dict( 43 | enable_temporal_attentions=True, 44 | enable_vae_temporal_decoder=True, 45 | guidance_scale=7.5, 46 | n_conds=4, 47 | num_sampling_steps=20, 48 | run_time=0, 49 | sample_method='ddpm', 50 | seed=None) 51 | p_use_pose_condition = 0.9 52 | 53 | replace_cond_frames = True 54 | cond_frames_choices = [ 55 | [], 56 | [0], 57 | [0,1], 58 | [0,1,2], 59 | [0,1,2,3], 60 | ] 61 | data_path = 'data/nuscenes/' 62 | 63 | train_dataset_config = dict( 64 | type='nuScenesSceneDatasetLidar', 65 | data_path = data_path, 66 | return_len = return_len_train, 67 | offset = 0, 68 | times=5, 69 | imageset = 'data/nuscenes_infos_train_temporal_v3_scene.pkl', 70 | ) 71 | 72 | val_dataset_config = dict( 73 | data_path='data/nuscenes/', 74 | imageset='data/nuscenes_infos_val_temporal_v3_scene.pkl', 75 | new_rel_pose=True, 76 | offset=0, 77 | return_len=return_len_, 78 | test_mode=True, 79 | times=1, 80 | type='nuScenesSceneDatasetLidar') 81 | train_wrapper_config = dict(phase='train', type='tpvformer_dataset_nuscenes') 82 | val_wrapper_config = dict(phase='val', type='tpvformer_dataset_nuscenes') 83 | train_loader = dict(batch_size=8, num_workers=1, shuffle=True) 84 | val_loader = dict(batch_size=1, num_workers=1, shuffle=False) 85 | loss = dict( 86 | loss_cfgs=[ 87 | dict( 88 | input_dict=dict(ce_inputs='ce_inputs', ce_labels='ce_labels'), 89 | type='CeLoss', 90 | weight=1.0), 91 | ], 92 | type='MultiLoss') 93 | loss_input_convertion = dict() 94 | 95 | _dim_ = 16 96 | base_channel = 64 97 | expansion = 8 98 | n_e_ = 512 99 | num_heads=12 100 | hidden_size=768 101 | 102 | model = dict( 103 | delta_input=False, 104 | world_model=dict( 105 | attention_mode='xformers', 106 | class_dropout_prob=0.1, 107 | depth=28, 108 | extras=1, 109 | hidden_size=hidden_size, 110 | in_channels=64, 111 | input_size=25, 112 | learn_sigma=True, 113 | mlp_ratio=4.0, 114 | num_classes=1000, 115 | num_frames=return_len_train, 116 | num_heads=num_heads, 117 | patch_size=1, 118 | pose_encoder=dict( 119 | do_proj=True, 120 | in_channels=2, 121 | num_fut_ts=1, 122 | num_layers=2, 123 | num_modes=3, 124 | out_channels=hidden_size, 125 | type='PoseEncoder_fourier', 126 | zero_init=False), 127 | type='Dome'), 128 | sampling_method='SAMPLE', 129 | topk=10, 130 | vae=dict( 131 | encoder_cfg=dict( 132 | attn_resolutions=(50, ), 133 | ch=base_channel, 134 | ch_mult=( 135 | 1, 136 | 2, 137 | 4, 138 | 8, 139 | ), 140 | double_z=False, 141 | dropout=0.0, 142 | in_channels=128, 143 | num_res_blocks=2, 144 | out_ch=base_channel, 145 | resamp_with_conv=True, 146 | resolution=200, 147 | type='Encoder2D', 148 | z_channels=base_channel*2), 149 | decoder_cfg=dict( 150 | attn_resolutions=(50, ), 151 | ch=base_channel, 152 | ch_mult=( 153 | 1, 154 | 2, 155 | 4, 156 | 8, 157 | ), 158 | dropout=0.0, 159 | give_pre_end=False, 160 | in_channels=_dim_ * expansion, 161 | num_res_blocks=2, 162 | out_ch=_dim_ * expansion, 163 | resamp_with_conv=True, 164 | resolution=200, 165 | type='Decoder3D', 166 | z_channels=base_channel), 167 | expansion=expansion, 168 | num_classes=18, 169 | scaling_factor=0.18215, 170 | type='VAERes3D')) 171 | shapes = [[200,200],[100,100],[50,50],[25,25]] 172 | 173 | unique_label = [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16] 174 | label_mapping = './config/label_mapping/nuscenes-occ.yaml' 175 | -------------------------------------------------------------------------------- /config/train_dome_resample.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | work_dir = './work_dir/dome' 4 | 5 | start_frame = 0 6 | mid_frame = 4 7 | end_frame = 10 8 | eval_length = end_frame-mid_frame 9 | 10 | return_len_train = 11 11 | return_len_ = 10 12 | grad_max_norm = 1 13 | print_freq = 1 14 | max_epochs = 2000 15 | warmup_iters = 50 16 | ema = True 17 | 18 | load_from = '' 19 | vae_load_from = 'ckpts/occvae_latest.pth' 20 | port = 25098 21 | revise_ckpt = 3 22 | eval_every_epochs = 10 23 | save_every_epochs = 10 24 | 25 | multisteplr = True 26 | multisteplr_config = dict( 27 | decay_rate=1, 28 | decay_t=[ 29 | 0, 30 | ], 31 | t_in_epochs=False, 32 | warmup_lr_init=1e-06, 33 | warmup_t=0) 34 | optimizer = dict(optimizer=dict(lr=0.0001, type='AdamW', weight_decay=0.0001)) 35 | 36 | schedule = dict( 37 | beta_end=0.02, 38 | beta_schedule='linear', 39 | beta_start=0.0001, 40 | variance_type='learned_range') 41 | 42 | sample = dict( 43 | enable_temporal_attentions=True, 44 | enable_vae_temporal_decoder=True, 45 | guidance_scale=7.5, 46 | n_conds=4, 47 | num_sampling_steps=20, 48 | run_time=0, 49 | sample_method='ddpm', 50 | seed=None) 51 | p_use_pose_condition = 0.9 52 | 53 | replace_cond_frames = True 54 | cond_frames_choices = [ 55 | [], 56 | [0], 57 | [0,1], 58 | [0,1,2], 59 | [0,1,2,3], 60 | ] 61 | data_path = 'data/nuscenes/' 62 | 63 | train_dataset_config = dict( 64 | data_path='data/resampled_occ', 65 | imageset='data/nuscenes_infos_train_temporal_v3_scene.pkl', 66 | offset=0, 67 | raw_times=10, 68 | resample_times=1, 69 | return_len=return_len_train, 70 | times=1, 71 | type='nuScenesSceneDatasetLidarResample' 72 | ) 73 | 74 | val_dataset_config = dict( 75 | data_path='data/nuscenes/', 76 | imageset='data/nuscenes_infos_val_temporal_v3_scene.pkl', 77 | new_rel_pose=True, 78 | offset=0, 79 | return_len=return_len_, 80 | test_mode=True, 81 | times=1, 82 | type='nuScenesSceneDatasetLidar') 83 | train_wrapper_config = dict(phase='train', type='tpvformer_dataset_nuscenes') 84 | val_wrapper_config = dict(phase='val', type='tpvformer_dataset_nuscenes') 85 | train_loader = dict(batch_size=8, num_workers=1, shuffle=True) 86 | val_loader = dict(batch_size=1, num_workers=1, shuffle=False) 87 | loss = dict( 88 | loss_cfgs=[ 89 | dict( 90 | input_dict=dict(ce_inputs='ce_inputs', ce_labels='ce_labels'), 91 | type='CeLoss', 92 | weight=1.0), 93 | ], 94 | type='MultiLoss') 95 | loss_input_convertion = dict() 96 | 97 | _dim_ = 16 98 | base_channel = 64 99 | expansion = 8 100 | n_e_ = 512 101 | num_heads=12 102 | hidden_size=768 103 | 104 | model = dict( 105 | delta_input=False, 106 | world_model=dict( 107 | attention_mode='xformers', 108 | class_dropout_prob=0.1, 109 | depth=28, 110 | extras=1, 111 | hidden_size=hidden_size, 112 | in_channels=64, 113 | input_size=25, 114 | learn_sigma=True, 115 | mlp_ratio=4.0, 116 | num_classes=1000, 117 | num_frames=return_len_train, 118 | num_heads=num_heads, 119 | patch_size=1, 120 | pose_encoder=dict( 121 | do_proj=True, 122 | in_channels=2, 123 | num_fut_ts=1, 124 | num_layers=2, 125 | num_modes=3, 126 | out_channels=hidden_size, 127 | type='PoseEncoder_fourier', 128 | zero_init=False), 129 | type='Dome'), 130 | sampling_method='SAMPLE', 131 | topk=10, 132 | vae=dict( 133 | encoder_cfg=dict( 134 | attn_resolutions=(50, ), 135 | ch=base_channel, 136 | ch_mult=( 137 | 1, 138 | 2, 139 | 4, 140 | 8, 141 | ), 142 | double_z=False, 143 | dropout=0.0, 144 | in_channels=128, 145 | num_res_blocks=2, 146 | out_ch=base_channel, 147 | resamp_with_conv=True, 148 | resolution=200, 149 | type='Encoder2D', 150 | z_channels=base_channel*2), 151 | decoder_cfg=dict( 152 | attn_resolutions=(50, ), 153 | ch=base_channel, 154 | ch_mult=( 155 | 1, 156 | 2, 157 | 4, 158 | 8, 159 | ), 160 | dropout=0.0, 161 | give_pre_end=False, 162 | in_channels=_dim_ * expansion, 163 | num_res_blocks=2, 164 | out_ch=_dim_ * expansion, 165 | resamp_with_conv=True, 166 | resolution=200, 167 | type='Decoder3D', 168 | z_channels=base_channel), 169 | expansion=expansion, 170 | num_classes=18, 171 | scaling_factor=0.18215, 172 | type='VAERes3D')) 173 | shapes = [[200,200],[100,100],[50,50],[25,25]] 174 | 175 | unique_label = [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16] 176 | label_mapping = './config/label_mapping/nuscenes-occ.yaml' 177 | -------------------------------------------------------------------------------- /config/train_occvae.py: -------------------------------------------------------------------------------- 1 | grad_max_norm = 35 2 | print_freq = 10 3 | max_epochs = 200 4 | warmup_iters = 200 5 | return_len_ =16 # debug #max for conv3d en-de 6 | return_len_ =12 # debug #max for conv2d en- conv3d de 7 | # return_len_ =1 # debug #max for conv2d en- conv3d de 8 | 9 | eval_every_epochs = 10 10 | save_every_epochs = 10 11 | 12 | multisteplr = False 13 | multisteplr_config = dict( 14 | decay_t = [87 * 500], 15 | decay_rate = 0.1, 16 | warmup_t = warmup_iters, 17 | warmup_lr_init = 1e-6, 18 | t_in_epochs = False 19 | ) 20 | 21 | 22 | optimizer = dict( 23 | optimizer=dict( 24 | type='AdamW', 25 | lr=1e-3, 26 | weight_decay=0.01, 27 | ), 28 | ) 29 | 30 | data_path = 'data/nuscenes/' 31 | 32 | 33 | train_dataset_config = dict( 34 | type='nuScenesSceneDatasetLidar', 35 | data_path = data_path, 36 | return_len = return_len_, 37 | offset = 0, 38 | imageset = 'data/nuscenes_infos_train_temporal_v3_scene.pkl', 39 | ) 40 | 41 | val_dataset_config = dict( 42 | type='nuScenesSceneDatasetLidar', 43 | data_path = data_path, 44 | return_len = return_len_, 45 | offset = 0, 46 | imageset = 'data/nuscenes_infos_val_temporal_v3_scene.pkl', 47 | ) 48 | 49 | train_wrapper_config = dict( 50 | type='tpvformer_dataset_nuscenes', 51 | phase='train', 52 | ) 53 | 54 | val_wrapper_config = dict( 55 | type='tpvformer_dataset_nuscenes', 56 | phase='val', 57 | ) 58 | 59 | train_loader = dict( 60 | batch_size = 1, 61 | shuffle = True, 62 | num_workers = 0,#1, 63 | ) 64 | 65 | val_loader = dict( 66 | batch_size = 1, 67 | shuffle = False, 68 | num_workers = 0,#1, 69 | ) 70 | 71 | 72 | 73 | loss = dict( 74 | type='MultiLoss', 75 | loss_cfgs=[ 76 | dict( 77 | type='ReconLoss', 78 | weight=10.0, 79 | ignore_label=-100, 80 | use_weight=False, 81 | cls_weight=None, 82 | input_dict={ 83 | 'logits': 'logits', 84 | 'labels': 'inputs'}), 85 | dict( 86 | type='LovaszLoss', 87 | weight=1.0, 88 | input_dict={ 89 | 'logits': 'logits', 90 | 'labels': 'inputs'}), 91 | dict( 92 | type='KLLoss', 93 | # weight=0.00025, 94 | weight=0.00005, 95 | input_dict={ 96 | 'z_mu': 'z_mu', 97 | 'z_sigma': 'z_sigma'}), 98 | # dict( 99 | # type='VQVAEEmbedLoss', 100 | # weight=1.0), 101 | ]) 102 | 103 | loss_input_convertion = dict( 104 | logits='logits', 105 | # embed_loss='embed_loss' 106 | ) 107 | 108 | 109 | load_from = '' 110 | 111 | _dim_ = 16 112 | expansion = 8 # class embedding 维度 113 | base_channel = 64 114 | n_e_ = 512 115 | model = dict( 116 | type = 'VAERes3D', 117 | encoder_cfg=dict( 118 | type='Encoder2D', 119 | # type='Encoder3D',#debug 120 | ch = base_channel, 121 | out_ch = base_channel, 122 | ch_mult = (1,2,4,8), 123 | num_res_blocks = 2, 124 | attn_resolutions = (50,), 125 | dropout = 0.0, 126 | resamp_with_conv = True, 127 | in_channels = _dim_ * expansion, 128 | resolution = 200, 129 | z_channels = base_channel * 2, 130 | double_z = False, 131 | # temporal_downsample= ( 132 | # "", 133 | # "TimeDownsample2x", 134 | # "TimeDownsample2x", 135 | # "", 136 | # ), 137 | ), 138 | decoder_cfg=dict( 139 | type='Decoder3D', 140 | ch = base_channel, 141 | out_ch = _dim_ * expansion, 142 | ch_mult = (1,2,4,8), 143 | num_res_blocks = 2, 144 | attn_resolutions = (50,), 145 | dropout = 0.0, 146 | resamp_with_conv = True, 147 | in_channels = _dim_ * expansion, 148 | resolution = 200, 149 | z_channels = base_channel, 150 | give_pre_end = False, 151 | # temporal_upsample = ("", "", "TimeUpsample2x", "TimeUpsample2x"), 152 | ), 153 | num_classes=18, 154 | expansion=expansion, 155 | # vqvae_cfg=dict( 156 | # type='VectorQuantizer', 157 | # n_e = n_e_, 158 | # e_dim = base_channel * 2, 159 | # beta = 1., 160 | # z_channels = base_channel * 2, 161 | # use_voxel=False) 162 | ) 163 | 164 | shapes = [[200, 200], [100, 100], [50, 50], [25, 25]] 165 | 166 | unique_label = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16] 167 | label_mapping = "./config/label_mapping/nuscenes-occ.yaml" -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from mmengine.registry import Registry 3 | OPENOCC_DATASET = Registry('openocc_dataset') 4 | OPENOCC_DATAWRAPPER = Registry('openocc_datawrapper') 5 | 6 | from .dataset import nuScenesSceneDatasetLidar, nuScenesSceneDatasetLidarTraverse 7 | from .dataset_wrapper import tpvformer_dataset_nuscenes, custom_collate_fn_temporal 8 | from .sampler import CustomDistributedSampler 9 | from torch.utils.data.distributed import DistributedSampler 10 | from torch.utils.data.dataloader import DataLoader 11 | import yaml 12 | 13 | def get_dataloader( 14 | train_dataset_config, 15 | val_dataset_config, 16 | train_wrapper_config, 17 | val_wrapper_config, 18 | train_loader, 19 | val_loader, 20 | nusc=dict( 21 | version='v1.0-trainval', 22 | dataroot='data/nuscenes'), 23 | dist=False, 24 | iter_resume=False, 25 | train_sampler_config=dict( 26 | shuffle=True, 27 | drop_last=True), 28 | val_sampler_config=dict( 29 | shuffle=False, 30 | drop_last=False), 31 | ): 32 | # if nusc is not None: 33 | # from nuscenes import NuScenes 34 | # nusc = NuScenes(**nusc) 35 | 36 | train_dataset = OPENOCC_DATASET.build( 37 | train_dataset_config, 38 | default_args={'nusc': nusc}) 39 | val_dataset = OPENOCC_DATASET.build( 40 | val_dataset_config, 41 | default_args={'nusc': nusc}) 42 | 43 | train_wrapper = OPENOCC_DATAWRAPPER.build( 44 | train_wrapper_config, 45 | default_args={'in_dataset': train_dataset}) 46 | val_wrapper = OPENOCC_DATAWRAPPER.build( 47 | val_wrapper_config, 48 | default_args={'in_dataset': val_dataset}) 49 | 50 | train_sampler = val_sampler = None 51 | if dist: 52 | if iter_resume: 53 | train_sampler = CustomDistributedSampler(train_wrapper, **train_sampler_config) 54 | else: 55 | train_sampler = DistributedSampler(train_wrapper, **train_sampler_config) 56 | val_sampler = DistributedSampler(val_wrapper, **val_sampler_config) 57 | 58 | train_dataset_loader = DataLoader( 59 | dataset=train_wrapper, 60 | batch_size=train_loader["batch_size"], 61 | collate_fn=custom_collate_fn_temporal, 62 | shuffle=False if dist else train_loader["shuffle"], 63 | sampler=train_sampler, 64 | num_workers=train_loader["num_workers"], 65 | pin_memory=True) 66 | val_dataset_loader = DataLoader( 67 | dataset=val_wrapper, 68 | batch_size=val_loader["batch_size"], 69 | collate_fn=custom_collate_fn_temporal, 70 | shuffle=False, 71 | sampler=val_sampler, 72 | num_workers=val_loader["num_workers"], 73 | pin_memory=True) 74 | 75 | return train_dataset_loader, val_dataset_loader 76 | 77 | 78 | def get_nuScenes_label_name(label_mapping): 79 | with open(label_mapping, 'r') as stream: 80 | nuScenesyaml = yaml.safe_load(stream) 81 | nuScenes_label_name = dict() 82 | for i in sorted(list(nuScenesyaml['learning_map'].keys()))[::-1]: 83 | val_ = nuScenesyaml['learning_map'][i] 84 | nuScenes_label_name[val_] = nuScenesyaml['labels_16'][val_] 85 | return nuScenes_label_name 86 | 87 | 88 | -------------------------------------------------------------------------------- /dataset/dataset_wrapper.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np, torch 3 | from torch.utils import data 4 | import torch.nn.functional as F 5 | from copy import deepcopy 6 | from mmengine import MMLogger 7 | logger = MMLogger.get_instance('genocc') 8 | try: 9 | from . import OPENOCC_DATAWRAPPER 10 | except: 11 | from mmengine.registry import Registry 12 | OPENOCC_DATAWRAPPER = Registry('openocc_datawrapper') 13 | import torch 14 | 15 | @OPENOCC_DATAWRAPPER.register_module() 16 | class tpvformer_dataset_nuscenes(data.Dataset): 17 | def __init__( 18 | self, 19 | in_dataset, 20 | phase='train', 21 | ): 22 | 'Initialization' 23 | self.point_cloud_dataset = in_dataset 24 | self.phase = phase 25 | 26 | def __len__(self): 27 | return len(self.point_cloud_dataset) 28 | 29 | def to_tensor(self, imgs): 30 | imgs = np.stack(imgs).astype(np.float32) 31 | imgs = torch.from_numpy(imgs) 32 | imgs = imgs.permute(0, 3, 1, 2) 33 | return imgs 34 | 35 | def __getitem__(self, index): 36 | input, target, metas = self.point_cloud_dataset[index] 37 | #### adapt to vae input 38 | input = torch.from_numpy(input) 39 | target = torch.from_numpy(target) 40 | return input, target, metas 41 | 42 | 43 | def custom_collate_fn_temporal(data): 44 | data_tuple = [] 45 | for i, item in enumerate(data[0]): 46 | if isinstance(item, torch.Tensor): 47 | data_tuple.append(torch.stack([d[i] for d in data])) 48 | elif isinstance(item, (dict, str)): 49 | data_tuple.append([d[i] for d in data]) 50 | elif item is None: 51 | data_tuple.append(None) 52 | else: 53 | raise NotImplementedError 54 | return data_tuple 55 | -------------------------------------------------------------------------------- /dataset/sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import TypeVar, Optional, Iterator 3 | 4 | import torch 5 | from torch.utils.data import Sampler, Dataset 6 | import torch.distributed as dist 7 | from typing import Callable 8 | import pandas as pd 9 | from torch.utils.data.distributed import DistributedSampler 10 | import torchvision 11 | 12 | 13 | 14 | T_co = TypeVar('T_co', covariant=True) 15 | 16 | 17 | class CustomDistributedSampler(Sampler[T_co]): 18 | r"""Sampler that restricts data loading to a subset of the dataset. 19 | 20 | It is especially useful in conjunction with 21 | :class:`torch.nn.parallel.DistributedDataParallel`. In such a case, each 22 | process can pass a :class:`~torch.utils.data.DistributedSampler` instance as a 23 | :class:`~torch.utils.data.DataLoader` sampler, and load a subset of the 24 | original dataset that is exclusive to it. 25 | 26 | .. note:: 27 | Dataset is assumed to be of constant size. 28 | 29 | Arguments: 30 | dataset: Dataset used for sampling. 31 | num_replicas (int, optional): Number of processes participating in 32 | distributed training. By default, :attr:`rank` is retrieved from the 33 | current distributed group. 34 | rank (int, optional): Rank of the current process within :attr:`num_replicas`. 35 | By default, :attr:`rank` is retrieved from the current distributed 36 | group. 37 | shuffle (bool, optional): If ``True`` (default), sampler will shuffle the 38 | indices. 39 | seed (int, optional): random seed used to shuffle the sampler if 40 | :attr:`shuffle=True`. This number should be identical across all 41 | processes in the distributed group. Default: ``0``. 42 | drop_last (bool, optional): if ``True``, then the sampler will drop the 43 | tail of the data to make it evenly divisible across the number of 44 | replicas. If ``False``, the sampler will add extra indices to make 45 | the data evenly divisible across the replicas. Default: ``False``. 46 | 47 | .. warning:: 48 | In distributed mode, calling the :meth:`set_epoch` method at 49 | the beginning of each epoch **before** creating the :class:`DataLoader` iterator 50 | is necessary to make shuffling work properly across multiple epochs. Otherwise, 51 | the same ordering will be always used. 52 | 53 | Example:: 54 | 55 | >>> sampler = DistributedSampler(dataset) if is_distributed else None 56 | >>> loader = DataLoader(dataset, shuffle=(sampler is None), 57 | ... sampler=sampler) 58 | >>> for epoch in range(start_epoch, n_epochs): 59 | ... if is_distributed: 60 | ... sampler.set_epoch(epoch) 61 | ... train(loader) 62 | """ 63 | 64 | def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None, 65 | rank: Optional[int] = None, shuffle: bool = True, 66 | seed: int = 0, drop_last: bool = False, last_iter: int = 0) -> None: 67 | if num_replicas is None: 68 | if not dist.is_available(): 69 | raise RuntimeError("Requires distributed package to be available") 70 | num_replicas = dist.get_world_size() 71 | if rank is None: 72 | if not dist.is_available(): 73 | raise RuntimeError("Requires distributed package to be available") 74 | rank = dist.get_rank() 75 | self.dataset = dataset 76 | self.num_replicas = num_replicas 77 | self.rank = rank 78 | self.epoch = 0 79 | self.drop_last = drop_last 80 | # If the dataset length is evenly divisible by # of replicas, then there 81 | # is no need to drop any data, since the dataset will be split equally. 82 | if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore 83 | # Split to nearest available length that is evenly divisible. 84 | # This is to ensure each rank receives the same amount of data when 85 | # using this Sampler. 86 | self.num_samples = math.ceil( 87 | # `type:ignore` is required because Dataset cannot provide a default __len__ 88 | # see NOTE in pytorch/torch/utils/data/sampler.py 89 | (len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore 90 | ) 91 | else: 92 | self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore 93 | self.total_size = self.num_samples * self.num_replicas 94 | self.shuffle = shuffle 95 | self.seed = seed 96 | self.first_run = True 97 | self.last_iter = last_iter 98 | 99 | def __iter__(self) -> Iterator[T_co]: 100 | if self.shuffle: 101 | # deterministically shuffle based on epoch and seed 102 | g = torch.Generator() 103 | g.manual_seed(self.seed + self.epoch) 104 | indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore 105 | else: 106 | indices = list(range(len(self.dataset))) # type: ignore 107 | 108 | if not self.drop_last: 109 | # add extra samples to make it evenly divisible 110 | indices += indices[:(self.total_size - len(indices))] 111 | else: 112 | # remove tail of data to make it evenly divisible. 113 | indices = indices[:self.total_size] 114 | assert len(indices) == self.total_size 115 | 116 | # subsample 117 | indices = indices[self.rank:self.total_size:self.num_replicas] 118 | if not self.first_run: 119 | assert len(indices) == self.num_samples 120 | else: 121 | indices = indices[self.last_iter:] 122 | self.last_iter = 0 123 | self.first_run = False 124 | 125 | return iter(indices) 126 | 127 | def __len__(self) -> int: 128 | return self.num_samples 129 | 130 | def set_epoch(self, epoch: int) -> None: 131 | r""" 132 | Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas 133 | use a different random ordering for each epoch. Otherwise, the next iteration of this 134 | sampler will yield the same ordering. 135 | 136 | Arguments: 137 | epoch (int): Epoch number. 138 | """ 139 | self.epoch = epoch 140 | 141 | def set_last_iter(self, last_iter: int): 142 | self.last_iter = last_iter 143 | -------------------------------------------------------------------------------- /diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | from . import gaussian_diffusion as gd 7 | from .respace import SpacedDiffusion, space_timesteps 8 | 9 | 10 | def create_diffusion( 11 | timestep_respacing, 12 | noise_schedule="linear", 13 | use_kl=False, 14 | sigma_small=False, 15 | predict_xstart=False, 16 | learn_sigma=True, 17 | # learn_sigma=False, 18 | rescale_learned_sigmas=False, 19 | diffusion_steps=1000, 20 | beta_start = 0.0001, 21 | beta_end = 0.02, 22 | replace_cond_frames=False, 23 | cond_frames_choices=None, 24 | ): 25 | betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps,beta_start,beta_end) 26 | if use_kl: 27 | loss_type = gd.LossType.RESCALED_KL 28 | elif rescale_learned_sigmas: 29 | loss_type = gd.LossType.RESCALED_MSE 30 | else: 31 | loss_type = gd.LossType.MSE 32 | if timestep_respacing is None or timestep_respacing == "": 33 | timestep_respacing = [diffusion_steps] 34 | return SpacedDiffusion( 35 | use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), 36 | betas=betas, 37 | model_mean_type=( 38 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X 39 | ), 40 | model_var_type=( 41 | ( 42 | gd.ModelVarType.FIXED_LARGE 43 | if not sigma_small 44 | else gd.ModelVarType.FIXED_SMALL 45 | ) 46 | if not learn_sigma 47 | else gd.ModelVarType.LEARNED_RANGE 48 | ), 49 | loss_type=loss_type, 50 | replace_cond_frames=replace_cond_frames, 51 | cond_frames_choices=cond_frames_choices, 52 | # rescale_timesteps=rescale_timesteps, 53 | ) 54 | -------------------------------------------------------------------------------- /diffusion/diffusion_utils.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | import torch as th 7 | import numpy as np 8 | 9 | 10 | def normal_kl(mean1, logvar1, mean2, logvar2): 11 | """ 12 | Compute the KL divergence between two gaussians. 13 | Shapes are automatically broadcasted, so batches can be compared to 14 | scalars, among other use cases. 15 | """ 16 | tensor = None 17 | for obj in (mean1, logvar1, mean2, logvar2): 18 | if isinstance(obj, th.Tensor): 19 | tensor = obj 20 | break 21 | assert tensor is not None, "at least one argument must be a Tensor" 22 | 23 | # Force variances to be Tensors. Broadcasting helps convert scalars to 24 | # Tensors, but it does not work for th.exp(). 25 | logvar1, logvar2 = [ 26 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 27 | for x in (logvar1, logvar2) 28 | ] 29 | 30 | return 0.5 * ( 31 | -1.0 32 | + logvar2 33 | - logvar1 34 | + th.exp(logvar1 - logvar2) 35 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 36 | ) 37 | 38 | 39 | def approx_standard_normal_cdf(x): 40 | """ 41 | A fast approximation of the cumulative distribution function of the 42 | standard normal. 43 | """ 44 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 45 | 46 | 47 | def continuous_gaussian_log_likelihood(x, *, means, log_scales): 48 | """ 49 | Compute the log-likelihood of a continuous Gaussian distribution. 50 | :param x: the targets 51 | :param means: the Gaussian mean Tensor. 52 | :param log_scales: the Gaussian log stddev Tensor. 53 | :return: a tensor like x of log probabilities (in nats). 54 | """ 55 | centered_x = x - means 56 | inv_stdv = th.exp(-log_scales) 57 | normalized_x = centered_x * inv_stdv 58 | log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x) 59 | return log_probs 60 | 61 | 62 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 63 | """ 64 | Compute the log-likelihood of a Gaussian distribution discretizing to a 65 | given image. 66 | :param x: the target images. It is assumed that this was uint8 values, 67 | rescaled to the range [-1, 1]. 68 | :param means: the Gaussian mean Tensor. 69 | :param log_scales: the Gaussian log stddev Tensor. 70 | :return: a tensor like x of log probabilities (in nats). 71 | """ 72 | assert x.shape == means.shape == log_scales.shape 73 | centered_x = x - means 74 | inv_stdv = th.exp(-log_scales) 75 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 76 | cdf_plus = approx_standard_normal_cdf(plus_in) 77 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 78 | cdf_min = approx_standard_normal_cdf(min_in) 79 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 80 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 81 | cdf_delta = cdf_plus - cdf_min 82 | log_probs = th.where( 83 | x < -0.999, 84 | log_cdf_plus, 85 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 86 | ) 87 | assert log_probs.shape == x.shape 88 | return log_probs 89 | -------------------------------------------------------------------------------- /diffusion/respace.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | import torch 6 | import numpy as np 7 | import torch as th 8 | 9 | from .gaussian_diffusion import GaussianDiffusion 10 | 11 | 12 | def space_timesteps(num_timesteps, section_counts): 13 | """ 14 | Create a list of timesteps to use from an original diffusion process, 15 | given the number of timesteps we want to take from equally-sized portions 16 | of the original process. 17 | For example, if there's 300 timesteps and the section counts are [10,15,20] 18 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 19 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 20 | If the stride is a string starting with "ddim", then the fixed striding 21 | from the DDIM paper is used, and only one section is allowed. 22 | :param num_timesteps: the number of diffusion steps in the original 23 | process to divide up. 24 | :param section_counts: either a list of numbers, or a string containing 25 | comma-separated numbers, indicating the step count 26 | per section. As a special case, use "ddimN" where N 27 | is a number of steps to use the striding from the 28 | DDIM paper. 29 | :return: a set of diffusion steps from the original process to use. 30 | """ 31 | if isinstance(section_counts, str): 32 | if section_counts.startswith("ddim"): 33 | desired_count = int(section_counts[len("ddim") :]) 34 | for i in range(1, num_timesteps): 35 | if len(range(0, num_timesteps, i)) == desired_count: 36 | return set(range(0, num_timesteps, i)) 37 | raise ValueError( 38 | f"cannot create exactly {num_timesteps} steps with an integer stride" 39 | ) 40 | section_counts = [int(x) for x in section_counts.split(",")] 41 | size_per = num_timesteps // len(section_counts) 42 | extra = num_timesteps % len(section_counts) 43 | start_idx = 0 44 | all_steps = [] 45 | for i, section_count in enumerate(section_counts): 46 | size = size_per + (1 if i < extra else 0) 47 | if size < section_count: 48 | raise ValueError( 49 | f"cannot divide section of {size} steps into {section_count}" 50 | ) 51 | if section_count <= 1: 52 | frac_stride = 1 53 | else: 54 | frac_stride = (size - 1) / (section_count - 1) 55 | cur_idx = 0.0 56 | taken_steps = [] 57 | for _ in range(section_count): 58 | taken_steps.append(start_idx + round(cur_idx)) 59 | cur_idx += frac_stride 60 | all_steps += taken_steps 61 | start_idx += size 62 | return set(all_steps) 63 | 64 | 65 | class SpacedDiffusion(GaussianDiffusion): 66 | """ 67 | A diffusion process which can skip steps in a base diffusion process. 68 | :param use_timesteps: a collection (sequence or set) of timesteps from the 69 | original diffusion process to retain. 70 | :param kwargs: the kwargs to create the base diffusion process. 71 | """ 72 | 73 | def __init__(self, use_timesteps, **kwargs): 74 | self.use_timesteps = set(use_timesteps) 75 | self.timestep_map = [] 76 | self.original_num_steps = len(kwargs["betas"]) 77 | 78 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 79 | last_alpha_cumprod = 1.0 80 | new_betas = [] 81 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 82 | if i in self.use_timesteps: 83 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 84 | last_alpha_cumprod = alpha_cumprod 85 | self.timestep_map.append(i) 86 | kwargs["betas"] = np.array(new_betas) 87 | super().__init__(**kwargs) 88 | 89 | def p_mean_variance( 90 | self, model, *args, **kwargs 91 | ): # pylint: disable=signature-differs 92 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 93 | 94 | # @torch.compile 95 | def training_losses( 96 | self, model, *args, **kwargs 97 | ): # pylint: disable=signature-differs 98 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 99 | 100 | def condition_mean(self, cond_fn, *args, **kwargs): 101 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) 102 | 103 | def condition_score(self, cond_fn, *args, **kwargs): 104 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) 105 | 106 | def _wrap_model(self, model): 107 | if isinstance(model, _WrappedModel): 108 | return model 109 | return _WrappedModel( 110 | model, self.timestep_map, self.original_num_steps 111 | ) 112 | 113 | def _scale_timesteps(self, t): 114 | # Scaling is done by the wrapped model. 115 | return t 116 | 117 | 118 | class _WrappedModel: 119 | def __init__(self, model, timestep_map, original_num_steps): 120 | self.model = model 121 | self.timestep_map = timestep_map 122 | # self.rescale_timesteps = rescale_timesteps 123 | self.original_num_steps = original_num_steps 124 | 125 | def __call__(self, x, ts, **kwargs): 126 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 127 | new_ts = map_tensor[ts] 128 | # if self.rescale_timesteps: 129 | # new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 130 | return self.model(x, new_ts, **kwargs) 131 | -------------------------------------------------------------------------------- /diffusion/timestep_sampler.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | from abc import ABC, abstractmethod 7 | 8 | import numpy as np 9 | import torch as th 10 | import torch.distributed as dist 11 | 12 | 13 | def create_named_schedule_sampler(name, diffusion): 14 | """ 15 | Create a ScheduleSampler from a library of pre-defined samplers. 16 | :param name: the name of the sampler. 17 | :param diffusion: the diffusion object to sample for. 18 | """ 19 | if name == "uniform": 20 | return UniformSampler(diffusion) 21 | elif name == "loss-second-moment": 22 | return LossSecondMomentResampler(diffusion) 23 | else: 24 | raise NotImplementedError(f"unknown schedule sampler: {name}") 25 | 26 | 27 | class ScheduleSampler(ABC): 28 | """ 29 | A distribution over timesteps in the diffusion process, intended to reduce 30 | variance of the objective. 31 | By default, samplers perform unbiased importance sampling, in which the 32 | objective's mean is unchanged. 33 | However, subclasses may override sample() to change how the resampled 34 | terms are reweighted, allowing for actual changes in the objective. 35 | """ 36 | 37 | @abstractmethod 38 | def weights(self): 39 | """ 40 | Get a numpy array of weights, one per diffusion step. 41 | The weights needn't be normalized, but must be positive. 42 | """ 43 | 44 | def sample(self, batch_size, device): 45 | """ 46 | Importance-sample timesteps for a batch. 47 | :param batch_size: the number of timesteps. 48 | :param device: the torch device to save to. 49 | :return: a tuple (timesteps, weights): 50 | - timesteps: a tensor of timestep indices. 51 | - weights: a tensor of weights to scale the resulting losses. 52 | """ 53 | w = self.weights() 54 | p = w / np.sum(w) 55 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 56 | indices = th.from_numpy(indices_np).long().to(device) 57 | weights_np = 1 / (len(p) * p[indices_np]) 58 | weights = th.from_numpy(weights_np).float().to(device) 59 | return indices, weights 60 | 61 | 62 | class UniformSampler(ScheduleSampler): 63 | def __init__(self, diffusion): 64 | self.diffusion = diffusion 65 | self._weights = np.ones([diffusion.num_timesteps]) 66 | 67 | def weights(self): 68 | return self._weights 69 | 70 | 71 | class LossAwareSampler(ScheduleSampler): 72 | def update_with_local_losses(self, local_ts, local_losses): 73 | """ 74 | Update the reweighting using losses from a model. 75 | Call this method from each rank with a batch of timesteps and the 76 | corresponding losses for each of those timesteps. 77 | This method will perform synchronization to make sure all of the ranks 78 | maintain the exact same reweighting. 79 | :param local_ts: an integer Tensor of timesteps. 80 | :param local_losses: a 1D Tensor of losses. 81 | """ 82 | batch_sizes = [ 83 | th.tensor([0], dtype=th.int32, device=local_ts.device) 84 | for _ in range(dist.get_world_size()) 85 | ] 86 | dist.all_gather( 87 | batch_sizes, 88 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 89 | ) 90 | 91 | # Pad all_gather batches to be the maximum batch size. 92 | batch_sizes = [x.item() for x in batch_sizes] 93 | max_bs = max(batch_sizes) 94 | 95 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] 96 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] 97 | dist.all_gather(timestep_batches, local_ts) 98 | dist.all_gather(loss_batches, local_losses) 99 | timesteps = [ 100 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 101 | ] 102 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 103 | self.update_with_all_losses(timesteps, losses) 104 | 105 | @abstractmethod 106 | def update_with_all_losses(self, ts, losses): 107 | """ 108 | Update the reweighting using losses from a model. 109 | Sub-classes should override this method to update the reweighting 110 | using losses from the model. 111 | This method directly updates the reweighting without synchronizing 112 | between workers. It is called by update_with_local_losses from all 113 | ranks with identical arguments. Thus, it should have deterministic 114 | behavior to maintain state across workers. 115 | :param ts: a list of int timesteps. 116 | :param losses: a list of float losses, one per timestep. 117 | """ 118 | 119 | 120 | class LossSecondMomentResampler(LossAwareSampler): 121 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): 122 | self.diffusion = diffusion 123 | self.history_per_term = history_per_term 124 | self.uniform_prob = uniform_prob 125 | self._loss_history = np.zeros( 126 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 127 | ) 128 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) 129 | 130 | def weights(self): 131 | if not self._warmed_up(): 132 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) 133 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 134 | weights /= np.sum(weights) 135 | weights *= 1 - self.uniform_prob 136 | weights += self.uniform_prob / len(weights) 137 | return weights 138 | 139 | def update_with_all_losses(self, ts, losses): 140 | for t, loss in zip(ts, losses): 141 | if self._loss_counts[t] == self.history_per_term: 142 | # Shift out the oldest loss term. 143 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 144 | self._loss_history[t, -1] = loss 145 | else: 146 | self._loss_history[t, self._loss_counts[t]] = loss 147 | self._loss_counts[t] += 1 148 | 149 | def _warmed_up(self): 150 | return (self._loss_counts == self.history_per_term).all() 151 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: Occworld 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge 6 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main 7 | - conda-forge 8 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/ 9 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/ 10 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge/ 11 | dependencies: 12 | - _libgcc_mutex=0.1=main 13 | - _openmp_mutex=5.1=1_gnu 14 | - arrow=1.2.3=py38h06a4308_1 15 | - attrs=23.1.0=pyh71513ae_1 16 | - backoff=2.2.1=pyhd8ed1ab_0 17 | - backports=1.0=pyhd8ed1ab_3 18 | - backports.functools_lru_cache=1.6.5=pyhd8ed1ab_0 19 | - beautifulsoup4=4.12.2=pyha770c72_0 20 | - blas=1.0=mkl 21 | - blessed=1.19.1=pyhe4f9e05_2 22 | - boto3=1.28.66=pyhd8ed1ab_0 23 | - botocore=1.31.66=pyhd8ed1ab_0 24 | - brotlipy=0.7.0=py38h27cfd23_1003 25 | - bzip2=1.0.8=h7b6447c_0 26 | - ca-certificates=2024.7.2=h06a4308_0 27 | - cachecontrol=0.12.11=py38h06a4308_1 28 | - certifi=2024.7.4=py38h06a4308_0 29 | - cffi=1.15.0=py38h7f8727e_0 30 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 31 | - cleo=2.0.1=pyhd8ed1ab_0 32 | - click=8.1.7=unix_pyh707e725_0 33 | - colorama=0.4.6=pyhd8ed1ab_0 34 | - commonmark=0.9.1=pyhd3eb1b0_0 35 | - crashtest=0.4.1=pyhd8ed1ab_0 36 | - croniter=1.4.1=pyhd8ed1ab_0 37 | - cryptography=3.4.8=py38h3e25421_1 38 | - cuda-cudart=11.8.89=0 39 | - cuda-cupti=11.8.87=0 40 | - cuda-libraries=11.8.0=0 41 | - cuda-nvrtc=11.8.89=0 42 | - cuda-nvtx=11.8.86=0 43 | - cuda-runtime=11.8.0=0 44 | - dataclasses=0.8=pyh6d0b6a4_7 45 | - dateutils=0.6.12=py_0 46 | - dbus=1.13.18=hb2f20db_0 47 | - deepdiff=5.8.1=pyhd8ed1ab_0 48 | - distlib=0.3.7=pyhd8ed1ab_0 49 | - dulwich=0.21.3=py38h5eee18b_0 50 | - exceptiongroup=1.1.3=pyhd8ed1ab_0 51 | - expat=2.2.10=h9c3ff4c_0 52 | - fastapi=0.103.2=pyhd8ed1ab_0 53 | - ffmpeg=4.3=hf484d3e_0 54 | - filelock=3.9.0=py38h06a4308_0 55 | - freetype=2.12.1=h4a9f257_0 56 | - future=0.18.3=py38h06a4308_0 57 | - gettext=0.21.0=h39681ba_1 58 | - giflib=5.2.1=h5eee18b_3 59 | - glib=2.56.2=had28632_1001 60 | - gmp=6.2.1=h295c915_3 61 | - gmpy2=2.1.2=py38heeb90bb_0 62 | - gnutls=3.6.15=he1e5248_0 63 | - h11=0.14.0=pyhd8ed1ab_0 64 | - html5lib=1.1=pyh9f0ad1d_0 65 | - icu=58.2=hf484d3e_1000 66 | - idna=3.4=py38h06a4308_0 67 | - importlib-metadata=6.8.0=pyha770c72_0 68 | - importlib_metadata=6.8.0=hd8ed1ab_0 69 | - inquirer=3.1.3=pyhd8ed1ab_0 70 | - intel-openmp=2021.4.0=h06a4308_3561 71 | - itsdangerous=2.1.2=pyhd8ed1ab_0 72 | - jaraco.classes=3.3.0=pyhd8ed1ab_0 73 | - jeepney=0.8.0=pyhd8ed1ab_0 74 | - jinja2=3.1.2=py38h06a4308_0 75 | - jpeg=9e=h5eee18b_1 76 | - jsonschema=4.19.2=py38h06a4308_0 77 | - jsonschema-specifications=2023.7.1=py38h06a4308_0 78 | - keyring=23.13.1=py38h578d9bd_0 79 | - lame=3.100=h7b6447c_0 80 | - lcms2=2.12=h3be6417_0 81 | - lerc=3.0=h295c915_0 82 | - libaio=0.3.113=h5eee18b_0 83 | - libcublas=11.11.3.6=0 84 | - libcufft=10.9.0.58=0 85 | - libcufile=1.7.2.10=0 86 | - libcurand=10.3.3.141=0 87 | - libcusolver=11.4.1.48=0 88 | - libcusparse=11.7.5.86=0 89 | - libdeflate=1.17=h5eee18b_0 90 | - libedit=3.1.20221030=h5eee18b_0 91 | - libffi=3.2.1=hf484d3e_1007 92 | - libgcc-ng=11.2.0=h1234567_1 93 | - libgomp=11.2.0=h1234567_1 94 | - libiconv=1.16=h7f8727e_2 95 | - libidn2=2.3.4=h5eee18b_0 96 | - libnpp=11.8.0.86=0 97 | - libnvjpeg=11.9.0.86=0 98 | - libpng=1.6.39=h5eee18b_0 99 | - libstdcxx-ng=11.2.0=h1234567_1 100 | - libtasn1=4.19.0=h5eee18b_0 101 | - libtiff=4.5.1=h6a678d5_0 102 | - libunistring=0.9.10=h27cfd23_0 103 | - libwebp=1.2.4=h11a3e52_1 104 | - libwebp-base=1.2.4=h5eee18b_1 105 | - libxml2=2.10.4=hcbfbd50_0 106 | - lightning=2.1.0=pyhd8ed1ab_0 107 | - lightning-cloud=0.5.42=pyhd8ed1ab_0 108 | - lightning-utilities=0.9.0=pyhd8ed1ab_0 109 | - lockfile=0.12.2=py_1 110 | - lz4-c=1.9.4=h6a678d5_0 111 | - markdown-it-py=3.0.0=pyhd8ed1ab_0 112 | - markupsafe=2.1.1=py38h7f8727e_0 113 | - mkl=2021.4.0=h06a4308_640 114 | - mkl-service=2.4.0=py38h7f8727e_0 115 | - mkl_fft=1.3.1=py38hd3c417c_0 116 | - mkl_random=1.2.2=py38h51133e4_0 117 | - more-itertools=10.1.0=pyhd8ed1ab_0 118 | - mpc=1.1.0=h10f8cd9_1 119 | - mpfr=4.0.2=hb69a4c5_1 120 | - mpmath=1.3.0=py38h06a4308_0 121 | - msgpack-python=1.0.3=py38hd09550d_0 122 | - ncurses=6.4=h6a678d5_0 123 | - nettle=3.7.3=hbbd107a_1 124 | - networkx=3.1=py38h06a4308_0 125 | - numpy=1.24.3=py38h14f4228_0 126 | - numpy-base=1.24.3=py38h31eccc5_0 127 | - openh264=2.1.1=h4ff587b_0 128 | - openssl=1.1.1w=h7f8727e_0 129 | - ordered-set=4.1.0=pyhd8ed1ab_0 130 | - pcre=8.45=h9c3ff4c_0 131 | - pexpect=4.8.0=pyh1a96a4e_2 132 | - pillow=9.4.0=py38h6a678d5_0 133 | - pip=23.2.1=py38h06a4308_0 134 | - pkginfo=1.9.6=pyhd8ed1ab_0 135 | - pkgutil-resolve-name=1.3.10=pyhd8ed1ab_1 136 | - poetry=1.4.0=py38h06a4308_0 137 | - poetry-core=1.5.1=py38h06a4308_0 138 | - poetry-plugin-export=1.3.0=py38h4849bfd_0 139 | - ptyprocess=0.7.0=pyhd3deb0d_0 140 | - pycparser=2.21=pyhd3eb1b0_0 141 | - pydantic=1.10.12=py38h5eee18b_1 142 | - pygments=2.16.1=pyhd8ed1ab_0 143 | - pyjwt=2.8.0=pyhd8ed1ab_0 144 | - pyopenssl=20.0.1=pyhd8ed1ab_0 145 | - pyproject_hooks=1.0.0=pyhd8ed1ab_0 146 | - pyrsistent=0.18.0=py38heee7806_0 147 | - pysocks=1.7.1=py38h06a4308_0 148 | - python=3.8.0=h0371630_2 149 | - python-build=0.10.0=pyhd8ed1ab_1 150 | - python-dateutil=2.8.2=pyhd8ed1ab_0 151 | - python-editor=1.0.4=py_0 152 | - python-installer=0.6.0=py38h06a4308_0 153 | - python-multipart=0.0.6=pyhd8ed1ab_0 154 | - python_abi=3.8=2_cp38 155 | - pytorch=2.0.1=py3.8_cuda11.8_cudnn8.7.0_0 156 | - pytorch-cuda=11.8=h7e8668a_5 157 | - pytorch-lightning=2.1.0=pyhd8ed1ab_0 158 | - pytorch-mutex=1.0=cuda 159 | - pytz=2023.3.post1=pyhd8ed1ab_0 160 | - pyyaml=6.0.1=py38h5eee18b_0 161 | - rapidfuzz=2.13.7=py38h417a72b_0 162 | - readchar=4.0.5=pyhd8ed1ab_0 163 | - readline=7.0=h7b6447c_5 164 | - referencing=0.30.2=py38h06a4308_0 165 | - requests-toolbelt=0.10.1=pyhd8ed1ab_0 166 | - s3transfer=0.7.0=pyhd8ed1ab_0 167 | - secretstorage=3.3.3=py38h578d9bd_2 168 | - shellingham=1.5.3=pyhd8ed1ab_0 169 | - six=1.16.0=pyhd3eb1b0_1 170 | - sniffio=1.3.0=pyhd8ed1ab_0 171 | - soupsieve=2.5=pyhd8ed1ab_1 172 | - sqlite=3.33.0=h62c20be_0 173 | - starlette=0.27.0=pyhd8ed1ab_0 174 | - starsessions=1.3.0=pyhd8ed1ab_0 175 | - sympy=1.11.1=py38h06a4308_0 176 | - tk=8.6.12=h1ccaba5_0 177 | - tomli=2.0.1=pyhd8ed1ab_0 178 | - tomlkit=0.12.1=pyha770c72_0 179 | - torchaudio=2.0.2=py38_cu118 180 | - torchmetrics=1.2.0=pyhd8ed1ab_0 181 | - torchtriton=2.0.0=py38 182 | - torchvision=0.15.2=py38_cu118 183 | - trove-classifiers=2023.10.18=pyhd8ed1ab_0 184 | - types-python-dateutil=2.8.19.14=pyhd8ed1ab_0 185 | - typing-extensions=4.7.1=py38h06a4308_0 186 | - typing_extensions=4.7.1=py38h06a4308_0 187 | - urllib3=1.26.16=py38h06a4308_0 188 | - uvicorn=0.23.2=py38h578d9bd_1 189 | - virtualenv=20.17.1=py38h578d9bd_0 190 | - webencodings=0.5.1=pyhd8ed1ab_2 191 | - websockets=10.4=py38h5eee18b_1 192 | - wheel=0.38.4=py38h06a4308_0 193 | - xz=5.4.2=h5eee18b_0 194 | - yaml=0.2.5=h7f98852_2 195 | - zlib=1.2.13=h5eee18b_0 196 | - zstd=1.5.5=hc292b87_0 197 | - pip: 198 | - absl-py==2.0.0 199 | - accelerate==0.30.1 200 | - addict==2.4.0 201 | - aiofiles==22.1.0 202 | - aiosqlite==0.20.0 203 | - aliyun-python-sdk-core==2.13.36 204 | - aliyun-python-sdk-kms==2.16.2 205 | - ansi2html==1.8.0 206 | - anyio==4.0.0 207 | - apptools==5.2.1 208 | - argon2-cffi==23.1.0 209 | - argon2-cffi-bindings==21.2.0 210 | - asttokens==2.4.0 211 | - async-lru==2.0.4 212 | - av==12.0.0 213 | - babel==2.12.1 214 | - backcall==0.2.0 215 | - black==23.11.0 216 | - bleach==6.0.0 217 | - blinker==1.7.0 218 | - cachetools==5.3.1 219 | - comm==0.1.4 220 | - configargparse==1.7 221 | - configobj==5.0.8 222 | - contourpy==1.1.0 223 | - crcmod==1.7 224 | - cycler==0.11.0 225 | - dash==2.14.1 226 | - dash-core-components==2.0.0 227 | - dash-html-components==2.0.0 228 | - dash-table==5.0.0 229 | - debugpy==1.7.0 230 | - decorator==5.1.1 231 | - decord==0.6.0 232 | - defusedxml==0.7.1 233 | - deprecation==2.1.0 234 | - descartes==1.1.0 235 | - diffusers==0.24.0 236 | - einops==0.7.0 237 | - entrypoints==0.4 238 | - envisage==7.0.3 239 | - executing==1.2.0 240 | - fastjsonschema==2.18.0 241 | - fire==0.5.0 242 | - flake8==5.0.4 243 | - flask==3.0.0 244 | - fonttools==4.42.1 245 | - fqdn==1.5.1 246 | - fsspec==2023.9.0 247 | - google-auth==2.23.4 248 | - google-auth-oauthlib==1.0.0 249 | - grpcio==1.59.2 250 | - huggingface-hub==0.23.1 251 | - imageio==2.32.0 252 | - imageio-ffmpeg==0.4.9 253 | - importlib-resources==6.0.1 254 | - iniconfig==2.0.0 255 | - ipdb==0.13.13 256 | - ipykernel==6.25.2 257 | - ipython==8.12.2 258 | - ipython-genutils==0.2.0 259 | - ipywidgets==8.1.0 260 | - isoduration==20.11.0 261 | - jedi==0.19.0 262 | - jmespath==0.10.0 263 | - joblib==1.3.2 264 | - json5==0.9.14 265 | - jsonpointer==2.4 266 | - jupyter==1.0.0 267 | - jupyter-client==7.4.9 268 | - jupyter-console==6.6.3 269 | - jupyter-core==5.3.1 270 | - jupyter-events==0.7.0 271 | - jupyter-lsp==2.2.0 272 | - jupyter-packaging==0.12.3 273 | - jupyter-server==2.7.3 274 | - jupyter-server-fileid==0.9.1 275 | - jupyter-server-terminals==0.4.4 276 | - jupyter-server-ydoc==0.8.0 277 | - jupyter-ydoc==0.2.5 278 | - jupyterlab==3.6.7 279 | - jupyterlab-pygments==0.2.2 280 | - jupyterlab-server==2.24.0 281 | - jupyterlab-widgets==3.0.8 282 | - kiwisolver==1.4.5 283 | - lazy-loader==0.3 284 | - llvmlite==0.40.1 285 | - lyft-dataset-sdk==0.0.8 286 | - markdown==3.4.4 287 | - matplotlib==3.5.2 288 | - matplotlib-inline==0.1.6 289 | - mayavi==4.8.1 290 | - mccabe==0.7.0 291 | - mdurl==0.1.2 292 | - mistune==3.0.1 293 | - mmcv==2.0.1 294 | - mmdet==3.3.0 295 | - mmdet3d==1.4.0 296 | - mmengine==0.8.4 297 | - model-index==0.1.11 298 | - mypy-extensions==1.0.0 299 | - nbclassic==1.0.0 300 | - nbclient==0.8.0 301 | - nbconvert==7.8.0 302 | - nbformat==5.7.0 303 | - nest-asyncio==1.5.7 304 | - notebook==6.5.6 305 | - notebook-shim==0.2.3 306 | - numba==0.57.1 307 | - nuscenes-devkit==1.1.10 308 | - oauthlib==3.2.2 309 | - open3d==0.13.0 310 | - open3d-python==0.3.0.0 311 | - opencv-python==4.9.0.80 312 | - opendatalab==0.0.10 313 | - openmim==0.3.9 314 | - openxlab==0.0.24 315 | - oss2==2.17.0 316 | - overrides==7.4.0 317 | - packaging==23.1 318 | - pandas==2.0.3 319 | - pandocfilters==1.5.0 320 | - parso==0.8.3 321 | - pathspec==0.11.2 322 | - pickleshare==0.7.5 323 | - platformdirs==3.10.0 324 | - plotly==5.18.0 325 | - pluggy==1.3.0 326 | - plyfile==1.0.1 327 | - prettytable==3.9.0 328 | - prometheus-client==0.17.1 329 | - prompt-toolkit==3.0.39 330 | - protobuf==4.25.0 331 | - psutil==5.9.5 332 | - pure-eval==0.2.2 333 | - pyasn1==0.5.0 334 | - pyasn1-modules==0.3.0 335 | - pycocotools==2.0.7 336 | - pycodestyle==2.9.1 337 | - pycryptodome==3.18.0 338 | - pyface==8.0.0 339 | - pyflakes==2.5.0 340 | - pyparsing==3.0.9 341 | - pyqt5==5.15.10 342 | - pyqt5-qt5==5.15.2 343 | - pyqt5-sip==12.13.0 344 | - pyquaternion==0.9.9 345 | - pytest==7.4.3 346 | - python-json-logger==2.0.7 347 | - pytorch-fast-transformers==0.4.0 348 | - pytorch-fid==0.3.0 349 | - pyvirtualdisplay==3.0 350 | - pywavelets==1.4.1 351 | - pyzmq==24.0.1 352 | - qtconsole==5.4.4 353 | - qtpy==2.4.0 354 | - regex==2024.5.15 355 | - requests==2.28.2 356 | - requests-oauthlib==1.3.1 357 | - retrying==1.3.4 358 | - rfc3339-validator==0.1.4 359 | - rfc3986-validator==0.1.1 360 | - rich==13.4.2 361 | - rpds-py==0.10.2 362 | - rsa==4.9 363 | - safetensors==0.4.3 364 | - scikit-image==0.21.0 365 | - scikit-learn==1.3.0 366 | - scipy==1.10.1 367 | - send2trash==1.8.2 368 | - setuptools==59.5.0 369 | - shapely==1.8.5 370 | - stack-data==0.6.2 371 | - swin-window-process==0.0.0 372 | - tabulate==0.9.0 373 | - tenacity==8.2.3 374 | - tensorboard==2.14.0 375 | - tensorboard-data-server==0.7.2 376 | - termcolor==2.3.0 377 | - terminado==0.17.1 378 | - terminaltables==3.1.10 379 | - thop==0.1.1-2209072238 380 | - threadpoolctl==3.2.0 381 | - tifffile==2023.7.10 382 | - timm==0.9.7 383 | - tinycss2==1.2.1 384 | - tokenizers==0.19.1 385 | - tornado==6.3.3 386 | - tqdm==4.65.2 387 | - traitlets==5.9.0 388 | - traits==6.4.2 389 | - traitsui==8.0.0 390 | - transformers==4.41.1 391 | - trimesh==4.0.4 392 | - tzdata==2023.3 393 | - uri-template==1.3.0 394 | - vtk==9.2.6 395 | - wcwidth==0.2.6 396 | - webcolors==1.13 397 | - websocket-client==1.6.3 398 | - werkzeug==3.0.1 399 | - widgetsnbextension==4.0.8 400 | - xformers==0.0.22 401 | - y-py==0.6.2 402 | - yapf==0.40.1 403 | - ypy-websocket==0.8.4 404 | - zipp==3.16.2 405 | -------------------------------------------------------------------------------- /loss/__init__.py: -------------------------------------------------------------------------------- 1 | from mmengine.registry import Registry 2 | OPENOCC_LOSS = Registry('openocc_loss') 3 | 4 | from .multi_loss import MultiLoss 5 | from .ce_loss import CeLoss 6 | from .emb_loss import VQVAEEmbedLoss 7 | from .recon_loss import ReconLoss, LovaszLoss -------------------------------------------------------------------------------- /loss/base_loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | writer = None 3 | 4 | class BaseLoss(nn.Module): 5 | 6 | """ Base loss class. 7 | args: 8 | weight: weight of current loss. 9 | input_keys: keys for actual inputs to calculate_loss(). 10 | Since "inputs" may contain many different fields, we use input_keys 11 | to distinguish them. 12 | loss_func: the actual loss func to calculate loss. 13 | """ 14 | 15 | def __init__( 16 | self, 17 | weight=1.0, 18 | input_dict={ 19 | 'input': 'input'}, 20 | **kwargs): 21 | super().__init__() 22 | self.weight = weight 23 | self.input_dict = input_dict 24 | self.loss_func = lambda: 0 25 | self.writer = writer 26 | 27 | # def calculate_loss(self, **kwargs): 28 | # return self.loss_func(*[kwargs[key] for key in self.input_keys]) 29 | 30 | def forward(self, inputs): 31 | actual_inputs = {} 32 | for input_key, input_val in self.input_dict.items(): 33 | actual_inputs.update({input_key: inputs[input_val]}) 34 | # return self.weight * self.calculate_loss(**actual_inputs) 35 | return self.weight * self.loss_func(**actual_inputs) 36 | -------------------------------------------------------------------------------- /loss/ce_loss.py: -------------------------------------------------------------------------------- 1 | from .base_loss import BaseLoss 2 | from . import OPENOCC_LOSS 3 | import torch.nn.functional as F 4 | import torch 5 | 6 | @OPENOCC_LOSS.register_module() 7 | class CeLoss(BaseLoss): 8 | 9 | def __init__(self, weight=1.0, ignore_label=-100, 10 | use_weight=False, cls_weight=None, input_dict=None, **kwargs): 11 | super().__init__(weight) 12 | 13 | if input_dict is None: 14 | self.input_dict = { 15 | 'ce_inputs': 'ce_inputs', 16 | 'ce_labels': 'ce_labels' 17 | } 18 | else: 19 | self.input_dict = input_dict 20 | self.loss_func = self.ce_loss 21 | self.ignore = ignore_label 22 | self.use_weight = use_weight 23 | self.cls_weight = torch.tensor(cls_weight) if cls_weight is not None else None 24 | 25 | def ce_loss(self, ce_inputs, ce_labels): 26 | # input: -1, c 27 | # output: -1, 1 28 | ce_loss = F.cross_entropy(ce_inputs, ce_labels) 29 | return ce_loss -------------------------------------------------------------------------------- /loss/emb_loss.py: -------------------------------------------------------------------------------- 1 | from .base_loss import BaseLoss 2 | from . import OPENOCC_LOSS 3 | 4 | @OPENOCC_LOSS.register_module() 5 | class VQVAEEmbedLoss(BaseLoss): 6 | 7 | def __init__(self, weight=1.0, input_dict=None, **kwargs): 8 | super().__init__(weight) 9 | 10 | if input_dict is None: 11 | self.input_dict = { 12 | 'embed_loss': 'embed_loss' 13 | } 14 | else: 15 | self.input_dict = input_dict 16 | self.loss_func = self.embed_loss 17 | 18 | def embed_loss(self, embed_loss): 19 | return embed_loss -------------------------------------------------------------------------------- /loss/multi_loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from . import OPENOCC_LOSS 3 | writer = None 4 | 5 | @OPENOCC_LOSS.register_module() 6 | class MultiLoss(nn.Module): 7 | 8 | def __init__(self, loss_cfgs): 9 | super().__init__() 10 | 11 | assert isinstance(loss_cfgs, list) 12 | self.num_losses = len(loss_cfgs) 13 | 14 | losses = [] 15 | for loss_cfg in loss_cfgs: 16 | losses.append(OPENOCC_LOSS.build(loss_cfg)) 17 | self.losses = nn.ModuleList(losses) 18 | self.iter_counter = 0 19 | 20 | def forward(self, inputs): 21 | 22 | loss_dict = {} 23 | tot_loss = 0. 24 | for loss_func in self.losses: 25 | loss = loss_func(inputs) 26 | tot_loss += loss 27 | loss_dict.update({ 28 | loss_func.__class__.__name__: \ 29 | loss.detach().item() / loss_func.weight 30 | }) 31 | if writer and self.iter_counter % 10 == 0: 32 | writer.add_scalar( 33 | f'loss/{loss_func.__class__.__name__}', 34 | loss.detach().item(), self.iter_counter) 35 | if writer and self.iter_counter % 10 == 0: 36 | writer.add_scalar( 37 | 'loss/total', tot_loss.detach().item(), self.iter_counter) 38 | self.iter_counter += 1 39 | 40 | return tot_loss, loss_dict -------------------------------------------------------------------------------- /loss/recon_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | 3 | from .base_loss import BaseLoss 4 | from . import OPENOCC_LOSS 5 | import torch.nn.functional as F 6 | 7 | @OPENOCC_LOSS.register_module() 8 | class ReconLoss(BaseLoss): 9 | 10 | def __init__(self, weight=1.0, ignore_label=-100, use_weight=False, cls_weight=None, input_dict=None, **kwargs): 11 | super().__init__(weight) 12 | 13 | if input_dict is None: 14 | self.input_dict = { 15 | 'logits': 'logits', 16 | 'labels': 'labels' 17 | } 18 | else: 19 | self.input_dict = input_dict 20 | self.loss_func = self.recon_loss 21 | self.ignore = ignore_label 22 | self.use_weight = use_weight 23 | self.cls_weight = torch.tensor(cls_weight) if cls_weight is not None else None 24 | 25 | def recon_loss(self, logits, labels): 26 | weight = None 27 | if self.use_weight: 28 | if self.cls_weight is not None: 29 | weight = self.cls_weight 30 | else: 31 | one_hot_labels = F.one_hot(labels, num_classes=logits.shape[-1]) # bs, F, H, W, D, C 32 | cls_freq = torch.sum(one_hot_labels, dim=[0, 1, 2, 3, 4]) # C 33 | weight = 1.0 / cls_freq.clamp_min_(1) * torch.numel(labels) / logits.shape[-1] 34 | 35 | rec_loss = F.cross_entropy(logits.permute(0, 5, 1, 2, 3, 4), labels, ignore_index=self.ignore, weight=weight) 36 | return rec_loss 37 | 38 | @OPENOCC_LOSS.register_module() 39 | class LovaszLoss(BaseLoss): 40 | 41 | def __init__(self, weight=1.0, input_dict=None, **kwargs): 42 | super().__init__(weight) 43 | 44 | if input_dict is None: 45 | self.input_dict = { 46 | 'logits': 'logits', 47 | 'labels': 'labels' 48 | } 49 | else: 50 | self.input_dict = input_dict 51 | self.loss_func = self.lovasz_loss 52 | 53 | def lovasz_loss(self, logits, labels): 54 | logits = logits.flatten(0, 1).permute(0, 4, 1, 2, 3).softmax(dim=1) 55 | labels = labels.flatten(0, 1) 56 | loss = lovasz_softmax(logits, labels) 57 | return loss 58 | 59 | 60 | @OPENOCC_LOSS.register_module() 61 | class KLLoss(BaseLoss): 62 | 63 | def __init__(self, weight=1.0, input_dict=None, **kwargs): 64 | super().__init__(weight) 65 | 66 | if input_dict is None: 67 | self.input_dict = { 68 | 'z_mu': 'z_mu', 69 | 'z_sigma': 'z_sigma' 70 | } 71 | else: 72 | self.input_dict = input_dict 73 | self.loss_func = self.KL_loss 74 | 75 | def KL_loss(self, z_mu, z_sigma): 76 | z_mu = z_mu.permute(0,2,1,3,4).flatten(0, 1) #mu 77 | z_sigma = z_sigma.permute(0,2,1,3,4).flatten(0, 1) #logvar 78 | # mu 79 | # log_var 80 | loss = torch.mean(-0.5 * torch.sum(1 + z_sigma - z_mu ** 2 - z_sigma.exp(), dim = [1,2,3]), dim = 0) 81 | return loss 82 | 83 | 84 | """ 85 | Lovasz-Softmax and Jaccard hinge loss in PyTorch 86 | Maxim Berman 2018 ESAT-PSI KU Leuven (MIT License) 87 | """ 88 | 89 | import torch 90 | from torch.autograd import Variable 91 | import torch.nn.functional as F 92 | try: 93 | from itertools import ifilterfalse 94 | except ImportError: # py3k 95 | from itertools import filterfalse as ifilterfalse 96 | 97 | 98 | def lovasz_grad(gt_sorted): 99 | """ 100 | Computes gradient of the Lovasz extension w.r.t sorted errors 101 | See Alg. 1 in paper 102 | """ 103 | p = len(gt_sorted) 104 | gts = gt_sorted.sum() 105 | intersection = gts - gt_sorted.float().cumsum(0) 106 | union = gts + (1 - gt_sorted).float().cumsum(0) 107 | jaccard = 1. - intersection / union 108 | if p > 1: # cover 1-pixel case 109 | jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] 110 | return jaccard 111 | 112 | # --------------------------- MULTICLASS LOSSES --------------------------- 113 | 114 | 115 | def lovasz_softmax(probas, labels, classes='present', per_image=False, ignore=None): 116 | """ 117 | Multi-class Lovasz-Softmax loss 118 | probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1). 119 | Interpreted as binary (sigmoid) output with outputs of size [B, H, W]. 120 | labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1) 121 | classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. 122 | per_image: compute the loss per image instead of per batch 123 | ignore: void class labels 124 | """ 125 | if per_image: 126 | loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), classes=classes) 127 | for prob, lab in zip(probas, labels)) 128 | else: 129 | loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), classes=classes) 130 | return loss 131 | 132 | 133 | def lovasz_softmax_flat(probas, labels, classes='present'): 134 | """ 135 | Multi-class Lovasz-Softmax loss 136 | probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1) 137 | labels: [P] Tensor, ground truth labels (between 0 and C - 1) 138 | classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. 139 | """ 140 | if probas.numel() == 0: 141 | # only void pixels, the gradients should be 0 142 | return 0.#probas * 0. 143 | #print(probas.size()) 144 | C = probas.size(1) 145 | losses = [] 146 | class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes 147 | for c in class_to_sum: 148 | fg = (labels == c).float() # foreground for class c 149 | if (classes == 'present' and fg.sum() == 0): 150 | continue 151 | if C == 1: 152 | if len(classes) > 1: 153 | raise ValueError('Sigmoid output possible only with 1 class') 154 | class_pred = probas[:, 0] 155 | else: 156 | class_pred = probas[:, c] 157 | errors = (Variable(fg) - class_pred).abs() 158 | errors_sorted, perm = torch.sort(errors, 0, descending=True) 159 | perm = perm.data 160 | fg_sorted = fg[perm] 161 | losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted)))) 162 | return mean(losses) 163 | 164 | 165 | def flatten_probas(probas, labels, ignore=None): 166 | """ 167 | Flattens predictions in the batch 168 | """ 169 | if probas.dim() == 3: 170 | # assumes output of a sigmoid layer 171 | B, H, W = probas.size() 172 | probas = probas.view(B, 1, H, W) 173 | elif probas.dim() == 5: 174 | #3D segmentation 175 | B, C, L, H, W = probas.size() 176 | probas = probas.contiguous().view(B, C, L, H*W) 177 | B, C, H, W = probas.size() 178 | probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C 179 | labels = labels.view(-1) 180 | if ignore is None: 181 | return probas, labels 182 | valid = (labels != ignore) 183 | vprobas = probas[valid]#.nonzero().squeeze()] 184 | # print(labels) 185 | # print(valid) 186 | vlabels = labels[valid] 187 | return vprobas, vlabels 188 | 189 | # --------------------------- HELPER FUNCTIONS --------------------------- 190 | 191 | def isnan(x): 192 | return x != x 193 | 194 | def mean(l, ignore_nan=False, empty=0): 195 | """ 196 | nanmean compatible with generators. 197 | """ 198 | l = iter(l) 199 | if ignore_nan: 200 | l = ifilterfalse(isnan, l) 201 | try: 202 | n = 1 203 | acc = next(l) 204 | except StopIteration: 205 | if empty == 'raise': 206 | raise ValueError('Empty mean') 207 | return empty 208 | for n, v in enumerate(l, 2): 209 | acc += v 210 | if n == 1: 211 | return acc 212 | return acc / n 213 | -------------------------------------------------------------------------------- /model/VAE/quantizer.py: -------------------------------------------------------------------------------- 1 | 2 | """ adapted from: https://github.com/CompVis/taming-transformers/blob/master/taming/modules/vqvae/quantize.py """ 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import numpy as np 8 | from einops import rearrange 9 | # from mmseg.models import BACKBONES 10 | from mmengine.registry import MODELS 11 | from mmengine.model import BaseModule 12 | 13 | @MODELS.register_module() 14 | class VectorQuantizer(BaseModule): 15 | """ 16 | Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly 17 | avoids costly matrix multiplications and allows for post-hoc remapping of indices. 18 | """ 19 | # NOTE: due to a bug the beta term was applied to the wrong term. for 20 | # backwards compatibility we use the buggy version by default, but you can 21 | # specify legacy=False to fix it. 22 | def __init__(self, n_e, e_dim, beta, z_channels, remap=None, unknown_index="random", 23 | sane_index_shape=False, legacy=True, use_voxel=True): 24 | super().__init__() 25 | self.n_e = n_e 26 | self.e_dim = e_dim 27 | self.beta = beta 28 | self.legacy = legacy 29 | 30 | self.embedding = nn.Embedding(self.n_e, self.e_dim) 31 | self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) 32 | 33 | self.remap = remap 34 | if self.remap is not None: 35 | self.register_buffer("used", torch.tensor(np.load(self.remap))) 36 | self.re_embed = self.used.shape[0] 37 | self.unknown_index = unknown_index # "random" or "extra" or integer 38 | if self.unknown_index == "extra": 39 | self.unknown_index = self.re_embed 40 | self.re_embed = self.re_embed+1 41 | print(f"Remapping {self.n_e} indices to {self.re_embed} indices. " 42 | f"Using {self.unknown_index} for unknown indices.") 43 | else: 44 | self.re_embed = n_e 45 | 46 | self.sane_index_shape = sane_index_shape 47 | 48 | conv_class = torch.nn.Conv3d if use_voxel else torch.nn.Conv2d 49 | self.quant_conv = conv_class(z_channels, self.e_dim, 1) 50 | self.post_quant_conv = conv_class(self.e_dim, z_channels, 1) 51 | 52 | def remap_to_used(self, inds): 53 | ishape = inds.shape 54 | assert len(ishape)>1 55 | inds = inds.reshape(ishape[0],-1) 56 | used = self.used.to(inds) 57 | match = (inds[:,:,None]==used[None,None,...]).long() 58 | new = match.argmax(-1) 59 | unknown = match.sum(2)<1 60 | if self.unknown_index == "random": 61 | new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device) 62 | else: 63 | new[unknown] = self.unknown_index 64 | return new.reshape(ishape) 65 | 66 | def unmap_to_all(self, inds): 67 | ishape = inds.shape 68 | assert len(ishape)>1 69 | inds = inds.reshape(ishape[0],-1) 70 | used = self.used.to(inds) 71 | if self.re_embed > self.used.shape[0]: # extra token 72 | inds[inds>=self.used.shape[0]] = 0 # simply set to zero 73 | back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds) 74 | return back.reshape(ishape) 75 | 76 | def forward(self, z, temp=None, rescale_logits=False, return_logits=False, is_voxel=False): 77 | z = self.quant_conv(z) 78 | z_q, loss, (perplexity, min_encodings, min_encoding_indices) = self.forward_quantizer(z, temp, rescale_logits, return_logits, is_voxel) 79 | z_q = self.post_quant_conv(z_q) 80 | return z_q, loss, (perplexity, min_encodings, min_encoding_indices) 81 | def forward_quantizer(self, z, temp=None, rescale_logits=False, return_logits=False, is_voxel=False): 82 | assert temp is None or temp==1.0, "Only for interface compatible with Gumbel" 83 | assert rescale_logits==False, "Only for interface compatible with Gumbel" 84 | assert return_logits==False, "Only for interface compatible with Gumbel" 85 | 86 | # reshape z -> (batch, height, width, channel) and flatten 87 | if not is_voxel: 88 | z = rearrange(z, 'b c h w -> b h w c').contiguous() 89 | else: 90 | z = rearrange(z, 'b c d h w -> b d h w c').contiguous() 91 | z_flattened = z.view(-1, self.e_dim) 92 | # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z 93 | 94 | d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ 95 | torch.sum(self.embedding.weight**2, dim=1) - 2 * \ 96 | torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n')) 97 | 98 | min_encoding_indices = torch.argmin(d, dim=1) 99 | z_q = self.embedding(min_encoding_indices).view(z.shape) 100 | perplexity = None 101 | min_encodings = None 102 | 103 | # compute loss for embedding 104 | if not self.legacy: 105 | loss = self.beta * torch.mean((z_q.detach()-z)**2) + \ 106 | torch.mean((z_q - z.detach()) ** 2) 107 | else: 108 | loss = torch.mean((z_q.detach()-z)**2) + self.beta * \ 109 | torch.mean((z_q - z.detach()) ** 2) 110 | 111 | # preserve gradients 112 | z_q = z + (z_q - z).detach() 113 | 114 | # reshape back to match original input shape 115 | if not is_voxel: 116 | z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous() 117 | else: 118 | z_q = rearrange(z_q, 'b d h w c -> b c d h w').contiguous() 119 | 120 | 121 | 122 | if self.remap is not None: 123 | min_encoding_indices = min_encoding_indices.reshape(z.shape[0],-1) # add batch axis 124 | min_encoding_indices = self.remap_to_used(min_encoding_indices) 125 | min_encoding_indices = min_encoding_indices.reshape(-1,1) # flatten 126 | 127 | if self.sane_index_shape: 128 | if not is_voxel: 129 | min_encoding_indices = min_encoding_indices.reshape( 130 | z_q.shape[0], z_q.shape[2], z_q.shape[3]) 131 | else: 132 | min_encoding_indices = min_encoding_indices.reshape( 133 | z_q.shape[0], z_q.shape[2], z_q.shape[3], z_q.shape[4]) 134 | 135 | return z_q, loss, (perplexity, min_encodings, min_encoding_indices) 136 | def get_codebook_entry(self, indices, shape): 137 | # shape specifying (batch, height, width, channel) 138 | if self.remap is not None: 139 | indices = indices.reshape(shape[0],-1) # add batch axis 140 | indices = self.unmap_to_all(indices) 141 | indices = indices.reshape(-1) # flatten again 142 | 143 | # get quantized latent vectors 144 | # z_q = self.embedding(indices) 145 | mask=(indices>0) & (indices b h w c').contiguous() 160 | else: 161 | b, c, d, h, w = z.shape 162 | z = rearrange(z, 'b c d h w -> b d h w c').contiguous() 163 | z_flattened = z.view(-1, self.e_dim) 164 | d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ 165 | torch.sum(self.embedding.weight**2, dim=1) - 2 * \ 166 | torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n')) 167 | min_encoding_indices = torch.argmin(d, dim=1) 168 | if not is_voxels: 169 | min_encoding_indices = min_encoding_indices.reshape(b, h, w) 170 | else: 171 | min_encoding_indices = min_encoding_indices.reshape(b, d, h, w) 172 | return min_encoding_indices 173 | 174 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .VAE.vae_2d_resnet import VAERes2D,VAERes3D 2 | from .VAE.quantizer import VectorQuantizer 3 | 4 | from .pose_encoder import PoseEncoder,PoseEncoder_fourier 5 | 6 | from .dome import Dome -------------------------------------------------------------------------------- /model/pose_encoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from mmengine.registry import MODELS 4 | from mmengine.model import BaseModule 5 | from einops import rearrange, repeat 6 | import torch.nn.functional as F 7 | 8 | @MODELS.register_module() 9 | class PoseEncoder(BaseModule): 10 | def __init__( 11 | self, 12 | in_channels, 13 | out_channels, 14 | num_layers=2, 15 | num_modes=3, 16 | num_fut_ts=1, 17 | init_cfg=None 18 | ): 19 | super().__init__(init_cfg) 20 | self.num_modes = num_modes 21 | self.num_fut_ts = num_fut_ts 22 | assert num_fut_ts == 1 23 | 24 | pose_encoder = [] 25 | 26 | for _ in range(num_layers - 1): 27 | pose_encoder.extend([ 28 | nn.Linear(in_channels, out_channels), 29 | nn.ReLU(True)]) 30 | in_channels = out_channels 31 | pose_encoder.append(nn.Linear(out_channels, out_channels)) 32 | self.pose_encoder = nn.Sequential(*pose_encoder) 33 | 34 | def forward(self,x): 35 | # x: N*2, 36 | pose_feat = self.pose_encoder(x) 37 | return pose_feat 38 | 39 | 40 | @MODELS.register_module() 41 | class PoseEncoder_fourier(BaseModule): 42 | def __init__( 43 | self, 44 | in_channels, 45 | out_channels, 46 | num_layers=2, 47 | num_modes=3, 48 | num_fut_ts=1, 49 | fourier_freqs=8, 50 | init_cfg=None, 51 | do_proj=False, 52 | max_length=77, 53 | # zero_init=False 54 | **kwargs 55 | ): 56 | super().__init__(init_cfg) 57 | self.num_modes = num_modes 58 | self.num_fut_ts = num_fut_ts 59 | assert num_fut_ts == 1 60 | # assert in_channels==2,"only support 2d coordinates for now, include gt_mode etc later" 61 | self.fourier_freqs=fourier_freqs 62 | self.position_dim = fourier_freqs * 2 * in_channels # 2: sin/cos, 2: xy 63 | in_channels=self.position_dim 64 | 65 | 66 | pose_encoder = [] 67 | for _ in range(num_layers - 1): 68 | pose_encoder.extend([ 69 | nn.Linear(in_channels, out_channels), 70 | nn.ReLU(True)]) 71 | in_channels = out_channels 72 | pose_encoder.append(nn.Linear(out_channels, out_channels)) 73 | self.pose_encoder = nn.Sequential(*pose_encoder) 74 | 75 | # self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim])) 76 | self.do_proj=do_proj 77 | self.max_length=max_length 78 | if do_proj: 79 | # proj b*f*c -> b*c 80 | self.embedding_projection =nn.Linear(max_length * out_channels, out_channels, bias=True) 81 | self.null_position_feature = torch.nn.Parameter(torch.zeros([out_channels])) 82 | # if zero_init: 83 | # self.zero_module() 84 | 85 | def zero_module(self): 86 | """ 87 | Zero out the parameters of a module and return it. 88 | """ 89 | for p in self.parameters(): 90 | p.detach().zero_() 91 | 92 | def forward(self,x,mask=None): 93 | # x: N*2, 94 | b,f=x.shape[:2] 95 | x = rearrange(x, 'b f d -> (b f) d') 96 | x=get_fourier_embeds_from_coordinates(self.fourier_freqs,x) # N*dim (bf)*32 # 2,11,32 97 | # if mask is not None: #TODO 98 | # # learnable null embedding 99 | # xyxy_null = self.null_position_feature.view(1, 1, -1) 100 | # # replace padding with learnable null embedding 101 | # xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null 102 | x = self.pose_encoder(x) #([2, 11, 768]) 103 | if self.do_proj: 104 | x = rearrange(x, '(b f) d -> b f d', b=b) # 2,11,32 105 | x_pad=F.pad(x,(0,0,0,self.max_length-f)) # 2,77,32 106 | xyxy_null = self.null_position_feature.view(1, 1, -1) 107 | mask=torch.zeros(b,self.max_length,1).to(x.device) 108 | mask[:,:f]=1 109 | x = x_pad * mask + (1 - mask) * xyxy_null 110 | x = rearrange(x, 'b f d -> b (f d)') 111 | x=self.embedding_projection(x) # b d 112 | else: 113 | x = rearrange(x, '(b f) d -> b f d', b=b) 114 | 115 | return x 116 | 117 | @MODELS.register_module() 118 | class PoseEncoder_fourier_yaw(BaseModule): 119 | def __init__( 120 | self, 121 | in_channels, 122 | out_channels, 123 | num_layers=2, 124 | num_modes=3, 125 | num_fut_ts=1, 126 | fourier_freqs=8, 127 | init_cfg=None, 128 | do_proj=False, 129 | max_length=77, 130 | # zero_init=False 131 | **kwargs 132 | ): 133 | super().__init__(init_cfg) 134 | self.num_modes = num_modes 135 | self.num_fut_ts = num_fut_ts 136 | assert num_fut_ts == 1 137 | # assert in_channels==2,"only support 2d coordinates for now, include gt_mode etc later" 138 | self.fourier_freqs=fourier_freqs 139 | self.position_dim = fourier_freqs * 2 * (in_channels-1) # 2: sin/cos, 2: xy 140 | self.position_dim_yaw = fourier_freqs * 2 * (1) # 2: sin/cos, 2: xy 141 | 142 | 143 | in_channels=self.position_dim 144 | pose_encoder = [] 145 | for _ in range(num_layers - 1): 146 | pose_encoder.extend([ 147 | nn.Linear(in_channels, out_channels), 148 | nn.ReLU(True)]) 149 | in_channels = out_channels 150 | pose_encoder.append(nn.Linear(out_channels, out_channels)) 151 | self.pose_encoder = nn.Sequential(*pose_encoder) 152 | 153 | in_channels=self.position_dim_yaw 154 | pose_encoder_yaw = [] 155 | for _ in range(num_layers - 1): 156 | pose_encoder_yaw.extend([ 157 | nn.Linear(in_channels, out_channels), 158 | nn.ReLU(True)]) 159 | in_channels = out_channels 160 | pose_encoder_yaw.append(nn.Linear(out_channels, out_channels)) 161 | self.pose_encoder_yaw = nn.Sequential(*pose_encoder_yaw) 162 | 163 | # self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim])) 164 | self.do_proj=do_proj 165 | self.max_length=max_length 166 | if do_proj: 167 | # proj b*f*c -> b*c 168 | self.embedding_projection =nn.Linear(max_length * out_channels, out_channels, bias=True) 169 | self.null_position_feature = torch.nn.Parameter(torch.zeros([out_channels])) 170 | self.embedding_projection_yaw =nn.Linear(max_length * out_channels, out_channels, bias=True) 171 | self.null_position_feature_yaw = torch.nn.Parameter(torch.zeros([out_channels])) 172 | # if zero_init: 173 | # self.zero_module() 174 | 175 | def zero_module(self, zero_params=None): 176 | """ 177 | Zero out the parameters of a module based on the given list of parameter names and return it. 178 | 179 | Args: 180 | zero_params (list): List of parameter names to zero. If None, all parameters will be zeroed. 181 | """ 182 | for name, p in self.named_parameters(): 183 | if zero_params is None or name in zero_params: 184 | p.detach().zero_() 185 | 186 | def forward(self,x,mask=None): 187 | # x: N*2, 188 | b,f=x.shape[:2] 189 | x = rearrange(x, 'b f d -> (b f) d') 190 | x,x_yaw=x[:,:-1],x[:,-1:] 191 | x=get_fourier_embeds_from_coordinates(self.fourier_freqs,x) # N*dim (bf)*32 # 2,11,32 192 | x_yaw=get_fourier_embeds_from_coordinates(self.fourier_freqs,x_yaw) # N*dim (bf)*32 # 2,11,32 193 | # if mask is not None: #TODO 194 | # # learnable null embedding 195 | # xyxy_null = self.null_position_feature.view(1, 1, -1) 196 | # # replace padding with learnable null embedding 197 | # xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null 198 | x = self.pose_encoder(x) #([2, 11, 768]) 199 | x_yaw = self.pose_encoder_yaw(x_yaw) #([2, 11, 768]) 200 | if self.do_proj: 201 | x = rearrange(x, '(b f) d -> b f d', b=b) # 2,11,32 202 | x_pad=F.pad(x,(0,0,0,self.max_length-f)) # 2,77,32 203 | xyxy_null = self.null_position_feature.view(1, 1, -1) 204 | mask=torch.zeros(b,self.max_length,1).to(x.device) 205 | mask[:,:f]=1 206 | x = x_pad * mask + (1 - mask) * xyxy_null 207 | x = rearrange(x, 'b f d -> b (f d)') 208 | x=self.embedding_projection(x) # b d 209 | else: 210 | x = rearrange(x, '(b f) d -> b f d', b=b) 211 | 212 | if self.do_proj: 213 | x_yaw = rearrange(x_yaw, '(b f) d -> b f d', b=b) # 2,11,32 214 | x_pad=F.pad(x_yaw,(0,0,0,self.max_length-f)) # 2,77,32 215 | xyxy_null = self.null_position_feature.view(1, 1, -1) 216 | mask=torch.zeros(b,self.max_length,1).to(x_yaw.device) 217 | mask[:,:f]=1 218 | x_yaw = x_pad * mask + (1 - mask) * xyxy_null 219 | x_yaw = rearrange(x_yaw, 'b f d -> b (f d)') 220 | x_yaw=self.embedding_projection(x_yaw) # b d 221 | else: 222 | x_yaw = rearrange(x_yaw, '(b f) d -> b f d', b=b) 223 | 224 | x=x+x_yaw 225 | 226 | return x 227 | 228 | def get_fourier_embeds_from_coordinates(embed_dim, xys): 229 | """ 230 | Args: 231 | embed_dim: int 232 | xys: a 3-D tensor [B x N x 4] representing the bounding boxes for GLIGEN pipeline 233 | Returns: 234 | [B x N x embed_dim] tensor of positional embeddings 235 | """ 236 | 237 | batch_size = xys.shape[0] 238 | ch= xys.shape[-1] 239 | 240 | emb = 100 ** (torch.arange(embed_dim) / embed_dim) 241 | emb = emb[None, None].to(device=xys.device, dtype=xys.dtype) 242 | emb = emb * xys.unsqueeze(-1) 243 | 244 | emb = torch.stack((emb.sin(), emb.cos()), dim=-1) 245 | emb = emb.permute(0, 2, 3, 1).reshape(batch_size, embed_dim * 2 * ch) 246 | 247 | return emb 248 | 249 | 250 | -------------------------------------------------------------------------------- /resample/astar.py: -------------------------------------------------------------------------------- 1 | # Sample code from https://www.redblobgames.com/pathfinding/a-star/ 2 | import heapq 3 | import collections 4 | import numpy as np 5 | import math 6 | 7 | class SimpleGraph: 8 | def __init__(self): 9 | self.edges = {} 10 | 11 | def neighbors(self, id): 12 | return self.edges[id] 13 | 14 | 15 | class Queue: 16 | def __init__(self): 17 | self.elements = collections.deque() 18 | 19 | def empty(self): 20 | return len(self.elements) == 0 21 | 22 | def put(self, x): 23 | self.elements.append(x) 24 | 25 | def get(self): 26 | return self.elements.popleft() 27 | 28 | # utility functions for dealing with square grids 29 | def from_id_width(id, width): 30 | return (id % width, id // width) 31 | 32 | def draw_tile(graph, id, style, width): 33 | r = "." 34 | if 'number' in style and id in style['number']: r = "%d" % style['number'][id] 35 | if 'point_to' in style and style['point_to'].get(id, None) is not None: 36 | (x1, y1) = id 37 | (x2, y2) = style['point_to'][id] 38 | if x2 == x1 + 1: r = ">" 39 | if x2 == x1 - 1: r = "<" 40 | if y2 == y1 + 1: r = "v" 41 | if y2 == y1 - 1: r = "^" 42 | if 'start' in style and id == style['start']: r = "A" 43 | if 'goal' in style and id == style['goal']: r = "Z" 44 | if 'path' in style and id in style['path']: r = "@" 45 | if id in graph.walls: r = "#" * width 46 | return r 47 | 48 | def draw_grid(graph, width=2, **style): 49 | for y in range(graph.height): 50 | for x in range(graph.width): 51 | print("%%-%ds" % width % draw_tile(graph, (x, y), style, width), end="") 52 | print() 53 | 54 | 55 | class SquareGrid: 56 | def __init__(self, width, height): 57 | self.width = width 58 | self.height = height 59 | self.walls = [] 60 | 61 | def in_bounds(self, id): 62 | (x, y) = id 63 | return 0 <= x < self.width and 0 <= y < self.height 64 | 65 | def passable(self, id): 66 | return id not in self.walls 67 | 68 | def neighbors(self, id): 69 | (x, y) = id 70 | results = [(x+1, y), (x, y-1), (x-1, y), (x, y+1)] 71 | if (x + y) % 2 == 0: results.reverse() # aesthetics 72 | results = filter(self.in_bounds, results) 73 | results = filter(self.passable, results) 74 | return results 75 | 76 | @classmethod 77 | def from_voxel_bev_map(cls, mask,mark_margin_obs=True): 78 | # assert mask.dtype == np.bool and mask.ndim == 2 79 | width,height = mask.shape[:2] 80 | grid = cls(width, height) 81 | if mark_margin_obs: 82 | mask[:,[0,-1]]=0 83 | mask[[0,-1],:]=0 84 | grid.walls = [(x,y) for x,y in np.argwhere(mask==0)] 85 | return grid 86 | 87 | 88 | class GridWithWeights(SquareGrid): 89 | def __init__(self, width, height): 90 | super().__init__(width, height) 91 | self.weights = {} 92 | 93 | def cost(self, from_node, to_node): 94 | 95 | return self.weights.get(to_node, 1) 96 | @classmethod 97 | def from_voxel_bev_map(cls, mask,cost_map=None,mark_margin_obs=True): 98 | # assert mask.dtype == np.bool and mask.ndim == 2 99 | width,height = mask.shape[:2] 100 | if mark_margin_obs: 101 | mask[:,[0,-1]]=0 102 | mask[[0,-1],:]=0 103 | grid = cls(width, height) 104 | grid.walls = [(x,y) for x,y in np.argwhere(mask==0)] 105 | if cost_map is not None: 106 | grid.weights = {(x,y):cost_map[x,y] for x in range(width) for y in range(height)} 107 | return grid 108 | 109 | 110 | 111 | class PriorityQueue: 112 | def __init__(self): 113 | self.elements = [] 114 | 115 | def empty(self): 116 | return len(self.elements) == 0 117 | 118 | def put(self, item, priority): 119 | heapq.heappush(self.elements, (priority, item)) 120 | 121 | def get(self): 122 | return heapq.heappop(self.elements)[1] 123 | 124 | def reconstruct_path(came_from, start, goal): 125 | current = goal 126 | path = [] 127 | while current != start: 128 | path.append(current) 129 | current = came_from[current] 130 | path.append(start) # optional 131 | path.reverse() # optional 132 | return path 133 | 134 | # def heuristic(a, b): 135 | # (x1, y1) = a 136 | # (x2, y2) = b 137 | # return abs(x1 - x2) + abs(y1 - y2) 138 | def heuristic(a, b): 139 | (x1, y1) = a 140 | (x2, y2) = b 141 | return math.sqrt((x1 - x2)**2 + (y1 - y2)**2) 142 | 143 | def a_star_search(graph, start, goal): 144 | frontier = PriorityQueue() 145 | frontier.put(start, 0) 146 | came_from = {} 147 | cost_so_far = {} 148 | came_from[start] = None 149 | cost_so_far[start] = 0 150 | 151 | while not frontier.empty(): 152 | current = frontier.get() 153 | 154 | if current == goal: 155 | break 156 | 157 | for next in graph.neighbors(current): 158 | new_cost = cost_so_far[current] + graph.cost(current, next) 159 | if next not in cost_so_far or new_cost < cost_so_far[next]: 160 | cost_so_far[next] = new_cost 161 | priority = new_cost + heuristic(goal, next) 162 | frontier.put(next, priority) 163 | came_from[next] = current 164 | 165 | return came_from, cost_so_far 166 | 167 | # def a_star_search_distance_cost(graph,distancecostmap, start, goal): 168 | # frontier = PriorityQueue() 169 | # frontier.put(start, 0) 170 | # came_from = {} 171 | # cost_so_far = {} 172 | # distance_cost = {} 173 | # came_from[start] = None 174 | # cost_so_far[start] = 0 175 | 176 | # while not frontier.empty(): 177 | # current = frontier.get() 178 | 179 | # if current == goal: 180 | # break 181 | 182 | # for next in graph.neighbors(current): 183 | # new_cost = cost_so_far[current] + graph.cost(current, next) 184 | # if next not in cost_so_far or new_cost < cost_so_far[next]: 185 | # cost_so_far[next] = new_cost 186 | # distance_cost 187 | # priority = new_cost + heuristic(goal, next) 188 | # frontier.put(next, priority) 189 | # came_from[next] = current 190 | 191 | # return came_from, cost_so_far 192 | 193 | def breadth_first_search_3(graph, start, goal): 194 | frontier = Queue() 195 | frontier.put(start) 196 | came_from = {} 197 | came_from[start] = None 198 | 199 | while not frontier.empty(): 200 | current = frontier.get() 201 | 202 | if current == goal: 203 | break 204 | 205 | for next in graph.neighbors(current): 206 | if next not in came_from: 207 | frontier.put(next) 208 | came_from[next] = current 209 | 210 | return came_from 211 | 212 | 213 | 214 | def distcost(distancecostmap,x, y, safty_value=2,w=50): 215 | # large safty value makes the path more away from the wall 216 | # However, if it is too large, almost grid will get max cost 217 | # which leads to eliminate the meaning of distance cost. 218 | max_distance_cost = np.max(distancecostmap) 219 | distance_cost = max_distance_cost-distancecostmap[x][y] 220 | #if distance_cost > (max_distance_cost/safty_value): 221 | # distance_cost = 1000 222 | # return distance_cost 223 | return w * distance_cost # E5 223 - 50 224 | 225 | # start, goal = (1, 4), (7, 8) 226 | # came_from, cost_so_far = a_star_search(diagram4, start, goal) 227 | # draw_grid(diagram4, width=3, point_to=came_from, start=start, goal=goal) 228 | # print() 229 | # draw_grid(diagram4, width=3, number=cost_so_far, start=start, goal=goal) 230 | # print() 231 | # draw_grid(diagram4, width=3, path=reconstruct_path(came_from, start=start, goal=goal)) 232 | 233 | 234 | 235 | def create_bev_from_pc(occ_pc_road, resolution_2d, max_dist): 236 | """ 237 | Create BEV map from point cloud 238 | Args: 239 | occ_pc_road: point cloud of the road 240 | resolution_2d: resolution of the BEV map 241 | max_dist: maximum distance to the road 242 | Returns: 243 | bev_road: BEV map of the road, Binary map 244 | nearest_distance: nearest distance to the road 245 | (x_min, y_min): minimum coordinates of the BEV map 246 | """ 247 | # set z=0 and compute grid parameters 248 | occ_pc_road = occ_pc_road[:,:2] 249 | x_min, x_max = occ_pc_road[:,0].min(), occ_pc_road[:,0].max() 250 | y_min, y_max = occ_pc_road[:,1].min(), occ_pc_road[:,1].max() 251 | n_x = int((x_max - x_min) / resolution_2d) 252 | n_y = int((y_max - y_min) / resolution_2d) 253 | 254 | # Create grid coordinates efficiently 255 | x_coords = np.linspace(x_min + resolution_2d/2, x_max - resolution_2d/2, n_x) 256 | y_coords = np.linspace(y_min + resolution_2d/2, y_max - resolution_2d/2, n_y) 257 | grid_x, grid_y = np.meshgrid(x_coords, y_coords) 258 | grid_coords = np.column_stack((grid_x.ravel(), grid_y.ravel())) 259 | 260 | # Fit KNN model and compute distances 261 | knn = NearestNeighbors(n_neighbors=3, algorithm='ball_tree', n_jobs=-1).fit(occ_pc_road) 262 | distances, _ = knn.kneighbors(grid_coords) 263 | mean_distances = distances.mean(axis=1) 264 | 265 | # Create and populate nd and occ_2d arrays 266 | nearest_distance = mean_distances.reshape(n_y, n_x) 267 | bev_road = (nearest_distance < max_dist).astype(np.int32) 268 | 269 | return bev_road, nearest_distance, (x_min, y_min) -------------------------------------------------------------------------------- /resample/launch.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import sys,os 3 | import pickle 4 | from tqdm import tqdm 5 | 6 | 7 | def run_command(command: str) -> None: 8 | """Run a command kill actions if it fails 9 | 10 | Args: 11 | command: command to run 12 | """ 13 | ret_code = subprocess.call(command, shell=True) 14 | if ret_code != 0: 15 | print(f"[bold red]Error: `{command}` failed. Exiting...") 16 | # sys.exit(1) 17 | 18 | 19 | def parse_args(): 20 | import argparse 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--rank', type=int, default=0) 23 | parser.add_argument('--n_rank', type=int, default=1) 24 | parser.add_argument('--dst', type=str, required=True) 25 | parser.add_argument('--imageset', type=str, default='../data/nuscenes_infos_train_temporal_v3_scene.pkl') 26 | parser.add_argument('--input_dataset', type=str, default='gts') 27 | parser.add_argument('--data_path', type=str, default='../data/nuscenes') 28 | 29 | 30 | return parser.parse_args() 31 | 32 | args=parse_args() 33 | with open(args.imageset, 'rb') as f: 34 | data = pickle.load(f) 35 | nusc_infos = data['infos'] 36 | 37 | 38 | for i in tqdm(range(args.rank,len(nusc_infos),args.n_rank),desc='lauch processing'): 39 | print('#'*50,'idx',i) 40 | dst=os.path.join(args.dst,f'src_scene-{i+1:04d}') 41 | cmd = [ 42 | "python", "main.py","--idx",f"{i}", 43 | "--dst",f"{dst}", 44 | "--imageset",f"{args.imageset}", 45 | "--input_dataset",f"{args.input_dataset}", 46 | "--data_path",f"{args.data_path}", 47 | ] 48 | cmd = " ".join(cmd) 49 | print('@'*50,cmd) 50 | run_command(cmd) 51 | 52 | 53 | 54 | 55 | 56 | -------------------------------------------------------------------------------- /resample/main.py: -------------------------------------------------------------------------------- 1 | from utils import ( 2 | visualize_point_cloud, 3 | vis_pose_mesh, 4 | colors, write_pc, 5 | get_inliers_outliers, 6 | sample_points_in_roi, 7 | sample_points_in_roi_v2, 8 | approximate_b_spline_path, 9 | sampling_occ_from_pc, 10 | get_3d_pose_from_2d, 11 | create_bev_from_pc, 12 | ransac, 13 | downsample_pc_with_label, 14 | get_mask_from_path 15 | ) 16 | import numpy as np 17 | import open3d as o3d 18 | from matplotlib import pyplot as plt 19 | from astar import GridWithWeights, a_star_search, reconstruct_path 20 | import time 21 | import os 22 | from tqdm import tqdm 23 | from functools import partial 24 | import joblib 25 | import cv2 26 | import pickle 27 | import shutil 28 | from pyquaternion import Quaternion 29 | 30 | 31 | def parse_args(): 32 | import argparse 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument('--idx', type=int, required=True) 35 | parser.add_argument('--dst', type=str, required=True) 36 | parser.add_argument('--imageset', type=str, default='../data/nuscenes_infos_train_temporal_v3_scene.pkl') 37 | parser.add_argument('--input_dataset', type=str, default='gts') 38 | parser.add_argument('--data_path', type=str, default='../data/nuscenes') 39 | 40 | 41 | return parser.parse_args() 42 | 43 | # Configuration 44 | args = parse_args() 45 | idx = args.idx 46 | dst_dir = args.dst 47 | os.makedirs(dst_dir, exist_ok=True) 48 | 49 | 50 | moving_cls_id = [ 51 | 2, # 'bicycle' 52 | 3, # 'bus' 53 | 4, # 'car' 54 | 6, # motorcycle 55 | 7, # 'pedestrian' 56 | 9, # trailer 57 | 10, # 'truck' 58 | ] 59 | road_cls_id = 11 60 | 61 | ## voxelization param 62 | pc_voxel_downsample=0.2 63 | # resolution_2d=1 64 | resolution_2d=2 65 | max_dist=0.5 66 | 67 | 68 | ## st en point sampling param 69 | path_expand_radius=2 70 | seed=0 71 | n_sample_pair=10 72 | min_distance_st_en=10 73 | 74 | ## A* param 75 | distance_cost_weigth=250 76 | n_traj_point_ds=4 77 | 78 | ## traj valid check 79 | min_min_traj_len_st_en=2 80 | delta_min_traj_len_st_en=10 81 | min_traj_len_st_en=30 82 | max_traj_len_st_en=50 83 | fail_cnt_thres=10 84 | 85 | ## resampling occ from pc 86 | n_sample_occ=40 # each traj only sample 40 87 | voxel_size= 0.4 88 | pc_range= [-40, -40, -1, 40, 40, 5.4] 89 | occ_size= [200, 200, 16] 90 | 91 | ################################### 92 | 93 | # data=np.load(f'./occ_{idx}.npz') 94 | # occ,trans,rot=data['occ'],data['e2g_rel0_t'],data['e2g_rel0_r'] 95 | def load_occ(args): 96 | with open(args.imageset, 'rb') as f: 97 | data = pickle.load(f) 98 | nusc_infos = data['infos'] 99 | assert args.idx0 162 | 163 | # save pc 164 | # repeat n times 165 | # write_pc(occ_pc_road,f'{dst_dir}/vis_gt_road.ply',colors[None,road_cls_id-1].repeat(len(occ_pc_road),axis=0)[:,:3]/255.0) #debug 166 | 167 | # create bev map 168 | bev_road, nearest_distance, (x_min, y_min,n_x,n_y) = create_bev_from_pc(occ_pc_road, resolution_2d, max_dist) 169 | distance_cost=distance_cost_weigth*nearest_distance 170 | origin_path=trans.copy() 171 | origin_path=((origin_path-np.array([x_min,y_min,0]))/resolution_2d).astype(np.int32) 172 | 173 | # Generate mask from the original path 174 | mask_origin_path = get_mask_from_path(origin_path, (n_x, n_y), expand_radius=path_expand_radius) 175 | 176 | # Refine the mask by considering only areas that are both in the original path and on the road 177 | mask_origin_path = np.logical_and(mask_origin_path > 0, bev_road > 0).astype(np.uint8) 178 | 179 | # Calculate map Paramaters 180 | bev_map=GridWithWeights.from_voxel_bev_map(bev_road,cost_map=distance_cost) 181 | 182 | n_r=n_c=np.ceil(n_sample_pair**0.5).astype(np.int32) 183 | fig,ax=plt.subplots(n_r,n_c,figsize=(5*n_c,5*n_r)) 184 | ax=ax.flatten() 185 | 186 | sampled_trajs = [] 187 | valid_path_count=0 188 | np.random.seed(seed) 189 | 190 | # Plotting 191 | ax[0].imshow(bev_road) 192 | # draw raw path 193 | ax[1].imshow(mask_origin_path) 194 | ax[1].plot(origin_path[:,1], origin_path[:,0], linewidth=1.5, color='k', zorder=0,alpha=0.8) 195 | ax[1].scatter(origin_path[0,1], origin_path[0,0], marker='o', c='r') 196 | ax[1].scatter(origin_path[-1,1], origin_path[-1,0], marker='x', c='b') 197 | ax[1].set_title(f'origin path')#, area {mask_origin_path.sum()}') 198 | plt.tight_layout() 199 | plt.savefig(f'{dst_dir}/sampled_trajectories.png', dpi=300, bbox_inches='tight') 200 | 201 | fail_cnt=0 202 | fail_v2_flag=False 203 | pbar=tqdm(total=n_sample_pair,desc='sampling traj') 204 | while valid_path_count fail_cnt_thres: 206 | if min_traj_len_st_en < min_min_traj_len_st_en: 207 | raise ValueError('Failed too many times: unable to generate valid trajectory') 208 | elif fail_v2_flag: 209 | print('Sampling method v1 failed, attempting to reduce min_distance_st_en') 210 | if min_traj_len_st_en > min_min_traj_len_st_en: 211 | min_traj_len_st_en = max(min_traj_len_st_en - delta_min_traj_len_st_en, min_min_traj_len_st_en) 212 | else: 213 | raise ValueError('Failed too many times: unable to generate valid trajectory') 214 | fail_cnt = 0 215 | else: 216 | print('Sampling method v2 failed, switching to v1') 217 | fail_v2_flag = True 218 | fail_cnt = 0 219 | # sample points in roi 220 | if not fail_v2_flag: 221 | # if mask_origin_path.sum()>min_path_area: 222 | st, en, dist=sample_points_in_roi_v2(bev_road.shape[0],bev_road.shape[1],num_points=1,resolution=resolution_2d,mask=bev_road, 223 | mask_path=mask_origin_path, 224 | min_distance_threshold=min_distance_st_en,verbose=False)[0] 225 | else: 226 | st, en, dist=sample_points_in_roi(bev_road.shape[0],bev_road.shape[1],num_points=1,resolution=resolution_2d,mask=bev_road,min_distance_threshold=min_distance_st_en,verbose=False)[0] 227 | 228 | start, goal = tuple(st), tuple(en) 229 | # tic = time.time() 230 | came_from, cost_so_far = a_star_search(bev_map, start, goal) 231 | # print('A* Time:', time.time() - tic) 232 | 233 | if len(came_from) <= 1 or goal not in came_from: 234 | print('@filtered, no solution') 235 | fail_cnt+=1 236 | continue 237 | 238 | path = np.array(reconstruct_path(came_from, start=start, goal=goal)) 239 | path_raw=path.copy() 240 | 241 | print(f'@len traj {len(path_raw)} @dist {dist}') 242 | if len(path_raw) < min_traj_len_st_en: 243 | print('@filtered, too short traj') 244 | fail_cnt+=1 245 | continue 246 | elif len(path_raw) > max_traj_len_st_en: 247 | print('@filtered, too long traj') 248 | fail_cnt+=1 249 | continue 250 | 251 | fail_cnt=0 252 | if n_traj_point_ds>1: 253 | path = np.concatenate([path[:-1:n_traj_point_ds], path[-1:]], axis=0) 254 | degree=min(5,len(path)-1) 255 | path_smooth = approximate_b_spline_path(path[:, 0], path[:, 1], n_path_points=1000, degree=degree) 256 | path_occ, dd = approximate_b_spline_path(path[:, 0], path[:, 1], n_path_points=n_sample_occ, degree=degree,with_derivatives=True) 257 | sampled_trajs.append(get_3d_pose_from_2d(path_occ, dd, [x_min, y_min], resolution_2d,plane_model=plane_model)) 258 | pbar.update(1) 259 | # Plotting 260 | ax[valid_path_count+2].imshow(bev_road) 261 | # draw raw path 262 | # ax[valid_path_count].plot(origin_path[:,1], origin_path[:,0], linewidth=1.5, color='k', zorder=0,alpha=0.8) 263 | ax[valid_path_count+2].plot(path_raw[:,1], path_raw[:,0], linewidth=1.5, color='g', zorder=0,alpha=0.8) 264 | ax[valid_path_count+2].plot(path[:,1], path[:,0], linewidth=1.5, color='b', zorder=0,alpha=0.8) 265 | ax[valid_path_count+2].plot(path_smooth[:,1], path_smooth[:,0], linewidth=1.5, color='r', zorder=0,alpha=0.8) 266 | ax[valid_path_count+2].scatter(st[1], st[0], marker='o', c='r') 267 | ax[valid_path_count+2].scatter(en[1], en[0], marker='x', c='b') 268 | ax[valid_path_count+2].set_title(f'resample {valid_path_count}\ndist: {dist:.2f}, traj: {len(path_raw)}') 269 | valid_path_count += 1 270 | # break #debug 271 | 272 | 273 | plt.tight_layout() 274 | plt.savefig(f'{dst_dir}/sampled_trajectories.png', dpi=300, bbox_inches='tight') 275 | 276 | sampled_trajs=np.array(sampled_trajs) 277 | np.save(f'{dst_dir}/sampled_trajectories.npy',sampled_trajs) 278 | assert len(sampled_trajs)==n_sample_pair 279 | 280 | # ## debug use orgin 281 | # sampled_trajs=np.array([np.eye(4)]*len(trans)) 282 | # sampled_trajs[:,:3,:3]=np.array([Quaternion(r).rotation_matrix for r in rot]) 283 | # sampled_trajs[:,:3,3]=trans 284 | # sampled_trajs=sampled_trajs[None] 285 | # n_sample_pair=1 286 | 287 | def process_trajectory(j, sampled_traj, occ_pc_static, cls_label, pc_range, voxel_size, occ_size, dst_dir_i): 288 | dst_dir_j = f'{dst_dir_i}/traj-{j:06d}' # TODO: use local id 289 | os.makedirs(dst_dir_j, exist_ok=True) 290 | 291 | # Transform to ego coordinates 292 | w2e = np.linalg.inv(sampled_traj) 293 | occ_pc_static_e = occ_pc_static @ w2e[:3,:3].T + w2e[:3,3] 294 | 295 | dense_voxels_with_semantic, voxel_semantic = sampling_occ_from_pc(occ_pc_static_e, cls_label, pc_range, voxel_size, occ_size) 296 | np.savez(f'{dst_dir_j}/labels.npz', semantics=voxel_semantic,pose=sampled_traj) 297 | # write_pc(dense_voxels_with_semantic[:,:3],f'{dst_dir_j}/occ_vis.ply',colors[dense_voxels_with_semantic[:,3]-1][:,:3]/255.0) #debug 298 | return voxel_semantic 299 | 300 | for i in tqdm(range(n_sample_pair),desc='resampling occ'): 301 | # if i>0: 302 | # break #debug 303 | 304 | voxel_semantic_all=[] 305 | # vis_pose_mesh(sampled_trajs[i],cmp_dir=dst_dir,fn=f'vis_resample_all_w_traj_sample_{i}.ply') #debug 306 | dst_dir_i=f'{dst_dir}/scene-{i:04d}' #TODO global id 307 | 308 | process_func = partial(process_trajectory, 309 | occ_pc_static=occ_pc_static, 310 | cls_label=cls_label, 311 | pc_range=pc_range, 312 | voxel_size=voxel_size, 313 | occ_size=occ_size, 314 | dst_dir_i=dst_dir_i) 315 | 316 | voxel_semantic_all = joblib.Parallel(n_jobs=-1)( 317 | joblib.delayed(process_func)(j, sampled_traj) 318 | for j, sampled_traj in enumerate(sampled_trajs[i]) 319 | ) 320 | # for j in range(len(sampled_trajs[i])): 321 | # tic=time.time() 322 | # voxel_semantic_all.append(process_func(j, sampled_trajs[i][j])) 323 | # print(f'@process_func {time.time()-tic}') 324 | 325 | visualize_point_cloud( 326 | voxel_semantic_all, 327 | None, 328 | None, 329 | abs_trans=sampled_trajs[i], 330 | cmp_dir=dst_dir_i, 331 | frame_type='', 332 | # frame_type='e+w', 333 | key='resample' 334 | ) 335 | -------------------------------------------------------------------------------- /resample/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | open3d 3 | tqdm 4 | scipy 5 | scikit-learn 6 | opencv-python 7 | matplotlib 8 | pyquaternion 9 | joblib -------------------------------------------------------------------------------- /static/images/dome_pipe.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gusongen/DOME/da390aeeab1990fa06e037ad4d486729ae4fe712/static/images/dome_pipe.png -------------------------------------------------------------------------------- /static/images/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gusongen/DOME/da390aeeab1990fa06e037ad4d486729ae4fe712/static/images/favicon.ico -------------------------------------------------------------------------------- /static/images/occvae.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gusongen/DOME/da390aeeab1990fa06e037ad4d486729ae4fe712/static/images/occvae.png -------------------------------------------------------------------------------- /static/images/overall_pipeline4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gusongen/DOME/da390aeeab1990fa06e037ad4d486729ae4fe712/static/images/overall_pipeline4.png -------------------------------------------------------------------------------- /static/images/teaser12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gusongen/DOME/da390aeeab1990fa06e037ad4d486729ae4fe712/static/images/teaser12.png -------------------------------------------------------------------------------- /static/images/vis_demo_cmp_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gusongen/DOME/da390aeeab1990fa06e037ad4d486729ae4fe712/static/images/vis_demo_cmp_2.png -------------------------------------------------------------------------------- /tools/eval.sh: -------------------------------------------------------------------------------- 1 | 2 | # export CUDA_VISIBLE_DEVICES=4,5,6,7 3 | # export CUDA_VISIBLE_DEVICES=7 4 | 5 | 6 | cfg=./config/train_dome.py 7 | dir=./work_dir/dome 8 | ckpt=ckpts/dome_latest.pth 9 | vae_ckpt=ckpts/occvae_latest.pth 10 | 11 | 12 | python tools/eval_metric.py \ 13 | --py-config $cfg \ 14 | --work-dir $dir \ 15 | --resume-from $ckpt \ 16 | --vae-resume-from $vae_ckpt 17 | -------------------------------------------------------------------------------- /tools/eval_vae.py: -------------------------------------------------------------------------------- 1 | import time, argparse, os.path as osp, os 2 | import torch, numpy as np 3 | import torch.distributed as dist 4 | from copy import deepcopy 5 | 6 | import mmcv 7 | from mmengine import Config 8 | from mmengine.runner import set_random_seed 9 | from mmengine.optim import build_optim_wrapper 10 | from mmengine.logging import MMLogger 11 | from mmengine.utils import symlink 12 | from mmengine.registry import MODELS 13 | from timm.scheduler import CosineLRScheduler, MultiStepLRScheduler 14 | import sys 15 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 16 | from utils.load_save_util import revise_ckpt, revise_ckpt_1, load_checkpoint 17 | from torch.utils.tensorboard import SummaryWriter 18 | import warnings 19 | from einops import rearrange 20 | warnings.filterwarnings("ignore") 21 | 22 | 23 | def pass_print(*args, **kwargs): 24 | pass 25 | 26 | @torch.no_grad() 27 | def main(local_rank, args): 28 | # global settings 29 | set_random_seed(args.seed) 30 | torch.backends.cudnn.deterministic = False 31 | torch.backends.cudnn.benchmark = True 32 | 33 | # load config 34 | cfg = Config.fromfile(args.py_config) 35 | cfg.work_dir = args.work_dir 36 | 37 | # init DDP 38 | if args.gpus > 1: 39 | distributed = True 40 | ip = os.environ.get("MASTER_ADDR", "127.0.0.1") 41 | port = os.environ.get("MASTER_PORT", cfg.get("port", 29510)) 42 | hosts = int(os.environ.get("WORLD_SIZE", 1)) # number of nodes 43 | rank = int(os.environ.get("RANK", 0)) # node id 44 | gpus = torch.cuda.device_count() # gpus per node 45 | print(f"tcp://{ip}:{port}") 46 | dist.init_process_group( 47 | backend="nccl", init_method=f"tcp://{ip}:{port}", 48 | world_size=hosts * gpus, rank=rank * gpus + local_rank) 49 | world_size = dist.get_world_size() 50 | cfg.gpu_ids = range(world_size) 51 | torch.cuda.set_device(local_rank) 52 | 53 | if local_rank != 0: 54 | import builtins 55 | builtins.print = pass_print 56 | else: 57 | distributed = False 58 | world_size = 1 59 | 60 | if local_rank == 0: 61 | os.makedirs(args.work_dir, exist_ok=True) 62 | cfg.dump(osp.join(args.work_dir, osp.basename(args.py_config))) 63 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) 64 | log_file = osp.join(args.work_dir, f'{timestamp}.log') 65 | logger = MMLogger('genocc', log_file=log_file) 66 | MMLogger._instance_dict['genocc'] = logger 67 | logger.info(f'Config:\n{cfg.pretty_text}') 68 | tb_dir=args.tb_dir if args.tb_dir else osp.join(args.work_dir, 'tb_log') 69 | writer = SummaryWriter(tb_dir) 70 | 71 | # build model 72 | import model 73 | from dataset import get_dataloader, get_nuScenes_label_name 74 | from loss import OPENOCC_LOSS 75 | from utils.metric_util import MeanIoU, multi_step_MeanIou 76 | from utils.freeze_model import freeze_model 77 | 78 | my_model = MODELS.build(cfg.model) 79 | my_model.init_weights() 80 | n_parameters = sum(p.numel() for p in my_model.parameters() if p.requires_grad) 81 | logger.info(f'Number of params: {n_parameters}') 82 | if cfg.get('freeze_dict', False): 83 | logger.info(f'Freezing model according to freeze_dict:{cfg.freeze_dict}') 84 | freeze_model(my_model, cfg.freeze_dict) 85 | n_parameters = sum(p.numel() for p in my_model.parameters() if p.requires_grad) 86 | logger.info(f'Number of params after freezed: {n_parameters}') 87 | if distributed: 88 | if cfg.get('syncBN', True): 89 | my_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(my_model) 90 | logger.info('converted sync bn.') 91 | 92 | find_unused_parameters = cfg.get('find_unused_parameters', True) 93 | ddp_model_module = torch.nn.parallel.DistributedDataParallel 94 | my_model = ddp_model_module( 95 | my_model.cuda(), 96 | device_ids=[torch.cuda.current_device()], 97 | broadcast_buffers=False, 98 | find_unused_parameters=find_unused_parameters) 99 | raw_model = my_model.module 100 | else: 101 | my_model = my_model.cuda() 102 | raw_model = my_model 103 | logger.info('done ddp model') 104 | 105 | train_dataset_loader, val_dataset_loader = get_dataloader( 106 | cfg.train_dataset_config, 107 | cfg.val_dataset_config, 108 | cfg.train_wrapper_config, 109 | cfg.val_wrapper_config, 110 | cfg.train_loader, 111 | cfg.val_loader, 112 | dist=distributed, 113 | iter_resume=args.iter_resume) 114 | 115 | # get optimizer, loss, scheduler 116 | optimizer = build_optim_wrapper(my_model, cfg.optimizer) 117 | loss_func = OPENOCC_LOSS.build(cfg.loss).cuda() 118 | max_num_epochs = cfg.max_epochs 119 | if cfg.get('multisteplr', False): 120 | scheduler = MultiStepLRScheduler( 121 | optimizer, 122 | **cfg.multisteplr_config) 123 | else: 124 | scheduler = CosineLRScheduler( 125 | optimizer, 126 | t_initial=len(train_dataset_loader) * max_num_epochs, 127 | lr_min=1e-6, 128 | warmup_t=cfg.get('warmup_iters', 500), 129 | warmup_lr_init=1e-6, 130 | t_in_epochs=False) 131 | 132 | # resume and load 133 | epoch = 0 134 | global_iter = 0 135 | last_iter = 0 136 | best_val_iou = [0]*cfg.get('return_len_', 10) 137 | best_val_miou = [0]*cfg.get('return_len_', 10) 138 | 139 | cfg.resume_from = '' 140 | # if osp.exists(osp.join(args.work_dir, 'latest.pth')): 141 | # cfg.resume_from = osp.join(args.work_dir, 'latest.pth') 142 | if args.resume_from: 143 | cfg.resume_from = args.resume_from 144 | if args.load_from: 145 | cfg.load_from = args.load_from 146 | 147 | logger.info('resume from: ' + cfg.resume_from) 148 | logger.info('load from: ' + cfg.load_from) 149 | logger.info('work dir: ' + args.work_dir) 150 | 151 | if cfg.resume_from and osp.exists(cfg.resume_from): 152 | map_location = 'cpu' 153 | ckpt = torch.load(cfg.resume_from, map_location=map_location) 154 | print(raw_model.load_state_dict(ckpt['state_dict'], strict=False)) 155 | optimizer.load_state_dict(ckpt['optimizer']) 156 | scheduler.load_state_dict(ckpt['scheduler']) 157 | epoch = ckpt['epoch'] 158 | global_iter = ckpt['global_iter'] 159 | last_iter = ckpt['last_iter'] if 'last_iter' in ckpt else 0 160 | if 'best_val_iou' in ckpt: 161 | best_val_iou = ckpt['best_val_iou'] 162 | if 'best_val_miou' in ckpt: 163 | best_val_miou = ckpt['best_val_miou'] 164 | 165 | if hasattr(train_dataset_loader.sampler, 'set_last_iter'): 166 | train_dataset_loader.sampler.set_last_iter(last_iter) 167 | print(f'successfully resumed from epoch {epoch}') 168 | elif cfg.load_from: 169 | ckpt = torch.load(cfg.load_from, map_location='cpu') 170 | if 'state_dict' in ckpt: 171 | state_dict = ckpt['state_dict'] 172 | else: 173 | state_dict = ckpt 174 | if cfg.get('revise_ckpt', False): 175 | if cfg.revise_ckpt == 1: 176 | print('revise_ckpt') 177 | print(raw_model.load_state_dict(revise_ckpt(state_dict), strict=False)) 178 | elif cfg.revise_ckpt == 2: 179 | print('revise_ckpt_1') 180 | print(raw_model.load_state_dict(revise_ckpt_1(state_dict), strict=False)) 181 | elif cfg.revise_ckpt == 3: 182 | print('revise_ckpt_2') 183 | print(raw_model.vae.load_state_dict(state_dict, strict=False)) 184 | else: 185 | # print(raw_model.load_state_dict(state_dict, strict=False)) 186 | load_checkpoint(raw_model,state_dict, strict=False) #TODO may need to remove moudle.xxx 187 | 188 | # training 189 | print_freq = cfg.print_freq 190 | grad_norm = 0 191 | 192 | label_name = get_nuScenes_label_name(cfg.label_mapping) 193 | unique_label = np.asarray(cfg.unique_label) 194 | unique_label_str = [label_name[l] for l in unique_label] 195 | # CalMeanIou_sem = multi_step_MeanIou(unique_label, cfg.get('ignore_label', -100), unique_label_str, 'sem', times=cfg.get('return_len_', 10)) 196 | # CalMeanIou_vox = multi_step_MeanIou([1], cfg.get('ignore_label', -100), ['occupied'], 'vox', times=cfg.get('return_len_', 10)) 197 | CalMeanIou_sem = multi_step_MeanIou(unique_label, cfg.get('ignore_label', -100), unique_label_str, 'sem', times=1)#cfg.get('return_len_', 10)) 198 | CalMeanIou_vox = multi_step_MeanIou([1], cfg.get('ignore_label', -100), ['occupied'], 'vox', times=1)#cfg.get('return_len_', 10)) 199 | # logger.info('compiling model') 200 | # my_model = torch.compile(my_model) 201 | # logger.info('done compile model') 202 | best_plan_loss = 100000 203 | # max_num_epochs=1 #debug 204 | if True: 205 | my_model.eval() 206 | os.environ['eval'] = 'true' 207 | val_loss_list = [] 208 | CalMeanIou_sem.reset() 209 | CalMeanIou_vox.reset() 210 | plan_loss = 0 211 | 212 | with torch.no_grad(): 213 | for i_iter_val, (input_occs, target_occs, metas) in enumerate(val_dataset_loader): 214 | # input_occs=rearrange(input_occs,'b (f1 f) h w d-> (b f) (f1 1) h w d',f1=2) 215 | # target_occs=rearrange(target_occs,'b f h w d-> (b f) 1 h w d') 216 | 217 | input_occs = input_occs.cuda() 218 | target_occs = target_occs.cuda() 219 | data_time_e = time.time() 220 | 221 | result_dict = my_model(x=input_occs, metas=metas) 222 | 223 | loss_input = { 224 | 'inputs': input_occs, 225 | 'target_occs': target_occs, 226 | # 'metas': metas 227 | **result_dict 228 | } 229 | for loss_input_key, loss_input_val in cfg.loss_input_convertion.items(): 230 | loss_input.update({ 231 | loss_input_key: result_dict[loss_input_val] 232 | }) 233 | loss, loss_dict = loss_func(loss_input) 234 | plan_loss += loss_dict.get('PlanRegLoss', 0) 235 | plan_loss += loss_dict.get('PlanRegLossLidar', 0) 236 | if result_dict.get('target_occs', None) is not None: 237 | target_occs = result_dict['target_occs'] 238 | target_occs_iou = deepcopy(target_occs) 239 | target_occs_iou[target_occs_iou != 17] = 1 240 | target_occs_iou[target_occs_iou == 17] = 0 241 | 242 | CalMeanIou_sem._after_step( 243 | rearrange(result_dict['sem_pred'],'b f h w d-> (b f) 1 h w d'), 244 | rearrange(target_occs,'b f h w d-> (b f) 1 h w d')) 245 | CalMeanIou_vox._after_step( 246 | rearrange(result_dict['iou_pred'],'b f h w d-> (b f) 1 h w d'), 247 | rearrange(target_occs_iou,'b f h w d-> (b f) 1 h w d')) 248 | val_loss_list.append(loss.detach().cpu().numpy()) 249 | if i_iter_val % print_freq == 0 and local_rank == 0: 250 | logger.info('[EVAL] Epoch %d Iter %5d: Loss: %.3f (%.3f)'%( 251 | epoch, i_iter_val, loss.item(), np.mean(val_loss_list))) 252 | writer.add_scalar(f'val/loss', loss.item(), global_iter) 253 | detailed_loss = [] 254 | for loss_name, loss_value in loss_dict.items(): 255 | detailed_loss.append(f'{loss_name}: {loss_value:.5f}') 256 | writer.add_scalar(f'val/{loss_name}', loss_value, global_iter) 257 | detailed_loss = ', '.join(detailed_loss) 258 | logger.info(detailed_loss) 259 | # break #debug 260 | val_miou, _ = CalMeanIou_sem._after_epoch() 261 | val_iou, _ = CalMeanIou_vox._after_epoch() 262 | 263 | del target_occs, input_occs 264 | plan_loss = plan_loss/len(val_dataset_loader) 265 | if plan_loss < best_plan_loss: 266 | best_plan_loss = plan_loss 267 | logger.info(f'PlanRegLoss is {plan_loss} while the best plan loss is {best_plan_loss}') 268 | #logger.info(f'PlanRegLoss is {plan_loss/len(val_dataset_loader)}') 269 | best_val_iou = val_iou#[max(best_val_iou[i], val_iou[i]) for i in range(len(best_val_iou))] 270 | best_val_miou = val_miou#[max(best_val_miou[i], val_miou[i]) for i in range(len(best_val_miou))] 271 | #logger.info(f'PlanRegLoss is {plan_loss/len(val_dataset_loader)}') 272 | logger.info(f'Current val iou is {val_iou} while the best val iou is {best_val_iou}') 273 | logger.info(f'Current val miou is {val_miou} while the best val miou is {best_val_miou}') 274 | torch.cuda.empty_cache() 275 | 276 | 277 | if __name__ == '__main__': 278 | # Training settings 279 | parser = argparse.ArgumentParser(description='') 280 | parser.add_argument('--py-config', default='config/tpv_lidarseg.py') 281 | parser.add_argument('--work-dir', type=str, default='./out/tpv_lidarseg') 282 | parser.add_argument('--tb-dir', type=str, default=None) 283 | parser.add_argument('--resume-from', type=str, default='') 284 | parser.add_argument('--iter-resume', action='store_true', default=False) 285 | parser.add_argument('--seed', type=int, default=42) 286 | parser.add_argument('--load_from', type=str, default=None) 287 | args = parser.parse_args() 288 | 289 | ngpus = torch.cuda.device_count() 290 | args.gpus = ngpus 291 | print(args) 292 | 293 | if ngpus > 1: 294 | torch.multiprocessing.spawn(main, args=(args,), nprocs=args.gpus) 295 | else: 296 | main(0, args) 297 | -------------------------------------------------------------------------------- /tools/eval_vae.sh: -------------------------------------------------------------------------------- 1 | 2 | # export CUDA_VISIBLE_DEVICES=4,5,6,7 3 | # export CUDA_VISIBLE_DEVICES=7 4 | 5 | 6 | cfg=./config/train_occvae.py 7 | dir=./work_dir/occ_vae 8 | vae_ckpt=ckpts/occvae_latest.pth 9 | 10 | python tools/eval_vae.py \ 11 | --py-config $cfg \ 12 | --work-dir $dir \ 13 | --load_from $vae_ckpt 14 | -------------------------------------------------------------------------------- /tools/train_diffusion.sh: -------------------------------------------------------------------------------- 1 | # export CUDA_VISIBLE_DEVICES=5 2 | 3 | 4 | cfg=./config/train_dome.py 5 | dir=./work_dir/dome 6 | vae_ckpt=ckpts/occvae_latest.pth 7 | 8 | 9 | python tools/train_diffusion.py \ 10 | --py-config $cfg \ 11 | --work-dir $dir -------------------------------------------------------------------------------- /tools/train_vae.py: -------------------------------------------------------------------------------- 1 | import time, argparse, os.path as osp, os 2 | import torch, numpy as np 3 | import torch.distributed as dist 4 | from copy import deepcopy 5 | 6 | import mmcv 7 | from mmengine import Config 8 | from mmengine.runner import set_random_seed 9 | from mmengine.optim import build_optim_wrapper 10 | from mmengine.logging import MMLogger 11 | from mmengine.utils import symlink 12 | from mmengine.registry import MODELS 13 | from timm.scheduler import CosineLRScheduler, MultiStepLRScheduler 14 | import sys 15 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 16 | from utils.load_save_util import revise_ckpt, revise_ckpt_1, load_checkpoint 17 | from torch.utils.tensorboard import SummaryWriter 18 | import warnings 19 | warnings.filterwarnings("ignore") 20 | 21 | 22 | def pass_print(*args, **kwargs): 23 | pass 24 | 25 | def main(local_rank, args): 26 | # global settings 27 | set_random_seed(args.seed) 28 | torch.backends.cudnn.deterministic = False 29 | torch.backends.cudnn.benchmark = True 30 | 31 | # load config 32 | cfg = Config.fromfile(args.py_config) 33 | cfg.work_dir = args.work_dir 34 | 35 | # init DDP 36 | if args.gpus > 1: 37 | distributed = True 38 | ip = os.environ.get("MASTER_ADDR", "127.0.0.1") 39 | port = os.environ.get("MASTER_PORT", cfg.get("port", 29510)) 40 | hosts = int(os.environ.get("WORLD_SIZE", 1)) # number of nodes 41 | rank = int(os.environ.get("RANK", 0)) # node id 42 | gpus = torch.cuda.device_count() # gpus per node 43 | print(f"tcp://{ip}:{port}") 44 | dist.init_process_group( 45 | backend="nccl", init_method=f"tcp://{ip}:{port}", 46 | world_size=hosts * gpus, rank=rank * gpus + local_rank) 47 | world_size = dist.get_world_size() 48 | cfg.gpu_ids = range(world_size) 49 | torch.cuda.set_device(local_rank) 50 | 51 | if local_rank != 0: 52 | import builtins 53 | builtins.print = pass_print 54 | else: 55 | distributed = False 56 | world_size = 1 57 | 58 | if local_rank == 0: 59 | os.makedirs(args.work_dir, exist_ok=True) 60 | cfg.dump(osp.join(args.work_dir, osp.basename(args.py_config))) 61 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) 62 | log_file = osp.join(args.work_dir, f'{timestamp}.log') 63 | logger = MMLogger('genocc', log_file=log_file) 64 | MMLogger._instance_dict['genocc'] = logger 65 | logger.info(f'Config:\n{cfg.pretty_text}') 66 | tb_dir=args.tb_dir if args.tb_dir else osp.join(args.work_dir, 'tb_log') 67 | writer = SummaryWriter(tb_dir) 68 | 69 | # build model 70 | import model 71 | from dataset import get_dataloader, get_nuScenes_label_name 72 | from loss import OPENOCC_LOSS 73 | from utils.metric_util import MeanIoU, multi_step_MeanIou 74 | from utils.freeze_model import freeze_model 75 | 76 | my_model = MODELS.build(cfg.model) 77 | my_model.init_weights() 78 | n_parameters = sum(p.numel() for p in my_model.parameters() if p.requires_grad) 79 | logger.info(f'Number of params: {n_parameters}') 80 | if cfg.get('freeze_dict', False): 81 | logger.info(f'Freezing model according to freeze_dict:{cfg.freeze_dict}') 82 | freeze_model(my_model, cfg.freeze_dict) 83 | n_parameters = sum(p.numel() for p in my_model.parameters() if p.requires_grad) 84 | logger.info(f'Number of params after freezed: {n_parameters}') 85 | if distributed: 86 | if cfg.get('syncBN', True): 87 | my_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(my_model) 88 | logger.info('converted sync bn.') 89 | 90 | find_unused_parameters = cfg.get('find_unused_parameters', False) 91 | ddp_model_module = torch.nn.parallel.DistributedDataParallel 92 | my_model = ddp_model_module( 93 | my_model.cuda(), 94 | device_ids=[torch.cuda.current_device()], 95 | broadcast_buffers=False, 96 | find_unused_parameters=find_unused_parameters) 97 | raw_model = my_model.module 98 | else: 99 | my_model = my_model.cuda() 100 | raw_model = my_model 101 | logger.info('done ddp model') 102 | 103 | train_dataset_loader, val_dataset_loader = get_dataloader( 104 | cfg.train_dataset_config, 105 | cfg.val_dataset_config, 106 | cfg.train_wrapper_config, 107 | cfg.val_wrapper_config, 108 | cfg.train_loader, 109 | cfg.val_loader, 110 | dist=distributed, 111 | iter_resume=args.iter_resume) 112 | 113 | # get optimizer, loss, scheduler 114 | optimizer = build_optim_wrapper(my_model, cfg.optimizer) 115 | loss_func = OPENOCC_LOSS.build(cfg.loss).cuda() 116 | max_num_epochs = cfg.max_epochs 117 | if cfg.get('multisteplr', False): 118 | scheduler = MultiStepLRScheduler( 119 | optimizer, 120 | **cfg.multisteplr_config) 121 | else: 122 | scheduler = CosineLRScheduler( 123 | optimizer, 124 | t_initial=len(train_dataset_loader) * max_num_epochs, 125 | lr_min=1e-6, 126 | warmup_t=cfg.get('warmup_iters', 500), 127 | warmup_lr_init=1e-6, 128 | t_in_epochs=False) 129 | 130 | # resume and load 131 | epoch = 0 132 | global_iter = 0 133 | last_iter = 0 134 | best_val_iou = [0]*cfg.get('return_len_', 10) 135 | best_val_miou = [0]*cfg.get('return_len_', 10) 136 | 137 | cfg.resume_from = '' 138 | # if osp.exists(osp.join(args.work_dir, 'latest.pth')): 139 | # cfg.resume_from = osp.join(args.work_dir, 'latest.pth') 140 | if args.resume_from: 141 | cfg.resume_from = args.resume_from 142 | if args.load_from: 143 | cfg.load_from = args.load_from 144 | 145 | logger.info('resume from: ' + cfg.resume_from) 146 | logger.info('load from: ' + cfg.load_from) 147 | logger.info('work dir: ' + args.work_dir) 148 | 149 | if cfg.resume_from and osp.exists(cfg.resume_from): 150 | map_location = 'cpu' 151 | ckpt = torch.load(cfg.resume_from, map_location=map_location) 152 | print(raw_model.load_state_dict(ckpt['state_dict'], strict=False)) 153 | optimizer.load_state_dict(ckpt['optimizer']) 154 | scheduler.load_state_dict(ckpt['scheduler']) 155 | epoch = ckpt['epoch'] 156 | global_iter = ckpt['global_iter'] 157 | last_iter = ckpt['last_iter'] if 'last_iter' in ckpt else 0 158 | if 'best_val_iou' in ckpt: 159 | best_val_iou = ckpt['best_val_iou'] 160 | if 'best_val_miou' in ckpt: 161 | best_val_miou = ckpt['best_val_miou'] 162 | 163 | if hasattr(train_dataset_loader.sampler, 'set_last_iter'): 164 | train_dataset_loader.sampler.set_last_iter(last_iter) 165 | print(f'successfully resumed from epoch {epoch}') 166 | elif cfg.load_from: 167 | ckpt = torch.load(cfg.load_from, map_location='cpu') 168 | if 'state_dict' in ckpt: 169 | state_dict = ckpt['state_dict'] 170 | else: 171 | state_dict = ckpt 172 | if cfg.get('revise_ckpt', False): 173 | if cfg.revise_ckpt == 1: 174 | print('revise_ckpt') 175 | print(raw_model.load_state_dict(revise_ckpt(state_dict), strict=False)) 176 | elif cfg.revise_ckpt == 2: 177 | print('revise_ckpt_1') 178 | print(raw_model.load_state_dict(revise_ckpt_1(state_dict), strict=False)) 179 | elif cfg.revise_ckpt == 3: 180 | print('revise_ckpt_2') 181 | print(raw_model.vae.load_state_dict(state_dict, strict=False)) 182 | else: 183 | # print(raw_model.load_state_dict(state_dict, strict=False)) 184 | load_checkpoint(raw_model,state_dict, strict=False) #TODO may need to remove moudle.xxx 185 | 186 | # training 187 | print_freq = cfg.print_freq 188 | first_run = True 189 | grad_norm = 0 190 | 191 | label_name = get_nuScenes_label_name(cfg.label_mapping) 192 | unique_label = np.asarray(cfg.unique_label) 193 | unique_label_str = [label_name[l] for l in unique_label] 194 | CalMeanIou_sem = multi_step_MeanIou(unique_label, cfg.get('ignore_label', -100), unique_label_str, 'sem', times=cfg.get('return_len_', 10)) 195 | CalMeanIou_vox = multi_step_MeanIou([1], cfg.get('ignore_label', -100), ['occupied'], 'vox', times=cfg.get('return_len_', 10)) 196 | # logger.info('compiling model') 197 | # my_model = torch.compile(my_model) 198 | # logger.info('done compile model') 199 | # max_num_epochs=1 #debug 200 | while epoch < max_num_epochs: 201 | 202 | my_model.train() 203 | os.environ['eval'] = 'false' 204 | if hasattr(train_dataset_loader.sampler, 'set_epoch'): 205 | train_dataset_loader.sampler.set_epoch(epoch) 206 | loss_list = [] 207 | time.sleep(10) 208 | data_time_s = time.time() 209 | time_s = time.time() 210 | for i_iter, (input_occs, target_occs, metas) in enumerate(train_dataset_loader): 211 | if first_run: 212 | i_iter = i_iter + last_iter 213 | 214 | input_occs = input_occs.cuda() 215 | target_occs = target_occs.cuda() 216 | data_time_e = time.time() 217 | use_pose_condition = torch.rand(1) < cfg.get('p_use_pose_condition',0) 218 | result_dict = my_model(x=input_occs, metas=metas,use_pose_condition=use_pose_condition) 219 | 220 | loss_input = { 221 | 'inputs': input_occs, 222 | 'target_occs': target_occs, 223 | # 'metas': metas 224 | **result_dict, 225 | } 226 | 227 | for loss_input_key, loss_input_val in cfg.loss_input_convertion.items(): 228 | input_=result_dict[loss_input_val] 229 | if 'temperal_mask' in result_dict: 230 | t_mask=result_dict['temperal_mask'] 231 | if input_.dim()==4: 232 | t_mask= t_mask.unsqueeze(1) 233 | input_*=t_mask 234 | loss_input.update({ 235 | loss_input_key: input_}) 236 | loss, loss_dict = loss_func(loss_input) 237 | optimizer.zero_grad() 238 | loss.backward() 239 | grad_norm = torch.nn.utils.clip_grad_norm_(my_model.parameters(), cfg.grad_max_norm) 240 | optimizer.step() 241 | 242 | loss_list.append(loss.detach().cpu().item()) 243 | scheduler.step_update(global_iter) 244 | time_e = time.time() 245 | 246 | global_iter += 1 247 | if i_iter % print_freq == 0 and local_rank == 0: 248 | lr = optimizer.param_groups[0]['lr'] 249 | logger.info('[TRAIN] Epoch %d Iter %5d/%d: Loss: %.3f (%.3f), grad_norm: %.3f, lr: %.7f, time: %.3f (%.3f)'%( 250 | epoch, i_iter, len(train_dataset_loader), 251 | loss.item(), np.mean(loss_list), grad_norm, lr, 252 | time_e - time_s, data_time_e - data_time_s)) 253 | writer.add_scalar(f'train/loss', loss.item(), global_iter) 254 | detailed_loss = [] 255 | for loss_name, loss_value in loss_dict.items(): 256 | detailed_loss.append(f'{loss_name}: {loss_value:.5f}') 257 | writer.add_scalar(f'train/{loss_name}', loss_value, global_iter) 258 | detailed_loss = ', '.join(detailed_loss) 259 | logger.info(detailed_loss) 260 | loss_list = [] 261 | # exit(0) #debug 262 | data_time_s = time.time() 263 | time_s = time.time() 264 | 265 | if args.iter_resume: 266 | if (i_iter + 1) % 50 == 0 and local_rank == 0: 267 | dict_to_save = { 268 | 'state_dict': raw_model.state_dict(), 269 | 'optimizer': optimizer.state_dict(), 270 | 'scheduler': scheduler.state_dict(), 271 | 'epoch': epoch, 272 | 'global_iter': global_iter, 273 | 'last_iter': i_iter + 1, 274 | } 275 | save_file_name = os.path.join(os.path.abspath(args.work_dir), 'iter.pth') 276 | torch.save(dict_to_save, save_file_name) 277 | dst_file = osp.join(args.work_dir, 'latest.pth') 278 | symlink(save_file_name, dst_file) 279 | logger.info(f'iter ckpt {i_iter + 1} saved!') 280 | # break #debug 281 | 282 | # save checkpoint 283 | if local_rank == 0 and (epoch+1) % cfg.get('save_every_epochs', 1) == 0: 284 | dict_to_save = { 285 | 'state_dict': raw_model.state_dict(), 286 | 'optimizer': optimizer.state_dict(), 287 | 'scheduler': scheduler.state_dict(), 288 | 'epoch': epoch + 1, 289 | 'global_iter': global_iter, 290 | } 291 | save_file_name = os.path.join(os.path.abspath(args.work_dir), f'epoch_{epoch+1}.pth') 292 | torch.save(dict_to_save, save_file_name) 293 | dst_file = osp.join(args.work_dir, 'latest.pth') 294 | symlink(save_file_name, dst_file) 295 | 296 | epoch += 1 297 | first_run = False 298 | 299 | # eval 300 | if epoch % cfg.get('eval_every_epochs', 1) != 0: 301 | continue 302 | my_model.eval() 303 | os.environ['eval'] = 'true' 304 | val_loss_list = [] 305 | CalMeanIou_sem.reset() 306 | CalMeanIou_vox.reset() 307 | 308 | with torch.no_grad(): 309 | for i_iter_val, (input_occs, target_occs, metas) in enumerate(val_dataset_loader): 310 | 311 | input_occs = input_occs.cuda() 312 | target_occs = target_occs.cuda() 313 | data_time_e = time.time() 314 | 315 | result_dict = my_model(x=input_occs, metas=metas) 316 | 317 | loss_input = { 318 | 'inputs': input_occs, 319 | 'target_occs': target_occs, 320 | # 'metas': metas 321 | **result_dict 322 | } 323 | for loss_input_key, loss_input_val in cfg.loss_input_convertion.items(): 324 | loss_input.update({ 325 | loss_input_key: result_dict[loss_input_val] 326 | }) 327 | loss, loss_dict = loss_func(loss_input) 328 | if result_dict.get('target_occs', None) is not None: 329 | target_occs = result_dict['target_occs'] 330 | target_occs_iou = deepcopy(target_occs) 331 | target_occs_iou[target_occs_iou != 17] = 1 332 | target_occs_iou[target_occs_iou == 17] = 0 333 | 334 | CalMeanIou_sem._after_step(result_dict['sem_pred'], target_occs) 335 | CalMeanIou_vox._after_step(result_dict['iou_pred'], target_occs_iou) 336 | val_loss_list.append(loss.detach().cpu().numpy()) 337 | if i_iter_val % print_freq == 0 and local_rank == 0: 338 | logger.info('[EVAL] Epoch %d Iter %5d: Loss: %.3f (%.3f)'%( 339 | epoch, i_iter_val, loss.item(), np.mean(val_loss_list))) 340 | writer.add_scalar(f'val/loss', loss.item(), global_iter) 341 | detailed_loss = [] 342 | for loss_name, loss_value in loss_dict.items(): 343 | detailed_loss.append(f'{loss_name}: {loss_value:.5f}') 344 | writer.add_scalar(f'val/{loss_name}', loss_value, global_iter) 345 | detailed_loss = ', '.join(detailed_loss) 346 | logger.info(detailed_loss) 347 | # break #debug 348 | val_miou, _ = CalMeanIou_sem._after_epoch() 349 | val_iou, _ = CalMeanIou_vox._after_epoch() 350 | 351 | del target_occs, input_occs 352 | best_val_iou = [max(best_val_iou[i], val_iou[i]) for i in range(len(best_val_iou))] 353 | best_val_miou = [max(best_val_miou[i], val_miou[i]) for i in range(len(best_val_miou))] 354 | logger.info(f'Current val iou is {val_iou} while the best val iou is {best_val_iou}') 355 | logger.info(f'Current val miou is {val_miou} while the best val miou is {best_val_miou}') 356 | torch.cuda.empty_cache() 357 | 358 | 359 | if __name__ == '__main__': 360 | # Training settings 361 | parser = argparse.ArgumentParser(description='') 362 | parser.add_argument('--py-config', default='config/tpv_lidarseg.py') 363 | parser.add_argument('--work-dir', type=str, default='./out/tpv_lidarseg') 364 | parser.add_argument('--tb-dir', type=str, default=None) 365 | parser.add_argument('--resume-from', type=str, default='') 366 | parser.add_argument('--iter-resume', action='store_true', default=False) 367 | parser.add_argument('--seed', type=int, default=42) 368 | parser.add_argument('--load_from', type=str, default=None) 369 | args = parser.parse_args() 370 | 371 | ngpus = torch.cuda.device_count() 372 | args.gpus = ngpus 373 | print(args) 374 | 375 | if ngpus > 1: 376 | torch.multiprocessing.spawn(main, args=(args,), nprocs=args.gpus) 377 | else: 378 | main(0, args) 379 | -------------------------------------------------------------------------------- /tools/train_vae.sh: -------------------------------------------------------------------------------- 1 | 2 | # export CUDA_VISIBLE_DEVICES=4,5,6,7 3 | # export CUDA_VISIBLE_DEVICES=7 4 | 5 | 6 | cfg=./config/train_occvae.py 7 | dir=./work_dir/occ_vae 8 | 9 | 10 | python tools/train_vae.py \ 11 | --py-config $cfg \ 12 | --work-dir $dir -------------------------------------------------------------------------------- /tools/vis_diffusion.sh: -------------------------------------------------------------------------------- 1 | # export CUDA_VISIBLE_DEVICES=5 2 | 3 | cfg=./config/train_dome.py 4 | dir=./work_dir/dome 5 | 6 | vae_ckpt=ckpts/occvae_latest.pth 7 | ckpt=ckpts/dome_latest.pth 8 | 9 | 10 | python tools/visualize_demo.py \ 11 | --py-config $cfg \ 12 | --work-dir $dir \ 13 | --resume-from $ckpt \ 14 | --vae-resume-from $vae_ckpt \ 15 | 16 | -------------------------------------------------------------------------------- /tools/vis_gif.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from tqdm import tqdm 3 | import imageio 4 | 5 | src=r"/home/users/songen.gu/adwm/OccWorld/out/occworld/visgts_autoreg/200/nuScenesSceneDatasetLidar" 6 | 7 | 8 | 9 | def create_gif(src, fps=10): 10 | images = [] 11 | for img in Path(src).rglob('*.png'): 12 | images.append(imageio.imread(img)) 13 | imageio.mimsave(f'{src}/vis.gif', images, fps=fps) 14 | 15 | def create_mp4(src, fps=2): 16 | with imageio.get_writer(f'{src}/vis.mp4', mode='I', fps=fps) as writer: 17 | for img in Path(src).rglob('vis*.png'): 18 | writer.append_data(imageio.imread(img)) 19 | 20 | if __name__ == '__main__': 21 | for dir in tqdm(Path(src).iterdir()): 22 | if dir.is_dir(): 23 | create_mp4(dir) 24 | create_gif(dir) 25 | -------------------------------------------------------------------------------- /tools/vis_utils.py: -------------------------------------------------------------------------------- 1 | from pyvirtualdisplay import Display 2 | display = Display(visible=False, size=(2560, 1440)) 3 | display.start() 4 | from mayavi import mlab 5 | import mayavi 6 | mlab.options.offscreen = True 7 | print("Set mlab.options.offscreen={}".format(mlab.options.offscreen)) 8 | import numpy as np 9 | import os 10 | from pyquaternion import Quaternion 11 | try: 12 | import open3d as o3d 13 | except: 14 | pass 15 | from functools import reduce 16 | import mmcv 17 | 18 | colors = np.array( 19 | [ 20 | [255, 120, 50, 255], # barrier orange 21 | [255, 192, 203, 255], # bicycle pink 22 | [255, 255, 0, 255], # bus yellow 23 | [ 0, 150, 245, 255], # car blue 24 | [ 0, 255, 255, 255], # construction_vehicle cyan 25 | [255, 127, 0, 255], # motorcycle dark orange 26 | [255, 0, 0, 255], # pedestrian red 27 | [255, 240, 150, 255], # traffic_cone light yellow 28 | [135, 60, 0, 255], # trailer brown 29 | [160, 32, 240, 255], # truck purple 30 | [255, 0, 255, 255], # driveable_surface dark pink 31 | # [175, 0, 75, 255], # other_flat dark red 32 | [139, 137, 137, 255], 33 | [ 75, 0, 75, 255], # sidewalk dard purple 34 | [150, 240, 80, 255], # terrain light green 35 | [230, 230, 250, 255], # manmade white 36 | [ 0, 175, 0, 255], # vegetation green 37 | # [ 0, 255, 127, 255], # ego car dark cyan 38 | # [255, 99, 71, 255], # ego car 39 | # [ 0, 191, 255, 255] # ego car 40 | ] 41 | ).astype(np.uint8) 42 | 43 | def pass_print(*args, **kwargs): 44 | pass 45 | 46 | 47 | def get_grid_coords(dims, resolution): 48 | """ 49 | :param dims: the dimensions of the grid [x, y, z] (i.e. [256, 256, 32]) 50 | :return coords_grid: is the center coords of voxels in the grid 51 | """ 52 | 53 | g_xx = np.arange(0, dims[0]) # [0, 1, ..., 256] 54 | # g_xx = g_xx[::-1] 55 | g_yy = np.arange(0, dims[1]) # [0, 1, ..., 256] 56 | # g_yy = g_yy[::-1] 57 | g_zz = np.arange(0, dims[2]) # [0, 1, ..., 32] 58 | 59 | # Obtaining the grid with coords... 60 | xx, yy, zz = np.meshgrid(g_xx, g_yy, g_zz) 61 | coords_grid = np.array([xx.flatten(), yy.flatten(), zz.flatten()]).T 62 | coords_grid = coords_grid.astype(np.float32) 63 | resolution = np.array(resolution, dtype=np.float32).reshape([1, 3]) 64 | 65 | coords_grid = (coords_grid * resolution) + resolution / 2 66 | 67 | return coords_grid 68 | 69 | def draw( 70 | voxels, # semantic occupancy predictions 71 | pred_pts, # lidarseg predictions 72 | vox_origin, 73 | voxel_size=0.2, # voxel size in the real world 74 | grid=None, # voxel coordinates of point cloud 75 | pt_label=None, # label of point cloud 76 | save_dir=None, 77 | cam_positions=None, 78 | focal_positions=None, 79 | timestamp=None, 80 | mode=0, 81 | sem=False, 82 | show_ego=False 83 | ): 84 | w, h, z = voxels.shape 85 | 86 | # assert show_ego 87 | if show_ego: 88 | assert voxels.shape==(200, 200, 16) 89 | voxels[96:104, 96:104, 2:7] = 15 90 | voxels[104:106, 96:104, 2:5] = 3 91 | 92 | # Compute the voxels coordinates 93 | grid_coords = get_grid_coords( 94 | [voxels.shape[0], voxels.shape[1], voxels.shape[2]], voxel_size 95 | ) + np.array(vox_origin, dtype=np.float32).reshape([1, 3]) 96 | 97 | if mode == 0: 98 | grid_coords = np.vstack([grid_coords.T, voxels.reshape(-1)]).T 99 | elif mode == 1: 100 | indexes = grid[:, 0] * h * z + grid[:, 1] * z + grid[:, 2] 101 | indexes, pt_index = np.unique(indexes, return_index=True) 102 | pred_pts = pred_pts[pt_index] 103 | grid_coords = grid_coords[indexes] 104 | grid_coords = np.vstack([grid_coords.T, pred_pts.reshape(-1)]).T 105 | elif mode == 2: 106 | indexes = grid[:, 0] * h * z + grid[:, 1] * z + grid[:, 2] 107 | indexes, pt_index = np.unique(indexes, return_index=True) 108 | gt_label = pt_label[pt_index] 109 | grid_coords = grid_coords[indexes] 110 | grid_coords = np.vstack([grid_coords.T, gt_label.reshape(-1)]).T 111 | else: 112 | raise NotImplementedError 113 | 114 | # Get the voxels inside FOV 115 | fov_grid_coords = grid_coords 116 | 117 | # Remove empty and unknown voxels 118 | fov_voxels = fov_grid_coords[ 119 | (fov_grid_coords[:, 3] > 0) & (fov_grid_coords[:, 3] < 17) 120 | ] 121 | print(len(fov_voxels)) 122 | 123 | 124 | figure = mlab.figure(size=(2560, 1440), bgcolor=(1, 1, 1)) 125 | voxel_size = sum(voxel_size) / 3 126 | plt_plot_fov = mlab.points3d( 127 | # fov_voxels[:, 1], 128 | # fov_voxels[:, 0], 129 | fov_voxels[:, 0], 130 | fov_voxels[:, 1], 131 | fov_voxels[:, 2], 132 | fov_voxels[:, 3], 133 | scale_factor=1.0 * voxel_size, 134 | mode="cube", 135 | opacity=1.0, 136 | vmin=1, 137 | vmax=16, # 16 138 | ) 139 | 140 | plt_plot_fov.glyph.scale_mode = "scale_by_vector" 141 | plt_plot_fov.module_manager.scalar_lut_manager.lut.table = colors 142 | dst=os.path.join(save_dir, f'vis_{timestamp}.png') 143 | mlab.savefig(dst) 144 | mlab.close() 145 | # crop 146 | im3=mmcv.imread(dst)[:,550:-530,:] 147 | # im3=mmcv.imread(dst)[:,590:-600,230:-230] 148 | # im3=mmcv.imread(dst)[360:-230,590:-600] 149 | mmcv.imwrite(im3, dst) 150 | return dst 151 | 152 | 153 | def write_pc(pc,dst,c=None): 154 | # pc=pc 155 | import open3d as o3d 156 | pcd = o3d.geometry.PointCloud() 157 | pcd.points = o3d.utility.Vector3dVector(pc) 158 | if c is not None: 159 | pcd.colors = o3d.utility.Vector3dVector(c) 160 | o3d.io.write_point_cloud(dst, pcd) 161 | 162 | def merge_mesh(meshes): 163 | return reduce(lambda x,y:x+y, meshes) 164 | 165 | 166 | def get_pose_mesh(trans_mat,s=5): 167 | 168 | # Create a coordinate frame with x-axis (red), y-axis (green), and z-axis (blue) 169 | mesh_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=s, origin=[0, 0, 0]) 170 | mesh_frame.transform(trans_mat) 171 | # Save the coordinate frame to a file 172 | return mesh_frame 173 | 174 | 175 | def visualize_point_cloud( 176 | all_pred, 177 | abs_pose, 178 | abs_rot, 179 | vox_origin=[-40, -40, -1], 180 | resolution=0.4, #voxel size 181 | cmp_dir="./", 182 | key='gt' 183 | ): 184 | assert len(all_pred)==len(abs_pose)==len(abs_rot) 185 | all_occ,all_color=[],[] 186 | pose_mesh=[] 187 | for i,(occ,pose,rot) in enumerate(zip(all_pred,abs_pose,abs_rot)): 188 | occ=occ.reshape(-1)#.flatten() 189 | mask=(occ>=1)&(occ<16) # ignore GO 190 | cc=colors[occ[mask]-1][:,:3]/255.0 #[...,::-1] 191 | 192 | # occ_x,occ_y,occ_z=np.meshgrid(np.arange(200),np.arange(200),np.arange(16)) 193 | # occ_x=occ_x.flatten() 194 | # occ_y=occ_y.flatten() 195 | # occ_z=occ_z.flatten() 196 | # occ_xyz=np.concatenate([occ_x[:,None],occ_y[:,None],occ_z[:,None]],axis=1) 197 | # occ_xyz=(occ_xyz * resolution) + resolution / 2 # to center 198 | # occ_xyz+=np.array([-40,-40,-1]) # to ego 199 | # Compute the voxels coordinates in ego frame 200 | occ_xyz = get_grid_coords( 201 | [200,200,16], [resolution]*3 202 | ) + np.array(vox_origin, dtype=np.float32).reshape([1, 3]) 203 | write_pc(occ_xyz[mask],os.path.join(cmp_dir, f'vis_{key}_{i}_e.ply'),c=cc) 204 | 205 | # ego to world 206 | rot_m=Quaternion(rot).rotation_matrix[:3,:3] 207 | # rot_m=rr@rot_m 208 | trans_mat=np.eye(4) 209 | trans_mat[:3,:3]=rot_m 210 | trans_mat[:3,3]=pose 211 | rr=np.array([ 212 | [0,1,0], 213 | [1,0,0], 214 | [0,0,1] 215 | ]) 216 | occ_xyz=occ_xyz@rr.T 217 | occ_xyz=occ_xyz@rot_m.T +pose 218 | 219 | write_pc(occ_xyz[mask],os.path.join(cmp_dir, f'vis_{key}_{i}_w.ply'),c=cc) 220 | 221 | all_occ.append(occ_xyz[mask]) 222 | all_color.append(cc) 223 | pose_mesh.append(get_pose_mesh(trans_mat)) 224 | 225 | 226 | all_occ=np.concatenate(all_occ, axis=0) 227 | all_color=np.concatenate(all_color, axis=0) 228 | 229 | write_pc(all_occ,os.path.join(cmp_dir, f'vis_{key}_all_w.ply'),c=all_color) 230 | o3d.io.write_triangle_mesh(os.path.join(cmp_dir, f'vis_{key}_all_w_traj.ply'),merge_mesh(pose_mesh)) 231 | 232 | 233 | def visualize_point_cloud_no_pose( 234 | all_pred, 235 | vox_origin=[-40, -40, -1], 236 | resolution=0.4, #voxel size 237 | cmp_dir="./", 238 | key='000000', 239 | key2='gt', 240 | offset=0, 241 | ): 242 | for i,occ in enumerate(all_pred): 243 | occ_d=occ.copy() 244 | occ=occ.reshape(-1)#.flatten() 245 | mask=(occ>=1)&(occ<16) # ignore GO 246 | cc=colors[occ[mask]-1][:,:3]/255.0 #[...,::-1] 247 | 248 | occ_xyz = get_grid_coords( 249 | [200,200,16], [resolution]*3 250 | ) + np.array(vox_origin, dtype=np.float32).reshape([1, 3]) 251 | write_pc(occ_xyz[mask],os.path.join(cmp_dir, f'vis_{key}_{i+offset:02d}_e_{key2}.ply'),c=cc) 252 | 253 | np.save(os.path.join(cmp_dir, f'vis_{key}_{i+offset:02d}_e_{key2}.npy'),occ_d) 254 | 255 | 256 | if __name__=='__main__': 257 | # np.savez('/home/users/songen.gu/adwm/OccWorld/visualizations/aaaa.npz',input_occs0=input_occs0,input_occs=input_occs,metas0=metas0,metas=metas) 258 | # load 259 | data=np.load('/home/users/songen.gu/adwm/OccWorld/visualizations/aaaa.npz') 260 | input_occs0=data['input_occs0'] 261 | input_occs=data['input_occs'] 262 | dst_dir='/home/users/songen.gu/adwm/OccWorld/visualizations/abccc' 263 | os.makedirs(dst_dir,exist_ok=True) 264 | dst_wm=draw(input_occs0[10], 265 | None, # predict_pts, 266 | [-40, -40, -1], 267 | [0.4] * 3, 268 | None, # grid.squeeze(0).cpu().numpy(), 269 | None,# pt_label.squeeze(-1), 270 | dst_dir,#recon_dir, 271 | None, # img_metas[0]['cam_positions'], 272 | None, # img_metas[0]['focal_positions'], 273 | timestamp=10, 274 | mode=0, 275 | sem=False) 276 | dst_dir='/home/users/songen.gu/adwm/OccWorld/visualizations/abcc2' 277 | os.makedirs(dst_dir,exist_ok=True) 278 | dst_wm=draw(input_occs[10], 279 | None, # predict_pts, 280 | [-40, -40, -1], 281 | [0.4] * 3, 282 | None, # grid.squeeze(0).cpu().numpy(), 283 | None,# pt_label.squeeze(-1), 284 | dst_dir,#recon_dir, 285 | None, # img_metas[0]['cam_positions'], 286 | None, # img_metas[0]['focal_positions'], 287 | timestamp=10, 288 | mode=0, 289 | sem=False) -------------------------------------------------------------------------------- /tools/vis_vae.sh: -------------------------------------------------------------------------------- 1 | # export CUDA_VISIBLE_DEVICES=5 2 | 3 | cfg=./config/train_occvae.py 4 | dir=./work_dir/occ_vae 5 | 6 | vae_ckpt=ckpts/occvae_latest.pth 7 | 8 | python tools/visualize_demo_vae.py \ 9 | --py-config $cfg \ 10 | --work-dir $dir \ 11 | --resume-from $vae_ckpt \ 12 | --export_pcd 13 | 14 | -------------------------------------------------------------------------------- /tools/visualize_demo.py: -------------------------------------------------------------------------------- 1 | from pyvirtualdisplay import Display 2 | display = Display(visible=False, size=(2560, 1440)) 3 | display.start() 4 | from mayavi import mlab 5 | import mayavi 6 | mlab.options.offscreen = True 7 | print("Set mlab.options.offscreen={}".format(mlab.options.offscreen)) 8 | 9 | import pdb 10 | import time, argparse, os.path as osp, os 11 | import torch, numpy as np 12 | import sys 13 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 14 | 15 | 16 | import mmcv 17 | from mmengine import Config 18 | from mmengine.runner import set_random_seed 19 | from mmengine.logging import MMLogger 20 | from mmengine.registry import MODELS 21 | import cv2 22 | from vis_gif import create_mp4 23 | import warnings 24 | warnings.filterwarnings("ignore") 25 | from einops import rearrange 26 | from diffusion import create_diffusion 27 | from vis_utils import draw 28 | 29 | 30 | 31 | def main(args): 32 | # global settings 33 | set_random_seed(args.seed) 34 | torch.backends.cudnn.deterministic = False 35 | torch.backends.cudnn.benchmark = True 36 | # load config 37 | cfg = Config.fromfile(args.py_config) 38 | cfg.work_dir = args.work_dir 39 | 40 | os.makedirs(args.work_dir, exist_ok=True) 41 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) 42 | args.dir_name=f'{args.dir_name}_{timestamp}' 43 | 44 | log_file = osp.join(args.work_dir, f'{cfg.get("data_type", "gts")}_visualize_{timestamp}.log') 45 | logger = MMLogger('genocc', log_file=log_file) 46 | MMLogger._instance_dict['genocc'] = logger 47 | logger.info(f'Config:\n{cfg.pretty_text}') 48 | 49 | # build model 50 | import model 51 | my_model = MODELS.build(cfg.model.world_model) 52 | n_parameters = sum(p.numel() for p in my_model.parameters() if p.requires_grad) 53 | logger.info(f'Number of params: {n_parameters}') 54 | my_model = my_model.cuda() 55 | raw_model = my_model 56 | vae=MODELS.build(cfg.model.vae).cuda() 57 | 58 | vae.requires_grad_(False) 59 | vae.eval() 60 | 61 | logger.info('done ddp model') 62 | from dataset import get_dataloader 63 | cfg.val_dataset_config.test_mode=True 64 | cfg.val_loader.num_workers=0 65 | cfg.train_loader.num_workers=0 66 | 67 | # cfg.val_dataset_config.new_rel_pose=False ## TODO 68 | # cfg.train_dataset_config.test_index_offset=args.test_index_offset 69 | cfg.val_dataset_config.test_index_offset=args.test_index_offset 70 | if args.return_len is not None: 71 | cfg.train_dataset_config.return_len=max(cfg.train_dataset_config.return_len,args.return_len) 72 | cfg.val_dataset_config.return_len=max(cfg.val_dataset_config.return_len,args.return_len) 73 | # cfg.val_dataset_config.return_len=60 74 | 75 | train_dataset_loader, val_dataset_loader = get_dataloader( 76 | cfg.train_dataset_config, 77 | cfg.val_dataset_config, 78 | cfg.train_wrapper_config, 79 | cfg.val_wrapper_config, 80 | cfg.train_loader, 81 | cfg.val_loader, 82 | dist=False) 83 | cfg.resume_from = '' 84 | if osp.exists(osp.join(args.work_dir, 'latest.pth')): 85 | cfg.resume_from = osp.join(args.work_dir, 'latest.pth') 86 | else: 87 | ckpts=[i for i in os.listdir(args.work_dir) if 88 | i.endswith('.pth') and i.replace('.pth','').replace('epoch_','').isdigit()] 89 | if len(ckpts)>0: 90 | ckpts.sort(key=lambda x:int(x.replace('.pth','').replace('epoch_',''))) 91 | cfg.resume_from = osp.join(args.work_dir, ckpts[-1]) 92 | 93 | if args.resume_from: 94 | cfg.resume_from = args.resume_from 95 | if args.vae_resume_from: 96 | cfg.vae_load_from=args.vae_resume_from 97 | logger.info('resume from: ' + cfg.resume_from) 98 | logger.info('vae resume from: ' + cfg.vae_load_from) 99 | logger.info('work dir: ' + args.work_dir) 100 | 101 | epoch = 'last' 102 | if cfg.resume_from and osp.exists(cfg.resume_from): 103 | map_location = 'cpu' 104 | ckpt = torch.load(cfg.resume_from, map_location=map_location) 105 | print(raw_model.load_state_dict(ckpt['state_dict'], strict=False)) 106 | epoch = ckpt['epoch'] 107 | print(f'successfully resumed from epoch {epoch}') 108 | elif cfg.load_from: 109 | ckpt = torch.load(cfg.load_from, map_location='cpu') 110 | if 'state_dict' in ckpt: 111 | state_dict = ckpt['state_dict'] 112 | else: 113 | state_dict = ckpt 114 | print(raw_model.load_state_dict(state_dict, strict=False)) 115 | print(vae.load_state_dict(torch.load(cfg.vae_load_from)['state_dict'])) 116 | 117 | # eval 118 | my_model.eval() 119 | os.environ['eval'] = 'true' 120 | recon_dir = os.path.join(args.work_dir, args.dir_name) 121 | os.makedirs(recon_dir, exist_ok=True) 122 | os.environ['recon_dir']=recon_dir 123 | 124 | diffusion = create_diffusion( 125 | # timestep_respacing=str(cfg.sample.num_sampling_steps), 126 | timestep_respacing=str(args.num_sampling_steps), 127 | beta_start=cfg.schedule.beta_start, 128 | beta_end=cfg.schedule.beta_end, 129 | replace_cond_frames=cfg.replace_cond_frames, 130 | cond_frames_choices=cfg.cond_frames_choices, 131 | predict_xstart=cfg.schedule.get('predict_xstart',False), 132 | ) 133 | if args.pose_control: 134 | cfg.sample.n_conds=1 135 | print(len(val_dataset_loader)) 136 | with torch.no_grad(): 137 | for i_iter_val, (input_occs, _, metas) in enumerate(val_dataset_loader): 138 | if i_iter_val not in args.scene_idx: 139 | continue 140 | if i_iter_val > max(args.scene_idx): 141 | break 142 | start_frame=cfg.get('start_frame', 0) 143 | mid_frame=cfg.get('mid_frame', 3) 144 | # end_frame=cfg.get('end_frame', 9) 145 | end_frame=input_occs.shape[1] if args.end_frame is None else args.end_frame 146 | 147 | if args.pose_control: 148 | # start_frame=0 149 | mid_frame=1 150 | # end_frame=10 151 | assert cfg.sample.n_conds==mid_frame 152 | # __import__('ipdb').set_trace() 153 | input_occs = input_occs.cuda() #torch.Size([1, 16, 200, 200, 16]) 154 | bs,f,_,_,_=input_occs.shape 155 | encoded_latent, shape=vae.forward_encoder(input_occs) 156 | encoded_latent,_,_=vae.sample_z(encoded_latent) #bchw 157 | # encoded_latent = self.vae.vqvae.quant_conv(encoded_latent) 158 | # encoded_latent, _,_ = vae.vqvae(encoded_latent, is_voxel=False) 159 | input_latents=encoded_latent*cfg.model.vae.scaling_factor 160 | if input_latents.dim()==4: 161 | input_latents = rearrange(input_latents, '(b f) c h w -> b f c h w', b=bs).contiguous() 162 | elif input_latents.dim()==5: 163 | input_latents = rearrange(input_latents, 'b c f h w -> b f c h w', b=bs).contiguous() 164 | else: 165 | raise NotImplementedError 166 | 167 | 168 | # from debug_vis import visualize_tensor_pca 169 | # TODO fix dim bug torch.Size([1, 64, 12, 25, 25]) 170 | # visualize_tensor_pca(encoded_latent.permute(0,2,3,1).cpu(), save_dir=recon_dir+'/debug_feature', filename=f'vis_vae_encode_{i_iter_val}.png') 171 | os.environ.update({'i_iter_val': str(i_iter_val)}) 172 | os.environ.update({'recon_dir': str(recon_dir)}) 173 | # rencon_occs=vae.forward_decoder(encoded_latent, shape, input_occs.shape) 174 | 175 | # gaussian diffusion pipeline 176 | w=h=cfg.model.vae.encoder_cfg.resolution 177 | vae_scale_factor = 2 ** (len(cfg.model.vae.encoder_cfg.ch_mult) - 1) 178 | vae_docoder_shapes=cfg.shapes[:len(cfg.model.vae.encoder_cfg.ch_mult) - 1] 179 | w//=vae_scale_factor 180 | h//=vae_scale_factor 181 | 182 | model_kwargs=dict( 183 | # # cfg_scale=cfg.sample.guidance_scale 184 | # metas=metas 185 | ) 186 | if args.pose or args.pose_control: 187 | # assert False #debug pure gen 188 | model_kwargs['metas']=metas 189 | noise_shape=(bs, end_frame,cfg.base_channel, w,h,) 190 | initial_cond_indices=None 191 | n_conds=cfg.sample.get('n_conds',0) 192 | if n_conds: 193 | initial_cond_indices=[index for index in range(n_conds)] 194 | 195 | # Sample images: 196 | if cfg.sample.sample_method == 'ddim': 197 | latents = diffusion.ddim_sample_loop( 198 | my_model, noise_shape, None, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device='cuda' 199 | ) 200 | elif cfg.sample.sample_method == 'ddpm': 201 | if args.rolling_sampling_n<2: 202 | 203 | latents = diffusion.p_sample_loop( 204 | my_model, noise_shape, None, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device='cuda', 205 | initial_cond_indices=initial_cond_indices, 206 | initial_cond_frames=input_latents, 207 | ) 208 | else: 209 | latents=diffusion.p_sample_loop_cond_rollout( 210 | my_model, noise_shape, None, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device='cuda', 211 | # initial_cond_indices=initial_cond_indices, 212 | input_latents=input_latents, 213 | rolling_sampling_n=args.rolling_sampling_n, 214 | n_conds=n_conds, 215 | n_conds_roll=args.n_conds_roll 216 | ) 217 | end_frame=latents.shape[1] 218 | latents = 1 / cfg.model.vae.scaling_factor * latents 219 | 220 | if cfg.model.vae.decoder_cfg.type=='Decoder3D': 221 | latents = rearrange(latents,'b f c h w-> b c f h w') 222 | else: 223 | # assert False #debug 224 | latents = rearrange(latents,'b f c h w -> (b f) c h w') 225 | 226 | logits = vae.forward_decoder( 227 | latents , shapes=vae_docoder_shapes,input_shape=[bs,end_frame,*cfg.shapes[0],cfg._dim_] 228 | ) 229 | dst_dir = os.path.join(recon_dir, str(i_iter_val),'pred') 230 | input_dir = os.path.join(recon_dir, f'{i_iter_val}','input') 231 | # input_occs = result['input_occs'] 232 | os.makedirs(dst_dir, exist_ok=True) 233 | os.makedirs(input_dir, exist_ok=True) 234 | 235 | 236 | if True: 237 | import matplotlib.pyplot as plt 238 | plt.clf() 239 | plt.plot(metas[0]['rel_poses'][:,0],metas[0]['rel_poses'][:,1],marker='o',alpha=0.5) 240 | plt.savefig(os.path.join(dst_dir, f'pose.png')) 241 | plt.clf() 242 | # for i, xyz in enumerate(e2g_t): 243 | # xy=xyz[:2] 244 | # gt_mode=gt_modes[i].astype('int').tolist().index(1) 245 | # ax2.annotate(f"{i+1}({gt_mode})", xy=xy, textcoords="offset points", xytext=(0,10), ha='center') 246 | # ax2.set_title('ego2global_translation (xy) (idx+gt_mode)') 247 | 248 | plt.plot(metas[0]['e2g_rel0_t'][:,0],metas[0]['e2g_rel0_t'][:,1]) 249 | plt.scatter([0],[0],c='r') 250 | 251 | plt.annotate(f"start", xy=(0,0), textcoords="offset points", xytext=(0,10),ha='center') 252 | 253 | 254 | plt.savefig(os.path.join(dst_dir, f'pose_w.png')) 255 | # exit(0) 256 | all_pred=[] 257 | for frame in range(start_frame,end_frame): 258 | # for frame in range(0,end_frame): 259 | # if frame >15 and frame%10!=0: 260 | # continue 261 | # tt=str(i_iter_val) + '_' + str(frame) 262 | tt=str(i_iter_val) + '_' + str(frame+args.test_index_offset) 263 | # if frame < rencon_occs.shape[1]: 264 | # input_occ = rencon_occs[:, frame, ...].argmax(-1).squeeze().cpu().numpy() 265 | if frame < input_occs.shape[1] and not args.skip_gt: 266 | # if True: 267 | input_occ = input_occs[:, frame, ...].squeeze().cpu().numpy() 268 | draw(input_occ, 269 | None, # predict_pts, 270 | [-40, -40, -1], 271 | [0.4] * 3, 272 | None, # grid.squeeze(0).cpu().numpy(), 273 | None,# pt_label.squeeze(-1), 274 | input_dir,#recon_dir, 275 | None, # img_metas[0]['cam_positions'], 276 | None, # img_metas[0]['focal_positions'], 277 | timestamp=tt, 278 | mode=0, 279 | sem=False, 280 | show_ego=args.show_ego) 281 | if True: 282 | # if frame>=mid_frame: 283 | logit = logits[:, frame, ...] 284 | pred = logit.argmax(dim=-1).squeeze().cpu().numpy() # 1, 1, 200, 200, 16 285 | all_pred.append((pred)) 286 | 287 | # all_pred.append((pred)) 288 | 289 | draw(pred, 290 | None, # predict_pts, 291 | [-40, -40, -1], 292 | [0.4] * 3, 293 | None, # grid.squeeze(0).cpu().numpy(), 294 | None,# pt_label.squeeze(-1), 295 | dst_dir,#recon_dir, 296 | None, # img_metas[0]['cam_positions'], 297 | None, # img_metas[0]['focal_positions'], 298 | timestamp=tt, 299 | mode=0, 300 | sem=False, 301 | show_ego=args.show_ego) 302 | logger.info('[EVAL] Iter %5d / %5d'%(i_iter_val, len(val_dataset_loader))) 303 | create_mp4(dst_dir) 304 | # create_mp4(cmp_dir) 305 | if args.export_pcd: 306 | from vis_utils import visualize_point_cloud 307 | 308 | abs_pose=metas[0]['e2g_t'] 309 | abs_rot=metas[0]['e2g_r'] 310 | n_gt=min(len(all_pred),len(abs_pose)) 311 | visualize_point_cloud(all_pred[:n_gt],abs_pose=abs_pose[:n_gt],abs_rot=abs_rot[:n_gt],cmp_dir=dst_dir,key='pred') 312 | 313 | 314 | if __name__ == '__main__': 315 | # Eval settings 316 | parser = argparse.ArgumentParser(description='') 317 | parser.add_argument('--py-config', default='config/tpv_lidarseg.py') 318 | parser.add_argument('--work-dir', type=str, default='./out/tpv_lidarseg') 319 | parser.add_argument('--resume-from', type=str, default='') 320 | parser.add_argument('--vae-resume-from', type=str, default='') 321 | parser.add_argument('--dir-name', type=str, default='vis') 322 | parser.add_argument('--num_sampling_steps', type=int, default=20) 323 | parser.add_argument('--seed', type=int, default=42) 324 | parser.add_argument('--end_frame', type=int, default=None) 325 | parser.add_argument('--n_conds_roll', type=int, default=None) 326 | parser.add_argument('--return_len', type=int, default=None) 327 | parser.add_argument('--num-trials', type=int, default=10) 328 | parser.add_argument('--frame-idx', nargs='+', type=int, default=[0, 10]) 329 | ######################################### 330 | # parser.add_argument('--scene-idx', nargs='+', type=int, default=[6,7,16,18,19,87,89,96,101]) 331 | parser.add_argument('--scene-idx', nargs='+', type=int, default=[6,7]) 332 | parser.add_argument('--rolling_sampling_n', type=int, default=1) 333 | parser.add_argument('--pose_control', action='store_true', default=False) 334 | parser.add_argument('--pose', action='store_true', default=True, help='Enable pose (default is True)') 335 | parser.add_argument('--no-pose', action='store_false', dest='pose', help='Disable pose') 336 | parser.add_argument('--test_index_offset',type=int, default=0) 337 | parser.add_argument('--ts',type=str, default=None) 338 | parser.add_argument('--skip_gt', action='store_true', default=False, help='Enable pose (default is True)') 339 | parser.add_argument('--show_ego', action='store_true', default=False, help='Enable pose (default is True)') 340 | parser.add_argument('--export_pcd', action='store_true', default=False, help='Enable pose (default is True)') 341 | 342 | args = parser.parse_args() 343 | 344 | ngpus = 1 345 | args.gpus = ngpus 346 | print(args) 347 | main(args) 348 | 349 | -------------------------------------------------------------------------------- /tools/visualize_demo_vae.py: -------------------------------------------------------------------------------- 1 | from pyvirtualdisplay import Display 2 | display = Display(visible=False, size=(2560, 1440)) 3 | display.start() 4 | import pdb 5 | from mayavi import mlab 6 | import mayavi 7 | mlab.options.offscreen = True 8 | print("Set mlab.options.offscreen={}".format(mlab.options.offscreen)) 9 | 10 | import time, argparse, os.path as osp, os 11 | import sys 12 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 13 | import torch, numpy as np 14 | import mmcv 15 | from mmengine import Config 16 | from mmengine.runner import set_random_seed 17 | from mmengine.logging import MMLogger 18 | from mmengine.registry import MODELS 19 | import cv2 20 | from vis_gif import create_mp4 21 | import warnings 22 | warnings.filterwarnings("ignore") 23 | from vis_utils import draw 24 | 25 | 26 | 27 | def main(args): 28 | # global settings 29 | set_random_seed(args.seed) 30 | torch.backends.cudnn.deterministic = False 31 | torch.backends.cudnn.benchmark = True 32 | # load config 33 | cfg = Config.fromfile(args.py_config) 34 | cfg.work_dir = args.work_dir 35 | 36 | os.makedirs(args.work_dir, exist_ok=True) 37 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) 38 | args.dir_name=f'{args.dir_name}_{timestamp}' 39 | 40 | log_file = osp.join(args.work_dir, f'{cfg.get("data_type", "gts")}_visualize_{timestamp}.log') 41 | logger = MMLogger('genocc', log_file=log_file) 42 | MMLogger._instance_dict['genocc'] = logger 43 | logger.info(f'Config:\n{cfg.pretty_text}') 44 | 45 | # build model 46 | import model 47 | my_model = MODELS.build(cfg.model) 48 | n_parameters = sum(p.numel() for p in my_model.parameters() if p.requires_grad) 49 | logger.info(f'Number of params: {n_parameters}') 50 | my_model = my_model.cuda() 51 | raw_model = my_model 52 | logger.info('done ddp model') 53 | from dataset import get_dataloader 54 | train_dataset_loader, val_dataset_loader = get_dataloader( 55 | cfg.train_dataset_config, 56 | cfg.val_dataset_config, 57 | cfg.train_wrapper_config, 58 | cfg.val_wrapper_config, 59 | cfg.train_loader, 60 | cfg.val_loader, 61 | dist=False) 62 | cfg.resume_from = '' 63 | if osp.exists(osp.join(args.work_dir, 'latest.pth')): 64 | cfg.resume_from = osp.join(args.work_dir, 'latest.pth') 65 | else: 66 | ckpts=[i for i in os.listdir(args.work_dir) if 67 | i.endswith('.pth') and i.replace('.pth','').replace('epoch_','').isdigit()] 68 | if len(ckpts)>0: 69 | ckpts.sort(key=lambda x:int(x.replace('.pth','').replace('epoch_',''))) 70 | cfg.resume_from = osp.join(args.work_dir, ckpts[-1]) 71 | 72 | if args.resume_from: 73 | cfg.resume_from = args.resume_from 74 | logger.info('resume from: ' + cfg.resume_from) 75 | logger.info('work dir: ' + args.work_dir) 76 | 77 | epoch = 'last' 78 | if cfg.resume_from and osp.exists(cfg.resume_from): 79 | map_location = 'cpu' 80 | ckpt = torch.load(cfg.resume_from, map_location=map_location) 81 | print(raw_model.load_state_dict(ckpt['state_dict'], strict=False)) 82 | epoch = ckpt['epoch'] 83 | print(f'successfully resumed from epoch {epoch}') 84 | elif cfg.load_from: 85 | ckpt = torch.load(cfg.load_from, map_location='cpu') 86 | if 'state_dict' in ckpt: 87 | state_dict = ckpt['state_dict'] 88 | else: 89 | state_dict = ckpt 90 | print(raw_model.load_state_dict(state_dict, strict=False)) 91 | 92 | # eval 93 | my_model.eval() 94 | os.environ['eval'] = 'true' 95 | recon_dir = os.path.join(args.work_dir, args.dir_name) 96 | os.makedirs(recon_dir, exist_ok=True) 97 | with torch.no_grad(): 98 | for i_iter_val, (input_occs, _, metas) in enumerate(val_dataset_loader): 99 | if i_iter_val not in args.scene_idx: 100 | continue 101 | if i_iter_val > max(args.scene_idx): 102 | break 103 | input_occs = input_occs.cuda() #torch.Size([1, 16, 200, 200, 16]) 104 | result = my_model(x=input_occs, metas=metas) 105 | start_frame=cfg.get('start_frame', 0) 106 | #end_frame=cfg.get('end_frame', 11) 107 | end_frame=input_occs.shape[1] 108 | logits = result['logits'] #torch.Size([1, 6, 200, 200, 16, 18]) 109 | dst_dir = os.path.join(recon_dir, str(i_iter_val),'pred') 110 | input_dir = os.path.join(recon_dir, f'{i_iter_val}','input') 111 | cmp_dir = os.path.join(recon_dir, f'{i_iter_val}','cmp') 112 | # input_occs = result['input_occs'] 113 | os.makedirs(dst_dir, exist_ok=True) 114 | os.makedirs(input_dir, exist_ok=True) 115 | os.makedirs(cmp_dir, exist_ok=True) 116 | all_pred=[] 117 | for frame in range(start_frame,end_frame): 118 | tt=str(i_iter_val) + '_' + str(frame) 119 | input_occ = input_occs[:, frame, ...].squeeze().cpu().numpy() 120 | dst_input=draw(input_occ, 121 | None, # predict_pts, 122 | [-40, -40, -1], 123 | [0.4] * 3, 124 | None, # grid.squeeze(0).cpu().numpy(), 125 | None,# pt_label.squeeze(-1), 126 | input_dir,#recon_dir, 127 | None, # img_metas[0]['cam_positions'], 128 | None, # img_metas[0]['focal_positions'], 129 | timestamp=tt, 130 | mode=0, 131 | sem=False) 132 | im=mmcv.imread(dst_input) 133 | cv2.putText(im, f'GT_{frame:02d}', (20, 100), cv2.FONT_HERSHEY_SIMPLEX, 3, (0, 0, 255), 2) 134 | mmcv.imwrite(im,dst_input) 135 | if True: 136 | logit = logits[:, frame, ...] 137 | pred = logit.argmax(dim=-1).squeeze().cpu().numpy() # 1, 1, 200, 200, 16 138 | all_pred.append((pred)) 139 | 140 | dst_wm=draw(pred, 141 | None, # predict_pts, 142 | [-40, -40, -1], 143 | [0.4] * 3, 144 | None, # grid.squeeze(0).cpu().numpy(), 145 | None,# pt_label.squeeze(-1), 146 | dst_dir,#recon_dir, 147 | None, # img_metas[0]['cam_positions'], 148 | None, # img_metas[0]['focal_positions'], 149 | timestamp=tt, 150 | mode=0, 151 | sem=False) 152 | im=mmcv.imread(dst_wm) 153 | cv2.putText(im, f'predict_{frame:02d}', (20, 100), cv2.FONT_HERSHEY_SIMPLEX, 3, (0, 0, 255), 2) 154 | mmcv.imwrite(im,dst_wm) 155 | # concat 2 img 156 | 157 | 158 | im1=mmcv.imread(dst_input)#[:,550:-530,:] 159 | im2=mmcv.imread(dst_wm)#[:,550:-530,:] 160 | mmcv.imwrite(np.concatenate([im1, im2], axis=1), os.path.join(cmp_dir, f'vis_{tt}.png')) 161 | logger.info('[EVAL] Iter %5d / %5d'%(i_iter_val, len(val_dataset_loader))) 162 | create_mp4(dst_dir) 163 | # create_mp4(cmp_dir) 164 | if args.export_pcd: 165 | from vis_utils import visualize_point_cloud 166 | 167 | abs_pose=metas[0]['e2g_t'] 168 | abs_rot=metas[0]['e2g_r'] 169 | n_gt=min(len(all_pred),len(abs_pose)) 170 | visualize_point_cloud(all_pred[:n_gt],abs_pose=abs_pose[:n_gt],abs_rot=abs_rot[:n_gt],cmp_dir=dst_dir,key='pred') 171 | 172 | 173 | # break #debug 174 | 175 | if __name__ == '__main__': 176 | # Eval settings 177 | parser = argparse.ArgumentParser(description='') 178 | parser.add_argument('--py-config', default='config/tpv_lidarseg.py') 179 | parser.add_argument('--work-dir', type=str, default='./out/tpv_lidarseg') 180 | parser.add_argument('--resume-from', type=str, default='') 181 | parser.add_argument('--dir-name', type=str, default='vis') 182 | parser.add_argument('--seed', type=int, default=42)#,1023,333,256]) 183 | parser.add_argument('--num-trials', type=int, default=10) 184 | parser.add_argument('--frame-idx', nargs='+', type=int, default=[0, 10]) 185 | parser.add_argument('--scene-idx', nargs='+', type=int, default=[6])#,7,16,18,19,87,89,96,101]) 186 | parser.add_argument('--export_pcd', action='store_true', default=False, help='Enable pose (default is True)') 187 | 188 | args = parser.parse_args() 189 | 190 | ngpus = 1 191 | args.gpus = ngpus 192 | print(args) 193 | main(args) 194 | 195 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /utils/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import OrderedDict 3 | 4 | @torch.no_grad() 5 | def update_ema(ema_model, model, decay=0.9999): 6 | """ 7 | Step the EMA model towards the current model. 8 | """ 9 | ema_params = OrderedDict(ema_model.named_parameters()) 10 | model_params = OrderedDict(model.named_parameters()) 11 | 12 | for name, param in model_params.items(): 13 | # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed 14 | ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay) -------------------------------------------------------------------------------- /utils/freeze_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from mmengine import MMLogger 3 | logger = MMLogger.get_instance('genocc') 4 | import torch.distributed as dist 5 | 6 | def freeze_model(model, freeze_dict): 7 | # given a model and a dictionary of booleans, freeze the model 8 | # according to the dictionary 9 | for key in freeze_dict: 10 | if freeze_dict[key]: 11 | for param in getattr(model, key).parameters(): 12 | param.requires_grad = False 13 | logger = MMLogger.get_current_instance() 14 | logger.info(f'Freezed {key} parameters') 15 | 16 | if __name__ == '__main__': 17 | model = torch.nn.Sequential( 18 | torch.nn.Linear(1, 1), 19 | torch.nn.Linear(1, 1), 20 | torch.nn.Linear(1, 1) 21 | ) 22 | print(model) 23 | freeze_dict = {'0': True, '1': False, '2': True} 24 | freeze_model(model, freeze_dict) 25 | import pdb; pdb.set_trace() -------------------------------------------------------------------------------- /utils/load_save_util.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import torch 3 | 4 | def revise_ckpt(state_dict): 5 | tmp_k = list(state_dict.keys())[0] 6 | if not tmp_k.startswith('module.'): 7 | state_dict = OrderedDict( 8 | {('module.' + k): v 9 | for k, v in state_dict.items()}) 10 | return state_dict 11 | 12 | def revise_ckpt_1(state_dict): 13 | tmp_k = list(state_dict.keys())[0] 14 | if tmp_k.startswith('module.'): 15 | state_dict = OrderedDict( 16 | {k[7:]: v 17 | for k, v in state_dict.items()}) 18 | return state_dict 19 | 20 | def revise_ckpt_2(state_dict): 21 | param_names = list(state_dict.keys()) 22 | for param_name in param_names: 23 | if 'img_neck.lateral_convs' in param_name or 'img_neck.fpn_convs' in param_name: 24 | del state_dict[param_name] 25 | return state_dict 26 | 27 | def load_checkpoint(raw_model, state_dict, strict=True) -> None: 28 | # state_dict = checkpoint["state_dict"] 29 | model_state_dict = raw_model.state_dict() 30 | is_changed = False 31 | for k in state_dict: 32 | if k in model_state_dict: 33 | if state_dict[k].shape != model_state_dict[k].shape: 34 | # process embedding 35 | if k=='class_embeds.weight': 36 | l1=state_dict[k].shape[0] 37 | l2=model_state_dict[k].shape[0] 38 | if l1>l2: 39 | state_dict[k] = state_dict[k][:l2] 40 | else: 41 | state_dict_k_new=torch.zeros_like(model_state_dict[k]) 42 | state_dict_k_new[:l1]=state_dict[k] 43 | state_dict[k]=state_dict_k_new 44 | else: 45 | print(f"Skip loading parameter: {k}, " 46 | f"required shape: {model_state_dict[k].shape}, " 47 | f"loaded shape: {state_dict[k].shape}") 48 | state_dict[k] = model_state_dict[k] 49 | is_changed = True 50 | 51 | else: 52 | print(f"Dropping parameter {k}") 53 | is_changed = True 54 | 55 | # if is_changed: 56 | # checkpoint.pop("optimizer_states", None) 57 | 58 | print(raw_model.load_state_dict(state_dict, strict=strict)) -------------------------------------------------------------------------------- /utils/metric_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from mmengine import MMLogger 3 | import torch 4 | import torch.distributed as dist 5 | # logger = MMLogger.get_instance('genocc',distributed=dist.is_initialized()) 6 | 7 | 8 | class MeanIoU: 9 | 10 | def __init__(self, 11 | class_indices, 12 | ignore_label: int, 13 | label_str, 14 | name 15 | # empty_class: int 16 | ): 17 | self.class_indices = class_indices 18 | self.num_classes = len(class_indices) 19 | self.ignore_label = ignore_label 20 | self.label_str = label_str 21 | self.name = name 22 | 23 | def reset(self) -> None: 24 | self.total_seen = torch.zeros(self.num_classes).cuda() 25 | self.total_correct = torch.zeros(self.num_classes).cuda() 26 | self.total_positive = torch.zeros(self.num_classes).cuda() 27 | 28 | def _after_step(self, outputs, targets, log_current=False): 29 | outputs = outputs[targets != self.ignore_label] 30 | targets = targets[targets != self.ignore_label] 31 | 32 | for i, c in enumerate(self.class_indices): 33 | self.total_seen[i] += torch.sum(targets == c).item() 34 | self.total_correct[i] += torch.sum((targets == c) 35 | & (outputs == c)).item() 36 | self.total_positive[i] += torch.sum(outputs == c).item() 37 | 38 | def _after_epoch(self): 39 | if dist.is_initialized(): 40 | dist.all_reduce(self.total_seen) 41 | dist.all_reduce(self.total_correct) 42 | dist.all_reduce(self.total_positive) 43 | 44 | ious = [] 45 | 46 | for i in range(self.num_classes): 47 | if self.total_seen[i] == 0: 48 | ious.append(1) 49 | else: 50 | cur_iou = self.total_correct[i] / (self.total_seen[i] 51 | + self.total_positive[i] 52 | - self.total_correct[i]) 53 | ious.append(cur_iou.item()) 54 | 55 | miou = np.mean(ious) 56 | logger = MMLogger.get_current_instance() 57 | logger.info(f'Validation per class iou {self.name}:') 58 | for iou, label_str in zip(ious, self.label_str): 59 | logger.info('%s : %.2f%%' % (label_str, iou * 100)) 60 | 61 | return miou * 100 62 | 63 | 64 | class multi_step_MeanIou: 65 | def __init__(self, 66 | class_indices, 67 | ignore_label: int, 68 | label_str, 69 | name, 70 | times=1): 71 | self.class_indices = class_indices 72 | self.num_classes = len(class_indices) 73 | self.ignore_label = ignore_label 74 | self.label_str = label_str 75 | self.name = name 76 | self.times = times 77 | 78 | def reset(self) -> None: 79 | self.total_seen = torch.zeros(self.times, self.num_classes).cuda() 80 | self.total_correct = torch.zeros(self.times, self.num_classes).cuda() 81 | self.total_positive = torch.zeros(self.times, self.num_classes).cuda() 82 | self.current_seen = torch.zeros(self.times, self.num_classes).cuda() 83 | self.current_correct = torch.zeros(self.times, self.num_classes).cuda() 84 | self.current_positive = torch.zeros(self.times, self.num_classes).cuda() 85 | 86 | def _after_step(self, outputses, targetses,log_current=False): 87 | 88 | assert outputses.shape[1] == self.times, f'{outputses.shape[1]} != {self.times}' 89 | assert targetses.shape[1] == self.times, f'{targetses.shape[1]} != {self.times}' 90 | mious = [] 91 | for t in range(self.times): 92 | ious = [] 93 | outputs = outputses[:,t, ...][targetses[:,t, ...] != self.ignore_label].cuda() 94 | targets = targetses[:,t, ...][targetses[:,t, ...] != self.ignore_label].cuda() 95 | for j, c in enumerate(self.class_indices): 96 | self.total_seen[t, j] += torch.sum(targets == c).item() 97 | self.total_correct[t, j] += torch.sum((targets == c) 98 | & (outputs == c)).item() 99 | self.total_positive[t, j] += torch.sum(outputs == c).item() 100 | if log_current: 101 | current_seen = torch.sum(targets == c).item() 102 | current_correct = torch.sum((targets == c)& (outputs == c)).item() 103 | current_positive = torch.sum(outputs == c).item() 104 | if current_seen == 0: 105 | ious.append(1) 106 | else: 107 | cur_iou = current_correct / (current_seen+current_positive-current_correct) 108 | ious.append(cur_iou) 109 | if log_current: 110 | miou = np.mean(ious) 111 | logger = MMLogger.get_current_instance()#distributed=dist.is_initialized()) 112 | logger.info(f'current:: per class iou {self.name} at time {t}:') 113 | for iou, label_str in zip(ious, self.label_str): 114 | logger.info('%s : %.2f%%' % (label_str, iou * 100)) 115 | logger.info(f'mIoU {self.name} at time {t}: %.2f%%' % (miou * 100)) 116 | mious.append(miou * 100) 117 | m_miou=np.mean(mious) 118 | # mious=torch.tensor(mious).cuda() 119 | return mious, m_miou 120 | 121 | def _after_epoch(self): 122 | logger = MMLogger.get_current_instance()#distributed=dist.is_initialized()) 123 | if dist.is_initialized(): 124 | dist.all_reduce(self.total_seen) 125 | dist.all_reduce(self.total_correct) 126 | dist.all_reduce(self.total_positive) 127 | logger.info(f'_after_epoch::total_seen: {self.total_seen.sum()}') 128 | logger.info(f'_after_epoch::total_correct: {self.total_correct.sum()}') 129 | logger.info(f'_after_epoch::total_positive: {self.total_positive.sum()}') 130 | mious = [] 131 | for t in range(self.times): 132 | ious = [] 133 | for i in range(self.num_classes): 134 | if self.total_seen[t, i] == 0: 135 | ious.append(1) 136 | else: 137 | cur_iou = self.total_correct[t, i] / (self.total_seen[t, i] 138 | + self.total_positive[t, i] 139 | - self.total_correct[t, i]) 140 | ious.append(cur_iou.item()) 141 | miou = np.mean(ious) 142 | logger.info(f'per class iou {self.name} at time {t}:') 143 | for iou, label_str in zip(ious, self.label_str): 144 | logger.info('%s : %.2f%%' % (label_str, iou * 100)) 145 | logger.info(f'mIoU {self.name} at time {t}: %.2f%%' % (miou * 100)) 146 | mious.append(miou * 100) 147 | return mious, np.mean(mious) --------------------------------------------------------------------------------