├── .gitignore ├── .isort.cfg ├── .pre-commit-config.yaml ├── Flag-DiT-ImageNet ├── README.md ├── exps │ ├── 600M_bs256_lr5e-4_bf16_qknorm_lognorm.sh │ └── slurm │ │ ├── 3B_bs256_lr5e-4_bf16_qknorm_lognorm.sh │ │ ├── 600M_bs256_lr5e-4_bf16_qknorm_lognorm.sh │ │ └── 7B_bs256_lr5e-4_bf16_qknorm_lognorm.sh ├── grad_norm.py ├── models │ ├── __init__.py │ ├── components.py │ └── model.py ├── parallel.py ├── scripts │ ├── run_8gpus.sh │ └── slurm │ │ ├── run_32gpus.sh │ │ └── run_8gpus.sh ├── train.py └── transport │ ├── __init__.py │ ├── integrators.py │ ├── path.py │ ├── transport.py │ └── utils.py ├── LICENSE ├── Next-DiT-ImageNet ├── .isort.cfg ├── README.md ├── exps │ ├── 600M_bs256_lr5e-4_bf16_qknorm_lognorm.sh │ └── slurm │ │ ├── 2B_bs256_lr5e-4_bf16_qknorm_lognorm.sh │ │ ├── 3B_bs256_lr5e-4_bf16_qknorm_lognorm.sh │ │ ├── 600M_bs256_lr5e-4_bf16_qknorm_lognorm.sh │ │ └── 7B_bs256_lr5e-4_bf16_qknorm_lognorm.sh ├── fid_is.png ├── grad_norm.py ├── init_loss.py ├── models │ ├── __init__.py │ └── models.py ├── parallel.py ├── sample.py ├── scripts │ ├── run_8gpus.sh │ └── slurm │ │ ├── run_32gpus.sh │ │ └── run_8gpus.sh ├── train.py └── transport │ ├── __init__.py │ ├── integrators.py │ ├── path.py │ ├── transport.py │ └── utils.py ├── Next-DiT-MoE ├── README.md ├── exps │ ├── 600M_bs256_lr5e-4_bf16_qknorm_lognorm.sh │ └── slurm │ │ ├── 2B_bs256_lr5e-4_bf16_qknorm_lognorm.sh │ │ ├── 3B_bs256_lr5e-4_bf16_qknorm_lognorm.sh │ │ ├── 600M_bs256_lr5e-4_bf16_qknorm_lognorm.sh │ │ └── 7B_bs256_lr5e-4_bf16_qknorm_lognorm.sh ├── grad_norm.py ├── loss_spacemoe.png ├── loss_timemoe.png ├── loss_timespacemoe.png ├── models │ ├── __init__.py │ ├── models.py │ ├── models1.py │ └── models2.py ├── moe_model.png ├── parallel.py ├── sample.py ├── scripts │ ├── run_8gpus.sh │ └── slurm │ │ ├── run_32gpus.sh │ │ └── run_8gpus.sh ├── train.py └── transport │ ├── __init__.py │ ├── integrators.py │ ├── path.py │ ├── transport.py │ └── utils.py ├── README.md ├── README_cn.md ├── assets ├── audios │ ├── a_telephone_bell_rings.wav │ └── a_telephone_bell_rings_gt.wav ├── compositional_intro.png ├── diverse_config.png ├── images │ ├── demo_image.png │ └── resolution_extrapolation_2.jpg ├── lumina-intro.png └── lumina-logo.png ├── lumina_audio ├── README.md ├── configs │ └── lumina-text2audio.yaml ├── demo_audio.py ├── models │ ├── __init__.py │ ├── autoencoder1d.py │ ├── diffusion │ │ ├── __init__.py │ │ ├── component.py │ │ ├── ddim.py │ │ ├── ddpm.py │ │ ├── ddpm_audio.py │ │ ├── distributions │ │ │ ├── __init__.py │ │ │ └── distributions.py │ │ ├── ema.py │ │ ├── flag_large_dit.py │ │ └── util.py │ ├── encoders │ │ ├── CLAP │ │ │ ├── CLAPWrapper.py │ │ │ ├── __init__.py │ │ │ ├── audio.py │ │ │ ├── clap.py │ │ │ ├── config.yml │ │ │ └── utils.py │ │ ├── __init__.py │ │ └── modules.py │ ├── lr_scheduler.py │ ├── util.py │ └── vocoder │ │ └── bigvgan │ │ ├── __init__.py │ │ ├── activations.py │ │ ├── alias_free_torch │ │ ├── __init__.py │ │ ├── act.py │ │ ├── filter.py │ │ └── resample.py │ │ └── models.py ├── n2s_openai.py ├── requirements.txt ├── run_audio.sh └── style.css ├── lumina_music ├── README.md ├── configs │ └── lumina-text2music.yaml ├── demo_music.py ├── models │ ├── __init__.py │ ├── autoencoder1d.py │ ├── diffusion │ │ ├── __init__.py │ │ ├── component.py │ │ ├── ddim.py │ │ ├── ddpm.py │ │ ├── ddpm_audio.py │ │ ├── distributions │ │ │ ├── __init__.py │ │ │ └── distributions.py │ │ ├── ema.py │ │ ├── flag_large_dit.py │ │ └── util.py │ ├── encoders │ │ ├── CLAP │ │ │ ├── CLAPWrapper.py │ │ │ ├── __init__.py │ │ │ ├── audio.py │ │ │ ├── clap.py │ │ │ ├── config.yml │ │ │ └── utils.py │ │ ├── __init__.py │ │ └── modules.py │ ├── lr_scheduler.py │ ├── util.py │ └── vocoder │ │ └── bigvgan │ │ ├── __init__.py │ │ ├── activations.py │ │ ├── alias_free_torch │ │ ├── __init__.py │ │ ├── act.py │ │ ├── filter.py │ │ └── resample.py │ │ └── models.py ├── requirements.txt └── run_music.sh ├── lumina_next_compositional_generation ├── README.md ├── demo.py ├── models │ ├── __init__.py │ ├── components.py │ └── model.py └── transport │ ├── __init__.py │ ├── integrators.py │ ├── path.py │ ├── transport.py │ └── utils.py ├── lumina_next_t2i ├── README.md ├── __init__.py ├── configs │ ├── data │ │ └── JourneyDB.yaml │ └── infer │ │ └── settings.yaml ├── data │ ├── __init__.py │ ├── data_reader.py │ └── dataset.py ├── demo.py ├── entry_point.py ├── grad_norm.py ├── imgproc.py ├── models │ ├── __init__.py │ ├── components.py │ └── model.py ├── parallel.py ├── sample.py ├── train.py ├── transport │ ├── __init__.py │ ├── integrators.py │ ├── path.py │ ├── transport.py │ └── utils.py └── utils │ ├── __init__.py │ ├── cli.py │ └── group.py ├── lumina_next_t2i_mini ├── README.md ├── __init__.py ├── configs │ └── data │ │ └── JourneyDB.yaml ├── data │ ├── __init__.py │ ├── data_reader.py │ └── dataset.py ├── demo.py ├── grad_norm.py ├── imgproc.py ├── models │ ├── __init__.py │ ├── components.py │ └── nextdit.py ├── parallel.py ├── sample.py ├── sample_img2img.py ├── sample_sd3.py ├── scripts │ ├── sample.sh │ ├── sample_img2img.sh │ └── sample_sd3.sh ├── train.py ├── train_dreambooth_sd3.py └── transport.py ├── lumina_t2i ├── .isort.cfg ├── README.md ├── __init__.py ├── configs │ ├── data │ │ └── JourneyDB.yaml │ └── infer │ │ └── settings.yaml ├── data │ ├── __init__.py │ ├── data_reader.py │ └── dataset.py ├── demo.py ├── entry_point.py ├── exps │ ├── 5B_bs512_lr1e-4_bf16_1024px_sdxlvae.sh │ ├── 5B_bs512_lr1e-4_bf16_256px_sdxlvae.sh │ ├── 5B_bs512_lr1e-4_bf16_512px_sdxlvae.sh │ └── slurm │ │ ├── 5B_bs512_lr1e-4_bf16_1024px_sdxlvae.sh │ │ ├── 5B_bs512_lr1e-4_bf16_256px_sdxlvae.sh │ │ └── 5B_bs512_lr1e-4_bf16_512px_sdxlvae.sh ├── grad_norm.py ├── imgproc.py ├── models │ ├── __init__.py │ ├── components.py │ └── model.py ├── parallel.py ├── requirements.txt ├── train.py ├── transport │ ├── __init__.py │ ├── integrators.py │ ├── path.py │ ├── transport.py │ └── utils.py └── utils │ ├── __init__.py │ ├── cli.py │ └── group.py ├── pyproject.toml ├── requirements.txt └── visual_anagrams ├── .DS_Store ├── LICENSE ├── MANIFEST.in ├── animate.py ├── environment.yml ├── generate.py ├── huggingface_login.py ├── models ├── __init__.py ├── components.py └── nextdit.py ├── readme.md ├── run.sh ├── setup.py └── visual_anagrams ├── .DS_Store ├── __init__.py ├── animate.py ├── assets └── CourierPrime-Regular.ttf ├── samplers.py ├── utils.py └── views ├── __init__.py ├── assets └── 4x4 │ ├── 4x4_corner_1024.png │ ├── 4x4_corner_256.png │ ├── 4x4_corner_64.png │ ├── 4x4_edge1_1024.png │ ├── 4x4_edge1_256.png │ ├── 4x4_edge1_64.png │ ├── 4x4_edge2_1024.png │ ├── 4x4_edge2_256.png │ ├── 4x4_edge2_64.png │ ├── 4x4_inner_1024.png │ ├── 4x4_inner_256.png │ └── 4x4_inner_64.png ├── jigsaw_helpers.py ├── permutations.py ├── view_base.py ├── view_blur.py ├── view_color.py ├── view_flip.py ├── view_hybrid.py ├── view_identity.py ├── view_inner_circle.py ├── view_jigsaw.py ├── view_motion.py ├── view_negate.py ├── view_patch_permute.py ├── view_permute.py ├── view_rotate.py ├── view_scale.py ├── view_skew.py ├── view_square_hinge.py └── view_white_balance.py /.isort.cfg: -------------------------------------------------------------------------------- 1 | [settings] 2 | profile = black 3 | line_length = 120 4 | sections = FUTURE,STDLIB,THIRDPARTY,FIRSTPARTY,LOCALFOLDER 5 | no_lines_before = STDLIB,LOCALFOLDER 6 | lines_between_types = 1 7 | combine_as_imports = True 8 | force_sort_within_sections = true 9 | order_by_type = True 10 | src_paths = * 11 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # Exclude all third-party libraries and auto-generated files globally 2 | exclude: | 3 | (?x)^( 4 | assets/.+| 5 | Flag-DiT-ImageNet/exps/.+| 6 | Next-DiT-ImageNet/exps/.+| 7 | lumina_t2i/configs/.+| 8 | lumina_next_t2i/configs/.+| 9 | )$ 10 | repos: 11 | # Common hooks 12 | - repo: https://github.com/pre-commit/pre-commit-hooks 13 | rev: v4.6.0 14 | hooks: 15 | - id: check-merge-conflict 16 | - id: check-symlinks 17 | - id: detect-private-key 18 | - id: end-of-file-fixer 19 | - id: trailing-whitespace 20 | files: (.*\.(py|bzl|md|rst|c|cc|cxx|cpp|cu|h|hpp|hxx|xpu|kps|cmake|yaml|yml|hook)|BUILD|.*\.BUILD|WORKSPACE|CMakeLists\.txt)$ 21 | # For Python files 22 | - repo: https://github.com/PyCQA/isort 23 | rev: 5.13.2 24 | hooks: 25 | - id: isort 26 | - repo: https://github.com/psf/black.git 27 | rev: 24.4.2 28 | hooks: 29 | - id: black 30 | files: (.*\.(py|pyi|bzl)|BUILD|.*\.BUILD|WORKSPACE)$ 31 | args: [--line-length=120] 32 | -------------------------------------------------------------------------------- /Flag-DiT-ImageNet/exps/600M_bs256_lr5e-4_bf16_qknorm_lognorm.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | train_data_root='/mnt/petrelfs/share/images/train' 4 | 5 | model=DiT_Llama_600M_patch2 6 | batch_size=256 7 | lr=5e-4 8 | precision=bf16 9 | 10 | exp_name=${model}_bs${batch_size}_lr${lr}_${precision}_qknorm 11 | mkdir -p results/"$exp_name" 12 | 13 | torchrun --nproc-per-node=8 train.py \ 14 | --model ${model} \ 15 | --data_path ${train_data_root} \ 16 | --results_dir results/"$exp_name" \ 17 | --micro_batch_size 32 \ 18 | --global_batch_size ${batch_size} --lr ${lr} \ 19 | --data_parallel sdp \ 20 | --max_steps 3000000 \ 21 | --ckpt_every 10000 --log_every 100 \ 22 | --precision ${precision} --grad_precision fp32 --qk_norm \ 23 | --snr_type "lognorm" \ 24 | 2>&1 | tee -a results/"$exp_name"/output.log 25 | -------------------------------------------------------------------------------- /Flag-DiT-ImageNet/exps/slurm/3B_bs256_lr5e-4_bf16_qknorm_lognorm.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | train_data_root='/mnt/petrelfs/share/images/train' 4 | 5 | model=DiT_Llama_3B_patch2 6 | batch_size=256 7 | lr=5e-4 8 | precision=bf16 9 | 10 | exp_name=${model}_bs${batch_size}_lr${lr}_${precision}_qknorm 11 | mkdir -p results/"$exp_name" 12 | 13 | python -u train.py \ 14 | --model ${model} \ 15 | --data_path ${train_data_root} \ 16 | --results_dir results/"$exp_name" \ 17 | --micro_batch_size 32 \ 18 | --global_batch_size ${batch_size} --lr ${lr} \ 19 | --data_parallel sdp \ 20 | --max_steps 3000000 \ 21 | --ckpt_every 10000 --log_every 100 \ 22 | --precision ${precision} --grad_precision fp32 --qk_norm \ 23 | --snr_type "lognorm" \ 24 | 2>&1 | tee -a results/"$exp_name"/output.log 25 | -------------------------------------------------------------------------------- /Flag-DiT-ImageNet/exps/slurm/600M_bs256_lr5e-4_bf16_qknorm_lognorm.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | train_data_root='/mnt/petrelfs/share/images/train' 4 | 5 | model=DiT_Llama_600M_patch2 6 | batch_size=256 7 | lr=5e-4 8 | precision=bf16 9 | 10 | exp_name=${model}_bs${batch_size}_lr${lr}_${precision}_qknorm 11 | mkdir -p results/"$exp_name" 12 | 13 | python -u train.py \ 14 | --model ${model} \ 15 | --data_path ${train_data_root} \ 16 | --results_dir results/"$exp_name" \ 17 | --micro_batch_size 32 \ 18 | --global_batch_size ${batch_size} --lr ${lr} \ 19 | --data_parallel sdp \ 20 | --max_steps 3000000 \ 21 | --ckpt_every 10000 --log_every 100 \ 22 | --precision ${precision} --grad_precision fp32 --qk_norm \ 23 | --snr_type "lognorm" \ 24 | 2>&1 | tee -a results/"$exp_name"/output.log 25 | -------------------------------------------------------------------------------- /Flag-DiT-ImageNet/exps/slurm/7B_bs256_lr5e-4_bf16_qknorm_lognorm.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | train_data_root='/mnt/petrelfs/share/images/train' 4 | 5 | model=DiT_Llama_7B_patch2 6 | batch_size=256 7 | lr=5e-4 8 | precision=bf16 9 | 10 | exp_name=${model}_bs${batch_size}_lr${lr}_${precision}_qknorm 11 | mkdir -p results/"$exp_name" 12 | 13 | python -u train.py \ 14 | --model ${model} \ 15 | --data_path ${train_data_root} \ 16 | --results_dir results/"$exp_name" \ 17 | --micro_batch_size 32 \ 18 | --global_batch_size ${batch_size} --lr ${lr} \ 19 | --data_parallel sdp \ 20 | --max_steps 3000000 \ 21 | --ckpt_every 10000 --log_every 100 \ 22 | --precision ${precision} --grad_precision fp32 --qk_norm \ 23 | --snr_type "lognorm" \ 24 | 2>&1 | tee -a results/"$exp_name"/output.log 25 | -------------------------------------------------------------------------------- /Flag-DiT-ImageNet/grad_norm.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import fairscale.nn.model_parallel.initialize as fs_init 4 | from fairscale.nn.model_parallel.layers import ColumnParallelLinear, ParallelEmbedding, RowParallelLinear 5 | import torch 6 | import torch.distributed as dist 7 | import torch.nn as nn 8 | 9 | 10 | def get_model_parallel_dim_dict(model: nn.Module) -> Dict[str, int]: 11 | ret_dict = {} 12 | for module_name, module in model.named_modules(): 13 | 14 | def param_fqn(param_name): 15 | return param_name if module_name == "" else module_name + "." + param_name 16 | 17 | if isinstance(module, ColumnParallelLinear): 18 | ret_dict[param_fqn("weight")] = 0 19 | if module.bias is not None: 20 | ret_dict[param_fqn("bias")] = 0 21 | elif isinstance(module, RowParallelLinear): 22 | ret_dict[param_fqn("weight")] = 1 23 | if module.bias is not None: 24 | ret_dict[param_fqn("bias")] = -1 25 | elif isinstance(module, ParallelEmbedding): 26 | ret_dict[param_fqn("weight")] = 1 27 | else: 28 | for param_name, param in module.named_parameters(recurse=False): 29 | ret_dict[param_fqn(param_name)] = -1 30 | return ret_dict 31 | 32 | 33 | def calculate_l2_grad_norm( 34 | model: nn.Module, 35 | model_parallel_dim_dict: Dict[str, int], 36 | ) -> float: 37 | mp_norm_sq = torch.tensor(0.0, dtype=torch.float32, device="cuda") 38 | non_mp_norm_sq = torch.tensor(0.0, dtype=torch.float32, device="cuda") 39 | 40 | for name, param in model.named_parameters(): 41 | if param.grad is None: 42 | continue 43 | name = ".".join(x for x in name.split(".") if not x.startswith("_")) 44 | assert name in model_parallel_dim_dict 45 | if model_parallel_dim_dict[name] < 0: 46 | non_mp_norm_sq += param.grad.norm(dtype=torch.float32) ** 2 47 | else: 48 | mp_norm_sq += param.grad.norm(dtype=torch.float32) ** 2 49 | 50 | dist.all_reduce(mp_norm_sq) 51 | dist.all_reduce(non_mp_norm_sq) 52 | non_mp_norm_sq /= fs_init.get_model_parallel_world_size() 53 | 54 | return (mp_norm_sq.item() + non_mp_norm_sq.item()) ** 0.5 55 | 56 | 57 | def scale_grad(model: nn.Module, factor: float) -> None: 58 | for param in model.parameters(): 59 | if param.grad is not None: 60 | param.grad.mul_(factor) 61 | -------------------------------------------------------------------------------- /Flag-DiT-ImageNet/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import DiT_Llama_3B_patch2, DiT_Llama_7B_patch2, DiT_Llama_600M_patch2 2 | -------------------------------------------------------------------------------- /Flag-DiT-ImageNet/models/components.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | try: 7 | from apex.normalization import FusedRMSNorm as RMSNorm 8 | except ImportError: 9 | warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation") 10 | 11 | class RMSNorm(torch.nn.Module): 12 | def __init__(self, dim: int, eps: float = 1e-6): 13 | """ 14 | Initialize the RMSNorm normalization layer. 15 | 16 | Args: 17 | dim (int): The dimension of the input tensor. 18 | eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. 19 | 20 | Attributes: 21 | eps (float): A small value added to the denominator for numerical stability. 22 | weight (nn.Parameter): Learnable scaling parameter. 23 | 24 | """ 25 | super().__init__() 26 | self.eps = eps 27 | self.weight = nn.Parameter(torch.ones(dim)) 28 | 29 | def _norm(self, x): 30 | """ 31 | Apply the RMSNorm normalization to the input tensor. 32 | 33 | Args: 34 | x (torch.Tensor): The input tensor. 35 | 36 | Returns: 37 | torch.Tensor: The normalized tensor. 38 | 39 | """ 40 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 41 | 42 | def forward(self, x): 43 | """ 44 | Forward pass through the RMSNorm layer. 45 | 46 | Args: 47 | x (torch.Tensor): The input tensor. 48 | 49 | Returns: 50 | torch.Tensor: The output tensor after applying RMSNorm. 51 | 52 | """ 53 | output = self._norm(x.float()).type_as(x) 54 | return output * self.weight 55 | -------------------------------------------------------------------------------- /Flag-DiT-ImageNet/parallel.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | import subprocess 5 | from time import sleep 6 | 7 | import fairscale.nn.model_parallel.initialize as fs_init 8 | import torch 9 | import torch.distributed as dist 10 | 11 | 12 | def _setup_dist_env_from_slurm(args): 13 | while not os.environ.get("MASTER_ADDR", ""): 14 | os.environ["MASTER_ADDR"] = ( 15 | subprocess.check_output( 16 | "sinfo -Nh -n %s | head -n 1 | awk '{print $1}'" % os.environ["SLURM_NODELIST"], 17 | shell=True, 18 | ) 19 | .decode() 20 | .strip() 21 | ) 22 | sleep(1) 23 | os.environ["MASTER_PORT"] = str(args.master_port) 24 | os.environ["RANK"] = os.environ["SLURM_PROCID"] 25 | os.environ["WORLD_SIZE"] = os.environ["SLURM_NPROCS"] 26 | os.environ["LOCAL_RANK"] = os.environ["SLURM_LOCALID"] 27 | os.environ["LOCAL_WORLD_SIZE"] = os.environ["SLURM_NTASKS_PER_NODE"] 28 | 29 | 30 | _INTRA_NODE_PROCESS_GROUP, _INTER_NODE_PROCESS_GROUP = None, None 31 | _LOCAL_RANK, _LOCAL_WORLD_SIZE = -1, -1 32 | 33 | 34 | def get_local_rank() -> int: 35 | return _LOCAL_RANK 36 | 37 | 38 | def get_local_world_size() -> int: 39 | return _LOCAL_WORLD_SIZE 40 | 41 | 42 | def distributed_init(args): 43 | if any([x not in os.environ for x in ["RANK", "WORLD_SIZE", "MASTER_PORT", "MASTER_ADDR"]]): 44 | _setup_dist_env_from_slurm(args) 45 | 46 | dist.init_process_group("nccl") 47 | fs_init.initialize_model_parallel(args.model_parallel_size) 48 | torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count()) 49 | 50 | global _LOCAL_RANK, _LOCAL_WORLD_SIZE 51 | _LOCAL_RANK = int(os.environ["LOCAL_RANK"]) 52 | _LOCAL_WORLD_SIZE = int(os.environ["LOCAL_WORLD_SIZE"]) 53 | 54 | global _INTRA_NODE_PROCESS_GROUP, _INTER_NODE_PROCESS_GROUP 55 | local_ranks, local_world_sizes = [ 56 | torch.empty([dist.get_world_size()], dtype=torch.long, device="cuda") for _ in (0, 1) 57 | ] 58 | dist.all_gather_into_tensor(local_ranks, torch.tensor(get_local_rank(), device="cuda")) 59 | dist.all_gather_into_tensor(local_world_sizes, torch.tensor(get_local_world_size(), device="cuda")) 60 | local_ranks, local_world_sizes = local_ranks.tolist(), local_world_sizes.tolist() 61 | node_ranks = [[0]] 62 | for i in range(1, dist.get_world_size()): 63 | if len(node_ranks[-1]) == local_world_sizes[i - 1]: 64 | node_ranks.append([]) 65 | else: 66 | assert local_world_sizes[i] == local_world_sizes[i - 1] 67 | node_ranks[-1].append(i) 68 | for ranks in node_ranks: 69 | group = dist.new_group(ranks) 70 | if dist.get_rank() in ranks: 71 | assert _INTRA_NODE_PROCESS_GROUP is None 72 | _INTRA_NODE_PROCESS_GROUP = group 73 | assert _INTRA_NODE_PROCESS_GROUP is not None 74 | 75 | if min(local_world_sizes) == max(local_world_sizes): 76 | for i in range(get_local_world_size()): 77 | group = dist.new_group(list(range(i, dist.get_world_size(), get_local_world_size()))) 78 | if i == get_local_rank(): 79 | assert _INTER_NODE_PROCESS_GROUP is None 80 | _INTER_NODE_PROCESS_GROUP = group 81 | assert _INTER_NODE_PROCESS_GROUP is not None 82 | 83 | 84 | def get_intra_node_process_group(): 85 | assert _INTRA_NODE_PROCESS_GROUP is not None, "Intra-node process group is not initialized." 86 | return _INTRA_NODE_PROCESS_GROUP 87 | 88 | 89 | def get_inter_node_process_group(): 90 | assert _INTRA_NODE_PROCESS_GROUP is not None, "Intra- and inter-node process groups are not initialized." 91 | return _INTER_NODE_PROCESS_GROUP 92 | -------------------------------------------------------------------------------- /Flag-DiT-ImageNet/scripts/run_8gpus.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # run Flag-DiT with single node 4 | 5 | # run Flag-DiT 600M 6 | bash exps/600M_bs256_lr5e-4_bf16_qknorm_lognorm.sh 7 | -------------------------------------------------------------------------------- /Flag-DiT-ImageNet/scripts/slurm/run_32gpus.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # run Flag-DiT with cluster 4 | 5 | # added config here for slurm cluster using 32 GPUs 6 | 7 | # run Flag-DiT 600M 8 | srun bash exps/600M_bs256_lr5e-4_bf16_qknorm_lognorm.sh 9 | # run Flag-DiT 3B 10 | srun bash exps/3B_bs256_lr5e-4_bf16_qknorm_lognorm.sh 11 | # run Flag-DiT 7B 12 | srun bash exps/7B_bs256_lr5e-4_bf16_qknorm_lognorm.sh 13 | -------------------------------------------------------------------------------- /Flag-DiT-ImageNet/scripts/slurm/run_8gpus.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # run Flag-DiT with cluster 4 | 5 | # added config here for slurm cluster using 8 GPUs 6 | 7 | # run Flag-DiT 600M 8 | srun bash exps/600M_bs256_lr5e-4_bf16_qknorm_lognorm.sh 9 | # run Flag-DiT 3B 10 | srun bash exps/3B_bs256_lr5e-4_bf16_qknorm_lognorm.sh 11 | # run Flag-DiT 7B 12 | srun bash exps/7B_bs256_lr5e-4_bf16_qknorm_lognorm.sh 13 | -------------------------------------------------------------------------------- /Flag-DiT-ImageNet/transport/__init__.py: -------------------------------------------------------------------------------- 1 | from .transport import ModelType, PathType, Sampler, SNRType, Transport, WeightType 2 | 3 | 4 | def create_transport( 5 | path_type="Linear", prediction="velocity", loss_weight=None, train_eps=None, sample_eps=None, snr_type="uniform" 6 | ): 7 | """function for creating Transport object 8 | **Note**: model prediction defaults to velocity 9 | Args: 10 | - path_type: type of path to use; default to linear 11 | - learn_score: set model prediction to score 12 | - learn_noise: set model prediction to noise 13 | - velocity_weighted: weight loss by velocity weight 14 | - likelihood_weighted: weight loss by likelihood weight 15 | - train_eps: small epsilon for avoiding instability during training 16 | - sample_eps: small epsilon for avoiding instability during sampling 17 | """ 18 | 19 | if prediction == "noise": 20 | model_type = ModelType.NOISE 21 | elif prediction == "score": 22 | model_type = ModelType.SCORE 23 | else: 24 | model_type = ModelType.VELOCITY 25 | 26 | if loss_weight == "velocity": 27 | loss_type = WeightType.VELOCITY 28 | elif loss_weight == "likelihood": 29 | loss_type = WeightType.LIKELIHOOD 30 | else: 31 | loss_type = WeightType.NONE 32 | 33 | if snr_type == "lognorm": 34 | snr_type = SNRType.LOGNORM 35 | elif snr_type == "uniform": 36 | snr_type = SNRType.UNIFORM 37 | else: 38 | raise ValueError(f"Invalid snr type {snr_type}") 39 | 40 | path_choice = { 41 | "Linear": PathType.LINEAR, 42 | "GVP": PathType.GVP, 43 | "VP": PathType.VP, 44 | } 45 | 46 | path_type = path_choice[path_type] 47 | 48 | if path_type in [PathType.VP]: 49 | train_eps = 1e-5 if train_eps is None else train_eps 50 | sample_eps = 1e-3 if train_eps is None else sample_eps 51 | elif path_type in [PathType.GVP, PathType.LINEAR] and model_type != ModelType.VELOCITY: 52 | train_eps = 1e-3 if train_eps is None else train_eps 53 | sample_eps = 1e-3 if train_eps is None else sample_eps 54 | else: # velocity & [GVP, LINEAR] is stable everywhere 55 | train_eps = 0 56 | sample_eps = 0 57 | 58 | # create flow state 59 | state = Transport( 60 | model_type=model_type, 61 | path_type=path_type, 62 | loss_type=loss_type, 63 | train_eps=train_eps, 64 | sample_eps=sample_eps, 65 | snr_type=snr_type, 66 | ) 67 | 68 | return state 69 | -------------------------------------------------------------------------------- /Flag-DiT-ImageNet/transport/utils.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | 3 | 4 | class EasyDict: 5 | def __init__(self, sub_dict): 6 | for k, v in sub_dict.items(): 7 | setattr(self, k, v) 8 | 9 | def __getitem__(self, key): 10 | return getattr(self, key) 11 | 12 | 13 | def mean_flat(x): 14 | """ 15 | Take the mean over all non-batch dimensions. 16 | """ 17 | return th.mean(x, dim=list(range(1, len(x.size())))) 18 | 19 | 20 | def log_state(state): 21 | result = [] 22 | 23 | sorted_state = dict(sorted(state.items())) 24 | for key, value in sorted_state.items(): 25 | # Check if the value is an instance of a class 26 | if "&1 | tee -a results/"$exp_name"/output.log 25 | -------------------------------------------------------------------------------- /Next-DiT-ImageNet/exps/slurm/2B_bs256_lr5e-4_bf16_qknorm_lognorm.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | train_data_root='/path/to/imagenet/images/train' 4 | 5 | model=DiT_Llama_2B_patch2 6 | batch_size=256 7 | lr=5e-4 8 | precision=bf16 9 | 10 | exp_name=${model}_bs${batch_size}_lr${lr}_${precision}_qknorm 11 | mkdir -p results/"$exp_name" 12 | 13 | python -u train.py \ 14 | --model ${model} \ 15 | --data_path ${train_data_root} \ 16 | --results_dir results/"$exp_name" \ 17 | --micro_batch_size 32 \ 18 | --global_batch_size ${batch_size} --lr ${lr} \ 19 | --data_parallel sdp \ 20 | --max_steps 3000000 \ 21 | --ckpt_every 10000 --log_every 100 \ 22 | --precision ${precision} --grad_precision fp32 --qk_norm \ 23 | --snr_type "lognorm" \ 24 | 2>&1 | tee -a results/"$exp_name"/output.log 25 | -------------------------------------------------------------------------------- /Next-DiT-ImageNet/exps/slurm/3B_bs256_lr5e-4_bf16_qknorm_lognorm.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | train_data_root='/path/to/imagenet/images/train' 4 | 5 | model=DiT_Llama_3B_patch2 6 | batch_size=256 7 | lr=5e-4 8 | precision=bf16 9 | 10 | exp_name=${model}_bs${batch_size}_lr${lr}_${precision}_qknorm 11 | mkdir -p results/"$exp_name" 12 | 13 | python -u train.py \ 14 | --model ${model} \ 15 | --data_path ${train_data_root} \ 16 | --results_dir results/"$exp_name" \ 17 | --micro_batch_size 32 \ 18 | --global_batch_size ${batch_size} --lr ${lr} \ 19 | --data_parallel sdp \ 20 | --max_steps 3000000 \ 21 | --ckpt_every 10000 --log_every 100 \ 22 | --precision ${precision} --grad_precision fp32 --qk_norm \ 23 | --snr_type "lognorm" \ 24 | 2>&1 | tee -a results/"$exp_name"/output.log 25 | -------------------------------------------------------------------------------- /Next-DiT-ImageNet/exps/slurm/600M_bs256_lr5e-4_bf16_qknorm_lognorm.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | train_data_root='/path/to/imagenet/images/train' 4 | 5 | model=DiT_Llama_600M_patch2 6 | batch_size=256 7 | lr=5e-4 8 | precision=bf16 9 | 10 | exp_name=${model}_bs${batch_size}_lr${lr}_${precision}_qknorm 11 | mkdir -p results/"$exp_name" 12 | 13 | python -u train.py \ 14 | --model ${model} \ 15 | --data_path ${train_data_root} \ 16 | --results_dir results/"$exp_name" \ 17 | --micro_batch_size 32 \ 18 | --global_batch_size ${batch_size} --lr ${lr} \ 19 | --data_parallel sdp \ 20 | --max_steps 3000000 \ 21 | --ckpt_every 10000 --log_every 100 \ 22 | --precision ${precision} --grad_precision fp32 --qk_norm \ 23 | --snr_type "lognorm" \ 24 | 2>&1 | tee -a results/"$exp_name"/output.log 25 | -------------------------------------------------------------------------------- /Next-DiT-ImageNet/exps/slurm/7B_bs256_lr5e-4_bf16_qknorm_lognorm.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | train_data_root='/path/to/imagenet/images/train' 4 | 5 | model=DiT_Llama_7B_patch2 6 | batch_size=256 7 | lr=5e-4 8 | precision=bf16 9 | 10 | exp_name=${model}_bs${batch_size}_lr${lr}_${precision}_qknorm 11 | mkdir -p results/"$exp_name" 12 | 13 | python -u train.py \ 14 | --model ${model} \ 15 | --data_path ${train_data_root} \ 16 | --results_dir results/"$exp_name" \ 17 | --micro_batch_size 32 \ 18 | --global_batch_size ${batch_size} --lr ${lr} \ 19 | --data_parallel sdp \ 20 | --max_steps 3000000 \ 21 | --ckpt_every 10000 --log_every 100 \ 22 | --precision ${precision} --grad_precision fp32 --qk_norm \ 23 | --snr_type "lognorm" \ 24 | 2>&1 | tee -a results/"$exp_name"/output.log 25 | -------------------------------------------------------------------------------- /Next-DiT-ImageNet/fid_is.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/Next-DiT-ImageNet/fid_is.png -------------------------------------------------------------------------------- /Next-DiT-ImageNet/grad_norm.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import fairscale.nn.model_parallel.initialize as fs_init 4 | from fairscale.nn.model_parallel.layers import ColumnParallelLinear, ParallelEmbedding, RowParallelLinear 5 | import torch 6 | import torch.distributed as dist 7 | import torch.nn as nn 8 | 9 | 10 | def get_model_parallel_dim_dict(model: nn.Module) -> Dict[str, int]: 11 | ret_dict = {} 12 | for module_name, module in model.named_modules(): 13 | if isinstance(module, ColumnParallelLinear): 14 | ret_dict[module_name + ".weight"] = 0 15 | if module.bias is not None: 16 | ret_dict[module_name + ".bias"] = 0 17 | elif isinstance(module, RowParallelLinear): 18 | ret_dict[module_name + ".weight"] = 1 19 | if module.bias is not None: 20 | ret_dict[module_name + ".bias"] = -1 21 | elif isinstance(module, ParallelEmbedding): 22 | ret_dict[module_name + ".weight"] = 1 23 | else: 24 | for param_name, param in module.named_parameters(recurse=False): 25 | ret_dict[(module_name + "." if len(module_name) > 0 else "") + param_name] = -1 26 | return ret_dict 27 | 28 | 29 | def calculate_l2_grad_norm( 30 | model: nn.Module, 31 | model_parallel_dim_dict: Dict[str, int], 32 | ) -> float: 33 | mp_norm_sq = torch.tensor(0.0, dtype=torch.float32, device="cuda") 34 | non_mp_norm_sq = torch.tensor(0.0, dtype=torch.float32, device="cuda") 35 | 36 | for name, param in model.named_parameters(): 37 | if param.grad is None: 38 | continue 39 | name = ".".join(x for x in name.split(".") if not x.startswith("_")) 40 | assert name in model_parallel_dim_dict 41 | if model_parallel_dim_dict[name] < 0: 42 | non_mp_norm_sq += param.grad.norm(dtype=torch.float32) 43 | else: 44 | mp_norm_sq += param.grad.norm(dtype=torch.float32) 45 | 46 | dist.all_reduce(mp_norm_sq) 47 | dist.all_reduce(non_mp_norm_sq) 48 | non_mp_norm_sq /= fs_init.get_model_parallel_world_size() 49 | 50 | return (mp_norm_sq.item() + non_mp_norm_sq.item()) ** 0.5 51 | 52 | 53 | def scale_grad(model: nn.Module, factor: float) -> None: 54 | for param in model.parameters(): 55 | if param.grad is not None: 56 | param.grad.mul_(factor) 57 | 58 | 59 | def get_param_norm_dict(model: nn.Module, model_parallel_dim_dict: Dict[str, int]) -> Dict[str, float]: 60 | param_norm_dict = {} 61 | for name, param in model.named_parameters(): 62 | name = ".".join(x for x in name.split(".") if not x.startswith("_")) 63 | norm_sq = param.norm(dtype=torch.float32) ** 2 64 | dist.all_reduce(norm_sq) 65 | norm_sq = norm_sq.item() 66 | if model_parallel_dim_dict[name] < 0: 67 | norm_sq /= fs_init.get_model_parallel_world_size() 68 | norm = norm_sq**0.5 69 | param_norm_dict[name] = norm 70 | return param_norm_dict 71 | -------------------------------------------------------------------------------- /Next-DiT-ImageNet/init_loss.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | 4 | import matplotlib.pyplot as plt 5 | 6 | 7 | def extract_loss_from_log(log_file): 8 | with open(log_file, "r") as f: 9 | log_text = f.read() 10 | pattern = r"\(step=(\d+)\) Train Loss: ([\d.]+)" 11 | matches = re.findall(pattern, log_text) 12 | steps = [] 13 | losses = [] 14 | for match in matches: 15 | step, train_loss = match 16 | steps.append(int(step)) 17 | # losses.append(min(float(train_loss), 0.9)) 18 | losses.append(float(train_loss)) 19 | return steps, losses 20 | 21 | 22 | def smooth_loss(losses, alpha): 23 | smoothed = [losses[0]] 24 | for i in range(1, len(losses)): 25 | smoothed.append((1 - alpha) * losses[i] + alpha * smoothed[i - 1]) 26 | return smoothed 27 | 28 | 29 | def plot_losses(log_folder): 30 | plt.figure(figsize=(10, 6)) 31 | 32 | # 设置全局字体大小和粗细 33 | plt.rcParams["font.size"] = 12 # 字体大小 34 | plt.rcParams["font.weight"] = "bold" # 字体粗细 35 | 36 | for log_file in os.listdir(log_folder): 37 | if log_file.endswith(".txt"): # 假设日志文件都以'.txt'结尾 38 | steps, losses = extract_loss_from_log(os.path.join(log_folder, log_file)) 39 | losses = smooth_loss(losses, 0.8) 40 | steps = [i / 1000 for i in steps] # 每1000步绘制一个点 41 | plt.plot(steps[80:], losses[80:], label=log_file.replace(".txt", "").replace("_", " ")) 42 | 43 | # 设置x轴和y轴的标签字体样式 44 | plt.xlabel("Training Iterations (k)", fontweight="bold", fontsize=14) 45 | plt.ylabel("Loss", fontweight="bold", fontsize=14) 46 | 47 | # 设置标题字体样式 48 | plt.legend() 49 | plt.grid(False) 50 | plt.savefig(os.path.join(log_folder, "loss_curve.png")) 51 | plt.show() 52 | 53 | 54 | if __name__ == "__main__": 55 | log_folder = "/mnt/petrelfs/share_data/liuwenze/results/Large-SiT-rope2d/3B_optimal_lr" 56 | plot_losses(log_folder) 57 | -------------------------------------------------------------------------------- /Next-DiT-ImageNet/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import DiT_Llama_2B_patch2, DiT_Llama_3B_patch2, DiT_Llama_7B_patch2, DiT_Llama_600M_patch2 2 | -------------------------------------------------------------------------------- /Next-DiT-ImageNet/parallel.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | import subprocess 5 | from time import sleep 6 | 7 | import fairscale.nn.model_parallel.initialize as fs_init 8 | import torch 9 | import torch.distributed as dist 10 | 11 | 12 | def _setup_dist_env_from_slurm(args): 13 | while not os.environ.get("MASTER_ADDR", ""): 14 | os.environ["MASTER_ADDR"] = ( 15 | subprocess.check_output( 16 | "sinfo -Nh -n %s | head -n 1 | awk '{print $1}'" % os.environ["SLURM_NODELIST"], 17 | shell=True, 18 | ) 19 | .decode() 20 | .strip() 21 | ) 22 | sleep(1) 23 | os.environ["MASTER_PORT"] = str(args.master_port) 24 | os.environ["RANK"] = os.environ["SLURM_PROCID"] 25 | os.environ["WORLD_SIZE"] = os.environ["SLURM_NPROCS"] 26 | os.environ["LOCAL_RANK"] = os.environ["SLURM_LOCALID"] 27 | os.environ["LOCAL_WORLD_SIZE"] = os.environ["SLURM_NTASKS_PER_NODE"] 28 | 29 | 30 | _INTRA_NODE_PROCESS_GROUP, _INTER_NODE_PROCESS_GROUP = None, None 31 | _LOCAL_RANK, _LOCAL_WORLD_SIZE = -1, -1 32 | 33 | 34 | def get_local_rank() -> int: 35 | return _LOCAL_RANK 36 | 37 | 38 | def get_local_world_size() -> int: 39 | return _LOCAL_WORLD_SIZE 40 | 41 | 42 | def distributed_init(args): 43 | if any([x not in os.environ for x in ["RANK", "WORLD_SIZE", "MASTER_PORT", "MASTER_ADDR"]]): 44 | _setup_dist_env_from_slurm(args) 45 | 46 | dist.init_process_group("nccl") 47 | fs_init.initialize_model_parallel(args.model_parallel_size) 48 | torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count()) 49 | 50 | global _LOCAL_RANK, _LOCAL_WORLD_SIZE 51 | _LOCAL_RANK = int(os.environ["LOCAL_RANK"]) 52 | _LOCAL_WORLD_SIZE = int(os.environ["LOCAL_WORLD_SIZE"]) 53 | 54 | global _INTRA_NODE_PROCESS_GROUP, _INTER_NODE_PROCESS_GROUP 55 | local_ranks, local_world_sizes = [ 56 | torch.empty([dist.get_world_size()], dtype=torch.long, device="cuda") for _ in (0, 1) 57 | ] 58 | dist.all_gather_into_tensor(local_ranks, torch.tensor(get_local_rank(), device="cuda")) 59 | dist.all_gather_into_tensor(local_world_sizes, torch.tensor(get_local_world_size(), device="cuda")) 60 | local_ranks, local_world_sizes = local_ranks.tolist(), local_world_sizes.tolist() 61 | node_ranks = [[0]] 62 | for i in range(1, dist.get_world_size()): 63 | if len(node_ranks[-1]) == local_world_sizes[i - 1]: 64 | node_ranks.append([]) 65 | else: 66 | assert local_world_sizes[i] == local_world_sizes[i - 1] 67 | node_ranks[-1].append(i) 68 | for ranks in node_ranks: 69 | group = dist.new_group(ranks) 70 | if dist.get_rank() in ranks: 71 | assert _INTRA_NODE_PROCESS_GROUP is None 72 | _INTRA_NODE_PROCESS_GROUP = group 73 | assert _INTRA_NODE_PROCESS_GROUP is not None 74 | 75 | if min(local_world_sizes) == max(local_world_sizes): 76 | for i in range(get_local_world_size()): 77 | group = dist.new_group(list(range(i, dist.get_world_size(), get_local_world_size()))) 78 | if i == get_local_rank(): 79 | assert _INTER_NODE_PROCESS_GROUP is None 80 | _INTER_NODE_PROCESS_GROUP = group 81 | assert _INTER_NODE_PROCESS_GROUP is not None 82 | 83 | 84 | def get_intra_node_process_group(): 85 | assert _INTRA_NODE_PROCESS_GROUP is not None, "Intra-node process group is not initialized." 86 | return _INTRA_NODE_PROCESS_GROUP 87 | 88 | 89 | def get_inter_node_process_group(): 90 | assert _INTRA_NODE_PROCESS_GROUP is not None, "Intra- and inter-node process groups are not initialized." 91 | return _INTER_NODE_PROCESS_GROUP 92 | -------------------------------------------------------------------------------- /Next-DiT-ImageNet/scripts/run_8gpus.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # run Next-DiT with single node 4 | 5 | # run Next-DiT 600M 6 | bash exps/600M_bs256_lr5e-4_bf16_qknorm_lognorm.sh 7 | -------------------------------------------------------------------------------- /Next-DiT-ImageNet/scripts/slurm/run_32gpus.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # run Next-DiT with cluster 4 | 5 | # added config here for slurm cluster using 8 GPUs 6 | 7 | # run Next-DiT 600M 8 | srun bash exps/600M_bs256_lr5e-4_bf16_qknorm_lognorm.sh 9 | # run Next-DiT 2B 10 | # srun bash exps/2B_bs256_lr5e-4_bf16_qknorm_lognorm.sh 11 | # run Next-DiT 3B 12 | # srun bash exps/3B_bs256_lr5e-4_bf16_qknorm_lognorm.sh 13 | # run Next-DiT 7B 14 | # srun bash exps/7B_bs256_lr5e-4_bf16_qknorm_lognorm.sh 15 | -------------------------------------------------------------------------------- /Next-DiT-ImageNet/scripts/slurm/run_8gpus.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # run Next-DiT with cluster 4 | 5 | # added config here for slurm cluster using 8 GPUs 6 | 7 | # run Next-DiT 600M 8 | srun bash exps/600M_bs256_lr5e-4_bf16_qknorm_lognorm.sh 9 | # run Next-DiT 2B 10 | # srun bash exps/2BB_bs256_lr5e-4_bf16_qknorm_lognorm.sh 11 | # run Next-DiT 3B 12 | # srun bash exps/3B_bs256_lr5e-4_bf16_qknorm_lognorm.sh 13 | # run Next-DiT 7B 14 | # srun bash exps/7B_bs256_lr5e-4_bf16_qknorm_lognorm.sh 15 | -------------------------------------------------------------------------------- /Next-DiT-ImageNet/transport/__init__.py: -------------------------------------------------------------------------------- 1 | from .transport import ModelType, PathType, Sampler, SNRType, Transport, WeightType 2 | 3 | 4 | def create_transport( 5 | path_type="Linear", prediction="velocity", loss_weight=None, train_eps=None, sample_eps=None, snr_type="uniform" 6 | ): 7 | """function for creating Transport object 8 | **Note**: model prediction defaults to velocity 9 | Args: 10 | - path_type: type of path to use; default to linear 11 | - learn_score: set model prediction to score 12 | - learn_noise: set model prediction to noise 13 | - velocity_weighted: weight loss by velocity weight 14 | - likelihood_weighted: weight loss by likelihood weight 15 | - train_eps: small epsilon for avoiding instability during training 16 | - sample_eps: small epsilon for avoiding instability during sampling 17 | """ 18 | 19 | if prediction == "noise": 20 | model_type = ModelType.NOISE 21 | elif prediction == "score": 22 | model_type = ModelType.SCORE 23 | else: 24 | model_type = ModelType.VELOCITY 25 | 26 | if loss_weight == "velocity": 27 | loss_type = WeightType.VELOCITY 28 | elif loss_weight == "likelihood": 29 | loss_type = WeightType.LIKELIHOOD 30 | else: 31 | loss_type = WeightType.NONE 32 | 33 | if snr_type == "lognorm": 34 | snr_type = SNRType.LOGNORM 35 | elif snr_type == "uniform": 36 | snr_type = SNRType.UNIFORM 37 | else: 38 | raise ValueError(f"Invalid snr type {snr_type}") 39 | 40 | path_choice = { 41 | "Linear": PathType.LINEAR, 42 | "GVP": PathType.GVP, 43 | "VP": PathType.VP, 44 | } 45 | 46 | path_type = path_choice[path_type] 47 | 48 | if path_type in [PathType.VP]: 49 | train_eps = 1e-5 if train_eps is None else train_eps 50 | sample_eps = 1e-3 if train_eps is None else sample_eps 51 | elif path_type in [PathType.GVP, PathType.LINEAR] and model_type != ModelType.VELOCITY: 52 | train_eps = 1e-3 if train_eps is None else train_eps 53 | sample_eps = 1e-3 if train_eps is None else sample_eps 54 | else: # velocity & [GVP, LINEAR] is stable everywhere 55 | train_eps = 0 56 | sample_eps = 0 57 | 58 | # create flow state 59 | state = Transport( 60 | model_type=model_type, 61 | path_type=path_type, 62 | loss_type=loss_type, 63 | train_eps=train_eps, 64 | sample_eps=sample_eps, 65 | snr_type=snr_type, 66 | ) 67 | 68 | return state 69 | -------------------------------------------------------------------------------- /Next-DiT-ImageNet/transport/utils.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | 3 | 4 | class EasyDict: 5 | def __init__(self, sub_dict): 6 | for k, v in sub_dict.items(): 7 | setattr(self, k, v) 8 | 9 | def __getitem__(self, key): 10 | return getattr(self, key) 11 | 12 | 13 | def mean_flat(x): 14 | """ 15 | Take the mean over all non-batch dimensions. 16 | """ 17 | return th.mean(x, dim=list(range(1, len(x.size())))) 18 | 19 | 20 | def log_state(state): 21 | result = [] 22 | 23 | sorted_state = dict(sorted(state.items())) 24 | for key, value in sorted_state.items(): 25 | # Check if the value is an instance of a class 26 | if "&1 | tee -a results/"$exp_name"/output.log 25 | -------------------------------------------------------------------------------- /Next-DiT-MoE/exps/slurm/2B_bs256_lr5e-4_bf16_qknorm_lognorm.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | train_data_root='/path/to/imagenet/images/train' 4 | 5 | model=DiT_Llama_2B_patch2 6 | batch_size=256 7 | lr=5e-4 8 | precision=bf16 9 | 10 | exp_name=${model}_bs${batch_size}_lr${lr}_${precision}_qknorm 11 | mkdir -p results/"$exp_name" 12 | 13 | python -u train.py \ 14 | --model ${model} \ 15 | --data_path ${train_data_root} \ 16 | --results_dir results/"$exp_name" \ 17 | --micro_batch_size 32 \ 18 | --global_batch_size ${batch_size} --lr ${lr} \ 19 | --data_parallel sdp \ 20 | --max_steps 3000000 \ 21 | --ckpt_every 10000 --log_every 100 \ 22 | --precision ${precision} --grad_precision fp32 --qk_norm \ 23 | --snr_type "lognorm" \ 24 | 2>&1 | tee -a results/"$exp_name"/output.log 25 | -------------------------------------------------------------------------------- /Next-DiT-MoE/exps/slurm/3B_bs256_lr5e-4_bf16_qknorm_lognorm.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | train_data_root='/path/to/imagenet/images/train' 4 | 5 | model=DiT_Llama_3B_patch2 6 | batch_size=256 7 | lr=5e-4 8 | precision=bf16 9 | 10 | exp_name=${model}_bs${batch_size}_lr${lr}_${precision}_qknorm 11 | mkdir -p results/"$exp_name" 12 | 13 | python -u train.py \ 14 | --model ${model} \ 15 | --data_path ${train_data_root} \ 16 | --results_dir results/"$exp_name" \ 17 | --micro_batch_size 32 \ 18 | --global_batch_size ${batch_size} --lr ${lr} \ 19 | --data_parallel sdp \ 20 | --max_steps 3000000 \ 21 | --ckpt_every 10000 --log_every 100 \ 22 | --precision ${precision} --grad_precision fp32 --qk_norm \ 23 | --snr_type "lognorm" \ 24 | 2>&1 | tee -a results/"$exp_name"/output.log 25 | -------------------------------------------------------------------------------- /Next-DiT-MoE/exps/slurm/600M_bs256_lr5e-4_bf16_qknorm_lognorm.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | train_data_root='/path/to/imagenet/images/train' 4 | 5 | model=DiT_Llama_600M_patch2 6 | batch_size=256 7 | lr=5e-4 8 | precision=bf16 9 | 10 | exp_name=${model}_bs${batch_size}_lr${lr}_${precision}_qknorm 11 | mkdir -p results/"$exp_name" 12 | 13 | python -u train.py \ 14 | --model ${model} \ 15 | --data_path ${train_data_root} \ 16 | --results_dir results/"$exp_name" \ 17 | --micro_batch_size 32 \ 18 | --global_batch_size ${batch_size} --lr ${lr} \ 19 | --data_parallel sdp \ 20 | --max_steps 3000000 \ 21 | --ckpt_every 10000 --log_every 100 \ 22 | --precision ${precision} --grad_precision fp32 --qk_norm \ 23 | --snr_type "lognorm" \ 24 | 2>&1 | tee -a results/"$exp_name"/output.log 25 | -------------------------------------------------------------------------------- /Next-DiT-MoE/exps/slurm/7B_bs256_lr5e-4_bf16_qknorm_lognorm.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | train_data_root='/path/to/imagenet/images/train' 4 | 5 | model=DiT_Llama_7B_patch2 6 | batch_size=256 7 | lr=5e-4 8 | precision=bf16 9 | 10 | exp_name=${model}_bs${batch_size}_lr${lr}_${precision}_qknorm 11 | mkdir -p results/"$exp_name" 12 | 13 | python -u train.py \ 14 | --model ${model} \ 15 | --data_path ${train_data_root} \ 16 | --results_dir results/"$exp_name" \ 17 | --micro_batch_size 32 \ 18 | --global_batch_size ${batch_size} --lr ${lr} \ 19 | --data_parallel sdp \ 20 | --max_steps 3000000 \ 21 | --ckpt_every 10000 --log_every 100 \ 22 | --precision ${precision} --grad_precision fp32 --qk_norm \ 23 | --snr_type "lognorm" \ 24 | 2>&1 | tee -a results/"$exp_name"/output.log 25 | -------------------------------------------------------------------------------- /Next-DiT-MoE/grad_norm.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | import torch 3 | import torch.nn as nn 4 | import torch.distributed as dist 5 | import fairscale.nn.model_parallel.initialize as fs_init 6 | from fairscale.nn.model_parallel.layers import ( 7 | ColumnParallelLinear, RowParallelLinear, ParallelEmbedding 8 | ) 9 | 10 | 11 | def get_model_parallel_dim_dict(model: nn.Module) -> Dict[str, int]: 12 | ret_dict = {} 13 | for module_name, module in model.named_modules(): 14 | if isinstance(module, ColumnParallelLinear): 15 | ret_dict[module_name + ".weight"] = 0 16 | if module.bias is not None: 17 | ret_dict[module_name + ".bias"] = 0 18 | elif isinstance(module, RowParallelLinear): 19 | ret_dict[module_name + ".weight"] = 1 20 | if module.bias is not None: 21 | ret_dict[module_name + ".bias"] = -1 22 | elif isinstance(module, ParallelEmbedding): 23 | ret_dict[module_name + ".weight"] = 1 24 | else: 25 | for param_name, param in module.named_parameters(recurse=False): 26 | ret_dict[(module_name + "." if len(module_name) > 0 else "") + param_name] = -1 27 | return ret_dict 28 | 29 | 30 | def calculate_l2_grad_norm( 31 | model: nn.Module, model_parallel_dim_dict: Dict[str, int], 32 | ) -> float: 33 | mp_norm_sq = torch.tensor(0., dtype=torch.float32, device="cuda") 34 | non_mp_norm_sq = torch.tensor(0., dtype=torch.float32, device="cuda") 35 | 36 | for name, param in model.named_parameters(): 37 | if param.grad is None: 38 | continue 39 | name = ".".join(x for x in name.split(".") if not x.startswith("_")) 40 | assert name in model_parallel_dim_dict 41 | if model_parallel_dim_dict[name] < 0: 42 | non_mp_norm_sq += param.grad.norm(dtype=torch.float32) 43 | else: 44 | mp_norm_sq += param.grad.norm(dtype=torch.float32) 45 | 46 | dist.all_reduce(mp_norm_sq) 47 | dist.all_reduce(non_mp_norm_sq) 48 | non_mp_norm_sq /= fs_init.get_model_parallel_world_size() 49 | 50 | return (mp_norm_sq.item() + non_mp_norm_sq.item()) ** 0.5 51 | 52 | 53 | def scale_grad(model: nn.Module, factor: float) -> None: 54 | for param in model.parameters(): 55 | if param.grad is not None: 56 | param.grad.mul_(factor) 57 | 58 | 59 | def get_param_norm_dict( 60 | model: nn.Module, model_parallel_dim_dict: Dict[str, int] 61 | ) -> Dict[str, float]: 62 | param_norm_dict = {} 63 | for name, param in model.named_parameters(): 64 | name = ".".join(x for x in name.split(".") if not x.startswith("_")) 65 | norm_sq = param.norm(dtype=torch.float32) ** 2 66 | dist.all_reduce(norm_sq) 67 | norm_sq = norm_sq.item() 68 | if model_parallel_dim_dict[name] < 0: 69 | norm_sq /= fs_init.get_model_parallel_world_size() 70 | norm = norm_sq ** 0.5 71 | param_norm_dict[name] = norm 72 | return param_norm_dict 73 | -------------------------------------------------------------------------------- /Next-DiT-MoE/loss_spacemoe.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/Next-DiT-MoE/loss_spacemoe.png -------------------------------------------------------------------------------- /Next-DiT-MoE/loss_timemoe.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/Next-DiT-MoE/loss_timemoe.png -------------------------------------------------------------------------------- /Next-DiT-MoE/loss_timespacemoe.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/Next-DiT-MoE/loss_timespacemoe.png -------------------------------------------------------------------------------- /Next-DiT-MoE/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import DiT_Llama_600M_patch2, DiT_Llama_2B_patch2, DiT_Llama_3B_patch2, DiT_Llama_7B_patch2 2 | from .models1 import DiT_Llama_600M_patch2_Spatial 3 | from .models2 import DiT_Llama_600M_patch2_Both -------------------------------------------------------------------------------- /Next-DiT-MoE/moe_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/Next-DiT-MoE/moe_model.png -------------------------------------------------------------------------------- /Next-DiT-MoE/parallel.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | import subprocess 5 | from time import sleep 6 | 7 | import torch 8 | import torch.distributed as dist 9 | 10 | import fairscale.nn.model_parallel.initialize as fs_init 11 | 12 | 13 | def _setup_dist_env_from_slurm(args): 14 | while not os.environ.get("MASTER_ADDR", ""): 15 | os.environ["MASTER_ADDR"] = subprocess.check_output( 16 | "sinfo -Nh -n %s | head -n 1 | awk '{print $1}'" % 17 | os.environ['SLURM_NODELIST'], 18 | shell=True, 19 | ).decode().strip() 20 | sleep(1) 21 | os.environ["MASTER_PORT"] = str(args.master_port) 22 | os.environ["RANK"] = os.environ["SLURM_PROCID"] 23 | os.environ["WORLD_SIZE"] = os.environ["SLURM_NPROCS"] 24 | os.environ["LOCAL_RANK"] = os.environ["SLURM_LOCALID"] 25 | os.environ["LOCAL_WORLD_SIZE"] = os.environ["SLURM_NTASKS_PER_NODE"] 26 | 27 | 28 | _INTRA_NODE_PROCESS_GROUP, _INTER_NODE_PROCESS_GROUP = None, None 29 | _LOCAL_RANK, _LOCAL_WORLD_SIZE = -1, -1 30 | 31 | 32 | def get_local_rank() -> int: 33 | return _LOCAL_RANK 34 | 35 | 36 | def get_local_world_size() -> int: 37 | return _LOCAL_WORLD_SIZE 38 | 39 | 40 | def distributed_init(args): 41 | if any([ 42 | x not in os.environ 43 | for x in ["RANK", "WORLD_SIZE", "MASTER_PORT", "MASTER_ADDR"] 44 | ]): 45 | _setup_dist_env_from_slurm(args) 46 | 47 | dist.init_process_group("nccl") 48 | fs_init.initialize_model_parallel(args.model_parallel_size) 49 | torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count()) 50 | 51 | global _LOCAL_RANK, _LOCAL_WORLD_SIZE 52 | _LOCAL_RANK = int(os.environ["LOCAL_RANK"]) 53 | _LOCAL_WORLD_SIZE = int(os.environ["LOCAL_WORLD_SIZE"]) 54 | 55 | global _INTRA_NODE_PROCESS_GROUP, _INTER_NODE_PROCESS_GROUP 56 | local_ranks, local_world_sizes = [torch.empty([dist.get_world_size()], dtype=torch.long, device="cuda") 57 | for _ in (0, 1)] 58 | dist.all_gather_into_tensor(local_ranks, torch.tensor(get_local_rank(), device="cuda")) 59 | dist.all_gather_into_tensor(local_world_sizes, torch.tensor(get_local_world_size(), device="cuda")) 60 | local_ranks, local_world_sizes = local_ranks.tolist(), local_world_sizes.tolist() 61 | node_ranks = [[0]] 62 | for i in range(1, dist.get_world_size()): 63 | if len(node_ranks[-1]) == local_world_sizes[i - 1]: 64 | node_ranks.append([]) 65 | else: 66 | assert local_world_sizes[i] == local_world_sizes[i - 1] 67 | node_ranks[-1].append(i) 68 | for ranks in node_ranks: 69 | group = dist.new_group(ranks) 70 | if dist.get_rank() in ranks: 71 | assert _INTRA_NODE_PROCESS_GROUP is None 72 | _INTRA_NODE_PROCESS_GROUP = group 73 | assert _INTRA_NODE_PROCESS_GROUP is not None 74 | 75 | if min(local_world_sizes) == max(local_world_sizes): 76 | for i in range(get_local_world_size()): 77 | group = dist.new_group(list(range(i, dist.get_world_size(), get_local_world_size()))) 78 | if i == get_local_rank(): 79 | assert _INTER_NODE_PROCESS_GROUP is None 80 | _INTER_NODE_PROCESS_GROUP = group 81 | assert _INTER_NODE_PROCESS_GROUP is not None 82 | 83 | 84 | def get_intra_node_process_group(): 85 | assert _INTRA_NODE_PROCESS_GROUP is not None, "Intra-node process group is not initialized." 86 | return _INTRA_NODE_PROCESS_GROUP 87 | 88 | 89 | def get_inter_node_process_group(): 90 | assert _INTRA_NODE_PROCESS_GROUP is not None, "Intra- and inter-node process groups are not initialized." 91 | return _INTER_NODE_PROCESS_GROUP 92 | 93 | -------------------------------------------------------------------------------- /Next-DiT-MoE/scripts/run_8gpus.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # run Next-DiT with single node 4 | 5 | # run Next-DiT 600M 6 | bash exps/600M_bs256_lr5e-4_bf16_qknorm_lognorm.sh 7 | -------------------------------------------------------------------------------- /Next-DiT-MoE/scripts/slurm/run_32gpus.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # run Next-DiT with cluster 4 | 5 | # added config here for slurm cluster using 8 GPUs 6 | 7 | # run Next-DiT 600M 8 | srun bash exps/600M_bs256_lr5e-4_bf16_qknorm_lognorm.sh 9 | # run Next-DiT 2B 10 | # srun bash exps/2B_bs256_lr5e-4_bf16_qknorm_lognorm.sh 11 | # run Next-DiT 3B 12 | # srun bash exps/3B_bs256_lr5e-4_bf16_qknorm_lognorm.sh 13 | # run Next-DiT 7B 14 | # srun bash exps/7B_bs256_lr5e-4_bf16_qknorm_lognorm.sh 15 | -------------------------------------------------------------------------------- /Next-DiT-MoE/scripts/slurm/run_8gpus.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # run Next-DiT with cluster 4 | 5 | # added config here for slurm cluster using 8 GPUs 6 | 7 | # run Next-DiT 600M 8 | srun bash exps/600M_bs256_lr5e-4_bf16_qknorm_lognorm.sh 9 | # run Next-DiT 2B 10 | # srun bash exps/2BB_bs256_lr5e-4_bf16_qknorm_lognorm.sh 11 | # run Next-DiT 3B 12 | # srun bash exps/3B_bs256_lr5e-4_bf16_qknorm_lognorm.sh 13 | # run Next-DiT 7B 14 | # srun bash exps/7B_bs256_lr5e-4_bf16_qknorm_lognorm.sh 15 | -------------------------------------------------------------------------------- /Next-DiT-MoE/transport/__init__.py: -------------------------------------------------------------------------------- 1 | from .transport import Transport, ModelType, WeightType, PathType, SNRType, Sampler 2 | 3 | def create_transport( 4 | path_type='Linear', 5 | prediction="velocity", 6 | loss_weight=None, 7 | train_eps=None, 8 | sample_eps=None, 9 | snr_type="uniform" 10 | ): 11 | """function for creating Transport object 12 | **Note**: model prediction defaults to velocity 13 | Args: 14 | - path_type: type of path to use; default to linear 15 | - learn_score: set model prediction to score 16 | - learn_noise: set model prediction to noise 17 | - velocity_weighted: weight loss by velocity weight 18 | - likelihood_weighted: weight loss by likelihood weight 19 | - train_eps: small epsilon for avoiding instability during training 20 | - sample_eps: small epsilon for avoiding instability during sampling 21 | """ 22 | 23 | if prediction == "noise": 24 | model_type = ModelType.NOISE 25 | elif prediction == "score": 26 | model_type = ModelType.SCORE 27 | else: 28 | model_type = ModelType.VELOCITY 29 | 30 | if loss_weight == "velocity": 31 | loss_type = WeightType.VELOCITY 32 | elif loss_weight == "likelihood": 33 | loss_type = WeightType.LIKELIHOOD 34 | else: 35 | loss_type = WeightType.NONE 36 | 37 | if snr_type == "lognorm": 38 | snr_type = SNRType.LOGNORM 39 | elif snr_type == "uniform": 40 | snr_type = SNRType.UNIFORM 41 | else: 42 | raise ValueError(f"Invalid snr type {snr_type}") 43 | 44 | path_choice = { 45 | "Linear": PathType.LINEAR, 46 | "GVP": PathType.GVP, 47 | "VP": PathType.VP, 48 | } 49 | 50 | path_type = path_choice[path_type] 51 | 52 | if (path_type in [PathType.VP]): 53 | train_eps = 1e-5 if train_eps is None else train_eps 54 | sample_eps = 1e-3 if train_eps is None else sample_eps 55 | elif (path_type in [PathType.GVP, PathType.LINEAR] and model_type != ModelType.VELOCITY): 56 | train_eps = 1e-3 if train_eps is None else train_eps 57 | sample_eps = 1e-3 if train_eps is None else sample_eps 58 | else: # velocity & [GVP, LINEAR] is stable everywhere 59 | train_eps = 0 60 | sample_eps = 0 61 | 62 | # create flow state 63 | state = Transport( 64 | model_type=model_type, 65 | path_type=path_type, 66 | loss_type=loss_type, 67 | train_eps=train_eps, 68 | sample_eps=sample_eps, 69 | snr_type=snr_type 70 | ) 71 | 72 | return state 73 | -------------------------------------------------------------------------------- /Next-DiT-MoE/transport/utils.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | 3 | class EasyDict: 4 | 5 | def __init__(self, sub_dict): 6 | for k, v in sub_dict.items(): 7 | setattr(self, k, v) 8 | 9 | def __getitem__(self, key): 10 | return getattr(self, key) 11 | 12 | def mean_flat(x): 13 | """ 14 | Take the mean over all non-batch dimensions. 15 | """ 16 | return th.mean(x, dim=list(range(1, len(x.size())))) 17 | 18 | def log_state(state): 19 | result = [] 20 | 21 | sorted_state = dict(sorted(state.items())) 22 | for key, value in sorted_state.items(): 23 | # Check if the value is an instance of a class 24 | if " 2 | 3 |
4 |

5 | 6 | # $\textbf{Lumina-T2X}$ : Transform text to any modality with Flow-based Large Diffusion Transformer 7 | -------------------------------------------------------------------------------- /assets/audios/a_telephone_bell_rings.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/assets/audios/a_telephone_bell_rings.wav -------------------------------------------------------------------------------- /assets/audios/a_telephone_bell_rings_gt.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/assets/audios/a_telephone_bell_rings_gt.wav -------------------------------------------------------------------------------- /assets/compositional_intro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/assets/compositional_intro.png -------------------------------------------------------------------------------- /assets/diverse_config.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/assets/diverse_config.png -------------------------------------------------------------------------------- /assets/images/demo_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/assets/images/demo_image.png -------------------------------------------------------------------------------- /assets/images/resolution_extrapolation_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/assets/images/resolution_extrapolation_2.jpg -------------------------------------------------------------------------------- /assets/lumina-intro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/assets/lumina-intro.png -------------------------------------------------------------------------------- /assets/lumina-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/assets/lumina-logo.png -------------------------------------------------------------------------------- /lumina_audio/configs/lumina-text2audio.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 3.0e-06 3 | target: models.diffusion.ddpm_audio.CFM 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.012 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: caption 12 | mel_dim: 20 13 | mel_length: 256 14 | channels: 0 15 | cond_stage_trainable: True 16 | conditioning_key: crossattn 17 | monitor: val/loss_simple_ema 18 | scale_by_std: true 19 | use_ema: false 20 | scheduler_config: 21 | target: models.lr_scheduler.LambdaLinearScheduler 22 | params: 23 | warm_up_steps: 24 | - 10000 25 | cycle_lengths: 26 | - 10000000000000 27 | f_start: 28 | - 1.0e-06 29 | f_max: 30 | - 1.0 31 | f_min: 32 | - 1.0 33 | unet_config: 34 | target: models.diffusion.flag_large_dit.FlagDiTv2 35 | params: 36 | in_channels: 20 37 | context_dim: 1024 38 | hidden_size: 768 39 | num_heads: 32 40 | depth: 16 41 | max_len: 1000 42 | 43 | first_stage_config: 44 | target: models.autoencoder1d.AutoencoderKL 45 | params: 46 | embed_dim: 20 47 | monitor: val/rec_loss 48 | ckpt_path: /path/to/ckpt/maa2/maa2.ckpt 49 | ddconfig: 50 | double_z: true 51 | in_channels: 80 52 | out_ch: 80 53 | z_channels: 20 54 | kernel_size: 5 55 | ch: 384 56 | ch_mult: 57 | - 1 58 | - 2 59 | - 4 60 | num_res_blocks: 2 61 | attn_layers: 62 | - 3 63 | down_layers: 64 | - 0 65 | dropout: 0.0 66 | lossconfig: 67 | target: torch.nn.Identity 68 | cond_stage_config: 69 | target: models.encoders.modules.FrozenCLAPFLANEmbedder 70 | params: 71 | weights_path: /path/to/ckpt/CLAP/CLAP_weights_2022.pth 72 | -------------------------------------------------------------------------------- /lumina_audio/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/lumina_audio/models/__init__.py -------------------------------------------------------------------------------- /lumina_audio/models/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/lumina_audio/models/diffusion/__init__.py -------------------------------------------------------------------------------- /lumina_audio/models/diffusion/component.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | try: 7 | from apex.normalization import FusedRMSNorm as RMSNorm 8 | except ImportError: 9 | warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation") 10 | 11 | class RMSNorm(torch.nn.Module): 12 | def __init__(self, dim: int, eps: float = 1e-6): 13 | """ 14 | Initialize the RMSNorm normalization layer. 15 | 16 | Args: 17 | dim (int): The dimension of the input tensor. 18 | eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. 19 | 20 | Attributes: 21 | eps (float): A small value added to the denominator for numerical stability. 22 | weight (nn.Parameter): Learnable scaling parameter. 23 | 24 | """ 25 | super().__init__() 26 | self.eps = eps 27 | self.weight = nn.Parameter(torch.ones(dim)) 28 | 29 | def _norm(self, x): 30 | """ 31 | Apply the RMSNorm normalization to the input tensor. 32 | 33 | Args: 34 | x (torch.Tensor): The input tensor. 35 | 36 | Returns: 37 | torch.Tensor: The normalized tensor. 38 | 39 | """ 40 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 41 | 42 | def forward(self, x): 43 | """ 44 | Forward pass through the RMSNorm layer. 45 | 46 | Args: 47 | x (torch.Tensor): The input tensor. 48 | 49 | Returns: 50 | torch.Tensor: The output tensor after applying RMSNorm. 51 | 52 | """ 53 | output = self._norm(x.float()).type_as(x) 54 | return output * self.weight 55 | -------------------------------------------------------------------------------- /lumina_audio/models/diffusion/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/lumina_audio/models/diffusion/distributions/__init__.py -------------------------------------------------------------------------------- /lumina_audio/models/diffusion/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.0]) 42 | else: 43 | sum_dim = list(range(1, len(self.mean.shape))) 44 | if other is None: 45 | 46 | return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=sum_dim) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var 51 | - 1.0 52 | - self.logvar 53 | + other.logvar, 54 | dim=sum_dim, 55 | ) 56 | 57 | def nll(self, sample, dims=[1, 2, 3]): 58 | if self.deterministic: 59 | return torch.Tensor([0.0]) 60 | logtwopi = np.log(2.0 * np.pi) 61 | return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims) 62 | 63 | def mode(self): 64 | return self.mean 65 | 66 | 67 | def normal_kl(mean1, logvar1, mean2, logvar2): 68 | """ 69 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 70 | Compute the KL divergence between two gaussians. 71 | Shapes are automatically broadcasted, so batches can be compared to 72 | scalars, among other use cases. 73 | """ 74 | tensor = None 75 | for obj in (mean1, logvar1, mean2, logvar2): 76 | if isinstance(obj, torch.Tensor): 77 | tensor = obj 78 | break 79 | assert tensor is not None, "at least one argument must be a Tensor" 80 | 81 | # Force variances to be Tensors. Broadcasting helps convert scalars to 82 | # Tensors, but it does not work for torch.exp(). 83 | logvar1, logvar2 = [x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2)] 84 | 85 | return 0.5 * ( 86 | -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 87 | ) 88 | -------------------------------------------------------------------------------- /lumina_audio/models/diffusion/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError("Decay must be between 0 and 1") 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer( 14 | "num_updates", torch.tensor(0, dtype=torch.int) if use_num_upates else torch.tensor(-1, dtype=torch.int) 15 | ) 16 | 17 | for name, p in model.named_parameters(): 18 | if p.requires_grad: 19 | # remove as '.'-character is not allowed in buffers 20 | s_name = name.replace(".", "") 21 | self.m_name2s_name.update({name: s_name}) 22 | self.register_buffer(s_name, p.clone().detach().data) 23 | 24 | self.collected_params = [] 25 | 26 | def forward(self, model): 27 | decay = self.decay 28 | 29 | if self.num_updates >= 0: 30 | self.num_updates += 1 31 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) 32 | 33 | one_minus_decay = 1.0 - decay 34 | 35 | with torch.no_grad(): 36 | m_param = dict(model.named_parameters()) 37 | shadow_params = dict(self.named_buffers()) 38 | 39 | for key in m_param: 40 | if m_param[key].requires_grad: 41 | sname = self.m_name2s_name[key] 42 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 43 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 44 | else: 45 | assert not key in self.m_name2s_name 46 | 47 | def copy_to(self, model): 48 | m_param = dict(model.named_parameters()) 49 | shadow_params = dict(self.named_buffers()) 50 | for key in m_param: 51 | if m_param[key].requires_grad: 52 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 53 | else: 54 | assert not key in self.m_name2s_name 55 | 56 | def store(self, parameters): 57 | """ 58 | Save the current parameters for restoring later. 59 | Args: 60 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 61 | temporarily stored. 62 | """ 63 | self.collected_params = [param.clone() for param in parameters] 64 | 65 | def restore(self, parameters): 66 | """ 67 | Restore the parameters stored with the `store` method. 68 | Useful to validate the model with EMA parameters without affecting the 69 | original optimization process. Store the parameters before the 70 | `copy_to` method. After validation (or model saving), use this to 71 | restore the former parameters. 72 | Args: 73 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 74 | updated with the stored parameters. 75 | """ 76 | for c_param, param in zip(self.collected_params, parameters): 77 | param.data.copy_(c_param.data) 78 | -------------------------------------------------------------------------------- /lumina_audio/models/encoders/CLAP/__init__.py: -------------------------------------------------------------------------------- 1 | from . import audio, clap, utils 2 | -------------------------------------------------------------------------------- /lumina_audio/models/encoders/CLAP/clap.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | 3 | import numpy as np 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | from transformers import AutoModel 8 | 9 | from .audio import get_audio_encoder 10 | 11 | 12 | class Projection(nn.Module): 13 | def __init__(self, d_in: int, d_out: int, p: float = 0.5) -> None: 14 | super().__init__() 15 | self.linear1 = nn.Linear(d_in, d_out, bias=False) 16 | self.linear2 = nn.Linear(d_out, d_out, bias=False) 17 | self.layer_norm = nn.LayerNorm(d_out) 18 | self.drop = nn.Dropout(p) 19 | 20 | def forward(self, x: torch.Tensor) -> torch.Tensor: 21 | embed1 = self.linear1(x) 22 | embed2 = self.drop(self.linear2(F.gelu(embed1))) 23 | embeds = self.layer_norm(embed1 + embed2) 24 | return embeds 25 | 26 | 27 | class AudioEncoder(nn.Module): 28 | def __init__( 29 | self, 30 | audioenc_name: str, 31 | d_in: int, 32 | d_out: int, 33 | sample_rate: int, 34 | window_size: int, 35 | hop_size: int, 36 | mel_bins: int, 37 | fmin: int, 38 | fmax: int, 39 | classes_num: int, 40 | ) -> None: 41 | super().__init__() 42 | 43 | audio_encoder = get_audio_encoder(audioenc_name) 44 | 45 | self.base = audio_encoder(sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num, d_in) 46 | 47 | self.projection = Projection(d_in, d_out) 48 | 49 | def forward(self, x): 50 | out_dict = self.base(x) 51 | audio_features, audio_classification_output = out_dict["embedding"], out_dict["clipwise_output"] 52 | projected_vec = self.projection(audio_features) 53 | return projected_vec, audio_classification_output 54 | 55 | 56 | class TextEncoder(nn.Module): 57 | def __init__(self, d_out: int, text_model: str, transformer_embed_dim: int) -> None: 58 | super().__init__() 59 | # if os.path.exists('/apdcephfs/share_1316500/nlphuang/results/Text_to_audio/pretrained/CLAP_AutoModel'): 60 | # root = '/apdcephfs' 61 | # else: 62 | # root = '/apdcephfs_intern' 63 | 64 | self.base = AutoModel.from_pretrained(text_model) 65 | self.projection = Projection(transformer_embed_dim, d_out) 66 | 67 | def forward(self, x): 68 | out = self.base(**x)[0] 69 | out = out[:, 0, :] # get CLS token output 70 | projected_vec = self.projection(out) 71 | return projected_vec 72 | 73 | 74 | class CLAP(nn.Module): 75 | def __init__( 76 | self, 77 | # audio 78 | audioenc_name: str, 79 | sample_rate: int, 80 | window_size: int, 81 | hop_size: int, 82 | mel_bins: int, 83 | fmin: int, 84 | fmax: int, 85 | classes_num: int, 86 | out_emb: int, 87 | # text 88 | text_model: str, 89 | transformer_embed_dim: int, 90 | # common 91 | d_proj: int, 92 | ): 93 | super().__init__() 94 | 95 | self.audio_encoder = AudioEncoder( 96 | audioenc_name, out_emb, d_proj, sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num 97 | ) 98 | 99 | self.caption_encoder = TextEncoder(d_proj, text_model, transformer_embed_dim) 100 | 101 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 102 | 103 | def forward(self, audio, text): 104 | audio_embed, _ = self.audio_encoder(audio) 105 | caption_embed = self.caption_encoder(text) 106 | 107 | return caption_embed, audio_embed, self.logit_scale.exp() 108 | -------------------------------------------------------------------------------- /lumina_audio/models/encoders/CLAP/config.yml: -------------------------------------------------------------------------------- 1 | # TEXT ENCODER CONFIG 2 | text_model: 'bert-base-uncased' 3 | text_len: 100 4 | transformer_embed_dim: 768 5 | freeze_text_encoder_weights: True 6 | 7 | # AUDIO ENCODER CONFIG 8 | audioenc_name: 'Cnn14' 9 | out_emb: 2048 10 | sampling_rate: 44100 11 | duration: 5 12 | fmin: 50 13 | fmax: 14000 14 | n_fft: 1028 15 | hop_size: 320 16 | mel_bins: 64 17 | window_size: 1024 18 | 19 | # PROJECTION SPACE CONFIG 20 | d_proj: 1024 21 | temperature: 0.003 22 | 23 | # TRAINING AND EVALUATION CONFIG 24 | num_classes: 527 25 | batch_size: 1024 26 | demo: False 27 | -------------------------------------------------------------------------------- /lumina_audio/models/encoders/CLAP/utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | 4 | import yaml 5 | 6 | 7 | def read_config_as_args(config_path, args=None, is_config_str=False): 8 | return_dict = {} 9 | 10 | if config_path is not None: 11 | if is_config_str: 12 | yml_config = yaml.load(config_path, Loader=yaml.FullLoader) 13 | else: 14 | with open(config_path, "r") as f: 15 | yml_config = yaml.load(f, Loader=yaml.FullLoader) 16 | 17 | if args != None: 18 | for k, v in yml_config.items(): 19 | if k in args.__dict__: 20 | args.__dict__[k] = v 21 | else: 22 | sys.stderr.write("Ignored unknown parameter {} in yaml.\n".format(k)) 23 | else: 24 | for k, v in yml_config.items(): 25 | return_dict[k] = v 26 | 27 | args = args if args != None else return_dict 28 | return argparse.Namespace(**args) 29 | -------------------------------------------------------------------------------- /lumina_audio/models/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/lumina_audio/models/encoders/__init__.py -------------------------------------------------------------------------------- /lumina_audio/models/vocoder/bigvgan/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/lumina_audio/models/vocoder/bigvgan/__init__.py -------------------------------------------------------------------------------- /lumina_audio/models/vocoder/bigvgan/alias_free_torch/__init__.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 2 | # LICENSE is in incl_licenses directory. 3 | 4 | from .act import * 5 | from .filter import * 6 | from .resample import * 7 | -------------------------------------------------------------------------------- /lumina_audio/models/vocoder/bigvgan/alias_free_torch/act.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 2 | # LICENSE is in incl_licenses directory. 3 | 4 | import torch.nn as nn 5 | 6 | from .resample import DownSample1d, UpSample1d 7 | 8 | 9 | class Activation1d(nn.Module): 10 | def __init__( 11 | self, activation, up_ratio: int = 2, down_ratio: int = 2, up_kernel_size: int = 12, down_kernel_size: int = 12 12 | ): 13 | super().__init__() 14 | self.up_ratio = up_ratio 15 | self.down_ratio = down_ratio 16 | self.act = activation 17 | self.upsample = UpSample1d(up_ratio, up_kernel_size) 18 | self.downsample = DownSample1d(down_ratio, down_kernel_size) 19 | 20 | # x: [B,C,T] 21 | def forward(self, x): 22 | x = self.upsample(x) 23 | x = self.act(x) 24 | x = self.downsample(x) 25 | 26 | return x 27 | -------------------------------------------------------------------------------- /lumina_audio/models/vocoder/bigvgan/alias_free_torch/filter.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 2 | # LICENSE is in incl_licenses directory. 3 | 4 | import math 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | if "sinc" in dir(torch): 11 | sinc = torch.sinc 12 | else: 13 | # This code is adopted from adefossez's julius.core.sinc under the MIT License 14 | # https://adefossez.github.io/julius/julius/core.html 15 | # LICENSE is in incl_licenses directory. 16 | def sinc(x: torch.Tensor): 17 | """ 18 | Implementation of sinc, i.e. sin(pi * x) / (pi * x) 19 | __Warning__: Different to julius.sinc, the input is multiplied by `pi`! 20 | """ 21 | return torch.where( 22 | x == 0, torch.tensor(1.0, device=x.device, dtype=x.dtype), torch.sin(math.pi * x) / math.pi / x 23 | ) 24 | 25 | 26 | # This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License 27 | # https://adefossez.github.io/julius/julius/lowpass.html 28 | # LICENSE is in incl_licenses directory. 29 | def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size] 30 | even = kernel_size % 2 == 0 31 | half_size = kernel_size // 2 32 | 33 | # For kaiser window 34 | delta_f = 4 * half_width 35 | A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 36 | if A > 50.0: 37 | beta = 0.1102 * (A - 8.7) 38 | elif A >= 21.0: 39 | beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0) 40 | else: 41 | beta = 0.0 42 | window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) 43 | 44 | # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio 45 | if even: 46 | time = torch.arange(-half_size, half_size) + 0.5 47 | else: 48 | time = torch.arange(kernel_size) - half_size 49 | if cutoff == 0: 50 | filter_ = torch.zeros_like(time) 51 | else: 52 | filter_ = 2 * cutoff * window * sinc(2 * cutoff * time) 53 | # Normalize filter to have sum = 1, otherwise we will have a small leakage 54 | # of the constant component in the input signal. 55 | filter_ /= filter_.sum() 56 | filter = filter_.view(1, 1, kernel_size) 57 | 58 | return filter 59 | 60 | 61 | class LowPassFilter1d(nn.Module): 62 | def __init__( 63 | self, 64 | cutoff=0.5, 65 | half_width=0.6, 66 | stride: int = 1, 67 | padding: bool = True, 68 | padding_mode: str = "replicate", 69 | kernel_size: int = 12, 70 | ): 71 | # kernel_size should be even number for stylegan3 setup, 72 | # in this implementation, odd number is also possible. 73 | super().__init__() 74 | if cutoff < -0.0: 75 | raise ValueError("Minimum cutoff must be larger than zero.") 76 | if cutoff > 0.5: 77 | raise ValueError("A cutoff above 0.5 does not make sense.") 78 | self.kernel_size = kernel_size 79 | self.even = kernel_size % 2 == 0 80 | self.pad_left = kernel_size // 2 - int(self.even) 81 | self.pad_right = kernel_size // 2 82 | self.stride = stride 83 | self.padding = padding 84 | self.padding_mode = padding_mode 85 | filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) 86 | self.register_buffer("filter", filter) 87 | 88 | # input [B, C, T] 89 | def forward(self, x): 90 | _, C, _ = x.shape 91 | 92 | if self.padding: 93 | x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode) 94 | out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C) 95 | 96 | return out 97 | -------------------------------------------------------------------------------- /lumina_audio/models/vocoder/bigvgan/alias_free_torch/resample.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 2 | # LICENSE is in incl_licenses directory. 3 | 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | 7 | from .filter import LowPassFilter1d, kaiser_sinc_filter1d 8 | 9 | 10 | class UpSample1d(nn.Module): 11 | def __init__(self, ratio=2, kernel_size=None): 12 | super().__init__() 13 | self.ratio = ratio 14 | self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size 15 | self.stride = ratio 16 | self.pad = self.kernel_size // ratio - 1 17 | self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 18 | self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 19 | filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size) 20 | self.register_buffer("filter", filter) 21 | 22 | # x: [B, C, T] 23 | def forward(self, x): 24 | _, C, _ = x.shape 25 | 26 | x = F.pad(x, (self.pad, self.pad), mode="replicate") 27 | x = self.ratio * F.conv_transpose1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C) 28 | x = x[..., self.pad_left : -self.pad_right] 29 | 30 | return x 31 | 32 | 33 | class DownSample1d(nn.Module): 34 | def __init__(self, ratio=2, kernel_size=None): 35 | super().__init__() 36 | self.ratio = ratio 37 | self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size 38 | self.lowpass = LowPassFilter1d( 39 | cutoff=0.5 / ratio, half_width=0.6 / ratio, stride=ratio, kernel_size=self.kernel_size 40 | ) 41 | 42 | def forward(self, x): 43 | xx = self.lowpass(x) 44 | 45 | return xx 46 | -------------------------------------------------------------------------------- /lumina_audio/n2s_openai.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | from openai import OpenAI 5 | import pandas as pd 6 | import requests 7 | 8 | openai_key = "your openai api key here" 9 | base_url = "" 10 | 11 | 12 | def get_struct(caption): 13 | if base_url != "": 14 | client = OpenAI(api_key=openai_key, base_url=base_url) 15 | else: 16 | client = OpenAI(api_key=openai_key) 17 | 18 | completion = client.chat.completions.create( 19 | model="gpt-3.5-turbo", 20 | messages=[ 21 | { 22 | "role": "user", 23 | "content": f"I want to know what sound might be in the given scene and you need to give me the results in the following format:\ 24 | Question: A bird sings on the river in the morning, a cow passes by and scares away the bird.\ 25 | Answer: @@@.\ 26 | Question: cellphone ringing a variety of tones followed by a loud explosion and fire crackling as a truck engine runs idle\ 27 | Answer: @@@\ 28 | Question: Train passing followed by short honks three times \ 29 | Answer: @\ 30 | All indicates the sound exists in the whole scene \ 31 | Start, mid, end indicates the time period the sound appear.\ 32 | Question: {caption} \ 33 | Answer:", 34 | }, 35 | ], 36 | temperature=0.0, 37 | ) 38 | 39 | return completion.choices[0].message.content 40 | 41 | 42 | def parse_args(): 43 | parser = argparse.ArgumentParser() 44 | parser.add_argument("--tsv_path", type=str) 45 | return parser.parse_args() 46 | 47 | 48 | if __name__ == "__main__": 49 | args = parse_args() 50 | tsv_path = args.tsv_path 51 | ori_df = pd.read_csv(tsv_path, sep="\t") 52 | index = 0 53 | end = len(ori_df) 54 | name = os.path.basename(tsv_path)[:-4] 55 | f = open(f"{name}.txt", "w") 56 | newcap_list = [] 57 | while index < end - 1: 58 | try: 59 | df = ori_df.iloc[index:end] 60 | for t in df.itertuples(): 61 | index = int(t[0]) 62 | ori_caption = getattr(t, "caption") 63 | strcut_cap = get_struct(ori_caption) 64 | if "sorry" in strcut_cap.lower(): 65 | strcut_cap = f"<{ori_caption.lower()}, all>" 66 | newcap_list.append(strcut_cap) 67 | f.write(f"{index}\t{strcut_cap}\n") 68 | f.flush() 69 | except: 70 | print("error") 71 | f.flush() 72 | f.close() 73 | with open(f"{name}.txt") as f: 74 | lines = f.readlines() 75 | id2cap = {} 76 | for line in lines: 77 | index, caption = line.strip().split("\t") 78 | id2cap[int(index)] = caption 79 | 80 | df = pd.read_csv(f"{name}.tsv", sep="\t") 81 | df["struct_cap"] = df.index.map(id2cap) 82 | df.to_csv(f"{name}_struct.tsv", sep="\t", index=False) 83 | -------------------------------------------------------------------------------- /lumina_audio/requirements.txt: -------------------------------------------------------------------------------- 1 | soundfile 2 | omegaconf 3 | torchdyn 4 | pytorch_lightning 5 | pytorch_memlab 6 | einops 7 | ninja 8 | torchlibrosa 9 | protobuf 10 | sentencepiece 11 | transformers 12 | gradio 13 | -------------------------------------------------------------------------------- /lumina_audio/run_audio.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/bash 2 | 3 | export HF_ENDPOINT=https://hf-mirror.com 4 | 5 | python -u demo_audio.py \ 6 | --ckpt "/path/to/ckpt/audio_generation" \ 7 | --vocoder_ckpt "/path/to/ckpt/bigvnat" \ 8 | --config_path "configs/lumina-text2audio.yaml" \ 9 | --sample_rate 16000 10 | -------------------------------------------------------------------------------- /lumina_audio/style.css: -------------------------------------------------------------------------------- 1 | .gallery { 2 | text-align: left; 3 | } 4 | -------------------------------------------------------------------------------- /lumina_music/configs/lumina-text2music.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 3.0e-06 3 | target: models.diffusion.ddpm_audio.CFM 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.012 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: caption 12 | mel_dim: 20 13 | mel_length: 256 14 | channels: 0 15 | cond_stage_trainable: True 16 | conditioning_key: crossattn 17 | monitor: val/loss_simple_ema 18 | scale_by_std: true 19 | use_ema: false 20 | scheduler_config: 21 | target: models.lr_scheduler.LambdaLinearScheduler 22 | params: 23 | warm_up_steps: 24 | - 10000 25 | cycle_lengths: 26 | - 10000000000000 27 | f_start: 28 | - 1.0e-06 29 | f_max: 30 | - 1.0 31 | f_min: 32 | - 1.0 33 | unet_config: 34 | target: models.diffusion.flag_large_dit.FlagDiTv2 35 | params: 36 | in_channels: 20 37 | context_dim: 1024 38 | hidden_size: 768 39 | num_heads: 32 40 | depth: 16 41 | max_len: 1000 42 | 43 | first_stage_config: 44 | target: models.autoencoder1d.AutoencoderKL 45 | params: 46 | embed_dim: 20 47 | monitor: val/rec_loss 48 | ckpt_path: /path/to/ckpt/maa2/maa2.ckpt 49 | ddconfig: 50 | double_z: true 51 | in_channels: 80 52 | out_ch: 80 53 | z_channels: 20 54 | kernel_size: 5 55 | ch: 384 56 | ch_mult: 57 | - 1 58 | - 2 59 | - 4 60 | num_res_blocks: 2 61 | attn_layers: 62 | - 3 63 | down_layers: 64 | - 0 65 | dropout: 0.0 66 | lossconfig: 67 | target: torch.nn.Identity 68 | cond_stage_config: 69 | target: models.encoders.modules.FrozenFLANEmbedder 70 | 71 | test_dataset: 72 | target: data.joinaudiodataset_struct_sample_anylen.TestManifest 73 | params: 74 | manifest: ./musiccaps_test_16000_struct.tsv 75 | spec_crop_len: 624 76 | -------------------------------------------------------------------------------- /lumina_music/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/lumina_music/models/__init__.py -------------------------------------------------------------------------------- /lumina_music/models/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/lumina_music/models/diffusion/__init__.py -------------------------------------------------------------------------------- /lumina_music/models/diffusion/component.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | try: 7 | from apex.normalization import FusedRMSNorm as RMSNorm 8 | except ImportError: 9 | warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation") 10 | 11 | class RMSNorm(torch.nn.Module): 12 | def __init__(self, dim: int, eps: float = 1e-6): 13 | """ 14 | Initialize the RMSNorm normalization layer. 15 | 16 | Args: 17 | dim (int): The dimension of the input tensor. 18 | eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. 19 | 20 | Attributes: 21 | eps (float): A small value added to the denominator for numerical stability. 22 | weight (nn.Parameter): Learnable scaling parameter. 23 | 24 | """ 25 | super().__init__() 26 | self.eps = eps 27 | self.weight = nn.Parameter(torch.ones(dim)) 28 | 29 | def _norm(self, x): 30 | """ 31 | Apply the RMSNorm normalization to the input tensor. 32 | 33 | Args: 34 | x (torch.Tensor): The input tensor. 35 | 36 | Returns: 37 | torch.Tensor: The normalized tensor. 38 | 39 | """ 40 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 41 | 42 | def forward(self, x): 43 | """ 44 | Forward pass through the RMSNorm layer. 45 | 46 | Args: 47 | x (torch.Tensor): The input tensor. 48 | 49 | Returns: 50 | torch.Tensor: The output tensor after applying RMSNorm. 51 | 52 | """ 53 | output = self._norm(x.float()).type_as(x) 54 | return output * self.weight 55 | -------------------------------------------------------------------------------- /lumina_music/models/diffusion/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/lumina_music/models/diffusion/distributions/__init__.py -------------------------------------------------------------------------------- /lumina_music/models/diffusion/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.0]) 42 | else: 43 | sum_dim = list(range(1, len(self.mean.shape))) 44 | if other is None: 45 | 46 | return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=sum_dim) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var 51 | - 1.0 52 | - self.logvar 53 | + other.logvar, 54 | dim=sum_dim, 55 | ) 56 | 57 | def nll(self, sample, dims=[1, 2, 3]): 58 | if self.deterministic: 59 | return torch.Tensor([0.0]) 60 | logtwopi = np.log(2.0 * np.pi) 61 | return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims) 62 | 63 | def mode(self): 64 | return self.mean 65 | 66 | 67 | def normal_kl(mean1, logvar1, mean2, logvar2): 68 | """ 69 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 70 | Compute the KL divergence between two gaussians. 71 | Shapes are automatically broadcasted, so batches can be compared to 72 | scalars, among other use cases. 73 | """ 74 | tensor = None 75 | for obj in (mean1, logvar1, mean2, logvar2): 76 | if isinstance(obj, torch.Tensor): 77 | tensor = obj 78 | break 79 | assert tensor is not None, "at least one argument must be a Tensor" 80 | 81 | # Force variances to be Tensors. Broadcasting helps convert scalars to 82 | # Tensors, but it does not work for torch.exp(). 83 | logvar1, logvar2 = [x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2)] 84 | 85 | return 0.5 * ( 86 | -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 87 | ) 88 | -------------------------------------------------------------------------------- /lumina_music/models/diffusion/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError("Decay must be between 0 and 1") 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer( 14 | "num_updates", torch.tensor(0, dtype=torch.int) if use_num_upates else torch.tensor(-1, dtype=torch.int) 15 | ) 16 | 17 | for name, p in model.named_parameters(): 18 | if p.requires_grad: 19 | # remove as '.'-character is not allowed in buffers 20 | s_name = name.replace(".", "") 21 | self.m_name2s_name.update({name: s_name}) 22 | self.register_buffer(s_name, p.clone().detach().data) 23 | 24 | self.collected_params = [] 25 | 26 | def forward(self, model): 27 | decay = self.decay 28 | 29 | if self.num_updates >= 0: 30 | self.num_updates += 1 31 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) 32 | 33 | one_minus_decay = 1.0 - decay 34 | 35 | with torch.no_grad(): 36 | m_param = dict(model.named_parameters()) 37 | shadow_params = dict(self.named_buffers()) 38 | 39 | for key in m_param: 40 | if m_param[key].requires_grad: 41 | sname = self.m_name2s_name[key] 42 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 43 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 44 | else: 45 | assert not key in self.m_name2s_name 46 | 47 | def copy_to(self, model): 48 | m_param = dict(model.named_parameters()) 49 | shadow_params = dict(self.named_buffers()) 50 | for key in m_param: 51 | if m_param[key].requires_grad: 52 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 53 | else: 54 | assert not key in self.m_name2s_name 55 | 56 | def store(self, parameters): 57 | """ 58 | Save the current parameters for restoring later. 59 | Args: 60 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 61 | temporarily stored. 62 | """ 63 | self.collected_params = [param.clone() for param in parameters] 64 | 65 | def restore(self, parameters): 66 | """ 67 | Restore the parameters stored with the `store` method. 68 | Useful to validate the model with EMA parameters without affecting the 69 | original optimization process. Store the parameters before the 70 | `copy_to` method. After validation (or model saving), use this to 71 | restore the former parameters. 72 | Args: 73 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 74 | updated with the stored parameters. 75 | """ 76 | for c_param, param in zip(self.collected_params, parameters): 77 | param.data.copy_(c_param.data) 78 | -------------------------------------------------------------------------------- /lumina_music/models/encoders/CLAP/__init__.py: -------------------------------------------------------------------------------- 1 | from . import audio, clap, utils 2 | -------------------------------------------------------------------------------- /lumina_music/models/encoders/CLAP/config.yml: -------------------------------------------------------------------------------- 1 | # TEXT ENCODER CONFIG 2 | text_model: 'bert-base-uncased' 3 | text_len: 100 4 | transformer_embed_dim: 768 5 | freeze_text_encoder_weights: True 6 | 7 | # AUDIO ENCODER CONFIG 8 | audioenc_name: 'Cnn14' 9 | out_emb: 2048 10 | sampling_rate: 44100 11 | duration: 5 12 | fmin: 50 13 | fmax: 14000 14 | n_fft: 1028 15 | hop_size: 320 16 | mel_bins: 64 17 | window_size: 1024 18 | 19 | # PROJECTION SPACE CONFIG 20 | d_proj: 1024 21 | temperature: 0.003 22 | 23 | # TRAINING AND EVALUATION CONFIG 24 | num_classes: 527 25 | batch_size: 1024 26 | demo: False 27 | -------------------------------------------------------------------------------- /lumina_music/models/encoders/CLAP/utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | 4 | import yaml 5 | 6 | 7 | def read_config_as_args(config_path, args=None, is_config_str=False): 8 | return_dict = {} 9 | 10 | if config_path is not None: 11 | if is_config_str: 12 | yml_config = yaml.load(config_path, Loader=yaml.FullLoader) 13 | else: 14 | with open(config_path, "r") as f: 15 | yml_config = yaml.load(f, Loader=yaml.FullLoader) 16 | 17 | if args != None: 18 | for k, v in yml_config.items(): 19 | if k in args.__dict__: 20 | args.__dict__[k] = v 21 | else: 22 | sys.stderr.write("Ignored unknown parameter {} in yaml.\n".format(k)) 23 | else: 24 | for k, v in yml_config.items(): 25 | return_dict[k] = v 26 | 27 | args = args if args != None else return_dict 28 | return argparse.Namespace(**args) 29 | -------------------------------------------------------------------------------- /lumina_music/models/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/lumina_music/models/encoders/__init__.py -------------------------------------------------------------------------------- /lumina_music/models/vocoder/bigvgan/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/lumina_music/models/vocoder/bigvgan/__init__.py -------------------------------------------------------------------------------- /lumina_music/models/vocoder/bigvgan/alias_free_torch/__init__.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 2 | # LICENSE is in incl_licenses directory. 3 | 4 | from .act import * 5 | from .filter import * 6 | from .resample import * 7 | -------------------------------------------------------------------------------- /lumina_music/models/vocoder/bigvgan/alias_free_torch/act.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 2 | # LICENSE is in incl_licenses directory. 3 | 4 | import torch.nn as nn 5 | 6 | from .resample import DownSample1d, UpSample1d 7 | 8 | 9 | class Activation1d(nn.Module): 10 | def __init__( 11 | self, activation, up_ratio: int = 2, down_ratio: int = 2, up_kernel_size: int = 12, down_kernel_size: int = 12 12 | ): 13 | super().__init__() 14 | self.up_ratio = up_ratio 15 | self.down_ratio = down_ratio 16 | self.act = activation 17 | self.upsample = UpSample1d(up_ratio, up_kernel_size) 18 | self.downsample = DownSample1d(down_ratio, down_kernel_size) 19 | 20 | # x: [B,C,T] 21 | def forward(self, x): 22 | x = self.upsample(x) 23 | x = self.act(x) 24 | x = self.downsample(x) 25 | 26 | return x 27 | -------------------------------------------------------------------------------- /lumina_music/models/vocoder/bigvgan/alias_free_torch/filter.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 2 | # LICENSE is in incl_licenses directory. 3 | 4 | import math 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | if "sinc" in dir(torch): 11 | sinc = torch.sinc 12 | else: 13 | # This code is adopted from adefossez's julius.core.sinc under the MIT License 14 | # https://adefossez.github.io/julius/julius/core.html 15 | # LICENSE is in incl_licenses directory. 16 | def sinc(x: torch.Tensor): 17 | """ 18 | Implementation of sinc, i.e. sin(pi * x) / (pi * x) 19 | __Warning__: Different to julius.sinc, the input is multiplied by `pi`! 20 | """ 21 | return torch.where( 22 | x == 0, torch.tensor(1.0, device=x.device, dtype=x.dtype), torch.sin(math.pi * x) / math.pi / x 23 | ) 24 | 25 | 26 | # This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License 27 | # https://adefossez.github.io/julius/julius/lowpass.html 28 | # LICENSE is in incl_licenses directory. 29 | def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size] 30 | even = kernel_size % 2 == 0 31 | half_size = kernel_size // 2 32 | 33 | # For kaiser window 34 | delta_f = 4 * half_width 35 | A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 36 | if A > 50.0: 37 | beta = 0.1102 * (A - 8.7) 38 | elif A >= 21.0: 39 | beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0) 40 | else: 41 | beta = 0.0 42 | window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) 43 | 44 | # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio 45 | if even: 46 | time = torch.arange(-half_size, half_size) + 0.5 47 | else: 48 | time = torch.arange(kernel_size) - half_size 49 | if cutoff == 0: 50 | filter_ = torch.zeros_like(time) 51 | else: 52 | filter_ = 2 * cutoff * window * sinc(2 * cutoff * time) 53 | # Normalize filter to have sum = 1, otherwise we will have a small leakage 54 | # of the constant component in the input signal. 55 | filter_ /= filter_.sum() 56 | filter = filter_.view(1, 1, kernel_size) 57 | 58 | return filter 59 | 60 | 61 | class LowPassFilter1d(nn.Module): 62 | def __init__( 63 | self, 64 | cutoff=0.5, 65 | half_width=0.6, 66 | stride: int = 1, 67 | padding: bool = True, 68 | padding_mode: str = "replicate", 69 | kernel_size: int = 12, 70 | ): 71 | # kernel_size should be even number for stylegan3 setup, 72 | # in this implementation, odd number is also possible. 73 | super().__init__() 74 | if cutoff < -0.0: 75 | raise ValueError("Minimum cutoff must be larger than zero.") 76 | if cutoff > 0.5: 77 | raise ValueError("A cutoff above 0.5 does not make sense.") 78 | self.kernel_size = kernel_size 79 | self.even = kernel_size % 2 == 0 80 | self.pad_left = kernel_size // 2 - int(self.even) 81 | self.pad_right = kernel_size // 2 82 | self.stride = stride 83 | self.padding = padding 84 | self.padding_mode = padding_mode 85 | filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) 86 | self.register_buffer("filter", filter) 87 | 88 | # input [B, C, T] 89 | def forward(self, x): 90 | _, C, _ = x.shape 91 | 92 | if self.padding: 93 | x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode) 94 | out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C) 95 | 96 | return out 97 | -------------------------------------------------------------------------------- /lumina_music/models/vocoder/bigvgan/alias_free_torch/resample.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 2 | # LICENSE is in incl_licenses directory. 3 | 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | 7 | from .filter import LowPassFilter1d, kaiser_sinc_filter1d 8 | 9 | 10 | class UpSample1d(nn.Module): 11 | def __init__(self, ratio=2, kernel_size=None): 12 | super().__init__() 13 | self.ratio = ratio 14 | self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size 15 | self.stride = ratio 16 | self.pad = self.kernel_size // ratio - 1 17 | self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 18 | self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 19 | filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size) 20 | self.register_buffer("filter", filter) 21 | 22 | # x: [B, C, T] 23 | def forward(self, x): 24 | _, C, _ = x.shape 25 | 26 | x = F.pad(x, (self.pad, self.pad), mode="replicate") 27 | x = self.ratio * F.conv_transpose1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C) 28 | x = x[..., self.pad_left : -self.pad_right] 29 | 30 | return x 31 | 32 | 33 | class DownSample1d(nn.Module): 34 | def __init__(self, ratio=2, kernel_size=None): 35 | super().__init__() 36 | self.ratio = ratio 37 | self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size 38 | self.lowpass = LowPassFilter1d( 39 | cutoff=0.5 / ratio, half_width=0.6 / ratio, stride=ratio, kernel_size=self.kernel_size 40 | ) 41 | 42 | def forward(self, x): 43 | xx = self.lowpass(x) 44 | 45 | return xx 46 | -------------------------------------------------------------------------------- /lumina_music/requirements.txt: -------------------------------------------------------------------------------- 1 | soundfile 2 | omegaconf 3 | torchdyn 4 | pytorch_lightning 5 | pytorch_memlab 6 | einops 7 | ninja 8 | torchlibrosa 9 | protobuf 10 | sentencepiece 11 | transformers 12 | gradio 13 | -------------------------------------------------------------------------------- /lumina_music/run_music.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/bash 2 | 3 | export HF_ENDPOINT=https://hf-mirror.com 4 | 5 | python -u demo_music.py \ 6 | --ckpt "/path/to/ckpt/music_generation" \ 7 | --vocoder_ckpt "/path/to/ckpt/bigvnat" \ 8 | --config_path "configs/lumina-text2music.yaml" \ 9 | --sample_rate 16000 10 | -------------------------------------------------------------------------------- /lumina_next_compositional_generation/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import NextDiT_2B_GQA_patch2, NextDiT_2B_patch2 2 | -------------------------------------------------------------------------------- /lumina_next_compositional_generation/models/components.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | try: 7 | from apex.normalization import FusedRMSNorm as RMSNorm 8 | except ImportError: 9 | warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation") 10 | 11 | class RMSNorm(torch.nn.Module): 12 | def __init__(self, dim: int, eps: float = 1e-6): 13 | """ 14 | Initialize the RMSNorm normalization layer. 15 | 16 | Args: 17 | dim (int): The dimension of the input tensor. 18 | eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. 19 | 20 | Attributes: 21 | eps (float): A small value added to the denominator for numerical stability. 22 | weight (nn.Parameter): Learnable scaling parameter. 23 | 24 | """ 25 | super().__init__() 26 | self.eps = eps 27 | self.weight = nn.Parameter(torch.ones(dim)) 28 | 29 | def _norm(self, x): 30 | """ 31 | Apply the RMSNorm normalization to the input tensor. 32 | 33 | Args: 34 | x (torch.Tensor): The input tensor. 35 | 36 | Returns: 37 | torch.Tensor: The normalized tensor. 38 | 39 | """ 40 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 41 | 42 | def forward(self, x): 43 | """ 44 | Forward pass through the RMSNorm layer. 45 | 46 | Args: 47 | x (torch.Tensor): The input tensor. 48 | 49 | Returns: 50 | torch.Tensor: The output tensor after applying RMSNorm. 51 | 52 | """ 53 | output = self._norm(x.float()).type_as(x) 54 | return output * self.weight 55 | -------------------------------------------------------------------------------- /lumina_next_compositional_generation/transport/__init__.py: -------------------------------------------------------------------------------- 1 | from .transport import ModelType, PathType, Sampler, SNRType, Transport, WeightType 2 | 3 | 4 | def create_transport( 5 | path_type="Linear", 6 | prediction="velocity", 7 | loss_weight=None, 8 | train_eps=None, 9 | sample_eps=None, 10 | snr_type="uniform", 11 | ): 12 | """function for creating Transport object 13 | **Note**: model prediction defaults to velocity 14 | Args: 15 | - path_type: type of path to use; default to linear 16 | - learn_score: set model prediction to score 17 | - learn_noise: set model prediction to noise 18 | - velocity_weighted: weight loss by velocity weight 19 | - likelihood_weighted: weight loss by likelihood weight 20 | - train_eps: small epsilon for avoiding instability during training 21 | - sample_eps: small epsilon for avoiding instability during sampling 22 | """ 23 | 24 | if prediction == "noise": 25 | model_type = ModelType.NOISE 26 | elif prediction == "score": 27 | model_type = ModelType.SCORE 28 | else: 29 | model_type = ModelType.VELOCITY 30 | 31 | if loss_weight == "velocity": 32 | loss_type = WeightType.VELOCITY 33 | elif loss_weight == "likelihood": 34 | loss_type = WeightType.LIKELIHOOD 35 | else: 36 | loss_type = WeightType.NONE 37 | 38 | if snr_type == "lognorm": 39 | snr_type = SNRType.LOGNORM 40 | elif snr_type == "uniform": 41 | snr_type = SNRType.UNIFORM 42 | else: 43 | raise ValueError(f"Invalid snr type {snr_type}") 44 | 45 | path_choice = { 46 | "Linear": PathType.LINEAR, 47 | "GVP": PathType.GVP, 48 | "VP": PathType.VP, 49 | } 50 | 51 | path_type = path_choice[path_type] 52 | 53 | if path_type in [PathType.VP]: 54 | train_eps = 1e-5 if train_eps is None else train_eps 55 | sample_eps = 1e-3 if train_eps is None else sample_eps 56 | elif path_type in [PathType.GVP, PathType.LINEAR] and model_type != ModelType.VELOCITY: 57 | train_eps = 1e-3 if train_eps is None else train_eps 58 | sample_eps = 1e-3 if train_eps is None else sample_eps 59 | else: # velocity & [GVP, LINEAR] is stable everywhere 60 | train_eps = 0 61 | sample_eps = 0 62 | 63 | # create flow state 64 | state = Transport( 65 | model_type=model_type, 66 | path_type=path_type, 67 | loss_type=loss_type, 68 | train_eps=train_eps, 69 | sample_eps=sample_eps, 70 | snr_type=snr_type, 71 | ) 72 | 73 | return state 74 | -------------------------------------------------------------------------------- /lumina_next_compositional_generation/transport/utils.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | 3 | 4 | class EasyDict: 5 | def __init__(self, sub_dict): 6 | for k, v in sub_dict.items(): 7 | setattr(self, k, v) 8 | 9 | def __getitem__(self, key): 10 | return getattr(self, key) 11 | 12 | 13 | def mean_flat(x): 14 | """ 15 | Take the mean over all non-batch dimensions. 16 | """ 17 | return th.mean(x, dim=list(range(1, len(x.size())))) 18 | 19 | 20 | def log_state(state): 21 | result = [] 22 | 23 | sorted_state = dict(sorted(state.items())) 24 | for key, value in sorted_state.items(): 25 | # Check if the value is an instance of a class 26 | if " Union[str, BytesIO]: 13 | if "s3://" in path: 14 | init_ceph_client_if_needed() 15 | file_bytes = BytesIO(client.get(path)) 16 | return file_bytes 17 | else: 18 | return path 19 | 20 | 21 | def init_ceph_client_if_needed(): 22 | global client 23 | if client is None: 24 | logger.info(f"initializing ceph client ...") 25 | st = time.time() 26 | from petrel_client.client import Client # noqa 27 | 28 | client = Client("../petreloss.conf") 29 | ed = time.time() 30 | logger.info(f"initialize client cost {ed - st:.2f} s") 31 | 32 | 33 | client = None 34 | -------------------------------------------------------------------------------- /lumina_next_t2i/grad_norm.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import fairscale.nn.model_parallel.initialize as fs_init 4 | from fairscale.nn.model_parallel.layers import ColumnParallelLinear, ParallelEmbedding, RowParallelLinear 5 | import torch 6 | import torch.distributed as dist 7 | import torch.nn as nn 8 | 9 | 10 | def get_model_parallel_dim_dict(model: nn.Module) -> Dict[str, int]: 11 | ret_dict = {} 12 | for module_name, module in model.named_modules(): 13 | 14 | def param_fqn(param_name): 15 | return param_name if module_name == "" else module_name + "." + param_name 16 | 17 | if isinstance(module, ColumnParallelLinear): 18 | ret_dict[param_fqn("weight")] = 0 19 | if module.bias is not None: 20 | ret_dict[param_fqn("bias")] = 0 21 | elif isinstance(module, RowParallelLinear): 22 | ret_dict[param_fqn("weight")] = 1 23 | if module.bias is not None: 24 | ret_dict[param_fqn("bias")] = -1 25 | elif isinstance(module, ParallelEmbedding): 26 | ret_dict[param_fqn("weight")] = 1 27 | else: 28 | for param_name, param in module.named_parameters(recurse=False): 29 | ret_dict[param_fqn(param_name)] = -1 30 | return ret_dict 31 | 32 | 33 | def calculate_l2_grad_norm( 34 | model: nn.Module, 35 | model_parallel_dim_dict: Dict[str, int], 36 | ) -> float: 37 | mp_norm_sq = torch.tensor(0.0, dtype=torch.float32, device="cuda") 38 | non_mp_norm_sq = torch.tensor(0.0, dtype=torch.float32, device="cuda") 39 | 40 | for name, param in model.named_parameters(): 41 | if param.grad is None: 42 | continue 43 | name = ".".join(x for x in name.split(".") if not x.startswith("_")) 44 | assert name in model_parallel_dim_dict 45 | if model_parallel_dim_dict[name] < 0: 46 | non_mp_norm_sq += param.grad.norm(dtype=torch.float32) ** 2 47 | else: 48 | mp_norm_sq += param.grad.norm(dtype=torch.float32) ** 2 49 | 50 | dist.all_reduce(mp_norm_sq) 51 | dist.all_reduce(non_mp_norm_sq) 52 | non_mp_norm_sq /= fs_init.get_model_parallel_world_size() 53 | 54 | return (mp_norm_sq.item() + non_mp_norm_sq.item()) ** 0.5 55 | 56 | 57 | def scale_grad(model: nn.Module, factor: float) -> None: 58 | for param in model.parameters(): 59 | if param.grad is not None: 60 | param.grad.mul_(factor) 61 | -------------------------------------------------------------------------------- /lumina_next_t2i/imgproc.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from PIL import Image 4 | import numpy as np 5 | 6 | 7 | def center_crop_arr(pil_image, image_size): 8 | """ 9 | Center cropping implementation from ADM. 10 | https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 11 | """ 12 | while min(*pil_image.size) >= 2 * image_size: 13 | pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX) 14 | 15 | scale = image_size / min(*pil_image.size) 16 | pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC) 17 | 18 | arr = np.array(pil_image) 19 | crop_y = (arr.shape[0] - image_size) // 2 20 | crop_x = (arr.shape[1] - image_size) // 2 21 | return Image.fromarray(arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size]) 22 | 23 | 24 | def center_crop(pil_image, crop_size): 25 | while pil_image.size[0] >= 2 * crop_size[0] and pil_image.size[1] >= 2 * crop_size[1]: 26 | pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX) 27 | 28 | scale = max(crop_size[0] / pil_image.size[0], crop_size[1] / pil_image.size[1]) 29 | pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC) 30 | 31 | crop_left = random.randint(0, pil_image.size[0] - crop_size[0]) 32 | crop_upper = random.randint(0, pil_image.size[1] - crop_size[1]) 33 | crop_right = crop_left + crop_size[0] 34 | crop_lower = crop_upper + crop_size[1] 35 | return pil_image.crop(box=(crop_left, crop_upper, crop_right, crop_lower)) 36 | 37 | 38 | def var_center_crop(pil_image, crop_size_list, random_top_k=4): 39 | w, h = pil_image.size 40 | rem_percent = [min(cw / w, ch / h) / max(cw / w, ch / h) for cw, ch in crop_size_list] 41 | crop_size = random.choice( 42 | sorted(((x, y) for x, y in zip(rem_percent, crop_size_list)), reverse=True)[:random_top_k] 43 | )[1] 44 | return center_crop(pil_image, crop_size) 45 | 46 | 47 | def generate_crop_size_list(num_patches, patch_size, max_ratio=4.0): 48 | assert max_ratio >= 1.0 49 | crop_size_list = [] 50 | wp, hp = num_patches, 1 51 | while wp > 0: 52 | if max(wp, hp) / min(wp, hp) <= max_ratio: 53 | crop_size_list.append((wp * patch_size, hp * patch_size)) 54 | if (hp + 1) * wp <= num_patches: 55 | hp += 1 56 | else: 57 | wp -= 1 58 | return crop_size_list 59 | -------------------------------------------------------------------------------- /lumina_next_t2i/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import NextDiT_2B_GQA_patch2, NextDiT_2B_patch2 2 | -------------------------------------------------------------------------------- /lumina_next_t2i/models/components.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | try: 7 | from apex.normalization import FusedRMSNorm as RMSNorm 8 | except ImportError: 9 | warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation") 10 | 11 | class RMSNorm(torch.nn.Module): 12 | def __init__(self, dim: int, eps: float = 1e-6): 13 | """ 14 | Initialize the RMSNorm normalization layer. 15 | 16 | Args: 17 | dim (int): The dimension of the input tensor. 18 | eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. 19 | 20 | Attributes: 21 | eps (float): A small value added to the denominator for numerical stability. 22 | weight (nn.Parameter): Learnable scaling parameter. 23 | 24 | """ 25 | super().__init__() 26 | self.eps = eps 27 | self.weight = nn.Parameter(torch.ones(dim)) 28 | 29 | def _norm(self, x): 30 | """ 31 | Apply the RMSNorm normalization to the input tensor. 32 | 33 | Args: 34 | x (torch.Tensor): The input tensor. 35 | 36 | Returns: 37 | torch.Tensor: The normalized tensor. 38 | 39 | """ 40 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 41 | 42 | def forward(self, x): 43 | """ 44 | Forward pass through the RMSNorm layer. 45 | 46 | Args: 47 | x (torch.Tensor): The input tensor. 48 | 49 | Returns: 50 | torch.Tensor: The output tensor after applying RMSNorm. 51 | 52 | """ 53 | output = self._norm(x.float()).type_as(x) 54 | return output * self.weight 55 | -------------------------------------------------------------------------------- /lumina_next_t2i/parallel.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | import subprocess 5 | from time import sleep 6 | 7 | import fairscale.nn.model_parallel.initialize as fs_init 8 | import torch 9 | import torch.distributed as dist 10 | 11 | 12 | def _setup_dist_env_from_slurm(args): 13 | while not os.environ.get("MASTER_ADDR", ""): 14 | os.environ["MASTER_ADDR"] = ( 15 | subprocess.check_output( 16 | "sinfo -Nh -n %s | head -n 1 | awk '{print $1}'" % os.environ["SLURM_NODELIST"], 17 | shell=True, 18 | ) 19 | .decode() 20 | .strip() 21 | ) 22 | sleep(1) 23 | os.environ["MASTER_PORT"] = str(args.master_port) 24 | os.environ["RANK"] = os.environ["SLURM_PROCID"] 25 | os.environ["WORLD_SIZE"] = os.environ["SLURM_NPROCS"] 26 | os.environ["LOCAL_RANK"] = os.environ["SLURM_LOCALID"] 27 | os.environ["LOCAL_WORLD_SIZE"] = os.environ["SLURM_NTASKS_PER_NODE"] 28 | 29 | 30 | _INTRA_NODE_PROCESS_GROUP, _INTER_NODE_PROCESS_GROUP = None, None 31 | _LOCAL_RANK, _LOCAL_WORLD_SIZE = -1, -1 32 | 33 | 34 | def get_local_rank() -> int: 35 | return _LOCAL_RANK 36 | 37 | 38 | def get_local_world_size() -> int: 39 | return _LOCAL_WORLD_SIZE 40 | 41 | 42 | def distributed_init(args): 43 | if any([x not in os.environ for x in ["RANK", "WORLD_SIZE", "MASTER_PORT", "MASTER_ADDR"]]): 44 | _setup_dist_env_from_slurm(args) 45 | 46 | dist.init_process_group("nccl") 47 | fs_init.initialize_model_parallel(args.model_parallel_size) 48 | torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count()) 49 | 50 | global _LOCAL_RANK, _LOCAL_WORLD_SIZE 51 | _LOCAL_RANK = int(os.environ["LOCAL_RANK"]) 52 | _LOCAL_WORLD_SIZE = int(os.environ["LOCAL_WORLD_SIZE"]) 53 | 54 | global _INTRA_NODE_PROCESS_GROUP, _INTER_NODE_PROCESS_GROUP 55 | local_ranks, local_world_sizes = [ 56 | torch.empty([dist.get_world_size()], dtype=torch.long, device="cuda") for _ in (0, 1) 57 | ] 58 | dist.all_gather_into_tensor(local_ranks, torch.tensor(get_local_rank(), device="cuda")) 59 | dist.all_gather_into_tensor(local_world_sizes, torch.tensor(get_local_world_size(), device="cuda")) 60 | local_ranks, local_world_sizes = local_ranks.tolist(), local_world_sizes.tolist() 61 | node_ranks = [[0]] 62 | for i in range(1, dist.get_world_size()): 63 | if len(node_ranks[-1]) == local_world_sizes[i - 1]: 64 | node_ranks.append([]) 65 | else: 66 | assert local_world_sizes[i] == local_world_sizes[i - 1] 67 | node_ranks[-1].append(i) 68 | for ranks in node_ranks: 69 | group = dist.new_group(ranks) 70 | if dist.get_rank() in ranks: 71 | assert _INTRA_NODE_PROCESS_GROUP is None 72 | _INTRA_NODE_PROCESS_GROUP = group 73 | assert _INTRA_NODE_PROCESS_GROUP is not None 74 | 75 | if min(local_world_sizes) == max(local_world_sizes): 76 | for i in range(get_local_world_size()): 77 | group = dist.new_group(list(range(i, dist.get_world_size(), get_local_world_size()))) 78 | if i == get_local_rank(): 79 | assert _INTER_NODE_PROCESS_GROUP is None 80 | _INTER_NODE_PROCESS_GROUP = group 81 | assert _INTER_NODE_PROCESS_GROUP is not None 82 | 83 | 84 | def get_intra_node_process_group(): 85 | assert _INTRA_NODE_PROCESS_GROUP is not None, "Intra-node process group is not initialized." 86 | return _INTRA_NODE_PROCESS_GROUP 87 | 88 | 89 | def get_inter_node_process_group(): 90 | assert _INTRA_NODE_PROCESS_GROUP is not None, "Intra- and inter-node process groups are not initialized." 91 | return _INTER_NODE_PROCESS_GROUP 92 | -------------------------------------------------------------------------------- /lumina_next_t2i/transport/__init__.py: -------------------------------------------------------------------------------- 1 | from .transport import ModelType, PathType, Sampler, Transport, WeightType 2 | 3 | 4 | def create_transport( 5 | path_type="Linear", 6 | prediction="velocity", 7 | loss_weight=None, 8 | train_eps=None, 9 | sample_eps=None, 10 | snr_type="uniform", 11 | ): 12 | """function for creating Transport object 13 | **Note**: model prediction defaults to velocity 14 | Args: 15 | - path_type: type of path to use; default to linear 16 | - learn_score: set model prediction to score 17 | - learn_noise: set model prediction to noise 18 | - velocity_weighted: weight loss by velocity weight 19 | - likelihood_weighted: weight loss by likelihood weight 20 | - train_eps: small epsilon for avoiding instability during training 21 | - sample_eps: small epsilon for avoiding instability during sampling 22 | """ 23 | 24 | if prediction == "noise": 25 | model_type = ModelType.NOISE 26 | elif prediction == "score": 27 | model_type = ModelType.SCORE 28 | else: 29 | model_type = ModelType.VELOCITY 30 | 31 | if loss_weight == "velocity": 32 | loss_type = WeightType.VELOCITY 33 | elif loss_weight == "likelihood": 34 | loss_type = WeightType.LIKELIHOOD 35 | else: 36 | loss_type = WeightType.NONE 37 | 38 | path_choice = { 39 | "Linear": PathType.LINEAR, 40 | "GVP": PathType.GVP, 41 | "VP": PathType.VP, 42 | } 43 | 44 | path_type = path_choice[path_type] 45 | 46 | if path_type in [PathType.VP]: 47 | train_eps = 1e-5 if train_eps is None else train_eps 48 | sample_eps = 1e-3 if train_eps is None else sample_eps 49 | elif path_type in [PathType.GVP, PathType.LINEAR] and model_type != ModelType.VELOCITY: 50 | train_eps = 1e-3 if train_eps is None else train_eps 51 | sample_eps = 1e-3 if train_eps is None else sample_eps 52 | else: # velocity & [GVP, LINEAR] is stable everywhere 53 | train_eps = 0 54 | sample_eps = 0 55 | 56 | # create flow state 57 | state = Transport( 58 | model_type=model_type, 59 | path_type=path_type, 60 | loss_type=loss_type, 61 | train_eps=train_eps, 62 | sample_eps=sample_eps, 63 | snr_type=snr_type, 64 | ) 65 | 66 | return state 67 | -------------------------------------------------------------------------------- /lumina_next_t2i/transport/utils.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | 3 | 4 | class EasyDict: 5 | def __init__(self, sub_dict): 6 | for k, v in sub_dict.items(): 7 | setattr(self, k, v) 8 | 9 | def __getitem__(self, key): 10 | return getattr(self, key) 11 | 12 | 13 | def mean_flat(x): 14 | """ 15 | Take the mean over all non-batch dimensions. 16 | """ 17 | return th.mean(x, dim=list(range(1, len(x.size())))) 18 | 19 | 20 | def log_state(state): 21 | result = [] 22 | 23 | sorted_state = dict(sorted(state.items())) 24 | for key, value in sorted_state.items(): 25 | # Check if the value is an instance of a class 26 | if " Union[str, BytesIO]: 13 | if "s3://" in path: 14 | init_ceph_client_if_needed() 15 | file_bytes = BytesIO(client.get(path)) 16 | return file_bytes 17 | else: 18 | return path 19 | 20 | 21 | def init_ceph_client_if_needed(): 22 | global client 23 | if client is None: 24 | logger.info(f"initializing ceph client ...") 25 | st = time.time() 26 | from petrel_client.client import Client # noqa 27 | 28 | client = Client("../petreloss.conf") 29 | ed = time.time() 30 | logger.info(f"initialize client cost {ed - st:.2f} s") 31 | 32 | 33 | client = None 34 | -------------------------------------------------------------------------------- /lumina_next_t2i_mini/grad_norm.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import torch 4 | import torch.distributed as dist 5 | import torch.nn as nn 6 | 7 | 8 | def calculate_l2_grad_norm(model: nn.Module) -> float: 9 | non_mp_norm_sq = torch.tensor(0.0, dtype=torch.float32, device="cuda") 10 | 11 | for name, param in model.named_parameters(): 12 | if param.grad is None: 13 | continue 14 | non_mp_norm_sq += param.grad.norm(dtype=torch.float32) ** 2 15 | 16 | dist.all_reduce(non_mp_norm_sq) 17 | 18 | return non_mp_norm_sq.item() ** 0.5 19 | 20 | 21 | def scale_grad(model: nn.Module, factor: float) -> None: 22 | for param in model.parameters(): 23 | if param.grad is not None: 24 | param.grad.mul_(factor) 25 | -------------------------------------------------------------------------------- /lumina_next_t2i_mini/imgproc.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from PIL import Image 4 | import numpy as np 5 | 6 | 7 | def center_crop_arr(pil_image, image_size): 8 | """ 9 | Center cropping implementation from ADM. 10 | https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 11 | """ 12 | while min(*pil_image.size) >= 2 * image_size: 13 | pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX) 14 | 15 | scale = image_size / min(*pil_image.size) 16 | pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC) 17 | 18 | arr = np.array(pil_image) 19 | crop_y = (arr.shape[0] - image_size) // 2 20 | crop_x = (arr.shape[1] - image_size) // 2 21 | return Image.fromarray(arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size]) 22 | 23 | 24 | def center_crop(pil_image, crop_size): 25 | while pil_image.size[0] >= 2 * crop_size[0] and pil_image.size[1] >= 2 * crop_size[1]: 26 | pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX) 27 | 28 | scale = max(crop_size[0] / pil_image.size[0], crop_size[1] / pil_image.size[1]) 29 | pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC) 30 | 31 | crop_left = random.randint(0, pil_image.size[0] - crop_size[0]) 32 | crop_upper = random.randint(0, pil_image.size[1] - crop_size[1]) 33 | crop_right = crop_left + crop_size[0] 34 | crop_lower = crop_upper + crop_size[1] 35 | return pil_image.crop(box=(crop_left, crop_upper, crop_right, crop_lower)) 36 | 37 | 38 | def var_center_crop(pil_image, crop_size_list, random_top_k=4): 39 | w, h = pil_image.size 40 | rem_percent = [min(cw / w, ch / h) / max(cw / w, ch / h) for cw, ch in crop_size_list] 41 | crop_size = random.choice( 42 | sorted(((x, y) for x, y in zip(rem_percent, crop_size_list)), reverse=True)[:random_top_k] 43 | )[1] 44 | return center_crop(pil_image, crop_size) 45 | 46 | 47 | def generate_crop_size_list(num_patches, patch_size, max_ratio=4.0): 48 | assert max_ratio >= 1.0 49 | crop_size_list = [] 50 | wp, hp = num_patches, 1 51 | while wp > 0: 52 | if max(wp, hp) / min(wp, hp) <= max_ratio: 53 | crop_size_list.append((wp * patch_size, hp * patch_size)) 54 | if (hp + 1) * wp <= num_patches: 55 | hp += 1 56 | else: 57 | wp -= 1 58 | return crop_size_list 59 | -------------------------------------------------------------------------------- /lumina_next_t2i_mini/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .nextdit import NextDiT_2B_GQA_patch2, NextDiT_2B_patch2 2 | -------------------------------------------------------------------------------- /lumina_next_t2i_mini/models/components.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | try: 7 | from apex.normalization import FusedRMSNorm as RMSNorm 8 | except ImportError: 9 | warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation") 10 | 11 | class RMSNorm(torch.nn.Module): 12 | def __init__(self, dim: int, eps: float = 1e-6): 13 | """ 14 | Initialize the RMSNorm normalization layer. 15 | 16 | Args: 17 | dim (int): The dimension of the input tensor. 18 | eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. 19 | 20 | Attributes: 21 | eps (float): A small value added to the denominator for numerical stability. 22 | weight (nn.Parameter): Learnable scaling parameter. 23 | 24 | """ 25 | super().__init__() 26 | self.eps = eps 27 | self.weight = nn.Parameter(torch.ones(dim)) 28 | 29 | def _norm(self, x): 30 | """ 31 | Apply the RMSNorm normalization to the input tensor. 32 | 33 | Args: 34 | x (torch.Tensor): The input tensor. 35 | 36 | Returns: 37 | torch.Tensor: The normalized tensor. 38 | 39 | """ 40 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 41 | 42 | def forward(self, x): 43 | """ 44 | Forward pass through the RMSNorm layer. 45 | 46 | Args: 47 | x (torch.Tensor): The input tensor. 48 | 49 | Returns: 50 | torch.Tensor: The output tensor after applying RMSNorm. 51 | 52 | """ 53 | output = self._norm(x.float()).type_as(x) 54 | return output * self.weight 55 | -------------------------------------------------------------------------------- /lumina_next_t2i_mini/parallel.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | import subprocess 5 | from time import sleep 6 | 7 | import torch 8 | import torch.distributed as dist 9 | 10 | 11 | def _setup_dist_env_from_slurm(args): 12 | while not os.environ.get("MASTER_ADDR", ""): 13 | os.environ["MASTER_ADDR"] = ( 14 | subprocess.check_output( 15 | "sinfo -Nh -n %s | head -n 1 | awk '{print $1}'" % os.environ["SLURM_NODELIST"], 16 | shell=True, 17 | ) 18 | .decode() 19 | .strip() 20 | ) 21 | sleep(1) 22 | os.environ["MASTER_PORT"] = str(args.master_port) 23 | os.environ["RANK"] = os.environ["SLURM_PROCID"] 24 | os.environ["WORLD_SIZE"] = os.environ["SLURM_NPROCS"] 25 | os.environ["LOCAL_RANK"] = os.environ["SLURM_LOCALID"] 26 | os.environ["LOCAL_WORLD_SIZE"] = os.environ["SLURM_NTASKS_PER_NODE"] 27 | 28 | 29 | _INTRA_NODE_PROCESS_GROUP, _INTER_NODE_PROCESS_GROUP = None, None 30 | _LOCAL_RANK, _LOCAL_WORLD_SIZE = -1, -1 31 | 32 | 33 | def get_local_rank() -> int: 34 | return _LOCAL_RANK 35 | 36 | 37 | def get_local_world_size() -> int: 38 | return _LOCAL_WORLD_SIZE 39 | 40 | 41 | def distributed_init(args): 42 | if any([x not in os.environ for x in ["RANK", "WORLD_SIZE", "MASTER_PORT", "MASTER_ADDR"]]): 43 | _setup_dist_env_from_slurm(args) 44 | 45 | dist.init_process_group("nccl") 46 | torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count()) 47 | 48 | global _LOCAL_RANK, _LOCAL_WORLD_SIZE 49 | _LOCAL_RANK = int(os.environ["LOCAL_RANK"]) 50 | _LOCAL_WORLD_SIZE = int(os.environ["LOCAL_WORLD_SIZE"]) 51 | 52 | global _INTRA_NODE_PROCESS_GROUP, _INTER_NODE_PROCESS_GROUP 53 | local_ranks, local_world_sizes = [ 54 | torch.empty([dist.get_world_size()], dtype=torch.long, device="cuda") for _ in (0, 1) 55 | ] 56 | dist.all_gather_into_tensor(local_ranks, torch.tensor(get_local_rank(), device="cuda")) 57 | dist.all_gather_into_tensor(local_world_sizes, torch.tensor(get_local_world_size(), device="cuda")) 58 | local_ranks, local_world_sizes = local_ranks.tolist(), local_world_sizes.tolist() 59 | node_ranks = [[0]] 60 | for i in range(1, dist.get_world_size()): 61 | if len(node_ranks[-1]) == local_world_sizes[i - 1]: 62 | node_ranks.append([]) 63 | else: 64 | assert local_world_sizes[i] == local_world_sizes[i - 1] 65 | node_ranks[-1].append(i) 66 | for ranks in node_ranks: 67 | group = dist.new_group(ranks) 68 | if dist.get_rank() in ranks: 69 | assert _INTRA_NODE_PROCESS_GROUP is None 70 | _INTRA_NODE_PROCESS_GROUP = group 71 | assert _INTRA_NODE_PROCESS_GROUP is not None 72 | 73 | if min(local_world_sizes) == max(local_world_sizes): 74 | for i in range(get_local_world_size()): 75 | group = dist.new_group(list(range(i, dist.get_world_size(), get_local_world_size()))) 76 | if i == get_local_rank(): 77 | assert _INTER_NODE_PROCESS_GROUP is None 78 | _INTER_NODE_PROCESS_GROUP = group 79 | assert _INTER_NODE_PROCESS_GROUP is not None 80 | 81 | 82 | def get_intra_node_process_group(): 83 | assert _INTRA_NODE_PROCESS_GROUP is not None, "Intra-node process group is not initialized." 84 | return _INTRA_NODE_PROCESS_GROUP 85 | 86 | 87 | def get_inter_node_process_group(): 88 | assert _INTRA_NODE_PROCESS_GROUP is not None, "Intra- and inter-node process groups are not initialized." 89 | return _INTER_NODE_PROCESS_GROUP 90 | -------------------------------------------------------------------------------- /lumina_next_t2i_mini/scripts/sample.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | # Lumina-Next supports any resolution (up to 2K) 5 | # res="1024:024x1024 1536:1536x1536 1664:1664x1664 1792:1792x1792 2048:2048x2048" 6 | res=1024:1024x1024 7 | t=4 8 | cfg=4.0 9 | seed=25 10 | steps=20 11 | solver=midpoint 12 | model_dir=your/model/dir/here 13 | cap_dir=your/caption/dir/here 14 | out_dir=your/output/dir/here 15 | python -u sample.py --ckpt ${model_dir} \ 16 | --image_save_path ${out_dir} \ 17 | --solver ${solver} --num_sampling_steps ${steps} \ 18 | --caption_path ${cap_dir} \ 19 | --seed ${seed} \ 20 | --resolution ${res} \ 21 | --time_shifting_factor ${t} \ 22 | --cfg_scale ${cfg} \ 23 | --batch_size 1 \ 24 | --use_flash_attn True # You can set this to False if you want to disable the flash attention 25 | -------------------------------------------------------------------------------- /lumina_next_t2i_mini/scripts/sample_img2img.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | # Lumina-Next supports any resolution (up to 2K) 5 | # res="1024:024x1024 1536:1536x1536 1664:1664x1664 1792:1792x1792 2048:2048x2048" 6 | res=1024:1024x1024 7 | t=4 8 | cfg=4.0 9 | seed=25 10 | steps=50 11 | solver=euler 12 | strength=0.6 13 | image_dir=your/image/dir/here 14 | model_dir=your/model/dir/here 15 | cap_dir=your/caption/dir/here 16 | out_dir=your/output/dir/here 17 | python -u sample_img2img.py --ckpt ${model_dir} \ 18 | --image_save_path ${out_dir} \ 19 | --solver ${solver} --num_sampling_steps ${steps} \ 20 | --caption_path ${cap_dir} \ 21 | --image ${image_dir} \ 22 | --seed ${seed} \ 23 | --resolution ${res} \ 24 | --strength ${strength} \ 25 | --time_shifting_factor ${t} \ 26 | --cfg_scale ${cfg} \ 27 | --batch_size 1 \ 28 | --use_flash_attn True # You can set this to False if you want to disable the flash attention -------------------------------------------------------------------------------- /lumina_next_t2i_mini/scripts/sample_sd3.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | # SD3 supports up to 1.2K resolution 5 | # res="1024:024x1024 1280:1280x1280" 6 | res=1024:1024x1024 7 | shift=3 8 | cfg=7.0 9 | seed=25 10 | steps=20 11 | solver=midpoint 12 | model_dir=stabilityai/stable-diffusion-3-medium-diffusers 13 | cap_dir=your/caption/dir/here 14 | out_dir=your/output/dir/here 15 | python -u sample_sd3.py --ckpt ${model_dir} \ 16 | --image_save_path ${out_dir} \ 17 | --solver ${solver} --num_sampling_steps ${steps} \ 18 | --caption_path ${cap_dir} \ 19 | --seed ${seed} \ 20 | --resolution ${res} \ 21 | --time_shifting_factor ${shift} \ 22 | --cfg_scale ${cfg} \ 23 | --batch_size 1 \ 24 | -------------------------------------------------------------------------------- /lumina_t2i/.isort.cfg: -------------------------------------------------------------------------------- 1 | [settings] 2 | profile = black 3 | line_length = 120 4 | sections = FUTURE,STDLIB,THIRDPARTY,FIRSTPARTY,LOCALFOLDER 5 | no_lines_before = STDLIB,LOCALFOLDER 6 | lines_between_types = 1 7 | combine_as_imports = True 8 | force_sort_within_sections = true 9 | order_by_type = True 10 | -------------------------------------------------------------------------------- /lumina_t2i/__init__.py: -------------------------------------------------------------------------------- 1 | from .entry_point import * 2 | -------------------------------------------------------------------------------- /lumina_t2i/configs/data/JourneyDB.yaml: -------------------------------------------------------------------------------- 1 | META: 2 | - 3 | path: '/path/to/journeyDB_train.json' 4 | -------------------------------------------------------------------------------- /lumina_t2i/configs/infer/settings.yaml: -------------------------------------------------------------------------------- 1 | - settings: 2 | 3 | model: 4 | ckpt: "" 5 | ckpt_lm: "" 6 | token: "" 7 | 8 | transport: 9 | path_type: "Linear" # option: ["Linear", "GVP", "VP"] 10 | prediction: "velocity" # option: ["velocity", "score", "noise"] 11 | loss_weight: "velocity" # option: [None, "velocity", "likelihood"] 12 | sample_eps: 0.1 13 | train_eps: 0.2 14 | 15 | ode: 16 | atol: 1e-6 # Absolute tolerance 17 | rtol: 1e-3 # Relative tolerance 18 | reverse: false # option: true or false 19 | likelihood: false # option: true or false 20 | 21 | infer: 22 | resolution: "1024x1024" # option: ["1024x1024", "512x2048", "2048x512", "(Extrapolation) 1664x1664", "(Extrapolation) 1024x2048", "(Extrapolation) 2048x1024"] 23 | num_sampling_steps: 60 # range: 1-1000 24 | cfg_scale: 4. # range: 1-20 25 | solver: "euler" # option: ["euler", "dopri5", "dopri8"] 26 | t_shift: 4 # range: 1-20 (int only) 27 | ntk_scaling: true # option: true or false 28 | proportional_attn: true # option: true or false 29 | seed: 0 # rnage: any number 30 | -------------------------------------------------------------------------------- /lumina_t2i/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_reader import * 2 | from .dataset import * 3 | -------------------------------------------------------------------------------- /lumina_t2i/data/data_reader.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | import logging 3 | import time 4 | from typing import Union 5 | 6 | from PIL import Image 7 | 8 | Image.MAX_IMAGE_PIXELS = None 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | def read_general(path) -> Union[str, BytesIO]: 13 | if "s3://" in path: 14 | init_ceph_client_if_needed() 15 | file_bytes = BytesIO(client.get(path)) 16 | return file_bytes 17 | else: 18 | return path 19 | 20 | 21 | def init_ceph_client_if_needed(): 22 | global client 23 | if client is None: 24 | logger.info(f"initializing ceph client ...") 25 | st = time.time() 26 | from petrel_client.client import Client # noqa 27 | 28 | client = Client("../petreloss.conf") 29 | ed = time.time() 30 | logger.info(f"initialize client cost {ed - st:.2f} s") 31 | 32 | 33 | client = None 34 | -------------------------------------------------------------------------------- /lumina_t2i/exps/5B_bs512_lr1e-4_bf16_1024px_sdxlvae.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | train_data_root='configs/data/JourneyDB.yaml' 4 | 5 | model=DiT_Llama_5B_patch2 6 | batch_size=512 7 | lr=1e-4 8 | precision=bf16 9 | image_size=1024 10 | vae=sdxl 11 | init_from=$1 12 | load_str=$2 13 | 14 | exp_name=${model}_bs${batch_size}_lr${lr}_${precision}_${image_size}px_vae${vae}_init${load_str} 15 | mkdir -p results/"$exp_name" 16 | 17 | torchrun --nproc-per-node=8 train.py \ 18 | --master_port 18181 \ 19 | --model ${model} \ 20 | --data_path ${train_data_root} \ 21 | --results_dir results/${exp_name} \ 22 | --micro_batch_size 2 \ 23 | --global_batch_size ${batch_size} --lr ${lr} \ 24 | --data_parallel fsdp \ 25 | --max_steps 3000000 \ 26 | --ckpt_every 2000 --log_every 10 \ 27 | --precision ${precision} --grad_precision fp32 --qk_norm \ 28 | --image_size ${image_size} \ 29 | --init_from $init_from \ 30 | --global_seed 3 \ 31 | --vae ${vae} \ 32 | 2>&1 | tee -a results/"$exp_name"/output.log 33 | -------------------------------------------------------------------------------- /lumina_t2i/exps/5B_bs512_lr1e-4_bf16_256px_sdxlvae.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | train_data_root='configs/data/JourneyDB.yaml' 4 | 5 | model=DiT_Llama_5B_patch2 6 | batch_size=512 7 | lr=1e-4 8 | precision=bf16 9 | image_size=256 10 | vae=sdxl 11 | 12 | exp_name=${model}_bs${batch_size}_lr${lr}_${precision}_${image_size}px_vae${vae} 13 | mkdir -p results/"$exp_name" 14 | 15 | torchrun --nproc-per-node=8 train.py \ 16 | --master_port 18181 \ 17 | --model ${model} \ 18 | --data_path ${train_data_root} \ 19 | --results_dir results/${exp_name} \ 20 | --micro_batch_size 16 \ 21 | --global_batch_size ${batch_size} --lr ${lr} \ 22 | --data_parallel fsdp \ 23 | --max_steps 3000000 \ 24 | --ckpt_every 20000 --log_every 100 \ 25 | --precision ${precision} --grad_precision fp32 --qk_norm \ 26 | --image_size ${image_size} \ 27 | --vae ${vae} \ 28 | 2>&1 | tee -a results/"$exp_name"/output.log 29 | -------------------------------------------------------------------------------- /lumina_t2i/exps/5B_bs512_lr1e-4_bf16_512px_sdxlvae.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | train_data_root='configs/data/JourneyDB.yaml' 4 | 5 | model=DiT_Llama_5B_patch2 6 | batch_size=512 7 | lr=1e-4 8 | precision=bf16 9 | image_size=512 10 | vae=sdxl 11 | init_from=$1 12 | load_str=$2 13 | 14 | exp_name=${model}_bs${batch_size}_lr${lr}_${precision}_${image_size}px_vae${vae}_init${load_str} 15 | mkdir -p results/"$exp_name" 16 | 17 | torchrun --nproc-per-node=8 train.py \ 18 | --master_port 18181 \ 19 | --model ${model} \ 20 | --data_path ${train_data_root} \ 21 | --results_dir results/${exp_name} \ 22 | --micro_batch_size 8 \ 23 | --global_batch_size ${batch_size} --lr ${lr} \ 24 | --data_parallel fsdp \ 25 | --max_steps 3000000 \ 26 | --ckpt_every 20000 --log_every 100 \ 27 | --precision ${precision} --grad_precision fp32 --qk_norm \ 28 | --image_size ${image_size} \ 29 | --init_from $init_from \ 30 | --global_seed 2 \ 31 | --vae ${vae} \ 32 | 2>&1 | tee -a results/"$exp_name"/output.log 33 | -------------------------------------------------------------------------------- /lumina_t2i/exps/slurm/5B_bs512_lr1e-4_bf16_1024px_sdxlvae.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | train_data_root='configs/data/JourneyDB.yaml' 4 | 5 | model=DiT_Llama_5B_patch2 6 | batch_size=512 7 | lr=1e-4 8 | precision=bf16 9 | image_size=1024 10 | vae=sdxl 11 | init_from=$1 12 | load_str=$2 13 | 14 | exp_name=${model}_bs${batch_size}_lr${lr}_${precision}_${image_size}px_vae${vae}_init${load_str} 15 | mkdir -p results/"$exp_name" 16 | 17 | python -u train.py \ 18 | --master_port 18181 \ 19 | --model ${model} \ 20 | --data_path ${train_data_root} \ 21 | --results_dir results/${exp_name} \ 22 | --micro_batch_size 2 \ 23 | --global_batch_size ${batch_size} --lr ${lr} \ 24 | --data_parallel fsdp \ 25 | --max_steps 3000000 \ 26 | --ckpt_every 2000 --log_every 10 \ 27 | --precision ${precision} --grad_precision fp32 --qk_norm \ 28 | --image_size ${image_size} \ 29 | --init_from $init_from \ 30 | --global_seed 3 \ 31 | --vae ${vae} \ 32 | 2>&1 | tee -a results/"$exp_name"/output.log 33 | -------------------------------------------------------------------------------- /lumina_t2i/exps/slurm/5B_bs512_lr1e-4_bf16_256px_sdxlvae.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | train_data_root='configs/data/JourneyDB.yaml' 4 | 5 | model=DiT_Llama_5B_patch2 6 | batch_size=512 7 | lr=1e-4 8 | precision=bf16 9 | image_size=256 10 | vae=sdxl 11 | 12 | exp_name=${model}_bs${batch_size}_lr${lr}_${precision}_${image_size}px_vae${vae} 13 | mkdir -p results/"$exp_name" 14 | 15 | python -u train.py \ 16 | --master_port 18181 \ 17 | --model ${model} \ 18 | --data_path ${train_data_root} \ 19 | --results_dir results/${exp_name} \ 20 | --micro_batch_size 16 \ 21 | --global_batch_size ${batch_size} --lr ${lr} \ 22 | --data_parallel fsdp \ 23 | --max_steps 3000000 \ 24 | --ckpt_every 20000 --log_every 100 \ 25 | --precision ${precision} --grad_precision fp32 --qk_norm \ 26 | --image_size ${image_size} \ 27 | --vae ${vae} \ 28 | 2>&1 | tee -a results/"$exp_name"/output.log 29 | -------------------------------------------------------------------------------- /lumina_t2i/exps/slurm/5B_bs512_lr1e-4_bf16_512px_sdxlvae.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | train_data_root='configs/data/JourneyDB.yaml' 4 | 5 | model=DiT_Llama_5B_patch2 6 | batch_size=512 7 | lr=1e-4 8 | precision=bf16 9 | image_size=512 10 | vae=sdxl 11 | init_from=$1 12 | load_str=$2 13 | 14 | exp_name=${model}_bs${batch_size}_lr${lr}_${precision}_${image_size}px_vae${vae}_init${load_str} 15 | mkdir -p results/"$exp_name" 16 | 17 | python -u train.py \ 18 | --master_port 18181 \ 19 | --model ${model} \ 20 | --data_path ${train_data_root} \ 21 | --results_dir results/${exp_name} \ 22 | --micro_batch_size 8 \ 23 | --global_batch_size ${batch_size} --lr ${lr} \ 24 | --data_parallel fsdp \ 25 | --max_steps 3000000 \ 26 | --ckpt_every 20000 --log_every 100 \ 27 | --precision ${precision} --grad_precision fp32 --qk_norm \ 28 | --image_size ${image_size} \ 29 | --init_from $init_from \ 30 | --global_seed 2 \ 31 | --vae ${vae} \ 32 | 2>&1 | tee -a results/"$exp_name"/output.log 33 | -------------------------------------------------------------------------------- /lumina_t2i/grad_norm.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import fairscale.nn.model_parallel.initialize as fs_init 4 | from fairscale.nn.model_parallel.layers import ColumnParallelLinear, ParallelEmbedding, RowParallelLinear 5 | import torch 6 | import torch.distributed as dist 7 | import torch.nn as nn 8 | 9 | 10 | def get_model_parallel_dim_dict(model: nn.Module) -> Dict[str, int]: 11 | ret_dict = {} 12 | for module_name, module in model.named_modules(): 13 | 14 | def param_fqn(param_name): 15 | return param_name if module_name == "" else module_name + "." + param_name 16 | 17 | if isinstance(module, ColumnParallelLinear): 18 | ret_dict[param_fqn("weight")] = 0 19 | if module.bias is not None: 20 | ret_dict[param_fqn("bias")] = 0 21 | elif isinstance(module, RowParallelLinear): 22 | ret_dict[param_fqn("weight")] = 1 23 | if module.bias is not None: 24 | ret_dict[param_fqn("bias")] = -1 25 | elif isinstance(module, ParallelEmbedding): 26 | ret_dict[param_fqn("weight")] = 1 27 | else: 28 | for param_name, param in module.named_parameters(recurse=False): 29 | ret_dict[param_fqn(param_name)] = -1 30 | return ret_dict 31 | 32 | 33 | def calculate_l2_grad_norm( 34 | model: nn.Module, 35 | model_parallel_dim_dict: Dict[str, int], 36 | ) -> float: 37 | mp_norm_sq = torch.tensor(0.0, dtype=torch.float32, device="cuda") 38 | non_mp_norm_sq = torch.tensor(0.0, dtype=torch.float32, device="cuda") 39 | 40 | for name, param in model.named_parameters(): 41 | if param.grad is None: 42 | continue 43 | name = ".".join(x for x in name.split(".") if not x.startswith("_")) 44 | assert name in model_parallel_dim_dict 45 | if model_parallel_dim_dict[name] < 0: 46 | non_mp_norm_sq += param.grad.norm(dtype=torch.float32) ** 2 47 | else: 48 | mp_norm_sq += param.grad.norm(dtype=torch.float32) ** 2 49 | 50 | dist.all_reduce(mp_norm_sq) 51 | dist.all_reduce(non_mp_norm_sq) 52 | non_mp_norm_sq /= fs_init.get_model_parallel_world_size() 53 | 54 | return (mp_norm_sq.item() + non_mp_norm_sq.item()) ** 0.5 55 | 56 | 57 | def scale_grad(model: nn.Module, factor: float) -> None: 58 | for param in model.parameters(): 59 | if param.grad is not None: 60 | param.grad.mul_(factor) 61 | -------------------------------------------------------------------------------- /lumina_t2i/imgproc.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from PIL import Image 4 | import numpy as np 5 | 6 | 7 | def center_crop_arr(pil_image, image_size): 8 | """ 9 | Center cropping implementation from ADM. 10 | https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 11 | """ 12 | while min(*pil_image.size) >= 2 * image_size: 13 | pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX) 14 | 15 | scale = image_size / min(*pil_image.size) 16 | pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC) 17 | 18 | arr = np.array(pil_image) 19 | crop_y = (arr.shape[0] - image_size) // 2 20 | crop_x = (arr.shape[1] - image_size) // 2 21 | return Image.fromarray(arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size]) 22 | 23 | 24 | def center_crop(pil_image, crop_size): 25 | while pil_image.size[0] >= 2 * crop_size[0] and pil_image.size[1] >= 2 * crop_size[1]: 26 | pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX) 27 | 28 | scale = max(crop_size[0] / pil_image.size[0], crop_size[1] / pil_image.size[1]) 29 | pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC) 30 | 31 | crop_left = random.randint(0, pil_image.size[0] - crop_size[0]) 32 | crop_upper = random.randint(0, pil_image.size[1] - crop_size[1]) 33 | crop_right = crop_left + crop_size[0] 34 | crop_lower = crop_upper + crop_size[1] 35 | return pil_image.crop(box=(crop_left, crop_upper, crop_right, crop_lower)) 36 | 37 | 38 | def var_center_crop(pil_image, crop_size_list, random_top_k=4): 39 | w, h = pil_image.size 40 | rem_percent = [min(cw / w, ch / h) / max(cw / w, ch / h) for cw, ch in crop_size_list] 41 | crop_size = random.choice( 42 | sorted(((x, y) for x, y in zip(rem_percent, crop_size_list)), reverse=True)[:random_top_k] 43 | )[1] 44 | return center_crop(pil_image, crop_size) 45 | 46 | 47 | def generate_crop_size_list(num_patches, patch_size, max_ratio=4.0): 48 | assert max_ratio >= 1.0 49 | crop_size_list = [] 50 | wp, hp = num_patches, 1 51 | while wp > 0: 52 | if max(wp, hp) / min(wp, hp) <= max_ratio: 53 | crop_size_list.append((wp * patch_size, hp * patch_size)) 54 | if (hp + 1) * wp <= num_patches: 55 | hp += 1 56 | else: 57 | wp -= 1 58 | return crop_size_list 59 | -------------------------------------------------------------------------------- /lumina_t2i/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import DiT_Llama_5B_patch2 2 | -------------------------------------------------------------------------------- /lumina_t2i/models/components.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | try: 7 | from apex.normalization import FusedRMSNorm as RMSNorm 8 | except ImportError: 9 | warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation") 10 | 11 | class RMSNorm(torch.nn.Module): 12 | def __init__(self, dim: int, eps: float = 1e-6): 13 | """ 14 | Initialize the RMSNorm normalization layer. 15 | 16 | Args: 17 | dim (int): The dimension of the input tensor. 18 | eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. 19 | 20 | Attributes: 21 | eps (float): A small value added to the denominator for numerical stability. 22 | weight (nn.Parameter): Learnable scaling parameter. 23 | 24 | """ 25 | super().__init__() 26 | self.eps = eps 27 | self.weight = nn.Parameter(torch.ones(dim)) 28 | 29 | def _norm(self, x): 30 | """ 31 | Apply the RMSNorm normalization to the input tensor. 32 | 33 | Args: 34 | x (torch.Tensor): The input tensor. 35 | 36 | Returns: 37 | torch.Tensor: The normalized tensor. 38 | 39 | """ 40 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 41 | 42 | def forward(self, x): 43 | """ 44 | Forward pass through the RMSNorm layer. 45 | 46 | Args: 47 | x (torch.Tensor): The input tensor. 48 | 49 | Returns: 50 | torch.Tensor: The output tensor after applying RMSNorm. 51 | 52 | """ 53 | output = self._norm(x.float()).type_as(x) 54 | return output * self.weight 55 | -------------------------------------------------------------------------------- /lumina_t2i/parallel.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | import subprocess 5 | from time import sleep 6 | 7 | import fairscale.nn.model_parallel.initialize as fs_init 8 | import torch 9 | import torch.distributed as dist 10 | 11 | 12 | def _setup_dist_env_from_slurm(args): 13 | while not os.environ.get("MASTER_ADDR", ""): 14 | os.environ["MASTER_ADDR"] = ( 15 | subprocess.check_output( 16 | "sinfo -Nh -n %s | head -n 1 | awk '{print $1}'" % os.environ["SLURM_NODELIST"], 17 | shell=True, 18 | ) 19 | .decode() 20 | .strip() 21 | ) 22 | sleep(1) 23 | os.environ["MASTER_PORT"] = str(args.master_port) 24 | os.environ["RANK"] = os.environ["SLURM_PROCID"] 25 | os.environ["WORLD_SIZE"] = os.environ["SLURM_NPROCS"] 26 | os.environ["LOCAL_RANK"] = os.environ["SLURM_LOCALID"] 27 | os.environ["LOCAL_WORLD_SIZE"] = os.environ["SLURM_NTASKS_PER_NODE"] 28 | 29 | 30 | _INTRA_NODE_PROCESS_GROUP, _INTER_NODE_PROCESS_GROUP = None, None 31 | _LOCAL_RANK, _LOCAL_WORLD_SIZE = -1, -1 32 | 33 | 34 | def get_local_rank() -> int: 35 | return _LOCAL_RANK 36 | 37 | 38 | def get_local_world_size() -> int: 39 | return _LOCAL_WORLD_SIZE 40 | 41 | 42 | def distributed_init(args): 43 | if any([x not in os.environ for x in ["RANK", "WORLD_SIZE", "MASTER_PORT", "MASTER_ADDR"]]): 44 | _setup_dist_env_from_slurm(args) 45 | 46 | dist.init_process_group("nccl") 47 | fs_init.initialize_model_parallel(args.model_parallel_size) 48 | torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count()) 49 | 50 | global _LOCAL_RANK, _LOCAL_WORLD_SIZE 51 | _LOCAL_RANK = int(os.environ["LOCAL_RANK"]) 52 | _LOCAL_WORLD_SIZE = int(os.environ["LOCAL_WORLD_SIZE"]) 53 | 54 | global _INTRA_NODE_PROCESS_GROUP, _INTER_NODE_PROCESS_GROUP 55 | local_ranks, local_world_sizes = [ 56 | torch.empty([dist.get_world_size()], dtype=torch.long, device="cuda") for _ in (0, 1) 57 | ] 58 | dist.all_gather_into_tensor(local_ranks, torch.tensor(get_local_rank(), device="cuda")) 59 | dist.all_gather_into_tensor(local_world_sizes, torch.tensor(get_local_world_size(), device="cuda")) 60 | local_ranks, local_world_sizes = local_ranks.tolist(), local_world_sizes.tolist() 61 | node_ranks = [[0]] 62 | for i in range(1, dist.get_world_size()): 63 | if len(node_ranks[-1]) == local_world_sizes[i - 1]: 64 | node_ranks.append([]) 65 | else: 66 | assert local_world_sizes[i] == local_world_sizes[i - 1] 67 | node_ranks[-1].append(i) 68 | for ranks in node_ranks: 69 | group = dist.new_group(ranks) 70 | if dist.get_rank() in ranks: 71 | assert _INTRA_NODE_PROCESS_GROUP is None 72 | _INTRA_NODE_PROCESS_GROUP = group 73 | assert _INTRA_NODE_PROCESS_GROUP is not None 74 | 75 | if min(local_world_sizes) == max(local_world_sizes): 76 | for i in range(get_local_world_size()): 77 | group = dist.new_group(list(range(i, dist.get_world_size(), get_local_world_size()))) 78 | if i == get_local_rank(): 79 | assert _INTER_NODE_PROCESS_GROUP is None 80 | _INTER_NODE_PROCESS_GROUP = group 81 | assert _INTER_NODE_PROCESS_GROUP is not None 82 | 83 | 84 | def get_intra_node_process_group(): 85 | assert _INTRA_NODE_PROCESS_GROUP is not None, "Intra-node process group is not initialized." 86 | return _INTRA_NODE_PROCESS_GROUP 87 | 88 | 89 | def get_inter_node_process_group(): 90 | assert _INTRA_NODE_PROCESS_GROUP is not None, "Intra- and inter-node process groups are not initialized." 91 | return _INTER_NODE_PROCESS_GROUP 92 | -------------------------------------------------------------------------------- /lumina_t2i/requirements.txt: -------------------------------------------------------------------------------- 1 | diffusers 2 | fairscale 3 | accelerate 4 | tensorboard 5 | transformers 6 | gradio 7 | torchdiffeq 8 | click 9 | -------------------------------------------------------------------------------- /lumina_t2i/transport/__init__.py: -------------------------------------------------------------------------------- 1 | from .transport import ModelType, PathType, Sampler, SNRType, Transport, WeightType 2 | 3 | 4 | def create_transport( 5 | path_type="Linear", 6 | prediction="velocity", 7 | loss_weight=None, 8 | train_eps=None, 9 | sample_eps=None, 10 | snr_type="uniform", 11 | ): 12 | """function for creating Transport object 13 | **Note**: model prediction defaults to velocity 14 | Args: 15 | - path_type: type of path to use; default to linear 16 | - learn_score: set model prediction to score 17 | - learn_noise: set model prediction to noise 18 | - velocity_weighted: weight loss by velocity weight 19 | - likelihood_weighted: weight loss by likelihood weight 20 | - train_eps: small epsilon for avoiding instability during training 21 | - sample_eps: small epsilon for avoiding instability during sampling 22 | """ 23 | 24 | if prediction == "noise": 25 | model_type = ModelType.NOISE 26 | elif prediction == "score": 27 | model_type = ModelType.SCORE 28 | else: 29 | model_type = ModelType.VELOCITY 30 | 31 | if loss_weight == "velocity": 32 | loss_type = WeightType.VELOCITY 33 | elif loss_weight == "likelihood": 34 | loss_type = WeightType.LIKELIHOOD 35 | else: 36 | loss_type = WeightType.NONE 37 | 38 | if snr_type == "lognorm": 39 | snr_type = SNRType.LOGNORM 40 | elif snr_type == "uniform": 41 | snr_type = SNRType.UNIFORM 42 | else: 43 | raise ValueError(f"Invalid snr type {snr_type}") 44 | 45 | path_choice = { 46 | "Linear": PathType.LINEAR, 47 | "GVP": PathType.GVP, 48 | "VP": PathType.VP, 49 | } 50 | 51 | path_type = path_choice[path_type] 52 | 53 | if path_type in [PathType.VP]: 54 | train_eps = 1e-5 if train_eps is None else train_eps 55 | sample_eps = 1e-3 if train_eps is None else sample_eps 56 | elif path_type in [PathType.GVP, PathType.LINEAR] and model_type != ModelType.VELOCITY: 57 | train_eps = 1e-3 if train_eps is None else train_eps 58 | sample_eps = 1e-3 if train_eps is None else sample_eps 59 | else: # velocity & [GVP, LINEAR] is stable everywhere 60 | train_eps = 0 61 | sample_eps = 0 62 | 63 | # create flow state 64 | state = Transport( 65 | model_type=model_type, 66 | path_type=path_type, 67 | loss_type=loss_type, 68 | train_eps=train_eps, 69 | sample_eps=sample_eps, 70 | snr_type=snr_type, 71 | ) 72 | 73 | return state 74 | -------------------------------------------------------------------------------- /lumina_t2i/transport/utils.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | 3 | 4 | class EasyDict: 5 | def __init__(self, sub_dict): 6 | for k, v in sub_dict.items(): 7 | setattr(self, k, v) 8 | 9 | def __getitem__(self, key): 10 | return getattr(self, key) 11 | 12 | 13 | def mean_flat(x): 14 | """ 15 | Take the mean over all non-batch dimensions. 16 | """ 17 | return th.mean(x, dim=list(range(1, len(x.size())))) 18 | 19 | 20 | def log_state(state): 21 | result = [] 22 | 23 | sorted_state = dict(sorted(state.items())) 24 | for key, value in sorted_state.items(): 25 | # Check if the value is an instance of a class 26 | if "=61.0"] 9 | build-backend = "setuptools.build_meta" 10 | 11 | [project] 12 | name = "lumina-t2x" # REQUIRED, is the only field that cannot be marked as dynamic. 13 | version = "1.5.0" # REQUIRED, although can be dynamic 14 | description = "Lumina-T2X is a model for Text to Any Modality Generation" 15 | readme = "README.md" 16 | 17 | requires-python = ">=3.10" 18 | 19 | license = {file = "LICENSE.txt"} 20 | 21 | keywords = ["generation", "multi-modal", "transformer", "aigc", "diffusion"] 22 | 23 | authors = [ 24 | {name = "Alpha-VLLM", email = "author@example.com" } 25 | ] 26 | 27 | maintainers = [ 28 | {name = "Chris Liu", email = "author@example.com" }, 29 | {name = "PommesPeter", email = "xepxa6823@gmail.com" } 30 | ] 31 | 32 | classifiers = [ 33 | # How mature is this project? Common values are 34 | # 3 - Alpha 35 | # 4 - Beta 36 | # 5 - Production/Stable 37 | "Development Status :: 3 - Alpha", 38 | "Programming Language :: Python :: 3", 39 | "Intended Audience :: Developers", 40 | "License :: OSI Approved :: Apache Software License", 41 | ] 42 | 43 | dependencies = [ 44 | "diffusers", 45 | "fairscale", 46 | "accelerate", 47 | "tensorboard", 48 | "transformers", 49 | "gradio", 50 | "torchdiffeq", 51 | "click" 52 | ] 53 | 54 | [project.optional-dependencies] 55 | dev = ["coverage", "pre-commit", "isort", "black"] 56 | image = ["diffusers", "fairscale", "accelerate", "tensorboard", "transformers", "gradio", "torchdiffeq", "click" 57 | ] 58 | music = ["soundfile", "omegaconf", "torchdyn", "pytorch_lightning", "pytorch_memlab", "einops", "ninja", "torchlibrosa", "protobuf", "sentencepiece", "gradio", "transformers"] 59 | audio = ["soundfile", "omegaconf", "torchdyn", "pytorch_lightning", "pytorch_memlab", "einops", "ninja", "torchlibrosa", "protobuf", "sentencepiece", "gradio", "transformers"] 60 | 61 | 62 | [project.scripts] 63 | lumina = "lumina_t2i:entry_point" 64 | lumina_next = "lumina_next_t2i:entry_point" 65 | 66 | [tool.setuptools.packages.find] 67 | exclude = ["assets*"] 68 | 69 | [tool.wheel] 70 | exclude = ["assets*"] 71 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | diffusers 2 | fairscale 3 | accelerate 4 | tensorboard 5 | transformers 6 | gradio 7 | torchdiffeq 8 | click 9 | -------------------------------------------------------------------------------- /visual_anagrams/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/visual_anagrams/.DS_Store -------------------------------------------------------------------------------- /visual_anagrams/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Daniel Geng 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /visual_anagrams/MANIFEST.in: -------------------------------------------------------------------------------- 1 | include visual_anagrams/views/assets/4x4/*.png 2 | include visual_anagrams/assets/CourierPrime-Regular.ttf 3 | -------------------------------------------------------------------------------- /visual_anagrams/animate.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | 3 | from visual_anagrams.views import get_views 4 | from visual_anagrams.animate import animate_two_view, animate_two_view_motion_blur 5 | from visual_anagrams.views.view_motion import MotionBlurView 6 | 7 | 8 | if __name__ == '__main__': 9 | import argparse 10 | import pickle 11 | from pathlib import Path 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--im_path", required=True, type=str, help='Path to the illusion to animate') 15 | parser.add_argument("--save_video_path", default=None, type=str, 16 | help='Path to save video to. If None, defaults to `im_path`, with extension `.mp4`') 17 | parser.add_argument("--metadata_path", default=None, type=str, help='Path to metadata. If specified, overrides `view` and `prompt` args') 18 | parser.add_argument("--view", default=None, type=str, help='Name of view to use') 19 | parser.add_argument("--prompt_1", default='', nargs='+', type=str, 20 | help='Prompt for first view. Passing multiple will join them with newlines.') 21 | parser.add_argument("--prompt_2", default='', nargs='+', type=str, 22 | help='Prompt for first view. Passing multiple will join them with newlines.') 23 | args = parser.parse_args() 24 | 25 | 26 | # Load image to animate 27 | im_path = Path(args.im_path) 28 | im = Image.open(im_path) 29 | 30 | # Get save dir 31 | if args.save_video_path is None: 32 | save_video_path = im_path.with_suffix('.mp4') 33 | 34 | # Get prompts and views from metadata 35 | if args.metadata_path is None: 36 | # Join prompts with newlines 37 | prompt_1 = '\n'.join(args.prompt_1) 38 | prompt_2 = '\n'.join(args.prompt_2) 39 | 40 | # Get paths and views 41 | view = get_views([args.view])[0] 42 | else: 43 | with open(args.metadata_path, 'rb') as f: 44 | metadata = pickle.load(f) 45 | view = metadata['views'][1] 46 | m_args = metadata['args'] 47 | prompt_1 = f'{m_args.style} {m_args.prompts[0]}'.strip() 48 | prompt_2 = f'{m_args.style} {m_args.prompts[1]}'.strip() 49 | 50 | # Get sizes 51 | im_size = im.size[0] 52 | frame_size = int(im_size * 1.5) 53 | 54 | if any([isinstance(view, MotionBlurView) for view in metadata['views']]): 55 | # Animate specifically motion blur views 56 | animate_two_view_motion_blur( 57 | im, 58 | view, 59 | prompt_1, 60 | prompt_2, 61 | save_video_path=save_video_path, 62 | hold_duration=60, 63 | text_fade_duration=10, 64 | transition_duration=2000, 65 | im_size=im_size, 66 | frame_size=frame_size, 67 | ) 68 | else: 69 | # Animate all other views 70 | animate_two_view( 71 | im, 72 | view, 73 | prompt_1, 74 | prompt_2, 75 | save_video_path=save_video_path, 76 | hold_duration=120, 77 | text_fade_duration=10, 78 | transition_duration=45, 79 | im_size=im_size, 80 | frame_size=frame_size, 81 | ) -------------------------------------------------------------------------------- /visual_anagrams/environment.yml: -------------------------------------------------------------------------------- 1 | name: visual_anagrams 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - _openmp_mutex=5.1=1_gnu 7 | - ca-certificates=2023.08.22=h06a4308_0 8 | - ld_impl_linux-64=2.38=h1181459_1 9 | - libffi=3.4.4=h6a678d5_0 10 | - libgcc-ng=11.2.0=h1234567_1 11 | - libgomp=11.2.0=h1234567_1 12 | - libstdcxx-ng=11.2.0=h1234567_1 13 | - ncurses=6.4=h6a678d5_0 14 | - openssl=3.0.12=h7f8727e_0 15 | - pip=23.3.1=py39h06a4308_0 16 | - python=3.9.18=h955ad1f_0 17 | - readline=8.2=h5eee18b_0 18 | - setuptools=68.0.0=py39h06a4308_0 19 | - sqlite=3.41.2=h5eee18b_0 20 | - tk=8.6.12=h1ccaba5_0 21 | - tzdata=2023c=h04d1e81_0 22 | - wheel=0.41.2=py39h06a4308_0 23 | - xz=5.4.5=h5eee18b_0 24 | - zlib=1.2.13=h5eee18b_0 25 | - pip: 26 | - accelerate==0.25.0 27 | - certifi==2023.11.17 28 | - charset-normalizer==3.3.2 29 | - diffusers==0.24.0 30 | - einops==0.7.0 31 | - filelock==3.13.1 32 | - fsspec==2023.12.1 33 | - huggingface-hub==0.19.4 34 | - idna==3.6 35 | - imageio==2.33.0 36 | - imageio-ffmpeg==0.4.9 37 | - importlib-metadata==7.0.0 38 | - jinja2==3.1.2 39 | - markupsafe==2.1.3 40 | - mpmath==1.3.0 41 | - networkx==3.2.1 42 | - numpy==1.26.2 43 | - nvidia-cublas-cu12==12.1.3.1 44 | - nvidia-cuda-cupti-cu12==12.1.105 45 | - nvidia-cuda-nvrtc-cu12==12.1.105 46 | - nvidia-cuda-runtime-cu12==12.1.105 47 | - nvidia-cudnn-cu12==8.9.2.26 48 | - nvidia-cufft-cu12==11.0.2.54 49 | - nvidia-curand-cu12==10.3.2.106 50 | - nvidia-cusolver-cu12==11.4.5.107 51 | - nvidia-cusparse-cu12==12.1.0.106 52 | - nvidia-nccl-cu12==2.18.1 53 | - nvidia-nvjitlink-cu12==12.3.101 54 | - nvidia-nvtx-cu12==12.1.105 55 | - packaging==23.2 56 | - pillow==10.1.0 57 | - psutil==5.9.6 58 | - pyyaml==6.0.1 59 | - regex==2023.10.3 60 | - requests==2.31.0 61 | - safetensors==0.4.1 62 | - sentencepiece==0.1.99 63 | - sympy==1.12 64 | - tokenizers==0.15.0 65 | - torch==2.1.1 66 | - torchvision==0.16.1 67 | - tqdm==4.66.1 68 | - transformers==4.35.2 69 | - triton==2.1.0 70 | - typing-extensions==4.8.0 71 | - urllib3==2.1.0 72 | - zipp==3.17.0 73 | -------------------------------------------------------------------------------- /visual_anagrams/huggingface_login.py: -------------------------------------------------------------------------------- 1 | from huggingface_hub import login 2 | login() -------------------------------------------------------------------------------- /visual_anagrams/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .nextdit import NextDiT_2B_GQA_patch2, NextDiT_2B_patch2, DiT_Llama_2B_GQA_patch2 2 | -------------------------------------------------------------------------------- /visual_anagrams/models/components.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | try: 7 | from apex.normalization import FusedRMSNorm as RMSNorm 8 | except ImportError: 9 | warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation") 10 | 11 | class RMSNorm(torch.nn.Module): 12 | def __init__(self, dim: int, eps: float = 1e-6): 13 | """ 14 | Initialize the RMSNorm normalization layer. 15 | 16 | Args: 17 | dim (int): The dimension of the input tensor. 18 | eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. 19 | 20 | Attributes: 21 | eps (float): A small value added to the denominator for numerical stability. 22 | weight (nn.Parameter): Learnable scaling parameter. 23 | 24 | """ 25 | super().__init__() 26 | self.eps = eps 27 | self.weight = nn.Parameter(torch.ones(dim)) 28 | 29 | def _norm(self, x): 30 | """ 31 | Apply the RMSNorm normalization to the input tensor. 32 | 33 | Args: 34 | x (torch.Tensor): The input tensor. 35 | 36 | Returns: 37 | torch.Tensor: The normalized tensor. 38 | 39 | """ 40 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 41 | 42 | def forward(self, x): 43 | """ 44 | Forward pass through the RMSNorm layer. 45 | 46 | Args: 47 | x (torch.Tensor): The input tensor. 48 | 49 | Returns: 50 | torch.Tensor: The output tensor after applying RMSNorm. 51 | 52 | """ 53 | output = self._norm(x.float()).type_as(x) 54 | return output * self.weight 55 | -------------------------------------------------------------------------------- /visual_anagrams/readme.md: -------------------------------------------------------------------------------- 1 |

2 | 3 |
4 |

5 | 6 | # Lumina-Pro Visual Anagrams 7 | 8 | `Lumina-Pro Visual Anagrams` is an implementation of the paper [Visual Anagrams: Generating Multi-View Optical Illusions with Diffusion Models](https://dangeng.github.io/visual_anagrams/) based on `Lumina-Pro`. 9 | 10 | 11 | ## Installation 12 | 13 | Please refer to the `Lumina-Pro` folder. 14 | 15 | 16 | ## Usage 17 | 18 | To generate an illusion, replace `path_to_your_ckpt` in the file `run.sh` with your Lumina-Pro model path and run the following command: 19 | ```bash 20 | bash run.sh 21 | ``` 22 | Here is a description of some useful arguments in the script: 23 | 24 | - `--name`: Name for the illusion. Will save samples to `./results/{name}`. 25 | - `--prompts`: A list of prompts for illusions 26 | - `--style`: Optional style prompt to prepend to each of the prompts. For example, could be `"an oil painting of"`. Saves some writing. 27 | - `--views`: A list of views to use. Must match the number of prompts. For a list of views see the `get_views` function in `visual_anagrams/views/__init__.py`. (Note: Only rotation and flip are supported so far.) 28 | - `--num_samples`: Number of illusions to sample. -------------------------------------------------------------------------------- /visual_anagrams/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # res=1024:1024x1024 4 | # res=2048:2048x2048 5 | # res=2560:2560x2560 6 | res=4096:4096x4096 7 | t=7 8 | cfg=8.0 9 | seed=17 10 | steps=30 11 | model_dir=path_to_your_ckpt 12 | n=10 13 | 14 | CUDA_VISIBLE_DEVICES=1 \ 15 | python generate.py --name flip.campfire.man \ 16 | --prompts "people at a campfire" "an old man"\ 17 | --style "an oil painting of" \ 18 | --views identity flip \ 19 | --num_samples ${n} \ 20 | --num_inference_steps ${steps} \ 21 | --ckpt ${model_dir} \ 22 | --seed ${seed} \ 23 | --resolution ${res} \ 24 | --time_shifting_factor ${t} \ 25 | --cfg_scale ${cfg} \ 26 | --batch_size 1 \ 27 | --use_flash_attn True # You can set this to False if you want to disable the flash attention 28 | -------------------------------------------------------------------------------- /visual_anagrams/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='visual_anagrams', 5 | version='0.1', 6 | packages=find_packages(), 7 | include_package_data=True, 8 | install_requires=[], 9 | ) 10 | -------------------------------------------------------------------------------- /visual_anagrams/visual_anagrams/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/visual_anagrams/visual_anagrams/.DS_Store -------------------------------------------------------------------------------- /visual_anagrams/visual_anagrams/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/visual_anagrams/visual_anagrams/__init__.py -------------------------------------------------------------------------------- /visual_anagrams/visual_anagrams/assets/CourierPrime-Regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/visual_anagrams/visual_anagrams/assets/CourierPrime-Regular.ttf -------------------------------------------------------------------------------- /visual_anagrams/visual_anagrams/views/__init__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from PIL import Image 3 | import numpy as np 4 | 5 | from .view_identity import IdentityView 6 | from .view_flip import FlipView 7 | from .view_rotate import Rotate180View, Rotate90CCWView, Rotate90CWView 8 | from .view_negate import NegateView 9 | from .view_skew import SkewView 10 | from .view_patch_permute import PatchPermuteView 11 | from .view_jigsaw import JigsawView 12 | from .view_inner_circle import InnerCircleView, InnerCircleViewFailure 13 | from .view_square_hinge import SquareHingeView 14 | from .view_blur import BlurViewFailure 15 | from .view_white_balance import WhiteBalanceViewFailure 16 | from .view_hybrid import HybridLowPassView, HybridHighPassView, \ 17 | TripleHybridHighPassView, TripleHybridLowPassView, \ 18 | TripleHybridMediumPassView 19 | from .view_color import ColorView, GrayscaleView 20 | from .view_motion import MotionBlurResView, MotionBlurView 21 | from .view_scale import ScaleView 22 | 23 | VIEW_MAP = { 24 | 'identity': IdentityView, 25 | 'flip': FlipView, 26 | 'rotate_cw': Rotate90CWView, 27 | 'rotate_ccw': Rotate90CCWView, 28 | 'rotate_180': Rotate180View, 29 | 'negate': NegateView, 30 | 'skew': SkewView, 31 | 'patch_permute': PatchPermuteView, 32 | 'pixel_permute': PatchPermuteView, 33 | 'jigsaw': JigsawView, 34 | 'inner_circle': InnerCircleView, 35 | 'square_hinge': SquareHingeView, 36 | 'inner_circle_failure': InnerCircleViewFailure, 37 | 'blur_failure': BlurViewFailure, 38 | 'white_balance_failure': WhiteBalanceViewFailure, 39 | 'low_pass': HybridLowPassView, 40 | 'high_pass': HybridHighPassView, 41 | 'triple_low_pass': TripleHybridLowPassView, 42 | 'triple_medium_pass': TripleHybridMediumPassView, 43 | 'triple_high_pass': TripleHybridHighPassView, 44 | 'grayscale': GrayscaleView, 45 | 'color': ColorView, 46 | 'motion': MotionBlurView, 47 | 'motion_res': MotionBlurResView, 48 | 'scale': ScaleView, 49 | } 50 | 51 | def get_anagrams_views(view_names, view_args=None): 52 | ''' 53 | Bespoke function to get views (just to make command line usage easier) 54 | ''' 55 | 56 | views = [] 57 | if view_args is None: 58 | view_args = [None for _ in view_names] 59 | 60 | for view_name, view_arg in zip(view_names, view_args): 61 | if view_name == 'patch_permute': 62 | args = [8 if view_arg is None else int(view_arg)] 63 | elif view_name == 'pixel_permute': 64 | args = [64 if view_arg is None else int(view_arg)] 65 | elif view_name == 'skew': 66 | args = [1.5 if view_arg is None else float(view_arg)] 67 | elif view_name in ['low_pass', 'high_pass']: 68 | args = [2.0 if view_arg is None else float(view_arg)] 69 | elif view_name in ['scale']: 70 | args = [0.5 if view_arg is None else float(view_arg)] 71 | else: 72 | args = [] 73 | 74 | view = VIEW_MAP[view_name](*args) 75 | views.append(view) 76 | 77 | return views 78 | -------------------------------------------------------------------------------- /visual_anagrams/visual_anagrams/views/assets/4x4/4x4_corner_1024.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/visual_anagrams/visual_anagrams/views/assets/4x4/4x4_corner_1024.png -------------------------------------------------------------------------------- /visual_anagrams/visual_anagrams/views/assets/4x4/4x4_corner_256.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/visual_anagrams/visual_anagrams/views/assets/4x4/4x4_corner_256.png -------------------------------------------------------------------------------- /visual_anagrams/visual_anagrams/views/assets/4x4/4x4_corner_64.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/visual_anagrams/visual_anagrams/views/assets/4x4/4x4_corner_64.png -------------------------------------------------------------------------------- /visual_anagrams/visual_anagrams/views/assets/4x4/4x4_edge1_1024.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/visual_anagrams/visual_anagrams/views/assets/4x4/4x4_edge1_1024.png -------------------------------------------------------------------------------- /visual_anagrams/visual_anagrams/views/assets/4x4/4x4_edge1_256.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/visual_anagrams/visual_anagrams/views/assets/4x4/4x4_edge1_256.png -------------------------------------------------------------------------------- /visual_anagrams/visual_anagrams/views/assets/4x4/4x4_edge1_64.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/visual_anagrams/visual_anagrams/views/assets/4x4/4x4_edge1_64.png -------------------------------------------------------------------------------- /visual_anagrams/visual_anagrams/views/assets/4x4/4x4_edge2_1024.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/visual_anagrams/visual_anagrams/views/assets/4x4/4x4_edge2_1024.png -------------------------------------------------------------------------------- /visual_anagrams/visual_anagrams/views/assets/4x4/4x4_edge2_256.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/visual_anagrams/visual_anagrams/views/assets/4x4/4x4_edge2_256.png -------------------------------------------------------------------------------- /visual_anagrams/visual_anagrams/views/assets/4x4/4x4_edge2_64.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/visual_anagrams/visual_anagrams/views/assets/4x4/4x4_edge2_64.png -------------------------------------------------------------------------------- /visual_anagrams/visual_anagrams/views/assets/4x4/4x4_inner_1024.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/visual_anagrams/visual_anagrams/views/assets/4x4/4x4_inner_1024.png -------------------------------------------------------------------------------- /visual_anagrams/visual_anagrams/views/assets/4x4/4x4_inner_256.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/visual_anagrams/visual_anagrams/views/assets/4x4/4x4_inner_256.png -------------------------------------------------------------------------------- /visual_anagrams/visual_anagrams/views/assets/4x4/4x4_inner_64.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/visual_anagrams/visual_anagrams/views/assets/4x4/4x4_inner_64.png -------------------------------------------------------------------------------- /visual_anagrams/visual_anagrams/views/jigsaw_helpers.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from PIL import Image 3 | import numpy as np 4 | 5 | def get_jigsaw_pieces(size): 6 | ''' 7 | Load all pieces of the 4x4 jigsaw puzzle. 8 | 9 | size (int) : 10 | Should be 64, 256, or 1024 indicating side length of jigsaw puzzle 11 | ''' 12 | 13 | # Location of pieces 14 | piece_dir = Path(__file__).parent / 'assets' 15 | 16 | # Helper function to load pieces as np arrays 17 | def load_pieces(path): 18 | ''' 19 | Load a piece, from the given path, as a binary numpy array. 20 | Return a list of the "base" piece, and all four of its rotations. 21 | ''' 22 | piece = Image.open(path) 23 | piece = np.array(piece)[:,:,0] // 255 24 | pieces = np.stack([np.rot90(piece, k=-i) for i in range(4)]) 25 | return pieces 26 | 27 | # Load pieces and rotate to get 16 pieces, and cat 28 | pieces_corner = load_pieces(piece_dir / f'4x4/4x4_corner_{size}.png') 29 | pieces_inner = load_pieces(piece_dir / f'4x4/4x4_inner_{size}.png') 30 | pieces_edge1 = load_pieces(piece_dir / f'4x4/4x4_edge1_{size}.png') 31 | pieces_edge2 = load_pieces(piece_dir / f'4x4/4x4_edge2_{size}.png') 32 | pieces = np.concatenate([pieces_corner, pieces_inner, pieces_edge1, pieces_edge2]) 33 | 34 | return pieces 35 | 36 | -------------------------------------------------------------------------------- /visual_anagrams/visual_anagrams/views/view_base.py: -------------------------------------------------------------------------------- 1 | class BaseView: 2 | ''' 3 | BaseView class, from which all views inherit. Implements the 4 | following functions: 5 | ''' 6 | 7 | def __init__(self): 8 | pass 9 | 10 | def view(self, im): 11 | ''' 12 | Apply transform to an image. 13 | 14 | im (`torch.tensor`): 15 | For stage 1: Tensor of shape (3, H, W) representing a noisy image 16 | OR 17 | For stage 2: Tensor of shape (6, H, W) representing a noisy image 18 | concatenated with an upsampled conditioning image from stage 1 19 | ''' 20 | raise NotImplementedError() 21 | 22 | def inverse_view(self, noise): 23 | ''' 24 | Apply inverse transform to noise estimates. 25 | Because DeepFloyd estimates the variance in addition to 26 | the noise, this function must apply the inverse to the 27 | variance as well. 28 | 29 | noise (`torch.tensor`): 30 | Tensor of shape (6, H, W) representing the noise estimate 31 | (first three channel dims) and variance estimates (last 32 | three channel dims) 33 | ''' 34 | raise NotImplementedError() 35 | 36 | def make_frame(self, im, t): 37 | ''' 38 | Make a frame, transitioning linearly from the identity view (t=0) 39 | to this view (t=1) 40 | 41 | im (`PIL.Image`): 42 | A PIL Image of the illusion 43 | 44 | t (float): 45 | A float in [0,1] indicating time in the animation. Should start 46 | at the identity view at t=0, and continuously transition to the 47 | view at t=1. 48 | ''' 49 | raise NotImplementedError() 50 | -------------------------------------------------------------------------------- /visual_anagrams/visual_anagrams/views/view_blur.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | 4 | import torch 5 | import torchvision.transforms.functional as TF 6 | 7 | from .view_base import BaseView 8 | 9 | 10 | class BlurViewFailure(BaseView): 11 | ''' 12 | A failing blur view, which blurs an image, in an attempt 13 | to synthesize hybrid images. 14 | ''' 15 | def __init__(self, factor=8): 16 | self.factor = factor 17 | 18 | def make_frame(self, im, t): 19 | im_size = im.size[0] 20 | frame_size = int(im_size * 1.5) 21 | new_size = int( im_size / (1 + (self.factor - 1) * t) ) 22 | 23 | # Convert to tensor 24 | im = torch.tensor(np.array(im) / 255.).permute(2,0,1) 25 | 26 | # Resize to new size 27 | im = TF.resize(im, new_size) 28 | 29 | # Convert back to PIL 30 | im = Image.fromarray((np.array(im.permute(1,2,0)) * 255.).astype(np.uint8)) 31 | 32 | # Paste on to canvas 33 | frame = Image.new('RGB', (frame_size, frame_size), (255, 255, 255)) 34 | frame.paste(im, ((frame_size - new_size) // 2, (frame_size - new_size) // 2)) 35 | 36 | return frame 37 | 38 | def view(self, im): 39 | im_size = im.shape[-1] 40 | 41 | # Downsample then upsample to "blur" 42 | im_small = TF.resize(im, im_size // self.factor) 43 | im_blur = TF.resize(im_small, im_size) 44 | 45 | return im_blur 46 | 47 | def inverse_view(self, noise): 48 | # The transform is technically uninvertible, so just do pass through 49 | return noise 50 | 51 | -------------------------------------------------------------------------------- /visual_anagrams/visual_anagrams/views/view_color.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | 4 | import torch 5 | 6 | from .view_base import BaseView 7 | 8 | def make_frame_color(im, t): 9 | im_size = im.size[0] 10 | frame_size = int(im_size * 1.5) 11 | 12 | # Convert to tensor 13 | im = torch.tensor(np.array(im) / 255.).permute(2,0,1) 14 | 15 | # Extract color and greyscale components 16 | im_grey = im.clone() 17 | im_grey[:] = im.mean(dim=0, keepdim=True) 18 | im_color = im - im_grey 19 | 20 | # Take linear interpolation 21 | im = im_grey + t * im_color 22 | 23 | # Convert back to PIL 24 | im = Image.fromarray((np.array(im.permute(1,2,0)) * 255.).astype(np.uint8)) 25 | 26 | # Paste on to canvas 27 | frame = Image.new('RGB', (frame_size, frame_size), (255, 255, 255)) 28 | frame.paste(im, ((frame_size - im_size) // 2, (frame_size - im_size) // 2)) 29 | 30 | return frame 31 | 32 | class GrayscaleView(BaseView): 33 | def __init__(self): 34 | pass 35 | 36 | def make_frame(self, im, t): 37 | return make_frame_color(im, t) 38 | 39 | def view(self, im): 40 | return im 41 | 42 | def save_view(self, im): 43 | im = torch.stack([im.mean(0)] * 3) 44 | return im 45 | 46 | def inverse_view(self, noise): 47 | # Get grayscale component by averaging color channels 48 | noise[:3] = torch.stack([noise[:3].mean(0)] * 3) 49 | return noise 50 | 51 | 52 | class ColorView(BaseView): 53 | def __init__(self): 54 | pass 55 | 56 | def make_frame(self, im, t): 57 | return make_frame_color(im, t) 58 | 59 | def view(self, im): 60 | return im 61 | 62 | def inverse_view(self, noise): 63 | # Get color component by taking residual 64 | noise[:3] = noise[:3] - torch.stack([noise[:3].mean(0)] * 3) 65 | return noise -------------------------------------------------------------------------------- /visual_anagrams/visual_anagrams/views/view_flip.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | 3 | import torch 4 | 5 | from .view_base import BaseView 6 | 7 | class FlipView(BaseView): 8 | def __init__(self): 9 | pass 10 | 11 | def view(self, im): 12 | return torch.flip(im, [1]) 13 | 14 | def inverse_view(self, noise): 15 | return torch.flip(noise, [1]) 16 | 17 | def make_frame(self, im, t): 18 | im_size = im.size[0] 19 | frame_size = int(im_size * 1.5) 20 | theta = -t * 180 21 | 22 | # TODO: Technically not a flip, change this to a homography later 23 | frame = Image.new('RGB', (frame_size, frame_size), (255, 255, 255)) 24 | frame.paste(im, ((frame_size - im_size) // 2, (frame_size - im_size) // 2)) 25 | frame = frame.rotate(theta, 26 | resample=Image.Resampling.BILINEAR, 27 | expand=False, 28 | fillcolor=(255,255,255)) 29 | 30 | return frame 31 | -------------------------------------------------------------------------------- /visual_anagrams/visual_anagrams/views/view_identity.py: -------------------------------------------------------------------------------- 1 | from .view_base import BaseView 2 | 3 | class IdentityView(BaseView): 4 | def __init__(self): 5 | pass 6 | 7 | def view(self, im): 8 | return im 9 | 10 | def inverse_view(self, noise): 11 | return noise 12 | -------------------------------------------------------------------------------- /visual_anagrams/visual_anagrams/views/view_motion.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | from .view_base import BaseView 8 | 9 | def make_frame_motion(im, t): 10 | im_size = im.size[0] 11 | factor = im_size / 64 / 4.0 12 | frame_size = int(im_size * 1.5) 13 | vel = 20 14 | amp = 29 * factor / 2 # 29 @ 256, 29 * 4 @ 1024, need to divide by 2 b/c amp 15 | 16 | # Triangular wave 17 | offset = int(amp * 2 / np.pi * np.arcsin(np.sin(2 * np.pi * vel * t))) 18 | 19 | # Paste on to canvas 20 | frame = Image.new('RGB', (frame_size, frame_size), (255, 255, 255)) 21 | frame.paste(im, (offset + (frame_size - im_size) // 2, offset + (frame_size - im_size) // 2)) 22 | 23 | return frame 24 | 25 | class MotionBlurView(BaseView): 26 | def __init__(self, size=7): 27 | self.size = size 28 | 29 | def make_frame(self, im, t): 30 | return make_frame_motion(im, t) 31 | 32 | def view(self, im): 33 | return im 34 | 35 | def inverse_view(self, noise): 36 | c, h, w = noise.shape 37 | factor = h // 64 # Account for image size 38 | 39 | # Make kernel on the fly 40 | size = self.size * factor 41 | size = size + ((factor + 1) % 2) # Make sure it's odd 42 | self.K = torch.eye(size)[None, None] / size 43 | self.K = self.K.to(noise.dtype).to(noise.device) 44 | 45 | # Apply kernel to each channel independently 46 | noise[:3] = torch.cat([F.conv2d(noise[i][None], self.K, padding=size//2) for i in range(3)]) 47 | 48 | return noise 49 | 50 | def save_view(self, im): 51 | c, h, w = im.shape 52 | factor = h // 64 # Account for image size 53 | 54 | # Make kernel on the fly 55 | size = self.size * factor 56 | size = size + ((factor + 1) % 2) # Make sure it's odd 57 | self.K = torch.eye(size)[None, None] / size 58 | self.K = self.K.to(im.dtype).to(im.device) 59 | 60 | # Apply kernel to each channel independently 61 | im = torch.cat([F.conv2d(im[i][None], self.K, padding=size//2) for i in range(3)]) 62 | 63 | return im 64 | 65 | 66 | 67 | class MotionBlurResView(BaseView): 68 | def __init__(self, size=7): 69 | self.size = size 70 | 71 | def make_frame(self, im, t): 72 | return make_frame_motion(im, t) 73 | 74 | def view(self, im): 75 | return im 76 | 77 | def inverse_view(self, noise): 78 | c, h, w = noise.shape 79 | factor = h // 64 # Account for image size 80 | 81 | # Make kernel on the fly 82 | size = self.size * factor 83 | size = size + ((factor + 1) % 2) # Make sure it's odd 84 | self.K = torch.eye(size)[None, None] / size 85 | self.K = self.K.to(noise.dtype).to(noise.device) 86 | 87 | # Apply kernel to each channel independently, and take residual 88 | noise[:3] = noise[:3] - torch.cat([F.conv2d(noise[i][None], self.K, padding=size//2) for i in range(3)]) 89 | 90 | return noise -------------------------------------------------------------------------------- /visual_anagrams/visual_anagrams/views/view_negate.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | 4 | import torch 5 | 6 | from .view_base import BaseView 7 | 8 | class NegateView(BaseView): 9 | def __init__(self): 10 | pass 11 | 12 | def view(self, im): 13 | return -im 14 | 15 | def inverse_view(self, noise): 16 | ''' 17 | Negating the variance estimate is "weird" so just don't do it. 18 | This hack seems to work just fine 19 | ''' 20 | invert_mask = torch.ones_like(noise) 21 | invert_mask[:3] = -1 22 | return noise * invert_mask 23 | 24 | def make_frame(self, im, t): 25 | im_size = im.size[0] 26 | frame_size = int(im_size * 1.5) 27 | 28 | # map t from [0, 1] -> [1, -1] 29 | t = 1 - t 30 | t = t * 2 - 1 31 | 32 | # Interpolate from pixels from [0, 1] to [1, 0] 33 | im = np.array(im) / 255. 34 | im = ((2 * im - 1) * t + 1) / 2. 35 | im = Image.fromarray((im * 255.).astype(np.uint8)) 36 | 37 | # Paste on to canvas 38 | frame = Image.new('RGB', (frame_size, frame_size), (255, 255, 255)) 39 | frame.paste(im, ((frame_size - im_size) // 2, (frame_size - im_size) // 2)) 40 | 41 | return frame 42 | -------------------------------------------------------------------------------- /visual_anagrams/visual_anagrams/views/view_rotate.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | 3 | import torchvision.transforms.functional as TF 4 | import torch 5 | from torchvision.transforms import InterpolationMode 6 | 7 | from .view_base import BaseView 8 | 9 | 10 | class Rotate90CWView(BaseView): 11 | def __init__(self): 12 | pass 13 | 14 | def view(self, im): 15 | # TODO: Is nearest-exact better? 16 | # return TF.rotate(im, -90, interpolation=InterpolationMode.NEAREST) # clockwise 90 17 | return torch.rot90(im, -1, dims=[1, 2]) 18 | 19 | def inverse_view(self, noise): 20 | # return TF.rotate(noise, 90, interpolation=InterpolationMode.NEAREST) # counter-clockwise 90 21 | return torch.rot90(noise, 1, dims=[1, 2]) 22 | 23 | def make_frame(self, im, t): 24 | im_size = im.size[0] 25 | frame_size = int(im_size * 1.5) 26 | theta = t * -90 27 | 28 | frame = Image.new('RGB', (frame_size, frame_size), (255, 255, 255)) 29 | centered_loc = (frame_size - im_size) // 2 30 | frame.paste(im, (centered_loc, centered_loc)) 31 | frame = frame.rotate(theta, 32 | resample=Image.Resampling.BILINEAR, 33 | expand=False, 34 | fillcolor=(255,255,255)) 35 | 36 | return frame 37 | 38 | 39 | class Rotate90CCWView(BaseView): 40 | def __init__(self): 41 | pass 42 | 43 | def view(self, im): 44 | # TODO: Is nearest-exact better? 45 | # return TF.rotate(im, 90, interpolation=InterpolationMode.NEAREST) 46 | return torch.rot90(im, 1, dims=[1, 2]) 47 | 48 | def inverse_view(self, noise): 49 | # return TF.rotate(noise, -90, interpolation=InterpolationMode.NEAREST) 50 | return torch.rot90(noise, -1, dims=[1, 2]) 51 | 52 | def make_frame(self, im, t): 53 | im_size = im.size[0] 54 | frame_size = int(im_size * 1.5) 55 | theta = t * 90 56 | 57 | frame = Image.new('RGB', (frame_size, frame_size), (255, 255, 255)) 58 | centered_loc = (frame_size - im_size) // 2 59 | frame.paste(im, (centered_loc, centered_loc)) 60 | frame = frame.rotate(theta, 61 | resample=Image.Resampling.BILINEAR, 62 | expand=False, 63 | fillcolor=(255,255,255)) 64 | 65 | return frame 66 | 67 | 68 | class Rotate180View(BaseView): 69 | def __init__(self): 70 | pass 71 | 72 | def view(self, im): 73 | # TODO: Is nearest-exact better? 74 | # return TF.rotate(im, 180, interpolation=InterpolationMode.NEAREST) 75 | return torch.rot90(im, 2, dims=[1, 2]) 76 | 77 | def inverse_view(self, noise): 78 | # return TF.rotate(noise, -180, interpolation=InterpolationMode.NEAREST) 79 | return torch.rot90(noise, -2, dims=[1, 2]) 80 | 81 | def make_frame(self, im, t): 82 | im_size = im.size[0] 83 | frame_size = int(im_size * 1.5) 84 | theta = t * 180 85 | 86 | frame = Image.new('RGB', (frame_size, frame_size), (255, 255, 255)) 87 | centered_loc = (frame_size - im_size) // 2 88 | frame.paste(im, (centered_loc, centered_loc)) 89 | frame = frame.rotate(theta, 90 | resample=Image.Resampling.BILINEAR, 91 | expand=False, 92 | fillcolor=(255,255,255)) 93 | 94 | return frame 95 | -------------------------------------------------------------------------------- /visual_anagrams/visual_anagrams/views/view_scale.py: -------------------------------------------------------------------------------- 1 | from .view_base import BaseView 2 | 3 | class ScaleView(BaseView): 4 | def __init__(self, scale=0.5): 5 | self.scale = scale 6 | 7 | def view(self, im): 8 | return im 9 | 10 | def inverse_view(self, noise): 11 | noise[:3] = self.scale * noise[:3] 12 | return noise -------------------------------------------------------------------------------- /visual_anagrams/visual_anagrams/views/view_skew.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | 4 | import torch 5 | 6 | from .view_base import BaseView 7 | 8 | 9 | class SkewView(BaseView): 10 | def __init__(self, skew_factor=1.5): 11 | self.skew_factor = skew_factor 12 | 13 | def skew_image(self, im, skew_factor): 14 | ''' 15 | Roll each column of the image by increasing displacements. 16 | This is a permutation of pixels 17 | ''' 18 | 19 | # Params 20 | c,h,w = im.shape 21 | h_center = h//2 22 | 23 | # Roll columns 24 | cols = [] 25 | for i in range(w): 26 | d = int(skew_factor * (i - h_center)) # Displacement 27 | col = im[:,:,i] 28 | cols.append(col.roll(d, dims=1)) 29 | 30 | # Stack rolled columns 31 | skewed = torch.stack(cols, dim=2) 32 | return skewed 33 | 34 | def view(self, im): 35 | return self.skew_image(im, self.skew_factor) 36 | 37 | def inverse_view(self, noise): 38 | return self.skew_image(noise, -self.skew_factor) 39 | 40 | def make_frame(self, im, t): 41 | im_size = im.size[0] 42 | frame_size = int(im_size * 1.5) 43 | skew_factor = t * self.skew_factor 44 | 45 | # Convert to tensor, skew, then convert back to PIL 46 | im = torch.tensor(np.array(im) / 255.).permute(2,0,1) 47 | im = self.skew_image(im, skew_factor) 48 | im = Image.fromarray((np.array(im.permute(1,2,0)) * 255.).astype(np.uint8)) 49 | 50 | # Paste on to canvas 51 | frame = Image.new('RGB', (frame_size, frame_size), (255, 255, 255)) 52 | frame.paste(im, ((frame_size - im_size) // 2, (frame_size - im_size) // 2)) 53 | 54 | return frame 55 | 56 | -------------------------------------------------------------------------------- /visual_anagrams/visual_anagrams/views/view_white_balance.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | 4 | import torch 5 | import torchvision.transforms.functional as TF 6 | 7 | from .view_base import BaseView 8 | 9 | 10 | class WhiteBalanceViewFailure(BaseView): 11 | ''' 12 | A failing white balancing view, which simply scales the pixel values 13 | by some constant factor. An attempt to reproduce the "dress" illusion 14 | ''' 15 | def __init__(self, factor=1.5): 16 | self.factor = factor 17 | 18 | def make_frame(self, im, t): 19 | im_size = im.size[0] 20 | frame_size = int(im_size * 1.5) 21 | 22 | # Interpolate factor on t 23 | factor = 1 + (self.factor - 1) * t 24 | 25 | # Convert to tensor 26 | im = torch.tensor(np.array(im) / 255.).permute(2,0,1) 27 | 28 | # Adjust colors 29 | im = im * factor 30 | im = torch.clip(im, 0, 1) 31 | 32 | # Convert back to PIL 33 | im = Image.fromarray((np.array(im.permute(1,2,0)) * 255.).astype(np.uint8)) 34 | 35 | # Paste on to canvas 36 | frame = Image.new('RGB', (frame_size, frame_size), (255, 255, 255)) 37 | frame.paste(im, ((frame_size - im_size) // 2, (frame_size - im_size) // 2)) 38 | 39 | return frame 40 | 41 | def view(self, im): 42 | return im * self.factor 43 | 44 | def inverse_view(self, noise): 45 | noise[:3] = noise[:3] / self.factor 46 | return noise 47 | 48 | --------------------------------------------------------------------------------