18 |
24 |
25 | -----------------
26 |
27 | 
28 | [](https://hits.seeyoufarm.com)
29 |
30 |
31 | ### [Project Page](https://showlab.github.io/Show-1) | [arXiv](https://arxiv.org/abs/2309.15818) | [PDF](https://arxiv.org/abs/2309.15818) | [🤗 Space](https://huggingface.co/spaces/showlab/Show-1) | [Colab](https://colab.research.google.com/github/camenduru/Show-1-colab/blob/main/Show_1_steps_colab.ipynb) | [Replicate Demo](https://replicate.com/cjwbw/show-1)
32 |
33 |
34 | ## News
35 | - [10/06/2024] Show-1 was accepted to IJCV!
36 | - [10/12/2023] Code and weights released!
37 |
38 | ## Setup
39 |
40 | ### Requirements
41 |
42 | ```shell
43 | pip install -r requirements.txt
44 | ```
45 |
46 | Note: PyTorch 2.0+ is highly recommended for more efficiency and speed on GPUs.
47 |
48 |
49 | ### Weights
50 |
51 | All model weights for Show-1 are available on [Show Lab's HuggingFace page](https://huggingface.co/showlab): Base Model ([show-1-base](https://huggingface.co/showlab/show-1-base)), Interpolation Model ([show-1-interpolation](https://huggingface.co/showlab/show-1-interpolation)), and Super-Resolution Model ([show-1-sr1](https://huggingface.co/showlab/show-1-sr1), [show-1-sr2](https://huggingface.co/showlab/show-1-sr2)).
52 |
53 | Note that our [show-1-sr1](https://huggingface.co/showlab/show-1-sr1) incorporates the image super-resolution model from DeepFloyd-IF, [DeepFloyd/IF-II-L-v1.0](https://huggingface.co/DeepFloyd/IF-II-L-v1.0), to upsample the first frame of the video. To obtain the respective weights, follow their [official instructions](https://huggingface.co/DeepFloyd/IF-II-L-v1.0).
54 |
55 | ## Usage
56 |
57 | To generate a video from a text prompt, run the command below:
58 |
59 | ```bash
60 | python run_inference.py
61 | ```
62 |
63 | By default, the videos generated from each stage are saved to the `outputs` folder in the GIF format. The script will automatically fetch the necessary model weights from HuggingFace. If you prefer, you can manually download the weights using git lfs and then update the `pretrained_model_path` to point to your local directory. Here's how:
64 |
65 | ```bash
66 | git lfs install
67 | git clone https://huggingface.co/showlab/show-1-base
68 | ```
69 |
70 | A demo is also available on the [`showlab/Show-1` 🤗 Space](https://huggingface.co/spaces/showlab/Show-1).
71 | You can use the gradio demo locally by running:
72 |
73 | ```bash
74 | python app.py
75 | ```
76 |
77 |
78 | ## Demo Video
79 | https://github.com/showlab/Show-1/assets/55792387/32242135-25a5-4757-b494-91bf314581e8
80 |
81 |
82 | ## Citation
83 | If you make use of our work, please cite our paper.
84 | ```bibtex
85 | @article{zhang2023show,
86 | title={Show-1: Marrying Pixel and Latent Diffusion Models for Text-to-Video Generation},
87 | author={Zhang, David Junhao and Wu, Jay Zhangjie and Liu, Jia-Wei and Zhao, Rui and Ran, Lingmin and Gu, Yuchao and Gao, Difei and Shou, Mike Zheng},
88 | journal={arXiv preprint arXiv:2309.15818},
89 | year={2023}
90 | }
91 | ```
92 |
93 | ## Commercial Use
94 |
95 | We are working with the university (NUS) to figure out the exact paperwork needed for approving commercial use request. In the meantime, to speed up the process, we'd like to solicit intent of interest from community and later on we will process these requests with high priority. If you are keen, can you kindly email us at mike.zheng.shou@gmail.com and junhao.zhang@u.nus.edu to answer the following questions, if possible:
96 | - Who are you / your company?
97 | - What is your product / application?
98 | - How Show-1 can benefit your product?
99 |
100 | ## Shoutouts
101 |
102 | - This work heavily builds on [diffusers](https://github.com/huggingface/diffusers), [deep-floyd/IF](https://github.com/deep-floyd/IF), [modelscope](https://huggingface.co/damo-vilab/modelscope-damo-text-to-video-synthesis), and [zeroscope](https://huggingface.co/cerspense/zeroscope_v2_576w). Thanks for open-sourcing!
103 | - Thanks [@camenduru](https://github.com/camenduru) for providing the CoLab demo and [@chenxwh](https://github.com/chenxwh) for providing replicate demo.
104 |
105 |
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 | import torch
3 | from diffusers.utils import export_to_video
4 |
5 | import os
6 | from PIL import Image
7 |
8 | import torch.nn.functional as F
9 |
10 | from diffusers import IFSuperResolutionPipeline, VideoToVideoSDPipeline
11 | from diffusers.utils import export_to_video
12 | from diffusers.utils.torch_utils import randn_tensor
13 |
14 | from showone.pipelines import TextToVideoIFPipeline, TextToVideoIFInterpPipeline, TextToVideoIFSuperResolutionPipeline
15 | from showone.pipelines.pipeline_t2v_base_pixel import tensor2vid
16 | from showone.pipelines.pipeline_t2v_sr_pixel_cond import TextToVideoIFSuperResolutionPipeline_Cond
17 |
18 |
19 | # Base Model
20 | pretrained_model_path = "showlab/show-1-base"
21 | pipe_base = TextToVideoIFPipeline.from_pretrained(
22 | pretrained_model_path,
23 | torch_dtype=torch.float16,
24 | variant="fp16"
25 | )
26 | pipe_base.enable_model_cpu_offload()
27 |
28 | # Interpolation Model
29 | pretrained_model_path = "showlab/show-1-interpolation"
30 | pipe_interp_1 = TextToVideoIFInterpPipeline.from_pretrained(
31 | pretrained_model_path,
32 | text_encoder=None,
33 | torch_dtype=torch.float16,
34 | variant="fp16"
35 | )
36 | pipe_interp_1.enable_model_cpu_offload()
37 |
38 | # Super-Resolution Model 1
39 | # Image super-resolution model from DeepFloyd https://huggingface.co/DeepFloyd/IF-II-L-v1.0
40 | pretrained_model_path = "DeepFloyd/IF-II-L-v1.0"
41 | pipe_sr_1_image = IFSuperResolutionPipeline.from_pretrained(
42 | pretrained_model_path,
43 | text_encoder=None,
44 | torch_dtype=torch.float16,
45 | variant="fp16",
46 | )
47 | pipe_sr_1_image.enable_model_cpu_offload()
48 |
49 | pretrained_model_path = "showlab/show-1-sr1"
50 | pipe_sr_1_cond = TextToVideoIFSuperResolutionPipeline_Cond.from_pretrained(
51 | pretrained_model_path,
52 | text_encoder=None,
53 | torch_dtype=torch.float16
54 | )
55 | pipe_sr_1_cond.enable_model_cpu_offload()
56 |
57 | # Super-Resolution Model 2
58 | pretrained_model_path = "showlab/show-1-sr2"
59 | pipe_sr_2 = VideoToVideoSDPipeline.from_pretrained(
60 | pretrained_model_path,
61 | torch_dtype=torch.float16
62 | )
63 | pipe_sr_2.enable_model_cpu_offload()
64 | pipe_sr_2.enable_vae_slicing()
65 |
66 | output_dir = "./outputs"
67 | os.makedirs(output_dir, exist_ok=True)
68 |
69 | def infer(prompt):
70 | print(prompt)
71 | negative_prompt = "low resolution, blur"
72 |
73 | # Text embeds
74 | prompt_embeds, negative_embeds = pipe_base.encode_prompt(prompt)
75 |
76 | # Keyframes generation (8x64x40, 2fps)
77 | video_frames = pipe_base(
78 | prompt_embeds=prompt_embeds,
79 | negative_prompt_embeds=negative_embeds,
80 | num_frames=8,
81 | height=40,
82 | width=64,
83 | num_inference_steps=75,
84 | guidance_scale=9.0,
85 | output_type="pt"
86 | ).frames
87 |
88 | # Frame interpolation (8x64x40, 2fps -> 29x64x40, 7.5fps)
89 | bsz, channel, num_frames, height, width = video_frames.shape
90 | new_num_frames = 3 * (num_frames - 1) + num_frames
91 | new_video_frames = torch.zeros((bsz, channel, new_num_frames, height, width),
92 | dtype=video_frames.dtype, device=video_frames.device)
93 | new_video_frames[:, :, torch.arange(0, new_num_frames, 4), ...] = video_frames
94 | init_noise = randn_tensor((bsz, channel, 5, height, width), dtype=video_frames.dtype,
95 | device=video_frames.device)
96 |
97 | for i in range(num_frames - 1):
98 | batch_i = torch.zeros((bsz, channel, 5, height, width), dtype=video_frames.dtype, device=video_frames.device)
99 | batch_i[:, :, 0, ...] = video_frames[:, :, i, ...]
100 | batch_i[:, :, -1, ...] = video_frames[:, :, i + 1, ...]
101 | batch_i = pipe_interp_1(
102 | pixel_values=batch_i,
103 | prompt_embeds=prompt_embeds,
104 | negative_prompt_embeds=negative_embeds,
105 | num_frames=batch_i.shape[2],
106 | height=40,
107 | width=64,
108 | num_inference_steps=50,
109 | guidance_scale=4.0,
110 | output_type="pt",
111 | init_noise=init_noise,
112 | cond_interpolation=True,
113 | ).frames
114 |
115 | new_video_frames[:, :, i * 4:i * 4 + 5, ...] = batch_i
116 |
117 | video_frames = new_video_frames
118 |
119 | # Super-resolution 1 (29x64x40 -> 29x256x160)
120 | bsz, channel, num_frames, height, width = video_frames.shape
121 | window_size, stride = 8, 7
122 | new_video_frames = torch.zeros(
123 | (bsz, channel, num_frames, height * 4, width * 4),
124 | dtype=video_frames.dtype,
125 | device=video_frames.device)
126 | for i in range(0, num_frames - window_size + 1, stride):
127 | batch_i = video_frames[:, :, i:i + window_size, ...]
128 |
129 | if i == 0:
130 | first_frame_cond = pipe_sr_1_image(
131 | image=video_frames[:, :, 0, ...],
132 | prompt_embeds=prompt_embeds,
133 | negative_prompt_embeds=negative_embeds,
134 | height=height * 4,
135 | width=width * 4,
136 | num_inference_steps=50,
137 | guidance_scale=4.0,
138 | noise_level=150,
139 | output_type="pt"
140 | ).images
141 | first_frame_cond = first_frame_cond.unsqueeze(2)
142 | else:
143 | first_frame_cond = new_video_frames[:, :, i:i + 1, ...]
144 |
145 | batch_i = pipe_sr_1_cond(
146 | image=batch_i,
147 | prompt_embeds=prompt_embeds,
148 | negative_prompt_embeds=negative_embeds,
149 | first_frame_cond=first_frame_cond,
150 | height=height * 4,
151 | width=width * 4,
152 | num_inference_steps=50,
153 | guidance_scale=7.0,
154 | noise_level=250,
155 | output_type="pt"
156 | ).frames
157 | new_video_frames[:, :, i:i + window_size, ...] = batch_i
158 |
159 | video_frames = new_video_frames
160 |
161 | # Super-resolution 2 (29x256x160 -> 29x576x320)
162 | video_frames = [Image.fromarray(frame).resize((576, 320)) for frame in tensor2vid(video_frames.clone())]
163 | video_frames = pipe_sr_2(
164 | prompt,
165 | negative_prompt=negative_prompt,
166 | video=video_frames,
167 | strength=0.8,
168 | num_inference_steps=50,
169 | ).frames
170 |
171 | video_path = export_to_video(video_frames, f"{output_dir}/{prompt[:200]}.mp4")
172 | print(video_path)
173 | return video_path
174 |
175 | css = """
176 | #col-container {max-width: 510px; margin-left: auto; margin-right: auto;}
177 | a {text-decoration-line: underline; font-weight: 600;}
178 | .animate-spin {
179 | animation: spin 1s linear infinite;
180 | }
181 |
182 | @keyframes spin {
183 | from {
184 | transform: rotate(0deg);
185 | }
186 | to {
187 | transform: rotate(360deg);
188 | }
189 | }
190 |
191 | #share-btn-container {
192 | display: flex;
193 | padding-left: 0.5rem !important;
194 | padding-right: 0.5rem !important;
195 | background-color: #000000;
196 | justify-content: center;
197 | align-items: center;
198 | border-radius: 9999px !important;
199 | max-width: 15rem;
200 | height: 36px;
201 | }
202 |
203 | div#share-btn-container > div {
204 | flex-direction: row;
205 | background: black;
206 | align-items: center;
207 | }
208 |
209 | #share-btn-container:hover {
210 | background-color: #060606;
211 | }
212 |
213 | #share-btn {
214 | all: initial;
215 | color: #ffffff;
216 | font-weight: 600;
217 | cursor:pointer;
218 | font-family: 'IBM Plex Sans', sans-serif;
219 | margin-left: 0.5rem !important;
220 | padding-top: 0.5rem !important;
221 | padding-bottom: 0.5rem !important;
222 | right:0;
223 | }
224 |
225 | #share-btn * {
226 | all: unset;
227 | }
228 |
229 | #share-btn-container div:nth-child(-n+2){
230 | width: auto !important;
231 | min-height: 0px !important;
232 | }
233 |
234 | #share-btn-container .wrap {
235 | display: none !important;
236 | }
237 |
238 | #share-btn-container.hidden {
239 | display: none!important;
240 | }
241 | img[src*='#center'] {
242 | display: inline-block;
243 | margin: unset;
244 | }
245 |
246 | .footer {
247 | margin-bottom: 45px;
248 | margin-top: 10px;
249 | text-align: center;
250 | border-bottom: 1px solid #e5e5e5;
251 | }
252 | .footer>p {
253 | font-size: .8rem;
254 | display: inline-block;
255 | padding: 0 10px;
256 | transform: translateY(10px);
257 | background: white;
258 | }
259 | .dark .footer {
260 | border-color: #303030;
261 | }
262 | .dark .footer>p {
263 | background: #0b0f19;
264 | }
265 | """
266 |
267 | with gr.Blocks(css=css) as demo:
268 | with gr.Column(elem_id="col-container"):
269 | gr.Markdown(
270 | """
271 |
Show-1 Text-to-Video
272 |
273 | A text-to-video generation model that marries the strength and alleviates the weakness of pixel-based and latent-based VDMs.
274 |
275 |
276 |
277 | Paper |
278 | Project Page |
279 | Github
280 |
281 |
282 | """
283 | )
284 |
285 | prompt_in = gr.Textbox(label="Prompt", placeholder="A panda taking a selfie", elem_id="prompt-in")
286 | #neg_prompt = gr.Textbox(label="Negative prompt", value="text, watermark, copyright, blurry, nsfw", elem_id="neg-prompt-in")
287 | #inference_steps = gr.Slider(label="Inference Steps", minimum=10, maximum=100, step=1, value=40, interactive=False)
288 | submit_btn = gr.Button("Submit")
289 | video_result = gr.Video(label="Video Output", elem_id="video-output")
290 |
291 | gr.HTML("""
292 |
298 | """)
299 |
300 | submit_btn.click(fn=infer,
301 | inputs=[prompt_in],
302 | outputs=[video_result],
303 | api_name="show-1")
304 |
305 | demo.queue(max_size=12).launch(show_api=True)
306 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | diffusers==0.19.3
2 | bitsandbytes==0.35.4
3 | decord==0.6.0
4 | transformers==4.29.1
5 | accelerate==0.18.0
6 | imageio==2.14.1
7 | torch==2.0.0
8 | torchvision==0.15.0
9 | beautifulsoup4
10 | tensorboard
11 | sentencepiece
12 | safetensors
13 | modelcards
14 | omegaconf
15 | pandas
16 | einops
17 | ftfy
18 |
--------------------------------------------------------------------------------
/run_inference.py:
--------------------------------------------------------------------------------
1 | import os
2 | import imageio
3 | from PIL import Image
4 |
5 | import torch
6 | import torch.nn.functional as F
7 |
8 | from diffusers import IFSuperResolutionPipeline, VideoToVideoSDPipeline
9 | from diffusers.utils.torch_utils import randn_tensor
10 |
11 | from showone.pipelines import TextToVideoIFPipeline, TextToVideoIFInterpPipeline, TextToVideoIFSuperResolutionPipeline
12 | from showone.pipelines.pipeline_t2v_base_pixel import tensor2vid
13 | from showone.pipelines.pipeline_t2v_sr_pixel_cond import TextToVideoIFSuperResolutionPipeline_Cond
14 |
15 |
16 | # Base Model
17 | # When using "showlab/show-1-base-0.0", it's advisable to increase the number of inference steps (e.g., 100)
18 | # and opt for a larger guidance scale (e.g., 12.0) to enhance visual quality.
19 | pretrained_model_path = "showlab/show-1-base"
20 | pipe_base = TextToVideoIFPipeline.from_pretrained(
21 | pretrained_model_path,
22 | torch_dtype=torch.float16,
23 | variant="fp16"
24 | )
25 | pipe_base.enable_model_cpu_offload()
26 |
27 | # Interpolation Model
28 | pretrained_model_path = "showlab/show-1-interpolation"
29 | pipe_interp_1 = TextToVideoIFInterpPipeline.from_pretrained(
30 | pretrained_model_path,
31 | torch_dtype=torch.float16,
32 | variant="fp16"
33 | )
34 | pipe_interp_1.enable_model_cpu_offload()
35 |
36 | # Super-Resolution Model 1
37 | # Image super-resolution model from DeepFloyd https://huggingface.co/DeepFloyd/IF-II-L-v1.0
38 | pretrained_model_path = "DeepFloyd/IF-II-L-v1.0"
39 | pipe_sr_1_image = IFSuperResolutionPipeline.from_pretrained(
40 | pretrained_model_path,
41 | text_encoder=None,
42 | torch_dtype=torch.float16,
43 | variant="fp16"
44 | )
45 | pipe_sr_1_image.enable_model_cpu_offload()
46 |
47 | pretrained_model_path = "showlab/show-1-sr1"
48 | pipe_sr_1_cond = TextToVideoIFSuperResolutionPipeline_Cond.from_pretrained(
49 | pretrained_model_path,
50 | torch_dtype=torch.float16
51 | )
52 | pipe_sr_1_cond.enable_model_cpu_offload()
53 |
54 | # Super-Resolution Model 2
55 | pretrained_model_path = "showlab/show-1-sr2"
56 | pipe_sr_2 = VideoToVideoSDPipeline.from_pretrained(
57 | pretrained_model_path,
58 | torch_dtype=torch.float16
59 | )
60 | pipe_sr_2.enable_model_cpu_offload()
61 | pipe_sr_2.enable_vae_slicing()
62 |
63 |
64 | # Inference
65 | prompt = "A burning lamborghini driving on rainbow."
66 | output_dir = "./outputs/example"
67 | negative_prompt = "low resolution, blur"
68 |
69 | seed = 345
70 | os.makedirs(output_dir, exist_ok=True)
71 |
72 | # Text embeds
73 | prompt_embeds, negative_embeds = pipe_base.encode_prompt(prompt)
74 |
75 | # Keyframes generation (8x64x40, 2fps)
76 | video_frames = pipe_base(
77 | prompt_embeds=prompt_embeds,
78 | negative_prompt_embeds=negative_embeds,
79 | num_frames=8,
80 | height=40,
81 | width=64,
82 | num_inference_steps=75,
83 | guidance_scale=9.0,
84 | generator=torch.manual_seed(seed),
85 | output_type="pt"
86 | ).frames
87 |
88 | imageio.mimsave(f"{output_dir}/{prompt}_base.gif", tensor2vid(video_frames.clone()), fps=2)
89 |
90 | # Frame interpolation (8x64x40, 2fps -> 29x64x40, 7.5fps)
91 | bsz, channel, num_frames, height, width = video_frames.shape
92 | new_num_frames = 3 * (num_frames - 1) + num_frames
93 | new_video_frames = torch.zeros((bsz, channel, new_num_frames, height, width),
94 | dtype=video_frames.dtype, device=video_frames.device)
95 | new_video_frames[:, :, torch.arange(0, new_num_frames, 4), ...] = video_frames
96 | init_noise = randn_tensor((bsz, channel, 5, height, width), dtype=video_frames.dtype,
97 | device=video_frames.device, generator=torch.manual_seed(seed))
98 |
99 | for i in range(num_frames - 1):
100 | batch_i = torch.zeros((bsz, channel, 5, height, width), dtype=video_frames.dtype, device=video_frames.device)
101 | batch_i[:, :, 0, ...] = video_frames[:, :, i, ...]
102 | batch_i[:, :, -1, ...] = video_frames[:, :, i + 1, ...]
103 | batch_i = pipe_interp_1(
104 | pixel_values=batch_i,
105 | prompt_embeds=prompt_embeds,
106 | negative_prompt_embeds=negative_embeds,
107 | num_frames=batch_i.shape[2],
108 | height=40,
109 | width=64,
110 | num_inference_steps=75,
111 | guidance_scale=4.0,
112 | generator=torch.manual_seed(seed),
113 | output_type="pt",
114 | init_noise=init_noise,
115 | cond_interpolation=True,
116 | ).frames
117 |
118 | new_video_frames[:, :, i * 4:i * 4 + 5, ...] = batch_i
119 |
120 | video_frames = new_video_frames
121 | imageio.mimsave(f"{output_dir}/{prompt}_interp.gif", tensor2vid(video_frames.clone()), fps=8)
122 |
123 | # Super-resolution 1 (29x64x40 -> 29x256x160)
124 | bsz, channel, num_frames, height, width = video_frames.shape
125 | window_size, stride = 8, 7
126 | new_video_frames = torch.zeros(
127 | (bsz, channel, num_frames, height * 4, width * 4),
128 | dtype=video_frames.dtype,
129 | device=video_frames.device)
130 | for i in range(0, num_frames - window_size + 1, stride):
131 | batch_i = video_frames[:, :, i:i + window_size, ...]
132 | all_frame_cond = None
133 |
134 | if i == 0:
135 | first_frame_cond = pipe_sr_1_image(
136 | image=video_frames[:, :, 0, ...],
137 | prompt_embeds=prompt_embeds,
138 | negative_prompt_embeds=negative_embeds,
139 | height=height * 4,
140 | width=width * 4,
141 | num_inference_steps=70,
142 | guidance_scale=4.0,
143 | noise_level=150,
144 | generator=torch.manual_seed(seed),
145 | output_type="pt"
146 | ).images
147 | first_frame_cond = first_frame_cond.unsqueeze(2)
148 | else:
149 | first_frame_cond = new_video_frames[:, :, i:i + 1, ...]
150 |
151 | batch_i = pipe_sr_1_cond(
152 | image=batch_i,
153 | prompt_embeds=prompt_embeds,
154 | negative_prompt_embeds=negative_embeds,
155 | first_frame_cond=first_frame_cond,
156 | height=height * 4,
157 | width=width * 4,
158 | num_inference_steps=125,
159 | guidance_scale=7.0,
160 | noise_level=250,
161 | generator=torch.manual_seed(seed),
162 | output_type="pt"
163 | ).frames
164 | new_video_frames[:, :, i:i + window_size, ...] = batch_i
165 |
166 | video_frames = new_video_frames
167 | imageio.mimsave(f"{output_dir}/{prompt}_sr1.gif", tensor2vid(video_frames.clone()), fps=8)
168 |
169 | # Super-resolution 2 (29x256x160 -> 29x576x320)
170 | video_frames = [Image.fromarray(frame).resize((576, 320)) for frame in tensor2vid(video_frames.clone())]
171 | video_frames = pipe_sr_2(
172 | prompt,
173 | negative_prompt=negative_prompt,
174 | video=video_frames,
175 | strength=0.8,
176 | num_inference_steps=50,
177 | generator=torch.manual_seed(seed),
178 | output_type="pt"
179 | ).frames
180 |
181 | imageio.mimsave(f"{output_dir}/{prompt}.gif", tensor2vid(video_frames.clone()), fps=8)
182 |
--------------------------------------------------------------------------------
/showone/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .unet_3d_condition import UNet3DConditionModel
--------------------------------------------------------------------------------
/showone/models/transformer_temporal.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from dataclasses import dataclass
15 | from typing import Optional
16 |
17 | import torch
18 | from torch import nn
19 |
20 | from diffusers.configuration_utils import ConfigMixin, register_to_config
21 | from diffusers.utils import BaseOutput
22 | from diffusers.models.attention import BasicTransformerBlock
23 | from diffusers.models.modeling_utils import ModelMixin
24 |
25 |
26 | @dataclass
27 | class TransformerTemporalModelOutput(BaseOutput):
28 | """
29 | The output of [`TransformerTemporalModel`].
30 |
31 | Args:
32 | sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`):
33 | The hidden states output conditioned on `encoder_hidden_states` input.
34 | """
35 |
36 | sample: torch.FloatTensor
37 |
38 |
39 | class TransformerTemporalModel(ModelMixin, ConfigMixin):
40 | """
41 | A Transformer model for video-like data.
42 |
43 | Parameters:
44 | num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
45 | attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
46 | in_channels (`int`, *optional*):
47 | The number of channels in the input and output (specify if the input is **continuous**).
48 | num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
49 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
50 | cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
51 | sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
52 | This is fixed during training since it is used to learn a number of position embeddings.
53 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
54 | attention_bias (`bool`, *optional*):
55 | Configure if the `TransformerBlock` attention should contain a bias parameter.
56 | double_self_attention (`bool`, *optional*):
57 | Configure if each `TransformerBlock` should contain two self-attention layers.
58 | """
59 |
60 | @register_to_config
61 | def __init__(
62 | self,
63 | num_attention_heads: int = 16,
64 | attention_head_dim: int = 88,
65 | in_channels: Optional[int] = None,
66 | out_channels: Optional[int] = None,
67 | num_layers: int = 1,
68 | dropout: float = 0.0,
69 | norm_num_groups: int = 32,
70 | cross_attention_dim: Optional[int] = None,
71 | attention_bias: bool = False,
72 | sample_size: Optional[int] = None,
73 | activation_fn: str = "geglu",
74 | norm_elementwise_affine: bool = True,
75 | double_self_attention: bool = True,
76 | ):
77 | super().__init__()
78 | self.num_attention_heads = num_attention_heads
79 | self.attention_head_dim = attention_head_dim
80 | inner_dim = num_attention_heads * attention_head_dim
81 |
82 | self.in_channels = in_channels
83 |
84 | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
85 | self.proj_in = nn.Linear(in_channels, inner_dim)
86 |
87 | # 3. Define transformers blocks
88 | self.transformer_blocks = nn.ModuleList(
89 | [
90 | BasicTransformerBlock(
91 | inner_dim,
92 | num_attention_heads,
93 | attention_head_dim,
94 | dropout=dropout,
95 | cross_attention_dim=cross_attention_dim,
96 | activation_fn=activation_fn,
97 | attention_bias=attention_bias,
98 | double_self_attention=double_self_attention,
99 | norm_elementwise_affine=norm_elementwise_affine,
100 | )
101 | for d in range(num_layers)
102 | ]
103 | )
104 |
105 | self.proj_out = nn.Linear(inner_dim, in_channels)
106 |
107 | def forward(
108 | self,
109 | hidden_states,
110 | encoder_hidden_states=None,
111 | timestep=None,
112 | class_labels=None,
113 | num_frames=1,
114 | cross_attention_kwargs=None,
115 | return_dict: bool = True,
116 | ):
117 | """
118 | The [`TransformerTemporal`] forward method.
119 |
120 | Args:
121 | hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
122 | Input hidden_states.
123 | encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
124 | Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
125 | self-attention.
126 | timestep ( `torch.long`, *optional*):
127 | Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
128 | class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
129 | Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
130 | `AdaLayerZeroNorm`.
131 | return_dict (`bool`, *optional*, defaults to `True`):
132 | Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
133 | tuple.
134 |
135 | Returns:
136 | [`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
137 | If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
138 | returned, otherwise a `tuple` where the first element is the sample tensor.
139 | """
140 | # 1. Input
141 | batch_frames, channel, height, width = hidden_states.shape
142 | batch_size = batch_frames // num_frames
143 |
144 | residual = hidden_states
145 |
146 | hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width)
147 | hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
148 |
149 | hidden_states = self.norm(hidden_states)
150 | hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel)
151 |
152 | hidden_states = self.proj_in(hidden_states)
153 |
154 | # 2. Blocks
155 | for block in self.transformer_blocks:
156 | hidden_states = block(
157 | hidden_states,
158 | encoder_hidden_states=encoder_hidden_states,
159 | timestep=timestep,
160 | cross_attention_kwargs=cross_attention_kwargs,
161 | class_labels=class_labels,
162 | )
163 |
164 | # 3. Output
165 | hidden_states = self.proj_out(hidden_states)
166 | hidden_states = (
167 | hidden_states[None, None, :]
168 | .reshape(batch_size, height, width, channel, num_frames)
169 | .permute(0, 3, 4, 1, 2)
170 | .contiguous()
171 | )
172 | hidden_states = hidden_states.reshape(batch_frames, channel, height, width)
173 |
174 | output = hidden_states + residual
175 |
176 | if not return_dict:
177 | return (output,)
178 |
179 | return TransformerTemporalModelOutput(sample=output)
180 |
--------------------------------------------------------------------------------
/showone/pipelines/__init__.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import List, Union
3 |
4 | import numpy as np
5 | import torch
6 |
7 | from diffusers.utils import BaseOutput, OptionalDependencyNotAvailable, is_torch_available, is_transformers_available
8 |
9 |
10 | @dataclass
11 | class TextToVideoPipelineOutput(BaseOutput):
12 | """
13 | Output class for text to video pipelines.
14 |
15 | Args:
16 | frames (`List[np.ndarray]` or `torch.FloatTensor`)
17 | List of denoised frames (essentially images) as NumPy arrays of shape `(height, width, num_channels)` or as
18 | a `torch` tensor. NumPy array present the denoised images of the diffusion pipeline. The length of the list
19 | denotes the video length i.e., the number of frames.
20 | """
21 |
22 | frames: Union[List[np.ndarray], torch.FloatTensor]
23 |
24 |
25 | try:
26 | if not (is_transformers_available() and is_torch_available()):
27 | raise OptionalDependencyNotAvailable()
28 | except OptionalDependencyNotAvailable:
29 | from diffusers.utils.dummy_torch_and_transformers_objects import * # noqa F403
30 | else:
31 | # from .pipeline_t2v_base_latent import TextToVideoSDPipeline # noqa: F401
32 | # from .pipeline_t2v_base_latent_sdxl import TextToVideoSDXLPipeline
33 | from .pipeline_t2v_base_pixel import TextToVideoIFPipeline
34 | from .pipeline_t2v_interp_pixel import TextToVideoIFInterpPipeline
35 | # from .pipeline_t2v_sr_latent import TextToVideoSDSuperResolutionPipeline
36 | from .pipeline_t2v_sr_pixel import TextToVideoIFSuperResolutionPipeline
37 | # from .pipeline_t2v_base_latent_controlnet import TextToVideoSDControlNetPipeline
38 |
--------------------------------------------------------------------------------
/showone/pipelines/pipeline_t2v_base_pixel.py:
--------------------------------------------------------------------------------
1 | import html
2 | import inspect
3 | import re
4 | import urllib.parse as ul
5 | from typing import Any, Callable, Dict, List, Optional, Union
6 |
7 | import numpy as np
8 | import torch
9 | from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer
10 |
11 | from diffusers.loaders import LoraLoaderMixin
12 | from diffusers.schedulers import DDPMScheduler
13 | from diffusers.utils import (
14 | BACKENDS_MAPPING,
15 | is_accelerate_available,
16 | is_accelerate_version,
17 | is_bs4_available,
18 | is_ftfy_available,
19 | logging,
20 | randn_tensor,
21 | )
22 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline
23 |
24 | from ..models import UNet3DConditionModel
25 | from . import TextToVideoPipelineOutput
26 |
27 |
28 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
29 |
30 | if is_bs4_available():
31 | from bs4 import BeautifulSoup
32 |
33 | if is_ftfy_available():
34 | import ftfy
35 |
36 |
37 | def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> List[np.ndarray]:
38 | # This code is copied from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
39 | # reshape to ncfhw
40 | mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1)
41 | std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1)
42 | # unnormalize back to [0,1]
43 | video = video.mul_(std).add_(mean)
44 | video.clamp_(0, 1)
45 | # prepare the final outputs
46 | i, c, f, h, w = video.shape
47 | images = video.permute(2, 3, 0, 4, 1).reshape(
48 | f, h, i * w, c
49 | ) # 1st (frames, h, batch_size, w, c) 2nd (frames, h, batch_size * w, c)
50 | images = images.unbind(dim=0) # prepare a list of indvidual (consecutive frames)
51 | images = [(image.cpu().numpy() * 255).astype("uint8") for image in images] # f h w c
52 | return images
53 |
54 |
55 | class TextToVideoIFPipeline(DiffusionPipeline, LoraLoaderMixin):
56 | tokenizer: T5Tokenizer
57 | text_encoder: T5EncoderModel
58 |
59 | unet: UNet3DConditionModel
60 | scheduler: DDPMScheduler
61 |
62 | feature_extractor: Optional[CLIPImageProcessor]
63 | # safety_checker: Optional[IFSafetyChecker]
64 |
65 | # watermarker: Optional[IFWatermarker]
66 |
67 | bad_punct_regex = re.compile(
68 | r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
69 | ) # noqa
70 |
71 | _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"]
72 |
73 | def __init__(
74 | self,
75 | tokenizer: T5Tokenizer,
76 | text_encoder: T5EncoderModel,
77 | unet: UNet3DConditionModel,
78 | scheduler: DDPMScheduler,
79 | feature_extractor: Optional[CLIPImageProcessor],
80 | ):
81 | super().__init__()
82 |
83 | self.register_modules(
84 | tokenizer=tokenizer,
85 | text_encoder=text_encoder,
86 | unet=unet,
87 | scheduler=scheduler,
88 | feature_extractor=feature_extractor,
89 | )
90 | self.safety_checker = None
91 |
92 | def enable_sequential_cpu_offload(self, gpu_id=0):
93 | r"""
94 | Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
95 | models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
96 | when their specific submodule has its `forward` method called.
97 | """
98 | if is_accelerate_available():
99 | from accelerate import cpu_offload
100 | else:
101 | raise ImportError("Please install accelerate via `pip install accelerate`")
102 |
103 | device = torch.device(f"cuda:{gpu_id}")
104 |
105 | models = [
106 | self.text_encoder,
107 | self.unet,
108 | ]
109 | for cpu_offloaded_model in models:
110 | if cpu_offloaded_model is not None:
111 | cpu_offload(cpu_offloaded_model, device)
112 |
113 | if self.safety_checker is not None:
114 | cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
115 |
116 | def enable_model_cpu_offload(self, gpu_id=0):
117 | r"""
118 | Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
119 | to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
120 | method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
121 | `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
122 | """
123 | if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
124 | from accelerate import cpu_offload_with_hook
125 | else:
126 | raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
127 |
128 | device = torch.device(f"cuda:{gpu_id}")
129 |
130 | self.unet.train()
131 |
132 | if self.device.type != "cpu":
133 | self.to("cpu", silence_dtype_warnings=True)
134 | torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
135 |
136 | hook = None
137 |
138 | if self.text_encoder is not None:
139 | _, hook = cpu_offload_with_hook(self.text_encoder, device, prev_module_hook=hook)
140 |
141 | # Accelerate will move the next model to the device _before_ calling the offload hook of the
142 | # previous model. This will cause both models to be present on the device at the same time.
143 | # IF uses T5 for its text encoder which is really large. We can manually call the offload
144 | # hook for the text encoder to ensure it's moved to the cpu before the unet is moved to
145 | # the GPU.
146 | self.text_encoder_offload_hook = hook
147 |
148 | _, hook = cpu_offload_with_hook(self.unet, device, prev_module_hook=hook)
149 |
150 | # if the safety checker isn't called, `unet_offload_hook` will have to be called to manually offload the unet
151 | self.unet_offload_hook = hook
152 |
153 | if self.safety_checker is not None:
154 | _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
155 |
156 | # We'll offload the last model manually.
157 | self.final_offload_hook = hook
158 |
159 | def remove_all_hooks(self):
160 | if is_accelerate_available():
161 | from accelerate.hooks import remove_hook_from_module
162 | else:
163 | raise ImportError("Please install accelerate via `pip install accelerate`")
164 |
165 | for model in [self.text_encoder, self.unet, self.safety_checker]:
166 | if model is not None:
167 | remove_hook_from_module(model, recurse=True)
168 |
169 | self.unet_offload_hook = None
170 | self.text_encoder_offload_hook = None
171 | self.final_offload_hook = None
172 |
173 | @property
174 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
175 | def _execution_device(self):
176 | r"""
177 | Returns the device on which the pipeline's models will be executed. After calling
178 | `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
179 | hooks.
180 | """
181 | if not hasattr(self.unet, "_hf_hook"):
182 | return self.device
183 | for module in self.unet.modules():
184 | if (
185 | hasattr(module, "_hf_hook")
186 | and hasattr(module._hf_hook, "execution_device")
187 | and module._hf_hook.execution_device is not None
188 | ):
189 | return torch.device(module._hf_hook.execution_device)
190 | return self.device
191 |
192 | @torch.no_grad()
193 | def encode_prompt(
194 | self,
195 | prompt,
196 | do_classifier_free_guidance=True,
197 | num_images_per_prompt=1,
198 | device=None,
199 | negative_prompt=None,
200 | prompt_embeds: Optional[torch.FloatTensor] = None,
201 | negative_prompt_embeds: Optional[torch.FloatTensor] = None,
202 | clean_caption: bool = False,
203 | ):
204 | r"""
205 | Encodes the prompt into text encoder hidden states.
206 |
207 | Args:
208 | prompt (`str` or `List[str]`, *optional*):
209 | prompt to be encoded
210 | device: (`torch.device`, *optional*):
211 | torch device to place the resulting embeddings on
212 | num_images_per_prompt (`int`, *optional*, defaults to 1):
213 | number of images that should be generated per prompt
214 | do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
215 | whether to use classifier free guidance or not
216 | negative_prompt (`str` or `List[str]`, *optional*):
217 | The prompt or prompts not to guide the image generation. If not defined, one has to pass
218 | `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
219 | Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
220 | prompt_embeds (`torch.FloatTensor`, *optional*):
221 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
222 | provided, text embeddings will be generated from `prompt` input argument.
223 | negative_prompt_embeds (`torch.FloatTensor`, *optional*):
224 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
225 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
226 | argument.
227 | """
228 | if prompt is not None and negative_prompt is not None:
229 | if type(prompt) is not type(negative_prompt):
230 | raise TypeError(
231 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
232 | f" {type(prompt)}."
233 | )
234 |
235 | if device is None:
236 | device = self._execution_device
237 |
238 | if prompt is not None and isinstance(prompt, str):
239 | batch_size = 1
240 | elif prompt is not None and isinstance(prompt, list):
241 | batch_size = len(prompt)
242 | else:
243 | batch_size = prompt_embeds.shape[0]
244 |
245 | # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF
246 | max_length = 77
247 |
248 | if prompt_embeds is None:
249 | prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
250 | text_inputs = self.tokenizer(
251 | prompt,
252 | padding="max_length",
253 | max_length=max_length,
254 | truncation=True,
255 | add_special_tokens=True,
256 | return_tensors="pt",
257 | )
258 | text_input_ids = text_inputs.input_ids
259 | untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
260 |
261 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
262 | text_input_ids, untruncated_ids
263 | ):
264 | removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
265 | logger.warning(
266 | "The following part of your input was truncated because CLIP can only handle sequences up to"
267 | f" {max_length} tokens: {removed_text}"
268 | )
269 |
270 | attention_mask = text_inputs.attention_mask.to(device)
271 |
272 | prompt_embeds = self.text_encoder(
273 | text_input_ids.to(device),
274 | attention_mask=attention_mask,
275 | )
276 | prompt_embeds = prompt_embeds[0]
277 |
278 | if self.text_encoder is not None:
279 | dtype = self.text_encoder.dtype
280 | elif self.unet is not None:
281 | dtype = self.unet.dtype
282 | else:
283 | dtype = None
284 |
285 | prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
286 |
287 | bs_embed, seq_len, _ = prompt_embeds.shape
288 | # duplicate text embeddings for each generation per prompt, using mps friendly method
289 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
290 | prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
291 |
292 | # get unconditional embeddings for classifier free guidance
293 | if do_classifier_free_guidance and negative_prompt_embeds is None:
294 | uncond_tokens: List[str]
295 | if negative_prompt is None:
296 | uncond_tokens = [""] * batch_size
297 | elif isinstance(negative_prompt, str):
298 | uncond_tokens = [negative_prompt]
299 | elif batch_size != len(negative_prompt):
300 | raise ValueError(
301 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
302 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
303 | " the batch size of `prompt`."
304 | )
305 | else:
306 | uncond_tokens = negative_prompt
307 |
308 | uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
309 | max_length = prompt_embeds.shape[1]
310 | uncond_input = self.tokenizer(
311 | uncond_tokens,
312 | padding="max_length",
313 | max_length=max_length,
314 | truncation=True,
315 | return_attention_mask=True,
316 | add_special_tokens=True,
317 | return_tensors="pt",
318 | )
319 | attention_mask = uncond_input.attention_mask.to(device)
320 |
321 | negative_prompt_embeds = self.text_encoder(
322 | uncond_input.input_ids.to(device),
323 | attention_mask=attention_mask,
324 | )
325 | negative_prompt_embeds = negative_prompt_embeds[0]
326 |
327 | if do_classifier_free_guidance:
328 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
329 | seq_len = negative_prompt_embeds.shape[1]
330 |
331 | negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
332 |
333 | negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
334 | negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
335 |
336 | # For classifier free guidance, we need to do two forward passes.
337 | # Here we concatenate the unconditional and text embeddings into a single batch
338 | # to avoid doing two forward passes
339 | else:
340 | negative_prompt_embeds = None
341 |
342 | return prompt_embeds, negative_prompt_embeds
343 |
344 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
345 | def prepare_extra_step_kwargs(self, generator, eta):
346 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
347 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
348 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
349 | # and should be between [0, 1]
350 |
351 | accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
352 | extra_step_kwargs = {}
353 | if accepts_eta:
354 | extra_step_kwargs["eta"] = eta
355 |
356 | # check if the scheduler accepts generator
357 | accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
358 | if accepts_generator:
359 | extra_step_kwargs["generator"] = generator
360 | return extra_step_kwargs
361 |
362 | def check_inputs(
363 | self,
364 | prompt,
365 | callback_steps,
366 | negative_prompt=None,
367 | prompt_embeds=None,
368 | negative_prompt_embeds=None,
369 | ):
370 | if (callback_steps is None) or (
371 | callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
372 | ):
373 | raise ValueError(
374 | f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
375 | f" {type(callback_steps)}."
376 | )
377 |
378 | if prompt is not None and prompt_embeds is not None:
379 | raise ValueError(
380 | f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
381 | " only forward one of the two."
382 | )
383 | elif prompt is None and prompt_embeds is None:
384 | raise ValueError(
385 | "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
386 | )
387 | elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
388 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
389 |
390 | if negative_prompt is not None and negative_prompt_embeds is not None:
391 | raise ValueError(
392 | f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
393 | f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
394 | )
395 |
396 | if prompt_embeds is not None and negative_prompt_embeds is not None:
397 | if prompt_embeds.shape != negative_prompt_embeds.shape:
398 | raise ValueError(
399 | "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
400 | f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
401 | f" {negative_prompt_embeds.shape}."
402 | )
403 |
404 | def prepare_intermediate_images(self, batch_size, num_channels, num_frames, height, width, dtype, device, generator):
405 | shape = (batch_size, num_channels, num_frames, height, width)
406 | if isinstance(generator, list) and len(generator) != batch_size:
407 | raise ValueError(
408 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
409 | f" size of {batch_size}. Make sure the batch size matches the length of the generators."
410 | )
411 |
412 | intermediate_images = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
413 |
414 | # scale the initial noise by the standard deviation required by the scheduler
415 | intermediate_images = intermediate_images * self.scheduler.init_noise_sigma
416 | return intermediate_images
417 |
418 | def _text_preprocessing(self, text, clean_caption=False):
419 | if clean_caption and not is_bs4_available():
420 | logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
421 | logger.warn("Setting `clean_caption` to False...")
422 | clean_caption = False
423 |
424 | if clean_caption and not is_ftfy_available():
425 | logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
426 | logger.warn("Setting `clean_caption` to False...")
427 | clean_caption = False
428 |
429 | if not isinstance(text, (tuple, list)):
430 | text = [text]
431 |
432 | def process(text: str):
433 | if clean_caption:
434 | text = self._clean_caption(text)
435 | text = self._clean_caption(text)
436 | else:
437 | text = text.lower().strip()
438 | return text
439 |
440 | return [process(t) for t in text]
441 |
442 | def _clean_caption(self, caption):
443 | caption = str(caption)
444 | caption = ul.unquote_plus(caption)
445 | caption = caption.strip().lower()
446 | caption = re.sub("
", "person", caption)
447 | # urls:
448 | caption = re.sub(
449 | r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
450 | "",
451 | caption,
452 | ) # regex for urls
453 | caption = re.sub(
454 | r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
455 | "",
456 | caption,
457 | ) # regex for urls
458 | # html:
459 | caption = BeautifulSoup(caption, features="html.parser").text
460 |
461 | # @
462 | caption = re.sub(r"@[\w\d]+\b", "", caption)
463 |
464 | # 31C0—31EF CJK Strokes
465 | # 31F0—31FF Katakana Phonetic Extensions
466 | # 3200—32FF Enclosed CJK Letters and Months
467 | # 3300—33FF CJK Compatibility
468 | # 3400—4DBF CJK Unified Ideographs Extension A
469 | # 4DC0—4DFF Yijing Hexagram Symbols
470 | # 4E00—9FFF CJK Unified Ideographs
471 | caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
472 | caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
473 | caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
474 | caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
475 | caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
476 | caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
477 | caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
478 | #######################################################
479 |
480 | # все виды тире / all types of dash --> "-"
481 | caption = re.sub(
482 | r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
483 | "-",
484 | caption,
485 | )
486 |
487 | # кавычки к одному стандарту
488 | caption = re.sub(r"[`´«»“”¨]", '"', caption)
489 | caption = re.sub(r"[‘’]", "'", caption)
490 |
491 | # "
492 | caption = re.sub(r""?", "", caption)
493 | # &
494 | caption = re.sub(r"&", "", caption)
495 |
496 | # ip adresses:
497 | caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
498 |
499 | # article ids:
500 | caption = re.sub(r"\d:\d\d\s+$", "", caption)
501 |
502 | # \n
503 | caption = re.sub(r"\\n", " ", caption)
504 |
505 | # "#123"
506 | caption = re.sub(r"#\d{1,3}\b", "", caption)
507 | # "#12345.."
508 | caption = re.sub(r"#\d{5,}\b", "", caption)
509 | # "123456.."
510 | caption = re.sub(r"\b\d{6,}\b", "", caption)
511 | # filenames:
512 | caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
513 |
514 | #
515 | caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
516 | caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
517 |
518 | caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
519 | caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
520 |
521 | # this-is-my-cute-cat / this_is_my_cute_cat
522 | regex2 = re.compile(r"(?:\-|\_)")
523 | if len(re.findall(regex2, caption)) > 3:
524 | caption = re.sub(regex2, " ", caption)
525 |
526 | caption = ftfy.fix_text(caption)
527 | caption = html.unescape(html.unescape(caption))
528 |
529 | caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
530 | caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
531 | caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
532 |
533 | caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
534 | caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
535 | caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
536 | caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
537 | caption = re.sub(r"\bpage\s+\d+\b", "", caption)
538 |
539 | caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
540 |
541 | caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
542 |
543 | caption = re.sub(r"\b\s+\:\s+", r": ", caption)
544 | caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
545 | caption = re.sub(r"\s+", " ", caption)
546 |
547 | caption.strip()
548 |
549 | caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
550 | caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
551 | caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
552 | caption = re.sub(r"^\.\S+$", "", caption)
553 |
554 | return caption.strip()
555 |
556 | @torch.no_grad()
557 | def __call__(
558 | self,
559 | prompt: Union[str, List[str]] = None,
560 | num_inference_steps: int = 100,
561 | timesteps: List[int] = None,
562 | guidance_scale: float = 7.0,
563 | negative_prompt: Optional[Union[str, List[str]]] = None,
564 | num_images_per_prompt: Optional[int] = 1,
565 | height: Optional[int] = None,
566 | width: Optional[int] = None,
567 | num_frames: int = 16,
568 | eta: float = 0.0,
569 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
570 | prompt_embeds: Optional[torch.FloatTensor] = None,
571 | negative_prompt_embeds: Optional[torch.FloatTensor] = None,
572 | output_type: Optional[str] = "np",
573 | return_dict: bool = True,
574 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
575 | callback_steps: int = 1,
576 | clean_caption: bool = True,
577 | cross_attention_kwargs: Optional[Dict[str, Any]] = None,
578 | ):
579 | """
580 | Function invoked when calling the pipeline for generation.
581 |
582 | Args:
583 | prompt (`str` or `List[str]`, *optional*):
584 | The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
585 | instead.
586 | num_inference_steps (`int`, *optional*, defaults to 50):
587 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the
588 | expense of slower inference.
589 | timesteps (`List[int]`, *optional*):
590 | Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
591 | timesteps are used. Must be in descending order.
592 | guidance_scale (`float`, *optional*, defaults to 7.5):
593 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
594 | `guidance_scale` is defined as `w` of equation 2. of [Imagen
595 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
596 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
597 | usually at the expense of lower image quality.
598 | negative_prompt (`str` or `List[str]`, *optional*):
599 | The prompt or prompts not to guide the image generation. If not defined, one has to pass
600 | `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
601 | less than `1`).
602 | num_images_per_prompt (`int`, *optional*, defaults to 1):
603 | The number of images to generate per prompt.
604 | height (`int`, *optional*, defaults to self.unet.config.sample_size):
605 | The height in pixels of the generated image.
606 | width (`int`, *optional*, defaults to self.unet.config.sample_size):
607 | The width in pixels of the generated image.
608 | eta (`float`, *optional*, defaults to 0.0):
609 | Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
610 | [`schedulers.DDIMScheduler`], will be ignored for others.
611 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
612 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
613 | to make generation deterministic.
614 | prompt_embeds (`torch.FloatTensor`, *optional*):
615 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
616 | provided, text embeddings will be generated from `prompt` input argument.
617 | negative_prompt_embeds (`torch.FloatTensor`, *optional*):
618 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
619 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
620 | argument.
621 | output_type (`str`, *optional*, defaults to `"pil"`):
622 | The output format of the generate image. Choose between
623 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
624 | return_dict (`bool`, *optional*, defaults to `True`):
625 | Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
626 | callback (`Callable`, *optional*):
627 | A function that will be called every `callback_steps` steps during inference. The function will be
628 | called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
629 | callback_steps (`int`, *optional*, defaults to 1):
630 | The frequency at which the `callback` function will be called. If not specified, the callback will be
631 | called at every step.
632 | clean_caption (`bool`, *optional*, defaults to `True`):
633 | Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
634 | be installed. If the dependencies are not installed, the embeddings will be created from the raw
635 | prompt.
636 | cross_attention_kwargs (`dict`, *optional*):
637 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
638 | `self.processor` in
639 | [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
640 |
641 | Examples:
642 |
643 | Returns:
644 | [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`:
645 | [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When
646 | returning a tuple, the first element is a list with the generated images, and the second element is a list
647 | of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw)
648 | or watermarked content, according to the `safety_checker`.
649 | """
650 | # 1. Check inputs. Raise error if not correct
651 | self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
652 |
653 | # 2. Define call parameters
654 | height = height or self.unet.config.sample_size
655 | width = width or self.unet.config.sample_size
656 |
657 | if prompt is not None and isinstance(prompt, str):
658 | batch_size = 1
659 | elif prompt is not None and isinstance(prompt, list):
660 | batch_size = len(prompt)
661 | else:
662 | batch_size = prompt_embeds.shape[0]
663 |
664 | device = self._execution_device
665 |
666 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
667 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
668 | # corresponds to doing no classifier free guidance.
669 | do_classifier_free_guidance = guidance_scale > 1.0
670 |
671 | # 3. Encode input prompt
672 | prompt_embeds, negative_prompt_embeds = self.encode_prompt(
673 | prompt,
674 | do_classifier_free_guidance,
675 | num_images_per_prompt=num_images_per_prompt,
676 | device=device,
677 | negative_prompt=negative_prompt,
678 | prompt_embeds=prompt_embeds,
679 | negative_prompt_embeds=negative_prompt_embeds,
680 | clean_caption=clean_caption,
681 | )
682 |
683 | if do_classifier_free_guidance:
684 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
685 |
686 | # 4. Prepare timesteps
687 | if timesteps is not None:
688 | self.scheduler.set_timesteps(timesteps=timesteps, device=device)
689 | timesteps = self.scheduler.timesteps
690 | num_inference_steps = len(timesteps)
691 | else:
692 | self.scheduler.set_timesteps(num_inference_steps, device=device)
693 | timesteps = self.scheduler.timesteps
694 |
695 | # 5. Prepare intermediate images
696 | intermediate_images = self.prepare_intermediate_images(
697 | batch_size * num_images_per_prompt,
698 | self.unet.config.in_channels,
699 | num_frames,
700 | height,
701 | width,
702 | prompt_embeds.dtype,
703 | device,
704 | generator,
705 | )
706 |
707 | # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
708 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
709 |
710 | # HACK: see comment in `enable_model_cpu_offload`
711 | if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None:
712 | self.text_encoder_offload_hook.offload()
713 |
714 | # 7. Denoising loop
715 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
716 | with self.progress_bar(total=num_inference_steps) as progress_bar:
717 | for i, t in enumerate(timesteps):
718 | model_input = (
719 | torch.cat([intermediate_images] * 2) if do_classifier_free_guidance else intermediate_images
720 | )
721 | model_input = self.scheduler.scale_model_input(model_input, t)
722 |
723 | # predict the noise residual
724 | noise_pred = self.unet(
725 | model_input,
726 | t,
727 | encoder_hidden_states=prompt_embeds,
728 | cross_attention_kwargs=cross_attention_kwargs,
729 | ).sample
730 |
731 | # perform guidance
732 | if do_classifier_free_guidance:
733 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
734 | noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1], dim=1)
735 | noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1], dim=1)
736 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
737 | noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
738 |
739 | if self.scheduler.config.variance_type not in ["learned", "learned_range"]:
740 | noise_pred, _ = noise_pred.split(model_input.shape[1], dim=1)
741 |
742 | # reshape latents
743 | bsz, channel, frames, height, width = intermediate_images.shape
744 | intermediate_images = intermediate_images.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, height, width)
745 | noise_pred = noise_pred.permute(0, 2, 1, 3, 4).reshape(bsz * frames, -1, height, width)
746 |
747 | # compute the previous noisy sample x_t -> x_t-1
748 | intermediate_images = self.scheduler.step(
749 | noise_pred, t, intermediate_images, **extra_step_kwargs
750 | ).prev_sample
751 |
752 | # reshape latents back
753 | intermediate_images = intermediate_images[None, :].reshape(bsz, frames, channel, height, width).permute(0, 2, 1, 3, 4)
754 |
755 | # call the callback, if provided
756 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
757 | progress_bar.update()
758 | if callback is not None and i % callback_steps == 0:
759 | callback(i, t, intermediate_images)
760 |
761 | video_tensor = intermediate_images
762 |
763 | if output_type == "pt":
764 | video = video_tensor
765 | else:
766 | video = tensor2vid(video_tensor)
767 |
768 | # Offload last model to CPU
769 | if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
770 | self.final_offload_hook.offload()
771 |
772 | if not return_dict:
773 | return (video,)
774 |
775 | return TextToVideoPipelineOutput(frames=video)
776 |
--------------------------------------------------------------------------------
/showone/pipelines/pipeline_t2v_interp_pixel.py:
--------------------------------------------------------------------------------
1 | import html
2 | import inspect
3 | import re
4 | import urllib.parse as ul
5 | from typing import Any, Callable, Dict, List, Optional, Union
6 |
7 | import numpy as np
8 | import torch
9 | from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer
10 |
11 | from diffusers.schedulers import DDPMScheduler
12 | from diffusers.utils import (
13 | BACKENDS_MAPPING,
14 | is_accelerate_available,
15 | is_accelerate_version,
16 | is_bs4_available,
17 | is_ftfy_available,
18 | logging,
19 | randn_tensor,
20 | replace_example_docstring,
21 | )
22 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline
23 |
24 | from ..models import UNet3DConditionModel
25 | from . import TextToVideoPipelineOutput
26 |
27 |
28 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
29 |
30 | if is_bs4_available():
31 | from bs4 import BeautifulSoup
32 |
33 | if is_ftfy_available():
34 | import ftfy
35 |
36 |
37 | def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> List[np.ndarray]:
38 | # This code is copied from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
39 | # reshape to ncfhw
40 | mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1)
41 | std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1)
42 | # unnormalize back to [0,1]
43 | video = video.mul_(std).add_(mean)
44 | video.clamp_(0, 1)
45 | # prepare the final outputs
46 | i, c, f, h, w = video.shape
47 | images = video.permute(2, 3, 0, 4, 1).reshape(
48 | f, h, i * w, c
49 | ) # 1st (frames, h, batch_size, w, c) 2nd (frames, h, batch_size * w, c)
50 | images = images.unbind(dim=0) # prepare a list of indvidual (consecutive frames)
51 | images = [(image.cpu().numpy() * 255).astype("uint8") for image in images] # f h w c
52 | return images
53 |
54 |
55 | class TextToVideoIFInterpPipeline(DiffusionPipeline):
56 | tokenizer: T5Tokenizer
57 | text_encoder: T5EncoderModel
58 |
59 | unet: UNet3DConditionModel
60 | scheduler: DDPMScheduler
61 |
62 | feature_extractor: Optional[CLIPImageProcessor]
63 | # safety_checker: Optional[IFSafetyChecker]
64 |
65 | # watermarker: Optional[IFWatermarker]
66 |
67 | bad_punct_regex = re.compile(
68 | r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
69 | ) # noqa
70 |
71 | _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"]
72 |
73 | def __init__(
74 | self,
75 | tokenizer: T5Tokenizer,
76 | text_encoder: T5EncoderModel,
77 | unet: UNet3DConditionModel,
78 | scheduler: DDPMScheduler,
79 | feature_extractor: Optional[CLIPImageProcessor],
80 | ):
81 | super().__init__()
82 |
83 | self.register_modules(
84 | tokenizer=tokenizer,
85 | text_encoder=text_encoder,
86 | unet=unet,
87 | scheduler=scheduler,
88 | feature_extractor=feature_extractor,
89 | )
90 | self.safety_checker = None
91 |
92 | def enable_sequential_cpu_offload(self, gpu_id=0):
93 | r"""
94 | Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
95 | models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
96 | when their specific submodule has its `forward` method called.
97 | """
98 | if is_accelerate_available():
99 | from accelerate import cpu_offload
100 | else:
101 | raise ImportError("Please install accelerate via `pip install accelerate`")
102 |
103 | device = torch.device(f"cuda:{gpu_id}")
104 |
105 | models = [
106 | self.text_encoder,
107 | self.unet,
108 | ]
109 | for cpu_offloaded_model in models:
110 | if cpu_offloaded_model is not None:
111 | cpu_offload(cpu_offloaded_model, device)
112 |
113 | if self.safety_checker is not None:
114 | cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
115 |
116 | def enable_model_cpu_offload(self, gpu_id=0):
117 | r"""
118 | Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
119 | to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
120 | method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
121 | `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
122 | """
123 | if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
124 | from accelerate import cpu_offload_with_hook
125 | else:
126 | raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
127 |
128 | device = torch.device(f"cuda:{gpu_id}")
129 |
130 |
131 | if self.device.type != "cpu":
132 | self.to("cpu", silence_dtype_warnings=True)
133 | torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
134 |
135 | hook = None
136 |
137 | if self.text_encoder is not None:
138 | _, hook = cpu_offload_with_hook(self.text_encoder, device, prev_module_hook=hook)
139 |
140 | # Accelerate will move the next model to the device _before_ calling the offload hook of the
141 | # previous model. This will cause both models to be present on the device at the same time.
142 | # IF uses T5 for its text encoder which is really large. We can manually call the offload
143 | # hook for the text encoder to ensure it's moved to the cpu before the unet is moved to
144 | # the GPU.
145 | self.text_encoder_offload_hook = hook
146 |
147 | _, hook = cpu_offload_with_hook(self.unet, device, prev_module_hook=hook)
148 |
149 | # if the safety checker isn't called, `unet_offload_hook` will have to be called to manually offload the unet
150 | self.unet_offload_hook = hook
151 |
152 | if self.safety_checker is not None:
153 | _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
154 |
155 | # We'll offload the last model manually.
156 | self.final_offload_hook = hook
157 |
158 | def remove_all_hooks(self):
159 | if is_accelerate_available():
160 | from accelerate.hooks import remove_hook_from_module
161 | else:
162 | raise ImportError("Please install accelerate via `pip install accelerate`")
163 |
164 | for model in [self.text_encoder, self.unet, self.safety_checker]:
165 | if model is not None:
166 | remove_hook_from_module(model, recurse=True)
167 |
168 | self.unet_offload_hook = None
169 | self.text_encoder_offload_hook = None
170 | self.final_offload_hook = None
171 |
172 | @property
173 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
174 | def _execution_device(self):
175 | r"""
176 | Returns the device on which the pipeline's models will be executed. After calling
177 | `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
178 | hooks.
179 | """
180 | if not hasattr(self.unet, "_hf_hook"):
181 | return self.device
182 | for module in self.unet.modules():
183 | if (
184 | hasattr(module, "_hf_hook")
185 | and hasattr(module._hf_hook, "execution_device")
186 | and module._hf_hook.execution_device is not None
187 | ):
188 | return torch.device(module._hf_hook.execution_device)
189 | return self.device
190 |
191 | @torch.no_grad()
192 | def encode_prompt(
193 | self,
194 | prompt,
195 | do_classifier_free_guidance=True,
196 | num_images_per_prompt=1,
197 | device=None,
198 | negative_prompt=None,
199 | prompt_embeds: Optional[torch.FloatTensor] = None,
200 | negative_prompt_embeds: Optional[torch.FloatTensor] = None,
201 | clean_caption: bool = False,
202 | ):
203 | r"""
204 | Encodes the prompt into text encoder hidden states.
205 |
206 | Args:
207 | prompt (`str` or `List[str]`, *optional*):
208 | prompt to be encoded
209 | device: (`torch.device`, *optional*):
210 | torch device to place the resulting embeddings on
211 | num_images_per_prompt (`int`, *optional*, defaults to 1):
212 | number of images that should be generated per prompt
213 | do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
214 | whether to use classifier free guidance or not
215 | negative_prompt (`str` or `List[str]`, *optional*):
216 | The prompt or prompts not to guide the image generation. If not defined, one has to pass
217 | `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
218 | Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
219 | prompt_embeds (`torch.FloatTensor`, *optional*):
220 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
221 | provided, text embeddings will be generated from `prompt` input argument.
222 | negative_prompt_embeds (`torch.FloatTensor`, *optional*):
223 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
224 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
225 | argument.
226 | """
227 | if prompt is not None and negative_prompt is not None:
228 | if type(prompt) is not type(negative_prompt):
229 | raise TypeError(
230 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
231 | f" {type(prompt)}."
232 | )
233 |
234 | if device is None:
235 | device = self._execution_device
236 |
237 | if prompt is not None and isinstance(prompt, str):
238 | batch_size = 1
239 | elif prompt is not None and isinstance(prompt, list):
240 | batch_size = len(prompt)
241 | else:
242 | batch_size = prompt_embeds.shape[0]
243 |
244 | # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF
245 | max_length = 77
246 |
247 | if prompt_embeds is None:
248 | prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
249 | text_inputs = self.tokenizer(
250 | prompt,
251 | padding="max_length",
252 | max_length=max_length,
253 | truncation=True,
254 | add_special_tokens=True,
255 | return_tensors="pt",
256 | )
257 | text_input_ids = text_inputs.input_ids
258 | untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
259 |
260 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
261 | text_input_ids, untruncated_ids
262 | ):
263 | removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
264 | logger.warning(
265 | "The following part of your input was truncated because CLIP can only handle sequences up to"
266 | f" {max_length} tokens: {removed_text}"
267 | )
268 |
269 | attention_mask = text_inputs.attention_mask.to(device)
270 |
271 | prompt_embeds = self.text_encoder(
272 | text_input_ids.to(device),
273 | attention_mask=attention_mask,
274 | )
275 | prompt_embeds = prompt_embeds[0]
276 |
277 | if self.text_encoder is not None:
278 | dtype = self.text_encoder.dtype
279 | elif self.unet is not None:
280 | dtype = self.unet.dtype
281 | else:
282 | dtype = None
283 |
284 | prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
285 |
286 | bs_embed, seq_len, _ = prompt_embeds.shape
287 | # duplicate text embeddings for each generation per prompt, using mps friendly method
288 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
289 | prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
290 |
291 | # get unconditional embeddings for classifier free guidance
292 | if do_classifier_free_guidance and negative_prompt_embeds is None:
293 | uncond_tokens: List[str]
294 | if negative_prompt is None:
295 | uncond_tokens = [""] * batch_size
296 | elif isinstance(negative_prompt, str):
297 | uncond_tokens = [negative_prompt]
298 | elif batch_size != len(negative_prompt):
299 | raise ValueError(
300 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
301 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
302 | " the batch size of `prompt`."
303 | )
304 | else:
305 | uncond_tokens = negative_prompt
306 |
307 | uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
308 | max_length = prompt_embeds.shape[1]
309 | uncond_input = self.tokenizer(
310 | uncond_tokens,
311 | padding="max_length",
312 | max_length=max_length,
313 | truncation=True,
314 | return_attention_mask=True,
315 | add_special_tokens=True,
316 | return_tensors="pt",
317 | )
318 | attention_mask = uncond_input.attention_mask.to(device)
319 |
320 | negative_prompt_embeds = self.text_encoder(
321 | uncond_input.input_ids.to(device),
322 | attention_mask=attention_mask,
323 | )
324 | negative_prompt_embeds = negative_prompt_embeds[0]
325 |
326 | if do_classifier_free_guidance:
327 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
328 | seq_len = negative_prompt_embeds.shape[1]
329 |
330 | negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
331 |
332 | negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
333 | negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
334 |
335 | # For classifier free guidance, we need to do two forward passes.
336 | # Here we concatenate the unconditional and text embeddings into a single batch
337 | # to avoid doing two forward passes
338 | else:
339 | negative_prompt_embeds = None
340 |
341 | return prompt_embeds, negative_prompt_embeds
342 |
343 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
344 | def prepare_extra_step_kwargs(self, generator, eta):
345 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
346 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
347 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
348 | # and should be between [0, 1]
349 |
350 | accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
351 | extra_step_kwargs = {}
352 | if accepts_eta:
353 | extra_step_kwargs["eta"] = eta
354 |
355 | # check if the scheduler accepts generator
356 | accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
357 | if accepts_generator:
358 | extra_step_kwargs["generator"] = generator
359 | return extra_step_kwargs
360 |
361 | def check_inputs(
362 | self,
363 | prompt,
364 | callback_steps,
365 | negative_prompt=None,
366 | prompt_embeds=None,
367 | negative_prompt_embeds=None,
368 | ):
369 | if (callback_steps is None) or (
370 | callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
371 | ):
372 | raise ValueError(
373 | f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
374 | f" {type(callback_steps)}."
375 | )
376 |
377 | if prompt is not None and prompt_embeds is not None:
378 | raise ValueError(
379 | f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
380 | " only forward one of the two."
381 | )
382 | elif prompt is None and prompt_embeds is None:
383 | raise ValueError(
384 | "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
385 | )
386 | elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
387 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
388 |
389 | if negative_prompt is not None and negative_prompt_embeds is not None:
390 | raise ValueError(
391 | f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
392 | f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
393 | )
394 |
395 | if prompt_embeds is not None and negative_prompt_embeds is not None:
396 | if prompt_embeds.shape != negative_prompt_embeds.shape:
397 | raise ValueError(
398 | "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
399 | f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
400 | f" {negative_prompt_embeds.shape}."
401 | )
402 |
403 | def prepare_intermediate_images(self, batch_size, num_channels, num_frames, height, width, dtype, device, generator):
404 | shape = (batch_size, num_channels, num_frames, height, width)
405 | if isinstance(generator, list) and len(generator) != batch_size:
406 | raise ValueError(
407 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
408 | f" size of {batch_size}. Make sure the batch size matches the length of the generators."
409 | )
410 |
411 | intermediate_images = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
412 |
413 | # scale the initial noise by the standard deviation required by the scheduler
414 | intermediate_images = intermediate_images * self.scheduler.init_noise_sigma
415 | return intermediate_images
416 |
417 | def _text_preprocessing(self, text, clean_caption=False):
418 | if clean_caption and not is_bs4_available():
419 | logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
420 | logger.warn("Setting `clean_caption` to False...")
421 | clean_caption = False
422 |
423 | if clean_caption and not is_ftfy_available():
424 | logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
425 | logger.warn("Setting `clean_caption` to False...")
426 | clean_caption = False
427 |
428 | if not isinstance(text, (tuple, list)):
429 | text = [text]
430 |
431 | def process(text: str):
432 | if clean_caption:
433 | text = self._clean_caption(text)
434 | text = self._clean_caption(text)
435 | else:
436 | text = text.lower().strip()
437 | return text
438 |
439 | return [process(t) for t in text]
440 |
441 | def _clean_caption(self, caption):
442 | caption = str(caption)
443 | caption = ul.unquote_plus(caption)
444 | caption = caption.strip().lower()
445 | caption = re.sub("", "person", caption)
446 | # urls:
447 | caption = re.sub(
448 | r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
449 | "",
450 | caption,
451 | ) # regex for urls
452 | caption = re.sub(
453 | r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
454 | "",
455 | caption,
456 | ) # regex for urls
457 | # html:
458 | caption = BeautifulSoup(caption, features="html.parser").text
459 |
460 | # @
461 | caption = re.sub(r"@[\w\d]+\b", "", caption)
462 |
463 | # 31C0—31EF CJK Strokes
464 | # 31F0—31FF Katakana Phonetic Extensions
465 | # 3200—32FF Enclosed CJK Letters and Months
466 | # 3300—33FF CJK Compatibility
467 | # 3400—4DBF CJK Unified Ideographs Extension A
468 | # 4DC0—4DFF Yijing Hexagram Symbols
469 | # 4E00—9FFF CJK Unified Ideographs
470 | caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
471 | caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
472 | caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
473 | caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
474 | caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
475 | caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
476 | caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
477 | #######################################################
478 |
479 | # все виды тире / all types of dash --> "-"
480 | caption = re.sub(
481 | r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
482 | "-",
483 | caption,
484 | )
485 |
486 | # кавычки к одному стандарту
487 | caption = re.sub(r"[`´«»“”¨]", '"', caption)
488 | caption = re.sub(r"[‘’]", "'", caption)
489 |
490 | # "
491 | caption = re.sub(r""?", "", caption)
492 | # &
493 | caption = re.sub(r"&", "", caption)
494 |
495 | # ip adresses:
496 | caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
497 |
498 | # article ids:
499 | caption = re.sub(r"\d:\d\d\s+$", "", caption)
500 |
501 | # \n
502 | caption = re.sub(r"\\n", " ", caption)
503 |
504 | # "#123"
505 | caption = re.sub(r"#\d{1,3}\b", "", caption)
506 | # "#12345.."
507 | caption = re.sub(r"#\d{5,}\b", "", caption)
508 | # "123456.."
509 | caption = re.sub(r"\b\d{6,}\b", "", caption)
510 | # filenames:
511 | caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
512 |
513 | #
514 | caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
515 | caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
516 |
517 | caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
518 | caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
519 |
520 | # this-is-my-cute-cat / this_is_my_cute_cat
521 | regex2 = re.compile(r"(?:\-|\_)")
522 | if len(re.findall(regex2, caption)) > 3:
523 | caption = re.sub(regex2, " ", caption)
524 |
525 | caption = ftfy.fix_text(caption)
526 | caption = html.unescape(html.unescape(caption))
527 |
528 | caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
529 | caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
530 | caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
531 |
532 | caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
533 | caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
534 | caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
535 | caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
536 | caption = re.sub(r"\bpage\s+\d+\b", "", caption)
537 |
538 | caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
539 |
540 | caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
541 |
542 | caption = re.sub(r"\b\s+\:\s+", r": ", caption)
543 | caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
544 | caption = re.sub(r"\s+", " ", caption)
545 |
546 | caption.strip()
547 |
548 | caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
549 | caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
550 | caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
551 | caption = re.sub(r"^\.\S+$", "", caption)
552 |
553 | return caption.strip()
554 |
555 | @torch.no_grad()
556 | def __call__(
557 | self,
558 | pixel_values,
559 | prompt: Union[str, List[str]] = None,
560 | num_inference_steps: int = 100,
561 | timesteps: List[int] = None,
562 | guidance_scale: float = 7.0,
563 | negative_prompt: Optional[Union[str, List[str]]] = None,
564 | num_images_per_prompt: Optional[int] = 1,
565 | height: Optional[int] = None,
566 | width: Optional[int] = None,
567 | num_frames: int = 16,
568 | eta: float = 0.0,
569 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
570 | prompt_embeds: Optional[torch.FloatTensor] = None,
571 | negative_prompt_embeds: Optional[torch.FloatTensor] = None,
572 | output_type: Optional[str] = "np",
573 | return_dict: bool = True,
574 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
575 | callback_steps: int = 1,
576 | clean_caption: bool = True,
577 | cross_attention_kwargs: Optional[Dict[str, Any]] = None,
578 | init_noise = None,
579 | cond_interpolation = False,
580 | ):
581 | """
582 | Function invoked when calling the pipeline for generation.
583 |
584 | Args:
585 | prompt (`str` or `List[str]`, *optional*):
586 | The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
587 | instead.
588 | num_inference_steps (`int`, *optional*, defaults to 50):
589 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the
590 | expense of slower inference.
591 | timesteps (`List[int]`, *optional*):
592 | Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
593 | timesteps are used. Must be in descending order.
594 | guidance_scale (`float`, *optional*, defaults to 7.5):
595 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
596 | `guidance_scale` is defined as `w` of equation 2. of [Imagen
597 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
598 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
599 | usually at the expense of lower image quality.
600 | negative_prompt (`str` or `List[str]`, *optional*):
601 | The prompt or prompts not to guide the image generation. If not defined, one has to pass
602 | `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
603 | less than `1`).
604 | num_images_per_prompt (`int`, *optional*, defaults to 1):
605 | The number of images to generate per prompt.
606 | height (`int`, *optional*, defaults to self.unet.config.sample_size):
607 | The height in pixels of the generated image.
608 | width (`int`, *optional*, defaults to self.unet.config.sample_size):
609 | The width in pixels of the generated image.
610 | eta (`float`, *optional*, defaults to 0.0):
611 | Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
612 | [`schedulers.DDIMScheduler`], will be ignored for others.
613 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
614 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
615 | to make generation deterministic.
616 | prompt_embeds (`torch.FloatTensor`, *optional*):
617 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
618 | provided, text embeddings will be generated from `prompt` input argument.
619 | negative_prompt_embeds (`torch.FloatTensor`, *optional*):
620 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
621 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
622 | argument.
623 | output_type (`str`, *optional*, defaults to `"pil"`):
624 | The output format of the generate image. Choose between
625 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
626 | return_dict (`bool`, *optional*, defaults to `True`):
627 | Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
628 | callback (`Callable`, *optional*):
629 | A function that will be called every `callback_steps` steps during inference. The function will be
630 | called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
631 | callback_steps (`int`, *optional*, defaults to 1):
632 | The frequency at which the `callback` function will be called. If not specified, the callback will be
633 | called at every step.
634 | clean_caption (`bool`, *optional*, defaults to `True`):
635 | Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
636 | be installed. If the dependencies are not installed, the embeddings will be created from the raw
637 | prompt.
638 | cross_attention_kwargs (`dict`, *optional*):
639 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
640 | `self.processor` in
641 | [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
642 |
643 | Examples:
644 |
645 | Returns:
646 | [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`:
647 | [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When
648 | returning a tuple, the first element is a list with the generated images, and the second element is a list
649 | of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw)
650 | or watermarked content, according to the `safety_checker`.
651 | """
652 | # 1. Check inputs. Raise error if not correct
653 | self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
654 |
655 | # 2. Define call parameters
656 | height = height or self.unet.config.sample_size
657 | width = width or self.unet.config.sample_size
658 |
659 | if prompt is not None and isinstance(prompt, str):
660 | batch_size = 1
661 | elif prompt is not None and isinstance(prompt, list):
662 | batch_size = len(prompt)
663 | else:
664 | batch_size = prompt_embeds.shape[0]
665 |
666 | device = self._execution_device
667 |
668 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
669 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
670 | # corresponds to doing no classifier free guidance.
671 | do_classifier_free_guidance = guidance_scale > 1.0
672 |
673 | # 3. Encode input prompt
674 | prompt_embeds, negative_prompt_embeds = self.encode_prompt(
675 | prompt,
676 | do_classifier_free_guidance,
677 | num_images_per_prompt=num_images_per_prompt,
678 | device=device,
679 | negative_prompt=negative_prompt,
680 | prompt_embeds=prompt_embeds,
681 | negative_prompt_embeds=negative_prompt_embeds,
682 | clean_caption=clean_caption,
683 | )
684 |
685 | if do_classifier_free_guidance:
686 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
687 |
688 | # 4. Prepare timesteps
689 | if timesteps is not None:
690 | self.scheduler.set_timesteps(timesteps=timesteps, device=device)
691 | timesteps = self.scheduler.timesteps
692 | num_inference_steps = len(timesteps)
693 | else:
694 | self.scheduler.set_timesteps(num_inference_steps, device=device)
695 | timesteps = self.scheduler.timesteps
696 |
697 | # 5. Prepare intermediate images
698 | pixel_values = pixel_values.to(device)
699 | if init_noise is not None:
700 | intermediate_images = init_noise
701 | else:
702 | intermediate_images = self.prepare_intermediate_images(
703 | batch_size * num_images_per_prompt,
704 | # self.unet.config.in_channels, # mask not noise.
705 | pixel_values.shape[1],
706 | num_frames,
707 | height,
708 | width,
709 | prompt_embeds.dtype,
710 | device,
711 | generator,
712 | )
713 |
714 | bsz = intermediate_images.shape[0]
715 | interp_mask = torch.zeros(bsz, 1, *intermediate_images.shape[2:], device=device, dtype=intermediate_images.dtype)
716 | interp_mask[:, :, 0, :, :] = 1
717 | interp_mask[:, :, -1, :, :] = 1
718 |
719 | if cond_interpolation:
720 | import torch.nn.functional as F
721 | pixel_values = F.interpolate(pixel_values[:, :, [0, -1], ...], pixel_values.shape[2:],
722 | mode="trilinear", align_corners=True)
723 | else:
724 | raise Exception("apply mask to pixel_values")
725 |
726 | # intermediate_images[:, :, 0, :, :] = pixel_values[:, :, 0, :, :]
727 | # intermediate_images[:, :, -1, :, :] = pixel_values[:, :, -1, :, :]
728 | pixel_values_condition = torch.cat((pixel_values, interp_mask), dim=1)
729 |
730 | # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
731 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
732 |
733 | # HACK: see comment in `enable_model_cpu_offload`
734 | if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None:
735 | self.text_encoder_offload_hook.offload()
736 |
737 | # 7. Denoising loop
738 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
739 | with self.progress_bar(total=num_inference_steps) as progress_bar:
740 | for i, t in enumerate(timesteps):
741 | intermediate_images_input = torch.cat((intermediate_images, pixel_values_condition), dim=1)
742 | model_input = (
743 | torch.cat([intermediate_images_input] * 2) if do_classifier_free_guidance else intermediate_images_input
744 | )
745 | model_input = self.scheduler.scale_model_input(model_input, t)
746 |
747 | # predict the noise residual
748 | noise_pred = self.unet(
749 | model_input,
750 | t,
751 | encoder_hidden_states=prompt_embeds,
752 | cross_attention_kwargs=cross_attention_kwargs,
753 | ).sample
754 | # perform guidance
755 | if do_classifier_free_guidance:
756 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
757 | noise_pred_uncond, _ = noise_pred_uncond.split(intermediate_images.shape[1], dim=1)
758 | noise_pred_text, predicted_variance = noise_pred_text.split(intermediate_images.shape[1], dim=1)
759 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
760 | noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
761 |
762 | if self.scheduler.config.variance_type not in ["learned", "learned_range"]:
763 | noise_pred, _ = noise_pred.split(intermediate_images.shape[1], dim=1)
764 |
765 | # reshape latents
766 | bsz, channel, frames, width, height = intermediate_images.shape
767 | intermediate_images = intermediate_images.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height)
768 | noise_pred = noise_pred.permute(0, 2, 1, 3, 4).reshape(bsz * frames, -1, width, height)
769 |
770 | # compute the previous noisy sample x_t -> x_t-1
771 | intermediate_images = self.scheduler.step(
772 | noise_pred, t, intermediate_images, **extra_step_kwargs
773 | ).prev_sample
774 |
775 | # reshape latents back
776 | intermediate_images = intermediate_images[None, :].reshape(bsz, frames, channel, width, height).permute(0, 2, 1, 3, 4)
777 |
778 | # call the callback, if provided
779 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
780 | progress_bar.update()
781 | if callback is not None and i % callback_steps == 0:
782 | callback(i, t, intermediate_images)
783 |
784 | video_tensor = intermediate_images
785 |
786 | if output_type == "pt":
787 | video = video_tensor
788 | else:
789 | video = tensor2vid(video_tensor)
790 |
791 | # Offload last model to CPU
792 | if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
793 | self.final_offload_hook.offload()
794 |
795 | if not return_dict:
796 | return (video,)
797 |
798 | return TextToVideoPipelineOutput(frames=video)
--------------------------------------------------------------------------------
/showone/pipelines/pipeline_t2v_sr_pixel.py:
--------------------------------------------------------------------------------
1 | import html
2 | import inspect
3 | import re
4 | import urllib.parse as ul
5 | from typing import Any, Callable, Dict, List, Optional, Union
6 |
7 | import numpy as np
8 | from einops import rearrange
9 | import PIL
10 | import torch
11 | import torch.nn.functional as F
12 | from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer
13 |
14 | from diffusers.loaders import LoraLoaderMixin
15 | from diffusers.schedulers import DDPMScheduler
16 | from diffusers.utils import (
17 | BACKENDS_MAPPING,
18 | is_accelerate_available,
19 | is_accelerate_version,
20 | is_bs4_available,
21 | is_ftfy_available,
22 | logging,
23 | randn_tensor,
24 | )
25 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline
26 |
27 | from ..models import UNet3DConditionModel
28 | from . import TextToVideoPipelineOutput
29 |
30 |
31 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
32 |
33 | if is_bs4_available():
34 | from bs4 import BeautifulSoup
35 |
36 | if is_ftfy_available():
37 | import ftfy
38 |
39 |
40 | def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> List[np.ndarray]:
41 | # This code is copied from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
42 | # reshape to ncfhw
43 | mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1)
44 | std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1)
45 | # unnormalize back to [0,1]
46 | video = video.mul_(std).add_(mean)
47 | video.clamp_(0, 1)
48 | # prepare the final outputs
49 | i, c, f, h, w = video.shape
50 | images = video.permute(2, 3, 0, 4, 1).reshape(
51 | f, h, i * w, c
52 | ) # 1st (frames, h, batch_size, w, c) 2nd (frames, h, batch_size * w, c)
53 | images = images.unbind(dim=0) # prepare a list of indvidual (consecutive frames)
54 | images = [(image.cpu().numpy() * 255).astype("uint8") for image in images] # f h w c
55 | return images
56 |
57 |
58 | class TextToVideoIFSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
59 | tokenizer: T5Tokenizer
60 | text_encoder: T5EncoderModel
61 |
62 | unet: UNet3DConditionModel
63 | scheduler: DDPMScheduler
64 | image_noising_scheduler: DDPMScheduler
65 |
66 | feature_extractor: Optional[CLIPImageProcessor]
67 | # safety_checker: Optional[IFSafetyChecker]
68 |
69 | # watermarker: Optional[IFWatermarker]
70 |
71 | bad_punct_regex = re.compile(
72 | r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
73 | ) # noqa
74 |
75 | _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"]
76 |
77 | def __init__(
78 | self,
79 | tokenizer: T5Tokenizer,
80 | text_encoder: T5EncoderModel,
81 | unet: UNet3DConditionModel,
82 | scheduler: DDPMScheduler,
83 | image_noising_scheduler: DDPMScheduler,
84 | feature_extractor: Optional[CLIPImageProcessor],
85 | ):
86 | super().__init__()
87 |
88 | self.register_modules(
89 | tokenizer=tokenizer,
90 | text_encoder=text_encoder,
91 | unet=unet,
92 | scheduler=scheduler,
93 | image_noising_scheduler=image_noising_scheduler,
94 | feature_extractor=feature_extractor,
95 | )
96 | self.safety_checker = None
97 |
98 | def enable_sequential_cpu_offload(self, gpu_id=0):
99 | r"""
100 | Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
101 | models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
102 | when their specific submodule has its `forward` method called.
103 | """
104 | if is_accelerate_available():
105 | from accelerate import cpu_offload
106 | else:
107 | raise ImportError("Please install accelerate via `pip install accelerate`")
108 |
109 | device = torch.device(f"cuda:{gpu_id}")
110 |
111 | models = [
112 | self.text_encoder,
113 | self.unet,
114 | ]
115 | for cpu_offloaded_model in models:
116 | if cpu_offloaded_model is not None:
117 | cpu_offload(cpu_offloaded_model, device)
118 |
119 | if self.safety_checker is not None:
120 | cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
121 |
122 | def enable_model_cpu_offload(self, gpu_id=0):
123 | r"""
124 | Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
125 | to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
126 | method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
127 | `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
128 | """
129 | if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
130 | from accelerate import cpu_offload_with_hook
131 | else:
132 | raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
133 |
134 | device = torch.device(f"cuda:{gpu_id}")
135 |
136 | if self.device.type != "cpu":
137 | self.to("cpu", silence_dtype_warnings=True)
138 | torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
139 |
140 | hook = None
141 |
142 | if self.text_encoder is not None:
143 | _, hook = cpu_offload_with_hook(self.text_encoder, device, prev_module_hook=hook)
144 |
145 | # Accelerate will move the next model to the device _before_ calling the offload hook of the
146 | # previous model. This will cause both models to be present on the device at the same time.
147 | # IF uses T5 for its text encoder which is really large. We can manually call the offload
148 | # hook for the text encoder to ensure it's moved to the cpu before the unet is moved to
149 | # the GPU.
150 | self.text_encoder_offload_hook = hook
151 |
152 | _, hook = cpu_offload_with_hook(self.unet, device, prev_module_hook=hook)
153 |
154 | # if the safety checker isn't called, `unet_offload_hook` will have to be called to manually offload the unet
155 | self.unet_offload_hook = hook
156 |
157 | if self.safety_checker is not None:
158 | _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
159 |
160 | # We'll offload the last model manually.
161 | self.final_offload_hook = hook
162 |
163 | def remove_all_hooks(self):
164 | if is_accelerate_available():
165 | from accelerate.hooks import remove_hook_from_module
166 | else:
167 | raise ImportError("Please install accelerate via `pip install accelerate`")
168 |
169 | for model in [self.text_encoder, self.unet, self.safety_checker]:
170 | if model is not None:
171 | remove_hook_from_module(model, recurse=True)
172 |
173 | self.unet_offload_hook = None
174 | self.text_encoder_offload_hook = None
175 | self.final_offload_hook = None
176 |
177 | @property
178 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
179 | def _execution_device(self):
180 | r"""
181 | Returns the device on which the pipeline's models will be executed. After calling
182 | `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
183 | hooks.
184 | """
185 | if not hasattr(self.unet, "_hf_hook"):
186 | return self.device
187 | for module in self.unet.modules():
188 | if (
189 | hasattr(module, "_hf_hook")
190 | and hasattr(module._hf_hook, "execution_device")
191 | and module._hf_hook.execution_device is not None
192 | ):
193 | return torch.device(module._hf_hook.execution_device)
194 | return self.device
195 |
196 | @torch.no_grad()
197 | def encode_prompt(
198 | self,
199 | prompt,
200 | do_classifier_free_guidance=True,
201 | num_images_per_prompt=1,
202 | device=None,
203 | negative_prompt=None,
204 | prompt_embeds: Optional[torch.FloatTensor] = None,
205 | negative_prompt_embeds: Optional[torch.FloatTensor] = None,
206 | clean_caption: bool = False,
207 | ):
208 | r"""
209 | Encodes the prompt into text encoder hidden states.
210 |
211 | Args:
212 | prompt (`str` or `List[str]`, *optional*):
213 | prompt to be encoded
214 | device: (`torch.device`, *optional*):
215 | torch device to place the resulting embeddings on
216 | num_images_per_prompt (`int`, *optional*, defaults to 1):
217 | number of images that should be generated per prompt
218 | do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
219 | whether to use classifier free guidance or not
220 | negative_prompt (`str` or `List[str]`, *optional*):
221 | The prompt or prompts not to guide the image generation. If not defined, one has to pass
222 | `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
223 | Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
224 | prompt_embeds (`torch.FloatTensor`, *optional*):
225 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
226 | provided, text embeddings will be generated from `prompt` input argument.
227 | negative_prompt_embeds (`torch.FloatTensor`, *optional*):
228 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
229 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
230 | argument.
231 | """
232 | if prompt is not None and negative_prompt is not None:
233 | if type(prompt) is not type(negative_prompt):
234 | raise TypeError(
235 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
236 | f" {type(prompt)}."
237 | )
238 |
239 | if device is None:
240 | device = self._execution_device
241 |
242 | if prompt is not None and isinstance(prompt, str):
243 | batch_size = 1
244 | elif prompt is not None and isinstance(prompt, list):
245 | batch_size = len(prompt)
246 | else:
247 | batch_size = prompt_embeds.shape[0]
248 |
249 | # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF
250 | max_length = 77
251 |
252 | if prompt_embeds is None:
253 | prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
254 | text_inputs = self.tokenizer(
255 | prompt,
256 | padding="max_length",
257 | max_length=max_length,
258 | truncation=True,
259 | add_special_tokens=True,
260 | return_tensors="pt",
261 | )
262 | text_input_ids = text_inputs.input_ids
263 | untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
264 |
265 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
266 | text_input_ids, untruncated_ids
267 | ):
268 | removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
269 | logger.warning(
270 | "The following part of your input was truncated because CLIP can only handle sequences up to"
271 | f" {max_length} tokens: {removed_text}"
272 | )
273 |
274 | attention_mask = text_inputs.attention_mask.to(device)
275 |
276 | prompt_embeds = self.text_encoder(
277 | text_input_ids.to(device),
278 | attention_mask=attention_mask,
279 | )
280 | prompt_embeds = prompt_embeds[0]
281 |
282 | if self.text_encoder is not None:
283 | dtype = self.text_encoder.dtype
284 | elif self.unet is not None:
285 | dtype = self.unet.dtype
286 | else:
287 | dtype = None
288 |
289 | prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
290 |
291 | bs_embed, seq_len, _ = prompt_embeds.shape
292 | # duplicate text embeddings for each generation per prompt, using mps friendly method
293 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
294 | prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
295 |
296 | # get unconditional embeddings for classifier free guidance
297 | if do_classifier_free_guidance and negative_prompt_embeds is None:
298 | uncond_tokens: List[str]
299 | if negative_prompt is None:
300 | uncond_tokens = [""] * batch_size
301 | elif isinstance(negative_prompt, str):
302 | uncond_tokens = [negative_prompt]
303 | elif batch_size != len(negative_prompt):
304 | raise ValueError(
305 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
306 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
307 | " the batch size of `prompt`."
308 | )
309 | else:
310 | uncond_tokens = negative_prompt
311 |
312 | uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
313 | max_length = prompt_embeds.shape[1]
314 | uncond_input = self.tokenizer(
315 | uncond_tokens,
316 | padding="max_length",
317 | max_length=max_length,
318 | truncation=True,
319 | return_attention_mask=True,
320 | add_special_tokens=True,
321 | return_tensors="pt",
322 | )
323 | attention_mask = uncond_input.attention_mask.to(device)
324 |
325 | negative_prompt_embeds = self.text_encoder(
326 | uncond_input.input_ids.to(device),
327 | attention_mask=attention_mask,
328 | )
329 | negative_prompt_embeds = negative_prompt_embeds[0]
330 |
331 | if do_classifier_free_guidance:
332 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
333 | seq_len = negative_prompt_embeds.shape[1]
334 |
335 | negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
336 |
337 | negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
338 | negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
339 |
340 | # For classifier free guidance, we need to do two forward passes.
341 | # Here we concatenate the unconditional and text embeddings into a single batch
342 | # to avoid doing two forward passes
343 | else:
344 | negative_prompt_embeds = None
345 |
346 | return prompt_embeds, negative_prompt_embeds
347 |
348 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
349 | def prepare_extra_step_kwargs(self, generator, eta):
350 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
351 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
352 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
353 | # and should be between [0, 1]
354 |
355 | accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
356 | extra_step_kwargs = {}
357 | if accepts_eta:
358 | extra_step_kwargs["eta"] = eta
359 |
360 | # check if the scheduler accepts generator
361 | accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
362 | if accepts_generator:
363 | extra_step_kwargs["generator"] = generator
364 | return extra_step_kwargs
365 |
366 | def check_inputs(
367 | self,
368 | prompt,
369 | image,
370 | batch_size,
371 | noise_level,
372 | callback_steps,
373 | negative_prompt=None,
374 | prompt_embeds=None,
375 | negative_prompt_embeds=None,
376 | ):
377 | if (callback_steps is None) or (
378 | callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
379 | ):
380 | raise ValueError(
381 | f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
382 | f" {type(callback_steps)}."
383 | )
384 |
385 | if prompt is not None and prompt_embeds is not None:
386 | raise ValueError(
387 | f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
388 | " only forward one of the two."
389 | )
390 | elif prompt is None and prompt_embeds is None:
391 | raise ValueError(
392 | "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
393 | )
394 | elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
395 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
396 |
397 | if negative_prompt is not None and negative_prompt_embeds is not None:
398 | raise ValueError(
399 | f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
400 | f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
401 | )
402 |
403 | if prompt_embeds is not None and negative_prompt_embeds is not None:
404 | if prompt_embeds.shape != negative_prompt_embeds.shape:
405 | raise ValueError(
406 | "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
407 | f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
408 | f" {negative_prompt_embeds.shape}."
409 | )
410 |
411 | if noise_level < 0 or noise_level >= self.image_noising_scheduler.config.num_train_timesteps:
412 | raise ValueError(
413 | f"`noise_level`: {noise_level} must be a valid timestep in `self.noising_scheduler`, [0, {self.image_noising_scheduler.config.num_train_timesteps})"
414 | )
415 |
416 | if isinstance(image, list):
417 | check_image_type = image[0]
418 | else:
419 | check_image_type = image
420 |
421 | if (
422 | not isinstance(check_image_type, torch.Tensor)
423 | and not isinstance(check_image_type, PIL.Image.Image)
424 | and not isinstance(check_image_type, np.ndarray)
425 | ):
426 | raise ValueError(
427 | "`image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is"
428 | f" {type(check_image_type)}"
429 | )
430 |
431 | if isinstance(image, list):
432 | image_batch_size = len(image)
433 | elif isinstance(image, torch.Tensor):
434 | image_batch_size = image.shape[0]
435 | elif isinstance(image, PIL.Image.Image):
436 | image_batch_size = 1
437 | elif isinstance(image, np.ndarray):
438 | image_batch_size = image.shape[0]
439 | else:
440 | assert False
441 |
442 | if batch_size != image_batch_size:
443 | raise ValueError(f"image batch size: {image_batch_size} must be same as prompt batch size {batch_size}")
444 |
445 | def prepare_intermediate_images(self, batch_size, num_channels, num_frames, height, width, dtype, device, generator):
446 | shape = (batch_size, num_channels, num_frames, height, width)
447 | if isinstance(generator, list) and len(generator) != batch_size:
448 | raise ValueError(
449 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
450 | f" size of {batch_size}. Make sure the batch size matches the length of the generators."
451 | )
452 |
453 | intermediate_images = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
454 |
455 | # scale the initial noise by the standard deviation required by the scheduler
456 | intermediate_images = intermediate_images * self.scheduler.init_noise_sigma
457 | return intermediate_images
458 |
459 | def preprocess_image(self, image, num_images_per_prompt, device):
460 | if not isinstance(image, torch.Tensor) and not isinstance(image, list):
461 | image = [image]
462 |
463 | if isinstance(image[0], PIL.Image.Image):
464 | image = [np.array(i).astype(np.float32) / 255.0 for i in image]
465 |
466 | image = np.stack(image, axis=0) # to np
467 | torch.from_numpy(image.transpose(0, 3, 1, 2))
468 | elif isinstance(image[0], np.ndarray):
469 | image = np.stack(image, axis=0) # to np
470 | if image.ndim == 5:
471 | image = image[0]
472 |
473 | image = torch.from_numpy(image.transpose(0, 3, 1, 2))
474 | elif isinstance(image, list) and isinstance(image[0], torch.Tensor):
475 | dims = image[0].ndim
476 |
477 | if dims == 3:
478 | image = torch.stack(image, dim=0)
479 | elif dims == 4:
480 | image = torch.concat(image, dim=0)
481 | else:
482 | raise ValueError(f"Image must have 3 or 4 dimensions, instead got {dims}")
483 |
484 | image = image.to(device=device, dtype=self.unet.dtype)
485 |
486 | image = image.repeat_interleave(num_images_per_prompt, dim=0)
487 |
488 | return image
489 |
490 | def _text_preprocessing(self, text, clean_caption=False):
491 | if clean_caption and not is_bs4_available():
492 | logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
493 | logger.warn("Setting `clean_caption` to False...")
494 | clean_caption = False
495 |
496 | if clean_caption and not is_ftfy_available():
497 | logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
498 | logger.warn("Setting `clean_caption` to False...")
499 | clean_caption = False
500 |
501 | if not isinstance(text, (tuple, list)):
502 | text = [text]
503 |
504 | def process(text: str):
505 | if clean_caption:
506 | text = self._clean_caption(text)
507 | text = self._clean_caption(text)
508 | else:
509 | text = text.lower().strip()
510 | return text
511 |
512 | return [process(t) for t in text]
513 |
514 | def _clean_caption(self, caption):
515 | caption = str(caption)
516 | caption = ul.unquote_plus(caption)
517 | caption = caption.strip().lower()
518 | caption = re.sub("", "person", caption)
519 | # urls:
520 | caption = re.sub(
521 | r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
522 | "",
523 | caption,
524 | ) # regex for urls
525 | caption = re.sub(
526 | r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
527 | "",
528 | caption,
529 | ) # regex for urls
530 | # html:
531 | caption = BeautifulSoup(caption, features="html.parser").text
532 |
533 | # @
534 | caption = re.sub(r"@[\w\d]+\b", "", caption)
535 |
536 | # 31C0—31EF CJK Strokes
537 | # 31F0—31FF Katakana Phonetic Extensions
538 | # 3200—32FF Enclosed CJK Letters and Months
539 | # 3300—33FF CJK Compatibility
540 | # 3400—4DBF CJK Unified Ideographs Extension A
541 | # 4DC0—4DFF Yijing Hexagram Symbols
542 | # 4E00—9FFF CJK Unified Ideographs
543 | caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
544 | caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
545 | caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
546 | caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
547 | caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
548 | caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
549 | caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
550 | #######################################################
551 |
552 | # все виды тире / all types of dash --> "-"
553 | caption = re.sub(
554 | r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
555 | "-",
556 | caption,
557 | )
558 |
559 | # кавычки к одному стандарту
560 | caption = re.sub(r"[`´«»“”¨]", '"', caption)
561 | caption = re.sub(r"[‘’]", "'", caption)
562 |
563 | # "
564 | caption = re.sub(r""?", "", caption)
565 | # &
566 | caption = re.sub(r"&", "", caption)
567 |
568 | # ip adresses:
569 | caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
570 |
571 | # article ids:
572 | caption = re.sub(r"\d:\d\d\s+$", "", caption)
573 |
574 | # \n
575 | caption = re.sub(r"\\n", " ", caption)
576 |
577 | # "#123"
578 | caption = re.sub(r"#\d{1,3}\b", "", caption)
579 | # "#12345.."
580 | caption = re.sub(r"#\d{5,}\b", "", caption)
581 | # "123456.."
582 | caption = re.sub(r"\b\d{6,}\b", "", caption)
583 | # filenames:
584 | caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
585 |
586 | #
587 | caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
588 | caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
589 |
590 | caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
591 | caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
592 |
593 | # this-is-my-cute-cat / this_is_my_cute_cat
594 | regex2 = re.compile(r"(?:\-|\_)")
595 | if len(re.findall(regex2, caption)) > 3:
596 | caption = re.sub(regex2, " ", caption)
597 |
598 | caption = ftfy.fix_text(caption)
599 | caption = html.unescape(html.unescape(caption))
600 |
601 | caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
602 | caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
603 | caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
604 |
605 | caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
606 | caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
607 | caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
608 | caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
609 | caption = re.sub(r"\bpage\s+\d+\b", "", caption)
610 |
611 | caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
612 |
613 | caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
614 |
615 | caption = re.sub(r"\b\s+\:\s+", r": ", caption)
616 | caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
617 | caption = re.sub(r"\s+", " ", caption)
618 |
619 | caption.strip()
620 |
621 | caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
622 | caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
623 | caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
624 | caption = re.sub(r"^\.\S+$", "", caption)
625 |
626 | return caption.strip()
627 |
628 | @torch.no_grad()
629 | def __call__(
630 | self,
631 | prompt: Union[str, List[str]] = None,
632 | height: Optional[int] = None,
633 | width: Optional[int] = None,
634 | image: Union[PIL.Image.Image, np.ndarray, torch.FloatTensor] = None,
635 | num_inference_steps: int = 50,
636 | timesteps: List[int] = None,
637 | guidance_scale: float = 4.0,
638 | negative_prompt: Optional[Union[str, List[str]]] = None,
639 | num_images_per_prompt: Optional[int] = 1,
640 | eta: float = 0.0,
641 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
642 | prompt_embeds: Optional[torch.FloatTensor] = None,
643 | negative_prompt_embeds: Optional[torch.FloatTensor] = None,
644 | output_type: Optional[str] = "np",
645 | return_dict: bool = True,
646 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
647 | callback_steps: int = 1,
648 | cross_attention_kwargs: Optional[Dict[str, Any]] = None,
649 | noise_level: int = 20,
650 | clean_caption: bool = True,
651 | ):
652 | """
653 | Function invoked when calling the pipeline for generation.
654 |
655 | Args:
656 | prompt (`str` or `List[str]`, *optional*):
657 | The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
658 | instead.
659 | height (`int`, *optional*, defaults to self.unet.config.sample_size):
660 | The height in pixels of the generated image.
661 | width (`int`, *optional*, defaults to self.unet.config.sample_size):
662 | The width in pixels of the generated image.
663 | image (`PIL.Image.Image`, `np.ndarray`, `torch.FloatTensor`):
664 | The image to be upscaled.
665 | num_inference_steps (`int`, *optional*, defaults to 50):
666 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the
667 | expense of slower inference.
668 | timesteps (`List[int]`, *optional*):
669 | Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
670 | timesteps are used. Must be in descending order.
671 | guidance_scale (`float`, *optional*, defaults to 7.5):
672 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
673 | `guidance_scale` is defined as `w` of equation 2. of [Imagen
674 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
675 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
676 | usually at the expense of lower image quality.
677 | negative_prompt (`str` or `List[str]`, *optional*):
678 | The prompt or prompts not to guide the image generation. If not defined, one has to pass
679 | `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
680 | less than `1`).
681 | num_images_per_prompt (`int`, *optional*, defaults to 1):
682 | The number of images to generate per prompt.
683 | eta (`float`, *optional*, defaults to 0.0):
684 | Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
685 | [`schedulers.DDIMScheduler`], will be ignored for others.
686 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
687 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
688 | to make generation deterministic.
689 | prompt_embeds (`torch.FloatTensor`, *optional*):
690 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
691 | provided, text embeddings will be generated from `prompt` input argument.
692 | negative_prompt_embeds (`torch.FloatTensor`, *optional*):
693 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
694 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
695 | argument.
696 | output_type (`str`, *optional*, defaults to `"pil"`):
697 | The output format of the generate image. Choose between
698 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
699 | return_dict (`bool`, *optional*, defaults to `True`):
700 | Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
701 | callback (`Callable`, *optional*):
702 | A function that will be called every `callback_steps` steps during inference. The function will be
703 | called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
704 | callback_steps (`int`, *optional*, defaults to 1):
705 | The frequency at which the `callback` function will be called. If not specified, the callback will be
706 | called at every step.
707 | cross_attention_kwargs (`dict`, *optional*):
708 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
709 | `self.processor` in
710 | [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
711 | noise_level (`int`, *optional*, defaults to 250):
712 | The amount of noise to add to the upscaled image. Must be in the range `[0, 1000)`
713 | clean_caption (`bool`, *optional*, defaults to `True`):
714 | Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
715 | be installed. If the dependencies are not installed, the embeddings will be created from the raw
716 | prompt.
717 |
718 | Examples:
719 |
720 | Returns:
721 | [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`:
722 | [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When
723 | returning a tuple, the first element is a list with the generated images, and the second element is a list
724 | of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw)
725 | or watermarked content, according to the `safety_checker`.
726 | """
727 | # 1. Check inputs. Raise error if not correct
728 |
729 | if prompt is not None and isinstance(prompt, str):
730 | batch_size = 1
731 | elif prompt is not None and isinstance(prompt, list):
732 | batch_size = len(prompt)
733 | else:
734 | batch_size = prompt_embeds.shape[0]
735 |
736 | self.check_inputs(
737 | prompt,
738 | image,
739 | batch_size,
740 | noise_level,
741 | callback_steps,
742 | negative_prompt,
743 | prompt_embeds,
744 | negative_prompt_embeds,
745 | )
746 |
747 | # 2. Define call parameters
748 |
749 | height = height or self.unet.config.sample_size
750 | width = width or self.unet.config.sample_size
751 | assert isinstance(image, torch.Tensor), f"{type(image)} is not supported."
752 | num_frames = image.shape[2]
753 |
754 | device = self._execution_device
755 |
756 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
757 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
758 | # corresponds to doing no classifier free guidance.
759 | do_classifier_free_guidance = guidance_scale > 1.0
760 |
761 | # 3. Encode input prompt
762 | prompt_embeds, negative_prompt_embeds = self.encode_prompt(
763 | prompt,
764 | do_classifier_free_guidance,
765 | num_images_per_prompt=num_images_per_prompt,
766 | device=device,
767 | negative_prompt=negative_prompt,
768 | prompt_embeds=prompt_embeds,
769 | negative_prompt_embeds=negative_prompt_embeds,
770 | clean_caption=clean_caption,
771 | )
772 |
773 | if do_classifier_free_guidance:
774 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
775 |
776 | # 4. Prepare timesteps
777 | if timesteps is not None:
778 | self.scheduler.set_timesteps(timesteps=timesteps, device=device)
779 | timesteps = self.scheduler.timesteps
780 | num_inference_steps = len(timesteps)
781 | else:
782 | self.scheduler.set_timesteps(num_inference_steps, device=device)
783 | timesteps = self.scheduler.timesteps
784 |
785 | # 5. Prepare intermediate images
786 | num_channels = self.unet.config.in_channels // 2
787 | intermediate_images = self.prepare_intermediate_images(
788 | batch_size * num_images_per_prompt,
789 | num_channels,
790 | num_frames,
791 | height,
792 | width,
793 | prompt_embeds.dtype,
794 | device,
795 | generator,
796 | )
797 |
798 | # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
799 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
800 |
801 | # 7. Prepare upscaled image and noise level
802 | image = self.preprocess_image(image, num_images_per_prompt, device)
803 | upscaled = rearrange(image, "b c f h w -> (b f) c h w")
804 | upscaled = F.interpolate(upscaled, (height, width), mode="bilinear", align_corners=True)
805 | upscaled = rearrange(upscaled, "(b f) c h w -> b c f h w", f=image.shape[2])
806 |
807 | noise_level = torch.tensor([noise_level] * upscaled.shape[0], device=upscaled.device)
808 | noise = randn_tensor(upscaled.shape, generator=generator, device=upscaled.device, dtype=upscaled.dtype)
809 | upscaled = self.image_noising_scheduler.add_noise(upscaled, noise, timesteps=noise_level)
810 |
811 | if do_classifier_free_guidance:
812 | noise_level = torch.cat([noise_level] * 2)
813 |
814 | # HACK: see comment in `enable_model_cpu_offload`
815 | if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None:
816 | self.text_encoder_offload_hook.offload()
817 |
818 | # 8. Denoising loop
819 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
820 | with self.progress_bar(total=num_inference_steps) as progress_bar:
821 | for i, t in enumerate(timesteps):
822 | model_input = torch.cat([intermediate_images, upscaled], dim=1)
823 |
824 | model_input = torch.cat([model_input] * 2) if do_classifier_free_guidance else model_input
825 | model_input = self.scheduler.scale_model_input(model_input, t)
826 |
827 | # predict the noise residual
828 | noise_pred = self.unet(
829 | model_input,
830 | t,
831 | encoder_hidden_states=prompt_embeds,
832 | class_labels=noise_level,
833 | cross_attention_kwargs=cross_attention_kwargs,
834 | ).sample
835 |
836 | # perform guidance
837 | if do_classifier_free_guidance:
838 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
839 | noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1)
840 | noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1] // 2, dim=1)
841 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
842 | noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
843 |
844 | # reshape latents
845 | bsz, channel, frames, height, width = intermediate_images.shape
846 | intermediate_images = intermediate_images.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, height, width)
847 | noise_pred = noise_pred.permute(0, 2, 1, 3, 4).reshape(bsz * frames, -1, height, width)
848 |
849 | # compute the previous noisy sample x_t -> x_t-1
850 | intermediate_images = self.scheduler.step(
851 | noise_pred, t, intermediate_images, **extra_step_kwargs
852 | ).prev_sample
853 |
854 | # reshape latents back
855 | intermediate_images = intermediate_images[None, :].reshape(bsz, frames, channel, height, width).permute(0, 2, 1, 3, 4)
856 |
857 | # call the callback, if provided
858 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
859 | progress_bar.update()
860 | if callback is not None and i % callback_steps == 0:
861 | callback(i, t, intermediate_images)
862 |
863 | video_tensor = intermediate_images
864 |
865 | if output_type == "pt":
866 | video = video_tensor
867 | else:
868 | video = tensor2vid(video_tensor)
869 |
870 | # Offload last model to CPU
871 | if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
872 | self.final_offload_hook.offload()
873 |
874 | if not return_dict:
875 | return (video,)
876 |
877 | return TextToVideoPipelineOutput(frames=video)
878 |
--------------------------------------------------------------------------------
/showone/pipelines/pipeline_t2v_sr_pixel_cond.py:
--------------------------------------------------------------------------------
1 | import html
2 | import inspect
3 | import re
4 | import urllib.parse as ul
5 | from typing import Any, Callable, Dict, List, Optional, Union
6 |
7 | import numpy as np
8 | import PIL
9 | import torch
10 | import torch.nn.functional as F
11 | from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer
12 | from einops import rearrange
13 |
14 | from diffusers.loaders import LoraLoaderMixin
15 | from diffusers.schedulers import DDPMScheduler
16 | from diffusers.utils import (
17 | BACKENDS_MAPPING,
18 | is_accelerate_available,
19 | is_accelerate_version,
20 | is_bs4_available,
21 | is_ftfy_available,
22 | logging,
23 | randn_tensor,
24 | replace_example_docstring,
25 | )
26 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline
27 |
28 | from ..models import UNet3DConditionModel
29 | from . import TextToVideoPipelineOutput
30 |
31 |
32 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
33 |
34 | if is_bs4_available():
35 | from bs4 import BeautifulSoup
36 |
37 | if is_ftfy_available():
38 | import ftfy
39 |
40 |
41 | def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -> List[np.ndarray]:
42 | # This code is copied from https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
43 | # reshape to ncfhw
44 | mean = torch.tensor(mean, device=video.device).reshape(1, -1, 1, 1, 1)
45 | std = torch.tensor(std, device=video.device).reshape(1, -1, 1, 1, 1)
46 | # unnormalize back to [0,1]
47 | video = video.mul_(std).add_(mean)
48 | video.clamp_(0, 1)
49 | # prepare the final outputs
50 | i, c, f, h, w = video.shape
51 | images = video.permute(2, 3, 0, 4, 1).reshape(
52 | f, h, i * w, c
53 | ) # 1st (frames, h, batch_size, w, c) 2nd (frames, h, batch_size * w, c)
54 | images = images.unbind(dim=0) # prepare a list of indvidual (consecutive frames)
55 | images = [(image.cpu().numpy() * 255).astype("uint8") for image in images] # f h w c
56 | return images
57 |
58 |
59 | class TextToVideoIFSuperResolutionPipeline_Cond(DiffusionPipeline, LoraLoaderMixin):
60 | tokenizer: T5Tokenizer
61 | text_encoder: T5EncoderModel
62 |
63 | unet: UNet3DConditionModel
64 | scheduler: DDPMScheduler
65 | image_noising_scheduler: DDPMScheduler
66 |
67 | feature_extractor: Optional[CLIPImageProcessor]
68 | # safety_checker: Optional[IFSafetyChecker]
69 |
70 | # watermarker: Optional[IFWatermarker]
71 |
72 | bad_punct_regex = re.compile(
73 | r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}"
74 | ) # noqa
75 |
76 | _optional_components = ["tokenizer", "text_encoder", "safety_checker", "feature_extractor", "watermarker"]
77 |
78 | def __init__(
79 | self,
80 | tokenizer: T5Tokenizer,
81 | text_encoder: T5EncoderModel,
82 | unet: UNet3DConditionModel,
83 | scheduler: DDPMScheduler,
84 | image_noising_scheduler: DDPMScheduler,
85 | feature_extractor: Optional[CLIPImageProcessor],
86 | ):
87 | super().__init__()
88 |
89 | self.register_modules(
90 | tokenizer=tokenizer,
91 | text_encoder=text_encoder,
92 | unet=unet,
93 | scheduler=scheduler,
94 | image_noising_scheduler=image_noising_scheduler,
95 | feature_extractor=feature_extractor,
96 | )
97 | self.safety_checker = None
98 |
99 | def enable_sequential_cpu_offload(self, gpu_id=0):
100 | r"""
101 | Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
102 | models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
103 | when their specific submodule has its `forward` method called.
104 | """
105 | if is_accelerate_available():
106 | from accelerate import cpu_offload
107 | else:
108 | raise ImportError("Please install accelerate via `pip install accelerate`")
109 |
110 | device = torch.device(f"cuda:{gpu_id}")
111 |
112 | models = [
113 | self.text_encoder,
114 | self.unet,
115 | ]
116 | for cpu_offloaded_model in models:
117 | if cpu_offloaded_model is not None:
118 | cpu_offload(cpu_offloaded_model, device)
119 |
120 | if self.safety_checker is not None:
121 | cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
122 |
123 | def enable_model_cpu_offload(self, gpu_id=0):
124 | r"""
125 | Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
126 | to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
127 | method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
128 | `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
129 | """
130 | if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
131 | from accelerate import cpu_offload_with_hook
132 | else:
133 | raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
134 |
135 | device = torch.device(f"cuda:{gpu_id}")
136 |
137 | if self.device.type != "cpu":
138 | self.to("cpu", silence_dtype_warnings=True)
139 | torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
140 |
141 | hook = None
142 |
143 | if self.text_encoder is not None:
144 | _, hook = cpu_offload_with_hook(self.text_encoder, device, prev_module_hook=hook)
145 |
146 | # Accelerate will move the next model to the device _before_ calling the offload hook of the
147 | # previous model. This will cause both models to be present on the device at the same time.
148 | # IF uses T5 for its text encoder which is really large. We can manually call the offload
149 | # hook for the text encoder to ensure it's moved to the cpu before the unet is moved to
150 | # the GPU.
151 | self.text_encoder_offload_hook = hook
152 |
153 | _, hook = cpu_offload_with_hook(self.unet, device, prev_module_hook=hook)
154 |
155 | # if the safety checker isn't called, `unet_offload_hook` will have to be called to manually offload the unet
156 | self.unet_offload_hook = hook
157 |
158 | if self.safety_checker is not None:
159 | _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
160 |
161 | # We'll offload the last model manually.
162 | self.final_offload_hook = hook
163 |
164 | def remove_all_hooks(self):
165 | if is_accelerate_available():
166 | from accelerate.hooks import remove_hook_from_module
167 | else:
168 | raise ImportError("Please install accelerate via `pip install accelerate`")
169 |
170 | for model in [self.text_encoder, self.unet, self.safety_checker]:
171 | if model is not None:
172 | remove_hook_from_module(model, recurse=True)
173 |
174 | self.unet_offload_hook = None
175 | self.text_encoder_offload_hook = None
176 | self.final_offload_hook = None
177 |
178 | @property
179 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
180 | def _execution_device(self):
181 | r"""
182 | Returns the device on which the pipeline's models will be executed. After calling
183 | `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
184 | hooks.
185 | """
186 | if not hasattr(self.unet, "_hf_hook"):
187 | return self.device
188 | for module in self.unet.modules():
189 | if (
190 | hasattr(module, "_hf_hook")
191 | and hasattr(module._hf_hook, "execution_device")
192 | and module._hf_hook.execution_device is not None
193 | ):
194 | return torch.device(module._hf_hook.execution_device)
195 | return self.device
196 |
197 | @torch.no_grad()
198 | def encode_prompt(
199 | self,
200 | prompt,
201 | do_classifier_free_guidance=True,
202 | num_images_per_prompt=1,
203 | device=None,
204 | negative_prompt=None,
205 | prompt_embeds: Optional[torch.FloatTensor] = None,
206 | negative_prompt_embeds: Optional[torch.FloatTensor] = None,
207 | clean_caption: bool = False,
208 | ):
209 | r"""
210 | Encodes the prompt into text encoder hidden states.
211 |
212 | Args:
213 | prompt (`str` or `List[str]`, *optional*):
214 | prompt to be encoded
215 | device: (`torch.device`, *optional*):
216 | torch device to place the resulting embeddings on
217 | num_images_per_prompt (`int`, *optional*, defaults to 1):
218 | number of images that should be generated per prompt
219 | do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
220 | whether to use classifier free guidance or not
221 | negative_prompt (`str` or `List[str]`, *optional*):
222 | The prompt or prompts not to guide the image generation. If not defined, one has to pass
223 | `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
224 | Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
225 | prompt_embeds (`torch.FloatTensor`, *optional*):
226 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
227 | provided, text embeddings will be generated from `prompt` input argument.
228 | negative_prompt_embeds (`torch.FloatTensor`, *optional*):
229 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
230 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
231 | argument.
232 | """
233 | if prompt is not None and negative_prompt is not None:
234 | if type(prompt) is not type(negative_prompt):
235 | raise TypeError(
236 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
237 | f" {type(prompt)}."
238 | )
239 |
240 | if device is None:
241 | device = self._execution_device
242 |
243 | if prompt is not None and isinstance(prompt, str):
244 | batch_size = 1
245 | elif prompt is not None and isinstance(prompt, list):
246 | batch_size = len(prompt)
247 | else:
248 | batch_size = prompt_embeds.shape[0]
249 |
250 | # while T5 can handle much longer input sequences than 77, the text encoder was trained with a max length of 77 for IF
251 | max_length = 77
252 |
253 | if prompt_embeds is None:
254 | prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
255 | text_inputs = self.tokenizer(
256 | prompt,
257 | padding="max_length",
258 | max_length=max_length,
259 | truncation=True,
260 | add_special_tokens=True,
261 | return_tensors="pt",
262 | )
263 | text_input_ids = text_inputs.input_ids
264 | untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
265 |
266 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
267 | text_input_ids, untruncated_ids
268 | ):
269 | removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
270 | logger.warning(
271 | "The following part of your input was truncated because CLIP can only handle sequences up to"
272 | f" {max_length} tokens: {removed_text}"
273 | )
274 |
275 | attention_mask = text_inputs.attention_mask.to(device)
276 |
277 | prompt_embeds = self.text_encoder(
278 | text_input_ids.to(device),
279 | attention_mask=attention_mask,
280 | )
281 | prompt_embeds = prompt_embeds[0]
282 |
283 | if self.text_encoder is not None:
284 | dtype = self.text_encoder.dtype
285 | elif self.unet is not None:
286 | dtype = self.unet.dtype
287 | else:
288 | dtype = None
289 |
290 | prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
291 |
292 | bs_embed, seq_len, _ = prompt_embeds.shape
293 | # duplicate text embeddings for each generation per prompt, using mps friendly method
294 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
295 | prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
296 |
297 | # get unconditional embeddings for classifier free guidance
298 | if do_classifier_free_guidance and negative_prompt_embeds is None:
299 | uncond_tokens: List[str]
300 | if negative_prompt is None:
301 | uncond_tokens = [""] * batch_size
302 | elif isinstance(negative_prompt, str):
303 | uncond_tokens = [negative_prompt]
304 | elif batch_size != len(negative_prompt):
305 | raise ValueError(
306 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
307 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
308 | " the batch size of `prompt`."
309 | )
310 | else:
311 | uncond_tokens = negative_prompt
312 |
313 | uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
314 | max_length = prompt_embeds.shape[1]
315 | uncond_input = self.tokenizer(
316 | uncond_tokens,
317 | padding="max_length",
318 | max_length=max_length,
319 | truncation=True,
320 | return_attention_mask=True,
321 | add_special_tokens=True,
322 | return_tensors="pt",
323 | )
324 | attention_mask = uncond_input.attention_mask.to(device)
325 |
326 | negative_prompt_embeds = self.text_encoder(
327 | uncond_input.input_ids.to(device),
328 | attention_mask=attention_mask,
329 | )
330 | negative_prompt_embeds = negative_prompt_embeds[0]
331 |
332 | if do_classifier_free_guidance:
333 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
334 | seq_len = negative_prompt_embeds.shape[1]
335 |
336 | negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
337 |
338 | negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
339 | negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
340 |
341 | # For classifier free guidance, we need to do two forward passes.
342 | # Here we concatenate the unconditional and text embeddings into a single batch
343 | # to avoid doing two forward passes
344 | else:
345 | negative_prompt_embeds = None
346 |
347 | return prompt_embeds, negative_prompt_embeds
348 |
349 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
350 | def prepare_extra_step_kwargs(self, generator, eta):
351 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
352 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
353 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
354 | # and should be between [0, 1]
355 |
356 | accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
357 | extra_step_kwargs = {}
358 | if accepts_eta:
359 | extra_step_kwargs["eta"] = eta
360 |
361 | # check if the scheduler accepts generator
362 | accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
363 | if accepts_generator:
364 | extra_step_kwargs["generator"] = generator
365 | return extra_step_kwargs
366 |
367 | def check_inputs(
368 | self,
369 | prompt,
370 | image,
371 | batch_size,
372 | noise_level,
373 | callback_steps,
374 | negative_prompt=None,
375 | prompt_embeds=None,
376 | negative_prompt_embeds=None,
377 | ):
378 | if (callback_steps is None) or (
379 | callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
380 | ):
381 | raise ValueError(
382 | f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
383 | f" {type(callback_steps)}."
384 | )
385 |
386 | if prompt is not None and prompt_embeds is not None:
387 | raise ValueError(
388 | f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
389 | " only forward one of the two."
390 | )
391 | elif prompt is None and prompt_embeds is None:
392 | raise ValueError(
393 | "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
394 | )
395 | elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
396 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
397 |
398 | if negative_prompt is not None and negative_prompt_embeds is not None:
399 | raise ValueError(
400 | f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
401 | f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
402 | )
403 |
404 | if prompt_embeds is not None and negative_prompt_embeds is not None:
405 | if prompt_embeds.shape != negative_prompt_embeds.shape:
406 | raise ValueError(
407 | "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
408 | f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
409 | f" {negative_prompt_embeds.shape}."
410 | )
411 |
412 | if noise_level < 0 or noise_level >= self.image_noising_scheduler.config.num_train_timesteps:
413 | raise ValueError(
414 | f"`noise_level`: {noise_level} must be a valid timestep in `self.noising_scheduler`, [0, {self.image_noising_scheduler.config.num_train_timesteps})"
415 | )
416 |
417 | if isinstance(image, list):
418 | check_image_type = image[0]
419 | else:
420 | check_image_type = image
421 |
422 | if (
423 | not isinstance(check_image_type, torch.Tensor)
424 | and not isinstance(check_image_type, PIL.Image.Image)
425 | and not isinstance(check_image_type, np.ndarray)
426 | ):
427 | raise ValueError(
428 | "`image` has to be of type `torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, or List[...] but is"
429 | f" {type(check_image_type)}"
430 | )
431 |
432 | if isinstance(image, list):
433 | image_batch_size = len(image)
434 | elif isinstance(image, torch.Tensor):
435 | image_batch_size = image.shape[0]
436 | elif isinstance(image, PIL.Image.Image):
437 | image_batch_size = 1
438 | elif isinstance(image, np.ndarray):
439 | image_batch_size = image.shape[0]
440 | else:
441 | assert False
442 |
443 | if batch_size != image_batch_size:
444 | raise ValueError(f"image batch size: {image_batch_size} must be same as prompt batch size {batch_size}")
445 |
446 | def prepare_intermediate_images(self, batch_size, num_channels, num_frames, height, width, dtype, device, generator):
447 | shape = (batch_size, num_channels, num_frames, height, width)
448 | if isinstance(generator, list) and len(generator) != batch_size:
449 | raise ValueError(
450 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
451 | f" size of {batch_size}. Make sure the batch size matches the length of the generators."
452 | )
453 |
454 | intermediate_images = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
455 |
456 | # scale the initial noise by the standard deviation required by the scheduler
457 | intermediate_images = intermediate_images * self.scheduler.init_noise_sigma
458 | return intermediate_images
459 |
460 | def preprocess_image(self, image, num_images_per_prompt, device):
461 | if not isinstance(image, torch.Tensor) and not isinstance(image, list):
462 | image = [image]
463 |
464 | if isinstance(image[0], PIL.Image.Image):
465 | image = [np.array(i).astype(np.float32) / 255.0 for i in image]
466 |
467 | image = np.stack(image, axis=0) # to np
468 | torch.from_numpy(image.transpose(0, 3, 1, 2))
469 | elif isinstance(image[0], np.ndarray):
470 | image = np.stack(image, axis=0) # to np
471 | if image.ndim == 5:
472 | image = image[0]
473 |
474 | image = torch.from_numpy(image.transpose(0, 3, 1, 2))
475 | elif isinstance(image, list) and isinstance(image[0], torch.Tensor):
476 | dims = image[0].ndim
477 |
478 | if dims == 3:
479 | image = torch.stack(image, dim=0)
480 | elif dims == 4:
481 | image = torch.concat(image, dim=0)
482 | else:
483 | raise ValueError(f"Image must have 3 or 4 dimensions, instead got {dims}")
484 |
485 | image = image.to(device=device, dtype=self.unet.dtype)
486 |
487 | image = image.repeat_interleave(num_images_per_prompt, dim=0)
488 |
489 | return image
490 |
491 | def _text_preprocessing(self, text, clean_caption=False):
492 | if clean_caption and not is_bs4_available():
493 | logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
494 | logger.warn("Setting `clean_caption` to False...")
495 | clean_caption = False
496 |
497 | if clean_caption and not is_ftfy_available():
498 | logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
499 | logger.warn("Setting `clean_caption` to False...")
500 | clean_caption = False
501 |
502 | if not isinstance(text, (tuple, list)):
503 | text = [text]
504 |
505 | def process(text: str):
506 | if clean_caption:
507 | text = self._clean_caption(text)
508 | text = self._clean_caption(text)
509 | else:
510 | text = text.lower().strip()
511 | return text
512 |
513 | return [process(t) for t in text]
514 |
515 | def _clean_caption(self, caption):
516 | caption = str(caption)
517 | caption = ul.unquote_plus(caption)
518 | caption = caption.strip().lower()
519 | caption = re.sub("", "person", caption)
520 | # urls:
521 | caption = re.sub(
522 | r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
523 | "",
524 | caption,
525 | ) # regex for urls
526 | caption = re.sub(
527 | r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
528 | "",
529 | caption,
530 | ) # regex for urls
531 | # html:
532 | caption = BeautifulSoup(caption, features="html.parser").text
533 |
534 | # @
535 | caption = re.sub(r"@[\w\d]+\b", "", caption)
536 |
537 | # 31C0—31EF CJK Strokes
538 | # 31F0—31FF Katakana Phonetic Extensions
539 | # 3200—32FF Enclosed CJK Letters and Months
540 | # 3300—33FF CJK Compatibility
541 | # 3400—4DBF CJK Unified Ideographs Extension A
542 | # 4DC0—4DFF Yijing Hexagram Symbols
543 | # 4E00—9FFF CJK Unified Ideographs
544 | caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
545 | caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
546 | caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
547 | caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
548 | caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
549 | caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
550 | caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
551 | #######################################################
552 |
553 | # все виды тире / all types of dash --> "-"
554 | caption = re.sub(
555 | r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
556 | "-",
557 | caption,
558 | )
559 |
560 | # кавычки к одному стандарту
561 | caption = re.sub(r"[`´«»“”¨]", '"', caption)
562 | caption = re.sub(r"[‘’]", "'", caption)
563 |
564 | # "
565 | caption = re.sub(r""?", "", caption)
566 | # &
567 | caption = re.sub(r"&", "", caption)
568 |
569 | # ip adresses:
570 | caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
571 |
572 | # article ids:
573 | caption = re.sub(r"\d:\d\d\s+$", "", caption)
574 |
575 | # \n
576 | caption = re.sub(r"\\n", " ", caption)
577 |
578 | # "#123"
579 | caption = re.sub(r"#\d{1,3}\b", "", caption)
580 | # "#12345.."
581 | caption = re.sub(r"#\d{5,}\b", "", caption)
582 | # "123456.."
583 | caption = re.sub(r"\b\d{6,}\b", "", caption)
584 | # filenames:
585 | caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
586 |
587 | #
588 | caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
589 | caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
590 |
591 | caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
592 | caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
593 |
594 | # this-is-my-cute-cat / this_is_my_cute_cat
595 | regex2 = re.compile(r"(?:\-|\_)")
596 | if len(re.findall(regex2, caption)) > 3:
597 | caption = re.sub(regex2, " ", caption)
598 |
599 | caption = ftfy.fix_text(caption)
600 | caption = html.unescape(html.unescape(caption))
601 |
602 | caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
603 | caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
604 | caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
605 |
606 | caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
607 | caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
608 | caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
609 | caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
610 | caption = re.sub(r"\bpage\s+\d+\b", "", caption)
611 |
612 | caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
613 |
614 | caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
615 |
616 | caption = re.sub(r"\b\s+\:\s+", r": ", caption)
617 | caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
618 | caption = re.sub(r"\s+", " ", caption)
619 |
620 | caption.strip()
621 |
622 | caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
623 | caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
624 | caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
625 | caption = re.sub(r"^\.\S+$", "", caption)
626 |
627 | return caption.strip()
628 |
629 | @torch.no_grad()
630 | def __call__(
631 | self,
632 | prompt: Union[str, List[str]] = None,
633 | height: Optional[int] = None,
634 | width: Optional[int] = None,
635 | image: Union[PIL.Image.Image, np.ndarray, torch.FloatTensor] = None,
636 | first_frame_cond: Union[PIL.Image.Image, np.ndarray, torch.FloatTensor] = None,
637 | all_frame_cond: Union[PIL.Image.Image, np.ndarray, torch.FloatTensor] = None,
638 | num_inference_steps: int = 50,
639 | timesteps: List[int] = None,
640 | guidance_scale: float = 4.0,
641 | negative_prompt: Optional[Union[str, List[str]]] = None,
642 | num_images_per_prompt: Optional[int] = 1,
643 | eta: float = 0.0,
644 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
645 | prompt_embeds: Optional[torch.FloatTensor] = None,
646 | negative_prompt_embeds: Optional[torch.FloatTensor] = None,
647 | output_type: Optional[str] = "np",
648 | return_dict: bool = True,
649 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
650 | callback_steps: int = 1,
651 | cross_attention_kwargs: Optional[Dict[str, Any]] = None,
652 | noise_level: int = 250,
653 | clean_caption: bool = True,
654 | ):
655 | """
656 | Function invoked when calling the pipeline for generation.
657 |
658 | Args:
659 | prompt (`str` or `List[str]`, *optional*):
660 | The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
661 | instead.
662 | height (`int`, *optional*, defaults to self.unet.config.sample_size):
663 | The height in pixels of the generated image.
664 | width (`int`, *optional*, defaults to self.unet.config.sample_size):
665 | The width in pixels of the generated image.
666 | image (`PIL.Image.Image`, `np.ndarray`, `torch.FloatTensor`):
667 | The image to be upscaled.
668 | num_inference_steps (`int`, *optional*, defaults to 50):
669 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the
670 | expense of slower inference.
671 | timesteps (`List[int]`, *optional*):
672 | Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
673 | timesteps are used. Must be in descending order.
674 | guidance_scale (`float`, *optional*, defaults to 7.5):
675 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
676 | `guidance_scale` is defined as `w` of equation 2. of [Imagen
677 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
678 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
679 | usually at the expense of lower image quality.
680 | negative_prompt (`str` or `List[str]`, *optional*):
681 | The prompt or prompts not to guide the image generation. If not defined, one has to pass
682 | `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
683 | less than `1`).
684 | num_images_per_prompt (`int`, *optional*, defaults to 1):
685 | The number of images to generate per prompt.
686 | eta (`float`, *optional*, defaults to 0.0):
687 | Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
688 | [`schedulers.DDIMScheduler`], will be ignored for others.
689 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
690 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
691 | to make generation deterministic.
692 | prompt_embeds (`torch.FloatTensor`, *optional*):
693 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
694 | provided, text embeddings will be generated from `prompt` input argument.
695 | negative_prompt_embeds (`torch.FloatTensor`, *optional*):
696 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
697 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
698 | argument.
699 | output_type (`str`, *optional*, defaults to `"pil"`):
700 | The output format of the generate image. Choose between
701 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
702 | return_dict (`bool`, *optional*, defaults to `True`):
703 | Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
704 | callback (`Callable`, *optional*):
705 | A function that will be called every `callback_steps` steps during inference. The function will be
706 | called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
707 | callback_steps (`int`, *optional*, defaults to 1):
708 | The frequency at which the `callback` function will be called. If not specified, the callback will be
709 | called at every step.
710 | cross_attention_kwargs (`dict`, *optional*):
711 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
712 | `self.processor` in
713 | [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
714 | noise_level (`int`, *optional*, defaults to 250):
715 | The amount of noise to add to the upscaled image. Must be in the range `[0, 1000)`
716 | clean_caption (`bool`, *optional*, defaults to `True`):
717 | Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
718 | be installed. If the dependencies are not installed, the embeddings will be created from the raw
719 | prompt.
720 |
721 | Examples:
722 |
723 | Returns:
724 | [`~pipelines.stable_diffusion.IFPipelineOutput`] or `tuple`:
725 | [`~pipelines.stable_diffusion.IFPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When
726 | returning a tuple, the first element is a list with the generated images, and the second element is a list
727 | of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw)
728 | or watermarked content, according to the `safety_checker`.
729 | """
730 | # 1. Check inputs. Raise error if not correct
731 |
732 | if prompt is not None and isinstance(prompt, str):
733 | batch_size = 1
734 | elif prompt is not None and isinstance(prompt, list):
735 | batch_size = len(prompt)
736 | else:
737 | batch_size = prompt_embeds.shape[0]
738 |
739 | self.check_inputs(
740 | prompt,
741 | image,
742 | batch_size,
743 | noise_level,
744 | callback_steps,
745 | negative_prompt,
746 | prompt_embeds,
747 | negative_prompt_embeds,
748 | )
749 |
750 | # 2. Define call parameters
751 |
752 | height = height or self.unet.config.sample_size
753 | width = width or self.unet.config.sample_size
754 | assert isinstance(image, torch.Tensor), f"{type(image)} is not supported."
755 | num_frames = image.shape[2]
756 |
757 | device = self._execution_device
758 |
759 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
760 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
761 | # corresponds to doing no classifier free guidance.
762 | do_classifier_free_guidance = guidance_scale > 1.0
763 |
764 | # 3. Encode input prompt
765 | prompt_embeds, negative_prompt_embeds = self.encode_prompt(
766 | prompt,
767 | do_classifier_free_guidance,
768 | num_images_per_prompt=num_images_per_prompt,
769 | device=device,
770 | negative_prompt=negative_prompt,
771 | prompt_embeds=prompt_embeds,
772 | negative_prompt_embeds=negative_prompt_embeds,
773 | clean_caption=clean_caption,
774 | )
775 |
776 | if do_classifier_free_guidance:
777 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
778 |
779 | # 4. Prepare timesteps
780 | if timesteps is not None:
781 | self.scheduler.set_timesteps(timesteps=timesteps, device=device)
782 | timesteps = self.scheduler.timesteps
783 | num_inference_steps = len(timesteps)
784 | else:
785 | self.scheduler.set_timesteps(num_inference_steps, device=device)
786 | timesteps = self.scheduler.timesteps
787 |
788 | # 5. Prepare intermediate images
789 | num_channels = self.unet.config.in_channels // 2
790 | intermediate_images = self.prepare_intermediate_images(
791 | batch_size * num_images_per_prompt,
792 | num_channels,
793 | num_frames,
794 | height,
795 | width,
796 | prompt_embeds.dtype,
797 | device,
798 | generator,
799 | )
800 |
801 | # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
802 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
803 |
804 | # 7. Prepare upscaled image and noise level
805 | image = self.preprocess_image(image, num_images_per_prompt, device)
806 | # upscaled = F.interpolate(image, (num_frames, height, width), mode="trilinear", align_corners=True)
807 | if all_frame_cond is not None:
808 | upscaled = all_frame_cond
809 | else:
810 | upscaled = rearrange(image, "b c f h w -> (b f) c h w")
811 | upscaled = F.interpolate(upscaled, (height, width), mode="bilinear", align_corners=True)
812 | upscaled = rearrange(upscaled, "(b f) c h w -> b c f h w", f=image.shape[2])
813 |
814 | noise_level = torch.tensor([noise_level] * upscaled.shape[0], device=upscaled.device)
815 | noise = randn_tensor(upscaled.shape, generator=generator, device=upscaled.device, dtype=upscaled.dtype)
816 | upscaled = self.image_noising_scheduler.add_noise(upscaled, noise, timesteps=noise_level)
817 | if first_frame_cond is not None:
818 | first_frame_cond = first_frame_cond.to(device=device, dtype=self.unet.dtype)
819 | upscaled[:,:,:1,:,:] = first_frame_cond
820 |
821 | if do_classifier_free_guidance:
822 | noise_level = torch.cat([noise_level] * 2)
823 |
824 | # HACK: see comment in `enable_model_cpu_offload`
825 | if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None:
826 | self.text_encoder_offload_hook.offload()
827 |
828 | # 8. Denoising loop
829 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
830 | with self.progress_bar(total=num_inference_steps) as progress_bar:
831 | for i, t in enumerate(timesteps):
832 | model_input = torch.cat([intermediate_images, upscaled], dim=1)
833 |
834 | model_input = torch.cat([model_input] * 2) if do_classifier_free_guidance else model_input
835 | model_input = self.scheduler.scale_model_input(model_input, t)
836 |
837 | # predict the noise residual
838 | noise_pred = self.unet(
839 | model_input,
840 | t,
841 | encoder_hidden_states=prompt_embeds,
842 | class_labels=noise_level,
843 | cross_attention_kwargs=cross_attention_kwargs,
844 | ).sample
845 |
846 | # perform guidance
847 | if do_classifier_free_guidance:
848 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
849 | noise_pred_uncond, _ = noise_pred_uncond.split(model_input.shape[1] // 2, dim=1)
850 | noise_pred_text, predicted_variance = noise_pred_text.split(model_input.shape[1] // 2, dim=1)
851 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
852 | noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
853 |
854 | if self.scheduler.config.variance_type not in ["learned", "learned_range"]:
855 | noise_pred, _ = noise_pred.split(intermediate_images.shape[1], dim=1)
856 |
857 | # reshape latents
858 | bsz, channel, frames, height, width = intermediate_images.shape
859 | intermediate_images = intermediate_images.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, height, width)
860 | noise_pred = noise_pred.permute(0, 2, 1, 3, 4).reshape(bsz * frames, -1, height, width)
861 |
862 | # compute the previous noisy sample x_t -> x_t-1
863 | intermediate_images = self.scheduler.step(
864 | noise_pred, t, intermediate_images, **extra_step_kwargs
865 | ).prev_sample
866 |
867 | # reshape latents back
868 | intermediate_images = intermediate_images[None, :].reshape(bsz, frames, channel, height, width).permute(0, 2, 1, 3, 4)
869 |
870 | # call the callback, if provided
871 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
872 | progress_bar.update()
873 | if callback is not None and i % callback_steps == 0:
874 | callback(i, t, intermediate_images)
875 |
876 | video_tensor = intermediate_images
877 |
878 | if output_type == "pt":
879 | video = video_tensor
880 | else:
881 | video = tensor2vid(video_tensor)
882 |
883 | # Offload last model to CPU
884 | if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
885 | self.final_offload_hook.offload()
886 |
887 | if not return_dict:
888 | return (video,)
889 |
890 | return TextToVideoPipelineOutput(frames=video)
891 |
--------------------------------------------------------------------------------