├── .flake8 ├── .gitignore ├── LICENSE ├── README.md ├── assets ├── model_overview.png ├── sample_image.jpg └── sample_video.mp4 ├── configs └── deepspeed │ ├── zero2_bf16.json │ └── zero2_fp16.json ├── diffnext ├── __init__.py ├── config │ ├── __init__.py │ ├── defaults.py │ └── yacs.py ├── data │ ├── __init__.py │ ├── builder.py │ ├── flex_loaders.py │ ├── flex_pipelines.py │ ├── flex_transforms.py │ └── utils.py ├── engine │ ├── __init__.py │ ├── builder.py │ ├── coordinator.py │ ├── lr_scheduler.py │ ├── model_ema.py │ ├── train_engine.py │ └── utils.py ├── image_processor.py ├── models │ ├── __init__.py │ ├── autoencoders │ │ ├── __init__.py │ │ ├── autoencoder_kl.py │ │ ├── autoencoder_kl_cogvideox.py │ │ ├── autoencoder_kl_ltx.py │ │ ├── autoencoder_kl_opensora.py │ │ └── modeling_utils.py │ ├── diffusion_mlp.py │ ├── diffusion_transformer.py │ ├── embeddings.py │ ├── flex_attention.py │ ├── guidance_scaler.py │ ├── normalization.py │ ├── text_encoders │ │ ├── __init__.py │ │ └── phi.py │ ├── transformers │ │ ├── __init__.py │ │ ├── transformer_3d.py │ │ └── transformer_nova.py │ └── vision_transformer.py ├── pipelines │ ├── __init__.py │ ├── builder.py │ └── nova │ │ ├── __init__.py │ │ ├── pipeline_nova.py │ │ ├── pipeline_nova_c2i.py │ │ ├── pipeline_train_c2i.py │ │ ├── pipeline_train_t2i.py │ │ ├── pipeline_train_t2v.py │ │ └── pipeline_utils.py ├── schedulers │ ├── __init__.py │ ├── scheduling_ddpm.py │ └── scheduling_flow.py └── utils │ ├── __init__.py │ ├── export_utils.py │ ├── logging.py │ ├── profiler │ ├── __init__.py │ ├── stats.py │ └── timer.py │ ├── registry.py │ └── tensorboard.py ├── docs ├── environment.md ├── evaluation.md ├── inference.md ├── model_zoo.md └── training.md ├── evaluations ├── geneval │ ├── metadata.jsonl │ ├── prompts.json │ └── sample.py └── vbench │ ├── prompts.json │ └── sample.py ├── pyproject.toml ├── requirements.txt ├── scripts ├── app_nova_t2i.py ├── app_nova_t2v.py └── train.py ├── setup.py └── version.txt /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 100 3 | ignore = 4 | # whitespace before ':' (conflicted with Black) 5 | E203, 6 | # ambiguous variable name 7 | E741, 8 | # ‘from module import *’ used; unable to detect undefined names 9 | F403, 10 | # name may be undefined, or defined from star imports: module 11 | F405, 12 | # redefinition of unused name from line N 13 | F811, 14 | # undefined name 15 | F821, 16 | # line break before binary operator 17 | W503, 18 | # line break after binary operator 19 | W504 20 | # module imported but unused 21 | per-file-ignores = __init__.py: F401 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files 2 | *.slo 3 | *.lo 4 | *.o 5 | *.cuo 6 | 7 | # Compiled Dynamic libraries 8 | *.so 9 | *.dll 10 | *.dylib 11 | 12 | # Compiled Static libraries 13 | *.lai 14 | *.la 15 | *.a 16 | *.lib 17 | 18 | # Compiled python 19 | *.pyc 20 | __pycache__ 21 | 22 | # Compiled MATLAB 23 | *.mex* 24 | 25 | # IPython notebook checkpoints 26 | .ipynb_checkpoints 27 | 28 | # Editor temporaries 29 | *.swp 30 | *~ 31 | 32 | # Sublime Text settings 33 | *.sublime-workspace 34 | *.sublime-project 35 | 36 | # Eclipse Project settings 37 | *.*project 38 | .settings 39 | 40 | # QtCreator files 41 | *.user 42 | 43 | # VSCode files 44 | .vscode 45 | 46 | # IDEA files 47 | .idea 48 | 49 | # OSX dir files 50 | .DS_Store 51 | 52 | # Android files 53 | .gradle 54 | *.iml 55 | local.properties 56 | -------------------------------------------------------------------------------- /assets/model_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baaivision/NOVA/f4524db37dfec29c10c6afcd43979bdb59311688/assets/model_overview.png -------------------------------------------------------------------------------- /assets/sample_image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baaivision/NOVA/f4524db37dfec29c10c6afcd43979bdb59311688/assets/sample_image.jpg -------------------------------------------------------------------------------- /assets/sample_video.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baaivision/NOVA/f4524db37dfec29c10c6afcd43979bdb59311688/assets/sample_video.mp4 -------------------------------------------------------------------------------- /configs/deepspeed/zero2_bf16.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_micro_batch_size_per_gpu": 1, 3 | "bf16": { 4 | "enabled": true 5 | }, 6 | "zero_optimization": { 7 | "stage": 2 8 | } 9 | } 10 | -------------------------------------------------------------------------------- /configs/deepspeed/zero2_fp16.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_micro_batch_size_per_gpu": 1, 3 | "fp16": { 4 | "enabled": true, 5 | "initial_scale_power": 16, 6 | "loss_scale_window": 1000 7 | }, 8 | "zero_optimization": { 9 | "stage": 2 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /diffnext/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2024-present, BAAI. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, esither express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """DiffNext: A diffusers based library for autoregressive diffusion models.""" 17 | -------------------------------------------------------------------------------- /diffnext/config/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2024-present, BAAI. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """Platform configurations.""" 17 | 18 | from diffnext.config.defaults import cfg # noqa 19 | -------------------------------------------------------------------------------- /diffnext/config/defaults.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2024-present, BAAI. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """Default configurations.""" 17 | 18 | from diffnext.config.yacs import CfgNode 19 | 20 | _C = cfg = CfgNode() 21 | 22 | # ------------------------------------------------------------ 23 | # Training options 24 | # ------------------------------------------------------------ 25 | _C.TRAIN = CfgNode() 26 | 27 | # The train dataset 28 | _C.TRAIN.DATASET = "" 29 | 30 | # The train dataset2 31 | _C.TRAIN.DATASET2 = "" 32 | 33 | # The loader type for training 34 | _C.TRAIN.LOADER = "vae_train" 35 | 36 | # The number of threads to load train data per GPU 37 | _C.TRAIN.NUM_THREADS = 4 38 | 39 | # Images to fill per mini-batch 40 | _C.TRAIN.BATCH_SIZE = 1 41 | 42 | # The EMA decay to smooth the checkpoints 43 | _C.TRAIN.MODEL_EMA = 0.99 44 | 45 | # Device to place the EMA model ("cpu", "gpu") 46 | _C.TRAIN.DEVICE_EMA = "gpu" 47 | 48 | # Condition repeat factor to enlarge noise sampling 49 | _C.TRAIN.LOSS_REPEAT = 4 50 | 51 | # The model checkpointing level 52 | _C.TRAIN.CHECKPOINTING = 2 53 | 54 | # ------------------------------------------------------------ 55 | # Model options 56 | # ------------------------------------------------------------ 57 | _C.MODEL = CfgNode() 58 | 59 | # The module type of model 60 | _C.MODEL.TYPE = "transformer" 61 | 62 | # The config dict 63 | _C.MODEL.CONFIG = {} 64 | 65 | # Initialize model with weights from this file 66 | _C.MODEL.WEIGHTS = "" 67 | 68 | # The compute precision 69 | _C.MODEL.PRECISION = "bfloat16" 70 | 71 | # ------------------------------------------------------------ 72 | # Pipeline options 73 | # ------------------------------------------------------------ 74 | _C.PIPELINE = CfgNode() 75 | 76 | # The registered pipeline type 77 | _C.PIPELINE.TYPE = "" 78 | 79 | # The dict of pipeline modules 80 | _C.PIPELINE.MODULES = {} 81 | 82 | # ------------------------------------------------------------ 83 | # Solver options 84 | # ------------------------------------------------------------ 85 | _C.SOLVER = CfgNode() 86 | 87 | # The interval to display logs 88 | _C.SOLVER.DISPLAY = 20 89 | 90 | # The interval to update ema model 91 | _C.SOLVER.EMA_EVERY = 100 92 | 93 | # The interval to snapshot a model 94 | _C.SOLVER.SNAPSHOT_EVERY = 5000 95 | 96 | # Prefix to yield the path: _iter_XYZ 97 | _C.SOLVER.SNAPSHOT_PREFIX = "model" 98 | 99 | # Maximum number of SGD iterations 100 | _C.SOLVER.MAX_STEPS = 2147483647 101 | 102 | # Base learning rate for the specified scheduler 103 | _C.SOLVER.BASE_LR = 0.0001 104 | 105 | # Minimal learning rate for the specified scheduler 106 | _C.SOLVER.MIN_LR = 0.0 107 | 108 | # The decay intervals for LRScheduler 109 | _C.SOLVER.DECAY_STEPS = [] 110 | 111 | # The decay factor for exponential LRScheduler 112 | _C.SOLVER.DECAY_GAMMA = 0.5 113 | 114 | # Warm up to ``BASE_LR`` over this number of steps 115 | _C.SOLVER.WARM_UP_STEPS = 250 116 | 117 | # Start the warm up from ``BASE_LR`` * ``FACTOR`` 118 | _C.SOLVER.WARM_UP_FACTOR = 1.0 / 1000 119 | 120 | # The type of optimizier 121 | _C.SOLVER.OPTIMIZER = "AdamW" 122 | 123 | # The adam beta2 value 124 | _C.SOLVER.ADAM_BETA2 = 0.95 125 | 126 | # The type of lr scheduler 127 | _C.SOLVER.LR_POLICY = "" 128 | 129 | # Gradient accumulation steps per SGD iteration 130 | _C.SOLVER.ACCUM_STEPS = 1 131 | 132 | # L2 regularization for weight parameters 133 | _C.SOLVER.WEIGHT_DECAY = 0.02 134 | 135 | # ------------------------------------------------------------ 136 | # Misc options 137 | # ------------------------------------------------------------ 138 | # Number of GPUs for distributed training 139 | _C.NUM_GPUS = 1 140 | 141 | # Random seed for reproducibility 142 | _C.RNG_SEED = 3 143 | 144 | # Default GPU device index 145 | _C.GPU_ID = 0 146 | -------------------------------------------------------------------------------- /diffnext/config/yacs.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2024-present, BAAI. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """Yet Another Configuration System (YACS).""" 17 | 18 | import copy 19 | 20 | import numpy as np 21 | import yaml 22 | 23 | 24 | class CfgNode(dict): 25 | """Node for configuration options.""" 26 | 27 | IMMUTABLE = "__immutable__" 28 | 29 | def __init__(self, *args, **kwargs): 30 | super(CfgNode, self).__init__(*args, **kwargs) 31 | self.__dict__[CfgNode.IMMUTABLE] = False 32 | 33 | def clone(self): 34 | """Recursively copy this CfgNode.""" 35 | return copy.deepcopy(self) 36 | 37 | def freeze(self): 38 | """Make this CfgNode and all of its children immutable.""" 39 | self._immutable(True) 40 | 41 | def is_frozen(self): 42 | """Return mutability.""" 43 | return self.__dict__[CfgNode.IMMUTABLE] 44 | 45 | def merge_from_file(self, cfg_filename): 46 | """Load a yaml config file and merge it into this CfgNode.""" 47 | with open(cfg_filename, "r") as f: 48 | other_cfg = CfgNode(yaml.safe_load(f)) 49 | self.merge_from_other_cfg(other_cfg) 50 | 51 | def merge_from_list(self, cfg_list): 52 | """Merge config (keys, values) in a list into this CfgNode.""" 53 | assert len(cfg_list) % 2 == 0 54 | from ast import literal_eval 55 | 56 | for k, v in zip(cfg_list[0::2], cfg_list[1::2]): 57 | key_list = k.split(".") 58 | d = self 59 | for sub_key in key_list[:-1]: 60 | assert sub_key in d 61 | d = d[sub_key] 62 | sub_key = key_list[-1] 63 | assert sub_key in d 64 | try: 65 | value = literal_eval(v) 66 | except: # noqa 67 | # Handle the case when v is a string literal 68 | value = v 69 | if type(value) != type(d[sub_key]): # noqa 70 | raise TypeError( 71 | "Type {} does not match original type {}".format(type(value), type(d[sub_key])) 72 | ) 73 | d[sub_key] = value 74 | 75 | def merge_from_other_cfg(self, other_cfg): 76 | """Merge ``other_cfg`` into this CfgNode.""" 77 | _merge_a_into_b(other_cfg, self) 78 | 79 | def _immutable(self, is_immutable): 80 | """Set immutability recursively to all nested CfgNode.""" 81 | self.__dict__[CfgNode.IMMUTABLE] = is_immutable 82 | for v in self.__dict__.values(): 83 | if isinstance(v, CfgNode): 84 | v._immutable(is_immutable) 85 | for v in self.values(): 86 | if isinstance(v, CfgNode): 87 | v._immutable(is_immutable) 88 | 89 | def __getattr__(self, name): 90 | if name in self.__dict__: 91 | return self.__dict__[name] 92 | elif name in self: 93 | return self[name] 94 | else: 95 | raise AttributeError(name) 96 | 97 | def __repr__(self): 98 | return "{}({})".format(self.__class__.__name__, super(CfgNode, self).__repr__()) 99 | 100 | def __setattr__(self, name, value): 101 | if not self.__dict__[CfgNode.IMMUTABLE]: 102 | if name in self.__dict__: 103 | self.__dict__[name] = value 104 | else: 105 | self[name] = value 106 | else: 107 | raise AttributeError( 108 | 'Attempted to set "{}" to "{}", but CfgNode is immutable'.format(name, value) 109 | ) 110 | 111 | def __str__(self): 112 | def _indent(s_, num_spaces): 113 | s = s_.split("\n") 114 | if len(s) == 1: 115 | return s_ 116 | first = s.pop(0) 117 | s = [(num_spaces * " ") + line for line in s] 118 | s = "\n".join(s) 119 | s = first + "\n" + s 120 | return s 121 | 122 | r = "" 123 | s = [] 124 | for k, v in sorted(self.items()): 125 | seperator = "\n" if isinstance(v, CfgNode) else " " 126 | attr_str = "{}:{}{}".format(str(k), seperator, str(v)) 127 | attr_str = _indent(attr_str, 2) 128 | s.append(attr_str) 129 | r += "\n".join(s) 130 | return r 131 | 132 | 133 | def _merge_a_into_b(a, b): 134 | """Merge config dictionary a into config dictionary b, clobbering the 135 | options in b whenever they are also specified in a.""" 136 | if not isinstance(a, dict): 137 | return 138 | for k, v in a.items(): 139 | # a must specify keys that are in b 140 | if k not in b: 141 | raise KeyError("{} is not a valid config key".format(k)) 142 | # The types must match, too 143 | v = _check_and_coerce_cfg_value_type(v, b[k], k) 144 | # Recursively merge dicts 145 | if type(v) is CfgNode: 146 | try: 147 | _merge_a_into_b(a[k], b[k]) 148 | except: # noqa 149 | print("Error under config key: {}".format(k)) 150 | raise 151 | else: 152 | b[k] = v 153 | 154 | 155 | def _check_and_coerce_cfg_value_type(value_a, value_b, key): 156 | """Check if the value type matched.""" 157 | type_a, type_b = type(value_a), type(value_b) 158 | if type_a is type_b: 159 | return value_a 160 | if type_b is float and type_a is int: 161 | return float(value_a) 162 | # Exceptions: numpy arrays, strings, tuple<->list 163 | if isinstance(value_b, np.ndarray): 164 | value_a = np.array(value_a, dtype=value_b.dtype) 165 | elif isinstance(value_a, tuple) and isinstance(value_b, list): 166 | value_a = list(value_a) 167 | elif isinstance(value_a, list) and isinstance(value_b, tuple): 168 | value_a = tuple(value_a) 169 | elif isinstance(value_a, dict) and isinstance(value_b, CfgNode): 170 | value_a = CfgNode(value_a) 171 | return value_a 172 | -------------------------------------------------------------------------------- /diffnext/data/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2024-present, BAAI. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """Data components.""" 17 | 18 | from diffnext.data import flex_pipelines 19 | from diffnext.data.builder import build_loader_train 20 | from diffnext.data.utils import get_dataset_size 21 | -------------------------------------------------------------------------------- /diffnext/data/builder.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2024-present, BAAI. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """Build for data.""" 17 | 18 | from diffnext.config import cfg 19 | from diffnext.utils.registry import Registry 20 | 21 | LOADERS = Registry("loaders") 22 | 23 | 24 | def build_loader_train(**kwargs): 25 | """Build the train loader.""" 26 | args = { 27 | "dataset": cfg.TRAIN.DATASET, 28 | "dataset2": cfg.TRAIN.DATASET2, 29 | "batch_size": cfg.TRAIN.BATCH_SIZE, 30 | "num_threads": cfg.TRAIN.NUM_THREADS, 31 | "seed": cfg.RNG_SEED + cfg.GPU_ID, 32 | "shuffle": True, 33 | } 34 | args.update(kwargs) 35 | return LOADERS.get(cfg.TRAIN.LOADER)(**args) 36 | -------------------------------------------------------------------------------- /diffnext/data/flex_loaders.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2024-present, BAAI. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """Flex data loaders.""" 17 | 18 | import collections 19 | import multiprocessing as mp 20 | import time 21 | import threading 22 | import queue 23 | 24 | import codewithgpu 25 | import numpy as np 26 | import torch 27 | 28 | from diffnext.config import cfg 29 | from diffnext.utils import logging 30 | 31 | 32 | class BalancedQueues(object): 33 | """Balanced queues.""" 34 | 35 | def __init__(self, base_queue, num=1): 36 | self.queues = [base_queue] 37 | self.queues += [mp.Queue(base_queue._maxsize) for _ in range(num - 1)] 38 | self.index = 0 39 | 40 | def put(self, obj, block=True, timeout=None): 41 | q = self.queues[self.index] 42 | q.put(obj, block=block, timeout=timeout) 43 | self.index = (self.index + 1) % len(self.queues) 44 | 45 | def get(self, block=True, timeout=None): 46 | q = self.queues[self.index] 47 | obj = q.get(block=block, timeout=timeout) 48 | self.index = (self.index + 1) % len(self.queues) 49 | return obj 50 | 51 | def get_n(self, num=1): 52 | outputs = [] 53 | while len(outputs) < num: 54 | obj = self.get() 55 | if obj is not None: 56 | outputs.append(obj) 57 | return outputs 58 | 59 | 60 | class DatasetReader(codewithgpu.DatasetReader): 61 | """Enhanced dataset reader to apply update.""" 62 | 63 | def before_first(self): 64 | """Move the cursor before begin.""" 65 | self._current = self._first 66 | self._dataset.seek(self._first) 67 | self._path2 = self._kwargs.get("path2", "") 68 | if self._path2 and not hasattr(self, "_dataset2"): 69 | self._dataset2 = self._dataset_getter(path=self._path2) 70 | self._dataset2.seek(self._first) if self._path2 else None 71 | 72 | def next_example(self): 73 | """Return the next example.""" 74 | example = super(DatasetReader, self).next_example() 75 | example.update(self._dataset2.read()) if self._path2 else None 76 | return example 77 | 78 | 79 | class DataLoaderBase(threading.Thread): 80 | """Base class of data loader.""" 81 | 82 | def __init__(self, worker, **kwargs): 83 | super(DataLoaderBase, self).__init__(daemon=True) 84 | self.batch_size = kwargs.get("batch_size", 2) 85 | self.num_readers = kwargs.get("num_readers", 1) 86 | self.num_workers = kwargs.get("num_workers", 3) 87 | self.queue_depth = kwargs.get("queue_depth", 2) 88 | # Initialize distributed group. 89 | from diffnext.engine import get_ddp_group 90 | 91 | rank, dist_size, dist_group = 0, 1, get_ddp_group() 92 | if dist_group is not None: 93 | rank = torch.distributed.get_rank(dist_group) 94 | dist_size = torch.distributed.get_world_size(dist_group) 95 | # Build queues. 96 | self.reader_queue = mp.Queue(self.queue_depth * self.batch_size) 97 | self.worker_queue = mp.Queue(self.queue_depth * self.batch_size) 98 | self.batch_queue = queue.Queue(self.queue_depth) 99 | self.reader_queue = BalancedQueues(self.reader_queue, self.num_workers) 100 | self.worker_queue = BalancedQueues(self.worker_queue, self.num_workers) 101 | # Build readers. 102 | self.readers = [] 103 | for i in range(self.num_readers): 104 | partition_id = i 105 | num_partitions = self.num_readers 106 | num_partitions *= dist_size 107 | partition_id += rank * self.num_readers 108 | self.readers.append( 109 | DatasetReader( 110 | output_queue=self.reader_queue, 111 | partition_id=partition_id, 112 | num_partitions=num_partitions, 113 | seed=cfg.RNG_SEED + partition_id, 114 | **kwargs, 115 | ) 116 | ) 117 | self.readers[i].start() 118 | time.sleep(0.1) 119 | # Build workers. 120 | self.workers = [] 121 | for i in range(self.num_workers): 122 | p = worker() 123 | p.seed += i + rank * self.num_workers 124 | p.reader_queue = self.reader_queue.queues[i] 125 | p.worker_queue = self.worker_queue.queues[i] 126 | p.start() 127 | self.workers.append(p) 128 | time.sleep(0.1) 129 | 130 | # Register cleanup callbacks. 131 | def cleanup(): 132 | def terminate(processes): 133 | for p in processes: 134 | p.terminate() 135 | p.join() 136 | 137 | terminate(self.workers) 138 | terminate(self.readers) 139 | 140 | import atexit 141 | 142 | atexit.register(cleanup) 143 | # Start batch prefetching. 144 | self.start() 145 | 146 | def next(self): 147 | """Return the next batch of data.""" 148 | return self.__next__() 149 | 150 | def run(self): 151 | """Main loop.""" 152 | 153 | def __call__(self): 154 | return self.next() 155 | 156 | def __iter__(self): 157 | """Return the iterator self.""" 158 | return self 159 | 160 | def __next__(self): 161 | """Return the next batch of data.""" 162 | return [self.batch_queue.get()] 163 | 164 | 165 | class DataLoader(DataLoaderBase): 166 | """Loader to return the batch of data.""" 167 | 168 | def __init__(self, dataset, worker, **kwargs): 169 | base_args = {"path": dataset, "path2": kwargs.get("dataset2", None)} 170 | self.contiguous = kwargs.get("contiguous", True) 171 | self.prefetch_count = kwargs.get("prefetch_count", 50) 172 | base_args["shuffle"] = kwargs.get("shuffle", True) 173 | base_args["batch_size"] = kwargs.get("batch_size", 1) 174 | base_args["num_workers"] = kwargs.get("num_workers", 1) 175 | super(DataLoader, self).__init__(worker, **base_args) 176 | 177 | def run(self): 178 | """Main loop.""" 179 | logging.info("Prefetch batches...") 180 | next_inputs = [] 181 | prev_inputs = self.worker_queue.get_n(self.prefetch_count * self.batch_size) 182 | while True: 183 | # Collect the next batch. 184 | if len(next_inputs) == 0: 185 | next_inputs, prev_inputs = prev_inputs, [] 186 | outputs = collections.defaultdict(list) 187 | for _ in range(self.batch_size): 188 | inputs = next_inputs.pop(0) 189 | for k, v in inputs.items(): 190 | outputs[k].extend(v) 191 | prev_inputs += self.worker_queue.get_n(1) 192 | # Stack batch data. 193 | if self.contiguous: 194 | outputs["moments"] = np.stack(outputs["moments"]) 195 | # Send batch data to consumer. 196 | self.batch_queue.put(outputs) 197 | -------------------------------------------------------------------------------- /diffnext/data/flex_pipelines.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2024-present, BAAI. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """Flex data pipelines.""" 17 | 18 | import multiprocessing 19 | 20 | import cv2 21 | import numpy.random as npr 22 | 23 | from diffnext.config import cfg 24 | from diffnext.data import flex_transforms 25 | from diffnext.data.builder import LOADERS 26 | from diffnext.data.flex_loaders import DataLoader 27 | 28 | 29 | class Worker(multiprocessing.Process): 30 | """Data worker.""" 31 | 32 | def __init__(self): 33 | super(Worker, self).__init__(daemon=True) 34 | self.seed = cfg.RNG_SEED 35 | self.reader_queue = None 36 | self.worker_queue = None 37 | 38 | def run(self): 39 | """Run implementation.""" 40 | # Disable opencv threading. 41 | cv2.setNumThreads(1) 42 | # Fix numpy random seed. 43 | npr.seed(self.seed) 44 | # Main loop. 45 | while True: 46 | outputs = self.get_outputs(self.reader_queue.get()) 47 | self.worker_queue.put(outputs) 48 | 49 | 50 | class VAETrainPipe(object): 51 | """VAE training pipeline.""" 52 | 53 | def __init__(self): 54 | super(VAETrainPipe, self).__init__() 55 | self.parse_moments = flex_transforms.ParseMoments() 56 | self.parse_annotations = flex_transforms.ParseAnnotations() 57 | 58 | def get_outputs(self, inputs): 59 | """Return the outputs.""" 60 | moments = self.parse_moments(inputs) 61 | label, caption = self.parse_annotations(inputs) 62 | aspect_ratio = float(moments.shape[-2]) / float(moments.shape[-1]) 63 | outputs = {"moments": [moments], "aspect_ratio": [aspect_ratio]} 64 | outputs.setdefault("prompt", [label]) if label is not None else None 65 | outputs.setdefault("prompt", [caption]) if caption is not None else None 66 | outputs.setdefault("motion_flow", [inputs["flow"]]) if "flow" in inputs else None 67 | return outputs 68 | 69 | 70 | class VAETrainWorker(VAETrainPipe, Worker): 71 | """VAE training worker.""" 72 | 73 | 74 | LOADERS.register("vae_train", DataLoader, worker=VAETrainWorker) 75 | -------------------------------------------------------------------------------- /diffnext/data/flex_transforms.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2024-present, BAAI. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """Flex data transforms.""" 17 | 18 | import re 19 | 20 | import numpy as np 21 | import numpy.random as npr 22 | 23 | 24 | class Transform(object): 25 | """Base transform type.""" 26 | 27 | def filter_outputs(self, *outputs): 28 | outputs = [x for x in outputs if x is not None] 29 | return outputs if len(outputs) > 1 else outputs[0] 30 | 31 | 32 | class ParseMoments(Transform): 33 | """Parse VAE moments.""" 34 | 35 | def __init__(self): 36 | super(ParseMoments, self).__init__() 37 | 38 | def __call__(self, inputs): 39 | return np.frombuffer(inputs["moments"], "float16").reshape(inputs["shape"]) 40 | 41 | 42 | class ParseAnnotations(Transform): 43 | """Parse ground-truth annotations.""" 44 | 45 | def __init__(self, short_prob=0.5): 46 | super(ParseAnnotations, self).__init__() 47 | self.short_prob = short_prob 48 | 49 | def __call__(self, inputs): 50 | text = inputs.get("text", None) 51 | label = inputs.get("label", None) 52 | caption = inputs.get("caption", None) 53 | if caption and isinstance(caption, dict): # Cached. 54 | caption = np.frombuffer(caption["data"], "float16").reshape(caption["shape"]) 55 | if text and isinstance(text, dict) and len(text["data"]) > 0 and npr.rand() < 0.5: 56 | caption = np.frombuffer(text["data"], "float16").reshape(text["shape"]) 57 | return label, caption 58 | 59 | # Improved short caption. 60 | if label is None: 61 | text_match = re.match(r"^(.*?[.!?])\s+", caption) 62 | text = text if text else (text_match.group(1) if text_match else caption) 63 | caption = text if text and npr.rand() < self.short_prob else caption 64 | return label, caption 65 | -------------------------------------------------------------------------------- /diffnext/data/utils.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2024-present, BAAI. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, esither express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """Data utilities.""" 17 | 18 | import os 19 | import json 20 | 21 | 22 | def get_dataset_size(source): 23 | """Return the dataset size.""" 24 | if source.endswith(".json"): 25 | return len(json.load(open(source, "r", encoding="utf-8"))) 26 | if source.endswith(".txt"): 27 | return len(open(source, "r").readlines()) 28 | meta_file = os.path.join(source, "METADATA") 29 | if os.path.exists(meta_file): 30 | with open(meta_file, "r") as f: 31 | return json.load(f)["entries"] 32 | raise ValueError("Unsupported dataset: " + source) 33 | -------------------------------------------------------------------------------- /diffnext/engine/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2024-present, BAAI. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, esither express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """Engine components.""" 17 | 18 | from diffnext.engine.builder import build_lr_scheduler 19 | from diffnext.engine.builder import build_model_ema 20 | from diffnext.engine.builder import build_optimizer 21 | from diffnext.engine.builder import build_tensorboard 22 | from diffnext.engine.coordinator import Coordinator 23 | from diffnext.engine.train_engine import run_train 24 | from diffnext.engine.utils import apply_ddp 25 | from diffnext.engine.utils import apply_deepspeed 26 | from diffnext.engine.utils import count_params 27 | from diffnext.engine.utils import create_ddp_group 28 | from diffnext.engine.utils import freeze_module 29 | from diffnext.engine.utils import get_ddp_group 30 | from diffnext.engine.utils import get_ddp_rank 31 | from diffnext.engine.utils import get_device 32 | from diffnext.engine.utils import get_param_groups 33 | from diffnext.engine.utils import load_weights 34 | from diffnext.engine.utils import manual_seed 35 | -------------------------------------------------------------------------------- /diffnext/engine/builder.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2024-present, BAAI. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """Engine builders.""" 17 | 18 | import torch 19 | 20 | from diffnext.config import cfg 21 | from diffnext.engine import lr_scheduler 22 | from diffnext.engine import model_ema 23 | 24 | 25 | def build_optimizer(params, **kwargs): 26 | """Build the optimizer.""" 27 | args = {"lr": cfg.SOLVER.BASE_LR, "weight_decay": cfg.SOLVER.WEIGHT_DECAY} 28 | optimizer = kwargs.pop("optimizer", cfg.SOLVER.OPTIMIZER) 29 | args.update(kwargs) 30 | args.setdefault("betas", (0.9, cfg.SOLVER.ADAM_BETA2)) if "Adam" in optimizer else None 31 | return getattr(torch.optim, optimizer)(params, **args) 32 | 33 | 34 | def build_model_ema(model, decay=0): 35 | """Build the EMA model.""" 36 | return model_ema.ModelEMA(model, decay) if decay else None 37 | 38 | 39 | def build_lr_scheduler(**kwargs): 40 | """Build the LR scheduler.""" 41 | args = { 42 | "lr_max": cfg.SOLVER.BASE_LR, 43 | "lr_min": cfg.SOLVER.MIN_LR, 44 | "warmup_steps": cfg.SOLVER.WARM_UP_STEPS, 45 | "warmup_factor": cfg.SOLVER.WARM_UP_FACTOR, 46 | "max_steps": cfg.SOLVER.MAX_STEPS, 47 | } 48 | policy = kwargs.pop("policy", cfg.SOLVER.LR_POLICY) 49 | args.update(kwargs) 50 | if policy == "steps_with_decay": 51 | args["decay_steps"] = cfg.SOLVER.DECAY_STEPS 52 | args["decay_gamma"] = cfg.SOLVER.DECAY_GAMMA 53 | return lr_scheduler.MultiStepLR(**args) 54 | elif policy == "cosine_decay": 55 | return lr_scheduler.CosineLR(**args) 56 | return lr_scheduler.ConstantLR(**args) 57 | 58 | 59 | def build_tensorboard(log_dir): 60 | """Build the tensorboard.""" 61 | from diffnext.utils.tensorboard import TensorBoard 62 | 63 | if TensorBoard.is_available(): 64 | return TensorBoard(log_dir) 65 | return None 66 | -------------------------------------------------------------------------------- /diffnext/engine/coordinator.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2024-present, BAAI. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, esither express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """Experiment coordinator.""" 17 | 18 | import os 19 | import os.path as osp 20 | import time 21 | 22 | import numpy as np 23 | 24 | from diffnext.config import cfg 25 | from diffnext.utils import logging 26 | 27 | 28 | class Coordinator(object): 29 | """Manage the unique experiments.""" 30 | 31 | def __init__(self, cfg_file, exp_dir=None): 32 | cfg.merge_from_file(cfg_file) 33 | if logging.is_root(): 34 | if exp_dir is None: 35 | name = time.strftime("%Y%m%d_%H%M%S", time.localtime(time.time())) 36 | exp_dir = "../experiments/{}".format(name) 37 | if not osp.exists(exp_dir): 38 | os.makedirs(exp_dir, exist_ok=True) 39 | else: 40 | if not osp.exists(exp_dir): 41 | os.makedirs(exp_dir, exist_ok=True) 42 | self.exp_dir = exp_dir 43 | self.deepspeed = None 44 | 45 | def path_at(self, file, auto_create=True): 46 | try: 47 | path = osp.abspath(osp.join(self.exp_dir, file)) 48 | if auto_create and not osp.exists(path): 49 | os.makedirs(path) 50 | except OSError: 51 | path = osp.abspath(osp.join("/tmp", file)) 52 | if auto_create and not osp.exists(path): 53 | os.makedirs(path) 54 | return path 55 | 56 | def get_checkpoint(self, step=None, last_idx=1, wait=False): 57 | path = self.path_at("checkpoints") 58 | 59 | def locate(last_idx=None): 60 | files = os.listdir(path) 61 | files = list(filter(lambda x: "_iter_" in x, files)) 62 | file_steps = [] 63 | for i, file in enumerate(files): 64 | file_step = int(file.split("_iter_")[-1].split(".")[0]) 65 | if step == file_step: 66 | return osp.join(path, files[i]), file_step 67 | file_steps.append(file_step) 68 | if step is None: 69 | if len(files) == 0: 70 | return None, 0 71 | if last_idx > len(files): 72 | return None, 0 73 | file = files[np.argsort(file_steps)[-last_idx]] 74 | file_step = file_steps[np.argsort(file_steps)[-last_idx]] 75 | return osp.join(path, file), file_step 76 | return None, 0 77 | 78 | file, file_step = locate(last_idx) 79 | while file is None and wait: 80 | logging.info("Wait for checkpoint at {}.".format(step)) 81 | time.sleep(10) 82 | file, file_step = locate(last_idx) 83 | return file, file_step 84 | -------------------------------------------------------------------------------- /diffnext/engine/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2024-present, BAAI. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, esither express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """Learning rate schedulers.""" 17 | 18 | import math 19 | 20 | 21 | class ConstantLR(object): 22 | """Constant LR scheduler.""" 23 | 24 | def __init__(self, **kwargs): 25 | self._lr_max = kwargs.pop("lr_max") 26 | self._lr_min = kwargs.pop("lr_min", 0) 27 | self._warmup_steps = kwargs.pop("warmup_steps", 0) 28 | self._warmup_factor = kwargs.pop("warmup_factor", 0) 29 | self._step_count = 0 30 | self._last_decay = 1.0 31 | 32 | def step(self): 33 | self._step_count += 1 34 | 35 | def get_lr(self): 36 | if self._step_count < self._warmup_steps: 37 | alpha = (self._step_count + 1.0) / self._warmup_steps 38 | return self._lr_max * (alpha + (1.0 - alpha) * self._warmup_factor) 39 | return self._lr_min + (self._lr_max - self._lr_min) * self.get_decay() 40 | 41 | def get_decay(self): 42 | return self._last_decay 43 | 44 | 45 | class CosineLR(ConstantLR): 46 | """LR scheduler with cosine decay.""" 47 | 48 | def __init__(self, lr_max, max_steps, lr_min=0, decay_step=1, **kwargs): 49 | super(CosineLR, self).__init__(lr_max=lr_max, lr_min=lr_min, **kwargs) 50 | self._decay_step = decay_step 51 | self._max_steps = max_steps 52 | 53 | def get_decay(self): 54 | t = self._step_count - self._warmup_steps 55 | t_max = self._max_steps - self._warmup_steps 56 | if t > 0 and t % self._decay_step == 0: 57 | self._last_decay = 0.5 * (1.0 + math.cos(math.pi * t / t_max)) 58 | return self._last_decay 59 | 60 | 61 | class MultiStepLR(ConstantLR): 62 | """LR scheduler with multi-steps decay.""" 63 | 64 | def __init__(self, lr_max, decay_steps, decay_gamma, **kwargs): 65 | super(MultiStepLR, self).__init__(lr_max=lr_max, **kwargs) 66 | self._decay_steps = decay_steps 67 | self._decay_gamma = decay_gamma 68 | self._stage_count = 0 69 | self._num_stages = len(decay_steps) 70 | 71 | def get_decay(self): 72 | if self._stage_count < self._num_stages: 73 | k = self._decay_steps[self._stage_count] 74 | while self._step_count >= k: 75 | self._stage_count += 1 76 | if self._stage_count >= self._num_stages: 77 | break 78 | k = self._decay_steps[self._stage_count] 79 | self._last_decay = self._decay_gamma**self._stage_count 80 | return self._last_decay 81 | -------------------------------------------------------------------------------- /diffnext/engine/model_ema.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2024-present, BAAI. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """Exponential Moving Average (EMA) of model updates.""" 17 | 18 | import copy 19 | 20 | import torch 21 | from torch import nn 22 | 23 | 24 | class ModelEMA(nn.Module): 25 | """Model Exponential Moving Average.""" 26 | 27 | def __init__(self, model, decay=0.9999): 28 | super(ModelEMA, self).__init__() 29 | self.decay = decay 30 | self.ema = copy.deepcopy(model).eval() 31 | self.ema._apply(lambda t: t.float() if t.requires_grad else t) # FP32. 32 | [setattr(p, "requires_grad", False) for p in self.ema.parameters()] 33 | 34 | @torch.no_grad() 35 | def update(self, model): 36 | for ema_v, model_v in zip(self.ema.parameters(), model.parameters()): 37 | if model_v.requires_grad: 38 | new_value = model_v.data.float() 39 | value = ema_v.to(device=new_value.device) 40 | ema_v.copy_(value.mul_(self.decay).add_(new_value, alpha=1 - self.decay)) 41 | -------------------------------------------------------------------------------- /diffnext/engine/train_engine.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2024-present, BAAI. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, esither express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """Custom deepspeed trainer focused on data parallelism specialization.""" 17 | 18 | import collections 19 | import os 20 | import shutil 21 | 22 | import torch 23 | 24 | from diffnext import engine 25 | from diffnext.config import cfg 26 | from diffnext.data.builder import build_loader_train 27 | from diffnext.pipelines.builder import build_pipeline, get_pipeline_path 28 | from diffnext.utils import logging 29 | from diffnext.utils import profiler 30 | 31 | 32 | class Trainer(object): 33 | """Schedule the iterative model training.""" 34 | 35 | def __init__(self, coordinator, start_iter=0): 36 | self.coordinator = coordinator 37 | self.loader = build_loader_train() 38 | self.precision = cfg.MODEL.PRECISION.lower() 39 | self.dtype = getattr(torch, self.precision) 40 | self.device_type = engine.get_device(0).type 41 | pipe_conf = {cfg.MODEL.TYPE: cfg.MODEL.CONFIG} if cfg.MODEL.CONFIG else None 42 | pipe_path = get_pipeline_path(cfg.MODEL.WEIGHTS, cfg.PIPELINE.MODULES or None, pipe_conf) 43 | self.pipe = build_pipeline(pipe_path, config=cfg) 44 | self.pipe = self.pipe.to(device=engine.get_device(cfg.GPU_ID)) 45 | self.ema_model = engine.build_model_ema(self.pipe.model, cfg.TRAIN.MODEL_EMA) 46 | self.ema_model.ema.cpu() if cfg.TRAIN.DEVICE_EMA.lower() == "cpu" else None 47 | self.model = self.pipe.configure_model(config=cfg) 48 | self.autocast = torch.autocast(self.device_type, self.dtype) 49 | param_groups = engine.get_param_groups(self.model) 50 | self.optimizer = engine.build_optimizer(param_groups) 51 | self.loss_scaler = torch.amp.GradScaler("cuda", enabled=self.precision == "float16") 52 | self.ds_model = engine.apply_deepspeed(self.model, self.optimizer, coordinator.deepspeed) 53 | self.ddp_model = engine.apply_ddp(self.model.float()) if self.ds_model is None else None 54 | self.scheduler = engine.build_lr_scheduler() 55 | self.metrics, self.board = collections.OrderedDict(), None 56 | if self.ema_model and start_iter > 0: 57 | ema_weights = cfg.MODEL.WEIGHTS.replace("checkpoints", "ema_checkpoints") 58 | ema_weights += "/%s/diffusion_pytorch_model.bin" % cfg.MODEL.TYPE 59 | engine.load_weights(self.ema_model.ema, ema_weights) 60 | 61 | @property 62 | def iter(self): 63 | return self.scheduler._step_count 64 | 65 | def snapshot(self): 66 | """Save the checkpoint of current iterative step.""" 67 | f = cfg.SOLVER.SNAPSHOT_PREFIX + "_iter_{}/{}".format(self.iter, cfg.MODEL.TYPE) 68 | f = os.path.join(self.coordinator.path_at("checkpoints"), f) 69 | if logging.is_root() and not os.path.exists(f): 70 | self.model.save_pretrained(f, safe_serialization=False) 71 | logging.info("Wrote snapshot to: {:s}".format(f)) 72 | if self.ema_model is not None: 73 | config_json = os.path.join(f, "config.json") 74 | f = f.replace("checkpoints", "ema_checkpoints") 75 | os.makedirs(f), shutil.copy(config_json, os.path.join(f, "config.json")) 76 | f = os.path.join(f, "diffusion_pytorch_model.bin") 77 | torch.save(self.ema_model.ema.state_dict(), f) 78 | 79 | def add_metrics(self, stats): 80 | """Add or update the metrics.""" 81 | for k, v in stats["metrics"].items(): 82 | if k not in self.metrics: 83 | self.metrics[k] = profiler.SmoothedValue() 84 | self.metrics[k].update(v) 85 | 86 | def display_metrics(self, stats): 87 | """Send metrics to the monitor.""" 88 | iter_template = "Iteration %d, lr = %.8f, time = %.2fs" 89 | metric_template = " " * 4 + "Train net output({}): {:.4f} ({:.4f})" 90 | logging.info(iter_template % (stats["iter"], stats["lr"], stats["time"])) 91 | for k, v in self.metrics.items(): 92 | logging.info(metric_template.format(k, stats["metrics"][k], v.average())) 93 | if self.board is not None: 94 | self.board.scalar_summary("lr", stats["lr"], stats["iter"]) 95 | self.board.scalar_summary("time", stats["time"], stats["iter"]) 96 | for k, v in self.metrics.items(): 97 | self.board.scalar_summary(k, v.average(), stats["iter"]) 98 | 99 | def step_ddp(self, metrics, accum_steps=1): 100 | """Single DDP optimization step.""" 101 | self.optimizer.zero_grad() 102 | for _ in range(accum_steps): 103 | inputs, _ = self.loader.next()[0], self.autocast.__enter__() 104 | outputs, losses, _ = self.ddp_model(inputs), [], self.autocast.__exit__(0, 0, 0) 105 | for k, v in outputs.items(): 106 | if "loss" not in k: 107 | continue 108 | if isinstance(v, torch.Tensor) and v.requires_grad: 109 | losses.append(v) 110 | metrics[k] += float(v) / accum_steps 111 | losses = sum(losses[1:], losses[0]) 112 | losses = losses.mul_(1.0 / accum_steps) if accum_steps > 1 else losses 113 | self.loss_scaler.scale(losses).backward() 114 | if self.loss_scaler.is_enabled(): 115 | metrics["~loss_scale"] += self.loss_scaler.get_scale() 116 | self.loss_scaler.step(self.optimizer) 117 | self.loss_scaler.update() 118 | 119 | def step_ds(self, metrics, accum_steps=1): 120 | """Single DeepSpeed optimization step.""" 121 | for _ in range(accum_steps): 122 | inputs = self.loader.next()[0] 123 | outputs, losses = self.ds_model(inputs), [] 124 | for k, v in outputs.items(): 125 | if "loss" not in k: 126 | continue 127 | if isinstance(v, torch.Tensor) and v.requires_grad: 128 | losses.append(v) 129 | metrics[k] += float(v) / accum_steps 130 | losses = sum(losses[1:], losses[0]) 131 | losses = losses.mul_(1.0 / accum_steps) if accum_steps > 1 else losses 132 | self.ds_model.backward(losses) 133 | if self.loss_scaler.is_enabled(): 134 | metrics["~loss_scale"] += float(self.ds_model.optimizer._get_loss_scale()) 135 | self.ds_model.step() 136 | 137 | def step(self, accum_steps=1): 138 | """Single model optimization step.""" 139 | stats = {"iter": self.iter} 140 | metrics = collections.defaultdict(float) 141 | timer = profiler.Timer().tic() 142 | stats["lr"] = self.scheduler.get_lr() 143 | for group in self.optimizer.param_groups: 144 | group["lr"] = stats["lr"] * group.get("lr_scale", 1.0) 145 | self.step_ds(metrics, accum_steps) if self.ds_model else None 146 | self.step_ddp(metrics, accum_steps) if self.ddp_model else None 147 | self.scheduler.step() 148 | stats["time"] = timer.toc() 149 | stats["metrics"] = collections.OrderedDict(sorted(metrics.items())) 150 | return stats 151 | 152 | def train_model(self, start_iter=0): 153 | """Training loop.""" 154 | timer = profiler.Timer() 155 | max_steps = cfg.SOLVER.MAX_STEPS 156 | accum_steps = cfg.SOLVER.ACCUM_STEPS 157 | display_every = cfg.SOLVER.DISPLAY 158 | progress_every = 10 * display_every 159 | ema_every = cfg.SOLVER.EMA_EVERY 160 | snapshot_every = cfg.SOLVER.SNAPSHOT_EVERY 161 | self.scheduler._step_count = start_iter 162 | while self.iter < max_steps: 163 | with timer.tic_and_toc(): 164 | stats = self.step(accum_steps) 165 | self.add_metrics(stats) 166 | if stats["iter"] % display_every == 0: 167 | self.display_metrics(stats) 168 | if self.iter % progress_every == 0: 169 | logging.info(profiler.get_progress(timer, self.iter, max_steps)) 170 | if self.iter % ema_every == 0 and self.ema_model: 171 | self.ema_model.update(self.model) 172 | if self.iter % snapshot_every == 0: 173 | self.snapshot() 174 | self.metrics.clear() 175 | 176 | 177 | def run_train(coordinator, start_iter=0, enable_tensorboard=False): 178 | """Start a model training task.""" 179 | trainer = Trainer(coordinator, start_iter=start_iter) 180 | if enable_tensorboard and logging.is_root(): 181 | trainer.board = engine.build_tensorboard(coordinator.path_at("logs")) 182 | logging.info("#Params: %.2fM" % engine.count_params(trainer.model)) 183 | logging.info("Start training...") 184 | trainer.train_model(start_iter) 185 | trainer.ema_model.update(trainer.model) if trainer.ema_model else None 186 | trainer.snapshot() 187 | -------------------------------------------------------------------------------- /diffnext/engine/utils.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2024-present, BAAI. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, esither express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """Engine utilities.""" 17 | 18 | import collections 19 | import pickle 20 | 21 | import numpy as np 22 | import torch 23 | from torch import nn 24 | 25 | from diffnext.utils import logging 26 | 27 | GLOBAL_DDP_GROUP = None 28 | 29 | 30 | def count_params(module, trainable=True, unit="M"): 31 | """Return the number of parameters.""" 32 | counts = [v.size().numel() for v in module.parameters() if v.requires_grad or (not trainable)] 33 | return sum(counts) / {"M": 1e6, "B": 1e9}[unit] 34 | 35 | 36 | def freeze_module(module, trainable=False): 37 | """Freeze parameters of given module.""" 38 | module.eval() if not trainable else module.train() 39 | for param in module.parameters(): 40 | param.requires_grad = trainable 41 | return module 42 | 43 | 44 | def get_device(index): 45 | """Create the available device object.""" 46 | if torch.cuda.is_available(): 47 | return torch.device("cuda", index) 48 | for device_type in ("mps",): 49 | try: 50 | if getattr(torch.backends, device_type).is_available(): 51 | return torch.device(device_type, index) 52 | except AttributeError: 53 | pass 54 | return torch.device("cpu") 55 | 56 | 57 | def get_param_groups(model): 58 | """Separate parameters into groups.""" 59 | memo, groups, lr_scale_getter = set(), collections.OrderedDict(), None 60 | norm_types = (nn.BatchNorm2d, nn.GroupNorm, nn.SyncBatchNorm, nn.LayerNorm) 61 | for module_name, module in model.named_modules(): 62 | for param_name, param in module.named_parameters(recurse=False): 63 | if not param.requires_grad or param in memo: 64 | continue 65 | memo.add(param) 66 | attrs = collections.OrderedDict() 67 | if lr_scale_getter: 68 | attrs["lr_scale"] = lr_scale_getter(f"{module_name}.{param_name}") 69 | if hasattr(param, "lr_scale"): 70 | attrs["lr_scale"] = param.lr_scale 71 | if getattr(param, "no_weight_decay", False) or isinstance(module, norm_types): 72 | attrs["weight_decay"] = 0 73 | group_name = "/".join(["%s:%s" % (v[0], v[1]) for v in list(attrs.items())]) 74 | groups[group_name] = groups.get(group_name, {**attrs, **{"params": []}}) 75 | groups[group_name]["params"].append(param) 76 | return list(groups.values()) 77 | 78 | 79 | def load_weights(module, weights_file, prefix_removed="", strict=True): 80 | """Load a weights file.""" 81 | if not weights_file: 82 | return 83 | if weights_file.endswith(".pkl"): 84 | with open(weights_file, "rb") as f: 85 | state_dict = pickle.load(f) 86 | for k, v in state_dict.items(): 87 | state_dict[k] = torch.as_tensor(v) 88 | else: 89 | state_dict = torch.load(weights_file, map_location="cpu", weights_only=False) 90 | if prefix_removed: 91 | new_state_dict = type(state_dict)() 92 | for k in list(state_dict.keys()): 93 | if k.startswith(prefix_removed): 94 | new_state_dict[k.replace(prefix_removed, "")] = state_dict.pop(k) 95 | state_dict = new_state_dict 96 | module.load_state_dict(state_dict, strict=strict) 97 | 98 | 99 | def manual_seed(seed, device_and_seed=None): 100 | """Set the cpu and device random seed.""" 101 | torch.manual_seed(seed) 102 | if device_and_seed is not None: 103 | device_index, device_seed = device_and_seed 104 | device_type = get_device(device_index).type 105 | np.random.seed(device_seed) 106 | if device_type in ("cuda", "mps"): 107 | getattr(torch, device_type).manual_seed(device_seed) 108 | 109 | 110 | def synchronize_device(device): 111 | """Synchronize the computation of device.""" 112 | if device.type in ("cuda", "mps"): 113 | getattr(torch, device.type).synchronize(device) 114 | 115 | 116 | def create_ddp_group(cfg, ranks=None, devices=None): 117 | """Create group for data parallelism.""" 118 | if not torch.distributed.is_initialized(): 119 | torch.distributed.init_process_group(backend="nccl") 120 | world_rank = torch.distributed.get_rank() 121 | ranks = ranks if ranks else [i for i in range(cfg.NUM_GPUS)] 122 | logging.set_root(world_rank == ranks[0]) 123 | devices = devices if devices else [i % 8 for i in range(len(ranks))] 124 | cfg.GPU_ID = devices[world_rank] 125 | torch.cuda.set_device(cfg.GPU_ID) 126 | global GLOBAL_DDP_GROUP 127 | GLOBAL_DDP_GROUP = torch.distributed.new_group(ranks) 128 | return GLOBAL_DDP_GROUP 129 | 130 | 131 | def get_ddp_group(): 132 | """Return the process group for data parallelism.""" 133 | return GLOBAL_DDP_GROUP 134 | 135 | 136 | def get_ddp_rank(): 137 | """Return the rank in the data parallelism group.""" 138 | ddp_group = get_ddp_group() 139 | if ddp_group is None: 140 | return 0 141 | return torch.distributed.get_rank(ddp_group) 142 | 143 | 144 | def apply_ddp(model): 145 | """Apply distributed data parallelism for given module.""" 146 | ddp_group = get_ddp_group() 147 | if ddp_group is None: 148 | return model 149 | return nn.parallel.DistributedDataParallel(model, process_group=ddp_group) 150 | 151 | 152 | def apply_deepspeed(model, optimizer, ds_config=None, log_lvl="WARNING"): 153 | """Apply deepspeed parallelism for given module.""" 154 | if not ds_config: 155 | return None 156 | import deepspeed 157 | 158 | deepspeed.logger.setLevel(log_lvl) 159 | ds_model = deepspeed.initialize(None, model, optimizer, config=ds_config)[0] 160 | return ds_model 161 | -------------------------------------------------------------------------------- /diffnext/image_processor.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2024-present, BAAI. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """Image processor.""" 17 | 18 | from typing import List, Union 19 | 20 | import numpy as np 21 | import PIL.Image 22 | import torch 23 | from torch import nn 24 | 25 | from diffusers.configuration_utils import ConfigMixin 26 | 27 | 28 | class VaeImageProcessor(ConfigMixin): 29 | """Image processor for VAE.""" 30 | 31 | def postprocess( 32 | self, image: torch.Tensor, output_type: str = "pil" 33 | ) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]: 34 | """Postprocess the image output from tensor. 35 | 36 | Args: 37 | image (torch.Tensor): 38 | The image tensor. 39 | output_type (str, *optional*, defaults to `pil`): 40 | The output image type, can be one of `pil`, `np`, `pt`, `latent`. 41 | 42 | Returns: 43 | Union[PIL.Image.Image, np.ndarray, torch.Tensor]: The postprocessed image. 44 | """ 45 | if output_type == "latent" or output_type == "pt": 46 | return image 47 | image = self.pt_to_numpy(image) 48 | if output_type == "np": 49 | return image 50 | if output_type == "pil" and len(image.shape) == 4: 51 | return self.numpy_to_pil(image) 52 | return image 53 | 54 | @staticmethod 55 | @torch.no_grad() 56 | def decode_latents(vae: nn.Module, latents: torch.Tensor, vae_batch_size=1) -> torch.Tensor: 57 | """Decode VAE latents. 58 | 59 | Args: 60 | vae (torch.nn.Module): 61 | The VAE model. 62 | latents (torch.Tensor): 63 | The input latents. 64 | vae_batch_size (int, *optional*, defaults to 1) 65 | The maximum images in a batch to decode. 66 | 67 | Returns: 68 | torch.Tensor: The output tensor. 69 | 70 | """ 71 | x, batch_size = vae.unscale_(latents), latents.size(0) 72 | sizes, splits = [vae_batch_size] * (batch_size // vae_batch_size), [] 73 | sizes += [batch_size - sum(sizes)] if sum(sizes) != batch_size else [] 74 | for x_split in x.split(sizes) if len(sizes) > 1 else [x]: 75 | splits.append(vae.decode(x_split).sample) 76 | return torch.cat(splits) if len(splits) > 1 else splits[0] 77 | 78 | @staticmethod 79 | def pt_to_numpy(images: torch.Tensor) -> np.ndarray: 80 | """Convert images from a torch tensor to a numpy array. 81 | 82 | Args: 83 | images (torch.Tensor): 84 | The image tensor. 85 | 86 | Returns: 87 | np.ndarry: The image array. 88 | """ 89 | x = images.permute(0, 2, 3, 4, 1) if images.dim() == 5 else images.permute(0, 2, 3, 1) 90 | return x.mul(127.5).add_(127.5).clamp(0, 255).byte().cpu().numpy() 91 | 92 | @staticmethod 93 | def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]: 94 | """Convert images from a numpy array to a list of PIL objects. 95 | 96 | Args: 97 | images (np.ndarray): 98 | The image array. 99 | 100 | Returns: 101 | List[PIL.Image.Image]: A list of PIL images. 102 | """ 103 | images = images[None, ...] if images.ndim == 3 else images 104 | return [PIL.Image.fromarray(image) for image in images] 105 | -------------------------------------------------------------------------------- /diffnext/models/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2024-present, BAAI. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """Models.""" 17 | -------------------------------------------------------------------------------- /diffnext/models/autoencoders/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2024-present, BAAI. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | -------------------------------------------------------------------------------- /diffnext/models/autoencoders/autoencoder_kl.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024-present, BAAI. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | ############################################################################## 15 | """Simple implementation of AutoEncoderKL.""" 16 | 17 | import torch 18 | from torch import nn 19 | 20 | from diffusers.configuration_utils import ConfigMixin, register_to_config 21 | from diffusers.models.modeling_outputs import AutoencoderKLOutput 22 | from diffusers.models.modeling_utils import ModelMixin 23 | 24 | from diffnext.models.autoencoders.modeling_utils import DecoderOutput 25 | from diffnext.models.autoencoders.modeling_utils import DiagonalGaussianDistribution 26 | from diffnext.models.autoencoders.modeling_utils import IdentityDistribution 27 | 28 | 29 | class Attention(nn.Module): 30 | """Multi-headed attention.""" 31 | 32 | def __init__(self, dim, num_heads=1): 33 | super(Attention, self).__init__() 34 | self.num_heads = num_heads or dim // 64 35 | self.head_dim = dim // self.num_heads 36 | self.group_norm = nn.GroupNorm(32, dim, eps=1e-6) 37 | self.to_q, self.to_k, self.to_v = [nn.Linear(dim, dim) for _ in range(3)] 38 | self.to_out = nn.ModuleList([nn.Linear(dim, dim)]) 39 | self._from_deprecated_attn_block = True # Fix for diffusers>=0.15.0 40 | 41 | def forward(self, x) -> torch.Tensor: 42 | x, shape = self.group_norm(x), (-1,) + x.shape[1:] 43 | x = x.flatten(2).transpose(1, 2).contiguous() 44 | qkv_shape = (-1, x.size(1), self.num_heads, self.head_dim) 45 | q, k, v = [f(x).view(qkv_shape).transpose(1, 2) for f in (self.to_q, self.to_k, self.to_v)] 46 | o = nn.functional.scaled_dot_product_attention(q, k, v).transpose(1, 2) 47 | return self.to_out[0](o.flatten(2)).transpose(1, 2).reshape(shape) 48 | 49 | 50 | class Resize(nn.Module): 51 | """Resize layer.""" 52 | 53 | def __init__(self, dim, downsample=1): 54 | super(Resize, self).__init__() 55 | self.conv = nn.Conv2d(dim, dim, 3, 2, 0) if downsample else None 56 | self.conv = nn.Conv2d(dim, dim, 3, 1, 1) if not downsample else self.conv 57 | self.downsample, self.upsample = downsample, int(not downsample) 58 | 59 | def forward(self, x) -> torch.Tensor: 60 | x = nn.functional.pad(x, (0, 1, 0, 1)) if self.downsample else x 61 | return self.conv(nn.functional.interpolate(x, None, (2, 2)) if self.upsample else x) 62 | 63 | 64 | class ResBlock(nn.Module): 65 | """Resnet block.""" 66 | 67 | def __init__(self, dim, out_dim): 68 | super(ResBlock, self).__init__() 69 | self.norm1 = nn.GroupNorm(32, dim, eps=1e-6) 70 | self.conv1 = nn.Conv2d(dim, out_dim, 3, 1, 1) 71 | self.norm2 = nn.GroupNorm(32, out_dim, eps=1e-6) 72 | self.conv2 = nn.Conv2d(out_dim, out_dim, 3, 1, 1) 73 | self.conv_shortcut = nn.Conv2d(dim, out_dim, 1) if out_dim != dim else None 74 | self.nonlinearity = nn.SiLU() 75 | 76 | def forward(self, x) -> torch.Tensor: 77 | shortcut = self.conv_shortcut(x) if self.conv_shortcut else x 78 | x = self.conv1(self.nonlinearity(self.norm1(x))) 79 | return self.conv2(self.nonlinearity(self.norm2(x))).add_(shortcut) 80 | 81 | 82 | class UNetResBlock(nn.Module): 83 | """UNet resnet block.""" 84 | 85 | def __init__(self, dim, out_dim, depth=2, downsample=0, upsample=0): 86 | super(UNetResBlock, self).__init__() 87 | block_dims = [(out_dim, out_dim) if i > 0 else (dim, out_dim) for i in range(depth)] 88 | self.resnets = nn.ModuleList(ResBlock(*dims) for dims in block_dims) 89 | self.downsamplers = nn.ModuleList([Resize(out_dim, 1)]) if downsample else [] 90 | self.upsamplers = nn.ModuleList([Resize(out_dim, 0)]) if upsample else [] 91 | 92 | def forward(self, x) -> torch.Tensor: 93 | for resnet in self.resnets: 94 | x = resnet(x) 95 | x = self.downsamplers[0](x) if self.downsamplers else x 96 | return self.upsamplers[0](x) if self.upsamplers else x 97 | 98 | 99 | class UNetMidBlock(nn.Module): 100 | """UNet mid block.""" 101 | 102 | def __init__(self, dim, num_heads=1, depth=1): 103 | super(UNetMidBlock, self).__init__() 104 | self.resnets = nn.ModuleList(ResBlock(dim, dim) for _ in range(depth + 1)) 105 | self.attentions = nn.ModuleList(Attention(dim, num_heads) for _ in range(depth)) 106 | 107 | def forward(self, x) -> torch.Tensor: 108 | x = self.resnets[0](x) 109 | for attn, resnet in zip(self.attentions, self.resnets[1:]): 110 | x = resnet(attn(x).add_(x)) 111 | return x 112 | 113 | 114 | class Encoder(nn.Module): 115 | """VAE encoder.""" 116 | 117 | def __init__(self, dim, out_dim, block_dims, block_depth=2): 118 | super(Encoder, self).__init__() 119 | self.conv_in = nn.Conv2d(dim, block_dims[0], 3, 1, 1) 120 | self.down_blocks = nn.ModuleList() 121 | for i, block_dim in enumerate(block_dims): 122 | downsample = 1 if i < (len(block_dims) - 1) else 0 123 | args = (block_dims[max(i - 1, 0)], block_dim, block_depth) 124 | self.down_blocks += [UNetResBlock(*args, downsample=downsample)] 125 | self.mid_block = UNetMidBlock(block_dims[-1]) 126 | self.conv_act = nn.SiLU() 127 | self.conv_norm_out = nn.GroupNorm(32, block_dims[-1], eps=1e-6) 128 | self.conv_out = nn.Conv2d(block_dims[-1], out_dim, 3, 1, 1) 129 | 130 | def forward(self, x) -> torch.Tensor: 131 | x = self.conv_in(x) 132 | for blk in self.down_blocks: 133 | x = blk(x) 134 | x = self.mid_block(x) 135 | return self.conv_out(self.conv_act(self.conv_norm_out(x))) 136 | 137 | 138 | class Decoder(nn.Module): 139 | """VAE decoder.""" 140 | 141 | def __init__(self, dim, out_dim, block_dims, block_depth=2): 142 | super(Decoder, self).__init__() 143 | block_dims = list(reversed(block_dims)) 144 | self.up_blocks = nn.ModuleList() 145 | for i, block_dim in enumerate(block_dims): 146 | upsample = 1 if i < (len(block_dims) - 1) else 0 147 | args = (block_dims[max(i - 1, 0)], block_dim, block_depth + 1) 148 | self.up_blocks += [UNetResBlock(*args, upsample=upsample)] 149 | self.conv_in = nn.Conv2d(dim, block_dims[0], 3, 1, 1) 150 | self.mid_block = UNetMidBlock(block_dims[0]) 151 | self.conv_act = nn.SiLU() 152 | self.conv_norm_out = nn.GroupNorm(32, block_dims[-1], eps=1e-6) 153 | self.conv_out = nn.Conv2d(block_dims[-1], out_dim, 3, 1, 1) 154 | 155 | def forward(self, x) -> torch.Tensor: 156 | x = self.conv_in(x) 157 | x = self.mid_block(x) 158 | for blk in self.up_blocks: 159 | x = blk(x) 160 | return self.conv_out(self.conv_act(self.conv_norm_out(x))) 161 | 162 | 163 | class AutoencoderKL(ModelMixin, ConfigMixin): 164 | """AutoEncoder KL.""" 165 | 166 | @register_to_config 167 | def __init__( 168 | self, 169 | in_channels=3, 170 | out_channels=3, 171 | down_block_types=("DownEncoderBlock2D",) * 4, 172 | up_block_types=("UpDecoderBlock2D",) * 4, 173 | block_out_channels=(128, 256, 512, 512), 174 | layers_per_block=2, 175 | act_fn="silu", 176 | latent_channels=16, 177 | norm_num_groups=32, 178 | sample_size=1024, 179 | scaling_factor=0.18215, 180 | shift_factor=None, 181 | latents_mean=None, 182 | latents_std=None, 183 | force_upcast=True, 184 | double_z=True, 185 | use_quant_conv=True, 186 | use_post_quant_conv=True, 187 | ): 188 | super(AutoencoderKL, self).__init__() 189 | channels, layers = block_out_channels, layers_per_block 190 | self.encoder = Encoder(in_channels, (1 + double_z) * latent_channels, channels, layers) 191 | self.decoder = Decoder(latent_channels, out_channels, channels, layers) 192 | quant_conv_type = type(self.decoder.conv_in) if use_quant_conv else nn.Identity 193 | post_quant_conv_type = type(self.decoder.conv_in) if use_post_quant_conv else nn.Identity 194 | self.quant_conv = quant_conv_type(*([(1 + double_z) * latent_channels] * 2 + [1])) 195 | self.post_quant_conv = post_quant_conv_type(latent_channels, latent_channels, 1) 196 | self.latent_dist = DiagonalGaussianDistribution if double_z else IdentityDistribution 197 | 198 | def scale_(self, x) -> torch.Tensor: 199 | """Scale the input latents.""" 200 | x.add_(-self.config.shift_factor) if self.config.shift_factor else None 201 | return x.mul_(self.config.scaling_factor) 202 | 203 | def unscale_(self, x) -> torch.Tensor: 204 | """Unscale the input latents.""" 205 | x.mul_(1 / self.config.scaling_factor) 206 | return x.add_(self.config.shift_factor) if self.config.shift_factor else x 207 | 208 | def encode(self, x) -> AutoencoderKLOutput: 209 | """Encode the input samples.""" 210 | z = self.quant_conv(self.encoder(self.forward(x))) 211 | posterior = self.latent_dist(z) 212 | return AutoencoderKLOutput(latent_dist=posterior) 213 | 214 | def decode(self, z) -> DecoderOutput: 215 | """Decode the input latents.""" 216 | z = z.squeeze_(2) if z.dim() == 5 else z 217 | x = self.decoder(self.post_quant_conv(self.forward(z))) 218 | return DecoderOutput(sample=x) 219 | 220 | def forward(self, x): # NOOP. 221 | return x 222 | -------------------------------------------------------------------------------- /diffnext/models/autoencoders/modeling_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024-present, BAAI. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | ############################################################################## 15 | """AutoEncoder utilities.""" 16 | 17 | from diffusers.models.modeling_outputs import BaseOutput 18 | import torch 19 | 20 | 21 | class DecoderOutput(BaseOutput): 22 | """Output of decoding method.""" 23 | 24 | sample: torch.Tensor 25 | 26 | 27 | class IdentityDistribution(object): 28 | """IdentityGaussianDistribution.""" 29 | 30 | def __init__(self, z): 31 | self.parameters = z 32 | 33 | def sample(self, generator=None): 34 | return self.parameters 35 | 36 | 37 | class DiagonalGaussianDistribution(object): 38 | """DiagonalGaussianDistribution.""" 39 | 40 | def __init__(self, z): 41 | self.parameters = z 42 | self.device, self.dtype = z.device, z.dtype 43 | if z.size(1) % 2: 44 | z = torch.cat([z, z[:, -1:].expand((-1, z.shape[1] - 2) + (-1,) * (z.dim() - 2))], 1) 45 | self.mean, self.logvar = z.float().chunk(2, dim=1) 46 | self.logvar = self.logvar.clamp(-30.0, 20.0) 47 | self.std, self.var = self.logvar.mul(0.5).exp_(), self.logvar.exp() 48 | 49 | def sample(self, generator=None) -> torch.Tensor: 50 | """Sample a latent from distribution.""" 51 | device, dtype = self.mean.device, self.mean.dtype 52 | norm_dist = torch.randn(self.mean.shape, generator=generator, device=device, dtype=dtype) 53 | return norm_dist.mul_(self.std).add_(self.mean).to(device=self.device, dtype=self.dtype) 54 | 55 | 56 | class TilingMixin(object): 57 | """Base class for input tiling.""" 58 | 59 | def __init__(self, sample_min_t=17, latent_min_t=5, sample_ovr_t=1, latent_ovr_t=0): 60 | self.sample_min_t, self.latent_min_t = sample_min_t, latent_min_t 61 | self.sample_ovr_t, self.latent_ovr_t = sample_ovr_t, latent_ovr_t 62 | 63 | def tiled_encoder(self, x) -> torch.Tensor: 64 | """Encode tiled samples.""" 65 | if x.dim() == 4 or x.size(2) <= self.sample_min_t: 66 | return self.encoder(x) 67 | t = x.shape[2] 68 | t_start = [i for i in range(0, t, self.sample_min_t - self.sample_ovr_t)] 69 | t_slice = [slice(i, i + self.sample_min_t) for i in t_start] 70 | t_tiles = [self.encoder(x[:, :, s]) for s in t_slice if s.stop <= t] 71 | t_tiles = [x[:, :, self.latent_ovr_t :] if i else x for i, x in enumerate(t_tiles)] 72 | return torch.cat(t_tiles, dim=2) 73 | 74 | def tiled_decoder(self, x) -> torch.Tensor: 75 | """Decode tiled latents.""" 76 | if x.dim() == 4 or x.size(2) <= self.latent_min_t: 77 | return self.decoder(x) 78 | t = x.shape[2] 79 | t_start = [i for i in range(0, t, self.latent_min_t - self.latent_ovr_t)] 80 | t_slice = [slice(i, i + self.latent_min_t) for i in t_start] 81 | t_tiles = [self.decoder(x[:, :, s]) for s in t_slice if s.stop <= t] 82 | t_tiles = [x[:, :, self.sample_ovr_t :] if i else x for i, x in enumerate(t_tiles)] 83 | return torch.cat(t_tiles, dim=2) 84 | -------------------------------------------------------------------------------- /diffnext/models/diffusion_mlp.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2024-present, BAAI. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """Diffusion MLP.""" 17 | 18 | import torch 19 | from torch import nn 20 | from torch.utils.checkpoint import checkpoint as apply_ckpt 21 | 22 | from diffnext.models.embeddings import PatchEmbed 23 | from diffnext.models.normalization import AdaLayerNormZero 24 | 25 | 26 | class Projector(nn.Module): 27 | """MLP Projector layer.""" 28 | 29 | def __init__(self, dim, mlp_dim=None, out_dim=None): 30 | super(Projector, self).__init__() 31 | self.fc1 = nn.Linear(dim, mlp_dim or dim) 32 | self.fc2 = nn.Linear(mlp_dim or dim, out_dim or dim) 33 | self.activation = nn.SiLU() 34 | 35 | def forward(self, x) -> torch.Tensor: 36 | return self.fc2(self.activation(self.fc1(x))) 37 | 38 | 39 | class DiffusionBlock(nn.Module): 40 | """Diffusion block.""" 41 | 42 | def __init__(self, dim): 43 | super(DiffusionBlock, self).__init__() 44 | self.dim, self.mlp_checkpointing = dim, False 45 | self.norm1 = AdaLayerNormZero(dim, num_stats=3, eps=1e-6) 46 | self.proj, self.norm2 = Projector(dim, dim, dim), nn.LayerNorm(dim) 47 | 48 | def forward(self, x, z) -> torch.Tensor: 49 | if self.mlp_checkpointing and x.requires_grad: 50 | h, (gate,) = apply_ckpt(self.norm1, x, z, use_reentrant=False) 51 | return self.norm2(apply_ckpt(self.proj, h, use_reentrant=False)).mul(gate).add_(x) 52 | h, (gate,) = self.norm1(x, z) 53 | return self.norm2(self.proj(h)).mul(gate).add_(x) 54 | 55 | 56 | class TimeCondEmbed(nn.Module): 57 | """Time-Condition embedding layer.""" 58 | 59 | def __init__(self, cond_dim, embed_dim, freq_dim=256): 60 | super(TimeCondEmbed, self).__init__() 61 | self.timestep_proj = Projector(freq_dim, embed_dim, embed_dim) 62 | self.condition_proj = Projector(cond_dim, embed_dim, embed_dim) 63 | self.freq_dim, self.time_freq = freq_dim, None 64 | 65 | def get_freq_embed(self, timestep, dtype) -> torch.Tensor: 66 | if self.time_freq is None: 67 | dim, log_theta = self.freq_dim // 2, 9.210340371976184 # math.log(10000) 68 | freq = torch.arange(dim, dtype=torch.float32, device=timestep.device) 69 | self.time_freq = freq.mul(-log_theta / dim).exp().unsqueeze(0) 70 | emb = timestep.unsqueeze(-1).float() * self.time_freq 71 | return torch.cat([emb.cos(), emb.sin()], dim=-1).to(dtype=dtype) 72 | 73 | def forward(self, timestep, z) -> torch.Tensor: 74 | t = self.timestep_proj(self.get_freq_embed(timestep, z.dtype)) 75 | return self.condition_proj(z).add_(t.unsqueeze_(1) if t.dim() == 2 else t) 76 | 77 | 78 | class DiffusionMLP(nn.Module): 79 | """Diffusion MLP model.""" 80 | 81 | def __init__(self, depth, embed_dim, cond_dim, patch_size=2, image_dim=4): 82 | super(DiffusionMLP, self).__init__() 83 | self.patch_embed = PatchEmbed(image_dim, embed_dim, patch_size) 84 | self.time_cond_embed = TimeCondEmbed(cond_dim, embed_dim) 85 | self.blocks = nn.ModuleList(DiffusionBlock(embed_dim) for _ in range(depth)) 86 | self.norm = AdaLayerNormZero(embed_dim, num_stats=2, eps=1e-6) 87 | self.head = nn.Linear(embed_dim, patch_size**2 * image_dim) 88 | 89 | def forward(self, x, timestep, z, pred_ids=None) -> torch.Tensor: 90 | x, o = self.patch_embed(x), None if pred_ids is None else x 91 | o = None if pred_ids is None else self.patch_embed.patchify(o) 92 | x = x if pred_ids is None else x.gather(1, pred_ids.expand(-1, -1, x.size(-1))) 93 | z = z if pred_ids is None else z.gather(1, pred_ids.expand(-1, -1, z.size(-1))) 94 | z = self.time_cond_embed(timestep, z) 95 | for blk in self.blocks: 96 | x = blk(x, z) 97 | x = self.norm(x, z)[0] 98 | x = self.head(x) 99 | return x if pred_ids is None else o.scatter(1, pred_ids.expand(-1, -1, x.size(-1)), x) 100 | -------------------------------------------------------------------------------- /diffnext/models/diffusion_transformer.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2024-present, BAAI. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """Diffusion Transformer.""" 17 | 18 | from functools import partial 19 | from typing import Tuple 20 | 21 | import torch 22 | from torch import nn 23 | from torch.utils.checkpoint import checkpoint as apply_ckpt 24 | 25 | from diffnext.models.embeddings import PatchEmbed, RotaryEmbed3D 26 | from diffnext.models.normalization import AdaLayerNormZero, AdaLayerNormSingle 27 | from diffnext.models.diffusion_mlp import Projector, TimeCondEmbed 28 | 29 | 30 | class TimeEmbed(TimeCondEmbed): 31 | """Time embedding layer.""" 32 | 33 | def __init__(self, embed_dim, freq_dim=256): 34 | nn.Module.__init__(self) 35 | self.timestep_proj = Projector(freq_dim, embed_dim, embed_dim) 36 | self.freq_dim, self.time_freq = freq_dim, None 37 | 38 | def forward(self, timestep) -> torch.Tensor: 39 | dtype = self.timestep_proj.fc1.weight.dtype 40 | temb = self.timestep_proj(self.get_freq_embed(timestep, dtype)) 41 | return temb.unsqueeze_(1) if temb.dim() == 2 else temb 42 | 43 | 44 | class MLP(nn.Module): 45 | """Two layers MLP.""" 46 | 47 | def __init__(self, dim, mlp_ratio=4): 48 | super(MLP, self).__init__() 49 | self.fc1 = nn.Linear(dim, int(dim * mlp_ratio)) 50 | self.fc2 = nn.Linear(int(dim * mlp_ratio), dim) 51 | self.activation = nn.GELU() 52 | 53 | def forward(self, x) -> torch.Tensor: 54 | return self.fc2(self.activation(self.fc1(x))) 55 | 56 | 57 | class Attention(nn.Module): 58 | """Multihead attention.""" 59 | 60 | def __init__(self, dim, num_heads, qkv_bias=True): 61 | super(Attention, self).__init__() 62 | self.num_heads, self.head_dim = num_heads, dim // num_heads 63 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 64 | self.proj, self.pe_func = nn.Linear(dim, dim), None 65 | 66 | def forward(self, x) -> torch.Tensor: 67 | qkv_shape = [-1, x.size(1), 3, self.num_heads, self.head_dim] 68 | q, k, v = self.qkv(x).view(qkv_shape).permute(2, 0, 3, 1, 4).unbind(dim=0) 69 | q, k = (self.pe_func(q), self.pe_func(k)) if self.pe_func else (q, k) 70 | o = nn.functional.scaled_dot_product_attention(q, k, v) 71 | return self.proj(o.transpose(1, 2).flatten(2)) 72 | 73 | 74 | class Block(nn.Module): 75 | """Transformer block.""" 76 | 77 | def __init__(self, dim, num_heads, mlp_ratio=4, qkv_bias=True, modulation_type=None): 78 | super(Block, self).__init__() 79 | self.modulation = (modulation_type or AdaLayerNormZero)(dim, num_stats=6, eps=1e-6) 80 | self.norm1, self.norm2 = nn.LayerNorm(dim), nn.LayerNorm(dim) 81 | self.attn = Attention(dim, num_heads, qkv_bias=qkv_bias) 82 | self.mlp = MLP(dim, mlp_ratio=mlp_ratio) 83 | self.attn_checkpointing = self.mlp_checkpointing = self.stg_skip = False 84 | 85 | def forward_modulation(self, x, z) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]: 86 | return self.modulation(x, z) 87 | 88 | def forward_attn(self, x) -> torch.Tensor: 89 | return self.norm1(self.attn(x)) 90 | 91 | def forward_mlp(self, x) -> torch.Tensor: 92 | return self.norm2(self.mlp(x)) 93 | 94 | def forward_ckpt(self, x, name) -> torch.Tensor: 95 | if getattr(self, f"{name}_checkpointing", False) and x.requires_grad: 96 | return apply_ckpt(getattr(self, f"forward_{name}"), x, use_reentrant=False) 97 | return getattr(self, f"forward_{name}")(x) 98 | 99 | def forward(self, x, z, pe_func: callable = None) -> torch.Tensor: 100 | self.attn.pe_func = pe_func 101 | stg_x = x.chunk(3)[-1] if self.stg_skip else None 102 | if self.mlp_checkpointing and x.requires_grad: 103 | x, stats = apply_ckpt(self.forward_modulation, x, z, use_reentrant=False) 104 | else: 105 | x, stats = self.forward_modulation(x, z) 106 | gate_msa, scale_mlp, shift_mlp, gate_mlp = stats 107 | x = self.forward_ckpt(x, "attn").mul(gate_msa).add_(x) 108 | x = self.modulation.norm(x).mul(1 + scale_mlp).add_(shift_mlp) 109 | x = self.forward_ckpt(x, "mlp").mul(gate_mlp).add_(x) 110 | return torch.cat(x.chunk(3)[:2] + (stg_x,)) if self.stg_skip else x 111 | 112 | 113 | class DiffusionTransformer(nn.Module): 114 | """Diffusion transformer.""" 115 | 116 | def __init__( 117 | self, 118 | depth, 119 | embed_dim, 120 | num_heads, 121 | mlp_ratio=4, 122 | patch_size=2, 123 | image_size=32, 124 | image_dim=None, 125 | modulation=True, 126 | ): 127 | super(DiffusionTransformer, self).__init__() 128 | final_norm = AdaLayerNormSingle if modulation else AdaLayerNormZero 129 | block = partial(Block, modulation_type=AdaLayerNormSingle) if modulation else Block 130 | self.embed_dim, self.image_size, self.image_dim = embed_dim, image_size, image_dim 131 | self.patch_embed = PatchEmbed(image_dim, embed_dim, patch_size) 132 | self.time_embed = TimeEmbed(embed_dim) 133 | self.modulation = AdaLayerNormZero(embed_dim, num_stats=6, eps=1e-6) if modulation else None 134 | self.rope = RotaryEmbed3D(embed_dim // num_heads) 135 | self.blocks = nn.ModuleList(block(embed_dim, num_heads, mlp_ratio) for _ in range(depth)) 136 | self.norm = final_norm(embed_dim, num_stats=2, eps=1e-6) 137 | self.head = nn.Linear(embed_dim, patch_size**2 * image_dim) 138 | 139 | def prepare_pe(self, c=None, pos=None) -> Tuple[callable, callable]: 140 | return self.rope.get_func(pos, pad=0 if c is None else c.size(1)) 141 | 142 | def forward(self, x, timestep, c=None, pos=None) -> torch.Tensor: 143 | x = self.patch_embed(x) 144 | t = self.time_embed(timestep) 145 | z = self.modulation.proj(self.modulation.activation(t)) if self.modulation else t 146 | pe = self.prepare_pe(c, pos) if pos is not None else None 147 | x = x if c is None else torch.cat([c, x], dim=1) 148 | for blk in self.blocks: 149 | x = blk(x, z, pe) 150 | x = self.norm(x if c is None else x[:, c.size(1) :], t)[0] 151 | return self.head(x) 152 | -------------------------------------------------------------------------------- /diffnext/models/flex_attention.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2024-present, BAAI. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """Flex attention layers.""" 17 | 18 | from itertools import accumulate 19 | from typing import List 20 | 21 | import torch 22 | from torch import nn 23 | 24 | try: 25 | from torch.nn.attention.flex_attention import create_block_mask 26 | from torch.nn.attention.flex_attention import flex_attention 27 | except ImportError: 28 | flex_attention = create_block_mask = None 29 | 30 | 31 | class FlexAttentionCausal2D(nn.Module): 32 | """Block-wise causal flex attention.""" 33 | 34 | def __init__(self): 35 | super(FlexAttentionCausal2D, self).__init__() 36 | self.attn_func, self.offsets = None, None 37 | self.cu_offsets, self.block_mask = None, None 38 | 39 | def set_offsets(self, offsets: List[int]): 40 | """Set block-wise mask offsets.""" 41 | offsets = list(type(offsets)([0]) + offsets if offsets[0] != 0 else offsets) 42 | if offsets != self.offsets: 43 | self.offsets, self.block_mask = offsets, None 44 | 45 | def set_offsets_by_lens(self, lens: List[int]): 46 | """Set block-wise mask offsets by lengths.""" 47 | self.set_offsets(list(accumulate(type(lens)([0]) + lens if lens[0] != 0 else lens))) 48 | 49 | def get_mask_mod(self) -> callable: 50 | """Return the mask modification.""" 51 | counts = self.cu_offsets[1:] - self.cu_offsets[:-1] 52 | ids = torch.arange(len(counts), device=self.cu_offsets.device, dtype=torch.int32) 53 | ids = ids.repeat_interleave(counts) 54 | return lambda b, h, q_idx, kv_idx: (q_idx >= kv_idx) | (ids[q_idx] == ids[kv_idx]) 55 | 56 | def get_attn_func(self) -> callable: 57 | """Return the attention function.""" 58 | if flex_attention is None: 59 | raise NotImplementedError(f"FlexAttn requires torch>=2.5 but got {torch.__version__}") 60 | if self.attn_func is None: 61 | self.attn_func = torch.compile(flex_attention) 62 | return self.attn_func 63 | 64 | def get_block_mask(self, q: torch.Tensor) -> torch.Tensor: 65 | """Return the attention block mask according to inputs.""" 66 | if self.block_mask is not None: 67 | return self.block_mask 68 | b, h, q_len = q.shape[:3] 69 | q_pad = (self.offsets[-1] + 127) // 128 * 128 - q_len 70 | offsets_pad = self.offsets + ([self.offsets[-1] + q_pad] if q_pad else []) 71 | args = {"B": b, "H": h, "Q_LEN": q_len + q_pad, "KV_LEN": q_len + q_pad, "_compile": True} 72 | self.cu_offsets = torch.as_tensor(offsets_pad, device=q.device, dtype=torch.int32) 73 | self.block_mask = create_block_mask(self.get_mask_mod(), **args) 74 | return self.block_mask 75 | 76 | def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: 77 | return self.get_attn_func()(q, k, v, block_mask=self.get_block_mask(q)) 78 | -------------------------------------------------------------------------------- /diffnext/models/guidance_scaler.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2024-present, BAAI. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """Classifier-free guidance scaler.""" 17 | 18 | import torch 19 | 20 | 21 | class GuidanceScaler(object): 22 | """Guidance scaler.""" 23 | 24 | def __init__(self, **kwargs): 25 | self.guidance_scale = kwargs.get("guidance_scale", 1) 26 | self.guidance_trunc = kwargs.get("guidance_trunc", 0) 27 | self.guidance_renorm = kwargs.get("guidance_renorm", 1) 28 | self.image_guidance_scale = kwargs.get("image_guidance_scale", 0) 29 | self.spatiotemporal_guidance_scale = kwargs.get("spatiotemporal_guidance_scale", 0) 30 | self.min_guidance_scale = kwargs.get("min_guidance_scale", None) or self.guidance_scale 31 | self.inc_guidance_scale = self.guidance_scale - self.min_guidance_scale 32 | 33 | @property 34 | def extra_pass(self) -> bool: 35 | """Return if an additional (third) guidance pass is required.""" 36 | return self.image_guidance_scale + self.spatiotemporal_guidance_scale > 0 37 | 38 | def clone(self): 39 | """Return a deepcopy of current guidance scaler.""" 40 | return GuidanceScaler(**self.__dict__) 41 | 42 | def decay_guidance_scale(self, decay=0): 43 | """Scale guidance scale according to decay.""" 44 | self.guidance_scale = self.inc_guidance_scale * decay + self.min_guidance_scale 45 | 46 | def expand(self, x: torch.Tensor, padding: torch.Tensor = None) -> torch.Tensor: 47 | """Expand input tensor for guidance passes.""" 48 | x = torch.stack([x] * (3 if self.extra_pass else 2)) if self.guidance_scale > 1 else x 49 | x.__setitem__(1, padding) if self.image_guidance_scale and padding is not None else None 50 | return x.flatten(0, 1) if self.guidance_scale > 1 else x 51 | 52 | def expand_text(self, c: torch.Tensor) -> torch.Tensor: 53 | """Expand text embedding tensor for guidance passes.""" 54 | c = list(c.chunk(2)) if self.extra_pass else c 55 | c.append(c[1]) if self.image_guidance_scale else None # Null, Null 56 | c.append(c[0]) if self.spatiotemporal_guidance_scale else None # Null, Text 57 | return torch.cat(c) if self.extra_pass else c 58 | 59 | def maybe_disable(self, timestep, *args): 60 | """Disable all guidance passes if matching truncation threshold.""" 61 | if self.guidance_scale > 1 and self.guidance_trunc: 62 | if float(timestep) < self.guidance_trunc: 63 | self.guidance_scale = 1 64 | return [_.chunk(3 if self.extra_pass else 2)[0] for _ in args] 65 | return args 66 | 67 | def renorm(self, x, cond): 68 | """Apply guidance renormalization to input logits.""" 69 | if self.guidance_renorm >= 1: 70 | return x 71 | args = {"dim": tuple(range(1, len(x.shape))), "keepdim": True} 72 | return x.mul_(cond.norm(**args).div_(x.norm(**args)).clamp(self.guidance_renorm, 1)) 73 | 74 | def scale(self, x: torch.Tensor) -> torch.Tensor: 75 | """Apply guidance passes to input logits.""" 76 | if self.guidance_scale <= 1: 77 | return x 78 | if self.image_guidance_scale: 79 | cond, uncond, imgcond = x.chunk(3) 80 | x = self.renorm(uncond.add(cond.sub(imgcond).mul_(self.guidance_scale)), cond) 81 | return x.add_(imgcond.sub_(uncond).mul_(self.image_guidance_scale)) 82 | if self.spatiotemporal_guidance_scale: 83 | cond, uncond, perturb = x.chunk(3) 84 | x = self.renorm(uncond.add_(cond.sub(uncond).mul_(self.guidance_scale)), cond) 85 | return x.add_(cond.sub_(perturb).mul_(self.spatiotemporal_guidance_scale)) 86 | cond, uncond = x.chunk(2) 87 | return self.renorm(uncond.add_(cond.sub(uncond).mul_(self.guidance_scale)), cond) 88 | -------------------------------------------------------------------------------- /diffnext/models/normalization.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2024-present, BAAI. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """Normalization Layers.""" 17 | 18 | from typing import Tuple 19 | 20 | import torch 21 | from torch import nn 22 | 23 | 24 | class AdaLayerNormZero(nn.Module): 25 | """Adaptive LayerNorm with residual stats.""" 26 | 27 | def __init__(self, dim, rank=None, num_stats=2, eps=1e-6): 28 | super(AdaLayerNormZero, self).__init__() 29 | self.lora = nn.Linear(dim, rank, bias=False) if rank else nn.Identity() 30 | self.proj = nn.Linear(rank if rank else dim, num_stats * dim) 31 | self.norm = nn.LayerNorm(dim, eps, elementwise_affine=False) if eps else nn.Identity() 32 | self.activation, self.num_stats = nn.SiLU(), num_stats 33 | 34 | def forward(self, x, z) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]: 35 | stats = self.proj(self.lora(self.activation(z))).chunk(self.num_stats, dim=-1) 36 | return self.norm(x).mul(1 + stats[0]).add_(stats[1]), stats[2:] 37 | 38 | 39 | class AdaLayerNorm(AdaLayerNormZero): 40 | """Adaptive LayerNorm.""" 41 | 42 | def __init__(self, dim, rank=None, eps=1e-6): 43 | super(AdaLayerNorm, self).__init__(dim, rank, num_stats=2, eps=eps) 44 | 45 | def forward(self, x, z) -> torch.Tensor: 46 | return super().forward(x, z)[0] 47 | 48 | 49 | class AdaLayerNormSingle(nn.Module): 50 | """Adaptive LayerNorm with shared residual stats.""" 51 | 52 | def __init__(self, dim, num_stats=2, eps=1e-6): 53 | super(AdaLayerNormSingle, self).__init__() 54 | self.bias = nn.Parameter(torch.randn(num_stats, dim) / dim**0.5) 55 | self.norm = nn.LayerNorm(dim, eps, elementwise_affine=False) if eps else nn.Identity() 56 | self.num_stats = num_stats 57 | 58 | def forward(self, x, z) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]: 59 | axis = -2 if z.size(-1) == self.bias.size(-1) else -1 60 | bias = self.bias.flatten(-1 if z.size(-1) == self.bias.size(-1) else 0) 61 | stats = z.add(bias).chunk(self.num_stats, dim=axis) 62 | return self.norm(x).mul(1 + stats[0]).add_(stats[1]), stats[2:] 63 | -------------------------------------------------------------------------------- /diffnext/models/text_encoders/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2024-present, BAAI. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | -------------------------------------------------------------------------------- /diffnext/models/text_encoders/phi.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024-present, BAAI. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | ############################################################################## 15 | """Simple implementation of Phi model.""" 16 | 17 | from typing import Tuple 18 | 19 | try: 20 | from flash_attn import flash_attn_func 21 | except ImportError: 22 | flash_attn_func = None 23 | try: 24 | from flash_attn.layers.rotary import apply_rotary_emb 25 | except ImportError: 26 | from einops import rearrange, repeat 27 | 28 | apply_rotary_emb = None 29 | 30 | 31 | import torch 32 | from torch import nn 33 | 34 | from transformers.activations import ACT2FN 35 | from transformers.modeling_outputs import BaseModelOutput 36 | from transformers.modeling_utils import PreTrainedModel 37 | from transformers.models.phi.configuration_phi import PhiConfig 38 | 39 | 40 | def rotate_half(x, interleaved=False) -> torch.Tensor: 41 | if not interleaved: 42 | x1, x2 = x.chunk(2, dim=-1) 43 | return torch.cat((-x2, x1), dim=-1) 44 | x1, x2 = x[..., ::2], x[..., 1::2] 45 | return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2) 46 | 47 | 48 | def apply_rotary_emb_torch(x, cos, sin, interleaved=False, inplace=False) -> torch.Tensor: 49 | ro_dim = cos.shape[-1] * 2 50 | cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") 51 | sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") 52 | return torch.cat( 53 | [x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]], 54 | dim=-1, 55 | ) 56 | 57 | 58 | apply_rotary_emb = apply_rotary_emb or apply_rotary_emb_torch 59 | 60 | 61 | class PhiCache(nn.Module): 62 | """Execution cache.""" 63 | 64 | def __init__(self, config: PhiConfig, device=None, dtype=None): 65 | super(PhiCache, self).__init__() 66 | self.config, self.device, self.dtype = config, device, dtype 67 | self.start_pos, self.end_pos, self.cache_dict = 0, 0, {} 68 | 69 | def reset(self, device=None, dtype=None): 70 | self.device, self.dtype = device, dtype 71 | max_seq_len = self.config.max_position_embeddings 72 | head_dim = self.config.hidden_size // self.config.num_attention_heads 73 | rotary_dim = int(self.config.partial_rotary_factor * head_dim) 74 | self.init_rotary(max_seq_len, rotary_dim, self.config.rope_theta) 75 | 76 | def init_rotary(self, seq_len, dim, theta=10000.0): 77 | grid = torch.arange(seq_len, dtype=torch.float32).unsqueeze_(-1) 78 | freq = torch.pow(theta, torch.arange(0, dim, 2)[: dim // 2].float().div_(dim)) 79 | broadcast_freq = grid.mul(freq.reciprocal_().unsqueeze_(0)) 80 | cache_cos = broadcast_freq.cos().view((-1, dim // 2)) 81 | cache_sin = broadcast_freq.sin().view((-1, dim // 2)) 82 | self.cache_dict["cos"] = cache_cos.to(self.device, self.dtype) 83 | self.cache_dict["sin"] = cache_sin.to(self.device, self.dtype) 84 | 85 | def set_seq(self, start_pos=0, end_pos=None): 86 | self.start_pos, self.end_pos = start_pos, end_pos 87 | if "cos" in self.cache_dict and end_pos is not None: 88 | self.cache_dict["seq_cos"] = self.cache_dict["cos"][self.start_pos : end_pos] 89 | self.cache_dict["seq_sin"] = self.cache_dict["sin"][self.start_pos : end_pos] 90 | 91 | def forward_rotary(self, q, k, inplace=False) -> Tuple[torch.Tensor, torch.Tensor]: 92 | cos = self.cache_dict.get("seq_cos", self.cache_dict.get("cos", None)) 93 | sin = self.cache_dict.get("seq_sin", self.cache_dict.get("sin", None)) 94 | q = apply_rotary_emb(q, cos, sin, interleaved=False, inplace=inplace) 95 | k = apply_rotary_emb(k, cos, sin, interleaved=False, inplace=inplace) 96 | return q, k 97 | 98 | 99 | class PhiMLP(nn.Module): 100 | """Two layers MLP.""" 101 | 102 | def __init__(self, config: PhiConfig): 103 | super().__init__() 104 | self.activation = ACT2FN[config.hidden_act] 105 | self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) 106 | self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) 107 | 108 | def forward(self, x) -> torch.Tensor: 109 | return self.fc2(self.activation(self.fc1(x))) 110 | 111 | 112 | class PhiAttention(nn.Module): 113 | """Multi-headed attention.""" 114 | 115 | def __init__(self, config: PhiConfig): 116 | super().__init__() 117 | self.config = config 118 | self.num_heads = config.num_attention_heads 119 | self.head_dim = config.hidden_size // self.num_heads 120 | self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim) 121 | self.k_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim) 122 | self.v_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim) 123 | self.dense = nn.Linear(self.num_heads * self.head_dim, config.hidden_size) 124 | 125 | 126 | class PhiFlashAttention2(PhiAttention): 127 | """Multi-headed attention using FA2.""" 128 | 129 | def forward(self, x, attn_mask=None) -> torch.Tensor: 130 | qkv_shape = (-1, x.shape[1], self.num_heads, self.head_dim) 131 | q, k, v = [f(x).view(qkv_shape) for f in (self.q_proj, self.k_proj, self.v_proj)] 132 | q, k = self.cache.forward_rotary(q, k, inplace=True) 133 | if flash_attn_func is None: 134 | q, k, v = [_.transpose(1, 2) for _ in (q, k, v)] 135 | o = nn.functional.scaled_dot_product_attention(q, k, v, is_causal=True) 136 | return self.dense(o.transpose(1, 2).flatten(2)) 137 | return self.dense(flash_attn_func(q, k, v, causal=True).flatten(2)) 138 | 139 | 140 | class PhiLayer(nn.Module): 141 | """Transformer layer.""" 142 | 143 | def __init__(self, config: PhiConfig): 144 | super().__init__() 145 | self.self_attn = PhiFlashAttention2(config) 146 | self.mlp = PhiMLP(config) 147 | self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 148 | 149 | def forward(self, x, attn_mask=None) -> torch.Tensor: 150 | shortcut, x = x, self.input_layernorm(x) 151 | return self.self_attn(x, attn_mask).add_(self.mlp(x)).add_(shortcut) 152 | 153 | 154 | class PhiPreTrainedModel(PreTrainedModel): 155 | """Base model.""" 156 | 157 | config_class = PhiConfig 158 | 159 | 160 | class PhiModel(PhiPreTrainedModel): 161 | """Standard model.""" 162 | 163 | def __init__(self, config: PhiConfig): 164 | super().__init__(config) 165 | self.padding_idx = config.pad_token_id 166 | self.vocab_size = config.vocab_size 167 | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) 168 | self.layers = nn.ModuleList(PhiLayer(config) for _ in range(config.num_hidden_layers)) 169 | self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 170 | self.cache, _ = PhiCache(config), self.post_init() 171 | 172 | def maybe_init_cache(self, **kwargs): 173 | if self.cache.device is not None: 174 | return 175 | self.cache.reset(self.device, self.dtype) 176 | [layer.self_attn.__dict__.setdefault("cache", self.cache) for layer in self.layers] 177 | 178 | def forward(self, input_ids, attention_mask=None, **kwargs) -> BaseModelOutput: 179 | self.maybe_init_cache(**kwargs) 180 | h = kwargs.get("inputs_embeds", None) 181 | h = self.embed_tokens(input_ids) if h is None else h 182 | start_pos = 0 if kwargs.get("past_key_values", None) is None else self.cache.end_pos 183 | self.cache.set_seq(start_pos, start_pos + h.shape[1]) 184 | for layer in self.layers: 185 | h = layer(h, attention_mask) 186 | h = self.final_layernorm(h) 187 | return BaseModelOutput(last_hidden_state=h) 188 | 189 | 190 | class PhiEncoderModel(PhiPreTrainedModel): 191 | """Encoder model.""" 192 | 193 | def __init__(self, config): 194 | super().__init__(config) 195 | self.model = PhiModel(config) 196 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size) 197 | self.vocab_size, _ = config.vocab_size, self.post_init() 198 | 199 | def forward(self, input_ids, attention_mask=None, **kwargs) -> BaseModelOutput: 200 | return self.model(input_ids, attention_mask, **kwargs) 201 | -------------------------------------------------------------------------------- /diffnext/models/transformers/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2024-present, BAAI. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | -------------------------------------------------------------------------------- /diffnext/models/transformers/transformer_nova.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2024-present, BAAI. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """3D transformer model for AR generation in NOVA.""" 17 | 18 | from diffusers.configuration_utils import ConfigMixin, register_to_config 19 | from diffusers.models.modeling_utils import ModelMixin 20 | 21 | from diffnext.models.diffusion_mlp import DiffusionMLP 22 | from diffnext.models.embeddings import PosEmbed, VideoPosEmbed, RotaryEmbed3D 23 | from diffnext.models.embeddings import MaskEmbed, MotionEmbed, TextEmbed, LabelEmbed 24 | from diffnext.models.normalization import AdaLayerNorm 25 | from diffnext.models.transformers.transformer_3d import Transformer3DModel 26 | from diffnext.models.vision_transformer import VisionTransformer 27 | from diffnext.utils import Registry 28 | 29 | VIDEO_ENCODERS = Registry("video_encoders") 30 | IMAGE_ENCODERS = Registry("image_encoders") 31 | IMAGE_DECODERS = Registry("image_decoders") 32 | 33 | 34 | @VIDEO_ENCODERS.register("vit_d16w768", depth=16, embed_dim=768, num_heads=12) 35 | @VIDEO_ENCODERS.register("vit_d16w1024", depth=16, embed_dim=1024, num_heads=16) 36 | @VIDEO_ENCODERS.register("vit_d16w1536", depth=16, embed_dim=1536, num_heads=16) 37 | def video_encoder(depth, embed_dim, num_heads, patch_size, image_size, image_dim): 38 | return VisionTransformer(**locals()) 39 | 40 | 41 | @IMAGE_ENCODERS.register("vit_d32w768", depth=32, embed_dim=768, num_heads=12) 42 | @IMAGE_ENCODERS.register("vit_d32w1024", depth=32, embed_dim=1024, num_heads=16) 43 | @IMAGE_ENCODERS.register("vit_d32w1536", depth=32, embed_dim=1536, num_heads=16) 44 | def image_encoder(depth, embed_dim, num_heads, patch_size, image_size, image_dim): 45 | return VisionTransformer(**locals()) 46 | 47 | 48 | @IMAGE_DECODERS.register("mlp_d3w1280", depth=3, embed_dim=1280) 49 | @IMAGE_DECODERS.register("mlp_d6w768", depth=6, embed_dim=768) 50 | @IMAGE_DECODERS.register("mlp_d6w1024", depth=6, embed_dim=1024) 51 | @IMAGE_DECODERS.register("mlp_d6w1536", depth=6, embed_dim=1536) 52 | def image_decoder(depth, embed_dim, patch_size, image_dim, cond_dim): 53 | return DiffusionMLP(**locals()) 54 | 55 | 56 | class NOVATransformer3DModel(Transformer3DModel, ModelMixin, ConfigMixin): 57 | """A 3D transformer model for AR generation in NOVA.""" 58 | 59 | @register_to_config 60 | def __init__( 61 | self, 62 | image_dim=None, 63 | image_size=None, 64 | image_stride=None, 65 | text_token_dim=None, 66 | text_token_len=None, 67 | image_base_size=None, 68 | video_base_size=None, 69 | video_mixer_rank=None, 70 | rotary_pos_embed=False, 71 | arch=("", "", ""), 72 | ): 73 | image_size = (image_size,) * 2 if isinstance(image_size, int) else image_size 74 | image_size = tuple(v // image_stride for v in image_size) 75 | image_args = {"image_dim": image_dim, "patch_size": 15 // image_stride + 1} 76 | video_args = {**image_args, "patch_size": image_args["patch_size"] * 2} 77 | video_encoder = VIDEO_ENCODERS.get(arch[0])(image_size=image_size, **video_args) 78 | image_encoder = IMAGE_ENCODERS.get(arch[1])(image_size=image_size, **image_args) 79 | image_decoder = IMAGE_DECODERS.get(arch[2])(cond_dim=image_encoder.embed_dim, **image_args) 80 | if rotary_pos_embed: 81 | video_pos_embed = RotaryEmbed3D(video_encoder.rope.dim, video_base_size[1:]) 82 | image_pos_embed = RotaryEmbed3D(image_encoder.rope.dim, image_base_size) 83 | else: 84 | video_pos_embed = VideoPosEmbed(video_encoder.embed_dim, video_base_size) 85 | image_encoder.pos_embed = PosEmbed(image_encoder.embed_dim, image_base_size) 86 | image_pos_embed = image_pos_embed if rotary_pos_embed else None 87 | if video_mixer_rank: 88 | video_mixer_rank = max(video_mixer_rank, 0) # Use vanilla AdaLN if ``rank`` < 0. 89 | video_encoder.mixer = AdaLayerNorm(video_encoder.embed_dim, video_mixer_rank, eps=None) 90 | if text_token_dim: 91 | text_embed = TextEmbed(text_token_dim, image_encoder.embed_dim, text_token_len) 92 | super(NOVATransformer3DModel, self).__init__( 93 | video_encoder=video_encoder, 94 | image_encoder=image_encoder, 95 | image_decoder=image_decoder, 96 | mask_embed=MaskEmbed(image_encoder.embed_dim), 97 | text_embed=text_embed if text_token_dim else None, 98 | label_embed=LabelEmbed(image_encoder.embed_dim) if not text_token_dim else None, 99 | video_pos_embed=video_pos_embed, 100 | image_pos_embed=image_pos_embed, 101 | motion_embed=MotionEmbed(video_encoder.embed_dim) if video_base_size[0] > 1 else None, 102 | ) 103 | -------------------------------------------------------------------------------- /diffnext/models/vision_transformer.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2024-present, BAAI. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """Vision Transformer.""" 17 | 18 | from typing import Tuple 19 | 20 | import torch 21 | from torch import nn 22 | from torch.utils.checkpoint import checkpoint as apply_ckpt 23 | 24 | from diffnext.models.embeddings import PatchEmbed, RotaryEmbed3D 25 | from diffnext.models.flex_attention import FlexAttentionCausal2D 26 | 27 | 28 | class MLP(nn.Module): 29 | """Two layers MLP.""" 30 | 31 | def __init__(self, dim, mlp_ratio=4): 32 | super(MLP, self).__init__() 33 | self.fc1 = nn.Linear(dim, int(dim * mlp_ratio)) 34 | self.fc2 = nn.Linear(int(dim * mlp_ratio), dim) 35 | self.activation = nn.GELU() 36 | 37 | def forward(self, x) -> torch.Tensor: 38 | return self.fc2(self.activation(self.fc1(x))) 39 | 40 | 41 | class Attention(nn.Module): 42 | """Multihead attention.""" 43 | 44 | def __init__(self, dim, num_heads, qkv_bias=True): 45 | super(Attention, self).__init__() 46 | self.num_heads, self.head_dim = num_heads, dim // num_heads 47 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 48 | self.proj = nn.Linear(dim, dim) 49 | self.attn_mask, self.cache_kv, self.pe_func, self.flex_attn = None, None, None, None 50 | 51 | def forward(self, x) -> torch.Tensor: 52 | qkv_shape = [-1, x.size(1), 3, self.num_heads, self.head_dim] 53 | q, k, v = self.qkv(x).view(qkv_shape).permute(2, 0, 3, 1, 4).unbind(dim=0) 54 | q, k = (self.pe_func(q), self.pe_func(k)) if self.pe_func else (q, k) 55 | if self.cache_kv is not None and self.cache_kv: 56 | if isinstance(self.cache_kv, list): 57 | k = self.cache_kv[0] = torch.cat([self.cache_kv[0], k], dim=2) 58 | v = self.cache_kv[1] = torch.cat([self.cache_kv[1], v], dim=2) 59 | else: 60 | self.cache_kv = [k, v] 61 | if self.flex_attn and self.flex_attn.offsets: 62 | return self.proj(self.flex_attn(q, k, v).transpose(1, 2).flatten(2)) 63 | o = nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=self.attn_mask) 64 | return self.proj(o.transpose(1, 2).flatten(2)) 65 | 66 | 67 | class Block(nn.Module): 68 | """Transformer block.""" 69 | 70 | def __init__(self, dim, num_heads, mlp_ratio=4, qkv_bias=True): 71 | super(Block, self).__init__() 72 | self.norm1 = nn.LayerNorm(dim) 73 | self.attn = Attention(dim, num_heads, qkv_bias=qkv_bias) 74 | self.norm2 = nn.LayerNorm(dim) 75 | self.mlp = MLP(dim, mlp_ratio=mlp_ratio) 76 | self.attn_checkpointing, self.mlp_checkpointing = False, False 77 | 78 | def forward_attn(self, x) -> torch.Tensor: 79 | return self.norm1(self.attn(x)) 80 | 81 | def forward_mlp(self, x) -> torch.Tensor: 82 | return self.norm2(self.mlp(x)) 83 | 84 | def forward_ckpt(self, x, name) -> torch.Tensor: 85 | if getattr(self, f"{name}_checkpointing", False) and x.requires_grad: 86 | return apply_ckpt(getattr(self, f"forward_{name}"), x, use_reentrant=False) 87 | return getattr(self, f"forward_{name}")(x) 88 | 89 | def forward(self, x, pe_func: callable = None) -> torch.Tensor: 90 | self.attn.pe_func = pe_func 91 | x = self.forward_ckpt(x, "attn").add_(x) 92 | return self.forward_ckpt(x, "mlp").add_(x) 93 | 94 | 95 | class VisionTransformer(nn.Module): 96 | """Vision transformer.""" 97 | 98 | def __init__( 99 | self, 100 | depth, 101 | embed_dim, 102 | num_heads, 103 | mlp_ratio=4, 104 | patch_size=2, 105 | image_size=32, 106 | image_dim=4, 107 | encoder_depth=None, 108 | ): 109 | super(VisionTransformer, self).__init__() 110 | self.embed_dim, self.image_size, self.image_dim = embed_dim, image_size, image_dim 111 | self.patch_embed = PatchEmbed(image_dim, embed_dim, patch_size) 112 | self.pos_embed, self.rope = nn.Identity(), RotaryEmbed3D(embed_dim // num_heads) 113 | self.blocks = nn.ModuleList(Block(embed_dim, num_heads, mlp_ratio) for _ in range(depth)) 114 | self.norm, self.mixer = nn.LayerNorm(embed_dim), nn.Identity() 115 | self.encoder_depth = len(self.blocks) // 2 if encoder_depth is None else encoder_depth 116 | self.flex_attn = FlexAttentionCausal2D() 117 | [setattr(blk.attn, "flex_attn", self.flex_attn) for blk in self.blocks] 118 | 119 | def prepare_pe(self, c=None, ids=None, pos=None) -> Tuple[callable, callable]: 120 | pad = 0 if c is None else c.size(1) 121 | pe1 = pe2 = self.rope.get_func(pos, pad) 122 | pe1 = self.rope.get_func(pos, pad, ids.expand(-1, -1, 3)) if ids is not None else pe1 123 | return pe1, pe2 124 | 125 | def enable_kvcache(self, mode=True): 126 | [setattr(blk.attn, "cache_kv", mode) for blk in self.blocks] 127 | 128 | def forward(self, x, c=None, prev_ids=None, pos=None) -> torch.Tensor: 129 | x, prev_ids = x if isinstance(x, (tuple, list)) else (x, prev_ids) 130 | prev_ids = prev_ids if self.encoder_depth else None 131 | x = x_masked = self.pos_embed(self.patch_embed(x)) 132 | pe1, pe2 = self.prepare_pe(c, prev_ids, pos) if pos is not None else [None] * 2 133 | if prev_ids is not None: # Split mask from x. 134 | prev_ids = prev_ids.expand(-1, -1, x.size(-1)) 135 | x = x.gather(1, prev_ids) 136 | x = x if c is None else torch.cat([c, x], dim=1) 137 | for blk in self.blocks[: self.encoder_depth]: 138 | x = blk(x, pe1) 139 | if prev_ids is not None and c is not None: # Split c from x. 140 | c, x = x.split((c.size(1), x.size(1) - c.size(1)), dim=1) 141 | if prev_ids is not None: # Merge mask with x. 142 | x = x_masked.to(dtype=x.dtype).scatter(1, prev_ids, x) 143 | x = x if c is None else torch.cat([c, x], dim=1) 144 | for blk in self.blocks[self.encoder_depth :]: 145 | x = blk(x, pe2) 146 | return self.norm(x if c is None else x[:, c.size(1) :]) 147 | -------------------------------------------------------------------------------- /diffnext/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2024-present, BAAI. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """Pipelines.""" 17 | 18 | from diffnext.pipelines.builder import build_pipeline 19 | from diffnext.pipelines.builder import build_diffusion_scheduler 20 | from diffnext.pipelines.nova import NOVAPipeline 21 | -------------------------------------------------------------------------------- /diffnext/pipelines/builder.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2024-present, BAAI. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """Pipeline builders.""" 17 | 18 | from typing import Dict 19 | 20 | import json 21 | import os 22 | import tempfile 23 | 24 | import torch 25 | 26 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline 27 | from diffusers.schedulers.scheduling_utils import SchedulerMixin 28 | from diffnext.utils.registry import Registry 29 | 30 | PIPELINES = Registry("pipelines") 31 | 32 | 33 | def get_pipeline_path( 34 | pretrained_path, 35 | module_dict: dict = None, 36 | module_config: Dict[str, dict] = None, 37 | target_path: str = None, 38 | ) -> str: 39 | """Return the pipeling loading path. 40 | 41 | Args: 42 | pretrained_path (str) 43 | The pretrained path to load pipeline. 44 | module_dict (dict, *optional*) 45 | The path dict to load custom modules. 46 | module_config (Dict[str, dict], *optional*) 47 | The custom configurations to dump into ``config.json``. 48 | target_path (str, *optional*) 49 | The path to store custom modules and configs. 50 | 51 | Returns: 52 | str: The pipeline loading path. 53 | 54 | """ 55 | if module_dict is None and module_config is None: 56 | return pretrained_path 57 | target_path = target_path or tempfile.mkdtemp() 58 | for k in os.listdir(pretrained_path): 59 | if not os.path.isdir(os.path.join(pretrained_path, k)): 60 | continue 61 | os.makedirs(os.path.join(target_path, k), exist_ok=True) 62 | for _ in os.listdir(os.path.join(pretrained_path, k)): 63 | os.symlink(os.path.join(pretrained_path, k, _), os.path.join(target_path, k, _)) 64 | module_dict = module_dict.copy() if module_dict is not None else {} 65 | model_index = module_dict.pop("model_index", os.path.join(pretrained_path, "model_index.json")) 66 | model_index = json.load(open(model_index)) 67 | for k, v in module_dict.items(): 68 | model_index.pop(k) if not v else None 69 | try: 70 | os.symlink(v, os.path.join(target_path, k)) if v else None 71 | except FileExistsError: # Some components may be provided. 72 | pass 73 | for k, v in (module_config or {}).items(): 74 | config_file = os.path.join(target_path, k, "config.json") 75 | os.remove(config_file) if v and os.path.exists(config_file) else None 76 | json.dump(v, open(config_file, "w")) if v else None 77 | json.dump(model_index, open(os.path.join(target_path, "model_index.json"), "w")) 78 | return target_path 79 | 80 | 81 | def build_diffusion_scheduler(scheduler_path, sample=False, **kwargs) -> SchedulerMixin: 82 | """Create a diffusion scheduler instance. 83 | 84 | Args: 85 | scheduler_path (str or scheduler instance) 86 | The path to load a diffusion scheduler. 87 | sample (bool, *optional*, default to False) 88 | Whether to create the sampling-specific scheduler. 89 | 90 | Returns: 91 | SchedulerMixin: The diffusion scheduler. 92 | 93 | """ 94 | from diffnext.schedulers.scheduling_ddpm import DDPMScheduler 95 | from diffnext.schedulers.scheduling_flow import FlowMatchEulerDiscreteScheduler # noqa 96 | 97 | if isinstance(scheduler_path, str): 98 | class_key = "_{}_class_name".format("sample" if sample else "noise") 99 | class_type = locals()[DDPMScheduler.load_config(**locals())[class_key]] 100 | return class_type.from_pretrained(**locals()) 101 | elif hasattr(scheduler_path, "config"): 102 | class_type = locals()[type(scheduler_path).__name__] 103 | return class_type.from_config(scheduler_path.config) 104 | return None 105 | 106 | 107 | def build_pipeline( 108 | pretrained_path, 109 | pipe_type=None, 110 | precison="float16", 111 | config=None, 112 | **kwargs, 113 | ) -> DiffusionPipeline: 114 | """Create a diffnext pipeline instance. 115 | 116 | Examples: 117 | ```py 118 | >>> from diffnext.pipelines import build_pipeline 119 | >>> pipe = build_pipeline("BAAI/nova-d48w768-sdxl1024", "nova_train_t2i") 120 | ``` 121 | 122 | Args: 123 | pretrained_path (str): 124 | The model path that includes ``model_index.json`` to create pipeline. 125 | pipe_type (str or `type(XXXPipeline)`, *optional*) 126 | The registered pipeline class or specific pipeline type. 127 | precision (str, *optional*, default to ``float16``) 128 | The compute precision used for all pipeline components. 129 | cfg (object, *optional*) 130 | The config object. 131 | 132 | Returns: 133 | DiffusionPipeline: The diffusion pipeline. 134 | 135 | """ 136 | pipe_type = config.PIPELINE.TYPE if config else pipe_type 137 | pipe_type = PIPELINES.get(pipe_type).func if isinstance(pipe_type, str) else pipe_type 138 | precison = config.MODEL.PRECISION if config else precison 139 | kwargs.setdefault("trust_remote_code", True) 140 | kwargs.setdefault("torch_dtype", getattr(torch, precison.lower())) 141 | return pipe_type.from_pretrained(pretrained_path, **kwargs) 142 | -------------------------------------------------------------------------------- /diffnext/pipelines/nova/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2024-present, BAAI. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """NOVA pipelines.""" 17 | 18 | from diffnext.pipelines.nova.pipeline_nova import NOVAPipeline 19 | from diffnext.pipelines.nova.pipeline_nova_c2i import NOVAC2IPipeline 20 | from diffnext.pipelines.nova.pipeline_train_c2i import NOVATrainC2IPipeline 21 | from diffnext.pipelines.nova.pipeline_train_t2i import NOVATrainT2IPipeline 22 | from diffnext.pipelines.nova.pipeline_train_t2v import NOVATrainT2VPipeline 23 | -------------------------------------------------------------------------------- /diffnext/pipelines/nova/pipeline_nova_c2i.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024-present, BAAI. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | ############################################################################## 15 | """Non-quantized autoregressive pipeline for NOVA.""" 16 | 17 | import numpy as np 18 | import torch 19 | 20 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline 21 | 22 | from diffnext.image_processor import VaeImageProcessor 23 | from diffnext.pipelines.builder import PIPELINES 24 | from diffnext.pipelines.nova.pipeline_utils import NOVAPipelineOutput, PipelineMixin 25 | 26 | 27 | @PIPELINES.register("nova_c2i") 28 | class NOVAC2IPipeline(DiffusionPipeline, PipelineMixin): 29 | """NOVA autoregressive diffusion pipeline.""" 30 | 31 | _optional_components = ["transformer", "scheduler", "vae"] 32 | 33 | def __init__(self, transformer=None, scheduler=None, vae=None, trust_remote_code=True): 34 | super(NOVAC2IPipeline, self).__init__() 35 | self.vae = self.register_module(vae, "vae") 36 | self.transformer = self.register_module(transformer, "transformer") 37 | self.scheduler = self.register_module(scheduler, "scheduler") 38 | self.transformer.sample_scheduler, self.guidance_scale = self.scheduler, 5.0 39 | self.image_processor = VaeImageProcessor() 40 | 41 | @torch.no_grad() 42 | def __call__( 43 | self, 44 | prompt=None, 45 | num_inference_steps=64, 46 | num_diffusion_steps=25, 47 | guidance_scale=5, 48 | min_guidance_scale=None, 49 | negative_prompt=None, 50 | num_images_per_prompt=1, 51 | generator=None, 52 | latents=None, 53 | disable_progress_bar=False, 54 | output_type="pil", 55 | **kwargs, 56 | ) -> NOVAPipelineOutput: 57 | """The call function to the pipeline for generation. 58 | 59 | Args: 60 | prompt (int or List[int], *optional*): 61 | The prompt to be encoded. 62 | num_inference_steps (int, *optional*, defaults to 64): 63 | The number of autoregressive steps. 64 | num_diffusion_steps (int, *optional*, defaults to 25): 65 | The number of denoising steps. 66 | guidance_scale (float, *optional*, defaults to 5): 67 | The classifier guidance scale. 68 | min_guidance_scale (float, *optional*): 69 | The minimum classifier guidance scale. 70 | negative_prompt (int or List[int], *optional*): 71 | The prompt or prompts to guide what to not include in image generation. 72 | num_images_per_prompt (int, *optional*, defaults to 1): 73 | The number of images that should be generated per prompt. 74 | generator (torch.Generator, *optional*): 75 | The random generator. 76 | disable_progress_bar (bool, *optional*) 77 | Whether to disable all progress bars. 78 | output_type (str, *optional*, defaults to `"pil"`): 79 | The output format of the generated image. Choose between `PIL.Image` or `np.array`. 80 | 81 | Returns: 82 | NOVAPipelineOutput: The pipeline output. 83 | """ 84 | self.guidance_scale = guidance_scale 85 | inputs = {"generator": generator, **locals()} 86 | num_patches = int(np.prod(self.transformer.config.image_base_size)) 87 | mask_ratios = np.cos(0.5 * np.pi * np.arange(num_inference_steps + 1) / num_inference_steps) 88 | mask_length = np.round(mask_ratios * num_patches).astype("int64") 89 | inputs["num_preds"] = mask_length[:-1] - mask_length[1:] 90 | inputs["tqdm1"], inputs["tqdm2"], inputs["latents"] = False, not disable_progress_bar, [] 91 | inputs["c"] = [self.encode_prompt(**dict(_ for _ in inputs.items() if "prompt" in _[0]))] 92 | inputs["batch_size"] = len(inputs["c"][0]) // (2 if guidance_scale > 1 else 1) 93 | _, outputs = inputs.pop("self"), self.transformer(inputs) 94 | if output_type != "latent": 95 | outputs["x"] = self.image_processor.decode_latents(self.vae, outputs["x"]) 96 | outputs["x"] = self.image_processor.postprocess(outputs["x"], output_type) 97 | return NOVAPipelineOutput(**{"images": outputs["x"]}) 98 | 99 | def encode_prompt( 100 | self, 101 | prompt, 102 | num_images_per_prompt=1, 103 | negative_prompt=None, 104 | ) -> torch.Tensor: 105 | """Encode class prompts. 106 | 107 | Args: 108 | prompt (int or List[int], *optional*): 109 | The prompt to be encoded. 110 | num_images_per_prompt (int, *optional*, defaults to 1): 111 | The number of images that should be generated per prompt. 112 | negative_prompt (int or List[int], *optional*): 113 | The prompt or prompts to guide what to not include in image generation. 114 | 115 | Returns: 116 | torch.Tensor: The prompt index. 117 | """ 118 | 119 | def select_or_pad(a, b, n=1): 120 | return [a or b] * n if isinstance(a or b, int) else (a or b) 121 | 122 | num_classes = self.transformer.label_embed.num_classes 123 | prompt = [prompt] if isinstance(prompt, int) else prompt 124 | negative_prompt = select_or_pad(negative_prompt, num_classes, len(prompt)) 125 | prompts = prompt + (negative_prompt if self.guidance_scale > 1 else []) 126 | c = self.transformer.label_embed(torch.as_tensor(prompts, device=self.device)) 127 | return c.repeat_interleave(num_images_per_prompt, dim=0) 128 | -------------------------------------------------------------------------------- /diffnext/pipelines/nova/pipeline_train_c2i.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024-present, BAAI. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | ############################################################################## 15 | """NOVA C2I training pipeline.""" 16 | 17 | from typing import Dict 18 | 19 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline 20 | import torch 21 | 22 | from diffnext import engine 23 | from diffnext.pipelines.builder import PIPELINES, build_diffusion_scheduler 24 | from diffnext.pipelines.nova.pipeline_utils import PipelineMixin 25 | 26 | 27 | @PIPELINES.register("nova_train_c2i") 28 | class NOVATrainC2IPipeline(DiffusionPipeline, PipelineMixin): 29 | """Pipeline for training NOVA C2I models.""" 30 | 31 | _optional_components = ["transformer", "scheduler", "vae"] 32 | 33 | def __init__(self, transformer=None, scheduler=None, vae=None, trust_remote_code=True): 34 | super(NOVATrainC2IPipeline, self).__init__() 35 | self.vae = self.register_module(vae, "vae") 36 | self.transformer = self.register_module(transformer, "transformer") 37 | self.scheduler = self.register_module(scheduler, "scheduler") 38 | self.transformer.noise_scheduler = build_diffusion_scheduler(self.scheduler) 39 | self.transformer.sample_scheduler, self.guidance_scale = self.scheduler, 5.0 40 | 41 | @property 42 | def model(self) -> torch.nn.Module: 43 | """Return the trainable model.""" 44 | return self.transformer 45 | 46 | def configure_model(self, loss_repeat=4, checkpointing=0, config=None) -> torch.nn.Module: 47 | """Configure the trainable model.""" 48 | self.model.loss_repeat = config.TRAIN.LOSS_REPEAT if config else loss_repeat 49 | ckpt_lvl = config.TRAIN.CHECKPOINTING if config else checkpointing 50 | [setattr(blk, "mlp_checkpointing", ckpt_lvl) for blk in self.model.video_encoder.blocks] 51 | [setattr(blk, "mlp_checkpointing", ckpt_lvl > 1) for blk in self.model.image_encoder.blocks] 52 | [setattr(blk, "mlp_checkpointing", ckpt_lvl > 2) for blk in self.model.image_decoder.blocks] 53 | engine.freeze_module(self.model.label_embed.norm) # We always use frozen LN. 54 | engine.freeze_module(self.model.video_pos_embed) # Freeze this module during C2I. 55 | engine.freeze_module(self.model.video_encoder.patch_embed) # Freeze this module during C2I. 56 | self.model.pipeline_preprocess = self.preprocess 57 | return self.model.train() 58 | 59 | def prepare_latents(self, inputs: Dict): 60 | """Prepare the video latents.""" 61 | if "images" in inputs: 62 | raise NotImplementedError 63 | elif "moments" in inputs: 64 | x = torch.as_tensor(inputs.pop("moments"), device=self.device).to(dtype=self.dtype) 65 | inputs["x"] = self.vae.scale_(self.vae.latent_dist(x).sample()) 66 | 67 | def encode_prompt(self, inputs: Dict): 68 | """Encode class prompts.""" 69 | prompts = torch.as_tensor(inputs.pop("prompt"), device=self.device) 70 | inputs["c"] = [self.transformer.label_embed(prompts)] 71 | 72 | def preprocess(self, inputs: Dict) -> Dict: 73 | """Define the pipeline preprocess at every call.""" 74 | if not self.model.training: 75 | raise RuntimeError("Excepted a trainable model.") 76 | self.prepare_latents(inputs) 77 | self.encode_prompt(inputs) 78 | -------------------------------------------------------------------------------- /diffnext/pipelines/nova/pipeline_train_t2i.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024-present, BAAI. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | ############################################################################## 15 | """NOVA T2I training pipeline.""" 16 | 17 | from typing import Dict 18 | 19 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline 20 | import torch 21 | 22 | from diffnext import engine 23 | from diffnext.pipelines.builder import PIPELINES, build_diffusion_scheduler 24 | from diffnext.pipelines.nova.pipeline_utils import PipelineMixin 25 | 26 | 27 | @PIPELINES.register("nova_train_t2i") 28 | class NOVATrainT2IPipeline(DiffusionPipeline, PipelineMixin): 29 | """Pipeline for training NOVA T2I models.""" 30 | 31 | _optional_components = ["transformer", "scheduler", "vae", "text_encoder", "tokenizer"] 32 | 33 | def __init__( 34 | self, 35 | transformer=None, 36 | scheduler=None, 37 | vae=None, 38 | text_encoder=None, 39 | tokenizer=None, 40 | trust_remote_code=True, 41 | ): 42 | super(NOVATrainT2IPipeline, self).__init__() 43 | self.vae = self.register_module(vae, "vae") 44 | self.text_encoder = self.register_module(text_encoder, "text_encoder") 45 | self.tokenizer = self.register_module(tokenizer, "tokenizer") 46 | self.transformer = self.register_module(transformer, "transformer") 47 | self.scheduler = self.register_module(scheduler, "scheduler") 48 | self.transformer.noise_scheduler = build_diffusion_scheduler(self.scheduler) 49 | self.transformer.sample_scheduler, self.guidance_scale = self.scheduler, 5.0 50 | 51 | @property 52 | def model(self) -> torch.nn.Module: 53 | """Return the trainable model.""" 54 | return self.transformer 55 | 56 | def configure_model(self, loss_repeat=4, checkpointing=0, config=None) -> torch.nn.Module: 57 | """Configure the trainable model.""" 58 | self.model.loss_repeat = config.TRAIN.LOSS_REPEAT if config else loss_repeat 59 | ckpt_lvl = config.TRAIN.CHECKPOINTING if config else checkpointing 60 | [setattr(blk, "mlp_checkpointing", ckpt_lvl) for blk in self.model.video_encoder.blocks] 61 | [setattr(blk, "mlp_checkpointing", ckpt_lvl > 1) for blk in self.model.image_encoder.blocks] 62 | [setattr(blk, "mlp_checkpointing", ckpt_lvl > 2) for blk in self.model.image_decoder.blocks] 63 | engine.freeze_module(self.model.text_embed.norm) # We always use frozen LN. 64 | engine.freeze_module(self.model.video_pos_embed) # Freeze this module during T2I. 65 | engine.freeze_module(self.model.video_encoder.patch_embed) # Freeze this module during T2I. 66 | engine.freeze_module(self.model.motion_embed) if self.model.motion_embed else None 67 | self.model.pipeline_preprocess = self.preprocess 68 | self.model.text_embed.encoders = [self.tokenizer, self.text_encoder] 69 | return self.model.train() 70 | 71 | def prepare_latents(self, inputs: Dict): 72 | """Prepare the video latents.""" 73 | if "images" in inputs: 74 | raise NotImplementedError 75 | elif "moments" in inputs: 76 | x = torch.as_tensor(inputs.pop("moments"), device=self.device).to(dtype=self.dtype) 77 | inputs["x"] = self.vae.scale_(self.vae.latent_dist(x).sample()) 78 | 79 | def encode_prompt(self, inputs: Dict): 80 | """Encode text prompts.""" 81 | inputs["c"] = inputs.get("c", []) 82 | if inputs.get("prompt", None) is not None and self.transformer.text_embed: 83 | inputs["c"].append(self.transformer.text_embed(inputs.pop("prompt"))) 84 | 85 | def preprocess(self, inputs: Dict) -> Dict: 86 | """Define the pipeline preprocess at every call.""" 87 | if not self.model.training: 88 | raise RuntimeError("Excepted a trainable model.") 89 | self.prepare_latents(inputs) 90 | self.encode_prompt(inputs) 91 | -------------------------------------------------------------------------------- /diffnext/pipelines/nova/pipeline_train_t2v.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024-present, BAAI. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | ############################################################################## 15 | """NOVA T2V training pipeline.""" 16 | 17 | from typing import Dict 18 | 19 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline 20 | import torch 21 | 22 | from diffnext import engine 23 | from diffnext.pipelines.builder import PIPELINES, build_diffusion_scheduler 24 | from diffnext.pipelines.nova.pipeline_utils import PipelineMixin 25 | 26 | 27 | @PIPELINES.register("nova_train_t2v") 28 | class NOVATrainT2VPipeline(DiffusionPipeline, PipelineMixin): 29 | """Pipeline for training NOVA T2V models.""" 30 | 31 | _optional_components = ["transformer", "scheduler", "vae", "text_encoder", "tokenizer"] 32 | 33 | def __init__( 34 | self, 35 | transformer=None, 36 | scheduler=None, 37 | vae=None, 38 | text_encoder=None, 39 | tokenizer=None, 40 | trust_remote_code=True, 41 | ): 42 | super(NOVATrainT2VPipeline, self).__init__() 43 | self.vae = self.register_module(vae, "vae") 44 | self.text_encoder = self.register_module(text_encoder, "text_encoder") 45 | self.tokenizer = self.register_module(tokenizer, "tokenizer") 46 | self.transformer = self.register_module(transformer, "transformer") 47 | self.scheduler = self.register_module(scheduler, "scheduler") 48 | self.transformer.noise_scheduler = build_diffusion_scheduler(self.scheduler) 49 | self.transformer.sample_scheduler, self.guidance_scale = self.scheduler, 5.0 50 | 51 | @property 52 | def model(self) -> torch.nn.Module: 53 | """Return the trainable model.""" 54 | return self.transformer 55 | 56 | def configure_model(self, loss_repeat=4, checkpointing=0, config=None) -> torch.nn.Module: 57 | """Configure the trainable model.""" 58 | self.model.loss_repeat = config.TRAIN.LOSS_REPEAT if config else loss_repeat 59 | ckpt_lvl = config.TRAIN.CHECKPOINTING if config else checkpointing 60 | [setattr(blk, "mlp_checkpointing", ckpt_lvl) for blk in self.model.video_encoder.blocks] 61 | [setattr(blk, "mlp_checkpointing", ckpt_lvl > 1) for blk in self.model.image_encoder.blocks] 62 | [setattr(blk, "mlp_checkpointing", ckpt_lvl > 2) for blk in self.model.image_decoder.blocks] 63 | engine.freeze_module(self.model.text_embed.norm) # We always use frozen LN. 64 | engine.freeze_module(self.model.motion_embed) # We always use frozen motion embedding. 65 | self.model.pipeline_preprocess = self.preprocess 66 | self.model.text_embed.encoders = [self.tokenizer, self.text_encoder] 67 | return self.model.train() 68 | 69 | def prepare_latents(self, inputs: Dict): 70 | """Prepare the video latents.""" 71 | if "images" in inputs: 72 | raise NotImplementedError 73 | elif "moments" in inputs: 74 | x = torch.as_tensor(inputs.pop("moments"), device=self.device).to(dtype=self.dtype) 75 | inputs["x"] = self.vae.scale_(self.vae.latent_dist(x).sample()) 76 | 77 | def encode_prompt(self, inputs: Dict): 78 | """Encode text prompts.""" 79 | inputs["c"] = inputs.get("c", []) 80 | if inputs.get("prompt", None) is not None and self.transformer.text_embed: 81 | inputs["c"].append(self.transformer.text_embed(inputs.pop("prompt"))) 82 | 83 | def preprocess(self, inputs: Dict) -> Dict: 84 | """Define the pipeline preprocess at every call.""" 85 | if not self.model.training: 86 | raise RuntimeError("Excepted a trainable model.") 87 | self.prepare_latents(inputs) 88 | self.encode_prompt(inputs) 89 | -------------------------------------------------------------------------------- /diffnext/pipelines/nova/pipeline_utils.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2024-present, BAAI. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """Pipeline utilities.""" 17 | 18 | from typing import List, Union 19 | 20 | from diffusers.utils import BaseOutput 21 | import numpy as np 22 | import PIL.Image 23 | import torch 24 | 25 | 26 | class NOVAPipelineOutput(BaseOutput): 27 | """Output class for NOVA pipelines. 28 | 29 | Args: 30 | images (List[PIL.Image.Image] or np.ndarray) 31 | List of PIL images or numpy array of shape `(batch_size, height, width, num_channels)`. 32 | frames (np.ndarray) 33 | List of video frames. The array shape is `(batch_size, num_frames, height, width, num_channels)` 34 | """ # noqa 35 | 36 | images: Union[List[PIL.Image.Image], np.ndarray] 37 | frames: np.array 38 | 39 | 40 | class PipelineMixin(object): 41 | """Base class for diffusion pipeline.""" 42 | 43 | def register_module(self, model_or_path, name) -> torch.nn.Module: 44 | """Register pipeline component. 45 | 46 | Args: 47 | model_or_path (str or torch.nn.Module): 48 | The model or path to model. 49 | name (str): 50 | The module name. 51 | 52 | Returns: 53 | torch.nn.Module: The registered module. 54 | 55 | """ 56 | model = model_or_path 57 | if isinstance(model_or_path, str): 58 | cls = self.__init__.__annotations__[name] 59 | if hasattr(cls, "from_pretrained") and model_or_path: 60 | model = cls.from_pretrained(model_or_path, torch_dtype=self.dtype) 61 | model = model.to(self.device) if isinstance(model, torch.nn.Module) else model 62 | model = cls() 63 | self.register_to_config(**{name: model.__class__.__name__}) 64 | return model 65 | -------------------------------------------------------------------------------- /diffnext/schedulers/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2024-present, BAAI. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """Schedulers.""" 17 | -------------------------------------------------------------------------------- /diffnext/schedulers/scheduling_flow.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024-present, BAAI. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | ############################################################################## 15 | """Simple implementation of Flow match scheduler.""" 16 | 17 | import dataclasses 18 | 19 | import numpy as np 20 | import torch 21 | 22 | from diffusers.configuration_utils import ConfigMixin, register_to_config 23 | from diffusers.models.modeling_outputs import BaseOutput 24 | from diffusers.schedulers.scheduling_utils import SchedulerMixin 25 | 26 | 27 | @dataclasses.dataclass 28 | class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput): 29 | """Output for scheduler's `step` function output.""" 30 | 31 | prev_sample: torch.FloatTensor 32 | 33 | 34 | class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): 35 | 36 | order = 1 37 | 38 | @register_to_config 39 | def __init__(self, num_train_timesteps=1000, shift=1.0, use_dynamic_shifting=False): 40 | timesteps = np.arange(1, num_train_timesteps + 1, dtype="float32")[::-1] 41 | sigmas = timesteps / num_train_timesteps 42 | if not use_dynamic_shifting: 43 | sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) 44 | self.timesteps = torch.as_tensor(sigmas * num_train_timesteps) 45 | self.sigmas = torch.as_tensor(sigmas) 46 | self.sigma_min, self.sigma_max = float(sigmas[-1]), float(sigmas[0]) 47 | self.timestep = self.sigma = None # Training states. 48 | self._begin_index = self._step_index = None # Inference counters. 49 | 50 | @property 51 | def step_index(self): 52 | """The index counter for current timestep.""" 53 | return self._step_index 54 | 55 | @property 56 | def begin_index(self): 57 | """The index for the first timestep.""" 58 | return self._begin_index 59 | 60 | def _sigma_to_t(self, sigma): 61 | return sigma * self.config.num_train_timesteps 62 | 63 | def _init_step_index(self, timestep): 64 | if self.begin_index is None: 65 | self._step_index = self.index_for_timestep(timestep) 66 | else: 67 | self._step_index = self._begin_index 68 | 69 | def index_for_timestep(self, timestep, schedule_timesteps=None): 70 | if schedule_timesteps is None: 71 | schedule_timesteps = self.timesteps 72 | indices = (schedule_timesteps == timestep).nonzero() 73 | return indices[1 if len(indices) > 1 else 0].item() 74 | 75 | def sample_timesteps(self, size, device=None): 76 | dist = torch.normal(0, 1, size, device=device).sigmoid_() 77 | return dist.mul_(self.config.num_train_timesteps).to(dtype=torch.int64) 78 | 79 | def set_timesteps(self, num_inference_steps): 80 | """Sets the discrete timesteps used for the diffusion chain.""" 81 | self.num_inference_steps = num_inference_steps 82 | t_max, t_min = self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min) 83 | timesteps = np.linspace(t_max, t_min, num_inference_steps, dtype="float32") 84 | sigmas = timesteps / self.config.num_train_timesteps 85 | sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas) 86 | self.sigmas = sigmas.tolist() + [0] 87 | self.timesteps = sigmas * self.config.num_train_timesteps 88 | self._begin_index = self._step_index = None 89 | 90 | def add_noise( 91 | self, 92 | original_samples: torch.Tensor, 93 | noise: torch.Tensor, 94 | timesteps: torch.Tensor, 95 | ): 96 | """Add forward noise to samples for training.""" 97 | dtype, device = original_samples.dtype, original_samples.device 98 | self.timestep = self.timesteps.to(device=device)[timesteps] 99 | self.sigma = self.sigmas.to(device=device, dtype=dtype)[timesteps] 100 | self.sigma = self.sigma.view(timesteps.shape + (1,) * (noise.dim() - timesteps.dim())) 101 | return self.sigma * noise + (1.0 - self.sigma) * original_samples 102 | 103 | def scale_noise(self, sample: torch.Tensor, timestep: float, noise: torch.Tensor): 104 | """Add forward noise to samples for inference.""" 105 | self._init_step_index(timestep) if self.step_index is None else None 106 | sigma = self.sigmas[self.step_index] 107 | return sigma * noise + (1.0 - sigma) * sample 108 | 109 | def step( 110 | self, 111 | model_output: torch.Tensor, 112 | timestep: float, 113 | sample: torch.FloatTensor, 114 | generator: torch.Generator = None, 115 | return_dict=True, 116 | ): 117 | self._init_step_index(timestep) if self.step_index is None else None 118 | dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index] 119 | prev_sample = model_output.mul(dt).add_(sample) 120 | self._step_index += 1 121 | if not return_dict: 122 | return (prev_sample,) 123 | return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample) 124 | -------------------------------------------------------------------------------- /diffnext/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2024-present, BAAI. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, esither express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """Utilities.""" 17 | 18 | from diffnext.utils.export_utils import export_to_image 19 | from diffnext.utils.export_utils import export_to_video 20 | from diffnext.utils.registry import Registry 21 | -------------------------------------------------------------------------------- /diffnext/utils/export_utils.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2024-present, BAAI. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, esither express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """Export utilities.""" 17 | 18 | import tempfile 19 | 20 | try: 21 | import imageio 22 | except ImportError: 23 | imageio = None 24 | import PIL.Image 25 | 26 | 27 | def export_to_image(image, output_image_path=None, suffix=".webp", quality=100): 28 | """Export to image.""" 29 | if output_image_path is None: 30 | output_image_path = tempfile.NamedTemporaryFile(suffix=suffix).name 31 | if isinstance(image, PIL.Image.Image): 32 | image.save(output_image_path, quality=quality) 33 | else: 34 | PIL.Image.fromarray(image).save(output_image_path, quality=quality) 35 | return output_image_path 36 | 37 | 38 | def export_to_video(video_frames, output_video_path=None, fps=12): 39 | """Export to video.""" 40 | if output_video_path is None: 41 | output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name 42 | if imageio is None: 43 | raise ImportError("Failed to import library.") 44 | with imageio.get_writer(output_video_path, fps=fps) as writer: 45 | for frame in video_frames: 46 | writer.append_data(frame) 47 | return output_video_path 48 | -------------------------------------------------------------------------------- /diffnext/utils/logging.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2024-present, BAAI. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, esither express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """Logging utilities.""" 17 | 18 | import inspect 19 | import logging as _logging 20 | import os 21 | import sys as _sys 22 | import threading 23 | 24 | 25 | _logger = None 26 | _logger_lock = threading.Lock() 27 | 28 | 29 | def get_logger(): 30 | global _logger 31 | # Use double-checked locking to avoid taking lock unnecessarily. 32 | if _logger: 33 | return _logger 34 | _logger_lock.acquire() 35 | try: 36 | if _logger: 37 | return _logger 38 | logger = _logging.getLogger("diffnext") 39 | logger.setLevel("INFO") 40 | logger.propagate = False 41 | logger._is_root = True 42 | if True: 43 | # Determine whether we are in an interactive environment. 44 | _interactive = False 45 | try: 46 | # This is only defined in interactive shells. 47 | if _sys.ps1: 48 | _interactive = True 49 | except AttributeError: 50 | # Even now, we may be in an interactive shell with `python -i`. 51 | _interactive = _sys.flags.interactive 52 | # If we are in an interactive environment (like Jupyter), set loglevel 53 | # to INFO and pipe the output to stdout. 54 | if _interactive: 55 | logger.setLevel("INFO") 56 | _logging_target = _sys.stdout 57 | else: 58 | _logging_target = _sys.stderr 59 | # Add the output handler. 60 | _handler = _logging.StreamHandler(_logging_target) 61 | _handler.setFormatter(_logging.Formatter("%(levelname)s %(message)s")) 62 | logger.addHandler(_handler) 63 | _logger = logger 64 | return _logger 65 | finally: 66 | _logger_lock.release() 67 | 68 | 69 | def _detailed_msg(msg): 70 | file, lineno = inspect.stack()[:3][2][1:3] 71 | return "{}:{}] {}".format(os.path.split(file)[-1], lineno, msg) 72 | 73 | 74 | def log(level, msg, *args, **kwargs): 75 | get_logger().log(level, _detailed_msg(msg), *args, **kwargs) 76 | 77 | 78 | def debug(msg, *args, **kwargs): 79 | if is_root(): 80 | get_logger().debug(_detailed_msg(msg), *args, **kwargs) 81 | 82 | 83 | def error(msg, *args, **kwargs): 84 | get_logger().error(_detailed_msg(msg), *args, **kwargs) 85 | assert 0 86 | 87 | 88 | def fatal(msg, *args, **kwargs): 89 | get_logger().fatal(_detailed_msg(msg), *args, **kwargs) 90 | assert 0 91 | 92 | 93 | def info(msg, *args, **kwargs): 94 | if is_root(): 95 | get_logger().info(_detailed_msg(msg), *args, **kwargs) 96 | 97 | 98 | def warning(msg, *args, **kwargs): 99 | if is_root(): 100 | get_logger().warning(_detailed_msg(msg), *args, **kwargs) 101 | 102 | 103 | def get_verbosity(): 104 | """Return how much logging output will be produced.""" 105 | return get_logger().getEffectiveLevel() 106 | 107 | 108 | def set_verbosity(v): 109 | """Set the threshold for what messages will be logged.""" 110 | get_logger().setLevel(v) 111 | 112 | 113 | def set_formatter(fmt=None, datefmt=None): 114 | """Set the formatter.""" 115 | handler = _logging.StreamHandler(_sys.stderr) 116 | handler.setFormatter(_logging.Formatter(fmt, datefmt)) 117 | logger = get_logger() 118 | logger.removeHandler(logger.handlers[0]) 119 | logger.addHandler(handler) 120 | 121 | 122 | def set_root(is_root=True): 123 | """Set logger to the root.""" 124 | get_logger()._is_root = is_root 125 | 126 | 127 | def is_root(): 128 | """Return logger is the root.""" 129 | return get_logger()._is_root 130 | -------------------------------------------------------------------------------- /diffnext/utils/profiler/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2024-present, BAAI. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, esither express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """Profiler utilities.""" 17 | 18 | from diffnext.utils.profiler.stats import SmoothedValue 19 | from diffnext.utils.profiler.timer import Timer 20 | from diffnext.utils.profiler.timer import get_progress 21 | -------------------------------------------------------------------------------- /diffnext/utils/profiler/stats.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2024-present, BAAI. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, esither express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """Trackable statistics.""" 17 | 18 | import collections 19 | import numpy as np 20 | 21 | 22 | class SmoothedValue(object): 23 | """Track values and provide smoothed report.""" 24 | 25 | def __init__(self, window_size=None): 26 | self.deque = collections.deque(maxlen=window_size) 27 | self.total = 0.0 28 | self.count = 0 29 | 30 | def update(self, value): 31 | self.deque.append(value) 32 | self.count += 1 33 | self.total += value 34 | 35 | def mean(self): 36 | return np.mean(self.deque) 37 | 38 | def median(self): 39 | return np.median(self.deque) 40 | 41 | def average(self): 42 | return self.total / self.count 43 | -------------------------------------------------------------------------------- /diffnext/utils/profiler/timer.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2024-present, BAAI. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, esither express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """Timing functions.""" 17 | 18 | import contextlib 19 | import datetime 20 | import time 21 | 22 | 23 | class Timer(object): 24 | """Simple timer.""" 25 | 26 | def __init__(self): 27 | self.total_time = 0.0 28 | self.calls = 0 29 | self.start_time = 0.0 30 | self.diff = 0.0 31 | self.average_time = 0.0 32 | 33 | def add_diff(self, diff, n=1, average=True): 34 | self.total_time += diff 35 | self.calls += n 36 | self.average_time = self.total_time / self.calls 37 | return self.average_time if average else self.diff 38 | 39 | @contextlib.contextmanager 40 | def tic_and_toc(self, n=1, average=True): 41 | try: 42 | yield self.tic() 43 | finally: 44 | self.toc(n, average) 45 | 46 | def tic(self): 47 | self.start_time = time.time() 48 | return self 49 | 50 | def toc(self, n=1, average=True): 51 | self.diff = time.time() - self.start_time 52 | return self.add_diff(self.diff, n, average) 53 | 54 | 55 | def get_progress(timer, step, max_steps): 56 | """Return the progress information.""" 57 | eta_seconds = timer.average_time * (max_steps - step) 58 | eta = str(datetime.timedelta(seconds=int(eta_seconds))) 59 | progress = (step + 1.0) / max_steps 60 | return "< PROGRESS: {:.2%} | SPEED: {:.3f}s / iter | ETA: {} >".format( 61 | progress, timer.average_time, eta 62 | ) 63 | -------------------------------------------------------------------------------- /diffnext/utils/registry.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2024-present, BAAI. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """Registry utilities.""" 17 | 18 | import collections 19 | import functools 20 | 21 | 22 | class Registry(object): 23 | """Registry class.""" 24 | 25 | def __init__(self, name): 26 | self.name = name 27 | self.registry = collections.OrderedDict() 28 | 29 | def has(self, key) -> bool: 30 | return key in self.registry 31 | 32 | def register(self, name, func=None, **kwargs): 33 | def decorated(inner_function): 34 | for key in name if isinstance(name, (tuple, list)) else [name]: 35 | self.registry[key] = functools.partial(inner_function, **kwargs) 36 | return inner_function 37 | 38 | if func is not None: 39 | return decorated(func) 40 | return decorated 41 | 42 | def get(self, name, default=None): 43 | if name is None: 44 | return None 45 | if not self.has(name): 46 | if default is not None: 47 | return default 48 | raise KeyError("`%s` is not registered in <%s>." % (name, self.name)) 49 | return self.registry[name] 50 | 51 | def try_get(self, name): 52 | if self.has(name): 53 | return self.get(name) 54 | return None 55 | -------------------------------------------------------------------------------- /diffnext/utils/tensorboard.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2024-present, BAAI. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """Tensorboard application.""" 17 | 18 | import time 19 | 20 | import numpy as np 21 | 22 | try: 23 | import tensorflow as tf 24 | except ImportError: 25 | tf = None 26 | 27 | 28 | class TensorBoard(object): 29 | """TensorBoard application.""" 30 | 31 | def __init__(self, log_dir=None): 32 | """Create a summary writer logging to log_dir.""" 33 | if tf is None: 34 | raise ImportError("Failed to import ``tensorflow`` package.") 35 | tf.config.set_visible_devices([], "GPU") 36 | if log_dir is None: 37 | log_dir = "./logs/" + time.strftime("%Y%m%d_%H%M%S", time.localtime(time.time())) 38 | self.writer = tf.summary.create_file_writer(log_dir) 39 | 40 | @staticmethod 41 | def is_available(): 42 | """Return if tensor board is available.""" 43 | return tf is not None 44 | 45 | def close(self): 46 | """Close board and apply all cached summaries.""" 47 | self.writer.close() 48 | 49 | def histogram_summary(self, tag, values, step, buckets=10): 50 | """Write a histogram of values.""" 51 | with self.writer.as_default(): 52 | tf.summary.histogram(tag, values, step, buckets=buckets) 53 | 54 | def image_summary(self, tag, images, step, order="BGR"): 55 | """Write a list of images.""" 56 | if isinstance(images, (tuple, list)): 57 | images = np.stack(images) 58 | if len(images.shape) != 4: 59 | raise ValueError("Images can not be packed to (N, H, W, C).") 60 | if order == "BGR": 61 | images = images[:, :, :, ::-1] 62 | with self.writer.as_default(): 63 | tf.summary.image(tag, images, step, max_outputs=images.shape[0]) 64 | 65 | def scalar_summary(self, tag, value, step): 66 | """Write a scalar.""" 67 | with self.writer.as_default(): 68 | tf.summary.scalar(tag, value, step) 69 | -------------------------------------------------------------------------------- /docs/environment.md: -------------------------------------------------------------------------------- 1 | # 1. Use Conda 2 | ```bash 3 | conda create python=3.10 -n nova 4 | conda activate nova 5 | pip install -r requirements.txt 6 | ``` 7 | 8 | # 2. Use Docker 9 | > **Note**: Coming soon. -------------------------------------------------------------------------------- /docs/evaluation.md: -------------------------------------------------------------------------------- 1 | # Evaluations 2 | 3 | ## GenEval 4 | 5 | ### 1. Generate prompt embeddings 6 | ```python 7 | import json, torch 8 | from transformers import CodeGenTokenizerFast 9 | from diffnext.models.text_encoders.phi import PhiEncoderModel 10 | 11 | model_path = "/path/to/nova-d48w1024-sdxl1024" 12 | device, dtype = torch.device("cuda", 0), torch.float16 13 | 14 | tokenizer = CodeGenTokenizerFast.from_pretrained(model_path + "/tokenizer") 15 | model = PhiEncoderModel.from_pretrained(model_path + "/text_encoder", torch_dtype=dtype) 16 | model = model.eval().to(device=device) 17 | 18 | coll_embeds = [[], []] 19 | for data in json.load(open("./evaluations/geneval/prompts.json")): 20 | for i, prompt in enumerate((data["prompt"], data["dense_prompt"])): 21 | input_ids = tokenizer(prompt, max_length=256, truncation=True).input_ids 22 | input_ids = torch.as_tensor(input_ids, device=device, dtype=torch.int64) 23 | with torch.no_grad(): 24 | coll_embeds[i].append(model(input_ids.unsqueeze_(0)).last_hidden_state[0].cpu()) 25 | torch.save({"prompts": coll_embeds[0]}, "./evaluations/geneval/prompts.pth") 26 | torch.save({"prompts": coll_embeds[1]}, "./evaluations/geneval/prompts_rewrite.pth") 27 | ``` 28 | 29 | ### 2. Sample prompt images 30 | ```bash 31 | python ./evaluations/geneval/sample.py \ 32 | --metadata ./evaluations/geneval/metadata.jsonl \ 33 | --prompt ./evaluations/geneval/prompts.pth \ 34 | --ckpt /path/to/nova-d48w1024-sdxl1024 \ 35 | --num_pred_steps 128 --guidance_scale 7 --prompt_size 16 --sample_size 4 \ 36 | --outdir ./evaluations/geneval/nova-d48w1024-sdxl1024-cfg7 37 | ``` 38 | 39 | ### 3. Evaluation 40 | =./evaluations/geneval/nova-d48w1024-sdxl1024-cfg7 41 | 42 | Please refer [GenEval](https://github.com/djghosh13/geneval?tab=readme-ov-file#evaluation) evaluation guide. 43 | 44 | ## VBench 45 | 46 | ### 1. Generate prompt embeddings 47 | ```python 48 | import json, torch 49 | from transformers import CodeGenTokenizerFast 50 | from diffnext.models.text_encoders.phi import PhiEncoderModel 51 | 52 | model_path = "/path/to/nova-d48w1024-osp480" 53 | device, dtype = torch.device("cuda", 0), torch.float16 54 | 55 | tokenizer = CodeGenTokenizerFast.from_pretrained(model_path + "/tokenizer") 56 | model = PhiEncoderModel.from_pretrained(model_path + "/text_encoder", torch_dtype=dtype) 57 | model = model.eval().to(device=device) 58 | 59 | coll_embeds, tags, texts = [[], []], [], [] 60 | for data in json.load(open("./evaluations/vbench/prompts.json")): 61 | for i, prompt in enumerate((data["prompt"], data["dense_prompt"])): 62 | input_ids = tokenizer(prompt, max_length=256, truncation=True).input_ids 63 | input_ids = torch.as_tensor(input_ids, device=device, dtype=torch.int64) 64 | with torch.no_grad(): 65 | coll_embeds[i].append(model(input_ids.unsqueeze_(0)).last_hidden_state[0].cpu()) 66 | tags.append(data["tag"]), texts.append(data["prompt"]) 67 | torch.save({"prompts": coll_embeds[0], "tags": tags, "texts": texts}, "./evaluations/vbench/prompts.pth") 68 | torch.save({"prompts": coll_embeds[1], "tags": tags, "texts": texts}, "./evaluations/vbench/prompts_rewrite.pth") 69 | ``` 70 | 71 | ### 2. Sample prompt videos 72 | ```bash 73 | python ./evaluations/vbench/sample.py \ 74 | --prompt ./evaluations/vbench/prompts.pth \ 75 | --ckpt /path/to/nova-d48w1024-osp480 \ 76 | --num_pred_steps 128 --guidance_scale 7 --prompt_size 8 --sample_size 5 --max_latent_length 9 --flow 5 \ 77 | --outdir ./evaluations/vbench/nova-d48w1024-osp480-cfg7-flow5 78 | ``` 79 | 80 | ### 3. Evaluation 81 | =./evaluations/vbench/nova-d48w1024-osp480-cfg7-flow5 82 | 83 | Please refer [VBench](https://github.com/Vchitect/VBench?tab=readme-ov-file#evaluation-on-the-standard-prompt-suite-of-vbench) evaluation guide. 84 | -------------------------------------------------------------------------------- /docs/inference.md: -------------------------------------------------------------------------------- 1 | # 1. Gradio 2 | ```bash 3 | # For text-to-image demo 4 | python scripts/app_nova_t2i.py --model "BAAI/nova-d48w1024-sdxl1024" --device 0 5 | 6 | # For text-to-video demo 7 | python scripts/app_nova_t2v.py --model "BAAI/nova-d48w1024-osp480" --device 0 8 | ``` 9 | -------------------------------------------------------------------------------- /docs/model_zoo.md: -------------------------------------------------------------------------------- 1 | # 1. text 2 img 2 | | Model | Parameters | Resolution | Data | Weight | GenEval | DPGBench | 3 | |:-----------:|:----------:|:----------:|:----:|:---------------------------------------------------------------------:|:--------:|:-------:| 4 | | NOVA-0.6B | 0.6B | 512x512 | 16M | [🤗 HF link](https://huggingface.co/BAAI/nova-d48w1024-sd512) | 0.75 | 81.76 | 5 | | NOVA-0.3B | 0.3B | 1024x1024 | 600M | [🤗 HF link](https://huggingface.co/BAAI/nova-d48w768-sdxl1024) | 0.67 | 80.60 | 6 | | NOVA-0.6B | 0.6B | 1024x1024 | 600M | [🤗 HF link](https://huggingface.co/BAAI/nova-d48w1024-sdxl1024) | 0.69 | 82.25 | 7 | | NOVA-1.4B | 1.4B | 1024x1024 | 600M | [🤗 HF link](https://huggingface.co/BAAI/nova-d48w1536-sdxl1024) | 0.71 | 83.01 | 8 | 9 | 10 | # 2. text 2 video 11 | | Model | Parameters | Resolution | Data | Weight | VBench | 12 | |:-----------:|:-----------:|:----------:|:----:|-----------------------------------------------------------------------|:------:| 13 | | NOVA-0.6B | 0.6B | 33x768x480 | 20M | [🤗 HF link](https://huggingface.co/BAAI/nova-d48w1024-osp480) | 80.12 | -------------------------------------------------------------------------------- /docs/training.md: -------------------------------------------------------------------------------- 1 | # Training Guide 2 | This guide provides simple snippets to train diffnext models. 3 | 4 | # 1. Build VAE cache 5 | To optimize training workflow, we preprocess images or videos into VAE latents. 6 | 7 | ## Requirements: 8 | ```bash 9 | pip install protobuf==3.20.3 codewithgpu decord 10 | ``` 11 | 12 | ## Build T2I cache 13 | Following snippet can be used to cache image latents: 14 | 15 | ```python 16 | import os, codewithgpu, torch, PIL.Image, numpy as np 17 | from diffnext.models.autoencoders.autoencoder_kl import AutoencoderKL 18 | 19 | device, dtype = torch.device("cuda"), torch.float16 20 | vae = AutoencoderKL.from_pretrained("/path/to/nova-d48w1024-sdxl1024/vae") 21 | vae = vae.to(device=device, dtype=dtype).eval() 22 | 23 | features = {"moments": "bytes", "caption": "string", "text": "string", "shape": ["int64"]} 24 | _, writer = os.makedirs("./img_dataset"), codewithgpu.RecordWriter("./img_dataset", features) 25 | 26 | img = PIL.Image.open("./assets/sample_image.jpg") 27 | x = torch.as_tensor(np.array(img)[None, ...].transpose(0, 3, 1, 2)).to(device).to(dtype) 28 | with torch.no_grad(): 29 | x = vae.encode(x.sub(127.5).div(127.5)).latent_dist.parameters.cpu().numpy()[0] 30 | example = {"caption": "long caption", "text": "short text"} 31 | writer.write({"shape": x.shape, "moments": x.tobytes(), **example}), writer.close() 32 | ``` 33 | 34 | ## Build T2V cache 35 | Following snippet can be used to cache video latents: 36 | 37 | ```python 38 | import os, codewithgpu, torch, decord, numpy as np 39 | from diffnext.models.autoencoders.autoencoder_kl_opensora import AutoencoderKLOpenSora 40 | 41 | device, dtype = torch.device("cuda"), torch.float16 42 | vae = AutoencoderKLOpenSora.from_pretrained("/path/to/nova-d48w1024-osp480/vae") 43 | vae = vae.to(device=device, dtype=dtype).eval() 44 | 45 | features = {"moments": "bytes", "caption": "string", "text": "string", "shape": ["int64"], "flow": "float64"} 46 | _, writer = os.makedirs("./vid_dataset"), codewithgpu.RecordWriter("./vid_dataset", features) 47 | 48 | resize, crop_size, frame_ids = 480, (480, 768), list(range(0, 65, 2)) 49 | vid = decord.VideoReader("./assets/sample_video.mp4") 50 | h, w = vid[0].shape[:2] 51 | scale = float(resize) / float(min(h, w)) 52 | size = int(h * scale + 0.5), int(w * scale + 0.5) 53 | y, x = (size[0] - crop_size[0]) // 2, (size[1] - crop_size[1]) // 2 54 | vid = decord.VideoReader("./assets/sample_video.mp4", height=size[0], width=size[1]) 55 | vid = vid.get_batch(frame_ids).asnumpy() 56 | vid = vid[:, y : y + crop_size[0], x : x + crop_size[1]] 57 | x = torch.as_tensor(vid[None, ...].transpose((0, 4, 1, 2, 3))).to(device).to(dtype) 58 | with torch.no_grad(): 59 | x = vae.encode(x.sub(127.5).div(127.5)).latent_dist.parameters.cpu().numpy()[0] 60 | example = {"caption": "long caption", "text": "short text", "flow": 5} 61 | writer.write({"shape": x.shape, "moments": x.tobytes(), **example}), writer.close() 62 | ``` 63 | 64 | # 2. Train models 65 | 66 | ## Train T2I model 67 | Following snippet provides simple T2I training arguments: 68 | 69 | ```python 70 | from diffnext.config import cfg 71 | cfg.PIPELINE.TYPE = "nova_train_t2i" 72 | cfg.MODEL.WEIGHTS = "/path/to/nova-d48w1024-sdxl1024" 73 | cfg.TRAIN.DATASET = "./img_dataset" 74 | cfg.SOLVER.BASE_LR, cfg.SOLVER.MAX_STEPS = 1e-4, 100 75 | open("./nova_d48w1024_1024px.yml", "w").write(str(cfg)) 76 | ``` 77 | ```bash 78 | python scripts/train.py --cfg ./nova_d48w1024_1024px.yml 79 | ``` 80 | 81 | ## Train T2V model 82 | Following snippet provides simple T2V training arguments: 83 | 84 | ```python 85 | from diffnext.config import cfg 86 | cfg.PIPELINE.TYPE = "nova_train_t2v" 87 | cfg.MODEL.WEIGHTS = "/path/to/nova-d48w1024-osp480" 88 | cfg.TRAIN.DATASET = "./vid_dataset" 89 | cfg.SOLVER.BASE_LR, cfg.SOLVER.MAX_STEPS = 1e-4, 100 90 | open("./nova_d48w1024_480px.yml", "w").write(str(cfg)) 91 | ``` 92 | ```bash 93 | python scripts/train.py --cfg ./nova_d48w1024_480px.yml 94 | ``` 95 | 96 | ## Train DeepSpeed model 97 | ```bash 98 | python scripts/train.py --cfg ./nova_d48w1024_1024px.yml --deepspeed ./configs/deepspeed/zero2_bf16.json 99 | ``` 100 | 101 | This script launches multi-nodes job using *hostfile*. 102 | 103 | Argument usage: 104 | ```bash 105 | python scripts/train.py --host /path/to/my_hostfile 106 | ``` 107 | 108 | Requirements: 109 | 110 | - The total number of slots accumulated in the *hostfile* should be equal to ``cfg.NUM_GPUS``. 111 | - The launcher machine must be able to SSH to all host machines with *passwordless login*. 112 | 113 | See [DeepSpeed's Doc](https://www.deepspeed.ai/getting-started/#resource-configuration-multi-node) for the hostfile details. 114 | 115 | -------------------------------------------------------------------------------- /evaluations/geneval/sample.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2024-present, BAAI. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """Sample GenEval images.""" 17 | 18 | import argparse 19 | import os 20 | import json 21 | 22 | import torch 23 | import PIL 24 | import tqdm 25 | 26 | from diffnext.pipelines.builder import build_pipeline, get_pipeline_path 27 | 28 | 29 | def parse_args(): 30 | """Parse arguments.""" 31 | parser = argparse.ArgumentParser(description="sample geneval images") 32 | parser.add_argument("--metadata", type=str, help="JSONL metadata") 33 | parser.add_argument("--ckpt", type=str, default=None, help="checkpoint file") 34 | parser.add_argument("--prompt", type=str, default="", help="prompt pth file") 35 | parser.add_argument("--num_pred_steps", type=int, default=128, help="inference steps") 36 | parser.add_argument("--num_diff_steps", type=int, default=25, help="diffusion steps") 37 | parser.add_argument("--guidance_scale", type=float, default=7, help="guidance scale") 38 | parser.add_argument("--prompt_size", type=int, default=16, help="prompt size for each batch") 39 | parser.add_argument("--sample_size", type=int, default=4, help="sample size for each prompt") 40 | parser.add_argument("--vae_batch_size", type=int, default=16, help="vae batch size") 41 | parser.add_argument("--outdir", type=str, default="", help="write to") 42 | return parser.parse_args() 43 | 44 | 45 | if __name__ == "__main__": 46 | args = parse_args() 47 | 48 | prompt_dict = torch.load(args.prompt, weights_only=False) 49 | 50 | num_pred_steps, num_diff_steps = args.num_pred_steps, args.num_diff_steps 51 | gen_args = {"num_inference_steps": num_pred_steps, "num_diffusion_steps": num_diff_steps} 52 | img_args = {"guidance_scale": args.guidance_scale, "output_type": "np"} 53 | img_args["vae_batch_size"] = args.vae_batch_size 54 | 55 | rank, world_size = 0, 1 56 | device = torch.device("cuda", rank) 57 | torch.cuda.set_device(device), torch.manual_seed(1337) 58 | generator = torch.Generator(device).manual_seed(1337) 59 | gen_args.update({"generator": generator, "disable_progress_bar": True}) 60 | is_root = device.index == 0 61 | 62 | pipe_path = get_pipeline_path(args.ckpt, {"text_encoder": ""}) 63 | pipe = build_pipeline(pipe_path, "nova", precison="float16").to(device=device) 64 | 65 | prompts = prompt_dict["prompts"] 66 | metadatas = [json.loads(v) for v in open(args.metadata)] 67 | os.makedirs(args.outdir, exist_ok=True) if is_root else None 68 | 69 | grids, prompt_inds = (args.prompt_size, args.sample_size), [] 70 | rank_prompt_inds = list(range(len(prompts)))[slice(rank, None, world_size)] 71 | 72 | for i, idx in enumerate(tqdm.tqdm(rank_prompt_inds, disable=not is_root)): 73 | prompt_inds.append(idx) 74 | if len(prompt_inds) != grids[0] and i != len(rank_prompt_inds) - 1: 75 | continue 76 | batch_prompts = sum([[prompts[i]] * grids[1] for i in prompt_inds], []) 77 | outputs = pipe(prompt_embeds=batch_prompts, **img_args, **gen_args) 78 | img = outputs["frames"][:, 0] if "frames" in outputs else outputs["images"] 79 | for i, idx in enumerate(prompt_inds): 80 | out_path = os.path.join(args.outdir, f"{idx:0>5}") 81 | sample_path = os.path.join(out_path, "samples") 82 | os.makedirs(out_path, exist_ok=True), os.makedirs(sample_path, exist_ok=True) 83 | json.dump(metadatas[idx], open(os.path.join(out_path, "metadata.jsonl"), "w")) 84 | for j in range(grids[1]): 85 | pil_img = PIL.Image.fromarray(img[i * grids[1] + j]) 86 | pil_img.save(os.path.join(sample_path, f"{j:05}.png")) 87 | prompt_inds = [] 88 | -------------------------------------------------------------------------------- /evaluations/vbench/sample.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2024-present, BAAI. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """Sample VBench videos.""" 17 | 18 | import argparse 19 | import collections 20 | import os 21 | 22 | import imageio 23 | import tqdm 24 | import torch 25 | 26 | from diffnext.pipelines.builder import build_pipeline, get_pipeline_path 27 | 28 | 29 | def parse_args(): 30 | """Parse arguments.""" 31 | parser = argparse.ArgumentParser(description="sample vbench videos") 32 | parser.add_argument("--ckpt", type=str, default=None, help="checkpoint file") 33 | parser.add_argument("--prompt", type=str, default="", help="prompt pth file") 34 | parser.add_argument("--max_latent_length", type=int, default=9, help="max latent length") 35 | parser.add_argument("--num_pred_steps", type=int, default=128, help="inference steps") 36 | parser.add_argument("--num_diff_steps", type=int, default=25, help="diffusion steps") 37 | parser.add_argument("--guidance_scale", type=float, default=7, help="guidance scale") 38 | parser.add_argument("--flow", default=5, type=float, help="motion flow") 39 | parser.add_argument("--prompt_size", type=int, default=8, help="prompt size for each batch") 40 | parser.add_argument("--sample_size", type=int, default=5, help="sample size for each prompt") 41 | parser.add_argument("--vae_batch_size", type=int, default=1, help="vae batch size") 42 | parser.add_argument("--outdir", type=str, default="", help="write to") 43 | return parser.parse_args() 44 | 45 | 46 | if __name__ == "__main__": 47 | args = parse_args() 48 | 49 | prompt_dict = torch.load(args.prompt, weights_only=False) 50 | 51 | num_pred_steps, num_diff_steps = args.num_pred_steps, args.num_diff_steps 52 | gen_args = {"num_inference_steps": num_pred_steps, "num_diffusion_steps": num_diff_steps} 53 | vid_args = {"guidance_scale": args.guidance_scale, "output_type": "np"} 54 | vid_args["vae_batch_size"] = args.vae_batch_size 55 | vid_args["max_latent_length"] = args.max_latent_length 56 | 57 | rank, world_size = 0, 1 58 | device = torch.device("cuda", rank) 59 | torch.cuda.set_device(device), torch.manual_seed(1337) 60 | generator = torch.Generator(device).manual_seed(1337) 61 | is_root = rank == 0 62 | gen_args.update({"generator": generator, "disable_progress_bar": True}) 63 | pipe_path = get_pipeline_path(args.ckpt, {"text_encoder": ""}) 64 | pipe = build_pipeline(pipe_path, "nova", precison="float16").to(device=device) 65 | 66 | grids, prompt_inds = (args.prompt_size, args.sample_size), [] 67 | prompts, tags, texts = prompt_dict["prompts"], prompt_dict["tags"], prompt_dict["texts"] 68 | rank_prompt_inds = list(range(len(prompts)))[slice(rank, None, world_size)] 69 | 70 | for i, idx in enumerate(tqdm.tqdm(rank_prompt_inds, disable=True)): 71 | prompt_inds.append(idx) 72 | if len(prompt_inds) != grids[0] and i != len(rank_prompt_inds) - 1: 73 | continue 74 | batch_names = sum( 75 | [[os.path.join(args.outdir, tags[i], texts[i])] * grids[1] for i in prompt_inds], [] 76 | ) 77 | batch_prompts = sum([[prompts[i]] * grids[1] for i in prompt_inds], []) 78 | outputs = pipe(prompt_embeds=batch_prompts, motion_flow=args.flow, **vid_args, **gen_args) 79 | batch_frames = outputs["frames"] 80 | name_cnt = collections.defaultdict(int) 81 | for j, frames in enumerate(batch_frames): 82 | name = batch_names[j].replace(".mp4", "-{}.mp4".format(name_cnt[batch_names[j]])) 83 | name_cnt[batch_names[j]] += 1 84 | with imageio.get_writer(name, fps=12, ffmpeg_log_level="error") as writer: 85 | [writer.append_data(frames[k]) for k in range(frames.shape[0])] 86 | prompt_inds = [] 87 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 100 3 | target-version = ['py310'] 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | diffusers 3 | transformers 4 | accelerate 5 | imageio[ffmpeg] 6 | pyyaml 7 | scipy 8 | codewithgpu 9 | -------------------------------------------------------------------------------- /scripts/app_nova_t2i.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024-present, BAAI. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | ############################################################################## 15 | """NOVA T2I application.""" 16 | 17 | import argparse 18 | import os 19 | 20 | import gradio as gr 21 | import numpy as np 22 | import torch 23 | 24 | from diffnext.pipelines import NOVAPipeline 25 | from diffnext.utils import export_to_image 26 | 27 | # Switch to the allocator optimized for dynamic shape. 28 | os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" 29 | 30 | 31 | def parse_args(): 32 | """Parse arguments.""" 33 | parser = argparse.ArgumentParser(description="Serve NOVA T2I application") 34 | parser.add_argument("--model", default="", help="model path") 35 | parser.add_argument("--device", type=int, default=0, help="device index") 36 | parser.add_argument("--precision", default="float16", help="compute precision") 37 | return parser.parse_args() 38 | 39 | 40 | def generate_image4( 41 | prompt, 42 | negative_prompt, 43 | seed, 44 | randomize_seed, 45 | guidance_scale, 46 | num_inference_steps, 47 | num_diffusion_steps, 48 | progress=gr.Progress(track_tqdm=True), 49 | ): 50 | """Generate 4 images.""" 51 | args = locals() 52 | seed = np.random.randint(2147483647) if randomize_seed else seed 53 | device = getattr(pipe, "_offload_device", pipe.device) 54 | generator = torch.Generator(device=device).manual_seed(seed) 55 | images = pipe(generator=generator, num_images_per_prompt=4, **args).images 56 | return [export_to_image(image, quality=95) for image in images] + [seed] 57 | 58 | 59 | css = """#col-container {margin: 0 auto; max-width: 1366px}""" 60 | title = "Autoregressive Video Generation without Vector Quantization" 61 | abbr = "NOn-quantized Video Autoregressive" 62 | header = ( 63 | "
" 64 | "

Autoregressive Video Generation without Vector Quantization

" 65 | "

[paper]" 66 | "[code]

" 67 | "
" 68 | ) 69 | header2 = f"

🖼️ A {abbr} model for continuous visual generation

" 70 | 71 | examples = [ 72 | "a selfie of an old man with a white beard.", 73 | "a woman with long hair next to a luminescent bird.", 74 | "a digital artwork of a cat styled in a whimsical fashion. The overall vibe is quirky and artistic.", # noqa 75 | "a shiba inu wearing a beret and black turtleneck.", 76 | "a beautiful afghan women by red hair and green eyes.", 77 | "beautiful fireworks in the sky with red, white and blue.", 78 | "A dragon perched majestically on a craggy, smoke-wreathed mountain.", 79 | "A photo of llama wearing sunglasses standing on the deck of a spaceship with the Earth in the background.", # noqa 80 | "Two pandas in fluffy slippers and bathrobes, lazily munching on bamboo.", 81 | ] 82 | 83 | 84 | if __name__ == "__main__": 85 | args = parse_args() 86 | 87 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu", args.device) 88 | model_args = {"torch_dtype": getattr(torch, args.precision.lower()), "trust_remote_code": True} 89 | pipe = NOVAPipeline.from_pretrained(args.model, **model_args).to(device) 90 | 91 | # Main Application. 92 | app = gr.Blocks(css=css, theme="origin").__enter__() 93 | container = gr.Column(elem_id="col-container").__enter__() 94 | _, main_row = gr.Markdown(header), gr.Row().__enter__() 95 | 96 | # Input. 97 | input_col = gr.Column().__enter__() 98 | prompt = gr.Text( 99 | label="Prompt", 100 | placeholder="Describe the video you want to generate", 101 | value="a shiba inu wearing a beret and black turtleneck.", 102 | lines=5, 103 | ) 104 | negative_prompt = gr.Text( 105 | label="Negative Prompt", 106 | placeholder="Describe what you don't want in the image", 107 | value="low quality, deformed, distorted, disfigured, fused fingers, bad anatomy, weird hand", # noqa 108 | lines=5, 109 | ) 110 | # fmt: off 111 | adv_opt = gr.Accordion("Advanced Options", open=True).__enter__() 112 | seed = gr.Slider(label="Seed", maximum=2147483647, step=1, value=0) 113 | randomize_seed = gr.Checkbox(label="Randomize seed", value=True) 114 | guidance_scale = gr.Slider(label="Guidance scale", minimum=1, maximum=10, step=0.1, value=5) 115 | with gr.Row(): 116 | num_inference_steps = gr.Slider(label="Inference steps", minimum=1, maximum=128, step=1, value=64) # noqa 117 | num_diffusion_steps = gr.Slider(label="Diffusion steps", minimum=1, maximum=50, step=1, value=25) # noqa 118 | adv_opt.__exit__() 119 | generate = gr.Button("Generate Image", variant="primary", size="lg") 120 | input_col.__exit__() 121 | # fmt: on 122 | 123 | # Results. 124 | result_col, _ = gr.Column().__enter__(), gr.Markdown(header2) 125 | with gr.Row(): 126 | result1 = gr.Image(label="Result1", show_label=False) 127 | result2 = gr.Image(label="Result2", show_label=False) 128 | with gr.Row(): 129 | result3 = gr.Image(label="Result3", show_label=False) 130 | result4 = gr.Image(label="Result4", show_label=False) 131 | result_col.__exit__(), main_row.__exit__() 132 | 133 | # Examples. 134 | with gr.Row(): 135 | gr.Examples(examples=examples, inputs=[prompt]) 136 | 137 | # Events. 138 | container.__exit__() 139 | gr.on( 140 | triggers=[generate.click, prompt.submit, negative_prompt.submit], 141 | fn=generate_image4, 142 | inputs=[ 143 | prompt, 144 | negative_prompt, 145 | seed, 146 | randomize_seed, 147 | guidance_scale, 148 | num_inference_steps, 149 | num_diffusion_steps, 150 | ], 151 | outputs=[result1, result2, result3, result4, seed], 152 | ) 153 | app.__exit__(), app.launch(share=False) 154 | -------------------------------------------------------------------------------- /scripts/app_nova_t2v.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024-present, BAAI. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | ############################################################################## 15 | """NOVA T2V application.""" 16 | 17 | import argparse 18 | import os 19 | 20 | import gradio as gr 21 | import numpy as np 22 | import PIL.Image 23 | import torch 24 | 25 | from diffnext.pipelines import NOVAPipeline 26 | from diffnext.utils import export_to_video 27 | 28 | # Fix tokenizer fork issue. 29 | os.environ["TOKENIZERS_PARALLELISM"] = "true" 30 | # Switch to the allocator optimized for dynamic shape. 31 | os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" 32 | 33 | 34 | def parse_args(): 35 | """Parse arguments.""" 36 | parser = argparse.ArgumentParser(description="Serve NOVA T2V application") 37 | parser.add_argument("--model", default="", help="model path") 38 | parser.add_argument("--device", type=int, default=0, help="device index") 39 | parser.add_argument("--precision", default="float16", help="compute precision") 40 | return parser.parse_args() 41 | 42 | 43 | def crop_image(image, target_h, target_w): 44 | """Center crop image to target size.""" 45 | h, w = image.height, image.width 46 | aspect_ratio_target, aspect_ratio = target_w / target_h, w / h 47 | if aspect_ratio > aspect_ratio_target: 48 | new_w = int(h * aspect_ratio_target) 49 | x_start = (w - new_w) // 2 50 | image = image.crop((x_start, 0, x_start + new_w, h)) 51 | else: 52 | new_h = int(w / aspect_ratio_target) 53 | y_start = (h - new_h) // 2 54 | image = image.crop((0, y_start, w, y_start + new_h)) 55 | return np.array(image.resize((target_w, target_h), PIL.Image.Resampling.BILINEAR)) 56 | 57 | 58 | def generate_video( 59 | prompt, 60 | negative_prompt, 61 | image_prompt, 62 | motion_flow, 63 | preset, 64 | seed, 65 | randomize_seed, 66 | guidance_scale, 67 | num_inference_steps, 68 | num_diffusion_steps, 69 | progress=gr.Progress(track_tqdm=True), 70 | ): 71 | """Generate a video.""" 72 | args = locals() 73 | preset = [p for p in video_presets if p["label"] == preset][0] 74 | args["max_latent_length"] = preset["#latents"] 75 | args["image"] = crop_image(image_prompt, preset["h"], preset["w"]) if image_prompt else None 76 | seed = np.random.randint(2147483647) if randomize_seed else seed 77 | device = getattr(pipe, "_offload_device", pipe.device) 78 | generator = torch.Generator(device=device).manual_seed(seed) 79 | frames = pipe(generator=generator, **args).frames[0] 80 | return export_to_video(frames, fps=12), seed 81 | 82 | 83 | title = "Autoregressive Video Generation without Vector Quantization" 84 | abbr = "NOn-quantized Video Autoregressive" 85 | header = ( 86 | "
" 87 | "

Autoregressive Video Generation without Vector Quantization

" 88 | "

[paper]" 89 | "[code]

" 90 | "
" 91 | ) 92 | header2 = f"

🎞️ A {abbr} model for continuous visual generation

" 93 | 94 | video_presets = [ 95 | {"label": "33x768x480", "w": 768, "h": 480, "#latents": 9}, 96 | {"label": "17x768x480", "w": 768, "h": 480, "#latents": 5}, 97 | {"label": "1x768x480", "w": 768, "h": 480, "#latents": 1}, 98 | ] 99 | 100 | 101 | prompts = [ 102 | "Niagara falls with colorful paint instead of water.", 103 | "Many spotted jellyfish pulsating under water. Their bodies are transparent and glowing in deep ocean.", # noqa 104 | "An intense close-up of a soldier’s face, covered in dirt and sweat, his eyes filled with determination as he surveys the battlefield.", # noqa 105 | "a close-up shot of a woman standing in a dimly lit room. she is wearing a traditional chinese outfit, which includes a red and gold dress with intricate designs and a matching headpiece. the woman has her hair styled in an updo, adorned with a gold accessory. her makeup is done in a way that accentuates her features, with red lipstick and dark eyeshadow. she is looking directly at the camera with a neutral expression. the room has a rustic feel, with wooden beams and a stone wall visible in the background. the lighting in the room is soft and warm, creating a contrast with the woman's vibrant attire. there are no texts or other objects in the video. the style of the video is a portrait, focusing on the woman and her attire.", # noqa 106 | "The camera slowly rotates around a massive stack of vintage televisions that are placed within a large New York museum gallery. Each of the televisions is showing a different program. There are 1950s sci-fi movies with their distinctive visuals, horror movies with their creepy scenes, news broadcasts with moving images and words, static on some screens, and a 1970s sitcom with its characteristic look. The televisions are of various sizes and designs, some with rounded edges and others with more angular shapes. The gallery is well-lit, with light falling on the stack of televisions and highlighting the different programs being shown. There are no people visible in the immediate vicinity, only the stack of televisions and the surrounding gallery space.", # noqa 107 | ] 108 | motion_flows = [5, 5, 5, 5, 5] 109 | videos = ["", "", "", "", ""] 110 | examples = [list(x) for x in zip(prompts, motion_flows)] 111 | 112 | 113 | if __name__ == "__main__": 114 | args = parse_args() 115 | 116 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu", args.device) 117 | model_args = {"torch_dtype": getattr(torch, args.precision.lower()), "trust_remote_code": True} 118 | pipe = NOVAPipeline.from_pretrained(args.model, **model_args).to(device) 119 | 120 | # Application. 121 | app = gr.Blocks(theme="origin").__enter__() 122 | container = gr.Column(elem_id="col-container").__enter__() 123 | _, main_row = gr.Markdown(header), gr.Row().__enter__() 124 | 125 | # Input. 126 | input_col = gr.Column().__enter__() 127 | prompt = gr.Text( 128 | label="Prompt", 129 | placeholder="Describe the video you want to generate", 130 | value="Niagara falls with colorful paint instead of water.", 131 | lines=5, 132 | ) 133 | negative_prompt = gr.Text( 134 | label="Negative Prompt", 135 | placeholder="Describe what you don't want in the video", 136 | value="", 137 | lines=1, 138 | ) 139 | image_prompt = gr.Image(label="Image Prompt (Optional) ", type="pil") 140 | # fmt: off 141 | adv_opt = gr.Accordion("Advanced Options", open=False).__enter__() 142 | seed = gr.Slider(label="Seed", maximum=2147483647, step=1, value=0) 143 | randomize_seed = gr.Checkbox(label="Randomize seed", value=True) 144 | guidance_scale = gr.Slider(label="Guidance scale", minimum=1, maximum=10.0, step=0.1, value=7.0) 145 | with gr.Row(): 146 | num_inference_steps = gr.Slider(label="Inference steps", minimum=1, maximum=128, step=1, value=128) # noqa 147 | num_diffusion_steps = gr.Slider(label="Diffusion steps", minimum=1, maximum=100, step=1, value=100) # noqa 148 | adv_opt.__exit__() 149 | generate = gr.Button("Generate Video", variant="primary", size="lg") 150 | input_col.__exit__() 151 | 152 | # Results. 153 | result_col, _ = gr.Column().__enter__(), gr.Markdown(header2) 154 | preset = gr.Dropdown([p["label"] for p in video_presets], label="Video Preset", value=video_presets[0]["label"]) # noqa 155 | motion_flow = gr.Slider(label="Motion Flow", minimum=1, maximum=10, step=1, value=5) 156 | result = gr.Video(label="Result", show_label=False, autoplay=True) 157 | result_col.__exit__(), main_row.__exit__() 158 | # fmt: on 159 | 160 | # Examples. 161 | with gr.Row(): 162 | gr.Examples(examples=examples, inputs=[prompt, motion_flow]) 163 | 164 | # Events. 165 | container.__exit__() 166 | gr.on( 167 | triggers=[generate.click, prompt.submit, negative_prompt.submit], 168 | fn=generate_video, 169 | inputs=[ 170 | prompt, 171 | negative_prompt, 172 | image_prompt, 173 | motion_flow, 174 | preset, 175 | seed, 176 | randomize_seed, 177 | guidance_scale, 178 | num_inference_steps, 179 | num_diffusion_steps, 180 | ], 181 | outputs=[result, seed], 182 | ) 183 | app.__exit__(), app.launch(share=False) 184 | -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2024-present, BAAI. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """Train a diffnext model.""" 17 | 18 | import argparse 19 | import os 20 | import sys 21 | import subprocess 22 | 23 | from diffnext import engine 24 | from diffnext.config import cfg 25 | from diffnext.data import get_dataset_size 26 | from diffnext.utils import logging 27 | 28 | 29 | def parse_args(): 30 | """Parse arguments.""" 31 | parser = argparse.ArgumentParser(description="Train a diffnext model") 32 | parser.add_argument("--cfg", default=None, help="config file") 33 | parser.add_argument("--exp-dir", default=None, help="experiment dir") 34 | parser.add_argument("--tensorboard", action="store_true", help="write metrics to tensorboard") 35 | parser.add_argument("--distributed", action="store_true", help="spawn distributed processes") 36 | parser.add_argument("--host", default="", help="hostfile for distributed training") 37 | parser.add_argument("--deepspeed", type=str, default="", help="deepspeed config file") 38 | return parser.parse_args() 39 | 40 | 41 | def spawn_processes(args, coordinator): 42 | """Spawn distributed processes.""" 43 | if args.deepspeed: 44 | cmd = "deepspeed --no_local_rank " 45 | cmd += '-H {} --launcher_args="-N" '.format(args.host) if args.host else "" 46 | cmd += "--num_gpus {} ".format(cfg.NUM_GPUS) if not args.host else "" 47 | else: 48 | cmd = "torchrun --nproc_per_node {} ".format(cfg.NUM_GPUS) 49 | cmd += "{} --distributed".format(os.path.abspath(__file__)) 50 | cmd += " --cfg {}".format(os.path.abspath(args.cfg)) 51 | cmd += " --exp-dir {}".format(coordinator.exp_dir) 52 | cmd += " --tensorboard" if args.tensorboard else "" 53 | cmd += " --deepspeed {}".format(args.deepspeed) if args.deepspeed else "" 54 | return subprocess.call(cmd, shell=True), sys.exit() 55 | 56 | 57 | def main(args): 58 | """Main entry point.""" 59 | logging.info("Called with args:\n" + str(args)) 60 | coordinator = engine.Coordinator(args.cfg, args.exp_dir) 61 | checkpoint, start_iter = coordinator.get_checkpoint() 62 | cfg.MODEL.WEIGHTS = checkpoint or cfg.MODEL.WEIGHTS 63 | logging.info("Using config:\n" + str(cfg)) 64 | spawn_processes(args, coordinator) if cfg.NUM_GPUS > 1 else None 65 | engine.manual_seed(cfg.RNG_SEED, (cfg.GPU_ID, cfg.RNG_SEED)) 66 | dataset_size = get_dataset_size(cfg.TRAIN.DATASET) 67 | logging.info("Dataset({}): {} examples for training.".format(cfg.TRAIN.DATASET, dataset_size)) 68 | logging.info("Checkpoints will be saved to `{:s}`".format(coordinator.path_at("checkpoints"))) 69 | engine.run_train(coordinator, start_iter, enable_tensorboard=args.tensorboard) 70 | 71 | 72 | def main_distributed(args): 73 | """Main distributed entry point.""" 74 | coordinator = engine.Coordinator(args.cfg, exp_dir=args.exp_dir) 75 | coordinator.deepspeed = args.deepspeed 76 | checkpoint, start_iter = coordinator.get_checkpoint() 77 | cfg.MODEL.WEIGHTS = checkpoint or cfg.MODEL.WEIGHTS 78 | engine.create_ddp_group(cfg) 79 | engine.manual_seed(cfg.RNG_SEED, (cfg.GPU_ID, cfg.RNG_SEED + engine.get_ddp_rank())) 80 | dataset_size = get_dataset_size(cfg.TRAIN.DATASET) 81 | logging.info("Dataset({}): {} examples for training.".format(cfg.TRAIN.DATASET, dataset_size)) 82 | logging.info("Checkpoints will be saved to `{:s}`".format(coordinator.path_at("checkpoints"))) 83 | engine.run_train(coordinator, start_iter, enable_tensorboard=args.tensorboard) 84 | 85 | 86 | if __name__ == "__main__": 87 | args = parse_args() 88 | main_distributed(args) if args.distributed else main(args) 89 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2024-present, BAAI. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ------------------------------------------------------------------------ 16 | """Python setup script.""" 17 | 18 | import argparse 19 | import os 20 | import shutil 21 | import subprocess 22 | import sys 23 | 24 | import setuptools 25 | import setuptools.command.build_py 26 | import setuptools.command.install 27 | 28 | 29 | def parse_args(): 30 | """Parse arguments.""" 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument("--version", default=None) 33 | args, unknown = parser.parse_known_args() 34 | sys.argv = [sys.argv[0]] + unknown 35 | args.git_version = None 36 | args.long_description = "" 37 | if args.version is None and os.path.exists("version.txt"): 38 | with open("version.txt", "r") as f: 39 | args.version = f.read().strip() 40 | if os.path.exists(".git"): 41 | try: 42 | git_version = subprocess.check_output(["git", "rev-parse", "HEAD"], cwd="./") 43 | args.git_version = git_version.decode("ascii").strip() 44 | except (OSError, subprocess.CalledProcessError): 45 | pass 46 | if os.path.exists("README.md"): 47 | with open(os.path.join("README.md"), encoding="utf-8") as f: 48 | args.long_description = f.read() 49 | return args 50 | 51 | 52 | def clean_builds(): 53 | for path in ["build", "diffnext.egg-info"]: 54 | if os.path.exists(path): 55 | shutil.rmtree(path) 56 | 57 | 58 | def find_packages(top): 59 | """Return the python sources installed to package.""" 60 | packages = [] 61 | for root, _, _ in os.walk(top): 62 | if os.path.exists(os.path.join(root, "__init__.py")): 63 | packages.append(root) 64 | return packages 65 | 66 | 67 | def find_package_data(): 68 | """Return the external data installed to package.""" 69 | return [] 70 | 71 | 72 | class BuildPyCommand(setuptools.command.build_py.build_py): 73 | """Enhanced 'build_py' command.""" 74 | 75 | def build_packages(self): 76 | with open("diffnext/version.py", "w") as f: 77 | f.write( 78 | 'version = "{}"\n' 79 | 'git_version = "{}"\n' 80 | "__version__ = version\n".format(args.version, args.git_version) 81 | ) 82 | super(BuildPyCommand, self).build_packages() 83 | 84 | def build_package_data(self): 85 | self.package_data = {"diffnext": find_package_data()} 86 | super(BuildPyCommand, self).build_package_data() 87 | 88 | 89 | class InstallCommand(setuptools.command.install.install): 90 | """Enhanced 'install' command.""" 91 | 92 | def initialize_options(self): 93 | super(InstallCommand, self).initialize_options() 94 | self.old_and_unmanageable = True 95 | 96 | 97 | args = parse_args() 98 | setuptools.setup( 99 | name="diffnext", 100 | version=args.version, 101 | description="A diffusers based library for autoregressive diffusion models.", 102 | long_description=args.long_description, 103 | long_description_content_type="text/markdown", 104 | url="https://github.com/baaivision/NOVA", 105 | author="BAAI", 106 | license="Apache License", 107 | packages=find_packages("diffnext"), 108 | cmdclass={"build_py": BuildPyCommand, "install": InstallCommand}, 109 | install_requires=[ 110 | "torch", 111 | "diffusers", 112 | "transformers", 113 | "accelerate", 114 | "imageio[ffmpeg]", 115 | "pyyaml", 116 | "scipy", 117 | ], 118 | classifiers=[ 119 | "Development Status :: 5 - Production/Stable", 120 | "Intended Audience :: Developers", 121 | "Intended Audience :: Education", 122 | "Intended Audience :: Science/Research", 123 | "License :: OSI Approved :: Apache Software License", 124 | "Programming Language :: Python :: 3", 125 | "Programming Language :: Python :: 3 :: Only", 126 | "Topic :: Scientific/Engineering", 127 | "Topic :: Scientific/Engineering :: Mathematics", 128 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 129 | ], 130 | ) 131 | clean_builds() 132 | -------------------------------------------------------------------------------- /version.txt: -------------------------------------------------------------------------------- 1 | 0.1.0a0 2 | --------------------------------------------------------------------------------