├── utils ├── __init__.py ├── README.md ├── print_utils.py ├── distributed_utils.py ├── tf_utils.py ├── ckpt_utils.py ├── cluster_utils.py ├── detector_utils.py ├── video_utils.py ├── wandb_utils.py └── gemini_utils.py ├── algorithms ├── __init__.py ├── common │ ├── __init__.py │ ├── models │ │ ├── __init__.py │ │ ├── mlp.py │ │ └── cnn.py │ ├── README.md │ ├── base_algo.py │ └── base_pytorch_algo.py ├── wan │ ├── distributed │ │ ├── __init__.py │ │ ├── fsdp.py │ │ └── xdit_context_parallel.py │ ├── __init__.py │ ├── utils │ │ ├── __init__.py │ │ └── utils.py │ ├── modules │ │ ├── __init__.py │ │ ├── tokenizers.py │ │ ├── xlm_roberta.py │ │ └── attention.py │ ├── configs │ │ ├── shared_config.py │ │ ├── wan_t2v_14B.py │ │ ├── wan_t2v_1_3B.py │ │ ├── wan_i2v_14B.py │ │ └── __init__.py │ └── wan_i2v.py ├── cogvideo │ ├── text_encoder │ │ ├── __init__.py │ │ └── text_encoder.py │ ├── __init__.py │ ├── t5.py │ ├── pos_embed.py │ └── cogvideox_vae.py └── README.md ├── datasets ├── __init__.py ├── deprecated │ └── video_1x_wm.py ├── dummy.py ├── README.md ├── pandas.py ├── droid.py ├── agibot_world.py ├── something_something.py ├── mixture.py ├── ego4d.py └── openx_base.py ├── configurations ├── cluster │ ├── fas_low.yaml │ ├── fas_single.yaml │ ├── README.md │ ├── phase3.yaml │ ├── mit_satori.yaml │ ├── mit_supercloud.yaml │ ├── mit_vision.yaml │ ├── fas_cpu.yaml │ ├── base_slurm.yaml │ ├── fas_high.yaml │ ├── fas_boyuan.yaml │ ├── tianyuan_high_single.yaml │ └── tianyuan_requeue.yaml ├── dataset │ ├── dummy.yaml │ ├── deprecated │ │ ├── video_1x_wm.yaml │ │ ├── bc_z.yaml │ │ ├── toto.yaml │ │ ├── jaco_play.yaml │ │ ├── roboturk.yaml │ │ ├── cmu_stretch.yaml │ │ ├── taco_play.yaml │ │ ├── viola.yaml │ │ ├── fractal.yaml │ │ ├── dobbe.yaml │ │ ├── utaustin_mutex.yaml │ │ ├── berkeley_autolab.yaml │ │ ├── fmb.yaml │ │ ├── berkeley_cable.yaml │ │ ├── berkeley_fanuc.yaml │ │ ├── austin_sailor.yaml │ │ ├── austin_sirius.yaml │ │ ├── dlr_edan.yaml │ │ ├── austin_buds.yaml │ │ ├── iamlab_cmu.yaml │ │ ├── nyu_franka.yaml │ │ ├── stanford_hydra.yaml │ │ └── ucsd_kitchen.yaml │ ├── something_something.yaml │ ├── ours_test.yaml │ ├── ego4d.yaml │ ├── agibot_world.yaml │ ├── pandas.yaml │ ├── language_table.yaml │ ├── bridge.yaml │ ├── droid.yaml │ ├── mixture_robot.yaml │ ├── mixture.yaml │ ├── openx_base.yaml │ ├── epic_kitchen.yaml │ └── video_base.yaml ├── algorithm │ ├── base_algo.yaml │ ├── base_pytorch_algo.yaml │ ├── wan_toy.yaml │ ├── wan_i2v.yaml │ └── wan_t2v.yaml ├── README.md ├── experiment │ ├── base_experiment.yaml │ ├── process_data.yaml │ ├── exp_video.yaml │ └── base_pytorch.yaml ├── sweep │ └── example_sweep.yaml └── config.yaml ├── experiments ├── README.md ├── __init__.py └── exp_video.py ├── requirements.txt ├── main.py └── README.md /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /algorithms/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /algorithms/common/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /algorithms/common/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /algorithms/wan/distributed/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /algorithms/cogvideo/text_encoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .text_encoder import compute_prompt_embeddings 2 | -------------------------------------------------------------------------------- /algorithms/wan/__init__.py: -------------------------------------------------------------------------------- 1 | from .wan_i2v import WanImageToVideo 2 | from .wan_t2v import WanTextToVideo 3 | -------------------------------------------------------------------------------- /algorithms/common/README.md: -------------------------------------------------------------------------------- 1 | THis folder contains models / algorithms that are considered general for many algorithms. 2 | -------------------------------------------------------------------------------- /algorithms/cogvideo/__init__.py: -------------------------------------------------------------------------------- 1 | from .cogvideox_i2v import CogVideoXImageToVideo 2 | from .cogvideox_vae import CogVideoXVAE 3 | -------------------------------------------------------------------------------- /configurations/cluster/fas_low.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - fas_high 3 | - _self_ 4 | 5 | params: 6 | partition: kempner_requeue -------------------------------------------------------------------------------- /utils/README.md: -------------------------------------------------------------------------------- 1 | # utils 2 | 3 | This is where you can put useful utilities like visualization, 3d conversion, logging etc 4 | -------------------------------------------------------------------------------- /utils/print_utils.py: -------------------------------------------------------------------------------- 1 | from colorama import Fore 2 | 3 | 4 | def cyan(x: str) -> str: 5 | return f"{Fore.CYAN}{x}{Fore.RESET}" 6 | -------------------------------------------------------------------------------- /configurations/cluster/fas_single.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - fas_low 3 | - _self_ 4 | params: 5 | num_gpus: 1 6 | num_cpus: 16 7 | memory: 64G -------------------------------------------------------------------------------- /configurations/dataset/dummy.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - video_base 3 | - _self_ 4 | 5 | load_video_latent: false 6 | load_prompt_embed: true 7 | image_to_video: true 8 | -------------------------------------------------------------------------------- /configurations/dataset/deprecated/video_1x_wm.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - video_base 3 | - _self_ 4 | 5 | data_root: data/1x_world_model 6 | metadata_path: metadata.csv 7 | -------------------------------------------------------------------------------- /configurations/dataset/something_something.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - video_base 3 | - _self_ 4 | 5 | data_root: data/something_something_v2 6 | metadata_path: merged_metadata.csv 7 | -------------------------------------------------------------------------------- /configurations/algorithm/base_algo.yaml: -------------------------------------------------------------------------------- 1 | # This will be passed as the cfg to Algo.__init__(cfg) of your algorithm class 2 | 3 | debug: ${debug} # inherited from configurations/config.yaml 4 | -------------------------------------------------------------------------------- /configurations/algorithm/base_pytorch_algo.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base_algo # inherits from configurations/algorithm/base_algo.yaml 3 | - _self_ 4 | 5 | lr: ${experiment.training.lr} 6 | -------------------------------------------------------------------------------- /configurations/cluster/README.md: -------------------------------------------------------------------------------- 1 | This folder contains config file for launching jobs on slurm cluster. All fields in configs files might not transfer to your cluster configuration. 2 | 3 | -------------------------------------------------------------------------------- /configurations/dataset/ours_test.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - video_base 3 | - _self_ 4 | 5 | data_root: data/ours_test 6 | metadata_path: metadata.csv 7 | 8 | test_percentage: 1.0 9 | load_prompt_embed: false 10 | 11 | filtering: 12 | disable: true 13 | -------------------------------------------------------------------------------- /configurations/README.md: -------------------------------------------------------------------------------- 1 | # configurations 2 | 3 | We use [Hydra](https://hydra.cc/docs/intro/) to manage configurations. Change/Add the yaml files in this folder 4 | to change the default configurations. You can also override the default configurations by 5 | passing command line arguments. 6 | 7 | -------------------------------------------------------------------------------- /configurations/dataset/deprecated/bc_z.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - openx_base 3 | - _self_ 4 | data_root: data/openx_embodiment/bc_z 5 | metadata_path: metadata.csv 6 | 7 | download: 8 | openx_name: bc_z 9 | openx_fps: 60 # TODO: change this according to visualization 10 | views: ["image"] -------------------------------------------------------------------------------- /configurations/dataset/deprecated/toto.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - openx_base 3 | - _self_ 4 | 5 | data_root: data/openx_embodiment/toto 6 | metadata_path: metadata.csv 7 | 8 | download: 9 | openx_name: toto 10 | openx_fps: 60 # TODO: change this according to visualization 11 | views: ["image"] -------------------------------------------------------------------------------- /configurations/dataset/deprecated/jaco_play.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - openx_base 3 | - _self_ 4 | data_root: data/openx_embodiment/jaco_play 5 | metadata_path: metadata.csv 6 | 7 | download: 8 | openx_name: jaco_play 9 | openx_fps: 60 # TODO: change this according to visualization 10 | views: ["image"] -------------------------------------------------------------------------------- /configurations/dataset/deprecated/roboturk.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - openx_base 3 | - _self_ 4 | data_root: data/openx_embodiment/roboturk 5 | metadata_path: metadata.csv 6 | 7 | download: 8 | openx_name: roboturk 9 | openx_fps: 60 # TODO: change this according to visualization 10 | views: ["front_rgb"] -------------------------------------------------------------------------------- /configurations/dataset/ego4d.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - video_base 3 | - _self_ 4 | 5 | data_root: data/ego4d 6 | #metadata_path: no_recaption.csv 7 | metadata_path: merged_metadata.csv 8 | load_prompt_embed: true 9 | pad_mode: slowdown 10 | 11 | filtering: 12 | n_frames: [60, 100] 13 | pad_mode: slowdown -------------------------------------------------------------------------------- /configurations/dataset/deprecated/cmu_stretch.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - openx_base 3 | - _self_ 4 | data_root: data/openx_embodiment/cmu_stretch 5 | metadata_path: metadata.csv 6 | 7 | download: 8 | openx_name: cmu_stretch 9 | openx_fps: 60 # TODO: change this according to visualization 10 | views: ["image"] -------------------------------------------------------------------------------- /configurations/dataset/deprecated/taco_play.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - openx_base 3 | - _self_ 4 | data_root: data/openx_embodiment/taco_play 5 | metadata_path: metadata.csv 6 | 7 | download: 8 | openx_name: taco_play 9 | openx_fps: 60 # TODO: change this according to visualization 10 | views: ["rgb_static"] -------------------------------------------------------------------------------- /configurations/dataset/deprecated/viola.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - openx_base 3 | - _self_ 4 | 5 | data_root: data/openx_embodiment/viola 6 | metadata_path: metadata.csv 7 | 8 | download: 9 | openx_name: viola 10 | openx_fps: 60 # TODO: change this according to visualization 11 | views: ["agentview_rgb"] 12 | -------------------------------------------------------------------------------- /configurations/algorithm/wan_toy.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - wan_i2v 3 | - _self_ 4 | 5 | text_encoder: 6 | ckpt_path: null 7 | 8 | vae: 9 | ckpt_path: null 10 | 11 | clip: 12 | ckpt_path: null 13 | 14 | model: 15 | ckpt_path: null 16 | dim: 128 17 | ffn_dim: 128 18 | num_heads: 4 19 | num_layers: 2 20 | -------------------------------------------------------------------------------- /configurations/cluster/phase3.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - fas_boyuan 3 | - _self_ 4 | 5 | params: 6 | partition: kempner_h100_priority2 # e.g. kempner_h100 7 | account: kempner_sham_lab # e.g. kempner_sham_lab 8 | env_name: wm 9 | num_gpus: 4 10 | num_cpus: 48 11 | memory: 512G 12 | time: "14-00:00:00" 13 | -------------------------------------------------------------------------------- /configurations/dataset/agibot_world.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - video_base 3 | - _self_ 4 | 5 | data_root: data/agibot_beta_temp 6 | metadata_path: merged_metadata.csv 7 | 8 | filtering: # filter raw videos based on these criteria 9 | n_frames: [60, 360] # number of frames range for the videos 10 | 11 | fps_override: 60 -------------------------------------------------------------------------------- /configurations/dataset/pandas.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - video_base 3 | - _self_ 4 | 5 | data_root: data/pandas 6 | #metadata_path: no_recaption.csv 7 | metadata_path: merged_metadata.csv 8 | load_prompt_embed: true 9 | trim_mode: random_cut 10 | 11 | test_percentage: 0.01 12 | 13 | filtering: 14 | n_frames: [60, 121] -------------------------------------------------------------------------------- /configurations/dataset/deprecated/fractal.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - openx_base 3 | - _self_ 4 | data_root: data/openx_embodiment/fractal20220817_data 5 | metadata_path: metadata.csv 6 | 7 | download: 8 | openx_name: fractal20220817_data 9 | openx_fps: 60 # TODO: change this according to visualization 10 | views: ["image"] -------------------------------------------------------------------------------- /configurations/dataset/deprecated/dobbe.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - openx_base 3 | - _self_ 4 | data_root: data/openx_embodiment/dobbe 5 | metadata_path: metadata.csv 6 | 7 | download: 8 | openx_name: dobbe 9 | openx_version: "0.0.1" 10 | openx_fps: 60 # TODO: change this according to visualization 11 | views: ["wrist_image"] -------------------------------------------------------------------------------- /configurations/dataset/deprecated/utaustin_mutex.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - openx_base 3 | - _self_ 4 | 5 | data_root: data/openx_embodiment/utaustin_mutex 6 | metadata_path: metadata.csv 7 | 8 | download: 9 | openx_name: utaustin_mutex 10 | openx_fps: 60 # TODO: change this according to visualization 11 | views: ["image"] -------------------------------------------------------------------------------- /configurations/dataset/language_table.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - openx_base 3 | - _self_ 4 | 5 | data_root: data/openx_embodiment/language_table 6 | metadata_path: merged_metadata.csv 7 | 8 | download: 9 | openx_name: language_table 10 | openx_fps: 6 11 | views: ["rgb"] 12 | 13 | filtering: 14 | n_frames: [7, 35] 15 | -------------------------------------------------------------------------------- /configurations/dataset/deprecated/berkeley_autolab.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - openx_base 3 | - _self_ 4 | data_root: data/openx_embodiment/berkeley_autolab_ur5 5 | metadata_path: metadata.csv 6 | 7 | download: 8 | openx_name: berkeley_autolab_ur5 9 | openx_fps: 60 # TODO: change this according to visualization 10 | views: ["image"] -------------------------------------------------------------------------------- /utils/distributed_utils.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | from typing import Callable 3 | import torch 4 | import torch.distributed as dist 5 | from lightning.pytorch.utilities.rank_zero import rank_zero_only 6 | 7 | is_rank_zero = wandb.run is not None 8 | is_rank_zero = rank_zero_only.rank == 0 9 | 10 | rank_zero_print = rank_zero_only(print) 11 | -------------------------------------------------------------------------------- /configurations/dataset/bridge.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - openx_base 3 | - _self_ 4 | data_root: data/openx_embodiment/bridge 5 | metadata_path: merged_metadata.csv 6 | 7 | download: 8 | openx_name: bridge 9 | openx_fps: 10 # TODO: change this according to visualization 10 | views: ["image"] 11 | 12 | filtering: 13 | n_frames: [15, 60] -------------------------------------------------------------------------------- /configurations/dataset/deprecated/fmb.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - openx_base 3 | - _self_ 4 | data_root: data/openx_embodiment/fmb 5 | metadata_path: metadata.csv 6 | 7 | download: 8 | openx_name: fmb 9 | openx_version: "0.0.1" 10 | openx_fps: 60 # TODO: change this according to visualization 11 | views: ["image_side_1", "image_side_2"] -------------------------------------------------------------------------------- /configurations/dataset/deprecated/berkeley_cable.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - openx_base 3 | - _self_ 4 | data_root: data/openx_embodiment/berkeley_cable_routing 5 | metadata_path: metadata.csv 6 | 7 | download: 8 | openx_name: berkeley_cable_routing 9 | openx_fps: 60 # TODO: change this according to visualization 10 | views: ["wrist45_image"] -------------------------------------------------------------------------------- /configurations/dataset/deprecated/berkeley_fanuc.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - openx_base 3 | - _self_ 4 | data_root: data/openx_embodiment/berkeley_fanuc_manipulation 5 | metadata_path: metadata.csv 6 | 7 | download: 8 | openx_name: berkeley_fanuc_manipulation 9 | openx_fps: 60 # TODO: change this according to visualization 10 | views: ["image"] -------------------------------------------------------------------------------- /configurations/experiment/base_experiment.yaml: -------------------------------------------------------------------------------- 1 | debug: ${debug} # inherited from configurations/config.yaml 2 | num_nodes: 1 # number of nodes for slurm distributed launch. ignore this if you don't specify `cluster=xxx` 3 | tasks: [main] # tasks to run sequantially, such as [training, test], useful when your project has multiple stages and you want to run only a openx of them. 4 | -------------------------------------------------------------------------------- /algorithms/wan/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .fm_solvers import (FlowDPMSolverMultistepScheduler, get_sampling_sigmas, 2 | retrieve_timesteps) 3 | from .fm_solvers_unipc import FlowUniPCMultistepScheduler 4 | 5 | __all__ = [ 6 | 'HuggingfaceTokenizer', 'get_sampling_sigmas', 'retrieve_timesteps', 7 | 'FlowDPMSolverMultistepScheduler', 'FlowUniPCMultistepScheduler' 8 | ] 9 | -------------------------------------------------------------------------------- /configurations/dataset/deprecated/austin_sailor.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - openx_base 3 | - _self_ 4 | data_root: data/openx_embodiment/austin_sailor_dataset_converted_externally_to_rlds 5 | metadata_path: metadata.csv 6 | 7 | download: 8 | openx_name: austin_sailor_dataset_converted_externally_to_rlds 9 | openx_fps: 60 # TODO: change this according to visualization 10 | views: ["image"] -------------------------------------------------------------------------------- /configurations/dataset/deprecated/austin_sirius.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - openx_base 3 | - _self_ 4 | data_root: data/openx_embodiment/austin_sirius_dataset_converted_externally_to_rlds 5 | metadata_path: metadata.csv 6 | 7 | download: 8 | openx_name: austin_sirius_dataset_converted_externally_to_rlds 9 | openx_fps: 60 # TODO: change this according to visualization 10 | views: ["image"] -------------------------------------------------------------------------------- /configurations/dataset/deprecated/dlr_edan.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - openx_base 3 | - _self_ 4 | data_root: data/openx_embodiment/dlr_edan_shared_control_converted_externally_to_rlds 5 | metadata_path: metadata.csv 6 | 7 | download: 8 | openx_name: dlr_edan_shared_control_converted_externally_to_rlds 9 | openx_fps: 60 # TODO: change this according to visualization 10 | views: ["image"] -------------------------------------------------------------------------------- /configurations/dataset/deprecated/austin_buds.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - openx_base 3 | - _self_ 4 | 5 | data_root: data/openx_embodiment/austin_buds_dataset_converted_externally_to_rlds 6 | metadata_path: metadata.csv 7 | 8 | download: 9 | openx_name: austin_buds_dataset_converted_externally_to_rlds 10 | openx_fps: 60 # TODO: change this according to visualization 11 | views: ["image"] -------------------------------------------------------------------------------- /configurations/dataset/deprecated/iamlab_cmu.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - openx_base 3 | - _self_ 4 | data_root: data/openx_embodiment/iamlab_cmu_pickup_insert_converted_externally_to_rlds 5 | metadata_path: metadata.csv 6 | 7 | download: 8 | openx_name: iamlab_cmu_pickup_insert_converted_externally_to_rlds 9 | openx_fps: 60 # TODO: change this according to visualization 10 | views: ["image"] -------------------------------------------------------------------------------- /configurations/dataset/deprecated/nyu_franka.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - openx_base 3 | - _self_ 4 | data_root: data/openx_embodiment/nyu_franka_play_dataset_converted_externally_to_rlds 5 | metadata_path: metadata.csv 6 | 7 | download: 8 | openx_name: nyu_franka_play_dataset_converted_externally_to_rlds 9 | openx_fps: 60 # TODO: change this according to visualization 10 | views: ["image"] -------------------------------------------------------------------------------- /configurations/dataset/deprecated/stanford_hydra.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - openx_base 3 | - _self_ 4 | data_root: data/openx_embodiment/stanford_hydra_dataset_converted_externally_to_rlds 5 | metadata_path: metadata.csv 6 | 7 | download: 8 | openx_name: stanford_hydra_dataset_converted_externally_to_rlds 9 | openx_fps: 60 # TODO: change this according to visualization 10 | views: ["image"] -------------------------------------------------------------------------------- /configurations/dataset/deprecated/ucsd_kitchen.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - openx_base 3 | - _self_ 4 | 5 | data_root: data/openx_embodiment/ucsd_kitchen_dataset_converted_externally_to_rlds 6 | metadata_path: metadata.csv 7 | 8 | download: 9 | openx_name: ucsd_kitchen_dataset_converted_externally_to_rlds 10 | openx_fps: 60 # TODO: change this according to visualization 11 | views: ["image"] -------------------------------------------------------------------------------- /algorithms/wan/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .attention import flash_attention 2 | from .model import WanModel 3 | from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model 4 | from .tokenizers import HuggingfaceTokenizer 5 | from .vae import WanVAE 6 | 7 | __all__ = [ 8 | 'WanVAE', 9 | 'WanModel', 10 | 'T5Model', 11 | 'T5Encoder', 12 | 'T5Decoder', 13 | 'T5EncoderModel', 14 | 'HuggingfaceTokenizer', 15 | 'flash_attention', 16 | ] 17 | -------------------------------------------------------------------------------- /configurations/cluster/mit_satori.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base_slurm 3 | - _self_ 4 | launch_template: | 5 | #!/bin/bash 6 | 7 | #SBATCH -J {name} 8 | #SBATCH -o {log_dir}/out_%j.out 9 | #SBATCH -e {log_dir}/error_%j.err 10 | #SBATCH --mail-user={email} 11 | #SBATCH --mail-type=FAIL 12 | #SBATCH --gres=gpu:{num_gpus} 13 | #SBATCH --cpus-per-task={num_cpus} 14 | #SBATCH --mem={memory} 15 | #SBATCH --time={time} 16 | 17 | source ~/.bashrc 18 | module load cuda/11.2 19 | conda activate {env_name} 20 | cd {project_root} 21 | python -m main {python_args} -------------------------------------------------------------------------------- /algorithms/common/base_algo.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Any, Dict, List, Optional, Tuple, Union 3 | 4 | from omegaconf import DictConfig 5 | 6 | 7 | class BaseAlgo(ABC): 8 | """ 9 | A base class for generic algorithms. 10 | """ 11 | 12 | def __init__(self, cfg: DictConfig): 13 | super().__init__() 14 | self.cfg = cfg 15 | self.debug = self.cfg.debug 16 | 17 | @abstractmethod 18 | def run(*args: Any, **kwargs: Any) -> Any: 19 | """ 20 | Run the algorithm. 21 | """ 22 | raise NotImplementedError 23 | -------------------------------------------------------------------------------- /configurations/dataset/droid.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - openx_base 3 | - _self_ 4 | 5 | data_root: data/droid 6 | #metadata_path: no_recaption.csv 7 | metadata_path: merged_metadata.csv 8 | load_prompt_embed: true 9 | pad_mode: slowdown 10 | 11 | filtering: # filter raw videos based on these criteria 12 | n_frames: [0, 300] # number of frames range for the videos 13 | height: [0, 4096] # height range for the videos 14 | width: [0, 4096] # width range for the videos 15 | fps: [0, 100] # fps range for the videos 16 | 17 | download: 18 | override_fps: 75 # x5 speedup 19 | views: ["ext1", "ext2"] #, "wrist_image_left" -------------------------------------------------------------------------------- /configurations/algorithm/wan_i2v.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - wan_t2v 3 | - _self_ 4 | 5 | text_encoder: 6 | ckpt_path: data/ckpts/Wan2.1-I2V-14B-480P/models_t5_umt5-xxl-enc-bf16.pth 7 | 8 | vae: 9 | ckpt_path: data/ckpts/Wan2.1-I2V-14B-480P/Wan2.1_VAE.pth 10 | 11 | clip: 12 | ckpt_path: data/ckpts/Wan2.1-I2V-14B-480P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth 13 | compile: true 14 | 15 | model: 16 | ckpt_path: data/ckpts/Wan2.1-I2V-14B-480P 17 | tuned_ckpt_path: data/ckpts/lvp_14B.ckpt #data/ckpts/phase3_40000.ckpt 18 | model_type: i2v 19 | dim: 5120 20 | ffn_dim: 13824 21 | num_heads: 40 22 | num_layers: 40 23 | -------------------------------------------------------------------------------- /configurations/cluster/mit_supercloud.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base_slurm 3 | - _self_ 4 | is_compute_node_offline: True # many slurm systems only allows internet on login node, not compute node 5 | 6 | launch_template: | 7 | #!/bin/bash 8 | 9 | #SBATCH -J {name} 10 | #SBATCH -o {log_dir}/out_%j.out 11 | #SBATCH -e {log_dir}/error_%j.err 12 | #SBATCH --mail-user={email} 13 | #SBATCH --mail-type=FAIL 14 | #SBATCH --gres=gpu:volta:{num_gpus} 15 | #SBATCH --cpus-per-task={num_cpus} 16 | #SBATCH --mem={memory} 17 | #SBATCH --time={time} 18 | 19 | cd {project_root} 20 | module load anaconda/2023a 21 | 22 | python -m main {python_args} -------------------------------------------------------------------------------- /configurations/cluster/mit_vision.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base_slurm 3 | - _self_ 4 | params: 5 | partition: null # e.g. vision-sitzmann 6 | qos: null # e.g. vision-sitzmann-main 7 | 8 | launch_template: | 9 | #!/bin/bash 10 | 11 | #SBATCH -J {name} 12 | #SBATCH -o {log_dir}/out_%j.out 13 | #SBATCH -e {log_dir}/error_%j.err 14 | #SBATCH --mail-user={email} 15 | #SBATCH --mail-type=FAIL 16 | #SBATCH --gres=gpu:{num_gpus} 17 | #SBATCH --cpus-per-task={num_cpus} 18 | #SBATCH --mem={memory} 19 | #SBATCH --time={time} 20 | #SBATCH --partition={partition} 21 | #SBATCH --qos={qos} 22 | source ~/.bashrc 23 | conda activate {env_name} 24 | cd {project_root} 25 | python -m main {python_args} 26 | -------------------------------------------------------------------------------- /configurations/sweep/example_sweep.yaml: -------------------------------------------------------------------------------- 1 | # wandb sweep configuration 2 | # this is independent of all other configurations under configurations/ folder as this is not used by the code 3 | 4 | program: main.py 5 | method: grid # hp search method 6 | 7 | metric: 8 | goal: maximize 9 | name: validation/accuracy 10 | 11 | parameters: 12 | # Sweep params 13 | algorithm.lr: 14 | values: [1e-3, 1e-4] 15 | experiment.training.batch_size: 16 | values: [32, 64] 17 | 18 | # Default params 19 | wandb.mode: 20 | value: online 21 | 22 | command: 23 | - ${env} 24 | - python 25 | - ${program} 26 | - ${args_no_hyphens} 27 | - +name=example_lr${algorithm.lr}_batch${experiment.training.batch_size} 28 | -------------------------------------------------------------------------------- /configurations/dataset/mixture_robot.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - video_base 3 | # - droid@subset/droid 4 | # - bridge@subset/bridge 5 | - agibot_world@subset/agibot_world 6 | - _self_ 7 | 8 | data_root: null 9 | metadata_path: null 10 | load_prompt_embed: true 11 | load_video_latent: false 12 | fps: 16 13 | 14 | training: 15 | weight_type: relative # relative weight consider the original size of the dataset, absolute weight doesn't 16 | weight: 17 | # droid: 1.0 18 | # bridge: 2.0 19 | agibot_world: 2.0 20 | validation: 21 | weight_type: absolute # relative weight consider the original size of the dataset, absolute weight doesn't 22 | weight: 23 | # droid: 1.0 24 | # bridge: 1.0 25 | agibot_world: 1.0 -------------------------------------------------------------------------------- /algorithms/wan/configs/shared_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import torch 3 | from easydict import EasyDict 4 | 5 | #------------------------ Wan shared config ------------------------# 6 | wan_shared_cfg = EasyDict() 7 | 8 | # t5 9 | wan_shared_cfg.t5_model = 'umt5_xxl' 10 | wan_shared_cfg.t5_dtype = torch.bfloat16 11 | wan_shared_cfg.text_len = 512 12 | 13 | # transformer 14 | wan_shared_cfg.param_dtype = torch.bfloat16 15 | 16 | # inference 17 | wan_shared_cfg.num_train_timesteps = 1000 18 | wan_shared_cfg.sample_fps = 16 19 | wan_shared_cfg.sample_neg_prompt = '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走' 20 | -------------------------------------------------------------------------------- /configurations/experiment/process_data.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base_experiment 3 | - _self_ 4 | 5 | tasks: [visualize_dataset] # add the method names you want to run, e.g. [cache_prompt_embed] 6 | new_data_root: null # newly created csv and files will be saved here. null will defaults to the output_dir of this run. 7 | 8 | visualize_dataset: 9 | n_samples: 32 10 | disable_filtering: false 11 | use_processed: true # if true, will use processed videos from __getitem__ instead of raw files 12 | 13 | cache_prompt_embed: 14 | batch_size: 32 15 | 16 | create_gemini_caption: 17 | n_workers: 12 18 | 19 | run_hand_pose_estimation: 20 | n_workers: 12 # not used 21 | 22 | run_human_detection: 23 | total_workers: 2 # not used 24 | job_id: 0 25 | # save_dir: "outputs/" 26 | 27 | benchmark_dataloader: 28 | batch_size: 4 29 | num_workers: 8 30 | -------------------------------------------------------------------------------- /algorithms/wan/configs/wan_t2v_14B.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | from easydict import EasyDict 3 | 4 | from .shared_config import wan_shared_cfg 5 | 6 | #------------------------ Wan T2V 14B ------------------------# 7 | 8 | t2v_14B = EasyDict(__name__='Config: Wan T2V 14B') 9 | t2v_14B.update(wan_shared_cfg) 10 | 11 | # t5 12 | t2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth' 13 | t2v_14B.t5_tokenizer = 'google/umt5-xxl' 14 | 15 | # vae 16 | t2v_14B.vae_checkpoint = 'Wan2.1_VAE.pth' 17 | t2v_14B.vae_stride = (4, 8, 8) 18 | 19 | # transformer 20 | t2v_14B.patch_size = (1, 2, 2) 21 | t2v_14B.dim = 5120 22 | t2v_14B.ffn_dim = 13824 23 | t2v_14B.freq_dim = 256 24 | t2v_14B.num_heads = 40 25 | t2v_14B.num_layers = 40 26 | t2v_14B.window_size = (-1, -1) 27 | t2v_14B.qk_norm = True 28 | t2v_14B.cross_attn_norm = True 29 | t2v_14B.eps = 1e-6 30 | -------------------------------------------------------------------------------- /algorithms/wan/configs/wan_t2v_1_3B.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | from easydict import EasyDict 3 | 4 | from .shared_config import wan_shared_cfg 5 | 6 | #------------------------ Wan T2V 1.3B ------------------------# 7 | 8 | t2v_1_3B = EasyDict(__name__='Config: Wan T2V 1.3B') 9 | t2v_1_3B.update(wan_shared_cfg) 10 | 11 | # t5 12 | t2v_1_3B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth' 13 | t2v_1_3B.t5_tokenizer = 'google/umt5-xxl' 14 | 15 | # vae 16 | t2v_1_3B.vae_checkpoint = 'Wan2.1_VAE.pth' 17 | t2v_1_3B.vae_stride = (4, 8, 8) 18 | 19 | # transformer 20 | t2v_1_3B.patch_size = (1, 2, 2) 21 | t2v_1_3B.dim = 1536 22 | t2v_1_3B.ffn_dim = 8960 23 | t2v_1_3B.freq_dim = 256 24 | t2v_1_3B.num_heads = 12 25 | t2v_1_3B.num_layers = 30 26 | t2v_1_3B.window_size = (-1, -1) 27 | t2v_1_3B.qk_norm = True 28 | t2v_1_3B.cross_attn_norm = True 29 | t2v_1_3B.eps = 1e-6 30 | -------------------------------------------------------------------------------- /experiments/README.md: -------------------------------------------------------------------------------- 1 | # experiments 2 | 3 | `experiments` folder contains code of experiments. Each file in the experiment folder represents a certain type of 4 | benchmark specific to a project. Such experiment can be instantiated with a certain dataset and a certain algorithm. 5 | 6 | You should create a new `.py` file for your experiment, 7 | inherent from any suitable base classes in `experiments/exp_base.py`, 8 | and then register your new experiment in `experiments/__init__.py`. 9 | 10 | You run an experiment by running `python -m main [options]` in the root directory of the 11 | project. You should not log any data in this folder, but storing them under `outputs` under root project 12 | directory. 13 | 14 | This folder is only intend to contain formal experiments. For debug code and unit tests, put them under `debug` folder. 15 | For scripts that's not meant to be an experiment please use `scripts` folder. 16 | -------------------------------------------------------------------------------- /configurations/cluster/fas_cpu.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base_slurm 3 | - _self_ 4 | params: 5 | partition: shared # e.g. kempner_h100 6 | # account: kempner_sham_lab # e.g. kempner_sham_lab 7 | env_name: wm 8 | num_gpus: 4 9 | num_cpus: 48 10 | memory: 128G 11 | time: "3-00:00:00" 12 | 13 | launch_template: | 14 | #!/bin/bash 15 | #SBATCH -J {name} 16 | #SBATCH -o {log_dir}/out_%j.out 17 | #SBATCH -e {log_dir}/error_%j.err 18 | #SBATCH --mail-user={email} 19 | #SBATCH --mail-type=FAIL 20 | #SBATCH --partition={partition} 21 | #SBATCH --nodes=${experiment.num_nodes} 22 | #SBATCH --cpus-per-task=12 23 | #SBATCH --mem={memory} 24 | #SBATCH --time={time} 25 | 26 | # export NCCL_DEBUG=INFO 27 | # export PYTHONFAULTHANDLER=1 28 | 29 | cd {project_root} 30 | module load Mambaforge 31 | mamba deactivate 32 | mamba activate {env_name} 33 | srun python -m main {python_args} 34 | 35 | -------------------------------------------------------------------------------- /configurations/cluster/base_slurm.yaml: -------------------------------------------------------------------------------- 1 | is_compute_node_offline: False # many slurm systems only allows internet on login node, not compute node 2 | 3 | params: 4 | env_name: template # change this to the name of your conda environment 5 | num_gpus: 1 6 | num_cpus: 32 7 | memory: 32G 8 | time: "24:0:0" # Acceptable time formats include "minutes", "minutes:seconds", "hours:minutes:seconds", "days-hours", "days-hours:minutes" and "days-hours:minutes:seconds". 9 | email: null 10 | 11 | launch_template: | 12 | #!/bin/bash 13 | 14 | #SBATCH -J {name} 15 | #SBATCH -o {log_dir}/out_%j.out 16 | #SBATCH -e {log_dir}/error_%j.err 17 | #SBATCH --mail-user={email} 18 | #SBATCH --mail-type=FAIL 19 | #SBATCH --gres=gpu:{num_gpus} 20 | #SBATCH --cpus-per-task={num_cpus} 21 | #SBATCH --mem={memory} 22 | #SBATCH --time={time} 23 | 24 | source ~/.bashrc 25 | conda activate {env_name} 26 | cd {project_root} 27 | python -m main {python_args} 28 | -------------------------------------------------------------------------------- /algorithms/common/models/mlp.py: -------------------------------------------------------------------------------- 1 | from typing import Type, Optional 2 | 3 | import torch 4 | from torch import nn as nn 5 | 6 | 7 | class SimpleMlp(nn.Module): 8 | """ 9 | A class for very simple multi layer perceptron 10 | """ 11 | 12 | def __init__( 13 | self, 14 | in_dim=2, 15 | out_dim=1, 16 | hidden_dim=64, 17 | n_layers=2, 18 | activation: Type[nn.Module] = nn.ReLU, 19 | output_activation: Optional[Type[nn.Module]] = None, 20 | ): 21 | super(SimpleMlp, self).__init__() 22 | layers = [nn.Linear(in_dim, hidden_dim), activation()] 23 | layers.extend( 24 | [nn.Linear(hidden_dim, hidden_dim), activation()] * (n_layers - 2) 25 | ) 26 | layers.append(nn.Linear(hidden_dim, out_dim)) 27 | if output_activation: 28 | layers.append(output_activation()) 29 | self.net = nn.Sequential(*layers) 30 | 31 | def forward(self, x): 32 | return self.net(x) 33 | -------------------------------------------------------------------------------- /utils/tf_utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def recursive_cast_to_numpy(obj): 5 | if isinstance(obj, tf.Tensor): 6 | if obj.dtype == tf.string: 7 | # Decode the string tensor to Python strings 8 | return obj.numpy().tolist() if obj.ndim > 0 else obj.numpy().decode("utf-8") 9 | else: 10 | # Convert non-string tensors to numpy arrays 11 | return obj.numpy() 12 | elif isinstance(obj, dict): 13 | # Recursively handle dictionary values 14 | return {key: recursive_cast_to_numpy(value) for key, value in obj.items()} 15 | elif isinstance(obj, list): 16 | # Recursively handle list elements 17 | return [recursive_cast_to_numpy(item) for item in obj] 18 | elif isinstance(obj, tuple): 19 | # Recursively handle tuple elements 20 | return tuple(recursive_cast_to_numpy(item) for item in obj) 21 | else: 22 | # Return the object as-is if it's not a tf.Tensor 23 | return obj 24 | -------------------------------------------------------------------------------- /configurations/experiment/exp_video.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base_pytorch 3 | - _self_ 4 | 5 | tasks: [training] 6 | 7 | training: 8 | lr: 1e-5 9 | precision: bf16-mixed 10 | batch_size: 1 11 | max_epochs: -1 12 | max_steps: 10000000 13 | checkpointing: 14 | every_n_train_steps: 2000 15 | every_n_epochs: null 16 | save_weights_only: true 17 | filename: "latest" 18 | optim: 19 | accumulate_grad_batches: 4 20 | gradient_clip_val: null 21 | data: 22 | num_workers: 5 # number of CPU threads for data preprocessing. 23 | 24 | validation: 25 | precision: bf16-mixed 26 | val_every_n_step: 1000 27 | val_every_n_epoch: null 28 | batch_size: 1 29 | limit_batch: 1 30 | data: 31 | num_workers: 1 # number of CPU threads for data preprocessing, for validation. 32 | 33 | test: 34 | precision: bf16-mixed 35 | limit_batch: null 36 | batch_size: 1 37 | data: 38 | num_workers: 1 # number of CPU threads for data preprocessing, for test. 39 | 40 | find_unused_parameters: False -------------------------------------------------------------------------------- /utils/ckpt_utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import wandb 3 | 4 | 5 | def is_run_id(run_id: str) -> bool: 6 | """Check if a string is a run ID.""" 7 | return len(run_id) == 8 and run_id.isalnum() 8 | 9 | 10 | def version_to_int(artifact) -> int: 11 | """Convert versions of the form vX to X. For example, v12 to 12.""" 12 | return int(artifact.version[1:]) 13 | 14 | 15 | def download_latest_checkpoint(run_path: str, download_dir: Path) -> Path: 16 | api = wandb.Api() 17 | run = api.run(run_path) 18 | 19 | # Find the latest saved model checkpoint. 20 | latest = None 21 | for artifact in run.logged_artifacts(): 22 | if artifact.type != "model" or artifact.state != "COMMITTED": 23 | continue 24 | 25 | if latest is None or version_to_int(artifact) > version_to_int(latest): 26 | latest = artifact 27 | 28 | # Download the checkpoint. 29 | download_dir.mkdir(exist_ok=True, parents=True) 30 | root = download_dir / run_path 31 | latest.download(root=root) 32 | return root / "model.ckpt" 33 | -------------------------------------------------------------------------------- /algorithms/README.md: -------------------------------------------------------------------------------- 1 | # algorithms 2 | 3 | `algorithms` folder is designed to contain implementation of algorithms or models. 4 | Content in `algorithms` can be loosely grouped components (e.g. models) or an algorithm has already has all 5 | components chained together (e.g. Lightning Module, RL algo). 6 | You should create a folder name after your own algorithm or baselines in it. 7 | 8 | Two example can be found in `examples` subfolder. 9 | 10 | The `common` subfolder is designed to contain general purpose classes that's useful for many projects, e.g MLP. 11 | 12 | You should not run any `.py` file from algorithms folder. 13 | Instead, you write unit tests / debug python files in `debug` and launch script in `experiments`. 14 | 15 | You are discouraged from putting visualization utilities in algorithms, as those should go to `utils` in project root. 16 | 17 | Each algorithm class takes in a DictConfig file `cfg` in its `__init__`, which allows you to pass in arguments via configuration file in `configurations/algorithm` or [command line override](https://hydra.cc/docs/tutorials/basic/your_first_app/simple_cli/). 18 | -------------------------------------------------------------------------------- /configurations/cluster/fas_high.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base_slurm 3 | - _self_ 4 | params: 5 | partition: kempner_h100 # e.g. kempner_h100 6 | account: kempner_sham_lab # e.g. kempner_sham_lab 7 | env_name: ei_world_model 8 | num_gpus: 4 9 | num_cpus: 48 10 | memory: 256G 11 | time: "3-00:00:00" 12 | 13 | launch_template: | 14 | #!/bin/bash 15 | #SBATCH -J {name} 16 | #SBATCH -o {log_dir}/out_%j.out 17 | #SBATCH -e {log_dir}/error_%j.err 18 | #SBATCH --mail-user={email} 19 | #SBATCH --mail-type=FAIL 20 | #SBATCH --account={account} 21 | #SBATCH --partition={partition} 22 | #SBATCH --nodes=${experiment.num_nodes} 23 | #SBATCH --ntasks-per-node={num_gpus} 24 | #SBATCH --gres=gpu:nvidia_h100_80gb_hbm3:{num_gpus} 25 | #SBATCH --cpus-per-task=12 26 | #SBATCH --mem={memory} 27 | #SBATCH --time={time} 28 | 29 | # export NCCL_DEBUG=INFO 30 | # export PYTHONFAULTHANDLER=1 31 | 32 | cd {project_root} 33 | module load Mambaforge 34 | mamba deactivate 35 | mamba activate {env_name} 36 | module load cuda/12.4.1-fasrc01 37 | module load gcc/9.5.0-fasrc01 38 | srun python -m main {python_args} -------------------------------------------------------------------------------- /configurations/cluster/fas_boyuan.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base_slurm 3 | - _self_ 4 | params: 5 | partition: kempner_h100_priority2 # e.g. kempner_h100 6 | account: kempner_sham_lab # e.g. kempner_sham_lab 7 | env_name: wm 8 | num_gpus: 4 9 | num_cpus: 48 10 | memory: 512G 11 | time: "3-00:00:00" 12 | 13 | launch_template: | 14 | #!/bin/bash 15 | #SBATCH -J {name} 16 | #SBATCH -o {log_dir}/out_%j.out 17 | #SBATCH -e {log_dir}/error_%j.err 18 | #SBATCH --mail-user={email} 19 | #SBATCH --mail-type=FAIL 20 | #SBATCH --account={account} 21 | #SBATCH --partition={partition} 22 | #SBATCH --nodes=${experiment.num_nodes} 23 | #SBATCH --ntasks-per-node={num_gpus} 24 | #SBATCH --gres=gpu:nvidia_h100_80gb_hbm3:{num_gpus} 25 | #SBATCH --cpus-per-task=12 26 | #SBATCH --mem={memory} 27 | #SBATCH --time={time} 28 | 29 | # export NCCL_DEBUG=INFO 30 | # export PYTHONFAULTHANDLER=1 31 | 32 | cd {project_root} 33 | module load Mambaforge 34 | mamba deactivate 35 | mamba activate {env_name} 36 | module load cuda/12.4.1-fasrc01 37 | module load gcc/9.5.0-fasrc01 38 | srun python -m main {python_args} 39 | -------------------------------------------------------------------------------- /algorithms/wan/configs/wan_i2v_14B.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import torch 3 | from easydict import EasyDict 4 | 5 | from .shared_config import wan_shared_cfg 6 | 7 | #------------------------ Wan I2V 14B ------------------------# 8 | 9 | i2v_14B = EasyDict(__name__='Config: Wan I2V 14B') 10 | i2v_14B.update(wan_shared_cfg) 11 | 12 | i2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth' 13 | i2v_14B.t5_tokenizer = 'google/umt5-xxl' 14 | 15 | # clip 16 | i2v_14B.clip_model = 'clip_xlm_roberta_vit_h_14' 17 | i2v_14B.clip_dtype = torch.float16 18 | i2v_14B.clip_checkpoint = 'models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth' 19 | i2v_14B.clip_tokenizer = 'xlm-roberta-large' 20 | 21 | # vae 22 | i2v_14B.vae_checkpoint = 'Wan2.1_VAE.pth' 23 | i2v_14B.vae_stride = (4, 8, 8) 24 | 25 | # transformer 26 | i2v_14B.patch_size = (1, 2, 2) 27 | i2v_14B.dim = 5120 28 | i2v_14B.ffn_dim = 13824 29 | i2v_14B.freq_dim = 256 30 | i2v_14B.num_heads = 40 31 | i2v_14B.num_layers = 40 32 | i2v_14B.window_size = (-1, -1) 33 | i2v_14B.qk_norm = True 34 | i2v_14B.cross_attn_norm = True 35 | i2v_14B.eps = 1e-6 36 | -------------------------------------------------------------------------------- /configurations/cluster/tianyuan_high_single.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base_slurm 3 | - _self_ 4 | params: 5 | partition: kempner_requeue # e.g. kempner_h100 6 | account: kempner_sham_lab # e.g. kempner_sham_lab 7 | env_name: ei_world_model 8 | num_gpus: 1 9 | num_cpus: 12 10 | memory: 128G 11 | time: "3-00:00:00" 12 | 13 | launch_template: | 14 | #!/bin/bash 15 | #SBATCH -J {name} 16 | #SBATCH -o {log_dir}/out_%j.out 17 | #SBATCH -e {log_dir}/error_%j.err 18 | #SBATCH --mail-user={email} 19 | #SBATCH --mail-type=FAIL 20 | #SBATCH --account={account} 21 | #SBATCH --partition={partition} 22 | #SBATCH --nodes=${experiment.num_nodes} 23 | #SBATCH --ntasks-per-node={num_gpus} 24 | #SBATCH --gres=gpu:nvidia_h100_80gb_hbm3:{num_gpus} 25 | #SBATCH --cpus-per-task=12 26 | #SBATCH --mem={memory} 27 | #SBATCH --time={time} 28 | 29 | # export NCCL_DEBUG=INFO 30 | # export PYTHONFAULTHANDLER=1 31 | 32 | cd {project_root} 33 | module load Mambaforge 34 | mamba deactivate 35 | mamba activate {env_name} 36 | module load cuda/12.4.1-fasrc01 37 | module load cudnn 38 | module load gcc/9.5.0-fasrc01 39 | srun python -m main {python_args} -------------------------------------------------------------------------------- /configurations/config.yaml: -------------------------------------------------------------------------------- 1 | # configuration parsing starts here 2 | defaults: 3 | - experiment: exp_video # experiment yaml file name in configurations/experiments folder [fixme] 4 | - dataset: mixture # dataset yaml file name in configurations/dataset folder [fixme] 5 | - algorithm: wan_i2v # algorithm yaml file name in configurations/algorithm folder [fixme] 6 | - cluster: null # optional, cluster yaml file name in configurations/cluster folder. Leave null for local compute 7 | - _self_ 8 | 9 | debug: false # global debug flag will be passed into configuration of experiment, dataset and algorithm 10 | 11 | wandb: 12 | entity: # wandb account name / organization name [fixme] 13 | project: large_video_planner # wandb project name; if not provided, defaults to root folder name [fixme] 14 | mode: offline # set wandb logging to online, offline or dryrun 15 | log_model: false # whether log ckpt and upload to wandb. "all" is recommended but may take a lot of space 16 | 17 | resume: null # wandb run id to resume logging and loading checkpoint from 18 | load: null # wanmdb run id containing checkpoint or a path to a checkpoint file 19 | 20 | -------------------------------------------------------------------------------- /configurations/cluster/tianyuan_requeue.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base_slurm 3 | - _self_ 4 | 5 | params: 6 | partition: kempner_requeue # e.g. kempner_h100 7 | account: kempner_sham_lab # e.g. kempner_sham_lab 8 | env_name: ei_world_model 9 | num_gpus: 4 10 | num_cpus: 48 11 | memory: 256G 12 | time: "3-00:00:00" 13 | 14 | launch_template: | 15 | #!/bin/bash 16 | #SBATCH -J {name} 17 | #SBATCH -o {log_dir}/out_%j.out 18 | #SBATCH -e {log_dir}/error_%j.err 19 | #SBATCH --mail-user={email} 20 | #SBATCH --mail-type=FAIL 21 | #SBATCH --account={account} 22 | #SBATCH --partition={partition} 23 | #SBATCH --nodes=${experiment.num_nodes} 24 | #SBATCH --ntasks-per-node={num_gpus} 25 | #SBATCH --gres=gpu:nvidia_h100_80gb_hbm3:{num_gpus} 26 | #SBATCH --cpus-per-task=12 27 | #SBATCH --mem={memory} 28 | #SBATCH --time={time} 29 | 30 | # export NCCL_DEBUG=INFO 31 | # export PYTHONFAULTHANDLER=1 32 | 33 | cd {project_root} 34 | module load Mambaforge 35 | mamba deactivate 36 | mamba activate {env_name} 37 | module load cuda/12.4.1-fasrc01 38 | module load cudnn 39 | module load gcc/9.5.0-fasrc01 40 | srun python -m main {python_args} -------------------------------------------------------------------------------- /algorithms/wan/distributed/fsdp.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | from functools import partial 3 | 4 | import torch 5 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 6 | from torch.distributed.fsdp import MixedPrecision, ShardingStrategy 7 | from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy 8 | 9 | 10 | def shard_model( 11 | model, 12 | device_id, 13 | param_dtype=torch.bfloat16, 14 | reduce_dtype=torch.float32, 15 | buffer_dtype=torch.float32, 16 | process_group=None, 17 | sharding_strategy=ShardingStrategy.FULL_SHARD, 18 | sync_module_states=True, 19 | ): 20 | model = FSDP( 21 | module=model, 22 | process_group=process_group, 23 | sharding_strategy=sharding_strategy, 24 | auto_wrap_policy=partial( 25 | lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks), 26 | mixed_precision=MixedPrecision( 27 | param_dtype=param_dtype, 28 | reduce_dtype=reduce_dtype, 29 | buffer_dtype=buffer_dtype), 30 | device_id=device_id, 31 | sync_module_states=sync_module_states) 32 | return model 33 | -------------------------------------------------------------------------------- /algorithms/wan/configs/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import copy 3 | import os 4 | 5 | os.environ['TOKENIZERS_PARALLELISM'] = 'false' 6 | 7 | from .wan_i2v_14B import i2v_14B 8 | from .wan_t2v_1_3B import t2v_1_3B 9 | from .wan_t2v_14B import t2v_14B 10 | 11 | # the config of t2i_14B is the same as t2v_14B 12 | t2i_14B = copy.deepcopy(t2v_14B) 13 | t2i_14B.__name__ = 'Config: Wan T2I 14B' 14 | 15 | WAN_CONFIGS = { 16 | 't2v-14B': t2v_14B, 17 | 't2v-1.3B': t2v_1_3B, 18 | 'i2v-14B': i2v_14B, 19 | 't2i-14B': t2i_14B, 20 | } 21 | 22 | SIZE_CONFIGS = { 23 | '720*1280': (720, 1280), 24 | '1280*720': (1280, 720), 25 | '480*832': (480, 832), 26 | '832*480': (832, 480), 27 | '1024*1024': (1024, 1024), 28 | } 29 | 30 | MAX_AREA_CONFIGS = { 31 | '720*1280': 720 * 1280, 32 | '1280*720': 1280 * 720, 33 | '480*832': 480 * 832, 34 | '832*480': 832 * 480, 35 | } 36 | 37 | SUPPORTED_SIZES = { 38 | 't2v-14B': ('720*1280', '1280*720', '480*832', '832*480'), 39 | 't2v-1.3B': ('480*832', '832*480'), 40 | 'i2v-14B': ('720*1280', '1280*720', '480*832', '832*480'), 41 | 't2i-14B': tuple(SIZE_CONFIGS.keys()), 42 | } 43 | -------------------------------------------------------------------------------- /configurations/dataset/mixture.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - video_base 3 | - pandas@subset/pandas 4 | - epic_kitchen@subset/epic_kitchen 5 | - ego4d@subset/ego4d 6 | - droid@subset/droid 7 | - something_something@subset/something_something 8 | - bridge@subset/bridge 9 | - agibot_world@subset/agibot_world 10 | - language_table@subset/language_table 11 | - _self_ 12 | 13 | data_root: null 14 | metadata_path: null 15 | load_prompt_embed: true 16 | load_video_latent: false 17 | fps: 16 18 | 19 | training: 20 | weight_type: relative # relative weight consider the original size of the dataset, absolute weight doesn't 21 | weight: 22 | pandas: 0.5 23 | epic_kitchen: 2.0 24 | ego4d: 1.5 25 | droid: 1.0 26 | something_something: 0.5 27 | bridge: 1.0 28 | agibot_world: 1.0 # 2.5 for phase 3 and 1.0 for phase 3.5 29 | language_table: 0.05 30 | 31 | validation: 32 | weight_type: absolute # relative weight consider the original size of the dataset, absolute weight doesn't 33 | weight: 34 | pandas: 1.0 35 | epic_kitchen: 1.0 36 | ego4d: 1.0 37 | droid: 1.0 38 | something_something: 0.25 39 | bridge: 1.0 40 | agibot_world: 1.0 41 | language_table: 0.25 42 | -------------------------------------------------------------------------------- /configurations/dataset/openx_base.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - video_base 3 | - _self_ 4 | # data_root: /n/holylfs06/LABS/sham_lab/Lab/eiwm_data/openx/ 5 | # metadata_path: /n/holylfs06/LABS/sham_lab/Lab/eiwm_data/openx/robot_dataset_language_table.jsonl 6 | 7 | data_root: ??? # e.g. data/openx_embodiment/bridge 8 | metadata_path: ??? # e.g. bridge.csv 9 | 10 | download: 11 | openx_name: ??? # as defined in the offical colab notebook. The name for the path gs://gresearch/robotics/{openx_name} 12 | openx_version: "0.1.0" # version number from open-x itself. only need to change for language_table and robo_net 13 | openx_fps: ??? # open-x doesn't provide fps in dataset but in the associate google sheet, so we manually define it here. See https://docs.google.com/spreadsheets/d/1rPBD77tk60AEIGZrGSODwyyzs5FgCU9Uz3h-3_t2A9g/ 14 | views: ??? # e.g. a list of names of the views to put into the final metadata e.g. ["wrist_cam", "top_view"] 15 | 16 | filtering: 17 | disable: false 18 | 19 | augmentation: 20 | random_flip: null # probability of random flip, null means no random flip 21 | ratio: [1.0, 1.0] # random scaling of the aspect ratio, see torchvision.transforms.v2.RandomResizedCrop 22 | scale: [1.0, 1.0] # random crop the video, see torchvision.transforms.v2.RandomResizedCrop 23 | -------------------------------------------------------------------------------- /configurations/dataset/epic_kitchen.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - video_base 3 | - _self_ 4 | 5 | data_root: data/epic_kitchens 6 | #metadata_path: no_recaption.csv 7 | metadata_path: merged_metadata.csv 8 | load_prompt_embed: true 9 | pad_mode: slowdown 10 | 11 | filtering: 12 | n_frames: [15, 216] 13 | 14 | download: 15 | annotation_url: 16 | training: https://raw.githubusercontent.com/epic-kitchens/epic-kitchens-100-annotations/refs/heads/master/EPIC_100_train.csv 17 | validation: https://raw.githubusercontent.com/epic-kitchens/epic-kitchens-100-annotations/refs/heads/master/EPIC_100_validation.csv 18 | md5_url: https://raw.githubusercontent.com/epic-kitchens/epic-kitchens-download-scripts/refs/heads/master/data/md5.csv 19 | errata_url: https://raw.githubusercontent.com/epic-kitchens/epic-kitchens-download-scripts/refs/heads/master/data/errata.csv 20 | splits_url: 21 | epic_55: https://raw.githubusercontent.com/epic-kitchens/epic-kitchens-download-scripts/refs/heads/master/data/epic_55_splits.csv 22 | epic_100: https://raw.githubusercontent.com/epic-kitchens/epic-kitchens-download-scripts/refs/heads/master/data/epic_100_splits.csv 23 | removal_threshold: [48, 128] # when the clip is above the lower bound, trim a fraction of frames 24 | removal_rate_max: 0.75 25 | removal_front_back: [0.0, 1.0] 26 | -------------------------------------------------------------------------------- /utils/cluster_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | utils for submitting to clusters, such as slurm 3 | """ 4 | 5 | import os 6 | from omegaconf import DictConfig, OmegaConf 7 | from datetime import datetime 8 | from pathlib import Path 9 | 10 | from utils.print_utils import cyan 11 | 12 | # This is set below. 13 | REPO_DIR = None 14 | 15 | 16 | def submit_slurm_job( 17 | cfg: DictConfig, 18 | python_args: str, 19 | project_root: Path, 20 | ): 21 | log_dir = project_root / "slurm_logs" / f"{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}-{cfg.name}" 22 | log_dir.mkdir(exist_ok=True, parents=True) 23 | (project_root / "slurm_logs" / "latest").unlink(missing_ok=True) 24 | (project_root / "slurm_logs" / "latest").symlink_to(log_dir, target_is_directory=True) 25 | 26 | params = dict(name=cfg.name, log_dir=log_dir, project_root=project_root, python_args=python_args) 27 | params.update(cfg.cluster.params) 28 | 29 | slurm_script = cfg.cluster.launch_template.format(**params) 30 | 31 | slurm_script_path = log_dir / "job.slurm" 32 | with slurm_script_path.open("w") as f: 33 | f.write(slurm_script) 34 | 35 | os.system(f"chmod +x {slurm_script_path}") 36 | os.system(f"sbatch {slurm_script_path}") 37 | 38 | print(f"\n{cyan('script:')} {slurm_script_path}\n{cyan('slurm errors and logs:')} {log_dir}\n") 39 | 40 | return log_dir 41 | -------------------------------------------------------------------------------- /experiments/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | from omegaconf import DictConfig 3 | import pathlib 4 | from lightning.pytorch.loggers.wandb import WandbLogger 5 | 6 | from .exp_base import BaseExperiment 7 | from .exp_video import VideoPredictionExperiment 8 | from .process_data import ProcessDataExperiment 9 | 10 | # each key has to be a yaml file under '[project_root]/configurations/experiment' without .yaml suffix 11 | exp_registry = dict( 12 | exp_video=VideoPredictionExperiment, 13 | process_data=ProcessDataExperiment, 14 | ) 15 | 16 | 17 | def build_experiment( 18 | cfg: DictConfig, 19 | logger: Optional[WandbLogger] = None, 20 | ckpt_path: Optional[Union[str, pathlib.Path]] = None, 21 | ) -> BaseExperiment: 22 | """ 23 | Build an experiment instance based on registry 24 | :param cfg: configuration file 25 | :param logger: optional logger for the experiment 26 | :param ckpt_path: optional checkpoint path for saving and loading 27 | :return: 28 | """ 29 | if cfg.experiment._name not in exp_registry: 30 | raise ValueError( 31 | f"Experiment {cfg.experiment._name} not found in registry {list(exp_registry.keys())}. " 32 | "Make sure you register it correctly in 'experiments/__init__.py' under the same name as yaml file." 33 | ) 34 | 35 | return exp_registry[cfg.experiment._name](cfg, logger, ckpt_path) 36 | -------------------------------------------------------------------------------- /algorithms/cogvideo/t5.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | import torch 3 | 4 | from transformers import AutoTokenizer, T5EncoderModel 5 | 6 | from algorithms.common.base_pytorch_algo import BasePytorchAlgo 7 | from algorithms.cogvideo.text_encoder import compute_prompt_embeddings 8 | 9 | 10 | class T5Encoder(BasePytorchAlgo): 11 | def __init__(self, cfg): 12 | self.pretrained_cfg = cfg.pretrained 13 | self.max_text_seq_length = 226 14 | super().__init__(cfg) 15 | self._build_model() # Explicitly call _build_model after initialization 16 | 17 | def _build_model(self): 18 | self.tokenizer = AutoTokenizer.from_pretrained( 19 | self.pretrained_cfg.pretrained_model_name_or_path, 20 | subfolder="tokenizer", 21 | revision=self.pretrained_cfg.revision, 22 | ) 23 | 24 | self.text_encoder = T5EncoderModel.from_pretrained( 25 | self.pretrained_cfg.pretrained_model_name_or_path, 26 | subfolder="text_encoder", 27 | revision=self.pretrained_cfg.revision, 28 | ) 29 | self.text_encoder.requires_grad_(False) 30 | 31 | def training_step(self, *args, **kwargs): 32 | raise NotImplementedError("T5Encoder does not support training") 33 | 34 | @torch.no_grad() 35 | def predict(self, prompts: Sequence[str]): 36 | prompt_embeds = compute_prompt_embeddings( 37 | self.tokenizer, 38 | self.text_encoder, 39 | prompts, 40 | self.max_text_seq_length, 41 | self.device, 42 | torch.bfloat16, 43 | requires_grad=False, 44 | ) 45 | return prompt_embeds 46 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Core ML frameworks 2 | torch==2.6.0 3 | torchvision==0.21.0 4 | pytorch-lightning==2.5.0.post0 5 | lightning==2.6.0 6 | numpy==1.24.4 7 | 8 | # Language Models & Diffusion 9 | transformers==4.51.3 10 | diffusers==0.34.0 11 | accelerate==1.1.1 12 | tokenizers==0.21.4 13 | deepspeed==0.15.1 # Install with: DS_BUILD_UTILS=1 DS_BUILD_FUSED_ADAM=1 DS_BUILD_CPU_ADAM=1 pip install deepspeed 14 | #flash-attn==2.3.5 # Optional, for efficient attention. Install with: pip install flash-attn --no-build-isolation 15 | 16 | # Configuration & Logging 17 | hydra-core==1.3.2 18 | omegaconf==2.3.0 19 | pyyaml==6.0.2 20 | 21 | # Experiment Tracking 22 | wandb==0.19.1 23 | wandb-osh==1.2.2 24 | tensorboard==2.19.0 25 | 26 | # Media Processing 27 | opencv-python==4.9.0.80 28 | decord==0.6.0 29 | imageio==2.34.0 30 | imageio-ffmpeg==0.5.1 31 | moviepy==1.0.3 32 | av==14.0.1 # Note: Changed from pyav==14.0.1 33 | scikit-video==1.1.11 34 | 35 | # Image Processing 36 | pillow==11.1.0 37 | einops==0.8.0 38 | 39 | # Data Analysis 40 | pandas==2.1.4 41 | scikit-learn==1.5.0 42 | matplotlib==3.8.2 43 | plotly==5.18.0 44 | 45 | # Dataset Dependencies 46 | tensorflow==2.15.0 # For tensorflow-datasets 47 | tensorflow-datasets==4.9.7 48 | gcsfs==2024.12.0 # For Google Cloud Storage datasets 49 | ego4d==1.7.3 # For Ego4D dataset 50 | ijson==3.3.0 # For streaming JSON parsing 51 | 52 | # Text Processing and Tokenization 53 | sentencepiece==0.2.0 54 | ftfy==6.3.1 # For text cleaning 55 | 56 | # Hugging Face Hub 57 | huggingface-hub==0.33.4 58 | 59 | # Utilities 60 | tqdm==4.66.2 61 | colorama==0.4.6 62 | click==8.1.8 63 | easydict==1.13 64 | msgpack==1.1.2 # For message serialization 65 | pyzmq==27.1.0 # For ZeroMQ (used in serving) 66 | pydantic 67 | 68 | # Optional: Gradio for Demos 69 | gradio==5.11.0 -------------------------------------------------------------------------------- /utils/detector_utils.py: -------------------------------------------------------------------------------- 1 | # from: https://github.com/ibaiGorordo/Sapiens-Pytorch-Inference/blob/main/sapiens_inference/detector.py 2 | import time 3 | from dataclasses import dataclass 4 | import numpy as np 5 | from ultralytics import YOLO 6 | 7 | @dataclass 8 | class DetectorConfig: 9 | model_path: str = "~/models/yolov8m.pt" 10 | person_id: int = 0 11 | conf_thres: float = 0.25 12 | 13 | 14 | def draw_boxes(img, boxes, color=(0, 255, 0), thickness=2): 15 | draw_img = img.copy() 16 | for box in boxes: 17 | x1, y1, x2, y2 = box 18 | draw_img = cv2.rectangle(draw_img, (x1, y1), (x2, y2), color, thickness) 19 | return draw_img 20 | 21 | 22 | class Detector: 23 | def __init__(self, config: DetectorConfig = DetectorConfig()): 24 | model_path = config.model_path 25 | if not model_path.endswith(".pt"): 26 | model_path = model_path.split(".")[0] + ".pt" 27 | self.model = YOLO(model_path) 28 | self.person_id = config.person_id 29 | self.conf_thres = config.conf_thres 30 | 31 | def __call__(self, img: np.ndarray) -> np.ndarray: 32 | # input: np.ndarray, shape (H, W, C) 33 | # rgb or bgr? 34 | return self.detect(img) 35 | 36 | def detect(self, img: np.ndarray) -> np.ndarray: 37 | # input: np.ndarray, shape (H, W, C) in BGR 38 | start = time.perf_counter() 39 | results = self.model(img, conf=self.conf_thres) 40 | detections = results[0].boxes.data.cpu().numpy() # (x1, y1, x2, y2, conf, cls) 41 | 42 | # Filter out only person 43 | person_detections = detections[detections[:, -1] == self.person_id] 44 | boxes = person_detections[:, :-2].astype(int) # (x1, y1, x2, y2) 45 | 46 | print(f"Detection inference took: {time.perf_counter() - start:.4f} seconds") 47 | return boxes 48 | 49 | 50 | if __name__ == "__main__": 51 | import cv2 52 | 53 | detector = Detector() 54 | img = cv2.imread("../ComfyUI_00074_.png") 55 | boxes = detector.detect(img) 56 | draw_img = draw_boxes(img, boxes) 57 | cv2.imshow("img", draw_img) 58 | cv2.waitKey(0) 59 | -------------------------------------------------------------------------------- /datasets/deprecated/video_1x_wm.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from tqdm import tqdm 3 | import cv2 4 | import pandas as pd 5 | 6 | from ..video_base import VideoDataset 7 | 8 | 9 | class WorldModel1XDataset(VideoDataset): 10 | """ 11 | 1X world model challenge dataset from https://huggingface.co/datasets/1x-technologies/worldmodel_raw_data 12 | """ 13 | 14 | def download(self): 15 | from huggingface_hub import snapshot_download 16 | 17 | raw_dir = self.data_root / "raw" 18 | raw_dir.mkdir(parents=True, exist_ok=True) 19 | 20 | snapshot_download( 21 | repo_id="1x-technologies/worldmodel_raw_data", 22 | local_dir=raw_dir, 23 | repo_type="dataset", 24 | ) 25 | 26 | records = [] 27 | split_dict = { 28 | "training": list((raw_dir / "train_v2.0_raw/videos/").glob("*.mp4")), 29 | "validation": list((raw_dir / "val_v2.0_raw/").glob("*.mp4")), 30 | } 31 | for split, video_paths in split_dict.items(): 32 | for video_path in tqdm(video_paths, desc=f"Verifying {split} videos"): 33 | cap = cv2.VideoCapture(video_path) 34 | if not cap.isOpened(): 35 | continue 36 | 37 | width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) 38 | height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 39 | fps = int(cap.get(cv2.CAP_PROP_FPS)) 40 | n_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 41 | cap.release() 42 | 43 | records.append( 44 | { 45 | "video_path": str(video_path.relative_to(self.data_root)), 46 | "height": height, 47 | "width": width, 48 | "fps": fps, 49 | "n_frames": n_frames, 50 | "split": split, 51 | } 52 | ) 53 | 54 | # Save as CSV 55 | metadata_path = self.data_root / self.metadata_path 56 | metadata_path.parent.mkdir(parents=True, exist_ok=True) 57 | df = pd.DataFrame.from_records(records) 58 | df.to_csv(metadata_path, index=False) 59 | print(f"Created metadata CSV with {len(records)} videos") 60 | -------------------------------------------------------------------------------- /configurations/dataset/video_base.yaml: -------------------------------------------------------------------------------- 1 | debug: ${debug} 2 | data_root: ??? # dataset folder location e.g. ~/data/something_something_v2 3 | metadata_path: ??? # a csv / json file that lists the entries for the dataset, should be a file path relative to data_root. 4 | auto_download: false # whether to automatically download the dataset if the data_root does not exist, proceed with caution 5 | force_download: false # whether to force download the dataset even if the data_root already exists, bypassing every check 6 | test_percentage: 0.01 # percentage of the dataset to use for testing vs training. However, if a field `split` is present in the metadata, that will be used instead 7 | height: 480 # target height for the output videos 8 | width: 832 # target width for the output videos 9 | n_frames: 49 # target number of frames for the output videos 10 | fps: 16 # target fps for the output videos 11 | id_token: null # if not null, tokenize to an id token for this dataset 12 | load_video_latent: false # whether to load a raw latent tensor instead of mp4 file. Require a field `image_latent_path` in csv 13 | load_prompt_embed: false # whether to load a raw embed tensor instead of running language model online. Require a field `prompt_embed_path` in csv 14 | check_video_path: false # whether to check if the video_path in the metadata is valid 15 | trim_mode: speedup # one of ["speedup", "random_cut"], specify how do we handle a video that's too long 16 | pad_mode: slowdown # one of ["slowdown", "pad_last", "discard"], specify how do we handle a video that's too short 17 | max_text_tokens: ${algorithm.max_text_tokens} # maximum number of tokens for the text encoder 18 | 19 | filtering: # filter raw videos based on these criteria 20 | disable: false # whether to disable filtering 21 | height: [0, 4096] # height range for the videos 22 | width: [0, 4096] # width range for the videos 23 | fps: [0, 120] # fps range for the videos 24 | n_frames: [0, 4096] # number of frames range for the videos 25 | 26 | augmentation: 27 | random_flip: null # probability of random flip, null means no random flip 28 | ratio: [0.98, 1.02] # random scaling of the aspect ratio, see torchvision.transforms.v2.RandomResizedCrop 29 | scale: [0.8, 1.0] # random crop the video, see torchvision.transforms.v2.RandomResizedCrop 30 | 31 | image_to_video: true # whether returning the first image too for I2V model 32 | # video_reshape_mode: center 33 | -------------------------------------------------------------------------------- /configurations/algorithm/wan_t2v.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base_pytorch_algo # inherits from configurations/algorithm/base_algo.yaml 3 | - _self_ 4 | 5 | lr: ${experiment.training.lr} 6 | betas: [0.9, 0.95] 7 | weight_decay: 5e-2 8 | lr_scheduler: 9 | name: constant_with_warmup 10 | num_warmup_steps: 1000 11 | 12 | load_video_latent: ${dataset.load_video_latent} # if true, load latent from disk instead of using video vae 13 | load_prompt_embed: ${dataset.load_prompt_embed} # if true, load prompt embedding from disk instead of running language model online 14 | 15 | diffusion_forcing: 16 | enabled: true 17 | mode: rand_history # independent, rand_history 18 | clean_hist_prob: 0.5 # probability of giving first frame image condition when finetuning image-to-video, overriding diffusion forcing's noise level for first frame 19 | 20 | n_frames: ${dataset.n_frames} 21 | height: ${dataset.height} 22 | width: ${dataset.width} 23 | num_train_timesteps: 1000 24 | diffusion_type: "continuous" # or "discrete" 25 | sample_solver: unipc 26 | sample_steps: 40 27 | sample_shift: 3.0 28 | lang_guidance: 3.0 29 | neg_prompt: "" 30 | hist_guidance: 2.0 #2.0 31 | sliding_hist: 1 # use 2 latent frames as history when extending videos 32 | gradient_checkpointing_rate: 1.0 # gradient checkpointing blocks as a ratio of total blocks 33 | max_text_tokens: 512 34 | 35 | logging: 36 | loss_freq: 1 37 | video_freq: 1000 38 | video_type: grid # grid or single 39 | fps: ${dataset.fps} 40 | 41 | serving: 42 | port: 6688 43 | 44 | text_encoder: 45 | text_len: 512 46 | text_dim: 4096 47 | compile: true 48 | name: google/umt5-xxl 49 | ckpt_path: data/ckpts/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth 50 | 51 | vae: 52 | ckpt_path: data/ckpts/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth 53 | compile: true 54 | z_dim: 16 55 | stride: [4, 8, 8] 56 | mean: [-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921] 57 | std: [2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160] 58 | 59 | model: 60 | ckpt_path: data/ckpts/Wan2.1-T2V-1.3B 61 | tuned_ckpt_path: null 62 | compile: false #true 63 | model_type: t2v # if i2v, this flag will let the model take in CLIP features 64 | patch_size: [1, 2, 2] 65 | in_dim: ${algorithm.vae.z_dim} 66 | dim: 1536 67 | ffn_dim: 8960 68 | freq_dim: 256 69 | out_dim: ${algorithm.vae.z_dim} 70 | num_heads: 12 71 | num_layers: 30 72 | window_size: [-1, -1] 73 | qk_norm: True 74 | cross_attn_norm: True 75 | eps: 1e-6 76 | 77 | -------------------------------------------------------------------------------- /algorithms/cogvideo/pos_embed.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import torch 4 | from diffusers.models.embeddings import get_3d_rotary_pos_embed 5 | 6 | 7 | def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): 8 | # Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid 9 | tw = tgt_width 10 | th = tgt_height 11 | h, w = src 12 | r = h / w 13 | if r > (th / tw): 14 | resize_height = th 15 | resize_width = int(round(th / h * w)) 16 | else: 17 | resize_width = tw 18 | resize_height = int(round(tw / w * h)) 19 | 20 | crop_top = int(round((th - resize_height) / 2.0)) 21 | crop_left = int(round((tw - resize_width) / 2.0)) 22 | 23 | return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) 24 | 25 | 26 | def prepare_rotary_positional_embeddings( 27 | height: int, 28 | width: int, 29 | num_frames: int, 30 | vae_scale_factor_spatial: int = 8, 31 | patch_size: int = 2, 32 | patch_size_t: Optional[int] = None, 33 | attention_head_dim: int = 64, 34 | device: Optional[torch.device] = None, 35 | base_height: int = 480, 36 | base_width: int = 720, 37 | ) -> Tuple[torch.Tensor, torch.Tensor]: 38 | grid_height = height // (vae_scale_factor_spatial * patch_size) 39 | grid_width = width // (vae_scale_factor_spatial * patch_size) 40 | base_size_width = base_width // (vae_scale_factor_spatial * patch_size) 41 | base_size_height = base_height // (vae_scale_factor_spatial * patch_size) 42 | 43 | if patch_size_t is None: 44 | # CogVideoX 1.0 45 | grid_crops_coords = get_resize_crop_region_for_grid( 46 | (grid_height, grid_width), base_size_width, base_size_height 47 | ) 48 | freqs_cos, freqs_sin = get_3d_rotary_pos_embed( 49 | embed_dim=attention_head_dim, 50 | crops_coords=grid_crops_coords, 51 | grid_size=(grid_height, grid_width), 52 | temporal_size=num_frames, 53 | ) 54 | else: 55 | # CogVideoX 1.5 56 | base_num_frames = (num_frames + patch_size_t - 1) // patch_size_t 57 | freqs_cos, freqs_sin = get_3d_rotary_pos_embed( 58 | embed_dim=attention_head_dim, 59 | crops_coords=None, 60 | grid_size=(grid_height, grid_width), 61 | temporal_size=base_num_frames, 62 | grid_type="slice", 63 | max_size=(base_size_height, base_size_width), 64 | ) 65 | 66 | freqs_cos = freqs_cos.to(device=device) 67 | freqs_sin = freqs_sin.to(device=device) 68 | return freqs_cos, freqs_sin 69 | -------------------------------------------------------------------------------- /algorithms/wan/modules/tokenizers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import html 3 | import string 4 | 5 | import ftfy 6 | import regex as re 7 | from transformers import AutoTokenizer 8 | 9 | __all__ = ['HuggingfaceTokenizer'] 10 | 11 | 12 | def basic_clean(text): 13 | text = ftfy.fix_text(text) 14 | text = html.unescape(html.unescape(text)) 15 | return text.strip() 16 | 17 | 18 | def whitespace_clean(text): 19 | text = re.sub(r'\s+', ' ', text) 20 | text = text.strip() 21 | return text 22 | 23 | 24 | def canonicalize(text, keep_punctuation_exact_string=None): 25 | text = text.replace('_', ' ') 26 | if keep_punctuation_exact_string: 27 | text = keep_punctuation_exact_string.join( 28 | part.translate(str.maketrans('', '', string.punctuation)) 29 | for part in text.split(keep_punctuation_exact_string)) 30 | else: 31 | text = text.translate(str.maketrans('', '', string.punctuation)) 32 | text = text.lower() 33 | text = re.sub(r'\s+', ' ', text) 34 | return text.strip() 35 | 36 | 37 | class HuggingfaceTokenizer: 38 | 39 | def __init__(self, name, seq_len=None, clean=None, **kwargs): 40 | assert clean in (None, 'whitespace', 'lower', 'canonicalize') 41 | self.name = name 42 | self.seq_len = seq_len 43 | self.clean = clean 44 | 45 | # init tokenizer 46 | self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs) 47 | self.vocab_size = self.tokenizer.vocab_size 48 | 49 | def __call__(self, sequence, **kwargs): 50 | return_mask = kwargs.pop('return_mask', False) 51 | 52 | # arguments 53 | _kwargs = {'return_tensors': 'pt'} 54 | if self.seq_len is not None: 55 | _kwargs.update({ 56 | 'padding': 'max_length', 57 | 'truncation': True, 58 | 'max_length': self.seq_len 59 | }) 60 | _kwargs.update(**kwargs) 61 | 62 | # tokenization 63 | if isinstance(sequence, str): 64 | sequence = [sequence] 65 | if self.clean: 66 | sequence = [self._clean(u) for u in sequence] 67 | ids = self.tokenizer(sequence, **_kwargs) 68 | 69 | # output 70 | if return_mask: 71 | return ids.input_ids, ids.attention_mask 72 | else: 73 | return ids.input_ids 74 | 75 | def _clean(self, text): 76 | if self.clean == 'whitespace': 77 | text = whitespace_clean(basic_clean(text)) 78 | elif self.clean == 'lower': 79 | text = whitespace_clean(basic_clean(text)).lower() 80 | elif self.clean == 'canonicalize': 81 | text = canonicalize(basic_clean(text)) 82 | return text 83 | -------------------------------------------------------------------------------- /datasets/dummy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from omegaconf import DictConfig 4 | from pathlib import Path 5 | 6 | 7 | class DummyVideoDataset(Dataset): 8 | def __init__(self, cfg: DictConfig, split: str = "training") -> None: 9 | super().__init__() 10 | self.cfg = cfg 11 | self.split = split 12 | self.height = cfg.height 13 | self.width = cfg.width 14 | self.n_frames = cfg.n_frames 15 | self.load_video_latent = cfg.load_video_latent 16 | self.load_prompt_embed = cfg.load_prompt_embed 17 | self.image_to_video = cfg.image_to_video 18 | self.max_text_tokens = cfg.max_text_tokens 19 | 20 | @property 21 | def metadata_path(self): 22 | raise ValueError("Dummy dataset does not have a metadata path") 23 | 24 | @property 25 | def data_root(self): 26 | raise ValueError("Dummy dataset does not have a data root path") 27 | 28 | def __len__(self) -> int: 29 | return 10000000 # Return fixed size of 10000000 30 | 31 | def __getitem__(self, idx: int) -> dict: 32 | # Generate dummy video tensor [T, C, H, W] 33 | videos = torch.randn(self.n_frames, 3, self.height, self.width) 34 | 35 | # Generate dummy image if needed 36 | images = videos[:1].clone() if self.image_to_video else None 37 | 38 | output = { 39 | "prompts": f"A dummy video caption for debugging purpose", 40 | "videos": videos, 41 | "video_metadata": { 42 | "num_frames": self.n_frames, 43 | "height": self.height, 44 | "width": self.width, 45 | "has_caption": True, 46 | }, 47 | "has_bbox": torch.tensor([False, False]), 48 | "bbox_render": torch.zeros(2, self.height, self.width), 49 | } 50 | 51 | if images is not None: 52 | output["images"] = images 53 | 54 | if self.load_prompt_embed: 55 | # Generate dummy prompt embeddings [self.max_text_tokens, 4096] 56 | output["prompt_embeds"] = torch.randn(self.max_text_tokens, 4096) 57 | output["prompt_embed_len"] = self.max_text_tokens 58 | 59 | if self.load_video_latent: 60 | # Generate dummy latents 61 | if self.image_to_video: 62 | output["image_latents"] = torch.randn( 63 | 4, 64 | self.n_frames // 4, 65 | self.height // 8, 66 | self.width // 8, 67 | ) 68 | output["video_latents"] = torch.randn( 69 | 4, 70 | self.n_frames // 4, 71 | self.height // 8, 72 | self.width // 8, 73 | ) 74 | 75 | return output 76 | -------------------------------------------------------------------------------- /configurations/experiment/base_pytorch.yaml: -------------------------------------------------------------------------------- 1 | # inherites from base_experiment.yaml 2 | # most of the options have docs at https://lightning.ai/docs/pytorch/stable/common/trainer.html 3 | 4 | defaults: 5 | - base_experiment 6 | - _self_ 7 | 8 | tasks: [training] # tasks to run sequantially, change when your project has multiple stages and you want to run only a openx of them. 9 | num_nodes: 1 # number of gpu servers used in large scale distributed training 10 | strategy: fsdp # distributed strategy to use, options: ddp, deepspeed_stage_2, fsdp 11 | 12 | training: 13 | precision: 16-mixed # set float precision, 16-mixed is faster while 32 is more stable 14 | compile: False # whether to compile the model with torch.compile 15 | lr: 0.001 # learning rate 16 | batch_size: 16 # training batch size; effective batch size is this number * gpu * nodes iff using distributed training 17 | max_epochs: 1000 # set to -1 to train forever 18 | max_steps: -1 # set to -1 to train forever, will override max_epochs 19 | max_time: null # set to something like "00:12:00:00" to enable 20 | data: 21 | num_workers: 8 # number of CPU threads for data preprocessing. 22 | shuffle: True # whether training data will be shuffled 23 | optim: 24 | accumulate_grad_batches: 1 # accumulate gradients for n batches before backprop 25 | gradient_clip_val: 5.0 # clip gradients with norm above this value, set to 0 to disable 26 | checkpointing: 27 | # these are arguments to pytorch lightning's callback, `ModelCheckpoint` class 28 | every_n_train_steps: 5000 # save a checkpoint every n train steps 29 | every_n_epochs: null # mutually exclusive with ``every_n_train_steps`` and ``train_time_interval`` 30 | train_time_interval: null # in format of "00:12:00:00", mutually exclusive with ``every_n_train_steps`` and ``every_n_epochs``. 31 | enable_version_counter: False # If this is ``False``, later checkpoint will be overwrite previous ones. 32 | 33 | 34 | validation: 35 | precision: 16-mixed 36 | compile: False # whether to compile the model with torch.compile 37 | inference_mode: True # whether to run in inference mode 38 | batch_size: 16 # validation batch size per GPU; effective batch size is this number * gpu * nodes iff using distributed training 39 | val_every_n_step: 2000 # controls how frequent do we run validation, can be float (fraction of epoches) or int (steps) or null (if val_every_n_epoch is set) 40 | val_every_n_epoch: null # if you want to do validation every n epoches, requires val_every_n_step to be null. 41 | limit_batch: null # if null, run through validation set. Otherwise limit the number of batches to use for validation. 42 | data: 43 | num_workers: 8 # number of CPU threads for data preprocessing, for validation. 44 | shuffle: False # whether validation data will be shuffled 45 | 46 | test: 47 | precision: 16-mixed 48 | compile: False # whether to compile the model with torch.compile 49 | inference_mode: True # whether to run in inference mode 50 | batch_size: 16 # test batch size per GPU; effective batch size is this number * gpu * nodes iff using distributed training 51 | limit_batch: null # if null, run through test set. Otherwise limit the number of batches to use for test. 52 | data: 53 | num_workers: 8 # number of CPU threads for data preprocessing, for test. 54 | shuffle: False # whether test data will be shuffled 55 | -------------------------------------------------------------------------------- /algorithms/cogvideo/text_encoder/text_encoder.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Union 2 | 3 | import torch 4 | from transformers import T5EncoderModel, T5Tokenizer 5 | 6 | 7 | def _get_t5_prompt_embeds( 8 | tokenizer: T5Tokenizer, 9 | text_encoder: T5EncoderModel, 10 | prompt: Union[str, List[str]], 11 | num_videos_per_prompt: int = 1, 12 | max_sequence_length: int = 226, 13 | device: Optional[torch.device] = None, 14 | dtype: Optional[torch.dtype] = None, 15 | text_input_ids=None, 16 | ): 17 | prompt = [prompt] if isinstance(prompt, str) else prompt 18 | batch_size = len(prompt) 19 | 20 | if tokenizer is not None: 21 | text_inputs = tokenizer( 22 | prompt, 23 | padding="max_length", 24 | max_length=max_sequence_length, 25 | truncation=True, 26 | add_special_tokens=True, 27 | return_tensors="pt", 28 | ) 29 | text_input_ids = text_inputs.input_ids 30 | else: 31 | if text_input_ids is None: 32 | raise ValueError( 33 | "`text_input_ids` must be provided when the tokenizer is not specified." 34 | ) 35 | 36 | prompt_embeds = text_encoder(text_input_ids.to(device))[0] 37 | prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) 38 | 39 | # duplicate text embeddings for each generation per prompt, using mps friendly method 40 | _, seq_len, _ = prompt_embeds.shape 41 | prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) 42 | prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) 43 | 44 | return prompt_embeds 45 | 46 | 47 | def encode_prompt( 48 | tokenizer: T5Tokenizer, 49 | text_encoder: T5EncoderModel, 50 | prompt: Union[str, List[str]], 51 | num_videos_per_prompt: int = 1, 52 | max_sequence_length: int = 226, 53 | device: Optional[torch.device] = None, 54 | dtype: Optional[torch.dtype] = None, 55 | text_input_ids=None, 56 | ): 57 | prompt = [prompt] if isinstance(prompt, str) else prompt 58 | prompt_embeds = _get_t5_prompt_embeds( 59 | tokenizer, 60 | text_encoder, 61 | prompt=prompt, 62 | num_videos_per_prompt=num_videos_per_prompt, 63 | max_sequence_length=max_sequence_length, 64 | device=device, 65 | dtype=dtype, 66 | text_input_ids=text_input_ids, 67 | ) 68 | return prompt_embeds 69 | 70 | 71 | def compute_prompt_embeddings( 72 | tokenizer: T5Tokenizer, 73 | text_encoder: T5EncoderModel, 74 | prompt: str, 75 | max_sequence_length: int, 76 | device: torch.device, 77 | dtype: torch.dtype, 78 | requires_grad: bool = False, 79 | ): 80 | if requires_grad: 81 | prompt_embeds = encode_prompt( 82 | tokenizer, 83 | text_encoder, 84 | prompt, 85 | num_videos_per_prompt=1, 86 | max_sequence_length=max_sequence_length, 87 | device=device, 88 | dtype=dtype, 89 | ) 90 | else: 91 | with torch.no_grad(): 92 | prompt_embeds = encode_prompt( 93 | tokenizer, 94 | text_encoder, 95 | prompt, 96 | num_videos_per_prompt=1, 97 | max_sequence_length=max_sequence_length, 98 | device=device, 99 | dtype=dtype, 100 | ) 101 | return prompt_embeds 102 | -------------------------------------------------------------------------------- /utils/video_utils.py: -------------------------------------------------------------------------------- 1 | import av 2 | from pathlib import Path 3 | import io 4 | from PIL import Image 5 | 6 | 7 | def write_numpy_to_mp4(video_data, output_path, fps=30): 8 | """ 9 | Write a numpy array into a mp4 file using pyav. 10 | 11 | Args: 12 | video_data (numpy.ndarray): The video data to write. Should be of shape (num_frames, height, width, channels). 13 | output_path (str): The path to the output mp4 file. 14 | fps (int): Frames per second for the output video. 15 | """ 16 | num_frames, height, width, channels = video_data.shape 17 | if channels != 3: 18 | raise ValueError("Video data should have 3 channels (RGB).") 19 | 20 | output_dir = Path(output_path).parent 21 | if not output_dir.exists(): 22 | raise FileNotFoundError(f"The directory {output_dir} does not exist.") 23 | 24 | container = av.open(output_path, mode="w") 25 | stream = container.add_stream("h264", rate=fps) 26 | stream.width = width 27 | stream.height = height 28 | stream.pix_fmt = "yuv420p" 29 | 30 | for frame in video_data: 31 | frame = av.VideoFrame.from_ndarray(frame, format="rgb24") 32 | for packet in stream.encode(frame): 33 | container.mux(packet) 34 | 35 | # Flush the encoder 36 | for packet in stream.encode(): 37 | container.mux(packet) 38 | 39 | container.close() 40 | 41 | 42 | def numpy_to_mp4_bytes(video_data, fps=30): 43 | """ 44 | Convert a numpy array to MP4 bytes in memory using PyAV for better efficiency. 45 | 46 | Args: 47 | video_data (numpy.ndarray): The video data to convert. Should be of shape (num_frames, height, width, channels). 48 | fps (int): Frames per second for the output video. 49 | 50 | Returns: 51 | bytes: The MP4 video data as bytes. 52 | """ 53 | if video_data.ndim != 4 or video_data.shape[-1] != 3: 54 | raise ValueError( 55 | "Video data should be of shape (num_frames, height, width, 3) for RGB video." 56 | ) 57 | 58 | num_frames, height, width, channels = video_data.shape 59 | 60 | # Check that dimensions are even (required by many players and codecs) 61 | if width % 2 != 0 or height % 2 != 0: 62 | raise ValueError( 63 | f"Video dimensions must be even. Got width={width}, height={height}" 64 | ) 65 | 66 | # Create an in-memory buffer 67 | buffer = io.BytesIO() 68 | container = av.open(buffer, mode="w", format="mp4") 69 | 70 | # Add video stream with more compatible settings 71 | stream = container.add_stream("h264", rate=fps) 72 | stream.width = width 73 | stream.height = height 74 | stream.pix_fmt = "yuv420p" 75 | 76 | # Set codec options with correct syntax for libopenh264 77 | # Note: profile must be an integer value, not a string name 78 | stream.options = { 79 | "profile": "66", # 66 = Baseline profile in H.264 80 | "level": "30", # 30 = Level 3.0 (must be integer value) 81 | "preset": "medium", 82 | "crf": "23", 83 | } 84 | 85 | # Encode frames directly from numpy array 86 | for frame_data in video_data: 87 | frame = av.VideoFrame.from_ndarray(frame_data, format="rgb24") 88 | for packet in stream.encode(frame): 89 | container.mux(packet) 90 | 91 | # Flush the encoder 92 | for packet in stream.encode(): 93 | container.mux(packet) 94 | 95 | # Close the container and get the buffer content 96 | container.close() 97 | buffer.seek(0) 98 | return buffer.getvalue() 99 | -------------------------------------------------------------------------------- /datasets/README.md: -------------------------------------------------------------------------------- 1 | The `datasets` folder is used to contain dataset code or environment code. 2 | Don't store actual data like images here! For those, please use the `data` folder instead of `datasets`. 3 | 4 | Create a folder to create your own pytorch dataset definition. Then, update the `__init__.py` 5 | at every level to register all datasets. 6 | 7 | Each dataset class takes in a DictConfig file `cfg` in its `__init__`, which allows you to pass in arguments via configuration file in `configurations/dataset` or [command line override](https://hydra.cc/docs/tutorials/basic/your_first_app/simple_cli/). 8 | 9 | 10 | 11 | ## Dataset format 12 | We train on a mixture of datasets, so we define a unified dataset format for consistency and ease of management. 13 | 14 | Each dataset includes a global metadata file, typically named `metadata_merged.csv`, which contains key information for each video clip. 15 | 16 | The file is named as metdata_**merged**.csv because each video clip may have multiple recaptions. Instead of saving the captions for each video into a list within a single csv row, we just create another row on the `metadata_merged.csv`. So `metadata_merged.csv` may contain multiple rows referring to the same video with different captions. For some dataset, we also provide a `cleaned_metadata.csv`, which contains a deduplicated version of the metadata (one entry per video) but excludes the additional recaptions. 17 | 18 | Important fields of the global metadata includes: 19 | 1. `video_path`: Relative path (from the metadata file) to the video clip. 20 | 2. `trim_start`, and `trim_end` (optional): Specifies the trimmed segment of the clip. If absent, the full video is used. 21 | 3. `gemini_caption`: Action-focused caption generated by Gemini Flash 2.0. 22 | 4. `original_caption`: Original caption from the source dataset; used when no Gemini caption is available. 23 | 5. `prompt_embed_path`: Path to precomputed T5 prompt embeddings (not released due to large size). 24 | 6. `stable_brightess` (optional): 1.0 if brightness is stable, 0.0 otherwise. We recommend removing videos with `stable_brightess == 0.0` 25 | 7. `stable_background` (optional): Either 1.0 or 0.0. Recommend to remove videos with `stable_background == 0.0`, this indicates the video has large average optical flow magnitudes, which very likely contains large background motions. 26 | 8. `detected_hand_in_first_frame` (optional): 1.0 if a human hand is detected in the first frame, 0.0 otherwise. Videos with 0.0 often cause embodiment ambiguity and should be filtered out. 27 | 9. There are some other fields which can help you understand more about this clips. `n_frames`, `n_fps`, `height`, `width`, ... etc. 28 | 29 | 30 | Then `mixture.py` implements a dataset that will weighted sampling from a group of datasets with above format. 31 | 32 | 33 | 34 | 35 | 36 | ## Downloading the dataset 37 | We provide dataset-specific download scripts for AgiBot World, DROID, Ego4D, EpicKitchens, and Something-Something in their respective dataset.py files within this folder. 38 | 39 | For downloading the filtered Pandas subset, we provide the unique `youtube_key_segment` for each video_clip, and the `trim_start`, and `trim_end` for each clip. To download these subset, please download the official metadata from [Pandas-70M](https://snap-research.github.io/Panda-70M/), then using the `youtube_key_segment` to find the URL of the video clips and then download with your own online video downloader. 40 | 41 | 42 | 43 | For Bridge, please download from (Bridge)[https://rail-berkeley.github.io/bridgedata/]. 44 | 45 | For Language Table, please download from (Language Table)[https://github.com/google-research/language-table]. 46 | 47 | -------------------------------------------------------------------------------- /algorithms/wan/utils/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import argparse 3 | import binascii 4 | import os 5 | import os.path as osp 6 | 7 | import imageio 8 | import torch 9 | import torchvision 10 | 11 | __all__ = ["cache_video", "cache_image", "str2bool"] 12 | 13 | 14 | def rand_name(length=8, suffix=""): 15 | name = binascii.b2a_hex(os.urandom(length)).decode("utf-8") 16 | if suffix: 17 | if not suffix.startswith("."): 18 | suffix = "." + suffix 19 | name += suffix 20 | return name 21 | 22 | 23 | def cache_video( 24 | tensor, 25 | save_file=None, 26 | fps=30, 27 | suffix=".mp4", 28 | nrow=8, 29 | normalize=True, 30 | value_range=(-1, 1), 31 | retry=5, 32 | ): 33 | # cache file 34 | cache_file = ( 35 | osp.join("/tmp", rand_name(suffix=suffix)) if save_file is None else save_file 36 | ) 37 | 38 | # save to cache 39 | error = None 40 | for _ in range(retry): 41 | try: 42 | # preprocess 43 | tensor = tensor.clamp(min(value_range), max(value_range)) 44 | tensor = torch.stack( 45 | [ 46 | torchvision.utils.make_grid( 47 | u, nrow=nrow, normalize=normalize, value_range=value_range 48 | ) 49 | for u in tensor.unbind(2) 50 | ], 51 | dim=1, 52 | ).permute(1, 2, 3, 0) 53 | tensor = (tensor * 255).type(torch.uint8).cpu() 54 | 55 | # write video 56 | writer = imageio.get_writer(cache_file, fps=fps, codec="libx264", quality=8) 57 | for frame in tensor.numpy(): 58 | writer.append_data(frame) 59 | writer.close() 60 | return cache_file 61 | except Exception as e: 62 | error = e 63 | continue 64 | else: 65 | print(f"cache_video failed, error: {error}", flush=True) 66 | return None 67 | 68 | 69 | def cache_image( 70 | tensor, save_file, nrow=8, normalize=True, value_range=(-1, 1), retry=5 71 | ): 72 | # cache file 73 | suffix = osp.splitext(save_file)[1] 74 | if suffix.lower() not in [".jpg", ".jpeg", ".png", ".tiff", ".gif", ".webp"]: 75 | suffix = ".png" 76 | 77 | # save to cache 78 | error = None 79 | for _ in range(retry): 80 | try: 81 | tensor = tensor.clamp(min(value_range), max(value_range)) 82 | torchvision.utils.save_image( 83 | tensor, 84 | save_file, 85 | nrow=nrow, 86 | normalize=normalize, 87 | value_range=value_range, 88 | ) 89 | return save_file 90 | except Exception as e: 91 | error = e 92 | continue 93 | 94 | 95 | def str2bool(v): 96 | """ 97 | Convert a string to a boolean. 98 | 99 | Supported true values: 'yes', 'true', 't', 'y', '1' 100 | Supported false values: 'no', 'false', 'f', 'n', '0' 101 | 102 | Args: 103 | v (str): String to convert. 104 | 105 | Returns: 106 | bool: Converted boolean value. 107 | 108 | Raises: 109 | argparse.ArgumentTypeError: If the value cannot be converted to boolean. 110 | """ 111 | if isinstance(v, bool): 112 | return v 113 | v_lower = v.lower() 114 | if v_lower in ("yes", "true", "t", "y", "1"): 115 | return True 116 | elif v_lower in ("no", "false", "f", "n", "0"): 117 | return False 118 | else: 119 | raise argparse.ArgumentTypeError("Boolean value expected (True/False)") 120 | -------------------------------------------------------------------------------- /experiments/exp_video.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.distributed.fsdp import MixedPrecision 3 | from torch.distributed.fsdp.wrap import ModuleWrapPolicy 4 | 5 | # from algorithms.cogvideo import CogVideoXImageToVideo, CogVideoXVAE 6 | from algorithms.wan import WanImageToVideo, WanTextToVideo 7 | from datasets.dummy import DummyVideoDataset 8 | from datasets.openx_base import OpenXVideoDataset 9 | from datasets.droid import DroidVideoDataset 10 | from datasets.something_something import SomethingSomethingDataset 11 | from datasets.epic_kitchen import EpicKitchenDataset 12 | from datasets.pandas import PandasVideoDataset 13 | from datasets.ego4d import Ego4DVideoDataset 14 | from datasets.agibot_world import AgibotWorldDataset 15 | from datasets.mixture import MixtureDataset 16 | from datasets.video_base import SingleFrameVideoDataset 17 | from .exp_base import BaseLightningExperiment 18 | 19 | 20 | class VideoPredictionExperiment(BaseLightningExperiment): 21 | """ 22 | A video prediction experiment 23 | """ 24 | 25 | compatible_algorithms = dict( 26 | # cogvideox_i2v=CogVideoXImageToVideo, 27 | # cogvideox_vae=CogVideoXVAE, 28 | wan_i2v=WanImageToVideo, 29 | wan_t2v=WanTextToVideo, 30 | wan_toy=WanImageToVideo, 31 | ) 32 | 33 | compatible_datasets = dict( 34 | mixture=MixtureDataset, 35 | mixture_robot=MixtureDataset, 36 | dummy=DummyVideoDataset, 37 | something_something=SomethingSomethingDataset, 38 | epic_kitchen=EpicKitchenDataset, 39 | pandas=PandasVideoDataset, 40 | ego4d=Ego4DVideoDataset, 41 | bridge=OpenXVideoDataset, 42 | droid=DroidVideoDataset, 43 | agibot_world=AgibotWorldDataset, 44 | language_table=OpenXVideoDataset, 45 | ours_test=SingleFrameVideoDataset, 46 | # austin_buds=OpenXVideoDataset, 47 | # austin_sailor=OpenXVideoDataset, 48 | # austin_sirius=OpenXVideoDataset, 49 | # bc_z=OpenXVideoDataset, 50 | # berkeley_autolab=OpenXVideoDataset, 51 | # berkeley_cable=OpenXVideoDataset, 52 | # berkeley_fanuc=OpenXVideoDataset, 53 | # cmu_stretch=OpenXVideoDataset, 54 | # dlr_edan=OpenXVideoDataset, 55 | # dobbe=OpenXVideoDataset, 56 | # fmb=OpenXVideoDataset, 57 | # fractal=OpenXVideoDataset, 58 | # iamlab_cmu=OpenXVideoDataset, 59 | # jaco_play=OpenXVideoDataset, 60 | # nyu_franka=OpenXVideoDataset, 61 | # roboturk=OpenXVideoDataset, 62 | # stanford_hydra=OpenXVideoDataset, 63 | # taco_play=OpenXVideoDataset, 64 | # toto=OpenXVideoDataset, 65 | # ucsd_kitchen=OpenXVideoDataset, 66 | # utaustin_mutex=OpenXVideoDataset, 67 | # viola=OpenXVideoDataset, 68 | ) 69 | 70 | def _build_strategy(self): 71 | from lightning.pytorch.strategies.fsdp import FSDPStrategy 72 | 73 | if self.cfg.strategy == "ddp": 74 | return super()._build_strategy() 75 | elif self.cfg.strategy == "fsdp": 76 | if self.cfg.num_nodes >= 8: 77 | device_mesh = (self.cfg.num_nodes // 8, 32) 78 | else: 79 | device_mesh = (1, self.cfg.num_nodes * 4) 80 | return FSDPStrategy( 81 | mixed_precision=MixedPrecision( 82 | param_dtype=torch.bfloat16, 83 | reduce_dtype=torch.bfloat16, 84 | buffer_dtype=torch.bfloat16, 85 | ), 86 | auto_wrap_policy=ModuleWrapPolicy(self.algo.classes_to_shard()), 87 | # sharding_strategy="FULL_SHARD", 88 | sharding_strategy="HYBRID_SHARD", 89 | device_mesh=device_mesh, 90 | ) 91 | 92 | else: 93 | return self.cfg.strategy 94 | 95 | def download_dataset(self): 96 | dataset = self._build_dataset("training") 97 | -------------------------------------------------------------------------------- /algorithms/cogvideo/cogvideox_vae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from einops import rearrange 5 | from diffusers import AutoencoderKLCogVideoX 6 | from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution 7 | 8 | from algorithms.common.base_pytorch_algo import BasePytorchAlgo 9 | 10 | 11 | class CogVideoXVAE(BasePytorchAlgo): 12 | """ 13 | Main classs for CogVideoXImageToVideo 14 | """ 15 | 16 | def __init__(self, cfg): 17 | self.pretrained_cfg = cfg.pretrained 18 | super().__init__(cfg) 19 | 20 | def configure_model(self): 21 | self.vae = AutoencoderKLCogVideoX.from_pretrained( 22 | self.pretrained_cfg.pretrained_model_name_or_path, 23 | subfolder="vae", 24 | revision=self.pretrained_cfg.revision, 25 | variant=self.pretrained_cfg.variant, 26 | ) 27 | self.criteria = nn.MSELoss() 28 | 29 | @torch.no_grad() 30 | def on_after_batch_transfer(self, batch, dataloader_idx): 31 | # data reprocessing, returned result is passed to self.training_step / self.validation_step 32 | 33 | images = batch["images"] 34 | videos = batch["videos"] 35 | batch_size = images.size(0) 36 | 37 | # Encode videos 38 | if not self.cfg.load_video_latent: 39 | images = images.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] 40 | image_noise_sigma = torch.normal( 41 | mean=-3.0, 42 | std=0.5, 43 | size=(batch_size,), 44 | device=self.device, 45 | dtype=images.dtype, 46 | ) 47 | image_noise_sigma = torch.exp(image_noise_sigma) 48 | noisy_images = ( 49 | images 50 | + torch.randn_like(images) 51 | * image_noise_sigma[:, None, None, None, None] 52 | ) 53 | if self.trainer.training: 54 | image_latent_dist = self.vae.encode(noisy_images).latent_dist 55 | else: 56 | image_latent_dist = self.vae.encode(images).latent_dist 57 | videos = videos.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] 58 | latent_dist = self.vae.encode(videos).latent_dist 59 | else: 60 | image_latent_dist = DiagonalGaussianDistribution(images) 61 | latent_dist = DiagonalGaussianDistribution(videos) 62 | if self.trainer.training: 63 | image_latents = image_latent_dist.mode().clone() 64 | video_latents = latent_dist.mode().clone() 65 | else: 66 | image_latents = image_latent_dist.sample() 67 | video_latents = latent_dist.sample() 68 | 69 | batch["image_latents"] = image_latents 70 | batch["video_latents"] = video_latents 71 | 72 | return batch 73 | 74 | def training_step(self, batch, batch_idx, *args, **kwargs): 75 | image_latents = batch["image_latents"] 76 | video_latents = batch["video_latents"] 77 | 78 | # video_pred = self.vae.decode(video_latents, return_dict=False) 79 | # loss = self.criteria(video_pred, batch["videos"]) 80 | raise NotImplementedError( 81 | "VAE is inference only. Append experiment.tasks=[validation] to run inference" 82 | ) 83 | 84 | return loss 85 | 86 | def validation_step(self, batch, batch_idx, *args, **kwargs): 87 | image_latents = batch["image_latents"] 88 | video_latents = batch["video_latents"] 89 | 90 | video_pred, *_ = self.vae.decode(video_latents, return_dict=False) 91 | video_pred = rearrange(video_pred, "b c t h w -> b t c h w") 92 | video_gt = batch["videos"] 93 | video = torch.cat([video_gt, video_pred], dim=-1) 94 | video = video * 0.5 + 0.5 95 | video = video.cpu() 96 | video = rearrange(self.all_gather(video), "p b ... -> (p b) ...") 97 | self.log_video("validation_vis/video_pred", video) 98 | 99 | return video_pred 100 | -------------------------------------------------------------------------------- /datasets/pandas.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from typing import List, Tuple, Any, Dict 3 | import time 4 | import json 5 | from pathlib import Path 6 | import decord 7 | from .video_base import VideoDataset 8 | 9 | 10 | class PandasVideoDataset(VideoDataset): 11 | def _load_records(self) -> Tuple[List[str], List[str]]: 12 | """ 13 | Given the metadata file, loads the records as a list. 14 | Each record is a dictionary containing a datapoint's mp4 path / caption etc. 15 | Require these entries: "video_path", "caption", "height", "width", "n_frames", "fps" 16 | 17 | For pandas70m, there are one extra key "youtube_key_segment", looks like: "2NQDnwJEBeQ_segment_7". 18 | It's the key identifier for the video. 19 | 20 | Pandas 70M comes with json config file. This method will convert the json config file to a csv file and save it before using. 21 | """ 22 | if self.metadata_path.suffix == ".json": 23 | # convert a legacy json file to a csv file we need 24 | start_time = time.time() 25 | records = [] 26 | with open(self.data_root / self.metadata_path, "r") as f: 27 | for line in f: 28 | item = json.loads(line) 29 | if "mp4_path" in item: 30 | item["video_path"] = item["mp4_path"] 31 | del item["mp4_path"] 32 | if "start_frame_index" in item: 33 | item["trim_start"] = item["start_frame_index"] 34 | del item["start_frame_index"] 35 | if "end_frame_index" in item: 36 | item["trim_end"] = item["end_frame_index"] 37 | del item["end_frame_index"] 38 | if "prompt_embed_path" in item: 39 | item["prompt_embed_path"] = ( 40 | "prompt_embeds/" + item["prompt_embed_path"] + ".pt" 41 | ) 42 | if "answers_for_four_questions" in item: 43 | del item["answers_for_four_questions"] 44 | records.append(item) 45 | 46 | df = pd.DataFrame.from_records(records) 47 | csv_path = self.metadata_path.with_suffix(".csv") 48 | df.to_csv(self.data_root / csv_path, index=False) 49 | self.metadata_path = csv_path 50 | end_time = time.time() 51 | print(f"Time taken for converting records: {end_time - start_time} seconds") 52 | 53 | return super()._load_records() 54 | 55 | 56 | if __name__ == "__main__": 57 | # do debug test 58 | import torch 59 | from omegaconf import OmegaConf 60 | 61 | debug_config = { 62 | "debug": True, 63 | "data_root": "/n/holylfs06/LABS/sham_lab/Lab/eiwm_data/pandas/", 64 | "metadata_path": "pandas_filtered_human_clip_meta_gemini_1.5_flash.json", 65 | "auto_download": False, 66 | "force_download": False, 67 | "test_percentage": 0.1, 68 | "id_token": "", 69 | "resolution": [256, 256], 70 | "n_frames": 8, 71 | "fps": 30, 72 | "trim_mode": "speedup", 73 | "pad_mode": "pad_last", 74 | "filtering": { 75 | "disable": False, 76 | "height": [32, 2160], 77 | "width": [32, 3840], 78 | "n_frames": [8, 1000], 79 | "fps": [1, 60], 80 | }, 81 | "load_video_latent": False, 82 | "load_prompt_embed": False, 83 | "augmentation": {"random_flip": 0.5, "ratio": None, "scale": None}, 84 | "image_to_video": False, 85 | "check_video_path": False, 86 | } 87 | 88 | # Convert dict to OmegaConf 89 | cfg = OmegaConf.create(debug_config) 90 | 91 | # Create dataset 92 | dataset = PandasVideoDataset(cfg=cfg, split="training") 93 | 94 | # Load one sample and print its contents 95 | sample = dataset[0] 96 | print("\nSample contents:") 97 | for key, value in sample.items(): 98 | if isinstance(value, torch.Tensor): 99 | print(f"{key}: Tensor of shape {value.shape}") 100 | elif isinstance(value, dict): 101 | print(f"{key}:") 102 | for k, v in value.items(): 103 | print(f" {k}: {v}") 104 | else: 105 | print(f"{key}: {value}") 106 | -------------------------------------------------------------------------------- /datasets/droid.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from tqdm import tqdm 3 | from pathlib import Path 4 | import decord 5 | import shutil 6 | import subprocess 7 | import json 8 | from typing import Dict, Any 9 | from .video_base import VideoDataset 10 | 11 | 12 | class DroidVideoDataset(VideoDataset): 13 | def __init__(self, cfg: Dict[str, Any], split: str = "training"): 14 | self.override_fps = cfg.download.override_fps 15 | self.views = cfg.download.views 16 | super().__init__(cfg, split) 17 | 18 | def download(self): 19 | self.data_root.mkdir(parents=True, exist_ok=True) 20 | 21 | # print("Downloading DROID dataset...") 22 | # cmd = f"gsutil -m cp -r gs://gresearch/robotics/droid_raw {self.data_root}" 23 | # subprocess.run(cmd, shell=True, check=True) 24 | # print("Download complete!") 25 | 26 | # build metadata 27 | raw_dir = self.data_root / "droid_raw" 28 | caption_file = raw_dir / "1.0.1" / "aggregated-annotations-030724.json" 29 | caption_data = json.load(open(caption_file)) 30 | records = [] 31 | for lab_dir in (raw_dir / "1.0.1").glob("*/"): 32 | print("processing", lab_dir) 33 | print("=" * 100) 34 | # Delete failure directory and its contents if it exists 35 | failure_dir = lab_dir / "failure" 36 | success_dir = lab_dir / "success" 37 | if failure_dir.exists(): 38 | shutil.rmtree(failure_dir) 39 | 40 | for date_dir in list(success_dir.glob("*")): 41 | for episode_dir in list(date_dir.glob("*")): 42 | # Rename episode directory if it contains ":" 43 | if ":" in episode_dir.name: 44 | new_name = episode_dir.name.replace(":", "_") 45 | new_path = episode_dir.parent / new_name 46 | if new_path.exists(): 47 | shutil.rmtree(episode_dir) 48 | else: 49 | episode_dir.rename(new_path) 50 | 51 | for episode_dir in tqdm(list(success_dir.glob("*/*"))): 52 | annotation_file = list(episode_dir.glob("*.json")) 53 | if not annotation_file: 54 | continue 55 | annotation_file = annotation_file[0] 56 | f = json.load(open(annotation_file)) 57 | caption = f["current_task"] 58 | uuid = f["uuid"] 59 | for views in self.views: 60 | video_path = lab_dir / f[views + "_mp4_path"].replace(":", "_") 61 | state_path = lab_dir / f["hdf5_path"].replace(":", "_") 62 | n_frames = f["trajectory_length"] 63 | 64 | if not video_path.exists(): 65 | print(f"Video file not found: {video_path}") 66 | continue 67 | 68 | try: 69 | vr = decord.VideoReader(str(video_path)) 70 | fps = self.override_fps 71 | width = 1280 # vr[0].shape[1] 72 | height = 720 # vr[0].shape[0] 73 | 74 | del vr 75 | except Exception as e: 76 | print(f"Error loading video {video_path}: {e}") 77 | continue 78 | 79 | video_path = video_path.relative_to(self.data_root) 80 | # state_path = state_path.relative_to(self.data_root) 81 | 82 | if uuid not in caption_data: 83 | caption = "" 84 | has_caption = False 85 | else: 86 | caption = caption_data[uuid] 87 | has_caption = True 88 | records.append( 89 | { 90 | "video_path": str(video_path), 91 | # "state_path": str(state_path), 92 | "original_caption": caption, 93 | "fps": fps, 94 | "n_frames": n_frames, 95 | "width": width, 96 | "height": height, 97 | "has_caption": has_caption, 98 | } 99 | ) 100 | metadata_path = self.data_root / self.metadata_path 101 | metadata_path.parent.mkdir(parents=True, exist_ok=True) 102 | df = pd.DataFrame(records) 103 | df.to_csv(metadata_path, index=False) 104 | print(f"Created metadata CSV with {len(records)} videos") 105 | -------------------------------------------------------------------------------- /datasets/agibot_world.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from tqdm import tqdm 3 | import cv2 4 | import shutil 5 | import json 6 | import pandas as pd 7 | import tarfile 8 | import decord 9 | import subprocess 10 | from huggingface_hub import snapshot_download 11 | from .video_base import VideoDataset 12 | 13 | 14 | class AgibotWorldDataset(VideoDataset): 15 | """ 16 | Agibot world dataset from https://huggingface.co/datasets/agibot-world/AgiBotWorld-Alpha 17 | """ 18 | 19 | def preprocess_record(self, record): 20 | record["fps"] = self.cfg.fps_override 21 | return record 22 | 23 | def download(self): 24 | 25 | raw_dir = self.data_root / "agibot_world_alpha" 26 | raw_dir.mkdir(parents=True, exist_ok=True) 27 | 28 | # snapshot_download( 29 | # repo_id="agibot-world/AgiBotWorld-Alpha", 30 | # local_dir=raw_dir, 31 | # repo_type="dataset", 32 | # ) 33 | 34 | # print("Extracting tar files...") 35 | # for task_dir in tqdm((raw_dir / "observations").glob("*")): 36 | # for tar_file in task_dir.glob("*.tar"): 37 | # tar = tarfile.open(tar_file) 38 | # tar.extractall(path=task_dir) 39 | # tar.close() 40 | # # Delete the tar file after extraction 41 | # tar_file.unlink() 42 | # for episode_dir in task_dir.glob("*/"): 43 | # depth_dir = episode_dir / "depth" 44 | # video_dir = episode_dir / "videos" 45 | # # Delete the depth directory if it exists 46 | # if depth_dir.exists(): 47 | # shutil.rmtree(depth_dir) 48 | 49 | # for video_file in video_dir.glob("*.mp4"): 50 | # if video_file.name != "head_color.mp4": 51 | # video_file.unlink() 52 | # else: 53 | # reencoded_video_path = video_file.with_name( 54 | # f"{video_file.stem}_reencoded.mp4" 55 | # ) 56 | # command = [ 57 | # "ffmpeg", 58 | # "-y", 59 | # "-i", 60 | # str(video_file), 61 | # "-c:v", 62 | # "libx264", 63 | # "-crf", 64 | # "23", 65 | # "-c:a", 66 | # "copy", 67 | # str(reencoded_video_path), 68 | # ] 69 | # print(f"Reencoding {video_file} to {reencoded_video_path}") 70 | # subprocess.run(command, check=True) 71 | 72 | print("Creating metadata CSV...") 73 | records = [] 74 | 75 | for info_file in (raw_dir / "task_info").glob("*.json"): 76 | with open(info_file, "r") as f: 77 | info = json.load(f) 78 | for episode_info in tqdm(info): 79 | episode_id = episode_info["episode_id"] 80 | task_id = episode_info["task_id"] 81 | video_path = raw_dir / ( 82 | f"observations/{task_id}/{episode_id}/videos/head_color_reencoded.mp4" 83 | ) 84 | if not video_path.exists(): 85 | print(f"Skipping {video_path} because it doesn't exist") 86 | continue 87 | try: 88 | vr = decord.VideoReader(str(video_path)) 89 | except Exception as e: 90 | print(f"Error loading video {video_path}: {e}") 91 | continue 92 | fps = 30 93 | width = 640 94 | height = 480 95 | clips = episode_info["label_info"]["action_config"] 96 | for clip in clips: 97 | trim_start = clip["start_frame"] 98 | trim_end = clip["end_frame"] 99 | caption = clip["action_text"] 100 | 101 | records.append( 102 | { 103 | "video_path": video_path.relative_to(self.data_root), 104 | "original_caption": caption, 105 | "trim_start": trim_start, 106 | "trim_end": trim_end, 107 | "fps": fps, 108 | "width": width, 109 | "height": height, 110 | "n_frames": len(vr), 111 | } 112 | ) 113 | 114 | # Save as CSV 115 | metadata_path = self.data_root / self.metadata_path 116 | metadata_path.parent.mkdir(parents=True, exist_ok=True) 117 | df = pd.DataFrame.from_records(records) 118 | df.to_csv(metadata_path, index=False) 119 | print(f"Created metadata CSV with {len(records)} videos") 120 | -------------------------------------------------------------------------------- /algorithms/wan/modules/xlm_roberta.py: -------------------------------------------------------------------------------- 1 | # Modified from transformers.models.xlm_roberta.modeling_xlm_roberta 2 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | __all__ = ['XLMRoberta', 'xlm_roberta_large'] 8 | 9 | 10 | class SelfAttention(nn.Module): 11 | 12 | def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5): 13 | assert dim % num_heads == 0 14 | super().__init__() 15 | self.dim = dim 16 | self.num_heads = num_heads 17 | self.head_dim = dim // num_heads 18 | self.eps = eps 19 | 20 | # layers 21 | self.q = nn.Linear(dim, dim) 22 | self.k = nn.Linear(dim, dim) 23 | self.v = nn.Linear(dim, dim) 24 | self.o = nn.Linear(dim, dim) 25 | self.dropout = nn.Dropout(dropout) 26 | 27 | def forward(self, x, mask): 28 | """ 29 | x: [B, L, C]. 30 | """ 31 | b, s, c, n, d = *x.size(), self.num_heads, self.head_dim 32 | 33 | # compute query, key, value 34 | q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3) 35 | k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3) 36 | v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3) 37 | 38 | # compute attention 39 | p = self.dropout.p if self.training else 0.0 40 | x = F.scaled_dot_product_attention(q, k, v, mask, p) 41 | x = x.permute(0, 2, 1, 3).reshape(b, s, c) 42 | 43 | # output 44 | x = self.o(x) 45 | x = self.dropout(x) 46 | return x 47 | 48 | 49 | class AttentionBlock(nn.Module): 50 | 51 | def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5): 52 | super().__init__() 53 | self.dim = dim 54 | self.num_heads = num_heads 55 | self.post_norm = post_norm 56 | self.eps = eps 57 | 58 | # layers 59 | self.attn = SelfAttention(dim, num_heads, dropout, eps) 60 | self.norm1 = nn.LayerNorm(dim, eps=eps) 61 | self.ffn = nn.Sequential( 62 | nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim), 63 | nn.Dropout(dropout)) 64 | self.norm2 = nn.LayerNorm(dim, eps=eps) 65 | 66 | def forward(self, x, mask): 67 | if self.post_norm: 68 | x = self.norm1(x + self.attn(x, mask)) 69 | x = self.norm2(x + self.ffn(x)) 70 | else: 71 | x = x + self.attn(self.norm1(x), mask) 72 | x = x + self.ffn(self.norm2(x)) 73 | return x 74 | 75 | 76 | class XLMRoberta(nn.Module): 77 | """ 78 | XLMRobertaModel with no pooler and no LM head. 79 | """ 80 | 81 | def __init__(self, 82 | vocab_size=250002, 83 | max_seq_len=514, 84 | type_size=1, 85 | pad_id=1, 86 | dim=1024, 87 | num_heads=16, 88 | num_layers=24, 89 | post_norm=True, 90 | dropout=0.1, 91 | eps=1e-5): 92 | super().__init__() 93 | self.vocab_size = vocab_size 94 | self.max_seq_len = max_seq_len 95 | self.type_size = type_size 96 | self.pad_id = pad_id 97 | self.dim = dim 98 | self.num_heads = num_heads 99 | self.num_layers = num_layers 100 | self.post_norm = post_norm 101 | self.eps = eps 102 | 103 | # embeddings 104 | self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id) 105 | self.type_embedding = nn.Embedding(type_size, dim) 106 | self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id) 107 | self.dropout = nn.Dropout(dropout) 108 | 109 | # blocks 110 | self.blocks = nn.ModuleList([ 111 | AttentionBlock(dim, num_heads, post_norm, dropout, eps) 112 | for _ in range(num_layers) 113 | ]) 114 | 115 | # norm layer 116 | self.norm = nn.LayerNorm(dim, eps=eps) 117 | 118 | def forward(self, ids): 119 | """ 120 | ids: [B, L] of torch.LongTensor. 121 | """ 122 | b, s = ids.shape 123 | mask = ids.ne(self.pad_id).long() 124 | 125 | # embeddings 126 | x = self.token_embedding(ids) + \ 127 | self.type_embedding(torch.zeros_like(ids)) + \ 128 | self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask) 129 | if self.post_norm: 130 | x = self.norm(x) 131 | x = self.dropout(x) 132 | 133 | # blocks 134 | mask = torch.where( 135 | mask.view(b, 1, 1, s).gt(0), 0.0, 136 | torch.finfo(x.dtype).min) 137 | for block in self.blocks: 138 | x = block(x, mask) 139 | 140 | # output 141 | if not self.post_norm: 142 | x = self.norm(x) 143 | return x 144 | 145 | 146 | def xlm_roberta_large(pretrained=False, 147 | return_tokenizer=False, 148 | device='cpu', 149 | **kwargs): 150 | """ 151 | XLMRobertaLarge adapted from Huggingface. 152 | """ 153 | # params 154 | cfg = dict( 155 | vocab_size=250002, 156 | max_seq_len=514, 157 | type_size=1, 158 | pad_id=1, 159 | dim=1024, 160 | num_heads=16, 161 | num_layers=24, 162 | post_norm=True, 163 | dropout=0.1, 164 | eps=1e-5) 165 | cfg.update(**kwargs) 166 | 167 | # init a model on device 168 | with torch.device(device): 169 | model = XLMRoberta(**cfg) 170 | return model 171 | -------------------------------------------------------------------------------- /datasets/something_something.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import subprocess 3 | import json 4 | import pandas as pd 5 | import zipfile 6 | import cv2 7 | from pathlib import Path 8 | from tqdm import tqdm 9 | 10 | from .video_base import VideoDataset 11 | 12 | 13 | class SomethingSomethingDataset(VideoDataset): 14 | """ 15 | Something Something Dataset from https://arxiv.org/abs/1706.04261 16 | """ 17 | 18 | def download(self): 19 | self.data_root.mkdir(parents=True, exist_ok=True) 20 | 21 | urls = [ 22 | "https://apigwx-aws.qualcomm.com/qsc/public/v1/api/download/software/dataset/AIDataset/Something-Something-V2/20bn-something-something-v2-00", 23 | "https://apigwx-aws.qualcomm.com/qsc/public/v1/api/download/software/dataset/AIDataset/Something-Something-V2/20bn-something-something-v2-01", 24 | "https://softwarecenter.qualcomm.com/api/download/software/dataset/AIDataset/Something-Something-V2/20bn-something-something-download-package-labels.zip", 25 | ] 26 | 27 | for url in urls: 28 | filename = Path(url).name 29 | filepath = self.data_root / filename 30 | 31 | print(f"Downloading {filename}...") 32 | response = requests.get(url, stream=True) 33 | response.raise_for_status() 34 | 35 | with open(filepath, "wb") as f: 36 | for chunk in response.iter_content(chunk_size=8192): 37 | f.write(chunk) 38 | 39 | # Use shell command to concatenate and extract tar video files 40 | print("Concatenating and extracting tar files...") 41 | cmd = f"cd {self.data_root} && cat 20bn-something-something-v2-0? | tar -xvzf -" 42 | subprocess.run(cmd, shell=True, check=True) 43 | print(f"Deleting zip files for video data...") 44 | for zip_file in self.data_root.glob("20bn-something-something-v2-0*"): 45 | print(f"Deleting {zip_file.name}...") 46 | zip_file.unlink() 47 | 48 | # Unzip the labels package 49 | labels_zip_path = ( 50 | self.data_root / "20bn-something-something-download-package-labels.zip" 51 | ) 52 | if labels_zip_path.exists(): 53 | print(f"Extracting {labels_zip_path.name}...") 54 | with zipfile.ZipFile(labels_zip_path, "r") as zip_ref: 55 | zip_ref.extractall(self.data_root) 56 | print(f"Deleting zip file for labels...") 57 | labels_zip_path.unlink() 58 | 59 | # Create metadata CSV from labels 60 | print("Creating metadata CSV file for Something Something Dataset") 61 | 62 | json_files = { 63 | "training": "labels/train.json", 64 | "validation": "labels/validation.json", 65 | } 66 | 67 | records = [] 68 | for split, json_file in json_files.items(): 69 | with open(self.data_root / json_file, "r") as f: 70 | labels = json.load(f) 71 | 72 | for item in tqdm(labels, desc=f"Creating metadata for {split}"): 73 | webm_video_path = f"20bn-something-something-v2/{item['id']}.webm" 74 | mp4_video_path = f"20bn-something-something-v2/{item['id']}.mp4" 75 | 76 | total_videos = len(labels) 77 | successful_conversions = 0 78 | 79 | if (self.data_root / webm_video_path).exists(): 80 | # Convert webm to mp4 using ffmpeg 81 | input_path = str(self.data_root / webm_video_path) 82 | output_path = str(self.data_root / mp4_video_path) 83 | cmd = f'ffmpeg -i {input_path} -vf "pad=ceil(iw/2)*2:ceil(ih/2)*2" -c:v libx264 -c:a aac {output_path}' 84 | try: 85 | subprocess.run( 86 | cmd, 87 | shell=True, 88 | check=True, 89 | stdout=subprocess.DEVNULL, 90 | stderr=subprocess.DEVNULL, 91 | ) 92 | # Delete the webm file after successful conversion 93 | (self.data_root / webm_video_path).unlink() 94 | 95 | # Get video metadata using cv2 96 | cap = cv2.VideoCapture(output_path) 97 | if not cap.isOpened(): 98 | continue 99 | 100 | width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) 101 | height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 102 | fps = int(cap.get(cv2.CAP_PROP_FPS)) 103 | n_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 104 | cap.release() 105 | 106 | caption = item["label"].replace("pretending to ", "") 107 | 108 | records.append( 109 | { 110 | "video_path": mp4_video_path, 111 | "caption": caption, 112 | "height": height, 113 | "width": width, 114 | "fps": fps, 115 | "n_frames": n_frames, 116 | "split": split, 117 | } 118 | ) 119 | successful_conversions += 1 120 | except subprocess.CalledProcessError: 121 | print(f"Conversion failed for {webm_video_path}") 122 | 123 | conversion_rate = (successful_conversions / total_videos) * 100 124 | print(f"Conversion success rate: {conversion_rate:.2f}%") 125 | 126 | # Save as CSV 127 | metadata_path = self.data_root / self.metadata_path 128 | metadata_path.parent.mkdir(parents=True, exist_ok=True) 129 | df = pd.DataFrame.from_records(records) 130 | df.to_csv(metadata_path, index=False) 131 | print(f"Created metadata CSV with {len(records)} videos") 132 | -------------------------------------------------------------------------------- /algorithms/wan/modules/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import torch 3 | 4 | try: 5 | import flash_attn_interface 6 | FLASH_ATTN_3_AVAILABLE = True 7 | except ModuleNotFoundError: 8 | FLASH_ATTN_3_AVAILABLE = False 9 | 10 | try: 11 | import flash_attn 12 | FLASH_ATTN_2_AVAILABLE = True 13 | except ModuleNotFoundError: 14 | FLASH_ATTN_2_AVAILABLE = False 15 | 16 | import warnings 17 | 18 | __all__ = [ 19 | 'flash_attention', 20 | 'attention', 21 | ] 22 | 23 | 24 | def flash_attention( 25 | q, 26 | k, 27 | v, 28 | q_lens=None, 29 | k_lens=None, 30 | dropout_p=0., 31 | softmax_scale=None, 32 | q_scale=None, 33 | causal=False, 34 | window_size=(-1, -1), 35 | deterministic=False, 36 | dtype=torch.bfloat16, 37 | version=None, 38 | ): 39 | """ 40 | q: [B, Lq, Nq, C1]. 41 | k: [B, Lk, Nk, C1]. 42 | v: [B, Lk, Nk, C2]. Nq must be divisible by Nk. 43 | q_lens: [B]. 44 | k_lens: [B]. 45 | dropout_p: float. Dropout probability. 46 | softmax_scale: float. The scaling of QK^T before applying softmax. 47 | causal: bool. Whether to apply causal attention mask. 48 | window_size: (left right). If not (-1, -1), apply sliding window local attention. 49 | deterministic: bool. If True, slightly slower and uses more memory. 50 | dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16. 51 | """ 52 | half_dtypes = (torch.float16, torch.bfloat16) 53 | assert dtype in half_dtypes 54 | assert q.device.type == 'cuda' and q.size(-1) <= 256 55 | 56 | # params 57 | b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype 58 | 59 | def half(x): 60 | return x if x.dtype in half_dtypes else x.to(dtype) 61 | 62 | # preprocess query 63 | if q_lens is None: 64 | q = half(q.flatten(0, 1)) 65 | q_lens = torch.tensor( 66 | [lq] * b, dtype=torch.int32).to( 67 | device=q.device, non_blocking=True) 68 | else: 69 | q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)])) 70 | 71 | # preprocess key, value 72 | if k_lens is None: 73 | k = half(k.flatten(0, 1)) 74 | v = half(v.flatten(0, 1)) 75 | k_lens = torch.tensor( 76 | [lk] * b, dtype=torch.int32).to( 77 | device=k.device, non_blocking=True) 78 | else: 79 | k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)])) 80 | v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)])) 81 | 82 | q = q.to(v.dtype) 83 | k = k.to(v.dtype) 84 | 85 | if q_scale is not None: 86 | q = q * q_scale 87 | 88 | if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE: 89 | warnings.warn( 90 | 'Flash attention 3 is not available, use flash attention 2 instead.' 91 | ) 92 | 93 | # apply attention 94 | if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE: 95 | # Note: dropout_p, window_size are not supported in FA3 now. 96 | x = flash_attn_interface.flash_attn_varlen_func( 97 | q=q, 98 | k=k, 99 | v=v, 100 | cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum( 101 | 0, dtype=torch.int32).to(q.device, non_blocking=True), 102 | cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum( 103 | 0, dtype=torch.int32).to(q.device, non_blocking=True), 104 | seqused_q=None, 105 | seqused_k=None, 106 | max_seqlen_q=lq, 107 | max_seqlen_k=lk, 108 | softmax_scale=softmax_scale, 109 | causal=causal, 110 | deterministic=deterministic)[0].unflatten(0, (b, lq)) 111 | else: 112 | assert FLASH_ATTN_2_AVAILABLE 113 | x = flash_attn.flash_attn_varlen_func( 114 | q=q, 115 | k=k, 116 | v=v, 117 | cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum( 118 | 0, dtype=torch.int32).to(q.device, non_blocking=True), 119 | cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum( 120 | 0, dtype=torch.int32).to(q.device, non_blocking=True), 121 | max_seqlen_q=lq, 122 | max_seqlen_k=lk, 123 | dropout_p=dropout_p, 124 | softmax_scale=softmax_scale, 125 | causal=causal, 126 | window_size=window_size, 127 | deterministic=deterministic).unflatten(0, (b, lq)) 128 | 129 | # output 130 | return x.type(out_dtype) 131 | 132 | 133 | def attention( 134 | q, 135 | k, 136 | v, 137 | q_lens=None, 138 | k_lens=None, 139 | dropout_p=0., 140 | softmax_scale=None, 141 | q_scale=None, 142 | causal=False, 143 | window_size=(-1, -1), 144 | deterministic=False, 145 | dtype=torch.bfloat16, 146 | fa_version=None, 147 | ): 148 | if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE: 149 | return flash_attention( 150 | q=q, 151 | k=k, 152 | v=v, 153 | q_lens=q_lens, 154 | k_lens=k_lens, 155 | dropout_p=dropout_p, 156 | softmax_scale=softmax_scale, 157 | q_scale=q_scale, 158 | causal=causal, 159 | window_size=window_size, 160 | deterministic=deterministic, 161 | dtype=dtype, 162 | version=fa_version, 163 | ) 164 | else: 165 | if q_lens is not None or k_lens is not None: 166 | warnings.warn( 167 | 'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.' 168 | ) 169 | attn_mask = None 170 | 171 | q = q.transpose(1, 2).to(dtype) 172 | k = k.transpose(1, 2).to(dtype) 173 | v = v.transpose(1, 2).to(dtype) 174 | 175 | out = torch.nn.functional.scaled_dot_product_attention( 176 | q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p) 177 | 178 | out = out.transpose(1, 2).contiguous() 179 | return out 180 | -------------------------------------------------------------------------------- /datasets/mixture.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from torch.utils.data import IterableDataset, Dataset 3 | from omegaconf import DictConfig 4 | import torch 5 | import numpy as np 6 | from datasets.dummy import DummyVideoDataset 7 | from datasets.openx_base import OpenXVideoDataset 8 | from datasets.droid import DroidVideoDataset 9 | from datasets.something_something import SomethingSomethingDataset 10 | from datasets.epic_kitchen import EpicKitchenDataset 11 | from datasets.pandas import PandasVideoDataset 12 | from datasets.deprecated.video_1x_wm import WorldModel1XDataset 13 | from datasets.agibot_world import AgibotWorldDataset 14 | from datasets.ego4d import Ego4DVideoDataset 15 | 16 | subset_classes = dict( 17 | dummy=DummyVideoDataset, 18 | something_something=SomethingSomethingDataset, 19 | epic_kitchen=EpicKitchenDataset, 20 | pandas=PandasVideoDataset, 21 | agibot_world=AgibotWorldDataset, 22 | video_1x_wm=WorldModel1XDataset, 23 | ego4d=Ego4DVideoDataset, 24 | droid=DroidVideoDataset, 25 | austin_buds=OpenXVideoDataset, 26 | austin_sailor=OpenXVideoDataset, 27 | austin_sirius=OpenXVideoDataset, 28 | bc_z=OpenXVideoDataset, 29 | berkeley_autolab=OpenXVideoDataset, 30 | berkeley_cable=OpenXVideoDataset, 31 | berkeley_fanuc=OpenXVideoDataset, 32 | bridge=OpenXVideoDataset, 33 | cmu_stretch=OpenXVideoDataset, 34 | dlr_edan=OpenXVideoDataset, 35 | dobbe=OpenXVideoDataset, 36 | fmb=OpenXVideoDataset, 37 | fractal=OpenXVideoDataset, 38 | iamlab_cmu=OpenXVideoDataset, 39 | jaco_play=OpenXVideoDataset, 40 | language_table=OpenXVideoDataset, 41 | nyu_franka=OpenXVideoDataset, 42 | roboturk=OpenXVideoDataset, 43 | stanford_hydra=OpenXVideoDataset, 44 | taco_play=OpenXVideoDataset, 45 | toto=OpenXVideoDataset, 46 | ucsd_kitchen=OpenXVideoDataset, 47 | utaustin_mutex=OpenXVideoDataset, 48 | viola=OpenXVideoDataset, 49 | ) 50 | 51 | 52 | class MixtureDataset(IterableDataset): 53 | """ 54 | A fault tolerant mixture of video datasets 55 | """ 56 | 57 | def __init__(self, cfg: DictConfig, split: str = "training"): 58 | super().__init__() 59 | self.cfg = cfg 60 | self.debug = cfg.debug 61 | self.split = split 62 | self.random_seed = np.random.get_state()[1][0] # Get current numpy random seed 63 | self.subset_cfg = { 64 | k.split("/")[1]: v for k, v in self.cfg.items() if k.startswith("subset/") 65 | } 66 | if split == "all": 67 | raise ValueError("split cannot be `all` for MixtureDataset`") 68 | weight = dict(self.cfg[split].weight) 69 | # Check if all keys in weight exist in subset_cfg 70 | for key in weight: 71 | if key not in self.subset_cfg: 72 | raise ValueError( 73 | f"Dataset '{key}' specified in weights but not found in configuration" 74 | ) 75 | self.subset_cfg = {k: v for k, v in self.subset_cfg.items() if k in weight} 76 | weight_type = self.cfg[split].weight_type # one of relative or absolute 77 | self.subsets: List[Dataset] = [] 78 | for subset_name, subset_cfg in self.subset_cfg.items(): 79 | subset_cfg["height"] = self.cfg.height 80 | subset_cfg["width"] = self.cfg.width 81 | subset_cfg["n_frames"] = self.cfg.n_frames 82 | subset_cfg["fps"] = self.cfg.fps 83 | subset_cfg["load_video_latent"] = self.cfg.load_video_latent 84 | subset_cfg["load_prompt_embed"] = self.cfg.load_prompt_embed 85 | subset_cfg["max_text_tokens"] = self.cfg.max_text_tokens 86 | subset_cfg["image_to_video"] = self.cfg.image_to_video 87 | self.subsets.append(subset_classes[subset_name](subset_cfg, split)) 88 | if weight_type == "relative": 89 | weight[subset_name] = weight[subset_name] * len(self.subsets[-1]) 90 | 91 | # Normalize weights to sum to 1 92 | total_weight = sum(weight.values()) 93 | self.normalized_weights = {k: v / total_weight for k, v in weight.items()} 94 | 95 | # Store dataset sizes for printing 96 | dataset_sizes = { 97 | subset_name: len(subset) 98 | for subset_name, subset in zip(self.subset_cfg.keys(), self.subsets) 99 | } 100 | 101 | # Print normalized weights and dataset sizes in a nice format 102 | print("\nDataset information for split '{}':".format(self.split)) 103 | print("-" * 60) 104 | print(f"{'Dataset':<25} {'Size':<10} {'Weight':<10} {'Normalized':<10}") 105 | print("-" * 60) 106 | for subset_name, norm_weight in sorted( 107 | self.normalized_weights.items(), key=lambda x: -x[1] 108 | ): 109 | size = dataset_sizes[subset_name] 110 | orig_weight = self.cfg[split].weight[subset_name] 111 | print( 112 | f"{subset_name:<25} {size:<10,d} {orig_weight:<10.4f} {norm_weight:<10.4f}" 113 | ) 114 | print("-" * 60) 115 | 116 | # Calculate cumulative probabilities for sampling 117 | self.cumsum_weights = {} 118 | cumsum = 0 119 | for k, v in self.normalized_weights.items(): 120 | cumsum += v 121 | self.cumsum_weights[k] = cumsum 122 | 123 | # some scripts want to access the records 124 | self.records = [] 125 | for subset in self.subsets: 126 | self.records.extend(subset.records) 127 | 128 | def __iter__(self): 129 | while True: 130 | # Sample a random subset based on weights using numpy random 131 | rand = np.random.random() 132 | for subset_name, cumsum in self.cumsum_weights.items(): 133 | if rand <= cumsum: 134 | selected_subset = subset_name 135 | break 136 | 137 | # Get the corresponding dataset index 138 | subset_idx = list(self.subset_cfg.keys()).index(selected_subset) 139 | 140 | try: 141 | # Sample randomly from the selected dataset using numpy random 142 | dataset = self.subsets[subset_idx] 143 | idx = np.random.randint(len(dataset)) 144 | sample = dataset[idx] 145 | yield sample 146 | except Exception as e: 147 | if self.debug: 148 | raise e 149 | else: 150 | print(f"Error sampling from {selected_subset}: {str(e)}") 151 | continue 152 | -------------------------------------------------------------------------------- /algorithms/wan/distributed/xdit_context_parallel.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. 2 | import torch 3 | import torch.amp as amp 4 | from xfuser.core.distributed import ( 5 | get_sequence_parallel_rank, 6 | get_sequence_parallel_world_size, 7 | get_sp_group, 8 | ) 9 | from xfuser.core.long_ctx_attention import xFuserLongContextAttention 10 | 11 | from ..modules.model import sinusoidal_embedding_1d 12 | 13 | 14 | def pad_freqs(original_tensor, target_len): 15 | seq_len, s1, s2 = original_tensor.shape 16 | pad_size = target_len - seq_len 17 | padding_tensor = torch.ones( 18 | pad_size, s1, s2, dtype=original_tensor.dtype, device=original_tensor.device 19 | ) 20 | padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0) 21 | return padded_tensor 22 | 23 | 24 | @amp.autocast("cuda", enabled=False) 25 | def rope_apply(x, grid_sizes, freqs): 26 | """ 27 | x: [B, L, N, C]. 28 | grid_sizes: [B, 3]. 29 | freqs: [M, C // 2]. 30 | """ 31 | s, n, c = x.size(1), x.size(2), x.size(3) // 2 32 | # split freqs 33 | freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) 34 | 35 | # loop over samples 36 | output = [] 37 | for i, (f, h, w) in enumerate(grid_sizes.tolist()): 38 | seq_len = f * h * w 39 | 40 | # precompute multipliers 41 | x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(s, n, -1, 2)) 42 | freqs_i = torch.cat( 43 | [ 44 | freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), 45 | freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), 46 | freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1), 47 | ], 48 | dim=-1, 49 | ).reshape(seq_len, 1, -1) 50 | 51 | # apply rotary embedding 52 | sp_size = get_sequence_parallel_world_size() 53 | sp_rank = get_sequence_parallel_rank() 54 | freqs_i = pad_freqs(freqs_i, s * sp_size) 55 | s_per_rank = s 56 | freqs_i_rank = freqs_i[ 57 | (sp_rank * s_per_rank) : ((sp_rank + 1) * s_per_rank), :, : 58 | ] 59 | x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2) 60 | x_i = torch.cat([x_i, x[i, s:]]) 61 | 62 | # append to collection 63 | output.append(x_i) 64 | return torch.stack(output).float() 65 | 66 | 67 | def usp_dit_forward( 68 | self, 69 | x, 70 | t, 71 | context, 72 | seq_len, 73 | clip_fea=None, 74 | y=None, 75 | ): 76 | """ 77 | x: A list of videos each with shape [C, T, H, W]. 78 | t: [B]. 79 | context: A list of text embeddings each with shape [L, C]. 80 | """ 81 | if self.model_type == "i2v": 82 | assert clip_fea is not None and y is not None 83 | # params 84 | device = self.patch_embedding.weight.device 85 | if self.freqs.device != device: 86 | self.freqs = self.freqs.to(device) 87 | 88 | if y is not None: 89 | x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] 90 | 91 | # embeddings 92 | x = [self.patch_embedding(u.unsqueeze(0)) for u in x] 93 | grid_sizes = torch.stack([torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) 94 | x = [u.flatten(2).transpose(1, 2) for u in x] 95 | seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) 96 | assert seq_lens.max() <= seq_len 97 | x = torch.cat( 98 | [ 99 | torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1) 100 | for u in x 101 | ] 102 | ) 103 | 104 | # time embeddings 105 | with amp.autocast("cuda", dtype=torch.float32): 106 | e = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, t).float()) 107 | e0 = self.time_projection(e).unflatten(1, (6, self.dim)) 108 | assert e.dtype == torch.float32 and e0.dtype == torch.float32 109 | 110 | # context 111 | context_lens = None 112 | context = self.text_embedding( 113 | torch.stack( 114 | [ 115 | torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) 116 | for u in context 117 | ] 118 | ) 119 | ) 120 | 121 | if clip_fea is not None: 122 | context_clip = self.img_emb(clip_fea) # bs x 257 x dim 123 | context = torch.concat([context_clip, context], dim=1) 124 | 125 | # arguments 126 | kwargs = dict( 127 | e=e0, 128 | seq_lens=seq_lens, 129 | grid_sizes=grid_sizes, 130 | freqs=self.freqs, 131 | context=context, 132 | context_lens=context_lens, 133 | ) 134 | 135 | # Context Parallel 136 | x = torch.chunk(x, get_sequence_parallel_world_size(), dim=1)[ 137 | get_sequence_parallel_rank() 138 | ] 139 | 140 | for block in self.blocks: 141 | x = block(x, **kwargs) 142 | 143 | # head 144 | x = self.head(x, e) 145 | 146 | # Context Parallel 147 | x = get_sp_group().all_gather(x, dim=1) 148 | 149 | # unpatchify 150 | x = self.unpatchify(x, grid_sizes) 151 | return [u.float() for u in x] 152 | 153 | 154 | def usp_attn_forward(self, x, seq_lens, grid_sizes, freqs, dtype=torch.bfloat16): 155 | b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim 156 | half_dtypes = (torch.float16, torch.bfloat16) 157 | 158 | def half(x): 159 | return x if x.dtype in half_dtypes else x.to(dtype) 160 | 161 | # query, key, value function 162 | def qkv_fn(x): 163 | q = self.norm_q(self.q(x)).view(b, s, n, d) 164 | k = self.norm_k(self.k(x)).view(b, s, n, d) 165 | v = self.v(x).view(b, s, n, d) 166 | return q, k, v 167 | 168 | q, k, v = qkv_fn(x) 169 | q = rope_apply(q, grid_sizes, freqs) 170 | k = rope_apply(k, grid_sizes, freqs) 171 | 172 | # TODO: We should use unpaded q,k,v for attention. 173 | # k_lens = seq_lens // get_sequence_parallel_world_size() 174 | # if k_lens is not None: 175 | # q = torch.cat([u[:l] for u, l in zip(q, k_lens)]).unsqueeze(0) 176 | # k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0) 177 | # v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0) 178 | 179 | x = xFuserLongContextAttention()( 180 | None, query=half(q), key=half(k), value=half(v), window_size=self.window_size 181 | ) 182 | 183 | # TODO: padding after attention. 184 | # x = torch.cat([x, x.new_zeros(b, s - x.size(1), n, d)], dim=1) 185 | 186 | # output 187 | x = x.flatten(2) 188 | x = self.o(x) 189 | return x 190 | -------------------------------------------------------------------------------- /datasets/ego4d.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from pathlib import Path 3 | import ijson 4 | from typing import Dict, Any 5 | from .video_base import VideoDataset 6 | 7 | 8 | class Ego4DVideoDataset(VideoDataset): 9 | 10 | def download(self): 11 | from ego4d.cli.cli import main_cfg as download_ego4d 12 | from ego4d.cli.config import Config as Ego4DConfig 13 | 14 | raw_dir = self.data_root / "raw" 15 | raw_dir.mkdir(parents=True, exist_ok=True) 16 | 17 | aws_credentials_path = Path.home() / ".aws" / "credentials" 18 | if not aws_credentials_path.exists(): 19 | raise FileNotFoundError( 20 | f"AWS credentials file not found at {aws_credentials_path}" 21 | "For Ego4D auto download, you need to request access and use the " 22 | "emailed key to set up AWS credentials first." 23 | "See https://ego4d-data.org/ for more information." 24 | ) 25 | 26 | cfg = Ego4DConfig( 27 | output_directory=str(raw_dir), 28 | datasets=["annotations", "clips"], 29 | benchmarks=["FHO"], 30 | metadata=True, 31 | assume_yes=True, 32 | ) 33 | 34 | import botocore 35 | 36 | try: 37 | download_ego4d(cfg) 38 | except botocore.exceptions.ClientError as e: 39 | print(e) 40 | raise RuntimeError( 41 | "Failed to download Ego4D dataset due to the above error." 42 | "If you see an error occurred (403) when calling the HeadObject operation: Forbidden", 43 | "It's likely due to an expired Ego4D AWS credential. Renew the dataset's online form and update the AWS credentials.", 44 | ) 45 | 46 | annotation_file = "v2/annotations/fho_main.json" 47 | print("Creating metadata CSV...") 48 | records = [] 49 | with open(raw_dir / annotation_file, "rb") as file: 50 | # Create a parser for the videos array 51 | videos = ijson.items(file, "videos.item") 52 | total = 0 53 | 54 | for v in videos: 55 | fps = round(v["video_metadata"]["fps"]) 56 | n_frames = v["video_metadata"]["num_frames"] 57 | width = v["video_metadata"]["width"] 58 | height = v["video_metadata"]["height"] 59 | for c in v["annotated_intervals"]: 60 | video_path = "raw/v2/clips/" + c["clip_uid"] + ".mp4" 61 | 62 | if not Path(self.data_root / video_path).exists(): 63 | continue 64 | 65 | for a in c["narrated_actions"]: 66 | total += 1 67 | critical_frames = a["clip_critical_frames"] 68 | is_valid_action = a["is_valid_action"] 69 | is_rejected = a["is_rejected"] 70 | is_invalid_annotation = a["is_invalid_annotation"] 71 | is_partial = a["is_partial"] 72 | if ( 73 | not critical_frames 74 | or not is_valid_action 75 | or is_rejected 76 | or is_invalid_annotation 77 | or is_partial 78 | ): 79 | continue 80 | caption = a["narration_text"] 81 | caption = ( 82 | caption.replace("#cC c ", " ") 83 | .replace("#Cc C ", " ") 84 | .replace("#C C ", "") 85 | .replace("#c c ", " ") 86 | .replace("#c- c ", " ") 87 | .replace("#c C ", " ") 88 | .replace("#c c", " ") 89 | .replace("#CC ", " ") 90 | .replace("#C C ", " ") 91 | .replace("#C c ", " ") 92 | .replace("#cc ", " ") 93 | .replace("#C- C ", " ") 94 | .replace("#c C ", " ") 95 | .replace("#C ", " ") 96 | .replace("#c ", " ") 97 | .replace("#", " ") 98 | ) 99 | pre_frame = critical_frames["pre_frame"] 100 | post_frame = critical_frames["post_frame"] 101 | pnr_frame = critical_frames["pnr_frame"] 102 | contact_frame = critical_frames["contact_frame"] 103 | 104 | # some manual heuristics to trim the video 105 | target_len = self._n_frames_in_src(fps) 106 | trim_start = pre_frame 107 | psudo_min_end = int((post_frame - pnr_frame) * 0.1) + pnr_frame 108 | if psudo_min_end - pre_frame >= target_len: 109 | trim_end = psudo_min_end 110 | elif post_frame - pnr_frame < target_len: 111 | trim_end = post_frame 112 | trim_start = max(trim_end - target_len, pre_frame - 15) 113 | else: 114 | trim_end = target_len + pre_frame 115 | 116 | trim_start = max(0, trim_start) 117 | trim_end = min(n_frames, trim_end) 118 | 119 | records.append( 120 | { 121 | "video_path": video_path, 122 | "height": height, 123 | "width": width, 124 | "n_frames": n_frames, 125 | "fps": fps, 126 | "original_caption": caption, 127 | "trim_start": trim_start, 128 | "trim_end": trim_end, 129 | "pre_frame": pre_frame, 130 | "pnr_frame": pnr_frame, 131 | "post_frame": post_frame, 132 | "contact_frame": contact_frame, 133 | } 134 | ) 135 | metadata_path = self.data_root / self.metadata_path 136 | metadata_path.parent.mkdir(parents=True, exist_ok=True) 137 | df = pd.DataFrame.from_records(records) 138 | df.to_csv(metadata_path, index=False) 139 | print(f"Created metadata CSV with {len(records)} records") 140 | -------------------------------------------------------------------------------- /algorithms/wan/wan_i2v.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops import rearrange, repeat 4 | from transformers import get_scheduler 5 | from .modules.clip import clip_xlm_roberta_vit_h_14 6 | from .wan_t2v import WanTextToVideo 7 | 8 | 9 | class WanImageToVideo(WanTextToVideo): 10 | """ 11 | Main class for WanImageToVideo, inheriting from WanTextToVideo 12 | """ 13 | 14 | def __init__(self, cfg): 15 | super().__init__(cfg) 16 | self.cfg.model.in_dim = self.cfg.vae.z_dim * 2 + 4 17 | 18 | def configure_model(self): 19 | # Call parent's configure_model first 20 | super().configure_model() 21 | 22 | if self.cfg.model.tuned_ckpt_path is None: 23 | self.model.hack_embedding_ckpt() 24 | 25 | # Additionally initialize CLIP for image encoding 26 | clip, clip_transform = clip_xlm_roberta_vit_h_14( 27 | pretrained=False, 28 | return_transforms=True, 29 | return_tokenizer=False, 30 | dtype=torch.float16 if self.is_inference else self.dtype, 31 | device="cpu", 32 | ) 33 | if self.cfg.clip.ckpt_path is not None: 34 | clip.load_state_dict( 35 | torch.load( 36 | self.cfg.clip.ckpt_path, map_location="cpu", weights_only=True 37 | ) 38 | ) 39 | if self.cfg.clip.compile: 40 | clip = torch.compile(clip) 41 | self.clip = clip 42 | self.clip_normalize = clip_transform.transforms[-1] 43 | 44 | def configure_optimizers(self): 45 | optimizer = torch.optim.AdamW( 46 | [ 47 | {"params": self.model.parameters(), "lr": self.cfg.lr}, 48 | {"params": self.vae.parameters(), "lr": 0}, 49 | {"params": self.clip.parameters(), "lr": 0}, 50 | ], 51 | weight_decay=self.cfg.weight_decay, 52 | betas=self.cfg.betas, 53 | ) 54 | # optimizer = torch.optim.AdamW( 55 | # self.model.parameters(), 56 | # lr=self.cfg.lr, 57 | # weight_decay=self.cfg.weight_decay, 58 | # betas=self.cfg.betas, 59 | # ) 60 | lr_scheduler_config = { 61 | "scheduler": get_scheduler( 62 | optimizer=optimizer, 63 | **self.cfg.lr_scheduler, 64 | ), 65 | "interval": "step", 66 | "frequency": 1, 67 | } 68 | 69 | return { 70 | "optimizer": optimizer, 71 | "lr_scheduler": lr_scheduler_config, 72 | } 73 | 74 | def clip_features(self, videos): 75 | size = (self.clip.image_size,) * 2 76 | videos = rearrange(videos, "b t c h w -> (b t) c h w") 77 | videos = nn.functional.interpolate( 78 | videos, size=size, mode="bicubic", align_corners=False 79 | ) 80 | videos = self.clip_normalize(videos.mul_(0.5).add_(0.5)) 81 | return self.clip.visual(videos, use_31_block=True) 82 | 83 | @torch.no_grad() 84 | def prepare_embeds(self, batch): 85 | batch = super().prepare_embeds(batch) 86 | 87 | videos = batch["videos"] 88 | images = videos[:, :1] 89 | has_bbox = batch["has_bbox"] # [B, 2] 90 | bbox_render = batch["bbox_render"] # [B, 2, H, W] 91 | 92 | batch_size, t, _, h, w = videos.shape 93 | lat_c, lat_t, lat_h, lat_w = self.lat_c, self.lat_t, self.lat_h, self.lat_w 94 | 95 | clip_embeds = self.clip_features(images) 96 | batch["clip_embeds"] = clip_embeds 97 | 98 | mask = torch.zeros( 99 | batch_size, 100 | self.vae_stride[0], 101 | lat_t, 102 | lat_h, 103 | lat_w, 104 | device=self.device, 105 | dtype=self.dtype, 106 | ) 107 | # after the ckpt hack, we repurpose the 4 mask channels for bounding box conditioning 108 | # second last channel is indicator of bounding box 109 | mask[:, 2, 0] = has_bbox[..., 0, None, None] 110 | mask[:, 2, -1] = has_bbox[..., -1, None, None] 111 | # Interpolate bbox_render to match latent dimensions 112 | bbox_render_resized = nn.functional.interpolate( 113 | bbox_render, 114 | size=(lat_h, lat_w), 115 | mode="bicubic", 116 | align_corners=False, 117 | ) 118 | # last channel is renderred bbox 119 | mask[:, 3, 0] = bbox_render_resized[:, 0] 120 | mask[:, 3, -1] = bbox_render_resized[:, -1] 121 | 122 | if self.diffusion_forcing.enabled: 123 | image_embeds = torch.zeros( 124 | batch_size, 125 | 4 + lat_c, 126 | lat_t, 127 | lat_h, 128 | lat_w, 129 | device=self.device, 130 | dtype=self.dtype, 131 | ) 132 | else: 133 | padded_images = torch.zeros(batch_size, 3, t - 1, h, w, device=self.device) 134 | padded_images = torch.cat( 135 | [rearrange(images, "b 1 c h w -> b c 1 h w"), padded_images], dim=2 136 | ) 137 | image_embeds = self.encode_video( 138 | padded_images 139 | ) # b, lat_c, lat_t, lat_h, lat_w 140 | image_embeds = torch.cat([mask, image_embeds], 1) 141 | mask[:, :2, 0] = 1 142 | batch["image_embeds"] = image_embeds 143 | 144 | return batch 145 | 146 | def visualize(self, video_pred, batch): 147 | bbox_render = batch["bbox_render"] # b, 2, h, w for first and last frame 148 | has_bbox = batch["has_bbox"] # b, 2 for first and last frame 149 | video_gt = batch["videos"] # b, t, 3, h, w 150 | 151 | alpha = 0.4 152 | l = video_gt.shape[1] // 4 153 | 154 | # Apply green bbox overlay with transparency to first frame if has_bbox for first frame 155 | mask = has_bbox[:, 0].bool() 156 | green = torch.zeros_like(video_gt[mask, :1]) 157 | green[:, :, 1] = 1.0 158 | if mask.any(): 159 | bbox = bbox_render[:, None, 0:1][mask] * alpha # b', 1, 1, h, w 160 | video_gt[mask, :l] = (1 - bbox) * video_gt[mask, :l] + bbox * green 161 | 162 | # Apply green bbox overlay with transparency to last frame if has_bbox for last frame 163 | mask = has_bbox[:, 1].bool() 164 | green = torch.zeros_like(video_gt[mask, :1]) 165 | green[:, :, 1] = 1.0 166 | if mask.any(): 167 | bbox = bbox_render[:, None, 1:2][mask] * alpha # b', 1, 1, h, w 168 | video_gt[mask, -l:] = (1 - bbox) * video_gt[mask, -l:] + bbox * green 169 | 170 | batch["videos"] = video_gt 171 | 172 | return super().visualize(video_pred, batch) 173 | -------------------------------------------------------------------------------- /utils/wandb_utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from datetime import timedelta 3 | from typing import TYPE_CHECKING, Any, Literal, Mapping, Optional, Union 4 | from typing_extensions import override 5 | from functools import wraps 6 | import os 7 | from wandb_osh.hooks import TriggerWandbSyncHook 8 | import time 9 | from lightning.pytorch.loggers.wandb import WandbLogger, _scan_checkpoints, ModelCheckpoint, Tensor 10 | from lightning.pytorch.utilities.rank_zero import rank_zero_only 11 | from lightning.fabric.utilities.types import _PATH 12 | 13 | 14 | if TYPE_CHECKING: 15 | from wandb.sdk.lib import RunDisabled 16 | from wandb.wandb_run import Run 17 | 18 | 19 | class SpaceEfficientWandbLogger(WandbLogger): 20 | """ 21 | A wandb logger that by default overrides artifacts to save space, instead of creating new version. 22 | A variable expiration_days can be set to control how long older versions of artifacts are kept. 23 | By default, the latest version is kept indefinitely, while older versions are kept for 5 days. 24 | """ 25 | 26 | def __init__( 27 | self, 28 | name: Optional[str] = None, 29 | save_dir: _PATH = ".", 30 | version: Optional[str] = None, 31 | offline: bool = False, 32 | dir: Optional[_PATH] = None, 33 | id: Optional[str] = None, 34 | anonymous: Optional[bool] = None, 35 | project: Optional[str] = None, 36 | log_model: Union[Literal["all"], bool] = False, 37 | experiment: Union["Run", "RunDisabled", None] = None, 38 | prefix: str = "", 39 | checkpoint_name: Optional[str] = None, 40 | expiration_days: Optional[int] = 5, 41 | **kwargs: Any, 42 | ) -> None: 43 | super().__init__( 44 | name=name, 45 | save_dir=save_dir, 46 | version=version, 47 | offline=False, 48 | dir=dir, 49 | id=id, 50 | anonymous=anonymous, 51 | project=project, 52 | log_model=log_model, 53 | experiment=experiment, 54 | prefix=prefix, 55 | checkpoint_name=checkpoint_name, 56 | **kwargs, 57 | ) 58 | 59 | super().__init__( 60 | name=name, 61 | save_dir=save_dir, 62 | version=version, 63 | offline=offline, 64 | dir=dir, 65 | id=id, 66 | anonymous=anonymous, 67 | project=project, 68 | log_model=log_model, 69 | experiment=experiment, 70 | prefix=prefix, 71 | checkpoint_name=checkpoint_name, 72 | **kwargs, 73 | ) 74 | self.expiration_days = expiration_days 75 | self._last_artifacts = [] 76 | 77 | def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> None: 78 | import wandb 79 | 80 | # get checkpoints to be saved with associated score 81 | checkpoints = _scan_checkpoints(checkpoint_callback, self._logged_model_time) 82 | 83 | # log iteratively all new checkpoints 84 | artifacts = [] 85 | for t, p, s, tag in checkpoints: 86 | metadata = { 87 | "score": s.item() if isinstance(s, Tensor) else s, 88 | "original_filename": Path(p).name, 89 | checkpoint_callback.__class__.__name__: { 90 | k: getattr(checkpoint_callback, k) 91 | for k in [ 92 | "monitor", 93 | "mode", 94 | "save_last", 95 | "save_top_k", 96 | "save_weights_only", 97 | "_every_n_train_steps", 98 | ] 99 | # ensure it does not break if `ModelCheckpoint` args change 100 | if hasattr(checkpoint_callback, k) 101 | }, 102 | } 103 | if not self._checkpoint_name: 104 | self._checkpoint_name = f"model-{self.experiment.id}" 105 | 106 | artifact = wandb.Artifact(name=self._checkpoint_name, type="model", metadata=metadata) 107 | artifact.add_file(p, name="model.ckpt") 108 | aliases = ["latest", "best"] if p == checkpoint_callback.best_model_path else ["latest"] 109 | self.experiment.log_artifact(artifact, aliases=aliases) 110 | # remember logged models - timestamp needed in case filename didn't change (lastkckpt or custom name) 111 | self._logged_model_time[p] = t 112 | artifacts.append(artifact) 113 | 114 | for artifact in self._last_artifacts: 115 | if not self._offline: 116 | artifact.wait() 117 | artifact.ttl = timedelta(days=self.expiration_days) 118 | artifact.save() 119 | self._last_artifacts = artifacts 120 | 121 | 122 | class OfflineWandbLogger(SpaceEfficientWandbLogger): 123 | """ 124 | Wraps WandbLogger to trigger offline sync hook occasionally. 125 | This is useful when running on slurm clusters, many of which 126 | only has internet on login nodes, not compute nodes. 127 | """ 128 | 129 | def __init__( 130 | self, 131 | name: Optional[str] = None, 132 | save_dir: _PATH = ".", 133 | version: Optional[str] = None, 134 | offline: bool = False, 135 | dir: Optional[_PATH] = None, 136 | id: Optional[str] = None, 137 | anonymous: Optional[bool] = None, 138 | project: Optional[str] = None, 139 | log_model: Union[Literal["all"], bool] = False, 140 | experiment: Union["Run", "RunDisabled", None] = None, 141 | prefix: str = "", 142 | checkpoint_name: Optional[str] = None, 143 | **kwargs: Any, 144 | ) -> None: 145 | super().__init__( 146 | name=name, 147 | save_dir=save_dir, 148 | version=version, 149 | offline=False, 150 | dir=dir, 151 | id=id, 152 | anonymous=anonymous, 153 | project=project, 154 | log_model=log_model, 155 | experiment=experiment, 156 | prefix=prefix, 157 | checkpoint_name=checkpoint_name, 158 | **kwargs, 159 | ) 160 | self._offline = offline 161 | communication_dir = Path(".wandb_osh_command_dir") 162 | communication_dir.mkdir(parents=True, exist_ok=True) 163 | self.trigger_sync = TriggerWandbSyncHook(communication_dir) 164 | self.last_sync_time = 0.0 165 | self.min_sync_interval = 60 166 | self.wandb_dir = os.path.join(self._save_dir, "wandb/latest-run") 167 | 168 | @override 169 | @rank_zero_only 170 | def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) -> None: 171 | out = super().log_metrics(metrics, step) 172 | if time.time() - self.last_sync_time > self.min_sync_interval: 173 | self.trigger_sync(self.wandb_dir) 174 | self.last_sync_time = time.time() 175 | return out 176 | -------------------------------------------------------------------------------- /algorithms/common/models/cnn.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | 6 | def is_square_of_two(num): 7 | if num <= 0: 8 | return False 9 | return num & (num - 1) == 0 10 | 11 | 12 | class CnnEncoder(nn.Module): 13 | """ 14 | Simple cnn encoder that encodes a 64x64 image to embeddings 15 | """ 16 | 17 | def __init__(self, embedding_size, activation_function="relu"): 18 | super().__init__() 19 | self.act_fn = getattr(F, activation_function) 20 | self.embedding_size = embedding_size 21 | self.fc = nn.Linear(1024, self.embedding_size) 22 | self.conv1 = nn.Conv2d(3, 32, 4, stride=2) 23 | self.conv2 = nn.Conv2d(32, 64, 4, stride=2) 24 | self.conv3 = nn.Conv2d(64, 128, 4, stride=2) 25 | self.conv4 = nn.Conv2d(128, 256, 4, stride=2) 26 | self.modules = [self.conv1, self.conv2, self.conv3, self.conv4] 27 | 28 | def forward(self, observation): 29 | batch_size = observation.shape[0] 30 | hidden = self.act_fn(self.conv1(observation)) 31 | hidden = self.act_fn(self.conv2(hidden)) 32 | hidden = self.act_fn(self.conv3(hidden)) 33 | hidden = self.act_fn(self.conv4(hidden)) 34 | hidden = self.fc(hidden.view(batch_size, 1024)) 35 | return hidden 36 | 37 | 38 | class CnnDecoder(nn.Module): 39 | """ 40 | Simple Cnn decoder that decodes an embedding to 64x64 images 41 | """ 42 | 43 | def __init__(self, embedding_size, activation_function="relu"): 44 | super().__init__() 45 | self.act_fn = getattr(F, activation_function) 46 | self.embedding_size = embedding_size 47 | self.fc = nn.Linear(embedding_size, 128) 48 | self.conv1 = nn.ConvTranspose2d(128, 128, 5, stride=2) 49 | self.conv2 = nn.ConvTranspose2d(128, 64, 5, stride=2) 50 | self.conv3 = nn.ConvTranspose2d(64, 32, 6, stride=2) 51 | self.conv4 = nn.ConvTranspose2d(32, 3, 6, stride=2) 52 | self.modules = [self.conv1, self.conv2, self.conv3, self.conv4] 53 | 54 | def forward(self, embedding): 55 | batch_size = embedding.shape[0] 56 | hidden = self.fc(embedding) 57 | hidden = hidden.view(batch_size, 128, 1, 1) 58 | hidden = self.act_fn(self.conv1(hidden)) 59 | hidden = self.act_fn(self.conv2(hidden)) 60 | hidden = self.act_fn(self.conv3(hidden)) 61 | observation = self.conv4(hidden) 62 | return observation 63 | 64 | 65 | class FullyConvEncoder(nn.Module): 66 | """ 67 | Simple fully convolutional encoder, with 2D input and 2D output 68 | """ 69 | 70 | def __init__( 71 | self, 72 | input_shape=(3, 64, 64), 73 | embedding_shape=(8, 16, 16), 74 | activation_function="relu", 75 | init_channels=16, 76 | ): 77 | super().__init__() 78 | 79 | assert len(input_shape) == 3, "input_shape must be a tuple of length 3" 80 | assert len(embedding_shape) == 3, "embedding_shape must be a tuple of length 3" 81 | assert input_shape[1] == input_shape[2] and is_square_of_two( 82 | input_shape[1] 83 | ), "input_shape must be square" 84 | assert ( 85 | embedding_shape[1] == embedding_shape[2] 86 | ), "embedding_shape must be square" 87 | assert ( 88 | input_shape[1] % embedding_shape[1] == 0 89 | ), "input_shape must be divisible by embedding_shape" 90 | assert is_square_of_two(init_channels), "init_channels must be a square of 2" 91 | 92 | depth = int(math.sqrt(input_shape[1] / embedding_shape[1])) + 1 93 | channels_per_layer = [init_channels * (2**i) for i in range(depth)] 94 | self.act_fn = getattr(F, activation_function) 95 | 96 | self.downs = nn.ModuleList([]) 97 | self.downs.append( 98 | nn.Conv2d( 99 | input_shape[0], 100 | channels_per_layer[0], 101 | kernel_size=3, 102 | stride=1, 103 | padding=1, 104 | ) 105 | ) 106 | 107 | for i in range(1, depth): 108 | self.downs.append( 109 | nn.Conv2d( 110 | channels_per_layer[i - 1], 111 | channels_per_layer[i], 112 | kernel_size=3, 113 | stride=2, 114 | padding=1, 115 | ) 116 | ) 117 | 118 | # Bottleneck layer 119 | self.downs.append( 120 | nn.Conv2d( 121 | channels_per_layer[-1], 122 | embedding_shape[0], 123 | kernel_size=1, 124 | stride=1, 125 | padding=0, 126 | ) 127 | ) 128 | 129 | def forward(self, observation): 130 | hidden = observation 131 | for layer in self.downs: 132 | hidden = self.act_fn(layer(hidden)) 133 | return hidden 134 | 135 | 136 | class FullyConvDecoder(nn.Module): 137 | """ 138 | Simple fully convolutional decoder, with 2D input and 2D output 139 | """ 140 | 141 | def __init__( 142 | self, 143 | embedding_shape=(8, 16, 16), 144 | output_shape=(3, 64, 64), 145 | activation_function="relu", 146 | init_channels=16, 147 | ): 148 | super().__init__() 149 | 150 | assert len(embedding_shape) == 3, "embedding_shape must be a tuple of length 3" 151 | assert len(output_shape) == 3, "output_shape must be a tuple of length 3" 152 | assert output_shape[1] == output_shape[2] and is_square_of_two( 153 | output_shape[1] 154 | ), "output_shape must be square" 155 | assert embedding_shape[1] == embedding_shape[2], "input_shape must be square" 156 | assert ( 157 | output_shape[1] % embedding_shape[1] == 0 158 | ), "output_shape must be divisible by input_shape" 159 | assert is_square_of_two(init_channels), "init_channels must be a square of 2" 160 | 161 | depth = int(math.sqrt(output_shape[1] / embedding_shape[1])) + 1 162 | channels_per_layer = [init_channels * (2**i) for i in range(depth)] 163 | self.act_fn = getattr(F, activation_function) 164 | 165 | self.ups = nn.ModuleList([]) 166 | self.ups.append( 167 | nn.ConvTranspose2d( 168 | embedding_shape[0], 169 | channels_per_layer[-1], 170 | kernel_size=1, 171 | stride=1, 172 | padding=0, 173 | ) 174 | ) 175 | 176 | for i in range(1, depth): 177 | self.ups.append( 178 | nn.ConvTranspose2d( 179 | channels_per_layer[-i], 180 | channels_per_layer[-i - 1], 181 | kernel_size=3, 182 | stride=2, 183 | padding=1, 184 | output_padding=1, 185 | ) 186 | ) 187 | 188 | self.output_layer = nn.ConvTranspose2d( 189 | channels_per_layer[0], output_shape[0], kernel_size=3, stride=1, padding=1 190 | ) 191 | 192 | def forward(self, embedding): 193 | hidden = embedding 194 | for layer in self.ups: 195 | hidden = self.act_fn(layer(hidden)) 196 | 197 | return self.output_layer(hidden) 198 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ 2 | This repo is forked from [Boyuan Chen](https://boyuan.space/)'s research 3 | template [repo](https://github.com/buoyancy99/research-template). 4 | By its MIT license, you must keep the above sentence in `README.md` 5 | and the `LICENSE` file to credit the author. 6 | 7 | Main file for the project. This will create and run new experiments and load checkpoints from wandb. 8 | Borrowed part of the code from David Charatan and wandb. 9 | """ 10 | 11 | import sys 12 | import subprocess 13 | import time 14 | from pathlib import Path 15 | 16 | import hydra 17 | from omegaconf import DictConfig, OmegaConf 18 | from omegaconf.omegaconf import open_dict 19 | 20 | from utils.print_utils import cyan 21 | from utils.distributed_utils import is_rank_zero 22 | from utils.ckpt_utils import download_latest_checkpoint, is_run_id 23 | from utils.cluster_utils import submit_slurm_job 24 | 25 | 26 | def run_local(cfg: DictConfig): 27 | # delay some imports in case they are not needed in non-local envs for submission 28 | from experiments import build_experiment 29 | 30 | # Get yaml names 31 | hydra_cfg = hydra.core.hydra_config.HydraConfig.get() 32 | cfg_choice = OmegaConf.to_container(hydra_cfg.runtime.choices) 33 | 34 | with open_dict(cfg): 35 | if cfg_choice["experiment"] is not None: 36 | cfg.experiment._name = cfg_choice["experiment"] 37 | if cfg_choice["dataset"] is not None: 38 | cfg.dataset._name = cfg_choice["dataset"] 39 | if cfg_choice["algorithm"] is not None: 40 | cfg.algorithm._name = cfg_choice["algorithm"] 41 | 42 | # Set up the output directory. 43 | output_dir = Path(hydra_cfg.runtime.output_dir) 44 | if is_rank_zero: 45 | print(cyan(f"Outputs will be saved to:"), output_dir) 46 | (output_dir.parents[1] / "latest-run").unlink(missing_ok=True) 47 | (output_dir.parents[1] / "latest-run").symlink_to( 48 | output_dir, target_is_directory=True 49 | ) 50 | 51 | # Resolve ckpt path 52 | resume = cfg.get("resume", None) 53 | load = cfg.get("load", None) 54 | checkpoint_path = None 55 | load_id = None 56 | if load and not is_run_id(load): 57 | checkpoint_path = load 58 | if resume: 59 | load_id = resume 60 | elif load and is_run_id(load): 61 | load_id = load 62 | else: 63 | load_id = None 64 | 65 | if load_id: 66 | run_path = f"{cfg.wandb.entity}/{cfg.wandb.project}/{load_id}" 67 | checkpoint_path = Path("outputs/downloaded") / run_path / "model.ckpt" 68 | 69 | # launch experiment 70 | experiment = build_experiment(cfg, output_dir, checkpoint_path) 71 | 72 | # for those who are searching, this is where we call tasks like 'training, validation, main' 73 | for task in cfg.experiment.tasks: 74 | experiment.exec_task(task) 75 | 76 | 77 | def run_slurm(cfg: DictConfig): 78 | python_args = " ".join(sys.argv[1:]) + " +_on_compute_node=True" 79 | project_root = Path.cwd() 80 | while not (project_root / ".git").exists(): 81 | project_root = project_root.parent 82 | if project_root == Path("/"): 83 | raise Exception("Could not find repo directory!") 84 | 85 | slurm_log_dir = submit_slurm_job( 86 | cfg, 87 | python_args, 88 | project_root, 89 | ) 90 | 91 | if ( 92 | "cluster" in cfg 93 | and cfg.cluster.is_compute_node_offline 94 | and cfg.wandb.mode == "online" 95 | ): 96 | print( 97 | "Job submitted to a compute node without internet. This requires manual syncing on login node." 98 | ) 99 | osh_command_dir = project_root / ".wandb_osh_command_dir" 100 | 101 | osh_proc = None 102 | # if click.confirm("Do you want us to run the sync loop for you?", default=True): 103 | osh_proc = subprocess.Popen(["wandb-osh", "--command-dir", osh_command_dir]) 104 | print(f"Running wandb-osh in background... PID: {osh_proc.pid}") 105 | print(f"To kill the sync process, run 'kill {osh_proc.pid}' in the terminal.") 106 | print( 107 | f"You can manually start a sync loop later by running the following:", 108 | cyan(f"wandb-osh --command-dir {osh_command_dir}"), 109 | ) 110 | 111 | print( 112 | "Once the job gets allocated and starts running, we will print a command below " 113 | "for you to trace the errors and outputs: (Ctrl + C to exit without waiting)" 114 | ) 115 | msg = f"tail -f {slurm_log_dir}/* \n" 116 | try: 117 | while not list(slurm_log_dir.glob("*.out")) and not list( 118 | slurm_log_dir.glob("*.err") 119 | ): 120 | time.sleep(1) 121 | print(cyan("To trace the outputs and errors, run the following command:"), msg) 122 | except KeyboardInterrupt: 123 | print("Keyboard interrupt detected. Exiting...") 124 | print( 125 | cyan( 126 | "To trace the outputs and errors, manually wait for the job to start and run the following command:" 127 | ), 128 | msg, 129 | ) 130 | 131 | 132 | @hydra.main( 133 | version_base=None, 134 | config_path="configurations", 135 | config_name="config", 136 | ) 137 | def run(cfg: DictConfig): 138 | if "_on_compute_node" in cfg and cfg.cluster.is_compute_node_offline: 139 | with open_dict(cfg): 140 | if cfg.cluster.is_compute_node_offline and cfg.wandb.mode == "online": 141 | cfg.wandb.mode = "offline" 142 | 143 | if "name" not in cfg: 144 | raise ValueError( 145 | "must specify a name for the run with command line argument '+name=[name]'" 146 | ) 147 | 148 | if not cfg.wandb.get("entity", None): 149 | raise ValueError( 150 | "must specify wandb entity in 'configurations/config.yaml' or with command line" 151 | " argument 'wandb.entity=[entity]' \n An entity is your wandb user name or group" 152 | " name. This is used for logging. If you don't have an wandb account, please signup at https://wandb.ai/" 153 | ) 154 | 155 | if cfg.wandb.project is None: 156 | cfg.wandb.project = str(Path(__file__).parent.name) 157 | 158 | # If resuming or loading a wandb ckpt and not on a compute node, download the checkpoint. 159 | resume = cfg.get("resume", None) 160 | load = cfg.get("load", None) 161 | 162 | if resume and load: 163 | raise ValueError( 164 | "When resuming a wandb run with `resume=[wandb id]`, checkpoint will be loaded from the cloud" 165 | "and `load` should not be specified." 166 | ) 167 | 168 | if resume: 169 | load_id = resume 170 | elif load and is_run_id(load): 171 | load_id = load 172 | else: 173 | load_id = None 174 | 175 | if load_id and "_on_compute_node" not in cfg: 176 | run_path = f"{cfg.wandb.entity}/{cfg.wandb.project}/{load_id}" 177 | download_latest_checkpoint(run_path, Path("outputs/downloaded")) 178 | 179 | if "cluster" in cfg and not "_on_compute_node" in cfg: 180 | print( 181 | cyan( 182 | "Slurm detected, submitting to compute node instead of running locally..." 183 | ) 184 | ) 185 | run_slurm(cfg) 186 | else: 187 | run_local(cfg) 188 | 189 | 190 | if __name__ == "__main__": 191 | run() 192 | -------------------------------------------------------------------------------- /datasets/openx_base.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import pickle 6 | from tqdm import tqdm 7 | 8 | from .video_base import VideoDataset 9 | from utils.video_utils import write_numpy_to_mp4 10 | 11 | 12 | class OpenXVideoDataset(VideoDataset): 13 | def preprocess_record(self, record): 14 | record["fps"] = self.cfg.download.openx_fps 15 | # if "bbox" in record: 16 | # bbox = eval(record["bbox"]) 17 | # if len(bbox) == 5: 18 | # record["has_bbox"] = True 19 | # record["bbox_left"] = bbox[0] 20 | # record["bbox_top"] = bbox[1] 21 | # record["bbox_right"] = bbox[2] 22 | # record["bbox_bottom"] = bbox[3] 23 | # else: 24 | # record["has_bbox"] = False 25 | # record["bbox_left"] = 0 26 | # record["bbox_top"] = 0 27 | # record["bbox_right"] = 0 28 | # record["bbox_bottom"] = 0 29 | return record 30 | 31 | def download(self): 32 | import tensorflow_datasets as tfds 33 | import tensorflow as tf 34 | from utils.tf_utils import recursive_cast_to_numpy 35 | 36 | all_episode_dir = self.data_root / "episodes" 37 | all_episode_dir.mkdir(parents=True, exist_ok=True) 38 | 39 | builder = tfds.builder_from_directory( 40 | builder_dir=f"gs://gresearch/robotics/{self.cfg.download.openx_name}/{self.cfg.download.openx_version}" 41 | ) 42 | info = builder.info 43 | n_episodes = info.splits["train"].num_examples 44 | 45 | # Count number of episodes to skip based on existing state files 46 | for episode_id in range(n_episodes): 47 | episode_dir = all_episode_dir / f"episode_{episode_id}" 48 | state_path = episode_dir / "states.pkl" 49 | if not state_path.exists(): 50 | break 51 | 52 | if episode_id > 0: 53 | print(f"Skipping {episode_id} already downloaded episodes") 54 | dataset = builder.as_dataset(split=f"train[{episode_id}:]") 55 | 56 | dataset = dataset.prefetch(tf.data.AUTOTUNE) 57 | for episode_data in tqdm(dataset, total=n_episodes - episode_id): 58 | episode_dir = all_episode_dir / f"episode_{episode_id}" 59 | episode_dir.mkdir(parents=True, exist_ok=True) 60 | episode_records = defaultdict(list) 61 | state_path = episode_dir / "states.pkl" 62 | if state_path.exists(): 63 | continue 64 | 65 | episode = defaultdict(list) 66 | videos = defaultdict(list) 67 | fields_to_stack = [] 68 | for k, v in episode_data.items(): 69 | if k != "steps": 70 | episode[k] = recursive_cast_to_numpy(v) 71 | 72 | # sometimes we can split a video into multiple segments based on caption 73 | segments = { 74 | "natural_language_instruction": [], 75 | "instruction": [], 76 | "language_instruction": [], 77 | "language_instruction_2": [], 78 | "language_instruction_3": [], 79 | } 80 | for idx, step in enumerate(episode_data["steps"]): 81 | step = recursive_cast_to_numpy(step) 82 | obs_dict = step["observation"] 83 | action_dict = step["action"] 84 | if hasattr(obs_dict, "shape"): 85 | obs_dict = dict(observation=obs_dict) 86 | if hasattr(action_dict, "shape"): 87 | action_dict = dict(action=action_dict) 88 | 89 | # some times caption field is here but mostly in observation 90 | for k, v in step.items(): 91 | if k in segments: 92 | obs_dict[k] = v 93 | 94 | for k, v in obs_dict.items(): 95 | if hasattr(v, "shape") and len(v.shape) == 3 and v.shape[-1] == 3: 96 | videos[k].append(v) 97 | elif k in segments: 98 | if ( 99 | k == "instruction" 100 | and self.cfg.download.openx_name == "language_table" 101 | ): 102 | # special case for language table dataset 103 | v = tf.convert_to_tensor(v) 104 | v = tf.strings.unicode_encode(v, output_encoding="UTF-8") 105 | v = v.numpy().decode("utf-8").split("\x00")[0] 106 | if not segments[k] or segments[k][-1][1] != v: 107 | segments[k].append((idx, v)) 108 | elif k != "natural_language_embedding": 109 | if hasattr(v, "shape"): 110 | fields_to_stack.append("observation/" + k) 111 | episode["observation/" + k].append(v) 112 | 113 | for k, v in action_dict.items(): 114 | fields_to_stack.append("action/" + k) 115 | episode["action/" + k].append(v) 116 | 117 | for k in list(segments.keys()): 118 | if not segments[k]: 119 | del segments[k] 120 | continue 121 | segments[k].append((idx + 1, "")) 122 | if not segments: 123 | segments["not_captioned"] = [(0, ""), (idx + 1, "")] 124 | 125 | for view, frames in videos.items(): 126 | frames = np.stack(frames) 127 | n, h, w, _ = frames.shape 128 | video_path = episode_dir / f"{view}.mp4" 129 | 130 | if h % 2 != 0: 131 | h = h - 1 132 | frames = frames[:, :h, :, :] 133 | if w % 2 != 0: 134 | w = w - 1 135 | frames = frames[:, :, :w, :] 136 | write_numpy_to_mp4(frames, str(video_path)) 137 | 138 | for k, v in segments.items(): 139 | for s in range(len(v) - 1): 140 | start_idx, caption = v[s] 141 | end_idx = v[s + 1][0] 142 | record = dict( 143 | video_path=str(video_path.relative_to(self.data_root)), 144 | state_path=str(state_path.relative_to(self.data_root)), 145 | height=h, 146 | width=w, 147 | n_frames=end_idx - start_idx, 148 | trim_start=start_idx, 149 | trim_end=end_idx, 150 | fps=self.cfg.download.openx_fps, 151 | original_caption=caption, 152 | has_caption=v[0][1] != "", 153 | ) 154 | episode_records[view].append(record) 155 | for view, records in episode_records.items(): 156 | df = pd.DataFrame.from_records(records) 157 | df.to_csv(episode_dir / f"{view}.csv", index=False) 158 | 159 | for k in fields_to_stack: 160 | episode[k] = np.stack(episode[k]) 161 | with open(state_path, "wb") as f: 162 | pickle.dump(episode, f) 163 | episode_id += 1 164 | 165 | # Save metadata 166 | metadata_path = self.data_root / self.metadata_path 167 | metadata_dir = metadata_path.parent 168 | metadata_dir.mkdir(parents=True, exist_ok=True) 169 | record_dict = defaultdict(list) 170 | for episode_dir in all_episode_dir.glob("episode_*"): 171 | for view_csv in episode_dir.glob("*.csv"): 172 | view_csv = view_csv.name 173 | view_df = pd.read_csv(episode_dir / view_csv) 174 | record_dict[view_csv].extend(view_df.to_dict("records")) 175 | all_df = [] 176 | for view_csv, records in record_dict.items(): 177 | df = pd.DataFrame.from_records(records) 178 | df.to_csv(metadata_dir / view_csv, index=False) 179 | print( 180 | f"Created metadata csv for view {view_csv.split('.')[0]} with {len(df)} records" 181 | ) 182 | if view_csv.replace(".csv", "") in self.cfg.download.views: 183 | all_df.append(df) 184 | all_df = pd.concat(all_df) 185 | all_df.to_csv(metadata_path, index=False) 186 | print(f"Created metadata CSV with {len(all_df)} records") 187 | -------------------------------------------------------------------------------- /utils/gemini_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import base64 4 | import queue 5 | import threading 6 | import traceback 7 | import time 8 | import gc 9 | from typing import Any, Dict, List 10 | from dataclasses import dataclass 11 | 12 | # Gemini / Vertex AI imports 13 | import vertexai 14 | from vertexai.generative_models import GenerativeModel, Part 15 | 16 | 17 | @dataclass 18 | class VideoEntry: 19 | mp4_path: str 20 | # optional keys below: 21 | youtube_key_segment: str = None 22 | duration: float = None 23 | fps: float = None 24 | height: int = None 25 | width: int = None 26 | n_frames: int = None 27 | # Add other metadata fields as needed 28 | 29 | 30 | @dataclass 31 | class CaptionResult: 32 | mp4_path: str 33 | caption: str 34 | # optional keys below: 35 | youtube_key_segment: str = None 36 | duration: float = None 37 | fps: float = None 38 | height: int = None 39 | width: int = None 40 | n_frames: int = None 41 | 42 | 43 | class GeminiCaptionProcessor: 44 | def __init__(self, output_file: str, num_workers: int = 12): 45 | self.output_file = output_file 46 | self.num_workers = num_workers 47 | self.entry_queue = queue.Queue() 48 | self.results_queue = queue.Queue() 49 | self.workers = [] 50 | self.success_count = 0 51 | self.fail_count = 0 52 | self.start_time = None 53 | self.end_time = None 54 | 55 | # Initialize Vertex AI 56 | PROJECT_ID = "" 57 | model_index = 0 58 | LOCATION = ["us-central1", "us-east5"][model_index] 59 | vertexai.init(project=PROJECT_ID, location=LOCATION) 60 | MODEL_NAME = ["gemini-2.0-flash-001"][model_index] # "gemini-2.0-flash-001" 61 | self.model = GenerativeModel(model_name=MODEL_NAME) 62 | print(f"Using model: {MODEL_NAME}") 63 | 64 | self.prompt = ( 65 | "Summarize this video directly, when summarizing please provide a detailed description of major subjects, actions, and interactions. " 66 | "Focus on key actions, interactions, and movements. Include camera movements. " 67 | "Keep the summary brief and clear. " 68 | "Only include information that is certain, and avoid speculation or assumptions." 69 | "In the last sentence, answer the question with just Yes or No, does the video contain rich human hand motions?" 70 | ) 71 | # Lock for updating success and fail counts 72 | self.count_lock = threading.Lock() 73 | 74 | self.optional_keys = [ 75 | "duration", 76 | "fps", 77 | "height", 78 | "width", 79 | "n_frames", 80 | "youtube_key_segment", 81 | ] 82 | 83 | def process_entries(self, records: List[Dict[str, Any]]): 84 | self.start_time = time.time() 85 | # Start worker threads 86 | for _ in range(self.num_workers): 87 | worker = threading.Thread(target=self._worker_process, daemon=True) 88 | worker.start() 89 | self.workers.append(worker) 90 | 91 | # Producer: read input lines and put them into the queue 92 | to_process_count = 0 93 | for data in records: 94 | entry = VideoEntry( 95 | mp4_path=data["video_path"], 96 | ) 97 | # add optional keys to entry: 98 | for key in self.optional_keys: 99 | if key in data: 100 | entry.__dict__[key] = data[key] 101 | self.entry_queue.put(entry) 102 | to_process_count += 1 103 | 104 | if to_process_count == 0: 105 | print("No new entries to process. All done!") 106 | # Even if none, still send sentinels to avoid blocking 107 | for _ in range(self.num_workers): 108 | self.entry_queue.put(None) 109 | return 110 | 111 | # Add sentinel values to signal workers to stop 112 | for _ in range(self.num_workers): 113 | self.entry_queue.put(None) 114 | 115 | # Wait for all workers to finish 116 | for worker in self.workers: 117 | worker.join() 118 | 119 | # Collect results 120 | results = [] 121 | while not self.results_queue.empty(): 122 | result = self.results_queue.get() 123 | # Only append results that aren't error messages 124 | if not result.caption.startswith("Error"): 125 | results.append(result) 126 | 127 | # Append results to output file 128 | with open(self.output_file, "a", encoding="utf-8") as f: 129 | for result in results: 130 | obj = {"video_path": result.mp4_path, "caption": result.caption} 131 | for key in self.optional_keys: 132 | if key in result.__dict__ and result.__dict__[key] is not None: 133 | obj[key] = result.__dict__[key] 134 | f.write(json.dumps(obj) + "\n") 135 | 136 | self.end_time = time.time() 137 | total_time = self.end_time - self.start_time 138 | print(f"Processed {len(results)} entries successfully.") 139 | print(f"Failed on {self.fail_count} entries.") 140 | print(f"Total time: {total_time:.2f} seconds.") 141 | if to_process_count > 0: 142 | print(f"Throughput: {to_process_count / total_time:.2f} videos/second.") 143 | print(f"Output file: {self.output_file}") 144 | 145 | def _read_video_file(self, file_path): 146 | """Read video file and convert it to base64.""" 147 | if not os.path.exists(file_path): 148 | raise FileNotFoundError(f"Video file not found: {file_path}") 149 | with open(file_path, "rb") as video_file: 150 | return base64.b64encode(video_file.read()).decode("utf-8") 151 | 152 | def get_gemini_caption(self, video_path: str) -> str: 153 | """Generate a caption for a single video using Gemini Flash.""" 154 | video_data = self._read_video_file(video_path) 155 | video_part = Part.from_data(data=video_data, mime_type="video/mp4") 156 | try: 157 | response = self.model.generate_content( 158 | [video_part, self.prompt], 159 | # generation_config={ 160 | # "max_output_tokens": 1024, 161 | # "temperature": 0.4 162 | # }, 163 | stream=False, 164 | ) 165 | return response.text 166 | except Exception as e: 167 | print(f"Error from Gemini API: {e}") 168 | return f"Error from Gemini API: {e}" 169 | 170 | def _process_single_entry(self, entry: VideoEntry) -> CaptionResult: 171 | caption = self.get_gemini_caption(entry.mp4_path) 172 | 173 | ret_result = CaptionResult(mp4_path=entry.mp4_path, caption=caption) 174 | for key in self.optional_keys: 175 | if key in entry.__dict__ and entry.__dict__[key] is not None: 176 | ret_result.__dict__[key] = entry.__dict__[key] 177 | return ret_result 178 | 179 | def _worker_process(self): 180 | while True: 181 | entry = self.entry_queue.get() 182 | if entry is None: # Check for sentinel value 183 | break 184 | if self.entry_queue.qsize() % 100 == 0: 185 | print( 186 | f"Processing {entry.mp4_path}. {self.entry_queue.qsize()} entries left in queue." 187 | ) 188 | gc_s_time = time.time() 189 | num_gc = gc.collect() 190 | gc_e_time = time.time() 191 | print( 192 | f"Garbage collection took {gc_e_time - gc_s_time} seconds, collected {num_gc} objects" 193 | ) 194 | try: 195 | result = self._process_single_entry(entry) 196 | # Check if result is error. If not, add to results_queue. 197 | if not result.caption.startswith("Error"): 198 | with self.count_lock: 199 | self.success_count += 1 200 | self.results_queue.put(result) 201 | else: 202 | with self.count_lock: 203 | self.fail_count += 1 204 | print(f"Skipping {entry.mp4_path} due to error in captioning.") 205 | except Exception as e: 206 | with self.count_lock: 207 | self.fail_count += 1 208 | print(f"Error processing {entry.mp4_path}: {str(e)}") 209 | traceback.print_exc() 210 | finally: 211 | self.entry_queue.task_done() 212 | -------------------------------------------------------------------------------- /algorithms/common/base_pytorch_algo.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import warnings 3 | from typing import Any, Union, Sequence, Optional 4 | 5 | from lightning.pytorch.utilities.types import STEP_OUTPUT 6 | from omegaconf import DictConfig 7 | import lightning.pytorch as pl 8 | import torch 9 | import numpy as np 10 | from PIL import Image 11 | import wandb 12 | import einops 13 | 14 | 15 | class BasePytorchAlgo(pl.LightningModule, ABC): 16 | """ 17 | A base class for Pytorch algorithms using Pytorch Lightning. 18 | See https://lightning.ai/docs/pytorch/stable/starter/introduction.html for more details. 19 | """ 20 | 21 | def __init__(self, cfg: DictConfig): 22 | self.cfg = cfg 23 | self.debug = self.cfg.debug 24 | super().__init__() 25 | 26 | @abstractmethod 27 | def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT: 28 | r"""Here you compute and return the training loss and some additional metrics for e.g. the progress bar or 29 | logger. 30 | 31 | Args: 32 | batch: The output of your data iterable, normally a :class:`~torch.utils.data.DataLoader`. 33 | batch_idx: The index of this batch. 34 | dataloader_idx: (only if multiple dataloaders used) The index of the dataloader that produced this batch. 35 | 36 | Return: 37 | Any of these options: 38 | - :class:`~torch.Tensor` - The loss tensor 39 | - ``dict`` - A dictionary. Can include any keys, but must include the key ``'loss'``. 40 | - ``None`` - Skip to the next batch. This is only supported for automatic optimization. 41 | This is not supported for multi-GPU, TPU, IPU, or DeepSpeed. 42 | 43 | In this step you'd normally do the forward pass and calculate the loss for a batch. 44 | You can also do fancier things like multiple forward passes or something model specific. 45 | 46 | Example:: 47 | 48 | def training_step(self, batch, batch_idx): 49 | x, y, z = batch 50 | out = self.encoder(x) 51 | loss = self.loss(out, x) 52 | return loss 53 | 54 | To use multiple optimizers, you can switch to 'manual optimization' and control their stepping: 55 | 56 | .. code-block:: python 57 | 58 | def __init__(self): 59 | super().__init__() 60 | self.automatic_optimization = False 61 | 62 | 63 | # Multiple optimizers (e.g.: GANs) 64 | def training_step(self, batch, batch_idx): 65 | opt1, opt2 = self.optimizers() 66 | 67 | # do training_step with encoder 68 | ... 69 | opt1.step() 70 | # do training_step with decoder 71 | ... 72 | opt2.step() 73 | 74 | Note: 75 | When ``accumulate_grad_batches`` > 1, the loss returned here will be automatically 76 | normalized by ``accumulate_grad_batches`` internally. 77 | 78 | """ 79 | return super().training_step(*args, **kwargs) 80 | 81 | def configure_optimizers(self): 82 | """ 83 | Return an optimizer. If you need to use more than one optimizer, refer to pytorch lightning documentation: 84 | https://lightning.ai/docs/pytorch/stable/common/optimization.html 85 | """ 86 | parameters = self.parameters() 87 | return torch.optim.Adam(parameters, lr=self.cfg.lr) 88 | 89 | def log_video( 90 | self, 91 | key: str, 92 | video: Union[np.ndarray, torch.Tensor], 93 | mean: Union[np.ndarray, torch.Tensor, Sequence, float] = None, 94 | std: Union[np.ndarray, torch.Tensor, Sequence, float] = None, 95 | fps: int = 12, 96 | format: str = "mp4", 97 | caption: str = None, 98 | step: int = None, 99 | ): 100 | """ 101 | Log video to wandb. WandbLogger in pytorch lightning does not support video logging yet, so we call wandb directly. 102 | 103 | Args: 104 | video: a numpy array or tensor, either in form (time, channel, height, width) or in the form 105 | (batch, time, channel, height, width). The content must be be in 0-255 if under dtype uint8 106 | or [0, 1] otherwise. 107 | mean: optional, the mean to unnormalize video tensor, assuming unnormalized data is in [0, 1]. 108 | std: optional, the std to unnormalize video tensor, assuming unnormalized data is in [0, 1]. 109 | key: the name of the video. 110 | fps: the frame rate of the video. 111 | format: the format of the video. Can be either "mp4" or "gif". 112 | """ 113 | 114 | if isinstance(video, torch.Tensor): 115 | video = video.detach().cpu().float().numpy() 116 | 117 | expand_shape = [1] * (len(video.shape) - 2) + [3, 1, 1] 118 | if std is not None: 119 | if isinstance(std, (float, int)): 120 | std = [std] * 3 121 | if isinstance(std, torch.Tensor): 122 | std = std.detach().cpu().numpy() 123 | std = np.array(std).reshape(*expand_shape) 124 | video = video * std 125 | if mean is not None: 126 | if isinstance(mean, (float, int)): 127 | mean = [mean] * 3 128 | if isinstance(mean, torch.Tensor): 129 | mean = mean.detach().cpu().numpy() 130 | mean = np.array(mean).reshape(*expand_shape) 131 | video = video + mean 132 | 133 | if video.dtype != np.uint8: 134 | video = np.clip(video, a_min=0, a_max=1) * 255 135 | video = video.astype(np.uint8) 136 | 137 | self.logger.experiment.log( 138 | { 139 | key: wandb.Video(video, fps=fps, format=format, caption=caption), 140 | }, 141 | step=self.global_step if step is None else step, 142 | ) 143 | 144 | def log_image( 145 | self, 146 | key: str, 147 | image: Union[np.ndarray, torch.Tensor, Image.Image, Sequence[Image.Image]], 148 | mean: Union[np.ndarray, torch.Tensor, Sequence, float] = None, 149 | std: Union[np.ndarray, torch.Tensor, Sequence, float] = None, 150 | **kwargs: Any, 151 | ): 152 | """ 153 | Log image(s) using WandbLogger. 154 | Args: 155 | key: the name of the video. 156 | image: a single image or a batch of images. If a batch of images, the shape should be (batch, channel, height, width). 157 | mean: optional, the mean to unnormalize image tensor, assuming unnormalized data is in [0, 1]. 158 | std: optional, the std to unnormalize tensor, assuming unnormalized data is in [0, 1]. 159 | kwargs: optional, WandbLogger log_image kwargs, such as captions=xxx. 160 | """ 161 | if isinstance(image, Image.Image): 162 | image = [image] 163 | elif len(image) and not isinstance(image[0], Image.Image): 164 | if isinstance(image, torch.Tensor): 165 | image = image.detach().cpu().numpy() 166 | 167 | if len(image.shape) == 3: 168 | image = image[None] 169 | 170 | if image.shape[1] == 3: 171 | if image.shape[-1] == 3: 172 | warnings.warn( 173 | f"Two channels in shape {image.shape} have size 3, assuming channel first." 174 | ) 175 | image = einops.rearrange(image, "b c h w -> b h w c") 176 | 177 | if std is not None: 178 | if isinstance(std, (float, int)): 179 | std = [std] * 3 180 | if isinstance(std, torch.Tensor): 181 | std = std.detach().cpu().numpy() 182 | std = np.array(std)[None, None, None] 183 | image = image * std 184 | if mean is not None: 185 | if isinstance(mean, (float, int)): 186 | mean = [mean] * 3 187 | if isinstance(mean, torch.Tensor): 188 | mean = mean.detach().cpu().numpy() 189 | mean = np.array(mean)[None, None, None] 190 | image = image + mean 191 | 192 | if image.dtype != np.uint8: 193 | image = np.clip(image, a_min=0.0, a_max=1.0) * 255 194 | image = image.astype(np.uint8) 195 | image = [img for img in image] 196 | 197 | self.logger.log_image(key=key, images=image, **kwargs) 198 | 199 | def log_gradient_stats(self): 200 | """Log gradient statistics such as the mean or std of norm.""" 201 | 202 | with torch.no_grad(): 203 | grad_norms = [] 204 | gpr = [] # gradient-to-parameter ratio 205 | for param in self.parameters(): 206 | if param.grad is not None: 207 | grad_norms.append(torch.norm(param.grad).item()) 208 | gpr.append(torch.norm(param.grad) / torch.norm(param)) 209 | if len(grad_norms) == 0: 210 | return 211 | grad_norms = torch.tensor(grad_norms) 212 | gpr = torch.tensor(gpr) 213 | self.log_dict( 214 | { 215 | "train/grad_norm/min": grad_norms.min(), 216 | "train/grad_norm/max": grad_norms.max(), 217 | "train/grad_norm/std": grad_norms.std(), 218 | "train/grad_norm/mean": grad_norms.mean(), 219 | "train/grad_norm/median": torch.median(grad_norms), 220 | "train/gpr/min": gpr.min(), 221 | "train/gpr/max": gpr.max(), 222 | "train/gpr/std": gpr.std(), 223 | "train/gpr/mean": gpr.mean(), 224 | "train/gpr/median": torch.median(gpr), 225 | } 226 | ) 227 | 228 | def register_data_mean_std( 229 | self, 230 | mean: Union[str, float, Sequence], 231 | std: Union[str, float, Sequence], 232 | namespace: str = "data", 233 | ): 234 | """ 235 | Register mean and std of data as tensor buffer. 236 | 237 | Args: 238 | mean: the mean of data. 239 | std: the std of data. 240 | namespace: the namespace of the registered buffer. 241 | """ 242 | for k, v in [("mean", mean), ("std", std)]: 243 | if isinstance(v, str): 244 | if v.endswith(".npy"): 245 | v = torch.from_numpy(np.load(v)) 246 | elif v.endswith(".pt"): 247 | v = torch.load(v) 248 | else: 249 | raise ValueError(f"Unsupported file type {v.split('.')[-1]}.") 250 | else: 251 | v = torch.tensor(v) 252 | self.register_buffer(f"{namespace}_{k}", v.float().to(self.device)) 253 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # Large Video Planner Enables Generalizable Robot Control 4 | 5 | This repo provides training and inference code for the paper "Large Video Planner Enables Generalizable Robot Control" 6 | 7 | [Paper](http://arxiv.org/abs/2512.15840)   8 | [Project Webpage](https://www.boyuan.space/large-video-planner/)   9 | [Hugging Face Demo](https://huggingface.co/spaces/KempnerInstituteAI/LVP) 10 | 11 | 12 |
13 | 14 | 15 | # Downloading the dataset 16 | Download all the metadata file for eight filtered dataset, and our third-party collected test set: 17 | ```bash 18 | huggingface-cli download KempnerInstituteAI/LVP \ 19 | --include "data/**" \ 20 | --local-dir . \ 21 | --local-dir-use-symlinks False 22 | ``` 23 | This will download each data folder under `data/`. 24 | 25 | # Downloading the checkpoints 26 | 27 | Please put all downloaded checkpoints within `data/ckpts` 28 | 29 | ## Downloading Our Fine-tuned Checkpoints 30 | 31 | ```bash 32 | huggingface-cli download KempnerInstituteAI/LVP \ 33 | --include "checkpoints/**" \ 34 | --local-dir . \ 35 | --local-dir-use-symlinks False 36 | 37 | mv checkpoints data/ckpts 38 | ``` 39 | This will take 66 GB disk space, be careful. 40 | 41 | After downloading the trained checkpoints should be in path: `data/ckpts/lvp_14B.ckpt`. 42 | This path is specified in `configurations/algorithm/wan_i2v.yaml` 43 | 44 | 45 | ## Downloading Wan 2.1 Pre-trained Checkpoints 46 | 47 | This codebase uses the **Wan 2.1 Image-to-Video (I2V) 14B** model for video generation. The checkpoint includes: 48 | - Wan2.1 diffusion model weights (14B parameters), from which we finetuned on. 49 | - VAE encoder/decoder 50 | - T5 text encoder (UMT5-XXL) 51 | - CLIP image encoder (XLM-Roberta-Large-ViT-Huge) 52 | 53 | **Official Download Instructions**: Please refer to the [Wan 2.1 GitHub repository](https://github.com/Wan-Video/Wan2.1#model-download) for the most up-to-date checkpoint download instructions. 54 | 55 | **Quick Download** (using Hugging Face CLI): 56 | ```bash 57 | # Download Wan 2.1 I2V 14B 480P (recommended for this codebase) 58 | huggingface-cli download Wan-AI/Wan2.1-I2V-14B-480P --local-dir ./data/ckpts/Wan2.1-I2V-14B-480P 59 | ``` 60 | 61 | The checkpoint will be downloaded to `./data/ckpts/Wan2.1-I2V-14B-480P/` and automatically includes all necessary components (VAE, T5, CLIP, main model). 62 | 63 | **Note**: The 480P model is used in our training pipeline. The checkpoint path is configured in [configurations/algorithm/wan_i2v.yaml](configurations/algorithm/wan_i2v.yaml). 64 | 65 | 66 | 67 | 68 | # Instructions for running the code 69 | 70 | This document provides detailed instructions for running inference and training with the EI World Model codebase. 71 | 72 | ## Table of Contents 73 | - [Environment Setup](#environment-setup) 74 | - [How to Run Inference](#how-to-run-inference) 75 | - [How to Run Training](#how-to-run-training) 76 | 77 | --- 78 | 79 | ## Environment Setup 80 | 81 | ### Prerequisites 82 | - Python 3.10 83 | - CUDA 12.1+ (for GPU support) 84 | - Conda or Mamba package manager 85 | 86 | ### Step 1: Create Conda Environment 87 | 88 | ```bash 89 | # using conda 90 | conda create python=3.10 -n ei_world_model 91 | conda activate ei_world_model 92 | ``` 93 | 94 | ### Step 2: Install Dependencies 95 | 96 | We store python dependencies in **`requirements.txt`** 97 | 98 | ```bash 99 | # Install core dependencies 100 | pip install -r requirements.txt 101 | 102 | 103 | # Install Flash Attention (for efficient attention) 104 | # This may take several minutes to compile 105 | pip install flash-attn --no-build-isolation 106 | ``` 107 | 108 | **Note**: If you encounter issues with `flash-attn`, you can skip it for inference-only usage. It's primarily needed for efficient training. 109 | 110 | ### Step 3: Configure WandB (Weights & Biases) 111 | 112 | WandB is used for experiment tracking and logging. 113 | 114 | ```bash 115 | # Login to WandB 116 | wandb login 117 | 118 | # Or set your API key 119 | export WANDB_API_KEY=your_api_key_here 120 | ``` 121 | 122 | Update your WandB entity in [configurations/config.yaml](configurations/config.yaml): 123 | 124 | ```yaml 125 | wandb: 126 | entity: your-wandb-username # Change this to your WandB username or org 127 | project: ei_world_model 128 | mode: online # Use 'offline' for no internet, 'dryrun' for testing 129 | ``` 130 | 131 | Note we set wandb to offline by default, so you can go through other part of the code without setting wandb first. 132 | 133 | ### Step 4: Verify Installation 134 | 135 | Test your installation with a quick inference run: 136 | 137 | ```bash 138 | # Test with toy model (no checkpoints needed) 139 | python -m main \ 140 | +name=test_installation \ 141 | experiment=exp_video \ 142 | algorithm=wan_toy \ 143 | dataset=dummy \ 144 | experiment.tasks=[validation] \ 145 | experiment.validation.limit_batch=1 146 | ``` 147 | 148 | If this runs without errors, your environment is set up correctly! 149 | 150 | ### Environment Variables 151 | 152 | For distributed training on SLURM clusters, you may need to set: 153 | 154 | ```bash 155 | # For offline compute nodes with WandB sync 156 | export WANDB_MODE=offline 157 | export WANDB_DIR=/path/to/wandb/logs 158 | 159 | # For debugging 160 | export HYDRA_FULL_ERROR=1 161 | export CUDA_LAUNCH_BLOCKING=1 162 | ``` 163 | 164 | ### Troubleshooting 165 | 166 | **Issue: CUDA out of memory** 167 | - Reduce batch size in experiment config: `experiment.training.batch_size=1` 168 | - Enable gradient checkpointing: `algorithm.gradient_checkpointing_rate=1.0` 169 | 170 | 171 | --- 172 | 173 | ## How to Run Inference 174 | 175 | Inference generates videos given an image and a text prompt using a pretrained model. 176 | 177 | ### Basic Inference Command 178 | 179 | ```bash 180 | mkdir -p 181 | python -m main \ 182 | +name= \ 183 | experiment=exp_video \ 184 | algorithm=wan_i2v \ 185 | dataset=ours_test \ 186 | experiment.tasks=[validation] \ 187 | algorithm.logging.video_type=single \ 188 | experiment.num_nodes=1 \ 189 | experiment.validation.limit_batch=null \ 190 | algorithm.hist_guidance=1.5 \ 191 | algorithm.lang_guidance=2.5 192 | ``` 193 | 194 | ### Command Arguments Explained 195 | 196 | #### Required Arguments 197 | - **`+name=`**: Unique experiment name for this run. Used for logging and organizing outputs in WandB and file system. 198 | 199 | #### Core Configuration 200 | - **`experiment=exp_video`**: Specifies the experiment type 201 | - Points to: [configurations/experiment/exp_video.yaml](configurations/experiment/exp_video.yaml) 202 | - Defines: Training/validation settings, tasks, precision, batch size 203 | 204 | - **`algorithm=wan_i2v`**: Selects the Wan 2.1 Image-to-Video model 205 | - Points to: [configurations/algorithm/wan_i2v.yaml](configurations/algorithm/wan_i2v.yaml) 206 | - Inherits from: [wan_t2v.yaml](configurations/algorithm/wan_t2v.yaml) 207 | - **You need to set checkpoint path in these two yaml files** 208 | 209 | - **`dataset=ours_test`**: Specifies evaluation dataset 210 | - Should point to: `configurations/dataset/ours_test.yaml` 211 | - Format: CSV with metadata (video_path, caption, height, width, fps, n_frames) 212 | - The specific CSV format is disscused in dataset/README.md 213 | 214 | #### Task Configuration 215 | - **`experiment.tasks=[validation]`**: Runs validation/inference mode 216 | - Executes the `validation()` method in [experiments/exp_video.py](experiments/exp_video.py) 217 | 218 | #### Cluster Configuration 219 | - **`cluster=fast_high`**: SLURM cluster settings we used for evaluation 220 | - Points to: [configurations/cluster/phase3_eval.yaml](configurations/cluster/fas_high.yaml) 221 | - Settings: 4 H100 GPUs, 48 CPUs, 512GB memory, 1-day time limit 222 | 223 | - **`experiment.num_nodes=1`**: Number of compute nodes (1 for inference) 224 | 225 | #### Inference Parameters 226 | - **`experiment.validation.limit_batch=null`**: Process all batches 227 | - Set to a number (e.g., `10`) to limit evaluation to N batches for quick testing 228 | 229 | - **`algorithm.hist_guidance=1.5`**: Historical guidance scale for conditioning on previous frames 230 | - Controls how strongly the model follows the input image 231 | - Range: 0.0 (no guidance) to 3.0+ (strong guidance) 232 | - Recommend: 1.5 233 | 234 | - **`algorithm.lang_guidance=2.5`**: Language guidance scale (classifier-free guidance) 235 | - Controls how strongly the model follows the text prompt 236 | - Range: 0.0 (no guidance) to 5.0+ (strong guidance) 237 | - Recommend: 2.0, 2.5 238 | 239 | #### Logging Configuration 240 | - **`algorithm.logging.video_type=single`**: Save videos individually 241 | - Alternative: `grid` - saves all videos in a grid layout 242 | 243 | --- 244 | 245 | ## How to Run Training 246 | 247 | Training fine-tunes the Wan 2.1 models on custom video datasets. 248 | 249 | ### Full-Scale Training Command 250 | 251 | ```bash 252 | python -m main \ 253 | +name=final_i2v \ 254 | experiment=exp_video \ 255 | algorithm=wan_i2v \ 256 | dataset=mixture \ 257 | experiment.num_nodes=32 \ 258 | algorithm.lang_guidance=0 \ 259 | algorithm.hist_guidance=0 \ 260 | experiment.validation.val_every_n_step=100000000 261 | ``` 262 | 263 | ### Debug Training Command (Toy Model) 264 | 265 | For rapid iteration and debugging, use a smaller toy model: 266 | 267 | ```bash 268 | python -m main \ 269 | +name=print_dataset_mix_debug_train \ 270 | experiment=exp_video \ 271 | algorithm=wan_toy \ 272 | dataset=mixture \ 273 | experiment.num_nodes=1 \ 274 | algorithm.lang_guidance=0 \ 275 | algorithm.hist_guidance=0 \ 276 | experiment.validation.val_every_n_step=100000000 277 | ``` 278 | 279 | ### Training Arguments Explained 280 | 281 | #### Required Arguments 282 | - **`+name=final_i2v`**: Experiment name for WandB logging and checkpoints 283 | 284 | #### Core Configuration 285 | - **`experiment=`**: Same as inference 286 | - Default task: `[training]` (defined in [exp_video.yaml](configurations/experiment/exp_video.yaml)) 287 | 288 | - **`algorithm=wan_i2v`** or **`algorithm=wan_toy`**: 289 | - `wan_i2v`: Full 14B parameter model ([wan_i2v.yaml](configurations/algorithm/wan_i2v.yaml)) 290 | - `wan_toy`: Tiny model for debugging ([wan_toy.yaml](configurations/algorithm/wan_toy.yaml)) 291 | - Only 2 layers, 128 dimensions (vs 40 layers, 5120 dimensions) 292 | - No checkpoint loading required 293 | 294 | - **`dataset=mixture`**: Combined dataset of multiple sources 295 | - Points to: [configurations/dataset/mixture.yaml](configurations/dataset/mixture.yaml) 296 | - Includes: Pandas, Epic Kitchen, Ego4D, DROID, Something-Something, Bridge, AgibotWorld, Language Table 297 | - Weighted mixture based on dataset sizes and importance 298 | 299 | #### Cluster Configuration 300 | - **`cluster=phase3`**: Training cluster settings we used 301 | - Points to: [configurations/cluster/phase3.yaml](configurations/cluster/phase3.yaml) 302 | - Settings: 4 H100 GPUs per node, 32 node, priority queue, 14-day time limit 303 | 304 | - **`experiment.num_nodes=32`**: Multi-node distributed training 305 | - 32 nodes × 4 GPUs = 128 GPUs for full training 306 | - Set to `1` for debugging with toy model 307 | 308 | ### Configuration System (Hydra) 309 | 310 | The codebase uses Hydra for hierarchical configuration management: 311 | 312 | 1. **Base Config**: [configurations/config.yaml](configurations/config.yaml) 313 | - Specifies defaults: experiment, dataset, algorithm, cluster 314 | - WandB settings for logging 315 | 316 | 2. **Config Composition**: Hydra composes configs from multiple YAML files 317 | - Command-line overrides: `algorithm.lang_guidance=2.5` 318 | - Inheritance: `wan_i2v.yaml` inherits from `wan_t2v.yaml` 319 | 320 | 3. **Config Resolution**: [main.py](main.py) resolves all configs and passes to experiment 321 | 322 | ### Execution Flow 323 | 324 | 1. **Entry**: `python -m main +name=... experiment=... algorithm=... dataset=...` 325 | 2. **Hydra Setup**: [main.py](main.py) loads and merges all configs 326 | 3. **Experiment Creation**: [experiments/exp_video.py](experiments/exp_video.py) builds experiment 327 | 4. **Task Execution**: Calls `experiment.exec_task(task)` for each task in `experiment.tasks` 328 | - Training: Sets up dataloaders, trainer, and runs training loop 329 | - Validation: Loads model, generates videos, saves outputs 330 | --------------------------------------------------------------------------------