├── modeling ├── __init__.py ├── generator │ ├── __init__.py │ ├── net │ │ ├── import_utils.py │ │ ├── act.py │ │ ├── norms.py │ │ └── basic_modules.py │ └── dc_ar.py ├── tokenizer │ ├── quantizer │ │ ├── __init__.py │ │ ├── hybrid_vq.py │ │ └── quantizer.py │ ├── utils │ │ ├── __init__.py │ │ ├── transforms.py │ │ ├── metric.py │ │ ├── list.py │ │ ├── dist.py │ │ ├── random.py │ │ ├── init.py │ │ ├── network.py │ │ └── misc.py │ ├── networks │ │ ├── dc_ae_blocks │ │ │ ├── act.py │ │ │ ├── norm.py │ │ │ └── triton_rms_norm.py │ │ └── dc_ae.py │ ├── configuration.py │ └── dc_ht.py ├── modules │ ├── __init__.py │ ├── losses.py │ ├── base_model.py │ └── ema_model.py ├── diffusion │ ├── __init__.py │ ├── diffusion_utils.py │ ├── respace.py │ └── diffloss.py └── utils.py ├── assets └── teaser.png ├── utils ├── safety_check.py ├── __init__.py └── demo_util.py ├── pyproject.toml ├── configs └── inference │ └── dc_ar_t2i_512.yaml ├── LICENSE ├── dc_ar_models.txt └── code.txt ├── .gitignore ├── README.md ├── sample.py └── app.py /modeling/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /modeling/generator/__init__.py: -------------------------------------------------------------------------------- 1 | from .dc_ar import DCAR -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dc-ai-projects/DC-AR/HEAD/assets/teaser.png -------------------------------------------------------------------------------- /modeling/tokenizer/quantizer/__init__.py: -------------------------------------------------------------------------------- 1 | from .quantizer import VectorQuantizer 2 | from .hybrid_vq import HybridVQ -------------------------------------------------------------------------------- /modeling/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_model import BaseModel 2 | from .ema_model import EMAModel 3 | from .losses import MLMLoss -------------------------------------------------------------------------------- /modeling/diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | from . import gaussian_diffusion as gd 2 | from .gaussian_diffusion import GaussianDiffusion 3 | from .respace import SpacedDiffusion, space_timesteps 4 | from .diffloss import DiffLoss -------------------------------------------------------------------------------- /modeling/tokenizer/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .dist import * 2 | from .init import * 3 | from .metric import * 4 | from .misc import * 5 | from .random import * 6 | from .transforms import * 7 | from .list import * 8 | from .network import * -------------------------------------------------------------------------------- /modeling/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | # This file is modified from https://github.com/NVlabs/VILA 18 | 19 | from typing import Dict, Sequence 20 | 21 | import torch 22 | import transformers 23 | 24 | __all__ = ["tokenize_fn"] 25 | 26 | 27 | def tokenize_fn( 28 | strings: Sequence[str], 29 | tokenizer: transformers.PreTrainedTokenizer, 30 | max_length: int, 31 | padding_mode: str = "longest", 32 | ) -> Dict: 33 | """Tokenize a list of strings.""" 34 | tokenized_list = [ 35 | tokenizer( 36 | text, 37 | return_tensors="pt", 38 | padding=padding_mode, 39 | max_length=max_length, 40 | truncation=True, 41 | ) 42 | for text in strings 43 | ] 44 | input_ids = [tokenized.input_ids[0] for tokenized in tokenized_list] 45 | return input_ids -------------------------------------------------------------------------------- /utils/safety_check.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.functional import softmax 3 | from transformers import AutoModelForCausalLM, AutoTokenizer 4 | 5 | __all__ = ["is_dangerous"] 6 | 7 | safety_policy = """ 8 | The prompt must not seek to generate harmful, abusive, hateful, sexually explicit, or otherwise inappropriate content targeting individuals or protected groups. 9 | """ 10 | 11 | 12 | def is_dangerous(tokenizer, model, user_prompt, threshold=0.5): 13 | prompt = f"""You are a policy expert trying to help determine whether a user 14 | prompt is in violation of the defined safety policies. 15 | 16 | 17 | Human Question: {user_prompt.strip()} 18 | 19 | 20 | Our safety principle is defined in the below: 21 | 22 | {safety_policy.strip()} 23 | 24 | Does the human question violate the above principle? Your answer must start 25 | with 'Yes' or 'No'. And then walk through step by step to be sure we answer 26 | correctly. 27 | """ 28 | 29 | inputs = tokenizer(prompt, return_tensors="pt").to("cuda") 30 | with torch.no_grad(): 31 | logits = model(**inputs).logits 32 | 33 | # Extract the logits for the Yes and No tokens 34 | vocab = tokenizer.get_vocab() 35 | selected_logits = logits[0, -1, [vocab["Yes"], vocab["No"]]] 36 | 37 | # Convert these logits to a probability with softmax 38 | probabilities = softmax(selected_logits, dim=0) 39 | 40 | # Return probability of 'Yes' 41 | score = probabilities[0].item() 42 | 43 | return score > threshold -------------------------------------------------------------------------------- /modeling/tokenizer/networks/dc_ae_blocks/act.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | # This file is modified from https://github.com/mit-han-lab/efficientvit 18 | 19 | from functools import partial 20 | from typing import Optional 21 | 22 | import torch.nn as nn 23 | 24 | from modeling.tokenizer.utils import build_kwargs_from_config 25 | 26 | __all__ = ["build_act"] 27 | 28 | 29 | # register activation function here 30 | REGISTERED_ACT_DICT: dict[str, type] = { 31 | "relu": nn.ReLU, 32 | "relu6": nn.ReLU6, 33 | "hswish": nn.Hardswish, 34 | "silu": nn.SiLU, 35 | "gelu": partial(nn.GELU, approximate="tanh"), 36 | } 37 | 38 | 39 | def build_act(name: str, **kwargs) -> Optional[nn.Module]: 40 | if name in REGISTERED_ACT_DICT: 41 | act_cls = REGISTERED_ACT_DICT[name] 42 | args = build_kwargs_from_config(kwargs, act_cls) 43 | return act_cls(**args) 44 | else: 45 | return None -------------------------------------------------------------------------------- /modeling/tokenizer/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | # This file is modified from https://github.com/mit-han-lab/efficientvit 18 | 19 | from typing import List, Optional 20 | 21 | import torch 22 | import torch.nn.functional as F 23 | 24 | 25 | def resize( 26 | x: torch.Tensor, 27 | size: Optional[any] = None, 28 | scale_factor: Optional[List[float]] = None, 29 | mode: str = "bicubic", 30 | align_corners: Optional[bool] = False, 31 | ) -> torch.Tensor: 32 | if mode in {"bilinear", "bicubic"}: 33 | return F.interpolate( 34 | x, 35 | size=size, 36 | scale_factor=scale_factor, 37 | mode=mode, 38 | align_corners=align_corners, 39 | ) 40 | elif mode in {"nearest", "area"}: 41 | return F.interpolate(x, size=size, scale_factor=scale_factor, mode=mode) 42 | else: 43 | raise NotImplementedError(f"resize(mode={mode}) not implemented.") 44 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | default_prompts = [ 2 | "beautiful lady, freckles, big smile, blue eyes, short ginger hair, dark makeup, wearing a floral blue vest top, soft light, dark grey background", 3 | "anthropomorphic profile of the white snow owl Crystal priestess , art deco painting, pretty and expressive eyes, ornate costume, mythical, ethereal, intricate, elaborate, hyperrealism, hyper detailed, 3D, 8K, Ultra Realistic, high octane, ultra resolution, amazing detail, perfection, In frame, photorealistic, cinematic lighting, visual clarity, shading , Lumen Reflections, Super-Resolution, gigapixel, color grading, retouch, enhanced, PBR, Blender, V-ray, Procreate, zBrush, Unreal Engine 5, cinematic, volumetric, dramatic, neon lighting, wide angle lens ,no digital painting blur." 4 | "Bright scene, aerial view, ancient city, fantasy, gorgeous light, mirror reflection, high detail, wide angle lens.", 5 | "A 4k dslr image of a lemur wearing a red magician hat and a blue coat performing magic tricks with cards in a garden.", 6 | "A silhouette of a grand piano overlooking a dusky cityscape viewed from a top-floor penthouse, rendered in the bold and vivid sytle of a vintage travel poster.", 7 | "Crocodile in a sweater.", 8 | "Luffy from ONEPIECE, handsome face, fantasy.", 9 | "3d digital art of an adorable ghost, glowing within, holding a heart shaped pumpkin, Halloween, super cute, spooky haunted house background.", 10 | "an astronaut sitting in a diner, eating fries, cinematic, analog film", 11 | "Chinese architecture, ancient style,mountain, bird, lotus, pond, big tree, 4K Unity, octane rendering.", 12 | ] -------------------------------------------------------------------------------- /modeling/generator/net/import_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | # This file is modified from https://github.com/NVlabs/Sana 18 | 19 | import importlib.util 20 | import logging 21 | import warnings 22 | 23 | import importlib_metadata 24 | from packaging import version 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | _xformers_available = importlib.util.find_spec("xformers") is not None 29 | try: 30 | if _xformers_available: 31 | _xformers_version = importlib_metadata.version("xformers") 32 | _torch_version = importlib_metadata.version("torch") 33 | if version.Version(_torch_version) < version.Version("1.12"): 34 | raise ValueError("xformers is installed but requires PyTorch >= 1.12") 35 | logger.debug(f"Successfully imported xformers version {_xformers_version}") 36 | except importlib_metadata.PackageNotFoundError: 37 | _xformers_available = False 38 | 39 | 40 | def is_xformers_available(): 41 | return _xformers_available 42 | 43 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "dcar" 7 | version = "1.0.0" 8 | description = "DC-AR" 9 | readme = "README.md" 10 | requires-python = ">=3.8" 11 | classifiers = [ 12 | "Programming Language :: Python :: 3", 13 | "License :: OSI Approved :: Apache Software License", 14 | ] 15 | dependencies = [ 16 | "torch==2.3.0", "torchvision==0.18.0", 17 | "transformers==4.42.2", "tokenizers>=0.15.2", "sentencepiece==0.2.0", "shortuuid", 18 | "accelerate==0.27.2", "peft==0.5.0", "bitsandbytes==0.41.0", "pydantic==2.2.0", 19 | "markdown2[all]", "numpy==1.26.4", "scikit-learn==1.2.2", 20 | "gradio==4.44.1", "gradio_client==1.3.0", 21 | "requests", "httpx==0.24.1", "uvicorn", "fastapi==0.101.1", 22 | "einops==0.6.1", "einops-exts==0.0.4", "timm==0.9.12", 23 | "openpyxl==3.1.2", "pytorchvideo==0.1.5", "decord==0.6.0", 24 | "datasets==2.16.1", "openai==1.8.0", "webdataset==0.2.86", 25 | "nltk==3.3", "pywsd==1.2.4", "opencv-python==4.8.0.74", 26 | "omegaconf==2.3.0", "diffusers==0.28.2", 27 | "einx", "wandb", "xformers==0.0.26.post1", "open_clip_torch", "termcolor", "iopath", "torchinfo", 28 | "spaces==0.30.3", "pre-commit==4.0.1", "black==24.10.0", "isort==5.13.2", 29 | ] 30 | 31 | [project.optional-dependencies] 32 | train = ["deepspeed==0.9.5", "ninja", "wandb"] 33 | eval = ["mmengine", "word2number", "Levenshtein", "nltk", "pywsd"] 34 | 35 | 36 | [tool.setuptools.packages.find] 37 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] 38 | 39 | [tool.wheel] 40 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] -------------------------------------------------------------------------------- /modeling/tokenizer/quantizer/hybrid_vq.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | from typing import Mapping, Text, Tuple 18 | 19 | import random 20 | import torch 21 | import torch.nn.functional as F 22 | from einops import rearrange 23 | from torch.cuda.amp import autocast 24 | 25 | from .quantizer import VectorQuantizer 26 | 27 | class HybridVQ(VectorQuantizer): 28 | def __init__(self, *args, **kwargs): 29 | super().__init__(*args, **kwargs) 30 | 31 | @autocast(enabled=False) 32 | def forward(self, z: torch.Tensor, skip_continuous_prob: float = 1.0) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]: 33 | z_quantized, result_dict = super(HybridVQ, self).forward(z) 34 | 35 | if self.use_l2_norm: 36 | z = F.normalize(z, p=2, dim=1) 37 | 38 | residual_features = z - z_quantized 39 | residual_features = rearrange(residual_features, 'b c h w -> b (h w) c') 40 | 41 | result_dict['residual_features'] = residual_features 42 | result_dict['z_quantized'] = z_quantized 43 | result_dict['features'] = z 44 | 45 | p = random.random() 46 | if p >= skip_continuous_prob: 47 | feature = z.clone() 48 | else: 49 | feature = z_quantized 50 | return feature, result_dict -------------------------------------------------------------------------------- /modeling/tokenizer/utils/metric.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | # This file is modified from https://github.com/mit-han-lab/efficientvit 18 | 19 | import torch 20 | 21 | from ..utils import sync_tensor 22 | 23 | __all__ = ["AverageMeter"] 24 | 25 | 26 | class AverageMeter: 27 | """Computes and stores the average and current value.""" 28 | 29 | def __init__(self, is_distributed=True): 30 | self.is_distributed = is_distributed 31 | self.sum = 0 32 | self.count = 0 33 | 34 | def _sync(self, val: torch.Tensor | int | float) -> torch.Tensor | int | float: 35 | return sync_tensor(val, reduce="sum") if self.is_distributed else val 36 | 37 | def update(self, val: torch.Tensor | int | float, delta_n=1): 38 | self.count += self._sync(delta_n) 39 | self.sum += self._sync(val * delta_n) 40 | 41 | def get_count(self) -> torch.Tensor | int | float: 42 | return ( 43 | self.count.item() 44 | if isinstance(self.count, torch.Tensor) and self.count.numel() == 1 45 | else self.count 46 | ) 47 | 48 | @property 49 | def avg(self): 50 | avg = -1 if self.count == 0 else self.sum / self.count 51 | return avg.item() if isinstance(avg, torch.Tensor) and avg.numel() == 1 else avg 52 | -------------------------------------------------------------------------------- /modeling/tokenizer/configuration.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | from typing import Optional 18 | 19 | from transformers import PretrainedConfig 20 | 21 | class DCHTConfig(PretrainedConfig): 22 | model_type = 'dc_ht' 23 | 24 | def __init__( 25 | self, 26 | model_name: str = 'dc-ae-f32-proxy-in-1.0', 27 | pretrained_path: Optional[str] = None, 28 | 29 | codebook_size: int = 16384, 30 | codebook_embed_dim: int = 32, 31 | codebook_l2_norm: bool = True, 32 | codebook_show_usage: bool = True, 33 | commit_loss_beta: float = 0.25, 34 | entropy_loss_ratio: float = 0.0, 35 | quantizer_type: str = 'vq', 36 | 37 | disc_updated: bool = True, 38 | **kwargs 39 | ): 40 | super().__init__() 41 | 42 | self.model_name = model_name 43 | self.pretrained_path = pretrained_path 44 | 45 | self.codebook_size = codebook_size 46 | self.codebook_embed_dim = codebook_embed_dim 47 | self.codebook_l2_norm = codebook_l2_norm 48 | self.codebook_show_usage = codebook_show_usage 49 | self.commit_loss_beta = commit_loss_beta 50 | self.entropy_loss_ratio = entropy_loss_ratio 51 | self.quantizer_type = quantizer_type 52 | 53 | self.disc_updated = disc_updated 54 | 55 | -------------------------------------------------------------------------------- /configs/inference/dc_ar_t2i_512.yaml: -------------------------------------------------------------------------------- 1 | experiment: 2 | tokenizer_checkpoint: "pretrained_models/dc-ht" 3 | generator_checkpoint: "pretrained_models/dc-ar-512/pytorch_model.bin" 4 | 5 | model: 6 | type: vq_dc_ae 7 | hybrid: True 8 | quantizer_type: hybrid_vq 9 | vq_model: 10 | type: vq_dc_ae 11 | pretrained_path: "pretrained_models/dc-ht" 12 | codebook_size: 16384 13 | token_size: 32 14 | num_latent_tokens: 256 15 | finetune_decoder: True 16 | 17 | generator: 18 | type: maskgit_t2i 19 | model_type: "DCAR" 20 | hidden_size: 1152 21 | num_hidden_layers: 28 22 | num_attention_heads: 16 23 | intermediate_size: 4608 24 | attn_type: flash 25 | ffn_type: mlp 26 | qk_norm: True 27 | dropout: 0.1 28 | attn_drop: 0.1 29 | num_steps: 12 30 | class_label_dropout: 0.1 31 | drop_path: 0. 32 | image_seq_len: 256 33 | 34 | # sampling hyper-params 35 | guidance_scale_pow: 2.5 36 | randomize_temperature: 1.5 37 | guidance_scale: 4.5 38 | guidance_decay: "constant" 39 | 40 | diffusion: 41 | width: 1024 42 | depth: 6 43 | sampler: iddpm 44 | num_sampling_steps: '32' 45 | batch_mul: 4 46 | vae_scale: 0.107 47 | 48 | text_model: 'google-t5/t5-base' 49 | text_token_length: 300 50 | context_dim: 768 51 | use_llm_system_prompt: False 52 | 53 | dataset: 54 | train: 55 | dataset_name: pixart_wds 56 | dataset_names: null 57 | data_path: /home/yechengw/dataset/InternData/mjdata_v2 58 | resolution: 512 59 | intermediate_resolution: 576 60 | per_gpu_batch_size: 64 61 | mean: 0.5 62 | std: 0.5 63 | mj_splits: mj2_2_10M_new+mj_1-5_new+mj2_1_11M_new 64 | tokenizer_path: 'google-t5/t5-base' 65 | tokenizer_max_length: 300 66 | eval: 67 | dataset_name: mjhq 68 | data_path: /home/yechengw/dataset/MJHQ-30K 69 | resolution: 512 70 | intermediate_resolution: 576 71 | per_gpu_batch_size: 50 72 | mean: 0.5 73 | std: 0.5 74 | no_train_aug: False -------------------------------------------------------------------------------- /modeling/tokenizer/utils/list.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | # This file is modified from https://github.com/mit-han-lab/efficientvit 18 | 19 | from typing import Any, Optional 20 | 21 | __all__ = [ 22 | "list_sum", 23 | "list_mean", 24 | "weighted_list_sum", 25 | "list_join", 26 | "val2list", 27 | "val2tuple", 28 | "squeeze_list", 29 | ] 30 | 31 | 32 | def list_sum(x: list) -> Any: 33 | return x[0] if len(x) == 1 else x[0] + list_sum(x[1:]) 34 | 35 | 36 | def list_mean(x: list) -> Any: 37 | return list_sum(x) / len(x) 38 | 39 | 40 | def weighted_list_sum(x: list, weights: list) -> Any: 41 | assert len(x) == len(weights) 42 | return x[0] * weights[0] if len(x) == 1 else x[0] * weights[0] + weighted_list_sum(x[1:], weights[1:]) 43 | 44 | 45 | def list_join(x: list, sep="\t", format_str="%s") -> str: 46 | return sep.join([format_str % val for val in x]) 47 | 48 | 49 | def val2list(x: list | tuple | Any, repeat_time=1) -> list: 50 | if isinstance(x, (list, tuple)): 51 | return list(x) 52 | return [x for _ in range(repeat_time)] 53 | 54 | 55 | def val2tuple(x: list | tuple | Any, min_len: int = 1, idx_repeat: int = -1) -> tuple: 56 | x = val2list(x) 57 | 58 | # repeat elements if necessary 59 | if len(x) > 0: 60 | x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))] 61 | 62 | return tuple(x) 63 | 64 | 65 | def squeeze_list(x: Optional[list]) -> list | Any: 66 | if x is not None and len(x) == 1: 67 | return x[0] 68 | else: 69 | return x 70 | -------------------------------------------------------------------------------- /modeling/modules/losses.py: -------------------------------------------------------------------------------- 1 | """This files contains training loss implementation. 2 | 3 | Copyright (2025) Bytedance Ltd. and/or its affiliates 4 | 5 | Licensed under the Apache License, Version 2.0 (the "License"); 6 | you may not use this file except in compliance with the License. 7 | You may obtain a copy of the License at 8 | 9 | http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | 17 | Ref: 18 | https://github.com/CompVis/taming-transformers/blob/master/taming/modules/losses/vqperceptual.py 19 | """ 20 | from typing import Mapping, Text, Tuple 21 | 22 | import torch 23 | import torch.nn as nn 24 | import torch.nn.functional as F 25 | from einops import rearrange 26 | from torch.cuda.amp import autocast 27 | 28 | class MLMLoss(torch.nn.Module): 29 | def __init__(self, 30 | config): 31 | super().__init__() 32 | self.label_smoothing = config.losses.label_smoothing 33 | self.loss_weight_unmasked_token = config.losses.loss_weight_unmasked_token 34 | self.criterion = torch.nn.CrossEntropyLoss(label_smoothing=self.label_smoothing, 35 | reduction="none") 36 | 37 | def forward(self, inputs: torch.Tensor, targets: torch.Tensor, 38 | weights=None) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]: 39 | inputs = rearrange(inputs, "b n c -> b c n") 40 | loss = self.criterion(inputs, targets) 41 | weights = weights.to(loss) 42 | loss_weights = (1.0 - weights) * self.loss_weight_unmasked_token + weights # set 0 to self.loss_weight_unasked_token 43 | loss = (loss * loss_weights).sum() / (loss_weights.sum() + 1e-8) 44 | # we only compute correct tokens on masked tokens 45 | correct_tokens = ((torch.argmax(inputs, dim=1) == targets) * weights).sum(dim=1) / (weights.sum(1) + 1e-8) 46 | return loss, {"loss": loss, "correct_tokens": correct_tokens.mean()} -------------------------------------------------------------------------------- /modeling/generator/net/act.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | import copy 18 | 19 | import torch.nn as nn 20 | 21 | __all__ = ["build_act", "get_act_name"] 22 | 23 | # register activation function here 24 | # name: module, kwargs with default values 25 | REGISTERED_ACT_DICT: dict[str, tuple[type, dict[str, any]]] = { 26 | "relu": (nn.ReLU, {"inplace": True}), 27 | "relu6": (nn.ReLU6, {"inplace": True}), 28 | "hswish": (nn.Hardswish, {"inplace": True}), 29 | "hsigmoid": (nn.Hardsigmoid, {"inplace": True}), 30 | "swish": (nn.SiLU, {"inplace": True}), 31 | "silu": (nn.SiLU, {"inplace": True}), 32 | "tanh": (nn.Tanh, {}), 33 | "sigmoid": (nn.Sigmoid, {}), 34 | "gelu": (nn.GELU, {"approximate": "tanh"}), 35 | "mish": (nn.Mish, {"inplace": True}), 36 | "identity": (nn.Identity, {}), 37 | } 38 | 39 | 40 | def build_act(name: str, **kwargs) -> nn.Module: 41 | if name in REGISTERED_ACT_DICT: 42 | act_cls, default_args = copy.deepcopy(REGISTERED_ACT_DICT[name]) 43 | for key in default_args: 44 | if key in kwargs: 45 | default_args[key] = kwargs[key] 46 | return act_cls(**default_args) 47 | elif name is None or name.lower() == "none": 48 | return None 49 | else: 50 | raise ValueError(f"do not support: {name}") 51 | 52 | 53 | def get_act_name(act: nn.Module) -> str: 54 | if act is None: 55 | return None 56 | module2name = {} 57 | for key, config in REGISTERED_ACT_DICT.items(): 58 | module2name[config[0].__name__] = key 59 | return module2name.get(type(act).__name__, "unknown") -------------------------------------------------------------------------------- /modeling/tokenizer/utils/dist.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | # This file is modified from https://github.com/mit-han-lab/efficientvit 18 | 19 | import os 20 | from typing import List, Union 21 | 22 | import torch 23 | import torch.distributed as dist 24 | 25 | 26 | def is_master() -> bool: 27 | return get_dist_rank() == 0 28 | 29 | 30 | def list_sum(x: list) -> any: 31 | return x[0] if len(x) == 1 else x[0] + list_sum(x[1:]) 32 | 33 | 34 | def list_mean(x: list) -> any: 35 | return list_sum(x) / len(x) 36 | 37 | 38 | def get_dist_rank() -> int: 39 | return int(os.environ["RANK"]) 40 | 41 | 42 | def get_dist_size() -> int: 43 | return int(os.environ["WORLD_SIZE"]) 44 | 45 | 46 | @torch.no_grad() 47 | def sync_tensor( 48 | tensor: Union[torch.Tensor, float], reduce="mean" 49 | ) -> Union[torch.Tensor, List[torch.Tensor]]: 50 | if not isinstance(tensor, torch.Tensor): 51 | tensor = torch.Tensor(1).fill_(tensor).cuda() 52 | tensor_list = [torch.empty_like(tensor) for _ in range(get_dist_size())] 53 | torch.distributed.all_gather(tensor_list, tensor.contiguous(), async_op=False) 54 | if reduce == "mean": 55 | return list_mean(tensor_list) 56 | elif reduce == "sum": 57 | return list_sum(tensor_list) 58 | elif reduce == "cat": 59 | return torch.cat(tensor_list, dim=0) 60 | elif reduce == "root": 61 | return tensor_list[0] 62 | else: 63 | return tensor_list 64 | 65 | 66 | @torch.no_grad() 67 | def all_gather_cat(world_size, tensor, dim=0): 68 | if world_size == 1: 69 | return tensor 70 | 71 | g_tensor = [torch.ones_like(tensor) for _ in range(world_size)] 72 | dist.all_gather(g_tensor, tensor) 73 | g_tensor = torch.cat(g_tensor, dim=dim) 74 | 75 | return g_tensor 76 | -------------------------------------------------------------------------------- /modeling/tokenizer/utils/random.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | # This file is modified from https://github.com/mit-han-lab/efficientvit 18 | 19 | from typing import List, Optional, Union 20 | 21 | import numpy as np 22 | import torch 23 | 24 | __all__ = [ 25 | "torch_randint", 26 | "torch_random", 27 | "torch_shuffle", 28 | "torch_uniform", 29 | "torch_random_choices", 30 | ] 31 | 32 | 33 | def torch_randint( 34 | low: int, high: int, generator: Optional[torch.Generator] = None 35 | ) -> int: 36 | """uniform: [low, high)""" 37 | if low == high: 38 | return low 39 | else: 40 | assert low < high 41 | return int(torch.randint(low=low, high=high, generator=generator, size=(1,))) 42 | 43 | 44 | def torch_random(generator: Optional[torch.Generator] = None) -> float: 45 | """uniform distribution on the interval [0, 1)""" 46 | return float(torch.rand(1, generator=generator)) 47 | 48 | 49 | def torch_shuffle( 50 | src_list: List[any], generator: Optional[torch.Generator] = None 51 | ) -> List[any]: 52 | rand_indexes = torch.randperm(len(src_list), generator=generator).tolist() 53 | return [src_list[i] for i in rand_indexes] 54 | 55 | 56 | def torch_uniform( 57 | low: float, high: float, generator: Optional[torch.Generator] = None 58 | ) -> float: 59 | """uniform distribution on the interval [low, high)""" 60 | rand_val = torch_random(generator) 61 | return (high - low) * rand_val + low 62 | 63 | 64 | def torch_random_choices( 65 | src_list: List[any], 66 | generator: Optional[torch.Generator] = None, 67 | k=1, 68 | weight_list: Optional[List[float]] = None, 69 | ) -> Union[any, list]: 70 | if weight_list is None: 71 | rand_idx = torch.randint( 72 | low=0, high=len(src_list), generator=generator, size=(k,) 73 | ) 74 | out_list = [src_list[i] for i in rand_idx] 75 | else: 76 | assert len(weight_list) == len(src_list) 77 | accumulate_weight_list = np.cumsum(weight_list) 78 | 79 | out_list = [] 80 | for _ in range(k): 81 | val = torch_uniform(0, accumulate_weight_list[-1], generator) 82 | active_id = 0 83 | for i, weight_val in enumerate(accumulate_weight_list): 84 | active_id = i 85 | if weight_val > val: 86 | break 87 | out_list.append(src_list[active_id]) 88 | 89 | return out_list[0] if k == 1 else out_list 90 | -------------------------------------------------------------------------------- /modeling/tokenizer/utils/init.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | # This file is modified from https://github.com/mit-han-lab/efficientvit 18 | 19 | import torch 20 | import torch.nn as nn 21 | from torch.nn.modules.batchnorm import _BatchNorm 22 | 23 | __all__ = ["init_modules", "zero_last_gamma"] 24 | 25 | 26 | def init_modules(model: nn.Module | list[nn.Module], init_type="trunc_normal") -> None: 27 | _DEFAULT_INIT_PARAM = {"trunc_normal": 0.02} 28 | 29 | if isinstance(model, list): 30 | for sub_module in model: 31 | init_modules(sub_module, init_type) 32 | else: 33 | init_params = init_type.split("@") 34 | init_params = float(init_params[1]) if len(init_params) > 1 else None 35 | 36 | if init_type.startswith("trunc_normal"): 37 | init_func = lambda param: nn.init.trunc_normal_( 38 | param, std=(_DEFAULT_INIT_PARAM["trunc_normal"] if init_params is None else init_params) 39 | ) 40 | else: 41 | raise NotImplementedError 42 | 43 | for m in model.modules(): 44 | if isinstance(m, (nn.Conv2d, nn.Linear, nn.ConvTranspose2d)): 45 | init_func(m.weight) 46 | if m.bias is not None: 47 | m.bias.data.zero_() 48 | elif isinstance(m, nn.Embedding): 49 | init_func(m.weight) 50 | elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)): 51 | m.weight.data.fill_(1) 52 | m.bias.data.zero_() 53 | else: 54 | weight = getattr(m, "weight", None) 55 | bias = getattr(m, "bias", None) 56 | if isinstance(weight, torch.nn.Parameter): 57 | init_func(weight) 58 | if isinstance(bias, torch.nn.Parameter): 59 | bias.data.zero_() 60 | 61 | 62 | def zero_last_gamma(model: nn.Module, init_val=0) -> None: 63 | import modules.nn.ops as ops 64 | 65 | for m in model.modules(): 66 | if isinstance(m, ops.ResidualBlock) and isinstance( 67 | m.shortcut, ops.IdentityLayer 68 | ): 69 | if isinstance( 70 | m.main, (ops.DSConv, ops.MBConv, ops.FusedMBConv, ops.GLUMBConv) 71 | ): 72 | parent_module = m.main.point_conv 73 | elif isinstance(m.main, ops.ResBlock): 74 | parent_module = m.main.conv2 75 | elif isinstance(m.main, (ops.ConvLayer, ops.LinearLayer)): 76 | parent_module = m.main 77 | elif isinstance(m.main, (ops.LiteMLA, ops.SoftmaxAtt, ops.CrossAtt)): 78 | parent_module = m.main.proj 79 | else: 80 | parent_module = None 81 | if parent_module is not None: 82 | norm = getattr(parent_module, "norm", None) 83 | if norm is not None and norm.weight.requires_grad: 84 | nn.init.constant_(norm.weight, init_val) 85 | -------------------------------------------------------------------------------- /modeling/tokenizer/dc_ht.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | from einops import rearrange 21 | from transformers import AutoConfig, AutoModel, PreTrainedModel 22 | 23 | from .configuration import DCHTConfig 24 | from .quantizer import VectorQuantizer as VQ 25 | from .quantizer import HybridVQ 26 | from .networks.dc_ae import Encoder, Decoder, DCAEConfig, dc_ae_f32, dc_ae_f16 27 | 28 | class DCHT(PreTrainedModel): 29 | config_class = DCHTConfig 30 | 31 | def __init__( 32 | self, 33 | config: DCHTConfig 34 | ): 35 | super().__init__(config) 36 | 37 | if config.model_name in ["dc-ae-f32-in-1.0", "dc-ae-f32-mix-1.0"]: 38 | dc_ae_cfg = dc_ae_f32(config.model_name, config.codebook_embed_dim, config.pretrained_path) 39 | elif config.model_name in ['dc-ae-f16-in-1.0']: 40 | dc_ae_cfg = dc_ae_f16(config.model_name, config.codebook_embed_dim, config.pretrained_path) 41 | else: 42 | raise NotImplementedError 43 | 44 | self.cfg = dc_ae_cfg 45 | self.encoder = Encoder(dc_ae_cfg.encoder) 46 | self.decoder = Decoder(dc_ae_cfg.decoder) 47 | 48 | self.hybrid = False 49 | if config.quantizer_type == 'vq': 50 | self.quantize = VQ(config.codebook_size, config.codebook_embed_dim, 51 | config.commit_loss_beta, config.codebook_l2_norm) 52 | elif config.quantizer_type == 'hybrid_vq': 53 | self.quantize = HybridVQ(config.codebook_size, config.codebook_embed_dim, 54 | config.commit_loss_beta, config.codebook_l2_norm) 55 | self.hybrid = True 56 | 57 | def encode(self, x, **kwargs): 58 | h = self.encoder(x) 59 | quant, info = self.quantize(h, **kwargs) 60 | return quant, info 61 | 62 | def decode(self, quant): 63 | dec = self.decoder(quant) 64 | return dec 65 | 66 | def decode_tokens(self, tokens, residual_features = None): 67 | batch, seq_len = tokens.shape # B x N 68 | z_quantized = self.quantize.get_codebook_entry( 69 | tokens.reshape(-1)).reshape(batch, int(seq_len ** 0.5), int(seq_len ** 0.5), -1) 70 | z_quantized = rearrange(z_quantized, 'b h w c -> b c h w').contiguous() 71 | if self.hybrid and residual_features is not None: 72 | residual_features = residual_features.reshape(batch, int(seq_len ** 0.5), int(seq_len ** 0.5), -1) 73 | residual_features = rearrange(residual_features, 'b h w c -> b c h w').contiguous() 74 | decoded = self.decode(z_quantized + residual_features) 75 | else: 76 | decoded = self.decode(z_quantized) 77 | return decoded 78 | 79 | def forward(self, x, **kwargs): 80 | z_q, info = self.encode(x, **kwargs) 81 | decoded = self.decode(z_q) 82 | return decoded, info 83 | 84 | AutoConfig.register("dc_ht", DCHTConfig) 85 | AutoModel.register(DCHTConfig, DCHT) -------------------------------------------------------------------------------- /modeling/tokenizer/quantizer/quantizer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | # This file is modified from https://github.com/mit-han-lab/efficientvit and https://github.com/CompVis/taming-transformers 18 | 19 | from typing import Mapping, Text, Tuple 20 | 21 | import torch 22 | from einops import rearrange 23 | from torch.cuda.amp import autocast 24 | 25 | class VectorQuantizer(torch.nn.Module): 26 | def __init__(self, 27 | codebook_size: int = 1024, 28 | token_size: int = 256, 29 | commitment_cost: float = 0.25, 30 | use_l2_norm: bool = False, 31 | ): 32 | super().__init__() 33 | self.commitment_cost = commitment_cost 34 | 35 | self.embedding = torch.nn.Embedding(codebook_size, token_size) 36 | self.embedding.weight.data.uniform_(-1.0 / codebook_size, 1.0 / codebook_size) 37 | self.use_l2_norm = use_l2_norm 38 | 39 | # Ensure quantization is performed using f32 40 | @autocast(enabled=False) 41 | def forward(self, z: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]: 42 | z = z.float() 43 | z = rearrange(z, 'b c h w -> b h w c').contiguous() 44 | z_flattened = rearrange(z, 'b h w c -> (b h w) c') 45 | 46 | if self.use_l2_norm: 47 | z_flattened = torch.nn.functional.normalize(z_flattened, dim=-1) 48 | embedding = torch.nn.functional.normalize(self.embedding.weight, dim=-1) 49 | else: 50 | embedding = self.embedding.weight 51 | d = torch.sum(z_flattened**2, dim=1, keepdim=True) + \ 52 | torch.sum(embedding**2, dim=1) - 2 * \ 53 | torch.einsum('bd,dn->bn', z_flattened, embedding.T) 54 | 55 | min_encoding_indices = torch.argmin(d, dim=1) # num_ele 56 | z_quantized = self.get_codebook_entry(min_encoding_indices).view(z.shape) 57 | 58 | if self.use_l2_norm: 59 | z = torch.nn.functional.normalize(z, dim=-1) 60 | 61 | # compute loss for embedding 62 | commitment_loss = self.commitment_cost * torch.mean((z_quantized.detach() - z) **2) 63 | codebook_loss = torch.mean((z_quantized - z.detach()) **2) 64 | 65 | loss = commitment_loss + codebook_loss 66 | 67 | # preserve gradients 68 | z_quantized = z + (z_quantized - z).detach() 69 | 70 | # reshape back to match original input shape 71 | z_quantized = rearrange(z_quantized, 'b h w c -> b c h w').contiguous() 72 | 73 | result_dict = dict( 74 | quantizer_loss=loss, 75 | commitment_loss=commitment_loss, 76 | codebook_loss=codebook_loss, 77 | min_encoding_indices=min_encoding_indices.view(z_quantized.shape[0], z_quantized.shape[2], z_quantized.shape[3]) 78 | ) 79 | 80 | return z_quantized, result_dict 81 | 82 | def get_codebook_entry(self, indices): 83 | if len(indices.shape) == 1: 84 | z_quantized = self.embedding(indices) 85 | elif len(indices.shape) == 2: 86 | z_quantized = torch.einsum('bd,dn->bn', indices, self.embedding.weight) 87 | else: 88 | raise NotImplementedError 89 | if self.use_l2_norm: 90 | z_quantized = torch.nn.functional.normalize(z_quantized, dim=-1) 91 | return z_quantized -------------------------------------------------------------------------------- /modeling/tokenizer/utils/network.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | # This file is modified from https://github.com/mit-han-lab/efficientvit 18 | 19 | import collections 20 | import os 21 | from inspect import signature 22 | from typing import Any, Callable, Optional 23 | 24 | import torch 25 | import torch.nn as nn 26 | import torch.nn.functional as F 27 | 28 | __all__ = [ 29 | "is_parallel", 30 | "get_device", 31 | "get_same_padding", 32 | "resize", 33 | "build_kwargs_from_config", 34 | "load_state_dict_from_file", 35 | "get_submodule_weights", 36 | ] 37 | 38 | 39 | def is_parallel(model: nn.Module) -> bool: 40 | return isinstance(model, (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)) 41 | 42 | 43 | def get_device(model: nn.Module) -> torch.device: 44 | return model.parameters().__next__().device 45 | 46 | 47 | def get_dtype(model: nn.Module) -> torch.dtype: 48 | return model.parameters().__next__().dtype 49 | 50 | 51 | def get_same_padding(kernel_size: int | tuple[int, ...]) -> int | tuple[int, ...]: 52 | if isinstance(kernel_size, tuple): 53 | return tuple([get_same_padding(ks) for ks in kernel_size]) 54 | else: 55 | assert kernel_size % 2 > 0, "kernel size should be odd number" 56 | return kernel_size // 2 57 | 58 | 59 | def resize( 60 | x: torch.Tensor, 61 | size: Optional[Any] = None, 62 | scale_factor: Optional[list[float]] = None, 63 | mode: str = "bicubic", 64 | align_corners: Optional[bool] = False, 65 | ) -> torch.Tensor: 66 | if mode in {"bilinear", "bicubic"}: 67 | return F.interpolate( 68 | x, 69 | size=size, 70 | scale_factor=scale_factor, 71 | mode=mode, 72 | align_corners=align_corners, 73 | ) 74 | elif mode in {"nearest", "area"}: 75 | return F.interpolate(x, size=size, scale_factor=scale_factor, mode=mode) 76 | else: 77 | raise NotImplementedError(f"resize(mode={mode}) not implemented.") 78 | 79 | 80 | def build_kwargs_from_config(config: dict, target_func: Callable) -> dict[str, Any]: 81 | valid_keys = list(signature(target_func).parameters) 82 | kwargs = {} 83 | for key in config: 84 | if key in valid_keys: 85 | kwargs[key] = config[key] 86 | return kwargs 87 | 88 | 89 | def load_state_dict_from_file(file: str, only_state_dict=True) -> dict[str, torch.Tensor]: 90 | file = os.path.realpath(os.path.expanduser(file)) 91 | checkpoint = torch.load(file, map_location="cpu") 92 | if only_state_dict and "state_dict" in checkpoint: 93 | checkpoint = checkpoint["state_dict"] 94 | return checkpoint 95 | 96 | 97 | def get_submodule_weights(weights: collections.OrderedDict, prefix: str): 98 | submodule_weights = collections.OrderedDict() 99 | len_prefix = len(prefix) 100 | for key, weight in weights.items(): 101 | if key.startswith(prefix): 102 | submodule_weights[key[len_prefix:]] = weight 103 | return submodule_weights 104 | 105 | 106 | def get_dtype_from_str(dtype: str) -> torch.dtype: 107 | if dtype == "fp32": 108 | return torch.float32 109 | if dtype == "fp16": 110 | return torch.float16 111 | if dtype == "bf16": 112 | return torch.bfloat16 113 | raise NotImplementedError(f"dtype {dtype} is not supported") 114 | -------------------------------------------------------------------------------- /LICENSE/dc_ar_models.txt: -------------------------------------------------------------------------------- 1 | NVIDIA License 2 | 3 | 4 | 1. Definitions 5 | 6 | “Licensor” means any person or entity that distributes its Work. 7 | 8 | “Work” means (a) the original work of authorship made available under this license, which may include software, documentation, or other files, and (b) any additions to or derivative works thereof that are made available under this license. 9 | 10 | The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as provided under U.S. copyright law; provided, however, that for the purposes of this license, derivative works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work. 11 | 12 | Works are “made available” under this license by including in or with the Work either (a) a copyright notice referencing the applicability of this license to the Work, or (b) a copy of this license. 13 | 14 | 2. License Grant 15 | 16 | 2.1 Copyright Grant. Subject to the terms and conditions of this license, each Licensor grants to you a perpetual, worldwide, non-exclusive, royalty-free, copyright license to use, reproduce, prepare derivative works of, publicly display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form. 17 | 18 | 3. Limitations 19 | 20 | 3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this license, (b) you include a complete copy of this license with your distribution, and (c) you retain without modification any copyright, patent, trademark, or attribution notices that are present in the Work. 21 | 22 | 3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and distribution of your derivative works of the Work (“Your Terms”) only if (a) Your Terms provide that the use limitation in Section 3.3 applies to your derivative works, and (b) you identify the specific derivative works that are subject to Your Terms. Notwithstanding Your Terms, this license (including the redistribution requirements in Section 3.1) will continue to apply to the Work itself. 23 | 24 | 3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use non-commercially. Notwithstanding the foregoing, NVIDIA Corporation and its affiliates may use the Work and any derivative works commercially. As used herein, “non-commercially” means for research or evaluation purposes only. 25 | 26 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim, cross-claim or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then your rights under this license from such Licensor (including the grant in Section 2.1) will terminate immediately. 27 | 28 | 3.5 Trademarks. This license does not grant any rights to use any Licensor’s or its affiliates’ names, logos, or trademarks, except as necessary to reproduce the notices described in this license. 29 | 30 | 3.6 Termination. If you violate any term of this license, then your rights under this license (including the grant in Section 2.1) will terminate immediately. 31 | 32 | 4. Disclaimer of Warranty. 33 | 34 | THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE. 35 | 36 | 5. Limitation of Liability. 37 | 38 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | !kernels/csrc/qgemm/w8a8/libs/*.so 10 | !kernels/csrc/qgemm/w4a8_per_chn/libs/*.so 11 | !kernels/csrc/qgemm/w4a8_per_group/libs/*.so 12 | 13 | # Checkpoints 14 | qserve_checkpoints/ 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | cover/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | db.sqlite3-journal 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | .pybuilder/ 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # IPython 89 | profile_default/ 90 | ipython_config.py 91 | 92 | # pyenv 93 | # For a library or package, you might want to ignore these files since the code is 94 | # intended to run in multiple environments; otherwise, check them in: 95 | # .python-version 96 | 97 | # pipenv 98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 101 | # install all needed dependencies. 102 | #Pipfile.lock 103 | 104 | # poetry 105 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 106 | # This is especially recommended for binary packages to ensure reproducibility, and is more 107 | # commonly ignored for libraries. 108 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 109 | #poetry.lock 110 | 111 | # pdm 112 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 113 | #pdm.lock 114 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 115 | # in version control. 116 | # https://pdm.fming.dev/#use-with-ide 117 | .pdm.toml 118 | 119 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 120 | __pypackages__/ 121 | 122 | # Celery stuff 123 | celerybeat-schedule 124 | celerybeat.pid 125 | 126 | # SageMath parsed files 127 | *.sage.py 128 | 129 | # Environments 130 | .env 131 | .venv 132 | env/ 133 | venv/ 134 | ENV/ 135 | env.bak/ 136 | venv.bak/ 137 | 138 | # Spyder project settings 139 | .spyderproject 140 | .spyproject 141 | 142 | # Rope project settings 143 | .ropeproject 144 | 145 | # mkdocs documentation 146 | /site 147 | 148 | # mypy 149 | .mypy_cache/ 150 | .dmypy.json 151 | dmypy.json 152 | 153 | # Pyre type checker 154 | .pyre/ 155 | 156 | # pytype static type analyzer 157 | .pytype/ 158 | 159 | # Cython debug symbols 160 | cython_debug/ 161 | 162 | # PyCharm 163 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 164 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 165 | # and can be added to the global gitignore or merged into this file. For a more nuclear 166 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 167 | #.idea/ 168 | 169 | # Result files 170 | *.csv 171 | 172 | logs/ 173 | wandb/ 174 | pretrained_models/ 175 | samples/ 176 | gradio_cached_examples/ -------------------------------------------------------------------------------- /modeling/diffusion/diffusion_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | # Modified from OpenAI's diffusion repos 18 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 19 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 20 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 21 | 22 | import numpy as np 23 | import torch as th 24 | 25 | 26 | def normal_kl(mean1, logvar1, mean2, logvar2): 27 | """ 28 | Compute the KL divergence between two gaussians. 29 | Shapes are automatically broadcasted, so batches can be compared to 30 | scalars, among other use cases. 31 | """ 32 | tensor = None 33 | for obj in (mean1, logvar1, mean2, logvar2): 34 | if isinstance(obj, th.Tensor): 35 | tensor = obj 36 | break 37 | assert tensor is not None, "at least one argument must be a Tensor" 38 | 39 | # Force variances to be Tensors. Broadcasting helps convert scalars to 40 | # Tensors, but it does not work for th.exp(). 41 | logvar1, logvar2 = ( 42 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 43 | for x in (logvar1, logvar2) 44 | ) 45 | 46 | return 0.5 * ( 47 | -1.0 48 | + logvar2 49 | - logvar1 50 | + th.exp(logvar1 - logvar2) 51 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 52 | ) 53 | 54 | 55 | def approx_standard_normal_cdf(x): 56 | """ 57 | A fast approximation of the cumulative distribution function of the 58 | standard normal. 59 | """ 60 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 61 | 62 | 63 | def continuous_gaussian_log_likelihood(x, *, means, log_scales): 64 | """ 65 | Compute the log-likelihood of a continuous Gaussian distribution. 66 | :param x: the targets 67 | :param means: the Gaussian mean Tensor. 68 | :param log_scales: the Gaussian log stddev Tensor. 69 | :return: a tensor like x of log probabilities (in nats). 70 | """ 71 | centered_x = x - means 72 | inv_stdv = th.exp(-log_scales) 73 | normalized_x = centered_x * inv_stdv 74 | log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob( 75 | normalized_x 76 | ) 77 | return log_probs 78 | 79 | 80 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 81 | """ 82 | Compute the log-likelihood of a Gaussian distribution discretizing to a 83 | given image. 84 | :param x: the target images. It is assumed that this was uint8 values, 85 | rescaled to the range [-1, 1]. 86 | :param means: the Gaussian mean Tensor. 87 | :param log_scales: the Gaussian log stddev Tensor. 88 | :return: a tensor like x of log probabilities (in nats). 89 | """ 90 | assert x.shape == means.shape == log_scales.shape 91 | centered_x = x - means 92 | inv_stdv = th.exp(-log_scales) 93 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 94 | cdf_plus = approx_standard_normal_cdf(plus_in) 95 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 96 | cdf_min = approx_standard_normal_cdf(min_in) 97 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 98 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 99 | cdf_delta = cdf_plus - cdf_min 100 | log_probs = th.where( 101 | x < -0.999, 102 | log_cdf_plus, 103 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 104 | ) 105 | assert log_probs.shape == x.shape 106 | return log_probs -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DC-AR: Efficient Masked Autoregressive Image Generation with Deep Compression Hybrid Tokenizer 2 | 3 | ###
ICCV 2025
4 | 5 |
6 |   7 |   8 |   9 |   10 |
11 | 12 | ![teaser_Page1](assets/teaser.png) 13 | 14 | ## 🔥🔥 News 15 | 16 | - \[2025/10/18\] 🔥 We released the DC-AR code and pre-trained models. 17 | - \[2025/6\] 🔥 DC-AR is accepted by ICCV 2025! 18 | 19 | ## Abstract 20 | 21 | We introduce DC-AR, a novel masked autoregressive (AR) text-to-image generation framework that delivers superior image generation quality with exceptional computational efficiency. Due to the tokenizers' limitations, prior masked AR models have lagged behind diffusion models in terms of quality or efficiency. We overcome this limitation by introducing DC-HT - a deep compression hybrid tokenizer for AR models that achieves a 32× spatial compression ratio while maintaining high reconstruction fidelity and cross-resolution generalization ability. Building upon DC-HT, we extend MaskGIT and create a new hybrid masked autoregressive image generation framework that first produces the structural elements through discrete tokens and then applies refinements via residual tokens. DC-AR achieves state-of-the-art results with a gFID of **5.49** on MJHQ-30K and an overall score of **0.69** on GenEval, while offering **1.5-7.9×** higher throughput and **2.0-3.5×** lower latency compared to prior leading diffusion and autoregressive models. 22 | 23 | ## Setup 24 | 25 | Download the repo and install the environment: 26 | 27 | ```bash 28 | git clone https://github.com/mit-han-lab/dc-ar 29 | cd dc-ar 30 | conda create -n dcar python=3.10 31 | conda activate dcar 32 | pip install -e . 33 | ``` 34 | 35 | Download DC-HT and DC-AR 36 | 37 | ```bash 38 | git clone https://huggingface.co/mit-han-lab/dc-ar-512 39 | git clone https://huggingface.co/mit-han-lab/dc-ht 40 | ``` 41 | 42 | Download the safety check model: 43 | 44 | ```bash 45 | git clone https://huggingface.co/google/shieldgemma-2b 46 | ``` 47 | 48 | Note: We use ShieldGemma-2B from Google DeepMind to filter out unsafe prompts in our demo. We strongly recommend using it if you are distributing our demo publicly. 49 | 50 | ## Usage 51 | 52 | ### Gradio demo 53 | 54 | You may launch the Gradio demo using the following script: 55 | 56 | ```bash 57 | python app.py --shield_model_path /path/to/ShieldGemma2B 58 | ``` 59 | 60 | ### Command Line Inference 61 | 62 | 1. Sampling with single prompt: 63 | 64 | ```bash 65 | python sample.py --prompt "YOUR_PROMPT" \ 66 | --sample_folder_dir /path/to/save_dir \ 67 | --shield_model_path /path/to/ShieldGemma2B 68 | ``` 69 | 70 | 2. Sampling with multiple prompts: 71 | 72 | ```bash 73 | # You can add --store_seperately to store each image individually, otherwise images will be stored in one grid. 74 | python sample.py --prompt_list [Prompt1, Prompt2, ..., PromptN] \ 75 | --sample_folder_dir /path/to/save_dir \ 76 | --shield_model_path /path/to/ShieldGemma2B 77 | ``` 78 | 79 | ## Acknowledgements 80 | 81 | Our codebase is inspired by awesome open source research projects such as [1D-Tokenizer](https://github.com/bytedance/1d-tokenizer) and [MAR](https://github.com/LTH14/mar). Thanks for their efforts! 82 | 83 | ## License 84 | + [Code](./LICENSE/code) 85 | + [DC-AR Models](./LICENSE/dc_ar_models) 86 | 87 | ## Contact 88 | + [Han Cai](http://hancai.ai/) 89 | + [Song Han](https://hanlab.mit.edu/songhan) 90 | 91 | ## 📖 BibTeX 92 | 93 | ```bibtex 94 | @article{wu2025dcar, 95 | title={DC-AR: Efficient Masked Autoregressive Image Generation with Deep Compression Hybrid Tokenizer}, 96 | author={Wu, Yecheng and Chen, Junyu and Zhang, Zhuoyang and Xie, Enze and Yu, Jincheng and Chen, Junsong and Hu, Jinyi and Lu, Yao and Han, Song and Cai, Han}, 97 | journal={arXiv preprint arXiv:2410.10733}, 98 | year={2025} 99 | } 100 | ``` 101 | -------------------------------------------------------------------------------- /sample.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import json 4 | import random 5 | from typing import Optional, Tuple 6 | from tqdm import tqdm 7 | import time 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.distributed as dist 12 | import torchvision 13 | from PIL import Image 14 | import numpy as np 15 | 16 | import utils.demo_util as demo_util 17 | from utils import default_prompts 18 | from utils.safety_check import is_dangerous 19 | 20 | from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM 21 | from modeling.tokenizer.dc_ht import DCHT 22 | 23 | 24 | def save_images(sample_imgs, sample_folder_dir, store_separately, prompts): 25 | if not store_separately and len(sample_imgs) > 1: 26 | grid = torchvision.utils.make_grid(sample_imgs, nrow=12) 27 | grid_np = grid.to(torch.float16).permute(1, 2, 0).mul_(255).cpu().numpy() 28 | 29 | os.makedirs(sample_folder_dir, exist_ok=True) 30 | grid_np = Image.fromarray(grid_np.astype(np.uint8)) 31 | grid_np.save(os.path.join(sample_folder_dir, f"sample_images.png")) 32 | print(f"Example images are saved to {sample_folder_dir}") 33 | else: 34 | # bs, 3, r, r 35 | sample_imgs_np = sample_imgs.mul_(255).cpu().numpy() 36 | num_imgs = sample_imgs_np.shape[0] 37 | os.makedirs(sample_folder_dir, exist_ok=True) 38 | for img_idx in range(num_imgs): 39 | cur_img = sample_imgs_np[img_idx] 40 | cur_img = cur_img.transpose(1, 2, 0).astype(np.uint8) 41 | cur_img_store = Image.fromarray(cur_img) 42 | cur_img_store.save(os.path.join(sample_folder_dir, f"{img_idx:06d}.png")) 43 | print(f"Image {img_idx} saved.") 44 | 45 | with open(os.path.join(sample_folder_dir, "prompt.txt"), "w") as f: 46 | f.write("\n".join(prompts)) 47 | 48 | 49 | def main(args): 50 | device = torch.device("cuda") 51 | 52 | torch.manual_seed(args.seed) 53 | 54 | config = demo_util.get_config(args.config) 55 | tokenizer = AutoModel.from_pretrained(config.experiment.tokenizer_checkpoint) 56 | tokenizer.eval() 57 | tokenizer.requires_grad_(False) 58 | 59 | generator = demo_util.get_generator(config) 60 | 61 | tokenizer = tokenizer.to(device) 62 | generator = generator.to(device) 63 | 64 | safety_checker_tokenizer = AutoTokenizer.from_pretrained(args.shield_model_path) 65 | safety_checker_model = AutoModelForCausalLM.from_pretrained( 66 | args.shield_model_path, 67 | device_map="auto", 68 | torch_dtype=torch.bfloat16, 69 | ).to(device) 70 | 71 | prompts = [] 72 | if args.prompt_list: 73 | prompts = args.prompt_list 74 | elif args.prompt: 75 | prompts = [args.prompt] 76 | else: 77 | print( 78 | "No prompt is provided. Will randomly sample 4 prompts from default prompts." 79 | ) 80 | prompts = random.sample(default_prompts, 4) 81 | 82 | for idx, prompt in enumerate(prompts): 83 | if is_dangerous( 84 | safety_checker_tokenizer, safety_checker_model, prompt 85 | ): 86 | prompts[idx] = random.sample(default_prompts, 1)[0] 87 | print( 88 | f"Detected Unsafe prompt with index {idx}, will replace by one of default prompts." 89 | ) 90 | 91 | start_time = time.time() 92 | generated_images = demo_util.sample_fn( 93 | generator=generator, 94 | tokenizer=tokenizer, 95 | conditions=prompts, 96 | randomize_temperature=config.model.generator.randomize_temperature, 97 | softmax_temperature_annealing=True, 98 | num_sample_steps=config.model.generator.get('num_steps', 256), 99 | guidance_scale=config.model.generator.guidance_scale, 100 | guidance_decay=config.model.generator.get('guidance_decay', 'power-cosine'), 101 | guidance_scale_pow=config.model.generator.get('guidance_scale_pow', 2.75), 102 | return_tensor=True, 103 | mean=config.dataset.eval.mean, 104 | std=config.dataset.eval.std, 105 | device=device, 106 | hybrid=config.model.get('hybrid', False), 107 | model_type=config.model.generator.get('type', 'maskgit') 108 | ).clamp_(0., 1.) 109 | 110 | total_time = time.time() - start_time 111 | print(f"Generating {len(prompts)} images takes {total_time:2f}s.") 112 | 113 | save_images( 114 | generated_images.clone(), 115 | args.sample_folder_dir, 116 | args.store_seperately, 117 | prompts 118 | ) 119 | 120 | 121 | if __name__ == '__main__': 122 | parser = argparse.ArgumentParser() 123 | parser.add_argument( 124 | "--config", 125 | type=str, 126 | default="configs/inference/dc_ar_t2i_512.yaml", 127 | ) 128 | parser.add_argument( 129 | "--shield_model_path", 130 | type=str, 131 | help="The path to shield model, we employ ShieldGemma-2B by default.", 132 | default="pretrained_models/shieldgemma-2b", 133 | ) 134 | parser.add_argument( 135 | "--prompt", 136 | type=str, 137 | help="A single prompt.", 138 | default="" 139 | ) 140 | parser.add_argument( 141 | "--prompt_list", 142 | type=list[str], 143 | default=[] 144 | ) 145 | parser.add_argument( 146 | "--seed", 147 | type=int, 148 | default=42, 149 | ) 150 | parser.add_argument( 151 | "--sample_folder_dir", 152 | type=str, 153 | help="The folder where the image samples are stored", 154 | default="samples/examples/", 155 | ) 156 | parser.add_argument( 157 | "--store_seperately", 158 | help="Store image samples in a grid or separately, set to False by default.", 159 | action="store_true", 160 | ) 161 | args = parser.parse_args() 162 | 163 | main(args) -------------------------------------------------------------------------------- /modeling/tokenizer/utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | # This file is modified from https://github.com/mit-han-lab/efficientvit 18 | 19 | import os 20 | import sys 21 | from inspect import signature 22 | from urllib.request import urlretrieve 23 | 24 | import torch 25 | import yaml 26 | 27 | __all__ = [ 28 | "load_state_dict_from_file", 29 | "parse_with_yaml", 30 | "parse_unknown_args", 31 | "partial_update_config", 32 | "resolve_and_load_config", 33 | "build_kwargs_from_config", 34 | "load_config", 35 | "dump_config", 36 | "download_via_url", 37 | ] 38 | 39 | 40 | def load_state_dict_from_file( 41 | file: str, only_state_dict=True 42 | ) -> dict[str, torch.Tensor]: 43 | file = os.path.realpath(os.path.expanduser(file)) 44 | checkpoint = torch.load(file, map_location="cpu") 45 | if only_state_dict and "state_dict" in checkpoint: 46 | checkpoint = checkpoint["state_dict"] 47 | return checkpoint 48 | 49 | 50 | def parse_with_yaml(config_str: str) -> str | dict: 51 | try: 52 | # add space manually for dict 53 | if "{" in config_str and "}" in config_str and ":" in config_str: 54 | out_str = config_str.replace(":", ": ") 55 | else: 56 | out_str = config_str 57 | return yaml.safe_load(out_str) 58 | except ValueError: 59 | # return raw string if parsing fails 60 | return config_str 61 | 62 | 63 | def parse_unknown_args(unknown: list) -> dict: 64 | """Parse unknown args.""" 65 | index = 0 66 | parsed_dict = {} 67 | while index < len(unknown): 68 | key, val = unknown[index], unknown[index + 1] 69 | index += 2 70 | if not key.startswith("--"): 71 | continue 72 | key = key[2:] 73 | 74 | # try parsing with either dot notation or full yaml notation 75 | # Note that the vanilla case "--key value" will be parsed the same 76 | if "." in key: 77 | # key == a.b.c, val == val --> parsed_dict[a][b][c] = val 78 | keys = key.split(".") 79 | dict_to_update = parsed_dict 80 | for key in keys[:-1]: 81 | if not ( 82 | key in dict_to_update and isinstance(dict_to_update[key], dict) 83 | ): 84 | dict_to_update[key] = {} 85 | dict_to_update = dict_to_update[key] 86 | dict_to_update[keys[-1]] = parse_with_yaml( 87 | val 88 | ) # so we can parse lists, bools, etc... 89 | else: 90 | parsed_dict[key] = parse_with_yaml(val) 91 | return parsed_dict 92 | 93 | 94 | def partial_update_config(config: dict, partial_config: dict) -> dict: 95 | for key in partial_config: 96 | if ( 97 | key in config 98 | and isinstance(partial_config[key], dict) 99 | and isinstance(config[key], dict) 100 | ): 101 | partial_update_config(config[key], partial_config[key]) 102 | else: 103 | config[key] = partial_config[key] 104 | return config 105 | 106 | 107 | def resolve_and_load_config(path: str, config_name="config.yaml") -> dict: 108 | path = os.path.realpath(os.path.expanduser(path)) 109 | if os.path.isdir(path): 110 | config_path = os.path.join(path, config_name) 111 | else: 112 | config_path = path 113 | if os.path.isfile(config_path): 114 | pass 115 | else: 116 | raise Exception(f"Cannot find a valid config at {path}") 117 | config = load_config(config_path) 118 | return config 119 | 120 | 121 | def build_kwargs_from_config(config: dict, target_func: callable) -> dict[str, any]: 122 | valid_keys = list(signature(target_func).parameters) 123 | kwargs = {} 124 | for key in config: 125 | if key in valid_keys: 126 | kwargs[key] = config[key] 127 | return kwargs 128 | 129 | 130 | class SafeLoaderWithTuple(yaml.SafeLoader): 131 | """A yaml safe loader with python tuple loading capabilities.""" 132 | 133 | def construct_python_tuple(self, node): 134 | return tuple(self.construct_sequence(node)) 135 | 136 | 137 | SafeLoaderWithTuple.add_constructor( 138 | "tag:yaml.org,2002:python/tuple", SafeLoaderWithTuple.construct_python_tuple 139 | ) 140 | 141 | 142 | def load_config(filename: str) -> dict: 143 | """Load a yaml file.""" 144 | filename = os.path.realpath(os.path.expanduser(filename)) 145 | return yaml.load(open(filename), Loader=SafeLoaderWithTuple) 146 | 147 | 148 | def dump_config(config: dict, filename: str) -> None: 149 | """Dump a config file""" 150 | filename = os.path.realpath(os.path.expanduser(filename)) 151 | yaml.dump(config, open(filename, "w"), sort_keys=False) 152 | 153 | 154 | def download_via_url(url: str, save_path: str, overwrite=False) -> str | None: 155 | save_path = os.path.expanduser(save_path) 156 | try: 157 | os.makedirs(os.path.dirname(save_path), exist_ok=True) 158 | if not os.path.exists(save_path) or overwrite: 159 | sys.stderr.write(f"Downloading: {url} to {save_path}\n") 160 | urlretrieve(url, save_path) 161 | return save_path 162 | except Exception as e: 163 | # remove lock file so download can be executed next time. 164 | sys.stderr.write("Failed to download from url {url}\n{e}\n") 165 | return None 166 | -------------------------------------------------------------------------------- /modeling/modules/base_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | # This file is modified from https://github.com/bytedance/1d-tokenizer 18 | 19 | import os 20 | from typing import Union, Callable, Dict, Optional 21 | 22 | import torch 23 | 24 | 25 | class BaseModel(torch.nn.Module): 26 | 27 | def __init__(self): 28 | super().__init__() 29 | 30 | def save_pretrained_weight( 31 | self, 32 | save_directory: Union[str, os.PathLike], 33 | save_function: Callable = None, 34 | state_dict: Optional[Dict[str, torch.Tensor]] = None, 35 | ): 36 | """Saves a model and its configuration file to a directory. 37 | 38 | Args: 39 | save_directory: A string or os.PathLike, directory to which to save. 40 | Will be created if it doesn't exist. 41 | save_function: A Callable function, the function to use to save the state dictionary. 42 | Useful on distributed training like TPUs when one need to replace `torch.save` by 43 | another method. Can be configured with the environment variable `DIFFUSERS_SAVE_MODE`. 44 | state_dict: A dictionary from str to torch.Tensor, the state dictionary to save. 45 | If `None`, the model's state dictionary will be saved. 46 | """ 47 | if os.path.isfile(save_directory): 48 | print(f"Provided path ({save_directory}) should be a directory, not a file") 49 | return 50 | 51 | if save_function is None: 52 | save_function = torch.save 53 | 54 | os.makedirs(save_directory, exist_ok=True) 55 | 56 | model_to_save = self 57 | 58 | if state_dict is None: 59 | state_dict = model_to_save.state_dict() 60 | weights_name = "pytorch_model.bin" 61 | 62 | save_function(state_dict, os.path.join(save_directory, weights_name)) 63 | 64 | print(f"Model weights saved in {os.path.join(save_directory, weights_name)}") 65 | 66 | def load_pretrained_weight( 67 | self, 68 | pretrained_model_path: Union[str, os.PathLike], 69 | strict_loading: bool = True, 70 | torch_dtype: Optional[torch.dtype] = None 71 | ): 72 | r"""Instantiates a pretrained pytorch model from a pre-trained model configuration. 73 | 74 | The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train 75 | the model, you should first set it back in training mode with `model.train()`. 76 | 77 | Args: 78 | pretrained_model_path: A string or os.PathLike, a path to a *directory* or *file* containing model weights. 79 | 80 | Raises: 81 | ValueError: If pretrained_model_path does not exist. 82 | """ 83 | # If pretrained_model_path is a file, set model_file to this file. 84 | if os.path.isfile(pretrained_model_path): 85 | model_file = pretrained_model_path 86 | # If pretrained_model_path is a directory, set model_file to the path of the 87 | # file "pytorch_model.bin" in this directory. 88 | elif os.path.isdir(pretrained_model_path): 89 | pretrained_model_path = os.path.join(pretrained_model_path, "pytorch_model.bin") 90 | if os.path.isfile(pretrained_model_path): 91 | model_file = pretrained_model_path 92 | else: 93 | raise ValueError(f"{pretrained_model_path} does not exist") 94 | else: 95 | raise ValueError(f"{pretrained_model_path} does not exist") 96 | 97 | # Load model state from checkpoint. 98 | checkpoint = torch.load(model_file, map_location="cpu") 99 | # Load state dictionary into self. 100 | msg = self.load_state_dict(checkpoint, strict=strict_loading) 101 | # Print information about loading weights. 102 | print(f"loading weight from {model_file}, msg: {msg}") 103 | # If torch_dtype is specified and is a valid torch.dtype, convert self to this dtype. 104 | if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): 105 | raise ValueError( 106 | f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." 107 | ) 108 | elif torch_dtype is not None: 109 | self.to(torch_dtype) 110 | 111 | # Set model in evaluation mode to deactivate DropOut modules by default. 112 | self.eval() 113 | 114 | def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int: 115 | """Gets the number of parameters in the module. 116 | 117 | Args: 118 | only_trainable: A boolean, whether to only include trainable parameters. 119 | exclude_embeddings: A boolean, whether to exclude parameters associated with embeddings. 120 | 121 | Returns: 122 | An integer, the number of parameters. 123 | """ 124 | 125 | if exclude_embeddings: 126 | embedding_param_names = [ 127 | f"{name}.weight" 128 | for name, module_type in self.named_modules() 129 | if isinstance(module_type, torch.nn.Embedding) 130 | ] 131 | non_embedding_parameters = [ 132 | parameter for name, parameter in self.named_parameters() if name not in embedding_param_names 133 | ] 134 | return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable) 135 | else: 136 | return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable) 137 | 138 | -------------------------------------------------------------------------------- /utils/demo_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | # This file is modified from https://github.com/bytedance/1d-tokenizer 18 | 19 | import torch 20 | 21 | from omegaconf import OmegaConf 22 | from modeling.generator.dc_ar import DCAR 23 | 24 | def get_config_cli(): 25 | cli_conf = OmegaConf.from_cli() 26 | 27 | yaml_conf = OmegaConf.load(cli_conf.config) 28 | conf = OmegaConf.merge(yaml_conf, cli_conf) 29 | 30 | return conf 31 | 32 | def get_config(config_path): 33 | conf = OmegaConf.load(config_path) 34 | return conf 35 | 36 | 37 | def get_generator(config, tokenizer=None): 38 | if config.model.generator.model_type == 'DCAR': 39 | model_cls = DCAR 40 | else: 41 | raise ValueError(f"Unsupported model type {config.model.generator.model_type}") 42 | generator = model_cls(config) 43 | generator.load_state_dict(torch.load(config.experiment.generator_checkpoint, map_location="cpu")) 44 | generator.eval() 45 | generator.requires_grad_(False) 46 | return generator 47 | 48 | @torch.no_grad() 49 | def sample_fn(generator, 50 | tokenizer, 51 | conditions=None, 52 | guidance_scale=3.0, 53 | guidance_decay="constant", 54 | guidance_scale_pow=3.0, 55 | randomize_temperature=2.0, 56 | softmax_temperature_annealing=False, 57 | num_sample_steps=8, 58 | device="cuda", 59 | return_tensor=False, 60 | mean=0.0, 61 | std=1.0, 62 | hybrid=False, 63 | model_type='maskgit', 64 | init_image_tokens=None, 65 | init_residual_features=None, 66 | init_mask=None): 67 | generator.eval() 68 | tokenizer.eval() 69 | if model_type in ['maskgit']: 70 | if conditions is None: 71 | # goldfish, chicken, tiger, cat, hourglass, ship, dog, race car, airliner, teddy bear, random 72 | conditions = [1, 7, 282, 604, 724, 179, 751, 404, 850, torch.randint(0, 999, size=(1,))] 73 | if not isinstance(conditions, torch.Tensor): 74 | conditions = torch.LongTensor(conditions) 75 | conditions = conditions.to(device) 76 | elif model_type in ['maskgit_t2i', 'maskgit_t2i_inpainting']: 77 | # goldfish, chicken, tiger, cat, hourglass, ship, dog, race car, airliner, teddy bear, random 78 | if conditions is None: 79 | conditions = [ 80 | "dog", 81 | "portrait photo of a girl, photograph, highly detailed face, depth of field", 82 | "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k", 83 | "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", 84 | "A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece", 85 | # MJV6 dataset 86 | "A young child is sitting on some pita bread next to some trees", 87 | "A slice of pizza", 88 | "A elegant sofa and chairs", 89 | "Group of people ski on the front of a field", 90 | "A person bent over on the mountain", 91 | "The action at home plate during a competitive baseball game as others watch", 92 | # overfitting prompts from training set 93 | "New Year card illustration, elegant nature style", 94 | "graphic for vintage stereo cassette, in the style of fauvist colors, kawaii aesthetic, flickr, color-streaked, holography, crisp outlines, strong use of color, white background, png", 95 | "russian snow day photos, art, in the style of adorable toy sculptures, orange and gold, cute cartoonish designs, neo-geo, frostpunk, emotive, orange", 96 | "Old vintage scrapbook wallpaper with line page, beautiful Black Eyed Susan pattern border, watercolor, high detail, HDR, self shadow, unique, intricate detail, hand-painted", 97 | "change to blue tang dynasty costumes, unchanged appearance", 98 | ] 99 | 100 | if model_type in ['maskgit_t2i', 'maskgit_t2i_inpainting']: 101 | generated_tokens = generator.generate( 102 | init_image_tokens=init_image_tokens, 103 | init_residual_features=init_residual_features, 104 | init_mask=init_mask, 105 | condition=conditions, 106 | guidance_scale=guidance_scale, 107 | guidance_decay=guidance_decay, 108 | guidance_scale_pow=guidance_scale_pow, 109 | randomize_temperature=randomize_temperature, 110 | softmax_temperature_annealing=softmax_temperature_annealing, 111 | num_sample_steps=num_sample_steps) 112 | else: 113 | raise NotImplementedError 114 | if isinstance(generated_tokens, tuple): 115 | generated_tokens, residual_features = generated_tokens 116 | if model_type in ['maskgit_t2i', 'maskgit_t2i_inpainting']: 117 | if hybrid: 118 | generated_image = tokenizer.decode_tokens( 119 | generated_tokens, residual_features 120 | ).mul_(std).add_(mean) 121 | else: 122 | generated_image = tokenizer.decode_tokens( 123 | generated_tokens, None 124 | ).mul_(std).add_(mean) 125 | else: 126 | raise NotImplementedError 127 | 128 | 129 | if return_tensor: 130 | return generated_image 131 | 132 | generated_image = torch.clamp(generated_image, 0.0, 1.0) 133 | generated_image = (generated_image * 255.0).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy() 134 | 135 | return generated_image 136 | -------------------------------------------------------------------------------- /modeling/tokenizer/networks/dc_ae_blocks/norm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | # This file is modified from https://github.com/mit-han-lab/efficientvit 18 | 19 | from typing import Optional 20 | 21 | import torch 22 | import torch.nn as nn 23 | from torch.nn.modules.batchnorm import _BatchNorm 24 | 25 | from .triton_rms_norm import TritonRMSNorm2dFunc 26 | from modeling.tokenizer.utils import build_kwargs_from_config 27 | 28 | __all__ = ["LayerNorm2d", "TritonRMSNorm2d", "build_norm", "reset_bn", "set_norm_eps"] 29 | 30 | 31 | class LayerNorm2d(nn.LayerNorm): 32 | def forward(self, x: torch.Tensor) -> torch.Tensor: 33 | out = x - torch.mean(x, dim=1, keepdim=True) 34 | out = out / torch.sqrt(torch.square(out).mean(dim=1, keepdim=True) + self.eps) 35 | if self.elementwise_affine: 36 | out = out * self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1) 37 | return out 38 | 39 | 40 | class TritonRMSNorm2d(nn.LayerNorm): 41 | def forward(self, x: torch.Tensor) -> torch.Tensor: 42 | return TritonRMSNorm2dFunc.apply(x, self.weight, self.bias, self.eps) 43 | 44 | 45 | # register normalization function here 46 | REGISTERED_NORM_DICT: dict[str, type] = { 47 | "bn2d": nn.BatchNorm2d, 48 | "ln": nn.LayerNorm, 49 | "ln2d": LayerNorm2d, 50 | "trms2d": TritonRMSNorm2d, 51 | } 52 | 53 | 54 | def build_norm(name="bn2d", num_features=None, **kwargs) -> Optional[nn.Module]: 55 | if name in ["ln", "ln2d", "trms2d"]: 56 | kwargs["normalized_shape"] = num_features 57 | else: 58 | kwargs["num_features"] = num_features 59 | if name in REGISTERED_NORM_DICT: 60 | norm_cls = REGISTERED_NORM_DICT[name] 61 | args = build_kwargs_from_config(kwargs, norm_cls) 62 | return norm_cls(**args) 63 | else: 64 | return None 65 | 66 | 67 | def reset_bn( 68 | model: nn.Module, 69 | data_loader: list, 70 | sync=True, 71 | progress_bar=False, 72 | ) -> None: 73 | import copy 74 | 75 | import torch.nn.functional as F 76 | from tqdm import tqdm 77 | 78 | from apps.utils import AverageMeter, is_master, sync_tensor, get_device, list_join 79 | 80 | bn_mean = {} 81 | bn_var = {} 82 | 83 | tmp_model = copy.deepcopy(model) 84 | for name, m in tmp_model.named_modules(): 85 | if isinstance(m, _BatchNorm): 86 | bn_mean[name] = AverageMeter(is_distributed=False) 87 | bn_var[name] = AverageMeter(is_distributed=False) 88 | 89 | def new_forward(bn, mean_est, var_est): 90 | def lambda_forward(x): 91 | x = x.contiguous() 92 | if sync: 93 | batch_mean = x.mean(0, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) # 1, C, 1, 1 94 | batch_mean = sync_tensor(batch_mean, reduce="cat") 95 | batch_mean = torch.mean(batch_mean, dim=0, keepdim=True) 96 | 97 | batch_var = (x - batch_mean) * (x - batch_mean) 98 | batch_var = batch_var.mean(0, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) 99 | batch_var = sync_tensor(batch_var, reduce="cat") 100 | batch_var = torch.mean(batch_var, dim=0, keepdim=True) 101 | else: 102 | batch_mean = x.mean(0, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) # 1, C, 1, 1 103 | batch_var = (x - batch_mean) * (x - batch_mean) 104 | batch_var = batch_var.mean(0, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) 105 | 106 | batch_mean = torch.squeeze(batch_mean) 107 | batch_var = torch.squeeze(batch_var) 108 | 109 | mean_est.update(batch_mean.data, x.size(0)) 110 | var_est.update(batch_var.data, x.size(0)) 111 | 112 | # bn forward using calculated mean & var 113 | _feature_dim = batch_mean.shape[0] 114 | return F.batch_norm( 115 | x, 116 | batch_mean, 117 | batch_var, 118 | bn.weight[:_feature_dim], 119 | bn.bias[:_feature_dim], 120 | False, 121 | 0.0, 122 | bn.eps, 123 | ) 124 | 125 | return lambda_forward 126 | 127 | m.forward = new_forward(m, bn_mean[name], bn_var[name]) 128 | 129 | # skip if there is no batch normalization layers in the network 130 | if len(bn_mean) == 0: 131 | return 132 | 133 | tmp_model.eval() 134 | with torch.no_grad(): 135 | with tqdm(total=len(data_loader), desc="reset bn", disable=not progress_bar or not is_master()) as t: 136 | for images in data_loader: 137 | images = images.to(get_device(tmp_model)) 138 | tmp_model(images) 139 | t.set_postfix( 140 | { 141 | "bs": images.size(0), 142 | "res": list_join(images.shape[-2:], "x"), 143 | } 144 | ) 145 | t.update() 146 | 147 | for name, m in model.named_modules(): 148 | if name in bn_mean and bn_mean[name].count > 0: 149 | feature_dim = bn_mean[name].avg.size(0) 150 | assert isinstance(m, _BatchNorm) 151 | m.running_mean.data[:feature_dim].copy_(bn_mean[name].avg) 152 | m.running_var.data[:feature_dim].copy_(bn_var[name].avg) 153 | 154 | 155 | def set_norm_eps(model: nn.Module, eps: Optional[float] = None) -> None: 156 | for m in model.modules(): 157 | if isinstance(m, (nn.GroupNorm, nn.LayerNorm, _BatchNorm)): 158 | if eps is not None: 159 | m.eps = eps -------------------------------------------------------------------------------- /modeling/diffusion/respace.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | # Modified from OpenAI's diffusion repos 18 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 19 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 20 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 21 | import numpy as np 22 | import torch as th 23 | 24 | from .gaussian_diffusion import GaussianDiffusion 25 | 26 | 27 | def space_timesteps(num_timesteps, section_counts): 28 | """ 29 | Create a list of timesteps to use from an original diffusion process, 30 | given the number of timesteps we want to take from equally-sized portions 31 | of the original process. 32 | For example, if there's 300 timesteps and the section counts are [10,15,20] 33 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 34 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 35 | If the stride is a string starting with "ddim", then the fixed striding 36 | from the DDIM paper is used, and only one section is allowed. 37 | :param num_timesteps: the number of diffusion steps in the original 38 | process to divide up. 39 | :param section_counts: either a list of numbers, or a string containing 40 | comma-separated numbers, indicating the step count 41 | per section. As a special case, use "ddimN" where N 42 | is a number of steps to use the striding from the 43 | DDIM paper. 44 | :return: a set of diffusion steps from the original process to use. 45 | """ 46 | if isinstance(section_counts, str): 47 | if section_counts.startswith("ddim"): 48 | desired_count = int(section_counts[len("ddim") :]) 49 | for i in range(1, num_timesteps): 50 | if len(range(0, num_timesteps, i)) == desired_count: 51 | return set(range(0, num_timesteps, i)) 52 | raise ValueError( 53 | f"cannot create exactly {num_timesteps} steps with an integer stride" 54 | ) 55 | section_counts = [int(x) for x in section_counts.split(",")] 56 | size_per = num_timesteps // len(section_counts) 57 | extra = num_timesteps % len(section_counts) 58 | start_idx = 0 59 | all_steps = [] 60 | for i, section_count in enumerate(section_counts): 61 | size = size_per + (1 if i < extra else 0) 62 | if size < section_count: 63 | raise ValueError( 64 | f"cannot divide section of {size} steps into {section_count}" 65 | ) 66 | if section_count <= 1: 67 | frac_stride = 1 68 | else: 69 | frac_stride = (size - 1) / (section_count - 1) 70 | cur_idx = 0.0 71 | taken_steps = [] 72 | for _ in range(section_count): 73 | taken_steps.append(start_idx + round(cur_idx)) 74 | cur_idx += frac_stride 75 | all_steps += taken_steps 76 | start_idx += size 77 | return set(all_steps) 78 | 79 | 80 | class SpacedDiffusion(GaussianDiffusion): 81 | """ 82 | A diffusion process which can skip steps in a base diffusion process. 83 | :param use_timesteps: a collection (sequence or set) of timesteps from the 84 | original diffusion process to retain. 85 | :param kwargs: the kwargs to create the base diffusion process. 86 | """ 87 | 88 | def __init__(self, use_timesteps, **kwargs): 89 | self.use_timesteps = set(use_timesteps) 90 | self.timestep_map = [] 91 | self.original_num_steps = len(kwargs["betas"]) 92 | 93 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 94 | last_alpha_cumprod = 1.0 95 | new_betas = [] 96 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 97 | if i in self.use_timesteps: 98 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 99 | last_alpha_cumprod = alpha_cumprod 100 | self.timestep_map.append(i) 101 | kwargs["betas"] = np.array(new_betas) 102 | super().__init__(**kwargs) 103 | 104 | def p_mean_variance( 105 | self, model, *args, **kwargs 106 | ): # pylint: disable=signature-differs 107 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 108 | 109 | def training_losses( 110 | self, model, *args, **kwargs 111 | ): # pylint: disable=signature-differs 112 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 113 | 114 | def condition_mean(self, cond_fn, *args, **kwargs): 115 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) 116 | 117 | def condition_score(self, cond_fn, *args, **kwargs): 118 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) 119 | 120 | def _wrap_model(self, model): 121 | if isinstance(model, _WrappedModel): 122 | return model 123 | return _WrappedModel(model, self.timestep_map, self.original_num_steps) 124 | 125 | def _scale_timesteps(self, t): 126 | # Scaling is done by the wrapped model. 127 | return t 128 | 129 | 130 | class _WrappedModel: 131 | def __init__(self, model, timestep_map, original_num_steps): 132 | self.model = model 133 | self.timestep_map = timestep_map 134 | # self.rescale_timesteps = rescale_timesteps 135 | self.original_num_steps = original_num_steps 136 | 137 | def __call__(self, x, ts, **kwargs): 138 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 139 | new_ts = map_tensor[ts] 140 | # if self.rescale_timesteps: 141 | # new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 142 | return self.model(x, new_ts, **kwargs) -------------------------------------------------------------------------------- /modeling/tokenizer/networks/dc_ae_blocks/triton_rms_norm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | # This file is modified from https://github.com/mit-han-lab/efficientvit 18 | 19 | import torch 20 | import triton 21 | import triton.language as tl 22 | 23 | __all__ = ["TritonRMSNorm2dFunc"] 24 | 25 | 26 | @triton.jit 27 | def _rms_norm_2d_fwd_fused( 28 | X, # pointer to the input 29 | Y, # pointer to the output 30 | W, # pointer to the weights 31 | B, # pointer to the biases 32 | Rrms, # pointer to the 1/rms 33 | M, 34 | C, 35 | N, 36 | num_blocks, # number of columns in X 37 | eps, # epsilon to avoid division by zero 38 | BLOCK_SIZE: tl.constexpr, 39 | ): 40 | # Map the program id to the row of X and Y it should compute. 41 | m_n = tl.program_id(0) 42 | m, n = m_n // num_blocks, m_n % num_blocks 43 | 44 | Y += m * C * N 45 | X += m * C * N 46 | # Compute mean 47 | 48 | cols = n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 49 | mask = cols < N 50 | 51 | x_sum_square = tl.zeros([BLOCK_SIZE], dtype=tl.float32) 52 | for off in range(0, C): 53 | x = tl.load(X + off * N + cols, mask=mask, other=0.0).to(tl.float32) 54 | x_sum_square += x * x 55 | mean_square = x_sum_square / C 56 | rrms = 1 / tl.sqrt(mean_square + eps) 57 | # Write rstd 58 | tl.store(Rrms + m * N + cols, rrms, mask=mask) 59 | # Normalize and apply linear transformation 60 | for off in range(0, C): 61 | pos = off * N + cols 62 | w = tl.load(W + off) 63 | b = tl.load(B + off) 64 | x = tl.load(X + pos, mask=mask, other=0.0).to(tl.float32) 65 | x_hat = x * rrms 66 | y = x_hat * w + b 67 | # Write output 68 | tl.store(Y + pos, y, mask=mask) 69 | 70 | 71 | @triton.jit 72 | def _rms_norm_2d_bwd_dx_fused( 73 | DX, # pointer to the input gradient 74 | DY, # pointer to the output gradient 75 | DW, # pointer to the partial sum of weights gradient 76 | DB, # pointer to the partial sum of biases gradient 77 | X, # pointer to the input 78 | W, # pointer to the weights 79 | B, # pointer to the biases 80 | Rrms, # pointer to the 1/rms 81 | M, 82 | C, 83 | N, # number of columns in X 84 | num_blocks, 85 | eps, # epsilon to avoid division by zero 86 | GROUP_SIZE_M: tl.constexpr, 87 | BLOCK_SIZE: tl.constexpr, 88 | BLOCK_SIZE_C: tl.constexpr, 89 | ): 90 | # Map the program id to the elements of X, DX, and DY it should compute. 91 | m_n = tl.program_id(0) 92 | m, n = m_n // num_blocks, m_n % num_blocks 93 | X += m * C * N 94 | DY += m * C * N 95 | DX += m * C * N 96 | Rrms += m * N 97 | 98 | cols = n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 99 | mask = cols < N 100 | # Offset locks and weights/biases gradient pointer for parallel reduction 101 | DW = DW + m_n * C 102 | DB = DB + m_n * C 103 | rrms = tl.load(Rrms + cols, mask=mask, other=1) 104 | # Load data to SRAM 105 | c1 = tl.zeros([BLOCK_SIZE], dtype=tl.float32) 106 | for off in range(0, C): 107 | pos = off * N + cols 108 | x = tl.load(X + pos, mask=mask, other=0).to(tl.float32) 109 | dy = tl.load(DY + pos, mask=mask, other=0).to(tl.float32) 110 | w = tl.load(W + off).to(tl.float32) 111 | # Compute dx 112 | xhat = x * rrms 113 | wdy = w * dy 114 | xhat = tl.where(mask, xhat, 0.0) 115 | wdy = tl.where(mask, wdy, 0.0) 116 | c1 += xhat * wdy 117 | # Accumulate partial sums for dw/db 118 | tl.store(DW + off, tl.sum((dy * xhat).to(w.dtype), axis=0)) 119 | tl.store(DB + off, tl.sum(dy.to(w.dtype), axis=0)) 120 | 121 | c1 /= C 122 | for off in range(0, C): 123 | pos = off * N + cols 124 | x = tl.load(X + pos, mask=mask, other=0).to(tl.float32) 125 | dy = tl.load(DY + pos, mask=mask, other=0).to(tl.float32) 126 | w = tl.load(W + off).to(tl.float32) 127 | xhat = x * rrms 128 | wdy = w * dy 129 | dx = (wdy - (xhat * c1)) * rrms 130 | # Write dx 131 | tl.store(DX + pos, dx, mask=mask) 132 | 133 | 134 | class TritonRMSNorm2dFunc(torch.autograd.Function): 135 | @staticmethod 136 | def forward(ctx, x, weight, bias, eps): 137 | # allocate output 138 | y = torch.empty_like(x) 139 | # reshape input data into 2D tensor 140 | x_arg = x.reshape(x.shape[0], x.shape[1], -1) 141 | M, C, N = x_arg.shape 142 | rrms = torch.empty((M, N), dtype=torch.float32, device="cuda") 143 | # Less than 64KB per feature: enqueue fused kernel 144 | BLOCK_SIZE = 256 145 | num_blocks = triton.cdiv(N, BLOCK_SIZE) 146 | num_warps = 8 147 | # enqueue kernel 148 | _rms_norm_2d_fwd_fused[(M * num_blocks,)]( # 149 | x_arg, 150 | y, 151 | weight, 152 | bias, 153 | rrms, # 154 | M, 155 | C, 156 | N, 157 | num_blocks, 158 | eps, # 159 | BLOCK_SIZE=BLOCK_SIZE, 160 | num_warps=num_warps, 161 | num_ctas=1, 162 | ) 163 | ctx.save_for_backward(x, weight, bias, rrms) 164 | ctx.BLOCK_SIZE = BLOCK_SIZE 165 | ctx.num_blocks = num_blocks 166 | ctx.num_warps = num_warps 167 | ctx.eps = eps 168 | return y 169 | 170 | @staticmethod 171 | def backward(ctx, dy): 172 | x, w, b, rrms = ctx.saved_tensors 173 | num_blocks = ctx.num_blocks 174 | 175 | x_arg = x.reshape(x.shape[0], x.shape[1], -1) 176 | M, C, N = x_arg.shape 177 | # GROUP_SIZE_M = 64 178 | GROUP_SIZE_M = M * num_blocks 179 | # allocate output 180 | _dw = torch.empty((GROUP_SIZE_M, C), dtype=x.dtype, device=w.device) 181 | _db = torch.empty((GROUP_SIZE_M, C), dtype=x.dtype, device=w.device) 182 | dw = torch.empty((C,), dtype=w.dtype, device=w.device) 183 | db = torch.empty((C,), dtype=w.dtype, device=w.device) 184 | dx = torch.empty_like(dy) 185 | # enqueue kernel using forward pass heuristics 186 | # also compute partial sums for DW and DB 187 | # print(f"M={M}, num_blocks={num_blocks}, dx={dx.shape}, dy={dy.shape}, _dw={_dw.shape}, _db={_db.shape}, x={x.shape}, w={w.shape}, b={b.shape}, m={m.shape}, v={v.shape}, M={M}, C={C}, N={N}") 188 | _rms_norm_2d_bwd_dx_fused[(M * num_blocks,)]( # 189 | dx, 190 | dy, 191 | _dw, 192 | _db, 193 | x, 194 | w, 195 | b, 196 | rrms, # 197 | M, 198 | C, 199 | N, 200 | num_blocks, 201 | ctx.eps, # 202 | BLOCK_SIZE=ctx.BLOCK_SIZE, 203 | GROUP_SIZE_M=GROUP_SIZE_M, # 204 | BLOCK_SIZE_C=triton.next_power_of_2(C), 205 | num_warps=ctx.num_warps, 206 | ) 207 | dw = _dw.sum(dim=0) 208 | db = _db.sum(dim=0) 209 | return dx, dw, db, None 210 | -------------------------------------------------------------------------------- /modeling/generator/net/norms.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | # This file is modified from https://github.com/NVlabs/Sana 18 | 19 | import copy 20 | import warnings 21 | 22 | import torch 23 | import torch.nn as nn 24 | from torch.nn.modules.batchnorm import _BatchNorm 25 | 26 | __all__ = ["LayerNorm2d", "build_norm", "get_norm_name", "reset_bn", "remove_bn", "set_norm_eps"] 27 | 28 | 29 | class LayerNorm2d(nn.LayerNorm): 30 | rmsnorm = False 31 | 32 | def forward(self, x: torch.Tensor) -> torch.Tensor: 33 | out = x if LayerNorm2d.rmsnorm else x - torch.mean(x, dim=1, keepdim=True) 34 | out = out / torch.sqrt(torch.square(out).mean(dim=1, keepdim=True) + self.eps) 35 | if self.elementwise_affine: 36 | out = out * self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1) 37 | return out 38 | 39 | def extra_repr(self) -> str: 40 | return f"{self.normalized_shape}, eps={self.eps}, elementwise_affine={self.elementwise_affine}, rmsnorm={self.rmsnorm}" 41 | 42 | 43 | # register normalization function here 44 | # name: module, kwargs with default values 45 | REGISTERED_NORMALIZATION_DICT: dict[str, tuple[type, dict[str, any]]] = { 46 | "bn2d": (nn.BatchNorm2d, {"num_features": None, "eps": 1e-5, "momentum": 0.1, "affine": True}), 47 | "syncbn": (nn.SyncBatchNorm, {"num_features": None, "eps": 1e-5, "momentum": 0.1, "affine": True}), 48 | "ln": (nn.LayerNorm, {"normalized_shape": None, "eps": 1e-5, "elementwise_affine": True}), 49 | "ln2d": (LayerNorm2d, {"normalized_shape": None, "eps": 1e-5, "elementwise_affine": True}), 50 | } 51 | 52 | 53 | def build_norm(name="bn2d", num_features=None, affine=True, **kwargs) -> nn.Module: 54 | if name in ["ln", "ln2d"]: 55 | kwargs["normalized_shape"] = num_features 56 | kwargs["elementwise_affine"] = affine 57 | else: 58 | kwargs["num_features"] = num_features 59 | kwargs["affine"] = affine 60 | if name in REGISTERED_NORMALIZATION_DICT: 61 | norm_cls, default_args = copy.deepcopy(REGISTERED_NORMALIZATION_DICT[name]) 62 | for key in default_args: 63 | if key in kwargs: 64 | default_args[key] = kwargs[key] 65 | return norm_cls(**default_args) 66 | elif name is None or name.lower() == "none": 67 | return None 68 | else: 69 | raise ValueError("do not support: %s" % name) 70 | 71 | 72 | def get_norm_name(norm: nn.Module) -> str: 73 | if norm is None: 74 | return None 75 | module2name = {} 76 | for key, config in REGISTERED_NORMALIZATION_DICT.items(): 77 | module2name[config[0].__name__] = key 78 | return module2name.get(type(norm).__name__, "unknown") 79 | 80 | 81 | def reset_bn( 82 | model: nn.Module, 83 | data_loader: list, 84 | sync=True, 85 | progress_bar=False, 86 | ) -> None: 87 | import copy 88 | 89 | import torch.nn.functional as F 90 | from apps.utils import AverageMeter, is_master, sync_tensor, get_device, list_join 91 | from tqdm import tqdm 92 | 93 | bn_mean = {} 94 | bn_var = {} 95 | 96 | tmp_model = copy.deepcopy(model) 97 | for name, m in tmp_model.named_modules(): 98 | if isinstance(m, _BatchNorm): 99 | bn_mean[name] = AverageMeter(is_distributed=False) 100 | bn_var[name] = AverageMeter(is_distributed=False) 101 | 102 | def new_forward(bn, mean_est, var_est): 103 | def lambda_forward(x): 104 | x = x.contiguous() 105 | if sync: 106 | batch_mean = x.mean(0, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) # 1, C, 1, 1 107 | batch_mean = sync_tensor(batch_mean, reduce="cat") 108 | batch_mean = torch.mean(batch_mean, dim=0, keepdim=True) 109 | 110 | batch_var = (x - batch_mean) * (x - batch_mean) 111 | batch_var = batch_var.mean(0, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) 112 | batch_var = sync_tensor(batch_var, reduce="cat") 113 | batch_var = torch.mean(batch_var, dim=0, keepdim=True) 114 | else: 115 | batch_mean = x.mean(0, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) # 1, C, 1, 1 116 | batch_var = (x - batch_mean) * (x - batch_mean) 117 | batch_var = batch_var.mean(0, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) 118 | 119 | batch_mean = torch.squeeze(batch_mean) 120 | batch_var = torch.squeeze(batch_var) 121 | 122 | mean_est.update(batch_mean.data, x.size(0)) 123 | var_est.update(batch_var.data, x.size(0)) 124 | 125 | # bn forward using calculated mean & var 126 | _feature_dim = batch_mean.shape[0] 127 | return F.batch_norm( 128 | x, 129 | batch_mean, 130 | batch_var, 131 | bn.weight[:_feature_dim], 132 | bn.bias[:_feature_dim], 133 | False, 134 | 0.0, 135 | bn.eps, 136 | ) 137 | 138 | return lambda_forward 139 | 140 | m.forward = new_forward(m, bn_mean[name], bn_var[name]) 141 | 142 | # skip if there is no batch normalization layers in the network 143 | if len(bn_mean) == 0: 144 | return 145 | 146 | tmp_model.eval() 147 | with torch.inference_mode(): 148 | with tqdm(total=len(data_loader), desc="reset bn", disable=not progress_bar or not is_master()) as t: 149 | for images in data_loader: 150 | images = images.to(get_device(tmp_model)) 151 | tmp_model(images) 152 | t.set_postfix( 153 | { 154 | "bs": images.size(0), 155 | "res": list_join(images.shape[-2:], "x"), 156 | } 157 | ) 158 | t.update() 159 | 160 | for name, m in model.named_modules(): 161 | if name in bn_mean and bn_mean[name].count > 0: 162 | feature_dim = bn_mean[name].avg.size(0) 163 | assert isinstance(m, _BatchNorm) 164 | m.running_mean.data[:feature_dim].copy_(bn_mean[name].avg) 165 | m.running_var.data[:feature_dim].copy_(bn_var[name].avg) 166 | 167 | 168 | def remove_bn(model: nn.Module) -> None: 169 | for m in model.modules(): 170 | if isinstance(m, _BatchNorm): 171 | m.weight = m.bias = None 172 | m.forward = lambda x: x 173 | 174 | 175 | def set_norm_eps(model: nn.Module, eps: float = None, momentum: float = None) -> None: 176 | for m in model.modules(): 177 | if isinstance(m, (nn.GroupNorm, nn.LayerNorm, _BatchNorm)): 178 | if eps is not None: 179 | m.eps = eps 180 | if momentum is not None: 181 | m.momentum = momentum 182 | 183 | 184 | class RMSNorm(torch.nn.Module): 185 | def __init__(self, dim: int, scale_factor=1.0, eps: float = 1e-6): 186 | """ 187 | Initialize the RMSNorm normalization layer. 188 | 189 | Args: 190 | dim (int): The dimension of the input tensor. 191 | eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. 192 | 193 | Attributes: 194 | eps (float): A small value added to the denominator for numerical stability. 195 | weight (nn.Parameter): Learnable scaling parameter. 196 | 197 | """ 198 | super().__init__() 199 | self.eps = eps 200 | self.weight = nn.Parameter(torch.ones(dim) * scale_factor) 201 | 202 | def _norm(self, x): 203 | """ 204 | Apply the RMSNorm normalization to the input tensor. 205 | 206 | Args: 207 | x (torch.Tensor): The input tensor. 208 | 209 | Returns: 210 | torch.Tensor: The normalized tensor. 211 | 212 | """ 213 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 214 | 215 | def forward(self, x): 216 | """ 217 | Forward pass through the RMSNorm layer. 218 | 219 | Args: 220 | x (torch.Tensor): The input tensor. 221 | 222 | Returns: 223 | torch.Tensor: The output tensor after applying RMSNorm. 224 | 225 | """ 226 | return (self.weight * self._norm(x.float())).type_as(x) -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import uuid 5 | 6 | import gradio as gr 7 | import numpy as np 8 | import spaces 9 | import torch 10 | from PIL import Image 11 | from transformers import ( 12 | AutoConfig, 13 | AutoModel, 14 | AutoModelForCausalLM, 15 | AutoTokenizer, 16 | HfArgumentParser, 17 | set_seed 18 | ) 19 | 20 | from modeling.tokenizer.dc_ht import DCHT 21 | from utils.safety_check import is_dangerous 22 | import utils.demo_util as demo_util 23 | 24 | # DESCRIPTION = ( 25 | # """ 26 | #

DC-AR-0.7B512px

27 | #

DC-AR: Efficient Masked Autoregressive Image Generation with Deep Compression Hybrid Tokenizer

28 | #

Powered by DC-AE and HART

29 | # """ + """\n

Note: We will replace unsafe prompts with a default prompt: \"A red heart.\"

""" 30 | # ) 31 | DESCRIPTION = f""" 32 |

DC-AR-0.7B512px

33 |

DC-AR: Efficient Masked Autoregressive Image Generation with Deep Compression Hybrid Tokenizer

34 |

Powered by DC-AE, and HART.

35 |

Unsafe word will give you a 'Red Heart❤️' in the image instead.

36 | """ 37 | if not torch.cuda.is_available(): 38 | DESCRIPTION += "\n

Running on CPU 🥶 This demo may not work on CPU.

" 39 | 40 | MAX_SEED = np.iinfo(np.int32).max 41 | CACHE_EXAMPLES = False 42 | MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "512")) 43 | USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1" 44 | ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1" 45 | 46 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 47 | 48 | NUM_IMAGES_PER_PROMPT = 1 49 | 50 | 51 | generator = None 52 | tokenizer = None 53 | safety_checker_tokenizer = None 54 | safety_checker_model = None 55 | 56 | def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: 57 | if randomize_seed: 58 | seed = random.randint(0, MAX_SEED) 59 | return seed 60 | 61 | 62 | @spaces.GPU(enable_queue=True) 63 | def generate( 64 | prompt: str, 65 | seed: int = 0, 66 | guidance_scale: float = 4.5, 67 | randomize_temperature: float = 1.5, 68 | randomize_seed: bool = False, 69 | num_sample_steps: int = 12, 70 | progress=gr.Progress(track_tqdm=True), 71 | ): 72 | global generator, tokenizer, safety_checker_tokenizer, safety_checker_model 73 | 74 | seed = int(randomize_seed_fn(seed, randomize_seed)) 75 | 76 | if is_dangerous(safety_checker_tokenizer, safety_checker_model, prompt): 77 | prompt = "A red heart." 78 | 79 | generated_images = demo_util.sample_fn( 80 | generator=generator, 81 | tokenizer=tokenizer, 82 | conditions=[prompt], 83 | guidance_scale=guidance_scale, 84 | guidance_decay='constant', 85 | randomize_temperature=randomize_temperature, 86 | softmax_temperature_annealing=True, 87 | num_sample_steps=num_sample_steps, 88 | device=device, 89 | mean=0.5, 90 | std=0.5, 91 | return_tensor=True, 92 | hybrid=True, 93 | model_type='maskgit_t2i', 94 | ).clamp_(0., 1.) 95 | 96 | images = [] 97 | sample_imgs_np = generated_images.clone().mul_(255).cpu().numpy() 98 | num_imgs = sample_imgs_np.shape[0] 99 | for img_idx in range(num_imgs): 100 | cur_img = sample_imgs_np[img_idx] 101 | cur_img = cur_img.transpose(1, 2, 0).astype(np.uint8) 102 | cur_img_store = Image.fromarray(cur_img) 103 | images.append(cur_img_store) 104 | 105 | return images, seed 106 | 107 | 108 | def main(args): 109 | 110 | global generator, tokenizer, safety_checker_tokenizer, safety_checker_model 111 | 112 | config = demo_util.get_config(args.config) 113 | tokenizer = AutoModel.from_pretrained(config.experiment.tokenizer_checkpoint) 114 | tokenizer.eval() 115 | tokenizer.requires_grad_(False) 116 | 117 | generator = demo_util.get_generator(config) 118 | 119 | tokenizer = tokenizer.to(device) 120 | generator = generator.to(device) 121 | 122 | safety_checker_tokenizer = AutoTokenizer.from_pretrained(args.shield_model_path) 123 | safety_checker_model = AutoModelForCausalLM.from_pretrained( 124 | args.shield_model_path, 125 | device_map="auto", 126 | torch_dtype=torch.bfloat16, 127 | ).to(device) 128 | 129 | examples = [ 130 | "melting apple", 131 | "A penguin wearing sunglasses on a beach.", 132 | "A moonlit path through a mystical forest.", 133 | "An astronaut riding a horse on the moon, oil painting by Van Gogh.", 134 | "A close-up photo of a honeycomb with bees actively working, golden honey visible in cells, wings a blur of movement.", 135 | "A train traveling through snowy mountains.", 136 | "A plate of cookies with a glass of milk.", 137 | "A close-up photo of a lotus flower emerging from muddy water, perfect pink petals opening toward sunlight, water droplets visible.", 138 | ] 139 | 140 | css = """ 141 | .gradio-container{max-width: 560px !important} 142 | h1{text-align:center} 143 | """ 144 | with gr.Blocks(css=css) as demo: 145 | gr.Markdown(DESCRIPTION) 146 | gr.DuplicateButton( 147 | value="Duplicate Space for private use", 148 | elem_id="duplicate-button", 149 | visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1", 150 | ) 151 | with gr.Group(): 152 | with gr.Row(): 153 | prompt = gr.Text( 154 | label="Prompt", 155 | show_label=False, 156 | max_lines=1, 157 | placeholder="Enter your prompt", 158 | container=False, 159 | ) 160 | run_button = gr.Button("Run", scale=0) 161 | 162 | result = gr.Gallery( 163 | label="Result", 164 | columns=NUM_IMAGES_PER_PROMPT, 165 | show_label=False, 166 | # height=800, 167 | ) 168 | with gr.Accordion("Advanced options", open=False): 169 | seed = gr.Slider( 170 | label="Seed", 171 | minimum=0, 172 | maximum=MAX_SEED, 173 | step=1, 174 | value=args.seed, 175 | ) 176 | randomize_seed = gr.Checkbox(label="Randomize seed", value=True) 177 | with gr.Row(): 178 | guidance_scale = gr.Slider( 179 | label="Guidance Scale", 180 | minimum=0.1, 181 | maximum=20, 182 | step=0.1, 183 | value=4.5, 184 | ) 185 | with gr.Row(): 186 | randomize_temperature = gr.Slider( 187 | label="Randomize Temperature", 188 | minimum=0.1, 189 | maximum=10, 190 | step=0.1, 191 | value=1.5, 192 | ) 193 | with gr.Row(): 194 | num_sample_steps = gr.Slider( 195 | label="Number of Sample Steps", 196 | minimum=8, 197 | maximum=32, 198 | step=1, 199 | value=12, 200 | ) 201 | 202 | gr.Examples( 203 | examples=examples, 204 | inputs=prompt, 205 | outputs=[result, seed], 206 | fn=generate, 207 | cache_examples=CACHE_EXAMPLES, 208 | ) 209 | 210 | gr.on( 211 | triggers=[ 212 | prompt.submit, 213 | run_button.click, 214 | ], 215 | fn=generate, 216 | inputs=[ 217 | prompt, 218 | seed, 219 | guidance_scale, 220 | randomize_temperature, 221 | randomize_seed, 222 | num_sample_steps, 223 | ], 224 | outputs=[result, seed], 225 | api_name="run", 226 | ) 227 | 228 | demo.queue(max_size=20).launch(share=True) 229 | 230 | 231 | if __name__ == '__main__': 232 | parser = argparse.ArgumentParser() 233 | parser.add_argument( 234 | "--config", 235 | type=str, 236 | default="configs/inference/dc_ar_t2i_512.yaml", 237 | ) 238 | parser.add_argument( 239 | "--shield_model_path", 240 | type=str, 241 | help="The path to shield model, we employ ShieldGemma-2B by default.", 242 | default="pretrained_models/shieldgemma-2b", 243 | ) 244 | parser.add_argument( 245 | "--seed", 246 | type=int, 247 | default=42, 248 | ) 249 | args = parser.parse_args() 250 | 251 | main(args) -------------------------------------------------------------------------------- /modeling/modules/ema_model.py: -------------------------------------------------------------------------------- 1 | """This file contains some base class implementation for EMA. 2 | 3 | This file may have been modified by Bytedance Ltd. and/or its affiliates (“Bytedance's Modifications”). 4 | All Bytedance's Modifications are Copyright (year) Bytedance Ltd. and/or its affiliates. 5 | 6 | Reference: 7 | https://github.com/huggingface/open-muse/blob/64e1afe033717d795866ab8204484705cd4dc3f7/muse/modeling_ema.py#L8 8 | """ 9 | 10 | 11 | import copy 12 | from typing import Any, Iterable, Optional, Union 13 | 14 | import torch 15 | 16 | 17 | class EMAModel: 18 | """Exponential Moving Average of models weights.""" 19 | def __init__( 20 | self, 21 | parameters: Iterable[torch.nn.Parameter], 22 | decay: float = 0.9999, 23 | min_decay: float = 0.0, 24 | update_after_step: int = 0, 25 | update_every: int = 1, 26 | current_step: int = 0, 27 | use_ema_warmup: bool = False, 28 | inv_gamma: Union[float, int] = 1.0, 29 | power: Union[float, int] = 2 / 3, 30 | model_cls: Optional[Any] = None, 31 | **model_config_kwargs 32 | ): 33 | """ 34 | Args: 35 | parameters (Iterable[torch.nn.Parameter]): The parameters to track. 36 | decay (float): The decay factor for the exponential moving average. 37 | min_decay (float): The minimum decay factor for the exponential moving average. 38 | update_after_step (int): The number of steps to wait before starting to update the EMA weights. 39 | update_every (int): The number of steps between each EMA update. 40 | current_step (int): The current training step. 41 | use_ema_warmup (bool): Whether to use EMA warmup. 42 | inv_gamma (float): 43 | Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True. 44 | power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True. 45 | 46 | notes on EMA Warmup: 47 | If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan 48 | to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), 49 | gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 50 | at 215.4k steps). 51 | """ 52 | 53 | parameters = list(parameters) 54 | self.shadow_params = [p.clone().detach() for p in parameters] 55 | self.temp_stored_params = None 56 | 57 | self.decay = decay 58 | self.min_decay = min_decay 59 | self.update_after_step = update_after_step 60 | self.update_every = update_every 61 | self.use_ema_warmup = use_ema_warmup 62 | self.inv_gamma = inv_gamma 63 | self.power = power 64 | self.optimization_step = current_step 65 | self.cur_decay_value = None # set in `step()` 66 | 67 | self.model_cls = model_cls 68 | self.model_config_kwargs = model_config_kwargs 69 | 70 | @classmethod 71 | def from_pretrained(cls, checkpoint, model_cls, **model_config_kwargs) -> "EMAModel": 72 | model = model_cls(**model_config_kwargs) 73 | model.load_pretrained_weight(checkpoint) 74 | 75 | ema_model = cls(model.parameters(), model_cls=model_cls, **model_config_kwargs) 76 | return ema_model 77 | 78 | def save_pretrained(self, path): 79 | if self.model_cls is None: 80 | raise ValueError("`save_pretrained` can only be used if `model_cls` was defined at __init__.") 81 | 82 | if self.model_config_kwargs is None: 83 | raise ValueError("`save_pretrained` can only be used if `model_config_kwargs` was defined at __init__.") 84 | 85 | model = self.model_cls(**self.model_config_kwargs) 86 | self.copy_to(model.parameters()) 87 | model.save_pretrained_weight(path) 88 | 89 | def set_step(self, optimization_step: int): 90 | self.optimization_step = optimization_step 91 | 92 | def get_decay(self, optimization_step: int) -> float: 93 | """Computes the decay factor for the exponential moving average.""" 94 | step = max(0, optimization_step - self.update_after_step - 1) 95 | 96 | if step <= 0: 97 | return 0.0 98 | 99 | if self.use_ema_warmup: 100 | cur_decay_value = 1 - (1 + step / self.inv_gamma) ** -self.power 101 | else: 102 | cur_decay_value = (1 + step) / (10 + step) 103 | 104 | cur_decay_value = min(cur_decay_value, self.decay) 105 | # Make sure decay is not smaller than min_decay. 106 | cur_decay_value = max(cur_decay_value, self.min_decay) 107 | return cur_decay_value 108 | 109 | @torch.no_grad() 110 | def step(self, parameters: Iterable[torch.nn.Parameter]): 111 | parameters = list(parameters) 112 | 113 | self.optimization_step += 1 114 | 115 | if (self.optimization_step - 1) % self.update_every != 0: 116 | return 117 | 118 | # Compute the decay factor for the exponential moving average. 119 | decay = self.get_decay(self.optimization_step) 120 | self.cur_decay_value = decay 121 | one_minus_decay = 1 - decay 122 | 123 | for s_param, param in zip(self.shadow_params, parameters): 124 | if param.requires_grad: 125 | s_param.sub_(one_minus_decay * (s_param - param)) 126 | else: 127 | s_param.copy_(param) 128 | 129 | def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: 130 | """Copies current averaged parameters into given collection of parameters. 131 | 132 | Args: 133 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 134 | updated with the stored moving averages. If `None`, the parameters with which this 135 | `ExponentialMovingAverage` was initialized will be used. 136 | """ 137 | parameters = list(parameters) 138 | for s_param, param in zip(self.shadow_params, parameters): 139 | param.data.copy_(s_param.to(param.device).data) 140 | 141 | def to(self, device=None, dtype=None) -> None: 142 | r"""Moves internal buffers of the ExponentialMovingAverage to `device`. 143 | 144 | Args: 145 | device: like `device` argument to `torch.Tensor.to` 146 | """ 147 | # .to() on the tensors handles None correctly 148 | self.shadow_params = [ 149 | p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) 150 | for p in self.shadow_params 151 | ] 152 | 153 | def state_dict(self) -> dict: 154 | r"""Returns the state of the ExponentialMovingAverage as a dict. This method is used by accelerate during 155 | checkpointing to save the ema state dict. 156 | """ 157 | # Following PyTorch conventions, references to tensors are returned: 158 | # "returns a reference to the state and not its copy!" - 159 | # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict 160 | return { 161 | "decay": self.decay, 162 | "min_decay": self.min_decay, 163 | "optimization_step": self.optimization_step, 164 | "update_after_step": self.update_after_step, 165 | "use_ema_warmup": self.use_ema_warmup, 166 | "inv_gamma": self.inv_gamma, 167 | "power": self.power, 168 | "shadow_params": self.shadow_params, 169 | } 170 | 171 | def store(self, parameters: Iterable[torch.nn.Parameter]) -> None: 172 | r""" 173 | Args: 174 | Save the current parameters for restoring later. 175 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 176 | temporarily stored. 177 | """ 178 | self.temp_stored_params = [param.detach().cpu().clone() for param in parameters] 179 | 180 | def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None: 181 | r"""Restores the parameters stored with the `store` method. Useful to validate 182 | the model with EMA parameters without affecting the original optimization process. 183 | Store the parameters before the `copy_to()` method. After validation (or 184 | model saving), use this to restore the former parameters. 185 | 186 | Args: 187 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 188 | updated with the stored parameters. If `None`, the parameters with which this 189 | `ExponentialMovingAverage` was initialized will be used. 190 | """ 191 | if self.temp_stored_params is None: 192 | raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights to `restore()`") 193 | for c_param, param in zip(self.temp_stored_params, parameters): 194 | param.data.copy_(c_param.data) 195 | 196 | # Better memory-wise. 197 | self.temp_stored_params = None 198 | 199 | def load_state_dict(self, state_dict: dict) -> None: 200 | r"""Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the 201 | ema state dict. 202 | 203 | Args: 204 | state_dict (dict): EMA state. Should be an object returned 205 | from a call to :meth:`state_dict`. 206 | """ 207 | # Deepcopy, to be consistent with module API 208 | state_dict = copy.deepcopy(state_dict) 209 | 210 | self.decay = state_dict.get("decay", self.decay) 211 | if self.decay < 0.0 or self.decay > 1.0: 212 | raise ValueError("Decay must be between 0 and 1") 213 | 214 | self.min_decay = state_dict.get("min_decay", self.min_decay) 215 | if not isinstance(self.min_decay, float): 216 | raise ValueError("Invalid min_decay") 217 | 218 | self.optimization_step = state_dict.get("optimization_step", self.optimization_step) 219 | if not isinstance(self.optimization_step, int): 220 | raise ValueError("Invalid optimization_step") 221 | 222 | self.update_after_step = state_dict.get("update_after_step", self.update_after_step) 223 | if not isinstance(self.update_after_step, int): 224 | raise ValueError("Invalid update_after_step") 225 | 226 | self.use_ema_warmup = state_dict.get("use_ema_warmup", self.use_ema_warmup) 227 | if not isinstance(self.use_ema_warmup, bool): 228 | raise ValueError("Invalid use_ema_warmup") 229 | 230 | self.inv_gamma = state_dict.get("inv_gamma", self.inv_gamma) 231 | if not isinstance(self.inv_gamma, (float, int)): 232 | raise ValueError("Invalid inv_gamma") 233 | 234 | self.power = state_dict.get("power", self.power) 235 | if not isinstance(self.power, (float, int)): 236 | raise ValueError("Invalid power") 237 | 238 | shadow_params = state_dict.get("shadow_params", None) 239 | if shadow_params is not None: 240 | self.shadow_params = shadow_params 241 | if not isinstance(self.shadow_params, list): 242 | raise ValueError("shadow_params must be a list") 243 | if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): 244 | raise ValueError("shadow_params must all be Tensors") -------------------------------------------------------------------------------- /modeling/generator/net/basic_modules.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | # This file is modified from https://github.com/PixArt-alpha/PixArt-sigma 18 | 19 | import torch 20 | import torch.nn as nn 21 | from timm.models.vision_transformer import Mlp 22 | 23 | from .act import build_act, get_act_name 24 | from .norms import build_norm, get_norm_name 25 | from .utils import get_same_padding, val2tuple 26 | 27 | 28 | class ConvLayer(nn.Module): 29 | def __init__( 30 | self, 31 | in_dim: int, 32 | out_dim: int, 33 | kernel_size=3, 34 | stride=1, 35 | dilation=1, 36 | groups=1, 37 | padding: int = None, 38 | use_bias=False, 39 | dropout=0.0, 40 | norm="bn2d", 41 | act="relu", 42 | ): 43 | super().__init__() 44 | if padding is None: 45 | padding = get_same_padding(kernel_size) 46 | padding *= dilation 47 | 48 | self.in_dim = in_dim 49 | self.out_dim = out_dim 50 | self.kernel_size = kernel_size 51 | self.stride = stride 52 | self.dilation = dilation 53 | self.groups = groups 54 | self.padding = padding 55 | self.use_bias = use_bias 56 | 57 | self.dropout = nn.Dropout2d(dropout, inplace=False) if dropout > 0 else None 58 | self.conv = nn.Conv2d( 59 | in_dim, 60 | out_dim, 61 | kernel_size=(kernel_size, kernel_size), 62 | stride=(stride, stride), 63 | padding=padding, 64 | dilation=(dilation, dilation), 65 | groups=groups, 66 | bias=use_bias, 67 | ) 68 | self.norm = build_norm(norm, num_features=out_dim) 69 | self.act = build_act(act) 70 | 71 | def forward(self, x: torch.Tensor) -> torch.Tensor: 72 | if self.dropout is not None: 73 | x = self.dropout(x) 74 | x = self.conv(x) 75 | if self.norm: 76 | x = self.norm(x) 77 | if self.act: 78 | x = self.act(x) 79 | return x 80 | 81 | 82 | class GLUMBConv(nn.Module): 83 | def __init__( 84 | self, 85 | in_features: int, 86 | hidden_features: int, 87 | out_feature=None, 88 | kernel_size=3, 89 | stride=1, 90 | padding: int = None, 91 | use_bias=False, 92 | norm=(None, None, None), 93 | act=("silu", "silu", None), 94 | dilation=1, 95 | ): 96 | out_feature = out_feature or in_features 97 | super().__init__() 98 | use_bias = val2tuple(use_bias, 3) 99 | norm = val2tuple(norm, 3) 100 | act = val2tuple(act, 3) 101 | 102 | self.glu_act = build_act(act[1], inplace=False) 103 | self.inverted_conv = ConvLayer( 104 | in_features, 105 | hidden_features * 2, 106 | 1, 107 | use_bias=use_bias[0], 108 | norm=norm[0], 109 | act=act[0], 110 | ) 111 | self.depth_conv = ConvLayer( 112 | hidden_features * 2, 113 | hidden_features * 2, 114 | kernel_size, 115 | stride=stride, 116 | groups=hidden_features * 2, 117 | padding=padding, 118 | use_bias=use_bias[1], 119 | norm=norm[1], 120 | act=None, 121 | dilation=dilation, 122 | ) 123 | self.point_conv = ConvLayer( 124 | hidden_features, 125 | out_feature, 126 | 1, 127 | use_bias=use_bias[2], 128 | norm=norm[2], 129 | act=act[2], 130 | ) 131 | # from IPython import embed; embed(header='debug dilate conv') 132 | 133 | def forward(self, x: torch.Tensor, HW=None) -> torch.Tensor: 134 | B, N, C = x.shape 135 | if HW is None: 136 | H = W = int(N**0.5) 137 | else: 138 | H, W = HW 139 | 140 | x = x.reshape(B, H, W, C).permute(0, 3, 1, 2) 141 | x = self.inverted_conv(x) 142 | x = self.depth_conv(x) 143 | 144 | x, gate = torch.chunk(x, 2, dim=1) 145 | gate = self.glu_act(gate) 146 | x = x * gate 147 | 148 | x = self.point_conv(x) 149 | x = x.reshape(B, C, N).permute(0, 2, 1) 150 | 151 | return x 152 | 153 | 154 | class SlimGLUMBConv(GLUMBConv): 155 | def __init__(self, *args, **kwargs): 156 | super().__init__(*args, **kwargs) 157 | 158 | # 移除 self.inverted_conv 层 159 | del self.inverted_conv 160 | self.out_dim = self.point_conv.out_dim 161 | 162 | def forward(self, x: torch.Tensor, HW=None) -> torch.Tensor: 163 | B, N, C = x.shape 164 | if HW is None: 165 | H = W = int(N**0.5) 166 | else: 167 | H, W = HW 168 | 169 | # 直接使用 x,跳过 self.inverted_conv 层的调用 170 | x = x.reshape(B, H, W, C).permute(0, 3, 1, 2) 171 | # x = self.inverted_conv(x) 172 | x = self.depth_conv(x) 173 | 174 | x, gate = torch.chunk(x, 2, dim=1) 175 | gate = self.glu_act(gate) 176 | x = x * gate 177 | 178 | x = self.point_conv(x) 179 | x = x.reshape(B, self.out_dim, N).permute(0, 2, 1) 180 | 181 | return x 182 | 183 | 184 | class MBConvPreGLU(nn.Module): 185 | def __init__( 186 | self, 187 | in_dim: int, 188 | out_dim: int, 189 | kernel_size=3, 190 | stride=1, 191 | mid_dim=None, 192 | expand=6, 193 | padding: int = None, 194 | use_bias=False, 195 | norm=(None, None, "ln2d"), 196 | act=("silu", "silu", None), 197 | ): 198 | super().__init__() 199 | use_bias = val2tuple(use_bias, 3) 200 | norm = val2tuple(norm, 3) 201 | act = val2tuple(act, 3) 202 | 203 | mid_dim = mid_dim or round(in_dim * expand) 204 | 205 | self.inverted_conv = ConvLayer( 206 | in_dim, 207 | mid_dim * 2, 208 | 1, 209 | use_bias=use_bias[0], 210 | norm=norm[0], 211 | act=None, 212 | ) 213 | self.glu_act = build_act(act[0], inplace=False) 214 | self.depth_conv = ConvLayer( 215 | mid_dim, 216 | mid_dim, 217 | kernel_size, 218 | stride=stride, 219 | groups=mid_dim, 220 | padding=padding, 221 | use_bias=use_bias[1], 222 | norm=norm[1], 223 | act=act[1], 224 | ) 225 | self.point_conv = ConvLayer( 226 | mid_dim, 227 | out_dim, 228 | 1, 229 | use_bias=use_bias[2], 230 | norm=norm[2], 231 | act=act[2], 232 | ) 233 | 234 | def forward(self, x: torch.Tensor, HW=None) -> torch.Tensor: 235 | B, N, C = x.shape 236 | if HW is None: 237 | H = W = int(N**0.5) 238 | else: 239 | H, W = HW 240 | 241 | x = x.reshape(B, H, W, C).permute(0, 3, 1, 2) 242 | 243 | x = self.inverted_conv(x) 244 | x, gate = torch.chunk(x, 2, dim=1) 245 | gate = self.glu_act(gate) 246 | x = x * gate 247 | 248 | x = self.depth_conv(x) 249 | x = self.point_conv(x) 250 | 251 | x = x.reshape(B, C, N).permute(0, 2, 1) 252 | return x 253 | 254 | @property 255 | def module_str(self) -> str: 256 | _str = f"{self.depth_conv.kernel_size}{type(self).__name__}(" 257 | _str += f"in={self.inverted_conv.in_dim},mid={self.depth_conv.in_dim},out={self.point_conv.out_dim},s={self.depth_conv.stride}" 258 | _str += ( 259 | f",norm={get_norm_name(self.inverted_conv.norm)}" 260 | f"+{get_norm_name(self.depth_conv.norm)}" 261 | f"+{get_norm_name(self.point_conv.norm)}" 262 | ) 263 | _str += ( 264 | f",act={get_act_name(self.inverted_conv.act)}" 265 | f"+{get_act_name(self.depth_conv.act)}" 266 | f"+{get_act_name(self.point_conv.act)}" 267 | ) 268 | _str += f",glu_act={get_act_name(self.glu_act)})" 269 | return _str 270 | 271 | 272 | class DWMlp(Mlp): 273 | """MLP as used in Vision Transformer, MLP-Mixer and related networks""" 274 | 275 | def __init__( 276 | self, 277 | in_features, 278 | hidden_features=None, 279 | out_features=None, 280 | act_layer=nn.GELU, 281 | bias=True, 282 | drop=0.0, 283 | kernel_size=3, 284 | stride=1, 285 | dilation=1, 286 | padding=None, 287 | ): 288 | super().__init__( 289 | in_features=in_features, 290 | hidden_features=hidden_features, 291 | out_features=out_features, 292 | act_layer=act_layer, 293 | bias=bias, 294 | drop=drop, 295 | ) 296 | hidden_features = hidden_features or in_features 297 | self.hidden_features = hidden_features 298 | if padding is None: 299 | padding = get_same_padding(kernel_size) 300 | padding *= dilation 301 | 302 | self.conv = nn.Conv2d( 303 | hidden_features, 304 | hidden_features, 305 | kernel_size=(kernel_size, kernel_size), 306 | stride=(stride, stride), 307 | padding=padding, 308 | dilation=(dilation, dilation), 309 | groups=hidden_features, 310 | bias=bias, 311 | ) 312 | 313 | def forward(self, x, HW=None): 314 | B, N, C = x.shape 315 | if HW is None: 316 | H = W = int(N**0.5) 317 | else: 318 | H, W = HW 319 | x = self.fc1(x) 320 | x = self.act(x) 321 | x = self.drop1(x) 322 | x = x.reshape(B, H, W, self.hidden_features).permute(0, 3, 1, 2) 323 | x = self.conv(x) 324 | x = x.reshape(B, self.hidden_features, N).permute(0, 2, 1) 325 | x = self.fc2(x) 326 | x = self.drop2(x) 327 | return x 328 | 329 | 330 | class Mlp(Mlp): 331 | """MLP as used in Vision Transformer, MLP-Mixer and related networks""" 332 | 333 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0.0): 334 | super().__init__( 335 | in_features=in_features, 336 | hidden_features=hidden_features, 337 | out_features=out_features, 338 | act_layer=act_layer, 339 | bias=bias, 340 | drop=drop, 341 | ) 342 | 343 | def forward(self, x, HW=None): 344 | x = self.fc1(x) 345 | x = self.act(x) 346 | x = self.drop1(x) 347 | x = self.fc2(x) 348 | x = self.drop2(x) 349 | return x 350 | 351 | 352 | if __name__ == "__main__": 353 | model = GLUMBConv( 354 | 1152, 355 | 1152 * 4, 356 | 1152, 357 | use_bias=(True, True, False), 358 | norm=(None, None, None), 359 | act=("silu", "silu", None), 360 | ).cuda() 361 | input = torch.randn(4, 256, 1152).cuda() 362 | output = model(input) -------------------------------------------------------------------------------- /LICENSE/code.txt: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2025 NVIDIA 190 | Licensed under the Apache License, Version 2.0 (the "License"); 191 | you may not use this file except in compliance with the License. 192 | You may obtain a copy of the License at 193 | 194 | http://www.apache.org/licenses/LICENSE-2.0 195 | 196 | Unless required by applicable law or agreed to in writing, software 197 | distributed under the License is distributed on an "AS IS" BASIS, 198 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 199 | See the License for the specific language governing permissions and 200 | limitations under the License. 201 | -------------------------------------------------------------------------------- /modeling/diffusion/diffloss.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | # Modified from OpenAI's diffusion repos 18 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 19 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 20 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 21 | 22 | import math 23 | 24 | import torch 25 | import torch.nn as nn 26 | 27 | from . import gaussian_diffusion as gd 28 | from .respace import SpacedDiffusion, space_timesteps 29 | 30 | class DiffLoss(nn.Module): 31 | def __init__( 32 | self, 33 | target_channels, 34 | z_channels, 35 | depth, 36 | width, 37 | num_sampling_steps, 38 | sampler="iddpm", 39 | vae_scale=0.6 40 | ): 41 | super().__init__() 42 | self.vae_scale = vae_scale 43 | self.in_channels = target_channels 44 | self.net = SimpleMLPAdaLN( 45 | in_channels=target_channels, 46 | model_channels=width, 47 | out_channels=target_channels * 2, 48 | z_channels=z_channels, 49 | num_res_blocks=depth, 50 | ) 51 | self.num_sampling_steps = num_sampling_steps 52 | self.sampler = sampler 53 | 54 | self.train_diffusion = create_diffusion( 55 | timestep_respacing="", noise_schedule="cosine" 56 | ) 57 | self.gen_diffusion = create_diffusion( 58 | timestep_respacing=num_sampling_steps, noise_schedule="cosine" 59 | ) 60 | 61 | def initialize_weights(self): 62 | self.net.initialize_weights() 63 | 64 | def forward(self, target, z, mask=None): 65 | t = torch.randint( 66 | 0, 67 | self.train_diffusion.num_timesteps, 68 | (target.shape[0],), 69 | device=target.device, 70 | ) 71 | model_kwargs = dict(c=z) 72 | loss_dict = self.train_diffusion.training_losses( 73 | self.net, target / self.vae_scale, t, model_kwargs 74 | ) 75 | loss = loss_dict["loss"] 76 | if mask is not None: 77 | loss = (loss * mask).sum() / mask.sum() 78 | return loss.mean() 79 | 80 | def sample(self, z, temperature=1.0, cfg=1.5, sampler=None): 81 | # diffusion loss sampling 82 | if not cfg == 1.0: 83 | noise = torch.randn(z.shape[0] // 2, self.in_channels).cuda() 84 | noise = torch.cat([noise, noise], dim=0) 85 | model_kwargs = dict(c=z, cfg_scale=cfg) 86 | sample_fn = self.net.forward_with_cfg 87 | else: 88 | noise = torch.randn(z.shape[0], self.in_channels).cuda() 89 | model_kwargs = dict(c=z) 90 | sample_fn = self.net.forward 91 | 92 | if sampler is None: 93 | sampler = self.sampler 94 | 95 | if sampler == "iddpm": 96 | sampled_token_latent = ( 97 | self.gen_diffusion.p_sample_loop( 98 | sample_fn, 99 | noise.shape, 100 | noise, 101 | clip_denoised=True, 102 | model_kwargs=model_kwargs, 103 | progress=False, 104 | temperature=temperature, 105 | ) 106 | * self.vae_scale 107 | ) 108 | else: 109 | raise NotImplementedError 110 | 111 | return sampled_token_latent 112 | 113 | 114 | def modulate(x, shift, scale): 115 | return x * (1 + scale) + shift 116 | 117 | 118 | class TimestepEmbedder(nn.Module): 119 | """ 120 | Embeds scalar timesteps into vector representations. 121 | """ 122 | 123 | def __init__(self, hidden_size, frequency_embedding_size=256): 124 | super().__init__() 125 | self.mlp = nn.Sequential( 126 | nn.Linear(frequency_embedding_size, hidden_size, bias=True), 127 | nn.SiLU(), 128 | nn.Linear(hidden_size, hidden_size, bias=True), 129 | ) 130 | self.frequency_embedding_size = frequency_embedding_size 131 | 132 | @staticmethod 133 | def timestep_embedding(t, dim, max_period=10000): 134 | """ 135 | Create sinusoidal timestep embeddings. 136 | :param t: a 1-D Tensor of N indices, one per batch element. 137 | These may be fractional. 138 | :param dim: the dimension of the output. 139 | :param max_period: controls the minimum frequency of the embeddings. 140 | :return: an (N, D) Tensor of positional embeddings. 141 | """ 142 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py 143 | half = dim // 2 144 | freqs = torch.exp( 145 | -math.log(max_period) 146 | * torch.arange(start=0, end=half, dtype=torch.float32) 147 | / half 148 | ).to(device=t.device) 149 | args = t[:, None].float() * freqs[None] 150 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 151 | if dim % 2: 152 | embedding = torch.cat( 153 | [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 154 | ) 155 | return embedding 156 | 157 | def forward(self, t): 158 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size) 159 | t_emb = self.mlp(t_freq) 160 | return t_emb 161 | 162 | 163 | class ResBlock(nn.Module): 164 | """ 165 | A residual block that can optionally change the number of channels. 166 | :param channels: the number of input channels. 167 | """ 168 | 169 | def __init__(self, channels): 170 | super().__init__() 171 | self.channels = channels 172 | 173 | self.in_ln = nn.LayerNorm(channels, eps=1e-6) 174 | self.mlp = nn.Sequential( 175 | nn.Linear(channels, channels, bias=True), 176 | nn.SiLU(), 177 | nn.Linear(channels, channels, bias=True), 178 | ) 179 | 180 | self.adaLN_modulation = nn.Sequential( 181 | nn.SiLU(), nn.Linear(channels, 3 * channels, bias=True) 182 | ) 183 | 184 | def forward(self, x, y): 185 | shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1) 186 | h = modulate(self.in_ln(x), shift_mlp, scale_mlp) 187 | h = self.mlp(h) 188 | return x + gate_mlp * h 189 | 190 | 191 | class FinalLayer(nn.Module): 192 | """ 193 | The final layer of DiT. 194 | """ 195 | 196 | def __init__(self, model_channels, out_channels): 197 | super().__init__() 198 | self.norm_final = nn.LayerNorm( 199 | model_channels, elementwise_affine=False, eps=1e-6 200 | ) 201 | self.linear = nn.Linear(model_channels, out_channels, bias=True) 202 | self.adaLN_modulation = nn.Sequential( 203 | nn.SiLU(), nn.Linear(model_channels, 2 * model_channels, bias=True) 204 | ) 205 | 206 | def forward(self, x, c): 207 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) 208 | x = modulate(self.norm_final(x), shift, scale) 209 | x = self.linear(x) 210 | return x 211 | 212 | 213 | class SimpleMLPAdaLN(nn.Module): 214 | def __init__( 215 | self, 216 | in_channels, 217 | model_channels, 218 | out_channels, 219 | z_channels, 220 | num_res_blocks, 221 | ): 222 | super().__init__() 223 | 224 | self.in_channels = in_channels 225 | self.model_channels = model_channels 226 | self.out_channels = out_channels 227 | self.num_res_blocks = num_res_blocks 228 | 229 | self.time_embed = TimestepEmbedder(model_channels) 230 | self.cond_embed = nn.Linear(z_channels, model_channels) 231 | 232 | self.input_proj = nn.Linear(in_channels, model_channels) 233 | 234 | res_blocks = [] 235 | for i in range(num_res_blocks): 236 | res_blocks.append(ResBlock(model_channels)) 237 | 238 | self.res_blocks = nn.ModuleList(res_blocks) 239 | self.final_layer = FinalLayer(model_channels, out_channels) 240 | 241 | self.initialize_weights() 242 | 243 | def initialize_weights(self): 244 | def _basic_init(module): 245 | if isinstance(module, nn.Linear): 246 | torch.nn.init.xavier_uniform_(module.weight) 247 | if module.bias is not None: 248 | nn.init.constant_(module.bias, 0) 249 | 250 | self.apply(_basic_init) 251 | 252 | # Initialize timestep embedding MLP 253 | nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02) 254 | nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02) 255 | 256 | # Zero-out adaLN modulation layers 257 | for block in self.res_blocks: 258 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0) 259 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0) 260 | 261 | # Zero-out output layers 262 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) 263 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) 264 | nn.init.constant_(self.final_layer.linear.weight, 0) 265 | nn.init.constant_(self.final_layer.linear.bias, 0) 266 | 267 | def forward(self, x, t, c): 268 | x = self.input_proj(x) 269 | t = self.time_embed(t) 270 | c = self.cond_embed(c) 271 | 272 | y = t + c 273 | 274 | for block in self.res_blocks: 275 | x = block(x, y) 276 | o = self.final_layer(x, y) 277 | return o 278 | 279 | def forward_with_cfg(self, x, t, c, cfg_scale): 280 | half = x[: len(x) // 2] 281 | combined = torch.cat([half, half], dim=0) 282 | model_out = self.forward(combined, t, c) 283 | eps, rest = model_out[:, : self.in_channels], model_out[:, self.in_channels :] 284 | cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) 285 | half_eps = cond_eps * (1 + cfg_scale) - cfg_scale * uncond_eps 286 | eps = torch.cat([half_eps, half_eps], dim=0) 287 | return torch.cat([eps, rest], dim=1) 288 | 289 | def create_diffusion( 290 | timestep_respacing, 291 | noise_schedule="linear", 292 | use_kl=False, 293 | sigma_small=False, 294 | predict_xstart=False, 295 | learn_sigma=True, 296 | rescale_learned_sigmas=False, 297 | diffusion_steps=1000, 298 | ): 299 | betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps) 300 | if use_kl: 301 | loss_type = gd.LossType.RESCALED_KL 302 | elif rescale_learned_sigmas: 303 | loss_type = gd.LossType.RESCALED_MSE 304 | else: 305 | loss_type = gd.LossType.MSE 306 | if timestep_respacing is None or timestep_respacing == "": 307 | timestep_respacing = [diffusion_steps] 308 | return SpacedDiffusion( 309 | use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), 310 | betas=betas, 311 | model_mean_type=( 312 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X 313 | ), 314 | model_var_type=( 315 | ( 316 | gd.ModelVarType.FIXED_LARGE 317 | if not sigma_small 318 | else gd.ModelVarType.FIXED_SMALL 319 | ) 320 | if not learn_sigma 321 | else gd.ModelVarType.LEARNED_RANGE 322 | ), 323 | loss_type=loss_type, 324 | # rescale_timesteps=rescale_timesteps, 325 | ) -------------------------------------------------------------------------------- /modeling/tokenizer/networks/dc_ae.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | # This file is modified from https://github.com/mit-han-lab/efficientvit 18 | 19 | from dataclasses import dataclass, field 20 | from typing import Any, Optional 21 | 22 | import torch 23 | import torch.nn as nn 24 | from omegaconf import MISSING, OmegaConf 25 | 26 | from .dc_ae_blocks.act import build_act 27 | from .dc_ae_blocks.norm import build_norm 28 | from .dc_ae_blocks.ops import ( 29 | ChannelDuplicatingPixelUnshuffleUpSampleLayer, 30 | ConvLayer, 31 | ConvPixelShuffleUpSampleLayer, 32 | ConvPixelUnshuffleDownSampleLayer, 33 | EfficientViTBlock, 34 | IdentityLayer, 35 | OpSequential, 36 | PixelUnshuffleChannelAveragingDownSampleLayer, 37 | ResBlock, 38 | ResidualBlock, 39 | ) 40 | 41 | __all__ = ["DCAE", "dc_ae_f16", "dc_ae_f16_res", "dc_ae_f32", "dc_ae_f64c128", "dc_ae_f128c512"] 42 | 43 | 44 | @dataclass 45 | class EncoderConfig: 46 | in_channels: int = MISSING 47 | latent_channels: int = MISSING 48 | width_list: tuple[int, ...] = (128, 256, 512, 512, 1024, 1024) 49 | depth_list: tuple[int, ...] = (2, 2, 2, 2, 2, 2) 50 | block_type: Any = "ResBlock" 51 | norm: str = "ln2d" 52 | act: str = "silu" 53 | downsample_block_type: str = "ConvPixelUnshuffle" 54 | downsample_match_channel: bool = True 55 | downsample_shortcut: Optional[str] = "averaging" 56 | out_norm: Optional[str] = None 57 | out_act: Optional[str] = None 58 | out_shortcut: Optional[str] = "averaging" 59 | double_latent: bool = False 60 | 61 | 62 | @dataclass 63 | class DecoderConfig: 64 | in_channels: int = MISSING 65 | latent_channels: int = MISSING 66 | in_shortcut: Optional[str] = "duplicating" 67 | width_list: tuple[int, ...] = (128, 256, 512, 512, 1024, 1024) 68 | depth_list: tuple[int, ...] = (2, 2, 2, 2, 2, 2) 69 | block_type: Any = "ResBlock" 70 | norm: Any = "ln2d" 71 | act: Any = "silu" 72 | upsample_block_type: str = "ConvPixelShuffle" 73 | upsample_match_channel: bool = True 74 | upsample_shortcut: str = "duplicating" 75 | out_norm: str = "ln2d" 76 | out_act: str = "relu" 77 | 78 | 79 | @dataclass 80 | class DCAEConfig: 81 | in_channels: int = 3 82 | latent_channels: int = 32 83 | encoder: EncoderConfig = field( 84 | default_factory=lambda: EncoderConfig(in_channels="${..in_channels}", latent_channels="${..latent_channels}") 85 | ) 86 | decoder: DecoderConfig = field( 87 | default_factory=lambda: DecoderConfig(in_channels="${..in_channels}", latent_channels="${..latent_channels}") 88 | ) 89 | use_quant_conv: bool = False 90 | 91 | pretrained_path: Optional[str] = None 92 | pretrained_source: str = "dc-ae" 93 | 94 | 95 | def build_block( 96 | block_type: str, in_channels: int, out_channels: int, norm: Optional[str], act: Optional[str] 97 | ) -> nn.Module: 98 | if block_type == "ResBlock": 99 | assert in_channels == out_channels 100 | main_block = ResBlock( 101 | in_channels=in_channels, 102 | out_channels=out_channels, 103 | kernel_size=3, 104 | stride=1, 105 | use_bias=(True, False), 106 | norm=(None, norm), 107 | act_func=(act, None), 108 | ) 109 | block = ResidualBlock(main_block, IdentityLayer()) 110 | elif block_type == "EViT_GLU": 111 | assert in_channels == out_channels 112 | block = EfficientViTBlock(in_channels, norm=norm, act_func=act, local_module="GLUMBConv", scales=()) 113 | else: 114 | raise ValueError(f"block_type {block_type} is not supported") 115 | return block 116 | 117 | 118 | def build_stage_main( 119 | width: int, depth: int, block_type: str | list[str], norm: str, act: str, input_width: int 120 | ) -> list[nn.Module]: 121 | assert isinstance(block_type, str) or (isinstance(block_type, list) and depth == len(block_type)) 122 | stage = [] 123 | for d in range(depth): 124 | current_block_type = block_type[d] if isinstance(block_type, list) else block_type 125 | block = build_block( 126 | block_type=current_block_type, 127 | in_channels=width if d > 0 else input_width, 128 | out_channels=width, 129 | norm=norm, 130 | act=act, 131 | ) 132 | stage.append(block) 133 | return stage 134 | 135 | 136 | def build_downsample_block(block_type: str, in_channels: int, out_channels: int, shortcut: Optional[str]) -> nn.Module: 137 | if block_type == "Conv": 138 | block = ConvLayer( 139 | in_channels=in_channels, 140 | out_channels=out_channels, 141 | kernel_size=3, 142 | stride=2, 143 | use_bias=True, 144 | norm=None, 145 | act=None, 146 | ) 147 | elif block_type == "ConvPixelUnshuffle": 148 | block = ConvPixelUnshuffleDownSampleLayer( 149 | in_channels=in_channels, out_channels=out_channels, kernel_size=3, factor=2 150 | ) 151 | else: 152 | raise ValueError(f"block_type {block_type} is not supported for downsampling") 153 | if shortcut is None: 154 | pass 155 | elif shortcut == "averaging": 156 | shortcut_block = PixelUnshuffleChannelAveragingDownSampleLayer( 157 | in_channels=in_channels, out_channels=out_channels, factor=2 158 | ) 159 | block = ResidualBlock(block, shortcut_block) 160 | else: 161 | raise ValueError(f"shortcut {shortcut} is not supported for downsample") 162 | return block 163 | 164 | 165 | def build_upsample_block(block_type: str, in_channels: int, out_channels: int, shortcut: Optional[str]) -> nn.Module: 166 | if block_type == "ConvPixelShuffle": 167 | block = ConvPixelShuffleUpSampleLayer( 168 | in_channels=in_channels, out_channels=out_channels, kernel_size=3, factor=2 169 | ) 170 | else: 171 | raise ValueError(f"block_type {block_type} is not supported for upsampling") 172 | if shortcut is None: 173 | pass 174 | elif shortcut == "duplicating": 175 | shortcut_block = ChannelDuplicatingPixelUnshuffleUpSampleLayer( 176 | in_channels=in_channels, out_channels=out_channels, factor=2 177 | ) 178 | block = ResidualBlock(block, shortcut_block) 179 | else: 180 | raise ValueError(f"shortcut {shortcut} is not supported for upsample") 181 | return block 182 | 183 | 184 | def build_encoder_project_in_block(in_channels: int, out_channels: int, factor: int, downsample_block_type: str): 185 | if factor == 1: 186 | block = ConvLayer( 187 | in_channels=in_channels, 188 | out_channels=out_channels, 189 | kernel_size=3, 190 | stride=1, 191 | use_bias=True, 192 | norm=None, 193 | act_func=None, 194 | ) 195 | elif factor == 2: 196 | block = build_downsample_block( 197 | block_type=downsample_block_type, in_channels=in_channels, out_channels=out_channels, shortcut=None 198 | ) 199 | else: 200 | raise ValueError(f"downsample factor {factor} is not supported for encoder project in") 201 | return block 202 | 203 | 204 | def build_encoder_project_out_block( 205 | in_channels: int, out_channels: int, norm: Optional[str], act: Optional[str], shortcut: Optional[str] 206 | ): 207 | block = OpSequential( 208 | [ 209 | build_norm(norm), 210 | build_act(act), 211 | ConvLayer( 212 | in_channels=in_channels, 213 | out_channels=out_channels, 214 | kernel_size=3, 215 | stride=1, 216 | use_bias=True, 217 | norm=None, 218 | act_func=None, 219 | ), 220 | ] 221 | ) 222 | if shortcut is None: 223 | pass 224 | elif shortcut == "averaging": 225 | shortcut_block = PixelUnshuffleChannelAveragingDownSampleLayer( 226 | in_channels=in_channels, out_channels=out_channels, factor=1 227 | ) 228 | block = ResidualBlock(block, shortcut_block) 229 | else: 230 | raise ValueError(f"shortcut {shortcut} is not supported for encoder project out") 231 | return block 232 | 233 | 234 | def build_decoder_project_in_block(in_channels: int, out_channels: int, shortcut: Optional[str]): 235 | block = ConvLayer( 236 | in_channels=in_channels, 237 | out_channels=out_channels, 238 | kernel_size=3, 239 | stride=1, 240 | use_bias=True, 241 | norm=None, 242 | act_func=None, 243 | ) 244 | if shortcut is None: 245 | pass 246 | elif shortcut == "duplicating": 247 | shortcut_block = ChannelDuplicatingPixelUnshuffleUpSampleLayer( 248 | in_channels=in_channels, out_channels=out_channels, factor=1 249 | ) 250 | block = ResidualBlock(block, shortcut_block) 251 | else: 252 | raise ValueError(f"shortcut {shortcut} is not supported for decoder project in") 253 | return block 254 | 255 | 256 | def build_decoder_project_out_block( 257 | in_channels: int, out_channels: int, factor: int, upsample_block_type: str, norm: Optional[str], act: Optional[str] 258 | ): 259 | layers: list[nn.Module] = [ 260 | build_norm(norm, in_channels), 261 | build_act(act), 262 | ] 263 | if factor == 1: 264 | layers.append( 265 | ConvLayer( 266 | in_channels=in_channels, 267 | out_channels=out_channels, 268 | kernel_size=3, 269 | stride=1, 270 | use_bias=True, 271 | norm=None, 272 | act_func=None, 273 | ) 274 | ) 275 | elif factor == 2: 276 | layers.append( 277 | build_upsample_block( 278 | block_type=upsample_block_type, in_channels=in_channels, out_channels=out_channels, shortcut=None 279 | ) 280 | ) 281 | else: 282 | raise ValueError(f"upsample factor {factor} is not supported for decoder project out") 283 | return OpSequential(layers) 284 | 285 | 286 | class Encoder(nn.Module): 287 | def __init__(self, cfg: EncoderConfig): 288 | super().__init__() 289 | self.cfg = cfg 290 | num_stages = len(cfg.width_list) 291 | self.num_stages = num_stages 292 | assert len(cfg.depth_list) == num_stages 293 | assert len(cfg.width_list) == num_stages 294 | assert isinstance(cfg.block_type, str) or ( 295 | isinstance(cfg.block_type, list) and len(cfg.block_type) == num_stages 296 | ) 297 | 298 | self.project_in = build_encoder_project_in_block( 299 | in_channels=cfg.in_channels, 300 | out_channels=cfg.width_list[0] if cfg.depth_list[0] > 0 else cfg.width_list[1], 301 | factor=1 if cfg.depth_list[0] > 0 else 2, 302 | downsample_block_type=cfg.downsample_block_type, 303 | ) 304 | 305 | self.stages: list[OpSequential] = [] 306 | for stage_id, (width, depth) in enumerate(zip(cfg.width_list, cfg.depth_list)): 307 | block_type = cfg.block_type[stage_id] if isinstance(cfg.block_type, list) else cfg.block_type 308 | stage = build_stage_main( 309 | width=width, depth=depth, block_type=block_type, norm=cfg.norm, act=cfg.act, input_width=width 310 | ) 311 | 312 | if stage_id < num_stages - 1 and depth > 0: 313 | downsample_block = build_downsample_block( 314 | block_type=cfg.downsample_block_type, 315 | in_channels=width, 316 | out_channels=cfg.width_list[stage_id + 1] if cfg.downsample_match_channel else width, 317 | shortcut=cfg.downsample_shortcut, 318 | ) 319 | stage.append(downsample_block) 320 | self.stages.append(OpSequential(stage)) 321 | self.stages = nn.ModuleList(self.stages) 322 | 323 | self.project_out = build_encoder_project_out_block( 324 | in_channels=cfg.width_list[-1], 325 | out_channels=2 * cfg.latent_channels if cfg.double_latent else cfg.latent_channels, 326 | norm=cfg.out_norm, 327 | act=cfg.out_act, 328 | shortcut=cfg.out_shortcut, 329 | ) 330 | 331 | def forward(self, x: torch.Tensor) -> torch.Tensor: 332 | x = self.project_in(x) 333 | for stage in self.stages: 334 | if len(stage.op_list) == 0: 335 | continue 336 | x = stage(x) 337 | x = self.project_out(x) 338 | return x 339 | 340 | 341 | class Decoder(nn.Module): 342 | def __init__(self, cfg: DecoderConfig): 343 | super().__init__() 344 | self.cfg = cfg 345 | num_stages = len(cfg.width_list) 346 | self.num_stages = num_stages 347 | assert len(cfg.depth_list) == num_stages 348 | assert len(cfg.width_list) == num_stages 349 | assert isinstance(cfg.block_type, str) or ( 350 | isinstance(cfg.block_type, list) and len(cfg.block_type) == num_stages 351 | ) 352 | assert isinstance(cfg.norm, str) or (isinstance(cfg.norm, list) and len(cfg.norm) == num_stages) 353 | assert isinstance(cfg.act, str) or (isinstance(cfg.act, list) and len(cfg.act) == num_stages) 354 | 355 | self.project_in = build_decoder_project_in_block( 356 | in_channels=cfg.latent_channels, 357 | out_channels=cfg.width_list[-1], 358 | shortcut=cfg.in_shortcut, 359 | ) 360 | 361 | self.stages: list[OpSequential] = [] 362 | for stage_id, (width, depth) in reversed(list(enumerate(zip(cfg.width_list, cfg.depth_list)))): 363 | stage = [] 364 | if stage_id < num_stages - 1 and depth > 0: 365 | upsample_block = build_upsample_block( 366 | block_type=cfg.upsample_block_type, 367 | in_channels=cfg.width_list[stage_id + 1], 368 | out_channels=width if cfg.upsample_match_channel else cfg.width_list[stage_id + 1], 369 | shortcut=cfg.upsample_shortcut, 370 | ) 371 | stage.append(upsample_block) 372 | 373 | block_type = cfg.block_type[stage_id] if isinstance(cfg.block_type, list) else cfg.block_type 374 | norm = cfg.norm[stage_id] if isinstance(cfg.norm, list) else cfg.norm 375 | act = cfg.act[stage_id] if isinstance(cfg.act, list) else cfg.act 376 | stage.extend( 377 | build_stage_main( 378 | width=width, 379 | depth=depth, 380 | block_type=block_type, 381 | norm=norm, 382 | act=act, 383 | input_width=( 384 | width if cfg.upsample_match_channel else cfg.width_list[min(stage_id + 1, num_stages - 1)] 385 | ), 386 | ) 387 | ) 388 | self.stages.insert(0, OpSequential(stage)) 389 | self.stages = nn.ModuleList(self.stages) 390 | 391 | self.project_out = build_decoder_project_out_block( 392 | in_channels=cfg.width_list[0] if cfg.depth_list[0] > 0 else cfg.width_list[1], 393 | out_channels=cfg.in_channels, 394 | factor=1 if cfg.depth_list[0] > 0 else 2, 395 | upsample_block_type=cfg.upsample_block_type, 396 | norm=cfg.out_norm, 397 | act=cfg.out_act, 398 | ) 399 | 400 | def forward(self, x: torch.Tensor) -> torch.Tensor: 401 | x = self.project_in(x) 402 | for stage in reversed(self.stages): 403 | if len(stage.op_list) == 0: 404 | continue 405 | x = stage(x) 406 | x = self.project_out(x) 407 | return x 408 | 409 | 410 | class DCAE(nn.Module): 411 | def __init__(self, cfg: DCAEConfig): 412 | super().__init__() 413 | self.cfg = cfg 414 | self.encoder = Encoder(cfg.encoder) 415 | self.decoder = Decoder(cfg.decoder) 416 | 417 | if self.cfg.pretrained_path is not None: 418 | self.load_model() 419 | 420 | def load_model(self): 421 | if self.cfg.pretrained_source == "dc-ae": 422 | state_dict = torch.load(self.cfg.pretrained_path, map_location="cpu", weights_only=True)["state_dict"] 423 | self.load_state_dict(state_dict) 424 | else: 425 | raise NotImplementedError 426 | 427 | @property 428 | def spatial_compression_ratio(self) -> int: 429 | return 2 ** (self.decoder.num_stages - 1) 430 | 431 | def encode(self, x: torch.Tensor) -> torch.Tensor: 432 | x = self.encoder(x) 433 | return x 434 | 435 | def decode(self, x: torch.Tensor) -> torch.Tensor: 436 | x = self.decoder(x) 437 | return x 438 | 439 | def forward(self, x: torch.Tensor, global_step: int = 0) -> torch.Tensor: 440 | x = self.encoder(x) 441 | x = self.decoder(x) 442 | return x, torch.tensor(0), {} 443 | 444 | def dc_ae_f16(name: str, channels, pretrained_path: str) -> DCAEConfig: 445 | if name in ['dc-ae-f16-in-1.0']: 446 | cfg_str = ( 447 | f"latent_channels={channels} " 448 | "encoder.block_type=[ResBlock,ResBlock,ResBlock,EViT_GLU,EViT_GLU] " 449 | "encoder.width_list=[128,256,512,512,1024] encoder.depth_list=[0,4,8,2,2] " 450 | "decoder.block_type=[ResBlock,ResBlock,ResBlock,EViT_GLU,EViT_GLU] " 451 | "decoder.width_list=[128,256,512,512,1024] decoder.depth_list=[0,5,10,2,2] " 452 | "decoder.norm=[bn2d,bn2d,bn2d,ln2d,ln2d] decoder.act=[relu,relu,relu,silu,silu]" 453 | ) 454 | else: 455 | raise NotImplementedError 456 | cfg = OmegaConf.from_dotlist(cfg_str.split(" ")) 457 | cfg: DCAEConfig = OmegaConf.to_object(OmegaConf.merge(OmegaConf.structured(DCAEConfig), cfg)) 458 | cfg.pretrained_path = pretrained_path 459 | return cfg 460 | 461 | def dc_ae_f16_res(name: str, channels, pretrained_path: str) -> DCAEConfig: 462 | if name in ['dc-ae-f16-res-in-1.0']: 463 | cfg_str = ( 464 | f"latent_channels={channels} " 465 | "encoder.block_type=[ResBlock,ResBlock,ResBlock,ResBlock,ResBlock] " 466 | "encoder.width_list=[128,256,512,512,1024] encoder.depth_list=[0,4,8,4,4] " 467 | "decoder.block_type=[ResBlock,ResBlock,ResBlock,ResBlock,ResBlock] " 468 | "decoder.width_list=[128,256,512,512,1024] decoder.depth_list=[0,5,10,4,4] " 469 | "decoder.norm=[bn2d,bn2d,bn2d,ln2d,ln2d] decoder.act=[relu,relu,relu,silu,silu]" 470 | ) 471 | else: 472 | raise NotImplementedError 473 | cfg = OmegaConf.from_dotlist(cfg_str.split(" ")) 474 | cfg: DCAEConfig = OmegaConf.to_object(OmegaConf.merge(OmegaConf.structured(DCAEConfig), cfg)) 475 | cfg.pretrained_path = pretrained_path 476 | return cfg 477 | 478 | def dc_ae_f32(name: str, channels, pretrained_path: str) -> DCAEConfig: 479 | if name in ["dc-ae-f32-in-1.0", "dc-ae-f32-mix-1.0"]: 480 | cfg_str = ( 481 | f"latent_channels={channels} " 482 | "encoder.block_type=[ResBlock,ResBlock,ResBlock,EViT_GLU,EViT_GLU,EViT_GLU] " 483 | "encoder.width_list=[128,256,512,512,1024,1024] encoder.depth_list=[0,4,8,2,2,2] " 484 | "decoder.block_type=[ResBlock,ResBlock,ResBlock,EViT_GLU,EViT_GLU,EViT_GLU] " 485 | "decoder.width_list=[128,256,512,512,1024,1024] decoder.depth_list=[0,5,10,2,2,2] " 486 | "decoder.norm=[bn2d,bn2d,bn2d,ln2d,ln2d,ln2d] decoder.act=[relu,relu,relu,silu,silu,silu]" 487 | ) 488 | else: 489 | raise NotImplementedError 490 | cfg = OmegaConf.from_dotlist(cfg_str.split(" ")) 491 | cfg: DCAEConfig = OmegaConf.to_object(OmegaConf.merge(OmegaConf.structured(DCAEConfig), cfg)) 492 | cfg.pretrained_path = pretrained_path 493 | return cfg 494 | 495 | 496 | def dc_ae_f64c128(name: str, pretrained_path: Optional[str] = None) -> DCAEConfig: 497 | if name in ["dc-ae-f64c128-in-1.0", "dc-ae-f64c128-mix-1.0"]: 498 | cfg_str = ( 499 | "latent_channels=128 " 500 | "encoder.block_type=[ResBlock,ResBlock,ResBlock,EViT_GLU,EViT_GLU,EViT_GLU,EViT_GLU] " 501 | "encoder.width_list=[128,256,512,512,1024,1024,2048] encoder.depth_list=[0,4,8,2,2,2,2] " 502 | "decoder.block_type=[ResBlock,ResBlock,ResBlock,EViT_GLU,EViT_GLU,EViT_GLU,EViT_GLU] " 503 | "decoder.width_list=[128,256,512,512,1024,1024,2048] decoder.depth_list=[0,5,10,2,2,2,2] " 504 | "decoder.norm=[bn2d,bn2d,bn2d,ln2d,ln2d,ln2d,ln2d] decoder.act=[relu,relu,relu,silu,silu,silu,silu]" 505 | ) 506 | else: 507 | raise NotImplementedError 508 | cfg = OmegaConf.from_dotlist(cfg_str.split(" ")) 509 | cfg: DCAEConfig = OmegaConf.to_object(OmegaConf.merge(OmegaConf.structured(DCAEConfig), cfg)) 510 | cfg.pretrained_path = pretrained_path 511 | return cfg 512 | 513 | 514 | def dc_ae_f128c512(name: str, pretrained_path: Optional[str] = None) -> DCAEConfig: 515 | if name in ["dc-ae-f128c512-in-1.0", "dc-ae-f128c512-mix-1.0"]: 516 | cfg_str = ( 517 | "latent_channels=512 " 518 | "encoder.block_type=[ResBlock,ResBlock,ResBlock,EViT_GLU,EViT_GLU,EViT_GLU,EViT_GLU,EViT_GLU] " 519 | "encoder.width_list=[128,256,512,512,1024,1024,2048,2048] encoder.depth_list=[0,4,8,2,2,2,2,2] " 520 | "decoder.block_type=[ResBlock,ResBlock,ResBlock,EViT_GLU,EViT_GLU,EViT_GLU,EViT_GLU,EViT_GLU] " 521 | "decoder.width_list=[128,256,512,512,1024,1024,2048,2048] decoder.depth_list=[0,5,10,2,2,2,2,2] " 522 | "decoder.norm=[bn2d,bn2d,bn2d,ln2d,ln2d,ln2d,ln2d,ln2d] decoder.act=[relu,relu,relu,silu,silu,silu,silu,silu]" 523 | ) 524 | else: 525 | raise NotImplementedError 526 | cfg = OmegaConf.from_dotlist(cfg_str.split(" ")) 527 | cfg: DCAEConfig = OmegaConf.to_object(OmegaConf.merge(OmegaConf.structured(DCAEConfig), cfg)) 528 | cfg.pretrained_path = pretrained_path 529 | return cfg 530 | -------------------------------------------------------------------------------- /modeling/generator/dc_ar.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 NVIDIA CORPORATION & AFFILIATES 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # SPDX-License-Identifier: Apache-2.0 16 | 17 | # This file is modified from https://github.com/bytedance/1d-tokenizer and https://github.com/NVlabs/Sana 18 | 19 | 20 | import math 21 | import json 22 | from pathlib import Path 23 | 24 | import numpy as np 25 | import torch 26 | import torch.nn as nn 27 | from einops import rearrange 28 | from timm.models.layers import DropPath 29 | 30 | from huggingface_hub import PyTorchModelHubMixin 31 | from omegaconf import OmegaConf 32 | from transformers import AutoModel, AutoTokenizer 33 | 34 | from modeling.utils import tokenize_fn 35 | from modeling.modules.base_model import BaseModel 36 | from modeling.diffusion import DiffLoss 37 | 38 | from .net.basic_modules import DWMlp, GLUMBConv, MBConvPreGLU, Mlp 39 | from .net.blocks import ( 40 | Attention, 41 | CaptionEmbedder, 42 | FlashAttention, 43 | LiteLA, 44 | MultiHeadCrossAttention, 45 | T2IFinalLayer, 46 | ) 47 | from .net.utils import auto_grad_checkpoint, to_2tuple 48 | 49 | 50 | 51 | class DCARBlock(nn.Module): 52 | 53 | def __init__( 54 | self, 55 | hidden_size, 56 | num_heads, 57 | mlp_ratio=4.0, 58 | drop_path=0, 59 | input_size=None, 60 | qk_norm=False, 61 | attn_type='flash', 62 | ffn_type='mlp', 63 | mlp_acts={"silu", "silu", None}, 64 | linear_head_dim=32, 65 | **block_kwargs, 66 | ): 67 | super().__init__() 68 | self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 69 | if attn_type == 'flash': 70 | self.attn = FlashAttention( 71 | hidden_size, 72 | num_heads=num_heads, 73 | qkv_bias=True, 74 | qk_norm=qk_norm, 75 | **block_kwargs, 76 | ) 77 | elif attn_type == "linear": 78 | self_num_heads = hidden_size // linear_head_dim 79 | self.attn = LiteLA(hidden_size, hidden_size, heads=self_num_heads, eps=1e-8, qk_norm=qk_norm) 80 | elif attn_type == "vanilla": 81 | self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True) 82 | else: 83 | raise ValueError(f"{attn_type} type is not defined.") 84 | 85 | self.cross_attn = MultiHeadCrossAttention(hidden_size, num_heads, **block_kwargs) 86 | self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 87 | if ffn_type == "dwmlp": 88 | approx_gelu = lambda: nn.GELU(approximate="tanh") 89 | self.mlp = DWMlp( 90 | in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0 91 | ) 92 | elif ffn_type == "glumbconv": 93 | self.mlp = GLUMBConv( 94 | in_features=hidden_size, 95 | hidden_features=int(hidden_size * mlp_ratio), 96 | use_bias=(True, True, False), 97 | norm=(None, None, None), 98 | act=mlp_acts, 99 | ) 100 | elif ffn_type == "glumbconv_dilate": 101 | self.mlp = GLUMBConv( 102 | in_features=hidden_size, 103 | hidden_features=int(hidden_size * mlp_ratio), 104 | use_bias=(True, True, False), 105 | norm=(None, None, None), 106 | act=mlp_acts, 107 | dilation=2, 108 | ) 109 | elif ffn_type == "mbconvpreglu": 110 | self.mlp = MBConvPreGLU( 111 | in_dim=hidden_size, 112 | out_dim=hidden_size, 113 | mid_dim=int(hidden_size * mlp_ratio), 114 | use_bias=(True, True, False), 115 | norm=None, 116 | act=("silu", "silu", None), 117 | ) 118 | elif ffn_type == "mlp": 119 | approx_gelu = lambda: nn.GELU(approximate="tanh") 120 | self.mlp = Mlp( 121 | in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0 122 | ) 123 | else: 124 | raise ValueError(f"{ffn_type} type is not defined.") 125 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 126 | 127 | def forward(self, x, y, mask=None, attn_bias=None, **kwargs): 128 | B, N, C = x.shape 129 | 130 | x = x + self.drop_path(self.attn(self.norm1(x), attn_bias=attn_bias).reshape(B, N, C)) 131 | x = x + self.cross_attn(x, y, mask) 132 | x = x + self.drop_path(self.mlp(self.norm2(x))) 133 | 134 | return x 135 | 136 | 137 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, pe_interpolation=1.0, base_size=16): 138 | """ 139 | grid_size: int of the grid height and width 140 | return: 141 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 142 | """ 143 | if isinstance(grid_size, int): 144 | grid_size = to_2tuple(grid_size) 145 | grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / pe_interpolation 146 | grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / pe_interpolation 147 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 148 | grid = np.stack(grid, axis=0) 149 | grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) 150 | 151 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 152 | if cls_token and extra_tokens > 0: 153 | pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) 154 | return pos_embed 155 | 156 | 157 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 158 | assert embed_dim % 2 == 0 159 | 160 | # use half of dimensions to encode grid_h 161 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 162 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 163 | 164 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 165 | return emb 166 | 167 | 168 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 169 | """ 170 | embed_dim: output dimension for each position 171 | pos: a list of positions to be encoded: size (M,) 172 | out: (M, D) 173 | """ 174 | assert embed_dim % 2 == 0 175 | omega = np.arange(embed_dim // 2, dtype=np.float64) 176 | omega /= embed_dim / 2.0 177 | omega = 1.0 / 10000**omega # (D/2,) 178 | 179 | pos = pos.reshape(-1) # (M,) 180 | out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product 181 | 182 | emb_sin = np.sin(out) # (M, D/2) 183 | emb_cos = np.cos(out) # (M, D/2) 184 | 185 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 186 | return emb 187 | 188 | 189 | class DCAR(BaseModel, PyTorchModelHubMixin): 190 | def __init__(self, config): 191 | 192 | if isinstance(config, dict): 193 | config = OmegaConf.create(config) 194 | 195 | super().__init__() 196 | self.config = config 197 | self.target_codebook_size = config.model.vq_model.codebook_size 198 | self.image_seq_len = config.model.generator.image_seq_len 199 | self.mask_token_id = self.target_codebook_size 200 | self.hidden_size = config.model.generator.hidden_size 201 | self.num_hidden_layers = config.model.generator.num_hidden_layers 202 | self.num_attention_heads = config.model.generator.num_attention_heads 203 | self.intermediate_size = config.model.generator.intermediate_size 204 | 205 | self.text_tokenizer = AutoTokenizer.from_pretrained(config.model.text_model) 206 | text_encoder = AutoModel.from_pretrained(config.model.text_model) 207 | text_encoder.requires_grad_(False) 208 | text_encoder.eval().cuda() 209 | self.text_encoder = (text_encoder,) 210 | self.text_tokenizer_max_length = config.model.text_token_length 211 | 212 | self.reasoning = False 213 | if config.model.type == 'vq_dc_ae_reasoning': 214 | self.reasoning = True 215 | self.thinking_token_len = int(self.image_seq_len / (1 + config.model.vq_model.reasoning_downscale ** 2)) 216 | self.generating_token_len = self.image_seq_len - self.thinking_token_len 217 | 218 | self.x_embedder = nn.Embedding( 219 | self.target_codebook_size + 1, 220 | self.hidden_size 221 | ) 222 | approx_gelu = lambda: nn.GELU(approximate="tanh") 223 | self.y_embedder = CaptionEmbedder(in_channels=config.model.context_dim, hidden_size=self.hidden_size, act_layer=approx_gelu, token_num=config.model.text_token_length) 224 | self.pos_embed = nn.Parameter(torch.zeros(1, self.image_seq_len, self.hidden_size), requires_grad=False) 225 | self.pe_interpolation = config.model.generator.get('pe_interpolation', 1.0) 226 | self.base_size = config.model.generator.get('base_size', 16) 227 | drop_path = [x.item() for x in torch.linspace(0, config.model.generator.drop_path, self.num_hidden_layers)] 228 | 229 | self.blocks = nn.ModuleList([ 230 | DCARBlock(self.hidden_size, self.num_attention_heads, self.intermediate_size / self.hidden_size, 231 | drop_path=drop_path[i], input_size=(int(self.image_seq_len ** 0.5), int(self.image_seq_len ** 0.5)), 232 | qk_norm=config.model.generator.get('qk_norm', False), attn_type=config.model.generator.get('attn_type', 'flash'), ffn_type=config.model.generator.get('ffn_type', 'mlp'), 233 | mlp_acts=("silu", "silu", None), linear_head_dim=32,) 234 | for i in range(self.num_hidden_layers) 235 | ]) 236 | self.final_layer = T2IFinalLayer(self.hidden_size, 1, self.target_codebook_size) 237 | 238 | diffusion_config = config.model.generator.diffusion 239 | self.diffloss = DiffLoss( 240 | target_channels=config.model.vq_model.token_size, 241 | z_channels=self.hidden_size, 242 | width=diffusion_config.width, 243 | depth=diffusion_config.depth, 244 | num_sampling_steps=diffusion_config.num_sampling_steps, 245 | sampler=diffusion_config.sampler, 246 | vae_scale=diffusion_config.vae_scale 247 | ) 248 | self.diffusion_batch_mul = diffusion_config.batch_mul 249 | 250 | self.initialize_weights() 251 | 252 | def _save_pretrained(self, save_directory: Path) -> None: 253 | """Save weights and config to a local directory.""" 254 | dict_config = OmegaConf.to_container(self.config) 255 | file_path = Path(save_directory) / "config.json" 256 | with open(file_path, 'w') as json_file: 257 | json.dump(dict_config, json_file, indent=4) 258 | super()._save_pretrained(save_directory) 259 | 260 | def forward_diff_loss(self, z, target, mask=None): 261 | bs, seq_len, _ = target.shape 262 | target = target.reshape(bs * seq_len, -1).repeat(self.diffusion_batch_mul, 1) 263 | z = z.reshape(bs * seq_len, -1).repeat(self.diffusion_batch_mul, 1) 264 | loss = self.diffloss(z=z, target=target, mask=mask) 265 | return loss 266 | 267 | 268 | def initialize_weights(self): 269 | def _basic_init(module): 270 | if isinstance(module, nn.Linear): 271 | torch.nn.init.xavier_uniform_(module.weight) 272 | if module.bias is not None: 273 | nn.init.constant_(module.bias, 0) 274 | 275 | self.apply(_basic_init) 276 | 277 | if not self.reasoning: 278 | pos_embed = get_2d_sincos_pos_embed( 279 | self.pos_embed.shape[-1], int(self.image_seq_len ** 0.5), 280 | pe_interpolation=self.pe_interpolation, base_size=self.base_size 281 | ) 282 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 283 | else: 284 | generating_pos_embed = get_2d_sincos_pos_embed( 285 | self.pos_embed.shape[-1], int(self.generating_token_len ** 0.5), 286 | pe_interpolation=self.pe_interpolation, base_size=self.base_size 287 | ) 288 | thinking_pos_embed = get_2d_sincos_pos_embed( 289 | self.pos_embed.shape[-1], int(self.thinking_token_len ** 0.5), 290 | pe_interpolation=self.pe_interpolation, base_size=self.base_size 291 | ) 292 | pos_embed = np.concatenate([generating_pos_embed, thinking_pos_embed], axis=0) 293 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 294 | 295 | w = self.x_embedder.weight.data 296 | nn.init.trunc_normal_(w, mean=0.0, std=0.02) 297 | 298 | nn.init.normal_(self.y_embedder.y_proj.fc1.weight, std=0.02) 299 | nn.init.normal_(self.y_embedder.y_proj.fc2.weight, std=0.02) 300 | 301 | nn.init.constant_(self.final_layer.linear.weight, 0) 302 | nn.init.constant_(self.final_layer.linear.bias, 0) 303 | 304 | 305 | def encode_text(self, text): 306 | encoder_attention_mask = text != self.text_tokenizer.pad_token_id 307 | with torch.no_grad(): 308 | encoder_hidden_states = self.text_encoder[0].encoder(text).last_hidden_state 309 | 310 | return encoder_hidden_states, encoder_attention_mask 311 | 312 | 313 | def forward(self, input_ids=None, condition=None, encoder_hidden_states=None, encoder_attention_mask=None, residual_features=None, cond_drop_prob=0.1): 314 | if input_ids is None: 315 | raise NotImplementedError 316 | 317 | if not isinstance(condition, torch.Tensor): 318 | condition = tokenize_fn( 319 | condition, 320 | tokenizer=self.text_tokenizer, 321 | max_length=self.text_tokenizer_max_length, 322 | padding_mode='max_length', 323 | ) 324 | condition = torch.stack(condition).cuda() 325 | 326 | 327 | x = self.x_embedder(input_ids) + self.pos_embed 328 | 329 | if encoder_hidden_states is None: 330 | with torch.no_grad(): 331 | y, mask = self.encode_text(condition) 332 | else: 333 | y, mask = encoder_hidden_states, encoder_attention_mask 334 | 335 | y = self.y_embedder(y, self.training, cond_drop_prob) 336 | 337 | if mask is not None: 338 | if mask.shape[0] != y.shape[0]: 339 | mask = mask.repeat(y.shape[0] // mask.shape[0], 1) 340 | mask = mask.squeeze(1) 341 | y = y.masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1]) 342 | y_lens = mask.sum(dim=1).tolist() 343 | else: 344 | y_lens = [y.shape[2]] * y.shape[0] 345 | y = y.view(1, -1, x.shape[-1]) 346 | 347 | for block in self.blocks: 348 | x = auto_grad_checkpoint(block, x, y, y_lens) 349 | 350 | logits = self.final_layer(x) 351 | if residual_features is not None: 352 | diff_loss = self.forward_diff_loss( 353 | z=x, target=residual_features 354 | ) 355 | else: 356 | diff_loss = None 357 | 358 | return logits, x, diff_loss 359 | 360 | def masking_input_tokens(self, input_tokens): 361 | batch_size, seq_len = input_tokens.shape 362 | device = input_tokens.device 363 | 364 | timesteps = torch.zeros((batch_size,), device=device).float().uniform_(0, 1.0) 365 | mask_ratio = torch.acos(timesteps) / (math.pi * 0.5) # arccos schedule 366 | mask_ratio = torch.clamp(mask_ratio, min=1e-6, max=1.) 367 | num_token_masked = (seq_len * mask_ratio).round().clamp(min=1) 368 | batch_randperm = torch.rand(batch_size, seq_len, device=device).argsort(dim=-1) 369 | masks = batch_randperm < rearrange(num_token_masked, 'b -> b 1') 370 | masked_tokens = torch.where(masks, self.mask_token_id, input_tokens) 371 | return masked_tokens, masks 372 | 373 | 374 | @torch.no_grad() 375 | def generate(self, 376 | condition, 377 | init_image_tokens=None, 378 | init_residual_features=None, 379 | init_mask=None, 380 | guidance_scale=3.0, 381 | guidance_decay="constant", 382 | guidance_scale_pow=3.0, 383 | randomize_temperature=4.5, 384 | softmax_temperature_annealing=False, 385 | num_sample_steps=8): 386 | if guidance_decay not in ["constant", "linear", "power-cosine"]: 387 | # contstant: constant guidance scale 388 | # linear: linear increasing the guidance scale as in MUSE 389 | # power-cosine: the guidance schedule from MDT 390 | raise ValueError(f"Unsupported guidance decay {guidance_decay}") 391 | 392 | if not isinstance(condition, torch.Tensor): 393 | condition = tokenize_fn( 394 | condition, 395 | tokenizer=self.text_tokenizer, 396 | max_length=self.text_tokenizer_max_length, 397 | padding_mode='max_length', 398 | ) 399 | condition = torch.stack(condition).cuda() 400 | encoder_hidden_states, encoder_attention_mask = self.encode_text(condition) 401 | 402 | device = condition.device 403 | if init_image_tokens is not None: 404 | ids = init_image_tokens 405 | else: 406 | ids = torch.full((condition.shape[0], self.image_seq_len), 407 | self.mask_token_id, device=device) 408 | 409 | cfg_scale = guidance_scale if guidance_decay == "constant" else 0. 410 | 411 | for step in range(num_sample_steps): 412 | ratio = 1. * (step + 1) / num_sample_steps 413 | annealed_temp = randomize_temperature * (1.0 - ratio) 414 | is_mask = (ids == self.mask_token_id) 415 | 416 | if guidance_decay == "power-cosine": 417 | guidance_scale_pow = torch.ones((1), device=device) * guidance_scale_pow 418 | scale_step = (1 - torch.cos(((step / num_sample_steps) ** guidance_scale_pow) * torch.pi)) * 1/2 419 | cfg_scale = (guidance_scale - 1) * scale_step + 1 420 | 421 | if cfg_scale != 0: 422 | logits, latents, _ = self.forward( 423 | torch.cat([ids, ids], dim=0), 424 | torch.cat([condition, condition], dim=0), 425 | encoder_hidden_states=torch.cat([encoder_hidden_states, encoder_hidden_states], dim=0), 426 | encoder_attention_mask=torch.cat([encoder_attention_mask, encoder_attention_mask], dim=0), 427 | cond_drop_prob=torch.cat([torch.zeros(ids.shape[0]), torch.ones(ids.shape[0])], dim=0).cuda() 428 | ) 429 | cond_logits, uncond_logits = logits.chunk(2, dim=0) 430 | cond_latents, uncond_latents = latents.chunk(2, dim=0) 431 | if guidance_decay == "power-cosine": 432 | logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale 433 | else: 434 | logits = cond_logits + (cond_logits - uncond_logits) * cfg_scale 435 | else: 436 | logits, latents, _ = self.forward( 437 | ids, condition, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, cond_drop_prob=0.0 438 | ) 439 | 440 | if softmax_temperature_annealing: 441 | softmax_temperature = 0.5 + 0.8 * (1 - ratio) 442 | logits = logits / softmax_temperature 443 | 444 | # Add gumbel noise 445 | def log(t, eps=1e-20): 446 | return torch.log(t.clamp(min=eps)) 447 | def gumbel_noise(t): 448 | noise = torch.zeros_like(t).uniform_(0, 1) 449 | return -log(-log(noise)) 450 | def add_gumbel_noise(t, temperature): 451 | return t + temperature * gumbel_noise(t) 452 | 453 | sampled_ids = add_gumbel_noise(logits, annealed_temp).argmax(dim=-1) 454 | sampled_logits = torch.squeeze( 455 | torch.gather(logits, dim=-1, index=torch.unsqueeze(sampled_ids, -1)), -1) 456 | sampled_ids = torch.where(is_mask, sampled_ids, ids) 457 | sampled_logits = torch.where(is_mask, sampled_logits, +np.inf).float() 458 | # masking 459 | mask_ratio = np.arccos(ratio) / (math.pi * 0.5) 460 | 461 | mask_len = torch.Tensor([np.floor(self.image_seq_len * mask_ratio)]).to(device) 462 | mask_len = torch.maximum(torch.Tensor([1]).to(device), 463 | torch.minimum(torch.sum(is_mask, dim=-1, keepdims=True) - 1, 464 | mask_len))[0].squeeze() 465 | confidence = add_gumbel_noise(sampled_logits, annealed_temp) 466 | sorted_confidence, _ = torch.sort(confidence, axis=-1) 467 | cut_off = sorted_confidence[:, mask_len.long() - 1:mask_len.long()] 468 | masking = (confidence <= cut_off) 469 | if step == num_sample_steps - 1: 470 | ids = sampled_ids 471 | latents = torch.cat([cond_latents, uncond_latents], dim=0) 472 | else: 473 | ids = torch.where(masking, self.mask_token_id, sampled_ids) 474 | 475 | if guidance_decay == "linear": 476 | cfg_scale = ratio * guidance_scale 477 | 478 | bs, seq_len, _ = latents.shape 479 | predicted_latents = latents.reshape(bs * seq_len, -1) 480 | residual_features = self.diffloss.sample( 481 | z=predicted_latents, temperature=1.0, cfg=cfg_scale 482 | ).reshape(bs, seq_len, -1) 483 | residual_features, _ = residual_features.chunk(2, dim=0) 484 | 485 | if init_residual_features is not None: 486 | residual_features = torch.where(init_mask.unsqueeze(-1).expand_as(residual_features), init_residual_features, residual_features) 487 | 488 | return ids, residual_features --------------------------------------------------------------------------------