├── .github
├── FUNDING.yml
└── ISSUE_TEMPLATE
│ ├── bug_report.yml
│ ├── config.yml
│ └── feature_request.yml
├── .gitignore
├── LICENSE
├── README.md
├── docs
├── CHANGELOG.md
├── demo.md
├── features.md
├── how-to-use.md
└── performance.md
├── model
└── .gitkeep
├── motion_module.py
└── scripts
├── animatediff.py
├── animatediff_freeinit.py
├── animatediff_i2ibatch.py
├── animatediff_infotext.py
├── animatediff_infv2v.py
├── animatediff_latent.py
├── animatediff_logger.py
├── animatediff_mm.py
├── animatediff_output.py
├── animatediff_prompt.py
├── animatediff_settings.py
├── animatediff_ui.py
├── animatediff_utils.py
└── animatediff_xyz.py
/.github/FUNDING.yml:
--------------------------------------------------------------------------------
1 | # These are supported funding model platforms
2 |
3 | github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
4 | patreon: # Replace with a single Patreon username
5 | open_collective: # Replace with a single Open Collective username
6 | ko_fi: conrevo # Replace with a single Ko-fi username
7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
9 | liberapay: # Replace with a single Liberapay username
10 | issuehunt: # Replace with a single IssueHunt username
11 | otechie: # Replace with a single Otechie username
12 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
13 | custom: ['https://paypal.me/conrevo', 'https://afdian.net/a/conrevo'] # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']
14 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.yml:
--------------------------------------------------------------------------------
1 | name: Bug Report
2 | description: Create a bug report
3 | title: "[Bug]: "
4 | labels: ["bug-report"]
5 |
6 | body:
7 | - type: checkboxes
8 | attributes:
9 | label: Is there an existing issue for this?
10 | description: Please search both open issues and closed issues to see if an issue already exists for the bug you encountered, and that it hasn't been fixed in a recent build/commit.
11 | options:
12 | - label: I have searched the existing issues and checked the recent builds/commits of both this extension and the webui
13 | required: true
14 | - type: checkboxes
15 | attributes:
16 | label: Have you read FAQ on README?
17 | description: I have collected some common questions from AnimateDiff original repository.
18 | options:
19 | - label: I have updated WebUI and this extension to the latest version
20 | required: true
21 | - type: markdown
22 | attributes:
23 | value: |
24 | *Please fill this form with as much information as possible, don't forget to fill "What OS..." and "What browsers" and *provide screenshots if possible**
25 | - type: textarea
26 | id: what-did
27 | attributes:
28 | label: What happened?
29 | description: Tell us what happened in a very clear and simple way
30 | validations:
31 | required: true
32 | - type: textarea
33 | id: steps
34 | attributes:
35 | label: Steps to reproduce the problem
36 | description: Please provide us with precise step by step information on how to reproduce the bug
37 | value: |
38 | 1. Go to ....
39 | 2. Press ....
40 | 3. ...
41 | validations:
42 | required: true
43 | - type: textarea
44 | id: what-should
45 | attributes:
46 | label: What should have happened?
47 | description: Tell what you think the normal behavior should be
48 | validations:
49 | required: true
50 | - type: textarea
51 | id: commits
52 | attributes:
53 | label: Commit where the problem happens
54 | description: Which commit of the extension are you running on? Please include the commit of both the extension and the webui (Do not write *Latest version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Commit** link at the bottom of the UI, or from the cmd/terminal if you can't launch it.)
55 | value: |
56 | webui:
57 | extension:
58 | validations:
59 | required: true
60 | - type: dropdown
61 | id: browsers
62 | attributes:
63 | label: What browsers do you use to access the UI ?
64 | multiple: true
65 | options:
66 | - Mozilla Firefox
67 | - Google Chrome
68 | - Brave
69 | - Apple Safari
70 | - Microsoft Edge
71 | - type: textarea
72 | id: cmdargs
73 | attributes:
74 | label: Command Line Arguments
75 | description: Are you using any launching parameters/command line arguments (modified webui-user .bat/.sh) ? If yes, please write them below. Write "No" otherwise.
76 | render: Shell
77 | validations:
78 | required: true
79 | - type: textarea
80 | id: logs
81 | attributes:
82 | label: Console logs
83 | description: Please provide the errors printed on your console log of your browser (type F12 and go to console) and your terminal, after your bug happened.
84 | render: Shell
85 | validations:
86 | required: true
87 | - type: textarea
88 | id: misc
89 | attributes:
90 | label: Additional information
91 | description: Please provide us with any relevant additional info or context.
92 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/config.yml:
--------------------------------------------------------------------------------
1 | blank_issues_enabled: true
2 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.yml:
--------------------------------------------------------------------------------
1 | name: Feature Request
2 | description: Create a feature request
3 | title: "[Feature]: "
4 | labels: ["feature-request"]
5 |
6 | body:
7 | - type: textarea
8 | id: feature
9 | attributes:
10 | label: Expected behavior
11 | description: Please describe the feature you want.
12 | validations:
13 | required: true
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | model/*.*
3 | model/*.*
4 | TODO.md
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Attribution-NonCommercial-ShareAlike 4.0 International
2 |
3 | =======================================================================
4 |
5 | Creative Commons Corporation ("Creative Commons") is not a law firm and
6 | does not provide legal services or legal advice. Distribution of
7 | Creative Commons public licenses does not create a lawyer-client or
8 | other relationship. Creative Commons makes its licenses and related
9 | information available on an "as-is" basis. Creative Commons gives no
10 | warranties regarding its licenses, any material licensed under their
11 | terms and conditions, or any related information. Creative Commons
12 | disclaims all liability for damages resulting from their use to the
13 | fullest extent possible.
14 |
15 | Using Creative Commons Public Licenses
16 |
17 | Creative Commons public licenses provide a standard set of terms and
18 | conditions that creators and other rights holders may use to share
19 | original works of authorship and other material subject to copyright
20 | and certain other rights specified in the public license below. The
21 | following considerations are for informational purposes only, are not
22 | exhaustive, and do not form part of our licenses.
23 |
24 | Considerations for licensors: Our public licenses are
25 | intended for use by those authorized to give the public
26 | permission to use material in ways otherwise restricted by
27 | copyright and certain other rights. Our licenses are
28 | irrevocable. Licensors should read and understand the terms
29 | and conditions of the license they choose before applying it.
30 | Licensors should also secure all rights necessary before
31 | applying our licenses so that the public can reuse the
32 | material as expected. Licensors should clearly mark any
33 | material not subject to the license. This includes other CC-
34 | licensed material, or material used under an exception or
35 | limitation to copyright. More considerations for licensors:
36 | wiki.creativecommons.org/Considerations_for_licensors
37 |
38 | Considerations for the public: By using one of our public
39 | licenses, a licensor grants the public permission to use the
40 | licensed material under specified terms and conditions. If
41 | the licensor's permission is not necessary for any reason--for
42 | example, because of any applicable exception or limitation to
43 | copyright--then that use is not regulated by the license. Our
44 | licenses grant only permissions under copyright and certain
45 | other rights that a licensor has authority to grant. Use of
46 | the licensed material may still be restricted for other
47 | reasons, including because others have copyright or other
48 | rights in the material. A licensor may make special requests,
49 | such as asking that all changes be marked or described.
50 | Although not required by our licenses, you are encouraged to
51 | respect those requests where reasonable. More considerations
52 | for the public:
53 | wiki.creativecommons.org/Considerations_for_licensees
54 |
55 | =======================================================================
56 |
57 | Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
58 | Public License
59 |
60 | By exercising the Licensed Rights (defined below), You accept and agree
61 | to be bound by the terms and conditions of this Creative Commons
62 | Attribution-NonCommercial-ShareAlike 4.0 International Public License
63 | ("Public License"). To the extent this Public License may be
64 | interpreted as a contract, You are granted the Licensed Rights in
65 | consideration of Your acceptance of these terms and conditions, and the
66 | Licensor grants You such rights in consideration of benefits the
67 | Licensor receives from making the Licensed Material available under
68 | these terms and conditions.
69 |
70 |
71 | Section 1 -- Definitions.
72 |
73 | a. Adapted Material means material subject to Copyright and Similar
74 | Rights that is derived from or based upon the Licensed Material
75 | and in which the Licensed Material is translated, altered,
76 | arranged, transformed, or otherwise modified in a manner requiring
77 | permission under the Copyright and Similar Rights held by the
78 | Licensor. For purposes of this Public License, where the Licensed
79 | Material is a musical work, performance, or sound recording,
80 | Adapted Material is always produced where the Licensed Material is
81 | synched in timed relation with a moving image.
82 |
83 | b. Adapter's License means the license You apply to Your Copyright
84 | and Similar Rights in Your contributions to Adapted Material in
85 | accordance with the terms and conditions of this Public License.
86 |
87 | c. BY-NC-SA Compatible License means a license listed at
88 | creativecommons.org/compatiblelicenses, approved by Creative
89 | Commons as essentially the equivalent of this Public License.
90 |
91 | d. Copyright and Similar Rights means copyright and/or similar rights
92 | closely related to copyright including, without limitation,
93 | performance, broadcast, sound recording, and Sui Generis Database
94 | Rights, without regard to how the rights are labeled or
95 | categorized. For purposes of this Public License, the rights
96 | specified in Section 2(b)(1)-(2) are not Copyright and Similar
97 | Rights.
98 |
99 | e. Effective Technological Measures means those measures that, in the
100 | absence of proper authority, may not be circumvented under laws
101 | fulfilling obligations under Article 11 of the WIPO Copyright
102 | Treaty adopted on December 20, 1996, and/or similar international
103 | agreements.
104 |
105 | f. Exceptions and Limitations means fair use, fair dealing, and/or
106 | any other exception or limitation to Copyright and Similar Rights
107 | that applies to Your use of the Licensed Material.
108 |
109 | g. License Elements means the license attributes listed in the name
110 | of a Creative Commons Public License. The License Elements of this
111 | Public License are Attribution, NonCommercial, and ShareAlike.
112 |
113 | h. Licensed Material means the artistic or literary work, database,
114 | or other material to which the Licensor applied this Public
115 | License.
116 |
117 | i. Licensed Rights means the rights granted to You subject to the
118 | terms and conditions of this Public License, which are limited to
119 | all Copyright and Similar Rights that apply to Your use of the
120 | Licensed Material and that the Licensor has authority to license.
121 |
122 | j. Licensor means the individual(s) or entity(ies) granting rights
123 | under this Public License.
124 |
125 | k. NonCommercial means not primarily intended for or directed towards
126 | commercial advantage or monetary compensation. For purposes of
127 | this Public License, the exchange of the Licensed Material for
128 | other material subject to Copyright and Similar Rights by digital
129 | file-sharing or similar means is NonCommercial provided there is
130 | no payment of monetary compensation in connection with the
131 | exchange.
132 |
133 | l. Share means to provide material to the public by any means or
134 | process that requires permission under the Licensed Rights, such
135 | as reproduction, public display, public performance, distribution,
136 | dissemination, communication, or importation, and to make material
137 | available to the public including in ways that members of the
138 | public may access the material from a place and at a time
139 | individually chosen by them.
140 |
141 | m. Sui Generis Database Rights means rights other than copyright
142 | resulting from Directive 96/9/EC of the European Parliament and of
143 | the Council of 11 March 1996 on the legal protection of databases,
144 | as amended and/or succeeded, as well as other essentially
145 | equivalent rights anywhere in the world.
146 |
147 | n. You means the individual or entity exercising the Licensed Rights
148 | under this Public License. Your has a corresponding meaning.
149 |
150 |
151 | Section 2 -- Scope.
152 |
153 | a. License grant.
154 |
155 | 1. Subject to the terms and conditions of this Public License,
156 | the Licensor hereby grants You a worldwide, royalty-free,
157 | non-sublicensable, non-exclusive, irrevocable license to
158 | exercise the Licensed Rights in the Licensed Material to:
159 |
160 | a. reproduce and Share the Licensed Material, in whole or
161 | in part, for NonCommercial purposes only; and
162 |
163 | b. produce, reproduce, and Share Adapted Material for
164 | NonCommercial purposes only.
165 |
166 | 2. Exceptions and Limitations. For the avoidance of doubt, where
167 | Exceptions and Limitations apply to Your use, this Public
168 | License does not apply, and You do not need to comply with
169 | its terms and conditions.
170 |
171 | 3. Term. The term of this Public License is specified in Section
172 | 6(a).
173 |
174 | 4. Media and formats; technical modifications allowed. The
175 | Licensor authorizes You to exercise the Licensed Rights in
176 | all media and formats whether now known or hereafter created,
177 | and to make technical modifications necessary to do so. The
178 | Licensor waives and/or agrees not to assert any right or
179 | authority to forbid You from making technical modifications
180 | necessary to exercise the Licensed Rights, including
181 | technical modifications necessary to circumvent Effective
182 | Technological Measures. For purposes of this Public License,
183 | simply making modifications authorized by this Section 2(a)
184 | (4) never produces Adapted Material.
185 |
186 | 5. Downstream recipients.
187 |
188 | a. Offer from the Licensor -- Licensed Material. Every
189 | recipient of the Licensed Material automatically
190 | receives an offer from the Licensor to exercise the
191 | Licensed Rights under the terms and conditions of this
192 | Public License.
193 |
194 | b. Additional offer from the Licensor -- Adapted Material.
195 | Every recipient of Adapted Material from You
196 | automatically receives an offer from the Licensor to
197 | exercise the Licensed Rights in the Adapted Material
198 | under the conditions of the Adapter's License You apply.
199 |
200 | c. No downstream restrictions. You may not offer or impose
201 | any additional or different terms or conditions on, or
202 | apply any Effective Technological Measures to, the
203 | Licensed Material if doing so restricts exercise of the
204 | Licensed Rights by any recipient of the Licensed
205 | Material.
206 |
207 | 6. No endorsement. Nothing in this Public License constitutes or
208 | may be construed as permission to assert or imply that You
209 | are, or that Your use of the Licensed Material is, connected
210 | with, or sponsored, endorsed, or granted official status by,
211 | the Licensor or others designated to receive attribution as
212 | provided in Section 3(a)(1)(A)(i).
213 |
214 | b. Other rights.
215 |
216 | 1. Moral rights, such as the right of integrity, are not
217 | licensed under this Public License, nor are publicity,
218 | privacy, and/or other similar personality rights; however, to
219 | the extent possible, the Licensor waives and/or agrees not to
220 | assert any such rights held by the Licensor to the limited
221 | extent necessary to allow You to exercise the Licensed
222 | Rights, but not otherwise.
223 |
224 | 2. Patent and trademark rights are not licensed under this
225 | Public License.
226 |
227 | 3. To the extent possible, the Licensor waives any right to
228 | collect royalties from You for the exercise of the Licensed
229 | Rights, whether directly or through a collecting society
230 | under any voluntary or waivable statutory or compulsory
231 | licensing scheme. In all other cases the Licensor expressly
232 | reserves any right to collect such royalties, including when
233 | the Licensed Material is used other than for NonCommercial
234 | purposes.
235 |
236 |
237 | Section 3 -- License Conditions.
238 |
239 | Your exercise of the Licensed Rights is expressly made subject to the
240 | following conditions.
241 |
242 | a. Attribution.
243 |
244 | 1. If You Share the Licensed Material (including in modified
245 | form), You must:
246 |
247 | a. retain the following if it is supplied by the Licensor
248 | with the Licensed Material:
249 |
250 | i. identification of the creator(s) of the Licensed
251 | Material and any others designated to receive
252 | attribution, in any reasonable manner requested by
253 | the Licensor (including by pseudonym if
254 | designated);
255 |
256 | ii. a copyright notice;
257 |
258 | iii. a notice that refers to this Public License;
259 |
260 | iv. a notice that refers to the disclaimer of
261 | warranties;
262 |
263 | v. a URI or hyperlink to the Licensed Material to the
264 | extent reasonably practicable;
265 |
266 | b. indicate if You modified the Licensed Material and
267 | retain an indication of any previous modifications; and
268 |
269 | c. indicate the Licensed Material is licensed under this
270 | Public License, and include the text of, or the URI or
271 | hyperlink to, this Public License.
272 |
273 | 2. You may satisfy the conditions in Section 3(a)(1) in any
274 | reasonable manner based on the medium, means, and context in
275 | which You Share the Licensed Material. For example, it may be
276 | reasonable to satisfy the conditions by providing a URI or
277 | hyperlink to a resource that includes the required
278 | information.
279 | 3. If requested by the Licensor, You must remove any of the
280 | information required by Section 3(a)(1)(A) to the extent
281 | reasonably practicable.
282 |
283 | b. ShareAlike.
284 |
285 | In addition to the conditions in Section 3(a), if You Share
286 | Adapted Material You produce, the following conditions also apply.
287 |
288 | 1. The Adapter's License You apply must be a Creative Commons
289 | license with the same License Elements, this version or
290 | later, or a BY-NC-SA Compatible License.
291 |
292 | 2. You must include the text of, or the URI or hyperlink to, the
293 | Adapter's License You apply. You may satisfy this condition
294 | in any reasonable manner based on the medium, means, and
295 | context in which You Share Adapted Material.
296 |
297 | 3. You may not offer or impose any additional or different terms
298 | or conditions on, or apply any Effective Technological
299 | Measures to, Adapted Material that restrict exercise of the
300 | rights granted under the Adapter's License You apply.
301 |
302 |
303 | Section 4 -- Sui Generis Database Rights.
304 |
305 | Where the Licensed Rights include Sui Generis Database Rights that
306 | apply to Your use of the Licensed Material:
307 |
308 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right
309 | to extract, reuse, reproduce, and Share all or a substantial
310 | portion of the contents of the database for NonCommercial purposes
311 | only;
312 |
313 | b. if You include all or a substantial portion of the database
314 | contents in a database in which You have Sui Generis Database
315 | Rights, then the database in which You have Sui Generis Database
316 | Rights (but not its individual contents) is Adapted Material,
317 | including for purposes of Section 3(b); and
318 |
319 | c. You must comply with the conditions in Section 3(a) if You Share
320 | all or a substantial portion of the contents of the database.
321 |
322 | For the avoidance of doubt, this Section 4 supplements and does not
323 | replace Your obligations under this Public License where the Licensed
324 | Rights include other Copyright and Similar Rights.
325 |
326 |
327 | Section 5 -- Disclaimer of Warranties and Limitation of Liability.
328 |
329 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
330 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
331 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
332 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
333 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
334 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
335 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
336 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
337 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
338 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
339 |
340 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
341 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
342 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
343 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
344 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
345 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
346 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
347 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
348 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
349 |
350 | c. The disclaimer of warranties and limitation of liability provided
351 | above shall be interpreted in a manner that, to the extent
352 | possible, most closely approximates an absolute disclaimer and
353 | waiver of all liability.
354 |
355 |
356 | Section 6 -- Term and Termination.
357 |
358 | a. This Public License applies for the term of the Copyright and
359 | Similar Rights licensed here. However, if You fail to comply with
360 | this Public License, then Your rights under this Public License
361 | terminate automatically.
362 |
363 | b. Where Your right to use the Licensed Material has terminated under
364 | Section 6(a), it reinstates:
365 |
366 | 1. automatically as of the date the violation is cured, provided
367 | it is cured within 30 days of Your discovery of the
368 | violation; or
369 |
370 | 2. upon express reinstatement by the Licensor.
371 |
372 | For the avoidance of doubt, this Section 6(b) does not affect any
373 | right the Licensor may have to seek remedies for Your violations
374 | of this Public License.
375 |
376 | c. For the avoidance of doubt, the Licensor may also offer the
377 | Licensed Material under separate terms or conditions or stop
378 | distributing the Licensed Material at any time; however, doing so
379 | will not terminate this Public License.
380 |
381 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
382 | License.
383 |
384 |
385 | Section 7 -- Other Terms and Conditions.
386 |
387 | a. The Licensor shall not be bound by any additional or different
388 | terms or conditions communicated by You unless expressly agreed.
389 |
390 | b. Any arrangements, understandings, or agreements regarding the
391 | Licensed Material not stated herein are separate from and
392 | independent of the terms and conditions of this Public License.
393 |
394 |
395 | Section 8 -- Interpretation.
396 |
397 | a. For the avoidance of doubt, this Public License does not, and
398 | shall not be interpreted to, reduce, limit, restrict, or impose
399 | conditions on any use of the Licensed Material that could lawfully
400 | be made without permission under this Public License.
401 |
402 | b. To the extent possible, if any provision of this Public License is
403 | deemed unenforceable, it shall be automatically reformed to the
404 | minimum extent necessary to make it enforceable. If the provision
405 | cannot be reformed, it shall be severed from this Public License
406 | without affecting the enforceability of the remaining terms and
407 | conditions.
408 |
409 | c. No term or condition of this Public License will be waived and no
410 | failure to comply consented to unless expressly agreed to by the
411 | Licensor.
412 |
413 | d. Nothing in this Public License constitutes or may be interpreted
414 | as a limitation upon, or waiver of, any privileges and immunities
415 | that apply to the Licensor or You, including from the legal
416 | processes of any jurisdiction or authority.
417 |
418 | =======================================================================
419 |
420 | Creative Commons is not a party to its public
421 | licenses. Notwithstanding, Creative Commons may elect to apply one of
422 | its public licenses to material it publishes and in those instances
423 | will be considered the “Licensor.” The text of the Creative Commons
424 | public licenses is dedicated to the public domain under the CC0 Public
425 | Domain Dedication. Except for the limited purpose of indicating that
426 | material is shared under a Creative Commons public license or as
427 | otherwise permitted by the Creative Commons policies published at
428 | creativecommons.org/policies, Creative Commons does not authorize the
429 | use of the trademark "Creative Commons" or any other trademark or logo
430 | of Creative Commons without its prior written consent including,
431 | without limitation, in connection with any unauthorized modifications
432 | to any of its public licenses or any other arrangements,
433 | understandings, or agreements concerning use of licensed material. For
434 | the avoidance of doubt, this paragraph does not form part of the
435 | public licenses.
436 |
437 | Creative Commons may be contacted at creativecommons.org.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # AnimateDiff for Stable Diffusion WebUI
2 |
3 | > I have recently added a non-commercial [license](https://creativecommons.org/licenses/by-nc-sa/4.0/) to this extension. If you want to use this extension for commercial purpose, please contact me via email.
4 |
5 | This extension aim for integrating [AnimateDiff](https://github.com/guoyww/AnimateDiff/) with [CLI](https://github.com/s9roll7/animatediff-cli-prompt-travel) into [AUTOMATIC1111 Stable Diffusion WebUI](https://github.com/AUTOMATIC1111/stable-diffusion-webui) with [ControlNet](https://github.com/Mikubill/sd-webui-controlnet), and form the most easy-to-use AI video toolkit. You can generate GIFs in exactly the same way as generating images after enabling this extension.
6 |
7 | This extension implements AnimateDiff in a different way. It inserts motion modules into UNet at runtime, so that you do not need to reload your model weights if you don't want to.
8 |
9 | You might also be interested in another extension I created: [Segment Anything for Stable Diffusion WebUI](https://github.com/continue-revolution/sd-webui-segment-anything), which could be quite useful for inpainting.
10 |
11 | [Forge](https://github.com/lllyasviel/stable-diffusion-webui-forge) users should either checkout branch [forge/master](https://github.com/continue-revolution/sd-webui-animatediff/tree/forge/master) in this repository or use [sd-forge-animatediff](https://github.com/continue-revolution/sd-forge-animatediff). They will be in sync.
12 |
13 |
14 | ## Table of Contents
15 | [Update](#update) | [Future Plan](#future-plan) | [Model Zoo](#model-zoo) | [Documentation](#documentation) | [Tutorial](#tutorial) | [Thanks](#thanks) | [Star History](#star-history) | [Sponsor](#sponsor)
16 |
17 |
18 | ## Update
19 | - [v2.0.0-a](https://github.com/continue-revolution/sd-webui-animatediff/tree/v2.0.0-a) in `03/02/2024`: The whole extension has been reworked to make it easier to maintain.
20 | - Prerequisite: WebUI >= 1.8.0 & ControlNet >=1.1.441 & PyTorch >= 2.0.0
21 | - New feature:
22 | - ControlNet inpaint / IP-Adapter prompt travel / SparseCtrl / ControlNet keyframe, see [ControlNet V2V](docs/features.md#controlnet-v2v)
23 | - FreeInit, see [FreeInit](docs/features.md#FreeInit)
24 | - Minor: mm filter based on sd version (click refresh button if you switch between SD1.5 and SDXL) / display extension version in infotext
25 | - Breaking change: You must use Motion LoRA, Hotshot-XL, AnimateDiff V3 Motion Adapter from my [huggingface repo](https://huggingface.co/conrevo/AnimateDiff-A1111/tree/main).
26 | - [v2.0.1-a](https://github.com/continue-revolution/sd-webui-animatediff/tree/v2.0.1-a) in `07/12/2024`: Support [AnimateLCM](https://github.com/G-U-N/AnimateLCM) from MMLab@CUHK. See [here](docs/features.md#animatelcm) for instruction.
27 |
28 |
29 | ## Future Plan
30 | Although [OpenAI Sora](https://openai.com/sora) is far better at following complex text prompts and generating complex scenes, we believe that OpenAI will NOT open source Sora or any other other products they released recently. My current plan is to continue developing this extension until when an open-sourced video model is released, with strong ability to generate complex scenes, easy customization and good ecosystem like SD1.5.
31 |
32 | We will try our best to bring interesting researches into both WebUI and Forge as long as we can. Not all researches will be implemented. You are welcome to submit a feature request if you find an interesting one. We are also open to learn from other equivalent software.
33 |
34 | That said, due to the notorious difficulty in maintaining [sd-webui-controlnet](https://github.com/Mikubill/sd-webui-controlnet), we do NOT plan to implement ANY new research into WebUI if it touches "reference control", such as [Magic Animate](https://github.com/magic-research/magic-animate). Such features will be Forge only. Also, some advanced features in [ControlNet Forge Intergrated](https://github.com/lllyasviel/stable-diffusion-webui-forge/tree/main/extensions-builtin/sd_forge_controlnet), such as ControlNet per-frame mask, will also be Forge only. I really hope that I could have bandwidth to rework sd-webui-controlnet, but it requires a huge amount of time.
35 |
36 |
37 | ## Model Zoo
38 | I am maintaining a [huggingface repo](https://huggingface.co/conrevo/AnimateDiff-A1111/tree/main) to provide all official models in fp16 & safetensors format. You are highly recommended to use my link. You MUST use my link to download Motion LoRA, Hotshot-XL, AnimateDiff V3 Motion Adapter. You may still use the old links if you want, for all other models
39 |
40 | - "Official" models by [@guoyww](https://github.com/guoyww): [Google Drive](https://drive.google.com/drive/folders/1EqLC65eR1-W-sGD0Im7fkED6c8GkiNFI) | [HuggingFace](https://huggingface.co/guoyww/animatediff/tree/main) | [CivitAI](https://civitai.com/models/108836)
41 | - "Stabilized" community models by [@manshoety](https://huggingface.co/manshoety): [HuggingFace](https://huggingface.co/manshoety/AD_Stabilized_Motion/tree/main)
42 | - "TemporalDiff" models by [@CiaraRowles](https://huggingface.co/CiaraRowles): [HuggingFace](https://huggingface.co/CiaraRowles/TemporalDiff/tree/main)
43 |
44 |
45 | ## Documentation
46 | - [How to Use](docs/how-to-use.md) -> [Preparation](docs/how-to-use.md#preparation) | [WebUI](docs/how-to-use.md#webui) | [API](docs/how-to-use.md#api) | [Parameters](docs/how-to-use.md#parameters)
47 | - [Features](docs/features.md) -> [Img2Vid](docs/features.md#img2vid) | [Prompt Travel](docs/features.md#prompt-travel) | [ControlNet V2V](docs/features.md#controlnet-v2v) | [ [Model Spec](docs/features.md#model-spec) -> [Motion LoRA](docs/features.md#motion-lora) | [V3](docs/features.md#v3) | [SDXL](docs/features.md#sdxl) | [AnimateLCM](docs/features.md#animatelcm) ]
48 | - [Performance](docs/performance.md) -> [ [Optimizations](docs/performance.md#optimizations) -> [Attention](docs/performance.md#attention) | [FP8](docs/performance.md#fp8) | [LCM](docs/performance.md#lcm) ] | [VRAM](docs/performance.md#vram) | [Batch Size](docs/performance.md#batch-size)
49 | - [Demo](docs/demo.md) -> [Basic Usage](docs/demo.md#basic-usage) | [Motion LoRA](docs/demo.md#motion-lora) | [Prompt Travel](docs/demo.md#prompt-travel) | [AnimateDiff V3](docs/demo.md#animatediff-v3) | [AnimateDiff XL](docs/demo.md#animatediff-xl) | [ControlNet V2V](docs/demo.md#controlnet-v2v)
50 |
51 |
52 | ## Tutorial
53 | There are a lot of wonderful video tutorials on YouTube and bilibili, and you should check those out for now. For the time being, there are a series of updates on the way and I don't want to work on my own before I am satisfied. An official tutorial should come when I am satisfied with the available features.
54 |
55 |
56 | ## Thanks
57 | We thank all developers and community users who contribute to this repository in many ways, especially
58 | - [@guoyww](https://github.com/guoyww) for creating AnimateDiff
59 | - [@limbo0000](https://github.com/limbo0000) for responding to my questions about AnimateDiff
60 | - [@neggles](https://github.com/neggles) and [@s9roll7](https://github.com/s9roll7) for developing [AnimateDiff CLI Prompt Travel](https://github.com/s9roll7/animatediff-cli-prompt-travel)
61 | - [@zappityzap](https://github.com/zappityzap) for developing the majority of the [output features](https://github.com/continue-revolution/sd-webui-animatediff/blob/master/scripts/animatediff_output.py)
62 | - [@thiswinex](https://github.com/thiswinex) for developing FreeInit
63 | - [@lllyasviel](https://github.com/lllyasviel) for adding me as a collaborator of sd-webui-controlnet and offering technical support for Forge
64 | - [@KohakuBlueleaf](https://github.com/KohakuBlueleaf) for helping with FP8 and LCM development
65 | - [@TDS4874](https://github.com/TDS4874) and [@opparco](https://github.com/opparco) for resolving the grey issue which significantly improve the performance
66 | - [@streamline](https://twitter.com/kaizirod) for providing ControlNet V2V dataset and workflow. His workflow is extremely amazing and definitely worth checking out.
67 |
68 |
69 | ## Star History
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 | ## Sponsor
80 | You can sponsor me via WeChat, AliPay or [PayPal](https://paypal.me/conrevo). You can also support me via [ko-fi](https://ko-fi.com/conrevo) or [afdian](https://afdian.net/a/conrevo).
81 |
82 | | WeChat | AliPay | PayPal |
83 | | --- | --- | --- |
84 | |  |  |  |
85 |
--------------------------------------------------------------------------------
/docs/CHANGELOG.md:
--------------------------------------------------------------------------------
1 | This ducoment backs up all previous 1.0 updates.
2 | - `2023/07/20` [v1.1.0](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.1.0): Fix gif duration, add loop number, remove auto-download, remove xformers, remove instructions on gradio UI, refactor README, add [sponsor](#sponsor) QR code.
3 | - `2023/07/24` [v1.2.0](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.2.0): Fix incorrect insertion of motion modules, add option to change path to motion modules in `Settings/AnimateDiff`, fix loading different motion modules.
4 | - `2023/09/04` [v1.3.0](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.3.0): Support any community models with the same architecture; fix grey problem via [#63](https://github.com/continue-revolution/sd-webui-animatediff/issues/63)
5 | - `2023/09/11` [v1.4.0](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.4.0): Support official v2 motion module (different architecture: GroupNorm not hacked, UNet middle layer has motion module).
6 | - `2023/09/14`: [v1.4.1](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.4.1): Always change `beta`, `alpha_comprod` and `alpha_comprod_prev` to resolve grey problem in other samplers.
7 | - `2023/09/16`: [v1.5.0](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.5.0): Randomize init latent to support [better img2gif](#img2gif); add other output formats and infotext output; add appending reversed frames; refactor code to ease maintaining.
8 | - `2023/09/19`: [v1.5.1](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.5.1): Support xformers, sdp, sub-quadratic attention optimization - [VRAM](#vram) usage decrease to 5.60GB with default setting.
9 | - `2023/09/22`: [v1.5.2](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.5.2): Option to disable xformers at `Settings/AnimateDiff` [due to a bug in xformers](https://github.com/facebookresearch/xformers/issues/845), [API support](#api), option to enable GIF paletter optimization at `Settings/AnimateDiff`, gifsicle optimization move to `Settings/AnimateDiff`.
10 | - `2023/09/25`: [v1.6.0](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.6.0): [Motion LoRA](https://github.com/guoyww/AnimateDiff#features) supported. See [Motion Lora](#motion-lora) for more information.
11 | - `2023/09/27`: [v1.7.0](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.7.0): [ControlNet](https://github.com/Mikubill/sd-webui-controlnet) supported. See [ControlNet V2V](#controlnet-v2v) for more information. [Safetensors](#model-zoo) for some motion modules are also available now.
12 | - `2023/09/29`: [v1.8.0](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.8.0): Infinite generation supported. See [WebUI Parameters](#webui-parameters) for more information.
13 | - `2023/10/01`: [v1.8.1](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.8.1): Now you can uncheck `Batch cond/uncond` in `Settings/Optimization` if you want. This will reduce your [VRAM](#vram) (5.31GB -> 4.21GB for SDP) but take longer time.
14 | - `2023/10/08`: [v1.9.0](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.9.0): Prompt travel supported. You must have ControlNet installed (you do not need to enable ControlNet) to try it. See [Prompt Travel](#prompt-travel) for how to trigger this feature.
15 | - `2023/10/11`: [v1.9.1](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.9.1): Use state_dict key to guess mm version, replace match case with if else to support python<3.10, option to save PNG to custom dir
16 | (see `Settings/AnimateDiff` for detail), move hints to js, install imageio\[ffmpeg\] automatically when MP4 save fails.
17 | - `2023/10/16`: [v1.9.2](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.9.2): Add context generator to completely remove any closed loop, prompt travel support closed loop, infotext fully supported including prompt travel, README refactor
18 | - `2023/10/19`: [v1.9.3](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.9.3): Support webp output format. See [#233](https://github.com/continue-revolution/sd-webui-animatediff/pull/233) for more information.
19 | - `2023/10/21`: [v1.9.4](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.9.4): Save prompt travel to output images, `Reverse` merged to `Closed loop` (See [WebUI Parameters](#webui-parameters)), remove `TimestepEmbedSequential` hijack, remove `hints.js`, better explanation of several context-related parameters.
20 | - `2023/10/25`: [v1.10.0](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.10.0): Support img2img batch. You need ControlNet installed to make it work properly (you do not need to enable ControlNet). See [ControlNet V2V](#controlnet-v2v) for more information.
21 | - `2023/10/29`: [v1.11.0](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.11.0): [HotShot-XL](https://github.com/hotshotco/Hotshot-XL) supported. See [SDXL](#sdxl) for more information.
22 | - `2023/11/06`: [v1.11.1](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.11.1): Optimize VRAM for ControlNet V2V, patch [encode_pil_to_base64](https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/modules/api/api.py#L104-L133) for api return a video, save frames to `AnimateDiff/yy-mm-dd/`, recover from assertion error, optional [request id](#api) for API.
23 | - `2023/11/10`: [v1.12.0](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.12.0): [AnimateDiff for SDXL](https://github.com/guoyww/AnimateDiff/tree/sdxl) supported. See [SDXL](#sdxl) for more information.
24 | - `2023/11/16`: [v1.12.1](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.12.1): FP8 precision and LCM sampler supported. See [Optimizations](#optimizations) for more information. You can also optionally upload videos to AWS S3 storage by configuring appropriately via `Settings/AnimateDiff AWS`.
25 | - `2023/12/19`: [v1.13.0](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.13.0): [AnimateDiff V3](https://github.com/guoyww/AnimateDiff?tab=readme-ov-file#202312-animatediff-v3-and-sparsectrl) supported. See [V3](#v3) for more information. Also: release all official models in fp16 & safetensors format [here](https://huggingface.co/conrevo/AnimateDiff-A1111/tree/main), add option to disable LCM sampler in `Settings/AnimateDiff`, remove patch [encode_pil_to_base64](https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/modules/api/api.py#L104-L133) because A1111 [v1.7.0](https://github.com/AUTOMATIC1111/stable-diffusion-webui/tree/v1.7.0) now supports video return for API.
26 | - `2024/01/12`: [v1.13.1](https://github.com/continue-revolution/sd-webui-animatediff/releases/tag/v1.13.1): This small version update completely comes from the community. We fix mp4 encode error [#402](https://github.com/continue-revolution/sd-webui-animatediff/pull/402), support infotext copy-paste [#400](https://github.com/continue-revolution/sd-webui-animatediff/pull/400), validate prompt travel frame numbers [#401](https://github.com/continue-revolution/sd-webui-animatediff/pull/401).
--------------------------------------------------------------------------------
/docs/demo.md:
--------------------------------------------------------------------------------
1 | # Demo
2 |
3 | ## Basic Usage
4 | | AnimateDiff | Extension | img2img |
5 | | --- | --- | --- |
6 | |  | |  |
7 |
8 | ## Motion LoRA
9 | | No LoRA | PanDown | PanLeft |
10 | | --- | --- | --- |
11 | |  |  |  |
12 |
13 | ## Prompt Travel
14 | 
15 |
16 | The prompt is similar to [here](features.md#prompt-travel).
17 |
18 | ## AnimateDiff V3
19 | You should be able to read infotext to understand how I generated this sample.
20 | 
21 |
22 |
23 | ## AnimateDiff XL
24 | You should be able to read infotext to understand how I generated this sample.
25 |
26 |
27 |
28 | ## ControlNet V2V
29 | See [here](features.md#controlnet-v2v)
30 |
--------------------------------------------------------------------------------
/docs/features.md:
--------------------------------------------------------------------------------
1 | # Features
2 |
3 | ## Img2Vid
4 | > I believe that there are better ways to do i2v. New methods will be implemented soon and this old and unstable way might be subject to removal.
5 |
6 | You need to go to img2img and submit an init frame via A1111 panel. You can optionally submit a last frame via extension panel.
7 |
8 | By default: your `init_latent` will be changed to
9 | ```
10 | init_alpha = (1 - frame_number ^ latent_power / latent_scale)
11 | init_latent = init_latent * init_alpha + random_tensor * (1 - init_alpha)
12 | ```
13 |
14 | If you upload a last frame: your `init_latent` will be changed in a similar way. Read [this code](https://github.com/continue-revolution/sd-webui-animatediff/tree/v1.5.0/scripts/animatediff_latent.py#L28-L65) to understand how it works.
15 |
16 |
17 | ## Prompt Travel
18 |
19 | Write positive prompt following the example below.
20 |
21 | The first line is head prompt, which is optional. You can write no/single/multiple lines of head prompts.
22 |
23 | All following lines in format `frame number`: `prompt` are for prompt interpolation. Your `frame number` should be in ascending order, smaller than the total `Number of frames`. The first frame is 0 index.
24 |
25 | The last line is tail prompt, which is optional. You can write no/single/multiple lines of tail prompts. If you don't need this feature, just write prompts in the old way.
26 | ```
27 | 1girl, yoimiya (genshin impact), origen, line, comet, wink, Masterpiece, BestQuality. UltraDetailed, , ,
28 | 0: closed mouth
29 | 8: open mouth
30 | smile
31 | ```
32 |
33 | ## FreeInit
34 |
35 | It allows you to use more time to get more coherent and consistent video frames.
36 |
37 | The default parameters provide satisfactory results for most use cases. Increasing the number of iterations can yield better outcomes, but it also prolongs the processing time. If your video contains more intense or rapid motions, consider switching the filter to Gaussian. For a detailed explanation of each parameter, please refer to the documentation in the [original repository](https://github.com/TianxingWu/FreeInit).
38 |
39 | | without FreeInit | with FreeInit (default params) |
40 | | --- | --- |
41 | |  |  |
42 |
43 |
44 | ## ControlNet V2V
45 | You need to go to txt2img / img2img-batch and submit source video or path to frames. Each ControlNet will find control images according to this priority:
46 | 1. ControlNet `Single Image` tab or `Batch Folder` tab. Simply upload a control image or a path to folder of control frames is enough.
47 | 1. Img2img Batch tab `Input directory` if you are using img2img batch. If you upload a directory of control frames, it will be the source control for ALL ControlNet units that you enable without submitting a control image or a path to ControlNet panel.
48 | 1. AnimateDiff `Video Path`. If you upload a path to frames through `Video Path`, it will be the source control for ALL ControlNet units that you enable without submitting a control image or a path to ControlNet panel.
49 | 1. AnimateDiff `Video Source`. If you upload a video through `Video Source`, it will be the source control for ALL ControlNet units that you enable without submitting a control image or a path to ControlNet panel.
50 |
51 | `Number of frames` will be capped to the minimum number of images among all **folders** you provide, unless it has a "keyframe" parameter.
52 |
53 | **SparseCtrl**: Sparse ControlNet is for video generation with key frames. If you upload one image in "single image" tab, it will control the following frames to follow your first frame (a **probably** better way to do img2vid). If you upload a path in "batch" tab, with "keyframe" parameter in a new line (see below), it will attempt to do video frame interpolation. Note that I don't think this ControlNet has a comparable performance to those trained by [@lllyasviel](https://github.com/lllyasviel). Use at your own risk.
54 |
55 | Example input parameter fill-in:
56 | 1. Fill-in seperate control inputs for different ControlNet units.
57 | 1. Control all frames with a single control input. Exception: SparseCtrl will only control the first frame in this way.
58 | | IP-Adapter | Output |
59 | | --- | --- |
60 | |  |  |
61 | 1. Control each frame with a seperate control input. You are encouraged to try multi-ControlNet.
62 | | Canny | Output |
63 | | --- | --- |
64 | |  |  |
65 | 1. ControlNet inpaint unit: You are encouraged to use my [Segment Anything](https://github.com/continue-revolution/sd-webui-segment-anything) extension to automatically draw mask / generate masks in batch.
66 | - specify a global image and draw mask on it, or upload a mask. White region is where changes will apply.
67 | - "mask" parameter for ControlNet inpaint in batch. Type "ctrl + enter" to start a new line and fill in "mask" parameter in format `mask:/path/to/mask/frames/`.
68 |
69 | | single image | batch |
70 | | --- | --- |
71 | |  |  |
72 | 1. "keyframe" parameter.
73 | - **IP-Adapter**: this parameter means "IP-Adapter prompt travel". See image below for explanation.
74 | 
75 | You will see terminal log like
76 | ```bash
77 | ControlNet - INFO - AnimateDiff + ControlNet ip-adapter_clip_sd15 receive the following parameters:
78 | ControlNet - INFO - batch control images: /home/conrevo/SD/dataset/upperbodydataset/mask/key-ipadapter/
79 | ControlNet - INFO - batch control keyframe index: [0, 6, 12, 18]
80 | ```
81 | ```bash
82 | ControlNet - INFO - IP-Adapter: control prompts will be traveled in the following way:
83 | ControlNet - INFO - 0: /home/conrevo/SD/dataset/upperbodydataset/mask/key-ipadapter/anime_girl_head_1.png
84 | ControlNet - INFO - 6: /home/conrevo/SD/dataset/upperbodydataset/mask/key-ipadapter/anime_girl_head_2.png
85 | ControlNet - INFO - 12: /home/conrevo/SD/dataset/upperbodydataset/mask/key-ipadapter/anime_girl_head_3.png
86 | ControlNet - INFO - 18: /home/conrevo/SD/dataset/upperbodydataset/mask/key-ipadapter/anime_girl_head_4.png
87 | ```
88 | - **SparseCtrl**: this parameter means keyframe. SparseCtrl has its special processing for keyframe logic. Specify this parameter in the same way as IP-Adapter above.
89 | - All other ControlNets: we insert blank control image for you, and the control latent for that frame will be purely zero. Specify this parameter in the same way as IP-Adapter above.
90 | 1. Specify a global `Videl path` and `Mask path` and leave ControlNet Unit `Input Directory` input blank.
91 | - You can arbitratily change ControlNet Unit tab to `Single Image` / `Batch Folder` / `Batch Upload` as long as you leave it blank.
92 | - If you specify a global mask path, all ControlNet Units that you do not give a `Mask Directory` will use this path.
93 | - Please only have one of `Video source` and `Video path`. They cannot be applied at the same time.
94 | 
95 | 1. img2img batch. See the screenshot below.
96 |
97 | There are a lot of amazing demo online. Here I provide a very simple demo. The dataset is from [streamline](https://twitter.com/kaizirod), but the workflow is an arbitrary setup by me. You can find a lot more much more amazing examples (and potentially available workflows / infotexts) on Reddit, Twitter, YouTube and Bilibili. The easiest way to share your workflow created by my software is to share one output frame with infotext.
98 | | input | output |
99 | | --- | --- |
100 | | | |
101 |
102 |
103 | ## Model Spec
104 | > BREAKING CHANGE: You need to use Motion LoRA, HotShot-XL and AnimateDiff V3 Motion Adapter from [my HuggingFace repository](https://huggingface.co/conrevo/AnimateDiff-A1111/tree/main/lora) instead of the original one.
105 |
106 | ### Motion LoRA
107 | [Download](https://huggingface.co/conrevo/AnimateDiff-A1111/tree/main/lora) and use them like any other LoRA you use (example: download Motion LoRA to `stable-diffusion-webui/models/Lora` and add `` to your positive prompt). **Motion LoRAs can only be applied to V2 motion module**.
108 |
109 | ### V3
110 | AnimateDiff V3 has identical state dict keys as V1 but slightly different inference logic (GroupNorm is not hacked for V3). You may optionally use [adapter](https://huggingface.co/conrevo/AnimateDiff-A1111/resolve/main/lora/mm_sd15_v3_adapter.safetensors?download=true) for V3, in the same way as how you apply LoRA. You MUST use [my link](https://huggingface.co/conrevo/AnimateDiff-A1111/resolve/main/lora/mm_sd15_v3_adapter.safetensors?download=true) instead of the [official link](https://huggingface.co/guoyww/animatediff/resolve/main/v3_sd15_adapter.ckpt?download=true). The official adapter won't work for A1111 due to state dict incompatibility.
111 |
112 | ### AnimateLCM
113 | - You can download the motion module from [here](https://huggingface.co/conrevo/AnimateDiff-A1111/resolve/main/motion_module/mm_sd15_AnimateLCM.safetensors?download=true). The [original weights](https://huggingface.co/wangfuyun/AnimateLCM/resolve/main/AnimateLCM_sd15_t2v.ckpt?download=true) should also work, but I recommend using my safetensors fp16 version.
114 | - You should also download Motion LoRA from [here](https://huggingface.co/wangfuyun/AnimateLCM/resolve/main/AnimateLCM_sd15_t2v_lora.safetensors?download=true) and use it like any LoRA.
115 | - You should use LCM sampler and a low CFG scale (typically 1-2).
116 |
117 | ### SDXL
118 | [AnimateDiff-XL](https://github.com/guoyww/AnimateDiff/tree/sdxl) and [HotShot-XL](https://github.com/hotshotco/Hotshot-XL) have identical architecture to AnimateDiff-SD1.5. The only difference are
119 | - HotShot-XL is trained with 8 frames instead of 16 frames. You are recommended to set `Context batch size` to 8 for HotShot-XL.
120 | - AnimateDiff-XL is still trained with 16 frames. You do not need to change `Context batch size` for AnimateDiff-XL.
121 | - AnimateDiff-XL & HotShot-XL have fewer layers compared to AnimateDiff-SD1.5 because of SDXL.
122 | - AnimateDiff-XL is trained with higher resolution compared to HotShot-XL.
123 |
124 | Although AnimateDiff-XL & HotShot-XL have identical structure as AnimateDiff-SD1.5, I strongly discourage you from using AnimateDiff-SD1.5 for SDXL, or using HotShot-XL / AnimateDiff-XL for SD1.5 - you will get severe artifect if you do that. I have decided not to supported that, despite the fact that it is not hard for me to do that.
125 |
126 | Technically all features available for AnimateDiff + SD1.5 are also available for (AnimateDiff / HotShot) + SDXL. However, I have not tested all of them. I have tested infinite context generation and prompt travel; I have not tested ControlNet. If you find any bug, please report it to me.
127 |
128 | Unfortunately, neither of these 2 motion modules are as good as those for SD1.5, and there is NOTHING I can do about it (they are just poorly trained). Also, there seem to be no ControlNets comparable to what [@lllyasviel](https://github.com/lllyasviel) had trained for Sd1.5. I strongly discourage anyone from applying SDXL for video generation. You will be VERY disappointed if you do that.
129 |
--------------------------------------------------------------------------------
/docs/how-to-use.md:
--------------------------------------------------------------------------------
1 | # How to Use
2 |
3 | ## Preparation
4 | 1. Update WebUI to 1.8.0 and ControlNet to v1.1.441, then install this extension via link. I do not plan to support older version.
5 | 1. Download motion modules and put the model weights under `stable-diffusion-webui/extensions/sd-webui-animatediff/model/`. If you want to use another directory to save model weights, please go to `Settings/AnimateDiff`. See [model zoo](../README.md#model-zoo) for a list of available motion modules.
6 | 1. Enable `Pad prompt/negative prompt to be same length` in Settings/Optimization and click Apply settings. You must do this to prevent generating two separate unrelated GIFs. Checking `Batch cond/uncond` is optional, which can improve speed but increase VRAM usage.
7 |
8 | ## WebUI
9 | 1. Go to txt2img if you want to try txt2vid and img2img if you want to try img2vid.
10 | 1. Choose an SD checkpoint, write prompts, set configurations such as image width/height. If you want to generate multiple GIFs at once, please [change batch number, instead of batch size](performance.md#batch-size).
11 | 1. Enable AnimateDiff extension, set up [each parameter](#parameters), then click `Generate`.
12 | 1. You should see the output GIF on the output gallery. You can access GIF output and image frames at `stable-diffusion-webui/outputs/{txt2img or img2img}-images/AnimateDiff/{yy-mm-dd}`. You may choose to save frames for each generation into the original txt2img / img2img output directory by uncheck a checkbox inside `Settings/AnimateDiff`.
13 |
14 | ## API
15 | It is quite similar to the way you use ControlNet. API will return a video in base64 format. In `format`, `PNG` means to save frames to your file system without returning all the frames. If you want your API to return all frames, please add `Frame` to `format` list. For most up-to-date parameters, please read [here](https://github.com/continue-revolution/sd-webui-animatediff/blob/master/scripts/animatediff_ui.py#L26).
16 | ```
17 | 'alwayson_scripts': {
18 | 'AnimateDiff': {
19 | 'args': [{
20 | 'model': 'mm_sd_v15_v2.ckpt', # Motion module
21 | 'format': ['GIF'], # Save format, 'GIF' | 'MP4' | 'PNG' | 'WEBP' | 'WEBM' | 'TXT' | 'Frame'
22 | 'enable': True, # Enable AnimateDiff
23 | 'video_length': 16, # Number of frames
24 | 'fps': 8, # FPS
25 | 'loop_number': 0, # Display loop number
26 | 'closed_loop': 'R+P', # Closed loop, 'N' | 'R-P' | 'R+P' | 'A'
27 | 'batch_size': 16, # Context batch size
28 | 'stride': 1, # Stride
29 | 'overlap': -1, # Overlap
30 | 'interp': 'Off', # Frame interpolation, 'Off' | 'FILM'
31 | 'interp_x': 10 # Interp X
32 | 'video_source': 'path/to/video.mp4', # Video source
33 | 'video_path': 'path/to/frames', # Video path
34 | 'mask_path': 'path/to/frame_masks', # Mask path
35 | 'latent_power': 1, # Latent power
36 | 'latent_scale': 32, # Latent scale
37 | 'last_frame': None, # Optional last frame
38 | 'latent_power_last': 1, # Optional latent power for last frame
39 | 'latent_scale_last': 32,# Optional latent scale for last frame
40 | 'request_id': '' # Optional request id. If provided, outputs will have request id as filename suffix
41 | }
42 | ]
43 | }
44 | },
45 | ```
46 |
47 | If you wish to specify different conditional hints for different ControlNet units, the only additional thing you need to do is to specify `batch_images` parameter in your ControlNet JSON API parameters. The expected input format is exactly the same as [how to use ControlNet in WebUI](features.md#controlnet-v2v).
48 |
49 |
50 | ## Parameters
51 | 1. **Save format** — Format of the output. Choose at least one of "GIF"|"MP4"|"WEBP"|"WEBM"|"PNG". Check "TXT" if you want infotext, which will live in the same directory as the output GIF. Infotext is also accessible via `stable-diffusion-webui/params.txt` and outputs in all formats.
52 | 1. You can optimize GIF with `gifsicle` (`apt install gifsicle` required, read [#91](https://github.com/continue-revolution/sd-webui-animatediff/pull/91) for more information) and/or `palette` (read [#104](https://github.com/continue-revolution/sd-webui-animatediff/pull/104) for more information). Go to `Settings/AnimateDiff` to enable them.
53 | 1. You can set quality and lossless for WEBP via `Settings/AnimateDiff`. Read [#233](https://github.com/continue-revolution/sd-webui-animatediff/pull/233) for more information.
54 | 1. If you are using API, by adding "PNG" to `format`, you can save all frames to your file system without returning all the frames. If you want your API to return all frames, please add `Frame` to `format` list.
55 | 1. **Number of frames** — Choose whatever number you like.
56 |
57 | If you enter 0 (default):
58 | - If you submit a video via `Video source` / enter a video path via `Video path` / enable ANY batch ControlNet, the number of frames will be the number of frames in the video (use shortest if more than one videos are submitted).
59 | - Otherwise, the number of frames will be your `Context batch size` described below.
60 |
61 | If you enter something smaller than your `Context batch size` other than 0: you will get the first `Number of frames` frames as your output GIF from your whole generation. All following frames will not appear in your generated GIF, but will be saved as PNGs as usual. Do not set `Number of frames` to be something smaler than `Context batch size` other than 0 because of [#213](https://github.com/continue-revolution/sd-webui-animatediff/issues/213).
62 | 1. **FPS** — Frames per second, which is how many frames (images) are shown every second. If 16 frames are generated at 8 frames per second, your GIF’s duration is 2 seconds. If you submit a source video, your FPS will be the same as the source video.
63 | 1. **Display loop number** — How many times the GIF is played. A value of `0` means the GIF never stops playing.
64 | 1. **Context batch size** — How many frames will be passed into the motion module at once. The SD1.5 motion modules are trained with 16 frames, so it’ll give the best results when the number of frames is set to `16`. SDXL HotShotXL motion modules are trained with 8 frames instead. Choose [1, 24] for V1 / HotShotXL motion modules and [1, 32] for V2 / AnimateDiffXL motion modules.
65 | 1. **Closed loop** — Closed loop means that this extension will try to make the last frame the same as the first frame.
66 | 1. When `Number of frames` > `Context batch size`, including when ControlNet is enabled and the source video frame number > `Context batch size` and `Number of frames` is 0, closed loop will be performed by AnimateDiff infinite context generator.
67 | 1. When `Number of frames` <= `Context batch size`, AnimateDiff infinite context generator will not be effective. Only when you choose `A` will AnimateDiff append reversed list of frames to the original list of frames to form closed loop.
68 |
69 | See below for explanation of each choice:
70 |
71 | - `N` means absolutely no closed loop - this is the only available option if `Number of frames` is smaller than `Context batch size` other than 0.
72 | - `R-P` means that the extension will try to reduce the number of closed loop context. The prompt travel will not be interpolated to be a closed loop.
73 | - `R+P` means that the extension will try to reduce the number of closed loop context. The prompt travel will be interpolated to be a closed loop.
74 | - `A` means that the extension will aggressively try to make the last frame the same as the first frame. The prompt travel will be interpolated to be a closed loop.
75 | 1. **Stride** — Max motion stride as a power of 2 (default: 1).
76 | 1. Due to the limitation of the infinite context generator, this parameter is effective only when `Number of frames` > `Context batch size`, including when ControlNet is enabled and the source video frame number > `Context batch size` and `Number of frames` is 0.
77 | 1. "Absolutely no closed loop" is only possible when `Stride` is 1.
78 | 1. For each 1 <= $2^i$ <= `Stride`, the infinite context generator will try to make frames $2^i$ apart temporal consistent. For example, if `Stride` is 4 and `Number of frames` is 8, it will make the following frames temporal consistent:
79 | - `Stride` == 1: [0, 1, 2, 3, 4, 5, 6, 7]
80 | - `Stride` == 2: [0, 2, 4, 6], [1, 3, 5, 7]
81 | - `Stride` == 4: [0, 4], [1, 5], [2, 6], [3, 7]
82 | 1. **Overlap** — Number of frames to overlap in context. If overlap is -1 (default): your overlap will be `Context batch size` // 4.
83 | 1. Due to the limitation of the infinite context generator, this parameter is effective only when `Number of frames` > `Context batch size`, including when ControlNet is enabled and the source video frame number > `Context batch size` and `Number of frames` is 0.
84 | 1. **Frame Interpolation** — Interpolate between frames with Deforum's FILM implementation. Requires Deforum extension. [#128](https://github.com/continue-revolution/sd-webui-animatediff/pull/128)
85 | 1. **Interp X** — Replace each input frame with X interpolated output frames. [#128](https://github.com/continue-revolution/sd-webui-animatediff/pull/128).
86 | 1. **Video source** — [Optional] Video source file for [ControlNet V2V](features.md#controlnet-v2v). You MUST enable ControlNet. It will be the source control for ALL ControlNet units that you enable without submitting a single control image to `Single Image` tab or a path to `Batch Folder` tab in ControlNet panel. You can of course submit one control image via `Single Image` tab or an input directory via `Batch Folder` tab, which will override this video source input and work as usual.
87 | 1. **Video path** — [Optional] Folder for source frames for [ControlNet V2V](features.md#controlnet-v2v), but higher priority than `Video source`. You MUST enable ControlNet. It will be the source control for ALL ControlNet units that you enable without submitting a control image or a path to ControlNet. You can of course submit one control image via `Single Image` tab or an input directory via `Batch Folder` tab, which will override this video path input and work as usual.
88 | 1. **FreeInit** - [Optional] Using FreeInit to improve temporal consistency of your videos.
89 | 1. The default parameters provide satisfactory results for most use cases.
90 | 1. Use "Gaussian" filter when your motion is intense.
91 | 1. See [original repo of Freeinit](https://github.com/TianxingWu/FreeInit) to for more parameter settings.
92 |
93 | See [ControlNet V2V](features.md#controlnet-v2v) for an example parameter fill-in and more explanation.
94 |
--------------------------------------------------------------------------------
/docs/performance.md:
--------------------------------------------------------------------------------
1 | # Performance
2 |
3 | ## Optimizations
4 |
5 | Optimizations can be significantly helpful if you want to improve speed and reduce VRAM usage.
6 |
7 | ### Attention
8 | We will always apply scaled dot product attention from PyTorch.
9 |
10 | ### FP8
11 | FP8 requires torch >= 2.1.0. Go to `Settings/Optimizations` and select `Enable` for `FP8 weight`. Don't forget to click `Apply settings` button.
12 |
13 | ### LCM
14 | [Latent Consistency Model](https://github.com/luosiallen/latent-consistency-model) is a recent breakthrough in Stable Diffusion community. You can generate images / videos within 6-8 steps if you
15 | - select `LCM` / `Euler A` / `Euler` / `DDIM` sampler
16 | - apply [LCM LoRA](https://civitai.com/models/195519/lcm-lora-weights-stable-diffusion-acceleration-module)
17 | - apply low CFG denoising strength (1-2 is recommended)
18 |
19 | I have [PR-ed](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14583) this sampler to Stable Diffusion WebUI and you no longer need this extension to have LCM sampler. I have removed LCM sampler in this repository.
20 |
21 |
22 | ## VRAM
23 | Actual VRAM usage depends on your image size and context batch size. You can try to reduce image size to reduce VRAM usage. You are discouraged from changing context batch size, because this conflicts training specification.
24 |
25 | The following data are SD1.5 + AnimateDiff, tested on Ubuntu 22.04, NVIDIA 4090, torch 2.0.1+cu117, H=W=512, frame=16 (default setting). `w/`/`w/o` means `Batch cond/uncond` in `Settings/Optimization` is checked/unchecked.
26 | | Optimization | VRAM w/ | VRAM w/o |
27 | | --- | --- | --- |
28 | | No optimization | 12.13GB | |
29 | | xformers/sdp | 5.60GB | 4.21GB |
30 | | sub-quadratic | 10.39GB | |
31 |
32 | For SDXL + HotShot + SDP, tested on Ubuntu 22.04, NVIDIA 4090, torch 2.0.1+cu117, H=W=512, frame=8 (default setting), you need 8.66GB VRAM.
33 |
34 | For SDXL + AnimateDiff + SDP, tested on Ubuntu 22.04, NVIDIA 4090, torch 2.0.1+cu117, H=1024, W=768, frame=16, you need 13.87GB VRAM.
35 |
36 |
37 | ## Batch Size
38 | Batch size on WebUI will be replaced by GIF frame number internally: 1 full GIF generated in 1 batch. If you want to generate multiple GIF at once, please change batch number.
39 |
40 | Batch number is NOT the same as batch size. In A1111 WebUI, batch number is above batch size. Batch number means the number of sequential steps, but batch size means the number of parallel steps. You do not have to worry too much when you increase batch number, but you do need to worry about your VRAM when you increase your batch size (where in this extension, video frame number). You do not need to change batch size at all when you are using this extension.
41 |
42 | We might develope approach to support batch size on WebUI, but this is with very low priority and we cannot commit a specific date for this.
43 |
--------------------------------------------------------------------------------
/model/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/continue-revolution/sd-webui-animatediff/a88e88912bcbae0531caccfc50fd639f6ea83fd0/model/.gitkeep
--------------------------------------------------------------------------------
/motion_module.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 | from typing import Optional
3 |
4 | import math
5 | import torch
6 | from torch import nn
7 | from einops import rearrange
8 |
9 | import torch.nn as disable_weight_init
10 | from ldm.modules.attention import FeedForward
11 |
12 |
13 | class MotionModuleType(Enum):
14 | AnimateDiffV1 = "AnimateDiff V1, Yuwei Guo, Shanghai AI Lab"
15 | AnimateDiffV2 = "AnimateDiff V2, Yuwei Guo, Shanghai AI Lab"
16 | AnimateDiffV3 = "AnimateDiff V3, Yuwei Guo, Shanghai AI Lab"
17 | AnimateDiffXL = "AnimateDiff SDXL, Yuwei Guo, Shanghai AI Lab"
18 | AnimateLCM = "AnimateLCM, Fu-Yun Wang, MMLab@CUHK"
19 | SparseCtrl = "SparseCtrl, Yuwei Guo, Shanghai AI Lab"
20 | HotShotXL = "HotShot-XL, John Mullan, Natural Synthetics Inc"
21 |
22 |
23 | @staticmethod
24 | def get_mm_type(state_dict: dict[str, torch.Tensor]):
25 | keys = list(state_dict.keys())
26 | if any(["mid_block" in k for k in keys]):
27 | if not any(["pe" in k for k in keys]):
28 | return MotionModuleType.AnimateLCM
29 | return MotionModuleType.AnimateDiffV2
30 | elif any(["down_blocks.3" in k for k in keys]):
31 | if 32 in next((state_dict[key] for key in state_dict if 'pe' in key), None).shape:
32 | return MotionModuleType.AnimateDiffV3
33 | else:
34 | return MotionModuleType.AnimateDiffV1
35 | else:
36 | if 32 in next((state_dict[key] for key in state_dict if 'pe' in key), None).shape:
37 | return MotionModuleType.AnimateDiffXL
38 | else:
39 | return MotionModuleType.HotShotXL
40 |
41 |
42 | def zero_module(module):
43 | # Zero out the parameters of a module and return it.
44 | for p in module.parameters():
45 | p.detach().zero_()
46 | return module
47 |
48 |
49 | class MotionWrapper(nn.Module):
50 | def __init__(self, mm_name: str, mm_hash: str, mm_type: MotionModuleType, operations = disable_weight_init):
51 | super().__init__()
52 | self.mm_name = mm_name
53 | self.mm_type = mm_type
54 | self.mm_hash = mm_hash
55 | max_len = 64 if mm_type == MotionModuleType.AnimateLCM else (24 if self.enable_gn_hack() else 32)
56 | in_channels = (320, 640, 1280) if self.is_xl else (320, 640, 1280, 1280)
57 | self.down_blocks = nn.ModuleList([])
58 | self.up_blocks = nn.ModuleList([])
59 | for c in in_channels:
60 | if mm_type in [MotionModuleType.SparseCtrl]:
61 | self.down_blocks.append(MotionModule(c, num_mm=2, max_len=max_len, attention_block_types=("Temporal_Self", ), operations=operations))
62 | else:
63 | self.down_blocks.append(MotionModule(c, num_mm=2, max_len=max_len, operations=operations))
64 | self.up_blocks.insert(0,MotionModule(c, num_mm=3, max_len=max_len, operations=operations))
65 | if self.is_v2:
66 | self.mid_block = MotionModule(1280, num_mm=1, max_len=max_len, operations=operations)
67 |
68 |
69 | def enable_gn_hack(self):
70 | return self.mm_type in [MotionModuleType.AnimateDiffV1, MotionModuleType.HotShotXL]
71 |
72 |
73 | @property
74 | def is_xl(self):
75 | return self.mm_type in [MotionModuleType.AnimateDiffXL, MotionModuleType.HotShotXL]
76 |
77 |
78 | @property
79 | def is_adxl(self):
80 | return self.mm_type == MotionModuleType.AnimateDiffXL
81 |
82 | @property
83 | def is_hotshot(self):
84 | return self.mm_type == MotionModuleType.HotShotXL
85 |
86 |
87 | @property
88 | def is_v2(self):
89 | return self.mm_type in [MotionModuleType.AnimateDiffV2, MotionModuleType.AnimateLCM]
90 |
91 |
92 | class MotionModule(nn.Module):
93 | def __init__(self, in_channels, num_mm, max_len, attention_block_types=("Temporal_Self", "Temporal_Self"), operations = disable_weight_init):
94 | super().__init__()
95 | self.motion_modules = nn.ModuleList([
96 | VanillaTemporalModule(
97 | in_channels=in_channels,
98 | temporal_position_encoding_max_len=max_len,
99 | attention_block_types=attention_block_types,
100 | operations=operations,)
101 | for _ in range(num_mm)])
102 |
103 |
104 | def forward(self, x: torch.Tensor):
105 | for mm in self.motion_modules:
106 | x = mm(x)
107 | return x
108 |
109 |
110 | class VanillaTemporalModule(nn.Module):
111 | def __init__(
112 | self,
113 | in_channels,
114 | num_attention_heads = 8,
115 | num_transformer_block = 1,
116 | attention_block_types =( "Temporal_Self", "Temporal_Self" ),
117 | temporal_position_encoding_max_len = 24,
118 | temporal_attention_dim_div = 1,
119 | zero_initialize = True,
120 | operations = disable_weight_init,
121 | ):
122 | super().__init__()
123 |
124 | self.temporal_transformer = TemporalTransformer3DModel(
125 | in_channels=in_channels,
126 | num_attention_heads=num_attention_heads,
127 | attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div,
128 | num_layers=num_transformer_block,
129 | attention_block_types=attention_block_types,
130 | temporal_position_encoding_max_len=temporal_position_encoding_max_len,
131 | operations=operations,
132 | )
133 |
134 | if zero_initialize:
135 | self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out)
136 |
137 |
138 | def forward(self, x: torch.Tensor):
139 | return self.temporal_transformer(x)
140 |
141 |
142 | class TemporalTransformer3DModel(nn.Module):
143 | def __init__(
144 | self,
145 | in_channels,
146 | num_attention_heads,
147 | attention_head_dim,
148 |
149 | num_layers,
150 | attention_block_types = ( "Temporal_Self", "Temporal_Self", ),
151 | dropout = 0.0,
152 | norm_num_groups = 32,
153 | activation_fn = "geglu",
154 | attention_bias = False,
155 | upcast_attention = False,
156 |
157 | temporal_position_encoding_max_len = 24,
158 |
159 | operations = disable_weight_init,
160 | ):
161 | super().__init__()
162 |
163 | inner_dim = num_attention_heads * attention_head_dim
164 |
165 | self.norm = operations.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
166 | self.proj_in = operations.Linear(in_channels, inner_dim)
167 |
168 | self.transformer_blocks = nn.ModuleList(
169 | [
170 | TemporalTransformerBlock(
171 | dim=inner_dim,
172 | num_attention_heads=num_attention_heads,
173 | attention_head_dim=attention_head_dim,
174 | attention_block_types=attention_block_types,
175 | dropout=dropout,
176 | activation_fn=activation_fn,
177 | attention_bias=attention_bias,
178 | upcast_attention=upcast_attention,
179 | temporal_position_encoding_max_len=temporal_position_encoding_max_len,
180 | operations=operations,
181 | )
182 | for _ in range(num_layers)
183 | ]
184 | )
185 | self.proj_out = operations.Linear(inner_dim, in_channels)
186 |
187 | def forward(self, hidden_states: torch.Tensor):
188 | _, _, height, _ = hidden_states.shape
189 | residual = hidden_states
190 |
191 | hidden_states = self.norm(hidden_states).type(hidden_states.dtype)
192 | hidden_states = rearrange(hidden_states, "b c h w -> b (h w) c")
193 | hidden_states = self.proj_in(hidden_states)
194 |
195 | # Transformer Blocks
196 | for block in self.transformer_blocks:
197 | hidden_states = block(hidden_states)
198 |
199 | # output
200 | hidden_states = self.proj_out(hidden_states)
201 | hidden_states = rearrange(hidden_states, "b (h w) c -> b c h w", h=height)
202 |
203 | output = hidden_states + residual
204 | return output
205 |
206 |
207 | class TemporalTransformerBlock(nn.Module):
208 | def __init__(
209 | self,
210 | dim,
211 | num_attention_heads,
212 | attention_head_dim,
213 | attention_block_types = ( "Temporal_Self", "Temporal_Self", ),
214 | dropout = 0.0,
215 | activation_fn = "geglu",
216 | attention_bias = False,
217 | upcast_attention = False,
218 | temporal_position_encoding_max_len = 24,
219 | operations = disable_weight_init,
220 | ):
221 | super().__init__()
222 |
223 | attention_blocks = []
224 | norms = []
225 |
226 | for _ in attention_block_types:
227 | attention_blocks.append(
228 | VersatileAttention(
229 | query_dim=dim,
230 | heads=num_attention_heads,
231 | dim_head=attention_head_dim,
232 | dropout=dropout,
233 | bias=attention_bias,
234 | upcast_attention=upcast_attention,
235 | temporal_position_encoding_max_len=temporal_position_encoding_max_len,
236 | operations=operations,
237 | )
238 | )
239 | norms.append(operations.LayerNorm(dim))
240 |
241 | self.attention_blocks = nn.ModuleList(attention_blocks)
242 | self.norms = nn.ModuleList(norms)
243 |
244 | self.ff = FeedForward(dim, dropout=dropout, glu=(activation_fn=='geglu'))
245 | self.ff_norm = operations.LayerNorm(dim)
246 |
247 |
248 | def forward(self, hidden_states: torch.Tensor):
249 | for attention_block, norm in zip(self.attention_blocks, self.norms):
250 | norm_hidden_states = norm(hidden_states).type(hidden_states.dtype)
251 | hidden_states = attention_block(norm_hidden_states) + hidden_states
252 |
253 | hidden_states = self.ff(self.ff_norm(hidden_states).type(hidden_states.dtype)) + hidden_states
254 |
255 | output = hidden_states
256 | return output
257 |
258 |
259 | class PositionalEncoding(nn.Module):
260 | def __init__(
261 | self,
262 | d_model,
263 | dropout = 0.,
264 | max_len = 24,
265 | ):
266 | super().__init__()
267 | self.dropout = nn.Dropout(p=dropout)
268 | position = torch.arange(max_len).unsqueeze(1)
269 | div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
270 | pe = torch.zeros(1, max_len, d_model)
271 | pe[0, :, 0::2] = torch.sin(position * div_term)
272 | pe[0, :, 1::2] = torch.cos(position * div_term)
273 | self.register_buffer('pe', pe)
274 |
275 | def forward(self, x):
276 | x = x + self.pe[:, :x.size(1)].to(x)
277 | return self.dropout(x)
278 |
279 |
280 | class CrossAttention(nn.Module):
281 | r"""
282 | A cross attention layer.
283 |
284 | Parameters:
285 | query_dim (`int`): The number of channels in the query.
286 | cross_attention_dim (`int`, *optional*):
287 | The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
288 | heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
289 | dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
290 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
291 | bias (`bool`, *optional*, defaults to False):
292 | Set to `True` for the query, key, and value linear layers to contain a bias parameter.
293 | """
294 |
295 | def __init__(
296 | self,
297 | query_dim: int,
298 | cross_attention_dim: Optional[int] = None,
299 | heads: int = 8,
300 | dim_head: int = 64,
301 | dropout: float = 0.0,
302 | bias=False,
303 | upcast_attention: bool = False,
304 | upcast_softmax: bool = False,
305 | operations = disable_weight_init,
306 | ):
307 | super().__init__()
308 | inner_dim = dim_head * heads
309 | cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
310 | self.upcast_attention = upcast_attention
311 | self.upcast_softmax = upcast_softmax
312 | self.scale = dim_head**-0.5
313 | self.heads = heads
314 |
315 | self.to_q = operations.Linear(query_dim, inner_dim, bias=bias)
316 | self.to_k = operations.Linear(cross_attention_dim, inner_dim, bias=bias)
317 | self.to_v = operations.Linear(cross_attention_dim, inner_dim, bias=bias)
318 |
319 | self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim), nn.Dropout(dropout))
320 |
321 |
322 | class VersatileAttention(CrossAttention):
323 | def __init__(
324 | self,
325 | temporal_position_encoding_max_len = 24,
326 | *args, **kwargs
327 | ):
328 | super().__init__(*args, **kwargs)
329 |
330 | self.pos_encoder = PositionalEncoding(
331 | kwargs["query_dim"],
332 | max_len=temporal_position_encoding_max_len)
333 |
334 |
335 | def forward(self, x: torch.Tensor):
336 | from scripts.animatediff_mm import mm_animatediff
337 | video_length = mm_animatediff.ad_params.batch_size
338 |
339 | d = x.shape[1]
340 | x = rearrange(x, "(b f) d c -> (b d) f c", f=video_length)
341 | x = self.pos_encoder(x)
342 |
343 | q = self.to_q(x)
344 | k = self.to_k(x)
345 | v = self.to_v(x)
346 |
347 | q, k, v = map(lambda t: rearrange(t, 'b s (h d) -> (b h) s d', h=self.heads), (q, k, v))
348 | x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
349 | x = rearrange(x, '(b h) s d -> b s (h d)', h=self.heads)
350 |
351 | x = self.to_out(x) # linear proj and dropout
352 | x = rearrange(x, "(b d) f c -> (b f) d c", d=d)
353 |
354 | return x
355 |
--------------------------------------------------------------------------------
/scripts/animatediff.py:
--------------------------------------------------------------------------------
1 | from typing import List, Tuple
2 |
3 | import gradio as gr
4 |
5 | from modules import script_callbacks, scripts
6 | from modules.processing import (Processed, StableDiffusionProcessing,
7 | StableDiffusionProcessingImg2Img)
8 | from modules.scripts import PostprocessBatchListArgs, PostprocessImageArgs
9 |
10 | from scripts.animatediff_infv2v import AnimateDiffInfV2V
11 | from scripts.animatediff_latent import AnimateDiffI2VLatent
12 | from scripts.animatediff_logger import logger_animatediff as logger
13 | from scripts.animatediff_mm import mm_animatediff as motion_module
14 | from scripts.animatediff_prompt import AnimateDiffPromptSchedule
15 | from scripts.animatediff_output import AnimateDiffOutput
16 | from scripts.animatediff_xyz import patch_xyz, xyz_attrs
17 | from scripts.animatediff_ui import AnimateDiffProcess, AnimateDiffUiGroup
18 | from scripts.animatediff_settings import on_ui_settings
19 | from scripts.animatediff_infotext import update_infotext, infotext_pasted
20 | from scripts.animatediff_utils import get_animatediff_arg
21 | from scripts.animatediff_i2ibatch import * # this is necessary for CN to find the function
22 | from scripts.animatediff_freeinit import AnimateDiffFreeInit
23 |
24 | script_dir = scripts.basedir()
25 | motion_module.set_script_dir(script_dir)
26 |
27 |
28 | class AnimateDiffScript(scripts.Script):
29 |
30 | def __init__(self):
31 | self.hacked = False
32 | self.infotext_fields: List[Tuple[gr.components.IOComponent, str]] = []
33 | self.paste_field_names: List[str] = []
34 |
35 |
36 | def title(self):
37 | return "AnimateDiff"
38 |
39 |
40 | def show(self, is_img2img):
41 | return scripts.AlwaysVisible
42 |
43 |
44 | def ui(self, is_img2img):
45 | unit = AnimateDiffUiGroup().render(
46 | is_img2img,
47 | self.infotext_fields,
48 | self.paste_field_names
49 | )
50 | return (unit,)
51 |
52 |
53 | def before_process(self, p: StableDiffusionProcessing, params: AnimateDiffProcess):
54 | if p.is_api:
55 | params = get_animatediff_arg(p)
56 | motion_module.set_ad_params(params)
57 |
58 | # apply XYZ settings
59 | params.apply_xyz()
60 | xyz_attrs.clear()
61 |
62 | if params.enable:
63 | logger.info("AnimateDiff process start.")
64 | motion_module.inject(p.sd_model, params.model)
65 | params.set_p(p)
66 | params.prompt_scheduler = AnimateDiffPromptSchedule(p, params)
67 | update_infotext(p, params)
68 | if params.freeinit_enable:
69 | self.freeinit_hacker = AnimateDiffFreeInit(params)
70 | self.freeinit_hacker.hack(p, params)
71 | self.hacked = True
72 | elif self.hacked:
73 | motion_module.restore(p.sd_model)
74 | self.hacked = False
75 |
76 |
77 | def before_process_batch(self, p: StableDiffusionProcessing, params: AnimateDiffProcess, **kwargs):
78 | if params.enable and isinstance(p, StableDiffusionProcessingImg2Img) and not params.is_i2i_batch:
79 | AnimateDiffI2VLatent().randomize(p, params)
80 |
81 |
82 | def postprocess_batch_list(self, p: StableDiffusionProcessing, pp: PostprocessBatchListArgs, params: AnimateDiffProcess, **kwargs):
83 | if params.enable:
84 | params.prompt_scheduler.save_infotext_img(p)
85 |
86 |
87 | def postprocess_image(self, p: StableDiffusionProcessing, pp: PostprocessImageArgs, params: AnimateDiffProcess, *args):
88 | if params.enable and isinstance(p, StableDiffusionProcessingImg2Img) and hasattr(p, '_animatediff_paste_to_full'):
89 | p.paste_to = p._animatediff_paste_to_full[p.batch_index]
90 |
91 |
92 | def postprocess(self, p: StableDiffusionProcessing, res: Processed, params: AnimateDiffProcess):
93 | if params.enable:
94 | params.prompt_scheduler.save_infotext_txt(res)
95 | motion_module.restore(p.sd_model)
96 | self.hacked = False
97 | AnimateDiffOutput().output(p, res, params)
98 | logger.info("AnimateDiff process end.")
99 |
100 |
101 | patch_xyz()
102 |
103 | script_callbacks.on_ui_settings(on_ui_settings)
104 | script_callbacks.on_after_component(AnimateDiffUiGroup.on_after_component)
105 | script_callbacks.on_cfg_denoiser(AnimateDiffInfV2V.animatediff_on_cfg_denoiser)
106 | script_callbacks.on_infotext_pasted(infotext_pasted)
107 |
--------------------------------------------------------------------------------
/scripts/animatediff_freeinit.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.fft as fft
3 | import math
4 |
5 | from modules import sd_models, shared, sd_samplers, devices
6 | from modules.processing import StableDiffusionProcessing, opt_C, opt_f, StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, decode_latent_batch
7 | from types import MethodType
8 |
9 | from scripts.animatediff_ui import AnimateDiffProcess
10 |
11 |
12 | def ddim_add_noise(
13 | original_samples: torch.FloatTensor,
14 | noise: torch.FloatTensor,
15 | timesteps: torch.IntTensor,
16 | ) -> torch.FloatTensor:
17 |
18 | alphas_cumprod = shared.sd_model.alphas_cumprod
19 | # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
20 | alphas_cumprod = alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
21 | timesteps = timesteps.to(original_samples.device)
22 |
23 | sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
24 | sqrt_alpha_prod = sqrt_alpha_prod.flatten()
25 | while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
26 | sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
27 |
28 | sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
29 | sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
30 | while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
31 | sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
32 |
33 | noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
34 | return noisy_samples
35 |
36 |
37 |
38 | class AnimateDiffFreeInit:
39 | def __init__(self, params):
40 | self.num_iters = params.freeinit_iters
41 | self.method = params.freeinit_filter
42 | self.d_s = params.freeinit_ds
43 | self.d_t = params.freeinit_dt
44 |
45 | @torch.no_grad()
46 | def init_filter(self, video_length, height, width, filter_params):
47 | # initialize frequency filter for noise reinitialization
48 | batch_size = 1
49 | filter_shape = [
50 | batch_size,
51 | opt_C,
52 | video_length,
53 | height // opt_f,
54 | width // opt_f
55 | ]
56 | self.freq_filter = get_freq_filter(filter_shape, device=devices.device, params=filter_params)
57 |
58 |
59 | def hack(self, p: StableDiffusionProcessing, params: AnimateDiffProcess):
60 | # init filter
61 | filter_params = {
62 | 'method': self.method,
63 | 'd_s': self.d_s,
64 | 'd_t': self.d_t,
65 | }
66 | self.init_filter(params.video_length, p.height, p.width, filter_params)
67 |
68 |
69 | def sample_t2i(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
70 | self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
71 |
72 | # hack total progress bar (works in an ugly way)
73 | setattr(self.sampler, 'freeinit_num_iters', self.num_freeinit_iters)
74 | setattr(self.sampler, 'freeinit_num_iter', 0)
75 |
76 | def callback_hack(self, d):
77 | step = d['i'] // self.freeinit_num_iters + self.freeinit_num_iter * (shared.state.sampling_steps // self.freeinit_num_iters)
78 |
79 | if self.stop_at is not None and step > self.stop_at:
80 | raise InterruptedException
81 |
82 | shared.state.sampling_step = step
83 |
84 | if d['i'] % self.freeinit_num_iters == 0:
85 | shared.total_tqdm.update()
86 |
87 | self.sampler.callback_state = MethodType(callback_hack, self.sampler)
88 |
89 | # Sampling with FreeInit
90 | x = self.rng.next()
91 | x_dtype = x.dtype
92 |
93 | for iter in range(self.num_freeinit_iters):
94 | self.sampler.freeinit_num_iter = iter
95 | if iter == 0:
96 | initial_x = x.detach().clone()
97 | else:
98 | # z_0
99 | diffuse_timesteps = torch.tensor(1000 - 1)
100 | z_T = ddim_add_noise(x, initial_x, diffuse_timesteps) # [16, 4, 64, 64]
101 | # z_T
102 | # 2. create random noise z_rand for high-frequency
103 | z_T = z_T.permute(1, 0, 2, 3)[None, ...] # [bs, 4, 16, 64, 64]
104 | #z_rand = torch.randn(z_T.shape, device=devices.device)
105 | z_rand = initial_x.detach().clone().permute(1, 0, 2, 3)[None, ...]
106 | # 3. Roise Reinitialization
107 | x = freq_mix_3d(z_T.to(dtype=torch.float32), z_rand, LPF=self.freq_filter)
108 |
109 | x = x[0].permute(1, 0, 2, 3)
110 | x = x.to(x_dtype)
111 |
112 | x = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
113 | devices.torch_gc()
114 |
115 | samples = x
116 | del x
117 |
118 | if not self.enable_hr:
119 | return samples
120 |
121 | devices.torch_gc()
122 |
123 | if self.latent_scale_mode is None:
124 | decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32)
125 | else:
126 | decoded_samples = None
127 |
128 | with sd_models.SkipWritingToConfig():
129 | sd_models.reload_model_weights(info=self.hr_checkpoint_info)
130 |
131 | return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts)
132 |
133 |
134 | def sample_i2i(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
135 | x = self.rng.next()
136 | x_dtype = x.dtype
137 |
138 |
139 | if self.initial_noise_multiplier != 1.0:
140 | self.extra_generation_params["Noise multiplier"] = self.initial_noise_multiplier
141 | x *= self.initial_noise_multiplier
142 |
143 | for iter in range(self.num_freeinit_iters):
144 | if iter == 0:
145 | initial_x = x.detach().clone()
146 | else:
147 | # z_0
148 | diffuse_timesteps = torch.tensor(1000 - 1)
149 | z_T = ddim_add_noise(x, initial_x, diffuse_timesteps) # [16, 4, 64, 64]
150 | # z_T
151 | # 2. create random noise z_rand for high-frequency
152 | z_T = z_T.permute(1, 0, 2, 3)[None, ...] # [bs, 4, 16, 64, 64]
153 | #z_rand = torch.randn(z_T.shape, device=devices.device)
154 | z_rand = initial_x.detach().clone().permute(1, 0, 2, 3)[None, ...]
155 | # 3. Roise Reinitialization
156 | x = freq_mix_3d(z_T.to(dtype=torch.float32), z_rand, LPF=self.freq_filter)
157 |
158 | x = x[0].permute(1, 0, 2, 3)
159 | x = x.to(x_dtype)
160 |
161 | x = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
162 | samples = x
163 |
164 | if self.mask is not None:
165 | samples = samples * self.nmask + self.init_latent * self.mask
166 |
167 | del x
168 | devices.torch_gc()
169 |
170 | return samples
171 |
172 | if isinstance(p, StableDiffusionProcessingTxt2Img):
173 | p.sample = MethodType(sample_t2i, p)
174 | elif isinstance(p, StableDiffusionProcessingImg2Img):
175 | p.sample = MethodType(sample_i2i, p)
176 | else:
177 | raise NotImplementedError
178 |
179 | setattr(p, 'freq_filter', self.freq_filter)
180 | setattr(p, 'num_freeinit_iters', self.num_iters)
181 |
182 |
183 | def freq_mix_3d(x, noise, LPF):
184 | """
185 | Noise reinitialization.
186 |
187 | Args:
188 | x: diffused latent
189 | noise: randomly sampled noise
190 | LPF: low pass filter
191 | """
192 | # FFT
193 | x_freq = fft.fftn(x, dim=(-3, -2, -1))
194 | x_freq = fft.fftshift(x_freq, dim=(-3, -2, -1))
195 | noise_freq = fft.fftn(noise, dim=(-3, -2, -1))
196 | noise_freq = fft.fftshift(noise_freq, dim=(-3, -2, -1))
197 |
198 | # frequency mix
199 | HPF = 1 - LPF
200 | x_freq_low = x_freq * LPF
201 | noise_freq_high = noise_freq * HPF
202 | x_freq_mixed = x_freq_low + noise_freq_high # mix in freq domain
203 |
204 | # IFFT
205 | x_freq_mixed = fft.ifftshift(x_freq_mixed, dim=(-3, -2, -1))
206 | x_mixed = fft.ifftn(x_freq_mixed, dim=(-3, -2, -1)).real
207 |
208 | return x_mixed
209 |
210 |
211 | def get_freq_filter(shape, device, params: dict):
212 | """
213 | Form the frequency filter for noise reinitialization.
214 |
215 | Args:
216 | shape: shape of latent (B, C, T, H, W)
217 | params: filter parameters
218 | """
219 | if params['method'] == "gaussian":
220 | return gaussian_low_pass_filter(shape=shape, d_s=params['d_s'], d_t=params['d_t']).to(device)
221 | elif params['method'] == "ideal":
222 | return ideal_low_pass_filter(shape=shape, d_s=params['d_s'], d_t=params['d_t']).to(device)
223 | elif params['method'] == "box":
224 | return box_low_pass_filter(shape=shape, d_s=params['d_s'], d_t=params['d_t']).to(device)
225 | elif params['method'] == "butterworth":
226 | return butterworth_low_pass_filter(shape=shape, n=4, d_s=params['d_s'], d_t=params['d_t']).to(device)
227 | else:
228 | raise NotImplementedError
229 |
230 | def gaussian_low_pass_filter(shape, d_s=0.25, d_t=0.25):
231 | """
232 | Compute the gaussian low pass filter mask.
233 |
234 | Args:
235 | shape: shape of the filter (volume)
236 | d_s: normalized stop frequency for spatial dimensions (0.0-1.0)
237 | d_t: normalized stop frequency for temporal dimension (0.0-1.0)
238 | """
239 | T, H, W = shape[-3], shape[-2], shape[-1]
240 | mask = torch.zeros(shape)
241 | if d_s==0 or d_t==0:
242 | return mask
243 | for t in range(T):
244 | for h in range(H):
245 | for w in range(W):
246 | d_square = (((d_s/d_t)*(2*t/T-1))**2 + (2*h/H-1)**2 + (2*w/W-1)**2)
247 | mask[..., t,h,w] = math.exp(-1/(2*d_s**2) * d_square)
248 | return mask
249 |
250 |
251 | def butterworth_low_pass_filter(shape, n=4, d_s=0.25, d_t=0.25):
252 | """
253 | Compute the butterworth low pass filter mask.
254 |
255 | Args:
256 | shape: shape of the filter (volume)
257 | n: order of the filter, larger n ~ ideal, smaller n ~ gaussian
258 | d_s: normalized stop frequency for spatial dimensions (0.0-1.0)
259 | d_t: normalized stop frequency for temporal dimension (0.0-1.0)
260 | """
261 | T, H, W = shape[-3], shape[-2], shape[-1]
262 | mask = torch.zeros(shape)
263 | if d_s==0 or d_t==0:
264 | return mask
265 | for t in range(T):
266 | for h in range(H):
267 | for w in range(W):
268 | d_square = (((d_s/d_t)*(2*t/T-1))**2 + (2*h/H-1)**2 + (2*w/W-1)**2)
269 | mask[..., t,h,w] = 1 / (1 + (d_square / d_s**2)**n)
270 | return mask
271 |
272 |
273 | def ideal_low_pass_filter(shape, d_s=0.25, d_t=0.25):
274 | """
275 | Compute the ideal low pass filter mask.
276 |
277 | Args:
278 | shape: shape of the filter (volume)
279 | d_s: normalized stop frequency for spatial dimensions (0.0-1.0)
280 | d_t: normalized stop frequency for temporal dimension (0.0-1.0)
281 | """
282 | T, H, W = shape[-3], shape[-2], shape[-1]
283 | mask = torch.zeros(shape)
284 | if d_s==0 or d_t==0:
285 | return mask
286 | for t in range(T):
287 | for h in range(H):
288 | for w in range(W):
289 | d_square = (((d_s/d_t)*(2*t/T-1))**2 + (2*h/H-1)**2 + (2*w/W-1)**2)
290 | mask[..., t,h,w] = 1 if d_square <= d_s*2 else 0
291 | return mask
292 |
293 |
294 | def box_low_pass_filter(shape, d_s=0.25, d_t=0.25):
295 | """
296 | Compute the ideal low pass filter mask (approximated version).
297 |
298 | Args:
299 | shape: shape of the filter (volume)
300 | d_s: normalized stop frequency for spatial dimensions (0.0-1.0)
301 | d_t: normalized stop frequency for temporal dimension (0.0-1.0)
302 | """
303 | T, H, W = shape[-3], shape[-2], shape[-1]
304 | mask = torch.zeros(shape)
305 | if d_s==0 or d_t==0:
306 | return mask
307 |
308 | threshold_s = round(int(H // 2) * d_s)
309 | threshold_t = round(T // 2 * d_t)
310 |
311 | cframe, crow, ccol = T // 2, H // 2, W //2
312 | mask[..., cframe - threshold_t:cframe + threshold_t, crow - threshold_s:crow + threshold_s, ccol - threshold_s:ccol + threshold_s] = 1.0
313 |
314 | return mask
315 |
--------------------------------------------------------------------------------
/scripts/animatediff_i2ibatch.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from types import MethodType
3 |
4 | import os
5 | import cv2
6 | import numpy as np
7 | import torch
8 | import hashlib
9 | from PIL import Image, ImageOps, UnidentifiedImageError
10 | from modules import processing, shared, scripts, devices, masking, sd_samplers, images
11 | from modules.processing import (StableDiffusionProcessingImg2Img,
12 | process_images,
13 | create_binary_mask,
14 | create_random_tensors,
15 | images_tensor_to_samples,
16 | setup_color_correction,
17 | opt_f)
18 | from modules.shared import opts
19 | from modules.sd_samplers_common import images_tensor_to_samples, approximation_indexes
20 | from modules.sd_models import get_closet_checkpoint_match
21 |
22 | from scripts.animatediff_logger import logger_animatediff as logger
23 | from scripts.animatediff_utils import get_animatediff_arg, get_controlnet_units
24 |
25 |
26 | def animatediff_i2i_init(self, all_prompts, all_seeds, all_subseeds): # only hack this when i2i-batch with batch mask
27 | self.extra_generation_params["Denoising strength"] = self.denoising_strength
28 |
29 | self.image_cfg_scale: float = self.image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None
30 |
31 | self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
32 | crop_regions = []
33 | paste_to = []
34 | masks_for_overlay = []
35 |
36 | image_masks = self.image_mask
37 |
38 | for idx, image_mask in enumerate(image_masks):
39 | # image_mask is passed in as RGBA by Gradio to support alpha masks,
40 | # but we still want to support binary masks.
41 | image_mask = create_binary_mask(image_mask)
42 |
43 | if self.inpainting_mask_invert:
44 | image_mask = ImageOps.invert(image_mask)
45 |
46 | if self.mask_blur_x > 0:
47 | np_mask = np.array(image_mask)
48 | kernel_size = 2 * int(2.5 * self.mask_blur_x + 0.5) + 1
49 | np_mask = cv2.GaussianBlur(np_mask, (kernel_size, 1), self.mask_blur_x)
50 | image_mask = Image.fromarray(np_mask)
51 |
52 | if self.mask_blur_y > 0:
53 | np_mask = np.array(image_mask)
54 | kernel_size = 2 * int(2.5 * self.mask_blur_y + 0.5) + 1
55 | np_mask = cv2.GaussianBlur(np_mask, (1, kernel_size), self.mask_blur_y)
56 | image_mask = Image.fromarray(np_mask)
57 |
58 | if self.inpaint_full_res:
59 | masks_for_overlay.append(image_mask)
60 | mask = image_mask.convert('L')
61 | crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding)
62 | crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
63 | crop_regions.append(crop_region)
64 | x1, y1, x2, y2 = crop_region
65 |
66 | mask = mask.crop(crop_region)
67 | image_mask = images.resize_image(2, mask, self.width, self.height)
68 | paste_to.append((x1, y1, x2-x1, y2-y1))
69 | else:
70 | image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height)
71 | np_mask = np.array(image_mask)
72 | np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8)
73 | masks_for_overlay.append(Image.fromarray(np_mask))
74 |
75 | image_masks[idx] = image_mask
76 |
77 | self.mask_for_overlay = masks_for_overlay[0] # only for saving purpose
78 | if paste_to:
79 | self.paste_to = paste_to[0]
80 | self._animatediff_paste_to_full = paste_to
81 |
82 | self.overlay_images = []
83 | add_color_corrections = opts.img2img_color_correction and self.color_corrections is None
84 | if add_color_corrections:
85 | self.color_corrections = []
86 | imgs = []
87 | for idx, img in enumerate(self.init_images):
88 | latent_mask = (self.latent_mask[idx] if isinstance(self.latent_mask, list) else self.latent_mask) if self.latent_mask is not None else image_masks[idx]
89 | # Save init image
90 | if opts.save_init_img:
91 | self.init_img_hash = hashlib.md5(img.tobytes()).hexdigest()
92 | images.save_image(img, path=opts.outdir_init_images, basename=None, forced_filename=self.init_img_hash, save_to_dirs=False)
93 |
94 | image = images.flatten(img, opts.img2img_background_color)
95 |
96 | if not crop_regions and self.resize_mode != 3:
97 | image = images.resize_image(self.resize_mode, image, self.width, self.height)
98 |
99 | if image_masks:
100 | image_masked = Image.new('RGBa', (image.width, image.height))
101 | image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(masks_for_overlay[idx].convert('L')))
102 |
103 | self.overlay_images.append(image_masked.convert('RGBA'))
104 |
105 | # crop_region is not None if we are doing inpaint full res
106 | if crop_regions:
107 | image = image.crop(crop_regions[idx])
108 | image = images.resize_image(2, image, self.width, self.height)
109 |
110 | if image_masks:
111 | if self.inpainting_fill != 1:
112 | image = masking.fill(image, latent_mask)
113 |
114 | if add_color_corrections:
115 | self.color_corrections.append(setup_color_correction(image))
116 |
117 | image = np.array(image).astype(np.float32) / 255.0
118 | image = np.moveaxis(image, 2, 0)
119 |
120 | imgs.append(image)
121 |
122 | if len(imgs) == 1:
123 | batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0)
124 | if self.overlay_images is not None:
125 | self.overlay_images = self.overlay_images * self.batch_size
126 |
127 | if self.color_corrections is not None and len(self.color_corrections) == 1:
128 | self.color_corrections = self.color_corrections * self.batch_size
129 |
130 | elif len(imgs) <= self.batch_size:
131 | self.batch_size = len(imgs)
132 | batch_images = np.array(imgs)
133 | else:
134 | raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")
135 |
136 | image = torch.from_numpy(batch_images)
137 | image = image.to(shared.device, dtype=devices.dtype_vae)
138 |
139 | if opts.sd_vae_encode_method != 'Full':
140 | self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method
141 |
142 | self.init_latent = images_tensor_to_samples(image, approximation_indexes.get(opts.sd_vae_encode_method), self.sd_model)
143 | devices.torch_gc()
144 |
145 | if self.resize_mode == 3:
146 | self.init_latent = torch.nn.functional.interpolate(self.init_latent, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
147 |
148 | if image_masks is not None:
149 | def process_letmask(init_mask):
150 | # init_mask = latent_mask
151 | latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
152 | latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
153 | latmask = latmask[0]
154 | latmask = np.around(latmask)
155 | return np.tile(latmask[None], (4, 1, 1))
156 |
157 | if self.latent_mask is not None and not isinstance(self.latent_mask, list):
158 | latmask = process_letmask(self.latent_mask)
159 | else:
160 | if isinstance(self.latent_mask, list):
161 | latmask = [process_letmask(x) for x in self.latent_mask]
162 | else:
163 | latmask = [process_letmask(x) for x in image_masks]
164 | latmask = np.stack(latmask, axis=0)
165 |
166 | self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype)
167 | self.nmask = torch.asarray(latmask).to(shared.device).type(self.sd_model.dtype)
168 |
169 | # this needs to be fixed to be done in sample() using actual seeds for batches
170 | if self.inpainting_fill == 2:
171 | self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask
172 | elif self.inpainting_fill == 3:
173 | self.init_latent = self.init_latent * self.mask
174 |
175 | self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_masks) # let's ignore this image_masks which is related to inpaint model with different arch
176 |
177 |
178 | def animatediff_i2i_batch(
179 | p: StableDiffusionProcessingImg2Img, input_dir: str, output_dir: str, inpaint_mask_dir: str,
180 | args, to_scale=False, scale_by=1.0, use_png_info=False, png_info_props=None, png_info_dir=None):
181 | ad_params = get_animatediff_arg(p)
182 | assert ad_params.enable, "AnimateDiff is not enabled."
183 | if not ad_params.video_path and not ad_params.video_source:
184 | ad_params.video_path = input_dir
185 |
186 | output_dir = output_dir.strip()
187 | processing.fix_seed(p)
188 |
189 | images = list(shared.walk_files(input_dir, allowed_extensions=(".png", ".jpg", ".jpeg", ".webp", ".tif", ".tiff")))
190 |
191 | is_inpaint_batch = False
192 | if inpaint_mask_dir:
193 | inpaint_masks = shared.listfiles(inpaint_mask_dir)
194 | is_inpaint_batch = bool(inpaint_masks)
195 |
196 | if is_inpaint_batch:
197 | assert len(inpaint_masks) == 1 or len(inpaint_masks) == len(images), 'The number of masks must be 1 or equal to the number of images.'
198 | logger.info(f"[i2i batch] Inpaint batch is enabled. {len(inpaint_masks)} masks found.")
199 | if len(inpaint_masks) > 1: # batch mask
200 | p.init = MethodType(animatediff_i2i_init, p)
201 |
202 | cn_units = get_controlnet_units(p)
203 | for idx, cn_unit in enumerate(cn_units):
204 | # batch path broadcast
205 | if (cn_unit.input_mode.name == 'SIMPLE' and cn_unit.image is None) or \
206 | (cn_unit.input_mode.name == 'BATCH' and not cn_unit.batch_images) or \
207 | (cn_unit.input_mode.name == 'MERGE' and not cn_unit.batch_input_gallery):
208 | cn_unit.input_mode = cn_unit.input_mode.__class__.BATCH
209 | if "inpaint" in cn_unit.module:
210 | cn_unit.batch_images = f"{cn_unit.batch_images}\nmask:{inpaint_mask_dir}"
211 | logger.info(f"ControlNetUnit-{idx} is an inpaint unit without cond_hint specification. We have set batch_images = {cn_unit.batch_images}.")
212 |
213 | logger.info(f"[i2i batch] Will process {len(images)} images, creating {p.n_iter} new videos.")
214 |
215 | # extract "default" params to use in case getting png info fails
216 | prompt = p.prompt
217 | negative_prompt = p.negative_prompt
218 | seed = p.seed
219 | cfg_scale = p.cfg_scale
220 | sampler_name = p.sampler_name
221 | steps = p.steps
222 | override_settings = p.override_settings
223 | sd_model_checkpoint_override = get_closet_checkpoint_match(override_settings.get("sd_model_checkpoint", None))
224 | batch_results = None
225 | discard_further_results = False
226 | frame_images = []
227 | frame_masks = []
228 |
229 | for i, image in enumerate(images):
230 |
231 | try:
232 | img = Image.open(image)
233 | except UnidentifiedImageError as e:
234 | print(e)
235 | continue
236 | # Use the EXIF orientation of photos taken by smartphones.
237 | img = ImageOps.exif_transpose(img)
238 |
239 | if to_scale:
240 | p.width = int(img.width * scale_by)
241 | p.height = int(img.height * scale_by)
242 |
243 | frame_images.append(img)
244 |
245 | image_path = Path(image)
246 | if is_inpaint_batch:
247 | if len(inpaint_masks) == 1:
248 | mask_image_path = inpaint_masks[0]
249 | p.image_mask = Image.open(mask_image_path)
250 | else:
251 | # try to find corresponding mask for an image using index matching
252 | mask_image_path = inpaint_masks[i]
253 | frame_masks.append(Image.open(mask_image_path))
254 |
255 | mask_image = Image.open(mask_image_path)
256 | p.image_mask = mask_image
257 |
258 | if use_png_info:
259 | try:
260 | info_img = frame_images[0]
261 | if png_info_dir:
262 | info_img_path = os.path.join(png_info_dir, os.path.basename(image))
263 | info_img = Image.open(info_img_path)
264 | from modules import images as imgutil
265 | from modules.infotext_utils import parse_generation_parameters
266 | geninfo, _ = imgutil.read_info_from_image(info_img)
267 | parsed_parameters = parse_generation_parameters(geninfo)
268 | parsed_parameters = {k: v for k, v in parsed_parameters.items() if k in (png_info_props or {})}
269 | except Exception:
270 | parsed_parameters = {}
271 |
272 | p.prompt = prompt + (" " + parsed_parameters["Prompt"] if "Prompt" in parsed_parameters else "")
273 | p.negative_prompt = negative_prompt + (" " + parsed_parameters["Negative prompt"] if "Negative prompt" in parsed_parameters else "")
274 | p.seed = int(parsed_parameters.get("Seed", seed))
275 | p.cfg_scale = float(parsed_parameters.get("CFG scale", cfg_scale))
276 | p.sampler_name = parsed_parameters.get("Sampler", sampler_name)
277 | p.steps = int(parsed_parameters.get("Steps", steps))
278 |
279 | model_info = get_closet_checkpoint_match(parsed_parameters.get("Model hash", None))
280 | if model_info is not None:
281 | p.override_settings['sd_model_checkpoint'] = model_info.name
282 | elif sd_model_checkpoint_override:
283 | p.override_settings['sd_model_checkpoint'] = sd_model_checkpoint_override
284 | else:
285 | p.override_settings.pop("sd_model_checkpoint", None)
286 |
287 | if output_dir:
288 | p.outpath_samples = output_dir
289 | p.override_settings['save_to_dirs'] = False
290 | p.override_settings['save_images_replace_action'] = "Add number suffix"
291 | if p.n_iter > 1 or p.batch_size > 1:
292 | p.override_settings['samples_filename_pattern'] = f'{image_path.stem}-[generation_number]'
293 | else:
294 | p.override_settings['samples_filename_pattern'] = f'{image_path.stem}'
295 |
296 | p.init_images = frame_images
297 | if len(frame_masks) > 0:
298 | p.image_mask = frame_masks
299 |
300 | proc = scripts.scripts_img2img.run(p, *args) # we should not support this, but just leave it here
301 |
302 | if proc is None:
303 | p.override_settings.pop('save_images_replace_action', None)
304 | proc = process_images(p)
305 | else:
306 | logger.warn("Warning: you are using an unsupported external script. AnimateDiff may not work properly.")
307 |
308 | if not discard_further_results and proc:
309 | if batch_results:
310 | batch_results.images.extend(proc.images)
311 | batch_results.infotexts.extend(proc.infotexts)
312 | else:
313 | batch_results = proc
314 |
315 | if 0 <= shared.opts.img2img_batch_show_results_limit < len(batch_results.images):
316 | discard_further_results = True
317 | batch_results.images = batch_results.images[:int(shared.opts.img2img_batch_show_results_limit)]
318 | batch_results.infotexts = batch_results.infotexts[:int(shared.opts.img2img_batch_show_results_limit)]
319 |
320 | return batch_results
321 |
--------------------------------------------------------------------------------
/scripts/animatediff_infotext.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from modules.paths import data_path
4 | from modules.processing import StableDiffusionProcessing, StableDiffusionProcessingImg2Img
5 |
6 | from scripts.animatediff_ui import AnimateDiffProcess
7 | from scripts.animatediff_logger import logger_animatediff as logger
8 |
9 |
10 | def update_infotext(p: StableDiffusionProcessing, params: AnimateDiffProcess):
11 | if p.extra_generation_params is not None:
12 | p.extra_generation_params["AnimateDiff"] = params.get_dict(isinstance(p, StableDiffusionProcessingImg2Img))
13 |
14 |
15 | def write_params_txt(info: str):
16 | with open(os.path.join(data_path, "params.txt"), "w", encoding="utf8") as file:
17 | file.write(info)
18 |
19 |
20 |
21 | def infotext_pasted(infotext, results):
22 | for k, v in results.items():
23 | if not k.startswith("AnimateDiff"):
24 | continue
25 |
26 | assert isinstance(v, str), f"Expected string but got {v}."
27 | try:
28 | for items in v.split(', '):
29 | field, value = items.split(': ')
30 | results[f"AnimateDiff {field}"] = value
31 | results.pop("AnimateDiff")
32 | except Exception as e:
33 | logger.warn(f"Failed to parse infotext value:\n{v}")
34 | logger.warn(f"Exception: {e}")
35 | break
36 |
--------------------------------------------------------------------------------
/scripts/animatediff_infv2v.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | from types import MethodType
3 |
4 | import numpy as np
5 | import torch
6 |
7 | from modules import devices, shared
8 | from modules.script_callbacks import CFGDenoiserParams
9 | from scripts.animatediff_logger import logger_animatediff as logger
10 | from scripts.animatediff_mm import mm_animatediff as motion_module
11 |
12 |
13 | class AnimateDiffInfV2V:
14 |
15 | # Returns fraction that has denominator that is a power of 2
16 | @staticmethod
17 | def ordered_halving(val):
18 | # get binary value, padded with 0s for 64 bits
19 | bin_str = f"{val:064b}"
20 | # flip binary value, padding included
21 | bin_flip = bin_str[::-1]
22 | # convert binary to int
23 | as_int = int(bin_flip, 2)
24 | # divide by 1 << 64, equivalent to 2**64, or 18446744073709551616,
25 | # or b10000000000000000000000000000000000000000000000000000000000000000 (1 with 64 zero's)
26 | final = as_int / (1 << 64)
27 | return final
28 |
29 |
30 | # Generator that returns lists of latent indeces to diffuse on
31 | @staticmethod
32 | def uniform(
33 | step: int,
34 | video_length: int = 0,
35 | batch_size: int = 16,
36 | stride: int = 1,
37 | overlap: int = 4,
38 | loop_setting: str = 'R-P',
39 | ):
40 | if video_length <= batch_size:
41 | yield list(range(batch_size))
42 | return
43 |
44 | closed_loop = (loop_setting == 'A')
45 | stride = min(stride, int(np.ceil(np.log2(video_length / batch_size))) + 1)
46 |
47 | for context_step in 1 << np.arange(stride):
48 | pad = int(round(video_length * AnimateDiffInfV2V.ordered_halving(step)))
49 | both_close_loop = False
50 | for j in range(
51 | int(AnimateDiffInfV2V.ordered_halving(step) * context_step) + pad,
52 | video_length + pad + (0 if closed_loop else -overlap),
53 | (batch_size * context_step - overlap),
54 | ):
55 | if loop_setting == 'N' and context_step == 1:
56 | current_context = [e % video_length for e in range(j, j + batch_size * context_step, context_step)]
57 | first_context = [e % video_length for e in range(0, batch_size * context_step, context_step)]
58 | last_context = [e % video_length for e in range(video_length - batch_size * context_step, video_length, context_step)]
59 | def get_unsorted_index(lst):
60 | for i in range(1, len(lst)):
61 | if lst[i] < lst[i-1]:
62 | return i
63 | return None
64 | unsorted_index = get_unsorted_index(current_context)
65 | if unsorted_index is None:
66 | yield current_context
67 | elif both_close_loop: # last and this context are close loop
68 | both_close_loop = False
69 | yield first_context
70 | elif unsorted_index < batch_size - overlap: # only this context is close loop
71 | yield last_context
72 | yield first_context
73 | else: # this and next context are close loop
74 | both_close_loop = True
75 | yield last_context
76 | else:
77 | yield [e % video_length for e in range(j, j + batch_size * context_step, context_step)]
78 |
79 |
80 | @staticmethod
81 | def animatediff_on_cfg_denoiser(cfg_params: CFGDenoiserParams):
82 | ad_params = motion_module.ad_params
83 | if ad_params is None or not ad_params.enable:
84 | return
85 |
86 | # !adetailer accomodation
87 | if not motion_module.mm_injected:
88 | if cfg_params.denoiser.step == 0:
89 | logger.warning(
90 | "No motion module detected, falling back to the original forward. You are most likely using !Adetailer. "
91 | "!Adetailer post-process your outputs sequentially, and there will NOT be motion module in your UNet, "
92 | "so there might be NO temporal consistency within the inpainted face. Use at your own risk. "
93 | "If you really want to pursue inpainting with AnimateDiff inserted into UNet, "
94 | "use Segment Anything to generate masks for each frame and inpaint them with AnimateDiff + ControlNet. "
95 | "Note that my proposal might be good or bad, do your own research to figure out the best way.")
96 | return
97 |
98 | if cfg_params.denoiser.step == 0 and getattr(cfg_params.denoiser.inner_model, 'original_forward', None) is None:
99 |
100 | # prompt travel
101 | prompt_closed_loop = (ad_params.video_length > ad_params.batch_size) and (ad_params.closed_loop in ['R+P', 'A'])
102 | ad_params.text_cond = ad_params.prompt_scheduler.multi_cond(cfg_params.text_cond, prompt_closed_loop)
103 | try:
104 | from scripts.external_code import find_cn_script
105 | cn_script = find_cn_script(cfg_params.denoiser.p.scripts)
106 | except:
107 | cn_script = None
108 |
109 | # infinite generation
110 | def mm_cn_select(context: List[int]):
111 | # take control images for current context.
112 | if cn_script and cn_script.latest_network:
113 | from scripts.hook import ControlModelType
114 | for control in cn_script.latest_network.control_params:
115 | if control.control_model_type not in [ControlModelType.IPAdapter, ControlModelType.Controlllite]:
116 | if control.hint_cond.shape[0] > len(context):
117 | control.hint_cond_backup = control.hint_cond
118 | control.hint_cond = control.hint_cond[context]
119 | control.hint_cond = control.hint_cond.to(device=devices.get_device_for("controlnet"))
120 | if control.hr_hint_cond is not None:
121 | if control.hr_hint_cond.shape[0] > len(context):
122 | control.hr_hint_cond_backup = control.hr_hint_cond
123 | control.hr_hint_cond = control.hr_hint_cond[context]
124 | control.hr_hint_cond = control.hr_hint_cond.to(device=devices.get_device_for("controlnet"))
125 | # IPAdapter and Controlllite are always on CPU.
126 | elif control.control_model_type == ControlModelType.IPAdapter and control.control_model.image_emb.cond_emb.shape[0] > len(context):
127 | from scripts.controlmodel_ipadapter import ImageEmbed
128 | if getattr(control.control_model.image_emb, "cond_emb_backup", None) is None:
129 | control.control_model.cond_emb_backup = control.control_model.image_emb.cond_emb
130 | control.control_model.image_emb = ImageEmbed(control.control_model.cond_emb_backup[context], control.control_model.image_emb.uncond_emb)
131 | elif control.control_model_type == ControlModelType.Controlllite:
132 | for module in control.control_model.modules.values():
133 | if module.cond_image.shape[0] > len(context):
134 | module.cond_image_backup = module.cond_image
135 | module.set_cond_image(module.cond_image[context])
136 |
137 | def mm_cn_restore(context: List[int]):
138 | # restore control images for next context
139 | if cn_script and cn_script.latest_network:
140 | from scripts.hook import ControlModelType
141 | for control in cn_script.latest_network.control_params:
142 | if control.control_model_type not in [ControlModelType.IPAdapter, ControlModelType.Controlllite]:
143 | if getattr(control, "hint_cond_backup", None) is not None:
144 | control.hint_cond_backup[context] = control.hint_cond.to(device="cpu")
145 | control.hint_cond = control.hint_cond_backup
146 | if control.hr_hint_cond is not None and getattr(control, "hr_hint_cond_backup", None) is not None:
147 | control.hr_hint_cond_backup[context] = control.hr_hint_cond.to(device="cpu")
148 | control.hr_hint_cond = control.hr_hint_cond_backup
149 | elif control.control_model_type == ControlModelType.Controlllite:
150 | for module in control.control_model.modules.values():
151 | if getattr(module, "cond_image_backup", None) is not None:
152 | module.set_cond_image(module.cond_image_backup)
153 |
154 | def mm_sd_forward(self, x_in, sigma_in, cond):
155 | logger.debug("Running special forward for AnimateDiff")
156 | x_out = torch.zeros_like(x_in)
157 | for context in AnimateDiffInfV2V.uniform(ad_params.step, ad_params.video_length, ad_params.batch_size, ad_params.stride, ad_params.overlap, ad_params.closed_loop):
158 | if shared.opts.batch_cond_uncond:
159 | _context = context + [c + ad_params.video_length for c in context]
160 | else:
161 | _context = context
162 | mm_cn_select(_context)
163 | out = self.original_forward(
164 | x_in[_context], sigma_in[_context],
165 | cond={k: ([v[0][_context]] if isinstance(v, list) else v[_context]) for k, v in cond.items()})
166 | x_out = x_out.to(dtype=out.dtype)
167 | x_out[_context] = out
168 | mm_cn_restore(_context)
169 | return x_out
170 |
171 | logger.info("inner model forward hooked")
172 | cfg_params.denoiser.inner_model.original_forward = cfg_params.denoiser.inner_model.forward
173 | cfg_params.denoiser.inner_model.forward = MethodType(mm_sd_forward, cfg_params.denoiser.inner_model)
174 |
175 | cfg_params.text_cond = ad_params.text_cond
176 | ad_params.step = cfg_params.denoiser.step
177 |
--------------------------------------------------------------------------------
/scripts/animatediff_latent.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from modules import images, shared
4 | from modules.devices import device, dtype_vae, torch_gc
5 | from modules.processing import StableDiffusionProcessingImg2Img
6 | from modules.sd_samplers_common import (approximation_indexes,
7 | images_tensor_to_samples)
8 |
9 | from scripts.animatediff_logger import logger_animatediff as logger
10 | from scripts.animatediff_ui import AnimateDiffProcess
11 |
12 |
13 | class AnimateDiffI2VLatent:
14 | def randomize(
15 | self, p: StableDiffusionProcessingImg2Img, params: AnimateDiffProcess
16 | ):
17 | # Get init_alpha
18 | init_alpha = [
19 | 1 - pow(i, params.latent_power) / params.latent_scale
20 | for i in range(params.video_length)
21 | ]
22 | logger.info(f"Randomizing init_latent according to {init_alpha}.")
23 | init_alpha = torch.tensor(init_alpha, dtype=torch.float32, device=device)[
24 | :, None, None, None
25 | ]
26 | init_alpha[init_alpha < 0] = 0
27 |
28 | if params.last_frame is not None:
29 | last_frame = params.last_frame
30 | if type(last_frame) == str:
31 | from modules.api.api import decode_base64_to_image
32 | last_frame = decode_base64_to_image(last_frame)
33 | # Get last_alpha
34 | last_alpha = [
35 | 1 - pow(i, params.latent_power_last) / params.latent_scale_last
36 | for i in range(params.video_length)
37 | ]
38 | last_alpha.reverse()
39 | logger.info(f"Randomizing last_latent according to {last_alpha}.")
40 | last_alpha = torch.tensor(last_alpha, dtype=torch.float32, device=device)[
41 | :, None, None, None
42 | ]
43 | last_alpha[last_alpha < 0] = 0
44 |
45 | # Normalize alpha
46 | sum_alpha = init_alpha + last_alpha
47 | mask_alpha = sum_alpha > 1
48 | scaling_factor = 1 / sum_alpha[mask_alpha]
49 | init_alpha[mask_alpha] *= scaling_factor
50 | last_alpha[mask_alpha] *= scaling_factor
51 | init_alpha[0] = 1
52 | init_alpha[-1] = 0
53 | last_alpha[0] = 0
54 | last_alpha[-1] = 1
55 |
56 | # Calculate last_latent
57 | if p.resize_mode != 3:
58 | last_frame = images.resize_image(
59 | p.resize_mode, last_frame, p.width, p.height
60 | )
61 | last_frame = np.array(last_frame).astype(np.float32) / 255.0
62 | last_frame = np.moveaxis(last_frame, 2, 0)[None, ...]
63 | last_frame = torch.from_numpy(last_frame).to(device).to(dtype_vae)
64 | last_latent = images_tensor_to_samples(
65 | last_frame,
66 | approximation_indexes.get(shared.opts.sd_vae_encode_method),
67 | p.sd_model,
68 | )
69 | torch_gc()
70 | if p.resize_mode == 3:
71 | opt_f = 8
72 | last_latent = torch.nn.functional.interpolate(
73 | last_latent,
74 | size=(p.height // opt_f, p.width // opt_f),
75 | mode="bilinear",
76 | )
77 | # Modify init_latent
78 | p.init_latent = (
79 | p.init_latent * init_alpha
80 | + last_latent * last_alpha
81 | + p.rng.next() * (1 - init_alpha - last_alpha)
82 | )
83 | else:
84 | p.init_latent = p.init_latent * init_alpha + p.rng.next() * (1 - init_alpha)
85 |
--------------------------------------------------------------------------------
/scripts/animatediff_logger.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import logging
3 | import sys
4 |
5 | from modules import shared
6 |
7 |
8 | class ColoredFormatter(logging.Formatter):
9 | COLORS = {
10 | "DEBUG": "\033[0;36m", # CYAN
11 | "INFO": "\033[0;32m", # GREEN
12 | "WARNING": "\033[0;33m", # YELLOW
13 | "ERROR": "\033[0;31m", # RED
14 | "CRITICAL": "\033[0;37;41m", # WHITE ON RED
15 | "RESET": "\033[0m", # RESET COLOR
16 | }
17 |
18 | def format(self, record):
19 | colored_record = copy.copy(record)
20 | levelname = colored_record.levelname
21 | seq = self.COLORS.get(levelname, self.COLORS["RESET"])
22 | colored_record.levelname = f"{seq}{levelname}{self.COLORS['RESET']}"
23 | return super().format(colored_record)
24 |
25 |
26 | # Create a new logger
27 | logger_animatediff = logging.getLogger("AnimateDiff")
28 | logger_animatediff.propagate = False
29 |
30 | # Add handler if we don't have one.
31 | if not logger_animatediff.handlers:
32 | handler = logging.StreamHandler(sys.stdout)
33 | handler.setFormatter(
34 | ColoredFormatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
35 | )
36 | logger_animatediff.addHandler(handler)
37 |
38 | # Configure logger
39 | loglevel_string = getattr(shared.cmd_opts, "loglevel", "INFO")
40 | if not loglevel_string:
41 | loglevel_string = "INFO"
42 | loglevel = getattr(logging, loglevel_string.upper(), None)
43 | logger_animatediff.setLevel(loglevel)
44 |
--------------------------------------------------------------------------------
/scripts/animatediff_mm.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import os
3 |
4 | import torch
5 | from einops import rearrange
6 | from modules import hashes, shared, sd_models, devices
7 | from modules.devices import cpu, device, torch_gc
8 |
9 | from motion_module import MotionWrapper, MotionModuleType
10 | from scripts.animatediff_logger import logger_animatediff as logger
11 |
12 |
13 | class AnimateDiffMM:
14 | mm_injected = False
15 |
16 | def __init__(self):
17 | self.mm: MotionWrapper = None
18 | self.script_dir = None
19 | self.ad_params = None
20 | self.prev_alpha_cumprod = None
21 | self.prev_alpha_cumprod_original = None
22 | self.gn32_original_forward = None
23 |
24 |
25 | def set_script_dir(self, script_dir):
26 | self.script_dir = script_dir
27 |
28 |
29 | def set_ad_params(self, ad_params):
30 | self.ad_params = ad_params
31 |
32 |
33 | def get_model_dir(self):
34 | model_dir = shared.opts.data.get("animatediff_model_path", os.path.join(self.script_dir, "model"))
35 | if not model_dir:
36 | model_dir = os.path.join(self.script_dir, "model")
37 | return model_dir
38 |
39 |
40 | def load(self, model_name: str):
41 | model_path = os.path.join(self.get_model_dir(), model_name)
42 | if not os.path.isfile(model_path):
43 | raise RuntimeError("Please download models manually.")
44 | if self.mm is None or self.mm.mm_name != model_name:
45 | logger.info(f"Loading motion module {model_name} from {model_path}")
46 | model_hash = hashes.sha256(model_path, f"AnimateDiff/{model_name}")
47 | mm_state_dict = sd_models.read_state_dict(model_path)
48 | model_type = MotionModuleType.get_mm_type(mm_state_dict)
49 | logger.info(f"Guessed {model_name} architecture: {model_type}")
50 | mm_config = dict(mm_name=model_name, mm_hash=model_hash, mm_type=model_type)
51 | self.mm = MotionWrapper(**mm_config)
52 | self.mm.load_state_dict(mm_state_dict, strict=not model_type==MotionModuleType.AnimateLCM)
53 | self.mm.to(device).eval()
54 | if not shared.cmd_opts.no_half:
55 | self.mm.half()
56 | if getattr(devices, "fp8", False):
57 | for module in self.mm.modules():
58 | if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
59 | module.to(torch.float8_e4m3fn)
60 |
61 |
62 | def inject(self, sd_model, model_name="mm_sd15_v3.safetensors"):
63 | if AnimateDiffMM.mm_injected:
64 | logger.info("Motion module already injected. Trying to restore.")
65 | self.restore(sd_model)
66 |
67 | unet = sd_model.model.diffusion_model
68 | self.load(model_name)
69 | inject_sdxl = sd_model.is_sdxl or self.mm.is_xl
70 | sd_ver = "SDXL" if sd_model.is_sdxl else "SD1.5"
71 | assert sd_model.is_sdxl == self.mm.is_xl, f"Motion module incompatible with SD. You are using {sd_ver} with {self.mm.mm_type}."
72 |
73 | if self.mm.is_v2:
74 | logger.info(f"Injecting motion module {model_name} into {sd_ver} UNet middle block.")
75 | unet.middle_block.insert(-1, self.mm.mid_block.motion_modules[0])
76 | elif self.mm.enable_gn_hack():
77 | logger.info(f"Hacking {sd_ver} GroupNorm32 forward function.")
78 | if self.mm.is_hotshot:
79 | from sgm.modules.diffusionmodules.util import GroupNorm32
80 | else:
81 | from ldm.modules.diffusionmodules.util import GroupNorm32
82 | self.gn32_original_forward = GroupNorm32.forward
83 | gn32_original_forward = self.gn32_original_forward
84 |
85 | def groupnorm32_mm_forward(self, x):
86 | x = rearrange(x, "(b f) c h w -> b c f h w", b=2)
87 | x = gn32_original_forward(self, x)
88 | x = rearrange(x, "b c f h w -> (b f) c h w", b=2)
89 | return x
90 |
91 | GroupNorm32.forward = groupnorm32_mm_forward
92 |
93 | logger.info(f"Injecting motion module {model_name} into {sd_ver} UNet input blocks.")
94 | for mm_idx, unet_idx in enumerate([1, 2, 4, 5, 7, 8, 10, 11]):
95 | if inject_sdxl and mm_idx >= 6:
96 | break
97 | mm_idx0, mm_idx1 = mm_idx // 2, mm_idx % 2
98 | mm_inject = getattr(self.mm.down_blocks[mm_idx0], "motion_modules")[mm_idx1]
99 | unet.input_blocks[unet_idx].append(mm_inject)
100 |
101 | logger.info(f"Injecting motion module {model_name} into {sd_ver} UNet output blocks.")
102 | for unet_idx in range(12):
103 | if inject_sdxl and unet_idx >= 9:
104 | break
105 | mm_idx0, mm_idx1 = unet_idx // 3, unet_idx % 3
106 | mm_inject = getattr(self.mm.up_blocks[mm_idx0], "motion_modules")[mm_idx1]
107 | if unet_idx % 3 == 2 and unet_idx != (8 if self.mm.is_xl else 11):
108 | unet.output_blocks[unet_idx].insert(-1, mm_inject)
109 | else:
110 | unet.output_blocks[unet_idx].append(mm_inject)
111 |
112 | self._set_ddim_alpha(sd_model)
113 | self._set_layer_mapping(sd_model)
114 | AnimateDiffMM.mm_injected = True
115 | logger.info(f"Injection finished.")
116 |
117 |
118 | def restore(self, sd_model):
119 | if not AnimateDiffMM.mm_injected:
120 | logger.info("Motion module already removed.")
121 | return
122 |
123 | inject_sdxl = sd_model.is_sdxl or self.mm.is_xl
124 | sd_ver = "SDXL" if sd_model.is_sdxl else "SD1.5"
125 | self._restore_ddim_alpha(sd_model)
126 | unet = sd_model.model.diffusion_model
127 |
128 | logger.info(f"Removing motion module from {sd_ver} UNet input blocks.")
129 | for unet_idx in [1, 2, 4, 5, 7, 8, 10, 11]:
130 | if inject_sdxl and unet_idx >= 9:
131 | break
132 | unet.input_blocks[unet_idx].pop(-1)
133 |
134 | logger.info(f"Removing motion module from {sd_ver} UNet output blocks.")
135 | for unet_idx in range(12):
136 | if inject_sdxl and unet_idx >= 9:
137 | break
138 | if unet_idx % 3 == 2 and unet_idx != (8 if self.mm.is_xl else 11):
139 | unet.output_blocks[unet_idx].pop(-2)
140 | else:
141 | unet.output_blocks[unet_idx].pop(-1)
142 |
143 | if self.mm.is_v2:
144 | logger.info(f"Removing motion module from {sd_ver} UNet middle block.")
145 | unet.middle_block.pop(-2)
146 | elif self.mm.enable_gn_hack():
147 | logger.info(f"Restoring {sd_ver} GroupNorm32 forward function.")
148 | if self.mm.is_hotshot:
149 | from sgm.modules.diffusionmodules.util import GroupNorm32
150 | else:
151 | from ldm.modules.diffusionmodules.util import GroupNorm32
152 | GroupNorm32.forward = self.gn32_original_forward
153 | self.gn32_original_forward = None
154 |
155 | AnimateDiffMM.mm_injected = False
156 | logger.info(f"Removal finished.")
157 | if sd_model.lowvram:
158 | self.unload()
159 |
160 |
161 | def _set_ddim_alpha(self, sd_model):
162 | logger.info(f"Setting DDIM alpha.")
163 | beta_start = 0.00085
164 | beta_end = 0.020 if self.mm.is_adxl else 0.012
165 | if self.mm.is_adxl:
166 | betas = torch.linspace(beta_start**0.5, beta_end**0.5, 1000, dtype=torch.float32, device=device) ** 2
167 | else:
168 | betas = torch.linspace(
169 | beta_start,
170 | beta_end,
171 | 1000 if sd_model.is_sdxl else sd_model.num_timesteps,
172 | dtype=torch.float32,
173 | device=device,
174 | )
175 | alphas = 1.0 - betas
176 | alphas_cumprod = torch.cumprod(alphas, dim=0)
177 | self.prev_alpha_cumprod = sd_model.alphas_cumprod
178 | self.prev_alpha_cumprod_original = sd_model.alphas_cumprod_original
179 | sd_model.alphas_cumprod = alphas_cumprod
180 | sd_model.alphas_cumprod_original = alphas_cumprod
181 |
182 |
183 | def _set_layer_mapping(self, sd_model):
184 | if hasattr(sd_model, 'network_layer_mapping'):
185 | for name, module in self.mm.named_modules():
186 | network_name = name.replace(".", "_")
187 | sd_model.network_layer_mapping[network_name] = module
188 | module.network_layer_name = network_name
189 |
190 |
191 | def _restore_ddim_alpha(self, sd_model):
192 | logger.info(f"Restoring DDIM alpha.")
193 | sd_model.alphas_cumprod = self.prev_alpha_cumprod
194 | sd_model.alphas_cumprod_original = self.prev_alpha_cumprod_original
195 | self.prev_alpha_cumprod = None
196 | self.prev_alpha_cumprod_original = None
197 |
198 |
199 | def unload(self):
200 | logger.info("Moving motion module to CPU")
201 | if self.mm is not None:
202 | self.mm.to(cpu)
203 | torch_gc()
204 | gc.collect()
205 |
206 |
207 | def remove(self):
208 | logger.info("Removing motion module from any memory")
209 | del self.mm
210 | self.mm = None
211 | torch_gc()
212 | gc.collect()
213 |
214 |
215 | mm_animatediff = AnimateDiffMM()
216 |
--------------------------------------------------------------------------------
/scripts/animatediff_output.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import datetime
3 | from pathlib import Path
4 | import traceback
5 |
6 | import imageio.v3 as imageio
7 | import numpy as np
8 | from PIL import Image, PngImagePlugin
9 | import PIL.features
10 | import piexif
11 | from modules import images, shared
12 | from modules.processing import Processed, StableDiffusionProcessing
13 |
14 | from scripts.animatediff_logger import logger_animatediff as logger
15 | from scripts.animatediff_ui import AnimateDiffProcess
16 |
17 |
18 |
19 | class AnimateDiffOutput:
20 | def output(self, p: StableDiffusionProcessing, res: Processed, params: AnimateDiffProcess):
21 | video_paths = []
22 | first_frames = []
23 | from_xyz = any("xyz_grid" in frame.filename for frame in traceback.extract_stack())
24 | logger.info(f"Saving output formats: {', '.join(params.format)}")
25 | date = datetime.datetime.now().strftime('%Y-%m-%d')
26 | output_dir = Path(f"{p.outpath_samples}/AnimateDiff/{date}")
27 | output_dir.mkdir(parents=True, exist_ok=True)
28 | step = params.video_length if params.video_length > params.batch_size else params.batch_size
29 | for i in range(res.index_of_first_image, len(res.images), step):
30 | if i-res.index_of_first_image >= len(res.all_seeds): break
31 | # frame interpolation replaces video_list with interpolated frames
32 | # so make a copy instead of a slice (reference), to avoid modifying res
33 | frame_list = [image.copy() for image in res.images[i : i + params.video_length]]
34 | if from_xyz:
35 | first_frames.append(res.images[i].copy())
36 |
37 | seq = images.get_next_sequence_number(output_dir, "")
38 | filename_suffix = f"-{params.request_id}" if params.request_id else ""
39 | filename = f"{seq:05}-{res.all_seeds[(i-res.index_of_first_image)]}{filename_suffix}"
40 |
41 | video_path_prefix = output_dir / filename
42 |
43 | frame_list = self._add_reverse(params, frame_list)
44 | frame_list = self._interp(p, params, frame_list, filename)
45 | video_paths += self._save(params, frame_list, video_path_prefix, res, i)
46 |
47 | if len(video_paths) == 0:
48 | return
49 |
50 | res.images = video_paths if not p.is_api else (self._encode_video_to_b64(video_paths) + (frame_list if 'Frame' in params.format else []))
51 |
52 | # replace results with first frame of each video so xyz grid draws correctly
53 | if from_xyz:
54 | res.images = first_frames
55 |
56 | if shared.opts.data.get("animatediff_frame_extract_remove", False):
57 | self._remove_frame_extract(params)
58 |
59 |
60 | def _remove_frame_extract(self, params: AnimateDiffProcess):
61 | if params.video_source and params.video_path and Path(params.video_path).exists():
62 | logger.info(f"Removing extracted frames from {params.video_path}")
63 | import shutil
64 | shutil.rmtree(params.video_path)
65 |
66 |
67 | def _add_reverse(self, params: AnimateDiffProcess, frame_list: list):
68 | if params.video_length <= params.batch_size and params.closed_loop in ['A']:
69 | frame_list_reverse = frame_list[::-1]
70 | if len(frame_list_reverse) > 0:
71 | frame_list_reverse.pop(0)
72 | if len(frame_list_reverse) > 0:
73 | frame_list_reverse.pop(-1)
74 | return frame_list + frame_list_reverse
75 | return frame_list
76 |
77 |
78 | def _interp(
79 | self,
80 | p: StableDiffusionProcessing,
81 | params: AnimateDiffProcess,
82 | frame_list: list,
83 | filename: str
84 | ):
85 | if params.interp not in ['FILM']:
86 | return frame_list
87 |
88 | try:
89 | from deforum_helpers.frame_interpolation import (
90 | calculate_frames_to_add, check_and_download_film_model)
91 | from film_interpolation.film_inference import run_film_interp_infer
92 | except ImportError:
93 | logger.error("Deforum not found. Please install: https://github.com/deforum-art/deforum-for-automatic1111-webui.git")
94 | return frame_list
95 |
96 | import glob
97 | import os
98 | import shutil
99 |
100 | import modules.paths as ph
101 |
102 | # load film model
103 | deforum_models_path = ph.models_path + '/Deforum'
104 | film_model_folder = os.path.join(deforum_models_path,'film_interpolation')
105 | film_model_name = 'film_net_fp16.pt'
106 | film_model_path = os.path.join(film_model_folder, film_model_name)
107 | check_and_download_film_model('film_net_fp16.pt', film_model_folder)
108 |
109 | film_in_between_frames_count = calculate_frames_to_add(len(frame_list), params.interp_x)
110 |
111 | # save original frames to tmp folder for deforum input
112 | tmp_folder = f"{p.outpath_samples}/AnimateDiff/tmp"
113 | input_folder = f"{tmp_folder}/input"
114 | os.makedirs(input_folder, exist_ok=True)
115 | for tmp_seq, frame in enumerate(frame_list):
116 | imageio.imwrite(f"{input_folder}/{tmp_seq:05}.png", frame)
117 |
118 | # deforum saves output frames to tmp/{filename}
119 | save_folder = f"{tmp_folder}/{filename}"
120 | os.makedirs(save_folder, exist_ok=True)
121 |
122 | run_film_interp_infer(
123 | model_path = film_model_path,
124 | input_folder = input_folder,
125 | save_folder = save_folder,
126 | inter_frames = film_in_between_frames_count)
127 |
128 | # load deforum output frames and replace video_list
129 | interp_frame_paths = sorted(glob.glob(os.path.join(save_folder, '*.png')))
130 | frame_list = []
131 | for f in interp_frame_paths:
132 | with Image.open(f) as img:
133 | img.load()
134 | frame_list.append(img)
135 |
136 | # if saving PNG, enforce saving to custom folder
137 | if "PNG" in params.format:
138 | params.force_save_to_custom = True
139 |
140 | # remove tmp folder
141 | try: shutil.rmtree(tmp_folder)
142 | except OSError as e: print(f"Error: {e}")
143 |
144 | return frame_list
145 |
146 |
147 | def _save(
148 | self,
149 | params: AnimateDiffProcess,
150 | frame_list: list,
151 | video_path_prefix: Path,
152 | res: Processed,
153 | index: int,
154 | ):
155 | video_paths = []
156 | video_array = [np.array(v) for v in frame_list]
157 | infotext = res.infotexts[index]
158 | s3_enable =shared.opts.data.get("animatediff_s3_enable", False)
159 | use_infotext = shared.opts.enable_pnginfo and infotext is not None
160 | if "PNG" in params.format and (shared.opts.data.get("animatediff_save_to_custom", True) or getattr(params, "force_save_to_custom", False)):
161 | video_path_prefix.mkdir(exist_ok=True, parents=True)
162 | for i, frame in enumerate(frame_list):
163 | png_filename = video_path_prefix/f"{i:05}.png"
164 | png_info = PngImagePlugin.PngInfo()
165 | png_info.add_text('parameters', infotext)
166 | imageio.imwrite(png_filename, frame, pnginfo=png_info)
167 |
168 | if "GIF" in params.format:
169 | video_path_gif = str(video_path_prefix) + ".gif"
170 | video_paths.append(video_path_gif)
171 | if shared.opts.data.get("animatediff_optimize_gif_palette", False):
172 | try:
173 | import av
174 | except ImportError:
175 | from launch import run_pip
176 | run_pip(
177 | "install imageio[pyav]",
178 | "sd-webui-animatediff GIF palette optimization requirement: imageio[pyav]",
179 | )
180 | imageio.imwrite(
181 | video_path_gif, video_array, plugin='pyav', fps=params.fps,
182 | codec='gif', out_pixel_format='pal8',
183 | filter_graph=(
184 | {
185 | "split": ("split", ""),
186 | "palgen": ("palettegen", ""),
187 | "paluse": ("paletteuse", ""),
188 | "scale": ("scale", f"{frame_list[0].width}:{frame_list[0].height}")
189 | },
190 | [
191 | ("video_in", "scale", 0, 0),
192 | ("scale", "split", 0, 0),
193 | ("split", "palgen", 1, 0),
194 | ("split", "paluse", 0, 0),
195 | ("palgen", "paluse", 0, 1),
196 | ("paluse", "video_out", 0, 0),
197 | ]
198 | )
199 | )
200 | # imageio[pyav].imwrite doesn't support comment parameter
201 | if use_infotext:
202 | try:
203 | import exiftool
204 | except ImportError:
205 | from launch import run_pip
206 | run_pip(
207 | "install PyExifTool",
208 | "sd-webui-animatediff GIF palette optimization requirement: PyExifTool",
209 | )
210 | import exiftool
211 | finally:
212 | try:
213 | exif_tool = exiftool.ExifTool()
214 | with exif_tool:
215 | escaped_infotext = infotext.replace('\n', r'\n')
216 | exif_tool.execute("-overwrite_original", f"-Comment={escaped_infotext}", video_path_gif)
217 | except FileNotFoundError:
218 | logger.warn(
219 | "exiftool not found, required for infotext with optimized GIF palette, try: apt install libimage-exiftool-perl or https://exiftool.org/"
220 | )
221 | else:
222 | imageio.imwrite(
223 | video_path_gif,
224 | video_array,
225 | plugin='pillow',
226 | duration=(1000 / params.fps),
227 | loop=params.loop_number,
228 | comment=(infotext if use_infotext else "")
229 | )
230 | if shared.opts.data.get("animatediff_optimize_gif_gifsicle", False):
231 | self._optimize_gif(video_path_gif)
232 |
233 | if "MP4" in params.format:
234 | video_path_mp4 = str(video_path_prefix) + ".mp4"
235 | video_paths.append(video_path_mp4)
236 | try:
237 | import av
238 | except ImportError:
239 | from launch import run_pip
240 | run_pip(
241 | "install imageio[pyav]",
242 | "sd-webui-animatediff MP4 save requirement: imageio[pyav]",
243 | )
244 | import av
245 | options = {
246 | "crf": str(shared.opts.data.get("animatediff_mp4_crf", 23))
247 | }
248 | preset = shared.opts.data.get("animatediff_mp4_preset", "")
249 | if preset != "": options["preset"] = preset
250 | tune = shared.opts.data.get("animatediff_mp4_tune", "")
251 | if tune != "": options["tune"] = tune
252 | output = av.open(video_path_mp4, "w")
253 | logger.info(f"Saving {video_path_mp4}")
254 | if use_infotext:
255 | output.metadata["Comment"] = infotext
256 | stream = output.add_stream('libx264', params.fps, options=options)
257 | stream.width = frame_list[0].width
258 | stream.height = frame_list[0].height
259 | for img in video_array:
260 | frame = av.VideoFrame.from_ndarray(img)
261 | packet = stream.encode(frame)
262 | output.mux(packet)
263 | packet = stream.encode(None)
264 | output.mux(packet)
265 | output.close()
266 |
267 | if "TXT" in params.format and res.images[index].info is not None:
268 | video_path_txt = str(video_path_prefix) + ".txt"
269 | with open(video_path_txt, "w", encoding="utf8") as file:
270 | file.write(f"{infotext}\n")
271 |
272 | if "WEBP" in params.format:
273 | if PIL.features.check('webp_anim'):
274 | video_path_webp = str(video_path_prefix) + ".webp"
275 | video_paths.append(video_path_webp)
276 | exif_bytes = b''
277 | if use_infotext:
278 | exif_bytes = piexif.dump({
279 | "Exif":{
280 | piexif.ExifIFD.UserComment:piexif.helper.UserComment.dump(infotext, encoding="unicode")
281 | }})
282 | lossless = shared.opts.data.get("animatediff_webp_lossless", False)
283 | quality = shared.opts.data.get("animatediff_webp_quality", 80)
284 | logger.info(f"Saving {video_path_webp} with lossless={lossless} and quality={quality}")
285 | imageio.imwrite(video_path_webp, video_array, plugin='pillow',
286 | duration=int(1 / params.fps * 1000), loop=params.loop_number,
287 | lossless=lossless, quality=quality, exif=exif_bytes
288 | )
289 | # see additional Pillow WebP options at https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html#webp
290 | else:
291 | logger.warn("WebP animation in Pillow requires system WebP library v0.5.0 or later")
292 | if "WEBM" in params.format:
293 | video_path_webm = str(video_path_prefix) + ".webm"
294 | video_paths.append(video_path_webm)
295 | logger.info(f"Saving {video_path_webm}")
296 | with imageio.imopen(video_path_webm, "w", plugin="pyav") as file:
297 | if use_infotext:
298 | file.container_metadata["Title"] = infotext
299 | file.container_metadata["Comment"] = infotext
300 | file.write(video_array, codec="vp9", fps=params.fps)
301 |
302 | if s3_enable:
303 | for video_path in video_paths: self._save_to_s3_stroge(video_path)
304 | return video_paths
305 |
306 |
307 | def _optimize_gif(self, video_path: str):
308 | try:
309 | import pygifsicle
310 | except ImportError:
311 | from launch import run_pip
312 |
313 | run_pip(
314 | "install pygifsicle",
315 | "sd-webui-animatediff GIF optimization requirement: pygifsicle",
316 | )
317 | import pygifsicle
318 | finally:
319 | try:
320 | pygifsicle.optimize(video_path)
321 | except FileNotFoundError:
322 | logger.warn("gifsicle not found, required for optimized GIFs, try: apt install gifsicle")
323 |
324 |
325 | def _encode_video_to_b64(self, paths):
326 | videos = []
327 | for v_path in paths:
328 | with open(v_path, "rb") as video_file:
329 | videos.append(base64.b64encode(video_file.read()).decode("utf-8"))
330 | return videos
331 |
332 |
333 | def _install_requirement_if_absent(self,lib):
334 | import launch
335 | if not launch.is_installed(lib):
336 | launch.run_pip(f"install {lib}", f"animatediff requirement: {lib}")
337 |
338 |
339 | def _exist_bucket(self,s3_client,bucketname):
340 | try:
341 | s3_client.head_bucket(Bucket=bucketname)
342 | return True
343 | except ClientError as e:
344 | if e.response['Error']['Code'] == '404':
345 | return False
346 | else:
347 | raise
348 |
349 |
350 | def _save_to_s3_stroge(self ,file_path):
351 | """
352 | put object to object storge
353 | :type bucketname: string
354 | :param bucketname: will save to this 'bucket' , access_key and secret_key must have permissions to save
355 | :type file : file
356 | :param file : the local file
357 | """
358 | self._install_requirement_if_absent('boto3')
359 | import boto3
360 | from botocore.exceptions import ClientError
361 | import os
362 | host = shared.opts.data.get("animatediff_s3_host", '127.0.0.1')
363 | port = shared.opts.data.get("animatediff_s3_port", '9001')
364 | access_key = shared.opts.data.get("animatediff_s3_access_key", '')
365 | secret_key = shared.opts.data.get("animatediff_s3_secret_key", '')
366 | bucket = shared.opts.data.get("animatediff_s3_storge_bucket", '')
367 | client = boto3.client(
368 | service_name='s3',
369 | aws_access_key_id = access_key,
370 | aws_secret_access_key = secret_key,
371 | endpoint_url=f'http://{host}:{port}',
372 | )
373 |
374 | if not os.path.exists(file_path): return
375 | date = datetime.datetime.now().strftime('%Y-%m-%d')
376 | if not self._exist_bucket(client,bucket):
377 | client.create_bucket(Bucket=bucket)
378 |
379 | filename = os.path.split(file_path)[1]
380 | targetpath = f"{date}/{filename}"
381 | client.upload_file(file_path, bucket, targetpath)
382 | logger.info(f"{file_path} saved to s3 in bucket: {bucket}")
383 | return f"http://{host}:{port}/{bucket}/{targetpath}"
384 |
--------------------------------------------------------------------------------
/scripts/animatediff_prompt.py:
--------------------------------------------------------------------------------
1 | import re
2 | import torch
3 |
4 | from modules.processing import StableDiffusionProcessing, Processed
5 |
6 | from scripts.animatediff_logger import logger_animatediff as logger
7 | from scripts.animatediff_infotext import write_params_txt
8 | from scripts.animatediff_ui import AnimateDiffProcess
9 |
10 | class AnimateDiffPromptSchedule:
11 |
12 | def __init__(self, p: StableDiffusionProcessing, params: AnimateDiffProcess):
13 | self.prompt_map = None
14 | self.original_prompt = None
15 | self.parse_prompt(p, params)
16 |
17 |
18 | def save_infotext_img(self, p: StableDiffusionProcessing):
19 | if self.prompt_map is not None:
20 | p.prompts = [self.original_prompt for _ in range(p.batch_size)]
21 |
22 |
23 | def save_infotext_txt(self, res: Processed):
24 | if self.prompt_map is not None:
25 | parts = res.info.split('\nNegative prompt: ', 1)
26 | if len(parts) > 1:
27 | res.info = f"{self.original_prompt}\nNegative prompt: {parts[1]}"
28 | for i in range(len(res.infotexts)):
29 | parts = res.infotexts[i].split('\nNegative prompt: ', 1)
30 | if len(parts) > 1:
31 | res.infotexts[i] = f"{self.original_prompt}\nNegative prompt: {parts[1]}"
32 | write_params_txt(res.info)
33 |
34 |
35 | def parse_prompt(self, p: StableDiffusionProcessing, params: AnimateDiffProcess):
36 | if type(p.prompt) is not str:
37 | logger.warn("prompt is not str, cannot support prompt map")
38 | return
39 |
40 | lines = p.prompt.strip().split('\n')
41 | data = {
42 | 'head_prompts': [],
43 | 'mapp_prompts': {},
44 | 'tail_prompts': []
45 | }
46 |
47 | mode = 'head'
48 | for line in lines:
49 | if mode == 'head':
50 | if re.match(r'^\d+:', line):
51 | mode = 'mapp'
52 | else:
53 | data['head_prompts'].append(line)
54 |
55 | if mode == 'mapp':
56 | match = re.match(r'^(\d+): (.+)$', line)
57 | if match:
58 | frame, prompt = match.groups()
59 | assert int(frame) < params.video_length, \
60 | f"invalid prompt travel frame number: {int(frame)} >= number of frames ({params.video_length})"
61 | data['mapp_prompts'][int(frame)] = prompt
62 | else:
63 | mode = 'tail'
64 |
65 | if mode == 'tail':
66 | data['tail_prompts'].append(line)
67 |
68 | if data['mapp_prompts']:
69 | logger.info("You are using prompt travel.")
70 | self.prompt_map = {}
71 | prompt_list = []
72 | last_frame = 0
73 | current_prompt = ''
74 | for frame, prompt in data['mapp_prompts'].items():
75 | prompt_list += [current_prompt for _ in range(last_frame, frame)]
76 | last_frame = frame
77 | current_prompt = f"{', '.join(data['head_prompts'])}, {prompt}, {', '.join(data['tail_prompts'])}"
78 | self.prompt_map[frame] = current_prompt
79 | prompt_list += [current_prompt for _ in range(last_frame, p.batch_size)]
80 | assert len(prompt_list) == p.batch_size, f"prompt_list length {len(prompt_list)} != batch_size {p.batch_size}"
81 | self.original_prompt = p.prompt
82 | p.prompt = prompt_list * p.n_iter
83 |
84 |
85 | def single_cond(self, center_frame, video_length: int, cond: torch.Tensor, closed_loop = False):
86 | if closed_loop:
87 | key_prev = list(self.prompt_map.keys())[-1]
88 | key_next = list(self.prompt_map.keys())[0]
89 | else:
90 | key_prev = list(self.prompt_map.keys())[0]
91 | key_next = list(self.prompt_map.keys())[-1]
92 |
93 | for p in self.prompt_map.keys():
94 | if p > center_frame:
95 | key_next = p
96 | break
97 | key_prev = p
98 |
99 | dist_prev = center_frame - key_prev
100 | if dist_prev < 0:
101 | dist_prev += video_length
102 | dist_next = key_next - center_frame
103 | if dist_next < 0:
104 | dist_next += video_length
105 |
106 | if key_prev == key_next or dist_prev + dist_next == 0:
107 | return cond[key_prev] if isinstance(cond, torch.Tensor) else {k: v[key_prev] for k, v in cond.items()}
108 |
109 | rate = dist_prev / (dist_prev + dist_next)
110 | if isinstance(cond, torch.Tensor):
111 | return AnimateDiffPromptSchedule.slerp(cond[key_prev], cond[key_next], rate)
112 | else: # isinstance(cond, dict)
113 | return {
114 | k: AnimateDiffPromptSchedule.slerp(v[key_prev], v[key_next], rate)
115 | for k, v in cond.items()
116 | }
117 |
118 |
119 | def multi_cond(self, cond: torch.Tensor, closed_loop = False):
120 | if self.prompt_map is None:
121 | return cond
122 | cond_list = [] if isinstance(cond, torch.Tensor) else {k: [] for k in cond.keys()}
123 | for i in range(cond.shape[0]):
124 | single_cond = self.single_cond(i, cond.shape[0], cond, closed_loop)
125 | if isinstance(cond, torch.Tensor):
126 | cond_list.append(single_cond)
127 | else:
128 | for k, v in single_cond.items():
129 | cond_list[k].append(v)
130 | if isinstance(cond, torch.Tensor):
131 | return torch.stack(cond_list).to(cond.dtype).to(cond.device)
132 | else:
133 | from modules.prompt_parser import DictWithShape
134 | return DictWithShape({k: torch.stack(v).to(cond[k].dtype).to(cond[k].device) for k, v in cond_list.items()}, None)
135 |
136 |
137 | @staticmethod
138 | def slerp(
139 | v0: torch.Tensor, v1: torch.Tensor, t: float, DOT_THRESHOLD: float = 0.9995
140 | ) -> torch.Tensor:
141 | u0 = v0 / v0.norm()
142 | u1 = v1 / v1.norm()
143 | dot = (u0 * u1).sum()
144 | if dot.abs() > DOT_THRESHOLD:
145 | return (1.0 - t) * v0 + t * v1
146 | omega = dot.acos()
147 | return (((1.0 - t) * omega).sin() * v0 + (t * omega).sin() * v1) / omega.sin()
148 |
--------------------------------------------------------------------------------
/scripts/animatediff_settings.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 |
3 | from modules import shared
4 | from scripts.animatediff_ui import supported_save_formats
5 |
6 |
7 | def on_ui_settings():
8 | section = ("animatediff", "AnimateDiff")
9 | s3_selection =("animatediff", "AnimateDiff AWS")
10 | shared.opts.add_option(
11 | "animatediff_model_path",
12 | shared.OptionInfo(
13 | None,
14 | "Path to save AnimateDiff motion modules",
15 | gr.Textbox,
16 | {"placeholder": "Leave empty to use default path: extensions/sd-webui-animatediff/model"},
17 | section=section,
18 | ),
19 | )
20 | shared.opts.add_option(
21 | "animatediff_default_save_formats",
22 | shared.OptionInfo(
23 | ["GIF", "PNG"],
24 | "Default Save Formats",
25 | gr.CheckboxGroup,
26 | {"choices": supported_save_formats},
27 | section=section
28 | ).needs_restart()
29 | )
30 | shared.opts.add_option(
31 | "animatediff_save_to_custom",
32 | shared.OptionInfo(
33 | True,
34 | "Save frames to stable-diffusion-webui/outputs/{ txt|img }2img-images/AnimateDiff/{gif filename}/{date} "
35 | "instead of stable-diffusion-webui/outputs/{ txt|img }2img-images/{date}/.",
36 | gr.Checkbox,
37 | section=section
38 | )
39 | )
40 | shared.opts.add_option(
41 | "animatediff_frame_extract_path",
42 | shared.OptionInfo(
43 | None,
44 | "Path to save extracted frames",
45 | gr.Textbox,
46 | {"placeholder": "Leave empty to use default path: tmp/animatediff-frames"},
47 | section=section
48 | )
49 | )
50 | shared.opts.add_option(
51 | "animatediff_frame_extract_remove",
52 | shared.OptionInfo(
53 | False,
54 | "Always remove extracted frames after processing",
55 | gr.Checkbox,
56 | section=section
57 | )
58 | )
59 | shared.opts.add_option(
60 | "animatediff_default_frame_extract_method",
61 | shared.OptionInfo(
62 | "ffmpeg",
63 | "Default frame extraction method",
64 | gr.Radio,
65 | {"choices": ["ffmpeg", "opencv"]},
66 | section=section
67 | )
68 | )
69 |
70 | # traditional video optimization specification
71 | shared.opts.add_option(
72 | "animatediff_optimize_gif_palette",
73 | shared.OptionInfo(
74 | False,
75 | "Calculate the optimal GIF palette, improves quality significantly, removes banding",
76 | gr.Checkbox,
77 | section=section
78 | )
79 | )
80 | shared.opts.add_option(
81 | "animatediff_optimize_gif_gifsicle",
82 | shared.OptionInfo(
83 | False,
84 | "Optimize GIFs with gifsicle, reduces file size",
85 | gr.Checkbox,
86 | section=section
87 | )
88 | )
89 | shared.opts.add_option(
90 | key="animatediff_mp4_crf",
91 | info=shared.OptionInfo(
92 | default=23,
93 | label="MP4 Quality (CRF)",
94 | component=gr.Slider,
95 | component_args={
96 | "minimum": 0,
97 | "maximum": 51,
98 | "step": 1},
99 | section=section
100 | )
101 | .link("docs", "https://trac.ffmpeg.org/wiki/Encode/H.264#crf")
102 | .info("17 for best quality, up to 28 for smaller size")
103 | )
104 | shared.opts.add_option(
105 | key="animatediff_mp4_preset",
106 | info=shared.OptionInfo(
107 | default="",
108 | label="MP4 Encoding Preset",
109 | component=gr.Dropdown,
110 | component_args={"choices": ["", 'veryslow', 'slower', 'slow', 'medium', 'fast', 'faster', 'veryfast', 'superfast', 'ultrafast']},
111 | section=section,
112 | )
113 | .link("docs", "https://trac.ffmpeg.org/wiki/Encode/H.264#Preset")
114 | .info("encoding speed, use the slowest you can tolerate")
115 | )
116 | shared.opts.add_option(
117 | key="animatediff_mp4_tune",
118 | info=shared.OptionInfo(
119 | default="",
120 | label="MP4 Tune encoding for content type",
121 | component=gr.Dropdown,
122 | component_args={"choices": ["", "film", "animation", "grain"]},
123 | section=section
124 | )
125 | .link("docs", "https://trac.ffmpeg.org/wiki/Encode/H.264#Tune")
126 | .info("optimize for specific content types")
127 | )
128 | shared.opts.add_option(
129 | "animatediff_webp_quality",
130 | shared.OptionInfo(
131 | 80,
132 | "WebP Quality (if lossless=True, increases compression and CPU usage)",
133 | gr.Slider,
134 | {
135 | "minimum": 1,
136 | "maximum": 100,
137 | "step": 1},
138 | section=section
139 | )
140 | )
141 | shared.opts.add_option(
142 | "animatediff_webp_lossless",
143 | shared.OptionInfo(
144 | False,
145 | "Save WebP in lossless format (highest quality, largest file size)",
146 | gr.Checkbox,
147 | section=section
148 | )
149 | )
150 |
151 | # s3 storage specification, most likely for some startup
152 | shared.opts.add_option(
153 | "animatediff_s3_enable",
154 | shared.OptionInfo(
155 | False,
156 | "Enable to Store file in object storage that supports the s3 protocol",
157 | gr.Checkbox,
158 | section=s3_selection
159 | )
160 | )
161 | shared.opts.add_option(
162 | "animatediff_s3_host",
163 | shared.OptionInfo(
164 | None,
165 | "S3 protocol host",
166 | gr.Textbox,
167 | section=s3_selection,
168 | ),
169 | )
170 | shared.opts.add_option(
171 | "animatediff_s3_port",
172 | shared.OptionInfo(
173 | None,
174 | "S3 protocol port",
175 | gr.Textbox,
176 | section=s3_selection,
177 | ),
178 | )
179 | shared.opts.add_option(
180 | "animatediff_s3_access_key",
181 | shared.OptionInfo(
182 | None,
183 | "S3 protocol access_key",
184 | gr.Textbox,
185 | section=s3_selection,
186 | ),
187 | )
188 | shared.opts.add_option(
189 | "animatediff_s3_secret_key",
190 | shared.OptionInfo(
191 | None,
192 | "S3 protocol secret_key",
193 | gr.Textbox,
194 | section=s3_selection,
195 | ),
196 | )
197 | shared.opts.add_option(
198 | "animatediff_s3_storge_bucket",
199 | shared.OptionInfo(
200 | None,
201 | "Bucket for file storage",
202 | gr.Textbox,
203 | section=s3_selection,
204 | ),
205 | )
--------------------------------------------------------------------------------
/scripts/animatediff_ui.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import os
4 | import cv2
5 | import subprocess
6 | import gradio as gr
7 |
8 | from modules import shared
9 | from modules.launch_utils import git
10 | from modules.processing import StableDiffusionProcessing, StableDiffusionProcessingImg2Img
11 |
12 | from scripts.animatediff_mm import mm_animatediff as motion_module
13 | from scripts.animatediff_xyz import xyz_attrs
14 | from scripts.animatediff_logger import logger_animatediff as logger
15 | from scripts.animatediff_utils import get_controlnet_units, extract_frames_from_video
16 |
17 | supported_save_formats = ["GIF", "MP4", "WEBP", "WEBM", "PNG", "TXT"]
18 |
19 | class ToolButton(gr.Button, gr.components.FormComponent):
20 | """Small button with single emoji as text, fits inside gradio forms"""
21 |
22 | def __init__(self, **kwargs):
23 | super().__init__(variant="tool", **kwargs)
24 |
25 |
26 | def get_block_name(self):
27 | return "button"
28 |
29 |
30 | class AnimateDiffProcess:
31 |
32 | def __init__(
33 | self,
34 | model="mm_sd15_v3.safetensors",
35 | enable=False,
36 | video_length=0,
37 | fps=8,
38 | loop_number=0,
39 | closed_loop='R-P',
40 | batch_size=16,
41 | stride=1,
42 | overlap=-1,
43 | format=shared.opts.data.get("animatediff_default_save_formats", ["GIF", "PNG"]),
44 | interp='Off',
45 | interp_x=10,
46 | video_source=None,
47 | video_path='',
48 | mask_path='',
49 | freeinit_enable=False,
50 | freeinit_filter="butterworth",
51 | freeinit_ds=0.25,
52 | freeinit_dt=0.25,
53 | freeinit_iters=3,
54 | latent_power=1,
55 | latent_scale=32,
56 | last_frame=None,
57 | latent_power_last=1,
58 | latent_scale_last=32,
59 | request_id = '',
60 | is_i2i_batch=False,
61 | video_default=False,
62 | prompt_scheduler=None,
63 | ):
64 | self.model = model
65 | self.enable = enable
66 | self.video_length = video_length
67 | self.fps = fps
68 | self.loop_number = loop_number
69 | self.closed_loop = closed_loop
70 | self.batch_size = batch_size
71 | self.stride = stride
72 | self.overlap = overlap
73 | self.format = format
74 | self.interp = interp
75 | self.interp_x = interp_x
76 | self.video_source = video_source
77 | self.video_path = video_path
78 | self.mask_path = mask_path
79 | self.freeinit_enable = freeinit_enable
80 | self.freeinit_filter = freeinit_filter
81 | self.freeinit_ds = freeinit_ds
82 | self.freeinit_dt = freeinit_dt
83 | self.freeinit_iters = freeinit_iters
84 | self.latent_power = latent_power
85 | self.latent_scale = latent_scale
86 | self.last_frame = last_frame
87 | self.latent_power_last = latent_power_last
88 | self.latent_scale_last = latent_scale_last
89 |
90 | # non-ui states
91 | self.request_id = request_id
92 | self.video_default = video_default
93 | self.is_i2i_batch = is_i2i_batch
94 | self.prompt_scheduler = prompt_scheduler
95 |
96 |
97 | def get_list(self, is_img2img: bool):
98 | return list(vars(self).values())[:(25 if is_img2img else 20)]
99 |
100 |
101 | def get_dict(self, is_img2img: bool):
102 | infotext = {
103 | "model": self.model,
104 | "video_length": self.video_length,
105 | "fps": self.fps,
106 | "loop_number": self.loop_number,
107 | "closed_loop": self.closed_loop,
108 | "batch_size": self.batch_size,
109 | "stride": self.stride,
110 | "overlap": self.overlap,
111 | "interp": self.interp,
112 | "interp_x": self.interp_x,
113 | "freeinit_enable": self.freeinit_enable,
114 | }
115 | if self.request_id:
116 | infotext['request_id'] = self.request_id
117 | if motion_module.mm is not None and motion_module.mm.mm_hash is not None:
118 | infotext['mm_hash'] = motion_module.mm.mm_hash[:8]
119 | if is_img2img:
120 | infotext.update({
121 | "latent_power": self.latent_power,
122 | "latent_scale": self.latent_scale,
123 | "latent_power_last": self.latent_power_last,
124 | "latent_scale_last": self.latent_scale_last,
125 | })
126 |
127 | try:
128 | ad_git_tag = subprocess.check_output(
129 | [git, "-C", motion_module.get_model_dir(), "describe", "--tags"],
130 | shell=False, encoding='utf8').strip()
131 | infotext['version'] = ad_git_tag
132 | except Exception as e:
133 | logger.warning(f"Failed to get git tag for AnimateDiff: {e}")
134 |
135 | infotext_str = ', '.join(f"{k}: {v}" for k, v in infotext.items())
136 | return infotext_str
137 |
138 |
139 | def get_param_names(self, is_img2img: bool):
140 | preserve = ["model", "enable", "video_length", "fps", "loop_number", "closed_loop", "batch_size", "stride", "overlap", "format", "interp", "interp_x"]
141 | if is_img2img:
142 | preserve.extend(["latent_power", "latent_power_last", "latent_scale", "latent_scale_last"])
143 |
144 | return preserve
145 |
146 |
147 | def _check(self):
148 | assert (
149 | self.video_length >= 0 and self.fps > 0
150 | ), "Video length and FPS should be positive."
151 | assert not set(supported_save_formats[:-1]).isdisjoint(
152 | self.format
153 | ), "At least one saving format should be selected."
154 |
155 |
156 | def apply_xyz(self):
157 | for k, v in xyz_attrs.items():
158 | setattr(self, k, v)
159 |
160 |
161 | def set_p(self, p: StableDiffusionProcessing):
162 | self._check()
163 | if self.video_length < self.batch_size:
164 | p.batch_size = self.batch_size
165 | else:
166 | p.batch_size = self.video_length
167 | if self.video_length == 0:
168 | self.video_length = p.batch_size
169 | self.video_default = True
170 | if self.overlap == -1:
171 | self.overlap = self.batch_size // 4
172 | if "PNG" not in self.format or shared.opts.data.get("animatediff_save_to_custom", True):
173 | p.do_not_save_samples = True
174 |
175 | cn_units = get_controlnet_units(p)
176 | min_batch_in_cn = -1
177 | for cn_unit in cn_units:
178 | if not cn_unit.enabled:
179 | continue
180 |
181 | # batch path broadcast
182 | if (cn_unit.input_mode.name == 'SIMPLE' and cn_unit.image is None) or \
183 | (cn_unit.input_mode.name == 'BATCH' and not cn_unit.batch_images) or \
184 | (cn_unit.input_mode.name == 'MERGE' and not cn_unit.batch_input_gallery):
185 | if not self.video_path:
186 | extract_frames_from_video(self)
187 | cn_unit.input_mode = cn_unit.input_mode.__class__.BATCH
188 | cn_unit.batch_images = self.video_path
189 |
190 | # mask path broadcast
191 | if cn_unit.input_mode.name == 'BATCH' and self.mask_path and not getattr(cn_unit, 'batch_mask_dir', False):
192 | cn_unit.batch_mask_dir = self.mask_path
193 |
194 | # find minimun control images in CN batch
195 | cn_unit_batch_params = cn_unit.batch_images.split('\n')
196 | if cn_unit.input_mode.name == 'BATCH':
197 | cn_unit.animatediff_batch = True # for A1111 sd-webui-controlnet
198 | if not any([cn_param.startswith("keyframe:") for cn_param in cn_unit_batch_params[1:]]):
199 | cn_unit_batch_num = len(shared.listfiles(cn_unit_batch_params[0]))
200 | if min_batch_in_cn == -1 or cn_unit_batch_num < min_batch_in_cn:
201 | min_batch_in_cn = cn_unit_batch_num
202 |
203 | if min_batch_in_cn != -1:
204 | self.fix_video_length(p, min_batch_in_cn)
205 | def cn_batch_modifler(batch_image_files: List[str], p: StableDiffusionProcessing):
206 | return batch_image_files[:self.video_length]
207 | for cn_unit in cn_units:
208 | if cn_unit.input_mode.name == 'BATCH':
209 | cur_batch_modifier = getattr(cn_unit, "batch_modifiers", [])
210 | cur_batch_modifier.append(cn_batch_modifler)
211 | cn_unit.batch_modifiers = cur_batch_modifier
212 | self.post_setup_cn_for_i2i_batch(p)
213 | logger.info(f"AnimateDiff + ControlNet will generate {self.video_length} frames.")
214 |
215 |
216 | def fix_video_length(self, p: StableDiffusionProcessing, min_batch_in_cn: int):
217 | # ensure that params.video_length <= video_length and params.batch_size <= video_length
218 | if self.video_length > min_batch_in_cn:
219 | self.video_length = min_batch_in_cn
220 | p.batch_size = min_batch_in_cn
221 | if self.batch_size > min_batch_in_cn:
222 | self.batch_size = min_batch_in_cn
223 | if self.video_default:
224 | self.video_length = min_batch_in_cn
225 | p.batch_size = min_batch_in_cn
226 |
227 |
228 | def post_setup_cn_for_i2i_batch(self, p: StableDiffusionProcessing):
229 | if not (self.is_i2i_batch and isinstance(p, StableDiffusionProcessingImg2Img)):
230 | return
231 |
232 | if len(p.init_images) > self.video_length:
233 | p.init_images = p.init_images[:self.video_length]
234 | if p.image_mask and isinstance(p.image_mask, list) and len(p.image_mask) > self.video_length:
235 | p.image_mask = p.image_mask[:self.video_length]
236 | if len(p.init_images) < self.video_length:
237 | self.video_length = len(p.init_images)
238 | p.batch_size = len(p.init_images)
239 | if len(p.init_images) < self.batch_size:
240 | self.batch_size = len(p.init_images)
241 |
242 |
243 | class AnimateDiffUiGroup:
244 | txt2img_submit_button = None
245 | img2img_submit_button = None
246 | setting_sd_model_checkpoint = None
247 | animatediff_ui_group = []
248 |
249 | def __init__(self):
250 | self.params = AnimateDiffProcess()
251 | AnimateDiffUiGroup.animatediff_ui_group.append(self)
252 |
253 | # Free-init
254 | self.filter_type_list = [
255 | "butterworth",
256 | "gaussian",
257 | "box",
258 | "ideal"
259 | ]
260 |
261 |
262 | def get_model_list(self):
263 | model_dir = motion_module.get_model_dir()
264 | if not os.path.isdir(model_dir):
265 | os.makedirs(model_dir, exist_ok=True)
266 | def get_sd_rm_tag():
267 | if shared.sd_model.is_sdxl:
268 | return ["sd1"]
269 | elif shared.sd_model.is_sd2:
270 | return ["sd1", "xl"]
271 | elif shared.sd_model.is_sd1:
272 | return ["xl"]
273 | else:
274 | return []
275 | return sorted([
276 | os.path.relpath(os.path.join(root, filename), model_dir)
277 | for root, dirs, filenames in os.walk(model_dir)
278 | for filename in filenames
279 | if filename != ".gitkeep" and not any(tag in filename for tag in get_sd_rm_tag())
280 | ])
281 |
282 | def refresh_models(self, *inputs):
283 | new_model_list = self.get_model_list()
284 | dd = inputs[0]
285 | if dd in new_model_list:
286 | selected = dd
287 | elif len(new_model_list) > 0:
288 | selected = new_model_list[0]
289 | else:
290 | selected = None
291 | return gr.Dropdown.update(choices=new_model_list, value=selected)
292 |
293 |
294 | def render(self, is_img2img: bool, infotext_fields, paste_field_names):
295 | elemid_prefix = "img2img-ad-" if is_img2img else "txt2img-ad-"
296 | with gr.Accordion("AnimateDiff", open=False):
297 | gr.Markdown(value="Please click [this link](https://github.com/continue-revolution/sd-webui-animatediff/blob/master/docs/how-to-use.md#parameters) to read the documentation of each parameter.")
298 | with gr.Row():
299 | with gr.Row():
300 | model_list = self.get_model_list()
301 | self.params.model = gr.Dropdown(
302 | choices=model_list,
303 | value=(self.params.model if self.params.model in model_list else (model_list[0] if len(model_list) > 0 else None)),
304 | label="Motion module",
305 | type="value",
306 | elem_id=f"{elemid_prefix}motion-module",
307 | )
308 | refresh_model = ToolButton(value="\U0001f504")
309 | refresh_model.click(self.refresh_models, self.params.model, self.params.model)
310 |
311 | self.params.format = gr.CheckboxGroup(
312 | choices=supported_save_formats,
313 | label="Save format",
314 | type="value",
315 | elem_id=f"{elemid_prefix}save-format",
316 | value=self.params.format,
317 | )
318 | with gr.Row():
319 | self.params.enable = gr.Checkbox(
320 | value=self.params.enable, label="Enable AnimateDiff",
321 | elem_id=f"{elemid_prefix}enable"
322 | )
323 | self.params.video_length = gr.Number(
324 | minimum=0,
325 | value=self.params.video_length,
326 | label="Number of frames",
327 | precision=0,
328 | elem_id=f"{elemid_prefix}video-length",
329 | )
330 | self.params.fps = gr.Number(
331 | value=self.params.fps, label="FPS", precision=0,
332 | elem_id=f"{elemid_prefix}fps"
333 | )
334 | self.params.loop_number = gr.Number(
335 | minimum=0,
336 | value=self.params.loop_number,
337 | label="Display loop number",
338 | precision=0,
339 | elem_id=f"{elemid_prefix}loop-number",
340 | )
341 | with gr.Row():
342 | self.params.closed_loop = gr.Radio(
343 | choices=["N", "R-P", "R+P", "A"],
344 | value=self.params.closed_loop,
345 | label="Closed loop",
346 | elem_id=f"{elemid_prefix}closed-loop",
347 | )
348 | self.params.batch_size = gr.Slider(
349 | minimum=1,
350 | maximum=32,
351 | value=self.params.batch_size,
352 | label="Context batch size",
353 | step=1,
354 | precision=0,
355 | elem_id=f"{elemid_prefix}batch-size",
356 | )
357 | self.params.stride = gr.Number(
358 | minimum=1,
359 | value=self.params.stride,
360 | label="Stride",
361 | precision=0,
362 | elem_id=f"{elemid_prefix}stride",
363 | )
364 | self.params.overlap = gr.Number(
365 | minimum=-1,
366 | value=self.params.overlap,
367 | label="Overlap",
368 | precision=0,
369 | elem_id=f"{elemid_prefix}overlap",
370 | )
371 | with gr.Row():
372 | self.params.interp = gr.Radio(
373 | choices=["Off", "FILM"],
374 | label="Frame Interpolation",
375 | elem_id=f"{elemid_prefix}interp-choice",
376 | value=self.params.interp
377 | )
378 | self.params.interp_x = gr.Number(
379 | value=self.params.interp_x, label="Interp X", precision=0,
380 | elem_id=f"{elemid_prefix}interp-x"
381 | )
382 | with gr.Accordion("FreeInit Params", open=False):
383 | gr.Markdown(
384 | """
385 | Adjust to control the smoothness.
386 | """
387 | )
388 | self.params.freeinit_enable = gr.Checkbox(
389 | value=self.params.freeinit_enable,
390 | label="Enable FreeInit",
391 | elem_id=f"{elemid_prefix}freeinit-enable"
392 | )
393 | self.params.freeinit_filter = gr.Dropdown(
394 | value=self.params.freeinit_filter,
395 | label="Filter Type",
396 | info="Default as Butterworth. To fix large inconsistencies, consider using Gaussian.",
397 | choices=self.filter_type_list,
398 | interactive=True,
399 | elem_id=f"{elemid_prefix}freeinit-filter"
400 | )
401 | self.params.freeinit_ds = gr.Slider(
402 | value=self.params.freeinit_ds,
403 | minimum=0,
404 | maximum=1,
405 | step=0.125,
406 | label="d_s",
407 | info="Stop frequency for spatial dimensions (0.0-1.0)",
408 | elem_id=f"{elemid_prefix}freeinit-ds"
409 | )
410 | self.params.freeinit_dt = gr.Slider(
411 | value=self.params.freeinit_dt,
412 | minimum=0,
413 | maximum=1,
414 | step=0.125,
415 | label="d_t",
416 | info="Stop frequency for temporal dimension (0.0-1.0)",
417 | elem_id=f"{elemid_prefix}freeinit-dt"
418 | )
419 | self.params.freeinit_iters = gr.Slider(
420 | value=self.params.freeinit_iters,
421 | minimum=2,
422 | maximum=5,
423 | step=1,
424 | label="FreeInit Iterations",
425 | info="Larger value leads to smoother results & longer inference time.",
426 | elem_id=f"{elemid_prefix}freeinit-dt",
427 | )
428 | self.params.video_source = gr.Video(
429 | value=self.params.video_source,
430 | label="Video source",
431 | )
432 | def update_fps(video_source):
433 | if video_source is not None and video_source != '':
434 | cap = cv2.VideoCapture(video_source)
435 | fps = int(cap.get(cv2.CAP_PROP_FPS))
436 | cap.release()
437 | return fps
438 | else:
439 | return int(self.params.fps.value)
440 | self.params.video_source.change(update_fps, inputs=self.params.video_source, outputs=self.params.fps)
441 | def update_frames(video_source):
442 | if video_source is not None and video_source != '':
443 | cap = cv2.VideoCapture(video_source)
444 | frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
445 | cap.release()
446 | return frames
447 | else:
448 | return int(self.params.video_length.value)
449 | self.params.video_source.change(update_frames, inputs=self.params.video_source, outputs=self.params.video_length)
450 | with gr.Row():
451 | self.params.video_path = gr.Textbox(
452 | value=self.params.video_path,
453 | label="Video path",
454 | elem_id=f"{elemid_prefix}video-path"
455 | )
456 | self.params.mask_path = gr.Textbox(
457 | value=self.params.mask_path,
458 | label="Mask path",
459 | visible=False,
460 | elem_id=f"{elemid_prefix}mask-path"
461 | )
462 | if is_img2img:
463 | with gr.Accordion("I2V Traditional", open=False):
464 | with gr.Row():
465 | self.params.latent_power = gr.Slider(
466 | minimum=0.1,
467 | maximum=10,
468 | value=self.params.latent_power,
469 | step=0.1,
470 | label="Latent power",
471 | elem_id=f"{elemid_prefix}latent-power",
472 | )
473 | self.params.latent_scale = gr.Slider(
474 | minimum=1,
475 | maximum=128,
476 | value=self.params.latent_scale,
477 | label="Latent scale",
478 | elem_id=f"{elemid_prefix}latent-scale"
479 | )
480 | self.params.latent_power_last = gr.Slider(
481 | minimum=0.1,
482 | maximum=10,
483 | value=self.params.latent_power_last,
484 | step=0.1,
485 | label="Optional latent power for last frame",
486 | elem_id=f"{elemid_prefix}latent-power-last",
487 | )
488 | self.params.latent_scale_last = gr.Slider(
489 | minimum=1,
490 | maximum=128,
491 | value=self.params.latent_scale_last,
492 | label="Optional latent scale for last frame",
493 | elem_id=f"{elemid_prefix}latent-scale-last"
494 | )
495 | self.params.last_frame = gr.Image(
496 | label="Optional last frame. Leave it blank if you do not need one.",
497 | type="pil",
498 | )
499 | with gr.Row():
500 | unload = gr.Button(value="Move motion module to CPU (default if lowvram)")
501 | remove = gr.Button(value="Remove motion module from any memory")
502 | unload.click(fn=motion_module.unload)
503 | remove.click(fn=motion_module.remove)
504 |
505 | # Set up controls to be copy-pasted using infotext
506 | fields = self.params.get_param_names(is_img2img)
507 | infotext_fields.extend((getattr(self.params, field), f"AnimateDiff {field}") for field in fields)
508 | paste_field_names.extend(f"AnimateDiff {field}" for field in fields)
509 |
510 | return self.register_unit(is_img2img)
511 |
512 |
513 | def register_unit(self, is_img2img: bool):
514 | unit = gr.State(value=AnimateDiffProcess)
515 | (
516 | AnimateDiffUiGroup.img2img_submit_button
517 | if is_img2img
518 | else AnimateDiffUiGroup.txt2img_submit_button
519 | ).click(
520 | fn=AnimateDiffProcess,
521 | inputs=self.params.get_list(is_img2img),
522 | outputs=unit,
523 | queue=False,
524 | )
525 | return unit
526 |
527 |
528 | @staticmethod
529 | def on_after_component(component, **_kwargs):
530 | elem_id = getattr(component, "elem_id", None)
531 |
532 | if elem_id == "txt2img_generate":
533 | AnimateDiffUiGroup.txt2img_submit_button = component
534 | return
535 |
536 | if elem_id == "img2img_generate":
537 | AnimateDiffUiGroup.img2img_submit_button = component
538 | return
539 |
540 | if elem_id == "setting_sd_model_checkpoint":
541 | for group in AnimateDiffUiGroup.animatediff_ui_group:
542 | component.change( # this step cannot success. I don't know why.
543 | fn=group.refresh_models,
544 | inputs=[group.params.model],
545 | outputs=[group.params.model],
546 | queue=False,
547 | )
548 | return
549 |
550 |
--------------------------------------------------------------------------------
/scripts/animatediff_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import subprocess
4 | from pathlib import Path
5 |
6 | from modules import shared
7 | from modules.paths import data_path
8 | from modules.processing import StableDiffusionProcessing
9 |
10 | from scripts.animatediff_logger import logger_animatediff as logger
11 |
12 | def generate_random_hash(length=8):
13 | import hashlib
14 | import secrets
15 |
16 | # Generate a random number or string
17 | random_data = secrets.token_bytes(32) # 32 bytes of random data
18 |
19 | # Create a SHA-256 hash of the random data
20 | hash_object = hashlib.sha256(random_data)
21 | hash_hex = hash_object.hexdigest()
22 |
23 | # Get the first 10 characters
24 | if length > len(hash_hex):
25 | length = len(hash_hex)
26 | return hash_hex[:length]
27 |
28 |
29 | def get_animatediff_arg(p: StableDiffusionProcessing):
30 | """
31 | Get AnimateDiff argument from `p`. If it's a dict, convert it to AnimateDiffProcess.
32 | """
33 | if not p.scripts:
34 | return None
35 |
36 | for script in p.scripts.alwayson_scripts:
37 | if script.title().lower() == "animatediff":
38 | animatediff_arg = p.script_args[script.args_from]
39 | if isinstance(animatediff_arg, dict):
40 | from scripts.animatediff_ui import AnimateDiffProcess
41 | animatediff_arg = AnimateDiffProcess(**animatediff_arg)
42 | p.script_args = list(p.script_args)
43 | p.script_args[script.args_from] = animatediff_arg
44 | return animatediff_arg
45 |
46 | return None
47 |
48 | def get_controlnet_units(p: StableDiffusionProcessing):
49 | """
50 | Get controlnet arguments from `p`.
51 | """
52 | if not p.scripts:
53 | return []
54 |
55 | for script in p.scripts.alwayson_scripts:
56 | if script.title().lower() == "controlnet":
57 | cn_units = p.script_args[script.args_from:script.args_to]
58 |
59 | if p.is_api and len(cn_units) > 0 and isinstance(cn_units[0], dict):
60 | from scripts import external_code
61 | from scripts.batch_hijack import InputMode
62 | cn_units_dataclass = external_code.get_all_units_in_processing(p)
63 | for cn_unit_dict, cn_unit_dataclass in zip(cn_units, cn_units_dataclass):
64 | if cn_unit_dataclass.image is None:
65 | cn_unit_dataclass.input_mode = InputMode.BATCH
66 | cn_unit_dataclass.batch_images = cn_unit_dict.get("batch_images", None)
67 | p.script_args[script.args_from:script.args_to] = cn_units_dataclass
68 |
69 | return [x for x in cn_units if x.enabled] if not p.is_api else cn_units
70 |
71 | return []
72 |
73 |
74 | def ffmpeg_extract_frames(source_video: str, output_dir: str, extract_key: bool = False):
75 | from modules.devices import device
76 | command = ["ffmpeg"]
77 | if "cuda" in str(device):
78 | command.extend(["-hwaccel", "cuda"])
79 | command.extend(["-i", source_video])
80 | if extract_key:
81 | command.extend(["-vf", "select='eq(pict_type,I)'", "-vsync", "vfr"])
82 | else:
83 | command.extend(["-filter:v", "mpdecimate=hi=64*200:lo=64*50:frac=0.33,setpts=N/FRAME_RATE/TB"])
84 | tmp_frame_dir = Path(output_dir)
85 | tmp_frame_dir.mkdir(parents=True, exist_ok=True)
86 | command.extend(["-qscale:v", "1", "-qmin", "1", "-c:a", "copy", str(tmp_frame_dir / '%09d.jpg')])
87 | logger.info(f"Attempting to extract frames via ffmpeg from {source_video} to {output_dir}")
88 | subprocess.run(command, check=True)
89 |
90 |
91 | def cv2_extract_frames(source_video: str, output_dir: str):
92 | logger.info(f"Attempting to extract frames via OpenCV from {source_video} to {output_dir}")
93 | cap = cv2.VideoCapture(source_video)
94 | frame_count = 0
95 | tmp_frame_dir = Path(output_dir)
96 | tmp_frame_dir.mkdir(parents=True, exist_ok=True)
97 | while cap.isOpened():
98 | ret, frame = cap.read()
99 | if not ret:
100 | break
101 | cv2.imwrite(f"{tmp_frame_dir}/{frame_count}.png", frame)
102 | frame_count += 1
103 | cap.release()
104 |
105 |
106 |
107 | def extract_frames_from_video(params):
108 | assert params.video_source, "You need to specify cond hint for ControlNet."
109 | params.video_path = shared.opts.data.get(
110 | "animatediff_frame_extract_path",
111 | f"{data_path}/tmp/animatediff-frames")
112 | if not params.video_path:
113 | params.video_path = f"{data_path}/tmp/animatediff-frames"
114 | params.video_path = os.path.join(params.video_path, f"{Path(params.video_source).stem}-{generate_random_hash()}")
115 | try:
116 | if shared.opts.data.get("animatediff_default_frame_extract_method", "ffmpeg") == "opencv":
117 | cv2_extract_frames(params.video_source, params.video_path)
118 | else:
119 | ffmpeg_extract_frames(params.video_source, params.video_path)
120 | except Exception as e:
121 | logger.error(f"[AnimateDiff] Error extracting frames via ffmpeg: {e}, fall back to OpenCV.")
122 | cv2_extract_frames(params.video_source, params.video_path)
123 |
--------------------------------------------------------------------------------
/scripts/animatediff_xyz.py:
--------------------------------------------------------------------------------
1 | from types import ModuleType
2 | from typing import Optional
3 |
4 | from modules import scripts
5 |
6 | from scripts.animatediff_logger import logger_animatediff as logger
7 |
8 | xyz_attrs: dict = {}
9 |
10 | def patch_xyz():
11 | xyz_module = find_xyz_module()
12 | if xyz_module is None:
13 | logger.warning("XYZ module not found.")
14 | return
15 | MODULE = "[AnimateDiff]"
16 | xyz_module.axis_options.extend([
17 | xyz_module.AxisOption(
18 | label=f"{MODULE} Enabled",
19 | type=str_to_bool,
20 | apply=apply_state("enable"),
21 | choices=choices_bool),
22 | xyz_module.AxisOption(
23 | label=f"{MODULE} Motion Module",
24 | type=str,
25 | apply=apply_state("model")),
26 | xyz_module.AxisOption(
27 | label=f"{MODULE} Video length",
28 | type=int_or_float,
29 | apply=apply_state("video_length")),
30 | xyz_module.AxisOption(
31 | label=f"{MODULE} FPS",
32 | type=int_or_float,
33 | apply=apply_state("fps")),
34 | xyz_module.AxisOption(
35 | label=f"{MODULE} Use main seed",
36 | type=str_to_bool,
37 | apply=apply_state("use_main_seed"),
38 | choices=choices_bool),
39 | xyz_module.AxisOption(
40 | label=f"{MODULE} Closed loop",
41 | type=str,
42 | apply=apply_state("closed_loop"),
43 | choices=lambda: ["N", "R-P", "R+P", "A"]),
44 | xyz_module.AxisOption(
45 | label=f"{MODULE} Batch size",
46 | type=int_or_float,
47 | apply=apply_state("batch_size")),
48 | xyz_module.AxisOption(
49 | label=f"{MODULE} Stride",
50 | type=int_or_float,
51 | apply=apply_state("stride")),
52 | xyz_module.AxisOption(
53 | label=f"{MODULE} Overlap",
54 | type=int_or_float,
55 | apply=apply_state("overlap")),
56 | xyz_module.AxisOption(
57 | label=f"{MODULE} Interp",
58 | type=str_to_bool,
59 | apply=apply_state("interp"),
60 | choices=choices_bool),
61 | xyz_module.AxisOption(
62 | label=f"{MODULE} Interp X",
63 | type=int_or_float,
64 | apply=apply_state("interp_x")),
65 | xyz_module.AxisOption(
66 | label=f"{MODULE} Video path",
67 | type=str,
68 | apply=apply_state("video_path")),
69 | xyz_module.AxisOptionImg2Img(
70 | label=f"{MODULE} Latent power",
71 | type=int_or_float,
72 | apply=apply_state("latent_power")),
73 | xyz_module.AxisOptionImg2Img(
74 | label=f"{MODULE} Latent scale",
75 | type=int_or_float,
76 | apply=apply_state("latent_scale")),
77 | xyz_module.AxisOptionImg2Img(
78 | label=f"{MODULE} Latent power last",
79 | type=int_or_float,
80 | apply=apply_state("latent_power_last")),
81 | xyz_module.AxisOptionImg2Img(
82 | label=f"{MODULE} Latent scale last",
83 | type=int_or_float,
84 | apply=apply_state("latent_scale_last")),
85 | ])
86 |
87 |
88 | def apply_state(k, key_map=None):
89 | def callback(_p, v, _vs):
90 | if key_map is not None:
91 | v = key_map[v]
92 | xyz_attrs[k] = v
93 |
94 | return callback
95 |
96 |
97 | def str_to_bool(string):
98 | string = str(string)
99 | if string in ["None", ""]:
100 | return None
101 | elif string.lower() in ["true", "1"]:
102 | return True
103 | elif string.lower() in ["false", "0"]:
104 | return False
105 | else:
106 | raise ValueError(f"Could not convert string to boolean: {string}")
107 |
108 |
109 | def int_or_float(string):
110 | try:
111 | return int(string)
112 | except ValueError:
113 | return float(string)
114 |
115 |
116 | def choices_bool():
117 | return ["False", "True"]
118 |
119 |
120 | def find_xyz_module() -> Optional[ModuleType]:
121 | for data in scripts.scripts_data:
122 | if data.script_class.__module__ in {"xyz_grid.py", "xy_grid.py", "scripts.xyz_grid", "scripts.xy_grid"} and hasattr(data, "module"):
123 | return data.module
124 |
125 | return None
126 |
--------------------------------------------------------------------------------