├── .gitignore
├── LICENSE
├── README.md
├── config
├── bedrooms_mixed.yaml
├── diningrooms_mixed.yaml
└── livingrooms_mixed.yaml
├── external_licenses
├── ATISS
├── DiffuScene
└── VQ-Diffusion
├── media
├── architecture_full.png
└── video.png
├── midiffusion
├── datasets
│ └── threed_front_encoding.py
├── evaluation
│ └── utils.py
├── networks
│ ├── __init__.py
│ ├── denoising_net
│ │ ├── continuous_transformer.py
│ │ ├── mixed_transformer.py
│ │ ├── time_embedding.py
│ │ ├── transformer_utils.py
│ │ └── unet1D.py
│ ├── diffusion_base.py
│ ├── diffusion_d3pm.py
│ ├── diffusion_ddpm.py
│ ├── diffusion_mixed.py
│ ├── diffusion_scene_layout_ddpm.py
│ ├── diffusion_scene_layout_mixed.py
│ ├── feature_extractors.py
│ ├── frozen_batchnorm.py
│ └── loss.py
└── stats_logger.py
├── scripts
├── generate_results.py
├── train_diffusion.py
└── utils.py
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # virtual env
2 | venv*/
3 |
4 | # IDE
5 | .vscode/
6 | .idea/
7 |
8 | # build files
9 | build/
10 | *.egg-info
11 | *.egg
12 | *.o
13 | *.so
14 | *__pycache__*
15 |
16 | # output dir
17 | output/
18 | scripts/wandb/
19 | wandb/
20 |
21 | # deprecated code
22 | _*/
23 |
24 | # eval scripts symbolically linked from ThreedFront
25 | scripts/evaluate_kl_divergence_object_category.py
26 | scripts/render_results.py
27 | scripts/compute_fid_scores.py
28 | scripts/synthetic_vs_real_classifier.py
29 |
30 | # log files
31 | *.log-*
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Attribution-NonCommercial 4.0 International
2 |
3 | =======================================================================
4 |
5 | Creative Commons Corporation ("Creative Commons") is not a law firm and
6 | does not provide legal services or legal advice. Distribution of
7 | Creative Commons public licenses does not create a lawyer-client or
8 | other relationship. Creative Commons makes its licenses and related
9 | information available on an "as-is" basis. Creative Commons gives no
10 | warranties regarding its licenses, any material licensed under their
11 | terms and conditions, or any related information. Creative Commons
12 | disclaims all liability for damages resulting from their use to the
13 | fullest extent possible.
14 |
15 | Using Creative Commons Public Licenses
16 |
17 | Creative Commons public licenses provide a standard set of terms and
18 | conditions that creators and other rights holders may use to share
19 | original works of authorship and other material subject to copyright
20 | and certain other rights specified in the public license below. The
21 | following considerations are for informational purposes only, are not
22 | exhaustive, and do not form part of our licenses.
23 |
24 | Considerations for licensors: Our public licenses are
25 | intended for use by those authorized to give the public
26 | permission to use material in ways otherwise restricted by
27 | copyright and certain other rights. Our licenses are
28 | irrevocable. Licensors should read and understand the terms
29 | and conditions of the license they choose before applying it.
30 | Licensors should also secure all rights necessary before
31 | applying our licenses so that the public can reuse the
32 | material as expected. Licensors should clearly mark any
33 | material not subject to the license. This includes other CC-
34 | licensed material, or material used under an exception or
35 | limitation to copyright. More considerations for licensors:
36 | wiki.creativecommons.org/Considerations_for_licensors
37 |
38 | Considerations for the public: By using one of our public
39 | licenses, a licensor grants the public permission to use the
40 | licensed material under specified terms and conditions. If
41 | the licensor's permission is not necessary for any reason--for
42 | example, because of any applicable exception or limitation to
43 | copyright--then that use is not regulated by the license. Our
44 | licenses grant only permissions under copyright and certain
45 | other rights that a licensor has authority to grant. Use of
46 | the licensed material may still be restricted for other
47 | reasons, including because others have copyright or other
48 | rights in the material. A licensor may make special requests,
49 | such as asking that all changes be marked or described.
50 | Although not required by our licenses, you are encouraged to
51 | respect those requests where reasonable. More_considerations
52 | for the public:
53 | wiki.creativecommons.org/Considerations_for_licensees
54 |
55 | =======================================================================
56 |
57 | Creative Commons Attribution-NonCommercial 4.0 International Public
58 | License
59 |
60 | By exercising the Licensed Rights (defined below), You accept and agree
61 | to be bound by the terms and conditions of this Creative Commons
62 | Attribution-NonCommercial 4.0 International Public License ("Public
63 | License"). To the extent this Public License may be interpreted as a
64 | contract, You are granted the Licensed Rights in consideration of Your
65 | acceptance of these terms and conditions, and the Licensor grants You
66 | such rights in consideration of benefits the Licensor receives from
67 | making the Licensed Material available under these terms and
68 | conditions.
69 |
70 | Section 1 -- Definitions.
71 |
72 | a. Adapted Material means material subject to Copyright and Similar
73 | Rights that is derived from or based upon the Licensed Material
74 | and in which the Licensed Material is translated, altered,
75 | arranged, transformed, or otherwise modified in a manner requiring
76 | permission under the Copyright and Similar Rights held by the
77 | Licensor. For purposes of this Public License, where the Licensed
78 | Material is a musical work, performance, or sound recording,
79 | Adapted Material is always produced where the Licensed Material is
80 | synched in timed relation with a moving image.
81 |
82 | b. Adapter's License means the license You apply to Your Copyright
83 | and Similar Rights in Your contributions to Adapted Material in
84 | accordance with the terms and conditions of this Public License.
85 |
86 | c. Copyright and Similar Rights means copyright and/or similar rights
87 | closely related to copyright including, without limitation,
88 | performance, broadcast, sound recording, and Sui Generis Database
89 | Rights, without regard to how the rights are labeled or
90 | categorized. For purposes of this Public License, the rights
91 | specified in Section 2(b)(1)-(2) are not Copyright and Similar
92 | Rights.
93 | d. Effective Technological Measures means those measures that, in the
94 | absence of proper authority, may not be circumvented under laws
95 | fulfilling obligations under Article 11 of the WIPO Copyright
96 | Treaty adopted on December 20, 1996, and/or similar international
97 | agreements.
98 |
99 | e. Exceptions and Limitations means fair use, fair dealing, and/or
100 | any other exception or limitation to Copyright and Similar Rights
101 | that applies to Your use of the Licensed Material.
102 |
103 | f. Licensed Material means the artistic or literary work, database,
104 | or other material to which the Licensor applied this Public
105 | License.
106 |
107 | g. Licensed Rights means the rights granted to You subject to the
108 | terms and conditions of this Public License, which are limited to
109 | all Copyright and Similar Rights that apply to Your use of the
110 | Licensed Material and that the Licensor has authority to license.
111 |
112 | h. Licensor means the individual(s) or entity(ies) granting rights
113 | under this Public License.
114 |
115 | i. NonCommercial means not primarily intended for or directed towards
116 | commercial advantage or monetary compensation. For purposes of
117 | this Public License, the exchange of the Licensed Material for
118 | other material subject to Copyright and Similar Rights by digital
119 | file-sharing or similar means is NonCommercial provided there is
120 | no payment of monetary compensation in connection with the
121 | exchange.
122 |
123 | j. Share means to provide material to the public by any means or
124 | process that requires permission under the Licensed Rights, such
125 | as reproduction, public display, public performance, distribution,
126 | dissemination, communication, or importation, and to make material
127 | available to the public including in ways that members of the
128 | public may access the material from a place and at a time
129 | individually chosen by them.
130 |
131 | k. Sui Generis Database Rights means rights other than copyright
132 | resulting from Directive 96/9/EC of the European Parliament and of
133 | the Council of 11 March 1996 on the legal protection of databases,
134 | as amended and/or succeeded, as well as other essentially
135 | equivalent rights anywhere in the world.
136 |
137 | l. You means the individual or entity exercising the Licensed Rights
138 | under this Public License. Your has a corresponding meaning.
139 |
140 | Section 2 -- Scope.
141 |
142 | a. License grant.
143 |
144 | 1. Subject to the terms and conditions of this Public License,
145 | the Licensor hereby grants You a worldwide, royalty-free,
146 | non-sublicensable, non-exclusive, irrevocable license to
147 | exercise the Licensed Rights in the Licensed Material to:
148 |
149 | a. reproduce and Share the Licensed Material, in whole or
150 | in part, for NonCommercial purposes only; and
151 |
152 | b. produce, reproduce, and Share Adapted Material for
153 | NonCommercial purposes only.
154 |
155 | 2. Exceptions and Limitations. For the avoidance of doubt, where
156 | Exceptions and Limitations apply to Your use, this Public
157 | License does not apply, and You do not need to comply with
158 | its terms and conditions.
159 |
160 | 3. Term. The term of this Public License is specified in Section
161 | 6(a).
162 |
163 | 4. Media and formats; technical modifications allowed. The
164 | Licensor authorizes You to exercise the Licensed Rights in
165 | all media and formats whether now known or hereafter created,
166 | and to make technical modifications necessary to do so. The
167 | Licensor waives and/or agrees not to assert any right or
168 | authority to forbid You from making technical modifications
169 | necessary to exercise the Licensed Rights, including
170 | technical modifications necessary to circumvent Effective
171 | Technological Measures. For purposes of this Public License,
172 | simply making modifications authorized by this Section 2(a)
173 | (4) never produces Adapted Material.
174 |
175 | 5. Downstream recipients.
176 |
177 | a. Offer from the Licensor -- Licensed Material. Every
178 | recipient of the Licensed Material automatically
179 | receives an offer from the Licensor to exercise the
180 | Licensed Rights under the terms and conditions of this
181 | Public License.
182 |
183 | b. No downstream restrictions. You may not offer or impose
184 | any additional or different terms or conditions on, or
185 | apply any Effective Technological Measures to, the
186 | Licensed Material if doing so restricts exercise of the
187 | Licensed Rights by any recipient of the Licensed
188 | Material.
189 |
190 | 6. No endorsement. Nothing in this Public License constitutes or
191 | may be construed as permission to assert or imply that You
192 | are, or that Your use of the Licensed Material is, connected
193 | with, or sponsored, endorsed, or granted official status by,
194 | the Licensor or others designated to receive attribution as
195 | provided in Section 3(a)(1)(A)(i).
196 |
197 | b. Other rights.
198 |
199 | 1. Moral rights, such as the right of integrity, are not
200 | licensed under this Public License, nor are publicity,
201 | privacy, and/or other similar personality rights; however, to
202 | the extent possible, the Licensor waives and/or agrees not to
203 | assert any such rights held by the Licensor to the limited
204 | extent necessary to allow You to exercise the Licensed
205 | Rights, but not otherwise.
206 |
207 | 2. Patent and trademark rights are not licensed under this
208 | Public License.
209 |
210 | 3. To the extent possible, the Licensor waives any right to
211 | collect royalties from You for the exercise of the Licensed
212 | Rights, whether directly or through a collecting society
213 | under any voluntary or waivable statutory or compulsory
214 | licensing scheme. In all other cases the Licensor expressly
215 | reserves any right to collect such royalties, including when
216 | the Licensed Material is used other than for NonCommercial
217 | purposes.
218 |
219 | Section 3 -- License Conditions.
220 |
221 | Your exercise of the Licensed Rights is expressly made subject to the
222 | following conditions.
223 |
224 | a. Attribution.
225 |
226 | 1. If You Share the Licensed Material (including in modified
227 | form), You must:
228 |
229 | a. retain the following if it is supplied by the Licensor
230 | with the Licensed Material:
231 |
232 | i. identification of the creator(s) of the Licensed
233 | Material and any others designated to receive
234 | attribution, in any reasonable manner requested by
235 | the Licensor (including by pseudonym if
236 | designated);
237 |
238 | ii. a copyright notice;
239 |
240 | iii. a notice that refers to this Public License;
241 |
242 | iv. a notice that refers to the disclaimer of
243 | warranties;
244 |
245 | v. a URI or hyperlink to the Licensed Material to the
246 | extent reasonably practicable;
247 |
248 | b. indicate if You modified the Licensed Material and
249 | retain an indication of any previous modifications; and
250 |
251 | c. indicate the Licensed Material is licensed under this
252 | Public License, and include the text of, or the URI or
253 | hyperlink to, this Public License.
254 |
255 | 2. You may satisfy the conditions in Section 3(a)(1) in any
256 | reasonable manner based on the medium, means, and context in
257 | which You Share the Licensed Material. For example, it may be
258 | reasonable to satisfy the conditions by providing a URI or
259 | hyperlink to a resource that includes the required
260 | information.
261 |
262 | 3. If requested by the Licensor, You must remove any of the
263 | information required by Section 3(a)(1)(A) to the extent
264 | reasonably practicable.
265 |
266 | 4. If You Share Adapted Material You produce, the Adapter's
267 | License You apply must not prevent recipients of the Adapted
268 | Material from complying with this Public License.
269 |
270 | Section 4 -- Sui Generis Database Rights.
271 |
272 | Where the Licensed Rights include Sui Generis Database Rights that
273 | apply to Your use of the Licensed Material:
274 |
275 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right
276 | to extract, reuse, reproduce, and Share all or a substantial
277 | portion of the contents of the database for NonCommercial purposes
278 | only;
279 |
280 | b. if You include all or a substantial portion of the database
281 | contents in a database in which You have Sui Generis Database
282 | Rights, then the database in which You have Sui Generis Database
283 | Rights (but not its individual contents) is Adapted Material; and
284 |
285 | c. You must comply with the conditions in Section 3(a) if You Share
286 | all or a substantial portion of the contents of the database.
287 |
288 | For the avoidance of doubt, this Section 4 supplements and does not
289 | replace Your obligations under this Public License where the Licensed
290 | Rights include other Copyright and Similar Rights.
291 |
292 | Section 5 -- Disclaimer of Warranties and Limitation of Liability.
293 |
294 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
295 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
296 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
297 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
298 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
299 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
300 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
301 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
302 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
303 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
304 |
305 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
306 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
307 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
308 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
309 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
310 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
311 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
312 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
313 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
314 |
315 | c. The disclaimer of warranties and limitation of liability provided
316 | above shall be interpreted in a manner that, to the extent
317 | possible, most closely approximates an absolute disclaimer and
318 | waiver of all liability.
319 |
320 | Section 6 -- Term and Termination.
321 |
322 | a. This Public License applies for the term of the Copyright and
323 | Similar Rights licensed here. However, if You fail to comply with
324 | this Public License, then Your rights under this Public License
325 | terminate automatically.
326 |
327 | b. Where Your right to use the Licensed Material has terminated under
328 | Section 6(a), it reinstates:
329 |
330 | 1. automatically as of the date the violation is cured, provided
331 | it is cured within 30 days of Your discovery of the
332 | violation; or
333 |
334 | 2. upon express reinstatement by the Licensor.
335 |
336 | For the avoidance of doubt, this Section 6(b) does not affect any
337 | right the Licensor may have to seek remedies for Your violations
338 | of this Public License.
339 |
340 | c. For the avoidance of doubt, the Licensor may also offer the
341 | Licensed Material under separate terms or conditions or stop
342 | distributing the Licensed Material at any time; however, doing so
343 | will not terminate this Public License.
344 |
345 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
346 | License.
347 |
348 | Section 7 -- Other Terms and Conditions.
349 |
350 | a. The Licensor shall not be bound by any additional or different
351 | terms or conditions communicated by You unless expressly agreed.
352 |
353 | b. Any arrangements, understandings, or agreements regarding the
354 | Licensed Material not stated herein are separate from and
355 | independent of the terms and conditions of this Public License.
356 |
357 | Section 8 -- Interpretation.
358 |
359 | a. For the avoidance of doubt, this Public License does not, and
360 | shall not be interpreted to, reduce, limit, restrict, or impose
361 | conditions on any use of the Licensed Material that could lawfully
362 | be made without permission under this Public License.
363 |
364 | b. To the extent possible, if any provision of this Public License is
365 | deemed unenforceable, it shall be automatically reformed to the
366 | minimum extent necessary to make it enforceable. If the provision
367 | cannot be reformed, it shall be severed from this Public License
368 | without affecting the enforceability of the remaining terms and
369 | conditions.
370 |
371 | c. No term or condition of this Public License will be waived and no
372 | failure to comply consented to unless expressly agreed to by the
373 | Licensor.
374 |
375 | d. Nothing in this Public License constitutes or may be interpreted
376 | as a limitation upon, or waiver of, any privileges and immunities
377 | that apply to the Licensor or You, including from the legal
378 | processes of any jurisdiction or authority.
379 |
380 | =======================================================================
381 |
382 | Creative Commons is not a party to its public
383 | licenses. Notwithstanding, Creative Commons may elect to apply one of
384 | its public licenses to material it publishes and in those instances
385 | will be considered the “Licensor.” The text of the Creative Commons
386 | public licenses is dedicated to the public domain under the CC0 Public
387 | Domain Dedication. Except for the limited purpose of indicating that
388 | material is shared under a Creative Commons public license or as
389 | otherwise permitted by the Creative Commons policies published at
390 | creativecommons.org/policies, Creative Commons does not authorize the
391 | use of the trademark "Creative Commons" or any other trademark or logo
392 | of Creative Commons without its prior written consent including,
393 | without limitation, in connection with any unauthorized modifications
394 | to any of its public licenses or any other arrangements,
395 | understandings, or agreements concerning use of licensed material. For
396 | the avoidance of doubt, this paragraph does not form part of the
397 | public licenses.
398 |
399 | Creative Commons may be contacted at creativecommons.org.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Mixed Diffusion Models for 3D Indoor Scene Synthesis
2 | This repository contains the model code that accompanies our paper [Mixed Diffusion for 3D Indoor Scene Synthesis](https://arxiv.org/abs/2405.21066).
3 | We present **MiDiffusion**, a novel mixed discrete-continuous diffusion model architecture, designed to synthesize plausible 3D indoor scenes from given room types, floor plans, and potentially pre-existing objects. Our approach uniquely implements structured corruption across the mixed discrete semantic and continuous geometric domains, resulting in a better conditioned problem for the reverse denoising step.
4 |
5 |
6 |
7 | We place the preprocessing and evaluation scripts for the [3D-FRONT](https://tianchi.aliyun.com/specials/promotion/alibaba-3d-scene-dataset) and [3D-FUTURE](https://tianchi.aliyun.com/specials/promotion/alibaba-3d-future) datasets based on [ATISS](https://github.com/nv-tlabs/ATISS/) in our [ThreedFront dataset](https://github.com/MIT-SPARK/ThreedFront) repository to facilitate comparisons with other 3D scene synthesis methods using the same datasets.
8 | ThreedFront also contains dataset class implementations as a standalone `threed_front` package, which is a dependency of this repository.
9 | We borrow code from [VQ-Diffusion](https://github.com/microsoft/VQ-Diffusion) and [DiffuScene](https://github.com/tangjiapeng/DiffuScene) for discrete and continuous domain diffusion implementations, respectively. Please refer to related licensing information in external_licenses.
10 |
11 | If you found this work useful, please consider citing our paper:
12 | ```
13 | @article{Hu24arxiv-MiDiffusion,
14 | author={Siyi Hu and Diego Martin Arroyo and Stephanie Debats and Fabian Manhardt and Luca Carlone and Federico Tombari},
15 | title={Mixed Diffusion for 3D Indoor Scene Synthesis},
16 | journal = {arXiv preprint: 2405.21066},
17 | pdf = {https://arxiv.org/abs/2405.21066},
18 | Year = {2024}
19 | }
20 | ```
21 |
22 | ## Installation & Dependencies
23 | Our code is developed in Python 3.8 with PyTorch 1.12.1 and CUDA 11.3.
24 |
25 | First, from this root directory, clone [ThreedFront](https://github.com/MIT-SPARK/ThreedFront):
26 | ```
27 | git clone git@github.com:MIT-SPARK/ThreedFront.git ../ThreedFront
28 | ```
29 | You can either install all dependencies listed in [ThreedFront](https://github.com/MIT-SPARK/ThreedFront), or, if you also want to use `threed_front` for other projects, install `threed_front` separately and add its `site-packages` directory. For example, if you use virtualenv, run
30 | ```
31 | echo "/lib/python3.x/site-packages" > /lib/python3.x/site-packages/threed-front.pth
32 | ```
33 |
34 | Then install `threed_front` and `midiffusion`. `midiffusion` requires two additional dependencies: [einops==0.8.0](https://einops.rocks/) and [wandb==0.17.1](https://docs.wandb.ai/quickstart).
35 | ```
36 | # install threed-front
37 | pip install -e ../ThreedFront
38 |
39 | # install midiffusion
40 | python setup.py build_ext --inplace
41 | pip install -e .
42 | ```
43 |
44 | ## Dataset
45 | We use [3D-FRONT](https://tianchi.aliyun.com/specials/promotion/alibaba-3d-scene-dataset) and [3D-FUTURE](https://tianchi.aliyun.com/specials/promotion/alibaba-3d-future) datasets for training and testing of our model.
46 | Please follow the data preprocessing steps in [ThreedFront](https://github.com/MIT-SPARK/ThreedFront).
47 | We use the same data files as those included in `ThreedFront/data_files` for training and evaluation steps. Please check that `PATH_TO_DATASET_FILES` and `PATH_TO_PROCESSED_DATA` in `scripts/utils.py` are pointing to the right directories.
48 |
49 | ## Training
50 | To train diffuscene on 3D Front-bedrooms, you can run
51 | ```
52 | python scripts/train_diffusion.py --experiment_tag
53 | ```
54 | We provide example config files in the `config/` directory. This train script saves a copy of the config file (as `config.yaml`) and log intermediate model weights to `output/log/` unless `--output_directory` is set otherwise.
55 |
56 | ## Experiment
57 | The `scripts/generate_results.py` script can compute and pickle synthetic layouts generated by a trained model through the `threed_front` package. We provide example trained models [here](https://drive.google.com/drive/folders/14N87Ap90KNaDlRv5u6UeCV1h_MT9QqaN?usp=sharing).
58 | ```
59 | python scripts/generate_results.py --result_tag
60 | ```
61 | This script loads config from the `config.yaml` file in the same directory as `` if not specified.
62 | The results will be saved to `output/predicted_results//results.pkl` unless `--output_directory` is set otherwise.
63 | We can run experiments with different object constraints using the same model by setting the `--experiment` argument. The options include:
64 | - **synthesis** (default): scene synthesis problem given input floor plans.
65 | - **scene_completion**: scene completion given floor plans and existing objects (specified via `--n_known_objects`).
66 | - **furniture_arrangement**: scene completion given floor plans, object labels and sizes.
67 | - **object_conditioned**: scene completion given floor plans, object labels.
68 | - **scene_completion_conditioned**: scene completion given floor plans, existing objects, and labels of remaining objects.
69 |
70 | You can then render the predicted layout to top-down projection images using `scripts/render_results.py` in [ThreedFront](https://github.com/MIT-SPARK/ThreedFront) for evaluation.
71 | ```
72 | python ../ThreedFront/scripts/render_results.py output/predicted_results//results.pkl
73 | ```
74 | Please read this script for rendering options.
75 |
76 | ## Evaluation
77 | The evaluation scripts in the `scripts/` directory of [ThreedFront](https://github.com/MIT-SPARK/ThreedFront) include:
78 | - `evaluate_kl_divergence_object_category.py`: Compute **KL-divergence** between ground-truth and synthesized object category distributions.
79 | - `compute_fid_scores.py`: Compute average **FID** or **KID** (if run with "--compute_kid" flag) between ground-truth and synthesized layout images.
80 | - `synthetic_vs_real_classifier.py`: Train image classifier to distinguish real and synthetic projection images, and compute average **classification accuracy**.
81 | - `bbox_analysis.py`: Count the number of **out-of-boundary** object bounding boxes and compute pairwise bounding boxes **IoU** (this requires sampled floor plan boundary and normal points).
82 |
83 | ## Video
84 | An overview of MiDiffusion is available on [YouTube](https://www.youtube.com/watch?v=sLOMhsweb8Y):
85 |
86 | [
](https://www.youtube.com/watch?v=sLOMhsweb8Y)
87 |
88 |
89 | ## Relevant Research
90 |
91 | Please also check out the following papers that explore similar ideas:
92 | - Fast and Flexible Indoor Scene Synthesis via Deep Convolutional Generative Models [pdf](https://arxiv.org/pdf/1811.12463.pdf)
93 | - Sceneformer: Indoor Scene Generation with Transformers [pdf](https://arxiv.org/pdf/2012.09793.pdf)
94 | - ATISS: Autoregressive Transformers for Indoor Scene Synthesis [pdf](https://arxiv.org/pdf/2110.03675.pdf)
95 | - Indoor Scene Generation from a Collection of Semantic-Segmented Depth Images [pdf](https://arxiv.org/abs/2108.09022)
96 | - Scene Synthesis via Uncertainty-Driven Attribute Synchronization [pdf](https://arxiv.org/abs/2108.13499)
97 | - LEGO-Net: Learning Regular Rearrangements of Objects in Rooms [pdf](https://arxiv.org/abs/2301.09629)
98 | - DiffuScene: Denoising Diffusion Models for Generative Indoor Scene Synthesis [pdf](https://arxiv.org/abs/2303.14207)
99 |
--------------------------------------------------------------------------------
/config/bedrooms_mixed.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | dataset_type: "cached_threedfront"
3 | encoding_type: "cached_diffusion_cosin_angle_wocm"
4 | dataset_directory: "bedroom"
5 | annotation_file: "bedroom_threed_front_splits.csv"
6 | augmentations: ["fixed_rotations"]
7 | train_stats: "dataset_stats.txt"
8 | room_layout_size: "64,64"
9 |
10 | network:
11 | type: "diffusion_scene_layout_mixed"
12 |
13 | # encoding dim
14 | sample_num_points: 12 # max_length
15 | angle_dim: 2
16 |
17 | # room mask condition
18 | room_mask_condition: true
19 | room_latent_dim: 64
20 |
21 | # position condition
22 | position_condition: false
23 | position_emb_dim: 0
24 |
25 | # diffusion config
26 | time_num: 1000
27 | diffusion_semantic_kwargs:
28 | att_1: 0.99999
29 | att_T: 0.000009
30 | ctt_1: 0.000009
31 | ctt_T: 0.99999
32 | model_output_type: 'x0'
33 | mask_weight: 1
34 | auxiliary_loss_weight: 0.0005
35 | adaptive_auxiliary_loss: True
36 | diffusion_geometric_kwargs:
37 | schedule_type: 'linear'
38 | beta_start: 0.0001
39 | beta_end: 0.02
40 | loss_type: 'mse'
41 | model_mean_type: 'eps'
42 | model_var_type: 'fixedsmall'
43 |
44 | # denoising net
45 | net_type: "transformer"
46 | net_kwargs:
47 | seperate_all: True
48 | n_layer: 8
49 | n_embd: 512
50 | n_head: 4
51 | dim_feedforward: 2048
52 | dropout: 0.1
53 | activate: 'GELU'
54 | timestep_type: 'adalayernorm_abs'
55 | mlp_type: 'fc'
56 |
57 | feature_extractor:
58 | name: "pointnet_simple"
59 | feat_units: [4, 64, 64, 512, 64]
60 |
61 | training:
62 | splits: ["train", "val"]
63 | epochs: 50000
64 | batch_size: 512
65 | save_frequency: 2000
66 | max_grad_norm: 10
67 | # optimizer
68 | optimizer: Adam
69 | weight_decay: 0.0
70 | # schedule
71 | schedule: 'step'
72 | lr: 0.0002
73 | lr_step: 10000
74 | lr_decay: 0.5
75 |
76 | validation:
77 | splits: ["test"]
78 | frequency: 100
79 | batch_size: 512
80 |
81 | logger:
82 | type: "wandb"
83 | project: "MiDiffusion"
84 |
--------------------------------------------------------------------------------
/config/diningrooms_mixed.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | dataset_type: "cached_threedfront"
3 | encoding_type: "cached_diffusion_cosin_angle_wocm"
4 | dataset_directory: "diningroom"
5 | annotation_file: "diningroom_threed_front_splits.csv"
6 | augmentations: ["fixed_rotations"]
7 | train_stats: "dataset_stats.txt"
8 | room_layout_size: "64,64"
9 |
10 | network:
11 | type: "diffusion_scene_layout_mixed"
12 |
13 | # encoding dim
14 | sample_num_points: 21 # max_length
15 | angle_dim: 2
16 |
17 | # room mask condition
18 | room_mask_condition: true
19 | room_latent_dim: 64
20 |
21 | # position condition
22 | position_condition: false
23 | position_emb_dim: 0
24 |
25 | # diffusion config
26 | time_num: 1000
27 | diffusion_semantic_kwargs:
28 | att_1: 0.99999
29 | att_T: 0.000009
30 | ctt_1: 0.000009
31 | ctt_T: 0.99999
32 | model_output_type: 'x0'
33 | mask_weight: 1
34 | auxiliary_loss_weight: 0.0005
35 | adaptive_auxiliary_loss: True
36 | diffusion_geometric_kwargs:
37 | schedule_type: 'linear'
38 | beta_start: 0.0001
39 | beta_end: 0.02
40 | loss_type: 'mse'
41 | model_mean_type: 'eps'
42 | model_var_type: 'fixedsmall'
43 |
44 | # denoising net
45 | net_type: "transformer"
46 | net_kwargs:
47 | seperate_all: True
48 | n_layer: 8
49 | n_embd: 512
50 | n_head: 4
51 | dim_feedforward: 2048
52 | dropout: 0.1
53 | activate: 'GELU'
54 | timestep_type: 'adalayernorm_abs'
55 | mlp_type: 'fc'
56 |
57 | feature_extractor:
58 | name: "pointnet_simple"
59 | feat_units: [4, 64, 64, 512, 64]
60 |
61 | training:
62 | splits: ["train", "val"]
63 | epochs: 100000
64 | batch_size: 512
65 | save_frequency: 2000
66 | max_grad_norm: 10
67 | # optimizer
68 | optimizer: Adam
69 | weight_decay: 0.0
70 | # schedule
71 | schedule: 'step'
72 | lr: 0.0002
73 | lr_step: 15000
74 | lr_decay: 0.5
75 |
76 | validation:
77 | splits: ["test"]
78 | frequency: 100
79 | batch_size: 512
80 |
81 | logger:
82 | type: "wandb"
83 | project: "MiDiffusion"
84 |
--------------------------------------------------------------------------------
/config/livingrooms_mixed.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | dataset_type: "cached_threedfront"
3 | encoding_type: "cached_diffusion_cosin_angle_wocm"
4 | dataset_directory: "livingroom"
5 | annotation_file: "livingroom_threed_front_splits.csv"
6 | augmentations: ["fixed_rotations"]
7 | train_stats: "dataset_stats.txt"
8 | room_layout_size: "64,64"
9 |
10 | network:
11 | type: "diffusion_scene_layout_mixed"
12 |
13 | # encoding dim
14 | sample_num_points: 21 # max_length
15 | angle_dim: 2
16 |
17 | # room mask condition
18 | room_mask_condition: true
19 | room_latent_dim: 64
20 |
21 | # position condition
22 | position_condition: false
23 | position_emb_dim: 0
24 |
25 | # diffusion config
26 | time_num: 1000
27 | diffusion_semantic_kwargs:
28 | att_1: 0.99999
29 | att_T: 0.000009
30 | ctt_1: 0.000009
31 | ctt_T: 0.99999
32 | model_output_type: 'x0'
33 | mask_weight: 1
34 | auxiliary_loss_weight: 0.0005
35 | adaptive_auxiliary_loss: True
36 | diffusion_geometric_kwargs:
37 | schedule_type: 'linear'
38 | beta_start: 0.0001
39 | beta_end: 0.02
40 | loss_type: 'mse'
41 | model_mean_type: 'eps'
42 | model_var_type: 'fixedsmall'
43 |
44 | # denoising net
45 | net_type: "transformer"
46 | net_kwargs:
47 | seperate_all: True
48 | n_layer: 8
49 | n_embd: 512
50 | n_head: 4
51 | dim_feedforward: 2048
52 | dropout: 0.1
53 | activate: 'GELU'
54 | timestep_type: 'adalayernorm_abs'
55 | mlp_type: 'fc'
56 |
57 | feature_extractor:
58 | name: "pointnet_simple"
59 | feat_units: [4, 64, 64, 512, 64]
60 |
61 | training:
62 | splits: ["train", "val"]
63 | epochs: 100000
64 | batch_size: 512
65 | save_frequency: 2000
66 | max_grad_norm: 10
67 | # optimizer
68 | optimizer: Adam
69 | weight_decay: 0.0
70 | # schedule
71 | schedule: 'step'
72 | lr: 0.0002
73 | lr_step: 15000
74 | lr_decay: 0.5
75 |
76 | validation:
77 | splits: ["test"]
78 | frequency: 100
79 | batch_size: 512
80 |
81 | logger:
82 | type: "wandb"
83 | project: "MiDiffusion"
84 |
--------------------------------------------------------------------------------
/external_licenses/ATISS:
--------------------------------------------------------------------------------
1 | NVIDIA Source Code License for ATISS
2 |
3 | 1. Definitions
4 |
5 | “Licensor” means any person or entity that distributes its Work.
6 | “Software” means the original work of authorship made available under this License.
7 | “Work” means the Software and any additions to or derivative works of the Software that are made available under this License.
8 | “NVIDIA Processors” means any central processing unit (CPU), graphics processing unit (GPU), field-programmable gate array (FPGA), application-specific integrated circuit (ASIC) or any combination thereof designed, made, sold, or provided by NVIDIA or its affiliates.
9 | The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as provided under U.S. copyright law; provided, however, that for the purposes of this License, derivative works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work.
10 | Works, including the Software, are “made available” under this License by including in or with the Work either (a) a copyright notice referencing the applicability of this License to the Work, or (b) a copy of this License.
11 |
12 | 2. License Grant
13 |
14 | 2.1 Copyright Grant. Subject to the terms and conditions of this License, each Licensor grants to you a perpetual, worldwide, non-exclusive, royalty-free, copyright license to reproduce, prepare derivative works of, publicly display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form.
15 |
16 | 3. Limitations
17 |
18 | 3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this License, (b) you include a complete copy of this License with your distribution, and (c) you retain without modification any copyright, patent, trademark, or attribution notices that are present in the Work.
19 |
20 | 3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and distribution of your derivative works of the Work (“Your Terms”) only if (a) Your Terms provide that the use limitation in Section 3.3 applies to your derivative works, and (b) you identify the specific derivative works that are subject to Your Terms. Notwithstanding Your Terms, this License (including the redistribution requirements in Section 3.1) will continue to apply to the Work itself.
21 |
22 | 3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use non-commercially and with NVIDIA Processors. Notwithstanding the foregoing, NVIDIA and its affiliates may use the Work and any derivative works commercially. As used herein, “non-commercially” means for research or evaluation purposes only.
23 |
24 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim, cross-claim or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then your rights under this License from such Licensor (including the grant in Section 2.1) will terminate immediately.
25 |
26 | 3.5 Trademarks. This License does not grant any rights to use any Licensor’s or its affiliates’ names, logos, or trademarks, except as necessary to reproduce the notices described in this License.
27 |
28 | 3.6 Termination. If you violate any term of this License, then your rights under this License (including the grant in Section 2.1) will terminate immediately.
29 |
30 | 4. Disclaimer of Warranty.
31 |
32 | THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
33 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE.
34 |
35 | 5. Limitation of Liability.
36 |
37 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER COMM ERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
--------------------------------------------------------------------------------
/external_licenses/DiffuScene:
--------------------------------------------------------------------------------
1 | Copyright 2023 Sony Group Corporation.
2 | All rights reserved.
3 |
4 | TERMS AND CONDITIONS FOR USE, MODIFICATION, REPRODUCTION AND DISTRIBUTION
5 |
6 | 1. Definitions.
7 | "License" means the terms and conditions for use, modification, reproduction and distribution as set forth under this document.
8 |
9 | "Licensor" means the copyright owner or entity authorized by the copyright owner that is granting this License.
10 |
11 | "Legal Entity" means the union of the acting entity and all other entities that control, are controlled by, or are under common control with that acting entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
12 |
13 | "You" (or "Your") means an individual or Legal Entity, except for Sony Group Corporation and all other entities that control, are controlled by, or are under common control with Sony Group Corporation, exercising permissions granted under this License.
14 |
15 | "Source" form means the preferred form for making modifications, including but not limited to human-readable software source code and configuration files.
16 |
17 | "Object" form means any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code and conversions to other media types.
18 |
19 | "Work" means the work of authorship, whether in Source or Object form, available under this License, as indicated by a copyright notice included in or attached to the work.
20 |
21 | "Derivative Works" mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, elaborations or other modifications represent, as a whole, an original work of authorship.
22 |
23 | "Contribution" means any work of authorship, including the original version of the Work and any modifications to, additions to or deletions of any part of that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in or application to the Work by the copyright owner or by an individual or Legal Entity authorized by the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing services, source code control systems and issue tracking systems that are managed by, or on behalf of, the Licensor or its representatives, but excluding communication explicitly designated as "Not a Contribution” by the copyright owner.
24 |
25 | "Contributor" means (i) Licensor and (ii) any individual or Legal Entity on behalf of whom a Contribution has been made and subsequently incorporated within the Work.
26 |
27 | 2. Grant of Copyright License.
28 | Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free and irrevocable copyright license to reproduce, make Derivative Works of, publicly display, publicly perform, sublicense and distribute the Work and such Derivative Works in Source or Object form.
29 |
30 | 3. Responsibilities for Distribution
31 | 3.1 Distribution of Object Form
32 | If You distribute the Work or Derivative Works in Object form, such Work or Derivative Works must also be available in Source form, as described in Section 3.2, and You must inform recipients of the Object form how they can obtain a copy of such Source form through a medium customarily used for software exchange, in a timely manner, at a charge no more than the cost of distribution to the recipient.
33 |
34 | 3.2 Distribution of Source Form
35 | When distributing the Work or Derivative Works in Source form, including any modifications that You create or to which You contribute, You must inform recipients that the Source form of the Work or Derivative Works are governed by this License, and how they can obtain a copy of this License.
36 |
37 | 3.3 Remote Network Interaction
38 | If Your version of the Work or Derivative Works support interaction remotely through a computer network, such version must also be available in Source form, as described in Section 3.1, and You must inform all interacting users how they can obtain a copy of such Source form through a medium customarily used for software exchange, in a timely manner, at a charge no more than the cost of distribution to the user.
39 |
40 | 3.4 Notes
41 | In any case of Section 3.1/3.2/3.3, You must clearly document how Your version of the Work or Derivative Works differ from the Work provided by the copyright owner, including but not limited to documenting any modified or added features, executables or modules, and You must fulfill all of the following requirements:
42 | (a) You must allow the copyright owner to include your modifications to the Work;
43 | (b) You must ensure that installation of Your version of the Work or Derivative Works do not prevent the user from installing or running the Work provided by the copyright owner. In addition, Your version must bear a name that is different from the name of the Work provided by the copyright owner; and
44 | (c) You must allow anyone who receives a copy of Your version of the Work or Derivative Works to make the Source or Object form of Your version, with or without further modifications, available to others under this License.
45 |
46 | 4. Commercial Use Limitation
47 | The Work and Derivative Works only may be used or intended for use non-commercially. Notwithstanding the foregoing, Sony Group Corporation, the original copyright owner, and affiliates authorized by Sony Group Corporation may use the Work and Derivative Works commercially. As used herein, “non-commercially” means research, education or evaluation purposes only.
48 |
49 | 5. Military Use Limitation
50 | It is forbidden to use the Work and Derivative Works in any system that is designed or intended to physically hurt or kill a human being, or cooperates with such system, including any kind of reconnaissance. This applies to any form of law enforcement and military operations as well as any kind of activity in an organization that designs, builds or evaluates hardware or software systems for those purposes.
51 |
52 | 6. Patent
53 | If You institute or threaten to institute patent litigation against any Licensor (including but not limited to a cross-claim or counterclaim in a lawsuit) alleging that the Work, Derivative Works or a Contribution incorporated within the Work or Derivative Works constitutes direct or contributory patent infringement, then all Your rights under this License (including the grant in Section 2) will terminate immediately.
54 |
55 | 7. Disclaimer of Warranty
56 | Licensor provides the Work or Derivative Works (and each Contributor provides its Contributions) on an "AS IS" basis, without warranties or conditions of any kind, either express or implied, including, without limitation, any warranties or conditions of title, non-infringement, merchantability, or fitness for a particular purpose. You are solely responsible for determining the appropriateness of using or redistributing the Work or Derivative Works and assume any risks associated with Your exercise of permissions under this License.
57 |
58 | 8. Limitation of Liability
59 | In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Licensor or Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work or Derivative Works (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Licensor or Contributor has been advised of the possibility of such damages.
60 |
61 | END OF TERMS AND CONDITIONS
--------------------------------------------------------------------------------
/external_licenses/VQ-Diffusion:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) Microsoft Corporation.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE
22 |
--------------------------------------------------------------------------------
/media/architecture_full.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MIT-SPARK/MiDiffusion/25b594338a29b31c6783b4b1ea87eeeb0ce12bbe/media/architecture_full.png
--------------------------------------------------------------------------------
/media/video.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MIT-SPARK/MiDiffusion/25b594338a29b31c6783b4b1ea87eeeb0ce12bbe/media/video.png
--------------------------------------------------------------------------------
/midiffusion/datasets/threed_front_encoding.py:
--------------------------------------------------------------------------------
1 | #
2 | # Modified from:
3 | # https://github.com/tangjiapeng/DiffuScene.
4 | #
5 |
6 | import numpy as np
7 | from torch.utils.data import dataloader
8 |
9 | from threed_front.datasets.threed_front_encoding_base import *
10 | from threed_front.datasets import get_raw_dataset
11 |
12 |
13 | class Diffusion(DatasetDecoratorBase):
14 | def __init__(self, dataset, max_length=None):
15 | super().__init__(dataset)
16 |
17 | if max_length is None:
18 | self._max_length = dataset.max_length
19 | else:
20 | assert max_length >= dataset.max_length
21 | self._max_length = max_length
22 |
23 | def __getitem__(self, idx):
24 | sample_params = self._dataset[idx]
25 |
26 | # Add the number of bounding boxes in the scene
27 | sample_params["length"] = sample_params["class_labels"].shape[0]
28 |
29 | sample_params_target = {}
30 | # Compute the target from the input
31 | for k, v in sample_params.items():
32 | if k in [
33 | "room_layout", "length", "fpbpn"
34 | ]:
35 | pass
36 |
37 | elif k == "class_labels":
38 | if self._dataset.n_classes == self._dataset.n_object_types + 2:
39 | # Delete the 'start' label and keep the last as 'empty' label
40 | class_labels = np.hstack([v[:, :-2], v[:, -1:]])
41 | else:
42 | assert self._dataset.n_classes == self._dataset.n_object_types + 1
43 | class_labels = v
44 | # Pad the 'empty' label in the end of each sequence,
45 | # and convert the class labels to -1, 1
46 | L, C = class_labels.shape
47 | empty_label = np.eye(C)[-1]
48 | sample_params_target[k] = np.vstack([
49 | class_labels,
50 | np.tile(empty_label[None, :], [self._max_length - L, 1])
51 | ]).astype(np.float32) * 2.0 - 1.0
52 |
53 | else:
54 | # Set the attributes for the 'empty' label
55 | L, C = v.shape
56 | sample_params_target[k] = np.vstack([
57 | v, np.zeros((self._max_length - L, C))
58 | ]).astype(np.float32)
59 |
60 | sample_params.update(sample_params_target)
61 |
62 | return sample_params
63 |
64 | @property
65 | def max_length(self):
66 | return self._max_length
67 |
68 | def collate_fn(self, samples):
69 | ''' Collater that puts each data field into a tensor with outer dimension
70 | batch size.
71 | Args:
72 | samples: samples
73 | '''
74 |
75 | samples = list(filter(lambda x: x is not None, samples))
76 | return dataloader.default_collate(samples)
77 |
78 |
79 | def get_dataset_raw_and_encoded(
80 | config,
81 | filter_fn=lambda s: s,
82 | path_to_bounds=None,
83 | augmentations=None,
84 | split=["train", "val"],
85 | max_length=None,
86 | include_room_mask=True,
87 | ):
88 | dataset = get_raw_dataset(
89 | config, filter_fn, path_to_bounds, split,
90 | include_room_mask=include_room_mask
91 | )
92 | encoding = dataset_encoding_factory(
93 | config.get("encoding_type"),
94 | dataset,
95 | augmentations,
96 | config.get("box_ordering", None),
97 | max_length
98 | )
99 |
100 | return dataset, encoding
101 |
102 |
103 | def get_encoded_dataset(
104 | config,
105 | filter_fn=lambda s: s,
106 | path_to_bounds=None,
107 | augmentations=None,
108 | split=["train", "val"],
109 | max_length=None,
110 | include_room_mask=True
111 | ):
112 | _, encoding = get_dataset_raw_and_encoded(
113 | config, filter_fn, path_to_bounds, augmentations, split, max_length,
114 | include_room_mask
115 | )
116 | return encoding
117 |
118 | def dataset_encoding_factory(
119 | name,
120 | dataset,
121 | augmentations=None,
122 | box_ordering=None,
123 | max_length=None,
124 | ):
125 | # list of object features
126 | feature_keys = ["class_labels", "translations", "sizes", "angles"]
127 | if "objfeats" in name:
128 | if "lat32" in name:
129 | feature_keys.append("objfeats_32")
130 | print("use lat32 as objfeats")
131 | else:
132 | feature_keys.append("objfeats")
133 | print("use lat64 as objfeats")
134 |
135 | # NOTE: The ordering might change after augmentations so really it should
136 | # be done after the augmentations. For class frequencies it is fine
137 | # though.
138 | if "cached" in name:
139 | dataset_collection = CachedDatasetCollection(dataset)
140 | if box_ordering:
141 | dataset_collection = \
142 | OrderedDataset(dataset_collection, feature_keys, box_ordering)
143 | else:
144 | box_ordered_dataset = BoxOrderedDataset(dataset, box_ordering)
145 |
146 | class_labels = ClassLabelsEncoder(box_ordered_dataset)
147 | translations = TranslationEncoder(box_ordered_dataset)
148 | sizes = SizeEncoder(box_ordered_dataset)
149 | angles = AngleEncoder(box_ordered_dataset)
150 | objfeats = ObjFeatEncoder(box_ordered_dataset)
151 | objfeats_32 = ObjFeat32Encoder(box_ordered_dataset)
152 |
153 | if name == "basic":
154 | return DatasetCollection(
155 | class_labels,
156 | translations,
157 | sizes,
158 | angles,
159 | objfeats,
160 | objfeats_32
161 | )
162 |
163 | room_layout = RoomLayoutEncoder(box_ordered_dataset)
164 | dataset_collection = DatasetCollection(
165 | room_layout,
166 | class_labels,
167 | translations,
168 | sizes,
169 | angles,
170 | objfeats,
171 | objfeats_32
172 | )
173 |
174 | if isinstance(augmentations, list):
175 | for aug_type in augmentations:
176 | if aug_type == "rotations":
177 | print("Applying rotation augmentations")
178 | dataset_collection = RotationAugmentation(dataset_collection)
179 | elif aug_type == "fixed_rotations":
180 | print("Applying fixed rotation augmentations")
181 | dataset_collection = RotationAugmentation(dataset_collection, fixed=True)
182 | elif aug_type == "jitter":
183 | print("Applying jittering augmentations")
184 | dataset_collection = Jitter(dataset_collection)
185 |
186 | # Scale the input
187 | if "cosin_angle" in name:
188 | dataset_collection = Scale_CosinAngle(dataset_collection)
189 | else:
190 | dataset_collection = Scale(dataset_collection)
191 |
192 | # for diffusion (represent objectness as the last channel of class label)
193 | if "diffusion" in name:
194 | if "eval" in name:
195 | return Diffusion(dataset_collection, max_length)
196 | elif "wocm_no_prm" in name:
197 | return Diffusion(dataset_collection, max_length)
198 | elif "wocm" in name:
199 | dataset_collection = Permutation(dataset_collection, feature_keys)
200 | return Diffusion(dataset_collection, max_length)
201 | else:
202 | raise NotImplementedError()
--------------------------------------------------------------------------------
/midiffusion/evaluation/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch.utils.data import dataloader
4 | from tqdm import tqdm
5 |
6 | from midiffusion.networks.diffusion_scene_layout_ddpm import DiffusionSceneLayout_DDPM
7 | from midiffusion.networks.diffusion_scene_layout_mixed import DiffusionSceneLayout_Mixed
8 | from midiffusion.datasets.threed_front_encoding import Diffusion
9 |
10 |
11 | def get_feature_mask(network, experiment, num_known_objects, device):
12 | if experiment == "synthesis":
13 | feature_mask = None
14 | print("Experiment: scene synthesis.")
15 | elif experiment == "scene_completion":
16 | assert num_known_objects > 0
17 | feature_mask = torch.zeros(
18 | (network.sample_num_points, network.point_dim + network.class_dim),
19 | dtype=torch.bool, device=device
20 | )
21 | feature_mask[:num_known_objects] = True
22 | print("Experiment: scene completion (given {} objects) using corruption-and-masking."\
23 | .format(num_known_objects))
24 | elif experiment == "furniture_arrangement":
25 | feature_mask = torch.zeros(
26 | network.point_dim + network.class_dim, dtype=torch.bool, device=device
27 | )
28 | feature_mask[network.translation_dim:
29 | network.translation_dim + network.size_dim] = True # size
30 | feature_mask[network.bbox_dim:
31 | network.bbox_dim + network.class_dim] = True # class
32 | feature_mask = feature_mask.repeat(network.sample_num_points, 1)
33 | print("Experiment: furniture arrangement.")
34 | elif experiment == "object_conditioned":
35 | feature_mask = torch.zeros(
36 | network.sample_num_points, network.point_dim + network.class_dim,
37 | dtype=torch.bool, device=device
38 | )
39 | feature_mask[:, network.bbox_dim: network.bbox_dim + network.class_dim] = True # class
40 | print("Experiment: object conditioned synthesis using corruption-and-masking.")
41 | elif experiment == "scene_completion_conditioned":
42 | feature_mask = torch.zeros(
43 | network.sample_num_points, network.point_dim + network.class_dim,
44 | dtype=torch.bool, device=device
45 | )
46 | feature_mask[:num_known_objects] = True # existing objects
47 | feature_mask[:, network.bbox_dim: network.bbox_dim + network.class_dim] = True # class
48 | print("Experiment: scene completion (given {} objects) conditioned on labels using corruption-and-masking."\
49 | .format(num_known_objects))
50 | else:
51 | raise NotImplemented
52 | return feature_mask
53 |
54 |
55 | def generate_layouts(network:DiffusionSceneLayout_DDPM, encoded_dataset:Diffusion,
56 | config, num_syn_scenes, sampling_rule="random",
57 | experiment="synthesis", num_known_objects=0,
58 | batch_size=16, device="cpu"):
59 | """Generate speicifed number of object layouts and also return a list of scene
60 | indices corresponding to the floor plan. Each layout is a 2D array where each
61 | row contain the concatenated object attributes.
62 | (Note: this code assumes "end" is the last object label, and, if used,
63 | "start" is the second to last label.)"""
64 |
65 | # Sample floor layout
66 | if sampling_rule == "random":
67 | sampled_indices = np.random.choice(len(encoded_dataset), num_syn_scenes).tolist()
68 | elif sampling_rule == "uniform":
69 | sampled_indices = np.arange(len(encoded_dataset)).tolist() * \
70 | (num_syn_scenes // len(encoded_dataset))
71 | sampled_indices += \
72 | np.random.choice(len(encoded_dataset),
73 | num_syn_scenes - len(sampled_indices)).tolist()
74 | else:
75 | raise NotImplemented
76 |
77 | # network params
78 | with_room_mask = config["network"].get("room_mask_condition", True)
79 | print("Floor condition: {}.".format(with_room_mask))
80 | feature_mask = get_feature_mask(network, experiment, num_known_objects, device)
81 |
82 | # Generate layouts
83 | network.to(device)
84 | network.eval()
85 | layout_list = []
86 | for i in tqdm(range(0, num_syn_scenes, batch_size)):
87 | scene_indices = sampled_indices[i: min(i + batch_size, num_syn_scenes)]
88 |
89 | room_feature = None
90 | if with_room_mask:
91 | if config["feature_extractor"]["name"] == "resnet18":
92 | room_feature = torch.from_numpy(np.stack([
93 | encoded_dataset[ind]["room_layout"] for ind in scene_indices
94 | ], axis=0)).to(device)
95 | elif config["feature_extractor"]["name"] == "pointnet_simple":
96 | room_feature = torch.from_numpy(np.stack([
97 | encoded_dataset[ind]["fpbpn"] for ind in scene_indices
98 | ], axis=0)).to(device)
99 |
100 | if experiment == "synthesis":
101 | input_boxes = None
102 | else:
103 | samples = list(encoded_dataset[ind] for ind in scene_indices)
104 | sample_params = dataloader.default_collate(samples)
105 | input_boxes = network.unpack_data(sample_params).to(device)
106 |
107 | bbox_params_list = network.generate_layout(
108 | room_feature=room_feature,
109 | batch_size=len(scene_indices),
110 | input_boxes=input_boxes,
111 | feature_mask=feature_mask,
112 | device=device,
113 | )
114 | for bbox_params_dict in bbox_params_list:
115 | boxes = encoded_dataset.post_process(bbox_params_dict)
116 | bbox_params = {k: v.numpy()[0] for k, v in boxes.items()}
117 | layout_list.append(bbox_params)
118 |
119 | return sampled_indices, layout_list
120 |
--------------------------------------------------------------------------------
/midiffusion/networks/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # Modified from:
3 | # https://github.com/tangjiapeng/DiffuScene
4 | #
5 |
6 | import os
7 | import math
8 | import torch
9 | from torch.nn.utils import clip_grad_norm_
10 |
11 | from .feature_extractors import get_feature_extractor
12 | from .diffusion_scene_layout_ddpm import DiffusionSceneLayout_DDPM
13 | from .diffusion_scene_layout_mixed import DiffusionSceneLayout_Mixed
14 | from ..stats_logger import StatsLogger
15 |
16 |
17 | def optimizer_factory(config, parameters):
18 | """Based on the provided config create the suitable optimizer."""
19 | optimizer = config.get("optimizer", "Adam")
20 | lr = config.get("lr", 1e-3)
21 | momentum = config.get("momentum", 0.9)
22 | # weight_decay = config.get("weight_decay", 0.0)
23 | # Weight decay was set to 0.0 in the paper's experiments. We note that
24 | # increasing the weight_decay deteriorates performance.
25 | weight_decay = 0.0
26 |
27 | if optimizer == "SGD":
28 | return torch.optim.SGD(
29 | parameters, lr=lr, momentum=momentum, weight_decay=weight_decay
30 | )
31 | elif optimizer == "Adam":
32 | return torch.optim.Adam(parameters, lr=lr, weight_decay=weight_decay)
33 | elif optimizer == "RAdam":
34 | return torch.optim.RAdam(parameters, lr=lr, weight_decay=weight_decay)
35 | else:
36 | raise NotImplementedError()
37 |
38 |
39 | def train_on_batch(model, optimizer, sample_params, max_grad_norm=None):
40 | # Make sure that everything has the correct size
41 | optimizer.zero_grad()
42 | # Compute the loss
43 | loss, loss_dict = model.get_loss(sample_params)
44 | for k, v in loss_dict.items():
45 | StatsLogger.instance()[k].value = v.item()
46 | # Do the backpropagation
47 | loss.backward()
48 | # Compute model norm
49 | if max_grad_norm is not None:
50 | grad_norm = clip_grad_norm_(model.parameters(), max_norm=max_grad_norm)
51 | StatsLogger.instance()["gradnorm"].value = grad_norm.item()
52 | # log learning rate
53 | StatsLogger.instance()["lr"].value = optimizer.param_groups[0]['lr']
54 | # Do the update
55 | optimizer.step()
56 |
57 | return loss.item()
58 |
59 |
60 | @torch.no_grad()
61 | def validate_on_batch(model, sample_params):
62 | # Compute the loss
63 | loss, loss_dict = model.get_loss(sample_params)
64 | for k, v in loss_dict.items():
65 | StatsLogger.instance()[k].value = v.item()
66 | return loss.item()
67 |
68 |
69 | def build_network(n_object_types, config, weight_file=None, device="cpu"):
70 | network_type = config["network"]["type"]
71 |
72 | feature_extractor = get_feature_extractor(
73 | **config["feature_extractor"]
74 | ) if config["network"].get("room_mask_condition", True) else None
75 |
76 | if network_type == "diffusion_scene_layout_ddpm":
77 | network = DiffusionSceneLayout_DDPM(
78 | n_object_types,
79 | feature_extractor,
80 | config["network"],
81 | os.path.join(config["data"]["dataset_directory"],
82 | config["data"]["train_stats"])
83 | )
84 | elif network_type == "diffusion_scene_layout_mixed":
85 | network = DiffusionSceneLayout_Mixed(
86 | n_object_types,
87 | feature_extractor,
88 | config["network"],
89 | os.path.join(config["data"]["dataset_directory"],
90 | config["data"]["train_stats"])
91 | )
92 | else:
93 | raise NotImplementedError()
94 |
95 | # Check whether there is a weight file provided to continue training from
96 | if weight_file is not None:
97 | print("Loading weight file from {}".format(weight_file))
98 | network.load_state_dict(
99 | torch.load(weight_file, map_location=device)
100 | )
101 | network.to(device)
102 | return network, train_on_batch, validate_on_batch
103 |
104 |
105 | # set up learning scheduler
106 | class LearningRateSchedule:
107 | def get_learning_rate(self, epoch):
108 | raise NotImplementedError()
109 |
110 |
111 | class StepLearningRateSchedule(LearningRateSchedule):
112 | def __init__(self, specs):
113 | print(specs)
114 | self.initial = specs["initial"]
115 | self.interval = specs["interval"]
116 | self.factor = specs["factor"]
117 |
118 | def get_learning_rate(self, epoch):
119 | return self.initial * (self.factor ** (epoch // self.interval))
120 |
121 |
122 | class LambdaLearningRateSchedule(LearningRateSchedule):
123 | def __init__(self, specs):
124 | print(specs)
125 | self.start_epoch = specs["start_epoch"]
126 | self.end_epoch = specs["end_epoch"]
127 | self.start_lr = specs["start_lr"]
128 | self.end_lr = specs["end_lr"]
129 |
130 | def lr_func(self, epoch):
131 | if epoch <= self.start_epoch:
132 | return 1.0
133 | elif epoch <= self.end_epoch:
134 | total = self.end_epoch - self.start_epoch
135 | delta = epoch - self.start_epoch
136 | frac = delta / total
137 | return (1-frac) * 1.0 + frac * (self.end_lr / self.start_lr)
138 | else:
139 | return self.end_lr / self.start_lr
140 |
141 | def get_learning_rate(self, epoch):
142 | lambda_factor = self.lr_func(epoch)
143 | return self.start_lr * lambda_factor
144 |
145 |
146 | class WarmupCosineLearningRateSchedule(LearningRateSchedule):
147 | def __init__(self, specs):
148 | print(specs)
149 | self.warmup_epochs = specs["warmup_epochs"]
150 | self.total_epochs = specs["total_epochs"]
151 | self.lr = specs["lr"]
152 | self.min_lr = specs["min_lr"]
153 |
154 | def get_learning_rate(self, epoch):
155 | if epoch <= self.warmup_epochs:
156 | lr = self.lr
157 | else:
158 | lr = self.min_lr + (self.lr - self.min_lr) * 0.5 * \
159 | (1.0 + math.cos(math.pi * (epoch-self.warmup_epochs) / \
160 | (self.total_epochs - self.warmup_epochs)))
161 | return lr
162 |
163 |
164 | def adjust_learning_rate(lr_schedules, optimizer, epoch):
165 | if (type(lr_schedules)==list):
166 | for i, param_group in enumerate(optimizer.param_groups):
167 | param_group["lr"] = lr_schedules[i].get_learning_rate(epoch)
168 | else:
169 | for i, param_group in enumerate(optimizer.param_groups):
170 | param_group["lr"] = lr_schedules.get_learning_rate(epoch)
171 |
172 |
173 | def schedule_factory(config):
174 | """Based on the provided config create the suitable learning schedule."""
175 | schedule = config.get("schedule", "lambda")
176 |
177 | # Set up LearningRateSchedule
178 | if schedule == "step" or schedule == "Step":
179 | lr_schedule = StepLearningRateSchedule({
180 | "type": "step",
181 | "initial" : config.get("lr", 1e-3),
182 | "interval": config.get("lr_step", 100),
183 | "factor" : config.get("lr_decay", 0.1),
184 | },)
185 |
186 | elif schedule == "lambda" or schedule == "Lambda":
187 | lr_schedule = LambdaLearningRateSchedule({
188 | "type": "lambda",
189 | "start_epoch": config.get("start_epoch", 1000),
190 | "end_epoch" : config.get("end_epoch", 1000),
191 | "start_lr" : config.get("start_lr", 0.002),
192 | "end_lr" : config.get("end_lr", 0.002),
193 | },)
194 |
195 | elif schedule == "warmupcosine" or schedule == "WarmupCosine":
196 | lr_schedule = WarmupCosineLearningRateSchedule({
197 | "type": "warmupcosine",
198 | "warmup_epochs" : config.get("warmup_epochs", 10),
199 | "total_epochs" : config.get("total_epochs", 2000),
200 | "lr" : config.get("lr", 2e-4),
201 | "min_lr" : config.get("min_lr", 1e-6),
202 | },)
203 |
204 | else:
205 | raise NotImplementedError()
206 |
207 | return lr_schedule
--------------------------------------------------------------------------------
/midiffusion/networks/denoising_net/continuous_transformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .transformer_utils import DenoiseTransformer
3 |
4 |
5 | class ContinuousDenoiseTransformer(DenoiseTransformer):
6 | """Continuous denoising transformer network where all object properties are
7 | treated as continuous"""
8 | def __init__(
9 | self,
10 | network_dim,
11 | seperate_all=True,
12 | n_layer=4,
13 | n_embd=512,
14 | n_head=8,
15 | dim_feedforward=2048,
16 | dropout=0.1,
17 | activate='GELU',
18 | num_timesteps=1000,
19 | timestep_type='adalayernorm_abs',
20 | context_dim=256,
21 | mlp_type='fc',
22 | ):
23 | # initialize self.tf_blocks, the transformer backbone
24 | super().__init__(
25 | n_layer, n_embd, n_head, dim_feedforward, dropout, activate,
26 | num_timesteps, timestep_type, context_dim, mlp_type
27 | )
28 |
29 | # feature dimensions
30 | self.objectness_dim, self.class_dim, self.objfeat_dim = \
31 | network_dim["objectness_dim"], network_dim["class_dim"], \
32 | network_dim["objfeat_dim"]
33 | self.translation_dim, self.size_dim, self.angle_dim = \
34 | network_dim["translation_dim"], network_dim["size_dim"], \
35 | network_dim["angle_dim"]
36 | self.bbox_dim = self.translation_dim + self.size_dim + self.angle_dim
37 | self.channels = self.bbox_dim + self.objectness_dim + self.class_dim + self.objfeat_dim
38 |
39 | # Initial feature specific processing
40 | self.seperate_all = seperate_all
41 | if self.seperate_all:
42 | self.bbox_embedf = self._encoder_mlp(n_embd, self.bbox_dim)
43 | self.bbox_hidden2output = self._decoder_mlp(n_embd, self.bbox_dim)
44 | feature_str = "translation/size/angle"
45 |
46 | if self.class_dim > 0:
47 | self.class_embedf = self._encoder_mlp(n_embd, self.class_dim)
48 | feature_str += "/class"
49 | if self.objectness_dim > 0:
50 | self.objectness_embedf = self._encoder_mlp(n_embd, self.objectness_dim)
51 | feature_str += "/objectness"
52 | if self.objfeat_dim > 0:
53 | self.objfeat_embedf = self._encoder_mlp(n_embd, self.objfeat_dim)
54 | feature_str += "/objfeat"
55 | print('separate unet1d encoder/decoder of {}'.format(feature_str))
56 | else:
57 | self.init_mlp = self._encoder_mlp(n_embd, self.channels)
58 | print('unet1d encoder of all object properties')
59 |
60 | # Final feature specific processing
61 | if self.seperate_all:
62 | self.bbox_hidden2output = self._decoder_mlp(n_embd, self.bbox_dim)
63 | if self.class_dim > 0:
64 | self.class_hidden2output = self._decoder_mlp(n_embd, self.class_dim)
65 | if self.objectness_dim > 0:
66 | self.objectness_hidden2output = self._decoder_mlp(n_embd, self.objectness_dim)
67 | if self.objfeat_dim > 0:
68 | self.objfeat_hidden2output = self._decoder_mlp(n_embd, self.objfeat_dim)
69 | else:
70 | self.hidden2output = self._decoder_mlp(n_embd, self.channels)
71 |
72 | def forward(self, x, time, context=None, context_cross=None):
73 | # x: (B, N, C)
74 | if context_cross is not None:
75 | raise NotImplemented # TODO
76 |
77 | # initial processing
78 | if self.seperate_all:
79 | x_bbox = self.bbox_embedf(x[:, :, 0:self.bbox_dim])
80 |
81 | if self.class_dim > 0:
82 | start_index = self.bbox_dim
83 | x_class = self.class_embedf(
84 | x[:, :, start_index:start_index+self.class_dim]
85 | )
86 | else:
87 | x_class = 0
88 |
89 | if self.objectness_dim > 0:
90 | start_index = self.bbox_dim+self.class_dim
91 | x_object = self.objectness_embedf(
92 | x[:, :, start_index:start_index+self.objectness_dim]
93 | )
94 | else:
95 | x_object = 0
96 |
97 | if self.objfeat_dim > 0:
98 | start_index = self.bbox_dim+self.class_dim+self.objectness_dim
99 | x_objfeat = self.objfeat_embedf(
100 | x[:, :, start_index:start_index+self.objfeat_dim]
101 | )
102 | else:
103 | x_objfeat = 0
104 |
105 | x = x_bbox + x_class + x_object + x_objfeat
106 | else:
107 | x = self.init_mlp(x)
108 |
109 | # transformer
110 | for block_idx in range(len(self.tf_blocks)):
111 | x = self.tf_blocks[block_idx](
112 | x, time, context
113 | )
114 |
115 | # final processing
116 | if self.seperate_all:
117 | out = self.bbox_hidden2output(x)
118 | if self.class_dim > 0:
119 | out_class = self.class_hidden2output(x)
120 | out = torch.cat([out, out_class], dim=2).contiguous()
121 | if self.objectness_dim > 0:
122 | out_object = self.objectness_hidden2output(x)
123 | out = torch.cat([out, out_object], dim=2).contiguous()
124 | if self.objfeat_dim > 0:
125 | out_objfeat = self.objfeat_hidden2output(x)
126 | out = torch.cat([out, out_objfeat], dim=2).contiguous()
127 | else:
128 | out = self.hidden2output(x)
129 |
130 | return out
--------------------------------------------------------------------------------
/midiffusion/networks/denoising_net/mixed_transformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from .transformer_utils import DenoiseTransformer
4 |
5 |
6 | class MixedDenoiseTransformer(DenoiseTransformer):
7 | """Mixed denoising transformer network where class labels are treated as
8 | discrete variables and (optional) geometric features are treated as
9 | continuous variables."""
10 | def __init__(
11 | self,
12 | network_dim,
13 | seperate_all=True,
14 | n_layer=4,
15 | n_embd=512,
16 | n_head=8,
17 | dim_feedforward=2048,
18 | dropout=0.1,
19 | activate='GELU',
20 | num_timesteps=1000,
21 | timestep_type='adalayernorm_abs',
22 | context_dim=256,
23 | mlp_type='fc',
24 | concate_features=False,
25 | ):
26 | # initialize self.tf_blocks, the transformer backbone
27 | super().__init__(
28 | n_layer, n_embd, n_head, dim_feedforward, dropout, activate,
29 | num_timesteps, timestep_type, context_dim, mlp_type
30 | )
31 | assert network_dim["class_dim"] > 0
32 |
33 | # feature dimensions
34 | self.objectness_dim, self.class_dim, self.objfeat_dim = \
35 | network_dim["objectness_dim"], network_dim["class_dim"], \
36 | network_dim["objfeat_dim"]
37 | self.translation_dim, self.size_dim, self.angle_dim = \
38 | network_dim["translation_dim"], network_dim["size_dim"], \
39 | network_dim["angle_dim"]
40 | self.bbox_dim = self.translation_dim + self.size_dim + self.angle_dim
41 | self.geo_dim = self.bbox_dim + self.objectness_dim + self.objfeat_dim
42 |
43 | # Feature specific processing
44 | self.concate_features = concate_features
45 | if concate_features:
46 | n_features = 2 + (self.objectness_dim > 0) + (self.objectness_dim > 0)
47 | feat_embd = n_embd // n_features
48 | geo_embd = feat_embd * (n_features - 1)
49 | class_embd = n_embd - geo_embd
50 | decode_embd = n_embd + network_dim["class_dim"] # geometric decoder
51 | print("concatenate features (class embd: {}, geometric embd: {})"
52 | .format(class_embd, feat_embd))
53 | else:
54 | feat_embd = n_embd
55 | geo_embd = n_embd
56 | class_embd = n_embd
57 | decode_embd = n_embd
58 |
59 | # semantic feature - add additional [mask] embedding
60 | self.class_emb = nn.Embedding(network_dim["class_dim"] + 1, class_embd)
61 | self.to_logits = nn.Sequential(
62 | nn.LayerNorm(n_embd),
63 | nn.Linear(n_embd, network_dim["class_dim"]),
64 | )
65 |
66 | # geometric features
67 | self.seperate_all = seperate_all
68 | if self.seperate_all:
69 | if self.bbox_dim > 0:
70 | self.bbox_embedf = self._encoder_mlp(feat_embd, self.bbox_dim)
71 | self.bbox_hidden2output = self._decoder_mlp(decode_embd, self.bbox_dim)
72 | if self.objectness_dim > 0:
73 | self.objectness_embedf = self._encoder_mlp(feat_embd, self.objectness_dim)
74 | self.objectness_hidden2output = self._decoder_mlp(decode_embd, self.objectness_dim)
75 | if self.objfeat_dim > 0:
76 | self.objfeat_embedf = self._encoder_mlp(feat_embd, self.objfeat_dim)
77 | self.objfeat_hidden2output = self._decoder_mlp(decode_embd, self.objfeat_dim)
78 | elif self.geo_dim > 0:
79 | self.init_mlp = self._encoder_mlp(geo_embd, self.geo_dim)
80 | self.geo_hidden2output = self._decoder_mlp(decode_embd, self.geo_dim)
81 |
82 | def forward(self, x_semantic, x_geometry, time, context=None, context_cross=None):
83 | B, N = x_semantic.shape
84 | if context_cross is not None:
85 | raise NotImplemented # TODO
86 |
87 | # initial processing
88 | x_class = self.class_emb(x_semantic)
89 | if self.seperate_all:
90 | if self.bbox_dim > 0:
91 | x_geo = self.bbox_embedf(x_geometry[:, :, 0:self.bbox_dim])
92 | else:
93 | x_geo = torch.empty(size=(B, N, 0), device=x_semantic.device)
94 | if self.objectness_dim > 0:
95 | start_index = self.bbox_dim
96 | x_object = self.objectness_embedf(
97 | x_geometry[:, :, start_index:start_index+self.objectness_dim]
98 | )
99 | if self.concate_features:
100 | x_geo = torch.cat([x_geo, x_object], dim=2).contiguous()
101 | else:
102 | x_geo += x_object
103 | if self.objfeat_dim > 0:
104 | start_index = self.bbox_dim+self.objectness_dim
105 | x_objfeat = self.objfeat_embedf(
106 | x_geometry[:, :, start_index:start_index+self.objfeat_dim]
107 | )
108 | if self.concate_features:
109 | x_geo = torch.cat([x_geo, x_objfeat], dim=2).contiguous()
110 | else:
111 | x_geo += x_objfeat
112 | elif self.geo_dim > 0:
113 | x_geo = self.init_mlp(x_geometry)
114 |
115 | if self.geo_dim > 0:
116 | if self.concate_features:
117 | x = torch.cat([x_class, x_geo], dim=2).contiguous()
118 | else:
119 | x = x_class + x_geo
120 | else:
121 | x = x_class
122 |
123 | # transformer
124 | for block_idx in range(len(self.tf_blocks)):
125 | x = self.tf_blocks[block_idx](x, time, context)
126 |
127 | # final processing
128 | out_class = self.to_logits(x)
129 | if self.concate_features:
130 | x = torch.cat([x, out_class], dim=2).contiguous()
131 | if self.seperate_all:
132 | if self.bbox_dim > 0:
133 | out = self.bbox_hidden2output(x)
134 | else:
135 | out = torch.empty(size=(B, N, 0), device=x.device)
136 | if self.objectness_dim > 0:
137 | out_object = self.objectness_hidden2output(x)
138 | out = torch.cat([out, out_object], dim=2).contiguous()
139 | if self.objfeat_dim > 0:
140 | out_objfeat = self.objfeat_hidden2output(x)
141 | out = torch.cat([out, out_objfeat], dim=2).contiguous()
142 | elif self.geo_dim > 0:
143 | out = self.geo_hidden2output(x)
144 | else:
145 | out = None
146 |
147 | return out_class, out
148 |
--------------------------------------------------------------------------------
/midiffusion/networks/denoising_net/time_embedding.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch import nn, Tensor
4 | from einops import rearrange
5 |
6 |
7 | class SinusoidalPosEmb(nn.Module):
8 | """https://github.com/microsoft/VQ-Diffusion/blob/main/image_synthesis/modeling/transformers/transformer_utils.py"""
9 | def __init__(self, dim: int, num_steps: int=4000, rescale_steps: int=4000):
10 | super().__init__()
11 | self.dim = dim
12 | if num_steps != rescale_steps:
13 | self.num_steps = float(num_steps)
14 | self.rescale_steps = float(rescale_steps)
15 | self.input_scaling = True
16 | else:
17 | self.input_scaling = False
18 |
19 | def forward(self, x: Tensor):
20 | # (B) -> (B, self.dim)
21 | if self.input_scaling:
22 | x = x / self.num_steps * self.rescale_steps
23 | device = x.device
24 | half_dim = self.dim // 2
25 | emb = math.log(10000) / (half_dim - 1)
26 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
27 | emb = x[:, None] * emb[None, :]
28 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
29 | return emb
30 |
31 |
32 | class RandomOrLearnedSinusoidalPosEmb(nn.Module):
33 | """ following @crowsonkb 's lead with random (learned optional) sinusoidal pos emb """
34 | """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """
35 |
36 | def __init__(self, dim, is_random = False):
37 | super().__init__()
38 | assert (dim % 2) == 0
39 | half_dim = dim // 2
40 | self.weights = nn.Parameter(torch.randn(half_dim), requires_grad = not is_random)
41 |
42 | def forward(self, x):
43 | x = rearrange(x, 'b -> b 1')
44 | freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
45 | fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)
46 | fouriered = torch.cat((x, fouriered), dim = -1)
47 | return fouriered
48 |
--------------------------------------------------------------------------------
/midiffusion/networks/denoising_net/transformer_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # Modified from:
3 | # https://github.com/cientgu/VQ-Diffusion/blob/main/image_synthesis/modeling/transformers/transformer_utils.py
4 | #
5 |
6 | import torch
7 | import torch.nn.functional as F
8 | from einops.layers.torch import Rearrange
9 | from torch import Tensor, nn
10 | from .time_embedding import SinusoidalPosEmb
11 |
12 | LAYER_NROM_EPS = 1e-5 # pytorch's default: 1e-5
13 |
14 |
15 | class GELU2(nn.Module):
16 | def __init__(self):
17 | super().__init__()
18 | def forward(self, x):
19 | return x * F.sigmoid(1.702 * x)
20 |
21 |
22 | class _AdaNorm(nn.Module):
23 | """Base normalization layer that incorporate timestep embeddings"""
24 | def __init__(
25 | self, n_embd: int, max_timestep: int, emb_type: str="adalayernorm_abs"
26 | ):
27 | super().__init__()
28 | assert n_embd % 2 == 0
29 | if "abs" in emb_type:
30 | self.emb = SinusoidalPosEmb(n_embd, num_steps=max_timestep)
31 | elif "mlp" in emb_type:
32 | self.emb = nn.Sequential(
33 | Rearrange("b -> b 1"),
34 | nn.Linear(1, n_embd // 2),
35 | nn.ReLU(),
36 | nn.Linear(n_embd // 2, n_embd),
37 | )
38 | else:
39 | self.emb = nn.Embedding(max_timestep, n_embd)
40 | self.silu = nn.SiLU()
41 | self.linear = nn.Linear(n_embd, n_embd * 2)
42 |
43 |
44 | class AdaLayerNorm(_AdaNorm):
45 | """Norm layer modified to incorporate timestep embeddings"""
46 | def __init__(
47 | self, n_embd: int, max_timestep: int, emb_type: str="adalayernorm_abs"
48 | ):
49 | super().__init__(n_embd, max_timestep, emb_type)
50 | self.layernorm = nn.LayerNorm(n_embd, eps=LAYER_NROM_EPS, elementwise_affine=False)
51 |
52 | def forward(self, x: Tensor, timestep: Tensor):
53 | # (B, N, n_embd),(B,) -> (B, N, n_embd)
54 | emb = self.linear(self.silu(self.emb(timestep))).unsqueeze(1) # B, 1, 2*n_embd
55 | scale, shift = torch.chunk(emb, 2, dim=2)
56 | x = self.layernorm(x) * (1 + scale) + shift
57 | return x
58 |
59 |
60 | class AdaInsNorm(_AdaNorm):
61 | """Base instance normalization layer that incorporate timestep embeddings"""
62 | def __init__(
63 | self, n_embd: int, max_timestep: int, emb_type: str="adainsnorm_abs"
64 | ):
65 | super().__init__(n_embd, max_timestep, emb_type)
66 | self.instancenorm = nn.InstanceNorm1d(n_embd, eps=LAYER_NROM_EPS)
67 |
68 | def forward(self, x: Tensor, timestep: Tensor):
69 | # (B, N, n_embd),(B,) -> (B, N, n_embd)
70 | emb = self.linear(self.silu(self.emb(timestep))).unsqueeze(1) # B, 1, 2*n_embd
71 | scale, shift = torch.chunk(emb, 2, dim=2)
72 | x = self.instancenorm(x.transpose(-1, -2)).transpose(-1, -2) * (1 + scale) \
73 | + shift
74 | return x
75 |
76 |
77 | class SelfAttention(nn.Module):
78 | """Multi-head self attention block - in transformer encoder"""
79 | def __init__(self, n_embd, n_head, dropout=0.1, batch_first=True):
80 | super().__init__()
81 | assert n_embd % n_head == 0
82 | self.mha = nn.MultiheadAttention(
83 | n_embd, n_head, dropout, batch_first=batch_first
84 | )
85 |
86 | def forward(self, x, attn_mask=None, key_padding_mask=None):
87 | return self.mha(
88 | x, x, x, need_weights=False,
89 | attn_mask=attn_mask, key_padding_mask=key_padding_mask
90 | )[0]
91 |
92 |
93 | class CrossAttention(nn.Module):
94 | """Multi-head cross attention block - in transformer decoder"""
95 | def __init__(self, n_embd, n_head, dropout=0.1, batch_first=True, kv_embd=None):
96 | super().__init__()
97 | assert n_embd % n_head == 0
98 | self.mha = nn.MultiheadAttention(
99 | n_embd, n_head, dropout, batch_first=batch_first, kdim=kv_embd, vdim=kv_embd
100 | )
101 |
102 | def forward(self, q, kv, attn_mask=None, key_padding_mask=None):
103 | return self.mha(
104 | query=q, key=kv, value=kv, need_weights=False,
105 | attn_mask=attn_mask, key_padding_mask=key_padding_mask
106 | )[0]
107 |
108 |
109 | class Block(nn.Module):
110 | """ Time-conditioned transformer block """
111 | def __init__(self,
112 | n_embd=512,
113 | n_head=8,
114 | dim_feedforward=2048,
115 | dropout=0.1,
116 | activate='GELU',
117 | num_timesteps=1000,
118 | timestep_type='adalayernorm_abs',
119 | attn_type='self',
120 | num_labels=None, # attn_type = 'selfcondition'
121 | label_type='adalayernorm', # attn_type = 'selfcondition'
122 | cond_emb_dim=None, # attn_type = 'selfcross'
123 | mlp_type = 'fc',
124 | ):
125 | super().__init__()
126 | self.attn_type = attn_type
127 |
128 | if "adalayernorm" in timestep_type:
129 | self.ln1 = AdaLayerNorm(n_embd, num_timesteps, timestep_type)
130 | elif "adainnorm" in timestep_type:
131 | self.ln1 = AdaInsNorm(n_embd, num_timesteps, timestep_type)
132 | else:
133 | raise ValueError(f"timestep_type={timestep_type} not valid.")
134 |
135 | if attn_type == 'self':
136 | self.attn = SelfAttention(
137 | n_embd=n_embd, n_head=n_head, dropout=dropout
138 | )
139 | self.ln2 = nn.LayerNorm(n_embd)
140 | elif attn_type == 'selfcondition': # conditioned on int labels
141 | self.attn = SelfAttention(
142 | n_embd=n_embd, n_head=n_head, dropout=dropout
143 | )
144 | if 'adalayernorm' in label_type:
145 | self.ln2 = AdaLayerNorm(n_embd, num_labels, label_type)
146 | else:
147 | self.ln2 = AdaInsNorm(n_embd, num_labels, label_type)
148 | elif attn_type == 'selfcross': # cross attention with cond_emb
149 | self.attn1 = SelfAttention(
150 | n_embd=n_embd, n_head=n_head, dropout=dropout
151 | )
152 | self.attn2 = CrossAttention(
153 | n_embd=n_embd, n_head=n_head, dropout=dropout, kv_embd=cond_emb_dim,
154 | )
155 | if 'adalayernorm' in timestep_type:
156 | self.ln1_1 = AdaLayerNorm(n_embd, num_timesteps, timestep_type)
157 | else:
158 | raise ValueError(f"timestep_type={timestep_type} not valid.")
159 | self.ln2 = nn.LayerNorm(n_embd)
160 | else:
161 | raise ValueError(f"attn_type={attn_type} not valid.")
162 |
163 | assert activate in ['GELU', 'GELU2']
164 | act = nn.GELU() if activate == 'GELU' else GELU2()
165 | if mlp_type == 'fc':
166 | self.mlp = nn.Sequential(
167 | nn.Linear(n_embd, dim_feedforward),
168 | act,
169 | nn.Linear(dim_feedforward, n_embd),
170 | nn.Dropout(dropout),
171 | )
172 | else:
173 | raise NotImplemented
174 |
175 | def forward(self, x, timestep, cond_output=None, mask=None):
176 | if self.attn_type == "self":
177 | x = x + self.attn(self.ln1(x, timestep), attn_mask=mask)
178 | x = x + self.mlp(self.ln2(x))
179 | elif self.attn_type == "selfcondition":
180 | x = x + self.attn(self.ln1(x, timestep), attn_mask=mask)
181 | x = x + self.mlp(self.ln2(x, cond_output))
182 | elif self.attn_type == "selfcross":
183 | x = x + self.attn1(self.ln1(x, timestep), attn_mask=mask)
184 | x = x + self.attn2(self.ln1_1(x, timestep), cond_output, attn_mask=mask)
185 | x = x + self.mlp(self.ln2(x))
186 | else:
187 | return NotImplemented
188 | return x
189 |
190 |
191 | class DenoiseTransformer(nn.Module):
192 | """Base denoising transformer class"""
193 | def __init__(
194 | self,
195 | n_layer=4,
196 | n_embd=512,
197 | n_head=8,
198 | dim_feedforward=2048,
199 | dropout=0.1,
200 | activate='GELU',
201 | num_timesteps=1000,
202 | timestep_type='adalayernorm_abs',
203 | context_dim=256,
204 | mlp_type='fc',
205 | ):
206 | super().__init__()
207 |
208 | # transformer backbone
209 | if context_dim == 0:
210 | self.tf_blocks = nn.Sequential(*[Block(
211 | n_embd, n_head, dim_feedforward, dropout, activate,
212 | num_timesteps, timestep_type, mlp_type=mlp_type,
213 | attn_type='self',
214 | ) for _ in range(n_layer)])
215 | else:
216 | self.tf_blocks = nn.Sequential(*[Block(
217 | n_embd, n_head, dim_feedforward, dropout, activate,
218 | num_timesteps, timestep_type, mlp_type=mlp_type,
219 | attn_type='selfcross', cond_emb_dim=context_dim,
220 | ) for _ in range(n_layer)])
221 |
222 | @staticmethod
223 | def _encoder_mlp(hidden_size, input_size):
224 | mlp_layers = [
225 | nn.Linear(input_size, hidden_size),
226 | nn.GELU(),
227 | nn.Linear(hidden_size, hidden_size*2),
228 | nn.GELU(),
229 | nn.Linear(hidden_size*2, hidden_size),
230 | ]
231 | return nn.Sequential(*mlp_layers)
232 |
233 | @staticmethod
234 | def _decoder_mlp(hidden_size, output_size):
235 | mlp_layers = [
236 | nn.Linear(hidden_size, hidden_size*2),
237 | nn.GELU(),
238 | nn.Linear(hidden_size*2, hidden_size),
239 | nn.GELU(),
240 | nn.Linear(hidden_size, output_size),
241 | ]
242 | return nn.Sequential(*mlp_layers)
243 |
--------------------------------------------------------------------------------
/midiffusion/networks/denoising_net/unet1D.py:
--------------------------------------------------------------------------------
1 | #
2 | # Modified from:
3 | # https://github.com/tangjiapeng/DiffuScene
4 | #
5 |
6 | from functools import partial
7 | import torch
8 | from torch import nn, einsum
9 | import torch.nn.functional as F
10 | from einops import rearrange, reduce
11 | from .time_embedding import SinusoidalPosEmb, RandomOrLearnedSinusoidalPosEmb
12 |
13 | # helpers functions
14 |
15 | def exists(x):
16 | return x is not None
17 |
18 |
19 | def default(val, d):
20 | if exists(val):
21 | return val
22 | return d() if callable(d) else d
23 |
24 |
25 | class Residual(nn.Module):
26 | def __init__(self, fn):
27 | super().__init__()
28 | self.fn = fn
29 |
30 | def forward(self, x, *args, **kwargs):
31 | return self.fn(x, *args, **kwargs) + x
32 |
33 |
34 | class ResidualCross(nn.Module):
35 | def __init__(self, fn):
36 | super().__init__()
37 | self.fn = fn
38 |
39 | def forward(self, x, context, *args, **kwargs):
40 | return self.fn(x, context, *args, **kwargs) + x
41 |
42 |
43 | def Upsample(dim, dim_out=None):
44 | if dim_out is None or dim == dim_out:
45 | return nn.Identity()
46 | else:
47 | return nn.Sequential(
48 | #nn.Upsample(scale_factor = 2, mode = 'nearest'),
49 |
50 | nn.Conv1d(dim, default(dim_out, dim), 1)
51 | )
52 |
53 |
54 | def Downsample(dim, dim_out=None):
55 | if dim_out is None or dim == dim_out:
56 | return nn.Identity()
57 | else:
58 | return nn.Sequential(
59 | #return nn.Conv1d(dim, default(dim_out, dim), 4, 2, 1)
60 |
61 | nn.Conv1d(dim, default(dim_out, dim), 1)
62 | )
63 |
64 | # ResNet block
65 |
66 | class WeightStandardizedConv1d(nn.Conv1d):
67 | """
68 | https://arxiv.org/abs/1903.10520
69 | weight standardization purportedly works synergistically with group normalization
70 | """
71 | def forward(self, x):
72 | eps = 1e-5 if x.dtype == torch.float32 else 1e-3
73 |
74 | weight = self.weight
75 | mean = reduce(weight, 'o ... -> o 1 1', 'mean')
76 | var = reduce(weight, 'o ... -> o 1 1', partial(torch.var, unbiased=False))
77 | normalized_weight = (weight - mean) * (var + eps).rsqrt()
78 |
79 | return F.conv1d(x, normalized_weight, self.bias, self.stride,
80 | self.padding, self.dilation, self.groups)
81 |
82 |
83 | class Block(nn.Module):
84 | def __init__(self, dim, dim_out, groups=8):
85 | super().__init__()
86 | self.proj = WeightStandardizedConv1d(dim, dim_out, 1, padding=0) # 3-->1
87 | self.norm = nn.GroupNorm(groups, dim_out)
88 | self.act = nn.SiLU()
89 |
90 | def forward(self, x, scale_shift=None):
91 | x = self.proj(x)
92 | x = self.norm(x)
93 |
94 | if exists(scale_shift):
95 | scale, shift = scale_shift
96 | x = x * (scale + 1) + shift
97 |
98 | x = self.act(x)
99 | return x
100 |
101 |
102 | class ResnetBlock(nn.Module):
103 | """https://arxiv.org/abs/1512.03385"""
104 |
105 | def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8):
106 | super().__init__()
107 | self.mlp = nn.Sequential(
108 | nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2)
109 | ) if exists(time_emb_dim) else None
110 |
111 | self.block1 = Block(dim, dim_out, groups=groups)
112 | self.block2 = Block(dim_out, dim_out, groups=groups)
113 | self.res_conv = nn.Conv1d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
114 |
115 | def forward(self, x, time_emb=None):
116 | scale_shift = None
117 | if exists(self.mlp) and exists(time_emb):
118 | time_emb = self.mlp(time_emb)
119 | if len(time_emb.shape) == 2:
120 | time_emb = rearrange(time_emb, 'b c -> b c 1')
121 | else:
122 | time_emb = rearrange(time_emb, 'b n c -> b c n')
123 | scale_shift = time_emb.chunk(2, dim=1)
124 |
125 | h = self.block1(x, scale_shift=scale_shift)
126 | h = self.block2(h)
127 |
128 | return h + self.res_conv(x)
129 |
130 |
131 | # Attention module
132 |
133 | class Attention(nn.Module):
134 | def __init__(self, dim, heads=4, dim_head=32):
135 | super().__init__()
136 | self.scale = dim_head ** -0.5
137 | self.heads = heads
138 | hidden_dim = dim_head * heads
139 |
140 | self.to_qkv = nn.Conv1d(dim, hidden_dim * 3, 1, bias=False)
141 | self.to_out = nn.Conv1d(hidden_dim, dim, 1)
142 |
143 | def forward(self, x):
144 | b, c, n = x.shape
145 | qkv = self.to_qkv(x).chunk(3, dim=1)
146 | q, k, v = map(
147 | lambda t: rearrange(t, 'b (h c) n -> b h c n', h=self.heads), qkv
148 | )
149 |
150 | q = q * self.scale
151 |
152 | sim = einsum('b h d i, b h d j -> b h i j', q, k)
153 | # sim = sim - sim.amax(dim=-1, keepdim=True).detach()
154 | attn = sim.softmax(dim=-1)
155 |
156 | out = einsum('b h i j, b h d j -> b h i d', attn, v)
157 | out = rearrange(out, 'b h n d -> b (h d) n')
158 | return self.to_out(out)
159 |
160 |
161 | class LinearAttention(nn.Module):
162 | def __init__(self, dim, heads=4, dim_head=32):
163 | super().__init__()
164 | self.scale = dim_head ** -0.5
165 | self.heads = heads
166 | hidden_dim = dim_head * heads
167 | self.to_qkv = nn.Conv1d(dim, hidden_dim * 3, 1, bias=False)
168 |
169 | self.to_out = nn.Sequential(
170 | nn.Conv1d(hidden_dim, dim, 1), LayerNorm(dim)
171 | )
172 |
173 | def forward(self, x):
174 | b, c, n = x.shape
175 | qkv = self.to_qkv(x).chunk(3, dim = 1)
176 | q, k, v = map(
177 | lambda t: rearrange(t, 'b (h c) n -> b h c n', h=self.heads), qkv
178 | )
179 |
180 | q = q.softmax(dim=-2)
181 | k = k.softmax(dim=-1)
182 |
183 | q = q * self.scale
184 | context = torch.einsum('b h d n, b h e n -> b h d e', k, v)
185 |
186 | out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
187 | out = rearrange(out, 'b h c n -> b (h c) n', h=self.heads)
188 | return self.to_out(out)
189 |
190 |
191 | class LinearAttentionCross(nn.Module):
192 | def __init__(self, dim, context_dim=None, heads = 4, dim_head = 32):
193 | super().__init__()
194 | self.scale = dim_head ** -0.5
195 | self.heads = heads
196 | hidden_dim = dim_head * heads
197 |
198 | if context_dim is None:
199 | context_dim = dim
200 | self.to_q = nn.Conv1d(dim, hidden_dim, 1, bias = False)
201 | self.to_kv = nn.Conv1d(context_dim, hidden_dim*2, 1, bias = False)
202 |
203 | self.to_out = nn.Sequential(
204 | nn.Conv1d(hidden_dim, dim, 1),
205 | LayerNorm(dim)
206 | )
207 |
208 | def forward(self, x, context):
209 | b, c, n = x.shape
210 | q = self.to_q(x)
211 | kv = self.to_kv(context).chunk(2, dim = 1)
212 | q = rearrange(q, 'b (h c) n -> b h c n', h = self.heads)
213 | k, v = map(
214 | lambda t: rearrange(t, 'b (h c) n -> b h c n', h = self.heads), kv
215 | )
216 |
217 | q = q.softmax(dim = -2)
218 | k = k.softmax(dim = -1)
219 |
220 | q = q * self.scale
221 | context = torch.einsum('b h d n, b h e n -> b h d e', k, v)
222 |
223 | out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
224 | out = rearrange(out, 'b h c n -> b (h c) n', h = self.heads)
225 | return self.to_out(out)
226 |
227 |
228 | # Group normalization
229 |
230 | class LayerNorm(nn.Module): # similar to nn.GroupNorm without bias param
231 | def __init__(self, dim):
232 | super().__init__()
233 | self.g = nn.Parameter(torch.ones(1, dim, 1))
234 |
235 | def forward(self, x):
236 | eps = 1e-5 if x.dtype == torch.float32 else 1e-3
237 | var = torch.var(x, dim=1, unbiased=False, keepdim=True)
238 | mean = torch.mean(x, dim=1, keepdim=True)
239 | return (x - mean) * (var + eps).rsqrt() * self.g
240 |
241 |
242 | class PreNorm(nn.Module):
243 | def __init__(self, dim, fn):
244 | super().__init__()
245 | self.fn = fn
246 | self.norm = LayerNorm(dim)
247 |
248 | def forward(self, x):
249 | x = self.norm(x)
250 | return self.fn(x)
251 |
252 |
253 | class PreNormCross(nn.Module):
254 | def __init__(self, dim, fn):
255 | super().__init__()
256 | self.fn = fn
257 | self.norm = LayerNorm(dim)
258 |
259 | def forward(self, x, context):
260 | x = self.norm(x)
261 | return self.fn(x, context)
262 |
263 |
264 | # Conditional U-Net
265 |
266 | class Unet1D(nn.Module):
267 | def __init__(
268 | self,
269 | network_dim,
270 | dim=256,
271 | init_dim = None, # default: dim
272 | out_dim = None,
273 | dim_mults=(1, 2, 4, 8),
274 | channels = 3, # ignored if seperate_all=True
275 | seperate_all=False,
276 | context_dim = 256,
277 | cross_condition=False,
278 | cross_condition_dim=256,
279 | resnet_block_groups = 8,
280 | learned_variance = False,
281 | learned_sinusoidal_cond = False,
282 | random_fourier_features = False,
283 | learned_sinusoidal_dim = 16
284 | ):
285 | super().__init__()
286 |
287 | # model flags
288 | self.cross_condition = cross_condition
289 | self.seperate_all = seperate_all
290 |
291 | # feature dimensions
292 | self.objectness_dim, self.class_dim, self.objfeat_dim = \
293 | network_dim["objectness_dim"], network_dim["class_dim"], \
294 | network_dim["objfeat_dim"]
295 | self.translation_dim, self.size_dim, self.angle_dim = \
296 | network_dim["translation_dim"], network_dim["size_dim"], \
297 | network_dim["angle_dim"]
298 | self.bbox_dim = self.translation_dim + self.size_dim + self.angle_dim
299 | self.context_dim = context_dim
300 | if cross_condition:
301 | self.cross_condition_dim = cross_condition_dim
302 |
303 | # Initial feature specific processing
304 | if self.seperate_all:
305 | self.bbox_embedf = Unet1D._encoder_mlp(dim, self.bbox_dim)
306 | self.bbox_hidden2output = Unet1D._decoder_mlp(dim, self.bbox_dim)
307 | feature_str = "translation/size/angle"
308 |
309 | if self.class_dim > 0:
310 | self.class_embedf = Unet1D._encoder_mlp(dim, self.class_dim)
311 | feature_str += "/class"
312 | if self.objectness_dim > 0:
313 | self.objectness_embedf = Unet1D._encoder_mlp(dim, self.objectness_dim)
314 | feature_str += "/objectness"
315 | if self.objfeat_dim > 0:
316 | self.objfeat_embedf = Unet1D._encoder_mlp(dim, self.objfeat_dim)
317 | feature_str += "/objfeat"
318 |
319 | input_channels = dim
320 | print('separate unet1d encoder/decoder of {}'.format(feature_str))
321 | else:
322 | input_channels = channels
323 | print('unet1d encoder of all object properties')
324 |
325 | # U-Net initialization
326 | init_dim = default(init_dim, dim)
327 | self.init_conv = nn.Conv1d(input_channels, init_dim, 1) #nn.Conv1d(input_channels, init_dim, 7, padding = 3)
328 |
329 | dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
330 | in_out = list(zip(dims[:-1], dims[1:]))
331 |
332 | block_klass = partial(ResnetBlock, groups=resnet_block_groups)
333 |
334 | # time embeddings
335 | time_dim = dim * 4
336 |
337 | self.random_or_learned_sinusoidal_cond = learned_sinusoidal_cond or random_fourier_features
338 | if self.random_or_learned_sinusoidal_cond:
339 | sinu_pos_emb = RandomOrLearnedSinusoidalPosEmb(
340 | learned_sinusoidal_dim, random_fourier_features
341 | )
342 | fourier_dim = learned_sinusoidal_dim + 1
343 | else:
344 | sinu_pos_emb = SinusoidalPosEmb(dim)
345 | fourier_dim = dim
346 |
347 | self.time_mlp = nn.Sequential(
348 | sinu_pos_emb,
349 | nn.Linear(fourier_dim, time_dim),
350 | nn.GELU(),
351 | nn.Linear(time_dim, time_dim)
352 | )
353 |
354 | # layers
355 | self.downs = nn.ModuleList([])
356 | self.ups = nn.ModuleList([])
357 | num_resolutions = len(in_out)
358 |
359 | for ind, (dim_in, dim_out) in enumerate(in_out):
360 | is_last = ind >= (num_resolutions - 1)
361 |
362 | self.downs.append(nn.ModuleList([
363 | block_klass(dim_in, dim_in, time_emb_dim=context_dim),
364 | block_klass(dim_in, dim_in, time_emb_dim=time_dim),
365 | ResidualCross(PreNormCross(
366 | dim_in, LinearAttentionCross(dim_in, cross_condition_dim)
367 | )) if cross_condition else nn.Identity(),
368 | block_klass(dim_in, dim_in, time_emb_dim=time_dim),
369 | Residual(PreNorm(dim_in, LinearAttention(dim_in))),
370 | Downsample(dim_in, dim_out) if not is_last
371 | else nn.Conv1d(dim_in, dim_out, 1) #3, padding = 1)
372 | ]))
373 |
374 | mid_dim = dims[-1]
375 | self.mid_block0 = block_klass(mid_dim, mid_dim, time_emb_dim=context_dim)
376 | self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
377 | self.mid_attn_cross = ResidualCross(PreNormCross(
378 | mid_dim, LinearAttentionCross(mid_dim, cross_condition_dim)
379 | )) if cross_condition else nn.Identity()
380 | self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
381 | self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
382 |
383 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
384 | is_last = ind == (len(in_out) - 1)
385 |
386 | self.ups.append(nn.ModuleList([
387 | block_klass(dim_out, dim_in, time_emb_dim=context_dim),
388 | block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim),
389 | ResidualCross(PreNormCross(
390 | dim_out, LinearAttentionCross(dim_out, cross_condition_dim)
391 | )) if cross_condition else nn.Identity(),
392 | block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim),
393 | Residual(PreNorm(dim_out, LinearAttention(dim_out))),
394 | Upsample(dim_out, dim_in) if not is_last
395 | else nn.Conv1d(dim_out, dim_in, 1) #3, padding = 1)
396 | ]))
397 |
398 | self.final_res_block = block_klass(init_dim * 2, dim, time_emb_dim=time_dim)
399 |
400 | # Final feature specific processing
401 | if self.seperate_all:
402 | self.bbox_hidden2output = Unet1D._decoder_mlp(dim, self.bbox_dim)
403 |
404 | if self.class_dim > 0:
405 | self.class_hidden2output = Unet1D._decoder_mlp(dim, self.class_dim)
406 | if self.objectness_dim > 0:
407 | self.objectness_hidden2output = Unet1D._decoder_mlp(dim, self.objectness_dim)
408 | if self.objfeat_dim > 0:
409 | self.objfeat_hidden2output = Unet1D._decoder_mlp(dim, self.objfeat_dim)
410 |
411 | else:
412 | default_out_dim = channels * (1 if not learned_variance else 2)
413 | self.out_dim = default(out_dim, default_out_dim)
414 | self.final_conv = nn.Conv1d(dim, self.out_dim, 1)
415 |
416 | @staticmethod
417 | def _encoder_mlp(hidden_size, input_size):
418 | mlp_layers = [
419 | nn.Conv1d(input_size, hidden_size, 1),
420 | nn.GELU(),
421 | nn.Conv1d(hidden_size, hidden_size*2, 1),
422 | nn.GELU(),
423 | nn.Conv1d(hidden_size*2, hidden_size, 1),
424 | ]
425 | return nn.Sequential(*mlp_layers)
426 |
427 | @staticmethod
428 | def _decoder_mlp(hidden_size, output_size):
429 | mlp_layers = [
430 | nn.Conv1d(hidden_size, hidden_size*2, 1),
431 | nn.GELU(),
432 | nn.Conv1d(hidden_size*2, hidden_size, 1),
433 | nn.GELU(),
434 | nn.Conv1d(hidden_size, output_size, 1),
435 | ]
436 | return nn.Sequential(*mlp_layers)
437 |
438 |
439 | def forward(self, x, time, context=None, context_cross=None):
440 | # (B, N, C) --> (B, C, N)
441 | x = torch.permute(x, (0, 2, 1)).contiguous()
442 | if context_cross is not None:
443 | # [B, N, C] --> [B, C, N]
444 | context_cross = torch.permute(context_cross, (0, 2, 1)).contiguous()
445 |
446 | # initial processing
447 | if self.seperate_all:
448 | x_bbox = self.bbox_embedf(x[:, 0:self.bbox_dim, :])
449 |
450 | if self.class_dim > 0:
451 | start_index = self.bbox_dim
452 | x_class = self.class_embedf(
453 | x[:, start_index:start_index+self.class_dim, :]
454 | )
455 | else:
456 | x_class = 0
457 |
458 | if self.objectness_dim > 0:
459 | start_index = self.bbox_dim+self.class_dim
460 | x_object = self.objectness_embedf(
461 | x[:, start_index:start_index+self.objectness_dim, :]
462 | )
463 | else:
464 | x_object = 0
465 |
466 | if self.objfeat_dim > 0:
467 | start_index = self.bbox_dim+self.class_dim+self.objectness_dim
468 | x_objfeat = self.objfeat_embedf(
469 | x[:, start_index:start_index+self.objfeat_dim, :]
470 | )
471 | else:
472 | x_objfeat = 0
473 |
474 | x = x_bbox + x_class + x_object + x_objfeat
475 |
476 | # unet-1D
477 | x = self.init_conv(x)
478 | r = x.clone()
479 |
480 | t = self.time_mlp(time)
481 |
482 | h = []
483 |
484 | for block0, block1, attncross, block2, attn, downsample in self.downs:
485 | x = block0(x, context)
486 | x = block1(x, t)
487 | h.append(x)
488 |
489 | x = attncross(x, context_cross) if self.cross_condition else attncross(x)
490 | x = block2(x, t)
491 | x = attn(x)
492 | h.append(x)
493 |
494 | x = downsample(x)
495 |
496 | x = self.mid_block0(x, context)
497 | x = self.mid_block1(x, t)
498 | x = self.mid_attn_cross(x, context_cross) if self.cross_condition else self.mid_attn_cross(x)
499 | x = self.mid_attn(x)
500 | x = self.mid_block2(x, t)
501 |
502 | for block0, block1, attncross, block2, attn, upsample in self.ups:
503 | x = block0(x, context)
504 | x = torch.cat((x, h.pop()), dim = 1)
505 | x = block1(x, t)
506 |
507 | x = attncross(x, context_cross) if self.cross_condition else attncross(x)
508 | x = torch.cat((x, h.pop()), dim = 1)
509 | x = block2(x, t)
510 | x = attn(x)
511 |
512 | x = upsample(x)
513 |
514 | x = torch.cat((x, r), dim = 1)
515 |
516 | x = self.final_res_block(x, t)
517 |
518 | # final processing
519 | if self.seperate_all:
520 | out = self.bbox_hidden2output(x)
521 | if self.class_dim > 0:
522 | out_class = self.class_hidden2output(x)
523 | out = torch.cat([out, out_class], dim=1).contiguous()
524 | if self.objectness_dim > 0:
525 | out_object = self.objectness_hidden2output(x)
526 | out = torch.cat([out, out_object], dim=1).contiguous()
527 | if self.objfeat_dim > 0:
528 | out_objfeat = self.objfeat_hidden2output(x)
529 | out = torch.cat([out, out_objfeat], dim=1).contiguous()
530 | else:
531 | out = self.final_conv(x)
532 |
533 | # (B, N, C) <-- (B, C, N)
534 | out = torch.permute(out, (0, 2, 1)).contiguous()
535 | return out
--------------------------------------------------------------------------------
/midiffusion/networks/diffusion_base.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.utils.data
3 |
4 | EPS_PROB = 1e-30 # minimum probability to make sure log prob is numerically stable
5 | LOG_ZERO = -69 # substitute of log(0)
6 |
7 |
8 | class BaseDiffusion(nn.Module):
9 | """
10 | Base class for diffusion model.
11 | Note that theoretically t = 1, ..., num_steps, the function argument t
12 | starts from 0 (i.e. off by 1) due to python indexing.
13 | """
14 | def __init__(self):
15 | super().__init__()
16 | self.device = torch.device("cpu")
17 |
18 | def q_pred(self, x_0, t):
19 | """
20 | Compute probability q(x_t | x_0)
21 | """
22 | raise NotImplementedError
23 |
24 | def q_sample(self, x_0, t):
25 | """
26 | Diffuse the data, i.e. sample from q(x_t | x_0)
27 | """
28 | raise NotImplementedError
29 |
30 | def q_posterior(self, x_0, x_t, t):
31 | """
32 | Compute posterior probability q(x_{t-1} | x_t, x_0)
33 | """
34 | raise NotImplementedError
35 |
36 | def p_pred(self, denoise_fn, x_t, t, condition, **kwargs):
37 | """
38 | Compute denoising probability p(x_{t-1} | x_t)
39 | """
40 | raise NotImplementedError
41 |
42 | def p_sample(self, denoise_fn, x_t, t, condition, **kwargs):
43 | """
44 | Denoise the data, i.e. sample from p(x_{t-1} | x_t)
45 | """
46 | raise NotImplementedError
47 |
48 | def p_sample_loop(self, denoise_fn, shape, condition, sample_freq=None, **kwargs):
49 | """
50 | Generate data by denoising recursively
51 | """
52 | raise NotImplementedError
53 |
54 | def p_losses(self, denoise_fn, x_0, condition, **kwargs):
55 | """
56 | Training loss calculation
57 | """
58 | raise NotImplementedError
59 |
60 | @staticmethod
61 | def _extract(a, t, x_shape):
62 | """
63 | Extract some coefficients at specified timesteps,
64 | then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
65 | """
66 | b, *_ = t.shape
67 | out = a.gather(-1, t)
68 | return out.reshape(b, *((1,) * (len(x_shape) - 1)))
69 |
70 | def _move_tensors(self, device):
71 | """
72 | Move pre-computed parameters to specified device
73 | """
74 | for attr_name in dir(self):
75 | attr = getattr(self, attr_name)
76 | if isinstance(attr, torch.Tensor):
77 | setattr(self, attr_name, attr.to(device))
78 | self.device = torch.device(device)
79 |
80 | @staticmethod
81 | def sum_last_dims(data_tensor, keep_dims=1):
82 | """
83 | Sum over the last dimensions. Default: sum input data over each batch.
84 | """
85 | return data_tensor.reshape(*data_tensor.shape[:keep_dims], -1).sum(-1)
86 |
87 | @staticmethod
88 | def mean_per_batch(data_tensor, feat_size=None, start_ind=0, mask=None):
89 | """
90 | Given B x N x C input data, return a 1-D average tensor of size B,
91 | with option to specify a B x N mask or a feature range to select data.
92 | """
93 | B, N, C = data_tensor.shape
94 |
95 | # Handle the case where feat_size is zero
96 | if feat_size == 0:
97 | return torch.zeros(B, device=data_tensor.device)
98 |
99 | # Determine feature size if not specified
100 | if feat_size is None:
101 | feat_size = C - start_ind
102 |
103 | # Select the relevant feature range from the data
104 | data_selected = data_tensor[:, :, start_ind:start_ind + feat_size]
105 |
106 | # Compute mean with or without mask
107 | if mask is not None:
108 | assert mask.shape == (B, N) and (mask.sum(dim=1) > 0).all()
109 | masked_sum = (data_selected * mask.unsqueeze(-1)).sum(dim=[1, 2])
110 | return masked_sum / (mask.sum(dim=1) * feat_size)
111 | else:
112 | return data_selected.mean(dim=[1, 2])
113 |
--------------------------------------------------------------------------------
/midiffusion/networks/diffusion_d3pm.py:
--------------------------------------------------------------------------------
1 | #
2 | # Modified from:
3 | # https://github.com/microsoft/VQ-Diffusion/blob/main/image_synthesis/modeling/transformers/diffusion_transformer.py
4 | #
5 |
6 | import numpy as np
7 | import torch.nn as nn
8 | import torch.utils.data
9 | import torch.nn.functional as F
10 | from .diffusion_base import BaseDiffusion, EPS_PROB, LOG_ZERO
11 | from .denoising_net.mixed_transformer import MixedDenoiseTransformer
12 |
13 |
14 | '''
15 | helper functions
16 | '''
17 |
18 | def log_1_min_a(a):
19 | """log(1 - exp(a))"""
20 | return torch.log(1 - a.exp() + EPS_PROB)
21 |
22 |
23 | def log_add_exp(a, b):
24 | """log(exp(a) + exp(b))"""
25 | maximum = torch.max(a, b)
26 | return maximum + torch.log(torch.exp(a - maximum) + torch.exp(b - maximum))
27 |
28 |
29 | def log_categorical(log_x_start, log_prob):
30 | return (log_x_start.exp() * log_prob).sum(dim=1)
31 |
32 |
33 | def index_to_log_onehot(x, num_classes):
34 | """convert (*, N) index tensor to (*, C, N) log 1-hot tensor"""
35 | assert x.max().item() < num_classes, f"Error: {x.max().item()} >= {num_classes}"
36 | x_onehot = F.one_hot(x, num_classes)
37 | permute_order = (0, -1) + tuple(range(1, x.ndim))
38 | x_onehot = x_onehot.permute(permute_order)
39 | log_x = torch.log(x_onehot.float().clamp(min=EPS_PROB))
40 | return log_x
41 |
42 |
43 | def log_onehot_to_index(log_x):
44 | """argmax(log_x, dim=1)"""
45 | return log_x.argmax(1)
46 |
47 |
48 | def alpha_schedule(
49 | num_timesteps, N=100, att_1=0.99999, att_T=0.000009, ctt_1=0.000009, ctt_T=0.99999
50 | ):
51 | # note: 0.0 will tends to raise unexpected behaviour (e.g., log(0.0)), thus avoid 0.0
52 | assert att_1 > 0.0 and att_T > 0.0 and ctt_1 > 0.0 and ctt_T > 0.0
53 | assert att_1 + ctt_1 <= 1.0 and att_T + ctt_T <= 1.0
54 |
55 | att = np.arange(0, num_timesteps) / (num_timesteps - 1) * (att_T - att_1) + att_1
56 | att = np.concatenate(([1], att))
57 | at = att[1:] / att[:-1]
58 | ctt = np.arange(0, num_timesteps) / (num_timesteps - 1) * (ctt_T - ctt_1) + ctt_1
59 | ctt = np.concatenate(([0], ctt))
60 | one_minus_ctt = 1 - ctt
61 | one_minus_ct = one_minus_ctt[1:] / one_minus_ctt[:-1]
62 | ct = 1 - one_minus_ct
63 | bt = (1 - at - ct) / N
64 | att = np.concatenate((att[1:], [1]))
65 | ctt = np.concatenate((ctt[1:], [0]))
66 | btt = (1 - att - ctt) / N
67 |
68 | def _f(x):
69 | return torch.tensor(x.astype("float64"))
70 |
71 | return _f(at), _f(bt), _f(ct), _f(att), _f(btt), _f(ctt)
72 |
73 | '''
74 | model
75 | '''
76 |
77 | class MaskAndReplaceDiffusion(BaseDiffusion):
78 | def __init__(self, num_classes, noise_params, model_output_type="x0",
79 | mask_weight=1, auxiliary_loss_weight=0,
80 | adaptive_auxiliary_loss=False):
81 | super().__init__()
82 |
83 | assert model_output_type in ["x0", "x_prev"]
84 | assert auxiliary_loss_weight >= 0
85 | assert mask_weight >= 0
86 | self.num_classes = num_classes # TODO: currently, this is includes 'empty' and 'mask'
87 | self.model_output_type = model_output_type
88 | self.auxiliary_loss_weight = auxiliary_loss_weight
89 | self.adaptive_auxiliary_loss = adaptive_auxiliary_loss
90 | self.mask_weight = mask_weight # 'mask' token weight
91 |
92 | # diffusion noise params
93 | at, bt, ct, att, btt, ctt = noise_params
94 | assert at.shape[0] == bt.shape[0] == ct.shape[0]
95 | assert att.shape[0] == btt.shape[0] == ctt.shape[0] == at.shape[0] + 1
96 | self.num_timesteps = at.shape[0]
97 |
98 | log_at, log_bt, log_ct = torch.log(at), torch.log(bt), torch.log(ct)
99 | log_cumprod_at, log_cumprod_bt, log_cumprod_ct = \
100 | torch.log(att), torch.log(btt), torch.log(ctt)
101 |
102 | log_1_min_ct = log_1_min_a(log_ct)
103 | log_1_min_cumprod_ct = log_1_min_a(log_cumprod_ct)
104 |
105 | assert log_add_exp(log_ct, log_1_min_ct).abs().sum().item() < 1.0e-5
106 | assert (
107 | log_add_exp(log_cumprod_ct, log_1_min_cumprod_ct).abs().sum().item()
108 | < 1.0e-5
109 | )
110 |
111 | self.diffusion_acc_list = [0] * self.num_timesteps
112 | self.diffusion_keep_list = [0] * self.num_timesteps
113 | # Convert to float32 and register buffers.
114 | self.register_buffer("log_at", log_at.float())
115 | self.register_buffer("log_bt", log_bt.float())
116 | self.register_buffer("log_ct", log_ct.float())
117 | self.register_buffer("log_cumprod_at", log_cumprod_at.float())
118 | self.register_buffer("log_cumprod_bt", log_cumprod_bt.float())
119 | self.register_buffer("log_cumprod_ct", log_cumprod_ct.float())
120 | self.register_buffer("log_1_min_ct", log_1_min_ct.float())
121 | self.register_buffer("log_1_min_cumprod_ct", log_1_min_cumprod_ct.float())
122 |
123 | self.register_buffer('Lt_history', torch.zeros(self.num_timesteps))
124 | self.register_buffer('Lt_count', torch.zeros(self.num_timesteps))
125 |
126 | def multinomial_kl(self, log_prob1, log_prob2): # compute KL loss on log_prob
127 | kl = (log_prob1.exp() * (log_prob1 - log_prob2)).sum(dim=1)
128 | return kl
129 |
130 | def q_pred_one_timestep(self, log_x_t_1, t):
131 | """
132 | log(Q_t * exp(log_x_t_1)), diffusion step: q(x_t | x_{t-1})
133 | """
134 | # log_x_t_1 (B, C, N)
135 | log_at = self._extract(self.log_at, t, log_x_t_1.shape) # at
136 | log_bt = self._extract(self.log_bt, t, log_x_t_1.shape) # bt
137 | log_ct = self._extract(self.log_ct, t, log_x_t_1.shape) # ct
138 | log_1_min_ct = self._extract(self.log_1_min_ct, t, log_x_t_1.shape) # 1-ct
139 |
140 | log_probs = torch.cat([
141 | log_add_exp(log_x_t_1[:, :-1, :] + log_at, log_bt), # dropped a small term
142 | log_add_exp(log_x_t_1[:, -1:, :] + log_1_min_ct, log_ct),
143 | ], dim=1)
144 |
145 | return log_probs
146 |
147 | def q_pred(self, log_x_start, t):
148 | """
149 | log(bar{Q}_t * exp(log_x_start)), diffuse the data to time t: q(x_t | x_0)
150 | """
151 | t = (t + (self.num_timesteps + 1)) % (self.num_timesteps + 1)
152 | log_cumprod_at = self._extract(self.log_cumprod_at, t, log_x_start.shape) # at~
153 | log_cumprod_bt = self._extract(self.log_cumprod_bt, t, log_x_start.shape) # bt~
154 | log_cumprod_ct = self._extract(self.log_cumprod_ct, t, log_x_start.shape) # ct~
155 | log_1_min_cumprod_ct = self._extract(
156 | self.log_1_min_cumprod_ct, t, log_x_start.shape
157 | ) # 1-ct~
158 |
159 | log_probs = torch.cat([
160 | log_add_exp(log_x_start[:, :-1, :] + log_cumprod_at, log_cumprod_bt),
161 | log_add_exp(
162 | log_x_start[:, -1:, :] + log_1_min_cumprod_ct, log_cumprod_ct
163 | ), # simplified
164 | ], dim=1)
165 |
166 | return log_probs
167 |
168 | def q_posterior(self, log_x_start, log_x_t, t):
169 | """
170 | log of prosterior probability q(x_{t-1}|x_t,x_0')
171 | """
172 | B, C, N = log_x_start.shape
173 | log_one_vector = torch.zeros(B, 1, 1).type_as(log_x_t)
174 | log_zero_vector = torch.full((B, 1, N), LOG_ZERO).type_as(log_x_t)
175 |
176 | # notice that log_x_t is onehot
177 | onehot_x_t = log_onehot_to_index(log_x_t)
178 | mask = (onehot_x_t == self.num_classes - 1).unsqueeze(1)
179 |
180 | log_qt = self.q_pred(log_x_t, t) # q(xt|x0)
181 | # log_qt = torch.cat((log_qt[:,:-1,:], log_zero_vector), dim=1)
182 | log_qt = log_qt[:, :-1, :]
183 | log_cumprod_ct = self._extract(self.log_cumprod_ct, t, log_x_start.shape) # ct~
184 | ct_cumprod_vector = log_cumprod_ct.expand(-1, self.num_classes - 1, -1)
185 | # ct_cumprod_vector = torch.cat((ct_cumprod_vector, log_one_vector), dim=1)
186 | log_qt = (~mask) * log_qt + mask * ct_cumprod_vector
187 |
188 | log_qt_one_timestep = self.q_pred_one_timestep(log_x_t, t) # q(xt|xt_1)
189 | log_qt_one_timestep = torch.cat(
190 | (log_qt_one_timestep[:, :-1, :], log_zero_vector), dim=1
191 | )
192 | log_ct = self._extract(self.log_ct, t, log_x_start.shape) # ct
193 | ct_vector = log_ct.expand(-1, self.num_classes - 1, -1)
194 | ct_vector = torch.cat((ct_vector, log_one_vector), dim=1)
195 | log_qt_one_timestep = (~mask) * log_qt_one_timestep + mask * ct_vector
196 |
197 | # log_x_start = torch.cat((log_x_start, log_zero_vector), dim=1)
198 | # q = log_x_start - log_qt
199 | q = log_x_start[:, :-1, :] - log_qt
200 | q = torch.cat((q, log_zero_vector), dim=1)
201 | q_log_sum_exp = torch.logsumexp(q, dim=1, keepdim=True)
202 | q = q - q_log_sum_exp
203 | log_EV_xtmin_given_xt_given_xstart = \
204 | self.q_pred(q, t - 1) + log_qt_one_timestep + q_log_sum_exp
205 |
206 | return torch.clamp(log_EV_xtmin_given_xt_given_xstart, LOG_ZERO, 0)
207 |
208 | @staticmethod
209 | def log_pred_from_denoise_out(denoise_out):
210 | """
211 | convert output of denoising network to log probability over classes and [mask]
212 | """
213 | out = denoise_out.permute((0, 2, 1)) # (B, N, C-1) -> (B, C-1, N)
214 | B, _, N = out.shape
215 |
216 | log_pred = F.log_softmax(out.double(), dim=1).float()
217 | log_pred = torch.clamp(log_pred, LOG_ZERO, 0)
218 | log_zero_vector = torch.full((B, 1, N), LOG_ZERO).type_as(log_pred)
219 | return torch.cat((log_pred, log_zero_vector), dim=1)
220 |
221 | def predict_denoise(self, denoise_fn, log_x_t, t, condition=None,
222 | condition_cross=None):
223 | """
224 | compute denoise_fn(data, t, condition, condition_cross) and convert output to log prob
225 | """
226 | x_t = log_onehot_to_index(log_x_t) # (B, N)
227 | out = denoise_fn(x_t, t, condition, condition_cross)
228 | log_pred = self.log_pred_from_denoise_out(out)
229 | assert log_pred.shape == log_x_t.shape
230 |
231 | return log_pred
232 |
233 | def p_pred(self, denoise_fn, log_x_t, t, condition=None, condition_cross=None):
234 | """
235 | log denoising probability, denoising step: p(x_{t-1} | x_t)
236 | """
237 | if self.model_output_type == 'x0':
238 | # if x0, first p(x0|xt), than sum(q(xt-1|xt,x0)*p(x0|xt))
239 | log_x_recon = self.predict_denoise(
240 | denoise_fn, log_x_t, t, condition, condition_cross
241 | )
242 | log_model_pred = self.q_posterior(
243 | log_x_start=log_x_recon, log_x_t=log_x_t, t=t
244 | )
245 | return log_model_pred, log_x_recon
246 | elif self.model_output_type == 'x_prev':
247 | log_model_pred = self.predict_denoise(
248 | denoise_fn, log_x_t, t, condition, condition_cross
249 | )
250 | return log_model_pred, None
251 | else:
252 | raise NotImplemented
253 |
254 | '''
255 | sampling
256 | '''
257 |
258 | def q_sample_one_step(self, log_x_t_1, t, no_mask=False):
259 | """
260 | sample from q(x_t | x_{t-1})
261 | """
262 | log_EV_qxt = self.q_pred_one_timestep(log_x_t_1, t)
263 | log_sample = self.log_sample_categorical(log_EV_qxt, no_mask)
264 | return log_sample
265 |
266 | def q_sample(self, log_x_start, t, no_mask=False):
267 | """
268 | sample from q(x_t | x_0)
269 | """
270 | log_EV_qxt_x0 = self.q_pred(log_x_start, t)
271 | log_sample = self.log_sample_categorical(log_EV_qxt_x0, no_mask)
272 | return log_sample
273 |
274 | @torch.no_grad()
275 | def p_sample(self, denoise_fn, log_x_t, t, condition, condition_cross=None):
276 | """
277 | sample x_{t-1} from p(x_{t-1} | x_t)
278 | """
279 | model_log_prob, _ = self.p_pred(denoise_fn, log_x_t, t, condition, condition_cross)
280 | log_sample = self.log_sample_categorical(model_log_prob)
281 | return log_sample
282 |
283 | def log_sample_categorical(self, logits, no_mask=False):
284 | """
285 | sample from log probability under gumbel noise, return results as log of 1-hot embedding
286 | (no_mask=True means sampling without the last [mask] class)
287 | """
288 | # use gumbel to sample onehot vector from log probability
289 | uniform = torch.rand_like(logits)
290 | gumbel_noise = -torch.log(-torch.log(uniform + EPS_PROB) + EPS_PROB)
291 | if no_mask:
292 | sample = (gumbel_noise + logits)[:, :-1, :].argmax(dim=1)
293 | else:
294 | sample = (gumbel_noise + logits).argmax(dim=1)
295 | log_sample = index_to_log_onehot(sample, self.num_classes)
296 | return log_sample
297 |
298 | def sample_time(self, b, device, method='uniform'):
299 | if method == 'importance':
300 | if not (self.Lt_count > 10).all():
301 | return self.sample_time(b, device, method='uniform')
302 |
303 | Lt_sqrt = torch.sqrt(self.Lt_history + 1e-10) + 0.0001
304 | Lt_sqrt[0] = Lt_sqrt[1] # Overwrite decoder term with L1.
305 | pt_all = Lt_sqrt / Lt_sqrt.sum()
306 |
307 | t = torch.multinomial(pt_all, num_samples=b, replacement=True)
308 | pt = pt_all.gather(dim=0, index=t)
309 | elif method == 'uniform':
310 | t = torch.randint(0, self.num_timesteps, (b,), device=device).long()
311 | pt = torch.ones_like(t).float() / self.num_timesteps
312 | else:
313 | raise ValueError
314 | return t, pt
315 |
316 | '''
317 | loss
318 | '''
319 | def compute_kl_loss(self, log_x_start, log_x_t, t, log_pred_prob):
320 | """compute train loss of each variable"""
321 | log_q_prob = self.q_posterior(log_x_start, log_x_t, t)
322 | kl = self.multinomial_kl(log_q_prob, log_pred_prob)
323 | decoder_nll = -log_categorical(log_x_start, log_pred_prob)
324 |
325 | t0_mask = (t == 0).unsqueeze(1).repeat(1, log_x_start.shape[-1])
326 | kl_loss = torch.where(t0_mask, decoder_nll, kl)
327 | return kl_loss
328 |
329 | def compute_aux_loss(self, log_x_start, log_x0_recon, t):
330 | """compute auxilary loss regulating predicted x0"""
331 | aux_loss = self.multinomial_kl(
332 | log_x_start[:, :-1, :], log_x0_recon[:, :-1, :]
333 | )
334 |
335 | t0_mask = (t == 0).unsqueeze(1).repeat(1, log_x_start.shape[-1])
336 | aux_loss = torch.where(t0_mask, torch.zeros_like(aux_loss), aux_loss)
337 | return aux_loss
338 |
339 | def p_losses(self, denoise_fn, x_start, t=None, pt=None, condition=None):
340 | assert self.model_output_type == 'x0'
341 | if t is None or pt is None:
342 | t, pt = self.sample_time(x_start.size(0), x_start.device, "uniform")
343 |
344 | log_xstart = index_to_log_onehot(x_start, self.num_classes)
345 | log_xt = self.q_sample(log_x_start=log_xstart, t=t)
346 |
347 | log_model_prob, log_x0_recon = self.p_pred(denoise_fn, log_xt, t, condition)
348 |
349 | x0_recon = log_onehot_to_index(log_x0_recon)
350 | x0_real = x_start
351 | xt_1_recon = log_onehot_to_index(log_model_prob)
352 | xt_recon = log_onehot_to_index(log_xt)
353 | for index in range(t.size(0)):
354 | this_t = t[index].item()
355 | same_rate = \
356 | (x0_recon[index] == x0_real[index]).sum().cpu() / x0_real.size(1)
357 | self.diffusion_acc_list[this_t] = \
358 | same_rate.item() * 0.1 + self.diffusion_acc_list[this_t] * 0.9
359 | same_rate = \
360 | (xt_1_recon[index] == xt_recon[index]).sum().cpu() / xt_recon.size(1)
361 | self.diffusion_keep_list[this_t] = \
362 | same_rate.item() * 0.1 + self.diffusion_keep_list[this_t] * 0.9
363 |
364 | # Compute train loss
365 | loss_tensor = self.compute_kl_loss(log_xstart, log_xt, t, log_model_prob)
366 | if self.mask_weight != 1: # adjust [mask] token weight
367 | mask_region = (log_onehot_to_index(log_xt) == self.num_classes - 1)
368 | loss_tensor = torch.where(
369 | mask_region, self.mask_weight * loss_tensor, loss_tensor
370 | )
371 | kl_loss = self.sum_last_dims(loss_tensor, keep_dims=1)
372 |
373 | Lt2 = kl_loss.pow(2)
374 | Lt2_prev = self.Lt_history.gather(dim=0, index=t)
375 | new_Lt_history = (0.1 * Lt2 + 0.9 * Lt2_prev).detach()
376 | self.Lt_history.scatter_(dim=0, index=t, src=new_Lt_history)
377 | self.Lt_count.scatter_add_(dim=0, index=t, src=torch.ones_like(Lt2))
378 |
379 | # Upweigh loss term of the kl
380 | loss1 = kl_loss / pt
381 | vb_loss = loss1.mean()
382 | losses_dict = {"kl_loss": loss1.mean()}
383 |
384 | if self.auxiliary_loss_weight > 0:
385 | loss_tensor = self.compute_aux_loss(log_xstart, log_x0_recon, t)
386 | if self.mask_weight != 1: # adjust [mask] token weight
387 | loss_tensor = torch.where(
388 | mask_region, self.mask_weight * loss_tensor, loss_tensor
389 | )
390 | aux_loss = self.sum_last_dims(loss_tensor, keep_dims=1)
391 | if self.adaptive_auxiliary_loss == True:
392 | addition_loss_weight = (1 - t / self.num_timesteps) + 1.0
393 | else:
394 | addition_loss_weight = 1.0
395 |
396 | loss2 = addition_loss_weight * self.auxiliary_loss_weight * aux_loss / pt
397 | losses_dict["aux_loss"] = loss2.mean()
398 | vb_loss += loss2.mean()
399 |
400 | return vb_loss, losses_dict
401 |
402 | def p_sample_loop(self, denoise_fn, log_x_end, condition, condition_cross=None,
403 | sample_freq=None):
404 | B, C, N = log_x_end.shape
405 | assert C == self.num_classes
406 | if sample_freq:
407 | pred_traj = [log_onehot_to_index(log_x_end)]
408 |
409 | log_x_t = log_x_end
410 | total_steps = self.num_timesteps
411 | for t in reversed(range(0, total_steps)):
412 | t_ = torch.full((B,), t, dtype=torch.int64, device=self.device)
413 | log_x_t = self.p_sample(
414 | denoise_fn=denoise_fn, log_x_t=log_x_t, t=t_,
415 | condition=condition, condition_cross=condition_cross
416 | ) # log_x_t is log_onehot
417 | if sample_freq and (t % sample_freq == 0 or t == total_steps - 1):
418 | pred_traj.append(log_onehot_to_index(log_x_t))
419 |
420 | if sample_freq:
421 | return pred_traj
422 | else:
423 | return log_onehot_to_index(log_x_t)
424 |
425 |
426 | class DiscreteDiffusionPoint(nn.Module):
427 | def __init__(self, denoise_net:nn.Module, class_dim, time_num=1000,
428 | model_output_type="x0", mask_weight=1, auxiliary_loss_weight=0,
429 | adaptive_auxiliary_loss=False, **kwargs):
430 |
431 | super(DiscreteDiffusionPoint, self).__init__()
432 |
433 | noise_params = alpha_schedule(time_num, class_dim, **kwargs)
434 |
435 | self.diffusion = MaskAndReplaceDiffusion(
436 | class_dim + 1, noise_params, model_output_type,
437 | mask_weight, auxiliary_loss_weight, adaptive_auxiliary_loss
438 | )
439 | self.model = denoise_net
440 | assert self.diffusion.num_classes - 1 == self.model.class_dim
441 |
442 | def _denoise(self, data, t, condition, condition_cross=None):
443 | B, N = data.shape
444 | C = self.diffusion.num_classes - 1
445 | assert t.shape == torch.Size([B]) and t.dtype == torch.int64
446 |
447 | if isinstance(self.model, MixedDenoiseTransformer):
448 | out, _ = self.model(
449 | x_semantic=data, x_geometry=None, time=t,
450 | context=condition, context_cross=condition_cross
451 | )
452 | else:
453 | out = self.model(data, t, condition)
454 | assert out.shape == torch.Size([B, N, C]), f"Error: {out.shape} != {B, N, C}"
455 | return out
456 |
457 | def get_loss_iter(self, data, condition=None):
458 | device = data.device
459 | self.diffusion._move_tensors(device)
460 | self.model.to(device)
461 |
462 | losses, loss_dict = self.diffusion.p_losses(
463 | denoise_fn=self._denoise, x_start=data, condition=condition,
464 | )
465 | return losses, loss_dict
466 |
467 | def gen_samples(self, shape, device, condition, condition_cross=None,
468 | freq=None):
469 | B, N = shape
470 | self.diffusion._move_tensors(device)
471 |
472 | log_zeros = torch.full((B, self.diffusion.num_classes-1, N), LOG_ZERO, device=device)
473 | log_ones = torch.zeros((B, 1, N), device=device)
474 | log_x_end = torch.cat((log_zeros, log_ones), dim=1)
475 |
476 | return self.diffusion.p_sample_loop(
477 | self._denoise, log_x_end=log_x_end,
478 | condition=condition, condition_cross=condition_cross,
479 | sample_freq=freq
480 | )
481 |
--------------------------------------------------------------------------------
/midiffusion/networks/diffusion_mixed.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | import torch.nn as nn
3 | import torch.utils.data
4 | from .diffusion_base import BaseDiffusion, LOG_ZERO
5 | from .diffusion_d3pm import MaskAndReplaceDiffusion, alpha_schedule, index_to_log_onehot, log_onehot_to_index
6 | from .diffusion_ddpm import GaussianDiffusion, get_betas
7 |
8 |
9 | def extract_params(func, param_dict):
10 | func_args = inspect.signature(func).parameters
11 | return {k:v for k, v in param_dict.items() if k in func_args}
12 |
13 |
14 | class MixedDiffusionPoint(nn.Module):
15 | def __init__(self, denoise_net:nn.Module, network_dim, time_num,
16 | d3pm_config, ddpm_config):
17 | super(MixedDiffusionPoint, self).__init__()
18 |
19 | self.num_timesteps = time_num
20 | self.num_classes = network_dim["class_dim"] # object categories plus [empty]
21 |
22 | # discrete semantic diffusion
23 | noise_params = alpha_schedule(
24 | num_timesteps=time_num, N=network_dim["class_dim"],
25 | **extract_params(alpha_schedule, d3pm_config)
26 | )
27 | self.diffusion_semantic = MaskAndReplaceDiffusion(
28 | network_dim["class_dim"] + 1, noise_params,
29 | **extract_params(MaskAndReplaceDiffusion.__init__, d3pm_config)
30 | )
31 | assert self.diffusion_semantic.num_classes == self.num_classes + 1
32 | assert self.diffusion_semantic.model_output_type == "x0"
33 |
34 | # continuous geometric diffusion
35 | network_dim["class_dim"] = 0
36 | betas = get_betas(
37 | time_num=time_num,
38 | **extract_params(get_betas, ddpm_config)
39 | )
40 | self.diffusion_geometric = GaussianDiffusion(
41 | network_dim, betas,
42 | **extract_params(GaussianDiffusion.__init__, ddpm_config)
43 | )
44 | assert self.diffusion_geometric.loss_type == "mse"
45 |
46 | # denoising net that takes in semantic and geometric features,
47 | # and output corresponding discrete and continuout predictions
48 | self.model = denoise_net
49 |
50 | def _denoise(self, data_semantic, data_geometric, t, condition, out_type="all"):
51 | out_class, out_bbox = \
52 | self.model(data_semantic, data_geometric, t, condition)
53 | if out_type == "semantic":
54 | return out_class
55 | elif out_type == "geometric":
56 | return out_bbox
57 | else:
58 | return out_class, out_bbox
59 |
60 | def get_loss_iter(self, data_semantic, data_geometric, condition=None):
61 | B, N, C = data_geometric.shape
62 | device = data_geometric.device
63 | assert data_semantic.shape == (B, N)
64 | assert data_semantic.device == device
65 |
66 | # Move models and pre-computed tensors to data device
67 | self.diffusion_semantic._move_tensors(device)
68 | self.diffusion_geometric._move_tensors(device)
69 | self.model.to(device)
70 |
71 | # Sample q(x_t | x_0)
72 | t = torch.randint(0, self.num_timesteps, size=(B,), device=device)
73 | # x_t_class: (B, N)
74 | log_xstart = index_to_log_onehot(data_semantic, self.num_classes + 1)
75 | log_xt = self.diffusion_semantic.q_sample(log_x_start=log_xstart, t=t)
76 | x_t_class = log_onehot_to_index(log_xt)
77 | # x_t_geometric: (B, N, C)
78 | noise_geometric = torch.randn_like(data_geometric)
79 | x_t_geometric = self.diffusion_geometric.q_sample(
80 | x_0=data_geometric, t=t, eps=noise_geometric
81 | )
82 |
83 | # Send x_t, t, condition through denoising net
84 | denoise_out_class, denoise_out_geometric = \
85 | self.model(x_t_class, x_t_geometric, t, context=condition)
86 |
87 | # Compute loss
88 | feat_separated_losses = dict() # loss.:
89 |
90 | # semantic
91 | log_x0_recon = self.diffusion_semantic.log_pred_from_denoise_out(denoise_out_class) # p_theta(x0|xt)
92 | log_p_prob = self.diffusion_semantic.q_posterior(
93 | log_x_start=log_x0_recon, log_x_t=log_xt, t=t
94 | ) # go through q(xt_1| xt, x0)
95 | # semantic - train loss
96 | loss_tensor = self.diffusion_semantic.compute_kl_loss(log_xstart, log_xt, t, log_p_prob)
97 | if self.diffusion_semantic.mask_weight != 1: # adjust [mask] token weight
98 | mask_token_region = (x_t_class == self.num_classes)
99 | loss_tensor = torch.where(
100 | mask_token_region,
101 | self.diffusion_semantic.mask_weight * loss_tensor,
102 | loss_tensor
103 | )
104 | loss_class = loss_tensor.mean(dim=1) # average over (B, N) loss tensor
105 | feat_separated_losses["loss.class"] = loss_class
106 | # semantic - aux loss
107 | if self.diffusion_semantic.auxiliary_loss_weight > 0:
108 | loss_tensor = self.diffusion_semantic.compute_aux_loss(log_xstart, log_x0_recon, t)
109 | if self.diffusion_semantic.mask_weight != 1: # adjust [mask] token weight
110 | loss_tensor = torch.where(
111 | mask_token_region,
112 | self.diffusion_semantic.mask_weight * loss_tensor,
113 | loss_tensor
114 | )
115 | aux_loss = loss_tensor.mean(dim=1) # average over (B, N) loss tensor
116 | if self.diffusion_semantic.adaptive_auxiliary_loss == True:
117 | addition_loss_weight = (1 - t / self.num_timesteps) + 1.0
118 | else:
119 | addition_loss_weight = 1.0
120 | loss_class_aux = (addition_loss_weight * \
121 | self.diffusion_semantic.auxiliary_loss_weight * aux_loss).mean(dim=-1)
122 | else:
123 | loss_class_aux = torch.zeros(B, device=device)
124 | feat_separated_losses["loss.class_aux"] = loss_class_aux
125 |
126 | # geometric
127 | object_mask = None # TODO
128 | # object_mask = torch.logical_and(
129 | # (x_t_class != self.num_classes), (x_t_class != self.num_classes - 1)
130 | # )
131 | loss_tensor = self.diffusion_geometric.compute_mse_loss(
132 | data_geometric, x_t_geometric, t, denoise_out_geometric, noise_geometric
133 | )
134 | weight_at_t = BaseDiffusion._extract(
135 | self.diffusion_geometric.loss_weight, t, torch.Size([B])
136 | )
137 | for feat_name, (feat_dim, start_ind) in self.diffusion_geometric.dim_dict.items():
138 | if feat_name in ["class", "object"] or feat_dim == 0:
139 | continue
140 | feat_loss = BaseDiffusion.mean_per_batch(
141 | loss_tensor, feat_dim, start_ind, mask=object_mask
142 | )
143 | feat_separated_losses["loss." + feat_name] = feat_loss * weight_at_t
144 | loss_geometric = feat_separated_losses["loss.bbox"]
145 | if "loss_objfeat" in feat_separated_losses.keys():
146 | loss_geometric += feat_separated_losses["loss.objfeat"]
147 | # additional bounding box regularization loss on mean of p(x_{t-1} | x_t)
148 | if self.diffusion_geometric.loss_iou:
149 | x_recon = self.diffusion_geometric.predict_start_from_denoise_out(
150 | denoise_out_geometric, x_t_geometric, t, clip_x_start=False
151 | )
152 | loss_iou = self.diffusion_geometric.bbox_iou_losses(x_recon, t)
153 | feat_separated_losses['loss.liou'] = loss_iou
154 | else:
155 | loss_iou = 0
156 |
157 | return ((loss_class + loss_class_aux + loss_geometric + loss_iou).mean(),
158 | {k: v.mean() for k, v in feat_separated_losses.items()})
159 |
160 | def gen_samples(self, shape_geo, device, condition,
161 | x0_class_partial=None, class_mask=None,
162 | x0_geometric_partial=None, geometry_mask=None,
163 | freq=None, clip_denoised=False):
164 | B, N, C = shape_geo
165 | self.diffusion_semantic._move_tensors(device)
166 | self.diffusion_geometric._move_tensors(device)
167 | self.model.to(device)
168 |
169 | # Generate complete diffusion of input partial data
170 | if x0_class_partial is None:
171 | partial_class = False
172 | else:
173 | partial_class = True
174 | assert x0_class_partial.shape == torch.Size((B, N))
175 | assert len(class_mask) == N
176 |
177 | # create a full sequence of noisy x_t's
178 | log_x_t_class_partials = []
179 | log_x_t_class_partial = index_to_log_onehot(x0_class_partial, self.num_classes + 1)
180 | for t in range(self.num_timesteps):
181 | t_ = torch.full((B,), t, dtype=torch.int64, device=device)
182 | log_x_t_class_partial = self.diffusion_semantic\
183 | .q_sample_one_step(log_x_t_class_partial, t_, no_mask=True)
184 | log_x_t_class_partials.append(log_x_t_class_partial)
185 |
186 | if x0_geometric_partial is None:
187 | partial_geometry = False
188 | else:
189 | partial_geometry = True
190 | assert x0_geometric_partial.shape == torch.Size((B, N, C))
191 | assert geometry_mask.shape == torch.Size((N, C))
192 |
193 | # create a full sequence of noisy x_t's
194 | x_t_geometric_partials = []
195 | x_t_geometric_partial = x0_geometric_partial
196 | for t in range(self.num_timesteps):
197 | t_ = torch.full((B,), t, dtype=torch.int64, device=device)
198 | x_t_geometric_partial = self.diffusion_geometric\
199 | .q_sample_one_step(x_t_geometric_partial, t_)
200 | x_t_geometric_partials.append(x_t_geometric_partial)
201 |
202 | # Initialize data
203 | log_zeros = torch.full((B, self.num_classes, N), LOG_ZERO, device=device)
204 | log_ones = torch.zeros((B, 1, N), device=device)
205 | log_x_end_class = torch.cat((log_zeros, log_ones), dim=1)
206 | x_end_geometric = torch.randn(size=shape_geo, dtype=torch.float, device=device)
207 |
208 | # Backward denoising steps
209 | if freq is not None:
210 | pred_traj = []
211 | log_x_t_class = log_x_end_class # (B, num_classes+1, N) - flip last two axis for output
212 | x_t_geometric = x_end_geometric # (B, N, num_features)
213 | for t in reversed(range(0, self.num_timesteps)):
214 | t_ = torch.full((B,), t, dtype=torch.int64, device=device)
215 |
216 | if partial_class:
217 | log_x_t_class[:, :, class_mask] = log_x_t_class_partials[t][:, :, class_mask]
218 | if partial_geometry:
219 | x_t_geometric = torch.where(
220 | geometry_mask.unsqueeze(0).expand_as(x_t_geometric),
221 | x_t_geometric_partials[t],
222 | x_t_geometric
223 | )
224 |
225 | x_t_class = log_onehot_to_index(log_x_t_class)
226 | denoise_out_class, denoise_out_geometric = \
227 | self.model(x_t_class, x_t_geometric, t_, context=condition)
228 | # semantic label probability distribution
229 | log_x_recon = self.diffusion_semantic.log_pred_from_denoise_out(denoise_out_class)
230 | if t == 0:
231 | log_EV_qxt_x0 = log_x_recon
232 | else:
233 | log_EV_qxt_x0 = self.diffusion_semantic.q_posterior(
234 | log_x_start=log_x_recon, log_x_t=log_x_t_class, t=t_
235 | )
236 | # semantic label sampling
237 | log_x_t_class = self.diffusion_semantic.log_sample_categorical(log_EV_qxt_x0)
238 | # geometric feature probability distribution
239 | p_mean, _, p_log_var = self.diffusion_geometric\
240 | .p_pred_from_denoise_out(denoise_out_geometric, x_t_geometric, t_)
241 | # geometric feature sampling
242 | noise = torch.randn_like(p_mean)
243 | x_t_geometric = p_mean + torch.exp(0.5 * p_log_var) * noise
244 | if clip_denoised:
245 | x_t_geometric = self.diffusion_geometric._clip_bbox(x_t_geometric)
246 |
247 | if freq is not None and (t % freq == 0 or t == self.num_timesteps - 1):
248 | pred_traj.append(
249 | (log_x_t_class[:, :-1, :].permute(0, 2, 1), x_t_geometric)
250 | ) # (B, N, num_classes), (B, N, num_features)
251 |
252 | # Replace final prediction with partial inputs
253 | if partial_class:
254 | log_x_t_class[:, :, class_mask] = \
255 | index_to_log_onehot(x0_class_partial, self.num_classes + 1)[:, :, class_mask]
256 | if partial_geometry:
257 | x_t_geometric = torch.where(
258 | geometry_mask.unsqueeze(0).expand_as(x_t_geometric),
259 | x0_geometric_partial,
260 | x_t_geometric
261 | )
262 |
263 | if (partial_class or partial_geometry) and freq is not None:
264 | pred_traj[-1] = (
265 | (log_x_t_class[:, :-1, :].permute(0, 2, 1), x_t_geometric)
266 | )
267 |
268 | if freq is not None:
269 | return pred_traj
270 | else:
271 | return log_x_t_class[:, :-1, :].permute(0, 2, 1), x_t_geometric
272 |
--------------------------------------------------------------------------------
/midiffusion/networks/diffusion_scene_layout_ddpm.py:
--------------------------------------------------------------------------------
1 | #
2 | # Modified from:
3 | # https://github.com/tangjiapeng/DiffuScene/blob/master/scene_synthesis/networks/diffusion_scene_layout_ddpm.py
4 | # This version (1) simplifies some parts of the original implementation, and
5 | # (2) works with transformer denoising network as well as U-Net.
6 | #
7 |
8 | import torch
9 | import torch.nn as nn
10 | from torch.nn import Module
11 |
12 | from .diffusion_ddpm import DiffusionPoint
13 | from .denoising_net.unet1D import Unet1D
14 | from .denoising_net.continuous_transformer import ContinuousDenoiseTransformer
15 | from .feature_extractors import ResNet18, PointNet_Point
16 |
17 |
18 | class DiffusionSceneLayout_DDPM(Module):
19 | """Scene synthesis via continuous state diffusion"""
20 | def __init__(self, n_object_types, feature_extractor, config, train_stats_file):
21 | super().__init__()
22 | self.n_object_types = n_object_types
23 | self.config = config
24 |
25 | if "class_dim" in config:
26 | assert config["class_dim"] == n_object_types + 1
27 | else:
28 | config["class_dim"] = n_object_types + 1
29 |
30 | # read object property dimension
31 | self.class_dim = config.get("class_dim")
32 | self.translation_dim = config.get("translation_dim", 3)
33 | self.size_dim = config.get("size_dim", 3)
34 | self.angle_dim = config.get("angle_dim", 1)
35 | self.bbox_dim = self.translation_dim + self.size_dim + self.angle_dim
36 | self.objectness_dim = config.get("objectness_dim", 0)
37 | self.objfeat_dim = config.get("objfeat_dim", 0)
38 | self.network_dim = {k: getattr(self, k) for k in [
39 | "objectness_dim", "class_dim", "translation_dim", "size_dim",
40 | "angle_dim", "objfeat_dim"
41 | ]}
42 |
43 | # maximum number of points
44 | self.sample_num_points = config.get("sample_num_points", 12)
45 |
46 | # initialize conditinoal input dimension
47 | self.context_dim = 0
48 |
49 | # room_mask_condition: if yes, define the feature extractor for the room mask
50 | self.room_mask_condition = config.get("room_mask_condition", True)
51 | if self.room_mask_condition:
52 | self.room_latent_dim = config["room_latent_dim"]
53 | self.feature_extractor = feature_extractor
54 | self.fc_room_f = nn.Linear(
55 | self.feature_extractor.feature_size, self.room_latent_dim
56 | ) if self.feature_extractor.feature_size != self.room_latent_dim \
57 | else nn.Identity()
58 | print('use room mask as condition')
59 | self.context_dim += self.room_latent_dim
60 |
61 | # define positional embeddings
62 | self.position_condition = config.get("position_condition", False)
63 | self.learnable_embedding = config.get("learnable_embedding", False)
64 |
65 | if self.position_condition:
66 | self.position_emb_dim = config.get("position_emb_dim", 64)
67 | if self.learnable_embedding:
68 | self.register_parameter("positional_embedding", nn.Parameter(
69 | torch.randn(self.sample_num_points, self.position_emb_dim)
70 | ))
71 | else:
72 | self.fc_position_condition = nn.Sequential(
73 | nn.Linear(self.sample_num_points, self.position_emb_dim,
74 | bias=False),
75 | nn.LeakyReLU(0.1, inplace=True),
76 | nn.Linear(self.position_emb_dim, self.position_emb_dim,
77 | bias=False),
78 | )
79 | print('use position embedding for object index')
80 | self.context_dim += self.position_emb_dim
81 |
82 | if "diffusion_kwargs" in config.keys():
83 | # define the denoising network
84 | if config["net_type"] == "unet1d":
85 | denoise_net = Unet1D(
86 | network_dim = self.network_dim,
87 | context_dim = self.context_dim,
88 | **config["net_kwargs"]
89 | )
90 | elif config["net_type"] == "transformer":
91 | denoise_net = ContinuousDenoiseTransformer(
92 | network_dim = self.network_dim,
93 | context_dim = self.context_dim,
94 | num_timesteps = config["diffusion_kwargs"]["time_num"],
95 | **config["net_kwargs"]
96 | )
97 | else:
98 | raise NotImplementedError()
99 |
100 | # define the diffusion type
101 | self.diffusion = DiffusionPoint(
102 | denoise_net = denoise_net,
103 | network_dim = self.network_dim,
104 | train_stats_file = train_stats_file,
105 | **config["diffusion_kwargs"]
106 | )
107 | self.point_dim = sum(dim for dim in self.network_dim.values())
108 |
109 | def unpack_data(self, sample_params):
110 | # read data
111 | class_labels = sample_params["class_labels"]
112 | translations = sample_params["translations"]
113 | sizes = sample_params["sizes"]
114 | angles = sample_params["angles"]
115 | if self.objectness_dim > 0:
116 | objectness = sample_params["objectness"]
117 | if self.objfeat_dim == 32:
118 | objfeats = sample_params["objfeats_32"]
119 | elif self.objfeat_dim == 64:
120 | objfeats = sample_params["objfeats"]
121 | elif self.objfeat_dim != 0:
122 | raise NotImplemented
123 |
124 | # get desired diffusion target
125 | room_layout_target = \
126 | torch.cat([translations, sizes, angles, class_labels], dim=-1)
127 | if self.objectness_dim > 0:
128 | room_layout_target = \
129 | torch.cat([room_layout_target, objectness], dim=-1)
130 | if self.objfeat_dim > 0:
131 | room_layout_target = \
132 | torch.cat([room_layout_target, objfeats], dim=-1)
133 |
134 | return room_layout_target
135 |
136 | def unpack_condition(self, batch_size, device, room_feature=None):
137 | # condition to denoise_net
138 | condition = None
139 |
140 | # get the latent feature of room_mask
141 | if self.room_mask_condition:
142 | room_layout_f = self.fc_room_f(self.feature_extractor(room_feature))
143 | condition = room_layout_f[:, None, :].repeat(1, self.sample_num_points, 1)
144 |
145 | # process instance position condition f
146 | if self.position_condition:
147 | if self.learnable_embedding:
148 | position_condition_f = self.positional_embedding[None, :]\
149 | .repeat(batch_size, 1, 1)
150 | else:
151 | instance_label = torch.eye(self.sample_num_points).float()\
152 | .to(device)[None, ...].repeat(batch_size, 1, 1)
153 | position_condition_f = self.fc_position_condition(instance_label)
154 |
155 | condition = torch.cat(
156 | [condition, position_condition_f], dim=-1
157 | ).contiguous() if condition is not None else position_condition_f
158 |
159 | return condition
160 |
161 | def get_loss(self, sample_params):
162 | # unpack sample_params
163 | room_layout_target = self.unpack_data(sample_params)
164 |
165 | # unpack condition
166 | room_feature = None
167 | if self.room_mask_condition:
168 | if isinstance(self.feature_extractor, ResNet18):
169 | room_feature = sample_params["room_layout"]
170 | elif isinstance(self.feature_extractor, PointNet_Point):
171 | room_feature = sample_params["fpbpn"]
172 | condition = self.unpack_condition(
173 | room_layout_target.shape[0], room_layout_target.device, room_feature
174 | )
175 |
176 | # denoise loss function
177 | loss, loss_dict = self.diffusion.get_loss_iter(
178 | room_layout_target, condition=condition
179 | )
180 | return loss, loss_dict
181 |
182 | def sample(self, room_feature=None, batch_size=1, input_boxes=None,
183 | feature_mask=None, clip_denoised=False, ret_traj=False, freq=40,
184 | device="cpu"):
185 | # condition to denoise_net
186 | condition = self.unpack_condition(batch_size, device, room_feature)
187 |
188 | # reverse sampling
189 | data_shape = (batch_size, self.sample_num_points, self.point_dim)
190 | return self.diffusion.gen_samples(
191 | data_shape, device=device, condition=condition,
192 | clip_denoised=clip_denoised, freq=freq if ret_traj else None
193 | )
194 |
195 | @torch.no_grad()
196 | def generate_layout(self, room_feature=None, batch_size=1, input_boxes=None,
197 | feature_mask=None, clip_denoised=False, device="cpu"):
198 | """Generate a list of bbox_params dict, each corresponds to one layout
199 | that can be processed by dataset's post_process() class function.
200 | The features in each dict is a tensor of [0, Ni, ?] dimension where
201 | Ni is the number of objects predicted."""
202 | if self.room_mask_condition:
203 | assert room_feature.size(0) == batch_size
204 |
205 | samples = self.sample(
206 | room_feature, batch_size, input_boxes=input_boxes, feature_mask=feature_mask,
207 | clip_denoised=clip_denoised, device=device, ret_traj=False
208 | )
209 |
210 | return self.delete_empty_from_network_samples(samples)
211 |
212 | @torch.no_grad()
213 | def generate_layout_progressive(self, room_feature=None, batch_size=1,
214 | input_boxes=None, feature_mask=None,
215 | clip_denoised=False, device="cpu",
216 | save_freq=100):
217 | """Generate a list of tuples. Each tuple stores a trajectory of one
218 | predicted layout. It contains a collection of bbox_params dict, each
219 | corresdpons to a frame in the reverse diffusion process that can be
220 | processed by dataset's post_process() class function.
221 | The features in each dict is a tensor of [0, Ni, ?] dimension where
222 | Ni is the number of objects predicted."""
223 | # generate results at each time step
224 | samples_traj = self.sample(
225 | room_feature, batch_size, input_boxes=input_boxes, feature_mask=feature_mask,
226 | clip_denoised=clip_denoised, device=device, ret_traj=True, freq=save_freq,
227 | )
228 | results_by_time = []
229 | for samples_t in samples_traj:
230 | samples_t_list = self.delete_empty_from_network_samples(samples_t)
231 | results_by_time.append(samples_t_list)
232 |
233 | # combine results of the same frame to a tuple
234 | traj_list = []
235 | for b in range(batch_size):
236 | traj_list.append(tuple(results[b] for results in results_by_time))
237 |
238 | return traj_list
239 |
240 | @torch.no_grad()
241 | def delete_empty_from_network_samples(self, samples):
242 | """Remove objects with 'empty' label given samples of [B, N, C]
243 | dimensions. The output is a list of dictionaries with features
244 | of [0, N_i, ?] dimensions for each object feature."""
245 | object_max, object_max_ind = torch.max(
246 | samples[:, :, self.bbox_dim:self.bbox_dim+self.n_object_types],
247 | dim=-1
248 | )
249 | samples_dict = {
250 | "translations": samples[:, :, 0:self.translation_dim].contiguous(),
251 | "sizes": samples[:, :, self.translation_dim:self.translation_dim+
252 | self.size_dim].contiguous(),
253 | "angles": samples[:, :, self.translation_dim+self.size_dim:
254 | self.bbox_dim].contiguous(),
255 | "class_labels": nn.functional.one_hot(
256 | object_max_ind, num_classes=self.n_object_types
257 | ),
258 | "is_empty": samples[:, :, self.bbox_dim+self.class_dim-1] > object_max,
259 | }
260 | if self.objfeat_dim > 0:
261 | samples_dict["objfeats"] = \
262 | samples[:, :, self.bbox_dim+self.class_dim:self.bbox_dim+
263 | self.class_dim+self.objfeat_dim]
264 |
265 | return self.delete_empty_boxes(samples_dict)
266 |
267 | @torch.no_grad()
268 | def delete_empty_boxes(self, samples_dict):
269 | """Remove objects with 'empty' label given samples_dict with features of
270 | [B, N, ?] dimensions. The output is a list of dictionaries with features
271 | of [0, N_i, ?] dimensions for each object feature."""
272 | batch_size = samples_dict["class_labels"].size(0)
273 | max_boxes = samples_dict["class_labels"].size(1)
274 |
275 | return_properties = ["class_labels", "translations", "sizes", "angles"]
276 | if self.objfeat_dim > 0:
277 | return_properties.append("objfeats")
278 |
279 | boxes_list = []
280 | for b in range(batch_size):
281 | # Initialize empty dict
282 | boxes = {
283 | "class_labels": torch.zeros(1, 0, self.n_object_types),
284 | "translations": torch.zeros(1, 0, self.translation_dim),
285 | "sizes": torch.zeros(1, 0, self.size_dim),
286 | "angles": torch.zeros(1, 0, self.angle_dim)
287 | }
288 | if self.objfeat_dim > 0:
289 | boxes["objfeats"] = torch.zeros(1, 0, self.objfeat_dim)
290 |
291 | for i in range(max_boxes):
292 | # Check if we have the end symbol
293 | if samples_dict["is_empty"][b, i]:
294 | continue
295 | else:
296 | for k in return_properties:
297 | boxes[k] = torch.cat([
298 | boxes[k],
299 | samples_dict[k][b:b+1, i:i+1, :].to("cpu")
300 | ], dim=1)
301 |
302 | boxes_list.append(boxes)
303 |
304 | return boxes_list
305 |
--------------------------------------------------------------------------------
/midiffusion/networks/diffusion_scene_layout_mixed.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from .diffusion_mixed import MixedDiffusionPoint
4 | from .denoising_net.mixed_transformer import MixedDenoiseTransformer
5 | from .feature_extractors import ResNet18, PointNet_Point
6 | from .diffusion_scene_layout_ddpm import DiffusionSceneLayout_DDPM
7 |
8 |
9 | class DiffusionSceneLayout_Mixed(DiffusionSceneLayout_DDPM):
10 | """Scene synthesis via mixed discrete-continuous diffusion"""
11 | def __init__(self, n_object_types, feature_extractor, config, train_stats_file):
12 | super(DiffusionSceneLayout_Mixed, self).__init__(
13 | n_object_types, feature_extractor, config, train_stats_file
14 | )
15 |
16 | # define the denoising network
17 | assert config["net_type"] == "transformer"
18 | denoise_net = MixedDenoiseTransformer(
19 | network_dim = self.network_dim,
20 | context_dim = self.context_dim,
21 | num_timesteps = config["time_num"],
22 | **config["net_kwargs"]
23 | )
24 |
25 | # define the diffusion type
26 | config["diffusion_geometric_kwargs"]["train_stats_file"] = train_stats_file
27 | self.diffusion = MixedDiffusionPoint(
28 | denoise_net = denoise_net,
29 | network_dim = self.network_dim,
30 | time_num = config["time_num"],
31 | d3pm_config = config["diffusion_semantic_kwargs"],
32 | ddpm_config = config["diffusion_geometric_kwargs"],
33 | )
34 | self.point_dim = sum(dim for k, dim in self.network_dim.items()
35 | if k !="class_dim")
36 |
37 | def get_loss(self, sample_params):
38 | # unpack sample_params
39 | room_layout_target = self.unpack_data(sample_params)
40 |
41 | # unpack condition
42 | room_feature = None
43 | if self.room_mask_condition:
44 | if isinstance(self.feature_extractor, ResNet18):
45 | room_feature = sample_params["room_layout"]
46 | elif isinstance(self.feature_extractor, PointNet_Point):
47 | room_feature = sample_params["fpbpn"]
48 | condition = self.unpack_condition(
49 | room_layout_target.shape[0], room_layout_target.device, room_feature
50 | )
51 | semantic_target = \
52 | room_layout_target[:, :, self.bbox_dim:self.bbox_dim+self.class_dim]\
53 | .argmax(dim=-1)
54 | geometric_target = torch.cat([
55 | room_layout_target[:, :, :self.bbox_dim],
56 | room_layout_target[:, :, self.bbox_dim+self.class_dim:]
57 | ], dim=-1).contiguous()
58 |
59 | # denoise loss function
60 | loss, loss_dict = self.diffusion.get_loss_iter(
61 | semantic_target, geometric_target, condition=condition
62 | )
63 | return loss, loss_dict
64 |
65 | def sample(self, room_feature=None, batch_size=1, input_boxes=None,
66 | feature_mask=None, clip_denoised=False, ret_traj=False, freq=40,
67 | device="cpu"):
68 | # condition to denoise_net
69 | condition = self.unpack_condition(batch_size, device, room_feature)
70 |
71 | # retrieve known features from input_boxes if available
72 | x0_class, x0_geometric, class_mask, geometry_mask = None, None, None, None
73 | if input_boxes is not None:
74 | assert input_boxes.shape == torch.Size(
75 | (batch_size, self.sample_num_points, self.point_dim+self.class_dim)
76 | )
77 | assert feature_mask.shape == torch.Size(
78 | (self.sample_num_points, self.point_dim+self.class_dim)
79 | )
80 | class_mask = torch.all(
81 | feature_mask[:, self.bbox_dim:self.bbox_dim+self.class_dim],
82 | dim=1
83 | ) # 1-D
84 | geometry_mask = torch.cat([
85 | feature_mask[:, :self.bbox_dim],
86 | feature_mask[:, self.bbox_dim+self.class_dim:]
87 | ], dim=1) # 2-D
88 |
89 | if class_mask.any():
90 | x0_class = \
91 | input_boxes[:, :, self.bbox_dim:self.bbox_dim+self.class_dim]\
92 | .argmax(dim=-1)
93 | if geometry_mask.any():
94 | x0_geometric = torch.cat([
95 | input_boxes[:, :, :self.bbox_dim],
96 | input_boxes[:, :, self.bbox_dim+self.class_dim:]
97 | ], dim=-1)
98 |
99 | # reverse sampling
100 | geometric_data_shape = (batch_size, self.sample_num_points, self.point_dim)
101 | if ret_traj:
102 | samples_traj = self.diffusion.gen_samples(
103 | geometric_data_shape, device=device, condition=condition,
104 | x0_class_partial=x0_class, class_mask=class_mask,
105 | x0_geometric_partial=x0_geometric, geometry_mask=geometry_mask,
106 | freq=freq, clip_denoised=clip_denoised
107 | )
108 | samples_list = []
109 | for samples_class, samples_geometric in samples_traj:
110 | samples_list.append(torch.cat([
111 | samples_geometric[:, :, :self.bbox_dim],
112 | torch.exp(samples_class),
113 | samples_geometric[:, :, self.bbox_dim:]
114 | ], dim=-1))
115 | return samples_list
116 | else:
117 | samples_class, samples_geometric = self.diffusion.gen_samples(
118 | geometric_data_shape, device=device, condition=condition,
119 | x0_class_partial=x0_class, class_mask=class_mask,
120 | x0_geometric_partial=x0_geometric, geometry_mask=geometry_mask,
121 | clip_denoised=clip_denoised
122 | )
123 | return torch.cat([
124 | samples_geometric[:, :, :self.bbox_dim],
125 | torch.exp(samples_class),
126 | samples_geometric[:, :, self.bbox_dim:]
127 | ], dim=-1)
128 |
--------------------------------------------------------------------------------
/midiffusion/networks/feature_extractors.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torchvision import models
4 |
5 | from .frozen_batchnorm import FrozenBatchNorm2d
6 |
7 |
8 | class BaseFeatureExtractor(nn.Module):
9 | """Hold some common functions among all feature extractor networks"""
10 | @property
11 | def feature_size(self):
12 | return self._feature_size
13 |
14 | def forward(self, data):
15 | raise NotImplementedError
16 |
17 |
18 | class ResNet18(BaseFeatureExtractor):
19 | """Build a feature extractor using the pretrained ResNet18 architecture for
20 | image based inputs"""
21 | """https://github.com/nv-tlabs/ATISS/blob/master/scene_synthesis/networks/feature_extractors.py"""
22 | def __init__(self, freeze_bn, input_channels, feature_size):
23 | super(ResNet18, self).__init__()
24 | self._feature_size = feature_size
25 |
26 | self._feature_extractor = models.resnet18(weights=None)
27 | if freeze_bn:
28 | FrozenBatchNorm2d.freeze(self._feature_extractor)
29 |
30 | self._feature_extractor.conv1 = torch.nn.Conv2d(
31 | input_channels,
32 | 64,
33 | kernel_size=(7, 7),
34 | stride=(2, 2),
35 | padding=(3, 3),
36 | bias=False
37 | )
38 |
39 | self._feature_extractor.fc = nn.Sequential(
40 | nn.Linear(512, 512), nn.ReLU(),
41 | nn.Linear(512, self.feature_size)
42 | )
43 | self._feature_extractor.avgpool = nn.AdaptiveAvgPool2d((1, 1))
44 |
45 | def forward(self, X):
46 | return self._feature_extractor(X)
47 |
48 |
49 | class PointNet_Point(BaseFeatureExtractor):
50 | """The 'pointnet_simple' encoder in train script"""
51 | """https://github.com/QiuhongAnnaWei/LEGO-Net/blob/main/model/floorplan_encoder.py"""
52 |
53 | def __init__(self, activation=torch.nn.LeakyReLU(), nfpbpn=256, feat_units=[4, 64, 64, 512, 511]):
54 | """ feat_units[0] should be 4 (x, y, nx, ny for fpbpn) (in_dim)
55 | feat_units[-1] should be transformer_input_d-1 (out_dim)
56 | maxpool happens before the last linear layer
57 | """
58 | super().__init__()
59 | self.activation = activation
60 | self._feature_size = feat_units[-1]
61 |
62 | layers = []
63 | for i in range(1, len(feat_units)):
64 | layers.append(torch.nn.Linear(feat_units[i-1], feat_units[i]))
65 | self.layers = torch.nn.ModuleList(layers)
66 |
67 | self.fp_maxpool= torch.nn.MaxPool1d(nfpbpn) # nfpbpn -> 1
68 |
69 |
70 | def forward(self, fpbpn):
71 | """ fpbpn: [batch_size, nfpbp=250, 4], floor_plan_boundary_points_normals """
72 |
73 | B = fpbpn.shape[0]
74 |
75 | for i in range(len(self.layers)-1): # first 3 layers
76 | fpbpn = self.activation(self.layers[i](fpbpn)) # [B, nfpbp, feat_units[-2]]
77 |
78 | fpbpn = fpbpn.permute((0,2,1)) # [B, feat_units[-2], nfpbp]
79 | scene_fp_feat = self.fp_maxpool(fpbpn).reshape(B, -1) # [B, feat_units[-2], 1] -> [B, feat_units[-2]]
80 |
81 | return self.layers[-1](scene_fp_feat) # [B, feat_units[-1]]
82 |
83 |
84 | class PointNet_Line(nn.Module):
85 | """ The 'pointnet' encoder in train script.
86 | Has 3 sections: point processing, line processing, floor plan feature processing """
87 | def __init__(self, activation=torch.nn.LeakyReLU(), maxnfpoc=25, corner_feat_units = [2, 64, 128], line_feat_units = [256, 512, 1024], fp_units = [1024, 511]):
88 | """ corner_feat_units[0] should be pos_dim (in_dim)
89 | line_feat_units[0] should be corner_feat_units[-1]*2
90 | fp_units[0] should be line_feat_units[-1]
91 | fp_units[-1] should be transformer_input_d-1 (out_dim)
92 | """
93 | super().__init__()
94 | self.activation = activation # torch.nn.LeakyReLU()
95 | self._feature_size = fp_units[-1]
96 |
97 | corner_feat = []
98 | for i in range(1, len(corner_feat_units)): # 2 layers
99 | corner_feat.append(torch.nn.Linear(corner_feat_units[i-1], corner_feat_units[i]))
100 | self.corner_feat = torch.nn.ModuleList(corner_feat)
101 |
102 | line_feat = []
103 | for i in range(1, len(line_feat_units)): # 2 layers
104 | line_feat.append(torch.nn.Linear(line_feat_units[i-1], line_feat_units[i]))
105 | self.line_feat = torch.nn.ModuleList(line_feat)
106 | self.fp_maxpool= torch.nn.MaxPool1d(maxnfpoc) # maxnfpoc -> 1
107 |
108 | # want floor plan to appear as an object token
109 | fp_feat = []
110 | for i in range(1, len(fp_units)): # 1 layer
111 | fp_feat.append(torch.nn.Linear(fp_units[i-1], fp_units[i]))
112 | self.fp_feat = torch.nn.ModuleList(fp_feat)
113 |
114 |
115 | def forward(self, fpoc, nfpc):
116 | """ fpoc : [batch_size, maxnfpoc, pos=2], with padded 0 beyond the num of floor plan ordered corners for each scene
117 | nfpc : [batch_size], num of floor plan ordered corners for each scene
118 | """
119 | device = fpoc.device
120 |
121 | # each point -> MLP -> each point has point features
122 | for i in range(len(self.corner_feat)-1):
123 | fpoc = self.activation(self.corner_feat[i](fpoc))
124 | fpoc = self.corner_feat[-1](fpoc) # [B, nfpv, corner_feat_units[-1]=128]
125 |
126 | # concatenate 2 point features for each line pair -> MLP -> maxpool
127 | B, maxnfpoc, cornerpt_feat_d = fpoc.shape[0], fpoc.shape[1], fpoc.shape[2] # maxnfpoc = maxnfpoc
128 | line_pairpt_input = torch.zeros(B, maxnfpoc, cornerpt_feat_d*2).to(device)
129 | for s_i in range(B):
130 | for l_i in range(nfpc[s_i]): # otherwise padded with 0; clockwise ordering of lines
131 | line_pairpt_input[s_i, l_i, :cornerpt_feat_d] = fpoc[s_i, l_i, :] # index slicing does not make copy
132 | line_pairpt_input[s_i, l_i, cornerpt_feat_d:] = fpoc[s_i, (l_i+1)%nfpc[s_i], :]
133 | line_pairpt_input = line_pairpt_input.to(device)
134 |
135 | for i in range(len(self.line_feat)-1):
136 | line_pairpt_input = self.activation(self.line_feat[i](line_pairpt_input))
137 | line_pairpt_input = self.line_feat[-1](line_pairpt_input) # [B, nfpv, line_feat_units[-1]=1024]
138 |
139 | line_pairpt_input_padded = torch.zeros(line_pairpt_input.shape).to(device) # [B, nfpv, 1024]
140 | for s_i in range(B):
141 | line_pairpt_input_padded[s_i, :nfpc[s_i], :] = line_pairpt_input[s_i, :nfpc[s_i], :]
142 | line_pairpt_input_padded[s_i, nfpc[s_i]:, :] = line_pairpt_input[s_i, 0, :] # duplicate first feat to not impact maxpool
143 | line_pairpt_input_padded = line_pairpt_input_padded.to(device)
144 |
145 | line_pairpt_input_padded = line_pairpt_input_padded.permute((0,2,1)) # [B, 1024, nfpv]
146 | scene_fpoc = self.fp_maxpool(line_pairpt_input_padded).reshape(B, -1) # [B, 1024, 1] -> [B, 1024]
147 |
148 | # One floor plan feat per scene -> MLP -> input to transformer as last obj token
149 | for i in range(len(self.fp_feat)-1):
150 | scene_fpoc = self.activation(self.fp_feat[i](scene_fpoc))
151 | scene_fpoc = torch.unsqueeze(self.fp_feat[-1](scene_fpoc), dim=1) # [B, None->1, transformer_input_d-1]
152 |
153 | return scene_fpoc
154 |
155 |
156 | def get_feature_extractor(name, **kwargs):
157 | """Based on the name return the appropriate feature extractor."""
158 | if name == "resnet18":
159 | return ResNet18(**kwargs)
160 | elif name == "pointnet_simple":
161 | return PointNet_Point(**kwargs)
162 | elif name == "pointnet":
163 | return PointNet_Line(**kwargs)
164 | else:
165 | raise NotImplemented
166 |
--------------------------------------------------------------------------------
/midiffusion/networks/frozen_batchnorm.py:
--------------------------------------------------------------------------------
1 | #
2 | # Modified from:
3 | # https://github.com/nv-tlabs/ATISS
4 | #
5 |
6 | import torch
7 | import torch.nn as nn
8 | from torch.nn.parameter import Parameter
9 |
10 |
11 | class FrozenBatchNorm2d(nn.Module):
12 | """A BatchNorm2d wrapper for Pytorch's BatchNorm2d where the batch
13 | statictis are fixed"""
14 | def __init__(self, num_features):
15 | super(FrozenBatchNorm2d, self).__init__()
16 | self.num_features = num_features
17 | self.register_parameter("weight", Parameter(torch.ones(num_features)))
18 | self.register_parameter("bias", Parameter(torch.zeros(num_features)))
19 | self.register_buffer("running_mean", torch.zeros(num_features))
20 | self.register_buffer("running_var", torch.ones(num_features))
21 |
22 | def extra_repr(self):
23 | return '{num_features}'.format(**self.__dict__)
24 |
25 | @classmethod
26 | def from_batch_norm(cls, bn):
27 | fbn = cls(bn.num_features)
28 | # Update the weight and biases based on the corresponding weights and
29 | # biases of the pre-trained bn layer
30 | with torch.no_grad():
31 | fbn.weight[...] = bn.weight
32 | fbn.bias[...] = bn.bias
33 | fbn.running_mean[...] = bn.running_mean
34 | fbn.running_var[...] = bn.running_var + bn.eps
35 | return fbn
36 |
37 | @staticmethod
38 | def _getattr_nested(m, module_names):
39 | if len(module_names) == 1:
40 | return getattr(m, module_names[0])
41 | else:
42 | return FrozenBatchNorm2d._getattr_nested(
43 | getattr(m, module_names[0]), module_names[1:]
44 | )
45 |
46 | @staticmethod
47 | def freeze(m):
48 | for (name, layer) in m.named_modules():
49 | if isinstance(layer, nn.BatchNorm2d):
50 | nest = name.split(".")
51 | if len(nest) == 1:
52 | setattr(m, name, FrozenBatchNorm2d.from_batch_norm(layer))
53 | else:
54 | setattr(
55 | FrozenBatchNorm2d._getattr_nested(m, nest[:-1]),
56 | nest[-1],
57 | FrozenBatchNorm2d.from_batch_norm(layer)
58 | )
59 |
60 | def forward(self, x):
61 | # Cast all fixed parameters to half() if necessary
62 | if x.dtype == torch.float16:
63 | self.weight = self.weight.half()
64 | self.bias = self.bias.half()
65 | self.running_mean = self.running_mean.half()
66 | self.running_var = self.running_var.half()
67 |
68 | scale = self.weight * self.running_var.rsqrt()
69 | bias = self.bias - self.running_mean * scale
70 | scale = scale.reshape(1, -1, 1, 1)
71 | bias = bias.reshape(1, -1, 1, 1)
72 | return x * scale + bias
73 |
74 |
75 | def freeze_network(network, freeze=False):
76 | if freeze:
77 | for p in network.parameters():
78 | p.requires_grad = False
79 | return network
80 |
--------------------------------------------------------------------------------
/midiffusion/networks/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def axis_aligned_bbox_overlaps_3d(bboxes1,
5 | bboxes2,
6 | mode='iou',
7 | is_aligned=False,
8 | eps=1e-6):
9 | """Calculate overlap between two set of axis aligned 3D bboxes. If
10 | ``is_aligned`` is ``False``, then calculate the overlaps between each bbox
11 | of bboxes1 and bboxes2, otherwise the overlaps between each aligned pair of
12 | bboxes1 and bboxes2.
13 | Args:
14 | bboxes1 (Tensor): shape (B, m, 6) in
15 | format or empty.
16 | bboxes2 (Tensor): shape (B, n, 6) in
17 | format or empty.
18 | B indicates the batch dim, in shape (B1, B2, ..., Bn).
19 | If ``is_aligned`` is ``True``, then m and n must be equal.
20 | mode (str): "iou" (intersection over union) or "giou" (generalized
21 | intersection over union).
22 | is_aligned (bool, optional): If True, then m and n must be equal.
23 | Defaults to False.
24 | eps (float, optional): A value added to the denominator for numerical
25 | stability. Defaults to 1e-6.
26 | Returns:
27 | Tensor: shape (m, n) if ``is_aligned`` is False else shape (m,)
28 | """
29 | """https://github.com/open-mmlab/mmdetection3d/blob/master/mmdet3d/core/bbox/iou_calculators/iou3d_calculator.py"""
30 | assert mode in ['iou', 'giou'], f'Unsupported mode {mode}'
31 | # Either the boxes are empty or the length of boxes's last dimension is 6
32 | assert (bboxes1.size(-1) == 6 or bboxes1.size(0) == 0)
33 | assert (bboxes2.size(-1) == 6 or bboxes2.size(0) == 0)
34 |
35 | # Batch dim must be the same
36 | # Batch dim: (B1, B2, ... Bn)
37 | assert bboxes1.shape[:-2] == bboxes2.shape[:-2]
38 | batch_shape = bboxes1.shape[:-2]
39 |
40 | rows = bboxes1.size(-2)
41 | cols = bboxes2.size(-2)
42 | if is_aligned:
43 | assert rows == cols
44 |
45 | if rows * cols == 0:
46 | if is_aligned:
47 | return bboxes1.new(batch_shape + (rows, ))
48 | else:
49 | return bboxes1.new(batch_shape + (rows, cols))
50 |
51 | area1 = (bboxes1[..., 3] -
52 | bboxes1[..., 0]) * (bboxes1[..., 4] - bboxes1[..., 1]) * (
53 | bboxes1[..., 5] - bboxes1[..., 2])
54 | area2 = (bboxes2[..., 3] -
55 | bboxes2[..., 0]) * (bboxes2[..., 4] - bboxes2[..., 1]) * (
56 | bboxes2[..., 5] - bboxes2[..., 2])
57 |
58 | if is_aligned:
59 | lt = torch.max(bboxes1[..., :3], bboxes2[..., :3]) # [B, rows, 3]
60 | rb = torch.min(bboxes1[..., 3:], bboxes2[..., 3:]) # [B, rows, 3]
61 |
62 | wh = (rb - lt).clamp(min=0) # [B, rows, 2]
63 | overlap = wh[..., 0] * wh[..., 1] * wh[..., 2]
64 |
65 | if mode in ['iou', 'giou']:
66 | union = area1 + area2 - overlap
67 | else:
68 | union = area1
69 | if mode == 'giou':
70 | enclosed_lt = torch.min(bboxes1[..., :3], bboxes2[..., :3])
71 | enclosed_rb = torch.max(bboxes1[..., 3:], bboxes2[..., 3:])
72 | else:
73 | lt = torch.max(bboxes1[..., :, None, :3],
74 | bboxes2[..., None, :, :3]) # [B, rows, cols, 3]
75 | rb = torch.min(bboxes1[..., :, None, 3:],
76 | bboxes2[..., None, :, 3:]) # [B, rows, cols, 3]
77 |
78 | wh = (rb - lt).clamp(min=0) # [B, rows, cols, 3]
79 | overlap = wh[..., 0] * wh[..., 1] * wh[..., 2]
80 |
81 | if mode in ['iou', 'giou']:
82 | union = area1[..., None] + area2[..., None, :] - overlap
83 | if mode == 'giou':
84 | enclosed_lt = torch.min(bboxes1[..., :, None, :3],
85 | bboxes2[..., None, :, :3])
86 | enclosed_rb = torch.max(bboxes1[..., :, None, 3:],
87 | bboxes2[..., None, :, 3:])
88 |
89 | eps = union.new_tensor([eps])
90 | union = torch.max(union, eps)
91 | ious = overlap / union
92 | if mode in ['iou']:
93 | return ious
94 | # calculate gious
95 | enclose_wh = (enclosed_rb - enclosed_lt).clamp(min=0)
96 | enclose_area = enclose_wh[..., 0] * enclose_wh[..., 1] * enclose_wh[..., 2]
97 | enclose_area = torch.max(enclose_area, eps)
98 | gious = ious - (enclose_area - union) / enclose_area
99 | return gious
100 |
--------------------------------------------------------------------------------
/midiffusion/stats_logger.py:
--------------------------------------------------------------------------------
1 | #
2 | # Modified from:
3 | # https://github.com/nv-tlabs/ATISS.
4 | #
5 |
6 | import sys
7 | import wandb
8 |
9 |
10 | class AverageAggregator(object):
11 | def __init__(self):
12 | self._value = 0
13 | self._count = 0
14 |
15 | @property
16 | def value(self):
17 | return self._value / self._count
18 |
19 | @value.setter
20 | def value(self, val):
21 | self._value += val
22 | self._count += 1
23 |
24 |
25 | class StatsLogger(object):
26 | __INSTANCE = None
27 |
28 | def __init__(self):
29 | if StatsLogger.__INSTANCE is not None:
30 | raise RuntimeError("StatsLogger should not be directly created")
31 |
32 | self._values = dict()
33 | self._loss = AverageAggregator()
34 | self._output_files = [sys.stdout]
35 |
36 | def add_output_file(self, f):
37 | self._output_files.append(f)
38 |
39 | def __getitem__(self, key):
40 | if key not in self._values:
41 | self._values[key] = AverageAggregator()
42 | return self._values[key]
43 |
44 | def clear(self):
45 | self._values.clear()
46 | self._loss = AverageAggregator()
47 | for f in self._output_files:
48 | if f.isatty():
49 | print(file=f, flush=True)
50 |
51 | def print_progress(self, epoch, batch, loss, precision="{:.5f}"):
52 | self._loss.value = loss
53 | fmt = "epoch: {} - batch: {} - loss: " + precision
54 | msg = fmt.format(epoch, batch, self._loss.value)
55 | for k, v in self._values.items():
56 | msg += " - " + k + ": " + precision.format(v.value)
57 | for f in self._output_files:
58 | if f.isatty():
59 | print(msg + "\b"*len(msg), end="", flush=True, file=f)
60 | else:
61 | print(msg, flush=True, file=f)
62 |
63 | @classmethod
64 | def instance(cls):
65 | if StatsLogger.__INSTANCE is None:
66 | StatsLogger.__INSTANCE = cls()
67 | return StatsLogger.__INSTANCE
68 |
69 |
70 | class WandB(StatsLogger):
71 | """Log the metrics in weights and biases. Code adapted from
72 | https://github.com/angeloskath/pytorch-boilerplate/blob/main/pbp/callbacks/wandb.py
73 |
74 | Arguments
75 | ---------
76 | project: str, the project name to use in weights and biases
77 | (default: '')
78 | watch: bool, use wandb.watch() on the model (default: True)
79 | log_frequency: int, the log frequency passed to wandb.watch
80 | (default: 10)
81 | """
82 | def init(
83 | self,
84 | experiment_arguments,
85 | model,
86 | project="experiment",
87 | name="experiment_name",
88 | watch=True,
89 | log_frequency=10,
90 | id=None
91 | ):
92 | self.project = project
93 | self.experiment_name = name
94 | self.watch = watch
95 | self.log_frequency = log_frequency
96 | self._epoch = 0
97 | self._validation = False
98 | # Login to wandb
99 | wandb.login()
100 |
101 | # Init the run
102 | wandb.init(
103 | project=(self.project or None),
104 | name=(self.experiment_name or None),
105 | config=dict(experiment_arguments.items()),
106 | id=id
107 | )
108 | self.id = wandb.run.id
109 |
110 | if self.watch:
111 | wandb.watch(model, log_freq=self.log_frequency)
112 |
113 | def print_progress(self, epoch, batch, loss, precision="{:.5f}"):
114 | super().print_progress(epoch, batch, loss, precision)
115 |
116 | self._validation = epoch < 0
117 | if not self._validation:
118 | self._epoch = epoch
119 |
120 | def clear(self):
121 | # Before clearing everything out send it to wandb
122 | prefix = "val_" if self._validation else ""
123 | values = {
124 | prefix+k: v.value
125 | for k, v in self._values.items()
126 | }
127 | values[prefix+"loss"] = self._loss.value
128 | wandb.log(values, step=self._epoch)
129 |
130 | super().clear()
131 |
--------------------------------------------------------------------------------
/scripts/generate_results.py:
--------------------------------------------------------------------------------
1 | """Script used for generating results using a previously trained model."""
2 | import argparse
3 | import os
4 | import sys
5 | import shutil
6 | import pickle
7 |
8 | import numpy as np
9 | import torch
10 |
11 | from utils import PROJ_DIR, load_config, update_data_file_paths
12 | from threed_front.datasets import get_raw_dataset
13 | from threed_front.evaluation import ThreedFrontResults
14 | from midiffusion.datasets.threed_front_encoding import get_dataset_raw_and_encoded
15 | from midiffusion.networks import build_network
16 | from midiffusion.evaluation.utils import generate_layouts
17 |
18 |
19 | def main(argv):
20 | parser = argparse.ArgumentParser(
21 | description="Generate scenes using a previously trained model"
22 | )
23 |
24 | parser.add_argument(
25 | "weight_file",
26 | help="Path to a pretrained model"
27 | )
28 | parser.add_argument(
29 | "--config_file",
30 | default=None,
31 | help="Path to the file that contains the experiment configuration"
32 | "(default: config.yaml in the model directory)"
33 | )
34 | parser.add_argument(
35 | "--output_directory",
36 | default=PROJ_DIR+"/output/predicted_results/",
37 | help="Path to the output directory"
38 | )
39 | parser.add_argument(
40 | "--n_known_objects",
41 | default=0,
42 | type=int,
43 | help="Number of existing objects for scene completion task"
44 | )
45 | parser.add_argument(
46 | "--experiment",
47 | default="synthesis",
48 | choices=[
49 | "synthesis",
50 | "scene_completion",
51 | "furniture_arrangement",
52 | "object_conditioned",
53 | "scene_completion_conditioned"
54 | ],
55 | help="Experiment name"
56 | )
57 | parser.add_argument(
58 | "--seed",
59 | type=int,
60 | default=0,
61 | help="Seed for the sampling floor plan"
62 | )
63 | parser.add_argument(
64 | "--n_syn_scenes",
65 | default=1000,
66 | type=int,
67 | help="Number of scenes to be synthesized"
68 | )
69 | parser.add_argument(
70 | "--batch_size",
71 | default=128,
72 | type=int,
73 | help="Number of synthesized scene in each batch"
74 | )
75 | parser.add_argument(
76 | "--result_tag",
77 | default=None,
78 | help="Result sub-directory name"
79 | )
80 | parser.add_argument(
81 | "--gpu",
82 | type=int,
83 | default=0,
84 | help="GPU ID"
85 | )
86 |
87 | args = parser.parse_args(argv)
88 |
89 | # Set the random seed
90 | np.random.seed(args.seed)
91 | torch.manual_seed(np.random.randint(np.iinfo(np.int32).max))
92 | if torch.cuda.is_available():
93 | torch.cuda.manual_seed_all(np.random.randint(np.iinfo(np.int32).max))
94 |
95 | if args.gpu < torch.cuda.device_count():
96 | device = torch.device("cuda:{}".format(args.gpu))
97 | else:
98 | device = torch.device("cpu")
99 | print("Running code on", device)
100 |
101 | # Check if result_dir exists and if it doesn't create it
102 | if args.result_tag is None:
103 | result_dir = args.output_directory
104 | else:
105 | result_dir = os.path.join(args.output_directory, args.result_tag)
106 | if os.path.exists(result_dir) and \
107 | len(os.listdir(result_dir)) > 0:
108 | input("{} direcotry is non-empty. Press any key to remove all files..." \
109 | .format(result_dir))
110 | for fi in os.listdir(result_dir):
111 | os.remove(os.path.join(result_dir, fi))
112 | else:
113 | os.makedirs(result_dir, exist_ok=True)
114 |
115 | # Run control files to save
116 | path_to_config = os.path.join(result_dir, "config.yaml")
117 | path_to_results = os.path.join(result_dir, "results.pkl")
118 |
119 | # Parse the config file
120 | if args.config_file is None:
121 | args.config_file = os.path.join(os.path.dirname(args.weight_file), "config.yaml")
122 | config = load_config(args.config_file)
123 | if "_eval" not in config["data"]["encoding_type"] and args.experiment == "synthesis":
124 | config["data"]["encoding_type"] += "_eval"
125 | if not os.path.exists(path_to_config) or \
126 | not os.path.samefile(args.config_file, path_to_config):
127 | shutil.copyfile(args.config_file, path_to_config)
128 |
129 | # Raw training data (for record keeping)
130 | raw_train_dataset = get_raw_dataset(
131 | update_data_file_paths(config["data"]),
132 | split=config["training"].get("splits", ["train", "val"]),
133 | include_room_mask=config["network"].get("room_mask_condition", True)
134 | )
135 |
136 | # Get Scaled dataset encoding (without data augmentation)
137 | raw_dataset, encoded_dataset = get_dataset_raw_and_encoded(
138 | update_data_file_paths(config["data"]),
139 | split=config["validation"].get("splits", ["test"]),
140 | max_length=config["network"]["sample_num_points"],
141 | include_room_mask=config["network"].get("room_mask_condition", True)
142 | )
143 | print("Loaded {} scenes with {} object types ({} labels):".format(
144 | len(encoded_dataset), encoded_dataset.n_object_types, encoded_dataset.n_classes))
145 | print(encoded_dataset.class_labels)
146 |
147 | # Build network with saved weights
148 | network, _, _ = build_network(
149 | encoded_dataset.n_object_types, config, args.weight_file, device=device
150 | )
151 | network.eval()
152 |
153 | # Generate final results
154 | sampled_indices, layout_list = generate_layouts(
155 | network, encoded_dataset, config, args.n_syn_scenes, "random",
156 | experiment=args.experiment, num_known_objects=args.n_known_objects,
157 | batch_size=args.batch_size, device=device
158 | )
159 |
160 | threed_front_results = ThreedFrontResults(
161 | raw_train_dataset, raw_dataset, config, sampled_indices, layout_list
162 | )
163 |
164 | pickle.dump(threed_front_results, open(path_to_results, "wb"))
165 | print("Saved result to:", path_to_results)
166 |
167 | kl_divergence = threed_front_results.kl_divergence()
168 | print("object category kl divergence:", kl_divergence)
169 |
170 |
171 | if __name__ == "__main__":
172 | main(sys.argv[1:])
--------------------------------------------------------------------------------
/scripts/train_diffusion.py:
--------------------------------------------------------------------------------
1 | #
2 | # Modified from:
3 | # https://github.com/nv-tlabs/ATISS.
4 | #
5 |
6 | """Script used to train a diffusion models."""
7 | import argparse
8 | import os
9 | import sys
10 | import shutil
11 | import time
12 | import numpy as np
13 |
14 | import torch
15 | from torch.utils.data import DataLoader
16 |
17 | from utils import PROJ_DIR, update_data_file_paths, id_generator, save_experiment_params, \
18 | load_config, get_time_str, load_checkpoints, save_checkpoints
19 | from midiffusion.datasets.threed_front_encoding import get_encoded_dataset
20 | from midiffusion.networks import build_network, optimizer_factory, schedule_factory, adjust_learning_rate
21 | from midiffusion.stats_logger import StatsLogger, WandB
22 |
23 |
24 | def main(argv):
25 | parser = argparse.ArgumentParser(
26 | description="Train a generative model on bounding boxes"
27 | )
28 |
29 | parser.add_argument(
30 | "config_file",
31 | help="Path to the file that contains the experiment configuration"
32 | )
33 | parser.add_argument(
34 | "--output_directory",
35 | default=PROJ_DIR+"/output/log",
36 | help="Path to the output directory"
37 | )
38 | parser.add_argument(
39 | "--weight_file",
40 | default=None,
41 | help=("The path to a previously trained model to continue"
42 | " the training from")
43 | )
44 | parser.add_argument(
45 | "--continue_from_epoch",
46 | default=0,
47 | type=int,
48 | help="Continue training from epoch (default=0)"
49 | )
50 | parser.add_argument(
51 | "--n_processes",
52 | type=int,
53 | default=0,
54 | help="The number of processed spawned by the batch provider"
55 | )
56 | parser.add_argument(
57 | "--seed",
58 | type=int,
59 | default=0,
60 | help="Seed for the PRNG"
61 | )
62 | parser.add_argument(
63 | "--experiment_tag",
64 | default=None,
65 | help="Tag that refers to the current experiment"
66 | )
67 | parser.add_argument(
68 | "--with_wandb_logger",
69 | action="store_true",
70 | help="Use wandB for logging the training progress"
71 | )
72 | parser.add_argument(
73 | "--gpu",
74 | type=int,
75 | default=0,
76 | help="GPU ID"
77 | )
78 |
79 | args = parser.parse_args(argv)
80 |
81 | # Set the random seed
82 | np.random.seed(args.seed)
83 | torch.manual_seed(np.random.randint(np.iinfo(np.int32).max))
84 | if torch.cuda.is_available():
85 | torch.cuda.manual_seed_all(np.random.randint(np.iinfo(np.int32).max))
86 |
87 | if args.gpu < torch.cuda.device_count():
88 | device = torch.device("cuda:{}".format(args.gpu))
89 | else:
90 | device = torch.device("cpu")
91 | print("Running code on", device)
92 |
93 | # Create an experiment directory using the experiment_tag
94 | if args.experiment_tag is None:
95 | experiment_tag = id_generator(9)
96 | else:
97 | experiment_tag = args.experiment_tag
98 | experiment_directory = os.path.join(args.output_directory, experiment_tag)
99 | os.makedirs(experiment_directory, exist_ok=True)
100 | # output files
101 | path_to_config = os.path.join(experiment_directory, "config.yaml")
102 | path_to_bounds = os.path.join(experiment_directory, "bounds.npz")
103 | path_to_params = os.path.join(experiment_directory, "params.json")
104 | path_to_stats = os.path.join(experiment_directory, "stats.txt")
105 | path_to_best_model = os.path.join(experiment_directory, "best_model.pt")
106 |
107 | # Parse the config file
108 | config = load_config(args.config_file)
109 | shutil.copyfile(args.config_file, path_to_config)
110 |
111 | train_dataset = get_encoded_dataset(
112 | update_data_file_paths(config["data"]),
113 | path_to_bounds=None,
114 | augmentations=config["data"].get("augmentations", None),
115 | split=config["training"].get("splits", ["train", "val"]),
116 | max_length=config["network"]["sample_num_points"],
117 | include_room_mask=(config["network"]["room_mask_condition"] and \
118 | config["feature_extractor"]["name"]=="resnet18")
119 | )
120 | # Compute the bounds for this experiment, save them to a file in the
121 | # experiment directory and pass them to the validation dataset
122 | np.savez(
123 | path_to_bounds,
124 | sizes=train_dataset.bounds["sizes"],
125 | translations=train_dataset.bounds["translations"],
126 | angles=train_dataset.bounds["angles"],
127 | #add objfeats
128 | objfeats=train_dataset.bounds["objfeats"],
129 | )
130 |
131 | validation_dataset = get_encoded_dataset(
132 | update_data_file_paths(config["data"]),
133 | path_to_bounds=path_to_bounds,
134 | augmentations=None,
135 | split=config["validation"].get("splits", ["test"]),
136 | max_length=config["network"]["sample_num_points"],
137 | include_room_mask=(config["network"]["room_mask_condition"] and \
138 | config["feature_extractor"]["name"]=="resnet18")
139 | )
140 |
141 | train_loader = DataLoader(
142 | train_dataset,
143 | batch_size=config["training"].get("batch_size", 128),
144 | num_workers=args.n_processes,
145 | collate_fn=train_dataset.collate_fn,
146 | shuffle=True
147 | )
148 | val_loader = DataLoader(
149 | validation_dataset,
150 | batch_size=config["validation"].get("batch_size", 1),
151 | num_workers=args.n_processes,
152 | collate_fn=validation_dataset.collate_fn,
153 | shuffle=False
154 | )
155 |
156 | # Make sure that the train_dataset and the validation_dataset have the same
157 | # number of object categories
158 | assert train_dataset.object_types == validation_dataset.object_types
159 | print("Saved dataset bounds to {}".format(path_to_bounds))
160 | print(" Loaded {} training scenes with {} object types".format(
161 | len(train_dataset), train_dataset.n_object_types)
162 | )
163 | print(" Loaded {} validation scenes with {} object types".format(
164 | len(validation_dataset), validation_dataset.n_object_types)
165 | )
166 |
167 | # Build the network architecture to be used for training
168 | network, train_on_batch, validate_on_batch = build_network(
169 | train_dataset.n_object_types, config, args.weight_file, device=device
170 | )
171 | n_all_params = int(sum([np.prod(p.size()) for p in network.parameters()]))
172 | n_trainable_params = int(sum([np.prod(p.size()) for p in \
173 | filter(lambda p: p.requires_grad, network.parameters())]))
174 | print(f"Number of parameters in {network.__class__.__name__}: "
175 | f"{n_trainable_params} / {n_all_params}")
176 | config["network"]["n_params"] = n_trainable_params
177 |
178 | # Build an optimizer object to compute the gradients of the parameters
179 | optimizer = optimizer_factory(config["training"], \
180 | filter(lambda p: p.requires_grad, network.parameters()))
181 |
182 | # Load the checkpoints if they exist in the experiment directory
183 | wandb_id = load_checkpoints(
184 | network, optimizer, experiment_directory, args, device
185 | )
186 | epochs = config["training"].get("epochs", 150)
187 | if args.continue_from_epoch == epochs:
188 | return
189 |
190 | # Load the learning rate scheduler
191 | lr_scheduler = schedule_factory(config["training"])
192 |
193 | # Initialize the logger
194 | if args.with_wandb_logger:
195 | WandB.instance().init(
196 | config,
197 | model=network,
198 | project=config["logger"].get("project", "MiDiffusion"),
199 | name=experiment_tag,
200 | watch=False,
201 | log_frequency=10,
202 | id=wandb_id,
203 | )
204 | args.with_wandb_logger = WandB.instance().id
205 |
206 | # Log the stats to a file
207 | StatsLogger.instance().add_output_file(open(path_to_stats, "w"))
208 |
209 | # Save the parameters of this run to a file
210 | save_experiment_params(args, experiment_tag, path_to_params)
211 | print("Save experiment statistics in {}".format(experiment_directory))
212 |
213 | # Do the training
214 | max_grad_norm = config["training"].get("max_grad_norm", None)
215 | save_every = config["training"].get("save_frequency", 10)
216 | val_every = config["validation"].get("frequency", 100)
217 |
218 | min_val_loss = float("inf")
219 | min_val_loss_epoch = 0
220 | tic = time.perf_counter()
221 | for i in range(args.continue_from_epoch + 1, epochs + 1):
222 | # adjust learning rate
223 | adjust_learning_rate(lr_scheduler, optimizer, i)
224 |
225 | network.train()
226 | #for b, sample in zip(range(steps_per_epoch), yield_forever(train_loader)):
227 | for b, sample in enumerate(train_loader):
228 | # Move everything to device
229 | for k, v in sample.items():
230 | if not isinstance(v, list):
231 | sample[k] = v.to(device)
232 | batch_loss = train_on_batch(network, optimizer, sample, max_grad_norm)
233 | StatsLogger.instance().print_progress(i, b+1, batch_loss)
234 |
235 | if (i % save_every) == 0:
236 | save_checkpoints(i, network, optimizer, experiment_directory)
237 | StatsLogger.instance().clear()
238 |
239 | if i % val_every == 0 and i > 0:
240 | print("====> Validation Epoch ====>")
241 | network.eval()
242 | val_loss_total = 0.0
243 | for b, sample in enumerate(val_loader):
244 | # Move everything to device
245 | for k, v in sample.items():
246 | if not isinstance(v, list):
247 | sample[k] = v.to(device)
248 | batch_loss = validate_on_batch(network, sample)
249 | StatsLogger.instance().print_progress(-1, b+1, batch_loss)
250 | val_loss_total += batch_loss
251 | StatsLogger.instance().clear()
252 |
253 | toc = time.perf_counter()
254 | elapsed_time = toc - tic
255 | estimated_total_time = (elapsed_time / (i - args.continue_from_epoch))\
256 | * (epochs - args.continue_from_epoch)
257 | print("====> [Elapsed time: {}] / [Estimated total: {}] ====>".format(
258 | get_time_str(elapsed_time), get_time_str(estimated_total_time)
259 | ))
260 |
261 | if val_loss_total < min_val_loss:
262 | # Overwrite best_model.pt
263 | min_val_loss = val_loss_total
264 | min_val_loss_epoch = i
265 | torch.save(network.state_dict(), path_to_best_model)
266 |
267 | print("Best model saved at epcho {} with validation loss = {}".format(
268 | min_val_loss_epoch, min_val_loss), file=open(path_to_stats, "a")
269 | )
270 |
271 |
272 | if __name__ == "__main__":
273 | main(sys.argv[1:])
274 |
--------------------------------------------------------------------------------
/scripts/utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # modified from:
3 | # https://github.com/nv-tlabs/ATISS.
4 | #
5 |
6 | import yaml
7 | try:
8 | from yaml import CLoader as Loader
9 | except ImportError:
10 | from yaml import Loader
11 | import json
12 | import time
13 | import string
14 | import os
15 | import random
16 | import subprocess
17 | import torch
18 |
19 | PROJ_DIR = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
20 | PATH_TO_DATASET_FILES = os.path.join(PROJ_DIR, "../ThreedFront/dataset_files/")
21 | PATH_TO_PROCESSED_DATA = os.path.join(PROJ_DIR, "../ThreedFront/output/3d_front_processed/")
22 |
23 |
24 | def load_config(config_file):
25 | with open(config_file, "r") as f:
26 | config = yaml.load(f, Loader=Loader)
27 | return config
28 |
29 |
30 | def update_data_file_paths(config_data):
31 | config_data["dataset_directory"] = \
32 | os.path.join(PATH_TO_PROCESSED_DATA, config_data["dataset_directory"])
33 | config_data["annotation_file"] = \
34 | os.path.join(PATH_TO_DATASET_FILES, config_data["annotation_file"])
35 | return config_data
36 |
37 |
38 | def id_generator(size=6, chars=string.ascii_uppercase + string.digits):
39 | return ''.join(random.choice(chars) for _ in range(size))
40 |
41 |
42 | def save_experiment_params(args, experiment_tag, path_to_params):
43 | t = vars(args)
44 | params = {k: str(v) for k, v in t.items()}
45 |
46 | git_dir = os.path.dirname(os.path.realpath(__file__))
47 | git_head_hash = "foo"
48 | try:
49 | git_head_hash = subprocess.check_output(
50 | ['git', 'rev-parse', 'HEAD']
51 | ).strip()
52 | except subprocess.CalledProcessError:
53 | # Keep the current working directory to move back in a bit
54 | cwd = os.getcwd()
55 | os.chdir(git_dir)
56 | git_head_hash = subprocess.check_output(
57 | ['git', 'rev-parse', 'HEAD']
58 | ).strip()
59 | os.chdir(cwd)
60 | params["git-commit"] = str(git_head_hash)
61 | params["experiment_tag"] = experiment_tag
62 | for k, v in list(params.items()):
63 | if v == "":
64 | params[k] = None
65 | if hasattr(args, "config_file"):
66 | config = load_config(args.config_file)
67 | params.update(config)
68 | with open(path_to_params, "w") as f:
69 | json.dump(params, f, indent=4)
70 |
71 |
72 | def get_time_str(seconds):
73 | hms_str = time.strftime("%H:%M:%S", time.gmtime(seconds))
74 | if seconds < 24 * 3600:
75 | return hms_str
76 |
77 | day_str = f"{int(seconds // (24 * 3600))} day "
78 | return day_str + hms_str
79 |
80 |
81 | def yield_forever(iterator):
82 | while True:
83 | for x in iterator:
84 | yield x
85 |
86 |
87 | def load_checkpoints(model, optimizer, experiment_directory, args, device):
88 | model_files = [
89 | f for f in os.listdir(experiment_directory)
90 | if f.startswith("model_")
91 | ]
92 | if len(model_files) == 0:
93 | return None
94 | ids = [int(f[6:]) for f in model_files]
95 | max_id = max(ids)
96 | model_path = os.path.join(
97 | experiment_directory, "model_{:05d}".format(max_id)
98 | )
99 | opt_path = os.path.join(
100 | experiment_directory, "opt_{:05d}".format(max_id)
101 | )
102 | if not (os.path.exists(model_path) and os.path.exists(opt_path)):
103 | return None
104 |
105 | print("Loading model checkpoint from {}".format(model_path))
106 | model.load_state_dict(torch.load(model_path, map_location=device))
107 | print("Loading optimizer checkpoint from {}".format(opt_path))
108 | optimizer.load_state_dict(
109 | torch.load(opt_path, map_location=device)
110 | )
111 | args.continue_from_epoch = max_id
112 |
113 | params_path = os.path.join(experiment_directory, "params.json")
114 | if os.path.exists(params_path) and args.with_wandb_logger:
115 | wandb_id = json.load(open(params_path, "r")).get("with_wandb_logger")
116 | return wandb_id
117 | else:
118 | return None
119 |
120 |
121 | def save_checkpoints(epoch, model, optimizer, experiment_directory):
122 | torch.save(
123 | model.state_dict(),
124 | os.path.join(experiment_directory, "model_{:05d}").format(epoch)
125 | )
126 | torch.save(
127 | optimizer.state_dict(),
128 | os.path.join(experiment_directory, "opt_{:05d}").format(epoch)
129 | )
130 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import find_packages, setup
2 |
3 |
4 | def get_install_requirements():
5 | return [
6 | "einops",
7 | "numpy",
8 | "pyyaml",
9 | "torch",
10 | "torchvision",
11 | "scipy",
12 | "tqdm",
13 | "wandb",
14 | "threed-front",
15 | ]
16 |
17 |
18 | def setup_package():
19 | with open("README.md") as f:
20 | long_description = f.read()
21 | setup(
22 | name='MiDiffusion',
23 | maintainer="Siyi Hu",
24 | maintainer_email="siyihu02@gmail.com",
25 | version='0.1',
26 | license='BSD-2-Clause',
27 | classifiers=[
28 | "Intended Audience :: Science/Research",
29 | "Intended Audience :: Developers",
30 | "Topic :: Scientific/Engineering",
31 | "Programming Language :: Python :: 3.8",
32 | ],
33 | description='Mixed diffusion models for synthetic 3D scene generation',
34 | long_description=long_description,
35 | long_description_content_type="text/markdown",
36 | packages=find_packages(include=['midiffusion']),
37 | install_requires=get_install_requirements(),
38 | )
39 |
40 |
41 | if __name__ == "__main__":
42 | setup_package()
43 |
--------------------------------------------------------------------------------