├── .gitignore
├── LICENSE
├── README.md
├── __assets__
├── Van_Gogh_flower.gif
├── flower_bloom.gif
├── ice_cube.gif
├── long_video.gif
├── rainbow_forming1.gif
├── rainbow_forming2.gif
├── rainbow_forming3.gif
├── rainbow_forming4.gif
├── river_freezing.gif
├── teaser.png
├── volcano_eruption1.gif
├── volcano_eruption2.gif
├── volcano_eruption3.gif
├── volcano_eruption4.gif
└── yellow_flower.gif
├── configs
├── A_thunderstorm_developing_over_a_sea.yaml
├── a_rainbow_is_forming.yaml
├── flowers.yaml
└── volcano_eruption.yaml
├── freebloom
├── models
│ ├── attention.py
│ ├── resnet.py
│ ├── unet.py
│ └── unet_blocks.py
├── pipelines
│ └── pipeline_spatio_temporal.py
├── prompt_attention
│ ├── attention_util.py
│ ├── ptp_utils.py
│ └── seq_aligner.py
└── util.py
├── interp.py
├── main.py
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 | .empty
3 | outputs
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Free-Bloom
2 |
3 | This repository is the official implementation of [Free-Bloom](https://arxiv.org/abs/2309.14494).
4 |
5 | **[Free-Bloom: Zero-Shot Text-to-Video Generator with LLM Director and LDM Animator](https://arxiv.org/abs/2309.14494)**
6 |
7 | [Hanzhuo Huang∗]() , [Yufan Feng∗](), [Cheng Shi](https://chengshiest.github.io/), [Lan Xu](https://www.xu-lan.com/), [Jingyi Yu](https://vic.shanghaitech.edu.cn/vrvc/en/people/jingyi-yu/), [Sibei Yang†](https://faculty.sist.shanghaitech.edu.cn/yangsibei/)
8 |
9 | *Equal contribution; †Corresponding Author
10 |
11 | [](https://arxiv.org/abs/2309.14494) 
12 |
13 |
14 | 
15 |
16 | ## Setup
17 |
18 | ### Requirements
19 | ```cmd
20 | conda create -n fb python=3.10
21 | conda activate fb
22 | pip install -r requirements.txt
23 | ```
24 |
25 | Installing [xformers](https://github.com/facebookresearch/xformers) is highly recommended for more efficiency and speed on GPUs.
26 | To enable xformers, set `enable_xformers_memory_efficient_attention=True` (default).
27 |
28 |
29 |
30 |
31 | ## Usage
32 |
33 | ### Generate
34 | ```cmd
35 | python main.py --config configs/flowers.yaml
36 | ```
37 |
38 | Change the path of diffusion models to your own for the `pretrained_model_path` key in config yaml file.
39 |
40 |
41 |
42 |
43 |
44 | ## Results
45 |
46 | **A Flower is blooming**
47 |
48 |
49 |
50 |  |
51 |  |
52 |  |
53 |  |
54 |
55 |
56 |
57 |
58 |
59 | **Volcano eruption**
60 |
61 |
62 |
63 |  |
64 |  |
65 |  |
66 |  |
67 |
68 |
69 |
70 | **A rainbow is forming**
71 |
72 |
73 |  |
74 |  |
75 |  |
76 |  |
77 |
78 |
79 |
80 |
81 | ## Citation
82 |
83 | ```
84 | @article{freebloom,
85 | title={Free-Bloom: Zero-Shot Text-to-Video Generator with LLM Director and LDM Animator},
86 | author={Huang, Hanzhuo and Feng, Yufan and Shi, Cheng and Xu, Lan and Yu, Jingyi and Yang, Sibei},
87 | journal={arXiv preprint arXiv:2309.14494},
88 | year={2023}
89 | }
90 | ```
91 |
--------------------------------------------------------------------------------
/__assets__/Van_Gogh_flower.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SooLab/Free-Bloom/4498b463cb4f1e652912db2144cf8d8983e6e141/__assets__/Van_Gogh_flower.gif
--------------------------------------------------------------------------------
/__assets__/flower_bloom.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SooLab/Free-Bloom/4498b463cb4f1e652912db2144cf8d8983e6e141/__assets__/flower_bloom.gif
--------------------------------------------------------------------------------
/__assets__/ice_cube.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SooLab/Free-Bloom/4498b463cb4f1e652912db2144cf8d8983e6e141/__assets__/ice_cube.gif
--------------------------------------------------------------------------------
/__assets__/long_video.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SooLab/Free-Bloom/4498b463cb4f1e652912db2144cf8d8983e6e141/__assets__/long_video.gif
--------------------------------------------------------------------------------
/__assets__/rainbow_forming1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SooLab/Free-Bloom/4498b463cb4f1e652912db2144cf8d8983e6e141/__assets__/rainbow_forming1.gif
--------------------------------------------------------------------------------
/__assets__/rainbow_forming2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SooLab/Free-Bloom/4498b463cb4f1e652912db2144cf8d8983e6e141/__assets__/rainbow_forming2.gif
--------------------------------------------------------------------------------
/__assets__/rainbow_forming3.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SooLab/Free-Bloom/4498b463cb4f1e652912db2144cf8d8983e6e141/__assets__/rainbow_forming3.gif
--------------------------------------------------------------------------------
/__assets__/rainbow_forming4.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SooLab/Free-Bloom/4498b463cb4f1e652912db2144cf8d8983e6e141/__assets__/rainbow_forming4.gif
--------------------------------------------------------------------------------
/__assets__/river_freezing.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SooLab/Free-Bloom/4498b463cb4f1e652912db2144cf8d8983e6e141/__assets__/river_freezing.gif
--------------------------------------------------------------------------------
/__assets__/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SooLab/Free-Bloom/4498b463cb4f1e652912db2144cf8d8983e6e141/__assets__/teaser.png
--------------------------------------------------------------------------------
/__assets__/volcano_eruption1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SooLab/Free-Bloom/4498b463cb4f1e652912db2144cf8d8983e6e141/__assets__/volcano_eruption1.gif
--------------------------------------------------------------------------------
/__assets__/volcano_eruption2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SooLab/Free-Bloom/4498b463cb4f1e652912db2144cf8d8983e6e141/__assets__/volcano_eruption2.gif
--------------------------------------------------------------------------------
/__assets__/volcano_eruption3.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SooLab/Free-Bloom/4498b463cb4f1e652912db2144cf8d8983e6e141/__assets__/volcano_eruption3.gif
--------------------------------------------------------------------------------
/__assets__/volcano_eruption4.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SooLab/Free-Bloom/4498b463cb4f1e652912db2144cf8d8983e6e141/__assets__/volcano_eruption4.gif
--------------------------------------------------------------------------------
/__assets__/yellow_flower.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SooLab/Free-Bloom/4498b463cb4f1e652912db2144cf8d8983e6e141/__assets__/yellow_flower.gif
--------------------------------------------------------------------------------
/configs/A_thunderstorm_developing_over_a_sea.yaml:
--------------------------------------------------------------------------------
1 | pretrained_model_path: "/data/diffusion_wights/stable-diffusion-v1-5"
2 | output_dir: "./outputs/A_thunderstorm_developing_over_a_sea"
3 |
4 | inference_config:
5 | diversity_rand_ratio: 0.1
6 |
7 | validation_data:
8 | prompts:
9 | - "A serene view of the sea under clear skies, with a few fluffy white clouds in the distance."
10 | - "Darker clouds gathering at the horizon as the sea starts to become choppy with rising waves.;"
11 | - "Lightning streaks across the sky, illuminating the dark clouds, and raindrops begin to fall onto the sea's surface."
12 | - "The thunderstorm intensifies with heavy rain pouring down, and strong winds whip up the sea, creating large waves."
13 | - "Lightning strikes become more frequent, illuminating the turbulent sea with flashes of light."
14 | - "The storm reaches its peak, with menacing dark clouds covering the entire sky, and the sea becomes a tumultuous mass of crashing waves and heavy rainfall."
15 |
16 | width: 512
17 | height: 512
18 | num_inference_steps: 50
19 | guidance_scale: 12.5
20 | use_inv_latent: True
21 | num_inv_steps: 50
22 | negative_prompt: "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer difits, cropped, worst quality, low quality, deformed body, bloated, ugly, unrealistic"
23 | interpolate_k: 0
24 | attention_type_former: ["self", "former", "first"]
25 | attention_type_latter: ["self"]
26 | attention_adapt_step: 30 # 0~50
27 |
28 | seed: 3243241 # as you like
29 | mixed_precision: fp16
30 | enable_xformers_memory_efficient_attention: True
31 |
--------------------------------------------------------------------------------
/configs/a_rainbow_is_forming.yaml:
--------------------------------------------------------------------------------
1 | pretrained_model_path: "/data/diffusion_wights/stable-diffusion-v1-5"
2 | output_dir: "./outputs/a_rainbow_is_forming"
3 |
4 | inference_config:
5 | diversity_rand_ratio: 0.1
6 |
7 | validation_data:
8 | prompts:
9 | - "The sky, partially cloudy, with faint hints of colors starting to emerge."
10 | - "A faint arch of colors becomes visible, stretching across the sky."
11 | - "The rainbow gains intensity as the colors become more vibrant and defined."
12 | - "The rainbow is now fully formed, displaying its classic arc shape."
13 | - "The colors of the rainbow shine brilliantly against the backdrop of the sky."
14 | - "The rainbow remains steady, its colors vivid and captivating as the rainbow decorates the sky."
15 |
16 | width: 512
17 | height: 512
18 | num_inference_steps: 50
19 | guidance_scale: 12.5
20 | use_inv_latent: True
21 | num_inv_steps: 50
22 | negative_prompt: "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer difits, cropped, worst quality, low quality, deformed body, bloated, ugly, unrealistic"
23 | interpolate_k: 0
24 | attention_type_former: ["self", "former", "first"]
25 | attention_type_latter: ["self"]
26 | attention_adapt_step: 50 # 0~50
27 |
28 | seed: 324423 # 42 # as you like
29 | mixed_precision: fp16
30 | enable_xformers_memory_efficient_attention: True
31 |
--------------------------------------------------------------------------------
/configs/flowers.yaml:
--------------------------------------------------------------------------------
1 | # A cluster of flowers blooms
2 | pretrained_model_path: "/data/diffusion_wights/stable-diffusion-v1-5"
3 | output_dir: "./outputs/flowers"
4 |
5 | inference_config:
6 | diversity_rand_ratio: 0.1
7 |
8 | validation_data:
9 | prompts:
10 | - "A group of closed buds can be seen on the stem of a plant."
11 | - "The buds begin to slowly unfurl, revealing small petals."
12 | - " The petals continue to unfurl, revealing more of the flower's center."
13 | - "The petals are now fully opened, and the center of the flower is visible."
14 | - "The flower's stamen and pistil become more prominent, and the petals start to curve outward."
15 | - "The fully bloomed flowers are in full view, with their petals open wide and displaying their vibrant colors."
16 |
17 |
18 | width: 512
19 | height: 512
20 | num_inference_steps: 50
21 | guidance_scale: 12.5
22 | use_inv_latent: True
23 | num_inv_steps: 50
24 | negative_prompt: "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer difits, cropped, worst quality, low quality, deformed body, bloated, ugly, unrealistic"
25 | interpolate_k: 0
26 | attention_type_former: [ "self", "first", "former" ]
27 | attention_type_latter: [ "self" ]
28 | attention_adapt_step: 20
29 |
30 | seed: 42
31 | mixed_precision: fp16
32 | enable_xformers_memory_efficient_attention: True
33 |
--------------------------------------------------------------------------------
/configs/volcano_eruption.yaml:
--------------------------------------------------------------------------------
1 | pretrained_model_path: "/data/diffusion_wights/stable-diffusion-v1-5"
2 | output_dir: "./outputs/volcano_eruption_3"
3 |
4 | inference_config:
5 | diversity_rand_ratio: 0.1
6 |
7 | validation_data:
8 | prompts:
9 | - "A towering volcano stands against a backdrop of clear blue skies, with no visible signs of activity."
10 | - "Suddenly, a plume of thick smoke and ash erupts from the volcano's summit, the plume billowing high into the air."
11 | - "Molten lava begins to flow down the volcano's slopes, the lava glowing brightly with intense heat and leaving a trail of destruction in its path."
12 | - "Explosions rock the volcano as fiery projectiles shoot into the sky, the projectiles scattering debris and ash in all directions.;"
13 | - "The eruption intensifies, with a massive column of smoke and ash ascending into the atmosphere, the column darkening the surrounding area and creating a dramatic spectacle."
14 | - "As the eruption reaches its peak, a pyroclastic flow cascades down the volcano's sides, the flow engulfing everything in its path with a deadly combination of hot gases, ash, and volcanic material."
15 |
16 |
17 |
18 | width: 512
19 | height: 512
20 | num_inference_steps: 50
21 | guidance_scale: 12.5
22 | use_inv_latent: True
23 | num_inv_steps: 50
24 | negative_prompt: "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer difits, cropped, worst quality, low quality, deformed body, bloated, ugly, unrealistic"
25 | interpolate_k: 0
26 | attention_type_former: ["self", "former", "first"]
27 | attention_type_latter: ["self"]
28 | attention_adapt_step: 40 # 0~50
29 |
30 | seed: 17 #133 # as you like
31 | mixed_precision: fp16
32 | enable_xformers_memory_efficient_attention: True
33 |
--------------------------------------------------------------------------------
/freebloom/models/attention.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
2 |
3 | from dataclasses import dataclass
4 | from typing import Optional
5 |
6 | import torch
7 | import torch.nn.functional as F
8 | from torch import nn
9 |
10 | from diffusers.configuration_utils import ConfigMixin, register_to_config
11 | from diffusers.modeling_utils import ModelMixin
12 | from diffusers.utils import BaseOutput
13 | from diffusers.utils.import_utils import is_xformers_available
14 | from diffusers.models.attention import CrossAttention, FeedForward, AdaLayerNorm
15 |
16 | from einops import rearrange, repeat
17 |
18 |
19 | @dataclass
20 | class Transformer3DModelOutput(BaseOutput):
21 | sample: torch.FloatTensor
22 |
23 |
24 | if is_xformers_available():
25 | import xformers
26 | import xformers.ops
27 | else:
28 | xformers = None
29 |
30 |
31 | class Transformer3DModel(ModelMixin, ConfigMixin):
32 | @register_to_config
33 | def __init__(
34 | self,
35 | num_attention_heads: int = 16,
36 | attention_head_dim: int = 88,
37 | in_channels: Optional[int] = None,
38 | num_layers: int = 1,
39 | dropout: float = 0.0,
40 | norm_num_groups: int = 32,
41 | cross_attention_dim: Optional[int] = None,
42 | attention_bias: bool = False,
43 | activation_fn: str = "geglu",
44 | num_embeds_ada_norm: Optional[int] = None,
45 | use_linear_projection: bool = False,
46 | only_cross_attention: bool = False,
47 | upcast_attention: bool = False,
48 | ):
49 | super().__init__()
50 | self.use_linear_projection = use_linear_projection
51 | self.num_attention_heads = num_attention_heads
52 | self.attention_head_dim = attention_head_dim
53 | inner_dim = num_attention_heads * attention_head_dim
54 |
55 | # Define input layers
56 | self.in_channels = in_channels
57 |
58 | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
59 | if use_linear_projection:
60 | self.proj_in = nn.Linear(in_channels, inner_dim)
61 | else:
62 | self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
63 |
64 | # Define transformers blocks
65 | self.transformer_blocks = nn.ModuleList(
66 | [
67 | BasicTransformerBlock(
68 | inner_dim,
69 | num_attention_heads,
70 | attention_head_dim,
71 | dropout=dropout,
72 | cross_attention_dim=cross_attention_dim,
73 | activation_fn=activation_fn,
74 | num_embeds_ada_norm=num_embeds_ada_norm,
75 | attention_bias=attention_bias,
76 | only_cross_attention=only_cross_attention,
77 | upcast_attention=upcast_attention,
78 | )
79 | for d in range(num_layers)
80 | ]
81 | )
82 |
83 | # 4. Define output layers
84 | if use_linear_projection:
85 | self.proj_out = nn.Linear(in_channels, inner_dim)
86 | else:
87 | self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
88 |
89 | def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
90 | # Input
91 | assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
92 | video_length = hidden_states.shape[2]
93 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
94 | # encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length)
95 |
96 | batch, channel, height, weight = hidden_states.shape
97 | residual = hidden_states
98 |
99 | hidden_states = self.norm(hidden_states)
100 | if not self.use_linear_projection:
101 | hidden_states = self.proj_in(hidden_states)
102 | inner_dim = hidden_states.shape[1]
103 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
104 | else:
105 | inner_dim = hidden_states.shape[1]
106 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
107 | hidden_states = self.proj_in(hidden_states)
108 |
109 | # Blocks
110 | for block in self.transformer_blocks:
111 | hidden_states = block(
112 | hidden_states,
113 | encoder_hidden_states=encoder_hidden_states,
114 | timestep=timestep,
115 | video_length=video_length
116 | )
117 |
118 | # Output
119 | if not self.use_linear_projection:
120 | hidden_states = (
121 | hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
122 | )
123 | hidden_states = self.proj_out(hidden_states)
124 | else:
125 | hidden_states = self.proj_out(hidden_states)
126 | hidden_states = (
127 | hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
128 | )
129 |
130 | output = hidden_states + residual
131 |
132 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
133 | if not return_dict:
134 | return (output,)
135 |
136 | return Transformer3DModelOutput(sample=output)
137 |
138 |
139 | class BasicTransformerBlock(nn.Module):
140 | def __init__(
141 | self,
142 | dim: int,
143 | num_attention_heads: int,
144 | attention_head_dim: int,
145 | dropout=0.0,
146 | cross_attention_dim: Optional[int] = None,
147 | activation_fn: str = "geglu",
148 | num_embeds_ada_norm: Optional[int] = None,
149 | attention_bias: bool = False,
150 | only_cross_attention: bool = False,
151 | upcast_attention: bool = False,
152 | ):
153 | super().__init__()
154 | self.only_cross_attention = only_cross_attention
155 | self.use_ada_layer_norm = num_embeds_ada_norm is not None
156 |
157 | # SC-Attn
158 | self.attn1 = SparseCausalAttention(
159 | query_dim=dim,
160 | heads=num_attention_heads,
161 | dim_head=attention_head_dim,
162 | dropout=dropout,
163 | bias=attention_bias,
164 | cross_attention_dim=cross_attention_dim if only_cross_attention else None,
165 | upcast_attention=upcast_attention,
166 | )
167 | self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
168 |
169 | # Cross-Attn
170 | if cross_attention_dim is not None:
171 | self.attn2 = CrossAttention(
172 | query_dim=dim,
173 | cross_attention_dim=cross_attention_dim,
174 | heads=num_attention_heads,
175 | dim_head=attention_head_dim,
176 | dropout=dropout,
177 | bias=attention_bias,
178 | upcast_attention=upcast_attention,
179 | )
180 | else:
181 | self.attn2 = None
182 |
183 | if cross_attention_dim is not None:
184 | self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
185 | else:
186 | self.norm2 = None
187 |
188 | # Feed-forward
189 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
190 | self.norm3 = nn.LayerNorm(dim)
191 |
192 | # Temp-Attn
193 | self.attn_temp = CrossAttention(
194 | query_dim=dim,
195 | heads=num_attention_heads,
196 | dim_head=attention_head_dim,
197 | dropout=dropout,
198 | bias=attention_bias,
199 | upcast_attention=upcast_attention,
200 | )
201 | nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
202 | self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
203 |
204 | def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
205 | if not is_xformers_available():
206 | print("Here is how to install it")
207 | raise ModuleNotFoundError(
208 | "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
209 | " xformers",
210 | name="xformers",
211 | )
212 | elif not torch.cuda.is_available():
213 | raise ValueError(
214 | "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
215 | " available for GPU "
216 | )
217 | else:
218 | try:
219 | # Make sure we can run the memory efficient attention
220 | _ = xformers.ops.memory_efficient_attention(
221 | torch.randn((1, 2, 40), device="cuda"),
222 | torch.randn((1, 2, 40), device="cuda"),
223 | torch.randn((1, 2, 40), device="cuda"),
224 | )
225 | except Exception as e:
226 | raise e
227 | self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
228 | if self.attn2 is not None:
229 | self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
230 | # self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
231 |
232 | def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None):
233 | # SparseCausal-Attention
234 | norm_hidden_states = (
235 | self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
236 | )
237 |
238 | if self.only_cross_attention:
239 | hidden_states = (
240 | self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
241 | )
242 | else:
243 | hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask,
244 | video_length=video_length) + hidden_states
245 |
246 | if self.attn2 is not None:
247 | # Cross-Attention
248 | norm_hidden_states = (
249 | self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
250 | )
251 | # if encoder_hidden_states.shape[0] != norm_hidden_states.shape[0]:
252 | # encoder_hidden_states = repeat(encoder_hidden_states, 'a b c -> (a f) b c', f=video_length)
253 | hidden_states = (
254 | self.attn2(
255 | norm_hidden_states, encoder_hidden_states=encoder_hidden_states,
256 | attention_mask=attention_mask
257 | )
258 | + hidden_states
259 | )
260 |
261 |
262 | # Feed-forward
263 | hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
264 |
265 | # Temporal-Attention
266 | # d = hidden_states.shape[1]
267 | # hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
268 | # norm_hidden_states = (
269 | # self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)
270 | # )
271 | # hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
272 | # hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
273 |
274 | return hidden_states
275 |
276 |
277 | class SparseCausalAttention(CrossAttention):
278 |
279 | # def __init__(self, query_dim: int,
280 | # cross_attention_dim: Optional[int] = None,
281 | # heads: int = 8,
282 | # dim_head: int = 64,
283 | # dropout: float = 0.0,
284 | # bias=False,
285 | # upcast_attention: bool = False,
286 | # upcast_softmax: bool = False,
287 | # added_kv_proj_dim: Optional[int] = None,
288 | # norm_num_groups: Optional[int] = None, ):
289 | # super(SparseCausalAttention, self).__init__(query_dim,
290 | # cross_attention_dim,
291 | # heads,
292 | # dim_head,
293 | # dropout,
294 | # bias,
295 | # upcast_attention,
296 | # upcast_softmax,
297 | # added_kv_proj_dim,
298 | # norm_num_groups)
299 | # inner_dim = dim_head * heads
300 | # self.to_q2 = nn.Linear(query_dim, inner_dim, bias=bias)
301 |
302 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
303 | batch_size, sequence_length, _ = hidden_states.shape
304 |
305 | encoder_hidden_states = encoder_hidden_states
306 |
307 | if self.group_norm is not None:
308 | hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
309 |
310 | query = self.to_q(hidden_states)
311 | dim = query.shape[-1]
312 | query = self.reshape_heads_to_batch_dim(query)
313 |
314 | if self.added_kv_proj_dim is not None:
315 | raise NotImplementedError
316 |
317 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
318 | key = self.to_k(encoder_hidden_states)
319 | value = self.to_v(encoder_hidden_states)
320 |
321 | former_frame_index = torch.arange(video_length) - 1
322 | former_frame_index[0] = 0
323 |
324 | key = rearrange(key, "(b f) d c -> b f d c", f=video_length)
325 | key = torch.cat([key[:, [0] * video_length], key[:, former_frame_index]], dim=2)
326 | key = rearrange(key, "b f d c -> (b f) d c")
327 |
328 | value = rearrange(value, "(b f) d c -> b f d c", f=video_length)
329 | value = torch.cat([value[:, [0] * video_length], value[:, former_frame_index]], dim=2)
330 | value = rearrange(value, "b f d c -> (b f) d c")
331 |
332 | key = self.reshape_heads_to_batch_dim(key)
333 | value = self.reshape_heads_to_batch_dim(value)
334 |
335 | if attention_mask is not None:
336 | if attention_mask.shape[-1] != query.shape[1]:
337 | target_length = query.shape[1]
338 | attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
339 | attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
340 |
341 | # attention, what we cannot get enough of
342 | if self._use_memory_efficient_attention_xformers:
343 | hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
344 | # Some versions of xformers return output in fp32, cast it back to the dtype of the input
345 | hidden_states = hidden_states.to(query.dtype)
346 | else:
347 | if self._slice_size is None or query.shape[0] // self._slice_size == 1:
348 | hidden_states = self._attention(query, key, value, attention_mask)
349 | else:
350 | hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
351 |
352 | # linear proj
353 | hidden_states = self.to_out[0](hidden_states)
354 |
355 | # dropout
356 | hidden_states = self.to_out[1](hidden_states)
357 | return hidden_states
358 |
--------------------------------------------------------------------------------
/freebloom/models/resnet.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from einops import rearrange
8 | from torch import Tensor
9 |
10 |
11 | class InflatedGroupNorm(nn.GroupNorm):
12 | def forward(self, x: Tensor) -> Tensor:
13 | f = x.shape[2]
14 |
15 | x = rearrange(x, 'b c f h w -> (b f) c h w')
16 | x = super().forward(x)
17 | x = rearrange(x, '(b f) c h w -> b c f h w', f=f)
18 |
19 | return x
20 |
21 |
22 | class InflatedConv3d(nn.Conv2d):
23 | def __init__(self, in_channels, out_channels, kernel_size, temporal_kernel_size=None, **kwargs):
24 | super(InflatedConv3d, self).__init__(in_channels, out_channels, kernel_size, **kwargs)
25 | if temporal_kernel_size is None:
26 | temporal_kernel_size = kernel_size
27 |
28 | self.conv_temp = (
29 | nn.Conv1d(
30 | out_channels,
31 | out_channels,
32 | kernel_size=temporal_kernel_size,
33 | padding=temporal_kernel_size // 2,
34 | )
35 | if kernel_size > 1
36 | else None
37 | )
38 |
39 | if self.conv_temp is not None:
40 | nn.init.dirac_(self.conv_temp.weight.data) # initialized to be identity
41 | nn.init.zeros_(self.conv_temp.bias.data)
42 |
43 | def forward(self, x):
44 | b = x.shape[0]
45 |
46 | is_video = x.ndim == 5
47 | if is_video:
48 | x = rearrange(x, "b c f h w -> (b f) c h w")
49 |
50 | x = super().forward(x)
51 |
52 | if is_video:
53 | x = rearrange(x, "(b f) c h w -> b c f h w", b=b)
54 |
55 | if self.conv_temp is None or not is_video:
56 | return x
57 |
58 | *_, h, w = x.shape
59 |
60 | x = rearrange(x, "b c f h w -> (b h w) c f")
61 |
62 | x = self.conv_temp(x)
63 |
64 | x = rearrange(x, "(b h w) c f -> b c f h w", h=h, w=w)
65 |
66 | return x
67 |
68 |
69 | class Upsample3D(nn.Module):
70 | def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
71 | super().__init__()
72 | self.channels = channels
73 | self.out_channels = out_channels or channels
74 | self.use_conv = use_conv
75 | self.use_conv_transpose = use_conv_transpose
76 | self.name = name
77 |
78 | conv = None
79 | if use_conv_transpose:
80 | raise NotImplementedError
81 | elif use_conv:
82 | conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
83 |
84 | if name == "conv":
85 | self.conv = conv
86 | else:
87 | self.Conv2d_0 = conv
88 |
89 | def forward(self, hidden_states, output_size=None):
90 | assert hidden_states.shape[1] == self.channels
91 |
92 | if self.use_conv_transpose:
93 | raise NotImplementedError
94 |
95 | # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
96 | dtype = hidden_states.dtype
97 | if dtype == torch.bfloat16:
98 | hidden_states = hidden_states.to(torch.float32)
99 |
100 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
101 | if hidden_states.shape[0] >= 64:
102 | hidden_states = hidden_states.contiguous()
103 |
104 | # if `output_size` is passed we force the interpolation output
105 | # size and do not make use of `scale_factor=2`
106 | if output_size is None:
107 | hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest")
108 | else:
109 | hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
110 |
111 | # If the input is bfloat16, we cast back to bfloat16
112 | if dtype == torch.bfloat16:
113 | hidden_states = hidden_states.to(dtype)
114 |
115 | if self.use_conv:
116 | if self.name == "conv":
117 | hidden_states = self.conv(hidden_states)
118 | else:
119 | hidden_states = self.Conv2d_0(hidden_states)
120 |
121 | return hidden_states
122 |
123 |
124 | class Downsample3D(nn.Module):
125 | def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
126 | super().__init__()
127 | self.channels = channels
128 | self.out_channels = out_channels or channels
129 | self.use_conv = use_conv
130 | self.padding = padding
131 | stride = 2
132 | self.name = name
133 |
134 | if use_conv:
135 | conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
136 | else:
137 | raise NotImplementedError
138 |
139 | if name == "conv":
140 | self.Conv2d_0 = conv
141 | self.conv = conv
142 | elif name == "Conv2d_0":
143 | self.conv = conv
144 | else:
145 | self.conv = conv
146 |
147 | def forward(self, hidden_states):
148 | assert hidden_states.shape[1] == self.channels
149 | if self.use_conv and self.padding == 0:
150 | raise NotImplementedError
151 |
152 | assert hidden_states.shape[1] == self.channels
153 | hidden_states = self.conv(hidden_states)
154 |
155 | return hidden_states
156 |
157 |
158 | class ResnetBlock3D(nn.Module):
159 | def __init__(
160 | self,
161 | *,
162 | in_channels,
163 | out_channels=None,
164 | conv_shortcut=False,
165 | dropout=0.0,
166 | temb_channels=512,
167 | groups=32,
168 | groups_out=None,
169 | pre_norm=True,
170 | eps=1e-6,
171 | non_linearity="swish",
172 | time_embedding_norm="default",
173 | output_scale_factor=1.0,
174 | use_in_shortcut=None,
175 | ):
176 | super().__init__()
177 | self.pre_norm = pre_norm
178 | self.pre_norm = True
179 | self.in_channels = in_channels
180 | out_channels = in_channels if out_channels is None else out_channels
181 | self.out_channels = out_channels
182 | self.use_conv_shortcut = conv_shortcut
183 | self.time_embedding_norm = time_embedding_norm
184 | self.output_scale_factor = output_scale_factor
185 |
186 | if groups_out is None:
187 | groups_out = groups
188 |
189 | self.norm1 = InflatedGroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
190 | # self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
191 |
192 | self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
193 |
194 | if temb_channels is not None:
195 | if self.time_embedding_norm == "default":
196 | time_emb_proj_out_channels = out_channels
197 | elif self.time_embedding_norm == "scale_shift":
198 | time_emb_proj_out_channels = out_channels * 2
199 | else:
200 | raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
201 |
202 | self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
203 | else:
204 | self.time_emb_proj = None
205 |
206 | self.norm2 = InflatedGroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
207 | # self.norm2 = nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
208 |
209 | self.dropout = torch.nn.Dropout(dropout)
210 | self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
211 |
212 | if non_linearity == "swish":
213 | self.nonlinearity = lambda x: F.silu(x)
214 | elif non_linearity == "mish":
215 | self.nonlinearity = Mish()
216 | elif non_linearity == "silu":
217 | self.nonlinearity = nn.SiLU()
218 |
219 | self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
220 |
221 | self.conv_shortcut = None
222 | if self.use_in_shortcut:
223 | self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
224 |
225 | def forward(self, input_tensor, temb):
226 | hidden_states = input_tensor # [b,c,f,h.w]
227 |
228 | hidden_states = self.norm1(hidden_states)
229 | hidden_states = self.nonlinearity(hidden_states)
230 |
231 | hidden_states = self.conv1(hidden_states)
232 |
233 | if temb is not None:
234 | temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
235 |
236 | if temb is not None and self.time_embedding_norm == "default":
237 | hidden_states = hidden_states + temb
238 |
239 | hidden_states = self.norm2(hidden_states)
240 |
241 | if temb is not None and self.time_embedding_norm == "scale_shift":
242 | scale, shift = torch.chunk(temb, 2, dim=1)
243 | hidden_states = hidden_states * (1 + scale) + shift
244 |
245 | hidden_states = self.nonlinearity(hidden_states)
246 |
247 | hidden_states = self.dropout(hidden_states)
248 | hidden_states = self.conv2(hidden_states)
249 |
250 | if self.conv_shortcut is not None:
251 | input_tensor = self.conv_shortcut(input_tensor)
252 |
253 | output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
254 |
255 | return output_tensor
256 |
257 |
258 | class Mish(torch.nn.Module):
259 | def forward(self, hidden_states):
260 | return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
261 |
--------------------------------------------------------------------------------
/freebloom/models/unet.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py
2 |
3 | from dataclasses import dataclass
4 | from typing import List, Optional, Tuple, Union
5 |
6 | import os
7 | import json
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.utils.checkpoint
12 |
13 | from diffusers.configuration_utils import ConfigMixin, register_to_config
14 | from diffusers.modeling_utils import ModelMixin
15 | from diffusers.utils import BaseOutput, logging
16 | from diffusers.models.embeddings import TimestepEmbedding, Timesteps
17 | from .unet_blocks import (
18 | CrossAttnDownBlock3D,
19 | CrossAttnUpBlock3D,
20 | DownBlock3D,
21 | UNetMidBlock3DCrossAttn,
22 | UpBlock3D,
23 | get_down_block,
24 | get_up_block,
25 | )
26 | from .resnet import InflatedConv3d, InflatedGroupNorm
27 |
28 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
29 |
30 |
31 | @dataclass
32 | class UNet3DConditionOutput(BaseOutput):
33 | sample: torch.FloatTensor
34 |
35 |
36 | class UNet3DConditionModel(ModelMixin, ConfigMixin):
37 | _supports_gradient_checkpointing = True
38 |
39 | @register_to_config
40 | def __init__(
41 | self,
42 | sample_size: Optional[int] = None,
43 | in_channels: int = 4,
44 | out_channels: int = 4,
45 | center_input_sample: bool = False,
46 | flip_sin_to_cos: bool = True,
47 | freq_shift: int = 0,
48 | down_block_types: Tuple[str] = (
49 | "CrossAttnDownBlock3D",
50 | "CrossAttnDownBlock3D",
51 | "CrossAttnDownBlock3D",
52 | "DownBlock3D",
53 | ),
54 | mid_block_type: str = "UNetMidBlock3DCrossAttn",
55 | up_block_types: Tuple[str] = (
56 | "UpBlock3D",
57 | "CrossAttnUpBlock3D",
58 | "CrossAttnUpBlock3D",
59 | "CrossAttnUpBlock3D"
60 | ),
61 | only_cross_attention: Union[bool, Tuple[bool]] = False,
62 | block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
63 | layers_per_block: int = 2,
64 | downsample_padding: int = 1,
65 | mid_block_scale_factor: float = 1,
66 | act_fn: str = "silu",
67 | norm_num_groups: int = 32,
68 | norm_eps: float = 1e-5,
69 | cross_attention_dim: int = 1280,
70 | attention_head_dim: Union[int, Tuple[int]] = 8,
71 | dual_cross_attention: bool = False,
72 | use_linear_projection: bool = False,
73 | class_embed_type: Optional[str] = None,
74 | num_class_embeds: Optional[int] = None,
75 | upcast_attention: bool = False,
76 | resnet_time_scale_shift: str = "default",
77 | ):
78 | super().__init__()
79 |
80 | self.sample_size = sample_size
81 | time_embed_dim = block_out_channels[0] * 4
82 |
83 | # input
84 | self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
85 |
86 | # time
87 | self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
88 | timestep_input_dim = block_out_channels[0]
89 |
90 | self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
91 |
92 | # class embedding
93 | if class_embed_type is None and num_class_embeds is not None:
94 | self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
95 | elif class_embed_type == "timestep":
96 | self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
97 | elif class_embed_type == "identity":
98 | self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
99 | else:
100 | self.class_embedding = None
101 |
102 | self.down_blocks = nn.ModuleList([])
103 | self.mid_block = None
104 | self.up_blocks = nn.ModuleList([])
105 |
106 | if isinstance(only_cross_attention, bool):
107 | only_cross_attention = [only_cross_attention] * len(down_block_types)
108 |
109 | if isinstance(attention_head_dim, int):
110 | attention_head_dim = (attention_head_dim,) * len(down_block_types)
111 |
112 | # down
113 | output_channel = block_out_channels[0]
114 | for i, down_block_type in enumerate(down_block_types):
115 | input_channel = output_channel
116 | output_channel = block_out_channels[i]
117 | is_final_block = i == len(block_out_channels) - 1
118 |
119 | down_block = get_down_block(
120 | down_block_type,
121 | num_layers=layers_per_block,
122 | in_channels=input_channel,
123 | out_channels=output_channel,
124 | temb_channels=time_embed_dim,
125 | add_downsample=not is_final_block,
126 | resnet_eps=norm_eps,
127 | resnet_act_fn=act_fn,
128 | resnet_groups=norm_num_groups,
129 | cross_attention_dim=cross_attention_dim,
130 | attn_num_head_channels=attention_head_dim[i],
131 | downsample_padding=downsample_padding,
132 | dual_cross_attention=dual_cross_attention,
133 | use_linear_projection=use_linear_projection,
134 | only_cross_attention=only_cross_attention[i],
135 | upcast_attention=upcast_attention,
136 | resnet_time_scale_shift=resnet_time_scale_shift,
137 | )
138 | self.down_blocks.append(down_block)
139 |
140 | # mid
141 | if mid_block_type == "UNetMidBlock3DCrossAttn":
142 | self.mid_block = UNetMidBlock3DCrossAttn(
143 | in_channels=block_out_channels[-1],
144 | temb_channels=time_embed_dim,
145 | resnet_eps=norm_eps,
146 | resnet_act_fn=act_fn,
147 | output_scale_factor=mid_block_scale_factor,
148 | resnet_time_scale_shift=resnet_time_scale_shift,
149 | cross_attention_dim=cross_attention_dim,
150 | attn_num_head_channels=attention_head_dim[-1],
151 | resnet_groups=norm_num_groups,
152 | dual_cross_attention=dual_cross_attention,
153 | use_linear_projection=use_linear_projection,
154 | upcast_attention=upcast_attention,
155 | )
156 | else:
157 | raise ValueError(f"unknown mid_block_type : {mid_block_type}")
158 |
159 | # count how many layers upsample the videos
160 | self.num_upsamplers = 0
161 |
162 | # up
163 | reversed_block_out_channels = list(reversed(block_out_channels))
164 | reversed_attention_head_dim = list(reversed(attention_head_dim))
165 | only_cross_attention = list(reversed(only_cross_attention))
166 | output_channel = reversed_block_out_channels[0]
167 | for i, up_block_type in enumerate(up_block_types):
168 | is_final_block = i == len(block_out_channels) - 1
169 |
170 | prev_output_channel = output_channel
171 | output_channel = reversed_block_out_channels[i]
172 | input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
173 |
174 | # add upsample block for all BUT final layer
175 | if not is_final_block:
176 | add_upsample = True
177 | self.num_upsamplers += 1
178 | else:
179 | add_upsample = False
180 |
181 | up_block = get_up_block(
182 | up_block_type,
183 | num_layers=layers_per_block + 1,
184 | in_channels=input_channel,
185 | out_channels=output_channel,
186 | prev_output_channel=prev_output_channel,
187 | temb_channels=time_embed_dim,
188 | add_upsample=add_upsample,
189 | resnet_eps=norm_eps,
190 | resnet_act_fn=act_fn,
191 | resnet_groups=norm_num_groups,
192 | cross_attention_dim=cross_attention_dim,
193 | attn_num_head_channels=reversed_attention_head_dim[i],
194 | dual_cross_attention=dual_cross_attention,
195 | use_linear_projection=use_linear_projection,
196 | only_cross_attention=only_cross_attention[i],
197 | upcast_attention=upcast_attention,
198 | resnet_time_scale_shift=resnet_time_scale_shift,
199 | )
200 | self.up_blocks.append(up_block)
201 | prev_output_channel = output_channel
202 |
203 | # out
204 | self.conv_norm_out = InflatedGroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups,
205 | eps=norm_eps)
206 | # self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
207 |
208 | self.conv_act = nn.SiLU()
209 | self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
210 |
211 | def set_attention_slice(self, slice_size):
212 | r"""
213 | Enable sliced attention computation.
214 |
215 | When this option is enabled, the attention module will split the input tensor in slices, to compute attention
216 | in several steps. This is useful to save some memory in exchange for a small speed decrease.
217 |
218 | Args:
219 | slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
220 | When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
221 | `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
222 | provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
223 | must be a multiple of `slice_size`.
224 | """
225 | sliceable_head_dims = []
226 |
227 | def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
228 | if hasattr(module, "set_attention_slice"):
229 | sliceable_head_dims.append(module.sliceable_head_dim)
230 |
231 | for child in module.children():
232 | fn_recursive_retrieve_slicable_dims(child)
233 |
234 | # retrieve number of attention layers
235 | for module in self.children():
236 | fn_recursive_retrieve_slicable_dims(module)
237 |
238 | num_slicable_layers = len(sliceable_head_dims)
239 |
240 | if slice_size == "auto":
241 | # half the attention head size is usually a good trade-off between
242 | # speed and memory
243 | slice_size = [dim // 2 for dim in sliceable_head_dims]
244 | elif slice_size == "max":
245 | # make smallest slice possible
246 | slice_size = num_slicable_layers * [1]
247 |
248 | slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
249 |
250 | if len(slice_size) != len(sliceable_head_dims):
251 | raise ValueError(
252 | f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
253 | f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
254 | )
255 |
256 | for i in range(len(slice_size)):
257 | size = slice_size[i]
258 | dim = sliceable_head_dims[i]
259 | if size is not None and size > dim:
260 | raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
261 |
262 | # Recursively walk through all the children.
263 | # Any children which exposes the set_attention_slice method
264 | # gets the message
265 | def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
266 | if hasattr(module, "set_attention_slice"):
267 | module.set_attention_slice(slice_size.pop())
268 |
269 | for child in module.children():
270 | fn_recursive_set_attention_slice(child, slice_size)
271 |
272 | reversed_slice_size = list(reversed(slice_size))
273 | for module in self.children():
274 | fn_recursive_set_attention_slice(module, reversed_slice_size)
275 |
276 | def _set_gradient_checkpointing(self, module, value=False):
277 | if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
278 | module.gradient_checkpointing = value
279 |
280 | def forward(
281 | self,
282 | sample: torch.FloatTensor,
283 | timestep: Union[torch.Tensor, float, int],
284 | encoder_hidden_states: torch.Tensor,
285 | class_labels: Optional[torch.Tensor] = None,
286 | attention_mask: Optional[torch.Tensor] = None,
287 | return_dict: bool = True,
288 | ) -> Union[UNet3DConditionOutput, Tuple]:
289 | r"""
290 | Args:
291 | sample (`torch.FloatTensor`): (batch, channel, f, height, width) noisy inputs tensor
292 | timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
293 | encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
294 | return_dict (`bool`, *optional*, defaults to `True`):
295 | Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
296 |
297 | Returns:
298 | [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
299 | [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
300 | returning a tuple, the first element is the sample tensor.
301 | """
302 | # By default samples have to be AT least a multiple of the overall upsampling factor.
303 | # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
304 | # However, the upsampling interpolation output size can be forced to fit any upsampling size
305 | # on the fly if necessary.
306 | default_overall_up_factor = 2 ** self.num_upsamplers
307 |
308 | # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
309 | forward_upsample_size = False
310 | upsample_size = None
311 |
312 | if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
313 | logger.info("Forward upsample size to force interpolation output size.")
314 | forward_upsample_size = True
315 |
316 | # prepare attention_mask
317 | if attention_mask is not None:
318 | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
319 | attention_mask = attention_mask.unsqueeze(1)
320 |
321 | # center input if necessary
322 | if self.config.center_input_sample:
323 | sample = 2 * sample - 1.0
324 |
325 | # time
326 | timesteps = timestep
327 | if not torch.is_tensor(timesteps):
328 | # This would be a good case for the `match` statement (Python 3.10+)
329 | is_mps = sample.device.type == "mps"
330 | if isinstance(timestep, float):
331 | dtype = torch.float32 if is_mps else torch.float64
332 | else:
333 | dtype = torch.int32 if is_mps else torch.int64
334 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
335 | elif len(timesteps.shape) == 0:
336 | timesteps = timesteps[None].to(sample.device)
337 |
338 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
339 | timesteps = timesteps.expand(sample.shape[0])
340 |
341 | t_emb = self.time_proj(timesteps)
342 |
343 | # timesteps does not contain any weights and will always return f32 tensors
344 | # but time_embedding might actually be running in fp16. so we need to cast here.
345 | # there might be better ways to encapsulate this.
346 | t_emb = t_emb.to(dtype=self.dtype)
347 | emb = self.time_embedding(t_emb)
348 |
349 | if self.class_embedding is not None:
350 | if class_labels is None:
351 | raise ValueError("class_labels should be provided when num_class_embeds > 0")
352 |
353 | if self.config.class_embed_type == "timestep":
354 | class_labels = self.time_proj(class_labels)
355 |
356 | class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
357 | emb = emb + class_emb
358 |
359 | # pre-process
360 | sample = self.conv_in(sample)
361 |
362 | # down
363 | down_block_res_samples = (sample,)
364 | for downsample_block in self.down_blocks:
365 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
366 | sample, res_samples = downsample_block(
367 | hidden_states=sample,
368 | temb=emb,
369 | encoder_hidden_states=encoder_hidden_states,
370 | attention_mask=attention_mask,
371 | )
372 | else:
373 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
374 |
375 | down_block_res_samples += res_samples
376 |
377 | # mid
378 | sample = self.mid_block(
379 | sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
380 | )
381 | # if sample.shape[2] == 7 and timestep > 500:
382 | # for i in range(5):
383 | # lambda_1 = (5 - i) / 5
384 | # lambda_2 = i / 5
385 | # sample[:, :, i + 1] = lambda_1 * sample[:, :, 0] + lambda_2 * sample[:, :, -1]
386 | # sample[:, :, 1] = 0.5 * sample[:, :, 0] + 0.5 * sample[:, :, 2]
387 |
388 | # up
389 | for i, upsample_block in enumerate(self.up_blocks):
390 | is_final_block = i == len(self.up_blocks) - 1
391 |
392 | res_samples = down_block_res_samples[-len(upsample_block.resnets):]
393 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
394 |
395 | # if we have not reached the final block and need to forward the
396 | # upsample size, we do it here
397 | if not is_final_block and forward_upsample_size:
398 | upsample_size = down_block_res_samples[-1].shape[2:]
399 |
400 | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
401 | sample = upsample_block(
402 | hidden_states=sample,
403 | temb=emb,
404 | res_hidden_states_tuple=res_samples,
405 | encoder_hidden_states=encoder_hidden_states,
406 | upsample_size=upsample_size,
407 | attention_mask=attention_mask,
408 | )
409 | else:
410 | sample = upsample_block(
411 | hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
412 | )
413 | # post-process
414 | sample = self.conv_norm_out(sample)
415 | sample = self.conv_act(sample)
416 | sample = self.conv_out(sample)
417 |
418 | if not return_dict:
419 | return (sample,)
420 |
421 | return UNet3DConditionOutput(sample=sample)
422 |
423 | @classmethod
424 | def from_pretrained_2d(cls, pretrained_model_path, subfolder=None):
425 | if subfolder is not None:
426 | pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
427 |
428 | config_file = os.path.join(pretrained_model_path, 'config.json')
429 | if not os.path.isfile(config_file):
430 | raise RuntimeError(f"{config_file} does not exist")
431 | with open(config_file, "r") as f:
432 | config = json.load(f)
433 | config["_class_name"] = cls.__name__
434 | config["down_block_types"] = [
435 | "CrossAttnDownBlock3D",
436 | "CrossAttnDownBlock3D",
437 | "CrossAttnDownBlock3D",
438 | "DownBlock3D"
439 | ]
440 | config["up_block_types"] = [
441 | "UpBlock3D",
442 | "CrossAttnUpBlock3D",
443 | "CrossAttnUpBlock3D",
444 | "CrossAttnUpBlock3D"
445 | ]
446 |
447 | config['mid_block_type'] = 'UNetMidBlock3DCrossAttn'
448 |
449 | from diffusers.utils import WEIGHTS_NAME
450 | model = cls.from_config(config)
451 | model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
452 | if not os.path.isfile(model_file):
453 | raise RuntimeError(f"{model_file} does not exist")
454 | state_dict = torch.load(model_file, map_location="cpu")
455 | # origin_state_dict = torch.load('/root/code/Tune-A-Video/checkpoints/stable-diffusion-v1-4/unet/diffusion_pytorch_model.bin', map_location='cpu')
456 | for k, v in model.state_dict().items():
457 | if '_temp.' in k:
458 | state_dict.update({k: v})
459 |
460 | # for k, v in origin_state_dict.items():
461 | # if '.to_q' in k and 'attn1' in k:
462 | # state_dict.update({k.replace('to_q', 'to_q2'): v})
463 |
464 | model.load_state_dict(state_dict)
465 |
466 | return model
467 |
--------------------------------------------------------------------------------
/freebloom/models/unet_blocks.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py
2 |
3 | import torch
4 | from torch import nn
5 |
6 | from .attention import Transformer3DModel
7 | from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
8 |
9 |
10 | def get_down_block(
11 | down_block_type,
12 | num_layers,
13 | in_channels,
14 | out_channels,
15 | temb_channels,
16 | add_downsample,
17 | resnet_eps,
18 | resnet_act_fn,
19 | attn_num_head_channels,
20 | resnet_groups=None,
21 | cross_attention_dim=None,
22 | downsample_padding=None,
23 | dual_cross_attention=False,
24 | use_linear_projection=False,
25 | only_cross_attention=False,
26 | upcast_attention=False,
27 | resnet_time_scale_shift="default",
28 | ):
29 | down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
30 | if down_block_type == "DownBlock3D":
31 | return DownBlock3D(
32 | num_layers=num_layers,
33 | in_channels=in_channels,
34 | out_channels=out_channels,
35 | temb_channels=temb_channels,
36 | add_downsample=add_downsample,
37 | resnet_eps=resnet_eps,
38 | resnet_act_fn=resnet_act_fn,
39 | resnet_groups=resnet_groups,
40 | downsample_padding=downsample_padding,
41 | resnet_time_scale_shift=resnet_time_scale_shift,
42 | )
43 | elif down_block_type == "CrossAttnDownBlock3D":
44 | if cross_attention_dim is None:
45 | raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
46 | return CrossAttnDownBlock3D(
47 | num_layers=num_layers,
48 | in_channels=in_channels,
49 | out_channels=out_channels,
50 | temb_channels=temb_channels,
51 | add_downsample=add_downsample,
52 | resnet_eps=resnet_eps,
53 | resnet_act_fn=resnet_act_fn,
54 | resnet_groups=resnet_groups,
55 | downsample_padding=downsample_padding,
56 | cross_attention_dim=cross_attention_dim,
57 | attn_num_head_channels=attn_num_head_channels,
58 | dual_cross_attention=dual_cross_attention,
59 | use_linear_projection=use_linear_projection,
60 | only_cross_attention=only_cross_attention,
61 | upcast_attention=upcast_attention,
62 | resnet_time_scale_shift=resnet_time_scale_shift,
63 | )
64 | raise ValueError(f"{down_block_type} does not exist.")
65 |
66 |
67 | def get_up_block(
68 | up_block_type,
69 | num_layers,
70 | in_channels,
71 | out_channels,
72 | prev_output_channel,
73 | temb_channels,
74 | add_upsample,
75 | resnet_eps,
76 | resnet_act_fn,
77 | attn_num_head_channels,
78 | resnet_groups=None,
79 | cross_attention_dim=None,
80 | dual_cross_attention=False,
81 | use_linear_projection=False,
82 | only_cross_attention=False,
83 | upcast_attention=False,
84 | resnet_time_scale_shift="default",
85 | ):
86 | up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
87 | if up_block_type == "UpBlock3D":
88 | return UpBlock3D(
89 | num_layers=num_layers,
90 | in_channels=in_channels,
91 | out_channels=out_channels,
92 | prev_output_channel=prev_output_channel,
93 | temb_channels=temb_channels,
94 | add_upsample=add_upsample,
95 | resnet_eps=resnet_eps,
96 | resnet_act_fn=resnet_act_fn,
97 | resnet_groups=resnet_groups,
98 | resnet_time_scale_shift=resnet_time_scale_shift,
99 | )
100 | elif up_block_type == "CrossAttnUpBlock3D":
101 | if cross_attention_dim is None:
102 | raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
103 | return CrossAttnUpBlock3D(
104 | num_layers=num_layers,
105 | in_channels=in_channels,
106 | out_channels=out_channels,
107 | prev_output_channel=prev_output_channel,
108 | temb_channels=temb_channels,
109 | add_upsample=add_upsample,
110 | resnet_eps=resnet_eps,
111 | resnet_act_fn=resnet_act_fn,
112 | resnet_groups=resnet_groups,
113 | cross_attention_dim=cross_attention_dim,
114 | attn_num_head_channels=attn_num_head_channels,
115 | dual_cross_attention=dual_cross_attention,
116 | use_linear_projection=use_linear_projection,
117 | only_cross_attention=only_cross_attention,
118 | upcast_attention=upcast_attention,
119 | resnet_time_scale_shift=resnet_time_scale_shift,
120 | )
121 | raise ValueError(f"{up_block_type} does not exist.")
122 |
123 |
124 | class UNetMidBlock3DCrossAttn(nn.Module):
125 | def __init__(
126 | self,
127 | in_channels: int,
128 | temb_channels: int,
129 | dropout: float = 0.0,
130 | num_layers: int = 1,
131 | resnet_eps: float = 1e-6,
132 | resnet_time_scale_shift: str = "default",
133 | resnet_act_fn: str = "swish",
134 | resnet_groups: int = 32,
135 | resnet_pre_norm: bool = True,
136 | attn_num_head_channels=1,
137 | output_scale_factor=1.0,
138 | cross_attention_dim=1280,
139 | dual_cross_attention=False,
140 | use_linear_projection=False,
141 | upcast_attention=False,
142 | ):
143 | super().__init__()
144 |
145 | self.has_cross_attention = True
146 | self.attn_num_head_channels = attn_num_head_channels
147 | resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
148 |
149 | # there is always at least one resnet
150 | resnets = [
151 | ResnetBlock3D(
152 | in_channels=in_channels,
153 | out_channels=in_channels,
154 | temb_channels=temb_channels,
155 | eps=resnet_eps,
156 | groups=resnet_groups,
157 | dropout=dropout,
158 | time_embedding_norm=resnet_time_scale_shift,
159 | non_linearity=resnet_act_fn,
160 | output_scale_factor=output_scale_factor,
161 | pre_norm=resnet_pre_norm,
162 | )
163 | ]
164 | attentions = []
165 |
166 | for _ in range(num_layers):
167 | if dual_cross_attention:
168 | raise NotImplementedError
169 | attentions.append(
170 | Transformer3DModel(
171 | attn_num_head_channels,
172 | in_channels // attn_num_head_channels,
173 | in_channels=in_channels,
174 | num_layers=1,
175 | cross_attention_dim=cross_attention_dim,
176 | norm_num_groups=resnet_groups,
177 | use_linear_projection=use_linear_projection,
178 | upcast_attention=upcast_attention,
179 | )
180 | )
181 | resnets.append(
182 | ResnetBlock3D(
183 | in_channels=in_channels,
184 | out_channels=in_channels,
185 | temb_channels=temb_channels,
186 | eps=resnet_eps,
187 | groups=resnet_groups,
188 | dropout=dropout,
189 | time_embedding_norm=resnet_time_scale_shift,
190 | non_linearity=resnet_act_fn,
191 | output_scale_factor=output_scale_factor,
192 | pre_norm=resnet_pre_norm,
193 | )
194 | )
195 |
196 | self.attentions = nn.ModuleList(attentions)
197 | self.resnets = nn.ModuleList(resnets)
198 |
199 | def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
200 | hidden_states = self.resnets[0](hidden_states, temb)
201 | for attn, resnet in zip(self.attentions, self.resnets[1:]):
202 | hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
203 | hidden_states = resnet(hidden_states, temb)
204 |
205 | return hidden_states
206 |
207 |
208 | class CrossAttnDownBlock3D(nn.Module):
209 | def __init__(
210 | self,
211 | in_channels: int,
212 | out_channels: int,
213 | temb_channels: int,
214 | dropout: float = 0.0,
215 | num_layers: int = 1,
216 | resnet_eps: float = 1e-6,
217 | resnet_time_scale_shift: str = "default",
218 | resnet_act_fn: str = "swish",
219 | resnet_groups: int = 32,
220 | resnet_pre_norm: bool = True,
221 | attn_num_head_channels=1,
222 | cross_attention_dim=1280,
223 | output_scale_factor=1.0,
224 | downsample_padding=1,
225 | add_downsample=True,
226 | dual_cross_attention=False,
227 | use_linear_projection=False,
228 | only_cross_attention=False,
229 | upcast_attention=False,
230 | ):
231 | super().__init__()
232 | resnets = []
233 | attentions = []
234 |
235 | self.has_cross_attention = True
236 | self.attn_num_head_channels = attn_num_head_channels
237 |
238 | for i in range(num_layers):
239 | in_channels = in_channels if i == 0 else out_channels
240 | resnets.append(
241 | ResnetBlock3D(
242 | in_channels=in_channels,
243 | out_channels=out_channels,
244 | temb_channels=temb_channels,
245 | eps=resnet_eps,
246 | groups=resnet_groups,
247 | dropout=dropout,
248 | time_embedding_norm=resnet_time_scale_shift,
249 | non_linearity=resnet_act_fn,
250 | output_scale_factor=output_scale_factor,
251 | pre_norm=resnet_pre_norm,
252 | )
253 | )
254 | if dual_cross_attention:
255 | raise NotImplementedError
256 | attentions.append(
257 | Transformer3DModel(
258 | attn_num_head_channels,
259 | out_channels // attn_num_head_channels,
260 | in_channels=out_channels,
261 | num_layers=1,
262 | cross_attention_dim=cross_attention_dim,
263 | norm_num_groups=resnet_groups,
264 | use_linear_projection=use_linear_projection,
265 | only_cross_attention=only_cross_attention,
266 | upcast_attention=upcast_attention,
267 | )
268 | )
269 | self.attentions = nn.ModuleList(attentions)
270 | self.resnets = nn.ModuleList(resnets)
271 |
272 | if add_downsample:
273 | self.downsamplers = nn.ModuleList(
274 | [
275 | Downsample3D(
276 | out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
277 | )
278 | ]
279 | )
280 | else:
281 | self.downsamplers = None
282 |
283 | self.gradient_checkpointing = False
284 |
285 | def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
286 | output_states = ()
287 |
288 | for resnet, attn in zip(self.resnets, self.attentions):
289 | if self.training and self.gradient_checkpointing:
290 |
291 | def create_custom_forward(module, return_dict=None):
292 | def custom_forward(*inputs):
293 | if return_dict is not None:
294 | return module(*inputs, return_dict=return_dict)
295 | else:
296 | return module(*inputs)
297 |
298 | return custom_forward
299 |
300 | hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
301 | hidden_states = torch.utils.checkpoint.checkpoint(
302 | create_custom_forward(attn, return_dict=False),
303 | hidden_states,
304 | encoder_hidden_states,
305 | )[0]
306 | else:
307 | hidden_states = resnet(hidden_states, temb)
308 | hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
309 |
310 | output_states += (hidden_states,)
311 |
312 | if self.downsamplers is not None:
313 | for downsampler in self.downsamplers:
314 | hidden_states = downsampler(hidden_states)
315 |
316 | output_states += (hidden_states,)
317 |
318 | return hidden_states, output_states
319 |
320 |
321 | class DownBlock3D(nn.Module):
322 | def __init__(
323 | self,
324 | in_channels: int,
325 | out_channels: int,
326 | temb_channels: int,
327 | dropout: float = 0.0,
328 | num_layers: int = 1,
329 | resnet_eps: float = 1e-6,
330 | resnet_time_scale_shift: str = "default",
331 | resnet_act_fn: str = "swish",
332 | resnet_groups: int = 32,
333 | resnet_pre_norm: bool = True,
334 | output_scale_factor=1.0,
335 | add_downsample=True,
336 | downsample_padding=1,
337 | ):
338 | super().__init__()
339 | resnets = []
340 |
341 | for i in range(num_layers):
342 | in_channels = in_channels if i == 0 else out_channels
343 | resnets.append(
344 | ResnetBlock3D(
345 | in_channels=in_channels,
346 | out_channels=out_channels,
347 | temb_channels=temb_channels,
348 | eps=resnet_eps,
349 | groups=resnet_groups,
350 | dropout=dropout,
351 | time_embedding_norm=resnet_time_scale_shift,
352 | non_linearity=resnet_act_fn,
353 | output_scale_factor=output_scale_factor,
354 | pre_norm=resnet_pre_norm,
355 | )
356 | )
357 |
358 | self.resnets = nn.ModuleList(resnets)
359 |
360 | if add_downsample:
361 | self.downsamplers = nn.ModuleList(
362 | [
363 | Downsample3D(
364 | out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
365 | )
366 | ]
367 | )
368 | else:
369 | self.downsamplers = None
370 |
371 | self.gradient_checkpointing = False
372 |
373 | def forward(self, hidden_states, temb=None):
374 | output_states = ()
375 |
376 | for resnet in self.resnets:
377 | if self.training and self.gradient_checkpointing:
378 |
379 | def create_custom_forward(module):
380 | def custom_forward(*inputs):
381 | return module(*inputs)
382 |
383 | return custom_forward
384 |
385 | hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
386 | else:
387 | hidden_states = resnet(hidden_states, temb)
388 |
389 | output_states += (hidden_states,)
390 |
391 | if self.downsamplers is not None:
392 | for downsampler in self.downsamplers:
393 | hidden_states = downsampler(hidden_states)
394 |
395 | output_states += (hidden_states,)
396 |
397 | return hidden_states, output_states
398 |
399 |
400 | class CrossAttnUpBlock3D(nn.Module):
401 | def __init__(
402 | self,
403 | in_channels: int,
404 | out_channels: int,
405 | prev_output_channel: int,
406 | temb_channels: int,
407 | dropout: float = 0.0,
408 | num_layers: int = 1,
409 | resnet_eps: float = 1e-6,
410 | resnet_time_scale_shift: str = "default",
411 | resnet_act_fn: str = "swish",
412 | resnet_groups: int = 32,
413 | resnet_pre_norm: bool = True,
414 | attn_num_head_channels=1,
415 | cross_attention_dim=1280,
416 | output_scale_factor=1.0,
417 | add_upsample=True,
418 | dual_cross_attention=False,
419 | use_linear_projection=False,
420 | only_cross_attention=False,
421 | upcast_attention=False,
422 | ):
423 | super().__init__()
424 | resnets = []
425 | attentions = []
426 |
427 | self.has_cross_attention = True
428 | self.attn_num_head_channels = attn_num_head_channels
429 |
430 | for i in range(num_layers):
431 | res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
432 | resnet_in_channels = prev_output_channel if i == 0 else out_channels
433 |
434 | resnets.append(
435 | ResnetBlock3D(
436 | in_channels=resnet_in_channels + res_skip_channels,
437 | out_channels=out_channels,
438 | temb_channels=temb_channels,
439 | eps=resnet_eps,
440 | groups=resnet_groups,
441 | dropout=dropout,
442 | time_embedding_norm=resnet_time_scale_shift,
443 | non_linearity=resnet_act_fn,
444 | output_scale_factor=output_scale_factor,
445 | pre_norm=resnet_pre_norm,
446 | )
447 | )
448 | if dual_cross_attention:
449 | raise NotImplementedError
450 | attentions.append(
451 | Transformer3DModel(
452 | attn_num_head_channels,
453 | out_channels // attn_num_head_channels,
454 | in_channels=out_channels,
455 | num_layers=1,
456 | cross_attention_dim=cross_attention_dim,
457 | norm_num_groups=resnet_groups,
458 | use_linear_projection=use_linear_projection,
459 | only_cross_attention=only_cross_attention,
460 | upcast_attention=upcast_attention,
461 | )
462 | )
463 |
464 | self.attentions = nn.ModuleList(attentions)
465 | self.resnets = nn.ModuleList(resnets)
466 |
467 | if add_upsample:
468 | self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
469 | else:
470 | self.upsamplers = None
471 |
472 | self.gradient_checkpointing = False
473 |
474 | def forward(
475 | self,
476 | hidden_states,
477 | res_hidden_states_tuple,
478 | temb=None,
479 | encoder_hidden_states=None,
480 | upsample_size=None,
481 | attention_mask=None,
482 | ):
483 | for resnet, attn in zip(self.resnets, self.attentions):
484 | # pop res hidden states
485 | res_hidden_states = res_hidden_states_tuple[-1]
486 | res_hidden_states_tuple = res_hidden_states_tuple[:-1]
487 | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
488 |
489 | if self.training and self.gradient_checkpointing:
490 |
491 | def create_custom_forward(module, return_dict=None):
492 | def custom_forward(*inputs):
493 | if return_dict is not None:
494 | return module(*inputs, return_dict=return_dict)
495 | else:
496 | return module(*inputs)
497 |
498 | return custom_forward
499 |
500 | hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
501 | hidden_states = torch.utils.checkpoint.checkpoint(
502 | create_custom_forward(attn, return_dict=False),
503 | hidden_states,
504 | encoder_hidden_states,
505 | )[0]
506 | else:
507 | hidden_states = resnet(hidden_states, temb)
508 | hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
509 |
510 | if self.upsamplers is not None:
511 | for upsampler in self.upsamplers:
512 | hidden_states = upsampler(hidden_states, upsample_size)
513 |
514 | return hidden_states
515 |
516 |
517 | class UpBlock3D(nn.Module):
518 | def __init__(
519 | self,
520 | in_channels: int,
521 | prev_output_channel: int,
522 | out_channels: int,
523 | temb_channels: int,
524 | dropout: float = 0.0,
525 | num_layers: int = 1,
526 | resnet_eps: float = 1e-6,
527 | resnet_time_scale_shift: str = "default",
528 | resnet_act_fn: str = "swish",
529 | resnet_groups: int = 32,
530 | resnet_pre_norm: bool = True,
531 | output_scale_factor=1.0,
532 | add_upsample=True,
533 | ):
534 | super().__init__()
535 | resnets = []
536 |
537 | for i in range(num_layers):
538 | res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
539 | resnet_in_channels = prev_output_channel if i == 0 else out_channels
540 |
541 | resnets.append(
542 | ResnetBlock3D(
543 | in_channels=resnet_in_channels + res_skip_channels,
544 | out_channels=out_channels,
545 | temb_channels=temb_channels,
546 | eps=resnet_eps,
547 | groups=resnet_groups,
548 | dropout=dropout,
549 | time_embedding_norm=resnet_time_scale_shift,
550 | non_linearity=resnet_act_fn,
551 | output_scale_factor=output_scale_factor,
552 | pre_norm=resnet_pre_norm,
553 | )
554 | )
555 |
556 | self.resnets = nn.ModuleList(resnets)
557 |
558 | if add_upsample:
559 | self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
560 | else:
561 | self.upsamplers = None
562 |
563 | self.gradient_checkpointing = False
564 |
565 | def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
566 | for resnet in self.resnets:
567 | # pop res hidden states
568 | res_hidden_states = res_hidden_states_tuple[-1]
569 | res_hidden_states_tuple = res_hidden_states_tuple[:-1]
570 | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
571 |
572 | if self.training and self.gradient_checkpointing:
573 |
574 | def create_custom_forward(module):
575 | def custom_forward(*inputs):
576 | return module(*inputs)
577 |
578 | return custom_forward
579 |
580 | hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
581 | else:
582 | hidden_states = resnet(hidden_states, temb)
583 |
584 | if self.upsamplers is not None:
585 | for upsampler in self.upsamplers:
586 | hidden_states = upsampler(hidden_states, upsample_size)
587 |
588 | return hidden_states
589 |
--------------------------------------------------------------------------------
/freebloom/pipelines/pipeline_spatio_temporal.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
2 | import copy
3 | import inspect
4 | import os.path
5 | from dataclasses import dataclass
6 | from typing import Callable, List, Optional, Union
7 |
8 | import numpy as np
9 | import torch
10 | import torchvision.utils
11 | from diffusers.configuration_utils import FrozenDict
12 | from diffusers.models import AutoencoderKL
13 | from diffusers.pipeline_utils import DiffusionPipeline
14 | from diffusers.schedulers import (
15 | DDIMScheduler,
16 | DPMSolverMultistepScheduler,
17 | EulerAncestralDiscreteScheduler,
18 | EulerDiscreteScheduler,
19 | LMSDiscreteScheduler,
20 | PNDMScheduler,
21 | )
22 | from diffusers.utils import deprecate, logging, BaseOutput
23 | from diffusers.utils import is_accelerate_available
24 | from einops import rearrange, repeat
25 | from packaging import version
26 | from transformers import CLIPTextModel, CLIPTokenizer
27 |
28 | from ..models.unet import UNet3DConditionModel
29 | from ..prompt_attention import attention_util, ptp_utils
30 |
31 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
32 |
33 |
34 | @dataclass
35 | class SpatioTemporalPipelineOutput(BaseOutput):
36 | videos: Union[torch.Tensor, np.ndarray]
37 |
38 |
39 | class SpatioTemporalPipeline(DiffusionPipeline):
40 | _optional_components = []
41 |
42 | def __init__(
43 | self,
44 | vae: AutoencoderKL,
45 | text_encoder: CLIPTextModel,
46 | tokenizer: CLIPTokenizer,
47 | unet: UNet3DConditionModel,
48 | scheduler: Union[
49 | DDIMScheduler,
50 | PNDMScheduler,
51 | LMSDiscreteScheduler,
52 | EulerDiscreteScheduler,
53 | EulerAncestralDiscreteScheduler,
54 | DPMSolverMultistepScheduler,
55 | ],
56 | noise_scheduler: Union[
57 | DDIMScheduler,
58 | PNDMScheduler,
59 | LMSDiscreteScheduler,
60 | EulerDiscreteScheduler,
61 | EulerAncestralDiscreteScheduler,
62 | DPMSolverMultistepScheduler,
63 | ] = None,
64 | disk_store: bool = False,
65 | config=None
66 | ):
67 | super().__init__()
68 |
69 | if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
70 | deprecation_message = (
71 | f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
72 | f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
73 | "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
74 | " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
75 | " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
76 | " file"
77 | )
78 | deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
79 | new_config = dict(scheduler.config)
80 | new_config["steps_offset"] = 1
81 | scheduler._internal_dict = FrozenDict(new_config)
82 |
83 | if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
84 | deprecation_message = (
85 | f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
86 | " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
87 | " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
88 | " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
89 | " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
90 | )
91 | deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
92 | new_config = dict(scheduler.config)
93 | new_config["clip_sample"] = False
94 | scheduler._internal_dict = FrozenDict(new_config)
95 |
96 | is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
97 | version.parse(unet.config._diffusers_version).base_version
98 | ) < version.parse("0.9.0.dev0")
99 | is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
100 | if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
101 | deprecation_message = (
102 | "The configuration file of the unet has set the default `sample_size` to smaller than"
103 | " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
104 | " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
105 | " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
106 | " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
107 | " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
108 | " in the config might lead to incorrect results in future versions. If you have downloaded this"
109 | " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
110 | " the `unet/config.json` file"
111 | )
112 | deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
113 | new_config = dict(unet.config)
114 | new_config["sample_size"] = 64
115 | unet._internal_dict = FrozenDict(new_config)
116 |
117 | self.register_modules(
118 | vae=vae,
119 | text_encoder=text_encoder,
120 | tokenizer=tokenizer,
121 | unet=unet,
122 | scheduler=scheduler,
123 | noise_scheduler=noise_scheduler
124 | )
125 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
126 |
127 | self.controller = attention_util.AttentionTest(disk_store=disk_store, config=config)
128 | self.hyper_config = config
129 |
130 | def enable_vae_slicing(self):
131 | self.vae.enable_slicing()
132 |
133 | def disable_vae_slicing(self):
134 | self.vae.disable_slicing()
135 |
136 | def enable_sequential_cpu_offload(self, gpu_id=0):
137 | if is_accelerate_available():
138 | from accelerate import cpu_offload
139 | else:
140 | raise ImportError("Please install accelerate via `pip install accelerate`")
141 |
142 | device = torch.device(f"cuda:{gpu_id}")
143 |
144 | for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
145 | if cpu_offloaded_model is not None:
146 | cpu_offload(cpu_offloaded_model, device)
147 |
148 | @property
149 | def _execution_device(self):
150 | if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
151 | return self.device
152 | for module in self.unet.modules():
153 | if (
154 | hasattr(module, "_hf_hook")
155 | and hasattr(module._hf_hook, "execution_device")
156 | and module._hf_hook.execution_device is not None
157 | ):
158 | return torch.device(module._hf_hook.execution_device)
159 | return self.device
160 |
161 | def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt):
162 | batch_size = len(prompt) if isinstance(prompt, list) else 1
163 |
164 | text_inputs = self.tokenizer(
165 | prompt,
166 | padding="max_length",
167 | max_length=self.tokenizer.model_max_length,
168 | truncation=True,
169 | return_tensors="pt",
170 | )
171 | text_input_ids = text_inputs.input_ids
172 | untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
173 |
174 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
175 | removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1: -1])
176 | logger.warning(
177 | "The following part of your input was truncated because CLIP can only handle sequences up to"
178 | f" {self.tokenizer.model_max_length} tokens: {removed_text}"
179 | )
180 |
181 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
182 | attention_mask = text_inputs.attention_mask.to(device)
183 | else:
184 | attention_mask = None
185 |
186 | text_embeddings = self.text_encoder(
187 | text_input_ids.to(device),
188 | attention_mask=attention_mask,
189 | )
190 | text_embeddings = text_embeddings[0]
191 |
192 | # duplicate text embeddings for each generation per prompt, using mps friendly method
193 | bs_embed, seq_len, _ = text_embeddings.shape
194 | text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
195 | text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
196 |
197 | # get unconditional embeddings for classifier free guidance
198 | if do_classifier_free_guidance:
199 | uncond_tokens: List[str]
200 | if negative_prompt is None:
201 | uncond_tokens = [""] * batch_size
202 | elif type(prompt) is not type(negative_prompt):
203 | raise TypeError(
204 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
205 | f" {type(prompt)}."
206 | )
207 | elif isinstance(negative_prompt, str):
208 | uncond_tokens = [negative_prompt]
209 | elif batch_size != len(negative_prompt):
210 | raise ValueError(
211 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
212 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
213 | " the batch size of `prompt`."
214 | )
215 | else:
216 | uncond_tokens = negative_prompt
217 |
218 | max_length = text_input_ids.shape[-1]
219 | uncond_input = self.tokenizer(
220 | uncond_tokens,
221 | padding="max_length",
222 | max_length=max_length,
223 | truncation=True,
224 | return_tensors="pt",
225 | )
226 |
227 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
228 | attention_mask = uncond_input.attention_mask.to(device)
229 | else:
230 | attention_mask = None
231 |
232 | uncond_embeddings = self.text_encoder(
233 | uncond_input.input_ids.to(device),
234 | attention_mask=attention_mask,
235 | )
236 | uncond_embeddings = uncond_embeddings[0]
237 |
238 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
239 | seq_len = uncond_embeddings.shape[1]
240 | uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1)
241 | uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1)
242 |
243 | # For classifier free guidance, we need to do two forward passes.
244 | # Here we concatenate the unconditional and text embeddings into a single batch
245 | # to avoid doing two forward passes
246 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
247 |
248 | return text_embeddings
249 |
250 | def decode_latents(self, latents, return_tensor=False):
251 | video_length = latents.shape[2]
252 | latents = 1 / 0.18215 * latents
253 | latents = rearrange(latents, "b c f h w -> (b f) c h w")
254 | video = self.vae.decode(latents).sample
255 | video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
256 | video = (video / 2 + 0.5).clamp(0, 1)
257 | if return_tensor:
258 | return video
259 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
260 | video = video.cpu().float().numpy()
261 | return video
262 |
263 | def prepare_extra_step_kwargs(self, generator, eta):
264 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
265 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
266 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
267 | # and should be between [0, 1]
268 |
269 | accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
270 | extra_step_kwargs = {}
271 | if accepts_eta:
272 | extra_step_kwargs["eta"] = eta
273 |
274 | # check if the scheduler accepts generator
275 | accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
276 | if accepts_generator:
277 | extra_step_kwargs["generator"] = generator
278 | return extra_step_kwargs
279 |
280 | def check_inputs(self, prompt, height, width, callback_steps, latents):
281 | if not isinstance(prompt, str) and not isinstance(prompt, list):
282 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
283 |
284 | if height % 8 != 0 or width % 8 != 0:
285 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
286 |
287 | if (callback_steps is None) or (
288 | callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
289 | ):
290 | raise ValueError(
291 | f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
292 | f" {type(callback_steps)}."
293 | )
294 | if isinstance(prompt, list) and latents is not None and len(prompt) != latents.shape[2]:
295 | raise ValueError(
296 | f"`prompt` is list but does match all frames. frames length: {latents.shape[2]}, prompt length: {len(prompt)}")
297 |
298 | def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator,
299 | latents=None, store_attention=True, frame_same_noise=False):
300 | if store_attention and self.controller is not None:
301 | ptp_utils.register_attention_control(self, self.controller)
302 |
303 | shape = (
304 | batch_size, num_channels_latents, video_length, height // self.vae_scale_factor,
305 | width // self.vae_scale_factor)
306 | if frame_same_noise:
307 | shape = (batch_size, num_channels_latents, 1, height // self.vae_scale_factor,
308 | width // self.vae_scale_factor)
309 | if isinstance(generator, list) and len(generator) != batch_size:
310 | raise ValueError(
311 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
312 | f" size of {batch_size}. Make sure the batch size matches the length of the generators."
313 | )
314 |
315 | if latents is None:
316 | rand_device = "cpu" if device.type == "mps" else device
317 |
318 | if isinstance(generator, list):
319 | shape = (1,) + shape[1:]
320 | latents = [
321 | torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
322 | for i in range(batch_size)
323 | ]
324 | if frame_same_noise:
325 | latents = [repeat(latent, 'b c 1 h w -> b c f h w', f=video_length) for latent in latents]
326 | latents = torch.cat(latents, dim=0).to(device)
327 | else:
328 | latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
329 | if frame_same_noise:
330 | latents = repeat(latents, 'b c 1 h w -> b c f h w', f=video_length)
331 |
332 | else:
333 | if latents.shape != shape:
334 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
335 | latents = latents.to(device)
336 |
337 | # scale the initial noise by the standard deviation required by the scheduler
338 | latents = latents * self.scheduler.init_noise_sigma
339 | return latents
340 |
341 | @torch.no_grad()
342 | def __call__(
343 | self,
344 | prompt: Union[str, List[str]],
345 | video_length: Optional[int],
346 | height: Optional[int] = None,
347 | width: Optional[int] = None,
348 | num_inference_steps: int = 50,
349 | guidance_scale: float = 7.5,
350 | negative_prompt: Optional[Union[str, List[str]]] = None,
351 | num_videos_per_prompt: Optional[int] = 1,
352 | eta: float = 0.0,
353 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
354 | latents: Optional[torch.FloatTensor] = None,
355 | output_type: Optional[str] = "tensor",
356 | return_dict: bool = True,
357 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
358 | callback_steps: Optional[int] = 1,
359 | fixed_latents: Optional[dict[torch.FloatTensor]] = None,
360 | fixed_latents_idx: list = None,
361 | inner_idx: list = None,
362 | init_text_embedding: torch.Tensor = None,
363 | mask: Optional[dict] = None,
364 | save_vis_inner=False,
365 | return_tensor=False,
366 | return_text_embedding=False,
367 | output_dir=None,
368 | **kwargs,
369 | ):
370 | if self.controller is not None:
371 | self.controller.reset()
372 | self.controller.batch_size = video_length
373 |
374 | # Default height and width to unet
375 | height = height or self.unet.config.sample_size * self.vae_scale_factor
376 | width = width or self.unet.config.sample_size * self.vae_scale_factor
377 |
378 | # Check inputs. Raise error if not correct
379 | self.check_inputs(prompt, height, width, callback_steps, latents)
380 |
381 | # Define call parameters
382 | # batch_size = 1 if isinstance(prompt, str) else len(prompt)
383 | batch_size = 1
384 | device = self._execution_device
385 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
386 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
387 | # corresponds to doing no classifier free guidance.
388 | do_classifier_free_guidance = guidance_scale > 1.0
389 |
390 | # Encode input prompt
391 | text_embeddings = self._encode_prompt(
392 | prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt
393 | )
394 |
395 | if init_text_embedding is not None:
396 | text_embeddings[video_length:] = init_text_embedding
397 |
398 | # Prepare timesteps
399 | self.scheduler.set_timesteps(num_inference_steps, device=device)
400 | timesteps = self.scheduler.timesteps
401 |
402 | # Prepare latent variables
403 | num_channels_latents = self.unet.in_channels
404 | latents = self.prepare_latents(
405 | batch_size * num_videos_per_prompt,
406 | num_channels_latents,
407 | video_length,
408 | height,
409 | width,
410 | text_embeddings.dtype,
411 | device,
412 | generator,
413 | latents,
414 | )
415 | latents_dtype = latents.dtype
416 |
417 | # Prepare extra step kwargs.
418 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
419 |
420 | # Denoising loop
421 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
422 | with self.progress_bar(total=num_inference_steps) as progress_bar:
423 | for i, t in enumerate(timesteps):
424 | # expand the latents if we are doing classifier free guidance
425 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
426 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
427 |
428 | # predict the noise residual
429 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample.to(
430 | dtype=latents_dtype)
431 |
432 | # perform guidance
433 | if do_classifier_free_guidance:
434 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
435 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
436 |
437 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
438 |
439 | if mask is not None and i < 10:
440 | f_1 = latents[:, :, :1]
441 | latents = mask[1] * f_1 + mask[0] * latents
442 | latents = repeat(latents, 'b c 1 h w -> b c f h w', f=video_length)
443 | latents = latents.to(latents_dtype)
444 |
445 | # call the callback, if provided
446 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
447 | progress_bar.update()
448 | if callback is not None and i % callback_steps == 0:
449 | callback(i, t, latents)
450 |
451 | if self.controller is not None:
452 | latents_old = latents
453 | dtype = latents.dtype
454 | latents_new = self.controller.step_callback(latents, inner_idx)
455 | latents = latents_new.to(dtype)
456 |
457 | if fixed_latents is not None:
458 | latents[:, :, fixed_latents_idx] = fixed_latents[i + 1]
459 |
460 | self.controller.empty_cache()
461 |
462 | if save_vis_inner:
463 | video = self.decode_latents(latents)
464 | video = torch.from_numpy(video)
465 | video = rearrange(video, 'b c f h w -> (b f) c h w')
466 | if os.path.exists(output_dir):
467 | os.makedirs(f'{output_dir}/inner', exist_ok=True)
468 | torchvision.utils.save_image(video, f'{output_dir}/inner/{i}.jpg')
469 |
470 | # Post-processing
471 | video = self.decode_latents(latents, return_tensor)
472 |
473 | # Convert to tensor
474 | if output_type == "tensor" and isinstance(video, np.ndarray):
475 | video = torch.from_numpy(video)
476 |
477 | if not return_dict:
478 | return video
479 | if return_text_embedding:
480 | return SpatioTemporalPipelineOutput(videos=video), text_embeddings[video_length:]
481 |
482 | return SpatioTemporalPipelineOutput(videos=video)
483 |
484 | def forward_to_t(self, latents_0, t, noise, text_embeddings):
485 | noisy_latents = self.noise_scheduler.add_noise(latents_0, noise, t)
486 |
487 | # Predict the noise residual and compute loss
488 | model_pred = self.unet(noisy_latents, t, text_embeddings).sample
489 | return model_pred
490 |
491 | def interpolate_between_two_frames(
492 | self,
493 | samples,
494 | x_noise,
495 | rand_ratio,
496 | where,
497 | prompt: Union[str, List[str]],
498 | k=3,
499 | height: Optional[int] = None,
500 | width: Optional[int] = None,
501 | num_inference_steps: int = 50,
502 | guidance_scale: float = 7.5,
503 | negative_prompt: Optional[Union[str, List[str]]] = None,
504 | num_videos_per_prompt: Optional[int] = 1,
505 | eta: float = 0.0,
506 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
507 | output_type: Optional[str] = "tensor",
508 | return_dict: bool = True,
509 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
510 | callback_steps: Optional[int] = 1,
511 | base_noise=None,
512 | text_embedding=None,
513 | **kwargs):
514 | if text_embedding is None:
515 | text_embeddings = self._encode_prompt(prompt, x_noise.device, num_videos_per_prompt, 0, None)
516 | else:
517 | text_embeddings = text_embedding
518 | prompt_temp = copy.deepcopy(prompt)
519 |
520 | times = 0
521 | k = k
522 | fixed_original_idx = where
523 | while times < k:
524 | if not isinstance(samples, torch.Tensor):
525 | samples = samples[0]
526 |
527 | inner_step = self.controller.latents_store
528 | if times == 0:
529 | for key, value in inner_step.items():
530 | inner_step[key] = value[:, :, fixed_original_idx]
531 |
532 | b, c, f, _, __ = samples.shape
533 | dist_list = []
534 | for i in range(f - 1):
535 | dist_list.append(torch.nn.functional.mse_loss(samples[:, :, i + 1], samples[:, :, i]))
536 | dist = torch.tensor(dist_list)
537 | num_insert = 1
538 | value, idx = torch.topk(dist, k=num_insert)
539 | tgt_video_length = num_insert + f
540 |
541 | idx = sorted(idx)
542 | tgt_idx = [id + 1 + i for i, id in enumerate(idx)]
543 |
544 | original_idx = [i for i in range(tgt_video_length) if i not in tgt_idx]
545 | prompt_after_inner = []
546 | x_temp = []
547 | text_embeddings_temp = []
548 | prompt_i = 0
549 | for i in range(tgt_video_length):
550 | if i in tgt_idx:
551 | prompt_after_inner.append('')
552 | search_idx = sorted(original_idx + [i])
553 | find = search_idx.index(i)
554 | pre_idx, next_idx = search_idx[find - 1], search_idx[find + 1]
555 | length = next_idx - pre_idx
556 |
557 | x_temp.append(np.cos(rand_ratio * np.pi / 2) * base_noise[:, :, 0] + np.sin(
558 | rand_ratio * np.pi / 2) * torch.randn_like(base_noise[:, :, 0]))
559 | text_embeddings_temp.append(
560 | (next_idx - i) / length * text_embeddings[prompt_i - 1] + (i - pre_idx) / length *
561 | text_embeddings[prompt_i])
562 | else:
563 | prompt_after_inner.append(prompt_temp[prompt_i])
564 | x_temp.append(x_noise[:, :, prompt_i])
565 | text_embeddings_temp.append(text_embeddings[prompt_i])
566 | prompt_i += 1
567 |
568 | inner_idx = tgt_idx
569 | prompt_temp = prompt_after_inner
570 |
571 | x_noise = torch.stack(x_temp, dim=2)
572 | text_embeddings = torch.stack(text_embeddings_temp, dim=0)
573 |
574 | samples, prompt_embedding = self(prompt_temp,
575 | tgt_video_length,
576 | height,
577 | width,
578 | num_inference_steps,
579 | guidance_scale,
580 | [negative_prompt] * tgt_video_length,
581 | num_videos_per_prompt,
582 | eta,
583 | generator,
584 | x_noise,
585 | output_type,
586 | return_dict,
587 | callback,
588 | callback_steps,
589 | init_text_embedding=text_embeddings,
590 | inner_idx=inner_idx,
591 | return_text_embedding=True,
592 | fixed_latents=copy.deepcopy(inner_step),
593 | fixed_latents_idx=original_idx,
594 | **kwargs)
595 |
596 | times += 1
597 |
598 | return samples[0]
599 |
--------------------------------------------------------------------------------
/freebloom/prompt_attention/attention_util.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import copy
3 | import os
4 | from typing import Union, Tuple, Dict, Optional, List
5 |
6 | import numpy as np
7 | import torch
8 | from PIL import Image
9 |
10 | from freebloom.prompt_attention import ptp_utils, seq_aligner
11 | from freebloom.prompt_attention.ptp_utils import get_time_string
12 |
13 |
14 | class AttentionControl(abc.ABC):
15 |
16 | def step_callback(self, x_t):
17 | self.cur_att_layer = 0
18 | self.cur_step += 1
19 | self.between_steps()
20 | return x_t
21 |
22 | def between_steps(self):
23 | return
24 |
25 | @property
26 | def num_uncond_att_layers(self):
27 | return self.num_att_layers if self.LOW_RESOURCE else 0
28 |
29 | @abc.abstractmethod
30 | def forward(self, attn, is_cross: bool, place_in_unet: str):
31 | raise NotImplementedError
32 |
33 | def __call__(self, attn, is_cross: bool, place_in_unet: str):
34 | if self.cur_att_layer >= self.num_uncond_att_layers:
35 | if self.LOW_RESOURCE:
36 | attn = self.forward(attn, is_cross, place_in_unet)
37 | else:
38 | h = attn.shape[0]
39 | attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet)
40 | self.cur_att_layer += 1
41 |
42 | return attn
43 |
44 | def reset(self):
45 | self.cur_step = 0
46 | self.cur_att_layer = 0
47 |
48 | def __init__(self):
49 | self.LOW_RESOURCE = False
50 | self.cur_step = 0
51 | self.num_att_layers = -1
52 | self.cur_att_layer = 0
53 |
54 |
55 | class EmptyControl(AttentionControl):
56 |
57 | def forward(self, attn, is_cross: bool, place_in_unet: str):
58 | return attn
59 |
60 |
61 | class AttentionStore(AttentionControl):
62 |
63 | def step_callback(self, x_t):
64 | x_t = super().step_callback(x_t)
65 | self.latents_store[self.cur_step] = x_t
66 | return x_t
67 |
68 | @staticmethod
69 | def get_empty_store():
70 | return {"down_cross": [], "mid_cross": [], "up_cross": [],
71 | "down_self": [], "mid_self": [], "up_self": []}
72 |
73 | def forward(self, attn, is_cross: bool, place_in_unet: str):
74 | key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
75 | h = attn.shape[0] // self.batch_size
76 | if attn.shape[1] <= 32 ** 2 and is_cross: # avoid memory overhead
77 | attn = attn.reshape(self.batch_size, h, *attn.shape[1:])
78 | self.step_store[key].append(attn)
79 | attn = attn.reshape(self.batch_size * h, *attn.shape[2:])
80 | return attn
81 |
82 | def between_steps(self):
83 | if len(self.attention_store) == 0:
84 | self.attention_store = self.step_store
85 | else:
86 | for key in self.attention_store:
87 | for i in range(len(self.attention_store[key])):
88 | self.attention_store[key][i] += self.step_store[key][i]
89 |
90 | if self.disk_store:
91 | path = self.store_dir + f'/{self.cur_step:03d}_attn.pt'
92 | torch.save(copy.deepcopy(self.step_store), path)
93 |
94 | self.step_store = self.get_empty_store()
95 |
96 | def get_average_attention(self):
97 | average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in
98 | self.attention_store}
99 | return average_attention
100 |
101 | def reset(self):
102 | super(AttentionStore, self).reset()
103 | self.step_store = self.get_empty_store()
104 | self.attention_store = {}
105 |
106 | def empty_cache(self):
107 | self.step_store = self.get_empty_store()
108 | self.attention_store = {}
109 |
110 | def __init__(self, disk_store=False, config=None):
111 | super(AttentionStore, self).__init__()
112 | self.disk_store = disk_store
113 | if self.disk_store:
114 | time_string = get_time_string()
115 | path = f'./.temp/attention_cache_{time_string}'
116 | os.makedirs(path, exist_ok=True)
117 | self.store_dir = path
118 | else:
119 | self.store_dir = None
120 |
121 | if config:
122 | self.config = config
123 |
124 | self.latents_store = {}
125 |
126 | self.step_store = self.get_empty_store()
127 | self.attention_store = {}
128 | self.attention_type_former = config["validation_data"]["attention_type_former"]
129 | self.attention_type_latter = config["validation_data"]["attention_type_latter"]
130 | self.attention_adapt_step = config["validation_data"]["attention_adapt_step"]
131 |
132 |
133 | class AttentionControlEdit(AttentionStore, abc.ABC):
134 |
135 | def step_callback(self, x_t):
136 | if self.local_blend is not None:
137 | x_t = self.local_blend(x_t, self.attention_store)
138 | return x_t
139 |
140 | def replace_self_attention(self, attn_base, att_replace):
141 | if att_replace.shape[2] <= 16 ** 2:
142 | return attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape)
143 | else:
144 | return att_replace
145 |
146 | @abc.abstractmethod
147 | def replace_cross_attention(self, attn_base, att_replace):
148 | raise NotImplementedError
149 |
150 | def forward(self, attn, is_cross: bool, place_in_unet: str):
151 | super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet)
152 | if is_cross or (self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1]):
153 | h = attn.shape[0] // (self.batch_size)
154 | attn = attn.reshape(self.batch_size, h, *attn.shape[1:])
155 | attn_base, attn_repalce = attn[0], attn[1:]
156 | if is_cross:
157 | alpha_words = self.cross_replace_alpha[self.cur_step]
158 | attn_repalce_new = self.replace_cross_attention(attn_base, attn_repalce) * alpha_words + (
159 | 1 - alpha_words) * attn_repalce
160 | attn[1:] = attn_repalce_new
161 | else:
162 | attn[1:] = self.replace_self_attention(attn_base, attn_repalce)
163 | attn = attn.reshape(self.batch_size * h, *attn.shape[2:])
164 | return attn
165 |
166 | def __init__(self, prompts, num_steps: int,
167 | cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]],
168 | self_replace_steps: Union[float, Tuple[float, float]],
169 | local_blend: None,
170 | tokenizer=None,
171 | device=torch.device('cuda') if torch.cuda.is_available() else torch.device(
172 | 'cpu')): # Optional[LocalBlend]):
173 | super(AttentionControlEdit, self).__init__()
174 | self.batch_size = len(prompts)
175 | self.cross_replace_alpha = ptp_utils.get_time_words_attention_alpha(prompts, num_steps, cross_replace_steps,
176 | tokenizer).to(device)
177 | if type(self_replace_steps) is float:
178 | self_replace_steps = 0, self_replace_steps
179 | self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1])
180 | self.local_blend = local_blend
181 |
182 |
183 | class AttentionReplace(AttentionControlEdit):
184 |
185 | def replace_cross_attention(self, attn_base, att_replace):
186 | return torch.einsum('hpw,bwn->bhpn', attn_base, self.mapper)
187 |
188 | def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float,
189 | local_blend=None, tokenizer=None,
190 | device=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')):
191 | super(AttentionReplace, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend,
192 | tokenizer=tokenizer, device=device)
193 | self.mapper = seq_aligner.get_replacement_mapper(prompts, tokenizer).to(device)
194 |
195 |
196 | class AttentionRefine(AttentionControlEdit):
197 |
198 | def replace_cross_attention(self, attn_base, att_replace):
199 | attn_base_replace = attn_base[:, :, self.mapper].permute(2, 0, 1, 3)
200 | attn_replace = attn_base_replace * self.alphas + att_replace * (1 - self.alphas)
201 | return attn_replace
202 |
203 | def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float,
204 | local_blend=None,
205 | tokenizer=None,
206 | device=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')):
207 | super(AttentionRefine, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend)
208 | self.mapper, alphas = seq_aligner.get_refinement_mapper(prompts, tokenizer)
209 | self.mapper, alphas = self.mapper.to(device), alphas.to(device)
210 | self.alphas = alphas.reshape(alphas.shape[0], 1, 1, alphas.shape[1])
211 |
212 |
213 | class AttentionReweight(AttentionControlEdit):
214 |
215 | def replace_cross_attention(self, attn_base, att_replace):
216 | if self.prev_controller is not None:
217 | attn_base = self.prev_controller.replace_cross_attention(attn_base, att_replace)
218 | attn_replace = attn_base[None, :, :, :] * self.equalizer[:, None, None, :]
219 | return attn_replace
220 |
221 | def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, equalizer,
222 | local_blend=None, controller: Optional[AttentionControlEdit] = None,
223 | device=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')):
224 | super(AttentionReweight, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps,
225 | local_blend)
226 | self.equalizer = equalizer.to(device)
227 | self.prev_controller = controller
228 |
229 |
230 | class AttentionTest(AttentionStore):
231 |
232 | def step_callback(self, x_t, inner_idx=None):
233 | x_t = super(AttentionTest, self).step_callback(x_t)
234 |
235 | if inner_idx is None:
236 | return x_t
237 |
238 | b, c, f, h, w = x_t.shape
239 |
240 | momentum = 0.1 if self.cur_step <= self.config['inference_config']['interpolation_step'] else 1.0
241 | original_idx = [i for i in range(f) if i not in inner_idx]
242 | for idx in inner_idx:
243 | search_idx = sorted(original_idx + [idx])
244 | find = search_idx.index(idx)
245 | pre_idx, next_idx = search_idx[find - 1], search_idx[find + 1]
246 | length = next_idx - pre_idx
247 |
248 | alpha = (idx - pre_idx) / length
249 | x_t[:, :, idx] = (1 - momentum) * ((next_idx - idx) / length * x_t[:, :, pre_idx] + (
250 | idx - pre_idx) / length * x_t[:, :, next_idx]) + momentum * x_t[:, :, idx]
251 |
252 | return x_t
253 |
254 | def forward(self, attn, is_cross: bool, place_in_unet: str):
255 | super(AttentionTest, self).forward(attn, is_cross, place_in_unet)
256 | if is_cross:
257 | h = attn.shape[0] // (self.batch_size)
258 | attn = attn.reshape(self.batch_size, h, *attn.shape[1:])
259 | attn_base, attn_repalce, attn_show = attn[0], attn[1:2], attn[2:3]
260 | attn = attn.reshape(self.batch_size * h, *attn.shape[2:])
261 | return attn
262 |
263 |
264 | def get_equalizer(text: str,
265 | word_select: Union[int, Tuple[int, ...]],
266 | values: Union[List[float], Tuple[float, ...]],
267 | tokenizer=None):
268 | if type(word_select) is int or type(word_select) is str:
269 | word_select = (word_select,)
270 | equalizer = torch.ones(len(values), 77)
271 | values = torch.tensor(values, dtype=torch.float32)
272 | for word in word_select:
273 | inds = ptp_utils.get_word_inds(text, word, tokenizer)
274 | equalizer[:, inds] = values
275 | return equalizer
276 |
277 |
278 | def aggregate_attention(prompts,
279 | attention_store: AttentionStore,
280 | res: int,
281 | from_where: List[str],
282 | is_cross: bool,
283 | select: int):
284 | out = []
285 | attention_maps = attention_store.get_average_attention()
286 | num_pixels = res ** 2
287 | for location in from_where:
288 | for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
289 | if item.shape[1] == num_pixels:
290 | cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select]
291 | out.append(cross_maps)
292 | out = torch.cat(out, dim=0)
293 | out = out.sum(0) / out.shape[0]
294 | return out.cpu()
295 |
296 |
297 | def show_cross_attention(tokenizer,
298 | prompts,
299 | attention_store: AttentionStore, res: int, from_where: List[str],
300 | select: int = 0):
301 | tokens = tokenizer.encode(prompts[select])
302 | decoder = tokenizer.decode
303 | attention_maps = aggregate_attention(attention_store, res, from_where, True, select)
304 | images = []
305 | for i in range(len(tokens)):
306 | image = attention_maps[:, :, i]
307 | image = 255 * image / image.max()
308 | image = image.unsqueeze(-1).expand(*image.shape, 3)
309 | image = image.numpy().astype(np.uint8)
310 | image = np.array(Image.fromarray(image).resize((256, 256)))
311 | image = ptp_utils.text_under_image(image, decoder(int(tokens[i])))
312 | images.append(image)
313 | ptp_utils.view_images(np.stack(images, axis=0))
314 |
315 |
316 | def show_self_attention_comp(attention_store: AttentionStore, res: int, from_where: List[str],
317 | max_com=10, select: int = 0):
318 | attention_maps = aggregate_attention(attention_store, res, from_where, False, select).numpy().reshape(
319 | (res ** 2, res ** 2))
320 | u, s, vh = np.linalg.svd(attention_maps - np.mean(attention_maps, axis=1, keepdims=True))
321 | images = []
322 | for i in range(max_com):
323 | image = vh[i].reshape(res, res)
324 | image = image - image.min()
325 | image = 255 * image / image.max()
326 | image = np.repeat(np.expand_dims(image, axis=2), 3, axis=2).astype(np.uint8)
327 | image = Image.fromarray(image).resize((256, 256))
328 | image = np.array(image)
329 | images.append(image)
330 | ptp_utils.view_images(np.concatenate(images, axis=1))
331 |
--------------------------------------------------------------------------------
/freebloom/prompt_attention/ptp_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import math
15 |
16 | import cv2
17 | import datetime
18 | import numpy as np
19 | import torch
20 | import torch.nn.functional as F
21 | import torchvision
22 | from PIL import Image
23 | from einops import rearrange, repeat
24 | from tqdm.notebook import tqdm
25 | from typing import Optional, Union, Tuple, List, Dict
26 |
27 |
28 | def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)):
29 | h, w, c = image.shape
30 | offset = int(h * .2)
31 | img = np.ones((h + offset, w, c), dtype=np.uint8) * 255
32 | font = cv2.FONT_HERSHEY_SIMPLEX
33 | # font = ImageFont.truetype("/usr/share/fonts/truetype/noto/NotoMono-Regular.ttf", font_size)
34 | img[:h] = image
35 | textsize = cv2.getTextSize(text, font, 1, 2)[0]
36 | text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2
37 | cv2.putText(img, text, (text_x, text_y), font, 1, text_color, 2)
38 | return img
39 |
40 |
41 | def view_images(images, num_rows=1, offset_ratio=0.02, save_path=None):
42 | if type(images) is list:
43 | num_empty = len(images) % num_rows
44 | elif images.ndim == 4:
45 | num_empty = images.shape[0] % num_rows
46 | else:
47 | images = [images]
48 | num_empty = 0
49 |
50 | empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255
51 | images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty
52 | num_items = len(images)
53 |
54 | h, w, c = images[0].shape
55 | offset = int(h * offset_ratio)
56 | num_cols = num_items // num_rows
57 | image_ = np.ones((h * num_rows + offset * (num_rows - 1),
58 | w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255
59 | for i in range(num_rows):
60 | for j in range(num_cols):
61 | image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[
62 | i * num_cols + j]
63 |
64 | if save_path is not None:
65 | pil_img = Image.fromarray(image_)
66 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
67 | pil_img.save(f'{save_path}/{now}.png')
68 | # display(pil_img)
69 |
70 |
71 | def diffusion_step(model, controller, latents, context, t, guidance_scale, low_resource=False):
72 | if low_resource:
73 | noise_pred_uncond = model.unet(latents, t, encoder_hidden_states=context[0])["sample"]
74 | noise_prediction_text = model.unet(latents, t, encoder_hidden_states=context[1])["sample"]
75 | else:
76 | latents_input = torch.cat([latents] * 2)
77 | noise_pred = model.unet(latents_input, t, encoder_hidden_states=context)["sample"]
78 | noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
79 | noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
80 | latents = model.scheduler.step(noise_pred, t, latents)["prev_sample"]
81 | latents = controller.step_callback(latents)
82 | return latents
83 |
84 |
85 | def latent2image(vae, latents):
86 | latents = 1 / 0.18215 * latents
87 | image = vae.decode(latents)['sample']
88 | image = (image / 2 + 0.5).clamp(0, 1)
89 | image = image.cpu().permute(0, 2, 3, 1).numpy()
90 | image = (image * 255).astype(np.uint8)
91 | return image
92 |
93 |
94 | def init_latent(latent, model, height, width, generator, batch_size):
95 | if latent is None:
96 | latent = torch.randn(
97 | (1, model.unet.in_channels, height // 8, width // 8),
98 | generator=generator,
99 | )
100 | latents = latent.expand(batch_size, model.unet.in_channels, height // 8, width // 8).to(model.device)
101 | return latent, latents
102 |
103 |
104 | @torch.no_grad()
105 | def text2image_ldm(
106 | model,
107 | prompt: List[str],
108 | controller,
109 | num_inference_steps: int = 50,
110 | guidance_scale: Optional[float] = 7.,
111 | generator: Optional[torch.Generator] = None,
112 | latent: Optional[torch.FloatTensor] = None,
113 | ):
114 | register_attention_control(model, controller)
115 | height = width = 256
116 | batch_size = len(prompt)
117 |
118 | uncond_input = model.tokenizer([""] * batch_size, padding="max_length", max_length=77, return_tensors="pt")
119 | uncond_embeddings = model.bert(uncond_input.input_ids.to(model.device))[0]
120 |
121 | text_input = model.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt")
122 | text_embeddings = model.bert(text_input.input_ids.to(model.device))[0]
123 | latent, latents = init_latent(latent, model, height, width, generator, batch_size)
124 | context = torch.cat([uncond_embeddings, text_embeddings])
125 |
126 | model.scheduler.set_timesteps(num_inference_steps)
127 | for t in tqdm(model.scheduler.timesteps):
128 | latents = diffusion_step(model, controller, latents, context, t, guidance_scale)
129 |
130 | image = latent2image(model.vqvae, latents)
131 |
132 | return image, latent
133 |
134 |
135 | @torch.no_grad()
136 | def text2image_ldm_stable(
137 | model,
138 | prompt: List[str],
139 | controller,
140 | num_inference_steps: int = 50,
141 | guidance_scale: float = 7.5,
142 | generator: Optional[torch.Generator] = None,
143 | latent: Optional[torch.FloatTensor] = None,
144 | low_resource: bool = False,
145 | ):
146 | register_attention_control(model, controller)
147 | height = width = 512
148 | batch_size = len(prompt)
149 |
150 | text_input = model.tokenizer(
151 | prompt,
152 | padding="max_length",
153 | max_length=model.tokenizer.model_max_length,
154 | truncation=True,
155 | return_tensors="pt",
156 | )
157 | text_embeddings = model.text_encoder(text_input.input_ids.to(model.device))[0]
158 | max_length = text_input.input_ids.shape[-1]
159 | uncond_input = model.tokenizer(
160 | [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
161 | )
162 | uncond_embeddings = model.text_encoder(uncond_input.input_ids.to(model.device))[0]
163 |
164 | context = [uncond_embeddings, text_embeddings]
165 | if not low_resource:
166 | context = torch.cat(context)
167 | latent, latents = init_latent(latent, model, height, width, generator, batch_size)
168 |
169 | # set timesteps
170 | extra_set_kwargs = {"offset": 1}
171 | model.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
172 | for t in tqdm(model.scheduler.timesteps):
173 | latents = diffusion_step(model, controller, latents, context, t, guidance_scale, low_resource)
174 |
175 | image = latent2image(model.vae, latents)
176 |
177 | return image, latent
178 |
179 |
180 | def register_attention_control(model, controller):
181 | def ca_forward(self, place_in_unet, attention_type):
182 | to_out = self.to_out
183 | if type(to_out) is torch.nn.modules.container.ModuleList:
184 | to_out = self.to_out[0]
185 | else:
186 | to_out = self.to_out
187 |
188 | def _attention(query, key, value, is_cross, attention_mask=None):
189 |
190 | if self.upcast_attention:
191 | query = query.float()
192 | key = key.float()
193 |
194 | attention_scores = torch.baddbmm(
195 | torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
196 | query,
197 | key.transpose(-1, -2),
198 | beta=0,
199 | alpha=self.scale,
200 | )
201 |
202 | if attention_mask is not None:
203 | attention_scores = attention_scores + attention_mask
204 |
205 | if self.upcast_softmax:
206 | attention_scores = attention_scores.float()
207 |
208 | attention_probs = attention_scores.softmax(dim=-1)
209 |
210 | # cast back to the original dtype
211 | attention_probs = attention_probs.to(value.dtype)
212 |
213 | attention_probs = controller(attention_probs, is_cross, place_in_unet)
214 |
215 | # compute attention output
216 | hidden_states = torch.bmm(attention_probs, value)
217 |
218 | # reshape hidden_states
219 | hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
220 | return hidden_states
221 |
222 | def reshape_heads_to_batch_dim(tensor):
223 | batch_size, seq_len, dim = tensor.shape
224 | head_size = self.heads
225 | tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
226 | tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
227 | return tensor
228 |
229 | def reshape_batch_dim_to_heads(tensor):
230 | batch_size, seq_len, dim = tensor.shape
231 | head_size = self.heads
232 | tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
233 | tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
234 | return tensor
235 |
236 | def forward(hidden_states, encoder_hidden_states=None, attention_mask=None):
237 | is_cross = encoder_hidden_states is not None
238 |
239 | batch_size, sequence_length, _ = hidden_states.shape
240 |
241 | encoder_hidden_states = encoder_hidden_states
242 |
243 | if self.group_norm is not None:
244 | hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
245 |
246 | query = self.to_q(hidden_states)
247 | dim = query.shape[-1]
248 | query = self.reshape_heads_to_batch_dim(query)
249 |
250 | if self.added_kv_proj_dim is not None:
251 | key = self.to_k(hidden_states)
252 | value = self.to_v(hidden_states)
253 | encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
254 | encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
255 |
256 | key = self.reshape_heads_to_batch_dim(key)
257 | value = self.reshape_heads_to_batch_dim(value)
258 | encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
259 | encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
260 |
261 | key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
262 | value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
263 | else:
264 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
265 | key = self.to_k(encoder_hidden_states)
266 | value = self.to_v(encoder_hidden_states)
267 |
268 | key = self.reshape_heads_to_batch_dim(key)
269 | value = self.reshape_heads_to_batch_dim(value)
270 |
271 | if attention_mask is not None:
272 | if attention_mask.shape[-1] != query.shape[1]:
273 | target_length = query.shape[1]
274 | attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
275 | attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
276 |
277 | # attention, what we cannot get enough of
278 | self._use_memory_efficient_attention_xformers = False
279 | if self._use_memory_efficient_attention_xformers:
280 | hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
281 | # Some versions of xformers return output in fp32, cast it back to the dtype of the input
282 | hidden_states = hidden_states.to(query.dtype)
283 | else:
284 | if self._slice_size is None or query.shape[0] // self._slice_size == 1:
285 | hidden_states = _attention(query, key, value, is_cross, attention_mask)
286 | else:
287 | hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
288 |
289 | # linear proj
290 | hidden_states = self.to_out[0](hidden_states)
291 |
292 | # dropout
293 | hidden_states = self.to_out[1](hidden_states)
294 | return hidden_states
295 |
296 | def sca_forward(hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
297 | is_cross = encoder_hidden_states is not None
298 |
299 | batch_size, sequence_length, _ = hidden_states.shape
300 |
301 | encoder_hidden_states = encoder_hidden_states
302 |
303 | if self.group_norm is not None:
304 | hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
305 |
306 | query = self.to_q(hidden_states)
307 | dim = query.shape[-1]
308 | query = self.reshape_heads_to_batch_dim(query) # [(b f), d ,c]
309 |
310 | if self.added_kv_proj_dim is not None:
311 | raise NotImplementedError
312 |
313 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
314 | key = self.to_k(encoder_hidden_states)
315 | value = self.to_v(encoder_hidden_states)
316 |
317 | key = rearrange(key, "(b f) d c -> b f d c", f=video_length)
318 | value = rearrange(value, "(b f) d c -> b f d c", f=video_length)
319 |
320 | attention_frame_index = []
321 |
322 | attention_type = controller.attention_type_former if controller.cur_step <= controller.attention_adapt_step \
323 | else controller.attention_type_latter
324 |
325 | # self only
326 | if "self" in attention_type:
327 | attention_frame_index.append(torch.arange(video_length))
328 |
329 | # first only
330 | if "first" in attention_type:
331 | attention_frame_index.append([0] * video_length)
332 |
333 | # former
334 | if "former" in attention_type:
335 | attention_frame_index.append(torch.clamp(torch.arange(video_length) - 1, min=0))
336 | # former end
337 |
338 | # all
339 | # for i in range(video_length):
340 | # attention_frame_index.append([i] * video_length)
341 | # all end
342 |
343 | key = torch.cat([key[:, frame_index] for frame_index in attention_frame_index], dim=2)
344 | value = torch.cat([value[:, frame_index] for frame_index in attention_frame_index], dim=2)
345 |
346 | b, f, k, c = key.shape
347 |
348 | key = rearrange(key, "b f d c -> (b f) d c")
349 | value = rearrange(value, "b f d c -> (b f) d c")
350 |
351 | key = self.reshape_heads_to_batch_dim(key)
352 | value = self.reshape_heads_to_batch_dim(value)
353 |
354 | if attention_mask is not None:
355 | if attention_mask.shape[-1] != query.shape[1]:
356 | target_length = query.shape[1]
357 | attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
358 | attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
359 |
360 | # attention, what we cannot get enough of
361 | # self._use_memory_efficient_attention_xformers = False
362 | if self._use_memory_efficient_attention_xformers:
363 | hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
364 | # Some versions of xformers return output in fp32, cast it back to the dtype of the input
365 | hidden_states = hidden_states.to(query.dtype)
366 | else:
367 | if self._slice_size is None or query.shape[0] // self._slice_size == 1:
368 | hidden_states = _attention(query, key, value, is_cross, attention_mask)
369 | else:
370 | hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
371 |
372 | # linear proj
373 | hidden_states = self.to_out[0](hidden_states)
374 |
375 | # dropout
376 | hidden_states = self.to_out[1](hidden_states)
377 | return hidden_states
378 |
379 | if attention_type == 'CrossAttention':
380 | return forward
381 | elif attention_type == "SparseCausalAttention":
382 | return sca_forward
383 |
384 | class DummyController:
385 |
386 | def __call__(self, *args):
387 | return args[0]
388 |
389 | def __init__(self):
390 | self.num_att_layers = 0
391 |
392 | if controller is None:
393 | controller = DummyController()
394 |
395 | def register_recr(net_, count, place_in_unet):
396 | if net_.__class__.__name__ == 'SparseCausalAttention':
397 | net_.forward = ca_forward(net_, place_in_unet, net_.__class__.__name__)
398 | return count + 1
399 | elif hasattr(net_, 'children'):
400 | for net__ in net_.children():
401 | count = register_recr(net__, count, place_in_unet)
402 | return count
403 |
404 | cross_att_count = 0
405 | sub_nets = model.unet.named_children()
406 | for net in sub_nets:
407 | if "down" in net[0]:
408 | cross_att_count += register_recr(net[1], 0, "down")
409 | elif "up" in net[0]:
410 | cross_att_count += register_recr(net[1], 0, "up")
411 | elif "mid" in net[0]:
412 | cross_att_count += register_recr(net[1], 0, "mid")
413 |
414 | controller.num_att_layers = cross_att_count
415 |
416 |
417 | def get_word_inds(text: str, word_place: int, tokenizer):
418 | split_text = text.split(" ")
419 | if type(word_place) is str:
420 | word_place = [i for i, word in enumerate(split_text) if word_place == word]
421 | elif type(word_place) is int:
422 | word_place = [word_place]
423 | out = []
424 | if len(word_place) > 0:
425 | words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1]
426 | cur_len, ptr = 0, 0
427 |
428 | for i in range(len(words_encode)):
429 | cur_len += len(words_encode[i])
430 | if ptr in word_place:
431 | out.append(i + 1)
432 | if cur_len >= len(split_text[ptr]):
433 | ptr += 1
434 | cur_len = 0
435 | return np.array(out)
436 |
437 |
438 | def update_alpha_time_word(alpha, bounds: Union[float, Tuple[float, float]], prompt_ind: int,
439 | word_inds: Optional[torch.Tensor] = None):
440 | if type(bounds) is float:
441 | bounds = 0, bounds
442 | start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0])
443 | if word_inds is None:
444 | word_inds = torch.arange(alpha.shape[2])
445 | alpha[: start, prompt_ind, word_inds] = 0
446 | alpha[start: end, prompt_ind, word_inds] = 1
447 | alpha[end:, prompt_ind, word_inds] = 0
448 | return alpha
449 |
450 |
451 | def get_time_words_attention_alpha(prompts, num_steps,
452 | cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]],
453 | tokenizer, max_num_words=77):
454 | if type(cross_replace_steps) is not dict:
455 | cross_replace_steps = {"default_": cross_replace_steps}
456 | if "default_" not in cross_replace_steps:
457 | cross_replace_steps["default_"] = (0., 1.)
458 | alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words)
459 | for i in range(len(prompts) - 1):
460 | alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"],
461 | i)
462 | for key, item in cross_replace_steps.items():
463 | if key != "default_":
464 | inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))]
465 | for i, ind in enumerate(inds):
466 | if len(ind) > 0:
467 | alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind)
468 | alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words)
469 | return alpha_time_words
470 |
471 |
472 | def get_time_string() -> str:
473 | x = datetime.datetime.now()
474 | return f"{(x.year - 2000):02d}{x.month:02d}{x.day:02d}-{x.hour:02d}{x.minute:02d}{x.second:02d}"
475 |
--------------------------------------------------------------------------------
/freebloom/prompt_attention/seq_aligner.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import torch
15 | import numpy as np
16 |
17 |
18 | class ScoreParams:
19 |
20 | def __init__(self, gap, match, mismatch):
21 | self.gap = gap
22 | self.match = match
23 | self.mismatch = mismatch
24 |
25 | def mis_match_char(self, x, y):
26 | if x != y:
27 | return self.mismatch
28 | else:
29 | return self.match
30 |
31 |
32 | def get_matrix(size_x, size_y, gap):
33 | matrix = []
34 | for i in range(len(size_x) + 1):
35 | sub_matrix = []
36 | for j in range(len(size_y) + 1):
37 | sub_matrix.append(0)
38 | matrix.append(sub_matrix)
39 | for j in range(1, len(size_y) + 1):
40 | matrix[0][j] = j * gap
41 | for i in range(1, len(size_x) + 1):
42 | matrix[i][0] = i * gap
43 | return matrix
44 |
45 |
46 | def get_matrix(size_x, size_y, gap):
47 | matrix = np.zeros((size_x + 1, size_y + 1), dtype=np.int32)
48 | matrix[0, 1:] = (np.arange(size_y) + 1) * gap
49 | matrix[1:, 0] = (np.arange(size_x) + 1) * gap
50 | return matrix
51 |
52 |
53 | def get_traceback_matrix(size_x, size_y):
54 | matrix = np.zeros((size_x + 1, size_y + 1), dtype=np.int32)
55 | matrix[0, 1:] = 1
56 | matrix[1:, 0] = 2
57 | matrix[0, 0] = 4
58 | return matrix
59 |
60 |
61 | def global_align(x, y, score):
62 | matrix = get_matrix(len(x), len(y), score.gap)
63 | trace_back = get_traceback_matrix(len(x), len(y))
64 | for i in range(1, len(x) + 1):
65 | for j in range(1, len(y) + 1):
66 | left = matrix[i, j - 1] + score.gap
67 | up = matrix[i - 1, j] + score.gap
68 | diag = matrix[i - 1, j - 1] + score.mis_match_char(x[i - 1], y[j - 1])
69 | matrix[i, j] = max(left, up, diag)
70 | if matrix[i, j] == left:
71 | trace_back[i, j] = 1
72 | elif matrix[i, j] == up:
73 | trace_back[i, j] = 2
74 | else:
75 | trace_back[i, j] = 3
76 | return matrix, trace_back
77 |
78 |
79 | def get_aligned_sequences(x, y, trace_back):
80 | x_seq = []
81 | y_seq = []
82 | i = len(x)
83 | j = len(y)
84 | mapper_y_to_x = []
85 | while i > 0 or j > 0:
86 | if trace_back[i, j] == 3:
87 | x_seq.append(x[i - 1])
88 | y_seq.append(y[j - 1])
89 | i = i - 1
90 | j = j - 1
91 | mapper_y_to_x.append((j, i))
92 | elif trace_back[i][j] == 1:
93 | x_seq.append('-')
94 | y_seq.append(y[j - 1])
95 | j = j - 1
96 | mapper_y_to_x.append((j, -1))
97 | elif trace_back[i][j] == 2:
98 | x_seq.append(x[i - 1])
99 | y_seq.append('-')
100 | i = i - 1
101 | elif trace_back[i][j] == 4:
102 | break
103 | mapper_y_to_x.reverse()
104 | return x_seq, y_seq, torch.tensor(mapper_y_to_x, dtype=torch.int64)
105 |
106 |
107 | def get_mapper(x: str, y: str, tokenizer, max_len=77):
108 | x_seq = tokenizer.encode(x)
109 | y_seq = tokenizer.encode(y)
110 | score = ScoreParams(0, 1, -1)
111 | matrix, trace_back = global_align(x_seq, y_seq, score)
112 | mapper_base = get_aligned_sequences(x_seq, y_seq, trace_back)[-1]
113 | alphas = torch.ones(max_len)
114 | alphas[: mapper_base.shape[0]] = mapper_base[:, 1].ne(-1).float()
115 | mapper = torch.zeros(max_len, dtype=torch.int64)
116 | mapper[:mapper_base.shape[0]] = mapper_base[:, 1]
117 | mapper[mapper_base.shape[0]:] = len(y_seq) + torch.arange(max_len - len(y_seq))
118 | return mapper, alphas
119 |
120 |
121 | def get_refinement_mapper(prompts, tokenizer, max_len=77):
122 | x_seq = prompts[0]
123 | mappers, alphas = [], []
124 | for i in range(1, len(prompts)):
125 | mapper, alpha = get_mapper(x_seq, prompts[i], tokenizer, max_len)
126 | mappers.append(mapper)
127 | alphas.append(alpha)
128 | return torch.stack(mappers), torch.stack(alphas)
129 |
130 |
131 | def get_word_inds(text: str, word_place: int, tokenizer):
132 | split_text = text.split(" ")
133 | if type(word_place) is str:
134 | word_place = [i for i, word in enumerate(split_text) if word_place == word]
135 | elif type(word_place) is int:
136 | word_place = [word_place]
137 | out = []
138 | if len(word_place) > 0:
139 | words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1]
140 | cur_len, ptr = 0, 0
141 |
142 | for i in range(len(words_encode)):
143 | cur_len += len(words_encode[i])
144 | if ptr in word_place:
145 | out.append(i + 1)
146 | if cur_len >= len(split_text[ptr]):
147 | ptr += 1
148 | cur_len = 0
149 | return np.array(out)
150 |
151 |
152 | def get_replacement_mapper_(x: str, y: str, tokenizer, max_len=77):
153 | words_x = x.split(' ')
154 | words_y = y.split(' ')
155 | if len(words_x) != len(words_y):
156 | raise ValueError(f"attention replacement edit can only be applied on prompts with the same length"
157 | f" but prompt A has {len(words_x)} words and prompt B has {len(words_y)} words.")
158 | inds_replace = [i for i in range(len(words_y)) if words_y[i] != words_x[i]]
159 | inds_source = [get_word_inds(x, i, tokenizer) for i in inds_replace]
160 | inds_target = [get_word_inds(y, i, tokenizer) for i in inds_replace]
161 | mapper = np.zeros((max_len, max_len))
162 | i = j = 0
163 | cur_inds = 0
164 | while i < max_len and j < max_len:
165 | if cur_inds < len(inds_source) and inds_source[cur_inds][0] == i:
166 | inds_source_, inds_target_ = inds_source[cur_inds], inds_target[cur_inds]
167 | if len(inds_source_) == len(inds_target_):
168 | mapper[inds_source_, inds_target_] = 1
169 | else:
170 | ratio = 1 / len(inds_target_)
171 | for i_t in inds_target_:
172 | mapper[inds_source_, i_t] = ratio
173 | cur_inds += 1
174 | i += len(inds_source_)
175 | j += len(inds_target_)
176 | elif cur_inds < len(inds_source):
177 | mapper[i, j] = 1
178 | i += 1
179 | j += 1
180 | else:
181 | mapper[j, j] = 1
182 | i += 1
183 | j += 1
184 |
185 | return torch.from_numpy(mapper).float()
186 |
187 |
188 | def get_replacement_mapper(prompts, tokenizer, max_len=77):
189 | x_seq = prompts[0]
190 | mappers = []
191 | for i in range(1, len(prompts)):
192 | mapper = get_replacement_mapper_(x_seq, prompts[i], tokenizer, max_len)
193 | mappers.append(mapper)
194 | return torch.stack(mappers)
195 |
--------------------------------------------------------------------------------
/freebloom/util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import imageio
3 | import numpy as np
4 | from PIL import Image
5 | from typing import Union
6 |
7 | import torch
8 | import torchvision
9 |
10 | from tqdm import tqdm
11 | from einops import rearrange
12 |
13 |
14 | def save_tensor_img(img, save_path):
15 | """
16 | :param img c, h, w , -1~1
17 | """
18 | # img = (img + 1.0) / 2.0
19 | img = Image.fromarray(img.mul(255).byte().numpy().transpose(1, 2, 0))
20 | img.save(save_path)
21 |
22 |
23 | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=4, fps=8):
24 | videos = rearrange(videos, "b c t h w -> t b c h w")
25 | outputs = []
26 | for x in videos:
27 | x = torchvision.utils.make_grid(x, nrow=n_rows)
28 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
29 | if rescale:
30 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1
31 | x = (x * 255).numpy().astype(np.uint8)
32 | outputs.append(x)
33 |
34 | os.makedirs(os.path.dirname(path), exist_ok=True)
35 | imageio.mimsave(path, outputs, fps=fps)
36 |
37 |
38 | def save_videos_per_frames_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=4):
39 | os.makedirs(path, exist_ok=True)
40 | for i, video in enumerate(videos):
41 | video = rearrange(video, "c f h w -> f c h w")
42 | x = torchvision.utils.make_grid(video, nrow=n_rows)
43 | if rescale:
44 | x = (x + 1.0) / 2.0
45 | # x = (x * 255).numpy().astype(np.int8)
46 | torchvision.utils.save_image(x, f'{path}/{i}_all.jpg')
47 | for j, img in enumerate(video):
48 | save_tensor_img(img, f'{path}/{j}.jpg')
49 |
50 | # DDIM Inversion
51 | @torch.no_grad()
52 | def init_prompt(prompt, pipeline):
53 | uncond_input = pipeline.tokenizer(
54 | [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length,
55 | return_tensors="pt"
56 | )
57 | uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0]
58 | text_input = pipeline.tokenizer(
59 | [prompt],
60 | padding="max_length",
61 | max_length=pipeline.tokenizer.model_max_length,
62 | truncation=True,
63 | return_tensors="pt",
64 | )
65 | text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0]
66 | context = torch.cat([uncond_embeddings, text_embeddings])
67 |
68 | return context
69 |
70 |
71 | def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int,
72 | sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler):
73 | timestep, next_timestep = min(
74 | timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep
75 | alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod
76 | alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep]
77 | beta_prod_t = 1 - alpha_prod_t
78 | next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
79 | next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
80 | next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
81 | return next_sample
82 |
83 |
84 | def get_noise_pred_single(latents, t, context, unet):
85 | noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"]
86 | return noise_pred
87 |
88 |
89 | @torch.no_grad()
90 | def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt):
91 | context = init_prompt(prompt, pipeline)
92 | uncond_embeddings, cond_embeddings = context.chunk(2)
93 |
94 | all_latent = [latent]
95 | latent = latent.clone().detach()
96 | for i in tqdm(range(num_inv_steps)):
97 | t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1]
98 | noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet)
99 | latent = next_step(noise_pred, t, latent, ddim_scheduler)
100 | all_latent.append(latent)
101 | return all_latent
102 |
103 |
104 | @torch.no_grad()
105 | def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""):
106 | ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt)
107 | return ddim_latents
108 |
109 |
110 | def invert(image_path, weight_dtype, pipeline, ddim_scheduler, num_inv_steps, prompt=""):
111 | image_gt = load_512(image_path)
112 | image = torch.from_numpy(image_gt).type(weight_dtype) / 127.5 - 1
113 | image = image.permute(2, 0, 1).unsqueeze(0).to(pipeline.vae.device)
114 | latent = pipeline.vae.encode(image)['latent_dist'].mean.unsqueeze(2)
115 | latent = latent * 0.18215 # pipeline.vae.config.scaling_factor
116 | latents = ddim_inversion(pipeline, ddim_scheduler, latent, num_inv_steps, prompt=prompt)
117 | return latents
118 |
119 |
120 | def load_512(image_path, left=0, right=0, top=0, bottom=0):
121 | if type(image_path) is str:
122 | image = np.array(Image.open(image_path))[:, :, :3]
123 | else:
124 | image = image_path
125 | h, w, c = image.shape
126 | left = min(left, w - 1)
127 | right = min(right, w - left - 1)
128 | top = min(top, h - left - 1)
129 | bottom = min(bottom, h - top - 1)
130 | image = image[top:h - bottom, left:w - right]
131 | h, w, c = image.shape
132 | if h < w:
133 | offset = (w - h) // 2
134 | image = image[:, offset:offset + h]
135 | elif w < h:
136 | offset = (h - w) // 2
137 | image = image[offset:offset + w]
138 | image = np.array(Image.fromarray(image).resize((512, 512)))
139 | return image
140 |
141 |
142 | @torch.no_grad()
143 | def latent2image(vae, latents):
144 | latents = 1 / vae.config.scaling_factor * latents
145 | image = vae.decode(latents)['sample']
146 | image = (image / 2 + 0.5).clamp(0, 1)
147 | image = image.cpu().permute(0, 2, 3, 1).numpy()
148 | image = (image * 255).astype(np.uint8)
149 | return image
150 |
151 |
152 | @torch.no_grad()
153 | def image2latent(vae, image, device, dtype=torch.float32):
154 | with torch.no_grad():
155 | if type(image) is Image:
156 | image = np.array(image)
157 | if type(image) is torch.Tensor and image.dim() == 4:
158 | latents = image
159 | else:
160 | image = torch.from_numpy(image).type(dtype) / 127.5 - 1
161 | image = image.permute(2, 0, 1).unsqueeze(0).to(device)
162 | latents = vae.encode(image)['latent_dist'].mean
163 | latents = latents * vae.config.scaling_factor
164 | return latents
165 |
--------------------------------------------------------------------------------
/interp.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datetime
3 | import inspect
4 | import logging
5 | import os
6 | from typing import Dict, Optional
7 |
8 | import diffusers
9 | import numpy as np
10 | import torch
11 | import torch.utils.checkpoint
12 | import transformers
13 | from accelerate import Accelerator
14 | from accelerate.logging import get_logger
15 | from accelerate.utils import set_seed
16 | from diffusers import AutoencoderKL, DDIMScheduler
17 | from diffusers.utils.import_utils import is_xformers_available
18 | from omegaconf import OmegaConf
19 | from transformers import CLIPTextModel, CLIPTokenizer
20 |
21 | from freebloom.models.unet import UNet3DConditionModel
22 | from freebloom.pipelines.pipeline_spatio_temporal import SpatioTemporalPipeline
23 | from freebloom.util import save_videos_grid, save_videos_per_frames_grid
24 |
25 | logger = get_logger(__name__, log_level="INFO")
26 |
27 |
28 | def main(
29 | pretrained_model_path: str,
30 | output_dir: str,
31 | validation_data: Dict,
32 | mixed_precision: Optional[str] = "fp16",
33 | enable_xformers_memory_efficient_attention: bool = True,
34 | seed: Optional[int] = None,
35 | inference_config: Dict = None,
36 | ):
37 | *_, config = inspect.getargvalues(inspect.currentframe())
38 |
39 | accelerator = Accelerator(
40 | mixed_precision=mixed_precision,
41 | project_dir=f"{output_dir}/acc_log"
42 | )
43 |
44 | # Make one log on every process with the configuration for debugging.
45 | logging.basicConfig(
46 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
47 | datefmt="%m/%d/%Y %H:%M:%S",
48 | level=logging.INFO,
49 | )
50 | logger.info(accelerator.state, main_process_only=False)
51 | if accelerator.is_local_main_process:
52 | transformers.utils.logging.set_verbosity_warning()
53 | diffusers.utils.logging.set_verbosity_info()
54 | else:
55 | transformers.utils.logging.set_verbosity_error()
56 | diffusers.utils.logging.set_verbosity_error()
57 |
58 | # If passed along, set the training seed now.
59 | if seed is not None:
60 | set_seed(seed)
61 |
62 | # Handle the output folder creation
63 | if accelerator.is_main_process:
64 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
65 | output_dir = os.path.join(output_dir, now)
66 | os.makedirs(output_dir, exist_ok=True)
67 | # OmegaConf.save(config, os.path.join(output_dir, 'config.yaml'))
68 |
69 | # Load scheduler, tokenizer and models.
70 | tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
71 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder")
72 | vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae")
73 | unet = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet")
74 |
75 | # Freeze vae and text_encoder
76 | vae.requires_grad_(False)
77 | text_encoder.requires_grad_(False)
78 | unet.requires_grad_(False)
79 |
80 | if enable_xformers_memory_efficient_attention:
81 | if is_xformers_available():
82 | unet.enable_xformers_memory_efficient_attention()
83 | else:
84 | raise ValueError("xformers is not available. Make sure it is installed correctly")
85 |
86 | # Get the validation pipeline
87 | validation_pipeline = SpatioTemporalPipeline(
88 | vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet,
89 | scheduler=DDIMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler"),
90 | disk_store=False,
91 | config=config
92 | )
93 | validation_pipeline.enable_vae_slicing()
94 | validation_pipeline.scheduler.set_timesteps(validation_data.num_inv_steps)
95 |
96 | # Prepare everything with our `accelerator`.
97 | unet = accelerator.prepare(unet)
98 |
99 | # For mixed precision training we cast the text_encoder and vae weights to half-precision
100 | # as these models are only used for inference, keeping weights in full precision is not required.
101 | weight_dtype = torch.float32
102 | if accelerator.mixed_precision == "fp16":
103 | weight_dtype = torch.float16
104 | elif accelerator.mixed_precision == "bf16":
105 | weight_dtype = torch.bfloat16
106 |
107 | # Move text_encode and vae to gpu and cast to weight_dtype
108 | text_encoder.to(accelerator.device, dtype=weight_dtype)
109 | vae.to(accelerator.device, dtype=weight_dtype)
110 | unet.to(accelerator.device, dtype=weight_dtype)
111 |
112 | # We need to initialize the trackers we use, and also store our configuration.
113 | # The trackers initializes automatically on the main process.
114 | if accelerator.is_main_process:
115 | accelerator.init_trackers("SpatioTemporal")
116 |
117 | text_encoder.eval()
118 | vae.eval()
119 | unet.eval()
120 |
121 | generator = torch.Generator(device=unet.device)
122 | generator.manual_seed(seed)
123 |
124 | samples = []
125 |
126 | prompt = list(validation_data.prompts)
127 | negative_prompt = config['validation_data']['negative_prompt']
128 | negative_prompt = [negative_prompt] * len(prompt)
129 |
130 | with (torch.no_grad()):
131 | x_base = validation_pipeline.prepare_latents(batch_size=1,
132 | num_channels_latents=4,
133 | video_length=len(prompt),
134 | height=512,
135 | width=512,
136 | dtype=weight_dtype,
137 | device=unet.device,
138 | generator=generator,
139 | store_attention=True,
140 | frame_same_noise=True)
141 |
142 | x_res = validation_pipeline.prepare_latents(batch_size=1,
143 | num_channels_latents=4,
144 | video_length=len(prompt),
145 | height=512,
146 | width=512,
147 | dtype=weight_dtype,
148 | device=unet.device,
149 | generator=generator,
150 | store_attention=True,
151 | frame_same_noise=False)
152 |
153 | x_T = np.cos(inference_config['diversity_rand_ratio'] * np.pi / 2) * x_base + np.sin(
154 | inference_config['diversity_rand_ratio'] * np.pi / 2) * x_res
155 |
156 | validation_data.pop('negative_prompt')
157 | # key frame
158 | key_frames, text_embedding = validation_pipeline(prompt, video_length=len(prompt), generator=generator,
159 | latents=x_T.type(weight_dtype),
160 | negative_prompt=negative_prompt,
161 | output_dir=output_dir,
162 | return_text_embedding=True,
163 | **validation_data)
164 |
165 | idx = [args.interp_idx, args.interp_idx + 1]
166 | config["inference_config"]["interpolation_step"] = args.interp_step
167 | interp = validation_pipeline.interpolate_between_two_frames(key_frames[0][:, :, idx], x_T[:, :, idx],
168 | rand_ratio=inference_config['diversity_rand_ratio'],
169 | where=idx,
170 | prompt=[prompt[i] for i in idx],
171 | k=args.interp_num,
172 | generator=generator,
173 | negative_prompt=negative_prompt[0],
174 | text_embedding=text_embedding[idx],
175 | base_noise=x_base[:, :, idx],
176 | **validation_data)
177 | sample = torch.cat([key_frames[0][:, :, :idx[0]], interp, key_frames[0][:, :, idx[1] + 1:]], dim=2)
178 | torch.cuda.empty_cache()
179 |
180 | samples.append(sample)
181 | samples = torch.concat(samples)
182 | save_videos_per_frames_grid(samples, f'{output_dir}/interp_img_samples', n_rows=6)
183 | logger.info(f"Saved samples to {output_dir}")
184 |
185 |
186 | if __name__ == "__main__":
187 | parser = argparse.ArgumentParser()
188 |
189 | parser.add_argument("--config", type=str, default="./configs/flowers.yaml")
190 | parser.add_argument("--interp_idx", type=int, required=True, help="where to interpolate")
191 | parser.add_argument("--interp_num", type=int, default=2)
192 | parser.add_argument("--interp_step", type=int, default=10)
193 | args = parser.parse_args()
194 |
195 | conf = OmegaConf.load(args.config)
196 | main(**conf)
197 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datetime
3 | import inspect
4 | import logging
5 | import os
6 | from typing import Dict, Optional
7 |
8 | import diffusers
9 | import numpy as np
10 | import torch
11 | import torch.utils.checkpoint
12 | import transformers
13 | from accelerate import Accelerator
14 | from accelerate.logging import get_logger
15 | from accelerate.utils import set_seed
16 | from diffusers import AutoencoderKL, DDIMScheduler
17 | from diffusers.utils.import_utils import is_xformers_available
18 | from omegaconf import OmegaConf
19 | from transformers import CLIPTextModel, CLIPTokenizer
20 |
21 | from freebloom.models.unet import UNet3DConditionModel
22 | from freebloom.pipelines.pipeline_spatio_temporal import SpatioTemporalPipeline
23 | from freebloom.util import save_videos_grid, save_videos_per_frames_grid
24 |
25 | logger = get_logger(__name__, log_level="INFO")
26 |
27 |
28 | def main(
29 | pretrained_model_path: str,
30 | output_dir: str,
31 | validation_data: Dict,
32 | mixed_precision: Optional[str] = "fp16",
33 | enable_xformers_memory_efficient_attention: bool = True,
34 | seed: Optional[int] = None,
35 | inference_config: Dict = None,
36 | ):
37 | *_, config = inspect.getargvalues(inspect.currentframe())
38 |
39 | accelerator = Accelerator(
40 | mixed_precision=mixed_precision,
41 | project_dir=f"{output_dir}/acc_log"
42 | )
43 |
44 | # Make one log on every process with the configuration for debugging.
45 | logging.basicConfig(
46 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
47 | datefmt="%m/%d/%Y %H:%M:%S",
48 | level=logging.INFO,
49 | )
50 | logger.info(accelerator.state, main_process_only=False)
51 | if accelerator.is_local_main_process:
52 | transformers.utils.logging.set_verbosity_warning()
53 | diffusers.utils.logging.set_verbosity_info()
54 | else:
55 | transformers.utils.logging.set_verbosity_error()
56 | diffusers.utils.logging.set_verbosity_error()
57 |
58 | # If passed along, set the training seed now.
59 | if seed is not None:
60 | set_seed(seed)
61 |
62 | # Handle the output folder creation
63 | if accelerator.is_main_process:
64 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
65 | output_dir = os.path.join(output_dir, now)
66 | os.makedirs(output_dir, exist_ok=True)
67 | os.makedirs(f"{output_dir}/samples", exist_ok=True)
68 | OmegaConf.save(config, os.path.join(output_dir, 'config.yaml'))
69 |
70 | # Load scheduler, tokenizer and models.
71 | tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
72 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder")
73 | vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae")
74 | unet = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet")
75 |
76 | # Freeze vae and text_encoder
77 | vae.requires_grad_(False)
78 | text_encoder.requires_grad_(False)
79 | unet.requires_grad_(False)
80 |
81 | if enable_xformers_memory_efficient_attention:
82 | if is_xformers_available():
83 | unet.enable_xformers_memory_efficient_attention()
84 | else:
85 | raise ValueError("xformers is not available. Make sure it is installed correctly")
86 |
87 | # Get the validation pipeline
88 | validation_pipeline = SpatioTemporalPipeline(
89 | vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet,
90 | scheduler=DDIMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler"),
91 | disk_store=False,
92 | config=config
93 | )
94 | validation_pipeline.enable_vae_slicing()
95 | validation_pipeline.scheduler.set_timesteps(validation_data.num_inv_steps)
96 |
97 | # Prepare everything with our `accelerator`.
98 | unet = accelerator.prepare(unet)
99 |
100 | # For mixed precision training we cast the text_encoder and vae weights to half-precision
101 | # as these models are only used for inference, keeping weights in full precision is not required.
102 | weight_dtype = torch.float32
103 | if accelerator.mixed_precision == "fp16":
104 | weight_dtype = torch.float16
105 | elif accelerator.mixed_precision == "bf16":
106 | weight_dtype = torch.bfloat16
107 |
108 | # Move text_encode and vae to gpu and cast to weight_dtype
109 | text_encoder.to(accelerator.device, dtype=weight_dtype)
110 | vae.to(accelerator.device, dtype=weight_dtype)
111 | unet.to(accelerator.device, dtype=weight_dtype)
112 |
113 | # We need to initialize the trackers we use, and also store our configuration.
114 | # The trackers initializes automatically on the main process.
115 | if accelerator.is_main_process:
116 | accelerator.init_trackers("SpatioTemporal")
117 |
118 | text_encoder.eval()
119 | vae.eval()
120 | unet.eval()
121 |
122 | generator = torch.Generator(device=unet.device)
123 | generator.manual_seed(seed)
124 |
125 | samples = []
126 |
127 | prompt = list(validation_data.prompts)
128 | negative_prompt = config['validation_data']['negative_prompt']
129 | negative_prompt = [negative_prompt] * len(prompt)
130 |
131 | with (torch.no_grad()):
132 | x_base = validation_pipeline.prepare_latents(batch_size=1,
133 | num_channels_latents=4,
134 | video_length=len(prompt),
135 | height=512,
136 | width=512,
137 | dtype=weight_dtype,
138 | device=unet.device,
139 | generator=generator,
140 | store_attention=True,
141 | frame_same_noise=True)
142 |
143 | x_res = validation_pipeline.prepare_latents(batch_size=1,
144 | num_channels_latents=4,
145 | video_length=len(prompt),
146 | height=512,
147 | width=512,
148 | dtype=weight_dtype,
149 | device=unet.device,
150 | generator=generator,
151 | store_attention=True,
152 | frame_same_noise=False)
153 |
154 | x_T = np.cos(inference_config['diversity_rand_ratio'] * np.pi / 2) * x_base + np.sin(
155 | inference_config['diversity_rand_ratio'] * np.pi / 2) * x_res
156 |
157 | validation_data.pop('negative_prompt')
158 | # key frame
159 | key_frames, text_embedding = validation_pipeline(prompt, video_length=len(prompt), generator=generator,
160 | latents=x_T.type(weight_dtype),
161 | negative_prompt=negative_prompt,
162 | output_dir=output_dir,
163 | return_text_embedding=True,
164 | **validation_data)
165 | torch.cuda.empty_cache()
166 |
167 | samples.append(key_frames[0])
168 | samples = torch.concat(samples)
169 | save_path = f"{output_dir}/samples/sample.gif"
170 | save_videos_grid(samples, save_path, n_rows=6)
171 | save_videos_per_frames_grid(samples, f'{output_dir}/img_samples', n_rows=6)
172 | logger.info(f"Saved samples to {save_path}")
173 |
174 |
175 | if __name__ == "__main__":
176 | parser = argparse.ArgumentParser()
177 |
178 | parser.add_argument("--config", type=str, default="./configs/flowers.yaml")
179 | args = parser.parse_args()
180 |
181 | conf = OmegaConf.load(args.config)
182 | main(**conf)
183 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.17.1
2 | decord==0.6.0
3 | diffusers==0.11.1
4 | einops==0.7.0
5 | imageio==2.26.0
6 | numpy==1.24.2
7 | omegaconf==2.3.0
8 | opencv_python==4.7.0.72
9 | packaging==23.2
10 | Pillow==10.1.0
11 | torch==1.13.1
12 | torchvision==0.14.1
13 | tqdm==4.65.0
14 | transformers==4.27.1
15 | xformers==0.0.16
16 |
--------------------------------------------------------------------------------