├── LICENSE.md ├── README.md ├── assets ├── method.jpg ├── result.jpg ├── results_vary_target_pose.jpg └── teaser.jpg ├── configs └── train_co3d_concept.yaml ├── main.py ├── requirements.txt ├── sample.py └── sgm ├── __init__.py ├── data ├── __init__.py └── data_co3d.py ├── lr_scheduler.py ├── models ├── __init__.py ├── autoencoder.py └── diffusion.py ├── modules ├── __init__.py ├── attention.py ├── autoencoding │ ├── lpips │ │ ├── __init__.py │ │ ├── loss │ │ │ ├── __init__.py │ │ │ └── lpips.py │ │ ├── model │ │ │ ├── LICENSE │ │ │ ├── __init__.py │ │ │ └── model.py │ │ ├── util.py │ │ └── vqperceptual.py │ └── regularizers │ │ ├── __init__.py │ │ ├── base.py │ │ └── quantize.py ├── diffusionmodules │ ├── __init__.py │ ├── denoiser.py │ ├── denoiser_scaling.py │ ├── denoiser_weighting.py │ ├── discretizer.py │ ├── guiders.py │ ├── loss.py │ ├── loss_weighting.py │ ├── model.py │ ├── openaimodel.py │ ├── sampling.py │ ├── sampling_utils.py │ ├── sigma_sampling.py │ ├── util.py │ └── wrappers.py ├── distributions │ ├── __init__.py │ ├── distributions.py │ └── distributions1.py ├── ema.py ├── encoders │ ├── __init__.py │ └── modules.py ├── nerfsd_pytorch3d.py └── utils_cameraray.py └── util.py /LICENSE.md: -------------------------------------------------------------------------------- 1 | # Attribution-NonCommercial-ShareAlike 4.0 International 2 | 3 | Creative Commons Corporation (“Creative Commons”) is not a law firm and does not provide legal services or legal advice. Distribution of Creative Commons public licenses does not create a lawyer-client or other relationship. Creative Commons makes its licenses and related information available on an “as-is” basis. Creative Commons gives no warranties regarding its licenses, any material licensed under their terms and conditions, or any related information. Creative Commons disclaims all liability for damages resulting from their use to the fullest extent possible. 4 | 5 | ### Using Creative Commons Public Licenses 6 | 7 | Creative Commons public licenses provide a standard set of terms and conditions that creators and other rights holders may use to share original works of authorship and other material subject to copyright and certain other rights specified in the public license below. The following considerations are for informational purposes only, are not exhaustive, and do not form part of our licenses. 8 | 9 | * __Considerations for licensors:__ Our public licenses are intended for use by those authorized to give the public permission to use material in ways otherwise restricted by copyright and certain other rights. Our licenses are irrevocable. Licensors should read and understand the terms and conditions of the license they choose before applying it. Licensors should also secure all rights necessary before applying our licenses so that the public can reuse the material as expected. Licensors should clearly mark any material not subject to the license. This includes other CC-licensed material, or material used under an exception or limitation to copyright. [More considerations for licensors](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensors). 10 | 11 | * __Considerations for the public:__ By using one of our public licenses, a licensor grants the public permission to use the licensed material under specified terms and conditions. If the licensor’s permission is not necessary for any reason–for example, because of any applicable exception or limitation to copyright–then that use is not regulated by the license. Our licenses grant only permissions under copyright and certain other rights that a licensor has authority to grant. Use of the licensed material may still be restricted for other reasons, including because others have copyright or other rights in the material. A licensor may make special requests, such as asking that all changes be marked or described. Although not required by our licenses, you are encouraged to respect those requests where reasonable. [More considerations for the public](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensees). 12 | 13 | ## Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License 14 | 15 | By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions. 16 | 17 | ### Section 1 – Definitions. 18 | 19 | a. __Adapted Material__ means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image. 20 | 21 | b. __Adapter's License__ means the license You apply to Your Copyright and Similar Rights in Your contributions to Adapted Material in accordance with the terms and conditions of this Public License. 22 | 23 | c. __BY-NC-SA Compatible License__ means a license listed at [creativecommons.org/compatiblelicenses](http://creativecommons.org/compatiblelicenses), approved by Creative Commons as essentially the equivalent of this Public License. 24 | 25 | d. __Copyright and Similar Rights__ means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights. 26 | 27 | e. __Effective Technological Measures__ means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements. 28 | 29 | f. __Exceptions and Limitations__ means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material. 30 | 31 | g. __License Elements__ means the license attributes listed in the name of a Creative Commons Public License. The License Elements of this Public License are Attribution, NonCommercial, and ShareAlike. 32 | 33 | h. __Licensed Material__ means the artistic or literary work, database, or other material to which the Licensor applied this Public License. 34 | 35 | i. __Licensed Rights__ means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license. 36 | 37 | j. __Licensor__ means the individual(s) or entity(ies) granting rights under this Public License. 38 | 39 | k. __NonCommercial__ means not primarily intended for or directed towards commercial advantage or monetary compensation. For purposes of this Public License, the exchange of the Licensed Material for other material subject to Copyright and Similar Rights by digital file-sharing or similar means is NonCommercial provided there is no payment of monetary compensation in connection with the exchange. 40 | 41 | l. __Share__ means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them. 42 | 43 | m. __Sui Generis Database Rights__ means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world. 44 | 45 | n. __You__ means the individual or entity exercising the Licensed Rights under this Public License. Your has a corresponding meaning. 46 | 47 | ### Section 2 – Scope. 48 | 49 | a. ___License grant.___ 50 | 51 | 1. Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to: 52 | 53 | A. reproduce and Share the Licensed Material, in whole or in part, for NonCommercial purposes only; and 54 | 55 | B. produce, reproduce, and Share Adapted Material for NonCommercial purposes only. 56 | 57 | 2. __Exceptions and Limitations.__ For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions. 58 | 59 | 3. __Term.__ The term of this Public License is specified in Section 6(a). 60 | 61 | 4. __Media and formats; technical modifications allowed.__ The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a)(4) never produces Adapted Material. 62 | 63 | 5. __Downstream recipients.__ 64 | 65 | A. __Offer from the Licensor – Licensed Material.__ Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License. 66 | 67 | B. __Additional offer from the Licensor – Adapted Material.__ Every recipient of Adapted Material from You automatically receives an offer from the Licensor to exercise the Licensed Rights in the Adapted Material under the conditions of the Adapter’s License You apply. 68 | 69 | C. __No downstream restrictions.__ You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material. 70 | 71 | 6. __No endorsement.__ Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i). 72 | 73 | b. ___Other rights.___ 74 | 75 | 1. Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise. 76 | 77 | 2. Patent and trademark rights are not licensed under this Public License. 78 | 79 | 3. To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties, including when the Licensed Material is used other than for NonCommercial purposes. 80 | 81 | ### Section 3 – License Conditions. 82 | 83 | Your exercise of the Licensed Rights is expressly made subject to the following conditions. 84 | 85 | a. ___Attribution.___ 86 | 87 | 1. If You Share the Licensed Material (including in modified form), You must: 88 | 89 | A. retain the following if it is supplied by the Licensor with the Licensed Material: 90 | 91 | i. identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated); 92 | 93 | ii. a copyright notice; 94 | 95 | iii. a notice that refers to this Public License; 96 | 97 | iv. a notice that refers to the disclaimer of warranties; 98 | 99 | v. a URI or hyperlink to the Licensed Material to the extent reasonably practicable; 100 | 101 | B. indicate if You modified the Licensed Material and retain an indication of any previous modifications; and 102 | 103 | C. indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License. 104 | 105 | 2. You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information. 106 | 107 | 3. If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable. 108 | 109 | b. ___ShareAlike.___ 110 | 111 | In addition to the conditions in Section 3(a), if You Share Adapted Material You produce, the following conditions also apply. 112 | 113 | 1. The Adapter’s License You apply must be a Creative Commons license with the same License Elements, this version or later, or a BY-NC-SA Compatible License. 114 | 115 | 2. You must include the text of, or the URI or hyperlink to, the Adapter's License You apply. You may satisfy this condition in any reasonable manner based on the medium, means, and context in which You Share Adapted Material. 116 | 117 | 3. You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, Adapted Material that restrict exercise of the rights granted under the Adapter's License You apply. 118 | 119 | ### Section 4 – Sui Generis Database Rights. 120 | 121 | Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material: 122 | 123 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database for NonCommercial purposes only; 124 | 125 | b. if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material, including for purposes of Section 3(b); and 126 | 127 | c. You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database. 128 | 129 | For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights. 130 | 131 | ### Section 5 – Disclaimer of Warranties and Limitation of Liability. 132 | 133 | a. __Unless otherwise separately undertaken by the Licensor, to the extent possible, the Licensor offers the Licensed Material as-is and as-available, and makes no representations or warranties of any kind concerning the Licensed Material, whether express, implied, statutory, or other. This includes, without limitation, warranties of title, merchantability, fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable. Where disclaimers of warranties are not allowed in full or in part, this disclaimer may not apply to You.__ 134 | 135 | b. __To the extent possible, in no event will the Licensor be liable to You on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this Public License or use of the Licensed Material, even if the Licensor has been advised of the possibility of such losses, costs, expenses, or damages. Where a limitation of liability is not allowed in full or in part, this limitation may not apply to You.__ 136 | 137 | c. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability. 138 | 139 | ### Section 6 – Term and Termination. 140 | 141 | a. This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically. 142 | 143 | b. Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates: 144 | 145 | 1. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or 146 | 147 | 2. upon express reinstatement by the Licensor. 148 | 149 | For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License. 150 | 151 | c. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License. 152 | 153 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public License. 154 | 155 | ### Section 7 – Other Terms and Conditions. 156 | 157 | a. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed. 158 | 159 | b. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License. 160 | 161 | ### Section 8 – Interpretation. 162 | 163 | a. For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License. 164 | 165 | b. To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions. 166 | 167 | c. No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor. 168 | 169 | d. Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority. 170 | 171 | > Creative Commons is not a party to its public licenses. Notwithstanding, Creative Commons may elect to apply one of its public licenses to material it publishes and in those instances will be considered the “Licensor.” The text of the Creative Commons public licenses is dedicated to the public domain under the CC0 Public Domain Dedication. Except for the limited purpose of indicating that material is shared under a Creative Commons public license or as otherwise permitted by the Creative Commons policies published at creativecommons.org/policies, Creative Commons does not authorize the use of the trademark “Creative Commons” or any other trademark or logo of Creative Commons without its prior written consent including, without limitation, in connection with any unauthorized modifications to any of its public licenses or any other arrangements, understandings, or agreements concerning use of licensed material. For the avoidance of doubt, this paragraph does not form part of the public licenses. 172 | > 173 | > Creative Commons may be contacted at creativecommons.org -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Custom Diffusion 360 2 | 3 | ### [website](https://customdiffusion360.github.io) | [paper](http://arxiv.org/abs/2404.12333) 4 | 5 | 6 | https://github.com/customdiffusion360/custom-diffusion360/assets/167265500/67b30422-5b82-4ee2-95a0-26c2e74154f8 7 | 8 | 9 | [Custom Diffusion 360](https://customdiffusion360.github.io) allows you to control the new custom object's viewpoint in generated images by text-to-image diffusion models, such as [Stable Diffusion](https://github.com/Stability-AI/generative-models). Given a 360-degree multiview dataset (~50 images), we fine-tune FeatureNeRF blocks in the intermediate feature space of the diffusion model to condition the generation on a target camera pose. 10 | 11 | **Customizing Text-to-Image Diffusion with Object Viewpoint Control**
12 | [Nupur Kumari](https://nupurkmr9.github.io/)*, [Grace Su](https://graceduansu.github.io/)*, [Richard Zhang](https://richzhang.github.io/), [Taesung Park](https://taesung.me/) [Eli Shechtman](https://research.adobe.com/person/eli-shechtman/), [Jun-Yan Zhu](https://www.cs.cmu.edu/~junyanz/)
13 |
14 | 15 | 16 | ## Results 17 | 18 | All of our results are based on the [SDXL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) model. 19 | We customize the model on various categories of multiview images, e.g., car, teddybear, chair, toy, motorcycle. For more generations and comparisons with baselines, please refer to our [webpage](https://customdiffusion360.github.io). 20 | 21 | ### Comparison to baselines 22 | 23 |
24 |

25 | 26 |

27 |
28 | 29 | ### Generations with different target camera pose 30 | 31 | 32 |
33 |

34 | 35 |

36 |
37 | 38 | 39 | 40 | ## Method Details 41 | 42 | 43 |
44 |

45 | 46 |

47 |
48 | 49 | 50 | Given multi-view images of an object with its camera pose, our method customizes a text-to-image diffusion model with that concept with an additional condition of target camera pose. We modify a subset of transformer layers to be pose-conditioned. This is done by adding a new FeatureNeRF block in intermediate feature space of the transformer layer. We finetune the new weights with the multiview dataset while keeping pre-trained model weights frozen. Similar to previous model customization methods, we add a new modifier token V* in front of the category name, e.g., V* car. 51 | 52 | 53 | ## Getting Started 54 | 55 | ``` 56 | git clone https://github.com/customdiffusion360/custom-diffusion360.git 57 | cd custom-diffusion360 58 | conda create -n pose python=3.8 59 | conda activate pose 60 | pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu118 61 | pip install -r requirements.txt 62 | ``` 63 | We also use `pytorch3D` in our code. Please look at the instructions to install that [here](https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md). Or you can follow the below steps to install from source: 64 | 65 | ``` 66 | conda install -c conda-forge cudatoolkit-dev -y 67 | export CUDA_HOME=$CONDA_PREFIX/pkgs/cuda-toolkit/" 68 | pip install "git+https://github.com/facebookresearch/pytorch3d.git@stable" 69 | ``` 70 | 71 | Download the stable-diffusion-xl model checkpoint: 72 | ``` 73 | mkdir pretrained-models 74 | cd pretrained-models 75 | wget https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0.safetensors 76 | wget https://huggingface.co/stabilityai/sdxl-vae/resolve/main/sdxl_vae.safetensors 77 | ``` 78 | 79 | ## Inference with provided models 80 | 81 | **Download pretrained models:** 82 | ``` 83 | gdown 1LM3Yc7gYXuNmFwr0s1Z-fnH0Ik8ttY8k -O pretrained-models/car0.tar 84 | tar -xvf pretrained-models/car0.tar -C pretrained-models/ 85 | ``` 86 | We provide all customized models [here](https://drive.google.com/drive/folders/17EfDutJzme_4JG-KWzxp6LYNfbqh4LmW?usp=sharing) 87 | 88 | **Sample images:** 89 | 90 | ``` 91 | python sample.py --custom_model_dir pretrained-models/car0 --output_dir outputs --prompt "a car beside a field of blooming sunflowers." 92 | ``` 93 | 94 | ## Training 95 | 96 | **Dataset:** 97 | 98 | We share the 14 concepts (part of [CO3Dv2](https://github.com/facebookresearch/co3d) and [NAVI](https://navidataset.github.io)) that we used in our paper for easy experimentation. The datasets are redistributed under the same licenses as the original works. 99 | 100 | ``` 101 | gdown 1GRnkm4xp89bnYAPnp01UMVlCbmdR7SeG 102 | tar -xvzf data.tar.gz 103 | ``` 104 | 105 | **Train:** 106 | 107 | ``` 108 | python main.py --base configs/train_co3d_concept.yaml --name car0 --resume_from_checkpoint_custom pretrained-models/sd_xl_base_1.0.safetensors --no_date --set_from_main --data_category car --data_single_id 0 109 | ``` 110 | 111 | **Your own multi-view images + Colmap:** 112 | to be released soon. 113 | 114 | ## Evaluation: to be released 115 | 116 | 117 | ## Referenced Github repos 118 | Thanks to the following for releasing their code. Our code builds upon these. 119 | 120 | **[Stable Diffusion-XL](https://github.com/Stability-AI/generative-models)** 121 | **[Relpose-plus-plus](https://github.com/amyxlase/relpose-plus-plus/tree/main)** 122 | **[GBT](https://github.com/mayankgrwl97/gbt)** 123 | 124 | 125 | ## Bibliography 126 | 127 | ``` 128 | @inproceedings{kumari2024customdiffusion360, 129 | title={Customizing Text-to-Image Diffusion with Object Viewpoint Control}, 130 | author={Kumari, Nupur and Su, Grace and Zhang, Richard and Park, Taesung and Shechtman, Eli and Zhu, Jun-Yan}, 131 | booktitle = {SIGGRAPH Asia}, 132 | year = {2024} 133 | } 134 | ``` 135 | 136 | ## Acknowledgments 137 | We are thankful to Kangle Deng, Sheng-Yu Wang, and Gaurav Parmar for their helpful comments and discussion and to Sean Liu, Ruihan Gao, Yufei Ye, and Bharath Raj for proofreading the draft. This work was partly done by Nupur Kumari during the Adobe internship. The work is partly supported by Adobe Research, the Packard Fellowship, the Amazon Faculty Research Award, and NSF IIS-2239076. Grace Su is supported by the NSF Graduate Research Fellowship (Grant No. DGE2140739). 138 | -------------------------------------------------------------------------------- /assets/method.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/customdiffusion360/custom-diffusion360/1a23f972274e7275fdeaa3197f5d22118aa228bb/assets/method.jpg -------------------------------------------------------------------------------- /assets/result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/customdiffusion360/custom-diffusion360/1a23f972274e7275fdeaa3197f5d22118aa228bb/assets/result.jpg -------------------------------------------------------------------------------- /assets/results_vary_target_pose.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/customdiffusion360/custom-diffusion360/1a23f972274e7275fdeaa3197f5d22118aa228bb/assets/results_vary_target_pose.jpg -------------------------------------------------------------------------------- /assets/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/customdiffusion360/custom-diffusion360/1a23f972274e7275fdeaa3197f5d22118aa228bb/assets/teaser.jpg -------------------------------------------------------------------------------- /configs/train_co3d_concept.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-4 3 | target: sgm.models.diffusion.DiffusionEngine 4 | params: 5 | scale_factor: 0.13025 6 | disable_first_stage_autocast: True 7 | trainkeys: pose 8 | multiplier: 0.05 9 | loss_rgb_lambda: 5 10 | loss_fg_lambda: 10 11 | loss_bg_lambda: 10 12 | log_keys: 13 | - txt 14 | 15 | denoiser_config: 16 | target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser 17 | params: 18 | num_idx: 1000 19 | 20 | weighting_config: 21 | target: sgm.modules.diffusionmodules.denoiser_weighting.EpsWeighting 22 | scaling_config: 23 | target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling 24 | discretization_config: 25 | target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization 26 | 27 | network_config: 28 | target: sgm.modules.diffusionmodules.openaimodel.UNetModel 29 | params: 30 | adm_in_channels: 2816 31 | num_classes: sequential 32 | use_checkpoint: False 33 | in_channels: 4 34 | out_channels: 4 35 | model_channels: 320 36 | attention_resolutions: [4, 2] 37 | num_res_blocks: 2 38 | channel_mult: [1, 2, 4] 39 | num_head_channels: 64 40 | use_linear_in_transformer: True 41 | transformer_depth: [1, 2, 10] 42 | context_dim: 2048 43 | spatial_transformer_attn_type: softmax-xformers 44 | image_cross_blocks: [0, 2, 4, 6, 8, 10] 45 | rgb: True 46 | far: 2 47 | num_samples: 24 48 | not_add_context_in_triplane: False 49 | rgb_predict: True 50 | add_lora: False 51 | average: False 52 | use_prev_weights_imp_sample: True 53 | stratified: True 54 | imp_sampling_percent: 0.9 55 | 56 | conditioner_config: 57 | target: sgm.modules.GeneralConditioner 58 | params: 59 | emb_models: 60 | # crossattn cond 61 | - is_trainable: False 62 | input_keys: txt,txt_ref 63 | target: sgm.modules.encoders.modules.FrozenCLIPEmbedder 64 | params: 65 | layer: hidden 66 | layer_idx: 11 67 | modifier_token: 68 | # crossattn and vector cond 69 | - is_trainable: False 70 | input_keys: txt,txt_ref 71 | target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder 72 | params: 73 | arch: ViT-bigG-14 74 | version: laion2b_s39b_b160k 75 | layer: penultimate 76 | always_return_pooled: True 77 | legacy: False 78 | modifier_token: 79 | # vector cond 80 | - is_trainable: False 81 | input_keys: original_size_as_tuple,original_size_as_tuple_ref 82 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 83 | params: 84 | outdim: 256 # multiplied by two 85 | # vector cond 86 | - is_trainable: False 87 | input_keys: crop_coords_top_left,crop_coords_top_left_ref 88 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 89 | params: 90 | outdim: 256 # multiplied by two 91 | # vector cond 92 | - is_trainable: False 93 | input_keys: target_size_as_tuple,target_size_as_tuple_ref 94 | target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND 95 | params: 96 | outdim: 256 # multiplied by two 97 | 98 | first_stage_config: 99 | target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper 100 | params: 101 | ckpt_path: pretrained-models/sdxl_vae.safetensors 102 | embed_dim: 4 103 | monitor: val/rec_loss 104 | ddconfig: 105 | attn_type: vanilla-xformers 106 | double_z: true 107 | z_channels: 4 108 | resolution: 256 109 | in_channels: 3 110 | out_ch: 3 111 | ch: 128 112 | ch_mult: [1, 2, 4, 4] 113 | num_res_blocks: 2 114 | attn_resolutions: [] 115 | dropout: 0.0 116 | lossconfig: 117 | target: torch.nn.Identity 118 | 119 | loss_fn_config: 120 | target: sgm.modules.diffusionmodules.loss.StandardDiffusionLossImgRef 121 | params: 122 | sigma_sampler_config: 123 | target: sgm.modules.diffusionmodules.sigma_sampling.CubicSampling 124 | params: 125 | num_idx: 1000 126 | discretization_config: 127 | target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization 128 | sigma_sampler_config_ref: 129 | target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling 130 | params: 131 | num_idx: 50 132 | 133 | discretization_config: 134 | target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization 135 | 136 | sampler_config: 137 | target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler 138 | params: 139 | num_steps: 50 140 | 141 | discretization_config: 142 | target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization 143 | 144 | guider_config: 145 | target: sgm.modules.diffusionmodules.guiders.VanillaCFGImgRef 146 | params: 147 | scale: 7.5 148 | 149 | data: 150 | target: sgm.data.data_co3d.CustomDataDictLoader 151 | params: 152 | batch_size: 1 153 | num_workers: 4 154 | category: teddybear 155 | img_size: 512 156 | skip: 2 157 | num_images: 5 158 | mask_images: True 159 | single_id: 0 160 | bbox: True 161 | addreg: True 162 | drop_ratio: 0.25 163 | drop_txt: 0.1 164 | modifier_token: 165 | 166 | lightning: 167 | modelcheckpoint: 168 | params: 169 | every_n_train_steps: 1600 170 | save_top_k: -1 171 | save_on_train_epoch_end: False 172 | 173 | callbacks: 174 | metrics_over_trainsteps_checkpoint: 175 | params: 176 | every_n_train_steps: 25000 177 | 178 | image_logger: 179 | target: main.ImageLogger 180 | params: 181 | disabled: False 182 | enable_autocast: False 183 | batch_frequency: 5000 184 | max_images: 8 185 | increase_log_steps: False 186 | log_first_step: False 187 | log_images_kwargs: 188 | use_ema_scope: False 189 | N: 1 190 | n_rows: 2 191 | 192 | trainer: 193 | devices: 0,1,2,3 194 | benchmark: True 195 | num_sanity_val_steps: 0 196 | accumulate_grad_batches: 1 197 | max_steps: 1610 198 | # val_check_interval: 400 199 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | omegaconf 2 | einops 3 | fire 4 | tqdm 5 | pillow 6 | numpy 7 | webdataset>=0.2.33 8 | ninja 9 | matplotlib 10 | torchmetrics 11 | opencv-python==4.6.0.66 12 | fairscale 13 | pytorch-lightning==2.0.1 14 | fire 15 | fsspec 16 | kornia==0.6.9 17 | natsort 18 | open-clip-torch 19 | chardet==5.1.0 20 | tensorboardx==2.6 21 | pandas 22 | pudb 23 | pyyaml 24 | urllib3<1.27,>=1.25.4 25 | scipy 26 | streamlit>=0.73.1 27 | timm 28 | tokenizers==0.12.1 29 | transformers==4.19.1 30 | triton==2.1.0 31 | torchdata==0.7.0 32 | wandb 33 | invisible-watermark 34 | xformers 35 | gdown 36 | -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers 37 | -e git+https://github.com/openai/CLIP.git@main#egg=clip 38 | -e git+https://github.com/Stability-AI/datapipelines.git@main#egg=sdata -------------------------------------------------------------------------------- /sample.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import glob 4 | import os 5 | import sys 6 | from typing import List 7 | 8 | import numpy as np 9 | import torch 10 | from einops import rearrange 11 | from omegaconf import OmegaConf 12 | from PIL import Image 13 | from pytorch3d.renderer.camera_utils import join_cameras_as_batch 14 | from pytorch_lightning import seed_everything 15 | 16 | sys.path.append('./') 17 | from sgm.modules.utils_cameraray import ( 18 | interpolate_translate_interpolate_xaxis, 19 | interpolate_translate_interpolate_yaxis, 20 | interpolate_translate_interpolate_zaxis, 21 | interpolatefocal, 22 | ) 23 | from sgm.util import load_model_from_config 24 | 25 | choices = [] 26 | 27 | 28 | def get_unique_embedder_keys_from_conditioner(conditioner): 29 | p = [x.input_keys for x in conditioner.embedders] 30 | return list(set([item for sublist in p for item in sublist])) + ['jpg_ref'] 31 | 32 | 33 | def customforward(self, x, xr, context=None, contextr=None, pose=None, mask_ref=None, prev_weights=None, timesteps=None): 34 | if not isinstance(context, list): 35 | context = [context] 36 | b, c, h, w = x.shape 37 | x_in = x 38 | fg_masks = [] 39 | alphas = [] 40 | rgbs = [] 41 | 42 | x = self.norm(x) 43 | 44 | if not self.use_linear: 45 | x = self.proj_in(x) 46 | 47 | x = rearrange(x, "b c h w -> b (h w) c").contiguous() 48 | if self.use_linear: 49 | x = self.proj_in(x) 50 | 51 | prev_weights = None 52 | counter = 0 53 | for i, block in enumerate(self.transformer_blocks): 54 | if i > 0 and len(context) == 1: 55 | i = 0 # use same context for each block 56 | if self.image_cross and (counter % self.poscontrol_interval == 0): 57 | x, fg_mask, weights, alpha, rgb = block(x, context=context[i], context_ref=x, pose=pose, mask_ref=mask_ref, prev_weights=prev_weights) 58 | prev_weights = weights 59 | fg_masks.append(fg_mask) 60 | if alpha is not None: 61 | alphas.append(alpha) 62 | if rgb is not None: 63 | rgbs.append(rgb) 64 | else: 65 | x, _, _, _, _ = block(x, context=context[i]) 66 | counter += 1 67 | if self.use_linear: 68 | x = self.proj_out(x) 69 | x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous() 70 | if not self.use_linear: 71 | x = self.proj_out(x) 72 | if len(fg_masks) > 0: 73 | if len(rgbs) <= 0: 74 | rgbs = None 75 | if len(alphas) <= 0: 76 | alphas = None 77 | return x + x_in, None, fg_masks, prev_weights, alphas, rgbs 78 | else: 79 | return x + x_in, None, None, prev_weights, None, None 80 | 81 | 82 | def _customforward( 83 | self, x, context=None, context_ref=None, pose=None, mask_ref=None, prev_weights=None, additional_tokens=None, n_times_crossframe_attn_in_self=0 84 | ): 85 | if context_ref is not None: 86 | global choices 87 | batch_size = x.size(0) 88 | # IP2P like sampling or default sampling 89 | if batch_size % 3 == 0: 90 | batch_size = batch_size // 3 91 | context_ref = torch.stack([self.references[:-1][y] for y in choices]).unsqueeze(0).expand(batch_size, -1, -1, -1) 92 | context_ref = torch.cat([self.references[-1:].unsqueeze(0).expand(batch_size, context_ref.size(1), -1, -1), context_ref, context_ref], dim=0) 93 | else: 94 | batch_size = batch_size // 2 95 | context_ref = torch.stack([self.references[:-1][y] for y in choices]).unsqueeze(0).expand(batch_size, -1, -1, -1) 96 | context_ref = torch.cat([self.references[-1:].unsqueeze(0).expand(batch_size, context_ref.size(1), -1, -1), context_ref], dim=0) 97 | 98 | fg_mask = None 99 | weights = None 100 | alphas = None 101 | predicted_rgb = None 102 | 103 | x = ( 104 | self.attn1( 105 | self.norm1(x), 106 | context=context if self.disable_self_attn else None, 107 | additional_tokens=additional_tokens, 108 | n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self 109 | if not self.disable_self_attn 110 | else 0, 111 | ) 112 | + x 113 | ) 114 | 115 | x = ( 116 | self.attn2( 117 | self.norm2(x), context=context, additional_tokens=additional_tokens, 118 | ) 119 | + x 120 | ) 121 | 122 | if context_ref is not None: 123 | if self.rendered_feat is not None: 124 | x = self.pose_emb_layers(torch.cat([x, self.rendered_feat], dim=-1)) 125 | else: 126 | xref, fg_mask, weights, alphas, predicted_rgb = self.reference_attn(x, 127 | context_ref, 128 | context, 129 | pose, 130 | prev_weights, 131 | mask_ref) 132 | self.rendered_feat = xref 133 | x = self.pose_emb_layers(torch.cat([x, xref], -1)) 134 | 135 | x = self.ff(self.norm3(x)) + x 136 | return x, fg_mask, weights, alphas, predicted_rgb 137 | 138 | 139 | def log_images( 140 | model, 141 | batch, 142 | N: int = 1, 143 | noise=None, 144 | scale_im=3.5, 145 | num_steps: int = 10, 146 | ucg_keys: List[str] = None, 147 | **kwargs, 148 | ): 149 | 150 | log = dict() 151 | conditioner_input_keys = [e.input_keys for e in model.conditioner.embedders] 152 | ucg_keys = conditioner_input_keys 153 | pose = batch['pose'] 154 | 155 | c, uc = model.conditioner.get_unconditional_conditioning( 156 | batch, 157 | force_uc_zero_embeddings=ucg_keys 158 | if len(model.conditioner.embedders) > 0 159 | else [], 160 | force_ref_zero_embeddings=True 161 | ) 162 | 163 | _, n = 1, len(pose)-1 164 | sampling_kwargs = {} 165 | 166 | if scale_im > 0: 167 | if uc is not None: 168 | if isinstance(pose, list): 169 | pose = pose[:N]*3 170 | else: 171 | pose = torch.cat([pose[:N]] * 3) 172 | else: 173 | if uc is not None: 174 | if isinstance(pose, list): 175 | pose = pose[:N]*2 176 | else: 177 | pose = torch.cat([pose[:N]] * 2) 178 | 179 | sampling_kwargs['pose'] = pose 180 | sampling_kwargs['drop_im'] = None 181 | sampling_kwargs['mask_ref'] = None 182 | 183 | for k in c: 184 | if isinstance(c[k], torch.Tensor): 185 | c[k], uc[k] = map(lambda y: y[k][:(n+1)*N].to('cuda'), (c, uc)) 186 | 187 | import time 188 | st = time.time() 189 | with model.ema_scope("Plotting"): 190 | samples = model.sample( 191 | c, shape=noise.shape[1:], uc=uc, batch_size=N, num_steps=num_steps, noise=noise, **sampling_kwargs 192 | ) 193 | model.clear_rendered_feat() 194 | samples = model.decode_first_stage(samples) 195 | print("Time taken for sampling", time.time() - st) 196 | log["samples"] = samples.cpu() 197 | 198 | return log 199 | 200 | 201 | def sample(config, 202 | ckpt=None, 203 | delta_ckpt=None, 204 | camera_path=None, 205 | num_images=6, 206 | prompt_list=None, 207 | scale=7.5, 208 | num_ref=8, 209 | num_steps=50, 210 | output_dir='', 211 | scale_im=None, 212 | max_images=20, 213 | seed=30, 214 | specific_id='', 215 | interp_start=-0.2, 216 | interp_end=0.21, 217 | interp_step=0.4, 218 | translateY=False, 219 | translateZ=False, 220 | translateX=False, 221 | translate_focal=False, 222 | resolution=512, 223 | random_render_path=None, 224 | allround_render=False, 225 | equidistant=False, 226 | ): 227 | 228 | config = OmegaConf.load(config) 229 | 230 | # setup guider config 231 | if scale_im > 0: 232 | guider_config = {'target': 'sgm.modules.diffusionmodules.guiders.ScheduledCFGImgTextRef', 233 | 'params': {'scale': scale, 'scale_im': scale_im} 234 | } 235 | config.model.params.sampler_config.params.guider_config = guider_config 236 | else: 237 | guider_config = {'target': 'sgm.modules.diffusionmodules.guiders.VanillaCFGImgRef', 238 | 'params': {'scale': scale} 239 | } 240 | config.model.params.sampler_config.params.guider_config = guider_config 241 | 242 | # load model 243 | model = load_model_from_config(config, ckpt, delta_ckpt) 244 | model = model.cuda() 245 | 246 | # change forward methods to store rendered features from first step and use the pre-calculated reference features 247 | def register_recr(net_): 248 | if net_.__class__.__name__ == 'SpatialTransformer': 249 | bound_method = customforward.__get__(net_, net_.__class__) 250 | setattr(net_, 'forward', bound_method) 251 | return 252 | elif hasattr(net_, 'children'): 253 | for net__ in net_.children(): 254 | register_recr(net__) 255 | return 256 | 257 | def register_recr2(net_): 258 | if net_.__class__.__name__ == 'BasicTransformerBlock': 259 | bound_method = _customforward.__get__(net_, net_.__class__) 260 | setattr(net_, 'forward', bound_method) 261 | return 262 | elif hasattr(net_, 'children'): 263 | for net__ in net_.children(): 264 | register_recr2(net__) 265 | return 266 | 267 | sub_nets = model.model.diffusion_model.named_children() 268 | for net in sub_nets: 269 | register_recr(net[1]) 270 | register_recr2(net[1]) 271 | 272 | # load cameras 273 | cameras_val, cameras_train = torch.load(camera_path) 274 | global choices 275 | num_ref = 8 276 | max_diff = len(cameras_train)/num_ref 277 | choices = [int(x) for x in torch.linspace(0, len(cameras_train) - max_diff, num_ref)] 278 | cameras_train_final = [cameras_train[i] for i in choices] 279 | 280 | # start sampling 281 | model.clear_rendered_feat() 282 | seedeval_counter = seed 283 | counter = 0 284 | for _, prompt in enumerate(prompt_list): 285 | curent_seed = seedeval_counter 286 | seed_everything(curent_seed) 287 | 288 | if translateZ or translateY or translateX or translate_focal: 289 | interp_reps = len(np.arange(interp_start, interp_end, interp_step)) 290 | noise = torch.randn(1, 4, resolution // 8, resolution // 8).to('cuda').repeat(num_images*interp_reps, 1, 1, 1) 291 | else: 292 | noise = torch.randn(1, 4, resolution // 8, resolution // 8).to('cuda').repeat(num_images, 1, 1, 1) 293 | 294 | # random sample camera poses 295 | pose_ids = np.random.choice(len(cameras_val), num_images, replace=False) 296 | print(pose_ids) 297 | pose = [cameras_val[i] for i in pose_ids] 298 | 299 | # prepare batches [if translating then call required functions on the target pose] 300 | batches = [] 301 | for i in range(num_images): 302 | batch = {'pose': [pose[i]] + cameras_train_final, 303 | "original_size_as_tuple": torch.tensor([512, 512]).reshape(-1, 2), 304 | "target_size_as_tuple": torch.tensor([512, 512]).reshape(-1, 2), 305 | "crop_coords_top_left": torch.tensor([0, 0]).reshape(-1, 2), 306 | "original_size_as_tuple_ref": torch.tensor([512, 512]).reshape(-1, 2), 307 | "target_size_as_tuple_ref": torch.tensor([512, 512]).reshape(-1, 2), 308 | "crop_coords_top_left_ref": torch.tensor([0, 0]).reshape(-1, 2), 309 | } 310 | if translateZ or translateY or translateX or translate_focal: 311 | cameras = [] 312 | if translateY: 313 | cameras += interpolate_translate_interpolate_yaxis(batch["pose"][0], interp_start, interp_end, interp_step) 314 | elif translateZ: 315 | cameras += interpolate_translate_interpolate_zaxis(batch["pose"][0], interp_start, interp_end, interp_step) 316 | elif translateX: 317 | cameras += interpolate_translate_interpolate_xaxis(batch["pose"][0], interp_start, interp_end, interp_step) 318 | else: 319 | cameras += interpolatefocal(batch["pose"][0], interp_start, interp_end, interp_step) 320 | for j in range(len(cameras)): 321 | batch_ = copy.deepcopy(batch) 322 | batch_["pose"][0] = cameras[j] 323 | batch_["pose"] = [join_cameras_as_batch(batch_["pose"])] 324 | batches.append(batch_) 325 | else: 326 | batch["pose"] = [join_cameras_as_batch(batch["pose"])] 327 | batches.append(batch) 328 | 329 | print(f'len batches: {len(batches)}') 330 | with torch.no_grad(): 331 | for batch in batches: 332 | for key in batch.keys(): 333 | if isinstance(batch[key], torch.Tensor): 334 | batch[key] = batch[key].to('cuda') 335 | elif 'pose' in key: 336 | batch[key] = [x.to('cuda') for x in batch[key]] 337 | else: 338 | pass 339 | 340 | batch["txt"] = [prompt for _ in range(1)] 341 | batch["txt_ref"] = [prompt for _ in range(len(batch["pose"])-1)] 342 | print(batch["txt"]) 343 | 344 | N = 1 345 | log_ = log_images(model, batch, N=N, noise=noise.clone()[:N], num_steps=50, scale_im=scale_im) 346 | im = Image.fromarray((torch.clip(log_["samples"][0].permute(1, 2, 0)*0.5+0.5, 0, 1.).cpu().numpy()*255).astype(np.uint8)) 347 | prompt_ = prompt.replace(' ', '_') 348 | im.save(f'{output_dir}/sample_{counter}_{prompt_}_{seedeval_counter}.png') 349 | counter += 1 350 | torch.cuda.empty_cache() 351 | return 352 | 353 | 354 | def get_parser(): 355 | parser = argparse.ArgumentParser() 356 | parser.add_argument("--ckpt", type=str, default='pretrained-models/sd_xl_base_1.0.safetensors') 357 | parser.add_argument("--custom_model_dir", type=str, default=None) 358 | parser.add_argument("--translateY", action="store_true") 359 | parser.add_argument("--translateZ", action="store_true") 360 | parser.add_argument("--translateX", action="store_true") 361 | parser.add_argument("--translate_focal", action="store_true") 362 | parser.add_argument("--num_images", type=int, default=5) 363 | parser.add_argument("--num_steps", type=int, default=50) 364 | parser.add_argument("--seed", type=int, default=30) 365 | parser.add_argument("--num_ref", type=int, default=8) 366 | parser.add_argument("--prompt", type=str, default="") 367 | parser.add_argument("--scale", type=float, default=7.5) 368 | parser.add_argument("--scale_im", type=float, default=3.5) 369 | parser.add_argument("--output_dir", type=str, default='') 370 | parser.add_argument("--interp_start", type=float, default=-0.2) 371 | parser.add_argument("--interp_end", type=float, default=0.21) 372 | parser.add_argument("--interp_step", type=float, default=0.4) 373 | parser.add_argument("--allround_render", action="store_true") 374 | return parser 375 | 376 | 377 | if __name__ == "__main__": 378 | parser = get_parser() 379 | args = parser.parse_args() 380 | seed = args.seed 381 | seed_everything(seed) 382 | 383 | args.delta_ckpt = os.path.join(args.custom_model_dir, 'checkpoints', 'step=000001600.ckpt') 384 | args.config = sorted(glob.glob(os.path.join(args.custom_model_dir, "configs/*.yaml")))[-1] 385 | args.camera_path = os.path.join(args.custom_model_dir, 'camera.bin') 386 | sample(args.config, 387 | ckpt=args.ckpt, 388 | delta_ckpt=args.delta_ckpt, 389 | camera_path=args.camera_path, 390 | num_images=args.num_images, 391 | prompt_list=[args.prompt], 392 | scale=args.scale, 393 | num_ref=args.num_ref, 394 | num_steps=args.num_steps, 395 | output_dir=args.output_dir, 396 | scale_im=args.scale_im, 397 | seed=args.seed, 398 | interp_start=args.interp_start, 399 | interp_end=args.interp_end, 400 | interp_step=args.interp_step, 401 | translateY=args.translateY, 402 | translateZ=args.translateZ, 403 | translateX=args.translateX, 404 | translate_focal=args.translate_focal, 405 | allround_render=args.allround_render, 406 | ) 407 | -------------------------------------------------------------------------------- /sgm/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import AutoencodingEngine, DiffusionEngine 2 | from .util import get_configs_path, instantiate_from_config 3 | 4 | __version__ = "0.1.0" 5 | -------------------------------------------------------------------------------- /sgm/data/__init__.py: -------------------------------------------------------------------------------- 1 | # from .dataset import StableDataModuleFromConfig 2 | -------------------------------------------------------------------------------- /sgm/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | 9 | def __init__( 10 | self, 11 | warm_up_steps, 12 | lr_min, 13 | lr_max, 14 | lr_start, 15 | max_decay_steps, 16 | verbosity_interval=0, 17 | ): 18 | self.lr_warm_up_steps = warm_up_steps 19 | self.lr_start = lr_start 20 | self.lr_min = lr_min 21 | self.lr_max = lr_max 22 | self.lr_max_decay_steps = max_decay_steps 23 | self.last_lr = 0.0 24 | self.verbosity_interval = verbosity_interval 25 | 26 | def schedule(self, n, **kwargs): 27 | if self.verbosity_interval > 0: 28 | if n % self.verbosity_interval == 0: 29 | print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 30 | if n < self.lr_warm_up_steps: 31 | lr = ( 32 | self.lr_max - self.lr_start 33 | ) / self.lr_warm_up_steps * n + self.lr_start 34 | self.last_lr = lr 35 | return lr 36 | else: 37 | t = (n - self.lr_warm_up_steps) / ( 38 | self.lr_max_decay_steps - self.lr_warm_up_steps 39 | ) 40 | t = min(t, 1.0) 41 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 42 | 1 + np.cos(t * np.pi) 43 | ) 44 | self.last_lr = lr 45 | return lr 46 | 47 | def __call__(self, n, **kwargs): 48 | return self.schedule(n, **kwargs) 49 | 50 | 51 | class LambdaWarmUpCosineScheduler2: 52 | """ 53 | supports repeated iterations, configurable via lists 54 | note: use with a base_lr of 1.0. 55 | """ 56 | 57 | def __init__( 58 | self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0 59 | ): 60 | assert ( 61 | len(warm_up_steps) 62 | == len(f_min) 63 | == len(f_max) 64 | == len(f_start) 65 | == len(cycle_lengths) 66 | ) 67 | self.lr_warm_up_steps = warm_up_steps 68 | self.f_start = f_start 69 | self.f_min = f_min 70 | self.f_max = f_max 71 | self.cycle_lengths = cycle_lengths 72 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 73 | self.last_f = 0.0 74 | self.verbosity_interval = verbosity_interval 75 | 76 | def find_in_interval(self, n): 77 | interval = 0 78 | for cl in self.cum_cycles[1:]: 79 | if n <= cl: 80 | return interval 81 | interval += 1 82 | 83 | def schedule(self, n, **kwargs): 84 | cycle = self.find_in_interval(n) 85 | n = n - self.cum_cycles[cycle] 86 | if self.verbosity_interval > 0: 87 | if n % self.verbosity_interval == 0: 88 | print( 89 | f"current step: {n}, recent lr-multiplier: {self.last_f}, " 90 | f"current cycle {cycle}" 91 | ) 92 | if n < self.lr_warm_up_steps[cycle]: 93 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[ 94 | cycle 95 | ] * n + self.f_start[cycle] 96 | self.last_f = f 97 | return f 98 | else: 99 | t = (n - self.lr_warm_up_steps[cycle]) / ( 100 | self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle] 101 | ) 102 | t = min(t, 1.0) 103 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 104 | 1 + np.cos(t * np.pi) 105 | ) 106 | self.last_f = f 107 | return f 108 | 109 | def __call__(self, n, **kwargs): 110 | return self.schedule(n, **kwargs) 111 | 112 | 113 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 114 | def schedule(self, n, **kwargs): 115 | cycle = self.find_in_interval(n) 116 | n = n - self.cum_cycles[cycle] 117 | if self.verbosity_interval > 0: 118 | if n % self.verbosity_interval == 0: 119 | print( 120 | f"current step: {n}, recent lr-multiplier: {self.last_f}, " 121 | f"current cycle {cycle}" 122 | ) 123 | 124 | if n < self.lr_warm_up_steps[cycle]: 125 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[ 126 | cycle 127 | ] * n + self.f_start[cycle] 128 | self.last_f = f 129 | return f 130 | else: 131 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * ( 132 | self.cycle_lengths[cycle] - n 133 | ) / (self.cycle_lengths[cycle]) 134 | self.last_f = f 135 | return f 136 | -------------------------------------------------------------------------------- /sgm/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .autoencoder import AutoencodingEngine 2 | from .diffusion import DiffusionEngine 3 | -------------------------------------------------------------------------------- /sgm/models/autoencoder.py: -------------------------------------------------------------------------------- 1 | import re 2 | from abc import abstractmethod 3 | from contextlib import contextmanager 4 | from typing import Any, Dict, Tuple, Union 5 | 6 | import pytorch_lightning as pl 7 | import torch 8 | from omegaconf import ListConfig 9 | from packaging import version 10 | from safetensors.torch import load_file as load_safetensors 11 | 12 | from ..modules.diffusionmodules.model import Decoder, Encoder 13 | from ..modules.distributions.distributions import DiagonalGaussianDistribution 14 | from ..modules.ema import LitEma 15 | from ..util import default, get_obj_from_str, instantiate_from_config 16 | 17 | 18 | class AbstractAutoencoder(pl.LightningModule): 19 | """ 20 | This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators, 21 | unCLIP models, etc. Hence, it is fairly general, and specific features 22 | (e.g. discriminator training, encoding, decoding) must be implemented in subclasses. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | ema_decay: Union[None, float] = None, 28 | monitor: Union[None, str] = None, 29 | input_key: str = "jpg", 30 | ckpt_path: Union[None, str] = None, 31 | ignore_keys: Union[Tuple, list, ListConfig] = (), 32 | ): 33 | super().__init__() 34 | self.input_key = input_key 35 | self.use_ema = ema_decay is not None 36 | if monitor is not None: 37 | self.monitor = monitor 38 | 39 | if self.use_ema: 40 | self.model_ema = LitEma(self, decay=ema_decay) 41 | print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") 42 | 43 | if ckpt_path is not None: 44 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 45 | 46 | if version.parse(torch.__version__) >= version.parse("2.0.0"): 47 | self.automatic_optimization = False 48 | 49 | def init_from_ckpt( 50 | self, path: str, ignore_keys: Union[Tuple, list, ListConfig] = tuple() 51 | ) -> None: 52 | if path.endswith("ckpt"): 53 | sd = torch.load(path, map_location="cpu")["state_dict"] 54 | elif path.endswith("safetensors"): 55 | sd = load_safetensors(path) 56 | else: 57 | raise NotImplementedError 58 | 59 | keys = list(sd.keys()) 60 | for k in keys: 61 | for ik in ignore_keys: 62 | if re.match(ik, k): 63 | print("Deleting key {} from state_dict.".format(k)) 64 | del sd[k] 65 | missing, unexpected = self.load_state_dict(sd, strict=False) 66 | print( 67 | f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys" 68 | ) 69 | if len(missing) > 0: 70 | print(f"Missing Keys: {missing}") 71 | if len(unexpected) > 0: 72 | print(f"Unexpected Keys: {unexpected}") 73 | 74 | @abstractmethod 75 | def get_input(self, batch) -> Any: 76 | raise NotImplementedError() 77 | 78 | def on_train_batch_end(self, *args, **kwargs): 79 | # for EMA computation 80 | if self.use_ema: 81 | self.model_ema(self) 82 | 83 | @contextmanager 84 | def ema_scope(self, context=None): 85 | if self.use_ema: 86 | self.model_ema.store(self.parameters()) 87 | self.model_ema.copy_to(self) 88 | if context is not None: 89 | print(f"{context}: Switched to EMA weights") 90 | try: 91 | yield None 92 | finally: 93 | if self.use_ema: 94 | self.model_ema.restore(self.parameters()) 95 | if context is not None: 96 | print(f"{context}: Restored training weights") 97 | 98 | @abstractmethod 99 | def encode(self, *args, **kwargs) -> torch.Tensor: 100 | raise NotImplementedError("encode()-method of abstract base class called") 101 | 102 | @abstractmethod 103 | def decode(self, *args, **kwargs) -> torch.Tensor: 104 | raise NotImplementedError("decode()-method of abstract base class called") 105 | 106 | def instantiate_optimizer_from_config(self, params, lr, cfg): 107 | print(f"loading >>> {cfg['target']} <<< optimizer from config") 108 | return get_obj_from_str(cfg["target"])( 109 | params, lr=lr, **cfg.get("params", dict()) 110 | ) 111 | 112 | def configure_optimizers(self) -> Any: 113 | raise NotImplementedError() 114 | 115 | 116 | class AutoencodingEngine(AbstractAutoencoder): 117 | """ 118 | Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL 119 | (we also restore them explicitly as special cases for legacy reasons). 120 | Regularizations such as KL or VQ are moved to the regularizer class. 121 | """ 122 | 123 | def __init__( 124 | self, 125 | *args, 126 | encoder_config: Dict, 127 | decoder_config: Dict, 128 | loss_config: Dict, 129 | regularizer_config: Dict, 130 | optimizer_config: Union[Dict, None] = None, 131 | lr_g_factor: float = 1.0, 132 | **kwargs, 133 | ): 134 | super().__init__(*args, **kwargs) 135 | # todo: add options to freeze encoder/decoder 136 | self.encoder = instantiate_from_config(encoder_config) 137 | self.decoder = instantiate_from_config(decoder_config) 138 | self.loss = instantiate_from_config(loss_config) 139 | self.regularization = instantiate_from_config(regularizer_config) 140 | self.optimizer_config = default( 141 | optimizer_config, {"target": "torch.optim.Adam"} 142 | ) 143 | self.lr_g_factor = lr_g_factor 144 | 145 | def get_input(self, batch: Dict) -> torch.Tensor: 146 | # assuming unified data format, dataloader returns a dict. 147 | # image tensors should be scaled to -1 ... 1 and in channels-first format (e.g., bchw instead if bhwc) 148 | return batch[self.input_key] 149 | 150 | def get_autoencoder_params(self) -> list: 151 | params = ( 152 | list(self.encoder.parameters()) 153 | + list(self.decoder.parameters()) 154 | + list(self.regularization.get_trainable_parameters()) 155 | + list(self.loss.get_trainable_autoencoder_parameters()) 156 | ) 157 | return params 158 | 159 | def get_discriminator_params(self) -> list: 160 | params = list(self.loss.get_trainable_parameters()) # e.g., discriminator 161 | return params 162 | 163 | def get_last_layer(self): 164 | return self.decoder.get_last_layer() 165 | 166 | def encode(self, x: Any, return_reg_log: bool = False) -> Any: 167 | z = self.encoder(x) 168 | z, reg_log = self.regularization(z) 169 | if return_reg_log: 170 | return z, reg_log 171 | return z 172 | 173 | def decode(self, z: Any) -> torch.Tensor: 174 | x = self.decoder(z) 175 | return x 176 | 177 | def forward(self, x: Any) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 178 | z, reg_log = self.encode(x, return_reg_log=True) 179 | dec = self.decode(z) 180 | return z, dec, reg_log 181 | 182 | def training_step(self, batch, batch_idx, optimizer_idx) -> Any: 183 | x = self.get_input(batch) 184 | z, xrec, regularization_log = self(x) 185 | 186 | if optimizer_idx == 0: 187 | # autoencode 188 | aeloss, log_dict_ae = self.loss( 189 | regularization_log, 190 | x, 191 | xrec, 192 | optimizer_idx, 193 | self.global_step, 194 | last_layer=self.get_last_layer(), 195 | split="train", 196 | ) 197 | 198 | self.log_dict( 199 | log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True 200 | ) 201 | return aeloss 202 | 203 | if optimizer_idx == 1: 204 | # discriminator 205 | discloss, log_dict_disc = self.loss( 206 | regularization_log, 207 | x, 208 | xrec, 209 | optimizer_idx, 210 | self.global_step, 211 | last_layer=self.get_last_layer(), 212 | split="train", 213 | ) 214 | self.log_dict( 215 | log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True 216 | ) 217 | return discloss 218 | 219 | def validation_step(self, batch, batch_idx) -> Dict: 220 | log_dict = self._validation_step(batch, batch_idx) 221 | with self.ema_scope(): 222 | log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema") 223 | log_dict.update(log_dict_ema) 224 | return log_dict 225 | 226 | def _validation_step(self, batch, batch_idx, postfix="") -> Dict: 227 | x = self.get_input(batch) 228 | 229 | z, xrec, regularization_log = self(x) 230 | aeloss, log_dict_ae = self.loss( 231 | regularization_log, 232 | x, 233 | xrec, 234 | 0, 235 | self.global_step, 236 | last_layer=self.get_last_layer(), 237 | split="val" + postfix, 238 | ) 239 | 240 | discloss, log_dict_disc = self.loss( 241 | regularization_log, 242 | x, 243 | xrec, 244 | 1, 245 | self.global_step, 246 | last_layer=self.get_last_layer(), 247 | split="val" + postfix, 248 | ) 249 | self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"]) 250 | log_dict_ae.update(log_dict_disc) 251 | self.log_dict(log_dict_ae) 252 | return log_dict_ae 253 | 254 | def configure_optimizers(self) -> Any: 255 | ae_params = self.get_autoencoder_params() 256 | disc_params = self.get_discriminator_params() 257 | 258 | opt_ae = self.instantiate_optimizer_from_config( 259 | ae_params, 260 | default(self.lr_g_factor, 1.0) * self.learning_rate, 261 | self.optimizer_config, 262 | ) 263 | opt_disc = self.instantiate_optimizer_from_config( 264 | disc_params, self.learning_rate, self.optimizer_config 265 | ) 266 | 267 | return [opt_ae, opt_disc], [] 268 | 269 | @torch.no_grad() 270 | def log_images(self, batch: Dict, **kwargs) -> Dict: 271 | log = dict() 272 | x = self.get_input(batch) 273 | _, xrec, _ = self(x) 274 | log["inputs"] = x 275 | log["reconstructions"] = xrec 276 | with self.ema_scope(): 277 | _, xrec_ema, _ = self(x) 278 | log["reconstructions_ema"] = xrec_ema 279 | return log 280 | 281 | 282 | class AutoencoderKL(AutoencodingEngine): 283 | def __init__(self, embed_dim: int, **kwargs): 284 | ddconfig = kwargs.pop("ddconfig") 285 | ckpt_path = kwargs.pop("ckpt_path", None) 286 | ignore_keys = kwargs.pop("ignore_keys", ()) 287 | super().__init__( 288 | encoder_config={"target": "torch.nn.Identity"}, 289 | decoder_config={"target": "torch.nn.Identity"}, 290 | regularizer_config={"target": "torch.nn.Identity"}, 291 | loss_config=kwargs.pop("lossconfig"), 292 | **kwargs, 293 | ) 294 | assert ddconfig["double_z"] 295 | self.encoder = Encoder(**ddconfig) 296 | self.decoder = Decoder(**ddconfig) 297 | self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1) 298 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) 299 | self.embed_dim = embed_dim 300 | 301 | if ckpt_path is not None: 302 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 303 | 304 | def encode(self, x): 305 | assert ( 306 | not self.training 307 | ), f"{self.__class__.__name__} only supports inference currently" 308 | h = self.encoder(x) 309 | moments = self.quant_conv(h) 310 | posterior = DiagonalGaussianDistribution(moments) 311 | return posterior 312 | 313 | def decode(self, z, **decoder_kwargs): 314 | z = self.post_quant_conv(z) 315 | dec = self.decoder(z, **decoder_kwargs) 316 | return dec 317 | 318 | 319 | class AutoencoderKLInferenceWrapper(AutoencoderKL): 320 | def encode(self, x): 321 | return super().encode(x).sample() 322 | 323 | 324 | class IdentityFirstStage(AbstractAutoencoder): 325 | def __init__(self, *args, **kwargs): 326 | super().__init__(*args, **kwargs) 327 | 328 | def get_input(self, x: Any) -> Any: 329 | return x 330 | 331 | def encode(self, x: Any, *args, **kwargs) -> Any: 332 | return x 333 | 334 | def decode(self, x: Any, *args, **kwargs) -> Any: 335 | return x 336 | -------------------------------------------------------------------------------- /sgm/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .encoders.modules import GeneralConditioner 2 | 3 | UNCONDITIONAL_CONFIG = { 4 | "target": "sgm.modules.GeneralConditioner", 5 | "params": {"emb_models": []}, 6 | } 7 | -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/customdiffusion360/custom-diffusion360/1a23f972274e7275fdeaa3197f5d22118aa228bb/sgm/modules/autoencoding/lpips/__init__.py -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/loss/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/customdiffusion360/custom-diffusion360/1a23f972274e7275fdeaa3197f5d22118aa228bb/sgm/modules/autoencoding/lpips/loss/__init__.py -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/loss/lpips.py: -------------------------------------------------------------------------------- 1 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" 2 | 3 | from collections import namedtuple 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import models 8 | 9 | from ..util import get_ckpt_path 10 | 11 | 12 | class LPIPS(nn.Module): 13 | # Learned perceptual metric 14 | def __init__(self, use_dropout=True): 15 | super().__init__() 16 | self.scaling_layer = ScalingLayer() 17 | self.chns = [64, 128, 256, 512, 512] # vg16 features 18 | self.net = vgg16(pretrained=True, requires_grad=False) 19 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 20 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 21 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 22 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 23 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 24 | self.load_from_pretrained() 25 | for param in self.parameters(): 26 | param.requires_grad = False 27 | 28 | def load_from_pretrained(self, name="vgg_lpips"): 29 | ckpt = get_ckpt_path(name, "sgm/modules/autoencoding/lpips/loss") 30 | self.load_state_dict( 31 | torch.load(ckpt, map_location=torch.device("cpu")), strict=False 32 | ) 33 | print("loaded pretrained LPIPS loss from {}".format(ckpt)) 34 | 35 | @classmethod 36 | def from_pretrained(cls, name="vgg_lpips"): 37 | if name != "vgg_lpips": 38 | raise NotImplementedError 39 | model = cls() 40 | ckpt = get_ckpt_path(name) 41 | model.load_state_dict( 42 | torch.load(ckpt, map_location=torch.device("cpu")), strict=False 43 | ) 44 | return model 45 | 46 | def forward(self, input, target): 47 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) 48 | outs0, outs1 = self.net(in0_input), self.net(in1_input) 49 | feats0, feats1, diffs = {}, {}, {} 50 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 51 | for kk in range(len(self.chns)): 52 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor( 53 | outs1[kk] 54 | ) 55 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 56 | 57 | res = [ 58 | spatial_average(lins[kk].model(diffs[kk]), keepdim=True) 59 | for kk in range(len(self.chns)) 60 | ] 61 | val = res[0] 62 | for l in range(1, len(self.chns)): 63 | val += res[l] 64 | return val 65 | 66 | 67 | class ScalingLayer(nn.Module): 68 | def __init__(self): 69 | super(ScalingLayer, self).__init__() 70 | self.register_buffer( 71 | "shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None] 72 | ) 73 | self.register_buffer( 74 | "scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None] 75 | ) 76 | 77 | def forward(self, inp): 78 | return (inp - self.shift) / self.scale 79 | 80 | 81 | class NetLinLayer(nn.Module): 82 | """A single linear layer which does a 1x1 conv""" 83 | 84 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 85 | super(NetLinLayer, self).__init__() 86 | layers = ( 87 | [ 88 | nn.Dropout(), 89 | ] 90 | if (use_dropout) 91 | else [] 92 | ) 93 | layers += [ 94 | nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), 95 | ] 96 | self.model = nn.Sequential(*layers) 97 | 98 | 99 | class vgg16(torch.nn.Module): 100 | def __init__(self, requires_grad=False, pretrained=True): 101 | super(vgg16, self).__init__() 102 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features 103 | self.slice1 = torch.nn.Sequential() 104 | self.slice2 = torch.nn.Sequential() 105 | self.slice3 = torch.nn.Sequential() 106 | self.slice4 = torch.nn.Sequential() 107 | self.slice5 = torch.nn.Sequential() 108 | self.N_slices = 5 109 | for x in range(4): 110 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 111 | for x in range(4, 9): 112 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 113 | for x in range(9, 16): 114 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 115 | for x in range(16, 23): 116 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 117 | for x in range(23, 30): 118 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 119 | if not requires_grad: 120 | for param in self.parameters(): 121 | param.requires_grad = False 122 | 123 | def forward(self, X): 124 | h = self.slice1(X) 125 | h_relu1_2 = h 126 | h = self.slice2(h) 127 | h_relu2_2 = h 128 | h = self.slice3(h) 129 | h_relu3_3 = h 130 | h = self.slice4(h) 131 | h_relu4_3 = h 132 | h = self.slice5(h) 133 | h_relu5_3 = h 134 | vgg_outputs = namedtuple( 135 | "VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"] 136 | ) 137 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 138 | return out 139 | 140 | 141 | def normalize_tensor(x, eps=1e-10): 142 | norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) 143 | return x / (norm_factor + eps) 144 | 145 | 146 | def spatial_average(x, keepdim=True): 147 | return x.mean([2, 3], keepdim=keepdim) 148 | -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/model/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017, Jun-Yan Zhu and Taesung Park 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 24 | 25 | 26 | --------------------------- LICENSE FOR pix2pix -------------------------------- 27 | BSD License 28 | 29 | For pix2pix software 30 | Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu 31 | All rights reserved. 32 | 33 | Redistribution and use in source and binary forms, with or without 34 | modification, are permitted provided that the following conditions are met: 35 | 36 | * Redistributions of source code must retain the above copyright notice, this 37 | list of conditions and the following disclaimer. 38 | 39 | * Redistributions in binary form must reproduce the above copyright notice, 40 | this list of conditions and the following disclaimer in the documentation 41 | and/or other materials provided with the distribution. 42 | 43 | ----------------------------- LICENSE FOR DCGAN -------------------------------- 44 | BSD License 45 | 46 | For dcgan.torch software 47 | 48 | Copyright (c) 2015, Facebook, Inc. All rights reserved. 49 | 50 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 51 | 52 | Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 53 | 54 | Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 55 | 56 | Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 57 | 58 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/customdiffusion360/custom-diffusion360/1a23f972274e7275fdeaa3197f5d22118aa228bb/sgm/modules/autoencoding/lpips/model/__init__.py -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/model/model.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import torch.nn as nn 4 | 5 | from ..util import ActNorm 6 | 7 | 8 | def weights_init(m): 9 | classname = m.__class__.__name__ 10 | if classname.find("Conv") != -1: 11 | nn.init.normal_(m.weight.data, 0.0, 0.02) 12 | elif classname.find("BatchNorm") != -1: 13 | nn.init.normal_(m.weight.data, 1.0, 0.02) 14 | nn.init.constant_(m.bias.data, 0) 15 | 16 | 17 | class NLayerDiscriminator(nn.Module): 18 | """Defines a PatchGAN discriminator as in Pix2Pix 19 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 20 | """ 21 | 22 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): 23 | """Construct a PatchGAN discriminator 24 | Parameters: 25 | input_nc (int) -- the number of channels in input images 26 | ndf (int) -- the number of filters in the last conv layer 27 | n_layers (int) -- the number of conv layers in the discriminator 28 | norm_layer -- normalization layer 29 | """ 30 | super(NLayerDiscriminator, self).__init__() 31 | if not use_actnorm: 32 | norm_layer = nn.BatchNorm2d 33 | else: 34 | norm_layer = ActNorm 35 | if ( 36 | type(norm_layer) == functools.partial 37 | ): # no need to use bias as BatchNorm2d has affine parameters 38 | use_bias = norm_layer.func != nn.BatchNorm2d 39 | else: 40 | use_bias = norm_layer != nn.BatchNorm2d 41 | 42 | kw = 4 43 | padw = 1 44 | sequence = [ 45 | nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), 46 | nn.LeakyReLU(0.2, True), 47 | ] 48 | nf_mult = 1 49 | nf_mult_prev = 1 50 | for n in range(1, n_layers): # gradually increase the number of filters 51 | nf_mult_prev = nf_mult 52 | nf_mult = min(2**n, 8) 53 | sequence += [ 54 | nn.Conv2d( 55 | ndf * nf_mult_prev, 56 | ndf * nf_mult, 57 | kernel_size=kw, 58 | stride=2, 59 | padding=padw, 60 | bias=use_bias, 61 | ), 62 | norm_layer(ndf * nf_mult), 63 | nn.LeakyReLU(0.2, True), 64 | ] 65 | 66 | nf_mult_prev = nf_mult 67 | nf_mult = min(2**n_layers, 8) 68 | sequence += [ 69 | nn.Conv2d( 70 | ndf * nf_mult_prev, 71 | ndf * nf_mult, 72 | kernel_size=kw, 73 | stride=1, 74 | padding=padw, 75 | bias=use_bias, 76 | ), 77 | norm_layer(ndf * nf_mult), 78 | nn.LeakyReLU(0.2, True), 79 | ] 80 | 81 | sequence += [ 82 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw) 83 | ] # output 1 channel prediction map 84 | self.main = nn.Sequential(*sequence) 85 | 86 | def forward(self, input): 87 | """Standard forward.""" 88 | return self.main(input) 89 | -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/util.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | 4 | import requests 5 | import torch 6 | import torch.nn as nn 7 | from tqdm import tqdm 8 | 9 | URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"} 10 | 11 | CKPT_MAP = {"vgg_lpips": "vgg.pth"} 12 | 13 | MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"} 14 | 15 | 16 | def download(url, local_path, chunk_size=1024): 17 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 18 | with requests.get(url, stream=True) as r: 19 | total_size = int(r.headers.get("content-length", 0)) 20 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 21 | with open(local_path, "wb") as f: 22 | for data in r.iter_content(chunk_size=chunk_size): 23 | if data: 24 | f.write(data) 25 | pbar.update(chunk_size) 26 | 27 | 28 | def md5_hash(path): 29 | with open(path, "rb") as f: 30 | content = f.read() 31 | return hashlib.md5(content).hexdigest() 32 | 33 | 34 | def get_ckpt_path(name, root, check=False): 35 | assert name in URL_MAP 36 | path = os.path.join(root, CKPT_MAP[name]) 37 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): 38 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) 39 | download(URL_MAP[name], path) 40 | md5 = md5_hash(path) 41 | assert md5 == MD5_MAP[name], md5 42 | return path 43 | 44 | 45 | class ActNorm(nn.Module): 46 | def __init__( 47 | self, num_features, logdet=False, affine=True, allow_reverse_init=False 48 | ): 49 | assert affine 50 | super().__init__() 51 | self.logdet = logdet 52 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) 53 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) 54 | self.allow_reverse_init = allow_reverse_init 55 | 56 | self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8)) 57 | 58 | def initialize(self, input): 59 | with torch.no_grad(): 60 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) 61 | mean = ( 62 | flatten.mean(1) 63 | .unsqueeze(1) 64 | .unsqueeze(2) 65 | .unsqueeze(3) 66 | .permute(1, 0, 2, 3) 67 | ) 68 | std = ( 69 | flatten.std(1) 70 | .unsqueeze(1) 71 | .unsqueeze(2) 72 | .unsqueeze(3) 73 | .permute(1, 0, 2, 3) 74 | ) 75 | 76 | self.loc.data.copy_(-mean) 77 | self.scale.data.copy_(1 / (std + 1e-6)) 78 | 79 | def forward(self, input, reverse=False): 80 | if reverse: 81 | return self.reverse(input) 82 | if len(input.shape) == 2: 83 | input = input[:, :, None, None] 84 | squeeze = True 85 | else: 86 | squeeze = False 87 | 88 | _, _, height, width = input.shape 89 | 90 | if self.training and self.initialized.item() == 0: 91 | self.initialize(input) 92 | self.initialized.fill_(1) 93 | 94 | h = self.scale * (input + self.loc) 95 | 96 | if squeeze: 97 | h = h.squeeze(-1).squeeze(-1) 98 | 99 | if self.logdet: 100 | log_abs = torch.log(torch.abs(self.scale)) 101 | logdet = height * width * torch.sum(log_abs) 102 | logdet = logdet * torch.ones(input.shape[0]).to(input) 103 | return h, logdet 104 | 105 | return h 106 | 107 | def reverse(self, output): 108 | if self.training and self.initialized.item() == 0: 109 | if not self.allow_reverse_init: 110 | raise RuntimeError( 111 | "Initializing ActNorm in reverse direction is " 112 | "disabled by default. Use allow_reverse_init=True to enable." 113 | ) 114 | else: 115 | self.initialize(output) 116 | self.initialized.fill_(1) 117 | 118 | if len(output.shape) == 2: 119 | output = output[:, :, None, None] 120 | squeeze = True 121 | else: 122 | squeeze = False 123 | 124 | h = output / self.scale - self.loc 125 | 126 | if squeeze: 127 | h = h.squeeze(-1).squeeze(-1) 128 | return h 129 | -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/vqperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def hinge_d_loss(logits_real, logits_fake): 6 | loss_real = torch.mean(F.relu(1.0 - logits_real)) 7 | loss_fake = torch.mean(F.relu(1.0 + logits_fake)) 8 | d_loss = 0.5 * (loss_real + loss_fake) 9 | return d_loss 10 | 11 | 12 | def vanilla_d_loss(logits_real, logits_fake): 13 | d_loss = 0.5 * ( 14 | torch.mean(torch.nn.functional.softplus(-logits_real)) 15 | + torch.mean(torch.nn.functional.softplus(logits_fake)) 16 | ) 17 | return d_loss 18 | -------------------------------------------------------------------------------- /sgm/modules/autoencoding/regularizers/__init__.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Any, Tuple 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from ....modules.distributions.distributions import \ 9 | DiagonalGaussianDistribution 10 | from .base import AbstractRegularizer 11 | 12 | 13 | class DiagonalGaussianRegularizer(AbstractRegularizer): 14 | def __init__(self, sample: bool = True): 15 | super().__init__() 16 | self.sample = sample 17 | 18 | def get_trainable_parameters(self) -> Any: 19 | yield from () 20 | 21 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: 22 | log = dict() 23 | posterior = DiagonalGaussianDistribution(z) 24 | if self.sample: 25 | z = posterior.sample() 26 | else: 27 | z = posterior.mode() 28 | kl_loss = posterior.kl() 29 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 30 | log["kl_loss"] = kl_loss 31 | return z, log 32 | -------------------------------------------------------------------------------- /sgm/modules/autoencoding/regularizers/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Any, Tuple 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn 7 | 8 | 9 | class AbstractRegularizer(nn.Module): 10 | def __init__(self): 11 | super().__init__() 12 | 13 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: 14 | raise NotImplementedError() 15 | 16 | @abstractmethod 17 | def get_trainable_parameters(self) -> Any: 18 | raise NotImplementedError() 19 | 20 | 21 | class IdentityRegularizer(AbstractRegularizer): 22 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: 23 | return z, dict() 24 | 25 | def get_trainable_parameters(self) -> Any: 26 | yield from () 27 | 28 | 29 | def measure_perplexity( 30 | predicted_indices: torch.Tensor, num_centroids: int 31 | ) -> Tuple[torch.Tensor, torch.Tensor]: 32 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py 33 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally 34 | encodings = ( 35 | F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids) 36 | ) 37 | avg_probs = encodings.mean(0) 38 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() 39 | cluster_use = torch.sum(avg_probs > 0) 40 | return perplexity, cluster_use 41 | -------------------------------------------------------------------------------- /sgm/modules/autoencoding/regularizers/quantize.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from abc import abstractmethod 3 | from typing import Dict, Iterator, Literal, Optional, Tuple, Union 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from einops import rearrange 10 | from torch import einsum 11 | 12 | from .base import AbstractRegularizer, measure_perplexity 13 | 14 | logpy = logging.getLogger(__name__) 15 | 16 | 17 | class AbstractQuantizer(AbstractRegularizer): 18 | def __init__(self): 19 | super().__init__() 20 | # Define these in your init 21 | # shape (N,) 22 | self.used: Optional[torch.Tensor] 23 | self.re_embed: int 24 | self.unknown_index: Union[Literal["random"], int] 25 | 26 | def remap_to_used(self, inds: torch.Tensor) -> torch.Tensor: 27 | assert self.used is not None, "You need to define used indices for remap" 28 | ishape = inds.shape 29 | assert len(ishape) > 1 30 | inds = inds.reshape(ishape[0], -1) 31 | used = self.used.to(inds) 32 | match = (inds[:, :, None] == used[None, None, ...]).long() 33 | new = match.argmax(-1) 34 | unknown = match.sum(2) < 1 35 | if self.unknown_index == "random": 36 | new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to( 37 | device=new.device 38 | ) 39 | else: 40 | new[unknown] = self.unknown_index 41 | return new.reshape(ishape) 42 | 43 | def unmap_to_all(self, inds: torch.Tensor) -> torch.Tensor: 44 | assert self.used is not None, "You need to define used indices for remap" 45 | ishape = inds.shape 46 | assert len(ishape) > 1 47 | inds = inds.reshape(ishape[0], -1) 48 | used = self.used.to(inds) 49 | if self.re_embed > self.used.shape[0]: # extra token 50 | inds[inds >= self.used.shape[0]] = 0 # simply set to zero 51 | back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) 52 | return back.reshape(ishape) 53 | 54 | @abstractmethod 55 | def get_codebook_entry( 56 | self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None 57 | ) -> torch.Tensor: 58 | raise NotImplementedError() 59 | 60 | def get_trainable_parameters(self) -> Iterator[torch.nn.Parameter]: 61 | yield from self.parameters() 62 | 63 | 64 | class GumbelQuantizer(AbstractQuantizer): 65 | """ 66 | credit to @karpathy: 67 | https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!) 68 | Gumbel Softmax trick quantizer 69 | Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016 70 | https://arxiv.org/abs/1611.01144 71 | """ 72 | 73 | def __init__( 74 | self, 75 | num_hiddens: int, 76 | embedding_dim: int, 77 | n_embed: int, 78 | straight_through: bool = True, 79 | kl_weight: float = 5e-4, 80 | temp_init: float = 1.0, 81 | remap: Optional[str] = None, 82 | unknown_index: str = "random", 83 | loss_key: str = "loss/vq", 84 | ) -> None: 85 | super().__init__() 86 | 87 | self.loss_key = loss_key 88 | self.embedding_dim = embedding_dim 89 | self.n_embed = n_embed 90 | 91 | self.straight_through = straight_through 92 | self.temperature = temp_init 93 | self.kl_weight = kl_weight 94 | 95 | self.proj = nn.Conv2d(num_hiddens, n_embed, 1) 96 | self.embed = nn.Embedding(n_embed, embedding_dim) 97 | 98 | self.remap = remap 99 | if self.remap is not None: 100 | self.register_buffer("used", torch.tensor(np.load(self.remap))) 101 | self.re_embed = self.used.shape[0] 102 | else: 103 | self.used = None 104 | self.re_embed = n_embed 105 | if unknown_index == "extra": 106 | self.unknown_index = self.re_embed 107 | self.re_embed = self.re_embed + 1 108 | else: 109 | assert unknown_index == "random" or isinstance( 110 | unknown_index, int 111 | ), "unknown index needs to be 'random', 'extra' or any integer" 112 | self.unknown_index = unknown_index # "random" or "extra" or integer 113 | if self.remap is not None: 114 | logpy.info( 115 | f"Remapping {self.n_embed} indices to {self.re_embed} indices. " 116 | f"Using {self.unknown_index} for unknown indices." 117 | ) 118 | 119 | def forward( 120 | self, z: torch.Tensor, temp: Optional[float] = None, return_logits: bool = False 121 | ) -> Tuple[torch.Tensor, Dict]: 122 | # force hard = True when we are in eval mode, as we must quantize. 123 | # actually, always true seems to work 124 | hard = self.straight_through if self.training else True 125 | temp = self.temperature if temp is None else temp 126 | out_dict = {} 127 | logits = self.proj(z) 128 | if self.remap is not None: 129 | # continue only with used logits 130 | full_zeros = torch.zeros_like(logits) 131 | logits = logits[:, self.used, ...] 132 | 133 | soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard) 134 | if self.remap is not None: 135 | # go back to all entries but unused set to zero 136 | full_zeros[:, self.used, ...] = soft_one_hot 137 | soft_one_hot = full_zeros 138 | z_q = einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight) 139 | 140 | # + kl divergence to the prior loss 141 | qy = F.softmax(logits, dim=1) 142 | diff = ( 143 | self.kl_weight 144 | * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean() 145 | ) 146 | out_dict[self.loss_key] = diff 147 | 148 | ind = soft_one_hot.argmax(dim=1) 149 | out_dict["indices"] = ind 150 | if self.remap is not None: 151 | ind = self.remap_to_used(ind) 152 | 153 | if return_logits: 154 | out_dict["logits"] = logits 155 | 156 | return z_q, out_dict 157 | 158 | def get_codebook_entry(self, indices, shape): 159 | # TODO: shape not yet optional 160 | b, h, w, c = shape 161 | assert b * h * w == indices.shape[0] 162 | indices = rearrange(indices, "(b h w) -> b h w", b=b, h=h, w=w) 163 | if self.remap is not None: 164 | indices = self.unmap_to_all(indices) 165 | one_hot = ( 166 | F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float() 167 | ) 168 | z_q = einsum("b n h w, n d -> b d h w", one_hot, self.embed.weight) 169 | return z_q 170 | 171 | 172 | class VectorQuantizer(AbstractQuantizer): 173 | """ 174 | ____________________________________________ 175 | Discretization bottleneck part of the VQ-VAE. 176 | Inputs: 177 | - n_e : number of embeddings 178 | - e_dim : dimension of embedding 179 | - beta : commitment cost used in loss term, 180 | beta * ||z_e(x)-sg[e]||^2 181 | _____________________________________________ 182 | """ 183 | 184 | def __init__( 185 | self, 186 | n_e: int, 187 | e_dim: int, 188 | beta: float = 0.25, 189 | remap: Optional[str] = None, 190 | unknown_index: str = "random", 191 | sane_index_shape: bool = False, 192 | log_perplexity: bool = False, 193 | embedding_weight_norm: bool = False, 194 | loss_key: str = "loss/vq", 195 | ): 196 | super().__init__() 197 | self.n_e = n_e 198 | self.e_dim = e_dim 199 | self.beta = beta 200 | self.loss_key = loss_key 201 | 202 | if not embedding_weight_norm: 203 | self.embedding = nn.Embedding(self.n_e, self.e_dim) 204 | self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) 205 | else: 206 | self.embedding = torch.nn.utils.weight_norm( 207 | nn.Embedding(self.n_e, self.e_dim), dim=1 208 | ) 209 | 210 | self.remap = remap 211 | if self.remap is not None: 212 | self.register_buffer("used", torch.tensor(np.load(self.remap))) 213 | self.re_embed = self.used.shape[0] 214 | else: 215 | self.used = None 216 | self.re_embed = n_e 217 | if unknown_index == "extra": 218 | self.unknown_index = self.re_embed 219 | self.re_embed = self.re_embed + 1 220 | else: 221 | assert unknown_index == "random" or isinstance( 222 | unknown_index, int 223 | ), "unknown index needs to be 'random', 'extra' or any integer" 224 | self.unknown_index = unknown_index # "random" or "extra" or integer 225 | if self.remap is not None: 226 | logpy.info( 227 | f"Remapping {self.n_e} indices to {self.re_embed} indices. " 228 | f"Using {self.unknown_index} for unknown indices." 229 | ) 230 | 231 | self.sane_index_shape = sane_index_shape 232 | self.log_perplexity = log_perplexity 233 | 234 | def forward( 235 | self, 236 | z: torch.Tensor, 237 | ) -> Tuple[torch.Tensor, Dict]: 238 | do_reshape = z.ndim == 4 239 | if do_reshape: 240 | # # reshape z -> (batch, height, width, channel) and flatten 241 | z = rearrange(z, "b c h w -> b h w c").contiguous() 242 | 243 | else: 244 | assert z.ndim < 4, "No reshaping strategy for inputs > 4 dimensions defined" 245 | z = z.contiguous() 246 | 247 | z_flattened = z.view(-1, self.e_dim) 248 | # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z 249 | 250 | d = ( 251 | torch.sum(z_flattened**2, dim=1, keepdim=True) 252 | + torch.sum(self.embedding.weight**2, dim=1) 253 | - 2 254 | * torch.einsum( 255 | "bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n") 256 | ) 257 | ) 258 | 259 | min_encoding_indices = torch.argmin(d, dim=1) 260 | z_q = self.embedding(min_encoding_indices).view(z.shape) 261 | loss_dict = {} 262 | if self.log_perplexity: 263 | perplexity, cluster_usage = measure_perplexity( 264 | min_encoding_indices.detach(), self.n_e 265 | ) 266 | loss_dict.update({"perplexity": perplexity, "cluster_usage": cluster_usage}) 267 | 268 | # compute loss for embedding 269 | loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean( 270 | (z_q - z.detach()) ** 2 271 | ) 272 | loss_dict[self.loss_key] = loss 273 | 274 | # preserve gradients 275 | z_q = z + (z_q - z).detach() 276 | 277 | # reshape back to match original input shape 278 | if do_reshape: 279 | z_q = rearrange(z_q, "b h w c -> b c h w").contiguous() 280 | 281 | if self.remap is not None: 282 | min_encoding_indices = min_encoding_indices.reshape( 283 | z.shape[0], -1 284 | ) # add batch axis 285 | min_encoding_indices = self.remap_to_used(min_encoding_indices) 286 | min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten 287 | 288 | if self.sane_index_shape: 289 | if do_reshape: 290 | min_encoding_indices = min_encoding_indices.reshape( 291 | z_q.shape[0], z_q.shape[2], z_q.shape[3] 292 | ) 293 | else: 294 | min_encoding_indices = rearrange( 295 | min_encoding_indices, "(b s) 1 -> b s", b=z_q.shape[0] 296 | ) 297 | 298 | loss_dict["min_encoding_indices"] = min_encoding_indices 299 | 300 | return z_q, loss_dict 301 | 302 | def get_codebook_entry( 303 | self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None 304 | ) -> torch.Tensor: 305 | # shape specifying (batch, height, width, channel) 306 | if self.remap is not None: 307 | assert shape is not None, "Need to give shape for remap" 308 | indices = indices.reshape(shape[0], -1) # add batch axis 309 | indices = self.unmap_to_all(indices) 310 | indices = indices.reshape(-1) # flatten again 311 | 312 | # get quantized latent vectors 313 | z_q = self.embedding(indices) 314 | 315 | if shape is not None: 316 | z_q = z_q.view(shape) 317 | # reshape back to match original input shape 318 | z_q = z_q.permute(0, 3, 1, 2).contiguous() 319 | 320 | return z_q 321 | 322 | 323 | class EmbeddingEMA(nn.Module): 324 | def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5): 325 | super().__init__() 326 | self.decay = decay 327 | self.eps = eps 328 | weight = torch.randn(num_tokens, codebook_dim) 329 | self.weight = nn.Parameter(weight, requires_grad=False) 330 | self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad=False) 331 | self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False) 332 | self.update = True 333 | 334 | def forward(self, embed_id): 335 | return F.embedding(embed_id, self.weight) 336 | 337 | def cluster_size_ema_update(self, new_cluster_size): 338 | self.cluster_size.data.mul_(self.decay).add_( 339 | new_cluster_size, alpha=1 - self.decay 340 | ) 341 | 342 | def embed_avg_ema_update(self, new_embed_avg): 343 | self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay) 344 | 345 | def weight_update(self, num_tokens): 346 | n = self.cluster_size.sum() 347 | smoothed_cluster_size = ( 348 | (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n 349 | ) 350 | # normalize embedding average with smoothed cluster size 351 | embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1) 352 | self.weight.data.copy_(embed_normalized) 353 | 354 | 355 | class EMAVectorQuantizer(AbstractQuantizer): 356 | def __init__( 357 | self, 358 | n_embed: int, 359 | embedding_dim: int, 360 | beta: float, 361 | decay: float = 0.99, 362 | eps: float = 1e-5, 363 | remap: Optional[str] = None, 364 | unknown_index: str = "random", 365 | loss_key: str = "loss/vq", 366 | ): 367 | super().__init__() 368 | self.codebook_dim = embedding_dim 369 | self.num_tokens = n_embed 370 | self.beta = beta 371 | self.loss_key = loss_key 372 | 373 | self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps) 374 | 375 | self.remap = remap 376 | if self.remap is not None: 377 | self.register_buffer("used", torch.tensor(np.load(self.remap))) 378 | self.re_embed = self.used.shape[0] 379 | else: 380 | self.used = None 381 | self.re_embed = n_embed 382 | if unknown_index == "extra": 383 | self.unknown_index = self.re_embed 384 | self.re_embed = self.re_embed + 1 385 | else: 386 | assert unknown_index == "random" or isinstance( 387 | unknown_index, int 388 | ), "unknown index needs to be 'random', 'extra' or any integer" 389 | self.unknown_index = unknown_index # "random" or "extra" or integer 390 | if self.remap is not None: 391 | logpy.info( 392 | f"Remapping {self.n_embed} indices to {self.re_embed} indices. " 393 | f"Using {self.unknown_index} for unknown indices." 394 | ) 395 | 396 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Dict]: 397 | # reshape z -> (batch, height, width, channel) and flatten 398 | # z, 'b c h w -> b h w c' 399 | z = rearrange(z, "b c h w -> b h w c") 400 | z_flattened = z.reshape(-1, self.codebook_dim) 401 | 402 | # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z 403 | d = ( 404 | z_flattened.pow(2).sum(dim=1, keepdim=True) 405 | + self.embedding.weight.pow(2).sum(dim=1) 406 | - 2 * torch.einsum("bd,nd->bn", z_flattened, self.embedding.weight) 407 | ) # 'n d -> d n' 408 | 409 | encoding_indices = torch.argmin(d, dim=1) 410 | 411 | z_q = self.embedding(encoding_indices).view(z.shape) 412 | encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype) 413 | avg_probs = torch.mean(encodings, dim=0) 414 | perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) 415 | 416 | if self.training and self.embedding.update: 417 | # EMA cluster size 418 | encodings_sum = encodings.sum(0) 419 | self.embedding.cluster_size_ema_update(encodings_sum) 420 | # EMA embedding average 421 | embed_sum = encodings.transpose(0, 1) @ z_flattened 422 | self.embedding.embed_avg_ema_update(embed_sum) 423 | # normalize embed_avg and update weight 424 | self.embedding.weight_update(self.num_tokens) 425 | 426 | # compute loss for embedding 427 | loss = self.beta * F.mse_loss(z_q.detach(), z) 428 | 429 | # preserve gradients 430 | z_q = z + (z_q - z).detach() 431 | 432 | # reshape back to match original input shape 433 | # z_q, 'b h w c -> b c h w' 434 | z_q = rearrange(z_q, "b h w c -> b c h w") 435 | 436 | out_dict = { 437 | self.loss_key: loss, 438 | "encodings": encodings, 439 | "encoding_indices": encoding_indices, 440 | "perplexity": perplexity, 441 | } 442 | 443 | return z_q, out_dict 444 | 445 | 446 | class VectorQuantizerWithInputProjection(VectorQuantizer): 447 | def __init__( 448 | self, 449 | input_dim: int, 450 | n_codes: int, 451 | codebook_dim: int, 452 | beta: float = 1.0, 453 | output_dim: Optional[int] = None, 454 | **kwargs, 455 | ): 456 | super().__init__(n_codes, codebook_dim, beta, **kwargs) 457 | self.proj_in = nn.Linear(input_dim, codebook_dim) 458 | self.output_dim = output_dim 459 | if output_dim is not None: 460 | self.proj_out = nn.Linear(codebook_dim, output_dim) 461 | else: 462 | self.proj_out = nn.Identity() 463 | 464 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Dict]: 465 | rearr = False 466 | in_shape = z.shape 467 | 468 | if z.ndim > 3: 469 | rearr = self.output_dim is not None 470 | z = rearrange(z, "b c ... -> b (...) c") 471 | z = self.proj_in(z) 472 | z_q, loss_dict = super().forward(z) 473 | 474 | z_q = self.proj_out(z_q) 475 | if rearr: 476 | if len(in_shape) == 4: 477 | z_q = rearrange(z_q, "b (h w) c -> b c h w ", w=in_shape[-1]) 478 | elif len(in_shape) == 5: 479 | z_q = rearrange( 480 | z_q, "b (t h w) c -> b c t h w ", w=in_shape[-1], h=in_shape[-2] 481 | ) 482 | else: 483 | raise NotImplementedError( 484 | f"rearranging not available for {len(in_shape)}-dimensional input." 485 | ) 486 | 487 | return z_q, loss_dict 488 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/customdiffusion360/custom-diffusion360/1a23f972274e7275fdeaa3197f5d22118aa228bb/sgm/modules/diffusionmodules/__init__.py -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/denoiser.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from ...util import append_dims, instantiate_from_config 4 | 5 | 6 | class Denoiser(nn.Module): 7 | def __init__(self, weighting_config, scaling_config): 8 | super().__init__() 9 | 10 | self.weighting = instantiate_from_config(weighting_config) 11 | self.scaling = instantiate_from_config(scaling_config) 12 | 13 | def possibly_quantize_sigma(self, sigma): 14 | return sigma 15 | 16 | def possibly_quantize_c_noise(self, c_noise): 17 | return c_noise 18 | 19 | def w(self, sigma): 20 | return self.weighting(sigma) 21 | 22 | def __call__(self, network, input, sigma, cond, sigmas_ref=None, **kwargs): 23 | sigma = self.possibly_quantize_sigma(sigma) 24 | sigma_shape = sigma.shape 25 | sigma = append_dims(sigma, input.ndim) 26 | if sigmas_ref is not None: 27 | if kwargs is not None: 28 | kwargs['sigmas_ref'] = sigmas_ref 29 | else: 30 | kwargs = {'sigmas_ref': sigmas_ref} 31 | 32 | if kwargs['input_ref'] is not None: 33 | noise = torch.randn_like(kwargs['input_ref']) 34 | kwargs['input_ref'] = kwargs['input_ref'] + noise * append_dims(sigmas_ref, kwargs['input_ref'].ndim) 35 | 36 | if 'input_ref' in kwargs and kwargs['input_ref'] is not None and 'sigmas_ref' in kwargs: 37 | _, _, c_in_ref, c_noise_ref = self.scaling(append_dims(kwargs['sigmas_ref'], kwargs['input_ref'].ndim)) 38 | kwargs['input_ref'] = kwargs['input_ref']*c_in_ref 39 | kwargs['sigmas_ref'] = self.possibly_quantize_c_noise(kwargs['sigmas_ref']) 40 | 41 | c_skip, c_out, c_in, c_noise = self.scaling(sigma) 42 | c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape)) 43 | predict, fg_mask_list, alphas_list, rgb_list = network(input * c_in, c_noise, cond, **kwargs) 44 | return predict * c_out + input * c_skip, fg_mask_list, alphas_list, rgb_list 45 | 46 | 47 | class DiscreteDenoiser(Denoiser): 48 | def __init__( 49 | self, 50 | weighting_config, 51 | scaling_config, 52 | num_idx, 53 | discretization_config, 54 | do_append_zero=False, 55 | quantize_c_noise=True, 56 | flip=True, 57 | ): 58 | super().__init__(weighting_config, scaling_config) 59 | sigmas = instantiate_from_config(discretization_config)( 60 | num_idx, do_append_zero=do_append_zero, flip=flip 61 | ) 62 | self.register_buffer("sigmas", sigmas) 63 | self.quantize_c_noise = quantize_c_noise 64 | 65 | def sigma_to_idx(self, sigma): 66 | dists = sigma - self.sigmas[:, None] 67 | return dists.abs().argmin(dim=0).view(sigma.shape) 68 | 69 | def idx_to_sigma(self, idx): 70 | return self.sigmas[idx] 71 | 72 | def possibly_quantize_sigma(self, sigma): 73 | return self.idx_to_sigma(self.sigma_to_idx(sigma)) 74 | 75 | def possibly_quantize_c_noise(self, c_noise): 76 | if self.quantize_c_noise: 77 | return self.sigma_to_idx(c_noise) 78 | else: 79 | return c_noise 80 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/denoiser_scaling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from abc import ABC, abstractmethod 3 | from typing import Tuple 4 | 5 | 6 | class DenoiserScaling(ABC): 7 | @abstractmethod 8 | def __call__( 9 | self, sigma: torch.Tensor 10 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 11 | pass 12 | 13 | 14 | class EDMScaling: 15 | def __init__(self, sigma_data=0.5): 16 | self.sigma_data = sigma_data 17 | 18 | def __call__(self, sigma): 19 | c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) 20 | c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5 21 | c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5 22 | c_noise = 0.25 * sigma.log() 23 | return c_skip, c_out, c_in, c_noise 24 | 25 | 26 | class EpsScaling: 27 | def __call__(self, sigma): 28 | c_skip = torch.ones_like(sigma, device=sigma.device) 29 | c_out = -sigma 30 | c_in = 1 / (sigma**2 + 1.0) ** 0.5 31 | c_noise = sigma.clone() 32 | return c_skip, c_out, c_in, c_noise 33 | 34 | 35 | class VScaling: 36 | def __call__(self, sigma): 37 | c_skip = 1.0 / (sigma**2 + 1.0) 38 | c_out = -sigma / (sigma**2 + 1.0) ** 0.5 39 | c_in = 1.0 / (sigma**2 + 1.0) ** 0.5 40 | c_noise = sigma.clone() 41 | return c_skip, c_out, c_in, c_noise 42 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/denoiser_weighting.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class UnitWeighting: 5 | def __call__(self, sigma): 6 | return torch.ones_like(sigma, device=sigma.device) 7 | 8 | 9 | class EDMWeighting: 10 | def __init__(self, sigma_data=0.5): 11 | self.sigma_data = sigma_data 12 | 13 | def __call__(self, sigma): 14 | return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 15 | 16 | 17 | class VWeighting(EDMWeighting): 18 | def __init__(self): 19 | super().__init__(sigma_data=1.0) 20 | 21 | 22 | class EpsWeighting: 23 | def __call__(self, sigma): 24 | return sigma**-2.0 25 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/discretizer.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from functools import partial 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from ...modules.diffusionmodules.util import make_beta_schedule 8 | from ...util import append_zero 9 | 10 | 11 | def generate_roughly_equally_spaced_steps( 12 | num_substeps: int, max_step: int 13 | ) -> np.ndarray: 14 | return np.linspace(max_step - 1, 0, num_substeps, endpoint=False).astype(int)[::-1] 15 | 16 | 17 | class Discretization: 18 | def __call__(self, n, do_append_zero=True, device="cpu", flip=False): 19 | sigmas = self.get_sigmas(n, device=device) 20 | sigmas = append_zero(sigmas) if do_append_zero else sigmas 21 | return sigmas if not flip else torch.flip(sigmas, (0,)) 22 | 23 | @abstractmethod 24 | def get_sigmas(self, n, device): 25 | pass 26 | 27 | 28 | class EDMDiscretization(Discretization): 29 | def __init__(self, sigma_min=0.002, sigma_max=80.0, rho=7.0): 30 | self.sigma_min = sigma_min 31 | self.sigma_max = sigma_max 32 | self.rho = rho 33 | 34 | def get_sigmas(self, n, device="cpu"): 35 | ramp = torch.linspace(0, 1, n, device=device) 36 | min_inv_rho = self.sigma_min ** (1 / self.rho) 37 | max_inv_rho = self.sigma_max ** (1 / self.rho) 38 | sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** self.rho 39 | return sigmas 40 | 41 | 42 | class LegacyDDPMDiscretization(Discretization): 43 | def __init__( 44 | self, 45 | linear_start=0.00085, 46 | linear_end=0.0120, 47 | num_timesteps=1000, 48 | ): 49 | super().__init__() 50 | self.num_timesteps = num_timesteps 51 | betas = make_beta_schedule( 52 | "linear", num_timesteps, linear_start=linear_start, linear_end=linear_end 53 | ) 54 | alphas = 1.0 - betas 55 | self.alphas_cumprod = np.cumprod(alphas, axis=0) 56 | self.to_torch = partial(torch.tensor, dtype=torch.float32) 57 | 58 | def get_sigmas(self, n, device="cpu"): 59 | if n < self.num_timesteps: 60 | timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps) 61 | alphas_cumprod = self.alphas_cumprod[timesteps] 62 | elif n == self.num_timesteps: 63 | alphas_cumprod = self.alphas_cumprod 64 | else: 65 | raise ValueError 66 | 67 | to_torch = partial(torch.tensor, dtype=torch.float32, device=device) 68 | sigmas = to_torch((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 69 | return torch.flip(sigmas, (0,)) 70 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/guiders.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from abc import ABC, abstractmethod 3 | from typing import Dict, List, Optional, Tuple, Union 4 | 5 | import torch 6 | from einops import rearrange, repeat 7 | 8 | from ...util import append_dims, default 9 | 10 | logpy = logging.getLogger(__name__) 11 | 12 | 13 | class Guider(ABC): 14 | @abstractmethod 15 | def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor: 16 | pass 17 | 18 | def prepare_inputs( 19 | self, x: torch.Tensor, s: float, c: Dict, uc: Dict 20 | ) -> Tuple[torch.Tensor, float, Dict]: 21 | pass 22 | 23 | 24 | class VanillaCFG(Guider): 25 | def __init__(self, scale: float): 26 | self.scale = scale 27 | 28 | def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: 29 | x_u, x_c = x.chunk(2) 30 | x_pred = x_u + self.scale * (x_c - x_u) 31 | return x_pred 32 | 33 | def prepare_inputs(self, x, s, c, uc): 34 | c_out = dict() 35 | 36 | for k in c: 37 | if k in ["vector", "crossattn", "concat"]: 38 | c_out[k] = torch.cat((uc[k], c[k]), 0) 39 | else: 40 | assert c[k] == uc[k] 41 | c_out[k] = c[k] 42 | return torch.cat([x] * 2), torch.cat([s] * 2), c_out 43 | 44 | 45 | class IdentityGuider(Guider): 46 | def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor: 47 | return x 48 | 49 | def prepare_inputs( 50 | self, x: torch.Tensor, s: float, c: Dict, uc: Dict 51 | ) -> Tuple[torch.Tensor, float, Dict]: 52 | c_out = dict() 53 | 54 | for k in c: 55 | c_out[k] = c[k] 56 | 57 | return x, s, c_out 58 | 59 | 60 | class LinearPredictionGuider(Guider): 61 | def __init__( 62 | self, 63 | max_scale: float, 64 | num_frames: int, 65 | min_scale: float = 1.0, 66 | additional_cond_keys: Optional[Union[List[str], str]] = None, 67 | ): 68 | self.min_scale = min_scale 69 | self.max_scale = max_scale 70 | self.num_frames = num_frames 71 | self.scale = torch.linspace(min_scale, max_scale, num_frames).unsqueeze(0) 72 | 73 | additional_cond_keys = default(additional_cond_keys, []) 74 | if isinstance(additional_cond_keys, str): 75 | additional_cond_keys = [additional_cond_keys] 76 | self.additional_cond_keys = additional_cond_keys 77 | 78 | def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: 79 | x_u, x_c = x.chunk(2) 80 | 81 | x_u = rearrange(x_u, "(b t) ... -> b t ...", t=self.num_frames) 82 | x_c = rearrange(x_c, "(b t) ... -> b t ...", t=self.num_frames) 83 | scale = repeat(self.scale, "1 t -> b t", b=x_u.shape[0]) 84 | scale = append_dims(scale, x_u.ndim).to(x_u.device) 85 | 86 | return rearrange(x_u + scale * (x_c - x_u), "b t ... -> (b t) ...") 87 | 88 | def prepare_inputs( 89 | self, x: torch.Tensor, s: torch.Tensor, c: dict, uc: dict 90 | ) -> Tuple[torch.Tensor, torch.Tensor, dict]: 91 | c_out = dict() 92 | 93 | for k in c: 94 | if k in ["vector", "crossattn", "concat"] + self.additional_cond_keys: 95 | c_out[k] = torch.cat((uc[k], c[k]), 0) 96 | else: 97 | assert c[k] == uc[k] 98 | c_out[k] = c[k] 99 | return torch.cat([x] * 2), torch.cat([s] * 2), c_out 100 | 101 | 102 | class ScheduledCFGImgTextRef(Guider): 103 | """ 104 | From InstructPix2Pix 105 | """ 106 | 107 | def __init__(self, scale: float, scale_im: float): 108 | self.scale = scale 109 | self.scale_im = scale_im 110 | 111 | def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: 112 | x_u, x_ic, x_c = x.chunk(3) 113 | x_pred = x_u + self.scale * (x_c - x_ic) + self.scale_im*(x_ic - x_u) 114 | return x_pred 115 | 116 | def prepare_inputs(self, x, s, c, uc): 117 | c_out = dict() 118 | 119 | for k in c: 120 | if k in ["vector", "crossattn", "concat"]: 121 | b = uc[k].shape[0] 122 | if k == "crossattn": 123 | uc1, uc2 = uc[k].split([x.size(0), b - x.size(0)]) 124 | c1, c2 = c[k].split([x.size(0), b - x.size(0)]) 125 | c_out[k] = torch.cat((uc1, uc1, c1, uc2, c2, c2), 0) 126 | else: 127 | uc1, uc2 = uc[k].split([x.size(0), b - x.size(0)]) 128 | c1, c2 = c[k].split([x.size(0), b - x.size(0)]) 129 | c_out[k] = torch.cat((uc1, uc1, c1, uc2, c2, c2), 0) 130 | else: 131 | assert c[k] == uc[k] 132 | c_out[k] = c[k] 133 | return torch.cat([x] * 3), torch.cat([s] * 3), c_out 134 | 135 | 136 | class VanillaCFGImgRef(Guider): 137 | """ 138 | implements parallelized CFG 139 | """ 140 | 141 | def __init__(self, scale: float): 142 | self.scale = scale 143 | 144 | def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: 145 | x_u, x_c = x.chunk(2) 146 | x_pred = x_u + self.scale * (x_c - x_u) 147 | return x_pred 148 | 149 | def prepare_inputs(self, x, s, c, uc): 150 | c_out = dict() 151 | 152 | for k in c: 153 | if k in ["vector", "crossattn", "concat"]: 154 | b = uc[k].shape[0] 155 | if k == "crossattn": 156 | uc1, uc2 = uc[k].split([x.size(0), b - x.size(0)]) 157 | c1, c2 = c[k].split([x.size(0), b - x.size(0)]) 158 | c_out[k] = torch.cat((uc1, c1, uc2, c2), 0) 159 | else: 160 | uc1, uc2 = uc[k].split([x.size(0), b - x.size(0)]) 161 | c1, c2 = c[k].split([x.size(0), b - x.size(0)]) 162 | c_out[k] = torch.cat((uc1, c1, uc2, c2), 0) 163 | else: 164 | assert c[k] == uc[k] 165 | c_out[k] = c[k] 166 | return torch.cat([x] * 2), torch.cat([s] * 2), c_out 167 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/loss.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional, Tuple, Union 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | 6 | from ...modules.autoencoding.lpips.loss.lpips import LPIPS 7 | from ...modules.encoders.modules import GeneralConditioner 8 | from ...util import append_dims, instantiate_from_config 9 | from .denoiser import Denoiser 10 | 11 | 12 | class StandardDiffusionLoss(nn.Module): 13 | def __init__( 14 | self, 15 | sigma_sampler_config: dict, 16 | loss_weighting_config: dict, 17 | loss_type: str = "l2", 18 | offset_noise_level: float = 0.0, 19 | batch2model_keys: Optional[Union[str, List[str]]] = None, 20 | ): 21 | super().__init__() 22 | 23 | assert loss_type in ["l2", "l1", "lpips"] 24 | 25 | self.sigma_sampler = instantiate_from_config(sigma_sampler_config) 26 | self.loss_weighting = instantiate_from_config(loss_weighting_config) 27 | 28 | self.loss_type = loss_type 29 | self.offset_noise_level = offset_noise_level 30 | 31 | if loss_type == "lpips": 32 | self.lpips = LPIPS().eval() 33 | 34 | if not batch2model_keys: 35 | batch2model_keys = [] 36 | 37 | if isinstance(batch2model_keys, str): 38 | batch2model_keys = [batch2model_keys] 39 | 40 | self.batch2model_keys = set(batch2model_keys) 41 | 42 | def get_noised_input( 43 | self, sigmas_bc: torch.Tensor, noise: torch.Tensor, input: torch.Tensor 44 | ) -> torch.Tensor: 45 | noised_input = input + noise * sigmas_bc 46 | return noised_input 47 | 48 | def forward( 49 | self, 50 | network: nn.Module, 51 | denoiser: Denoiser, 52 | conditioner: GeneralConditioner, 53 | input: torch.Tensor, 54 | batch: Dict, 55 | ) -> torch.Tensor: 56 | cond = conditioner(batch) 57 | return self._forward(network, denoiser, cond, input, batch) 58 | 59 | def _forward( 60 | self, 61 | network: nn.Module, 62 | denoiser: Denoiser, 63 | cond: Dict, 64 | input: torch.Tensor, 65 | batch: Dict, 66 | ) -> Tuple[torch.Tensor, Dict]: 67 | additional_model_inputs = { 68 | key: batch[key] for key in self.batch2model_keys.intersection(batch) 69 | } 70 | sigmas = self.sigma_sampler(input.shape[0]).to(input) 71 | 72 | noise = torch.randn_like(input) 73 | if self.offset_noise_level > 0.0: 74 | offset_shape = ( 75 | (input.shape[0], 1, input.shape[2]) 76 | if self.n_frames is not None 77 | else (input.shape[0], input.shape[1]) 78 | ) 79 | noise = noise + self.offset_noise_level * append_dims( 80 | torch.randn(offset_shape, device=input.device), 81 | input.ndim, 82 | ) 83 | sigmas_bc = append_dims(sigmas, input.ndim) 84 | noised_input = self.get_noised_input(sigmas_bc, noise, input) 85 | 86 | model_output = denoiser( 87 | network, noised_input, sigmas, cond, **additional_model_inputs 88 | ) 89 | w = append_dims(self.loss_weighting(sigmas), input.ndim) 90 | return self.get_loss(model_output, input, w) 91 | 92 | def get_loss(self, model_output, target, w): 93 | if self.loss_type == "l2": 94 | return torch.mean( 95 | (w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1 96 | ) 97 | elif self.loss_type == "l1": 98 | return torch.mean( 99 | (w * (model_output - target).abs()).reshape(target.shape[0], -1), 1 100 | ) 101 | elif self.loss_type == "lpips": 102 | loss = self.lpips(model_output, target).reshape(-1) 103 | return loss 104 | else: 105 | raise NotImplementedError(f"Unknown loss type {self.loss_type}") 106 | 107 | 108 | class StandardDiffusionLossImgRef(nn.Module): 109 | def __init__( 110 | self, 111 | sigma_sampler_config: dict, 112 | sigma_sampler_config_ref: dict, 113 | type: str = "l2", 114 | offset_noise_level: float = 0.0, 115 | batch2model_keys: Optional[Union[str, List[str]]] = None, 116 | ): 117 | super().__init__() 118 | 119 | assert type in ["l2", "l1", "lpips"] 120 | 121 | self.sigma_sampler = instantiate_from_config(sigma_sampler_config) 122 | self.sigma_sampler_ref = None 123 | if sigma_sampler_config_ref is not None: 124 | self.sigma_sampler_ref = instantiate_from_config(sigma_sampler_config_ref) 125 | 126 | self.type = type 127 | self.offset_noise_level = offset_noise_level 128 | 129 | if type == "lpips": 130 | self.lpips = LPIPS().eval() 131 | 132 | if not batch2model_keys: 133 | batch2model_keys = [] 134 | 135 | if isinstance(batch2model_keys, str): 136 | batch2model_keys = [batch2model_keys] 137 | 138 | self.batch2model_keys = set(batch2model_keys) 139 | 140 | def __call__(self, network, denoiser, conditioner, input, input_rgb, input_ref, pose, mask, mask_ref, opacity, batch): 141 | cond = conditioner(batch) 142 | additional_model_inputs = { 143 | key: batch[key] for key in self.batch2model_keys.intersection(batch) 144 | } 145 | 146 | sigmas = self.sigma_sampler(input.shape[0]).to(input.device) 147 | noise = torch.randn_like(input) 148 | if self.offset_noise_level > 0.0: 149 | noise = noise + self.offset_noise_level * append_dims( 150 | torch.randn(input.shape[0], device=input.device), input.ndim 151 | ) 152 | 153 | additional_model_inputs['pose'] = pose 154 | additional_model_inputs['mask_ref'] = mask_ref 155 | 156 | noised_input = input + noise * append_dims(sigmas, input.ndim) 157 | if self.sigma_sampler_ref is not None: 158 | sigmas_ref = self.sigma_sampler_ref(input.shape[0]).to(input.device) 159 | if input_ref is not None: 160 | noise = torch.randn_like(input_ref) 161 | if self.offset_noise_level > 0.0: 162 | noise = noise + self.offset_noise_level * append_dims( 163 | torch.randn(input_ref.shape[0], device=input_ref.device), input_ref.ndim 164 | ) 165 | input_ref = input_ref + noise * append_dims(sigmas_ref, input_ref.ndim) 166 | additional_model_inputs['sigmas_ref'] = sigmas_ref 167 | 168 | additional_model_inputs['input_ref'] = input_ref 169 | 170 | model_output, fg_mask_list, alphas, predicted_rgb_list = denoiser( 171 | network, noised_input, sigmas, cond, **additional_model_inputs 172 | ) 173 | 174 | w = append_dims(denoiser.w(sigmas), input.ndim) 175 | return self.get_loss(model_output, fg_mask_list, predicted_rgb_list, input, input_rgb, w, mask, mask_ref, opacity, alphas) 176 | 177 | def get_loss(self, model_output, fg_mask_list, predicted_rgb_list, target, target_rgb, w, mask, mask_ref, opacity, alphas_list): 178 | loss_rgb = [] 179 | loss_fg = [] 180 | loss_bg = [] 181 | with torch.amp.autocast(device_type='cuda', dtype=torch.float32): 182 | if self.type == "l2": 183 | loss = (w * (model_output - target) ** 2) 184 | if mask is not None: 185 | loss_l2 = (loss*mask).sum([1, 2, 3])/(mask.sum([1, 2, 3]) + 1e-6) 186 | else: 187 | loss_l2 = torch.mean(loss.reshape(target.shape[0], -1), 1) 188 | if len(fg_mask_list) > 0 and len(alphas_list) > 0: 189 | for fg_mask, alphas in zip(fg_mask_list, alphas_list): 190 | size = int(math.sqrt(fg_mask.size(1))) 191 | opacity = torch.nn.functional.interpolate(opacity, size=size, antialias=True, mode='bilinear').detach() 192 | fg_mask = torch.clamp(fg_mask.reshape(-1, size*size), 0., 1.) 193 | loss_fg_ = ((fg_mask - opacity.reshape(-1, size*size))**2).mean(1) #torch.nn.functional.binary_cross_entropy(rgb, torch.clip(mask.reshape(-1, size*size), 0., 1.), reduce=False) 194 | loss_bg_ = (alphas - opacity.reshape(-1, size*size, 1, 1)).abs()*(1-opacity.reshape(-1, size*size, 1, 1)) #alpahs : b hw d 1 195 | loss_bg_ = (loss_bg_*((opacity.reshape(-1, size*size, 1, 1) < 0.1)*1)).mean([1, 2, 3]) 196 | loss_fg.append(loss_fg_) 197 | loss_bg.append(loss_bg_) 198 | loss_fg = torch.stack(loss_fg, 1) 199 | loss_bg = torch.stack(loss_bg, 1) 200 | 201 | if len(predicted_rgb_list) > 0: 202 | for rgb in predicted_rgb_list: 203 | size = int(math.sqrt(rgb.size(1))) 204 | mask_ = torch.nn.functional.interpolate(mask, size=size, antialias=True, mode='bilinear').detach() 205 | loss_rgb_ = ((torch.nn.functional.interpolate(target_rgb*0.5+0.5, size=size, antialias=True, mode='bilinear').detach() - rgb.reshape(-1, size, size, 3).permute(0, 3, 1, 2)) ** 2) 206 | loss_rgb.append((loss_rgb_*mask_).sum([1, 2, 3])/(mask.sum([1, 2, 3]) + 1e-6)) 207 | loss_rgb = torch.stack(loss_rgb, 1) 208 | # print(loss_l2, loss_fg, loss_bg, loss_rgb) 209 | return loss_l2, loss_fg, loss_bg, loss_rgb 210 | elif self.type == "l1": 211 | return torch.mean( 212 | (w * (model_output - target).abs()).reshape(target.shape[0], -1), 1 213 | ), loss_rgb 214 | elif self.type == "lpips": 215 | loss = self.lpips(model_output, target).reshape(-1) 216 | return loss, loss_rgb 217 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/loss_weighting.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import torch 4 | 5 | 6 | class DiffusionLossWeighting(ABC): 7 | @abstractmethod 8 | def __call__(self, sigma: torch.Tensor) -> torch.Tensor: 9 | pass 10 | 11 | 12 | class UnitWeighting(DiffusionLossWeighting): 13 | def __call__(self, sigma: torch.Tensor) -> torch.Tensor: 14 | return torch.ones_like(sigma, device=sigma.device) 15 | 16 | 17 | class EDMWeighting(DiffusionLossWeighting): 18 | def __init__(self, sigma_data: float = 0.5): 19 | self.sigma_data = sigma_data 20 | 21 | def __call__(self, sigma: torch.Tensor) -> torch.Tensor: 22 | return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 23 | 24 | 25 | class VWeighting(EDMWeighting): 26 | def __init__(self): 27 | super().__init__(sigma_data=1.0) 28 | 29 | 30 | class EpsWeighting(DiffusionLossWeighting): 31 | def __call__(self, sigma: torch.Tensor) -> torch.Tensor: 32 | return sigma**-2.0 33 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/sampling.py: -------------------------------------------------------------------------------- 1 | """ 2 | Partially ported from https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py 3 | """ 4 | 5 | 6 | from typing import Dict, Union 7 | 8 | import torch 9 | from omegaconf import ListConfig, OmegaConf 10 | from tqdm import tqdm 11 | 12 | from ...modules.diffusionmodules.sampling_utils import ( 13 | get_ancestral_step, 14 | linear_multistep_coeff, 15 | to_d, 16 | to_neg_log_sigma, 17 | to_sigma, 18 | ) 19 | from ...util import append_dims, default, instantiate_from_config 20 | 21 | DEFAULT_GUIDER = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"} 22 | 23 | 24 | class BaseDiffusionSampler: 25 | def __init__( 26 | self, 27 | discretization_config: Union[Dict, ListConfig, OmegaConf], 28 | num_steps: Union[int, None] = None, 29 | guider_config: Union[Dict, ListConfig, OmegaConf, None] = None, 30 | verbose: bool = False, 31 | device: str = "cuda", 32 | ): 33 | self.num_steps = num_steps 34 | self.discretization = instantiate_from_config(discretization_config) 35 | self.guider = instantiate_from_config( 36 | default( 37 | guider_config, 38 | DEFAULT_GUIDER, 39 | ) 40 | ) 41 | self.verbose = verbose 42 | self.device = device 43 | 44 | def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None): 45 | sigmas = self.discretization( 46 | self.num_steps if num_steps is None else num_steps, device=self.device 47 | ) 48 | uc = default(uc, cond) 49 | 50 | x *= torch.sqrt(1.0 + sigmas[0] ** 2.0) 51 | num_sigmas = len(sigmas) 52 | 53 | s_in = x.new_ones([x.shape[0]]) 54 | 55 | return x, s_in, sigmas, num_sigmas, cond, uc 56 | 57 | def denoise(self, x, denoiser, sigma, cond, uc): 58 | denoised, _, _, rgb_list = denoiser(*self.guider.prepare_inputs(x, sigma, cond, uc)) 59 | denoised = self.guider(denoised, sigma) 60 | return denoised, rgb_list 61 | 62 | def get_sigma_gen(self, num_sigmas): 63 | sigma_generator = range(num_sigmas - 1) 64 | if self.verbose: 65 | print("#" * 30, " Sampling setting ", "#" * 30) 66 | print(f"Sampler: {self.__class__.__name__}") 67 | print(f"Discretization: {self.discretization.__class__.__name__}") 68 | print(f"Guider: {self.guider.__class__.__name__}") 69 | sigma_generator = tqdm( 70 | sigma_generator, 71 | total=num_sigmas, 72 | desc=f"Sampling with {self.__class__.__name__} for {num_sigmas} steps", 73 | ) 74 | return sigma_generator 75 | 76 | 77 | class SingleStepDiffusionSampler(BaseDiffusionSampler): 78 | def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc, *args, **kwargs): 79 | raise NotImplementedError 80 | 81 | def euler_step(self, x, d, dt): 82 | return x + dt * d 83 | 84 | 85 | class EDMSampler(SingleStepDiffusionSampler): 86 | def __init__( 87 | self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs 88 | ): 89 | super().__init__(*args, **kwargs) 90 | 91 | self.s_churn = s_churn 92 | self.s_tmin = s_tmin 93 | self.s_tmax = s_tmax 94 | self.s_noise = s_noise 95 | 96 | def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, gamma=0.0): 97 | sigma_hat = sigma * (gamma + 1.0) 98 | if gamma > 0: 99 | eps = torch.randn_like(x) * self.s_noise 100 | x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5 101 | 102 | denoised, rgb_list = self.denoise(x, denoiser, sigma_hat, cond, uc) 103 | d = to_d(x, sigma_hat, denoised) 104 | dt = append_dims(next_sigma - sigma_hat, x.ndim) 105 | 106 | euler_step = self.euler_step(x, d, dt) 107 | x = self.possible_correction_step( 108 | euler_step, x, d, dt, next_sigma, denoiser, cond, uc 109 | ) 110 | return x, rgb_list 111 | 112 | def __call__(self, denoiser, x, cond, uc=None, num_steps=None, mask=None, init_im=None): 113 | return self.forward(denoiser, x, cond, uc=uc, num_steps=num_steps, mask=mask, init_im=init_im) 114 | 115 | def forward(self, denoiser, x, cond, uc=None, num_steps=None, mask=None, init_im=None): 116 | x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( 117 | x, cond, uc, num_steps 118 | ) 119 | for i in self.get_sigma_gen(num_sigmas): 120 | gamma = ( 121 | min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) 122 | if self.s_tmin <= sigmas[i] <= self.s_tmax 123 | else 0.0 124 | ) 125 | x_new, rgb_list = self.sampler_step( 126 | s_in * sigmas[i], 127 | s_in * sigmas[i + 1], 128 | denoiser, 129 | x, 130 | cond, 131 | uc, 132 | gamma, 133 | ) 134 | x = x_new 135 | 136 | return x, rgb_list 137 | 138 | 139 | def get_views(panorama_height, panorama_width, window_size=64, stride=48): 140 | # panorama_height /= 8 141 | # panorama_width /= 8 142 | num_blocks_height = (panorama_height - window_size) // stride + 1 143 | num_blocks_width = (panorama_width - window_size) // stride + 1 144 | total_num_blocks = int(num_blocks_height * num_blocks_width) 145 | views = [] 146 | for i in range(total_num_blocks): 147 | h_start = int((i // num_blocks_width) * stride) 148 | h_end = h_start + window_size 149 | w_start = int((i % num_blocks_width) * stride) 150 | w_end = w_start + window_size 151 | views.append((h_start, h_end, w_start, w_end)) 152 | return views 153 | 154 | 155 | class EDMMultidiffusionSampler(SingleStepDiffusionSampler): 156 | def __init__( 157 | self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs 158 | ): 159 | super().__init__(*args, **kwargs) 160 | 161 | self.s_churn = s_churn 162 | self.s_tmin = s_tmin 163 | self.s_tmax = s_tmax 164 | self.s_noise = s_noise 165 | 166 | def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, gamma=0.0): 167 | sigma_hat = sigma * (gamma + 1.0) 168 | if gamma > 0: 169 | eps = torch.randn_like(x) * self.s_noise 170 | x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5 171 | 172 | denoised, rgb_list = self.denoise(x, denoiser, sigma_hat, cond, uc) 173 | d = to_d(x, sigma_hat, denoised) 174 | dt = append_dims(next_sigma - sigma_hat, x.ndim) 175 | 176 | euler_step = self.euler_step(x, d, dt) 177 | x = self.possible_correction_step( 178 | euler_step, x, d, dt, next_sigma, denoiser, cond, uc 179 | ) 180 | return x, rgb_list 181 | 182 | def __call__(self, denoiser, model, x, cond, uc=None, num_steps=None, multikwargs=None): 183 | return self.forward(denoiser, model, x, cond, uc=uc, num_steps=num_steps, multikwargs=multikwargs) 184 | 185 | def forward(self, denoiser, model, x, cond, uc=None, num_steps=None, multikwargs=None): 186 | views = get_views(x.shape[-2], 48*(len(multikwargs)+1)) 187 | shape = x.shape 188 | x = torch.randn(shape[0], shape[1], shape[2], 48*(len(multikwargs)+1)).to(x.device) 189 | count = torch.zeros_like(x, device=x.device) 190 | value = torch.zeros_like(x, device=x.device) 191 | 192 | x, s_in, sigmas, num_sigmas, cond_, uc = self.prepare_sampling_loop( 193 | x, cond[0], uc[0], num_steps 194 | ) 195 | 196 | for i in self.get_sigma_gen(num_sigmas): 197 | gamma = ( 198 | min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) 199 | if self.s_tmin <= sigmas[i] <= self.s_tmax 200 | else 0.0 201 | ) 202 | count.zero_() 203 | value.zero_() 204 | 205 | for j, (h_start, h_end, w_start, w_end) in enumerate(views): 206 | # TODO we can support batches, and pass multiple views at once to the unet 207 | latent_view = x[:, :, h_start:h_end, w_start:w_end] 208 | # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. 209 | kwargs = {'pose': multikwargs[j]['pose'], 'mask_ref': None, 'drop_im': j} 210 | x_new, rgb_list = self.sampler_step( 211 | s_in * sigmas[i], 212 | s_in * sigmas[i + 1], 213 | lambda input, sigma, c: denoiser( 214 | model, input, sigma, c, **kwargs 215 | ), 216 | latent_view, 217 | cond[j], 218 | uc, 219 | gamma, 220 | ) 221 | # compute the denoising step with the reference model 222 | value[:, :, h_start:h_end, w_start:w_end] += x_new 223 | count[:, :, h_start:h_end, w_start:w_end] += 1 224 | 225 | # take the MultiDiffusion step 226 | x = torch.where(count > 0, value / count, value) 227 | 228 | return x, rgb_list 229 | 230 | def possible_correction_step( 231 | self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc 232 | ): 233 | return euler_step 234 | 235 | 236 | class AncestralSampler(SingleStepDiffusionSampler): 237 | def __init__(self, eta=1.0, s_noise=1.0, *args, **kwargs): 238 | super().__init__(*args, **kwargs) 239 | 240 | self.eta = eta 241 | self.s_noise = s_noise 242 | self.noise_sampler = lambda x: torch.randn_like(x) 243 | 244 | def ancestral_euler_step(self, x, denoised, sigma, sigma_down): 245 | d = to_d(x, sigma, denoised) 246 | dt = append_dims(sigma_down - sigma, x.ndim) 247 | 248 | return self.euler_step(x, d, dt) 249 | 250 | def ancestral_step(self, x, sigma, next_sigma, sigma_up): 251 | x = torch.where( 252 | append_dims(next_sigma, x.ndim) > 0.0, 253 | x + self.noise_sampler(x) * self.s_noise * append_dims(sigma_up, x.ndim), 254 | x, 255 | ) 256 | return x 257 | 258 | def __call__(self, denoiser, x, cond, uc=None, num_steps=None): 259 | x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( 260 | x, cond, uc, num_steps 261 | ) 262 | 263 | for i in self.get_sigma_gen(num_sigmas): 264 | x = self.sampler_step( 265 | s_in * sigmas[i], 266 | s_in * sigmas[i + 1], 267 | denoiser, 268 | x, 269 | cond, 270 | uc, 271 | ) 272 | 273 | return x 274 | 275 | 276 | class LinearMultistepSampler(BaseDiffusionSampler): 277 | def __init__( 278 | self, 279 | order=4, 280 | *args, 281 | **kwargs, 282 | ): 283 | super().__init__(*args, **kwargs) 284 | 285 | self.order = order 286 | 287 | def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs): 288 | x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( 289 | x, cond, uc, num_steps 290 | ) 291 | 292 | ds = [] 293 | sigmas_cpu = sigmas.detach().cpu().numpy() 294 | for i in self.get_sigma_gen(num_sigmas): 295 | sigma = s_in * sigmas[i] 296 | denoised, _ = denoiser( 297 | *self.guider.prepare_inputs(x, sigma, cond, uc), **kwargs 298 | ) 299 | denoised = self.guider(denoised, sigma) 300 | d = to_d(x, sigma, denoised) 301 | ds.append(d) 302 | if len(ds) > self.order: 303 | ds.pop(0) 304 | cur_order = min(i + 1, self.order) 305 | coeffs = [ 306 | linear_multistep_coeff(cur_order, sigmas_cpu, i, j) 307 | for j in range(cur_order) 308 | ] 309 | x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds))) 310 | 311 | return x 312 | 313 | 314 | class EulerEDMSampler(EDMSampler): 315 | def possible_correction_step( 316 | self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc 317 | ): 318 | return euler_step 319 | 320 | 321 | class HeunEDMSampler(EDMSampler): 322 | def possible_correction_step( 323 | self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc 324 | ): 325 | if torch.sum(next_sigma) < 1e-14: 326 | # Save a network evaluation if all noise levels are 0 327 | return euler_step 328 | else: 329 | denoised = self.denoise(euler_step, denoiser, next_sigma, cond, uc) 330 | d_new = to_d(euler_step, next_sigma, denoised) 331 | d_prime = (d + d_new) / 2.0 332 | 333 | # apply correction if noise level is not 0 334 | x = torch.where( 335 | append_dims(next_sigma, x.ndim) > 0.0, x + d_prime * dt, euler_step 336 | ) 337 | return x 338 | 339 | 340 | class EulerAncestralSampler(AncestralSampler): 341 | def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc): 342 | sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta) 343 | denoised = self.denoise(x, denoiser, sigma, cond, uc) 344 | x = self.ancestral_euler_step(x, denoised, sigma, sigma_down) 345 | x = self.ancestral_step(x, sigma, next_sigma, sigma_up) 346 | 347 | return x 348 | 349 | 350 | class DPMPP2SAncestralSampler(AncestralSampler): 351 | def get_variables(self, sigma, sigma_down): 352 | t, t_next = [to_neg_log_sigma(s) for s in (sigma, sigma_down)] 353 | h = t_next - t 354 | s = t + 0.5 * h 355 | return h, s, t, t_next 356 | 357 | def get_mult(self, h, s, t, t_next): 358 | mult1 = to_sigma(s) / to_sigma(t) 359 | mult2 = (-0.5 * h).expm1() 360 | mult3 = to_sigma(t_next) / to_sigma(t) 361 | mult4 = (-h).expm1() 362 | 363 | return mult1, mult2, mult3, mult4 364 | 365 | def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, **kwargs): 366 | sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta) 367 | denoised = self.denoise(x, denoiser, sigma, cond, uc) 368 | x_euler = self.ancestral_euler_step(x, denoised, sigma, sigma_down) 369 | 370 | if torch.sum(sigma_down) < 1e-14: 371 | # Save a network evaluation if all noise levels are 0 372 | x = x_euler 373 | else: 374 | h, s, t, t_next = self.get_variables(sigma, sigma_down) 375 | mult = [ 376 | append_dims(mult, x.ndim) for mult in self.get_mult(h, s, t, t_next) 377 | ] 378 | 379 | x2 = mult[0] * x - mult[1] * denoised 380 | denoised2 = self.denoise(x2, denoiser, to_sigma(s), cond, uc) 381 | x_dpmpp2s = mult[2] * x - mult[3] * denoised2 382 | 383 | # apply correction if noise level is not 0 384 | x = torch.where(append_dims(sigma_down, x.ndim) > 0.0, x_dpmpp2s, x_euler) 385 | 386 | x = self.ancestral_step(x, sigma, next_sigma, sigma_up) 387 | return x 388 | 389 | 390 | class DPMPP2MSampler(BaseDiffusionSampler): 391 | def get_variables(self, sigma, next_sigma, previous_sigma=None): 392 | t, t_next = [to_neg_log_sigma(s) for s in (sigma, next_sigma)] 393 | h = t_next - t 394 | 395 | if previous_sigma is not None: 396 | h_last = t - to_neg_log_sigma(previous_sigma) 397 | r = h_last / h 398 | return h, r, t, t_next 399 | else: 400 | return h, None, t, t_next 401 | 402 | def get_mult(self, h, r, t, t_next, previous_sigma): 403 | mult1 = to_sigma(t_next) / to_sigma(t) 404 | mult2 = (-h).expm1() 405 | 406 | if previous_sigma is not None: 407 | mult3 = 1 + 1 / (2 * r) 408 | mult4 = 1 / (2 * r) 409 | return mult1, mult2, mult3, mult4 410 | else: 411 | return mult1, mult2 412 | 413 | def sampler_step( 414 | self, 415 | old_denoised, 416 | previous_sigma, 417 | sigma, 418 | next_sigma, 419 | denoiser, 420 | x, 421 | cond, 422 | uc=None, 423 | ): 424 | denoised = self.denoise(x, denoiser, sigma, cond, uc) 425 | 426 | h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma) 427 | mult = [ 428 | append_dims(mult, x.ndim) 429 | for mult in self.get_mult(h, r, t, t_next, previous_sigma) 430 | ] 431 | 432 | x_standard = mult[0] * x - mult[1] * denoised 433 | if old_denoised is None or torch.sum(next_sigma) < 1e-14: 434 | # Save a network evaluation if all noise levels are 0 or on the first step 435 | return x_standard, denoised 436 | else: 437 | denoised_d = mult[2] * denoised - mult[3] * old_denoised 438 | x_advanced = mult[0] * x - mult[1] * denoised_d 439 | 440 | # apply correction if noise level is not 0 and not first step 441 | x = torch.where( 442 | append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard 443 | ) 444 | 445 | return x, denoised 446 | 447 | def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs): 448 | x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( 449 | x, cond, uc, num_steps 450 | ) 451 | 452 | old_denoised = None 453 | for i in self.get_sigma_gen(num_sigmas): 454 | x, old_denoised = self.sampler_step( 455 | old_denoised, 456 | None if i == 0 else s_in * sigmas[i - 1], 457 | s_in * sigmas[i], 458 | s_in * sigmas[i + 1], 459 | denoiser, 460 | x, 461 | cond, 462 | uc=uc, 463 | ) 464 | 465 | return x 466 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/sampling_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from scipy import integrate 3 | 4 | from ...util import append_dims 5 | 6 | 7 | class NoDynamicThresholding: 8 | def __call__(self, uncond, cond, scale): 9 | return uncond + scale * (cond - uncond) 10 | 11 | 12 | def linear_multistep_coeff(order, t, i, j, epsrel=1e-4): 13 | if order - 1 > i: 14 | raise ValueError(f"Order {order} too high for step {i}") 15 | 16 | def fn(tau): 17 | prod = 1.0 18 | for k in range(order): 19 | if j == k: 20 | continue 21 | prod *= (tau - t[i - k]) / (t[i - j] - t[i - k]) 22 | return prod 23 | 24 | return integrate.quad(fn, t[i], t[i + 1], epsrel=epsrel)[0] 25 | 26 | 27 | def get_ancestral_step(sigma_from, sigma_to, eta=1.0): 28 | if not eta: 29 | return sigma_to, 0.0 30 | sigma_up = torch.minimum( 31 | sigma_to, 32 | eta 33 | * (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5, 34 | ) 35 | sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 36 | return sigma_down, sigma_up 37 | 38 | 39 | def to_d(x, sigma, denoised): 40 | return (x - denoised) / append_dims(sigma, x.ndim) 41 | 42 | 43 | def to_neg_log_sigma(sigma): 44 | return sigma.log().neg() 45 | 46 | 47 | def to_sigma(neg_log_sigma): 48 | return neg_log_sigma.neg().exp() 49 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/sigma_sampling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ...util import default, instantiate_from_config 4 | 5 | 6 | class EDMSampling: 7 | def __init__(self, p_mean=-1.2, p_std=1.2): 8 | self.p_mean = p_mean 9 | self.p_std = p_std 10 | 11 | def __call__(self, n_samples, rand=None): 12 | log_sigma = self.p_mean + self.p_std * default(rand, torch.randn((n_samples,))) 13 | return log_sigma.exp() 14 | 15 | 16 | class DiscreteSampling: 17 | def __init__(self, discretization_config, num_idx, num_idx_start=0, do_append_zero=False, flip=True): 18 | self.num_idx = num_idx 19 | self.num_idx_start = num_idx_start 20 | self.sigmas = instantiate_from_config(discretization_config)( 21 | num_idx, do_append_zero=do_append_zero, flip=flip 22 | ) 23 | 24 | def idx_to_sigma(self, idx): 25 | return self.sigmas[idx] 26 | 27 | def __call__(self, n_samples, rand=None): 28 | idx = default( 29 | rand, 30 | torch.randint(self.num_idx_start, self.num_idx, (n_samples,)), 31 | ) 32 | return self.idx_to_sigma(idx) 33 | 34 | 35 | class CubicSampling: 36 | def __init__(self, discretization_config, num_idx, do_append_zero=False, flip=True): 37 | self.num_idx = num_idx 38 | self.sigmas = instantiate_from_config(discretization_config)( 39 | num_idx, do_append_zero=do_append_zero, flip=flip 40 | ) 41 | 42 | def idx_to_sigma(self, idx): 43 | return self.sigmas[idx] 44 | 45 | def __call__(self, n_samples, rand=None): 46 | t = torch.rand((n_samples,)) 47 | t = (1 - t ** 3) * (self.num_idx-1) 48 | t = t.long() 49 | idx = default( 50 | rand, 51 | t, 52 | ) 53 | return self.idx_to_sigma(idx) 54 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/util.py: -------------------------------------------------------------------------------- 1 | """ 2 | adopted from 3 | https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 4 | and 5 | https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py 6 | and 7 | https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py 8 | 9 | thanks! 10 | """ 11 | 12 | import math 13 | 14 | import torch 15 | import torch.nn as nn 16 | from einops import repeat 17 | 18 | 19 | def make_beta_schedule( 20 | schedule, 21 | n_timestep, 22 | linear_start=1e-4, 23 | linear_end=2e-2, 24 | ): 25 | if schedule == "linear": 26 | betas = ( 27 | torch.linspace( 28 | linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64 29 | ) 30 | ** 2 31 | ) 32 | return betas.numpy() 33 | 34 | 35 | def extract_into_tensor(a, t, x_shape): 36 | b, *_ = t.shape 37 | out = a.gather(-1, t) 38 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 39 | 40 | 41 | def mixed_checkpoint(func, inputs: dict, params, flag): 42 | """ 43 | Evaluate a function without caching intermediate activations, allowing for 44 | reduced memory at the expense of extra compute in the backward pass. This differs from the original checkpoint function 45 | borrowed from https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py in that 46 | it also works with non-tensor inputs 47 | :param func: the function to evaluate. 48 | :param inputs: the argument dictionary to pass to `func`. 49 | :param params: a sequence of parameters `func` depends on but does not 50 | explicitly take as arguments. 51 | :param flag: if False, disable gradient checkpointing. 52 | """ 53 | if flag: 54 | tensor_keys = [key for key in inputs if isinstance(inputs[key], torch.Tensor)] 55 | tensor_inputs = [ 56 | inputs[key] for key in inputs if isinstance(inputs[key], torch.Tensor) 57 | ] 58 | non_tensor_keys = [ 59 | key for key in inputs if not isinstance(inputs[key], torch.Tensor) 60 | ] 61 | non_tensor_inputs = [ 62 | inputs[key] for key in inputs if not isinstance(inputs[key], torch.Tensor) 63 | ] 64 | args = tuple(tensor_inputs) + tuple(non_tensor_inputs) + tuple(params) 65 | return MixedCheckpointFunction.apply( 66 | func, 67 | len(tensor_inputs), 68 | len(non_tensor_inputs), 69 | tensor_keys, 70 | non_tensor_keys, 71 | *args, 72 | ) 73 | else: 74 | return func(**inputs) 75 | 76 | 77 | class MixedCheckpointFunction(torch.autograd.Function): 78 | @staticmethod 79 | def forward( 80 | ctx, 81 | run_function, 82 | length_tensors, 83 | length_non_tensors, 84 | tensor_keys, 85 | non_tensor_keys, 86 | *args, 87 | ): 88 | ctx.end_tensors = length_tensors 89 | ctx.end_non_tensors = length_tensors + length_non_tensors 90 | ctx.gpu_autocast_kwargs = { 91 | "enabled": torch.is_autocast_enabled(), 92 | "dtype": torch.get_autocast_gpu_dtype(), 93 | "cache_enabled": torch.is_autocast_cache_enabled(), 94 | } 95 | assert ( 96 | len(tensor_keys) == length_tensors 97 | and len(non_tensor_keys) == length_non_tensors 98 | ) 99 | 100 | ctx.input_tensors = { 101 | key: val for (key, val) in zip(tensor_keys, list(args[: ctx.end_tensors])) 102 | } 103 | ctx.input_non_tensors = { 104 | key: val 105 | for (key, val) in zip( 106 | non_tensor_keys, list(args[ctx.end_tensors : ctx.end_non_tensors]) 107 | ) 108 | } 109 | ctx.run_function = run_function 110 | ctx.input_params = list(args[ctx.end_non_tensors :]) 111 | 112 | with torch.no_grad(): 113 | output_tensors = ctx.run_function( 114 | **ctx.input_tensors, **ctx.input_non_tensors 115 | ) 116 | return output_tensors 117 | 118 | @staticmethod 119 | def backward(ctx, *output_grads): 120 | # additional_args = {key: ctx.input_tensors[key] for key in ctx.input_tensors if not isinstance(ctx.input_tensors[key],torch.Tensor)} 121 | ctx.input_tensors = { 122 | key: ctx.input_tensors[key].detach().requires_grad_(True) 123 | for key in ctx.input_tensors 124 | } 125 | 126 | with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): 127 | # Fixes a bug where the first op in run_function modifies the 128 | # Tensor storage in place, which is not allowed for detach()'d 129 | # Tensors. 130 | shallow_copies = { 131 | key: ctx.input_tensors[key].view_as(ctx.input_tensors[key]) 132 | for key in ctx.input_tensors 133 | } 134 | # shallow_copies.update(additional_args) 135 | output_tensors = ctx.run_function(**shallow_copies, **ctx.input_non_tensors) 136 | input_grads = torch.autograd.grad( 137 | output_tensors, 138 | list(ctx.input_tensors.values()) + ctx.input_params, 139 | output_grads, 140 | allow_unused=True, 141 | ) 142 | del ctx.input_tensors 143 | del ctx.input_params 144 | del output_tensors 145 | return ( 146 | (None, None, None, None, None) 147 | + input_grads[: ctx.end_tensors] 148 | + (None,) * (ctx.end_non_tensors - ctx.end_tensors) 149 | + input_grads[ctx.end_tensors :] 150 | ) 151 | 152 | 153 | def checkpoint(func, inputs, params, flag): 154 | """ 155 | Evaluate a function without caching intermediate activations, allowing for 156 | reduced memory at the expense of extra compute in the backward pass. 157 | :param func: the function to evaluate. 158 | :param inputs: the argument sequence to pass to `func`. 159 | :param params: a sequence of parameters `func` depends on but does not 160 | explicitly take as arguments. 161 | :param flag: if False, disable gradient checkpointing. 162 | """ 163 | if flag: 164 | args = tuple(inputs) + tuple(params) 165 | return CheckpointFunction.apply(func, len(inputs), *args) 166 | else: 167 | return func(*inputs) 168 | 169 | 170 | class CheckpointFunction(torch.autograd.Function): 171 | @staticmethod 172 | def forward(ctx, run_function, length, *args): 173 | ctx.run_function = run_function 174 | ctx.input_tensors = list(args[:length]) 175 | ctx.input_params = list(args[length:]) 176 | ctx.gpu_autocast_kwargs = { 177 | "enabled": torch.is_autocast_enabled(), 178 | "dtype": torch.get_autocast_gpu_dtype(), 179 | "cache_enabled": torch.is_autocast_cache_enabled(), 180 | } 181 | with torch.no_grad(): 182 | output_tensors = ctx.run_function(*ctx.input_tensors) 183 | return output_tensors 184 | 185 | @staticmethod 186 | def backward(ctx, *output_grads): 187 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 188 | with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): 189 | # Fixes a bug where the first op in run_function modifies the 190 | # Tensor storage in place, which is not allowed for detach()'d 191 | # Tensors. 192 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 193 | output_tensors = ctx.run_function(*shallow_copies) 194 | input_grads = torch.autograd.grad( 195 | output_tensors, 196 | ctx.input_tensors + ctx.input_params, 197 | output_grads, 198 | allow_unused=True, 199 | ) 200 | del ctx.input_tensors 201 | del ctx.input_params 202 | del output_tensors 203 | return (None, None) + input_grads 204 | 205 | 206 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 207 | """ 208 | Create sinusoidal timestep embeddings. 209 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 210 | These may be fractional. 211 | :param dim: the dimension of the output. 212 | :param max_period: controls the minimum frequency of the embeddings. 213 | :return: an [N x dim] Tensor of positional embeddings. 214 | """ 215 | if not repeat_only: 216 | half = dim // 2 217 | freqs = torch.exp( 218 | -math.log(max_period) 219 | * torch.arange(start=0, end=half, dtype=torch.float32) 220 | / half 221 | ).to(device=timesteps.device) 222 | args = timesteps[:, None].float() * freqs[None] 223 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 224 | if dim % 2: 225 | embedding = torch.cat( 226 | [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 227 | ) 228 | else: 229 | embedding = repeat(timesteps, "b -> b d", d=dim) 230 | return embedding 231 | 232 | 233 | def timestep_embedding_pose(timesteps, dim, max_period=10000, repeat_only=False): 234 | """ 235 | Create sinusoidal timestep embeddings. 236 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 237 | These may be fractional. 238 | :param dim: the dimension of the output. 239 | :param max_period: controls the minimum frequency of the embeddings. 240 | :return: an [N x dim] Tensor of positional embeddings. 241 | """ 242 | if not repeat_only: 243 | half = dim // 2 244 | freqs = torch.exp( 245 | -math.log(max_period) 246 | * torch.arange(start=0, end=half, dtype=torch.float32) 247 | / half 248 | ).to(device=timesteps.device) 249 | args = timesteps[:, None].float() * freqs[None] 250 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 251 | if dim % 2: 252 | embedding = torch.cat( 253 | [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 254 | ) 255 | else: 256 | embedding = repeat(timesteps, "b -> b d", d=dim) 257 | return embedding 258 | 259 | 260 | def zero_module(module): 261 | """ 262 | Zero out the parameters of a module and return it. 263 | """ 264 | for p in module.parameters(): 265 | p.detach().zero_() 266 | return module 267 | 268 | 269 | def ones_module(module): 270 | """ 271 | Zero out the parameters of a module and return it. 272 | """ 273 | for p in module.parameters(): 274 | p.detach().data.fill_(1.) 275 | return module 276 | 277 | 278 | def scale_module(module, scale): 279 | """ 280 | Scale the parameters of a module and return it. 281 | """ 282 | for p in module.parameters(): 283 | p.detach().mul_(scale) 284 | return module 285 | 286 | 287 | def mean_flat(tensor): 288 | """ 289 | Take the mean over all non-batch dimensions. 290 | """ 291 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 292 | 293 | 294 | def normalization(channels): 295 | """ 296 | Make a standard normalization layer. 297 | :param channels: number of input channels. 298 | :return: an nn.Module for normalization. 299 | """ 300 | return GroupNorm32(32, channels) 301 | 302 | 303 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 304 | class SiLU(nn.Module): 305 | def forward(self, x): 306 | return x * torch.sigmoid(x) 307 | 308 | 309 | class GroupNorm32(nn.GroupNorm): 310 | def forward(self, x): 311 | return super().forward(x.float()).type(x.dtype) 312 | 313 | 314 | def conv_nd(dims, *args, **kwargs): 315 | """ 316 | Create a 1D, 2D, or 3D convolution module. 317 | """ 318 | if dims == 1: 319 | return nn.Conv1d(*args, **kwargs) 320 | elif dims == 2: 321 | return nn.Conv2d(*args, **kwargs) 322 | elif dims == 3: 323 | return nn.Conv3d(*args, **kwargs) 324 | raise ValueError(f"unsupported dimensions: {dims}") 325 | 326 | 327 | def linear(*args, **kwargs): 328 | """ 329 | Create a linear module. 330 | """ 331 | return nn.Linear(*args, **kwargs) 332 | 333 | 334 | def avg_pool_nd(dims, *args, **kwargs): 335 | """ 336 | Create a 1D, 2D, or 3D average pooling module. 337 | """ 338 | if dims == 1: 339 | return nn.AvgPool1d(*args, **kwargs) 340 | elif dims == 2: 341 | return nn.AvgPool2d(*args, **kwargs) 342 | elif dims == 3: 343 | return nn.AvgPool3d(*args, **kwargs) 344 | raise ValueError(f"unsupported dimensions: {dims}") 345 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/wrappers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from packaging import version 4 | 5 | OPENAIUNETWRAPPER = "sgm.modules.diffusionmodules.wrappers.OpenAIWrapper" 6 | 7 | 8 | class IdentityWrapper(nn.Module): 9 | def __init__(self, diffusion_model, compile_model: bool = False): 10 | super().__init__() 11 | compile = ( 12 | torch.compile 13 | if (version.parse(torch.__version__) >= version.parse("2.0.0")) 14 | and compile_model 15 | else lambda x: x 16 | ) 17 | self.diffusion_model = compile(diffusion_model) 18 | 19 | def forward(self, *args, **kwargs): 20 | return self.diffusion_model(*args, **kwargs) 21 | 22 | 23 | class OpenAIWrapper(IdentityWrapper): 24 | def forward( 25 | self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs 26 | ) -> torch.Tensor: 27 | x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1) 28 | return self.diffusion_model( 29 | x, 30 | timesteps=t, 31 | context=c.get("crossattn", None), 32 | y=c.get("vector", None), 33 | **kwargs 34 | ) 35 | 36 | -------------------------------------------------------------------------------- /sgm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/customdiffusion360/custom-diffusion360/1a23f972274e7275fdeaa3197f5d22118aa228bb/sgm/modules/distributions/__init__.py -------------------------------------------------------------------------------- /sgm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to( 34 | device=self.parameters.device 35 | ) 36 | 37 | def sample(self): 38 | x = self.mean + self.std * torch.randn(self.mean.shape).to( 39 | device=self.parameters.device 40 | ) 41 | return x 42 | 43 | def kl(self, other=None): 44 | if self.deterministic: 45 | return torch.Tensor([0.0]) 46 | else: 47 | if other is None: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, 50 | dim=[1, 2, 3], 51 | ) 52 | else: 53 | return 0.5 * torch.sum( 54 | torch.pow(self.mean - other.mean, 2) / other.var 55 | + self.var / other.var 56 | - 1.0 57 | - self.logvar 58 | + other.logvar, 59 | dim=[1, 2, 3], 60 | ) 61 | 62 | def nll(self, sample, dims=[1, 2, 3]): 63 | if self.deterministic: 64 | return torch.Tensor([0.0]) 65 | logtwopi = np.log(2.0 * np.pi) 66 | return 0.5 * torch.sum( 67 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 68 | dim=dims, 69 | ) 70 | 71 | def mode(self): 72 | return self.mean 73 | 74 | 75 | def normal_kl(mean1, logvar1, mean2, logvar2): 76 | """ 77 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 78 | Compute the KL divergence between two gaussians. 79 | Shapes are automatically broadcasted, so batches can be compared to 80 | scalars, among other use cases. 81 | """ 82 | tensor = None 83 | for obj in (mean1, logvar1, mean2, logvar2): 84 | if isinstance(obj, torch.Tensor): 85 | tensor = obj 86 | break 87 | assert tensor is not None, "at least one argument must be a Tensor" 88 | 89 | # Force variances to be Tensors. Broadcasting helps convert scalars to 90 | # Tensors, but it does not work for torch.exp(). 91 | logvar1, logvar2 = [ 92 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 93 | for x in (logvar1, logvar2) 94 | ] 95 | 96 | return 0.5 * ( 97 | -1.0 98 | + logvar2 99 | - logvar1 100 | + torch.exp(logvar1 - logvar2) 101 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 102 | ) 103 | -------------------------------------------------------------------------------- /sgm/modules/distributions/distributions1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to( 34 | device=self.parameters.device 35 | ) 36 | 37 | def sample(self): 38 | x = self.mean + self.std * torch.randn(self.mean.shape).to( 39 | device=self.parameters.device 40 | ) 41 | return x 42 | 43 | def kl(self, other=None): 44 | if self.deterministic: 45 | return torch.Tensor([0.0]) 46 | else: 47 | if other is None: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, 50 | dim=[1, 2, 3], 51 | ) 52 | else: 53 | return 0.5 * torch.sum( 54 | torch.pow(self.mean - other.mean, 2) / other.var 55 | + self.var / other.var 56 | - 1.0 57 | - self.logvar 58 | + other.logvar, 59 | dim=[1, 2, 3], 60 | ) 61 | 62 | def nll(self, sample, dims=[1, 2, 3]): 63 | if self.deterministic: 64 | return torch.Tensor([0.0]) 65 | logtwopi = np.log(2.0 * np.pi) 66 | return 0.5 * torch.sum( 67 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 68 | dim=dims, 69 | ) 70 | 71 | def mode(self): 72 | return self.mean 73 | 74 | 75 | def normal_kl(mean1, logvar1, mean2, logvar2): 76 | """ 77 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 78 | Compute the KL divergence between two gaussians. 79 | Shapes are automatically broadcasted, so batches can be compared to 80 | scalars, among other use cases. 81 | """ 82 | tensor = None 83 | for obj in (mean1, logvar1, mean2, logvar2): 84 | if isinstance(obj, torch.Tensor): 85 | tensor = obj 86 | break 87 | assert tensor is not None, "at least one argument must be a Tensor" 88 | 89 | # Force variances to be Tensors. Broadcasting helps convert scalars to 90 | # Tensors, but it does not work for torch.exp(). 91 | logvar1, logvar2 = [ 92 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 93 | for x in (logvar1, logvar2) 94 | ] 95 | 96 | return 0.5 * ( 97 | -1.0 98 | + logvar2 99 | - logvar1 100 | + torch.exp(logvar1 - logvar2) 101 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 102 | ) 103 | -------------------------------------------------------------------------------- /sgm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError("Decay must be between 0 and 1") 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer( 14 | "num_updates", 15 | torch.tensor(0, dtype=torch.int) 16 | if use_num_upates 17 | else torch.tensor(-1, dtype=torch.int), 18 | ) 19 | 20 | for name, p in model.named_parameters(): 21 | if p.requires_grad: 22 | # remove as '.'-character is not allowed in buffers 23 | s_name = name.replace(".", "") 24 | self.m_name2s_name.update({name: s_name}) 25 | self.register_buffer(s_name, p.clone().detach().data) 26 | 27 | self.collected_params = [] 28 | 29 | def reset_num_updates(self): 30 | del self.num_updates 31 | self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int)) 32 | 33 | def forward(self, model): 34 | decay = self.decay 35 | 36 | if self.num_updates >= 0: 37 | self.num_updates += 1 38 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) 39 | 40 | one_minus_decay = 1.0 - decay 41 | 42 | with torch.no_grad(): 43 | m_param = dict(model.named_parameters()) 44 | shadow_params = dict(self.named_buffers()) 45 | 46 | for key in m_param: 47 | if m_param[key].requires_grad: 48 | sname = self.m_name2s_name[key] 49 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 50 | shadow_params[sname].sub_( 51 | one_minus_decay * (shadow_params[sname] - m_param[key]) 52 | ) 53 | else: 54 | assert not key in self.m_name2s_name 55 | 56 | def copy_to(self, model): 57 | m_param = dict(model.named_parameters()) 58 | shadow_params = dict(self.named_buffers()) 59 | for key in m_param: 60 | if m_param[key].requires_grad: 61 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 62 | else: 63 | assert not key in self.m_name2s_name 64 | 65 | def store(self, parameters): 66 | """ 67 | Save the current parameters for restoring later. 68 | Args: 69 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 70 | temporarily stored. 71 | """ 72 | self.collected_params = [param.clone() for param in parameters] 73 | 74 | def restore(self, parameters): 75 | """ 76 | Restore the parameters stored with the `store` method. 77 | Useful to validate the model with EMA parameters without affecting the 78 | original optimization process. Store the parameters before the 79 | `copy_to` method. After validation (or model saving), use this to 80 | restore the former parameters. 81 | Args: 82 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 83 | updated with the stored parameters. 84 | """ 85 | for c_param, param in zip(self.collected_params, parameters): 86 | param.data.copy_(c_param.data) 87 | -------------------------------------------------------------------------------- /sgm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/customdiffusion360/custom-diffusion360/1a23f972274e7275fdeaa3197f5d22118aa228bb/sgm/modules/encoders/__init__.py -------------------------------------------------------------------------------- /sgm/modules/nerfsd_pytorch3d.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch 5 | from einops import rearrange 6 | from ..modules.utils_cameraray import ( 7 | get_patch_rays, 8 | get_plucker_parameterization, 9 | positional_encoding, 10 | convert_to_view_space, 11 | convert_to_view_space_points, 12 | convert_to_target_space, 13 | ) 14 | 15 | 16 | from pytorch3d.renderer import ray_bundle_to_ray_points 17 | from pytorch3d.renderer.implicit.raysampling import RayBundle as RayBundle 18 | from pytorch3d import _C 19 | 20 | from ..modules.diffusionmodules.util import zero_module 21 | 22 | 23 | class FeatureNeRFEncoding(nn.Module): 24 | def __init__( 25 | self, 26 | in_channels, 27 | out_channels, 28 | far_plane: float = 2.0, 29 | rgb_predict=False, 30 | average=False, 31 | num_freqs=16, 32 | ) -> None: 33 | super().__init__() 34 | 35 | self.far_plane = far_plane 36 | self.rgb_predict = rgb_predict 37 | self.average = average 38 | self.num_freqs = num_freqs 39 | dim = 3 40 | self.plane_coefs = nn.Sequential( 41 | nn.Linear(in_channels + self.num_freqs * dim * 4 + 2 * dim, out_channels), 42 | nn.SiLU(), 43 | nn.Linear(out_channels, out_channels), 44 | ) 45 | if not self.average: 46 | self.nviews = nn.Linear( 47 | in_channels + self.num_freqs * dim * 4 + 2 * dim, 1 48 | ) 49 | self.decoder = zero_module( 50 | nn.Linear(out_channels, 1 + (3 if rgb_predict else 0), bias=False) 51 | ) 52 | 53 | def forward(self, pose, xref, ray_points, rays, mask_ref): 54 | # xref : [b, n, hw, c] 55 | # ray_points: [b, n+1, hw, d, 3] 56 | # rays: [b, n+1, hw, 6] 57 | 58 | b, n, hw, c = xref.shape 59 | d = ray_points.shape[3] 60 | res = int(math.sqrt(hw)) 61 | if mask_ref is not None: 62 | mask_ref = torch.nn.functional.interpolate( 63 | rearrange( 64 | mask_ref, 65 | "b n ... -> (b n) ...", 66 | ), 67 | size=[res, res], 68 | mode="nearest", 69 | ).reshape(b, n, -1, 1) 70 | xref = xref * mask_ref 71 | 72 | volume = [] 73 | for i, cam in enumerate(pose): 74 | volume.append( 75 | cam.transform_points_ndc(ray_points[i, 0].reshape(-1, 3)).reshape(n + 1, hw, d, 3)[..., :2] 76 | ) 77 | volume = torch.stack(volume) 78 | 79 | plane_features = F.grid_sample( 80 | rearrange( 81 | xref, 82 | "b n (h w) c -> (b n) c h w", 83 | b=b, 84 | h=int(math.sqrt(hw)), 85 | w=int(math.sqrt(hw)), 86 | c=c, 87 | n=n, 88 | ), 89 | torch.clip( 90 | torch.nan_to_num( 91 | rearrange(-1 * volume[:, 1:].detach(), "b n ... -> (b n) ...") 92 | ), 93 | -1.2, 94 | 1.2, 95 | ), 96 | align_corners=True, 97 | padding_mode="zeros", 98 | ) # [bn, c, hw, d] 99 | 100 | plane_features = rearrange(plane_features, "(b n) ... -> b n ...", b=b, n=n) 101 | 102 | xyz_grid_features_inviewframe = convert_to_view_space_points(pose, ray_points[:, 0]) 103 | xyz_grid_features_inviewframe_encoding = positional_encoding(xyz_grid_features_inviewframe, self.num_freqs) 104 | camera_features_inviewframe = ( 105 | convert_to_view_space(pose, rays[:, 0])[:, 1:] 106 | .reshape(b, n, hw, 1, -1) 107 | .expand(-1, -1, -1, d, -1) 108 | ) 109 | camera_features_inviewframe_encoding = positional_encoding( 110 | get_plucker_parameterization(camera_features_inviewframe), 111 | self.num_freqs // 2, 112 | ) 113 | xyz_grid_features = xyz_grid_features_inviewframe_encoding[:, :1].expand( 114 | -1, n, -1, -1, -1 115 | ) 116 | camera_features = ( 117 | (convert_to_target_space(pose, rays[:, 1:])[..., :3]) 118 | .reshape(b, n, hw, 1, -1) 119 | .expand(-1, -1, -1, d, -1) 120 | ) 121 | camera_features_encoding = positional_encoding( 122 | camera_features, self.num_freqs 123 | ) 124 | plane_features_final = self.plane_coefs( 125 | torch.cat( 126 | [ 127 | plane_features.permute(0, 1, 3, 4, 2), 128 | xyz_grid_features_inviewframe_encoding[:, 1:], 129 | xyz_grid_features_inviewframe[:, 1:], 130 | camera_features_inviewframe_encoding, 131 | camera_features_inviewframe[..., 3:], 132 | ], 133 | dim=-1, 134 | ) 135 | ) # b, n, hw, d, c 136 | 137 | # plane_features = torch.cat([plane_features, xyz_grid_features, camera_features], dim=1) 138 | if not self.average: 139 | plane_features_attn = nn.functional.softmax( 140 | self.nviews( 141 | torch.cat( 142 | [ 143 | plane_features.permute(0, 1, 3, 4, 2), 144 | xyz_grid_features, 145 | xyz_grid_features_inviewframe[:, :1].expand(-1, n, -1, -1, -1), 146 | camera_features, 147 | camera_features_encoding, 148 | ], 149 | dim=-1, 150 | ) 151 | ), 152 | dim=1, 153 | ) # b, n, hw, d, 1 154 | 155 | plane_features_final = (plane_features_final * plane_features_attn).sum(1) 156 | else: 157 | plane_features_final = plane_features_final.mean(1) 158 | plane_features_attn = None 159 | 160 | out = self.decoder(plane_features_final) 161 | return torch.cat([plane_features_final, out], dim=-1), plane_features_attn 162 | 163 | 164 | class VolRender(nn.Module): 165 | def __init__( 166 | self, 167 | ): 168 | super().__init__() 169 | 170 | def get_weights(self, densities, deltas): 171 | """Return weights based on predicted densities 172 | 173 | Args: 174 | densities: Predicted densities for samples along ray 175 | 176 | Returns: 177 | Weights for each sample 178 | """ 179 | delta_density = deltas * densities # [b, hw, "num_samples", 1] 180 | alphas = 1 - torch.exp(-delta_density) 181 | transmittance = torch.cumsum(delta_density[..., :-1, :], dim=-2) 182 | transmittance = torch.cat( 183 | [ 184 | torch.zeros((*transmittance.shape[:2], 1, 1), device=densities.device), 185 | transmittance, 186 | ], 187 | dim=-2, 188 | ) 189 | transmittance = torch.exp(-transmittance) # [b, hw, "num_samples", 1] 190 | 191 | weights = alphas * transmittance # [b, hw, "num_samples", 1] 192 | weights = torch.nan_to_num(weights) 193 | # opacities = 1.0 - torch.prod(1.0 - alphas, dim=-2, keepdim=True) 194 | return weights, alphas, transmittance 195 | 196 | def forward( 197 | self, 198 | features, 199 | densities, 200 | dists=None, 201 | return_weight=False, 202 | densities_uniform=None, 203 | dists_uniform=None, 204 | return_weights_uniform=False, 205 | rgb=None 206 | ): 207 | alphas = None 208 | fg_mask = None 209 | if dists is not None: 210 | weights, alphas, transmittance = self.get_weights(densities, dists) 211 | fg_mask = torch.sum(weights, -2) 212 | else: 213 | weights = densities # used when we have a pretraind nerf with direct weights as output 214 | 215 | rendered_feats = torch.sum(weights * features, dim=-2) + torch.sum( 216 | (1 - weights) * torch.zeros_like(features), dim=-2 217 | ) 218 | if rgb is not None: 219 | rgb = torch.sum(weights * rgb, dim=-2) + torch.sum( 220 | (1 - weights) * torch.zeros_like(rgb), dim=-2 221 | ) 222 | # print("RENDER", fg_mask.shape, weights.shape) 223 | weights_uniform = None 224 | if return_weight: 225 | return rendered_feats, fg_mask, alphas, weights, rgb 226 | elif return_weights_uniform: 227 | if densities_uniform is not None: 228 | weights_uniform, _, transmittance = self.get_weights(densities_uniform, dists_uniform) 229 | return rendered_feats, fg_mask, alphas, weights_uniform, rgb 230 | else: 231 | return rendered_feats, fg_mask, alphas, None, rgb 232 | 233 | 234 | class Raymarcher(nn.Module): 235 | def __init__( 236 | self, 237 | num_samples=32, 238 | far_plane=2.0, 239 | stratified=False, 240 | training=True, 241 | imp_sampling_percent=0.9, 242 | near_plane=0., 243 | ): 244 | super().__init__() 245 | self.num_samples = num_samples 246 | self.far_plane = far_plane 247 | self.near_plane = near_plane 248 | u_max = 1. / (self.num_samples) 249 | u = torch.linspace(0, 1 - u_max, self.num_samples, device="cuda") 250 | self.register_buffer("u", u) 251 | lengths = torch.linspace(self.near_plane, self.near_plane+self.far_plane, self.num_samples+1, device="cuda") 252 | # u = (u[..., :-1] + u[..., 1:]) / 2.0 253 | lengths_center = (lengths[..., 1:] + lengths[..., :-1]) / 2.0 254 | lengths_upper = torch.cat([lengths_center, lengths[..., -1:]], -1) 255 | lengths_lower = torch.cat([lengths[..., :1], lengths_center], -1) 256 | self.register_buffer("lengths", lengths) 257 | self.register_buffer("lengths_center", lengths_center) 258 | self.register_buffer("lengths_upper", lengths_upper) 259 | self.register_buffer("lengths_lower", lengths_lower) 260 | self.stratified = stratified 261 | self.training = training 262 | self.imp_sampling_percent = imp_sampling_percent 263 | 264 | @torch.no_grad() 265 | def importance_sampling(self, cdf, num_rays, num_samples, device): 266 | # sample target rays for each reference view 267 | cdf = cdf[..., 0] + 0.01 268 | if cdf.shape[1] != num_rays: 269 | size = int(math.sqrt(num_rays)) 270 | size_ = int(math.sqrt(cdf.size(1))) 271 | cdf = rearrange( 272 | torch.nn.functional.interpolate( 273 | rearrange( 274 | cdf.permute(0, 2, 1), "... (h w) -> ... h w", h=size_, w=size_ 275 | ), 276 | size=[size, size], 277 | antialias=True, 278 | mode="bilinear", 279 | ), 280 | "... h w -> ... (h w)", 281 | h=size, 282 | w=size, 283 | ).permute(0, 2, 1) 284 | 285 | lengths = self.lengths[None, None, :].expand(cdf.shape[0], num_rays, -1) 286 | 287 | cdf_sum = torch.sum(cdf, dim=-1, keepdim=True) 288 | padding = torch.relu(1e-5 - cdf_sum) 289 | cdf = cdf + padding / cdf.shape[-1] 290 | cdf_sum += padding 291 | 292 | pdf = cdf / cdf_sum 293 | 294 | # sample_pdf function 295 | u_max = 1. / (num_samples) 296 | u = self.u[None, None, :].expand(cdf.shape[0], num_rays, -1) 297 | if self.stratified and self.training: 298 | u += torch.rand((cdf.shape[0], num_rays, num_samples), dtype=cdf.dtype, device=cdf.device,) * u_max 299 | 300 | _C.sample_pdf( 301 | lengths.reshape(-1, num_samples + 1), 302 | pdf.reshape(-1, num_samples), 303 | u.reshape(-1, num_samples), 304 | 1e-5, 305 | ) 306 | return u, torch.cat([u[..., 1:] - u[..., :-1], lengths[..., -1:] - u[..., -1:]], -1) 307 | 308 | @torch.no_grad() 309 | def stratified_sampling(self, num_rays, device, uniform=False): 310 | lengths_uniform = self.lengths[None, None, :].expand(-1, num_rays, -1) 311 | 312 | if uniform: 313 | return ( 314 | (lengths_uniform[..., 1:] + lengths_uniform[..., :-1]) / 2.0, 315 | lengths_uniform[..., 1:] - lengths_uniform[..., :-1], 316 | ) 317 | if self.stratified and self.training: 318 | t_rand = torch.rand( 319 | (num_rays, self.num_samples + 1), 320 | dtype=lengths_uniform.dtype, 321 | device=lengths_uniform.device, 322 | ) 323 | jittered = self.lengths_lower[None, None, :].expand(-1, num_rays, -1) + \ 324 | (self.lengths_upper[None, None, :].expand(-1, num_rays, -1) - self.lengths_lower[None, None, :].expand(-1, num_rays, -1)) * t_rand 325 | return ((jittered[..., :-1] + jittered[..., 1:])/2., jittered[..., 1:] - jittered[..., :-1]) 326 | else: 327 | return ( 328 | (lengths_uniform[..., 1:] + lengths_uniform[..., :-1]) / 2.0, 329 | lengths_uniform[..., 1:] - lengths_uniform[..., :-1], 330 | ) 331 | 332 | @torch.no_grad() 333 | def forward(self, pose, resolution, weights, imp_sample_next_step=False, device='cuda', pytorch3d=True): 334 | input_patch_rays, xys = get_patch_rays( 335 | pose, 336 | num_patches_x=resolution, 337 | num_patches_y=resolution, 338 | device=device, 339 | return_xys=True, 340 | stratified=self.stratified and self.training, 341 | ) # (b, n, h*w, 6) 342 | 343 | num_rays = resolution**2 344 | # sample target rays for each reference view 345 | if weights is not None: 346 | if self.imp_sampling_percent <= 0: 347 | lengths, dists = self.stratified_sampling(num_rays, device) 348 | elif (torch.rand(1) < (1.-self.imp_sampling_percent)) and self.training: 349 | lengths, dists = self.stratified_sampling(num_rays, device) 350 | else: 351 | lengths, dists = self.importance_sampling( 352 | weights, num_rays, self.num_samples, device=device 353 | ) 354 | else: 355 | lengths, dists = self.stratified_sampling(num_rays, device) 356 | 357 | dists_uniform = None 358 | ray_points_uniform = None 359 | if imp_sample_next_step: 360 | lengths_uniform, dists_uniform = self.stratified_sampling( 361 | num_rays, device, uniform=True 362 | ) 363 | 364 | target_patch_raybundle_uniform = RayBundle( 365 | origins=input_patch_rays[:, :1, :, :3], 366 | directions=input_patch_rays[:, :1, :, 3:], 367 | lengths=lengths_uniform, 368 | xys=xys.to(device), 369 | ) 370 | ray_points_uniform = ray_bundle_to_ray_points(target_patch_raybundle_uniform).detach() 371 | dists_uniform = dists_uniform.detach() 372 | 373 | # print( 374 | # "SAMPLING", 375 | # lengths.shape, 376 | # lengths_uniform.shape, 377 | # dists.shape, 378 | # dists_uniform.shape, 379 | # input_patch_rays.shape, 380 | # ) 381 | target_patch_raybundle = RayBundle( 382 | origins=input_patch_rays[:, :1, :, :3], 383 | directions=input_patch_rays[:, :1, :, 3:], 384 | lengths=lengths.to(device), 385 | xys=xys.to(device), 386 | ) 387 | ray_points = ray_bundle_to_ray_points(target_patch_raybundle) 388 | return ( 389 | input_patch_rays.detach(), 390 | ray_points.detach(), 391 | dists.detach(), 392 | ray_points_uniform, 393 | dists_uniform, 394 | ) 395 | 396 | 397 | class NerfSDModule(nn.Module): 398 | def __init__( 399 | self, 400 | mode="feature-nerf", 401 | out_channels=None, 402 | far_plane=2.0, 403 | num_samples=32, 404 | rgb_predict=False, 405 | average=False, 406 | num_freqs=16, 407 | stratified=False, 408 | imp_sampling_percent=0.9, 409 | near_plane=0. 410 | ): 411 | MODES = { 412 | "feature-nerf": FeatureNeRFEncoding, # ampere 413 | } 414 | super().__init__() 415 | self.rgb_predict = rgb_predict 416 | 417 | self.raymarcher = Raymarcher( 418 | num_samples=num_samples, 419 | far_plane=near_plane + far_plane, 420 | stratified=stratified, 421 | imp_sampling_percent=imp_sampling_percent, 422 | near_plane=near_plane, 423 | ) 424 | model_class = MODES[mode] 425 | self.model = model_class( 426 | out_channels, 427 | out_channels, 428 | far_plane=near_plane + far_plane, 429 | rgb_predict=rgb_predict, 430 | average=average, 431 | num_freqs=num_freqs, 432 | ) 433 | 434 | def forward(self, pose, xref=None, mask_ref=None, prev_weights=None, imp_sample_next_step=False,): 435 | # xref: b n h w c or b n hw c 436 | # pose: a list of pytorch3d cameras 437 | # mask_ref: mask corresponding to black regions because of padding non square images. 438 | rgb = None 439 | dists_uniform = None 440 | weights_uniform = None 441 | resolution = (int(math.sqrt(xref.size(2))) if len(xref.shape) == 4 else xref.size(3)) 442 | input_patch_rays, ray_points, dists, ray_points_uniform, dists_uniform = (self.raymarcher(pose, resolution, weights=prev_weights, device=xref.device)) 443 | output, plane_features_attn = self.model(pose, xref, ray_points, input_patch_rays, mask_ref) 444 | weights = output[..., -1:] 445 | features = output[..., :-1] 446 | if self.rgb_predict: 447 | rgb = features[..., -3:] 448 | features = features[..., :-3] 449 | dists = dists.unsqueeze(-1) 450 | with torch.no_grad(): 451 | if ray_points_uniform is not None: 452 | output_uniform, _ = self.model(pose, xref, ray_points_uniform, input_patch_rays, mask_ref) 453 | weights_uniform = output_uniform[..., -1:] 454 | dists_uniform = dists_uniform.unsqueeze(-1) 455 | 456 | return ( 457 | features, 458 | weights, 459 | dists, 460 | plane_features_attn, 461 | rgb, 462 | weights_uniform, 463 | dists_uniform, 464 | ) 465 | -------------------------------------------------------------------------------- /sgm/modules/utils_cameraray.py: -------------------------------------------------------------------------------- 1 | #### Code taken from: https://github.com/mayankgrwl97/gbt 2 | """Utils for ray manipulation""" 3 | 4 | import numpy as np 5 | import torch 6 | from pytorch3d.renderer.implicit.raysampling import RayBundle as RayBundle 7 | from pytorch3d.renderer.camera_utils import join_cameras_as_batch 8 | from pytorch3d.renderer.cameras import PerspectiveCameras 9 | 10 | 11 | ############################# RAY BUNDLE UTILITIES ############################# 12 | 13 | def is_scalar(x): 14 | """Returns True if the provided variable is a scalar 15 | 16 | Args: 17 | x: scalar or array-like (numpy array or torch tensor) 18 | 19 | Returns: 20 | bool: True if x is of the type scalar, or array-like with 0 dimension. False, otherwise 21 | 22 | """ 23 | if isinstance(x, float) or isinstance(x, int): 24 | return True 25 | 26 | if isinstance(x, np.ndarray) and np.ndim(x) == 0: 27 | return True 28 | 29 | if isinstance(x, torch.Tensor) and x.dim() == 0: 30 | return True 31 | 32 | return False 33 | 34 | 35 | def transform_rays(reference_R, reference_T, rays): 36 | """ 37 | PyTorch3D Convention is used: X_cam = X_world @ R + T 38 | 39 | Args: 40 | reference_R: world2cam rotation matrix for reference camera (B, 3, 3) 41 | reference_T: world2cam translation vector for reference camera (B, 3) 42 | rays: (origin, direction) defined in world reference frame (B, V, N, 6) 43 | Returns: 44 | torch.Tensor: Transformed rays w.r.t. reference camera (B, V, N, 6) 45 | """ 46 | batch, num_views, num_rays, ray_dim = rays.shape 47 | assert ( 48 | ray_dim == 6 49 | ), "First 3 dimensions should be origin; Last 3 dimensions should be direction" 50 | 51 | rays = rays.reshape(batch, num_views * num_rays, ray_dim) 52 | rays_out = rays.clone() 53 | rays_out[..., :3] = torch.bmm(rays[..., :3], reference_R) + reference_T.unsqueeze( 54 | -2 55 | ) 56 | rays_out[..., 3:] = torch.bmm(rays[..., 3:], reference_R) 57 | rays_out = rays_out.reshape(batch, num_views, num_rays, ray_dim) 58 | return rays_out 59 | 60 | 61 | def get_directional_raybundle(cameras, x_pos_ndc, y_pos_ndc, max_depth=1): 62 | if is_scalar(x_pos_ndc): 63 | x_pos_ndc = [x_pos_ndc] 64 | if is_scalar(y_pos_ndc): 65 | y_pos_ndc = [y_pos_ndc] 66 | assert is_scalar(max_depth) 67 | 68 | if not isinstance(x_pos_ndc, torch.Tensor): 69 | x_pos_ndc = torch.tensor(x_pos_ndc) # (N, ) 70 | if not isinstance(y_pos_ndc, torch.Tensor): 71 | y_pos_ndc = torch.tensor(y_pos_ndc) # (N, ) 72 | 73 | xy_depth = torch.stack( 74 | (x_pos_ndc, y_pos_ndc, torch.ones_like(x_pos_ndc) * max_depth), dim=-1 75 | ) # (N, 3) 76 | 77 | num_points = xy_depth.shape[0] 78 | 79 | unprojected = cameras.unproject_points( 80 | xy_depth.to(cameras.device), world_coordinates=True, from_ndc=True 81 | ) # (N, 3) 82 | unprojected = unprojected.unsqueeze(0).to("cpu") # (B, N, 3) 83 | 84 | origins = ( 85 | cameras.get_camera_center()[:, None, :].expand(-1, num_points, -1).to("cpu") 86 | ) # (B, N, 3) 87 | directions = unprojected - origins # (B, N, 3) 88 | directions = directions / directions.norm(dim=-1).unsqueeze(-1) # (B, N, 3) 89 | lengths = ( 90 | torch.tensor([[0, 3]]).unsqueeze(0).expand(-1, num_points, -1).to("cpu") 91 | ) # (B, N, 2) 92 | xys = xy_depth[:, :2].unsqueeze(0).to("cpu") # (B, N, 2) 93 | 94 | raybundle = RayBundle( 95 | origins=origins.to("cpu"), 96 | directions=directions.to("cpu"), 97 | lengths=lengths.to("cpu"), 98 | xys=xys.to("cpu"), 99 | ) 100 | return raybundle 101 | 102 | 103 | def get_patch_raybundle( 104 | cameras, num_patches_x, num_patches_y, max_depth=1, stratified=False 105 | ): 106 | horizontal_patch_edges = torch.linspace(1, -1, num_patches_x + 1) 107 | # horizontal_positions = horizontal_patch_edges[:-1] # (num_patches_x,): Top left corner of patch 108 | 109 | vertical_patch_edges = torch.linspace(1, -1, num_patches_y + 1) 110 | # vertical_positions = vertical_patch_edges[:-1] # (num_patches_y,): Top left corner of patch 111 | if stratified: 112 | horizontal_patch_edges_center = ( 113 | horizontal_patch_edges[..., 1:] + horizontal_patch_edges[..., :-1] 114 | ) / 2.0 115 | horizontal_patch_edges_upper = torch.cat( 116 | [horizontal_patch_edges_center, horizontal_patch_edges[..., -1:]], -1 117 | ) 118 | horizontal_patch_edges_lower = torch.cat( 119 | [horizontal_patch_edges[..., :1], horizontal_patch_edges_center], -1 120 | ) 121 | horizontal_positions = ( 122 | horizontal_patch_edges_lower 123 | + (horizontal_patch_edges_upper - horizontal_patch_edges_lower) 124 | * torch.rand_like(horizontal_patch_edges_lower) 125 | )[..., :-1] 126 | 127 | vertical_patch_edges_center = ( 128 | vertical_patch_edges[..., 1:] + vertical_patch_edges[..., :-1] 129 | ) / 2.0 130 | vertical_patch_edges_upper = torch.cat( 131 | [vertical_patch_edges_center, vertical_patch_edges[..., -1:]], -1 132 | ) 133 | vertical_patch_edges_lower = torch.cat( 134 | [vertical_patch_edges[..., :1], vertical_patch_edges_center], -1 135 | ) 136 | vertical_positions = ( 137 | vertical_patch_edges_lower 138 | + (vertical_patch_edges_upper - vertical_patch_edges_lower) 139 | * torch.rand_like(vertical_patch_edges_lower) 140 | )[..., :-1] 141 | else: 142 | horizontal_positions = ( 143 | horizontal_patch_edges[:-1] + horizontal_patch_edges[1:] 144 | ) / 2 # (num_patches_x, ) # Center of patch 145 | vertical_positions = ( 146 | vertical_patch_edges[:-1] + vertical_patch_edges[1:] 147 | ) / 2 # (num_patches_y, ) # Center of patch 148 | 149 | h_pos, v_pos = torch.meshgrid( 150 | horizontal_positions, vertical_positions, indexing='xy' 151 | ) # (num_patches_y, num_patches_x), (num_patches_y, num_patches_x) 152 | h_pos = h_pos.reshape(-1) # (num_patches_y * num_patches_x) 153 | v_pos = v_pos.reshape(-1) # (num_patches_y * num_patches_x) 154 | 155 | raybundle = get_directional_raybundle( 156 | cameras=cameras, x_pos_ndc=h_pos, y_pos_ndc=v_pos, max_depth=max_depth 157 | ) 158 | return raybundle 159 | 160 | 161 | def get_patch_rays( 162 | cameras_list, 163 | num_patches_x, 164 | num_patches_y, 165 | device, 166 | return_xys=False, 167 | stratified=False, 168 | ): 169 | """Returns patch rays given the camera viewpoints 170 | 171 | Args: 172 | cameras_list(list[pytorch3d.renderer.cameras.BaseCameras]): List of list of cameras (len (batch_size, num_input_views,)) 173 | num_patches_x: Number of patches in the x-direction (horizontal) 174 | num_patches_y: Number of patches in the y-direction (vertical) 175 | 176 | Returns: 177 | torch.tensor: Patch rays of shape (batch_size, num_views, num_patches, 6) 178 | """ 179 | batch, numviews = len(cameras_list), len(cameras_list[0]) 180 | cameras_list = join_cameras_as_batch([cam for cam_batch in cameras_list for cam in cam_batch]) 181 | patch_rays = get_patch_raybundle( 182 | cameras_list, 183 | num_patches_y=num_patches_y, 184 | num_patches_x=num_patches_x, 185 | stratified=stratified, 186 | ) 187 | if return_xys: 188 | xys = patch_rays.xys 189 | 190 | patch_rays = torch.cat((patch_rays.origins.unsqueeze(0), patch_rays.directions), dim=-1) 191 | patch_rays = patch_rays.reshape( 192 | batch, numviews, num_patches_x * num_patches_y, 6 193 | ).to(device) 194 | if return_xys: 195 | return patch_rays, xys 196 | return patch_rays 197 | 198 | ############################ RAY PARAMETERIZATION ############################## 199 | 200 | 201 | def get_plucker_parameterization(ray): 202 | """Returns the plucker representation of the rays given the (origin, direction) representation 203 | 204 | Args: 205 | ray(torch.Tensor): Tensor of shape (..., 6) with the (origin, direction) representation 206 | 207 | Returns: 208 | torch.Tensor: Tensor of shape (..., 6) with the plucker (D, OxD) representation 209 | """ 210 | ray = ray.clone() # Create a clone 211 | ray_origins = ray[..., :3] 212 | ray_directions = ray[..., 3:] 213 | ray_directions = ray_directions / ray_directions.norm(dim=-1).unsqueeze( 214 | -1 215 | ) # Normalize ray directions to unit vectors 216 | plucker_normal = torch.cross(ray_origins, ray_directions, dim=-1) 217 | plucker_parameterization = torch.cat([ray_directions, plucker_normal], dim=-1) 218 | 219 | return plucker_parameterization 220 | 221 | 222 | def positional_encoding(ray, n_freqs=10, start_freq=0): 223 | """ 224 | Positional Embeddings. For more details see Section 5.1 of 225 | NeRFs: https://arxiv.org/pdf/2003.08934.pdf 226 | 227 | Args: 228 | ray: (B,P,d) 229 | n_freqs: num of frequency bands 230 | parameterize(str|None): Parameterization used for rays. Recommended: use 'plucker'. Default=None. 231 | 232 | Returns: 233 | pos_embeddings: Mapping input ray from R to R^{2*n_freqs}. 234 | """ 235 | start_freq = -1 * (n_freqs / 2) 236 | freq_bands = 2.0 ** torch.arange(start_freq, start_freq + n_freqs) * np.pi 237 | sin_encodings = [torch.sin(ray * freq) for freq in freq_bands] 238 | cos_encodings = [torch.cos(ray * freq) for freq in freq_bands] 239 | pos_embeddings = torch.cat( 240 | sin_encodings + cos_encodings, dim=-1 241 | ) # B, P, d * 2n_freqs 242 | return pos_embeddings 243 | 244 | 245 | def convert_to_target_space(input_cameras, input_rays): 246 | input_rays_transformed = [] 247 | # input_cameras: b, N 248 | # input_rays: b, N, hw, 6 249 | # return: b, N, hw, 6 250 | for i in range(len(input_cameras[0])): 251 | reference_cameras = [cameras[0] for cameras in input_cameras] 252 | reference_R = [ 253 | camera.R.to(input_rays.device) for camera in reference_cameras 254 | ] # List (length=batch_size) of Rs(shape: 1, 3, 3) 255 | reference_R = torch.cat(reference_R, dim=0) # (B, 3, 3) 256 | reference_T = [ 257 | camera.T.to(input_rays.device) for camera in reference_cameras 258 | ] # List (length=batch_size) of Ts(shape: 1, 3) 259 | reference_T = torch.cat(reference_T, dim=0) # (B, 3) 260 | input_rays_transformed.append( 261 | transform_rays( 262 | reference_R=reference_R, 263 | reference_T=reference_T, 264 | rays=input_rays[:, i: i + 1], 265 | ) 266 | ) 267 | return torch.cat(input_rays_transformed, 1) 268 | 269 | 270 | def convert_to_view_space(input_cameras, input_rays): 271 | input_rays_transformed = [] 272 | # input_cameras: b, N 273 | # input_rays: b, hw, 6 274 | # return: b, n, hw, 6 275 | for i in range(len(input_cameras[0])): 276 | reference_cameras = [cameras[i] for cameras in input_cameras] 277 | reference_R = [ 278 | camera.R.to(input_rays.device) for camera in reference_cameras 279 | ] # List (length=batch_size) of Rs(shape: 1, 3, 3) 280 | reference_R = torch.cat(reference_R, dim=0) # (B, 3, 3) 281 | reference_T = [ 282 | camera.T.to(input_rays.device) for camera in reference_cameras 283 | ] # List (length=batch_size) of Ts(shape: 1, 3) 284 | reference_T = torch.cat(reference_T, dim=0) # (B, 3) 285 | input_rays_transformed.append( 286 | transform_rays( 287 | reference_R=reference_R, 288 | reference_T=reference_T, 289 | rays=input_rays.unsqueeze(1), 290 | ) 291 | ) 292 | return torch.cat(input_rays_transformed, 1) 293 | 294 | 295 | def convert_to_view_space_points(input_cameras, input_points): 296 | input_rays_transformed = [] 297 | # input_cameras: b, N 298 | # ipput_points: b, hw, d, 3 299 | # returns: b, N, hw, d, 3 [target points transformed in the reference view frame] 300 | for i in range(len(input_cameras[0])): 301 | reference_cameras = [cameras[i] for cameras in input_cameras] 302 | reference_R = [ 303 | camera.R.to(input_points.device) for camera in reference_cameras 304 | ] # List (length=batch_size) of Rs(shape: 1, 3, 3) 305 | reference_R = torch.cat(reference_R, dim=0) # (B, 3, 3) 306 | reference_T = [ 307 | camera.T.to(input_points.device) for camera in reference_cameras 308 | ] # List (length=batch_size) of Ts(shape: 1, 3) 309 | reference_T = torch.cat(reference_T, dim=0) # (B, 3) 310 | input_points_clone = torch.einsum( 311 | "bsdj,bjk->bsdk", input_points, reference_R 312 | ) + reference_T.reshape(-1, 1, 1, 3) 313 | input_rays_transformed.append(input_points_clone.unsqueeze(1)) 314 | return torch.cat(input_rays_transformed, dim=1) 315 | 316 | 317 | def interpolate_translate_interpolate_xaxis(cam1, interp_start, interp_end, interp_step): 318 | cameras = [] 319 | for i in np.arange(interp_start, interp_end, interp_step): 320 | viewtoworld = cam1.get_world_to_view_transform().inverse() 321 | 322 | x_axis = torch.from_numpy(np.array([i, 0., 0.0])).reshape(1, 3).float().to(cam1.device) 323 | newc = viewtoworld.transform_points(x_axis) 324 | rt = cam1.R[0] 325 | # t = cam1.T 326 | new_t = -rt.T@newc.T 327 | 328 | cameras.append(PerspectiveCameras(R=cam1.R, 329 | T=new_t.T, 330 | focal_length=cam1.focal_length, 331 | principal_point=cam1.principal_point, 332 | image_size=512, 333 | ) 334 | ) 335 | return cameras 336 | 337 | 338 | def interpolate_translate_interpolate_yaxis(cam1, interp_start, interp_end, interp_step): 339 | cameras = [] 340 | for i in np.arange(interp_start, interp_end, interp_step): 341 | # i = np.clip(i, -0.2, 0.18) 342 | viewtoworld = cam1.get_world_to_view_transform().inverse() 343 | 344 | x_axis = torch.from_numpy(np.array([0, i, 0.0])).reshape(1, 3).float().to(cam1.device) 345 | newc = viewtoworld.transform_points(x_axis) 346 | rt = cam1.R[0] 347 | # t = cam1.T 348 | new_t = -rt.T@newc.T 349 | 350 | cameras.append(PerspectiveCameras(R=cam1.R, 351 | T=new_t.T, 352 | focal_length=cam1.focal_length, 353 | principal_point=cam1.principal_point, 354 | image_size=512, 355 | ) 356 | ) 357 | return cameras 358 | 359 | 360 | def interpolate_translate_interpolate_zaxis(cam1, interp_start, interp_end, interp_step): 361 | cameras = [] 362 | for i in np.arange(interp_start, interp_end, interp_step): 363 | viewtoworld = cam1.get_world_to_view_transform().inverse() 364 | 365 | x_axis = torch.from_numpy(np.array([0, 0., i])).reshape(1, 3).float().to(cam1.device) 366 | newc = viewtoworld.transform_points(x_axis) 367 | rt = cam1.R[0] 368 | # t = cam1.T 369 | new_t = -rt.T@newc.T 370 | 371 | cameras.append(PerspectiveCameras(R=cam1.R, 372 | T=new_t.T, 373 | focal_length=cam1.focal_length, 374 | principal_point=cam1.principal_point, 375 | image_size=512, 376 | ) 377 | ) 378 | return cameras 379 | 380 | 381 | def interpolatefocal(cam1, interp_start, interp_end, interp_step): 382 | cameras = [] 383 | for i in np.arange(interp_start, interp_end, interp_step): 384 | cameras.append(PerspectiveCameras(R=cam1.R, 385 | T=cam1.T, 386 | focal_length=cam1.focal_length*i, 387 | principal_point=cam1.principal_point, 388 | image_size=512, 389 | ) 390 | ) 391 | return cameras 392 | -------------------------------------------------------------------------------- /sgm/util.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import importlib 3 | import os 4 | from functools import partial 5 | from inspect import isfunction 6 | 7 | import fsspec 8 | import numpy as np 9 | import torch 10 | from PIL import Image, ImageDraw, ImageFont 11 | from safetensors.torch import load_file as load_safetensors 12 | 13 | 14 | def disabled_train(self, mode=True): 15 | """Overwrite model.train with this function to make sure train/eval mode 16 | does not change anymore.""" 17 | return self 18 | 19 | 20 | def get_string_from_tuple(s): 21 | try: 22 | # Check if the string starts and ends with parentheses 23 | if s[0] == "(" and s[-1] == ")": 24 | # Convert the string to a tuple 25 | t = eval(s) 26 | # Check if the type of t is tuple 27 | if type(t) == tuple: 28 | return t[0] 29 | else: 30 | pass 31 | except: 32 | pass 33 | return s 34 | 35 | 36 | def is_power_of_two(n): 37 | """ 38 | chat.openai.com/chat 39 | Return True if n is a power of 2, otherwise return False. 40 | 41 | The function is_power_of_two takes an integer n as input and returns True if n is a power of 2, otherwise it returns False. 42 | The function works by first checking if n is less than or equal to 0. If n is less than or equal to 0, it can't be a power of 2, so the function returns False. 43 | If n is greater than 0, the function checks whether n is a power of 2 by using a bitwise AND operation between n and n-1. If n is a power of 2, then it will have only one bit set to 1 in its binary representation. When we subtract 1 from a power of 2, all the bits to the right of that bit become 1, and the bit itself becomes 0. So, when we perform a bitwise AND between n and n-1, we get 0 if n is a power of 2, and a non-zero value otherwise. 44 | Thus, if the result of the bitwise AND operation is 0, then n is a power of 2 and the function returns True. Otherwise, the function returns False. 45 | 46 | """ 47 | if n <= 0: 48 | return False 49 | return (n & (n - 1)) == 0 50 | 51 | 52 | def autocast(f, enabled=True): 53 | def do_autocast(*args, **kwargs): 54 | with torch.cuda.amp.autocast( 55 | enabled=enabled, 56 | dtype=torch.get_autocast_gpu_dtype(), 57 | cache_enabled=torch.is_autocast_cache_enabled(), 58 | ): 59 | return f(*args, **kwargs) 60 | 61 | return do_autocast 62 | 63 | 64 | def load_partial_from_config(config): 65 | return partial(get_obj_from_str(config["target"]), **config.get("params", dict())) 66 | 67 | 68 | def log_txt_as_img(wh, xc, size=10): 69 | # wh a tuple of (width, height) 70 | # xc a list of captions to plot 71 | b = len(xc) 72 | txts = list() 73 | for bi in range(b): 74 | txt = Image.new("RGB", wh, color="white") 75 | draw = ImageDraw.Draw(txt) 76 | font = ImageFont.truetype("data/DejaVuSans.ttf", size=size) 77 | nc = int(40 * (wh[0] / 256)) 78 | if isinstance(xc[bi], list): 79 | text_seq = xc[bi][0] 80 | else: 81 | text_seq = xc[bi] 82 | lines = "\n".join( 83 | text_seq[start : start + nc] for start in range(0, len(text_seq), nc) 84 | ) 85 | 86 | try: 87 | draw.text((0, 0), lines, fill="black", font=font) 88 | except UnicodeEncodeError: 89 | print("Cant encode string for logging. Skipping.") 90 | 91 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 92 | txts.append(txt) 93 | txts = np.stack(txts) 94 | txts = torch.tensor(txts) 95 | return txts 96 | 97 | 98 | def partialclass(cls, *args, **kwargs): 99 | class NewCls(cls): 100 | __init__ = functools.partialmethod(cls.__init__, *args, **kwargs) 101 | 102 | return NewCls 103 | 104 | 105 | def make_path_absolute(path): 106 | fs, p = fsspec.core.url_to_fs(path) 107 | if fs.protocol == "file": 108 | return os.path.abspath(p) 109 | return path 110 | 111 | 112 | def ismap(x): 113 | if not isinstance(x, torch.Tensor): 114 | return False 115 | return (len(x.shape) == 4) and (x.shape[1] > 3) 116 | 117 | 118 | def isimage(x): 119 | if not isinstance(x, torch.Tensor): 120 | return False 121 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 122 | 123 | 124 | def isheatmap(x): 125 | if not isinstance(x, torch.Tensor): 126 | return False 127 | 128 | return x.ndim == 2 129 | 130 | 131 | def isneighbors(x): 132 | if not isinstance(x, torch.Tensor): 133 | return False 134 | return x.ndim == 5 and (x.shape[2] == 3 or x.shape[2] == 1) 135 | 136 | 137 | def exists(x): 138 | return x is not None 139 | 140 | 141 | def expand_dims_like(x, y): 142 | while x.dim() != y.dim(): 143 | x = x.unsqueeze(-1) 144 | return x 145 | 146 | 147 | def default(val, d): 148 | if exists(val): 149 | return val 150 | return d() if isfunction(d) else d 151 | 152 | 153 | def mean_flat(tensor): 154 | """ 155 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 156 | Take the mean over all non-batch dimensions. 157 | """ 158 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 159 | 160 | 161 | def count_params(model, verbose=False): 162 | total_params = sum(p.numel() for p in model.parameters()) 163 | if verbose: 164 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") 165 | return total_params 166 | 167 | 168 | def instantiate_from_config(config): 169 | if not "target" in config: 170 | if config == "__is_first_stage__": 171 | return None 172 | elif config == "__is_unconditional__": 173 | return None 174 | raise KeyError("Expected key `target` to instantiate.") 175 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 176 | 177 | 178 | def get_obj_from_str(string, reload=False, invalidate_cache=True): 179 | module, cls = string.rsplit(".", 1) 180 | if invalidate_cache: 181 | importlib.invalidate_caches() 182 | if reload: 183 | module_imp = importlib.import_module(module) 184 | importlib.reload(module_imp) 185 | return getattr(importlib.import_module(module, package=None), cls) 186 | 187 | 188 | def append_zero(x): 189 | return torch.cat([x, x.new_zeros([1])]) 190 | 191 | 192 | def append_dims(x, target_dims): 193 | """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" 194 | dims_to_append = target_dims - x.ndim 195 | if dims_to_append < 0: 196 | raise ValueError( 197 | f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" 198 | ) 199 | return x[(...,) + (None,) * dims_to_append] 200 | 201 | 202 | def load_model_from_config(config, ckpt, delta_ckpt=None, verbose=True, freeze=True): 203 | print(f"Loading model from {ckpt}") 204 | if ckpt.endswith("ckpt"): 205 | pl_sd = torch.load(ckpt, map_location="cpu") 206 | if "global_step" in pl_sd: 207 | print(f"Global Step: {pl_sd['global_step']}") 208 | sd = pl_sd["state_dict"] 209 | elif ckpt.endswith("safetensors"): 210 | sd = load_safetensors(ckpt) 211 | else: 212 | raise NotImplementedError 213 | 214 | model = instantiate_from_config(config.model) 215 | 216 | if delta_ckpt is not None: 217 | token_weights1 = sd['conditioner.embedders.0.transformer.text_model.embeddings.token_embedding.weight'] 218 | token_weights2 = sd['conditioner.embedders.1.model.token_embedding.weight'] 219 | del sd['conditioner.embedders.0.transformer.text_model.embeddings.token_embedding.weight'] 220 | del sd['conditioner.embedders.1.model.token_embedding.weight'] 221 | 222 | m, u = model.load_state_dict(sd, strict=False) 223 | 224 | ## Load delta ckpt 225 | if delta_ckpt is not None: 226 | pl_sd_delta = torch.load(delta_ckpt, map_location="cpu") 227 | sd_delta = pl_sd_delta["delta_state_dict"] 228 | model.conditioner.embedders[0].transformer.text_model.embeddings.token_embedding.weight.data = torch.cat([token_weights1, sd_delta['embed'][0]], 0).to(model.device) 229 | model.conditioner.embedders[1].model.token_embedding.weight.data = torch.cat([token_weights2, sd_delta['embed'][1]], 0).to(model.device) 230 | del sd_delta['embed'] 231 | for name, module in model.model.diffusion_model.named_modules(): 232 | if len(name.split('.')) > 1 and name.split('.')[-2] == 'transformer_blocks': 233 | if hasattr(module, 'pose_emb_layers'): 234 | module.register_buffer('references', sd_delta[f'model.diffusion_model.{name}.references']) 235 | del sd_delta[f'model.diffusion_model.{name}.references'] 236 | 237 | m, u = model.load_state_dict(sd_delta, strict=False) 238 | 239 | if len(m) > 0 and verbose: 240 | print("missing keys:") 241 | print(m) 242 | if len(u) > 0 and verbose: 243 | print("unexpected keys:") 244 | print(u) 245 | 246 | if freeze: 247 | for param in model.parameters(): 248 | param.requires_grad = False 249 | 250 | model.eval() 251 | return model 252 | 253 | 254 | def get_configs_path() -> str: 255 | """ 256 | Get the `configs` directory. 257 | For a working copy, this is the one in the root of the repository, 258 | but for an installed copy, it's in the `sgm` package (see pyproject.toml). 259 | """ 260 | this_dir = os.path.dirname(__file__) 261 | candidates = ( 262 | os.path.join(this_dir, "configs"), 263 | os.path.join(this_dir, "..", "configs"), 264 | ) 265 | for candidate in candidates: 266 | candidate = os.path.abspath(candidate) 267 | if os.path.isdir(candidate): 268 | return candidate 269 | raise FileNotFoundError(f"Could not find SGM configs in {candidates}") 270 | 271 | 272 | def get_nested_attribute(obj, attribute_path, depth=None, return_key=False): 273 | """ 274 | Will return the result of a recursive get attribute call. 275 | E.g.: 276 | a.b.c 277 | = getattr(getattr(a, "b"), "c") 278 | = get_nested_attribute(a, "b.c") 279 | If any part of the attribute call is an integer x with current obj a, will 280 | try to call a[x] instead of a.x first. 281 | """ 282 | attributes = attribute_path.split(".") 283 | if depth is not None and depth > 0: 284 | attributes = attributes[:depth] 285 | assert len(attributes) > 0, "At least one attribute should be selected" 286 | current_attribute = obj 287 | current_key = None 288 | for level, attribute in enumerate(attributes): 289 | current_key = ".".join(attributes[: level + 1]) 290 | try: 291 | id_ = int(attribute) 292 | current_attribute = current_attribute[id_] 293 | except ValueError: 294 | current_attribute = getattr(current_attribute, attribute) 295 | 296 | return (current_attribute, current_key) if return_key else current_attribute 297 | --------------------------------------------------------------------------------