├── .gitignore
├── LICENSE
├── README.md
├── config
├── label_mapping
│ └── nuscenes-occ.yaml
├── train_dome.py
├── train_dome_resample.py
└── train_occvae.py
├── dataset
├── __init__.py
├── dataset.py
├── dataset_wrapper.py
└── sampler.py
├── diffusion
├── __init__.py
├── diffusion_utils.py
├── gaussian_diffusion.py
├── respace.py
└── timestep_sampler.py
├── environment.yml
├── loss
├── __init__.py
├── base_loss.py
├── ce_loss.py
├── emb_loss.py
├── multi_loss.py
└── recon_loss.py
├── model
├── VAE
│ ├── quantizer.py
│ └── vae_2d_resnet.py
├── __init__.py
├── dome.py
└── pose_encoder.py
├── resample
├── astar.py
├── launch.py
├── main.py
├── requirements.txt
└── utils.py
├── static
└── images
│ ├── dome_pipe.png
│ ├── favicon.ico
│ ├── occvae.png
│ ├── overall_pipeline4.png
│ ├── teaser12.png
│ └── vis_demo_cmp_2.png
├── tools
├── eval.sh
├── eval_metric.py
├── eval_vae.py
├── eval_vae.sh
├── train_diffusion.py
├── train_diffusion.sh
├── train_vae.py
├── train_vae.sh
├── vis_diffusion.sh
├── vis_gif.py
├── vis_utils.py
├── vis_vae.sh
├── visualize_demo.py
└── visualize_demo_vae.py
└── utils
├── __init__.py
├── ema.py
├── freeze_model.py
├── load_save_util.py
└── metric_util.py
/.gitignore:
--------------------------------------------------------------------------------
1 | data/
2 | data
3 | out
4 | out/
5 | results/
6 | __pycache__/
7 | visualization/results/
8 | analyze/
9 | *.png
10 | *.jpg
11 | *.log
12 | ckpts/
13 | results/
14 | .DS_Store
15 | dev/
16 | kernels/
17 | work_dir/
18 | *.pth
19 | *.zip
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # DOME: Taming Diffusion Model into High-Fidelity Controllable Occupancy World Model
4 |
5 |
6 | ### [Project Page](https://gusongen.github.io/DOME/) | [Paper](https://arxiv.org/abs/2410.10429v1)
7 |
8 |
9 |
10 |
11 |
12 | https://github.com/user-attachments/assets/7da724d5-acbc-40f7-b5f8-dac38cfbe24a
13 |
14 |
15 | https://github.com/user-attachments/assets/9ab4237c-67c0-4718-8576-32068a74bccf
16 |
17 |
18 | https://github.com/user-attachments/assets/f29149f3-a749-4c55-9777-3d05912bebe1
19 |
20 |
21 |
22 |
23 |
24 |
25 | Our Occupancy World Model can generate long-duration occupancy forecasts and can be effectively controlled by trajectory conditions.
26 |
27 |
28 | # 📖 Overview
29 |
30 | Our method consists of two components: (a) Occ-VAE Pipeline encodes occupancy frames into a continuous latent space, enabling efficient data compression. (b)DOME Pipeline learns to predict 4D occupancy based on historical occupancy observations.
31 |
32 |
33 |
34 | ## 🗓️ News
35 | - [2025.1.1] We release the code and checkpoints.
36 | - [2024.11.18] Project page is online!
37 |
38 | ## 🗓️ TODO
39 | - [x] Code release.
40 | - [x] Checkpoint release.
41 |
42 |
43 | ## 🚀 Setup
44 | ### clone the repo
45 | ```
46 | git clone https://github.com/gusongen/DOME.git
47 | cd DOME
48 | ```
49 |
50 | ### environment setup
51 | ```
52 | conda env create --file environment.yml
53 | ```
54 |
55 | ### data preparation
56 | 1. Create soft link from `data/nuscenes` to your_nuscenes_path
57 |
58 | 2. Prepare the gts semantic occupancy introduced in [Occ3d](https://github.com/Tsinghua-MARS-Lab/Occ3D)
59 |
60 | 3. Download our generated train/val pickle files and put them in `data/`
61 |
62 | [nuscenes_infos_train_temporal_v3_scene.pkl](https://cloud.tsinghua.edu.cn/d/9e231ed16e4a4caca3bd/)
63 |
64 | [nuscenes_infos_val_temporal_v3_scene.pkl](https://cloud.tsinghua.edu.cn/d/9e231ed16e4a4caca3bd/)
65 |
66 | The dataset should be organized as follows:
67 |
68 | ```
69 | .
70 | └── data/
71 | ├── nuscenes # downloaded from www.nuscenes.org/
72 | │ ├── lidarseg
73 | │ ├── maps
74 | │ ├── samples
75 | │ ├── sweeps
76 | │ ├── v1.0-trainval
77 | │ └── gts # download from Occ3d
78 | ├── nuscenes_infos_train_temporal_v3_scene.pkl
79 | └── nuscenes_infos_val_temporal_v3_scene.pkl
80 | ```
81 | ### ckpt preparation
82 | Download the pretrained weights from [here](https://drive.google.com/drive/folders/1D1HugOG7JurEqmnQo4XbW_-Ji0chEq-e?usp=sharing) and put them in `ckpts` folder.
83 |
84 | ## 🏃 Run the code
85 | ### (optional) Preprocess resampled data
86 | ```
87 | cd resample
88 |
89 | python launch.py \
90 | --dst ../data/resampled_occ \
91 | --imageset ../data/nuscenes_infos_train_temporal_v3_scene.pkl \
92 | --data_path ../data/nuscenes
93 | ```
94 |
95 | ### OCC-VAE
96 | ```shell
97 | # train
98 | sh tools/train_vae.sh
99 |
100 | # eval
101 | sh tools/eval_vae.sh
102 |
103 | # visualize
104 | sh tools/vis_vae.sh
105 | ```
106 |
107 | ### DOME
108 | ```shell
109 | # train
110 | sh tools/train_diffusion.sh
111 |
112 | # eval
113 | sh tools/eval.sh
114 |
115 | # visualize
116 | sh tools/vis_diffusion.sh
117 | ```
118 |
119 | ## 🎫 Acknowledgment
120 | This code draws inspiration from their work. We sincerely appreciate their excellent contribution.
121 | - [OccWorld](https://github.com/wzzheng/OccWorld)
122 | - [Latte](https://github.com/Vchitect/Latte)
123 | - [Vista](https://github.com/OpenDriveLab/Vista.git)
124 | - [PyTorch-VAE](https://github.com/AntixK/PyTorch-VAE)
125 | - [A* serach](https://www.redblobgames.com/pathfinding/a-star/)
126 |
127 | ## 🖊️ Citation
128 | ```
129 | @article{gu2024dome,
130 | title={Dome: Taming diffusion model into high-fidelity controllable occupancy world model},
131 | author={Gu, Songen and Yin, Wei and Jin, Bu and Guo, Xiaoyang and Wang, Junming and Li, Haodong and Zhang, Qian and Long, Xiaoxiao},
132 | journal={arXiv preprint arXiv:2410.10429},
133 | year={2024}
134 | }
135 | ```
136 |
137 |
138 |
--------------------------------------------------------------------------------
/config/label_mapping/nuscenes-occ.yaml:
--------------------------------------------------------------------------------
1 | labels:
2 | 0: 'noise'
3 | 1: 'animal'
4 | 2: 'human.pedestrian.adult'
5 | 3: 'human.pedestrian.child'
6 | 4: 'human.pedestrian.construction_worker'
7 | 5: 'human.pedestrian.personal_mobility'
8 | 6: 'human.pedestrian.police_officer'
9 | 7: 'human.pedestrian.stroller'
10 | 8: 'human.pedestrian.wheelchair'
11 | 9: 'movable_object.barrier'
12 | 10: 'movable_object.debris'
13 | 11: 'movable_object.pushable_pullable'
14 | 12: 'movable_object.trafficcone'
15 | 13: 'static_object.bicycle_rack'
16 | 14: 'vehicle.bicycle'
17 | 15: 'vehicle.bus.bendy'
18 | 16: 'vehicle.bus.rigid'
19 | 17: 'vehicle.car'
20 | 18: 'vehicle.construction'
21 | 19: 'vehicle.emergency.ambulance'
22 | 20: 'vehicle.emergency.police'
23 | 21: 'vehicle.motorcycle'
24 | 22: 'vehicle.trailer'
25 | 23: 'vehicle.truck'
26 | 24: 'flat.driveable_surface'
27 | 25: 'flat.other'
28 | 26: 'flat.sidewalk'
29 | 27: 'flat.terrain'
30 | 28: 'static.manmade'
31 | 29: 'static.other'
32 | 30: 'static.vegetation'
33 | 31: 'vehicle.ego'
34 | 32: 'empty'
35 | labels_16:
36 | 0: 'others'
37 | 1: 'barrier'
38 | 2: 'bicycle'
39 | 3: 'bus'
40 | 4: 'car'
41 | 5: 'construction_vehicle'
42 | 6: 'motorcycle'
43 | 7: 'pedestrian'
44 | 8: 'traffic_cone'
45 | 9: 'trailer'
46 | 10: 'truck'
47 | 11: 'driveable_surface'
48 | 12: 'other_flat'
49 | 13: 'sidewalk'
50 | 14: 'terrain'
51 | 15: 'manmade'
52 | 16: 'vegetation'
53 | 17: 'empty'
54 | learning_map:
55 | 0: 0
56 | 1: 1
57 | 2: 2
58 | 3: 3
59 | 4: 4
60 | 5: 5
61 | 6: 6
62 | 7: 7
63 | 8: 8
64 | 9: 9
65 | 10: 10
66 | 11: 11
67 | 12: 12
68 | 13: 13
69 | 14: 14
70 | 15: 15
71 | 16: 16
72 | 17: 17
--------------------------------------------------------------------------------
/config/train_dome.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | work_dir = './work_dir/dome'
4 |
5 | start_frame = 0
6 | mid_frame = 4
7 | end_frame = 10
8 | eval_length = end_frame-mid_frame
9 |
10 | return_len_train = 11
11 | return_len_ = 10
12 | grad_max_norm = 1
13 | print_freq = 1
14 | max_epochs = 2000
15 | warmup_iters = 50
16 | ema = True
17 |
18 | load_from = ''
19 | vae_load_from = 'ckpts/occvae_latest.pth'
20 | port = 25098
21 | revise_ckpt = 3
22 | eval_every_epochs = 10
23 | save_every_epochs = 10
24 |
25 | multisteplr = True
26 | multisteplr_config = dict(
27 | decay_rate=1,
28 | decay_t=[
29 | 0,
30 | ],
31 | t_in_epochs=False,
32 | warmup_lr_init=1e-06,
33 | warmup_t=0)
34 | optimizer = dict(optimizer=dict(lr=0.0001, type='AdamW', weight_decay=0.0001))
35 |
36 | schedule = dict(
37 | beta_end=0.02,
38 | beta_schedule='linear',
39 | beta_start=0.0001,
40 | variance_type='learned_range')
41 |
42 | sample = dict(
43 | enable_temporal_attentions=True,
44 | enable_vae_temporal_decoder=True,
45 | guidance_scale=7.5,
46 | n_conds=4,
47 | num_sampling_steps=20,
48 | run_time=0,
49 | sample_method='ddpm',
50 | seed=None)
51 | p_use_pose_condition = 0.9
52 |
53 | replace_cond_frames = True
54 | cond_frames_choices = [
55 | [],
56 | [0],
57 | [0,1],
58 | [0,1,2],
59 | [0,1,2,3],
60 | ]
61 | data_path = 'data/nuscenes/'
62 |
63 | train_dataset_config = dict(
64 | type='nuScenesSceneDatasetLidar',
65 | data_path = data_path,
66 | return_len = return_len_train,
67 | offset = 0,
68 | times=5,
69 | imageset = 'data/nuscenes_infos_train_temporal_v3_scene.pkl',
70 | )
71 |
72 | val_dataset_config = dict(
73 | data_path='data/nuscenes/',
74 | imageset='data/nuscenes_infos_val_temporal_v3_scene.pkl',
75 | new_rel_pose=True,
76 | offset=0,
77 | return_len=return_len_,
78 | test_mode=True,
79 | times=1,
80 | type='nuScenesSceneDatasetLidar')
81 | train_wrapper_config = dict(phase='train', type='tpvformer_dataset_nuscenes')
82 | val_wrapper_config = dict(phase='val', type='tpvformer_dataset_nuscenes')
83 | train_loader = dict(batch_size=8, num_workers=1, shuffle=True)
84 | val_loader = dict(batch_size=1, num_workers=1, shuffle=False)
85 | loss = dict(
86 | loss_cfgs=[
87 | dict(
88 | input_dict=dict(ce_inputs='ce_inputs', ce_labels='ce_labels'),
89 | type='CeLoss',
90 | weight=1.0),
91 | ],
92 | type='MultiLoss')
93 | loss_input_convertion = dict()
94 |
95 | _dim_ = 16
96 | base_channel = 64
97 | expansion = 8
98 | n_e_ = 512
99 | num_heads=12
100 | hidden_size=768
101 |
102 | model = dict(
103 | delta_input=False,
104 | world_model=dict(
105 | attention_mode='xformers',
106 | class_dropout_prob=0.1,
107 | depth=28,
108 | extras=1,
109 | hidden_size=hidden_size,
110 | in_channels=64,
111 | input_size=25,
112 | learn_sigma=True,
113 | mlp_ratio=4.0,
114 | num_classes=1000,
115 | num_frames=return_len_train,
116 | num_heads=num_heads,
117 | patch_size=1,
118 | pose_encoder=dict(
119 | do_proj=True,
120 | in_channels=2,
121 | num_fut_ts=1,
122 | num_layers=2,
123 | num_modes=3,
124 | out_channels=hidden_size,
125 | type='PoseEncoder_fourier',
126 | zero_init=False),
127 | type='Dome'),
128 | sampling_method='SAMPLE',
129 | topk=10,
130 | vae=dict(
131 | encoder_cfg=dict(
132 | attn_resolutions=(50, ),
133 | ch=base_channel,
134 | ch_mult=(
135 | 1,
136 | 2,
137 | 4,
138 | 8,
139 | ),
140 | double_z=False,
141 | dropout=0.0,
142 | in_channels=128,
143 | num_res_blocks=2,
144 | out_ch=base_channel,
145 | resamp_with_conv=True,
146 | resolution=200,
147 | type='Encoder2D',
148 | z_channels=base_channel*2),
149 | decoder_cfg=dict(
150 | attn_resolutions=(50, ),
151 | ch=base_channel,
152 | ch_mult=(
153 | 1,
154 | 2,
155 | 4,
156 | 8,
157 | ),
158 | dropout=0.0,
159 | give_pre_end=False,
160 | in_channels=_dim_ * expansion,
161 | num_res_blocks=2,
162 | out_ch=_dim_ * expansion,
163 | resamp_with_conv=True,
164 | resolution=200,
165 | type='Decoder3D',
166 | z_channels=base_channel),
167 | expansion=expansion,
168 | num_classes=18,
169 | scaling_factor=0.18215,
170 | type='VAERes3D'))
171 | shapes = [[200,200],[100,100],[50,50],[25,25]]
172 |
173 | unique_label = [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16]
174 | label_mapping = './config/label_mapping/nuscenes-occ.yaml'
175 |
--------------------------------------------------------------------------------
/config/train_dome_resample.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | work_dir = './work_dir/dome'
4 |
5 | start_frame = 0
6 | mid_frame = 4
7 | end_frame = 10
8 | eval_length = end_frame-mid_frame
9 |
10 | return_len_train = 11
11 | return_len_ = 10
12 | grad_max_norm = 1
13 | print_freq = 1
14 | max_epochs = 2000
15 | warmup_iters = 50
16 | ema = True
17 |
18 | load_from = ''
19 | vae_load_from = 'ckpts/occvae_latest.pth'
20 | port = 25098
21 | revise_ckpt = 3
22 | eval_every_epochs = 10
23 | save_every_epochs = 10
24 |
25 | multisteplr = True
26 | multisteplr_config = dict(
27 | decay_rate=1,
28 | decay_t=[
29 | 0,
30 | ],
31 | t_in_epochs=False,
32 | warmup_lr_init=1e-06,
33 | warmup_t=0)
34 | optimizer = dict(optimizer=dict(lr=0.0001, type='AdamW', weight_decay=0.0001))
35 |
36 | schedule = dict(
37 | beta_end=0.02,
38 | beta_schedule='linear',
39 | beta_start=0.0001,
40 | variance_type='learned_range')
41 |
42 | sample = dict(
43 | enable_temporal_attentions=True,
44 | enable_vae_temporal_decoder=True,
45 | guidance_scale=7.5,
46 | n_conds=4,
47 | num_sampling_steps=20,
48 | run_time=0,
49 | sample_method='ddpm',
50 | seed=None)
51 | p_use_pose_condition = 0.9
52 |
53 | replace_cond_frames = True
54 | cond_frames_choices = [
55 | [],
56 | [0],
57 | [0,1],
58 | [0,1,2],
59 | [0,1,2,3],
60 | ]
61 | data_path = 'data/nuscenes/'
62 |
63 | train_dataset_config = dict(
64 | data_path='data/resampled_occ',
65 | imageset='data/nuscenes_infos_train_temporal_v3_scene.pkl',
66 | offset=0,
67 | raw_times=10,
68 | resample_times=1,
69 | return_len=return_len_train,
70 | times=1,
71 | type='nuScenesSceneDatasetLidarResample'
72 | )
73 |
74 | val_dataset_config = dict(
75 | data_path='data/nuscenes/',
76 | imageset='data/nuscenes_infos_val_temporal_v3_scene.pkl',
77 | new_rel_pose=True,
78 | offset=0,
79 | return_len=return_len_,
80 | test_mode=True,
81 | times=1,
82 | type='nuScenesSceneDatasetLidar')
83 | train_wrapper_config = dict(phase='train', type='tpvformer_dataset_nuscenes')
84 | val_wrapper_config = dict(phase='val', type='tpvformer_dataset_nuscenes')
85 | train_loader = dict(batch_size=8, num_workers=1, shuffle=True)
86 | val_loader = dict(batch_size=1, num_workers=1, shuffle=False)
87 | loss = dict(
88 | loss_cfgs=[
89 | dict(
90 | input_dict=dict(ce_inputs='ce_inputs', ce_labels='ce_labels'),
91 | type='CeLoss',
92 | weight=1.0),
93 | ],
94 | type='MultiLoss')
95 | loss_input_convertion = dict()
96 |
97 | _dim_ = 16
98 | base_channel = 64
99 | expansion = 8
100 | n_e_ = 512
101 | num_heads=12
102 | hidden_size=768
103 |
104 | model = dict(
105 | delta_input=False,
106 | world_model=dict(
107 | attention_mode='xformers',
108 | class_dropout_prob=0.1,
109 | depth=28,
110 | extras=1,
111 | hidden_size=hidden_size,
112 | in_channels=64,
113 | input_size=25,
114 | learn_sigma=True,
115 | mlp_ratio=4.0,
116 | num_classes=1000,
117 | num_frames=return_len_train,
118 | num_heads=num_heads,
119 | patch_size=1,
120 | pose_encoder=dict(
121 | do_proj=True,
122 | in_channels=2,
123 | num_fut_ts=1,
124 | num_layers=2,
125 | num_modes=3,
126 | out_channels=hidden_size,
127 | type='PoseEncoder_fourier',
128 | zero_init=False),
129 | type='Dome'),
130 | sampling_method='SAMPLE',
131 | topk=10,
132 | vae=dict(
133 | encoder_cfg=dict(
134 | attn_resolutions=(50, ),
135 | ch=base_channel,
136 | ch_mult=(
137 | 1,
138 | 2,
139 | 4,
140 | 8,
141 | ),
142 | double_z=False,
143 | dropout=0.0,
144 | in_channels=128,
145 | num_res_blocks=2,
146 | out_ch=base_channel,
147 | resamp_with_conv=True,
148 | resolution=200,
149 | type='Encoder2D',
150 | z_channels=base_channel*2),
151 | decoder_cfg=dict(
152 | attn_resolutions=(50, ),
153 | ch=base_channel,
154 | ch_mult=(
155 | 1,
156 | 2,
157 | 4,
158 | 8,
159 | ),
160 | dropout=0.0,
161 | give_pre_end=False,
162 | in_channels=_dim_ * expansion,
163 | num_res_blocks=2,
164 | out_ch=_dim_ * expansion,
165 | resamp_with_conv=True,
166 | resolution=200,
167 | type='Decoder3D',
168 | z_channels=base_channel),
169 | expansion=expansion,
170 | num_classes=18,
171 | scaling_factor=0.18215,
172 | type='VAERes3D'))
173 | shapes = [[200,200],[100,100],[50,50],[25,25]]
174 |
175 | unique_label = [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16]
176 | label_mapping = './config/label_mapping/nuscenes-occ.yaml'
177 |
--------------------------------------------------------------------------------
/config/train_occvae.py:
--------------------------------------------------------------------------------
1 | grad_max_norm = 35
2 | print_freq = 10
3 | max_epochs = 200
4 | warmup_iters = 200
5 | return_len_ =16 # debug #max for conv3d en-de
6 | return_len_ =12 # debug #max for conv2d en- conv3d de
7 | # return_len_ =1 # debug #max for conv2d en- conv3d de
8 |
9 | eval_every_epochs = 10
10 | save_every_epochs = 10
11 |
12 | multisteplr = False
13 | multisteplr_config = dict(
14 | decay_t = [87 * 500],
15 | decay_rate = 0.1,
16 | warmup_t = warmup_iters,
17 | warmup_lr_init = 1e-6,
18 | t_in_epochs = False
19 | )
20 |
21 |
22 | optimizer = dict(
23 | optimizer=dict(
24 | type='AdamW',
25 | lr=1e-3,
26 | weight_decay=0.01,
27 | ),
28 | )
29 |
30 | data_path = 'data/nuscenes/'
31 |
32 |
33 | train_dataset_config = dict(
34 | type='nuScenesSceneDatasetLidar',
35 | data_path = data_path,
36 | return_len = return_len_,
37 | offset = 0,
38 | imageset = 'data/nuscenes_infos_train_temporal_v3_scene.pkl',
39 | )
40 |
41 | val_dataset_config = dict(
42 | type='nuScenesSceneDatasetLidar',
43 | data_path = data_path,
44 | return_len = return_len_,
45 | offset = 0,
46 | imageset = 'data/nuscenes_infos_val_temporal_v3_scene.pkl',
47 | )
48 |
49 | train_wrapper_config = dict(
50 | type='tpvformer_dataset_nuscenes',
51 | phase='train',
52 | )
53 |
54 | val_wrapper_config = dict(
55 | type='tpvformer_dataset_nuscenes',
56 | phase='val',
57 | )
58 |
59 | train_loader = dict(
60 | batch_size = 1,
61 | shuffle = True,
62 | num_workers = 0,#1,
63 | )
64 |
65 | val_loader = dict(
66 | batch_size = 1,
67 | shuffle = False,
68 | num_workers = 0,#1,
69 | )
70 |
71 |
72 |
73 | loss = dict(
74 | type='MultiLoss',
75 | loss_cfgs=[
76 | dict(
77 | type='ReconLoss',
78 | weight=10.0,
79 | ignore_label=-100,
80 | use_weight=False,
81 | cls_weight=None,
82 | input_dict={
83 | 'logits': 'logits',
84 | 'labels': 'inputs'}),
85 | dict(
86 | type='LovaszLoss',
87 | weight=1.0,
88 | input_dict={
89 | 'logits': 'logits',
90 | 'labels': 'inputs'}),
91 | dict(
92 | type='KLLoss',
93 | # weight=0.00025,
94 | weight=0.00005,
95 | input_dict={
96 | 'z_mu': 'z_mu',
97 | 'z_sigma': 'z_sigma'}),
98 | # dict(
99 | # type='VQVAEEmbedLoss',
100 | # weight=1.0),
101 | ])
102 |
103 | loss_input_convertion = dict(
104 | logits='logits',
105 | # embed_loss='embed_loss'
106 | )
107 |
108 |
109 | load_from = ''
110 |
111 | _dim_ = 16
112 | expansion = 8 # class embedding 维度
113 | base_channel = 64
114 | n_e_ = 512
115 | model = dict(
116 | type = 'VAERes3D',
117 | encoder_cfg=dict(
118 | type='Encoder2D',
119 | # type='Encoder3D',#debug
120 | ch = base_channel,
121 | out_ch = base_channel,
122 | ch_mult = (1,2,4,8),
123 | num_res_blocks = 2,
124 | attn_resolutions = (50,),
125 | dropout = 0.0,
126 | resamp_with_conv = True,
127 | in_channels = _dim_ * expansion,
128 | resolution = 200,
129 | z_channels = base_channel * 2,
130 | double_z = False,
131 | # temporal_downsample= (
132 | # "",
133 | # "TimeDownsample2x",
134 | # "TimeDownsample2x",
135 | # "",
136 | # ),
137 | ),
138 | decoder_cfg=dict(
139 | type='Decoder3D',
140 | ch = base_channel,
141 | out_ch = _dim_ * expansion,
142 | ch_mult = (1,2,4,8),
143 | num_res_blocks = 2,
144 | attn_resolutions = (50,),
145 | dropout = 0.0,
146 | resamp_with_conv = True,
147 | in_channels = _dim_ * expansion,
148 | resolution = 200,
149 | z_channels = base_channel,
150 | give_pre_end = False,
151 | # temporal_upsample = ("", "", "TimeUpsample2x", "TimeUpsample2x"),
152 | ),
153 | num_classes=18,
154 | expansion=expansion,
155 | # vqvae_cfg=dict(
156 | # type='VectorQuantizer',
157 | # n_e = n_e_,
158 | # e_dim = base_channel * 2,
159 | # beta = 1.,
160 | # z_channels = base_channel * 2,
161 | # use_voxel=False)
162 | )
163 |
164 | shapes = [[200, 200], [100, 100], [50, 50], [25, 25]]
165 |
166 | unique_label = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]
167 | label_mapping = "./config/label_mapping/nuscenes-occ.yaml"
--------------------------------------------------------------------------------
/dataset/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from mmengine.registry import Registry
3 | OPENOCC_DATASET = Registry('openocc_dataset')
4 | OPENOCC_DATAWRAPPER = Registry('openocc_datawrapper')
5 |
6 | from .dataset import nuScenesSceneDatasetLidar, nuScenesSceneDatasetLidarTraverse
7 | from .dataset_wrapper import tpvformer_dataset_nuscenes, custom_collate_fn_temporal
8 | from .sampler import CustomDistributedSampler
9 | from torch.utils.data.distributed import DistributedSampler
10 | from torch.utils.data.dataloader import DataLoader
11 | import yaml
12 |
13 | def get_dataloader(
14 | train_dataset_config,
15 | val_dataset_config,
16 | train_wrapper_config,
17 | val_wrapper_config,
18 | train_loader,
19 | val_loader,
20 | nusc=dict(
21 | version='v1.0-trainval',
22 | dataroot='data/nuscenes'),
23 | dist=False,
24 | iter_resume=False,
25 | train_sampler_config=dict(
26 | shuffle=True,
27 | drop_last=True),
28 | val_sampler_config=dict(
29 | shuffle=False,
30 | drop_last=False),
31 | ):
32 | # if nusc is not None:
33 | # from nuscenes import NuScenes
34 | # nusc = NuScenes(**nusc)
35 |
36 | train_dataset = OPENOCC_DATASET.build(
37 | train_dataset_config,
38 | default_args={'nusc': nusc})
39 | val_dataset = OPENOCC_DATASET.build(
40 | val_dataset_config,
41 | default_args={'nusc': nusc})
42 |
43 | train_wrapper = OPENOCC_DATAWRAPPER.build(
44 | train_wrapper_config,
45 | default_args={'in_dataset': train_dataset})
46 | val_wrapper = OPENOCC_DATAWRAPPER.build(
47 | val_wrapper_config,
48 | default_args={'in_dataset': val_dataset})
49 |
50 | train_sampler = val_sampler = None
51 | if dist:
52 | if iter_resume:
53 | train_sampler = CustomDistributedSampler(train_wrapper, **train_sampler_config)
54 | else:
55 | train_sampler = DistributedSampler(train_wrapper, **train_sampler_config)
56 | val_sampler = DistributedSampler(val_wrapper, **val_sampler_config)
57 |
58 | train_dataset_loader = DataLoader(
59 | dataset=train_wrapper,
60 | batch_size=train_loader["batch_size"],
61 | collate_fn=custom_collate_fn_temporal,
62 | shuffle=False if dist else train_loader["shuffle"],
63 | sampler=train_sampler,
64 | num_workers=train_loader["num_workers"],
65 | pin_memory=True)
66 | val_dataset_loader = DataLoader(
67 | dataset=val_wrapper,
68 | batch_size=val_loader["batch_size"],
69 | collate_fn=custom_collate_fn_temporal,
70 | shuffle=False,
71 | sampler=val_sampler,
72 | num_workers=val_loader["num_workers"],
73 | pin_memory=True)
74 |
75 | return train_dataset_loader, val_dataset_loader
76 |
77 |
78 | def get_nuScenes_label_name(label_mapping):
79 | with open(label_mapping, 'r') as stream:
80 | nuScenesyaml = yaml.safe_load(stream)
81 | nuScenes_label_name = dict()
82 | for i in sorted(list(nuScenesyaml['learning_map'].keys()))[::-1]:
83 | val_ = nuScenesyaml['learning_map'][i]
84 | nuScenes_label_name[val_] = nuScenesyaml['labels_16'][val_]
85 | return nuScenes_label_name
86 |
87 |
88 |
--------------------------------------------------------------------------------
/dataset/dataset_wrapper.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np, torch
3 | from torch.utils import data
4 | import torch.nn.functional as F
5 | from copy import deepcopy
6 | from mmengine import MMLogger
7 | logger = MMLogger.get_instance('genocc')
8 | try:
9 | from . import OPENOCC_DATAWRAPPER
10 | except:
11 | from mmengine.registry import Registry
12 | OPENOCC_DATAWRAPPER = Registry('openocc_datawrapper')
13 | import torch
14 |
15 | @OPENOCC_DATAWRAPPER.register_module()
16 | class tpvformer_dataset_nuscenes(data.Dataset):
17 | def __init__(
18 | self,
19 | in_dataset,
20 | phase='train',
21 | ):
22 | 'Initialization'
23 | self.point_cloud_dataset = in_dataset
24 | self.phase = phase
25 |
26 | def __len__(self):
27 | return len(self.point_cloud_dataset)
28 |
29 | def to_tensor(self, imgs):
30 | imgs = np.stack(imgs).astype(np.float32)
31 | imgs = torch.from_numpy(imgs)
32 | imgs = imgs.permute(0, 3, 1, 2)
33 | return imgs
34 |
35 | def __getitem__(self, index):
36 | input, target, metas = self.point_cloud_dataset[index]
37 | #### adapt to vae input
38 | input = torch.from_numpy(input)
39 | target = torch.from_numpy(target)
40 | return input, target, metas
41 |
42 |
43 | def custom_collate_fn_temporal(data):
44 | data_tuple = []
45 | for i, item in enumerate(data[0]):
46 | if isinstance(item, torch.Tensor):
47 | data_tuple.append(torch.stack([d[i] for d in data]))
48 | elif isinstance(item, (dict, str)):
49 | data_tuple.append([d[i] for d in data])
50 | elif item is None:
51 | data_tuple.append(None)
52 | else:
53 | raise NotImplementedError
54 | return data_tuple
55 |
--------------------------------------------------------------------------------
/dataset/sampler.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import TypeVar, Optional, Iterator
3 |
4 | import torch
5 | from torch.utils.data import Sampler, Dataset
6 | import torch.distributed as dist
7 | from typing import Callable
8 | import pandas as pd
9 | from torch.utils.data.distributed import DistributedSampler
10 | import torchvision
11 |
12 |
13 |
14 | T_co = TypeVar('T_co', covariant=True)
15 |
16 |
17 | class CustomDistributedSampler(Sampler[T_co]):
18 | r"""Sampler that restricts data loading to a subset of the dataset.
19 |
20 | It is especially useful in conjunction with
21 | :class:`torch.nn.parallel.DistributedDataParallel`. In such a case, each
22 | process can pass a :class:`~torch.utils.data.DistributedSampler` instance as a
23 | :class:`~torch.utils.data.DataLoader` sampler, and load a subset of the
24 | original dataset that is exclusive to it.
25 |
26 | .. note::
27 | Dataset is assumed to be of constant size.
28 |
29 | Arguments:
30 | dataset: Dataset used for sampling.
31 | num_replicas (int, optional): Number of processes participating in
32 | distributed training. By default, :attr:`rank` is retrieved from the
33 | current distributed group.
34 | rank (int, optional): Rank of the current process within :attr:`num_replicas`.
35 | By default, :attr:`rank` is retrieved from the current distributed
36 | group.
37 | shuffle (bool, optional): If ``True`` (default), sampler will shuffle the
38 | indices.
39 | seed (int, optional): random seed used to shuffle the sampler if
40 | :attr:`shuffle=True`. This number should be identical across all
41 | processes in the distributed group. Default: ``0``.
42 | drop_last (bool, optional): if ``True``, then the sampler will drop the
43 | tail of the data to make it evenly divisible across the number of
44 | replicas. If ``False``, the sampler will add extra indices to make
45 | the data evenly divisible across the replicas. Default: ``False``.
46 |
47 | .. warning::
48 | In distributed mode, calling the :meth:`set_epoch` method at
49 | the beginning of each epoch **before** creating the :class:`DataLoader` iterator
50 | is necessary to make shuffling work properly across multiple epochs. Otherwise,
51 | the same ordering will be always used.
52 |
53 | Example::
54 |
55 | >>> sampler = DistributedSampler(dataset) if is_distributed else None
56 | >>> loader = DataLoader(dataset, shuffle=(sampler is None),
57 | ... sampler=sampler)
58 | >>> for epoch in range(start_epoch, n_epochs):
59 | ... if is_distributed:
60 | ... sampler.set_epoch(epoch)
61 | ... train(loader)
62 | """
63 |
64 | def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None,
65 | rank: Optional[int] = None, shuffle: bool = True,
66 | seed: int = 0, drop_last: bool = False, last_iter: int = 0) -> None:
67 | if num_replicas is None:
68 | if not dist.is_available():
69 | raise RuntimeError("Requires distributed package to be available")
70 | num_replicas = dist.get_world_size()
71 | if rank is None:
72 | if not dist.is_available():
73 | raise RuntimeError("Requires distributed package to be available")
74 | rank = dist.get_rank()
75 | self.dataset = dataset
76 | self.num_replicas = num_replicas
77 | self.rank = rank
78 | self.epoch = 0
79 | self.drop_last = drop_last
80 | # If the dataset length is evenly divisible by # of replicas, then there
81 | # is no need to drop any data, since the dataset will be split equally.
82 | if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore
83 | # Split to nearest available length that is evenly divisible.
84 | # This is to ensure each rank receives the same amount of data when
85 | # using this Sampler.
86 | self.num_samples = math.ceil(
87 | # `type:ignore` is required because Dataset cannot provide a default __len__
88 | # see NOTE in pytorch/torch/utils/data/sampler.py
89 | (len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore
90 | )
91 | else:
92 | self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore
93 | self.total_size = self.num_samples * self.num_replicas
94 | self.shuffle = shuffle
95 | self.seed = seed
96 | self.first_run = True
97 | self.last_iter = last_iter
98 |
99 | def __iter__(self) -> Iterator[T_co]:
100 | if self.shuffle:
101 | # deterministically shuffle based on epoch and seed
102 | g = torch.Generator()
103 | g.manual_seed(self.seed + self.epoch)
104 | indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore
105 | else:
106 | indices = list(range(len(self.dataset))) # type: ignore
107 |
108 | if not self.drop_last:
109 | # add extra samples to make it evenly divisible
110 | indices += indices[:(self.total_size - len(indices))]
111 | else:
112 | # remove tail of data to make it evenly divisible.
113 | indices = indices[:self.total_size]
114 | assert len(indices) == self.total_size
115 |
116 | # subsample
117 | indices = indices[self.rank:self.total_size:self.num_replicas]
118 | if not self.first_run:
119 | assert len(indices) == self.num_samples
120 | else:
121 | indices = indices[self.last_iter:]
122 | self.last_iter = 0
123 | self.first_run = False
124 |
125 | return iter(indices)
126 |
127 | def __len__(self) -> int:
128 | return self.num_samples
129 |
130 | def set_epoch(self, epoch: int) -> None:
131 | r"""
132 | Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
133 | use a different random ordering for each epoch. Otherwise, the next iteration of this
134 | sampler will yield the same ordering.
135 |
136 | Arguments:
137 | epoch (int): Epoch number.
138 | """
139 | self.epoch = epoch
140 |
141 | def set_last_iter(self, last_iter: int):
142 | self.last_iter = last_iter
143 |
--------------------------------------------------------------------------------
/diffusion/__init__.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | from . import gaussian_diffusion as gd
7 | from .respace import SpacedDiffusion, space_timesteps
8 |
9 |
10 | def create_diffusion(
11 | timestep_respacing,
12 | noise_schedule="linear",
13 | use_kl=False,
14 | sigma_small=False,
15 | predict_xstart=False,
16 | learn_sigma=True,
17 | # learn_sigma=False,
18 | rescale_learned_sigmas=False,
19 | diffusion_steps=1000,
20 | beta_start = 0.0001,
21 | beta_end = 0.02,
22 | replace_cond_frames=False,
23 | cond_frames_choices=None,
24 | ):
25 | betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps,beta_start,beta_end)
26 | if use_kl:
27 | loss_type = gd.LossType.RESCALED_KL
28 | elif rescale_learned_sigmas:
29 | loss_type = gd.LossType.RESCALED_MSE
30 | else:
31 | loss_type = gd.LossType.MSE
32 | if timestep_respacing is None or timestep_respacing == "":
33 | timestep_respacing = [diffusion_steps]
34 | return SpacedDiffusion(
35 | use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
36 | betas=betas,
37 | model_mean_type=(
38 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
39 | ),
40 | model_var_type=(
41 | (
42 | gd.ModelVarType.FIXED_LARGE
43 | if not sigma_small
44 | else gd.ModelVarType.FIXED_SMALL
45 | )
46 | if not learn_sigma
47 | else gd.ModelVarType.LEARNED_RANGE
48 | ),
49 | loss_type=loss_type,
50 | replace_cond_frames=replace_cond_frames,
51 | cond_frames_choices=cond_frames_choices,
52 | # rescale_timesteps=rescale_timesteps,
53 | )
54 |
--------------------------------------------------------------------------------
/diffusion/diffusion_utils.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | import torch as th
7 | import numpy as np
8 |
9 |
10 | def normal_kl(mean1, logvar1, mean2, logvar2):
11 | """
12 | Compute the KL divergence between two gaussians.
13 | Shapes are automatically broadcasted, so batches can be compared to
14 | scalars, among other use cases.
15 | """
16 | tensor = None
17 | for obj in (mean1, logvar1, mean2, logvar2):
18 | if isinstance(obj, th.Tensor):
19 | tensor = obj
20 | break
21 | assert tensor is not None, "at least one argument must be a Tensor"
22 |
23 | # Force variances to be Tensors. Broadcasting helps convert scalars to
24 | # Tensors, but it does not work for th.exp().
25 | logvar1, logvar2 = [
26 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
27 | for x in (logvar1, logvar2)
28 | ]
29 |
30 | return 0.5 * (
31 | -1.0
32 | + logvar2
33 | - logvar1
34 | + th.exp(logvar1 - logvar2)
35 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
36 | )
37 |
38 |
39 | def approx_standard_normal_cdf(x):
40 | """
41 | A fast approximation of the cumulative distribution function of the
42 | standard normal.
43 | """
44 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
45 |
46 |
47 | def continuous_gaussian_log_likelihood(x, *, means, log_scales):
48 | """
49 | Compute the log-likelihood of a continuous Gaussian distribution.
50 | :param x: the targets
51 | :param means: the Gaussian mean Tensor.
52 | :param log_scales: the Gaussian log stddev Tensor.
53 | :return: a tensor like x of log probabilities (in nats).
54 | """
55 | centered_x = x - means
56 | inv_stdv = th.exp(-log_scales)
57 | normalized_x = centered_x * inv_stdv
58 | log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x)
59 | return log_probs
60 |
61 |
62 | def discretized_gaussian_log_likelihood(x, *, means, log_scales):
63 | """
64 | Compute the log-likelihood of a Gaussian distribution discretizing to a
65 | given image.
66 | :param x: the target images. It is assumed that this was uint8 values,
67 | rescaled to the range [-1, 1].
68 | :param means: the Gaussian mean Tensor.
69 | :param log_scales: the Gaussian log stddev Tensor.
70 | :return: a tensor like x of log probabilities (in nats).
71 | """
72 | assert x.shape == means.shape == log_scales.shape
73 | centered_x = x - means
74 | inv_stdv = th.exp(-log_scales)
75 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
76 | cdf_plus = approx_standard_normal_cdf(plus_in)
77 | min_in = inv_stdv * (centered_x - 1.0 / 255.0)
78 | cdf_min = approx_standard_normal_cdf(min_in)
79 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
80 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
81 | cdf_delta = cdf_plus - cdf_min
82 | log_probs = th.where(
83 | x < -0.999,
84 | log_cdf_plus,
85 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
86 | )
87 | assert log_probs.shape == x.shape
88 | return log_probs
89 |
--------------------------------------------------------------------------------
/diffusion/respace.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 | import torch
6 | import numpy as np
7 | import torch as th
8 |
9 | from .gaussian_diffusion import GaussianDiffusion
10 |
11 |
12 | def space_timesteps(num_timesteps, section_counts):
13 | """
14 | Create a list of timesteps to use from an original diffusion process,
15 | given the number of timesteps we want to take from equally-sized portions
16 | of the original process.
17 | For example, if there's 300 timesteps and the section counts are [10,15,20]
18 | then the first 100 timesteps are strided to be 10 timesteps, the second 100
19 | are strided to be 15 timesteps, and the final 100 are strided to be 20.
20 | If the stride is a string starting with "ddim", then the fixed striding
21 | from the DDIM paper is used, and only one section is allowed.
22 | :param num_timesteps: the number of diffusion steps in the original
23 | process to divide up.
24 | :param section_counts: either a list of numbers, or a string containing
25 | comma-separated numbers, indicating the step count
26 | per section. As a special case, use "ddimN" where N
27 | is a number of steps to use the striding from the
28 | DDIM paper.
29 | :return: a set of diffusion steps from the original process to use.
30 | """
31 | if isinstance(section_counts, str):
32 | if section_counts.startswith("ddim"):
33 | desired_count = int(section_counts[len("ddim") :])
34 | for i in range(1, num_timesteps):
35 | if len(range(0, num_timesteps, i)) == desired_count:
36 | return set(range(0, num_timesteps, i))
37 | raise ValueError(
38 | f"cannot create exactly {num_timesteps} steps with an integer stride"
39 | )
40 | section_counts = [int(x) for x in section_counts.split(",")]
41 | size_per = num_timesteps // len(section_counts)
42 | extra = num_timesteps % len(section_counts)
43 | start_idx = 0
44 | all_steps = []
45 | for i, section_count in enumerate(section_counts):
46 | size = size_per + (1 if i < extra else 0)
47 | if size < section_count:
48 | raise ValueError(
49 | f"cannot divide section of {size} steps into {section_count}"
50 | )
51 | if section_count <= 1:
52 | frac_stride = 1
53 | else:
54 | frac_stride = (size - 1) / (section_count - 1)
55 | cur_idx = 0.0
56 | taken_steps = []
57 | for _ in range(section_count):
58 | taken_steps.append(start_idx + round(cur_idx))
59 | cur_idx += frac_stride
60 | all_steps += taken_steps
61 | start_idx += size
62 | return set(all_steps)
63 |
64 |
65 | class SpacedDiffusion(GaussianDiffusion):
66 | """
67 | A diffusion process which can skip steps in a base diffusion process.
68 | :param use_timesteps: a collection (sequence or set) of timesteps from the
69 | original diffusion process to retain.
70 | :param kwargs: the kwargs to create the base diffusion process.
71 | """
72 |
73 | def __init__(self, use_timesteps, **kwargs):
74 | self.use_timesteps = set(use_timesteps)
75 | self.timestep_map = []
76 | self.original_num_steps = len(kwargs["betas"])
77 |
78 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
79 | last_alpha_cumprod = 1.0
80 | new_betas = []
81 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
82 | if i in self.use_timesteps:
83 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
84 | last_alpha_cumprod = alpha_cumprod
85 | self.timestep_map.append(i)
86 | kwargs["betas"] = np.array(new_betas)
87 | super().__init__(**kwargs)
88 |
89 | def p_mean_variance(
90 | self, model, *args, **kwargs
91 | ): # pylint: disable=signature-differs
92 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
93 |
94 | # @torch.compile
95 | def training_losses(
96 | self, model, *args, **kwargs
97 | ): # pylint: disable=signature-differs
98 | return super().training_losses(self._wrap_model(model), *args, **kwargs)
99 |
100 | def condition_mean(self, cond_fn, *args, **kwargs):
101 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
102 |
103 | def condition_score(self, cond_fn, *args, **kwargs):
104 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
105 |
106 | def _wrap_model(self, model):
107 | if isinstance(model, _WrappedModel):
108 | return model
109 | return _WrappedModel(
110 | model, self.timestep_map, self.original_num_steps
111 | )
112 |
113 | def _scale_timesteps(self, t):
114 | # Scaling is done by the wrapped model.
115 | return t
116 |
117 |
118 | class _WrappedModel:
119 | def __init__(self, model, timestep_map, original_num_steps):
120 | self.model = model
121 | self.timestep_map = timestep_map
122 | # self.rescale_timesteps = rescale_timesteps
123 | self.original_num_steps = original_num_steps
124 |
125 | def __call__(self, x, ts, **kwargs):
126 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
127 | new_ts = map_tensor[ts]
128 | # if self.rescale_timesteps:
129 | # new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
130 | return self.model(x, new_ts, **kwargs)
131 |
--------------------------------------------------------------------------------
/diffusion/timestep_sampler.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | from abc import ABC, abstractmethod
7 |
8 | import numpy as np
9 | import torch as th
10 | import torch.distributed as dist
11 |
12 |
13 | def create_named_schedule_sampler(name, diffusion):
14 | """
15 | Create a ScheduleSampler from a library of pre-defined samplers.
16 | :param name: the name of the sampler.
17 | :param diffusion: the diffusion object to sample for.
18 | """
19 | if name == "uniform":
20 | return UniformSampler(diffusion)
21 | elif name == "loss-second-moment":
22 | return LossSecondMomentResampler(diffusion)
23 | else:
24 | raise NotImplementedError(f"unknown schedule sampler: {name}")
25 |
26 |
27 | class ScheduleSampler(ABC):
28 | """
29 | A distribution over timesteps in the diffusion process, intended to reduce
30 | variance of the objective.
31 | By default, samplers perform unbiased importance sampling, in which the
32 | objective's mean is unchanged.
33 | However, subclasses may override sample() to change how the resampled
34 | terms are reweighted, allowing for actual changes in the objective.
35 | """
36 |
37 | @abstractmethod
38 | def weights(self):
39 | """
40 | Get a numpy array of weights, one per diffusion step.
41 | The weights needn't be normalized, but must be positive.
42 | """
43 |
44 | def sample(self, batch_size, device):
45 | """
46 | Importance-sample timesteps for a batch.
47 | :param batch_size: the number of timesteps.
48 | :param device: the torch device to save to.
49 | :return: a tuple (timesteps, weights):
50 | - timesteps: a tensor of timestep indices.
51 | - weights: a tensor of weights to scale the resulting losses.
52 | """
53 | w = self.weights()
54 | p = w / np.sum(w)
55 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
56 | indices = th.from_numpy(indices_np).long().to(device)
57 | weights_np = 1 / (len(p) * p[indices_np])
58 | weights = th.from_numpy(weights_np).float().to(device)
59 | return indices, weights
60 |
61 |
62 | class UniformSampler(ScheduleSampler):
63 | def __init__(self, diffusion):
64 | self.diffusion = diffusion
65 | self._weights = np.ones([diffusion.num_timesteps])
66 |
67 | def weights(self):
68 | return self._weights
69 |
70 |
71 | class LossAwareSampler(ScheduleSampler):
72 | def update_with_local_losses(self, local_ts, local_losses):
73 | """
74 | Update the reweighting using losses from a model.
75 | Call this method from each rank with a batch of timesteps and the
76 | corresponding losses for each of those timesteps.
77 | This method will perform synchronization to make sure all of the ranks
78 | maintain the exact same reweighting.
79 | :param local_ts: an integer Tensor of timesteps.
80 | :param local_losses: a 1D Tensor of losses.
81 | """
82 | batch_sizes = [
83 | th.tensor([0], dtype=th.int32, device=local_ts.device)
84 | for _ in range(dist.get_world_size())
85 | ]
86 | dist.all_gather(
87 | batch_sizes,
88 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
89 | )
90 |
91 | # Pad all_gather batches to be the maximum batch size.
92 | batch_sizes = [x.item() for x in batch_sizes]
93 | max_bs = max(batch_sizes)
94 |
95 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
96 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
97 | dist.all_gather(timestep_batches, local_ts)
98 | dist.all_gather(loss_batches, local_losses)
99 | timesteps = [
100 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
101 | ]
102 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
103 | self.update_with_all_losses(timesteps, losses)
104 |
105 | @abstractmethod
106 | def update_with_all_losses(self, ts, losses):
107 | """
108 | Update the reweighting using losses from a model.
109 | Sub-classes should override this method to update the reweighting
110 | using losses from the model.
111 | This method directly updates the reweighting without synchronizing
112 | between workers. It is called by update_with_local_losses from all
113 | ranks with identical arguments. Thus, it should have deterministic
114 | behavior to maintain state across workers.
115 | :param ts: a list of int timesteps.
116 | :param losses: a list of float losses, one per timestep.
117 | """
118 |
119 |
120 | class LossSecondMomentResampler(LossAwareSampler):
121 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
122 | self.diffusion = diffusion
123 | self.history_per_term = history_per_term
124 | self.uniform_prob = uniform_prob
125 | self._loss_history = np.zeros(
126 | [diffusion.num_timesteps, history_per_term], dtype=np.float64
127 | )
128 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
129 |
130 | def weights(self):
131 | if not self._warmed_up():
132 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
133 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
134 | weights /= np.sum(weights)
135 | weights *= 1 - self.uniform_prob
136 | weights += self.uniform_prob / len(weights)
137 | return weights
138 |
139 | def update_with_all_losses(self, ts, losses):
140 | for t, loss in zip(ts, losses):
141 | if self._loss_counts[t] == self.history_per_term:
142 | # Shift out the oldest loss term.
143 | self._loss_history[t, :-1] = self._loss_history[t, 1:]
144 | self._loss_history[t, -1] = loss
145 | else:
146 | self._loss_history[t, self._loss_counts[t]] = loss
147 | self._loss_counts[t] += 1
148 |
149 | def _warmed_up(self):
150 | return (self._loss_counts == self.history_per_term).all()
151 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: Occworld
2 | channels:
3 | - pytorch
4 | - nvidia
5 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge
6 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
7 | - conda-forge
8 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/
9 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/
10 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge/
11 | dependencies:
12 | - _libgcc_mutex=0.1=main
13 | - _openmp_mutex=5.1=1_gnu
14 | - arrow=1.2.3=py38h06a4308_1
15 | - attrs=23.1.0=pyh71513ae_1
16 | - backoff=2.2.1=pyhd8ed1ab_0
17 | - backports=1.0=pyhd8ed1ab_3
18 | - backports.functools_lru_cache=1.6.5=pyhd8ed1ab_0
19 | - beautifulsoup4=4.12.2=pyha770c72_0
20 | - blas=1.0=mkl
21 | - blessed=1.19.1=pyhe4f9e05_2
22 | - boto3=1.28.66=pyhd8ed1ab_0
23 | - botocore=1.31.66=pyhd8ed1ab_0
24 | - brotlipy=0.7.0=py38h27cfd23_1003
25 | - bzip2=1.0.8=h7b6447c_0
26 | - ca-certificates=2024.7.2=h06a4308_0
27 | - cachecontrol=0.12.11=py38h06a4308_1
28 | - certifi=2024.7.4=py38h06a4308_0
29 | - cffi=1.15.0=py38h7f8727e_0
30 | - charset-normalizer=2.0.4=pyhd3eb1b0_0
31 | - cleo=2.0.1=pyhd8ed1ab_0
32 | - click=8.1.7=unix_pyh707e725_0
33 | - colorama=0.4.6=pyhd8ed1ab_0
34 | - commonmark=0.9.1=pyhd3eb1b0_0
35 | - crashtest=0.4.1=pyhd8ed1ab_0
36 | - croniter=1.4.1=pyhd8ed1ab_0
37 | - cryptography=3.4.8=py38h3e25421_1
38 | - cuda-cudart=11.8.89=0
39 | - cuda-cupti=11.8.87=0
40 | - cuda-libraries=11.8.0=0
41 | - cuda-nvrtc=11.8.89=0
42 | - cuda-nvtx=11.8.86=0
43 | - cuda-runtime=11.8.0=0
44 | - dataclasses=0.8=pyh6d0b6a4_7
45 | - dateutils=0.6.12=py_0
46 | - dbus=1.13.18=hb2f20db_0
47 | - deepdiff=5.8.1=pyhd8ed1ab_0
48 | - distlib=0.3.7=pyhd8ed1ab_0
49 | - dulwich=0.21.3=py38h5eee18b_0
50 | - exceptiongroup=1.1.3=pyhd8ed1ab_0
51 | - expat=2.2.10=h9c3ff4c_0
52 | - fastapi=0.103.2=pyhd8ed1ab_0
53 | - ffmpeg=4.3=hf484d3e_0
54 | - filelock=3.9.0=py38h06a4308_0
55 | - freetype=2.12.1=h4a9f257_0
56 | - future=0.18.3=py38h06a4308_0
57 | - gettext=0.21.0=h39681ba_1
58 | - giflib=5.2.1=h5eee18b_3
59 | - glib=2.56.2=had28632_1001
60 | - gmp=6.2.1=h295c915_3
61 | - gmpy2=2.1.2=py38heeb90bb_0
62 | - gnutls=3.6.15=he1e5248_0
63 | - h11=0.14.0=pyhd8ed1ab_0
64 | - html5lib=1.1=pyh9f0ad1d_0
65 | - icu=58.2=hf484d3e_1000
66 | - idna=3.4=py38h06a4308_0
67 | - importlib-metadata=6.8.0=pyha770c72_0
68 | - importlib_metadata=6.8.0=hd8ed1ab_0
69 | - inquirer=3.1.3=pyhd8ed1ab_0
70 | - intel-openmp=2021.4.0=h06a4308_3561
71 | - itsdangerous=2.1.2=pyhd8ed1ab_0
72 | - jaraco.classes=3.3.0=pyhd8ed1ab_0
73 | - jeepney=0.8.0=pyhd8ed1ab_0
74 | - jinja2=3.1.2=py38h06a4308_0
75 | - jpeg=9e=h5eee18b_1
76 | - jsonschema=4.19.2=py38h06a4308_0
77 | - jsonschema-specifications=2023.7.1=py38h06a4308_0
78 | - keyring=23.13.1=py38h578d9bd_0
79 | - lame=3.100=h7b6447c_0
80 | - lcms2=2.12=h3be6417_0
81 | - lerc=3.0=h295c915_0
82 | - libaio=0.3.113=h5eee18b_0
83 | - libcublas=11.11.3.6=0
84 | - libcufft=10.9.0.58=0
85 | - libcufile=1.7.2.10=0
86 | - libcurand=10.3.3.141=0
87 | - libcusolver=11.4.1.48=0
88 | - libcusparse=11.7.5.86=0
89 | - libdeflate=1.17=h5eee18b_0
90 | - libedit=3.1.20221030=h5eee18b_0
91 | - libffi=3.2.1=hf484d3e_1007
92 | - libgcc-ng=11.2.0=h1234567_1
93 | - libgomp=11.2.0=h1234567_1
94 | - libiconv=1.16=h7f8727e_2
95 | - libidn2=2.3.4=h5eee18b_0
96 | - libnpp=11.8.0.86=0
97 | - libnvjpeg=11.9.0.86=0
98 | - libpng=1.6.39=h5eee18b_0
99 | - libstdcxx-ng=11.2.0=h1234567_1
100 | - libtasn1=4.19.0=h5eee18b_0
101 | - libtiff=4.5.1=h6a678d5_0
102 | - libunistring=0.9.10=h27cfd23_0
103 | - libwebp=1.2.4=h11a3e52_1
104 | - libwebp-base=1.2.4=h5eee18b_1
105 | - libxml2=2.10.4=hcbfbd50_0
106 | - lightning=2.1.0=pyhd8ed1ab_0
107 | - lightning-cloud=0.5.42=pyhd8ed1ab_0
108 | - lightning-utilities=0.9.0=pyhd8ed1ab_0
109 | - lockfile=0.12.2=py_1
110 | - lz4-c=1.9.4=h6a678d5_0
111 | - markdown-it-py=3.0.0=pyhd8ed1ab_0
112 | - markupsafe=2.1.1=py38h7f8727e_0
113 | - mkl=2021.4.0=h06a4308_640
114 | - mkl-service=2.4.0=py38h7f8727e_0
115 | - mkl_fft=1.3.1=py38hd3c417c_0
116 | - mkl_random=1.2.2=py38h51133e4_0
117 | - more-itertools=10.1.0=pyhd8ed1ab_0
118 | - mpc=1.1.0=h10f8cd9_1
119 | - mpfr=4.0.2=hb69a4c5_1
120 | - mpmath=1.3.0=py38h06a4308_0
121 | - msgpack-python=1.0.3=py38hd09550d_0
122 | - ncurses=6.4=h6a678d5_0
123 | - nettle=3.7.3=hbbd107a_1
124 | - networkx=3.1=py38h06a4308_0
125 | - numpy=1.24.3=py38h14f4228_0
126 | - numpy-base=1.24.3=py38h31eccc5_0
127 | - openh264=2.1.1=h4ff587b_0
128 | - openssl=1.1.1w=h7f8727e_0
129 | - ordered-set=4.1.0=pyhd8ed1ab_0
130 | - pcre=8.45=h9c3ff4c_0
131 | - pexpect=4.8.0=pyh1a96a4e_2
132 | - pillow=9.4.0=py38h6a678d5_0
133 | - pip=23.2.1=py38h06a4308_0
134 | - pkginfo=1.9.6=pyhd8ed1ab_0
135 | - pkgutil-resolve-name=1.3.10=pyhd8ed1ab_1
136 | - poetry=1.4.0=py38h06a4308_0
137 | - poetry-core=1.5.1=py38h06a4308_0
138 | - poetry-plugin-export=1.3.0=py38h4849bfd_0
139 | - ptyprocess=0.7.0=pyhd3deb0d_0
140 | - pycparser=2.21=pyhd3eb1b0_0
141 | - pydantic=1.10.12=py38h5eee18b_1
142 | - pygments=2.16.1=pyhd8ed1ab_0
143 | - pyjwt=2.8.0=pyhd8ed1ab_0
144 | - pyopenssl=20.0.1=pyhd8ed1ab_0
145 | - pyproject_hooks=1.0.0=pyhd8ed1ab_0
146 | - pyrsistent=0.18.0=py38heee7806_0
147 | - pysocks=1.7.1=py38h06a4308_0
148 | - python=3.8.0=h0371630_2
149 | - python-build=0.10.0=pyhd8ed1ab_1
150 | - python-dateutil=2.8.2=pyhd8ed1ab_0
151 | - python-editor=1.0.4=py_0
152 | - python-installer=0.6.0=py38h06a4308_0
153 | - python-multipart=0.0.6=pyhd8ed1ab_0
154 | - python_abi=3.8=2_cp38
155 | - pytorch=2.0.1=py3.8_cuda11.8_cudnn8.7.0_0
156 | - pytorch-cuda=11.8=h7e8668a_5
157 | - pytorch-lightning=2.1.0=pyhd8ed1ab_0
158 | - pytorch-mutex=1.0=cuda
159 | - pytz=2023.3.post1=pyhd8ed1ab_0
160 | - pyyaml=6.0.1=py38h5eee18b_0
161 | - rapidfuzz=2.13.7=py38h417a72b_0
162 | - readchar=4.0.5=pyhd8ed1ab_0
163 | - readline=7.0=h7b6447c_5
164 | - referencing=0.30.2=py38h06a4308_0
165 | - requests-toolbelt=0.10.1=pyhd8ed1ab_0
166 | - s3transfer=0.7.0=pyhd8ed1ab_0
167 | - secretstorage=3.3.3=py38h578d9bd_2
168 | - shellingham=1.5.3=pyhd8ed1ab_0
169 | - six=1.16.0=pyhd3eb1b0_1
170 | - sniffio=1.3.0=pyhd8ed1ab_0
171 | - soupsieve=2.5=pyhd8ed1ab_1
172 | - sqlite=3.33.0=h62c20be_0
173 | - starlette=0.27.0=pyhd8ed1ab_0
174 | - starsessions=1.3.0=pyhd8ed1ab_0
175 | - sympy=1.11.1=py38h06a4308_0
176 | - tk=8.6.12=h1ccaba5_0
177 | - tomli=2.0.1=pyhd8ed1ab_0
178 | - tomlkit=0.12.1=pyha770c72_0
179 | - torchaudio=2.0.2=py38_cu118
180 | - torchmetrics=1.2.0=pyhd8ed1ab_0
181 | - torchtriton=2.0.0=py38
182 | - torchvision=0.15.2=py38_cu118
183 | - trove-classifiers=2023.10.18=pyhd8ed1ab_0
184 | - types-python-dateutil=2.8.19.14=pyhd8ed1ab_0
185 | - typing-extensions=4.7.1=py38h06a4308_0
186 | - typing_extensions=4.7.1=py38h06a4308_0
187 | - urllib3=1.26.16=py38h06a4308_0
188 | - uvicorn=0.23.2=py38h578d9bd_1
189 | - virtualenv=20.17.1=py38h578d9bd_0
190 | - webencodings=0.5.1=pyhd8ed1ab_2
191 | - websockets=10.4=py38h5eee18b_1
192 | - wheel=0.38.4=py38h06a4308_0
193 | - xz=5.4.2=h5eee18b_0
194 | - yaml=0.2.5=h7f98852_2
195 | - zlib=1.2.13=h5eee18b_0
196 | - zstd=1.5.5=hc292b87_0
197 | - pip:
198 | - absl-py==2.0.0
199 | - accelerate==0.30.1
200 | - addict==2.4.0
201 | - aiofiles==22.1.0
202 | - aiosqlite==0.20.0
203 | - aliyun-python-sdk-core==2.13.36
204 | - aliyun-python-sdk-kms==2.16.2
205 | - ansi2html==1.8.0
206 | - anyio==4.0.0
207 | - apptools==5.2.1
208 | - argon2-cffi==23.1.0
209 | - argon2-cffi-bindings==21.2.0
210 | - asttokens==2.4.0
211 | - async-lru==2.0.4
212 | - av==12.0.0
213 | - babel==2.12.1
214 | - backcall==0.2.0
215 | - black==23.11.0
216 | - bleach==6.0.0
217 | - blinker==1.7.0
218 | - cachetools==5.3.1
219 | - comm==0.1.4
220 | - configargparse==1.7
221 | - configobj==5.0.8
222 | - contourpy==1.1.0
223 | - crcmod==1.7
224 | - cycler==0.11.0
225 | - dash==2.14.1
226 | - dash-core-components==2.0.0
227 | - dash-html-components==2.0.0
228 | - dash-table==5.0.0
229 | - debugpy==1.7.0
230 | - decorator==5.1.1
231 | - decord==0.6.0
232 | - defusedxml==0.7.1
233 | - deprecation==2.1.0
234 | - descartes==1.1.0
235 | - diffusers==0.24.0
236 | - einops==0.7.0
237 | - entrypoints==0.4
238 | - envisage==7.0.3
239 | - executing==1.2.0
240 | - fastjsonschema==2.18.0
241 | - fire==0.5.0
242 | - flake8==5.0.4
243 | - flask==3.0.0
244 | - fonttools==4.42.1
245 | - fqdn==1.5.1
246 | - fsspec==2023.9.0
247 | - google-auth==2.23.4
248 | - google-auth-oauthlib==1.0.0
249 | - grpcio==1.59.2
250 | - huggingface-hub==0.23.1
251 | - imageio==2.32.0
252 | - imageio-ffmpeg==0.4.9
253 | - importlib-resources==6.0.1
254 | - iniconfig==2.0.0
255 | - ipdb==0.13.13
256 | - ipykernel==6.25.2
257 | - ipython==8.12.2
258 | - ipython-genutils==0.2.0
259 | - ipywidgets==8.1.0
260 | - isoduration==20.11.0
261 | - jedi==0.19.0
262 | - jmespath==0.10.0
263 | - joblib==1.3.2
264 | - json5==0.9.14
265 | - jsonpointer==2.4
266 | - jupyter==1.0.0
267 | - jupyter-client==7.4.9
268 | - jupyter-console==6.6.3
269 | - jupyter-core==5.3.1
270 | - jupyter-events==0.7.0
271 | - jupyter-lsp==2.2.0
272 | - jupyter-packaging==0.12.3
273 | - jupyter-server==2.7.3
274 | - jupyter-server-fileid==0.9.1
275 | - jupyter-server-terminals==0.4.4
276 | - jupyter-server-ydoc==0.8.0
277 | - jupyter-ydoc==0.2.5
278 | - jupyterlab==3.6.7
279 | - jupyterlab-pygments==0.2.2
280 | - jupyterlab-server==2.24.0
281 | - jupyterlab-widgets==3.0.8
282 | - kiwisolver==1.4.5
283 | - lazy-loader==0.3
284 | - llvmlite==0.40.1
285 | - lyft-dataset-sdk==0.0.8
286 | - markdown==3.4.4
287 | - matplotlib==3.5.2
288 | - matplotlib-inline==0.1.6
289 | - mayavi==4.8.1
290 | - mccabe==0.7.0
291 | - mdurl==0.1.2
292 | - mistune==3.0.1
293 | - mmcv==2.0.1
294 | - mmdet==3.3.0
295 | - mmdet3d==1.4.0
296 | - mmengine==0.8.4
297 | - model-index==0.1.11
298 | - mypy-extensions==1.0.0
299 | - nbclassic==1.0.0
300 | - nbclient==0.8.0
301 | - nbconvert==7.8.0
302 | - nbformat==5.7.0
303 | - nest-asyncio==1.5.7
304 | - notebook==6.5.6
305 | - notebook-shim==0.2.3
306 | - numba==0.57.1
307 | - nuscenes-devkit==1.1.10
308 | - oauthlib==3.2.2
309 | - open3d==0.13.0
310 | - open3d-python==0.3.0.0
311 | - opencv-python==4.9.0.80
312 | - opendatalab==0.0.10
313 | - openmim==0.3.9
314 | - openxlab==0.0.24
315 | - oss2==2.17.0
316 | - overrides==7.4.0
317 | - packaging==23.1
318 | - pandas==2.0.3
319 | - pandocfilters==1.5.0
320 | - parso==0.8.3
321 | - pathspec==0.11.2
322 | - pickleshare==0.7.5
323 | - platformdirs==3.10.0
324 | - plotly==5.18.0
325 | - pluggy==1.3.0
326 | - plyfile==1.0.1
327 | - prettytable==3.9.0
328 | - prometheus-client==0.17.1
329 | - prompt-toolkit==3.0.39
330 | - protobuf==4.25.0
331 | - psutil==5.9.5
332 | - pure-eval==0.2.2
333 | - pyasn1==0.5.0
334 | - pyasn1-modules==0.3.0
335 | - pycocotools==2.0.7
336 | - pycodestyle==2.9.1
337 | - pycryptodome==3.18.0
338 | - pyface==8.0.0
339 | - pyflakes==2.5.0
340 | - pyparsing==3.0.9
341 | - pyqt5==5.15.10
342 | - pyqt5-qt5==5.15.2
343 | - pyqt5-sip==12.13.0
344 | - pyquaternion==0.9.9
345 | - pytest==7.4.3
346 | - python-json-logger==2.0.7
347 | - pytorch-fast-transformers==0.4.0
348 | - pytorch-fid==0.3.0
349 | - pyvirtualdisplay==3.0
350 | - pywavelets==1.4.1
351 | - pyzmq==24.0.1
352 | - qtconsole==5.4.4
353 | - qtpy==2.4.0
354 | - regex==2024.5.15
355 | - requests==2.28.2
356 | - requests-oauthlib==1.3.1
357 | - retrying==1.3.4
358 | - rfc3339-validator==0.1.4
359 | - rfc3986-validator==0.1.1
360 | - rich==13.4.2
361 | - rpds-py==0.10.2
362 | - rsa==4.9
363 | - safetensors==0.4.3
364 | - scikit-image==0.21.0
365 | - scikit-learn==1.3.0
366 | - scipy==1.10.1
367 | - send2trash==1.8.2
368 | - setuptools==59.5.0
369 | - shapely==1.8.5
370 | - stack-data==0.6.2
371 | - swin-window-process==0.0.0
372 | - tabulate==0.9.0
373 | - tenacity==8.2.3
374 | - tensorboard==2.14.0
375 | - tensorboard-data-server==0.7.2
376 | - termcolor==2.3.0
377 | - terminado==0.17.1
378 | - terminaltables==3.1.10
379 | - thop==0.1.1-2209072238
380 | - threadpoolctl==3.2.0
381 | - tifffile==2023.7.10
382 | - timm==0.9.7
383 | - tinycss2==1.2.1
384 | - tokenizers==0.19.1
385 | - tornado==6.3.3
386 | - tqdm==4.65.2
387 | - traitlets==5.9.0
388 | - traits==6.4.2
389 | - traitsui==8.0.0
390 | - transformers==4.41.1
391 | - trimesh==4.0.4
392 | - tzdata==2023.3
393 | - uri-template==1.3.0
394 | - vtk==9.2.6
395 | - wcwidth==0.2.6
396 | - webcolors==1.13
397 | - websocket-client==1.6.3
398 | - werkzeug==3.0.1
399 | - widgetsnbextension==4.0.8
400 | - xformers==0.0.22
401 | - y-py==0.6.2
402 | - yapf==0.40.1
403 | - ypy-websocket==0.8.4
404 | - zipp==3.16.2
405 |
--------------------------------------------------------------------------------
/loss/__init__.py:
--------------------------------------------------------------------------------
1 | from mmengine.registry import Registry
2 | OPENOCC_LOSS = Registry('openocc_loss')
3 |
4 | from .multi_loss import MultiLoss
5 | from .ce_loss import CeLoss
6 | from .emb_loss import VQVAEEmbedLoss
7 | from .recon_loss import ReconLoss, LovaszLoss
--------------------------------------------------------------------------------
/loss/base_loss.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | writer = None
3 |
4 | class BaseLoss(nn.Module):
5 |
6 | """ Base loss class.
7 | args:
8 | weight: weight of current loss.
9 | input_keys: keys for actual inputs to calculate_loss().
10 | Since "inputs" may contain many different fields, we use input_keys
11 | to distinguish them.
12 | loss_func: the actual loss func to calculate loss.
13 | """
14 |
15 | def __init__(
16 | self,
17 | weight=1.0,
18 | input_dict={
19 | 'input': 'input'},
20 | **kwargs):
21 | super().__init__()
22 | self.weight = weight
23 | self.input_dict = input_dict
24 | self.loss_func = lambda: 0
25 | self.writer = writer
26 |
27 | # def calculate_loss(self, **kwargs):
28 | # return self.loss_func(*[kwargs[key] for key in self.input_keys])
29 |
30 | def forward(self, inputs):
31 | actual_inputs = {}
32 | for input_key, input_val in self.input_dict.items():
33 | actual_inputs.update({input_key: inputs[input_val]})
34 | # return self.weight * self.calculate_loss(**actual_inputs)
35 | return self.weight * self.loss_func(**actual_inputs)
36 |
--------------------------------------------------------------------------------
/loss/ce_loss.py:
--------------------------------------------------------------------------------
1 | from .base_loss import BaseLoss
2 | from . import OPENOCC_LOSS
3 | import torch.nn.functional as F
4 | import torch
5 |
6 | @OPENOCC_LOSS.register_module()
7 | class CeLoss(BaseLoss):
8 |
9 | def __init__(self, weight=1.0, ignore_label=-100,
10 | use_weight=False, cls_weight=None, input_dict=None, **kwargs):
11 | super().__init__(weight)
12 |
13 | if input_dict is None:
14 | self.input_dict = {
15 | 'ce_inputs': 'ce_inputs',
16 | 'ce_labels': 'ce_labels'
17 | }
18 | else:
19 | self.input_dict = input_dict
20 | self.loss_func = self.ce_loss
21 | self.ignore = ignore_label
22 | self.use_weight = use_weight
23 | self.cls_weight = torch.tensor(cls_weight) if cls_weight is not None else None
24 |
25 | def ce_loss(self, ce_inputs, ce_labels):
26 | # input: -1, c
27 | # output: -1, 1
28 | ce_loss = F.cross_entropy(ce_inputs, ce_labels)
29 | return ce_loss
--------------------------------------------------------------------------------
/loss/emb_loss.py:
--------------------------------------------------------------------------------
1 | from .base_loss import BaseLoss
2 | from . import OPENOCC_LOSS
3 |
4 | @OPENOCC_LOSS.register_module()
5 | class VQVAEEmbedLoss(BaseLoss):
6 |
7 | def __init__(self, weight=1.0, input_dict=None, **kwargs):
8 | super().__init__(weight)
9 |
10 | if input_dict is None:
11 | self.input_dict = {
12 | 'embed_loss': 'embed_loss'
13 | }
14 | else:
15 | self.input_dict = input_dict
16 | self.loss_func = self.embed_loss
17 |
18 | def embed_loss(self, embed_loss):
19 | return embed_loss
--------------------------------------------------------------------------------
/loss/multi_loss.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from . import OPENOCC_LOSS
3 | writer = None
4 |
5 | @OPENOCC_LOSS.register_module()
6 | class MultiLoss(nn.Module):
7 |
8 | def __init__(self, loss_cfgs):
9 | super().__init__()
10 |
11 | assert isinstance(loss_cfgs, list)
12 | self.num_losses = len(loss_cfgs)
13 |
14 | losses = []
15 | for loss_cfg in loss_cfgs:
16 | losses.append(OPENOCC_LOSS.build(loss_cfg))
17 | self.losses = nn.ModuleList(losses)
18 | self.iter_counter = 0
19 |
20 | def forward(self, inputs):
21 |
22 | loss_dict = {}
23 | tot_loss = 0.
24 | for loss_func in self.losses:
25 | loss = loss_func(inputs)
26 | tot_loss += loss
27 | loss_dict.update({
28 | loss_func.__class__.__name__: \
29 | loss.detach().item() / loss_func.weight
30 | })
31 | if writer and self.iter_counter % 10 == 0:
32 | writer.add_scalar(
33 | f'loss/{loss_func.__class__.__name__}',
34 | loss.detach().item(), self.iter_counter)
35 | if writer and self.iter_counter % 10 == 0:
36 | writer.add_scalar(
37 | 'loss/total', tot_loss.detach().item(), self.iter_counter)
38 | self.iter_counter += 1
39 |
40 | return tot_loss, loss_dict
--------------------------------------------------------------------------------
/loss/recon_loss.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 |
3 | from .base_loss import BaseLoss
4 | from . import OPENOCC_LOSS
5 | import torch.nn.functional as F
6 |
7 | @OPENOCC_LOSS.register_module()
8 | class ReconLoss(BaseLoss):
9 |
10 | def __init__(self, weight=1.0, ignore_label=-100, use_weight=False, cls_weight=None, input_dict=None, **kwargs):
11 | super().__init__(weight)
12 |
13 | if input_dict is None:
14 | self.input_dict = {
15 | 'logits': 'logits',
16 | 'labels': 'labels'
17 | }
18 | else:
19 | self.input_dict = input_dict
20 | self.loss_func = self.recon_loss
21 | self.ignore = ignore_label
22 | self.use_weight = use_weight
23 | self.cls_weight = torch.tensor(cls_weight) if cls_weight is not None else None
24 |
25 | def recon_loss(self, logits, labels):
26 | weight = None
27 | if self.use_weight:
28 | if self.cls_weight is not None:
29 | weight = self.cls_weight
30 | else:
31 | one_hot_labels = F.one_hot(labels, num_classes=logits.shape[-1]) # bs, F, H, W, D, C
32 | cls_freq = torch.sum(one_hot_labels, dim=[0, 1, 2, 3, 4]) # C
33 | weight = 1.0 / cls_freq.clamp_min_(1) * torch.numel(labels) / logits.shape[-1]
34 |
35 | rec_loss = F.cross_entropy(logits.permute(0, 5, 1, 2, 3, 4), labels, ignore_index=self.ignore, weight=weight)
36 | return rec_loss
37 |
38 | @OPENOCC_LOSS.register_module()
39 | class LovaszLoss(BaseLoss):
40 |
41 | def __init__(self, weight=1.0, input_dict=None, **kwargs):
42 | super().__init__(weight)
43 |
44 | if input_dict is None:
45 | self.input_dict = {
46 | 'logits': 'logits',
47 | 'labels': 'labels'
48 | }
49 | else:
50 | self.input_dict = input_dict
51 | self.loss_func = self.lovasz_loss
52 |
53 | def lovasz_loss(self, logits, labels):
54 | logits = logits.flatten(0, 1).permute(0, 4, 1, 2, 3).softmax(dim=1)
55 | labels = labels.flatten(0, 1)
56 | loss = lovasz_softmax(logits, labels)
57 | return loss
58 |
59 |
60 | @OPENOCC_LOSS.register_module()
61 | class KLLoss(BaseLoss):
62 |
63 | def __init__(self, weight=1.0, input_dict=None, **kwargs):
64 | super().__init__(weight)
65 |
66 | if input_dict is None:
67 | self.input_dict = {
68 | 'z_mu': 'z_mu',
69 | 'z_sigma': 'z_sigma'
70 | }
71 | else:
72 | self.input_dict = input_dict
73 | self.loss_func = self.KL_loss
74 |
75 | def KL_loss(self, z_mu, z_sigma):
76 | z_mu = z_mu.permute(0,2,1,3,4).flatten(0, 1) #mu
77 | z_sigma = z_sigma.permute(0,2,1,3,4).flatten(0, 1) #logvar
78 | # mu
79 | # log_var
80 | loss = torch.mean(-0.5 * torch.sum(1 + z_sigma - z_mu ** 2 - z_sigma.exp(), dim = [1,2,3]), dim = 0)
81 | return loss
82 |
83 |
84 | """
85 | Lovasz-Softmax and Jaccard hinge loss in PyTorch
86 | Maxim Berman 2018 ESAT-PSI KU Leuven (MIT License)
87 | """
88 |
89 | import torch
90 | from torch.autograd import Variable
91 | import torch.nn.functional as F
92 | try:
93 | from itertools import ifilterfalse
94 | except ImportError: # py3k
95 | from itertools import filterfalse as ifilterfalse
96 |
97 |
98 | def lovasz_grad(gt_sorted):
99 | """
100 | Computes gradient of the Lovasz extension w.r.t sorted errors
101 | See Alg. 1 in paper
102 | """
103 | p = len(gt_sorted)
104 | gts = gt_sorted.sum()
105 | intersection = gts - gt_sorted.float().cumsum(0)
106 | union = gts + (1 - gt_sorted).float().cumsum(0)
107 | jaccard = 1. - intersection / union
108 | if p > 1: # cover 1-pixel case
109 | jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
110 | return jaccard
111 |
112 | # --------------------------- MULTICLASS LOSSES ---------------------------
113 |
114 |
115 | def lovasz_softmax(probas, labels, classes='present', per_image=False, ignore=None):
116 | """
117 | Multi-class Lovasz-Softmax loss
118 | probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1).
119 | Interpreted as binary (sigmoid) output with outputs of size [B, H, W].
120 | labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)
121 | classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
122 | per_image: compute the loss per image instead of per batch
123 | ignore: void class labels
124 | """
125 | if per_image:
126 | loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), classes=classes)
127 | for prob, lab in zip(probas, labels))
128 | else:
129 | loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), classes=classes)
130 | return loss
131 |
132 |
133 | def lovasz_softmax_flat(probas, labels, classes='present'):
134 | """
135 | Multi-class Lovasz-Softmax loss
136 | probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1)
137 | labels: [P] Tensor, ground truth labels (between 0 and C - 1)
138 | classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
139 | """
140 | if probas.numel() == 0:
141 | # only void pixels, the gradients should be 0
142 | return 0.#probas * 0.
143 | #print(probas.size())
144 | C = probas.size(1)
145 | losses = []
146 | class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes
147 | for c in class_to_sum:
148 | fg = (labels == c).float() # foreground for class c
149 | if (classes == 'present' and fg.sum() == 0):
150 | continue
151 | if C == 1:
152 | if len(classes) > 1:
153 | raise ValueError('Sigmoid output possible only with 1 class')
154 | class_pred = probas[:, 0]
155 | else:
156 | class_pred = probas[:, c]
157 | errors = (Variable(fg) - class_pred).abs()
158 | errors_sorted, perm = torch.sort(errors, 0, descending=True)
159 | perm = perm.data
160 | fg_sorted = fg[perm]
161 | losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted))))
162 | return mean(losses)
163 |
164 |
165 | def flatten_probas(probas, labels, ignore=None):
166 | """
167 | Flattens predictions in the batch
168 | """
169 | if probas.dim() == 3:
170 | # assumes output of a sigmoid layer
171 | B, H, W = probas.size()
172 | probas = probas.view(B, 1, H, W)
173 | elif probas.dim() == 5:
174 | #3D segmentation
175 | B, C, L, H, W = probas.size()
176 | probas = probas.contiguous().view(B, C, L, H*W)
177 | B, C, H, W = probas.size()
178 | probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C
179 | labels = labels.view(-1)
180 | if ignore is None:
181 | return probas, labels
182 | valid = (labels != ignore)
183 | vprobas = probas[valid]#.nonzero().squeeze()]
184 | # print(labels)
185 | # print(valid)
186 | vlabels = labels[valid]
187 | return vprobas, vlabels
188 |
189 | # --------------------------- HELPER FUNCTIONS ---------------------------
190 |
191 | def isnan(x):
192 | return x != x
193 |
194 | def mean(l, ignore_nan=False, empty=0):
195 | """
196 | nanmean compatible with generators.
197 | """
198 | l = iter(l)
199 | if ignore_nan:
200 | l = ifilterfalse(isnan, l)
201 | try:
202 | n = 1
203 | acc = next(l)
204 | except StopIteration:
205 | if empty == 'raise':
206 | raise ValueError('Empty mean')
207 | return empty
208 | for n, v in enumerate(l, 2):
209 | acc += v
210 | if n == 1:
211 | return acc
212 | return acc / n
213 |
--------------------------------------------------------------------------------
/model/VAE/quantizer.py:
--------------------------------------------------------------------------------
1 |
2 | """ adapted from: https://github.com/CompVis/taming-transformers/blob/master/taming/modules/vqvae/quantize.py """
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import numpy as np
8 | from einops import rearrange
9 | # from mmseg.models import BACKBONES
10 | from mmengine.registry import MODELS
11 | from mmengine.model import BaseModule
12 |
13 | @MODELS.register_module()
14 | class VectorQuantizer(BaseModule):
15 | """
16 | Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
17 | avoids costly matrix multiplications and allows for post-hoc remapping of indices.
18 | """
19 | # NOTE: due to a bug the beta term was applied to the wrong term. for
20 | # backwards compatibility we use the buggy version by default, but you can
21 | # specify legacy=False to fix it.
22 | def __init__(self, n_e, e_dim, beta, z_channels, remap=None, unknown_index="random",
23 | sane_index_shape=False, legacy=True, use_voxel=True):
24 | super().__init__()
25 | self.n_e = n_e
26 | self.e_dim = e_dim
27 | self.beta = beta
28 | self.legacy = legacy
29 |
30 | self.embedding = nn.Embedding(self.n_e, self.e_dim)
31 | self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
32 |
33 | self.remap = remap
34 | if self.remap is not None:
35 | self.register_buffer("used", torch.tensor(np.load(self.remap)))
36 | self.re_embed = self.used.shape[0]
37 | self.unknown_index = unknown_index # "random" or "extra" or integer
38 | if self.unknown_index == "extra":
39 | self.unknown_index = self.re_embed
40 | self.re_embed = self.re_embed+1
41 | print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
42 | f"Using {self.unknown_index} for unknown indices.")
43 | else:
44 | self.re_embed = n_e
45 |
46 | self.sane_index_shape = sane_index_shape
47 |
48 | conv_class = torch.nn.Conv3d if use_voxel else torch.nn.Conv2d
49 | self.quant_conv = conv_class(z_channels, self.e_dim, 1)
50 | self.post_quant_conv = conv_class(self.e_dim, z_channels, 1)
51 |
52 | def remap_to_used(self, inds):
53 | ishape = inds.shape
54 | assert len(ishape)>1
55 | inds = inds.reshape(ishape[0],-1)
56 | used = self.used.to(inds)
57 | match = (inds[:,:,None]==used[None,None,...]).long()
58 | new = match.argmax(-1)
59 | unknown = match.sum(2)<1
60 | if self.unknown_index == "random":
61 | new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
62 | else:
63 | new[unknown] = self.unknown_index
64 | return new.reshape(ishape)
65 |
66 | def unmap_to_all(self, inds):
67 | ishape = inds.shape
68 | assert len(ishape)>1
69 | inds = inds.reshape(ishape[0],-1)
70 | used = self.used.to(inds)
71 | if self.re_embed > self.used.shape[0]: # extra token
72 | inds[inds>=self.used.shape[0]] = 0 # simply set to zero
73 | back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
74 | return back.reshape(ishape)
75 |
76 | def forward(self, z, temp=None, rescale_logits=False, return_logits=False, is_voxel=False):
77 | z = self.quant_conv(z)
78 | z_q, loss, (perplexity, min_encodings, min_encoding_indices) = self.forward_quantizer(z, temp, rescale_logits, return_logits, is_voxel)
79 | z_q = self.post_quant_conv(z_q)
80 | return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
81 | def forward_quantizer(self, z, temp=None, rescale_logits=False, return_logits=False, is_voxel=False):
82 | assert temp is None or temp==1.0, "Only for interface compatible with Gumbel"
83 | assert rescale_logits==False, "Only for interface compatible with Gumbel"
84 | assert return_logits==False, "Only for interface compatible with Gumbel"
85 |
86 | # reshape z -> (batch, height, width, channel) and flatten
87 | if not is_voxel:
88 | z = rearrange(z, 'b c h w -> b h w c').contiguous()
89 | else:
90 | z = rearrange(z, 'b c d h w -> b d h w c').contiguous()
91 | z_flattened = z.view(-1, self.e_dim)
92 | # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
93 |
94 | d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
95 | torch.sum(self.embedding.weight**2, dim=1) - 2 * \
96 | torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
97 |
98 | min_encoding_indices = torch.argmin(d, dim=1)
99 | z_q = self.embedding(min_encoding_indices).view(z.shape)
100 | perplexity = None
101 | min_encodings = None
102 |
103 | # compute loss for embedding
104 | if not self.legacy:
105 | loss = self.beta * torch.mean((z_q.detach()-z)**2) + \
106 | torch.mean((z_q - z.detach()) ** 2)
107 | else:
108 | loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
109 | torch.mean((z_q - z.detach()) ** 2)
110 |
111 | # preserve gradients
112 | z_q = z + (z_q - z).detach()
113 |
114 | # reshape back to match original input shape
115 | if not is_voxel:
116 | z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
117 | else:
118 | z_q = rearrange(z_q, 'b d h w c -> b c d h w').contiguous()
119 |
120 |
121 |
122 | if self.remap is not None:
123 | min_encoding_indices = min_encoding_indices.reshape(z.shape[0],-1) # add batch axis
124 | min_encoding_indices = self.remap_to_used(min_encoding_indices)
125 | min_encoding_indices = min_encoding_indices.reshape(-1,1) # flatten
126 |
127 | if self.sane_index_shape:
128 | if not is_voxel:
129 | min_encoding_indices = min_encoding_indices.reshape(
130 | z_q.shape[0], z_q.shape[2], z_q.shape[3])
131 | else:
132 | min_encoding_indices = min_encoding_indices.reshape(
133 | z_q.shape[0], z_q.shape[2], z_q.shape[3], z_q.shape[4])
134 |
135 | return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
136 | def get_codebook_entry(self, indices, shape):
137 | # shape specifying (batch, height, width, channel)
138 | if self.remap is not None:
139 | indices = indices.reshape(shape[0],-1) # add batch axis
140 | indices = self.unmap_to_all(indices)
141 | indices = indices.reshape(-1) # flatten again
142 |
143 | # get quantized latent vectors
144 | # z_q = self.embedding(indices)
145 | mask=(indices>0) & (indices b h w c').contiguous()
160 | else:
161 | b, c, d, h, w = z.shape
162 | z = rearrange(z, 'b c d h w -> b d h w c').contiguous()
163 | z_flattened = z.view(-1, self.e_dim)
164 | d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
165 | torch.sum(self.embedding.weight**2, dim=1) - 2 * \
166 | torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
167 | min_encoding_indices = torch.argmin(d, dim=1)
168 | if not is_voxels:
169 | min_encoding_indices = min_encoding_indices.reshape(b, h, w)
170 | else:
171 | min_encoding_indices = min_encoding_indices.reshape(b, d, h, w)
172 | return min_encoding_indices
173 |
174 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .VAE.vae_2d_resnet import VAERes2D,VAERes3D
2 | from .VAE.quantizer import VectorQuantizer
3 |
4 | from .pose_encoder import PoseEncoder,PoseEncoder_fourier
5 |
6 | from .dome import Dome
--------------------------------------------------------------------------------
/model/pose_encoder.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | from mmengine.registry import MODELS
4 | from mmengine.model import BaseModule
5 | from einops import rearrange, repeat
6 | import torch.nn.functional as F
7 |
8 | @MODELS.register_module()
9 | class PoseEncoder(BaseModule):
10 | def __init__(
11 | self,
12 | in_channels,
13 | out_channels,
14 | num_layers=2,
15 | num_modes=3,
16 | num_fut_ts=1,
17 | init_cfg=None
18 | ):
19 | super().__init__(init_cfg)
20 | self.num_modes = num_modes
21 | self.num_fut_ts = num_fut_ts
22 | assert num_fut_ts == 1
23 |
24 | pose_encoder = []
25 |
26 | for _ in range(num_layers - 1):
27 | pose_encoder.extend([
28 | nn.Linear(in_channels, out_channels),
29 | nn.ReLU(True)])
30 | in_channels = out_channels
31 | pose_encoder.append(nn.Linear(out_channels, out_channels))
32 | self.pose_encoder = nn.Sequential(*pose_encoder)
33 |
34 | def forward(self,x):
35 | # x: N*2,
36 | pose_feat = self.pose_encoder(x)
37 | return pose_feat
38 |
39 |
40 | @MODELS.register_module()
41 | class PoseEncoder_fourier(BaseModule):
42 | def __init__(
43 | self,
44 | in_channels,
45 | out_channels,
46 | num_layers=2,
47 | num_modes=3,
48 | num_fut_ts=1,
49 | fourier_freqs=8,
50 | init_cfg=None,
51 | do_proj=False,
52 | max_length=77,
53 | # zero_init=False
54 | **kwargs
55 | ):
56 | super().__init__(init_cfg)
57 | self.num_modes = num_modes
58 | self.num_fut_ts = num_fut_ts
59 | assert num_fut_ts == 1
60 | # assert in_channels==2,"only support 2d coordinates for now, include gt_mode etc later"
61 | self.fourier_freqs=fourier_freqs
62 | self.position_dim = fourier_freqs * 2 * in_channels # 2: sin/cos, 2: xy
63 | in_channels=self.position_dim
64 |
65 |
66 | pose_encoder = []
67 | for _ in range(num_layers - 1):
68 | pose_encoder.extend([
69 | nn.Linear(in_channels, out_channels),
70 | nn.ReLU(True)])
71 | in_channels = out_channels
72 | pose_encoder.append(nn.Linear(out_channels, out_channels))
73 | self.pose_encoder = nn.Sequential(*pose_encoder)
74 |
75 | # self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim]))
76 | self.do_proj=do_proj
77 | self.max_length=max_length
78 | if do_proj:
79 | # proj b*f*c -> b*c
80 | self.embedding_projection =nn.Linear(max_length * out_channels, out_channels, bias=True)
81 | self.null_position_feature = torch.nn.Parameter(torch.zeros([out_channels]))
82 | # if zero_init:
83 | # self.zero_module()
84 |
85 | def zero_module(self):
86 | """
87 | Zero out the parameters of a module and return it.
88 | """
89 | for p in self.parameters():
90 | p.detach().zero_()
91 |
92 | def forward(self,x,mask=None):
93 | # x: N*2,
94 | b,f=x.shape[:2]
95 | x = rearrange(x, 'b f d -> (b f) d')
96 | x=get_fourier_embeds_from_coordinates(self.fourier_freqs,x) # N*dim (bf)*32 # 2,11,32
97 | # if mask is not None: #TODO
98 | # # learnable null embedding
99 | # xyxy_null = self.null_position_feature.view(1, 1, -1)
100 | # # replace padding with learnable null embedding
101 | # xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null
102 | x = self.pose_encoder(x) #([2, 11, 768])
103 | if self.do_proj:
104 | x = rearrange(x, '(b f) d -> b f d', b=b) # 2,11,32
105 | x_pad=F.pad(x,(0,0,0,self.max_length-f)) # 2,77,32
106 | xyxy_null = self.null_position_feature.view(1, 1, -1)
107 | mask=torch.zeros(b,self.max_length,1).to(x.device)
108 | mask[:,:f]=1
109 | x = x_pad * mask + (1 - mask) * xyxy_null
110 | x = rearrange(x, 'b f d -> b (f d)')
111 | x=self.embedding_projection(x) # b d
112 | else:
113 | x = rearrange(x, '(b f) d -> b f d', b=b)
114 |
115 | return x
116 |
117 | @MODELS.register_module()
118 | class PoseEncoder_fourier_yaw(BaseModule):
119 | def __init__(
120 | self,
121 | in_channels,
122 | out_channels,
123 | num_layers=2,
124 | num_modes=3,
125 | num_fut_ts=1,
126 | fourier_freqs=8,
127 | init_cfg=None,
128 | do_proj=False,
129 | max_length=77,
130 | # zero_init=False
131 | **kwargs
132 | ):
133 | super().__init__(init_cfg)
134 | self.num_modes = num_modes
135 | self.num_fut_ts = num_fut_ts
136 | assert num_fut_ts == 1
137 | # assert in_channels==2,"only support 2d coordinates for now, include gt_mode etc later"
138 | self.fourier_freqs=fourier_freqs
139 | self.position_dim = fourier_freqs * 2 * (in_channels-1) # 2: sin/cos, 2: xy
140 | self.position_dim_yaw = fourier_freqs * 2 * (1) # 2: sin/cos, 2: xy
141 |
142 |
143 | in_channels=self.position_dim
144 | pose_encoder = []
145 | for _ in range(num_layers - 1):
146 | pose_encoder.extend([
147 | nn.Linear(in_channels, out_channels),
148 | nn.ReLU(True)])
149 | in_channels = out_channels
150 | pose_encoder.append(nn.Linear(out_channels, out_channels))
151 | self.pose_encoder = nn.Sequential(*pose_encoder)
152 |
153 | in_channels=self.position_dim_yaw
154 | pose_encoder_yaw = []
155 | for _ in range(num_layers - 1):
156 | pose_encoder_yaw.extend([
157 | nn.Linear(in_channels, out_channels),
158 | nn.ReLU(True)])
159 | in_channels = out_channels
160 | pose_encoder_yaw.append(nn.Linear(out_channels, out_channels))
161 | self.pose_encoder_yaw = nn.Sequential(*pose_encoder_yaw)
162 |
163 | # self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim]))
164 | self.do_proj=do_proj
165 | self.max_length=max_length
166 | if do_proj:
167 | # proj b*f*c -> b*c
168 | self.embedding_projection =nn.Linear(max_length * out_channels, out_channels, bias=True)
169 | self.null_position_feature = torch.nn.Parameter(torch.zeros([out_channels]))
170 | self.embedding_projection_yaw =nn.Linear(max_length * out_channels, out_channels, bias=True)
171 | self.null_position_feature_yaw = torch.nn.Parameter(torch.zeros([out_channels]))
172 | # if zero_init:
173 | # self.zero_module()
174 |
175 | def zero_module(self, zero_params=None):
176 | """
177 | Zero out the parameters of a module based on the given list of parameter names and return it.
178 |
179 | Args:
180 | zero_params (list): List of parameter names to zero. If None, all parameters will be zeroed.
181 | """
182 | for name, p in self.named_parameters():
183 | if zero_params is None or name in zero_params:
184 | p.detach().zero_()
185 |
186 | def forward(self,x,mask=None):
187 | # x: N*2,
188 | b,f=x.shape[:2]
189 | x = rearrange(x, 'b f d -> (b f) d')
190 | x,x_yaw=x[:,:-1],x[:,-1:]
191 | x=get_fourier_embeds_from_coordinates(self.fourier_freqs,x) # N*dim (bf)*32 # 2,11,32
192 | x_yaw=get_fourier_embeds_from_coordinates(self.fourier_freqs,x_yaw) # N*dim (bf)*32 # 2,11,32
193 | # if mask is not None: #TODO
194 | # # learnable null embedding
195 | # xyxy_null = self.null_position_feature.view(1, 1, -1)
196 | # # replace padding with learnable null embedding
197 | # xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null
198 | x = self.pose_encoder(x) #([2, 11, 768])
199 | x_yaw = self.pose_encoder_yaw(x_yaw) #([2, 11, 768])
200 | if self.do_proj:
201 | x = rearrange(x, '(b f) d -> b f d', b=b) # 2,11,32
202 | x_pad=F.pad(x,(0,0,0,self.max_length-f)) # 2,77,32
203 | xyxy_null = self.null_position_feature.view(1, 1, -1)
204 | mask=torch.zeros(b,self.max_length,1).to(x.device)
205 | mask[:,:f]=1
206 | x = x_pad * mask + (1 - mask) * xyxy_null
207 | x = rearrange(x, 'b f d -> b (f d)')
208 | x=self.embedding_projection(x) # b d
209 | else:
210 | x = rearrange(x, '(b f) d -> b f d', b=b)
211 |
212 | if self.do_proj:
213 | x_yaw = rearrange(x_yaw, '(b f) d -> b f d', b=b) # 2,11,32
214 | x_pad=F.pad(x_yaw,(0,0,0,self.max_length-f)) # 2,77,32
215 | xyxy_null = self.null_position_feature.view(1, 1, -1)
216 | mask=torch.zeros(b,self.max_length,1).to(x_yaw.device)
217 | mask[:,:f]=1
218 | x_yaw = x_pad * mask + (1 - mask) * xyxy_null
219 | x_yaw = rearrange(x_yaw, 'b f d -> b (f d)')
220 | x_yaw=self.embedding_projection(x_yaw) # b d
221 | else:
222 | x_yaw = rearrange(x_yaw, '(b f) d -> b f d', b=b)
223 |
224 | x=x+x_yaw
225 |
226 | return x
227 |
228 | def get_fourier_embeds_from_coordinates(embed_dim, xys):
229 | """
230 | Args:
231 | embed_dim: int
232 | xys: a 3-D tensor [B x N x 4] representing the bounding boxes for GLIGEN pipeline
233 | Returns:
234 | [B x N x embed_dim] tensor of positional embeddings
235 | """
236 |
237 | batch_size = xys.shape[0]
238 | ch= xys.shape[-1]
239 |
240 | emb = 100 ** (torch.arange(embed_dim) / embed_dim)
241 | emb = emb[None, None].to(device=xys.device, dtype=xys.dtype)
242 | emb = emb * xys.unsqueeze(-1)
243 |
244 | emb = torch.stack((emb.sin(), emb.cos()), dim=-1)
245 | emb = emb.permute(0, 2, 3, 1).reshape(batch_size, embed_dim * 2 * ch)
246 |
247 | return emb
248 |
249 |
250 |
--------------------------------------------------------------------------------
/resample/astar.py:
--------------------------------------------------------------------------------
1 | # Sample code from https://www.redblobgames.com/pathfinding/a-star/
2 | import heapq
3 | import collections
4 | import numpy as np
5 | import math
6 |
7 | class SimpleGraph:
8 | def __init__(self):
9 | self.edges = {}
10 |
11 | def neighbors(self, id):
12 | return self.edges[id]
13 |
14 |
15 | class Queue:
16 | def __init__(self):
17 | self.elements = collections.deque()
18 |
19 | def empty(self):
20 | return len(self.elements) == 0
21 |
22 | def put(self, x):
23 | self.elements.append(x)
24 |
25 | def get(self):
26 | return self.elements.popleft()
27 |
28 | # utility functions for dealing with square grids
29 | def from_id_width(id, width):
30 | return (id % width, id // width)
31 |
32 | def draw_tile(graph, id, style, width):
33 | r = "."
34 | if 'number' in style and id in style['number']: r = "%d" % style['number'][id]
35 | if 'point_to' in style and style['point_to'].get(id, None) is not None:
36 | (x1, y1) = id
37 | (x2, y2) = style['point_to'][id]
38 | if x2 == x1 + 1: r = ">"
39 | if x2 == x1 - 1: r = "<"
40 | if y2 == y1 + 1: r = "v"
41 | if y2 == y1 - 1: r = "^"
42 | if 'start' in style and id == style['start']: r = "A"
43 | if 'goal' in style and id == style['goal']: r = "Z"
44 | if 'path' in style and id in style['path']: r = "@"
45 | if id in graph.walls: r = "#" * width
46 | return r
47 |
48 | def draw_grid(graph, width=2, **style):
49 | for y in range(graph.height):
50 | for x in range(graph.width):
51 | print("%%-%ds" % width % draw_tile(graph, (x, y), style, width), end="")
52 | print()
53 |
54 |
55 | class SquareGrid:
56 | def __init__(self, width, height):
57 | self.width = width
58 | self.height = height
59 | self.walls = []
60 |
61 | def in_bounds(self, id):
62 | (x, y) = id
63 | return 0 <= x < self.width and 0 <= y < self.height
64 |
65 | def passable(self, id):
66 | return id not in self.walls
67 |
68 | def neighbors(self, id):
69 | (x, y) = id
70 | results = [(x+1, y), (x, y-1), (x-1, y), (x, y+1)]
71 | if (x + y) % 2 == 0: results.reverse() # aesthetics
72 | results = filter(self.in_bounds, results)
73 | results = filter(self.passable, results)
74 | return results
75 |
76 | @classmethod
77 | def from_voxel_bev_map(cls, mask,mark_margin_obs=True):
78 | # assert mask.dtype == np.bool and mask.ndim == 2
79 | width,height = mask.shape[:2]
80 | grid = cls(width, height)
81 | if mark_margin_obs:
82 | mask[:,[0,-1]]=0
83 | mask[[0,-1],:]=0
84 | grid.walls = [(x,y) for x,y in np.argwhere(mask==0)]
85 | return grid
86 |
87 |
88 | class GridWithWeights(SquareGrid):
89 | def __init__(self, width, height):
90 | super().__init__(width, height)
91 | self.weights = {}
92 |
93 | def cost(self, from_node, to_node):
94 |
95 | return self.weights.get(to_node, 1)
96 | @classmethod
97 | def from_voxel_bev_map(cls, mask,cost_map=None,mark_margin_obs=True):
98 | # assert mask.dtype == np.bool and mask.ndim == 2
99 | width,height = mask.shape[:2]
100 | if mark_margin_obs:
101 | mask[:,[0,-1]]=0
102 | mask[[0,-1],:]=0
103 | grid = cls(width, height)
104 | grid.walls = [(x,y) for x,y in np.argwhere(mask==0)]
105 | if cost_map is not None:
106 | grid.weights = {(x,y):cost_map[x,y] for x in range(width) for y in range(height)}
107 | return grid
108 |
109 |
110 |
111 | class PriorityQueue:
112 | def __init__(self):
113 | self.elements = []
114 |
115 | def empty(self):
116 | return len(self.elements) == 0
117 |
118 | def put(self, item, priority):
119 | heapq.heappush(self.elements, (priority, item))
120 |
121 | def get(self):
122 | return heapq.heappop(self.elements)[1]
123 |
124 | def reconstruct_path(came_from, start, goal):
125 | current = goal
126 | path = []
127 | while current != start:
128 | path.append(current)
129 | current = came_from[current]
130 | path.append(start) # optional
131 | path.reverse() # optional
132 | return path
133 |
134 | # def heuristic(a, b):
135 | # (x1, y1) = a
136 | # (x2, y2) = b
137 | # return abs(x1 - x2) + abs(y1 - y2)
138 | def heuristic(a, b):
139 | (x1, y1) = a
140 | (x2, y2) = b
141 | return math.sqrt((x1 - x2)**2 + (y1 - y2)**2)
142 |
143 | def a_star_search(graph, start, goal):
144 | frontier = PriorityQueue()
145 | frontier.put(start, 0)
146 | came_from = {}
147 | cost_so_far = {}
148 | came_from[start] = None
149 | cost_so_far[start] = 0
150 |
151 | while not frontier.empty():
152 | current = frontier.get()
153 |
154 | if current == goal:
155 | break
156 |
157 | for next in graph.neighbors(current):
158 | new_cost = cost_so_far[current] + graph.cost(current, next)
159 | if next not in cost_so_far or new_cost < cost_so_far[next]:
160 | cost_so_far[next] = new_cost
161 | priority = new_cost + heuristic(goal, next)
162 | frontier.put(next, priority)
163 | came_from[next] = current
164 |
165 | return came_from, cost_so_far
166 |
167 | # def a_star_search_distance_cost(graph,distancecostmap, start, goal):
168 | # frontier = PriorityQueue()
169 | # frontier.put(start, 0)
170 | # came_from = {}
171 | # cost_so_far = {}
172 | # distance_cost = {}
173 | # came_from[start] = None
174 | # cost_so_far[start] = 0
175 |
176 | # while not frontier.empty():
177 | # current = frontier.get()
178 |
179 | # if current == goal:
180 | # break
181 |
182 | # for next in graph.neighbors(current):
183 | # new_cost = cost_so_far[current] + graph.cost(current, next)
184 | # if next not in cost_so_far or new_cost < cost_so_far[next]:
185 | # cost_so_far[next] = new_cost
186 | # distance_cost
187 | # priority = new_cost + heuristic(goal, next)
188 | # frontier.put(next, priority)
189 | # came_from[next] = current
190 |
191 | # return came_from, cost_so_far
192 |
193 | def breadth_first_search_3(graph, start, goal):
194 | frontier = Queue()
195 | frontier.put(start)
196 | came_from = {}
197 | came_from[start] = None
198 |
199 | while not frontier.empty():
200 | current = frontier.get()
201 |
202 | if current == goal:
203 | break
204 |
205 | for next in graph.neighbors(current):
206 | if next not in came_from:
207 | frontier.put(next)
208 | came_from[next] = current
209 |
210 | return came_from
211 |
212 |
213 |
214 | def distcost(distancecostmap,x, y, safty_value=2,w=50):
215 | # large safty value makes the path more away from the wall
216 | # However, if it is too large, almost grid will get max cost
217 | # which leads to eliminate the meaning of distance cost.
218 | max_distance_cost = np.max(distancecostmap)
219 | distance_cost = max_distance_cost-distancecostmap[x][y]
220 | #if distance_cost > (max_distance_cost/safty_value):
221 | # distance_cost = 1000
222 | # return distance_cost
223 | return w * distance_cost # E5 223 - 50
224 |
225 | # start, goal = (1, 4), (7, 8)
226 | # came_from, cost_so_far = a_star_search(diagram4, start, goal)
227 | # draw_grid(diagram4, width=3, point_to=came_from, start=start, goal=goal)
228 | # print()
229 | # draw_grid(diagram4, width=3, number=cost_so_far, start=start, goal=goal)
230 | # print()
231 | # draw_grid(diagram4, width=3, path=reconstruct_path(came_from, start=start, goal=goal))
232 |
233 |
234 |
235 | def create_bev_from_pc(occ_pc_road, resolution_2d, max_dist):
236 | """
237 | Create BEV map from point cloud
238 | Args:
239 | occ_pc_road: point cloud of the road
240 | resolution_2d: resolution of the BEV map
241 | max_dist: maximum distance to the road
242 | Returns:
243 | bev_road: BEV map of the road, Binary map
244 | nearest_distance: nearest distance to the road
245 | (x_min, y_min): minimum coordinates of the BEV map
246 | """
247 | # set z=0 and compute grid parameters
248 | occ_pc_road = occ_pc_road[:,:2]
249 | x_min, x_max = occ_pc_road[:,0].min(), occ_pc_road[:,0].max()
250 | y_min, y_max = occ_pc_road[:,1].min(), occ_pc_road[:,1].max()
251 | n_x = int((x_max - x_min) / resolution_2d)
252 | n_y = int((y_max - y_min) / resolution_2d)
253 |
254 | # Create grid coordinates efficiently
255 | x_coords = np.linspace(x_min + resolution_2d/2, x_max - resolution_2d/2, n_x)
256 | y_coords = np.linspace(y_min + resolution_2d/2, y_max - resolution_2d/2, n_y)
257 | grid_x, grid_y = np.meshgrid(x_coords, y_coords)
258 | grid_coords = np.column_stack((grid_x.ravel(), grid_y.ravel()))
259 |
260 | # Fit KNN model and compute distances
261 | knn = NearestNeighbors(n_neighbors=3, algorithm='ball_tree', n_jobs=-1).fit(occ_pc_road)
262 | distances, _ = knn.kneighbors(grid_coords)
263 | mean_distances = distances.mean(axis=1)
264 |
265 | # Create and populate nd and occ_2d arrays
266 | nearest_distance = mean_distances.reshape(n_y, n_x)
267 | bev_road = (nearest_distance < max_dist).astype(np.int32)
268 |
269 | return bev_road, nearest_distance, (x_min, y_min)
--------------------------------------------------------------------------------
/resample/launch.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | import sys,os
3 | import pickle
4 | from tqdm import tqdm
5 |
6 |
7 | def run_command(command: str) -> None:
8 | """Run a command kill actions if it fails
9 |
10 | Args:
11 | command: command to run
12 | """
13 | ret_code = subprocess.call(command, shell=True)
14 | if ret_code != 0:
15 | print(f"[bold red]Error: `{command}` failed. Exiting...")
16 | # sys.exit(1)
17 |
18 |
19 | def parse_args():
20 | import argparse
21 | parser = argparse.ArgumentParser()
22 | parser.add_argument('--rank', type=int, default=0)
23 | parser.add_argument('--n_rank', type=int, default=1)
24 | parser.add_argument('--dst', type=str, required=True)
25 | parser.add_argument('--imageset', type=str, default='../data/nuscenes_infos_train_temporal_v3_scene.pkl')
26 | parser.add_argument('--input_dataset', type=str, default='gts')
27 | parser.add_argument('--data_path', type=str, default='../data/nuscenes')
28 |
29 |
30 | return parser.parse_args()
31 |
32 | args=parse_args()
33 | with open(args.imageset, 'rb') as f:
34 | data = pickle.load(f)
35 | nusc_infos = data['infos']
36 |
37 |
38 | for i in tqdm(range(args.rank,len(nusc_infos),args.n_rank),desc='lauch processing'):
39 | print('#'*50,'idx',i)
40 | dst=os.path.join(args.dst,f'src_scene-{i+1:04d}')
41 | cmd = [
42 | "python", "main.py","--idx",f"{i}",
43 | "--dst",f"{dst}",
44 | "--imageset",f"{args.imageset}",
45 | "--input_dataset",f"{args.input_dataset}",
46 | "--data_path",f"{args.data_path}",
47 | ]
48 | cmd = " ".join(cmd)
49 | print('@'*50,cmd)
50 | run_command(cmd)
51 |
52 |
53 |
54 |
55 |
56 |
--------------------------------------------------------------------------------
/resample/main.py:
--------------------------------------------------------------------------------
1 | from utils import (
2 | visualize_point_cloud,
3 | vis_pose_mesh,
4 | colors, write_pc,
5 | get_inliers_outliers,
6 | sample_points_in_roi,
7 | sample_points_in_roi_v2,
8 | approximate_b_spline_path,
9 | sampling_occ_from_pc,
10 | get_3d_pose_from_2d,
11 | create_bev_from_pc,
12 | ransac,
13 | downsample_pc_with_label,
14 | get_mask_from_path
15 | )
16 | import numpy as np
17 | import open3d as o3d
18 | from matplotlib import pyplot as plt
19 | from astar import GridWithWeights, a_star_search, reconstruct_path
20 | import time
21 | import os
22 | from tqdm import tqdm
23 | from functools import partial
24 | import joblib
25 | import cv2
26 | import pickle
27 | import shutil
28 | from pyquaternion import Quaternion
29 |
30 |
31 | def parse_args():
32 | import argparse
33 | parser = argparse.ArgumentParser()
34 | parser.add_argument('--idx', type=int, required=True)
35 | parser.add_argument('--dst', type=str, required=True)
36 | parser.add_argument('--imageset', type=str, default='../data/nuscenes_infos_train_temporal_v3_scene.pkl')
37 | parser.add_argument('--input_dataset', type=str, default='gts')
38 | parser.add_argument('--data_path', type=str, default='../data/nuscenes')
39 |
40 |
41 | return parser.parse_args()
42 |
43 | # Configuration
44 | args = parse_args()
45 | idx = args.idx
46 | dst_dir = args.dst
47 | os.makedirs(dst_dir, exist_ok=True)
48 |
49 |
50 | moving_cls_id = [
51 | 2, # 'bicycle'
52 | 3, # 'bus'
53 | 4, # 'car'
54 | 6, # motorcycle
55 | 7, # 'pedestrian'
56 | 9, # trailer
57 | 10, # 'truck'
58 | ]
59 | road_cls_id = 11
60 |
61 | ## voxelization param
62 | pc_voxel_downsample=0.2
63 | # resolution_2d=1
64 | resolution_2d=2
65 | max_dist=0.5
66 |
67 |
68 | ## st en point sampling param
69 | path_expand_radius=2
70 | seed=0
71 | n_sample_pair=10
72 | min_distance_st_en=10
73 |
74 | ## A* param
75 | distance_cost_weigth=250
76 | n_traj_point_ds=4
77 |
78 | ## traj valid check
79 | min_min_traj_len_st_en=2
80 | delta_min_traj_len_st_en=10
81 | min_traj_len_st_en=30
82 | max_traj_len_st_en=50
83 | fail_cnt_thres=10
84 |
85 | ## resampling occ from pc
86 | n_sample_occ=40 # each traj only sample 40
87 | voxel_size= 0.4
88 | pc_range= [-40, -40, -1, 40, 40, 5.4]
89 | occ_size= [200, 200, 16]
90 |
91 | ###################################
92 |
93 | # data=np.load(f'./occ_{idx}.npz')
94 | # occ,trans,rot=data['occ'],data['e2g_rel0_t'],data['e2g_rel0_r']
95 | def load_occ(args):
96 | with open(args.imageset, 'rb') as f:
97 | data = pickle.load(f)
98 | nusc_infos = data['infos']
99 | assert args.idx0
162 |
163 | # save pc
164 | # repeat n times
165 | # write_pc(occ_pc_road,f'{dst_dir}/vis_gt_road.ply',colors[None,road_cls_id-1].repeat(len(occ_pc_road),axis=0)[:,:3]/255.0) #debug
166 |
167 | # create bev map
168 | bev_road, nearest_distance, (x_min, y_min,n_x,n_y) = create_bev_from_pc(occ_pc_road, resolution_2d, max_dist)
169 | distance_cost=distance_cost_weigth*nearest_distance
170 | origin_path=trans.copy()
171 | origin_path=((origin_path-np.array([x_min,y_min,0]))/resolution_2d).astype(np.int32)
172 |
173 | # Generate mask from the original path
174 | mask_origin_path = get_mask_from_path(origin_path, (n_x, n_y), expand_radius=path_expand_radius)
175 |
176 | # Refine the mask by considering only areas that are both in the original path and on the road
177 | mask_origin_path = np.logical_and(mask_origin_path > 0, bev_road > 0).astype(np.uint8)
178 |
179 | # Calculate map Paramaters
180 | bev_map=GridWithWeights.from_voxel_bev_map(bev_road,cost_map=distance_cost)
181 |
182 | n_r=n_c=np.ceil(n_sample_pair**0.5).astype(np.int32)
183 | fig,ax=plt.subplots(n_r,n_c,figsize=(5*n_c,5*n_r))
184 | ax=ax.flatten()
185 |
186 | sampled_trajs = []
187 | valid_path_count=0
188 | np.random.seed(seed)
189 |
190 | # Plotting
191 | ax[0].imshow(bev_road)
192 | # draw raw path
193 | ax[1].imshow(mask_origin_path)
194 | ax[1].plot(origin_path[:,1], origin_path[:,0], linewidth=1.5, color='k', zorder=0,alpha=0.8)
195 | ax[1].scatter(origin_path[0,1], origin_path[0,0], marker='o', c='r')
196 | ax[1].scatter(origin_path[-1,1], origin_path[-1,0], marker='x', c='b')
197 | ax[1].set_title(f'origin path')#, area {mask_origin_path.sum()}')
198 | plt.tight_layout()
199 | plt.savefig(f'{dst_dir}/sampled_trajectories.png', dpi=300, bbox_inches='tight')
200 |
201 | fail_cnt=0
202 | fail_v2_flag=False
203 | pbar=tqdm(total=n_sample_pair,desc='sampling traj')
204 | while valid_path_count fail_cnt_thres:
206 | if min_traj_len_st_en < min_min_traj_len_st_en:
207 | raise ValueError('Failed too many times: unable to generate valid trajectory')
208 | elif fail_v2_flag:
209 | print('Sampling method v1 failed, attempting to reduce min_distance_st_en')
210 | if min_traj_len_st_en > min_min_traj_len_st_en:
211 | min_traj_len_st_en = max(min_traj_len_st_en - delta_min_traj_len_st_en, min_min_traj_len_st_en)
212 | else:
213 | raise ValueError('Failed too many times: unable to generate valid trajectory')
214 | fail_cnt = 0
215 | else:
216 | print('Sampling method v2 failed, switching to v1')
217 | fail_v2_flag = True
218 | fail_cnt = 0
219 | # sample points in roi
220 | if not fail_v2_flag:
221 | # if mask_origin_path.sum()>min_path_area:
222 | st, en, dist=sample_points_in_roi_v2(bev_road.shape[0],bev_road.shape[1],num_points=1,resolution=resolution_2d,mask=bev_road,
223 | mask_path=mask_origin_path,
224 | min_distance_threshold=min_distance_st_en,verbose=False)[0]
225 | else:
226 | st, en, dist=sample_points_in_roi(bev_road.shape[0],bev_road.shape[1],num_points=1,resolution=resolution_2d,mask=bev_road,min_distance_threshold=min_distance_st_en,verbose=False)[0]
227 |
228 | start, goal = tuple(st), tuple(en)
229 | # tic = time.time()
230 | came_from, cost_so_far = a_star_search(bev_map, start, goal)
231 | # print('A* Time:', time.time() - tic)
232 |
233 | if len(came_from) <= 1 or goal not in came_from:
234 | print('@filtered, no solution')
235 | fail_cnt+=1
236 | continue
237 |
238 | path = np.array(reconstruct_path(came_from, start=start, goal=goal))
239 | path_raw=path.copy()
240 |
241 | print(f'@len traj {len(path_raw)} @dist {dist}')
242 | if len(path_raw) < min_traj_len_st_en:
243 | print('@filtered, too short traj')
244 | fail_cnt+=1
245 | continue
246 | elif len(path_raw) > max_traj_len_st_en:
247 | print('@filtered, too long traj')
248 | fail_cnt+=1
249 | continue
250 |
251 | fail_cnt=0
252 | if n_traj_point_ds>1:
253 | path = np.concatenate([path[:-1:n_traj_point_ds], path[-1:]], axis=0)
254 | degree=min(5,len(path)-1)
255 | path_smooth = approximate_b_spline_path(path[:, 0], path[:, 1], n_path_points=1000, degree=degree)
256 | path_occ, dd = approximate_b_spline_path(path[:, 0], path[:, 1], n_path_points=n_sample_occ, degree=degree,with_derivatives=True)
257 | sampled_trajs.append(get_3d_pose_from_2d(path_occ, dd, [x_min, y_min], resolution_2d,plane_model=plane_model))
258 | pbar.update(1)
259 | # Plotting
260 | ax[valid_path_count+2].imshow(bev_road)
261 | # draw raw path
262 | # ax[valid_path_count].plot(origin_path[:,1], origin_path[:,0], linewidth=1.5, color='k', zorder=0,alpha=0.8)
263 | ax[valid_path_count+2].plot(path_raw[:,1], path_raw[:,0], linewidth=1.5, color='g', zorder=0,alpha=0.8)
264 | ax[valid_path_count+2].plot(path[:,1], path[:,0], linewidth=1.5, color='b', zorder=0,alpha=0.8)
265 | ax[valid_path_count+2].plot(path_smooth[:,1], path_smooth[:,0], linewidth=1.5, color='r', zorder=0,alpha=0.8)
266 | ax[valid_path_count+2].scatter(st[1], st[0], marker='o', c='r')
267 | ax[valid_path_count+2].scatter(en[1], en[0], marker='x', c='b')
268 | ax[valid_path_count+2].set_title(f'resample {valid_path_count}\ndist: {dist:.2f}, traj: {len(path_raw)}')
269 | valid_path_count += 1
270 | # break #debug
271 |
272 |
273 | plt.tight_layout()
274 | plt.savefig(f'{dst_dir}/sampled_trajectories.png', dpi=300, bbox_inches='tight')
275 |
276 | sampled_trajs=np.array(sampled_trajs)
277 | np.save(f'{dst_dir}/sampled_trajectories.npy',sampled_trajs)
278 | assert len(sampled_trajs)==n_sample_pair
279 |
280 | # ## debug use orgin
281 | # sampled_trajs=np.array([np.eye(4)]*len(trans))
282 | # sampled_trajs[:,:3,:3]=np.array([Quaternion(r).rotation_matrix for r in rot])
283 | # sampled_trajs[:,:3,3]=trans
284 | # sampled_trajs=sampled_trajs[None]
285 | # n_sample_pair=1
286 |
287 | def process_trajectory(j, sampled_traj, occ_pc_static, cls_label, pc_range, voxel_size, occ_size, dst_dir_i):
288 | dst_dir_j = f'{dst_dir_i}/traj-{j:06d}' # TODO: use local id
289 | os.makedirs(dst_dir_j, exist_ok=True)
290 |
291 | # Transform to ego coordinates
292 | w2e = np.linalg.inv(sampled_traj)
293 | occ_pc_static_e = occ_pc_static @ w2e[:3,:3].T + w2e[:3,3]
294 |
295 | dense_voxels_with_semantic, voxel_semantic = sampling_occ_from_pc(occ_pc_static_e, cls_label, pc_range, voxel_size, occ_size)
296 | np.savez(f'{dst_dir_j}/labels.npz', semantics=voxel_semantic,pose=sampled_traj)
297 | # write_pc(dense_voxels_with_semantic[:,:3],f'{dst_dir_j}/occ_vis.ply',colors[dense_voxels_with_semantic[:,3]-1][:,:3]/255.0) #debug
298 | return voxel_semantic
299 |
300 | for i in tqdm(range(n_sample_pair),desc='resampling occ'):
301 | # if i>0:
302 | # break #debug
303 |
304 | voxel_semantic_all=[]
305 | # vis_pose_mesh(sampled_trajs[i],cmp_dir=dst_dir,fn=f'vis_resample_all_w_traj_sample_{i}.ply') #debug
306 | dst_dir_i=f'{dst_dir}/scene-{i:04d}' #TODO global id
307 |
308 | process_func = partial(process_trajectory,
309 | occ_pc_static=occ_pc_static,
310 | cls_label=cls_label,
311 | pc_range=pc_range,
312 | voxel_size=voxel_size,
313 | occ_size=occ_size,
314 | dst_dir_i=dst_dir_i)
315 |
316 | voxel_semantic_all = joblib.Parallel(n_jobs=-1)(
317 | joblib.delayed(process_func)(j, sampled_traj)
318 | for j, sampled_traj in enumerate(sampled_trajs[i])
319 | )
320 | # for j in range(len(sampled_trajs[i])):
321 | # tic=time.time()
322 | # voxel_semantic_all.append(process_func(j, sampled_trajs[i][j]))
323 | # print(f'@process_func {time.time()-tic}')
324 |
325 | visualize_point_cloud(
326 | voxel_semantic_all,
327 | None,
328 | None,
329 | abs_trans=sampled_trajs[i],
330 | cmp_dir=dst_dir_i,
331 | frame_type='',
332 | # frame_type='e+w',
333 | key='resample'
334 | )
335 |
--------------------------------------------------------------------------------
/resample/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | open3d
3 | tqdm
4 | scipy
5 | scikit-learn
6 | opencv-python
7 | matplotlib
8 | pyquaternion
9 | joblib
--------------------------------------------------------------------------------
/static/images/dome_pipe.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gusongen/DOME/da390aeeab1990fa06e037ad4d486729ae4fe712/static/images/dome_pipe.png
--------------------------------------------------------------------------------
/static/images/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gusongen/DOME/da390aeeab1990fa06e037ad4d486729ae4fe712/static/images/favicon.ico
--------------------------------------------------------------------------------
/static/images/occvae.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gusongen/DOME/da390aeeab1990fa06e037ad4d486729ae4fe712/static/images/occvae.png
--------------------------------------------------------------------------------
/static/images/overall_pipeline4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gusongen/DOME/da390aeeab1990fa06e037ad4d486729ae4fe712/static/images/overall_pipeline4.png
--------------------------------------------------------------------------------
/static/images/teaser12.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gusongen/DOME/da390aeeab1990fa06e037ad4d486729ae4fe712/static/images/teaser12.png
--------------------------------------------------------------------------------
/static/images/vis_demo_cmp_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gusongen/DOME/da390aeeab1990fa06e037ad4d486729ae4fe712/static/images/vis_demo_cmp_2.png
--------------------------------------------------------------------------------
/tools/eval.sh:
--------------------------------------------------------------------------------
1 |
2 | # export CUDA_VISIBLE_DEVICES=4,5,6,7
3 | # export CUDA_VISIBLE_DEVICES=7
4 |
5 |
6 | cfg=./config/train_dome.py
7 | dir=./work_dir/dome
8 | ckpt=ckpts/dome_latest.pth
9 | vae_ckpt=ckpts/occvae_latest.pth
10 |
11 |
12 | python tools/eval_metric.py \
13 | --py-config $cfg \
14 | --work-dir $dir \
15 | --resume-from $ckpt \
16 | --vae-resume-from $vae_ckpt
17 |
--------------------------------------------------------------------------------
/tools/eval_vae.py:
--------------------------------------------------------------------------------
1 | import time, argparse, os.path as osp, os
2 | import torch, numpy as np
3 | import torch.distributed as dist
4 | from copy import deepcopy
5 |
6 | import mmcv
7 | from mmengine import Config
8 | from mmengine.runner import set_random_seed
9 | from mmengine.optim import build_optim_wrapper
10 | from mmengine.logging import MMLogger
11 | from mmengine.utils import symlink
12 | from mmengine.registry import MODELS
13 | from timm.scheduler import CosineLRScheduler, MultiStepLRScheduler
14 | import sys
15 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
16 | from utils.load_save_util import revise_ckpt, revise_ckpt_1, load_checkpoint
17 | from torch.utils.tensorboard import SummaryWriter
18 | import warnings
19 | from einops import rearrange
20 | warnings.filterwarnings("ignore")
21 |
22 |
23 | def pass_print(*args, **kwargs):
24 | pass
25 |
26 | @torch.no_grad()
27 | def main(local_rank, args):
28 | # global settings
29 | set_random_seed(args.seed)
30 | torch.backends.cudnn.deterministic = False
31 | torch.backends.cudnn.benchmark = True
32 |
33 | # load config
34 | cfg = Config.fromfile(args.py_config)
35 | cfg.work_dir = args.work_dir
36 |
37 | # init DDP
38 | if args.gpus > 1:
39 | distributed = True
40 | ip = os.environ.get("MASTER_ADDR", "127.0.0.1")
41 | port = os.environ.get("MASTER_PORT", cfg.get("port", 29510))
42 | hosts = int(os.environ.get("WORLD_SIZE", 1)) # number of nodes
43 | rank = int(os.environ.get("RANK", 0)) # node id
44 | gpus = torch.cuda.device_count() # gpus per node
45 | print(f"tcp://{ip}:{port}")
46 | dist.init_process_group(
47 | backend="nccl", init_method=f"tcp://{ip}:{port}",
48 | world_size=hosts * gpus, rank=rank * gpus + local_rank)
49 | world_size = dist.get_world_size()
50 | cfg.gpu_ids = range(world_size)
51 | torch.cuda.set_device(local_rank)
52 |
53 | if local_rank != 0:
54 | import builtins
55 | builtins.print = pass_print
56 | else:
57 | distributed = False
58 | world_size = 1
59 |
60 | if local_rank == 0:
61 | os.makedirs(args.work_dir, exist_ok=True)
62 | cfg.dump(osp.join(args.work_dir, osp.basename(args.py_config)))
63 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
64 | log_file = osp.join(args.work_dir, f'{timestamp}.log')
65 | logger = MMLogger('genocc', log_file=log_file)
66 | MMLogger._instance_dict['genocc'] = logger
67 | logger.info(f'Config:\n{cfg.pretty_text}')
68 | tb_dir=args.tb_dir if args.tb_dir else osp.join(args.work_dir, 'tb_log')
69 | writer = SummaryWriter(tb_dir)
70 |
71 | # build model
72 | import model
73 | from dataset import get_dataloader, get_nuScenes_label_name
74 | from loss import OPENOCC_LOSS
75 | from utils.metric_util import MeanIoU, multi_step_MeanIou
76 | from utils.freeze_model import freeze_model
77 |
78 | my_model = MODELS.build(cfg.model)
79 | my_model.init_weights()
80 | n_parameters = sum(p.numel() for p in my_model.parameters() if p.requires_grad)
81 | logger.info(f'Number of params: {n_parameters}')
82 | if cfg.get('freeze_dict', False):
83 | logger.info(f'Freezing model according to freeze_dict:{cfg.freeze_dict}')
84 | freeze_model(my_model, cfg.freeze_dict)
85 | n_parameters = sum(p.numel() for p in my_model.parameters() if p.requires_grad)
86 | logger.info(f'Number of params after freezed: {n_parameters}')
87 | if distributed:
88 | if cfg.get('syncBN', True):
89 | my_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(my_model)
90 | logger.info('converted sync bn.')
91 |
92 | find_unused_parameters = cfg.get('find_unused_parameters', True)
93 | ddp_model_module = torch.nn.parallel.DistributedDataParallel
94 | my_model = ddp_model_module(
95 | my_model.cuda(),
96 | device_ids=[torch.cuda.current_device()],
97 | broadcast_buffers=False,
98 | find_unused_parameters=find_unused_parameters)
99 | raw_model = my_model.module
100 | else:
101 | my_model = my_model.cuda()
102 | raw_model = my_model
103 | logger.info('done ddp model')
104 |
105 | train_dataset_loader, val_dataset_loader = get_dataloader(
106 | cfg.train_dataset_config,
107 | cfg.val_dataset_config,
108 | cfg.train_wrapper_config,
109 | cfg.val_wrapper_config,
110 | cfg.train_loader,
111 | cfg.val_loader,
112 | dist=distributed,
113 | iter_resume=args.iter_resume)
114 |
115 | # get optimizer, loss, scheduler
116 | optimizer = build_optim_wrapper(my_model, cfg.optimizer)
117 | loss_func = OPENOCC_LOSS.build(cfg.loss).cuda()
118 | max_num_epochs = cfg.max_epochs
119 | if cfg.get('multisteplr', False):
120 | scheduler = MultiStepLRScheduler(
121 | optimizer,
122 | **cfg.multisteplr_config)
123 | else:
124 | scheduler = CosineLRScheduler(
125 | optimizer,
126 | t_initial=len(train_dataset_loader) * max_num_epochs,
127 | lr_min=1e-6,
128 | warmup_t=cfg.get('warmup_iters', 500),
129 | warmup_lr_init=1e-6,
130 | t_in_epochs=False)
131 |
132 | # resume and load
133 | epoch = 0
134 | global_iter = 0
135 | last_iter = 0
136 | best_val_iou = [0]*cfg.get('return_len_', 10)
137 | best_val_miou = [0]*cfg.get('return_len_', 10)
138 |
139 | cfg.resume_from = ''
140 | # if osp.exists(osp.join(args.work_dir, 'latest.pth')):
141 | # cfg.resume_from = osp.join(args.work_dir, 'latest.pth')
142 | if args.resume_from:
143 | cfg.resume_from = args.resume_from
144 | if args.load_from:
145 | cfg.load_from = args.load_from
146 |
147 | logger.info('resume from: ' + cfg.resume_from)
148 | logger.info('load from: ' + cfg.load_from)
149 | logger.info('work dir: ' + args.work_dir)
150 |
151 | if cfg.resume_from and osp.exists(cfg.resume_from):
152 | map_location = 'cpu'
153 | ckpt = torch.load(cfg.resume_from, map_location=map_location)
154 | print(raw_model.load_state_dict(ckpt['state_dict'], strict=False))
155 | optimizer.load_state_dict(ckpt['optimizer'])
156 | scheduler.load_state_dict(ckpt['scheduler'])
157 | epoch = ckpt['epoch']
158 | global_iter = ckpt['global_iter']
159 | last_iter = ckpt['last_iter'] if 'last_iter' in ckpt else 0
160 | if 'best_val_iou' in ckpt:
161 | best_val_iou = ckpt['best_val_iou']
162 | if 'best_val_miou' in ckpt:
163 | best_val_miou = ckpt['best_val_miou']
164 |
165 | if hasattr(train_dataset_loader.sampler, 'set_last_iter'):
166 | train_dataset_loader.sampler.set_last_iter(last_iter)
167 | print(f'successfully resumed from epoch {epoch}')
168 | elif cfg.load_from:
169 | ckpt = torch.load(cfg.load_from, map_location='cpu')
170 | if 'state_dict' in ckpt:
171 | state_dict = ckpt['state_dict']
172 | else:
173 | state_dict = ckpt
174 | if cfg.get('revise_ckpt', False):
175 | if cfg.revise_ckpt == 1:
176 | print('revise_ckpt')
177 | print(raw_model.load_state_dict(revise_ckpt(state_dict), strict=False))
178 | elif cfg.revise_ckpt == 2:
179 | print('revise_ckpt_1')
180 | print(raw_model.load_state_dict(revise_ckpt_1(state_dict), strict=False))
181 | elif cfg.revise_ckpt == 3:
182 | print('revise_ckpt_2')
183 | print(raw_model.vae.load_state_dict(state_dict, strict=False))
184 | else:
185 | # print(raw_model.load_state_dict(state_dict, strict=False))
186 | load_checkpoint(raw_model,state_dict, strict=False) #TODO may need to remove moudle.xxx
187 |
188 | # training
189 | print_freq = cfg.print_freq
190 | grad_norm = 0
191 |
192 | label_name = get_nuScenes_label_name(cfg.label_mapping)
193 | unique_label = np.asarray(cfg.unique_label)
194 | unique_label_str = [label_name[l] for l in unique_label]
195 | # CalMeanIou_sem = multi_step_MeanIou(unique_label, cfg.get('ignore_label', -100), unique_label_str, 'sem', times=cfg.get('return_len_', 10))
196 | # CalMeanIou_vox = multi_step_MeanIou([1], cfg.get('ignore_label', -100), ['occupied'], 'vox', times=cfg.get('return_len_', 10))
197 | CalMeanIou_sem = multi_step_MeanIou(unique_label, cfg.get('ignore_label', -100), unique_label_str, 'sem', times=1)#cfg.get('return_len_', 10))
198 | CalMeanIou_vox = multi_step_MeanIou([1], cfg.get('ignore_label', -100), ['occupied'], 'vox', times=1)#cfg.get('return_len_', 10))
199 | # logger.info('compiling model')
200 | # my_model = torch.compile(my_model)
201 | # logger.info('done compile model')
202 | best_plan_loss = 100000
203 | # max_num_epochs=1 #debug
204 | if True:
205 | my_model.eval()
206 | os.environ['eval'] = 'true'
207 | val_loss_list = []
208 | CalMeanIou_sem.reset()
209 | CalMeanIou_vox.reset()
210 | plan_loss = 0
211 |
212 | with torch.no_grad():
213 | for i_iter_val, (input_occs, target_occs, metas) in enumerate(val_dataset_loader):
214 | # input_occs=rearrange(input_occs,'b (f1 f) h w d-> (b f) (f1 1) h w d',f1=2)
215 | # target_occs=rearrange(target_occs,'b f h w d-> (b f) 1 h w d')
216 |
217 | input_occs = input_occs.cuda()
218 | target_occs = target_occs.cuda()
219 | data_time_e = time.time()
220 |
221 | result_dict = my_model(x=input_occs, metas=metas)
222 |
223 | loss_input = {
224 | 'inputs': input_occs,
225 | 'target_occs': target_occs,
226 | # 'metas': metas
227 | **result_dict
228 | }
229 | for loss_input_key, loss_input_val in cfg.loss_input_convertion.items():
230 | loss_input.update({
231 | loss_input_key: result_dict[loss_input_val]
232 | })
233 | loss, loss_dict = loss_func(loss_input)
234 | plan_loss += loss_dict.get('PlanRegLoss', 0)
235 | plan_loss += loss_dict.get('PlanRegLossLidar', 0)
236 | if result_dict.get('target_occs', None) is not None:
237 | target_occs = result_dict['target_occs']
238 | target_occs_iou = deepcopy(target_occs)
239 | target_occs_iou[target_occs_iou != 17] = 1
240 | target_occs_iou[target_occs_iou == 17] = 0
241 |
242 | CalMeanIou_sem._after_step(
243 | rearrange(result_dict['sem_pred'],'b f h w d-> (b f) 1 h w d'),
244 | rearrange(target_occs,'b f h w d-> (b f) 1 h w d'))
245 | CalMeanIou_vox._after_step(
246 | rearrange(result_dict['iou_pred'],'b f h w d-> (b f) 1 h w d'),
247 | rearrange(target_occs_iou,'b f h w d-> (b f) 1 h w d'))
248 | val_loss_list.append(loss.detach().cpu().numpy())
249 | if i_iter_val % print_freq == 0 and local_rank == 0:
250 | logger.info('[EVAL] Epoch %d Iter %5d: Loss: %.3f (%.3f)'%(
251 | epoch, i_iter_val, loss.item(), np.mean(val_loss_list)))
252 | writer.add_scalar(f'val/loss', loss.item(), global_iter)
253 | detailed_loss = []
254 | for loss_name, loss_value in loss_dict.items():
255 | detailed_loss.append(f'{loss_name}: {loss_value:.5f}')
256 | writer.add_scalar(f'val/{loss_name}', loss_value, global_iter)
257 | detailed_loss = ', '.join(detailed_loss)
258 | logger.info(detailed_loss)
259 | # break #debug
260 | val_miou, _ = CalMeanIou_sem._after_epoch()
261 | val_iou, _ = CalMeanIou_vox._after_epoch()
262 |
263 | del target_occs, input_occs
264 | plan_loss = plan_loss/len(val_dataset_loader)
265 | if plan_loss < best_plan_loss:
266 | best_plan_loss = plan_loss
267 | logger.info(f'PlanRegLoss is {plan_loss} while the best plan loss is {best_plan_loss}')
268 | #logger.info(f'PlanRegLoss is {plan_loss/len(val_dataset_loader)}')
269 | best_val_iou = val_iou#[max(best_val_iou[i], val_iou[i]) for i in range(len(best_val_iou))]
270 | best_val_miou = val_miou#[max(best_val_miou[i], val_miou[i]) for i in range(len(best_val_miou))]
271 | #logger.info(f'PlanRegLoss is {plan_loss/len(val_dataset_loader)}')
272 | logger.info(f'Current val iou is {val_iou} while the best val iou is {best_val_iou}')
273 | logger.info(f'Current val miou is {val_miou} while the best val miou is {best_val_miou}')
274 | torch.cuda.empty_cache()
275 |
276 |
277 | if __name__ == '__main__':
278 | # Training settings
279 | parser = argparse.ArgumentParser(description='')
280 | parser.add_argument('--py-config', default='config/tpv_lidarseg.py')
281 | parser.add_argument('--work-dir', type=str, default='./out/tpv_lidarseg')
282 | parser.add_argument('--tb-dir', type=str, default=None)
283 | parser.add_argument('--resume-from', type=str, default='')
284 | parser.add_argument('--iter-resume', action='store_true', default=False)
285 | parser.add_argument('--seed', type=int, default=42)
286 | parser.add_argument('--load_from', type=str, default=None)
287 | args = parser.parse_args()
288 |
289 | ngpus = torch.cuda.device_count()
290 | args.gpus = ngpus
291 | print(args)
292 |
293 | if ngpus > 1:
294 | torch.multiprocessing.spawn(main, args=(args,), nprocs=args.gpus)
295 | else:
296 | main(0, args)
297 |
--------------------------------------------------------------------------------
/tools/eval_vae.sh:
--------------------------------------------------------------------------------
1 |
2 | # export CUDA_VISIBLE_DEVICES=4,5,6,7
3 | # export CUDA_VISIBLE_DEVICES=7
4 |
5 |
6 | cfg=./config/train_occvae.py
7 | dir=./work_dir/occ_vae
8 | vae_ckpt=ckpts/occvae_latest.pth
9 |
10 | python tools/eval_vae.py \
11 | --py-config $cfg \
12 | --work-dir $dir \
13 | --load_from $vae_ckpt
14 |
--------------------------------------------------------------------------------
/tools/train_diffusion.sh:
--------------------------------------------------------------------------------
1 | # export CUDA_VISIBLE_DEVICES=5
2 |
3 |
4 | cfg=./config/train_dome.py
5 | dir=./work_dir/dome
6 | vae_ckpt=ckpts/occvae_latest.pth
7 |
8 |
9 | python tools/train_diffusion.py \
10 | --py-config $cfg \
11 | --work-dir $dir
--------------------------------------------------------------------------------
/tools/train_vae.py:
--------------------------------------------------------------------------------
1 | import time, argparse, os.path as osp, os
2 | import torch, numpy as np
3 | import torch.distributed as dist
4 | from copy import deepcopy
5 |
6 | import mmcv
7 | from mmengine import Config
8 | from mmengine.runner import set_random_seed
9 | from mmengine.optim import build_optim_wrapper
10 | from mmengine.logging import MMLogger
11 | from mmengine.utils import symlink
12 | from mmengine.registry import MODELS
13 | from timm.scheduler import CosineLRScheduler, MultiStepLRScheduler
14 | import sys
15 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
16 | from utils.load_save_util import revise_ckpt, revise_ckpt_1, load_checkpoint
17 | from torch.utils.tensorboard import SummaryWriter
18 | import warnings
19 | warnings.filterwarnings("ignore")
20 |
21 |
22 | def pass_print(*args, **kwargs):
23 | pass
24 |
25 | def main(local_rank, args):
26 | # global settings
27 | set_random_seed(args.seed)
28 | torch.backends.cudnn.deterministic = False
29 | torch.backends.cudnn.benchmark = True
30 |
31 | # load config
32 | cfg = Config.fromfile(args.py_config)
33 | cfg.work_dir = args.work_dir
34 |
35 | # init DDP
36 | if args.gpus > 1:
37 | distributed = True
38 | ip = os.environ.get("MASTER_ADDR", "127.0.0.1")
39 | port = os.environ.get("MASTER_PORT", cfg.get("port", 29510))
40 | hosts = int(os.environ.get("WORLD_SIZE", 1)) # number of nodes
41 | rank = int(os.environ.get("RANK", 0)) # node id
42 | gpus = torch.cuda.device_count() # gpus per node
43 | print(f"tcp://{ip}:{port}")
44 | dist.init_process_group(
45 | backend="nccl", init_method=f"tcp://{ip}:{port}",
46 | world_size=hosts * gpus, rank=rank * gpus + local_rank)
47 | world_size = dist.get_world_size()
48 | cfg.gpu_ids = range(world_size)
49 | torch.cuda.set_device(local_rank)
50 |
51 | if local_rank != 0:
52 | import builtins
53 | builtins.print = pass_print
54 | else:
55 | distributed = False
56 | world_size = 1
57 |
58 | if local_rank == 0:
59 | os.makedirs(args.work_dir, exist_ok=True)
60 | cfg.dump(osp.join(args.work_dir, osp.basename(args.py_config)))
61 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
62 | log_file = osp.join(args.work_dir, f'{timestamp}.log')
63 | logger = MMLogger('genocc', log_file=log_file)
64 | MMLogger._instance_dict['genocc'] = logger
65 | logger.info(f'Config:\n{cfg.pretty_text}')
66 | tb_dir=args.tb_dir if args.tb_dir else osp.join(args.work_dir, 'tb_log')
67 | writer = SummaryWriter(tb_dir)
68 |
69 | # build model
70 | import model
71 | from dataset import get_dataloader, get_nuScenes_label_name
72 | from loss import OPENOCC_LOSS
73 | from utils.metric_util import MeanIoU, multi_step_MeanIou
74 | from utils.freeze_model import freeze_model
75 |
76 | my_model = MODELS.build(cfg.model)
77 | my_model.init_weights()
78 | n_parameters = sum(p.numel() for p in my_model.parameters() if p.requires_grad)
79 | logger.info(f'Number of params: {n_parameters}')
80 | if cfg.get('freeze_dict', False):
81 | logger.info(f'Freezing model according to freeze_dict:{cfg.freeze_dict}')
82 | freeze_model(my_model, cfg.freeze_dict)
83 | n_parameters = sum(p.numel() for p in my_model.parameters() if p.requires_grad)
84 | logger.info(f'Number of params after freezed: {n_parameters}')
85 | if distributed:
86 | if cfg.get('syncBN', True):
87 | my_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(my_model)
88 | logger.info('converted sync bn.')
89 |
90 | find_unused_parameters = cfg.get('find_unused_parameters', False)
91 | ddp_model_module = torch.nn.parallel.DistributedDataParallel
92 | my_model = ddp_model_module(
93 | my_model.cuda(),
94 | device_ids=[torch.cuda.current_device()],
95 | broadcast_buffers=False,
96 | find_unused_parameters=find_unused_parameters)
97 | raw_model = my_model.module
98 | else:
99 | my_model = my_model.cuda()
100 | raw_model = my_model
101 | logger.info('done ddp model')
102 |
103 | train_dataset_loader, val_dataset_loader = get_dataloader(
104 | cfg.train_dataset_config,
105 | cfg.val_dataset_config,
106 | cfg.train_wrapper_config,
107 | cfg.val_wrapper_config,
108 | cfg.train_loader,
109 | cfg.val_loader,
110 | dist=distributed,
111 | iter_resume=args.iter_resume)
112 |
113 | # get optimizer, loss, scheduler
114 | optimizer = build_optim_wrapper(my_model, cfg.optimizer)
115 | loss_func = OPENOCC_LOSS.build(cfg.loss).cuda()
116 | max_num_epochs = cfg.max_epochs
117 | if cfg.get('multisteplr', False):
118 | scheduler = MultiStepLRScheduler(
119 | optimizer,
120 | **cfg.multisteplr_config)
121 | else:
122 | scheduler = CosineLRScheduler(
123 | optimizer,
124 | t_initial=len(train_dataset_loader) * max_num_epochs,
125 | lr_min=1e-6,
126 | warmup_t=cfg.get('warmup_iters', 500),
127 | warmup_lr_init=1e-6,
128 | t_in_epochs=False)
129 |
130 | # resume and load
131 | epoch = 0
132 | global_iter = 0
133 | last_iter = 0
134 | best_val_iou = [0]*cfg.get('return_len_', 10)
135 | best_val_miou = [0]*cfg.get('return_len_', 10)
136 |
137 | cfg.resume_from = ''
138 | # if osp.exists(osp.join(args.work_dir, 'latest.pth')):
139 | # cfg.resume_from = osp.join(args.work_dir, 'latest.pth')
140 | if args.resume_from:
141 | cfg.resume_from = args.resume_from
142 | if args.load_from:
143 | cfg.load_from = args.load_from
144 |
145 | logger.info('resume from: ' + cfg.resume_from)
146 | logger.info('load from: ' + cfg.load_from)
147 | logger.info('work dir: ' + args.work_dir)
148 |
149 | if cfg.resume_from and osp.exists(cfg.resume_from):
150 | map_location = 'cpu'
151 | ckpt = torch.load(cfg.resume_from, map_location=map_location)
152 | print(raw_model.load_state_dict(ckpt['state_dict'], strict=False))
153 | optimizer.load_state_dict(ckpt['optimizer'])
154 | scheduler.load_state_dict(ckpt['scheduler'])
155 | epoch = ckpt['epoch']
156 | global_iter = ckpt['global_iter']
157 | last_iter = ckpt['last_iter'] if 'last_iter' in ckpt else 0
158 | if 'best_val_iou' in ckpt:
159 | best_val_iou = ckpt['best_val_iou']
160 | if 'best_val_miou' in ckpt:
161 | best_val_miou = ckpt['best_val_miou']
162 |
163 | if hasattr(train_dataset_loader.sampler, 'set_last_iter'):
164 | train_dataset_loader.sampler.set_last_iter(last_iter)
165 | print(f'successfully resumed from epoch {epoch}')
166 | elif cfg.load_from:
167 | ckpt = torch.load(cfg.load_from, map_location='cpu')
168 | if 'state_dict' in ckpt:
169 | state_dict = ckpt['state_dict']
170 | else:
171 | state_dict = ckpt
172 | if cfg.get('revise_ckpt', False):
173 | if cfg.revise_ckpt == 1:
174 | print('revise_ckpt')
175 | print(raw_model.load_state_dict(revise_ckpt(state_dict), strict=False))
176 | elif cfg.revise_ckpt == 2:
177 | print('revise_ckpt_1')
178 | print(raw_model.load_state_dict(revise_ckpt_1(state_dict), strict=False))
179 | elif cfg.revise_ckpt == 3:
180 | print('revise_ckpt_2')
181 | print(raw_model.vae.load_state_dict(state_dict, strict=False))
182 | else:
183 | # print(raw_model.load_state_dict(state_dict, strict=False))
184 | load_checkpoint(raw_model,state_dict, strict=False) #TODO may need to remove moudle.xxx
185 |
186 | # training
187 | print_freq = cfg.print_freq
188 | first_run = True
189 | grad_norm = 0
190 |
191 | label_name = get_nuScenes_label_name(cfg.label_mapping)
192 | unique_label = np.asarray(cfg.unique_label)
193 | unique_label_str = [label_name[l] for l in unique_label]
194 | CalMeanIou_sem = multi_step_MeanIou(unique_label, cfg.get('ignore_label', -100), unique_label_str, 'sem', times=cfg.get('return_len_', 10))
195 | CalMeanIou_vox = multi_step_MeanIou([1], cfg.get('ignore_label', -100), ['occupied'], 'vox', times=cfg.get('return_len_', 10))
196 | # logger.info('compiling model')
197 | # my_model = torch.compile(my_model)
198 | # logger.info('done compile model')
199 | # max_num_epochs=1 #debug
200 | while epoch < max_num_epochs:
201 |
202 | my_model.train()
203 | os.environ['eval'] = 'false'
204 | if hasattr(train_dataset_loader.sampler, 'set_epoch'):
205 | train_dataset_loader.sampler.set_epoch(epoch)
206 | loss_list = []
207 | time.sleep(10)
208 | data_time_s = time.time()
209 | time_s = time.time()
210 | for i_iter, (input_occs, target_occs, metas) in enumerate(train_dataset_loader):
211 | if first_run:
212 | i_iter = i_iter + last_iter
213 |
214 | input_occs = input_occs.cuda()
215 | target_occs = target_occs.cuda()
216 | data_time_e = time.time()
217 | use_pose_condition = torch.rand(1) < cfg.get('p_use_pose_condition',0)
218 | result_dict = my_model(x=input_occs, metas=metas,use_pose_condition=use_pose_condition)
219 |
220 | loss_input = {
221 | 'inputs': input_occs,
222 | 'target_occs': target_occs,
223 | # 'metas': metas
224 | **result_dict,
225 | }
226 |
227 | for loss_input_key, loss_input_val in cfg.loss_input_convertion.items():
228 | input_=result_dict[loss_input_val]
229 | if 'temperal_mask' in result_dict:
230 | t_mask=result_dict['temperal_mask']
231 | if input_.dim()==4:
232 | t_mask= t_mask.unsqueeze(1)
233 | input_*=t_mask
234 | loss_input.update({
235 | loss_input_key: input_})
236 | loss, loss_dict = loss_func(loss_input)
237 | optimizer.zero_grad()
238 | loss.backward()
239 | grad_norm = torch.nn.utils.clip_grad_norm_(my_model.parameters(), cfg.grad_max_norm)
240 | optimizer.step()
241 |
242 | loss_list.append(loss.detach().cpu().item())
243 | scheduler.step_update(global_iter)
244 | time_e = time.time()
245 |
246 | global_iter += 1
247 | if i_iter % print_freq == 0 and local_rank == 0:
248 | lr = optimizer.param_groups[0]['lr']
249 | logger.info('[TRAIN] Epoch %d Iter %5d/%d: Loss: %.3f (%.3f), grad_norm: %.3f, lr: %.7f, time: %.3f (%.3f)'%(
250 | epoch, i_iter, len(train_dataset_loader),
251 | loss.item(), np.mean(loss_list), grad_norm, lr,
252 | time_e - time_s, data_time_e - data_time_s))
253 | writer.add_scalar(f'train/loss', loss.item(), global_iter)
254 | detailed_loss = []
255 | for loss_name, loss_value in loss_dict.items():
256 | detailed_loss.append(f'{loss_name}: {loss_value:.5f}')
257 | writer.add_scalar(f'train/{loss_name}', loss_value, global_iter)
258 | detailed_loss = ', '.join(detailed_loss)
259 | logger.info(detailed_loss)
260 | loss_list = []
261 | # exit(0) #debug
262 | data_time_s = time.time()
263 | time_s = time.time()
264 |
265 | if args.iter_resume:
266 | if (i_iter + 1) % 50 == 0 and local_rank == 0:
267 | dict_to_save = {
268 | 'state_dict': raw_model.state_dict(),
269 | 'optimizer': optimizer.state_dict(),
270 | 'scheduler': scheduler.state_dict(),
271 | 'epoch': epoch,
272 | 'global_iter': global_iter,
273 | 'last_iter': i_iter + 1,
274 | }
275 | save_file_name = os.path.join(os.path.abspath(args.work_dir), 'iter.pth')
276 | torch.save(dict_to_save, save_file_name)
277 | dst_file = osp.join(args.work_dir, 'latest.pth')
278 | symlink(save_file_name, dst_file)
279 | logger.info(f'iter ckpt {i_iter + 1} saved!')
280 | # break #debug
281 |
282 | # save checkpoint
283 | if local_rank == 0 and (epoch+1) % cfg.get('save_every_epochs', 1) == 0:
284 | dict_to_save = {
285 | 'state_dict': raw_model.state_dict(),
286 | 'optimizer': optimizer.state_dict(),
287 | 'scheduler': scheduler.state_dict(),
288 | 'epoch': epoch + 1,
289 | 'global_iter': global_iter,
290 | }
291 | save_file_name = os.path.join(os.path.abspath(args.work_dir), f'epoch_{epoch+1}.pth')
292 | torch.save(dict_to_save, save_file_name)
293 | dst_file = osp.join(args.work_dir, 'latest.pth')
294 | symlink(save_file_name, dst_file)
295 |
296 | epoch += 1
297 | first_run = False
298 |
299 | # eval
300 | if epoch % cfg.get('eval_every_epochs', 1) != 0:
301 | continue
302 | my_model.eval()
303 | os.environ['eval'] = 'true'
304 | val_loss_list = []
305 | CalMeanIou_sem.reset()
306 | CalMeanIou_vox.reset()
307 |
308 | with torch.no_grad():
309 | for i_iter_val, (input_occs, target_occs, metas) in enumerate(val_dataset_loader):
310 |
311 | input_occs = input_occs.cuda()
312 | target_occs = target_occs.cuda()
313 | data_time_e = time.time()
314 |
315 | result_dict = my_model(x=input_occs, metas=metas)
316 |
317 | loss_input = {
318 | 'inputs': input_occs,
319 | 'target_occs': target_occs,
320 | # 'metas': metas
321 | **result_dict
322 | }
323 | for loss_input_key, loss_input_val in cfg.loss_input_convertion.items():
324 | loss_input.update({
325 | loss_input_key: result_dict[loss_input_val]
326 | })
327 | loss, loss_dict = loss_func(loss_input)
328 | if result_dict.get('target_occs', None) is not None:
329 | target_occs = result_dict['target_occs']
330 | target_occs_iou = deepcopy(target_occs)
331 | target_occs_iou[target_occs_iou != 17] = 1
332 | target_occs_iou[target_occs_iou == 17] = 0
333 |
334 | CalMeanIou_sem._after_step(result_dict['sem_pred'], target_occs)
335 | CalMeanIou_vox._after_step(result_dict['iou_pred'], target_occs_iou)
336 | val_loss_list.append(loss.detach().cpu().numpy())
337 | if i_iter_val % print_freq == 0 and local_rank == 0:
338 | logger.info('[EVAL] Epoch %d Iter %5d: Loss: %.3f (%.3f)'%(
339 | epoch, i_iter_val, loss.item(), np.mean(val_loss_list)))
340 | writer.add_scalar(f'val/loss', loss.item(), global_iter)
341 | detailed_loss = []
342 | for loss_name, loss_value in loss_dict.items():
343 | detailed_loss.append(f'{loss_name}: {loss_value:.5f}')
344 | writer.add_scalar(f'val/{loss_name}', loss_value, global_iter)
345 | detailed_loss = ', '.join(detailed_loss)
346 | logger.info(detailed_loss)
347 | # break #debug
348 | val_miou, _ = CalMeanIou_sem._after_epoch()
349 | val_iou, _ = CalMeanIou_vox._after_epoch()
350 |
351 | del target_occs, input_occs
352 | best_val_iou = [max(best_val_iou[i], val_iou[i]) for i in range(len(best_val_iou))]
353 | best_val_miou = [max(best_val_miou[i], val_miou[i]) for i in range(len(best_val_miou))]
354 | logger.info(f'Current val iou is {val_iou} while the best val iou is {best_val_iou}')
355 | logger.info(f'Current val miou is {val_miou} while the best val miou is {best_val_miou}')
356 | torch.cuda.empty_cache()
357 |
358 |
359 | if __name__ == '__main__':
360 | # Training settings
361 | parser = argparse.ArgumentParser(description='')
362 | parser.add_argument('--py-config', default='config/tpv_lidarseg.py')
363 | parser.add_argument('--work-dir', type=str, default='./out/tpv_lidarseg')
364 | parser.add_argument('--tb-dir', type=str, default=None)
365 | parser.add_argument('--resume-from', type=str, default='')
366 | parser.add_argument('--iter-resume', action='store_true', default=False)
367 | parser.add_argument('--seed', type=int, default=42)
368 | parser.add_argument('--load_from', type=str, default=None)
369 | args = parser.parse_args()
370 |
371 | ngpus = torch.cuda.device_count()
372 | args.gpus = ngpus
373 | print(args)
374 |
375 | if ngpus > 1:
376 | torch.multiprocessing.spawn(main, args=(args,), nprocs=args.gpus)
377 | else:
378 | main(0, args)
379 |
--------------------------------------------------------------------------------
/tools/train_vae.sh:
--------------------------------------------------------------------------------
1 |
2 | # export CUDA_VISIBLE_DEVICES=4,5,6,7
3 | # export CUDA_VISIBLE_DEVICES=7
4 |
5 |
6 | cfg=./config/train_occvae.py
7 | dir=./work_dir/occ_vae
8 |
9 |
10 | python tools/train_vae.py \
11 | --py-config $cfg \
12 | --work-dir $dir
--------------------------------------------------------------------------------
/tools/vis_diffusion.sh:
--------------------------------------------------------------------------------
1 | # export CUDA_VISIBLE_DEVICES=5
2 |
3 | cfg=./config/train_dome.py
4 | dir=./work_dir/dome
5 |
6 | vae_ckpt=ckpts/occvae_latest.pth
7 | ckpt=ckpts/dome_latest.pth
8 |
9 |
10 | python tools/visualize_demo.py \
11 | --py-config $cfg \
12 | --work-dir $dir \
13 | --resume-from $ckpt \
14 | --vae-resume-from $vae_ckpt \
15 |
16 |
--------------------------------------------------------------------------------
/tools/vis_gif.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from tqdm import tqdm
3 | import imageio
4 |
5 | src=r"/home/users/songen.gu/adwm/OccWorld/out/occworld/visgts_autoreg/200/nuScenesSceneDatasetLidar"
6 |
7 |
8 |
9 | def create_gif(src, fps=10):
10 | images = []
11 | for img in Path(src).rglob('*.png'):
12 | images.append(imageio.imread(img))
13 | imageio.mimsave(f'{src}/vis.gif', images, fps=fps)
14 |
15 | def create_mp4(src, fps=2):
16 | with imageio.get_writer(f'{src}/vis.mp4', mode='I', fps=fps) as writer:
17 | for img in Path(src).rglob('vis*.png'):
18 | writer.append_data(imageio.imread(img))
19 |
20 | if __name__ == '__main__':
21 | for dir in tqdm(Path(src).iterdir()):
22 | if dir.is_dir():
23 | create_mp4(dir)
24 | create_gif(dir)
25 |
--------------------------------------------------------------------------------
/tools/vis_utils.py:
--------------------------------------------------------------------------------
1 | from pyvirtualdisplay import Display
2 | display = Display(visible=False, size=(2560, 1440))
3 | display.start()
4 | from mayavi import mlab
5 | import mayavi
6 | mlab.options.offscreen = True
7 | print("Set mlab.options.offscreen={}".format(mlab.options.offscreen))
8 | import numpy as np
9 | import os
10 | from pyquaternion import Quaternion
11 | try:
12 | import open3d as o3d
13 | except:
14 | pass
15 | from functools import reduce
16 | import mmcv
17 |
18 | colors = np.array(
19 | [
20 | [255, 120, 50, 255], # barrier orange
21 | [255, 192, 203, 255], # bicycle pink
22 | [255, 255, 0, 255], # bus yellow
23 | [ 0, 150, 245, 255], # car blue
24 | [ 0, 255, 255, 255], # construction_vehicle cyan
25 | [255, 127, 0, 255], # motorcycle dark orange
26 | [255, 0, 0, 255], # pedestrian red
27 | [255, 240, 150, 255], # traffic_cone light yellow
28 | [135, 60, 0, 255], # trailer brown
29 | [160, 32, 240, 255], # truck purple
30 | [255, 0, 255, 255], # driveable_surface dark pink
31 | # [175, 0, 75, 255], # other_flat dark red
32 | [139, 137, 137, 255],
33 | [ 75, 0, 75, 255], # sidewalk dard purple
34 | [150, 240, 80, 255], # terrain light green
35 | [230, 230, 250, 255], # manmade white
36 | [ 0, 175, 0, 255], # vegetation green
37 | # [ 0, 255, 127, 255], # ego car dark cyan
38 | # [255, 99, 71, 255], # ego car
39 | # [ 0, 191, 255, 255] # ego car
40 | ]
41 | ).astype(np.uint8)
42 |
43 | def pass_print(*args, **kwargs):
44 | pass
45 |
46 |
47 | def get_grid_coords(dims, resolution):
48 | """
49 | :param dims: the dimensions of the grid [x, y, z] (i.e. [256, 256, 32])
50 | :return coords_grid: is the center coords of voxels in the grid
51 | """
52 |
53 | g_xx = np.arange(0, dims[0]) # [0, 1, ..., 256]
54 | # g_xx = g_xx[::-1]
55 | g_yy = np.arange(0, dims[1]) # [0, 1, ..., 256]
56 | # g_yy = g_yy[::-1]
57 | g_zz = np.arange(0, dims[2]) # [0, 1, ..., 32]
58 |
59 | # Obtaining the grid with coords...
60 | xx, yy, zz = np.meshgrid(g_xx, g_yy, g_zz)
61 | coords_grid = np.array([xx.flatten(), yy.flatten(), zz.flatten()]).T
62 | coords_grid = coords_grid.astype(np.float32)
63 | resolution = np.array(resolution, dtype=np.float32).reshape([1, 3])
64 |
65 | coords_grid = (coords_grid * resolution) + resolution / 2
66 |
67 | return coords_grid
68 |
69 | def draw(
70 | voxels, # semantic occupancy predictions
71 | pred_pts, # lidarseg predictions
72 | vox_origin,
73 | voxel_size=0.2, # voxel size in the real world
74 | grid=None, # voxel coordinates of point cloud
75 | pt_label=None, # label of point cloud
76 | save_dir=None,
77 | cam_positions=None,
78 | focal_positions=None,
79 | timestamp=None,
80 | mode=0,
81 | sem=False,
82 | show_ego=False
83 | ):
84 | w, h, z = voxels.shape
85 |
86 | # assert show_ego
87 | if show_ego:
88 | assert voxels.shape==(200, 200, 16)
89 | voxels[96:104, 96:104, 2:7] = 15
90 | voxels[104:106, 96:104, 2:5] = 3
91 |
92 | # Compute the voxels coordinates
93 | grid_coords = get_grid_coords(
94 | [voxels.shape[0], voxels.shape[1], voxels.shape[2]], voxel_size
95 | ) + np.array(vox_origin, dtype=np.float32).reshape([1, 3])
96 |
97 | if mode == 0:
98 | grid_coords = np.vstack([grid_coords.T, voxels.reshape(-1)]).T
99 | elif mode == 1:
100 | indexes = grid[:, 0] * h * z + grid[:, 1] * z + grid[:, 2]
101 | indexes, pt_index = np.unique(indexes, return_index=True)
102 | pred_pts = pred_pts[pt_index]
103 | grid_coords = grid_coords[indexes]
104 | grid_coords = np.vstack([grid_coords.T, pred_pts.reshape(-1)]).T
105 | elif mode == 2:
106 | indexes = grid[:, 0] * h * z + grid[:, 1] * z + grid[:, 2]
107 | indexes, pt_index = np.unique(indexes, return_index=True)
108 | gt_label = pt_label[pt_index]
109 | grid_coords = grid_coords[indexes]
110 | grid_coords = np.vstack([grid_coords.T, gt_label.reshape(-1)]).T
111 | else:
112 | raise NotImplementedError
113 |
114 | # Get the voxels inside FOV
115 | fov_grid_coords = grid_coords
116 |
117 | # Remove empty and unknown voxels
118 | fov_voxels = fov_grid_coords[
119 | (fov_grid_coords[:, 3] > 0) & (fov_grid_coords[:, 3] < 17)
120 | ]
121 | print(len(fov_voxels))
122 |
123 |
124 | figure = mlab.figure(size=(2560, 1440), bgcolor=(1, 1, 1))
125 | voxel_size = sum(voxel_size) / 3
126 | plt_plot_fov = mlab.points3d(
127 | # fov_voxels[:, 1],
128 | # fov_voxels[:, 0],
129 | fov_voxels[:, 0],
130 | fov_voxels[:, 1],
131 | fov_voxels[:, 2],
132 | fov_voxels[:, 3],
133 | scale_factor=1.0 * voxel_size,
134 | mode="cube",
135 | opacity=1.0,
136 | vmin=1,
137 | vmax=16, # 16
138 | )
139 |
140 | plt_plot_fov.glyph.scale_mode = "scale_by_vector"
141 | plt_plot_fov.module_manager.scalar_lut_manager.lut.table = colors
142 | dst=os.path.join(save_dir, f'vis_{timestamp}.png')
143 | mlab.savefig(dst)
144 | mlab.close()
145 | # crop
146 | im3=mmcv.imread(dst)[:,550:-530,:]
147 | # im3=mmcv.imread(dst)[:,590:-600,230:-230]
148 | # im3=mmcv.imread(dst)[360:-230,590:-600]
149 | mmcv.imwrite(im3, dst)
150 | return dst
151 |
152 |
153 | def write_pc(pc,dst,c=None):
154 | # pc=pc
155 | import open3d as o3d
156 | pcd = o3d.geometry.PointCloud()
157 | pcd.points = o3d.utility.Vector3dVector(pc)
158 | if c is not None:
159 | pcd.colors = o3d.utility.Vector3dVector(c)
160 | o3d.io.write_point_cloud(dst, pcd)
161 |
162 | def merge_mesh(meshes):
163 | return reduce(lambda x,y:x+y, meshes)
164 |
165 |
166 | def get_pose_mesh(trans_mat,s=5):
167 |
168 | # Create a coordinate frame with x-axis (red), y-axis (green), and z-axis (blue)
169 | mesh_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=s, origin=[0, 0, 0])
170 | mesh_frame.transform(trans_mat)
171 | # Save the coordinate frame to a file
172 | return mesh_frame
173 |
174 |
175 | def visualize_point_cloud(
176 | all_pred,
177 | abs_pose,
178 | abs_rot,
179 | vox_origin=[-40, -40, -1],
180 | resolution=0.4, #voxel size
181 | cmp_dir="./",
182 | key='gt'
183 | ):
184 | assert len(all_pred)==len(abs_pose)==len(abs_rot)
185 | all_occ,all_color=[],[]
186 | pose_mesh=[]
187 | for i,(occ,pose,rot) in enumerate(zip(all_pred,abs_pose,abs_rot)):
188 | occ=occ.reshape(-1)#.flatten()
189 | mask=(occ>=1)&(occ<16) # ignore GO
190 | cc=colors[occ[mask]-1][:,:3]/255.0 #[...,::-1]
191 |
192 | # occ_x,occ_y,occ_z=np.meshgrid(np.arange(200),np.arange(200),np.arange(16))
193 | # occ_x=occ_x.flatten()
194 | # occ_y=occ_y.flatten()
195 | # occ_z=occ_z.flatten()
196 | # occ_xyz=np.concatenate([occ_x[:,None],occ_y[:,None],occ_z[:,None]],axis=1)
197 | # occ_xyz=(occ_xyz * resolution) + resolution / 2 # to center
198 | # occ_xyz+=np.array([-40,-40,-1]) # to ego
199 | # Compute the voxels coordinates in ego frame
200 | occ_xyz = get_grid_coords(
201 | [200,200,16], [resolution]*3
202 | ) + np.array(vox_origin, dtype=np.float32).reshape([1, 3])
203 | write_pc(occ_xyz[mask],os.path.join(cmp_dir, f'vis_{key}_{i}_e.ply'),c=cc)
204 |
205 | # ego to world
206 | rot_m=Quaternion(rot).rotation_matrix[:3,:3]
207 | # rot_m=rr@rot_m
208 | trans_mat=np.eye(4)
209 | trans_mat[:3,:3]=rot_m
210 | trans_mat[:3,3]=pose
211 | rr=np.array([
212 | [0,1,0],
213 | [1,0,0],
214 | [0,0,1]
215 | ])
216 | occ_xyz=occ_xyz@rr.T
217 | occ_xyz=occ_xyz@rot_m.T +pose
218 |
219 | write_pc(occ_xyz[mask],os.path.join(cmp_dir, f'vis_{key}_{i}_w.ply'),c=cc)
220 |
221 | all_occ.append(occ_xyz[mask])
222 | all_color.append(cc)
223 | pose_mesh.append(get_pose_mesh(trans_mat))
224 |
225 |
226 | all_occ=np.concatenate(all_occ, axis=0)
227 | all_color=np.concatenate(all_color, axis=0)
228 |
229 | write_pc(all_occ,os.path.join(cmp_dir, f'vis_{key}_all_w.ply'),c=all_color)
230 | o3d.io.write_triangle_mesh(os.path.join(cmp_dir, f'vis_{key}_all_w_traj.ply'),merge_mesh(pose_mesh))
231 |
232 |
233 | def visualize_point_cloud_no_pose(
234 | all_pred,
235 | vox_origin=[-40, -40, -1],
236 | resolution=0.4, #voxel size
237 | cmp_dir="./",
238 | key='000000',
239 | key2='gt',
240 | offset=0,
241 | ):
242 | for i,occ in enumerate(all_pred):
243 | occ_d=occ.copy()
244 | occ=occ.reshape(-1)#.flatten()
245 | mask=(occ>=1)&(occ<16) # ignore GO
246 | cc=colors[occ[mask]-1][:,:3]/255.0 #[...,::-1]
247 |
248 | occ_xyz = get_grid_coords(
249 | [200,200,16], [resolution]*3
250 | ) + np.array(vox_origin, dtype=np.float32).reshape([1, 3])
251 | write_pc(occ_xyz[mask],os.path.join(cmp_dir, f'vis_{key}_{i+offset:02d}_e_{key2}.ply'),c=cc)
252 |
253 | np.save(os.path.join(cmp_dir, f'vis_{key}_{i+offset:02d}_e_{key2}.npy'),occ_d)
254 |
255 |
256 | if __name__=='__main__':
257 | # np.savez('/home/users/songen.gu/adwm/OccWorld/visualizations/aaaa.npz',input_occs0=input_occs0,input_occs=input_occs,metas0=metas0,metas=metas)
258 | # load
259 | data=np.load('/home/users/songen.gu/adwm/OccWorld/visualizations/aaaa.npz')
260 | input_occs0=data['input_occs0']
261 | input_occs=data['input_occs']
262 | dst_dir='/home/users/songen.gu/adwm/OccWorld/visualizations/abccc'
263 | os.makedirs(dst_dir,exist_ok=True)
264 | dst_wm=draw(input_occs0[10],
265 | None, # predict_pts,
266 | [-40, -40, -1],
267 | [0.4] * 3,
268 | None, # grid.squeeze(0).cpu().numpy(),
269 | None,# pt_label.squeeze(-1),
270 | dst_dir,#recon_dir,
271 | None, # img_metas[0]['cam_positions'],
272 | None, # img_metas[0]['focal_positions'],
273 | timestamp=10,
274 | mode=0,
275 | sem=False)
276 | dst_dir='/home/users/songen.gu/adwm/OccWorld/visualizations/abcc2'
277 | os.makedirs(dst_dir,exist_ok=True)
278 | dst_wm=draw(input_occs[10],
279 | None, # predict_pts,
280 | [-40, -40, -1],
281 | [0.4] * 3,
282 | None, # grid.squeeze(0).cpu().numpy(),
283 | None,# pt_label.squeeze(-1),
284 | dst_dir,#recon_dir,
285 | None, # img_metas[0]['cam_positions'],
286 | None, # img_metas[0]['focal_positions'],
287 | timestamp=10,
288 | mode=0,
289 | sem=False)
--------------------------------------------------------------------------------
/tools/vis_vae.sh:
--------------------------------------------------------------------------------
1 | # export CUDA_VISIBLE_DEVICES=5
2 |
3 | cfg=./config/train_occvae.py
4 | dir=./work_dir/occ_vae
5 |
6 | vae_ckpt=ckpts/occvae_latest.pth
7 |
8 | python tools/visualize_demo_vae.py \
9 | --py-config $cfg \
10 | --work-dir $dir \
11 | --resume-from $vae_ckpt \
12 | --export_pcd
13 |
14 |
--------------------------------------------------------------------------------
/tools/visualize_demo.py:
--------------------------------------------------------------------------------
1 | from pyvirtualdisplay import Display
2 | display = Display(visible=False, size=(2560, 1440))
3 | display.start()
4 | from mayavi import mlab
5 | import mayavi
6 | mlab.options.offscreen = True
7 | print("Set mlab.options.offscreen={}".format(mlab.options.offscreen))
8 |
9 | import pdb
10 | import time, argparse, os.path as osp, os
11 | import torch, numpy as np
12 | import sys
13 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
14 |
15 |
16 | import mmcv
17 | from mmengine import Config
18 | from mmengine.runner import set_random_seed
19 | from mmengine.logging import MMLogger
20 | from mmengine.registry import MODELS
21 | import cv2
22 | from vis_gif import create_mp4
23 | import warnings
24 | warnings.filterwarnings("ignore")
25 | from einops import rearrange
26 | from diffusion import create_diffusion
27 | from vis_utils import draw
28 |
29 |
30 |
31 | def main(args):
32 | # global settings
33 | set_random_seed(args.seed)
34 | torch.backends.cudnn.deterministic = False
35 | torch.backends.cudnn.benchmark = True
36 | # load config
37 | cfg = Config.fromfile(args.py_config)
38 | cfg.work_dir = args.work_dir
39 |
40 | os.makedirs(args.work_dir, exist_ok=True)
41 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
42 | args.dir_name=f'{args.dir_name}_{timestamp}'
43 |
44 | log_file = osp.join(args.work_dir, f'{cfg.get("data_type", "gts")}_visualize_{timestamp}.log')
45 | logger = MMLogger('genocc', log_file=log_file)
46 | MMLogger._instance_dict['genocc'] = logger
47 | logger.info(f'Config:\n{cfg.pretty_text}')
48 |
49 | # build model
50 | import model
51 | my_model = MODELS.build(cfg.model.world_model)
52 | n_parameters = sum(p.numel() for p in my_model.parameters() if p.requires_grad)
53 | logger.info(f'Number of params: {n_parameters}')
54 | my_model = my_model.cuda()
55 | raw_model = my_model
56 | vae=MODELS.build(cfg.model.vae).cuda()
57 |
58 | vae.requires_grad_(False)
59 | vae.eval()
60 |
61 | logger.info('done ddp model')
62 | from dataset import get_dataloader
63 | cfg.val_dataset_config.test_mode=True
64 | cfg.val_loader.num_workers=0
65 | cfg.train_loader.num_workers=0
66 |
67 | # cfg.val_dataset_config.new_rel_pose=False ## TODO
68 | # cfg.train_dataset_config.test_index_offset=args.test_index_offset
69 | cfg.val_dataset_config.test_index_offset=args.test_index_offset
70 | if args.return_len is not None:
71 | cfg.train_dataset_config.return_len=max(cfg.train_dataset_config.return_len,args.return_len)
72 | cfg.val_dataset_config.return_len=max(cfg.val_dataset_config.return_len,args.return_len)
73 | # cfg.val_dataset_config.return_len=60
74 |
75 | train_dataset_loader, val_dataset_loader = get_dataloader(
76 | cfg.train_dataset_config,
77 | cfg.val_dataset_config,
78 | cfg.train_wrapper_config,
79 | cfg.val_wrapper_config,
80 | cfg.train_loader,
81 | cfg.val_loader,
82 | dist=False)
83 | cfg.resume_from = ''
84 | if osp.exists(osp.join(args.work_dir, 'latest.pth')):
85 | cfg.resume_from = osp.join(args.work_dir, 'latest.pth')
86 | else:
87 | ckpts=[i for i in os.listdir(args.work_dir) if
88 | i.endswith('.pth') and i.replace('.pth','').replace('epoch_','').isdigit()]
89 | if len(ckpts)>0:
90 | ckpts.sort(key=lambda x:int(x.replace('.pth','').replace('epoch_','')))
91 | cfg.resume_from = osp.join(args.work_dir, ckpts[-1])
92 |
93 | if args.resume_from:
94 | cfg.resume_from = args.resume_from
95 | if args.vae_resume_from:
96 | cfg.vae_load_from=args.vae_resume_from
97 | logger.info('resume from: ' + cfg.resume_from)
98 | logger.info('vae resume from: ' + cfg.vae_load_from)
99 | logger.info('work dir: ' + args.work_dir)
100 |
101 | epoch = 'last'
102 | if cfg.resume_from and osp.exists(cfg.resume_from):
103 | map_location = 'cpu'
104 | ckpt = torch.load(cfg.resume_from, map_location=map_location)
105 | print(raw_model.load_state_dict(ckpt['state_dict'], strict=False))
106 | epoch = ckpt['epoch']
107 | print(f'successfully resumed from epoch {epoch}')
108 | elif cfg.load_from:
109 | ckpt = torch.load(cfg.load_from, map_location='cpu')
110 | if 'state_dict' in ckpt:
111 | state_dict = ckpt['state_dict']
112 | else:
113 | state_dict = ckpt
114 | print(raw_model.load_state_dict(state_dict, strict=False))
115 | print(vae.load_state_dict(torch.load(cfg.vae_load_from)['state_dict']))
116 |
117 | # eval
118 | my_model.eval()
119 | os.environ['eval'] = 'true'
120 | recon_dir = os.path.join(args.work_dir, args.dir_name)
121 | os.makedirs(recon_dir, exist_ok=True)
122 | os.environ['recon_dir']=recon_dir
123 |
124 | diffusion = create_diffusion(
125 | # timestep_respacing=str(cfg.sample.num_sampling_steps),
126 | timestep_respacing=str(args.num_sampling_steps),
127 | beta_start=cfg.schedule.beta_start,
128 | beta_end=cfg.schedule.beta_end,
129 | replace_cond_frames=cfg.replace_cond_frames,
130 | cond_frames_choices=cfg.cond_frames_choices,
131 | predict_xstart=cfg.schedule.get('predict_xstart',False),
132 | )
133 | if args.pose_control:
134 | cfg.sample.n_conds=1
135 | print(len(val_dataset_loader))
136 | with torch.no_grad():
137 | for i_iter_val, (input_occs, _, metas) in enumerate(val_dataset_loader):
138 | if i_iter_val not in args.scene_idx:
139 | continue
140 | if i_iter_val > max(args.scene_idx):
141 | break
142 | start_frame=cfg.get('start_frame', 0)
143 | mid_frame=cfg.get('mid_frame', 3)
144 | # end_frame=cfg.get('end_frame', 9)
145 | end_frame=input_occs.shape[1] if args.end_frame is None else args.end_frame
146 |
147 | if args.pose_control:
148 | # start_frame=0
149 | mid_frame=1
150 | # end_frame=10
151 | assert cfg.sample.n_conds==mid_frame
152 | # __import__('ipdb').set_trace()
153 | input_occs = input_occs.cuda() #torch.Size([1, 16, 200, 200, 16])
154 | bs,f,_,_,_=input_occs.shape
155 | encoded_latent, shape=vae.forward_encoder(input_occs)
156 | encoded_latent,_,_=vae.sample_z(encoded_latent) #bchw
157 | # encoded_latent = self.vae.vqvae.quant_conv(encoded_latent)
158 | # encoded_latent, _,_ = vae.vqvae(encoded_latent, is_voxel=False)
159 | input_latents=encoded_latent*cfg.model.vae.scaling_factor
160 | if input_latents.dim()==4:
161 | input_latents = rearrange(input_latents, '(b f) c h w -> b f c h w', b=bs).contiguous()
162 | elif input_latents.dim()==5:
163 | input_latents = rearrange(input_latents, 'b c f h w -> b f c h w', b=bs).contiguous()
164 | else:
165 | raise NotImplementedError
166 |
167 |
168 | # from debug_vis import visualize_tensor_pca
169 | # TODO fix dim bug torch.Size([1, 64, 12, 25, 25])
170 | # visualize_tensor_pca(encoded_latent.permute(0,2,3,1).cpu(), save_dir=recon_dir+'/debug_feature', filename=f'vis_vae_encode_{i_iter_val}.png')
171 | os.environ.update({'i_iter_val': str(i_iter_val)})
172 | os.environ.update({'recon_dir': str(recon_dir)})
173 | # rencon_occs=vae.forward_decoder(encoded_latent, shape, input_occs.shape)
174 |
175 | # gaussian diffusion pipeline
176 | w=h=cfg.model.vae.encoder_cfg.resolution
177 | vae_scale_factor = 2 ** (len(cfg.model.vae.encoder_cfg.ch_mult) - 1)
178 | vae_docoder_shapes=cfg.shapes[:len(cfg.model.vae.encoder_cfg.ch_mult) - 1]
179 | w//=vae_scale_factor
180 | h//=vae_scale_factor
181 |
182 | model_kwargs=dict(
183 | # # cfg_scale=cfg.sample.guidance_scale
184 | # metas=metas
185 | )
186 | if args.pose or args.pose_control:
187 | # assert False #debug pure gen
188 | model_kwargs['metas']=metas
189 | noise_shape=(bs, end_frame,cfg.base_channel, w,h,)
190 | initial_cond_indices=None
191 | n_conds=cfg.sample.get('n_conds',0)
192 | if n_conds:
193 | initial_cond_indices=[index for index in range(n_conds)]
194 |
195 | # Sample images:
196 | if cfg.sample.sample_method == 'ddim':
197 | latents = diffusion.ddim_sample_loop(
198 | my_model, noise_shape, None, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device='cuda'
199 | )
200 | elif cfg.sample.sample_method == 'ddpm':
201 | if args.rolling_sampling_n<2:
202 |
203 | latents = diffusion.p_sample_loop(
204 | my_model, noise_shape, None, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device='cuda',
205 | initial_cond_indices=initial_cond_indices,
206 | initial_cond_frames=input_latents,
207 | )
208 | else:
209 | latents=diffusion.p_sample_loop_cond_rollout(
210 | my_model, noise_shape, None, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device='cuda',
211 | # initial_cond_indices=initial_cond_indices,
212 | input_latents=input_latents,
213 | rolling_sampling_n=args.rolling_sampling_n,
214 | n_conds=n_conds,
215 | n_conds_roll=args.n_conds_roll
216 | )
217 | end_frame=latents.shape[1]
218 | latents = 1 / cfg.model.vae.scaling_factor * latents
219 |
220 | if cfg.model.vae.decoder_cfg.type=='Decoder3D':
221 | latents = rearrange(latents,'b f c h w-> b c f h w')
222 | else:
223 | # assert False #debug
224 | latents = rearrange(latents,'b f c h w -> (b f) c h w')
225 |
226 | logits = vae.forward_decoder(
227 | latents , shapes=vae_docoder_shapes,input_shape=[bs,end_frame,*cfg.shapes[0],cfg._dim_]
228 | )
229 | dst_dir = os.path.join(recon_dir, str(i_iter_val),'pred')
230 | input_dir = os.path.join(recon_dir, f'{i_iter_val}','input')
231 | # input_occs = result['input_occs']
232 | os.makedirs(dst_dir, exist_ok=True)
233 | os.makedirs(input_dir, exist_ok=True)
234 |
235 |
236 | if True:
237 | import matplotlib.pyplot as plt
238 | plt.clf()
239 | plt.plot(metas[0]['rel_poses'][:,0],metas[0]['rel_poses'][:,1],marker='o',alpha=0.5)
240 | plt.savefig(os.path.join(dst_dir, f'pose.png'))
241 | plt.clf()
242 | # for i, xyz in enumerate(e2g_t):
243 | # xy=xyz[:2]
244 | # gt_mode=gt_modes[i].astype('int').tolist().index(1)
245 | # ax2.annotate(f"{i+1}({gt_mode})", xy=xy, textcoords="offset points", xytext=(0,10), ha='center')
246 | # ax2.set_title('ego2global_translation (xy) (idx+gt_mode)')
247 |
248 | plt.plot(metas[0]['e2g_rel0_t'][:,0],metas[0]['e2g_rel0_t'][:,1])
249 | plt.scatter([0],[0],c='r')
250 |
251 | plt.annotate(f"start", xy=(0,0), textcoords="offset points", xytext=(0,10),ha='center')
252 |
253 |
254 | plt.savefig(os.path.join(dst_dir, f'pose_w.png'))
255 | # exit(0)
256 | all_pred=[]
257 | for frame in range(start_frame,end_frame):
258 | # for frame in range(0,end_frame):
259 | # if frame >15 and frame%10!=0:
260 | # continue
261 | # tt=str(i_iter_val) + '_' + str(frame)
262 | tt=str(i_iter_val) + '_' + str(frame+args.test_index_offset)
263 | # if frame < rencon_occs.shape[1]:
264 | # input_occ = rencon_occs[:, frame, ...].argmax(-1).squeeze().cpu().numpy()
265 | if frame < input_occs.shape[1] and not args.skip_gt:
266 | # if True:
267 | input_occ = input_occs[:, frame, ...].squeeze().cpu().numpy()
268 | draw(input_occ,
269 | None, # predict_pts,
270 | [-40, -40, -1],
271 | [0.4] * 3,
272 | None, # grid.squeeze(0).cpu().numpy(),
273 | None,# pt_label.squeeze(-1),
274 | input_dir,#recon_dir,
275 | None, # img_metas[0]['cam_positions'],
276 | None, # img_metas[0]['focal_positions'],
277 | timestamp=tt,
278 | mode=0,
279 | sem=False,
280 | show_ego=args.show_ego)
281 | if True:
282 | # if frame>=mid_frame:
283 | logit = logits[:, frame, ...]
284 | pred = logit.argmax(dim=-1).squeeze().cpu().numpy() # 1, 1, 200, 200, 16
285 | all_pred.append((pred))
286 |
287 | # all_pred.append((pred))
288 |
289 | draw(pred,
290 | None, # predict_pts,
291 | [-40, -40, -1],
292 | [0.4] * 3,
293 | None, # grid.squeeze(0).cpu().numpy(),
294 | None,# pt_label.squeeze(-1),
295 | dst_dir,#recon_dir,
296 | None, # img_metas[0]['cam_positions'],
297 | None, # img_metas[0]['focal_positions'],
298 | timestamp=tt,
299 | mode=0,
300 | sem=False,
301 | show_ego=args.show_ego)
302 | logger.info('[EVAL] Iter %5d / %5d'%(i_iter_val, len(val_dataset_loader)))
303 | create_mp4(dst_dir)
304 | # create_mp4(cmp_dir)
305 | if args.export_pcd:
306 | from vis_utils import visualize_point_cloud
307 |
308 | abs_pose=metas[0]['e2g_t']
309 | abs_rot=metas[0]['e2g_r']
310 | n_gt=min(len(all_pred),len(abs_pose))
311 | visualize_point_cloud(all_pred[:n_gt],abs_pose=abs_pose[:n_gt],abs_rot=abs_rot[:n_gt],cmp_dir=dst_dir,key='pred')
312 |
313 |
314 | if __name__ == '__main__':
315 | # Eval settings
316 | parser = argparse.ArgumentParser(description='')
317 | parser.add_argument('--py-config', default='config/tpv_lidarseg.py')
318 | parser.add_argument('--work-dir', type=str, default='./out/tpv_lidarseg')
319 | parser.add_argument('--resume-from', type=str, default='')
320 | parser.add_argument('--vae-resume-from', type=str, default='')
321 | parser.add_argument('--dir-name', type=str, default='vis')
322 | parser.add_argument('--num_sampling_steps', type=int, default=20)
323 | parser.add_argument('--seed', type=int, default=42)
324 | parser.add_argument('--end_frame', type=int, default=None)
325 | parser.add_argument('--n_conds_roll', type=int, default=None)
326 | parser.add_argument('--return_len', type=int, default=None)
327 | parser.add_argument('--num-trials', type=int, default=10)
328 | parser.add_argument('--frame-idx', nargs='+', type=int, default=[0, 10])
329 | #########################################
330 | # parser.add_argument('--scene-idx', nargs='+', type=int, default=[6,7,16,18,19,87,89,96,101])
331 | parser.add_argument('--scene-idx', nargs='+', type=int, default=[6,7])
332 | parser.add_argument('--rolling_sampling_n', type=int, default=1)
333 | parser.add_argument('--pose_control', action='store_true', default=False)
334 | parser.add_argument('--pose', action='store_true', default=True, help='Enable pose (default is True)')
335 | parser.add_argument('--no-pose', action='store_false', dest='pose', help='Disable pose')
336 | parser.add_argument('--test_index_offset',type=int, default=0)
337 | parser.add_argument('--ts',type=str, default=None)
338 | parser.add_argument('--skip_gt', action='store_true', default=False, help='Enable pose (default is True)')
339 | parser.add_argument('--show_ego', action='store_true', default=False, help='Enable pose (default is True)')
340 | parser.add_argument('--export_pcd', action='store_true', default=False, help='Enable pose (default is True)')
341 |
342 | args = parser.parse_args()
343 |
344 | ngpus = 1
345 | args.gpus = ngpus
346 | print(args)
347 | main(args)
348 |
349 |
--------------------------------------------------------------------------------
/tools/visualize_demo_vae.py:
--------------------------------------------------------------------------------
1 | from pyvirtualdisplay import Display
2 | display = Display(visible=False, size=(2560, 1440))
3 | display.start()
4 | import pdb
5 | from mayavi import mlab
6 | import mayavi
7 | mlab.options.offscreen = True
8 | print("Set mlab.options.offscreen={}".format(mlab.options.offscreen))
9 |
10 | import time, argparse, os.path as osp, os
11 | import sys
12 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
13 | import torch, numpy as np
14 | import mmcv
15 | from mmengine import Config
16 | from mmengine.runner import set_random_seed
17 | from mmengine.logging import MMLogger
18 | from mmengine.registry import MODELS
19 | import cv2
20 | from vis_gif import create_mp4
21 | import warnings
22 | warnings.filterwarnings("ignore")
23 | from vis_utils import draw
24 |
25 |
26 |
27 | def main(args):
28 | # global settings
29 | set_random_seed(args.seed)
30 | torch.backends.cudnn.deterministic = False
31 | torch.backends.cudnn.benchmark = True
32 | # load config
33 | cfg = Config.fromfile(args.py_config)
34 | cfg.work_dir = args.work_dir
35 |
36 | os.makedirs(args.work_dir, exist_ok=True)
37 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
38 | args.dir_name=f'{args.dir_name}_{timestamp}'
39 |
40 | log_file = osp.join(args.work_dir, f'{cfg.get("data_type", "gts")}_visualize_{timestamp}.log')
41 | logger = MMLogger('genocc', log_file=log_file)
42 | MMLogger._instance_dict['genocc'] = logger
43 | logger.info(f'Config:\n{cfg.pretty_text}')
44 |
45 | # build model
46 | import model
47 | my_model = MODELS.build(cfg.model)
48 | n_parameters = sum(p.numel() for p in my_model.parameters() if p.requires_grad)
49 | logger.info(f'Number of params: {n_parameters}')
50 | my_model = my_model.cuda()
51 | raw_model = my_model
52 | logger.info('done ddp model')
53 | from dataset import get_dataloader
54 | train_dataset_loader, val_dataset_loader = get_dataloader(
55 | cfg.train_dataset_config,
56 | cfg.val_dataset_config,
57 | cfg.train_wrapper_config,
58 | cfg.val_wrapper_config,
59 | cfg.train_loader,
60 | cfg.val_loader,
61 | dist=False)
62 | cfg.resume_from = ''
63 | if osp.exists(osp.join(args.work_dir, 'latest.pth')):
64 | cfg.resume_from = osp.join(args.work_dir, 'latest.pth')
65 | else:
66 | ckpts=[i for i in os.listdir(args.work_dir) if
67 | i.endswith('.pth') and i.replace('.pth','').replace('epoch_','').isdigit()]
68 | if len(ckpts)>0:
69 | ckpts.sort(key=lambda x:int(x.replace('.pth','').replace('epoch_','')))
70 | cfg.resume_from = osp.join(args.work_dir, ckpts[-1])
71 |
72 | if args.resume_from:
73 | cfg.resume_from = args.resume_from
74 | logger.info('resume from: ' + cfg.resume_from)
75 | logger.info('work dir: ' + args.work_dir)
76 |
77 | epoch = 'last'
78 | if cfg.resume_from and osp.exists(cfg.resume_from):
79 | map_location = 'cpu'
80 | ckpt = torch.load(cfg.resume_from, map_location=map_location)
81 | print(raw_model.load_state_dict(ckpt['state_dict'], strict=False))
82 | epoch = ckpt['epoch']
83 | print(f'successfully resumed from epoch {epoch}')
84 | elif cfg.load_from:
85 | ckpt = torch.load(cfg.load_from, map_location='cpu')
86 | if 'state_dict' in ckpt:
87 | state_dict = ckpt['state_dict']
88 | else:
89 | state_dict = ckpt
90 | print(raw_model.load_state_dict(state_dict, strict=False))
91 |
92 | # eval
93 | my_model.eval()
94 | os.environ['eval'] = 'true'
95 | recon_dir = os.path.join(args.work_dir, args.dir_name)
96 | os.makedirs(recon_dir, exist_ok=True)
97 | with torch.no_grad():
98 | for i_iter_val, (input_occs, _, metas) in enumerate(val_dataset_loader):
99 | if i_iter_val not in args.scene_idx:
100 | continue
101 | if i_iter_val > max(args.scene_idx):
102 | break
103 | input_occs = input_occs.cuda() #torch.Size([1, 16, 200, 200, 16])
104 | result = my_model(x=input_occs, metas=metas)
105 | start_frame=cfg.get('start_frame', 0)
106 | #end_frame=cfg.get('end_frame', 11)
107 | end_frame=input_occs.shape[1]
108 | logits = result['logits'] #torch.Size([1, 6, 200, 200, 16, 18])
109 | dst_dir = os.path.join(recon_dir, str(i_iter_val),'pred')
110 | input_dir = os.path.join(recon_dir, f'{i_iter_val}','input')
111 | cmp_dir = os.path.join(recon_dir, f'{i_iter_val}','cmp')
112 | # input_occs = result['input_occs']
113 | os.makedirs(dst_dir, exist_ok=True)
114 | os.makedirs(input_dir, exist_ok=True)
115 | os.makedirs(cmp_dir, exist_ok=True)
116 | all_pred=[]
117 | for frame in range(start_frame,end_frame):
118 | tt=str(i_iter_val) + '_' + str(frame)
119 | input_occ = input_occs[:, frame, ...].squeeze().cpu().numpy()
120 | dst_input=draw(input_occ,
121 | None, # predict_pts,
122 | [-40, -40, -1],
123 | [0.4] * 3,
124 | None, # grid.squeeze(0).cpu().numpy(),
125 | None,# pt_label.squeeze(-1),
126 | input_dir,#recon_dir,
127 | None, # img_metas[0]['cam_positions'],
128 | None, # img_metas[0]['focal_positions'],
129 | timestamp=tt,
130 | mode=0,
131 | sem=False)
132 | im=mmcv.imread(dst_input)
133 | cv2.putText(im, f'GT_{frame:02d}', (20, 100), cv2.FONT_HERSHEY_SIMPLEX, 3, (0, 0, 255), 2)
134 | mmcv.imwrite(im,dst_input)
135 | if True:
136 | logit = logits[:, frame, ...]
137 | pred = logit.argmax(dim=-1).squeeze().cpu().numpy() # 1, 1, 200, 200, 16
138 | all_pred.append((pred))
139 |
140 | dst_wm=draw(pred,
141 | None, # predict_pts,
142 | [-40, -40, -1],
143 | [0.4] * 3,
144 | None, # grid.squeeze(0).cpu().numpy(),
145 | None,# pt_label.squeeze(-1),
146 | dst_dir,#recon_dir,
147 | None, # img_metas[0]['cam_positions'],
148 | None, # img_metas[0]['focal_positions'],
149 | timestamp=tt,
150 | mode=0,
151 | sem=False)
152 | im=mmcv.imread(dst_wm)
153 | cv2.putText(im, f'predict_{frame:02d}', (20, 100), cv2.FONT_HERSHEY_SIMPLEX, 3, (0, 0, 255), 2)
154 | mmcv.imwrite(im,dst_wm)
155 | # concat 2 img
156 |
157 |
158 | im1=mmcv.imread(dst_input)#[:,550:-530,:]
159 | im2=mmcv.imread(dst_wm)#[:,550:-530,:]
160 | mmcv.imwrite(np.concatenate([im1, im2], axis=1), os.path.join(cmp_dir, f'vis_{tt}.png'))
161 | logger.info('[EVAL] Iter %5d / %5d'%(i_iter_val, len(val_dataset_loader)))
162 | create_mp4(dst_dir)
163 | # create_mp4(cmp_dir)
164 | if args.export_pcd:
165 | from vis_utils import visualize_point_cloud
166 |
167 | abs_pose=metas[0]['e2g_t']
168 | abs_rot=metas[0]['e2g_r']
169 | n_gt=min(len(all_pred),len(abs_pose))
170 | visualize_point_cloud(all_pred[:n_gt],abs_pose=abs_pose[:n_gt],abs_rot=abs_rot[:n_gt],cmp_dir=dst_dir,key='pred')
171 |
172 |
173 | # break #debug
174 |
175 | if __name__ == '__main__':
176 | # Eval settings
177 | parser = argparse.ArgumentParser(description='')
178 | parser.add_argument('--py-config', default='config/tpv_lidarseg.py')
179 | parser.add_argument('--work-dir', type=str, default='./out/tpv_lidarseg')
180 | parser.add_argument('--resume-from', type=str, default='')
181 | parser.add_argument('--dir-name', type=str, default='vis')
182 | parser.add_argument('--seed', type=int, default=42)#,1023,333,256])
183 | parser.add_argument('--num-trials', type=int, default=10)
184 | parser.add_argument('--frame-idx', nargs='+', type=int, default=[0, 10])
185 | parser.add_argument('--scene-idx', nargs='+', type=int, default=[6])#,7,16,18,19,87,89,96,101])
186 | parser.add_argument('--export_pcd', action='store_true', default=False, help='Enable pose (default is True)')
187 |
188 | args = parser.parse_args()
189 |
190 | ngpus = 1
191 | args.gpus = ngpus
192 | print(args)
193 | main(args)
194 |
195 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/utils/ema.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from collections import OrderedDict
3 |
4 | @torch.no_grad()
5 | def update_ema(ema_model, model, decay=0.9999):
6 | """
7 | Step the EMA model towards the current model.
8 | """
9 | ema_params = OrderedDict(ema_model.named_parameters())
10 | model_params = OrderedDict(model.named_parameters())
11 |
12 | for name, param in model_params.items():
13 | # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
14 | ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
--------------------------------------------------------------------------------
/utils/freeze_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from mmengine import MMLogger
3 | logger = MMLogger.get_instance('genocc')
4 | import torch.distributed as dist
5 |
6 | def freeze_model(model, freeze_dict):
7 | # given a model and a dictionary of booleans, freeze the model
8 | # according to the dictionary
9 | for key in freeze_dict:
10 | if freeze_dict[key]:
11 | for param in getattr(model, key).parameters():
12 | param.requires_grad = False
13 | logger = MMLogger.get_current_instance()
14 | logger.info(f'Freezed {key} parameters')
15 |
16 | if __name__ == '__main__':
17 | model = torch.nn.Sequential(
18 | torch.nn.Linear(1, 1),
19 | torch.nn.Linear(1, 1),
20 | torch.nn.Linear(1, 1)
21 | )
22 | print(model)
23 | freeze_dict = {'0': True, '1': False, '2': True}
24 | freeze_model(model, freeze_dict)
25 | import pdb; pdb.set_trace()
--------------------------------------------------------------------------------
/utils/load_save_util.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | import torch
3 |
4 | def revise_ckpt(state_dict):
5 | tmp_k = list(state_dict.keys())[0]
6 | if not tmp_k.startswith('module.'):
7 | state_dict = OrderedDict(
8 | {('module.' + k): v
9 | for k, v in state_dict.items()})
10 | return state_dict
11 |
12 | def revise_ckpt_1(state_dict):
13 | tmp_k = list(state_dict.keys())[0]
14 | if tmp_k.startswith('module.'):
15 | state_dict = OrderedDict(
16 | {k[7:]: v
17 | for k, v in state_dict.items()})
18 | return state_dict
19 |
20 | def revise_ckpt_2(state_dict):
21 | param_names = list(state_dict.keys())
22 | for param_name in param_names:
23 | if 'img_neck.lateral_convs' in param_name or 'img_neck.fpn_convs' in param_name:
24 | del state_dict[param_name]
25 | return state_dict
26 |
27 | def load_checkpoint(raw_model, state_dict, strict=True) -> None:
28 | # state_dict = checkpoint["state_dict"]
29 | model_state_dict = raw_model.state_dict()
30 | is_changed = False
31 | for k in state_dict:
32 | if k in model_state_dict:
33 | if state_dict[k].shape != model_state_dict[k].shape:
34 | # process embedding
35 | if k=='class_embeds.weight':
36 | l1=state_dict[k].shape[0]
37 | l2=model_state_dict[k].shape[0]
38 | if l1>l2:
39 | state_dict[k] = state_dict[k][:l2]
40 | else:
41 | state_dict_k_new=torch.zeros_like(model_state_dict[k])
42 | state_dict_k_new[:l1]=state_dict[k]
43 | state_dict[k]=state_dict_k_new
44 | else:
45 | print(f"Skip loading parameter: {k}, "
46 | f"required shape: {model_state_dict[k].shape}, "
47 | f"loaded shape: {state_dict[k].shape}")
48 | state_dict[k] = model_state_dict[k]
49 | is_changed = True
50 |
51 | else:
52 | print(f"Dropping parameter {k}")
53 | is_changed = True
54 |
55 | # if is_changed:
56 | # checkpoint.pop("optimizer_states", None)
57 |
58 | print(raw_model.load_state_dict(state_dict, strict=strict))
--------------------------------------------------------------------------------
/utils/metric_util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from mmengine import MMLogger
3 | import torch
4 | import torch.distributed as dist
5 | # logger = MMLogger.get_instance('genocc',distributed=dist.is_initialized())
6 |
7 |
8 | class MeanIoU:
9 |
10 | def __init__(self,
11 | class_indices,
12 | ignore_label: int,
13 | label_str,
14 | name
15 | # empty_class: int
16 | ):
17 | self.class_indices = class_indices
18 | self.num_classes = len(class_indices)
19 | self.ignore_label = ignore_label
20 | self.label_str = label_str
21 | self.name = name
22 |
23 | def reset(self) -> None:
24 | self.total_seen = torch.zeros(self.num_classes).cuda()
25 | self.total_correct = torch.zeros(self.num_classes).cuda()
26 | self.total_positive = torch.zeros(self.num_classes).cuda()
27 |
28 | def _after_step(self, outputs, targets, log_current=False):
29 | outputs = outputs[targets != self.ignore_label]
30 | targets = targets[targets != self.ignore_label]
31 |
32 | for i, c in enumerate(self.class_indices):
33 | self.total_seen[i] += torch.sum(targets == c).item()
34 | self.total_correct[i] += torch.sum((targets == c)
35 | & (outputs == c)).item()
36 | self.total_positive[i] += torch.sum(outputs == c).item()
37 |
38 | def _after_epoch(self):
39 | if dist.is_initialized():
40 | dist.all_reduce(self.total_seen)
41 | dist.all_reduce(self.total_correct)
42 | dist.all_reduce(self.total_positive)
43 |
44 | ious = []
45 |
46 | for i in range(self.num_classes):
47 | if self.total_seen[i] == 0:
48 | ious.append(1)
49 | else:
50 | cur_iou = self.total_correct[i] / (self.total_seen[i]
51 | + self.total_positive[i]
52 | - self.total_correct[i])
53 | ious.append(cur_iou.item())
54 |
55 | miou = np.mean(ious)
56 | logger = MMLogger.get_current_instance()
57 | logger.info(f'Validation per class iou {self.name}:')
58 | for iou, label_str in zip(ious, self.label_str):
59 | logger.info('%s : %.2f%%' % (label_str, iou * 100))
60 |
61 | return miou * 100
62 |
63 |
64 | class multi_step_MeanIou:
65 | def __init__(self,
66 | class_indices,
67 | ignore_label: int,
68 | label_str,
69 | name,
70 | times=1):
71 | self.class_indices = class_indices
72 | self.num_classes = len(class_indices)
73 | self.ignore_label = ignore_label
74 | self.label_str = label_str
75 | self.name = name
76 | self.times = times
77 |
78 | def reset(self) -> None:
79 | self.total_seen = torch.zeros(self.times, self.num_classes).cuda()
80 | self.total_correct = torch.zeros(self.times, self.num_classes).cuda()
81 | self.total_positive = torch.zeros(self.times, self.num_classes).cuda()
82 | self.current_seen = torch.zeros(self.times, self.num_classes).cuda()
83 | self.current_correct = torch.zeros(self.times, self.num_classes).cuda()
84 | self.current_positive = torch.zeros(self.times, self.num_classes).cuda()
85 |
86 | def _after_step(self, outputses, targetses,log_current=False):
87 |
88 | assert outputses.shape[1] == self.times, f'{outputses.shape[1]} != {self.times}'
89 | assert targetses.shape[1] == self.times, f'{targetses.shape[1]} != {self.times}'
90 | mious = []
91 | for t in range(self.times):
92 | ious = []
93 | outputs = outputses[:,t, ...][targetses[:,t, ...] != self.ignore_label].cuda()
94 | targets = targetses[:,t, ...][targetses[:,t, ...] != self.ignore_label].cuda()
95 | for j, c in enumerate(self.class_indices):
96 | self.total_seen[t, j] += torch.sum(targets == c).item()
97 | self.total_correct[t, j] += torch.sum((targets == c)
98 | & (outputs == c)).item()
99 | self.total_positive[t, j] += torch.sum(outputs == c).item()
100 | if log_current:
101 | current_seen = torch.sum(targets == c).item()
102 | current_correct = torch.sum((targets == c)& (outputs == c)).item()
103 | current_positive = torch.sum(outputs == c).item()
104 | if current_seen == 0:
105 | ious.append(1)
106 | else:
107 | cur_iou = current_correct / (current_seen+current_positive-current_correct)
108 | ious.append(cur_iou)
109 | if log_current:
110 | miou = np.mean(ious)
111 | logger = MMLogger.get_current_instance()#distributed=dist.is_initialized())
112 | logger.info(f'current:: per class iou {self.name} at time {t}:')
113 | for iou, label_str in zip(ious, self.label_str):
114 | logger.info('%s : %.2f%%' % (label_str, iou * 100))
115 | logger.info(f'mIoU {self.name} at time {t}: %.2f%%' % (miou * 100))
116 | mious.append(miou * 100)
117 | m_miou=np.mean(mious)
118 | # mious=torch.tensor(mious).cuda()
119 | return mious, m_miou
120 |
121 | def _after_epoch(self):
122 | logger = MMLogger.get_current_instance()#distributed=dist.is_initialized())
123 | if dist.is_initialized():
124 | dist.all_reduce(self.total_seen)
125 | dist.all_reduce(self.total_correct)
126 | dist.all_reduce(self.total_positive)
127 | logger.info(f'_after_epoch::total_seen: {self.total_seen.sum()}')
128 | logger.info(f'_after_epoch::total_correct: {self.total_correct.sum()}')
129 | logger.info(f'_after_epoch::total_positive: {self.total_positive.sum()}')
130 | mious = []
131 | for t in range(self.times):
132 | ious = []
133 | for i in range(self.num_classes):
134 | if self.total_seen[t, i] == 0:
135 | ious.append(1)
136 | else:
137 | cur_iou = self.total_correct[t, i] / (self.total_seen[t, i]
138 | + self.total_positive[t, i]
139 | - self.total_correct[t, i])
140 | ious.append(cur_iou.item())
141 | miou = np.mean(ious)
142 | logger.info(f'per class iou {self.name} at time {t}:')
143 | for iou, label_str in zip(ious, self.label_str):
144 | logger.info('%s : %.2f%%' % (label_str, iou * 100))
145 | logger.info(f'mIoU {self.name} at time {t}: %.2f%%' % (miou * 100))
146 | mious.append(miou * 100)
147 | return mious, np.mean(mious)
--------------------------------------------------------------------------------