The response has been limited to 50k tokens of the smallest files in the repo. You can remove this limitation by removing the max tokens filter.
├── .gitignore
├── .style.yapf
├── INSTALL.md
├── LICENSE.txt
├── Makefile
├── README.md
├── assets
    ├── comp_effic.png
    ├── data_for_diff_stage.jpg
    ├── i2v_res.png
    ├── logo.png
    ├── t2v_res.jpg
    ├── vben_vs_sota.png
    ├── video_dit_arch.jpg
    └── video_vae_res.jpg
├── examples
    ├── flf2v_input_first_frame.png
    ├── flf2v_input_last_frame.png
    ├── girl.png
    ├── i2v_input.JPG
    └── snake.png
├── generate.py
├── gradio
    ├── fl2v_14B_singleGPU.py
    ├── i2v_14B_singleGPU.py
    ├── t2i_14B_singleGPU.py
    ├── t2v_1.3B_singleGPU.py
    ├── t2v_14B_singleGPU.py
    └── vace.py
├── pyproject.toml
├── requirements.txt
├── tests
    ├── README.md
    └── test.sh
└── wan
    ├── __init__.py
    ├── configs
        ├── __init__.py
        ├── shared_config.py
        ├── wan_i2v_14B.py
        ├── wan_t2v_14B.py
        └── wan_t2v_1_3B.py
    ├── distributed
        ├── __init__.py
        ├── fsdp.py
        └── xdit_context_parallel.py
    ├── first_last_frame2video.py
    ├── image2video.py
    ├── modules
        ├── __init__.py
        ├── attention.py
        ├── clip.py
        ├── model.py
        ├── t5.py
        ├── tokenizers.py
        ├── vace_model.py
        ├── vae.py
        └── xlm_roberta.py
    ├── text2video.py
    ├── utils
        ├── __init__.py
        ├── fm_solvers.py
        ├── fm_solvers_unipc.py
        ├── prompt_extend.py
        ├── qwen_vl_utils.py
        ├── utils.py
        └── vace_processor.py
    └── vace.py


/.gitignore:
--------------------------------------------------------------------------------
 1 | .*
 2 | *.py[cod]
 3 | # *.jpg
 4 | *.jpeg
 5 | # *.png
 6 | *.gif
 7 | *.bmp
 8 | *.mp4
 9 | *.mov
10 | *.mkv
11 | *.log
12 | *.zip
13 | *.pt
14 | *.pth
15 | *.ckpt
16 | *.safetensors
17 | *.json
18 | # *.txt
19 | *.backup
20 | *.pkl
21 | *.html
22 | *.pdf
23 | *.whl
24 | cache
25 | __pycache__/
26 | storage/
27 | samples/
28 | !.gitignore
29 | !requirements.txt
30 | .DS_Store
31 | *DS_Store
32 | google/
33 | Wan2.1-T2V-14B/
34 | Wan2.1-T2V-1.3B/
35 | Wan2.1-I2V-14B-480P/
36 | Wan2.1-I2V-14B-720P/
37 | poetry.lock


--------------------------------------------------------------------------------
/.style.yapf:
--------------------------------------------------------------------------------
  1 | [style]
  2 | # Align closing bracket with visual indentation.
  3 | align_closing_bracket_with_visual_indent=False
  4 | 
  5 | # Allow dictionary keys to exist on multiple lines. For example:
  6 | #
  7 | #   x = {
  8 | #       ('this is the first element of a tuple',
  9 | #        'this is the second element of a tuple'):
 10 | #            value,
 11 | #   }
 12 | allow_multiline_dictionary_keys=False
 13 | 
 14 | # Allow lambdas to be formatted on more than one line.
 15 | allow_multiline_lambdas=False
 16 | 
 17 | # Allow splitting before a default / named assignment in an argument list.
 18 | allow_split_before_default_or_named_assigns=False
 19 | 
 20 | # Allow splits before the dictionary value.
 21 | allow_split_before_dict_value=True
 22 | 
 23 | #   Let spacing indicate operator precedence. For example:
 24 | #
 25 | #     a = 1 * 2 + 3 / 4
 26 | #     b = 1 / 2 - 3 * 4
 27 | #     c = (1 + 2) * (3 - 4)
 28 | #     d = (1 - 2) / (3 + 4)
 29 | #     e = 1 * 2 - 3
 30 | #     f = 1 + 2 + 3 + 4
 31 | #
 32 | # will be formatted as follows to indicate precedence:
 33 | #
 34 | #     a = 1*2 + 3/4
 35 | #     b = 1/2 - 3*4
 36 | #     c = (1+2) * (3-4)
 37 | #     d = (1-2) / (3+4)
 38 | #     e = 1*2 - 3
 39 | #     f = 1 + 2 + 3 + 4
 40 | #
 41 | arithmetic_precedence_indication=False
 42 | 
 43 | # Number of blank lines surrounding top-level function and class
 44 | # definitions.
 45 | blank_lines_around_top_level_definition=2
 46 | 
 47 | # Insert a blank line before a class-level docstring.
 48 | blank_line_before_class_docstring=False
 49 | 
 50 | # Insert a blank line before a module docstring.
 51 | blank_line_before_module_docstring=False
 52 | 
 53 | # Insert a blank line before a 'def' or 'class' immediately nested
 54 | # within another 'def' or 'class'. For example:
 55 | #
 56 | #   class Foo:
 57 | #                      # <------ this blank line
 58 | #     def method():
 59 | #       ...
 60 | blank_line_before_nested_class_or_def=True
 61 | 
 62 | # Do not split consecutive brackets. Only relevant when
 63 | # dedent_closing_brackets is set. For example:
 64 | #
 65 | #    call_func_that_takes_a_dict(
 66 | #        {
 67 | #            'key1': 'value1',
 68 | #            'key2': 'value2',
 69 | #        }
 70 | #    )
 71 | #
 72 | # would reformat to:
 73 | #
 74 | #    call_func_that_takes_a_dict({
 75 | #        'key1': 'value1',
 76 | #        'key2': 'value2',
 77 | #    })
 78 | coalesce_brackets=False
 79 | 
 80 | # The column limit.
 81 | column_limit=80
 82 | 
 83 | # The style for continuation alignment. Possible values are:
 84 | #
 85 | # - SPACE: Use spaces for continuation alignment. This is default behavior.
 86 | # - FIXED: Use fixed number (CONTINUATION_INDENT_WIDTH) of columns
 87 | #   (ie: CONTINUATION_INDENT_WIDTH/INDENT_WIDTH tabs or
 88 | #   CONTINUATION_INDENT_WIDTH spaces) for continuation alignment.
 89 | # - VALIGN-RIGHT: Vertically align continuation lines to multiple of
 90 | #   INDENT_WIDTH columns. Slightly right (one tab or a few spaces) if
 91 | #   cannot vertically align continuation lines with indent characters.
 92 | continuation_align_style=SPACE
 93 | 
 94 | # Indent width used for line continuations.
 95 | continuation_indent_width=4
 96 | 
 97 | # Put closing brackets on a separate line, dedented, if the bracketed
 98 | # expression can't fit in a single line. Applies to all kinds of brackets,
 99 | # including function definitions and calls. For example:
100 | #
101 | #   config = {
102 | #       'key1': 'value1',
103 | #       'key2': 'value2',
104 | #   }        # <--- this bracket is dedented and on a separate line
105 | #
106 | #   time_series = self.remote_client.query_entity_counters(
107 | #       entity='dev3246.region1',
108 | #       key='dns.query_latency_tcp',
109 | #       transform=Transformation.AVERAGE(window=timedelta(seconds=60)),
110 | #       start_ts=now()-timedelta(days=3),
111 | #       end_ts=now(),
112 | #   )        # <--- this bracket is dedented and on a separate line
113 | dedent_closing_brackets=False
114 | 
115 | # Disable the heuristic which places each list element on a separate line
116 | # if the list is comma-terminated.
117 | disable_ending_comma_heuristic=False
118 | 
119 | # Place each dictionary entry onto its own line.
120 | each_dict_entry_on_separate_line=True
121 | 
122 | # Require multiline dictionary even if it would normally fit on one line.
123 | # For example:
124 | #
125 | #   config = {
126 | #       'key1': 'value1'
127 | #   }
128 | force_multiline_dict=False
129 | 
130 | # The regex for an i18n comment. The presence of this comment stops
131 | # reformatting of that line, because the comments are required to be
132 | # next to the string they translate.
133 | i18n_comment=#\..*
134 | 
135 | # The i18n function call names. The presence of this function stops
136 | # reformattting on that line, because the string it has cannot be moved
137 | # away from the i18n comment.
138 | i18n_function_call=N_, _
139 | 
140 | # Indent blank lines.
141 | indent_blank_lines=False
142 | 
143 | # Put closing brackets on a separate line, indented, if the bracketed
144 | # expression can't fit in a single line. Applies to all kinds of brackets,
145 | # including function definitions and calls. For example:
146 | #
147 | #   config = {
148 | #       'key1': 'value1',
149 | #       'key2': 'value2',
150 | #       }        # <--- this bracket is indented and on a separate line
151 | #
152 | #   time_series = self.remote_client.query_entity_counters(
153 | #       entity='dev3246.region1',
154 | #       key='dns.query_latency_tcp',
155 | #       transform=Transformation.AVERAGE(window=timedelta(seconds=60)),
156 | #       start_ts=now()-timedelta(days=3),
157 | #       end_ts=now(),
158 | #       )        # <--- this bracket is indented and on a separate line
159 | indent_closing_brackets=False
160 | 
161 | # Indent the dictionary value if it cannot fit on the same line as the
162 | # dictionary key. For example:
163 | #
164 | #   config = {
165 | #       'key1':
166 | #           'value1',
167 | #       'key2': value1 +
168 | #               value2,
169 | #   }
170 | indent_dictionary_value=True
171 | 
172 | # The number of columns to use for indentation.
173 | indent_width=4
174 | 
175 | # Join short lines into one line. E.g., single line 'if' statements.
176 | join_multiple_lines=False
177 | 
178 | # Do not include spaces around selected binary operators. For example:
179 | #
180 | #   1 + 2 * 3 - 4 / 5
181 | #
182 | # will be formatted as follows when configured with "*,/":
183 | #
184 | #   1 + 2*3 - 4/5
185 | no_spaces_around_selected_binary_operators=
186 | 
187 | # Use spaces around default or named assigns.
188 | spaces_around_default_or_named_assign=False
189 | 
190 | # Adds a space after the opening '{' and before the ending '}' dict delimiters.
191 | #
192 | #   {1: 2}
193 | #
194 | # will be formatted as:
195 | #
196 | #   { 1: 2 }
197 | spaces_around_dict_delimiters=False
198 | 
199 | # Adds a space after the opening '[' and before the ending ']' list delimiters.
200 | #
201 | #   [1, 2]
202 | #
203 | # will be formatted as:
204 | #
205 | #   [ 1, 2 ]
206 | spaces_around_list_delimiters=False
207 | 
208 | # Use spaces around the power operator.
209 | spaces_around_power_operator=False
210 | 
211 | # Use spaces around the subscript / slice operator.  For example:
212 | #
213 | #   my_list[1 : 10 : 2]
214 | spaces_around_subscript_colon=False
215 | 
216 | # Adds a space after the opening '(' and before the ending ')' tuple delimiters.
217 | #
218 | #   (1, 2, 3)
219 | #
220 | # will be formatted as:
221 | #
222 | #   ( 1, 2, 3 )
223 | spaces_around_tuple_delimiters=False
224 | 
225 | # The number of spaces required before a trailing comment.
226 | # This can be a single value (representing the number of spaces
227 | # before each trailing comment) or list of values (representing
228 | # alignment column values; trailing comments within a block will
229 | # be aligned to the first column value that is greater than the maximum
230 | # line length within the block). For example:
231 | #
232 | # With spaces_before_comment=5:
233 | #
234 | #   1 + 1 # Adding values
235 | #
236 | # will be formatted as:
237 | #
238 | #   1 + 1     # Adding values <-- 5 spaces between the end of the statement and comment
239 | #
240 | # With spaces_before_comment=15, 20:
241 | #
242 | #   1 + 1 # Adding values
243 | #   two + two # More adding
244 | #
245 | #   longer_statement # This is a longer statement
246 | #   short # This is a shorter statement
247 | #
248 | #   a_very_long_statement_that_extends_beyond_the_final_column # Comment
249 | #   short # This is a shorter statement
250 | #
251 | # will be formatted as:
252 | #
253 | #   1 + 1          # Adding values <-- end of line comments in block aligned to col 15
254 | #   two + two      # More adding
255 | #
256 | #   longer_statement    # This is a longer statement <-- end of line comments in block aligned to col 20
257 | #   short               # This is a shorter statement
258 | #
259 | #   a_very_long_statement_that_extends_beyond_the_final_column  # Comment <-- the end of line comments are aligned based on the line length
260 | #   short                                                       # This is a shorter statement
261 | #
262 | spaces_before_comment=2
263 | 
264 | # Insert a space between the ending comma and closing bracket of a list,
265 | # etc.
266 | space_between_ending_comma_and_closing_bracket=False
267 | 
268 | # Use spaces inside brackets, braces, and parentheses.  For example:
269 | #
270 | #   method_call( 1 )
271 | #   my_dict[ 3 ][ 1 ][ get_index( *args, **kwargs ) ]
272 | #   my_set = { 1, 2, 3 }
273 | space_inside_brackets=False
274 | 
275 | # Split before arguments
276 | split_all_comma_separated_values=False
277 | 
278 | # Split before arguments, but do not split all subexpressions recursively
279 | # (unless needed).
280 | split_all_top_level_comma_separated_values=False
281 | 
282 | # Split before arguments if the argument list is terminated by a
283 | # comma.
284 | split_arguments_when_comma_terminated=False
285 | 
286 | # Set to True to prefer splitting before '+', '-', '*', '/', '//', or '@'
287 | # rather than after.
288 | split_before_arithmetic_operator=False
289 | 
290 | # Set to True to prefer splitting before '&', '|' or '^' rather than
291 | # after.
292 | split_before_bitwise_operator=False
293 | 
294 | # Split before the closing bracket if a list or dict literal doesn't fit on
295 | # a single line.
296 | split_before_closing_bracket=True
297 | 
298 | # Split before a dictionary or set generator (comp_for). For example, note
299 | # the split before the 'for':
300 | #
301 | #   foo = {
302 | #       variable: 'Hello world, have a nice day!'
303 | #       for variable in bar if variable != 42
304 | #   }
305 | split_before_dict_set_generator=False
306 | 
307 | # Split before the '.' if we need to split a longer expression:
308 | #
309 | #   foo = ('This is a really long string: {}, {}, {}, {}'.format(a, b, c, d))
310 | #
311 | # would reformat to something like:
312 | #
313 | #   foo = ('This is a really long string: {}, {}, {}, {}'
314 | #          .format(a, b, c, d))
315 | split_before_dot=False
316 | 
317 | # Split after the opening paren which surrounds an expression if it doesn't
318 | # fit on a single line.
319 | split_before_expression_after_opening_paren=True
320 | 
321 | # If an argument / parameter list is going to be split, then split before
322 | # the first argument.
323 | split_before_first_argument=False
324 | 
325 | # Set to True to prefer splitting before 'and' or 'or' rather than
326 | # after.
327 | split_before_logical_operator=False
328 | 
329 | # Split named assignments onto individual lines.
330 | split_before_named_assigns=True
331 | 
332 | # Set to True to split list comprehensions and generators that have
333 | # non-trivial expressions and multiple clauses before each of these
334 | # clauses. For example:
335 | #
336 | #   result = [
337 | #       a_long_var + 100 for a_long_var in xrange(1000)
338 | #       if a_long_var % 10]
339 | #
340 | # would reformat to something like:
341 | #
342 | #   result = [
343 | #       a_long_var + 100
344 | #       for a_long_var in xrange(1000)
345 | #       if a_long_var % 10]
346 | split_complex_comprehension=True
347 | 
348 | # The penalty for splitting right after the opening bracket.
349 | split_penalty_after_opening_bracket=300
350 | 
351 | # The penalty for splitting the line after a unary operator.
352 | split_penalty_after_unary_operator=10000
353 | 
354 | # The penalty of splitting the line around the '+', '-', '*', '/', '//',
355 | # ``%``, and '@' operators.
356 | split_penalty_arithmetic_operator=300
357 | 
358 | # The penalty for splitting right before an if expression.
359 | split_penalty_before_if_expr=0
360 | 
361 | # The penalty of splitting the line around the '&', '|', and '^'
362 | # operators.
363 | split_penalty_bitwise_operator=300
364 | 
365 | # The penalty for splitting a list comprehension or generator
366 | # expression.
367 | split_penalty_comprehension=2100
368 | 
369 | # The penalty for characters over the column limit.
370 | split_penalty_excess_character=7000
371 | 
372 | # The penalty incurred by adding a line split to the unwrapped line. The
373 | # more line splits added the higher the penalty.
374 | split_penalty_for_added_line_split=30
375 | 
376 | # The penalty of splitting a list of "import as" names. For example:
377 | #
378 | #   from a_very_long_or_indented_module_name_yada_yad import (long_argument_1,
379 | #                                                             long_argument_2,
380 | #                                                             long_argument_3)
381 | #
382 | # would reformat to something like:
383 | #
384 | #   from a_very_long_or_indented_module_name_yada_yad import (
385 | #       long_argument_1, long_argument_2, long_argument_3)
386 | split_penalty_import_names=0
387 | 
388 | # The penalty of splitting the line around the 'and' and 'or'
389 | # operators.
390 | split_penalty_logical_operator=300
391 | 
392 | # Use the Tab character for indentation.
393 | use_tabs=False


--------------------------------------------------------------------------------
/INSTALL.md:
--------------------------------------------------------------------------------
 1 | # Installation Guide
 2 | 
 3 | ## Install with pip
 4 | 
 5 | ```bash
 6 | pip install .
 7 | pip install .[dev]  # Installe aussi les outils de dev
 8 | ```
 9 | 
10 | ## Install with Poetry
11 | 
12 | Ensure you have [Poetry](https://python-poetry.org/docs/#installation) installed on your system.
13 | 
14 | To install all dependencies:
15 | 
16 | ```bash
17 | poetry install
18 | ```
19 | 
20 | ### Handling `flash-attn` Installation Issues
21 | 
22 | If `flash-attn` fails due to **PEP 517 build issues**, you can try one of the following fixes.
23 | 
24 | #### No-Build-Isolation Installation (Recommended)
25 | ```bash
26 | poetry run pip install --upgrade pip setuptools wheel
27 | poetry run pip install flash-attn --no-build-isolation
28 | poetry install
29 | ```
30 | 
31 | #### Install from Git (Alternative)
32 | ```bash
33 | poetry run pip install git+https://github.com/Dao-AILab/flash-attention.git
34 | ```
35 | 
36 | ---
37 | 
38 | ### Running the Model
39 | 
40 | Once the installation is complete, you can run **Wan2.1** using:
41 | 
42 | ```bash
43 | poetry run python generate.py --task t2v-14B --size '1280x720' --ckpt_dir ./Wan2.1-T2V-14B --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
44 | ```
45 | 
46 | #### Test
47 | ```bash
48 | pytest tests/
49 | ```
50 | #### Format
51 | ```bash
52 | black .
53 | isort .
54 | ```
55 | 


--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
  1 |                                  Apache License
  2 |                            Version 2.0, January 2004
  3 |                         http://www.apache.org/licenses/
  4 | 
  5 |    TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
  6 | 
  7 |    1. Definitions.
  8 | 
  9 |       "License" shall mean the terms and conditions for use, reproduction,
 10 |       and distribution as defined by Sections 1 through 9 of this document.
 11 | 
 12 |       "Licensor" shall mean the copyright owner or entity authorized by
 13 |       the copyright owner that is granting the License.
 14 | 
 15 |       "Legal Entity" shall mean the union of the acting entity and all
 16 |       other entities that control, are controlled by, or are under common
 17 |       control with that entity. For the purposes of this definition,
 18 |       "control" means (i) the power, direct or indirect, to cause the
 19 |       direction or management of such entity, whether by contract or
 20 |       otherwise, or (ii) ownership of fifty percent (50%) or more of the
 21 |       outstanding shares, or (iii) beneficial ownership of such entity.
 22 | 
 23 |       "You" (or "Your") shall mean an individual or Legal Entity
 24 |       exercising permissions granted by this License.
 25 | 
 26 |       "Source" form shall mean the preferred form for making modifications,
 27 |       including but not limited to software source code, documentation
 28 |       source, and configuration files.
 29 | 
 30 |       "Object" form shall mean any form resulting from mechanical
 31 |       transformation or translation of a Source form, including but
 32 |       not limited to compiled object code, generated documentation,
 33 |       and conversions to other media types.
 34 | 
 35 |       "Work" shall mean the work of authorship, whether in Source or
 36 |       Object form, made available under the License, as indicated by a
 37 |       copyright notice that is included in or attached to the work
 38 |       (an example is provided in the Appendix below).
 39 | 
 40 |       "Derivative Works" shall mean any work, whether in Source or Object
 41 |       form, that is based on (or derived from) the Work and for which the
 42 |       editorial revisions, annotations, elaborations, or other modifications
 43 |       represent, as a whole, an original work of authorship. For the purposes
 44 |       of this License, Derivative Works shall not include works that remain
 45 |       separable from, or merely link (or bind by name) to the interfaces of,
 46 |       the Work and Derivative Works thereof.
 47 | 
 48 |       "Contribution" shall mean any work of authorship, including
 49 |       the original version of the Work and any modifications or additions
 50 |       to that Work or Derivative Works thereof, that is intentionally
 51 |       submitted to Licensor for inclusion in the Work by the copyright owner
 52 |       or by an individual or Legal Entity authorized to submit on behalf of
 53 |       the copyright owner. For the purposes of this definition, "submitted"
 54 |       means any form of electronic, verbal, or written communication sent
 55 |       to the Licensor or its representatives, including but not limited to
 56 |       communication on electronic mailing lists, source code control systems,
 57 |       and issue tracking systems that are managed by, or on behalf of, the
 58 |       Licensor for the purpose of discussing and improving the Work, but
 59 |       excluding communication that is conspicuously marked or otherwise
 60 |       designated in writing by the copyright owner as "Not a Contribution."
 61 | 
 62 |       "Contributor" shall mean Licensor and any individual or Legal Entity
 63 |       on behalf of whom a Contribution has been received by Licensor and
 64 |       subsequently incorporated within the Work.
 65 | 
 66 |    2. Grant of Copyright License. Subject to the terms and conditions of
 67 |       this License, each Contributor hereby grants to You a perpetual,
 68 |       worldwide, non-exclusive, no-charge, royalty-free, irrevocable
 69 |       copyright license to reproduce, prepare Derivative Works of,
 70 |       publicly display, publicly perform, sublicense, and distribute the
 71 |       Work and such Derivative Works in Source or Object form.
 72 | 
 73 |    3. Grant of Patent License. Subject to the terms and conditions of
 74 |       this License, each Contributor hereby grants to You a perpetual,
 75 |       worldwide, non-exclusive, no-charge, royalty-free, irrevocable
 76 |       (except as stated in this section) patent license to make, have made,
 77 |       use, offer to sell, sell, import, and otherwise transfer the Work,
 78 |       where such license applies only to those patent claims licensable
 79 |       by such Contributor that are necessarily infringed by their
 80 |       Contribution(s) alone or by combination of their Contribution(s)
 81 |       with the Work to which such Contribution(s) was submitted. If You
 82 |       institute patent litigation against any entity (including a
 83 |       cross-claim or counterclaim in a lawsuit) alleging that the Work
 84 |       or a Contribution incorporated within the Work constitutes direct
 85 |       or contributory patent infringement, then any patent licenses
 86 |       granted to You under this License for that Work shall terminate
 87 |       as of the date such litigation is filed.
 88 | 
 89 |    4. Redistribution. You may reproduce and distribute copies of the
 90 |       Work or Derivative Works thereof in any medium, with or without
 91 |       modifications, and in Source or Object form, provided that You
 92 |       meet the following conditions:
 93 | 
 94 |       (a) You must give any other recipients of the Work or
 95 |           Derivative Works a copy of this License; and
 96 | 
 97 |       (b) You must cause any modified files to carry prominent notices
 98 |           stating that You changed the files; and
 99 | 
100 |       (c) You must retain, in the Source form of any Derivative Works
101 |           that You distribute, all copyright, patent, trademark, and
102 |           attribution notices from the Source form of the Work,
103 |           excluding those notices that do not pertain to any part of
104 |           the Derivative Works; and
105 | 
106 |       (d) If the Work includes a "NOTICE" text file as part of its
107 |           distribution, then any Derivative Works that You distribute must
108 |           include a readable copy of the attribution notices contained
109 |           within such NOTICE file, excluding those notices that do not
110 |           pertain to any part of the Derivative Works, in at least one
111 |           of the following places: within a NOTICE text file distributed
112 |           as part of the Derivative Works; within the Source form or
113 |           documentation, if provided along with the Derivative Works; or,
114 |           within a display generated by the Derivative Works, if and
115 |           wherever such third-party notices normally appear. The contents
116 |           of the NOTICE file are for informational purposes only and
117 |           do not modify the License. You may add Your own attribution
118 |           notices within Derivative Works that You distribute, alongside
119 |           or as an addendum to the NOTICE text from the Work, provided
120 |           that such additional attribution notices cannot be construed
121 |           as modifying the License.
122 | 
123 |       You may add Your own copyright statement to Your modifications and
124 |       may provide additional or different license terms and conditions
125 |       for use, reproduction, or distribution of Your modifications, or
126 |       for any such Derivative Works as a whole, provided Your use,
127 |       reproduction, and distribution of the Work otherwise complies with
128 |       the conditions stated in this License.
129 | 
130 |    5. Submission of Contributions. Unless You explicitly state otherwise,
131 |       any Contribution intentionally submitted for inclusion in the Work
132 |       by You to the Licensor shall be under the terms and conditions of
133 |       this License, without any additional terms or conditions.
134 |       Notwithstanding the above, nothing herein shall supersede or modify
135 |       the terms of any separate license agreement you may have executed
136 |       with Licensor regarding such Contributions.
137 | 
138 |    6. Trademarks. This License does not grant permission to use the trade
139 |       names, trademarks, service marks, or product names of the Licensor,
140 |       except as required for reasonable and customary use in describing the
141 |       origin of the Work and reproducing the content of the NOTICE file.
142 | 
143 |    7. Disclaimer of Warranty. Unless required by applicable law or
144 |       agreed to in writing, Licensor provides the Work (and each
145 |       Contributor provides its Contributions) on an "AS IS" BASIS,
146 |       WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 |       implied, including, without limitation, any warranties or conditions
148 |       of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 |       PARTICULAR PURPOSE. You are solely responsible for determining the
150 |       appropriateness of using or redistributing the Work and assume any
151 |       risks associated with Your exercise of permissions under this License.
152 | 
153 |    8. Limitation of Liability. In no event and under no legal theory,
154 |       whether in tort (including negligence), contract, or otherwise,
155 |       unless required by applicable law (such as deliberate and grossly
156 |       negligent acts) or agreed to in writing, shall any Contributor be
157 |       liable to You for damages, including any direct, indirect, special,
158 |       incidental, or consequential damages of any character arising as a
159 |       result of this License or out of the use or inability to use the
160 |       Work (including but not limited to damages for loss of goodwill,
161 |       work stoppage, computer failure or malfunction, or any and all
162 |       other commercial damages or losses), even if such Contributor
163 |       has been advised of the possibility of such damages.
164 | 
165 |    9. Accepting Warranty or Additional Liability. While redistributing
166 |       the Work or Derivative Works thereof, You may choose to offer,
167 |       and charge a fee for, acceptance of support, warranty, indemnity,
168 |       or other liability obligations and/or rights consistent with this
169 |       License. However, in accepting such obligations, You may act only
170 |       on Your own behalf and on Your sole responsibility, not on behalf
171 |       of any other Contributor, and only if You agree to indemnify,
172 |       defend, and hold each Contributor harmless for any liability
173 |       incurred by, or claims asserted against, such Contributor by reason
174 |       of your accepting any such warranty or additional liability.
175 | 
176 |    END OF TERMS AND CONDITIONS
177 | 
178 |    APPENDIX: How to apply the Apache License to your work.
179 | 
180 |       To apply the Apache License to your work, attach the following
181 |       boilerplate notice, with the fields enclosed by brackets "[]"
182 |       replaced with your own identifying information. (Don't include
183 |       the brackets!)  The text should be enclosed in the appropriate
184 |       comment syntax for the file format. We also recommend that a
185 |       file or class name and description of purpose be included on the
186 |       same "printed page" as the copyright notice for easier
187 |       identification within third-party archives.
188 | 
189 |    Copyright [yyyy] [name of copyright owner]
190 | 
191 |    Licensed under the Apache License, Version 2.0 (the "License");
192 |    you may not use this file except in compliance with the License.
193 |    You may obtain a copy of the License at
194 | 
195 |        http://www.apache.org/licenses/LICENSE-2.0
196 | 
197 |    Unless required by applicable law or agreed to in writing, software
198 |    distributed under the License is distributed on an "AS IS" BASIS,
199 |    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 |    See the License for the specific language governing permissions and
201 |    limitations under the License.
202 | 


--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | .PHONY: format
2 | 
3 | format:
4 | 	isort generate.py gradio wan
5 | 	yapf -i -r *.py generate.py gradio wan
6 | 


--------------------------------------------------------------------------------
/assets/comp_effic.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wan-Video/Wan2.1/261ca43e67ccaf68dbb5c0d8d1f971efde267573/assets/comp_effic.png


--------------------------------------------------------------------------------
/assets/data_for_diff_stage.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wan-Video/Wan2.1/261ca43e67ccaf68dbb5c0d8d1f971efde267573/assets/data_for_diff_stage.jpg


--------------------------------------------------------------------------------
/assets/i2v_res.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wan-Video/Wan2.1/261ca43e67ccaf68dbb5c0d8d1f971efde267573/assets/i2v_res.png


--------------------------------------------------------------------------------
/assets/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wan-Video/Wan2.1/261ca43e67ccaf68dbb5c0d8d1f971efde267573/assets/logo.png


--------------------------------------------------------------------------------
/assets/t2v_res.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wan-Video/Wan2.1/261ca43e67ccaf68dbb5c0d8d1f971efde267573/assets/t2v_res.jpg


--------------------------------------------------------------------------------
/assets/vben_vs_sota.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wan-Video/Wan2.1/261ca43e67ccaf68dbb5c0d8d1f971efde267573/assets/vben_vs_sota.png


--------------------------------------------------------------------------------
/assets/video_dit_arch.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wan-Video/Wan2.1/261ca43e67ccaf68dbb5c0d8d1f971efde267573/assets/video_dit_arch.jpg


--------------------------------------------------------------------------------
/assets/video_vae_res.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wan-Video/Wan2.1/261ca43e67ccaf68dbb5c0d8d1f971efde267573/assets/video_vae_res.jpg


--------------------------------------------------------------------------------
/examples/flf2v_input_first_frame.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wan-Video/Wan2.1/261ca43e67ccaf68dbb5c0d8d1f971efde267573/examples/flf2v_input_first_frame.png


--------------------------------------------------------------------------------
/examples/flf2v_input_last_frame.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wan-Video/Wan2.1/261ca43e67ccaf68dbb5c0d8d1f971efde267573/examples/flf2v_input_last_frame.png


--------------------------------------------------------------------------------
/examples/girl.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wan-Video/Wan2.1/261ca43e67ccaf68dbb5c0d8d1f971efde267573/examples/girl.png


--------------------------------------------------------------------------------
/examples/i2v_input.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wan-Video/Wan2.1/261ca43e67ccaf68dbb5c0d8d1f971efde267573/examples/i2v_input.JPG


--------------------------------------------------------------------------------
/examples/snake.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wan-Video/Wan2.1/261ca43e67ccaf68dbb5c0d8d1f971efde267573/examples/snake.png


--------------------------------------------------------------------------------
/gradio/fl2v_14B_singleGPU.py:
--------------------------------------------------------------------------------
  1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
  2 | import argparse
  3 | import gc
  4 | import os
  5 | import os.path as osp
  6 | import sys
  7 | import warnings
  8 | 
  9 | import gradio as gr
 10 | 
 11 | warnings.filterwarnings('ignore')
 12 | 
 13 | # Model
 14 | sys.path.insert(
 15 |     0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
 16 | import wan
 17 | from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS
 18 | from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
 19 | from wan.utils.utils import cache_video
 20 | 
 21 | # Global Var
 22 | prompt_expander = None
 23 | wan_flf2v_720P = None
 24 | 
 25 | 
 26 | # Button Func
 27 | def load_model(value):
 28 |     global wan_flf2v_720P
 29 | 
 30 |     if value == '------':
 31 |         print("No model loaded")
 32 |         return '------'
 33 | 
 34 |     if value == '720P':
 35 |         if args.ckpt_dir_720p is None:
 36 |             print("Please specify the checkpoint directory for 720P model")
 37 |             return '------'
 38 |         if wan_flf2v_720P is not None:
 39 |             pass
 40 |         else:
 41 |             gc.collect()
 42 | 
 43 |             print("load 14B-720P flf2v model...", end='', flush=True)
 44 |             cfg = WAN_CONFIGS['flf2v-14B']
 45 |             wan_flf2v_720P = wan.WanFLF2V(
 46 |                 config=cfg,
 47 |                 checkpoint_dir=args.ckpt_dir_720p,
 48 |                 device_id=0,
 49 |                 rank=0,
 50 |                 t5_fsdp=False,
 51 |                 dit_fsdp=False,
 52 |                 use_usp=False,
 53 |             )
 54 |             print("done", flush=True)
 55 |             return '720P'
 56 |     return value
 57 | 
 58 | 
 59 | def prompt_enc(prompt, img_first, img_last, tar_lang):
 60 |     print('prompt extend...')
 61 |     if img_first is None or img_last is None:
 62 |         print('Please upload the first and last frames')
 63 |         return prompt
 64 |     global prompt_expander
 65 |     prompt_output = prompt_expander(
 66 |         prompt, image=[img_first, img_last], tar_lang=tar_lang.lower())
 67 |     if prompt_output.status == False:
 68 |         return prompt
 69 |     else:
 70 |         return prompt_output.prompt
 71 | 
 72 | 
 73 | def flf2v_generation(flf2vid_prompt, flf2vid_image_first, flf2vid_image_last,
 74 |                      resolution, sd_steps, guide_scale, shift_scale, seed,
 75 |                      n_prompt):
 76 | 
 77 |     if resolution == '------':
 78 |         print(
 79 |             'Please specify the resolution ckpt dir or specify the resolution')
 80 |         return None
 81 | 
 82 |     else:
 83 |         if resolution == '720P':
 84 |             global wan_flf2v_720P
 85 |             video = wan_flf2v_720P.generate(
 86 |                 flf2vid_prompt,
 87 |                 flf2vid_image_first,
 88 |                 flf2vid_image_last,
 89 |                 max_area=MAX_AREA_CONFIGS['720*1280'],
 90 |                 shift=shift_scale,
 91 |                 sampling_steps=sd_steps,
 92 |                 guide_scale=guide_scale,
 93 |                 n_prompt=n_prompt,
 94 |                 seed=seed,
 95 |                 offload_model=True)
 96 |             pass
 97 |         else:
 98 |             print('Sorry, currently only 720P is supported.')
 99 |             return None
100 | 
101 |         cache_video(
102 |             tensor=video[None],
103 |             save_file="example.mp4",
104 |             fps=16,
105 |             nrow=1,
106 |             normalize=True,
107 |             value_range=(-1, 1))
108 | 
109 |         return "example.mp4"
110 | 
111 | 
112 | # Interface
113 | def gradio_interface():
114 |     with gr.Blocks() as demo:
115 |         gr.Markdown("""
116 |                     <div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
117 |                         Wan2.1 (FLF2V-14B)
118 |                     </div>
119 |                     <div style="text-align: center; font-size: 16px; font-weight: normal; margin-bottom: 20px;">
120 |                         Wan: Open and Advanced Large-Scale Video Generative Models.
121 |                     </div>
122 |                     """)
123 | 
124 |         with gr.Row():
125 |             with gr.Column():
126 |                 resolution = gr.Dropdown(
127 |                     label='Resolution',
128 |                     choices=['------', '720P'],
129 |                     value='------')
130 |                 flf2vid_image_first = gr.Image(
131 |                     type="pil",
132 |                     label="Upload First Frame",
133 |                     elem_id="image_upload",
134 |                 )
135 |                 flf2vid_image_last = gr.Image(
136 |                     type="pil",
137 |                     label="Upload Last Frame",
138 |                     elem_id="image_upload",
139 |                 )
140 |                 flf2vid_prompt = gr.Textbox(
141 |                     label="Prompt",
142 |                     placeholder="Describe the video you want to generate",
143 |                 )
144 |                 tar_lang = gr.Radio(
145 |                     choices=["ZH", "EN"],
146 |                     label="Target language of prompt enhance",
147 |                     value="ZH")
148 |                 run_p_button = gr.Button(value="Prompt Enhance")
149 | 
150 |                 with gr.Accordion("Advanced Options", open=True):
151 |                     with gr.Row():
152 |                         sd_steps = gr.Slider(
153 |                             label="Diffusion steps",
154 |                             minimum=1,
155 |                             maximum=1000,
156 |                             value=50,
157 |                             step=1)
158 |                         guide_scale = gr.Slider(
159 |                             label="Guide scale",
160 |                             minimum=0,
161 |                             maximum=20,
162 |                             value=5.0,
163 |                             step=1)
164 |                     with gr.Row():
165 |                         shift_scale = gr.Slider(
166 |                             label="Shift scale",
167 |                             minimum=0,
168 |                             maximum=20,
169 |                             value=5.0,
170 |                             step=1)
171 |                         seed = gr.Slider(
172 |                             label="Seed",
173 |                             minimum=-1,
174 |                             maximum=2147483647,
175 |                             step=1,
176 |                             value=-1)
177 |                     n_prompt = gr.Textbox(
178 |                         label="Negative Prompt",
179 |                         placeholder="Describe the negative prompt you want to add"
180 |                     )
181 | 
182 |                 run_flf2v_button = gr.Button("Generate Video")
183 | 
184 |             with gr.Column():
185 |                 result_gallery = gr.Video(
186 |                     label='Generated Video', interactive=False, height=600)
187 | 
188 |         resolution.input(
189 |             fn=load_model, inputs=[resolution], outputs=[resolution])
190 | 
191 |         run_p_button.click(
192 |             fn=prompt_enc,
193 |             inputs=[
194 |                 flf2vid_prompt, flf2vid_image_first, flf2vid_image_last,
195 |                 tar_lang
196 |             ],
197 |             outputs=[flf2vid_prompt])
198 | 
199 |         run_flf2v_button.click(
200 |             fn=flf2v_generation,
201 |             inputs=[
202 |                 flf2vid_prompt, flf2vid_image_first, flf2vid_image_last,
203 |                 resolution, sd_steps, guide_scale, shift_scale, seed, n_prompt
204 |             ],
205 |             outputs=[result_gallery],
206 |         )
207 | 
208 |     return demo
209 | 
210 | 
211 | # Main
212 | def _parse_args():
213 |     parser = argparse.ArgumentParser(
214 |         description="Generate a video from a text prompt or image using Gradio")
215 |     parser.add_argument(
216 |         "--ckpt_dir_720p",
217 |         type=str,
218 |         default=None,
219 |         help="The path to the checkpoint directory.")
220 |     parser.add_argument(
221 |         "--prompt_extend_method",
222 |         type=str,
223 |         default="local_qwen",
224 |         choices=["dashscope", "local_qwen"],
225 |         help="The prompt extend method to use.")
226 |     parser.add_argument(
227 |         "--prompt_extend_model",
228 |         type=str,
229 |         default=None,
230 |         help="The prompt extend model to use.")
231 | 
232 |     args = parser.parse_args()
233 |     assert args.ckpt_dir_720p is not None, "Please specify the checkpoint directory."
234 | 
235 |     return args
236 | 
237 | 
238 | if __name__ == '__main__':
239 |     args = _parse_args()
240 | 
241 |     print("Step1: Init prompt_expander...", end='', flush=True)
242 |     if args.prompt_extend_method == "dashscope":
243 |         prompt_expander = DashScopePromptExpander(
244 |             model_name=args.prompt_extend_model, is_vl=True)
245 |     elif args.prompt_extend_method == "local_qwen":
246 |         prompt_expander = QwenPromptExpander(
247 |             model_name=args.prompt_extend_model, is_vl=True, device=0)
248 |     else:
249 |         raise NotImplementedError(
250 |             f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
251 |     print("done", flush=True)
252 | 
253 |     demo = gradio_interface()
254 |     demo.launch(server_name="0.0.0.0", share=False, server_port=7860)
255 | 


--------------------------------------------------------------------------------
/gradio/i2v_14B_singleGPU.py:
--------------------------------------------------------------------------------
  1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
  2 | import argparse
  3 | import gc
  4 | import os
  5 | import os.path as osp
  6 | import sys
  7 | import warnings
  8 | 
  9 | import gradio as gr
 10 | 
 11 | warnings.filterwarnings('ignore')
 12 | 
 13 | # Model
 14 | sys.path.insert(
 15 |     0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
 16 | import wan
 17 | from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS
 18 | from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
 19 | from wan.utils.utils import cache_video
 20 | 
 21 | # Global Var
 22 | prompt_expander = None
 23 | wan_i2v_480P = None
 24 | wan_i2v_720P = None
 25 | 
 26 | 
 27 | # Button Func
 28 | def load_model(value):
 29 |     global wan_i2v_480P, wan_i2v_720P
 30 | 
 31 |     if value == '------':
 32 |         print("No model loaded")
 33 |         return '------'
 34 | 
 35 |     if value == '720P':
 36 |         if args.ckpt_dir_720p is None:
 37 |             print("Please specify the checkpoint directory for 720P model")
 38 |             return '------'
 39 |         if wan_i2v_720P is not None:
 40 |             pass
 41 |         else:
 42 |             del wan_i2v_480P
 43 |             gc.collect()
 44 |             wan_i2v_480P = None
 45 | 
 46 |             print("load 14B-720P i2v model...", end='', flush=True)
 47 |             cfg = WAN_CONFIGS['i2v-14B']
 48 |             wan_i2v_720P = wan.WanI2V(
 49 |                 config=cfg,
 50 |                 checkpoint_dir=args.ckpt_dir_720p,
 51 |                 device_id=0,
 52 |                 rank=0,
 53 |                 t5_fsdp=False,
 54 |                 dit_fsdp=False,
 55 |                 use_usp=False,
 56 |             )
 57 |             print("done", flush=True)
 58 |             return '720P'
 59 | 
 60 |     if value == '480P':
 61 |         if args.ckpt_dir_480p is None:
 62 |             print("Please specify the checkpoint directory for 480P model")
 63 |             return '------'
 64 |         if wan_i2v_480P is not None:
 65 |             pass
 66 |         else:
 67 |             del wan_i2v_720P
 68 |             gc.collect()
 69 |             wan_i2v_720P = None
 70 | 
 71 |             print("load 14B-480P i2v model...", end='', flush=True)
 72 |             cfg = WAN_CONFIGS['i2v-14B']
 73 |             wan_i2v_480P = wan.WanI2V(
 74 |                 config=cfg,
 75 |                 checkpoint_dir=args.ckpt_dir_480p,
 76 |                 device_id=0,
 77 |                 rank=0,
 78 |                 t5_fsdp=False,
 79 |                 dit_fsdp=False,
 80 |                 use_usp=False,
 81 |             )
 82 |             print("done", flush=True)
 83 |             return '480P'
 84 |     return value
 85 | 
 86 | 
 87 | def prompt_enc(prompt, img, tar_lang):
 88 |     print('prompt extend...')
 89 |     if img is None:
 90 |         print('Please upload an image')
 91 |         return prompt
 92 |     global prompt_expander
 93 |     prompt_output = prompt_expander(
 94 |         prompt, image=img, tar_lang=tar_lang.lower())
 95 |     if prompt_output.status == False:
 96 |         return prompt
 97 |     else:
 98 |         return prompt_output.prompt
 99 | 
100 | 
101 | def i2v_generation(img2vid_prompt, img2vid_image, resolution, sd_steps,
102 |                    guide_scale, shift_scale, seed, n_prompt):
103 |     # print(f"{img2vid_prompt},{resolution},{sd_steps},{guide_scale},{shift_scale},{seed},{n_prompt}")
104 | 
105 |     if resolution == '------':
106 |         print(
107 |             'Please specify at least one resolution ckpt dir or specify the resolution'
108 |         )
109 |         return None
110 | 
111 |     else:
112 |         if resolution == '720P':
113 |             global wan_i2v_720P
114 |             video = wan_i2v_720P.generate(
115 |                 img2vid_prompt,
116 |                 img2vid_image,
117 |                 max_area=MAX_AREA_CONFIGS['720*1280'],
118 |                 shift=shift_scale,
119 |                 sampling_steps=sd_steps,
120 |                 guide_scale=guide_scale,
121 |                 n_prompt=n_prompt,
122 |                 seed=seed,
123 |                 offload_model=True)
124 |         else:
125 |             global wan_i2v_480P
126 |             video = wan_i2v_480P.generate(
127 |                 img2vid_prompt,
128 |                 img2vid_image,
129 |                 max_area=MAX_AREA_CONFIGS['480*832'],
130 |                 shift=shift_scale,
131 |                 sampling_steps=sd_steps,
132 |                 guide_scale=guide_scale,
133 |                 n_prompt=n_prompt,
134 |                 seed=seed,
135 |                 offload_model=True)
136 | 
137 |         cache_video(
138 |             tensor=video[None],
139 |             save_file="example.mp4",
140 |             fps=16,
141 |             nrow=1,
142 |             normalize=True,
143 |             value_range=(-1, 1))
144 | 
145 |         return "example.mp4"
146 | 
147 | 
148 | # Interface
149 | def gradio_interface():
150 |     with gr.Blocks() as demo:
151 |         gr.Markdown("""
152 |                     <div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
153 |                         Wan2.1 (I2V-14B)
154 |                     </div>
155 |                     <div style="text-align: center; font-size: 16px; font-weight: normal; margin-bottom: 20px;">
156 |                         Wan: Open and Advanced Large-Scale Video Generative Models.
157 |                     </div>
158 |                     """)
159 | 
160 |         with gr.Row():
161 |             with gr.Column():
162 |                 resolution = gr.Dropdown(
163 |                     label='Resolution',
164 |                     choices=['------', '720P', '480P'],
165 |                     value='------')
166 | 
167 |                 img2vid_image = gr.Image(
168 |                     type="pil",
169 |                     label="Upload Input Image",
170 |                     elem_id="image_upload",
171 |                 )
172 |                 img2vid_prompt = gr.Textbox(
173 |                     label="Prompt",
174 |                     placeholder="Describe the video you want to generate",
175 |                 )
176 |                 tar_lang = gr.Radio(
177 |                     choices=["ZH", "EN"],
178 |                     label="Target language of prompt enhance",
179 |                     value="ZH")
180 |                 run_p_button = gr.Button(value="Prompt Enhance")
181 | 
182 |                 with gr.Accordion("Advanced Options", open=True):
183 |                     with gr.Row():
184 |                         sd_steps = gr.Slider(
185 |                             label="Diffusion steps",
186 |                             minimum=1,
187 |                             maximum=1000,
188 |                             value=50,
189 |                             step=1)
190 |                         guide_scale = gr.Slider(
191 |                             label="Guide scale",
192 |                             minimum=0,
193 |                             maximum=20,
194 |                             value=5.0,
195 |                             step=1)
196 |                     with gr.Row():
197 |                         shift_scale = gr.Slider(
198 |                             label="Shift scale",
199 |                             minimum=0,
200 |                             maximum=10,
201 |                             value=5.0,
202 |                             step=1)
203 |                         seed = gr.Slider(
204 |                             label="Seed",
205 |                             minimum=-1,
206 |                             maximum=2147483647,
207 |                             step=1,
208 |                             value=-1)
209 |                     n_prompt = gr.Textbox(
210 |                         label="Negative Prompt",
211 |                         placeholder="Describe the negative prompt you want to add"
212 |                     )
213 | 
214 |                 run_i2v_button = gr.Button("Generate Video")
215 | 
216 |             with gr.Column():
217 |                 result_gallery = gr.Video(
218 |                     label='Generated Video', interactive=False, height=600)
219 | 
220 |         resolution.input(
221 |             fn=load_model, inputs=[resolution], outputs=[resolution])
222 | 
223 |         run_p_button.click(
224 |             fn=prompt_enc,
225 |             inputs=[img2vid_prompt, img2vid_image, tar_lang],
226 |             outputs=[img2vid_prompt])
227 | 
228 |         run_i2v_button.click(
229 |             fn=i2v_generation,
230 |             inputs=[
231 |                 img2vid_prompt, img2vid_image, resolution, sd_steps,
232 |                 guide_scale, shift_scale, seed, n_prompt
233 |             ],
234 |             outputs=[result_gallery],
235 |         )
236 | 
237 |     return demo
238 | 
239 | 
240 | # Main
241 | def _parse_args():
242 |     parser = argparse.ArgumentParser(
243 |         description="Generate a video from a text prompt or image using Gradio")
244 |     parser.add_argument(
245 |         "--ckpt_dir_720p",
246 |         type=str,
247 |         default=None,
248 |         help="The path to the checkpoint directory.")
249 |     parser.add_argument(
250 |         "--ckpt_dir_480p",
251 |         type=str,
252 |         default=None,
253 |         help="The path to the checkpoint directory.")
254 |     parser.add_argument(
255 |         "--prompt_extend_method",
256 |         type=str,
257 |         default="local_qwen",
258 |         choices=["dashscope", "local_qwen"],
259 |         help="The prompt extend method to use.")
260 |     parser.add_argument(
261 |         "--prompt_extend_model",
262 |         type=str,
263 |         default=None,
264 |         help="The prompt extend model to use.")
265 | 
266 |     args = parser.parse_args()
267 |     assert args.ckpt_dir_720p is not None or args.ckpt_dir_480p is not None, "Please specify at least one checkpoint directory."
268 | 
269 |     return args
270 | 
271 | 
272 | if __name__ == '__main__':
273 |     args = _parse_args()
274 | 
275 |     print("Step1: Init prompt_expander...", end='', flush=True)
276 |     if args.prompt_extend_method == "dashscope":
277 |         prompt_expander = DashScopePromptExpander(
278 |             model_name=args.prompt_extend_model, is_vl=True)
279 |     elif args.prompt_extend_method == "local_qwen":
280 |         prompt_expander = QwenPromptExpander(
281 |             model_name=args.prompt_extend_model, is_vl=True, device=0)
282 |     else:
283 |         raise NotImplementedError(
284 |             f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
285 |     print("done", flush=True)
286 | 
287 |     demo = gradio_interface()
288 |     demo.launch(server_name="0.0.0.0", share=False, server_port=7860)
289 | 


--------------------------------------------------------------------------------
/gradio/t2i_14B_singleGPU.py:
--------------------------------------------------------------------------------
  1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
  2 | import argparse
  3 | import os
  4 | import os.path as osp
  5 | import sys
  6 | import warnings
  7 | 
  8 | import gradio as gr
  9 | 
 10 | warnings.filterwarnings('ignore')
 11 | 
 12 | # Model
 13 | sys.path.insert(
 14 |     0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
 15 | import wan
 16 | from wan.configs import WAN_CONFIGS
 17 | from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
 18 | from wan.utils.utils import cache_image
 19 | 
 20 | # Global Var
 21 | prompt_expander = None
 22 | wan_t2i = None
 23 | 
 24 | 
 25 | # Button Func
 26 | def prompt_enc(prompt, tar_lang):
 27 |     global prompt_expander
 28 |     prompt_output = prompt_expander(prompt, tar_lang=tar_lang.lower())
 29 |     if prompt_output.status == False:
 30 |         return prompt
 31 |     else:
 32 |         return prompt_output.prompt
 33 | 
 34 | 
 35 | def t2i_generation(txt2img_prompt, resolution, sd_steps, guide_scale,
 36 |                    shift_scale, seed, n_prompt):
 37 |     global wan_t2i
 38 |     # print(f"{txt2img_prompt},{resolution},{sd_steps},{guide_scale},{shift_scale},{seed},{n_prompt}")
 39 | 
 40 |     W = int(resolution.split("*")[0])
 41 |     H = int(resolution.split("*")[1])
 42 |     video = wan_t2i.generate(
 43 |         txt2img_prompt,
 44 |         size=(W, H),
 45 |         frame_num=1,
 46 |         shift=shift_scale,
 47 |         sampling_steps=sd_steps,
 48 |         guide_scale=guide_scale,
 49 |         n_prompt=n_prompt,
 50 |         seed=seed,
 51 |         offload_model=True)
 52 | 
 53 |     cache_image(
 54 |         tensor=video.squeeze(1)[None],
 55 |         save_file="example.png",
 56 |         nrow=1,
 57 |         normalize=True,
 58 |         value_range=(-1, 1))
 59 | 
 60 |     return "example.png"
 61 | 
 62 | 
 63 | # Interface
 64 | def gradio_interface():
 65 |     with gr.Blocks() as demo:
 66 |         gr.Markdown("""
 67 |                     <div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
 68 |                         Wan2.1 (T2I-14B)
 69 |                     </div>
 70 |                     <div style="text-align: center; font-size: 16px; font-weight: normal; margin-bottom: 20px;">
 71 |                         Wan: Open and Advanced Large-Scale Video Generative Models.
 72 |                     </div>
 73 |                     """)
 74 | 
 75 |         with gr.Row():
 76 |             with gr.Column():
 77 |                 txt2img_prompt = gr.Textbox(
 78 |                     label="Prompt",
 79 |                     placeholder="Describe the image you want to generate",
 80 |                 )
 81 |                 tar_lang = gr.Radio(
 82 |                     choices=["ZH", "EN"],
 83 |                     label="Target language of prompt enhance",
 84 |                     value="ZH")
 85 |                 run_p_button = gr.Button(value="Prompt Enhance")
 86 | 
 87 |                 with gr.Accordion("Advanced Options", open=True):
 88 |                     resolution = gr.Dropdown(
 89 |                         label='Resolution(Width*Height)',
 90 |                         choices=[
 91 |                             '720*1280', '1280*720', '960*960', '1088*832',
 92 |                             '832*1088', '480*832', '832*480', '624*624',
 93 |                             '704*544', '544*704'
 94 |                         ],
 95 |                         value='720*1280')
 96 | 
 97 |                     with gr.Row():
 98 |                         sd_steps = gr.Slider(
 99 |                             label="Diffusion steps",
100 |                             minimum=1,
101 |                             maximum=1000,
102 |                             value=50,
103 |                             step=1)
104 |                         guide_scale = gr.Slider(
105 |                             label="Guide scale",
106 |                             minimum=0,
107 |                             maximum=20,
108 |                             value=5.0,
109 |                             step=1)
110 |                     with gr.Row():
111 |                         shift_scale = gr.Slider(
112 |                             label="Shift scale",
113 |                             minimum=0,
114 |                             maximum=10,
115 |                             value=5.0,
116 |                             step=1)
117 |                         seed = gr.Slider(
118 |                             label="Seed",
119 |                             minimum=-1,
120 |                             maximum=2147483647,
121 |                             step=1,
122 |                             value=-1)
123 |                     n_prompt = gr.Textbox(
124 |                         label="Negative Prompt",
125 |                         placeholder="Describe the negative prompt you want to add"
126 |                     )
127 | 
128 |                 run_t2i_button = gr.Button("Generate Image")
129 | 
130 |             with gr.Column():
131 |                 result_gallery = gr.Image(
132 |                     label='Generated Image', interactive=False, height=600)
133 | 
134 |         run_p_button.click(
135 |             fn=prompt_enc,
136 |             inputs=[txt2img_prompt, tar_lang],
137 |             outputs=[txt2img_prompt])
138 | 
139 |         run_t2i_button.click(
140 |             fn=t2i_generation,
141 |             inputs=[
142 |                 txt2img_prompt, resolution, sd_steps, guide_scale, shift_scale,
143 |                 seed, n_prompt
144 |             ],
145 |             outputs=[result_gallery],
146 |         )
147 | 
148 |     return demo
149 | 
150 | 
151 | # Main
152 | def _parse_args():
153 |     parser = argparse.ArgumentParser(
154 |         description="Generate a image from a text prompt or image using Gradio")
155 |     parser.add_argument(
156 |         "--ckpt_dir",
157 |         type=str,
158 |         default="cache",
159 |         help="The path to the checkpoint directory.")
160 |     parser.add_argument(
161 |         "--prompt_extend_method",
162 |         type=str,
163 |         default="local_qwen",
164 |         choices=["dashscope", "local_qwen"],
165 |         help="The prompt extend method to use.")
166 |     parser.add_argument(
167 |         "--prompt_extend_model",
168 |         type=str,
169 |         default=None,
170 |         help="The prompt extend model to use.")
171 | 
172 |     args = parser.parse_args()
173 | 
174 |     return args
175 | 
176 | 
177 | if __name__ == '__main__':
178 |     args = _parse_args()
179 | 
180 |     print("Step1: Init prompt_expander...", end='', flush=True)
181 |     if args.prompt_extend_method == "dashscope":
182 |         prompt_expander = DashScopePromptExpander(
183 |             model_name=args.prompt_extend_model, is_vl=False)
184 |     elif args.prompt_extend_method == "local_qwen":
185 |         prompt_expander = QwenPromptExpander(
186 |             model_name=args.prompt_extend_model, is_vl=False, device=0)
187 |     else:
188 |         raise NotImplementedError(
189 |             f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
190 |     print("done", flush=True)
191 | 
192 |     print("Step2: Init 14B t2i model...", end='', flush=True)
193 |     cfg = WAN_CONFIGS['t2i-14B']
194 |     wan_t2i = wan.WanT2V(
195 |         config=cfg,
196 |         checkpoint_dir=args.ckpt_dir,
197 |         device_id=0,
198 |         rank=0,
199 |         t5_fsdp=False,
200 |         dit_fsdp=False,
201 |         use_usp=False,
202 |     )
203 |     print("done", flush=True)
204 | 
205 |     demo = gradio_interface()
206 |     demo.launch(server_name="0.0.0.0", share=False, server_port=7860)
207 | 


--------------------------------------------------------------------------------
/gradio/t2v_1.3B_singleGPU.py:
--------------------------------------------------------------------------------
  1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
  2 | import argparse
  3 | import os
  4 | import os.path as osp
  5 | import sys
  6 | import warnings
  7 | 
  8 | import gradio as gr
  9 | 
 10 | warnings.filterwarnings('ignore')
 11 | 
 12 | # Model
 13 | sys.path.insert(
 14 |     0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
 15 | import wan
 16 | from wan.configs import WAN_CONFIGS
 17 | from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
 18 | from wan.utils.utils import cache_video
 19 | 
 20 | # Global Var
 21 | prompt_expander = None
 22 | wan_t2v = None
 23 | 
 24 | 
 25 | # Button Func
 26 | def prompt_enc(prompt, tar_lang):
 27 |     global prompt_expander
 28 |     prompt_output = prompt_expander(prompt, tar_lang=tar_lang.lower())
 29 |     if prompt_output.status == False:
 30 |         return prompt
 31 |     else:
 32 |         return prompt_output.prompt
 33 | 
 34 | 
 35 | def t2v_generation(txt2vid_prompt, resolution, sd_steps, guide_scale,
 36 |                    shift_scale, seed, n_prompt):
 37 |     global wan_t2v
 38 |     # print(f"{txt2vid_prompt},{resolution},{sd_steps},{guide_scale},{shift_scale},{seed},{n_prompt}")
 39 | 
 40 |     W = int(resolution.split("*")[0])
 41 |     H = int(resolution.split("*")[1])
 42 |     video = wan_t2v.generate(
 43 |         txt2vid_prompt,
 44 |         size=(W, H),
 45 |         shift=shift_scale,
 46 |         sampling_steps=sd_steps,
 47 |         guide_scale=guide_scale,
 48 |         n_prompt=n_prompt,
 49 |         seed=seed,
 50 |         offload_model=True)
 51 | 
 52 |     cache_video(
 53 |         tensor=video[None],
 54 |         save_file="example.mp4",
 55 |         fps=16,
 56 |         nrow=1,
 57 |         normalize=True,
 58 |         value_range=(-1, 1))
 59 | 
 60 |     return "example.mp4"
 61 | 
 62 | 
 63 | # Interface
 64 | def gradio_interface():
 65 |     with gr.Blocks() as demo:
 66 |         gr.Markdown("""
 67 |                     <div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
 68 |                         Wan2.1 (T2V-1.3B)
 69 |                     </div>
 70 |                     <div style="text-align: center; font-size: 16px; font-weight: normal; margin-bottom: 20px;">
 71 |                         Wan: Open and Advanced Large-Scale Video Generative Models.
 72 |                     </div>
 73 |                     """)
 74 | 
 75 |         with gr.Row():
 76 |             with gr.Column():
 77 |                 txt2vid_prompt = gr.Textbox(
 78 |                     label="Prompt",
 79 |                     placeholder="Describe the video you want to generate",
 80 |                 )
 81 |                 tar_lang = gr.Radio(
 82 |                     choices=["ZH", "EN"],
 83 |                     label="Target language of prompt enhance",
 84 |                     value="ZH")
 85 |                 run_p_button = gr.Button(value="Prompt Enhance")
 86 | 
 87 |                 with gr.Accordion("Advanced Options", open=True):
 88 |                     resolution = gr.Dropdown(
 89 |                         label='Resolution(Width*Height)',
 90 |                         choices=[
 91 |                             '480*832',
 92 |                             '832*480',
 93 |                             '624*624',
 94 |                             '704*544',
 95 |                             '544*704',
 96 |                         ],
 97 |                         value='480*832')
 98 | 
 99 |                     with gr.Row():
100 |                         sd_steps = gr.Slider(
101 |                             label="Diffusion steps",
102 |                             minimum=1,
103 |                             maximum=1000,
104 |                             value=50,
105 |                             step=1)
106 |                         guide_scale = gr.Slider(
107 |                             label="Guide scale",
108 |                             minimum=0,
109 |                             maximum=20,
110 |                             value=6.0,
111 |                             step=1)
112 |                     with gr.Row():
113 |                         shift_scale = gr.Slider(
114 |                             label="Shift scale",
115 |                             minimum=0,
116 |                             maximum=20,
117 |                             value=8.0,
118 |                             step=1)
119 |                         seed = gr.Slider(
120 |                             label="Seed",
121 |                             minimum=-1,
122 |                             maximum=2147483647,
123 |                             step=1,
124 |                             value=-1)
125 |                     n_prompt = gr.Textbox(
126 |                         label="Negative Prompt",
127 |                         placeholder="Describe the negative prompt you want to add"
128 |                     )
129 | 
130 |                 run_t2v_button = gr.Button("Generate Video")
131 | 
132 |             with gr.Column():
133 |                 result_gallery = gr.Video(
134 |                     label='Generated Video', interactive=False, height=600)
135 | 
136 |         run_p_button.click(
137 |             fn=prompt_enc,
138 |             inputs=[txt2vid_prompt, tar_lang],
139 |             outputs=[txt2vid_prompt])
140 | 
141 |         run_t2v_button.click(
142 |             fn=t2v_generation,
143 |             inputs=[
144 |                 txt2vid_prompt, resolution, sd_steps, guide_scale, shift_scale,
145 |                 seed, n_prompt
146 |             ],
147 |             outputs=[result_gallery],
148 |         )
149 | 
150 |     return demo
151 | 
152 | 
153 | # Main
154 | def _parse_args():
155 |     parser = argparse.ArgumentParser(
156 |         description="Generate a video from a text prompt or image using Gradio")
157 |     parser.add_argument(
158 |         "--ckpt_dir",
159 |         type=str,
160 |         default="cache",
161 |         help="The path to the checkpoint directory.")
162 |     parser.add_argument(
163 |         "--prompt_extend_method",
164 |         type=str,
165 |         default="local_qwen",
166 |         choices=["dashscope", "local_qwen"],
167 |         help="The prompt extend method to use.")
168 |     parser.add_argument(
169 |         "--prompt_extend_model",
170 |         type=str,
171 |         default=None,
172 |         help="The prompt extend model to use.")
173 | 
174 |     args = parser.parse_args()
175 | 
176 |     return args
177 | 
178 | 
179 | if __name__ == '__main__':
180 |     args = _parse_args()
181 | 
182 |     print("Step1: Init prompt_expander...", end='', flush=True)
183 |     if args.prompt_extend_method == "dashscope":
184 |         prompt_expander = DashScopePromptExpander(
185 |             model_name=args.prompt_extend_model, is_vl=False)
186 |     elif args.prompt_extend_method == "local_qwen":
187 |         prompt_expander = QwenPromptExpander(
188 |             model_name=args.prompt_extend_model, is_vl=False, device=0)
189 |     else:
190 |         raise NotImplementedError(
191 |             f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
192 |     print("done", flush=True)
193 | 
194 |     print("Step2: Init 1.3B t2v model...", end='', flush=True)
195 |     cfg = WAN_CONFIGS['t2v-1.3B']
196 |     wan_t2v = wan.WanT2V(
197 |         config=cfg,
198 |         checkpoint_dir=args.ckpt_dir,
199 |         device_id=0,
200 |         rank=0,
201 |         t5_fsdp=False,
202 |         dit_fsdp=False,
203 |         use_usp=False,
204 |     )
205 |     print("done", flush=True)
206 | 
207 |     demo = gradio_interface()
208 |     demo.launch(server_name="0.0.0.0", share=False, server_port=7860)
209 | 


--------------------------------------------------------------------------------
/gradio/t2v_14B_singleGPU.py:
--------------------------------------------------------------------------------
  1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
  2 | import argparse
  3 | import os
  4 | import os.path as osp
  5 | import sys
  6 | import warnings
  7 | 
  8 | import gradio as gr
  9 | 
 10 | warnings.filterwarnings('ignore')
 11 | 
 12 | # Model
 13 | sys.path.insert(
 14 |     0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2]))
 15 | import wan
 16 | from wan.configs import WAN_CONFIGS
 17 | from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
 18 | from wan.utils.utils import cache_video
 19 | 
 20 | # Global Var
 21 | prompt_expander = None
 22 | wan_t2v = None
 23 | 
 24 | 
 25 | # Button Func
 26 | def prompt_enc(prompt, tar_lang):
 27 |     global prompt_expander
 28 |     prompt_output = prompt_expander(prompt, tar_lang=tar_lang.lower())
 29 |     if prompt_output.status == False:
 30 |         return prompt
 31 |     else:
 32 |         return prompt_output.prompt
 33 | 
 34 | 
 35 | def t2v_generation(txt2vid_prompt, resolution, sd_steps, guide_scale,
 36 |                    shift_scale, seed, n_prompt):
 37 |     global wan_t2v
 38 |     # print(f"{txt2vid_prompt},{resolution},{sd_steps},{guide_scale},{shift_scale},{seed},{n_prompt}")
 39 | 
 40 |     W = int(resolution.split("*")[0])
 41 |     H = int(resolution.split("*")[1])
 42 |     video = wan_t2v.generate(
 43 |         txt2vid_prompt,
 44 |         size=(W, H),
 45 |         shift=shift_scale,
 46 |         sampling_steps=sd_steps,
 47 |         guide_scale=guide_scale,
 48 |         n_prompt=n_prompt,
 49 |         seed=seed,
 50 |         offload_model=True)
 51 | 
 52 |     cache_video(
 53 |         tensor=video[None],
 54 |         save_file="example.mp4",
 55 |         fps=16,
 56 |         nrow=1,
 57 |         normalize=True,
 58 |         value_range=(-1, 1))
 59 | 
 60 |     return "example.mp4"
 61 | 
 62 | 
 63 | # Interface
 64 | def gradio_interface():
 65 |     with gr.Blocks() as demo:
 66 |         gr.Markdown("""
 67 |                     <div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
 68 |                         Wan2.1 (T2V-14B)
 69 |                     </div>
 70 |                     <div style="text-align: center; font-size: 16px; font-weight: normal; margin-bottom: 20px;">
 71 |                         Wan: Open and Advanced Large-Scale Video Generative Models.
 72 |                     </div>
 73 |                     """)
 74 | 
 75 |         with gr.Row():
 76 |             with gr.Column():
 77 |                 txt2vid_prompt = gr.Textbox(
 78 |                     label="Prompt",
 79 |                     placeholder="Describe the video you want to generate",
 80 |                 )
 81 |                 tar_lang = gr.Radio(
 82 |                     choices=["ZH", "EN"],
 83 |                     label="Target language of prompt enhance",
 84 |                     value="ZH")
 85 |                 run_p_button = gr.Button(value="Prompt Enhance")
 86 | 
 87 |                 with gr.Accordion("Advanced Options", open=True):
 88 |                     resolution = gr.Dropdown(
 89 |                         label='Resolution(Width*Height)',
 90 |                         choices=[
 91 |                             '720*1280', '1280*720', '960*960', '1088*832',
 92 |                             '832*1088', '480*832', '832*480', '624*624',
 93 |                             '704*544', '544*704'
 94 |                         ],
 95 |                         value='720*1280')
 96 | 
 97 |                     with gr.Row():
 98 |                         sd_steps = gr.Slider(
 99 |                             label="Diffusion steps",
100 |                             minimum=1,
101 |                             maximum=1000,
102 |                             value=50,
103 |                             step=1)
104 |                         guide_scale = gr.Slider(
105 |                             label="Guide scale",
106 |                             minimum=0,
107 |                             maximum=20,
108 |                             value=5.0,
109 |                             step=1)
110 |                     with gr.Row():
111 |                         shift_scale = gr.Slider(
112 |                             label="Shift scale",
113 |                             minimum=0,
114 |                             maximum=10,
115 |                             value=5.0,
116 |                             step=1)
117 |                         seed = gr.Slider(
118 |                             label="Seed",
119 |                             minimum=-1,
120 |                             maximum=2147483647,
121 |                             step=1,
122 |                             value=-1)
123 |                     n_prompt = gr.Textbox(
124 |                         label="Negative Prompt",
125 |                         placeholder="Describe the negative prompt you want to add"
126 |                     )
127 | 
128 |                 run_t2v_button = gr.Button("Generate Video")
129 | 
130 |             with gr.Column():
131 |                 result_gallery = gr.Video(
132 |                     label='Generated Video', interactive=False, height=600)
133 | 
134 |         run_p_button.click(
135 |             fn=prompt_enc,
136 |             inputs=[txt2vid_prompt, tar_lang],
137 |             outputs=[txt2vid_prompt])
138 | 
139 |         run_t2v_button.click(
140 |             fn=t2v_generation,
141 |             inputs=[
142 |                 txt2vid_prompt, resolution, sd_steps, guide_scale, shift_scale,
143 |                 seed, n_prompt
144 |             ],
145 |             outputs=[result_gallery],
146 |         )
147 | 
148 |     return demo
149 | 
150 | 
151 | # Main
152 | def _parse_args():
153 |     parser = argparse.ArgumentParser(
154 |         description="Generate a video from a text prompt or image using Gradio")
155 |     parser.add_argument(
156 |         "--ckpt_dir",
157 |         type=str,
158 |         default="cache",
159 |         help="The path to the checkpoint directory.")
160 |     parser.add_argument(
161 |         "--prompt_extend_method",
162 |         type=str,
163 |         default="local_qwen",
164 |         choices=["dashscope", "local_qwen"],
165 |         help="The prompt extend method to use.")
166 |     parser.add_argument(
167 |         "--prompt_extend_model",
168 |         type=str,
169 |         default=None,
170 |         help="The prompt extend model to use.")
171 | 
172 |     args = parser.parse_args()
173 | 
174 |     return args
175 | 
176 | 
177 | if __name__ == '__main__':
178 |     args = _parse_args()
179 | 
180 |     print("Step1: Init prompt_expander...", end='', flush=True)
181 |     if args.prompt_extend_method == "dashscope":
182 |         prompt_expander = DashScopePromptExpander(
183 |             model_name=args.prompt_extend_model, is_vl=False)
184 |     elif args.prompt_extend_method == "local_qwen":
185 |         prompt_expander = QwenPromptExpander(
186 |             model_name=args.prompt_extend_model, is_vl=False, device=0)
187 |     else:
188 |         raise NotImplementedError(
189 |             f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
190 |     print("done", flush=True)
191 | 
192 |     print("Step2: Init 14B t2v model...", end='', flush=True)
193 |     cfg = WAN_CONFIGS['t2v-14B']
194 |     wan_t2v = wan.WanT2V(
195 |         config=cfg,
196 |         checkpoint_dir=args.ckpt_dir,
197 |         device_id=0,
198 |         rank=0,
199 |         t5_fsdp=False,
200 |         dit_fsdp=False,
201 |         use_usp=False,
202 |     )
203 |     print("done", flush=True)
204 | 
205 |     demo = gradio_interface()
206 |     demo.launch(server_name="0.0.0.0", share=False, server_port=7860)
207 | 


--------------------------------------------------------------------------------
/gradio/vace.py:
--------------------------------------------------------------------------------
  1 | # -*- coding: utf-8 -*-
  2 | # Copyright (c) Alibaba, Inc. and its affiliates.
  3 | 
  4 | import argparse
  5 | import datetime
  6 | import os
  7 | import sys
  8 | 
  9 | import imageio
 10 | import numpy as np
 11 | import torch
 12 | 
 13 | import gradio as gr
 14 | 
 15 | sys.path.insert(
 16 |     0, os.path.sep.join(os.path.realpath(__file__).split(os.path.sep)[:-2]))
 17 | import wan
 18 | from wan import WanVace, WanVaceMP
 19 | from wan.configs import SIZE_CONFIGS, WAN_CONFIGS
 20 | 
 21 | 
 22 | class FixedSizeQueue:
 23 | 
 24 |     def __init__(self, max_size):
 25 |         self.max_size = max_size
 26 |         self.queue = []
 27 | 
 28 |     def add(self, item):
 29 |         self.queue.insert(0, item)
 30 |         if len(self.queue) > self.max_size:
 31 |             self.queue.pop()
 32 | 
 33 |     def get(self):
 34 |         return self.queue
 35 | 
 36 |     def __repr__(self):
 37 |         return str(self.queue)
 38 | 
 39 | 
 40 | class VACEInference:
 41 | 
 42 |     def __init__(self,
 43 |                  cfg,
 44 |                  skip_load=False,
 45 |                  gallery_share=True,
 46 |                  gallery_share_limit=5):
 47 |         self.cfg = cfg
 48 |         self.save_dir = cfg.save_dir
 49 |         self.gallery_share = gallery_share
 50 |         self.gallery_share_data = FixedSizeQueue(max_size=gallery_share_limit)
 51 |         if not skip_load:
 52 |             if not args.mp:
 53 |                 self.pipe = WanVace(
 54 |                     config=WAN_CONFIGS[cfg.model_name],
 55 |                     checkpoint_dir=cfg.ckpt_dir,
 56 |                     device_id=0,
 57 |                     rank=0,
 58 |                     t5_fsdp=False,
 59 |                     dit_fsdp=False,
 60 |                     use_usp=False,
 61 |                 )
 62 |             else:
 63 |                 self.pipe = WanVaceMP(
 64 |                     config=WAN_CONFIGS[cfg.model_name],
 65 |                     checkpoint_dir=cfg.ckpt_dir,
 66 |                     use_usp=True,
 67 |                     ulysses_size=cfg.ulysses_size,
 68 |                     ring_size=cfg.ring_size)
 69 | 
 70 |     def create_ui(self, *args, **kwargs):
 71 |         gr.Markdown("""
 72 |                     <div style="text-align: center; font-size: 24px; font-weight: bold; margin-bottom: 15px;">
 73 |                         <a href="https://ali-vilab.github.io/VACE-Page/" style="text-decoration: none; color: inherit;">VACE-WAN Demo</a>
 74 |                     </div>
 75 |                     """)
 76 |         with gr.Row(variant='panel', equal_height=True):
 77 |             with gr.Column(scale=1, min_width=0):
 78 |                 self.src_video = gr.Video(
 79 |                     label="src_video",
 80 |                     sources=['upload'],
 81 |                     value=None,
 82 |                     interactive=True)
 83 |             with gr.Column(scale=1, min_width=0):
 84 |                 self.src_mask = gr.Video(
 85 |                     label="src_mask",
 86 |                     sources=['upload'],
 87 |                     value=None,
 88 |                     interactive=True)
 89 |         #
 90 |         with gr.Row(variant='panel', equal_height=True):
 91 |             with gr.Column(scale=1, min_width=0):
 92 |                 with gr.Row(equal_height=True):
 93 |                     self.src_ref_image_1 = gr.Image(
 94 |                         label='src_ref_image_1',
 95 |                         height=200,
 96 |                         interactive=True,
 97 |                         type='filepath',
 98 |                         image_mode='RGB',
 99 |                         sources=['upload'],
100 |                         elem_id="src_ref_image_1",
101 |                         format='png')
102 |                     self.src_ref_image_2 = gr.Image(
103 |                         label='src_ref_image_2',
104 |                         height=200,
105 |                         interactive=True,
106 |                         type='filepath',
107 |                         image_mode='RGB',
108 |                         sources=['upload'],
109 |                         elem_id="src_ref_image_2",
110 |                         format='png')
111 |                     self.src_ref_image_3 = gr.Image(
112 |                         label='src_ref_image_3',
113 |                         height=200,
114 |                         interactive=True,
115 |                         type='filepath',
116 |                         image_mode='RGB',
117 |                         sources=['upload'],
118 |                         elem_id="src_ref_image_3",
119 |                         format='png')
120 |         with gr.Row(variant='panel', equal_height=True):
121 |             with gr.Column(scale=1):
122 |                 self.prompt = gr.Textbox(
123 |                     show_label=False,
124 |                     placeholder="positive_prompt_input",
125 |                     elem_id='positive_prompt',
126 |                     container=True,
127 |                     autofocus=True,
128 |                     elem_classes='type_row',
129 |                     visible=True,
130 |                     lines=2)
131 |                 self.negative_prompt = gr.Textbox(
132 |                     show_label=False,
133 |                     value=self.pipe.config.sample_neg_prompt,
134 |                     placeholder="negative_prompt_input",
135 |                     elem_id='negative_prompt',
136 |                     container=True,
137 |                     autofocus=False,
138 |                     elem_classes='type_row',
139 |                     visible=True,
140 |                     interactive=True,
141 |                     lines=1)
142 |         #
143 |         with gr.Row(variant='panel', equal_height=True):
144 |             with gr.Column(scale=1, min_width=0):
145 |                 with gr.Row(equal_height=True):
146 |                     self.shift_scale = gr.Slider(
147 |                         label='shift_scale',
148 |                         minimum=0.0,
149 |                         maximum=100.0,
150 |                         step=1.0,
151 |                         value=16.0,
152 |                         interactive=True)
153 |                     self.sample_steps = gr.Slider(
154 |                         label='sample_steps',
155 |                         minimum=1,
156 |                         maximum=100,
157 |                         step=1,
158 |                         value=25,
159 |                         interactive=True)
160 |                     self.context_scale = gr.Slider(
161 |                         label='context_scale',
162 |                         minimum=0.0,
163 |                         maximum=2.0,
164 |                         step=0.1,
165 |                         value=1.0,
166 |                         interactive=True)
167 |                     self.guide_scale = gr.Slider(
168 |                         label='guide_scale',
169 |                         minimum=1,
170 |                         maximum=10,
171 |                         step=0.5,
172 |                         value=5.0,
173 |                         interactive=True)
174 |                     self.infer_seed = gr.Slider(
175 |                         minimum=-1, maximum=10000000, value=2025, label="Seed")
176 |         #
177 |         with gr.Accordion(label="Usable without source video", open=False):
178 |             with gr.Row(equal_height=True):
179 |                 self.output_height = gr.Textbox(
180 |                     label='resolutions_height',
181 |                     # value=480,
182 |                     value=720,
183 |                     interactive=True)
184 |                 self.output_width = gr.Textbox(
185 |                     label='resolutions_width',
186 |                     # value=832,
187 |                     value=1280,
188 |                     interactive=True)
189 |                 self.frame_rate = gr.Textbox(
190 |                     label='frame_rate', value=16, interactive=True)
191 |                 self.num_frames = gr.Textbox(
192 |                     label='num_frames', value=81, interactive=True)
193 |         #
194 |         with gr.Row(equal_height=True):
195 |             with gr.Column(scale=5):
196 |                 self.generate_button = gr.Button(
197 |                     value='Run',
198 |                     elem_classes='type_row',
199 |                     elem_id='generate_button',
200 |                     visible=True)
201 |             with gr.Column(scale=1):
202 |                 self.refresh_button = gr.Button(value='\U0001f504')  # 🔄
203 |         #
204 |         self.output_gallery = gr.Gallery(
205 |             label="output_gallery",
206 |             value=[],
207 |             interactive=False,
208 |             allow_preview=True,
209 |             preview=True)
210 | 
211 |     def generate(self, output_gallery, src_video, src_mask, src_ref_image_1,
212 |                  src_ref_image_2, src_ref_image_3, prompt, negative_prompt,
213 |                  shift_scale, sample_steps, context_scale, guide_scale,
214 |                  infer_seed, output_height, output_width, frame_rate,
215 |                  num_frames):
216 |         output_height, output_width, frame_rate, num_frames = int(
217 |             output_height), int(output_width), int(frame_rate), int(num_frames)
218 |         src_ref_images = [
219 |             x for x in [src_ref_image_1, src_ref_image_2, src_ref_image_3]
220 |             if x is not None
221 |         ]
222 |         src_video, src_mask, src_ref_images = self.pipe.prepare_source(
223 |             [src_video], [src_mask], [src_ref_images],
224 |             num_frames=num_frames,
225 |             image_size=SIZE_CONFIGS[f"{output_width}*{output_height}"],
226 |             device=self.pipe.device)
227 |         video = self.pipe.generate(
228 |             prompt,
229 |             src_video,
230 |             src_mask,
231 |             src_ref_images,
232 |             size=(output_width, output_height),
233 |             context_scale=context_scale,
234 |             shift=shift_scale,
235 |             sampling_steps=sample_steps,
236 |             guide_scale=guide_scale,
237 |             n_prompt=negative_prompt,
238 |             seed=infer_seed,
239 |             offload_model=True)
240 | 
241 |         name = '{0:%Y%m%d%-H%M%S}'.format(datetime.datetime.now())
242 |         video_path = os.path.join(self.save_dir, f'cur_gallery_{name}.mp4')
243 |         video_frames = (
244 |             torch.clamp(video / 2 + 0.5, min=0.0, max=1.0).permute(1, 2, 3, 0) *
245 |             255).cpu().numpy().astype(np.uint8)
246 | 
247 |         try:
248 |             writer = imageio.get_writer(
249 |                 video_path,
250 |                 fps=frame_rate,
251 |                 codec='libx264',
252 |                 quality=8,
253 |                 macro_block_size=1)
254 |             for frame in video_frames:
255 |                 writer.append_data(frame)
256 |             writer.close()
257 |             print(video_path)
258 |         except Exception as e:
259 |             raise gr.Error(f"Video save error: {e}")
260 | 
261 |         if self.gallery_share:
262 |             self.gallery_share_data.add(video_path)
263 |             return self.gallery_share_data.get()
264 |         else:
265 |             return [video_path]
266 | 
267 |     def set_callbacks(self, **kwargs):
268 |         self.gen_inputs = [
269 |             self.output_gallery, self.src_video, self.src_mask,
270 |             self.src_ref_image_1, self.src_ref_image_2, self.src_ref_image_3,
271 |             self.prompt, self.negative_prompt, self.shift_scale,
272 |             self.sample_steps, self.context_scale, self.guide_scale,
273 |             self.infer_seed, self.output_height, self.output_width,
274 |             self.frame_rate, self.num_frames
275 |         ]
276 |         self.gen_outputs = [self.output_gallery]
277 |         self.generate_button.click(
278 |             self.generate,
279 |             inputs=self.gen_inputs,
280 |             outputs=self.gen_outputs,
281 |             queue=True)
282 |         self.refresh_button.click(
283 |             lambda x: self.gallery_share_data.get()
284 |             if self.gallery_share else x,
285 |             inputs=[self.output_gallery],
286 |             outputs=[self.output_gallery])
287 | 
288 | 
289 | if __name__ == '__main__':
290 |     parser = argparse.ArgumentParser(
291 |         description='Argparser for VACE-WAN Demo:\n')
292 |     parser.add_argument(
293 |         '--server_port', dest='server_port', help='', type=int, default=7860)
294 |     parser.add_argument(
295 |         '--server_name', dest='server_name', help='', default='0.0.0.0')
296 |     parser.add_argument('--root_path', dest='root_path', help='', default=None)
297 |     parser.add_argument('--save_dir', dest='save_dir', help='', default='cache')
298 |     parser.add_argument(
299 |         "--mp",
300 |         action="store_true",
301 |         help="Use Multi-GPUs",
302 |     )
303 |     parser.add_argument(
304 |         "--model_name",
305 |         type=str,
306 |         default="vace-14B",
307 |         choices=list(WAN_CONFIGS.keys()),
308 |         help="The model name to run.")
309 |     parser.add_argument(
310 |         "--ulysses_size",
311 |         type=int,
312 |         default=1,
313 |         help="The size of the ulysses parallelism in DiT.")
314 |     parser.add_argument(
315 |         "--ring_size",
316 |         type=int,
317 |         default=1,
318 |         help="The size of the ring attention parallelism in DiT.")
319 |     parser.add_argument(
320 |         "--ckpt_dir",
321 |         type=str,
322 |         # default='models/VACE-Wan2.1-1.3B-Preview',
323 |         default='models/Wan2.1-VACE-14B/',
324 |         help="The path to the checkpoint directory.",
325 |     )
326 |     parser.add_argument(
327 |         "--offload_to_cpu",
328 |         action="store_true",
329 |         help="Offloading unnecessary computations to CPU.",
330 |     )
331 | 
332 |     args = parser.parse_args()
333 | 
334 |     if not os.path.exists(args.save_dir):
335 |         os.makedirs(args.save_dir, exist_ok=True)
336 | 
337 |     with gr.Blocks() as demo:
338 |         infer_gr = VACEInference(
339 |             args, skip_load=False, gallery_share=True, gallery_share_limit=5)
340 |         infer_gr.create_ui()
341 |         infer_gr.set_callbacks()
342 |         allowed_paths = [args.save_dir]
343 |         demo.queue(status_update_rate=1).launch(
344 |             server_name=args.server_name,
345 |             server_port=args.server_port,
346 |             root_path=args.root_path,
347 |             allowed_paths=allowed_paths,
348 |             show_error=True,
349 |             debug=True)
350 | 


--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
 1 | [build-system]
 2 | requires = ["setuptools>=61.0"]
 3 | build-backend = "setuptools.build_meta"
 4 | 
 5 | [project]
 6 | name = "wan"
 7 | version = "2.1.0"
 8 | description = "Wan: Open and Advanced Large-Scale Video Generative Models"
 9 | authors = [
10 |     { name = "Wan Team", email = "wan.ai@alibabacloud.com" }
11 | ]
12 | license = { file = "LICENSE.txt" }
13 | readme = "README.md"
14 | requires-python = ">=3.10,<4.0"
15 | dependencies = [
16 |     "torch>=2.4.0",
17 |     "torchvision>=0.19.0",
18 |     "opencv-python>=4.9.0.80",
19 |     "diffusers>=0.31.0",
20 |     "transformers>=4.49.0",
21 |     "tokenizers>=0.20.3",
22 |     "accelerate>=1.1.1",
23 |     "tqdm",
24 |     "imageio",
25 |     "easydict",
26 |     "ftfy",
27 |     "dashscope",
28 |     "imageio-ffmpeg",
29 |     "flash_attn",
30 |     "gradio>=5.0.0",
31 |     "numpy>=1.23.5,<2"
32 | ]
33 | 
34 | [project.optional-dependencies]
35 | dev = [
36 |     "pytest",
37 |     "black",
38 |     "flake8",
39 |     "isort",
40 |     "mypy",
41 |     "huggingface-hub[cli]"
42 | ]
43 | 
44 | [project.urls]
45 | homepage = "https://wanxai.com"
46 | documentation = "https://github.com/Wan-Video/Wan2.1"
47 | repository = "https://github.com/Wan-Video/Wan2.1"
48 | huggingface = "https://huggingface.co/Wan-AI/"
49 | modelscope = "https://modelscope.cn/organization/Wan-AI"
50 | discord = "https://discord.gg/p5XbdQV7"
51 | 
52 | [tool.setuptools]
53 | packages = ["wan"]
54 | 
55 | [tool.setuptools.package-data]
56 | "wan" = ["**/*.py"]
57 | 
58 | [tool.black]
59 | line-length = 88
60 | 
61 | [tool.isort]
62 | profile = "black"
63 | 
64 | [tool.mypy]
65 | strict = true
66 | 
67 | 
68 | 


--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
 1 | torch>=2.4.0
 2 | torchvision>=0.19.0
 3 | opencv-python>=4.9.0.80
 4 | diffusers>=0.31.0
 5 | transformers>=4.49.0
 6 | tokenizers>=0.20.3
 7 | accelerate>=1.1.1
 8 | tqdm
 9 | imageio
10 | easydict
11 | ftfy
12 | dashscope
13 | imageio-ffmpeg
14 | flash_attn
15 | gradio>=5.0.0
16 | numpy>=1.23.5,<2
17 | 


--------------------------------------------------------------------------------
/tests/README.md:
--------------------------------------------------------------------------------
1 | 
2 | Put all your models (Wan2.1-T2V-1.3B, Wan2.1-T2V-14B, Wan2.1-I2V-14B-480P, Wan2.1-I2V-14B-720P) in a folder and specify the max GPU number you want to use.
3 | 
4 | ```bash
5 | bash ./test.sh <local model dir> <gpu number>
6 | ```
7 | 


--------------------------------------------------------------------------------
/tests/test.sh:
--------------------------------------------------------------------------------
  1 | #!/bin/bash
  2 | 
  3 | 
  4 | if [ "$#" -eq 2 ]; then
  5 |   MODEL_DIR=$(realpath "$1")
  6 |   GPUS=$2
  7 | else
  8 |   echo "Usage: $0 <local model dir> <gpu number>"
  9 |   exit 1
 10 | fi
 11 | 
 12 | SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
 13 | REPO_ROOT="$(dirname "$SCRIPT_DIR")"
 14 | cd "$REPO_ROOT" || exit 1
 15 | 
 16 | PY_FILE=./generate.py
 17 | 
 18 | 
 19 | function t2v_1_3B() {
 20 |     T2V_1_3B_CKPT_DIR="$MODEL_DIR/Wan2.1-T2V-1.3B"
 21 | 
 22 |     # 1-GPU Test
 23 |     echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_1_3B 1-GPU Test: "
 24 |     python $PY_FILE --task t2v-1.3B --size 480*832 --ckpt_dir $T2V_1_3B_CKPT_DIR
 25 | 
 26 |     # Multiple GPU Test
 27 |     echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_1_3B Multiple GPU Test: "
 28 |     torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-1.3B --ckpt_dir $T2V_1_3B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS
 29 | 
 30 |     echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_1_3B Multiple GPU, prompt extend local_qwen: "
 31 |     torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-1.3B --ckpt_dir $T2V_1_3B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_model "Qwen/Qwen2.5-3B-Instruct" --prompt_extend_target_lang "en"
 32 | 
 33 |     if [ -n "${DASH_API_KEY+x}" ]; then
 34 |         echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_1_3B Multiple GPU, prompt extend dashscope: "
 35 |         torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-1.3B --ckpt_dir $T2V_1_3B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_method "dashscope"
 36 |     else
 37 |         echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> No DASH_API_KEY found, skip the dashscope extend test."
 38 |     fi
 39 | }
 40 | 
 41 | function t2v_14B() {
 42 |     T2V_14B_CKPT_DIR="$MODEL_DIR/Wan2.1-T2V-14B"
 43 | 
 44 |     # 1-GPU Test
 45 |     echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_14B 1-GPU Test: "
 46 |     python $PY_FILE --task t2v-14B --size 480*832 --ckpt_dir $T2V_14B_CKPT_DIR
 47 | 
 48 |     # Multiple GPU Test
 49 |     echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_14B Multiple GPU Test: "
 50 |     torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-14B --ckpt_dir $T2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS
 51 | 
 52 |     echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_14B Multiple GPU, prompt extend local_qwen: "
 53 |     torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-14B --ckpt_dir $T2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_model "Qwen/Qwen2.5-3B-Instruct" --prompt_extend_target_lang "en"
 54 | }
 55 | 
 56 | 
 57 | 
 58 | function t2i_14B() {
 59 |     T2V_14B_CKPT_DIR="$MODEL_DIR/Wan2.1-T2V-14B"
 60 | 
 61 |     # 1-GPU Test
 62 |     echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2i_14B 1-GPU Test: "
 63 |     python $PY_FILE --task t2i-14B --size 480*832 --ckpt_dir $T2V_14B_CKPT_DIR
 64 | 
 65 |     # Multiple GPU Test
 66 |     echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2i_14B Multiple GPU Test: "
 67 |     torchrun --nproc_per_node=$GPUS $PY_FILE --task t2i-14B --ckpt_dir $T2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS
 68 | 
 69 |     echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2i_14B Multiple GPU, prompt extend local_qwen: "
 70 |     torchrun --nproc_per_node=$GPUS $PY_FILE --task t2i-14B --ckpt_dir $T2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_model "Qwen/Qwen2.5-3B-Instruct" --prompt_extend_target_lang "en"
 71 | }
 72 | 
 73 | 
 74 | function i2v_14B_480p() {
 75 |     I2V_14B_CKPT_DIR="$MODEL_DIR/Wan2.1-I2V-14B-480P"
 76 | 
 77 |     echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B 1-GPU Test: "
 78 |     python $PY_FILE --task i2v-14B --size 832*480 --ckpt_dir $I2V_14B_CKPT_DIR
 79 | 
 80 |     # Multiple GPU Test
 81 |     echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B Multiple GPU Test: "
 82 |     torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-14B --ckpt_dir $I2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS
 83 | 
 84 |     echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B Multiple GPU, prompt extend local_qwen: "
 85 |     torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-14B --ckpt_dir $I2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_model "Qwen/Qwen2.5-VL-3B-Instruct" --prompt_extend_target_lang "en"
 86 | 
 87 |     if [ -n "${DASH_API_KEY+x}" ]; then
 88 |         echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B Multiple GPU, prompt extend dashscope: "
 89 |         torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-14B --ckpt_dir $I2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_method "dashscope"
 90 |     else
 91 |         echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> No DASH_API_KEY found, skip the dashscope extend test."
 92 |     fi
 93 | }
 94 | 
 95 | 
 96 | function i2v_14B_720p() {
 97 |     I2V_14B_CKPT_DIR="$MODEL_DIR/Wan2.1-I2V-14B-720P"
 98 | 
 99 |     # 1-GPU Test
100 |     echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B 1-GPU Test: "
101 |     python $PY_FILE --task i2v-14B --size 720*1280 --ckpt_dir $I2V_14B_CKPT_DIR
102 | 
103 |     # Multiple GPU Test
104 |     echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B Multiple GPU Test: "
105 |     torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-14B --ckpt_dir $I2V_14B_CKPT_DIR --size 720*1280 --dit_fsdp --t5_fsdp --ulysses_size $GPUS
106 | }
107 | 
108 | function vace_1_3B() {
109 |     VACE_1_3B_CKPT_DIR="$MODEL_DIR/VACE-Wan2.1-1.3B-Preview/"
110 |     torchrun --nproc_per_node=$GPUS $PY_FILE --ulysses_size $GPUS --task vace-1.3B --size 480*832 --ckpt_dir $VACE_1_3B_CKPT_DIR
111 | 
112 | }
113 | 
114 | 
115 | t2i_14B
116 | t2v_1_3B
117 | t2v_14B
118 | i2v_14B_480p
119 | i2v_14B_720p
120 | vace_1_3B
121 | 


--------------------------------------------------------------------------------
/wan/__init__.py:
--------------------------------------------------------------------------------
1 | from . import configs, distributed, modules
2 | from .first_last_frame2video import WanFLF2V
3 | from .image2video import WanI2V
4 | from .text2video import WanT2V
5 | from .vace import WanVace, WanVaceMP
6 | 


--------------------------------------------------------------------------------
/wan/configs/__init__.py:
--------------------------------------------------------------------------------
 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
 2 | import copy
 3 | import os
 4 | 
 5 | os.environ['TOKENIZERS_PARALLELISM'] = 'false'
 6 | 
 7 | from .wan_i2v_14B import i2v_14B
 8 | from .wan_t2v_1_3B import t2v_1_3B
 9 | from .wan_t2v_14B import t2v_14B
10 | 
11 | # the config of t2i_14B is the same as t2v_14B
12 | t2i_14B = copy.deepcopy(t2v_14B)
13 | t2i_14B.__name__ = 'Config: Wan T2I 14B'
14 | 
15 | # the config of flf2v_14B is the same as i2v_14B
16 | flf2v_14B = copy.deepcopy(i2v_14B)
17 | flf2v_14B.__name__ = 'Config: Wan FLF2V 14B'
18 | flf2v_14B.sample_neg_prompt = "镜头切换," + flf2v_14B.sample_neg_prompt
19 | 
20 | WAN_CONFIGS = {
21 |     't2v-14B': t2v_14B,
22 |     't2v-1.3B': t2v_1_3B,
23 |     'i2v-14B': i2v_14B,
24 |     't2i-14B': t2i_14B,
25 |     'flf2v-14B': flf2v_14B,
26 |     'vace-1.3B': t2v_1_3B,
27 |     'vace-14B': t2v_14B,
28 | }
29 | 
30 | SIZE_CONFIGS = {
31 |     '720*1280': (720, 1280),
32 |     '1280*720': (1280, 720),
33 |     '480*832': (480, 832),
34 |     '832*480': (832, 480),
35 |     '1024*1024': (1024, 1024),
36 | }
37 | 
38 | MAX_AREA_CONFIGS = {
39 |     '720*1280': 720 * 1280,
40 |     '1280*720': 1280 * 720,
41 |     '480*832': 480 * 832,
42 |     '832*480': 832 * 480,
43 | }
44 | 
45 | SUPPORTED_SIZES = {
46 |     't2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
47 |     't2v-1.3B': ('480*832', '832*480'),
48 |     'i2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
49 |     'flf2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
50 |     't2i-14B': tuple(SIZE_CONFIGS.keys()),
51 |     'vace-1.3B': ('480*832', '832*480'),
52 |     'vace-14B': ('720*1280', '1280*720', '480*832', '832*480')
53 | }
54 | 


--------------------------------------------------------------------------------
/wan/configs/shared_config.py:
--------------------------------------------------------------------------------
 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
 2 | import torch
 3 | from easydict import EasyDict
 4 | 
 5 | #------------------------ Wan shared config ------------------------#
 6 | wan_shared_cfg = EasyDict()
 7 | 
 8 | # t5
 9 | wan_shared_cfg.t5_model = 'umt5_xxl'
10 | wan_shared_cfg.t5_dtype = torch.bfloat16
11 | wan_shared_cfg.text_len = 512
12 | 
13 | # transformer
14 | wan_shared_cfg.param_dtype = torch.bfloat16
15 | 
16 | # inference
17 | wan_shared_cfg.num_train_timesteps = 1000
18 | wan_shared_cfg.sample_fps = 16
19 | wan_shared_cfg.sample_neg_prompt = '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
20 | 


--------------------------------------------------------------------------------
/wan/configs/wan_i2v_14B.py:
--------------------------------------------------------------------------------
 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
 2 | import torch
 3 | from easydict import EasyDict
 4 | 
 5 | from .shared_config import wan_shared_cfg
 6 | 
 7 | #------------------------ Wan I2V 14B ------------------------#
 8 | 
 9 | i2v_14B = EasyDict(__name__='Config: Wan I2V 14B')
10 | i2v_14B.update(wan_shared_cfg)
11 | i2v_14B.sample_neg_prompt = "镜头晃动," + i2v_14B.sample_neg_prompt
12 | 
13 | i2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
14 | i2v_14B.t5_tokenizer = 'google/umt5-xxl'
15 | 
16 | # clip
17 | i2v_14B.clip_model = 'clip_xlm_roberta_vit_h_14'
18 | i2v_14B.clip_dtype = torch.float16
19 | i2v_14B.clip_checkpoint = 'models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth'
20 | i2v_14B.clip_tokenizer = 'xlm-roberta-large'
21 | 
22 | # vae
23 | i2v_14B.vae_checkpoint = 'Wan2.1_VAE.pth'
24 | i2v_14B.vae_stride = (4, 8, 8)
25 | 
26 | # transformer
27 | i2v_14B.patch_size = (1, 2, 2)
28 | i2v_14B.dim = 5120
29 | i2v_14B.ffn_dim = 13824
30 | i2v_14B.freq_dim = 256
31 | i2v_14B.num_heads = 40
32 | i2v_14B.num_layers = 40
33 | i2v_14B.window_size = (-1, -1)
34 | i2v_14B.qk_norm = True
35 | i2v_14B.cross_attn_norm = True
36 | i2v_14B.eps = 1e-6
37 | 


--------------------------------------------------------------------------------
/wan/configs/wan_t2v_14B.py:
--------------------------------------------------------------------------------
 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
 2 | from easydict import EasyDict
 3 | 
 4 | from .shared_config import wan_shared_cfg
 5 | 
 6 | #------------------------ Wan T2V 14B ------------------------#
 7 | 
 8 | t2v_14B = EasyDict(__name__='Config: Wan T2V 14B')
 9 | t2v_14B.update(wan_shared_cfg)
10 | 
11 | # t5
12 | t2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
13 | t2v_14B.t5_tokenizer = 'google/umt5-xxl'
14 | 
15 | # vae
16 | t2v_14B.vae_checkpoint = 'Wan2.1_VAE.pth'
17 | t2v_14B.vae_stride = (4, 8, 8)
18 | 
19 | # transformer
20 | t2v_14B.patch_size = (1, 2, 2)
21 | t2v_14B.dim = 5120
22 | t2v_14B.ffn_dim = 13824
23 | t2v_14B.freq_dim = 256
24 | t2v_14B.num_heads = 40
25 | t2v_14B.num_layers = 40
26 | t2v_14B.window_size = (-1, -1)
27 | t2v_14B.qk_norm = True
28 | t2v_14B.cross_attn_norm = True
29 | t2v_14B.eps = 1e-6
30 | 


--------------------------------------------------------------------------------
/wan/configs/wan_t2v_1_3B.py:
--------------------------------------------------------------------------------
 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
 2 | from easydict import EasyDict
 3 | 
 4 | from .shared_config import wan_shared_cfg
 5 | 
 6 | #------------------------ Wan T2V 1.3B ------------------------#
 7 | 
 8 | t2v_1_3B = EasyDict(__name__='Config: Wan T2V 1.3B')
 9 | t2v_1_3B.update(wan_shared_cfg)
10 | 
11 | # t5
12 | t2v_1_3B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
13 | t2v_1_3B.t5_tokenizer = 'google/umt5-xxl'
14 | 
15 | # vae
16 | t2v_1_3B.vae_checkpoint = 'Wan2.1_VAE.pth'
17 | t2v_1_3B.vae_stride = (4, 8, 8)
18 | 
19 | # transformer
20 | t2v_1_3B.patch_size = (1, 2, 2)
21 | t2v_1_3B.dim = 1536
22 | t2v_1_3B.ffn_dim = 8960
23 | t2v_1_3B.freq_dim = 256
24 | t2v_1_3B.num_heads = 12
25 | t2v_1_3B.num_layers = 30
26 | t2v_1_3B.window_size = (-1, -1)
27 | t2v_1_3B.qk_norm = True
28 | t2v_1_3B.cross_attn_norm = True
29 | t2v_1_3B.eps = 1e-6
30 | 


--------------------------------------------------------------------------------
/wan/distributed/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wan-Video/Wan2.1/261ca43e67ccaf68dbb5c0d8d1f971efde267573/wan/distributed/__init__.py


--------------------------------------------------------------------------------
/wan/distributed/fsdp.py:
--------------------------------------------------------------------------------
 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
 2 | import gc
 3 | from functools import partial
 4 | 
 5 | import torch
 6 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
 7 | from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
 8 | from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
 9 | from torch.distributed.utils import _free_storage
10 | 
11 | 
12 | def shard_model(
13 |     model,
14 |     device_id,
15 |     param_dtype=torch.bfloat16,
16 |     reduce_dtype=torch.float32,
17 |     buffer_dtype=torch.float32,
18 |     process_group=None,
19 |     sharding_strategy=ShardingStrategy.FULL_SHARD,
20 |     sync_module_states=True,
21 | ):
22 |     model = FSDP(
23 |         module=model,
24 |         process_group=process_group,
25 |         sharding_strategy=sharding_strategy,
26 |         auto_wrap_policy=partial(
27 |             lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks),
28 |         mixed_precision=MixedPrecision(
29 |             param_dtype=param_dtype,
30 |             reduce_dtype=reduce_dtype,
31 |             buffer_dtype=buffer_dtype),
32 |         device_id=device_id,
33 |         sync_module_states=sync_module_states)
34 |     return model
35 | 
36 | 
37 | def free_model(model):
38 |     for m in model.modules():
39 |         if isinstance(m, FSDP):
40 |             _free_storage(m._handle.flat_param.data)
41 |     del model
42 |     gc.collect()
43 |     torch.cuda.empty_cache()
44 | 


--------------------------------------------------------------------------------
/wan/distributed/xdit_context_parallel.py:
--------------------------------------------------------------------------------
  1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
  2 | import torch
  3 | import torch.cuda.amp as amp
  4 | from xfuser.core.distributed import (
  5 |     get_sequence_parallel_rank,
  6 |     get_sequence_parallel_world_size,
  7 |     get_sp_group,
  8 | )
  9 | from xfuser.core.long_ctx_attention import xFuserLongContextAttention
 10 | 
 11 | from ..modules.model import sinusoidal_embedding_1d
 12 | 
 13 | 
 14 | def pad_freqs(original_tensor, target_len):
 15 |     seq_len, s1, s2 = original_tensor.shape
 16 |     pad_size = target_len - seq_len
 17 |     padding_tensor = torch.ones(
 18 |         pad_size,
 19 |         s1,
 20 |         s2,
 21 |         dtype=original_tensor.dtype,
 22 |         device=original_tensor.device)
 23 |     padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
 24 |     return padded_tensor
 25 | 
 26 | 
 27 | @amp.autocast(enabled=False)
 28 | def rope_apply(x, grid_sizes, freqs):
 29 |     """
 30 |     x:          [B, L, N, C].
 31 |     grid_sizes: [B, 3].
 32 |     freqs:      [M, C // 2].
 33 |     """
 34 |     s, n, c = x.size(1), x.size(2), x.size(3) // 2
 35 |     # split freqs
 36 |     freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
 37 | 
 38 |     # loop over samples
 39 |     output = []
 40 |     for i, (f, h, w) in enumerate(grid_sizes.tolist()):
 41 |         seq_len = f * h * w
 42 | 
 43 |         # precompute multipliers
 44 |         x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(
 45 |             s, n, -1, 2))
 46 |         freqs_i = torch.cat([
 47 |             freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
 48 |             freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
 49 |             freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
 50 |         ],
 51 |                             dim=-1).reshape(seq_len, 1, -1)
 52 | 
 53 |         # apply rotary embedding
 54 |         sp_size = get_sequence_parallel_world_size()
 55 |         sp_rank = get_sequence_parallel_rank()
 56 |         freqs_i = pad_freqs(freqs_i, s * sp_size)
 57 |         s_per_rank = s
 58 |         freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) *
 59 |                                                        s_per_rank), :, :]
 60 |         x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2)
 61 |         x_i = torch.cat([x_i, x[i, s:]])
 62 | 
 63 |         # append to collection
 64 |         output.append(x_i)
 65 |     return torch.stack(output).float()
 66 | 
 67 | 
 68 | def usp_dit_forward_vace(self, x, vace_context, seq_len, kwargs):
 69 |     # embeddings
 70 |     c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
 71 |     c = [u.flatten(2).transpose(1, 2) for u in c]
 72 |     c = torch.cat([
 73 |         torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1)
 74 |         for u in c
 75 |     ])
 76 | 
 77 |     # arguments
 78 |     new_kwargs = dict(x=x)
 79 |     new_kwargs.update(kwargs)
 80 | 
 81 |     # Context Parallel
 82 |     c = torch.chunk(
 83 |         c, get_sequence_parallel_world_size(),
 84 |         dim=1)[get_sequence_parallel_rank()]
 85 | 
 86 |     hints = []
 87 |     for block in self.vace_blocks:
 88 |         c, c_skip = block(c, **new_kwargs)
 89 |         hints.append(c_skip)
 90 |     return hints
 91 | 
 92 | 
 93 | def usp_dit_forward(
 94 |     self,
 95 |     x,
 96 |     t,
 97 |     context,
 98 |     seq_len,
 99 |     vace_context=None,
100 |     vace_context_scale=1.0,
101 |     clip_fea=None,
102 |     y=None,
103 | ):
104 |     """
105 |     x:              A list of videos each with shape [C, T, H, W].
106 |     t:              [B].
107 |     context:        A list of text embeddings each with shape [L, C].
108 |     """
109 |     if self.model_type == 'i2v':
110 |         assert clip_fea is not None and y is not None
111 |     # params
112 |     device = self.patch_embedding.weight.device
113 |     if self.freqs.device != device:
114 |         self.freqs = self.freqs.to(device)
115 | 
116 |     if self.model_type != 'vace' and y is not None:
117 |         x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
118 | 
119 |     # embeddings
120 |     x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
121 |     grid_sizes = torch.stack(
122 |         [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
123 |     x = [u.flatten(2).transpose(1, 2) for u in x]
124 |     seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
125 |     assert seq_lens.max() <= seq_len
126 |     x = torch.cat([
127 |         torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1)
128 |         for u in x
129 |     ])
130 | 
131 |     # time embeddings
132 |     with amp.autocast(dtype=torch.float32):
133 |         e = self.time_embedding(
134 |             sinusoidal_embedding_1d(self.freq_dim, t).float())
135 |         e0 = self.time_projection(e).unflatten(1, (6, self.dim))
136 |         assert e.dtype == torch.float32 and e0.dtype == torch.float32
137 | 
138 |     # context
139 |     context_lens = None
140 |     context = self.text_embedding(
141 |         torch.stack([
142 |             torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
143 |             for u in context
144 |         ]))
145 | 
146 |     if self.model_type != 'vace' and clip_fea is not None:
147 |         context_clip = self.img_emb(clip_fea)  # bs x 257 x dim
148 |         context = torch.concat([context_clip, context], dim=1)
149 | 
150 |     # arguments
151 |     kwargs = dict(
152 |         e=e0,
153 |         seq_lens=seq_lens,
154 |         grid_sizes=grid_sizes,
155 |         freqs=self.freqs,
156 |         context=context,
157 |         context_lens=context_lens)
158 | 
159 |     # Context Parallel
160 |     x = torch.chunk(
161 |         x, get_sequence_parallel_world_size(),
162 |         dim=1)[get_sequence_parallel_rank()]
163 | 
164 |     if self.model_type == 'vace':
165 |         hints = self.forward_vace(x, vace_context, seq_len, kwargs)
166 |         kwargs['hints'] = hints
167 |         kwargs['context_scale'] = vace_context_scale
168 | 
169 |     for block in self.blocks:
170 |         x = block(x, **kwargs)
171 | 
172 |     # head
173 |     x = self.head(x, e)
174 | 
175 |     # Context Parallel
176 |     x = get_sp_group().all_gather(x, dim=1)
177 | 
178 |     # unpatchify
179 |     x = self.unpatchify(x, grid_sizes)
180 |     return [u.float() for u in x]
181 | 
182 | 
183 | def usp_attn_forward(self,
184 |                      x,
185 |                      seq_lens,
186 |                      grid_sizes,
187 |                      freqs,
188 |                      dtype=torch.bfloat16):
189 |     b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
190 |     half_dtypes = (torch.float16, torch.bfloat16)
191 | 
192 |     def half(x):
193 |         return x if x.dtype in half_dtypes else x.to(dtype)
194 | 
195 |     # query, key, value function
196 |     def qkv_fn(x):
197 |         q = self.norm_q(self.q(x)).view(b, s, n, d)
198 |         k = self.norm_k(self.k(x)).view(b, s, n, d)
199 |         v = self.v(x).view(b, s, n, d)
200 |         return q, k, v
201 | 
202 |     q, k, v = qkv_fn(x)
203 |     q = rope_apply(q, grid_sizes, freqs)
204 |     k = rope_apply(k, grid_sizes, freqs)
205 | 
206 |     # TODO: We should use unpaded q,k,v for attention.
207 |     # k_lens = seq_lens // get_sequence_parallel_world_size()
208 |     # if k_lens is not None:
209 |     #     q = torch.cat([u[:l] for u, l in zip(q, k_lens)]).unsqueeze(0)
210 |     #     k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0)
211 |     #     v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0)
212 | 
213 |     x = xFuserLongContextAttention()(
214 |         None,
215 |         query=half(q),
216 |         key=half(k),
217 |         value=half(v),
218 |         window_size=self.window_size)
219 | 
220 |     # TODO: padding after attention.
221 |     # x = torch.cat([x, x.new_zeros(b, s - x.size(1), n, d)], dim=1)
222 | 
223 |     # output
224 |     x = x.flatten(2)
225 |     x = self.o(x)
226 |     return x
227 | 


--------------------------------------------------------------------------------
/wan/first_last_frame2video.py:
--------------------------------------------------------------------------------
  1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
  2 | import gc
  3 | import logging
  4 | import math
  5 | import os
  6 | import random
  7 | import sys
  8 | import types
  9 | from contextlib import contextmanager
 10 | from functools import partial
 11 | 
 12 | import numpy as np
 13 | import torch
 14 | import torch.cuda.amp as amp
 15 | import torch.distributed as dist
 16 | import torchvision.transforms.functional as TF
 17 | from tqdm import tqdm
 18 | 
 19 | from .distributed.fsdp import shard_model
 20 | from .modules.clip import CLIPModel
 21 | from .modules.model import WanModel
 22 | from .modules.t5 import T5EncoderModel
 23 | from .modules.vae import WanVAE
 24 | from .utils.fm_solvers import (
 25 |     FlowDPMSolverMultistepScheduler,
 26 |     get_sampling_sigmas,
 27 |     retrieve_timesteps,
 28 | )
 29 | from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
 30 | 
 31 | 
 32 | class WanFLF2V:
 33 | 
 34 |     def __init__(
 35 |         self,
 36 |         config,
 37 |         checkpoint_dir,
 38 |         device_id=0,
 39 |         rank=0,
 40 |         t5_fsdp=False,
 41 |         dit_fsdp=False,
 42 |         use_usp=False,
 43 |         t5_cpu=False,
 44 |         init_on_cpu=True,
 45 |     ):
 46 |         r"""
 47 |         Initializes the image-to-video generation model components.
 48 | 
 49 |         Args:
 50 |             config (EasyDict):
 51 |                 Object containing model parameters initialized from config.py
 52 |             checkpoint_dir (`str`):
 53 |                 Path to directory containing model checkpoints
 54 |             device_id (`int`,  *optional*, defaults to 0):
 55 |                 Id of target GPU device
 56 |             rank (`int`,  *optional*, defaults to 0):
 57 |                 Process rank for distributed training
 58 |             t5_fsdp (`bool`, *optional*, defaults to False):
 59 |                 Enable FSDP sharding for T5 model
 60 |             dit_fsdp (`bool`, *optional*, defaults to False):
 61 |                 Enable FSDP sharding for DiT model
 62 |             use_usp (`bool`, *optional*, defaults to False):
 63 |                 Enable distribution strategy of USP.
 64 |             t5_cpu (`bool`, *optional*, defaults to False):
 65 |                 Whether to place T5 model on CPU. Only works without t5_fsdp.
 66 |             init_on_cpu (`bool`, *optional*, defaults to True):
 67 |                 Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
 68 |         """
 69 |         self.device = torch.device(f"cuda:{device_id}")
 70 |         self.config = config
 71 |         self.rank = rank
 72 |         self.use_usp = use_usp
 73 |         self.t5_cpu = t5_cpu
 74 | 
 75 |         self.num_train_timesteps = config.num_train_timesteps
 76 |         self.param_dtype = config.param_dtype
 77 | 
 78 |         shard_fn = partial(shard_model, device_id=device_id)
 79 |         self.text_encoder = T5EncoderModel(
 80 |             text_len=config.text_len,
 81 |             dtype=config.t5_dtype,
 82 |             device=torch.device('cpu'),
 83 |             checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
 84 |             tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
 85 |             shard_fn=shard_fn if t5_fsdp else None,
 86 |         )
 87 | 
 88 |         self.vae_stride = config.vae_stride
 89 |         self.patch_size = config.patch_size
 90 |         self.vae = WanVAE(
 91 |             vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
 92 |             device=self.device)
 93 | 
 94 |         self.clip = CLIPModel(
 95 |             dtype=config.clip_dtype,
 96 |             device=self.device,
 97 |             checkpoint_path=os.path.join(checkpoint_dir,
 98 |                                          config.clip_checkpoint),
 99 |             tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))
100 | 
101 |         logging.info(f"Creating WanModel from {checkpoint_dir}")
102 |         self.model = WanModel.from_pretrained(checkpoint_dir)
103 |         self.model.eval().requires_grad_(False)
104 | 
105 |         if t5_fsdp or dit_fsdp or use_usp:
106 |             init_on_cpu = False
107 | 
108 |         if use_usp:
109 |             from xfuser.core.distributed import get_sequence_parallel_world_size
110 | 
111 |             from .distributed.xdit_context_parallel import (
112 |                 usp_attn_forward,
113 |                 usp_dit_forward,
114 |             )
115 |             for block in self.model.blocks:
116 |                 block.self_attn.forward = types.MethodType(
117 |                     usp_attn_forward, block.self_attn)
118 |             self.model.forward = types.MethodType(usp_dit_forward, self.model)
119 |             self.sp_size = get_sequence_parallel_world_size()
120 |         else:
121 |             self.sp_size = 1
122 | 
123 |         if dist.is_initialized():
124 |             dist.barrier()
125 |         if dit_fsdp:
126 |             self.model = shard_fn(self.model)
127 |         else:
128 |             if not init_on_cpu:
129 |                 self.model.to(self.device)
130 | 
131 |         self.sample_neg_prompt = config.sample_neg_prompt
132 | 
133 |     def generate(self,
134 |                  input_prompt,
135 |                  first_frame,
136 |                  last_frame,
137 |                  max_area=720 * 1280,
138 |                  frame_num=81,
139 |                  shift=16,
140 |                  sample_solver='unipc',
141 |                  sampling_steps=50,
142 |                  guide_scale=5.5,
143 |                  n_prompt="",
144 |                  seed=-1,
145 |                  offload_model=True):
146 |         r"""
147 |         Generates video frames from input first-last frame and text prompt using diffusion process.
148 | 
149 |         Args:
150 |             input_prompt (`str`):
151 |                 Text prompt for content generation.
152 |             first_frame (PIL.Image.Image):
153 |                 Input image tensor. Shape: [3, H, W]
154 |             last_frame (PIL.Image.Image):
155 |                 Input image tensor. Shape: [3, H, W]
156 |                 [NOTE] If the sizes of first_frame and last_frame are mismatched, last_frame will be cropped & resized
157 |                 to match first_frame.
158 |             max_area (`int`, *optional*, defaults to 720*1280):
159 |                 Maximum pixel area for latent space calculation. Controls video resolution scaling
160 |             frame_num (`int`, *optional*, defaults to 81):
161 |                 How many frames to sample from a video. The number should be 4n+1
162 |             shift (`float`, *optional*, defaults to 5.0):
163 |                 Noise schedule shift parameter. Affects temporal dynamics
164 |                 [NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0.
165 |             sample_solver (`str`, *optional*, defaults to 'unipc'):
166 |                 Solver used to sample the video.
167 |             sampling_steps (`int`, *optional*, defaults to 40):
168 |                 Number of diffusion sampling steps. Higher values improve quality but slow generation
169 |             guide_scale (`float`, *optional*, defaults 5.0):
170 |                 Classifier-free guidance scale. Controls prompt adherence vs. creativity
171 |             n_prompt (`str`, *optional*, defaults to ""):
172 |                 Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
173 |             seed (`int`, *optional*, defaults to -1):
174 |                 Random seed for noise generation. If -1, use random seed
175 |             offload_model (`bool`, *optional*, defaults to True):
176 |                 If True, offloads models to CPU during generation to save VRAM
177 | 
178 |         Returns:
179 |             torch.Tensor:
180 |                 Generated video frames tensor. Dimensions: (C, N H, W) where:
181 |                 - C: Color channels (3 for RGB)
182 |                 - N: Number of frames (81)
183 |                 - H: Frame height (from max_area)
184 |                 - W: Frame width from max_area)
185 |         """
186 |         first_frame_size = first_frame.size
187 |         last_frame_size = last_frame.size
188 |         first_frame = TF.to_tensor(first_frame).sub_(0.5).div_(0.5).to(
189 |             self.device)
190 |         last_frame = TF.to_tensor(last_frame).sub_(0.5).div_(0.5).to(
191 |             self.device)
192 | 
193 |         F = frame_num
194 |         first_frame_h, first_frame_w = first_frame.shape[1:]
195 |         aspect_ratio = first_frame_h / first_frame_w
196 |         lat_h = round(
197 |             np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] //
198 |             self.patch_size[1] * self.patch_size[1])
199 |         lat_w = round(
200 |             np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] //
201 |             self.patch_size[2] * self.patch_size[2])
202 |         first_frame_h = lat_h * self.vae_stride[1]
203 |         first_frame_w = lat_w * self.vae_stride[2]
204 |         if first_frame_size != last_frame_size:
205 |             # 1. resize
206 |             last_frame_resize_ratio = max(
207 |                 first_frame_size[0] / last_frame_size[0],
208 |                 first_frame_size[1] / last_frame_size[1])
209 |             last_frame_size = [
210 |                 round(last_frame_size[0] * last_frame_resize_ratio),
211 |                 round(last_frame_size[1] * last_frame_resize_ratio),
212 |             ]
213 |             # 2. center crop
214 |             last_frame = TF.center_crop(last_frame, last_frame_size)
215 | 
216 |         max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // (
217 |             self.patch_size[1] * self.patch_size[2])
218 |         max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size
219 | 
220 |         seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
221 |         seed_g = torch.Generator(device=self.device)
222 |         seed_g.manual_seed(seed)
223 |         noise = torch.randn(
224 |             16, (F - 1) // 4 + 1,
225 |             lat_h,
226 |             lat_w,
227 |             dtype=torch.float32,
228 |             generator=seed_g,
229 |             device=self.device)
230 | 
231 |         msk = torch.ones(1, 81, lat_h, lat_w, device=self.device)
232 |         msk[:, 1:-1] = 0
233 |         msk = torch.concat([
234 |             torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]
235 |         ],
236 |                            dim=1)
237 |         msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
238 |         msk = msk.transpose(1, 2)[0]
239 | 
240 |         if n_prompt == "":
241 |             n_prompt = self.sample_neg_prompt
242 | 
243 |         # preprocess
244 |         if not self.t5_cpu:
245 |             self.text_encoder.model.to(self.device)
246 |             context = self.text_encoder([input_prompt], self.device)
247 |             context_null = self.text_encoder([n_prompt], self.device)
248 |             if offload_model:
249 |                 self.text_encoder.model.cpu()
250 |         else:
251 |             context = self.text_encoder([input_prompt], torch.device('cpu'))
252 |             context_null = self.text_encoder([n_prompt], torch.device('cpu'))
253 |             context = [t.to(self.device) for t in context]
254 |             context_null = [t.to(self.device) for t in context_null]
255 | 
256 |         self.clip.model.to(self.device)
257 |         clip_context = self.clip.visual(
258 |             [first_frame[:, None, :, :], last_frame[:, None, :, :]])
259 |         if offload_model:
260 |             self.clip.model.cpu()
261 | 
262 |         y = self.vae.encode([
263 |             torch.concat([
264 |                 torch.nn.functional.interpolate(
265 |                     first_frame[None].cpu(),
266 |                     size=(first_frame_h, first_frame_w),
267 |                     mode='bicubic').transpose(0, 1),
268 |                 torch.zeros(3, F - 2, first_frame_h, first_frame_w),
269 |                 torch.nn.functional.interpolate(
270 |                     last_frame[None].cpu(),
271 |                     size=(first_frame_h, first_frame_w),
272 |                     mode='bicubic').transpose(0, 1),
273 |             ],
274 |                          dim=1).to(self.device)
275 |         ])[0]
276 |         y = torch.concat([msk, y])
277 | 
278 |         @contextmanager
279 |         def noop_no_sync():
280 |             yield
281 | 
282 |         no_sync = getattr(self.model, 'no_sync', noop_no_sync)
283 | 
284 |         # evaluation mode
285 |         with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync():
286 | 
287 |             if sample_solver == 'unipc':
288 |                 sample_scheduler = FlowUniPCMultistepScheduler(
289 |                     num_train_timesteps=self.num_train_timesteps,
290 |                     shift=1,
291 |                     use_dynamic_shifting=False)
292 |                 sample_scheduler.set_timesteps(
293 |                     sampling_steps, device=self.device, shift=shift)
294 |                 timesteps = sample_scheduler.timesteps
295 |             elif sample_solver == 'dpm++':
296 |                 sample_scheduler = FlowDPMSolverMultistepScheduler(
297 |                     num_train_timesteps=self.num_train_timesteps,
298 |                     shift=1,
299 |                     use_dynamic_shifting=False)
300 |                 sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
301 |                 timesteps, _ = retrieve_timesteps(
302 |                     sample_scheduler,
303 |                     device=self.device,
304 |                     sigmas=sampling_sigmas)
305 |             else:
306 |                 raise NotImplementedError("Unsupported solver.")
307 | 
308 |             # sample videos
309 |             latent = noise
310 | 
311 |             arg_c = {
312 |                 'context': [context[0]],
313 |                 'clip_fea': clip_context,
314 |                 'seq_len': max_seq_len,
315 |                 'y': [y],
316 |             }
317 | 
318 |             arg_null = {
319 |                 'context': context_null,
320 |                 'clip_fea': clip_context,
321 |                 'seq_len': max_seq_len,
322 |                 'y': [y],
323 |             }
324 | 
325 |             if offload_model:
326 |                 torch.cuda.empty_cache()
327 | 
328 |             self.model.to(self.device)
329 |             for _, t in enumerate(tqdm(timesteps)):
330 |                 latent_model_input = [latent.to(self.device)]
331 |                 timestep = [t]
332 | 
333 |                 timestep = torch.stack(timestep).to(self.device)
334 | 
335 |                 noise_pred_cond = self.model(
336 |                     latent_model_input, t=timestep, **arg_c)[0].to(
337 |                         torch.device('cpu') if offload_model else self.device)
338 |                 if offload_model:
339 |                     torch.cuda.empty_cache()
340 |                 noise_pred_uncond = self.model(
341 |                     latent_model_input, t=timestep, **arg_null)[0].to(
342 |                         torch.device('cpu') if offload_model else self.device)
343 |                 if offload_model:
344 |                     torch.cuda.empty_cache()
345 |                 noise_pred = noise_pred_uncond + guide_scale * (
346 |                     noise_pred_cond - noise_pred_uncond)
347 | 
348 |                 latent = latent.to(
349 |                     torch.device('cpu') if offload_model else self.device)
350 | 
351 |                 temp_x0 = sample_scheduler.step(
352 |                     noise_pred.unsqueeze(0),
353 |                     t,
354 |                     latent.unsqueeze(0),
355 |                     return_dict=False,
356 |                     generator=seed_g)[0]
357 |                 latent = temp_x0.squeeze(0)
358 | 
359 |                 x0 = [latent.to(self.device)]
360 |                 del latent_model_input, timestep
361 | 
362 |             if offload_model:
363 |                 self.model.cpu()
364 |                 torch.cuda.empty_cache()
365 | 
366 |             if self.rank == 0:
367 |                 videos = self.vae.decode(x0)
368 | 
369 |         del noise, latent
370 |         del sample_scheduler
371 |         if offload_model:
372 |             gc.collect()
373 |             torch.cuda.synchronize()
374 |         if dist.is_initialized():
375 |             dist.barrier()
376 | 
377 |         return videos[0] if self.rank == 0 else None
378 | 


--------------------------------------------------------------------------------
/wan/image2video.py:
--------------------------------------------------------------------------------
  1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
  2 | import gc
  3 | import logging
  4 | import math
  5 | import os
  6 | import random
  7 | import sys
  8 | import types
  9 | from contextlib import contextmanager
 10 | from functools import partial
 11 | 
 12 | import numpy as np
 13 | import torch
 14 | import torch.cuda.amp as amp
 15 | import torch.distributed as dist
 16 | import torchvision.transforms.functional as TF
 17 | from tqdm import tqdm
 18 | 
 19 | from .distributed.fsdp import shard_model
 20 | from .modules.clip import CLIPModel
 21 | from .modules.model import WanModel
 22 | from .modules.t5 import T5EncoderModel
 23 | from .modules.vae import WanVAE
 24 | from .utils.fm_solvers import (
 25 |     FlowDPMSolverMultistepScheduler,
 26 |     get_sampling_sigmas,
 27 |     retrieve_timesteps,
 28 | )
 29 | from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
 30 | 
 31 | 
 32 | class WanI2V:
 33 | 
 34 |     def __init__(
 35 |         self,
 36 |         config,
 37 |         checkpoint_dir,
 38 |         device_id=0,
 39 |         rank=0,
 40 |         t5_fsdp=False,
 41 |         dit_fsdp=False,
 42 |         use_usp=False,
 43 |         t5_cpu=False,
 44 |         init_on_cpu=True,
 45 |     ):
 46 |         r"""
 47 |         Initializes the image-to-video generation model components.
 48 | 
 49 |         Args:
 50 |             config (EasyDict):
 51 |                 Object containing model parameters initialized from config.py
 52 |             checkpoint_dir (`str`):
 53 |                 Path to directory containing model checkpoints
 54 |             device_id (`int`,  *optional*, defaults to 0):
 55 |                 Id of target GPU device
 56 |             rank (`int`,  *optional*, defaults to 0):
 57 |                 Process rank for distributed training
 58 |             t5_fsdp (`bool`, *optional*, defaults to False):
 59 |                 Enable FSDP sharding for T5 model
 60 |             dit_fsdp (`bool`, *optional*, defaults to False):
 61 |                 Enable FSDP sharding for DiT model
 62 |             use_usp (`bool`, *optional*, defaults to False):
 63 |                 Enable distribution strategy of USP.
 64 |             t5_cpu (`bool`, *optional*, defaults to False):
 65 |                 Whether to place T5 model on CPU. Only works without t5_fsdp.
 66 |             init_on_cpu (`bool`, *optional*, defaults to True):
 67 |                 Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
 68 |         """
 69 |         self.device = torch.device(f"cuda:{device_id}")
 70 |         self.config = config
 71 |         self.rank = rank
 72 |         self.use_usp = use_usp
 73 |         self.t5_cpu = t5_cpu
 74 | 
 75 |         self.num_train_timesteps = config.num_train_timesteps
 76 |         self.param_dtype = config.param_dtype
 77 | 
 78 |         shard_fn = partial(shard_model, device_id=device_id)
 79 |         self.text_encoder = T5EncoderModel(
 80 |             text_len=config.text_len,
 81 |             dtype=config.t5_dtype,
 82 |             device=torch.device('cpu'),
 83 |             checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
 84 |             tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
 85 |             shard_fn=shard_fn if t5_fsdp else None,
 86 |         )
 87 | 
 88 |         self.vae_stride = config.vae_stride
 89 |         self.patch_size = config.patch_size
 90 |         self.vae = WanVAE(
 91 |             vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
 92 |             device=self.device)
 93 | 
 94 |         self.clip = CLIPModel(
 95 |             dtype=config.clip_dtype,
 96 |             device=self.device,
 97 |             checkpoint_path=os.path.join(checkpoint_dir,
 98 |                                          config.clip_checkpoint),
 99 |             tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))
100 | 
101 |         logging.info(f"Creating WanModel from {checkpoint_dir}")
102 |         self.model = WanModel.from_pretrained(checkpoint_dir)
103 |         self.model.eval().requires_grad_(False)
104 | 
105 |         if t5_fsdp or dit_fsdp or use_usp:
106 |             init_on_cpu = False
107 | 
108 |         if use_usp:
109 |             from xfuser.core.distributed import get_sequence_parallel_world_size
110 | 
111 |             from .distributed.xdit_context_parallel import (
112 |                 usp_attn_forward,
113 |                 usp_dit_forward,
114 |             )
115 |             for block in self.model.blocks:
116 |                 block.self_attn.forward = types.MethodType(
117 |                     usp_attn_forward, block.self_attn)
118 |             self.model.forward = types.MethodType(usp_dit_forward, self.model)
119 |             self.sp_size = get_sequence_parallel_world_size()
120 |         else:
121 |             self.sp_size = 1
122 | 
123 |         if dist.is_initialized():
124 |             dist.barrier()
125 |         if dit_fsdp:
126 |             self.model = shard_fn(self.model)
127 |         else:
128 |             if not init_on_cpu:
129 |                 self.model.to(self.device)
130 | 
131 |         self.sample_neg_prompt = config.sample_neg_prompt
132 | 
133 |     def generate(self,
134 |                  input_prompt,
135 |                  img,
136 |                  max_area=720 * 1280,
137 |                  frame_num=81,
138 |                  shift=5.0,
139 |                  sample_solver='unipc',
140 |                  sampling_steps=40,
141 |                  guide_scale=5.0,
142 |                  n_prompt="",
143 |                  seed=-1,
144 |                  offload_model=True):
145 |         r"""
146 |         Generates video frames from input image and text prompt using diffusion process.
147 | 
148 |         Args:
149 |             input_prompt (`str`):
150 |                 Text prompt for content generation.
151 |             img (PIL.Image.Image):
152 |                 Input image tensor. Shape: [3, H, W]
153 |             max_area (`int`, *optional*, defaults to 720*1280):
154 |                 Maximum pixel area for latent space calculation. Controls video resolution scaling
155 |             frame_num (`int`, *optional*, defaults to 81):
156 |                 How many frames to sample from a video. The number should be 4n+1
157 |             shift (`float`, *optional*, defaults to 5.0):
158 |                 Noise schedule shift parameter. Affects temporal dynamics
159 |                 [NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0.
160 |             sample_solver (`str`, *optional*, defaults to 'unipc'):
161 |                 Solver used to sample the video.
162 |             sampling_steps (`int`, *optional*, defaults to 40):
163 |                 Number of diffusion sampling steps. Higher values improve quality but slow generation
164 |             guide_scale (`float`, *optional*, defaults 5.0):
165 |                 Classifier-free guidance scale. Controls prompt adherence vs. creativity
166 |             n_prompt (`str`, *optional*, defaults to ""):
167 |                 Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
168 |             seed (`int`, *optional*, defaults to -1):
169 |                 Random seed for noise generation. If -1, use random seed
170 |             offload_model (`bool`, *optional*, defaults to True):
171 |                 If True, offloads models to CPU during generation to save VRAM
172 | 
173 |         Returns:
174 |             torch.Tensor:
175 |                 Generated video frames tensor. Dimensions: (C, N H, W) where:
176 |                 - C: Color channels (3 for RGB)
177 |                 - N: Number of frames (81)
178 |                 - H: Frame height (from max_area)
179 |                 - W: Frame width from max_area)
180 |         """
181 |         img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device)
182 | 
183 |         F = frame_num
184 |         h, w = img.shape[1:]
185 |         aspect_ratio = h / w
186 |         lat_h = round(
187 |             np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] //
188 |             self.patch_size[1] * self.patch_size[1])
189 |         lat_w = round(
190 |             np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] //
191 |             self.patch_size[2] * self.patch_size[2])
192 |         h = lat_h * self.vae_stride[1]
193 |         w = lat_w * self.vae_stride[2]
194 | 
195 |         max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // (
196 |             self.patch_size[1] * self.patch_size[2])
197 |         max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size
198 | 
199 |         seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
200 |         seed_g = torch.Generator(device=self.device)
201 |         seed_g.manual_seed(seed)
202 |         noise = torch.randn(
203 |             16, (F - 1) // 4 + 1,
204 |             lat_h,
205 |             lat_w,
206 |             dtype=torch.float32,
207 |             generator=seed_g,
208 |             device=self.device)
209 | 
210 |         msk = torch.ones(1, 81, lat_h, lat_w, device=self.device)
211 |         msk[:, 1:] = 0
212 |         msk = torch.concat([
213 |             torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]
214 |         ],
215 |                            dim=1)
216 |         msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
217 |         msk = msk.transpose(1, 2)[0]
218 | 
219 |         if n_prompt == "":
220 |             n_prompt = self.sample_neg_prompt
221 | 
222 |         # preprocess
223 |         if not self.t5_cpu:
224 |             self.text_encoder.model.to(self.device)
225 |             context = self.text_encoder([input_prompt], self.device)
226 |             context_null = self.text_encoder([n_prompt], self.device)
227 |             if offload_model:
228 |                 self.text_encoder.model.cpu()
229 |         else:
230 |             context = self.text_encoder([input_prompt], torch.device('cpu'))
231 |             context_null = self.text_encoder([n_prompt], torch.device('cpu'))
232 |             context = [t.to(self.device) for t in context]
233 |             context_null = [t.to(self.device) for t in context_null]
234 | 
235 |         self.clip.model.to(self.device)
236 |         clip_context = self.clip.visual([img[:, None, :, :]])
237 |         if offload_model:
238 |             self.clip.model.cpu()
239 | 
240 |         y = self.vae.encode([
241 |             torch.concat([
242 |                 torch.nn.functional.interpolate(
243 |                     img[None].cpu(), size=(h, w), mode='bicubic').transpose(
244 |                         0, 1),
245 |                 torch.zeros(3, F - 1, h, w)
246 |             ],
247 |                          dim=1).to(self.device)
248 |         ])[0]
249 |         y = torch.concat([msk, y])
250 | 
251 |         @contextmanager
252 |         def noop_no_sync():
253 |             yield
254 | 
255 |         no_sync = getattr(self.model, 'no_sync', noop_no_sync)
256 | 
257 |         # evaluation mode
258 |         with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync():
259 | 
260 |             if sample_solver == 'unipc':
261 |                 sample_scheduler = FlowUniPCMultistepScheduler(
262 |                     num_train_timesteps=self.num_train_timesteps,
263 |                     shift=1,
264 |                     use_dynamic_shifting=False)
265 |                 sample_scheduler.set_timesteps(
266 |                     sampling_steps, device=self.device, shift=shift)
267 |                 timesteps = sample_scheduler.timesteps
268 |             elif sample_solver == 'dpm++':
269 |                 sample_scheduler = FlowDPMSolverMultistepScheduler(
270 |                     num_train_timesteps=self.num_train_timesteps,
271 |                     shift=1,
272 |                     use_dynamic_shifting=False)
273 |                 sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
274 |                 timesteps, _ = retrieve_timesteps(
275 |                     sample_scheduler,
276 |                     device=self.device,
277 |                     sigmas=sampling_sigmas)
278 |             else:
279 |                 raise NotImplementedError("Unsupported solver.")
280 | 
281 |             # sample videos
282 |             latent = noise
283 | 
284 |             arg_c = {
285 |                 'context': [context[0]],
286 |                 'clip_fea': clip_context,
287 |                 'seq_len': max_seq_len,
288 |                 'y': [y],
289 |             }
290 | 
291 |             arg_null = {
292 |                 'context': context_null,
293 |                 'clip_fea': clip_context,
294 |                 'seq_len': max_seq_len,
295 |                 'y': [y],
296 |             }
297 | 
298 |             if offload_model:
299 |                 torch.cuda.empty_cache()
300 | 
301 |             self.model.to(self.device)
302 |             for _, t in enumerate(tqdm(timesteps)):
303 |                 latent_model_input = [latent.to(self.device)]
304 |                 timestep = [t]
305 | 
306 |                 timestep = torch.stack(timestep).to(self.device)
307 | 
308 |                 noise_pred_cond = self.model(
309 |                     latent_model_input, t=timestep, **arg_c)[0].to(
310 |                         torch.device('cpu') if offload_model else self.device)
311 |                 if offload_model:
312 |                     torch.cuda.empty_cache()
313 |                 noise_pred_uncond = self.model(
314 |                     latent_model_input, t=timestep, **arg_null)[0].to(
315 |                         torch.device('cpu') if offload_model else self.device)
316 |                 if offload_model:
317 |                     torch.cuda.empty_cache()
318 |                 noise_pred = noise_pred_uncond + guide_scale * (
319 |                     noise_pred_cond - noise_pred_uncond)
320 | 
321 |                 latent = latent.to(
322 |                     torch.device('cpu') if offload_model else self.device)
323 | 
324 |                 temp_x0 = sample_scheduler.step(
325 |                     noise_pred.unsqueeze(0),
326 |                     t,
327 |                     latent.unsqueeze(0),
328 |                     return_dict=False,
329 |                     generator=seed_g)[0]
330 |                 latent = temp_x0.squeeze(0)
331 | 
332 |                 x0 = [latent.to(self.device)]
333 |                 del latent_model_input, timestep
334 | 
335 |             if offload_model:
336 |                 self.model.cpu()
337 |                 torch.cuda.empty_cache()
338 | 
339 |             if self.rank == 0:
340 |                 videos = self.vae.decode(x0)
341 | 
342 |         del noise, latent
343 |         del sample_scheduler
344 |         if offload_model:
345 |             gc.collect()
346 |             torch.cuda.synchronize()
347 |         if dist.is_initialized():
348 |             dist.barrier()
349 | 
350 |         return videos[0] if self.rank == 0 else None
351 | 


--------------------------------------------------------------------------------
/wan/modules/__init__.py:
--------------------------------------------------------------------------------
 1 | from .attention import flash_attention
 2 | from .model import WanModel
 3 | from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model
 4 | from .tokenizers import HuggingfaceTokenizer
 5 | from .vace_model import VaceWanModel
 6 | from .vae import WanVAE
 7 | 
 8 | __all__ = [
 9 |     'WanVAE',
10 |     'WanModel',
11 |     'VaceWanModel',
12 |     'T5Model',
13 |     'T5Encoder',
14 |     'T5Decoder',
15 |     'T5EncoderModel',
16 |     'HuggingfaceTokenizer',
17 |     'flash_attention',
18 | ]
19 | 


--------------------------------------------------------------------------------
/wan/modules/attention.py:
--------------------------------------------------------------------------------
  1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
  2 | import torch
  3 | 
  4 | try:
  5 |     import flash_attn_interface
  6 |     FLASH_ATTN_3_AVAILABLE = True
  7 | except ModuleNotFoundError:
  8 |     FLASH_ATTN_3_AVAILABLE = False
  9 | 
 10 | try:
 11 |     import flash_attn
 12 |     FLASH_ATTN_2_AVAILABLE = True
 13 | except ModuleNotFoundError:
 14 |     FLASH_ATTN_2_AVAILABLE = False
 15 | 
 16 | import warnings
 17 | 
 18 | __all__ = [
 19 |     'flash_attention',
 20 |     'attention',
 21 | ]
 22 | 
 23 | 
 24 | def flash_attention(
 25 |     q,
 26 |     k,
 27 |     v,
 28 |     q_lens=None,
 29 |     k_lens=None,
 30 |     dropout_p=0.,
 31 |     softmax_scale=None,
 32 |     q_scale=None,
 33 |     causal=False,
 34 |     window_size=(-1, -1),
 35 |     deterministic=False,
 36 |     dtype=torch.bfloat16,
 37 |     version=None,
 38 | ):
 39 |     """
 40 |     q:              [B, Lq, Nq, C1].
 41 |     k:              [B, Lk, Nk, C1].
 42 |     v:              [B, Lk, Nk, C2]. Nq must be divisible by Nk.
 43 |     q_lens:         [B].
 44 |     k_lens:         [B].
 45 |     dropout_p:      float. Dropout probability.
 46 |     softmax_scale:  float. The scaling of QK^T before applying softmax.
 47 |     causal:         bool. Whether to apply causal attention mask.
 48 |     window_size:    (left right). If not (-1, -1), apply sliding window local attention.
 49 |     deterministic:  bool. If True, slightly slower and uses more memory.
 50 |     dtype:          torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
 51 |     """
 52 |     half_dtypes = (torch.float16, torch.bfloat16)
 53 |     assert dtype in half_dtypes
 54 |     assert q.device.type == 'cuda' and q.size(-1) <= 256
 55 | 
 56 |     # params
 57 |     b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
 58 | 
 59 |     def half(x):
 60 |         return x if x.dtype in half_dtypes else x.to(dtype)
 61 | 
 62 |     # preprocess query
 63 |     if q_lens is None:
 64 |         q = half(q.flatten(0, 1))
 65 |         q_lens = torch.tensor(
 66 |             [lq] * b, dtype=torch.int32).to(
 67 |                 device=q.device, non_blocking=True)
 68 |     else:
 69 |         q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))
 70 | 
 71 |     # preprocess key, value
 72 |     if k_lens is None:
 73 |         k = half(k.flatten(0, 1))
 74 |         v = half(v.flatten(0, 1))
 75 |         k_lens = torch.tensor(
 76 |             [lk] * b, dtype=torch.int32).to(
 77 |                 device=k.device, non_blocking=True)
 78 |     else:
 79 |         k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
 80 |         v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))
 81 | 
 82 |     q = q.to(v.dtype)
 83 |     k = k.to(v.dtype)
 84 | 
 85 |     if q_scale is not None:
 86 |         q = q * q_scale
 87 | 
 88 |     if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
 89 |         warnings.warn(
 90 |             'Flash attention 3 is not available, use flash attention 2 instead.'
 91 |         )
 92 | 
 93 |     # apply attention
 94 |     if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE:
 95 |         # Note: dropout_p, window_size are not supported in FA3 now.
 96 |         x = flash_attn_interface.flash_attn_varlen_func(
 97 |             q=q,
 98 |             k=k,
 99 |             v=v,
100 |             cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
101 |                 0, dtype=torch.int32).to(q.device, non_blocking=True),
102 |             cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
103 |                 0, dtype=torch.int32).to(q.device, non_blocking=True),
104 |             seqused_q=None,
105 |             seqused_k=None,
106 |             max_seqlen_q=lq,
107 |             max_seqlen_k=lk,
108 |             softmax_scale=softmax_scale,
109 |             causal=causal,
110 |             deterministic=deterministic)[0].unflatten(0, (b, lq))
111 |     else:
112 |         assert FLASH_ATTN_2_AVAILABLE
113 |         x = flash_attn.flash_attn_varlen_func(
114 |             q=q,
115 |             k=k,
116 |             v=v,
117 |             cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
118 |                 0, dtype=torch.int32).to(q.device, non_blocking=True),
119 |             cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
120 |                 0, dtype=torch.int32).to(q.device, non_blocking=True),
121 |             max_seqlen_q=lq,
122 |             max_seqlen_k=lk,
123 |             dropout_p=dropout_p,
124 |             softmax_scale=softmax_scale,
125 |             causal=causal,
126 |             window_size=window_size,
127 |             deterministic=deterministic).unflatten(0, (b, lq))
128 | 
129 |     # output
130 |     return x.type(out_dtype)
131 | 
132 | 
133 | def attention(
134 |     q,
135 |     k,
136 |     v,
137 |     q_lens=None,
138 |     k_lens=None,
139 |     dropout_p=0.,
140 |     softmax_scale=None,
141 |     q_scale=None,
142 |     causal=False,
143 |     window_size=(-1, -1),
144 |     deterministic=False,
145 |     dtype=torch.bfloat16,
146 |     fa_version=None,
147 | ):
148 |     if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
149 |         return flash_attention(
150 |             q=q,
151 |             k=k,
152 |             v=v,
153 |             q_lens=q_lens,
154 |             k_lens=k_lens,
155 |             dropout_p=dropout_p,
156 |             softmax_scale=softmax_scale,
157 |             q_scale=q_scale,
158 |             causal=causal,
159 |             window_size=window_size,
160 |             deterministic=deterministic,
161 |             dtype=dtype,
162 |             version=fa_version,
163 |         )
164 |     else:
165 |         if q_lens is not None or k_lens is not None:
166 |             warnings.warn(
167 |                 'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.'
168 |             )
169 |         attn_mask = None
170 | 
171 |         q = q.transpose(1, 2).to(dtype)
172 |         k = k.transpose(1, 2).to(dtype)
173 |         v = v.transpose(1, 2).to(dtype)
174 | 
175 |         out = torch.nn.functional.scaled_dot_product_attention(
176 |             q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p)
177 | 
178 |         out = out.transpose(1, 2).contiguous()
179 |         return out
180 | 


--------------------------------------------------------------------------------
/wan/modules/t5.py:
--------------------------------------------------------------------------------
  1 | # Modified from transformers.models.t5.modeling_t5
  2 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
  3 | import logging
  4 | import math
  5 | 
  6 | import torch
  7 | import torch.nn as nn
  8 | import torch.nn.functional as F
  9 | 
 10 | from .tokenizers import HuggingfaceTokenizer
 11 | 
 12 | __all__ = [
 13 |     'T5Model',
 14 |     'T5Encoder',
 15 |     'T5Decoder',
 16 |     'T5EncoderModel',
 17 | ]
 18 | 
 19 | 
 20 | def fp16_clamp(x):
 21 |     if x.dtype == torch.float16 and torch.isinf(x).any():
 22 |         clamp = torch.finfo(x.dtype).max - 1000
 23 |         x = torch.clamp(x, min=-clamp, max=clamp)
 24 |     return x
 25 | 
 26 | 
 27 | def init_weights(m):
 28 |     if isinstance(m, T5LayerNorm):
 29 |         nn.init.ones_(m.weight)
 30 |     elif isinstance(m, T5Model):
 31 |         nn.init.normal_(m.token_embedding.weight, std=1.0)
 32 |     elif isinstance(m, T5FeedForward):
 33 |         nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
 34 |         nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
 35 |         nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
 36 |     elif isinstance(m, T5Attention):
 37 |         nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5)
 38 |         nn.init.normal_(m.k.weight, std=m.dim**-0.5)
 39 |         nn.init.normal_(m.v.weight, std=m.dim**-0.5)
 40 |         nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5)
 41 |     elif isinstance(m, T5RelativeEmbedding):
 42 |         nn.init.normal_(
 43 |             m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5)
 44 | 
 45 | 
 46 | class GELU(nn.Module):
 47 | 
 48 |     def forward(self, x):
 49 |         return 0.5 * x * (1.0 + torch.tanh(
 50 |             math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
 51 | 
 52 | 
 53 | class T5LayerNorm(nn.Module):
 54 | 
 55 |     def __init__(self, dim, eps=1e-6):
 56 |         super(T5LayerNorm, self).__init__()
 57 |         self.dim = dim
 58 |         self.eps = eps
 59 |         self.weight = nn.Parameter(torch.ones(dim))
 60 | 
 61 |     def forward(self, x):
 62 |         x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) +
 63 |                             self.eps)
 64 |         if self.weight.dtype in [torch.float16, torch.bfloat16]:
 65 |             x = x.type_as(self.weight)
 66 |         return self.weight * x
 67 | 
 68 | 
 69 | class T5Attention(nn.Module):
 70 | 
 71 |     def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
 72 |         assert dim_attn % num_heads == 0
 73 |         super(T5Attention, self).__init__()
 74 |         self.dim = dim
 75 |         self.dim_attn = dim_attn
 76 |         self.num_heads = num_heads
 77 |         self.head_dim = dim_attn // num_heads
 78 | 
 79 |         # layers
 80 |         self.q = nn.Linear(dim, dim_attn, bias=False)
 81 |         self.k = nn.Linear(dim, dim_attn, bias=False)
 82 |         self.v = nn.Linear(dim, dim_attn, bias=False)
 83 |         self.o = nn.Linear(dim_attn, dim, bias=False)
 84 |         self.dropout = nn.Dropout(dropout)
 85 | 
 86 |     def forward(self, x, context=None, mask=None, pos_bias=None):
 87 |         """
 88 |         x:          [B, L1, C].
 89 |         context:    [B, L2, C] or None.
 90 |         mask:       [B, L2] or [B, L1, L2] or None.
 91 |         """
 92 |         # check inputs
 93 |         context = x if context is None else context
 94 |         b, n, c = x.size(0), self.num_heads, self.head_dim
 95 | 
 96 |         # compute query, key, value
 97 |         q = self.q(x).view(b, -1, n, c)
 98 |         k = self.k(context).view(b, -1, n, c)
 99 |         v = self.v(context).view(b, -1, n, c)
100 | 
101 |         # attention bias
102 |         attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
103 |         if pos_bias is not None:
104 |             attn_bias += pos_bias
105 |         if mask is not None:
106 |             assert mask.ndim in [2, 3]
107 |             mask = mask.view(b, 1, 1,
108 |                              -1) if mask.ndim == 2 else mask.unsqueeze(1)
109 |             attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
110 | 
111 |         # compute attention (T5 does not use scaling)
112 |         attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias
113 |         attn = F.softmax(attn.float(), dim=-1).type_as(attn)
114 |         x = torch.einsum('bnij,bjnc->binc', attn, v)
115 | 
116 |         # output
117 |         x = x.reshape(b, -1, n * c)
118 |         x = self.o(x)
119 |         x = self.dropout(x)
120 |         return x
121 | 
122 | 
123 | class T5FeedForward(nn.Module):
124 | 
125 |     def __init__(self, dim, dim_ffn, dropout=0.1):
126 |         super(T5FeedForward, self).__init__()
127 |         self.dim = dim
128 |         self.dim_ffn = dim_ffn
129 | 
130 |         # layers
131 |         self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
132 |         self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
133 |         self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
134 |         self.dropout = nn.Dropout(dropout)
135 | 
136 |     def forward(self, x):
137 |         x = self.fc1(x) * self.gate(x)
138 |         x = self.dropout(x)
139 |         x = self.fc2(x)
140 |         x = self.dropout(x)
141 |         return x
142 | 
143 | 
144 | class T5SelfAttention(nn.Module):
145 | 
146 |     def __init__(self,
147 |                  dim,
148 |                  dim_attn,
149 |                  dim_ffn,
150 |                  num_heads,
151 |                  num_buckets,
152 |                  shared_pos=True,
153 |                  dropout=0.1):
154 |         super(T5SelfAttention, self).__init__()
155 |         self.dim = dim
156 |         self.dim_attn = dim_attn
157 |         self.dim_ffn = dim_ffn
158 |         self.num_heads = num_heads
159 |         self.num_buckets = num_buckets
160 |         self.shared_pos = shared_pos
161 | 
162 |         # layers
163 |         self.norm1 = T5LayerNorm(dim)
164 |         self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
165 |         self.norm2 = T5LayerNorm(dim)
166 |         self.ffn = T5FeedForward(dim, dim_ffn, dropout)
167 |         self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
168 |             num_buckets, num_heads, bidirectional=True)
169 | 
170 |     def forward(self, x, mask=None, pos_bias=None):
171 |         e = pos_bias if self.shared_pos else self.pos_embedding(
172 |             x.size(1), x.size(1))
173 |         x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
174 |         x = fp16_clamp(x + self.ffn(self.norm2(x)))
175 |         return x
176 | 
177 | 
178 | class T5CrossAttention(nn.Module):
179 | 
180 |     def __init__(self,
181 |                  dim,
182 |                  dim_attn,
183 |                  dim_ffn,
184 |                  num_heads,
185 |                  num_buckets,
186 |                  shared_pos=True,
187 |                  dropout=0.1):
188 |         super(T5CrossAttention, self).__init__()
189 |         self.dim = dim
190 |         self.dim_attn = dim_attn
191 |         self.dim_ffn = dim_ffn
192 |         self.num_heads = num_heads
193 |         self.num_buckets = num_buckets
194 |         self.shared_pos = shared_pos
195 | 
196 |         # layers
197 |         self.norm1 = T5LayerNorm(dim)
198 |         self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout)
199 |         self.norm2 = T5LayerNorm(dim)
200 |         self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout)
201 |         self.norm3 = T5LayerNorm(dim)
202 |         self.ffn = T5FeedForward(dim, dim_ffn, dropout)
203 |         self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
204 |             num_buckets, num_heads, bidirectional=False)
205 | 
206 |     def forward(self,
207 |                 x,
208 |                 mask=None,
209 |                 encoder_states=None,
210 |                 encoder_mask=None,
211 |                 pos_bias=None):
212 |         e = pos_bias if self.shared_pos else self.pos_embedding(
213 |             x.size(1), x.size(1))
214 |         x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e))
215 |         x = fp16_clamp(x + self.cross_attn(
216 |             self.norm2(x), context=encoder_states, mask=encoder_mask))
217 |         x = fp16_clamp(x + self.ffn(self.norm3(x)))
218 |         return x
219 | 
220 | 
221 | class T5RelativeEmbedding(nn.Module):
222 | 
223 |     def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
224 |         super(T5RelativeEmbedding, self).__init__()
225 |         self.num_buckets = num_buckets
226 |         self.num_heads = num_heads
227 |         self.bidirectional = bidirectional
228 |         self.max_dist = max_dist
229 | 
230 |         # layers
231 |         self.embedding = nn.Embedding(num_buckets, num_heads)
232 | 
233 |     def forward(self, lq, lk):
234 |         device = self.embedding.weight.device
235 |         # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
236 |         #     torch.arange(lq).unsqueeze(1).to(device)
237 |         rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \
238 |             torch.arange(lq, device=device).unsqueeze(1)
239 |         rel_pos = self._relative_position_bucket(rel_pos)
240 |         rel_pos_embeds = self.embedding(rel_pos)
241 |         rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(
242 |             0)  # [1, N, Lq, Lk]
243 |         return rel_pos_embeds.contiguous()
244 | 
245 |     def _relative_position_bucket(self, rel_pos):
246 |         # preprocess
247 |         if self.bidirectional:
248 |             num_buckets = self.num_buckets // 2
249 |             rel_buckets = (rel_pos > 0).long() * num_buckets
250 |             rel_pos = torch.abs(rel_pos)
251 |         else:
252 |             num_buckets = self.num_buckets
253 |             rel_buckets = 0
254 |             rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))
255 | 
256 |         # embeddings for small and large positions
257 |         max_exact = num_buckets // 2
258 |         rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) /
259 |                                      math.log(self.max_dist / max_exact) *
260 |                                      (num_buckets - max_exact)).long()
261 |         rel_pos_large = torch.min(
262 |             rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
263 |         rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
264 |         return rel_buckets
265 | 
266 | 
267 | class T5Encoder(nn.Module):
268 | 
269 |     def __init__(self,
270 |                  vocab,
271 |                  dim,
272 |                  dim_attn,
273 |                  dim_ffn,
274 |                  num_heads,
275 |                  num_layers,
276 |                  num_buckets,
277 |                  shared_pos=True,
278 |                  dropout=0.1):
279 |         super(T5Encoder, self).__init__()
280 |         self.dim = dim
281 |         self.dim_attn = dim_attn
282 |         self.dim_ffn = dim_ffn
283 |         self.num_heads = num_heads
284 |         self.num_layers = num_layers
285 |         self.num_buckets = num_buckets
286 |         self.shared_pos = shared_pos
287 | 
288 |         # layers
289 |         self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \
290 |             else nn.Embedding(vocab, dim)
291 |         self.pos_embedding = T5RelativeEmbedding(
292 |             num_buckets, num_heads, bidirectional=True) if shared_pos else None
293 |         self.dropout = nn.Dropout(dropout)
294 |         self.blocks = nn.ModuleList([
295 |             T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
296 |                             shared_pos, dropout) for _ in range(num_layers)
297 |         ])
298 |         self.norm = T5LayerNorm(dim)
299 | 
300 |         # initialize weights
301 |         self.apply(init_weights)
302 | 
303 |     def forward(self, ids, mask=None):
304 |         x = self.token_embedding(ids)
305 |         x = self.dropout(x)
306 |         e = self.pos_embedding(x.size(1),
307 |                                x.size(1)) if self.shared_pos else None
308 |         for block in self.blocks:
309 |             x = block(x, mask, pos_bias=e)
310 |         x = self.norm(x)
311 |         x = self.dropout(x)
312 |         return x
313 | 
314 | 
315 | class T5Decoder(nn.Module):
316 | 
317 |     def __init__(self,
318 |                  vocab,
319 |                  dim,
320 |                  dim_attn,
321 |                  dim_ffn,
322 |                  num_heads,
323 |                  num_layers,
324 |                  num_buckets,
325 |                  shared_pos=True,
326 |                  dropout=0.1):
327 |         super(T5Decoder, self).__init__()
328 |         self.dim = dim
329 |         self.dim_attn = dim_attn
330 |         self.dim_ffn = dim_ffn
331 |         self.num_heads = num_heads
332 |         self.num_layers = num_layers
333 |         self.num_buckets = num_buckets
334 |         self.shared_pos = shared_pos
335 | 
336 |         # layers
337 |         self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \
338 |             else nn.Embedding(vocab, dim)
339 |         self.pos_embedding = T5RelativeEmbedding(
340 |             num_buckets, num_heads, bidirectional=False) if shared_pos else None
341 |         self.dropout = nn.Dropout(dropout)
342 |         self.blocks = nn.ModuleList([
343 |             T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
344 |                              shared_pos, dropout) for _ in range(num_layers)
345 |         ])
346 |         self.norm = T5LayerNorm(dim)
347 | 
348 |         # initialize weights
349 |         self.apply(init_weights)
350 | 
351 |     def forward(self, ids, mask=None, encoder_states=None, encoder_mask=None):
352 |         b, s = ids.size()
353 | 
354 |         # causal mask
355 |         if mask is None:
356 |             mask = torch.tril(torch.ones(1, s, s).to(ids.device))
357 |         elif mask.ndim == 2:
358 |             mask = torch.tril(mask.unsqueeze(1).expand(-1, s, -1))
359 | 
360 |         # layers
361 |         x = self.token_embedding(ids)
362 |         x = self.dropout(x)
363 |         e = self.pos_embedding(x.size(1),
364 |                                x.size(1)) if self.shared_pos else None
365 |         for block in self.blocks:
366 |             x = block(x, mask, encoder_states, encoder_mask, pos_bias=e)
367 |         x = self.norm(x)
368 |         x = self.dropout(x)
369 |         return x
370 | 
371 | 
372 | class T5Model(nn.Module):
373 | 
374 |     def __init__(self,
375 |                  vocab_size,
376 |                  dim,
377 |                  dim_attn,
378 |                  dim_ffn,
379 |                  num_heads,
380 |                  encoder_layers,
381 |                  decoder_layers,
382 |                  num_buckets,
383 |                  shared_pos=True,
384 |                  dropout=0.1):
385 |         super(T5Model, self).__init__()
386 |         self.vocab_size = vocab_size
387 |         self.dim = dim
388 |         self.dim_attn = dim_attn
389 |         self.dim_ffn = dim_ffn
390 |         self.num_heads = num_heads
391 |         self.encoder_layers = encoder_layers
392 |         self.decoder_layers = decoder_layers
393 |         self.num_buckets = num_buckets
394 | 
395 |         # layers
396 |         self.token_embedding = nn.Embedding(vocab_size, dim)
397 |         self.encoder = T5Encoder(self.token_embedding, dim, dim_attn, dim_ffn,
398 |                                  num_heads, encoder_layers, num_buckets,
399 |                                  shared_pos, dropout)
400 |         self.decoder = T5Decoder(self.token_embedding, dim, dim_attn, dim_ffn,
401 |                                  num_heads, decoder_layers, num_buckets,
402 |                                  shared_pos, dropout)
403 |         self.head = nn.Linear(dim, vocab_size, bias=False)
404 | 
405 |         # initialize weights
406 |         self.apply(init_weights)
407 | 
408 |     def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask):
409 |         x = self.encoder(encoder_ids, encoder_mask)
410 |         x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask)
411 |         x = self.head(x)
412 |         return x
413 | 
414 | 
415 | def _t5(name,
416 |         encoder_only=False,
417 |         decoder_only=False,
418 |         return_tokenizer=False,
419 |         tokenizer_kwargs={},
420 |         dtype=torch.float32,
421 |         device='cpu',
422 |         **kwargs):
423 |     # sanity check
424 |     assert not (encoder_only and decoder_only)
425 | 
426 |     # params
427 |     if encoder_only:
428 |         model_cls = T5Encoder
429 |         kwargs['vocab'] = kwargs.pop('vocab_size')
430 |         kwargs['num_layers'] = kwargs.pop('encoder_layers')
431 |         _ = kwargs.pop('decoder_layers')
432 |     elif decoder_only:
433 |         model_cls = T5Decoder
434 |         kwargs['vocab'] = kwargs.pop('vocab_size')
435 |         kwargs['num_layers'] = kwargs.pop('decoder_layers')
436 |         _ = kwargs.pop('encoder_layers')
437 |     else:
438 |         model_cls = T5Model
439 | 
440 |     # init model
441 |     with torch.device(device):
442 |         model = model_cls(**kwargs)
443 | 
444 |     # set device
445 |     model = model.to(dtype=dtype, device=device)
446 | 
447 |     # init tokenizer
448 |     if return_tokenizer:
449 |         from .tokenizers import HuggingfaceTokenizer
450 |         tokenizer = HuggingfaceTokenizer(f'google/{name}', **tokenizer_kwargs)
451 |         return model, tokenizer
452 |     else:
453 |         return model
454 | 
455 | 
456 | def umt5_xxl(**kwargs):
457 |     cfg = dict(
458 |         vocab_size=256384,
459 |         dim=4096,
460 |         dim_attn=4096,
461 |         dim_ffn=10240,
462 |         num_heads=64,
463 |         encoder_layers=24,
464 |         decoder_layers=24,
465 |         num_buckets=32,
466 |         shared_pos=False,
467 |         dropout=0.1)
468 |     cfg.update(**kwargs)
469 |     return _t5('umt5-xxl', **cfg)
470 | 
471 | 
472 | class T5EncoderModel:
473 | 
474 |     def __init__(
475 |         self,
476 |         text_len,
477 |         dtype=torch.bfloat16,
478 |         device=torch.cuda.current_device(),
479 |         checkpoint_path=None,
480 |         tokenizer_path=None,
481 |         shard_fn=None,
482 |     ):
483 |         self.text_len = text_len
484 |         self.dtype = dtype
485 |         self.device = device
486 |         self.checkpoint_path = checkpoint_path
487 |         self.tokenizer_path = tokenizer_path
488 | 
489 |         # init model
490 |         model = umt5_xxl(
491 |             encoder_only=True,
492 |             return_tokenizer=False,
493 |             dtype=dtype,
494 |             device=device).eval().requires_grad_(False)
495 |         logging.info(f'loading {checkpoint_path}')
496 |         model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
497 |         self.model = model
498 |         if shard_fn is not None:
499 |             self.model = shard_fn(self.model, sync_module_states=False)
500 |         else:
501 |             self.model.to(self.device)
502 |         # init tokenizer
503 |         self.tokenizer = HuggingfaceTokenizer(
504 |             name=tokenizer_path, seq_len=text_len, clean='whitespace')
505 | 
506 |     def __call__(self, texts, device):
507 |         ids, mask = self.tokenizer(
508 |             texts, return_mask=True, add_special_tokens=True)
509 |         ids = ids.to(device)
510 |         mask = mask.to(device)
511 |         seq_lens = mask.gt(0).sum(dim=1).long()
512 |         context = self.model(ids, mask)
513 |         return [u[:v] for u, v in zip(context, seq_lens)]
514 | 


--------------------------------------------------------------------------------
/wan/modules/tokenizers.py:
--------------------------------------------------------------------------------
 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
 2 | import html
 3 | import string
 4 | 
 5 | import ftfy
 6 | import regex as re
 7 | from transformers import AutoTokenizer
 8 | 
 9 | __all__ = ['HuggingfaceTokenizer']
10 | 
11 | 
12 | def basic_clean(text):
13 |     text = ftfy.fix_text(text)
14 |     text = html.unescape(html.unescape(text))
15 |     return text.strip()
16 | 
17 | 
18 | def whitespace_clean(text):
19 |     text = re.sub(r'\s+', ' ', text)
20 |     text = text.strip()
21 |     return text
22 | 
23 | 
24 | def canonicalize(text, keep_punctuation_exact_string=None):
25 |     text = text.replace('_', ' ')
26 |     if keep_punctuation_exact_string:
27 |         text = keep_punctuation_exact_string.join(
28 |             part.translate(str.maketrans('', '', string.punctuation))
29 |             for part in text.split(keep_punctuation_exact_string))
30 |     else:
31 |         text = text.translate(str.maketrans('', '', string.punctuation))
32 |     text = text.lower()
33 |     text = re.sub(r'\s+', ' ', text)
34 |     return text.strip()
35 | 
36 | 
37 | class HuggingfaceTokenizer:
38 | 
39 |     def __init__(self, name, seq_len=None, clean=None, **kwargs):
40 |         assert clean in (None, 'whitespace', 'lower', 'canonicalize')
41 |         self.name = name
42 |         self.seq_len = seq_len
43 |         self.clean = clean
44 | 
45 |         # init tokenizer
46 |         self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
47 |         self.vocab_size = self.tokenizer.vocab_size
48 | 
49 |     def __call__(self, sequence, **kwargs):
50 |         return_mask = kwargs.pop('return_mask', False)
51 | 
52 |         # arguments
53 |         _kwargs = {'return_tensors': 'pt'}
54 |         if self.seq_len is not None:
55 |             _kwargs.update({
56 |                 'padding': 'max_length',
57 |                 'truncation': True,
58 |                 'max_length': self.seq_len
59 |             })
60 |         _kwargs.update(**kwargs)
61 | 
62 |         # tokenization
63 |         if isinstance(sequence, str):
64 |             sequence = [sequence]
65 |         if self.clean:
66 |             sequence = [self._clean(u) for u in sequence]
67 |         ids = self.tokenizer(sequence, **_kwargs)
68 | 
69 |         # output
70 |         if return_mask:
71 |             return ids.input_ids, ids.attention_mask
72 |         else:
73 |             return ids.input_ids
74 | 
75 |     def _clean(self, text):
76 |         if self.clean == 'whitespace':
77 |             text = whitespace_clean(basic_clean(text))
78 |         elif self.clean == 'lower':
79 |             text = whitespace_clean(basic_clean(text)).lower()
80 |         elif self.clean == 'canonicalize':
81 |             text = canonicalize(basic_clean(text))
82 |         return text
83 | 


--------------------------------------------------------------------------------
/wan/modules/vace_model.py:
--------------------------------------------------------------------------------
  1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
  2 | import torch
  3 | import torch.cuda.amp as amp
  4 | import torch.nn as nn
  5 | from diffusers.configuration_utils import register_to_config
  6 | 
  7 | from .model import WanAttentionBlock, WanModel, sinusoidal_embedding_1d
  8 | 
  9 | 
 10 | class VaceWanAttentionBlock(WanAttentionBlock):
 11 | 
 12 |     def __init__(self,
 13 |                  cross_attn_type,
 14 |                  dim,
 15 |                  ffn_dim,
 16 |                  num_heads,
 17 |                  window_size=(-1, -1),
 18 |                  qk_norm=True,
 19 |                  cross_attn_norm=False,
 20 |                  eps=1e-6,
 21 |                  block_id=0):
 22 |         super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size,
 23 |                          qk_norm, cross_attn_norm, eps)
 24 |         self.block_id = block_id
 25 |         if block_id == 0:
 26 |             self.before_proj = nn.Linear(self.dim, self.dim)
 27 |             nn.init.zeros_(self.before_proj.weight)
 28 |             nn.init.zeros_(self.before_proj.bias)
 29 |         self.after_proj = nn.Linear(self.dim, self.dim)
 30 |         nn.init.zeros_(self.after_proj.weight)
 31 |         nn.init.zeros_(self.after_proj.bias)
 32 | 
 33 |     def forward(self, c, x, **kwargs):
 34 |         if self.block_id == 0:
 35 |             c = self.before_proj(c) + x
 36 | 
 37 |         c = super().forward(c, **kwargs)
 38 |         c_skip = self.after_proj(c)
 39 |         return c, c_skip
 40 | 
 41 | 
 42 | class BaseWanAttentionBlock(WanAttentionBlock):
 43 | 
 44 |     def __init__(self,
 45 |                  cross_attn_type,
 46 |                  dim,
 47 |                  ffn_dim,
 48 |                  num_heads,
 49 |                  window_size=(-1, -1),
 50 |                  qk_norm=True,
 51 |                  cross_attn_norm=False,
 52 |                  eps=1e-6,
 53 |                  block_id=None):
 54 |         super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size,
 55 |                          qk_norm, cross_attn_norm, eps)
 56 |         self.block_id = block_id
 57 | 
 58 |     def forward(self, x, hints, context_scale=1.0, **kwargs):
 59 |         x = super().forward(x, **kwargs)
 60 |         if self.block_id is not None:
 61 |             x = x + hints[self.block_id] * context_scale
 62 |         return x
 63 | 
 64 | 
 65 | class VaceWanModel(WanModel):
 66 | 
 67 |     @register_to_config
 68 |     def __init__(self,
 69 |                  vace_layers=None,
 70 |                  vace_in_dim=None,
 71 |                  model_type='vace',
 72 |                  patch_size=(1, 2, 2),
 73 |                  text_len=512,
 74 |                  in_dim=16,
 75 |                  dim=2048,
 76 |                  ffn_dim=8192,
 77 |                  freq_dim=256,
 78 |                  text_dim=4096,
 79 |                  out_dim=16,
 80 |                  num_heads=16,
 81 |                  num_layers=32,
 82 |                  window_size=(-1, -1),
 83 |                  qk_norm=True,
 84 |                  cross_attn_norm=True,
 85 |                  eps=1e-6):
 86 |         super().__init__(model_type, patch_size, text_len, in_dim, dim, ffn_dim,
 87 |                          freq_dim, text_dim, out_dim, num_heads, num_layers,
 88 |                          window_size, qk_norm, cross_attn_norm, eps)
 89 | 
 90 |         self.vace_layers = [i for i in range(0, self.num_layers, 2)
 91 |                            ] if vace_layers is None else vace_layers
 92 |         self.vace_in_dim = self.in_dim if vace_in_dim is None else vace_in_dim
 93 | 
 94 |         assert 0 in self.vace_layers
 95 |         self.vace_layers_mapping = {
 96 |             i: n for n, i in enumerate(self.vace_layers)
 97 |         }
 98 | 
 99 |         # blocks
100 |         self.blocks = nn.ModuleList([
101 |             BaseWanAttentionBlock(
102 |                 't2v_cross_attn',
103 |                 self.dim,
104 |                 self.ffn_dim,
105 |                 self.num_heads,
106 |                 self.window_size,
107 |                 self.qk_norm,
108 |                 self.cross_attn_norm,
109 |                 self.eps,
110 |                 block_id=self.vace_layers_mapping[i]
111 |                 if i in self.vace_layers else None)
112 |             for i in range(self.num_layers)
113 |         ])
114 | 
115 |         # vace blocks
116 |         self.vace_blocks = nn.ModuleList([
117 |             VaceWanAttentionBlock(
118 |                 't2v_cross_attn',
119 |                 self.dim,
120 |                 self.ffn_dim,
121 |                 self.num_heads,
122 |                 self.window_size,
123 |                 self.qk_norm,
124 |                 self.cross_attn_norm,
125 |                 self.eps,
126 |                 block_id=i) for i in self.vace_layers
127 |         ])
128 | 
129 |         # vace patch embeddings
130 |         self.vace_patch_embedding = nn.Conv3d(
131 |             self.vace_in_dim,
132 |             self.dim,
133 |             kernel_size=self.patch_size,
134 |             stride=self.patch_size)
135 | 
136 |     def forward_vace(self, x, vace_context, seq_len, kwargs):
137 |         # embeddings
138 |         c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
139 |         c = [u.flatten(2).transpose(1, 2) for u in c]
140 |         c = torch.cat([
141 |             torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
142 |                       dim=1) for u in c
143 |         ])
144 | 
145 |         # arguments
146 |         new_kwargs = dict(x=x)
147 |         new_kwargs.update(kwargs)
148 | 
149 |         hints = []
150 |         for block in self.vace_blocks:
151 |             c, c_skip = block(c, **new_kwargs)
152 |             hints.append(c_skip)
153 |         return hints
154 | 
155 |     def forward(
156 |         self,
157 |         x,
158 |         t,
159 |         vace_context,
160 |         context,
161 |         seq_len,
162 |         vace_context_scale=1.0,
163 |         clip_fea=None,
164 |         y=None,
165 |     ):
166 |         r"""
167 |         Forward pass through the diffusion model
168 | 
169 |         Args:
170 |             x (List[Tensor]):
171 |                 List of input video tensors, each with shape [C_in, F, H, W]
172 |             t (Tensor):
173 |                 Diffusion timesteps tensor of shape [B]
174 |             context (List[Tensor]):
175 |                 List of text embeddings each with shape [L, C]
176 |             seq_len (`int`):
177 |                 Maximum sequence length for positional encoding
178 |             clip_fea (Tensor, *optional*):
179 |                 CLIP image features for image-to-video mode
180 |             y (List[Tensor], *optional*):
181 |                 Conditional video inputs for image-to-video mode, same shape as x
182 | 
183 |         Returns:
184 |             List[Tensor]:
185 |                 List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
186 |         """
187 |         # if self.model_type == 'i2v':
188 |         #     assert clip_fea is not None and y is not None
189 |         # params
190 |         device = self.patch_embedding.weight.device
191 |         if self.freqs.device != device:
192 |             self.freqs = self.freqs.to(device)
193 | 
194 |         # if y is not None:
195 |         #     x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
196 | 
197 |         # embeddings
198 |         x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
199 |         grid_sizes = torch.stack(
200 |             [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
201 |         x = [u.flatten(2).transpose(1, 2) for u in x]
202 |         seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
203 |         assert seq_lens.max() <= seq_len
204 |         x = torch.cat([
205 |             torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
206 |                       dim=1) for u in x
207 |         ])
208 | 
209 |         # time embeddings
210 |         with amp.autocast(dtype=torch.float32):
211 |             e = self.time_embedding(
212 |                 sinusoidal_embedding_1d(self.freq_dim, t).float())
213 |             e0 = self.time_projection(e).unflatten(1, (6, self.dim))
214 |             assert e.dtype == torch.float32 and e0.dtype == torch.float32
215 | 
216 |         # context
217 |         context_lens = None
218 |         context = self.text_embedding(
219 |             torch.stack([
220 |                 torch.cat(
221 |                     [u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
222 |                 for u in context
223 |             ]))
224 | 
225 |         # if clip_fea is not None:
226 |         #     context_clip = self.img_emb(clip_fea)  # bs x 257 x dim
227 |         #     context = torch.concat([context_clip, context], dim=1)
228 | 
229 |         # arguments
230 |         kwargs = dict(
231 |             e=e0,
232 |             seq_lens=seq_lens,
233 |             grid_sizes=grid_sizes,
234 |             freqs=self.freqs,
235 |             context=context,
236 |             context_lens=context_lens)
237 | 
238 |         hints = self.forward_vace(x, vace_context, seq_len, kwargs)
239 |         kwargs['hints'] = hints
240 |         kwargs['context_scale'] = vace_context_scale
241 | 
242 |         for block in self.blocks:
243 |             x = block(x, **kwargs)
244 | 
245 |         # head
246 |         x = self.head(x, e)
247 | 
248 |         # unpatchify
249 |         x = self.unpatchify(x, grid_sizes)
250 |         return [u.float() for u in x]
251 | 


--------------------------------------------------------------------------------
/wan/modules/xlm_roberta.py:
--------------------------------------------------------------------------------
  1 | # Modified from transformers.models.xlm_roberta.modeling_xlm_roberta
  2 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
  3 | import torch
  4 | import torch.nn as nn
  5 | import torch.nn.functional as F
  6 | 
  7 | __all__ = ['XLMRoberta', 'xlm_roberta_large']
  8 | 
  9 | 
 10 | class SelfAttention(nn.Module):
 11 | 
 12 |     def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):
 13 |         assert dim % num_heads == 0
 14 |         super().__init__()
 15 |         self.dim = dim
 16 |         self.num_heads = num_heads
 17 |         self.head_dim = dim // num_heads
 18 |         self.eps = eps
 19 | 
 20 |         # layers
 21 |         self.q = nn.Linear(dim, dim)
 22 |         self.k = nn.Linear(dim, dim)
 23 |         self.v = nn.Linear(dim, dim)
 24 |         self.o = nn.Linear(dim, dim)
 25 |         self.dropout = nn.Dropout(dropout)
 26 | 
 27 |     def forward(self, x, mask):
 28 |         """
 29 |         x:   [B, L, C].
 30 |         """
 31 |         b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
 32 | 
 33 |         # compute query, key, value
 34 |         q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
 35 |         k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
 36 |         v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
 37 | 
 38 |         # compute attention
 39 |         p = self.dropout.p if self.training else 0.0
 40 |         x = F.scaled_dot_product_attention(q, k, v, mask, p)
 41 |         x = x.permute(0, 2, 1, 3).reshape(b, s, c)
 42 | 
 43 |         # output
 44 |         x = self.o(x)
 45 |         x = self.dropout(x)
 46 |         return x
 47 | 
 48 | 
 49 | class AttentionBlock(nn.Module):
 50 | 
 51 |     def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):
 52 |         super().__init__()
 53 |         self.dim = dim
 54 |         self.num_heads = num_heads
 55 |         self.post_norm = post_norm
 56 |         self.eps = eps
 57 | 
 58 |         # layers
 59 |         self.attn = SelfAttention(dim, num_heads, dropout, eps)
 60 |         self.norm1 = nn.LayerNorm(dim, eps=eps)
 61 |         self.ffn = nn.Sequential(
 62 |             nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim),
 63 |             nn.Dropout(dropout))
 64 |         self.norm2 = nn.LayerNorm(dim, eps=eps)
 65 | 
 66 |     def forward(self, x, mask):
 67 |         if self.post_norm:
 68 |             x = self.norm1(x + self.attn(x, mask))
 69 |             x = self.norm2(x + self.ffn(x))
 70 |         else:
 71 |             x = x + self.attn(self.norm1(x), mask)
 72 |             x = x + self.ffn(self.norm2(x))
 73 |         return x
 74 | 
 75 | 
 76 | class XLMRoberta(nn.Module):
 77 |     """
 78 |     XLMRobertaModel with no pooler and no LM head.
 79 |     """
 80 | 
 81 |     def __init__(self,
 82 |                  vocab_size=250002,
 83 |                  max_seq_len=514,
 84 |                  type_size=1,
 85 |                  pad_id=1,
 86 |                  dim=1024,
 87 |                  num_heads=16,
 88 |                  num_layers=24,
 89 |                  post_norm=True,
 90 |                  dropout=0.1,
 91 |                  eps=1e-5):
 92 |         super().__init__()
 93 |         self.vocab_size = vocab_size
 94 |         self.max_seq_len = max_seq_len
 95 |         self.type_size = type_size
 96 |         self.pad_id = pad_id
 97 |         self.dim = dim
 98 |         self.num_heads = num_heads
 99 |         self.num_layers = num_layers
100 |         self.post_norm = post_norm
101 |         self.eps = eps
102 | 
103 |         # embeddings
104 |         self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id)
105 |         self.type_embedding = nn.Embedding(type_size, dim)
106 |         self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id)
107 |         self.dropout = nn.Dropout(dropout)
108 | 
109 |         # blocks
110 |         self.blocks = nn.ModuleList([
111 |             AttentionBlock(dim, num_heads, post_norm, dropout, eps)
112 |             for _ in range(num_layers)
113 |         ])
114 | 
115 |         # norm layer
116 |         self.norm = nn.LayerNorm(dim, eps=eps)
117 | 
118 |     def forward(self, ids):
119 |         """
120 |         ids: [B, L] of torch.LongTensor.
121 |         """
122 |         b, s = ids.shape
123 |         mask = ids.ne(self.pad_id).long()
124 | 
125 |         # embeddings
126 |         x = self.token_embedding(ids) + \
127 |             self.type_embedding(torch.zeros_like(ids)) + \
128 |             self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask)
129 |         if self.post_norm:
130 |             x = self.norm(x)
131 |         x = self.dropout(x)
132 | 
133 |         # blocks
134 |         mask = torch.where(
135 |             mask.view(b, 1, 1, s).gt(0), 0.0,
136 |             torch.finfo(x.dtype).min)
137 |         for block in self.blocks:
138 |             x = block(x, mask)
139 | 
140 |         # output
141 |         if not self.post_norm:
142 |             x = self.norm(x)
143 |         return x
144 | 
145 | 
146 | def xlm_roberta_large(pretrained=False,
147 |                       return_tokenizer=False,
148 |                       device='cpu',
149 |                       **kwargs):
150 |     """
151 |     XLMRobertaLarge adapted from Huggingface.
152 |     """
153 |     # params
154 |     cfg = dict(
155 |         vocab_size=250002,
156 |         max_seq_len=514,
157 |         type_size=1,
158 |         pad_id=1,
159 |         dim=1024,
160 |         num_heads=16,
161 |         num_layers=24,
162 |         post_norm=True,
163 |         dropout=0.1,
164 |         eps=1e-5)
165 |     cfg.update(**kwargs)
166 | 
167 |     # init a model on device
168 |     with torch.device(device):
169 |         model = XLMRoberta(**cfg)
170 |     return model
171 | 


--------------------------------------------------------------------------------
/wan/text2video.py:
--------------------------------------------------------------------------------
  1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
  2 | import gc
  3 | import logging
  4 | import math
  5 | import os
  6 | import random
  7 | import sys
  8 | import types
  9 | from contextlib import contextmanager
 10 | from functools import partial
 11 | 
 12 | import torch
 13 | import torch.cuda.amp as amp
 14 | import torch.distributed as dist
 15 | from tqdm import tqdm
 16 | 
 17 | from .distributed.fsdp import shard_model
 18 | from .modules.model import WanModel
 19 | from .modules.t5 import T5EncoderModel
 20 | from .modules.vae import WanVAE
 21 | from .utils.fm_solvers import (
 22 |     FlowDPMSolverMultistepScheduler,
 23 |     get_sampling_sigmas,
 24 |     retrieve_timesteps,
 25 | )
 26 | from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
 27 | 
 28 | 
 29 | class WanT2V:
 30 | 
 31 |     def __init__(
 32 |         self,
 33 |         config,
 34 |         checkpoint_dir,
 35 |         device_id=0,
 36 |         rank=0,
 37 |         t5_fsdp=False,
 38 |         dit_fsdp=False,
 39 |         use_usp=False,
 40 |         t5_cpu=False,
 41 |     ):
 42 |         r"""
 43 |         Initializes the Wan text-to-video generation model components.
 44 | 
 45 |         Args:
 46 |             config (EasyDict):
 47 |                 Object containing model parameters initialized from config.py
 48 |             checkpoint_dir (`str`):
 49 |                 Path to directory containing model checkpoints
 50 |             device_id (`int`,  *optional*, defaults to 0):
 51 |                 Id of target GPU device
 52 |             rank (`int`,  *optional*, defaults to 0):
 53 |                 Process rank for distributed training
 54 |             t5_fsdp (`bool`, *optional*, defaults to False):
 55 |                 Enable FSDP sharding for T5 model
 56 |             dit_fsdp (`bool`, *optional*, defaults to False):
 57 |                 Enable FSDP sharding for DiT model
 58 |             use_usp (`bool`, *optional*, defaults to False):
 59 |                 Enable distribution strategy of USP.
 60 |             t5_cpu (`bool`, *optional*, defaults to False):
 61 |                 Whether to place T5 model on CPU. Only works without t5_fsdp.
 62 |         """
 63 |         self.device = torch.device(f"cuda:{device_id}")
 64 |         self.config = config
 65 |         self.rank = rank
 66 |         self.t5_cpu = t5_cpu
 67 | 
 68 |         self.num_train_timesteps = config.num_train_timesteps
 69 |         self.param_dtype = config.param_dtype
 70 | 
 71 |         shard_fn = partial(shard_model, device_id=device_id)
 72 |         self.text_encoder = T5EncoderModel(
 73 |             text_len=config.text_len,
 74 |             dtype=config.t5_dtype,
 75 |             device=torch.device('cpu'),
 76 |             checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
 77 |             tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
 78 |             shard_fn=shard_fn if t5_fsdp else None)
 79 | 
 80 |         self.vae_stride = config.vae_stride
 81 |         self.patch_size = config.patch_size
 82 |         self.vae = WanVAE(
 83 |             vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
 84 |             device=self.device)
 85 | 
 86 |         logging.info(f"Creating WanModel from {checkpoint_dir}")
 87 |         self.model = WanModel.from_pretrained(checkpoint_dir)
 88 |         self.model.eval().requires_grad_(False)
 89 | 
 90 |         if use_usp:
 91 |             from xfuser.core.distributed import get_sequence_parallel_world_size
 92 | 
 93 |             from .distributed.xdit_context_parallel import (
 94 |                 usp_attn_forward,
 95 |                 usp_dit_forward,
 96 |             )
 97 |             for block in self.model.blocks:
 98 |                 block.self_attn.forward = types.MethodType(
 99 |                     usp_attn_forward, block.self_attn)
100 |             self.model.forward = types.MethodType(usp_dit_forward, self.model)
101 |             self.sp_size = get_sequence_parallel_world_size()
102 |         else:
103 |             self.sp_size = 1
104 | 
105 |         if dist.is_initialized():
106 |             dist.barrier()
107 |         if dit_fsdp:
108 |             self.model = shard_fn(self.model)
109 |         else:
110 |             self.model.to(self.device)
111 | 
112 |         self.sample_neg_prompt = config.sample_neg_prompt
113 | 
114 |     def generate(self,
115 |                  input_prompt,
116 |                  size=(1280, 720),
117 |                  frame_num=81,
118 |                  shift=5.0,
119 |                  sample_solver='unipc',
120 |                  sampling_steps=50,
121 |                  guide_scale=5.0,
122 |                  n_prompt="",
123 |                  seed=-1,
124 |                  offload_model=True):
125 |         r"""
126 |         Generates video frames from text prompt using diffusion process.
127 | 
128 |         Args:
129 |             input_prompt (`str`):
130 |                 Text prompt for content generation
131 |             size (tupele[`int`], *optional*, defaults to (1280,720)):
132 |                 Controls video resolution, (width,height).
133 |             frame_num (`int`, *optional*, defaults to 81):
134 |                 How many frames to sample from a video. The number should be 4n+1
135 |             shift (`float`, *optional*, defaults to 5.0):
136 |                 Noise schedule shift parameter. Affects temporal dynamics
137 |             sample_solver (`str`, *optional*, defaults to 'unipc'):
138 |                 Solver used to sample the video.
139 |             sampling_steps (`int`, *optional*, defaults to 40):
140 |                 Number of diffusion sampling steps. Higher values improve quality but slow generation
141 |             guide_scale (`float`, *optional*, defaults 5.0):
142 |                 Classifier-free guidance scale. Controls prompt adherence vs. creativity
143 |             n_prompt (`str`, *optional*, defaults to ""):
144 |                 Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
145 |             seed (`int`, *optional*, defaults to -1):
146 |                 Random seed for noise generation. If -1, use random seed.
147 |             offload_model (`bool`, *optional*, defaults to True):
148 |                 If True, offloads models to CPU during generation to save VRAM
149 | 
150 |         Returns:
151 |             torch.Tensor:
152 |                 Generated video frames tensor. Dimensions: (C, N H, W) where:
153 |                 - C: Color channels (3 for RGB)
154 |                 - N: Number of frames (81)
155 |                 - H: Frame height (from size)
156 |                 - W: Frame width from size)
157 |         """
158 |         # preprocess
159 |         F = frame_num
160 |         target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
161 |                         size[1] // self.vae_stride[1],
162 |                         size[0] // self.vae_stride[2])
163 | 
164 |         seq_len = math.ceil((target_shape[2] * target_shape[3]) /
165 |                             (self.patch_size[1] * self.patch_size[2]) *
166 |                             target_shape[1] / self.sp_size) * self.sp_size
167 | 
168 |         if n_prompt == "":
169 |             n_prompt = self.sample_neg_prompt
170 |         seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
171 |         seed_g = torch.Generator(device=self.device)
172 |         seed_g.manual_seed(seed)
173 | 
174 |         if not self.t5_cpu:
175 |             self.text_encoder.model.to(self.device)
176 |             context = self.text_encoder([input_prompt], self.device)
177 |             context_null = self.text_encoder([n_prompt], self.device)
178 |             if offload_model:
179 |                 self.text_encoder.model.cpu()
180 |         else:
181 |             context = self.text_encoder([input_prompt], torch.device('cpu'))
182 |             context_null = self.text_encoder([n_prompt], torch.device('cpu'))
183 |             context = [t.to(self.device) for t in context]
184 |             context_null = [t.to(self.device) for t in context_null]
185 | 
186 |         noise = [
187 |             torch.randn(
188 |                 target_shape[0],
189 |                 target_shape[1],
190 |                 target_shape[2],
191 |                 target_shape[3],
192 |                 dtype=torch.float32,
193 |                 device=self.device,
194 |                 generator=seed_g)
195 |         ]
196 | 
197 |         @contextmanager
198 |         def noop_no_sync():
199 |             yield
200 | 
201 |         no_sync = getattr(self.model, 'no_sync', noop_no_sync)
202 | 
203 |         # evaluation mode
204 |         with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync():
205 | 
206 |             if sample_solver == 'unipc':
207 |                 sample_scheduler = FlowUniPCMultistepScheduler(
208 |                     num_train_timesteps=self.num_train_timesteps,
209 |                     shift=1,
210 |                     use_dynamic_shifting=False)
211 |                 sample_scheduler.set_timesteps(
212 |                     sampling_steps, device=self.device, shift=shift)
213 |                 timesteps = sample_scheduler.timesteps
214 |             elif sample_solver == 'dpm++':
215 |                 sample_scheduler = FlowDPMSolverMultistepScheduler(
216 |                     num_train_timesteps=self.num_train_timesteps,
217 |                     shift=1,
218 |                     use_dynamic_shifting=False)
219 |                 sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
220 |                 timesteps, _ = retrieve_timesteps(
221 |                     sample_scheduler,
222 |                     device=self.device,
223 |                     sigmas=sampling_sigmas)
224 |             else:
225 |                 raise NotImplementedError("Unsupported solver.")
226 | 
227 |             # sample videos
228 |             latents = noise
229 | 
230 |             arg_c = {'context': context, 'seq_len': seq_len}
231 |             arg_null = {'context': context_null, 'seq_len': seq_len}
232 | 
233 |             for _, t in enumerate(tqdm(timesteps)):
234 |                 latent_model_input = latents
235 |                 timestep = [t]
236 | 
237 |                 timestep = torch.stack(timestep)
238 | 
239 |                 self.model.to(self.device)
240 |                 noise_pred_cond = self.model(
241 |                     latent_model_input, t=timestep, **arg_c)[0]
242 |                 noise_pred_uncond = self.model(
243 |                     latent_model_input, t=timestep, **arg_null)[0]
244 | 
245 |                 noise_pred = noise_pred_uncond + guide_scale * (
246 |                     noise_pred_cond - noise_pred_uncond)
247 | 
248 |                 temp_x0 = sample_scheduler.step(
249 |                     noise_pred.unsqueeze(0),
250 |                     t,
251 |                     latents[0].unsqueeze(0),
252 |                     return_dict=False,
253 |                     generator=seed_g)[0]
254 |                 latents = [temp_x0.squeeze(0)]
255 | 
256 |             x0 = latents
257 |             if offload_model:
258 |                 self.model.cpu()
259 |                 torch.cuda.empty_cache()
260 |             if self.rank == 0:
261 |                 videos = self.vae.decode(x0)
262 | 
263 |         del noise, latents
264 |         del sample_scheduler
265 |         if offload_model:
266 |             gc.collect()
267 |             torch.cuda.synchronize()
268 |         if dist.is_initialized():
269 |             dist.barrier()
270 | 
271 |         return videos[0] if self.rank == 0 else None
272 | 


--------------------------------------------------------------------------------
/wan/utils/__init__.py:
--------------------------------------------------------------------------------
 1 | from .fm_solvers import (
 2 |     FlowDPMSolverMultistepScheduler,
 3 |     get_sampling_sigmas,
 4 |     retrieve_timesteps,
 5 | )
 6 | from .fm_solvers_unipc import FlowUniPCMultistepScheduler
 7 | from .vace_processor import VaceVideoProcessor
 8 | 
 9 | __all__ = [
10 |     'HuggingfaceTokenizer', 'get_sampling_sigmas', 'retrieve_timesteps',
11 |     'FlowDPMSolverMultistepScheduler', 'FlowUniPCMultistepScheduler',
12 |     'VaceVideoProcessor'
13 | ]
14 | 


--------------------------------------------------------------------------------
/wan/utils/qwen_vl_utils.py:
--------------------------------------------------------------------------------
  1 | # Copied from https://github.com/kq-chen/qwen-vl-utils
  2 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
  3 | from __future__ import annotations
  4 | 
  5 | import base64
  6 | import logging
  7 | import math
  8 | import os
  9 | import sys
 10 | import time
 11 | import warnings
 12 | from functools import lru_cache
 13 | from io import BytesIO
 14 | 
 15 | import requests
 16 | import torch
 17 | import torchvision
 18 | from packaging import version
 19 | from PIL import Image
 20 | from torchvision import io, transforms
 21 | from torchvision.transforms import InterpolationMode
 22 | 
 23 | logger = logging.getLogger(__name__)
 24 | 
 25 | IMAGE_FACTOR = 28
 26 | MIN_PIXELS = 4 * 28 * 28
 27 | MAX_PIXELS = 16384 * 28 * 28
 28 | MAX_RATIO = 200
 29 | 
 30 | VIDEO_MIN_PIXELS = 128 * 28 * 28
 31 | VIDEO_MAX_PIXELS = 768 * 28 * 28
 32 | VIDEO_TOTAL_PIXELS = 24576 * 28 * 28
 33 | FRAME_FACTOR = 2
 34 | FPS = 2.0
 35 | FPS_MIN_FRAMES = 4
 36 | FPS_MAX_FRAMES = 768
 37 | 
 38 | 
 39 | def round_by_factor(number: int, factor: int) -> int:
 40 |     """Returns the closest integer to 'number' that is divisible by 'factor'."""
 41 |     return round(number / factor) * factor
 42 | 
 43 | 
 44 | def ceil_by_factor(number: int, factor: int) -> int:
 45 |     """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
 46 |     return math.ceil(number / factor) * factor
 47 | 
 48 | 
 49 | def floor_by_factor(number: int, factor: int) -> int:
 50 |     """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
 51 |     return math.floor(number / factor) * factor
 52 | 
 53 | 
 54 | def smart_resize(height: int,
 55 |                  width: int,
 56 |                  factor: int = IMAGE_FACTOR,
 57 |                  min_pixels: int = MIN_PIXELS,
 58 |                  max_pixels: int = MAX_PIXELS) -> tuple[int, int]:
 59 |     """
 60 |     Rescales the image so that the following conditions are met:
 61 | 
 62 |     1. Both dimensions (height and width) are divisible by 'factor'.
 63 | 
 64 |     2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
 65 | 
 66 |     3. The aspect ratio of the image is maintained as closely as possible.
 67 |     """
 68 |     if max(height, width) / min(height, width) > MAX_RATIO:
 69 |         raise ValueError(
 70 |             f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
 71 |         )
 72 |     h_bar = max(factor, round_by_factor(height, factor))
 73 |     w_bar = max(factor, round_by_factor(width, factor))
 74 |     if h_bar * w_bar > max_pixels:
 75 |         beta = math.sqrt((height * width) / max_pixels)
 76 |         h_bar = floor_by_factor(height / beta, factor)
 77 |         w_bar = floor_by_factor(width / beta, factor)
 78 |     elif h_bar * w_bar < min_pixels:
 79 |         beta = math.sqrt(min_pixels / (height * width))
 80 |         h_bar = ceil_by_factor(height * beta, factor)
 81 |         w_bar = ceil_by_factor(width * beta, factor)
 82 |     return h_bar, w_bar
 83 | 
 84 | 
 85 | def fetch_image(ele: dict[str, str | Image.Image],
 86 |                 size_factor: int = IMAGE_FACTOR) -> Image.Image:
 87 |     if "image" in ele:
 88 |         image = ele["image"]
 89 |     else:
 90 |         image = ele["image_url"]
 91 |     image_obj = None
 92 |     if isinstance(image, Image.Image):
 93 |         image_obj = image
 94 |     elif image.startswith("http://") or image.startswith("https://"):
 95 |         image_obj = Image.open(requests.get(image, stream=True).raw)
 96 |     elif image.startswith("file://"):
 97 |         image_obj = Image.open(image[7:])
 98 |     elif image.startswith("data:image"):
 99 |         if "base64," in image:
100 |             _, base64_data = image.split("base64,", 1)
101 |             data = base64.b64decode(base64_data)
102 |             image_obj = Image.open(BytesIO(data))
103 |     else:
104 |         image_obj = Image.open(image)
105 |     if image_obj is None:
106 |         raise ValueError(
107 |             f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}"
108 |         )
109 |     image = image_obj.convert("RGB")
110 |     ## resize
111 |     if "resized_height" in ele and "resized_width" in ele:
112 |         resized_height, resized_width = smart_resize(
113 |             ele["resized_height"],
114 |             ele["resized_width"],
115 |             factor=size_factor,
116 |         )
117 |     else:
118 |         width, height = image.size
119 |         min_pixels = ele.get("min_pixels", MIN_PIXELS)
120 |         max_pixels = ele.get("max_pixels", MAX_PIXELS)
121 |         resized_height, resized_width = smart_resize(
122 |             height,
123 |             width,
124 |             factor=size_factor,
125 |             min_pixels=min_pixels,
126 |             max_pixels=max_pixels,
127 |         )
128 |     image = image.resize((resized_width, resized_height))
129 | 
130 |     return image
131 | 
132 | 
133 | def smart_nframes(
134 |     ele: dict,
135 |     total_frames: int,
136 |     video_fps: int | float,
137 | ) -> int:
138 |     """calculate the number of frames for video used for model inputs.
139 | 
140 |     Args:
141 |         ele (dict): a dict contains the configuration of video.
142 |             support either `fps` or `nframes`:
143 |                 - nframes: the number of frames to extract for model inputs.
144 |                 - fps: the fps to extract frames for model inputs.
145 |                     - min_frames: the minimum number of frames of the video, only used when fps is provided.
146 |                     - max_frames: the maximum number of frames of the video, only used when fps is provided.
147 |         total_frames (int): the original total number of frames of the video.
148 |         video_fps (int | float): the original fps of the video.
149 | 
150 |     Raises:
151 |         ValueError: nframes should in interval [FRAME_FACTOR, total_frames].
152 | 
153 |     Returns:
154 |         int: the number of frames for video used for model inputs.
155 |     """
156 |     assert not ("fps" in ele and
157 |                 "nframes" in ele), "Only accept either `fps` or `nframes`"
158 |     if "nframes" in ele:
159 |         nframes = round_by_factor(ele["nframes"], FRAME_FACTOR)
160 |     else:
161 |         fps = ele.get("fps", FPS)
162 |         min_frames = ceil_by_factor(
163 |             ele.get("min_frames", FPS_MIN_FRAMES), FRAME_FACTOR)
164 |         max_frames = floor_by_factor(
165 |             ele.get("max_frames", min(FPS_MAX_FRAMES, total_frames)),
166 |             FRAME_FACTOR)
167 |         nframes = total_frames / video_fps * fps
168 |         nframes = min(max(nframes, min_frames), max_frames)
169 |         nframes = round_by_factor(nframes, FRAME_FACTOR)
170 |     if not (FRAME_FACTOR <= nframes and nframes <= total_frames):
171 |         raise ValueError(
172 |             f"nframes should in interval [{FRAME_FACTOR}, {total_frames}], but got {nframes}."
173 |         )
174 |     return nframes
175 | 
176 | 
177 | def _read_video_torchvision(ele: dict,) -> torch.Tensor:
178 |     """read video using torchvision.io.read_video
179 | 
180 |     Args:
181 |         ele (dict): a dict contains the configuration of video.
182 |         support keys:
183 |             - video: the path of video. support "file://", "http://", "https://" and local path.
184 |             - video_start: the start time of video.
185 |             - video_end: the end time of video.
186 |     Returns:
187 |         torch.Tensor: the video tensor with shape (T, C, H, W).
188 |     """
189 |     video_path = ele["video"]
190 |     if version.parse(torchvision.__version__) < version.parse("0.19.0"):
191 |         if "http://" in video_path or "https://" in video_path:
192 |             warnings.warn(
193 |                 "torchvision < 0.19.0 does not support http/https video path, please upgrade to 0.19.0."
194 |             )
195 |         if "file://" in video_path:
196 |             video_path = video_path[7:]
197 |     st = time.time()
198 |     video, audio, info = io.read_video(
199 |         video_path,
200 |         start_pts=ele.get("video_start", 0.0),
201 |         end_pts=ele.get("video_end", None),
202 |         pts_unit="sec",
203 |         output_format="TCHW",
204 |     )
205 |     total_frames, video_fps = video.size(0), info["video_fps"]
206 |     logger.info(
207 |         f"torchvision:  {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s"
208 |     )
209 |     nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)
210 |     idx = torch.linspace(0, total_frames - 1, nframes).round().long()
211 |     video = video[idx]
212 |     return video
213 | 
214 | 
215 | def is_decord_available() -> bool:
216 |     import importlib.util
217 | 
218 |     return importlib.util.find_spec("decord") is not None
219 | 
220 | 
221 | def _read_video_decord(ele: dict,) -> torch.Tensor:
222 |     """read video using decord.VideoReader
223 | 
224 |     Args:
225 |         ele (dict): a dict contains the configuration of video.
226 |         support keys:
227 |             - video: the path of video. support "file://", "http://", "https://" and local path.
228 |             - video_start: the start time of video.
229 |             - video_end: the end time of video.
230 |     Returns:
231 |         torch.Tensor: the video tensor with shape (T, C, H, W).
232 |     """
233 |     import decord
234 |     video_path = ele["video"]
235 |     st = time.time()
236 |     vr = decord.VideoReader(video_path)
237 |     # TODO: support start_pts and end_pts
238 |     if 'video_start' in ele or 'video_end' in ele:
239 |         raise NotImplementedError(
240 |             "not support start_pts and end_pts in decord for now.")
241 |     total_frames, video_fps = len(vr), vr.get_avg_fps()
242 |     logger.info(
243 |         f"decord:  {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s"
244 |     )
245 |     nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)
246 |     idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist()
247 |     video = vr.get_batch(idx).asnumpy()
248 |     video = torch.tensor(video).permute(0, 3, 1, 2)  # Convert to TCHW format
249 |     return video
250 | 
251 | 
252 | VIDEO_READER_BACKENDS = {
253 |     "decord": _read_video_decord,
254 |     "torchvision": _read_video_torchvision,
255 | }
256 | 
257 | FORCE_QWENVL_VIDEO_READER = os.getenv("FORCE_QWENVL_VIDEO_READER", None)
258 | 
259 | 
260 | @lru_cache(maxsize=1)
261 | def get_video_reader_backend() -> str:
262 |     if FORCE_QWENVL_VIDEO_READER is not None:
263 |         video_reader_backend = FORCE_QWENVL_VIDEO_READER
264 |     elif is_decord_available():
265 |         video_reader_backend = "decord"
266 |     else:
267 |         video_reader_backend = "torchvision"
268 |     print(
269 |         f"qwen-vl-utils using {video_reader_backend} to read video.",
270 |         file=sys.stderr)
271 |     return video_reader_backend
272 | 
273 | 
274 | def fetch_video(
275 |         ele: dict,
276 |         image_factor: int = IMAGE_FACTOR) -> torch.Tensor | list[Image.Image]:
277 |     if isinstance(ele["video"], str):
278 |         video_reader_backend = get_video_reader_backend()
279 |         video = VIDEO_READER_BACKENDS[video_reader_backend](ele)
280 |         nframes, _, height, width = video.shape
281 | 
282 |         min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS)
283 |         total_pixels = ele.get("total_pixels", VIDEO_TOTAL_PIXELS)
284 |         max_pixels = max(
285 |             min(VIDEO_MAX_PIXELS, total_pixels / nframes * FRAME_FACTOR),
286 |             int(min_pixels * 1.05))
287 |         max_pixels = ele.get("max_pixels", max_pixels)
288 |         if "resized_height" in ele and "resized_width" in ele:
289 |             resized_height, resized_width = smart_resize(
290 |                 ele["resized_height"],
291 |                 ele["resized_width"],
292 |                 factor=image_factor,
293 |             )
294 |         else:
295 |             resized_height, resized_width = smart_resize(
296 |                 height,
297 |                 width,
298 |                 factor=image_factor,
299 |                 min_pixels=min_pixels,
300 |                 max_pixels=max_pixels,
301 |             )
302 |         video = transforms.functional.resize(
303 |             video,
304 |             [resized_height, resized_width],
305 |             interpolation=InterpolationMode.BICUBIC,
306 |             antialias=True,
307 |         ).float()
308 |         return video
309 |     else:
310 |         assert isinstance(ele["video"], (list, tuple))
311 |         process_info = ele.copy()
312 |         process_info.pop("type", None)
313 |         process_info.pop("video", None)
314 |         images = [
315 |             fetch_image({
316 |                 "image": video_element,
317 |                 **process_info
318 |             },
319 |                         size_factor=image_factor)
320 |             for video_element in ele["video"]
321 |         ]
322 |         nframes = ceil_by_factor(len(images), FRAME_FACTOR)
323 |         if len(images) < nframes:
324 |             images.extend([images[-1]] * (nframes - len(images)))
325 |         return images
326 | 
327 | 
328 | def extract_vision_info(
329 |         conversations: list[dict] | list[list[dict]]) -> list[dict]:
330 |     vision_infos = []
331 |     if isinstance(conversations[0], dict):
332 |         conversations = [conversations]
333 |     for conversation in conversations:
334 |         for message in conversation:
335 |             if isinstance(message["content"], list):
336 |                 for ele in message["content"]:
337 |                     if ("image" in ele or "image_url" in ele or
338 |                             "video" in ele or
339 |                             ele["type"] in ("image", "image_url", "video")):
340 |                         vision_infos.append(ele)
341 |     return vision_infos
342 | 
343 | 
344 | def process_vision_info(
345 |     conversations: list[dict] | list[list[dict]],
346 | ) -> tuple[list[Image.Image] | None, list[torch.Tensor | list[Image.Image]] |
347 |            None]:
348 |     vision_infos = extract_vision_info(conversations)
349 |     ## Read images or videos
350 |     image_inputs = []
351 |     video_inputs = []
352 |     for vision_info in vision_infos:
353 |         if "image" in vision_info or "image_url" in vision_info:
354 |             image_inputs.append(fetch_image(vision_info))
355 |         elif "video" in vision_info:
356 |             video_inputs.append(fetch_video(vision_info))
357 |         else:
358 |             raise ValueError("image, image_url or video should in content.")
359 |     if len(image_inputs) == 0:
360 |         image_inputs = None
361 |     if len(video_inputs) == 0:
362 |         video_inputs = None
363 |     return image_inputs, video_inputs
364 | 


--------------------------------------------------------------------------------
/wan/utils/utils.py:
--------------------------------------------------------------------------------
  1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
  2 | import argparse
  3 | import binascii
  4 | import os
  5 | import os.path as osp
  6 | 
  7 | import imageio
  8 | import torch
  9 | import torchvision
 10 | 
 11 | __all__ = ['cache_video', 'cache_image', 'str2bool']
 12 | 
 13 | 
 14 | def rand_name(length=8, suffix=''):
 15 |     name = binascii.b2a_hex(os.urandom(length)).decode('utf-8')
 16 |     if suffix:
 17 |         if not suffix.startswith('.'):
 18 |             suffix = '.' + suffix
 19 |         name += suffix
 20 |     return name
 21 | 
 22 | 
 23 | def cache_video(tensor,
 24 |                 save_file=None,
 25 |                 fps=30,
 26 |                 suffix='.mp4',
 27 |                 nrow=8,
 28 |                 normalize=True,
 29 |                 value_range=(-1, 1),
 30 |                 retry=5):
 31 |     # cache file
 32 |     cache_file = osp.join('/tmp', rand_name(
 33 |         suffix=suffix)) if save_file is None else save_file
 34 | 
 35 |     # save to cache
 36 |     error = None
 37 |     for _ in range(retry):
 38 |         try:
 39 |             # preprocess
 40 |             tensor = tensor.clamp(min(value_range), max(value_range))
 41 |             tensor = torch.stack([
 42 |                 torchvision.utils.make_grid(
 43 |                     u, nrow=nrow, normalize=normalize, value_range=value_range)
 44 |                 for u in tensor.unbind(2)
 45 |             ],
 46 |                                  dim=1).permute(1, 2, 3, 0)
 47 |             tensor = (tensor * 255).type(torch.uint8).cpu()
 48 | 
 49 |             # write video
 50 |             writer = imageio.get_writer(
 51 |                 cache_file, fps=fps, codec='libx264', quality=8)
 52 |             for frame in tensor.numpy():
 53 |                 writer.append_data(frame)
 54 |             writer.close()
 55 |             return cache_file
 56 |         except Exception as e:
 57 |             error = e
 58 |             continue
 59 |     else:
 60 |         print(f'cache_video failed, error: {error}', flush=True)
 61 |         return None
 62 | 
 63 | 
 64 | def cache_image(tensor,
 65 |                 save_file,
 66 |                 nrow=8,
 67 |                 normalize=True,
 68 |                 value_range=(-1, 1),
 69 |                 retry=5):
 70 |     # cache file
 71 |     suffix = osp.splitext(save_file)[1]
 72 |     if suffix.lower() not in [
 73 |             '.jpg', '.jpeg', '.png', '.tiff', '.gif', '.webp'
 74 |     ]:
 75 |         suffix = '.png'
 76 | 
 77 |     # save to cache
 78 |     error = None
 79 |     for _ in range(retry):
 80 |         try:
 81 |             tensor = tensor.clamp(min(value_range), max(value_range))
 82 |             torchvision.utils.save_image(
 83 |                 tensor,
 84 |                 save_file,
 85 |                 nrow=nrow,
 86 |                 normalize=normalize,
 87 |                 value_range=value_range)
 88 |             return save_file
 89 |         except Exception as e:
 90 |             error = e
 91 |             continue
 92 | 
 93 | 
 94 | def str2bool(v):
 95 |     """
 96 |     Convert a string to a boolean.
 97 | 
 98 |     Supported true values: 'yes', 'true', 't', 'y', '1'
 99 |     Supported false values: 'no', 'false', 'f', 'n', '0'
100 | 
101 |     Args:
102 |         v (str): String to convert.
103 | 
104 |     Returns:
105 |         bool: Converted boolean value.
106 | 
107 |     Raises:
108 |         argparse.ArgumentTypeError: If the value cannot be converted to boolean.
109 |     """
110 |     if isinstance(v, bool):
111 |         return v
112 |     v_lower = v.lower()
113 |     if v_lower in ('yes', 'true', 't', 'y', '1'):
114 |         return True
115 |     elif v_lower in ('no', 'false', 'f', 'n', '0'):
116 |         return False
117 |     else:
118 |         raise argparse.ArgumentTypeError('Boolean value expected (True/False)')
119 | 


--------------------------------------------------------------------------------
/wan/utils/vace_processor.py:
--------------------------------------------------------------------------------
  1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
  2 | import numpy as np
  3 | import torch
  4 | import torch.nn.functional as F
  5 | import torchvision.transforms.functional as TF
  6 | from PIL import Image
  7 | 
  8 | 
  9 | class VaceImageProcessor(object):
 10 | 
 11 |     def __init__(self, downsample=None, seq_len=None):
 12 |         self.downsample = downsample
 13 |         self.seq_len = seq_len
 14 | 
 15 |     def _pillow_convert(self, image, cvt_type='RGB'):
 16 |         if image.mode != cvt_type:
 17 |             if image.mode == 'P':
 18 |                 image = image.convert(f'{cvt_type}A')
 19 |             if image.mode == f'{cvt_type}A':
 20 |                 bg = Image.new(
 21 |                     cvt_type,
 22 |                     size=(image.width, image.height),
 23 |                     color=(255, 255, 255))
 24 |                 bg.paste(image, (0, 0), mask=image)
 25 |                 image = bg
 26 |             else:
 27 |                 image = image.convert(cvt_type)
 28 |         return image
 29 | 
 30 |     def _load_image(self, img_path):
 31 |         if img_path is None or img_path == '':
 32 |             return None
 33 |         img = Image.open(img_path)
 34 |         img = self._pillow_convert(img)
 35 |         return img
 36 | 
 37 |     def _resize_crop(self, img, oh, ow, normalize=True):
 38 |         """
 39 |         Resize, center crop, convert to tensor, and normalize.
 40 |         """
 41 |         # resize and crop
 42 |         iw, ih = img.size
 43 |         if iw != ow or ih != oh:
 44 |             # resize
 45 |             scale = max(ow / iw, oh / ih)
 46 |             img = img.resize((round(scale * iw), round(scale * ih)),
 47 |                              resample=Image.Resampling.LANCZOS)
 48 |             assert img.width >= ow and img.height >= oh
 49 | 
 50 |             # center crop
 51 |             x1 = (img.width - ow) // 2
 52 |             y1 = (img.height - oh) // 2
 53 |             img = img.crop((x1, y1, x1 + ow, y1 + oh))
 54 | 
 55 |         # normalize
 56 |         if normalize:
 57 |             img = TF.to_tensor(img).sub_(0.5).div_(0.5).unsqueeze(1)
 58 |         return img
 59 | 
 60 |     def _image_preprocess(self, img, oh, ow, normalize=True, **kwargs):
 61 |         return self._resize_crop(img, oh, ow, normalize)
 62 | 
 63 |     def load_image(self, data_key, **kwargs):
 64 |         return self.load_image_batch(data_key, **kwargs)
 65 | 
 66 |     def load_image_pair(self, data_key, data_key2, **kwargs):
 67 |         return self.load_image_batch(data_key, data_key2, **kwargs)
 68 | 
 69 |     def load_image_batch(self,
 70 |                          *data_key_batch,
 71 |                          normalize=True,
 72 |                          seq_len=None,
 73 |                          **kwargs):
 74 |         seq_len = self.seq_len if seq_len is None else seq_len
 75 |         imgs = []
 76 |         for data_key in data_key_batch:
 77 |             img = self._load_image(data_key)
 78 |             imgs.append(img)
 79 |         w, h = imgs[0].size
 80 |         dh, dw = self.downsample[1:]
 81 | 
 82 |         # compute output size
 83 |         scale = min(1., np.sqrt(seq_len / ((h / dh) * (w / dw))))
 84 |         oh = int(h * scale) // dh * dh
 85 |         ow = int(w * scale) // dw * dw
 86 |         assert (oh // dh) * (ow // dw) <= seq_len
 87 |         imgs = [self._image_preprocess(img, oh, ow, normalize) for img in imgs]
 88 |         return *imgs, (oh, ow)
 89 | 
 90 | 
 91 | class VaceVideoProcessor(object):
 92 | 
 93 |     def __init__(self, downsample, min_area, max_area, min_fps, max_fps,
 94 |                  zero_start, seq_len, keep_last, **kwargs):
 95 |         self.downsample = downsample
 96 |         self.min_area = min_area
 97 |         self.max_area = max_area
 98 |         self.min_fps = min_fps
 99 |         self.max_fps = max_fps
100 |         self.zero_start = zero_start
101 |         self.keep_last = keep_last
102 |         self.seq_len = seq_len
103 |         assert seq_len >= min_area / (self.downsample[1] * self.downsample[2])
104 | 
105 |     def set_area(self, area):
106 |         self.min_area = area
107 |         self.max_area = area
108 | 
109 |     def set_seq_len(self, seq_len):
110 |         self.seq_len = seq_len
111 | 
112 |     @staticmethod
113 |     def resize_crop(video: torch.Tensor, oh: int, ow: int):
114 |         """
115 |         Resize, center crop and normalize for decord loaded video (torch.Tensor type)
116 | 
117 |         Parameters:
118 |           video - video to process (torch.Tensor): Tensor from `reader.get_batch(frame_ids)`, in shape of (T, H, W, C)
119 |           oh - target height (int)
120 |           ow - target width (int)
121 | 
122 |         Returns:
123 |             The processed video (torch.Tensor): Normalized tensor range [-1, 1], in shape of (C, T, H, W)
124 | 
125 |         Raises:
126 |         """
127 |         # permute ([t, h, w, c] -> [t, c, h, w])
128 |         video = video.permute(0, 3, 1, 2)
129 | 
130 |         # resize and crop
131 |         ih, iw = video.shape[2:]
132 |         if ih != oh or iw != ow:
133 |             # resize
134 |             scale = max(ow / iw, oh / ih)
135 |             video = F.interpolate(
136 |                 video,
137 |                 size=(round(scale * ih), round(scale * iw)),
138 |                 mode='bicubic',
139 |                 antialias=True)
140 |             assert video.size(3) >= ow and video.size(2) >= oh
141 | 
142 |             # center crop
143 |             x1 = (video.size(3) - ow) // 2
144 |             y1 = (video.size(2) - oh) // 2
145 |             video = video[:, :, y1:y1 + oh, x1:x1 + ow]
146 | 
147 |         # permute ([t, c, h, w] -> [c, t, h, w]) and normalize
148 |         video = video.transpose(0, 1).float().div_(127.5).sub_(1.)
149 |         return video
150 | 
151 |     def _video_preprocess(self, video, oh, ow):
152 |         return self.resize_crop(video, oh, ow)
153 | 
154 |     def _get_frameid_bbox_default(self, fps, frame_timestamps, h, w, crop_box,
155 |                                   rng):
156 |         target_fps = min(fps, self.max_fps)
157 |         duration = frame_timestamps[-1].mean()
158 |         x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box
159 |         h, w = y2 - y1, x2 - x1
160 |         ratio = h / w
161 |         df, dh, dw = self.downsample
162 | 
163 |         area_z = min(self.seq_len, self.max_area / (dh * dw),
164 |                      (h // dh) * (w // dw))
165 |         of = min((int(duration * target_fps) - 1) // df + 1,
166 |                  int(self.seq_len / area_z))
167 | 
168 |         # deduce target shape of the [latent video]
169 |         target_area_z = min(area_z, int(self.seq_len / of))
170 |         oh = round(np.sqrt(target_area_z * ratio))
171 |         ow = int(target_area_z / oh)
172 |         of = (of - 1) * df + 1
173 |         oh *= dh
174 |         ow *= dw
175 | 
176 |         # sample frame ids
177 |         target_duration = of / target_fps
178 |         begin = 0. if self.zero_start else rng.uniform(
179 |             0, duration - target_duration)
180 |         timestamps = np.linspace(begin, begin + target_duration, of)
181 |         frame_ids = np.argmax(
182 |             np.logical_and(timestamps[:, None] >= frame_timestamps[None, :, 0],
183 |                            timestamps[:, None] < frame_timestamps[None, :, 1]),
184 |             axis=1).tolist()
185 |         return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps
186 | 
187 |     def _get_frameid_bbox_adjust_last(self, fps, frame_timestamps, h, w,
188 |                                       crop_box, rng):
189 |         duration = frame_timestamps[-1].mean()
190 |         x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box
191 |         h, w = y2 - y1, x2 - x1
192 |         ratio = h / w
193 |         df, dh, dw = self.downsample
194 | 
195 |         area_z = min(self.seq_len, self.max_area / (dh * dw),
196 |                      (h // dh) * (w // dw))
197 |         of = min((len(frame_timestamps) - 1) // df + 1,
198 |                  int(self.seq_len / area_z))
199 | 
200 |         # deduce target shape of the [latent video]
201 |         target_area_z = min(area_z, int(self.seq_len / of))
202 |         oh = round(np.sqrt(target_area_z * ratio))
203 |         ow = int(target_area_z / oh)
204 |         of = (of - 1) * df + 1
205 |         oh *= dh
206 |         ow *= dw
207 | 
208 |         # sample frame ids
209 |         target_duration = duration
210 |         target_fps = of / target_duration
211 |         timestamps = np.linspace(0., target_duration, of)
212 |         frame_ids = np.argmax(
213 |             np.logical_and(timestamps[:, None] >= frame_timestamps[None, :, 0],
214 |                            timestamps[:, None] <= frame_timestamps[None, :, 1]),
215 |             axis=1).tolist()
216 |         # print(oh, ow, of, target_duration, target_fps, len(frame_timestamps), len(frame_ids))
217 |         return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps
218 | 
219 |     def _get_frameid_bbox(self, fps, frame_timestamps, h, w, crop_box, rng):
220 |         if self.keep_last:
221 |             return self._get_frameid_bbox_adjust_last(fps, frame_timestamps, h,
222 |                                                       w, crop_box, rng)
223 |         else:
224 |             return self._get_frameid_bbox_default(fps, frame_timestamps, h, w,
225 |                                                   crop_box, rng)
226 | 
227 |     def load_video(self, data_key, crop_box=None, seed=2024, **kwargs):
228 |         return self.load_video_batch(
229 |             data_key, crop_box=crop_box, seed=seed, **kwargs)
230 | 
231 |     def load_video_pair(self,
232 |                         data_key,
233 |                         data_key2,
234 |                         crop_box=None,
235 |                         seed=2024,
236 |                         **kwargs):
237 |         return self.load_video_batch(
238 |             data_key, data_key2, crop_box=crop_box, seed=seed, **kwargs)
239 | 
240 |     def load_video_batch(self,
241 |                          *data_key_batch,
242 |                          crop_box=None,
243 |                          seed=2024,
244 |                          **kwargs):
245 |         rng = np.random.default_rng(seed + hash(data_key_batch[0]) % 10000)
246 |         # read video
247 |         import decord
248 |         decord.bridge.set_bridge('torch')
249 |         readers = []
250 |         for data_k in data_key_batch:
251 |             reader = decord.VideoReader(data_k)
252 |             readers.append(reader)
253 | 
254 |         fps = readers[0].get_avg_fps()
255 |         length = min([len(r) for r in readers])
256 |         frame_timestamps = [
257 |             readers[0].get_frame_timestamp(i) for i in range(length)
258 |         ]
259 |         frame_timestamps = np.array(frame_timestamps, dtype=np.float32)
260 |         h, w = readers[0].next().shape[:2]
261 |         frame_ids, (x1, x2, y1, y2), (oh, ow), fps = self._get_frameid_bbox(
262 |             fps, frame_timestamps, h, w, crop_box, rng)
263 | 
264 |         # preprocess video
265 |         videos = [
266 |             reader.get_batch(frame_ids)[:, y1:y2, x1:x2, :]
267 |             for reader in readers
268 |         ]
269 |         videos = [self._video_preprocess(video, oh, ow) for video in videos]
270 |         return *videos, frame_ids, (oh, ow), fps
271 |         # return videos if len(videos) > 1 else videos[0]
272 | 
273 | 
274 | def prepare_source(src_video, src_mask, src_ref_images, num_frames, image_size,
275 |                    device):
276 |     for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)):
277 |         if sub_src_video is None and sub_src_mask is None:
278 |             src_video[i] = torch.zeros(
279 |                 (3, num_frames, image_size[0], image_size[1]), device=device)
280 |             src_mask[i] = torch.ones(
281 |                 (1, num_frames, image_size[0], image_size[1]), device=device)
282 |     for i, ref_images in enumerate(src_ref_images):
283 |         if ref_images is not None:
284 |             for j, ref_img in enumerate(ref_images):
285 |                 if ref_img is not None and ref_img.shape[-2:] != image_size:
286 |                     canvas_height, canvas_width = image_size
287 |                     ref_height, ref_width = ref_img.shape[-2:]
288 |                     white_canvas = torch.ones(
289 |                         (3, 1, canvas_height, canvas_width),
290 |                         device=device)  # [-1, 1]
291 |                     scale = min(canvas_height / ref_height,
292 |                                 canvas_width / ref_width)
293 |                     new_height = int(ref_height * scale)
294 |                     new_width = int(ref_width * scale)
295 |                     resized_image = F.interpolate(
296 |                         ref_img.squeeze(1).unsqueeze(0),
297 |                         size=(new_height, new_width),
298 |                         mode='bilinear',
299 |                         align_corners=False).squeeze(0).unsqueeze(1)
300 |                     top = (canvas_height - new_height) // 2
301 |                     left = (canvas_width - new_width) // 2
302 |                     white_canvas[:, :, top:top + new_height,
303 |                                  left:left + new_width] = resized_image
304 |                     src_ref_images[i][j] = white_canvas
305 |     return src_video, src_mask, src_ref_images
306 | 


--------------------------------------------------------------------------------