├── .gitignore ├── LICENSE ├── README.md ├── assets ├── predictions.jpg └── teaser.jpg ├── data ├── environment.yml └── examples │ └── coco │ ├── panoptic_images │ ├── 000000012280.png │ ├── 000000014226.png │ ├── 000000042888.png │ ├── 000000084752.png │ ├── 000000124975.png │ ├── 000000193884.png │ ├── 000000215259.png │ ├── 000000221693.png │ ├── 000000242934.png │ ├── 000000298697.png │ ├── 000000301376.png │ ├── 000000383443.png │ ├── 000000395343.png │ ├── 000000475064.png │ ├── 000000488664.png │ └── 000000521405.png │ └── rgb_images │ ├── 000000012280.jpg │ ├── 000000014226.jpg │ ├── 000000042888.jpg │ ├── 000000084752.jpg │ ├── 000000124975.jpg │ ├── 000000193884.jpg │ ├── 000000215259.jpg │ ├── 000000221693.jpg │ ├── 000000242934.jpg │ ├── 000000298697.jpg │ ├── 000000301376.jpg │ ├── 000000383443.jpg │ ├── 000000395343.jpg │ ├── 000000475064.jpg │ ├── 000000488664.jpg │ └── 000000521405.jpg ├── ldmseg ├── __init__.py ├── data │ ├── __init__.py │ ├── blip_captions │ │ ├── captions_train2017.json │ │ └── captions_val2017.json │ ├── coco.py │ ├── dataset_base.py │ └── util │ │ ├── mask_generator.py │ │ ├── mypath.py │ │ └── pil_transforms.py ├── evaluations │ ├── __init__.py │ ├── panoptic_evaluation.py │ ├── panoptic_evaluation_agnostic.py │ └── semseg_evaluation.py ├── models │ ├── __init__.py │ ├── descriptors.py │ ├── unet.py │ ├── upscaler.py │ └── vae.py ├── schedulers │ ├── __init__.py │ └── ddim_scheduler.py ├── trainers │ ├── __init__.py │ ├── losses.py │ ├── optim.py │ ├── trainers_ae.py │ └── trainers_ldm_cond.py ├── utils │ ├── __init__.py │ ├── config.py │ ├── detectron2_utils.py │ └── utils.py └── version.py ├── setup.py └── tools ├── configs ├── base │ └── base.yaml ├── config.yaml ├── datasets │ └── coco.yaml ├── distributed │ └── local.yaml └── env │ └── root_paths.yaml ├── main_ae.py ├── main_ldm.py ├── main_ldm_slurm.py └── scripts ├── eval.sh ├── install_env.sh ├── install_env_manual.sh ├── train_ae.sh └── train_diffusion.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # custom 7 | # vsode env 8 | .vscode 9 | outputs 10 | output 11 | pretrained 12 | visualizations 13 | deprecated 14 | *npz 15 | *pb 16 | *.log 17 | *.pkl 18 | *.pt 19 | 20 | # C extensions 21 | *.so 22 | 23 | # Distribution / packaging 24 | .Python 25 | build/ 26 | develop-eggs/ 27 | dist/ 28 | downloads/ 29 | eggs/ 30 | .eggs/ 31 | lib/ 32 | lib64/ 33 | parts/ 34 | sdist/ 35 | var/ 36 | wheels/ 37 | pip-wheel-metadata/ 38 | share/python-wheels/ 39 | *.egg-info/ 40 | .installed.cfg 41 | *.egg 42 | 43 | # Jupyter Notebook 44 | .ipynb_checkpoints 45 | 46 | # pyenv 47 | .python-version 48 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More_considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial 4.0 International Public 58 | License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial 4.0 International Public License ("Public 63 | License"). To the extent this Public License may be interpreted as a 64 | contract, You are granted the Licensed Rights in consideration of Your 65 | acceptance of these terms and conditions, and the Licensor grants You 66 | such rights in consideration of benefits the Licensor receives from 67 | making the Licensed Material available under these terms and 68 | conditions. 69 | 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. Copyright and Similar Rights means copyright and/or similar rights 88 | closely related to copyright including, without limitation, 89 | performance, broadcast, sound recording, and Sui Generis Database 90 | Rights, without regard to how the rights are labeled or 91 | categorized. For purposes of this Public License, the rights 92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 93 | Rights. 94 | d. Effective Technological Measures means those measures that, in the 95 | absence of proper authority, may not be circumvented under laws 96 | fulfilling obligations under Article 11 of the WIPO Copyright 97 | Treaty adopted on December 20, 1996, and/or similar international 98 | agreements. 99 | 100 | e. Exceptions and Limitations means fair use, fair dealing, and/or 101 | any other exception or limitation to Copyright and Similar Rights 102 | that applies to Your use of the Licensed Material. 103 | 104 | f. Licensed Material means the artistic or literary work, database, 105 | or other material to which the Licensor applied this Public 106 | License. 107 | 108 | g. Licensed Rights means the rights granted to You subject to the 109 | terms and conditions of this Public License, which are limited to 110 | all Copyright and Similar Rights that apply to Your use of the 111 | Licensed Material and that the Licensor has authority to license. 112 | 113 | h. Licensor means the individual(s) or entity(ies) granting rights 114 | under this Public License. 115 | 116 | i. NonCommercial means not primarily intended for or directed towards 117 | commercial advantage or monetary compensation. For purposes of 118 | this Public License, the exchange of the Licensed Material for 119 | other material subject to Copyright and Similar Rights by digital 120 | file-sharing or similar means is NonCommercial provided there is 121 | no payment of monetary compensation in connection with the 122 | exchange. 123 | 124 | j. Share means to provide material to the public by any means or 125 | process that requires permission under the Licensed Rights, such 126 | as reproduction, public display, public performance, distribution, 127 | dissemination, communication, or importation, and to make material 128 | available to the public including in ways that members of the 129 | public may access the material from a place and at a time 130 | individually chosen by them. 131 | 132 | k. Sui Generis Database Rights means rights other than copyright 133 | resulting from Directive 96/9/EC of the European Parliament and of 134 | the Council of 11 March 1996 on the legal protection of databases, 135 | as amended and/or succeeded, as well as other essentially 136 | equivalent rights anywhere in the world. 137 | 138 | l. You means the individual or entity exercising the Licensed Rights 139 | under this Public License. Your has a corresponding meaning. 140 | 141 | 142 | Section 2 -- Scope. 143 | 144 | a. License grant. 145 | 146 | 1. Subject to the terms and conditions of this Public License, 147 | the Licensor hereby grants You a worldwide, royalty-free, 148 | non-sublicensable, non-exclusive, irrevocable license to 149 | exercise the Licensed Rights in the Licensed Material to: 150 | 151 | a. reproduce and Share the Licensed Material, in whole or 152 | in part, for NonCommercial purposes only; and 153 | 154 | b. produce, reproduce, and Share Adapted Material for 155 | NonCommercial purposes only. 156 | 157 | 2. Exceptions and Limitations. For the avoidance of doubt, where 158 | Exceptions and Limitations apply to Your use, this Public 159 | License does not apply, and You do not need to comply with 160 | its terms and conditions. 161 | 162 | 3. Term. The term of this Public License is specified in Section 163 | 6(a). 164 | 165 | 4. Media and formats; technical modifications allowed. The 166 | Licensor authorizes You to exercise the Licensed Rights in 167 | all media and formats whether now known or hereafter created, 168 | and to make technical modifications necessary to do so. The 169 | Licensor waives and/or agrees not to assert any right or 170 | authority to forbid You from making technical modifications 171 | necessary to exercise the Licensed Rights, including 172 | technical modifications necessary to circumvent Effective 173 | Technological Measures. For purposes of this Public License, 174 | simply making modifications authorized by this Section 2(a) 175 | (4) never produces Adapted Material. 176 | 177 | 5. Downstream recipients. 178 | 179 | a. Offer from the Licensor -- Licensed Material. Every 180 | recipient of the Licensed Material automatically 181 | receives an offer from the Licensor to exercise the 182 | Licensed Rights under the terms and conditions of this 183 | Public License. 184 | 185 | b. No downstream restrictions. You may not offer or impose 186 | any additional or different terms or conditions on, or 187 | apply any Effective Technological Measures to, the 188 | Licensed Material if doing so restricts exercise of the 189 | Licensed Rights by any recipient of the Licensed 190 | Material. 191 | 192 | 6. No endorsement. Nothing in this Public License constitutes or 193 | may be construed as permission to assert or imply that You 194 | are, or that Your use of the Licensed Material is, connected 195 | with, or sponsored, endorsed, or granted official status by, 196 | the Licensor or others designated to receive attribution as 197 | provided in Section 3(a)(1)(A)(i). 198 | 199 | b. Other rights. 200 | 201 | 1. Moral rights, such as the right of integrity, are not 202 | licensed under this Public License, nor are publicity, 203 | privacy, and/or other similar personality rights; however, to 204 | the extent possible, the Licensor waives and/or agrees not to 205 | assert any such rights held by the Licensor to the limited 206 | extent necessary to allow You to exercise the Licensed 207 | Rights, but not otherwise. 208 | 209 | 2. Patent and trademark rights are not licensed under this 210 | Public License. 211 | 212 | 3. To the extent possible, the Licensor waives any right to 213 | collect royalties from You for the exercise of the Licensed 214 | Rights, whether directly or through a collecting society 215 | under any voluntary or waivable statutory or compulsory 216 | licensing scheme. In all other cases the Licensor expressly 217 | reserves any right to collect such royalties, including when 218 | the Licensed Material is used other than for NonCommercial 219 | purposes. 220 | 221 | 222 | Section 3 -- License Conditions. 223 | 224 | Your exercise of the Licensed Rights is expressly made subject to the 225 | following conditions. 226 | 227 | a. Attribution. 228 | 229 | 1. If You Share the Licensed Material (including in modified 230 | form), You must: 231 | 232 | a. retain the following if it is supplied by the Licensor 233 | with the Licensed Material: 234 | 235 | i. identification of the creator(s) of the Licensed 236 | Material and any others designated to receive 237 | attribution, in any reasonable manner requested by 238 | the Licensor (including by pseudonym if 239 | designated); 240 | 241 | ii. a copyright notice; 242 | 243 | iii. a notice that refers to this Public License; 244 | 245 | iv. a notice that refers to the disclaimer of 246 | warranties; 247 | 248 | v. a URI or hyperlink to the Licensed Material to the 249 | extent reasonably practicable; 250 | 251 | b. indicate if You modified the Licensed Material and 252 | retain an indication of any previous modifications; and 253 | 254 | c. indicate the Licensed Material is licensed under this 255 | Public License, and include the text of, or the URI or 256 | hyperlink to, this Public License. 257 | 258 | 2. You may satisfy the conditions in Section 3(a)(1) in any 259 | reasonable manner based on the medium, means, and context in 260 | which You Share the Licensed Material. For example, it may be 261 | reasonable to satisfy the conditions by providing a URI or 262 | hyperlink to a resource that includes the required 263 | information. 264 | 265 | 3. If requested by the Licensor, You must remove any of the 266 | information required by Section 3(a)(1)(A) to the extent 267 | reasonably practicable. 268 | 269 | 4. If You Share Adapted Material You produce, the Adapter's 270 | License You apply must not prevent recipients of the Adapted 271 | Material from complying with this Public License. 272 | 273 | 274 | Section 4 -- Sui Generis Database Rights. 275 | 276 | Where the Licensed Rights include Sui Generis Database Rights that 277 | apply to Your use of the Licensed Material: 278 | 279 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 280 | to extract, reuse, reproduce, and Share all or a substantial 281 | portion of the contents of the database for NonCommercial purposes 282 | only; 283 | 284 | b. if You include all or a substantial portion of the database 285 | contents in a database in which You have Sui Generis Database 286 | Rights, then the database in which You have Sui Generis Database 287 | Rights (but not its individual contents) is Adapted Material; and 288 | 289 | c. You must comply with the conditions in Section 3(a) if You Share 290 | all or a substantial portion of the contents of the database. 291 | 292 | For the avoidance of doubt, this Section 4 supplements and does not 293 | replace Your obligations under this Public License where the Licensed 294 | Rights include other Copyright and Similar Rights. 295 | 296 | 297 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 298 | 299 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 300 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 301 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 302 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 303 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 304 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 305 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 306 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 307 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 308 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 309 | 310 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 311 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 312 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 313 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 314 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 315 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 316 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 317 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 318 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 319 | 320 | c. The disclaimer of warranties and limitation of liability provided 321 | above shall be interpreted in a manner that, to the extent 322 | possible, most closely approximates an absolute disclaimer and 323 | waiver of all liability. 324 | 325 | 326 | Section 6 -- Term and Termination. 327 | 328 | a. This Public License applies for the term of the Copyright and 329 | Similar Rights licensed here. However, if You fail to comply with 330 | this Public License, then Your rights under this Public License 331 | terminate automatically. 332 | 333 | b. Where Your right to use the Licensed Material has terminated under 334 | Section 6(a), it reinstates: 335 | 336 | 1. automatically as of the date the violation is cured, provided 337 | it is cured within 30 days of Your discovery of the 338 | violation; or 339 | 340 | 2. upon express reinstatement by the Licensor. 341 | 342 | For the avoidance of doubt, this Section 6(b) does not affect any 343 | right the Licensor may have to seek remedies for Your violations 344 | of this Public License. 345 | 346 | c. For the avoidance of doubt, the Licensor may also offer the 347 | Licensed Material under separate terms or conditions or stop 348 | distributing the Licensed Material at any time; however, doing so 349 | will not terminate this Public License. 350 | 351 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 352 | License. 353 | 354 | 355 | Section 7 -- Other Terms and Conditions. 356 | 357 | a. The Licensor shall not be bound by any additional or different 358 | terms or conditions communicated by You unless expressly agreed. 359 | 360 | b. Any arrangements, understandings, or agreements regarding the 361 | Licensed Material not stated herein are separate from and 362 | independent of the terms and conditions of this Public License. 363 | 364 | 365 | Section 8 -- Interpretation. 366 | 367 | a. For the avoidance of doubt, this Public License does not, and 368 | shall not be interpreted to, reduce, limit, restrict, or impose 369 | conditions on any use of the Licensed Material that could lawfully 370 | be made without permission under this Public License. 371 | 372 | b. To the extent possible, if any provision of this Public License is 373 | deemed unenforceable, it shall be automatically reformed to the 374 | minimum extent necessary to make it enforceable. If the provision 375 | cannot be reformed, it shall be severed from this Public License 376 | without affecting the enforceability of the remaining terms and 377 | conditions. 378 | 379 | c. No term or condition of this Public License will be waived and no 380 | failure to comply consented to unless expressly agreed to by the 381 | Licensor. 382 | 383 | d. Nothing in this Public License constitutes or may be interpreted 384 | as a limitation upon, or waiver of, any privileges and immunities 385 | that apply to the Licensor or You, including from the legal 386 | processes of any jurisdiction or authority. 387 | 388 | ======================================================================= 389 | 390 | Creative Commons is not a party to its public 391 | licenses. Notwithstanding, Creative Commons may elect to apply one of 392 | its public licenses to material it publishes and in those instances 393 | will be considered the “Licensor.” The text of the Creative Commons 394 | public licenses is dedicated to the public domain under the CC0 Public 395 | Domain Dedication. Except for the limited purpose of indicating that 396 | material is shared under a Creative Commons public license or as 397 | otherwise permitted by the Creative Commons policies published at 398 | creativecommons.org/policies, Creative Commons does not authorize the 399 | use of the trademark "Creative Commons" or any other trademark or logo 400 | of Creative Commons without its prior written consent including, 401 | without limitation, in connection with any unauthorized modifications 402 | to any of its public licenses or any other arrangements, 403 | understandings, or agreements concerning use of licensed material. For 404 | the avoidance of doubt, this paragraph does not form part of the 405 | public licenses. 406 | 407 | Creative Commons may be contacted at creativecommons.org. 408 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A Simple Latent Diffusion Approach for Panoptic Segmentation and Mask Inpainting 2 | 3 | This repo contains the Pytorch implementation of LDMSeg: a simple latent diffusion approach for panoptic segmentation and Mask inpainting. The provided code inlcudes both the training and evaluation. 4 | 5 | > [**A Simple Latent Diffusion Approach for Panoptic Segmentation and Mask Inpainting**](https://arxiv.org/abs/2401.10227) 6 | > 7 | > [Wouter Van Gansbeke](https://wvangansbeke.github.io/) and [Bert De Brabandere](https://scholar.google.be/citations?user=KcMb_7EAAAAJ) 8 |
9 |
10 |
11 | ## Contents
12 | 1. [Introduction](#-introduction)
13 | 0. [Installation](#-installation)
14 | - [Automatic Installation](#automatic-installation)
15 | - [Manual Installation](#manual-installation)
16 | 0. [Training](#-training)
17 | - [Step 1: Train Auto-Encoder](#step-1-train-an-auto-encoder-on-panoptic-segmentation-maps)
18 | - [Step 2: Train LDM](#step-2-train-an-ldm-for-panoptic-segmentation-conditioned-on-rgb-images)
19 | 0. [Pretrained Models](#-pretrained-models)
20 | 0. [Citation](#-citation)
21 | 0. [License](#license)
22 | 0. [Acknoledgements](#acknoledgements)
23 |
24 |
25 | ## 📋 Introduction
26 | This paper presents a conditional latent diffusion approach to tackle the task of panoptic segmentation.
27 | The aim is to omit the need for specialized architectures (e.g., region-proposal-networks or object queries), complex loss functions (e.g., Hungarian matching or based on bounding boxes), and additional post-processing methods (e.g., clustering, NMS, or object pasting).
28 | As a result, we rely on Stable Diffusion, which is a task-agnostic framework. The presented approach consists of two-steps: (1) project the panoptic segmentation masks to a latent space with a shallow auto-encoder; (2) train a diffusion model in latent space, conditioned on RGB images.
29 |
30 | __Key Contributions__: Our contributions are threefold:
31 |
32 | 1. __Generative Framework__: We propose a fully
33 | generative approach based on Latent Diffusion Models
34 | (LDMs) for panoptic segmentation. Our approach builds
35 | upon Stable Diffusion to strive for simplicity and to
36 | ease compute. We first study the class-agnostic setup to liberate
37 | panoptic segmentation from predefined classes.
38 | 2. __General-Purpose Design__: Our approach circumvents spe-
39 | cialized architectures, complex loss functions, and object
40 | detection modules, present in the majority of prevailing
41 | methods. Here, the denoising objective omits the necessity
42 | for object queries, region proposals, and Hungarian match-
43 | ing. This simple and general approach paves the way
44 | for future extensions to a wide range of dense prediction
45 | tasks, e.g., depth prediction, saliency estimation, etc.
46 | 3. __Mask Inpainting__: We successfully apply our approach to
47 | scene-centric datasets and demonstrate its mask inpainting
48 | capabilities for different sparsity levels.
49 | The approach shows promising results for global mask inpainting.
50 |
51 | ## 🛠 Installation
52 |
53 | The code runs with recent Pytorch versions, e.g. 2.0.
54 | Further, you can create a python environment with [Anaconda](https://docs.anaconda.com/anaconda/install/):
55 | ```
56 | conda create -n LDMSeg python=3.11
57 | conda activate LDMSeg
58 | ```
59 | ### Automatic Installation
60 | We recommend to follow the automatic installatation (see `tools/scripts/install_env.sh`). Run the following commands to install the project in editable mode. Note that all dependencies will be installed automatically.
61 | As this might not always work (e.g., due to CUDA or gcc issues), please have a look at the manual installation steps.
62 |
63 | ```shell
64 | python -m pip install -e .
65 | pip install git+https://github.com/facebookresearch/detectron2.git
66 | pip install git+https://github.com/cocodataset/panopticapi.git
67 | ```
68 |
69 | ### Manual Installation
70 | The most important packages can be quickly installed with pip as:
71 | ```shell
72 | pip install torch torchvision einops # Main framework
73 | pip install diffusers transformers xformers accelerate timm # For using pretrained models
74 | pip install scipy opencv-python # For augmentations or loss
75 | pip install pyyaml easydict hydra-core # For using config files
76 | pip install termcolor wandb # For printing and logging
77 | ```
78 | See `data/environment.yml` for a copy of my environment. We also rely on some dependencies from [detectron2](https://detectron2.readthedocs.io/en/latest/tutorials/install.html) and [panopticapi](https://github.com/cocodataset/panopticapi). Please follow their docs.
79 |
80 | ## 🗃️ Dataset
81 | We currently support the [COCO](https://cocodataset.org/#download) dataset. Please follow the docs for installing the images and their corresponding panoptic segmentation masks. Also, take a look at the `ldmseg/data/` directory for a few examples on the COCO dataset. As a sidenote, the adopted structure should be fairly standard:
82 | ```
83 | .
84 | └── coco
85 | ├── annotations
86 | ├── panoptic_semseg_train2017
87 | ├── panoptic_semseg_val2017
88 | ├── panoptic_train2017 -> annotations/panoptic_train2017
89 | ├── panoptic_val2017 -> annotations/panoptic_val2017
90 | ├── test2017
91 | ├── train2017
92 | └── val2017
93 | ```
94 |
95 | Last but not least, change the paths in `configs/env/root_paths.yml` to your dataset root and your desired output directory respectively.
96 |
97 | ## ⏳ Training
98 | The presented approach is two-pronged: First, we train an auto-encoder to represent segmentation maps in a lower dimensional space (e.g., 64x64). Next, we start from pretrained Latent Diffusion Models (LDM), particularly Stable Diffusion, to train a model which can generate panoptic masks from RGB images.
99 | The models can be trained by running the the following commands. By default we will train on the COCO dataset with the base config file defined in `tools/configs/base/base.yaml`. Note that this file will be automatically loaded as we rely on the `hydra` package.
100 |
101 | ### Step 1: Train an Auto-Encoder on Panoptic Segmentation Maps
102 | ```python
103 | python -W ignore tools/main_ae.py \
104 | datasets=coco \
105 | base.train_kwargs.fp16=True \
106 | base.optimizer_name=adamw \
107 | base.optimizer_kwargs.lr=1e-4 \
108 | base.optimizer_kwargs.weight_decay=0.05
109 | ```
110 | More details on passing arguments can be found in `tools/scripts/train_ae.sh`. For example, I run this model for 50k iterations on a single GPU of 23 GB with a total batch size of 16.
111 |
112 | ### Step 2: Train an LDM for Panoptic Segmentation Conditioned on RGB Images
113 | ```python
114 | python -W ignore tools/main_ldm.py \
115 | datasets=coco \
116 | base.train_kwargs.gradient_checkpointing=True \
117 | base.train_kwargs.fp16=True \
118 | base.train_kwargs.weight_dtype=float16 \
119 | base.optimizer_zero_redundancy=True \
120 | base.optimizer_name=adamw \
121 | base.optimizer_kwargs.lr=1e-4 \
122 | base.optimizer_kwargs.weight_decay=0.05 \
123 | base.scheduler_kwargs.weight='max_clamp_snr' \
124 | base.vae_model_kwargs.pretrained_path='$AE_MODEL'
125 | ```
126 | `$AE_MODEL` denotes the path to the model obtained from the previous step.
127 | More details on passing arguments can be found in `tools/scripts/train_diffusion.sh`. For example, I ran this model for 200k iterations on 8 GPUs of 16 GB with a total batch size of 256.
128 |
129 | ## 📊 Pretrained Models
130 |
131 | We're planning to release several trained models. The (class-agnostic) PQ metric is provided on the COCO validation set.
132 |
133 | | Model |\#Params | Dataset | Iters | PQ | SQ | RQ | Download link |
134 | |----------------------------|-----|------------|------------|--------|---|---|---------------------------------------------------------------------------------------------------------|
135 | | [AE](#training) | ~2M | COCO | 66k | - | - | - | [Download](https://drive.google.com/file/d/1wmOGB-Ue47DPGFiPxiBFxHv1h5g5Zooe/view?usp=sharing) (23 MB) |
136 | | [LDM](#training) | ~800M | COCO | 200k | 51.7 | 82.0 | 63.0 | [Download](https://drive.google.com/file/d/1EKuOm_DnSGa0Ff-EkIl6Q1wknZxm4ygB/view?usp=sharing) (3.3 GB) |
137 |
138 | Note: A less powerful AE (i.e., less downsampling or upsampling layers) can often benefit inpainting, as we don't perform additional finetuning.
139 |
140 | The evaluation should look like:
141 | ```python
142 | python -W ignore tools/main_ldm.py \
143 | datasets=coco \
144 | base.sampling_kwargs.num_inference_steps=50 \
145 | base.eval_only=True \
146 | base.load_path=$PRETRAINED_MODEL_PATH \
147 | ```
148 | You can add parameters if necessary. Higher thresholds such as `--base.eval_kwargs.count_th 700` or `--base.eval_kwargs.mask_th 0.9` can further boost the numbers.
149 | However, we use standard values by thresholding at 0.5 and removing segments with an area smaller than 512 for the evaluation.
150 |
151 | To evaluate a pretrained model from above, run `tools/scripts/eval.sh`.
152 |
153 |
154 | Here, we visualize the results:
155 |
156 |
157 |
158 |
159 | ## 🪧 Citation
160 | If you find this repository useful for your research, please consider citing the following paper:
161 |
162 | ```bibtex
163 | @article{vangansbeke2024ldmseg,
164 | title={a simple latent diffusion approach for panoptic segmentation and mask inpainting},
165 | author={Van Gansbeke, Wouter and De Brabandere, Bert},
166 | journal={arxiv preprint arxiv:2401.10227},
167 | year={2024}
168 | }
169 | ```
170 | For any enquiries, please contact the [main author](https://github.com/wvangansbeke).
171 |
172 | ## License
173 |
174 | This software is released under a creative commons license which allows for personal and research use only. For a commercial license please contact the authors. You can view a license summary [here](http://creativecommons.org/licenses/by-nc/4.0/).
175 |
176 |
177 | ## Acknoledgements
178 |
179 | I'm thankful for all the public repositories (see also references in the code), and in particular for the [detectron2](https://github.com/facebookresearch/detectron2) and [diffusers](https://github.com/huggingface/diffusers) libaries.
180 |
--------------------------------------------------------------------------------
/assets/predictions.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/segments-ai/latent-diffusion-segmentation/66b4d172ccb94de81f5a63636d425dcd15805fb3/assets/predictions.jpg
--------------------------------------------------------------------------------
/assets/teaser.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/segments-ai/latent-diffusion-segmentation/66b4d172ccb94de81f5a63636d425dcd15805fb3/assets/teaser.jpg
--------------------------------------------------------------------------------
/data/environment.yml:
--------------------------------------------------------------------------------
1 | name: LDMSeg
2 | channels:
3 | - conda-forge
4 | - pytorch
5 | dependencies:
6 | - _libgcc_mutex=0.1=conda_forge
7 | - _openmp_mutex=4.5=2_gnu
8 | - bzip2=1.0.8=h7f98852_4
9 | - ca-certificates=2023.5.7=hbcca054_0
10 | - ld_impl_linux-64=2.40=h41732ed_0
11 | - libexpat=2.5.0=hcb278e6_1
12 | - libffi=3.4.2=h7f98852_5
13 | - libgcc-ng=12.2.0=h65d4601_19
14 | - libgomp=12.2.0=h65d4601_19
15 | - libnsl=2.0.0=h7f98852_0
16 | - libsqlite=3.42.0=h2797004_0
17 | - libuuid=2.38.1=h0b41bf4_0
18 | - libzlib=1.2.13=h166bdaf_4
19 | - ncurses=6.3=h27087fc_1
20 | - openssl=3.1.0=hd590300_3
21 | - pip=23.1.2=pyhd8ed1ab_0
22 | - python=3.11.3=h2755cc3_0_cpython
23 | - readline=8.2=h8228510_1
24 | - setuptools=67.7.2=pyhd8ed1ab_0
25 | - tk=8.6.12=h27826a3_0
26 | - wheel=0.40.0=pyhd8ed1ab_0
27 | - xz=5.2.6=h166bdaf_0
28 | - pip:
29 | - absl-py==1.4.0
30 | - accelerate==0.19.0
31 | - aiohttp==3.8.4
32 | - aiosignal==1.3.1
33 | - antlr4-python3-runtime==4.9.3
34 | - appdirs==1.4.4
35 | - async-timeout==4.0.2
36 | - attrs==23.1.0
37 | - bitsandbytes==0.41.1
38 | - black==23.7.0
39 | - cachetools==5.3.1
40 | - certifi==2023.5.7
41 | - charset-normalizer==3.1.0
42 | - cityscapesscripts==2.2.2
43 | - click==8.1.3
44 | - cloudpickle==2.2.1
45 | - cmake==3.26.3
46 | - coloredlogs==15.0.1
47 | - contourpy==1.0.7
48 | - cycler==0.11.0
49 | - datasets==2.12.0
50 | - diffusers==0.16.1
51 | - dill==0.3.6
52 | - docker-pycreds==0.4.0
53 | - easydict==1.10
54 | - einops==0.6.1
55 | - filelock==3.12.0
56 | - flake8==6.1.0
57 | - fonttools==4.39.4
58 | - frozenlist==1.3.3
59 | - fsspec==2023.5.0
60 | - fvcore==0.1.5.post20221221
61 | - gitdb==4.0.10
62 | - gitpython==3.1.31
63 | - google-auth==2.22.0
64 | - google-auth-oauthlib==1.0.0
65 | - grpcio==1.56.2
66 | - huggingface-hub==0.14.1
67 | - humanfriendly==10.0
68 | - hydra-core==1.3.2
69 | - idna==3.4
70 | - imageio==2.29.0
71 | - importlib-metadata==6.6.0
72 | - iopath==0.1.9
73 | - jinja2==3.1.2
74 | - kiwisolver==1.4.4
75 | - lit==16.0.5
76 | - markdown==3.4.4
77 | - markupsafe==2.1.2
78 | - matplotlib==3.7.1
79 | - mccabe==0.7.0
80 | - mpmath==1.3.0
81 | - multidict==6.0.4
82 | - multiprocess==0.70.14
83 | - mypy-extensions==1.0.0
84 | - networkx==3.1
85 | - numpy==1.24.3
86 | - nvidia-cublas-cu11==11.10.3.66
87 | - nvidia-cuda-cupti-cu11==11.7.101
88 | - nvidia-cuda-nvrtc-cu11==11.7.99
89 | - nvidia-cuda-runtime-cu11==11.7.99
90 | - nvidia-cudnn-cu11==8.5.0.96
91 | - nvidia-cufft-cu11==10.9.0.58
92 | - nvidia-curand-cu11==10.2.10.91
93 | - nvidia-cusolver-cu11==11.4.0.1
94 | - nvidia-cusparse-cu11==11.7.4.91
95 | - nvidia-nccl-cu11==2.14.3
96 | - nvidia-nvtx-cu11==11.7.91
97 | - oauthlib==3.2.2
98 | - omegaconf==2.3.0
99 | - opencv-python==4.7.0.72
100 | - packaging==23.1
101 | - pandas==2.0.1
102 | - panopticapi==0.1
103 | - pathspec==0.11.2
104 | - pathtools==0.1.2
105 | - pillow==9.5.0
106 | - platformdirs==3.10.0
107 | - portalocker==2.7.0
108 | - protobuf==4.23.1
109 | - psutil==5.9.5
110 | - pyarrow==12.0.0
111 | - pyasn1==0.5.0
112 | - pyasn1-modules==0.3.0
113 | - pycocotools==2.0.6
114 | - pycodestyle==2.11.0
115 | - pyflakes==3.1.0
116 | - pyparsing==3.0.9
117 | - pyquaternion==0.9.9
118 | - pyre-extensions==0.0.29
119 | - python-dateutil==2.8.2
120 | - pytz==2023.3
121 | - pyyaml==6.0
122 | - regex==2023.5.5
123 | - requests==2.31.0
124 | - requests-oauthlib==1.3.1
125 | - responses==0.18.0
126 | - rsa==4.9
127 | - safetensors==0.3.1
128 | - scipy==1.11.1
129 | - sentencepiece==0.1.99
130 | - sentry-sdk==1.24.0
131 | - setproctitle==1.3.2
132 | - six==1.16.0
133 | - smmap==5.0.0
134 | - sympy==1.12
135 | - tabulate==0.9.0
136 | - tensorboard==2.13.0
137 | - tensorboard-data-server==0.7.1
138 | - termcolor==2.3.0
139 | - timm==0.9.2
140 | - tokenizers==0.13.3
141 | - torch==2.0.1
142 | - torchvision==0.15.2
143 | - tqdm==4.65.0
144 | - transformers==4.29.2
145 | - triton==2.0.0
146 | - typing==3.7.4.3
147 | - typing-extensions==4.6.2
148 | - typing-inspect==0.9.0
149 | - tzdata==2023.3
150 | - urllib3==1.26.16
151 | - wandb==0.15.3
152 | - werkzeug==2.3.6
153 | - xformers==0.0.20
154 | - xxhash==3.2.0
155 | - yacs==0.1.8
156 | - yarl==1.9.2
157 | - zipp==3.15.0
158 | prefix: /home/ubuntu/datasets/miniconda3/envs/LDMSeg
159 |
--------------------------------------------------------------------------------
/data/examples/coco/panoptic_images/000000012280.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/segments-ai/latent-diffusion-segmentation/66b4d172ccb94de81f5a63636d425dcd15805fb3/data/examples/coco/panoptic_images/000000012280.png
--------------------------------------------------------------------------------
/data/examples/coco/panoptic_images/000000014226.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/segments-ai/latent-diffusion-segmentation/66b4d172ccb94de81f5a63636d425dcd15805fb3/data/examples/coco/panoptic_images/000000014226.png
--------------------------------------------------------------------------------
/data/examples/coco/panoptic_images/000000042888.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/segments-ai/latent-diffusion-segmentation/66b4d172ccb94de81f5a63636d425dcd15805fb3/data/examples/coco/panoptic_images/000000042888.png
--------------------------------------------------------------------------------
/data/examples/coco/panoptic_images/000000084752.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/segments-ai/latent-diffusion-segmentation/66b4d172ccb94de81f5a63636d425dcd15805fb3/data/examples/coco/panoptic_images/000000084752.png
--------------------------------------------------------------------------------
/data/examples/coco/panoptic_images/000000124975.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/segments-ai/latent-diffusion-segmentation/66b4d172ccb94de81f5a63636d425dcd15805fb3/data/examples/coco/panoptic_images/000000124975.png
--------------------------------------------------------------------------------
/data/examples/coco/panoptic_images/000000193884.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/segments-ai/latent-diffusion-segmentation/66b4d172ccb94de81f5a63636d425dcd15805fb3/data/examples/coco/panoptic_images/000000193884.png
--------------------------------------------------------------------------------
/data/examples/coco/panoptic_images/000000215259.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/segments-ai/latent-diffusion-segmentation/66b4d172ccb94de81f5a63636d425dcd15805fb3/data/examples/coco/panoptic_images/000000215259.png
--------------------------------------------------------------------------------
/data/examples/coco/panoptic_images/000000221693.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/segments-ai/latent-diffusion-segmentation/66b4d172ccb94de81f5a63636d425dcd15805fb3/data/examples/coco/panoptic_images/000000221693.png
--------------------------------------------------------------------------------
/data/examples/coco/panoptic_images/000000242934.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/segments-ai/latent-diffusion-segmentation/66b4d172ccb94de81f5a63636d425dcd15805fb3/data/examples/coco/panoptic_images/000000242934.png
--------------------------------------------------------------------------------
/data/examples/coco/panoptic_images/000000298697.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/segments-ai/latent-diffusion-segmentation/66b4d172ccb94de81f5a63636d425dcd15805fb3/data/examples/coco/panoptic_images/000000298697.png
--------------------------------------------------------------------------------
/data/examples/coco/panoptic_images/000000301376.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/segments-ai/latent-diffusion-segmentation/66b4d172ccb94de81f5a63636d425dcd15805fb3/data/examples/coco/panoptic_images/000000301376.png
--------------------------------------------------------------------------------
/data/examples/coco/panoptic_images/000000383443.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/segments-ai/latent-diffusion-segmentation/66b4d172ccb94de81f5a63636d425dcd15805fb3/data/examples/coco/panoptic_images/000000383443.png
--------------------------------------------------------------------------------
/data/examples/coco/panoptic_images/000000395343.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/segments-ai/latent-diffusion-segmentation/66b4d172ccb94de81f5a63636d425dcd15805fb3/data/examples/coco/panoptic_images/000000395343.png
--------------------------------------------------------------------------------
/data/examples/coco/panoptic_images/000000475064.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/segments-ai/latent-diffusion-segmentation/66b4d172ccb94de81f5a63636d425dcd15805fb3/data/examples/coco/panoptic_images/000000475064.png
--------------------------------------------------------------------------------
/data/examples/coco/panoptic_images/000000488664.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/segments-ai/latent-diffusion-segmentation/66b4d172ccb94de81f5a63636d425dcd15805fb3/data/examples/coco/panoptic_images/000000488664.png
--------------------------------------------------------------------------------
/data/examples/coco/panoptic_images/000000521405.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/segments-ai/latent-diffusion-segmentation/66b4d172ccb94de81f5a63636d425dcd15805fb3/data/examples/coco/panoptic_images/000000521405.png
--------------------------------------------------------------------------------
/data/examples/coco/rgb_images/000000012280.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/segments-ai/latent-diffusion-segmentation/66b4d172ccb94de81f5a63636d425dcd15805fb3/data/examples/coco/rgb_images/000000012280.jpg
--------------------------------------------------------------------------------
/data/examples/coco/rgb_images/000000014226.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/segments-ai/latent-diffusion-segmentation/66b4d172ccb94de81f5a63636d425dcd15805fb3/data/examples/coco/rgb_images/000000014226.jpg
--------------------------------------------------------------------------------
/data/examples/coco/rgb_images/000000042888.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/segments-ai/latent-diffusion-segmentation/66b4d172ccb94de81f5a63636d425dcd15805fb3/data/examples/coco/rgb_images/000000042888.jpg
--------------------------------------------------------------------------------
/data/examples/coco/rgb_images/000000084752.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/segments-ai/latent-diffusion-segmentation/66b4d172ccb94de81f5a63636d425dcd15805fb3/data/examples/coco/rgb_images/000000084752.jpg
--------------------------------------------------------------------------------
/data/examples/coco/rgb_images/000000124975.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/segments-ai/latent-diffusion-segmentation/66b4d172ccb94de81f5a63636d425dcd15805fb3/data/examples/coco/rgb_images/000000124975.jpg
--------------------------------------------------------------------------------
/data/examples/coco/rgb_images/000000193884.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/segments-ai/latent-diffusion-segmentation/66b4d172ccb94de81f5a63636d425dcd15805fb3/data/examples/coco/rgb_images/000000193884.jpg
--------------------------------------------------------------------------------
/data/examples/coco/rgb_images/000000215259.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/segments-ai/latent-diffusion-segmentation/66b4d172ccb94de81f5a63636d425dcd15805fb3/data/examples/coco/rgb_images/000000215259.jpg
--------------------------------------------------------------------------------
/data/examples/coco/rgb_images/000000221693.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/segments-ai/latent-diffusion-segmentation/66b4d172ccb94de81f5a63636d425dcd15805fb3/data/examples/coco/rgb_images/000000221693.jpg
--------------------------------------------------------------------------------
/data/examples/coco/rgb_images/000000242934.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/segments-ai/latent-diffusion-segmentation/66b4d172ccb94de81f5a63636d425dcd15805fb3/data/examples/coco/rgb_images/000000242934.jpg
--------------------------------------------------------------------------------
/data/examples/coco/rgb_images/000000298697.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/segments-ai/latent-diffusion-segmentation/66b4d172ccb94de81f5a63636d425dcd15805fb3/data/examples/coco/rgb_images/000000298697.jpg
--------------------------------------------------------------------------------
/data/examples/coco/rgb_images/000000301376.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/segments-ai/latent-diffusion-segmentation/66b4d172ccb94de81f5a63636d425dcd15805fb3/data/examples/coco/rgb_images/000000301376.jpg
--------------------------------------------------------------------------------
/data/examples/coco/rgb_images/000000383443.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/segments-ai/latent-diffusion-segmentation/66b4d172ccb94de81f5a63636d425dcd15805fb3/data/examples/coco/rgb_images/000000383443.jpg
--------------------------------------------------------------------------------
/data/examples/coco/rgb_images/000000395343.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/segments-ai/latent-diffusion-segmentation/66b4d172ccb94de81f5a63636d425dcd15805fb3/data/examples/coco/rgb_images/000000395343.jpg
--------------------------------------------------------------------------------
/data/examples/coco/rgb_images/000000475064.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/segments-ai/latent-diffusion-segmentation/66b4d172ccb94de81f5a63636d425dcd15805fb3/data/examples/coco/rgb_images/000000475064.jpg
--------------------------------------------------------------------------------
/data/examples/coco/rgb_images/000000488664.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/segments-ai/latent-diffusion-segmentation/66b4d172ccb94de81f5a63636d425dcd15805fb3/data/examples/coco/rgb_images/000000488664.jpg
--------------------------------------------------------------------------------
/data/examples/coco/rgb_images/000000521405.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/segments-ai/latent-diffusion-segmentation/66b4d172ccb94de81f5a63636d425dcd15805fb3/data/examples/coco/rgb_images/000000521405.jpg
--------------------------------------------------------------------------------
/ldmseg/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/segments-ai/latent-diffusion-segmentation/66b4d172ccb94de81f5a63636d425dcd15805fb3/ldmseg/__init__.py
--------------------------------------------------------------------------------
/ldmseg/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .coco import COCO
2 | from .dataset_base import DatasetBase
3 |
4 | __all__ = ['COCO', 'DatasetBase']
5 |
--------------------------------------------------------------------------------
/ldmseg/data/dataset_base.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Wouter Van Gansbeke
3 |
4 | Dataset class to be used for training and evaluation
5 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/)
6 | """
7 |
8 |
9 | import torch
10 | from torch import nn
11 | from torchvision import transforms as T
12 | from typing import Callable, Dict, Tuple, Any, Optional
13 |
14 |
15 | class DatasetBase(object):
16 |
17 | def __init__(
18 | self,
19 | data_dir: str
20 | ) -> None:
21 | """ Base class for datasets
22 | """
23 |
24 | self.data_dir = data_dir
25 |
26 | def get_train_transforms(
27 | self,
28 | p: Dict[str, Any]
29 | ) -> Callable:
30 | """ Returns a composition of transformations to be applied to the training images
31 | """
32 |
33 | normalize = T.Normalize(**p['normalize_params']) if p['normalize'] else nn.Identity()
34 | if p['type'] == 'crop_resize_pil':
35 | from .util import pil_transforms as pil_tr
36 |
37 | size = p['size']
38 | normalize = pil_tr.Normalize(**p['normalize_params']) if p['normalize'] else nn.Identity()
39 | transforms = T.Compose([
40 | pil_tr.RandomHorizontalFlip() if p['flip'] else nn.Identity(),
41 | pil_tr.CropResize((size, size), crop_mode=None),
42 | pil_tr.ToTensor(),
43 | normalize
44 | ])
45 |
46 | else:
47 | raise NotImplementedError(f'Unknown transformation type {p["type"]}')
48 |
49 | return transforms
50 |
51 | def get_val_transforms(
52 | self,
53 | p: Dict
54 | ) -> Callable:
55 | """ Returns a composition of transformations to be applied to the validation images
56 | """
57 |
58 | normalize = T.Normalize(**p['normalize_params']) if p['normalize'] else nn.Identity()
59 | if p['type'] in ['crop_resize_pil', 'random_crop_resize_pil']:
60 | from .util import pil_transforms as pil_tr
61 | size = p['size']
62 | normalize = pil_tr.Normalize(**p['normalize_params']) if p['normalize'] else nn.Identity()
63 | transforms = T.Compose([
64 | pil_tr.CropResize((size, size), crop_mode=None),
65 | pil_tr.ToTensor(),
66 | normalize
67 | ])
68 |
69 | else:
70 | raise NotImplementedError(f'Unknown transformation type {p["type"]}')
71 |
72 | return transforms
73 |
74 | def get_dataset(
75 | self,
76 | db_name: str,
77 | *,
78 | split: Any,
79 | tokenizer: Optional[Callable] = None,
80 | transform: Optional[Callable] = None,
81 | remap_labels: bool = False,
82 | caption_dropout: float = 0.0,
83 | download: bool = False,
84 | overfit: bool = False,
85 | encoding_mode: str = 'color',
86 | caption_type: Optional[str] = 'none',
87 | inpaint_mask_size: Optional[Tuple[int]] = None,
88 | num_classes: Optional[int] = None,
89 | fill_value: Optional[int] = None,
90 | ignore_label: Optional[int] = None,
91 | inpainting_strength: Optional[float] = None,
92 | ) -> Any:
93 | """ Returns the dataset to be used for training or evaluation
94 | """
95 |
96 | if db_name in 'coco':
97 | from .coco import COCO
98 | dataset_cls = COCO
99 | else:
100 | raise NotImplementedError(f'Unknown dataset {db_name}')
101 |
102 | if isinstance(split, list):
103 | datasets = [
104 | dataset_cls(
105 | prefix=self.data_dir,
106 | split=sp,
107 | transform=transform,
108 | download=download,
109 | remap_labels=remap_labels,
110 | tokenizer=tokenizer,
111 | caption_dropout=caption_dropout,
112 | overfit=overfit,
113 | caption_type=caption_type,
114 | encoding_mode=encoding_mode,
115 | inpaint_mask_size=inpaint_mask_size,
116 | num_classes=num_classes,
117 | fill_value=fill_value,
118 | ignore_label=ignore_label,
119 | inpainting_strength=inpainting_strength,
120 | ) for sp in split
121 | ]
122 | return torch.utils.data.ConcatDataset(datasets)
123 | else:
124 | dataset = dataset_cls(
125 | prefix=self.data_dir,
126 | split=split,
127 | transform=transform,
128 | download=download,
129 | remap_labels=remap_labels,
130 | tokenizer=tokenizer,
131 | caption_dropout=caption_dropout,
132 | overfit=overfit,
133 | caption_type=caption_type,
134 | encoding_mode=encoding_mode,
135 | inpaint_mask_size=inpaint_mask_size,
136 | num_classes=num_classes,
137 | fill_value=fill_value,
138 | ignore_label=ignore_label,
139 | inpainting_strength=inpainting_strength,
140 | )
141 | return dataset
142 |
--------------------------------------------------------------------------------
/ldmseg/data/util/mask_generator.py:
--------------------------------------------------------------------------------
1 | import random
2 | import math
3 | import numpy as np
4 |
5 |
6 | class MaskingGenerator:
7 | def __init__(
8 | self,
9 | input_size=(32, 32),
10 | num_masking_patches=512,
11 | min_num_patches=4,
12 | max_num_patches=128,
13 | min_aspect=0.3,
14 | max_aspect=None,
15 | mode='random_global',
16 | ):
17 |
18 | if not isinstance(input_size, (tuple, list)):
19 | input_size = (input_size, ) * 2
20 | self.height, self.width = input_size
21 |
22 | self.num_patches = self.height * self.width
23 | self.num_masking_patches = num_masking_patches
24 |
25 | self.min_num_patches = min_num_patches
26 | self.max_num_patches = num_masking_patches if max_num_patches is None else max_num_patches
27 |
28 | max_aspect = max_aspect or 1 / min_aspect
29 | self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
30 | self.fill_percentage = float(self.num_masking_patches) / self.num_patches
31 | self.mode = mode
32 |
33 | def __repr__(self):
34 | repr_str = "Generator in mode %s with params (%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % (
35 | self.mode,
36 | self.height, self.width, self.min_num_patches, self.max_num_patches,
37 | self.num_masking_patches, self.log_aspect_ratio[0], self.log_aspect_ratio[1])
38 | return repr_str
39 |
40 | def get_shape(self):
41 | return self.height, self.width
42 |
43 | def _mask(self, mask, max_mask_patches):
44 | delta = 0
45 | for _ in range(10):
46 | target_area = random.uniform(self.min_num_patches, max_mask_patches)
47 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
48 | h = int(round(math.sqrt(target_area * aspect_ratio)))
49 | w = int(round(math.sqrt(target_area / aspect_ratio)))
50 | if w < self.width and h < self.height:
51 | top = random.randint(0, self.height - h)
52 | left = random.randint(0, self.width - w)
53 |
54 | num_masked = mask[top: top + h, left: left + w].sum()
55 | # Overlap
56 | if 0 < h * w - num_masked <= max_mask_patches:
57 | for i in range(top, top + h):
58 | for j in range(left, left + w):
59 | if mask[i, j] == 0:
60 | mask[i, j] = 1
61 | delta += 1
62 |
63 | if delta > 0:
64 | break
65 | return delta
66 |
67 | def _get_global_mask(self, mask, verbose=False):
68 | mask_count = 0
69 | # self.num_masking_patches = random.randint(128, 1024)
70 |
71 | num_iters = 0
72 | while mask_count < self.num_masking_patches:
73 | max_mask_patches = self.num_masking_patches - mask_count
74 | max_mask_patches = min(max_mask_patches, self.max_num_patches)
75 | num_iters += 1
76 |
77 | delta = self._mask(mask, max_mask_patches)
78 | if delta == 0:
79 | break
80 | else:
81 | mask_count += delta
82 |
83 | if verbose:
84 | print('num_iters =', num_iters)
85 | return mask
86 |
87 | def _get_local_mask(self, mask, verbose=False, strength=0.5):
88 | mask[np.random.rand(*self.get_shape()) < strength] = 1
89 | if verbose:
90 | print('mask.sum() =', mask.sum())
91 | return mask
92 |
93 | def __call__(self, t=0.5, verbose=False):
94 | # init mask with zeros
95 | mask = np.zeros(shape=self.get_shape(), dtype=np.int64)
96 |
97 | if self.mode == 'random_local':
98 | return self._get_local_mask(mask, verbose=verbose, strength=t)
99 |
100 | elif self.mode == 'random_global':
101 | return self._get_global_mask(mask, verbose=verbose)
102 |
103 | elif self.mode == 'random_global_plus_local':
104 | return (self._get_global_mask(mask, verbose=verbose) +
105 | self._get_local_mask(mask, verbose=verbose, strength=t)) > 0
106 |
107 | elif self.mode == 'object':
108 | assert False
109 |
110 | else:
111 | raise NotImplementedError
112 |
113 |
114 | if __name__ == "__main__":
115 |
116 | from PIL import Image
117 |
118 | # generator = MaskingGenerator((32, 32), 512, 16, 32, mode='random_global_plus_local')
119 | generator = MaskingGenerator((64, 64), mode='random_local')
120 | print(generator)
121 | mask = generator(verbose=True, t=0.15)
122 |
123 | # save mask as image
124 | mask = mask * 255
125 | mask = Image.fromarray(mask.astype(np.uint8))
126 | mask.save("mask.png")
127 |
--------------------------------------------------------------------------------
/ldmseg/data/util/mypath.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Wouter Van Gansbeke
3 |
4 | File with the root path to different datasets
5 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/)
6 | """
7 |
8 | import os
9 |
10 |
11 | class MyPath(object):
12 | @staticmethod
13 | def db_root_dir(database='', prefix='/efs/datasets/'):
14 |
15 | db_names = {'coco', 'cityscapes'}
16 | assert (database in db_names), 'Database {} not available.'.format(database)
17 |
18 | return os.path.join(prefix, database)
19 |
--------------------------------------------------------------------------------
/ldmseg/data/util/pil_transforms.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Wouter Van Gansbeke
3 |
4 | File with augmentations based on PIL
5 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/)
6 | """
7 |
8 | import numbers
9 | from collections.abc import Sequence
10 | from PIL import Image, ImageFilter
11 | import numpy as np
12 | import torch
13 | import torchvision
14 | import torchvision.transforms.functional as F
15 | import random
16 | import math
17 | import sys
18 | from typing import Tuple
19 |
20 | # define interpolation modes
21 | INT_MODES = {
22 | 'image': 'bicubic',
23 | 'semseg': 'nearest',
24 | 'class_labels': 'nearest',
25 | 'mask': 'nearest',
26 | 'image_semseg': 'bicubic',
27 | 'image_class_labels': 'bicubic',
28 | }
29 |
30 |
31 | def resize_operation(img, h, w, mode='bicubic'):
32 | if mode == 'bicubic':
33 | img = img.resize((w, h), resample=getattr(Image, 'Resampling', Image).BICUBIC, reducing_gap=None)
34 | elif mode == 'bilinear':
35 | img = img.resize((w, h), resample=getattr(Image, 'Resampling', Image).BILINEAR, reducing_gap=None)
36 | elif mode == 'nearest':
37 | img = img.resize((w, h), resample=getattr(Image, 'Resampling', Image).NEAREST, reducing_gap=None)
38 | else:
39 | raise NotImplementedError
40 | return img
41 |
42 |
43 | class RandomHorizontalFlip(object):
44 | """Horizontally flip the given image and ground truth randomly with a probability of 0.5."""
45 |
46 | def __call__(self, sample):
47 |
48 | if random.random() < 0.5:
49 | for elem in sample.keys():
50 | if elem in ['meta', 'text']:
51 | continue
52 | else:
53 | sample[elem] = F.hflip(sample[elem])
54 |
55 | return sample
56 |
57 | def __str__(self):
58 | return 'RandomHorizontalFlip(p=0.5)'
59 |
60 |
61 | class RandomColorJitter(object):
62 | """Horizontally flip the given image and ground truth randomly with a probability of 0.5."""
63 |
64 | def __init__(self) -> None:
65 | self.colorjitter = torchvision.transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
66 |
67 | def __call__(self, sample):
68 |
69 | if random.random() < 0.5:
70 | for elem in sample.keys():
71 | if elem in ['image']:
72 | sample[elem] = self.colorjitter(sample[elem])
73 |
74 | return sample
75 |
76 | def __str__(self):
77 | return f'RandomColorJitter(p={self.p})'
78 |
79 |
80 | class RandomGaussianBlur(object):
81 | """Horizontally flip the given image and ground truth randomly with a probability of 0.5."""
82 |
83 | def __init__(self, sigma=[.1, 2.], p=0.2):
84 | self.sigma = sigma
85 | self.p = p
86 |
87 | def __call__(self, sample):
88 | if random.random() < 0.5:
89 | for elem in sample.keys():
90 | if elem in ['image', 'image_semseg']:
91 | sigma = random.uniform(self.sigma[0], self.sigma[1])
92 | sample[elem] = sample[elem].filter(ImageFilter.GaussianBlur(radius=sigma))
93 | return sample
94 |
95 | def __str__(self):
96 | return f'RandomGaussianBlur(p={self.p})'
97 |
98 |
99 | class CropResize(object):
100 | def __init__(self, size, crop_mode=None):
101 | self.size = size
102 | self.crop_mode = None
103 | assert self.crop_mode in ['centre', 'random', None]
104 |
105 | def crop_and_resize(self, img, h, w, mode='bicubic'):
106 | # crop
107 | if self.crop_mode == 'centre':
108 | img_w, img_h = img.size
109 | min_size = min(img_h, img_w)
110 | if min_size == img_h:
111 | margin = (img_w - min_size) // 2
112 | new_img = img.crop((margin, 0, margin+min_size, min_size))
113 | else:
114 | margin = (img_h - min_size) // 2
115 | new_img = img.crop((0, margin, min_size, margin+min_size))
116 |
117 | elif self.crop_mode == 'random':
118 | img_w, img_h = img.size
119 | min_size = min(img_h, img_w)
120 | if min_size == img_h:
121 | margin = random.randint(0, (img_w - min_size) // 2)
122 | new_img = img.crop((margin, 0, margin+min_size, min_size))
123 | else:
124 | margin = random.randint(0, (img_h - min_size) // 2)
125 | new_img = img.crop((0, margin, min_size, margin+min_size))
126 | else:
127 | new_img = img
128 |
129 | # resize
130 | if mode == 'bicubic':
131 | new_img = new_img.resize((w, h), resample=getattr(Image, 'Resampling', Image).BICUBIC, reducing_gap=None)
132 | elif mode == 'bilinear':
133 | new_img = new_img.resize((w, h), resample=getattr(Image, 'Resampling', Image).BILINEAR, reducing_gap=None)
134 | elif mode == 'nearest':
135 | new_img = new_img.resize((w, h), resample=getattr(Image, 'Resampling', Image).NEAREST, reducing_gap=None)
136 | else:
137 | raise NotImplementedError
138 | return new_img
139 |
140 | def __call__(self, sample):
141 | for elem in sample.keys():
142 | if elem in ['image', 'image_semseg', 'semseg', 'mask', 'class_labels', 'image_class_labels']:
143 | sample[elem] = self.crop_and_resize(sample[elem], self.size[0], self.size[1], mode=INT_MODES[elem])
144 | return sample
145 |
146 | def __str__(self) -> str:
147 | return f"CropResize(size={self.size}, crop_mode={self.crop_mode})"
148 |
149 |
150 | class ToTensor(object):
151 | """Convert ndarrays in sample to Tensors."""
152 | def __init__(self):
153 | self.to_tensor = torchvision.transforms.ToTensor()
154 |
155 | def __call__(self, sample):
156 |
157 | for elem in sample.keys():
158 | if 'meta' in elem or 'text' in elem:
159 | continue
160 |
161 | elif elem in ['image', 'image_semseg', 'image_class_labels']:
162 | sample[elem] = self.to_tensor(sample[elem]) # Regular ToTensor operation
163 |
164 | elif elem in ['semseg', 'mask', 'class_labels']:
165 | sample[elem] = torch.from_numpy(np.array(sample[elem])).long() # Torch Long
166 |
167 | else:
168 | raise NotImplementedError
169 |
170 | return sample
171 |
172 | def __str__(self):
173 | return 'ToTensor'
174 |
175 |
176 | class Normalize(object):
177 | """Normalize a tensor image with mean and standard deviation.
178 | Args:
179 | mean (sequence): Sequence of means for each channel.
180 | std (sequence): Sequence of standard deviations for each channel.
181 | """
182 |
183 | def __init__(self, mean, std):
184 | self.normalize = torchvision.transforms.Normalize(mean, std)
185 |
186 | def __call__(self, sample):
187 |
188 | for elem in sample.keys():
189 | if 'meta' in elem or 'text' in elem:
190 | continue
191 |
192 | elif elem in ['image', 'image_semseg']:
193 | sample[elem] = self.normalize(sample[elem])
194 |
195 | else:
196 | raise NotImplementedError
197 |
198 | return sample
199 |
200 | def __str__(self):
201 | return f"Normalize(mean={self.normalize.mean}, std={self.normalize.std})"
202 |
--------------------------------------------------------------------------------
/ldmseg/evaluations/__init__.py:
--------------------------------------------------------------------------------
1 | from .semseg_evaluation import SemsegMeter
2 | from .panoptic_evaluation import PanopticEvaluator
3 | from .panoptic_evaluation_agnostic import PanopticEvaluatorAgnostic
4 |
5 | __all__ = [
6 | 'SemsegMeter',
7 | 'PanopticEvaluator',
8 | 'PanopticEvaluatorAgnostic',
9 | ]
10 |
--------------------------------------------------------------------------------
/ldmseg/evaluations/panoptic_evaluation.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Wouter Van Gansbeke
3 |
4 | Panoptic evaluation class
5 | Mostly copied from detectron2 with some modifications to make it compatible with our codebase
6 | See https://github.com/facebookresearch/detectron2 for more details and license.
7 | """
8 |
9 | import contextlib
10 | import io
11 | import itertools
12 | import json
13 | import logging
14 | import os
15 | import tempfile
16 | from collections import OrderedDict
17 | from typing import Optional
18 | from PIL import Image
19 | from tabulate import tabulate
20 |
21 | from detectron2.utils import comm
22 | from detectron2.utils.file_io import PathManager
23 | from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES
24 | from detectron2.evaluation.evaluator import DatasetEvaluator
25 |
26 | from termcolor import colored
27 |
28 | logger = logging.getLogger(__name__)
29 |
30 |
31 | class PanopticEvaluator(DatasetEvaluator):
32 | """
33 | Evaluate Panoptic Quality metrics on COCO using PanopticAPI.
34 | It saves panoptic segmentation prediction in `output_dir`
35 |
36 | It contains a synchronize call and has to be called from all workers.
37 | """
38 |
39 | def __init__(self, metadata, output_dir: Optional[str] = None):
40 | """
41 | Args:
42 | dataset_name: name of the dataset
43 | output_dir: output directory to save results for evaluation.
44 | """
45 | self._metadata = metadata
46 | self._thing_contiguous_id_to_dataset_id = {
47 | v: k for k, v in self._metadata["thing_dataset_id_to_contiguous_id"].items()
48 | }
49 | self._stuff_contiguous_id_to_dataset_id = {
50 | v: k for k, v in self._metadata["stuff_dataset_id_to_contiguous_id"].items()
51 | }
52 |
53 | self._output_dir = output_dir
54 | if self._output_dir is not None:
55 | PathManager.mkdirs(self._output_dir)
56 |
57 | def reset(self):
58 | self._predictions = []
59 |
60 | def _convert_category_id(self, segment_info):
61 | isthing = segment_info.pop("isthing", None)
62 | if isthing is None:
63 | # the model produces panoptic category id directly. No more conversion needed
64 | return segment_info
65 | if isthing is True:
66 | segment_info["category_id"] = self._thing_contiguous_id_to_dataset_id[
67 | segment_info["category_id"]
68 | ]
69 | else:
70 | segment_info["category_id"] = self._stuff_contiguous_id_to_dataset_id[
71 | segment_info["category_id"]
72 | ]
73 | return segment_info
74 |
75 | def process(self, file_names, outputs):
76 | from panopticapi.utils import id2rgb
77 |
78 | for file_name, output in zip(file_names, outputs):
79 | panoptic_img, segments_info = output["panoptic_seg"]
80 | panoptic_img = panoptic_img.cpu().numpy()
81 |
82 | file_name = os.path.basename(file_name)
83 | file_name_png = os.path.splitext(file_name)[0] + ".png"
84 | with io.BytesIO() as out:
85 | Image.fromarray(id2rgb(panoptic_img)).save(out, format="PNG")
86 | segments_info = [self._convert_category_id(x) for x in segments_info]
87 | self._predictions.append(
88 | {
89 | "image_id": int(file_name.split(".")[0]),
90 | "file_name": file_name_png,
91 | "png_string": out.getvalue(),
92 | "segments_info": segments_info,
93 | }
94 | )
95 |
96 | def evaluate(self):
97 | comm.synchronize()
98 |
99 | self._predictions = comm.gather(self._predictions)
100 | self._predictions = list(itertools.chain(*self._predictions))
101 | if not comm.is_main_process():
102 | return
103 |
104 | # PanopticApi requires local files
105 | gt_json = PathManager.get_local_path(self._metadata['panoptic_json'])
106 | gt_folder = PathManager.get_local_path(self._metadata['panoptic_root'])
107 |
108 | with tempfile.TemporaryDirectory(prefix="panoptic_eval") as pred_dir:
109 | logger.info("Writing all panoptic predictions to {} ...".format(pred_dir))
110 | for p in self._predictions:
111 | with open(os.path.join(pred_dir, p["file_name"]), "wb") as f:
112 | f.write(p.pop("png_string"))
113 |
114 | with open(gt_json, "r") as f:
115 | json_data = json.load(f)
116 | json_data["annotations"] = self._predictions
117 |
118 | output_dir = self._output_dir or pred_dir
119 | predictions_json = os.path.join(output_dir, "predictions.json")
120 | with PathManager.open(predictions_json, "w") as f:
121 | f.write(json.dumps(json_data))
122 |
123 | from panopticapi.evaluation import pq_compute
124 |
125 | with contextlib.redirect_stdout(io.StringIO()):
126 | pq_res = pq_compute(
127 | gt_json,
128 | PathManager.get_local_path(predictions_json),
129 | gt_folder=gt_folder,
130 | pred_folder=pred_dir,
131 | )
132 |
133 | res = {}
134 | res["PQ"] = 100 * pq_res["All"]["pq"]
135 | res["SQ"] = 100 * pq_res["All"]["sq"]
136 | res["RQ"] = 100 * pq_res["All"]["rq"]
137 | res["PQ_th"] = 100 * pq_res["Things"]["pq"]
138 | res["SQ_th"] = 100 * pq_res["Things"]["sq"]
139 | res["RQ_th"] = 100 * pq_res["Things"]["rq"]
140 | res["PQ_st"] = 100 * pq_res["Stuff"]["pq"]
141 | res["SQ_st"] = 100 * pq_res["Stuff"]["sq"]
142 | res["RQ_st"] = 100 * pq_res["Stuff"]["rq"]
143 |
144 | results = OrderedDict({"panoptic_seg": res})
145 | print(colored(get_table(pq_res), 'yellow'))
146 |
147 | return results
148 |
149 |
150 | def get_table(pq_res):
151 | headers = ["", "PQ", "SQ", "RQ", "#categories"]
152 | data = []
153 | for name in ["All", "Things", "Stuff"]:
154 | if name not in pq_res:
155 | continue
156 | row = [name] + [pq_res[name][k] * 100 for k in ["pq", "sq", "rq"]] + [pq_res[name]["n"]]
157 | data.append(row)
158 | table = tabulate(
159 | data, headers=headers, tablefmt="pipe", floatfmt=".3f", stralign="center", numalign="center"
160 | )
161 | return table
162 |
163 |
164 | def _print_panoptic_results(pq_res):
165 | headers = ["", "PQ", "SQ", "RQ", "#categories"]
166 | data = []
167 | for name in ["All", "Things", "Stuff"]:
168 | row = [name] + [pq_res[name][k] * 100 for k in ["pq", "sq", "rq"]] + [pq_res[name]["n"]]
169 | data.append(row)
170 | table = tabulate(
171 | data, headers=headers, tablefmt="pipe", floatfmt=".3f", stralign="center", numalign="center"
172 | )
173 | logger.info("Panoptic Evaluation Results:\n" + table)
174 |
175 |
176 | if __name__ == "__main__":
177 | from detectron2.utils.logger import setup_logger
178 |
179 | logger = setup_logger()
180 | import argparse
181 |
182 | parser = argparse.ArgumentParser()
183 | parser.add_argument("--gt-json")
184 | parser.add_argument("--gt-dir")
185 | parser.add_argument("--pred-json")
186 | parser.add_argument("--pred-dir")
187 | args = parser.parse_args()
188 |
189 | from panopticapi.evaluation import pq_compute
190 |
191 | with contextlib.redirect_stdout(io.StringIO()):
192 | pq_res = pq_compute(
193 | args.gt_json, args.pred_json, gt_folder=args.gt_dir, pred_folder=args.pred_dir
194 | )
195 | _print_panoptic_results(pq_res)
196 |
--------------------------------------------------------------------------------
/ldmseg/evaluations/panoptic_evaluation_agnostic.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Wouter Van Gansbeke
3 |
4 | Panoptic evaluation class for class agnostic models
5 | Mostly copied from detectron2 with some modifications to make it compatible with our codebase
6 | See https://github.com/facebookresearch/detectron2 for more details and license.
7 | """
8 |
9 | import torch
10 | import contextlib
11 | import io
12 | import itertools
13 | import json
14 | import logging
15 | import os
16 | import tempfile
17 | from collections import OrderedDict
18 | from PIL import Image
19 | from tabulate import tabulate
20 | import numpy as np
21 |
22 | from detectron2.utils import comm
23 | from detectron2.utils.file_io import PathManager
24 | from detectron2.evaluation.evaluator import DatasetEvaluator
25 | from termcolor import colored
26 | from typing import Optional, List, Dict, Union
27 |
28 | logger = logging.getLogger(__name__)
29 |
30 |
31 | class PanopticEvaluatorAgnostic(DatasetEvaluator):
32 | """
33 | Evaluate Panoptic Quality metrics on COCO using PanopticAPI.
34 | It saves panoptic segmentation prediction in `output_dir`
35 |
36 | It contains a synchronize call and has to be called from all workers.
37 | """
38 |
39 | def __init__(self, output_dir: str = './output/predictions_panoptic/', meta: Optional[Dict] = None):
40 | """
41 | Args:
42 | dataset_name: name of the dataset
43 | output_dir: output directory to save results for evaluation.
44 | """
45 | self._metadata = meta
46 | self.thing_dataset_id_to_contiguous_id = self._metadata["thing_dataset_id_to_contiguous_id"]
47 | self.stuff_dataset_id_to_contiguous_id = self._metadata["stuff_dataset_id_to_contiguous_id"]
48 | self.panoptic_json = self._metadata['panoptic_json']
49 | self.panoptic_root = self._metadata['panoptic_root']
50 | self.label_divisor = 1
51 | self._thing_contiguous_id_to_dataset_id = {
52 | v: k for k, v in self._metadata["thing_dataset_id_to_contiguous_id"].items()
53 | }
54 | self._stuff_contiguous_id_to_dataset_id = {
55 | v: k for k, v in self._metadata["stuff_dataset_id_to_contiguous_id"].items()
56 | }
57 |
58 | self.class_agnostic = True
59 | if self.class_agnostic:
60 | # we need to modify the json of the ground truth to make it class agnostic and save it
61 | gt_json = PathManager.get_local_path(self.panoptic_json)
62 | gt_json_agnostic = gt_json.replace(".json", "_agnostic.json")
63 | self.panoptic_json = gt_json_agnostic
64 | if not os.path.exists(gt_json_agnostic):
65 | with open(gt_json, "r") as f:
66 | json_data = json.load(f)
67 | for anno in json_data["annotations"]:
68 | for seg in anno["segments_info"]:
69 | seg["category_id"] = 1
70 | json_data['categories'] = [{'id': 1, 'name': 'object', 'supercategory': 'object', 'isthing': 1}]
71 | with PathManager.open(gt_json_agnostic, "w") as f:
72 | f.write(json.dumps(json_data))
73 |
74 | self._output_dir = output_dir
75 | if self._output_dir is not None:
76 | PathManager.mkdirs(self._output_dir)
77 |
78 | def reset(self):
79 | self._predictions = []
80 |
81 | def _convert_category_id(self, segment_info: dict) -> dict:
82 | isthing = segment_info.pop("isthing", None)
83 | if isthing is None:
84 | # the model produces panoptic category id directly. No more conversion needed
85 | return segment_info
86 | if isthing is True:
87 | segment_info["category_id"] = self._thing_contiguous_id_to_dataset_id[
88 | segment_info["category_id"]
89 | ]
90 | else:
91 | segment_info["category_id"] = self._stuff_contiguous_id_to_dataset_id[
92 | segment_info["category_id"]
93 | ]
94 | return segment_info
95 |
96 | def process(
97 | self,
98 | file_names: List[str],
99 | image_ids: List[int],
100 | outputs: Dict[str, Union[torch.Tensor, np.ndarray, dict]],
101 | ):
102 | from panopticapi.utils import id2rgb
103 |
104 | for file_name, image_id, output in zip(file_names, image_ids, outputs):
105 | panoptic_img, segments_info = output["panoptic_seg"]
106 | if isinstance(panoptic_img, torch.Tensor):
107 | panoptic_img = panoptic_img.cpu().numpy()
108 |
109 | for seg_id in segments_info:
110 | seg_id['category_id'] = 1
111 | seg_id['isthing'] = True
112 |
113 | file_name = os.path.basename(file_name)
114 | file_name_png = os.path.splitext(file_name)[0] + ".png"
115 | with io.BytesIO() as out:
116 | Image.fromarray(id2rgb(panoptic_img)).save(out, format="PNG")
117 | if not self.class_agnostic:
118 | segments_info = [self._convert_category_id(x) for x in segments_info]
119 | self._predictions.append(
120 | {
121 | "image_id": image_id,
122 | "file_name": file_name_png,
123 | "png_string": out.getvalue(),
124 | "segments_info": segments_info,
125 | }
126 | )
127 |
128 | def evaluate(self):
129 | comm.synchronize()
130 |
131 | self._predictions = comm.gather(self._predictions)
132 | self._predictions = list(itertools.chain(*self._predictions))
133 | if not comm.is_main_process():
134 | return
135 |
136 | # PanopticApi requires local files
137 | gt_json = PathManager.get_local_path(self.panoptic_json)
138 | gt_folder = PathManager.get_local_path(self.panoptic_root)
139 |
140 | with tempfile.TemporaryDirectory(prefix="panoptic_eval") as pred_dir:
141 | logger.info("Writing all panoptic predictions to {} ...".format(pred_dir))
142 | print(colored('Writing all panoptic predictions to {}...'.format(pred_dir), 'blue'))
143 | for p in self._predictions:
144 | with open(os.path.join(pred_dir, p["file_name"]), "wb") as f:
145 | f.write(p.pop("png_string"))
146 |
147 | with open(gt_json, "r") as f:
148 | json_data = json.load(f)
149 | json_data["annotations"] = self._predictions
150 |
151 | output_dir = self._output_dir or pred_dir
152 | predictions_json = os.path.join(output_dir, "predictions.json")
153 | with PathManager.open(predictions_json, "w") as f:
154 | f.write(json.dumps(json_data))
155 |
156 | with contextlib.redirect_stdout(io.StringIO()):
157 | pq_res, pq_stat_per_cat, num_preds = pq_compute(
158 | gt_json,
159 | PathManager.get_local_path(predictions_json),
160 | gt_folder=gt_folder,
161 | pred_folder=pred_dir,
162 | )
163 |
164 | res = {}
165 | res["PQ"] = 100 * pq_res["All"]["pq"]
166 | res["SQ"] = 100 * pq_res["All"]["sq"]
167 | res["RQ"] = 100 * pq_res["All"]["rq"]
168 | res["PQ_th"] = 100 * pq_res["Things"]["pq"]
169 | res["SQ_th"] = 100 * pq_res["Things"]["sq"]
170 | res["RQ_th"] = 100 * pq_res["Things"]["rq"]
171 | if "Stuff" in pq_res:
172 | res["PQ_st"] = 100 * pq_res["Stuff"]["pq"]
173 | res["SQ_st"] = 100 * pq_res["Stuff"]["sq"]
174 | res["RQ_st"] = 100 * pq_res["Stuff"]["rq"]
175 |
176 | results = OrderedDict({"panoptic_seg": res})
177 |
178 | precision = pq_stat_per_cat[1].tp / (pq_stat_per_cat[1].tp + pq_stat_per_cat[1].fp + 1e-8)
179 | recall = pq_stat_per_cat[1].tp / (pq_stat_per_cat[1].tp + pq_stat_per_cat[1].fn + 1e-8)
180 | print('')
181 | print('precision: ', precision*100)
182 | print('recall: ', recall*100)
183 | print('found {} predictions'.format(num_preds))
184 | print(colored(get_table(pq_res), 'yellow'))
185 | return results
186 |
187 |
188 | def pq_compute(
189 | gt_json_file: str,
190 | pred_json_file: str,
191 | gt_folder=None,
192 | pred_folder=None,
193 | ):
194 | from panopticapi.evaluation import pq_compute_multi_core
195 |
196 | with open(gt_json_file, 'r') as f:
197 | gt_json = json.load(f)
198 | with open(pred_json_file, 'r') as f:
199 | pred_json = json.load(f)
200 |
201 | if gt_folder is None:
202 | gt_folder = gt_json_file.replace('.json', '')
203 | if pred_folder is None:
204 | pred_folder = pred_json_file.replace('.json', '')
205 | categories = {el['id']: el for el in gt_json['categories']}
206 |
207 | if not os.path.isdir(gt_folder):
208 | raise Exception("Folder {} with ground truth segmentations doesn't exist".format(gt_folder))
209 | if not os.path.isdir(pred_folder):
210 | raise Exception("Folder {} with predicted segmentations doesn't exist".format(pred_folder))
211 |
212 | pred_annotations = {el['image_id']: el for el in pred_json['annotations']}
213 | matched_annotations_list = []
214 | for gt_ann in gt_json['annotations']:
215 | image_id = gt_ann['image_id']
216 | if image_id not in pred_annotations:
217 | # raise Exception('no prediction for the image with id: {}'.format(image_id))
218 | continue
219 | matched_annotations_list.append((gt_ann, pred_annotations[image_id]))
220 |
221 | pq_stat = pq_compute_multi_core(matched_annotations_list, gt_folder, pred_folder, categories)
222 |
223 | metrics = [("All", None), ("Things", True)]
224 | results = {}
225 | for name, isthing in metrics:
226 | results[name], per_class_results = pq_stat.pq_average(categories, isthing=isthing)
227 | if name == 'All':
228 | results['per_class'] = per_class_results
229 |
230 | return results, pq_stat.pq_per_cat, len(pred_annotations)
231 |
232 |
233 | def get_table(pq_res: dict):
234 | headers = ["", "PQ", "SQ", "RQ", "#categories"]
235 | data = []
236 | for name in ["All", "Things", "Stuff"]:
237 | if name not in pq_res:
238 | continue
239 | row = [name] + [pq_res[name][k] * 100 for k in ["pq", "sq", "rq"]] + [pq_res[name]["n"]]
240 | data.append(row)
241 | table = tabulate(
242 | data, headers=headers, tablefmt="pipe", floatfmt=".3f", stralign="center", numalign="center"
243 | )
244 | return table
245 |
--------------------------------------------------------------------------------
/ldmseg/evaluations/semseg_evaluation.py:
--------------------------------------------------------------------------------
1 | """
2 | Authors: Wouter Van Gansbeke
3 |
4 | Semantic segmentation evaluation utils
5 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/)
6 | """
7 |
8 | import numpy as np
9 | import torch
10 | import torch.distributed as dist
11 | from ldmseg.utils.utils import is_dist_avail_and_initialized
12 |
13 |
14 | class SemsegMeter(object):
15 | def __init__(self, num_classes, class_names, has_bg=True, ignore_index=255, gpu_idx='cuda'):
16 | self.num_classes = num_classes + int(has_bg)
17 | self.class_names = class_names
18 | self.tp = [0] * self.num_classes
19 | self.fp = [0] * self.num_classes
20 | self.fn = [0] * self.num_classes
21 | self.ignore_index = ignore_index
22 | self.gpu_idx = gpu_idx
23 |
24 | @torch.no_grad()
25 | def update(self, pred, gt):
26 | valid = (gt != self.ignore_index)
27 |
28 | for i_part in range(0, self.num_classes):
29 | tmp_gt = (gt == i_part)
30 | tmp_pred = (pred == i_part)
31 | self.tp[i_part] += torch.sum(tmp_gt & tmp_pred & valid).item()
32 | self.fp[i_part] += torch.sum(~tmp_gt & tmp_pred & valid).item()
33 | self.fn[i_part] += torch.sum(tmp_gt & ~tmp_pred & valid).item()
34 |
35 | def reset(self):
36 | self.tp = [0] * self.num_classes
37 | self.fp = [0] * self.num_classes
38 | self.fn = [0] * self.num_classes
39 |
40 | def return_score(self, verbose=True, name='dataset', suppress_prints=False):
41 | jac = [0] * self.num_classes
42 | for i_part in range(self.num_classes):
43 | jac[i_part] = float(self.tp[i_part]) / max(float(self.tp[i_part] + self.fp[i_part] + self.fn[i_part]), 1e-8)
44 |
45 | eval_result = dict()
46 | eval_result['jaccards_all_categs'] = jac
47 | eval_result['mIoU'] = np.mean(jac)
48 |
49 | if not suppress_prints or verbose:
50 | print(f'Evaluation for semantic segmentation - {name}')
51 | print('mIoU is %.2f' % (100*eval_result['mIoU']))
52 | if verbose:
53 | for i_part in range(self.num_classes):
54 | print('IoU class %s is %.2f' % (self.class_names[i_part], 100*jac[i_part]))
55 |
56 | return eval_result
57 |
58 | def synchronize_between_processes(self):
59 | """
60 | Warning: does not synchronize the deque!
61 | """
62 | if not is_dist_avail_and_initialized():
63 | return
64 | t = torch.tensor([self.tp, self.fp, self.fn], dtype=torch.float64, device=self.gpu_idx)
65 | dist.barrier()
66 | dist.all_reduce(t)
67 | self.tp = t[0]
68 | self.fp = t[1]
69 | self.fn = t[2]
70 |
71 | def __str__(self):
72 | res = self.return_score(verbose=False, suppress_prints=True)['mIoU']*100
73 | fmtstr = "IoU ({0:.2f})"
74 | return fmtstr.format(res)
75 |
--------------------------------------------------------------------------------
/ldmseg/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .unet import UNet
2 | from .vae import GeneralVAESeg, GeneralVAEImage
3 | from .descriptors import get_image_descriptor_model
4 | from .upscaler import Upscaler
5 |
6 | __all__ = ['UNet', 'GeneralVAESeg', 'GeneralVAEImage', 'get_image_descriptor_model', 'Upscaler']
7 |
--------------------------------------------------------------------------------
/ldmseg/models/descriptors.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Wouter Van Gansbeke
3 |
4 | File with descriptor models for latent diffusion training
5 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/)
6 | """
7 |
8 | import torch
9 | import torch.nn as nn
10 | from typing import Optional
11 | from transformers import CLIPVisionModel, CLIPVisionModelWithProjection, CLIPTokenizer, CLIPTextModel
12 | from functools import partial
13 |
14 |
15 | class MyCLIPVisionModel(CLIPVisionModel):
16 | def forward(
17 | self,
18 | pixel_values: Optional[torch.FloatTensor] = None,
19 | output_attentions: Optional[bool] = None,
20 | output_hidden_states: Optional[bool] = None,
21 | return_dict: Optional[bool] = None,
22 | ):
23 |
24 | out = self.vision_model(
25 | pixel_values=pixel_values,
26 | output_attentions=output_attentions,
27 | output_hidden_states=output_hidden_states,
28 | return_dict=return_dict,
29 | )
30 |
31 | return {'last_feat': out.last_hidden_state.permute(0, 2, 1)}
32 |
33 |
34 | class MyCLIPVisionModelWithProjection(CLIPVisionModelWithProjection):
35 | def forward(
36 | self,
37 | pixel_values: Optional[torch.FloatTensor] = None,
38 | output_attentions: Optional[bool] = None,
39 | output_hidden_states: Optional[bool] = None,
40 | return_dict: Optional[bool] = None,
41 | ):
42 | vision_outputs = self.vision_model(
43 | pixel_values=pixel_values,
44 | output_attentions=output_attentions,
45 | output_hidden_states=output_hidden_states,
46 | return_dict=return_dict,
47 | )
48 |
49 | pooled_output = vision_outputs[1] # pooled_output
50 |
51 | image_embeds = self.visual_projection(pooled_output)
52 | # last_hidden_state = vision_outputs.last_hidden_state
53 | # hidden_states = vision_outputs.hidden_states
54 | # attentions = vision_outputs.attentions
55 |
56 | return {'last_feat': image_embeds.unsqueeze(-1)}
57 |
58 |
59 | def get_dino_image_descriptor_model():
60 | raise NotImplementedError('Not yet supported')
61 |
62 |
63 | def get_mae_image_descriptor_model():
64 | raise NotImplementedError('Not yet supported')
65 |
66 |
67 | def get_image_descriptor_model(descriptor_name, pretrained_model_path, unet):
68 | text_encoder = tokenizer = image_descriptor_model = None
69 | if descriptor_name == 'clip_image':
70 | # image_descriptor_model = MyCLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
71 | image_descriptor_model = MyCLIPVisionModel.from_pretrained("openai/clip-vit-large-patch14")
72 | unet.modify_encoder_hidden_state_proj(1024, 768)
73 |
74 | elif descriptor_name == 'clip_image_proj':
75 | # image_descriptor_model = MyCLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
76 | image_descriptor_model = MyCLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
77 |
78 | elif descriptor_name == 'dino_image':
79 | raise NotImplementedError('DINO is not yet supported')
80 | get_dino_image_descriptor_model()
81 | unet.modify_encoder_hidden_state_proj(768, 768)
82 | print('adding linear projection to unet for image descriptors')
83 |
84 | elif descriptor_name == 'mae':
85 | raise NotImplementedError('MAE is not yet supported')
86 | get_mae_image_descriptor_model()
87 | unet.modify_encoder_hidden_state_proj(768, 768)
88 | print('adding linear projection to unet for image descriptors')
89 |
90 | elif descriptor_name == 'learnable':
91 | unet.define_learnable_embeddings(128, 768)
92 | print(f'Successfully added learnable object queries to unet as {unet.object_queries}')
93 |
94 | elif descriptor_name == 'remove':
95 | unet.remove_cross_attention()
96 | print('Successfully removed cross attention layers from unet')
97 |
98 | else:
99 | assert descriptor_name == 'none'
100 | # load the pretrained CLIP model
101 | tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
102 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder")
103 | print('Succesfully loaded pretrained CLIP text encoder')
104 |
105 | return image_descriptor_model, text_encoder, tokenizer
106 |
--------------------------------------------------------------------------------
/ldmseg/models/upscaler.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Wouter Van Gansbeke
3 |
4 | File with a simple upscaler (decoder) model for latent diffusion training
5 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/)
6 | """
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | from typing import Optional
12 | try:
13 | from diffusers.models.unet_2d_blocks import UNetMidBlock2D
14 | except ImportError:
15 | print("Diffusers package not found to train VAEs. Please install diffusers")
16 | raise ImportError
17 |
18 |
19 | class Upscaler(nn.Module):
20 | def __init__(
21 | self,
22 | latent_channels: int = 4,
23 | int_channels: int = 256,
24 | upscaler_channels: int = 256,
25 | out_channels: int = 128,
26 | num_mid_blocks: int = 0,
27 | num_upscalers: int = 1,
28 | fuse_rgb: bool = False,
29 | downsample_factor: int = 8,
30 | norm_num_groups: int = 32,
31 | pretrained_path: Optional[str] = None,
32 | ) -> None:
33 |
34 | super().__init__()
35 |
36 | self.enable_mid_block = num_mid_blocks > 0
37 | self.num_mid_blocks = num_mid_blocks
38 | self.downsample_factor = downsample_factor
39 | self.interpolation_factor = self.downsample_factor // (2 ** num_upscalers)
40 |
41 | self.fuse_rgb = fuse_rgb
42 | multiplier = 2 if self.fuse_rgb else 1
43 | self.define_decoder(out_channels, int_channels, upscaler_channels, norm_num_groups,
44 | latent_channels * multiplier, num_upscalers=num_upscalers)
45 | self.gradient_checkpoint = False
46 | if pretrained_path is not None:
47 | self.load_pretrained(pretrained_path)
48 | print('Interpolation factor: ', self.interpolation_factor)
49 |
50 | def enable_gradient_checkpointing(self):
51 | raise NotImplementedError("Gradient checkpointing not implemented for Upscaler")
52 |
53 | def load_pretrained(self, pretrained_path):
54 | data = torch.load(pretrained_path, map_location='cpu')
55 | # remove the module prefix from the state dict
56 | data['vae'] = {k.replace('module.', ''): v for k, v in data['vae'].items()}
57 | msg = self.load_state_dict(data['vae'], strict=False)
58 | print(f'Loaded pretrained decoder from VAE checkp. {pretrained_path} with message {msg}')
59 |
60 | def define_decoder(
61 | self,
62 | num_classes: int,
63 | int_channels: int = 256,
64 | upscaler_channels: int = 256,
65 | norm_num_groups: int = 32,
66 | latent_channels: int = 4,
67 | num_upscalers: int = 1,
68 | ):
69 |
70 | decoder_in_conv = nn.Conv2d(latent_channels, int_channels, kernel_size=3, padding=1)
71 |
72 | if self.enable_mid_block:
73 | decoder_mid_block = [UNetMidBlock2D(
74 | in_channels=int_channels,
75 | resnet_eps=1e-6,
76 | resnet_act_fn='silu',
77 | output_scale_factor=1,
78 | resnet_time_scale_shift="default",
79 | attn_num_head_channels=None,
80 | resnet_groups=norm_num_groups,
81 | temb_channels=None,
82 | ) for _ in range(self.num_mid_blocks)]
83 | else:
84 | decoder_mid_block = [nn.Identity()]
85 |
86 | dim = upscaler_channels
87 | upscaler = []
88 | for i in range(num_upscalers):
89 | in_channels = int_channels if i == 0 else dim
90 | upscaler.extend(
91 | [
92 | nn.ConvTranspose2d(in_channels, dim, kernel_size=2, stride=2),
93 | LayerNorm2d(dim),
94 | nn.SiLU()
95 | ]
96 | )
97 | upscaler.extend(
98 | [
99 | nn.GroupNorm(norm_num_groups, dim),
100 | nn.SiLU(),
101 | nn.Conv2d(dim, num_classes, 3, padding=1),
102 | ]
103 | )
104 |
105 | self.decoder = nn.Sequential(
106 | decoder_in_conv,
107 | *decoder_mid_block,
108 | *upscaler,
109 | )
110 |
111 | def freeze_layers(self):
112 | raise NotImplementedError
113 |
114 | def decode(self, z, interpolate=True):
115 | x = self.decoder(z)
116 | if interpolate:
117 | x = F.interpolate(x, scale_factor=self.interpolation_factor, mode='bilinear', align_corners=False)
118 | return x
119 |
120 | def forward(
121 | self,
122 | z: torch.Tensor,
123 | interpolate: bool = False,
124 | z_rgb: Optional[torch.tensor] = None
125 | ) -> torch.Tensor:
126 |
127 | if z_rgb is not None and self.fuse_rgb:
128 | z = torch.cat([z, z_rgb], dim=1)
129 |
130 | return self.decode(z, interpolate=interpolate)
131 |
132 |
133 | class LayerNorm2d(nn.Module):
134 | # copied from detectron2
135 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
136 | super().__init__()
137 | self.weight = nn.Parameter(torch.ones(num_channels))
138 | self.bias = nn.Parameter(torch.zeros(num_channels))
139 | self.eps = eps
140 |
141 | def forward(self, x: torch.Tensor) -> torch.Tensor:
142 | u = x.mean(1, keepdim=True)
143 | s = (x - u).pow(2).mean(1, keepdim=True)
144 | x = (x - u) / torch.sqrt(s + self.eps)
145 | x = self.weight[:, None, None] * x + self.bias[:, None, None]
146 | return x
147 |
--------------------------------------------------------------------------------
/ldmseg/models/vae.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Wouter Van Gansbeke
3 |
4 | File with VAE models for latent diffusion training
5 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/)
6 | """
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | from einops import rearrange
12 | from typing import Tuple, Optional, Union
13 | from ldmseg.utils import OutputDict
14 | try:
15 | from diffusers.models.unet_2d_blocks import UNetMidBlock2D
16 | from diffusers import AutoencoderKL
17 | except ImportError:
18 | print("Diffusers package not found to train VAEs. Please install diffusers")
19 | raise ImportError
20 |
21 |
22 | class RangeDict(OutputDict):
23 | min: torch.Tensor
24 | max: torch.Tensor
25 |
26 |
27 | class VAEOutput(OutputDict):
28 | sample: torch.Tensor
29 | posterior: torch.Tensor
30 |
31 |
32 | class EncoderOutput(OutputDict):
33 | latent_dist: torch.Tensor
34 |
35 |
36 | class GeneralVAEImage(AutoencoderKL):
37 |
38 | def set_scaling_factor(self, scaling_factor):
39 | self.scaling_factor = scaling_factor
40 |
41 |
42 | class GeneralVAESeg(nn.Module):
43 | def __init__(
44 | self,
45 | in_channels: int = 3,
46 | int_channels: int = 256,
47 | out_channels: int = 128,
48 | block_out_channels: Tuple[int] = (32, 64, 128, 256),
49 | latent_channels: int = 4,
50 | norm_num_groups: int = 32,
51 | scaling_factor: float = 0.18215,
52 | pretrained_path: Optional[str] = None,
53 | encoder: Optional[nn.Module] = None,
54 | num_mid_blocks: int = 0,
55 | num_latents: int = 2,
56 | num_upscalers: int = 1,
57 | upscale_channels: int = 256,
58 | parametrization: str = 'gaussian',
59 | fuse_rgb: bool = False,
60 | resize_input: bool = False,
61 | act_fn: str = 'none',
62 | clamp_output: bool = False,
63 | freeze_codebook: bool = False,
64 | skip_encoder: bool = False,
65 | ) -> None:
66 |
67 | super().__init__()
68 |
69 | self.enable_mid_block = num_mid_blocks > 0
70 | self.num_mid_blocks = num_mid_blocks
71 | self.downsample_factor = 2 ** (len(block_out_channels) - 1)
72 | self.interpolation_factor = self.downsample_factor // (2 ** num_upscalers)
73 | if "discrete" in parametrization:
74 | num_embeddings = 128
75 | if freeze_codebook:
76 | print('Freezing codebook')
77 | gen = torch.Generator().manual_seed(42)
78 | Q = torch.linalg.qr(torch.randn(num_embeddings, latent_channels, generator=gen))[0]
79 | self.codebook = nn.Embedding.from_pretrained(Q, freeze=True)
80 | else:
81 | self.codebook = nn.Embedding(num_embeddings, latent_channels, max_norm=None)
82 | num_latents = num_embeddings // latent_channels
83 | elif parametrization == 'auto':
84 | num_latents = 1
85 |
86 | if encoder is None:
87 | if fuse_rgb:
88 | in_channels += 3
89 | self.define_encoder(in_channels, block_out_channels, int_channels, norm_num_groups,
90 | latent_channels, num_latents=num_latents, resize_input=resize_input,
91 | skip_encoder=skip_encoder)
92 | else:
93 | self.encoder = encoder
94 | self.freeze_encoder()
95 |
96 | self.define_decoder(out_channels, int_channels, norm_num_groups, latent_channels,
97 | num_upscalers=num_upscalers, upscale_channels=upscale_channels)
98 | self.scaling_factor = scaling_factor
99 | self.gradient_checkpoint = False
100 | if pretrained_path is not None:
101 | self.load_pretrained(pretrained_path)
102 | self.parametrization = parametrization
103 | self.interpolation_factor = self.downsample_factor // (2 ** num_upscalers)
104 | self.num_latents = num_latents
105 | self.act_fn = act_fn
106 | self.clamp_output = clamp_output
107 | print('Interpolation factor: ', self.interpolation_factor)
108 | print('Parametrization: ', self.parametrization)
109 | print('Activation function: ', self.act_fn)
110 | assert self.parametrization in ['gaussian', 'discrete_gumbel_softmax', 'discrete_codebook', 'auto']
111 | assert self.num_latents in [1, 2, 32]
112 |
113 | def enable_gradient_checkpointing(self):
114 | raise NotImplementedError("Gradient checkpointing not implemented for a shallow VAE")
115 |
116 | def load_pretrained(self, pretrained_path):
117 | data = torch.load(pretrained_path, map_location='cpu')
118 | # remove the module prefix from the state dict
119 | data['vae'] = {k.replace('module.', ''): v for k, v in data['vae'].items()}
120 | msg = self.load_state_dict(data['vae'], strict=True)
121 | print(f'Loaded pretrained VAE from {pretrained_path} with message {msg}')
122 |
123 | def define_decoder(
124 | self,
125 | num_classes: int,
126 | int_channels: int = 256,
127 | norm_num_groups: int = 32,
128 | latent_channels: int = 4,
129 | num_upscalers: int = 1,
130 | upscale_channels: int = 256,
131 | ):
132 |
133 | decoder_in_conv = nn.Conv2d(latent_channels, int_channels, kernel_size=3, padding=1)
134 |
135 | if self.enable_mid_block:
136 | decoder_mid_block = UNetMidBlock2D(
137 | in_channels=int_channels,
138 | resnet_eps=1e-6,
139 | resnet_act_fn='silu',
140 | output_scale_factor=1,
141 | resnet_time_scale_shift="default",
142 | resnet_groups=norm_num_groups,
143 | temb_channels=None,
144 | add_attention=False,
145 | )
146 | else:
147 | decoder_mid_block = nn.Identity()
148 |
149 | dim = upscale_channels
150 | upscaler = []
151 | for i in range(num_upscalers):
152 | in_channels = int_channels if i == 0 else dim
153 | upscaler.extend(
154 | [
155 | nn.ConvTranspose2d(in_channels, dim, kernel_size=2, stride=2),
156 | LayerNorm2d(dim),
157 | nn.SiLU()
158 | ]
159 | )
160 | upscaler.extend(
161 | [
162 | nn.GroupNorm(norm_num_groups, dim),
163 | nn.SiLU(),
164 | nn.Conv2d(dim, num_classes, 3, padding=1),
165 | ]
166 | )
167 |
168 | self.decoder = nn.Sequential(
169 | decoder_in_conv,
170 | decoder_mid_block,
171 | *upscaler,
172 | )
173 |
174 | def define_encoder(
175 | self,
176 | in_channels: int,
177 | block_out_channels: Tuple[int],
178 | int_channels: int = 256,
179 | norm_num_groups: int = 32,
180 | latent_channels: int = 4,
181 | num_latents: int = 2,
182 | resize_input: bool = False,
183 | skip_encoder: bool = False,
184 | ):
185 | # define semseg encoder
186 | if skip_encoder:
187 | self.encoder = nn.Conv2d(in_channels, latent_channels * num_latents, 8, stride=8)
188 | return
189 |
190 | encoder_in_block = [
191 | nn.Conv2d(in_channels, block_out_channels[0] if not resize_input else int_channels,
192 | kernel_size=3, padding=1),
193 | nn.SiLU(),
194 | ]
195 |
196 | if not resize_input:
197 | down_blocks_semseg = []
198 | for i in range(len(block_out_channels) - 1):
199 | channel_in = block_out_channels[i]
200 | channel_out = block_out_channels[i + 1]
201 | down_blocks_semseg.extend(
202 | [
203 | nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1),
204 | nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2),
205 | nn.SiLU(),
206 | ]
207 | )
208 | else:
209 | down_blocks_semseg = [
210 | nn.Upsample(scale_factor=1. / self.downsample_factor, mode='bilinear', align_corners=False)
211 | ]
212 | encoder_down_blocks = [
213 | *down_blocks_semseg,
214 | nn.Conv2d(block_out_channels[-1], int_channels, kernel_size=3, padding=1),
215 | ]
216 |
217 | encoder_mid_blocks = []
218 | if self.enable_mid_block:
219 | for _ in range(self.num_mid_blocks):
220 | encoder_mid_blocks.append(UNetMidBlock2D(
221 | in_channels=int_channels,
222 | resnet_eps=1e-6,
223 | resnet_act_fn='silu',
224 | output_scale_factor=1,
225 | resnet_time_scale_shift="default",
226 | resnet_groups=norm_num_groups,
227 | temb_channels=None,
228 | add_attention=False,
229 | ))
230 | else:
231 | encoder_mid_blocks = [nn.Identity()]
232 |
233 | encoder_out_block = [
234 | nn.GroupNorm(num_channels=int_channels, num_groups=norm_num_groups, eps=1e-6),
235 | nn.SiLU(),
236 | nn.Conv2d(int_channels, latent_channels * num_latents, 3, padding=1)
237 | ]
238 |
239 | self.encoder = nn.Sequential(
240 | *encoder_in_block,
241 | *encoder_down_blocks,
242 | *encoder_mid_blocks,
243 | *encoder_out_block,
244 | )
245 |
246 | def freeze_layers(self):
247 | raise NotImplementedError
248 |
249 | def freeze_encoder(self):
250 | self.encoder.requires_grad_(False)
251 |
252 | def encode(self, semseg):
253 | moments = self.encoder(semseg)
254 | if self.parametrization == 'gaussian':
255 | posterior = DiagonalGaussianDistribution(
256 | moments, clamp_output=self.clamp_output, act_fn=self.act_fn)
257 | elif self.parametrization == 'discrete_gumbel_softmax':
258 | posterior = GumbelSoftmaxDistribution(
259 | moments, self.codebook, clamp_output=self.clamp_output, act_fn=self.act_fn)
260 | elif self.parametrization == 'discrete_codebook':
261 | posterior = DiscreteCodebookAssignemnt(
262 | moments, self.codebook, clamp_output=self.clamp_output, act_fn=self.act_fn)
263 | elif self.parametrization == 'auto':
264 | posterior = Bottleneck(moments, act_fn=self.act_fn)
265 | return EncoderOutput(latent_dist=posterior)
266 |
267 | def decode(self, z, interpolate=True):
268 | x = self.decoder(z)
269 | if interpolate:
270 | x = F.interpolate(x, scale_factor=self.interpolation_factor, mode='bilinear', align_corners=False)
271 | return x
272 |
273 | def forward(
274 | self,
275 | sample: torch.FloatTensor,
276 | sample_posterior: bool = True,
277 | return_dict: bool = True,
278 | generator: Optional[torch.Generator] = None,
279 | rgb_sample: Optional[torch.FloatTensor] = None,
280 | valid_mask: Optional[torch.FloatTensor] = None,
281 | ) -> Union[VAEOutput, torch.FloatTensor]:
282 |
283 | x = sample
284 |
285 | # encode
286 | if rgb_sample is not None:
287 | x = torch.cat([x, rgb_sample], dim=1)
288 |
289 | posterior = self.encode(x).latent_dist
290 |
291 | # sample from posterior
292 | if sample_posterior:
293 | z = posterior.sample(generator=generator)
294 | else:
295 | z = posterior.mode()
296 |
297 | # (optional) mask out invalid pixels
298 | if valid_mask is not None:
299 | z = z * valid_mask[:, None]
300 |
301 | # decode
302 | dec = self.decode(z, interpolate=False)
303 |
304 | if not return_dict:
305 | return (dec,)
306 | return VAEOutput(sample=dec, posterior=posterior)
307 |
308 |
309 | class LayerNorm2d(nn.Module):
310 | # copied from detectron2
311 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
312 | super().__init__()
313 | self.weight = nn.Parameter(torch.ones(num_channels))
314 | self.bias = nn.Parameter(torch.zeros(num_channels))
315 | self.eps = eps
316 |
317 | def forward(self, x: torch.Tensor) -> torch.Tensor:
318 | u = x.mean(1, keepdim=True)
319 | s = (x - u).pow(2).mean(1, keepdim=True)
320 | x = (x - u) / torch.sqrt(s + self.eps)
321 | x = self.weight[:, None, None] * x + self.bias[:, None, None]
322 | return x
323 |
324 |
325 | class Bottleneck(object):
326 | """
327 | Simple bottleneck class to be used in the AE.
328 | """
329 |
330 | def __init__(
331 | self,
332 | parameters: torch.Tensor,
333 | act_fn: str = 'none',
334 | ):
335 | self.mean = parameters
336 | self.mean = self.to_range(self.mean, act_fn)
337 | self.act_fn = act_fn
338 |
339 | def to_range(self, x, act_fn):
340 | if act_fn == 'sigmoid':
341 | return 2 * F.sigmoid(x) - 1
342 | elif act_fn == 'tanh':
343 | return F.tanh(x)
344 | elif act_fn == 'clip':
345 | return torch.clamp(x, -5.0, 5.0)
346 | elif act_fn == 'l2':
347 | return F.normalize(x, dim=1, p=2)
348 | elif act_fn == 'none':
349 | return x
350 | else:
351 | raise NotImplementedError
352 |
353 | def mode(self):
354 | return self.mean
355 |
356 | def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
357 | return self.mean
358 |
359 | def kl(self):
360 | return torch.sum(torch.pow(self.mean, 2), dim=[1, 2, 3])
361 |
362 | def get_range(self):
363 | return RangeDict(min=self.mean.min(), max=self.mean.max())
364 |
365 | def __str__(self) -> str:
366 | return f"Bottleneck(mean={self.mean}, " \
367 | f"act_fn={self.act_fn})"
368 |
369 |
370 | class DiagonalGaussianDistribution(object):
371 | """
372 | Parametrizes a diagonal Gaussian distribution with a mean and log-variance.
373 | Allows computing the KL divergence with a standard diagonal Gaussian distribution.
374 | Added functionalities to diffusers library: bottleneck clamp and activation function.
375 | """
376 |
377 | def __init__(
378 | self,
379 | parameters: torch.Tensor,
380 | clamp_output: bool = False,
381 | act_fn: str = 'none',
382 | ):
383 |
384 | self.parameters = parameters
385 | if clamp_output:
386 | parameters = torch.clamp(parameters, -5.0, 5.0)
387 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
388 | self.mean = self.to_range(self.mean, act_fn)
389 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
390 | self.std = torch.exp(0.5 * self.logvar)
391 | self.var = torch.exp(self.logvar)
392 | self.clamp_output = clamp_output
393 | self.act_fn = act_fn
394 |
395 | def to_range(self, x, act_fn):
396 | if act_fn == 'sigmoid':
397 | return 2 * F.sigmoid(x) - 1
398 | elif act_fn == 'tanh':
399 | return F.tanh(x)
400 | elif act_fn == 'clip':
401 | return torch.clamp(x, -1, 1)
402 | elif act_fn == 'none':
403 | return x
404 | else:
405 | raise NotImplementedError
406 |
407 | def mode(self):
408 | return self.mean
409 |
410 | def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
411 | sample = torch.randn(
412 | self.mean.shape, generator=generator, device=self.parameters.device, dtype=self.parameters.dtype)
413 | x = self.mean + self.std * sample
414 | return x
415 |
416 | def kl(self):
417 | return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3])
418 |
419 | def get_range(self):
420 | return RangeDict(min=self.mean.min(), max=self.mean.max())
421 |
422 | def __str__(self) -> str:
423 | return f"DiagonalGaussianDistribution(mean={self.mean}, var={self.var}, " \
424 | f"clamp_output={self.clamp_output}, act_fn={self.act_fn})"
425 |
426 |
427 | class GumbelSoftmaxDistribution(object):
428 | """
429 | Parametrizes a uniform gumbel softmax distribution.
430 | """
431 |
432 | def __init__(
433 | self,
434 | parameters: torch.Tensor,
435 | codebook: nn.Embedding,
436 | clamp_output: bool = False,
437 | act_fn: str = 'none',
438 | ):
439 |
440 | self.parameters = parameters
441 | if clamp_output:
442 | parameters = torch.clamp(parameters, -5.0, 5.0)
443 |
444 | self.clamp_output = clamp_output
445 | self.act_fn = act_fn
446 | self.straight_through = True
447 | self.temp = 0.2
448 | self.codebook = codebook
449 | self.num_tokens = codebook.weight.shape[0]
450 | assert self.num_tokens == 128
451 | assert self.parameters.shape[1] == self.num_tokens
452 |
453 | def to_range(self, x, act_fn):
454 | if act_fn == 'sigmoid':
455 | return 2 * F.sigmoid(x) - 1
456 | elif act_fn == 'tanh':
457 | return F.tanh(x)
458 | elif act_fn == 'clip':
459 | return torch.clamp(x, -1, 1)
460 | elif act_fn == 'none':
461 | return x
462 | else:
463 | raise NotImplementedError
464 |
465 | def mode(self) -> torch.FloatTensor:
466 | indices = self.get_codebook_indices()
467 | # one_hot = torch.scatter(torch.zeros_like(self.parameters), 1, indices[:, None], 1.0)
468 | one_hot = F.one_hot(indices, num_classes=self.num_tokens).permute(0, 3, 1, 2).float()
469 | sampled = torch.einsum('b n h w, n d -> b d h w', one_hot, self.codebook.weight)
470 | return sampled
471 |
472 | def get_codebook_indices(self) -> torch.LongTensor:
473 | return self.parameters.argmax(dim=1)
474 |
475 | def get_codebook_probs(self) -> torch.FloatTensor:
476 | return nn.Softmax(dim=1)(self.parameters)
477 |
478 | def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
479 | soft_one_hot = F.gumbel_softmax(self.parameters, tau=self.temp, dim=1, hard=self.straight_through)
480 | sampled = torch.einsum('b n h w, n d -> b d h w', soft_one_hot, self.codebook.weight)
481 | return sampled
482 |
483 | def kl(self) -> torch.FloatTensor:
484 | logits = rearrange(self.parameters, 'b n h w -> b (h w) n')
485 | qy = F.softmax(logits, dim=-1)
486 | log_qy = torch.log(qy + 1e-10)
487 | log_uniform = torch.log(torch.tensor([1. / self.num_tokens], device=log_qy.device))
488 | kl_div = F.kl_div(log_uniform, log_qy, None, None, 'batchmean', log_target=True)
489 | return kl_div
490 |
491 | def get_range(self):
492 | raise NotImplementedError
493 |
494 | def __str__(self) -> str:
495 | return f"DiscreteGumbelSoftmaxDistribution(mean={self.parameters}, " \
496 | f"clamp_output={self.clamp_output}, act_fn={self.act_fn})"
497 |
498 |
499 | class DiscreteCodebookAssignemnt(object):
500 | """
501 | Parametrizes a discrete codebook distribution.
502 | """
503 |
504 | def __init__(
505 | self,
506 | parameters: torch.Tensor,
507 | codebook: nn.Embedding,
508 | clamp_output: bool = False,
509 | act_fn: str = 'none',
510 | ):
511 |
512 | self.parameters = parameters
513 | if clamp_output:
514 | parameters = torch.clamp(parameters, -5.0, 5.0)
515 | self.clamp_output = clamp_output
516 | self.act_fn = act_fn
517 | self.straight_through = True
518 | self.temp = 1.0
519 | self.codebook = codebook
520 | self.num_tokens = codebook.weight.shape[0]
521 | assert self.num_tokens == 128
522 | assert self.parameters.shape[1] == self.num_tokens
523 |
524 | def to_range(self, x, act_fn):
525 | if act_fn == 'sigmoid':
526 | return 2 * F.sigmoid(x) - 1
527 | elif act_fn == 'tanh':
528 | return F.tanh(x)
529 | elif act_fn == 'clip':
530 | return torch.clamp(x, -1, 1)
531 | elif act_fn == 'none':
532 | return x
533 | else:
534 | raise NotImplementedError
535 |
536 | def mode(self) -> torch.FloatTensor:
537 | indices = self.get_codebook_indices()
538 | # one_hot = torch.scatter(torch.zeros_like(self.parameters), 1, indices[:, None], 1.0)
539 | one_hot = F.one_hot(indices, num_classes=self.num_tokens).permute(0, 3, 1, 2).float()
540 | sampled = torch.einsum('b n h w, n d -> b d h w', one_hot, self.codebook.weight)
541 | return sampled
542 |
543 | def get_codebook_indices(self) -> torch.LongTensor:
544 | return self.parameters.argmax(dim=1)
545 |
546 | def get_codebook_probs(self) -> torch.FloatTensor:
547 | return nn.Softmax(dim=1)(self.parameters)
548 |
549 | def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
550 | _, indices = self.parameters.max(dim=1)
551 | y_hard = F.one_hot(indices, num_classes=self.num_tokens).permute(0, 3, 1, 2).float()
552 | y_hard = (y_hard - self.parameters).detach() + self.parameters
553 | sampled = torch.einsum('b n h w, n d -> b d h w', y_hard, self.codebook.weight)
554 | return sampled
555 |
556 | def kl(self):
557 | logits = rearrange(self.parameters, 'b n h w -> b (h w) n')
558 | qy = F.softmax(logits, dim=-1)
559 | log_qy = torch.log(qy + 1e-10)
560 | log_uniform = torch.log(torch.tensor([1. / self.num_tokens], device=log_qy.device))
561 | kl_div = F.kl_div(log_uniform, log_qy, None, None, 'batchmean', log_target=True)
562 | return kl_div
563 |
564 | def get_range(self):
565 | raise NotImplementedError
566 |
567 | def __str__(self) -> str:
568 | return f"DiscreteCodeBookAssignment(mean={self.parameters}, " \
569 | f"clamp_output={self.clamp_output}, act_fn={self.act_fn})"
570 |
--------------------------------------------------------------------------------
/ldmseg/schedulers/__init__.py:
--------------------------------------------------------------------------------
1 | from .ddim_scheduler import DDIMNoiseScheduler
2 |
3 | __all__ = [
4 | 'DDIMNoiseScheduler',
5 | ]
6 |
--------------------------------------------------------------------------------
/ldmseg/schedulers/ddim_scheduler.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Wouter Van Gansbeke
3 |
4 | This file contains the noise scheduler for the diffusion process.
5 | Based on the implement. in the diffusers library (Apache License): https://https://github.com/huggingface/diffusers
6 | Added features to DDIM scheduler(https://arxiv.org/abs/2102.09672), in summary:
7 | - Define method to remove noise from the noisy samples according to the adopted scheduler.
8 | - Define loss weights for each timestep. The weights are used to scale the loss for each timestep.
9 | (i.e., small timesteps are weighted less than large timesteps.)
10 | - Add glide cosine schedule from diffusers to DDIM as well.
11 | - Use a `step_offset` by default during inference for sampling segmentation maps from Guassian noise.
12 | """
13 |
14 | import math
15 | import torch
16 | import numpy as np
17 | from ldmseg.utils import OutputDict
18 | from typing import Optional, Union
19 |
20 |
21 | class DDIMNoiseSchedulerOutput(OutputDict):
22 | prev_sample: torch.FloatTensor
23 | pred_original_sample: Optional[torch.FloatTensor] = None
24 |
25 |
26 | class DDIMNoiseScheduler(object):
27 | """
28 | Noise scheduler for the diffusion process.
29 | Implementation is adapted from the diffusers library
30 | """
31 |
32 | def __init__(
33 | self,
34 | num_train_timesteps: int = 1000,
35 | beta_start: float = 0.0001,
36 | beta_end: float = 0.02,
37 | beta_schedule: str = "linear",
38 | clip_sample: bool = True,
39 | set_alpha_to_one: bool = True,
40 | steps_offset: int = 0,
41 | prediction_type: str = "epsilon",
42 | thresholding: bool = False,
43 | dynamic_thresholding_ratio: float = 0.995,
44 | clip_sample_range: float = 1.0,
45 | sample_max_value: float = 1.0,
46 | weight: str = 'none',
47 | max_snr: float = 5.0,
48 | device: Union[str, torch.device] = None,
49 | verbose: bool = True,
50 | ):
51 | if beta_schedule == "linear":
52 | self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
53 | elif beta_schedule == "scaled_linear":
54 | # this schedule is very specific to the latent diffusion model.
55 | self.betas = (
56 | torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
57 | )
58 | elif beta_schedule == "squaredcos_cap_v2":
59 | # Glide cosine schedule
60 | self.betas = self.get_betas_for_alpha_bar(num_train_timesteps)
61 | elif beta_schedule == "sigmoid":
62 | # GeoDiff sigmoid schedule
63 | betas = torch.linspace(-6, 6, num_train_timesteps)
64 | self.betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
65 | else:
66 | raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
67 |
68 | self.alphas = 1.0 - self.betas
69 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
70 |
71 | # At every step in ddim, we are looking into the previous alphas_cumprod
72 | # For the final step, there is no previous alphas_cumprod because we are already at 0
73 | # `set_alpha_to_one` decides whether we set this parameter simply to one or
74 | # whether we use the final alpha of the "non-previous" one.
75 | self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
76 |
77 | # compute loss weights
78 | self.compute_loss_weights(mode=weight, max_snr=max_snr)
79 | self.weights = self.weights.to(device)
80 |
81 | # set other parameters
82 | self.num_train_timesteps = num_train_timesteps
83 | self.num_inference_steps = None
84 | self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
85 | self.clip_sample = clip_sample
86 | self.clip_sample_range = clip_sample_range
87 | self.prediction_type = prediction_type
88 | self.thresholding = thresholding
89 | self.dynamic_thresholding_ratio = dynamic_thresholding_ratio
90 | self.steps_offset = steps_offset
91 | self.beta_schedule = beta_schedule
92 | self.beta_start = beta_start
93 | self.beta_end = beta_end
94 | self.init_noise_sigma = 1.0
95 | self.verbose = verbose
96 |
97 | def compute_loss_weights(self, mode='max_clamp_snr', max_snr=5.0):
98 | """
99 | Compute loss weights for each timestep. The weights are used to scale the loss of each timestep.
100 | Small timesteps are weighted less than large timesteps.
101 | """
102 |
103 | assert mode in ['inverse_log_snr', 'max_clamp_snr', 'linear', 'fixed', 'none']
104 | self.weight_mode = mode
105 | snr = self.alphas_cumprod / (1 - self.alphas_cumprod)
106 | if mode == 'inverse_log_snr':
107 | self.weights = torch.log(1. / snr).clamp(min=1)
108 | self.weights /= self.weights[-1] # normalize
109 | elif mode == 'max_clamp_snr':
110 | self.weights = snr.clamp(max=max_snr) / snr
111 | elif mode == 'fixed':
112 | self.weights = snr
113 | self.weights[:len(self.weights) // 4] = 0.1
114 | elif mode == 'linear':
115 | self.weights = torch.arange(1, len(snr) + 1) / len(snr)
116 | else:
117 | self.weights = torch.ones_like(snr)
118 |
119 | def set_timesteps_inference(self, num_inference_steps: int, device: Union[str, torch.device] = None, tmin: int = 0):
120 | """
121 | Set the timesteps for inference. This is used to compute the noise schedule for inference.
122 | We shift the timesteps by `steps_offset` to make sure the final timestep is always included (i.e., t = 999)
123 | """
124 |
125 | self.num_inference_steps = num_inference_steps
126 | step_ratio = self.num_train_timesteps // self.num_inference_steps
127 | self.steps_offset = step_ratio - 1
128 | timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
129 | self.timesteps = torch.from_numpy(timesteps).to(device)
130 | self.timesteps += self.steps_offset
131 | self.timesteps = self.timesteps[self.timesteps >= tmin]
132 |
133 | def move_timesteps_to(self, device: Union[str, torch.device]):
134 | """ Move timesteps to `device`
135 | """
136 | self.timesteps = self.timesteps.to(device)
137 |
138 | def get_betas_for_alpha_bar(self, num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor:
139 | """
140 | Used for Glide cosine schedule.
141 | Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
142 | (1-beta) over time from t = [0,1].
143 | """
144 |
145 | def alpha_bar(time_step):
146 | return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
147 |
148 | betas = []
149 | for i in range(num_diffusion_timesteps):
150 | t1 = i / num_diffusion_timesteps
151 | t2 = (i + 1) / num_diffusion_timesteps
152 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
153 | return torch.tensor(betas, dtype=torch.float32)
154 |
155 | def add_noise(
156 | self,
157 | original_samples: torch.FloatTensor,
158 | noise: torch.FloatTensor,
159 | timesteps: torch.IntTensor,
160 | scale: float = 1.0,
161 | mask_noise_perc: Optional[float] = None,
162 | ) -> torch.FloatTensor:
163 | """
164 | Add noise to the original samples according to the noise schedule.
165 | The core function of the diffusion process.
166 | """
167 |
168 | alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
169 | timesteps = timesteps.to(original_samples.device)
170 |
171 | sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
172 | sqrt_alpha_prod = sqrt_alpha_prod.flatten()
173 | while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
174 | sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
175 |
176 | sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
177 | sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
178 | while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
179 | sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
180 |
181 | if mask_noise_perc is not None:
182 | # fill percentage of the mask with zeros (i.e. remove noise)
183 | mask = torch.rand_like(original_samples) < mask_noise_perc
184 | noise *= mask
185 |
186 | noisy_samples = sqrt_alpha_prod * scale * original_samples + sqrt_one_minus_alpha_prod * noise
187 | return noisy_samples
188 |
189 | @torch.no_grad()
190 | def remove_noise(
191 | self,
192 | noisy_samples: torch.FloatTensor,
193 | noise: torch.FloatTensor,
194 | timesteps: torch.IntTensor,
195 | scale: float = 1.0,
196 | ) -> torch.FloatTensor:
197 | """
198 | Remove predicted noise from the noisy samples according to the defined noise scheduler.
199 | """
200 |
201 | # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
202 | alphas_cumprod = self.alphas_cumprod.to(device=noisy_samples.device, dtype=noisy_samples.dtype)
203 | timesteps = timesteps.to(noisy_samples.device)
204 |
205 | sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
206 | sqrt_alpha_prod = sqrt_alpha_prod.flatten()
207 | while len(sqrt_alpha_prod.shape) < len(noisy_samples.shape):
208 | sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
209 |
210 | sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
211 | sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
212 | while len(sqrt_one_minus_alpha_prod.shape) < len(noisy_samples.shape):
213 | sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
214 |
215 | original_samples = (noisy_samples - sqrt_one_minus_alpha_prod * noise) / (sqrt_alpha_prod * scale)
216 | return original_samples
217 |
218 | def step(
219 | self,
220 | model_output: torch.FloatTensor,
221 | timestep: int,
222 | sample: torch.FloatTensor,
223 | use_clipped_model_output: bool = False,
224 | ) -> DDIMNoiseSchedulerOutput:
225 | """
226 | Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
227 | process from the learned model outputs (most often the predicted noise).
228 | """
229 |
230 | # 1. get previous step value (=t-1)
231 | prev_timestep = timestep - self.num_train_timesteps // self.num_inference_steps
232 |
233 | # 2. compute alphas, betas
234 | alpha_prod_t = self.alphas_cumprod[timestep]
235 | alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
236 | beta_prod_t = 1 - alpha_prod_t
237 |
238 | # 3. compute predicted original sample from predicted noise also called
239 | if self.prediction_type == "epsilon":
240 | pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
241 | pred_epsilon = model_output
242 | elif self.prediction_type == "sample":
243 | pred_original_sample = model_output
244 | pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
245 | elif self.prediction_type == "v_prediction":
246 | pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
247 | pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
248 | else:
249 | raise NotImplementedError
250 |
251 | # 4. Clip or threshold "predicted x_0"
252 | if self.thresholding:
253 | raise NotImplementedError
254 | elif self.clip_sample:
255 | pred_original_sample = pred_original_sample.clamp(
256 | -self.clip_sample_range, self.clip_sample_range
257 | )
258 |
259 | if use_clipped_model_output:
260 | # the pred_epsilon is always re-derived from the clipped x_0 in Glide
261 | pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
262 |
263 | # 6. compute "direction pointing to x_t"
264 | pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * pred_epsilon
265 |
266 | # 7. compute x_t without "random noise"
267 | prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
268 |
269 | return DDIMNoiseSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
270 |
271 | def __str__(self) -> str:
272 | print_weights = self.weights if self.verbose else 'VerboseDisabled'
273 | return f"DDIMScheduler(num_inference_steps={self.num_inference_steps}, " \
274 | f"num_train_timesteps={self.num_train_timesteps}, " \
275 | f"prediction_type={self.prediction_type}, " \
276 | f"beta_start={self.beta_start}, " \
277 | f"beta_end={self.beta_end}, " \
278 | f"beta_schedule={self.beta_schedule}, " \
279 | f"clip_sample={self.clip_sample}, " \
280 | f"clip_sample_range={self.clip_sample_range}, " \
281 | f"thresholding={self.thresholding}, " \
282 | f"dynamic_thresholding_ratio={self.dynamic_thresholding_ratio}, " \
283 | f"steps_offset={self.steps_offset}, " \
284 | f"weight_mode={self.weight_mode}, " \
285 | f"weights={print_weights})"
286 |
287 | def __repr__(self) -> str:
288 | return self.__str__()
289 |
290 | def __len__(self) -> int:
291 | return self.num_train_timesteps
292 |
--------------------------------------------------------------------------------
/ldmseg/trainers/__init__.py:
--------------------------------------------------------------------------------
1 | from .trainers_ae import TrainerAE
2 | from .trainers_ldm_cond import TrainerDiffusion
3 |
4 |
5 | __all__ = [
6 | "TrainerAE",
7 | "TrainerDiffusion",
8 | ]
9 |
--------------------------------------------------------------------------------
/ldmseg/trainers/losses.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Wouter Van Gansbeke
3 |
4 | File with loss functions
5 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/)
6 | """
7 |
8 | from typing import Optional, Dict
9 | from scipy.optimize import linear_sum_assignment
10 |
11 | import torch
12 | import torch.nn.functional as F
13 | import torch.nn as nn
14 | import torch.distributed as dist
15 |
16 | from ldmseg.utils.utils import get_world_size
17 | from ldmseg.utils.detectron2_utils import (
18 | get_uncertain_point_coords_with_randomness,
19 | point_sample,
20 | )
21 |
22 |
23 | class SegmentationLosses(nn.Module):
24 | def __init__(
25 | self,
26 | num_points=12544,
27 | oversample_ratio=3,
28 | importance_sample_ratio=0.75,
29 | ignore_label=0,
30 | cost_mask=1.0,
31 | cost_class=1.0,
32 | temperature=1.0,
33 | ):
34 | super().__init__()
35 | self.num_points = num_points
36 | self.oversample_ratio = oversample_ratio
37 | self.importance_sample_ratio = importance_sample_ratio
38 | self.ignore_label = ignore_label
39 | self.temperature = temperature
40 | self.cost_mask = cost_mask
41 | self.cost_class = cost_class
42 | self.world_size = get_world_size()
43 |
44 | @torch.no_grad()
45 | def matcher(self, outputs, targets, pred_logits=None):
46 | """
47 | Matcher comes from Mask2Former: https://arxiv.org/abs/2112.01527
48 | This function is not used by default.
49 | """
50 |
51 | bs = len(outputs)
52 | num_queries = outputs.shape[1]
53 | indices = []
54 |
55 | for b in range(bs):
56 | out_mask = outputs[b] # [num_queries, H_pred, W_pred]
57 | tgt_mask = targets[b]['masks']
58 | if tgt_mask is None:
59 | indices.append(None)
60 | continue
61 |
62 | cost_class = 0
63 | if pred_logits is not None:
64 | tgt_ids = targets[b]["labels"]
65 | out_prob = pred_logits[b].softmax(-1) # [num_queries, num_classes]
66 | cost_class = -out_prob[:, tgt_ids]
67 | cost_class = -out_prob.view(-1, 1)
68 |
69 | out_mask = out_mask[:, None]
70 | tgt_mask = tgt_mask[:, None]
71 | # all masks share the same set of points for efficient matching!
72 | point_coords = torch.rand(1, self.num_points, 2, device=out_mask.device)
73 | # get gt labels
74 | tgt_mask = point_sample(
75 | tgt_mask,
76 | point_coords.repeat(tgt_mask.shape[0], 1, 1),
77 | align_corners=False,
78 | ).squeeze(1)
79 |
80 | out_mask = point_sample(
81 | out_mask,
82 | point_coords.repeat(out_mask.shape[0], 1, 1),
83 | align_corners=False,
84 | ).squeeze(1)
85 |
86 | with torch.cuda.amp.autocast(enabled=False):
87 | out_mask = out_mask.float()
88 | tgt_mask = tgt_mask.float()
89 | cost_mask = self.matcher_sigmoid_ce_loss(out_mask, tgt_mask)
90 | cost_dice = self.matcher_dice_loss(out_mask, tgt_mask)
91 |
92 | # Final cost matrix
93 | C = self.cost_mask * (cost_mask + cost_dice) + self.cost_class * cost_class
94 | C = C.reshape(num_queries, -1).cpu()
95 |
96 | indices.append(linear_sum_assignment(C))
97 |
98 | return [
99 | (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64))
100 | for i, j in indices
101 | ]
102 |
103 | @torch.no_grad()
104 | def _get_src_permutation_idx(self, indices):
105 | # permute predictions following indices
106 | batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
107 | src_idx = torch.cat([src for (src, _) in indices])
108 | return batch_idx, src_idx
109 |
110 | @torch.no_grad()
111 | def _get_tgt_permutation_idx(self, indices):
112 | # permute targets following indices
113 | batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
114 | tgt_idx = torch.cat([tgt for (_, tgt) in indices])
115 | return batch_idx, tgt_idx
116 |
117 | def loss_masks(
118 | self,
119 | outputs: torch.Tensor,
120 | targets: Dict,
121 | indices=None,
122 | ) -> torch.tensor:
123 | """
124 | Uncertainty loss for instance segmentation as used in Mask2Former: https://arxiv.org/abs/2112.01527
125 | Only minor modifications (i.e., simplified matching ids for ground truth and filtering of empty masks)
126 | """
127 |
128 | # we first need to convert the targets to the format expected by the loss
129 | if indices is None:
130 | targets, indices = self.prepare_targets(targets, ignore_label=self.ignore_label)
131 |
132 | # filter empty masks
133 | masks = [t["masks"] for t in targets]
134 | valids = [m is not None for m in masks]
135 | outputs = outputs[valids]
136 | indices = [idx for idx, v in zip(indices, valids) if v]
137 | masks = [m for m in masks if m is not None]
138 |
139 | # skip if no masks in current batch
140 | num_masks = sum(len(m) for m in masks)
141 | if num_masks == 0:
142 | return outputs.sum() * 0.0
143 | num_masks = torch.as_tensor([num_masks], dtype=torch.float, device=outputs.device)
144 | if dist.is_available() and dist.is_initialized():
145 | torch.distributed.all_reduce(num_masks)
146 | num_masks = torch.clamp(num_masks / self.world_size, min=1).item()
147 |
148 | src_idx = self._get_src_permutation_idx(indices)
149 | src_masks = outputs[src_idx]
150 | target_masks = torch.cat([t[idx[1]] for t, idx in zip(masks, indices)])
151 |
152 | src_masks = src_masks[:, None]
153 | target_masks = target_masks[:, None]
154 | with torch.no_grad():
155 | # sample point_coords
156 | if self.oversample_ratio > 0:
157 | point_coords = get_uncertain_point_coords_with_randomness(
158 | src_masks,
159 | lambda logits: self.calculate_uncertainty(logits),
160 | self.num_points,
161 | self.oversample_ratio,
162 | self.importance_sample_ratio,
163 | )
164 | else:
165 | point_coords = torch.rand(src_masks.shape[0], self.num_points, 2, device=src_masks.device)
166 |
167 | # get gt labels
168 | point_labels = point_sample(
169 | target_masks,
170 | point_coords,
171 | align_corners=False,
172 | ).squeeze(1)
173 |
174 | point_logits = point_sample(
175 | src_masks,
176 | point_coords,
177 | align_corners=False,
178 | ).squeeze(1)
179 |
180 | del src_masks
181 | del target_masks
182 |
183 | loss_mask = self.sigmoid_ce_loss(point_logits, point_labels, num_masks)
184 | loss_dice = self.dice_loss(point_logits, point_labels, num_masks)
185 | return loss_mask + loss_dice
186 |
187 | def dice_loss(
188 | self,
189 | inputs: torch.Tensor,
190 | targets: torch.Tensor,
191 | num_masks: float,
192 | ):
193 | """
194 | Compute the DICE loss, similar to generalized IOU for masks
195 | Args:
196 | inputs: A float tensor of arbitrary shape.
197 | The predictions for each example.
198 | targets: A float tensor with the same shape as inputs. Stores the binary
199 | classification label for each element in inputs
200 | (0 for the negative class and 1 for the positive class).
201 | """
202 | inputs = inputs.sigmoid()
203 | inputs = inputs.flatten(1)
204 | numerator = 2 * (inputs * targets).sum(-1)
205 | denominator = inputs.sum(-1) + targets.sum(-1)
206 | loss = 1 - (numerator + 1) / (denominator + 1)
207 | return loss.sum() / num_masks
208 |
209 | def matcher_dice_loss(
210 | self,
211 | inputs: torch.Tensor,
212 | targets: torch.Tensor,
213 | ):
214 | """
215 | Compute the DICE loss, similar to generalized IOU for masks
216 | Args:
217 | inputs: A float tensor of arbitrary shape.
218 | The predictions for each example.
219 | targets: A float tensor with the same shape as inputs. Stores the binary
220 | classification label for each element in inputs
221 | (0 for the negative class and 1 for the positive class).
222 | """
223 | inputs = inputs.sigmoid()
224 | inputs = inputs.flatten(1)
225 | numerator = 2 * torch.einsum("nc,mc->nm", inputs, targets)
226 | denominator = inputs.sum(-1)[:, None] + targets.sum(-1)[None, :]
227 | loss = 1 - (numerator + 1) / (denominator + 1)
228 | return loss
229 |
230 | def sigmoid_ce_loss(
231 | self,
232 | inputs: torch.Tensor,
233 | targets: torch.Tensor,
234 | num_masks: float,
235 | ):
236 | """
237 | Args:
238 | inputs: A float tensor of arbitrary shape.
239 | The predictions for each example.
240 | targets: A float tensor with the same shape as inputs. Stores the binary
241 | classification label for each element in inputs
242 | (0 for the negative class and 1 for the positive class).
243 | Returns:
244 | Loss tensor
245 | """
246 | loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
247 | return loss.mean(1).sum() / num_masks
248 |
249 | def matcher_sigmoid_ce_loss(
250 | self,
251 | inputs: torch.Tensor,
252 | targets: torch.Tensor,
253 | ):
254 | """
255 | Args:
256 | inputs: A float tensor of arbitrary shape.
257 | The predictions for each example.
258 | targets: A float tensor with the same shape as inputs. Stores the binary
259 | classification label for each element in inputs
260 | (0 for the negative class and 1 for the positive class).
261 | Returns:
262 | Loss tensor
263 | """
264 | hw = inputs.shape[1]
265 |
266 | pos = F.binary_cross_entropy_with_logits(
267 | inputs, torch.ones_like(inputs), reduction="none"
268 | )
269 | neg = F.binary_cross_entropy_with_logits(
270 | inputs, torch.zeros_like(inputs), reduction="none"
271 | )
272 |
273 | loss = torch.einsum("nc,mc->nm", pos, targets) + torch.einsum(
274 | "nc,mc->nm", neg, (1 - targets)
275 | )
276 |
277 | return loss / hw
278 |
279 | def calculate_uncertainty(self, logits: torch.Tensor) -> torch.Tensor:
280 | """
281 | Calculates the uncertainty when using sigmoid loss.
282 | Defined according to PointRend: https://arxiv.org/abs/1912.08193
283 |
284 | Args:
285 | logits (Tensor): A tensor of shape (R, 1, ...) for class-specific or
286 | class-agnostic, where R is the total number of predicted masks in all images and C is
287 | the number of foreground classes. The values are logits.
288 | Returns:
289 | scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with
290 | the most uncertain locations having the highest uncertainty score.
291 | """
292 | assert logits.shape[1] == 1
293 | gt_class_logits = logits.clone()
294 | return -(torch.abs(gt_class_logits))
295 |
296 | def calculate_uncertainty_seg(self, sem_seg_logits):
297 | """ Calculates the uncertainty when using a CE loss.
298 | Defined according to PointRend: https://arxiv.org/abs/1912.08193
299 | """
300 | top2_scores = torch.topk(sem_seg_logits, k=2, dim=1)[0]
301 | return (top2_scores[:, 1] - top2_scores[:, 0]).unsqueeze(1) # [B, 1, P]
302 |
303 | def loss_ce(
304 | self,
305 | outputs: torch.Tensor,
306 | targets: Dict,
307 | indices: Optional[torch.Tensor] = None,
308 | masks: Optional[torch.Tensor] = None,
309 | ) -> torch.Tensor:
310 |
311 | if indices is not None:
312 | src_masks = outputs
313 | masks = [t["masks"] for t in targets]
314 | src_idx = self._get_src_permutation_idx(indices)
315 | target_masks = torch.cat([t[idx[1]] for t, idx in zip(masks, indices)])
316 | target_masks = target_masks.bool()
317 |
318 | h, w = target_masks.shape[-2:]
319 | ce_targets = torch.full((len(targets), h, w), self.ignore_label, dtype=torch.int64, device=src_masks.device)
320 | for mask_idx, (x, y) in enumerate(zip(*src_idx)):
321 | mask_i = target_masks[mask_idx]
322 | ce_targets[x, mask_i] = y
323 | targets = ce_targets
324 |
325 | if masks is not None:
326 | targets[~masks[:, 0].bool()] = self.ignore_label
327 |
328 | with torch.no_grad():
329 | # sample point_coords
330 | if self.oversample_ratio > 0:
331 | point_coords = get_uncertain_point_coords_with_randomness(
332 | outputs,
333 | lambda logits: self.calculate_uncertainty_seg(logits),
334 | self.num_points,
335 | self.oversample_ratio,
336 | self.importance_sample_ratio,
337 | )
338 | else:
339 | point_coords = torch.rand(outputs.shape[0], self.num_points, 2, device=outputs.device)
340 |
341 | # get gt labels
342 | point_labels = point_sample(
343 | targets.unsqueeze(1).float(),
344 | point_coords,
345 | mode='nearest',
346 | align_corners=False,
347 | ).squeeze(1).to(torch.long)
348 |
349 | point_logits = point_sample(
350 | outputs,
351 | point_coords,
352 | align_corners=False,
353 | )
354 |
355 | ce_loss = F.cross_entropy(
356 | point_logits / self.temperature,
357 | point_labels,
358 | reduction="mean",
359 | ignore_index=self.ignore_label,
360 | )
361 |
362 | return ce_loss
363 |
364 | def point_loss(
365 | self,
366 | outputs: torch.Tensor,
367 | targets: torch.Tensor,
368 | do_matching: bool = False,
369 | masks: Optional[torch.Tensor] = None,
370 | ) -> torch.Tensor:
371 | """
372 | We use very effective losses to quantify the quality of the reconstruction.
373 | Overall loss function consists of 3 terms:
374 | - Cross entropy loss with uncertainty sampling
375 | - BCE + Dice loss with uncertainty sampling
376 |
377 | Based on the PointRend paper: https://arxiv.org/abs/1912.08193
378 | Combined CE with BCE + Dice loss
379 | This also works with only a vanilla CE loss but this impl. trains much faster
380 | """
381 |
382 | indices = None
383 | if do_matching:
384 | targets, _ = self.prepare_targets(targets, ignore_label=self.ignore_label)
385 | indices = self.matcher(outputs, targets)
386 |
387 | # (1) ce loss on uncertain regions
388 | ce_loss = self.loss_ce(outputs, targets, indices=indices, masks=masks)
389 |
390 | # (2) bce + dice loss for uncertain regions per object mask
391 | mask_loss = self.loss_masks(outputs, targets, indices=indices)
392 | losses = {'ce': ce_loss, 'mask': mask_loss}
393 |
394 | return losses
395 |
396 | @torch.no_grad()
397 | def prepare_targets(
398 | self,
399 | targets,
400 | ignore_label=0,
401 | ):
402 |
403 | """
404 | Function to convert targets to the format expected by the loss
405 |
406 | Args:
407 | targets: list[Dict]
408 | ignore_label: int
409 | """
410 | new_targets = []
411 | instance_ids = []
412 |
413 | for idx_t, target in enumerate(targets):
414 | unique_classes = torch.unique(target)
415 | masks = []
416 |
417 | # 0. target masks
418 | for idx in unique_classes:
419 | if idx == ignore_label:
420 | continue
421 | binary_target = torch.where(target == idx, 1, 0).to(torch.float32)
422 | masks.append(binary_target)
423 |
424 | unique_classes_excl = unique_classes[unique_classes != ignore_label]
425 | instance_ids.append(
426 | (
427 | unique_classes_excl.cpu(),
428 | torch.arange(len(unique_classes_excl), dtype=torch.int64)
429 | )
430 | )
431 |
432 | new_targets.append(
433 | {
434 | "labels": torch.full(
435 | (len(masks),), 0, dtype=torch.int64, device=target.device) if len(masks) > 0 else None,
436 | "masks": torch.stack(masks) if len(masks) > 0 else None,
437 | }
438 | )
439 | return new_targets, instance_ids
440 |
--------------------------------------------------------------------------------
/ldmseg/trainers/optim.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Wouter Van Gansbeke
3 |
4 | # File with optimizers for training
5 | # Parts of this file are based on detectron2 (Apache License): https://github.com/facebookresearch/detectron2
6 | """
7 |
8 | import torch
9 | from termcolor import colored
10 | from typing import Any, Dict
11 | from typing import List, Tuple, Optional, Callable, Set
12 | import copy
13 | from collections import defaultdict
14 |
15 |
16 | def get_optim(
17 | model: torch.nn.Module,
18 | base_lr: float = 0.0001,
19 | weight_decay: float = 0.0,
20 | weight_decay_norm: float = 0.0,
21 | betas: Tuple = (0.9, 0.999),
22 | lr_factor_func: Optional[Callable] = None,
23 | zero_redundancy: bool = False,
24 | verbose: bool = True,
25 | save_optim: Optional[bool] = None,
26 | ) -> Tuple[torch.optim.Optimizer, bool]:
27 |
28 | # TODO: make `overrides`` an optional argument
29 |
30 | ret = get_optimizer_params(
31 | model, weight_decay=None, weight_decay_norm=weight_decay_norm, base_lr=base_lr,
32 | lr_factor_func=lr_factor_func,
33 | overrides={
34 | 'module.backbone.image_encoder.net.pos_embed': {'weight_decay': 0.0},
35 | 'module.backbone_image.image_encoder.net.pos_embed': {'weight_decay': 0.0},
36 | },
37 | verbose=verbose,
38 | )
39 | if zero_redundancy:
40 | from torch.distributed.optim import ZeroRedundancyOptimizer
41 |
42 | optim = ZeroRedundancyOptimizer(
43 | ret, optimizer_class=torch.optim.AdamW,
44 | lr=base_lr, weight_decay=weight_decay, betas=betas,
45 | )
46 | save_optim = False if save_optim is None else save_optim
47 | else:
48 | optim = torch.optim.AdamW(ret, lr=base_lr, weight_decay=weight_decay, betas=betas)
49 | save_optim = True if save_optim is None else save_optim
50 | return optim, save_optim
51 |
52 |
53 | def get_optim_unet(
54 | model: torch.nn.Module,
55 | base_lr: float = 0.0001,
56 | weight_decay: float = 0.0,
57 | weight_decay_norm: float = 0.0,
58 | betas: Tuple = (0.9, 0.999),
59 | lr_factor_func: Optional[Callable] = None,
60 | zero_redundancy: bool = False,
61 | verbose: bool = True,
62 | save_optim: Optional[bool] = None,
63 | ) -> Tuple[torch.optim.Optimizer, bool]:
64 |
65 | ret = get_optimizer_params(
66 | model, weight_decay=None, weight_decay_norm=weight_decay_norm, base_lr=base_lr,
67 | lr_factor_func=lr_factor_func,
68 | overrides={'module.object_queries.weight': {'weight_decay': 0.0}},
69 | verbose=verbose,
70 | )
71 | if zero_redundancy:
72 | from torch.distributed.optim import ZeroRedundancyOptimizer
73 |
74 | optim = ZeroRedundancyOptimizer(
75 | ret, optimizer_class=torch.optim.AdamW,
76 | lr=base_lr, weight_decay=weight_decay, betas=betas,
77 | )
78 | save_optim = False if save_optim is None else save_optim
79 | else:
80 | optim = torch.optim.AdamW(ret, lr=base_lr, weight_decay=weight_decay, betas=betas)
81 | save_optim = True if save_optim is None else save_optim
82 | return optim, save_optim
83 |
84 |
85 | def get_optim_general(
86 | params: Dict[str, Any],
87 | name: str,
88 | p: Dict[str, Any],
89 | zero_redundancy: bool = False,
90 | save_optim: Optional[bool] = None,
91 | ) -> Tuple[torch.optim.Optimizer, bool]:
92 |
93 | """ Returns the optimizer to be used for training
94 | """
95 | if name != 'adamw8bit' and not zero_redundancy:
96 | print(colored('Most likely not enough memory for training', 'red'))
97 | print(colored(f'Please use zero redundancy optimizer with {name}', 'red'))
98 |
99 | if 'weight_decay_norm' in p.keys():
100 | del p['weight_decay_norm']
101 |
102 | if zero_redundancy:
103 | # reinitialize optimizer with ZeroRedundancyOptimizer
104 | from torch.distributed.optim import ZeroRedundancyOptimizer
105 | if name == 'adamw':
106 | optim_cls = torch.optim.AdamW
107 | elif name == 'adamw8bit':
108 | import bitsandbytes as bnb
109 | optim_cls = bnb.optim.AdamW8bit
110 | elif name == 'sgd':
111 | optim_cls = torch.optim.SGD
112 | if 'momentum' not in p.keys():
113 | p['momentum'] = 0.9
114 | if 'betas' in p.keys():
115 | del p['betas']
116 | else:
117 | raise NotImplementedError(f'Optimizer {name} not implemented with zero redundancy')
118 |
119 | optimizer = ZeroRedundancyOptimizer(
120 | params,
121 | optimizer_class=optim_cls,
122 | **p,
123 | )
124 | save_optim = False if save_optim is None else save_optim
125 | print(colored('Using ZeroRedundancyOptimizer w/o storing optim to save memory', 'red'))
126 | return optimizer, save_optim
127 |
128 | if name == 'adamw':
129 | optimizer = torch.optim.AdamW(params, **p)
130 | elif name == 'adamw8bit':
131 | import bitsandbytes as bnb
132 | optimizer = bnb.optim.AdamW8bit(params, **p)
133 | elif name == 'adam':
134 | optimizer = torch.optim.Adam(params, **p)
135 | elif name == 'sgd':
136 | if 'momentum' not in p.keys():
137 | p['momentum'] = 0.9
138 | if 'betas' in p.keys():
139 | del p['betas']
140 | optimizer = torch.optim.SGD(params, **p)
141 | else:
142 | raise NotImplementedError(f'Unknown optimizer {name}')
143 |
144 | save_optim = True if save_optim is None else save_optim
145 | return optimizer, save_optim
146 |
147 |
148 | def get_optimizer_params(
149 | model: torch.nn.Module,
150 | base_lr: Optional[float] = None,
151 | weight_decay: Optional[float] = None,
152 | weight_decay_norm: Optional[float] = None,
153 | bias_lr_factor: Optional[float] = 1.0,
154 | weight_decay_bias: Optional[float] = None,
155 | lr_factor_func: Optional[Callable] = None,
156 | overrides: Optional[Dict[str, Dict[str, float]]] = None,
157 | verbose: bool = False,
158 | ) -> List[Dict[str, Any]]:
159 |
160 | # Based on the implementation from detectron2
161 | # TODO: clean this up
162 |
163 | if overrides is None:
164 | overrides = {}
165 | defaults = {}
166 | if base_lr is not None:
167 | defaults["lr"] = base_lr
168 | if weight_decay is not None:
169 | defaults["weight_decay"] = weight_decay
170 | bias_overrides = {}
171 | if bias_lr_factor is not None and bias_lr_factor != 1.0:
172 | if base_lr is None:
173 | raise ValueError("bias_lr_factor requires base_lr")
174 | bias_overrides["lr"] = base_lr * bias_lr_factor
175 | if weight_decay_bias is not None:
176 | bias_overrides["weight_decay"] = weight_decay_bias
177 | if len(bias_overrides):
178 | if "bias" in overrides:
179 | raise ValueError("Conflicting overrides for 'bias'")
180 | overrides["bias"] = bias_overrides
181 | if lr_factor_func is not None:
182 | if base_lr is None:
183 | raise ValueError("lr_factor_func requires base_lr")
184 | norm_module_types = (
185 | torch.nn.BatchNorm1d,
186 | torch.nn.BatchNorm2d,
187 | torch.nn.BatchNorm3d,
188 | torch.nn.SyncBatchNorm,
189 | torch.nn.GroupNorm,
190 | torch.nn.InstanceNorm1d,
191 | torch.nn.InstanceNorm2d,
192 | torch.nn.InstanceNorm3d,
193 | torch.nn.LayerNorm,
194 | torch.nn.LocalResponseNorm,
195 | )
196 | params: List[Dict[str, Any]] = []
197 | memo: Set[torch.nn.parameter.Parameter] = set()
198 | for module_name, module in model.named_modules():
199 | for module_param_name, value in module.named_parameters(recurse=False):
200 | if not value.requires_grad:
201 | continue
202 | # Avoid duplicating parameters
203 | if value in memo:
204 | continue
205 | memo.add(value)
206 |
207 | hyperparams = copy.copy(defaults)
208 | if isinstance(module, norm_module_types) and weight_decay_norm is not None:
209 | hyperparams["weight_decay"] = weight_decay_norm
210 | if lr_factor_func is not None:
211 | hyperparams["lr"] *= lr_factor_func(f"{module_name}.{module_param_name}")
212 |
213 | hyperparams.update(overrides.get(f"{module_name}.{module_param_name}", {}))
214 | params.append({"params": [value], **hyperparams})
215 | if verbose:
216 | print(f'Adding {module_name}.{module_param_name} to optimizer with {hyperparams}')
217 | return reduce_param_groups(params)
218 |
219 |
220 | def _expand_param_groups(params: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
221 | # copied from detectron2
222 | ret = defaultdict(dict)
223 | for item in params:
224 | assert "params" in item
225 | cur_params = {x: y for x, y in item.items() if x != "params"}
226 | for param in item["params"]:
227 | ret[param].update({"params": [param], **cur_params})
228 | return list(ret.values())
229 |
230 |
231 | def reduce_param_groups(params: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
232 | # copied from detectron2
233 | params = _expand_param_groups(params)
234 | groups = defaultdict(list) # re-group all parameter groups by their hyperparams
235 | for item in params:
236 | cur_params = tuple((x, y) for x, y in item.items() if x != "params")
237 | groups[cur_params].extend(item["params"])
238 | ret = []
239 | for param_keys, param_values in groups.items():
240 | cur = {kv[0]: kv[1] for kv in param_keys}
241 | cur["params"] = param_values
242 | ret.append(cur)
243 | return ret
244 |
--------------------------------------------------------------------------------
/ldmseg/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .utils import *
2 | from .config import prepare_config
3 |
--------------------------------------------------------------------------------
/ldmseg/utils/config.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Wouter Van Gansbeke
3 |
4 | Functions to handle configuration files
5 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/)
6 | """
7 |
8 | import os
9 | import yaml
10 | import errno
11 | from easydict import EasyDict
12 |
13 |
14 | def mkdir_if_missing(directory):
15 | if not os.path.exists(directory):
16 | try:
17 | os.makedirs(directory)
18 | except OSError as e:
19 | if e.errno != errno.EEXIST:
20 | raise
21 |
22 |
23 | def load_config(config_file_exp):
24 | with open(config_file_exp, 'r') as stream:
25 | config = yaml.safe_load(stream)
26 |
27 | cfg = EasyDict()
28 |
29 | for k, v in config.items():
30 | cfg[k] = v
31 |
32 | return cfg
33 |
34 |
35 | def create_config(config_file_env, config_file_exp, run_idx=None):
36 | # Config for environment path
37 | with open(config_file_env, 'r') as stream:
38 | root_dir = yaml.safe_load(stream)['root_dir']
39 |
40 | with open(config_file_exp, 'r') as stream:
41 | config = yaml.safe_load(stream)
42 |
43 | cfg = EasyDict()
44 |
45 | # Copy
46 | for k, v in config.items():
47 | cfg[k] = v
48 |
49 | # Num classes
50 | if cfg['train_db_name'].lower() not in ['coco', 'coco_panoptic', 'cityscapes', 'cityscapes+coco']:
51 | raise ValueError('Invalid train db name {}'.format(cfg['train_db_name']))
52 |
53 | # Paths
54 | subdir = os.path.join(root_dir, config_file_exp.split('/')[-2])
55 | mkdir_if_missing(subdir)
56 | output_dir = os.path.join(subdir, os.path.basename(config_file_exp).split('.')[0])
57 | mkdir_if_missing(output_dir)
58 |
59 | if run_idx is not None:
60 | output_dir = os.path.join(output_dir, 'run_{}'.format(run_idx))
61 | mkdir_if_missing(output_dir)
62 |
63 | cfg['output_dir'] = output_dir
64 | cfg['checkpoint'] = os.path.join(cfg['output_dir'], 'checkpoint.pth.tar')
65 | cfg['best_model'] = os.path.join(cfg['output_dir'], 'best_model.pth.tar')
66 | cfg['save_dir'] = os.path.join(cfg['output_dir'], 'predictions')
67 | mkdir_if_missing(cfg['save_dir'])
68 | cfg['log_file'] = os.path.join(cfg['output_dir'], 'logger.txt')
69 |
70 | return cfg
71 |
72 |
73 | def prepare_config(cfg, root_dir, data_dir='', run_idx=None):
74 | # Num classes
75 | if cfg['train_db_name'].lower() not in ['coco', 'coco_panoptic', 'cityscapes', 'cityscapes+coco']:
76 | raise ValueError('Invalid train db name {}'.format(cfg['train_db_name']))
77 |
78 | # Paths
79 | output_dir = os.path.join(root_dir, cfg['train_db_name'])
80 | mkdir_if_missing(output_dir)
81 |
82 | # create a unique identifier for the run based on the current time
83 | if isinstance(run_idx, int) and run_idx < 0:
84 | from datetime import datetime
85 | from termcolor import colored
86 | now = datetime.now()
87 | run_idx = now.strftime("%Y%m%d_%H%M%S")
88 | print(colored('Using current time as run identifier: {}'.format(run_idx), 'red'))
89 | output_dir = os.path.join(output_dir, 'run_{}'.format(run_idx))
90 | mkdir_if_missing(output_dir)
91 |
92 | cfg['data_dir'] = data_dir
93 | cfg['output_dir'] = output_dir
94 | cfg['save_dir'] = os.path.join(cfg['output_dir'], 'predictions')
95 | mkdir_if_missing(cfg['save_dir'])
96 | cfg['log_file'] = os.path.join(cfg['output_dir'], 'logger.txt')
97 |
98 | return cfg, run_idx
99 |
--------------------------------------------------------------------------------
/ldmseg/utils/detectron2_utils.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | import torch
3 | import torch.nn.functional as F
4 | # directly copied from detectron2 (prevents importing detectron2 here)
5 |
6 |
7 | def cat(tensors: List[torch.Tensor], dim: int = 0):
8 | """
9 | Efficient version of torch.cat that avoids a copy if there is only a single element in a list
10 | """
11 | assert isinstance(tensors, (list, tuple))
12 | if len(tensors) == 1:
13 | return tensors[0]
14 | return torch.cat(tensors, dim)
15 |
16 |
17 | def get_uncertain_point_coords_with_randomness(
18 | coarse_logits, uncertainty_func, num_points, oversample_ratio, importance_sample_ratio
19 | ):
20 | """
21 | Sample points in [0, 1] x [0, 1] coordinate space based on their uncertainty. The unceratinties
22 | are calculated for each point using 'uncertainty_func' function that takes point's logit
23 | prediction as input.
24 | See PointRend paper for details.
25 |
26 | Args:
27 | coarse_logits (Tensor): A tensor of shape (N, C, Hmask, Wmask) or (N, 1, Hmask, Wmask) for
28 | class-specific or class-agnostic prediction.
29 | uncertainty_func: A function that takes a Tensor of shape (N, C, P) or (N, 1, P) that
30 | contains logit predictions for P points and returns their uncertainties as a Tensor of
31 | shape (N, 1, P).
32 | num_points (int): The number of points P to sample.
33 | oversample_ratio (int): Oversampling parameter.
34 | importance_sample_ratio (float): Ratio of points that are sampled via importnace sampling.
35 |
36 | Returns:
37 | point_coords (Tensor): A tensor of shape (N, P, 2) that contains the coordinates of P
38 | sampled points.
39 | """
40 | assert oversample_ratio >= 1
41 | assert importance_sample_ratio <= 1 and importance_sample_ratio >= 0
42 | num_boxes = coarse_logits.shape[0]
43 | num_sampled = int(num_points * oversample_ratio)
44 | point_coords = torch.rand(num_boxes, num_sampled, 2, device=coarse_logits.device)
45 | point_logits = point_sample(coarse_logits, point_coords, align_corners=False)
46 | # It is crucial to calculate uncertainty based on the sampled prediction value for the points.
47 | # Calculating uncertainties of the coarse predictions first and sampling them for points leads
48 | # to incorrect results.
49 | # To illustrate this: assume uncertainty_func(logits)=-abs(logits), a sampled point between
50 | # two coarse predictions with -1 and 1 logits has 0 logits, and therefore 0 uncertainty value.
51 | # However, if we calculate uncertainties for the coarse predictions first,
52 | # both will have -1 uncertainty, and the sampled point will get -1 uncertainty.
53 | point_uncertainties = uncertainty_func(point_logits)
54 | num_uncertain_points = int(importance_sample_ratio * num_points)
55 | num_random_points = num_points - num_uncertain_points
56 | idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
57 | shift = num_sampled * torch.arange(num_boxes, dtype=torch.long, device=coarse_logits.device)
58 | idx += shift[:, None]
59 | point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(
60 | num_boxes, num_uncertain_points, 2
61 | )
62 | if num_random_points > 0:
63 | point_coords = cat(
64 | [
65 | point_coords,
66 | torch.rand(num_boxes, num_random_points, 2, device=coarse_logits.device),
67 | ],
68 | dim=1,
69 | )
70 | return point_coords
71 |
72 |
73 | def point_sample(input, point_coords, **kwargs):
74 | """
75 | A wrapper around :function:`torch.nn.functional.grid_sample` to support 3D point_coords tensors.
76 | Unlike :function:`torch.nn.functional.grid_sample` it assumes `point_coords` to lie inside
77 | [0, 1] x [0, 1] square.
78 |
79 | Args:
80 | input (Tensor): A tensor of shape (N, C, H, W) that contains features map on a H x W grid.
81 | point_coords (Tensor): A tensor of shape (N, P, 2) or (N, Hgrid, Wgrid, 2) that contains
82 | [0, 1] x [0, 1] normalized point coordinates.
83 |
84 | Returns:
85 | output (Tensor): A tensor of shape (N, C, P) or (N, C, Hgrid, Wgrid) that contains
86 | features for points in `point_coords`. The features are obtained via bilinear
87 | interplation from `input` the same way as :function:`torch.nn.functional.grid_sample`.
88 | """
89 | add_dim = False
90 | if point_coords.dim() == 3:
91 | add_dim = True
92 | point_coords = point_coords.unsqueeze(2)
93 | output = F.grid_sample(input, 2.0 * point_coords - 1.0, **kwargs)
94 | if add_dim:
95 | output = output.squeeze(3)
96 | return output
97 |
98 |
--------------------------------------------------------------------------------
/ldmseg/utils/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Wouter Van Gansbeke
3 |
4 | File with helper functions
5 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/)
6 |
7 | """
8 |
9 | import os
10 | import sys
11 | import errno
12 | from typing import Optional
13 | from collections import OrderedDict
14 |
15 | import numpy as np
16 | import torch
17 | import torch.nn as nn
18 | import torch.distributed as dist
19 |
20 | from detectron2.utils.visualizer import (
21 | Visualizer, _PanopticPrediction,
22 | ColorMode, _OFF_WHITE, _create_text_labels
23 | )
24 |
25 |
26 | class OutputDict(OrderedDict):
27 | # TODO: change to NamedTuple
28 | # keep setitem and setattr in sync
29 | def __setitem__(self, key, value):
30 | super().__setitem__(key, value)
31 | super().__setattr__(key, value)
32 |
33 |
34 | def mkdir_if_missing(directory):
35 | if not os.path.exists(directory):
36 | try:
37 | os.makedirs(directory)
38 | except OSError as e:
39 | if e.errno != errno.EEXIST:
40 | raise
41 |
42 |
43 | # helper functions for distributed training
44 | def is_dist_avail_and_initialized() -> bool:
45 | if not dist.is_available():
46 | return False
47 | if not dist.is_initialized():
48 | return False
49 | return True
50 |
51 |
52 | def get_world_size() -> int:
53 | if not is_dist_avail_and_initialized():
54 | return 1
55 | return dist.get_world_size()
56 |
57 |
58 | def get_rank() -> int:
59 | if not is_dist_avail_and_initialized():
60 | return 0
61 | return dist.get_rank()
62 |
63 |
64 | def is_main_process() -> bool:
65 | return get_rank() == 0
66 |
67 |
68 | def has_batchnorms(model: nn.Module):
69 | bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)
70 | for _, module in model.named_modules():
71 | if isinstance(module, bn_types):
72 | return True
73 | return False
74 |
75 |
76 | def gpu_gather(tensor: torch.Tensor) -> torch.Tensor:
77 | if tensor.ndim == 0:
78 | tensor = tensor.clone()[None]
79 | output_tensors = [tensor.clone() for _ in range(dist.get_world_size())]
80 | dist.all_gather(output_tensors, tensor)
81 | return torch.cat(output_tensors, dim=0)
82 |
83 |
84 | def cosine_scheduler(
85 | base_value: float,
86 | final_value: float,
87 | epochs: int,
88 | niter_per_ep: int,
89 | start_warmup_value: int = 0,
90 | warmup_iters: Optional[int] = None,
91 | ) -> np.ndarray:
92 | """ Cosine scheduler with warmup.
93 | """
94 |
95 | warmup_schedule = np.array([])
96 | if warmup_iters is None:
97 | warmup_iters = 0
98 | if warmup_iters > 0:
99 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
100 |
101 | iters = np.arange(epochs * niter_per_ep - warmup_iters)
102 | schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))
103 |
104 | schedule = np.concatenate((warmup_schedule, schedule))
105 | assert len(schedule) == epochs * niter_per_ep
106 | return schedule
107 |
108 |
109 | def warmup_scheduler(
110 | base_value: float,
111 | final_value: float,
112 | epochs: int,
113 | niter_per_ep: int,
114 | start_warmup_value: int = 0,
115 | warmup_iters: Optional[int] = None,
116 | ) -> np.ndarray:
117 | """ Linear warmup scheduler.
118 | """
119 |
120 | warmup_schedule = np.array([])
121 | if warmup_iters is None:
122 | warmup_iters = 0
123 | if warmup_iters > 0:
124 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
125 |
126 | schedule = np.ones(epochs * niter_per_ep - warmup_iters) * base_value
127 | schedule = np.concatenate((warmup_schedule, schedule))
128 | assert len(schedule) == epochs * niter_per_ep
129 | return schedule
130 |
131 |
132 | def step_scheduler(
133 | base_value: float,
134 | final_value: float,
135 | epochs: int,
136 | niter_per_ep: int,
137 | decay_epochs: list = [20, 40],
138 | decay_rate: float = 0.1,
139 | start_warmup_value: int = 0,
140 | warmup_iters: Optional[int] = None,
141 | ) -> np.ndarray:
142 | """ Step scheduler with warmup.
143 | """
144 | assert isinstance(decay_epochs, list), "decay_epochs must be a list"
145 |
146 | warmup_schedule = np.array([])
147 | if warmup_iters is None:
148 | warmup_iters = 0
149 | if warmup_iters > 0:
150 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
151 |
152 | schedule = np.ones(epochs * niter_per_ep - warmup_iters) * base_value
153 | for decay_epoch in decay_epochs:
154 | schedule[int(decay_epoch * niter_per_ep - warmup_iters):] *= decay_rate
155 | schedule = np.concatenate((warmup_schedule, schedule))
156 | assert len(schedule) == epochs * niter_per_ep
157 | return schedule
158 |
159 |
160 | class AverageMeter(object):
161 | """Computes and stores the average and current value"""
162 |
163 | def __init__(self, name, fmt=":f"):
164 | self.name = name
165 | self.fmt = fmt
166 | self.reset()
167 |
168 | def reset(self):
169 | self.val = 0
170 | self.avg = 0
171 | self.sum = 0
172 | self.count = 0
173 |
174 | def update(self, val, n=1):
175 | self.val = val
176 | self.sum += val * n
177 | self.count += n
178 | self.avg = self.sum / self.count
179 |
180 | def __str__(self):
181 | fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
182 | return fmtstr.format(**self.__dict__)
183 |
184 |
185 | class ProgressMeter(object):
186 | def __init__(self, num_batches, meters, prefix=""):
187 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
188 | self.meters = meters
189 | self.prefix = prefix
190 |
191 | def display(self, batch):
192 | entries = [self.prefix + self.batch_fmtstr.format(batch)]
193 | entries += [str(meter) for meter in self.meters]
194 | print("\t".join(entries))
195 |
196 | def _get_batch_fmtstr(self, num_batches):
197 | num_digits = len(str(num_batches // 1))
198 | fmt = "{:" + str(num_digits) + "d}"
199 | return "[" + fmt + "/" + fmt.format(num_batches) + "]"
200 |
201 |
202 | class Logger(object):
203 | def __init__(self, fpath=None):
204 | self.console = sys.stdout
205 | self.file = None
206 | self.fpath = fpath
207 | if fpath is not None:
208 | if not os.path.exists(os.path.dirname(fpath)):
209 | os.makedirs(os.path.dirname(fpath))
210 | self.file = open(fpath, 'w')
211 | else:
212 | self.file = open(fpath, 'a')
213 |
214 | def __del__(self):
215 | self.close()
216 |
217 | def __enter__(self):
218 | pass
219 |
220 | def __exit__(self, *args):
221 | self.close()
222 |
223 | def write(self, msg):
224 | self.console.write(msg)
225 | if self.file is not None:
226 | self.file.write(msg)
227 |
228 | def flush(self):
229 | self.console.flush()
230 | if self.file is not None:
231 | self.file.flush()
232 | os.fsync(self.file.fileno())
233 |
234 | def close(self):
235 | self.console.close()
236 | if self.file is not None:
237 | self.file.close()
238 |
239 |
240 | def color_map(N: int = 256, normalized: bool = False):
241 | def bitget(byteval, idx):
242 | return ((byteval & (1 << idx)) != 0)
243 |
244 | dtype = 'float32' if normalized else 'uint8'
245 | cmap = np.zeros((N, 3), dtype=dtype)
246 | for i in range(N):
247 | r = g = b = 0
248 | c = i
249 | for j in range(8):
250 | r = r | (bitget(c, 0) << 7-j)
251 | g = g | (bitget(c, 1) << 7-j)
252 | b = b | (bitget(c, 2) << 7-j)
253 | c = c >> 3
254 |
255 | cmap[i] = np.array([r, g, b])
256 |
257 | cmap = cmap/255 if normalized else cmap
258 | return cmap
259 |
260 |
261 | def collate_fn(batch: dict):
262 | # TODO: make general
263 | images = torch.stack([d['image'] for d in batch])
264 | semseg = torch.stack([d['semseg'] for d in batch])
265 | image_semseg = torch.stack([d['image_semseg'] for d in batch])
266 | tokens = mask = inpainting_mask = text = meta = None
267 | if 'tokens' in batch[0]:
268 | tokens = torch.stack([d['tokens'] for d in batch])
269 | if 'mask' in batch[0]:
270 | mask = torch.stack([d['mask'] for d in batch])
271 | if 'inpainting_mask' in batch[0]:
272 | inpainting_mask = torch.stack([d['inpainting_mask'] for d in batch])
273 | if 'text' in batch[0]:
274 | text = [d['text'] for d in batch]
275 | if 'meta' in batch[0]:
276 | meta = [d['meta'] for d in batch]
277 | return {
278 | 'image': images,
279 | 'semseg': semseg,
280 | 'meta': meta,
281 | 'text': text,
282 | 'tokens': tokens,
283 | 'mask': mask,
284 | 'inpainting_mask': inpainting_mask,
285 | 'image_semseg': image_semseg
286 | }
287 |
288 |
289 | class MyVisualizer(Visualizer):
290 | def draw_panoptic_seg(
291 | self,
292 | panoptic_seg,
293 | segments_info,
294 | area_threshold=None,
295 | alpha=0.7,
296 | random_colors=False,
297 | suppress_thing_labels=False,
298 | ):
299 | """
300 | Only minor changes to the original function from detectron2.utils.visualizer
301 | """
302 |
303 | pred = _PanopticPrediction(panoptic_seg, segments_info, self.metadata)
304 |
305 | if self._instance_mode == ColorMode.IMAGE_BW:
306 | self.output.reset_image(self._create_grayscale_image(pred.non_empty_mask()))
307 |
308 | # draw mask for all semantic segments first i.e. "stuff"
309 | for mask, sinfo in pred.semantic_masks():
310 | category_idx = sinfo["category_id"]
311 | try:
312 | mask_color = [x / 255 for x in self.metadata["stuff_colors"][category_idx]]
313 | except AttributeError:
314 | mask_color = None
315 |
316 | text = self.metadata["stuff_classes"][category_idx]
317 | self.draw_binary_mask(
318 | mask,
319 | color=mask_color,
320 | edge_color=_OFF_WHITE,
321 | text=text,
322 | alpha=alpha,
323 | area_threshold=area_threshold,
324 | )
325 |
326 | # draw mask for all instances second
327 | all_instances = list(pred.instance_masks())
328 | if len(all_instances) == 0:
329 | return self.output
330 | masks, sinfo = list(zip(*all_instances))
331 | category_ids = [x["category_id"] for x in sinfo]
332 |
333 | try:
334 | scores = [x["score"] for x in sinfo]
335 | except KeyError:
336 | scores = None
337 | class_names = self.metadata["thing_classes"] if not suppress_thing_labels else ["object"] * 2
338 | labels = _create_text_labels(
339 | category_ids, scores, class_names, [x.get("iscrowd", 0) for x in sinfo]
340 | )
341 |
342 | try:
343 | colors = [
344 | self._jitter([x / 255 for x in self.metadata["thing_colors"][c]]) for c in category_ids
345 | ]
346 | except AttributeError:
347 | colors = None
348 |
349 | if random_colors:
350 | colors = None
351 | self.overlay_instances(masks=masks, labels=labels, assigned_colors=colors, alpha=alpha)
352 |
353 | return self.output
354 |
355 |
356 | def get_imagenet_stats():
357 | stats = {
358 | 'mean': torch.tensor([0.485, 0.456, 0.406], dtype=torch.float32).view(1, 3, 1, 1),
359 | 'std': torch.tensor([0.229, 0.224, 0.225], dtype=torch.float32).view(1, 3, 1, 1),
360 | 'mean_clip': torch.tensor([0.48145466, 0.4578275, 0.40821073], dtype=torch.float32).view(1, 3, 1, 1),
361 | 'std_clip': torch.tensor([0.26862954, 0.26130258, 0.27577711], dtype=torch.float32).view(1, 3, 1, 1),
362 | }
363 | return stats
364 |
--------------------------------------------------------------------------------
/ldmseg/version.py:
--------------------------------------------------------------------------------
1 | __version__ = '1.0.0'
2 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 | from ldmseg.version import __version__
3 |
4 | setup(
5 | name='ldmseg',
6 | packages=find_packages(),
7 | version=__version__,
8 | license='Attribution-NonCommercial 4.0 International',
9 | description='A simple latent diffusion framework for panoptic segmentation and mask inpainting',
10 | author='Wouter Van Gansbeke',
11 | url='https://github.com/segments-ai/latent-diffusion-segmentation',
12 | long_description_content_type='text/markdown',
13 | keywords=[
14 | 'artificial intelligence',
15 | 'computer vision',
16 | 'generative models',
17 | 'segmentation',
18 | ],
19 | install_requires=[
20 | 'torch',
21 | 'torchvision',
22 | 'einops',
23 | 'diffusers',
24 | 'transformers',
25 | 'xformers',
26 | 'accelerate',
27 | 'timm',
28 | 'scipy',
29 | 'opencv-python',
30 | 'pyyaml',
31 | 'easydict',
32 | 'hydra-core',
33 | 'termcolor',
34 | 'wandb',
35 | ],
36 | classifiers=[
37 | 'Development Status :: 4 - Beta',
38 | 'Intended Audience :: Developers',
39 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
40 | 'Programming Language :: Python :: 3.11',
41 | ],
42 | )
43 |
--------------------------------------------------------------------------------
/tools/configs/base/base.yaml:
--------------------------------------------------------------------------------
1 | # pretrained path
2 | pretrained_model_path: "CompVis/stable-diffusion-v1-4"
3 |
4 | # Logging
5 | wandb: False
6 |
7 | # Evaluate only
8 | eval_only: False
9 | load_path:
10 |
11 | # vae model path and kwargs
12 | image_scaling_factor: 0.18215 # standard SD scaling factor
13 | shared_vae_encoder: False
14 | vae_model_kwargs:
15 | in_channels: 7 # consider bit encoding
16 | int_channels: 256
17 | out_channels: 128
18 | block_out_channels: [32, 64, 128, 256]
19 | latent_channels: 4
20 | num_latents: 2
21 | num_upscalers: 2
22 | upscale_channels: 256
23 | norm_num_groups: 32
24 | scaling_factor: 0.2
25 | parametrization: 'gaussian'
26 | act_fn: 'none'
27 | clamp_output: False
28 | freeze_codebook: False
29 | num_mid_blocks: 0
30 | fuse_rgb: False
31 | resize_input: False
32 | skip_encoder: False
33 | pretrained_path:
34 |
35 | # Model
36 | backbone: unet
37 | model_kwargs:
38 | in_channels: 8
39 | init_mode_seg: copy
40 | init_mode_image: zero
41 | cond_channels: 0
42 | separate_conv: False
43 | separate_encoder: False
44 | add_adaptor: False
45 | init_mode_adaptor: random
46 |
47 | # Noise scheduler
48 | noise_scheduler_kwargs:
49 | prediction_type: epsilon
50 | beta_schedule: scaled_linear
51 | num_train_timesteps: 1000
52 | beta_start: 0.00085
53 | beta_end: 0.012
54 | steps_offset: 1
55 | clip_sample: False
56 | set_alpha_to_one: False
57 | thresholding: False
58 | dynamic_thresholding_ratio: 0.995
59 | clip_sample_range: 1.0
60 | sample_max_value: 1.0
61 | weight: none
62 | max_snr: 5.0
63 |
64 | # Training parameters
65 | train_kwargs:
66 | dropout: 0.0
67 | inpaint_mask_size: [64, 64]
68 | type_mask: ignore
69 | latent_mask: False
70 | encoding_mode: bits
71 | image_descriptors: remove
72 | caption_type: none
73 | caption_dropout: 1.0 # always drop captions by default
74 | prob_train_on_pred: 0.0
75 | prob_inpainting: 0.0
76 | min_noise_level: 0
77 | rgb_noise_level: 0
78 | cond_noise_level: 0
79 | self_condition: False
80 | sample_posterior: False
81 | sample_posterior_rgb: False
82 | remap_seg: True
83 | train_num_steps: 35000
84 | batch_size: 8
85 | accumulate: 1
86 | num_workers: 2
87 | loss: l2
88 | ohem_ratio: 1.
89 | cudnn: False
90 | fp16: False
91 | weight_dtype: float32
92 | use_xformers: False
93 | gradient_as_bucket_view: False
94 | clip_grad: 3.0
95 | allow_tf32: False
96 | freeze_layers: ['time_embedding']
97 | find_unused_parameters: False
98 | gradient_checkpointing: False
99 |
100 | # Loss weights
101 | loss_weights:
102 | mask: 1.0
103 | ce: 1.0
104 | kl: 0.0
105 |
106 | # Loss kwargs
107 | loss_kwargs:
108 | num_points: 12544
109 | oversample_ratio: 3
110 | importance_sample_ratio: 0.75
111 | cost_mask: 1.0
112 | cost_class: 1.0
113 | temperature: 1.0
114 |
115 | # Sampling parameters
116 | sampling_kwargs:
117 | num_inference_steps: 50
118 | guidance_scale: 7.5 # only applicable when using guidance with image descriptors
119 | seed: 0
120 | block_size: 2
121 | prob_mask: 0.5
122 |
123 | # Evaluation parameters / Visualization
124 | eval_kwargs:
125 | mask_th: 0.5
126 | count_th: 512
127 | overlap_th: 0.5
128 | batch_size: 16
129 | num_workers: 2
130 | vis_every: 5000
131 | print_freq: 100
132 |
133 | # Optimizer
134 | optimizer_name: adamw
135 | optimizer_kwargs:
136 | lr: 1.0e-4
137 | betas: [0.9, 0.999]
138 | weight_decay: 0.0
139 | weight_decay_norm: 0.0
140 | optimizer_zero_redundancy: False
141 | optimizer_backbone_multiplier: 1.0
142 | optimizer_save_optim: False
143 |
144 | # EMA (not used)
145 | ema_on: False
146 | ema_kwargs:
147 | decay: 0.9999
148 | device: cuda
149 |
150 | # LR scheduler
151 | lr_scheduler_name: warmup
152 | lr_scheduler_kwargs:
153 | final_lr: 0.000001 # N/A if scheduler_type is not cosine
154 | warmup_iters: 200
155 |
156 | # Transformations
157 | transformation_kwargs:
158 | type: crop_resize_pil
159 | size: 512 # size of the image during training
160 | size_rgb: 512 # resize rgb to this size
161 | max_size: 512 # max size of the image during eval
162 | scales: [352, 384, 416, 448, 480, 512, 544, 576, 608, 640]
163 | min_scale: 0.5
164 | max_scale: 1.5
165 | pad_value: 0
166 | flip: True
167 | normalize: False
168 | normalize_params:
169 | mean: [0.485, 0.456, 0.406]
170 | std: [0.229, 0.224, 0.225]
171 |
172 | # dataset information: will be overwritten by the dataset config
173 | train_db_name:
174 | val_db_name:
175 | split: 'val'
176 | num_classes: 128 # max number of instances we can detect
177 | num_bits: 7 # if we want to detect 128 instances, we need 7 bits
178 | has_bg: False
179 | ignore_label: 0
180 | fill_value: 0.5
181 | inpainting_strength: 0.0
182 |
--------------------------------------------------------------------------------
/tools/configs/config.yaml:
--------------------------------------------------------------------------------
1 | # default yaml
2 | #
3 |
4 | setup: simple_diffusion
5 | run_idx: -1
6 | debug: False
7 | defaults:
8 | - _self_
9 | - base: base
10 | - datasets: coco
11 | - env: root_paths
12 | - distributed: local
13 |
14 | hydra:
15 | output_subdir: null
16 |
--------------------------------------------------------------------------------
/tools/configs/datasets/coco.yaml:
--------------------------------------------------------------------------------
1 | # Dataset
2 | train_db_name: coco
3 | val_db_name: coco
4 | split: 'train'
5 | num_classes: 128
6 | has_bg: False
7 | ignore_label: 0
8 | fill_value: 0.5
9 | inpainting_strength: 0.0
10 |
--------------------------------------------------------------------------------
/tools/configs/distributed/local.yaml:
--------------------------------------------------------------------------------
1 | # distributed settting
2 |
3 | world_size: 1
4 | rank: 0
5 | dist_url: tcp://127.0.0.1:54286
6 | dist_backend: nccl
7 | multiprocessing_distributed: True
8 |
--------------------------------------------------------------------------------
/tools/configs/env/root_paths.yaml:
--------------------------------------------------------------------------------
1 | root_dir: '/scratch/wouter_vangansbeke/datasets/outputs/'
2 | data_dir: '/scratch/wouter_vangansbeke/datasets/'
3 |
--------------------------------------------------------------------------------
/tools/main_ae.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Wouter Van Gansbeke
3 |
4 | Main file for training auto-encoders and vaes
5 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/)
6 | """
7 |
8 | import os
9 | import sys
10 | import json
11 | import hydra
12 | import wandb
13 | import builtins
14 | from termcolor import colored
15 | from datetime import datetime
16 | from omegaconf import OmegaConf, DictConfig
17 | from typing import Dict, Any
18 |
19 | import torch
20 | import torch.multiprocessing as mp
21 | import torch.distributed as dist
22 | from diffusers import AutoencoderKL
23 |
24 | from ldmseg.models import GeneralVAESeg
25 | from ldmseg.trainers import TrainerAE
26 | from ldmseg.utils import prepare_config, Logger, is_main_process
27 |
28 |
29 | @hydra.main(version_base=None, config_path="configs/", config_name="config")
30 | def main(cfg: DictConfig) -> None:
31 | """
32 | Main function that calls multiple workers for distributed training
33 |
34 | args:
35 | cfg (DictConfig): configuration file as a dictionary
36 | """
37 |
38 | # define configs
39 | cfg = OmegaConf.to_object(cfg)
40 | wandb.config = cfg
41 | cfg_dist = cfg['distributed']
42 | cfg_dataset = cfg['datasets']
43 | cfg_base = cfg['base']
44 | project_dir = cfg['setup']
45 | cfg_dataset = cfg_base | cfg_dataset # overwrite base configs with dataset specific configs
46 | root_dir = os.path.join(cfg['env']['root_dir'], project_dir)
47 | data_dir = cfg['env']['data_dir']
48 | cfg_dataset, project_name = prepare_config(cfg_dataset, root_dir, data_dir, run_idx=cfg['run_idx'])
49 | project_name = f"{cfg_dataset['train_db_name']}_{project_name}"
50 | print(colored(f"Project name: {project_name}", 'red'))
51 |
52 | if cfg_dist['dist_url'] == "env://" and cfg_dist['world_size'] == -1:
53 | cfg_dist['world_size'] = int(os.environ["WORLD_SIZE"])
54 | cfg_dist['distributed'] = cfg_dist['world_size'] > 1 or cfg_dist['multiprocessing_distributed']
55 | # handle debug mode
56 | if cfg['debug']:
57 | print(colored("Running in debug mode!", "red"))
58 | cfg_dist['world_size'] = 1
59 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
60 | cfg_dataset['train_kwargs']['num_workers'] = 0
61 | cfg_dataset['eval_kwargs']['num_workers'] = 0
62 | ngpus_per_node = torch.cuda.device_count()
63 | # Since we have ngpus_per_node processes per node, the total world_size
64 | # needs to be adjusted accordingly
65 | cfg_dist['world_size'] = ngpus_per_node * cfg_dist['world_size']
66 | print(colored(f"World size: {cfg_dist['world_size']}", 'blue'))
67 | # main_worker process function
68 | if cfg['debug']:
69 | main_worker(0, ngpus_per_node, cfg_dist, cfg_dataset, project_name)
70 | else:
71 | # Use torch.multiprocessing.spawn to launch distributed processes
72 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, cfg_dist, cfg_dataset, project_name))
73 |
74 |
75 | def main_worker(
76 | gpu: int,
77 | ngpus_per_node: int,
78 | cfg_dist: Dict[str, Any],
79 | p: Dict[str, Any],
80 | name: str = 'segmentation_diffusion'
81 | ) -> None:
82 | """
83 | A single worker for distributed training
84 |
85 | args:
86 | gpu (int): local index of the gpu per node
87 | ngpus_per_node (int): number of gpus per node
88 | cfg_dist (Dict): arguments for distributed training
89 | p (Dict): arguments for training
90 | name (str): name of the run
91 | """
92 |
93 | cfg_dist['gpu'] = gpu
94 | cfg_dist['ngpus_per_node'] = ngpus_per_node
95 | if cfg_dist['multiprocessing_distributed'] and cfg_dist['gpu'] != 0:
96 | def print_pass(*args):
97 | pass
98 | builtins.print = print_pass
99 |
100 | if cfg_dist['gpu'] is not None:
101 | print("Use GPU: {} for printing".format(cfg_dist['gpu']))
102 |
103 | if cfg_dist['distributed']:
104 | if cfg_dist['dist_url'] == "env://" and cfg_dist['rank'] == -1:
105 | cfg_dist['rank'] = int(os.environ["RANK"])
106 | if cfg_dist['multiprocessing_distributed']:
107 | # For multiprocessing distributed training, rank needs to be the
108 | # global rank among all the processes
109 | cfg_dist['rank'] = cfg_dist['rank'] * ngpus_per_node + gpu
110 | print(colored("Starting distributed training", "yellow"))
111 | dist.init_process_group(backend=cfg_dist['dist_backend'],
112 | init_method=cfg_dist['dist_url'],
113 | world_size=cfg_dist['world_size'],
114 | rank=cfg_dist['rank'])
115 |
116 | print(colored("Initialized distributed training", "yellow"))
117 | # For multiprocessing distributed, DistributedDataParallel constructor
118 | # should always set the single device scope, otherwise,
119 | # DistributedDataParallel will use all available devices.
120 | torch.cuda.set_device(cfg_dist['gpu'])
121 |
122 | # logging and printing with wandb
123 | if p['wandb'] and is_main_process():
124 | wandb.init(name=name, project="ddm")
125 | sys.stdout = Logger(os.path.join(p['output_dir'], f'log_file_gpu_{cfg_dist["gpu"]}.txt'))
126 |
127 | # print configs
128 | p['name'] = name
129 | readable_p = json.dumps(p, indent=4, sort_keys=True)
130 | print(colored(readable_p, 'red'))
131 | print(colored(datetime.now(), 'yellow'))
132 |
133 | # get param dicts
134 | train_params, eval_params = p['train_kwargs'], p['eval_kwargs']
135 |
136 | # define model
137 | pretrained_model_path = p['pretrained_model_path']
138 | print('pretrained_model_path', pretrained_model_path)
139 | vae_encoder = None
140 | if p['shared_vae_encoder']:
141 | vae_image = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae")
142 | vae_encoder = torch.nn.Sequential(vae_image.encoder, vae_image.quant_conv)
143 | vae = GeneralVAESeg(**p['vae_model_kwargs'], encoder=vae_encoder)
144 | print(vae)
145 |
146 | # gradient checkpointing
147 | if p['train_kwargs']['gradient_checkpointing']:
148 | vae.enable_gradient_checkpointing()
149 |
150 | # load model on gpu
151 | vae = vae.to(cfg_dist['gpu'])
152 |
153 | # print number of trainable parameters
154 | print(colored(
155 | f"Number of trainable parameters: {sum(p.numel() for p in vae.parameters() if p.requires_grad) / 1e6 :.2f}M",
156 | 'yellow'))
157 |
158 | if cfg_dist['distributed']:
159 | # For multiprocessing distributed, DistributedDataParallel constructor
160 | # should always set the single device scope, otherwise,
161 | # DistributedDataParallel will use all available devices.
162 | assert cfg_dist['gpu'] is not None
163 | vae = torch.nn.parallel.DistributedDataParallel(
164 | vae, device_ids=[cfg_dist['gpu']],
165 | find_unused_parameters=train_params['find_unused_parameters'],
166 | gradient_as_bucket_view=train_params['gradient_as_bucket_view'],
167 | )
168 | print(colored(f"Batch size per process is {train_params['batch_size']}", 'yellow'))
169 | print(colored(f"Workers per process is {train_params['num_workers']}", 'yellow'))
170 | else:
171 | raise NotImplementedError("Only DistributedDataParallel is supported.")
172 |
173 | # define trainer
174 | trainer = TrainerAE(
175 | p,
176 | vae,
177 | args=cfg_dist, # distributed training args
178 | results_folder=p['output_dir'], # output directory
179 | save_and_sample_every=eval_params['vis_every'], # save and sample every n iters
180 | cudnn_on=train_params['cudnn'], # turn on cudnn
181 | fp16=train_params['fp16'], # turn on floating point 16
182 | )
183 |
184 | # resume training from checkpoint
185 | trainer.resume() # looks for model at run_idx (automatically determined)
186 | if p['load_path'] is not None:
187 | trainer.load(model_path=p['load_path']) # looks for model at load_path
188 |
189 | if p['eval_only']:
190 | trainer.compute_metrics(['miou', 'pq'])
191 | return
192 |
193 | # train
194 | trainer.train_loop()
195 |
196 | # terminate wandb
197 | if p['wandb'] and is_main_process():
198 | wandb.finish()
199 |
200 |
201 | if __name__ == "__main__":
202 | main()
203 |
--------------------------------------------------------------------------------
/tools/main_ldm.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Wouter Van Gansbeke
3 |
4 | Main file for training diffusion models
5 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/)
6 | """
7 |
8 | import os
9 | import sys
10 | import json
11 | import wandb
12 | import builtins
13 | import hydra
14 | from typing import Dict, Any
15 | from termcolor import colored
16 | from datetime import datetime
17 | from omegaconf import OmegaConf, DictConfig
18 |
19 | import torch
20 | import torch.nn as nn
21 | import torch.multiprocessing as mp
22 | import torch.distributed as dist
23 |
24 | from ldmseg.trainers import TrainerDiffusion
25 | from ldmseg.models import UNet, GeneralVAESeg, GeneralVAEImage, get_image_descriptor_model
26 | from ldmseg.schedulers import DDIMNoiseScheduler
27 | from ldmseg.utils import prepare_config, Logger, is_main_process
28 |
29 |
30 | @hydra.main(version_base=None, config_path="configs/", config_name="config")
31 | def main(cfg: DictConfig) -> None:
32 | """ Main function that calls multiple workers for distributed training
33 | """
34 |
35 | # define configs
36 | cfg = OmegaConf.to_object(cfg)
37 | wandb.config = cfg
38 | cfg_dist = cfg['distributed']
39 | cfg_dataset = cfg['datasets']
40 | cfg_base = cfg['base']
41 | project_dir = cfg['setup']
42 | cfg_dataset = cfg_base | cfg_dataset # overwrite base configs with dataset specific configs
43 | root_dir = os.path.join(cfg['env']['root_dir'], project_dir)
44 | data_dir = cfg['env']['data_dir']
45 | cfg_dataset, project_name = prepare_config(cfg_dataset, root_dir, data_dir, run_idx=cfg['run_idx'])
46 | project_name = f"{cfg_dataset['train_db_name']}_{project_name}"
47 | print(colored(f"Project name: {project_name}", 'red'))
48 |
49 | if cfg_dist['dist_url'] == "env://" and cfg_dist['world_size'] == -1:
50 | cfg_dist['world_size'] = int(os.environ["WORLD_SIZE"])
51 | cfg_dist['distributed'] = cfg_dist['world_size'] > 1 or cfg_dist['multiprocessing_distributed']
52 | # handle debug mode
53 | if cfg['debug']:
54 | print(colored("Running in debug mode!", "red"))
55 | cfg_dist['world_size'] = 1
56 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
57 | cfg_dataset['train_kwargs']['num_workers'] = 0
58 | cfg_dataset['eval_kwargs']['num_workers'] = 0
59 | ngpus_per_node = torch.cuda.device_count()
60 | # Since we have ngpus_per_node processes per node, the total world_size
61 | # needs to be adjusted accordingly
62 | cfg_dist['world_size'] = ngpus_per_node * cfg_dist['world_size']
63 | print(colored(f"World size: {cfg_dist['world_size']}", 'blue'))
64 | # main_worker process function
65 | if cfg['debug']:
66 | main_worker(0, ngpus_per_node, cfg_dist, cfg_dataset, project_name)
67 | else:
68 | # Use torch.multiprocessing.spawn to launch distributed processes: the
69 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, cfg_dist, cfg_dataset, project_name))
70 |
71 |
72 | def main_worker(
73 | gpu: int,
74 | ngpus_per_node: int,
75 | cfg_dist: Dict[str, Any],
76 | p: Dict[str, Any],
77 | name: str = 'simple_diffusion'
78 | ) -> None:
79 | """
80 | A single worker for distributed training
81 |
82 | args:
83 | gpu (int): gpu id
84 | ngpus_per_node (int): number of gpus per node
85 | cfg_dist (Dict[str, Any]): distributed configs
86 | p (Dict[str, Any]): training configs
87 | name (str): name of the experiment
88 | """
89 |
90 | cfg_dist['gpu'] = gpu
91 | cfg_dist['ngpus_per_node'] = ngpus_per_node
92 | if cfg_dist['multiprocessing_distributed'] and cfg_dist['gpu'] != 0:
93 | def print_pass(*args):
94 | pass
95 | builtins.print = print_pass
96 |
97 | if cfg_dist['gpu'] is not None:
98 | print("Use GPU: {} for printing".format(cfg_dist['gpu']))
99 |
100 | if cfg_dist['distributed']:
101 | if cfg_dist['dist_url'] == "env://" and cfg_dist['rank'] == -1:
102 | cfg_dist['rank'] = int(os.environ["RANK"])
103 | if cfg_dist['multiprocessing_distributed']:
104 | # For multiprocessing distributed training, rank needs to be the
105 | # global rank among all the processes
106 | cfg_dist['rank'] = cfg_dist['rank'] * ngpus_per_node + gpu
107 | print(colored("Starting distributed training", "yellow"))
108 | dist.init_process_group(backend=cfg_dist['dist_backend'],
109 | init_method=cfg_dist['dist_url'],
110 | world_size=cfg_dist['world_size'],
111 | rank=cfg_dist['rank'])
112 |
113 | print(colored("Initialized distributed training", "yellow"))
114 | # For multiprocessing distributed, DistributedDataParallel constructor
115 | # should always set the single device scope, otherwise,
116 | # DistributedDataParallel will use all available devices.
117 | torch.cuda.set_device(cfg_dist['gpu'])
118 |
119 | # logging and printing with wandb
120 | if p['wandb'] and is_main_process():
121 | wandb.init(name=name, project="ddm")
122 | sys.stdout = Logger(os.path.join(p['output_dir'], f'log_file_gpu_{cfg_dist["gpu"]}.txt'))
123 |
124 | # print configs
125 | p['name'] = name
126 | readable_p = json.dumps(p, indent=4, sort_keys=True)
127 | print(colored(readable_p, 'red'))
128 | print(colored(datetime.now(), 'yellow'))
129 |
130 | # get model params
131 | train_params, eval_params = p['train_kwargs'], p['eval_kwargs']
132 | pretrained_model_path = p['pretrained_model_path']
133 | print('Pretrained model path:', pretrained_model_path)
134 |
135 | semseg_vae_encoder = None
136 | cache_dir = os.path.join(p['data_dir'], 'cache')
137 | vae_image = GeneralVAEImage.from_pretrained(pretrained_model_path, subfolder="vae", cache_dir=cache_dir)
138 | vae_image.decoder = nn.Identity() # remove the decoder
139 | vae_image.set_scaling_factor(p['image_scaling_factor'])
140 | if p['shared_vae_encoder']:
141 | semseg_vae_encoder = torch.nn.Sequential(vae_image.encoder, vae_image.quant_conv) # NOTE: deepcopy is safer
142 | vae_semseg = GeneralVAESeg(**p['vae_model_kwargs'], encoder=semseg_vae_encoder)
143 | print(f"Scaling factors: image {vae_image.scaling_factor}, semseg {vae_semseg.scaling_factor}")
144 |
145 | # load the pretrained UNet model
146 | unet = UNet.from_pretrained(pretrained_model_path, subfolder="unet", cache_dir=cache_dir)
147 |
148 | # gradient checkpointing
149 | if p['train_kwargs']['gradient_checkpointing']:
150 | print(colored('Gradient checkpointing is enabled', 'yellow'))
151 | unet.enable_gradient_checkpointing()
152 |
153 | # load the pretrained image descriptor model
154 | # leverage prior knowledge by using descriptors: CLIP, DINO, MAE, or learnable
155 | image_descriptor_model, text_encoder, tokenizer = get_image_descriptor_model(
156 | train_params['image_descriptors'], pretrained_model_path, unet
157 | )
158 | # modify the encoder of the UNet and freeze layers
159 | unet.modify_encoder(**p['model_kwargs'])
160 | unet.freeze_layers(p['train_kwargs']['freeze_layers'])
161 | print('time embedding frozen')
162 |
163 | # define weight dtype
164 | assert train_params['weight_dtype'] in ['float32', 'float16']
165 | weight_dtype = torch.float32 if train_params['weight_dtype'] == 'float32' else torch.float16
166 | vae_image = vae_image.to(cfg_dist['gpu'], dtype=weight_dtype)
167 | vae_semseg = vae_semseg.to(cfg_dist['gpu'], dtype=torch.float32)
168 | unet = unet.to(cfg_dist['gpu'], dtype=torch.float32) # unet is always fp32
169 | if image_descriptor_model is not None:
170 | image_descriptor_model = image_descriptor_model.to(cfg_dist['gpu'], dtype=weight_dtype)
171 | if text_encoder is not None:
172 | text_encoder = text_encoder.to(cfg_dist['gpu'], dtype=weight_dtype)
173 |
174 | # define noise scheduler
175 | noise_scheduler = DDIMNoiseScheduler(**p['noise_scheduler_kwargs'], device=cfg_dist['gpu'])
176 | print(noise_scheduler)
177 |
178 | # print number of trainable parameters
179 | print(colored(
180 | f"Number of trainable parameters: {sum(p.numel() for p in unet.parameters() if p.requires_grad) / 1e6:.2f}M",
181 | 'yellow'))
182 |
183 | if cfg_dist['distributed']:
184 | # For multiprocessing distributed, DistributedDataParallel constructor
185 | # should always set the single device scope, otherwise,
186 | # DistributedDataParallel will use all available devices.
187 | assert cfg_dist['gpu'] is not None
188 | unet = torch.nn.parallel.DistributedDataParallel(
189 | unet, device_ids=[cfg_dist['gpu']],
190 | find_unused_parameters=p['train_kwargs']['find_unused_parameters'],
191 | gradient_as_bucket_view=p['train_kwargs']['gradient_as_bucket_view'],
192 | )
193 | print(colored(f"Batch size per process is {train_params['batch_size']}", 'yellow'))
194 | print(colored(f"Workers per process is {train_params['num_workers']}", 'yellow'))
195 | else:
196 | raise NotImplementedError("Only DistributedDataParallel is supported.")
197 |
198 | # define trainer
199 | trainer = TrainerDiffusion(
200 | p,
201 | vae_image, vae_semseg, unet, tokenizer, text_encoder, # models
202 | noise_scheduler, # noise scheduler for diffusion
203 | image_descriptor_model=image_descriptor_model, # image descriptor model
204 | ema_unet=None, # ema model for unet
205 | args=cfg_dist, # distributed training args
206 | results_folder=p['output_dir'], # output directory
207 | save_and_sample_every=eval_params['vis_every'], # save and sample every n iters
208 | cudnn_on=train_params['cudnn'], # turn on cudnn
209 | fp16=train_params['fp16'], # turn on floating point 16
210 | ema_on=p['ema_on'], # turn on exponential moving average
211 | weight_dtype=weight_dtype, # weight dtype
212 | )
213 |
214 | # resume training from checkpoint
215 | trainer.resume(load_vae=True) # looks for model at run_idx (automatically determined)
216 | trainer.load(model_path=p['load_path'], load_vae=True) # looks for model at load_path
217 |
218 | # only evaluate
219 | if p['eval_only']:
220 | max_iter = None
221 | trainer.compute_metrics(
222 | metrics=['pq'],
223 | dataloader=trainer.dl_val,
224 | threshold_output=True,
225 | save_images=True,
226 | seed=42,
227 | max_iter=max_iter,
228 | models_to_eval=[trainer.unet_model],
229 | num_inference_steps=trainer.num_inference_steps,
230 | )
231 | dist.barrier()
232 | return
233 |
234 | # train
235 | trainer.train_loop()
236 |
237 | # terminate wandb
238 | if p['wandb'] and is_main_process():
239 | wandb.finish()
240 |
241 |
242 | if __name__ == "__main__":
243 | main()
244 |
--------------------------------------------------------------------------------
/tools/main_ldm_slurm.py:
--------------------------------------------------------------------------------
1 | """
2 | Author: Wouter Van Gansbeke
3 |
4 | Main file for training diffusion models
5 | Licensed under the CC BY-NC 4.0 license (https://creativecommons.org/licenses/by-nc/4.0/)
6 | """
7 |
8 | import os
9 | import sys
10 | import json
11 | import wandb
12 | import builtins
13 | import hydra
14 | from typing import Dict, Any
15 | from termcolor import colored
16 | from datetime import datetime
17 | from omegaconf import OmegaConf, DictConfig
18 |
19 | import torch
20 | import torch.nn as nn
21 | import torch.distributed as dist
22 |
23 | from ldmseg.trainers import TrainerDiffusion
24 | from ldmseg.models import UNet, GeneralVAESeg, GeneralVAEImage, get_image_descriptor_model
25 | from ldmseg.schedulers import DDIMNoiseScheduler
26 | from ldmseg.utils import prepare_config, Logger, is_main_process
27 |
28 |
29 | @hydra.main(version_base=None, config_path="configs/", config_name="config")
30 | def main(cfg: DictConfig) -> None:
31 | """ Main function that calls multiple workers for distributed training
32 | """
33 |
34 | # define configs
35 | cfg = OmegaConf.to_object(cfg)
36 | wandb.config = cfg
37 | cfg_dist = cfg['distributed']
38 | cfg_dataset = cfg['datasets']
39 | cfg_base = cfg['base']
40 | project_dir = cfg['setup']
41 | cfg_dataset = cfg_base | cfg_dataset # overwrite base configs with dataset specific configs
42 | root_dir = os.path.join(cfg['env']['root_dir'], project_dir)
43 | data_dir = cfg['env']['data_dir']
44 | cfg_dataset, project_name = prepare_config(cfg_dataset, root_dir, data_dir, run_idx=cfg['run_idx'])
45 | project_name = f"{cfg_dataset['train_db_name']}_{project_name}"
46 | print(colored(f"Project name: {project_name}", 'red'))
47 | cfg_dist['distributed'] = True
48 |
49 | ngpus_per_node = torch.cuda.device_count()
50 | # Since we have ngpus_per_node processes per node, the total world_size
51 | # needs to be adjusted accordingly
52 | cfg_dist['rank'] = int(os.environ["SLURM_PROCID"])
53 | cfg_dist['world_size'] = int(os.environ["SLURM_NNODES"]) * int(
54 | os.environ["SLURM_TASKS_PER_NODE"][0])
55 | print(colored(f"World size: {cfg_dist['world_size']}", 'blue'))
56 | cfg_dist['gpu'] = cfg_dist['rank'] % torch.cuda.device_count()
57 | print(colored(f"GPU: {cfg_dist['gpu']}", 'blue'))
58 | cfg_dist['ngpus_per_node'] = ngpus_per_node
59 |
60 | # main_worker process function
61 | main_worker(cfg_dist, cfg_dataset, project_name)
62 | return
63 |
64 |
65 | def main_worker(
66 | cfg_dist: Dict[str, Any],
67 | p: Dict[str, Any],
68 | name: str = 'simple_diffusion'
69 | ) -> None:
70 | """
71 | A single worker for distributed training
72 |
73 | args:
74 | cfg_dist (Dict[str, Any]): distributed configs
75 | p (Dict[str, Any]): training configs
76 | name (str): name of the experiment
77 | """
78 |
79 | if cfg_dist['multiprocessing_distributed'] and cfg_dist['gpu'] != 0:
80 | def print_pass(*args):
81 | pass
82 | builtins.print = print_pass
83 |
84 | if cfg_dist['gpu'] is not None:
85 | print("Use GPU: {} for printing".format(cfg_dist['gpu']))
86 |
87 | if cfg_dist['distributed']:
88 | assert cfg_dist['multiprocessing_distributed']
89 | # cfg_dist['rank'] = cfg_dist['rank'] * ngpus_per_node + gpu
90 | print(colored("Starting distributed training", "yellow"))
91 | dist.init_process_group(backend=cfg_dist['dist_backend'],
92 | init_method=cfg_dist['dist_url'],
93 | world_size=cfg_dist['world_size'],
94 | rank=cfg_dist['rank'])
95 |
96 | print(colored("Initialized distributed training", "yellow"))
97 | # For multiprocessing distributed, DistributedDataParallel constructor
98 | # should always set the single device scope, otherwise,
99 | # DistributedDataParallel will use all available devices.
100 | torch.cuda.set_device(cfg_dist['gpu'])
101 |
102 | # logging and printing with wandb
103 | if p['wandb'] and is_main_process():
104 | wandb.init(name=name, project="ddm")
105 | sys.stdout = Logger(os.path.join(p['output_dir'], f'log_file_gpu_{cfg_dist["gpu"]}.txt'))
106 |
107 | # print configs
108 | p['name'] = name
109 | readable_p = json.dumps(p, indent=4, sort_keys=True)
110 | print(colored(readable_p, 'red'))
111 | print(colored(datetime.now(), 'yellow'))
112 |
113 | # get model params
114 | train_params, eval_params = p['train_kwargs'], p['eval_kwargs']
115 | pretrained_model_path = p['pretrained_model_path']
116 | print('Pretrained model path:', pretrained_model_path)
117 |
118 | semseg_vae_encoder = None
119 | cache_dir = os.path.join(p['data_dir'], 'cache')
120 | vae_image = GeneralVAEImage.from_pretrained(
121 | pretrained_model_path, subfolder="vae", cache_dir=cache_dir)
122 | vae_image.decoder = nn.Identity() # remove the decoder
123 | vae_image.set_scaling_factor(p['image_scaling_factor'])
124 | if p['shared_vae_encoder']:
125 | semseg_vae_encoder = torch.nn.Sequential(vae_image.encoder, vae_image.quant_conv) # NOTE: deepcopy is safer
126 | vae_semseg = GeneralVAESeg(**p['vae_model_kwargs'], encoder=semseg_vae_encoder)
127 | print(f"Scaling factors: image {vae_image.scaling_factor}, semseg {vae_semseg.scaling_factor}")
128 |
129 | # load the pretrained UNet model
130 | unet = UNet.from_pretrained(pretrained_model_path, subfolder="unet", cache_dir=cache_dir)
131 |
132 | # gradient checkpointing
133 | if p['train_kwargs']['gradient_checkpointing']:
134 | print(colored('Gradient checkpointing is enabled', 'yellow'))
135 | unet.enable_gradient_checkpointing()
136 |
137 | # load the pretrained image descriptor model
138 | # leverage prior knowledge by using descriptors: CLIP, DINO, MAE, or learnable
139 | image_descriptor_model, text_encoder, tokenizer = get_image_descriptor_model(
140 | train_params['image_descriptors'], pretrained_model_path, unet
141 | )
142 | # modify the encoder of the UNet and freeze layers
143 | unet.modify_encoder(**p['model_kwargs'])
144 | unet.freeze_layers(p['train_kwargs']['freeze_layers'])
145 | print('time embedding frozen')
146 |
147 | # define weight dtype
148 | assert train_params['weight_dtype'] in ['float32', 'float16']
149 | weight_dtype = torch.float32 if train_params['weight_dtype'] == 'float32' else torch.float16
150 | vae_image = vae_image.to(cfg_dist['gpu'], dtype=weight_dtype)
151 | vae_semseg = vae_semseg.to(cfg_dist['gpu'], dtype=torch.float32)
152 | unet = unet.to(cfg_dist['gpu'], dtype=torch.float32) # unet is always fp32
153 | if image_descriptor_model is not None:
154 | image_descriptor_model = image_descriptor_model.to(cfg_dist['gpu'], dtype=weight_dtype)
155 | if text_encoder is not None:
156 | text_encoder = text_encoder.to(cfg_dist['gpu'], dtype=weight_dtype)
157 |
158 | # define noise scheduler
159 | noise_scheduler = DDIMNoiseScheduler(**p['noise_scheduler_kwargs'], device=cfg_dist['gpu'])
160 | print(noise_scheduler)
161 |
162 | # print number of trainable parameters
163 | print(colored(
164 | f"Number of trainable parameters: {sum(p.numel() for p in unet.parameters() if p.requires_grad) / 1e6:.2f}M",
165 | 'yellow'))
166 |
167 | if cfg_dist['distributed']:
168 | # For multiprocessing distributed, DistributedDataParallel constructor
169 | # should always set the single device scope, otherwise,
170 | # DistributedDataParallel will use all available devices.
171 | assert cfg_dist['gpu'] is not None
172 | unet = torch.nn.parallel.DistributedDataParallel(
173 | unet, device_ids=[cfg_dist['gpu']],
174 | find_unused_parameters=p['train_kwargs']['find_unused_parameters'],
175 | gradient_as_bucket_view=p['train_kwargs']['gradient_as_bucket_view'],
176 | )
177 | print(colored(f"Batch size per process is {train_params['batch_size']}", 'yellow'))
178 | print(colored(f"Workers per process is {train_params['num_workers']}", 'yellow'))
179 | else:
180 | raise NotImplementedError("Only DistributedDataParallel is supported.")
181 |
182 | # define trainer
183 | trainer = TrainerDiffusion(
184 | p,
185 | vae_image, vae_semseg, unet, tokenizer, text_encoder, # models
186 | noise_scheduler, # noise scheduler for diffusion
187 | image_descriptor_model=image_descriptor_model, # image descriptor model
188 | ema_unet=None, # ema model for unet
189 | args=cfg_dist, # distributed training args
190 | results_folder=p['output_dir'], # output directory
191 | save_and_sample_every=eval_params['vis_every'], # save and sample every n iters
192 | cudnn_on=train_params['cudnn'], # turn on cudnn
193 | fp16=train_params['fp16'], # turn on floating point 16
194 | ema_on=p['ema_on'], # turn on exponential moving average
195 | weight_dtype=weight_dtype, # weight dtype
196 | )
197 |
198 | # resume training from checkpoint
199 | trainer.resume(load_vae=True) # looks for model at run_idx (automatically determined)
200 | trainer.load(model_path=p['load_path'], load_vae=True) # looks for model at load_path
201 |
202 | # only evaluate
203 | if p['eval_only']:
204 | max_iter = None
205 | trainer.compute_metrics(
206 | metrics=['pq'],
207 | dataloader=trainer.dl_val,
208 | threshold_output=True,
209 | save_images=True,
210 | seed=42,
211 | max_iter=max_iter,
212 | models_to_eval=[trainer.unet_model],
213 | num_inference_steps=trainer.num_inference_steps,
214 | )
215 | dist.barrier()
216 | return
217 |
218 | # train
219 | trainer.train_loop()
220 |
221 | # terminate wandb
222 | if p['wandb'] and is_main_process():
223 | wandb.finish()
224 |
225 |
226 | if __name__ == "__main__":
227 | main()
228 |
--------------------------------------------------------------------------------
/tools/scripts/eval.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | BS=${1-32}
3 |
4 | # verify the pretrained model (at 'pretrained/ldmseg.pt')
5 |
6 | export OMP_NUM_THREADS=1
7 | python -W ignore tools/main_ldm.py \
8 | datasets=coco \
9 | base.train_kwargs.batch_size=$BS \
10 | base.train_kwargs.weight_dtype=float16 \
11 | base.vae_model_kwargs.scaling_factor=0.18215 \
12 | base.transformation_kwargs.type=crop_resize_pil \
13 | base.transformation_kwargs.size=512 \
14 | base.eval_kwargs.count_th=512 \
15 | base.sampling_kwargs.num_inference_steps=50 \
16 | base.train_kwargs.self_condition=True \
17 | base.model_kwargs.cond_channels=4 \
18 | base.load_path='pretrained/ldmseg.pt' \
19 | base.eval_only=True \
20 |
--------------------------------------------------------------------------------
/tools/scripts/install_env.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | miniconda3_path=${1-'/home/ubuntu/datasets/miniconda3'}
3 |
4 | # check if miniconda3_path exists
5 | if [ ! -d "$miniconda3_path" ]; then
6 | echo "Error: miniconda3_path does not exist."
7 | exit 1
8 | fi
9 |
10 | # check we are running this from ldmseg root directory
11 | if [ ! -d "tools" ]; then
12 | echo "Error: tools directory does not exist. You're not in the repo's root directory."
13 | exit 1
14 | fi
15 |
16 | # prints
17 | echo "Installing environment, using miniconda path ..."
18 | echo ${miniconda3_path}/etc/profile.d/conda.sh
19 |
20 | # install environment
21 | . ${miniconda3_path}/etc/profile.d/conda.sh # initialize conda
22 | conda create -n ldmseg_env python -y
23 | conda activate ldmseg_env
24 |
25 | # install packages
26 | python -m pip install -e .
27 | pip install git+https://github.com/facebookresearch/detectron2.git # detectron2
28 | pip install git+https://github.com/cocodataset/panopticapi.git # pq evaluation (optional)
29 | pip install flake8 black # code formatting (optional)
30 |
--------------------------------------------------------------------------------
/tools/scripts/install_env_manual.sh:
--------------------------------------------------------------------------------
1 | # Example script to install environment:
2 | # adapt paths accordingly
3 |
4 | wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
5 | sh Miniconda3-latest-Linux-x86_64.sh
6 |
7 | conda create -n pytorch_cuda12 python=3.11
8 | conda activate pytorch_cuda12
9 | conda install -c "nvidia/label/cuda-12.1.0" cuda-toolkit
10 | pip install torch torchvision
11 | pip install xformers einops diffusers transformers accelerate timm scipy opencv-python
12 | pip install pyyaml easydict hydra-core termcolor wandb
13 |
14 | python -m pip install -e .
15 |
16 | conda install -c conda-forge gcc=10.4
17 | conda install -c conda-forge gxx=10.4
18 | export CC=/scratch/wouter_vangansbeke/miniconda3/envs/pytorch_cuda12/bin/gcc
19 | export CXX=/scratch/wouter_vangansbeke/miniconda3/envs/pytorch_cuda12/bin/g++
20 |
21 | pip install git+https://github.com/facebookresearch/detectron2.git
22 | pip install git+https://github.com/cocodataset/panopticapi.git
23 |
--------------------------------------------------------------------------------
/tools/scripts/train_ae.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | BS=${1-8}
3 |
4 | export OMP_NUM_THREADS=1
5 | python -W ignore tools/main_ae.py \
6 | datasets=coco \
7 | debug=$DEBUG \
8 | base.wandb=False \
9 | base.train_kwargs.batch_size=$BS \
10 | base.train_kwargs.accumulate=1 \
11 | base.train_kwargs.train_num_steps=90000 \
12 | base.train_kwargs.fp16=True \
13 | base.vae_model_kwargs.num_mid_blocks=0 \
14 | base.vae_model_kwargs.num_upscalers=2 \
15 | base.optimizer_name='adamw' \
16 | base.optimizer_kwargs.lr=1e-4 \
17 | base.optimizer_kwargs.weight_decay=0.05 \
18 | base.transformation_kwargs.type=crop_resize_pil \
19 | base.transformation_kwargs.size=512 \
20 | base.eval_kwargs.mask_th=0.8 \
21 | base.train_kwargs.prob_inpainting=0.0 \
22 | base.vae_model_kwargs.parametrization=gaussian \
23 |
--------------------------------------------------------------------------------
/tools/scripts/train_diffusion.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | BS=${1-32}
3 | ITERS=${2-90000}
4 |
5 | export OMP_NUM_THREADS=1
6 | python -W ignore tools/main_ldm.py \
7 | datasets=coco \
8 | debug=False \
9 | base.wandb=False \
10 | base.train_kwargs.train_num_steps=$ITERS \
11 | base.train_kwargs.batch_size=$BS \
12 | base.train_kwargs.accumulate=1 \
13 | base.eval_kwargs.vis_every=500 \
14 | base.train_kwargs.gradient_checkpointing=True \
15 | base.train_kwargs.weight_dtype=float16 \
16 | base.train_kwargs.fp16=True \
17 | base.vae_model_kwargs.pretrained_path='pretrained/ae.pt' \
18 | base.vae_model_kwargs.parametrization=gaussian \
19 | base.vae_model_kwargs.num_upscalers=2 \
20 | base.vae_model_kwargs.num_mid_blocks=0 \
21 | base.noise_scheduler_kwargs.prediction_type='epsilon' \
22 | base.noise_scheduler_kwargs.weight='max_clamp_snr' \
23 | base.noise_scheduler_kwargs.max_snr=2.0 \
24 | base.vae_model_kwargs.scaling_factor=0.18215 \
25 | base.train_kwargs.ohem_ratio=1.0 \
26 | base.optimizer_name='adamw' \
27 | base.optimizer_zero_redundancy=True \
28 | base.optimizer_kwargs.lr=1.0e-4 \
29 | base.optimizer_kwargs.weight_decay=0.05 \
30 | base.train_kwargs.clip_grad=1.0 \
31 | base.transformation_kwargs.type=crop_resize_pil \
32 | base.transformation_kwargs.size=512 \
33 | base.transformation_kwargs.max_size=512 \
34 | "base.train_kwargs.freeze_layers=['time_embedding']" \
35 | base.eval_kwargs.mask_th=0.9 \
36 | base.eval_kwargs.overlap_th=0.9 \
37 | base.eval_kwargs.count_th=512 \
38 | base.sampling_kwargs.num_inference_steps=50 \
39 | base.train_kwargs.self_condition=True \
40 | base.model_kwargs.cond_channels=4 \
41 | base.lr_scheduler_name=cosine \
42 |
--------------------------------------------------------------------------------