├── .gitignore
├── .isort.cfg
├── .pre-commit-config.yaml
├── Flag-DiT-ImageNet
├── README.md
├── exps
│ ├── 600M_bs256_lr5e-4_bf16_qknorm_lognorm.sh
│ └── slurm
│ │ ├── 3B_bs256_lr5e-4_bf16_qknorm_lognorm.sh
│ │ ├── 600M_bs256_lr5e-4_bf16_qknorm_lognorm.sh
│ │ └── 7B_bs256_lr5e-4_bf16_qknorm_lognorm.sh
├── grad_norm.py
├── models
│ ├── __init__.py
│ ├── components.py
│ └── model.py
├── parallel.py
├── scripts
│ ├── run_8gpus.sh
│ └── slurm
│ │ ├── run_32gpus.sh
│ │ └── run_8gpus.sh
├── train.py
└── transport
│ ├── __init__.py
│ ├── integrators.py
│ ├── path.py
│ ├── transport.py
│ └── utils.py
├── LICENSE
├── Next-DiT-ImageNet
├── .isort.cfg
├── README.md
├── exps
│ ├── 600M_bs256_lr5e-4_bf16_qknorm_lognorm.sh
│ └── slurm
│ │ ├── 2B_bs256_lr5e-4_bf16_qknorm_lognorm.sh
│ │ ├── 3B_bs256_lr5e-4_bf16_qknorm_lognorm.sh
│ │ ├── 600M_bs256_lr5e-4_bf16_qknorm_lognorm.sh
│ │ └── 7B_bs256_lr5e-4_bf16_qknorm_lognorm.sh
├── fid_is.png
├── grad_norm.py
├── init_loss.py
├── models
│ ├── __init__.py
│ └── models.py
├── parallel.py
├── sample.py
├── scripts
│ ├── run_8gpus.sh
│ └── slurm
│ │ ├── run_32gpus.sh
│ │ └── run_8gpus.sh
├── train.py
└── transport
│ ├── __init__.py
│ ├── integrators.py
│ ├── path.py
│ ├── transport.py
│ └── utils.py
├── Next-DiT-MoE
├── README.md
├── exps
│ ├── 600M_bs256_lr5e-4_bf16_qknorm_lognorm.sh
│ └── slurm
│ │ ├── 2B_bs256_lr5e-4_bf16_qknorm_lognorm.sh
│ │ ├── 3B_bs256_lr5e-4_bf16_qknorm_lognorm.sh
│ │ ├── 600M_bs256_lr5e-4_bf16_qknorm_lognorm.sh
│ │ └── 7B_bs256_lr5e-4_bf16_qknorm_lognorm.sh
├── grad_norm.py
├── loss_spacemoe.png
├── loss_timemoe.png
├── loss_timespacemoe.png
├── models
│ ├── __init__.py
│ ├── models.py
│ ├── models1.py
│ └── models2.py
├── moe_model.png
├── parallel.py
├── sample.py
├── scripts
│ ├── run_8gpus.sh
│ └── slurm
│ │ ├── run_32gpus.sh
│ │ └── run_8gpus.sh
├── train.py
└── transport
│ ├── __init__.py
│ ├── integrators.py
│ ├── path.py
│ ├── transport.py
│ └── utils.py
├── README.md
├── README_cn.md
├── assets
├── audios
│ ├── a_telephone_bell_rings.wav
│ └── a_telephone_bell_rings_gt.wav
├── compositional_intro.png
├── diverse_config.png
├── images
│ ├── demo_image.png
│ └── resolution_extrapolation_2.jpg
├── lumina-intro.png
└── lumina-logo.png
├── lumina_audio
├── README.md
├── configs
│ └── lumina-text2audio.yaml
├── demo_audio.py
├── models
│ ├── __init__.py
│ ├── autoencoder1d.py
│ ├── diffusion
│ │ ├── __init__.py
│ │ ├── component.py
│ │ ├── ddim.py
│ │ ├── ddpm.py
│ │ ├── ddpm_audio.py
│ │ ├── distributions
│ │ │ ├── __init__.py
│ │ │ └── distributions.py
│ │ ├── ema.py
│ │ ├── flag_large_dit.py
│ │ └── util.py
│ ├── encoders
│ │ ├── CLAP
│ │ │ ├── CLAPWrapper.py
│ │ │ ├── __init__.py
│ │ │ ├── audio.py
│ │ │ ├── clap.py
│ │ │ ├── config.yml
│ │ │ └── utils.py
│ │ ├── __init__.py
│ │ └── modules.py
│ ├── lr_scheduler.py
│ ├── util.py
│ └── vocoder
│ │ └── bigvgan
│ │ ├── __init__.py
│ │ ├── activations.py
│ │ ├── alias_free_torch
│ │ ├── __init__.py
│ │ ├── act.py
│ │ ├── filter.py
│ │ └── resample.py
│ │ └── models.py
├── n2s_openai.py
├── requirements.txt
├── run_audio.sh
└── style.css
├── lumina_music
├── README.md
├── configs
│ └── lumina-text2music.yaml
├── demo_music.py
├── models
│ ├── __init__.py
│ ├── autoencoder1d.py
│ ├── diffusion
│ │ ├── __init__.py
│ │ ├── component.py
│ │ ├── ddim.py
│ │ ├── ddpm.py
│ │ ├── ddpm_audio.py
│ │ ├── distributions
│ │ │ ├── __init__.py
│ │ │ └── distributions.py
│ │ ├── ema.py
│ │ ├── flag_large_dit.py
│ │ └── util.py
│ ├── encoders
│ │ ├── CLAP
│ │ │ ├── CLAPWrapper.py
│ │ │ ├── __init__.py
│ │ │ ├── audio.py
│ │ │ ├── clap.py
│ │ │ ├── config.yml
│ │ │ └── utils.py
│ │ ├── __init__.py
│ │ └── modules.py
│ ├── lr_scheduler.py
│ ├── util.py
│ └── vocoder
│ │ └── bigvgan
│ │ ├── __init__.py
│ │ ├── activations.py
│ │ ├── alias_free_torch
│ │ ├── __init__.py
│ │ ├── act.py
│ │ ├── filter.py
│ │ └── resample.py
│ │ └── models.py
├── requirements.txt
└── run_music.sh
├── lumina_next_compositional_generation
├── README.md
├── demo.py
├── models
│ ├── __init__.py
│ ├── components.py
│ └── model.py
└── transport
│ ├── __init__.py
│ ├── integrators.py
│ ├── path.py
│ ├── transport.py
│ └── utils.py
├── lumina_next_t2i
├── README.md
├── __init__.py
├── configs
│ ├── data
│ │ └── JourneyDB.yaml
│ └── infer
│ │ └── settings.yaml
├── data
│ ├── __init__.py
│ ├── data_reader.py
│ └── dataset.py
├── demo.py
├── entry_point.py
├── grad_norm.py
├── imgproc.py
├── models
│ ├── __init__.py
│ ├── components.py
│ └── model.py
├── parallel.py
├── sample.py
├── train.py
├── transport
│ ├── __init__.py
│ ├── integrators.py
│ ├── path.py
│ ├── transport.py
│ └── utils.py
└── utils
│ ├── __init__.py
│ ├── cli.py
│ └── group.py
├── lumina_next_t2i_mini
├── README.md
├── __init__.py
├── configs
│ └── data
│ │ └── JourneyDB.yaml
├── data
│ ├── __init__.py
│ ├── data_reader.py
│ └── dataset.py
├── demo.py
├── grad_norm.py
├── imgproc.py
├── models
│ ├── __init__.py
│ ├── components.py
│ └── nextdit.py
├── parallel.py
├── sample.py
├── sample_img2img.py
├── sample_sd3.py
├── scripts
│ ├── sample.sh
│ ├── sample_img2img.sh
│ └── sample_sd3.sh
├── train.py
├── train_dreambooth_sd3.py
└── transport.py
├── lumina_t2i
├── .isort.cfg
├── README.md
├── __init__.py
├── configs
│ ├── data
│ │ └── JourneyDB.yaml
│ └── infer
│ │ └── settings.yaml
├── data
│ ├── __init__.py
│ ├── data_reader.py
│ └── dataset.py
├── demo.py
├── entry_point.py
├── exps
│ ├── 5B_bs512_lr1e-4_bf16_1024px_sdxlvae.sh
│ ├── 5B_bs512_lr1e-4_bf16_256px_sdxlvae.sh
│ ├── 5B_bs512_lr1e-4_bf16_512px_sdxlvae.sh
│ └── slurm
│ │ ├── 5B_bs512_lr1e-4_bf16_1024px_sdxlvae.sh
│ │ ├── 5B_bs512_lr1e-4_bf16_256px_sdxlvae.sh
│ │ └── 5B_bs512_lr1e-4_bf16_512px_sdxlvae.sh
├── grad_norm.py
├── imgproc.py
├── models
│ ├── __init__.py
│ ├── components.py
│ └── model.py
├── parallel.py
├── requirements.txt
├── train.py
├── transport
│ ├── __init__.py
│ ├── integrators.py
│ ├── path.py
│ ├── transport.py
│ └── utils.py
└── utils
│ ├── __init__.py
│ ├── cli.py
│ └── group.py
├── pyproject.toml
├── requirements.txt
└── visual_anagrams
├── .DS_Store
├── LICENSE
├── MANIFEST.in
├── animate.py
├── environment.yml
├── generate.py
├── huggingface_login.py
├── models
├── __init__.py
├── components.py
└── nextdit.py
├── readme.md
├── run.sh
├── setup.py
└── visual_anagrams
├── .DS_Store
├── __init__.py
├── animate.py
├── assets
└── CourierPrime-Regular.ttf
├── samplers.py
├── utils.py
└── views
├── __init__.py
├── assets
└── 4x4
│ ├── 4x4_corner_1024.png
│ ├── 4x4_corner_256.png
│ ├── 4x4_corner_64.png
│ ├── 4x4_edge1_1024.png
│ ├── 4x4_edge1_256.png
│ ├── 4x4_edge1_64.png
│ ├── 4x4_edge2_1024.png
│ ├── 4x4_edge2_256.png
│ ├── 4x4_edge2_64.png
│ ├── 4x4_inner_1024.png
│ ├── 4x4_inner_256.png
│ └── 4x4_inner_64.png
├── jigsaw_helpers.py
├── permutations.py
├── view_base.py
├── view_blur.py
├── view_color.py
├── view_flip.py
├── view_hybrid.py
├── view_identity.py
├── view_inner_circle.py
├── view_jigsaw.py
├── view_motion.py
├── view_negate.py
├── view_patch_permute.py
├── view_permute.py
├── view_rotate.py
├── view_scale.py
├── view_skew.py
├── view_square_hinge.py
└── view_white_balance.py
/.isort.cfg:
--------------------------------------------------------------------------------
1 | [settings]
2 | profile = black
3 | line_length = 120
4 | sections = FUTURE,STDLIB,THIRDPARTY,FIRSTPARTY,LOCALFOLDER
5 | no_lines_before = STDLIB,LOCALFOLDER
6 | lines_between_types = 1
7 | combine_as_imports = True
8 | force_sort_within_sections = true
9 | order_by_type = True
10 | src_paths = *
11 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | # Exclude all third-party libraries and auto-generated files globally
2 | exclude: |
3 | (?x)^(
4 | assets/.+|
5 | Flag-DiT-ImageNet/exps/.+|
6 | Next-DiT-ImageNet/exps/.+|
7 | lumina_t2i/configs/.+|
8 | lumina_next_t2i/configs/.+|
9 | )$
10 | repos:
11 | # Common hooks
12 | - repo: https://github.com/pre-commit/pre-commit-hooks
13 | rev: v4.6.0
14 | hooks:
15 | - id: check-merge-conflict
16 | - id: check-symlinks
17 | - id: detect-private-key
18 | - id: end-of-file-fixer
19 | - id: trailing-whitespace
20 | files: (.*\.(py|bzl|md|rst|c|cc|cxx|cpp|cu|h|hpp|hxx|xpu|kps|cmake|yaml|yml|hook)|BUILD|.*\.BUILD|WORKSPACE|CMakeLists\.txt)$
21 | # For Python files
22 | - repo: https://github.com/PyCQA/isort
23 | rev: 5.13.2
24 | hooks:
25 | - id: isort
26 | - repo: https://github.com/psf/black.git
27 | rev: 24.4.2
28 | hooks:
29 | - id: black
30 | files: (.*\.(py|pyi|bzl)|BUILD|.*\.BUILD|WORKSPACE)$
31 | args: [--line-length=120]
32 |
--------------------------------------------------------------------------------
/Flag-DiT-ImageNet/exps/600M_bs256_lr5e-4_bf16_qknorm_lognorm.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 |
3 | train_data_root='/mnt/petrelfs/share/images/train'
4 |
5 | model=DiT_Llama_600M_patch2
6 | batch_size=256
7 | lr=5e-4
8 | precision=bf16
9 |
10 | exp_name=${model}_bs${batch_size}_lr${lr}_${precision}_qknorm
11 | mkdir -p results/"$exp_name"
12 |
13 | torchrun --nproc-per-node=8 train.py \
14 | --model ${model} \
15 | --data_path ${train_data_root} \
16 | --results_dir results/"$exp_name" \
17 | --micro_batch_size 32 \
18 | --global_batch_size ${batch_size} --lr ${lr} \
19 | --data_parallel sdp \
20 | --max_steps 3000000 \
21 | --ckpt_every 10000 --log_every 100 \
22 | --precision ${precision} --grad_precision fp32 --qk_norm \
23 | --snr_type "lognorm" \
24 | 2>&1 | tee -a results/"$exp_name"/output.log
25 |
--------------------------------------------------------------------------------
/Flag-DiT-ImageNet/exps/slurm/3B_bs256_lr5e-4_bf16_qknorm_lognorm.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 |
3 | train_data_root='/mnt/petrelfs/share/images/train'
4 |
5 | model=DiT_Llama_3B_patch2
6 | batch_size=256
7 | lr=5e-4
8 | precision=bf16
9 |
10 | exp_name=${model}_bs${batch_size}_lr${lr}_${precision}_qknorm
11 | mkdir -p results/"$exp_name"
12 |
13 | python -u train.py \
14 | --model ${model} \
15 | --data_path ${train_data_root} \
16 | --results_dir results/"$exp_name" \
17 | --micro_batch_size 32 \
18 | --global_batch_size ${batch_size} --lr ${lr} \
19 | --data_parallel sdp \
20 | --max_steps 3000000 \
21 | --ckpt_every 10000 --log_every 100 \
22 | --precision ${precision} --grad_precision fp32 --qk_norm \
23 | --snr_type "lognorm" \
24 | 2>&1 | tee -a results/"$exp_name"/output.log
25 |
--------------------------------------------------------------------------------
/Flag-DiT-ImageNet/exps/slurm/600M_bs256_lr5e-4_bf16_qknorm_lognorm.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 |
3 | train_data_root='/mnt/petrelfs/share/images/train'
4 |
5 | model=DiT_Llama_600M_patch2
6 | batch_size=256
7 | lr=5e-4
8 | precision=bf16
9 |
10 | exp_name=${model}_bs${batch_size}_lr${lr}_${precision}_qknorm
11 | mkdir -p results/"$exp_name"
12 |
13 | python -u train.py \
14 | --model ${model} \
15 | --data_path ${train_data_root} \
16 | --results_dir results/"$exp_name" \
17 | --micro_batch_size 32 \
18 | --global_batch_size ${batch_size} --lr ${lr} \
19 | --data_parallel sdp \
20 | --max_steps 3000000 \
21 | --ckpt_every 10000 --log_every 100 \
22 | --precision ${precision} --grad_precision fp32 --qk_norm \
23 | --snr_type "lognorm" \
24 | 2>&1 | tee -a results/"$exp_name"/output.log
25 |
--------------------------------------------------------------------------------
/Flag-DiT-ImageNet/exps/slurm/7B_bs256_lr5e-4_bf16_qknorm_lognorm.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 |
3 | train_data_root='/mnt/petrelfs/share/images/train'
4 |
5 | model=DiT_Llama_7B_patch2
6 | batch_size=256
7 | lr=5e-4
8 | precision=bf16
9 |
10 | exp_name=${model}_bs${batch_size}_lr${lr}_${precision}_qknorm
11 | mkdir -p results/"$exp_name"
12 |
13 | python -u train.py \
14 | --model ${model} \
15 | --data_path ${train_data_root} \
16 | --results_dir results/"$exp_name" \
17 | --micro_batch_size 32 \
18 | --global_batch_size ${batch_size} --lr ${lr} \
19 | --data_parallel sdp \
20 | --max_steps 3000000 \
21 | --ckpt_every 10000 --log_every 100 \
22 | --precision ${precision} --grad_precision fp32 --qk_norm \
23 | --snr_type "lognorm" \
24 | 2>&1 | tee -a results/"$exp_name"/output.log
25 |
--------------------------------------------------------------------------------
/Flag-DiT-ImageNet/grad_norm.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 |
3 | import fairscale.nn.model_parallel.initialize as fs_init
4 | from fairscale.nn.model_parallel.layers import ColumnParallelLinear, ParallelEmbedding, RowParallelLinear
5 | import torch
6 | import torch.distributed as dist
7 | import torch.nn as nn
8 |
9 |
10 | def get_model_parallel_dim_dict(model: nn.Module) -> Dict[str, int]:
11 | ret_dict = {}
12 | for module_name, module in model.named_modules():
13 |
14 | def param_fqn(param_name):
15 | return param_name if module_name == "" else module_name + "." + param_name
16 |
17 | if isinstance(module, ColumnParallelLinear):
18 | ret_dict[param_fqn("weight")] = 0
19 | if module.bias is not None:
20 | ret_dict[param_fqn("bias")] = 0
21 | elif isinstance(module, RowParallelLinear):
22 | ret_dict[param_fqn("weight")] = 1
23 | if module.bias is not None:
24 | ret_dict[param_fqn("bias")] = -1
25 | elif isinstance(module, ParallelEmbedding):
26 | ret_dict[param_fqn("weight")] = 1
27 | else:
28 | for param_name, param in module.named_parameters(recurse=False):
29 | ret_dict[param_fqn(param_name)] = -1
30 | return ret_dict
31 |
32 |
33 | def calculate_l2_grad_norm(
34 | model: nn.Module,
35 | model_parallel_dim_dict: Dict[str, int],
36 | ) -> float:
37 | mp_norm_sq = torch.tensor(0.0, dtype=torch.float32, device="cuda")
38 | non_mp_norm_sq = torch.tensor(0.0, dtype=torch.float32, device="cuda")
39 |
40 | for name, param in model.named_parameters():
41 | if param.grad is None:
42 | continue
43 | name = ".".join(x for x in name.split(".") if not x.startswith("_"))
44 | assert name in model_parallel_dim_dict
45 | if model_parallel_dim_dict[name] < 0:
46 | non_mp_norm_sq += param.grad.norm(dtype=torch.float32) ** 2
47 | else:
48 | mp_norm_sq += param.grad.norm(dtype=torch.float32) ** 2
49 |
50 | dist.all_reduce(mp_norm_sq)
51 | dist.all_reduce(non_mp_norm_sq)
52 | non_mp_norm_sq /= fs_init.get_model_parallel_world_size()
53 |
54 | return (mp_norm_sq.item() + non_mp_norm_sq.item()) ** 0.5
55 |
56 |
57 | def scale_grad(model: nn.Module, factor: float) -> None:
58 | for param in model.parameters():
59 | if param.grad is not None:
60 | param.grad.mul_(factor)
61 |
--------------------------------------------------------------------------------
/Flag-DiT-ImageNet/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import DiT_Llama_3B_patch2, DiT_Llama_7B_patch2, DiT_Llama_600M_patch2
2 |
--------------------------------------------------------------------------------
/Flag-DiT-ImageNet/models/components.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | try:
7 | from apex.normalization import FusedRMSNorm as RMSNorm
8 | except ImportError:
9 | warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation")
10 |
11 | class RMSNorm(torch.nn.Module):
12 | def __init__(self, dim: int, eps: float = 1e-6):
13 | """
14 | Initialize the RMSNorm normalization layer.
15 |
16 | Args:
17 | dim (int): The dimension of the input tensor.
18 | eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
19 |
20 | Attributes:
21 | eps (float): A small value added to the denominator for numerical stability.
22 | weight (nn.Parameter): Learnable scaling parameter.
23 |
24 | """
25 | super().__init__()
26 | self.eps = eps
27 | self.weight = nn.Parameter(torch.ones(dim))
28 |
29 | def _norm(self, x):
30 | """
31 | Apply the RMSNorm normalization to the input tensor.
32 |
33 | Args:
34 | x (torch.Tensor): The input tensor.
35 |
36 | Returns:
37 | torch.Tensor: The normalized tensor.
38 |
39 | """
40 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
41 |
42 | def forward(self, x):
43 | """
44 | Forward pass through the RMSNorm layer.
45 |
46 | Args:
47 | x (torch.Tensor): The input tensor.
48 |
49 | Returns:
50 | torch.Tensor: The output tensor after applying RMSNorm.
51 |
52 | """
53 | output = self._norm(x.float()).type_as(x)
54 | return output * self.weight
55 |
--------------------------------------------------------------------------------
/Flag-DiT-ImageNet/parallel.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import os
4 | import subprocess
5 | from time import sleep
6 |
7 | import fairscale.nn.model_parallel.initialize as fs_init
8 | import torch
9 | import torch.distributed as dist
10 |
11 |
12 | def _setup_dist_env_from_slurm(args):
13 | while not os.environ.get("MASTER_ADDR", ""):
14 | os.environ["MASTER_ADDR"] = (
15 | subprocess.check_output(
16 | "sinfo -Nh -n %s | head -n 1 | awk '{print $1}'" % os.environ["SLURM_NODELIST"],
17 | shell=True,
18 | )
19 | .decode()
20 | .strip()
21 | )
22 | sleep(1)
23 | os.environ["MASTER_PORT"] = str(args.master_port)
24 | os.environ["RANK"] = os.environ["SLURM_PROCID"]
25 | os.environ["WORLD_SIZE"] = os.environ["SLURM_NPROCS"]
26 | os.environ["LOCAL_RANK"] = os.environ["SLURM_LOCALID"]
27 | os.environ["LOCAL_WORLD_SIZE"] = os.environ["SLURM_NTASKS_PER_NODE"]
28 |
29 |
30 | _INTRA_NODE_PROCESS_GROUP, _INTER_NODE_PROCESS_GROUP = None, None
31 | _LOCAL_RANK, _LOCAL_WORLD_SIZE = -1, -1
32 |
33 |
34 | def get_local_rank() -> int:
35 | return _LOCAL_RANK
36 |
37 |
38 | def get_local_world_size() -> int:
39 | return _LOCAL_WORLD_SIZE
40 |
41 |
42 | def distributed_init(args):
43 | if any([x not in os.environ for x in ["RANK", "WORLD_SIZE", "MASTER_PORT", "MASTER_ADDR"]]):
44 | _setup_dist_env_from_slurm(args)
45 |
46 | dist.init_process_group("nccl")
47 | fs_init.initialize_model_parallel(args.model_parallel_size)
48 | torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
49 |
50 | global _LOCAL_RANK, _LOCAL_WORLD_SIZE
51 | _LOCAL_RANK = int(os.environ["LOCAL_RANK"])
52 | _LOCAL_WORLD_SIZE = int(os.environ["LOCAL_WORLD_SIZE"])
53 |
54 | global _INTRA_NODE_PROCESS_GROUP, _INTER_NODE_PROCESS_GROUP
55 | local_ranks, local_world_sizes = [
56 | torch.empty([dist.get_world_size()], dtype=torch.long, device="cuda") for _ in (0, 1)
57 | ]
58 | dist.all_gather_into_tensor(local_ranks, torch.tensor(get_local_rank(), device="cuda"))
59 | dist.all_gather_into_tensor(local_world_sizes, torch.tensor(get_local_world_size(), device="cuda"))
60 | local_ranks, local_world_sizes = local_ranks.tolist(), local_world_sizes.tolist()
61 | node_ranks = [[0]]
62 | for i in range(1, dist.get_world_size()):
63 | if len(node_ranks[-1]) == local_world_sizes[i - 1]:
64 | node_ranks.append([])
65 | else:
66 | assert local_world_sizes[i] == local_world_sizes[i - 1]
67 | node_ranks[-1].append(i)
68 | for ranks in node_ranks:
69 | group = dist.new_group(ranks)
70 | if dist.get_rank() in ranks:
71 | assert _INTRA_NODE_PROCESS_GROUP is None
72 | _INTRA_NODE_PROCESS_GROUP = group
73 | assert _INTRA_NODE_PROCESS_GROUP is not None
74 |
75 | if min(local_world_sizes) == max(local_world_sizes):
76 | for i in range(get_local_world_size()):
77 | group = dist.new_group(list(range(i, dist.get_world_size(), get_local_world_size())))
78 | if i == get_local_rank():
79 | assert _INTER_NODE_PROCESS_GROUP is None
80 | _INTER_NODE_PROCESS_GROUP = group
81 | assert _INTER_NODE_PROCESS_GROUP is not None
82 |
83 |
84 | def get_intra_node_process_group():
85 | assert _INTRA_NODE_PROCESS_GROUP is not None, "Intra-node process group is not initialized."
86 | return _INTRA_NODE_PROCESS_GROUP
87 |
88 |
89 | def get_inter_node_process_group():
90 | assert _INTRA_NODE_PROCESS_GROUP is not None, "Intra- and inter-node process groups are not initialized."
91 | return _INTER_NODE_PROCESS_GROUP
92 |
--------------------------------------------------------------------------------
/Flag-DiT-ImageNet/scripts/run_8gpus.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # run Flag-DiT with single node
4 |
5 | # run Flag-DiT 600M
6 | bash exps/600M_bs256_lr5e-4_bf16_qknorm_lognorm.sh
7 |
--------------------------------------------------------------------------------
/Flag-DiT-ImageNet/scripts/slurm/run_32gpus.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # run Flag-DiT with cluster
4 |
5 | # added config here for slurm cluster using 32 GPUs
6 |
7 | # run Flag-DiT 600M
8 | srun bash exps/600M_bs256_lr5e-4_bf16_qknorm_lognorm.sh
9 | # run Flag-DiT 3B
10 | srun bash exps/3B_bs256_lr5e-4_bf16_qknorm_lognorm.sh
11 | # run Flag-DiT 7B
12 | srun bash exps/7B_bs256_lr5e-4_bf16_qknorm_lognorm.sh
13 |
--------------------------------------------------------------------------------
/Flag-DiT-ImageNet/scripts/slurm/run_8gpus.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # run Flag-DiT with cluster
4 |
5 | # added config here for slurm cluster using 8 GPUs
6 |
7 | # run Flag-DiT 600M
8 | srun bash exps/600M_bs256_lr5e-4_bf16_qknorm_lognorm.sh
9 | # run Flag-DiT 3B
10 | srun bash exps/3B_bs256_lr5e-4_bf16_qknorm_lognorm.sh
11 | # run Flag-DiT 7B
12 | srun bash exps/7B_bs256_lr5e-4_bf16_qknorm_lognorm.sh
13 |
--------------------------------------------------------------------------------
/Flag-DiT-ImageNet/transport/__init__.py:
--------------------------------------------------------------------------------
1 | from .transport import ModelType, PathType, Sampler, SNRType, Transport, WeightType
2 |
3 |
4 | def create_transport(
5 | path_type="Linear", prediction="velocity", loss_weight=None, train_eps=None, sample_eps=None, snr_type="uniform"
6 | ):
7 | """function for creating Transport object
8 | **Note**: model prediction defaults to velocity
9 | Args:
10 | - path_type: type of path to use; default to linear
11 | - learn_score: set model prediction to score
12 | - learn_noise: set model prediction to noise
13 | - velocity_weighted: weight loss by velocity weight
14 | - likelihood_weighted: weight loss by likelihood weight
15 | - train_eps: small epsilon for avoiding instability during training
16 | - sample_eps: small epsilon for avoiding instability during sampling
17 | """
18 |
19 | if prediction == "noise":
20 | model_type = ModelType.NOISE
21 | elif prediction == "score":
22 | model_type = ModelType.SCORE
23 | else:
24 | model_type = ModelType.VELOCITY
25 |
26 | if loss_weight == "velocity":
27 | loss_type = WeightType.VELOCITY
28 | elif loss_weight == "likelihood":
29 | loss_type = WeightType.LIKELIHOOD
30 | else:
31 | loss_type = WeightType.NONE
32 |
33 | if snr_type == "lognorm":
34 | snr_type = SNRType.LOGNORM
35 | elif snr_type == "uniform":
36 | snr_type = SNRType.UNIFORM
37 | else:
38 | raise ValueError(f"Invalid snr type {snr_type}")
39 |
40 | path_choice = {
41 | "Linear": PathType.LINEAR,
42 | "GVP": PathType.GVP,
43 | "VP": PathType.VP,
44 | }
45 |
46 | path_type = path_choice[path_type]
47 |
48 | if path_type in [PathType.VP]:
49 | train_eps = 1e-5 if train_eps is None else train_eps
50 | sample_eps = 1e-3 if train_eps is None else sample_eps
51 | elif path_type in [PathType.GVP, PathType.LINEAR] and model_type != ModelType.VELOCITY:
52 | train_eps = 1e-3 if train_eps is None else train_eps
53 | sample_eps = 1e-3 if train_eps is None else sample_eps
54 | else: # velocity & [GVP, LINEAR] is stable everywhere
55 | train_eps = 0
56 | sample_eps = 0
57 |
58 | # create flow state
59 | state = Transport(
60 | model_type=model_type,
61 | path_type=path_type,
62 | loss_type=loss_type,
63 | train_eps=train_eps,
64 | sample_eps=sample_eps,
65 | snr_type=snr_type,
66 | )
67 |
68 | return state
69 |
--------------------------------------------------------------------------------
/Flag-DiT-ImageNet/transport/utils.py:
--------------------------------------------------------------------------------
1 | import torch as th
2 |
3 |
4 | class EasyDict:
5 | def __init__(self, sub_dict):
6 | for k, v in sub_dict.items():
7 | setattr(self, k, v)
8 |
9 | def __getitem__(self, key):
10 | return getattr(self, key)
11 |
12 |
13 | def mean_flat(x):
14 | """
15 | Take the mean over all non-batch dimensions.
16 | """
17 | return th.mean(x, dim=list(range(1, len(x.size()))))
18 |
19 |
20 | def log_state(state):
21 | result = []
22 |
23 | sorted_state = dict(sorted(state.items()))
24 | for key, value in sorted_state.items():
25 | # Check if the value is an instance of a class
26 | if "
5 |
6 | # $\textbf{Lumina-T2X}$ : Transform text to any modality with Flow-based Large Diffusion Transformer
7 |
--------------------------------------------------------------------------------
/assets/audios/a_telephone_bell_rings.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/assets/audios/a_telephone_bell_rings.wav
--------------------------------------------------------------------------------
/assets/audios/a_telephone_bell_rings_gt.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/assets/audios/a_telephone_bell_rings_gt.wav
--------------------------------------------------------------------------------
/assets/compositional_intro.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/assets/compositional_intro.png
--------------------------------------------------------------------------------
/assets/diverse_config.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/assets/diverse_config.png
--------------------------------------------------------------------------------
/assets/images/demo_image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/assets/images/demo_image.png
--------------------------------------------------------------------------------
/assets/images/resolution_extrapolation_2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/assets/images/resolution_extrapolation_2.jpg
--------------------------------------------------------------------------------
/assets/lumina-intro.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/assets/lumina-intro.png
--------------------------------------------------------------------------------
/assets/lumina-logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/assets/lumina-logo.png
--------------------------------------------------------------------------------
/lumina_audio/configs/lumina-text2audio.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 3.0e-06
3 | target: models.diffusion.ddpm_audio.CFM
4 | params:
5 | linear_start: 0.00085
6 | linear_end: 0.012
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: caption
12 | mel_dim: 20
13 | mel_length: 256
14 | channels: 0
15 | cond_stage_trainable: True
16 | conditioning_key: crossattn
17 | monitor: val/loss_simple_ema
18 | scale_by_std: true
19 | use_ema: false
20 | scheduler_config:
21 | target: models.lr_scheduler.LambdaLinearScheduler
22 | params:
23 | warm_up_steps:
24 | - 10000
25 | cycle_lengths:
26 | - 10000000000000
27 | f_start:
28 | - 1.0e-06
29 | f_max:
30 | - 1.0
31 | f_min:
32 | - 1.0
33 | unet_config:
34 | target: models.diffusion.flag_large_dit.FlagDiTv2
35 | params:
36 | in_channels: 20
37 | context_dim: 1024
38 | hidden_size: 768
39 | num_heads: 32
40 | depth: 16
41 | max_len: 1000
42 |
43 | first_stage_config:
44 | target: models.autoencoder1d.AutoencoderKL
45 | params:
46 | embed_dim: 20
47 | monitor: val/rec_loss
48 | ckpt_path: /path/to/ckpt/maa2/maa2.ckpt
49 | ddconfig:
50 | double_z: true
51 | in_channels: 80
52 | out_ch: 80
53 | z_channels: 20
54 | kernel_size: 5
55 | ch: 384
56 | ch_mult:
57 | - 1
58 | - 2
59 | - 4
60 | num_res_blocks: 2
61 | attn_layers:
62 | - 3
63 | down_layers:
64 | - 0
65 | dropout: 0.0
66 | lossconfig:
67 | target: torch.nn.Identity
68 | cond_stage_config:
69 | target: models.encoders.modules.FrozenCLAPFLANEmbedder
70 | params:
71 | weights_path: /path/to/ckpt/CLAP/CLAP_weights_2022.pth
72 |
--------------------------------------------------------------------------------
/lumina_audio/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/lumina_audio/models/__init__.py
--------------------------------------------------------------------------------
/lumina_audio/models/diffusion/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/lumina_audio/models/diffusion/__init__.py
--------------------------------------------------------------------------------
/lumina_audio/models/diffusion/component.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | try:
7 | from apex.normalization import FusedRMSNorm as RMSNorm
8 | except ImportError:
9 | warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation")
10 |
11 | class RMSNorm(torch.nn.Module):
12 | def __init__(self, dim: int, eps: float = 1e-6):
13 | """
14 | Initialize the RMSNorm normalization layer.
15 |
16 | Args:
17 | dim (int): The dimension of the input tensor.
18 | eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
19 |
20 | Attributes:
21 | eps (float): A small value added to the denominator for numerical stability.
22 | weight (nn.Parameter): Learnable scaling parameter.
23 |
24 | """
25 | super().__init__()
26 | self.eps = eps
27 | self.weight = nn.Parameter(torch.ones(dim))
28 |
29 | def _norm(self, x):
30 | """
31 | Apply the RMSNorm normalization to the input tensor.
32 |
33 | Args:
34 | x (torch.Tensor): The input tensor.
35 |
36 | Returns:
37 | torch.Tensor: The normalized tensor.
38 |
39 | """
40 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
41 |
42 | def forward(self, x):
43 | """
44 | Forward pass through the RMSNorm layer.
45 |
46 | Args:
47 | x (torch.Tensor): The input tensor.
48 |
49 | Returns:
50 | torch.Tensor: The output tensor after applying RMSNorm.
51 |
52 | """
53 | output = self._norm(x.float()).type_as(x)
54 | return output * self.weight
55 |
--------------------------------------------------------------------------------
/lumina_audio/models/diffusion/distributions/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/lumina_audio/models/diffusion/distributions/__init__.py
--------------------------------------------------------------------------------
/lumina_audio/models/diffusion/distributions/distributions.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | class AbstractDistribution:
6 | def sample(self):
7 | raise NotImplementedError()
8 |
9 | def mode(self):
10 | raise NotImplementedError()
11 |
12 |
13 | class DiracDistribution(AbstractDistribution):
14 | def __init__(self, value):
15 | self.value = value
16 |
17 | def sample(self):
18 | return self.value
19 |
20 | def mode(self):
21 | return self.value
22 |
23 |
24 | class DiagonalGaussianDistribution(object):
25 | def __init__(self, parameters, deterministic=False):
26 | self.parameters = parameters
27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29 | self.deterministic = deterministic
30 | self.std = torch.exp(0.5 * self.logvar)
31 | self.var = torch.exp(self.logvar)
32 | if self.deterministic:
33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34 |
35 | def sample(self):
36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
37 | return x
38 |
39 | def kl(self, other=None):
40 | if self.deterministic:
41 | return torch.Tensor([0.0])
42 | else:
43 | sum_dim = list(range(1, len(self.mean.shape)))
44 | if other is None:
45 |
46 | return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=sum_dim)
47 | else:
48 | return 0.5 * torch.sum(
49 | torch.pow(self.mean - other.mean, 2) / other.var
50 | + self.var / other.var
51 | - 1.0
52 | - self.logvar
53 | + other.logvar,
54 | dim=sum_dim,
55 | )
56 |
57 | def nll(self, sample, dims=[1, 2, 3]):
58 | if self.deterministic:
59 | return torch.Tensor([0.0])
60 | logtwopi = np.log(2.0 * np.pi)
61 | return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims)
62 |
63 | def mode(self):
64 | return self.mean
65 |
66 |
67 | def normal_kl(mean1, logvar1, mean2, logvar2):
68 | """
69 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
70 | Compute the KL divergence between two gaussians.
71 | Shapes are automatically broadcasted, so batches can be compared to
72 | scalars, among other use cases.
73 | """
74 | tensor = None
75 | for obj in (mean1, logvar1, mean2, logvar2):
76 | if isinstance(obj, torch.Tensor):
77 | tensor = obj
78 | break
79 | assert tensor is not None, "at least one argument must be a Tensor"
80 |
81 | # Force variances to be Tensors. Broadcasting helps convert scalars to
82 | # Tensors, but it does not work for torch.exp().
83 | logvar1, logvar2 = [x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2)]
84 |
85 | return 0.5 * (
86 | -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
87 | )
88 |
--------------------------------------------------------------------------------
/lumina_audio/models/diffusion/ema.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class LitEma(nn.Module):
6 | def __init__(self, model, decay=0.9999, use_num_upates=True):
7 | super().__init__()
8 | if decay < 0.0 or decay > 1.0:
9 | raise ValueError("Decay must be between 0 and 1")
10 |
11 | self.m_name2s_name = {}
12 | self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
13 | self.register_buffer(
14 | "num_updates", torch.tensor(0, dtype=torch.int) if use_num_upates else torch.tensor(-1, dtype=torch.int)
15 | )
16 |
17 | for name, p in model.named_parameters():
18 | if p.requires_grad:
19 | # remove as '.'-character is not allowed in buffers
20 | s_name = name.replace(".", "")
21 | self.m_name2s_name.update({name: s_name})
22 | self.register_buffer(s_name, p.clone().detach().data)
23 |
24 | self.collected_params = []
25 |
26 | def forward(self, model):
27 | decay = self.decay
28 |
29 | if self.num_updates >= 0:
30 | self.num_updates += 1
31 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
32 |
33 | one_minus_decay = 1.0 - decay
34 |
35 | with torch.no_grad():
36 | m_param = dict(model.named_parameters())
37 | shadow_params = dict(self.named_buffers())
38 |
39 | for key in m_param:
40 | if m_param[key].requires_grad:
41 | sname = self.m_name2s_name[key]
42 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
43 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
44 | else:
45 | assert not key in self.m_name2s_name
46 |
47 | def copy_to(self, model):
48 | m_param = dict(model.named_parameters())
49 | shadow_params = dict(self.named_buffers())
50 | for key in m_param:
51 | if m_param[key].requires_grad:
52 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
53 | else:
54 | assert not key in self.m_name2s_name
55 |
56 | def store(self, parameters):
57 | """
58 | Save the current parameters for restoring later.
59 | Args:
60 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
61 | temporarily stored.
62 | """
63 | self.collected_params = [param.clone() for param in parameters]
64 |
65 | def restore(self, parameters):
66 | """
67 | Restore the parameters stored with the `store` method.
68 | Useful to validate the model with EMA parameters without affecting the
69 | original optimization process. Store the parameters before the
70 | `copy_to` method. After validation (or model saving), use this to
71 | restore the former parameters.
72 | Args:
73 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
74 | updated with the stored parameters.
75 | """
76 | for c_param, param in zip(self.collected_params, parameters):
77 | param.data.copy_(c_param.data)
78 |
--------------------------------------------------------------------------------
/lumina_audio/models/encoders/CLAP/__init__.py:
--------------------------------------------------------------------------------
1 | from . import audio, clap, utils
2 |
--------------------------------------------------------------------------------
/lumina_audio/models/encoders/CLAP/clap.py:
--------------------------------------------------------------------------------
1 | import os.path
2 |
3 | import numpy as np
4 | import torch
5 | from torch import nn
6 | import torch.nn.functional as F
7 | from transformers import AutoModel
8 |
9 | from .audio import get_audio_encoder
10 |
11 |
12 | class Projection(nn.Module):
13 | def __init__(self, d_in: int, d_out: int, p: float = 0.5) -> None:
14 | super().__init__()
15 | self.linear1 = nn.Linear(d_in, d_out, bias=False)
16 | self.linear2 = nn.Linear(d_out, d_out, bias=False)
17 | self.layer_norm = nn.LayerNorm(d_out)
18 | self.drop = nn.Dropout(p)
19 |
20 | def forward(self, x: torch.Tensor) -> torch.Tensor:
21 | embed1 = self.linear1(x)
22 | embed2 = self.drop(self.linear2(F.gelu(embed1)))
23 | embeds = self.layer_norm(embed1 + embed2)
24 | return embeds
25 |
26 |
27 | class AudioEncoder(nn.Module):
28 | def __init__(
29 | self,
30 | audioenc_name: str,
31 | d_in: int,
32 | d_out: int,
33 | sample_rate: int,
34 | window_size: int,
35 | hop_size: int,
36 | mel_bins: int,
37 | fmin: int,
38 | fmax: int,
39 | classes_num: int,
40 | ) -> None:
41 | super().__init__()
42 |
43 | audio_encoder = get_audio_encoder(audioenc_name)
44 |
45 | self.base = audio_encoder(sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num, d_in)
46 |
47 | self.projection = Projection(d_in, d_out)
48 |
49 | def forward(self, x):
50 | out_dict = self.base(x)
51 | audio_features, audio_classification_output = out_dict["embedding"], out_dict["clipwise_output"]
52 | projected_vec = self.projection(audio_features)
53 | return projected_vec, audio_classification_output
54 |
55 |
56 | class TextEncoder(nn.Module):
57 | def __init__(self, d_out: int, text_model: str, transformer_embed_dim: int) -> None:
58 | super().__init__()
59 | # if os.path.exists('/apdcephfs/share_1316500/nlphuang/results/Text_to_audio/pretrained/CLAP_AutoModel'):
60 | # root = '/apdcephfs'
61 | # else:
62 | # root = '/apdcephfs_intern'
63 |
64 | self.base = AutoModel.from_pretrained(text_model)
65 | self.projection = Projection(transformer_embed_dim, d_out)
66 |
67 | def forward(self, x):
68 | out = self.base(**x)[0]
69 | out = out[:, 0, :] # get CLS token output
70 | projected_vec = self.projection(out)
71 | return projected_vec
72 |
73 |
74 | class CLAP(nn.Module):
75 | def __init__(
76 | self,
77 | # audio
78 | audioenc_name: str,
79 | sample_rate: int,
80 | window_size: int,
81 | hop_size: int,
82 | mel_bins: int,
83 | fmin: int,
84 | fmax: int,
85 | classes_num: int,
86 | out_emb: int,
87 | # text
88 | text_model: str,
89 | transformer_embed_dim: int,
90 | # common
91 | d_proj: int,
92 | ):
93 | super().__init__()
94 |
95 | self.audio_encoder = AudioEncoder(
96 | audioenc_name, out_emb, d_proj, sample_rate, window_size, hop_size, mel_bins, fmin, fmax, classes_num
97 | )
98 |
99 | self.caption_encoder = TextEncoder(d_proj, text_model, transformer_embed_dim)
100 |
101 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
102 |
103 | def forward(self, audio, text):
104 | audio_embed, _ = self.audio_encoder(audio)
105 | caption_embed = self.caption_encoder(text)
106 |
107 | return caption_embed, audio_embed, self.logit_scale.exp()
108 |
--------------------------------------------------------------------------------
/lumina_audio/models/encoders/CLAP/config.yml:
--------------------------------------------------------------------------------
1 | # TEXT ENCODER CONFIG
2 | text_model: 'bert-base-uncased'
3 | text_len: 100
4 | transformer_embed_dim: 768
5 | freeze_text_encoder_weights: True
6 |
7 | # AUDIO ENCODER CONFIG
8 | audioenc_name: 'Cnn14'
9 | out_emb: 2048
10 | sampling_rate: 44100
11 | duration: 5
12 | fmin: 50
13 | fmax: 14000
14 | n_fft: 1028
15 | hop_size: 320
16 | mel_bins: 64
17 | window_size: 1024
18 |
19 | # PROJECTION SPACE CONFIG
20 | d_proj: 1024
21 | temperature: 0.003
22 |
23 | # TRAINING AND EVALUATION CONFIG
24 | num_classes: 527
25 | batch_size: 1024
26 | demo: False
27 |
--------------------------------------------------------------------------------
/lumina_audio/models/encoders/CLAP/utils.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import sys
3 |
4 | import yaml
5 |
6 |
7 | def read_config_as_args(config_path, args=None, is_config_str=False):
8 | return_dict = {}
9 |
10 | if config_path is not None:
11 | if is_config_str:
12 | yml_config = yaml.load(config_path, Loader=yaml.FullLoader)
13 | else:
14 | with open(config_path, "r") as f:
15 | yml_config = yaml.load(f, Loader=yaml.FullLoader)
16 |
17 | if args != None:
18 | for k, v in yml_config.items():
19 | if k in args.__dict__:
20 | args.__dict__[k] = v
21 | else:
22 | sys.stderr.write("Ignored unknown parameter {} in yaml.\n".format(k))
23 | else:
24 | for k, v in yml_config.items():
25 | return_dict[k] = v
26 |
27 | args = args if args != None else return_dict
28 | return argparse.Namespace(**args)
29 |
--------------------------------------------------------------------------------
/lumina_audio/models/encoders/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/lumina_audio/models/encoders/__init__.py
--------------------------------------------------------------------------------
/lumina_audio/models/vocoder/bigvgan/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/lumina_audio/models/vocoder/bigvgan/__init__.py
--------------------------------------------------------------------------------
/lumina_audio/models/vocoder/bigvgan/alias_free_torch/__init__.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2 | # LICENSE is in incl_licenses directory.
3 |
4 | from .act import *
5 | from .filter import *
6 | from .resample import *
7 |
--------------------------------------------------------------------------------
/lumina_audio/models/vocoder/bigvgan/alias_free_torch/act.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2 | # LICENSE is in incl_licenses directory.
3 |
4 | import torch.nn as nn
5 |
6 | from .resample import DownSample1d, UpSample1d
7 |
8 |
9 | class Activation1d(nn.Module):
10 | def __init__(
11 | self, activation, up_ratio: int = 2, down_ratio: int = 2, up_kernel_size: int = 12, down_kernel_size: int = 12
12 | ):
13 | super().__init__()
14 | self.up_ratio = up_ratio
15 | self.down_ratio = down_ratio
16 | self.act = activation
17 | self.upsample = UpSample1d(up_ratio, up_kernel_size)
18 | self.downsample = DownSample1d(down_ratio, down_kernel_size)
19 |
20 | # x: [B,C,T]
21 | def forward(self, x):
22 | x = self.upsample(x)
23 | x = self.act(x)
24 | x = self.downsample(x)
25 |
26 | return x
27 |
--------------------------------------------------------------------------------
/lumina_audio/models/vocoder/bigvgan/alias_free_torch/filter.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2 | # LICENSE is in incl_licenses directory.
3 |
4 | import math
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 | if "sinc" in dir(torch):
11 | sinc = torch.sinc
12 | else:
13 | # This code is adopted from adefossez's julius.core.sinc under the MIT License
14 | # https://adefossez.github.io/julius/julius/core.html
15 | # LICENSE is in incl_licenses directory.
16 | def sinc(x: torch.Tensor):
17 | """
18 | Implementation of sinc, i.e. sin(pi * x) / (pi * x)
19 | __Warning__: Different to julius.sinc, the input is multiplied by `pi`!
20 | """
21 | return torch.where(
22 | x == 0, torch.tensor(1.0, device=x.device, dtype=x.dtype), torch.sin(math.pi * x) / math.pi / x
23 | )
24 |
25 |
26 | # This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
27 | # https://adefossez.github.io/julius/julius/lowpass.html
28 | # LICENSE is in incl_licenses directory.
29 | def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
30 | even = kernel_size % 2 == 0
31 | half_size = kernel_size // 2
32 |
33 | # For kaiser window
34 | delta_f = 4 * half_width
35 | A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
36 | if A > 50.0:
37 | beta = 0.1102 * (A - 8.7)
38 | elif A >= 21.0:
39 | beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
40 | else:
41 | beta = 0.0
42 | window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
43 |
44 | # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
45 | if even:
46 | time = torch.arange(-half_size, half_size) + 0.5
47 | else:
48 | time = torch.arange(kernel_size) - half_size
49 | if cutoff == 0:
50 | filter_ = torch.zeros_like(time)
51 | else:
52 | filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
53 | # Normalize filter to have sum = 1, otherwise we will have a small leakage
54 | # of the constant component in the input signal.
55 | filter_ /= filter_.sum()
56 | filter = filter_.view(1, 1, kernel_size)
57 |
58 | return filter
59 |
60 |
61 | class LowPassFilter1d(nn.Module):
62 | def __init__(
63 | self,
64 | cutoff=0.5,
65 | half_width=0.6,
66 | stride: int = 1,
67 | padding: bool = True,
68 | padding_mode: str = "replicate",
69 | kernel_size: int = 12,
70 | ):
71 | # kernel_size should be even number for stylegan3 setup,
72 | # in this implementation, odd number is also possible.
73 | super().__init__()
74 | if cutoff < -0.0:
75 | raise ValueError("Minimum cutoff must be larger than zero.")
76 | if cutoff > 0.5:
77 | raise ValueError("A cutoff above 0.5 does not make sense.")
78 | self.kernel_size = kernel_size
79 | self.even = kernel_size % 2 == 0
80 | self.pad_left = kernel_size // 2 - int(self.even)
81 | self.pad_right = kernel_size // 2
82 | self.stride = stride
83 | self.padding = padding
84 | self.padding_mode = padding_mode
85 | filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
86 | self.register_buffer("filter", filter)
87 |
88 | # input [B, C, T]
89 | def forward(self, x):
90 | _, C, _ = x.shape
91 |
92 | if self.padding:
93 | x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
94 | out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
95 |
96 | return out
97 |
--------------------------------------------------------------------------------
/lumina_audio/models/vocoder/bigvgan/alias_free_torch/resample.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2 | # LICENSE is in incl_licenses directory.
3 |
4 | import torch.nn as nn
5 | from torch.nn import functional as F
6 |
7 | from .filter import LowPassFilter1d, kaiser_sinc_filter1d
8 |
9 |
10 | class UpSample1d(nn.Module):
11 | def __init__(self, ratio=2, kernel_size=None):
12 | super().__init__()
13 | self.ratio = ratio
14 | self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
15 | self.stride = ratio
16 | self.pad = self.kernel_size // ratio - 1
17 | self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
18 | self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
19 | filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size)
20 | self.register_buffer("filter", filter)
21 |
22 | # x: [B, C, T]
23 | def forward(self, x):
24 | _, C, _ = x.shape
25 |
26 | x = F.pad(x, (self.pad, self.pad), mode="replicate")
27 | x = self.ratio * F.conv_transpose1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
28 | x = x[..., self.pad_left : -self.pad_right]
29 |
30 | return x
31 |
32 |
33 | class DownSample1d(nn.Module):
34 | def __init__(self, ratio=2, kernel_size=None):
35 | super().__init__()
36 | self.ratio = ratio
37 | self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
38 | self.lowpass = LowPassFilter1d(
39 | cutoff=0.5 / ratio, half_width=0.6 / ratio, stride=ratio, kernel_size=self.kernel_size
40 | )
41 |
42 | def forward(self, x):
43 | xx = self.lowpass(x)
44 |
45 | return xx
46 |
--------------------------------------------------------------------------------
/lumina_audio/n2s_openai.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | from openai import OpenAI
5 | import pandas as pd
6 | import requests
7 |
8 | openai_key = "your openai api key here"
9 | base_url = ""
10 |
11 |
12 | def get_struct(caption):
13 | if base_url != "":
14 | client = OpenAI(api_key=openai_key, base_url=base_url)
15 | else:
16 | client = OpenAI(api_key=openai_key)
17 |
18 | completion = client.chat.completions.create(
19 | model="gpt-3.5-turbo",
20 | messages=[
21 | {
22 | "role": "user",
23 | "content": f"I want to know what sound might be in the given scene and you need to give me the results in the following format:\
24 | Question: A bird sings on the river in the morning, a cow passes by and scares away the bird.\
25 | Answer: @@@.\
26 | Question: cellphone ringing a variety of tones followed by a loud explosion and fire crackling as a truck engine runs idle\
27 | Answer: @@@\
28 | Question: Train passing followed by short honks three times \
29 | Answer: @\
30 | All indicates the sound exists in the whole scene \
31 | Start, mid, end indicates the time period the sound appear.\
32 | Question: {caption} \
33 | Answer:",
34 | },
35 | ],
36 | temperature=0.0,
37 | )
38 |
39 | return completion.choices[0].message.content
40 |
41 |
42 | def parse_args():
43 | parser = argparse.ArgumentParser()
44 | parser.add_argument("--tsv_path", type=str)
45 | return parser.parse_args()
46 |
47 |
48 | if __name__ == "__main__":
49 | args = parse_args()
50 | tsv_path = args.tsv_path
51 | ori_df = pd.read_csv(tsv_path, sep="\t")
52 | index = 0
53 | end = len(ori_df)
54 | name = os.path.basename(tsv_path)[:-4]
55 | f = open(f"{name}.txt", "w")
56 | newcap_list = []
57 | while index < end - 1:
58 | try:
59 | df = ori_df.iloc[index:end]
60 | for t in df.itertuples():
61 | index = int(t[0])
62 | ori_caption = getattr(t, "caption")
63 | strcut_cap = get_struct(ori_caption)
64 | if "sorry" in strcut_cap.lower():
65 | strcut_cap = f"<{ori_caption.lower()}, all>"
66 | newcap_list.append(strcut_cap)
67 | f.write(f"{index}\t{strcut_cap}\n")
68 | f.flush()
69 | except:
70 | print("error")
71 | f.flush()
72 | f.close()
73 | with open(f"{name}.txt") as f:
74 | lines = f.readlines()
75 | id2cap = {}
76 | for line in lines:
77 | index, caption = line.strip().split("\t")
78 | id2cap[int(index)] = caption
79 |
80 | df = pd.read_csv(f"{name}.tsv", sep="\t")
81 | df["struct_cap"] = df.index.map(id2cap)
82 | df.to_csv(f"{name}_struct.tsv", sep="\t", index=False)
83 |
--------------------------------------------------------------------------------
/lumina_audio/requirements.txt:
--------------------------------------------------------------------------------
1 | soundfile
2 | omegaconf
3 | torchdyn
4 | pytorch_lightning
5 | pytorch_memlab
6 | einops
7 | ninja
8 | torchlibrosa
9 | protobuf
10 | sentencepiece
11 | transformers
12 | gradio
13 |
--------------------------------------------------------------------------------
/lumina_audio/run_audio.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/bash
2 |
3 | export HF_ENDPOINT=https://hf-mirror.com
4 |
5 | python -u demo_audio.py \
6 | --ckpt "/path/to/ckpt/audio_generation" \
7 | --vocoder_ckpt "/path/to/ckpt/bigvnat" \
8 | --config_path "configs/lumina-text2audio.yaml" \
9 | --sample_rate 16000
10 |
--------------------------------------------------------------------------------
/lumina_audio/style.css:
--------------------------------------------------------------------------------
1 | .gallery {
2 | text-align: left;
3 | }
4 |
--------------------------------------------------------------------------------
/lumina_music/configs/lumina-text2music.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 3.0e-06
3 | target: models.diffusion.ddpm_audio.CFM
4 | params:
5 | linear_start: 0.00085
6 | linear_end: 0.012
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: caption
12 | mel_dim: 20
13 | mel_length: 256
14 | channels: 0
15 | cond_stage_trainable: True
16 | conditioning_key: crossattn
17 | monitor: val/loss_simple_ema
18 | scale_by_std: true
19 | use_ema: false
20 | scheduler_config:
21 | target: models.lr_scheduler.LambdaLinearScheduler
22 | params:
23 | warm_up_steps:
24 | - 10000
25 | cycle_lengths:
26 | - 10000000000000
27 | f_start:
28 | - 1.0e-06
29 | f_max:
30 | - 1.0
31 | f_min:
32 | - 1.0
33 | unet_config:
34 | target: models.diffusion.flag_large_dit.FlagDiTv2
35 | params:
36 | in_channels: 20
37 | context_dim: 1024
38 | hidden_size: 768
39 | num_heads: 32
40 | depth: 16
41 | max_len: 1000
42 |
43 | first_stage_config:
44 | target: models.autoencoder1d.AutoencoderKL
45 | params:
46 | embed_dim: 20
47 | monitor: val/rec_loss
48 | ckpt_path: /path/to/ckpt/maa2/maa2.ckpt
49 | ddconfig:
50 | double_z: true
51 | in_channels: 80
52 | out_ch: 80
53 | z_channels: 20
54 | kernel_size: 5
55 | ch: 384
56 | ch_mult:
57 | - 1
58 | - 2
59 | - 4
60 | num_res_blocks: 2
61 | attn_layers:
62 | - 3
63 | down_layers:
64 | - 0
65 | dropout: 0.0
66 | lossconfig:
67 | target: torch.nn.Identity
68 | cond_stage_config:
69 | target: models.encoders.modules.FrozenFLANEmbedder
70 |
71 | test_dataset:
72 | target: data.joinaudiodataset_struct_sample_anylen.TestManifest
73 | params:
74 | manifest: ./musiccaps_test_16000_struct.tsv
75 | spec_crop_len: 624
76 |
--------------------------------------------------------------------------------
/lumina_music/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/lumina_music/models/__init__.py
--------------------------------------------------------------------------------
/lumina_music/models/diffusion/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/lumina_music/models/diffusion/__init__.py
--------------------------------------------------------------------------------
/lumina_music/models/diffusion/component.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | try:
7 | from apex.normalization import FusedRMSNorm as RMSNorm
8 | except ImportError:
9 | warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation")
10 |
11 | class RMSNorm(torch.nn.Module):
12 | def __init__(self, dim: int, eps: float = 1e-6):
13 | """
14 | Initialize the RMSNorm normalization layer.
15 |
16 | Args:
17 | dim (int): The dimension of the input tensor.
18 | eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
19 |
20 | Attributes:
21 | eps (float): A small value added to the denominator for numerical stability.
22 | weight (nn.Parameter): Learnable scaling parameter.
23 |
24 | """
25 | super().__init__()
26 | self.eps = eps
27 | self.weight = nn.Parameter(torch.ones(dim))
28 |
29 | def _norm(self, x):
30 | """
31 | Apply the RMSNorm normalization to the input tensor.
32 |
33 | Args:
34 | x (torch.Tensor): The input tensor.
35 |
36 | Returns:
37 | torch.Tensor: The normalized tensor.
38 |
39 | """
40 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
41 |
42 | def forward(self, x):
43 | """
44 | Forward pass through the RMSNorm layer.
45 |
46 | Args:
47 | x (torch.Tensor): The input tensor.
48 |
49 | Returns:
50 | torch.Tensor: The output tensor after applying RMSNorm.
51 |
52 | """
53 | output = self._norm(x.float()).type_as(x)
54 | return output * self.weight
55 |
--------------------------------------------------------------------------------
/lumina_music/models/diffusion/distributions/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/lumina_music/models/diffusion/distributions/__init__.py
--------------------------------------------------------------------------------
/lumina_music/models/diffusion/distributions/distributions.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | class AbstractDistribution:
6 | def sample(self):
7 | raise NotImplementedError()
8 |
9 | def mode(self):
10 | raise NotImplementedError()
11 |
12 |
13 | class DiracDistribution(AbstractDistribution):
14 | def __init__(self, value):
15 | self.value = value
16 |
17 | def sample(self):
18 | return self.value
19 |
20 | def mode(self):
21 | return self.value
22 |
23 |
24 | class DiagonalGaussianDistribution(object):
25 | def __init__(self, parameters, deterministic=False):
26 | self.parameters = parameters
27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29 | self.deterministic = deterministic
30 | self.std = torch.exp(0.5 * self.logvar)
31 | self.var = torch.exp(self.logvar)
32 | if self.deterministic:
33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34 |
35 | def sample(self):
36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
37 | return x
38 |
39 | def kl(self, other=None):
40 | if self.deterministic:
41 | return torch.Tensor([0.0])
42 | else:
43 | sum_dim = list(range(1, len(self.mean.shape)))
44 | if other is None:
45 |
46 | return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=sum_dim)
47 | else:
48 | return 0.5 * torch.sum(
49 | torch.pow(self.mean - other.mean, 2) / other.var
50 | + self.var / other.var
51 | - 1.0
52 | - self.logvar
53 | + other.logvar,
54 | dim=sum_dim,
55 | )
56 |
57 | def nll(self, sample, dims=[1, 2, 3]):
58 | if self.deterministic:
59 | return torch.Tensor([0.0])
60 | logtwopi = np.log(2.0 * np.pi)
61 | return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims)
62 |
63 | def mode(self):
64 | return self.mean
65 |
66 |
67 | def normal_kl(mean1, logvar1, mean2, logvar2):
68 | """
69 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
70 | Compute the KL divergence between two gaussians.
71 | Shapes are automatically broadcasted, so batches can be compared to
72 | scalars, among other use cases.
73 | """
74 | tensor = None
75 | for obj in (mean1, logvar1, mean2, logvar2):
76 | if isinstance(obj, torch.Tensor):
77 | tensor = obj
78 | break
79 | assert tensor is not None, "at least one argument must be a Tensor"
80 |
81 | # Force variances to be Tensors. Broadcasting helps convert scalars to
82 | # Tensors, but it does not work for torch.exp().
83 | logvar1, logvar2 = [x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2)]
84 |
85 | return 0.5 * (
86 | -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
87 | )
88 |
--------------------------------------------------------------------------------
/lumina_music/models/diffusion/ema.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class LitEma(nn.Module):
6 | def __init__(self, model, decay=0.9999, use_num_upates=True):
7 | super().__init__()
8 | if decay < 0.0 or decay > 1.0:
9 | raise ValueError("Decay must be between 0 and 1")
10 |
11 | self.m_name2s_name = {}
12 | self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
13 | self.register_buffer(
14 | "num_updates", torch.tensor(0, dtype=torch.int) if use_num_upates else torch.tensor(-1, dtype=torch.int)
15 | )
16 |
17 | for name, p in model.named_parameters():
18 | if p.requires_grad:
19 | # remove as '.'-character is not allowed in buffers
20 | s_name = name.replace(".", "")
21 | self.m_name2s_name.update({name: s_name})
22 | self.register_buffer(s_name, p.clone().detach().data)
23 |
24 | self.collected_params = []
25 |
26 | def forward(self, model):
27 | decay = self.decay
28 |
29 | if self.num_updates >= 0:
30 | self.num_updates += 1
31 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
32 |
33 | one_minus_decay = 1.0 - decay
34 |
35 | with torch.no_grad():
36 | m_param = dict(model.named_parameters())
37 | shadow_params = dict(self.named_buffers())
38 |
39 | for key in m_param:
40 | if m_param[key].requires_grad:
41 | sname = self.m_name2s_name[key]
42 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
43 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
44 | else:
45 | assert not key in self.m_name2s_name
46 |
47 | def copy_to(self, model):
48 | m_param = dict(model.named_parameters())
49 | shadow_params = dict(self.named_buffers())
50 | for key in m_param:
51 | if m_param[key].requires_grad:
52 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
53 | else:
54 | assert not key in self.m_name2s_name
55 |
56 | def store(self, parameters):
57 | """
58 | Save the current parameters for restoring later.
59 | Args:
60 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
61 | temporarily stored.
62 | """
63 | self.collected_params = [param.clone() for param in parameters]
64 |
65 | def restore(self, parameters):
66 | """
67 | Restore the parameters stored with the `store` method.
68 | Useful to validate the model with EMA parameters without affecting the
69 | original optimization process. Store the parameters before the
70 | `copy_to` method. After validation (or model saving), use this to
71 | restore the former parameters.
72 | Args:
73 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
74 | updated with the stored parameters.
75 | """
76 | for c_param, param in zip(self.collected_params, parameters):
77 | param.data.copy_(c_param.data)
78 |
--------------------------------------------------------------------------------
/lumina_music/models/encoders/CLAP/__init__.py:
--------------------------------------------------------------------------------
1 | from . import audio, clap, utils
2 |
--------------------------------------------------------------------------------
/lumina_music/models/encoders/CLAP/config.yml:
--------------------------------------------------------------------------------
1 | # TEXT ENCODER CONFIG
2 | text_model: 'bert-base-uncased'
3 | text_len: 100
4 | transformer_embed_dim: 768
5 | freeze_text_encoder_weights: True
6 |
7 | # AUDIO ENCODER CONFIG
8 | audioenc_name: 'Cnn14'
9 | out_emb: 2048
10 | sampling_rate: 44100
11 | duration: 5
12 | fmin: 50
13 | fmax: 14000
14 | n_fft: 1028
15 | hop_size: 320
16 | mel_bins: 64
17 | window_size: 1024
18 |
19 | # PROJECTION SPACE CONFIG
20 | d_proj: 1024
21 | temperature: 0.003
22 |
23 | # TRAINING AND EVALUATION CONFIG
24 | num_classes: 527
25 | batch_size: 1024
26 | demo: False
27 |
--------------------------------------------------------------------------------
/lumina_music/models/encoders/CLAP/utils.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import sys
3 |
4 | import yaml
5 |
6 |
7 | def read_config_as_args(config_path, args=None, is_config_str=False):
8 | return_dict = {}
9 |
10 | if config_path is not None:
11 | if is_config_str:
12 | yml_config = yaml.load(config_path, Loader=yaml.FullLoader)
13 | else:
14 | with open(config_path, "r") as f:
15 | yml_config = yaml.load(f, Loader=yaml.FullLoader)
16 |
17 | if args != None:
18 | for k, v in yml_config.items():
19 | if k in args.__dict__:
20 | args.__dict__[k] = v
21 | else:
22 | sys.stderr.write("Ignored unknown parameter {} in yaml.\n".format(k))
23 | else:
24 | for k, v in yml_config.items():
25 | return_dict[k] = v
26 |
27 | args = args if args != None else return_dict
28 | return argparse.Namespace(**args)
29 |
--------------------------------------------------------------------------------
/lumina_music/models/encoders/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/lumina_music/models/encoders/__init__.py
--------------------------------------------------------------------------------
/lumina_music/models/vocoder/bigvgan/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/lumina_music/models/vocoder/bigvgan/__init__.py
--------------------------------------------------------------------------------
/lumina_music/models/vocoder/bigvgan/alias_free_torch/__init__.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2 | # LICENSE is in incl_licenses directory.
3 |
4 | from .act import *
5 | from .filter import *
6 | from .resample import *
7 |
--------------------------------------------------------------------------------
/lumina_music/models/vocoder/bigvgan/alias_free_torch/act.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2 | # LICENSE is in incl_licenses directory.
3 |
4 | import torch.nn as nn
5 |
6 | from .resample import DownSample1d, UpSample1d
7 |
8 |
9 | class Activation1d(nn.Module):
10 | def __init__(
11 | self, activation, up_ratio: int = 2, down_ratio: int = 2, up_kernel_size: int = 12, down_kernel_size: int = 12
12 | ):
13 | super().__init__()
14 | self.up_ratio = up_ratio
15 | self.down_ratio = down_ratio
16 | self.act = activation
17 | self.upsample = UpSample1d(up_ratio, up_kernel_size)
18 | self.downsample = DownSample1d(down_ratio, down_kernel_size)
19 |
20 | # x: [B,C,T]
21 | def forward(self, x):
22 | x = self.upsample(x)
23 | x = self.act(x)
24 | x = self.downsample(x)
25 |
26 | return x
27 |
--------------------------------------------------------------------------------
/lumina_music/models/vocoder/bigvgan/alias_free_torch/filter.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2 | # LICENSE is in incl_licenses directory.
3 |
4 | import math
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 | if "sinc" in dir(torch):
11 | sinc = torch.sinc
12 | else:
13 | # This code is adopted from adefossez's julius.core.sinc under the MIT License
14 | # https://adefossez.github.io/julius/julius/core.html
15 | # LICENSE is in incl_licenses directory.
16 | def sinc(x: torch.Tensor):
17 | """
18 | Implementation of sinc, i.e. sin(pi * x) / (pi * x)
19 | __Warning__: Different to julius.sinc, the input is multiplied by `pi`!
20 | """
21 | return torch.where(
22 | x == 0, torch.tensor(1.0, device=x.device, dtype=x.dtype), torch.sin(math.pi * x) / math.pi / x
23 | )
24 |
25 |
26 | # This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
27 | # https://adefossez.github.io/julius/julius/lowpass.html
28 | # LICENSE is in incl_licenses directory.
29 | def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
30 | even = kernel_size % 2 == 0
31 | half_size = kernel_size // 2
32 |
33 | # For kaiser window
34 | delta_f = 4 * half_width
35 | A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
36 | if A > 50.0:
37 | beta = 0.1102 * (A - 8.7)
38 | elif A >= 21.0:
39 | beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
40 | else:
41 | beta = 0.0
42 | window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
43 |
44 | # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
45 | if even:
46 | time = torch.arange(-half_size, half_size) + 0.5
47 | else:
48 | time = torch.arange(kernel_size) - half_size
49 | if cutoff == 0:
50 | filter_ = torch.zeros_like(time)
51 | else:
52 | filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
53 | # Normalize filter to have sum = 1, otherwise we will have a small leakage
54 | # of the constant component in the input signal.
55 | filter_ /= filter_.sum()
56 | filter = filter_.view(1, 1, kernel_size)
57 |
58 | return filter
59 |
60 |
61 | class LowPassFilter1d(nn.Module):
62 | def __init__(
63 | self,
64 | cutoff=0.5,
65 | half_width=0.6,
66 | stride: int = 1,
67 | padding: bool = True,
68 | padding_mode: str = "replicate",
69 | kernel_size: int = 12,
70 | ):
71 | # kernel_size should be even number for stylegan3 setup,
72 | # in this implementation, odd number is also possible.
73 | super().__init__()
74 | if cutoff < -0.0:
75 | raise ValueError("Minimum cutoff must be larger than zero.")
76 | if cutoff > 0.5:
77 | raise ValueError("A cutoff above 0.5 does not make sense.")
78 | self.kernel_size = kernel_size
79 | self.even = kernel_size % 2 == 0
80 | self.pad_left = kernel_size // 2 - int(self.even)
81 | self.pad_right = kernel_size // 2
82 | self.stride = stride
83 | self.padding = padding
84 | self.padding_mode = padding_mode
85 | filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
86 | self.register_buffer("filter", filter)
87 |
88 | # input [B, C, T]
89 | def forward(self, x):
90 | _, C, _ = x.shape
91 |
92 | if self.padding:
93 | x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
94 | out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
95 |
96 | return out
97 |
--------------------------------------------------------------------------------
/lumina_music/models/vocoder/bigvgan/alias_free_torch/resample.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2 | # LICENSE is in incl_licenses directory.
3 |
4 | import torch.nn as nn
5 | from torch.nn import functional as F
6 |
7 | from .filter import LowPassFilter1d, kaiser_sinc_filter1d
8 |
9 |
10 | class UpSample1d(nn.Module):
11 | def __init__(self, ratio=2, kernel_size=None):
12 | super().__init__()
13 | self.ratio = ratio
14 | self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
15 | self.stride = ratio
16 | self.pad = self.kernel_size // ratio - 1
17 | self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
18 | self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
19 | filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size)
20 | self.register_buffer("filter", filter)
21 |
22 | # x: [B, C, T]
23 | def forward(self, x):
24 | _, C, _ = x.shape
25 |
26 | x = F.pad(x, (self.pad, self.pad), mode="replicate")
27 | x = self.ratio * F.conv_transpose1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
28 | x = x[..., self.pad_left : -self.pad_right]
29 |
30 | return x
31 |
32 |
33 | class DownSample1d(nn.Module):
34 | def __init__(self, ratio=2, kernel_size=None):
35 | super().__init__()
36 | self.ratio = ratio
37 | self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
38 | self.lowpass = LowPassFilter1d(
39 | cutoff=0.5 / ratio, half_width=0.6 / ratio, stride=ratio, kernel_size=self.kernel_size
40 | )
41 |
42 | def forward(self, x):
43 | xx = self.lowpass(x)
44 |
45 | return xx
46 |
--------------------------------------------------------------------------------
/lumina_music/requirements.txt:
--------------------------------------------------------------------------------
1 | soundfile
2 | omegaconf
3 | torchdyn
4 | pytorch_lightning
5 | pytorch_memlab
6 | einops
7 | ninja
8 | torchlibrosa
9 | protobuf
10 | sentencepiece
11 | transformers
12 | gradio
13 |
--------------------------------------------------------------------------------
/lumina_music/run_music.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/bash
2 |
3 | export HF_ENDPOINT=https://hf-mirror.com
4 |
5 | python -u demo_music.py \
6 | --ckpt "/path/to/ckpt/music_generation" \
7 | --vocoder_ckpt "/path/to/ckpt/bigvnat" \
8 | --config_path "configs/lumina-text2music.yaml" \
9 | --sample_rate 16000
10 |
--------------------------------------------------------------------------------
/lumina_next_compositional_generation/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import NextDiT_2B_GQA_patch2, NextDiT_2B_patch2
2 |
--------------------------------------------------------------------------------
/lumina_next_compositional_generation/models/components.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | try:
7 | from apex.normalization import FusedRMSNorm as RMSNorm
8 | except ImportError:
9 | warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation")
10 |
11 | class RMSNorm(torch.nn.Module):
12 | def __init__(self, dim: int, eps: float = 1e-6):
13 | """
14 | Initialize the RMSNorm normalization layer.
15 |
16 | Args:
17 | dim (int): The dimension of the input tensor.
18 | eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
19 |
20 | Attributes:
21 | eps (float): A small value added to the denominator for numerical stability.
22 | weight (nn.Parameter): Learnable scaling parameter.
23 |
24 | """
25 | super().__init__()
26 | self.eps = eps
27 | self.weight = nn.Parameter(torch.ones(dim))
28 |
29 | def _norm(self, x):
30 | """
31 | Apply the RMSNorm normalization to the input tensor.
32 |
33 | Args:
34 | x (torch.Tensor): The input tensor.
35 |
36 | Returns:
37 | torch.Tensor: The normalized tensor.
38 |
39 | """
40 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
41 |
42 | def forward(self, x):
43 | """
44 | Forward pass through the RMSNorm layer.
45 |
46 | Args:
47 | x (torch.Tensor): The input tensor.
48 |
49 | Returns:
50 | torch.Tensor: The output tensor after applying RMSNorm.
51 |
52 | """
53 | output = self._norm(x.float()).type_as(x)
54 | return output * self.weight
55 |
--------------------------------------------------------------------------------
/lumina_next_compositional_generation/transport/__init__.py:
--------------------------------------------------------------------------------
1 | from .transport import ModelType, PathType, Sampler, SNRType, Transport, WeightType
2 |
3 |
4 | def create_transport(
5 | path_type="Linear",
6 | prediction="velocity",
7 | loss_weight=None,
8 | train_eps=None,
9 | sample_eps=None,
10 | snr_type="uniform",
11 | ):
12 | """function for creating Transport object
13 | **Note**: model prediction defaults to velocity
14 | Args:
15 | - path_type: type of path to use; default to linear
16 | - learn_score: set model prediction to score
17 | - learn_noise: set model prediction to noise
18 | - velocity_weighted: weight loss by velocity weight
19 | - likelihood_weighted: weight loss by likelihood weight
20 | - train_eps: small epsilon for avoiding instability during training
21 | - sample_eps: small epsilon for avoiding instability during sampling
22 | """
23 |
24 | if prediction == "noise":
25 | model_type = ModelType.NOISE
26 | elif prediction == "score":
27 | model_type = ModelType.SCORE
28 | else:
29 | model_type = ModelType.VELOCITY
30 |
31 | if loss_weight == "velocity":
32 | loss_type = WeightType.VELOCITY
33 | elif loss_weight == "likelihood":
34 | loss_type = WeightType.LIKELIHOOD
35 | else:
36 | loss_type = WeightType.NONE
37 |
38 | if snr_type == "lognorm":
39 | snr_type = SNRType.LOGNORM
40 | elif snr_type == "uniform":
41 | snr_type = SNRType.UNIFORM
42 | else:
43 | raise ValueError(f"Invalid snr type {snr_type}")
44 |
45 | path_choice = {
46 | "Linear": PathType.LINEAR,
47 | "GVP": PathType.GVP,
48 | "VP": PathType.VP,
49 | }
50 |
51 | path_type = path_choice[path_type]
52 |
53 | if path_type in [PathType.VP]:
54 | train_eps = 1e-5 if train_eps is None else train_eps
55 | sample_eps = 1e-3 if train_eps is None else sample_eps
56 | elif path_type in [PathType.GVP, PathType.LINEAR] and model_type != ModelType.VELOCITY:
57 | train_eps = 1e-3 if train_eps is None else train_eps
58 | sample_eps = 1e-3 if train_eps is None else sample_eps
59 | else: # velocity & [GVP, LINEAR] is stable everywhere
60 | train_eps = 0
61 | sample_eps = 0
62 |
63 | # create flow state
64 | state = Transport(
65 | model_type=model_type,
66 | path_type=path_type,
67 | loss_type=loss_type,
68 | train_eps=train_eps,
69 | sample_eps=sample_eps,
70 | snr_type=snr_type,
71 | )
72 |
73 | return state
74 |
--------------------------------------------------------------------------------
/lumina_next_compositional_generation/transport/utils.py:
--------------------------------------------------------------------------------
1 | import torch as th
2 |
3 |
4 | class EasyDict:
5 | def __init__(self, sub_dict):
6 | for k, v in sub_dict.items():
7 | setattr(self, k, v)
8 |
9 | def __getitem__(self, key):
10 | return getattr(self, key)
11 |
12 |
13 | def mean_flat(x):
14 | """
15 | Take the mean over all non-batch dimensions.
16 | """
17 | return th.mean(x, dim=list(range(1, len(x.size()))))
18 |
19 |
20 | def log_state(state):
21 | result = []
22 |
23 | sorted_state = dict(sorted(state.items()))
24 | for key, value in sorted_state.items():
25 | # Check if the value is an instance of a class
26 | if " Union[str, BytesIO]:
13 | if "s3://" in path:
14 | init_ceph_client_if_needed()
15 | file_bytes = BytesIO(client.get(path))
16 | return file_bytes
17 | else:
18 | return path
19 |
20 |
21 | def init_ceph_client_if_needed():
22 | global client
23 | if client is None:
24 | logger.info(f"initializing ceph client ...")
25 | st = time.time()
26 | from petrel_client.client import Client # noqa
27 |
28 | client = Client("../petreloss.conf")
29 | ed = time.time()
30 | logger.info(f"initialize client cost {ed - st:.2f} s")
31 |
32 |
33 | client = None
34 |
--------------------------------------------------------------------------------
/lumina_next_t2i/grad_norm.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 |
3 | import fairscale.nn.model_parallel.initialize as fs_init
4 | from fairscale.nn.model_parallel.layers import ColumnParallelLinear, ParallelEmbedding, RowParallelLinear
5 | import torch
6 | import torch.distributed as dist
7 | import torch.nn as nn
8 |
9 |
10 | def get_model_parallel_dim_dict(model: nn.Module) -> Dict[str, int]:
11 | ret_dict = {}
12 | for module_name, module in model.named_modules():
13 |
14 | def param_fqn(param_name):
15 | return param_name if module_name == "" else module_name + "." + param_name
16 |
17 | if isinstance(module, ColumnParallelLinear):
18 | ret_dict[param_fqn("weight")] = 0
19 | if module.bias is not None:
20 | ret_dict[param_fqn("bias")] = 0
21 | elif isinstance(module, RowParallelLinear):
22 | ret_dict[param_fqn("weight")] = 1
23 | if module.bias is not None:
24 | ret_dict[param_fqn("bias")] = -1
25 | elif isinstance(module, ParallelEmbedding):
26 | ret_dict[param_fqn("weight")] = 1
27 | else:
28 | for param_name, param in module.named_parameters(recurse=False):
29 | ret_dict[param_fqn(param_name)] = -1
30 | return ret_dict
31 |
32 |
33 | def calculate_l2_grad_norm(
34 | model: nn.Module,
35 | model_parallel_dim_dict: Dict[str, int],
36 | ) -> float:
37 | mp_norm_sq = torch.tensor(0.0, dtype=torch.float32, device="cuda")
38 | non_mp_norm_sq = torch.tensor(0.0, dtype=torch.float32, device="cuda")
39 |
40 | for name, param in model.named_parameters():
41 | if param.grad is None:
42 | continue
43 | name = ".".join(x for x in name.split(".") if not x.startswith("_"))
44 | assert name in model_parallel_dim_dict
45 | if model_parallel_dim_dict[name] < 0:
46 | non_mp_norm_sq += param.grad.norm(dtype=torch.float32) ** 2
47 | else:
48 | mp_norm_sq += param.grad.norm(dtype=torch.float32) ** 2
49 |
50 | dist.all_reduce(mp_norm_sq)
51 | dist.all_reduce(non_mp_norm_sq)
52 | non_mp_norm_sq /= fs_init.get_model_parallel_world_size()
53 |
54 | return (mp_norm_sq.item() + non_mp_norm_sq.item()) ** 0.5
55 |
56 |
57 | def scale_grad(model: nn.Module, factor: float) -> None:
58 | for param in model.parameters():
59 | if param.grad is not None:
60 | param.grad.mul_(factor)
61 |
--------------------------------------------------------------------------------
/lumina_next_t2i/imgproc.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | from PIL import Image
4 | import numpy as np
5 |
6 |
7 | def center_crop_arr(pil_image, image_size):
8 | """
9 | Center cropping implementation from ADM.
10 | https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
11 | """
12 | while min(*pil_image.size) >= 2 * image_size:
13 | pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX)
14 |
15 | scale = image_size / min(*pil_image.size)
16 | pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC)
17 |
18 | arr = np.array(pil_image)
19 | crop_y = (arr.shape[0] - image_size) // 2
20 | crop_x = (arr.shape[1] - image_size) // 2
21 | return Image.fromarray(arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size])
22 |
23 |
24 | def center_crop(pil_image, crop_size):
25 | while pil_image.size[0] >= 2 * crop_size[0] and pil_image.size[1] >= 2 * crop_size[1]:
26 | pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX)
27 |
28 | scale = max(crop_size[0] / pil_image.size[0], crop_size[1] / pil_image.size[1])
29 | pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC)
30 |
31 | crop_left = random.randint(0, pil_image.size[0] - crop_size[0])
32 | crop_upper = random.randint(0, pil_image.size[1] - crop_size[1])
33 | crop_right = crop_left + crop_size[0]
34 | crop_lower = crop_upper + crop_size[1]
35 | return pil_image.crop(box=(crop_left, crop_upper, crop_right, crop_lower))
36 |
37 |
38 | def var_center_crop(pil_image, crop_size_list, random_top_k=4):
39 | w, h = pil_image.size
40 | rem_percent = [min(cw / w, ch / h) / max(cw / w, ch / h) for cw, ch in crop_size_list]
41 | crop_size = random.choice(
42 | sorted(((x, y) for x, y in zip(rem_percent, crop_size_list)), reverse=True)[:random_top_k]
43 | )[1]
44 | return center_crop(pil_image, crop_size)
45 |
46 |
47 | def generate_crop_size_list(num_patches, patch_size, max_ratio=4.0):
48 | assert max_ratio >= 1.0
49 | crop_size_list = []
50 | wp, hp = num_patches, 1
51 | while wp > 0:
52 | if max(wp, hp) / min(wp, hp) <= max_ratio:
53 | crop_size_list.append((wp * patch_size, hp * patch_size))
54 | if (hp + 1) * wp <= num_patches:
55 | hp += 1
56 | else:
57 | wp -= 1
58 | return crop_size_list
59 |
--------------------------------------------------------------------------------
/lumina_next_t2i/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import NextDiT_2B_GQA_patch2, NextDiT_2B_patch2
2 |
--------------------------------------------------------------------------------
/lumina_next_t2i/models/components.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | try:
7 | from apex.normalization import FusedRMSNorm as RMSNorm
8 | except ImportError:
9 | warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation")
10 |
11 | class RMSNorm(torch.nn.Module):
12 | def __init__(self, dim: int, eps: float = 1e-6):
13 | """
14 | Initialize the RMSNorm normalization layer.
15 |
16 | Args:
17 | dim (int): The dimension of the input tensor.
18 | eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
19 |
20 | Attributes:
21 | eps (float): A small value added to the denominator for numerical stability.
22 | weight (nn.Parameter): Learnable scaling parameter.
23 |
24 | """
25 | super().__init__()
26 | self.eps = eps
27 | self.weight = nn.Parameter(torch.ones(dim))
28 |
29 | def _norm(self, x):
30 | """
31 | Apply the RMSNorm normalization to the input tensor.
32 |
33 | Args:
34 | x (torch.Tensor): The input tensor.
35 |
36 | Returns:
37 | torch.Tensor: The normalized tensor.
38 |
39 | """
40 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
41 |
42 | def forward(self, x):
43 | """
44 | Forward pass through the RMSNorm layer.
45 |
46 | Args:
47 | x (torch.Tensor): The input tensor.
48 |
49 | Returns:
50 | torch.Tensor: The output tensor after applying RMSNorm.
51 |
52 | """
53 | output = self._norm(x.float()).type_as(x)
54 | return output * self.weight
55 |
--------------------------------------------------------------------------------
/lumina_next_t2i/parallel.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import os
4 | import subprocess
5 | from time import sleep
6 |
7 | import fairscale.nn.model_parallel.initialize as fs_init
8 | import torch
9 | import torch.distributed as dist
10 |
11 |
12 | def _setup_dist_env_from_slurm(args):
13 | while not os.environ.get("MASTER_ADDR", ""):
14 | os.environ["MASTER_ADDR"] = (
15 | subprocess.check_output(
16 | "sinfo -Nh -n %s | head -n 1 | awk '{print $1}'" % os.environ["SLURM_NODELIST"],
17 | shell=True,
18 | )
19 | .decode()
20 | .strip()
21 | )
22 | sleep(1)
23 | os.environ["MASTER_PORT"] = str(args.master_port)
24 | os.environ["RANK"] = os.environ["SLURM_PROCID"]
25 | os.environ["WORLD_SIZE"] = os.environ["SLURM_NPROCS"]
26 | os.environ["LOCAL_RANK"] = os.environ["SLURM_LOCALID"]
27 | os.environ["LOCAL_WORLD_SIZE"] = os.environ["SLURM_NTASKS_PER_NODE"]
28 |
29 |
30 | _INTRA_NODE_PROCESS_GROUP, _INTER_NODE_PROCESS_GROUP = None, None
31 | _LOCAL_RANK, _LOCAL_WORLD_SIZE = -1, -1
32 |
33 |
34 | def get_local_rank() -> int:
35 | return _LOCAL_RANK
36 |
37 |
38 | def get_local_world_size() -> int:
39 | return _LOCAL_WORLD_SIZE
40 |
41 |
42 | def distributed_init(args):
43 | if any([x not in os.environ for x in ["RANK", "WORLD_SIZE", "MASTER_PORT", "MASTER_ADDR"]]):
44 | _setup_dist_env_from_slurm(args)
45 |
46 | dist.init_process_group("nccl")
47 | fs_init.initialize_model_parallel(args.model_parallel_size)
48 | torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
49 |
50 | global _LOCAL_RANK, _LOCAL_WORLD_SIZE
51 | _LOCAL_RANK = int(os.environ["LOCAL_RANK"])
52 | _LOCAL_WORLD_SIZE = int(os.environ["LOCAL_WORLD_SIZE"])
53 |
54 | global _INTRA_NODE_PROCESS_GROUP, _INTER_NODE_PROCESS_GROUP
55 | local_ranks, local_world_sizes = [
56 | torch.empty([dist.get_world_size()], dtype=torch.long, device="cuda") for _ in (0, 1)
57 | ]
58 | dist.all_gather_into_tensor(local_ranks, torch.tensor(get_local_rank(), device="cuda"))
59 | dist.all_gather_into_tensor(local_world_sizes, torch.tensor(get_local_world_size(), device="cuda"))
60 | local_ranks, local_world_sizes = local_ranks.tolist(), local_world_sizes.tolist()
61 | node_ranks = [[0]]
62 | for i in range(1, dist.get_world_size()):
63 | if len(node_ranks[-1]) == local_world_sizes[i - 1]:
64 | node_ranks.append([])
65 | else:
66 | assert local_world_sizes[i] == local_world_sizes[i - 1]
67 | node_ranks[-1].append(i)
68 | for ranks in node_ranks:
69 | group = dist.new_group(ranks)
70 | if dist.get_rank() in ranks:
71 | assert _INTRA_NODE_PROCESS_GROUP is None
72 | _INTRA_NODE_PROCESS_GROUP = group
73 | assert _INTRA_NODE_PROCESS_GROUP is not None
74 |
75 | if min(local_world_sizes) == max(local_world_sizes):
76 | for i in range(get_local_world_size()):
77 | group = dist.new_group(list(range(i, dist.get_world_size(), get_local_world_size())))
78 | if i == get_local_rank():
79 | assert _INTER_NODE_PROCESS_GROUP is None
80 | _INTER_NODE_PROCESS_GROUP = group
81 | assert _INTER_NODE_PROCESS_GROUP is not None
82 |
83 |
84 | def get_intra_node_process_group():
85 | assert _INTRA_NODE_PROCESS_GROUP is not None, "Intra-node process group is not initialized."
86 | return _INTRA_NODE_PROCESS_GROUP
87 |
88 |
89 | def get_inter_node_process_group():
90 | assert _INTRA_NODE_PROCESS_GROUP is not None, "Intra- and inter-node process groups are not initialized."
91 | return _INTER_NODE_PROCESS_GROUP
92 |
--------------------------------------------------------------------------------
/lumina_next_t2i/transport/__init__.py:
--------------------------------------------------------------------------------
1 | from .transport import ModelType, PathType, Sampler, Transport, WeightType
2 |
3 |
4 | def create_transport(
5 | path_type="Linear",
6 | prediction="velocity",
7 | loss_weight=None,
8 | train_eps=None,
9 | sample_eps=None,
10 | snr_type="uniform",
11 | ):
12 | """function for creating Transport object
13 | **Note**: model prediction defaults to velocity
14 | Args:
15 | - path_type: type of path to use; default to linear
16 | - learn_score: set model prediction to score
17 | - learn_noise: set model prediction to noise
18 | - velocity_weighted: weight loss by velocity weight
19 | - likelihood_weighted: weight loss by likelihood weight
20 | - train_eps: small epsilon for avoiding instability during training
21 | - sample_eps: small epsilon for avoiding instability during sampling
22 | """
23 |
24 | if prediction == "noise":
25 | model_type = ModelType.NOISE
26 | elif prediction == "score":
27 | model_type = ModelType.SCORE
28 | else:
29 | model_type = ModelType.VELOCITY
30 |
31 | if loss_weight == "velocity":
32 | loss_type = WeightType.VELOCITY
33 | elif loss_weight == "likelihood":
34 | loss_type = WeightType.LIKELIHOOD
35 | else:
36 | loss_type = WeightType.NONE
37 |
38 | path_choice = {
39 | "Linear": PathType.LINEAR,
40 | "GVP": PathType.GVP,
41 | "VP": PathType.VP,
42 | }
43 |
44 | path_type = path_choice[path_type]
45 |
46 | if path_type in [PathType.VP]:
47 | train_eps = 1e-5 if train_eps is None else train_eps
48 | sample_eps = 1e-3 if train_eps is None else sample_eps
49 | elif path_type in [PathType.GVP, PathType.LINEAR] and model_type != ModelType.VELOCITY:
50 | train_eps = 1e-3 if train_eps is None else train_eps
51 | sample_eps = 1e-3 if train_eps is None else sample_eps
52 | else: # velocity & [GVP, LINEAR] is stable everywhere
53 | train_eps = 0
54 | sample_eps = 0
55 |
56 | # create flow state
57 | state = Transport(
58 | model_type=model_type,
59 | path_type=path_type,
60 | loss_type=loss_type,
61 | train_eps=train_eps,
62 | sample_eps=sample_eps,
63 | snr_type=snr_type,
64 | )
65 |
66 | return state
67 |
--------------------------------------------------------------------------------
/lumina_next_t2i/transport/utils.py:
--------------------------------------------------------------------------------
1 | import torch as th
2 |
3 |
4 | class EasyDict:
5 | def __init__(self, sub_dict):
6 | for k, v in sub_dict.items():
7 | setattr(self, k, v)
8 |
9 | def __getitem__(self, key):
10 | return getattr(self, key)
11 |
12 |
13 | def mean_flat(x):
14 | """
15 | Take the mean over all non-batch dimensions.
16 | """
17 | return th.mean(x, dim=list(range(1, len(x.size()))))
18 |
19 |
20 | def log_state(state):
21 | result = []
22 |
23 | sorted_state = dict(sorted(state.items()))
24 | for key, value in sorted_state.items():
25 | # Check if the value is an instance of a class
26 | if " Union[str, BytesIO]:
13 | if "s3://" in path:
14 | init_ceph_client_if_needed()
15 | file_bytes = BytesIO(client.get(path))
16 | return file_bytes
17 | else:
18 | return path
19 |
20 |
21 | def init_ceph_client_if_needed():
22 | global client
23 | if client is None:
24 | logger.info(f"initializing ceph client ...")
25 | st = time.time()
26 | from petrel_client.client import Client # noqa
27 |
28 | client = Client("../petreloss.conf")
29 | ed = time.time()
30 | logger.info(f"initialize client cost {ed - st:.2f} s")
31 |
32 |
33 | client = None
34 |
--------------------------------------------------------------------------------
/lumina_next_t2i_mini/grad_norm.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 |
3 | import torch
4 | import torch.distributed as dist
5 | import torch.nn as nn
6 |
7 |
8 | def calculate_l2_grad_norm(model: nn.Module) -> float:
9 | non_mp_norm_sq = torch.tensor(0.0, dtype=torch.float32, device="cuda")
10 |
11 | for name, param in model.named_parameters():
12 | if param.grad is None:
13 | continue
14 | non_mp_norm_sq += param.grad.norm(dtype=torch.float32) ** 2
15 |
16 | dist.all_reduce(non_mp_norm_sq)
17 |
18 | return non_mp_norm_sq.item() ** 0.5
19 |
20 |
21 | def scale_grad(model: nn.Module, factor: float) -> None:
22 | for param in model.parameters():
23 | if param.grad is not None:
24 | param.grad.mul_(factor)
25 |
--------------------------------------------------------------------------------
/lumina_next_t2i_mini/imgproc.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | from PIL import Image
4 | import numpy as np
5 |
6 |
7 | def center_crop_arr(pil_image, image_size):
8 | """
9 | Center cropping implementation from ADM.
10 | https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
11 | """
12 | while min(*pil_image.size) >= 2 * image_size:
13 | pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX)
14 |
15 | scale = image_size / min(*pil_image.size)
16 | pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC)
17 |
18 | arr = np.array(pil_image)
19 | crop_y = (arr.shape[0] - image_size) // 2
20 | crop_x = (arr.shape[1] - image_size) // 2
21 | return Image.fromarray(arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size])
22 |
23 |
24 | def center_crop(pil_image, crop_size):
25 | while pil_image.size[0] >= 2 * crop_size[0] and pil_image.size[1] >= 2 * crop_size[1]:
26 | pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX)
27 |
28 | scale = max(crop_size[0] / pil_image.size[0], crop_size[1] / pil_image.size[1])
29 | pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC)
30 |
31 | crop_left = random.randint(0, pil_image.size[0] - crop_size[0])
32 | crop_upper = random.randint(0, pil_image.size[1] - crop_size[1])
33 | crop_right = crop_left + crop_size[0]
34 | crop_lower = crop_upper + crop_size[1]
35 | return pil_image.crop(box=(crop_left, crop_upper, crop_right, crop_lower))
36 |
37 |
38 | def var_center_crop(pil_image, crop_size_list, random_top_k=4):
39 | w, h = pil_image.size
40 | rem_percent = [min(cw / w, ch / h) / max(cw / w, ch / h) for cw, ch in crop_size_list]
41 | crop_size = random.choice(
42 | sorted(((x, y) for x, y in zip(rem_percent, crop_size_list)), reverse=True)[:random_top_k]
43 | )[1]
44 | return center_crop(pil_image, crop_size)
45 |
46 |
47 | def generate_crop_size_list(num_patches, patch_size, max_ratio=4.0):
48 | assert max_ratio >= 1.0
49 | crop_size_list = []
50 | wp, hp = num_patches, 1
51 | while wp > 0:
52 | if max(wp, hp) / min(wp, hp) <= max_ratio:
53 | crop_size_list.append((wp * patch_size, hp * patch_size))
54 | if (hp + 1) * wp <= num_patches:
55 | hp += 1
56 | else:
57 | wp -= 1
58 | return crop_size_list
59 |
--------------------------------------------------------------------------------
/lumina_next_t2i_mini/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .nextdit import NextDiT_2B_GQA_patch2, NextDiT_2B_patch2
2 |
--------------------------------------------------------------------------------
/lumina_next_t2i_mini/models/components.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | try:
7 | from apex.normalization import FusedRMSNorm as RMSNorm
8 | except ImportError:
9 | warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation")
10 |
11 | class RMSNorm(torch.nn.Module):
12 | def __init__(self, dim: int, eps: float = 1e-6):
13 | """
14 | Initialize the RMSNorm normalization layer.
15 |
16 | Args:
17 | dim (int): The dimension of the input tensor.
18 | eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
19 |
20 | Attributes:
21 | eps (float): A small value added to the denominator for numerical stability.
22 | weight (nn.Parameter): Learnable scaling parameter.
23 |
24 | """
25 | super().__init__()
26 | self.eps = eps
27 | self.weight = nn.Parameter(torch.ones(dim))
28 |
29 | def _norm(self, x):
30 | """
31 | Apply the RMSNorm normalization to the input tensor.
32 |
33 | Args:
34 | x (torch.Tensor): The input tensor.
35 |
36 | Returns:
37 | torch.Tensor: The normalized tensor.
38 |
39 | """
40 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
41 |
42 | def forward(self, x):
43 | """
44 | Forward pass through the RMSNorm layer.
45 |
46 | Args:
47 | x (torch.Tensor): The input tensor.
48 |
49 | Returns:
50 | torch.Tensor: The output tensor after applying RMSNorm.
51 |
52 | """
53 | output = self._norm(x.float()).type_as(x)
54 | return output * self.weight
55 |
--------------------------------------------------------------------------------
/lumina_next_t2i_mini/parallel.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import os
4 | import subprocess
5 | from time import sleep
6 |
7 | import torch
8 | import torch.distributed as dist
9 |
10 |
11 | def _setup_dist_env_from_slurm(args):
12 | while not os.environ.get("MASTER_ADDR", ""):
13 | os.environ["MASTER_ADDR"] = (
14 | subprocess.check_output(
15 | "sinfo -Nh -n %s | head -n 1 | awk '{print $1}'" % os.environ["SLURM_NODELIST"],
16 | shell=True,
17 | )
18 | .decode()
19 | .strip()
20 | )
21 | sleep(1)
22 | os.environ["MASTER_PORT"] = str(args.master_port)
23 | os.environ["RANK"] = os.environ["SLURM_PROCID"]
24 | os.environ["WORLD_SIZE"] = os.environ["SLURM_NPROCS"]
25 | os.environ["LOCAL_RANK"] = os.environ["SLURM_LOCALID"]
26 | os.environ["LOCAL_WORLD_SIZE"] = os.environ["SLURM_NTASKS_PER_NODE"]
27 |
28 |
29 | _INTRA_NODE_PROCESS_GROUP, _INTER_NODE_PROCESS_GROUP = None, None
30 | _LOCAL_RANK, _LOCAL_WORLD_SIZE = -1, -1
31 |
32 |
33 | def get_local_rank() -> int:
34 | return _LOCAL_RANK
35 |
36 |
37 | def get_local_world_size() -> int:
38 | return _LOCAL_WORLD_SIZE
39 |
40 |
41 | def distributed_init(args):
42 | if any([x not in os.environ for x in ["RANK", "WORLD_SIZE", "MASTER_PORT", "MASTER_ADDR"]]):
43 | _setup_dist_env_from_slurm(args)
44 |
45 | dist.init_process_group("nccl")
46 | torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
47 |
48 | global _LOCAL_RANK, _LOCAL_WORLD_SIZE
49 | _LOCAL_RANK = int(os.environ["LOCAL_RANK"])
50 | _LOCAL_WORLD_SIZE = int(os.environ["LOCAL_WORLD_SIZE"])
51 |
52 | global _INTRA_NODE_PROCESS_GROUP, _INTER_NODE_PROCESS_GROUP
53 | local_ranks, local_world_sizes = [
54 | torch.empty([dist.get_world_size()], dtype=torch.long, device="cuda") for _ in (0, 1)
55 | ]
56 | dist.all_gather_into_tensor(local_ranks, torch.tensor(get_local_rank(), device="cuda"))
57 | dist.all_gather_into_tensor(local_world_sizes, torch.tensor(get_local_world_size(), device="cuda"))
58 | local_ranks, local_world_sizes = local_ranks.tolist(), local_world_sizes.tolist()
59 | node_ranks = [[0]]
60 | for i in range(1, dist.get_world_size()):
61 | if len(node_ranks[-1]) == local_world_sizes[i - 1]:
62 | node_ranks.append([])
63 | else:
64 | assert local_world_sizes[i] == local_world_sizes[i - 1]
65 | node_ranks[-1].append(i)
66 | for ranks in node_ranks:
67 | group = dist.new_group(ranks)
68 | if dist.get_rank() in ranks:
69 | assert _INTRA_NODE_PROCESS_GROUP is None
70 | _INTRA_NODE_PROCESS_GROUP = group
71 | assert _INTRA_NODE_PROCESS_GROUP is not None
72 |
73 | if min(local_world_sizes) == max(local_world_sizes):
74 | for i in range(get_local_world_size()):
75 | group = dist.new_group(list(range(i, dist.get_world_size(), get_local_world_size())))
76 | if i == get_local_rank():
77 | assert _INTER_NODE_PROCESS_GROUP is None
78 | _INTER_NODE_PROCESS_GROUP = group
79 | assert _INTER_NODE_PROCESS_GROUP is not None
80 |
81 |
82 | def get_intra_node_process_group():
83 | assert _INTRA_NODE_PROCESS_GROUP is not None, "Intra-node process group is not initialized."
84 | return _INTRA_NODE_PROCESS_GROUP
85 |
86 |
87 | def get_inter_node_process_group():
88 | assert _INTRA_NODE_PROCESS_GROUP is not None, "Intra- and inter-node process groups are not initialized."
89 | return _INTER_NODE_PROCESS_GROUP
90 |
--------------------------------------------------------------------------------
/lumina_next_t2i_mini/scripts/sample.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 |
4 | # Lumina-Next supports any resolution (up to 2K)
5 | # res="1024:024x1024 1536:1536x1536 1664:1664x1664 1792:1792x1792 2048:2048x2048"
6 | res=1024:1024x1024
7 | t=4
8 | cfg=4.0
9 | seed=25
10 | steps=20
11 | solver=midpoint
12 | model_dir=your/model/dir/here
13 | cap_dir=your/caption/dir/here
14 | out_dir=your/output/dir/here
15 | python -u sample.py --ckpt ${model_dir} \
16 | --image_save_path ${out_dir} \
17 | --solver ${solver} --num_sampling_steps ${steps} \
18 | --caption_path ${cap_dir} \
19 | --seed ${seed} \
20 | --resolution ${res} \
21 | --time_shifting_factor ${t} \
22 | --cfg_scale ${cfg} \
23 | --batch_size 1 \
24 | --use_flash_attn True # You can set this to False if you want to disable the flash attention
25 |
--------------------------------------------------------------------------------
/lumina_next_t2i_mini/scripts/sample_img2img.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 |
4 | # Lumina-Next supports any resolution (up to 2K)
5 | # res="1024:024x1024 1536:1536x1536 1664:1664x1664 1792:1792x1792 2048:2048x2048"
6 | res=1024:1024x1024
7 | t=4
8 | cfg=4.0
9 | seed=25
10 | steps=50
11 | solver=euler
12 | strength=0.6
13 | image_dir=your/image/dir/here
14 | model_dir=your/model/dir/here
15 | cap_dir=your/caption/dir/here
16 | out_dir=your/output/dir/here
17 | python -u sample_img2img.py --ckpt ${model_dir} \
18 | --image_save_path ${out_dir} \
19 | --solver ${solver} --num_sampling_steps ${steps} \
20 | --caption_path ${cap_dir} \
21 | --image ${image_dir} \
22 | --seed ${seed} \
23 | --resolution ${res} \
24 | --strength ${strength} \
25 | --time_shifting_factor ${t} \
26 | --cfg_scale ${cfg} \
27 | --batch_size 1 \
28 | --use_flash_attn True # You can set this to False if you want to disable the flash attention
--------------------------------------------------------------------------------
/lumina_next_t2i_mini/scripts/sample_sd3.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 |
4 | # SD3 supports up to 1.2K resolution
5 | # res="1024:024x1024 1280:1280x1280"
6 | res=1024:1024x1024
7 | shift=3
8 | cfg=7.0
9 | seed=25
10 | steps=20
11 | solver=midpoint
12 | model_dir=stabilityai/stable-diffusion-3-medium-diffusers
13 | cap_dir=your/caption/dir/here
14 | out_dir=your/output/dir/here
15 | python -u sample_sd3.py --ckpt ${model_dir} \
16 | --image_save_path ${out_dir} \
17 | --solver ${solver} --num_sampling_steps ${steps} \
18 | --caption_path ${cap_dir} \
19 | --seed ${seed} \
20 | --resolution ${res} \
21 | --time_shifting_factor ${shift} \
22 | --cfg_scale ${cfg} \
23 | --batch_size 1 \
24 |
--------------------------------------------------------------------------------
/lumina_t2i/.isort.cfg:
--------------------------------------------------------------------------------
1 | [settings]
2 | profile = black
3 | line_length = 120
4 | sections = FUTURE,STDLIB,THIRDPARTY,FIRSTPARTY,LOCALFOLDER
5 | no_lines_before = STDLIB,LOCALFOLDER
6 | lines_between_types = 1
7 | combine_as_imports = True
8 | force_sort_within_sections = true
9 | order_by_type = True
10 |
--------------------------------------------------------------------------------
/lumina_t2i/__init__.py:
--------------------------------------------------------------------------------
1 | from .entry_point import *
2 |
--------------------------------------------------------------------------------
/lumina_t2i/configs/data/JourneyDB.yaml:
--------------------------------------------------------------------------------
1 | META:
2 | -
3 | path: '/path/to/journeyDB_train.json'
4 |
--------------------------------------------------------------------------------
/lumina_t2i/configs/infer/settings.yaml:
--------------------------------------------------------------------------------
1 | - settings:
2 |
3 | model:
4 | ckpt: ""
5 | ckpt_lm: ""
6 | token: ""
7 |
8 | transport:
9 | path_type: "Linear" # option: ["Linear", "GVP", "VP"]
10 | prediction: "velocity" # option: ["velocity", "score", "noise"]
11 | loss_weight: "velocity" # option: [None, "velocity", "likelihood"]
12 | sample_eps: 0.1
13 | train_eps: 0.2
14 |
15 | ode:
16 | atol: 1e-6 # Absolute tolerance
17 | rtol: 1e-3 # Relative tolerance
18 | reverse: false # option: true or false
19 | likelihood: false # option: true or false
20 |
21 | infer:
22 | resolution: "1024x1024" # option: ["1024x1024", "512x2048", "2048x512", "(Extrapolation) 1664x1664", "(Extrapolation) 1024x2048", "(Extrapolation) 2048x1024"]
23 | num_sampling_steps: 60 # range: 1-1000
24 | cfg_scale: 4. # range: 1-20
25 | solver: "euler" # option: ["euler", "dopri5", "dopri8"]
26 | t_shift: 4 # range: 1-20 (int only)
27 | ntk_scaling: true # option: true or false
28 | proportional_attn: true # option: true or false
29 | seed: 0 # rnage: any number
30 |
--------------------------------------------------------------------------------
/lumina_t2i/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .data_reader import *
2 | from .dataset import *
3 |
--------------------------------------------------------------------------------
/lumina_t2i/data/data_reader.py:
--------------------------------------------------------------------------------
1 | from io import BytesIO
2 | import logging
3 | import time
4 | from typing import Union
5 |
6 | from PIL import Image
7 |
8 | Image.MAX_IMAGE_PIXELS = None
9 | logger = logging.getLogger(__name__)
10 |
11 |
12 | def read_general(path) -> Union[str, BytesIO]:
13 | if "s3://" in path:
14 | init_ceph_client_if_needed()
15 | file_bytes = BytesIO(client.get(path))
16 | return file_bytes
17 | else:
18 | return path
19 |
20 |
21 | def init_ceph_client_if_needed():
22 | global client
23 | if client is None:
24 | logger.info(f"initializing ceph client ...")
25 | st = time.time()
26 | from petrel_client.client import Client # noqa
27 |
28 | client = Client("../petreloss.conf")
29 | ed = time.time()
30 | logger.info(f"initialize client cost {ed - st:.2f} s")
31 |
32 |
33 | client = None
34 |
--------------------------------------------------------------------------------
/lumina_t2i/exps/5B_bs512_lr1e-4_bf16_1024px_sdxlvae.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 |
3 | train_data_root='configs/data/JourneyDB.yaml'
4 |
5 | model=DiT_Llama_5B_patch2
6 | batch_size=512
7 | lr=1e-4
8 | precision=bf16
9 | image_size=1024
10 | vae=sdxl
11 | init_from=$1
12 | load_str=$2
13 |
14 | exp_name=${model}_bs${batch_size}_lr${lr}_${precision}_${image_size}px_vae${vae}_init${load_str}
15 | mkdir -p results/"$exp_name"
16 |
17 | torchrun --nproc-per-node=8 train.py \
18 | --master_port 18181 \
19 | --model ${model} \
20 | --data_path ${train_data_root} \
21 | --results_dir results/${exp_name} \
22 | --micro_batch_size 2 \
23 | --global_batch_size ${batch_size} --lr ${lr} \
24 | --data_parallel fsdp \
25 | --max_steps 3000000 \
26 | --ckpt_every 2000 --log_every 10 \
27 | --precision ${precision} --grad_precision fp32 --qk_norm \
28 | --image_size ${image_size} \
29 | --init_from $init_from \
30 | --global_seed 3 \
31 | --vae ${vae} \
32 | 2>&1 | tee -a results/"$exp_name"/output.log
33 |
--------------------------------------------------------------------------------
/lumina_t2i/exps/5B_bs512_lr1e-4_bf16_256px_sdxlvae.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 |
3 | train_data_root='configs/data/JourneyDB.yaml'
4 |
5 | model=DiT_Llama_5B_patch2
6 | batch_size=512
7 | lr=1e-4
8 | precision=bf16
9 | image_size=256
10 | vae=sdxl
11 |
12 | exp_name=${model}_bs${batch_size}_lr${lr}_${precision}_${image_size}px_vae${vae}
13 | mkdir -p results/"$exp_name"
14 |
15 | torchrun --nproc-per-node=8 train.py \
16 | --master_port 18181 \
17 | --model ${model} \
18 | --data_path ${train_data_root} \
19 | --results_dir results/${exp_name} \
20 | --micro_batch_size 16 \
21 | --global_batch_size ${batch_size} --lr ${lr} \
22 | --data_parallel fsdp \
23 | --max_steps 3000000 \
24 | --ckpt_every 20000 --log_every 100 \
25 | --precision ${precision} --grad_precision fp32 --qk_norm \
26 | --image_size ${image_size} \
27 | --vae ${vae} \
28 | 2>&1 | tee -a results/"$exp_name"/output.log
29 |
--------------------------------------------------------------------------------
/lumina_t2i/exps/5B_bs512_lr1e-4_bf16_512px_sdxlvae.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 |
3 | train_data_root='configs/data/JourneyDB.yaml'
4 |
5 | model=DiT_Llama_5B_patch2
6 | batch_size=512
7 | lr=1e-4
8 | precision=bf16
9 | image_size=512
10 | vae=sdxl
11 | init_from=$1
12 | load_str=$2
13 |
14 | exp_name=${model}_bs${batch_size}_lr${lr}_${precision}_${image_size}px_vae${vae}_init${load_str}
15 | mkdir -p results/"$exp_name"
16 |
17 | torchrun --nproc-per-node=8 train.py \
18 | --master_port 18181 \
19 | --model ${model} \
20 | --data_path ${train_data_root} \
21 | --results_dir results/${exp_name} \
22 | --micro_batch_size 8 \
23 | --global_batch_size ${batch_size} --lr ${lr} \
24 | --data_parallel fsdp \
25 | --max_steps 3000000 \
26 | --ckpt_every 20000 --log_every 100 \
27 | --precision ${precision} --grad_precision fp32 --qk_norm \
28 | --image_size ${image_size} \
29 | --init_from $init_from \
30 | --global_seed 2 \
31 | --vae ${vae} \
32 | 2>&1 | tee -a results/"$exp_name"/output.log
33 |
--------------------------------------------------------------------------------
/lumina_t2i/exps/slurm/5B_bs512_lr1e-4_bf16_1024px_sdxlvae.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 |
3 | train_data_root='configs/data/JourneyDB.yaml'
4 |
5 | model=DiT_Llama_5B_patch2
6 | batch_size=512
7 | lr=1e-4
8 | precision=bf16
9 | image_size=1024
10 | vae=sdxl
11 | init_from=$1
12 | load_str=$2
13 |
14 | exp_name=${model}_bs${batch_size}_lr${lr}_${precision}_${image_size}px_vae${vae}_init${load_str}
15 | mkdir -p results/"$exp_name"
16 |
17 | python -u train.py \
18 | --master_port 18181 \
19 | --model ${model} \
20 | --data_path ${train_data_root} \
21 | --results_dir results/${exp_name} \
22 | --micro_batch_size 2 \
23 | --global_batch_size ${batch_size} --lr ${lr} \
24 | --data_parallel fsdp \
25 | --max_steps 3000000 \
26 | --ckpt_every 2000 --log_every 10 \
27 | --precision ${precision} --grad_precision fp32 --qk_norm \
28 | --image_size ${image_size} \
29 | --init_from $init_from \
30 | --global_seed 3 \
31 | --vae ${vae} \
32 | 2>&1 | tee -a results/"$exp_name"/output.log
33 |
--------------------------------------------------------------------------------
/lumina_t2i/exps/slurm/5B_bs512_lr1e-4_bf16_256px_sdxlvae.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 |
3 | train_data_root='configs/data/JourneyDB.yaml'
4 |
5 | model=DiT_Llama_5B_patch2
6 | batch_size=512
7 | lr=1e-4
8 | precision=bf16
9 | image_size=256
10 | vae=sdxl
11 |
12 | exp_name=${model}_bs${batch_size}_lr${lr}_${precision}_${image_size}px_vae${vae}
13 | mkdir -p results/"$exp_name"
14 |
15 | python -u train.py \
16 | --master_port 18181 \
17 | --model ${model} \
18 | --data_path ${train_data_root} \
19 | --results_dir results/${exp_name} \
20 | --micro_batch_size 16 \
21 | --global_batch_size ${batch_size} --lr ${lr} \
22 | --data_parallel fsdp \
23 | --max_steps 3000000 \
24 | --ckpt_every 20000 --log_every 100 \
25 | --precision ${precision} --grad_precision fp32 --qk_norm \
26 | --image_size ${image_size} \
27 | --vae ${vae} \
28 | 2>&1 | tee -a results/"$exp_name"/output.log
29 |
--------------------------------------------------------------------------------
/lumina_t2i/exps/slurm/5B_bs512_lr1e-4_bf16_512px_sdxlvae.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 |
3 | train_data_root='configs/data/JourneyDB.yaml'
4 |
5 | model=DiT_Llama_5B_patch2
6 | batch_size=512
7 | lr=1e-4
8 | precision=bf16
9 | image_size=512
10 | vae=sdxl
11 | init_from=$1
12 | load_str=$2
13 |
14 | exp_name=${model}_bs${batch_size}_lr${lr}_${precision}_${image_size}px_vae${vae}_init${load_str}
15 | mkdir -p results/"$exp_name"
16 |
17 | python -u train.py \
18 | --master_port 18181 \
19 | --model ${model} \
20 | --data_path ${train_data_root} \
21 | --results_dir results/${exp_name} \
22 | --micro_batch_size 8 \
23 | --global_batch_size ${batch_size} --lr ${lr} \
24 | --data_parallel fsdp \
25 | --max_steps 3000000 \
26 | --ckpt_every 20000 --log_every 100 \
27 | --precision ${precision} --grad_precision fp32 --qk_norm \
28 | --image_size ${image_size} \
29 | --init_from $init_from \
30 | --global_seed 2 \
31 | --vae ${vae} \
32 | 2>&1 | tee -a results/"$exp_name"/output.log
33 |
--------------------------------------------------------------------------------
/lumina_t2i/grad_norm.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 |
3 | import fairscale.nn.model_parallel.initialize as fs_init
4 | from fairscale.nn.model_parallel.layers import ColumnParallelLinear, ParallelEmbedding, RowParallelLinear
5 | import torch
6 | import torch.distributed as dist
7 | import torch.nn as nn
8 |
9 |
10 | def get_model_parallel_dim_dict(model: nn.Module) -> Dict[str, int]:
11 | ret_dict = {}
12 | for module_name, module in model.named_modules():
13 |
14 | def param_fqn(param_name):
15 | return param_name if module_name == "" else module_name + "." + param_name
16 |
17 | if isinstance(module, ColumnParallelLinear):
18 | ret_dict[param_fqn("weight")] = 0
19 | if module.bias is not None:
20 | ret_dict[param_fqn("bias")] = 0
21 | elif isinstance(module, RowParallelLinear):
22 | ret_dict[param_fqn("weight")] = 1
23 | if module.bias is not None:
24 | ret_dict[param_fqn("bias")] = -1
25 | elif isinstance(module, ParallelEmbedding):
26 | ret_dict[param_fqn("weight")] = 1
27 | else:
28 | for param_name, param in module.named_parameters(recurse=False):
29 | ret_dict[param_fqn(param_name)] = -1
30 | return ret_dict
31 |
32 |
33 | def calculate_l2_grad_norm(
34 | model: nn.Module,
35 | model_parallel_dim_dict: Dict[str, int],
36 | ) -> float:
37 | mp_norm_sq = torch.tensor(0.0, dtype=torch.float32, device="cuda")
38 | non_mp_norm_sq = torch.tensor(0.0, dtype=torch.float32, device="cuda")
39 |
40 | for name, param in model.named_parameters():
41 | if param.grad is None:
42 | continue
43 | name = ".".join(x for x in name.split(".") if not x.startswith("_"))
44 | assert name in model_parallel_dim_dict
45 | if model_parallel_dim_dict[name] < 0:
46 | non_mp_norm_sq += param.grad.norm(dtype=torch.float32) ** 2
47 | else:
48 | mp_norm_sq += param.grad.norm(dtype=torch.float32) ** 2
49 |
50 | dist.all_reduce(mp_norm_sq)
51 | dist.all_reduce(non_mp_norm_sq)
52 | non_mp_norm_sq /= fs_init.get_model_parallel_world_size()
53 |
54 | return (mp_norm_sq.item() + non_mp_norm_sq.item()) ** 0.5
55 |
56 |
57 | def scale_grad(model: nn.Module, factor: float) -> None:
58 | for param in model.parameters():
59 | if param.grad is not None:
60 | param.grad.mul_(factor)
61 |
--------------------------------------------------------------------------------
/lumina_t2i/imgproc.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | from PIL import Image
4 | import numpy as np
5 |
6 |
7 | def center_crop_arr(pil_image, image_size):
8 | """
9 | Center cropping implementation from ADM.
10 | https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
11 | """
12 | while min(*pil_image.size) >= 2 * image_size:
13 | pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX)
14 |
15 | scale = image_size / min(*pil_image.size)
16 | pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC)
17 |
18 | arr = np.array(pil_image)
19 | crop_y = (arr.shape[0] - image_size) // 2
20 | crop_x = (arr.shape[1] - image_size) // 2
21 | return Image.fromarray(arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size])
22 |
23 |
24 | def center_crop(pil_image, crop_size):
25 | while pil_image.size[0] >= 2 * crop_size[0] and pil_image.size[1] >= 2 * crop_size[1]:
26 | pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX)
27 |
28 | scale = max(crop_size[0] / pil_image.size[0], crop_size[1] / pil_image.size[1])
29 | pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC)
30 |
31 | crop_left = random.randint(0, pil_image.size[0] - crop_size[0])
32 | crop_upper = random.randint(0, pil_image.size[1] - crop_size[1])
33 | crop_right = crop_left + crop_size[0]
34 | crop_lower = crop_upper + crop_size[1]
35 | return pil_image.crop(box=(crop_left, crop_upper, crop_right, crop_lower))
36 |
37 |
38 | def var_center_crop(pil_image, crop_size_list, random_top_k=4):
39 | w, h = pil_image.size
40 | rem_percent = [min(cw / w, ch / h) / max(cw / w, ch / h) for cw, ch in crop_size_list]
41 | crop_size = random.choice(
42 | sorted(((x, y) for x, y in zip(rem_percent, crop_size_list)), reverse=True)[:random_top_k]
43 | )[1]
44 | return center_crop(pil_image, crop_size)
45 |
46 |
47 | def generate_crop_size_list(num_patches, patch_size, max_ratio=4.0):
48 | assert max_ratio >= 1.0
49 | crop_size_list = []
50 | wp, hp = num_patches, 1
51 | while wp > 0:
52 | if max(wp, hp) / min(wp, hp) <= max_ratio:
53 | crop_size_list.append((wp * patch_size, hp * patch_size))
54 | if (hp + 1) * wp <= num_patches:
55 | hp += 1
56 | else:
57 | wp -= 1
58 | return crop_size_list
59 |
--------------------------------------------------------------------------------
/lumina_t2i/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import DiT_Llama_5B_patch2
2 |
--------------------------------------------------------------------------------
/lumina_t2i/models/components.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | try:
7 | from apex.normalization import FusedRMSNorm as RMSNorm
8 | except ImportError:
9 | warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation")
10 |
11 | class RMSNorm(torch.nn.Module):
12 | def __init__(self, dim: int, eps: float = 1e-6):
13 | """
14 | Initialize the RMSNorm normalization layer.
15 |
16 | Args:
17 | dim (int): The dimension of the input tensor.
18 | eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
19 |
20 | Attributes:
21 | eps (float): A small value added to the denominator for numerical stability.
22 | weight (nn.Parameter): Learnable scaling parameter.
23 |
24 | """
25 | super().__init__()
26 | self.eps = eps
27 | self.weight = nn.Parameter(torch.ones(dim))
28 |
29 | def _norm(self, x):
30 | """
31 | Apply the RMSNorm normalization to the input tensor.
32 |
33 | Args:
34 | x (torch.Tensor): The input tensor.
35 |
36 | Returns:
37 | torch.Tensor: The normalized tensor.
38 |
39 | """
40 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
41 |
42 | def forward(self, x):
43 | """
44 | Forward pass through the RMSNorm layer.
45 |
46 | Args:
47 | x (torch.Tensor): The input tensor.
48 |
49 | Returns:
50 | torch.Tensor: The output tensor after applying RMSNorm.
51 |
52 | """
53 | output = self._norm(x.float()).type_as(x)
54 | return output * self.weight
55 |
--------------------------------------------------------------------------------
/lumina_t2i/parallel.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import os
4 | import subprocess
5 | from time import sleep
6 |
7 | import fairscale.nn.model_parallel.initialize as fs_init
8 | import torch
9 | import torch.distributed as dist
10 |
11 |
12 | def _setup_dist_env_from_slurm(args):
13 | while not os.environ.get("MASTER_ADDR", ""):
14 | os.environ["MASTER_ADDR"] = (
15 | subprocess.check_output(
16 | "sinfo -Nh -n %s | head -n 1 | awk '{print $1}'" % os.environ["SLURM_NODELIST"],
17 | shell=True,
18 | )
19 | .decode()
20 | .strip()
21 | )
22 | sleep(1)
23 | os.environ["MASTER_PORT"] = str(args.master_port)
24 | os.environ["RANK"] = os.environ["SLURM_PROCID"]
25 | os.environ["WORLD_SIZE"] = os.environ["SLURM_NPROCS"]
26 | os.environ["LOCAL_RANK"] = os.environ["SLURM_LOCALID"]
27 | os.environ["LOCAL_WORLD_SIZE"] = os.environ["SLURM_NTASKS_PER_NODE"]
28 |
29 |
30 | _INTRA_NODE_PROCESS_GROUP, _INTER_NODE_PROCESS_GROUP = None, None
31 | _LOCAL_RANK, _LOCAL_WORLD_SIZE = -1, -1
32 |
33 |
34 | def get_local_rank() -> int:
35 | return _LOCAL_RANK
36 |
37 |
38 | def get_local_world_size() -> int:
39 | return _LOCAL_WORLD_SIZE
40 |
41 |
42 | def distributed_init(args):
43 | if any([x not in os.environ for x in ["RANK", "WORLD_SIZE", "MASTER_PORT", "MASTER_ADDR"]]):
44 | _setup_dist_env_from_slurm(args)
45 |
46 | dist.init_process_group("nccl")
47 | fs_init.initialize_model_parallel(args.model_parallel_size)
48 | torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
49 |
50 | global _LOCAL_RANK, _LOCAL_WORLD_SIZE
51 | _LOCAL_RANK = int(os.environ["LOCAL_RANK"])
52 | _LOCAL_WORLD_SIZE = int(os.environ["LOCAL_WORLD_SIZE"])
53 |
54 | global _INTRA_NODE_PROCESS_GROUP, _INTER_NODE_PROCESS_GROUP
55 | local_ranks, local_world_sizes = [
56 | torch.empty([dist.get_world_size()], dtype=torch.long, device="cuda") for _ in (0, 1)
57 | ]
58 | dist.all_gather_into_tensor(local_ranks, torch.tensor(get_local_rank(), device="cuda"))
59 | dist.all_gather_into_tensor(local_world_sizes, torch.tensor(get_local_world_size(), device="cuda"))
60 | local_ranks, local_world_sizes = local_ranks.tolist(), local_world_sizes.tolist()
61 | node_ranks = [[0]]
62 | for i in range(1, dist.get_world_size()):
63 | if len(node_ranks[-1]) == local_world_sizes[i - 1]:
64 | node_ranks.append([])
65 | else:
66 | assert local_world_sizes[i] == local_world_sizes[i - 1]
67 | node_ranks[-1].append(i)
68 | for ranks in node_ranks:
69 | group = dist.new_group(ranks)
70 | if dist.get_rank() in ranks:
71 | assert _INTRA_NODE_PROCESS_GROUP is None
72 | _INTRA_NODE_PROCESS_GROUP = group
73 | assert _INTRA_NODE_PROCESS_GROUP is not None
74 |
75 | if min(local_world_sizes) == max(local_world_sizes):
76 | for i in range(get_local_world_size()):
77 | group = dist.new_group(list(range(i, dist.get_world_size(), get_local_world_size())))
78 | if i == get_local_rank():
79 | assert _INTER_NODE_PROCESS_GROUP is None
80 | _INTER_NODE_PROCESS_GROUP = group
81 | assert _INTER_NODE_PROCESS_GROUP is not None
82 |
83 |
84 | def get_intra_node_process_group():
85 | assert _INTRA_NODE_PROCESS_GROUP is not None, "Intra-node process group is not initialized."
86 | return _INTRA_NODE_PROCESS_GROUP
87 |
88 |
89 | def get_inter_node_process_group():
90 | assert _INTRA_NODE_PROCESS_GROUP is not None, "Intra- and inter-node process groups are not initialized."
91 | return _INTER_NODE_PROCESS_GROUP
92 |
--------------------------------------------------------------------------------
/lumina_t2i/requirements.txt:
--------------------------------------------------------------------------------
1 | diffusers
2 | fairscale
3 | accelerate
4 | tensorboard
5 | transformers
6 | gradio
7 | torchdiffeq
8 | click
9 |
--------------------------------------------------------------------------------
/lumina_t2i/transport/__init__.py:
--------------------------------------------------------------------------------
1 | from .transport import ModelType, PathType, Sampler, SNRType, Transport, WeightType
2 |
3 |
4 | def create_transport(
5 | path_type="Linear",
6 | prediction="velocity",
7 | loss_weight=None,
8 | train_eps=None,
9 | sample_eps=None,
10 | snr_type="uniform",
11 | ):
12 | """function for creating Transport object
13 | **Note**: model prediction defaults to velocity
14 | Args:
15 | - path_type: type of path to use; default to linear
16 | - learn_score: set model prediction to score
17 | - learn_noise: set model prediction to noise
18 | - velocity_weighted: weight loss by velocity weight
19 | - likelihood_weighted: weight loss by likelihood weight
20 | - train_eps: small epsilon for avoiding instability during training
21 | - sample_eps: small epsilon for avoiding instability during sampling
22 | """
23 |
24 | if prediction == "noise":
25 | model_type = ModelType.NOISE
26 | elif prediction == "score":
27 | model_type = ModelType.SCORE
28 | else:
29 | model_type = ModelType.VELOCITY
30 |
31 | if loss_weight == "velocity":
32 | loss_type = WeightType.VELOCITY
33 | elif loss_weight == "likelihood":
34 | loss_type = WeightType.LIKELIHOOD
35 | else:
36 | loss_type = WeightType.NONE
37 |
38 | if snr_type == "lognorm":
39 | snr_type = SNRType.LOGNORM
40 | elif snr_type == "uniform":
41 | snr_type = SNRType.UNIFORM
42 | else:
43 | raise ValueError(f"Invalid snr type {snr_type}")
44 |
45 | path_choice = {
46 | "Linear": PathType.LINEAR,
47 | "GVP": PathType.GVP,
48 | "VP": PathType.VP,
49 | }
50 |
51 | path_type = path_choice[path_type]
52 |
53 | if path_type in [PathType.VP]:
54 | train_eps = 1e-5 if train_eps is None else train_eps
55 | sample_eps = 1e-3 if train_eps is None else sample_eps
56 | elif path_type in [PathType.GVP, PathType.LINEAR] and model_type != ModelType.VELOCITY:
57 | train_eps = 1e-3 if train_eps is None else train_eps
58 | sample_eps = 1e-3 if train_eps is None else sample_eps
59 | else: # velocity & [GVP, LINEAR] is stable everywhere
60 | train_eps = 0
61 | sample_eps = 0
62 |
63 | # create flow state
64 | state = Transport(
65 | model_type=model_type,
66 | path_type=path_type,
67 | loss_type=loss_type,
68 | train_eps=train_eps,
69 | sample_eps=sample_eps,
70 | snr_type=snr_type,
71 | )
72 |
73 | return state
74 |
--------------------------------------------------------------------------------
/lumina_t2i/transport/utils.py:
--------------------------------------------------------------------------------
1 | import torch as th
2 |
3 |
4 | class EasyDict:
5 | def __init__(self, sub_dict):
6 | for k, v in sub_dict.items():
7 | setattr(self, k, v)
8 |
9 | def __getitem__(self, key):
10 | return getattr(self, key)
11 |
12 |
13 | def mean_flat(x):
14 | """
15 | Take the mean over all non-batch dimensions.
16 | """
17 | return th.mean(x, dim=list(range(1, len(x.size()))))
18 |
19 |
20 | def log_state(state):
21 | result = []
22 |
23 | sorted_state = dict(sorted(state.items()))
24 | for key, value in sorted_state.items():
25 | # Check if the value is an instance of a class
26 | if "=61.0"]
9 | build-backend = "setuptools.build_meta"
10 |
11 | [project]
12 | name = "lumina-t2x" # REQUIRED, is the only field that cannot be marked as dynamic.
13 | version = "1.5.0" # REQUIRED, although can be dynamic
14 | description = "Lumina-T2X is a model for Text to Any Modality Generation"
15 | readme = "README.md"
16 |
17 | requires-python = ">=3.10"
18 |
19 | license = {file = "LICENSE.txt"}
20 |
21 | keywords = ["generation", "multi-modal", "transformer", "aigc", "diffusion"]
22 |
23 | authors = [
24 | {name = "Alpha-VLLM", email = "author@example.com" }
25 | ]
26 |
27 | maintainers = [
28 | {name = "Chris Liu", email = "author@example.com" },
29 | {name = "PommesPeter", email = "xepxa6823@gmail.com" }
30 | ]
31 |
32 | classifiers = [
33 | # How mature is this project? Common values are
34 | # 3 - Alpha
35 | # 4 - Beta
36 | # 5 - Production/Stable
37 | "Development Status :: 3 - Alpha",
38 | "Programming Language :: Python :: 3",
39 | "Intended Audience :: Developers",
40 | "License :: OSI Approved :: Apache Software License",
41 | ]
42 |
43 | dependencies = [
44 | "diffusers",
45 | "fairscale",
46 | "accelerate",
47 | "tensorboard",
48 | "transformers",
49 | "gradio",
50 | "torchdiffeq",
51 | "click"
52 | ]
53 |
54 | [project.optional-dependencies]
55 | dev = ["coverage", "pre-commit", "isort", "black"]
56 | image = ["diffusers", "fairscale", "accelerate", "tensorboard", "transformers", "gradio", "torchdiffeq", "click"
57 | ]
58 | music = ["soundfile", "omegaconf", "torchdyn", "pytorch_lightning", "pytorch_memlab", "einops", "ninja", "torchlibrosa", "protobuf", "sentencepiece", "gradio", "transformers"]
59 | audio = ["soundfile", "omegaconf", "torchdyn", "pytorch_lightning", "pytorch_memlab", "einops", "ninja", "torchlibrosa", "protobuf", "sentencepiece", "gradio", "transformers"]
60 |
61 |
62 | [project.scripts]
63 | lumina = "lumina_t2i:entry_point"
64 | lumina_next = "lumina_next_t2i:entry_point"
65 |
66 | [tool.setuptools.packages.find]
67 | exclude = ["assets*"]
68 |
69 | [tool.wheel]
70 | exclude = ["assets*"]
71 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | diffusers
2 | fairscale
3 | accelerate
4 | tensorboard
5 | transformers
6 | gradio
7 | torchdiffeq
8 | click
9 |
--------------------------------------------------------------------------------
/visual_anagrams/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/visual_anagrams/.DS_Store
--------------------------------------------------------------------------------
/visual_anagrams/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Daniel Geng
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/visual_anagrams/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include visual_anagrams/views/assets/4x4/*.png
2 | include visual_anagrams/assets/CourierPrime-Regular.ttf
3 |
--------------------------------------------------------------------------------
/visual_anagrams/animate.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 |
3 | from visual_anagrams.views import get_views
4 | from visual_anagrams.animate import animate_two_view, animate_two_view_motion_blur
5 | from visual_anagrams.views.view_motion import MotionBlurView
6 |
7 |
8 | if __name__ == '__main__':
9 | import argparse
10 | import pickle
11 | from pathlib import Path
12 |
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument("--im_path", required=True, type=str, help='Path to the illusion to animate')
15 | parser.add_argument("--save_video_path", default=None, type=str,
16 | help='Path to save video to. If None, defaults to `im_path`, with extension `.mp4`')
17 | parser.add_argument("--metadata_path", default=None, type=str, help='Path to metadata. If specified, overrides `view` and `prompt` args')
18 | parser.add_argument("--view", default=None, type=str, help='Name of view to use')
19 | parser.add_argument("--prompt_1", default='', nargs='+', type=str,
20 | help='Prompt for first view. Passing multiple will join them with newlines.')
21 | parser.add_argument("--prompt_2", default='', nargs='+', type=str,
22 | help='Prompt for first view. Passing multiple will join them with newlines.')
23 | args = parser.parse_args()
24 |
25 |
26 | # Load image to animate
27 | im_path = Path(args.im_path)
28 | im = Image.open(im_path)
29 |
30 | # Get save dir
31 | if args.save_video_path is None:
32 | save_video_path = im_path.with_suffix('.mp4')
33 |
34 | # Get prompts and views from metadata
35 | if args.metadata_path is None:
36 | # Join prompts with newlines
37 | prompt_1 = '\n'.join(args.prompt_1)
38 | prompt_2 = '\n'.join(args.prompt_2)
39 |
40 | # Get paths and views
41 | view = get_views([args.view])[0]
42 | else:
43 | with open(args.metadata_path, 'rb') as f:
44 | metadata = pickle.load(f)
45 | view = metadata['views'][1]
46 | m_args = metadata['args']
47 | prompt_1 = f'{m_args.style} {m_args.prompts[0]}'.strip()
48 | prompt_2 = f'{m_args.style} {m_args.prompts[1]}'.strip()
49 |
50 | # Get sizes
51 | im_size = im.size[0]
52 | frame_size = int(im_size * 1.5)
53 |
54 | if any([isinstance(view, MotionBlurView) for view in metadata['views']]):
55 | # Animate specifically motion blur views
56 | animate_two_view_motion_blur(
57 | im,
58 | view,
59 | prompt_1,
60 | prompt_2,
61 | save_video_path=save_video_path,
62 | hold_duration=60,
63 | text_fade_duration=10,
64 | transition_duration=2000,
65 | im_size=im_size,
66 | frame_size=frame_size,
67 | )
68 | else:
69 | # Animate all other views
70 | animate_two_view(
71 | im,
72 | view,
73 | prompt_1,
74 | prompt_2,
75 | save_video_path=save_video_path,
76 | hold_duration=120,
77 | text_fade_duration=10,
78 | transition_duration=45,
79 | im_size=im_size,
80 | frame_size=frame_size,
81 | )
--------------------------------------------------------------------------------
/visual_anagrams/environment.yml:
--------------------------------------------------------------------------------
1 | name: visual_anagrams
2 | channels:
3 | - defaults
4 | dependencies:
5 | - _libgcc_mutex=0.1=main
6 | - _openmp_mutex=5.1=1_gnu
7 | - ca-certificates=2023.08.22=h06a4308_0
8 | - ld_impl_linux-64=2.38=h1181459_1
9 | - libffi=3.4.4=h6a678d5_0
10 | - libgcc-ng=11.2.0=h1234567_1
11 | - libgomp=11.2.0=h1234567_1
12 | - libstdcxx-ng=11.2.0=h1234567_1
13 | - ncurses=6.4=h6a678d5_0
14 | - openssl=3.0.12=h7f8727e_0
15 | - pip=23.3.1=py39h06a4308_0
16 | - python=3.9.18=h955ad1f_0
17 | - readline=8.2=h5eee18b_0
18 | - setuptools=68.0.0=py39h06a4308_0
19 | - sqlite=3.41.2=h5eee18b_0
20 | - tk=8.6.12=h1ccaba5_0
21 | - tzdata=2023c=h04d1e81_0
22 | - wheel=0.41.2=py39h06a4308_0
23 | - xz=5.4.5=h5eee18b_0
24 | - zlib=1.2.13=h5eee18b_0
25 | - pip:
26 | - accelerate==0.25.0
27 | - certifi==2023.11.17
28 | - charset-normalizer==3.3.2
29 | - diffusers==0.24.0
30 | - einops==0.7.0
31 | - filelock==3.13.1
32 | - fsspec==2023.12.1
33 | - huggingface-hub==0.19.4
34 | - idna==3.6
35 | - imageio==2.33.0
36 | - imageio-ffmpeg==0.4.9
37 | - importlib-metadata==7.0.0
38 | - jinja2==3.1.2
39 | - markupsafe==2.1.3
40 | - mpmath==1.3.0
41 | - networkx==3.2.1
42 | - numpy==1.26.2
43 | - nvidia-cublas-cu12==12.1.3.1
44 | - nvidia-cuda-cupti-cu12==12.1.105
45 | - nvidia-cuda-nvrtc-cu12==12.1.105
46 | - nvidia-cuda-runtime-cu12==12.1.105
47 | - nvidia-cudnn-cu12==8.9.2.26
48 | - nvidia-cufft-cu12==11.0.2.54
49 | - nvidia-curand-cu12==10.3.2.106
50 | - nvidia-cusolver-cu12==11.4.5.107
51 | - nvidia-cusparse-cu12==12.1.0.106
52 | - nvidia-nccl-cu12==2.18.1
53 | - nvidia-nvjitlink-cu12==12.3.101
54 | - nvidia-nvtx-cu12==12.1.105
55 | - packaging==23.2
56 | - pillow==10.1.0
57 | - psutil==5.9.6
58 | - pyyaml==6.0.1
59 | - regex==2023.10.3
60 | - requests==2.31.0
61 | - safetensors==0.4.1
62 | - sentencepiece==0.1.99
63 | - sympy==1.12
64 | - tokenizers==0.15.0
65 | - torch==2.1.1
66 | - torchvision==0.16.1
67 | - tqdm==4.66.1
68 | - transformers==4.35.2
69 | - triton==2.1.0
70 | - typing-extensions==4.8.0
71 | - urllib3==2.1.0
72 | - zipp==3.17.0
73 |
--------------------------------------------------------------------------------
/visual_anagrams/huggingface_login.py:
--------------------------------------------------------------------------------
1 | from huggingface_hub import login
2 | login()
--------------------------------------------------------------------------------
/visual_anagrams/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .nextdit import NextDiT_2B_GQA_patch2, NextDiT_2B_patch2, DiT_Llama_2B_GQA_patch2
2 |
--------------------------------------------------------------------------------
/visual_anagrams/models/components.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | try:
7 | from apex.normalization import FusedRMSNorm as RMSNorm
8 | except ImportError:
9 | warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation")
10 |
11 | class RMSNorm(torch.nn.Module):
12 | def __init__(self, dim: int, eps: float = 1e-6):
13 | """
14 | Initialize the RMSNorm normalization layer.
15 |
16 | Args:
17 | dim (int): The dimension of the input tensor.
18 | eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
19 |
20 | Attributes:
21 | eps (float): A small value added to the denominator for numerical stability.
22 | weight (nn.Parameter): Learnable scaling parameter.
23 |
24 | """
25 | super().__init__()
26 | self.eps = eps
27 | self.weight = nn.Parameter(torch.ones(dim))
28 |
29 | def _norm(self, x):
30 | """
31 | Apply the RMSNorm normalization to the input tensor.
32 |
33 | Args:
34 | x (torch.Tensor): The input tensor.
35 |
36 | Returns:
37 | torch.Tensor: The normalized tensor.
38 |
39 | """
40 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
41 |
42 | def forward(self, x):
43 | """
44 | Forward pass through the RMSNorm layer.
45 |
46 | Args:
47 | x (torch.Tensor): The input tensor.
48 |
49 | Returns:
50 | torch.Tensor: The output tensor after applying RMSNorm.
51 |
52 | """
53 | output = self._norm(x.float()).type_as(x)
54 | return output * self.weight
55 |
--------------------------------------------------------------------------------
/visual_anagrams/readme.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | # Lumina-Pro Visual Anagrams
7 |
8 | `Lumina-Pro Visual Anagrams` is an implementation of the paper [Visual Anagrams: Generating Multi-View Optical Illusions with Diffusion Models](https://dangeng.github.io/visual_anagrams/) based on `Lumina-Pro`.
9 |
10 |
11 | ## Installation
12 |
13 | Please refer to the `Lumina-Pro` folder.
14 |
15 |
16 | ## Usage
17 |
18 | To generate an illusion, replace `path_to_your_ckpt` in the file `run.sh` with your Lumina-Pro model path and run the following command:
19 | ```bash
20 | bash run.sh
21 | ```
22 | Here is a description of some useful arguments in the script:
23 |
24 | - `--name`: Name for the illusion. Will save samples to `./results/{name}`.
25 | - `--prompts`: A list of prompts for illusions
26 | - `--style`: Optional style prompt to prepend to each of the prompts. For example, could be `"an oil painting of"`. Saves some writing.
27 | - `--views`: A list of views to use. Must match the number of prompts. For a list of views see the `get_views` function in `visual_anagrams/views/__init__.py`. (Note: Only rotation and flip are supported so far.)
28 | - `--num_samples`: Number of illusions to sample.
--------------------------------------------------------------------------------
/visual_anagrams/run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # res=1024:1024x1024
4 | # res=2048:2048x2048
5 | # res=2560:2560x2560
6 | res=4096:4096x4096
7 | t=7
8 | cfg=8.0
9 | seed=17
10 | steps=30
11 | model_dir=path_to_your_ckpt
12 | n=10
13 |
14 | CUDA_VISIBLE_DEVICES=1 \
15 | python generate.py --name flip.campfire.man \
16 | --prompts "people at a campfire" "an old man"\
17 | --style "an oil painting of" \
18 | --views identity flip \
19 | --num_samples ${n} \
20 | --num_inference_steps ${steps} \
21 | --ckpt ${model_dir} \
22 | --seed ${seed} \
23 | --resolution ${res} \
24 | --time_shifting_factor ${t} \
25 | --cfg_scale ${cfg} \
26 | --batch_size 1 \
27 | --use_flash_attn True # You can set this to False if you want to disable the flash attention
28 |
--------------------------------------------------------------------------------
/visual_anagrams/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name='visual_anagrams',
5 | version='0.1',
6 | packages=find_packages(),
7 | include_package_data=True,
8 | install_requires=[],
9 | )
10 |
--------------------------------------------------------------------------------
/visual_anagrams/visual_anagrams/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/visual_anagrams/visual_anagrams/.DS_Store
--------------------------------------------------------------------------------
/visual_anagrams/visual_anagrams/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/visual_anagrams/visual_anagrams/__init__.py
--------------------------------------------------------------------------------
/visual_anagrams/visual_anagrams/assets/CourierPrime-Regular.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/visual_anagrams/visual_anagrams/assets/CourierPrime-Regular.ttf
--------------------------------------------------------------------------------
/visual_anagrams/visual_anagrams/views/__init__.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from PIL import Image
3 | import numpy as np
4 |
5 | from .view_identity import IdentityView
6 | from .view_flip import FlipView
7 | from .view_rotate import Rotate180View, Rotate90CCWView, Rotate90CWView
8 | from .view_negate import NegateView
9 | from .view_skew import SkewView
10 | from .view_patch_permute import PatchPermuteView
11 | from .view_jigsaw import JigsawView
12 | from .view_inner_circle import InnerCircleView, InnerCircleViewFailure
13 | from .view_square_hinge import SquareHingeView
14 | from .view_blur import BlurViewFailure
15 | from .view_white_balance import WhiteBalanceViewFailure
16 | from .view_hybrid import HybridLowPassView, HybridHighPassView, \
17 | TripleHybridHighPassView, TripleHybridLowPassView, \
18 | TripleHybridMediumPassView
19 | from .view_color import ColorView, GrayscaleView
20 | from .view_motion import MotionBlurResView, MotionBlurView
21 | from .view_scale import ScaleView
22 |
23 | VIEW_MAP = {
24 | 'identity': IdentityView,
25 | 'flip': FlipView,
26 | 'rotate_cw': Rotate90CWView,
27 | 'rotate_ccw': Rotate90CCWView,
28 | 'rotate_180': Rotate180View,
29 | 'negate': NegateView,
30 | 'skew': SkewView,
31 | 'patch_permute': PatchPermuteView,
32 | 'pixel_permute': PatchPermuteView,
33 | 'jigsaw': JigsawView,
34 | 'inner_circle': InnerCircleView,
35 | 'square_hinge': SquareHingeView,
36 | 'inner_circle_failure': InnerCircleViewFailure,
37 | 'blur_failure': BlurViewFailure,
38 | 'white_balance_failure': WhiteBalanceViewFailure,
39 | 'low_pass': HybridLowPassView,
40 | 'high_pass': HybridHighPassView,
41 | 'triple_low_pass': TripleHybridLowPassView,
42 | 'triple_medium_pass': TripleHybridMediumPassView,
43 | 'triple_high_pass': TripleHybridHighPassView,
44 | 'grayscale': GrayscaleView,
45 | 'color': ColorView,
46 | 'motion': MotionBlurView,
47 | 'motion_res': MotionBlurResView,
48 | 'scale': ScaleView,
49 | }
50 |
51 | def get_anagrams_views(view_names, view_args=None):
52 | '''
53 | Bespoke function to get views (just to make command line usage easier)
54 | '''
55 |
56 | views = []
57 | if view_args is None:
58 | view_args = [None for _ in view_names]
59 |
60 | for view_name, view_arg in zip(view_names, view_args):
61 | if view_name == 'patch_permute':
62 | args = [8 if view_arg is None else int(view_arg)]
63 | elif view_name == 'pixel_permute':
64 | args = [64 if view_arg is None else int(view_arg)]
65 | elif view_name == 'skew':
66 | args = [1.5 if view_arg is None else float(view_arg)]
67 | elif view_name in ['low_pass', 'high_pass']:
68 | args = [2.0 if view_arg is None else float(view_arg)]
69 | elif view_name in ['scale']:
70 | args = [0.5 if view_arg is None else float(view_arg)]
71 | else:
72 | args = []
73 |
74 | view = VIEW_MAP[view_name](*args)
75 | views.append(view)
76 |
77 | return views
78 |
--------------------------------------------------------------------------------
/visual_anagrams/visual_anagrams/views/assets/4x4/4x4_corner_1024.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/visual_anagrams/visual_anagrams/views/assets/4x4/4x4_corner_1024.png
--------------------------------------------------------------------------------
/visual_anagrams/visual_anagrams/views/assets/4x4/4x4_corner_256.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/visual_anagrams/visual_anagrams/views/assets/4x4/4x4_corner_256.png
--------------------------------------------------------------------------------
/visual_anagrams/visual_anagrams/views/assets/4x4/4x4_corner_64.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/visual_anagrams/visual_anagrams/views/assets/4x4/4x4_corner_64.png
--------------------------------------------------------------------------------
/visual_anagrams/visual_anagrams/views/assets/4x4/4x4_edge1_1024.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/visual_anagrams/visual_anagrams/views/assets/4x4/4x4_edge1_1024.png
--------------------------------------------------------------------------------
/visual_anagrams/visual_anagrams/views/assets/4x4/4x4_edge1_256.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/visual_anagrams/visual_anagrams/views/assets/4x4/4x4_edge1_256.png
--------------------------------------------------------------------------------
/visual_anagrams/visual_anagrams/views/assets/4x4/4x4_edge1_64.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/visual_anagrams/visual_anagrams/views/assets/4x4/4x4_edge1_64.png
--------------------------------------------------------------------------------
/visual_anagrams/visual_anagrams/views/assets/4x4/4x4_edge2_1024.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/visual_anagrams/visual_anagrams/views/assets/4x4/4x4_edge2_1024.png
--------------------------------------------------------------------------------
/visual_anagrams/visual_anagrams/views/assets/4x4/4x4_edge2_256.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/visual_anagrams/visual_anagrams/views/assets/4x4/4x4_edge2_256.png
--------------------------------------------------------------------------------
/visual_anagrams/visual_anagrams/views/assets/4x4/4x4_edge2_64.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/visual_anagrams/visual_anagrams/views/assets/4x4/4x4_edge2_64.png
--------------------------------------------------------------------------------
/visual_anagrams/visual_anagrams/views/assets/4x4/4x4_inner_1024.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/visual_anagrams/visual_anagrams/views/assets/4x4/4x4_inner_1024.png
--------------------------------------------------------------------------------
/visual_anagrams/visual_anagrams/views/assets/4x4/4x4_inner_256.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/visual_anagrams/visual_anagrams/views/assets/4x4/4x4_inner_256.png
--------------------------------------------------------------------------------
/visual_anagrams/visual_anagrams/views/assets/4x4/4x4_inner_64.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Alpha-VLLM/Lumina-T2X/1c606962f95899da711633ee3a333d21c753e2d9/visual_anagrams/visual_anagrams/views/assets/4x4/4x4_inner_64.png
--------------------------------------------------------------------------------
/visual_anagrams/visual_anagrams/views/jigsaw_helpers.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from PIL import Image
3 | import numpy as np
4 |
5 | def get_jigsaw_pieces(size):
6 | '''
7 | Load all pieces of the 4x4 jigsaw puzzle.
8 |
9 | size (int) :
10 | Should be 64, 256, or 1024 indicating side length of jigsaw puzzle
11 | '''
12 |
13 | # Location of pieces
14 | piece_dir = Path(__file__).parent / 'assets'
15 |
16 | # Helper function to load pieces as np arrays
17 | def load_pieces(path):
18 | '''
19 | Load a piece, from the given path, as a binary numpy array.
20 | Return a list of the "base" piece, and all four of its rotations.
21 | '''
22 | piece = Image.open(path)
23 | piece = np.array(piece)[:,:,0] // 255
24 | pieces = np.stack([np.rot90(piece, k=-i) for i in range(4)])
25 | return pieces
26 |
27 | # Load pieces and rotate to get 16 pieces, and cat
28 | pieces_corner = load_pieces(piece_dir / f'4x4/4x4_corner_{size}.png')
29 | pieces_inner = load_pieces(piece_dir / f'4x4/4x4_inner_{size}.png')
30 | pieces_edge1 = load_pieces(piece_dir / f'4x4/4x4_edge1_{size}.png')
31 | pieces_edge2 = load_pieces(piece_dir / f'4x4/4x4_edge2_{size}.png')
32 | pieces = np.concatenate([pieces_corner, pieces_inner, pieces_edge1, pieces_edge2])
33 |
34 | return pieces
35 |
36 |
--------------------------------------------------------------------------------
/visual_anagrams/visual_anagrams/views/view_base.py:
--------------------------------------------------------------------------------
1 | class BaseView:
2 | '''
3 | BaseView class, from which all views inherit. Implements the
4 | following functions:
5 | '''
6 |
7 | def __init__(self):
8 | pass
9 |
10 | def view(self, im):
11 | '''
12 | Apply transform to an image.
13 |
14 | im (`torch.tensor`):
15 | For stage 1: Tensor of shape (3, H, W) representing a noisy image
16 | OR
17 | For stage 2: Tensor of shape (6, H, W) representing a noisy image
18 | concatenated with an upsampled conditioning image from stage 1
19 | '''
20 | raise NotImplementedError()
21 |
22 | def inverse_view(self, noise):
23 | '''
24 | Apply inverse transform to noise estimates.
25 | Because DeepFloyd estimates the variance in addition to
26 | the noise, this function must apply the inverse to the
27 | variance as well.
28 |
29 | noise (`torch.tensor`):
30 | Tensor of shape (6, H, W) representing the noise estimate
31 | (first three channel dims) and variance estimates (last
32 | three channel dims)
33 | '''
34 | raise NotImplementedError()
35 |
36 | def make_frame(self, im, t):
37 | '''
38 | Make a frame, transitioning linearly from the identity view (t=0)
39 | to this view (t=1)
40 |
41 | im (`PIL.Image`):
42 | A PIL Image of the illusion
43 |
44 | t (float):
45 | A float in [0,1] indicating time in the animation. Should start
46 | at the identity view at t=0, and continuously transition to the
47 | view at t=1.
48 | '''
49 | raise NotImplementedError()
50 |
--------------------------------------------------------------------------------
/visual_anagrams/visual_anagrams/views/view_blur.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import numpy as np
3 |
4 | import torch
5 | import torchvision.transforms.functional as TF
6 |
7 | from .view_base import BaseView
8 |
9 |
10 | class BlurViewFailure(BaseView):
11 | '''
12 | A failing blur view, which blurs an image, in an attempt
13 | to synthesize hybrid images.
14 | '''
15 | def __init__(self, factor=8):
16 | self.factor = factor
17 |
18 | def make_frame(self, im, t):
19 | im_size = im.size[0]
20 | frame_size = int(im_size * 1.5)
21 | new_size = int( im_size / (1 + (self.factor - 1) * t) )
22 |
23 | # Convert to tensor
24 | im = torch.tensor(np.array(im) / 255.).permute(2,0,1)
25 |
26 | # Resize to new size
27 | im = TF.resize(im, new_size)
28 |
29 | # Convert back to PIL
30 | im = Image.fromarray((np.array(im.permute(1,2,0)) * 255.).astype(np.uint8))
31 |
32 | # Paste on to canvas
33 | frame = Image.new('RGB', (frame_size, frame_size), (255, 255, 255))
34 | frame.paste(im, ((frame_size - new_size) // 2, (frame_size - new_size) // 2))
35 |
36 | return frame
37 |
38 | def view(self, im):
39 | im_size = im.shape[-1]
40 |
41 | # Downsample then upsample to "blur"
42 | im_small = TF.resize(im, im_size // self.factor)
43 | im_blur = TF.resize(im_small, im_size)
44 |
45 | return im_blur
46 |
47 | def inverse_view(self, noise):
48 | # The transform is technically uninvertible, so just do pass through
49 | return noise
50 |
51 |
--------------------------------------------------------------------------------
/visual_anagrams/visual_anagrams/views/view_color.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import numpy as np
3 |
4 | import torch
5 |
6 | from .view_base import BaseView
7 |
8 | def make_frame_color(im, t):
9 | im_size = im.size[0]
10 | frame_size = int(im_size * 1.5)
11 |
12 | # Convert to tensor
13 | im = torch.tensor(np.array(im) / 255.).permute(2,0,1)
14 |
15 | # Extract color and greyscale components
16 | im_grey = im.clone()
17 | im_grey[:] = im.mean(dim=0, keepdim=True)
18 | im_color = im - im_grey
19 |
20 | # Take linear interpolation
21 | im = im_grey + t * im_color
22 |
23 | # Convert back to PIL
24 | im = Image.fromarray((np.array(im.permute(1,2,0)) * 255.).astype(np.uint8))
25 |
26 | # Paste on to canvas
27 | frame = Image.new('RGB', (frame_size, frame_size), (255, 255, 255))
28 | frame.paste(im, ((frame_size - im_size) // 2, (frame_size - im_size) // 2))
29 |
30 | return frame
31 |
32 | class GrayscaleView(BaseView):
33 | def __init__(self):
34 | pass
35 |
36 | def make_frame(self, im, t):
37 | return make_frame_color(im, t)
38 |
39 | def view(self, im):
40 | return im
41 |
42 | def save_view(self, im):
43 | im = torch.stack([im.mean(0)] * 3)
44 | return im
45 |
46 | def inverse_view(self, noise):
47 | # Get grayscale component by averaging color channels
48 | noise[:3] = torch.stack([noise[:3].mean(0)] * 3)
49 | return noise
50 |
51 |
52 | class ColorView(BaseView):
53 | def __init__(self):
54 | pass
55 |
56 | def make_frame(self, im, t):
57 | return make_frame_color(im, t)
58 |
59 | def view(self, im):
60 | return im
61 |
62 | def inverse_view(self, noise):
63 | # Get color component by taking residual
64 | noise[:3] = noise[:3] - torch.stack([noise[:3].mean(0)] * 3)
65 | return noise
--------------------------------------------------------------------------------
/visual_anagrams/visual_anagrams/views/view_flip.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 |
3 | import torch
4 |
5 | from .view_base import BaseView
6 |
7 | class FlipView(BaseView):
8 | def __init__(self):
9 | pass
10 |
11 | def view(self, im):
12 | return torch.flip(im, [1])
13 |
14 | def inverse_view(self, noise):
15 | return torch.flip(noise, [1])
16 |
17 | def make_frame(self, im, t):
18 | im_size = im.size[0]
19 | frame_size = int(im_size * 1.5)
20 | theta = -t * 180
21 |
22 | # TODO: Technically not a flip, change this to a homography later
23 | frame = Image.new('RGB', (frame_size, frame_size), (255, 255, 255))
24 | frame.paste(im, ((frame_size - im_size) // 2, (frame_size - im_size) // 2))
25 | frame = frame.rotate(theta,
26 | resample=Image.Resampling.BILINEAR,
27 | expand=False,
28 | fillcolor=(255,255,255))
29 |
30 | return frame
31 |
--------------------------------------------------------------------------------
/visual_anagrams/visual_anagrams/views/view_identity.py:
--------------------------------------------------------------------------------
1 | from .view_base import BaseView
2 |
3 | class IdentityView(BaseView):
4 | def __init__(self):
5 | pass
6 |
7 | def view(self, im):
8 | return im
9 |
10 | def inverse_view(self, noise):
11 | return noise
12 |
--------------------------------------------------------------------------------
/visual_anagrams/visual_anagrams/views/view_motion.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import numpy as np
3 |
4 | import torch
5 | import torch.nn.functional as F
6 |
7 | from .view_base import BaseView
8 |
9 | def make_frame_motion(im, t):
10 | im_size = im.size[0]
11 | factor = im_size / 64 / 4.0
12 | frame_size = int(im_size * 1.5)
13 | vel = 20
14 | amp = 29 * factor / 2 # 29 @ 256, 29 * 4 @ 1024, need to divide by 2 b/c amp
15 |
16 | # Triangular wave
17 | offset = int(amp * 2 / np.pi * np.arcsin(np.sin(2 * np.pi * vel * t)))
18 |
19 | # Paste on to canvas
20 | frame = Image.new('RGB', (frame_size, frame_size), (255, 255, 255))
21 | frame.paste(im, (offset + (frame_size - im_size) // 2, offset + (frame_size - im_size) // 2))
22 |
23 | return frame
24 |
25 | class MotionBlurView(BaseView):
26 | def __init__(self, size=7):
27 | self.size = size
28 |
29 | def make_frame(self, im, t):
30 | return make_frame_motion(im, t)
31 |
32 | def view(self, im):
33 | return im
34 |
35 | def inverse_view(self, noise):
36 | c, h, w = noise.shape
37 | factor = h // 64 # Account for image size
38 |
39 | # Make kernel on the fly
40 | size = self.size * factor
41 | size = size + ((factor + 1) % 2) # Make sure it's odd
42 | self.K = torch.eye(size)[None, None] / size
43 | self.K = self.K.to(noise.dtype).to(noise.device)
44 |
45 | # Apply kernel to each channel independently
46 | noise[:3] = torch.cat([F.conv2d(noise[i][None], self.K, padding=size//2) for i in range(3)])
47 |
48 | return noise
49 |
50 | def save_view(self, im):
51 | c, h, w = im.shape
52 | factor = h // 64 # Account for image size
53 |
54 | # Make kernel on the fly
55 | size = self.size * factor
56 | size = size + ((factor + 1) % 2) # Make sure it's odd
57 | self.K = torch.eye(size)[None, None] / size
58 | self.K = self.K.to(im.dtype).to(im.device)
59 |
60 | # Apply kernel to each channel independently
61 | im = torch.cat([F.conv2d(im[i][None], self.K, padding=size//2) for i in range(3)])
62 |
63 | return im
64 |
65 |
66 |
67 | class MotionBlurResView(BaseView):
68 | def __init__(self, size=7):
69 | self.size = size
70 |
71 | def make_frame(self, im, t):
72 | return make_frame_motion(im, t)
73 |
74 | def view(self, im):
75 | return im
76 |
77 | def inverse_view(self, noise):
78 | c, h, w = noise.shape
79 | factor = h // 64 # Account for image size
80 |
81 | # Make kernel on the fly
82 | size = self.size * factor
83 | size = size + ((factor + 1) % 2) # Make sure it's odd
84 | self.K = torch.eye(size)[None, None] / size
85 | self.K = self.K.to(noise.dtype).to(noise.device)
86 |
87 | # Apply kernel to each channel independently, and take residual
88 | noise[:3] = noise[:3] - torch.cat([F.conv2d(noise[i][None], self.K, padding=size//2) for i in range(3)])
89 |
90 | return noise
--------------------------------------------------------------------------------
/visual_anagrams/visual_anagrams/views/view_negate.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import numpy as np
3 |
4 | import torch
5 |
6 | from .view_base import BaseView
7 |
8 | class NegateView(BaseView):
9 | def __init__(self):
10 | pass
11 |
12 | def view(self, im):
13 | return -im
14 |
15 | def inverse_view(self, noise):
16 | '''
17 | Negating the variance estimate is "weird" so just don't do it.
18 | This hack seems to work just fine
19 | '''
20 | invert_mask = torch.ones_like(noise)
21 | invert_mask[:3] = -1
22 | return noise * invert_mask
23 |
24 | def make_frame(self, im, t):
25 | im_size = im.size[0]
26 | frame_size = int(im_size * 1.5)
27 |
28 | # map t from [0, 1] -> [1, -1]
29 | t = 1 - t
30 | t = t * 2 - 1
31 |
32 | # Interpolate from pixels from [0, 1] to [1, 0]
33 | im = np.array(im) / 255.
34 | im = ((2 * im - 1) * t + 1) / 2.
35 | im = Image.fromarray((im * 255.).astype(np.uint8))
36 |
37 | # Paste on to canvas
38 | frame = Image.new('RGB', (frame_size, frame_size), (255, 255, 255))
39 | frame.paste(im, ((frame_size - im_size) // 2, (frame_size - im_size) // 2))
40 |
41 | return frame
42 |
--------------------------------------------------------------------------------
/visual_anagrams/visual_anagrams/views/view_rotate.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 |
3 | import torchvision.transforms.functional as TF
4 | import torch
5 | from torchvision.transforms import InterpolationMode
6 |
7 | from .view_base import BaseView
8 |
9 |
10 | class Rotate90CWView(BaseView):
11 | def __init__(self):
12 | pass
13 |
14 | def view(self, im):
15 | # TODO: Is nearest-exact better?
16 | # return TF.rotate(im, -90, interpolation=InterpolationMode.NEAREST) # clockwise 90
17 | return torch.rot90(im, -1, dims=[1, 2])
18 |
19 | def inverse_view(self, noise):
20 | # return TF.rotate(noise, 90, interpolation=InterpolationMode.NEAREST) # counter-clockwise 90
21 | return torch.rot90(noise, 1, dims=[1, 2])
22 |
23 | def make_frame(self, im, t):
24 | im_size = im.size[0]
25 | frame_size = int(im_size * 1.5)
26 | theta = t * -90
27 |
28 | frame = Image.new('RGB', (frame_size, frame_size), (255, 255, 255))
29 | centered_loc = (frame_size - im_size) // 2
30 | frame.paste(im, (centered_loc, centered_loc))
31 | frame = frame.rotate(theta,
32 | resample=Image.Resampling.BILINEAR,
33 | expand=False,
34 | fillcolor=(255,255,255))
35 |
36 | return frame
37 |
38 |
39 | class Rotate90CCWView(BaseView):
40 | def __init__(self):
41 | pass
42 |
43 | def view(self, im):
44 | # TODO: Is nearest-exact better?
45 | # return TF.rotate(im, 90, interpolation=InterpolationMode.NEAREST)
46 | return torch.rot90(im, 1, dims=[1, 2])
47 |
48 | def inverse_view(self, noise):
49 | # return TF.rotate(noise, -90, interpolation=InterpolationMode.NEAREST)
50 | return torch.rot90(noise, -1, dims=[1, 2])
51 |
52 | def make_frame(self, im, t):
53 | im_size = im.size[0]
54 | frame_size = int(im_size * 1.5)
55 | theta = t * 90
56 |
57 | frame = Image.new('RGB', (frame_size, frame_size), (255, 255, 255))
58 | centered_loc = (frame_size - im_size) // 2
59 | frame.paste(im, (centered_loc, centered_loc))
60 | frame = frame.rotate(theta,
61 | resample=Image.Resampling.BILINEAR,
62 | expand=False,
63 | fillcolor=(255,255,255))
64 |
65 | return frame
66 |
67 |
68 | class Rotate180View(BaseView):
69 | def __init__(self):
70 | pass
71 |
72 | def view(self, im):
73 | # TODO: Is nearest-exact better?
74 | # return TF.rotate(im, 180, interpolation=InterpolationMode.NEAREST)
75 | return torch.rot90(im, 2, dims=[1, 2])
76 |
77 | def inverse_view(self, noise):
78 | # return TF.rotate(noise, -180, interpolation=InterpolationMode.NEAREST)
79 | return torch.rot90(noise, -2, dims=[1, 2])
80 |
81 | def make_frame(self, im, t):
82 | im_size = im.size[0]
83 | frame_size = int(im_size * 1.5)
84 | theta = t * 180
85 |
86 | frame = Image.new('RGB', (frame_size, frame_size), (255, 255, 255))
87 | centered_loc = (frame_size - im_size) // 2
88 | frame.paste(im, (centered_loc, centered_loc))
89 | frame = frame.rotate(theta,
90 | resample=Image.Resampling.BILINEAR,
91 | expand=False,
92 | fillcolor=(255,255,255))
93 |
94 | return frame
95 |
--------------------------------------------------------------------------------
/visual_anagrams/visual_anagrams/views/view_scale.py:
--------------------------------------------------------------------------------
1 | from .view_base import BaseView
2 |
3 | class ScaleView(BaseView):
4 | def __init__(self, scale=0.5):
5 | self.scale = scale
6 |
7 | def view(self, im):
8 | return im
9 |
10 | def inverse_view(self, noise):
11 | noise[:3] = self.scale * noise[:3]
12 | return noise
--------------------------------------------------------------------------------
/visual_anagrams/visual_anagrams/views/view_skew.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import numpy as np
3 |
4 | import torch
5 |
6 | from .view_base import BaseView
7 |
8 |
9 | class SkewView(BaseView):
10 | def __init__(self, skew_factor=1.5):
11 | self.skew_factor = skew_factor
12 |
13 | def skew_image(self, im, skew_factor):
14 | '''
15 | Roll each column of the image by increasing displacements.
16 | This is a permutation of pixels
17 | '''
18 |
19 | # Params
20 | c,h,w = im.shape
21 | h_center = h//2
22 |
23 | # Roll columns
24 | cols = []
25 | for i in range(w):
26 | d = int(skew_factor * (i - h_center)) # Displacement
27 | col = im[:,:,i]
28 | cols.append(col.roll(d, dims=1))
29 |
30 | # Stack rolled columns
31 | skewed = torch.stack(cols, dim=2)
32 | return skewed
33 |
34 | def view(self, im):
35 | return self.skew_image(im, self.skew_factor)
36 |
37 | def inverse_view(self, noise):
38 | return self.skew_image(noise, -self.skew_factor)
39 |
40 | def make_frame(self, im, t):
41 | im_size = im.size[0]
42 | frame_size = int(im_size * 1.5)
43 | skew_factor = t * self.skew_factor
44 |
45 | # Convert to tensor, skew, then convert back to PIL
46 | im = torch.tensor(np.array(im) / 255.).permute(2,0,1)
47 | im = self.skew_image(im, skew_factor)
48 | im = Image.fromarray((np.array(im.permute(1,2,0)) * 255.).astype(np.uint8))
49 |
50 | # Paste on to canvas
51 | frame = Image.new('RGB', (frame_size, frame_size), (255, 255, 255))
52 | frame.paste(im, ((frame_size - im_size) // 2, (frame_size - im_size) // 2))
53 |
54 | return frame
55 |
56 |
--------------------------------------------------------------------------------
/visual_anagrams/visual_anagrams/views/view_white_balance.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import numpy as np
3 |
4 | import torch
5 | import torchvision.transforms.functional as TF
6 |
7 | from .view_base import BaseView
8 |
9 |
10 | class WhiteBalanceViewFailure(BaseView):
11 | '''
12 | A failing white balancing view, which simply scales the pixel values
13 | by some constant factor. An attempt to reproduce the "dress" illusion
14 | '''
15 | def __init__(self, factor=1.5):
16 | self.factor = factor
17 |
18 | def make_frame(self, im, t):
19 | im_size = im.size[0]
20 | frame_size = int(im_size * 1.5)
21 |
22 | # Interpolate factor on t
23 | factor = 1 + (self.factor - 1) * t
24 |
25 | # Convert to tensor
26 | im = torch.tensor(np.array(im) / 255.).permute(2,0,1)
27 |
28 | # Adjust colors
29 | im = im * factor
30 | im = torch.clip(im, 0, 1)
31 |
32 | # Convert back to PIL
33 | im = Image.fromarray((np.array(im.permute(1,2,0)) * 255.).astype(np.uint8))
34 |
35 | # Paste on to canvas
36 | frame = Image.new('RGB', (frame_size, frame_size), (255, 255, 255))
37 | frame.paste(im, ((frame_size - im_size) // 2, (frame_size - im_size) // 2))
38 |
39 | return frame
40 |
41 | def view(self, im):
42 | return im * self.factor
43 |
44 | def inverse_view(self, noise):
45 | noise[:3] = noise[:3] / self.factor
46 | return noise
47 |
48 |
--------------------------------------------------------------------------------