├── .gitignore ├── .style.yapf ├── INSTALL.md ├── LICENSE.txt ├── Makefile ├── README.md ├── assets ├── comp_effic.png ├── data_for_diff_stage.jpg ├── i2v_res.png ├── logo.png ├── t2v_res.jpg ├── vben_vs_sota.png ├── video_dit_arch.jpg └── video_vae_res.jpg ├── examples ├── flf2v_input_first_frame.png ├── flf2v_input_last_frame.png ├── girl.png ├── i2v_input.JPG └── snake.png ├── generate.py ├── gradio ├── fl2v_14B_singleGPU.py ├── i2v_14B_singleGPU.py ├── t2i_14B_singleGPU.py ├── t2v_1.3B_singleGPU.py ├── t2v_14B_singleGPU.py └── vace.py ├── pyproject.toml ├── requirements.txt ├── tests ├── README.md └── test.sh └── wan ├── __init__.py ├── configs ├── __init__.py ├── shared_config.py ├── wan_i2v_14B.py ├── wan_t2v_14B.py └── wan_t2v_1_3B.py ├── distributed ├── __init__.py ├── fsdp.py └── xdit_context_parallel.py ├── first_last_frame2video.py ├── image2video.py ├── modules ├── __init__.py ├── attention.py ├── clip.py ├── model.py ├── t5.py ├── tokenizers.py ├── vace_model.py ├── vae.py └── xlm_roberta.py ├── text2video.py ├── utils ├── __init__.py ├── fm_solvers.py ├── fm_solvers_unipc.py ├── prompt_extend.py ├── qwen_vl_utils.py ├── utils.py └── vace_processor.py └── vace.py /.gitignore: -------------------------------------------------------------------------------- 1 | .* 2 | *.py[cod] 3 | # *.jpg 4 | *.jpeg 5 | # *.png 6 | *.gif 7 | *.bmp 8 | *.mp4 9 | *.mov 10 | *.mkv 11 | *.log 12 | *.zip 13 | *.pt 14 | *.pth 15 | *.ckpt 16 | *.safetensors 17 | *.json 18 | # *.txt 19 | *.backup 20 | *.pkl 21 | *.html 22 | *.pdf 23 | *.whl 24 | cache 25 | __pycache__/ 26 | storage/ 27 | samples/ 28 | !.gitignore 29 | !requirements.txt 30 | .DS_Store 31 | *DS_Store 32 | google/ 33 | Wan2.1-T2V-14B/ 34 | Wan2.1-T2V-1.3B/ 35 | Wan2.1-I2V-14B-480P/ 36 | Wan2.1-I2V-14B-720P/ 37 | poetry.lock -------------------------------------------------------------------------------- /.style.yapf: -------------------------------------------------------------------------------- 1 | [style] 2 | # Align closing bracket with visual indentation. 3 | align_closing_bracket_with_visual_indent=False 4 | 5 | # Allow dictionary keys to exist on multiple lines. For example: 6 | # 7 | # x = { 8 | # ('this is the first element of a tuple', 9 | # 'this is the second element of a tuple'): 10 | # value, 11 | # } 12 | allow_multiline_dictionary_keys=False 13 | 14 | # Allow lambdas to be formatted on more than one line. 15 | allow_multiline_lambdas=False 16 | 17 | # Allow splitting before a default / named assignment in an argument list. 18 | allow_split_before_default_or_named_assigns=False 19 | 20 | # Allow splits before the dictionary value. 21 | allow_split_before_dict_value=True 22 | 23 | # Let spacing indicate operator precedence. For example: 24 | # 25 | # a = 1 * 2 + 3 / 4 26 | # b = 1 / 2 - 3 * 4 27 | # c = (1 + 2) * (3 - 4) 28 | # d = (1 - 2) / (3 + 4) 29 | # e = 1 * 2 - 3 30 | # f = 1 + 2 + 3 + 4 31 | # 32 | # will be formatted as follows to indicate precedence: 33 | # 34 | # a = 1*2 + 3/4 35 | # b = 1/2 - 3*4 36 | # c = (1+2) * (3-4) 37 | # d = (1-2) / (3+4) 38 | # e = 1*2 - 3 39 | # f = 1 + 2 + 3 + 4 40 | # 41 | arithmetic_precedence_indication=False 42 | 43 | # Number of blank lines surrounding top-level function and class 44 | # definitions. 45 | blank_lines_around_top_level_definition=2 46 | 47 | # Insert a blank line before a class-level docstring. 48 | blank_line_before_class_docstring=False 49 | 50 | # Insert a blank line before a module docstring. 51 | blank_line_before_module_docstring=False 52 | 53 | # Insert a blank line before a 'def' or 'class' immediately nested 54 | # within another 'def' or 'class'. For example: 55 | # 56 | # class Foo: 57 | # # <------ this blank line 58 | # def method(): 59 | # ... 60 | blank_line_before_nested_class_or_def=True 61 | 62 | # Do not split consecutive brackets. Only relevant when 63 | # dedent_closing_brackets is set. For example: 64 | # 65 | # call_func_that_takes_a_dict( 66 | # { 67 | # 'key1': 'value1', 68 | # 'key2': 'value2', 69 | # } 70 | # ) 71 | # 72 | # would reformat to: 73 | # 74 | # call_func_that_takes_a_dict({ 75 | # 'key1': 'value1', 76 | # 'key2': 'value2', 77 | # }) 78 | coalesce_brackets=False 79 | 80 | # The column limit. 81 | column_limit=80 82 | 83 | # The style for continuation alignment. Possible values are: 84 | # 85 | # - SPACE: Use spaces for continuation alignment. This is default behavior. 86 | # - FIXED: Use fixed number (CONTINUATION_INDENT_WIDTH) of columns 87 | # (ie: CONTINUATION_INDENT_WIDTH/INDENT_WIDTH tabs or 88 | # CONTINUATION_INDENT_WIDTH spaces) for continuation alignment. 89 | # - VALIGN-RIGHT: Vertically align continuation lines to multiple of 90 | # INDENT_WIDTH columns. Slightly right (one tab or a few spaces) if 91 | # cannot vertically align continuation lines with indent characters. 92 | continuation_align_style=SPACE 93 | 94 | # Indent width used for line continuations. 95 | continuation_indent_width=4 96 | 97 | # Put closing brackets on a separate line, dedented, if the bracketed 98 | # expression can't fit in a single line. Applies to all kinds of brackets, 99 | # including function definitions and calls. For example: 100 | # 101 | # config = { 102 | # 'key1': 'value1', 103 | # 'key2': 'value2', 104 | # } # <--- this bracket is dedented and on a separate line 105 | # 106 | # time_series = self.remote_client.query_entity_counters( 107 | # entity='dev3246.region1', 108 | # key='dns.query_latency_tcp', 109 | # transform=Transformation.AVERAGE(window=timedelta(seconds=60)), 110 | # start_ts=now()-timedelta(days=3), 111 | # end_ts=now(), 112 | # ) # <--- this bracket is dedented and on a separate line 113 | dedent_closing_brackets=False 114 | 115 | # Disable the heuristic which places each list element on a separate line 116 | # if the list is comma-terminated. 117 | disable_ending_comma_heuristic=False 118 | 119 | # Place each dictionary entry onto its own line. 120 | each_dict_entry_on_separate_line=True 121 | 122 | # Require multiline dictionary even if it would normally fit on one line. 123 | # For example: 124 | # 125 | # config = { 126 | # 'key1': 'value1' 127 | # } 128 | force_multiline_dict=False 129 | 130 | # The regex for an i18n comment. The presence of this comment stops 131 | # reformatting of that line, because the comments are required to be 132 | # next to the string they translate. 133 | i18n_comment=#\..* 134 | 135 | # The i18n function call names. The presence of this function stops 136 | # reformattting on that line, because the string it has cannot be moved 137 | # away from the i18n comment. 138 | i18n_function_call=N_, _ 139 | 140 | # Indent blank lines. 141 | indent_blank_lines=False 142 | 143 | # Put closing brackets on a separate line, indented, if the bracketed 144 | # expression can't fit in a single line. Applies to all kinds of brackets, 145 | # including function definitions and calls. For example: 146 | # 147 | # config = { 148 | # 'key1': 'value1', 149 | # 'key2': 'value2', 150 | # } # <--- this bracket is indented and on a separate line 151 | # 152 | # time_series = self.remote_client.query_entity_counters( 153 | # entity='dev3246.region1', 154 | # key='dns.query_latency_tcp', 155 | # transform=Transformation.AVERAGE(window=timedelta(seconds=60)), 156 | # start_ts=now()-timedelta(days=3), 157 | # end_ts=now(), 158 | # ) # <--- this bracket is indented and on a separate line 159 | indent_closing_brackets=False 160 | 161 | # Indent the dictionary value if it cannot fit on the same line as the 162 | # dictionary key. For example: 163 | # 164 | # config = { 165 | # 'key1': 166 | # 'value1', 167 | # 'key2': value1 + 168 | # value2, 169 | # } 170 | indent_dictionary_value=True 171 | 172 | # The number of columns to use for indentation. 173 | indent_width=4 174 | 175 | # Join short lines into one line. E.g., single line 'if' statements. 176 | join_multiple_lines=False 177 | 178 | # Do not include spaces around selected binary operators. For example: 179 | # 180 | # 1 + 2 * 3 - 4 / 5 181 | # 182 | # will be formatted as follows when configured with "*,/": 183 | # 184 | # 1 + 2*3 - 4/5 185 | no_spaces_around_selected_binary_operators= 186 | 187 | # Use spaces around default or named assigns. 188 | spaces_around_default_or_named_assign=False 189 | 190 | # Adds a space after the opening '{' and before the ending '}' dict delimiters. 191 | # 192 | # {1: 2} 193 | # 194 | # will be formatted as: 195 | # 196 | # { 1: 2 } 197 | spaces_around_dict_delimiters=False 198 | 199 | # Adds a space after the opening '[' and before the ending ']' list delimiters. 200 | # 201 | # [1, 2] 202 | # 203 | # will be formatted as: 204 | # 205 | # [ 1, 2 ] 206 | spaces_around_list_delimiters=False 207 | 208 | # Use spaces around the power operator. 209 | spaces_around_power_operator=False 210 | 211 | # Use spaces around the subscript / slice operator. For example: 212 | # 213 | # my_list[1 : 10 : 2] 214 | spaces_around_subscript_colon=False 215 | 216 | # Adds a space after the opening '(' and before the ending ')' tuple delimiters. 217 | # 218 | # (1, 2, 3) 219 | # 220 | # will be formatted as: 221 | # 222 | # ( 1, 2, 3 ) 223 | spaces_around_tuple_delimiters=False 224 | 225 | # The number of spaces required before a trailing comment. 226 | # This can be a single value (representing the number of spaces 227 | # before each trailing comment) or list of values (representing 228 | # alignment column values; trailing comments within a block will 229 | # be aligned to the first column value that is greater than the maximum 230 | # line length within the block). For example: 231 | # 232 | # With spaces_before_comment=5: 233 | # 234 | # 1 + 1 # Adding values 235 | # 236 | # will be formatted as: 237 | # 238 | # 1 + 1 # Adding values <-- 5 spaces between the end of the statement and comment 239 | # 240 | # With spaces_before_comment=15, 20: 241 | # 242 | # 1 + 1 # Adding values 243 | # two + two # More adding 244 | # 245 | # longer_statement # This is a longer statement 246 | # short # This is a shorter statement 247 | # 248 | # a_very_long_statement_that_extends_beyond_the_final_column # Comment 249 | # short # This is a shorter statement 250 | # 251 | # will be formatted as: 252 | # 253 | # 1 + 1 # Adding values <-- end of line comments in block aligned to col 15 254 | # two + two # More adding 255 | # 256 | # longer_statement # This is a longer statement <-- end of line comments in block aligned to col 20 257 | # short # This is a shorter statement 258 | # 259 | # a_very_long_statement_that_extends_beyond_the_final_column # Comment <-- the end of line comments are aligned based on the line length 260 | # short # This is a shorter statement 261 | # 262 | spaces_before_comment=2 263 | 264 | # Insert a space between the ending comma and closing bracket of a list, 265 | # etc. 266 | space_between_ending_comma_and_closing_bracket=False 267 | 268 | # Use spaces inside brackets, braces, and parentheses. For example: 269 | # 270 | # method_call( 1 ) 271 | # my_dict[ 3 ][ 1 ][ get_index( *args, **kwargs ) ] 272 | # my_set = { 1, 2, 3 } 273 | space_inside_brackets=False 274 | 275 | # Split before arguments 276 | split_all_comma_separated_values=False 277 | 278 | # Split before arguments, but do not split all subexpressions recursively 279 | # (unless needed). 280 | split_all_top_level_comma_separated_values=False 281 | 282 | # Split before arguments if the argument list is terminated by a 283 | # comma. 284 | split_arguments_when_comma_terminated=False 285 | 286 | # Set to True to prefer splitting before '+', '-', '*', '/', '//', or '@' 287 | # rather than after. 288 | split_before_arithmetic_operator=False 289 | 290 | # Set to True to prefer splitting before '&', '|' or '^' rather than 291 | # after. 292 | split_before_bitwise_operator=False 293 | 294 | # Split before the closing bracket if a list or dict literal doesn't fit on 295 | # a single line. 296 | split_before_closing_bracket=True 297 | 298 | # Split before a dictionary or set generator (comp_for). For example, note 299 | # the split before the 'for': 300 | # 301 | # foo = { 302 | # variable: 'Hello world, have a nice day!' 303 | # for variable in bar if variable != 42 304 | # } 305 | split_before_dict_set_generator=False 306 | 307 | # Split before the '.' if we need to split a longer expression: 308 | # 309 | # foo = ('This is a really long string: {}, {}, {}, {}'.format(a, b, c, d)) 310 | # 311 | # would reformat to something like: 312 | # 313 | # foo = ('This is a really long string: {}, {}, {}, {}' 314 | # .format(a, b, c, d)) 315 | split_before_dot=False 316 | 317 | # Split after the opening paren which surrounds an expression if it doesn't 318 | # fit on a single line. 319 | split_before_expression_after_opening_paren=True 320 | 321 | # If an argument / parameter list is going to be split, then split before 322 | # the first argument. 323 | split_before_first_argument=False 324 | 325 | # Set to True to prefer splitting before 'and' or 'or' rather than 326 | # after. 327 | split_before_logical_operator=False 328 | 329 | # Split named assignments onto individual lines. 330 | split_before_named_assigns=True 331 | 332 | # Set to True to split list comprehensions and generators that have 333 | # non-trivial expressions and multiple clauses before each of these 334 | # clauses. For example: 335 | # 336 | # result = [ 337 | # a_long_var + 100 for a_long_var in xrange(1000) 338 | # if a_long_var % 10] 339 | # 340 | # would reformat to something like: 341 | # 342 | # result = [ 343 | # a_long_var + 100 344 | # for a_long_var in xrange(1000) 345 | # if a_long_var % 10] 346 | split_complex_comprehension=True 347 | 348 | # The penalty for splitting right after the opening bracket. 349 | split_penalty_after_opening_bracket=300 350 | 351 | # The penalty for splitting the line after a unary operator. 352 | split_penalty_after_unary_operator=10000 353 | 354 | # The penalty of splitting the line around the '+', '-', '*', '/', '//', 355 | # ``%``, and '@' operators. 356 | split_penalty_arithmetic_operator=300 357 | 358 | # The penalty for splitting right before an if expression. 359 | split_penalty_before_if_expr=0 360 | 361 | # The penalty of splitting the line around the '&', '|', and '^' 362 | # operators. 363 | split_penalty_bitwise_operator=300 364 | 365 | # The penalty for splitting a list comprehension or generator 366 | # expression. 367 | split_penalty_comprehension=2100 368 | 369 | # The penalty for characters over the column limit. 370 | split_penalty_excess_character=7000 371 | 372 | # The penalty incurred by adding a line split to the unwrapped line. The 373 | # more line splits added the higher the penalty. 374 | split_penalty_for_added_line_split=30 375 | 376 | # The penalty of splitting a list of "import as" names. For example: 377 | # 378 | # from a_very_long_or_indented_module_name_yada_yad import (long_argument_1, 379 | # long_argument_2, 380 | # long_argument_3) 381 | # 382 | # would reformat to something like: 383 | # 384 | # from a_very_long_or_indented_module_name_yada_yad import ( 385 | # long_argument_1, long_argument_2, long_argument_3) 386 | split_penalty_import_names=0 387 | 388 | # The penalty of splitting the line around the 'and' and 'or' 389 | # operators. 390 | split_penalty_logical_operator=300 391 | 392 | # Use the Tab character for indentation. 393 | use_tabs=False -------------------------------------------------------------------------------- /INSTALL.md: -------------------------------------------------------------------------------- 1 | # Installation Guide 2 | 3 | ## Install with pip 4 | 5 | ```bash 6 | pip install . 7 | pip install .[dev] # Installe aussi les outils de dev 8 | ``` 9 | 10 | ## Install with Poetry 11 | 12 | Ensure you have [Poetry](https://python-poetry.org/docs/#installation) installed on your system. 13 | 14 | To install all dependencies: 15 | 16 | ```bash 17 | poetry install 18 | ``` 19 | 20 | ### Handling `flash-attn` Installation Issues 21 | 22 | If `flash-attn` fails due to **PEP 517 build issues**, you can try one of the following fixes. 23 | 24 | #### No-Build-Isolation Installation (Recommended) 25 | ```bash 26 | poetry run pip install --upgrade pip setuptools wheel 27 | poetry run pip install flash-attn --no-build-isolation 28 | poetry install 29 | ``` 30 | 31 | #### Install from Git (Alternative) 32 | ```bash 33 | poetry run pip install git+https://github.com/Dao-AILab/flash-attention.git 34 | ``` 35 | 36 | --- 37 | 38 | ### Running the Model 39 | 40 | Once the installation is complete, you can run **Wan2.1** using: 41 | 42 | ```bash 43 | poetry run python generate.py --task t2v-14B --size '1280x720' --ckpt_dir ./Wan2.1-T2V-14B --prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." 44 | ``` 45 | 46 | #### Test 47 | ```bash 48 | pytest tests/ 49 | ``` 50 | #### Format 51 | ```bash 52 | black . 53 | isort . 54 | ``` 55 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: format 2 | 3 | format: 4 | isort generate.py gradio wan 5 | yapf -i -r *.py generate.py gradio wan 6 | -------------------------------------------------------------------------------- /assets/comp_effic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wan-Video/Wan2.1/261ca43e67ccaf68dbb5c0d8d1f971efde267573/assets/comp_effic.png -------------------------------------------------------------------------------- /assets/data_for_diff_stage.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wan-Video/Wan2.1/261ca43e67ccaf68dbb5c0d8d1f971efde267573/assets/data_for_diff_stage.jpg -------------------------------------------------------------------------------- /assets/i2v_res.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wan-Video/Wan2.1/261ca43e67ccaf68dbb5c0d8d1f971efde267573/assets/i2v_res.png -------------------------------------------------------------------------------- /assets/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wan-Video/Wan2.1/261ca43e67ccaf68dbb5c0d8d1f971efde267573/assets/logo.png -------------------------------------------------------------------------------- /assets/t2v_res.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wan-Video/Wan2.1/261ca43e67ccaf68dbb5c0d8d1f971efde267573/assets/t2v_res.jpg -------------------------------------------------------------------------------- /assets/vben_vs_sota.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wan-Video/Wan2.1/261ca43e67ccaf68dbb5c0d8d1f971efde267573/assets/vben_vs_sota.png -------------------------------------------------------------------------------- /assets/video_dit_arch.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wan-Video/Wan2.1/261ca43e67ccaf68dbb5c0d8d1f971efde267573/assets/video_dit_arch.jpg -------------------------------------------------------------------------------- /assets/video_vae_res.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wan-Video/Wan2.1/261ca43e67ccaf68dbb5c0d8d1f971efde267573/assets/video_vae_res.jpg -------------------------------------------------------------------------------- /examples/flf2v_input_first_frame.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wan-Video/Wan2.1/261ca43e67ccaf68dbb5c0d8d1f971efde267573/examples/flf2v_input_first_frame.png -------------------------------------------------------------------------------- /examples/flf2v_input_last_frame.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wan-Video/Wan2.1/261ca43e67ccaf68dbb5c0d8d1f971efde267573/examples/flf2v_input_last_frame.png -------------------------------------------------------------------------------- /examples/girl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wan-Video/Wan2.1/261ca43e67ccaf68dbb5c0d8d1f971efde267573/examples/girl.png -------------------------------------------------------------------------------- /examples/i2v_input.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wan-Video/Wan2.1/261ca43e67ccaf68dbb5c0d8d1f971efde267573/examples/i2v_input.JPG -------------------------------------------------------------------------------- /examples/snake.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wan-Video/Wan2.1/261ca43e67ccaf68dbb5c0d8d1f971efde267573/examples/snake.png -------------------------------------------------------------------------------- /gradio/fl2v_14B_singleGPU.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import argparse 3 | import gc 4 | import os 5 | import os.path as osp 6 | import sys 7 | import warnings 8 | 9 | import gradio as gr 10 | 11 | warnings.filterwarnings('ignore') 12 | 13 | # Model 14 | sys.path.insert( 15 | 0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2])) 16 | import wan 17 | from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS 18 | from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander 19 | from wan.utils.utils import cache_video 20 | 21 | # Global Var 22 | prompt_expander = None 23 | wan_flf2v_720P = None 24 | 25 | 26 | # Button Func 27 | def load_model(value): 28 | global wan_flf2v_720P 29 | 30 | if value == '------': 31 | print("No model loaded") 32 | return '------' 33 | 34 | if value == '720P': 35 | if args.ckpt_dir_720p is None: 36 | print("Please specify the checkpoint directory for 720P model") 37 | return '------' 38 | if wan_flf2v_720P is not None: 39 | pass 40 | else: 41 | gc.collect() 42 | 43 | print("load 14B-720P flf2v model...", end='', flush=True) 44 | cfg = WAN_CONFIGS['flf2v-14B'] 45 | wan_flf2v_720P = wan.WanFLF2V( 46 | config=cfg, 47 | checkpoint_dir=args.ckpt_dir_720p, 48 | device_id=0, 49 | rank=0, 50 | t5_fsdp=False, 51 | dit_fsdp=False, 52 | use_usp=False, 53 | ) 54 | print("done", flush=True) 55 | return '720P' 56 | return value 57 | 58 | 59 | def prompt_enc(prompt, img_first, img_last, tar_lang): 60 | print('prompt extend...') 61 | if img_first is None or img_last is None: 62 | print('Please upload the first and last frames') 63 | return prompt 64 | global prompt_expander 65 | prompt_output = prompt_expander( 66 | prompt, image=[img_first, img_last], tar_lang=tar_lang.lower()) 67 | if prompt_output.status == False: 68 | return prompt 69 | else: 70 | return prompt_output.prompt 71 | 72 | 73 | def flf2v_generation(flf2vid_prompt, flf2vid_image_first, flf2vid_image_last, 74 | resolution, sd_steps, guide_scale, shift_scale, seed, 75 | n_prompt): 76 | 77 | if resolution == '------': 78 | print( 79 | 'Please specify the resolution ckpt dir or specify the resolution') 80 | return None 81 | 82 | else: 83 | if resolution == '720P': 84 | global wan_flf2v_720P 85 | video = wan_flf2v_720P.generate( 86 | flf2vid_prompt, 87 | flf2vid_image_first, 88 | flf2vid_image_last, 89 | max_area=MAX_AREA_CONFIGS['720*1280'], 90 | shift=shift_scale, 91 | sampling_steps=sd_steps, 92 | guide_scale=guide_scale, 93 | n_prompt=n_prompt, 94 | seed=seed, 95 | offload_model=True) 96 | pass 97 | else: 98 | print('Sorry, currently only 720P is supported.') 99 | return None 100 | 101 | cache_video( 102 | tensor=video[None], 103 | save_file="example.mp4", 104 | fps=16, 105 | nrow=1, 106 | normalize=True, 107 | value_range=(-1, 1)) 108 | 109 | return "example.mp4" 110 | 111 | 112 | # Interface 113 | def gradio_interface(): 114 | with gr.Blocks() as demo: 115 | gr.Markdown(""" 116 | <div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;"> 117 | Wan2.1 (FLF2V-14B) 118 | </div> 119 | <div style="text-align: center; font-size: 16px; font-weight: normal; margin-bottom: 20px;"> 120 | Wan: Open and Advanced Large-Scale Video Generative Models. 121 | </div> 122 | """) 123 | 124 | with gr.Row(): 125 | with gr.Column(): 126 | resolution = gr.Dropdown( 127 | label='Resolution', 128 | choices=['------', '720P'], 129 | value='------') 130 | flf2vid_image_first = gr.Image( 131 | type="pil", 132 | label="Upload First Frame", 133 | elem_id="image_upload", 134 | ) 135 | flf2vid_image_last = gr.Image( 136 | type="pil", 137 | label="Upload Last Frame", 138 | elem_id="image_upload", 139 | ) 140 | flf2vid_prompt = gr.Textbox( 141 | label="Prompt", 142 | placeholder="Describe the video you want to generate", 143 | ) 144 | tar_lang = gr.Radio( 145 | choices=["ZH", "EN"], 146 | label="Target language of prompt enhance", 147 | value="ZH") 148 | run_p_button = gr.Button(value="Prompt Enhance") 149 | 150 | with gr.Accordion("Advanced Options", open=True): 151 | with gr.Row(): 152 | sd_steps = gr.Slider( 153 | label="Diffusion steps", 154 | minimum=1, 155 | maximum=1000, 156 | value=50, 157 | step=1) 158 | guide_scale = gr.Slider( 159 | label="Guide scale", 160 | minimum=0, 161 | maximum=20, 162 | value=5.0, 163 | step=1) 164 | with gr.Row(): 165 | shift_scale = gr.Slider( 166 | label="Shift scale", 167 | minimum=0, 168 | maximum=20, 169 | value=5.0, 170 | step=1) 171 | seed = gr.Slider( 172 | label="Seed", 173 | minimum=-1, 174 | maximum=2147483647, 175 | step=1, 176 | value=-1) 177 | n_prompt = gr.Textbox( 178 | label="Negative Prompt", 179 | placeholder="Describe the negative prompt you want to add" 180 | ) 181 | 182 | run_flf2v_button = gr.Button("Generate Video") 183 | 184 | with gr.Column(): 185 | result_gallery = gr.Video( 186 | label='Generated Video', interactive=False, height=600) 187 | 188 | resolution.input( 189 | fn=load_model, inputs=[resolution], outputs=[resolution]) 190 | 191 | run_p_button.click( 192 | fn=prompt_enc, 193 | inputs=[ 194 | flf2vid_prompt, flf2vid_image_first, flf2vid_image_last, 195 | tar_lang 196 | ], 197 | outputs=[flf2vid_prompt]) 198 | 199 | run_flf2v_button.click( 200 | fn=flf2v_generation, 201 | inputs=[ 202 | flf2vid_prompt, flf2vid_image_first, flf2vid_image_last, 203 | resolution, sd_steps, guide_scale, shift_scale, seed, n_prompt 204 | ], 205 | outputs=[result_gallery], 206 | ) 207 | 208 | return demo 209 | 210 | 211 | # Main 212 | def _parse_args(): 213 | parser = argparse.ArgumentParser( 214 | description="Generate a video from a text prompt or image using Gradio") 215 | parser.add_argument( 216 | "--ckpt_dir_720p", 217 | type=str, 218 | default=None, 219 | help="The path to the checkpoint directory.") 220 | parser.add_argument( 221 | "--prompt_extend_method", 222 | type=str, 223 | default="local_qwen", 224 | choices=["dashscope", "local_qwen"], 225 | help="The prompt extend method to use.") 226 | parser.add_argument( 227 | "--prompt_extend_model", 228 | type=str, 229 | default=None, 230 | help="The prompt extend model to use.") 231 | 232 | args = parser.parse_args() 233 | assert args.ckpt_dir_720p is not None, "Please specify the checkpoint directory." 234 | 235 | return args 236 | 237 | 238 | if __name__ == '__main__': 239 | args = _parse_args() 240 | 241 | print("Step1: Init prompt_expander...", end='', flush=True) 242 | if args.prompt_extend_method == "dashscope": 243 | prompt_expander = DashScopePromptExpander( 244 | model_name=args.prompt_extend_model, is_vl=True) 245 | elif args.prompt_extend_method == "local_qwen": 246 | prompt_expander = QwenPromptExpander( 247 | model_name=args.prompt_extend_model, is_vl=True, device=0) 248 | else: 249 | raise NotImplementedError( 250 | f"Unsupport prompt_extend_method: {args.prompt_extend_method}") 251 | print("done", flush=True) 252 | 253 | demo = gradio_interface() 254 | demo.launch(server_name="0.0.0.0", share=False, server_port=7860) 255 | -------------------------------------------------------------------------------- /gradio/i2v_14B_singleGPU.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import argparse 3 | import gc 4 | import os 5 | import os.path as osp 6 | import sys 7 | import warnings 8 | 9 | import gradio as gr 10 | 11 | warnings.filterwarnings('ignore') 12 | 13 | # Model 14 | sys.path.insert( 15 | 0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2])) 16 | import wan 17 | from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS 18 | from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander 19 | from wan.utils.utils import cache_video 20 | 21 | # Global Var 22 | prompt_expander = None 23 | wan_i2v_480P = None 24 | wan_i2v_720P = None 25 | 26 | 27 | # Button Func 28 | def load_model(value): 29 | global wan_i2v_480P, wan_i2v_720P 30 | 31 | if value == '------': 32 | print("No model loaded") 33 | return '------' 34 | 35 | if value == '720P': 36 | if args.ckpt_dir_720p is None: 37 | print("Please specify the checkpoint directory for 720P model") 38 | return '------' 39 | if wan_i2v_720P is not None: 40 | pass 41 | else: 42 | del wan_i2v_480P 43 | gc.collect() 44 | wan_i2v_480P = None 45 | 46 | print("load 14B-720P i2v model...", end='', flush=True) 47 | cfg = WAN_CONFIGS['i2v-14B'] 48 | wan_i2v_720P = wan.WanI2V( 49 | config=cfg, 50 | checkpoint_dir=args.ckpt_dir_720p, 51 | device_id=0, 52 | rank=0, 53 | t5_fsdp=False, 54 | dit_fsdp=False, 55 | use_usp=False, 56 | ) 57 | print("done", flush=True) 58 | return '720P' 59 | 60 | if value == '480P': 61 | if args.ckpt_dir_480p is None: 62 | print("Please specify the checkpoint directory for 480P model") 63 | return '------' 64 | if wan_i2v_480P is not None: 65 | pass 66 | else: 67 | del wan_i2v_720P 68 | gc.collect() 69 | wan_i2v_720P = None 70 | 71 | print("load 14B-480P i2v model...", end='', flush=True) 72 | cfg = WAN_CONFIGS['i2v-14B'] 73 | wan_i2v_480P = wan.WanI2V( 74 | config=cfg, 75 | checkpoint_dir=args.ckpt_dir_480p, 76 | device_id=0, 77 | rank=0, 78 | t5_fsdp=False, 79 | dit_fsdp=False, 80 | use_usp=False, 81 | ) 82 | print("done", flush=True) 83 | return '480P' 84 | return value 85 | 86 | 87 | def prompt_enc(prompt, img, tar_lang): 88 | print('prompt extend...') 89 | if img is None: 90 | print('Please upload an image') 91 | return prompt 92 | global prompt_expander 93 | prompt_output = prompt_expander( 94 | prompt, image=img, tar_lang=tar_lang.lower()) 95 | if prompt_output.status == False: 96 | return prompt 97 | else: 98 | return prompt_output.prompt 99 | 100 | 101 | def i2v_generation(img2vid_prompt, img2vid_image, resolution, sd_steps, 102 | guide_scale, shift_scale, seed, n_prompt): 103 | # print(f"{img2vid_prompt},{resolution},{sd_steps},{guide_scale},{shift_scale},{seed},{n_prompt}") 104 | 105 | if resolution == '------': 106 | print( 107 | 'Please specify at least one resolution ckpt dir or specify the resolution' 108 | ) 109 | return None 110 | 111 | else: 112 | if resolution == '720P': 113 | global wan_i2v_720P 114 | video = wan_i2v_720P.generate( 115 | img2vid_prompt, 116 | img2vid_image, 117 | max_area=MAX_AREA_CONFIGS['720*1280'], 118 | shift=shift_scale, 119 | sampling_steps=sd_steps, 120 | guide_scale=guide_scale, 121 | n_prompt=n_prompt, 122 | seed=seed, 123 | offload_model=True) 124 | else: 125 | global wan_i2v_480P 126 | video = wan_i2v_480P.generate( 127 | img2vid_prompt, 128 | img2vid_image, 129 | max_area=MAX_AREA_CONFIGS['480*832'], 130 | shift=shift_scale, 131 | sampling_steps=sd_steps, 132 | guide_scale=guide_scale, 133 | n_prompt=n_prompt, 134 | seed=seed, 135 | offload_model=True) 136 | 137 | cache_video( 138 | tensor=video[None], 139 | save_file="example.mp4", 140 | fps=16, 141 | nrow=1, 142 | normalize=True, 143 | value_range=(-1, 1)) 144 | 145 | return "example.mp4" 146 | 147 | 148 | # Interface 149 | def gradio_interface(): 150 | with gr.Blocks() as demo: 151 | gr.Markdown(""" 152 | <div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;"> 153 | Wan2.1 (I2V-14B) 154 | </div> 155 | <div style="text-align: center; font-size: 16px; font-weight: normal; margin-bottom: 20px;"> 156 | Wan: Open and Advanced Large-Scale Video Generative Models. 157 | </div> 158 | """) 159 | 160 | with gr.Row(): 161 | with gr.Column(): 162 | resolution = gr.Dropdown( 163 | label='Resolution', 164 | choices=['------', '720P', '480P'], 165 | value='------') 166 | 167 | img2vid_image = gr.Image( 168 | type="pil", 169 | label="Upload Input Image", 170 | elem_id="image_upload", 171 | ) 172 | img2vid_prompt = gr.Textbox( 173 | label="Prompt", 174 | placeholder="Describe the video you want to generate", 175 | ) 176 | tar_lang = gr.Radio( 177 | choices=["ZH", "EN"], 178 | label="Target language of prompt enhance", 179 | value="ZH") 180 | run_p_button = gr.Button(value="Prompt Enhance") 181 | 182 | with gr.Accordion("Advanced Options", open=True): 183 | with gr.Row(): 184 | sd_steps = gr.Slider( 185 | label="Diffusion steps", 186 | minimum=1, 187 | maximum=1000, 188 | value=50, 189 | step=1) 190 | guide_scale = gr.Slider( 191 | label="Guide scale", 192 | minimum=0, 193 | maximum=20, 194 | value=5.0, 195 | step=1) 196 | with gr.Row(): 197 | shift_scale = gr.Slider( 198 | label="Shift scale", 199 | minimum=0, 200 | maximum=10, 201 | value=5.0, 202 | step=1) 203 | seed = gr.Slider( 204 | label="Seed", 205 | minimum=-1, 206 | maximum=2147483647, 207 | step=1, 208 | value=-1) 209 | n_prompt = gr.Textbox( 210 | label="Negative Prompt", 211 | placeholder="Describe the negative prompt you want to add" 212 | ) 213 | 214 | run_i2v_button = gr.Button("Generate Video") 215 | 216 | with gr.Column(): 217 | result_gallery = gr.Video( 218 | label='Generated Video', interactive=False, height=600) 219 | 220 | resolution.input( 221 | fn=load_model, inputs=[resolution], outputs=[resolution]) 222 | 223 | run_p_button.click( 224 | fn=prompt_enc, 225 | inputs=[img2vid_prompt, img2vid_image, tar_lang], 226 | outputs=[img2vid_prompt]) 227 | 228 | run_i2v_button.click( 229 | fn=i2v_generation, 230 | inputs=[ 231 | img2vid_prompt, img2vid_image, resolution, sd_steps, 232 | guide_scale, shift_scale, seed, n_prompt 233 | ], 234 | outputs=[result_gallery], 235 | ) 236 | 237 | return demo 238 | 239 | 240 | # Main 241 | def _parse_args(): 242 | parser = argparse.ArgumentParser( 243 | description="Generate a video from a text prompt or image using Gradio") 244 | parser.add_argument( 245 | "--ckpt_dir_720p", 246 | type=str, 247 | default=None, 248 | help="The path to the checkpoint directory.") 249 | parser.add_argument( 250 | "--ckpt_dir_480p", 251 | type=str, 252 | default=None, 253 | help="The path to the checkpoint directory.") 254 | parser.add_argument( 255 | "--prompt_extend_method", 256 | type=str, 257 | default="local_qwen", 258 | choices=["dashscope", "local_qwen"], 259 | help="The prompt extend method to use.") 260 | parser.add_argument( 261 | "--prompt_extend_model", 262 | type=str, 263 | default=None, 264 | help="The prompt extend model to use.") 265 | 266 | args = parser.parse_args() 267 | assert args.ckpt_dir_720p is not None or args.ckpt_dir_480p is not None, "Please specify at least one checkpoint directory." 268 | 269 | return args 270 | 271 | 272 | if __name__ == '__main__': 273 | args = _parse_args() 274 | 275 | print("Step1: Init prompt_expander...", end='', flush=True) 276 | if args.prompt_extend_method == "dashscope": 277 | prompt_expander = DashScopePromptExpander( 278 | model_name=args.prompt_extend_model, is_vl=True) 279 | elif args.prompt_extend_method == "local_qwen": 280 | prompt_expander = QwenPromptExpander( 281 | model_name=args.prompt_extend_model, is_vl=True, device=0) 282 | else: 283 | raise NotImplementedError( 284 | f"Unsupport prompt_extend_method: {args.prompt_extend_method}") 285 | print("done", flush=True) 286 | 287 | demo = gradio_interface() 288 | demo.launch(server_name="0.0.0.0", share=False, server_port=7860) 289 | -------------------------------------------------------------------------------- /gradio/t2i_14B_singleGPU.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import argparse 3 | import os 4 | import os.path as osp 5 | import sys 6 | import warnings 7 | 8 | import gradio as gr 9 | 10 | warnings.filterwarnings('ignore') 11 | 12 | # Model 13 | sys.path.insert( 14 | 0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2])) 15 | import wan 16 | from wan.configs import WAN_CONFIGS 17 | from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander 18 | from wan.utils.utils import cache_image 19 | 20 | # Global Var 21 | prompt_expander = None 22 | wan_t2i = None 23 | 24 | 25 | # Button Func 26 | def prompt_enc(prompt, tar_lang): 27 | global prompt_expander 28 | prompt_output = prompt_expander(prompt, tar_lang=tar_lang.lower()) 29 | if prompt_output.status == False: 30 | return prompt 31 | else: 32 | return prompt_output.prompt 33 | 34 | 35 | def t2i_generation(txt2img_prompt, resolution, sd_steps, guide_scale, 36 | shift_scale, seed, n_prompt): 37 | global wan_t2i 38 | # print(f"{txt2img_prompt},{resolution},{sd_steps},{guide_scale},{shift_scale},{seed},{n_prompt}") 39 | 40 | W = int(resolution.split("*")[0]) 41 | H = int(resolution.split("*")[1]) 42 | video = wan_t2i.generate( 43 | txt2img_prompt, 44 | size=(W, H), 45 | frame_num=1, 46 | shift=shift_scale, 47 | sampling_steps=sd_steps, 48 | guide_scale=guide_scale, 49 | n_prompt=n_prompt, 50 | seed=seed, 51 | offload_model=True) 52 | 53 | cache_image( 54 | tensor=video.squeeze(1)[None], 55 | save_file="example.png", 56 | nrow=1, 57 | normalize=True, 58 | value_range=(-1, 1)) 59 | 60 | return "example.png" 61 | 62 | 63 | # Interface 64 | def gradio_interface(): 65 | with gr.Blocks() as demo: 66 | gr.Markdown(""" 67 | <div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;"> 68 | Wan2.1 (T2I-14B) 69 | </div> 70 | <div style="text-align: center; font-size: 16px; font-weight: normal; margin-bottom: 20px;"> 71 | Wan: Open and Advanced Large-Scale Video Generative Models. 72 | </div> 73 | """) 74 | 75 | with gr.Row(): 76 | with gr.Column(): 77 | txt2img_prompt = gr.Textbox( 78 | label="Prompt", 79 | placeholder="Describe the image you want to generate", 80 | ) 81 | tar_lang = gr.Radio( 82 | choices=["ZH", "EN"], 83 | label="Target language of prompt enhance", 84 | value="ZH") 85 | run_p_button = gr.Button(value="Prompt Enhance") 86 | 87 | with gr.Accordion("Advanced Options", open=True): 88 | resolution = gr.Dropdown( 89 | label='Resolution(Width*Height)', 90 | choices=[ 91 | '720*1280', '1280*720', '960*960', '1088*832', 92 | '832*1088', '480*832', '832*480', '624*624', 93 | '704*544', '544*704' 94 | ], 95 | value='720*1280') 96 | 97 | with gr.Row(): 98 | sd_steps = gr.Slider( 99 | label="Diffusion steps", 100 | minimum=1, 101 | maximum=1000, 102 | value=50, 103 | step=1) 104 | guide_scale = gr.Slider( 105 | label="Guide scale", 106 | minimum=0, 107 | maximum=20, 108 | value=5.0, 109 | step=1) 110 | with gr.Row(): 111 | shift_scale = gr.Slider( 112 | label="Shift scale", 113 | minimum=0, 114 | maximum=10, 115 | value=5.0, 116 | step=1) 117 | seed = gr.Slider( 118 | label="Seed", 119 | minimum=-1, 120 | maximum=2147483647, 121 | step=1, 122 | value=-1) 123 | n_prompt = gr.Textbox( 124 | label="Negative Prompt", 125 | placeholder="Describe the negative prompt you want to add" 126 | ) 127 | 128 | run_t2i_button = gr.Button("Generate Image") 129 | 130 | with gr.Column(): 131 | result_gallery = gr.Image( 132 | label='Generated Image', interactive=False, height=600) 133 | 134 | run_p_button.click( 135 | fn=prompt_enc, 136 | inputs=[txt2img_prompt, tar_lang], 137 | outputs=[txt2img_prompt]) 138 | 139 | run_t2i_button.click( 140 | fn=t2i_generation, 141 | inputs=[ 142 | txt2img_prompt, resolution, sd_steps, guide_scale, shift_scale, 143 | seed, n_prompt 144 | ], 145 | outputs=[result_gallery], 146 | ) 147 | 148 | return demo 149 | 150 | 151 | # Main 152 | def _parse_args(): 153 | parser = argparse.ArgumentParser( 154 | description="Generate a image from a text prompt or image using Gradio") 155 | parser.add_argument( 156 | "--ckpt_dir", 157 | type=str, 158 | default="cache", 159 | help="The path to the checkpoint directory.") 160 | parser.add_argument( 161 | "--prompt_extend_method", 162 | type=str, 163 | default="local_qwen", 164 | choices=["dashscope", "local_qwen"], 165 | help="The prompt extend method to use.") 166 | parser.add_argument( 167 | "--prompt_extend_model", 168 | type=str, 169 | default=None, 170 | help="The prompt extend model to use.") 171 | 172 | args = parser.parse_args() 173 | 174 | return args 175 | 176 | 177 | if __name__ == '__main__': 178 | args = _parse_args() 179 | 180 | print("Step1: Init prompt_expander...", end='', flush=True) 181 | if args.prompt_extend_method == "dashscope": 182 | prompt_expander = DashScopePromptExpander( 183 | model_name=args.prompt_extend_model, is_vl=False) 184 | elif args.prompt_extend_method == "local_qwen": 185 | prompt_expander = QwenPromptExpander( 186 | model_name=args.prompt_extend_model, is_vl=False, device=0) 187 | else: 188 | raise NotImplementedError( 189 | f"Unsupport prompt_extend_method: {args.prompt_extend_method}") 190 | print("done", flush=True) 191 | 192 | print("Step2: Init 14B t2i model...", end='', flush=True) 193 | cfg = WAN_CONFIGS['t2i-14B'] 194 | wan_t2i = wan.WanT2V( 195 | config=cfg, 196 | checkpoint_dir=args.ckpt_dir, 197 | device_id=0, 198 | rank=0, 199 | t5_fsdp=False, 200 | dit_fsdp=False, 201 | use_usp=False, 202 | ) 203 | print("done", flush=True) 204 | 205 | demo = gradio_interface() 206 | demo.launch(server_name="0.0.0.0", share=False, server_port=7860) 207 | -------------------------------------------------------------------------------- /gradio/t2v_1.3B_singleGPU.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import argparse 3 | import os 4 | import os.path as osp 5 | import sys 6 | import warnings 7 | 8 | import gradio as gr 9 | 10 | warnings.filterwarnings('ignore') 11 | 12 | # Model 13 | sys.path.insert( 14 | 0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2])) 15 | import wan 16 | from wan.configs import WAN_CONFIGS 17 | from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander 18 | from wan.utils.utils import cache_video 19 | 20 | # Global Var 21 | prompt_expander = None 22 | wan_t2v = None 23 | 24 | 25 | # Button Func 26 | def prompt_enc(prompt, tar_lang): 27 | global prompt_expander 28 | prompt_output = prompt_expander(prompt, tar_lang=tar_lang.lower()) 29 | if prompt_output.status == False: 30 | return prompt 31 | else: 32 | return prompt_output.prompt 33 | 34 | 35 | def t2v_generation(txt2vid_prompt, resolution, sd_steps, guide_scale, 36 | shift_scale, seed, n_prompt): 37 | global wan_t2v 38 | # print(f"{txt2vid_prompt},{resolution},{sd_steps},{guide_scale},{shift_scale},{seed},{n_prompt}") 39 | 40 | W = int(resolution.split("*")[0]) 41 | H = int(resolution.split("*")[1]) 42 | video = wan_t2v.generate( 43 | txt2vid_prompt, 44 | size=(W, H), 45 | shift=shift_scale, 46 | sampling_steps=sd_steps, 47 | guide_scale=guide_scale, 48 | n_prompt=n_prompt, 49 | seed=seed, 50 | offload_model=True) 51 | 52 | cache_video( 53 | tensor=video[None], 54 | save_file="example.mp4", 55 | fps=16, 56 | nrow=1, 57 | normalize=True, 58 | value_range=(-1, 1)) 59 | 60 | return "example.mp4" 61 | 62 | 63 | # Interface 64 | def gradio_interface(): 65 | with gr.Blocks() as demo: 66 | gr.Markdown(""" 67 | <div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;"> 68 | Wan2.1 (T2V-1.3B) 69 | </div> 70 | <div style="text-align: center; font-size: 16px; font-weight: normal; margin-bottom: 20px;"> 71 | Wan: Open and Advanced Large-Scale Video Generative Models. 72 | </div> 73 | """) 74 | 75 | with gr.Row(): 76 | with gr.Column(): 77 | txt2vid_prompt = gr.Textbox( 78 | label="Prompt", 79 | placeholder="Describe the video you want to generate", 80 | ) 81 | tar_lang = gr.Radio( 82 | choices=["ZH", "EN"], 83 | label="Target language of prompt enhance", 84 | value="ZH") 85 | run_p_button = gr.Button(value="Prompt Enhance") 86 | 87 | with gr.Accordion("Advanced Options", open=True): 88 | resolution = gr.Dropdown( 89 | label='Resolution(Width*Height)', 90 | choices=[ 91 | '480*832', 92 | '832*480', 93 | '624*624', 94 | '704*544', 95 | '544*704', 96 | ], 97 | value='480*832') 98 | 99 | with gr.Row(): 100 | sd_steps = gr.Slider( 101 | label="Diffusion steps", 102 | minimum=1, 103 | maximum=1000, 104 | value=50, 105 | step=1) 106 | guide_scale = gr.Slider( 107 | label="Guide scale", 108 | minimum=0, 109 | maximum=20, 110 | value=6.0, 111 | step=1) 112 | with gr.Row(): 113 | shift_scale = gr.Slider( 114 | label="Shift scale", 115 | minimum=0, 116 | maximum=20, 117 | value=8.0, 118 | step=1) 119 | seed = gr.Slider( 120 | label="Seed", 121 | minimum=-1, 122 | maximum=2147483647, 123 | step=1, 124 | value=-1) 125 | n_prompt = gr.Textbox( 126 | label="Negative Prompt", 127 | placeholder="Describe the negative prompt you want to add" 128 | ) 129 | 130 | run_t2v_button = gr.Button("Generate Video") 131 | 132 | with gr.Column(): 133 | result_gallery = gr.Video( 134 | label='Generated Video', interactive=False, height=600) 135 | 136 | run_p_button.click( 137 | fn=prompt_enc, 138 | inputs=[txt2vid_prompt, tar_lang], 139 | outputs=[txt2vid_prompt]) 140 | 141 | run_t2v_button.click( 142 | fn=t2v_generation, 143 | inputs=[ 144 | txt2vid_prompt, resolution, sd_steps, guide_scale, shift_scale, 145 | seed, n_prompt 146 | ], 147 | outputs=[result_gallery], 148 | ) 149 | 150 | return demo 151 | 152 | 153 | # Main 154 | def _parse_args(): 155 | parser = argparse.ArgumentParser( 156 | description="Generate a video from a text prompt or image using Gradio") 157 | parser.add_argument( 158 | "--ckpt_dir", 159 | type=str, 160 | default="cache", 161 | help="The path to the checkpoint directory.") 162 | parser.add_argument( 163 | "--prompt_extend_method", 164 | type=str, 165 | default="local_qwen", 166 | choices=["dashscope", "local_qwen"], 167 | help="The prompt extend method to use.") 168 | parser.add_argument( 169 | "--prompt_extend_model", 170 | type=str, 171 | default=None, 172 | help="The prompt extend model to use.") 173 | 174 | args = parser.parse_args() 175 | 176 | return args 177 | 178 | 179 | if __name__ == '__main__': 180 | args = _parse_args() 181 | 182 | print("Step1: Init prompt_expander...", end='', flush=True) 183 | if args.prompt_extend_method == "dashscope": 184 | prompt_expander = DashScopePromptExpander( 185 | model_name=args.prompt_extend_model, is_vl=False) 186 | elif args.prompt_extend_method == "local_qwen": 187 | prompt_expander = QwenPromptExpander( 188 | model_name=args.prompt_extend_model, is_vl=False, device=0) 189 | else: 190 | raise NotImplementedError( 191 | f"Unsupport prompt_extend_method: {args.prompt_extend_method}") 192 | print("done", flush=True) 193 | 194 | print("Step2: Init 1.3B t2v model...", end='', flush=True) 195 | cfg = WAN_CONFIGS['t2v-1.3B'] 196 | wan_t2v = wan.WanT2V( 197 | config=cfg, 198 | checkpoint_dir=args.ckpt_dir, 199 | device_id=0, 200 | rank=0, 201 | t5_fsdp=False, 202 | dit_fsdp=False, 203 | use_usp=False, 204 | ) 205 | print("done", flush=True) 206 | 207 | demo = gradio_interface() 208 | demo.launch(server_name="0.0.0.0", share=False, server_port=7860) 209 | -------------------------------------------------------------------------------- /gradio/t2v_14B_singleGPU.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import argparse 3 | import os 4 | import os.path as osp 5 | import sys 6 | import warnings 7 | 8 | import gradio as gr 9 | 10 | warnings.filterwarnings('ignore') 11 | 12 | # Model 13 | sys.path.insert( 14 | 0, os.path.sep.join(osp.realpath(__file__).split(os.path.sep)[:-2])) 15 | import wan 16 | from wan.configs import WAN_CONFIGS 17 | from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander 18 | from wan.utils.utils import cache_video 19 | 20 | # Global Var 21 | prompt_expander = None 22 | wan_t2v = None 23 | 24 | 25 | # Button Func 26 | def prompt_enc(prompt, tar_lang): 27 | global prompt_expander 28 | prompt_output = prompt_expander(prompt, tar_lang=tar_lang.lower()) 29 | if prompt_output.status == False: 30 | return prompt 31 | else: 32 | return prompt_output.prompt 33 | 34 | 35 | def t2v_generation(txt2vid_prompt, resolution, sd_steps, guide_scale, 36 | shift_scale, seed, n_prompt): 37 | global wan_t2v 38 | # print(f"{txt2vid_prompt},{resolution},{sd_steps},{guide_scale},{shift_scale},{seed},{n_prompt}") 39 | 40 | W = int(resolution.split("*")[0]) 41 | H = int(resolution.split("*")[1]) 42 | video = wan_t2v.generate( 43 | txt2vid_prompt, 44 | size=(W, H), 45 | shift=shift_scale, 46 | sampling_steps=sd_steps, 47 | guide_scale=guide_scale, 48 | n_prompt=n_prompt, 49 | seed=seed, 50 | offload_model=True) 51 | 52 | cache_video( 53 | tensor=video[None], 54 | save_file="example.mp4", 55 | fps=16, 56 | nrow=1, 57 | normalize=True, 58 | value_range=(-1, 1)) 59 | 60 | return "example.mp4" 61 | 62 | 63 | # Interface 64 | def gradio_interface(): 65 | with gr.Blocks() as demo: 66 | gr.Markdown(""" 67 | <div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;"> 68 | Wan2.1 (T2V-14B) 69 | </div> 70 | <div style="text-align: center; font-size: 16px; font-weight: normal; margin-bottom: 20px;"> 71 | Wan: Open and Advanced Large-Scale Video Generative Models. 72 | </div> 73 | """) 74 | 75 | with gr.Row(): 76 | with gr.Column(): 77 | txt2vid_prompt = gr.Textbox( 78 | label="Prompt", 79 | placeholder="Describe the video you want to generate", 80 | ) 81 | tar_lang = gr.Radio( 82 | choices=["ZH", "EN"], 83 | label="Target language of prompt enhance", 84 | value="ZH") 85 | run_p_button = gr.Button(value="Prompt Enhance") 86 | 87 | with gr.Accordion("Advanced Options", open=True): 88 | resolution = gr.Dropdown( 89 | label='Resolution(Width*Height)', 90 | choices=[ 91 | '720*1280', '1280*720', '960*960', '1088*832', 92 | '832*1088', '480*832', '832*480', '624*624', 93 | '704*544', '544*704' 94 | ], 95 | value='720*1280') 96 | 97 | with gr.Row(): 98 | sd_steps = gr.Slider( 99 | label="Diffusion steps", 100 | minimum=1, 101 | maximum=1000, 102 | value=50, 103 | step=1) 104 | guide_scale = gr.Slider( 105 | label="Guide scale", 106 | minimum=0, 107 | maximum=20, 108 | value=5.0, 109 | step=1) 110 | with gr.Row(): 111 | shift_scale = gr.Slider( 112 | label="Shift scale", 113 | minimum=0, 114 | maximum=10, 115 | value=5.0, 116 | step=1) 117 | seed = gr.Slider( 118 | label="Seed", 119 | minimum=-1, 120 | maximum=2147483647, 121 | step=1, 122 | value=-1) 123 | n_prompt = gr.Textbox( 124 | label="Negative Prompt", 125 | placeholder="Describe the negative prompt you want to add" 126 | ) 127 | 128 | run_t2v_button = gr.Button("Generate Video") 129 | 130 | with gr.Column(): 131 | result_gallery = gr.Video( 132 | label='Generated Video', interactive=False, height=600) 133 | 134 | run_p_button.click( 135 | fn=prompt_enc, 136 | inputs=[txt2vid_prompt, tar_lang], 137 | outputs=[txt2vid_prompt]) 138 | 139 | run_t2v_button.click( 140 | fn=t2v_generation, 141 | inputs=[ 142 | txt2vid_prompt, resolution, sd_steps, guide_scale, shift_scale, 143 | seed, n_prompt 144 | ], 145 | outputs=[result_gallery], 146 | ) 147 | 148 | return demo 149 | 150 | 151 | # Main 152 | def _parse_args(): 153 | parser = argparse.ArgumentParser( 154 | description="Generate a video from a text prompt or image using Gradio") 155 | parser.add_argument( 156 | "--ckpt_dir", 157 | type=str, 158 | default="cache", 159 | help="The path to the checkpoint directory.") 160 | parser.add_argument( 161 | "--prompt_extend_method", 162 | type=str, 163 | default="local_qwen", 164 | choices=["dashscope", "local_qwen"], 165 | help="The prompt extend method to use.") 166 | parser.add_argument( 167 | "--prompt_extend_model", 168 | type=str, 169 | default=None, 170 | help="The prompt extend model to use.") 171 | 172 | args = parser.parse_args() 173 | 174 | return args 175 | 176 | 177 | if __name__ == '__main__': 178 | args = _parse_args() 179 | 180 | print("Step1: Init prompt_expander...", end='', flush=True) 181 | if args.prompt_extend_method == "dashscope": 182 | prompt_expander = DashScopePromptExpander( 183 | model_name=args.prompt_extend_model, is_vl=False) 184 | elif args.prompt_extend_method == "local_qwen": 185 | prompt_expander = QwenPromptExpander( 186 | model_name=args.prompt_extend_model, is_vl=False, device=0) 187 | else: 188 | raise NotImplementedError( 189 | f"Unsupport prompt_extend_method: {args.prompt_extend_method}") 190 | print("done", flush=True) 191 | 192 | print("Step2: Init 14B t2v model...", end='', flush=True) 193 | cfg = WAN_CONFIGS['t2v-14B'] 194 | wan_t2v = wan.WanT2V( 195 | config=cfg, 196 | checkpoint_dir=args.ckpt_dir, 197 | device_id=0, 198 | rank=0, 199 | t5_fsdp=False, 200 | dit_fsdp=False, 201 | use_usp=False, 202 | ) 203 | print("done", flush=True) 204 | 205 | demo = gradio_interface() 206 | demo.launch(server_name="0.0.0.0", share=False, server_port=7860) 207 | -------------------------------------------------------------------------------- /gradio/vace.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Alibaba, Inc. and its affiliates. 3 | 4 | import argparse 5 | import datetime 6 | import os 7 | import sys 8 | 9 | import imageio 10 | import numpy as np 11 | import torch 12 | 13 | import gradio as gr 14 | 15 | sys.path.insert( 16 | 0, os.path.sep.join(os.path.realpath(__file__).split(os.path.sep)[:-2])) 17 | import wan 18 | from wan import WanVace, WanVaceMP 19 | from wan.configs import SIZE_CONFIGS, WAN_CONFIGS 20 | 21 | 22 | class FixedSizeQueue: 23 | 24 | def __init__(self, max_size): 25 | self.max_size = max_size 26 | self.queue = [] 27 | 28 | def add(self, item): 29 | self.queue.insert(0, item) 30 | if len(self.queue) > self.max_size: 31 | self.queue.pop() 32 | 33 | def get(self): 34 | return self.queue 35 | 36 | def __repr__(self): 37 | return str(self.queue) 38 | 39 | 40 | class VACEInference: 41 | 42 | def __init__(self, 43 | cfg, 44 | skip_load=False, 45 | gallery_share=True, 46 | gallery_share_limit=5): 47 | self.cfg = cfg 48 | self.save_dir = cfg.save_dir 49 | self.gallery_share = gallery_share 50 | self.gallery_share_data = FixedSizeQueue(max_size=gallery_share_limit) 51 | if not skip_load: 52 | if not args.mp: 53 | self.pipe = WanVace( 54 | config=WAN_CONFIGS[cfg.model_name], 55 | checkpoint_dir=cfg.ckpt_dir, 56 | device_id=0, 57 | rank=0, 58 | t5_fsdp=False, 59 | dit_fsdp=False, 60 | use_usp=False, 61 | ) 62 | else: 63 | self.pipe = WanVaceMP( 64 | config=WAN_CONFIGS[cfg.model_name], 65 | checkpoint_dir=cfg.ckpt_dir, 66 | use_usp=True, 67 | ulysses_size=cfg.ulysses_size, 68 | ring_size=cfg.ring_size) 69 | 70 | def create_ui(self, *args, **kwargs): 71 | gr.Markdown(""" 72 | <div style="text-align: center; font-size: 24px; font-weight: bold; margin-bottom: 15px;"> 73 | <a href="https://ali-vilab.github.io/VACE-Page/" style="text-decoration: none; color: inherit;">VACE-WAN Demo</a> 74 | </div> 75 | """) 76 | with gr.Row(variant='panel', equal_height=True): 77 | with gr.Column(scale=1, min_width=0): 78 | self.src_video = gr.Video( 79 | label="src_video", 80 | sources=['upload'], 81 | value=None, 82 | interactive=True) 83 | with gr.Column(scale=1, min_width=0): 84 | self.src_mask = gr.Video( 85 | label="src_mask", 86 | sources=['upload'], 87 | value=None, 88 | interactive=True) 89 | # 90 | with gr.Row(variant='panel', equal_height=True): 91 | with gr.Column(scale=1, min_width=0): 92 | with gr.Row(equal_height=True): 93 | self.src_ref_image_1 = gr.Image( 94 | label='src_ref_image_1', 95 | height=200, 96 | interactive=True, 97 | type='filepath', 98 | image_mode='RGB', 99 | sources=['upload'], 100 | elem_id="src_ref_image_1", 101 | format='png') 102 | self.src_ref_image_2 = gr.Image( 103 | label='src_ref_image_2', 104 | height=200, 105 | interactive=True, 106 | type='filepath', 107 | image_mode='RGB', 108 | sources=['upload'], 109 | elem_id="src_ref_image_2", 110 | format='png') 111 | self.src_ref_image_3 = gr.Image( 112 | label='src_ref_image_3', 113 | height=200, 114 | interactive=True, 115 | type='filepath', 116 | image_mode='RGB', 117 | sources=['upload'], 118 | elem_id="src_ref_image_3", 119 | format='png') 120 | with gr.Row(variant='panel', equal_height=True): 121 | with gr.Column(scale=1): 122 | self.prompt = gr.Textbox( 123 | show_label=False, 124 | placeholder="positive_prompt_input", 125 | elem_id='positive_prompt', 126 | container=True, 127 | autofocus=True, 128 | elem_classes='type_row', 129 | visible=True, 130 | lines=2) 131 | self.negative_prompt = gr.Textbox( 132 | show_label=False, 133 | value=self.pipe.config.sample_neg_prompt, 134 | placeholder="negative_prompt_input", 135 | elem_id='negative_prompt', 136 | container=True, 137 | autofocus=False, 138 | elem_classes='type_row', 139 | visible=True, 140 | interactive=True, 141 | lines=1) 142 | # 143 | with gr.Row(variant='panel', equal_height=True): 144 | with gr.Column(scale=1, min_width=0): 145 | with gr.Row(equal_height=True): 146 | self.shift_scale = gr.Slider( 147 | label='shift_scale', 148 | minimum=0.0, 149 | maximum=100.0, 150 | step=1.0, 151 | value=16.0, 152 | interactive=True) 153 | self.sample_steps = gr.Slider( 154 | label='sample_steps', 155 | minimum=1, 156 | maximum=100, 157 | step=1, 158 | value=25, 159 | interactive=True) 160 | self.context_scale = gr.Slider( 161 | label='context_scale', 162 | minimum=0.0, 163 | maximum=2.0, 164 | step=0.1, 165 | value=1.0, 166 | interactive=True) 167 | self.guide_scale = gr.Slider( 168 | label='guide_scale', 169 | minimum=1, 170 | maximum=10, 171 | step=0.5, 172 | value=5.0, 173 | interactive=True) 174 | self.infer_seed = gr.Slider( 175 | minimum=-1, maximum=10000000, value=2025, label="Seed") 176 | # 177 | with gr.Accordion(label="Usable without source video", open=False): 178 | with gr.Row(equal_height=True): 179 | self.output_height = gr.Textbox( 180 | label='resolutions_height', 181 | # value=480, 182 | value=720, 183 | interactive=True) 184 | self.output_width = gr.Textbox( 185 | label='resolutions_width', 186 | # value=832, 187 | value=1280, 188 | interactive=True) 189 | self.frame_rate = gr.Textbox( 190 | label='frame_rate', value=16, interactive=True) 191 | self.num_frames = gr.Textbox( 192 | label='num_frames', value=81, interactive=True) 193 | # 194 | with gr.Row(equal_height=True): 195 | with gr.Column(scale=5): 196 | self.generate_button = gr.Button( 197 | value='Run', 198 | elem_classes='type_row', 199 | elem_id='generate_button', 200 | visible=True) 201 | with gr.Column(scale=1): 202 | self.refresh_button = gr.Button(value='\U0001f504') # 🔄 203 | # 204 | self.output_gallery = gr.Gallery( 205 | label="output_gallery", 206 | value=[], 207 | interactive=False, 208 | allow_preview=True, 209 | preview=True) 210 | 211 | def generate(self, output_gallery, src_video, src_mask, src_ref_image_1, 212 | src_ref_image_2, src_ref_image_3, prompt, negative_prompt, 213 | shift_scale, sample_steps, context_scale, guide_scale, 214 | infer_seed, output_height, output_width, frame_rate, 215 | num_frames): 216 | output_height, output_width, frame_rate, num_frames = int( 217 | output_height), int(output_width), int(frame_rate), int(num_frames) 218 | src_ref_images = [ 219 | x for x in [src_ref_image_1, src_ref_image_2, src_ref_image_3] 220 | if x is not None 221 | ] 222 | src_video, src_mask, src_ref_images = self.pipe.prepare_source( 223 | [src_video], [src_mask], [src_ref_images], 224 | num_frames=num_frames, 225 | image_size=SIZE_CONFIGS[f"{output_width}*{output_height}"], 226 | device=self.pipe.device) 227 | video = self.pipe.generate( 228 | prompt, 229 | src_video, 230 | src_mask, 231 | src_ref_images, 232 | size=(output_width, output_height), 233 | context_scale=context_scale, 234 | shift=shift_scale, 235 | sampling_steps=sample_steps, 236 | guide_scale=guide_scale, 237 | n_prompt=negative_prompt, 238 | seed=infer_seed, 239 | offload_model=True) 240 | 241 | name = '{0:%Y%m%d%-H%M%S}'.format(datetime.datetime.now()) 242 | video_path = os.path.join(self.save_dir, f'cur_gallery_{name}.mp4') 243 | video_frames = ( 244 | torch.clamp(video / 2 + 0.5, min=0.0, max=1.0).permute(1, 2, 3, 0) * 245 | 255).cpu().numpy().astype(np.uint8) 246 | 247 | try: 248 | writer = imageio.get_writer( 249 | video_path, 250 | fps=frame_rate, 251 | codec='libx264', 252 | quality=8, 253 | macro_block_size=1) 254 | for frame in video_frames: 255 | writer.append_data(frame) 256 | writer.close() 257 | print(video_path) 258 | except Exception as e: 259 | raise gr.Error(f"Video save error: {e}") 260 | 261 | if self.gallery_share: 262 | self.gallery_share_data.add(video_path) 263 | return self.gallery_share_data.get() 264 | else: 265 | return [video_path] 266 | 267 | def set_callbacks(self, **kwargs): 268 | self.gen_inputs = [ 269 | self.output_gallery, self.src_video, self.src_mask, 270 | self.src_ref_image_1, self.src_ref_image_2, self.src_ref_image_3, 271 | self.prompt, self.negative_prompt, self.shift_scale, 272 | self.sample_steps, self.context_scale, self.guide_scale, 273 | self.infer_seed, self.output_height, self.output_width, 274 | self.frame_rate, self.num_frames 275 | ] 276 | self.gen_outputs = [self.output_gallery] 277 | self.generate_button.click( 278 | self.generate, 279 | inputs=self.gen_inputs, 280 | outputs=self.gen_outputs, 281 | queue=True) 282 | self.refresh_button.click( 283 | lambda x: self.gallery_share_data.get() 284 | if self.gallery_share else x, 285 | inputs=[self.output_gallery], 286 | outputs=[self.output_gallery]) 287 | 288 | 289 | if __name__ == '__main__': 290 | parser = argparse.ArgumentParser( 291 | description='Argparser for VACE-WAN Demo:\n') 292 | parser.add_argument( 293 | '--server_port', dest='server_port', help='', type=int, default=7860) 294 | parser.add_argument( 295 | '--server_name', dest='server_name', help='', default='0.0.0.0') 296 | parser.add_argument('--root_path', dest='root_path', help='', default=None) 297 | parser.add_argument('--save_dir', dest='save_dir', help='', default='cache') 298 | parser.add_argument( 299 | "--mp", 300 | action="store_true", 301 | help="Use Multi-GPUs", 302 | ) 303 | parser.add_argument( 304 | "--model_name", 305 | type=str, 306 | default="vace-14B", 307 | choices=list(WAN_CONFIGS.keys()), 308 | help="The model name to run.") 309 | parser.add_argument( 310 | "--ulysses_size", 311 | type=int, 312 | default=1, 313 | help="The size of the ulysses parallelism in DiT.") 314 | parser.add_argument( 315 | "--ring_size", 316 | type=int, 317 | default=1, 318 | help="The size of the ring attention parallelism in DiT.") 319 | parser.add_argument( 320 | "--ckpt_dir", 321 | type=str, 322 | # default='models/VACE-Wan2.1-1.3B-Preview', 323 | default='models/Wan2.1-VACE-14B/', 324 | help="The path to the checkpoint directory.", 325 | ) 326 | parser.add_argument( 327 | "--offload_to_cpu", 328 | action="store_true", 329 | help="Offloading unnecessary computations to CPU.", 330 | ) 331 | 332 | args = parser.parse_args() 333 | 334 | if not os.path.exists(args.save_dir): 335 | os.makedirs(args.save_dir, exist_ok=True) 336 | 337 | with gr.Blocks() as demo: 338 | infer_gr = VACEInference( 339 | args, skip_load=False, gallery_share=True, gallery_share_limit=5) 340 | infer_gr.create_ui() 341 | infer_gr.set_callbacks() 342 | allowed_paths = [args.save_dir] 343 | demo.queue(status_update_rate=1).launch( 344 | server_name=args.server_name, 345 | server_port=args.server_port, 346 | root_path=args.root_path, 347 | allowed_paths=allowed_paths, 348 | show_error=True, 349 | debug=True) 350 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "wan" 7 | version = "2.1.0" 8 | description = "Wan: Open and Advanced Large-Scale Video Generative Models" 9 | authors = [ 10 | { name = "Wan Team", email = "wan.ai@alibabacloud.com" } 11 | ] 12 | license = { file = "LICENSE.txt" } 13 | readme = "README.md" 14 | requires-python = ">=3.10,<4.0" 15 | dependencies = [ 16 | "torch>=2.4.0", 17 | "torchvision>=0.19.0", 18 | "opencv-python>=4.9.0.80", 19 | "diffusers>=0.31.0", 20 | "transformers>=4.49.0", 21 | "tokenizers>=0.20.3", 22 | "accelerate>=1.1.1", 23 | "tqdm", 24 | "imageio", 25 | "easydict", 26 | "ftfy", 27 | "dashscope", 28 | "imageio-ffmpeg", 29 | "flash_attn", 30 | "gradio>=5.0.0", 31 | "numpy>=1.23.5,<2" 32 | ] 33 | 34 | [project.optional-dependencies] 35 | dev = [ 36 | "pytest", 37 | "black", 38 | "flake8", 39 | "isort", 40 | "mypy", 41 | "huggingface-hub[cli]" 42 | ] 43 | 44 | [project.urls] 45 | homepage = "https://wanxai.com" 46 | documentation = "https://github.com/Wan-Video/Wan2.1" 47 | repository = "https://github.com/Wan-Video/Wan2.1" 48 | huggingface = "https://huggingface.co/Wan-AI/" 49 | modelscope = "https://modelscope.cn/organization/Wan-AI" 50 | discord = "https://discord.gg/p5XbdQV7" 51 | 52 | [tool.setuptools] 53 | packages = ["wan"] 54 | 55 | [tool.setuptools.package-data] 56 | "wan" = ["**/*.py"] 57 | 58 | [tool.black] 59 | line-length = 88 60 | 61 | [tool.isort] 62 | profile = "black" 63 | 64 | [tool.mypy] 65 | strict = true 66 | 67 | 68 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=2.4.0 2 | torchvision>=0.19.0 3 | opencv-python>=4.9.0.80 4 | diffusers>=0.31.0 5 | transformers>=4.49.0 6 | tokenizers>=0.20.3 7 | accelerate>=1.1.1 8 | tqdm 9 | imageio 10 | easydict 11 | ftfy 12 | dashscope 13 | imageio-ffmpeg 14 | flash_attn 15 | gradio>=5.0.0 16 | numpy>=1.23.5,<2 17 | -------------------------------------------------------------------------------- /tests/README.md: -------------------------------------------------------------------------------- 1 | 2 | Put all your models (Wan2.1-T2V-1.3B, Wan2.1-T2V-14B, Wan2.1-I2V-14B-480P, Wan2.1-I2V-14B-720P) in a folder and specify the max GPU number you want to use. 3 | 4 | ```bash 5 | bash ./test.sh <local model dir> <gpu number> 6 | ``` 7 | -------------------------------------------------------------------------------- /tests/test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | if [ "$#" -eq 2 ]; then 5 | MODEL_DIR=$(realpath "$1") 6 | GPUS=$2 7 | else 8 | echo "Usage: $0 <local model dir> <gpu number>" 9 | exit 1 10 | fi 11 | 12 | SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" 13 | REPO_ROOT="$(dirname "$SCRIPT_DIR")" 14 | cd "$REPO_ROOT" || exit 1 15 | 16 | PY_FILE=./generate.py 17 | 18 | 19 | function t2v_1_3B() { 20 | T2V_1_3B_CKPT_DIR="$MODEL_DIR/Wan2.1-T2V-1.3B" 21 | 22 | # 1-GPU Test 23 | echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_1_3B 1-GPU Test: " 24 | python $PY_FILE --task t2v-1.3B --size 480*832 --ckpt_dir $T2V_1_3B_CKPT_DIR 25 | 26 | # Multiple GPU Test 27 | echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_1_3B Multiple GPU Test: " 28 | torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-1.3B --ckpt_dir $T2V_1_3B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS 29 | 30 | echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_1_3B Multiple GPU, prompt extend local_qwen: " 31 | torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-1.3B --ckpt_dir $T2V_1_3B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_model "Qwen/Qwen2.5-3B-Instruct" --prompt_extend_target_lang "en" 32 | 33 | if [ -n "${DASH_API_KEY+x}" ]; then 34 | echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_1_3B Multiple GPU, prompt extend dashscope: " 35 | torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-1.3B --ckpt_dir $T2V_1_3B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_method "dashscope" 36 | else 37 | echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> No DASH_API_KEY found, skip the dashscope extend test." 38 | fi 39 | } 40 | 41 | function t2v_14B() { 42 | T2V_14B_CKPT_DIR="$MODEL_DIR/Wan2.1-T2V-14B" 43 | 44 | # 1-GPU Test 45 | echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_14B 1-GPU Test: " 46 | python $PY_FILE --task t2v-14B --size 480*832 --ckpt_dir $T2V_14B_CKPT_DIR 47 | 48 | # Multiple GPU Test 49 | echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_14B Multiple GPU Test: " 50 | torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-14B --ckpt_dir $T2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS 51 | 52 | echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2v_14B Multiple GPU, prompt extend local_qwen: " 53 | torchrun --nproc_per_node=$GPUS $PY_FILE --task t2v-14B --ckpt_dir $T2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_model "Qwen/Qwen2.5-3B-Instruct" --prompt_extend_target_lang "en" 54 | } 55 | 56 | 57 | 58 | function t2i_14B() { 59 | T2V_14B_CKPT_DIR="$MODEL_DIR/Wan2.1-T2V-14B" 60 | 61 | # 1-GPU Test 62 | echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2i_14B 1-GPU Test: " 63 | python $PY_FILE --task t2i-14B --size 480*832 --ckpt_dir $T2V_14B_CKPT_DIR 64 | 65 | # Multiple GPU Test 66 | echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2i_14B Multiple GPU Test: " 67 | torchrun --nproc_per_node=$GPUS $PY_FILE --task t2i-14B --ckpt_dir $T2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS 68 | 69 | echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> t2i_14B Multiple GPU, prompt extend local_qwen: " 70 | torchrun --nproc_per_node=$GPUS $PY_FILE --task t2i-14B --ckpt_dir $T2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_model "Qwen/Qwen2.5-3B-Instruct" --prompt_extend_target_lang "en" 71 | } 72 | 73 | 74 | function i2v_14B_480p() { 75 | I2V_14B_CKPT_DIR="$MODEL_DIR/Wan2.1-I2V-14B-480P" 76 | 77 | echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B 1-GPU Test: " 78 | python $PY_FILE --task i2v-14B --size 832*480 --ckpt_dir $I2V_14B_CKPT_DIR 79 | 80 | # Multiple GPU Test 81 | echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B Multiple GPU Test: " 82 | torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-14B --ckpt_dir $I2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS 83 | 84 | echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B Multiple GPU, prompt extend local_qwen: " 85 | torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-14B --ckpt_dir $I2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_model "Qwen/Qwen2.5-VL-3B-Instruct" --prompt_extend_target_lang "en" 86 | 87 | if [ -n "${DASH_API_KEY+x}" ]; then 88 | echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B Multiple GPU, prompt extend dashscope: " 89 | torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-14B --ckpt_dir $I2V_14B_CKPT_DIR --size 832*480 --dit_fsdp --t5_fsdp --ulysses_size $GPUS --use_prompt_extend --prompt_extend_method "dashscope" 90 | else 91 | echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> No DASH_API_KEY found, skip the dashscope extend test." 92 | fi 93 | } 94 | 95 | 96 | function i2v_14B_720p() { 97 | I2V_14B_CKPT_DIR="$MODEL_DIR/Wan2.1-I2V-14B-720P" 98 | 99 | # 1-GPU Test 100 | echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B 1-GPU Test: " 101 | python $PY_FILE --task i2v-14B --size 720*1280 --ckpt_dir $I2V_14B_CKPT_DIR 102 | 103 | # Multiple GPU Test 104 | echo -e "\n\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> i2v_14B Multiple GPU Test: " 105 | torchrun --nproc_per_node=$GPUS $PY_FILE --task i2v-14B --ckpt_dir $I2V_14B_CKPT_DIR --size 720*1280 --dit_fsdp --t5_fsdp --ulysses_size $GPUS 106 | } 107 | 108 | function vace_1_3B() { 109 | VACE_1_3B_CKPT_DIR="$MODEL_DIR/VACE-Wan2.1-1.3B-Preview/" 110 | torchrun --nproc_per_node=$GPUS $PY_FILE --ulysses_size $GPUS --task vace-1.3B --size 480*832 --ckpt_dir $VACE_1_3B_CKPT_DIR 111 | 112 | } 113 | 114 | 115 | t2i_14B 116 | t2v_1_3B 117 | t2v_14B 118 | i2v_14B_480p 119 | i2v_14B_720p 120 | vace_1_3B 121 | -------------------------------------------------------------------------------- /wan/__init__.py: -------------------------------------------------------------------------------- 1 | from . import configs, distributed, modules 2 | from .first_last_frame2video import WanFLF2V 3 | from .image2video import WanI2V 4 | from .text2video import WanT2V 5 | from .vace import WanVace, WanVaceMP 6 | -------------------------------------------------------------------------------- /wan/configs/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import copy 3 | import os 4 | 5 | os.environ['TOKENIZERS_PARALLELISM'] = 'false' 6 | 7 | from .wan_i2v_14B import i2v_14B 8 | from .wan_t2v_1_3B import t2v_1_3B 9 | from .wan_t2v_14B import t2v_14B 10 | 11 | # the config of t2i_14B is the same as t2v_14B 12 | t2i_14B = copy.deepcopy(t2v_14B) 13 | t2i_14B.__name__ = 'Config: Wan T2I 14B' 14 | 15 | # the config of flf2v_14B is the same as i2v_14B 16 | flf2v_14B = copy.deepcopy(i2v_14B) 17 | flf2v_14B.__name__ = 'Config: Wan FLF2V 14B' 18 | flf2v_14B.sample_neg_prompt = "镜头切换," + flf2v_14B.sample_neg_prompt 19 | 20 | WAN_CONFIGS = { 21 | 't2v-14B': t2v_14B, 22 | 't2v-1.3B': t2v_1_3B, 23 | 'i2v-14B': i2v_14B, 24 | 't2i-14B': t2i_14B, 25 | 'flf2v-14B': flf2v_14B, 26 | 'vace-1.3B': t2v_1_3B, 27 | 'vace-14B': t2v_14B, 28 | } 29 | 30 | SIZE_CONFIGS = { 31 | '720*1280': (720, 1280), 32 | '1280*720': (1280, 720), 33 | '480*832': (480, 832), 34 | '832*480': (832, 480), 35 | '1024*1024': (1024, 1024), 36 | } 37 | 38 | MAX_AREA_CONFIGS = { 39 | '720*1280': 720 * 1280, 40 | '1280*720': 1280 * 720, 41 | '480*832': 480 * 832, 42 | '832*480': 832 * 480, 43 | } 44 | 45 | SUPPORTED_SIZES = { 46 | 't2v-14B': ('720*1280', '1280*720', '480*832', '832*480'), 47 | 't2v-1.3B': ('480*832', '832*480'), 48 | 'i2v-14B': ('720*1280', '1280*720', '480*832', '832*480'), 49 | 'flf2v-14B': ('720*1280', '1280*720', '480*832', '832*480'), 50 | 't2i-14B': tuple(SIZE_CONFIGS.keys()), 51 | 'vace-1.3B': ('480*832', '832*480'), 52 | 'vace-14B': ('720*1280', '1280*720', '480*832', '832*480') 53 | } 54 | -------------------------------------------------------------------------------- /wan/configs/shared_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import torch 3 | from easydict import EasyDict 4 | 5 | #------------------------ Wan shared config ------------------------# 6 | wan_shared_cfg = EasyDict() 7 | 8 | # t5 9 | wan_shared_cfg.t5_model = 'umt5_xxl' 10 | wan_shared_cfg.t5_dtype = torch.bfloat16 11 | wan_shared_cfg.text_len = 512 12 | 13 | # transformer 14 | wan_shared_cfg.param_dtype = torch.bfloat16 15 | 16 | # inference 17 | wan_shared_cfg.num_train_timesteps = 1000 18 | wan_shared_cfg.sample_fps = 16 19 | wan_shared_cfg.sample_neg_prompt = '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走' 20 | -------------------------------------------------------------------------------- /wan/configs/wan_i2v_14B.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import torch 3 | from easydict import EasyDict 4 | 5 | from .shared_config import wan_shared_cfg 6 | 7 | #------------------------ Wan I2V 14B ------------------------# 8 | 9 | i2v_14B = EasyDict(__name__='Config: Wan I2V 14B') 10 | i2v_14B.update(wan_shared_cfg) 11 | i2v_14B.sample_neg_prompt = "镜头晃动," + i2v_14B.sample_neg_prompt 12 | 13 | i2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth' 14 | i2v_14B.t5_tokenizer = 'google/umt5-xxl' 15 | 16 | # clip 17 | i2v_14B.clip_model = 'clip_xlm_roberta_vit_h_14' 18 | i2v_14B.clip_dtype = torch.float16 19 | i2v_14B.clip_checkpoint = 'models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth' 20 | i2v_14B.clip_tokenizer = 'xlm-roberta-large' 21 | 22 | # vae 23 | i2v_14B.vae_checkpoint = 'Wan2.1_VAE.pth' 24 | i2v_14B.vae_stride = (4, 8, 8) 25 | 26 | # transformer 27 | i2v_14B.patch_size = (1, 2, 2) 28 | i2v_14B.dim = 5120 29 | i2v_14B.ffn_dim = 13824 30 | i2v_14B.freq_dim = 256 31 | i2v_14B.num_heads = 40 32 | i2v_14B.num_layers = 40 33 | i2v_14B.window_size = (-1, -1) 34 | i2v_14B.qk_norm = True 35 | i2v_14B.cross_attn_norm = True 36 | i2v_14B.eps = 1e-6 37 | -------------------------------------------------------------------------------- /wan/configs/wan_t2v_14B.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | from easydict import EasyDict 3 | 4 | from .shared_config import wan_shared_cfg 5 | 6 | #------------------------ Wan T2V 14B ------------------------# 7 | 8 | t2v_14B = EasyDict(__name__='Config: Wan T2V 14B') 9 | t2v_14B.update(wan_shared_cfg) 10 | 11 | # t5 12 | t2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth' 13 | t2v_14B.t5_tokenizer = 'google/umt5-xxl' 14 | 15 | # vae 16 | t2v_14B.vae_checkpoint = 'Wan2.1_VAE.pth' 17 | t2v_14B.vae_stride = (4, 8, 8) 18 | 19 | # transformer 20 | t2v_14B.patch_size = (1, 2, 2) 21 | t2v_14B.dim = 5120 22 | t2v_14B.ffn_dim = 13824 23 | t2v_14B.freq_dim = 256 24 | t2v_14B.num_heads = 40 25 | t2v_14B.num_layers = 40 26 | t2v_14B.window_size = (-1, -1) 27 | t2v_14B.qk_norm = True 28 | t2v_14B.cross_attn_norm = True 29 | t2v_14B.eps = 1e-6 30 | -------------------------------------------------------------------------------- /wan/configs/wan_t2v_1_3B.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | from easydict import EasyDict 3 | 4 | from .shared_config import wan_shared_cfg 5 | 6 | #------------------------ Wan T2V 1.3B ------------------------# 7 | 8 | t2v_1_3B = EasyDict(__name__='Config: Wan T2V 1.3B') 9 | t2v_1_3B.update(wan_shared_cfg) 10 | 11 | # t5 12 | t2v_1_3B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth' 13 | t2v_1_3B.t5_tokenizer = 'google/umt5-xxl' 14 | 15 | # vae 16 | t2v_1_3B.vae_checkpoint = 'Wan2.1_VAE.pth' 17 | t2v_1_3B.vae_stride = (4, 8, 8) 18 | 19 | # transformer 20 | t2v_1_3B.patch_size = (1, 2, 2) 21 | t2v_1_3B.dim = 1536 22 | t2v_1_3B.ffn_dim = 8960 23 | t2v_1_3B.freq_dim = 256 24 | t2v_1_3B.num_heads = 12 25 | t2v_1_3B.num_layers = 30 26 | t2v_1_3B.window_size = (-1, -1) 27 | t2v_1_3B.qk_norm = True 28 | t2v_1_3B.cross_attn_norm = True 29 | t2v_1_3B.eps = 1e-6 30 | -------------------------------------------------------------------------------- /wan/distributed/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wan-Video/Wan2.1/261ca43e67ccaf68dbb5c0d8d1f971efde267573/wan/distributed/__init__.py -------------------------------------------------------------------------------- /wan/distributed/fsdp.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import gc 3 | from functools import partial 4 | 5 | import torch 6 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 7 | from torch.distributed.fsdp import MixedPrecision, ShardingStrategy 8 | from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy 9 | from torch.distributed.utils import _free_storage 10 | 11 | 12 | def shard_model( 13 | model, 14 | device_id, 15 | param_dtype=torch.bfloat16, 16 | reduce_dtype=torch.float32, 17 | buffer_dtype=torch.float32, 18 | process_group=None, 19 | sharding_strategy=ShardingStrategy.FULL_SHARD, 20 | sync_module_states=True, 21 | ): 22 | model = FSDP( 23 | module=model, 24 | process_group=process_group, 25 | sharding_strategy=sharding_strategy, 26 | auto_wrap_policy=partial( 27 | lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks), 28 | mixed_precision=MixedPrecision( 29 | param_dtype=param_dtype, 30 | reduce_dtype=reduce_dtype, 31 | buffer_dtype=buffer_dtype), 32 | device_id=device_id, 33 | sync_module_states=sync_module_states) 34 | return model 35 | 36 | 37 | def free_model(model): 38 | for m in model.modules(): 39 | if isinstance(m, FSDP): 40 | _free_storage(m._handle.flat_param.data) 41 | del model 42 | gc.collect() 43 | torch.cuda.empty_cache() 44 | -------------------------------------------------------------------------------- /wan/distributed/xdit_context_parallel.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import torch 3 | import torch.cuda.amp as amp 4 | from xfuser.core.distributed import ( 5 | get_sequence_parallel_rank, 6 | get_sequence_parallel_world_size, 7 | get_sp_group, 8 | ) 9 | from xfuser.core.long_ctx_attention import xFuserLongContextAttention 10 | 11 | from ..modules.model import sinusoidal_embedding_1d 12 | 13 | 14 | def pad_freqs(original_tensor, target_len): 15 | seq_len, s1, s2 = original_tensor.shape 16 | pad_size = target_len - seq_len 17 | padding_tensor = torch.ones( 18 | pad_size, 19 | s1, 20 | s2, 21 | dtype=original_tensor.dtype, 22 | device=original_tensor.device) 23 | padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0) 24 | return padded_tensor 25 | 26 | 27 | @amp.autocast(enabled=False) 28 | def rope_apply(x, grid_sizes, freqs): 29 | """ 30 | x: [B, L, N, C]. 31 | grid_sizes: [B, 3]. 32 | freqs: [M, C // 2]. 33 | """ 34 | s, n, c = x.size(1), x.size(2), x.size(3) // 2 35 | # split freqs 36 | freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) 37 | 38 | # loop over samples 39 | output = [] 40 | for i, (f, h, w) in enumerate(grid_sizes.tolist()): 41 | seq_len = f * h * w 42 | 43 | # precompute multipliers 44 | x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape( 45 | s, n, -1, 2)) 46 | freqs_i = torch.cat([ 47 | freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), 48 | freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), 49 | freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) 50 | ], 51 | dim=-1).reshape(seq_len, 1, -1) 52 | 53 | # apply rotary embedding 54 | sp_size = get_sequence_parallel_world_size() 55 | sp_rank = get_sequence_parallel_rank() 56 | freqs_i = pad_freqs(freqs_i, s * sp_size) 57 | s_per_rank = s 58 | freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) * 59 | s_per_rank), :, :] 60 | x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2) 61 | x_i = torch.cat([x_i, x[i, s:]]) 62 | 63 | # append to collection 64 | output.append(x_i) 65 | return torch.stack(output).float() 66 | 67 | 68 | def usp_dit_forward_vace(self, x, vace_context, seq_len, kwargs): 69 | # embeddings 70 | c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context] 71 | c = [u.flatten(2).transpose(1, 2) for u in c] 72 | c = torch.cat([ 73 | torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1) 74 | for u in c 75 | ]) 76 | 77 | # arguments 78 | new_kwargs = dict(x=x) 79 | new_kwargs.update(kwargs) 80 | 81 | # Context Parallel 82 | c = torch.chunk( 83 | c, get_sequence_parallel_world_size(), 84 | dim=1)[get_sequence_parallel_rank()] 85 | 86 | hints = [] 87 | for block in self.vace_blocks: 88 | c, c_skip = block(c, **new_kwargs) 89 | hints.append(c_skip) 90 | return hints 91 | 92 | 93 | def usp_dit_forward( 94 | self, 95 | x, 96 | t, 97 | context, 98 | seq_len, 99 | vace_context=None, 100 | vace_context_scale=1.0, 101 | clip_fea=None, 102 | y=None, 103 | ): 104 | """ 105 | x: A list of videos each with shape [C, T, H, W]. 106 | t: [B]. 107 | context: A list of text embeddings each with shape [L, C]. 108 | """ 109 | if self.model_type == 'i2v': 110 | assert clip_fea is not None and y is not None 111 | # params 112 | device = self.patch_embedding.weight.device 113 | if self.freqs.device != device: 114 | self.freqs = self.freqs.to(device) 115 | 116 | if self.model_type != 'vace' and y is not None: 117 | x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] 118 | 119 | # embeddings 120 | x = [self.patch_embedding(u.unsqueeze(0)) for u in x] 121 | grid_sizes = torch.stack( 122 | [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) 123 | x = [u.flatten(2).transpose(1, 2) for u in x] 124 | seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) 125 | assert seq_lens.max() <= seq_len 126 | x = torch.cat([ 127 | torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1) 128 | for u in x 129 | ]) 130 | 131 | # time embeddings 132 | with amp.autocast(dtype=torch.float32): 133 | e = self.time_embedding( 134 | sinusoidal_embedding_1d(self.freq_dim, t).float()) 135 | e0 = self.time_projection(e).unflatten(1, (6, self.dim)) 136 | assert e.dtype == torch.float32 and e0.dtype == torch.float32 137 | 138 | # context 139 | context_lens = None 140 | context = self.text_embedding( 141 | torch.stack([ 142 | torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) 143 | for u in context 144 | ])) 145 | 146 | if self.model_type != 'vace' and clip_fea is not None: 147 | context_clip = self.img_emb(clip_fea) # bs x 257 x dim 148 | context = torch.concat([context_clip, context], dim=1) 149 | 150 | # arguments 151 | kwargs = dict( 152 | e=e0, 153 | seq_lens=seq_lens, 154 | grid_sizes=grid_sizes, 155 | freqs=self.freqs, 156 | context=context, 157 | context_lens=context_lens) 158 | 159 | # Context Parallel 160 | x = torch.chunk( 161 | x, get_sequence_parallel_world_size(), 162 | dim=1)[get_sequence_parallel_rank()] 163 | 164 | if self.model_type == 'vace': 165 | hints = self.forward_vace(x, vace_context, seq_len, kwargs) 166 | kwargs['hints'] = hints 167 | kwargs['context_scale'] = vace_context_scale 168 | 169 | for block in self.blocks: 170 | x = block(x, **kwargs) 171 | 172 | # head 173 | x = self.head(x, e) 174 | 175 | # Context Parallel 176 | x = get_sp_group().all_gather(x, dim=1) 177 | 178 | # unpatchify 179 | x = self.unpatchify(x, grid_sizes) 180 | return [u.float() for u in x] 181 | 182 | 183 | def usp_attn_forward(self, 184 | x, 185 | seq_lens, 186 | grid_sizes, 187 | freqs, 188 | dtype=torch.bfloat16): 189 | b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim 190 | half_dtypes = (torch.float16, torch.bfloat16) 191 | 192 | def half(x): 193 | return x if x.dtype in half_dtypes else x.to(dtype) 194 | 195 | # query, key, value function 196 | def qkv_fn(x): 197 | q = self.norm_q(self.q(x)).view(b, s, n, d) 198 | k = self.norm_k(self.k(x)).view(b, s, n, d) 199 | v = self.v(x).view(b, s, n, d) 200 | return q, k, v 201 | 202 | q, k, v = qkv_fn(x) 203 | q = rope_apply(q, grid_sizes, freqs) 204 | k = rope_apply(k, grid_sizes, freqs) 205 | 206 | # TODO: We should use unpaded q,k,v for attention. 207 | # k_lens = seq_lens // get_sequence_parallel_world_size() 208 | # if k_lens is not None: 209 | # q = torch.cat([u[:l] for u, l in zip(q, k_lens)]).unsqueeze(0) 210 | # k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0) 211 | # v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0) 212 | 213 | x = xFuserLongContextAttention()( 214 | None, 215 | query=half(q), 216 | key=half(k), 217 | value=half(v), 218 | window_size=self.window_size) 219 | 220 | # TODO: padding after attention. 221 | # x = torch.cat([x, x.new_zeros(b, s - x.size(1), n, d)], dim=1) 222 | 223 | # output 224 | x = x.flatten(2) 225 | x = self.o(x) 226 | return x 227 | -------------------------------------------------------------------------------- /wan/first_last_frame2video.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import gc 3 | import logging 4 | import math 5 | import os 6 | import random 7 | import sys 8 | import types 9 | from contextlib import contextmanager 10 | from functools import partial 11 | 12 | import numpy as np 13 | import torch 14 | import torch.cuda.amp as amp 15 | import torch.distributed as dist 16 | import torchvision.transforms.functional as TF 17 | from tqdm import tqdm 18 | 19 | from .distributed.fsdp import shard_model 20 | from .modules.clip import CLIPModel 21 | from .modules.model import WanModel 22 | from .modules.t5 import T5EncoderModel 23 | from .modules.vae import WanVAE 24 | from .utils.fm_solvers import ( 25 | FlowDPMSolverMultistepScheduler, 26 | get_sampling_sigmas, 27 | retrieve_timesteps, 28 | ) 29 | from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler 30 | 31 | 32 | class WanFLF2V: 33 | 34 | def __init__( 35 | self, 36 | config, 37 | checkpoint_dir, 38 | device_id=0, 39 | rank=0, 40 | t5_fsdp=False, 41 | dit_fsdp=False, 42 | use_usp=False, 43 | t5_cpu=False, 44 | init_on_cpu=True, 45 | ): 46 | r""" 47 | Initializes the image-to-video generation model components. 48 | 49 | Args: 50 | config (EasyDict): 51 | Object containing model parameters initialized from config.py 52 | checkpoint_dir (`str`): 53 | Path to directory containing model checkpoints 54 | device_id (`int`, *optional*, defaults to 0): 55 | Id of target GPU device 56 | rank (`int`, *optional*, defaults to 0): 57 | Process rank for distributed training 58 | t5_fsdp (`bool`, *optional*, defaults to False): 59 | Enable FSDP sharding for T5 model 60 | dit_fsdp (`bool`, *optional*, defaults to False): 61 | Enable FSDP sharding for DiT model 62 | use_usp (`bool`, *optional*, defaults to False): 63 | Enable distribution strategy of USP. 64 | t5_cpu (`bool`, *optional*, defaults to False): 65 | Whether to place T5 model on CPU. Only works without t5_fsdp. 66 | init_on_cpu (`bool`, *optional*, defaults to True): 67 | Enable initializing Transformer Model on CPU. Only works without FSDP or USP. 68 | """ 69 | self.device = torch.device(f"cuda:{device_id}") 70 | self.config = config 71 | self.rank = rank 72 | self.use_usp = use_usp 73 | self.t5_cpu = t5_cpu 74 | 75 | self.num_train_timesteps = config.num_train_timesteps 76 | self.param_dtype = config.param_dtype 77 | 78 | shard_fn = partial(shard_model, device_id=device_id) 79 | self.text_encoder = T5EncoderModel( 80 | text_len=config.text_len, 81 | dtype=config.t5_dtype, 82 | device=torch.device('cpu'), 83 | checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint), 84 | tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), 85 | shard_fn=shard_fn if t5_fsdp else None, 86 | ) 87 | 88 | self.vae_stride = config.vae_stride 89 | self.patch_size = config.patch_size 90 | self.vae = WanVAE( 91 | vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), 92 | device=self.device) 93 | 94 | self.clip = CLIPModel( 95 | dtype=config.clip_dtype, 96 | device=self.device, 97 | checkpoint_path=os.path.join(checkpoint_dir, 98 | config.clip_checkpoint), 99 | tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer)) 100 | 101 | logging.info(f"Creating WanModel from {checkpoint_dir}") 102 | self.model = WanModel.from_pretrained(checkpoint_dir) 103 | self.model.eval().requires_grad_(False) 104 | 105 | if t5_fsdp or dit_fsdp or use_usp: 106 | init_on_cpu = False 107 | 108 | if use_usp: 109 | from xfuser.core.distributed import get_sequence_parallel_world_size 110 | 111 | from .distributed.xdit_context_parallel import ( 112 | usp_attn_forward, 113 | usp_dit_forward, 114 | ) 115 | for block in self.model.blocks: 116 | block.self_attn.forward = types.MethodType( 117 | usp_attn_forward, block.self_attn) 118 | self.model.forward = types.MethodType(usp_dit_forward, self.model) 119 | self.sp_size = get_sequence_parallel_world_size() 120 | else: 121 | self.sp_size = 1 122 | 123 | if dist.is_initialized(): 124 | dist.barrier() 125 | if dit_fsdp: 126 | self.model = shard_fn(self.model) 127 | else: 128 | if not init_on_cpu: 129 | self.model.to(self.device) 130 | 131 | self.sample_neg_prompt = config.sample_neg_prompt 132 | 133 | def generate(self, 134 | input_prompt, 135 | first_frame, 136 | last_frame, 137 | max_area=720 * 1280, 138 | frame_num=81, 139 | shift=16, 140 | sample_solver='unipc', 141 | sampling_steps=50, 142 | guide_scale=5.5, 143 | n_prompt="", 144 | seed=-1, 145 | offload_model=True): 146 | r""" 147 | Generates video frames from input first-last frame and text prompt using diffusion process. 148 | 149 | Args: 150 | input_prompt (`str`): 151 | Text prompt for content generation. 152 | first_frame (PIL.Image.Image): 153 | Input image tensor. Shape: [3, H, W] 154 | last_frame (PIL.Image.Image): 155 | Input image tensor. Shape: [3, H, W] 156 | [NOTE] If the sizes of first_frame and last_frame are mismatched, last_frame will be cropped & resized 157 | to match first_frame. 158 | max_area (`int`, *optional*, defaults to 720*1280): 159 | Maximum pixel area for latent space calculation. Controls video resolution scaling 160 | frame_num (`int`, *optional*, defaults to 81): 161 | How many frames to sample from a video. The number should be 4n+1 162 | shift (`float`, *optional*, defaults to 5.0): 163 | Noise schedule shift parameter. Affects temporal dynamics 164 | [NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0. 165 | sample_solver (`str`, *optional*, defaults to 'unipc'): 166 | Solver used to sample the video. 167 | sampling_steps (`int`, *optional*, defaults to 40): 168 | Number of diffusion sampling steps. Higher values improve quality but slow generation 169 | guide_scale (`float`, *optional*, defaults 5.0): 170 | Classifier-free guidance scale. Controls prompt adherence vs. creativity 171 | n_prompt (`str`, *optional*, defaults to ""): 172 | Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt` 173 | seed (`int`, *optional*, defaults to -1): 174 | Random seed for noise generation. If -1, use random seed 175 | offload_model (`bool`, *optional*, defaults to True): 176 | If True, offloads models to CPU during generation to save VRAM 177 | 178 | Returns: 179 | torch.Tensor: 180 | Generated video frames tensor. Dimensions: (C, N H, W) where: 181 | - C: Color channels (3 for RGB) 182 | - N: Number of frames (81) 183 | - H: Frame height (from max_area) 184 | - W: Frame width from max_area) 185 | """ 186 | first_frame_size = first_frame.size 187 | last_frame_size = last_frame.size 188 | first_frame = TF.to_tensor(first_frame).sub_(0.5).div_(0.5).to( 189 | self.device) 190 | last_frame = TF.to_tensor(last_frame).sub_(0.5).div_(0.5).to( 191 | self.device) 192 | 193 | F = frame_num 194 | first_frame_h, first_frame_w = first_frame.shape[1:] 195 | aspect_ratio = first_frame_h / first_frame_w 196 | lat_h = round( 197 | np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] // 198 | self.patch_size[1] * self.patch_size[1]) 199 | lat_w = round( 200 | np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] // 201 | self.patch_size[2] * self.patch_size[2]) 202 | first_frame_h = lat_h * self.vae_stride[1] 203 | first_frame_w = lat_w * self.vae_stride[2] 204 | if first_frame_size != last_frame_size: 205 | # 1. resize 206 | last_frame_resize_ratio = max( 207 | first_frame_size[0] / last_frame_size[0], 208 | first_frame_size[1] / last_frame_size[1]) 209 | last_frame_size = [ 210 | round(last_frame_size[0] * last_frame_resize_ratio), 211 | round(last_frame_size[1] * last_frame_resize_ratio), 212 | ] 213 | # 2. center crop 214 | last_frame = TF.center_crop(last_frame, last_frame_size) 215 | 216 | max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // ( 217 | self.patch_size[1] * self.patch_size[2]) 218 | max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size 219 | 220 | seed = seed if seed >= 0 else random.randint(0, sys.maxsize) 221 | seed_g = torch.Generator(device=self.device) 222 | seed_g.manual_seed(seed) 223 | noise = torch.randn( 224 | 16, (F - 1) // 4 + 1, 225 | lat_h, 226 | lat_w, 227 | dtype=torch.float32, 228 | generator=seed_g, 229 | device=self.device) 230 | 231 | msk = torch.ones(1, 81, lat_h, lat_w, device=self.device) 232 | msk[:, 1:-1] = 0 233 | msk = torch.concat([ 234 | torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:] 235 | ], 236 | dim=1) 237 | msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) 238 | msk = msk.transpose(1, 2)[0] 239 | 240 | if n_prompt == "": 241 | n_prompt = self.sample_neg_prompt 242 | 243 | # preprocess 244 | if not self.t5_cpu: 245 | self.text_encoder.model.to(self.device) 246 | context = self.text_encoder([input_prompt], self.device) 247 | context_null = self.text_encoder([n_prompt], self.device) 248 | if offload_model: 249 | self.text_encoder.model.cpu() 250 | else: 251 | context = self.text_encoder([input_prompt], torch.device('cpu')) 252 | context_null = self.text_encoder([n_prompt], torch.device('cpu')) 253 | context = [t.to(self.device) for t in context] 254 | context_null = [t.to(self.device) for t in context_null] 255 | 256 | self.clip.model.to(self.device) 257 | clip_context = self.clip.visual( 258 | [first_frame[:, None, :, :], last_frame[:, None, :, :]]) 259 | if offload_model: 260 | self.clip.model.cpu() 261 | 262 | y = self.vae.encode([ 263 | torch.concat([ 264 | torch.nn.functional.interpolate( 265 | first_frame[None].cpu(), 266 | size=(first_frame_h, first_frame_w), 267 | mode='bicubic').transpose(0, 1), 268 | torch.zeros(3, F - 2, first_frame_h, first_frame_w), 269 | torch.nn.functional.interpolate( 270 | last_frame[None].cpu(), 271 | size=(first_frame_h, first_frame_w), 272 | mode='bicubic').transpose(0, 1), 273 | ], 274 | dim=1).to(self.device) 275 | ])[0] 276 | y = torch.concat([msk, y]) 277 | 278 | @contextmanager 279 | def noop_no_sync(): 280 | yield 281 | 282 | no_sync = getattr(self.model, 'no_sync', noop_no_sync) 283 | 284 | # evaluation mode 285 | with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync(): 286 | 287 | if sample_solver == 'unipc': 288 | sample_scheduler = FlowUniPCMultistepScheduler( 289 | num_train_timesteps=self.num_train_timesteps, 290 | shift=1, 291 | use_dynamic_shifting=False) 292 | sample_scheduler.set_timesteps( 293 | sampling_steps, device=self.device, shift=shift) 294 | timesteps = sample_scheduler.timesteps 295 | elif sample_solver == 'dpm++': 296 | sample_scheduler = FlowDPMSolverMultistepScheduler( 297 | num_train_timesteps=self.num_train_timesteps, 298 | shift=1, 299 | use_dynamic_shifting=False) 300 | sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) 301 | timesteps, _ = retrieve_timesteps( 302 | sample_scheduler, 303 | device=self.device, 304 | sigmas=sampling_sigmas) 305 | else: 306 | raise NotImplementedError("Unsupported solver.") 307 | 308 | # sample videos 309 | latent = noise 310 | 311 | arg_c = { 312 | 'context': [context[0]], 313 | 'clip_fea': clip_context, 314 | 'seq_len': max_seq_len, 315 | 'y': [y], 316 | } 317 | 318 | arg_null = { 319 | 'context': context_null, 320 | 'clip_fea': clip_context, 321 | 'seq_len': max_seq_len, 322 | 'y': [y], 323 | } 324 | 325 | if offload_model: 326 | torch.cuda.empty_cache() 327 | 328 | self.model.to(self.device) 329 | for _, t in enumerate(tqdm(timesteps)): 330 | latent_model_input = [latent.to(self.device)] 331 | timestep = [t] 332 | 333 | timestep = torch.stack(timestep).to(self.device) 334 | 335 | noise_pred_cond = self.model( 336 | latent_model_input, t=timestep, **arg_c)[0].to( 337 | torch.device('cpu') if offload_model else self.device) 338 | if offload_model: 339 | torch.cuda.empty_cache() 340 | noise_pred_uncond = self.model( 341 | latent_model_input, t=timestep, **arg_null)[0].to( 342 | torch.device('cpu') if offload_model else self.device) 343 | if offload_model: 344 | torch.cuda.empty_cache() 345 | noise_pred = noise_pred_uncond + guide_scale * ( 346 | noise_pred_cond - noise_pred_uncond) 347 | 348 | latent = latent.to( 349 | torch.device('cpu') if offload_model else self.device) 350 | 351 | temp_x0 = sample_scheduler.step( 352 | noise_pred.unsqueeze(0), 353 | t, 354 | latent.unsqueeze(0), 355 | return_dict=False, 356 | generator=seed_g)[0] 357 | latent = temp_x0.squeeze(0) 358 | 359 | x0 = [latent.to(self.device)] 360 | del latent_model_input, timestep 361 | 362 | if offload_model: 363 | self.model.cpu() 364 | torch.cuda.empty_cache() 365 | 366 | if self.rank == 0: 367 | videos = self.vae.decode(x0) 368 | 369 | del noise, latent 370 | del sample_scheduler 371 | if offload_model: 372 | gc.collect() 373 | torch.cuda.synchronize() 374 | if dist.is_initialized(): 375 | dist.barrier() 376 | 377 | return videos[0] if self.rank == 0 else None 378 | -------------------------------------------------------------------------------- /wan/image2video.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import gc 3 | import logging 4 | import math 5 | import os 6 | import random 7 | import sys 8 | import types 9 | from contextlib import contextmanager 10 | from functools import partial 11 | 12 | import numpy as np 13 | import torch 14 | import torch.cuda.amp as amp 15 | import torch.distributed as dist 16 | import torchvision.transforms.functional as TF 17 | from tqdm import tqdm 18 | 19 | from .distributed.fsdp import shard_model 20 | from .modules.clip import CLIPModel 21 | from .modules.model import WanModel 22 | from .modules.t5 import T5EncoderModel 23 | from .modules.vae import WanVAE 24 | from .utils.fm_solvers import ( 25 | FlowDPMSolverMultistepScheduler, 26 | get_sampling_sigmas, 27 | retrieve_timesteps, 28 | ) 29 | from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler 30 | 31 | 32 | class WanI2V: 33 | 34 | def __init__( 35 | self, 36 | config, 37 | checkpoint_dir, 38 | device_id=0, 39 | rank=0, 40 | t5_fsdp=False, 41 | dit_fsdp=False, 42 | use_usp=False, 43 | t5_cpu=False, 44 | init_on_cpu=True, 45 | ): 46 | r""" 47 | Initializes the image-to-video generation model components. 48 | 49 | Args: 50 | config (EasyDict): 51 | Object containing model parameters initialized from config.py 52 | checkpoint_dir (`str`): 53 | Path to directory containing model checkpoints 54 | device_id (`int`, *optional*, defaults to 0): 55 | Id of target GPU device 56 | rank (`int`, *optional*, defaults to 0): 57 | Process rank for distributed training 58 | t5_fsdp (`bool`, *optional*, defaults to False): 59 | Enable FSDP sharding for T5 model 60 | dit_fsdp (`bool`, *optional*, defaults to False): 61 | Enable FSDP sharding for DiT model 62 | use_usp (`bool`, *optional*, defaults to False): 63 | Enable distribution strategy of USP. 64 | t5_cpu (`bool`, *optional*, defaults to False): 65 | Whether to place T5 model on CPU. Only works without t5_fsdp. 66 | init_on_cpu (`bool`, *optional*, defaults to True): 67 | Enable initializing Transformer Model on CPU. Only works without FSDP or USP. 68 | """ 69 | self.device = torch.device(f"cuda:{device_id}") 70 | self.config = config 71 | self.rank = rank 72 | self.use_usp = use_usp 73 | self.t5_cpu = t5_cpu 74 | 75 | self.num_train_timesteps = config.num_train_timesteps 76 | self.param_dtype = config.param_dtype 77 | 78 | shard_fn = partial(shard_model, device_id=device_id) 79 | self.text_encoder = T5EncoderModel( 80 | text_len=config.text_len, 81 | dtype=config.t5_dtype, 82 | device=torch.device('cpu'), 83 | checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint), 84 | tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), 85 | shard_fn=shard_fn if t5_fsdp else None, 86 | ) 87 | 88 | self.vae_stride = config.vae_stride 89 | self.patch_size = config.patch_size 90 | self.vae = WanVAE( 91 | vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), 92 | device=self.device) 93 | 94 | self.clip = CLIPModel( 95 | dtype=config.clip_dtype, 96 | device=self.device, 97 | checkpoint_path=os.path.join(checkpoint_dir, 98 | config.clip_checkpoint), 99 | tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer)) 100 | 101 | logging.info(f"Creating WanModel from {checkpoint_dir}") 102 | self.model = WanModel.from_pretrained(checkpoint_dir) 103 | self.model.eval().requires_grad_(False) 104 | 105 | if t5_fsdp or dit_fsdp or use_usp: 106 | init_on_cpu = False 107 | 108 | if use_usp: 109 | from xfuser.core.distributed import get_sequence_parallel_world_size 110 | 111 | from .distributed.xdit_context_parallel import ( 112 | usp_attn_forward, 113 | usp_dit_forward, 114 | ) 115 | for block in self.model.blocks: 116 | block.self_attn.forward = types.MethodType( 117 | usp_attn_forward, block.self_attn) 118 | self.model.forward = types.MethodType(usp_dit_forward, self.model) 119 | self.sp_size = get_sequence_parallel_world_size() 120 | else: 121 | self.sp_size = 1 122 | 123 | if dist.is_initialized(): 124 | dist.barrier() 125 | if dit_fsdp: 126 | self.model = shard_fn(self.model) 127 | else: 128 | if not init_on_cpu: 129 | self.model.to(self.device) 130 | 131 | self.sample_neg_prompt = config.sample_neg_prompt 132 | 133 | def generate(self, 134 | input_prompt, 135 | img, 136 | max_area=720 * 1280, 137 | frame_num=81, 138 | shift=5.0, 139 | sample_solver='unipc', 140 | sampling_steps=40, 141 | guide_scale=5.0, 142 | n_prompt="", 143 | seed=-1, 144 | offload_model=True): 145 | r""" 146 | Generates video frames from input image and text prompt using diffusion process. 147 | 148 | Args: 149 | input_prompt (`str`): 150 | Text prompt for content generation. 151 | img (PIL.Image.Image): 152 | Input image tensor. Shape: [3, H, W] 153 | max_area (`int`, *optional*, defaults to 720*1280): 154 | Maximum pixel area for latent space calculation. Controls video resolution scaling 155 | frame_num (`int`, *optional*, defaults to 81): 156 | How many frames to sample from a video. The number should be 4n+1 157 | shift (`float`, *optional*, defaults to 5.0): 158 | Noise schedule shift parameter. Affects temporal dynamics 159 | [NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0. 160 | sample_solver (`str`, *optional*, defaults to 'unipc'): 161 | Solver used to sample the video. 162 | sampling_steps (`int`, *optional*, defaults to 40): 163 | Number of diffusion sampling steps. Higher values improve quality but slow generation 164 | guide_scale (`float`, *optional*, defaults 5.0): 165 | Classifier-free guidance scale. Controls prompt adherence vs. creativity 166 | n_prompt (`str`, *optional*, defaults to ""): 167 | Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt` 168 | seed (`int`, *optional*, defaults to -1): 169 | Random seed for noise generation. If -1, use random seed 170 | offload_model (`bool`, *optional*, defaults to True): 171 | If True, offloads models to CPU during generation to save VRAM 172 | 173 | Returns: 174 | torch.Tensor: 175 | Generated video frames tensor. Dimensions: (C, N H, W) where: 176 | - C: Color channels (3 for RGB) 177 | - N: Number of frames (81) 178 | - H: Frame height (from max_area) 179 | - W: Frame width from max_area) 180 | """ 181 | img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device) 182 | 183 | F = frame_num 184 | h, w = img.shape[1:] 185 | aspect_ratio = h / w 186 | lat_h = round( 187 | np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] // 188 | self.patch_size[1] * self.patch_size[1]) 189 | lat_w = round( 190 | np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] // 191 | self.patch_size[2] * self.patch_size[2]) 192 | h = lat_h * self.vae_stride[1] 193 | w = lat_w * self.vae_stride[2] 194 | 195 | max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // ( 196 | self.patch_size[1] * self.patch_size[2]) 197 | max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size 198 | 199 | seed = seed if seed >= 0 else random.randint(0, sys.maxsize) 200 | seed_g = torch.Generator(device=self.device) 201 | seed_g.manual_seed(seed) 202 | noise = torch.randn( 203 | 16, (F - 1) // 4 + 1, 204 | lat_h, 205 | lat_w, 206 | dtype=torch.float32, 207 | generator=seed_g, 208 | device=self.device) 209 | 210 | msk = torch.ones(1, 81, lat_h, lat_w, device=self.device) 211 | msk[:, 1:] = 0 212 | msk = torch.concat([ 213 | torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:] 214 | ], 215 | dim=1) 216 | msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) 217 | msk = msk.transpose(1, 2)[0] 218 | 219 | if n_prompt == "": 220 | n_prompt = self.sample_neg_prompt 221 | 222 | # preprocess 223 | if not self.t5_cpu: 224 | self.text_encoder.model.to(self.device) 225 | context = self.text_encoder([input_prompt], self.device) 226 | context_null = self.text_encoder([n_prompt], self.device) 227 | if offload_model: 228 | self.text_encoder.model.cpu() 229 | else: 230 | context = self.text_encoder([input_prompt], torch.device('cpu')) 231 | context_null = self.text_encoder([n_prompt], torch.device('cpu')) 232 | context = [t.to(self.device) for t in context] 233 | context_null = [t.to(self.device) for t in context_null] 234 | 235 | self.clip.model.to(self.device) 236 | clip_context = self.clip.visual([img[:, None, :, :]]) 237 | if offload_model: 238 | self.clip.model.cpu() 239 | 240 | y = self.vae.encode([ 241 | torch.concat([ 242 | torch.nn.functional.interpolate( 243 | img[None].cpu(), size=(h, w), mode='bicubic').transpose( 244 | 0, 1), 245 | torch.zeros(3, F - 1, h, w) 246 | ], 247 | dim=1).to(self.device) 248 | ])[0] 249 | y = torch.concat([msk, y]) 250 | 251 | @contextmanager 252 | def noop_no_sync(): 253 | yield 254 | 255 | no_sync = getattr(self.model, 'no_sync', noop_no_sync) 256 | 257 | # evaluation mode 258 | with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync(): 259 | 260 | if sample_solver == 'unipc': 261 | sample_scheduler = FlowUniPCMultistepScheduler( 262 | num_train_timesteps=self.num_train_timesteps, 263 | shift=1, 264 | use_dynamic_shifting=False) 265 | sample_scheduler.set_timesteps( 266 | sampling_steps, device=self.device, shift=shift) 267 | timesteps = sample_scheduler.timesteps 268 | elif sample_solver == 'dpm++': 269 | sample_scheduler = FlowDPMSolverMultistepScheduler( 270 | num_train_timesteps=self.num_train_timesteps, 271 | shift=1, 272 | use_dynamic_shifting=False) 273 | sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) 274 | timesteps, _ = retrieve_timesteps( 275 | sample_scheduler, 276 | device=self.device, 277 | sigmas=sampling_sigmas) 278 | else: 279 | raise NotImplementedError("Unsupported solver.") 280 | 281 | # sample videos 282 | latent = noise 283 | 284 | arg_c = { 285 | 'context': [context[0]], 286 | 'clip_fea': clip_context, 287 | 'seq_len': max_seq_len, 288 | 'y': [y], 289 | } 290 | 291 | arg_null = { 292 | 'context': context_null, 293 | 'clip_fea': clip_context, 294 | 'seq_len': max_seq_len, 295 | 'y': [y], 296 | } 297 | 298 | if offload_model: 299 | torch.cuda.empty_cache() 300 | 301 | self.model.to(self.device) 302 | for _, t in enumerate(tqdm(timesteps)): 303 | latent_model_input = [latent.to(self.device)] 304 | timestep = [t] 305 | 306 | timestep = torch.stack(timestep).to(self.device) 307 | 308 | noise_pred_cond = self.model( 309 | latent_model_input, t=timestep, **arg_c)[0].to( 310 | torch.device('cpu') if offload_model else self.device) 311 | if offload_model: 312 | torch.cuda.empty_cache() 313 | noise_pred_uncond = self.model( 314 | latent_model_input, t=timestep, **arg_null)[0].to( 315 | torch.device('cpu') if offload_model else self.device) 316 | if offload_model: 317 | torch.cuda.empty_cache() 318 | noise_pred = noise_pred_uncond + guide_scale * ( 319 | noise_pred_cond - noise_pred_uncond) 320 | 321 | latent = latent.to( 322 | torch.device('cpu') if offload_model else self.device) 323 | 324 | temp_x0 = sample_scheduler.step( 325 | noise_pred.unsqueeze(0), 326 | t, 327 | latent.unsqueeze(0), 328 | return_dict=False, 329 | generator=seed_g)[0] 330 | latent = temp_x0.squeeze(0) 331 | 332 | x0 = [latent.to(self.device)] 333 | del latent_model_input, timestep 334 | 335 | if offload_model: 336 | self.model.cpu() 337 | torch.cuda.empty_cache() 338 | 339 | if self.rank == 0: 340 | videos = self.vae.decode(x0) 341 | 342 | del noise, latent 343 | del sample_scheduler 344 | if offload_model: 345 | gc.collect() 346 | torch.cuda.synchronize() 347 | if dist.is_initialized(): 348 | dist.barrier() 349 | 350 | return videos[0] if self.rank == 0 else None 351 | -------------------------------------------------------------------------------- /wan/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .attention import flash_attention 2 | from .model import WanModel 3 | from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model 4 | from .tokenizers import HuggingfaceTokenizer 5 | from .vace_model import VaceWanModel 6 | from .vae import WanVAE 7 | 8 | __all__ = [ 9 | 'WanVAE', 10 | 'WanModel', 11 | 'VaceWanModel', 12 | 'T5Model', 13 | 'T5Encoder', 14 | 'T5Decoder', 15 | 'T5EncoderModel', 16 | 'HuggingfaceTokenizer', 17 | 'flash_attention', 18 | ] 19 | -------------------------------------------------------------------------------- /wan/modules/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import torch 3 | 4 | try: 5 | import flash_attn_interface 6 | FLASH_ATTN_3_AVAILABLE = True 7 | except ModuleNotFoundError: 8 | FLASH_ATTN_3_AVAILABLE = False 9 | 10 | try: 11 | import flash_attn 12 | FLASH_ATTN_2_AVAILABLE = True 13 | except ModuleNotFoundError: 14 | FLASH_ATTN_2_AVAILABLE = False 15 | 16 | import warnings 17 | 18 | __all__ = [ 19 | 'flash_attention', 20 | 'attention', 21 | ] 22 | 23 | 24 | def flash_attention( 25 | q, 26 | k, 27 | v, 28 | q_lens=None, 29 | k_lens=None, 30 | dropout_p=0., 31 | softmax_scale=None, 32 | q_scale=None, 33 | causal=False, 34 | window_size=(-1, -1), 35 | deterministic=False, 36 | dtype=torch.bfloat16, 37 | version=None, 38 | ): 39 | """ 40 | q: [B, Lq, Nq, C1]. 41 | k: [B, Lk, Nk, C1]. 42 | v: [B, Lk, Nk, C2]. Nq must be divisible by Nk. 43 | q_lens: [B]. 44 | k_lens: [B]. 45 | dropout_p: float. Dropout probability. 46 | softmax_scale: float. The scaling of QK^T before applying softmax. 47 | causal: bool. Whether to apply causal attention mask. 48 | window_size: (left right). If not (-1, -1), apply sliding window local attention. 49 | deterministic: bool. If True, slightly slower and uses more memory. 50 | dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16. 51 | """ 52 | half_dtypes = (torch.float16, torch.bfloat16) 53 | assert dtype in half_dtypes 54 | assert q.device.type == 'cuda' and q.size(-1) <= 256 55 | 56 | # params 57 | b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype 58 | 59 | def half(x): 60 | return x if x.dtype in half_dtypes else x.to(dtype) 61 | 62 | # preprocess query 63 | if q_lens is None: 64 | q = half(q.flatten(0, 1)) 65 | q_lens = torch.tensor( 66 | [lq] * b, dtype=torch.int32).to( 67 | device=q.device, non_blocking=True) 68 | else: 69 | q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)])) 70 | 71 | # preprocess key, value 72 | if k_lens is None: 73 | k = half(k.flatten(0, 1)) 74 | v = half(v.flatten(0, 1)) 75 | k_lens = torch.tensor( 76 | [lk] * b, dtype=torch.int32).to( 77 | device=k.device, non_blocking=True) 78 | else: 79 | k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)])) 80 | v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)])) 81 | 82 | q = q.to(v.dtype) 83 | k = k.to(v.dtype) 84 | 85 | if q_scale is not None: 86 | q = q * q_scale 87 | 88 | if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE: 89 | warnings.warn( 90 | 'Flash attention 3 is not available, use flash attention 2 instead.' 91 | ) 92 | 93 | # apply attention 94 | if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE: 95 | # Note: dropout_p, window_size are not supported in FA3 now. 96 | x = flash_attn_interface.flash_attn_varlen_func( 97 | q=q, 98 | k=k, 99 | v=v, 100 | cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum( 101 | 0, dtype=torch.int32).to(q.device, non_blocking=True), 102 | cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum( 103 | 0, dtype=torch.int32).to(q.device, non_blocking=True), 104 | seqused_q=None, 105 | seqused_k=None, 106 | max_seqlen_q=lq, 107 | max_seqlen_k=lk, 108 | softmax_scale=softmax_scale, 109 | causal=causal, 110 | deterministic=deterministic)[0].unflatten(0, (b, lq)) 111 | else: 112 | assert FLASH_ATTN_2_AVAILABLE 113 | x = flash_attn.flash_attn_varlen_func( 114 | q=q, 115 | k=k, 116 | v=v, 117 | cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum( 118 | 0, dtype=torch.int32).to(q.device, non_blocking=True), 119 | cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum( 120 | 0, dtype=torch.int32).to(q.device, non_blocking=True), 121 | max_seqlen_q=lq, 122 | max_seqlen_k=lk, 123 | dropout_p=dropout_p, 124 | softmax_scale=softmax_scale, 125 | causal=causal, 126 | window_size=window_size, 127 | deterministic=deterministic).unflatten(0, (b, lq)) 128 | 129 | # output 130 | return x.type(out_dtype) 131 | 132 | 133 | def attention( 134 | q, 135 | k, 136 | v, 137 | q_lens=None, 138 | k_lens=None, 139 | dropout_p=0., 140 | softmax_scale=None, 141 | q_scale=None, 142 | causal=False, 143 | window_size=(-1, -1), 144 | deterministic=False, 145 | dtype=torch.bfloat16, 146 | fa_version=None, 147 | ): 148 | if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE: 149 | return flash_attention( 150 | q=q, 151 | k=k, 152 | v=v, 153 | q_lens=q_lens, 154 | k_lens=k_lens, 155 | dropout_p=dropout_p, 156 | softmax_scale=softmax_scale, 157 | q_scale=q_scale, 158 | causal=causal, 159 | window_size=window_size, 160 | deterministic=deterministic, 161 | dtype=dtype, 162 | version=fa_version, 163 | ) 164 | else: 165 | if q_lens is not None or k_lens is not None: 166 | warnings.warn( 167 | 'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.' 168 | ) 169 | attn_mask = None 170 | 171 | q = q.transpose(1, 2).to(dtype) 172 | k = k.transpose(1, 2).to(dtype) 173 | v = v.transpose(1, 2).to(dtype) 174 | 175 | out = torch.nn.functional.scaled_dot_product_attention( 176 | q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p) 177 | 178 | out = out.transpose(1, 2).contiguous() 179 | return out 180 | -------------------------------------------------------------------------------- /wan/modules/t5.py: -------------------------------------------------------------------------------- 1 | # Modified from transformers.models.t5.modeling_t5 2 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 3 | import logging 4 | import math 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from .tokenizers import HuggingfaceTokenizer 11 | 12 | __all__ = [ 13 | 'T5Model', 14 | 'T5Encoder', 15 | 'T5Decoder', 16 | 'T5EncoderModel', 17 | ] 18 | 19 | 20 | def fp16_clamp(x): 21 | if x.dtype == torch.float16 and torch.isinf(x).any(): 22 | clamp = torch.finfo(x.dtype).max - 1000 23 | x = torch.clamp(x, min=-clamp, max=clamp) 24 | return x 25 | 26 | 27 | def init_weights(m): 28 | if isinstance(m, T5LayerNorm): 29 | nn.init.ones_(m.weight) 30 | elif isinstance(m, T5Model): 31 | nn.init.normal_(m.token_embedding.weight, std=1.0) 32 | elif isinstance(m, T5FeedForward): 33 | nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5) 34 | nn.init.normal_(m.fc1.weight, std=m.dim**-0.5) 35 | nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5) 36 | elif isinstance(m, T5Attention): 37 | nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5) 38 | nn.init.normal_(m.k.weight, std=m.dim**-0.5) 39 | nn.init.normal_(m.v.weight, std=m.dim**-0.5) 40 | nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5) 41 | elif isinstance(m, T5RelativeEmbedding): 42 | nn.init.normal_( 43 | m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5) 44 | 45 | 46 | class GELU(nn.Module): 47 | 48 | def forward(self, x): 49 | return 0.5 * x * (1.0 + torch.tanh( 50 | math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) 51 | 52 | 53 | class T5LayerNorm(nn.Module): 54 | 55 | def __init__(self, dim, eps=1e-6): 56 | super(T5LayerNorm, self).__init__() 57 | self.dim = dim 58 | self.eps = eps 59 | self.weight = nn.Parameter(torch.ones(dim)) 60 | 61 | def forward(self, x): 62 | x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + 63 | self.eps) 64 | if self.weight.dtype in [torch.float16, torch.bfloat16]: 65 | x = x.type_as(self.weight) 66 | return self.weight * x 67 | 68 | 69 | class T5Attention(nn.Module): 70 | 71 | def __init__(self, dim, dim_attn, num_heads, dropout=0.1): 72 | assert dim_attn % num_heads == 0 73 | super(T5Attention, self).__init__() 74 | self.dim = dim 75 | self.dim_attn = dim_attn 76 | self.num_heads = num_heads 77 | self.head_dim = dim_attn // num_heads 78 | 79 | # layers 80 | self.q = nn.Linear(dim, dim_attn, bias=False) 81 | self.k = nn.Linear(dim, dim_attn, bias=False) 82 | self.v = nn.Linear(dim, dim_attn, bias=False) 83 | self.o = nn.Linear(dim_attn, dim, bias=False) 84 | self.dropout = nn.Dropout(dropout) 85 | 86 | def forward(self, x, context=None, mask=None, pos_bias=None): 87 | """ 88 | x: [B, L1, C]. 89 | context: [B, L2, C] or None. 90 | mask: [B, L2] or [B, L1, L2] or None. 91 | """ 92 | # check inputs 93 | context = x if context is None else context 94 | b, n, c = x.size(0), self.num_heads, self.head_dim 95 | 96 | # compute query, key, value 97 | q = self.q(x).view(b, -1, n, c) 98 | k = self.k(context).view(b, -1, n, c) 99 | v = self.v(context).view(b, -1, n, c) 100 | 101 | # attention bias 102 | attn_bias = x.new_zeros(b, n, q.size(1), k.size(1)) 103 | if pos_bias is not None: 104 | attn_bias += pos_bias 105 | if mask is not None: 106 | assert mask.ndim in [2, 3] 107 | mask = mask.view(b, 1, 1, 108 | -1) if mask.ndim == 2 else mask.unsqueeze(1) 109 | attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min) 110 | 111 | # compute attention (T5 does not use scaling) 112 | attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias 113 | attn = F.softmax(attn.float(), dim=-1).type_as(attn) 114 | x = torch.einsum('bnij,bjnc->binc', attn, v) 115 | 116 | # output 117 | x = x.reshape(b, -1, n * c) 118 | x = self.o(x) 119 | x = self.dropout(x) 120 | return x 121 | 122 | 123 | class T5FeedForward(nn.Module): 124 | 125 | def __init__(self, dim, dim_ffn, dropout=0.1): 126 | super(T5FeedForward, self).__init__() 127 | self.dim = dim 128 | self.dim_ffn = dim_ffn 129 | 130 | # layers 131 | self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU()) 132 | self.fc1 = nn.Linear(dim, dim_ffn, bias=False) 133 | self.fc2 = nn.Linear(dim_ffn, dim, bias=False) 134 | self.dropout = nn.Dropout(dropout) 135 | 136 | def forward(self, x): 137 | x = self.fc1(x) * self.gate(x) 138 | x = self.dropout(x) 139 | x = self.fc2(x) 140 | x = self.dropout(x) 141 | return x 142 | 143 | 144 | class T5SelfAttention(nn.Module): 145 | 146 | def __init__(self, 147 | dim, 148 | dim_attn, 149 | dim_ffn, 150 | num_heads, 151 | num_buckets, 152 | shared_pos=True, 153 | dropout=0.1): 154 | super(T5SelfAttention, self).__init__() 155 | self.dim = dim 156 | self.dim_attn = dim_attn 157 | self.dim_ffn = dim_ffn 158 | self.num_heads = num_heads 159 | self.num_buckets = num_buckets 160 | self.shared_pos = shared_pos 161 | 162 | # layers 163 | self.norm1 = T5LayerNorm(dim) 164 | self.attn = T5Attention(dim, dim_attn, num_heads, dropout) 165 | self.norm2 = T5LayerNorm(dim) 166 | self.ffn = T5FeedForward(dim, dim_ffn, dropout) 167 | self.pos_embedding = None if shared_pos else T5RelativeEmbedding( 168 | num_buckets, num_heads, bidirectional=True) 169 | 170 | def forward(self, x, mask=None, pos_bias=None): 171 | e = pos_bias if self.shared_pos else self.pos_embedding( 172 | x.size(1), x.size(1)) 173 | x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e)) 174 | x = fp16_clamp(x + self.ffn(self.norm2(x))) 175 | return x 176 | 177 | 178 | class T5CrossAttention(nn.Module): 179 | 180 | def __init__(self, 181 | dim, 182 | dim_attn, 183 | dim_ffn, 184 | num_heads, 185 | num_buckets, 186 | shared_pos=True, 187 | dropout=0.1): 188 | super(T5CrossAttention, self).__init__() 189 | self.dim = dim 190 | self.dim_attn = dim_attn 191 | self.dim_ffn = dim_ffn 192 | self.num_heads = num_heads 193 | self.num_buckets = num_buckets 194 | self.shared_pos = shared_pos 195 | 196 | # layers 197 | self.norm1 = T5LayerNorm(dim) 198 | self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout) 199 | self.norm2 = T5LayerNorm(dim) 200 | self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout) 201 | self.norm3 = T5LayerNorm(dim) 202 | self.ffn = T5FeedForward(dim, dim_ffn, dropout) 203 | self.pos_embedding = None if shared_pos else T5RelativeEmbedding( 204 | num_buckets, num_heads, bidirectional=False) 205 | 206 | def forward(self, 207 | x, 208 | mask=None, 209 | encoder_states=None, 210 | encoder_mask=None, 211 | pos_bias=None): 212 | e = pos_bias if self.shared_pos else self.pos_embedding( 213 | x.size(1), x.size(1)) 214 | x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e)) 215 | x = fp16_clamp(x + self.cross_attn( 216 | self.norm2(x), context=encoder_states, mask=encoder_mask)) 217 | x = fp16_clamp(x + self.ffn(self.norm3(x))) 218 | return x 219 | 220 | 221 | class T5RelativeEmbedding(nn.Module): 222 | 223 | def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128): 224 | super(T5RelativeEmbedding, self).__init__() 225 | self.num_buckets = num_buckets 226 | self.num_heads = num_heads 227 | self.bidirectional = bidirectional 228 | self.max_dist = max_dist 229 | 230 | # layers 231 | self.embedding = nn.Embedding(num_buckets, num_heads) 232 | 233 | def forward(self, lq, lk): 234 | device = self.embedding.weight.device 235 | # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \ 236 | # torch.arange(lq).unsqueeze(1).to(device) 237 | rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \ 238 | torch.arange(lq, device=device).unsqueeze(1) 239 | rel_pos = self._relative_position_bucket(rel_pos) 240 | rel_pos_embeds = self.embedding(rel_pos) 241 | rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze( 242 | 0) # [1, N, Lq, Lk] 243 | return rel_pos_embeds.contiguous() 244 | 245 | def _relative_position_bucket(self, rel_pos): 246 | # preprocess 247 | if self.bidirectional: 248 | num_buckets = self.num_buckets // 2 249 | rel_buckets = (rel_pos > 0).long() * num_buckets 250 | rel_pos = torch.abs(rel_pos) 251 | else: 252 | num_buckets = self.num_buckets 253 | rel_buckets = 0 254 | rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos)) 255 | 256 | # embeddings for small and large positions 257 | max_exact = num_buckets // 2 258 | rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) / 259 | math.log(self.max_dist / max_exact) * 260 | (num_buckets - max_exact)).long() 261 | rel_pos_large = torch.min( 262 | rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1)) 263 | rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large) 264 | return rel_buckets 265 | 266 | 267 | class T5Encoder(nn.Module): 268 | 269 | def __init__(self, 270 | vocab, 271 | dim, 272 | dim_attn, 273 | dim_ffn, 274 | num_heads, 275 | num_layers, 276 | num_buckets, 277 | shared_pos=True, 278 | dropout=0.1): 279 | super(T5Encoder, self).__init__() 280 | self.dim = dim 281 | self.dim_attn = dim_attn 282 | self.dim_ffn = dim_ffn 283 | self.num_heads = num_heads 284 | self.num_layers = num_layers 285 | self.num_buckets = num_buckets 286 | self.shared_pos = shared_pos 287 | 288 | # layers 289 | self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \ 290 | else nn.Embedding(vocab, dim) 291 | self.pos_embedding = T5RelativeEmbedding( 292 | num_buckets, num_heads, bidirectional=True) if shared_pos else None 293 | self.dropout = nn.Dropout(dropout) 294 | self.blocks = nn.ModuleList([ 295 | T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, 296 | shared_pos, dropout) for _ in range(num_layers) 297 | ]) 298 | self.norm = T5LayerNorm(dim) 299 | 300 | # initialize weights 301 | self.apply(init_weights) 302 | 303 | def forward(self, ids, mask=None): 304 | x = self.token_embedding(ids) 305 | x = self.dropout(x) 306 | e = self.pos_embedding(x.size(1), 307 | x.size(1)) if self.shared_pos else None 308 | for block in self.blocks: 309 | x = block(x, mask, pos_bias=e) 310 | x = self.norm(x) 311 | x = self.dropout(x) 312 | return x 313 | 314 | 315 | class T5Decoder(nn.Module): 316 | 317 | def __init__(self, 318 | vocab, 319 | dim, 320 | dim_attn, 321 | dim_ffn, 322 | num_heads, 323 | num_layers, 324 | num_buckets, 325 | shared_pos=True, 326 | dropout=0.1): 327 | super(T5Decoder, self).__init__() 328 | self.dim = dim 329 | self.dim_attn = dim_attn 330 | self.dim_ffn = dim_ffn 331 | self.num_heads = num_heads 332 | self.num_layers = num_layers 333 | self.num_buckets = num_buckets 334 | self.shared_pos = shared_pos 335 | 336 | # layers 337 | self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \ 338 | else nn.Embedding(vocab, dim) 339 | self.pos_embedding = T5RelativeEmbedding( 340 | num_buckets, num_heads, bidirectional=False) if shared_pos else None 341 | self.dropout = nn.Dropout(dropout) 342 | self.blocks = nn.ModuleList([ 343 | T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, 344 | shared_pos, dropout) for _ in range(num_layers) 345 | ]) 346 | self.norm = T5LayerNorm(dim) 347 | 348 | # initialize weights 349 | self.apply(init_weights) 350 | 351 | def forward(self, ids, mask=None, encoder_states=None, encoder_mask=None): 352 | b, s = ids.size() 353 | 354 | # causal mask 355 | if mask is None: 356 | mask = torch.tril(torch.ones(1, s, s).to(ids.device)) 357 | elif mask.ndim == 2: 358 | mask = torch.tril(mask.unsqueeze(1).expand(-1, s, -1)) 359 | 360 | # layers 361 | x = self.token_embedding(ids) 362 | x = self.dropout(x) 363 | e = self.pos_embedding(x.size(1), 364 | x.size(1)) if self.shared_pos else None 365 | for block in self.blocks: 366 | x = block(x, mask, encoder_states, encoder_mask, pos_bias=e) 367 | x = self.norm(x) 368 | x = self.dropout(x) 369 | return x 370 | 371 | 372 | class T5Model(nn.Module): 373 | 374 | def __init__(self, 375 | vocab_size, 376 | dim, 377 | dim_attn, 378 | dim_ffn, 379 | num_heads, 380 | encoder_layers, 381 | decoder_layers, 382 | num_buckets, 383 | shared_pos=True, 384 | dropout=0.1): 385 | super(T5Model, self).__init__() 386 | self.vocab_size = vocab_size 387 | self.dim = dim 388 | self.dim_attn = dim_attn 389 | self.dim_ffn = dim_ffn 390 | self.num_heads = num_heads 391 | self.encoder_layers = encoder_layers 392 | self.decoder_layers = decoder_layers 393 | self.num_buckets = num_buckets 394 | 395 | # layers 396 | self.token_embedding = nn.Embedding(vocab_size, dim) 397 | self.encoder = T5Encoder(self.token_embedding, dim, dim_attn, dim_ffn, 398 | num_heads, encoder_layers, num_buckets, 399 | shared_pos, dropout) 400 | self.decoder = T5Decoder(self.token_embedding, dim, dim_attn, dim_ffn, 401 | num_heads, decoder_layers, num_buckets, 402 | shared_pos, dropout) 403 | self.head = nn.Linear(dim, vocab_size, bias=False) 404 | 405 | # initialize weights 406 | self.apply(init_weights) 407 | 408 | def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask): 409 | x = self.encoder(encoder_ids, encoder_mask) 410 | x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask) 411 | x = self.head(x) 412 | return x 413 | 414 | 415 | def _t5(name, 416 | encoder_only=False, 417 | decoder_only=False, 418 | return_tokenizer=False, 419 | tokenizer_kwargs={}, 420 | dtype=torch.float32, 421 | device='cpu', 422 | **kwargs): 423 | # sanity check 424 | assert not (encoder_only and decoder_only) 425 | 426 | # params 427 | if encoder_only: 428 | model_cls = T5Encoder 429 | kwargs['vocab'] = kwargs.pop('vocab_size') 430 | kwargs['num_layers'] = kwargs.pop('encoder_layers') 431 | _ = kwargs.pop('decoder_layers') 432 | elif decoder_only: 433 | model_cls = T5Decoder 434 | kwargs['vocab'] = kwargs.pop('vocab_size') 435 | kwargs['num_layers'] = kwargs.pop('decoder_layers') 436 | _ = kwargs.pop('encoder_layers') 437 | else: 438 | model_cls = T5Model 439 | 440 | # init model 441 | with torch.device(device): 442 | model = model_cls(**kwargs) 443 | 444 | # set device 445 | model = model.to(dtype=dtype, device=device) 446 | 447 | # init tokenizer 448 | if return_tokenizer: 449 | from .tokenizers import HuggingfaceTokenizer 450 | tokenizer = HuggingfaceTokenizer(f'google/{name}', **tokenizer_kwargs) 451 | return model, tokenizer 452 | else: 453 | return model 454 | 455 | 456 | def umt5_xxl(**kwargs): 457 | cfg = dict( 458 | vocab_size=256384, 459 | dim=4096, 460 | dim_attn=4096, 461 | dim_ffn=10240, 462 | num_heads=64, 463 | encoder_layers=24, 464 | decoder_layers=24, 465 | num_buckets=32, 466 | shared_pos=False, 467 | dropout=0.1) 468 | cfg.update(**kwargs) 469 | return _t5('umt5-xxl', **cfg) 470 | 471 | 472 | class T5EncoderModel: 473 | 474 | def __init__( 475 | self, 476 | text_len, 477 | dtype=torch.bfloat16, 478 | device=torch.cuda.current_device(), 479 | checkpoint_path=None, 480 | tokenizer_path=None, 481 | shard_fn=None, 482 | ): 483 | self.text_len = text_len 484 | self.dtype = dtype 485 | self.device = device 486 | self.checkpoint_path = checkpoint_path 487 | self.tokenizer_path = tokenizer_path 488 | 489 | # init model 490 | model = umt5_xxl( 491 | encoder_only=True, 492 | return_tokenizer=False, 493 | dtype=dtype, 494 | device=device).eval().requires_grad_(False) 495 | logging.info(f'loading {checkpoint_path}') 496 | model.load_state_dict(torch.load(checkpoint_path, map_location='cpu')) 497 | self.model = model 498 | if shard_fn is not None: 499 | self.model = shard_fn(self.model, sync_module_states=False) 500 | else: 501 | self.model.to(self.device) 502 | # init tokenizer 503 | self.tokenizer = HuggingfaceTokenizer( 504 | name=tokenizer_path, seq_len=text_len, clean='whitespace') 505 | 506 | def __call__(self, texts, device): 507 | ids, mask = self.tokenizer( 508 | texts, return_mask=True, add_special_tokens=True) 509 | ids = ids.to(device) 510 | mask = mask.to(device) 511 | seq_lens = mask.gt(0).sum(dim=1).long() 512 | context = self.model(ids, mask) 513 | return [u[:v] for u, v in zip(context, seq_lens)] 514 | -------------------------------------------------------------------------------- /wan/modules/tokenizers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import html 3 | import string 4 | 5 | import ftfy 6 | import regex as re 7 | from transformers import AutoTokenizer 8 | 9 | __all__ = ['HuggingfaceTokenizer'] 10 | 11 | 12 | def basic_clean(text): 13 | text = ftfy.fix_text(text) 14 | text = html.unescape(html.unescape(text)) 15 | return text.strip() 16 | 17 | 18 | def whitespace_clean(text): 19 | text = re.sub(r'\s+', ' ', text) 20 | text = text.strip() 21 | return text 22 | 23 | 24 | def canonicalize(text, keep_punctuation_exact_string=None): 25 | text = text.replace('_', ' ') 26 | if keep_punctuation_exact_string: 27 | text = keep_punctuation_exact_string.join( 28 | part.translate(str.maketrans('', '', string.punctuation)) 29 | for part in text.split(keep_punctuation_exact_string)) 30 | else: 31 | text = text.translate(str.maketrans('', '', string.punctuation)) 32 | text = text.lower() 33 | text = re.sub(r'\s+', ' ', text) 34 | return text.strip() 35 | 36 | 37 | class HuggingfaceTokenizer: 38 | 39 | def __init__(self, name, seq_len=None, clean=None, **kwargs): 40 | assert clean in (None, 'whitespace', 'lower', 'canonicalize') 41 | self.name = name 42 | self.seq_len = seq_len 43 | self.clean = clean 44 | 45 | # init tokenizer 46 | self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs) 47 | self.vocab_size = self.tokenizer.vocab_size 48 | 49 | def __call__(self, sequence, **kwargs): 50 | return_mask = kwargs.pop('return_mask', False) 51 | 52 | # arguments 53 | _kwargs = {'return_tensors': 'pt'} 54 | if self.seq_len is not None: 55 | _kwargs.update({ 56 | 'padding': 'max_length', 57 | 'truncation': True, 58 | 'max_length': self.seq_len 59 | }) 60 | _kwargs.update(**kwargs) 61 | 62 | # tokenization 63 | if isinstance(sequence, str): 64 | sequence = [sequence] 65 | if self.clean: 66 | sequence = [self._clean(u) for u in sequence] 67 | ids = self.tokenizer(sequence, **_kwargs) 68 | 69 | # output 70 | if return_mask: 71 | return ids.input_ids, ids.attention_mask 72 | else: 73 | return ids.input_ids 74 | 75 | def _clean(self, text): 76 | if self.clean == 'whitespace': 77 | text = whitespace_clean(basic_clean(text)) 78 | elif self.clean == 'lower': 79 | text = whitespace_clean(basic_clean(text)).lower() 80 | elif self.clean == 'canonicalize': 81 | text = canonicalize(basic_clean(text)) 82 | return text 83 | -------------------------------------------------------------------------------- /wan/modules/vace_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import torch 3 | import torch.cuda.amp as amp 4 | import torch.nn as nn 5 | from diffusers.configuration_utils import register_to_config 6 | 7 | from .model import WanAttentionBlock, WanModel, sinusoidal_embedding_1d 8 | 9 | 10 | class VaceWanAttentionBlock(WanAttentionBlock): 11 | 12 | def __init__(self, 13 | cross_attn_type, 14 | dim, 15 | ffn_dim, 16 | num_heads, 17 | window_size=(-1, -1), 18 | qk_norm=True, 19 | cross_attn_norm=False, 20 | eps=1e-6, 21 | block_id=0): 22 | super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, 23 | qk_norm, cross_attn_norm, eps) 24 | self.block_id = block_id 25 | if block_id == 0: 26 | self.before_proj = nn.Linear(self.dim, self.dim) 27 | nn.init.zeros_(self.before_proj.weight) 28 | nn.init.zeros_(self.before_proj.bias) 29 | self.after_proj = nn.Linear(self.dim, self.dim) 30 | nn.init.zeros_(self.after_proj.weight) 31 | nn.init.zeros_(self.after_proj.bias) 32 | 33 | def forward(self, c, x, **kwargs): 34 | if self.block_id == 0: 35 | c = self.before_proj(c) + x 36 | 37 | c = super().forward(c, **kwargs) 38 | c_skip = self.after_proj(c) 39 | return c, c_skip 40 | 41 | 42 | class BaseWanAttentionBlock(WanAttentionBlock): 43 | 44 | def __init__(self, 45 | cross_attn_type, 46 | dim, 47 | ffn_dim, 48 | num_heads, 49 | window_size=(-1, -1), 50 | qk_norm=True, 51 | cross_attn_norm=False, 52 | eps=1e-6, 53 | block_id=None): 54 | super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, 55 | qk_norm, cross_attn_norm, eps) 56 | self.block_id = block_id 57 | 58 | def forward(self, x, hints, context_scale=1.0, **kwargs): 59 | x = super().forward(x, **kwargs) 60 | if self.block_id is not None: 61 | x = x + hints[self.block_id] * context_scale 62 | return x 63 | 64 | 65 | class VaceWanModel(WanModel): 66 | 67 | @register_to_config 68 | def __init__(self, 69 | vace_layers=None, 70 | vace_in_dim=None, 71 | model_type='vace', 72 | patch_size=(1, 2, 2), 73 | text_len=512, 74 | in_dim=16, 75 | dim=2048, 76 | ffn_dim=8192, 77 | freq_dim=256, 78 | text_dim=4096, 79 | out_dim=16, 80 | num_heads=16, 81 | num_layers=32, 82 | window_size=(-1, -1), 83 | qk_norm=True, 84 | cross_attn_norm=True, 85 | eps=1e-6): 86 | super().__init__(model_type, patch_size, text_len, in_dim, dim, ffn_dim, 87 | freq_dim, text_dim, out_dim, num_heads, num_layers, 88 | window_size, qk_norm, cross_attn_norm, eps) 89 | 90 | self.vace_layers = [i for i in range(0, self.num_layers, 2) 91 | ] if vace_layers is None else vace_layers 92 | self.vace_in_dim = self.in_dim if vace_in_dim is None else vace_in_dim 93 | 94 | assert 0 in self.vace_layers 95 | self.vace_layers_mapping = { 96 | i: n for n, i in enumerate(self.vace_layers) 97 | } 98 | 99 | # blocks 100 | self.blocks = nn.ModuleList([ 101 | BaseWanAttentionBlock( 102 | 't2v_cross_attn', 103 | self.dim, 104 | self.ffn_dim, 105 | self.num_heads, 106 | self.window_size, 107 | self.qk_norm, 108 | self.cross_attn_norm, 109 | self.eps, 110 | block_id=self.vace_layers_mapping[i] 111 | if i in self.vace_layers else None) 112 | for i in range(self.num_layers) 113 | ]) 114 | 115 | # vace blocks 116 | self.vace_blocks = nn.ModuleList([ 117 | VaceWanAttentionBlock( 118 | 't2v_cross_attn', 119 | self.dim, 120 | self.ffn_dim, 121 | self.num_heads, 122 | self.window_size, 123 | self.qk_norm, 124 | self.cross_attn_norm, 125 | self.eps, 126 | block_id=i) for i in self.vace_layers 127 | ]) 128 | 129 | # vace patch embeddings 130 | self.vace_patch_embedding = nn.Conv3d( 131 | self.vace_in_dim, 132 | self.dim, 133 | kernel_size=self.patch_size, 134 | stride=self.patch_size) 135 | 136 | def forward_vace(self, x, vace_context, seq_len, kwargs): 137 | # embeddings 138 | c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context] 139 | c = [u.flatten(2).transpose(1, 2) for u in c] 140 | c = torch.cat([ 141 | torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], 142 | dim=1) for u in c 143 | ]) 144 | 145 | # arguments 146 | new_kwargs = dict(x=x) 147 | new_kwargs.update(kwargs) 148 | 149 | hints = [] 150 | for block in self.vace_blocks: 151 | c, c_skip = block(c, **new_kwargs) 152 | hints.append(c_skip) 153 | return hints 154 | 155 | def forward( 156 | self, 157 | x, 158 | t, 159 | vace_context, 160 | context, 161 | seq_len, 162 | vace_context_scale=1.0, 163 | clip_fea=None, 164 | y=None, 165 | ): 166 | r""" 167 | Forward pass through the diffusion model 168 | 169 | Args: 170 | x (List[Tensor]): 171 | List of input video tensors, each with shape [C_in, F, H, W] 172 | t (Tensor): 173 | Diffusion timesteps tensor of shape [B] 174 | context (List[Tensor]): 175 | List of text embeddings each with shape [L, C] 176 | seq_len (`int`): 177 | Maximum sequence length for positional encoding 178 | clip_fea (Tensor, *optional*): 179 | CLIP image features for image-to-video mode 180 | y (List[Tensor], *optional*): 181 | Conditional video inputs for image-to-video mode, same shape as x 182 | 183 | Returns: 184 | List[Tensor]: 185 | List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8] 186 | """ 187 | # if self.model_type == 'i2v': 188 | # assert clip_fea is not None and y is not None 189 | # params 190 | device = self.patch_embedding.weight.device 191 | if self.freqs.device != device: 192 | self.freqs = self.freqs.to(device) 193 | 194 | # if y is not None: 195 | # x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] 196 | 197 | # embeddings 198 | x = [self.patch_embedding(u.unsqueeze(0)) for u in x] 199 | grid_sizes = torch.stack( 200 | [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) 201 | x = [u.flatten(2).transpose(1, 2) for u in x] 202 | seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) 203 | assert seq_lens.max() <= seq_len 204 | x = torch.cat([ 205 | torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], 206 | dim=1) for u in x 207 | ]) 208 | 209 | # time embeddings 210 | with amp.autocast(dtype=torch.float32): 211 | e = self.time_embedding( 212 | sinusoidal_embedding_1d(self.freq_dim, t).float()) 213 | e0 = self.time_projection(e).unflatten(1, (6, self.dim)) 214 | assert e.dtype == torch.float32 and e0.dtype == torch.float32 215 | 216 | # context 217 | context_lens = None 218 | context = self.text_embedding( 219 | torch.stack([ 220 | torch.cat( 221 | [u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) 222 | for u in context 223 | ])) 224 | 225 | # if clip_fea is not None: 226 | # context_clip = self.img_emb(clip_fea) # bs x 257 x dim 227 | # context = torch.concat([context_clip, context], dim=1) 228 | 229 | # arguments 230 | kwargs = dict( 231 | e=e0, 232 | seq_lens=seq_lens, 233 | grid_sizes=grid_sizes, 234 | freqs=self.freqs, 235 | context=context, 236 | context_lens=context_lens) 237 | 238 | hints = self.forward_vace(x, vace_context, seq_len, kwargs) 239 | kwargs['hints'] = hints 240 | kwargs['context_scale'] = vace_context_scale 241 | 242 | for block in self.blocks: 243 | x = block(x, **kwargs) 244 | 245 | # head 246 | x = self.head(x, e) 247 | 248 | # unpatchify 249 | x = self.unpatchify(x, grid_sizes) 250 | return [u.float() for u in x] 251 | -------------------------------------------------------------------------------- /wan/modules/xlm_roberta.py: -------------------------------------------------------------------------------- 1 | # Modified from transformers.models.xlm_roberta.modeling_xlm_roberta 2 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | __all__ = ['XLMRoberta', 'xlm_roberta_large'] 8 | 9 | 10 | class SelfAttention(nn.Module): 11 | 12 | def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5): 13 | assert dim % num_heads == 0 14 | super().__init__() 15 | self.dim = dim 16 | self.num_heads = num_heads 17 | self.head_dim = dim // num_heads 18 | self.eps = eps 19 | 20 | # layers 21 | self.q = nn.Linear(dim, dim) 22 | self.k = nn.Linear(dim, dim) 23 | self.v = nn.Linear(dim, dim) 24 | self.o = nn.Linear(dim, dim) 25 | self.dropout = nn.Dropout(dropout) 26 | 27 | def forward(self, x, mask): 28 | """ 29 | x: [B, L, C]. 30 | """ 31 | b, s, c, n, d = *x.size(), self.num_heads, self.head_dim 32 | 33 | # compute query, key, value 34 | q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3) 35 | k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3) 36 | v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3) 37 | 38 | # compute attention 39 | p = self.dropout.p if self.training else 0.0 40 | x = F.scaled_dot_product_attention(q, k, v, mask, p) 41 | x = x.permute(0, 2, 1, 3).reshape(b, s, c) 42 | 43 | # output 44 | x = self.o(x) 45 | x = self.dropout(x) 46 | return x 47 | 48 | 49 | class AttentionBlock(nn.Module): 50 | 51 | def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5): 52 | super().__init__() 53 | self.dim = dim 54 | self.num_heads = num_heads 55 | self.post_norm = post_norm 56 | self.eps = eps 57 | 58 | # layers 59 | self.attn = SelfAttention(dim, num_heads, dropout, eps) 60 | self.norm1 = nn.LayerNorm(dim, eps=eps) 61 | self.ffn = nn.Sequential( 62 | nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim), 63 | nn.Dropout(dropout)) 64 | self.norm2 = nn.LayerNorm(dim, eps=eps) 65 | 66 | def forward(self, x, mask): 67 | if self.post_norm: 68 | x = self.norm1(x + self.attn(x, mask)) 69 | x = self.norm2(x + self.ffn(x)) 70 | else: 71 | x = x + self.attn(self.norm1(x), mask) 72 | x = x + self.ffn(self.norm2(x)) 73 | return x 74 | 75 | 76 | class XLMRoberta(nn.Module): 77 | """ 78 | XLMRobertaModel with no pooler and no LM head. 79 | """ 80 | 81 | def __init__(self, 82 | vocab_size=250002, 83 | max_seq_len=514, 84 | type_size=1, 85 | pad_id=1, 86 | dim=1024, 87 | num_heads=16, 88 | num_layers=24, 89 | post_norm=True, 90 | dropout=0.1, 91 | eps=1e-5): 92 | super().__init__() 93 | self.vocab_size = vocab_size 94 | self.max_seq_len = max_seq_len 95 | self.type_size = type_size 96 | self.pad_id = pad_id 97 | self.dim = dim 98 | self.num_heads = num_heads 99 | self.num_layers = num_layers 100 | self.post_norm = post_norm 101 | self.eps = eps 102 | 103 | # embeddings 104 | self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id) 105 | self.type_embedding = nn.Embedding(type_size, dim) 106 | self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id) 107 | self.dropout = nn.Dropout(dropout) 108 | 109 | # blocks 110 | self.blocks = nn.ModuleList([ 111 | AttentionBlock(dim, num_heads, post_norm, dropout, eps) 112 | for _ in range(num_layers) 113 | ]) 114 | 115 | # norm layer 116 | self.norm = nn.LayerNorm(dim, eps=eps) 117 | 118 | def forward(self, ids): 119 | """ 120 | ids: [B, L] of torch.LongTensor. 121 | """ 122 | b, s = ids.shape 123 | mask = ids.ne(self.pad_id).long() 124 | 125 | # embeddings 126 | x = self.token_embedding(ids) + \ 127 | self.type_embedding(torch.zeros_like(ids)) + \ 128 | self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask) 129 | if self.post_norm: 130 | x = self.norm(x) 131 | x = self.dropout(x) 132 | 133 | # blocks 134 | mask = torch.where( 135 | mask.view(b, 1, 1, s).gt(0), 0.0, 136 | torch.finfo(x.dtype).min) 137 | for block in self.blocks: 138 | x = block(x, mask) 139 | 140 | # output 141 | if not self.post_norm: 142 | x = self.norm(x) 143 | return x 144 | 145 | 146 | def xlm_roberta_large(pretrained=False, 147 | return_tokenizer=False, 148 | device='cpu', 149 | **kwargs): 150 | """ 151 | XLMRobertaLarge adapted from Huggingface. 152 | """ 153 | # params 154 | cfg = dict( 155 | vocab_size=250002, 156 | max_seq_len=514, 157 | type_size=1, 158 | pad_id=1, 159 | dim=1024, 160 | num_heads=16, 161 | num_layers=24, 162 | post_norm=True, 163 | dropout=0.1, 164 | eps=1e-5) 165 | cfg.update(**kwargs) 166 | 167 | # init a model on device 168 | with torch.device(device): 169 | model = XLMRoberta(**cfg) 170 | return model 171 | -------------------------------------------------------------------------------- /wan/text2video.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import gc 3 | import logging 4 | import math 5 | import os 6 | import random 7 | import sys 8 | import types 9 | from contextlib import contextmanager 10 | from functools import partial 11 | 12 | import torch 13 | import torch.cuda.amp as amp 14 | import torch.distributed as dist 15 | from tqdm import tqdm 16 | 17 | from .distributed.fsdp import shard_model 18 | from .modules.model import WanModel 19 | from .modules.t5 import T5EncoderModel 20 | from .modules.vae import WanVAE 21 | from .utils.fm_solvers import ( 22 | FlowDPMSolverMultistepScheduler, 23 | get_sampling_sigmas, 24 | retrieve_timesteps, 25 | ) 26 | from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler 27 | 28 | 29 | class WanT2V: 30 | 31 | def __init__( 32 | self, 33 | config, 34 | checkpoint_dir, 35 | device_id=0, 36 | rank=0, 37 | t5_fsdp=False, 38 | dit_fsdp=False, 39 | use_usp=False, 40 | t5_cpu=False, 41 | ): 42 | r""" 43 | Initializes the Wan text-to-video generation model components. 44 | 45 | Args: 46 | config (EasyDict): 47 | Object containing model parameters initialized from config.py 48 | checkpoint_dir (`str`): 49 | Path to directory containing model checkpoints 50 | device_id (`int`, *optional*, defaults to 0): 51 | Id of target GPU device 52 | rank (`int`, *optional*, defaults to 0): 53 | Process rank for distributed training 54 | t5_fsdp (`bool`, *optional*, defaults to False): 55 | Enable FSDP sharding for T5 model 56 | dit_fsdp (`bool`, *optional*, defaults to False): 57 | Enable FSDP sharding for DiT model 58 | use_usp (`bool`, *optional*, defaults to False): 59 | Enable distribution strategy of USP. 60 | t5_cpu (`bool`, *optional*, defaults to False): 61 | Whether to place T5 model on CPU. Only works without t5_fsdp. 62 | """ 63 | self.device = torch.device(f"cuda:{device_id}") 64 | self.config = config 65 | self.rank = rank 66 | self.t5_cpu = t5_cpu 67 | 68 | self.num_train_timesteps = config.num_train_timesteps 69 | self.param_dtype = config.param_dtype 70 | 71 | shard_fn = partial(shard_model, device_id=device_id) 72 | self.text_encoder = T5EncoderModel( 73 | text_len=config.text_len, 74 | dtype=config.t5_dtype, 75 | device=torch.device('cpu'), 76 | checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint), 77 | tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), 78 | shard_fn=shard_fn if t5_fsdp else None) 79 | 80 | self.vae_stride = config.vae_stride 81 | self.patch_size = config.patch_size 82 | self.vae = WanVAE( 83 | vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), 84 | device=self.device) 85 | 86 | logging.info(f"Creating WanModel from {checkpoint_dir}") 87 | self.model = WanModel.from_pretrained(checkpoint_dir) 88 | self.model.eval().requires_grad_(False) 89 | 90 | if use_usp: 91 | from xfuser.core.distributed import get_sequence_parallel_world_size 92 | 93 | from .distributed.xdit_context_parallel import ( 94 | usp_attn_forward, 95 | usp_dit_forward, 96 | ) 97 | for block in self.model.blocks: 98 | block.self_attn.forward = types.MethodType( 99 | usp_attn_forward, block.self_attn) 100 | self.model.forward = types.MethodType(usp_dit_forward, self.model) 101 | self.sp_size = get_sequence_parallel_world_size() 102 | else: 103 | self.sp_size = 1 104 | 105 | if dist.is_initialized(): 106 | dist.barrier() 107 | if dit_fsdp: 108 | self.model = shard_fn(self.model) 109 | else: 110 | self.model.to(self.device) 111 | 112 | self.sample_neg_prompt = config.sample_neg_prompt 113 | 114 | def generate(self, 115 | input_prompt, 116 | size=(1280, 720), 117 | frame_num=81, 118 | shift=5.0, 119 | sample_solver='unipc', 120 | sampling_steps=50, 121 | guide_scale=5.0, 122 | n_prompt="", 123 | seed=-1, 124 | offload_model=True): 125 | r""" 126 | Generates video frames from text prompt using diffusion process. 127 | 128 | Args: 129 | input_prompt (`str`): 130 | Text prompt for content generation 131 | size (tupele[`int`], *optional*, defaults to (1280,720)): 132 | Controls video resolution, (width,height). 133 | frame_num (`int`, *optional*, defaults to 81): 134 | How many frames to sample from a video. The number should be 4n+1 135 | shift (`float`, *optional*, defaults to 5.0): 136 | Noise schedule shift parameter. Affects temporal dynamics 137 | sample_solver (`str`, *optional*, defaults to 'unipc'): 138 | Solver used to sample the video. 139 | sampling_steps (`int`, *optional*, defaults to 40): 140 | Number of diffusion sampling steps. Higher values improve quality but slow generation 141 | guide_scale (`float`, *optional*, defaults 5.0): 142 | Classifier-free guidance scale. Controls prompt adherence vs. creativity 143 | n_prompt (`str`, *optional*, defaults to ""): 144 | Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt` 145 | seed (`int`, *optional*, defaults to -1): 146 | Random seed for noise generation. If -1, use random seed. 147 | offload_model (`bool`, *optional*, defaults to True): 148 | If True, offloads models to CPU during generation to save VRAM 149 | 150 | Returns: 151 | torch.Tensor: 152 | Generated video frames tensor. Dimensions: (C, N H, W) where: 153 | - C: Color channels (3 for RGB) 154 | - N: Number of frames (81) 155 | - H: Frame height (from size) 156 | - W: Frame width from size) 157 | """ 158 | # preprocess 159 | F = frame_num 160 | target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1, 161 | size[1] // self.vae_stride[1], 162 | size[0] // self.vae_stride[2]) 163 | 164 | seq_len = math.ceil((target_shape[2] * target_shape[3]) / 165 | (self.patch_size[1] * self.patch_size[2]) * 166 | target_shape[1] / self.sp_size) * self.sp_size 167 | 168 | if n_prompt == "": 169 | n_prompt = self.sample_neg_prompt 170 | seed = seed if seed >= 0 else random.randint(0, sys.maxsize) 171 | seed_g = torch.Generator(device=self.device) 172 | seed_g.manual_seed(seed) 173 | 174 | if not self.t5_cpu: 175 | self.text_encoder.model.to(self.device) 176 | context = self.text_encoder([input_prompt], self.device) 177 | context_null = self.text_encoder([n_prompt], self.device) 178 | if offload_model: 179 | self.text_encoder.model.cpu() 180 | else: 181 | context = self.text_encoder([input_prompt], torch.device('cpu')) 182 | context_null = self.text_encoder([n_prompt], torch.device('cpu')) 183 | context = [t.to(self.device) for t in context] 184 | context_null = [t.to(self.device) for t in context_null] 185 | 186 | noise = [ 187 | torch.randn( 188 | target_shape[0], 189 | target_shape[1], 190 | target_shape[2], 191 | target_shape[3], 192 | dtype=torch.float32, 193 | device=self.device, 194 | generator=seed_g) 195 | ] 196 | 197 | @contextmanager 198 | def noop_no_sync(): 199 | yield 200 | 201 | no_sync = getattr(self.model, 'no_sync', noop_no_sync) 202 | 203 | # evaluation mode 204 | with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync(): 205 | 206 | if sample_solver == 'unipc': 207 | sample_scheduler = FlowUniPCMultistepScheduler( 208 | num_train_timesteps=self.num_train_timesteps, 209 | shift=1, 210 | use_dynamic_shifting=False) 211 | sample_scheduler.set_timesteps( 212 | sampling_steps, device=self.device, shift=shift) 213 | timesteps = sample_scheduler.timesteps 214 | elif sample_solver == 'dpm++': 215 | sample_scheduler = FlowDPMSolverMultistepScheduler( 216 | num_train_timesteps=self.num_train_timesteps, 217 | shift=1, 218 | use_dynamic_shifting=False) 219 | sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) 220 | timesteps, _ = retrieve_timesteps( 221 | sample_scheduler, 222 | device=self.device, 223 | sigmas=sampling_sigmas) 224 | else: 225 | raise NotImplementedError("Unsupported solver.") 226 | 227 | # sample videos 228 | latents = noise 229 | 230 | arg_c = {'context': context, 'seq_len': seq_len} 231 | arg_null = {'context': context_null, 'seq_len': seq_len} 232 | 233 | for _, t in enumerate(tqdm(timesteps)): 234 | latent_model_input = latents 235 | timestep = [t] 236 | 237 | timestep = torch.stack(timestep) 238 | 239 | self.model.to(self.device) 240 | noise_pred_cond = self.model( 241 | latent_model_input, t=timestep, **arg_c)[0] 242 | noise_pred_uncond = self.model( 243 | latent_model_input, t=timestep, **arg_null)[0] 244 | 245 | noise_pred = noise_pred_uncond + guide_scale * ( 246 | noise_pred_cond - noise_pred_uncond) 247 | 248 | temp_x0 = sample_scheduler.step( 249 | noise_pred.unsqueeze(0), 250 | t, 251 | latents[0].unsqueeze(0), 252 | return_dict=False, 253 | generator=seed_g)[0] 254 | latents = [temp_x0.squeeze(0)] 255 | 256 | x0 = latents 257 | if offload_model: 258 | self.model.cpu() 259 | torch.cuda.empty_cache() 260 | if self.rank == 0: 261 | videos = self.vae.decode(x0) 262 | 263 | del noise, latents 264 | del sample_scheduler 265 | if offload_model: 266 | gc.collect() 267 | torch.cuda.synchronize() 268 | if dist.is_initialized(): 269 | dist.barrier() 270 | 271 | return videos[0] if self.rank == 0 else None 272 | -------------------------------------------------------------------------------- /wan/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .fm_solvers import ( 2 | FlowDPMSolverMultistepScheduler, 3 | get_sampling_sigmas, 4 | retrieve_timesteps, 5 | ) 6 | from .fm_solvers_unipc import FlowUniPCMultistepScheduler 7 | from .vace_processor import VaceVideoProcessor 8 | 9 | __all__ = [ 10 | 'HuggingfaceTokenizer', 'get_sampling_sigmas', 'retrieve_timesteps', 11 | 'FlowDPMSolverMultistepScheduler', 'FlowUniPCMultistepScheduler', 12 | 'VaceVideoProcessor' 13 | ] 14 | -------------------------------------------------------------------------------- /wan/utils/qwen_vl_utils.py: -------------------------------------------------------------------------------- 1 | # Copied from https://github.com/kq-chen/qwen-vl-utils 2 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 3 | from __future__ import annotations 4 | 5 | import base64 6 | import logging 7 | import math 8 | import os 9 | import sys 10 | import time 11 | import warnings 12 | from functools import lru_cache 13 | from io import BytesIO 14 | 15 | import requests 16 | import torch 17 | import torchvision 18 | from packaging import version 19 | from PIL import Image 20 | from torchvision import io, transforms 21 | from torchvision.transforms import InterpolationMode 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | IMAGE_FACTOR = 28 26 | MIN_PIXELS = 4 * 28 * 28 27 | MAX_PIXELS = 16384 * 28 * 28 28 | MAX_RATIO = 200 29 | 30 | VIDEO_MIN_PIXELS = 128 * 28 * 28 31 | VIDEO_MAX_PIXELS = 768 * 28 * 28 32 | VIDEO_TOTAL_PIXELS = 24576 * 28 * 28 33 | FRAME_FACTOR = 2 34 | FPS = 2.0 35 | FPS_MIN_FRAMES = 4 36 | FPS_MAX_FRAMES = 768 37 | 38 | 39 | def round_by_factor(number: int, factor: int) -> int: 40 | """Returns the closest integer to 'number' that is divisible by 'factor'.""" 41 | return round(number / factor) * factor 42 | 43 | 44 | def ceil_by_factor(number: int, factor: int) -> int: 45 | """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" 46 | return math.ceil(number / factor) * factor 47 | 48 | 49 | def floor_by_factor(number: int, factor: int) -> int: 50 | """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" 51 | return math.floor(number / factor) * factor 52 | 53 | 54 | def smart_resize(height: int, 55 | width: int, 56 | factor: int = IMAGE_FACTOR, 57 | min_pixels: int = MIN_PIXELS, 58 | max_pixels: int = MAX_PIXELS) -> tuple[int, int]: 59 | """ 60 | Rescales the image so that the following conditions are met: 61 | 62 | 1. Both dimensions (height and width) are divisible by 'factor'. 63 | 64 | 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. 65 | 66 | 3. The aspect ratio of the image is maintained as closely as possible. 67 | """ 68 | if max(height, width) / min(height, width) > MAX_RATIO: 69 | raise ValueError( 70 | f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}" 71 | ) 72 | h_bar = max(factor, round_by_factor(height, factor)) 73 | w_bar = max(factor, round_by_factor(width, factor)) 74 | if h_bar * w_bar > max_pixels: 75 | beta = math.sqrt((height * width) / max_pixels) 76 | h_bar = floor_by_factor(height / beta, factor) 77 | w_bar = floor_by_factor(width / beta, factor) 78 | elif h_bar * w_bar < min_pixels: 79 | beta = math.sqrt(min_pixels / (height * width)) 80 | h_bar = ceil_by_factor(height * beta, factor) 81 | w_bar = ceil_by_factor(width * beta, factor) 82 | return h_bar, w_bar 83 | 84 | 85 | def fetch_image(ele: dict[str, str | Image.Image], 86 | size_factor: int = IMAGE_FACTOR) -> Image.Image: 87 | if "image" in ele: 88 | image = ele["image"] 89 | else: 90 | image = ele["image_url"] 91 | image_obj = None 92 | if isinstance(image, Image.Image): 93 | image_obj = image 94 | elif image.startswith("http://") or image.startswith("https://"): 95 | image_obj = Image.open(requests.get(image, stream=True).raw) 96 | elif image.startswith("file://"): 97 | image_obj = Image.open(image[7:]) 98 | elif image.startswith("data:image"): 99 | if "base64," in image: 100 | _, base64_data = image.split("base64,", 1) 101 | data = base64.b64decode(base64_data) 102 | image_obj = Image.open(BytesIO(data)) 103 | else: 104 | image_obj = Image.open(image) 105 | if image_obj is None: 106 | raise ValueError( 107 | f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}" 108 | ) 109 | image = image_obj.convert("RGB") 110 | ## resize 111 | if "resized_height" in ele and "resized_width" in ele: 112 | resized_height, resized_width = smart_resize( 113 | ele["resized_height"], 114 | ele["resized_width"], 115 | factor=size_factor, 116 | ) 117 | else: 118 | width, height = image.size 119 | min_pixels = ele.get("min_pixels", MIN_PIXELS) 120 | max_pixels = ele.get("max_pixels", MAX_PIXELS) 121 | resized_height, resized_width = smart_resize( 122 | height, 123 | width, 124 | factor=size_factor, 125 | min_pixels=min_pixels, 126 | max_pixels=max_pixels, 127 | ) 128 | image = image.resize((resized_width, resized_height)) 129 | 130 | return image 131 | 132 | 133 | def smart_nframes( 134 | ele: dict, 135 | total_frames: int, 136 | video_fps: int | float, 137 | ) -> int: 138 | """calculate the number of frames for video used for model inputs. 139 | 140 | Args: 141 | ele (dict): a dict contains the configuration of video. 142 | support either `fps` or `nframes`: 143 | - nframes: the number of frames to extract for model inputs. 144 | - fps: the fps to extract frames for model inputs. 145 | - min_frames: the minimum number of frames of the video, only used when fps is provided. 146 | - max_frames: the maximum number of frames of the video, only used when fps is provided. 147 | total_frames (int): the original total number of frames of the video. 148 | video_fps (int | float): the original fps of the video. 149 | 150 | Raises: 151 | ValueError: nframes should in interval [FRAME_FACTOR, total_frames]. 152 | 153 | Returns: 154 | int: the number of frames for video used for model inputs. 155 | """ 156 | assert not ("fps" in ele and 157 | "nframes" in ele), "Only accept either `fps` or `nframes`" 158 | if "nframes" in ele: 159 | nframes = round_by_factor(ele["nframes"], FRAME_FACTOR) 160 | else: 161 | fps = ele.get("fps", FPS) 162 | min_frames = ceil_by_factor( 163 | ele.get("min_frames", FPS_MIN_FRAMES), FRAME_FACTOR) 164 | max_frames = floor_by_factor( 165 | ele.get("max_frames", min(FPS_MAX_FRAMES, total_frames)), 166 | FRAME_FACTOR) 167 | nframes = total_frames / video_fps * fps 168 | nframes = min(max(nframes, min_frames), max_frames) 169 | nframes = round_by_factor(nframes, FRAME_FACTOR) 170 | if not (FRAME_FACTOR <= nframes and nframes <= total_frames): 171 | raise ValueError( 172 | f"nframes should in interval [{FRAME_FACTOR}, {total_frames}], but got {nframes}." 173 | ) 174 | return nframes 175 | 176 | 177 | def _read_video_torchvision(ele: dict,) -> torch.Tensor: 178 | """read video using torchvision.io.read_video 179 | 180 | Args: 181 | ele (dict): a dict contains the configuration of video. 182 | support keys: 183 | - video: the path of video. support "file://", "http://", "https://" and local path. 184 | - video_start: the start time of video. 185 | - video_end: the end time of video. 186 | Returns: 187 | torch.Tensor: the video tensor with shape (T, C, H, W). 188 | """ 189 | video_path = ele["video"] 190 | if version.parse(torchvision.__version__) < version.parse("0.19.0"): 191 | if "http://" in video_path or "https://" in video_path: 192 | warnings.warn( 193 | "torchvision < 0.19.0 does not support http/https video path, please upgrade to 0.19.0." 194 | ) 195 | if "file://" in video_path: 196 | video_path = video_path[7:] 197 | st = time.time() 198 | video, audio, info = io.read_video( 199 | video_path, 200 | start_pts=ele.get("video_start", 0.0), 201 | end_pts=ele.get("video_end", None), 202 | pts_unit="sec", 203 | output_format="TCHW", 204 | ) 205 | total_frames, video_fps = video.size(0), info["video_fps"] 206 | logger.info( 207 | f"torchvision: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s" 208 | ) 209 | nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps) 210 | idx = torch.linspace(0, total_frames - 1, nframes).round().long() 211 | video = video[idx] 212 | return video 213 | 214 | 215 | def is_decord_available() -> bool: 216 | import importlib.util 217 | 218 | return importlib.util.find_spec("decord") is not None 219 | 220 | 221 | def _read_video_decord(ele: dict,) -> torch.Tensor: 222 | """read video using decord.VideoReader 223 | 224 | Args: 225 | ele (dict): a dict contains the configuration of video. 226 | support keys: 227 | - video: the path of video. support "file://", "http://", "https://" and local path. 228 | - video_start: the start time of video. 229 | - video_end: the end time of video. 230 | Returns: 231 | torch.Tensor: the video tensor with shape (T, C, H, W). 232 | """ 233 | import decord 234 | video_path = ele["video"] 235 | st = time.time() 236 | vr = decord.VideoReader(video_path) 237 | # TODO: support start_pts and end_pts 238 | if 'video_start' in ele or 'video_end' in ele: 239 | raise NotImplementedError( 240 | "not support start_pts and end_pts in decord for now.") 241 | total_frames, video_fps = len(vr), vr.get_avg_fps() 242 | logger.info( 243 | f"decord: {video_path=}, {total_frames=}, {video_fps=}, time={time.time() - st:.3f}s" 244 | ) 245 | nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps) 246 | idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist() 247 | video = vr.get_batch(idx).asnumpy() 248 | video = torch.tensor(video).permute(0, 3, 1, 2) # Convert to TCHW format 249 | return video 250 | 251 | 252 | VIDEO_READER_BACKENDS = { 253 | "decord": _read_video_decord, 254 | "torchvision": _read_video_torchvision, 255 | } 256 | 257 | FORCE_QWENVL_VIDEO_READER = os.getenv("FORCE_QWENVL_VIDEO_READER", None) 258 | 259 | 260 | @lru_cache(maxsize=1) 261 | def get_video_reader_backend() -> str: 262 | if FORCE_QWENVL_VIDEO_READER is not None: 263 | video_reader_backend = FORCE_QWENVL_VIDEO_READER 264 | elif is_decord_available(): 265 | video_reader_backend = "decord" 266 | else: 267 | video_reader_backend = "torchvision" 268 | print( 269 | f"qwen-vl-utils using {video_reader_backend} to read video.", 270 | file=sys.stderr) 271 | return video_reader_backend 272 | 273 | 274 | def fetch_video( 275 | ele: dict, 276 | image_factor: int = IMAGE_FACTOR) -> torch.Tensor | list[Image.Image]: 277 | if isinstance(ele["video"], str): 278 | video_reader_backend = get_video_reader_backend() 279 | video = VIDEO_READER_BACKENDS[video_reader_backend](ele) 280 | nframes, _, height, width = video.shape 281 | 282 | min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS) 283 | total_pixels = ele.get("total_pixels", VIDEO_TOTAL_PIXELS) 284 | max_pixels = max( 285 | min(VIDEO_MAX_PIXELS, total_pixels / nframes * FRAME_FACTOR), 286 | int(min_pixels * 1.05)) 287 | max_pixels = ele.get("max_pixels", max_pixels) 288 | if "resized_height" in ele and "resized_width" in ele: 289 | resized_height, resized_width = smart_resize( 290 | ele["resized_height"], 291 | ele["resized_width"], 292 | factor=image_factor, 293 | ) 294 | else: 295 | resized_height, resized_width = smart_resize( 296 | height, 297 | width, 298 | factor=image_factor, 299 | min_pixels=min_pixels, 300 | max_pixels=max_pixels, 301 | ) 302 | video = transforms.functional.resize( 303 | video, 304 | [resized_height, resized_width], 305 | interpolation=InterpolationMode.BICUBIC, 306 | antialias=True, 307 | ).float() 308 | return video 309 | else: 310 | assert isinstance(ele["video"], (list, tuple)) 311 | process_info = ele.copy() 312 | process_info.pop("type", None) 313 | process_info.pop("video", None) 314 | images = [ 315 | fetch_image({ 316 | "image": video_element, 317 | **process_info 318 | }, 319 | size_factor=image_factor) 320 | for video_element in ele["video"] 321 | ] 322 | nframes = ceil_by_factor(len(images), FRAME_FACTOR) 323 | if len(images) < nframes: 324 | images.extend([images[-1]] * (nframes - len(images))) 325 | return images 326 | 327 | 328 | def extract_vision_info( 329 | conversations: list[dict] | list[list[dict]]) -> list[dict]: 330 | vision_infos = [] 331 | if isinstance(conversations[0], dict): 332 | conversations = [conversations] 333 | for conversation in conversations: 334 | for message in conversation: 335 | if isinstance(message["content"], list): 336 | for ele in message["content"]: 337 | if ("image" in ele or "image_url" in ele or 338 | "video" in ele or 339 | ele["type"] in ("image", "image_url", "video")): 340 | vision_infos.append(ele) 341 | return vision_infos 342 | 343 | 344 | def process_vision_info( 345 | conversations: list[dict] | list[list[dict]], 346 | ) -> tuple[list[Image.Image] | None, list[torch.Tensor | list[Image.Image]] | 347 | None]: 348 | vision_infos = extract_vision_info(conversations) 349 | ## Read images or videos 350 | image_inputs = [] 351 | video_inputs = [] 352 | for vision_info in vision_infos: 353 | if "image" in vision_info or "image_url" in vision_info: 354 | image_inputs.append(fetch_image(vision_info)) 355 | elif "video" in vision_info: 356 | video_inputs.append(fetch_video(vision_info)) 357 | else: 358 | raise ValueError("image, image_url or video should in content.") 359 | if len(image_inputs) == 0: 360 | image_inputs = None 361 | if len(video_inputs) == 0: 362 | video_inputs = None 363 | return image_inputs, video_inputs 364 | -------------------------------------------------------------------------------- /wan/utils/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import argparse 3 | import binascii 4 | import os 5 | import os.path as osp 6 | 7 | import imageio 8 | import torch 9 | import torchvision 10 | 11 | __all__ = ['cache_video', 'cache_image', 'str2bool'] 12 | 13 | 14 | def rand_name(length=8, suffix=''): 15 | name = binascii.b2a_hex(os.urandom(length)).decode('utf-8') 16 | if suffix: 17 | if not suffix.startswith('.'): 18 | suffix = '.' + suffix 19 | name += suffix 20 | return name 21 | 22 | 23 | def cache_video(tensor, 24 | save_file=None, 25 | fps=30, 26 | suffix='.mp4', 27 | nrow=8, 28 | normalize=True, 29 | value_range=(-1, 1), 30 | retry=5): 31 | # cache file 32 | cache_file = osp.join('/tmp', rand_name( 33 | suffix=suffix)) if save_file is None else save_file 34 | 35 | # save to cache 36 | error = None 37 | for _ in range(retry): 38 | try: 39 | # preprocess 40 | tensor = tensor.clamp(min(value_range), max(value_range)) 41 | tensor = torch.stack([ 42 | torchvision.utils.make_grid( 43 | u, nrow=nrow, normalize=normalize, value_range=value_range) 44 | for u in tensor.unbind(2) 45 | ], 46 | dim=1).permute(1, 2, 3, 0) 47 | tensor = (tensor * 255).type(torch.uint8).cpu() 48 | 49 | # write video 50 | writer = imageio.get_writer( 51 | cache_file, fps=fps, codec='libx264', quality=8) 52 | for frame in tensor.numpy(): 53 | writer.append_data(frame) 54 | writer.close() 55 | return cache_file 56 | except Exception as e: 57 | error = e 58 | continue 59 | else: 60 | print(f'cache_video failed, error: {error}', flush=True) 61 | return None 62 | 63 | 64 | def cache_image(tensor, 65 | save_file, 66 | nrow=8, 67 | normalize=True, 68 | value_range=(-1, 1), 69 | retry=5): 70 | # cache file 71 | suffix = osp.splitext(save_file)[1] 72 | if suffix.lower() not in [ 73 | '.jpg', '.jpeg', '.png', '.tiff', '.gif', '.webp' 74 | ]: 75 | suffix = '.png' 76 | 77 | # save to cache 78 | error = None 79 | for _ in range(retry): 80 | try: 81 | tensor = tensor.clamp(min(value_range), max(value_range)) 82 | torchvision.utils.save_image( 83 | tensor, 84 | save_file, 85 | nrow=nrow, 86 | normalize=normalize, 87 | value_range=value_range) 88 | return save_file 89 | except Exception as e: 90 | error = e 91 | continue 92 | 93 | 94 | def str2bool(v): 95 | """ 96 | Convert a string to a boolean. 97 | 98 | Supported true values: 'yes', 'true', 't', 'y', '1' 99 | Supported false values: 'no', 'false', 'f', 'n', '0' 100 | 101 | Args: 102 | v (str): String to convert. 103 | 104 | Returns: 105 | bool: Converted boolean value. 106 | 107 | Raises: 108 | argparse.ArgumentTypeError: If the value cannot be converted to boolean. 109 | """ 110 | if isinstance(v, bool): 111 | return v 112 | v_lower = v.lower() 113 | if v_lower in ('yes', 'true', 't', 'y', '1'): 114 | return True 115 | elif v_lower in ('no', 'false', 'f', 'n', '0'): 116 | return False 117 | else: 118 | raise argparse.ArgumentTypeError('Boolean value expected (True/False)') 119 | -------------------------------------------------------------------------------- /wan/utils/vace_processor.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | import torchvision.transforms.functional as TF 6 | from PIL import Image 7 | 8 | 9 | class VaceImageProcessor(object): 10 | 11 | def __init__(self, downsample=None, seq_len=None): 12 | self.downsample = downsample 13 | self.seq_len = seq_len 14 | 15 | def _pillow_convert(self, image, cvt_type='RGB'): 16 | if image.mode != cvt_type: 17 | if image.mode == 'P': 18 | image = image.convert(f'{cvt_type}A') 19 | if image.mode == f'{cvt_type}A': 20 | bg = Image.new( 21 | cvt_type, 22 | size=(image.width, image.height), 23 | color=(255, 255, 255)) 24 | bg.paste(image, (0, 0), mask=image) 25 | image = bg 26 | else: 27 | image = image.convert(cvt_type) 28 | return image 29 | 30 | def _load_image(self, img_path): 31 | if img_path is None or img_path == '': 32 | return None 33 | img = Image.open(img_path) 34 | img = self._pillow_convert(img) 35 | return img 36 | 37 | def _resize_crop(self, img, oh, ow, normalize=True): 38 | """ 39 | Resize, center crop, convert to tensor, and normalize. 40 | """ 41 | # resize and crop 42 | iw, ih = img.size 43 | if iw != ow or ih != oh: 44 | # resize 45 | scale = max(ow / iw, oh / ih) 46 | img = img.resize((round(scale * iw), round(scale * ih)), 47 | resample=Image.Resampling.LANCZOS) 48 | assert img.width >= ow and img.height >= oh 49 | 50 | # center crop 51 | x1 = (img.width - ow) // 2 52 | y1 = (img.height - oh) // 2 53 | img = img.crop((x1, y1, x1 + ow, y1 + oh)) 54 | 55 | # normalize 56 | if normalize: 57 | img = TF.to_tensor(img).sub_(0.5).div_(0.5).unsqueeze(1) 58 | return img 59 | 60 | def _image_preprocess(self, img, oh, ow, normalize=True, **kwargs): 61 | return self._resize_crop(img, oh, ow, normalize) 62 | 63 | def load_image(self, data_key, **kwargs): 64 | return self.load_image_batch(data_key, **kwargs) 65 | 66 | def load_image_pair(self, data_key, data_key2, **kwargs): 67 | return self.load_image_batch(data_key, data_key2, **kwargs) 68 | 69 | def load_image_batch(self, 70 | *data_key_batch, 71 | normalize=True, 72 | seq_len=None, 73 | **kwargs): 74 | seq_len = self.seq_len if seq_len is None else seq_len 75 | imgs = [] 76 | for data_key in data_key_batch: 77 | img = self._load_image(data_key) 78 | imgs.append(img) 79 | w, h = imgs[0].size 80 | dh, dw = self.downsample[1:] 81 | 82 | # compute output size 83 | scale = min(1., np.sqrt(seq_len / ((h / dh) * (w / dw)))) 84 | oh = int(h * scale) // dh * dh 85 | ow = int(w * scale) // dw * dw 86 | assert (oh // dh) * (ow // dw) <= seq_len 87 | imgs = [self._image_preprocess(img, oh, ow, normalize) for img in imgs] 88 | return *imgs, (oh, ow) 89 | 90 | 91 | class VaceVideoProcessor(object): 92 | 93 | def __init__(self, downsample, min_area, max_area, min_fps, max_fps, 94 | zero_start, seq_len, keep_last, **kwargs): 95 | self.downsample = downsample 96 | self.min_area = min_area 97 | self.max_area = max_area 98 | self.min_fps = min_fps 99 | self.max_fps = max_fps 100 | self.zero_start = zero_start 101 | self.keep_last = keep_last 102 | self.seq_len = seq_len 103 | assert seq_len >= min_area / (self.downsample[1] * self.downsample[2]) 104 | 105 | def set_area(self, area): 106 | self.min_area = area 107 | self.max_area = area 108 | 109 | def set_seq_len(self, seq_len): 110 | self.seq_len = seq_len 111 | 112 | @staticmethod 113 | def resize_crop(video: torch.Tensor, oh: int, ow: int): 114 | """ 115 | Resize, center crop and normalize for decord loaded video (torch.Tensor type) 116 | 117 | Parameters: 118 | video - video to process (torch.Tensor): Tensor from `reader.get_batch(frame_ids)`, in shape of (T, H, W, C) 119 | oh - target height (int) 120 | ow - target width (int) 121 | 122 | Returns: 123 | The processed video (torch.Tensor): Normalized tensor range [-1, 1], in shape of (C, T, H, W) 124 | 125 | Raises: 126 | """ 127 | # permute ([t, h, w, c] -> [t, c, h, w]) 128 | video = video.permute(0, 3, 1, 2) 129 | 130 | # resize and crop 131 | ih, iw = video.shape[2:] 132 | if ih != oh or iw != ow: 133 | # resize 134 | scale = max(ow / iw, oh / ih) 135 | video = F.interpolate( 136 | video, 137 | size=(round(scale * ih), round(scale * iw)), 138 | mode='bicubic', 139 | antialias=True) 140 | assert video.size(3) >= ow and video.size(2) >= oh 141 | 142 | # center crop 143 | x1 = (video.size(3) - ow) // 2 144 | y1 = (video.size(2) - oh) // 2 145 | video = video[:, :, y1:y1 + oh, x1:x1 + ow] 146 | 147 | # permute ([t, c, h, w] -> [c, t, h, w]) and normalize 148 | video = video.transpose(0, 1).float().div_(127.5).sub_(1.) 149 | return video 150 | 151 | def _video_preprocess(self, video, oh, ow): 152 | return self.resize_crop(video, oh, ow) 153 | 154 | def _get_frameid_bbox_default(self, fps, frame_timestamps, h, w, crop_box, 155 | rng): 156 | target_fps = min(fps, self.max_fps) 157 | duration = frame_timestamps[-1].mean() 158 | x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box 159 | h, w = y2 - y1, x2 - x1 160 | ratio = h / w 161 | df, dh, dw = self.downsample 162 | 163 | area_z = min(self.seq_len, self.max_area / (dh * dw), 164 | (h // dh) * (w // dw)) 165 | of = min((int(duration * target_fps) - 1) // df + 1, 166 | int(self.seq_len / area_z)) 167 | 168 | # deduce target shape of the [latent video] 169 | target_area_z = min(area_z, int(self.seq_len / of)) 170 | oh = round(np.sqrt(target_area_z * ratio)) 171 | ow = int(target_area_z / oh) 172 | of = (of - 1) * df + 1 173 | oh *= dh 174 | ow *= dw 175 | 176 | # sample frame ids 177 | target_duration = of / target_fps 178 | begin = 0. if self.zero_start else rng.uniform( 179 | 0, duration - target_duration) 180 | timestamps = np.linspace(begin, begin + target_duration, of) 181 | frame_ids = np.argmax( 182 | np.logical_and(timestamps[:, None] >= frame_timestamps[None, :, 0], 183 | timestamps[:, None] < frame_timestamps[None, :, 1]), 184 | axis=1).tolist() 185 | return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps 186 | 187 | def _get_frameid_bbox_adjust_last(self, fps, frame_timestamps, h, w, 188 | crop_box, rng): 189 | duration = frame_timestamps[-1].mean() 190 | x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box 191 | h, w = y2 - y1, x2 - x1 192 | ratio = h / w 193 | df, dh, dw = self.downsample 194 | 195 | area_z = min(self.seq_len, self.max_area / (dh * dw), 196 | (h // dh) * (w // dw)) 197 | of = min((len(frame_timestamps) - 1) // df + 1, 198 | int(self.seq_len / area_z)) 199 | 200 | # deduce target shape of the [latent video] 201 | target_area_z = min(area_z, int(self.seq_len / of)) 202 | oh = round(np.sqrt(target_area_z * ratio)) 203 | ow = int(target_area_z / oh) 204 | of = (of - 1) * df + 1 205 | oh *= dh 206 | ow *= dw 207 | 208 | # sample frame ids 209 | target_duration = duration 210 | target_fps = of / target_duration 211 | timestamps = np.linspace(0., target_duration, of) 212 | frame_ids = np.argmax( 213 | np.logical_and(timestamps[:, None] >= frame_timestamps[None, :, 0], 214 | timestamps[:, None] <= frame_timestamps[None, :, 1]), 215 | axis=1).tolist() 216 | # print(oh, ow, of, target_duration, target_fps, len(frame_timestamps), len(frame_ids)) 217 | return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps 218 | 219 | def _get_frameid_bbox(self, fps, frame_timestamps, h, w, crop_box, rng): 220 | if self.keep_last: 221 | return self._get_frameid_bbox_adjust_last(fps, frame_timestamps, h, 222 | w, crop_box, rng) 223 | else: 224 | return self._get_frameid_bbox_default(fps, frame_timestamps, h, w, 225 | crop_box, rng) 226 | 227 | def load_video(self, data_key, crop_box=None, seed=2024, **kwargs): 228 | return self.load_video_batch( 229 | data_key, crop_box=crop_box, seed=seed, **kwargs) 230 | 231 | def load_video_pair(self, 232 | data_key, 233 | data_key2, 234 | crop_box=None, 235 | seed=2024, 236 | **kwargs): 237 | return self.load_video_batch( 238 | data_key, data_key2, crop_box=crop_box, seed=seed, **kwargs) 239 | 240 | def load_video_batch(self, 241 | *data_key_batch, 242 | crop_box=None, 243 | seed=2024, 244 | **kwargs): 245 | rng = np.random.default_rng(seed + hash(data_key_batch[0]) % 10000) 246 | # read video 247 | import decord 248 | decord.bridge.set_bridge('torch') 249 | readers = [] 250 | for data_k in data_key_batch: 251 | reader = decord.VideoReader(data_k) 252 | readers.append(reader) 253 | 254 | fps = readers[0].get_avg_fps() 255 | length = min([len(r) for r in readers]) 256 | frame_timestamps = [ 257 | readers[0].get_frame_timestamp(i) for i in range(length) 258 | ] 259 | frame_timestamps = np.array(frame_timestamps, dtype=np.float32) 260 | h, w = readers[0].next().shape[:2] 261 | frame_ids, (x1, x2, y1, y2), (oh, ow), fps = self._get_frameid_bbox( 262 | fps, frame_timestamps, h, w, crop_box, rng) 263 | 264 | # preprocess video 265 | videos = [ 266 | reader.get_batch(frame_ids)[:, y1:y2, x1:x2, :] 267 | for reader in readers 268 | ] 269 | videos = [self._video_preprocess(video, oh, ow) for video in videos] 270 | return *videos, frame_ids, (oh, ow), fps 271 | # return videos if len(videos) > 1 else videos[0] 272 | 273 | 274 | def prepare_source(src_video, src_mask, src_ref_images, num_frames, image_size, 275 | device): 276 | for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)): 277 | if sub_src_video is None and sub_src_mask is None: 278 | src_video[i] = torch.zeros( 279 | (3, num_frames, image_size[0], image_size[1]), device=device) 280 | src_mask[i] = torch.ones( 281 | (1, num_frames, image_size[0], image_size[1]), device=device) 282 | for i, ref_images in enumerate(src_ref_images): 283 | if ref_images is not None: 284 | for j, ref_img in enumerate(ref_images): 285 | if ref_img is not None and ref_img.shape[-2:] != image_size: 286 | canvas_height, canvas_width = image_size 287 | ref_height, ref_width = ref_img.shape[-2:] 288 | white_canvas = torch.ones( 289 | (3, 1, canvas_height, canvas_width), 290 | device=device) # [-1, 1] 291 | scale = min(canvas_height / ref_height, 292 | canvas_width / ref_width) 293 | new_height = int(ref_height * scale) 294 | new_width = int(ref_width * scale) 295 | resized_image = F.interpolate( 296 | ref_img.squeeze(1).unsqueeze(0), 297 | size=(new_height, new_width), 298 | mode='bilinear', 299 | align_corners=False).squeeze(0).unsqueeze(1) 300 | top = (canvas_height - new_height) // 2 301 | left = (canvas_width - new_width) // 2 302 | white_canvas[:, :, top:top + new_height, 303 | left:left + new_width] = resized_image 304 | src_ref_images[i][j] = white_canvas 305 | return src_video, src_mask, src_ref_images 306 | --------------------------------------------------------------------------------