├── .gitignore ├── LICENSE ├── README.md ├── __assets__ ├── Van_Gogh_flower.gif ├── flower_bloom.gif ├── ice_cube.gif ├── long_video.gif ├── rainbow_forming1.gif ├── rainbow_forming2.gif ├── rainbow_forming3.gif ├── rainbow_forming4.gif ├── river_freezing.gif ├── teaser.png ├── volcano_eruption1.gif ├── volcano_eruption2.gif ├── volcano_eruption3.gif ├── volcano_eruption4.gif └── yellow_flower.gif ├── configs ├── A_thunderstorm_developing_over_a_sea.yaml ├── a_rainbow_is_forming.yaml ├── flowers.yaml └── volcano_eruption.yaml ├── freebloom ├── models │ ├── attention.py │ ├── resnet.py │ ├── unet.py │ └── unet_blocks.py ├── pipelines │ └── pipeline_spatio_temporal.py ├── prompt_attention │ ├── attention_util.py │ ├── ptp_utils.py │ └── seq_aligner.py └── util.py ├── interp.py ├── main.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | .empty 3 | outputs -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Free-Bloom 2 | 3 | This repository is the official implementation of [Free-Bloom](https://arxiv.org/abs/2309.14494). 4 | 5 | **[Free-Bloom: Zero-Shot Text-to-Video Generator with LLM Director and LDM Animator](https://arxiv.org/abs/2309.14494)** 6 | 7 | [Hanzhuo Huang∗]() , [Yufan Feng∗](), [Cheng Shi](https://chengshiest.github.io/), [Lan Xu](https://www.xu-lan.com/), [Jingyi Yu](https://vic.shanghaitech.edu.cn/vrvc/en/people/jingyi-yu/), [Sibei Yang†](https://faculty.sist.shanghaitech.edu.cn/yangsibei/) 8 | 9 | *Equal contribution; †Corresponding Author 10 | 11 | [![arXiv](https://img.shields.io/badge/arXiv-FreeBloom-b31b1b.svg)](https://arxiv.org/abs/2309.14494) ![Pytorch](https://img.shields.io/badge/PyTorch->=1.10.0-Red?logo=pytorch) 12 | 13 | 14 | ![image-20230924124604776](__assets__/teaser.png) 15 | 16 | ## Setup 17 | 18 | ### Requirements 19 | ```cmd 20 | conda create -n fb python=3.10 21 | conda activate fb 22 | pip install -r requirements.txt 23 | ``` 24 | 25 | Installing [xformers](https://github.com/facebookresearch/xformers) is highly recommended for more efficiency and speed on GPUs. 26 | To enable xformers, set `enable_xformers_memory_efficient_attention=True` (default). 27 | 28 | 29 | 30 | 31 | ## Usage 32 | 33 | ### Generate 34 | ```cmd 35 | python main.py --config configs/flowers.yaml 36 | ``` 37 | 38 | Change the path of diffusion models to your own for the `pretrained_model_path` key in config yaml file. 39 | 40 | 41 | 42 | 43 | 44 | ## Results 45 | 46 | **A Flower is blooming** 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 |
56 | 57 | 58 | 59 | **Volcano eruption** 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 |
69 | 70 | **A rainbow is forming** 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 |
79 | 80 | 81 | ## Citation 82 | 83 | ``` 84 | @article{freebloom, 85 | title={Free-Bloom: Zero-Shot Text-to-Video Generator with LLM Director and LDM Animator}, 86 | author={Huang, Hanzhuo and Feng, Yufan and Shi, Cheng and Xu, Lan and Yu, Jingyi and Yang, Sibei}, 87 | journal={arXiv preprint arXiv:2309.14494}, 88 | year={2023} 89 | } 90 | ``` 91 | -------------------------------------------------------------------------------- /__assets__/Van_Gogh_flower.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SooLab/Free-Bloom/4498b463cb4f1e652912db2144cf8d8983e6e141/__assets__/Van_Gogh_flower.gif -------------------------------------------------------------------------------- /__assets__/flower_bloom.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SooLab/Free-Bloom/4498b463cb4f1e652912db2144cf8d8983e6e141/__assets__/flower_bloom.gif -------------------------------------------------------------------------------- /__assets__/ice_cube.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SooLab/Free-Bloom/4498b463cb4f1e652912db2144cf8d8983e6e141/__assets__/ice_cube.gif -------------------------------------------------------------------------------- /__assets__/long_video.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SooLab/Free-Bloom/4498b463cb4f1e652912db2144cf8d8983e6e141/__assets__/long_video.gif -------------------------------------------------------------------------------- /__assets__/rainbow_forming1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SooLab/Free-Bloom/4498b463cb4f1e652912db2144cf8d8983e6e141/__assets__/rainbow_forming1.gif -------------------------------------------------------------------------------- /__assets__/rainbow_forming2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SooLab/Free-Bloom/4498b463cb4f1e652912db2144cf8d8983e6e141/__assets__/rainbow_forming2.gif -------------------------------------------------------------------------------- /__assets__/rainbow_forming3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SooLab/Free-Bloom/4498b463cb4f1e652912db2144cf8d8983e6e141/__assets__/rainbow_forming3.gif -------------------------------------------------------------------------------- /__assets__/rainbow_forming4.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SooLab/Free-Bloom/4498b463cb4f1e652912db2144cf8d8983e6e141/__assets__/rainbow_forming4.gif -------------------------------------------------------------------------------- /__assets__/river_freezing.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SooLab/Free-Bloom/4498b463cb4f1e652912db2144cf8d8983e6e141/__assets__/river_freezing.gif -------------------------------------------------------------------------------- /__assets__/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SooLab/Free-Bloom/4498b463cb4f1e652912db2144cf8d8983e6e141/__assets__/teaser.png -------------------------------------------------------------------------------- /__assets__/volcano_eruption1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SooLab/Free-Bloom/4498b463cb4f1e652912db2144cf8d8983e6e141/__assets__/volcano_eruption1.gif -------------------------------------------------------------------------------- /__assets__/volcano_eruption2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SooLab/Free-Bloom/4498b463cb4f1e652912db2144cf8d8983e6e141/__assets__/volcano_eruption2.gif -------------------------------------------------------------------------------- /__assets__/volcano_eruption3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SooLab/Free-Bloom/4498b463cb4f1e652912db2144cf8d8983e6e141/__assets__/volcano_eruption3.gif -------------------------------------------------------------------------------- /__assets__/volcano_eruption4.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SooLab/Free-Bloom/4498b463cb4f1e652912db2144cf8d8983e6e141/__assets__/volcano_eruption4.gif -------------------------------------------------------------------------------- /__assets__/yellow_flower.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SooLab/Free-Bloom/4498b463cb4f1e652912db2144cf8d8983e6e141/__assets__/yellow_flower.gif -------------------------------------------------------------------------------- /configs/A_thunderstorm_developing_over_a_sea.yaml: -------------------------------------------------------------------------------- 1 | pretrained_model_path: "/data/diffusion_wights/stable-diffusion-v1-5" 2 | output_dir: "./outputs/A_thunderstorm_developing_over_a_sea" 3 | 4 | inference_config: 5 | diversity_rand_ratio: 0.1 6 | 7 | validation_data: 8 | prompts: 9 | - "A serene view of the sea under clear skies, with a few fluffy white clouds in the distance." 10 | - "Darker clouds gathering at the horizon as the sea starts to become choppy with rising waves.;" 11 | - "Lightning streaks across the sky, illuminating the dark clouds, and raindrops begin to fall onto the sea's surface." 12 | - "The thunderstorm intensifies with heavy rain pouring down, and strong winds whip up the sea, creating large waves." 13 | - "Lightning strikes become more frequent, illuminating the turbulent sea with flashes of light." 14 | - "The storm reaches its peak, with menacing dark clouds covering the entire sky, and the sea becomes a tumultuous mass of crashing waves and heavy rainfall." 15 | 16 | width: 512 17 | height: 512 18 | num_inference_steps: 50 19 | guidance_scale: 12.5 20 | use_inv_latent: True 21 | num_inv_steps: 50 22 | negative_prompt: "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer difits, cropped, worst quality, low quality, deformed body, bloated, ugly, unrealistic" 23 | interpolate_k: 0 24 | attention_type_former: ["self", "former", "first"] 25 | attention_type_latter: ["self"] 26 | attention_adapt_step: 30 # 0~50 27 | 28 | seed: 3243241 # as you like 29 | mixed_precision: fp16 30 | enable_xformers_memory_efficient_attention: True 31 | -------------------------------------------------------------------------------- /configs/a_rainbow_is_forming.yaml: -------------------------------------------------------------------------------- 1 | pretrained_model_path: "/data/diffusion_wights/stable-diffusion-v1-5" 2 | output_dir: "./outputs/a_rainbow_is_forming" 3 | 4 | inference_config: 5 | diversity_rand_ratio: 0.1 6 | 7 | validation_data: 8 | prompts: 9 | - "The sky, partially cloudy, with faint hints of colors starting to emerge." 10 | - "A faint arch of colors becomes visible, stretching across the sky." 11 | - "The rainbow gains intensity as the colors become more vibrant and defined." 12 | - "The rainbow is now fully formed, displaying its classic arc shape." 13 | - "The colors of the rainbow shine brilliantly against the backdrop of the sky." 14 | - "The rainbow remains steady, its colors vivid and captivating as the rainbow decorates the sky." 15 | 16 | width: 512 17 | height: 512 18 | num_inference_steps: 50 19 | guidance_scale: 12.5 20 | use_inv_latent: True 21 | num_inv_steps: 50 22 | negative_prompt: "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer difits, cropped, worst quality, low quality, deformed body, bloated, ugly, unrealistic" 23 | interpolate_k: 0 24 | attention_type_former: ["self", "former", "first"] 25 | attention_type_latter: ["self"] 26 | attention_adapt_step: 50 # 0~50 27 | 28 | seed: 324423 # 42 # as you like 29 | mixed_precision: fp16 30 | enable_xformers_memory_efficient_attention: True 31 | -------------------------------------------------------------------------------- /configs/flowers.yaml: -------------------------------------------------------------------------------- 1 | # A cluster of flowers blooms 2 | pretrained_model_path: "/data/diffusion_wights/stable-diffusion-v1-5" 3 | output_dir: "./outputs/flowers" 4 | 5 | inference_config: 6 | diversity_rand_ratio: 0.1 7 | 8 | validation_data: 9 | prompts: 10 | - "A group of closed buds can be seen on the stem of a plant." 11 | - "The buds begin to slowly unfurl, revealing small petals." 12 | - " The petals continue to unfurl, revealing more of the flower's center." 13 | - "The petals are now fully opened, and the center of the flower is visible." 14 | - "The flower's stamen and pistil become more prominent, and the petals start to curve outward." 15 | - "The fully bloomed flowers are in full view, with their petals open wide and displaying their vibrant colors." 16 | 17 | 18 | width: 512 19 | height: 512 20 | num_inference_steps: 50 21 | guidance_scale: 12.5 22 | use_inv_latent: True 23 | num_inv_steps: 50 24 | negative_prompt: "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer difits, cropped, worst quality, low quality, deformed body, bloated, ugly, unrealistic" 25 | interpolate_k: 0 26 | attention_type_former: [ "self", "first", "former" ] 27 | attention_type_latter: [ "self" ] 28 | attention_adapt_step: 20 29 | 30 | seed: 42 31 | mixed_precision: fp16 32 | enable_xformers_memory_efficient_attention: True 33 | -------------------------------------------------------------------------------- /configs/volcano_eruption.yaml: -------------------------------------------------------------------------------- 1 | pretrained_model_path: "/data/diffusion_wights/stable-diffusion-v1-5" 2 | output_dir: "./outputs/volcano_eruption_3" 3 | 4 | inference_config: 5 | diversity_rand_ratio: 0.1 6 | 7 | validation_data: 8 | prompts: 9 | - "A towering volcano stands against a backdrop of clear blue skies, with no visible signs of activity." 10 | - "Suddenly, a plume of thick smoke and ash erupts from the volcano's summit, the plume billowing high into the air." 11 | - "Molten lava begins to flow down the volcano's slopes, the lava glowing brightly with intense heat and leaving a trail of destruction in its path." 12 | - "Explosions rock the volcano as fiery projectiles shoot into the sky, the projectiles scattering debris and ash in all directions.;" 13 | - "The eruption intensifies, with a massive column of smoke and ash ascending into the atmosphere, the column darkening the surrounding area and creating a dramatic spectacle." 14 | - "As the eruption reaches its peak, a pyroclastic flow cascades down the volcano's sides, the flow engulfing everything in its path with a deadly combination of hot gases, ash, and volcanic material." 15 | 16 | 17 | 18 | width: 512 19 | height: 512 20 | num_inference_steps: 50 21 | guidance_scale: 12.5 22 | use_inv_latent: True 23 | num_inv_steps: 50 24 | negative_prompt: "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer difits, cropped, worst quality, low quality, deformed body, bloated, ugly, unrealistic" 25 | interpolate_k: 0 26 | attention_type_former: ["self", "former", "first"] 27 | attention_type_latter: ["self"] 28 | attention_adapt_step: 40 # 0~50 29 | 30 | seed: 17 #133 # as you like 31 | mixed_precision: fp16 32 | enable_xformers_memory_efficient_attention: True 33 | -------------------------------------------------------------------------------- /freebloom/models/attention.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py 2 | 3 | from dataclasses import dataclass 4 | from typing import Optional 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from torch import nn 9 | 10 | from diffusers.configuration_utils import ConfigMixin, register_to_config 11 | from diffusers.modeling_utils import ModelMixin 12 | from diffusers.utils import BaseOutput 13 | from diffusers.utils.import_utils import is_xformers_available 14 | from diffusers.models.attention import CrossAttention, FeedForward, AdaLayerNorm 15 | 16 | from einops import rearrange, repeat 17 | 18 | 19 | @dataclass 20 | class Transformer3DModelOutput(BaseOutput): 21 | sample: torch.FloatTensor 22 | 23 | 24 | if is_xformers_available(): 25 | import xformers 26 | import xformers.ops 27 | else: 28 | xformers = None 29 | 30 | 31 | class Transformer3DModel(ModelMixin, ConfigMixin): 32 | @register_to_config 33 | def __init__( 34 | self, 35 | num_attention_heads: int = 16, 36 | attention_head_dim: int = 88, 37 | in_channels: Optional[int] = None, 38 | num_layers: int = 1, 39 | dropout: float = 0.0, 40 | norm_num_groups: int = 32, 41 | cross_attention_dim: Optional[int] = None, 42 | attention_bias: bool = False, 43 | activation_fn: str = "geglu", 44 | num_embeds_ada_norm: Optional[int] = None, 45 | use_linear_projection: bool = False, 46 | only_cross_attention: bool = False, 47 | upcast_attention: bool = False, 48 | ): 49 | super().__init__() 50 | self.use_linear_projection = use_linear_projection 51 | self.num_attention_heads = num_attention_heads 52 | self.attention_head_dim = attention_head_dim 53 | inner_dim = num_attention_heads * attention_head_dim 54 | 55 | # Define input layers 56 | self.in_channels = in_channels 57 | 58 | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) 59 | if use_linear_projection: 60 | self.proj_in = nn.Linear(in_channels, inner_dim) 61 | else: 62 | self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) 63 | 64 | # Define transformers blocks 65 | self.transformer_blocks = nn.ModuleList( 66 | [ 67 | BasicTransformerBlock( 68 | inner_dim, 69 | num_attention_heads, 70 | attention_head_dim, 71 | dropout=dropout, 72 | cross_attention_dim=cross_attention_dim, 73 | activation_fn=activation_fn, 74 | num_embeds_ada_norm=num_embeds_ada_norm, 75 | attention_bias=attention_bias, 76 | only_cross_attention=only_cross_attention, 77 | upcast_attention=upcast_attention, 78 | ) 79 | for d in range(num_layers) 80 | ] 81 | ) 82 | 83 | # 4. Define output layers 84 | if use_linear_projection: 85 | self.proj_out = nn.Linear(in_channels, inner_dim) 86 | else: 87 | self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) 88 | 89 | def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True): 90 | # Input 91 | assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." 92 | video_length = hidden_states.shape[2] 93 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") 94 | # encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length) 95 | 96 | batch, channel, height, weight = hidden_states.shape 97 | residual = hidden_states 98 | 99 | hidden_states = self.norm(hidden_states) 100 | if not self.use_linear_projection: 101 | hidden_states = self.proj_in(hidden_states) 102 | inner_dim = hidden_states.shape[1] 103 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) 104 | else: 105 | inner_dim = hidden_states.shape[1] 106 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) 107 | hidden_states = self.proj_in(hidden_states) 108 | 109 | # Blocks 110 | for block in self.transformer_blocks: 111 | hidden_states = block( 112 | hidden_states, 113 | encoder_hidden_states=encoder_hidden_states, 114 | timestep=timestep, 115 | video_length=video_length 116 | ) 117 | 118 | # Output 119 | if not self.use_linear_projection: 120 | hidden_states = ( 121 | hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() 122 | ) 123 | hidden_states = self.proj_out(hidden_states) 124 | else: 125 | hidden_states = self.proj_out(hidden_states) 126 | hidden_states = ( 127 | hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() 128 | ) 129 | 130 | output = hidden_states + residual 131 | 132 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) 133 | if not return_dict: 134 | return (output,) 135 | 136 | return Transformer3DModelOutput(sample=output) 137 | 138 | 139 | class BasicTransformerBlock(nn.Module): 140 | def __init__( 141 | self, 142 | dim: int, 143 | num_attention_heads: int, 144 | attention_head_dim: int, 145 | dropout=0.0, 146 | cross_attention_dim: Optional[int] = None, 147 | activation_fn: str = "geglu", 148 | num_embeds_ada_norm: Optional[int] = None, 149 | attention_bias: bool = False, 150 | only_cross_attention: bool = False, 151 | upcast_attention: bool = False, 152 | ): 153 | super().__init__() 154 | self.only_cross_attention = only_cross_attention 155 | self.use_ada_layer_norm = num_embeds_ada_norm is not None 156 | 157 | # SC-Attn 158 | self.attn1 = SparseCausalAttention( 159 | query_dim=dim, 160 | heads=num_attention_heads, 161 | dim_head=attention_head_dim, 162 | dropout=dropout, 163 | bias=attention_bias, 164 | cross_attention_dim=cross_attention_dim if only_cross_attention else None, 165 | upcast_attention=upcast_attention, 166 | ) 167 | self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) 168 | 169 | # Cross-Attn 170 | if cross_attention_dim is not None: 171 | self.attn2 = CrossAttention( 172 | query_dim=dim, 173 | cross_attention_dim=cross_attention_dim, 174 | heads=num_attention_heads, 175 | dim_head=attention_head_dim, 176 | dropout=dropout, 177 | bias=attention_bias, 178 | upcast_attention=upcast_attention, 179 | ) 180 | else: 181 | self.attn2 = None 182 | 183 | if cross_attention_dim is not None: 184 | self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) 185 | else: 186 | self.norm2 = None 187 | 188 | # Feed-forward 189 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) 190 | self.norm3 = nn.LayerNorm(dim) 191 | 192 | # Temp-Attn 193 | self.attn_temp = CrossAttention( 194 | query_dim=dim, 195 | heads=num_attention_heads, 196 | dim_head=attention_head_dim, 197 | dropout=dropout, 198 | bias=attention_bias, 199 | upcast_attention=upcast_attention, 200 | ) 201 | nn.init.zeros_(self.attn_temp.to_out[0].weight.data) 202 | self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) 203 | 204 | def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): 205 | if not is_xformers_available(): 206 | print("Here is how to install it") 207 | raise ModuleNotFoundError( 208 | "Refer to https://github.com/facebookresearch/xformers for more information on how to install" 209 | " xformers", 210 | name="xformers", 211 | ) 212 | elif not torch.cuda.is_available(): 213 | raise ValueError( 214 | "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only" 215 | " available for GPU " 216 | ) 217 | else: 218 | try: 219 | # Make sure we can run the memory efficient attention 220 | _ = xformers.ops.memory_efficient_attention( 221 | torch.randn((1, 2, 40), device="cuda"), 222 | torch.randn((1, 2, 40), device="cuda"), 223 | torch.randn((1, 2, 40), device="cuda"), 224 | ) 225 | except Exception as e: 226 | raise e 227 | self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers 228 | if self.attn2 is not None: 229 | self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers 230 | # self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers 231 | 232 | def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None): 233 | # SparseCausal-Attention 234 | norm_hidden_states = ( 235 | self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) 236 | ) 237 | 238 | if self.only_cross_attention: 239 | hidden_states = ( 240 | self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states 241 | ) 242 | else: 243 | hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, 244 | video_length=video_length) + hidden_states 245 | 246 | if self.attn2 is not None: 247 | # Cross-Attention 248 | norm_hidden_states = ( 249 | self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) 250 | ) 251 | # if encoder_hidden_states.shape[0] != norm_hidden_states.shape[0]: 252 | # encoder_hidden_states = repeat(encoder_hidden_states, 'a b c -> (a f) b c', f=video_length) 253 | hidden_states = ( 254 | self.attn2( 255 | norm_hidden_states, encoder_hidden_states=encoder_hidden_states, 256 | attention_mask=attention_mask 257 | ) 258 | + hidden_states 259 | ) 260 | 261 | 262 | # Feed-forward 263 | hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states 264 | 265 | # Temporal-Attention 266 | # d = hidden_states.shape[1] 267 | # hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length) 268 | # norm_hidden_states = ( 269 | # self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states) 270 | # ) 271 | # hidden_states = self.attn_temp(norm_hidden_states) + hidden_states 272 | # hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) 273 | 274 | return hidden_states 275 | 276 | 277 | class SparseCausalAttention(CrossAttention): 278 | 279 | # def __init__(self, query_dim: int, 280 | # cross_attention_dim: Optional[int] = None, 281 | # heads: int = 8, 282 | # dim_head: int = 64, 283 | # dropout: float = 0.0, 284 | # bias=False, 285 | # upcast_attention: bool = False, 286 | # upcast_softmax: bool = False, 287 | # added_kv_proj_dim: Optional[int] = None, 288 | # norm_num_groups: Optional[int] = None, ): 289 | # super(SparseCausalAttention, self).__init__(query_dim, 290 | # cross_attention_dim, 291 | # heads, 292 | # dim_head, 293 | # dropout, 294 | # bias, 295 | # upcast_attention, 296 | # upcast_softmax, 297 | # added_kv_proj_dim, 298 | # norm_num_groups) 299 | # inner_dim = dim_head * heads 300 | # self.to_q2 = nn.Linear(query_dim, inner_dim, bias=bias) 301 | 302 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None): 303 | batch_size, sequence_length, _ = hidden_states.shape 304 | 305 | encoder_hidden_states = encoder_hidden_states 306 | 307 | if self.group_norm is not None: 308 | hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 309 | 310 | query = self.to_q(hidden_states) 311 | dim = query.shape[-1] 312 | query = self.reshape_heads_to_batch_dim(query) 313 | 314 | if self.added_kv_proj_dim is not None: 315 | raise NotImplementedError 316 | 317 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states 318 | key = self.to_k(encoder_hidden_states) 319 | value = self.to_v(encoder_hidden_states) 320 | 321 | former_frame_index = torch.arange(video_length) - 1 322 | former_frame_index[0] = 0 323 | 324 | key = rearrange(key, "(b f) d c -> b f d c", f=video_length) 325 | key = torch.cat([key[:, [0] * video_length], key[:, former_frame_index]], dim=2) 326 | key = rearrange(key, "b f d c -> (b f) d c") 327 | 328 | value = rearrange(value, "(b f) d c -> b f d c", f=video_length) 329 | value = torch.cat([value[:, [0] * video_length], value[:, former_frame_index]], dim=2) 330 | value = rearrange(value, "b f d c -> (b f) d c") 331 | 332 | key = self.reshape_heads_to_batch_dim(key) 333 | value = self.reshape_heads_to_batch_dim(value) 334 | 335 | if attention_mask is not None: 336 | if attention_mask.shape[-1] != query.shape[1]: 337 | target_length = query.shape[1] 338 | attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) 339 | attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) 340 | 341 | # attention, what we cannot get enough of 342 | if self._use_memory_efficient_attention_xformers: 343 | hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) 344 | # Some versions of xformers return output in fp32, cast it back to the dtype of the input 345 | hidden_states = hidden_states.to(query.dtype) 346 | else: 347 | if self._slice_size is None or query.shape[0] // self._slice_size == 1: 348 | hidden_states = self._attention(query, key, value, attention_mask) 349 | else: 350 | hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) 351 | 352 | # linear proj 353 | hidden_states = self.to_out[0](hidden_states) 354 | 355 | # dropout 356 | hidden_states = self.to_out[1](hidden_states) 357 | return hidden_states 358 | -------------------------------------------------------------------------------- /freebloom/models/resnet.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from einops import rearrange 8 | from torch import Tensor 9 | 10 | 11 | class InflatedGroupNorm(nn.GroupNorm): 12 | def forward(self, x: Tensor) -> Tensor: 13 | f = x.shape[2] 14 | 15 | x = rearrange(x, 'b c f h w -> (b f) c h w') 16 | x = super().forward(x) 17 | x = rearrange(x, '(b f) c h w -> b c f h w', f=f) 18 | 19 | return x 20 | 21 | 22 | class InflatedConv3d(nn.Conv2d): 23 | def __init__(self, in_channels, out_channels, kernel_size, temporal_kernel_size=None, **kwargs): 24 | super(InflatedConv3d, self).__init__(in_channels, out_channels, kernel_size, **kwargs) 25 | if temporal_kernel_size is None: 26 | temporal_kernel_size = kernel_size 27 | 28 | self.conv_temp = ( 29 | nn.Conv1d( 30 | out_channels, 31 | out_channels, 32 | kernel_size=temporal_kernel_size, 33 | padding=temporal_kernel_size // 2, 34 | ) 35 | if kernel_size > 1 36 | else None 37 | ) 38 | 39 | if self.conv_temp is not None: 40 | nn.init.dirac_(self.conv_temp.weight.data) # initialized to be identity 41 | nn.init.zeros_(self.conv_temp.bias.data) 42 | 43 | def forward(self, x): 44 | b = x.shape[0] 45 | 46 | is_video = x.ndim == 5 47 | if is_video: 48 | x = rearrange(x, "b c f h w -> (b f) c h w") 49 | 50 | x = super().forward(x) 51 | 52 | if is_video: 53 | x = rearrange(x, "(b f) c h w -> b c f h w", b=b) 54 | 55 | if self.conv_temp is None or not is_video: 56 | return x 57 | 58 | *_, h, w = x.shape 59 | 60 | x = rearrange(x, "b c f h w -> (b h w) c f") 61 | 62 | x = self.conv_temp(x) 63 | 64 | x = rearrange(x, "(b h w) c f -> b c f h w", h=h, w=w) 65 | 66 | return x 67 | 68 | 69 | class Upsample3D(nn.Module): 70 | def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): 71 | super().__init__() 72 | self.channels = channels 73 | self.out_channels = out_channels or channels 74 | self.use_conv = use_conv 75 | self.use_conv_transpose = use_conv_transpose 76 | self.name = name 77 | 78 | conv = None 79 | if use_conv_transpose: 80 | raise NotImplementedError 81 | elif use_conv: 82 | conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1) 83 | 84 | if name == "conv": 85 | self.conv = conv 86 | else: 87 | self.Conv2d_0 = conv 88 | 89 | def forward(self, hidden_states, output_size=None): 90 | assert hidden_states.shape[1] == self.channels 91 | 92 | if self.use_conv_transpose: 93 | raise NotImplementedError 94 | 95 | # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 96 | dtype = hidden_states.dtype 97 | if dtype == torch.bfloat16: 98 | hidden_states = hidden_states.to(torch.float32) 99 | 100 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 101 | if hidden_states.shape[0] >= 64: 102 | hidden_states = hidden_states.contiguous() 103 | 104 | # if `output_size` is passed we force the interpolation output 105 | # size and do not make use of `scale_factor=2` 106 | if output_size is None: 107 | hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest") 108 | else: 109 | hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") 110 | 111 | # If the input is bfloat16, we cast back to bfloat16 112 | if dtype == torch.bfloat16: 113 | hidden_states = hidden_states.to(dtype) 114 | 115 | if self.use_conv: 116 | if self.name == "conv": 117 | hidden_states = self.conv(hidden_states) 118 | else: 119 | hidden_states = self.Conv2d_0(hidden_states) 120 | 121 | return hidden_states 122 | 123 | 124 | class Downsample3D(nn.Module): 125 | def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): 126 | super().__init__() 127 | self.channels = channels 128 | self.out_channels = out_channels or channels 129 | self.use_conv = use_conv 130 | self.padding = padding 131 | stride = 2 132 | self.name = name 133 | 134 | if use_conv: 135 | conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding) 136 | else: 137 | raise NotImplementedError 138 | 139 | if name == "conv": 140 | self.Conv2d_0 = conv 141 | self.conv = conv 142 | elif name == "Conv2d_0": 143 | self.conv = conv 144 | else: 145 | self.conv = conv 146 | 147 | def forward(self, hidden_states): 148 | assert hidden_states.shape[1] == self.channels 149 | if self.use_conv and self.padding == 0: 150 | raise NotImplementedError 151 | 152 | assert hidden_states.shape[1] == self.channels 153 | hidden_states = self.conv(hidden_states) 154 | 155 | return hidden_states 156 | 157 | 158 | class ResnetBlock3D(nn.Module): 159 | def __init__( 160 | self, 161 | *, 162 | in_channels, 163 | out_channels=None, 164 | conv_shortcut=False, 165 | dropout=0.0, 166 | temb_channels=512, 167 | groups=32, 168 | groups_out=None, 169 | pre_norm=True, 170 | eps=1e-6, 171 | non_linearity="swish", 172 | time_embedding_norm="default", 173 | output_scale_factor=1.0, 174 | use_in_shortcut=None, 175 | ): 176 | super().__init__() 177 | self.pre_norm = pre_norm 178 | self.pre_norm = True 179 | self.in_channels = in_channels 180 | out_channels = in_channels if out_channels is None else out_channels 181 | self.out_channels = out_channels 182 | self.use_conv_shortcut = conv_shortcut 183 | self.time_embedding_norm = time_embedding_norm 184 | self.output_scale_factor = output_scale_factor 185 | 186 | if groups_out is None: 187 | groups_out = groups 188 | 189 | self.norm1 = InflatedGroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) 190 | # self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) 191 | 192 | self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) 193 | 194 | if temb_channels is not None: 195 | if self.time_embedding_norm == "default": 196 | time_emb_proj_out_channels = out_channels 197 | elif self.time_embedding_norm == "scale_shift": 198 | time_emb_proj_out_channels = out_channels * 2 199 | else: 200 | raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") 201 | 202 | self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels) 203 | else: 204 | self.time_emb_proj = None 205 | 206 | self.norm2 = InflatedGroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) 207 | # self.norm2 = nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) 208 | 209 | self.dropout = torch.nn.Dropout(dropout) 210 | self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 211 | 212 | if non_linearity == "swish": 213 | self.nonlinearity = lambda x: F.silu(x) 214 | elif non_linearity == "mish": 215 | self.nonlinearity = Mish() 216 | elif non_linearity == "silu": 217 | self.nonlinearity = nn.SiLU() 218 | 219 | self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut 220 | 221 | self.conv_shortcut = None 222 | if self.use_in_shortcut: 223 | self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) 224 | 225 | def forward(self, input_tensor, temb): 226 | hidden_states = input_tensor # [b,c,f,h.w] 227 | 228 | hidden_states = self.norm1(hidden_states) 229 | hidden_states = self.nonlinearity(hidden_states) 230 | 231 | hidden_states = self.conv1(hidden_states) 232 | 233 | if temb is not None: 234 | temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None] 235 | 236 | if temb is not None and self.time_embedding_norm == "default": 237 | hidden_states = hidden_states + temb 238 | 239 | hidden_states = self.norm2(hidden_states) 240 | 241 | if temb is not None and self.time_embedding_norm == "scale_shift": 242 | scale, shift = torch.chunk(temb, 2, dim=1) 243 | hidden_states = hidden_states * (1 + scale) + shift 244 | 245 | hidden_states = self.nonlinearity(hidden_states) 246 | 247 | hidden_states = self.dropout(hidden_states) 248 | hidden_states = self.conv2(hidden_states) 249 | 250 | if self.conv_shortcut is not None: 251 | input_tensor = self.conv_shortcut(input_tensor) 252 | 253 | output_tensor = (input_tensor + hidden_states) / self.output_scale_factor 254 | 255 | return output_tensor 256 | 257 | 258 | class Mish(torch.nn.Module): 259 | def forward(self, hidden_states): 260 | return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) 261 | -------------------------------------------------------------------------------- /freebloom/models/unet.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py 2 | 3 | from dataclasses import dataclass 4 | from typing import List, Optional, Tuple, Union 5 | 6 | import os 7 | import json 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.utils.checkpoint 12 | 13 | from diffusers.configuration_utils import ConfigMixin, register_to_config 14 | from diffusers.modeling_utils import ModelMixin 15 | from diffusers.utils import BaseOutput, logging 16 | from diffusers.models.embeddings import TimestepEmbedding, Timesteps 17 | from .unet_blocks import ( 18 | CrossAttnDownBlock3D, 19 | CrossAttnUpBlock3D, 20 | DownBlock3D, 21 | UNetMidBlock3DCrossAttn, 22 | UpBlock3D, 23 | get_down_block, 24 | get_up_block, 25 | ) 26 | from .resnet import InflatedConv3d, InflatedGroupNorm 27 | 28 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 29 | 30 | 31 | @dataclass 32 | class UNet3DConditionOutput(BaseOutput): 33 | sample: torch.FloatTensor 34 | 35 | 36 | class UNet3DConditionModel(ModelMixin, ConfigMixin): 37 | _supports_gradient_checkpointing = True 38 | 39 | @register_to_config 40 | def __init__( 41 | self, 42 | sample_size: Optional[int] = None, 43 | in_channels: int = 4, 44 | out_channels: int = 4, 45 | center_input_sample: bool = False, 46 | flip_sin_to_cos: bool = True, 47 | freq_shift: int = 0, 48 | down_block_types: Tuple[str] = ( 49 | "CrossAttnDownBlock3D", 50 | "CrossAttnDownBlock3D", 51 | "CrossAttnDownBlock3D", 52 | "DownBlock3D", 53 | ), 54 | mid_block_type: str = "UNetMidBlock3DCrossAttn", 55 | up_block_types: Tuple[str] = ( 56 | "UpBlock3D", 57 | "CrossAttnUpBlock3D", 58 | "CrossAttnUpBlock3D", 59 | "CrossAttnUpBlock3D" 60 | ), 61 | only_cross_attention: Union[bool, Tuple[bool]] = False, 62 | block_out_channels: Tuple[int] = (320, 640, 1280, 1280), 63 | layers_per_block: int = 2, 64 | downsample_padding: int = 1, 65 | mid_block_scale_factor: float = 1, 66 | act_fn: str = "silu", 67 | norm_num_groups: int = 32, 68 | norm_eps: float = 1e-5, 69 | cross_attention_dim: int = 1280, 70 | attention_head_dim: Union[int, Tuple[int]] = 8, 71 | dual_cross_attention: bool = False, 72 | use_linear_projection: bool = False, 73 | class_embed_type: Optional[str] = None, 74 | num_class_embeds: Optional[int] = None, 75 | upcast_attention: bool = False, 76 | resnet_time_scale_shift: str = "default", 77 | ): 78 | super().__init__() 79 | 80 | self.sample_size = sample_size 81 | time_embed_dim = block_out_channels[0] * 4 82 | 83 | # input 84 | self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) 85 | 86 | # time 87 | self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) 88 | timestep_input_dim = block_out_channels[0] 89 | 90 | self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) 91 | 92 | # class embedding 93 | if class_embed_type is None and num_class_embeds is not None: 94 | self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) 95 | elif class_embed_type == "timestep": 96 | self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) 97 | elif class_embed_type == "identity": 98 | self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) 99 | else: 100 | self.class_embedding = None 101 | 102 | self.down_blocks = nn.ModuleList([]) 103 | self.mid_block = None 104 | self.up_blocks = nn.ModuleList([]) 105 | 106 | if isinstance(only_cross_attention, bool): 107 | only_cross_attention = [only_cross_attention] * len(down_block_types) 108 | 109 | if isinstance(attention_head_dim, int): 110 | attention_head_dim = (attention_head_dim,) * len(down_block_types) 111 | 112 | # down 113 | output_channel = block_out_channels[0] 114 | for i, down_block_type in enumerate(down_block_types): 115 | input_channel = output_channel 116 | output_channel = block_out_channels[i] 117 | is_final_block = i == len(block_out_channels) - 1 118 | 119 | down_block = get_down_block( 120 | down_block_type, 121 | num_layers=layers_per_block, 122 | in_channels=input_channel, 123 | out_channels=output_channel, 124 | temb_channels=time_embed_dim, 125 | add_downsample=not is_final_block, 126 | resnet_eps=norm_eps, 127 | resnet_act_fn=act_fn, 128 | resnet_groups=norm_num_groups, 129 | cross_attention_dim=cross_attention_dim, 130 | attn_num_head_channels=attention_head_dim[i], 131 | downsample_padding=downsample_padding, 132 | dual_cross_attention=dual_cross_attention, 133 | use_linear_projection=use_linear_projection, 134 | only_cross_attention=only_cross_attention[i], 135 | upcast_attention=upcast_attention, 136 | resnet_time_scale_shift=resnet_time_scale_shift, 137 | ) 138 | self.down_blocks.append(down_block) 139 | 140 | # mid 141 | if mid_block_type == "UNetMidBlock3DCrossAttn": 142 | self.mid_block = UNetMidBlock3DCrossAttn( 143 | in_channels=block_out_channels[-1], 144 | temb_channels=time_embed_dim, 145 | resnet_eps=norm_eps, 146 | resnet_act_fn=act_fn, 147 | output_scale_factor=mid_block_scale_factor, 148 | resnet_time_scale_shift=resnet_time_scale_shift, 149 | cross_attention_dim=cross_attention_dim, 150 | attn_num_head_channels=attention_head_dim[-1], 151 | resnet_groups=norm_num_groups, 152 | dual_cross_attention=dual_cross_attention, 153 | use_linear_projection=use_linear_projection, 154 | upcast_attention=upcast_attention, 155 | ) 156 | else: 157 | raise ValueError(f"unknown mid_block_type : {mid_block_type}") 158 | 159 | # count how many layers upsample the videos 160 | self.num_upsamplers = 0 161 | 162 | # up 163 | reversed_block_out_channels = list(reversed(block_out_channels)) 164 | reversed_attention_head_dim = list(reversed(attention_head_dim)) 165 | only_cross_attention = list(reversed(only_cross_attention)) 166 | output_channel = reversed_block_out_channels[0] 167 | for i, up_block_type in enumerate(up_block_types): 168 | is_final_block = i == len(block_out_channels) - 1 169 | 170 | prev_output_channel = output_channel 171 | output_channel = reversed_block_out_channels[i] 172 | input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] 173 | 174 | # add upsample block for all BUT final layer 175 | if not is_final_block: 176 | add_upsample = True 177 | self.num_upsamplers += 1 178 | else: 179 | add_upsample = False 180 | 181 | up_block = get_up_block( 182 | up_block_type, 183 | num_layers=layers_per_block + 1, 184 | in_channels=input_channel, 185 | out_channels=output_channel, 186 | prev_output_channel=prev_output_channel, 187 | temb_channels=time_embed_dim, 188 | add_upsample=add_upsample, 189 | resnet_eps=norm_eps, 190 | resnet_act_fn=act_fn, 191 | resnet_groups=norm_num_groups, 192 | cross_attention_dim=cross_attention_dim, 193 | attn_num_head_channels=reversed_attention_head_dim[i], 194 | dual_cross_attention=dual_cross_attention, 195 | use_linear_projection=use_linear_projection, 196 | only_cross_attention=only_cross_attention[i], 197 | upcast_attention=upcast_attention, 198 | resnet_time_scale_shift=resnet_time_scale_shift, 199 | ) 200 | self.up_blocks.append(up_block) 201 | prev_output_channel = output_channel 202 | 203 | # out 204 | self.conv_norm_out = InflatedGroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, 205 | eps=norm_eps) 206 | # self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) 207 | 208 | self.conv_act = nn.SiLU() 209 | self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1) 210 | 211 | def set_attention_slice(self, slice_size): 212 | r""" 213 | Enable sliced attention computation. 214 | 215 | When this option is enabled, the attention module will split the input tensor in slices, to compute attention 216 | in several steps. This is useful to save some memory in exchange for a small speed decrease. 217 | 218 | Args: 219 | slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): 220 | When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If 221 | `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is 222 | provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` 223 | must be a multiple of `slice_size`. 224 | """ 225 | sliceable_head_dims = [] 226 | 227 | def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): 228 | if hasattr(module, "set_attention_slice"): 229 | sliceable_head_dims.append(module.sliceable_head_dim) 230 | 231 | for child in module.children(): 232 | fn_recursive_retrieve_slicable_dims(child) 233 | 234 | # retrieve number of attention layers 235 | for module in self.children(): 236 | fn_recursive_retrieve_slicable_dims(module) 237 | 238 | num_slicable_layers = len(sliceable_head_dims) 239 | 240 | if slice_size == "auto": 241 | # half the attention head size is usually a good trade-off between 242 | # speed and memory 243 | slice_size = [dim // 2 for dim in sliceable_head_dims] 244 | elif slice_size == "max": 245 | # make smallest slice possible 246 | slice_size = num_slicable_layers * [1] 247 | 248 | slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size 249 | 250 | if len(slice_size) != len(sliceable_head_dims): 251 | raise ValueError( 252 | f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" 253 | f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." 254 | ) 255 | 256 | for i in range(len(slice_size)): 257 | size = slice_size[i] 258 | dim = sliceable_head_dims[i] 259 | if size is not None and size > dim: 260 | raise ValueError(f"size {size} has to be smaller or equal to {dim}.") 261 | 262 | # Recursively walk through all the children. 263 | # Any children which exposes the set_attention_slice method 264 | # gets the message 265 | def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): 266 | if hasattr(module, "set_attention_slice"): 267 | module.set_attention_slice(slice_size.pop()) 268 | 269 | for child in module.children(): 270 | fn_recursive_set_attention_slice(child, slice_size) 271 | 272 | reversed_slice_size = list(reversed(slice_size)) 273 | for module in self.children(): 274 | fn_recursive_set_attention_slice(module, reversed_slice_size) 275 | 276 | def _set_gradient_checkpointing(self, module, value=False): 277 | if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): 278 | module.gradient_checkpointing = value 279 | 280 | def forward( 281 | self, 282 | sample: torch.FloatTensor, 283 | timestep: Union[torch.Tensor, float, int], 284 | encoder_hidden_states: torch.Tensor, 285 | class_labels: Optional[torch.Tensor] = None, 286 | attention_mask: Optional[torch.Tensor] = None, 287 | return_dict: bool = True, 288 | ) -> Union[UNet3DConditionOutput, Tuple]: 289 | r""" 290 | Args: 291 | sample (`torch.FloatTensor`): (batch, channel, f, height, width) noisy inputs tensor 292 | timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps 293 | encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states 294 | return_dict (`bool`, *optional*, defaults to `True`): 295 | Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. 296 | 297 | Returns: 298 | [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: 299 | [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When 300 | returning a tuple, the first element is the sample tensor. 301 | """ 302 | # By default samples have to be AT least a multiple of the overall upsampling factor. 303 | # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). 304 | # However, the upsampling interpolation output size can be forced to fit any upsampling size 305 | # on the fly if necessary. 306 | default_overall_up_factor = 2 ** self.num_upsamplers 307 | 308 | # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` 309 | forward_upsample_size = False 310 | upsample_size = None 311 | 312 | if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): 313 | logger.info("Forward upsample size to force interpolation output size.") 314 | forward_upsample_size = True 315 | 316 | # prepare attention_mask 317 | if attention_mask is not None: 318 | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 319 | attention_mask = attention_mask.unsqueeze(1) 320 | 321 | # center input if necessary 322 | if self.config.center_input_sample: 323 | sample = 2 * sample - 1.0 324 | 325 | # time 326 | timesteps = timestep 327 | if not torch.is_tensor(timesteps): 328 | # This would be a good case for the `match` statement (Python 3.10+) 329 | is_mps = sample.device.type == "mps" 330 | if isinstance(timestep, float): 331 | dtype = torch.float32 if is_mps else torch.float64 332 | else: 333 | dtype = torch.int32 if is_mps else torch.int64 334 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) 335 | elif len(timesteps.shape) == 0: 336 | timesteps = timesteps[None].to(sample.device) 337 | 338 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 339 | timesteps = timesteps.expand(sample.shape[0]) 340 | 341 | t_emb = self.time_proj(timesteps) 342 | 343 | # timesteps does not contain any weights and will always return f32 tensors 344 | # but time_embedding might actually be running in fp16. so we need to cast here. 345 | # there might be better ways to encapsulate this. 346 | t_emb = t_emb.to(dtype=self.dtype) 347 | emb = self.time_embedding(t_emb) 348 | 349 | if self.class_embedding is not None: 350 | if class_labels is None: 351 | raise ValueError("class_labels should be provided when num_class_embeds > 0") 352 | 353 | if self.config.class_embed_type == "timestep": 354 | class_labels = self.time_proj(class_labels) 355 | 356 | class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) 357 | emb = emb + class_emb 358 | 359 | # pre-process 360 | sample = self.conv_in(sample) 361 | 362 | # down 363 | down_block_res_samples = (sample,) 364 | for downsample_block in self.down_blocks: 365 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: 366 | sample, res_samples = downsample_block( 367 | hidden_states=sample, 368 | temb=emb, 369 | encoder_hidden_states=encoder_hidden_states, 370 | attention_mask=attention_mask, 371 | ) 372 | else: 373 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb) 374 | 375 | down_block_res_samples += res_samples 376 | 377 | # mid 378 | sample = self.mid_block( 379 | sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask 380 | ) 381 | # if sample.shape[2] == 7 and timestep > 500: 382 | # for i in range(5): 383 | # lambda_1 = (5 - i) / 5 384 | # lambda_2 = i / 5 385 | # sample[:, :, i + 1] = lambda_1 * sample[:, :, 0] + lambda_2 * sample[:, :, -1] 386 | # sample[:, :, 1] = 0.5 * sample[:, :, 0] + 0.5 * sample[:, :, 2] 387 | 388 | # up 389 | for i, upsample_block in enumerate(self.up_blocks): 390 | is_final_block = i == len(self.up_blocks) - 1 391 | 392 | res_samples = down_block_res_samples[-len(upsample_block.resnets):] 393 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] 394 | 395 | # if we have not reached the final block and need to forward the 396 | # upsample size, we do it here 397 | if not is_final_block and forward_upsample_size: 398 | upsample_size = down_block_res_samples[-1].shape[2:] 399 | 400 | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: 401 | sample = upsample_block( 402 | hidden_states=sample, 403 | temb=emb, 404 | res_hidden_states_tuple=res_samples, 405 | encoder_hidden_states=encoder_hidden_states, 406 | upsample_size=upsample_size, 407 | attention_mask=attention_mask, 408 | ) 409 | else: 410 | sample = upsample_block( 411 | hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size 412 | ) 413 | # post-process 414 | sample = self.conv_norm_out(sample) 415 | sample = self.conv_act(sample) 416 | sample = self.conv_out(sample) 417 | 418 | if not return_dict: 419 | return (sample,) 420 | 421 | return UNet3DConditionOutput(sample=sample) 422 | 423 | @classmethod 424 | def from_pretrained_2d(cls, pretrained_model_path, subfolder=None): 425 | if subfolder is not None: 426 | pretrained_model_path = os.path.join(pretrained_model_path, subfolder) 427 | 428 | config_file = os.path.join(pretrained_model_path, 'config.json') 429 | if not os.path.isfile(config_file): 430 | raise RuntimeError(f"{config_file} does not exist") 431 | with open(config_file, "r") as f: 432 | config = json.load(f) 433 | config["_class_name"] = cls.__name__ 434 | config["down_block_types"] = [ 435 | "CrossAttnDownBlock3D", 436 | "CrossAttnDownBlock3D", 437 | "CrossAttnDownBlock3D", 438 | "DownBlock3D" 439 | ] 440 | config["up_block_types"] = [ 441 | "UpBlock3D", 442 | "CrossAttnUpBlock3D", 443 | "CrossAttnUpBlock3D", 444 | "CrossAttnUpBlock3D" 445 | ] 446 | 447 | config['mid_block_type'] = 'UNetMidBlock3DCrossAttn' 448 | 449 | from diffusers.utils import WEIGHTS_NAME 450 | model = cls.from_config(config) 451 | model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) 452 | if not os.path.isfile(model_file): 453 | raise RuntimeError(f"{model_file} does not exist") 454 | state_dict = torch.load(model_file, map_location="cpu") 455 | # origin_state_dict = torch.load('/root/code/Tune-A-Video/checkpoints/stable-diffusion-v1-4/unet/diffusion_pytorch_model.bin', map_location='cpu') 456 | for k, v in model.state_dict().items(): 457 | if '_temp.' in k: 458 | state_dict.update({k: v}) 459 | 460 | # for k, v in origin_state_dict.items(): 461 | # if '.to_q' in k and 'attn1' in k: 462 | # state_dict.update({k.replace('to_q', 'to_q2'): v}) 463 | 464 | model.load_state_dict(state_dict) 465 | 466 | return model 467 | -------------------------------------------------------------------------------- /freebloom/models/unet_blocks.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from .attention import Transformer3DModel 7 | from .resnet import Downsample3D, ResnetBlock3D, Upsample3D 8 | 9 | 10 | def get_down_block( 11 | down_block_type, 12 | num_layers, 13 | in_channels, 14 | out_channels, 15 | temb_channels, 16 | add_downsample, 17 | resnet_eps, 18 | resnet_act_fn, 19 | attn_num_head_channels, 20 | resnet_groups=None, 21 | cross_attention_dim=None, 22 | downsample_padding=None, 23 | dual_cross_attention=False, 24 | use_linear_projection=False, 25 | only_cross_attention=False, 26 | upcast_attention=False, 27 | resnet_time_scale_shift="default", 28 | ): 29 | down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type 30 | if down_block_type == "DownBlock3D": 31 | return DownBlock3D( 32 | num_layers=num_layers, 33 | in_channels=in_channels, 34 | out_channels=out_channels, 35 | temb_channels=temb_channels, 36 | add_downsample=add_downsample, 37 | resnet_eps=resnet_eps, 38 | resnet_act_fn=resnet_act_fn, 39 | resnet_groups=resnet_groups, 40 | downsample_padding=downsample_padding, 41 | resnet_time_scale_shift=resnet_time_scale_shift, 42 | ) 43 | elif down_block_type == "CrossAttnDownBlock3D": 44 | if cross_attention_dim is None: 45 | raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D") 46 | return CrossAttnDownBlock3D( 47 | num_layers=num_layers, 48 | in_channels=in_channels, 49 | out_channels=out_channels, 50 | temb_channels=temb_channels, 51 | add_downsample=add_downsample, 52 | resnet_eps=resnet_eps, 53 | resnet_act_fn=resnet_act_fn, 54 | resnet_groups=resnet_groups, 55 | downsample_padding=downsample_padding, 56 | cross_attention_dim=cross_attention_dim, 57 | attn_num_head_channels=attn_num_head_channels, 58 | dual_cross_attention=dual_cross_attention, 59 | use_linear_projection=use_linear_projection, 60 | only_cross_attention=only_cross_attention, 61 | upcast_attention=upcast_attention, 62 | resnet_time_scale_shift=resnet_time_scale_shift, 63 | ) 64 | raise ValueError(f"{down_block_type} does not exist.") 65 | 66 | 67 | def get_up_block( 68 | up_block_type, 69 | num_layers, 70 | in_channels, 71 | out_channels, 72 | prev_output_channel, 73 | temb_channels, 74 | add_upsample, 75 | resnet_eps, 76 | resnet_act_fn, 77 | attn_num_head_channels, 78 | resnet_groups=None, 79 | cross_attention_dim=None, 80 | dual_cross_attention=False, 81 | use_linear_projection=False, 82 | only_cross_attention=False, 83 | upcast_attention=False, 84 | resnet_time_scale_shift="default", 85 | ): 86 | up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type 87 | if up_block_type == "UpBlock3D": 88 | return UpBlock3D( 89 | num_layers=num_layers, 90 | in_channels=in_channels, 91 | out_channels=out_channels, 92 | prev_output_channel=prev_output_channel, 93 | temb_channels=temb_channels, 94 | add_upsample=add_upsample, 95 | resnet_eps=resnet_eps, 96 | resnet_act_fn=resnet_act_fn, 97 | resnet_groups=resnet_groups, 98 | resnet_time_scale_shift=resnet_time_scale_shift, 99 | ) 100 | elif up_block_type == "CrossAttnUpBlock3D": 101 | if cross_attention_dim is None: 102 | raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D") 103 | return CrossAttnUpBlock3D( 104 | num_layers=num_layers, 105 | in_channels=in_channels, 106 | out_channels=out_channels, 107 | prev_output_channel=prev_output_channel, 108 | temb_channels=temb_channels, 109 | add_upsample=add_upsample, 110 | resnet_eps=resnet_eps, 111 | resnet_act_fn=resnet_act_fn, 112 | resnet_groups=resnet_groups, 113 | cross_attention_dim=cross_attention_dim, 114 | attn_num_head_channels=attn_num_head_channels, 115 | dual_cross_attention=dual_cross_attention, 116 | use_linear_projection=use_linear_projection, 117 | only_cross_attention=only_cross_attention, 118 | upcast_attention=upcast_attention, 119 | resnet_time_scale_shift=resnet_time_scale_shift, 120 | ) 121 | raise ValueError(f"{up_block_type} does not exist.") 122 | 123 | 124 | class UNetMidBlock3DCrossAttn(nn.Module): 125 | def __init__( 126 | self, 127 | in_channels: int, 128 | temb_channels: int, 129 | dropout: float = 0.0, 130 | num_layers: int = 1, 131 | resnet_eps: float = 1e-6, 132 | resnet_time_scale_shift: str = "default", 133 | resnet_act_fn: str = "swish", 134 | resnet_groups: int = 32, 135 | resnet_pre_norm: bool = True, 136 | attn_num_head_channels=1, 137 | output_scale_factor=1.0, 138 | cross_attention_dim=1280, 139 | dual_cross_attention=False, 140 | use_linear_projection=False, 141 | upcast_attention=False, 142 | ): 143 | super().__init__() 144 | 145 | self.has_cross_attention = True 146 | self.attn_num_head_channels = attn_num_head_channels 147 | resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) 148 | 149 | # there is always at least one resnet 150 | resnets = [ 151 | ResnetBlock3D( 152 | in_channels=in_channels, 153 | out_channels=in_channels, 154 | temb_channels=temb_channels, 155 | eps=resnet_eps, 156 | groups=resnet_groups, 157 | dropout=dropout, 158 | time_embedding_norm=resnet_time_scale_shift, 159 | non_linearity=resnet_act_fn, 160 | output_scale_factor=output_scale_factor, 161 | pre_norm=resnet_pre_norm, 162 | ) 163 | ] 164 | attentions = [] 165 | 166 | for _ in range(num_layers): 167 | if dual_cross_attention: 168 | raise NotImplementedError 169 | attentions.append( 170 | Transformer3DModel( 171 | attn_num_head_channels, 172 | in_channels // attn_num_head_channels, 173 | in_channels=in_channels, 174 | num_layers=1, 175 | cross_attention_dim=cross_attention_dim, 176 | norm_num_groups=resnet_groups, 177 | use_linear_projection=use_linear_projection, 178 | upcast_attention=upcast_attention, 179 | ) 180 | ) 181 | resnets.append( 182 | ResnetBlock3D( 183 | in_channels=in_channels, 184 | out_channels=in_channels, 185 | temb_channels=temb_channels, 186 | eps=resnet_eps, 187 | groups=resnet_groups, 188 | dropout=dropout, 189 | time_embedding_norm=resnet_time_scale_shift, 190 | non_linearity=resnet_act_fn, 191 | output_scale_factor=output_scale_factor, 192 | pre_norm=resnet_pre_norm, 193 | ) 194 | ) 195 | 196 | self.attentions = nn.ModuleList(attentions) 197 | self.resnets = nn.ModuleList(resnets) 198 | 199 | def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): 200 | hidden_states = self.resnets[0](hidden_states, temb) 201 | for attn, resnet in zip(self.attentions, self.resnets[1:]): 202 | hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample 203 | hidden_states = resnet(hidden_states, temb) 204 | 205 | return hidden_states 206 | 207 | 208 | class CrossAttnDownBlock3D(nn.Module): 209 | def __init__( 210 | self, 211 | in_channels: int, 212 | out_channels: int, 213 | temb_channels: int, 214 | dropout: float = 0.0, 215 | num_layers: int = 1, 216 | resnet_eps: float = 1e-6, 217 | resnet_time_scale_shift: str = "default", 218 | resnet_act_fn: str = "swish", 219 | resnet_groups: int = 32, 220 | resnet_pre_norm: bool = True, 221 | attn_num_head_channels=1, 222 | cross_attention_dim=1280, 223 | output_scale_factor=1.0, 224 | downsample_padding=1, 225 | add_downsample=True, 226 | dual_cross_attention=False, 227 | use_linear_projection=False, 228 | only_cross_attention=False, 229 | upcast_attention=False, 230 | ): 231 | super().__init__() 232 | resnets = [] 233 | attentions = [] 234 | 235 | self.has_cross_attention = True 236 | self.attn_num_head_channels = attn_num_head_channels 237 | 238 | for i in range(num_layers): 239 | in_channels = in_channels if i == 0 else out_channels 240 | resnets.append( 241 | ResnetBlock3D( 242 | in_channels=in_channels, 243 | out_channels=out_channels, 244 | temb_channels=temb_channels, 245 | eps=resnet_eps, 246 | groups=resnet_groups, 247 | dropout=dropout, 248 | time_embedding_norm=resnet_time_scale_shift, 249 | non_linearity=resnet_act_fn, 250 | output_scale_factor=output_scale_factor, 251 | pre_norm=resnet_pre_norm, 252 | ) 253 | ) 254 | if dual_cross_attention: 255 | raise NotImplementedError 256 | attentions.append( 257 | Transformer3DModel( 258 | attn_num_head_channels, 259 | out_channels // attn_num_head_channels, 260 | in_channels=out_channels, 261 | num_layers=1, 262 | cross_attention_dim=cross_attention_dim, 263 | norm_num_groups=resnet_groups, 264 | use_linear_projection=use_linear_projection, 265 | only_cross_attention=only_cross_attention, 266 | upcast_attention=upcast_attention, 267 | ) 268 | ) 269 | self.attentions = nn.ModuleList(attentions) 270 | self.resnets = nn.ModuleList(resnets) 271 | 272 | if add_downsample: 273 | self.downsamplers = nn.ModuleList( 274 | [ 275 | Downsample3D( 276 | out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" 277 | ) 278 | ] 279 | ) 280 | else: 281 | self.downsamplers = None 282 | 283 | self.gradient_checkpointing = False 284 | 285 | def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): 286 | output_states = () 287 | 288 | for resnet, attn in zip(self.resnets, self.attentions): 289 | if self.training and self.gradient_checkpointing: 290 | 291 | def create_custom_forward(module, return_dict=None): 292 | def custom_forward(*inputs): 293 | if return_dict is not None: 294 | return module(*inputs, return_dict=return_dict) 295 | else: 296 | return module(*inputs) 297 | 298 | return custom_forward 299 | 300 | hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) 301 | hidden_states = torch.utils.checkpoint.checkpoint( 302 | create_custom_forward(attn, return_dict=False), 303 | hidden_states, 304 | encoder_hidden_states, 305 | )[0] 306 | else: 307 | hidden_states = resnet(hidden_states, temb) 308 | hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample 309 | 310 | output_states += (hidden_states,) 311 | 312 | if self.downsamplers is not None: 313 | for downsampler in self.downsamplers: 314 | hidden_states = downsampler(hidden_states) 315 | 316 | output_states += (hidden_states,) 317 | 318 | return hidden_states, output_states 319 | 320 | 321 | class DownBlock3D(nn.Module): 322 | def __init__( 323 | self, 324 | in_channels: int, 325 | out_channels: int, 326 | temb_channels: int, 327 | dropout: float = 0.0, 328 | num_layers: int = 1, 329 | resnet_eps: float = 1e-6, 330 | resnet_time_scale_shift: str = "default", 331 | resnet_act_fn: str = "swish", 332 | resnet_groups: int = 32, 333 | resnet_pre_norm: bool = True, 334 | output_scale_factor=1.0, 335 | add_downsample=True, 336 | downsample_padding=1, 337 | ): 338 | super().__init__() 339 | resnets = [] 340 | 341 | for i in range(num_layers): 342 | in_channels = in_channels if i == 0 else out_channels 343 | resnets.append( 344 | ResnetBlock3D( 345 | in_channels=in_channels, 346 | out_channels=out_channels, 347 | temb_channels=temb_channels, 348 | eps=resnet_eps, 349 | groups=resnet_groups, 350 | dropout=dropout, 351 | time_embedding_norm=resnet_time_scale_shift, 352 | non_linearity=resnet_act_fn, 353 | output_scale_factor=output_scale_factor, 354 | pre_norm=resnet_pre_norm, 355 | ) 356 | ) 357 | 358 | self.resnets = nn.ModuleList(resnets) 359 | 360 | if add_downsample: 361 | self.downsamplers = nn.ModuleList( 362 | [ 363 | Downsample3D( 364 | out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" 365 | ) 366 | ] 367 | ) 368 | else: 369 | self.downsamplers = None 370 | 371 | self.gradient_checkpointing = False 372 | 373 | def forward(self, hidden_states, temb=None): 374 | output_states = () 375 | 376 | for resnet in self.resnets: 377 | if self.training and self.gradient_checkpointing: 378 | 379 | def create_custom_forward(module): 380 | def custom_forward(*inputs): 381 | return module(*inputs) 382 | 383 | return custom_forward 384 | 385 | hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) 386 | else: 387 | hidden_states = resnet(hidden_states, temb) 388 | 389 | output_states += (hidden_states,) 390 | 391 | if self.downsamplers is not None: 392 | for downsampler in self.downsamplers: 393 | hidden_states = downsampler(hidden_states) 394 | 395 | output_states += (hidden_states,) 396 | 397 | return hidden_states, output_states 398 | 399 | 400 | class CrossAttnUpBlock3D(nn.Module): 401 | def __init__( 402 | self, 403 | in_channels: int, 404 | out_channels: int, 405 | prev_output_channel: int, 406 | temb_channels: int, 407 | dropout: float = 0.0, 408 | num_layers: int = 1, 409 | resnet_eps: float = 1e-6, 410 | resnet_time_scale_shift: str = "default", 411 | resnet_act_fn: str = "swish", 412 | resnet_groups: int = 32, 413 | resnet_pre_norm: bool = True, 414 | attn_num_head_channels=1, 415 | cross_attention_dim=1280, 416 | output_scale_factor=1.0, 417 | add_upsample=True, 418 | dual_cross_attention=False, 419 | use_linear_projection=False, 420 | only_cross_attention=False, 421 | upcast_attention=False, 422 | ): 423 | super().__init__() 424 | resnets = [] 425 | attentions = [] 426 | 427 | self.has_cross_attention = True 428 | self.attn_num_head_channels = attn_num_head_channels 429 | 430 | for i in range(num_layers): 431 | res_skip_channels = in_channels if (i == num_layers - 1) else out_channels 432 | resnet_in_channels = prev_output_channel if i == 0 else out_channels 433 | 434 | resnets.append( 435 | ResnetBlock3D( 436 | in_channels=resnet_in_channels + res_skip_channels, 437 | out_channels=out_channels, 438 | temb_channels=temb_channels, 439 | eps=resnet_eps, 440 | groups=resnet_groups, 441 | dropout=dropout, 442 | time_embedding_norm=resnet_time_scale_shift, 443 | non_linearity=resnet_act_fn, 444 | output_scale_factor=output_scale_factor, 445 | pre_norm=resnet_pre_norm, 446 | ) 447 | ) 448 | if dual_cross_attention: 449 | raise NotImplementedError 450 | attentions.append( 451 | Transformer3DModel( 452 | attn_num_head_channels, 453 | out_channels // attn_num_head_channels, 454 | in_channels=out_channels, 455 | num_layers=1, 456 | cross_attention_dim=cross_attention_dim, 457 | norm_num_groups=resnet_groups, 458 | use_linear_projection=use_linear_projection, 459 | only_cross_attention=only_cross_attention, 460 | upcast_attention=upcast_attention, 461 | ) 462 | ) 463 | 464 | self.attentions = nn.ModuleList(attentions) 465 | self.resnets = nn.ModuleList(resnets) 466 | 467 | if add_upsample: 468 | self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]) 469 | else: 470 | self.upsamplers = None 471 | 472 | self.gradient_checkpointing = False 473 | 474 | def forward( 475 | self, 476 | hidden_states, 477 | res_hidden_states_tuple, 478 | temb=None, 479 | encoder_hidden_states=None, 480 | upsample_size=None, 481 | attention_mask=None, 482 | ): 483 | for resnet, attn in zip(self.resnets, self.attentions): 484 | # pop res hidden states 485 | res_hidden_states = res_hidden_states_tuple[-1] 486 | res_hidden_states_tuple = res_hidden_states_tuple[:-1] 487 | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) 488 | 489 | if self.training and self.gradient_checkpointing: 490 | 491 | def create_custom_forward(module, return_dict=None): 492 | def custom_forward(*inputs): 493 | if return_dict is not None: 494 | return module(*inputs, return_dict=return_dict) 495 | else: 496 | return module(*inputs) 497 | 498 | return custom_forward 499 | 500 | hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) 501 | hidden_states = torch.utils.checkpoint.checkpoint( 502 | create_custom_forward(attn, return_dict=False), 503 | hidden_states, 504 | encoder_hidden_states, 505 | )[0] 506 | else: 507 | hidden_states = resnet(hidden_states, temb) 508 | hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample 509 | 510 | if self.upsamplers is not None: 511 | for upsampler in self.upsamplers: 512 | hidden_states = upsampler(hidden_states, upsample_size) 513 | 514 | return hidden_states 515 | 516 | 517 | class UpBlock3D(nn.Module): 518 | def __init__( 519 | self, 520 | in_channels: int, 521 | prev_output_channel: int, 522 | out_channels: int, 523 | temb_channels: int, 524 | dropout: float = 0.0, 525 | num_layers: int = 1, 526 | resnet_eps: float = 1e-6, 527 | resnet_time_scale_shift: str = "default", 528 | resnet_act_fn: str = "swish", 529 | resnet_groups: int = 32, 530 | resnet_pre_norm: bool = True, 531 | output_scale_factor=1.0, 532 | add_upsample=True, 533 | ): 534 | super().__init__() 535 | resnets = [] 536 | 537 | for i in range(num_layers): 538 | res_skip_channels = in_channels if (i == num_layers - 1) else out_channels 539 | resnet_in_channels = prev_output_channel if i == 0 else out_channels 540 | 541 | resnets.append( 542 | ResnetBlock3D( 543 | in_channels=resnet_in_channels + res_skip_channels, 544 | out_channels=out_channels, 545 | temb_channels=temb_channels, 546 | eps=resnet_eps, 547 | groups=resnet_groups, 548 | dropout=dropout, 549 | time_embedding_norm=resnet_time_scale_shift, 550 | non_linearity=resnet_act_fn, 551 | output_scale_factor=output_scale_factor, 552 | pre_norm=resnet_pre_norm, 553 | ) 554 | ) 555 | 556 | self.resnets = nn.ModuleList(resnets) 557 | 558 | if add_upsample: 559 | self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]) 560 | else: 561 | self.upsamplers = None 562 | 563 | self.gradient_checkpointing = False 564 | 565 | def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): 566 | for resnet in self.resnets: 567 | # pop res hidden states 568 | res_hidden_states = res_hidden_states_tuple[-1] 569 | res_hidden_states_tuple = res_hidden_states_tuple[:-1] 570 | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) 571 | 572 | if self.training and self.gradient_checkpointing: 573 | 574 | def create_custom_forward(module): 575 | def custom_forward(*inputs): 576 | return module(*inputs) 577 | 578 | return custom_forward 579 | 580 | hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) 581 | else: 582 | hidden_states = resnet(hidden_states, temb) 583 | 584 | if self.upsamplers is not None: 585 | for upsampler in self.upsamplers: 586 | hidden_states = upsampler(hidden_states, upsample_size) 587 | 588 | return hidden_states 589 | -------------------------------------------------------------------------------- /freebloom/pipelines/pipeline_spatio_temporal.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py 2 | import copy 3 | import inspect 4 | import os.path 5 | from dataclasses import dataclass 6 | from typing import Callable, List, Optional, Union 7 | 8 | import numpy as np 9 | import torch 10 | import torchvision.utils 11 | from diffusers.configuration_utils import FrozenDict 12 | from diffusers.models import AutoencoderKL 13 | from diffusers.pipeline_utils import DiffusionPipeline 14 | from diffusers.schedulers import ( 15 | DDIMScheduler, 16 | DPMSolverMultistepScheduler, 17 | EulerAncestralDiscreteScheduler, 18 | EulerDiscreteScheduler, 19 | LMSDiscreteScheduler, 20 | PNDMScheduler, 21 | ) 22 | from diffusers.utils import deprecate, logging, BaseOutput 23 | from diffusers.utils import is_accelerate_available 24 | from einops import rearrange, repeat 25 | from packaging import version 26 | from transformers import CLIPTextModel, CLIPTokenizer 27 | 28 | from ..models.unet import UNet3DConditionModel 29 | from ..prompt_attention import attention_util, ptp_utils 30 | 31 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 32 | 33 | 34 | @dataclass 35 | class SpatioTemporalPipelineOutput(BaseOutput): 36 | videos: Union[torch.Tensor, np.ndarray] 37 | 38 | 39 | class SpatioTemporalPipeline(DiffusionPipeline): 40 | _optional_components = [] 41 | 42 | def __init__( 43 | self, 44 | vae: AutoencoderKL, 45 | text_encoder: CLIPTextModel, 46 | tokenizer: CLIPTokenizer, 47 | unet: UNet3DConditionModel, 48 | scheduler: Union[ 49 | DDIMScheduler, 50 | PNDMScheduler, 51 | LMSDiscreteScheduler, 52 | EulerDiscreteScheduler, 53 | EulerAncestralDiscreteScheduler, 54 | DPMSolverMultistepScheduler, 55 | ], 56 | noise_scheduler: Union[ 57 | DDIMScheduler, 58 | PNDMScheduler, 59 | LMSDiscreteScheduler, 60 | EulerDiscreteScheduler, 61 | EulerAncestralDiscreteScheduler, 62 | DPMSolverMultistepScheduler, 63 | ] = None, 64 | disk_store: bool = False, 65 | config=None 66 | ): 67 | super().__init__() 68 | 69 | if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: 70 | deprecation_message = ( 71 | f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" 72 | f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " 73 | "to update the config accordingly as leaving `steps_offset` might led to incorrect results" 74 | " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," 75 | " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" 76 | " file" 77 | ) 78 | deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) 79 | new_config = dict(scheduler.config) 80 | new_config["steps_offset"] = 1 81 | scheduler._internal_dict = FrozenDict(new_config) 82 | 83 | if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: 84 | deprecation_message = ( 85 | f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." 86 | " `clip_sample` should be set to False in the configuration file. Please make sure to update the" 87 | " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" 88 | " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" 89 | " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" 90 | ) 91 | deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) 92 | new_config = dict(scheduler.config) 93 | new_config["clip_sample"] = False 94 | scheduler._internal_dict = FrozenDict(new_config) 95 | 96 | is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( 97 | version.parse(unet.config._diffusers_version).base_version 98 | ) < version.parse("0.9.0.dev0") 99 | is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 100 | if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: 101 | deprecation_message = ( 102 | "The configuration file of the unet has set the default `sample_size` to smaller than" 103 | " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" 104 | " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" 105 | " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" 106 | " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" 107 | " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" 108 | " in the config might lead to incorrect results in future versions. If you have downloaded this" 109 | " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" 110 | " the `unet/config.json` file" 111 | ) 112 | deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) 113 | new_config = dict(unet.config) 114 | new_config["sample_size"] = 64 115 | unet._internal_dict = FrozenDict(new_config) 116 | 117 | self.register_modules( 118 | vae=vae, 119 | text_encoder=text_encoder, 120 | tokenizer=tokenizer, 121 | unet=unet, 122 | scheduler=scheduler, 123 | noise_scheduler=noise_scheduler 124 | ) 125 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) 126 | 127 | self.controller = attention_util.AttentionTest(disk_store=disk_store, config=config) 128 | self.hyper_config = config 129 | 130 | def enable_vae_slicing(self): 131 | self.vae.enable_slicing() 132 | 133 | def disable_vae_slicing(self): 134 | self.vae.disable_slicing() 135 | 136 | def enable_sequential_cpu_offload(self, gpu_id=0): 137 | if is_accelerate_available(): 138 | from accelerate import cpu_offload 139 | else: 140 | raise ImportError("Please install accelerate via `pip install accelerate`") 141 | 142 | device = torch.device(f"cuda:{gpu_id}") 143 | 144 | for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: 145 | if cpu_offloaded_model is not None: 146 | cpu_offload(cpu_offloaded_model, device) 147 | 148 | @property 149 | def _execution_device(self): 150 | if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): 151 | return self.device 152 | for module in self.unet.modules(): 153 | if ( 154 | hasattr(module, "_hf_hook") 155 | and hasattr(module._hf_hook, "execution_device") 156 | and module._hf_hook.execution_device is not None 157 | ): 158 | return torch.device(module._hf_hook.execution_device) 159 | return self.device 160 | 161 | def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt): 162 | batch_size = len(prompt) if isinstance(prompt, list) else 1 163 | 164 | text_inputs = self.tokenizer( 165 | prompt, 166 | padding="max_length", 167 | max_length=self.tokenizer.model_max_length, 168 | truncation=True, 169 | return_tensors="pt", 170 | ) 171 | text_input_ids = text_inputs.input_ids 172 | untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids 173 | 174 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): 175 | removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1: -1]) 176 | logger.warning( 177 | "The following part of your input was truncated because CLIP can only handle sequences up to" 178 | f" {self.tokenizer.model_max_length} tokens: {removed_text}" 179 | ) 180 | 181 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: 182 | attention_mask = text_inputs.attention_mask.to(device) 183 | else: 184 | attention_mask = None 185 | 186 | text_embeddings = self.text_encoder( 187 | text_input_ids.to(device), 188 | attention_mask=attention_mask, 189 | ) 190 | text_embeddings = text_embeddings[0] 191 | 192 | # duplicate text embeddings for each generation per prompt, using mps friendly method 193 | bs_embed, seq_len, _ = text_embeddings.shape 194 | text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1) 195 | text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1) 196 | 197 | # get unconditional embeddings for classifier free guidance 198 | if do_classifier_free_guidance: 199 | uncond_tokens: List[str] 200 | if negative_prompt is None: 201 | uncond_tokens = [""] * batch_size 202 | elif type(prompt) is not type(negative_prompt): 203 | raise TypeError( 204 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" 205 | f" {type(prompt)}." 206 | ) 207 | elif isinstance(negative_prompt, str): 208 | uncond_tokens = [negative_prompt] 209 | elif batch_size != len(negative_prompt): 210 | raise ValueError( 211 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" 212 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" 213 | " the batch size of `prompt`." 214 | ) 215 | else: 216 | uncond_tokens = negative_prompt 217 | 218 | max_length = text_input_ids.shape[-1] 219 | uncond_input = self.tokenizer( 220 | uncond_tokens, 221 | padding="max_length", 222 | max_length=max_length, 223 | truncation=True, 224 | return_tensors="pt", 225 | ) 226 | 227 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: 228 | attention_mask = uncond_input.attention_mask.to(device) 229 | else: 230 | attention_mask = None 231 | 232 | uncond_embeddings = self.text_encoder( 233 | uncond_input.input_ids.to(device), 234 | attention_mask=attention_mask, 235 | ) 236 | uncond_embeddings = uncond_embeddings[0] 237 | 238 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method 239 | seq_len = uncond_embeddings.shape[1] 240 | uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1) 241 | uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1) 242 | 243 | # For classifier free guidance, we need to do two forward passes. 244 | # Here we concatenate the unconditional and text embeddings into a single batch 245 | # to avoid doing two forward passes 246 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) 247 | 248 | return text_embeddings 249 | 250 | def decode_latents(self, latents, return_tensor=False): 251 | video_length = latents.shape[2] 252 | latents = 1 / 0.18215 * latents 253 | latents = rearrange(latents, "b c f h w -> (b f) c h w") 254 | video = self.vae.decode(latents).sample 255 | video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length) 256 | video = (video / 2 + 0.5).clamp(0, 1) 257 | if return_tensor: 258 | return video 259 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 260 | video = video.cpu().float().numpy() 261 | return video 262 | 263 | def prepare_extra_step_kwargs(self, generator, eta): 264 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 265 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 266 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 267 | # and should be between [0, 1] 268 | 269 | accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) 270 | extra_step_kwargs = {} 271 | if accepts_eta: 272 | extra_step_kwargs["eta"] = eta 273 | 274 | # check if the scheduler accepts generator 275 | accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) 276 | if accepts_generator: 277 | extra_step_kwargs["generator"] = generator 278 | return extra_step_kwargs 279 | 280 | def check_inputs(self, prompt, height, width, callback_steps, latents): 281 | if not isinstance(prompt, str) and not isinstance(prompt, list): 282 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") 283 | 284 | if height % 8 != 0 or width % 8 != 0: 285 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") 286 | 287 | if (callback_steps is None) or ( 288 | callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) 289 | ): 290 | raise ValueError( 291 | f"`callback_steps` has to be a positive integer but is {callback_steps} of type" 292 | f" {type(callback_steps)}." 293 | ) 294 | if isinstance(prompt, list) and latents is not None and len(prompt) != latents.shape[2]: 295 | raise ValueError( 296 | f"`prompt` is list but does match all frames. frames length: {latents.shape[2]}, prompt length: {len(prompt)}") 297 | 298 | def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, 299 | latents=None, store_attention=True, frame_same_noise=False): 300 | if store_attention and self.controller is not None: 301 | ptp_utils.register_attention_control(self, self.controller) 302 | 303 | shape = ( 304 | batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, 305 | width // self.vae_scale_factor) 306 | if frame_same_noise: 307 | shape = (batch_size, num_channels_latents, 1, height // self.vae_scale_factor, 308 | width // self.vae_scale_factor) 309 | if isinstance(generator, list) and len(generator) != batch_size: 310 | raise ValueError( 311 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 312 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 313 | ) 314 | 315 | if latents is None: 316 | rand_device = "cpu" if device.type == "mps" else device 317 | 318 | if isinstance(generator, list): 319 | shape = (1,) + shape[1:] 320 | latents = [ 321 | torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) 322 | for i in range(batch_size) 323 | ] 324 | if frame_same_noise: 325 | latents = [repeat(latent, 'b c 1 h w -> b c f h w', f=video_length) for latent in latents] 326 | latents = torch.cat(latents, dim=0).to(device) 327 | else: 328 | latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device) 329 | if frame_same_noise: 330 | latents = repeat(latents, 'b c 1 h w -> b c f h w', f=video_length) 331 | 332 | else: 333 | if latents.shape != shape: 334 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") 335 | latents = latents.to(device) 336 | 337 | # scale the initial noise by the standard deviation required by the scheduler 338 | latents = latents * self.scheduler.init_noise_sigma 339 | return latents 340 | 341 | @torch.no_grad() 342 | def __call__( 343 | self, 344 | prompt: Union[str, List[str]], 345 | video_length: Optional[int], 346 | height: Optional[int] = None, 347 | width: Optional[int] = None, 348 | num_inference_steps: int = 50, 349 | guidance_scale: float = 7.5, 350 | negative_prompt: Optional[Union[str, List[str]]] = None, 351 | num_videos_per_prompt: Optional[int] = 1, 352 | eta: float = 0.0, 353 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 354 | latents: Optional[torch.FloatTensor] = None, 355 | output_type: Optional[str] = "tensor", 356 | return_dict: bool = True, 357 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 358 | callback_steps: Optional[int] = 1, 359 | fixed_latents: Optional[dict[torch.FloatTensor]] = None, 360 | fixed_latents_idx: list = None, 361 | inner_idx: list = None, 362 | init_text_embedding: torch.Tensor = None, 363 | mask: Optional[dict] = None, 364 | save_vis_inner=False, 365 | return_tensor=False, 366 | return_text_embedding=False, 367 | output_dir=None, 368 | **kwargs, 369 | ): 370 | if self.controller is not None: 371 | self.controller.reset() 372 | self.controller.batch_size = video_length 373 | 374 | # Default height and width to unet 375 | height = height or self.unet.config.sample_size * self.vae_scale_factor 376 | width = width or self.unet.config.sample_size * self.vae_scale_factor 377 | 378 | # Check inputs. Raise error if not correct 379 | self.check_inputs(prompt, height, width, callback_steps, latents) 380 | 381 | # Define call parameters 382 | # batch_size = 1 if isinstance(prompt, str) else len(prompt) 383 | batch_size = 1 384 | device = self._execution_device 385 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 386 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 387 | # corresponds to doing no classifier free guidance. 388 | do_classifier_free_guidance = guidance_scale > 1.0 389 | 390 | # Encode input prompt 391 | text_embeddings = self._encode_prompt( 392 | prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt 393 | ) 394 | 395 | if init_text_embedding is not None: 396 | text_embeddings[video_length:] = init_text_embedding 397 | 398 | # Prepare timesteps 399 | self.scheduler.set_timesteps(num_inference_steps, device=device) 400 | timesteps = self.scheduler.timesteps 401 | 402 | # Prepare latent variables 403 | num_channels_latents = self.unet.in_channels 404 | latents = self.prepare_latents( 405 | batch_size * num_videos_per_prompt, 406 | num_channels_latents, 407 | video_length, 408 | height, 409 | width, 410 | text_embeddings.dtype, 411 | device, 412 | generator, 413 | latents, 414 | ) 415 | latents_dtype = latents.dtype 416 | 417 | # Prepare extra step kwargs. 418 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 419 | 420 | # Denoising loop 421 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 422 | with self.progress_bar(total=num_inference_steps) as progress_bar: 423 | for i, t in enumerate(timesteps): 424 | # expand the latents if we are doing classifier free guidance 425 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 426 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 427 | 428 | # predict the noise residual 429 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample.to( 430 | dtype=latents_dtype) 431 | 432 | # perform guidance 433 | if do_classifier_free_guidance: 434 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 435 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 436 | 437 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample 438 | 439 | if mask is not None and i < 10: 440 | f_1 = latents[:, :, :1] 441 | latents = mask[1] * f_1 + mask[0] * latents 442 | latents = repeat(latents, 'b c 1 h w -> b c f h w', f=video_length) 443 | latents = latents.to(latents_dtype) 444 | 445 | # call the callback, if provided 446 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 447 | progress_bar.update() 448 | if callback is not None and i % callback_steps == 0: 449 | callback(i, t, latents) 450 | 451 | if self.controller is not None: 452 | latents_old = latents 453 | dtype = latents.dtype 454 | latents_new = self.controller.step_callback(latents, inner_idx) 455 | latents = latents_new.to(dtype) 456 | 457 | if fixed_latents is not None: 458 | latents[:, :, fixed_latents_idx] = fixed_latents[i + 1] 459 | 460 | self.controller.empty_cache() 461 | 462 | if save_vis_inner: 463 | video = self.decode_latents(latents) 464 | video = torch.from_numpy(video) 465 | video = rearrange(video, 'b c f h w -> (b f) c h w') 466 | if os.path.exists(output_dir): 467 | os.makedirs(f'{output_dir}/inner', exist_ok=True) 468 | torchvision.utils.save_image(video, f'{output_dir}/inner/{i}.jpg') 469 | 470 | # Post-processing 471 | video = self.decode_latents(latents, return_tensor) 472 | 473 | # Convert to tensor 474 | if output_type == "tensor" and isinstance(video, np.ndarray): 475 | video = torch.from_numpy(video) 476 | 477 | if not return_dict: 478 | return video 479 | if return_text_embedding: 480 | return SpatioTemporalPipelineOutput(videos=video), text_embeddings[video_length:] 481 | 482 | return SpatioTemporalPipelineOutput(videos=video) 483 | 484 | def forward_to_t(self, latents_0, t, noise, text_embeddings): 485 | noisy_latents = self.noise_scheduler.add_noise(latents_0, noise, t) 486 | 487 | # Predict the noise residual and compute loss 488 | model_pred = self.unet(noisy_latents, t, text_embeddings).sample 489 | return model_pred 490 | 491 | def interpolate_between_two_frames( 492 | self, 493 | samples, 494 | x_noise, 495 | rand_ratio, 496 | where, 497 | prompt: Union[str, List[str]], 498 | k=3, 499 | height: Optional[int] = None, 500 | width: Optional[int] = None, 501 | num_inference_steps: int = 50, 502 | guidance_scale: float = 7.5, 503 | negative_prompt: Optional[Union[str, List[str]]] = None, 504 | num_videos_per_prompt: Optional[int] = 1, 505 | eta: float = 0.0, 506 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 507 | output_type: Optional[str] = "tensor", 508 | return_dict: bool = True, 509 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 510 | callback_steps: Optional[int] = 1, 511 | base_noise=None, 512 | text_embedding=None, 513 | **kwargs): 514 | if text_embedding is None: 515 | text_embeddings = self._encode_prompt(prompt, x_noise.device, num_videos_per_prompt, 0, None) 516 | else: 517 | text_embeddings = text_embedding 518 | prompt_temp = copy.deepcopy(prompt) 519 | 520 | times = 0 521 | k = k 522 | fixed_original_idx = where 523 | while times < k: 524 | if not isinstance(samples, torch.Tensor): 525 | samples = samples[0] 526 | 527 | inner_step = self.controller.latents_store 528 | if times == 0: 529 | for key, value in inner_step.items(): 530 | inner_step[key] = value[:, :, fixed_original_idx] 531 | 532 | b, c, f, _, __ = samples.shape 533 | dist_list = [] 534 | for i in range(f - 1): 535 | dist_list.append(torch.nn.functional.mse_loss(samples[:, :, i + 1], samples[:, :, i])) 536 | dist = torch.tensor(dist_list) 537 | num_insert = 1 538 | value, idx = torch.topk(dist, k=num_insert) 539 | tgt_video_length = num_insert + f 540 | 541 | idx = sorted(idx) 542 | tgt_idx = [id + 1 + i for i, id in enumerate(idx)] 543 | 544 | original_idx = [i for i in range(tgt_video_length) if i not in tgt_idx] 545 | prompt_after_inner = [] 546 | x_temp = [] 547 | text_embeddings_temp = [] 548 | prompt_i = 0 549 | for i in range(tgt_video_length): 550 | if i in tgt_idx: 551 | prompt_after_inner.append('') 552 | search_idx = sorted(original_idx + [i]) 553 | find = search_idx.index(i) 554 | pre_idx, next_idx = search_idx[find - 1], search_idx[find + 1] 555 | length = next_idx - pre_idx 556 | 557 | x_temp.append(np.cos(rand_ratio * np.pi / 2) * base_noise[:, :, 0] + np.sin( 558 | rand_ratio * np.pi / 2) * torch.randn_like(base_noise[:, :, 0])) 559 | text_embeddings_temp.append( 560 | (next_idx - i) / length * text_embeddings[prompt_i - 1] + (i - pre_idx) / length * 561 | text_embeddings[prompt_i]) 562 | else: 563 | prompt_after_inner.append(prompt_temp[prompt_i]) 564 | x_temp.append(x_noise[:, :, prompt_i]) 565 | text_embeddings_temp.append(text_embeddings[prompt_i]) 566 | prompt_i += 1 567 | 568 | inner_idx = tgt_idx 569 | prompt_temp = prompt_after_inner 570 | 571 | x_noise = torch.stack(x_temp, dim=2) 572 | text_embeddings = torch.stack(text_embeddings_temp, dim=0) 573 | 574 | samples, prompt_embedding = self(prompt_temp, 575 | tgt_video_length, 576 | height, 577 | width, 578 | num_inference_steps, 579 | guidance_scale, 580 | [negative_prompt] * tgt_video_length, 581 | num_videos_per_prompt, 582 | eta, 583 | generator, 584 | x_noise, 585 | output_type, 586 | return_dict, 587 | callback, 588 | callback_steps, 589 | init_text_embedding=text_embeddings, 590 | inner_idx=inner_idx, 591 | return_text_embedding=True, 592 | fixed_latents=copy.deepcopy(inner_step), 593 | fixed_latents_idx=original_idx, 594 | **kwargs) 595 | 596 | times += 1 597 | 598 | return samples[0] 599 | -------------------------------------------------------------------------------- /freebloom/prompt_attention/attention_util.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import copy 3 | import os 4 | from typing import Union, Tuple, Dict, Optional, List 5 | 6 | import numpy as np 7 | import torch 8 | from PIL import Image 9 | 10 | from freebloom.prompt_attention import ptp_utils, seq_aligner 11 | from freebloom.prompt_attention.ptp_utils import get_time_string 12 | 13 | 14 | class AttentionControl(abc.ABC): 15 | 16 | def step_callback(self, x_t): 17 | self.cur_att_layer = 0 18 | self.cur_step += 1 19 | self.between_steps() 20 | return x_t 21 | 22 | def between_steps(self): 23 | return 24 | 25 | @property 26 | def num_uncond_att_layers(self): 27 | return self.num_att_layers if self.LOW_RESOURCE else 0 28 | 29 | @abc.abstractmethod 30 | def forward(self, attn, is_cross: bool, place_in_unet: str): 31 | raise NotImplementedError 32 | 33 | def __call__(self, attn, is_cross: bool, place_in_unet: str): 34 | if self.cur_att_layer >= self.num_uncond_att_layers: 35 | if self.LOW_RESOURCE: 36 | attn = self.forward(attn, is_cross, place_in_unet) 37 | else: 38 | h = attn.shape[0] 39 | attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet) 40 | self.cur_att_layer += 1 41 | 42 | return attn 43 | 44 | def reset(self): 45 | self.cur_step = 0 46 | self.cur_att_layer = 0 47 | 48 | def __init__(self): 49 | self.LOW_RESOURCE = False 50 | self.cur_step = 0 51 | self.num_att_layers = -1 52 | self.cur_att_layer = 0 53 | 54 | 55 | class EmptyControl(AttentionControl): 56 | 57 | def forward(self, attn, is_cross: bool, place_in_unet: str): 58 | return attn 59 | 60 | 61 | class AttentionStore(AttentionControl): 62 | 63 | def step_callback(self, x_t): 64 | x_t = super().step_callback(x_t) 65 | self.latents_store[self.cur_step] = x_t 66 | return x_t 67 | 68 | @staticmethod 69 | def get_empty_store(): 70 | return {"down_cross": [], "mid_cross": [], "up_cross": [], 71 | "down_self": [], "mid_self": [], "up_self": []} 72 | 73 | def forward(self, attn, is_cross: bool, place_in_unet: str): 74 | key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" 75 | h = attn.shape[0] // self.batch_size 76 | if attn.shape[1] <= 32 ** 2 and is_cross: # avoid memory overhead 77 | attn = attn.reshape(self.batch_size, h, *attn.shape[1:]) 78 | self.step_store[key].append(attn) 79 | attn = attn.reshape(self.batch_size * h, *attn.shape[2:]) 80 | return attn 81 | 82 | def between_steps(self): 83 | if len(self.attention_store) == 0: 84 | self.attention_store = self.step_store 85 | else: 86 | for key in self.attention_store: 87 | for i in range(len(self.attention_store[key])): 88 | self.attention_store[key][i] += self.step_store[key][i] 89 | 90 | if self.disk_store: 91 | path = self.store_dir + f'/{self.cur_step:03d}_attn.pt' 92 | torch.save(copy.deepcopy(self.step_store), path) 93 | 94 | self.step_store = self.get_empty_store() 95 | 96 | def get_average_attention(self): 97 | average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in 98 | self.attention_store} 99 | return average_attention 100 | 101 | def reset(self): 102 | super(AttentionStore, self).reset() 103 | self.step_store = self.get_empty_store() 104 | self.attention_store = {} 105 | 106 | def empty_cache(self): 107 | self.step_store = self.get_empty_store() 108 | self.attention_store = {} 109 | 110 | def __init__(self, disk_store=False, config=None): 111 | super(AttentionStore, self).__init__() 112 | self.disk_store = disk_store 113 | if self.disk_store: 114 | time_string = get_time_string() 115 | path = f'./.temp/attention_cache_{time_string}' 116 | os.makedirs(path, exist_ok=True) 117 | self.store_dir = path 118 | else: 119 | self.store_dir = None 120 | 121 | if config: 122 | self.config = config 123 | 124 | self.latents_store = {} 125 | 126 | self.step_store = self.get_empty_store() 127 | self.attention_store = {} 128 | self.attention_type_former = config["validation_data"]["attention_type_former"] 129 | self.attention_type_latter = config["validation_data"]["attention_type_latter"] 130 | self.attention_adapt_step = config["validation_data"]["attention_adapt_step"] 131 | 132 | 133 | class AttentionControlEdit(AttentionStore, abc.ABC): 134 | 135 | def step_callback(self, x_t): 136 | if self.local_blend is not None: 137 | x_t = self.local_blend(x_t, self.attention_store) 138 | return x_t 139 | 140 | def replace_self_attention(self, attn_base, att_replace): 141 | if att_replace.shape[2] <= 16 ** 2: 142 | return attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape) 143 | else: 144 | return att_replace 145 | 146 | @abc.abstractmethod 147 | def replace_cross_attention(self, attn_base, att_replace): 148 | raise NotImplementedError 149 | 150 | def forward(self, attn, is_cross: bool, place_in_unet: str): 151 | super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet) 152 | if is_cross or (self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1]): 153 | h = attn.shape[0] // (self.batch_size) 154 | attn = attn.reshape(self.batch_size, h, *attn.shape[1:]) 155 | attn_base, attn_repalce = attn[0], attn[1:] 156 | if is_cross: 157 | alpha_words = self.cross_replace_alpha[self.cur_step] 158 | attn_repalce_new = self.replace_cross_attention(attn_base, attn_repalce) * alpha_words + ( 159 | 1 - alpha_words) * attn_repalce 160 | attn[1:] = attn_repalce_new 161 | else: 162 | attn[1:] = self.replace_self_attention(attn_base, attn_repalce) 163 | attn = attn.reshape(self.batch_size * h, *attn.shape[2:]) 164 | return attn 165 | 166 | def __init__(self, prompts, num_steps: int, 167 | cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]], 168 | self_replace_steps: Union[float, Tuple[float, float]], 169 | local_blend: None, 170 | tokenizer=None, 171 | device=torch.device('cuda') if torch.cuda.is_available() else torch.device( 172 | 'cpu')): # Optional[LocalBlend]): 173 | super(AttentionControlEdit, self).__init__() 174 | self.batch_size = len(prompts) 175 | self.cross_replace_alpha = ptp_utils.get_time_words_attention_alpha(prompts, num_steps, cross_replace_steps, 176 | tokenizer).to(device) 177 | if type(self_replace_steps) is float: 178 | self_replace_steps = 0, self_replace_steps 179 | self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1]) 180 | self.local_blend = local_blend 181 | 182 | 183 | class AttentionReplace(AttentionControlEdit): 184 | 185 | def replace_cross_attention(self, attn_base, att_replace): 186 | return torch.einsum('hpw,bwn->bhpn', attn_base, self.mapper) 187 | 188 | def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, 189 | local_blend=None, tokenizer=None, 190 | device=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')): 191 | super(AttentionReplace, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend, 192 | tokenizer=tokenizer, device=device) 193 | self.mapper = seq_aligner.get_replacement_mapper(prompts, tokenizer).to(device) 194 | 195 | 196 | class AttentionRefine(AttentionControlEdit): 197 | 198 | def replace_cross_attention(self, attn_base, att_replace): 199 | attn_base_replace = attn_base[:, :, self.mapper].permute(2, 0, 1, 3) 200 | attn_replace = attn_base_replace * self.alphas + att_replace * (1 - self.alphas) 201 | return attn_replace 202 | 203 | def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, 204 | local_blend=None, 205 | tokenizer=None, 206 | device=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')): 207 | super(AttentionRefine, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend) 208 | self.mapper, alphas = seq_aligner.get_refinement_mapper(prompts, tokenizer) 209 | self.mapper, alphas = self.mapper.to(device), alphas.to(device) 210 | self.alphas = alphas.reshape(alphas.shape[0], 1, 1, alphas.shape[1]) 211 | 212 | 213 | class AttentionReweight(AttentionControlEdit): 214 | 215 | def replace_cross_attention(self, attn_base, att_replace): 216 | if self.prev_controller is not None: 217 | attn_base = self.prev_controller.replace_cross_attention(attn_base, att_replace) 218 | attn_replace = attn_base[None, :, :, :] * self.equalizer[:, None, None, :] 219 | return attn_replace 220 | 221 | def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, equalizer, 222 | local_blend=None, controller: Optional[AttentionControlEdit] = None, 223 | device=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')): 224 | super(AttentionReweight, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, 225 | local_blend) 226 | self.equalizer = equalizer.to(device) 227 | self.prev_controller = controller 228 | 229 | 230 | class AttentionTest(AttentionStore): 231 | 232 | def step_callback(self, x_t, inner_idx=None): 233 | x_t = super(AttentionTest, self).step_callback(x_t) 234 | 235 | if inner_idx is None: 236 | return x_t 237 | 238 | b, c, f, h, w = x_t.shape 239 | 240 | momentum = 0.1 if self.cur_step <= self.config['inference_config']['interpolation_step'] else 1.0 241 | original_idx = [i for i in range(f) if i not in inner_idx] 242 | for idx in inner_idx: 243 | search_idx = sorted(original_idx + [idx]) 244 | find = search_idx.index(idx) 245 | pre_idx, next_idx = search_idx[find - 1], search_idx[find + 1] 246 | length = next_idx - pre_idx 247 | 248 | alpha = (idx - pre_idx) / length 249 | x_t[:, :, idx] = (1 - momentum) * ((next_idx - idx) / length * x_t[:, :, pre_idx] + ( 250 | idx - pre_idx) / length * x_t[:, :, next_idx]) + momentum * x_t[:, :, idx] 251 | 252 | return x_t 253 | 254 | def forward(self, attn, is_cross: bool, place_in_unet: str): 255 | super(AttentionTest, self).forward(attn, is_cross, place_in_unet) 256 | if is_cross: 257 | h = attn.shape[0] // (self.batch_size) 258 | attn = attn.reshape(self.batch_size, h, *attn.shape[1:]) 259 | attn_base, attn_repalce, attn_show = attn[0], attn[1:2], attn[2:3] 260 | attn = attn.reshape(self.batch_size * h, *attn.shape[2:]) 261 | return attn 262 | 263 | 264 | def get_equalizer(text: str, 265 | word_select: Union[int, Tuple[int, ...]], 266 | values: Union[List[float], Tuple[float, ...]], 267 | tokenizer=None): 268 | if type(word_select) is int or type(word_select) is str: 269 | word_select = (word_select,) 270 | equalizer = torch.ones(len(values), 77) 271 | values = torch.tensor(values, dtype=torch.float32) 272 | for word in word_select: 273 | inds = ptp_utils.get_word_inds(text, word, tokenizer) 274 | equalizer[:, inds] = values 275 | return equalizer 276 | 277 | 278 | def aggregate_attention(prompts, 279 | attention_store: AttentionStore, 280 | res: int, 281 | from_where: List[str], 282 | is_cross: bool, 283 | select: int): 284 | out = [] 285 | attention_maps = attention_store.get_average_attention() 286 | num_pixels = res ** 2 287 | for location in from_where: 288 | for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]: 289 | if item.shape[1] == num_pixels: 290 | cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select] 291 | out.append(cross_maps) 292 | out = torch.cat(out, dim=0) 293 | out = out.sum(0) / out.shape[0] 294 | return out.cpu() 295 | 296 | 297 | def show_cross_attention(tokenizer, 298 | prompts, 299 | attention_store: AttentionStore, res: int, from_where: List[str], 300 | select: int = 0): 301 | tokens = tokenizer.encode(prompts[select]) 302 | decoder = tokenizer.decode 303 | attention_maps = aggregate_attention(attention_store, res, from_where, True, select) 304 | images = [] 305 | for i in range(len(tokens)): 306 | image = attention_maps[:, :, i] 307 | image = 255 * image / image.max() 308 | image = image.unsqueeze(-1).expand(*image.shape, 3) 309 | image = image.numpy().astype(np.uint8) 310 | image = np.array(Image.fromarray(image).resize((256, 256))) 311 | image = ptp_utils.text_under_image(image, decoder(int(tokens[i]))) 312 | images.append(image) 313 | ptp_utils.view_images(np.stack(images, axis=0)) 314 | 315 | 316 | def show_self_attention_comp(attention_store: AttentionStore, res: int, from_where: List[str], 317 | max_com=10, select: int = 0): 318 | attention_maps = aggregate_attention(attention_store, res, from_where, False, select).numpy().reshape( 319 | (res ** 2, res ** 2)) 320 | u, s, vh = np.linalg.svd(attention_maps - np.mean(attention_maps, axis=1, keepdims=True)) 321 | images = [] 322 | for i in range(max_com): 323 | image = vh[i].reshape(res, res) 324 | image = image - image.min() 325 | image = 255 * image / image.max() 326 | image = np.repeat(np.expand_dims(image, axis=2), 3, axis=2).astype(np.uint8) 327 | image = Image.fromarray(image).resize((256, 256)) 328 | image = np.array(image) 329 | images.append(image) 330 | ptp_utils.view_images(np.concatenate(images, axis=1)) 331 | -------------------------------------------------------------------------------- /freebloom/prompt_attention/ptp_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import math 15 | 16 | import cv2 17 | import datetime 18 | import numpy as np 19 | import torch 20 | import torch.nn.functional as F 21 | import torchvision 22 | from PIL import Image 23 | from einops import rearrange, repeat 24 | from tqdm.notebook import tqdm 25 | from typing import Optional, Union, Tuple, List, Dict 26 | 27 | 28 | def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)): 29 | h, w, c = image.shape 30 | offset = int(h * .2) 31 | img = np.ones((h + offset, w, c), dtype=np.uint8) * 255 32 | font = cv2.FONT_HERSHEY_SIMPLEX 33 | # font = ImageFont.truetype("/usr/share/fonts/truetype/noto/NotoMono-Regular.ttf", font_size) 34 | img[:h] = image 35 | textsize = cv2.getTextSize(text, font, 1, 2)[0] 36 | text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2 37 | cv2.putText(img, text, (text_x, text_y), font, 1, text_color, 2) 38 | return img 39 | 40 | 41 | def view_images(images, num_rows=1, offset_ratio=0.02, save_path=None): 42 | if type(images) is list: 43 | num_empty = len(images) % num_rows 44 | elif images.ndim == 4: 45 | num_empty = images.shape[0] % num_rows 46 | else: 47 | images = [images] 48 | num_empty = 0 49 | 50 | empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255 51 | images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty 52 | num_items = len(images) 53 | 54 | h, w, c = images[0].shape 55 | offset = int(h * offset_ratio) 56 | num_cols = num_items // num_rows 57 | image_ = np.ones((h * num_rows + offset * (num_rows - 1), 58 | w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255 59 | for i in range(num_rows): 60 | for j in range(num_cols): 61 | image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[ 62 | i * num_cols + j] 63 | 64 | if save_path is not None: 65 | pil_img = Image.fromarray(image_) 66 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 67 | pil_img.save(f'{save_path}/{now}.png') 68 | # display(pil_img) 69 | 70 | 71 | def diffusion_step(model, controller, latents, context, t, guidance_scale, low_resource=False): 72 | if low_resource: 73 | noise_pred_uncond = model.unet(latents, t, encoder_hidden_states=context[0])["sample"] 74 | noise_prediction_text = model.unet(latents, t, encoder_hidden_states=context[1])["sample"] 75 | else: 76 | latents_input = torch.cat([latents] * 2) 77 | noise_pred = model.unet(latents_input, t, encoder_hidden_states=context)["sample"] 78 | noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2) 79 | noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) 80 | latents = model.scheduler.step(noise_pred, t, latents)["prev_sample"] 81 | latents = controller.step_callback(latents) 82 | return latents 83 | 84 | 85 | def latent2image(vae, latents): 86 | latents = 1 / 0.18215 * latents 87 | image = vae.decode(latents)['sample'] 88 | image = (image / 2 + 0.5).clamp(0, 1) 89 | image = image.cpu().permute(0, 2, 3, 1).numpy() 90 | image = (image * 255).astype(np.uint8) 91 | return image 92 | 93 | 94 | def init_latent(latent, model, height, width, generator, batch_size): 95 | if latent is None: 96 | latent = torch.randn( 97 | (1, model.unet.in_channels, height // 8, width // 8), 98 | generator=generator, 99 | ) 100 | latents = latent.expand(batch_size, model.unet.in_channels, height // 8, width // 8).to(model.device) 101 | return latent, latents 102 | 103 | 104 | @torch.no_grad() 105 | def text2image_ldm( 106 | model, 107 | prompt: List[str], 108 | controller, 109 | num_inference_steps: int = 50, 110 | guidance_scale: Optional[float] = 7., 111 | generator: Optional[torch.Generator] = None, 112 | latent: Optional[torch.FloatTensor] = None, 113 | ): 114 | register_attention_control(model, controller) 115 | height = width = 256 116 | batch_size = len(prompt) 117 | 118 | uncond_input = model.tokenizer([""] * batch_size, padding="max_length", max_length=77, return_tensors="pt") 119 | uncond_embeddings = model.bert(uncond_input.input_ids.to(model.device))[0] 120 | 121 | text_input = model.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt") 122 | text_embeddings = model.bert(text_input.input_ids.to(model.device))[0] 123 | latent, latents = init_latent(latent, model, height, width, generator, batch_size) 124 | context = torch.cat([uncond_embeddings, text_embeddings]) 125 | 126 | model.scheduler.set_timesteps(num_inference_steps) 127 | for t in tqdm(model.scheduler.timesteps): 128 | latents = diffusion_step(model, controller, latents, context, t, guidance_scale) 129 | 130 | image = latent2image(model.vqvae, latents) 131 | 132 | return image, latent 133 | 134 | 135 | @torch.no_grad() 136 | def text2image_ldm_stable( 137 | model, 138 | prompt: List[str], 139 | controller, 140 | num_inference_steps: int = 50, 141 | guidance_scale: float = 7.5, 142 | generator: Optional[torch.Generator] = None, 143 | latent: Optional[torch.FloatTensor] = None, 144 | low_resource: bool = False, 145 | ): 146 | register_attention_control(model, controller) 147 | height = width = 512 148 | batch_size = len(prompt) 149 | 150 | text_input = model.tokenizer( 151 | prompt, 152 | padding="max_length", 153 | max_length=model.tokenizer.model_max_length, 154 | truncation=True, 155 | return_tensors="pt", 156 | ) 157 | text_embeddings = model.text_encoder(text_input.input_ids.to(model.device))[0] 158 | max_length = text_input.input_ids.shape[-1] 159 | uncond_input = model.tokenizer( 160 | [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" 161 | ) 162 | uncond_embeddings = model.text_encoder(uncond_input.input_ids.to(model.device))[0] 163 | 164 | context = [uncond_embeddings, text_embeddings] 165 | if not low_resource: 166 | context = torch.cat(context) 167 | latent, latents = init_latent(latent, model, height, width, generator, batch_size) 168 | 169 | # set timesteps 170 | extra_set_kwargs = {"offset": 1} 171 | model.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) 172 | for t in tqdm(model.scheduler.timesteps): 173 | latents = diffusion_step(model, controller, latents, context, t, guidance_scale, low_resource) 174 | 175 | image = latent2image(model.vae, latents) 176 | 177 | return image, latent 178 | 179 | 180 | def register_attention_control(model, controller): 181 | def ca_forward(self, place_in_unet, attention_type): 182 | to_out = self.to_out 183 | if type(to_out) is torch.nn.modules.container.ModuleList: 184 | to_out = self.to_out[0] 185 | else: 186 | to_out = self.to_out 187 | 188 | def _attention(query, key, value, is_cross, attention_mask=None): 189 | 190 | if self.upcast_attention: 191 | query = query.float() 192 | key = key.float() 193 | 194 | attention_scores = torch.baddbmm( 195 | torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), 196 | query, 197 | key.transpose(-1, -2), 198 | beta=0, 199 | alpha=self.scale, 200 | ) 201 | 202 | if attention_mask is not None: 203 | attention_scores = attention_scores + attention_mask 204 | 205 | if self.upcast_softmax: 206 | attention_scores = attention_scores.float() 207 | 208 | attention_probs = attention_scores.softmax(dim=-1) 209 | 210 | # cast back to the original dtype 211 | attention_probs = attention_probs.to(value.dtype) 212 | 213 | attention_probs = controller(attention_probs, is_cross, place_in_unet) 214 | 215 | # compute attention output 216 | hidden_states = torch.bmm(attention_probs, value) 217 | 218 | # reshape hidden_states 219 | hidden_states = self.reshape_batch_dim_to_heads(hidden_states) 220 | return hidden_states 221 | 222 | def reshape_heads_to_batch_dim(tensor): 223 | batch_size, seq_len, dim = tensor.shape 224 | head_size = self.heads 225 | tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) 226 | tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) 227 | return tensor 228 | 229 | def reshape_batch_dim_to_heads(tensor): 230 | batch_size, seq_len, dim = tensor.shape 231 | head_size = self.heads 232 | tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) 233 | tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) 234 | return tensor 235 | 236 | def forward(hidden_states, encoder_hidden_states=None, attention_mask=None): 237 | is_cross = encoder_hidden_states is not None 238 | 239 | batch_size, sequence_length, _ = hidden_states.shape 240 | 241 | encoder_hidden_states = encoder_hidden_states 242 | 243 | if self.group_norm is not None: 244 | hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 245 | 246 | query = self.to_q(hidden_states) 247 | dim = query.shape[-1] 248 | query = self.reshape_heads_to_batch_dim(query) 249 | 250 | if self.added_kv_proj_dim is not None: 251 | key = self.to_k(hidden_states) 252 | value = self.to_v(hidden_states) 253 | encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states) 254 | encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states) 255 | 256 | key = self.reshape_heads_to_batch_dim(key) 257 | value = self.reshape_heads_to_batch_dim(value) 258 | encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj) 259 | encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj) 260 | 261 | key = torch.concat([encoder_hidden_states_key_proj, key], dim=1) 262 | value = torch.concat([encoder_hidden_states_value_proj, value], dim=1) 263 | else: 264 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states 265 | key = self.to_k(encoder_hidden_states) 266 | value = self.to_v(encoder_hidden_states) 267 | 268 | key = self.reshape_heads_to_batch_dim(key) 269 | value = self.reshape_heads_to_batch_dim(value) 270 | 271 | if attention_mask is not None: 272 | if attention_mask.shape[-1] != query.shape[1]: 273 | target_length = query.shape[1] 274 | attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) 275 | attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) 276 | 277 | # attention, what we cannot get enough of 278 | self._use_memory_efficient_attention_xformers = False 279 | if self._use_memory_efficient_attention_xformers: 280 | hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) 281 | # Some versions of xformers return output in fp32, cast it back to the dtype of the input 282 | hidden_states = hidden_states.to(query.dtype) 283 | else: 284 | if self._slice_size is None or query.shape[0] // self._slice_size == 1: 285 | hidden_states = _attention(query, key, value, is_cross, attention_mask) 286 | else: 287 | hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) 288 | 289 | # linear proj 290 | hidden_states = self.to_out[0](hidden_states) 291 | 292 | # dropout 293 | hidden_states = self.to_out[1](hidden_states) 294 | return hidden_states 295 | 296 | def sca_forward(hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None): 297 | is_cross = encoder_hidden_states is not None 298 | 299 | batch_size, sequence_length, _ = hidden_states.shape 300 | 301 | encoder_hidden_states = encoder_hidden_states 302 | 303 | if self.group_norm is not None: 304 | hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 305 | 306 | query = self.to_q(hidden_states) 307 | dim = query.shape[-1] 308 | query = self.reshape_heads_to_batch_dim(query) # [(b f), d ,c] 309 | 310 | if self.added_kv_proj_dim is not None: 311 | raise NotImplementedError 312 | 313 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states 314 | key = self.to_k(encoder_hidden_states) 315 | value = self.to_v(encoder_hidden_states) 316 | 317 | key = rearrange(key, "(b f) d c -> b f d c", f=video_length) 318 | value = rearrange(value, "(b f) d c -> b f d c", f=video_length) 319 | 320 | attention_frame_index = [] 321 | 322 | attention_type = controller.attention_type_former if controller.cur_step <= controller.attention_adapt_step \ 323 | else controller.attention_type_latter 324 | 325 | # self only 326 | if "self" in attention_type: 327 | attention_frame_index.append(torch.arange(video_length)) 328 | 329 | # first only 330 | if "first" in attention_type: 331 | attention_frame_index.append([0] * video_length) 332 | 333 | # former 334 | if "former" in attention_type: 335 | attention_frame_index.append(torch.clamp(torch.arange(video_length) - 1, min=0)) 336 | # former end 337 | 338 | # all 339 | # for i in range(video_length): 340 | # attention_frame_index.append([i] * video_length) 341 | # all end 342 | 343 | key = torch.cat([key[:, frame_index] for frame_index in attention_frame_index], dim=2) 344 | value = torch.cat([value[:, frame_index] for frame_index in attention_frame_index], dim=2) 345 | 346 | b, f, k, c = key.shape 347 | 348 | key = rearrange(key, "b f d c -> (b f) d c") 349 | value = rearrange(value, "b f d c -> (b f) d c") 350 | 351 | key = self.reshape_heads_to_batch_dim(key) 352 | value = self.reshape_heads_to_batch_dim(value) 353 | 354 | if attention_mask is not None: 355 | if attention_mask.shape[-1] != query.shape[1]: 356 | target_length = query.shape[1] 357 | attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) 358 | attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) 359 | 360 | # attention, what we cannot get enough of 361 | # self._use_memory_efficient_attention_xformers = False 362 | if self._use_memory_efficient_attention_xformers: 363 | hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) 364 | # Some versions of xformers return output in fp32, cast it back to the dtype of the input 365 | hidden_states = hidden_states.to(query.dtype) 366 | else: 367 | if self._slice_size is None or query.shape[0] // self._slice_size == 1: 368 | hidden_states = _attention(query, key, value, is_cross, attention_mask) 369 | else: 370 | hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) 371 | 372 | # linear proj 373 | hidden_states = self.to_out[0](hidden_states) 374 | 375 | # dropout 376 | hidden_states = self.to_out[1](hidden_states) 377 | return hidden_states 378 | 379 | if attention_type == 'CrossAttention': 380 | return forward 381 | elif attention_type == "SparseCausalAttention": 382 | return sca_forward 383 | 384 | class DummyController: 385 | 386 | def __call__(self, *args): 387 | return args[0] 388 | 389 | def __init__(self): 390 | self.num_att_layers = 0 391 | 392 | if controller is None: 393 | controller = DummyController() 394 | 395 | def register_recr(net_, count, place_in_unet): 396 | if net_.__class__.__name__ == 'SparseCausalAttention': 397 | net_.forward = ca_forward(net_, place_in_unet, net_.__class__.__name__) 398 | return count + 1 399 | elif hasattr(net_, 'children'): 400 | for net__ in net_.children(): 401 | count = register_recr(net__, count, place_in_unet) 402 | return count 403 | 404 | cross_att_count = 0 405 | sub_nets = model.unet.named_children() 406 | for net in sub_nets: 407 | if "down" in net[0]: 408 | cross_att_count += register_recr(net[1], 0, "down") 409 | elif "up" in net[0]: 410 | cross_att_count += register_recr(net[1], 0, "up") 411 | elif "mid" in net[0]: 412 | cross_att_count += register_recr(net[1], 0, "mid") 413 | 414 | controller.num_att_layers = cross_att_count 415 | 416 | 417 | def get_word_inds(text: str, word_place: int, tokenizer): 418 | split_text = text.split(" ") 419 | if type(word_place) is str: 420 | word_place = [i for i, word in enumerate(split_text) if word_place == word] 421 | elif type(word_place) is int: 422 | word_place = [word_place] 423 | out = [] 424 | if len(word_place) > 0: 425 | words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1] 426 | cur_len, ptr = 0, 0 427 | 428 | for i in range(len(words_encode)): 429 | cur_len += len(words_encode[i]) 430 | if ptr in word_place: 431 | out.append(i + 1) 432 | if cur_len >= len(split_text[ptr]): 433 | ptr += 1 434 | cur_len = 0 435 | return np.array(out) 436 | 437 | 438 | def update_alpha_time_word(alpha, bounds: Union[float, Tuple[float, float]], prompt_ind: int, 439 | word_inds: Optional[torch.Tensor] = None): 440 | if type(bounds) is float: 441 | bounds = 0, bounds 442 | start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0]) 443 | if word_inds is None: 444 | word_inds = torch.arange(alpha.shape[2]) 445 | alpha[: start, prompt_ind, word_inds] = 0 446 | alpha[start: end, prompt_ind, word_inds] = 1 447 | alpha[end:, prompt_ind, word_inds] = 0 448 | return alpha 449 | 450 | 451 | def get_time_words_attention_alpha(prompts, num_steps, 452 | cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]], 453 | tokenizer, max_num_words=77): 454 | if type(cross_replace_steps) is not dict: 455 | cross_replace_steps = {"default_": cross_replace_steps} 456 | if "default_" not in cross_replace_steps: 457 | cross_replace_steps["default_"] = (0., 1.) 458 | alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words) 459 | for i in range(len(prompts) - 1): 460 | alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"], 461 | i) 462 | for key, item in cross_replace_steps.items(): 463 | if key != "default_": 464 | inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))] 465 | for i, ind in enumerate(inds): 466 | if len(ind) > 0: 467 | alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind) 468 | alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words) 469 | return alpha_time_words 470 | 471 | 472 | def get_time_string() -> str: 473 | x = datetime.datetime.now() 474 | return f"{(x.year - 2000):02d}{x.month:02d}{x.day:02d}-{x.hour:02d}{x.minute:02d}{x.second:02d}" 475 | -------------------------------------------------------------------------------- /freebloom/prompt_attention/seq_aligner.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import torch 15 | import numpy as np 16 | 17 | 18 | class ScoreParams: 19 | 20 | def __init__(self, gap, match, mismatch): 21 | self.gap = gap 22 | self.match = match 23 | self.mismatch = mismatch 24 | 25 | def mis_match_char(self, x, y): 26 | if x != y: 27 | return self.mismatch 28 | else: 29 | return self.match 30 | 31 | 32 | def get_matrix(size_x, size_y, gap): 33 | matrix = [] 34 | for i in range(len(size_x) + 1): 35 | sub_matrix = [] 36 | for j in range(len(size_y) + 1): 37 | sub_matrix.append(0) 38 | matrix.append(sub_matrix) 39 | for j in range(1, len(size_y) + 1): 40 | matrix[0][j] = j * gap 41 | for i in range(1, len(size_x) + 1): 42 | matrix[i][0] = i * gap 43 | return matrix 44 | 45 | 46 | def get_matrix(size_x, size_y, gap): 47 | matrix = np.zeros((size_x + 1, size_y + 1), dtype=np.int32) 48 | matrix[0, 1:] = (np.arange(size_y) + 1) * gap 49 | matrix[1:, 0] = (np.arange(size_x) + 1) * gap 50 | return matrix 51 | 52 | 53 | def get_traceback_matrix(size_x, size_y): 54 | matrix = np.zeros((size_x + 1, size_y + 1), dtype=np.int32) 55 | matrix[0, 1:] = 1 56 | matrix[1:, 0] = 2 57 | matrix[0, 0] = 4 58 | return matrix 59 | 60 | 61 | def global_align(x, y, score): 62 | matrix = get_matrix(len(x), len(y), score.gap) 63 | trace_back = get_traceback_matrix(len(x), len(y)) 64 | for i in range(1, len(x) + 1): 65 | for j in range(1, len(y) + 1): 66 | left = matrix[i, j - 1] + score.gap 67 | up = matrix[i - 1, j] + score.gap 68 | diag = matrix[i - 1, j - 1] + score.mis_match_char(x[i - 1], y[j - 1]) 69 | matrix[i, j] = max(left, up, diag) 70 | if matrix[i, j] == left: 71 | trace_back[i, j] = 1 72 | elif matrix[i, j] == up: 73 | trace_back[i, j] = 2 74 | else: 75 | trace_back[i, j] = 3 76 | return matrix, trace_back 77 | 78 | 79 | def get_aligned_sequences(x, y, trace_back): 80 | x_seq = [] 81 | y_seq = [] 82 | i = len(x) 83 | j = len(y) 84 | mapper_y_to_x = [] 85 | while i > 0 or j > 0: 86 | if trace_back[i, j] == 3: 87 | x_seq.append(x[i - 1]) 88 | y_seq.append(y[j - 1]) 89 | i = i - 1 90 | j = j - 1 91 | mapper_y_to_x.append((j, i)) 92 | elif trace_back[i][j] == 1: 93 | x_seq.append('-') 94 | y_seq.append(y[j - 1]) 95 | j = j - 1 96 | mapper_y_to_x.append((j, -1)) 97 | elif trace_back[i][j] == 2: 98 | x_seq.append(x[i - 1]) 99 | y_seq.append('-') 100 | i = i - 1 101 | elif trace_back[i][j] == 4: 102 | break 103 | mapper_y_to_x.reverse() 104 | return x_seq, y_seq, torch.tensor(mapper_y_to_x, dtype=torch.int64) 105 | 106 | 107 | def get_mapper(x: str, y: str, tokenizer, max_len=77): 108 | x_seq = tokenizer.encode(x) 109 | y_seq = tokenizer.encode(y) 110 | score = ScoreParams(0, 1, -1) 111 | matrix, trace_back = global_align(x_seq, y_seq, score) 112 | mapper_base = get_aligned_sequences(x_seq, y_seq, trace_back)[-1] 113 | alphas = torch.ones(max_len) 114 | alphas[: mapper_base.shape[0]] = mapper_base[:, 1].ne(-1).float() 115 | mapper = torch.zeros(max_len, dtype=torch.int64) 116 | mapper[:mapper_base.shape[0]] = mapper_base[:, 1] 117 | mapper[mapper_base.shape[0]:] = len(y_seq) + torch.arange(max_len - len(y_seq)) 118 | return mapper, alphas 119 | 120 | 121 | def get_refinement_mapper(prompts, tokenizer, max_len=77): 122 | x_seq = prompts[0] 123 | mappers, alphas = [], [] 124 | for i in range(1, len(prompts)): 125 | mapper, alpha = get_mapper(x_seq, prompts[i], tokenizer, max_len) 126 | mappers.append(mapper) 127 | alphas.append(alpha) 128 | return torch.stack(mappers), torch.stack(alphas) 129 | 130 | 131 | def get_word_inds(text: str, word_place: int, tokenizer): 132 | split_text = text.split(" ") 133 | if type(word_place) is str: 134 | word_place = [i for i, word in enumerate(split_text) if word_place == word] 135 | elif type(word_place) is int: 136 | word_place = [word_place] 137 | out = [] 138 | if len(word_place) > 0: 139 | words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1] 140 | cur_len, ptr = 0, 0 141 | 142 | for i in range(len(words_encode)): 143 | cur_len += len(words_encode[i]) 144 | if ptr in word_place: 145 | out.append(i + 1) 146 | if cur_len >= len(split_text[ptr]): 147 | ptr += 1 148 | cur_len = 0 149 | return np.array(out) 150 | 151 | 152 | def get_replacement_mapper_(x: str, y: str, tokenizer, max_len=77): 153 | words_x = x.split(' ') 154 | words_y = y.split(' ') 155 | if len(words_x) != len(words_y): 156 | raise ValueError(f"attention replacement edit can only be applied on prompts with the same length" 157 | f" but prompt A has {len(words_x)} words and prompt B has {len(words_y)} words.") 158 | inds_replace = [i for i in range(len(words_y)) if words_y[i] != words_x[i]] 159 | inds_source = [get_word_inds(x, i, tokenizer) for i in inds_replace] 160 | inds_target = [get_word_inds(y, i, tokenizer) for i in inds_replace] 161 | mapper = np.zeros((max_len, max_len)) 162 | i = j = 0 163 | cur_inds = 0 164 | while i < max_len and j < max_len: 165 | if cur_inds < len(inds_source) and inds_source[cur_inds][0] == i: 166 | inds_source_, inds_target_ = inds_source[cur_inds], inds_target[cur_inds] 167 | if len(inds_source_) == len(inds_target_): 168 | mapper[inds_source_, inds_target_] = 1 169 | else: 170 | ratio = 1 / len(inds_target_) 171 | for i_t in inds_target_: 172 | mapper[inds_source_, i_t] = ratio 173 | cur_inds += 1 174 | i += len(inds_source_) 175 | j += len(inds_target_) 176 | elif cur_inds < len(inds_source): 177 | mapper[i, j] = 1 178 | i += 1 179 | j += 1 180 | else: 181 | mapper[j, j] = 1 182 | i += 1 183 | j += 1 184 | 185 | return torch.from_numpy(mapper).float() 186 | 187 | 188 | def get_replacement_mapper(prompts, tokenizer, max_len=77): 189 | x_seq = prompts[0] 190 | mappers = [] 191 | for i in range(1, len(prompts)): 192 | mapper = get_replacement_mapper_(x_seq, prompts[i], tokenizer, max_len) 193 | mappers.append(mapper) 194 | return torch.stack(mappers) 195 | -------------------------------------------------------------------------------- /freebloom/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import imageio 3 | import numpy as np 4 | from PIL import Image 5 | from typing import Union 6 | 7 | import torch 8 | import torchvision 9 | 10 | from tqdm import tqdm 11 | from einops import rearrange 12 | 13 | 14 | def save_tensor_img(img, save_path): 15 | """ 16 | :param img c, h, w , -1~1 17 | """ 18 | # img = (img + 1.0) / 2.0 19 | img = Image.fromarray(img.mul(255).byte().numpy().transpose(1, 2, 0)) 20 | img.save(save_path) 21 | 22 | 23 | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=4, fps=8): 24 | videos = rearrange(videos, "b c t h w -> t b c h w") 25 | outputs = [] 26 | for x in videos: 27 | x = torchvision.utils.make_grid(x, nrow=n_rows) 28 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) 29 | if rescale: 30 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1 31 | x = (x * 255).numpy().astype(np.uint8) 32 | outputs.append(x) 33 | 34 | os.makedirs(os.path.dirname(path), exist_ok=True) 35 | imageio.mimsave(path, outputs, fps=fps) 36 | 37 | 38 | def save_videos_per_frames_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=4): 39 | os.makedirs(path, exist_ok=True) 40 | for i, video in enumerate(videos): 41 | video = rearrange(video, "c f h w -> f c h w") 42 | x = torchvision.utils.make_grid(video, nrow=n_rows) 43 | if rescale: 44 | x = (x + 1.0) / 2.0 45 | # x = (x * 255).numpy().astype(np.int8) 46 | torchvision.utils.save_image(x, f'{path}/{i}_all.jpg') 47 | for j, img in enumerate(video): 48 | save_tensor_img(img, f'{path}/{j}.jpg') 49 | 50 | # DDIM Inversion 51 | @torch.no_grad() 52 | def init_prompt(prompt, pipeline): 53 | uncond_input = pipeline.tokenizer( 54 | [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length, 55 | return_tensors="pt" 56 | ) 57 | uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0] 58 | text_input = pipeline.tokenizer( 59 | [prompt], 60 | padding="max_length", 61 | max_length=pipeline.tokenizer.model_max_length, 62 | truncation=True, 63 | return_tensors="pt", 64 | ) 65 | text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0] 66 | context = torch.cat([uncond_embeddings, text_embeddings]) 67 | 68 | return context 69 | 70 | 71 | def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, 72 | sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler): 73 | timestep, next_timestep = min( 74 | timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep 75 | alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod 76 | alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep] 77 | beta_prod_t = 1 - alpha_prod_t 78 | next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 79 | next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output 80 | next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction 81 | return next_sample 82 | 83 | 84 | def get_noise_pred_single(latents, t, context, unet): 85 | noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"] 86 | return noise_pred 87 | 88 | 89 | @torch.no_grad() 90 | def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt): 91 | context = init_prompt(prompt, pipeline) 92 | uncond_embeddings, cond_embeddings = context.chunk(2) 93 | 94 | all_latent = [latent] 95 | latent = latent.clone().detach() 96 | for i in tqdm(range(num_inv_steps)): 97 | t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1] 98 | noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet) 99 | latent = next_step(noise_pred, t, latent, ddim_scheduler) 100 | all_latent.append(latent) 101 | return all_latent 102 | 103 | 104 | @torch.no_grad() 105 | def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""): 106 | ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt) 107 | return ddim_latents 108 | 109 | 110 | def invert(image_path, weight_dtype, pipeline, ddim_scheduler, num_inv_steps, prompt=""): 111 | image_gt = load_512(image_path) 112 | image = torch.from_numpy(image_gt).type(weight_dtype) / 127.5 - 1 113 | image = image.permute(2, 0, 1).unsqueeze(0).to(pipeline.vae.device) 114 | latent = pipeline.vae.encode(image)['latent_dist'].mean.unsqueeze(2) 115 | latent = latent * 0.18215 # pipeline.vae.config.scaling_factor 116 | latents = ddim_inversion(pipeline, ddim_scheduler, latent, num_inv_steps, prompt=prompt) 117 | return latents 118 | 119 | 120 | def load_512(image_path, left=0, right=0, top=0, bottom=0): 121 | if type(image_path) is str: 122 | image = np.array(Image.open(image_path))[:, :, :3] 123 | else: 124 | image = image_path 125 | h, w, c = image.shape 126 | left = min(left, w - 1) 127 | right = min(right, w - left - 1) 128 | top = min(top, h - left - 1) 129 | bottom = min(bottom, h - top - 1) 130 | image = image[top:h - bottom, left:w - right] 131 | h, w, c = image.shape 132 | if h < w: 133 | offset = (w - h) // 2 134 | image = image[:, offset:offset + h] 135 | elif w < h: 136 | offset = (h - w) // 2 137 | image = image[offset:offset + w] 138 | image = np.array(Image.fromarray(image).resize((512, 512))) 139 | return image 140 | 141 | 142 | @torch.no_grad() 143 | def latent2image(vae, latents): 144 | latents = 1 / vae.config.scaling_factor * latents 145 | image = vae.decode(latents)['sample'] 146 | image = (image / 2 + 0.5).clamp(0, 1) 147 | image = image.cpu().permute(0, 2, 3, 1).numpy() 148 | image = (image * 255).astype(np.uint8) 149 | return image 150 | 151 | 152 | @torch.no_grad() 153 | def image2latent(vae, image, device, dtype=torch.float32): 154 | with torch.no_grad(): 155 | if type(image) is Image: 156 | image = np.array(image) 157 | if type(image) is torch.Tensor and image.dim() == 4: 158 | latents = image 159 | else: 160 | image = torch.from_numpy(image).type(dtype) / 127.5 - 1 161 | image = image.permute(2, 0, 1).unsqueeze(0).to(device) 162 | latents = vae.encode(image)['latent_dist'].mean 163 | latents = latents * vae.config.scaling_factor 164 | return latents 165 | -------------------------------------------------------------------------------- /interp.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import inspect 4 | import logging 5 | import os 6 | from typing import Dict, Optional 7 | 8 | import diffusers 9 | import numpy as np 10 | import torch 11 | import torch.utils.checkpoint 12 | import transformers 13 | from accelerate import Accelerator 14 | from accelerate.logging import get_logger 15 | from accelerate.utils import set_seed 16 | from diffusers import AutoencoderKL, DDIMScheduler 17 | from diffusers.utils.import_utils import is_xformers_available 18 | from omegaconf import OmegaConf 19 | from transformers import CLIPTextModel, CLIPTokenizer 20 | 21 | from freebloom.models.unet import UNet3DConditionModel 22 | from freebloom.pipelines.pipeline_spatio_temporal import SpatioTemporalPipeline 23 | from freebloom.util import save_videos_grid, save_videos_per_frames_grid 24 | 25 | logger = get_logger(__name__, log_level="INFO") 26 | 27 | 28 | def main( 29 | pretrained_model_path: str, 30 | output_dir: str, 31 | validation_data: Dict, 32 | mixed_precision: Optional[str] = "fp16", 33 | enable_xformers_memory_efficient_attention: bool = True, 34 | seed: Optional[int] = None, 35 | inference_config: Dict = None, 36 | ): 37 | *_, config = inspect.getargvalues(inspect.currentframe()) 38 | 39 | accelerator = Accelerator( 40 | mixed_precision=mixed_precision, 41 | project_dir=f"{output_dir}/acc_log" 42 | ) 43 | 44 | # Make one log on every process with the configuration for debugging. 45 | logging.basicConfig( 46 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 47 | datefmt="%m/%d/%Y %H:%M:%S", 48 | level=logging.INFO, 49 | ) 50 | logger.info(accelerator.state, main_process_only=False) 51 | if accelerator.is_local_main_process: 52 | transformers.utils.logging.set_verbosity_warning() 53 | diffusers.utils.logging.set_verbosity_info() 54 | else: 55 | transformers.utils.logging.set_verbosity_error() 56 | diffusers.utils.logging.set_verbosity_error() 57 | 58 | # If passed along, set the training seed now. 59 | if seed is not None: 60 | set_seed(seed) 61 | 62 | # Handle the output folder creation 63 | if accelerator.is_main_process: 64 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 65 | output_dir = os.path.join(output_dir, now) 66 | os.makedirs(output_dir, exist_ok=True) 67 | # OmegaConf.save(config, os.path.join(output_dir, 'config.yaml')) 68 | 69 | # Load scheduler, tokenizer and models. 70 | tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") 71 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder") 72 | vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae") 73 | unet = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet") 74 | 75 | # Freeze vae and text_encoder 76 | vae.requires_grad_(False) 77 | text_encoder.requires_grad_(False) 78 | unet.requires_grad_(False) 79 | 80 | if enable_xformers_memory_efficient_attention: 81 | if is_xformers_available(): 82 | unet.enable_xformers_memory_efficient_attention() 83 | else: 84 | raise ValueError("xformers is not available. Make sure it is installed correctly") 85 | 86 | # Get the validation pipeline 87 | validation_pipeline = SpatioTemporalPipeline( 88 | vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, 89 | scheduler=DDIMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler"), 90 | disk_store=False, 91 | config=config 92 | ) 93 | validation_pipeline.enable_vae_slicing() 94 | validation_pipeline.scheduler.set_timesteps(validation_data.num_inv_steps) 95 | 96 | # Prepare everything with our `accelerator`. 97 | unet = accelerator.prepare(unet) 98 | 99 | # For mixed precision training we cast the text_encoder and vae weights to half-precision 100 | # as these models are only used for inference, keeping weights in full precision is not required. 101 | weight_dtype = torch.float32 102 | if accelerator.mixed_precision == "fp16": 103 | weight_dtype = torch.float16 104 | elif accelerator.mixed_precision == "bf16": 105 | weight_dtype = torch.bfloat16 106 | 107 | # Move text_encode and vae to gpu and cast to weight_dtype 108 | text_encoder.to(accelerator.device, dtype=weight_dtype) 109 | vae.to(accelerator.device, dtype=weight_dtype) 110 | unet.to(accelerator.device, dtype=weight_dtype) 111 | 112 | # We need to initialize the trackers we use, and also store our configuration. 113 | # The trackers initializes automatically on the main process. 114 | if accelerator.is_main_process: 115 | accelerator.init_trackers("SpatioTemporal") 116 | 117 | text_encoder.eval() 118 | vae.eval() 119 | unet.eval() 120 | 121 | generator = torch.Generator(device=unet.device) 122 | generator.manual_seed(seed) 123 | 124 | samples = [] 125 | 126 | prompt = list(validation_data.prompts) 127 | negative_prompt = config['validation_data']['negative_prompt'] 128 | negative_prompt = [negative_prompt] * len(prompt) 129 | 130 | with (torch.no_grad()): 131 | x_base = validation_pipeline.prepare_latents(batch_size=1, 132 | num_channels_latents=4, 133 | video_length=len(prompt), 134 | height=512, 135 | width=512, 136 | dtype=weight_dtype, 137 | device=unet.device, 138 | generator=generator, 139 | store_attention=True, 140 | frame_same_noise=True) 141 | 142 | x_res = validation_pipeline.prepare_latents(batch_size=1, 143 | num_channels_latents=4, 144 | video_length=len(prompt), 145 | height=512, 146 | width=512, 147 | dtype=weight_dtype, 148 | device=unet.device, 149 | generator=generator, 150 | store_attention=True, 151 | frame_same_noise=False) 152 | 153 | x_T = np.cos(inference_config['diversity_rand_ratio'] * np.pi / 2) * x_base + np.sin( 154 | inference_config['diversity_rand_ratio'] * np.pi / 2) * x_res 155 | 156 | validation_data.pop('negative_prompt') 157 | # key frame 158 | key_frames, text_embedding = validation_pipeline(prompt, video_length=len(prompt), generator=generator, 159 | latents=x_T.type(weight_dtype), 160 | negative_prompt=negative_prompt, 161 | output_dir=output_dir, 162 | return_text_embedding=True, 163 | **validation_data) 164 | 165 | idx = [args.interp_idx, args.interp_idx + 1] 166 | config["inference_config"]["interpolation_step"] = args.interp_step 167 | interp = validation_pipeline.interpolate_between_two_frames(key_frames[0][:, :, idx], x_T[:, :, idx], 168 | rand_ratio=inference_config['diversity_rand_ratio'], 169 | where=idx, 170 | prompt=[prompt[i] for i in idx], 171 | k=args.interp_num, 172 | generator=generator, 173 | negative_prompt=negative_prompt[0], 174 | text_embedding=text_embedding[idx], 175 | base_noise=x_base[:, :, idx], 176 | **validation_data) 177 | sample = torch.cat([key_frames[0][:, :, :idx[0]], interp, key_frames[0][:, :, idx[1] + 1:]], dim=2) 178 | torch.cuda.empty_cache() 179 | 180 | samples.append(sample) 181 | samples = torch.concat(samples) 182 | save_videos_per_frames_grid(samples, f'{output_dir}/interp_img_samples', n_rows=6) 183 | logger.info(f"Saved samples to {output_dir}") 184 | 185 | 186 | if __name__ == "__main__": 187 | parser = argparse.ArgumentParser() 188 | 189 | parser.add_argument("--config", type=str, default="./configs/flowers.yaml") 190 | parser.add_argument("--interp_idx", type=int, required=True, help="where to interpolate") 191 | parser.add_argument("--interp_num", type=int, default=2) 192 | parser.add_argument("--interp_step", type=int, default=10) 193 | args = parser.parse_args() 194 | 195 | conf = OmegaConf.load(args.config) 196 | main(**conf) 197 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import inspect 4 | import logging 5 | import os 6 | from typing import Dict, Optional 7 | 8 | import diffusers 9 | import numpy as np 10 | import torch 11 | import torch.utils.checkpoint 12 | import transformers 13 | from accelerate import Accelerator 14 | from accelerate.logging import get_logger 15 | from accelerate.utils import set_seed 16 | from diffusers import AutoencoderKL, DDIMScheduler 17 | from diffusers.utils.import_utils import is_xformers_available 18 | from omegaconf import OmegaConf 19 | from transformers import CLIPTextModel, CLIPTokenizer 20 | 21 | from freebloom.models.unet import UNet3DConditionModel 22 | from freebloom.pipelines.pipeline_spatio_temporal import SpatioTemporalPipeline 23 | from freebloom.util import save_videos_grid, save_videos_per_frames_grid 24 | 25 | logger = get_logger(__name__, log_level="INFO") 26 | 27 | 28 | def main( 29 | pretrained_model_path: str, 30 | output_dir: str, 31 | validation_data: Dict, 32 | mixed_precision: Optional[str] = "fp16", 33 | enable_xformers_memory_efficient_attention: bool = True, 34 | seed: Optional[int] = None, 35 | inference_config: Dict = None, 36 | ): 37 | *_, config = inspect.getargvalues(inspect.currentframe()) 38 | 39 | accelerator = Accelerator( 40 | mixed_precision=mixed_precision, 41 | project_dir=f"{output_dir}/acc_log" 42 | ) 43 | 44 | # Make one log on every process with the configuration for debugging. 45 | logging.basicConfig( 46 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 47 | datefmt="%m/%d/%Y %H:%M:%S", 48 | level=logging.INFO, 49 | ) 50 | logger.info(accelerator.state, main_process_only=False) 51 | if accelerator.is_local_main_process: 52 | transformers.utils.logging.set_verbosity_warning() 53 | diffusers.utils.logging.set_verbosity_info() 54 | else: 55 | transformers.utils.logging.set_verbosity_error() 56 | diffusers.utils.logging.set_verbosity_error() 57 | 58 | # If passed along, set the training seed now. 59 | if seed is not None: 60 | set_seed(seed) 61 | 62 | # Handle the output folder creation 63 | if accelerator.is_main_process: 64 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 65 | output_dir = os.path.join(output_dir, now) 66 | os.makedirs(output_dir, exist_ok=True) 67 | os.makedirs(f"{output_dir}/samples", exist_ok=True) 68 | OmegaConf.save(config, os.path.join(output_dir, 'config.yaml')) 69 | 70 | # Load scheduler, tokenizer and models. 71 | tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") 72 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder") 73 | vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae") 74 | unet = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet") 75 | 76 | # Freeze vae and text_encoder 77 | vae.requires_grad_(False) 78 | text_encoder.requires_grad_(False) 79 | unet.requires_grad_(False) 80 | 81 | if enable_xformers_memory_efficient_attention: 82 | if is_xformers_available(): 83 | unet.enable_xformers_memory_efficient_attention() 84 | else: 85 | raise ValueError("xformers is not available. Make sure it is installed correctly") 86 | 87 | # Get the validation pipeline 88 | validation_pipeline = SpatioTemporalPipeline( 89 | vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, 90 | scheduler=DDIMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler"), 91 | disk_store=False, 92 | config=config 93 | ) 94 | validation_pipeline.enable_vae_slicing() 95 | validation_pipeline.scheduler.set_timesteps(validation_data.num_inv_steps) 96 | 97 | # Prepare everything with our `accelerator`. 98 | unet = accelerator.prepare(unet) 99 | 100 | # For mixed precision training we cast the text_encoder and vae weights to half-precision 101 | # as these models are only used for inference, keeping weights in full precision is not required. 102 | weight_dtype = torch.float32 103 | if accelerator.mixed_precision == "fp16": 104 | weight_dtype = torch.float16 105 | elif accelerator.mixed_precision == "bf16": 106 | weight_dtype = torch.bfloat16 107 | 108 | # Move text_encode and vae to gpu and cast to weight_dtype 109 | text_encoder.to(accelerator.device, dtype=weight_dtype) 110 | vae.to(accelerator.device, dtype=weight_dtype) 111 | unet.to(accelerator.device, dtype=weight_dtype) 112 | 113 | # We need to initialize the trackers we use, and also store our configuration. 114 | # The trackers initializes automatically on the main process. 115 | if accelerator.is_main_process: 116 | accelerator.init_trackers("SpatioTemporal") 117 | 118 | text_encoder.eval() 119 | vae.eval() 120 | unet.eval() 121 | 122 | generator = torch.Generator(device=unet.device) 123 | generator.manual_seed(seed) 124 | 125 | samples = [] 126 | 127 | prompt = list(validation_data.prompts) 128 | negative_prompt = config['validation_data']['negative_prompt'] 129 | negative_prompt = [negative_prompt] * len(prompt) 130 | 131 | with (torch.no_grad()): 132 | x_base = validation_pipeline.prepare_latents(batch_size=1, 133 | num_channels_latents=4, 134 | video_length=len(prompt), 135 | height=512, 136 | width=512, 137 | dtype=weight_dtype, 138 | device=unet.device, 139 | generator=generator, 140 | store_attention=True, 141 | frame_same_noise=True) 142 | 143 | x_res = validation_pipeline.prepare_latents(batch_size=1, 144 | num_channels_latents=4, 145 | video_length=len(prompt), 146 | height=512, 147 | width=512, 148 | dtype=weight_dtype, 149 | device=unet.device, 150 | generator=generator, 151 | store_attention=True, 152 | frame_same_noise=False) 153 | 154 | x_T = np.cos(inference_config['diversity_rand_ratio'] * np.pi / 2) * x_base + np.sin( 155 | inference_config['diversity_rand_ratio'] * np.pi / 2) * x_res 156 | 157 | validation_data.pop('negative_prompt') 158 | # key frame 159 | key_frames, text_embedding = validation_pipeline(prompt, video_length=len(prompt), generator=generator, 160 | latents=x_T.type(weight_dtype), 161 | negative_prompt=negative_prompt, 162 | output_dir=output_dir, 163 | return_text_embedding=True, 164 | **validation_data) 165 | torch.cuda.empty_cache() 166 | 167 | samples.append(key_frames[0]) 168 | samples = torch.concat(samples) 169 | save_path = f"{output_dir}/samples/sample.gif" 170 | save_videos_grid(samples, save_path, n_rows=6) 171 | save_videos_per_frames_grid(samples, f'{output_dir}/img_samples', n_rows=6) 172 | logger.info(f"Saved samples to {save_path}") 173 | 174 | 175 | if __name__ == "__main__": 176 | parser = argparse.ArgumentParser() 177 | 178 | parser.add_argument("--config", type=str, default="./configs/flowers.yaml") 179 | args = parser.parse_args() 180 | 181 | conf = OmegaConf.load(args.config) 182 | main(**conf) 183 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.17.1 2 | decord==0.6.0 3 | diffusers==0.11.1 4 | einops==0.7.0 5 | imageio==2.26.0 6 | numpy==1.24.2 7 | omegaconf==2.3.0 8 | opencv_python==4.7.0.72 9 | packaging==23.2 10 | Pillow==10.1.0 11 | torch==1.13.1 12 | torchvision==0.14.1 13 | tqdm==4.65.0 14 | transformers==4.27.1 15 | xformers==0.0.16 16 | --------------------------------------------------------------------------------