├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── configs └── lvcd.yaml ├── examples └── lvcd_test_example_01.json ├── inference ├── model_hack.py ├── sample_func.py └── test │ └── sample_1 │ └── sample_1.mp4 ├── models ├── csvd.py └── layers.py ├── nodes.py ├── requirements.txt └── sgm ├── __init__.py ├── models ├── __init__.py ├── autoencoder.py └── diffusion.py ├── modules ├── __init__.py ├── attention.py ├── autoencoding │ ├── __init__.py │ ├── losses │ │ ├── __init__.py │ │ ├── discriminator_loss.py │ │ └── lpips.py │ ├── lpips │ │ ├── __init__.py │ │ ├── loss │ │ │ ├── .gitignore │ │ │ ├── LICENSE │ │ │ ├── __init__.py │ │ │ └── lpips.py │ │ ├── model │ │ │ ├── LICENSE │ │ │ ├── __init__.py │ │ │ └── model.py │ │ ├── util.py │ │ └── vqperceptual.py │ ├── regularizers │ │ ├── __init__.py │ │ ├── base.py │ │ └── quantize.py │ └── temporal_ae.py ├── diffusionmodules │ ├── __init__.py │ ├── denoiser.py │ ├── denoiser_scaling.py │ ├── denoiser_weighting.py │ ├── discretizer.py │ ├── guiders.py │ ├── loss.py │ ├── loss_weighting.py │ ├── model.py │ ├── openaimodel.py │ ├── sampling.py │ ├── sampling_utils.py │ ├── sigma_sampling.py │ ├── util.py │ ├── video_model.py │ └── wrappers.py ├── distributions │ ├── __init__.py │ └── distributions.py ├── ema.py ├── encoders │ ├── __init__.py │ └── modules.py └── video_attention.py └── util.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | 3 | training/ 4 | lightning_logs/ 5 | image_log/ 6 | 7 | *.pth 8 | *.pt 9 | *.ckpt 10 | *.safetensors 11 | 12 | # Byte-compiled / optimized / DLL files 13 | __pycache__/ 14 | *.py[cod] 15 | *$py.class 16 | 17 | # C extensions 18 | *.so 19 | 20 | # Distribution / packaging 21 | .Python 22 | build/ 23 | develop-eggs/ 24 | dist/ 25 | downloads/ 26 | eggs/ 27 | .eggs/ 28 | lib/ 29 | lib64/ 30 | parts/ 31 | sdist/ 32 | var/ 33 | wheels/ 34 | pip-wheel-metadata/ 35 | share/python-wheels/ 36 | *.egg-info/ 37 | .installed.cfg 38 | *.egg 39 | MANIFEST 40 | 41 | # PyInstaller 42 | # Usually these files are written by a python script from a template 43 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 44 | *.manifest 45 | *.spec 46 | 47 | # Installer logs 48 | pip-log.txt 49 | pip-delete-this-directory.txt 50 | 51 | # Unit test / coverage reports 52 | htmlcov/ 53 | .tox/ 54 | .nox/ 55 | .coverage 56 | .coverage.* 57 | .cache 58 | nosetests.xml 59 | coverage.xml 60 | *.cover 61 | *.py,cover 62 | .hypothesis/ 63 | .pytest_cache/ 64 | 65 | # Translations 66 | *.mo 67 | *.pot 68 | 69 | # Django stuff: 70 | *.log 71 | local_settings.py 72 | db.sqlite3 73 | db.sqlite3-journal 74 | 75 | # Flask stuff: 76 | instance/ 77 | .webassets-cache 78 | 79 | # Scrapy stuff: 80 | .scrapy 81 | 82 | # Sphinx documentation 83 | docs/_build/ 84 | 85 | # PyBuilder 86 | target/ 87 | 88 | # Jupyter Notebook 89 | .ipynb_checkpoints 90 | 91 | # IPython 92 | profile_default/ 93 | ipython_config.py 94 | 95 | # pyenv 96 | .python-version 97 | 98 | # pipenv 99 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 100 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 101 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 102 | # install all needed dependencies. 103 | #Pipfile.lock 104 | 105 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 106 | __pypackages__/ 107 | 108 | # Celery stuff 109 | celerybeat-schedule 110 | celerybeat.pid 111 | 112 | # SageMath parsed files 113 | *.sage.py 114 | 115 | # Environments 116 | .env 117 | .venv 118 | env/ 119 | venv/ 120 | ENV/ 121 | env.bak/ 122 | venv.bak/ 123 | 124 | # Spyder project settings 125 | .spyderproject 126 | .spyproject 127 | 128 | # Rope project settings 129 | .ropeproject 130 | 131 | # mkdocs documentation 132 | /site 133 | 134 | # mypy 135 | .mypy_cache/ 136 | .dmypy.json 137 | dmypy.json 138 | 139 | # Pyre type checker 140 | .pyre/ 141 | 142 | 143 | *.safetensors 144 | *.ckpt 145 | 146 | checkpoints -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ComfyUI wrapper nodes for LVCD 2 | 3 | 4 | https://github.com/user-attachments/assets/6d4d5afd-f8b1-499b-8872-d5d2f6929da4 5 | 6 | Requires SVD model, seems to work best with the original one, but runs with 1.1 and XT as well, this is loaded normally from ComfyUI/models/checkpoints: 7 | 8 | https://huggingface.co/stabilityai/stable-video-diffusion-img2vid 9 | 10 | fp16 version: 11 | 12 | https://huggingface.co/Kijai/LVCD-pruned/blob/main/svd-fp16.safetensors 13 | 14 | LVCD model itself goes to ComfyUI/models/lvcd (autodownloaded if it doesn't exist): 15 | 16 | https://huggingface.co/Kijai/LVCD-pruned/blob/main/lvcd-fp16.safetensors 17 | 18 | Original repo: 19 | 20 | https://github.com/luckyhzt/LVCD 21 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS 2 | 3 | __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] -------------------------------------------------------------------------------- /configs/lvcd.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-5 3 | target: .models.csvd.VideoDiffusionEngine 4 | params: 5 | scale_factor: 0.18215 6 | disable_first_stage_autocast: True 7 | ckpt_path: checkpoints/svd.safetensors 8 | control_model_path: Null 9 | init_from_unet: True 10 | sd_locked: False 11 | drop_first_stage_model: True 12 | 13 | denoiser_config: 14 | target: .sgm.modules.diffusionmodules.denoiser.Denoiser 15 | params: 16 | scaling_config: 17 | target: .sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise 18 | 19 | network_config: 20 | target: .models.csvd.ControlledVideoUNet 21 | params: 22 | adm_in_channels: 768 23 | num_classes: sequential 24 | use_checkpoint: True 25 | in_channels: 8 26 | out_channels: 4 27 | model_channels: 320 28 | attention_resolutions: [4, 2, 1] 29 | num_res_blocks: 2 30 | channel_mult: [1, 2, 4, 4] 31 | num_head_channels: 64 32 | use_linear_in_transformer: True 33 | transformer_depth: 1 34 | context_dim: 1024 35 | spatial_transformer_attn_type: softmax-xformers 36 | extra_ff_mix_layer: True 37 | use_spatial_context: True 38 | merge_strategy: learned_with_images 39 | video_kernel_size: [3, 1, 1] 40 | temporal_attn_type: .models.layers.TemporalAttention_Masked 41 | spatial_self_attn_type: .models.layers.ReferenceAttention 42 | conv3d_type: .models.layers.Conv3d_Masked 43 | trainable_layers: ['TemporalAttention_Masked', 'ReferenceAttention'] 44 | 45 | controlnet_config: 46 | target: .models.csvd.ControlNet 47 | params: 48 | adm_in_channels: 768 49 | num_classes: sequential 50 | use_checkpoint: True 51 | in_channels: 8 52 | model_channels: 320 53 | hint_channels: 3 54 | attention_resolutions: [4, 2, 1] 55 | num_res_blocks: 2 56 | channel_mult: [1, 2, 4, 4] 57 | num_head_channels: 64 58 | use_linear_in_transformer: True 59 | transformer_depth: 1 60 | context_dim: 1024 61 | spatial_transformer_attn_type: softmax-xformers 62 | extra_ff_mix_layer: True 63 | use_spatial_context: True 64 | merge_strategy: learned_with_images 65 | video_kernel_size: [3, 1, 1] 66 | temporal_attn_type: .models.layers.TemporalAttention_Masked 67 | spatial_self_attn_type: .models.layers.ReferenceAttention 68 | conv3d_type: .models.layers.Conv3d_Masked 69 | 70 | conditioner_config: 71 | target: .sgm.modules.GeneralConditioner 72 | params: 73 | emb_models: 74 | - is_trainable: False 75 | input_key: cond_frames_without_noise 76 | target: .sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder 77 | params: 78 | n_cond_frames: 1 79 | n_copies: 1 80 | open_clip_embedding_config: 81 | target: .sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder 82 | params: 83 | freeze: True 84 | init_device : cuda:0 85 | 86 | - input_key: fps_id 87 | is_trainable: False 88 | target: .sgm.modules.encoders.modules.ConcatTimestepEmbedderND 89 | params: 90 | outdim: 256 91 | 92 | - input_key: motion_bucket_id 93 | is_trainable: False 94 | target: .sgm.modules.encoders.modules.ConcatTimestepEmbedderND 95 | params: 96 | outdim: 256 97 | 98 | - input_key: cond_frames 99 | is_trainable: False 100 | target: .sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder 101 | params: 102 | disable_encoder_autocast: True 103 | n_cond_frames: 1 104 | n_copies: 1 105 | is_ae: True 106 | encoder_config: 107 | target: .sgm.models.autoencoder.AutoencoderKLModeOnly 108 | params: 109 | embed_dim: 4 110 | monitor: val/rec_loss 111 | ddconfig: 112 | attn_type: vanilla-xformers 113 | double_z: True 114 | z_channels: 4 115 | resolution: 256 116 | in_channels: 3 117 | out_ch: 3 118 | ch: 128 119 | ch_mult: [1, 2, 4, 4] 120 | num_res_blocks: 2 121 | attn_resolutions: [] 122 | dropout: 0.0 123 | lossconfig: 124 | target: torch.nn.Identity 125 | 126 | - input_key: cond_aug 127 | is_trainable: False 128 | target: .sgm.modules.encoders.modules.ConcatTimestepEmbedderND 129 | params: 130 | outdim: 256 131 | 132 | first_stage_config: 133 | target: .sgm.models.autoencoder.AutoencodingEngine 134 | params: 135 | loss_config: 136 | target: torch.nn.Identity 137 | regularizer_config: 138 | target: .sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer 139 | encoder_config: 140 | target: .sgm.modules.diffusionmodules.model.Encoder 141 | params: 142 | attn_type: vanilla 143 | double_z: True 144 | z_channels: 4 145 | resolution: 256 146 | in_channels: 3 147 | out_ch: 3 148 | ch: 128 149 | ch_mult: [1, 2, 4, 4] 150 | num_res_blocks: 2 151 | attn_resolutions: [] 152 | dropout: 0.0 153 | decoder_config: 154 | target: .sgm.modules.autoencoding.temporal_ae.VideoDecoder 155 | params: 156 | attn_type: vanilla 157 | double_z: True 158 | z_channels: 4 159 | resolution: 256 160 | in_channels: 3 161 | out_ch: 3 162 | ch: 128 163 | ch_mult: [1, 2, 4, 4] 164 | num_res_blocks: 2 165 | attn_resolutions: [] 166 | dropout: 0.0 167 | video_kernel_size: [3, 1, 1] 168 | 169 | sampler_config: 170 | target: .sgm.modules.diffusionmodules.sampling.EulerEDMSampler 171 | params: 172 | num_steps: 25 173 | 174 | discretization_config: 175 | target: .sgm.modules.diffusionmodules.discretizer.EDMDiscretization 176 | params: 177 | sigma_max: 700.0 178 | 179 | guider_config: 180 | target: .sgm.modules.diffusionmodules.guiders.LinearPredictionGuider 181 | params: 182 | num_frames: 14 183 | max_scale: 2.5 184 | min_scale: 1.0 185 | additional_cond_keys: ['control_hint'] 186 | 187 | loss_fn_config: 188 | target: .sgm.modules.diffusionmodules.loss.StandardDiffusionLoss 189 | params: 190 | batch2model_keys: ['num_video_frames', 'image_only_indicator'] 191 | additional_cond_keys: ['control_hint', 'crossattn_scale', 'concat_scale'] 192 | 193 | loss_weighting_config: 194 | target: .sgm.modules.diffusionmodules.loss_weighting.EDMWeighting 195 | params: 196 | sigma_data: 1.0 197 | 198 | sigma_sampler_config: 199 | target: .sgm.modules.diffusionmodules.sigma_sampling.EDMSampling 200 | params: 201 | p_mean: 1.0 202 | p_std: 1.6 203 | 204 | 205 | lightning: 206 | modelcheckpoint: 207 | params: 208 | every_n_train_steps: 1500 209 | save_last: False 210 | save_top_k: -1 211 | filename: '{epoch:04d}-{global_step:06.0f}' 212 | 213 | strategy: 214 | params: 215 | process_group_backend: gloo 216 | 217 | trainer: 218 | devices: 4,5,6,7, 219 | benchmark: True 220 | num_sanity_val_steps: 0 221 | accumulate_grad_batches: 4 222 | max_epochs: 100 223 | precision: 16-mixed 224 | 225 | 226 | data: 227 | target: .sgm.data.my_dataset.DataModuleFromConfig 228 | params: 229 | batch_size: 2 230 | num_workers: 16 231 | 232 | train: 233 | target: models.dataset.AnimeVideoDataset 234 | params: 235 | data_root: /data0/zhitong/datasets/animation_dataset 236 | size: [320, 576] 237 | motion_bucket_id: 160 238 | fps_id: 6 239 | num_frames: 15 240 | cond_aug: False 241 | nframe_range: [15, 200] 242 | uncond_prob: 0.0 243 | sketch_type: 'draw' 244 | train_clips: 'train_clips_hist' 245 | missing_controls: Null 246 | sample_stride: 1 247 | 248 | 249 | 250 | 251 | -------------------------------------------------------------------------------- /inference/model_hack.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from einops import rearrange 4 | import torch.nn.functional as F 5 | import warnings 6 | import platform 7 | from torch.nn.attention import SDPBackend, sdpa_kernel 8 | backends = [] 9 | 10 | def get_sdpa_settings(): 11 | if torch.cuda.is_available(): 12 | old_gpu = torch.cuda.get_device_properties(0).major < 7 13 | # only use Flash Attention on Ampere (8.0) or newer GPUs 14 | use_flash_attn = torch.cuda.get_device_properties(0).major >= 8 and platform.system() == 'Linux' 15 | if not use_flash_attn: 16 | warnings.warn( 17 | "Flash Attention is disabled as it requires a GPU with Ampere (8.0) CUDA capability.", 18 | category=UserWarning, 19 | stacklevel=2, 20 | ) 21 | # keep math kernel for PyTorch versions before 2.2 (Flash Attention v2 is only 22 | # available on PyTorch 2.2+, while Flash Attention v1 cannot handle all cases) 23 | pytorch_version = tuple(int(v) for v in torch.__version__.split(".")[:2]) 24 | if pytorch_version < (2, 2): 25 | warnings.warn( 26 | f"You are using PyTorch {torch.__version__} without Flash Attention v2 support. " 27 | "Consider upgrading to PyTorch 2.2+ for Flash Attention v2 (which could be faster).", 28 | category=UserWarning, 29 | stacklevel=2, 30 | ) 31 | math_kernel_on = pytorch_version < (2, 2) or not use_flash_attn 32 | else: 33 | old_gpu = True 34 | use_flash_attn = False 35 | math_kernel_on = True 36 | 37 | return old_gpu, use_flash_attn, math_kernel_on 38 | 39 | OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings() 40 | backends.append(SDPBackend.EFFICIENT_ATTENTION) 41 | if USE_FLASH_ATTN: 42 | backends.append(SDPBackend.FLASH_ATTENTION) 43 | if MATH_KERNEL_ON: 44 | backends.append(SDPBackend.MATH) 45 | 46 | def remove_all_hooks(model: torch.nn.Module) -> None: 47 | for child in model.children(): 48 | if hasattr(child, "_forward_hooks"): 49 | child._forward_hooks.clear() 50 | if hasattr(child, "_backward_hooks"): 51 | child._backward_hooks.clear() 52 | remove_all_hooks(child) 53 | 54 | 55 | class Hacked_model(nn.Module): 56 | def __init__(self, model, **kwargs): 57 | super().__init__() 58 | self.operator = Reference(model, **kwargs) 59 | 60 | def forward(self, model, step, x_in, c_noise, cond_in, **additional_model_inputs): 61 | # Register hooks 62 | self.operator.register_hooks(model) 63 | self.operator.setup(step) 64 | # Model forward 65 | out = model.apply_model(x_in, c_noise, cond_in, **additional_model_inputs) 66 | # Remove hooks 67 | self.operator.remove_hooks() 68 | 69 | return out 70 | 71 | def clear_storage(self): 72 | self.operator.clear_storage() 73 | 74 | 75 | class Operator(): 76 | def __init__(self, model): 77 | self.hook_handles = [] 78 | self.layers = self.get_hook_layers(model) 79 | 80 | def get_hook_layers(self, model): 81 | raise NotImplementedError 82 | 83 | def hook(self, module, inputs, outputs): 84 | raise NotImplementedError 85 | 86 | def setup(self, step, branch, opt): 87 | raise NotImplementedError 88 | 89 | def clear_storage(self): 90 | self.storage.clear() 91 | 92 | def register_hooks(self, model): 93 | for m in model.modules(): 94 | index = id(m) 95 | if index in self.layers.keys(): 96 | handle = m.register_forward_hook(self.hook) 97 | self.hook_handles.append(handle) 98 | 99 | def remove_hooks(self): 100 | while len(self.hook_handles) > 0: 101 | self.hook_handles[0].remove() 102 | self.hook_handles.pop(0) 103 | 104 | 105 | class Reference(Operator): 106 | def __init__(self, model, **kwargs): 107 | self.storage = nn.ParameterDict() 108 | self.overlap = kwargs['overlap'] 109 | self.nframes = kwargs['nframes'] 110 | self.refattn_amp = kwargs['refattn_amplify'] 111 | self.refattn_hook = kwargs['refattn_hook'] 112 | self.prev_steps = kwargs['prev_steps'] 113 | super().__init__(model) 114 | 115 | def setup(self, step): 116 | self.step = step 117 | 118 | def get_hook_layers(self, model): 119 | layers = dict() 120 | if self.refattn_hook: 121 | # Hook ref attention layers in ControlNet 122 | layer_name = model.control_model.spatial_self_attn_type.split('.')[-1] 123 | i = 0 124 | for name, module in model.control_model.named_modules(): 125 | if module.__class__.__name__ == layer_name and '.time_stack' not in name and '.attn1' in name: 126 | layers[id(module)] = f'cnet-refcond-{i}' 127 | i += 1 128 | # Hook ref attention layers in UNet 129 | layer_name = model.model.diffusion_model.spatial_self_attn_type.split('.')[-1] 130 | i = 0 131 | for name, module in model.model.diffusion_model.named_modules(): 132 | if module.__class__.__name__ == layer_name and '.time_stack' not in name and '.attn1' in name: 133 | layers[id(module)] = f'unet-refcond-{i}' 134 | i += 1 135 | return layers 136 | 137 | @torch.no_grad() 138 | def hook(self, module, inputs, outputs): 139 | layer = self.layers[id(module)] 140 | if 'refcond' in layer: 141 | out = self.reference_attn_forward(module, inputs, outputs) 142 | return out 143 | 144 | 145 | def reference_attn_forward(self, module, inputs, outputs): 146 | overlap = self.overlap 147 | T = self.nframes 148 | h = module.heads 149 | 150 | index = id(module) 151 | layer = self.layers[index] 152 | layer_ind = int(layer.split('-')[-1]) 153 | 154 | q = module.to_q(inputs[0]) 155 | k = module.to_k(inputs[0]) 156 | v = module.to_v(inputs[0]) 157 | 158 | olap = 3 159 | 160 | if self.mode == 'normal': 161 | indices = [ 162 | list(range(0, T)), 163 | list(range(0, overlap+1)) + [0]*(T-overlap-1), 164 | ] 165 | elif self.mode == 'prevref': 166 | if self.step < self.prev_steps: 167 | indices = [ 168 | list(range(0, 2*overlap+1)) + list(range(2*overlap+1-olap, T-olap)), 169 | list(range(0, overlap+1)) + list(range(1, overlap+1)) + [0]*(T-2*overlap-1), 170 | ] 171 | else: 172 | '''indices = [ 173 | list(range(0, 2*overlap+1)) + list(range(2*overlap+1-olap, T-olap)), 174 | list(range(0, overlap+1)) + [0]*overlap + [0]*(T-2*overlap-1), 175 | ]''' 176 | indices = [ 177 | list(range(0, T)), 178 | list(range(0, overlap+1)) + [0]*(T-overlap-1), 179 | ] 180 | elif self.mode == 'normal1': 181 | indices = [ 182 | list(range(0, T)), 183 | list(range(0, overlap+1)) + [overlap] + [0]*(T-overlap-2), 184 | ] 185 | '''elif self.mode == 'tempref': 186 | if self.step < self.prev_steps: 187 | indices = [ 188 | list(range(0, 2*overlap+1)) + [2*overlap]*olap + list(range(2*overlap+1, T-olap)), 189 | list(range(0, overlap+1)) + list(range(1, overlap+1)) + [0]*(T-2*overlap-1), 190 | ] 191 | else: 192 | indices = [ 193 | list(range(0, 2*overlap+1)) + [2*overlap]*olap + list(range(2*overlap+1, T-olap)), 194 | list(range(0, overlap+1)) + [0]*overlap + [0]*(T-2*overlap-1), 195 | ]''' 196 | 197 | 198 | k = rearrange(k, '(b t) ... -> b t ...', t=T) 199 | v = rearrange(v, '(b t) ... -> b t ...', t=T) 200 | k = torch.cat([k[:, indices[i]] for i in range(len(indices))], dim=2).clone() 201 | v = torch.cat([v[:, indices[i]] for i in range(len(indices))], dim=2).clone() 202 | k = rearrange(k, 'b t ... -> (b t) ...') 203 | v = rearrange(v, 'b t ... -> (b t) ...') 204 | 205 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) 206 | # Attention 207 | N = q.shape[-2] 208 | with sdpa_kernel(backends): 209 | if layer_ind > 12 or self.mode == 'normal': 210 | attn_bias = None 211 | else: 212 | attn_bias = torch.zeros([T, 1, N, 2*N], device=q.device, dtype=torch.float32) 213 | amplify = torch.tensor(self.refattn_amp).to(attn_bias) 214 | amplify = rearrange(amplify, 'b t -> t 1 1 b') 215 | amplify = amplify.log() 216 | attn_bias[:, :, :, :N] = amplify[:, :, :, [0]] 217 | attn_bias[:, :, :, N:] = amplify[:, :, :, [1]] 218 | out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias) 219 | del q, k, v 220 | 221 | out = rearrange(out, "b h n d -> b n (h d)", h=h) 222 | return module.to_out(out) 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | -------------------------------------------------------------------------------- /inference/sample_func.py: -------------------------------------------------------------------------------- 1 | # @title Sampling function 2 | import math 3 | from typing import Optional 4 | import copy 5 | 6 | import torch 7 | from einops import rearrange, repeat 8 | 9 | from ..sgm.util import append_dims 10 | 11 | from .model_hack import Hacked_model, remove_all_hooks 12 | from comfy.utils import ProgressBar 13 | 14 | @torch.no_grad() 15 | def sample_video(model, device, inp, arg, verbose=True): 16 | 17 | def get_indices(n_samples, overlap): 18 | indices = [] 19 | for n in range(n_samples): 20 | if n == 0: 21 | start = 1 22 | first_ref = 0 23 | second_refs = [0] * overlap 24 | else: 25 | start = end - overlap 26 | first_ref = 0 27 | second_refs = list(range(start, start+overlap)) 28 | 29 | end = start + arg.num_frames - overlap - 1 30 | frame_ind = [first_ref] + second_refs + list(range(start, end)) 31 | 32 | ref_ind = 0 33 | 34 | blend_ind = [0] + list(range(-overlap, 0)) 35 | 36 | indices.append([frame_ind, ref_ind, blend_ind]) 37 | 38 | return indices 39 | 40 | remove_all_hooks(model) 41 | overlap = arg.overlap 42 | prev_attn_steps = arg.prev_attn_steps 43 | 44 | n_samples = (len(inp.skts)-(arg.num_frames-overlap)) // (arg.num_frames-2*overlap-1) + 1 45 | blend_indices = [ list(range(0, arg.num_frames)), 46 | [0] + list(range(1, overlap+1))*2 + [overlap]*(arg.num_frames-2*overlap-1)] 47 | blend_steps = [0]*(overlap+1) + [25]*(overlap) + [0]*(arg.num_frames-2*overlap-1) 48 | indices = get_indices(n_samples=n_samples, overlap=overlap) 49 | 50 | # Initialization 51 | H, W = inp.imgs[0].shape[2:] 52 | shape = (arg.num_frames, 4, H // 8, W // 8) 53 | torch.manual_seed(arg.seed) 54 | x_T = torch.randn(shape, dtype=torch.float32, device="cpu").to(device) 55 | 56 | hacked = Hacked_model( 57 | model, overlap=overlap, nframes=arg.num_frames, 58 | refattn_hook=True, prev_steps=prev_attn_steps, 59 | refattn_amplify = [ 60 | [1.0]*(overlap+1) + [1.0]*overlap + [1.0]*3 + [1.0]*7, # Self-attention 61 | [1.0]*(overlap+1) + [10.0]*overlap + [1.0]*3 + [1.0]*7, # Ref-attention 62 | ] 63 | ) 64 | 65 | first_cond = model.encode_first_stage(inp.imgs[0].to(device)) / model.scale_factor 66 | first_conds = repeat(first_cond, 'b ... -> (b t) ...', t=arg.num_frames-overlap-1) 67 | 68 | for i, index in enumerate(indices): 69 | frame_ind, ref_ind, blend_ind = index 70 | input_img = inp.imgs[ref_ind].to(device) 71 | sketches = torch.cat([inp.skts[i] for i in frame_ind]).to(device) 72 | if i == 0: 73 | hacked.operator.mode = 'normal' 74 | add_conds = None 75 | intermediates = {'xt': None, 'denoised': None, 'x0': None} 76 | else: 77 | hacked.operator.mode = arg.ref_mode 78 | prev_conds = x0[-overlap:] / model.scale_factor 79 | add_conds = {'concat': { 80 | 'cond': torch.cat([ first_cond, prev_conds, first_conds ]), 81 | } } 82 | for k in intermediates['xt'].keys(): 83 | intermediates['denoised'][k] = intermediates['denoised'][k][blend_ind].clone() 84 | 85 | x0, intermediates = sample( 86 | model=model, device=device, x_T=x_T, input_img=input_img, 87 | additional_conditions=add_conds, controls=sketches, hacked=hacked, 88 | blend_x0=intermediates['denoised'], blend_ind=blend_indices, blend_steps=blend_steps, 89 | return_intermediate=True, **vars(arg), verbose=True, 90 | ) 91 | 92 | if i == 0: 93 | outputs = torch.cat([first_cond*model.scale_factor, x0[-14:]]).cpu() 94 | else: 95 | outputs = torch.cat([outputs[:-overlap], x0[-14:].cpu()]) 96 | 97 | old_xT = x_T.clone() 98 | x_T = torch.cat([ old_xT[[0]], old_xT[-overlap:], old_xT[-overlap:], old_xT[overlap+1:-overlap], ]) 99 | 100 | return outputs 101 | 102 | 103 | @torch.no_grad() 104 | def decode_video(model, device, latents, arg): 105 | model.en_and_decode_n_samples_a_time = arg.decoding_t 106 | 107 | N = latents.shape[0] 108 | B = arg.decoding_t 109 | olap = arg.decoding_olap 110 | f = arg.decoding_first 111 | 112 | end = 0 113 | 114 | i = 0 115 | comfy_pbar = ProgressBar(N) 116 | with torch.autocast('cuda'): 117 | while end < N: 118 | start = i * (B - f - olap) + f 119 | end = min( start + B - f, N) 120 | 121 | indices = [0]*f + list(range(start, end)) 122 | 123 | inputs = latents[indices] 124 | out = model.decode_first_stage(inputs.to(device)).cpu() 125 | out = torch.clamp(out, min=-1.0, max=1.0) 126 | if i == 0: 127 | outputs = out.clone() 128 | else: 129 | outputs = torch.cat([ outputs, out[f+olap:] ]) 130 | i += 1 131 | comfy_pbar.update(1) 132 | 133 | return outputs 134 | 135 | 136 | 137 | 138 | def sample( 139 | model, 140 | device: str, 141 | input_img: torch.Tensor, 142 | hacked = None, 143 | x_T: torch.Tensor = None, 144 | num_frames: Optional[int] = None, 145 | num_steps: Optional[int] = None, 146 | palette: Optional[torch.Tensor] = None, 147 | anchor: Optional[torch.Tensor] = None, 148 | fps_id: int = 6, 149 | motion_bucket_id: int = 127, 150 | cond_aug: float = 0.02, 151 | seed: int = 23, 152 | decoding_t: int = 14, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary. 153 | output_folder: Optional[str] = "/content/outputs", 154 | verbose: bool = True, 155 | controls: torch.Tensor = None, 156 | blend_ind = None, 157 | blend_x0: torch.Tensor = None, 158 | scale = [1.0, 1.0], 159 | return_intermediate: bool = False, 160 | input_latent: torch.Tensor = None, 161 | first_control: torch.Tensor = None, 162 | blend_steps = None, 163 | gamma = 0.0, 164 | additional_conditions = None, 165 | starting_conditions = None, 166 | cfg_combine_forward = True, 167 | **kwargs, 168 | ): 169 | """ 170 | Simple script to generate a single sample conditioned on an image `input_path` or multiple images, one for each 171 | image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`. 172 | """ 173 | seed_everything(seed) 174 | 175 | if True: 176 | H, W = input_img.shape[2:] 177 | assert input_img.shape[1] == 3 178 | F = 8 179 | C = 4 180 | shape = (num_frames, C, H // F, W // F) 181 | 182 | if motion_bucket_id > 255: 183 | print("WARNING: High motion bucket! This may lead to suboptimal performance.") 184 | if fps_id < 5: 185 | print("WARNING: Small fps value! This may lead to suboptimal performance.") 186 | if fps_id > 30: 187 | print("WARNING: Large fps value! This may lead to suboptimal performance.") 188 | 189 | value_dict = {} 190 | value_dict["motion_bucket_id"] = motion_bucket_id 191 | value_dict["fps_id"] = fps_id 192 | value_dict["cond_aug"] = cond_aug 193 | value_dict["cond_frames_without_noise"] = input_img 194 | value_dict["cond_frames"] = input_img + cond_aug * torch.randn_like(input_img) 195 | value_dict["cond_aug"] = cond_aug 196 | model.sampler.verbose = verbose 197 | model.sampler.device = device 198 | 199 | with torch.no_grad(): 200 | with torch.autocast('cuda'): 201 | # Prepare conditions 202 | print("Preparing conditions...") 203 | c, uc, additional_model_inputs = get_conditioning( 204 | model, 205 | get_unique_embedder_keys_from_conditioner(model.conditioner), 206 | value_dict, 207 | [1, num_frames], 208 | T=num_frames, 209 | input_latent=input_latent, 210 | device=device, 211 | controls=controls, palette=palette, anchor=anchor, first_control=first_control, 212 | additional_conditions=additional_conditions, 213 | ) 214 | print("conditions prepared") 215 | # Initial noise 216 | if x_T is None: 217 | randn = torch.randn(shape, dtype=torch.float32, device="cpu").to(device) 218 | else: 219 | randn = x_T.clone() 220 | 221 | # Prepare for swapping conditions 222 | if starting_conditions is not None: 223 | original_c = copy.deepcopy(c) 224 | 225 | '''Sampling''' 226 | intermediate = {'xt': {}, 'denoised': {},} 227 | with torch.no_grad(): 228 | x = randn.clone() 229 | sigmas = model.sampler.discretization(num_steps, device=device).to(torch.float32) 230 | x *= torch.sqrt(1.0 + sigmas[0] ** 2.0) 231 | num_sigmas = len(sigmas) 232 | comfy_pbar = ProgressBar(num_sigmas) 233 | for i in model.sampler.get_sigma_gen(num_sigmas): 234 | 235 | # Blending 236 | if blend_steps is not None and blend_ind is not None: 237 | blend = (i < max(blend_steps)) 238 | target_ind = [] 239 | source_ind = [] 240 | for k, b in enumerate(blend_steps): 241 | if i < b: 242 | target_ind.append(blend_ind[0][k]) 243 | source_ind.append(blend_ind[1][k]) 244 | else: 245 | blend = False 246 | 247 | if return_intermediate: 248 | intermediate['xt'][i] = x.clone() 249 | 250 | if starting_conditions is not None: 251 | if i < starting_conditions['step']: 252 | c = copy.deepcopy(original_c) 253 | for k in starting_conditions['cond'].keys(): 254 | c[k] = starting_conditions['cond'][k] 255 | else: 256 | c = original_c 257 | 258 | if True: 259 | # Prepare sigma 260 | s_ones = x.new_ones([x.shape[0]], dtype=torch.float32) 261 | sigma = s_ones * sigmas[i] 262 | next_sigma = s_ones * sigmas[i+1] 263 | sigma_hat = sigma * (gamma + 1.0) 264 | # Denoising 265 | denoised = denoise( 266 | model, hacked, i, x, c, uc, additional_model_inputs, 267 | sigma_hat, scale, cfg_combine_forward, 268 | ) 269 | # CFG guidance 270 | denoised = guidance(denoised, scale, num_frames) 271 | if return_intermediate: 272 | intermediate['denoised'][i] = denoised.clone() 273 | 274 | # x0 blending 275 | if blend and blend_x0 is not None: 276 | #denoised[target_ind] = blend_x0[num_steps-1][source_ind] 277 | denoised[target_ind] = blend_x0[i][source_ind] 278 | 279 | # Euler step 280 | d = (x - denoised) / append_dims(sigma_hat, x.ndim) 281 | dt = append_dims(next_sigma - sigma_hat, x.ndim) 282 | x = x + dt * d 283 | comfy_pbar.update(1) 284 | 285 | samples_z = x.clone().to(dtype=model.first_stage_model.dtype) 286 | 287 | if return_intermediate: 288 | return samples_z, intermediate 289 | else: 290 | return samples_z, None 291 | 292 | 293 | def get_unique_embedder_keys_from_conditioner(conditioner): 294 | return list(set([x.input_key for x in conditioner.embedders])) 295 | 296 | def get_conditioning(model, keys, value_dict, N, T, device, input_latent, additional_conditions, dtype=None, **kwargs): 297 | batch = {} 298 | batch_uc = {} 299 | 300 | for key in keys: 301 | if key == "fps_id": 302 | batch[key] = ( 303 | torch.tensor([value_dict["fps_id"]]) 304 | .to(device, dtype=dtype) 305 | .repeat(int(math.prod(N))) 306 | ) 307 | elif key == "motion_bucket_id": 308 | batch[key] = ( 309 | torch.tensor([value_dict["motion_bucket_id"]]) 310 | .to(device, dtype=dtype) 311 | .repeat(int(math.prod(N))) 312 | ) 313 | elif key == "cond_aug": 314 | batch[key] = repeat( 315 | torch.tensor([value_dict["cond_aug"]]).to(device, dtype=dtype), 316 | "1 -> b", 317 | b=math.prod(N), 318 | ) 319 | elif key == "cond_frames": 320 | batch[key] = torch.cat([ value_dict["cond_frames"] ]*N[0]) 321 | elif key == "cond_frames_without_noise": 322 | batch[key] = torch.cat([ value_dict["cond_frames_without_noise"] ]*N[0]) 323 | else: 324 | batch[key] = value_dict[key] 325 | 326 | if T is not None: 327 | batch["num_video_frames"] = T 328 | 329 | for key in batch.keys(): 330 | if key not in batch_uc and isinstance(batch[key], torch.Tensor): 331 | batch_uc[key] = torch.clone(batch[key]) 332 | 333 | c, uc = model.conditioner.get_unconditional_conditioning( 334 | batch, 335 | batch_uc=batch_uc, 336 | force_uc_zero_embeddings=[ 337 | "cond_frames", 338 | "cond_frames_without_noise", 339 | ], 340 | ) 341 | if input_latent is not None: 342 | c['concat'] = input_latent.clone() / 0.18215 343 | 344 | # from here, dtype is fp16 345 | for k in ["crossattn", "concat"]: 346 | uc[k] = repeat(uc[k], "b ... -> b t ...", t=T) 347 | uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=T) 348 | c[k] = repeat(c[k], "b ... -> b t ...", t=T) 349 | c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=T) 350 | for k in uc.keys(): 351 | uc[k] = uc[k].to(dtype=torch.float32) 352 | c[k] = c[k].to(dtype=torch.float32) 353 | 354 | if 'controls' in kwargs and kwargs['controls'] is not None: 355 | uc['control_hint'] = kwargs['controls'].to(torch.float32) 356 | c['control_hint'] = kwargs['controls'].to(torch.float32) 357 | if 'first_control' in kwargs and kwargs['first_control'] is not None: 358 | c['first_control'] = kwargs['first_control'].to(torch.float32) 359 | uc['first_control'] = torch.zeros_like(c['first_control']) 360 | if 'palette' in kwargs and kwargs['palette'] is not None: 361 | uc['palette'] = kwargs['palette'].to(torch.float32) 362 | c['palette'] = kwargs['palette'].to(torch.float32) 363 | if 'anchor' in kwargs and kwargs['anchor'] is not None: 364 | uc['anchor'] = kwargs['anchor'].to(torch.float32) 365 | c['anchor'] = kwargs['anchor'].to(torch.float32) 366 | 367 | if additional_conditions is not None: 368 | for k in additional_conditions.keys(): 369 | c[k] = additional_conditions[k]['cond'].to(torch.float32) 370 | if 'uncond' in additional_conditions[k].keys(): 371 | uc[k] = additional_conditions[k]['uncond'].to(torch.float32) 372 | else: 373 | uc[k] = additional_conditions[k]['cond'].to(torch.float32) 374 | 375 | additional_model_inputs = {} 376 | additional_model_inputs["image_only_indicator"] = torch.zeros(1, T).to(device) 377 | additional_model_inputs["num_video_frames"] = batch["num_video_frames"] 378 | 379 | for k in additional_model_inputs: 380 | if isinstance(additional_model_inputs[k], torch.Tensor): 381 | additional_model_inputs[k] = additional_model_inputs[k].to(dtype=torch.float32) 382 | 383 | return c, uc, additional_model_inputs 384 | 385 | def denoise( 386 | model, hacked, step, x, 387 | c, uc, additional_model_inputs, 388 | sigma_hat, scale, cfg_combine_forward, 389 | ): 390 | # Prepare model input 391 | if scale[1] != 1.0 and cfg_combine_forward: 392 | cond_in = dict() 393 | if additional_model_inputs['image_only_indicator'].shape[0] == 1: 394 | additional_model_inputs["image_only_indicator"] = additional_model_inputs["image_only_indicator"].repeat(2, 1) 395 | for k in c: 396 | if k in ["vector", "crossattn", "concat"] + model.sampler.guider.additional_cond_keys: 397 | cond_in[k] = torch.cat((uc[k], c[k]), 0) 398 | else: 399 | assert c[k] == uc[k] 400 | cond_in[k] = c[k] 401 | x_in = torch.cat([x] * 2) 402 | s_in = torch.cat([sigma_hat] * 2) 403 | else: 404 | cond_in = c 405 | x_in = x 406 | s_in = sigma_hat 407 | 408 | if hacked is not None: 409 | model_forward = lambda inp, c_noise, cond, **add: hacked(model, step, inp, c_noise, cond, **add) 410 | else: 411 | model_forward = model.apply_model 412 | 413 | denoised = model.denoiser(model_forward, x_in, s_in, cond_in, **additional_model_inputs) 414 | 415 | if not cfg_combine_forward and scale[1] != 1.0: 416 | uc_denoised = model.denoiser(model_forward, x_in, s_in, uc, **additional_model_inputs) 417 | denoised = torch.cat([uc_denoised, denoised]) 418 | 419 | if denoised.shape[0] < x_in.shape[0]: 420 | denoised = rearrange(denoised, '(b t) ... -> b t ...', t=additional_model_inputs["num_video_frames"]-1) 421 | denoised = torch.cat([denoised[:, [0]], denoised], dim=1) 422 | denoised = rearrange(denoised, 'b t ... -> (b t) ...') 423 | 424 | return denoised 425 | 426 | def guidance(denoised, scale, num_frames): 427 | if scale[1] != 1.0: 428 | x_u, x_c = denoised.chunk(2) 429 | x_u = rearrange(x_u, "(b t) ... -> b t ...", t=num_frames) 430 | x_c = rearrange(x_c, "(b t) ... -> b t ...", t=num_frames) 431 | scales = torch.linspace(scale[0], scale[1], num_frames).unsqueeze(0) 432 | scales = repeat(scales, "1 t -> b t", b=x_u.shape[0]) 433 | scales = append_dims(scales, x_u.ndim).to(x_u.device) 434 | denoised = rearrange(x_u + scales * (x_c - x_u), "b t ... -> (b t) ...") 435 | 436 | return denoised 437 | 438 | def seed_everything(seed: int): 439 | import random, os 440 | import numpy as np 441 | import torch 442 | 443 | random.seed(seed) 444 | os.environ['PYTHONHASHSEED'] = str(seed) 445 | np.random.seed(seed) 446 | torch.manual_seed(seed) 447 | torch.cuda.manual_seed(seed) 448 | torch.backends.cudnn.deterministic = True 449 | torch.backends.cudnn.benchmark = True -------------------------------------------------------------------------------- /inference/test/sample_1/sample_1.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kijai/ComfyUI-LVCDWrapper/081c8180029b1b5eb8f416e079456311ff467c83/inference/test/sample_1/sample_1.mp4 -------------------------------------------------------------------------------- /models/layers.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from inspect import isfunction 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from einops import rearrange, repeat 7 | from packaging import version 8 | from torch import nn 9 | 10 | import comfy.ops 11 | ops = comfy.ops.manual_cast 12 | 13 | logpy = logging.getLogger(__name__) 14 | 15 | backends = [] 16 | 17 | if version.parse(torch.__version__) >= version.parse("2.0.0"): 18 | SDP_IS_AVAILABLE = True 19 | from torch.nn.attention import SDPBackend, sdpa_kernel 20 | 21 | BACKEND_MAP = { 22 | SDPBackend.MATH: [SDPBackend.MATH], 23 | SDPBackend.FLASH_ATTENTION: [SDPBackend.FLASH_ATTENTION], 24 | SDPBackend.EFFICIENT_ATTENTION: [SDPBackend.EFFICIENT_ATTENTION], 25 | None: [SDPBackend.MATH, SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] 26 | } 27 | else: 28 | from contextlib import nullcontext 29 | 30 | SDP_IS_AVAILABLE = False 31 | sdp_kernel = nullcontext 32 | BACKEND_MAP = {} 33 | logpy.warn( 34 | f"No SDP backend available, likely because you are running in pytorch " 35 | f"versions < 2.0. In fact, you are using PyTorch {torch.__version__}. " 36 | f"You might want to consider upgrading." 37 | ) 38 | 39 | try: 40 | import xformers 41 | import xformers.ops 42 | 43 | XFORMERS_IS_AVAILABLE = True 44 | except: 45 | XFORMERS_IS_AVAILABLE = False 46 | logpy.warn("no module 'xformers'. Processing without...") 47 | 48 | 49 | '''This temporal attention replace the original one in SVD to disable the temporal 50 | attentions between the first frame (reference path) and the remaining 14 frames (video path).''' 51 | class TemporalAttention_Masked(nn.Module): 52 | def __init__( 53 | self, 54 | query_dim, 55 | context_dim=None, 56 | heads=8, 57 | dim_head=64, 58 | dropout=0.0, 59 | backend=None, 60 | ): 61 | super().__init__() 62 | inner_dim = dim_head * heads 63 | context_dim = default(context_dim, query_dim) 64 | 65 | self.scale = dim_head**-0.5 66 | self.heads = heads 67 | 68 | self.to_q = ops.Linear(query_dim, inner_dim, bias=False) 69 | self.to_k = ops.Linear(context_dim, inner_dim, bias=False) 70 | self.to_v = ops.Linear(context_dim, inner_dim, bias=False) 71 | 72 | self.to_out = nn.Sequential( 73 | ops.Linear(inner_dim, query_dim), nn.Dropout(dropout) 74 | ) 75 | self.backend = backend 76 | 77 | def forward( 78 | self, 79 | x, 80 | context=None, 81 | mask=None, 82 | additional_tokens=None, 83 | n_times_crossframe_attn_in_self=0, 84 | ): 85 | if hasattr(self, '_forward_hooks') and len(self._forward_hooks) > 0: 86 | # If hooked do nothing 87 | return x 88 | else: 89 | return self._forward(x, context, mask, additional_tokens, n_times_crossframe_attn_in_self) 90 | 91 | def _forward( 92 | self, 93 | x, 94 | context=None, 95 | mask=None, 96 | additional_tokens=None, 97 | n_times_crossframe_attn_in_self=0, 98 | ): 99 | h = self.heads 100 | 101 | if mask is None: 102 | T = x.shape[-2] 103 | dt = T - 14 104 | mask = torch.ones(T, T).to(x) 105 | mask[:, :dt] = 0.0 106 | mask[:dt, :] = 0.0 107 | inds = [t for t in range(dt)] 108 | mask[inds, inds] = 1.0 109 | mask = rearrange(mask, 'h w -> 1 1 h w') 110 | mask = mask.bool() 111 | 112 | if additional_tokens is not None: 113 | # get the number of masked tokens at the beginning of the output sequence 114 | n_tokens_to_mask = additional_tokens.shape[1] 115 | # add additional token 116 | x = torch.cat([additional_tokens, x], dim=1) 117 | 118 | q = self.to_q(x) 119 | context = default(context, x) 120 | k = self.to_k(context) 121 | v = self.to_v(context) 122 | 123 | if n_times_crossframe_attn_in_self: 124 | # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439 125 | assert x.shape[0] % n_times_crossframe_attn_in_self == 0 126 | n_cp = x.shape[0] // n_times_crossframe_attn_in_self 127 | k = repeat( 128 | k[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp 129 | ) 130 | v = repeat( 131 | v[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp 132 | ) 133 | 134 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) 135 | 136 | backends.extend(BACKEND_MAP[self.backend]) 137 | 138 | with sdpa_kernel(backends): 139 | # print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape) 140 | out = F.scaled_dot_product_attention( 141 | q, k, v, attn_mask=mask 142 | ) # scale is dim_head ** -0.5 per default 143 | 144 | del q, k, v 145 | out = rearrange(out, "b h n d -> b n (h d)", h=h) 146 | 147 | if additional_tokens is not None: 148 | # remove additional token 149 | out = out[:, n_tokens_to_mask:] 150 | return self.to_out(out) 151 | 152 | 153 | '''The reference attention which replace the original spatial self-attention layers in SVD.''' 154 | class ReferenceAttention(nn.Module): 155 | def __init__( 156 | self, 157 | query_dim, 158 | context_dim=None, 159 | heads=8, 160 | dim_head=64, 161 | dropout=0.0, 162 | backend=None, 163 | ): 164 | super().__init__() 165 | inner_dim = dim_head * heads 166 | context_dim = default(context_dim, query_dim) 167 | 168 | self.scale = dim_head**-0.5 169 | self.heads = heads 170 | 171 | self.to_q = ops.Linear(query_dim, inner_dim, bias=False) 172 | self.to_k = ops.Linear(context_dim, inner_dim, bias=False) 173 | self.to_v = ops.Linear(context_dim, inner_dim, bias=False) 174 | 175 | self.to_out = nn.Sequential( 176 | ops.Linear(inner_dim, query_dim), nn.Dropout(dropout) 177 | ) 178 | self.backend = backend 179 | 180 | def forward( 181 | self, 182 | x, 183 | context=None, 184 | mask=None, 185 | additional_tokens=None, 186 | n_times_crossframe_attn_in_self=0, 187 | ): 188 | if hasattr(self, '_forward_hooks') and len(self._forward_hooks) > 0: 189 | # If hooked do nothing 190 | return x 191 | else: 192 | return self._forward(x, context, mask, additional_tokens, n_times_crossframe_attn_in_self) 193 | 194 | def _forward( 195 | self, 196 | x, 197 | context=None, 198 | mask=None, 199 | additional_tokens=None, 200 | n_times_crossframe_attn_in_self=0, 201 | ): 202 | B = x.shape[0] // 14 203 | T = x.shape[0] // B 204 | h = self.heads 205 | 206 | if additional_tokens is not None: 207 | # get the number of masked tokens at the beginning of the output sequence 208 | n_tokens_to_mask = additional_tokens.shape[1] 209 | # add additional token 210 | x = torch.cat([additional_tokens, x], dim=1) 211 | 212 | q = self.to_q(x) 213 | context = default(context, x) 214 | k = self.to_k(context) 215 | v = self.to_v(context) 216 | # Refconcat: Q [K, K0] [V, V0] 217 | k0 = rearrange(k, '(b t) ... -> b t ...', t=T)[:, [0]] 218 | k0 = repeat(k0, 'b t0 ... -> b (t t0) ...', t=T) 219 | k0 = rearrange(k0, 'b t ... -> (b t) ...') 220 | v0 = rearrange(v, '(b t) ... -> b t ...', t=T)[:, [0]] 221 | v0 = repeat(v0, 'b t0 ... -> b (t t0) ...', t=T) 222 | v0 = rearrange(v0, 'b t ... -> (b t) ...') 223 | k = torch.cat([k, k0], dim=1) 224 | v = torch.cat([v, v0], dim=1) 225 | 226 | if n_times_crossframe_attn_in_self: 227 | # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439 228 | assert x.shape[0] % n_times_crossframe_attn_in_self == 0 229 | n_cp = x.shape[0] // n_times_crossframe_attn_in_self 230 | k = repeat( 231 | k[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp 232 | ) 233 | v = repeat( 234 | v[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp 235 | ) 236 | 237 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) 238 | 239 | backends.extend(BACKEND_MAP[self.backend]) 240 | 241 | with sdpa_kernel(backends): 242 | # print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape) 243 | out = F.scaled_dot_product_attention( 244 | q, k, v, attn_mask=mask 245 | ) # scale is dim_head ** -0.5 per default 246 | 247 | del q, k, v 248 | out = rearrange(out, "b h n d -> b n (h d)", h=h) 249 | 250 | if additional_tokens is not None: 251 | # remove additional token 252 | out = out[:, n_tokens_to_mask:] 253 | return self.to_out(out) 254 | 255 | 256 | '''The 3D convolutional layers which disables the interactions between the 257 | first frame (reference path) and the remaining 14 frames (video path).''' 258 | class Conv3d_Masked(nn.Module): 259 | def __init__(self, in_channels, out_channels, kernel_size, padding): 260 | super().__init__() 261 | self.padding = padding 262 | self.weight = nn.Parameter( torch.zeros([out_channels, in_channels, kernel_size[0], kernel_size[1], kernel_size[2]]) ) 263 | self.bias = nn.Parameter( torch.zeros([out_channels]) ) 264 | 265 | def forward(self, x): 266 | dt = x.shape[2] - 14 267 | 268 | zeros_pad = torch.zeros_like(x[:, :, [0]]) 269 | 270 | xs = [] 271 | for i in range(dt): 272 | xs.append( x[:, :, [i]] ) 273 | xs.append( zeros_pad ) 274 | xs.append( x[:, :, dt:] ) 275 | x = torch.cat(xs, dim=2) 276 | 277 | x = torch.nn.functional.conv3d( 278 | input=x, 279 | weight=self.weight, 280 | bias=self.bias, 281 | padding=self.padding, 282 | ) 283 | 284 | out_ind = [2*i for i in range(dt)] 285 | 286 | x = torch.cat([ x[:, :, out_ind], x[:, :, 2*dt:] ], dim=2) 287 | 288 | return x 289 | 290 | 291 | 292 | def default(val, d): 293 | if exists(val): 294 | return val 295 | return d() if isfunction(d) else d 296 | 297 | def exists(val): 298 | return val is not None -------------------------------------------------------------------------------- /nodes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import folder_paths 4 | import comfy.model_management as mm 5 | 6 | import argparse 7 | from omegaconf import OmegaConf 8 | import logging 9 | from .sgm.util import instantiate_from_config 10 | from .inference.sample_func import sample_video, decode_video 11 | 12 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') 13 | log = logging.getLogger(__name__) 14 | 15 | script_directory = os.path.dirname(os.path.abspath(__file__)) 16 | 17 | class LoadLVCDModel: 18 | @classmethod 19 | def INPUT_TYPES(s): 20 | return { 21 | "required": { 22 | "model": (folder_paths.get_filename_list("checkpoints"), {"tooltip": "Normal SVD model, default is the normal very first non-XT SVD"} ), 23 | "use_xformers": ("BOOLEAN", {"default": False}), 24 | }, 25 | "optional": { 26 | "precision": (["fp16", "fp32", "bf16"], 27 | {"default": "fp16"} 28 | ), 29 | } 30 | } 31 | 32 | RETURN_TYPES = ("LVCDPIPE",) 33 | RETURN_NAMES = ("LVCD_pipe", ) 34 | FUNCTION = "loadmodel" 35 | CATEGORY = "ComfyUI-LVCDWrapper" 36 | 37 | def loadmodel(self, model, precision, use_xformers): 38 | 39 | device = mm.get_torch_device() 40 | offload_device = mm.unet_offload_device() 41 | dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision] 42 | mm.soft_empty_cache() 43 | 44 | svd_model_path = folder_paths.get_full_path_or_raise("checkpoints", model) 45 | download_path = os.path.join(folder_paths.models_dir, "lvcd") 46 | lvcd_path = os.path.join(download_path, "lvcd-fp16.safetensors") 47 | 48 | if not os.path.exists(lvcd_path): 49 | log.info(f"Downloading LVCD model to: {lvcd_path}") 50 | from huggingface_hub import snapshot_download 51 | 52 | snapshot_download( 53 | repo_id="Kijai/LVCD-pruned", 54 | local_dir=download_path, 55 | local_dir_use_symlinks=False, 56 | ) 57 | 58 | config_path = os.path.join(script_directory, "configs", "lvcd.yaml") 59 | config = OmegaConf.load(config_path) 60 | config.model.params.drop_first_stage_model = False 61 | config.model.params.init_from_unet = False 62 | 63 | if use_xformers: 64 | config.model.params.network_config.params.spatial_transformer_attn_type = 'softmax-xformers' 65 | config.model.params.controlnet_config.params.spatial_transformer_attn_type = 'softmax-xformers' 66 | config.model.params.conditioner_config.params.emb_models[3].params.encoder_config.params.ddconfig.attn_type = 'vanilla-xformers' 67 | else: 68 | config.model.params.network_config.params.spatial_transformer_attn_type = 'softmax' 69 | config.model.params.controlnet_config.params.spatial_transformer_attn_type = 'softmax' 70 | config.model.params.conditioner_config.params.emb_models[3].params.encoder_config.params.ddconfig.attn_type = 'vanilla' 71 | 72 | config.model.params.ckpt_path = svd_model_path 73 | config.model.params.control_model_path = lvcd_path 74 | 75 | with torch.device(device): 76 | model = instantiate_from_config(config.model).to(device).eval().requires_grad_(False) 77 | 78 | model.model.to(dtype) 79 | model.control_model.to(dtype) 80 | model.eval() 81 | model = model.requires_grad_(False) 82 | 83 | lvcd_pipe = { 84 | "model": model, 85 | "dtype": dtype, 86 | } 87 | 88 | return (lvcd_pipe,) 89 | 90 | class LVCDSampler: 91 | @classmethod 92 | def INPUT_TYPES(s): 93 | return { 94 | "required": { 95 | "LVCD_pipe": ("LVCDPIPE",), 96 | "ref_images": ("IMAGE",), 97 | "sketch_images": ("IMAGE",), 98 | "num_frames": ("INT", {"default": 19, "min": 15, "max": 100, "step": 1}), 99 | "num_steps": ("INT", {"default": 25, "min": 1, "max": 100, "step": 1}), 100 | "fps_id": ("INT", {"default": 6, "min": 1, "max": 100, "step": 1}), 101 | "motion_bucket_id": ("INT", {"default": 160, "min": 0, "max": 1000, "step": 1}), 102 | "cond_aug": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}), 103 | "overlap": ("INT", {"default": 4, "min": 1, "max": 100, "step": 1}), 104 | "prev_attn_steps": ("INT", {"default": 25, "min": 1, "max": 100, "step": 1}), 105 | "seed": ("INT", {"default": 123, "min": 0, "max": 2**32, "step": 1}), 106 | "keep_model_loaded": ("BOOLEAN", {"default": False}), 107 | }, 108 | } 109 | 110 | RETURN_TYPES = ("LVCDPIPE", "SVDSAMPLES",) 111 | RETURN_NAMES = ("LVCD_pipe", "samples",) 112 | FUNCTION = "loadmodel" 113 | CATEGORY = "ComfyUI-LVCDWrapper" 114 | 115 | def loadmodel(self, LVCD_pipe, ref_images, sketch_images, num_frames, num_steps, fps_id, motion_bucket_id, cond_aug, overlap, 116 | prev_attn_steps, seed, keep_model_loaded): 117 | 118 | device = mm.get_torch_device() 119 | offload_device = mm.unet_offload_device() 120 | mm.soft_empty_cache() 121 | 122 | model = LVCD_pipe["model"] 123 | 124 | inp = argparse.ArgumentParser() 125 | B, H, W, C = ref_images.shape 126 | inp.resolution = [H, W] 127 | 128 | inp.imgs = [] 129 | inp.skts = [] 130 | 131 | ref_images = ref_images.permute(0, 3, 1, 2).to(device) * 2 - 1 132 | for ref_img in ref_images: 133 | print(ref_img.shape) 134 | inp.imgs.append(ref_img.unsqueeze(0)) 135 | sketch_images = sketch_images.permute(0, 3, 1, 2).to(device) 136 | for skt in sketch_images: 137 | inp.skts.append(skt.unsqueeze(0)) 138 | 139 | arg = argparse.ArgumentParser() 140 | 141 | arg.ref_mode = 'prevref' 142 | arg.num_frames = num_frames 143 | arg.num_steps = num_steps 144 | arg.overlap = overlap 145 | arg.prev_attn_steps = prev_attn_steps 146 | arg.scale = [1.0, 1.0] 147 | arg.seed = seed 148 | arg.fps_id = fps_id 149 | arg.motion_bucket_id = motion_bucket_id 150 | arg.cond_aug = cond_aug 151 | 152 | model.to(device) 153 | model.control_model.to(device) 154 | samples = sample_video(model, device, inp, arg, verbose=True) 155 | if not keep_model_loaded: 156 | model.to(offload_device) 157 | model.control_model.to(offload_device) 158 | 159 | return (LVCD_pipe, samples) 160 | 161 | class LVCDDecoder: 162 | @classmethod 163 | def INPUT_TYPES(s): 164 | return { 165 | "required": { 166 | "LVCD_pipe": ("LVCDPIPE",), 167 | "samples": ("SVDSAMPLES",), 168 | "decoding_t": ("INT", {"default": 10, "min": 1, "max": 100, "step": 1}), 169 | "decoding_olap": ("INT", {"default": 3, "min": 0, "max": 100, "step": 1}), 170 | "decoding_first": ("INT", {"default": 1, "min": 0, "max": 100, "step": 1}), 171 | }, 172 | } 173 | 174 | RETURN_TYPES = ("IMAGE",) 175 | RETURN_NAMES = ("images", ) 176 | FUNCTION = "loadmodel" 177 | CATEGORY = "ComfyUI-LVCDWrapper" 178 | 179 | def loadmodel(self, LVCD_pipe, samples, decoding_t, decoding_olap, decoding_first): 180 | 181 | device = mm.get_torch_device() 182 | offload_device = mm.unet_offload_device() 183 | mm.soft_empty_cache() 184 | 185 | model = LVCD_pipe["model"] 186 | 187 | arg = argparse.ArgumentParser() 188 | 189 | arg.decoding_t = decoding_t 190 | arg.decoding_olap = decoding_olap 191 | arg.decoding_first = decoding_first 192 | 193 | model.first_stage_model.to(device) 194 | frames = decode_video(model, device, samples, arg) 195 | model.first_stage_model.to(offload_device) 196 | 197 | min_value = frames.min() 198 | max_value = frames.max() 199 | 200 | frames = (frames - min_value) / (max_value - min_value) 201 | frames = frames.permute(0, 2, 3, 1).cpu().float() 202 | 203 | 204 | return (frames,) 205 | 206 | NODE_CLASS_MAPPINGS = { 207 | "LoadLVCDModel": LoadLVCDModel, 208 | "LVCDSampler": LVCDSampler, 209 | "LVCDDecoder": LVCDDecoder, 210 | } 211 | NODE_DISPLAY_NAME_MAPPINGS = { 212 | "LoadLVCDModel": "Load LVCD Model", 213 | "LVCDSampler": "LVCD Sampler", 214 | "LVCDDecoder": "LVCD Decoder", 215 | } 216 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | omegaconf>=2.3.0 2 | clip @ git+https://github.com/openai/CLIP.git 3 | open-clip-torch>=2.20.0 4 | pytorch-lightning>=2.0.1 5 | timm>=0.9.2 -------------------------------------------------------------------------------- /sgm/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import AutoencodingEngine, DiffusionEngine 2 | from .util import get_configs_path, instantiate_from_config 3 | 4 | __version__ = "0.1.0" 5 | -------------------------------------------------------------------------------- /sgm/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .autoencoder import AutoencodingEngine 2 | from .diffusion import DiffusionEngine 3 | -------------------------------------------------------------------------------- /sgm/models/diffusion.py: -------------------------------------------------------------------------------- 1 | import math 2 | from contextlib import contextmanager 3 | from typing import Any, Dict, List, Optional, Tuple, Union 4 | 5 | import pytorch_lightning as pl 6 | import torch 7 | from omegaconf import ListConfig, OmegaConf 8 | from safetensors.torch import load_file as load_safetensors 9 | from torch.optim.lr_scheduler import LambdaLR 10 | from einops import rearrange 11 | 12 | from ..modules import UNCONDITIONAL_CONFIG 13 | from ..modules.autoencoding.temporal_ae import VideoDecoder 14 | from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER 15 | from ..modules.ema import LitEma 16 | from ..util import (default, disabled_train, get_obj_from_str, 17 | instantiate_from_config, log_txt_as_img) 18 | 19 | 20 | class DiffusionEngine(pl.LightningModule): 21 | def __init__( 22 | self, 23 | network_config, 24 | denoiser_config, 25 | first_stage_config, 26 | conditioner_config: Union[None, Dict, ListConfig, OmegaConf] = None, 27 | sampler_config: Union[None, Dict, ListConfig, OmegaConf] = None, 28 | optimizer_config: Union[None, Dict, ListConfig, OmegaConf] = None, 29 | scheduler_config: Union[None, Dict, ListConfig, OmegaConf] = None, 30 | loss_fn_config: Union[None, Dict, ListConfig, OmegaConf] = None, 31 | network_wrapper: Union[None, str] = None, 32 | ckpt_path: Union[None, str] = None, 33 | use_ema: bool = False, 34 | ema_decay_rate: float = 0.9999, 35 | scale_factor: float = 1.0, 36 | disable_first_stage_autocast=False, 37 | input_key: str = "jpg", 38 | log_keys: Union[List, None] = None, 39 | no_cond_log: bool = False, 40 | compile_model: bool = False, 41 | en_and_decode_n_samples_a_time: Optional[int] = None, 42 | ): 43 | super().__init__() 44 | self.log_keys = log_keys 45 | self.input_key = input_key 46 | self.optimizer_config = default( 47 | optimizer_config, {"target": "torch.optim.AdamW"} 48 | ) 49 | model = instantiate_from_config(network_config) 50 | self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))( 51 | model, compile_model=compile_model 52 | ) 53 | 54 | self.denoiser = instantiate_from_config(denoiser_config) 55 | self.sampler = ( 56 | instantiate_from_config(sampler_config) 57 | if sampler_config is not None 58 | else None 59 | ) 60 | self.conditioner = instantiate_from_config( 61 | default(conditioner_config, UNCONDITIONAL_CONFIG) 62 | ) 63 | self.scheduler_config = scheduler_config 64 | self._init_first_stage(first_stage_config) 65 | 66 | self.loss_fn = ( 67 | instantiate_from_config(loss_fn_config) 68 | if loss_fn_config is not None 69 | else None 70 | ) 71 | 72 | self.use_ema = use_ema 73 | if self.use_ema: 74 | self.model_ema = LitEma(self.model, decay=ema_decay_rate) 75 | print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") 76 | 77 | self.scale_factor = scale_factor 78 | self.disable_first_stage_autocast = disable_first_stage_autocast 79 | self.no_cond_log = no_cond_log 80 | 81 | if ckpt_path is not None: 82 | self.init_from_ckpt(ckpt_path) 83 | 84 | self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time 85 | 86 | def init_from_ckpt( 87 | self, 88 | path: str, 89 | ) -> None: 90 | if path.endswith("ckpt"): 91 | sd = torch.load(path, map_location="cpu")["state_dict"] 92 | elif path.endswith("safetensors"): 93 | sd = load_safetensors(path) 94 | else: 95 | raise NotImplementedError 96 | 97 | missing, unexpected = self.load_state_dict(sd, strict=False) 98 | # print( 99 | # f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys" 100 | # ) 101 | # if len(missing) > 0: 102 | # print(f"Missing Keys: {missing}") 103 | # if len(unexpected) > 0: 104 | # print(f"Unexpected Keys: {unexpected}") 105 | 106 | def _init_first_stage(self, config): 107 | model = instantiate_from_config(config)#.eval() 108 | # Train function is overwritten by a empty function to ensure no training 109 | model.train = disabled_train 110 | for param in model.parameters(): 111 | param.requires_grad = False 112 | self.first_stage_model = model 113 | 114 | def get_input(self, batch): 115 | # assuming unified data format, dataloader returns a dict. 116 | # image tensors should be scaled to -1 ... 1 and in bchw format 117 | if 'num_video_frames' in batch.keys(): 118 | for k in batch.keys(): 119 | if k not in ['num_video_frames', 'image_only_indicator', 'first_sigma']: 120 | batch[k] = rearrange(batch[k], 'b t ... -> (b t) ...') 121 | elif k in ['num_video_frames', 'first_sigma']: 122 | batch[k] = batch[k].detach().cpu().numpy()[0] 123 | return batch[self.input_key] 124 | 125 | @torch.no_grad() 126 | def decode_first_stage(self, z): 127 | z = 1.0 / self.scale_factor * z 128 | n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0]) 129 | 130 | n_rounds = math.ceil(z.shape[0] / n_samples) 131 | all_out = [] 132 | with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast): 133 | for n in range(n_rounds): 134 | if isinstance(self.first_stage_model.decoder, VideoDecoder): 135 | kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])} 136 | else: 137 | kwargs = {} 138 | out = self.first_stage_model.decode( 139 | z[n * n_samples : (n + 1) * n_samples], **kwargs 140 | ) 141 | all_out.append(out) 142 | out = torch.cat(all_out, dim=0) 143 | return out 144 | 145 | @torch.no_grad() 146 | def encode_first_stage(self, x): 147 | n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0]) 148 | n_samples = 1 149 | n_rounds = math.ceil(x.shape[0] / n_samples) 150 | all_out = [] 151 | with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast): 152 | for n in range(n_rounds): 153 | out = self.first_stage_model.encode( x[n * n_samples : (n + 1) * n_samples] ) 154 | self.first_stage_model.zero_grad(set_to_none=True) 155 | torch.cuda.empty_cache() 156 | all_out.append(out) 157 | 158 | z = torch.cat(all_out, dim=0) 159 | z = self.scale_factor * z 160 | return z 161 | 162 | def forward(self, x, batch): 163 | loss = self.loss_fn(self.model, self.denoiser, self.conditioner, x, batch) 164 | loss_mean = loss.mean() 165 | loss_dict = {"loss": loss_mean} 166 | return loss_mean, loss_dict 167 | 168 | def shared_step(self, batch: Dict) -> Any: 169 | x = self.get_input(batch) 170 | x = self.encode_first_stage(x) 171 | batch["global_step"] = self.global_step 172 | loss, loss_dict = self(x, batch) 173 | return loss, loss_dict 174 | 175 | def training_step(self, batch, batch_idx): 176 | loss, loss_dict = self.shared_step(batch) 177 | 178 | self.log_dict( 179 | loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False 180 | ) 181 | 182 | self.log( 183 | "global_step", 184 | self.global_step, 185 | prog_bar=True, 186 | logger=True, 187 | on_step=True, 188 | on_epoch=False, 189 | ) 190 | 191 | if self.scheduler_config is not None: 192 | lr = self.optimizers().param_groups[0]["lr"] 193 | self.log( 194 | "lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False 195 | ) 196 | 197 | return loss 198 | 199 | def on_train_start(self, *args, **kwargs): 200 | if self.sampler is None or self.loss_fn is None: 201 | raise ValueError("Sampler and loss function need to be set for training.") 202 | 203 | def on_train_batch_end(self, *args, **kwargs): 204 | if self.use_ema: 205 | self.model_ema(self.model) 206 | 207 | @contextmanager 208 | def ema_scope(self, context=None): 209 | if self.use_ema: 210 | self.model_ema.store(self.model.parameters()) 211 | self.model_ema.copy_to(self.model) 212 | if context is not None: 213 | print(f"{context}: Switched to EMA weights") 214 | try: 215 | yield None 216 | finally: 217 | if self.use_ema: 218 | self.model_ema.restore(self.model.parameters()) 219 | if context is not None: 220 | print(f"{context}: Restored training weights") 221 | 222 | def instantiate_optimizer_from_config(self, params, lr, cfg): 223 | return get_obj_from_str(cfg["target"])( 224 | params, lr=lr, **cfg.get("params", dict()) 225 | ) 226 | 227 | def configure_optimizers(self): 228 | lr = self.learning_rate 229 | params = list(self.model.parameters()) 230 | for embedder in self.conditioner.embedders: 231 | if embedder.is_trainable: 232 | params = params + list(embedder.parameters()) 233 | opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config) 234 | if self.scheduler_config is not None: 235 | scheduler = instantiate_from_config(self.scheduler_config) 236 | print("Setting up LambdaLR scheduler...") 237 | scheduler = [ 238 | { 239 | "scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule), 240 | "interval": "step", 241 | "frequency": 1, 242 | } 243 | ] 244 | return [opt], scheduler 245 | return opt 246 | 247 | @torch.no_grad() 248 | def sample( 249 | self, 250 | cond: Dict, 251 | uc: Union[Dict, None] = None, 252 | batch_size: int = 16, 253 | shape: Union[None, Tuple, List] = None, 254 | **kwargs, 255 | ): 256 | randn = torch.randn(batch_size, *shape).to(self.device) 257 | 258 | denoiser = lambda input, sigma, c: self.denoiser( 259 | self.model, input, sigma, c, **kwargs 260 | ) 261 | samples = self.sampler(denoiser, randn, cond, uc=uc) 262 | return samples 263 | 264 | @torch.no_grad() 265 | def log_conditionings(self, batch: Dict, n: int) -> Dict: 266 | """ 267 | Defines heuristics to log different conditionings. 268 | These can be lists of strings (text-to-image), tensors, ints, ... 269 | """ 270 | image_h, image_w = batch[self.input_key].shape[2:] 271 | log = dict() 272 | 273 | for embedder in self.conditioner.embedders: 274 | if ( 275 | (self.log_keys is None) or (embedder.input_key in self.log_keys) 276 | ) and not self.no_cond_log: 277 | x = batch[embedder.input_key][:n] 278 | if isinstance(x, torch.Tensor): 279 | if x.dim() == 1: 280 | # class-conditional, convert integer to string 281 | x = [str(x[i].detach().cpu().numpy()) for i in range(x.shape[0])] 282 | xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4) 283 | elif x.dim() == 2: 284 | # size and crop cond and the like 285 | x = [ 286 | "x".join([str(xx) for xx in x[i].tolist()]) 287 | for i in range(x.shape[0]) 288 | ] 289 | xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20) 290 | else: 291 | raise NotImplementedError() 292 | elif isinstance(x, (List, ListConfig)): 293 | if isinstance(x[0], str): 294 | # strings 295 | xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20) 296 | else: 297 | raise NotImplementedError() 298 | else: 299 | raise NotImplementedError() 300 | log[embedder.input_key] = xc 301 | return log 302 | 303 | @torch.no_grad() 304 | def log_images( 305 | self, 306 | batch: Dict, 307 | N: int = 8, 308 | sample: bool = True, 309 | ucg_keys: List[str] = None, 310 | **kwargs, 311 | ) -> Dict: 312 | conditioner_input_keys = [e.input_key for e in self.conditioner.embedders] 313 | if ucg_keys: 314 | assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), ( 315 | "Each defined ucg key for sampling must be in the provided conditioner input keys," 316 | f"but we have {ucg_keys} vs. {conditioner_input_keys}" 317 | ) 318 | else: 319 | ucg_keys = conditioner_input_keys 320 | log = dict() 321 | 322 | torch.cuda.empty_cache() 323 | 324 | x = self.get_input(batch) 325 | 326 | c, uc = self.conditioner.get_unconditional_conditioning( 327 | batch, 328 | force_uc_zero_embeddings=ucg_keys 329 | if len(self.conditioner.embedders) > 0 330 | else [], 331 | ) 332 | 333 | sampling_kwargs = {} 334 | 335 | N = min(x.shape[0], N) 336 | x = x.to(self.device)[:N] 337 | log["inputs"] = x 338 | 339 | torch.cuda.empty_cache() 340 | 341 | z = self.encode_first_stage(x) 342 | log["reconstructions"] = self.decode_first_stage(z) 343 | log.update(self.log_conditionings(batch, N)) 344 | 345 | for k in c: 346 | if isinstance(c[k], torch.Tensor): 347 | c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc)) 348 | 349 | if sample: 350 | with self.ema_scope("Plotting"): 351 | samples = self.sample( 352 | c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs 353 | ) 354 | samples = self.decode_first_stage(samples) 355 | log["samples"] = samples 356 | return log 357 | -------------------------------------------------------------------------------- /sgm/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .encoders.modules import GeneralConditioner 2 | 3 | UNCONDITIONAL_CONFIG = { 4 | "target": ".sgm.modules.GeneralConditioner", 5 | "params": {"emb_models": []}, 6 | } 7 | -------------------------------------------------------------------------------- /sgm/modules/autoencoding/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kijai/ComfyUI-LVCDWrapper/081c8180029b1b5eb8f416e079456311ff467c83/sgm/modules/autoencoding/__init__.py -------------------------------------------------------------------------------- /sgm/modules/autoencoding/losses/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | "GeneralLPIPSWithDiscriminator", 3 | "LatentLPIPS", 4 | ] 5 | 6 | from .discriminator_loss import GeneralLPIPSWithDiscriminator 7 | from .lpips import LatentLPIPS 8 | -------------------------------------------------------------------------------- /sgm/modules/autoencoding/losses/discriminator_loss.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Iterator, List, Optional, Tuple, Union 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torchvision 7 | from einops import rearrange 8 | from matplotlib import colormaps 9 | from matplotlib import pyplot as plt 10 | 11 | from ....util import default, instantiate_from_config 12 | from ..lpips.loss.lpips import LPIPS 13 | from ..lpips.model.model import weights_init 14 | from ..lpips.vqperceptual import hinge_d_loss, vanilla_d_loss 15 | 16 | 17 | class GeneralLPIPSWithDiscriminator(nn.Module): 18 | def __init__( 19 | self, 20 | disc_start: int, 21 | logvar_init: float = 0.0, 22 | disc_num_layers: int = 3, 23 | disc_in_channels: int = 3, 24 | disc_factor: float = 1.0, 25 | disc_weight: float = 1.0, 26 | perceptual_weight: float = 1.0, 27 | disc_loss: str = "hinge", 28 | scale_input_to_tgt_size: bool = False, 29 | dims: int = 2, 30 | learn_logvar: bool = False, 31 | regularization_weights: Union[None, Dict[str, float]] = None, 32 | additional_log_keys: Optional[List[str]] = None, 33 | discriminator_config: Optional[Dict] = None, 34 | ): 35 | super().__init__() 36 | self.dims = dims 37 | if self.dims > 2: 38 | print( 39 | f"running with dims={dims}. This means that for perceptual loss " 40 | f"calculation, the LPIPS loss will be applied to each frame " 41 | f"independently." 42 | ) 43 | self.scale_input_to_tgt_size = scale_input_to_tgt_size 44 | assert disc_loss in ["hinge", "vanilla"] 45 | self.perceptual_loss = LPIPS().eval() 46 | self.perceptual_weight = perceptual_weight 47 | # output log variance 48 | self.logvar = nn.Parameter( 49 | torch.full((), logvar_init), requires_grad=learn_logvar 50 | ) 51 | self.learn_logvar = learn_logvar 52 | 53 | discriminator_config = default( 54 | discriminator_config, 55 | { 56 | "target": ".sgm.modules.autoencoding.lpips.model.model.NLayerDiscriminator", 57 | "params": { 58 | "input_nc": disc_in_channels, 59 | "n_layers": disc_num_layers, 60 | "use_actnorm": False, 61 | }, 62 | }, 63 | ) 64 | 65 | self.discriminator = instantiate_from_config(discriminator_config).apply( 66 | weights_init 67 | ) 68 | self.discriminator_iter_start = disc_start 69 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss 70 | self.disc_factor = disc_factor 71 | self.discriminator_weight = disc_weight 72 | self.regularization_weights = default(regularization_weights, {}) 73 | 74 | self.forward_keys = [ 75 | "optimizer_idx", 76 | "global_step", 77 | "last_layer", 78 | "split", 79 | "regularization_log", 80 | ] 81 | 82 | self.additional_log_keys = set(default(additional_log_keys, [])) 83 | self.additional_log_keys.update(set(self.regularization_weights.keys())) 84 | 85 | def get_trainable_parameters(self) -> Iterator[nn.Parameter]: 86 | return self.discriminator.parameters() 87 | 88 | def get_trainable_autoencoder_parameters(self) -> Iterator[nn.Parameter]: 89 | if self.learn_logvar: 90 | yield self.logvar 91 | yield from () 92 | 93 | @torch.no_grad() 94 | def log_images( 95 | self, inputs: torch.Tensor, reconstructions: torch.Tensor 96 | ) -> Dict[str, torch.Tensor]: 97 | # calc logits of real/fake 98 | logits_real = self.discriminator(inputs.contiguous().detach()) 99 | if len(logits_real.shape) < 4: 100 | # Non patch-discriminator 101 | return dict() 102 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 103 | # -> (b, 1, h, w) 104 | 105 | # parameters for colormapping 106 | high = max(logits_fake.abs().max(), logits_real.abs().max()).item() 107 | cmap = colormaps["PiYG"] # diverging colormap 108 | 109 | def to_colormap(logits: torch.Tensor) -> torch.Tensor: 110 | """(b, 1, ...) -> (b, 3, ...)""" 111 | logits = (logits + high) / (2 * high) 112 | logits_np = cmap(logits.cpu().numpy())[..., :3] # truncate alpha channel 113 | # -> (b, 1, ..., 3) 114 | logits = torch.from_numpy(logits_np).to(logits.device) 115 | return rearrange(logits, "b 1 ... c -> b c ...") 116 | 117 | logits_real = torch.nn.functional.interpolate( 118 | logits_real, 119 | size=inputs.shape[-2:], 120 | mode="nearest", 121 | antialias=False, 122 | ) 123 | logits_fake = torch.nn.functional.interpolate( 124 | logits_fake, 125 | size=reconstructions.shape[-2:], 126 | mode="nearest", 127 | antialias=False, 128 | ) 129 | 130 | # alpha value of logits for overlay 131 | alpha_real = torch.abs(logits_real) / high 132 | alpha_fake = torch.abs(logits_fake) / high 133 | # -> (b, 1, h, w) in range [0, 0.5] 134 | # alpha value of lines don't really matter, since the values are the same 135 | # for both images and logits anyway 136 | grid_alpha_real = torchvision.utils.make_grid(alpha_real, nrow=4) 137 | grid_alpha_fake = torchvision.utils.make_grid(alpha_fake, nrow=4) 138 | grid_alpha = 0.8 * torch.cat((grid_alpha_real, grid_alpha_fake), dim=1) 139 | # -> (1, h, w) 140 | # blend logits and images together 141 | 142 | # prepare logits for plotting 143 | logits_real = to_colormap(logits_real) 144 | logits_fake = to_colormap(logits_fake) 145 | # resize logits 146 | # -> (b, 3, h, w) 147 | 148 | # make some grids 149 | # add all logits to one plot 150 | logits_real = torchvision.utils.make_grid(logits_real, nrow=4) 151 | logits_fake = torchvision.utils.make_grid(logits_fake, nrow=4) 152 | # I just love how torchvision calls the number of columns `nrow` 153 | grid_logits = torch.cat((logits_real, logits_fake), dim=1) 154 | # -> (3, h, w) 155 | 156 | grid_images_real = torchvision.utils.make_grid(0.5 * inputs + 0.5, nrow=4) 157 | grid_images_fake = torchvision.utils.make_grid( 158 | 0.5 * reconstructions + 0.5, nrow=4 159 | ) 160 | grid_images = torch.cat((grid_images_real, grid_images_fake), dim=1) 161 | # -> (3, h, w) in range [0, 1] 162 | 163 | grid_blend = grid_alpha * grid_logits + (1 - grid_alpha) * grid_images 164 | 165 | # Create labeled colorbar 166 | dpi = 100 167 | height = 128 / dpi 168 | width = grid_logits.shape[2] / dpi 169 | fig, ax = plt.subplots(figsize=(width, height), dpi=dpi) 170 | img = ax.imshow(np.array([[-high, high]]), cmap=cmap) 171 | plt.colorbar( 172 | img, 173 | cax=ax, 174 | orientation="horizontal", 175 | fraction=0.9, 176 | aspect=width / height, 177 | pad=0.0, 178 | ) 179 | img.set_visible(False) 180 | fig.tight_layout() 181 | fig.canvas.draw() 182 | # manually convert figure to numpy 183 | cbar_np = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) 184 | cbar_np = cbar_np.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 185 | cbar = torch.from_numpy(cbar_np.copy()).to(grid_logits.dtype) / 255.0 186 | cbar = rearrange(cbar, "h w c -> c h w").to(grid_logits.device) 187 | 188 | # Add colorbar to plot 189 | annotated_grid = torch.cat((grid_logits, cbar), dim=1) 190 | blended_grid = torch.cat((grid_blend, cbar), dim=1) 191 | return { 192 | "vis_logits": 2 * annotated_grid[None, ...] - 1, 193 | "vis_logits_blended": 2 * blended_grid[None, ...] - 1, 194 | } 195 | 196 | def calculate_adaptive_weight( 197 | self, nll_loss: torch.Tensor, g_loss: torch.Tensor, last_layer: torch.Tensor 198 | ) -> torch.Tensor: 199 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 200 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 201 | 202 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 203 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 204 | d_weight = d_weight * self.discriminator_weight 205 | return d_weight 206 | 207 | def forward( 208 | self, 209 | inputs: torch.Tensor, 210 | reconstructions: torch.Tensor, 211 | *, # added because I changed the order here 212 | regularization_log: Dict[str, torch.Tensor], 213 | optimizer_idx: int, 214 | global_step: int, 215 | last_layer: torch.Tensor, 216 | split: str = "train", 217 | weights: Union[None, float, torch.Tensor] = None, 218 | ) -> Tuple[torch.Tensor, dict]: 219 | if self.scale_input_to_tgt_size: 220 | inputs = torch.nn.functional.interpolate( 221 | inputs, reconstructions.shape[2:], mode="bicubic", antialias=True 222 | ) 223 | 224 | if self.dims > 2: 225 | inputs, reconstructions = map( 226 | lambda x: rearrange(x, "b c t h w -> (b t) c h w"), 227 | (inputs, reconstructions), 228 | ) 229 | 230 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 231 | if self.perceptual_weight > 0: 232 | p_loss = self.perceptual_loss( 233 | inputs.contiguous(), reconstructions.contiguous() 234 | ) 235 | rec_loss = rec_loss + self.perceptual_weight * p_loss 236 | 237 | nll_loss, weighted_nll_loss = self.get_nll_loss(rec_loss, weights) 238 | 239 | # now the GAN part 240 | if optimizer_idx == 0: 241 | # generator update 242 | if global_step >= self.discriminator_iter_start or not self.training: 243 | logits_fake = self.discriminator(reconstructions.contiguous()) 244 | g_loss = -torch.mean(logits_fake) 245 | if self.training: 246 | d_weight = self.calculate_adaptive_weight( 247 | nll_loss, g_loss, last_layer=last_layer 248 | ) 249 | else: 250 | d_weight = torch.tensor(1.0) 251 | else: 252 | d_weight = torch.tensor(0.0) 253 | g_loss = torch.tensor(0.0, requires_grad=True) 254 | 255 | loss = weighted_nll_loss + d_weight * self.disc_factor * g_loss 256 | log = dict() 257 | for k in regularization_log: 258 | if k in self.regularization_weights: 259 | loss = loss + self.regularization_weights[k] * regularization_log[k] 260 | if k in self.additional_log_keys: 261 | log[f"{split}/{k}"] = regularization_log[k].detach().float().mean() 262 | 263 | log.update( 264 | { 265 | f"{split}/loss/total": loss.clone().detach().mean(), 266 | f"{split}/loss/nll": nll_loss.detach().mean(), 267 | f"{split}/loss/rec": rec_loss.detach().mean(), 268 | f"{split}/loss/g": g_loss.detach().mean(), 269 | f"{split}/scalars/logvar": self.logvar.detach(), 270 | f"{split}/scalars/d_weight": d_weight.detach(), 271 | } 272 | ) 273 | 274 | return loss, log 275 | elif optimizer_idx == 1: 276 | # second pass for discriminator update 277 | logits_real = self.discriminator(inputs.contiguous().detach()) 278 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 279 | 280 | if global_step >= self.discriminator_iter_start or not self.training: 281 | d_loss = self.disc_factor * self.disc_loss(logits_real, logits_fake) 282 | else: 283 | d_loss = torch.tensor(0.0, requires_grad=True) 284 | 285 | log = { 286 | f"{split}/loss/disc": d_loss.clone().detach().mean(), 287 | f"{split}/logits/real": logits_real.detach().mean(), 288 | f"{split}/logits/fake": logits_fake.detach().mean(), 289 | } 290 | return d_loss, log 291 | else: 292 | raise NotImplementedError(f"Unknown optimizer_idx {optimizer_idx}") 293 | 294 | def get_nll_loss( 295 | self, 296 | rec_loss: torch.Tensor, 297 | weights: Optional[Union[float, torch.Tensor]] = None, 298 | ) -> Tuple[torch.Tensor, torch.Tensor]: 299 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar 300 | weighted_nll_loss = nll_loss 301 | if weights is not None: 302 | weighted_nll_loss = weights * nll_loss 303 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] 304 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 305 | 306 | return nll_loss, weighted_nll_loss 307 | -------------------------------------------------------------------------------- /sgm/modules/autoencoding/losses/lpips.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from ....util import default, instantiate_from_config 5 | from ..lpips.loss.lpips import LPIPS 6 | 7 | 8 | class LatentLPIPS(nn.Module): 9 | def __init__( 10 | self, 11 | decoder_config, 12 | perceptual_weight=1.0, 13 | latent_weight=1.0, 14 | scale_input_to_tgt_size=False, 15 | scale_tgt_to_input_size=False, 16 | perceptual_weight_on_inputs=0.0, 17 | ): 18 | super().__init__() 19 | self.scale_input_to_tgt_size = scale_input_to_tgt_size 20 | self.scale_tgt_to_input_size = scale_tgt_to_input_size 21 | self.init_decoder(decoder_config) 22 | self.perceptual_loss = LPIPS().eval() 23 | self.perceptual_weight = perceptual_weight 24 | self.latent_weight = latent_weight 25 | self.perceptual_weight_on_inputs = perceptual_weight_on_inputs 26 | 27 | def init_decoder(self, config): 28 | self.decoder = instantiate_from_config(config) 29 | if hasattr(self.decoder, "encoder"): 30 | del self.decoder.encoder 31 | 32 | def forward(self, latent_inputs, latent_predictions, image_inputs, split="train"): 33 | log = dict() 34 | loss = (latent_inputs - latent_predictions) ** 2 35 | log[f"{split}/latent_l2_loss"] = loss.mean().detach() 36 | image_reconstructions = None 37 | if self.perceptual_weight > 0.0: 38 | image_reconstructions = self.decoder.decode(latent_predictions) 39 | image_targets = self.decoder.decode(latent_inputs) 40 | perceptual_loss = self.perceptual_loss( 41 | image_targets.contiguous(), image_reconstructions.contiguous() 42 | ) 43 | loss = ( 44 | self.latent_weight * loss.mean() 45 | + self.perceptual_weight * perceptual_loss.mean() 46 | ) 47 | log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach() 48 | 49 | if self.perceptual_weight_on_inputs > 0.0: 50 | image_reconstructions = default( 51 | image_reconstructions, self.decoder.decode(latent_predictions) 52 | ) 53 | if self.scale_input_to_tgt_size: 54 | image_inputs = torch.nn.functional.interpolate( 55 | image_inputs, 56 | image_reconstructions.shape[2:], 57 | mode="bicubic", 58 | antialias=True, 59 | ) 60 | elif self.scale_tgt_to_input_size: 61 | image_reconstructions = torch.nn.functional.interpolate( 62 | image_reconstructions, 63 | image_inputs.shape[2:], 64 | mode="bicubic", 65 | antialias=True, 66 | ) 67 | 68 | perceptual_loss2 = self.perceptual_loss( 69 | image_inputs.contiguous(), image_reconstructions.contiguous() 70 | ) 71 | loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean() 72 | log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach() 73 | return loss, log 74 | -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kijai/ComfyUI-LVCDWrapper/081c8180029b1b5eb8f416e079456311ff467c83/sgm/modules/autoencoding/lpips/__init__.py -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/loss/.gitignore: -------------------------------------------------------------------------------- 1 | vgg.pth -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/loss/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/loss/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kijai/ComfyUI-LVCDWrapper/081c8180029b1b5eb8f416e079456311ff467c83/sgm/modules/autoencoding/lpips/loss/__init__.py -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/loss/lpips.py: -------------------------------------------------------------------------------- 1 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" 2 | 3 | from collections import namedtuple 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import models 8 | 9 | from ..util import get_ckpt_path 10 | 11 | 12 | class LPIPS(nn.Module): 13 | # Learned perceptual metric 14 | def __init__(self, use_dropout=True): 15 | super().__init__() 16 | self.scaling_layer = ScalingLayer() 17 | self.chns = [64, 128, 256, 512, 512] # vg16 features 18 | self.net = vgg16(pretrained=True, requires_grad=False) 19 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 20 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 21 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 22 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 23 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 24 | self.load_from_pretrained() 25 | for param in self.parameters(): 26 | param.requires_grad = False 27 | 28 | def load_from_pretrained(self, name="vgg_lpips"): 29 | ckpt = get_ckpt_path(name, ".sgm/modules/autoencoding/lpips/loss") 30 | self.load_state_dict( 31 | torch.load(ckpt, map_location=torch.device("cpu")), strict=False 32 | ) 33 | print("loaded pretrained LPIPS loss from {}".format(ckpt)) 34 | 35 | @classmethod 36 | def from_pretrained(cls, name="vgg_lpips"): 37 | if name != "vgg_lpips": 38 | raise NotImplementedError 39 | model = cls() 40 | ckpt = get_ckpt_path(name) 41 | model.load_state_dict( 42 | torch.load(ckpt, map_location=torch.device("cpu")), strict=False 43 | ) 44 | return model 45 | 46 | def forward(self, input, target): 47 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) 48 | outs0, outs1 = self.net(in0_input), self.net(in1_input) 49 | feats0, feats1, diffs = {}, {}, {} 50 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 51 | for kk in range(len(self.chns)): 52 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor( 53 | outs1[kk] 54 | ) 55 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 56 | 57 | res = [ 58 | spatial_average(lins[kk].model(diffs[kk]), keepdim=True) 59 | for kk in range(len(self.chns)) 60 | ] 61 | val = res[0] 62 | for l in range(1, len(self.chns)): 63 | val += res[l] 64 | return val 65 | 66 | 67 | class ScalingLayer(nn.Module): 68 | def __init__(self): 69 | super(ScalingLayer, self).__init__() 70 | self.register_buffer( 71 | "shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None] 72 | ) 73 | self.register_buffer( 74 | "scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None] 75 | ) 76 | 77 | def forward(self, inp): 78 | return (inp - self.shift) / self.scale 79 | 80 | 81 | class NetLinLayer(nn.Module): 82 | """A single linear layer which does a 1x1 conv""" 83 | 84 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 85 | super(NetLinLayer, self).__init__() 86 | layers = ( 87 | [ 88 | nn.Dropout(), 89 | ] 90 | if (use_dropout) 91 | else [] 92 | ) 93 | layers += [ 94 | nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), 95 | ] 96 | self.model = nn.Sequential(*layers) 97 | 98 | 99 | class vgg16(torch.nn.Module): 100 | def __init__(self, requires_grad=False, pretrained=True): 101 | super(vgg16, self).__init__() 102 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features 103 | self.slice1 = torch.nn.Sequential() 104 | self.slice2 = torch.nn.Sequential() 105 | self.slice3 = torch.nn.Sequential() 106 | self.slice4 = torch.nn.Sequential() 107 | self.slice5 = torch.nn.Sequential() 108 | self.N_slices = 5 109 | for x in range(4): 110 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 111 | for x in range(4, 9): 112 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 113 | for x in range(9, 16): 114 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 115 | for x in range(16, 23): 116 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 117 | for x in range(23, 30): 118 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 119 | if not requires_grad: 120 | for param in self.parameters(): 121 | param.requires_grad = False 122 | 123 | def forward(self, X): 124 | h = self.slice1(X) 125 | h_relu1_2 = h 126 | h = self.slice2(h) 127 | h_relu2_2 = h 128 | h = self.slice3(h) 129 | h_relu3_3 = h 130 | h = self.slice4(h) 131 | h_relu4_3 = h 132 | h = self.slice5(h) 133 | h_relu5_3 = h 134 | vgg_outputs = namedtuple( 135 | "VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"] 136 | ) 137 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 138 | return out 139 | 140 | 141 | def normalize_tensor(x, eps=1e-10): 142 | norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) 143 | return x / (norm_factor + eps) 144 | 145 | 146 | def spatial_average(x, keepdim=True): 147 | return x.mean([2, 3], keepdim=keepdim) 148 | -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/model/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017, Jun-Yan Zhu and Taesung Park 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 24 | 25 | 26 | --------------------------- LICENSE FOR pix2pix -------------------------------- 27 | BSD License 28 | 29 | For pix2pix software 30 | Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu 31 | All rights reserved. 32 | 33 | Redistribution and use in source and binary forms, with or without 34 | modification, are permitted provided that the following conditions are met: 35 | 36 | * Redistributions of source code must retain the above copyright notice, this 37 | list of conditions and the following disclaimer. 38 | 39 | * Redistributions in binary form must reproduce the above copyright notice, 40 | this list of conditions and the following disclaimer in the documentation 41 | and/or other materials provided with the distribution. 42 | 43 | ----------------------------- LICENSE FOR DCGAN -------------------------------- 44 | BSD License 45 | 46 | For dcgan.torch software 47 | 48 | Copyright (c) 2015, Facebook, Inc. All rights reserved. 49 | 50 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 51 | 52 | Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 53 | 54 | Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 55 | 56 | Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 57 | 58 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kijai/ComfyUI-LVCDWrapper/081c8180029b1b5eb8f416e079456311ff467c83/sgm/modules/autoencoding/lpips/model/__init__.py -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/model/model.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import torch.nn as nn 4 | 5 | from ..util import ActNorm 6 | 7 | 8 | def weights_init(m): 9 | classname = m.__class__.__name__ 10 | if classname.find("Conv") != -1: 11 | nn.init.normal_(m.weight.data, 0.0, 0.02) 12 | elif classname.find("BatchNorm") != -1: 13 | nn.init.normal_(m.weight.data, 1.0, 0.02) 14 | nn.init.constant_(m.bias.data, 0) 15 | 16 | 17 | class NLayerDiscriminator(nn.Module): 18 | """Defines a PatchGAN discriminator as in Pix2Pix 19 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 20 | """ 21 | 22 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): 23 | """Construct a PatchGAN discriminator 24 | Parameters: 25 | input_nc (int) -- the number of channels in input images 26 | ndf (int) -- the number of filters in the last conv layer 27 | n_layers (int) -- the number of conv layers in the discriminator 28 | norm_layer -- normalization layer 29 | """ 30 | super(NLayerDiscriminator, self).__init__() 31 | if not use_actnorm: 32 | norm_layer = nn.BatchNorm2d 33 | else: 34 | norm_layer = ActNorm 35 | if ( 36 | type(norm_layer) == functools.partial 37 | ): # no need to use bias as BatchNorm2d has affine parameters 38 | use_bias = norm_layer.func != nn.BatchNorm2d 39 | else: 40 | use_bias = norm_layer != nn.BatchNorm2d 41 | 42 | kw = 4 43 | padw = 1 44 | sequence = [ 45 | nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), 46 | nn.LeakyReLU(0.2, True), 47 | ] 48 | nf_mult = 1 49 | nf_mult_prev = 1 50 | for n in range(1, n_layers): # gradually increase the number of filters 51 | nf_mult_prev = nf_mult 52 | nf_mult = min(2**n, 8) 53 | sequence += [ 54 | nn.Conv2d( 55 | ndf * nf_mult_prev, 56 | ndf * nf_mult, 57 | kernel_size=kw, 58 | stride=2, 59 | padding=padw, 60 | bias=use_bias, 61 | ), 62 | norm_layer(ndf * nf_mult), 63 | nn.LeakyReLU(0.2, True), 64 | ] 65 | 66 | nf_mult_prev = nf_mult 67 | nf_mult = min(2**n_layers, 8) 68 | sequence += [ 69 | nn.Conv2d( 70 | ndf * nf_mult_prev, 71 | ndf * nf_mult, 72 | kernel_size=kw, 73 | stride=1, 74 | padding=padw, 75 | bias=use_bias, 76 | ), 77 | norm_layer(ndf * nf_mult), 78 | nn.LeakyReLU(0.2, True), 79 | ] 80 | 81 | sequence += [ 82 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw) 83 | ] # output 1 channel prediction map 84 | self.main = nn.Sequential(*sequence) 85 | 86 | def forward(self, input): 87 | """Standard forward.""" 88 | return self.main(input) 89 | -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/util.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | 4 | import requests 5 | import torch 6 | import torch.nn as nn 7 | from tqdm import tqdm 8 | 9 | URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"} 10 | 11 | CKPT_MAP = {"vgg_lpips": "vgg.pth"} 12 | 13 | MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"} 14 | 15 | 16 | def download(url, local_path, chunk_size=1024): 17 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 18 | with requests.get(url, stream=True) as r: 19 | total_size = int(r.headers.get("content-length", 0)) 20 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 21 | with open(local_path, "wb") as f: 22 | for data in r.iter_content(chunk_size=chunk_size): 23 | if data: 24 | f.write(data) 25 | pbar.update(chunk_size) 26 | 27 | 28 | def md5_hash(path): 29 | with open(path, "rb") as f: 30 | content = f.read() 31 | return hashlib.md5(content).hexdigest() 32 | 33 | 34 | def get_ckpt_path(name, root, check=False): 35 | assert name in URL_MAP 36 | path = os.path.join(root, CKPT_MAP[name]) 37 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): 38 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) 39 | download(URL_MAP[name], path) 40 | md5 = md5_hash(path) 41 | assert md5 == MD5_MAP[name], md5 42 | return path 43 | 44 | 45 | class ActNorm(nn.Module): 46 | def __init__( 47 | self, num_features, logdet=False, affine=True, allow_reverse_init=False 48 | ): 49 | assert affine 50 | super().__init__() 51 | self.logdet = logdet 52 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) 53 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) 54 | self.allow_reverse_init = allow_reverse_init 55 | 56 | self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8)) 57 | 58 | def initialize(self, input): 59 | with torch.no_grad(): 60 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) 61 | mean = ( 62 | flatten.mean(1) 63 | .unsqueeze(1) 64 | .unsqueeze(2) 65 | .unsqueeze(3) 66 | .permute(1, 0, 2, 3) 67 | ) 68 | std = ( 69 | flatten.std(1) 70 | .unsqueeze(1) 71 | .unsqueeze(2) 72 | .unsqueeze(3) 73 | .permute(1, 0, 2, 3) 74 | ) 75 | 76 | self.loc.data.copy_(-mean) 77 | self.scale.data.copy_(1 / (std + 1e-6)) 78 | 79 | def forward(self, input, reverse=False): 80 | if reverse: 81 | return self.reverse(input) 82 | if len(input.shape) == 2: 83 | input = input[:, :, None, None] 84 | squeeze = True 85 | else: 86 | squeeze = False 87 | 88 | _, _, height, width = input.shape 89 | 90 | if self.training and self.initialized.item() == 0: 91 | self.initialize(input) 92 | self.initialized.fill_(1) 93 | 94 | h = self.scale * (input + self.loc) 95 | 96 | if squeeze: 97 | h = h.squeeze(-1).squeeze(-1) 98 | 99 | if self.logdet: 100 | log_abs = torch.log(torch.abs(self.scale)) 101 | logdet = height * width * torch.sum(log_abs) 102 | logdet = logdet * torch.ones(input.shape[0]).to(input) 103 | return h, logdet 104 | 105 | return h 106 | 107 | def reverse(self, output): 108 | if self.training and self.initialized.item() == 0: 109 | if not self.allow_reverse_init: 110 | raise RuntimeError( 111 | "Initializing ActNorm in reverse direction is " 112 | "disabled by default. Use allow_reverse_init=True to enable." 113 | ) 114 | else: 115 | self.initialize(output) 116 | self.initialized.fill_(1) 117 | 118 | if len(output.shape) == 2: 119 | output = output[:, :, None, None] 120 | squeeze = True 121 | else: 122 | squeeze = False 123 | 124 | h = output / self.scale - self.loc 125 | 126 | if squeeze: 127 | h = h.squeeze(-1).squeeze(-1) 128 | return h 129 | -------------------------------------------------------------------------------- /sgm/modules/autoencoding/lpips/vqperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def hinge_d_loss(logits_real, logits_fake): 6 | loss_real = torch.mean(F.relu(1.0 - logits_real)) 7 | loss_fake = torch.mean(F.relu(1.0 + logits_fake)) 8 | d_loss = 0.5 * (loss_real + loss_fake) 9 | return d_loss 10 | 11 | 12 | def vanilla_d_loss(logits_real, logits_fake): 13 | d_loss = 0.5 * ( 14 | torch.mean(torch.nn.functional.softplus(-logits_real)) 15 | + torch.mean(torch.nn.functional.softplus(logits_fake)) 16 | ) 17 | return d_loss 18 | -------------------------------------------------------------------------------- /sgm/modules/autoencoding/regularizers/__init__.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Any, Tuple 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from ....modules.distributions.distributions import \ 9 | DiagonalGaussianDistribution 10 | from .base import AbstractRegularizer 11 | 12 | 13 | class DiagonalGaussianRegularizer(AbstractRegularizer): 14 | def __init__(self, sample: bool = True): 15 | super().__init__() 16 | self.sample = sample 17 | 18 | def get_trainable_parameters(self) -> Any: 19 | yield from () 20 | 21 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: 22 | log = dict() 23 | posterior = DiagonalGaussianDistribution(z) 24 | if self.sample: 25 | z = posterior.sample() 26 | else: 27 | z = posterior.mode() 28 | kl_loss = posterior.kl() 29 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 30 | log["kl_loss"] = kl_loss 31 | return z, log 32 | -------------------------------------------------------------------------------- /sgm/modules/autoencoding/regularizers/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Any, Tuple 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn 7 | 8 | 9 | class AbstractRegularizer(nn.Module): 10 | def __init__(self): 11 | super().__init__() 12 | 13 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: 14 | raise NotImplementedError() 15 | 16 | @abstractmethod 17 | def get_trainable_parameters(self) -> Any: 18 | raise NotImplementedError() 19 | 20 | 21 | class IdentityRegularizer(AbstractRegularizer): 22 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: 23 | return z, dict() 24 | 25 | def get_trainable_parameters(self) -> Any: 26 | yield from () 27 | 28 | 29 | def measure_perplexity( 30 | predicted_indices: torch.Tensor, num_centroids: int 31 | ) -> Tuple[torch.Tensor, torch.Tensor]: 32 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py 33 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally 34 | encodings = ( 35 | F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids) 36 | ) 37 | avg_probs = encodings.mean(0) 38 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() 39 | cluster_use = torch.sum(avg_probs > 0) 40 | return perplexity, cluster_use 41 | -------------------------------------------------------------------------------- /sgm/modules/autoencoding/temporal_ae.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Iterable, Union 2 | 3 | import torch 4 | from einops import rearrange, repeat 5 | 6 | from ...modules.diffusionmodules.model import ( 7 | XFORMERS_IS_AVAILABLE, 8 | AttnBlock, 9 | Decoder, 10 | MemoryEfficientAttnBlock, 11 | ResnetBlock, 12 | ) 13 | from ...modules.diffusionmodules.openaimodel import ResBlock, timestep_embedding 14 | from ...modules.video_attention import VideoTransformerBlock 15 | from ...util import partialclass 16 | 17 | import comfy.ops 18 | ops = comfy.ops.manual_cast 19 | class VideoResBlock(ResnetBlock): 20 | def __init__( 21 | self, 22 | out_channels, 23 | *args, 24 | dropout=0.0, 25 | video_kernel_size=3, 26 | alpha=0.0, 27 | merge_strategy="learned", 28 | **kwargs, 29 | ): 30 | super().__init__(out_channels=out_channels, dropout=dropout, *args, **kwargs) 31 | if video_kernel_size is None: 32 | video_kernel_size = [3, 1, 1] 33 | self.time_stack = ResBlock( 34 | channels=out_channels, 35 | emb_channels=0, 36 | dropout=dropout, 37 | dims=3, 38 | use_scale_shift_norm=False, 39 | use_conv=False, 40 | up=False, 41 | down=False, 42 | kernel_size=video_kernel_size, 43 | use_checkpoint=False, 44 | skip_t_emb=True, 45 | ) 46 | 47 | self.merge_strategy = merge_strategy 48 | if self.merge_strategy == "fixed": 49 | self.register_buffer("mix_factor", torch.Tensor([alpha])) 50 | elif self.merge_strategy == "learned": 51 | self.register_parameter( 52 | "mix_factor", torch.nn.Parameter(torch.Tensor([alpha])) 53 | ) 54 | else: 55 | raise ValueError(f"unknown merge strategy {self.merge_strategy}") 56 | 57 | def get_alpha(self, bs): 58 | if self.merge_strategy == "fixed": 59 | return self.mix_factor 60 | elif self.merge_strategy == "learned": 61 | return torch.sigmoid(self.mix_factor) 62 | else: 63 | raise NotImplementedError() 64 | 65 | def forward(self, x, temb, skip_video=False, timesteps=None): 66 | if timesteps is None: 67 | timesteps = self.timesteps 68 | 69 | b, c, h, w = x.shape 70 | 71 | x = super().forward(x, temb) 72 | 73 | if not skip_video: 74 | x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) 75 | 76 | x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) 77 | 78 | x = self.time_stack(x, temb) 79 | 80 | alpha = self.get_alpha(bs=b // timesteps) 81 | x = alpha * x + (1.0 - alpha) * x_mix 82 | 83 | x = rearrange(x, "b c t h w -> (b t) c h w") 84 | return x 85 | 86 | 87 | class AE3DConv(torch.nn.Conv2d): 88 | def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs): 89 | super().__init__(in_channels, out_channels, *args, **kwargs) 90 | if isinstance(video_kernel_size, Iterable): 91 | padding = [int(k // 2) for k in video_kernel_size] 92 | else: 93 | padding = int(video_kernel_size // 2) 94 | 95 | self.time_mix_conv = torch.nn.Conv3d( 96 | in_channels=out_channels, 97 | out_channels=out_channels, 98 | kernel_size=video_kernel_size, 99 | padding=padding, 100 | ) 101 | 102 | def forward(self, input, timesteps, skip_video=False): 103 | x = super().forward(input) 104 | if skip_video: 105 | return x 106 | x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) 107 | x = self.time_mix_conv(x) 108 | return rearrange(x, "b c t h w -> (b t) c h w") 109 | 110 | 111 | class VideoBlock(AttnBlock): 112 | def __init__( 113 | self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned" 114 | ): 115 | super().__init__(in_channels) 116 | # no context, single headed, as in base class 117 | self.time_mix_block = VideoTransformerBlock( 118 | dim=in_channels, 119 | n_heads=1, 120 | d_head=in_channels, 121 | checkpoint=False, 122 | ff_in=True, 123 | attn_mode="softmax", 124 | ) 125 | 126 | time_embed_dim = self.in_channels * 4 127 | self.video_time_embed = torch.nn.Sequential( 128 | ops.Linear(self.in_channels, time_embed_dim), 129 | torch.nn.SiLU(), 130 | ops.Linear(time_embed_dim, self.in_channels), 131 | ) 132 | 133 | self.merge_strategy = merge_strategy 134 | if self.merge_strategy == "fixed": 135 | self.register_buffer("mix_factor", torch.Tensor([alpha])) 136 | elif self.merge_strategy == "learned": 137 | self.register_parameter( 138 | "mix_factor", torch.nn.Parameter(torch.Tensor([alpha])) 139 | ) 140 | else: 141 | raise ValueError(f"unknown merge strategy {self.merge_strategy}") 142 | 143 | def forward(self, x, timesteps, skip_video=False): 144 | if skip_video: 145 | return super().forward(x) 146 | 147 | x_in = x 148 | x = self.attention(x) 149 | h, w = x.shape[2:] 150 | x = rearrange(x, "b c h w -> b (h w) c") 151 | 152 | x_mix = x 153 | num_frames = torch.arange(timesteps, device=x.device) 154 | num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps) 155 | num_frames = rearrange(num_frames, "b t -> (b t)") 156 | t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False) 157 | emb = self.video_time_embed(t_emb) # b, n_channels 158 | emb = emb[:, None, :] 159 | x_mix = x_mix + emb 160 | 161 | alpha = self.get_alpha() 162 | x_mix = self.time_mix_block(x_mix, timesteps=timesteps) 163 | x = alpha * x + (1.0 - alpha) * x_mix # alpha merge 164 | 165 | x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) 166 | x = self.proj_out(x) 167 | 168 | return x_in + x 169 | 170 | def get_alpha( 171 | self, 172 | ): 173 | if self.merge_strategy == "fixed": 174 | return self.mix_factor 175 | elif self.merge_strategy == "learned": 176 | return torch.sigmoid(self.mix_factor) 177 | else: 178 | raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}") 179 | 180 | 181 | class MemoryEfficientVideoBlock(MemoryEfficientAttnBlock): 182 | def __init__( 183 | self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned" 184 | ): 185 | super().__init__(in_channels) 186 | # no context, single headed, as in base class 187 | self.time_mix_block = VideoTransformerBlock( 188 | dim=in_channels, 189 | n_heads=1, 190 | d_head=in_channels, 191 | checkpoint=False, 192 | ff_in=True, 193 | attn_mode="softmax-xformers", 194 | ) 195 | 196 | time_embed_dim = self.in_channels * 4 197 | self.video_time_embed = torch.nn.Sequential( 198 | ops.Linear(self.in_channels, time_embed_dim), 199 | torch.nn.SiLU(), 200 | ops.Linear(time_embed_dim, self.in_channels), 201 | ) 202 | 203 | self.merge_strategy = merge_strategy 204 | if self.merge_strategy == "fixed": 205 | self.register_buffer("mix_factor", torch.Tensor([alpha])) 206 | elif self.merge_strategy == "learned": 207 | self.register_parameter( 208 | "mix_factor", torch.nn.Parameter(torch.Tensor([alpha])) 209 | ) 210 | else: 211 | raise ValueError(f"unknown merge strategy {self.merge_strategy}") 212 | 213 | def forward(self, x, timesteps, skip_time_block=False): 214 | if skip_time_block: 215 | return super().forward(x) 216 | 217 | x_in = x 218 | x = self.attention(x) 219 | h, w = x.shape[2:] 220 | x = rearrange(x, "b c h w -> b (h w) c") 221 | 222 | x_mix = x 223 | num_frames = torch.arange(timesteps, device=x.device) 224 | num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps) 225 | num_frames = rearrange(num_frames, "b t -> (b t)") 226 | t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False) 227 | emb = self.video_time_embed(t_emb) # b, n_channels 228 | emb = emb[:, None, :] 229 | x_mix = x_mix + emb 230 | 231 | alpha = self.get_alpha() 232 | x_mix = self.time_mix_block(x_mix, timesteps=timesteps) 233 | x = alpha * x + (1.0 - alpha) * x_mix # alpha merge 234 | 235 | x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) 236 | x = self.proj_out(x) 237 | 238 | return x_in + x 239 | 240 | def get_alpha( 241 | self, 242 | ): 243 | if self.merge_strategy == "fixed": 244 | return self.mix_factor 245 | elif self.merge_strategy == "learned": 246 | return torch.sigmoid(self.mix_factor) 247 | else: 248 | raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}") 249 | 250 | 251 | def make_time_attn( 252 | in_channels, 253 | attn_type="vanilla", 254 | attn_kwargs=None, 255 | alpha: float = 0, 256 | merge_strategy: str = "learned", 257 | ): 258 | assert attn_type in [ 259 | "vanilla", 260 | "vanilla-xformers", 261 | ], f"attn_type {attn_type} not supported for spatio-temporal attention" 262 | print( 263 | f"making spatial and temporal attention of type '{attn_type}' with {in_channels} in_channels" 264 | ) 265 | if not XFORMERS_IS_AVAILABLE and attn_type == "vanilla-xformers": 266 | print( 267 | f"Attention mode '{attn_type}' is not available. Falling back to vanilla attention. " 268 | f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}" 269 | ) 270 | attn_type = "vanilla" 271 | 272 | if attn_type == "vanilla": 273 | assert attn_kwargs is None 274 | return partialclass( 275 | VideoBlock, in_channels, alpha=alpha, merge_strategy=merge_strategy 276 | ) 277 | elif attn_type == "vanilla-xformers": 278 | print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...") 279 | return partialclass( 280 | MemoryEfficientVideoBlock, 281 | in_channels, 282 | alpha=alpha, 283 | merge_strategy=merge_strategy, 284 | ) 285 | else: 286 | return NotImplementedError() 287 | 288 | 289 | class Conv2DWrapper(torch.nn.Conv2d): 290 | def forward(self, input: torch.Tensor, **kwargs) -> torch.Tensor: 291 | return super().forward(input) 292 | 293 | 294 | class VideoDecoder(Decoder): 295 | available_time_modes = ["all", "conv-only", "attn-only"] 296 | 297 | def __init__( 298 | self, 299 | *args, 300 | video_kernel_size: Union[int, list] = 3, 301 | alpha: float = 0.0, 302 | merge_strategy: str = "learned", 303 | time_mode: str = "conv-only", 304 | **kwargs, 305 | ): 306 | self.video_kernel_size = video_kernel_size 307 | self.alpha = alpha 308 | self.merge_strategy = merge_strategy 309 | self.time_mode = time_mode 310 | assert ( 311 | self.time_mode in self.available_time_modes 312 | ), f"time_mode parameter has to be in {self.available_time_modes}" 313 | super().__init__(*args, **kwargs) 314 | 315 | def get_last_layer(self, skip_time_mix=False, **kwargs): 316 | if self.time_mode == "attn-only": 317 | raise NotImplementedError("TODO") 318 | else: 319 | return ( 320 | self.conv_out.time_mix_conv.weight 321 | if not skip_time_mix 322 | else self.conv_out.weight 323 | ) 324 | 325 | def _make_attn(self) -> Callable: 326 | if self.time_mode not in ["conv-only", "only-last-conv"]: 327 | return partialclass( 328 | make_time_attn, 329 | alpha=self.alpha, 330 | merge_strategy=self.merge_strategy, 331 | ) 332 | else: 333 | return super()._make_attn() 334 | 335 | def _make_conv(self) -> Callable: 336 | if self.time_mode != "attn-only": 337 | return partialclass(AE3DConv, video_kernel_size=self.video_kernel_size) 338 | else: 339 | return Conv2DWrapper 340 | 341 | def _make_resblock(self) -> Callable: 342 | if self.time_mode not in ["attn-only", "only-last-conv"]: 343 | return partialclass( 344 | VideoResBlock, 345 | video_kernel_size=self.video_kernel_size, 346 | alpha=self.alpha, 347 | merge_strategy=self.merge_strategy, 348 | ) 349 | else: 350 | return super()._make_resblock() 351 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kijai/ComfyUI-LVCDWrapper/081c8180029b1b5eb8f416e079456311ff467c83/sgm/modules/diffusionmodules/__init__.py -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/denoiser.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from ...util import append_dims, instantiate_from_config 7 | from .denoiser_scaling import DenoiserScaling 8 | from .discretizer import Discretization 9 | 10 | 11 | class Denoiser(nn.Module): 12 | def __init__(self, scaling_config: Dict): 13 | super().__init__() 14 | 15 | self.scaling: DenoiserScaling = instantiate_from_config(scaling_config) 16 | 17 | def possibly_quantize_sigma(self, sigma: torch.Tensor) -> torch.Tensor: 18 | return sigma 19 | 20 | def possibly_quantize_c_noise(self, c_noise: torch.Tensor) -> torch.Tensor: 21 | return c_noise 22 | 23 | def forward( 24 | self, 25 | network, 26 | input: torch.Tensor, 27 | sigma: torch.Tensor, 28 | cond: Dict, 29 | **additional_model_inputs, 30 | ) -> torch.Tensor: 31 | sigma = self.possibly_quantize_sigma(sigma) 32 | sigma_shape = sigma.shape 33 | sigma = append_dims(sigma, input.ndim) 34 | c_skip, c_out, c_in, c_noise = self.scaling(sigma) 35 | c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape)) 36 | return ( 37 | network(input * c_in, c_noise, cond, **additional_model_inputs) * c_out 38 | + input * c_skip 39 | ) 40 | 41 | 42 | class DiscreteDenoiser(Denoiser): 43 | def __init__( 44 | self, 45 | scaling_config: Dict, 46 | num_idx: int, 47 | discretization_config: Dict, 48 | do_append_zero: bool = False, 49 | quantize_c_noise: bool = True, 50 | flip: bool = True, 51 | ): 52 | super().__init__(scaling_config) 53 | self.discretization: Discretization = instantiate_from_config( 54 | discretization_config 55 | ) 56 | sigmas = self.discretization(num_idx, do_append_zero=do_append_zero, flip=flip) 57 | self.register_buffer("sigmas", sigmas) 58 | self.quantize_c_noise = quantize_c_noise 59 | self.num_idx = num_idx 60 | 61 | def sigma_to_idx(self, sigma: torch.Tensor) -> torch.Tensor: 62 | dists = sigma - self.sigmas[:, None] 63 | return dists.abs().argmin(dim=0).view(sigma.shape) 64 | 65 | def idx_to_sigma(self, idx: Union[torch.Tensor, int]) -> torch.Tensor: 66 | return self.sigmas[idx] 67 | 68 | def possibly_quantize_sigma(self, sigma: torch.Tensor) -> torch.Tensor: 69 | return self.idx_to_sigma(self.sigma_to_idx(sigma)) 70 | 71 | def possibly_quantize_c_noise(self, c_noise: torch.Tensor) -> torch.Tensor: 72 | if self.quantize_c_noise: 73 | return self.sigma_to_idx(c_noise) 74 | else: 75 | return c_noise 76 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/denoiser_scaling.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Tuple 3 | 4 | import torch 5 | 6 | 7 | class DenoiserScaling(ABC): 8 | @abstractmethod 9 | def __call__( 10 | self, sigma: torch.Tensor 11 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 12 | pass 13 | 14 | 15 | class EDMScaling: 16 | def __init__(self, sigma_data: float = 0.5): 17 | self.sigma_data = sigma_data 18 | 19 | def __call__( 20 | self, sigma: torch.Tensor 21 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 22 | c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) 23 | c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5 24 | c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5 25 | c_noise = 0.25 * sigma.log() 26 | return c_skip, c_out, c_in, c_noise 27 | 28 | 29 | class EpsScaling: 30 | def __call__( 31 | self, sigma: torch.Tensor 32 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 33 | c_skip = torch.ones_like(sigma, device=sigma.device) 34 | c_out = -sigma 35 | c_in = 1 / (sigma**2 + 1.0) ** 0.5 36 | c_noise = sigma.clone() 37 | return c_skip, c_out, c_in, c_noise 38 | 39 | 40 | class VScaling: 41 | def __call__( 42 | self, sigma: torch.Tensor 43 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 44 | c_skip = 1.0 / (sigma**2 + 1.0) 45 | c_out = -sigma / (sigma**2 + 1.0) ** 0.5 46 | c_in = 1.0 / (sigma**2 + 1.0) ** 0.5 47 | c_noise = sigma.clone() 48 | return c_skip, c_out, c_in, c_noise 49 | 50 | 51 | class VScalingWithEDMcNoise(DenoiserScaling): 52 | def __call__( 53 | self, sigma: torch.Tensor 54 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 55 | c_skip = 1.0 / (sigma**2 + 1.0) 56 | c_out = -sigma / (sigma**2 + 1.0) ** 0.5 57 | c_in = 1.0 / (sigma**2 + 1.0) ** 0.5 58 | c_noise = 0.25 * sigma.log() 59 | return c_skip, c_out, c_in, c_noise 60 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/denoiser_weighting.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class UnitWeighting: 5 | def __call__(self, sigma): 6 | return torch.ones_like(sigma, device=sigma.device) 7 | 8 | 9 | class EDMWeighting: 10 | def __init__(self, sigma_data=0.5): 11 | self.sigma_data = sigma_data 12 | 13 | def __call__(self, sigma): 14 | return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 15 | 16 | 17 | class VWeighting(EDMWeighting): 18 | def __init__(self): 19 | super().__init__(sigma_data=1.0) 20 | 21 | 22 | class EpsWeighting: 23 | def __call__(self, sigma): 24 | return sigma**-2.0 25 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/discretizer.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from functools import partial 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from ...modules.diffusionmodules.util import make_beta_schedule 8 | from ...util import append_zero 9 | 10 | 11 | def generate_roughly_equally_spaced_steps( 12 | num_substeps: int, max_step: int 13 | ) -> np.ndarray: 14 | return np.linspace(max_step - 1, 0, num_substeps, endpoint=False).astype(int)[::-1] 15 | 16 | 17 | class Discretization: 18 | def __call__(self, n, do_append_zero=True, device="cpu", flip=False): 19 | sigmas = self.get_sigmas(n, device=device) 20 | sigmas = append_zero(sigmas) if do_append_zero else sigmas 21 | return sigmas if not flip else torch.flip(sigmas, (0,)) 22 | 23 | @abstractmethod 24 | def get_sigmas(self, n, device): 25 | pass 26 | 27 | 28 | class EDMDiscretization(Discretization): 29 | def __init__(self, sigma_min=0.002, sigma_max=80.0, rho=7.0): 30 | self.sigma_min = sigma_min 31 | self.sigma_max = sigma_max 32 | self.rho = rho 33 | 34 | def get_sigmas(self, n, device="cpu"): 35 | ramp = torch.linspace(0, 1, n, device=device) 36 | min_inv_rho = self.sigma_min ** (1 / self.rho) 37 | max_inv_rho = self.sigma_max ** (1 / self.rho) 38 | sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** self.rho 39 | return sigmas 40 | 41 | 42 | class LegacyDDPMDiscretization(Discretization): 43 | def __init__( 44 | self, 45 | linear_start=0.00085, 46 | linear_end=0.0120, 47 | num_timesteps=1000, 48 | ): 49 | super().__init__() 50 | self.num_timesteps = num_timesteps 51 | betas = make_beta_schedule( 52 | "linear", num_timesteps, linear_start=linear_start, linear_end=linear_end 53 | ) 54 | alphas = 1.0 - betas 55 | self.alphas_cumprod = np.cumprod(alphas, axis=0) 56 | self.to_torch = partial(torch.tensor, dtype=torch.float32) 57 | 58 | def get_sigmas(self, n, device="cpu"): 59 | if n < self.num_timesteps: 60 | timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps) 61 | alphas_cumprod = self.alphas_cumprod[timesteps] 62 | elif n == self.num_timesteps: 63 | alphas_cumprod = self.alphas_cumprod 64 | else: 65 | raise ValueError 66 | 67 | to_torch = partial(torch.tensor, dtype=torch.float32, device=device) 68 | sigmas = to_torch((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 69 | return torch.flip(sigmas, (0,)) 70 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/guiders.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from abc import ABC, abstractmethod 3 | from typing import Dict, List, Optional, Tuple, Union 4 | 5 | import torch 6 | from einops import rearrange, repeat 7 | 8 | from ...util import append_dims, default 9 | 10 | logpy = logging.getLogger(__name__) 11 | 12 | 13 | class Guider(ABC): 14 | @abstractmethod 15 | def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor: 16 | pass 17 | 18 | def prepare_inputs( 19 | self, x: torch.Tensor, s: float, c: Dict, uc: Dict 20 | ) -> Tuple[torch.Tensor, float, Dict]: 21 | pass 22 | 23 | 24 | class VanillaCFG(Guider): 25 | def __init__(self, scale: float): 26 | self.scale = scale 27 | 28 | def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: 29 | x_u, x_c = x.chunk(2) 30 | x_pred = x_u + self.scale * (x_c - x_u) 31 | return x_pred 32 | 33 | def prepare_inputs(self, x, s, c, uc): 34 | c_out = dict() 35 | 36 | for k in c: 37 | if k in ["vector", "crossattn", "concat"]: 38 | c_out[k] = torch.cat((uc[k], c[k]), 0) 39 | else: 40 | assert c[k] == uc[k] 41 | c_out[k] = c[k] 42 | return torch.cat([x] * 2), torch.cat([s] * 2), c_out 43 | 44 | 45 | class IdentityGuider(Guider): 46 | def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor: 47 | return x 48 | 49 | def prepare_inputs( 50 | self, x: torch.Tensor, s: float, c: Dict, uc: Dict 51 | ) -> Tuple[torch.Tensor, float, Dict]: 52 | c_out = dict() 53 | 54 | for k in c: 55 | c_out[k] = c[k] 56 | 57 | return x, s, c_out 58 | 59 | 60 | class LinearPredictionGuider(Guider): 61 | def __init__( 62 | self, 63 | max_scale: float, 64 | num_frames: int, 65 | min_scale: float = 1.0, 66 | additional_cond_keys: Optional[Union[List[str], str]] = None, 67 | ): 68 | self.min_scale = min_scale 69 | self.max_scale = max_scale 70 | self.num_frames = num_frames 71 | self.scale = torch.linspace(min_scale, max_scale, num_frames).unsqueeze(0) 72 | 73 | additional_cond_keys = default(additional_cond_keys, []) 74 | if isinstance(additional_cond_keys, str): 75 | additional_cond_keys = [additional_cond_keys] 76 | self.additional_cond_keys = additional_cond_keys 77 | 78 | def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: 79 | x_u, x_c = x.chunk(2) 80 | 81 | x_u = rearrange(x_u, "(b t) ... -> b t ...", t=self.num_frames) 82 | x_c = rearrange(x_c, "(b t) ... -> b t ...", t=self.num_frames) 83 | scale = repeat(self.scale, "1 t -> b t", b=x_u.shape[0]) 84 | scale = append_dims(scale, x_u.ndim).to(x_u.device) 85 | 86 | return rearrange(x_u + scale * (x_c - x_u), "b t ... -> (b t) ...") 87 | 88 | def prepare_inputs( 89 | self, x: torch.Tensor, s: torch.Tensor, c: dict, uc: dict 90 | ) -> Tuple[torch.Tensor, torch.Tensor, dict]: 91 | c_out = dict() 92 | 93 | for k in c: 94 | if k in ["vector", "crossattn", "concat"] + self.additional_cond_keys: 95 | c_out[k] = torch.cat((uc[k], c[k]), 0) 96 | else: 97 | assert c[k] == uc[k] 98 | c_out[k] = c[k] 99 | return torch.cat([x] * 2), torch.cat([s] * 2), c_out 100 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/loss.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional, Tuple, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from ...modules.autoencoding.lpips.loss.lpips import LPIPS 7 | from ...modules.encoders.modules import GeneralConditioner 8 | from ...util import append_dims, instantiate_from_config 9 | from .denoiser import Denoiser 10 | from einops import rearrange, repeat 11 | 12 | 13 | class StandardDiffusionLoss(nn.Module): 14 | def __init__( 15 | self, 16 | sigma_sampler_config: dict, 17 | loss_weighting_config: dict, 18 | loss_type: str = "l2", 19 | offset_noise_level: float = 0.0, 20 | batch2model_keys: Optional[Union[str, List[str]]] = None, 21 | additional_cond_keys: Optional[Union[str, List[str]]] = None, 22 | ): 23 | super().__init__() 24 | 25 | assert loss_type in ["l2", "l1", "lpips"] 26 | 27 | self.sigma_sampler = instantiate_from_config(sigma_sampler_config) 28 | self.loss_weighting = instantiate_from_config(loss_weighting_config) 29 | 30 | self.loss_type = loss_type 31 | self.offset_noise_level = offset_noise_level 32 | 33 | if loss_type == "lpips": 34 | self.lpips = LPIPS().eval() 35 | 36 | if not batch2model_keys: 37 | batch2model_keys = [] 38 | 39 | if isinstance(batch2model_keys, str): 40 | batch2model_keys = [batch2model_keys] 41 | 42 | self.batch2model_keys = set(batch2model_keys) 43 | self.additional_cond_keys = set(additional_cond_keys) 44 | 45 | def get_noised_input( 46 | self, sigmas_bc: torch.Tensor, noise: torch.Tensor, input: torch.Tensor 47 | ) -> torch.Tensor: 48 | noised_input = input + noise * sigmas_bc 49 | return noised_input 50 | 51 | def forward( 52 | self, 53 | network, 54 | denoiser: Denoiser, 55 | conditioner: GeneralConditioner, 56 | input: torch.Tensor, 57 | batch: Dict, 58 | ) -> torch.Tensor: 59 | cond = conditioner(batch) 60 | 61 | if 'num_video_frames' in batch.keys(): 62 | for k in ["crossattn", "concat"]: 63 | cond[k] = repeat(cond[k], "b ... -> (b t) ...", t=batch['num_video_frames']) 64 | 65 | cond_keys = self.additional_cond_keys.intersection(batch) 66 | for k in cond_keys: 67 | if k in ['crossattn_scale', 'concat_scale', 'prev_frame']: 68 | cond[k] = repeat(batch[k], "b ... -> (b t) ...", t=batch['num_video_frames']) 69 | else: 70 | cond[k] = batch[k] 71 | 72 | return self._forward(network, denoiser, cond, input, batch) 73 | 74 | def _forward( 75 | self, 76 | network, 77 | denoiser: Denoiser, 78 | cond: Dict, 79 | input: torch.Tensor, 80 | batch: Dict, 81 | ) -> Tuple[torch.Tensor, Dict]: 82 | additional_model_inputs = { 83 | key: batch[key] for key in self.batch2model_keys.intersection(batch) 84 | } 85 | if 'num_video_frames' in batch.keys(): 86 | b = input.shape[0] // batch['num_video_frames'] 87 | sigmas = self.sigma_sampler(b).to(input) 88 | sigmas = repeat(sigmas, "b ... -> b t ...", t=batch['num_video_frames']).contiguous() 89 | sigmas = rearrange(sigmas, "b t ... -> (b t) ...", t=batch['num_video_frames']) 90 | else: 91 | sigmas = self.sigma_sampler(input.shape[0]).to(input) 92 | 93 | noise = torch.randn_like(input) 94 | if self.offset_noise_level > 0.0: 95 | offset_shape = ( 96 | (input.shape[0], 1, input.shape[2]) 97 | if self.n_frames is not None 98 | else (input.shape[0], input.shape[1]) 99 | ) 100 | noise = noise + self.offset_noise_level * append_dims( 101 | torch.randn(offset_shape, device=input.device), 102 | input.ndim, 103 | ) 104 | sigmas_bc = append_dims(sigmas, input.ndim) 105 | noised_input = self.get_noised_input(sigmas_bc, noise, input) 106 | 107 | model_output = denoiser( 108 | network, noised_input, sigmas, cond, **additional_model_inputs 109 | ) 110 | w = append_dims(self.loss_weighting(sigmas), input.ndim) 111 | return self.get_loss(model_output, input, w) 112 | 113 | def get_loss(self, model_output, target, w): 114 | if self.loss_type == "l2": 115 | return torch.mean( 116 | (w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1 117 | ) 118 | elif self.loss_type == "l1": 119 | return torch.mean( 120 | (w * (model_output - target).abs()).reshape(target.shape[0], -1), 1 121 | ) 122 | elif self.loss_type == "lpips": 123 | loss = self.lpips(model_output, target).reshape(-1) 124 | return loss 125 | else: 126 | raise NotImplementedError(f"Unknown loss type {self.loss_type}") 127 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/loss_weighting.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import torch 4 | 5 | 6 | class DiffusionLossWeighting(ABC): 7 | @abstractmethod 8 | def __call__(self, sigma: torch.Tensor) -> torch.Tensor: 9 | pass 10 | 11 | 12 | class UnitWeighting(DiffusionLossWeighting): 13 | def __call__(self, sigma: torch.Tensor) -> torch.Tensor: 14 | return torch.ones_like(sigma, device=sigma.device) 15 | 16 | 17 | class EDMWeighting(DiffusionLossWeighting): 18 | def __init__(self, sigma_data: float = 0.5): 19 | self.sigma_data = sigma_data 20 | 21 | def __call__(self, sigma: torch.Tensor) -> torch.Tensor: 22 | return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 23 | 24 | 25 | class VWeighting(EDMWeighting): 26 | def __init__(self): 27 | super().__init__(sigma_data=1.0) 28 | 29 | 30 | class EpsWeighting(DiffusionLossWeighting): 31 | def __call__(self, sigma: torch.Tensor) -> torch.Tensor: 32 | return sigma**-2.0 33 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/sampling.py: -------------------------------------------------------------------------------- 1 | """ 2 | Partially ported from https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py 3 | """ 4 | 5 | 6 | from typing import Dict, Union 7 | 8 | import torch 9 | from omegaconf import ListConfig, OmegaConf 10 | from tqdm import tqdm 11 | 12 | from ...modules.diffusionmodules.sampling_utils import (get_ancestral_step, 13 | linear_multistep_coeff, 14 | to_d, to_neg_log_sigma, 15 | to_sigma) 16 | from ...util import append_dims, default, instantiate_from_config 17 | 18 | DEFAULT_GUIDER = {"target": ".sgm.modules.diffusionmodules.guiders.IdentityGuider"} 19 | 20 | 21 | class BaseDiffusionSampler: 22 | def __init__( 23 | self, 24 | discretization_config: Union[Dict, ListConfig, OmegaConf], 25 | num_steps: Union[int, None] = None, 26 | guider_config: Union[Dict, ListConfig, OmegaConf, None] = None, 27 | verbose: bool = False, 28 | device: str = "cuda", 29 | ): 30 | self.num_steps = num_steps 31 | self.discretization = instantiate_from_config(discretization_config) 32 | self.guider = instantiate_from_config( 33 | default( 34 | guider_config, 35 | DEFAULT_GUIDER, 36 | ) 37 | ) 38 | self.verbose = verbose 39 | self.device = device 40 | 41 | def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None): 42 | sigmas = self.discretization( 43 | self.num_steps if num_steps is None else num_steps, device=self.device 44 | ) 45 | uc = default(uc, cond) 46 | 47 | x *= torch.sqrt(1.0 + sigmas[0] ** 2.0) 48 | num_sigmas = len(sigmas) 49 | 50 | s_in = x.new_ones([x.shape[0]]) 51 | 52 | return x, s_in, sigmas, num_sigmas, cond, uc 53 | 54 | def denoise(self, x, denoiser, sigma, cond, uc): 55 | denoised = denoiser(*self.guider.prepare_inputs(x, sigma, cond, uc)) 56 | denoised = self.guider(denoised, sigma) 57 | return denoised 58 | 59 | def get_sigma_gen(self, num_sigmas): 60 | sigma_generator = range(num_sigmas - 1) 61 | if self.verbose: 62 | #print("#" * 30, " Sampling setting ", "#" * 30) 63 | #print(f"Sampler: {self.__class__.__name__}") 64 | #print(f"Discretization: {self.discretization.__class__.__name__}") 65 | #print(f"Guider: {self.guider.__class__.__name__}") 66 | sigma_generator = tqdm( 67 | sigma_generator, 68 | total=num_sigmas, 69 | desc=f"Sampling with {self.__class__.__name__} for {num_sigmas} steps", 70 | ) 71 | return sigma_generator 72 | 73 | 74 | class SingleStepDiffusionSampler(BaseDiffusionSampler): 75 | def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc, *args, **kwargs): 76 | raise NotImplementedError 77 | 78 | def euler_step(self, x, d, dt): 79 | return x + dt * d 80 | 81 | 82 | class EDMSampler(SingleStepDiffusionSampler): 83 | def __init__( 84 | self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs 85 | ): 86 | super().__init__(*args, **kwargs) 87 | 88 | self.s_churn = s_churn 89 | self.s_tmin = s_tmin 90 | self.s_tmax = s_tmax 91 | self.s_noise = s_noise 92 | 93 | def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, gamma=0.0): 94 | sigma_hat = sigma * (gamma + 1.0) 95 | if gamma > 0: 96 | eps = torch.randn_like(x) * self.s_noise 97 | x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5 98 | 99 | denoised = self.denoise(x, denoiser, sigma_hat, cond, uc) 100 | d = to_d(x, sigma_hat, denoised) 101 | dt = append_dims(next_sigma - sigma_hat, x.ndim) 102 | 103 | euler_step = self.euler_step(x, d, dt) 104 | x = self.possible_correction_step( 105 | euler_step, x, d, dt, next_sigma, denoiser, cond, uc 106 | ) 107 | return x 108 | 109 | def __call__(self, denoiser, x, cond, uc=None, num_steps=None): 110 | x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( 111 | x, cond, uc, num_steps 112 | ) 113 | 114 | for i in self.get_sigma_gen(num_sigmas): 115 | gamma = ( 116 | min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) 117 | if self.s_tmin <= sigmas[i] <= self.s_tmax 118 | else 0.0 119 | ) 120 | x = self.sampler_step( 121 | s_in * sigmas[i], 122 | s_in * sigmas[i + 1], 123 | denoiser, 124 | x, 125 | cond, 126 | uc, 127 | gamma, 128 | ) 129 | 130 | return x 131 | 132 | 133 | class AncestralSampler(SingleStepDiffusionSampler): 134 | def __init__(self, eta=1.0, s_noise=1.0, *args, **kwargs): 135 | super().__init__(*args, **kwargs) 136 | 137 | self.eta = eta 138 | self.s_noise = s_noise 139 | self.noise_sampler = lambda x: torch.randn_like(x) 140 | 141 | def ancestral_euler_step(self, x, denoised, sigma, sigma_down): 142 | d = to_d(x, sigma, denoised) 143 | dt = append_dims(sigma_down - sigma, x.ndim) 144 | 145 | return self.euler_step(x, d, dt) 146 | 147 | def ancestral_step(self, x, sigma, next_sigma, sigma_up): 148 | x = torch.where( 149 | append_dims(next_sigma, x.ndim) > 0.0, 150 | x + self.noise_sampler(x) * self.s_noise * append_dims(sigma_up, x.ndim), 151 | x, 152 | ) 153 | return x 154 | 155 | def __call__(self, denoiser, x, cond, uc=None, num_steps=None): 156 | x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( 157 | x, cond, uc, num_steps 158 | ) 159 | 160 | for i in self.get_sigma_gen(num_sigmas): 161 | x = self.sampler_step( 162 | s_in * sigmas[i], 163 | s_in * sigmas[i + 1], 164 | denoiser, 165 | x, 166 | cond, 167 | uc, 168 | ) 169 | 170 | return x 171 | 172 | 173 | class LinearMultistepSampler(BaseDiffusionSampler): 174 | def __init__( 175 | self, 176 | order=4, 177 | *args, 178 | **kwargs, 179 | ): 180 | super().__init__(*args, **kwargs) 181 | 182 | self.order = order 183 | 184 | def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs): 185 | x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( 186 | x, cond, uc, num_steps 187 | ) 188 | 189 | ds = [] 190 | sigmas_cpu = sigmas.detach().cpu().numpy() 191 | for i in self.get_sigma_gen(num_sigmas): 192 | sigma = s_in * sigmas[i] 193 | denoised = denoiser( 194 | *self.guider.prepare_inputs(x, sigma, cond, uc), **kwargs 195 | ) 196 | denoised = self.guider(denoised, sigma) 197 | d = to_d(x, sigma, denoised) 198 | ds.append(d) 199 | if len(ds) > self.order: 200 | ds.pop(0) 201 | cur_order = min(i + 1, self.order) 202 | coeffs = [ 203 | linear_multistep_coeff(cur_order, sigmas_cpu, i, j) 204 | for j in range(cur_order) 205 | ] 206 | x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds))) 207 | 208 | return x 209 | 210 | 211 | class EulerEDMSampler(EDMSampler): 212 | def possible_correction_step( 213 | self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc 214 | ): 215 | return euler_step 216 | 217 | 218 | class HeunEDMSampler(EDMSampler): 219 | def possible_correction_step( 220 | self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc 221 | ): 222 | if torch.sum(next_sigma) < 1e-14: 223 | # Save a network evaluation if all noise levels are 0 224 | return euler_step 225 | else: 226 | denoised = self.denoise(euler_step, denoiser, next_sigma, cond, uc) 227 | d_new = to_d(euler_step, next_sigma, denoised) 228 | d_prime = (d + d_new) / 2.0 229 | 230 | # apply correction if noise level is not 0 231 | x = torch.where( 232 | append_dims(next_sigma, x.ndim) > 0.0, x + d_prime * dt, euler_step 233 | ) 234 | return x 235 | 236 | 237 | class EulerAncestralSampler(AncestralSampler): 238 | def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc): 239 | sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta) 240 | denoised = self.denoise(x, denoiser, sigma, cond, uc) 241 | x = self.ancestral_euler_step(x, denoised, sigma, sigma_down) 242 | x = self.ancestral_step(x, sigma, next_sigma, sigma_up) 243 | 244 | return x 245 | 246 | 247 | class DPMPP2SAncestralSampler(AncestralSampler): 248 | def get_variables(self, sigma, sigma_down): 249 | t, t_next = [to_neg_log_sigma(s) for s in (sigma, sigma_down)] 250 | h = t_next - t 251 | s = t + 0.5 * h 252 | return h, s, t, t_next 253 | 254 | def get_mult(self, h, s, t, t_next): 255 | mult1 = to_sigma(s) / to_sigma(t) 256 | mult2 = (-0.5 * h).expm1() 257 | mult3 = to_sigma(t_next) / to_sigma(t) 258 | mult4 = (-h).expm1() 259 | 260 | return mult1, mult2, mult3, mult4 261 | 262 | def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, **kwargs): 263 | sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta) 264 | denoised = self.denoise(x, denoiser, sigma, cond, uc) 265 | x_euler = self.ancestral_euler_step(x, denoised, sigma, sigma_down) 266 | 267 | if torch.sum(sigma_down) < 1e-14: 268 | # Save a network evaluation if all noise levels are 0 269 | x = x_euler 270 | else: 271 | h, s, t, t_next = self.get_variables(sigma, sigma_down) 272 | mult = [ 273 | append_dims(mult, x.ndim) for mult in self.get_mult(h, s, t, t_next) 274 | ] 275 | 276 | x2 = mult[0] * x - mult[1] * denoised 277 | denoised2 = self.denoise(x2, denoiser, to_sigma(s), cond, uc) 278 | x_dpmpp2s = mult[2] * x - mult[3] * denoised2 279 | 280 | # apply correction if noise level is not 0 281 | x = torch.where(append_dims(sigma_down, x.ndim) > 0.0, x_dpmpp2s, x_euler) 282 | 283 | x = self.ancestral_step(x, sigma, next_sigma, sigma_up) 284 | return x 285 | 286 | 287 | class DPMPP2MSampler(BaseDiffusionSampler): 288 | def get_variables(self, sigma, next_sigma, previous_sigma=None): 289 | t, t_next = [to_neg_log_sigma(s) for s in (sigma, next_sigma)] 290 | h = t_next - t 291 | 292 | if previous_sigma is not None: 293 | h_last = t - to_neg_log_sigma(previous_sigma) 294 | r = h_last / h 295 | return h, r, t, t_next 296 | else: 297 | return h, None, t, t_next 298 | 299 | def get_mult(self, h, r, t, t_next, previous_sigma): 300 | mult1 = to_sigma(t_next) / to_sigma(t) 301 | mult2 = (-h).expm1() 302 | 303 | if previous_sigma is not None: 304 | mult3 = 1 + 1 / (2 * r) 305 | mult4 = 1 / (2 * r) 306 | return mult1, mult2, mult3, mult4 307 | else: 308 | return mult1, mult2 309 | 310 | def sampler_step( 311 | self, 312 | old_denoised, 313 | previous_sigma, 314 | sigma, 315 | next_sigma, 316 | denoiser, 317 | x, 318 | cond, 319 | uc=None, 320 | ): 321 | denoised = self.denoise(x, denoiser, sigma, cond, uc) 322 | 323 | h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma) 324 | mult = [ 325 | append_dims(mult, x.ndim) 326 | for mult in self.get_mult(h, r, t, t_next, previous_sigma) 327 | ] 328 | 329 | x_standard = mult[0] * x - mult[1] * denoised 330 | if old_denoised is None or torch.sum(next_sigma) < 1e-14: 331 | # Save a network evaluation if all noise levels are 0 or on the first step 332 | return x_standard, denoised 333 | else: 334 | denoised_d = mult[2] * denoised - mult[3] * old_denoised 335 | x_advanced = mult[0] * x - mult[1] * denoised_d 336 | 337 | # apply correction if noise level is not 0 and not first step 338 | x = torch.where( 339 | append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard 340 | ) 341 | 342 | return x, denoised 343 | 344 | def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs): 345 | x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( 346 | x, cond, uc, num_steps 347 | ) 348 | 349 | old_denoised = None 350 | for i in self.get_sigma_gen(num_sigmas): 351 | x, old_denoised = self.sampler_step( 352 | old_denoised, 353 | None if i == 0 else s_in * sigmas[i - 1], 354 | s_in * sigmas[i], 355 | s_in * sigmas[i + 1], 356 | denoiser, 357 | x, 358 | cond, 359 | uc=uc, 360 | ) 361 | 362 | return x 363 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/sampling_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from scipy import integrate 3 | 4 | from ...util import append_dims 5 | 6 | 7 | def linear_multistep_coeff(order, t, i, j, epsrel=1e-4): 8 | if order - 1 > i: 9 | raise ValueError(f"Order {order} too high for step {i}") 10 | 11 | def fn(tau): 12 | prod = 1.0 13 | for k in range(order): 14 | if j == k: 15 | continue 16 | prod *= (tau - t[i - k]) / (t[i - j] - t[i - k]) 17 | return prod 18 | 19 | return integrate.quad(fn, t[i], t[i + 1], epsrel=epsrel)[0] 20 | 21 | 22 | def get_ancestral_step(sigma_from, sigma_to, eta=1.0): 23 | if not eta: 24 | return sigma_to, 0.0 25 | sigma_up = torch.minimum( 26 | sigma_to, 27 | eta 28 | * (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5, 29 | ) 30 | sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 31 | return sigma_down, sigma_up 32 | 33 | 34 | def to_d(x, sigma, denoised): 35 | return (x - denoised) / append_dims(sigma, x.ndim) 36 | 37 | 38 | def to_neg_log_sigma(sigma): 39 | return sigma.log().neg() 40 | 41 | 42 | def to_sigma(neg_log_sigma): 43 | return neg_log_sigma.neg().exp() 44 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/sigma_sampling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ...util import default, instantiate_from_config 4 | 5 | 6 | class EDMSampling: 7 | def __init__(self, p_mean=-1.2, p_std=1.2): 8 | self.p_mean = p_mean 9 | self.p_std = p_std 10 | 11 | def __call__(self, n_samples, rand=None): 12 | log_sigma = self.p_mean + self.p_std * default(rand, torch.randn((n_samples,))) 13 | return log_sigma.exp() 14 | 15 | 16 | class DiscreteSampling: 17 | def __init__(self, discretization_config, num_idx, do_append_zero=False, flip=True): 18 | self.num_idx = num_idx 19 | self.sigmas = instantiate_from_config(discretization_config)( 20 | num_idx, do_append_zero=do_append_zero, flip=flip 21 | ) 22 | 23 | def idx_to_sigma(self, idx): 24 | return self.sigmas[idx] 25 | 26 | def __call__(self, n_samples, rand=None): 27 | idx = default( 28 | rand, 29 | torch.randint(0, self.num_idx, (n_samples,)), 30 | ) 31 | return self.idx_to_sigma(idx) 32 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/util.py: -------------------------------------------------------------------------------- 1 | """ 2 | partially adopted from 3 | https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 4 | and 5 | https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py 6 | and 7 | https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py 8 | 9 | thanks! 10 | """ 11 | 12 | import math 13 | from typing import Optional 14 | 15 | import torch 16 | import torch.nn as nn 17 | from einops import rearrange, repeat 18 | 19 | 20 | def make_beta_schedule( 21 | schedule, 22 | n_timestep, 23 | linear_start=1e-4, 24 | linear_end=2e-2, 25 | ): 26 | if schedule == "linear": 27 | betas = ( 28 | torch.linspace( 29 | linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64 30 | ) 31 | ** 2 32 | ) 33 | return betas.numpy() 34 | 35 | 36 | def extract_into_tensor(a, t, x_shape): 37 | b, *_ = t.shape 38 | out = a.gather(-1, t) 39 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 40 | 41 | 42 | def mixed_checkpoint(func, inputs: dict, params, flag): 43 | """ 44 | Evaluate a function without caching intermediate activations, allowing for 45 | reduced memory at the expense of extra compute in the backward pass. This differs from the original checkpoint function 46 | borrowed from https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py in that 47 | it also works with non-tensor inputs 48 | :param func: the function to evaluate. 49 | :param inputs: the argument dictionary to pass to `func`. 50 | :param params: a sequence of parameters `func` depends on but does not 51 | explicitly take as arguments. 52 | :param flag: if False, disable gradient checkpointing. 53 | """ 54 | if flag: 55 | tensor_keys = [key for key in inputs if isinstance(inputs[key], torch.Tensor)] 56 | tensor_inputs = [ 57 | inputs[key] for key in inputs if isinstance(inputs[key], torch.Tensor) 58 | ] 59 | non_tensor_keys = [ 60 | key for key in inputs if not isinstance(inputs[key], torch.Tensor) 61 | ] 62 | non_tensor_inputs = [ 63 | inputs[key] for key in inputs if not isinstance(inputs[key], torch.Tensor) 64 | ] 65 | args = tuple(tensor_inputs) + tuple(non_tensor_inputs) + tuple(params) 66 | return MixedCheckpointFunction.apply( 67 | func, 68 | len(tensor_inputs), 69 | len(non_tensor_inputs), 70 | tensor_keys, 71 | non_tensor_keys, 72 | *args, 73 | ) 74 | else: 75 | return func(**inputs) 76 | 77 | 78 | class MixedCheckpointFunction(torch.autograd.Function): 79 | @staticmethod 80 | def forward( 81 | ctx, 82 | run_function, 83 | length_tensors, 84 | length_non_tensors, 85 | tensor_keys, 86 | non_tensor_keys, 87 | *args, 88 | ): 89 | ctx.end_tensors = length_tensors 90 | ctx.end_non_tensors = length_tensors + length_non_tensors 91 | ctx.gpu_autocast_kwargs = { 92 | "enabled": torch.is_autocast_enabled(), 93 | "dtype": torch.get_autocast_gpu_dtype(), 94 | "cache_enabled": torch.is_autocast_cache_enabled(), 95 | } 96 | assert ( 97 | len(tensor_keys) == length_tensors 98 | and len(non_tensor_keys) == length_non_tensors 99 | ) 100 | 101 | ctx.input_tensors = { 102 | key: val for (key, val) in zip(tensor_keys, list(args[: ctx.end_tensors])) 103 | } 104 | ctx.input_non_tensors = { 105 | key: val 106 | for (key, val) in zip( 107 | non_tensor_keys, list(args[ctx.end_tensors : ctx.end_non_tensors]) 108 | ) 109 | } 110 | ctx.run_function = run_function 111 | ctx.input_params = list(args[ctx.end_non_tensors :]) 112 | 113 | with torch.no_grad(): 114 | output_tensors = ctx.run_function( 115 | **ctx.input_tensors, **ctx.input_non_tensors 116 | ) 117 | return output_tensors 118 | 119 | @staticmethod 120 | def backward(ctx, *output_grads): 121 | # additional_args = {key: ctx.input_tensors[key] for key in ctx.input_tensors if not isinstance(ctx.input_tensors[key],torch.Tensor)} 122 | ctx.input_tensors = { 123 | key: ctx.input_tensors[key].detach().requires_grad_(True) 124 | for key in ctx.input_tensors 125 | } 126 | 127 | with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): 128 | # Fixes a bug where the first op in run_function modifies the 129 | # Tensor storage in place, which is not allowed for detach()'d 130 | # Tensors. 131 | shallow_copies = { 132 | key: ctx.input_tensors[key].view_as(ctx.input_tensors[key]) 133 | for key in ctx.input_tensors 134 | } 135 | # shallow_copies.update(additional_args) 136 | output_tensors = ctx.run_function(**shallow_copies, **ctx.input_non_tensors) 137 | input_grads = torch.autograd.grad( 138 | output_tensors, 139 | list(ctx.input_tensors.values()) + ctx.input_params, 140 | output_grads, 141 | allow_unused=True, 142 | ) 143 | del ctx.input_tensors 144 | del ctx.input_params 145 | del output_tensors 146 | return ( 147 | (None, None, None, None, None) 148 | + input_grads[: ctx.end_tensors] 149 | + (None,) * (ctx.end_non_tensors - ctx.end_tensors) 150 | + input_grads[ctx.end_tensors :] 151 | ) 152 | 153 | 154 | def checkpoint(func, inputs, params, flag): 155 | """ 156 | Evaluate a function without caching intermediate activations, allowing for 157 | reduced memory at the expense of extra compute in the backward pass. 158 | :param func: the function to evaluate. 159 | :param inputs: the argument sequence to pass to `func`. 160 | :param params: a sequence of parameters `func` depends on but does not 161 | explicitly take as arguments. 162 | :param flag: if False, disable gradient checkpointing. 163 | """ 164 | if flag: 165 | args = tuple(inputs) + tuple(params) 166 | return CheckpointFunction.apply(func, len(inputs), *args) 167 | else: 168 | return func(*inputs) 169 | 170 | 171 | class CheckpointFunction(torch.autograd.Function): 172 | @staticmethod 173 | def forward(ctx, run_function, length, *args): 174 | ctx.run_function = run_function 175 | ctx.input_tensors = list(args[:length]) 176 | ctx.input_params = list(args[length:]) 177 | ctx.gpu_autocast_kwargs = { 178 | "enabled": torch.is_autocast_enabled(), 179 | "dtype": torch.get_autocast_gpu_dtype(), 180 | "cache_enabled": torch.is_autocast_cache_enabled(), 181 | } 182 | with torch.no_grad(): 183 | output_tensors = ctx.run_function(*ctx.input_tensors) 184 | return output_tensors 185 | 186 | @staticmethod 187 | def backward(ctx, *output_grads): 188 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 189 | with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): 190 | # Fixes a bug where the first op in run_function modifies the 191 | # Tensor storage in place, which is not allowed for detach()'d 192 | # Tensors. 193 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 194 | output_tensors = ctx.run_function(*shallow_copies) 195 | input_grads = torch.autograd.grad( 196 | output_tensors, 197 | ctx.input_tensors + ctx.input_params, 198 | output_grads, 199 | allow_unused=True, 200 | ) 201 | del ctx.input_tensors 202 | del ctx.input_params 203 | del output_tensors 204 | return (None, None) + input_grads 205 | 206 | 207 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 208 | """ 209 | Create sinusoidal timestep embeddings. 210 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 211 | These may be fractional. 212 | :param dim: the dimension of the output. 213 | :param max_period: controls the minimum frequency of the embeddings. 214 | :return: an [N x dim] Tensor of positional embeddings. 215 | """ 216 | if not repeat_only: 217 | half = dim // 2 218 | freqs = torch.exp( 219 | -math.log(max_period) 220 | * torch.arange(start=0, end=half, dtype=torch.float32) 221 | / half 222 | ).to(device=timesteps.device) 223 | args = timesteps[:, None].float() * freqs[None] 224 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 225 | if dim % 2: 226 | embedding = torch.cat( 227 | [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 228 | ) 229 | else: 230 | embedding = repeat(timesteps, "b -> b d", d=dim) 231 | return embedding 232 | 233 | 234 | def zero_module(module): 235 | """ 236 | Zero out the parameters of a module and return it. 237 | """ 238 | for p in module.parameters(): 239 | p.detach().zero_() 240 | return module 241 | 242 | 243 | def scale_module(module, scale): 244 | """ 245 | Scale the parameters of a module and return it. 246 | """ 247 | for p in module.parameters(): 248 | p.detach().mul_(scale) 249 | return module 250 | 251 | 252 | def mean_flat(tensor): 253 | """ 254 | Take the mean over all non-batch dimensions. 255 | """ 256 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 257 | 258 | 259 | def normalization(channels, affine=True): 260 | """ 261 | Make a standard normalization layer. 262 | :param channels: number of input channels. 263 | :return: an nn.Module for normalization. 264 | """ 265 | return GroupNorm32(32, channels, affine=affine) 266 | 267 | def group_normalization(groups, channels, affine=True): 268 | return GroupNorm(groups, channels, affine=affine) 269 | 270 | 271 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 272 | class SiLU(nn.Module): 273 | def forward(self, x): 274 | return x * torch.sigmoid(x) 275 | 276 | 277 | class GroupNorm32(nn.GroupNorm): 278 | def forward(self, x): 279 | return super().forward(x.float()).type(x.dtype) 280 | 281 | class GroupNorm(nn.GroupNorm): 282 | def forward(self, x): 283 | return super().forward(x.float()).type(x.dtype) 284 | 285 | 286 | import comfy.ops 287 | ops = comfy.ops.manual_cast 288 | 289 | def conv_nd(dims, *args, **kwargs): 290 | """ 291 | Create a 1D, 2D, or 3D convolution module. 292 | """ 293 | if dims == 1: 294 | return ops.Conv1d(*args, **kwargs) 295 | elif dims == 2: 296 | return ops.Conv2d(*args, **kwargs) 297 | elif dims == 3: 298 | return ops.Conv3d(*args, **kwargs) 299 | raise ValueError(f"unsupported dimensions: {dims}") 300 | 301 | 302 | def linear(*args, **kwargs): 303 | """ 304 | Create a linear module. 305 | """ 306 | return ops.Linear(*args, **kwargs) 307 | 308 | 309 | def avg_pool_nd(dims, *args, **kwargs): 310 | """ 311 | Create a 1D, 2D, or 3D average pooling module. 312 | """ 313 | if dims == 1: 314 | return nn.AvgPool1d(*args, **kwargs) 315 | elif dims == 2: 316 | return nn.AvgPool2d(*args, **kwargs) 317 | elif dims == 3: 318 | return nn.AvgPool3d(*args, **kwargs) 319 | raise ValueError(f"unsupported dimensions: {dims}") 320 | 321 | 322 | class AlphaBlender(nn.Module): 323 | strategies = ["learned", "fixed", "learned_with_images"] 324 | 325 | def __init__( 326 | self, 327 | alpha: float, 328 | merge_strategy: str = "learned_with_images", 329 | rearrange_pattern: str = "b t -> (b t) 1 1", 330 | ): 331 | super().__init__() 332 | self.merge_strategy = merge_strategy 333 | self.rearrange_pattern = rearrange_pattern 334 | 335 | assert ( 336 | merge_strategy in self.strategies 337 | ), f"merge_strategy needs to be in {self.strategies}" 338 | 339 | if self.merge_strategy == "fixed": 340 | self.register_buffer("mix_factor", torch.Tensor([alpha])) 341 | elif ( 342 | self.merge_strategy == "learned" 343 | or self.merge_strategy == "learned_with_images" 344 | ): 345 | self.register_parameter( 346 | "mix_factor", torch.nn.Parameter(torch.Tensor([alpha])) 347 | ) 348 | else: 349 | raise ValueError(f"unknown merge strategy {self.merge_strategy}") 350 | 351 | def get_alpha(self, image_only_indicator: torch.Tensor) -> torch.Tensor: 352 | if self.merge_strategy == "fixed": 353 | alpha = self.mix_factor 354 | elif self.merge_strategy == "learned": 355 | alpha = torch.sigmoid(self.mix_factor) 356 | elif self.merge_strategy == "learned_with_images": 357 | assert image_only_indicator is not None, "need image_only_indicator ..." 358 | alpha = torch.where( 359 | image_only_indicator.bool(), 360 | torch.ones(1, 1, device=image_only_indicator.device), 361 | rearrange(torch.sigmoid(self.mix_factor), "... -> ... 1"), 362 | ) 363 | alpha = rearrange(alpha, self.rearrange_pattern) 364 | else: 365 | raise NotImplementedError 366 | return alpha 367 | 368 | def forward( 369 | self, 370 | x_spatial: torch.Tensor, 371 | x_temporal: torch.Tensor, 372 | image_only_indicator: Optional[torch.Tensor] = None, 373 | ) -> torch.Tensor: 374 | alpha = self.get_alpha(image_only_indicator) 375 | a = 0 376 | x = ( 377 | alpha.to(x_spatial.dtype) * x_spatial 378 | + (1.0 - alpha).to(x_spatial.dtype) * x_temporal 379 | ) 380 | return x 381 | -------------------------------------------------------------------------------- /sgm/modules/diffusionmodules/wrappers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from packaging import version 4 | 5 | OPENAIUNETWRAPPER = ".sgm.modules.diffusionmodules.wrappers.OpenAIWrapper" 6 | 7 | 8 | class IdentityWrapper(nn.Module): 9 | def __init__(self, diffusion_model, compile_model: bool = False): 10 | super().__init__() 11 | compile = ( 12 | torch.compile 13 | if (version.parse(torch.__version__) >= version.parse("2.0.0")) 14 | and compile_model 15 | else lambda x: x 16 | ) 17 | self.diffusion_model = compile(diffusion_model) 18 | 19 | def forward(self, *args, **kwargs): 20 | return self.diffusion_model(*args, **kwargs) 21 | 22 | 23 | class OpenAIWrapper(IdentityWrapper): 24 | def forward( 25 | self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs 26 | ) -> torch.Tensor: 27 | x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1) 28 | return self.diffusion_model( 29 | x, 30 | timesteps=t, 31 | context=c.get("crossattn", None), 32 | y=c.get("vector", None), 33 | **kwargs, 34 | ) 35 | -------------------------------------------------------------------------------- /sgm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kijai/ComfyUI-LVCDWrapper/081c8180029b1b5eb8f416e079456311ff467c83/sgm/modules/distributions/__init__.py -------------------------------------------------------------------------------- /sgm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to( 34 | device=self.parameters.device 35 | ) 36 | 37 | def sample(self): 38 | x = self.mean + self.std * torch.randn(self.mean.shape).to( 39 | device=self.parameters.device 40 | ) 41 | return x 42 | 43 | def kl(self, other=None): 44 | if self.deterministic: 45 | return torch.Tensor([0.0]) 46 | else: 47 | if other is None: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, 50 | dim=[1, 2, 3], 51 | ) 52 | else: 53 | return 0.5 * torch.sum( 54 | torch.pow(self.mean - other.mean, 2) / other.var 55 | + self.var / other.var 56 | - 1.0 57 | - self.logvar 58 | + other.logvar, 59 | dim=[1, 2, 3], 60 | ) 61 | 62 | def nll(self, sample, dims=[1, 2, 3]): 63 | if self.deterministic: 64 | return torch.Tensor([0.0]) 65 | logtwopi = np.log(2.0 * np.pi) 66 | return 0.5 * torch.sum( 67 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 68 | dim=dims, 69 | ) 70 | 71 | def mode(self): 72 | return self.mean 73 | 74 | 75 | def normal_kl(mean1, logvar1, mean2, logvar2): 76 | """ 77 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 78 | Compute the KL divergence between two gaussians. 79 | Shapes are automatically broadcasted, so batches can be compared to 80 | scalars, among other use cases. 81 | """ 82 | tensor = None 83 | for obj in (mean1, logvar1, mean2, logvar2): 84 | if isinstance(obj, torch.Tensor): 85 | tensor = obj 86 | break 87 | assert tensor is not None, "at least one argument must be a Tensor" 88 | 89 | # Force variances to be Tensors. Broadcasting helps convert scalars to 90 | # Tensors, but it does not work for torch.exp(). 91 | logvar1, logvar2 = [ 92 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 93 | for x in (logvar1, logvar2) 94 | ] 95 | 96 | return 0.5 * ( 97 | -1.0 98 | + logvar2 99 | - logvar1 100 | + torch.exp(logvar1 - logvar2) 101 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 102 | ) 103 | -------------------------------------------------------------------------------- /sgm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError("Decay must be between 0 and 1") 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer( 14 | "num_updates", 15 | torch.tensor(0, dtype=torch.int) 16 | if use_num_upates 17 | else torch.tensor(-1, dtype=torch.int), 18 | ) 19 | 20 | for name, p in model.named_parameters(): 21 | if p.requires_grad: 22 | # remove as '.'-character is not allowed in buffers 23 | s_name = name.replace(".", "") 24 | self.m_name2s_name.update({name: s_name}) 25 | self.register_buffer(s_name, p.clone().detach().data) 26 | 27 | self.collected_params = [] 28 | 29 | def reset_num_updates(self): 30 | del self.num_updates 31 | self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int)) 32 | 33 | def forward(self, model): 34 | decay = self.decay 35 | 36 | if self.num_updates >= 0: 37 | self.num_updates += 1 38 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) 39 | 40 | one_minus_decay = 1.0 - decay 41 | 42 | with torch.no_grad(): 43 | m_param = dict(model.named_parameters()) 44 | shadow_params = dict(self.named_buffers()) 45 | 46 | for key in m_param: 47 | if m_param[key].requires_grad: 48 | sname = self.m_name2s_name[key] 49 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 50 | shadow_params[sname].sub_( 51 | one_minus_decay * (shadow_params[sname] - m_param[key]) 52 | ) 53 | else: 54 | assert not key in self.m_name2s_name 55 | 56 | def copy_to(self, model): 57 | m_param = dict(model.named_parameters()) 58 | shadow_params = dict(self.named_buffers()) 59 | for key in m_param: 60 | if m_param[key].requires_grad: 61 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 62 | else: 63 | assert not key in self.m_name2s_name 64 | 65 | def store(self, parameters): 66 | """ 67 | Save the current parameters for restoring later. 68 | Args: 69 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 70 | temporarily stored. 71 | """ 72 | self.collected_params = [param.clone() for param in parameters] 73 | 74 | def restore(self, parameters): 75 | """ 76 | Restore the parameters stored with the `store` method. 77 | Useful to validate the model with EMA parameters without affecting the 78 | original optimization process. Store the parameters before the 79 | `copy_to` method. After validation (or model saving), use this to 80 | restore the former parameters. 81 | Args: 82 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 83 | updated with the stored parameters. 84 | """ 85 | for c_param, param in zip(self.collected_params, parameters): 86 | param.data.copy_(c_param.data) 87 | -------------------------------------------------------------------------------- /sgm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kijai/ComfyUI-LVCDWrapper/081c8180029b1b5eb8f416e079456311ff467c83/sgm/modules/encoders/__init__.py -------------------------------------------------------------------------------- /sgm/modules/video_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..modules.attention import * 4 | from ..modules.diffusionmodules.util import AlphaBlender, linear, timestep_embedding 5 | 6 | 7 | class TimeMixSequential(nn.Sequential): 8 | def forward(self, x, context=None, timesteps=None): 9 | for layer in self: 10 | x = layer(x, context, timesteps) 11 | 12 | return x 13 | 14 | 15 | class VideoTransformerBlock(nn.Module): 16 | ATTENTION_MODES = { 17 | "softmax": CrossAttention, 18 | "softmax-masked": CrossAttention_Masked, 19 | "softmax-xformers": MemoryEfficientCrossAttention, 20 | } 21 | 22 | def __init__( 23 | self, 24 | dim, 25 | n_heads, 26 | d_head, 27 | dropout=0.0, 28 | context_dim=None, 29 | gated_ff=True, 30 | checkpoint=True, 31 | timesteps=None, 32 | ff_in=False, 33 | inner_dim=None, 34 | attn_mode="softmax", 35 | temporal_attn_mode=None, 36 | disable_self_attn=False, 37 | disable_temporal_crossattention=False, 38 | switch_temporal_ca_to_sa=False, 39 | ): 40 | super().__init__() 41 | 42 | attn_cls = self.ATTENTION_MODES[attn_mode] 43 | if temporal_attn_mode is None: 44 | temp_attn_cls = attn_cls 45 | elif isinstance(temporal_attn_mode, str): 46 | temp_attn_cls = self.ATTENTION_MODES[temporal_attn_mode] 47 | else: 48 | temp_attn_cls = temporal_attn_mode 49 | 50 | self.ff_in = ff_in or inner_dim is not None 51 | if inner_dim is None: 52 | inner_dim = dim 53 | 54 | assert int(n_heads * d_head) == inner_dim 55 | 56 | self.is_res = inner_dim == dim 57 | 58 | if self.ff_in: 59 | self.norm_in = nn.LayerNorm(dim) 60 | self.ff_in = FeedForward( 61 | dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff 62 | ) 63 | 64 | self.timesteps = timesteps 65 | self.disable_self_attn = disable_self_attn 66 | if self.disable_self_attn: 67 | self.attn1 = attn_cls( 68 | query_dim=inner_dim, 69 | heads=n_heads, 70 | dim_head=d_head, 71 | context_dim=context_dim, 72 | dropout=dropout, 73 | ) # is a cross-attention 74 | else: 75 | self.attn1 = temp_attn_cls( 76 | query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout 77 | ) # is a self-attention 78 | 79 | self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff) 80 | 81 | if disable_temporal_crossattention: 82 | if switch_temporal_ca_to_sa: 83 | raise ValueError 84 | else: 85 | self.attn2 = None 86 | else: 87 | self.norm2 = nn.LayerNorm(inner_dim) 88 | if switch_temporal_ca_to_sa: 89 | self.attn2 = temp_attn_cls( 90 | query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout 91 | ) # is a self-attention 92 | else: 93 | self.attn2 = attn_cls( 94 | query_dim=inner_dim, 95 | context_dim=context_dim, 96 | heads=n_heads, 97 | dim_head=d_head, 98 | dropout=dropout, 99 | ) # is self-attn if context is none 100 | 101 | self.norm1 = nn.LayerNorm(inner_dim) 102 | self.norm3 = nn.LayerNorm(inner_dim) 103 | self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa 104 | 105 | self.checkpoint = checkpoint 106 | #if self.checkpoint: 107 | #print(f"{self.__class__.__name__} is using checkpointing") 108 | 109 | def forward( 110 | self, x: torch.Tensor, 111 | pos_emb: torch.Tensor, 112 | context: torch.Tensor = None, 113 | timesteps: int = None 114 | ) -> torch.Tensor: 115 | if hasattr(self, '_forward_hooks') and len(self._forward_hooks) > 0: 116 | # If hooked do nothing 117 | self.timesteps = timesteps 118 | return x 119 | else: 120 | if self.checkpoint: 121 | return checkpoint(self._forward, x, pos_emb, context, timesteps) 122 | else: 123 | return self._forward(x, pos_emb, context=context, timesteps=timesteps) 124 | 125 | def _forward(self, x, pos_emb, context=None, timesteps=None): 126 | assert self.timesteps or timesteps 127 | assert not (self.timesteps and timesteps) or self.timesteps == timesteps 128 | timesteps = self.timesteps or timesteps 129 | B, S, C = x.shape 130 | x = x + pos_emb 131 | x = rearrange(x, "(b t) s c -> (b s) t c", t=timesteps) 132 | 133 | if self.ff_in: 134 | x_skip = x 135 | x = self.ff_in(self.norm_in(x)) 136 | if self.is_res: 137 | x += x_skip 138 | 139 | if self.disable_self_attn: 140 | x = self.attn1(self.norm1(x), context=context) + x 141 | else: 142 | x = self.attn1(self.norm1(x)) + x 143 | 144 | if self.attn2 is not None: 145 | if self.switch_temporal_ca_to_sa: 146 | x = self.attn2(self.norm2(x)) + x 147 | else: 148 | x = self.attn2(self.norm2(x), context=context) + x 149 | x_skip = x 150 | x = self.ff(self.norm3(x)) 151 | if self.is_res: 152 | x += x_skip 153 | 154 | x = rearrange( 155 | x, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps 156 | ) 157 | return x 158 | 159 | def get_last_layer(self): 160 | return self.ff.net[-1].weight 161 | 162 | 163 | class SpatialVideoTransformer(SpatialTransformer): 164 | def __init__( 165 | self, 166 | in_channels, 167 | n_heads, 168 | d_head, 169 | depth=1, 170 | dropout=0.0, 171 | use_linear=False, 172 | context_dim=None, 173 | use_spatial_context=False, 174 | timesteps=None, 175 | merge_strategy: str = "fixed", 176 | merge_factor: float = 0.5, 177 | time_context_dim=None, 178 | ff_in=False, 179 | checkpoint=False, 180 | time_depth=1, 181 | attn_mode="softmax", 182 | temporal_attn_mode=None, 183 | disable_self_attn=False, 184 | disable_temporal_crossattention=False, 185 | max_time_embed_period: int = 10000, 186 | additional_attn_cond = False, 187 | spatial_self_attn_type: str = None, 188 | ): 189 | super().__init__( 190 | in_channels, 191 | n_heads, 192 | d_head, 193 | depth=depth, 194 | dropout=dropout, 195 | attn_type=attn_mode, 196 | use_checkpoint=checkpoint, 197 | context_dim=context_dim, 198 | use_linear=use_linear, 199 | disable_self_attn=disable_self_attn, 200 | additional_attn_cond=additional_attn_cond, 201 | spatial_self_attn_type=spatial_self_attn_type, 202 | ) 203 | self.time_depth = time_depth 204 | self.depth = depth 205 | self.max_time_embed_period = max_time_embed_period 206 | 207 | time_mix_d_head = d_head 208 | n_time_mix_heads = n_heads 209 | 210 | time_mix_inner_dim = int(time_mix_d_head * n_time_mix_heads) 211 | 212 | inner_dim = n_heads * d_head 213 | if use_spatial_context: 214 | time_context_dim = context_dim 215 | 216 | self.time_stack = nn.ModuleList( 217 | [ 218 | VideoTransformerBlock( 219 | inner_dim, 220 | n_time_mix_heads, 221 | time_mix_d_head, 222 | dropout=dropout, 223 | context_dim=time_context_dim, 224 | timesteps=timesteps, 225 | checkpoint=checkpoint, 226 | ff_in=ff_in, 227 | inner_dim=time_mix_inner_dim, 228 | attn_mode=attn_mode, 229 | temporal_attn_mode=temporal_attn_mode, 230 | disable_self_attn=disable_self_attn, 231 | disable_temporal_crossattention=disable_temporal_crossattention, 232 | ) 233 | for _ in range(self.depth) 234 | ] 235 | ) 236 | 237 | assert len(self.time_stack) == len(self.transformer_blocks) 238 | 239 | self.use_spatial_context = use_spatial_context 240 | self.in_channels = in_channels 241 | 242 | time_embed_dim = self.in_channels * 4 243 | self.time_pos_embed = nn.Sequential( 244 | linear(self.in_channels, time_embed_dim), 245 | nn.SiLU(), 246 | linear(time_embed_dim, self.in_channels), 247 | ) 248 | 249 | self.time_mixer = AlphaBlender( 250 | alpha=merge_factor, merge_strategy=merge_strategy 251 | ) 252 | 253 | def forward( 254 | self, 255 | x: torch.Tensor, 256 | context: Optional[torch.Tensor] = None, 257 | time_context: Optional[torch.Tensor] = None, 258 | attn_cond: Optional[torch.Tensor] = None, 259 | timesteps: Optional[int] = None, 260 | image_only_indicator: Optional[torch.Tensor] = None, 261 | ) -> torch.Tensor: 262 | _, _, h, w = x.shape 263 | x_in = x 264 | spatial_context = None 265 | if exists(context): 266 | spatial_context = context 267 | 268 | if self.use_spatial_context: 269 | assert ( 270 | context.ndim == 3 271 | ), f"n dims of spatial context should be 3 but are {context.ndim}" 272 | 273 | time_context = context 274 | time_context_first_timestep = time_context[::timesteps] 275 | time_context = repeat( 276 | time_context_first_timestep, "b ... -> (b n) ...", n=h * w 277 | ) 278 | elif time_context is not None and not self.use_spatial_context: 279 | time_context = repeat(time_context, "b ... -> (b n) ...", n=h * w) 280 | if time_context.ndim == 2: 281 | time_context = rearrange(time_context, "b c -> b 1 c") 282 | 283 | x = self.norm(x) 284 | if not self.use_linear: 285 | x = self.proj_in(x) 286 | x = rearrange(x, "b c h w -> b (h w) c") 287 | if self.use_linear: 288 | x = self.proj_in(x) 289 | 290 | num_frames = torch.arange(timesteps, device=x.device) 291 | num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps) 292 | num_frames = rearrange(num_frames, "b t -> (b t)") 293 | t_emb = timestep_embedding( 294 | num_frames, 295 | self.in_channels, 296 | repeat_only=False, 297 | max_period=self.max_time_embed_period, 298 | ) 299 | if timesteps > 14: 300 | dt = timesteps - 14 301 | t_emb = rearrange(t_emb, '(b t) ... -> b t ...', t=timesteps) 302 | t_embs = [ t_emb[:,[0]] ] * dt 303 | t_embs += [ t_emb[:, :-dt] ] 304 | t_emb = torch.cat(t_embs, dim=1) 305 | t_emb = rearrange(t_emb, 'b t ... -> (b t) ...') 306 | emb = self.time_pos_embed(t_emb) 307 | emb = emb[:, None, :] 308 | 309 | for it_, (block, mix_block) in enumerate( 310 | zip(self.transformer_blocks, self.time_stack) 311 | ): 312 | x = block( 313 | x, 314 | context=spatial_context, 315 | attn_cond=attn_cond, 316 | ) 317 | 318 | x_mix = x 319 | 320 | x_mix = mix_block(x_mix, emb, time_context, timesteps=timesteps) 321 | x = self.time_mixer( 322 | x_spatial=x, 323 | x_temporal=x_mix, 324 | image_only_indicator=image_only_indicator, 325 | ) 326 | if self.use_linear: 327 | x = self.proj_out(x) 328 | x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) 329 | if not self.use_linear: 330 | x = self.proj_out(x) 331 | out = x + x_in 332 | return out 333 | -------------------------------------------------------------------------------- /sgm/util.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import importlib 3 | import os 4 | from functools import partial 5 | from inspect import isfunction 6 | 7 | import fsspec 8 | import numpy as np 9 | import torch 10 | from PIL import Image, ImageDraw, ImageFont 11 | from safetensors.torch import load_file as load_safetensors 12 | 13 | 14 | def disabled_train(self, mode=True): 15 | """Overwrite model.train with this function to make sure train/eval mode 16 | does not change anymore.""" 17 | return self 18 | 19 | 20 | def get_string_from_tuple(s): 21 | try: 22 | # Check if the string starts and ends with parentheses 23 | if s[0] == "(" and s[-1] == ")": 24 | # Convert the string to a tuple 25 | t = eval(s) 26 | # Check if the type of t is tuple 27 | if type(t) == tuple: 28 | return t[0] 29 | else: 30 | pass 31 | except: 32 | pass 33 | return s 34 | 35 | 36 | def is_power_of_two(n): 37 | """ 38 | chat.openai.com/chat 39 | Return True if n is a power of 2, otherwise return False. 40 | 41 | The function is_power_of_two takes an integer n as input and returns True if n is a power of 2, otherwise it returns False. 42 | The function works by first checking if n is less than or equal to 0. If n is less than or equal to 0, it can't be a power of 2, so the function returns False. 43 | If n is greater than 0, the function checks whether n is a power of 2 by using a bitwise AND operation between n and n-1. If n is a power of 2, then it will have only one bit set to 1 in its binary representation. When we subtract 1 from a power of 2, all the bits to the right of that bit become 1, and the bit itself becomes 0. So, when we perform a bitwise AND between n and n-1, we get 0 if n is a power of 2, and a non-zero value otherwise. 44 | Thus, if the result of the bitwise AND operation is 0, then n is a power of 2 and the function returns True. Otherwise, the function returns False. 45 | 46 | """ 47 | if n <= 0: 48 | return False 49 | return (n & (n - 1)) == 0 50 | 51 | 52 | def autocast(f, enabled=True): 53 | def do_autocast(*args, **kwargs): 54 | with torch.cuda.amp.autocast( 55 | enabled=enabled, 56 | dtype=torch.get_autocast_gpu_dtype(), 57 | cache_enabled=torch.is_autocast_cache_enabled(), 58 | ): 59 | return f(*args, **kwargs) 60 | 61 | return do_autocast 62 | 63 | 64 | def load_partial_from_config(config): 65 | return partial(get_obj_from_str(config["target"]), **config.get("params", dict())) 66 | 67 | 68 | def log_txt_as_img(wh, xc, size=10): 69 | # wh a tuple of (width, height) 70 | # xc a list of captions to plot 71 | b = len(xc) 72 | txts = list() 73 | for bi in range(b): 74 | txt = Image.new("RGB", wh, color="white") 75 | draw = ImageDraw.Draw(txt) 76 | font = ImageFont.truetype("data/DejaVuSans.ttf", size=size) 77 | nc = int(40 * (wh[0] / 256)) 78 | if isinstance(xc[bi], list): 79 | text_seq = xc[bi][0] 80 | else: 81 | text_seq = xc[bi] 82 | lines = "\n".join( 83 | text_seq[start : start + nc] for start in range(0, len(text_seq), nc) 84 | ) 85 | 86 | try: 87 | draw.text((0, 0), lines, fill="black", font=font) 88 | except UnicodeEncodeError: 89 | print("Cant encode string for logging. Skipping.") 90 | 91 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 92 | txts.append(txt) 93 | txts = np.stack(txts) 94 | txts = torch.tensor(txts) 95 | return txts 96 | 97 | 98 | def partialclass(cls, *args, **kwargs): 99 | class NewCls(cls): 100 | __init__ = functools.partialmethod(cls.__init__, *args, **kwargs) 101 | 102 | return NewCls 103 | 104 | 105 | def make_path_absolute(path): 106 | fs, p = fsspec.core.url_to_fs(path) 107 | if fs.protocol == "file": 108 | return os.path.abspath(p) 109 | return path 110 | 111 | 112 | def ismap(x): 113 | if not isinstance(x, torch.Tensor): 114 | return False 115 | return (len(x.shape) == 4) and (x.shape[1] > 3) 116 | 117 | 118 | def isimage(x): 119 | if not isinstance(x, torch.Tensor): 120 | return False 121 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 122 | 123 | 124 | def isheatmap(x): 125 | if not isinstance(x, torch.Tensor): 126 | return False 127 | 128 | return x.ndim == 2 129 | 130 | 131 | def isneighbors(x): 132 | if not isinstance(x, torch.Tensor): 133 | return False 134 | return x.ndim == 5 and (x.shape[2] == 3 or x.shape[2] == 1) 135 | 136 | 137 | def exists(x): 138 | return x is not None 139 | 140 | 141 | def expand_dims_like(x, y): 142 | while x.dim() != y.dim(): 143 | x = x.unsqueeze(-1) 144 | return x 145 | 146 | 147 | def default(val, d): 148 | if exists(val): 149 | return val 150 | return d() if isfunction(d) else d 151 | 152 | 153 | def mean_flat(tensor): 154 | """ 155 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 156 | Take the mean over all non-batch dimensions. 157 | """ 158 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 159 | 160 | 161 | def count_params(model, verbose=False): 162 | total_params = sum(p.numel() for p in model.parameters()) 163 | if verbose: 164 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") 165 | return total_params 166 | 167 | 168 | def instantiate_from_config(config): 169 | if not "target" in config: 170 | if config == "__is_first_stage__": 171 | return None 172 | elif config == "__is_unconditional__": 173 | return None 174 | raise KeyError("Expected key `target` to instantiate.") 175 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 176 | 177 | 178 | def get_obj_from_str(string, reload=False, invalidate_cache=True): 179 | package_directory_name = os.path.basename(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 180 | module, cls = string.rsplit(".", 1) 181 | if invalidate_cache: 182 | importlib.invalidate_caches() 183 | if reload: 184 | module_imp = importlib.import_module(module) 185 | importlib.reload(module_imp) 186 | return getattr(importlib.import_module(module, package=package_directory_name), cls) 187 | 188 | 189 | def append_zero(x): 190 | return torch.cat([x, x.new_zeros([1])]) 191 | 192 | 193 | def append_dims(x, target_dims): 194 | """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" 195 | dims_to_append = target_dims - x.ndim 196 | if dims_to_append < 0: 197 | raise ValueError( 198 | f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" 199 | ) 200 | return x[(...,) + (None,) * dims_to_append] 201 | 202 | 203 | def load_model_from_config(config, ckpt, verbose=True, freeze=True): 204 | print(f"Loading model from {ckpt}") 205 | if ckpt.endswith("ckpt"): 206 | pl_sd = torch.load(ckpt, map_location="cpu") 207 | if "global_step" in pl_sd: 208 | print(f"Global Step: {pl_sd['global_step']}") 209 | sd = pl_sd["state_dict"] 210 | elif ckpt.endswith("safetensors"): 211 | sd = load_safetensors(ckpt) 212 | else: 213 | raise NotImplementedError 214 | 215 | model = instantiate_from_config(config.model) 216 | 217 | m, u = model.load_state_dict(sd, strict=False) 218 | 219 | if len(m) > 0 and verbose: 220 | print("missing keys:") 221 | print(m) 222 | if len(u) > 0 and verbose: 223 | print("unexpected keys:") 224 | print(u) 225 | 226 | if freeze: 227 | for param in model.parameters(): 228 | param.requires_grad = False 229 | 230 | model.eval() 231 | return model 232 | 233 | 234 | def get_configs_path() -> str: 235 | """ 236 | Get the `configs` directory. 237 | For a working copy, this is the one in the root of the repository, 238 | but for an installed copy, it's in the `sgm` package (see pyproject.toml). 239 | """ 240 | this_dir = os.path.dirname(__file__) 241 | candidates = ( 242 | os.path.join(this_dir, "configs"), 243 | os.path.join(this_dir, "..", "configs"), 244 | ) 245 | for candidate in candidates: 246 | candidate = os.path.abspath(candidate) 247 | if os.path.isdir(candidate): 248 | return candidate 249 | raise FileNotFoundError(f"Could not find SGM configs in {candidates}") 250 | 251 | 252 | def get_nested_attribute(obj, attribute_path, depth=None, return_key=False): 253 | """ 254 | Will return the result of a recursive get attribute call. 255 | E.g.: 256 | a.b.c 257 | = getattr(getattr(a, "b"), "c") 258 | = get_nested_attribute(a, "b.c") 259 | If any part of the attribute call is an integer x with current obj a, will 260 | try to call a[x] instead of a.x first. 261 | """ 262 | attributes = attribute_path.split(".") 263 | if depth is not None and depth > 0: 264 | attributes = attributes[:depth] 265 | assert len(attributes) > 0, "At least one attribute should be selected" 266 | current_attribute = obj 267 | current_key = None 268 | for level, attribute in enumerate(attributes): 269 | current_key = ".".join(attributes[: level + 1]) 270 | try: 271 | id_ = int(attribute) 272 | current_attribute = current_attribute[id_] 273 | except ValueError: 274 | current_attribute = getattr(current_attribute, attribute) 275 | 276 | return (current_attribute, current_key) if return_key else current_attribute 277 | --------------------------------------------------------------------------------