├── README.md ├── assets └── teaser │ ├── teaser.png │ ├── video_editing.gif │ └── video_generation.gif ├── callbacks ├── __init__.py └── log_master.py ├── configs ├── _meta_ │ └── inference.yaml ├── callbacks │ ├── early_stopping.yaml │ ├── log_master.yaml │ ├── lr_monitor.yaml │ └── model_checkpoint.yaml ├── data │ ├── default.yaml │ ├── train │ │ └── pexels300k.yaml │ └── val │ │ └── pexels300k.yaml ├── default.yaml ├── logdir │ └── default.yaml ├── loggers │ └── tensorboard.yaml ├── model │ └── sdvideo.yaml └── trainer │ ├── default.yaml │ └── distributed.yaml ├── data ├── __init__.py ├── constants.py ├── data_downloader.py ├── dataloader.py ├── datasets.py ├── registry.py └── utils.py ├── inference.sh ├── main.py ├── model ├── __init__.py ├── model.py ├── modules │ ├── __init__.py │ ├── attention.py │ ├── attention_hasFFN.py │ ├── modules.py │ ├── resnet.py │ ├── uent.py │ ├── unet_blocks.py │ ├── utils.py │ └── zero_snr_ddpm.py ├── pipeline.py └── utils.py ├── requirements.txt ├── scripts └── launcher.py └── utils ├── __init__.py ├── dist.py ├── file_ops.py ├── logger.py ├── registry.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 |

MagDiff: Multi-Alignment Diffusion for High-Fidelity Video Generation and Editing

4 |

5 | Haoyu Zhao 6 | · 7 | Tianyi Lu 8 | · 9 | Jiaxi Gu 10 | · 11 | Xing Zhang 12 | · 13 | Qingping Zheng 14 | · 15 | Zuxuan Wu 16 | · 17 | Hang Xu 18 | · 19 | Yu-Gang Jiang 20 |
21 |
22 | Paper PDF 23 | Project Page 24 |
25 | Fudan University   |   Huawei Noah's Ark Lab 26 |

27 | 28 | 29 | 30 | 33 | 34 | 35 | 38 | 39 |
31 |
32 |
36 |
37 |
40 | 41 | ## 📢 News 42 | * **[2024.12.22]** Release inference code. We are working to improve MagDiff, stay tuned! 43 | * **[2024.07.04]** Our paper has been accepted by the 18th European Conference on Computer Vision (ECCV) 2024. 44 | * **[2023.11.29]** Release first paper version on Arxiv. 45 | 46 | ## 🏃‍♂️ Getting Started 47 | Download the pretrained base models for [StableDiffusion V2.1](https://huggingface.co/stabilityai/stable-diffusion-2-1-base). 48 | 49 | Download our MagDiff [checkpoints](https://huggingface.co/). 50 | 51 | Please follow the huggingface download instructions to download the above models and checkpoints. 52 | 53 | Below is an example structure of these model files. 54 | 55 | ``` 56 | assets/ 57 | ├── MagDiff.pth 58 | └── stable-diffusion-2-1-base/ 59 | ├── scheduler/... 60 | ├── text_encoder/... 61 | ├── tokenizer/... 62 | ├── unet/... 63 | ├── vae/... 64 | ├── ... 65 | └── README.md 66 | ``` 67 | 68 | ## ⚒️ Installation 69 | prerequisites: `python>=3.10`, `CUDA>=11.8`. 70 | 71 | Install with `pip`: 72 | ```bash 73 | pip3 install -r requirements.txt 74 | ``` 75 | 76 | ## 💃 Inference 77 | Run inference on single GPU: 78 | ```bash 79 | bash inference.sh 80 | ``` 81 | 82 | ## 🎓 Citation 83 | If you find this codebase useful for your research, please use the following entry. 84 | ```BibTeX 85 | @inproceedings{zhao2024magdiff, 86 | author = {Zhao, Haoyu and Lu, Tianyi and Gu, Jiaxi and Zhang, Xing and Zheng, Qingping and Wu, Zuxuan and Xu, Hang and Jiang Yu-Gang}, 87 | title = {MagDiff: Multi-Alignment Diffusion for High-Fidelity Video Generation and Editing}, 88 | booktitle = {European Conference on Computer Vision}, 89 | year = {2024} 90 | } 91 | ``` -------------------------------------------------------------------------------- /assets/teaser/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gulucaptain/videoassembler/1fb50cd5f85aa5896b85003aff3641b00946eddc/assets/teaser/teaser.png -------------------------------------------------------------------------------- /assets/teaser/video_editing.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gulucaptain/videoassembler/1fb50cd5f85aa5896b85003aff3641b00946eddc/assets/teaser/video_editing.gif -------------------------------------------------------------------------------- /assets/teaser/video_generation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gulucaptain/videoassembler/1fb50cd5f85aa5896b85003aff3641b00946eddc/assets/teaser/video_generation.gif -------------------------------------------------------------------------------- /callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | from .log_master import LogMaster -------------------------------------------------------------------------------- /callbacks/log_master.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import json 3 | import logging 4 | import os 5 | import time 6 | from copy import deepcopy 7 | from pathlib import Path 8 | 9 | from pytorch_lightning.callbacks.callback import Callback 10 | from pytorch_lightning.loggers import TensorBoardLogger 11 | from pytorch_lightning.utilities import rank_zero_only 12 | 13 | from utils import ops 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | class LogMaster(Callback): 19 | def __init__( 20 | self, name, log_file, monitor="val_loss", monitor_fn="min", 21 | save_ckpt=True, remote_dir=None 22 | ): 23 | super().__init__() 24 | # logger's file handler is not set here and should be set outside 25 | # because rank is unknown when this function is called in ddp. 26 | self.name = name 27 | self.log_file = log_file 28 | self.monitor = monitor 29 | self.monitor_fn = eval(monitor_fn) 30 | self.save_ckpt = save_ckpt 31 | self.remote_dir = os.path.join(remote_dir, self.name) \ 32 | if remote_dir else remote_dir 33 | self.model_perf = {} 34 | self.t_train_start, self.t_batch_start = .0, .0 35 | 36 | def on_train_start(self, trainer, pl_module) -> None: 37 | # log train configs 38 | logger.info( 39 | f"WORLD_SIZE={trainer.world_size}; " 40 | f"NUM_NODES={trainer.num_nodes}; " 41 | f"GPUS_PER_NODE={trainer.num_devices}; " 42 | f"STEPS_PER_EPOCH={trainer.num_training_batches}." 43 | ) 44 | # upload training task config 45 | self.upload_to_cloud("config.yaml") 46 | # upload log file 47 | self.upload_to_cloud(self.log_file) 48 | self.t_train_start = time.monotonic() 49 | 50 | def on_train_batch_start(self, trainer, module, batch, batch_idx, unused=0): 51 | if batch_idx % trainer.log_every_n_steps == 1: 52 | self.t_batch_start = time.monotonic() 53 | 54 | def on_train_batch_end( 55 | self, trainer, module, outputs, batch, batch_idx, unused=0) -> None: 56 | if batch_idx % trainer.log_every_n_steps == 1: 57 | metrics = deepcopy(trainer.callback_metrics) 58 | msg = "; ".join([f"{k}={v.item():.4f}" for k, v in metrics.items() 59 | if k.startswith("train")]) 60 | zfill_batch = len(str(trainer.estimated_stepping_batches)) 61 | time_elapsed = datetime.timedelta( 62 | seconds=int(time.monotonic() - self.t_train_start)) 63 | time_remained = datetime.timedelta( 64 | seconds=int( 65 | (time.monotonic() - self.t_batch_start) * 66 | (trainer.estimated_stepping_batches - trainer.global_step) 67 | ) 68 | ) 69 | time_info = f"{time_elapsed} < {time_remained}" 70 | logger.info("[Steps {}/{}]: {} (Time: {})".format( 71 | str(trainer.global_step).zfill(zfill_batch), 72 | str(trainer.estimated_stepping_batches).zfill(zfill_batch), 73 | msg, time_info 74 | )) 75 | # upload log files 76 | self.upload_to_cloud(self.log_file) 77 | self.upload_to_cloud(self.get_tblog_dir(trainer.loggers)) 78 | 79 | def on_validation_end(self, trainer, pl_module): 80 | # exit() 81 | zfill_batch = len(str(trainer.estimated_stepping_batches)) 82 | metrics = deepcopy(trainer.callback_metrics) 83 | monitor_metric = metrics.get(self.monitor) 84 | if trainer.global_step > 200: 85 | if trainer.global_step % 5000 == 0: 86 | ckpt_name = f"steps_{str(trainer.global_step).zfill(zfill_batch)}.pth" 87 | self.model_perf.setdefault(ckpt_name, (ckpt_name, monitor_metric)) 88 | if self.save_ckpt: 89 | trainer.save_checkpoint(ckpt_name, weights_only=True) 90 | return 91 | if trainer.sanity_checking: 92 | return 93 | metrics = deepcopy(trainer.callback_metrics) 94 | metrics = {k: v.item() for k, v in metrics.items() if k.startswith("val")} 95 | if len(metrics) < 1: 96 | logger.warning("There are no metrics for validation!") 97 | zfill_batch = len(str(trainer.estimated_stepping_batches)) 98 | msg = "; ".join([f"{k}={v:.6f}" for k, v in metrics.items()]) 99 | logger.info("[Evaluation] [Steps={}/{}]: {}".format( 100 | str(trainer.global_step).zfill(zfill_batch), 101 | str(trainer.estimated_stepping_batches).zfill(zfill_batch), msg 102 | )) 103 | # upload log file 104 | if self.monitor not in metrics: 105 | raise KeyError(f"Metric `{self.monitor}` not in callback metrics: {metrics.keys()}") 106 | monitor_metric = metrics.get(self.monitor) 107 | self.upload_to_cloud(self.log_file) 108 | self.upload_to_cloud(self.get_tblog_dir(trainer.loggers)) 109 | # compute model performance 110 | ckpt_name = f"steps_{str(trainer.global_step).zfill(zfill_batch)}.pth" 111 | self.model_perf.setdefault(ckpt_name, (ckpt_name, monitor_metric)) 112 | # update best model information 113 | best_model = self.monitor_fn(self.model_perf, key=lambda x: self.model_perf.get(x)[-1]) 114 | self.model_perf["best_model"] = ( 115 | self.model_perf[best_model][0], self.model_perf[best_model][-1]) 116 | perf_file = f"performances_{self.monitor}.json" 117 | with open(perf_file, "w") as f: 118 | json.dump(self.model_perf, f, indent=2) 119 | self.upload_to_cloud(perf_file) 120 | # save to checkpoint and upload to remote server 121 | if self.save_ckpt: 122 | trainer.save_checkpoint(ckpt_name, weights_only=True) 123 | if self.upload_to_cloud(ckpt_name): 124 | Path(ckpt_name).unlink() 125 | if self.remote_dir: 126 | logger.info(f"Detailed results are saved in: {self.remote_dir}") 127 | # upload output samples 128 | samples_dir = Path(f"samples_s{str(trainer.global_step).zfill(zfill_batch)}") 129 | self.upload_to_cloud(samples_dir.as_posix()) 130 | 131 | @staticmethod 132 | def get_tblog_dir(loggers): 133 | tb_loggers = [x for x in loggers if isinstance(x, TensorBoardLogger)] 134 | return tb_loggers[0].log_dir if tb_loggers else "" 135 | 136 | @rank_zero_only 137 | def upload_to_cloud(self, local_file): 138 | if ops.mox_valid and self.remote_dir and os.path.exists(local_file): 139 | remote_path = os.path.join( 140 | self.remote_dir, os.path.basename(os.path.normpath(local_file))) 141 | ops.copy(local_file, remote_path) 142 | return remote_path 143 | else: 144 | return None 145 | -------------------------------------------------------------------------------- /configs/_meta_/inference.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - /default 5 | - override /hydra/job_logging: default 6 | - override /data/train: 7 | - webvid5s 8 | - override /data/val: 9 | - webvid5s 10 | - _self_ 11 | 12 | job_name: "magdiff_inference" 13 | output_dir: "_.outputs_20241206" 14 | 15 | callbacks: 16 | log_master: 17 | remote_dir: "" 18 | save_ckpt: true 19 | 20 | loggers: 21 | tensorboard: 22 | log_graph: True 23 | name: "tensorboard_save_dir" 24 | 25 | trainer: 26 | max_steps: 1 27 | log_every_n_steps: 1 28 | val_check_interval: 1 29 | num_sanity_val_steps: 1 30 | enable_progress_bar: True 31 | max_epochs: 1 32 | 33 | evaluator: pl_validate 34 | 35 | model: 36 | pretrained_model_path: "/home/user/model/stable-diffusion-2-1-base/" 37 | ckpt_path: "/home/user/model/magdiff.pth" 38 | lr: 0.00005 39 | scheduler_name: constant_with_warmup 40 | warmup_steps: 100 41 | num_inference_steps: 50 # debug 42 | # classifier-free guidance 43 | null_text_ratio: 0.15 44 | guidance_scale: 7.5 45 | # model component variants 46 | add_temp_embed: true 47 | prepend_first_frame: false 48 | add_temp_transformer: false 49 | add_temp_conv: true 50 | # trainable module 51 | freeze_text_encoder: true 52 | trainable_modules: 53 | - "temp_" 54 | - "transformer_blocks\\.\\d+" 55 | - "conv_in" 56 | # added model config 57 | load_pretrained_conv_in: True 58 | enable_xformers: False 59 | resolution: 256 60 | in_channels: 8 61 | add_entity_vae: True 62 | add_entity_clip: True 63 | 64 | data: 65 | batch_size_train: 8 66 | batch_size_val: 2 67 | resolution: 256 68 | sample_rate: 1 69 | train: 70 | webvid5s: 71 | data_dir: /home/user/data/magdiff/videos 72 | csv_subdir: annotations_76k.jsonl 73 | val: 74 | webvid5s: 75 | data_dir: /home/user/data/magdiff/val_videos 76 | csv_subdir: annotations_val.jsonl 77 | -------------------------------------------------------------------------------- /configs/callbacks/early_stopping.yaml: -------------------------------------------------------------------------------- 1 | early_stopping: 2 | _target_: pytorch_lightning.callbacks.EarlyStopping 3 | monitor: loss 4 | mode: min 5 | patience: 20 -------------------------------------------------------------------------------- /configs/callbacks/log_master.yaml: -------------------------------------------------------------------------------- 1 | log_master: 2 | _target_: callbacks.LogMaster 3 | name: ${job_name} 4 | log_file: mlog_${hydra:job.name}.log 5 | remote_dir: null 6 | monitor: val_clip_score 7 | monitor_fn: max -------------------------------------------------------------------------------- /configs/callbacks/lr_monitor.yaml: -------------------------------------------------------------------------------- 1 | lr_monitor: 2 | _target_: pytorch_lightning.callbacks.LearningRateMonitor 3 | logging_interval: step -------------------------------------------------------------------------------- /configs/callbacks/model_checkpoint.yaml: -------------------------------------------------------------------------------- 1 | model_checkpoint: 2 | _target_: pytorch_lightning.callbacks.ModelCheckpoint 3 | dirpath: ckpt 4 | every_n_train_steps: 10000 5 | filename: "model_e{epoch:03d}s{step:08d}" 6 | save_last: true 7 | save_weights_only: false -------------------------------------------------------------------------------- /configs/data/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: data.SDVideoDataModule 2 | 3 | defaults: 4 | - train: null 5 | - train_alt: null 6 | - val: null 7 | - _self_ 8 | 9 | batch_size_train: 1 10 | batch_size_val: 1 11 | 12 | num_frames: 8 13 | resolution: 512 14 | 15 | num_workers: 8 -------------------------------------------------------------------------------- /configs/data/train/pexels300k.yaml: -------------------------------------------------------------------------------- 1 | pexels300k: 2 | type: PexelsDataset 3 | data_dir: s3://bucket-9329/gujiaxi/data/Pexels300k/ 4 | csv_subdir: csv 5 | tar_subdir: tar 6 | 7 | num_frames: ${data.num_frames} 8 | resolution: ${data.resolution} 9 | -------------------------------------------------------------------------------- /configs/data/val/pexels300k.yaml: -------------------------------------------------------------------------------- 1 | pexels300k: 2 | type: PexelsDataset 3 | data_dir: s3://bucket-9329/gujiaxi/data/Pexels300k/ 4 | csv_subdir: csv 5 | tar_subdir: tar 6 | 7 | num_frames: ${data.num_frames} 8 | resolution: ${data.resolution} 9 | -------------------------------------------------------------------------------- /configs/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - data: default 5 | - model: sdvideo 6 | - callbacks: 7 | - lr_monitor 8 | - log_master 9 | - loggers: 10 | - tensorboard 11 | - trainer: distributed 12 | - logdir: default 13 | - override hydra/hydra_logging: disabled 14 | - override hydra/job_logging: disabled 15 | - _self_ 16 | 17 | seed: 42 18 | 19 | job_name: "RUN_1" 20 | 21 | output_dir: "output" -------------------------------------------------------------------------------- /configs/logdir/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | hydra: 4 | output_subdir: null 5 | job: 6 | name: ${job_name}_${now:%Y%m%d}_${now:%H%M%S} 7 | chdir: true 8 | env_set: 9 | TOKENIZERS_PARALLELISM: false 10 | run: 11 | dir: ${output_dir}/${job_name} 12 | sweep: 13 | dir: ${output_dir}/${job_name} -------------------------------------------------------------------------------- /configs/loggers/tensorboard.yaml: -------------------------------------------------------------------------------- 1 | tensorboard: 2 | _target_: pytorch_lightning.loggers.tensorboard.TensorBoardLogger 3 | save_dir: "." 4 | name: null 5 | version: tblog_${hydra:job.name} 6 | log_graph: False 7 | default_hp_metric: True 8 | prefix: "" 9 | -------------------------------------------------------------------------------- /configs/model/sdvideo.yaml: -------------------------------------------------------------------------------- 1 | _target_: model.SDVideoModel 2 | 3 | pretrained_model_path: null 4 | ckpt_path: null 5 | 6 | # train 7 | null_text_ratio: 0.1 8 | lr: 0.001 9 | weight_decay: 0.01 10 | warmup_steps: 10 11 | 12 | # inference 13 | num_inference_steps: 50 14 | guidance_scale: 7.5 15 | 16 | # input resolution 17 | resolution: ${data.resolution} 18 | 19 | # memory saving techniques 20 | enable_gradient_checkpointing: true 21 | enable_xformers: true -------------------------------------------------------------------------------- /configs/trainer/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.Trainer 2 | 3 | check_val_every_n_epoch: null 4 | max_steps: 8 5 | val_check_interval: 2 6 | log_every_n_steps: 2 7 | 8 | # disable checkpoint and progress bar because they are included in log_master callback 9 | enable_checkpointing: false 10 | enable_progress_bar: false 11 | 12 | # disables sampler replacement 13 | replace_sampler_ddp: false -------------------------------------------------------------------------------- /configs/trainer/distributed.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default.yaml 3 | 4 | strategy: 5 | _target_: pytorch_lightning.strategies.DDPStrategy 6 | find_unused_parameters: false 7 | gradient_as_bucket_view: true 8 | timeout: 9 | _target_: datetime.timedelta 10 | hours: 8 11 | 12 | # mixed-precision 13 | precision: 32 14 | 15 | # gradient clipping 16 | gradient_clip_val: 1.0 17 | 18 | # number of devices (-1 for all) 19 | accelerator: gpu 20 | devices: auto 21 | 22 | num_nodes: 1 -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_downloader import download_data, download_model 2 | from .dataloader import SDVideoDataModule -------------------------------------------------------------------------------- /data/constants.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | dir_path = Path(__file__).parent.as_posix() 4 | 5 | DEFAULT_DATA_DIR = "/cache/data/" 6 | DEFAULT_MODEL_DIR = "/cache/pretrained/" 7 | DEFAULT_TOKENIZER_DIR = Path(dir_path).parent.joinpath("assets/tokenizer/") 8 | DEFAULT_CSV_SUBDIR = "csv" 9 | DEFAULT_TAR_SUBDIR = "tar" 10 | 11 | DEFAULT_LABEL_FILENAME = "classnames.txt" 12 | -------------------------------------------------------------------------------- /data/data_downloader.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import tarfile 4 | import time 5 | from multiprocessing import Pool 6 | from pathlib import Path 7 | 8 | from utils import dist_envs 9 | from utils import get_free_space 10 | from utils import ops 11 | from .constants import ( 12 | DEFAULT_DATA_DIR, 13 | DEFAULT_MODEL_DIR, 14 | DEFAULT_LABEL_FILENAME, 15 | DEFAULT_CSV_SUBDIR, 16 | DEFAULT_TAR_SUBDIR, 17 | ) 18 | from .utils import read_multi_csv 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | class DataDownloader: 24 | def __init__(self, processes=8): 25 | self.pool = Pool(processes) 26 | 27 | def apply_job(self, remote_path, local_path, unzip_dir=None): 28 | # download file 29 | try: 30 | ops.copy(remote_path, local_path) 31 | except Exception as e: 32 | logger.error( 33 | f"File download failed (remote_path={remote_path}; local_path={local_path}): {e}" 34 | ) 35 | logger.info(f"File downloaded: {os.path.basename(local_path)}") 36 | # unzip file 37 | if unzip_dir: 38 | self.wait_for_load_below(self.pool._processes) 39 | self.pool.apply_async(self.unzip, args=(local_path, unzip_dir)) 40 | 41 | @staticmethod 42 | def unzip(file_path, unzip_dir): 43 | if not tarfile.is_tarfile(file_path): 44 | return 45 | with tarfile.open(file_path) as tarf: 46 | tarf.extractall(unzip_dir) 47 | logger.info(f"File unzipped: {os.path.basename(file_path)}") 48 | # delete zip file after unzipped 49 | os.remove(file_path) 50 | 51 | def wait_for_load_below(self, load_limit: int): 52 | t_start = time.monotonic() 53 | logger.info(f"Current free space: {get_free_space('/cache/'):.2f} GB.") 54 | while len(self.pool._cache) > load_limit: 55 | time.sleep(5) 56 | if time.monotonic() - t_start > 60: 57 | logger.info("Downloader is currently overloaded: waiting...") 58 | if load_limit == 0: 59 | self.pool.close() 60 | self.pool.join() 61 | 62 | 63 | def download_dataset(downloader, dataset_cfg, local_dir, max_num_csv=0): 64 | # data are downloaded only on main process of each node 65 | if not dist_envs.is_initialized: 66 | raise ValueError(f"Distributed environments are not initialized!") 67 | # only the main process of each node is used for data downloading 68 | if dist_envs.local_rank != 0: 69 | return 70 | local_dir = Path(local_dir) 71 | # get directory names 72 | remote_dir = dataset_cfg["data_dir"] 73 | csv_subdir = dataset_cfg.get("csv_subdir", DEFAULT_CSV_SUBDIR) 74 | tar_subdir = dataset_cfg.get("tar_subdir", DEFAULT_TAR_SUBDIR) 75 | # download label file 76 | label_filename = dataset_cfg.get("label_filename", DEFAULT_LABEL_FILENAME) 77 | remote_label_file = os.path.join(remote_dir, label_filename) 78 | local_label_file = local_dir.joinpath(label_filename) 79 | if not local_label_file.exists() and ops.exists(remote_label_file): 80 | downloader.apply_job(remote_label_file, local_label_file.as_posix()) 81 | # download csvs 82 | remote_csv_dir = os.path.join(remote_dir, csv_subdir) 83 | csv_files = [x for x in ops.listdir(remote_csv_dir) if x.endswith(".csv")] 84 | csv_files = csv_files[:max_num_csv] if max_num_csv > 0 else csv_files 85 | local_csv_dir = local_dir.joinpath(csv_subdir) 86 | for idx, csv_file in enumerate(csv_files): 87 | local_csv = local_csv_dir.joinpath(csv_file) 88 | if not local_csv.exists(): 89 | logger.info(f"[{idx + 1}/{len(csv_files)}] " 90 | f"Downloading csv file: {csv_file}") 91 | downloader.apply_job( 92 | os.path.join(remote_csv_dir, csv_file), local_csv.as_posix()) 93 | 94 | tar_name = Path(csv_files[0]).stem 95 | tar_ext_real = None 96 | for tar_ext in ["", ".tar"]: 97 | if ops.exists(os.path.join(remote_dir, tar_subdir, tar_name + tar_ext)): 98 | tar_ext_real = tar_ext 99 | break 100 | if tar_ext_real is None: 101 | raise NotImplementedError( 102 | f"Extension of tar files is not recognized: {os.path.join(remote_dir, tar_subdir)}") 103 | tar_files = [Path(x).with_suffix(tar_ext_real) for x in csv_files] 104 | 105 | if dataset_cfg.get("split_among_nodes", False): 106 | # Only download a part of tar files according to current node_rank 107 | total_df = read_multi_csv(local_csv_dir) 108 | packages = total_df["package"].to_list() 109 | node_rank, num_nodes, world_size = \ 110 | dist_envs.node_rank, dist_envs.num_nodes, dist_envs.world_size 111 | # drop_last is necessary when split_among_nodes=True 112 | packages = packages[:(len(packages) // world_size) * world_size] 113 | chunk_size = len(packages) // num_nodes 114 | packages = set(packages[node_rank * chunk_size:(node_rank + 1) * chunk_size]) 115 | tar_files = [x for x in tar_files if Path(x).stem in packages] 116 | 117 | # Start download tar files 118 | remote_tar_dir = os.path.join(remote_dir, tar_subdir) 119 | local_tar_dir = local_dir.joinpath(tar_subdir) 120 | for idx, tar_file in enumerate(tar_files): 121 | local_tar = local_tar_dir.joinpath(tar_file) 122 | local_subdir = local_tar_dir.joinpath(local_tar.stem) 123 | if local_subdir.exists(): 124 | continue 125 | logger.info( 126 | f"[{idx + 1}/{len(tar_files)}] " 127 | f"Downloading & extracting tar file: {tar_file}") 128 | downloader.apply_job( 129 | remote_path=os.path.join(remote_tar_dir, tar_file), 130 | local_path=local_tar.as_posix(), 131 | unzip_dir=local_tar_dir.as_posix() 132 | ) 133 | logger.info(f"Dataset download complete: {Path(local_dir).name}") 134 | 135 | 136 | def download_data(data_cfg): 137 | downloader = DataDownloader() 138 | stages = ("train", "train_alt", "val") 139 | for stage in stages: 140 | datasets = data_cfg.get(stage, dict()).items() 141 | for dataset_name, dataset_cfg in datasets: 142 | if dataset_cfg["data_dir"].startswith("s3://"): 143 | local_dir = Path(DEFAULT_DATA_DIR, dataset_name).as_posix() 144 | logger.info(f"Downloading dataset: {dataset_name}") 145 | download_dataset(downloader, dataset_cfg, local_dir, 146 | max_num_csv=dataset_cfg.get("max_num_csv", 0)) 147 | data_cfg[stage][dataset_name]["data_dir"] = local_dir 148 | downloader.wait_for_load_below(0) 149 | return data_cfg 150 | 151 | 152 | def download_model(model_path): 153 | if not model_path or not model_path.startswith("s3://"): 154 | return model_path 155 | local_path = Path(DEFAULT_MODEL_DIR, os.path.basename(os.path.normpath(model_path))) 156 | if dist_envs.local_rank == 0 and not local_path.exists(): 157 | logger.info(f"Downloading model: {local_path.name}") 158 | downloader = DataDownloader() 159 | downloader.apply_job(model_path, local_path.as_posix()) 160 | downloader.wait_for_load_below(0) 161 | return local_path.as_posix() 162 | -------------------------------------------------------------------------------- /data/dataloader.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import pytorch_lightning as pl 4 | from omegaconf import DictConfig 5 | from pytorch_lightning.utilities.data import CombinedLoader 6 | from torch.utils.data import ConcatDataset 7 | from torch.utils.data import DataLoader 8 | 9 | from .datasets import DATASETS 10 | from .utils import GroupDistributedSampler, MergeDataset, DistributedSampler 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | def build_dataloader(data_config: DictConfig, batch_size=4, shuffle=False, 16 | drop_last=True, num_workers=8, pin_memory=True): 17 | cfg_datasets = list(data_config.values()) 18 | datasets = [] 19 | for cfg_dataset in cfg_datasets: 20 | dataset_cls = DATASETS.get(cfg_dataset.get("type")) 21 | cfg_dataset.update({"random_sample": shuffle}) 22 | dataset = dataset_cls(**cfg_dataset) 23 | datasets.append(dataset) 24 | if len(datasets) == 1: 25 | dataset = ConcatDataset(datasets) 26 | else: 27 | if not drop_last: 28 | logger.warning(f"Option `drop_last` is forced activated when merging multiple datasets.") 29 | dataset = MergeDataset(datasets) 30 | dataloader_args = dict( 31 | dataset=dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory) 32 | if not drop_last: 33 | dataloader = DataLoader(**dataloader_args, shuffle=shuffle, drop_last=drop_last) 34 | # sampler = DistributedSampler(dataset, shuffle=False, drop_last=False) 35 | # dataloader = DataLoader(**dataloader_args, sampler=sampler) 36 | else: 37 | sampler = GroupDistributedSampler(dataset, shuffle=shuffle, drop_last=drop_last) 38 | dataloader = DataLoader(**dataloader_args, sampler=sampler) 39 | return dataloader 40 | 41 | 42 | class SDVideoDataModule(pl.LightningDataModule): 43 | def __init__( 44 | self, train, val, train_alt=None, batch_size_train=1, batch_size_val=1, 45 | num_workers=8, pin_memory=True, **kwargs 46 | ): 47 | super().__init__() 48 | # make `prepare_data` called in each node when ddp is used. 49 | self.prepare_data_per_node = True 50 | # dir config 51 | self.train = train 52 | self.val = val 53 | self.train_alt = train_alt 54 | # dataloader config 55 | self.batch_size_train = batch_size_train 56 | self.batch_size_val = batch_size_val 57 | self.num_workers = num_workers 58 | self.pin_memory = pin_memory 59 | 60 | def train_dataloader(self): 61 | shuffle, drop_last = True, True 62 | dataloader = build_dataloader( 63 | self.train, shuffle=shuffle, drop_last=drop_last, batch_size=self.batch_size_train, 64 | num_workers=self.num_workers, pin_memory=self.pin_memory 65 | ) 66 | dataloaders = {"train": dataloader} 67 | if not self.train_alt: 68 | return dataloaders 69 | dataloader_alt = build_dataloader( 70 | self.train_alt, shuffle=shuffle, drop_last=drop_last, batch_size=self.batch_size_train * 8, 71 | num_workers=self.num_workers, pin_memory=self.pin_memory 72 | ) 73 | dataloaders.update({"train_alt": dataloader_alt}) 74 | return CombinedLoader(dataloaders, mode="max_size_cycle") 75 | 76 | def val_dataloader(self): 77 | dataloader = build_dataloader( 78 | self.val, shuffle=False, drop_last=False, batch_size=self.batch_size_val, 79 | num_workers=self.num_workers 80 | ) 81 | return dataloader 82 | -------------------------------------------------------------------------------- /data/datasets.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import random 3 | import warnings 4 | from ast import literal_eval 5 | from pathlib import Path 6 | 7 | from PIL import Image 8 | from torch.utils.data import Dataset 9 | from torchvision.datasets.folder import pil_loader 10 | from torchvision.transforms import Compose, Resize, RandomCrop, CenterCrop, Lambda 11 | from torchvision.transforms.functional import pil_to_tensor 12 | from transformers import CLIPTokenizer 13 | 14 | from utils import dist_envs 15 | import os 16 | import json 17 | import open_clip 18 | import kornia 19 | 20 | from .constants import DEFAULT_TOKENIZER_DIR 21 | from .registry import DATASETS 22 | from .utils import read_multi_csv, load_video, load_entity_vae, load_entity_clip 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | warnings.simplefilter("error", Image.DecompressionBombWarning) 27 | 28 | DATALOAD_TRY_TIMES = 64 29 | 30 | def load_jsonl(filename): 31 | with open(filename, "r") as f: 32 | return [json.loads(l.strip("\n")) for l in f.readlines()] 33 | 34 | class GeneralPairDataset(Dataset): 35 | def __init__( 36 | self, data_dir, csv_subdir, tar_subdir, resolution, tokenizer_dir, 37 | random_sample, 38 | ): 39 | self.random_sample = random_sample 40 | # data related 41 | self.data_dir = data_dir 42 | self.anns_dir = os.path.join(data_dir, csv_subdir) 43 | self.frame_anns = load_jsonl(self.anns_dir) 44 | # image related 45 | # image_crop_op = RandomCrop if random_sample else CenterCrop 46 | image_crop_op = CenterCrop 47 | # resize; crop; scale pixels from [0, 255) to [-1, 1) 48 | self.transform = Compose([ 49 | Resize(resolution, antialias=False), image_crop_op(resolution), 50 | Lambda(lambda pixels: pixels / 127.5 - 1.0) 51 | ]) 52 | self.entity_clip_transform = Compose([ 53 | Resize(448, antialias=False), image_crop_op(448) 54 | ]) 55 | # text related 56 | self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_dir) 57 | 58 | self.preprocess = None 59 | 60 | def __len__(self): 61 | return len(read_multi_csv(self.meta_dir)) 62 | 63 | def __getitem__(self, idx): 64 | raise NotImplementedError 65 | 66 | 67 | class VideoTextDataset(GeneralPairDataset): 68 | def __init__( 69 | self, data_dir, csv_subdir, tar_subdir, tokenizer_dir=DEFAULT_TOKENIZER_DIR, 70 | num_frames=8, resolution=512, random_sample=True, **kwargs, 71 | ): 72 | super().__init__( 73 | data_dir=data_dir, csv_subdir=csv_subdir, tar_subdir=tar_subdir, 74 | resolution=resolution, tokenizer_dir=tokenizer_dir, 75 | random_sample=random_sample 76 | ) 77 | # sample config 78 | self.num_frames = num_frames 79 | # process dataframe 80 | # self.videos, self.texts = self.read_data(self.meta_dir) 81 | self.videos, self.tars, self.texts, self.coordinates = self.read_anns(self.frame_anns) 82 | 83 | @staticmethod 84 | def read_anns(anns): 85 | videos = [] 86 | tars = [] 87 | texts = [] 88 | coordinates = [] 89 | for ann in anns: 90 | videos.append(ann['vid']) 91 | tars.append(ann['tar']) 92 | texts.append(ann['text']) 93 | coordinates.append(ann['coordinates']) 94 | return videos, tars, texts, coordinates 95 | 96 | @staticmethod 97 | def read_data(meta_dir): 98 | df = read_multi_csv(meta_dir) 99 | videos, texts = df["video"].to_list(), df["caption"].to_list() 100 | return videos, texts 101 | 102 | def __len__(self): 103 | return len(self.videos) 104 | 105 | def __getitem__(self, idx): 106 | try_num = 0 107 | while try_num < DATALOAD_TRY_TIMES: 108 | try: 109 | video_path = Path(self.data_dir, "videos", self.tars[idx], "raw_vid", self.videos[idx]) 110 | text = self.texts[idx] 111 | coords = self.coordinates[idx] 112 | if isinstance(text, list): 113 | text = random.choice(text) 114 | # VIDEO 115 | frame_rate = 1 116 | pixel_values, sample_ids = load_video( 117 | video_path.as_posix(), 118 | n_sample_frames=self.num_frames, 119 | sample_rate=frame_rate, 120 | random_sample=self.random_sample, 121 | transform=self.transform, 122 | selected_frames=coords 123 | ) 124 | use_rand_entity_sample = True 125 | chosed_index = 0.2 126 | entity_vae = load_entity_vae( 127 | video_path, sample_ids=sample_ids, transform=self.transform, 128 | use_rand_entity_sample=use_rand_entity_sample, chosed_index=chosed_index 129 | ) 130 | entity_clip = load_entity_clip( 131 | video_path, sample_ids=sample_ids, 132 | preprocess=self.preprocess, use_rand_entity_sample=use_rand_entity_sample, 133 | chosed_index=chosed_index, transform = self.entity_clip_transform 134 | ) 135 | # TEXT 136 | text_token_ids = self.tokenizer( 137 | text, 138 | max_length=self.tokenizer.model_max_length, 139 | padding="max_length", 140 | truncation=True, 141 | return_tensors="pt", 142 | ).input_ids[0] 143 | break 144 | except Exception as e: 145 | logger.warning(f"Exception occurred parsing video file ({self.videos[idx]}): {e}") 146 | idx = random.randrange( 147 | len(self) // dist_envs.num_nodes * dist_envs.node_rank, 148 | len(self) // dist_envs.num_nodes * (dist_envs.node_rank + 1) 149 | ) if try_num < DATALOAD_TRY_TIMES // 2 else random.randrange(len(self)) 150 | try_num += 1 151 | # output 152 | output = dict( 153 | pixel_values=pixel_values, entity_vae=entity_vae, entity_clip=entity_clip, text_token_ids=text_token_ids, frame_rates=frame_rate 154 | ) 155 | return output 156 | 157 | 158 | @DATASETS.register_module() 159 | class TgifDataset(VideoTextDataset): 160 | @staticmethod 161 | def read_data(meta_dir): 162 | df = read_multi_csv(meta_dir) 163 | df["video"] = df[["package", "id"]].agg( 164 | lambda xs: "{0}/{1}.gif".format(*xs), axis=1) 165 | # load data paths 166 | videos = df["video"].to_list() 167 | texts = list(map(literal_eval, df["caption"])) 168 | return videos, texts 169 | 170 | 171 | @DATASETS.register_module() 172 | class VatexDataset(VideoTextDataset): 173 | @staticmethod 174 | def read_data(meta_dir): 175 | df = read_multi_csv(meta_dir) 176 | df["video"] = df[["package", "videoID"]].agg( 177 | lambda xs: "{0}/{1}".format(*xs), axis=1) 178 | # load data paths 179 | videos = df["video"].to_list() 180 | texts = list(map(literal_eval, df["enCap"])) 181 | return videos, texts 182 | 183 | 184 | @DATASETS.register_module() 185 | class WebvidDataset(VideoTextDataset): 186 | @staticmethod 187 | def read_data(meta_dir): 188 | df = read_multi_csv(meta_dir) 189 | df["video"] = df[["package", "videoid"]].agg( 190 | lambda xs: "{0}/{1}.mp4".format(*xs), axis=1) 191 | # load data paths 192 | videos = df["video"].to_list() 193 | # add watermark as keyword since all videos in WebVid have watermarks 194 | texts = [[f"{x}, watermark", f"{x} with watermark"] 195 | for x in df["name"].to_list()] 196 | return videos, texts 197 | 198 | 199 | @DATASETS.register_module() 200 | class K700CaptionDataset(VideoTextDataset): 201 | @staticmethod 202 | def read_data(meta_dir): 203 | df = read_multi_csv(meta_dir) 204 | df["video"] = df[["package", "id"]].agg( 205 | lambda xs: "{0}/{1}".format(*xs), axis=1) 206 | df["text"] = df[["class", "caption"]].agg( 207 | lambda xs: ["{0}: {1}".format(*xs), 208 | "{1}, {0}".format(*xs)], axis=1) 209 | # load data paths 210 | videos = df["video"].to_list() 211 | texts = df["text"].to_list() 212 | return videos, texts 213 | 214 | 215 | @DATASETS.register_module() 216 | class MidjourneyVideoDataset(VideoTextDataset): 217 | @staticmethod 218 | def read_data(meta_dir): 219 | df = read_multi_csv(meta_dir) 220 | df["video"] = df[["package", "videoname"]].agg( 221 | lambda xs: "{0}/{1}".format(*xs), axis=1) 222 | # load data paths 223 | videos = df["video"].to_list() 224 | texts = df["caption"].to_list() 225 | return videos, texts 226 | 227 | 228 | @DATASETS.register_module() 229 | class MomentsInTimeDataset(VideoTextDataset): 230 | @staticmethod 231 | def read_data(meta_dir): 232 | df = read_multi_csv(meta_dir) 233 | df["video"] = df[["package", "Video Name"]].agg( 234 | lambda xs: "{0}/{1}".format(*xs), axis=1) 235 | # load data paths 236 | videos = df["video"].to_list() 237 | texts = df["caption"].to_list() 238 | return videos, texts 239 | 240 | 241 | @DATASETS.register_module() 242 | class PexelsDataset(VideoTextDataset): 243 | @staticmethod 244 | def read_data(meta_dir): 245 | df = read_multi_csv(meta_dir) 246 | df["video"] = df[["package", "id"]].agg( 247 | lambda xs: "{0}/{1}.mp4".format(*xs), axis=1) 248 | # load data paths 249 | videos = df["video"].to_list() 250 | texts = df["caption"].to_list() 251 | return videos, texts 252 | 253 | 254 | class ImageTextDataset(GeneralPairDataset): 255 | def __init__( 256 | self, data_dir, csv_subdir, tar_subdir, tokenizer_dir=DEFAULT_TOKENIZER_DIR, 257 | resolution=512, random_sample=True, **kwargs, 258 | ): 259 | super().__init__( 260 | data_dir=data_dir, csv_subdir=csv_subdir, tar_subdir=tar_subdir, 261 | resolution=resolution, tokenizer_dir=tokenizer_dir, 262 | random_sample=random_sample 263 | ) 264 | self.images, self.texts = self.read_data(self.meta_dir) 265 | 266 | @staticmethod 267 | def read_data(meta_dir): 268 | df = read_multi_csv(meta_dir) 269 | images, texts = df["image"].to_list(), df["text"].to_list() 270 | return images, texts 271 | 272 | def __len__(self): 273 | return len(self.images) 274 | 275 | def __getitem__(self, idx): 276 | try_num = 0 277 | while try_num < DATALOAD_TRY_TIMES: 278 | try: 279 | # IMAGE 280 | image_path = Path(self.data_dir, self.images[idx]) 281 | image = pil_loader(image_path.as_posix()) 282 | pixel_values = self.transform(pil_to_tensor(image)) 283 | # TEXT 284 | text = self.texts[idx] 285 | text_token_ids = self.tokenizer( 286 | text, 287 | max_length=self.tokenizer.model_max_length, 288 | padding="max_length", 289 | truncation=True, 290 | return_tensors="pt", 291 | ).input_ids[0] 292 | break 293 | except Exception as e: 294 | logger.warning(f"Exception occurred parsing image file ({self.images[idx]}): {e}") 295 | idx = random.randrange( 296 | len(self) // dist_envs.num_nodes * dist_envs.node_rank, 297 | len(self) // dist_envs.num_nodes * (dist_envs.node_rank + 1) 298 | ) if try_num < DATALOAD_TRY_TIMES // 2 else random.randrange(len(self)) 299 | try_num += 1 300 | # output 301 | frame_rates = random.randrange(30) 302 | output = dict( 303 | pixel_values=pixel_values, text_token_ids=text_token_ids, frame_rates=frame_rates 304 | ) 305 | return output 306 | 307 | 308 | @DATASETS.register_module() 309 | class LaionDataset(ImageTextDataset): 310 | @staticmethod 311 | def read_data(meta_dir): 312 | df = read_multi_csv(meta_dir) 313 | df = df.dropna(subset=["dir", "text"], how="any").reset_index(drop=True) 314 | images, texts = df["dir"].to_list(), df["text"].to_list() 315 | return images, texts 316 | 317 | 318 | @DATASETS.register_module() 319 | class MidJourneyImageDataset(ImageTextDataset): 320 | @staticmethod 321 | def read_data(meta_dir): 322 | df = read_multi_csv(meta_dir) 323 | df = df.dropna(subset=["videoname", "caption"], how="any").reset_index(drop=True) 324 | df["videoname"] = df[["package", "videoname"]].agg( 325 | lambda xs: "{0}/{1}.png".format(*xs), axis=1) 326 | images, texts = df["videoname"].to_list(), df["caption"].to_list() 327 | return images, texts 328 | -------------------------------------------------------------------------------- /data/registry.py: -------------------------------------------------------------------------------- 1 | from utils import Registry 2 | 3 | DATASETS = Registry("datasets") 4 | -------------------------------------------------------------------------------- /data/utils.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import importlib 3 | import logging 4 | import os 5 | import random 6 | from collections import Counter 7 | from operator import itemgetter 8 | from pathlib import Path 9 | from typing import Union 10 | 11 | import numpy as np 12 | import pandas as pd 13 | import torch 14 | from PIL import Image, ImageSequence 15 | from torch.utils.data import ConcatDataset 16 | from torch.utils.data.distributed import DistributedSampler 17 | from torchvision.datasets.folder import pil_loader 18 | from torchvision.transforms import Compose, Resize, CenterCrop, Lambda 19 | from torchvision.transforms.functional import pil_to_tensor 20 | 21 | from utils import dist_envs 22 | import cv2 23 | import random 24 | 25 | if importlib.util.find_spec("decord"): 26 | import decord 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | 31 | def sparse_sample(total_frames, sample_frames, sample_rate, random_sample=False): 32 | if sample_frames <= 0: # sample over the total sequence of frames 33 | ids = np.arange(0, total_frames, sample_rate, dtype=int).tolist() 34 | elif sample_rate * (sample_frames - 1) + 1 <= total_frames: 35 | offset = random.randrange(total_frames - (sample_rate * (sample_frames - 1))) \ 36 | if random_sample else 0 37 | ids = list(range(offset, total_frames + offset, sample_rate))[:sample_frames] 38 | else: 39 | ids = np.linspace(0, total_frames, sample_frames, endpoint=False, dtype=int).tolist() 40 | return ids 41 | 42 | def load_image(img_path, resize_res=None, crop=False, rescale=False): 43 | """ Load an image to tensor 44 | 45 | Args: 46 | img_path: image file path to load. 47 | resize_res: resolution of resize, no resize if None 48 | crop: center crop if True 49 | rescale: values in [-1, 1) if rescale=True otherwise [0, 255) 50 | 51 | Returns: 52 | torch.Tensor: image tensor in the shape of [c, h, w] 53 | """ 54 | img = pil_to_tensor(pil_loader(img_path)) 55 | tsfm_ops = [] 56 | crop_size = min(img.shape[-2:]) 57 | if resize_res: 58 | tsfm_ops.append(Resize(resize_res, antialias=False)) 59 | crop_size = resize_res 60 | if crop: 61 | tsfm_ops.append(CenterCrop(crop_size)) 62 | if rescale: 63 | tsfm_ops.append(Lambda(lambda pixels: pixels / 127.5 - 1.0)) 64 | transform = Compose(tsfm_ops) 65 | return transform(img) 66 | 67 | 68 | def load_video(file_path, n_sample_frames, sample_rate=4, random_sample=False, transform=None, selected_frames=None): 69 | # random_sample = True 70 | sample_args = dict( 71 | sample_frames=n_sample_frames, sample_rate=sample_rate, random_sample=random_sample) 72 | video = [] 73 | if Path(file_path).is_dir(): 74 | img_files = sorted(Path(file_path).glob("*"), key=lambda i: int(i.stem[6:])) 75 | if len(img_files) < 1: 76 | logger.error(f"No data in video directory: {file_path}") 77 | raise FileNotFoundError(f"No data in video directory: {file_path}") 78 | 79 | # sample_ids = sparse_sample(len(img_files), **sample_args) 80 | selected_frames_ids = sparse_sample(len(selected_frames), **sample_args) ##FIXME change to selected frames, not all the frames 81 | sample_ids = [] 82 | for i in range(0, len(selected_frames_ids)): 83 | sample_ids.append(selected_frames[selected_frames_ids[i]]) 84 | 85 | for img_file in itemgetter(*sample_ids)(img_files): 86 | img = pil_loader(img_file.as_posix()) 87 | img = pil_to_tensor(img) 88 | video.append(img) 89 | elif file_path.endswith(".gif"): 90 | with Image.open(file_path) as gif: 91 | sample_ids = sparse_sample(gif.n_frames, **sample_args) 92 | sample_ids_counter = Counter(sample_ids) 93 | for frame_idx, frame in enumerate(ImageSequence.Iterator(gif)): 94 | if frame_idx in sample_ids_counter: 95 | frame = pil_to_tensor(frame.convert("RGB")) 96 | for _ in range(sample_ids_counter[frame_idx]): 97 | video.append(frame) 98 | else: 99 | vreader = decord.VideoReader(file_path) 100 | sample_ids = sparse_sample(len(vreader), **sample_args) 101 | frames = vreader.get_batch(sample_ids).asnumpy() # (f, h, w, c) 102 | for frame_idx in range(frames.shape[0]): 103 | video.append(pil_to_tensor(Image.fromarray(frames[frame_idx]).convert("RGB"))) 104 | video = torch.stack(video) # (f, c, h, w) 105 | if transform is not None: 106 | video = transform(video) 107 | return video, sample_ids 108 | 109 | def load_entity_vae(file_path, sample_ids, transform=None, use_rand_entity_sample=False, chosed_index=None): 110 | video = [] 111 | selected_frames = [1] 112 | if Path(file_path).is_dir(): 113 | img_files = sorted(Path(file_path).glob("*"), key=lambda i: int(i.stem[6:])) 114 | if len(img_files) < 1: 115 | logger.error(f"No data in video directory: {file_path}") 116 | raise FileNotFoundError(f"No data in video directory: {file_path}") 117 | index = 1 118 | for img_file in itemgetter(*sample_ids)(img_files): 119 | img_file = str(img_file) 120 | mask_file = img_file.replace("raw_vid/", "").replace("videos", "mask").replace("frame_","").replace("jpg","png") 121 | video_img = cv2.imread(img_file) 122 | mask_img = cv2.imread(mask_file, cv2.IMREAD_GRAYSCALE) 123 | entity = cv2.bitwise_and(video_img, video_img, mask=mask_img) 124 | # cv2 -> PIL 125 | img = Image.fromarray(cv2.cvtColor(entity, cv2.COLOR_BGR2RGB)) 126 | img = pil_to_tensor(img) 127 | if use_rand_entity_sample: ###FIXME random mask some frames 128 | if index not in selected_frames: 129 | img = torch.zeros_like(img) 130 | video.append(img) 131 | index += 1 132 | video = torch.stack(video) # (f, c, h, w) 133 | if transform is not None: 134 | video = transform(video) 135 | return video 136 | 137 | def load_entity_clip(file_path, sample_ids, preprocess=None, use_rand_entity_sample=False, chosed_index=None, transform=None): 138 | video = [] 139 | selected_frames = [1] 140 | if Path(file_path).is_dir(): 141 | img_files = sorted(Path(file_path).glob("*"), key=lambda i: int(i.stem[6:])) 142 | if len(img_files) < 1: 143 | logger.error(f"No data in video directory: {file_path}") 144 | raise FileNotFoundError(f"No data in video directory: {file_path}") 145 | index = 1 146 | for img_file in itemgetter(*sample_ids)(img_files): 147 | img_file = str(img_file) 148 | mask_file = img_file.replace("raw_vid/", "").replace("videos", "mask").replace("frame_","").replace("jpg","png") 149 | video_img = cv2.imread(img_file) 150 | mask_img = cv2.imread(mask_file, cv2.IMREAD_GRAYSCALE) 151 | entity = cv2.bitwise_and(video_img, video_img, mask=mask_img) 152 | 153 | img_tensor = torch.from_numpy(entity).permute(2, 0, 1).float() 154 | if use_rand_entity_sample: ###FIXME random mask some frames 155 | if index not in selected_frames: 156 | img_tensor = torch.zeros_like(img_tensor) 157 | img = (img_tensor / 255. - 0.5) * 2 158 | video.append(img) 159 | index += 1 160 | video = torch.stack(video, dim=0) # (f, c, h, w) 161 | if transform is not None: 162 | video = transform(video) 163 | return video 164 | 165 | 166 | class MergeDataset(ConcatDataset): 167 | r"""Dataset as a merge of multiple datasets. 168 | 169 | Datasets are firstly split into multiple chunks and these chunks are merged 170 | one after another. 171 | 172 | Example: 173 | 3 Datasets with sizes: [6, 17, 27]; split_num=4 174 | The global indices will be like: 175 | [[1, 4, 6] 176 | [1, 4, 6] 177 | [1, 4, 6] 178 | [3, 5, 9]] 179 | 180 | Args: 181 | datasets: List of datasets to be merged 182 | """ 183 | 184 | @staticmethod 185 | def split_cumsum(group_sizes, split_num): 186 | r, s = [], 0 187 | for split in range(split_num): 188 | chunk_sizes = [x // split_num for x in group_sizes] 189 | if split == split_num - 1: 190 | chunk_sizes = [chunk_sizes[i] + group_sizes[i] % split_num 191 | for i in range(len(group_sizes))] 192 | for chunk_size in chunk_sizes: 193 | r.append(chunk_size + s) 194 | s += chunk_size 195 | return r 196 | 197 | def __init__(self, datasets) -> None: 198 | super(MergeDataset, self).__init__(datasets) 199 | num_nodes, world_size = dist_envs.num_nodes, dist_envs.world_size 200 | # drop_last for all datasets 201 | self.datasets = list(datasets) 202 | self.dataset_sizes = list() 203 | for dataset in self.datasets: 204 | dataset_size = len(dataset) # type: ignore[arg-type] 205 | self.dataset_sizes.append((dataset_size // world_size) * world_size) 206 | self.chunk_sizes = [x // num_nodes for x in self.dataset_sizes] 207 | self.cumulative_sizes = self.split_cumsum(self.dataset_sizes, num_nodes) 208 | 209 | def __len__(self): 210 | return sum(self.dataset_sizes) 211 | 212 | def __getitem__(self, idx): 213 | if idx < 0: 214 | if -idx > len(self): 215 | raise ValueError("absolute value of index should not exceed dataset length") 216 | idx = len(self) + idx 217 | global_chunk_idx = bisect.bisect_right(self.cumulative_sizes, idx) 218 | if global_chunk_idx == 0: 219 | dataset_idx = 0 220 | sample_idx = idx 221 | else: 222 | dataset_idx = global_chunk_idx % len(self.datasets) 223 | sample_delta = idx - self.cumulative_sizes[global_chunk_idx - 1] 224 | sample_idx = global_chunk_idx // len(self.datasets) * self.chunk_sizes[dataset_idx] + sample_delta 225 | return self.datasets[dataset_idx][sample_idx] 226 | 227 | 228 | class GroupDistributedSampler(DistributedSampler): 229 | r"""Sampler that restricts grouped data loading to a subset of the dataset. 230 | 231 | The dataset is firstly split into groups determined by `num_nodes`. The 232 | grouped datasets are then shuffled and distributed among devices in each 233 | node. 234 | """ 235 | 236 | def __iter__(self): 237 | num_nodes = dist_envs.num_nodes 238 | assert self.num_replicas % num_nodes == 0 239 | indices = list(range(len(self.dataset))) # type: ignore[arg-type] 240 | chunk_size = self.total_size // num_nodes 241 | if self.shuffle: 242 | indices = [] 243 | for node_rank in range(num_nodes): 244 | g = torch.Generator() 245 | g.manual_seed(self.seed + node_rank + self.epoch) 246 | chunk_indices = torch.randperm( 247 | chunk_size, generator=g).tolist() # type: ignore[arg-type] 248 | indices.extend([x + chunk_size * node_rank for x in chunk_indices]) 249 | 250 | # remove tail of data to make it evenly divisible. 251 | indices = indices[:self.total_size] 252 | assert len(indices) == self.total_size 253 | 254 | # subsample 255 | num_devices = self.num_replicas // num_nodes 256 | node_rank = int(self.rank / num_devices) 257 | local_rank = self.rank % num_devices 258 | indices = indices[local_rank:chunk_size:num_devices] 259 | indices = [x + chunk_size * node_rank for x in indices] 260 | assert len(indices) == self.num_samples 261 | 262 | return iter(indices) 263 | 264 | 265 | class LabelEncoder: 266 | """ Encodes an label via a dictionary. 267 | Args: 268 | label_source (list of strings): labels of data used to build encoding dictionary. 269 | Example: 270 | >>> labels = ['label_a', 'label_b'] 271 | >>> encoder = LabelEncoder(labels) 272 | >>> encoder.encode('label_a') 273 | tensor(0) 274 | >>> encoder.decode(encoder.encode('label_a')) 275 | 'label_a' 276 | >>> encoder.encode('label_b') 277 | tensor(1) 278 | >>> encoder.size 279 | ['label_a', 'label_b'] 280 | """ 281 | 282 | def __init__(self, label_source: Union[list, str, os.PathLike]): 283 | if isinstance(label_source, list): 284 | self.labels = label_source 285 | else: 286 | with open(label_source, "r", encoding="utf-8") as f: 287 | lines = [x.strip() for x in f.readlines() if x.strip()] 288 | self.labels = lines 289 | self.idx_to_label = {idx: lab for idx, lab in enumerate(self.labels)} 290 | self.label_to_idx = {lab: idx for idx, lab in enumerate(self.labels)} 291 | 292 | @property 293 | def size(self): 294 | """ 295 | Returns: 296 | int: Number of labels in the dictionary. 297 | """ 298 | return len(self.labels) 299 | 300 | def encode(self, label): 301 | """ Encodes a ``label``. 302 | 303 | Args: 304 | label (object): Label to encode. 305 | 306 | Returns: 307 | torch.Tensor: Encoding of the label. 308 | """ 309 | return torch.tensor(self.label_to_idx.get(label), dtype=torch.long) 310 | 311 | def batch_encode(self, iterator, dim=0): 312 | """ 313 | Args: 314 | iterator (iterator): Batch of labels to encode. 315 | dim (int, optional): Dimension along which to concatenate tensors. 316 | 317 | Returns: 318 | torch.Tensor: Tensor of encoded labels. 319 | """ 320 | return torch.stack([self.encode(x) for x in iterator], dim=dim) 321 | 322 | def decode(self, encoded): 323 | """ Decodes ``encoded`` label. 324 | 325 | Args: 326 | encoded (torch.Tensor): Encoded label. 327 | 328 | Returns: 329 | object: Label decoded from ``encoded``. 330 | """ 331 | if encoded.numel() > 1: 332 | raise ValueError( 333 | '``decode`` decodes one label at a time, use ``batch_decode`` instead.') 334 | 335 | return self.idx_to_label[encoded.squeeze().item()] 336 | 337 | def batch_decode(self, tensor, dim=0): 338 | """ 339 | Args: 340 | tensor (torch.Tensor): Batch of tensors. 341 | dim (int, optional): Dimension along which to split tensors. 342 | 343 | Returns: 344 | list: Batch of decoded labels. 345 | """ 346 | return [self.decode(x) for x in [t.squeeze(0) for t in tensor.split(1, dim=dim)]] 347 | 348 | 349 | def read_multi_csv(csv_dir): 350 | csvs = sorted(Path(csv_dir).glob("*.csv")) 351 | df_all = [] 352 | for c in csvs: 353 | df = pd.read_csv(c) 354 | df["package"] = Path(c).stem 355 | df_all.append(df) 356 | df_all = pd.concat(df_all, ignore_index=True) 357 | return df_all 358 | -------------------------------------------------------------------------------- /inference.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 MASTER_ADDR=127.0.0.1 MASTER_PORT=10086 NODE_RANK=0 WORLD_SIZE=1 HYDRA_FULL_ERROR=1 \ 2 | python main.py --config-name=_meta_/inference.yaml -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import hydra 5 | from hydra.utils import instantiate 6 | from lightning_fabric.utilities.seed import seed_everything 7 | from omegaconf import OmegaConf 8 | 9 | from data import download_data, download_model 10 | from utils import instantiate_multi, save_config, dist_envs, enable_logger 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | @hydra.main(config_path="configs", config_name="example", version_base=None) 16 | def main(config): 17 | # 1. initialize trainer (should be done at the first place) 18 | trainer_callbacks = instantiate_multi(config, "callbacks") 19 | trainer_loggers = instantiate_multi(config, "loggers") 20 | trainer = instantiate( 21 | config.trainer, callbacks=trainer_callbacks, logger=trainer_loggers, 22 | ) 23 | # 2. save envs and configs 24 | save_config(config) 25 | seed_everything(config.get("seed", 42), workers=True) 26 | dist_envs.init_envs(trainer) 27 | # 3. enable some loggers 28 | if "log_master" in config.callbacks and config.callbacks.log_master is not None: 29 | enable_logger("callbacks.log_master", config.callbacks.log_master.log_file) 30 | for log_name in ( 31 | "__main__", "data.dataloader", "data.datasets", "data.data_downloader", 32 | "model.model", "components.interpolation.module", "components.interpolation.pipeline" 33 | ): 34 | enable_logger(log_name) 35 | # 4. build model 36 | remote_paths = ["pretrained_model_path", "ckpt_path", "temporal_vae_path"] 37 | for remote_path in remote_paths: 38 | if hasattr(config.model, remote_path): 39 | rpath = getattr(config.model, remote_path) 40 | setattr(config.model, remote_path, download_model(rpath)) 41 | model = instantiate(config.model) 42 | # [DEBUG] show whole config 43 | if os.environ.get("DEBUG_ON", None): 44 | logger.info(OmegaConf.to_yaml(config)) 45 | # 5. starting training/testing 46 | evaluator = config.get("evaluator", None) 47 | if evaluator is None or evaluator == "pl_validate": 48 | # config.data = download_data(config.data) 49 | datamodule = instantiate(config.data) 50 | run_fn = trainer.fit if evaluator is None else trainer.validate 51 | run_fn(model=model, datamodule=datamodule) 52 | else: 53 | model.setup(stage="test") 54 | evaluator = instantiate(config.evaluator) 55 | evaluator(model) 56 | logger.info("All finished.") 57 | 58 | 59 | if __name__ == "__main__": 60 | main() 61 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import SDVideoModel, SDVideoModelEvaluator 2 | -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import random 5 | import re 6 | from datetime import datetime 7 | from pathlib import Path 8 | from typing import Union 9 | import time 10 | 11 | import pandas as pd 12 | import pytorch_lightning as pl 13 | import torch 14 | import torch.nn.functional as F 15 | import torch.utils.checkpoint 16 | from diffusers import AutoencoderKL, DDIMScheduler 17 | from .modules.zero_snr_ddpm import DDPMScheduler ###FIXME changed zero-SNR 18 | from diffusers.optimization import get_scheduler 19 | from diffusers.utils.import_utils import is_xformers_available 20 | from einops import rearrange 21 | from torch.optim import AdamW 22 | from transformers import CLIPTextModel, CLIPTokenizer, CLIPModel, CLIPProcessor 23 | 24 | from components.vae import TemporalAutoencoderKL 25 | from data.utils import load_image 26 | from utils import dist_envs 27 | from utils import slugify 28 | from .modules.unet import UNet3DConditionModel 29 | from .pipeline import SDVideoPipeline 30 | from .utils import save_videos_grid, compute_clip_score, prepare_masked_latents, prepare_entity_latents, FrozenOpenCLIPImageEmbedderV2, Resampler 31 | 32 | logger = logging.getLogger(__name__) 33 | 34 | 35 | class SDVideoModel(pl.LightningModule): 36 | def __init__(self, pretrained_model_path, **kwargs): 37 | super().__init__() 38 | self.save_hyperparameters(ignore=["pretrained_model_path"], logger=False) 39 | # main training module 40 | self.unet: Union[str, UNet3DConditionModel] = Path(pretrained_model_path, "unet").as_posix() 41 | # components for training 42 | self.noise_scheduler_dir = Path(pretrained_model_path, "scheduler").as_posix() 43 | self.vae = Path(pretrained_model_path, "vae").as_posix() 44 | self.text_encoder = Path(pretrained_model_path, "text_encoder").as_posix() 45 | self.tokenizer: Union[str, CLIPTokenizer] = Path(pretrained_model_path, "tokenizer").as_posix() 46 | # clip model for metric 47 | self.clip = Path(pretrained_model_path, "clip").as_posix() 48 | self.clip_processor = Path(pretrained_model_path, "clip").as_posix() 49 | # define pipeline for inference 50 | self.val_pipeline = None 51 | # video frame resolution 52 | self.resolution = kwargs.get("resolution", 512) 53 | # use temporal_vae 54 | self.temporal_vae_path = kwargs.get("temporal_vae_path", None) 55 | # use prompt image 56 | self.in_channels = kwargs.get("in_channels", 4) 57 | self.use_prompt_image = self.in_channels > 4 58 | self.add_entity_vae = kwargs.get("add_entity_vae", False) 59 | self.add_entity_clip = kwargs.get("add_entity_clip", False) 60 | 61 | ### add open clip model 62 | if self.add_entity_clip: 63 | self.embedding_dim = 1280 64 | self.entity_clip_model = FrozenOpenCLIPImageEmbedderV2(arch="ViT-H-14") ###FIXME 65 | self.enclip_projector = Resampler(dim=1024, depth=4, dim_head=64, heads=12, num_queries=16, embedding_dim=self.embedding_dim, output_dim=1024, ff_mult=4) 66 | else: 67 | self.entity_clip_model = None 68 | self.enclip_projector = None 69 | 70 | def setup(self, stage: str) -> None: 71 | # build modules 72 | self.noise_scheduler = DDPMScheduler.from_pretrained(self.noise_scheduler_dir) 73 | self.tokenizer = CLIPTokenizer.from_pretrained(self.tokenizer) 74 | 75 | if self.temporal_vae_path: 76 | self.vae = TemporalAutoencoderKL.from_pretrained(self.temporal_vae_path) 77 | else: 78 | self.vae = AutoencoderKL.from_pretrained(self.vae) 79 | self.text_encoder = CLIPTextModel.from_pretrained(self.text_encoder) 80 | self.unet = UNet3DConditionModel.from_pretrained_2d( 81 | self.unet, sample_size=self.resolution // (2 ** (len(self.vae.config.block_out_channels) - 1)), 82 | in_channels=self.in_channels, 83 | add_temp_transformer=self.hparams.get("add_temp_transformer", False), 84 | add_temp_attn_only_on_upblocks=self.hparams.get("add_temp_attn_only_on_upblocks", False), 85 | prepend_first_frame=self.hparams.get("prepend_first_frame", False), 86 | add_temp_embed=self.hparams.get("add_temp_embed", False), 87 | add_temp_ff=self.hparams.get("add_temp_ff", False), 88 | add_temp_conv=self.hparams.get("add_temp_conv", False), 89 | num_class_embeds=self.hparams.get("num_class_embeds", None) 90 | ) 91 | 92 | # load previously trained components for resumed training 93 | ckpt_path = self.hparams.get("ckpt_path", None) 94 | if ckpt_path is not None: 95 | state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"] 96 | 97 | mod_list = ["unet", "text_encoder"] if self.temporal_vae_path \ 98 | else ["unet", "text_encoder", "vae", "enclip_projector"] 99 | for mod in mod_list: 100 | if any(filter(lambda x: x.startswith(mod), state_dict.keys())): 101 | mod_instance = getattr(self, mod) 102 | mod_instance.load_state_dict( 103 | {k[len(mod) + 1:]: v for k, v in state_dict.items() if k.startswith(mod)}, strict=False 104 | ) 105 | 106 | # null text for classifier-free guidance 107 | self.null_text_token_ids = self.tokenizer( # noqa 108 | "", max_length=self.tokenizer.model_max_length, padding="max_length", 109 | truncation=True, return_tensors="pt", 110 | ).input_ids[0] 111 | 112 | # train only the trainable modules 113 | trainable_modules = self.hparams.get("trainable_modules", None) 114 | if trainable_modules is not None: 115 | self.unet.requires_grad_(False) 116 | for pname, params in self.unet.named_parameters(): 117 | if any([re.search(pat, pname) for pat in trainable_modules]): 118 | params.requires_grad = True 119 | if self.add_entity_clip: 120 | for pname, params in self.entity_clip_model.named_parameters(): 121 | params.requires_grad = False 122 | for pname, params in self.enclip_projector.named_parameters(): 123 | params.requires_grad = True 124 | 125 | # raise error when `in_channel` > 4 and `conv_in` is not trainable 126 | if self.use_prompt_image and not self.unet.conv_in.weight.requires_grad: 127 | raise AssertionError(f"use_prompt_image=True but `unet.conv_in` is frozen.") 128 | if not self.use_prompt_image and self.unet.conv_in.weight.requires_grad: 129 | logger.warning(f"use_prompt_image=False but `unet.conv_in` is trainable.") 130 | 131 | # load clip modules for evaluation 132 | self.clip = CLIPModel.from_pretrained(self.clip) 133 | self.clip_processor = CLIPProcessor.from_pretrained(self.clip_processor) 134 | # prepare modules 135 | for component in [self.vae, self.text_encoder, self.clip]: 136 | if not isinstance(component, CLIPTextModel) or self.hparams.get("freeze_text_encoder", False): 137 | component.requires_grad_(False).eval() 138 | if stage != "test" and self.trainer.precision.startswith("16"): 139 | component.to(dtype=torch.float16) 140 | # [DEBUG] show which parameters are trainable 141 | if os.environ.get("DEBUG_ON", None): 142 | params_trainable, params_frozen = [], [] 143 | for name, params in self.named_parameters(): 144 | if params.requires_grad: 145 | params_trainable.append(name) 146 | else: 147 | params_frozen.append(name) 148 | logger.info(f"*** [Trainable parameters]: {params_trainable}") 149 | logger.info(f"*** [Frozen parameters]: {params_frozen}") 150 | # use gradient checkpointing 151 | if self.hparams.get("enable_gradient_checkpointing", True): 152 | if not self.hparams.get("freeze_text_encoder", False): 153 | self.text_encoder.gradient_checkpointing_enable() 154 | self.unet.enable_gradient_checkpointing() 155 | # use xformers for efficient training 156 | if self.hparams.get("enable_xformers", True) and not hasattr(F, "scaled_dot_product_attention"): 157 | if is_xformers_available(): 158 | self.unet.enable_xformers_memory_efficient_attention() 159 | if self.unet.temp_transformer is not None: 160 | # FIXME: disable this specific layer otherwise CUDA error occurred 161 | self.unet.temp_transformer.set_use_memory_efficient_attention_xformers(False) 162 | else: 163 | raise ValueError("xformers is not available. Make sure it is installed correctly") 164 | # construct pipeline for inference 165 | self.val_pipeline = SDVideoPipeline( 166 | vae=self.vae, text_encoder=self.text_encoder, tokenizer=self.tokenizer, unet=self.unet, 167 | entity_clip_model=self.entity_clip_model,enclip_projector=self.enclip_projector, 168 | scheduler=DDIMScheduler.from_pretrained(self.noise_scheduler_dir), 169 | ) 170 | 171 | def validation_step(self, batch, batch_idx): 172 | pixel_values = batch["pixel_values"].to(dtype=torch.float16) \ 173 | if self.trainer.precision.startswith("16") else batch["pixel_values"] 174 | entity_vae = batch["entity_vae"].to(dtype=torch.float16) \ 175 | if self.trainer.precision.startswith("16") else batch["entity_vae"] 176 | entity_clip = batch["entity_clip"].to(dtype=torch.float16) \ 177 | if self.trainer.precision.startswith("16") else batch["entity_clip"] 178 | 179 | text_token_ids = batch["text_token_ids"] 180 | video_len = pixel_values.shape[1] 181 | # inference arguments 182 | num_inference_steps = self.hparams.get("num_inference_steps", 50) 183 | guidance_scale = self.hparams.get("guidance_scale", 7.5) 184 | noise_alpha = self.hparams.get("noise_alpha", .0) 185 | # parase prompts 186 | prompts = self.tokenizer.batch_decode(text_token_ids, skip_special_tokens=True) 187 | generator = torch.Generator(device=self.device) 188 | seed = 42 189 | generator.manual_seed(seed) 190 | # compose args 191 | pipeline_args = dict( 192 | generator=generator, num_frames=video_len, 193 | num_inference_steps=num_inference_steps, 194 | guidance_scale=guidance_scale, noise_alpha=noise_alpha, 195 | frame_rate=1 if self.hparams.get("num_class_embeds", None) else None 196 | ) 197 | samples = [] 198 | 199 | for example_id in range(len(prompts)): 200 | prompt = prompts[example_id] 201 | prompt_image = pixel_values[example_id][:1, ...] 202 | entity_vae_image = torch.unsqueeze(entity_vae[example_id], 0) 203 | entity_clip_image = torch.unsqueeze(entity_clip[example_id], 0) 204 | sample = self.val_pipeline( 205 | prompt, prompt_image if self.use_prompt_image else None, 206 | entity_vae=entity_vae_image, entity_clip=entity_clip_image, 207 | add_entity_vae=self.add_entity_vae, add_entity_clip=self.add_entity_clip, 208 | **pipeline_args 209 | ).videos 210 | if self.trainer.is_global_zero: 211 | num_step_str = str(self.global_step).zfill(len(str(self.trainer.estimated_stepping_batches))) 212 | if prompt == "": 213 | prompt = "text_prompt_equal_null" 214 | save_videos_grid(sample, Path(f"samples_s{num_step_str}", f"{prompt}.gif")) 215 | samples.append(sample) 216 | # clip model for metric 217 | clip_scores = compute_clip_score( 218 | model=self.clip, model_processor=self.clip_processor, 219 | images=torch.cat(samples), texts=list(prompts), rescale=False, 220 | ) 221 | self.log("val_clip_score", clip_scores.mean(), on_step=False, on_epoch=True, sync_dist=True) 222 | 223 | def configure_optimizers(self): 224 | optimizer_args = dict( 225 | params=filter(lambda p: p.requires_grad, self.parameters()), 226 | lr=self.hparams.get("lr", 1e-3), 227 | weight_decay=self.hparams.get("weight_decay", 1e-2) 228 | ) 229 | optimizer = AdamW(**optimizer_args) 230 | # valid scheduler names: diffusers.optimization.SchedulerType 231 | scheduler_name = self.hparams.get("scheduler_name", "cosine") 232 | scheduler = get_scheduler( 233 | scheduler_name, optimizer=optimizer, 234 | num_warmup_steps=self.hparams.get("warmup_steps", 8), 235 | num_training_steps=self.trainer.estimated_stepping_batches 236 | ) 237 | lr_scheduler = dict(scheduler=scheduler, interval="step", frequency=1) 238 | return dict(optimizer=optimizer, lr_scheduler=lr_scheduler) 239 | 240 | 241 | class SDVideoModelEvaluator: 242 | def __init__(self, **kwargs): 243 | torch.multiprocessing.set_start_method("spawn", force=True) 244 | torch.multiprocessing.set_sharing_strategy("file_system") 245 | 246 | self.seed = kwargs.pop("seed", 42) 247 | self.prompts = kwargs.pop("prompts", None) 248 | if self.prompts is None: 249 | raise ValueError(f"No prompts provided.") 250 | elif isinstance(self.prompts, str) and not Path(self.prompts).exists(): 251 | raise FileNotFoundError(f"Prompt file not found: {self.prompts}") 252 | elif isinstance(self.prompts, str): 253 | if self.prompts.endswith(".txt"): 254 | with open(self.prompts, "r", encoding="utf-8") as f: 255 | self.prompts = [x.strip() for x in f.readlines() if x.strip()] 256 | elif self.prompts.endswith(".json"): 257 | with open(self.prompts, "r", encoding="utf-8") as f: 258 | self.prompts = sorted([ 259 | random.choice(x) if isinstance(x, list) else x 260 | for x in json.load(f).values() 261 | ]) 262 | elif self.prompts.endswith(".csv"): 263 | # prompt images can be set in this condition. 264 | csv_path = self.prompts 265 | df_prompts = pd.read_csv(self.prompts) 266 | self.prompts = df_prompts.iloc[:, 0].tolist() 267 | if len(df_prompts.columns) >= 2: 268 | self.prompts_img = [Path(csv_path).parent.joinpath(x).as_posix() for x in df_prompts.iloc[:, 1]] 269 | else: 270 | self.prompts_img = None 271 | 272 | self.add_file_logger(logger, kwargs.pop("log_file", None)) 273 | self.output_file = kwargs.pop("output_file", "results.csv") 274 | self.batch_size = kwargs.pop("batch_size", 4) 275 | self.val_params = kwargs 276 | 277 | @staticmethod 278 | def add_file_logger(logger, log_file=None, log_level=logging.INFO): 279 | if dist_envs.global_rank == 0 and log_file is not None: 280 | log_handler = logging.FileHandler(log_file, "w") 281 | log_handler.setFormatter( 282 | logging.Formatter("[%(asctime)s][%(name)s][%(levelname)s]: %(message)s")) 283 | log_handler.setLevel(log_level) 284 | logger.addHandler(log_handler) 285 | 286 | @staticmethod 287 | def infer(rank, model, model_params, q_input, q_output, seed=42): 288 | device = torch.device(f"cuda:{rank}") 289 | model.to(device) 290 | generator = torch.Generator(device=device) 291 | generator.manual_seed(seed + rank) 292 | output_video_dir = Path("output_videos") 293 | output_video_dir.mkdir(parents=True, exist_ok=True) 294 | while True: 295 | inputs = q_input.get() 296 | if inputs is None: # check for sentinel value 297 | print(f"[{datetime.now()}] Process #{rank} ended.") 298 | break 299 | start_idx, prompts, prompts_img = inputs 300 | if prompts_img is not None: 301 | prompts_img= prompts_img.to(device) 302 | videos = model.val_pipeline( 303 | prompts, prompts_img, generator=generator, negative_prompt=["watermark"] * len(prompts), 304 | **model_params 305 | ).videos 306 | for idx, prompt in enumerate(prompts): 307 | gif_file = output_video_dir.joinpath(f"{start_idx + idx}_{prompt}.gif") 308 | save_videos_grid(videos[idx:idx + 1, ...], gif_file) 309 | print(f"[{datetime.now()}] Sample is saved #{start_idx + idx}: \"{prompt}\"") 310 | clip_scores = compute_clip_score( 311 | model=model.clip, model_processor=model.clip_processor, 312 | images=videos, texts=prompts, rescale=False, 313 | ) 314 | q_output.put((prompts, clip_scores.cpu().tolist())) 315 | return None 316 | 317 | def __call__(self, model): 318 | model.eval() 319 | 320 | # load prompts images if exist 321 | if model.use_prompt_image and self.prompts_img is not None: 322 | self.prompts_img = torch.stack([ 323 | load_image(x, model.resolution, True, True) 324 | for x in self.prompts_img 325 | ]) 326 | 327 | if not torch.cuda.is_available(): 328 | raise NotImplementedError(f"No GPU found.") 329 | 330 | self.val_params.setdefault( 331 | "num_inference_steps", model.hparams.get("num_inference_steps", 50) 332 | ) 333 | self.val_params.setdefault( 334 | "guidance_scale", model.hparams.get("guidance_scale", 7.5) 335 | ) 336 | self.val_params.setdefault( 337 | "noise_alpha", model.hparams.get("noise_alpha", .0) 338 | ) 339 | logger.info(f"val_params: {self.val_params}") 340 | 341 | q_input = torch.multiprocessing.Queue() 342 | q_output = torch.multiprocessing.Queue() 343 | processes = [] 344 | for rank in range(torch.cuda.device_count()): 345 | p = torch.multiprocessing.Process( 346 | target=self.infer, 347 | args=(rank, model, self.val_params, q_input, q_output, self.seed) 348 | ) 349 | p.start() 350 | processes.append(p) 351 | # send model inputs to queue 352 | result_num = 0 353 | for start_idx in range(0, len(self.prompts), self.batch_size): 354 | result_num += 1 355 | ref_images = self.prompts_img[start_idx:start_idx + self.batch_size] \ 356 | if model.use_prompt_image and self.prompts_img is not None else None 357 | q_input.put(( 358 | start_idx, 359 | self.prompts[start_idx:start_idx + self.batch_size], 360 | ref_images 361 | )) 362 | for _ in processes: 363 | q_input.put(None) # sentinel value to signal subprocesses to exit 364 | # The result queue has to be processed before joining the processes. 365 | results = [q_output.get() for _ in range(result_num)] 366 | # joining the processes 367 | for p in processes: 368 | p.join() # wait for all subprocesses to finish 369 | all_prompts, all_clip_scores = [], [] 370 | for prompts, clip_scores in results: 371 | all_prompts.extend(prompts) 372 | all_clip_scores.extend(clip_scores) 373 | output_df = pd.DataFrame({ 374 | "prompt": all_prompts, "clip_score": all_clip_scores 375 | }) 376 | output_df.to_csv(self.output_file, index=False) 377 | logger.info(f"--- Metrics ---") 378 | logger.info(f"Mean CLIP_SCORE: {sum(all_clip_scores) / len(all_clip_scores)}") 379 | logger.info(f"Test results saved in: {self.output_file}") 380 | -------------------------------------------------------------------------------- /model/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gulucaptain/videoassembler/1fb50cd5f85aa5896b85003aff3641b00946eddc/model/modules/__init__.py -------------------------------------------------------------------------------- /model/modules/attention_hasFFN.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py 2 | from dataclasses import dataclass 3 | from typing import Callable, Optional 4 | 5 | import torch 6 | from diffusers.configuration_utils import ConfigMixin, register_to_config 7 | from diffusers.models import ModelMixin 8 | from diffusers.models.attention import FeedForward, AdaLayerNorm 9 | from diffusers.models.cross_attention import CrossAttention 10 | from diffusers.utils import BaseOutput 11 | from diffusers.utils.import_utils import is_xformers_available 12 | from einops import rearrange, repeat 13 | from torch import nn 14 | 15 | from .modules import get_sin_pos_embedding 16 | from .utils import zero_module 17 | 18 | if is_xformers_available(): 19 | import xformers 20 | import xformers.ops 21 | else: 22 | xformers = None 23 | 24 | 25 | class BasicTransformerBlock(nn.Module): 26 | r""" 27 | A basic Transformer block. 28 | Parameters: 29 | dim (`int`): The number of channels in the input and output. 30 | num_attention_heads (`int`): The number of heads to use for multi-head attention. 31 | attention_head_dim (`int`): The number of channels in each head. 32 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 33 | cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. 34 | only_cross_attention (`bool`, *optional*): 35 | Whether to use only cross-attention layers. In this case two cross attention layers are used. 36 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. 37 | num_embeds_ada_norm (: 38 | obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. 39 | attention_bias (: 40 | obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. 41 | """ 42 | 43 | def __init__( 44 | self, 45 | dim: int, 46 | num_attention_heads: int, 47 | attention_head_dim: int, 48 | dropout=0.0, 49 | cross_attention_dim: Optional[int] = None, 50 | activation_fn: str = "geglu", 51 | num_embeds_ada_norm: Optional[int] = None, 52 | attention_bias: bool = False, 53 | only_cross_attention: bool = False, 54 | upcast_attention: bool = False, 55 | norm_elementwise_affine: bool = True, 56 | final_dropout: bool = False, 57 | add_temp_attn: bool = False, 58 | prepend_first_frame: bool = False, 59 | add_temp_embed: bool = False, 60 | ): 61 | super().__init__() 62 | self.only_cross_attention = only_cross_attention 63 | self.use_ada_layer_norm = num_embeds_ada_norm is not None 64 | 65 | # temporal embedding 66 | self.add_temp_embed = add_temp_embed 67 | 68 | if add_temp_attn: 69 | if prepend_first_frame: 70 | # SC-Attn 71 | self.attn1 = SparseCausalAttention( 72 | query_dim=dim, 73 | cross_attention_dim=cross_attention_dim if only_cross_attention else None, 74 | heads=num_attention_heads, 75 | dim_head=attention_head_dim, 76 | dropout=dropout, 77 | bias=attention_bias, 78 | upcast_attention=upcast_attention, 79 | ) 80 | else: 81 | # Normal CrossAttn 82 | self.attn1 = CrossAttention( 83 | query_dim=dim, 84 | cross_attention_dim=cross_attention_dim if only_cross_attention else None, 85 | heads=num_attention_heads, 86 | dim_head=attention_head_dim, 87 | dropout=dropout, 88 | bias=attention_bias, 89 | upcast_attention=upcast_attention, 90 | ) 91 | 92 | # Temp-Attn 93 | self.temp_norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) 94 | self.temp_attn = CrossAttention( 95 | query_dim=dim, 96 | heads=num_attention_heads, 97 | dim_head=attention_head_dim, 98 | dropout=dropout, 99 | bias=attention_bias, 100 | upcast_attention=upcast_attention, 101 | ) 102 | self.temp_norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) 103 | self.temp_ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) 104 | if isinstance(self.temp_ff.net[-1], nn.Linear): 105 | self.temp_ff.net[-1] = zero_module(self.temp_ff.net[-1]) 106 | elif isinstance(self.temp_ff.net[-2], nn.Linear): 107 | self.temp_ff.net[-2] = zero_module(self.temp_ff.net[-2]) 108 | else: 109 | # Normal Attention 110 | self.attn1 = CrossAttention( 111 | query_dim=dim, 112 | cross_attention_dim=cross_attention_dim if only_cross_attention else None, 113 | heads=num_attention_heads, 114 | dim_head=attention_head_dim, 115 | dropout=dropout, 116 | bias=attention_bias, 117 | upcast_attention=upcast_attention, 118 | ) 119 | self.temp_attn = None 120 | 121 | self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) \ 122 | if self.use_ada_layer_norm else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) 123 | 124 | # Cross-Attn 125 | if cross_attention_dim is not None: 126 | self.attn2 = CrossAttention( 127 | query_dim=dim, 128 | cross_attention_dim=cross_attention_dim, 129 | heads=num_attention_heads, 130 | dim_head=attention_head_dim, 131 | dropout=dropout, 132 | bias=attention_bias, 133 | upcast_attention=upcast_attention, 134 | ) 135 | # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. 136 | # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during 137 | # the second cross attention block. 138 | self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) \ 139 | if self.use_ada_layer_norm else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) 140 | else: 141 | self.attn2 = None 142 | self.norm2 = None 143 | 144 | # Feed-forward 145 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) 146 | self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) 147 | 148 | def set_use_memory_efficient_attention_xformers( 149 | self, use_memory_efficient_attention_xformers: bool, 150 | attention_op: Optional[Callable] = None 151 | ): 152 | if not is_xformers_available(): 153 | print("Here is how to install it") 154 | raise ModuleNotFoundError( 155 | "Refer to https://github.com/facebookresearch/xformers for more information on how to install" 156 | " xformers", 157 | name="xformers", 158 | ) 159 | elif not torch.cuda.is_available(): 160 | raise ValueError( 161 | "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only" 162 | " available for GPU " 163 | ) 164 | else: 165 | try: 166 | # Make sure we can run the memory efficient attention 167 | xformers.ops.memory_efficient_attention( 168 | torch.randn((1, 2, 40), device="cuda"), 169 | torch.randn((1, 2, 40), device="cuda"), 170 | torch.randn((1, 2, 40), device="cuda"), 171 | ) 172 | except Exception as e: 173 | raise e 174 | self.attn1.set_use_memory_efficient_attention_xformers( 175 | use_memory_efficient_attention_xformers, attention_op=attention_op) 176 | if self.attn2 is not None: 177 | self.attn2.set_use_memory_efficient_attention_xformers( 178 | use_memory_efficient_attention_xformers, attention_op=attention_op) 179 | if self.temp_attn is not None: 180 | self.temp_attn.set_use_memory_efficient_attention_xformers( 181 | use_memory_efficient_attention_xformers, attention_op=attention_op) 182 | 183 | def forward( 184 | self, 185 | hidden_states, 186 | encoder_hidden_states=None, 187 | timestep=None, 188 | attention_mask=None, 189 | video_length=None, 190 | ): 191 | # SparseCausal-Attention 192 | norm_hidden_states = ( 193 | self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) 194 | ) 195 | 196 | attn1_args = dict(hidden_states=norm_hidden_states, attention_mask=attention_mask) 197 | if self.temp_attn is not None and isinstance(self.attn1, SparseCausalAttention): 198 | attn1_args.update({"video_length": video_length}) 199 | # Self-/Sparse-Attention 200 | if self.only_cross_attention: 201 | hidden_states = self.attn1( 202 | **attn1_args, encoder_hidden_states=encoder_hidden_states 203 | ) + hidden_states 204 | else: 205 | hidden_states = self.attn1(**attn1_args) + hidden_states 206 | 207 | if self.attn2 is not None: 208 | # Cross-Attention 209 | norm_hidden_states = ( 210 | self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) 211 | ) 212 | hidden_states = ( 213 | self.attn2( 214 | norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask 215 | ) 216 | + hidden_states 217 | ) 218 | 219 | if self.temp_attn is not None: 220 | identity = hidden_states 221 | d = hidden_states.shape[1] 222 | # add temporal embedding 223 | if self.add_temp_embed: 224 | temp_emb = get_sin_pos_embedding( 225 | hidden_states.shape[-1], video_length).to(hidden_states) 226 | hidden_states = rearrange(hidden_states, "(b f) d c -> b d f c", f=video_length) 227 | hidden_states += temp_emb 228 | hidden_states = rearrange(hidden_states, "b d f c -> (b f) d c") 229 | # normalization 230 | hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length) 231 | norm_hidden_states1 = ( 232 | self.temp_norm1(hidden_states, timestep) if self.use_ada_layer_norm 233 | else self.temp_norm1(hidden_states) 234 | ) 235 | # apply temporal attention 236 | hidden_states = self.temp_attn(norm_hidden_states1) + hidden_states 237 | # apply temporal feed-forward 238 | norm_hidden_states2 = ( 239 | self.temp_norm2(hidden_states, timestep) if self.use_ada_layer_norm 240 | else self.temp_norm2(hidden_states) 241 | ) 242 | hidden_states = self.temp_ff(norm_hidden_states2) + hidden_states 243 | # rearrange back 244 | hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) 245 | # ignore effects of temporal layers on image inputs 246 | if video_length <= 1: 247 | hidden_states = identity + 0.0 * hidden_states 248 | 249 | # Feed-forward 250 | hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states 251 | 252 | return hidden_states 253 | 254 | 255 | @dataclass 256 | class Transformer3DModelOutput(BaseOutput): 257 | sample: torch.FloatTensor 258 | 259 | 260 | class Transformer3DModel(ModelMixin, ConfigMixin): 261 | @register_to_config 262 | def __init__( 263 | self, 264 | num_attention_heads: int = 16, 265 | attention_head_dim: int = 88, 266 | in_channels: Optional[int] = None, 267 | num_layers: int = 1, 268 | dropout: float = 0.0, 269 | norm_num_groups: int = 32, 270 | cross_attention_dim: Optional[int] = None, 271 | attention_bias: bool = False, 272 | activation_fn: str = "geglu", 273 | num_embeds_ada_norm: Optional[int] = None, 274 | use_linear_projection: bool = False, 275 | only_cross_attention: bool = False, 276 | upcast_attention: bool = False, 277 | add_temp_attn: bool = False, 278 | prepend_first_frame: bool = False, 279 | add_temp_embed: bool = False 280 | ): 281 | super().__init__() 282 | self.use_linear_projection = use_linear_projection 283 | self.num_attention_heads = num_attention_heads 284 | self.attention_head_dim = attention_head_dim 285 | inner_dim = num_attention_heads * attention_head_dim 286 | 287 | # Define input layers 288 | self.in_channels = in_channels 289 | 290 | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) 291 | if use_linear_projection: 292 | self.proj_in = nn.Linear(in_channels, inner_dim) 293 | else: 294 | self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) 295 | 296 | # Define transformers blocks 297 | self.transformer_blocks = nn.ModuleList( 298 | [ 299 | BasicTransformerBlock( 300 | inner_dim, 301 | num_attention_heads, 302 | attention_head_dim, 303 | dropout=dropout, 304 | cross_attention_dim=cross_attention_dim, 305 | activation_fn=activation_fn, 306 | num_embeds_ada_norm=num_embeds_ada_norm, 307 | attention_bias=attention_bias, 308 | only_cross_attention=only_cross_attention, 309 | upcast_attention=upcast_attention, 310 | add_temp_attn=add_temp_attn, 311 | prepend_first_frame=prepend_first_frame, 312 | add_temp_embed=add_temp_embed 313 | ) 314 | for _ in range(num_layers) 315 | ] 316 | ) 317 | 318 | # 4. Define output layers 319 | if use_linear_projection: 320 | self.proj_out = nn.Linear(in_channels, inner_dim) 321 | else: 322 | self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) 323 | 324 | def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, 325 | return_dict=False): 326 | # Input 327 | assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." 328 | video_length = hidden_states.shape[2] 329 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") 330 | encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length) 331 | 332 | batch, channel, height, weight = hidden_states.shape 333 | residual = hidden_states 334 | 335 | hidden_states = self.norm(hidden_states) 336 | if not self.use_linear_projection: 337 | hidden_states = self.proj_in(hidden_states) 338 | inner_dim = hidden_states.shape[1] 339 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) 340 | else: 341 | inner_dim = hidden_states.shape[1] 342 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) 343 | hidden_states = self.proj_in(hidden_states) 344 | 345 | # Blocks 346 | for block in self.transformer_blocks: 347 | hidden_states = block( 348 | hidden_states, 349 | encoder_hidden_states=encoder_hidden_states, 350 | timestep=timestep, 351 | video_length=video_length 352 | ) 353 | 354 | # Output 355 | if not self.use_linear_projection: 356 | hidden_states = ( 357 | hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() 358 | ) 359 | hidden_states = self.proj_out(hidden_states) 360 | else: 361 | hidden_states = self.proj_out(hidden_states) 362 | hidden_states = ( 363 | hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() 364 | ) 365 | 366 | output = hidden_states + residual 367 | 368 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) 369 | if not return_dict: 370 | return output, 371 | 372 | return Transformer3DModelOutput(sample=output) 373 | 374 | 375 | @dataclass 376 | class TransformerTemporalModelOutput(BaseOutput): 377 | """ 378 | Args: 379 | sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`) 380 | Hidden states conditioned on `encoder_hidden_states` input. 381 | """ 382 | sample: torch.FloatTensor 383 | 384 | 385 | class TransformerTemporalModel(ModelMixin, ConfigMixin): 386 | """ 387 | Transformer model for video-like data. 388 | 389 | Parameters: 390 | num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. 391 | attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. 392 | in_channels (`int`, *optional*): 393 | Pass if the input is continuous. The number of channels in the input and output. 394 | num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. 395 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 396 | cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. 397 | sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. 398 | Note that this is fixed at training time as it is used for learning a number of position embeddings. See 399 | `ImagePositionalEmbeddings`. 400 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. 401 | attention_bias (`bool`, *optional*): 402 | Configure if the TransformerBlocks' attention should contain a bias parameter. 403 | """ 404 | 405 | @register_to_config 406 | def __init__( 407 | self, 408 | num_attention_heads: int = 16, 409 | attention_head_dim: int = 88, 410 | in_channels: Optional[int] = None, 411 | num_layers: int = 1, 412 | dropout: float = 0.0, 413 | norm_num_groups: int = 32, 414 | cross_attention_dim: Optional[int] = None, 415 | attention_bias: bool = False, 416 | activation_fn: str = "geglu", 417 | norm_elementwise_affine: bool = True, 418 | add_temp_embed: bool = False, 419 | ): 420 | super().__init__() 421 | self.num_attention_heads = num_attention_heads 422 | self.attention_head_dim = attention_head_dim 423 | inner_dim = num_attention_heads * attention_head_dim 424 | 425 | self.in_channels = in_channels 426 | 427 | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) 428 | self.proj_in = nn.Linear(in_channels, inner_dim) 429 | 430 | # 3. Define transformers blocks 431 | self.transformer_blocks = nn.ModuleList( 432 | [ 433 | BasicTransformerBlock( 434 | inner_dim, 435 | num_attention_heads, 436 | attention_head_dim, 437 | dropout=dropout, 438 | cross_attention_dim=cross_attention_dim, 439 | activation_fn=activation_fn, 440 | attention_bias=attention_bias, 441 | norm_elementwise_affine=norm_elementwise_affine, 442 | add_temp_embed=add_temp_embed 443 | ) 444 | for _ in range(num_layers) 445 | ] 446 | ) 447 | 448 | self.proj_out = nn.Linear(inner_dim, in_channels) 449 | self.proj_out = zero_module(self.proj_out) 450 | 451 | def forward( 452 | self, 453 | hidden_states, 454 | encoder_hidden_states=None, 455 | timestep=None, 456 | return_dict: bool = True, 457 | ): 458 | """ 459 | Args: 460 | hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. 461 | When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input 462 | hidden_states 463 | encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): 464 | Conditional embeddings for cross attention layer. If not given, cross-attention defaults to 465 | self-attention. 466 | timestep ( `torch.long`, *optional*): 467 | Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. 468 | return_dict (`bool`, *optional*, defaults to `True`): 469 | Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. 470 | 471 | Returns: 472 | [`~models.transformer_2d.TransformerTemporalModelOutput`] or `tuple`: 473 | [`~models.transformer_2d.TransformerTemporalModelOutput`] if `return_dict` is True, otherwise a `tuple`. 474 | When returning a tuple, the first element is the sample tensor. 475 | """ 476 | # 1. Input 477 | batch_size, channel, num_frames, height, width = hidden_states.shape 478 | 479 | residual = hidden_states 480 | 481 | hidden_states = self.norm(hidden_states) 482 | hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape( 483 | batch_size * height * width, num_frames, channel) 484 | hidden_states = self.proj_in(hidden_states) 485 | 486 | # 2. Blocks 487 | for block in self.transformer_blocks: 488 | hidden_states = block( 489 | hidden_states, 490 | encoder_hidden_states=encoder_hidden_states, 491 | timestep=timestep, 492 | video_length=num_frames 493 | ) 494 | 495 | # 3. Output 496 | hidden_states = self.proj_out(hidden_states) 497 | hidden_states = ( 498 | hidden_states[None, None, :] 499 | .reshape(batch_size, height, width, channel, num_frames) 500 | .permute(0, 3, 4, 1, 2) 501 | .contiguous() 502 | ) 503 | output = hidden_states + residual 504 | 505 | if not return_dict: 506 | return output, 507 | 508 | return TransformerTemporalModelOutput(sample=output) 509 | 510 | 511 | class SparseCausalAttention(CrossAttention): 512 | def forward(self, hidden_states, encoder_hidden_states=None, 513 | attention_mask=None, **cross_attention_kwargs): 514 | batch_size, sequence_length, _ = hidden_states.shape 515 | video_length = cross_attention_kwargs.get("video_length", 8) 516 | attention_mask = self.prepare_attention_mask( 517 | attention_mask, sequence_length, batch_size) 518 | query = self.to_q(hidden_states) 519 | dim = query.shape[-1] 520 | 521 | if self.added_kv_proj_dim is not None: 522 | raise NotImplementedError 523 | 524 | if encoder_hidden_states is None: 525 | encoder_hidden_states = hidden_states 526 | elif self.cross_attention_norm: 527 | encoder_hidden_states = self.norm_cross(encoder_hidden_states) 528 | 529 | key = self.to_k(encoder_hidden_states) 530 | value = self.to_v(encoder_hidden_states) 531 | 532 | former_frame_index = torch.arange(video_length) - 1 533 | former_frame_index[0] = 0 534 | 535 | key = rearrange(key, "(b f) d c -> b f d c", f=video_length) 536 | if video_length > 1: 537 | key = torch.cat([key[:, [0] * video_length], key[:, former_frame_index]], dim=2) 538 | key = rearrange(key, "b f d c -> (b f) d c") 539 | 540 | value = rearrange(value, "(b f) d c -> b f d c", f=video_length) 541 | if video_length > 1: 542 | value = torch.cat([value[:, [0] * video_length], value[:, former_frame_index]], dim=2) 543 | value = rearrange(value, "b f d c -> (b f) d c") 544 | 545 | query = self.head_to_batch_dim(query) 546 | key = self.head_to_batch_dim(key) 547 | value = self.head_to_batch_dim(value) 548 | 549 | # attention, what we cannot get enough of 550 | if hasattr(self.processor, "attention_op"): 551 | hidden_states = xformers.ops.memory_efficient_attention( 552 | query, key, value, attn_bias=attention_mask, op=self.processor.attention_op 553 | ) 554 | hidden_states = hidden_states.to(query.dtype) 555 | elif hasattr(self.processor, "slice_size"): 556 | batch_size_attention = query.shape[0] 557 | hidden_states = torch.zeros( 558 | (batch_size_attention, sequence_length, dim // self.heads), 559 | device=query.device, dtype=query.dtype 560 | ) 561 | for i in range(hidden_states.shape[0] // self.processor.slice_size): 562 | start_idx = i * self.slice_size 563 | end_idx = (i + 1) * self.slice_size 564 | query_slice = query[start_idx:end_idx] 565 | key_slice = key[start_idx:end_idx] 566 | attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None 567 | attn_slice = self.get_attention_scores(query_slice, key_slice, attn_mask_slice) 568 | attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) 569 | hidden_states[start_idx:end_idx] = attn_slice 570 | else: 571 | attention_probs = self.get_attention_scores(query, key, attention_mask) 572 | hidden_states = torch.bmm(attention_probs, value) 573 | hidden_states = self.batch_to_head_dim(hidden_states) 574 | 575 | # linear proj 576 | hidden_states = self.to_out[0](hidden_states) 577 | 578 | # dropout 579 | hidden_states = self.to_out[1](hidden_states) 580 | return hidden_states 581 | -------------------------------------------------------------------------------- /model/modules/modules.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | 5 | 6 | def get_sin_pos_embedding(embed_dim, seq_len): 7 | """ 8 | :param embed_dim: dimension of the model 9 | :param seq_len: length of positions 10 | :return: [length, embed_dim] position matrix 11 | """ 12 | if embed_dim % 2 != 0: 13 | raise ValueError("Cannot use sin/cos positional encoding with " 14 | "odd dim (got dim={:d})".format(embed_dim)) 15 | pe = torch.zeros(seq_len, embed_dim) 16 | position = torch.arange(0, seq_len).unsqueeze(1) 17 | div_term = torch.exp(torch.arange(0, embed_dim, 2, dtype=torch.float) * 18 | -(math.log(10000.0) / embed_dim)) 19 | pe[:, 0::2] = torch.sin(position.float() * div_term) 20 | pe[:, 1::2] = torch.cos(position.float() * div_term) 21 | 22 | return pe 23 | -------------------------------------------------------------------------------- /model/modules/resnet.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from einops import rearrange 8 | 9 | 10 | class InflatedConv3d(nn.Conv2d): 11 | def forward(self, x): 12 | video_length = x.shape[2] 13 | 14 | x = rearrange(x, "b c f h w -> (b f) c h w") 15 | x = super().forward(x) 16 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) 17 | return x 18 | 19 | 20 | class Upsample3D(nn.Module): 21 | def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): 22 | super().__init__() 23 | self.channels = channels 24 | self.out_channels = out_channels or channels 25 | self.use_conv = use_conv 26 | self.use_conv_transpose = use_conv_transpose 27 | self.name = name 28 | 29 | conv = None 30 | if use_conv_transpose: 31 | raise NotImplementedError 32 | elif use_conv: 33 | conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1) 34 | 35 | if name == "conv": 36 | self.conv = conv 37 | else: 38 | self.Conv2d_0 = conv 39 | 40 | def forward(self, hidden_states, output_size=None): 41 | assert hidden_states.shape[1] == self.channels 42 | 43 | if self.use_conv_transpose: 44 | raise NotImplementedError 45 | 46 | # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 47 | dtype = hidden_states.dtype 48 | if dtype == torch.bfloat16: 49 | hidden_states = hidden_states.to(torch.float32) 50 | 51 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 52 | if hidden_states.shape[0] >= 64: 53 | hidden_states = hidden_states.contiguous() 54 | 55 | # if `output_size` is passed we force the interpolation output 56 | # size and do not make use of `scale_factor=2` 57 | if output_size is None: 58 | hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest") 59 | else: 60 | hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") 61 | 62 | # If the input is bfloat16, we cast back to bfloat16 63 | if dtype == torch.bfloat16: 64 | hidden_states = hidden_states.to(dtype) 65 | 66 | if self.use_conv: 67 | if self.name == "conv": 68 | hidden_states = self.conv(hidden_states) 69 | else: 70 | hidden_states = self.Conv2d_0(hidden_states) 71 | 72 | return hidden_states 73 | 74 | 75 | class Downsample3D(nn.Module): 76 | def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): 77 | super().__init__() 78 | self.channels = channels 79 | self.out_channels = out_channels or channels 80 | self.use_conv = use_conv 81 | self.padding = padding 82 | stride = 2 83 | self.name = name 84 | 85 | if use_conv: 86 | conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding) 87 | else: 88 | raise NotImplementedError 89 | 90 | if name == "conv": 91 | self.Conv2d_0 = conv 92 | self.conv = conv 93 | elif name == "Conv2d_0": 94 | self.conv = conv 95 | else: 96 | self.conv = conv 97 | 98 | def forward(self, hidden_states): 99 | assert hidden_states.shape[1] == self.channels 100 | if self.use_conv and self.padding == 0: 101 | raise NotImplementedError 102 | 103 | assert hidden_states.shape[1] == self.channels 104 | hidden_states = self.conv(hidden_states) 105 | 106 | return hidden_states 107 | 108 | 109 | class ResnetBlock3D(nn.Module): 110 | def __init__( 111 | self, 112 | *, 113 | in_channels, 114 | out_channels=None, 115 | conv_shortcut=False, 116 | dropout=0.0, 117 | temb_channels=512, 118 | groups=32, 119 | groups_out=None, 120 | pre_norm=True, 121 | eps=1e-6, 122 | non_linearity="swish", 123 | time_embedding_norm="default", 124 | output_scale_factor=1.0, 125 | use_in_shortcut=None, 126 | ): 127 | super().__init__() 128 | self.pre_norm = pre_norm 129 | self.pre_norm = True 130 | self.in_channels = in_channels 131 | out_channels = in_channels if out_channels is None else out_channels 132 | self.out_channels = out_channels 133 | self.use_conv_shortcut = conv_shortcut 134 | self.time_embedding_norm = time_embedding_norm 135 | self.output_scale_factor = output_scale_factor 136 | 137 | if groups_out is None: 138 | groups_out = groups 139 | 140 | self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) 141 | 142 | self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) 143 | 144 | if temb_channels is not None: 145 | if self.time_embedding_norm == "default": 146 | time_emb_proj_out_channels = out_channels 147 | elif self.time_embedding_norm == "scale_shift": 148 | time_emb_proj_out_channels = out_channels * 2 149 | else: 150 | raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") 151 | 152 | self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels) 153 | else: 154 | self.time_emb_proj = None 155 | 156 | self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) 157 | self.dropout = torch.nn.Dropout(dropout) 158 | self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 159 | 160 | if non_linearity == "swish": 161 | self.nonlinearity = lambda x: F.silu(x) 162 | elif non_linearity == "mish": 163 | self.nonlinearity = Mish() 164 | elif non_linearity == "silu": 165 | self.nonlinearity = nn.SiLU() 166 | 167 | self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut 168 | 169 | self.conv_shortcut = None 170 | if self.use_in_shortcut: 171 | self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) 172 | 173 | def forward(self, input_tensor, temb): 174 | hidden_states = input_tensor 175 | 176 | hidden_states = self.norm1(hidden_states) 177 | hidden_states = self.nonlinearity(hidden_states) 178 | 179 | hidden_states = self.conv1(hidden_states) 180 | 181 | if temb is not None: 182 | temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None] 183 | 184 | if temb is not None and self.time_embedding_norm == "default": 185 | hidden_states = hidden_states + temb 186 | 187 | hidden_states = self.norm2(hidden_states) 188 | 189 | if temb is not None and self.time_embedding_norm == "scale_shift": 190 | scale, shift = torch.chunk(temb, 2, dim=1) 191 | hidden_states = hidden_states * (1 + scale) + shift 192 | 193 | hidden_states = self.nonlinearity(hidden_states) 194 | 195 | hidden_states = self.dropout(hidden_states) 196 | hidden_states = self.conv2(hidden_states) 197 | 198 | if self.conv_shortcut is not None: 199 | input_tensor = self.conv_shortcut(input_tensor) 200 | 201 | output_tensor = (input_tensor + hidden_states) / self.output_scale_factor 202 | 203 | return output_tensor 204 | 205 | 206 | class Mish(torch.nn.Module): 207 | def forward(self, hidden_states): 208 | return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) 209 | -------------------------------------------------------------------------------- /model/modules/uent.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py 2 | 3 | import json 4 | import os 5 | from dataclasses import dataclass 6 | from typing import List, Optional, Tuple, Union 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.utils.checkpoint 11 | from diffusers.configuration_utils import ConfigMixin, register_to_config 12 | from diffusers.models import ModelMixin 13 | from diffusers.models.embeddings import TimestepEmbedding, Timesteps 14 | from diffusers.utils import BaseOutput, logging 15 | 16 | from .attention import TransformerTemporalModel 17 | from .resnet import InflatedConv3d 18 | from .unet_blocks import ( 19 | CrossAttnDownBlock3D, 20 | CrossAttnUpBlock3D, 21 | DownBlock3D, 22 | UNetMidBlock3DCrossAttn, 23 | UpBlock3D, 24 | get_down_block, 25 | get_up_block, 26 | ) 27 | 28 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 29 | 30 | 31 | @dataclass 32 | class UNet3DConditionOutput(BaseOutput): 33 | sample: torch.FloatTensor 34 | 35 | 36 | class UNet3DConditionModel(ModelMixin, ConfigMixin): 37 | _supports_gradient_checkpointing = True 38 | 39 | @register_to_config 40 | def __init__( 41 | self, 42 | sample_size: Optional[int] = None, 43 | in_channels: int = 4, 44 | out_channels: int = 4, 45 | center_input_sample: bool = False, 46 | flip_sin_to_cos: bool = True, 47 | freq_shift: int = 0, 48 | down_block_types: Tuple[str] = ( 49 | "CrossAttnDownBlock3D", 50 | "CrossAttnDownBlock3D", 51 | "CrossAttnDownBlock3D", 52 | "DownBlock3D", 53 | ), 54 | mid_block_type: str = "UNetMidBlock3DCrossAttn", 55 | up_block_types: Tuple[str] = ( 56 | "UpBlock3D", 57 | "CrossAttnUpBlock3D", 58 | "CrossAttnUpBlock3D", 59 | "CrossAttnUpBlock3D", 60 | ), 61 | only_cross_attention: Union[bool, Tuple[bool]] = False, 62 | block_out_channels: Tuple[int] = (320, 640, 1280, 1280), 63 | layers_per_block: int = 2, 64 | downsample_padding: int = 1, 65 | mid_block_scale_factor: float = 1, 66 | act_fn: str = "silu", 67 | norm_num_groups: int = 32, 68 | norm_eps: float = 1e-5, 69 | cross_attention_dim: int = 1280, 70 | attention_head_dim: Union[int, Tuple[int]] = 8, 71 | dual_cross_attention: bool = False, 72 | use_linear_projection: bool = False, 73 | class_embed_type: Optional[str] = None, 74 | num_class_embeds: Optional[int] = None, 75 | upcast_attention: bool = False, 76 | resnet_time_scale_shift: str = "default", 77 | add_temp_transformer: bool = False, 78 | add_temp_attn_only_on_upblocks: bool = False, 79 | prepend_first_frame: bool = False, 80 | add_temp_embed: bool = False, 81 | add_temp_ff: bool = False, 82 | add_temp_conv: bool = False, 83 | ): 84 | super().__init__() 85 | 86 | self.sample_size = sample_size 87 | time_embed_dim = block_out_channels[0] * 4 88 | 89 | # input 90 | self.conv_in = InflatedConv3d( 91 | in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1) 92 | ) 93 | 94 | self.temp_transformer = TransformerTemporalModel( 95 | num_attention_heads=8, attention_head_dim=64, 96 | in_channels=block_out_channels[0], 97 | num_layers=1, add_temp_embed=add_temp_embed, 98 | add_temp_ff=add_temp_ff 99 | ) if add_temp_transformer else None 100 | 101 | # time 102 | self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) 103 | timestep_input_dim = block_out_channels[0] 104 | 105 | self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) 106 | 107 | # class embedding 108 | if class_embed_type is None and num_class_embeds is not None: 109 | self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) 110 | elif class_embed_type == "timestep": 111 | self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) 112 | elif class_embed_type == "identity": 113 | self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) 114 | else: 115 | self.class_embedding = None 116 | 117 | self.down_blocks = nn.ModuleList([]) 118 | self.mid_block = None 119 | self.up_blocks = nn.ModuleList([]) 120 | 121 | if isinstance(only_cross_attention, bool): 122 | only_cross_attention = [only_cross_attention] * len(down_block_types) 123 | 124 | if isinstance(attention_head_dim, int): 125 | attention_head_dim = (attention_head_dim,) * len(down_block_types) 126 | 127 | # down 128 | output_channel = block_out_channels[0] 129 | for i, down_block_type in enumerate(down_block_types): 130 | input_channel = output_channel 131 | output_channel = block_out_channels[i] 132 | is_final_block = i == len(block_out_channels) - 1 133 | 134 | down_block = get_down_block( 135 | down_block_type, 136 | num_layers=layers_per_block, 137 | in_channels=input_channel, 138 | out_channels=output_channel, 139 | temb_channels=time_embed_dim, 140 | add_downsample=not is_final_block, 141 | resnet_eps=norm_eps, 142 | resnet_act_fn=act_fn, 143 | resnet_groups=norm_num_groups, 144 | cross_attention_dim=cross_attention_dim, 145 | attn_num_head_channels=attention_head_dim[i], 146 | downsample_padding=downsample_padding, 147 | dual_cross_attention=dual_cross_attention, 148 | use_linear_projection=use_linear_projection, 149 | only_cross_attention=only_cross_attention[i], 150 | upcast_attention=upcast_attention, 151 | resnet_time_scale_shift=resnet_time_scale_shift, 152 | add_temp_attn=not add_temp_attn_only_on_upblocks, 153 | prepend_first_frame=prepend_first_frame, 154 | add_temp_embed=add_temp_embed, 155 | add_temp_ff=add_temp_ff, 156 | add_temp_conv=add_temp_conv 157 | ) 158 | self.down_blocks.append(down_block) 159 | 160 | # mid 161 | if mid_block_type == "UNetMidBlock3DCrossAttn": 162 | self.mid_block = UNetMidBlock3DCrossAttn( 163 | in_channels=block_out_channels[-1], 164 | temb_channels=time_embed_dim, 165 | resnet_eps=norm_eps, 166 | resnet_act_fn=act_fn, 167 | output_scale_factor=mid_block_scale_factor, 168 | resnet_time_scale_shift=resnet_time_scale_shift, 169 | cross_attention_dim=cross_attention_dim, 170 | attn_num_head_channels=attention_head_dim[-1], 171 | resnet_groups=norm_num_groups, 172 | dual_cross_attention=dual_cross_attention, 173 | use_linear_projection=use_linear_projection, 174 | upcast_attention=upcast_attention, 175 | add_temp_attn=not add_temp_attn_only_on_upblocks, 176 | prepend_first_frame=prepend_first_frame, 177 | add_temp_embed=add_temp_embed, 178 | add_temp_ff=add_temp_ff, 179 | add_temp_conv=add_temp_conv, 180 | ) 181 | else: 182 | raise ValueError(f"unknown mid_block_type : {mid_block_type}") 183 | 184 | # count how many layers upsample the videos 185 | self.num_upsamplers = 0 186 | 187 | # up 188 | reversed_block_out_channels = list(reversed(block_out_channels)) 189 | reversed_attention_head_dim = list(reversed(attention_head_dim)) 190 | only_cross_attention = list(reversed(only_cross_attention)) 191 | output_channel = reversed_block_out_channels[0] 192 | for i, up_block_type in enumerate(up_block_types): 193 | is_final_block = i == len(block_out_channels) - 1 194 | 195 | prev_output_channel = output_channel 196 | output_channel = reversed_block_out_channels[i] 197 | input_channel = reversed_block_out_channels[ 198 | min(i + 1, len(block_out_channels) - 1) 199 | ] 200 | 201 | # add upsample block for all BUT final layer 202 | if not is_final_block: 203 | add_upsample = True 204 | self.num_upsamplers += 1 205 | else: 206 | add_upsample = False 207 | 208 | up_block = get_up_block( 209 | up_block_type, 210 | num_layers=layers_per_block + 1, 211 | in_channels=input_channel, 212 | out_channels=output_channel, 213 | prev_output_channel=prev_output_channel, 214 | temb_channels=time_embed_dim, 215 | add_upsample=add_upsample, 216 | resnet_eps=norm_eps, 217 | resnet_act_fn=act_fn, 218 | resnet_groups=norm_num_groups, 219 | cross_attention_dim=cross_attention_dim, 220 | attn_num_head_channels=reversed_attention_head_dim[i], 221 | dual_cross_attention=dual_cross_attention, 222 | use_linear_projection=use_linear_projection, 223 | only_cross_attention=only_cross_attention[i], 224 | upcast_attention=upcast_attention, 225 | resnet_time_scale_shift=resnet_time_scale_shift, 226 | add_temp_attn=True, 227 | prepend_first_frame=prepend_first_frame, 228 | add_temp_embed=add_temp_embed, 229 | add_temp_ff=add_temp_ff, 230 | add_temp_conv=add_temp_conv 231 | ) 232 | self.up_blocks.append(up_block) 233 | prev_output_channel = output_channel 234 | 235 | # out 236 | self.conv_norm_out = nn.GroupNorm( 237 | num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps 238 | ) 239 | self.conv_act = nn.SiLU() 240 | self.conv_out = InflatedConv3d( 241 | block_out_channels[0], out_channels, kernel_size=3, padding=1 242 | ) 243 | 244 | def set_attention_slice(self, slice_size): 245 | r""" 246 | Enable sliced attention computation. 247 | 248 | When this option is enabled, the attention module will split the input tensor in slices, to compute attention 249 | in several steps. This is useful to save some memory in exchange for a small speed decrease. 250 | 251 | Args: 252 | slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): 253 | When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If 254 | `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is 255 | provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` 256 | must be a multiple of `slice_size`. 257 | """ 258 | sliceable_head_dims = [] 259 | 260 | def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): 261 | if hasattr(module, "set_attention_slice"): 262 | sliceable_head_dims.append(module.sliceable_head_dim) 263 | 264 | for child in module.children(): 265 | fn_recursive_retrieve_slicable_dims(child) 266 | 267 | # retrieve number of attention layers 268 | for module in self.children(): 269 | fn_recursive_retrieve_slicable_dims(module) 270 | 271 | num_slicable_layers = len(sliceable_head_dims) 272 | 273 | if slice_size == "auto": 274 | # half the attention head size is usually a good trade-off between 275 | # speed and memory 276 | slice_size = [dim // 2 for dim in sliceable_head_dims] 277 | elif slice_size == "max": 278 | # make smallest slice possible 279 | slice_size = num_slicable_layers * [1] 280 | 281 | slice_size = ( 282 | num_slicable_layers * [slice_size] 283 | if not isinstance(slice_size, list) 284 | else slice_size 285 | ) 286 | 287 | if len(slice_size) != len(sliceable_head_dims): 288 | raise ValueError( 289 | f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" 290 | f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." 291 | ) 292 | 293 | for i in range(len(slice_size)): 294 | size = slice_size[i] 295 | dim = sliceable_head_dims[i] 296 | if size is not None and size > dim: 297 | raise ValueError(f"size {size} has to be smaller or equal to {dim}.") 298 | 299 | # Recursively walk through all the children. 300 | # Any children which exposes the set_attention_slice method 301 | # gets the message 302 | def fn_recursive_set_attention_slice( 303 | module: torch.nn.Module, slice_size: List[int] 304 | ): 305 | if hasattr(module, "set_attention_slice"): 306 | module.set_attention_slice(slice_size.pop()) 307 | 308 | for child in module.children(): 309 | fn_recursive_set_attention_slice(child, slice_size) 310 | 311 | reversed_slice_size = list(reversed(slice_size)) 312 | for module in self.children(): 313 | fn_recursive_set_attention_slice(module, reversed_slice_size) 314 | 315 | def _set_gradient_checkpointing(self, module, value=False): 316 | if isinstance( 317 | module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D) 318 | ): 319 | module.gradient_checkpointing = value 320 | 321 | def forward( 322 | self, 323 | sample: torch.FloatTensor, 324 | timestep: Union[torch.Tensor, float, int], 325 | encoder_hidden_states: torch.Tensor, 326 | class_labels: Optional[torch.Tensor] = None, 327 | attention_mask: Optional[torch.Tensor] = None, 328 | return_dict: bool = True, 329 | ) -> Union[UNet3DConditionOutput, Tuple]: 330 | r""" 331 | Args: 332 | sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor 333 | timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps 334 | encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states 335 | return_dict (`bool`, *optional*, defaults to `True`): 336 | Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. 337 | 338 | Returns: 339 | [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: 340 | [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When 341 | returning a tuple, the first element is the sample tensor. 342 | """ 343 | # By default samples have to be AT least a multiple of the overall upsampling factor. 344 | # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). 345 | # However, the upsampling interpolation output size can be forced to fit any upsampling size 346 | # on the fly if necessary. 347 | default_overall_up_factor = 2 ** self.num_upsamplers 348 | 349 | # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` 350 | forward_upsample_size = False 351 | upsample_size = None 352 | 353 | if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): 354 | logger.info("Forward upsample size to force interpolation output size.") 355 | forward_upsample_size = True 356 | 357 | # prepare attention_mask 358 | if attention_mask is not None: 359 | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 360 | attention_mask = attention_mask.unsqueeze(1) 361 | 362 | # center input if necessary 363 | if self.config.center_input_sample: 364 | sample = 2 * sample - 1.0 365 | 366 | # time 367 | timesteps = timestep 368 | if not torch.is_tensor(timesteps): 369 | # This would be a good case for the `match` statement (Python 3.10+) 370 | is_mps = sample.device.type == "mps" 371 | if isinstance(timestep, float): 372 | dtype = torch.float32 if is_mps else torch.float64 373 | else: 374 | dtype = torch.int32 if is_mps else torch.int64 375 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) 376 | elif len(timesteps.shape) == 0: 377 | timesteps = timesteps[None].to(sample.device) 378 | 379 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 380 | timesteps = timesteps.expand(sample.shape[0]) 381 | 382 | t_emb = self.time_proj(timesteps) 383 | 384 | # timesteps does not contain any weights and will always return f32 tensors 385 | # but time_embedding might actually be running in fp16. so we need to cast here. 386 | # there might be better ways to encapsulate this. 387 | t_emb = t_emb.to(dtype=self.dtype) 388 | emb = self.time_embedding(t_emb) 389 | 390 | if self.class_embedding is not None: 391 | if class_labels is None: 392 | raise ValueError( 393 | "class_labels should be provided when num_class_embeds > 0" 394 | ) 395 | 396 | if self.config.class_embed_type == "timestep": 397 | class_labels = self.time_proj(class_labels) 398 | 399 | class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) 400 | emb = emb + class_emb 401 | 402 | # fp16: cast to model dtype 403 | sample = sample.to(self.dtype) 404 | encoder_hidden_states = encoder_hidden_states.to(self.dtype) 405 | 406 | # pre-process 407 | sample = self.conv_in(sample) 408 | if self.temp_transformer is not None: 409 | sample_new = self.temp_transformer(sample).sample 410 | sample = sample_new if sample.shape[2] > 1 else sample + 0.0 * sample_new 411 | 412 | # down 413 | down_block_res_samples = (sample,) 414 | for downsample_block in self.down_blocks: 415 | if ( 416 | hasattr(downsample_block, "has_cross_attention") 417 | and downsample_block.has_cross_attention 418 | ): 419 | sample, res_samples = downsample_block( 420 | hidden_states=sample, 421 | temb=emb, 422 | encoder_hidden_states=encoder_hidden_states, 423 | attention_mask=attention_mask, 424 | ) 425 | else: 426 | sample, res_samples = downsample_block( 427 | hidden_states=sample, temb=emb 428 | ) 429 | down_block_res_samples += res_samples 430 | 431 | # mid 432 | sample = self.mid_block( 433 | sample, 434 | emb, 435 | encoder_hidden_states=encoder_hidden_states, 436 | attention_mask=attention_mask, 437 | ) 438 | 439 | # up 440 | for i, upsample_block in enumerate(self.up_blocks): 441 | is_final_block = i == len(self.up_blocks) - 1 442 | 443 | res_samples = down_block_res_samples[-len(upsample_block.resnets):] 444 | down_block_res_samples = down_block_res_samples[:-len(upsample_block.resnets)] 445 | 446 | # if we have not reached the final block and need to forward the 447 | # upsample size, we do it here 448 | if not is_final_block and forward_upsample_size: 449 | upsample_size = down_block_res_samples[-1].shape[2:] 450 | 451 | if ( 452 | hasattr(upsample_block, "has_cross_attention") 453 | and upsample_block.has_cross_attention 454 | ): 455 | sample = upsample_block( 456 | hidden_states=sample, 457 | temb=emb, 458 | res_hidden_states_tuple=res_samples, 459 | encoder_hidden_states=encoder_hidden_states, 460 | upsample_size=upsample_size, 461 | attention_mask=attention_mask, 462 | ) 463 | else: 464 | sample = upsample_block( 465 | hidden_states=sample, 466 | temb=emb, 467 | res_hidden_states_tuple=res_samples, 468 | upsample_size=upsample_size, 469 | ) 470 | # post-process 471 | sample = self.conv_norm_out(sample) 472 | sample = self.conv_act(sample) 473 | sample = self.conv_out(sample) 474 | 475 | if not return_dict: 476 | return sample, 477 | 478 | return UNet3DConditionOutput(sample=sample) 479 | 480 | @classmethod 481 | def from_pretrained_2d( 482 | cls, pretrained_model_path, subfolder=None, **kwargs): 483 | if subfolder is not None: 484 | pretrained_model_path = os.path.join(pretrained_model_path, subfolder) 485 | 486 | config_file = os.path.join(pretrained_model_path, "config.json") 487 | if not os.path.isfile(config_file): 488 | raise RuntimeError(f"{config_file} does not exist") 489 | with open(config_file, "r") as f: 490 | config = json.load(f) 491 | config["_class_name"] = cls.__name__ 492 | config["down_block_types"] = [ 493 | "CrossAttnDownBlock3D", 494 | "CrossAttnDownBlock3D", 495 | "CrossAttnDownBlock3D", 496 | "DownBlock3D", 497 | ] 498 | config["up_block_types"] = [ 499 | "UpBlock3D", 500 | "CrossAttnUpBlock3D", 501 | "CrossAttnUpBlock3D", 502 | "CrossAttnUpBlock3D", 503 | ] 504 | if "sample_size" in kwargs and "sample_size" in config: 505 | config["sample_size"] = kwargs.get("sample_size") 506 | if "in_channels" in kwargs and "in_channels" in config: 507 | config["in_channels"] = kwargs.get("in_channels") 508 | 509 | from diffusers.utils import WEIGHTS_NAME 510 | 511 | model = cls.from_config( 512 | config, add_temp_transformer=kwargs.get("add_temp_transformer", False), 513 | add_temp_attn_only_on_upblocks=kwargs.get("add_temp_attn_only_on_upblocks", False), 514 | prepend_first_frame=kwargs.get("prepend_first_frame", False), 515 | add_temp_embed=kwargs.get("add_temp_embed", False), 516 | add_temp_ff=kwargs.get("add_temp_ff", False), 517 | add_temp_conv=kwargs.get("add_temp_conv", False), 518 | num_class_embeds=kwargs.get("num_class_embeds", None) 519 | ) 520 | model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) 521 | if not os.path.isfile(model_file): 522 | raise RuntimeError(f"{model_file} does not exist") 523 | state_dict = torch.load(model_file, map_location="cpu") 524 | for k, v in model.state_dict().items(): 525 | if ("temp_" in k or "class_embedding" in k) and k not in state_dict: 526 | state_dict.update({k: v}) 527 | if "conv_in" in k and v.shape != state_dict[k].shape: 528 | state_dict.update({k: v}) 529 | model.load_state_dict(state_dict, strict=False) 530 | 531 | return model 532 | -------------------------------------------------------------------------------- /model/modules/unet_blocks.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from .attention import Transformer3DModel 7 | from .resnet import Downsample3D, ResnetBlock3D, Upsample3D 8 | from .utils import checkpoint, zero_module 9 | 10 | 11 | def get_down_block( 12 | down_block_type, 13 | num_layers, 14 | in_channels, 15 | out_channels, 16 | temb_channels, 17 | add_downsample, 18 | resnet_eps, 19 | resnet_act_fn, 20 | attn_num_head_channels, 21 | resnet_groups=None, 22 | cross_attention_dim=None, 23 | downsample_padding=None, 24 | dual_cross_attention=False, 25 | use_linear_projection=False, 26 | only_cross_attention=False, 27 | upcast_attention=False, 28 | resnet_time_scale_shift="default", 29 | add_temp_attn=False, 30 | prepend_first_frame=False, 31 | add_temp_embed=False, 32 | add_temp_ff=False, 33 | add_temp_conv=False, 34 | ): 35 | down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type 36 | if down_block_type == "DownBlock3D": 37 | return DownBlock3D( 38 | num_layers=num_layers, 39 | in_channels=in_channels, 40 | out_channels=out_channels, 41 | temb_channels=temb_channels, 42 | add_downsample=add_downsample, 43 | resnet_eps=resnet_eps, 44 | resnet_act_fn=resnet_act_fn, 45 | resnet_groups=resnet_groups, 46 | downsample_padding=downsample_padding, 47 | resnet_time_scale_shift=resnet_time_scale_shift, 48 | add_temp_conv=add_temp_conv 49 | ) 50 | elif down_block_type == "CrossAttnDownBlock3D": 51 | if cross_attention_dim is None: 52 | raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D") 53 | return CrossAttnDownBlock3D( 54 | num_layers=num_layers, 55 | in_channels=in_channels, 56 | out_channels=out_channels, 57 | temb_channels=temb_channels, 58 | add_downsample=add_downsample, 59 | resnet_eps=resnet_eps, 60 | resnet_act_fn=resnet_act_fn, 61 | resnet_groups=resnet_groups, 62 | downsample_padding=downsample_padding, 63 | cross_attention_dim=cross_attention_dim, 64 | attn_num_head_channels=attn_num_head_channels, 65 | dual_cross_attention=dual_cross_attention, 66 | use_linear_projection=use_linear_projection, 67 | only_cross_attention=only_cross_attention, 68 | upcast_attention=upcast_attention, 69 | resnet_time_scale_shift=resnet_time_scale_shift, 70 | add_temp_attn=add_temp_attn, 71 | prepend_first_frame=prepend_first_frame, 72 | add_temp_embed=add_temp_embed, 73 | add_temp_ff=add_temp_ff, 74 | add_temp_conv=add_temp_conv 75 | ) 76 | raise ValueError(f"{down_block_type} does not exist.") 77 | 78 | 79 | def get_up_block( 80 | up_block_type, 81 | num_layers, 82 | in_channels, 83 | out_channels, 84 | prev_output_channel, 85 | temb_channels, 86 | add_upsample, 87 | resnet_eps, 88 | resnet_act_fn, 89 | attn_num_head_channels, 90 | resnet_groups=None, 91 | cross_attention_dim=None, 92 | dual_cross_attention=False, 93 | use_linear_projection=False, 94 | only_cross_attention=False, 95 | upcast_attention=False, 96 | resnet_time_scale_shift="default", 97 | add_temp_attn=False, 98 | prepend_first_frame=False, 99 | add_temp_embed=False, 100 | add_temp_ff=False, 101 | add_temp_conv=False, 102 | ): 103 | up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type 104 | if up_block_type == "UpBlock3D": 105 | return UpBlock3D( 106 | num_layers=num_layers, 107 | in_channels=in_channels, 108 | out_channels=out_channels, 109 | prev_output_channel=prev_output_channel, 110 | temb_channels=temb_channels, 111 | add_upsample=add_upsample, 112 | resnet_eps=resnet_eps, 113 | resnet_act_fn=resnet_act_fn, 114 | resnet_groups=resnet_groups, 115 | resnet_time_scale_shift=resnet_time_scale_shift, 116 | add_temp_conv=add_temp_conv 117 | ) 118 | elif up_block_type == "CrossAttnUpBlock3D": 119 | if cross_attention_dim is None: 120 | raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D") 121 | return CrossAttnUpBlock3D( 122 | num_layers=num_layers, 123 | in_channels=in_channels, 124 | out_channels=out_channels, 125 | prev_output_channel=prev_output_channel, 126 | temb_channels=temb_channels, 127 | add_upsample=add_upsample, 128 | resnet_eps=resnet_eps, 129 | resnet_act_fn=resnet_act_fn, 130 | resnet_groups=resnet_groups, 131 | cross_attention_dim=cross_attention_dim, 132 | attn_num_head_channels=attn_num_head_channels, 133 | dual_cross_attention=dual_cross_attention, 134 | use_linear_projection=use_linear_projection, 135 | only_cross_attention=only_cross_attention, 136 | upcast_attention=upcast_attention, 137 | resnet_time_scale_shift=resnet_time_scale_shift, 138 | add_temp_attn=add_temp_attn, 139 | prepend_first_frame=prepend_first_frame, 140 | add_temp_embed=add_temp_embed, 141 | add_temp_ff=add_temp_ff, 142 | add_temp_conv=add_temp_conv 143 | ) 144 | raise ValueError(f"{up_block_type} does not exist.") 145 | 146 | 147 | class TemporalConvLayer(nn.Module): 148 | """ 149 | Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from: 150 | https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016 151 | """ 152 | 153 | def __init__(self, in_dim, out_dim=None, num_layers=4, dropout=0.0): 154 | super().__init__() 155 | out_dim = out_dim or in_dim 156 | 157 | # conv layers 158 | convs = [] 159 | prev_dim, next_dim = in_dim, out_dim 160 | for i in range(num_layers): 161 | if i == num_layers - 1: 162 | next_dim = out_dim 163 | convs.extend([ 164 | nn.GroupNorm(32, prev_dim), nn.SiLU(), nn.Dropout(dropout), 165 | nn.Conv3d(prev_dim, next_dim, (3, 1, 1), padding=(1, 0, 0)) 166 | ]) 167 | prev_dim, next_dim = next_dim, prev_dim 168 | self.convs = nn.ModuleList(convs) 169 | 170 | def forward(self, hidden_states): 171 | video_length = hidden_states.shape[2] 172 | 173 | identity = hidden_states 174 | for conv in self.convs: 175 | hidden_states = conv(hidden_states) 176 | 177 | # ignore effects of temporal layers on image inputs 178 | hidden_states = identity + hidden_states if video_length > 1 \ 179 | else identity + 0.0 * hidden_states 180 | 181 | return hidden_states 182 | 183 | 184 | class UNetMidBlock3DCrossAttn(nn.Module): 185 | def __init__( 186 | self, 187 | in_channels: int, 188 | temb_channels: int, 189 | dropout: float = 0.0, 190 | num_layers: int = 1, 191 | resnet_eps: float = 1e-6, 192 | resnet_time_scale_shift: str = "default", 193 | resnet_act_fn: str = "swish", 194 | resnet_groups: int = 32, 195 | resnet_pre_norm: bool = True, 196 | attn_num_head_channels=1, 197 | output_scale_factor=1.0, 198 | cross_attention_dim=1280, 199 | dual_cross_attention=False, 200 | use_linear_projection=False, 201 | upcast_attention=False, 202 | add_temp_attn=False, 203 | prepend_first_frame=False, 204 | add_temp_embed=False, 205 | add_temp_ff=False, 206 | add_temp_conv=False, 207 | ): 208 | super().__init__() 209 | 210 | self.has_cross_attention = True 211 | self.attn_num_head_channels = attn_num_head_channels 212 | resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) 213 | 214 | # there is always at least one resnet 215 | resnets = [ 216 | ResnetBlock3D( 217 | in_channels=in_channels, 218 | out_channels=in_channels, 219 | temb_channels=temb_channels, 220 | eps=resnet_eps, 221 | groups=resnet_groups, 222 | dropout=dropout, 223 | time_embedding_norm=resnet_time_scale_shift, 224 | non_linearity=resnet_act_fn, 225 | output_scale_factor=output_scale_factor, 226 | pre_norm=resnet_pre_norm, 227 | ) 228 | ] 229 | attentions = [] 230 | if add_temp_conv: 231 | self.temp_convs = None 232 | temp_convs = [TemporalConvLayer(in_channels, in_channels, dropout=0.1)] 233 | temp_convs[-1].convs[-1] = zero_module(temp_convs[-1].convs[-1]) 234 | 235 | for _ in range(num_layers): 236 | if dual_cross_attention: 237 | raise NotImplementedError 238 | attentions.append( 239 | Transformer3DModel( 240 | attn_num_head_channels, 241 | in_channels // attn_num_head_channels, 242 | in_channels=in_channels, 243 | num_layers=1, 244 | cross_attention_dim=cross_attention_dim, 245 | norm_num_groups=resnet_groups, 246 | use_linear_projection=use_linear_projection, 247 | upcast_attention=upcast_attention, 248 | add_temp_attn=add_temp_attn, 249 | prepend_first_frame=prepend_first_frame, 250 | add_temp_embed=add_temp_embed, 251 | add_temp_ff=add_temp_ff 252 | ) 253 | ) 254 | resnets.append( 255 | ResnetBlock3D( 256 | in_channels=in_channels, 257 | out_channels=in_channels, 258 | temb_channels=temb_channels, 259 | eps=resnet_eps, 260 | groups=resnet_groups, 261 | dropout=dropout, 262 | time_embedding_norm=resnet_time_scale_shift, 263 | non_linearity=resnet_act_fn, 264 | output_scale_factor=output_scale_factor, 265 | pre_norm=resnet_pre_norm, 266 | ) 267 | ) 268 | if hasattr(self, "temp_convs"): 269 | temp_convs.append(TemporalConvLayer(in_channels, in_channels, dropout=0.1)) 270 | temp_convs[-1].convs[-1] = zero_module(temp_convs[-1].convs[-1]) 271 | 272 | self.attentions = nn.ModuleList(attentions) 273 | self.resnets = nn.ModuleList(resnets) 274 | if hasattr(self, "temp_convs"): 275 | self.temp_convs = nn.ModuleList(temp_convs) 276 | 277 | def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): 278 | hidden_states = self.resnets[0](hidden_states, temb) 279 | if hasattr(self, "temp_convs"): 280 | hidden_states = self.temp_convs[0](hidden_states) 281 | for layer_idx in range(len(self.attentions)): 282 | attn = self.attentions[layer_idx] 283 | resnet = self.resnets[layer_idx + 1] 284 | hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states)[0] 285 | hidden_states = resnet(hidden_states, temb) 286 | if hasattr(self, "temp_convs"): 287 | temp_conv = self.temp_convs[layer_idx + 1] 288 | hidden_states = temp_conv(hidden_states) 289 | 290 | return hidden_states 291 | 292 | 293 | class CrossAttnDownBlock3D(nn.Module): 294 | def __init__( 295 | self, 296 | in_channels: int, 297 | out_channels: int, 298 | temb_channels: int, 299 | dropout: float = 0.0, 300 | num_layers: int = 1, 301 | resnet_eps: float = 1e-6, 302 | resnet_time_scale_shift: str = "default", 303 | resnet_act_fn: str = "swish", 304 | resnet_groups: int = 32, 305 | resnet_pre_norm: bool = True, 306 | attn_num_head_channels=1, 307 | cross_attention_dim=1280, 308 | output_scale_factor=1.0, 309 | downsample_padding=1, 310 | add_downsample=True, 311 | dual_cross_attention=False, 312 | use_linear_projection=False, 313 | only_cross_attention=False, 314 | upcast_attention=False, 315 | add_temp_attn=False, 316 | prepend_first_frame=False, 317 | add_temp_embed=False, 318 | add_temp_ff=False, 319 | add_temp_conv=False, 320 | ): 321 | super().__init__() 322 | resnets = [] 323 | attentions = [] 324 | if add_temp_conv: 325 | self.temp_convs = None 326 | temp_convs = [] 327 | 328 | self.has_cross_attention = True 329 | self.attn_num_head_channels = attn_num_head_channels 330 | 331 | for i in range(num_layers): 332 | in_channels = in_channels if i == 0 else out_channels 333 | resnets.append( 334 | ResnetBlock3D( 335 | in_channels=in_channels, 336 | out_channels=out_channels, 337 | temb_channels=temb_channels, 338 | eps=resnet_eps, 339 | groups=resnet_groups, 340 | dropout=dropout, 341 | time_embedding_norm=resnet_time_scale_shift, 342 | non_linearity=resnet_act_fn, 343 | output_scale_factor=output_scale_factor, 344 | pre_norm=resnet_pre_norm, 345 | ) 346 | ) 347 | if hasattr(self, "temp_convs"): 348 | temp_convs.append(TemporalConvLayer(out_channels, out_channels, dropout=0.1)) 349 | temp_convs[-1].convs[-1] = zero_module(temp_convs[-1].convs[-1]) 350 | if dual_cross_attention: 351 | raise NotImplementedError 352 | attentions.append( 353 | Transformer3DModel( 354 | attn_num_head_channels, 355 | out_channels // attn_num_head_channels, 356 | in_channels=out_channels, 357 | num_layers=1, 358 | cross_attention_dim=cross_attention_dim, 359 | norm_num_groups=resnet_groups, 360 | use_linear_projection=use_linear_projection, 361 | only_cross_attention=only_cross_attention, 362 | upcast_attention=upcast_attention, 363 | add_temp_attn=add_temp_attn, 364 | prepend_first_frame=prepend_first_frame, 365 | add_temp_embed=add_temp_embed, 366 | add_temp_ff=add_temp_ff, 367 | ) 368 | ) 369 | self.attentions = nn.ModuleList(attentions) 370 | self.resnets = nn.ModuleList(resnets) 371 | if hasattr(self, "temp_convs"): 372 | self.temp_convs = nn.ModuleList(temp_convs) 373 | 374 | if add_downsample: 375 | self.downsamplers = nn.ModuleList( 376 | [ 377 | Downsample3D( 378 | out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" 379 | ) 380 | ] 381 | ) 382 | else: 383 | self.downsamplers = None 384 | 385 | self.gradient_checkpointing = False 386 | 387 | def forward( 388 | self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): 389 | output_states = () 390 | 391 | for layer_idx in range(len(self.resnets)): 392 | resnet, attn = self.resnets[layer_idx], self.attentions[layer_idx] 393 | is_checkpointing = self.training and self.gradient_checkpointing 394 | hidden_states = checkpoint(func=resnet, inputs=(hidden_states, temb), flag=is_checkpointing) 395 | if hasattr(self, "temp_convs"): 396 | temp_conv = self.temp_convs[layer_idx] 397 | hidden_states = checkpoint(func=temp_conv, inputs=(hidden_states,), flag=is_checkpointing) 398 | hidden_states = checkpoint( 399 | func=attn, inputs=(hidden_states, encoder_hidden_states), flag=is_checkpointing)[0] 400 | 401 | output_states += (hidden_states,) 402 | 403 | if self.downsamplers is not None: 404 | for downsampler in self.downsamplers: 405 | hidden_states = downsampler(hidden_states) 406 | 407 | output_states += (hidden_states,) 408 | 409 | return hidden_states, output_states 410 | 411 | 412 | class DownBlock3D(nn.Module): 413 | def __init__( 414 | self, 415 | in_channels: int, 416 | out_channels: int, 417 | temb_channels: int, 418 | dropout: float = 0.0, 419 | num_layers: int = 1, 420 | resnet_eps: float = 1e-6, 421 | resnet_time_scale_shift: str = "default", 422 | resnet_act_fn: str = "swish", 423 | resnet_groups: int = 32, 424 | resnet_pre_norm: bool = True, 425 | output_scale_factor=1.0, 426 | add_downsample=True, 427 | downsample_padding=1, 428 | add_temp_conv=False, 429 | ): 430 | super().__init__() 431 | resnets = [] 432 | if add_temp_conv: 433 | self.temp_convs = None 434 | temp_convs = [] 435 | for i in range(num_layers): 436 | in_channels = in_channels if i == 0 else out_channels 437 | resnets.append( 438 | ResnetBlock3D( 439 | in_channels=in_channels, 440 | out_channels=out_channels, 441 | temb_channels=temb_channels, 442 | eps=resnet_eps, 443 | groups=resnet_groups, 444 | dropout=dropout, 445 | time_embedding_norm=resnet_time_scale_shift, 446 | non_linearity=resnet_act_fn, 447 | output_scale_factor=output_scale_factor, 448 | pre_norm=resnet_pre_norm, 449 | ) 450 | ) 451 | if hasattr(self, "temp_convs"): 452 | temp_convs.append(TemporalConvLayer(out_channels, out_channels, dropout=0.1)) 453 | temp_convs[-1].convs[-1] = zero_module(temp_convs[-1].convs[-1]) 454 | 455 | self.resnets = nn.ModuleList(resnets) 456 | if hasattr(self, "temp_convs"): 457 | self.temp_convs = nn.ModuleList(temp_convs) 458 | 459 | if add_downsample: 460 | self.downsamplers = nn.ModuleList( 461 | [ 462 | Downsample3D( 463 | out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" 464 | ) 465 | ] 466 | ) 467 | else: 468 | self.downsamplers = None 469 | 470 | self.gradient_checkpointing = False 471 | 472 | def forward(self, hidden_states, temb=None): 473 | output_states = () 474 | 475 | for layer_idx in range(len(self.resnets)): 476 | resnet = self.resnets[layer_idx] 477 | is_checkpointing = self.training and self.gradient_checkpointing 478 | hidden_states = checkpoint(func=resnet, inputs=(hidden_states, temb), flag=is_checkpointing) 479 | if hasattr(self, "temp_convs"): 480 | temp_conv = self.temp_convs[layer_idx] 481 | hidden_states = checkpoint(func=temp_conv, inputs=(hidden_states,), flag=is_checkpointing) 482 | 483 | output_states += (hidden_states,) 484 | 485 | if self.downsamplers is not None: 486 | for downsampler in self.downsamplers: 487 | hidden_states = downsampler(hidden_states) 488 | 489 | output_states += (hidden_states,) 490 | 491 | return hidden_states, output_states 492 | 493 | 494 | class CrossAttnUpBlock3D(nn.Module): 495 | def __init__( 496 | self, 497 | in_channels: int, 498 | out_channels: int, 499 | prev_output_channel: int, 500 | temb_channels: int, 501 | dropout: float = 0.0, 502 | num_layers: int = 1, 503 | resnet_eps: float = 1e-6, 504 | resnet_time_scale_shift: str = "default", 505 | resnet_act_fn: str = "swish", 506 | resnet_groups: int = 32, 507 | resnet_pre_norm: bool = True, 508 | attn_num_head_channels=1, 509 | cross_attention_dim=1280, 510 | output_scale_factor=1.0, 511 | add_upsample=True, 512 | dual_cross_attention=False, 513 | use_linear_projection=False, 514 | only_cross_attention=False, 515 | upcast_attention=False, 516 | add_temp_attn=False, 517 | prepend_first_frame=False, 518 | add_temp_embed=False, 519 | add_temp_ff=False, 520 | add_temp_conv=False, 521 | ): 522 | super().__init__() 523 | resnets = [] 524 | attentions = [] 525 | if add_temp_conv: 526 | self.temp_convs = None 527 | temp_convs = [] 528 | 529 | self.has_cross_attention = True 530 | self.attn_num_head_channels = attn_num_head_channels 531 | 532 | for i in range(num_layers): 533 | res_skip_channels = in_channels if (i == num_layers - 1) else out_channels 534 | resnet_in_channels = prev_output_channel if i == 0 else out_channels 535 | 536 | resnets.append( 537 | ResnetBlock3D( 538 | in_channels=resnet_in_channels + res_skip_channels, 539 | out_channels=out_channels, 540 | temb_channels=temb_channels, 541 | eps=resnet_eps, 542 | groups=resnet_groups, 543 | dropout=dropout, 544 | time_embedding_norm=resnet_time_scale_shift, 545 | non_linearity=resnet_act_fn, 546 | output_scale_factor=output_scale_factor, 547 | pre_norm=resnet_pre_norm, 548 | ) 549 | ) 550 | if dual_cross_attention: 551 | raise NotImplementedError 552 | attentions.append( 553 | Transformer3DModel( 554 | attn_num_head_channels, 555 | out_channels // attn_num_head_channels, 556 | in_channels=out_channels, 557 | num_layers=1, 558 | cross_attention_dim=cross_attention_dim, 559 | norm_num_groups=resnet_groups, 560 | use_linear_projection=use_linear_projection, 561 | only_cross_attention=only_cross_attention, 562 | upcast_attention=upcast_attention, 563 | add_temp_attn=add_temp_attn, 564 | prepend_first_frame=prepend_first_frame, 565 | add_temp_embed=add_temp_embed, 566 | add_temp_ff=add_temp_ff 567 | ) 568 | ) 569 | if hasattr(self, "temp_convs"): 570 | temp_convs.append(TemporalConvLayer(out_channels, out_channels, dropout=0.1)) 571 | temp_convs[-1].convs[-1] = zero_module(temp_convs[-1].convs[-1]) 572 | 573 | self.attentions = nn.ModuleList(attentions) 574 | self.resnets = nn.ModuleList(resnets) 575 | if hasattr(self, "temp_convs"): 576 | self.temp_convs = nn.ModuleList(temp_convs) 577 | 578 | if add_upsample: 579 | self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]) 580 | else: 581 | self.upsamplers = None 582 | 583 | self.gradient_checkpointing = False 584 | 585 | def forward( 586 | self, 587 | hidden_states, 588 | res_hidden_states_tuple, 589 | temb=None, 590 | encoder_hidden_states=None, 591 | upsample_size=None, 592 | attention_mask=None 593 | ): 594 | for layer_idx in range(len(self.resnets)): 595 | resnet, attn = self.resnets[layer_idx], self.attentions[layer_idx] 596 | # pop res hidden states 597 | res_hidden_states = res_hidden_states_tuple[-1] 598 | res_hidden_states_tuple = res_hidden_states_tuple[:-1] 599 | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) 600 | 601 | is_checkpointing = self.training and self.gradient_checkpointing 602 | hidden_states = checkpoint(func=resnet, inputs=(hidden_states, temb), flag=is_checkpointing) 603 | if hasattr(self, "temp_convs"): 604 | temp_conv = self.temp_convs[layer_idx] 605 | hidden_states = checkpoint(func=temp_conv, inputs=(hidden_states,), flag=is_checkpointing) 606 | hidden_states = checkpoint( 607 | func=attn, inputs=(hidden_states, encoder_hidden_states), flag=is_checkpointing)[0] 608 | 609 | if self.upsamplers is not None: 610 | for upsampler in self.upsamplers: 611 | hidden_states = upsampler(hidden_states, upsample_size) 612 | 613 | return hidden_states 614 | 615 | 616 | class UpBlock3D(nn.Module): 617 | def __init__( 618 | self, 619 | in_channels: int, 620 | prev_output_channel: int, 621 | out_channels: int, 622 | temb_channels: int, 623 | dropout: float = 0.0, 624 | num_layers: int = 1, 625 | resnet_eps: float = 1e-6, 626 | resnet_time_scale_shift: str = "default", 627 | resnet_act_fn: str = "swish", 628 | resnet_groups: int = 32, 629 | resnet_pre_norm: bool = True, 630 | output_scale_factor=1.0, 631 | add_upsample=True, 632 | add_temp_conv=False, 633 | ): 634 | super().__init__() 635 | resnets = [] 636 | 637 | if add_temp_conv: 638 | self.temp_convs = None 639 | temp_convs = [] 640 | 641 | for i in range(num_layers): 642 | res_skip_channels = in_channels if (i == num_layers - 1) else out_channels 643 | resnet_in_channels = prev_output_channel if i == 0 else out_channels 644 | 645 | resnets.append( 646 | ResnetBlock3D( 647 | in_channels=resnet_in_channels + res_skip_channels, 648 | out_channels=out_channels, 649 | temb_channels=temb_channels, 650 | eps=resnet_eps, 651 | groups=resnet_groups, 652 | dropout=dropout, 653 | time_embedding_norm=resnet_time_scale_shift, 654 | non_linearity=resnet_act_fn, 655 | output_scale_factor=output_scale_factor, 656 | pre_norm=resnet_pre_norm, 657 | ) 658 | ) 659 | if hasattr(self, "temp_convs"): 660 | temp_convs.append(TemporalConvLayer(out_channels, out_channels, dropout=0.1)) 661 | temp_convs[-1].convs[-1] = zero_module(temp_convs[-1].convs[-1]) 662 | 663 | self.resnets = nn.ModuleList(resnets) 664 | if hasattr(self, "temp_convs"): 665 | self.temp_convs = nn.ModuleList(temp_convs) 666 | 667 | if add_upsample: 668 | self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]) 669 | else: 670 | self.upsamplers = None 671 | 672 | self.gradient_checkpointing = False 673 | 674 | def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): 675 | for layer_idx in range(len(self.resnets)): 676 | resnet = self.resnets[layer_idx] 677 | # pop res hidden states 678 | res_hidden_states = res_hidden_states_tuple[-1] 679 | res_hidden_states_tuple = res_hidden_states_tuple[:-1] 680 | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) 681 | 682 | is_checkpointing = self.training and self.gradient_checkpointing 683 | hidden_states = checkpoint(func=resnet, inputs=(hidden_states, temb), flag=is_checkpointing) 684 | if hasattr(self, "temp_convs"): 685 | temp_conv = self.temp_convs[layer_idx] 686 | hidden_states = checkpoint(func=temp_conv, inputs=(hidden_states,), flag=is_checkpointing) 687 | 688 | if self.upsamplers is not None: 689 | for upsampler in self.upsamplers: 690 | hidden_states = upsampler(hidden_states, upsample_size) 691 | 692 | return hidden_states 693 | -------------------------------------------------------------------------------- /model/modules/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def checkpoint(func, inputs, flag): 5 | """ 6 | Evaluate a function without caching intermediate activations, allowing for 7 | reduced memory at the expense of extra compute in the backward pass. 8 | :param func: the function to evaluate. 9 | :param inputs: the argument sequence to pass to `func`. 10 | :param flag: if False, disable gradient checkpointing. 11 | """ 12 | if flag: 13 | return torch.utils.checkpoint.checkpoint(func, *inputs, use_reentrant=False) 14 | else: 15 | return func(*inputs) 16 | 17 | 18 | def zero_module(module): 19 | """ 20 | Zero out the parameters of a module and return it. 21 | """ 22 | for p in module.parameters(): 23 | p.detach().zero_() 24 | return module 25 | -------------------------------------------------------------------------------- /model/modules/zero_snr_ddpm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 UC Berkeley Team and The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim 16 | 17 | import math 18 | from dataclasses import dataclass 19 | from typing import List, Optional, Tuple, Union 20 | 21 | import numpy as np 22 | import torch 23 | 24 | from diffusers.configuration_utils import ConfigMixin, register_to_config 25 | from diffusers.utils import BaseOutput, randn_tensor 26 | from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin 27 | 28 | 29 | @dataclass 30 | class DDPMSchedulerOutput(BaseOutput): 31 | """ 32 | Output class for the scheduler's step function output. 33 | 34 | Args: 35 | prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): 36 | Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the 37 | denoising loop. 38 | pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): 39 | The predicted denoised sample (x_{0}) based on the model output from the current timestep. 40 | `pred_original_sample` can be used to preview progress or for guidance. 41 | """ 42 | 43 | prev_sample: torch.FloatTensor 44 | pred_original_sample: Optional[torch.FloatTensor] = None 45 | 46 | 47 | def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): 48 | """ 49 | Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of 50 | (1-beta) over time from t = [0,1]. 51 | 52 | Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up 53 | to that part of the diffusion process. 54 | 55 | 56 | Args: 57 | num_diffusion_timesteps (`int`): the number of betas to produce. 58 | max_beta (`float`): the maximum beta to use; use values lower than 1 to 59 | prevent singularities. 60 | 61 | Returns: 62 | betas (`np.ndarray`): the betas used by the scheduler to step the model outputs 63 | """ 64 | 65 | def alpha_bar(time_step): 66 | return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 67 | 68 | betas = [] 69 | for i in range(num_diffusion_timesteps): 70 | t1 = i / num_diffusion_timesteps 71 | t2 = (i + 1) / num_diffusion_timesteps 72 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 73 | return torch.tensor(betas, dtype=torch.float32) 74 | 75 | def enforce_zero_terminal_snr(betas): 76 | # convert betas to alphas_bar_sqrt 77 | alphas = 1 - betas 78 | alphas_bar = alphas.cumprod(0) 79 | alphas_bar_sqrt = alphas_bar.sqrt() 80 | 81 | # store old values 82 | alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() 83 | alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() 84 | # shift so last timestep is zero. 85 | alphas_bar_sqrt -= alphas_bar_sqrt_T 86 | # scale so first timestep is back to old value. 87 | alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) 88 | 89 | # convert alphas_bar_sqrt to betas 90 | alphas_bar = alphas_bar_sqrt ** 2 91 | alphas = alphas_bar[1:] / alphas_bar[:-1] 92 | alphas = torch.cat([alphas_bar[0:1], alphas]) 93 | betas = 1 - alphas 94 | return betas 95 | 96 | class DDPMScheduler(SchedulerMixin, ConfigMixin): 97 | """ 98 | Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and 99 | Langevin dynamics sampling. 100 | 101 | [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` 102 | function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. 103 | [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and 104 | [`~SchedulerMixin.from_pretrained`] functions. 105 | 106 | For more details, see the original paper: https://arxiv.org/abs/2006.11239 107 | 108 | Args: 109 | num_train_timesteps (`int`): number of diffusion steps used to train the model. 110 | beta_start (`float`): the starting `beta` value of inference. 111 | beta_end (`float`): the final `beta` value. 112 | beta_schedule (`str`): 113 | the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from 114 | `linear`, `scaled_linear`, or `squaredcos_cap_v2`. 115 | trained_betas (`np.ndarray`, optional): 116 | option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. 117 | variance_type (`str`): 118 | options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, 119 | `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. 120 | clip_sample (`bool`, default `True`): 121 | option to clip predicted sample between -1 and 1 for numerical stability. 122 | prediction_type (`str`, default `epsilon`, optional): 123 | prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion 124 | process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 125 | https://imagen.research.google/video/paper.pdf) 126 | """ 127 | 128 | _compatibles = [e.name for e in KarrasDiffusionSchedulers] 129 | order = 1 130 | 131 | @register_to_config 132 | def __init__( 133 | self, 134 | num_train_timesteps: int = 1000, 135 | beta_start: float = 0.0001, 136 | beta_end: float = 0.02, 137 | beta_schedule: str = "linear", 138 | trained_betas: Optional[Union[np.ndarray, List[float]]] = None, 139 | variance_type: str = "fixed_small", 140 | clip_sample: bool = True, 141 | prediction_type: str = "epsilon", 142 | clip_sample_range: Optional[float] = 1.0, 143 | ): 144 | if trained_betas is not None: 145 | self.betas = torch.tensor(trained_betas, dtype=torch.float32) 146 | elif beta_schedule == "linear": 147 | self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) 148 | elif beta_schedule == "scaled_linear": 149 | # this schedule is very specific to the latent diffusion model. 150 | self.betas = ( 151 | torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 152 | ) 153 | self.betas = enforce_zero_terminal_snr(self.betas) # Zero-SNR 154 | elif beta_schedule == "squaredcos_cap_v2": 155 | # Glide cosine schedule 156 | self.betas = betas_for_alpha_bar(num_train_timesteps) 157 | elif beta_schedule == "sigmoid": 158 | # GeoDiff sigmoid schedule 159 | betas = torch.linspace(-6, 6, num_train_timesteps) 160 | self.betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start 161 | else: 162 | raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") 163 | 164 | self.alphas = 1.0 - self.betas 165 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) 166 | self.one = torch.tensor(1.0) 167 | 168 | # standard deviation of the initial noise distribution 169 | self.init_noise_sigma = 1.0 170 | 171 | # setable values 172 | self.num_inference_steps = None 173 | self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) 174 | 175 | self.variance_type = variance_type 176 | 177 | def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: 178 | """ 179 | Ensures interchangeability with schedulers that need to scale the denoising model input depending on the 180 | current timestep. 181 | 182 | Args: 183 | sample (`torch.FloatTensor`): input sample 184 | timestep (`int`, optional): current timestep 185 | 186 | Returns: 187 | `torch.FloatTensor`: scaled input sample 188 | """ 189 | return sample 190 | 191 | def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): 192 | """ 193 | Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. 194 | 195 | Args: 196 | num_inference_steps (`int`): 197 | the number of diffusion steps used when generating samples with a pre-trained model. 198 | """ 199 | 200 | if num_inference_steps > self.config.num_train_timesteps: 201 | raise ValueError( 202 | f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" 203 | f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" 204 | f" maximal {self.config.num_train_timesteps} timesteps." 205 | ) 206 | 207 | self.num_inference_steps = num_inference_steps 208 | 209 | step_ratio = self.config.num_train_timesteps // self.num_inference_steps 210 | timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) 211 | self.timesteps = torch.from_numpy(timesteps).to(device) 212 | 213 | def _get_variance(self, t, predicted_variance=None, variance_type=None): 214 | num_inference_steps = self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps 215 | prev_t = t - self.config.num_train_timesteps // num_inference_steps 216 | alpha_prod_t = self.alphas_cumprod[t] 217 | alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one 218 | current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev 219 | 220 | # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) 221 | # and sample from it to get previous sample 222 | # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample 223 | variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t 224 | 225 | if variance_type is None: 226 | variance_type = self.config.variance_type 227 | 228 | # hacks - were probably added for training stability 229 | if variance_type == "fixed_small": 230 | variance = torch.clamp(variance, min=1e-20) 231 | # for rl-diffuser https://arxiv.org/abs/2205.09991 232 | elif variance_type == "fixed_small_log": 233 | variance = torch.log(torch.clamp(variance, min=1e-20)) 234 | variance = torch.exp(0.5 * variance) 235 | elif variance_type == "fixed_large": 236 | variance = current_beta_t 237 | elif variance_type == "fixed_large_log": 238 | # Glide max_log 239 | variance = torch.log(current_beta_t) 240 | elif variance_type == "learned": 241 | return predicted_variance 242 | elif variance_type == "learned_range": 243 | min_log = torch.log(variance) 244 | max_log = torch.log(self.betas[t]) 245 | frac = (predicted_variance + 1) / 2 246 | variance = frac * max_log + (1 - frac) * min_log 247 | 248 | return variance 249 | 250 | def step( 251 | self, 252 | model_output: torch.FloatTensor, 253 | timestep: int, 254 | sample: torch.FloatTensor, 255 | generator=None, 256 | return_dict: bool = True, 257 | ) -> Union[DDPMSchedulerOutput, Tuple]: 258 | """ 259 | Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion 260 | process from the learned model outputs (most often the predicted noise). 261 | 262 | Args: 263 | model_output (`torch.FloatTensor`): direct output from learned diffusion model. 264 | timestep (`int`): current discrete timestep in the diffusion chain. 265 | sample (`torch.FloatTensor`): 266 | current instance of sample being created by diffusion process. 267 | generator: random number generator. 268 | return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class 269 | 270 | Returns: 271 | [`~schedulers.scheduling_utils.DDPMSchedulerOutput`] or `tuple`: 272 | [`~schedulers.scheduling_utils.DDPMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When 273 | returning a tuple, the first element is the sample tensor. 274 | 275 | """ 276 | t = timestep 277 | num_inference_steps = self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps 278 | prev_t = timestep - self.config.num_train_timesteps // num_inference_steps 279 | 280 | if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: 281 | model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) 282 | else: 283 | predicted_variance = None 284 | 285 | # 1. compute alphas, betas 286 | alpha_prod_t = self.alphas_cumprod[t] 287 | alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one 288 | beta_prod_t = 1 - alpha_prod_t 289 | beta_prod_t_prev = 1 - alpha_prod_t_prev 290 | current_alpha_t = alpha_prod_t / alpha_prod_t_prev 291 | current_beta_t = 1 - current_alpha_t 292 | 293 | # 2. compute predicted original sample from predicted noise also called 294 | # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf 295 | if self.config.prediction_type == "epsilon": 296 | pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) 297 | elif self.config.prediction_type == "sample": 298 | pred_original_sample = model_output 299 | elif self.config.prediction_type == "v_prediction": 300 | pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output 301 | else: 302 | raise ValueError( 303 | f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or" 304 | " `v_prediction` for the DDPMScheduler." 305 | ) 306 | 307 | # 3. Clip "predicted x_0" 308 | if self.config.clip_sample: 309 | pred_original_sample = torch.clamp( 310 | pred_original_sample, -self.config.clip_sample_range, self.config.clip_sample_range 311 | ) 312 | 313 | # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t 314 | # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf 315 | pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t 316 | current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t 317 | 318 | # 5. Compute predicted previous sample µ_t 319 | # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf 320 | pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample 321 | 322 | # 6. Add noise 323 | variance = 0 324 | if t > 0: 325 | device = model_output.device 326 | variance_noise = randn_tensor( 327 | model_output.shape, generator=generator, device=device, dtype=model_output.dtype 328 | ) 329 | if self.variance_type == "fixed_small_log": 330 | variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise 331 | elif self.variance_type == "learned_range": 332 | variance = self._get_variance(t, predicted_variance=predicted_variance) 333 | variance = torch.exp(0.5 * variance) * variance_noise 334 | else: 335 | variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * variance_noise 336 | 337 | pred_prev_sample = pred_prev_sample + variance 338 | 339 | if not return_dict: 340 | return (pred_prev_sample,) 341 | 342 | return DDPMSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample) 343 | 344 | def add_noise( 345 | self, 346 | original_samples: torch.FloatTensor, 347 | noise: torch.FloatTensor, 348 | timesteps: torch.IntTensor, 349 | ) -> torch.FloatTensor: 350 | # Make sure alphas_cumprod and timestep have same device and dtype as original_samples 351 | self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) 352 | timesteps = timesteps.to(original_samples.device) 353 | 354 | sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 355 | sqrt_alpha_prod = sqrt_alpha_prod.flatten() 356 | while len(sqrt_alpha_prod.shape) < len(original_samples.shape): 357 | sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) 358 | 359 | sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 360 | sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() 361 | while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): 362 | sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) 363 | 364 | noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise 365 | return noisy_samples 366 | 367 | def get_velocity( 368 | self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor 369 | ) -> torch.FloatTensor: 370 | # Make sure alphas_cumprod and timestep have same device and dtype as sample 371 | self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) 372 | timesteps = timesteps.to(sample.device) 373 | 374 | sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 375 | sqrt_alpha_prod = sqrt_alpha_prod.flatten() 376 | while len(sqrt_alpha_prod.shape) < len(sample.shape): 377 | sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) 378 | 379 | sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 380 | sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() 381 | while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): 382 | sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) 383 | 384 | velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample 385 | return velocity 386 | 387 | def __len__(self): 388 | return self.config.num_train_timesteps 389 | -------------------------------------------------------------------------------- /model/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import random 3 | from math import sqrt 4 | from pathlib import Path 5 | from typing import Union, Tuple, List, Optional 6 | 7 | import imageio 8 | import torch 9 | import torchvision 10 | from einops import rearrange 11 | from torch.utils.checkpoint import checkpoint 12 | 13 | import torch.nn as nn 14 | import kornia 15 | import open_clip 16 | import math 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | def randn_base( 22 | shape: Union[Tuple, List], 23 | mean: float = .0, 24 | std: float = 1., 25 | generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None, 26 | device: Optional["torch.device"] = None, 27 | dtype: Optional["torch.dtype"] = None 28 | ): 29 | if isinstance(generator, list): 30 | shape = (1,) + shape[1:] 31 | tensor = [ 32 | torch.normal( 33 | mean=mean, std=std, size=shape, generator=generator[i], 34 | device=device, dtype=dtype 35 | ) 36 | for i in range(len(generator)) 37 | ] 38 | tensor = torch.cat(tensor, dim=0).to(device) 39 | else: 40 | tensor = torch.normal( 41 | mean=mean, std=std, size=shape, generator=generator, device=device, 42 | dtype=dtype 43 | ) 44 | return tensor 45 | 46 | 47 | def randn_mixed( 48 | shape: Union[Tuple, List], 49 | dim: int, 50 | alpha: float = .0, 51 | generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None, 52 | device: Optional["torch.device"] = None, 53 | dtype: Optional["torch.dtype"] = None 54 | ): 55 | """ Refer to Section 4 of Preserve Your Own Correlation: 56 | [A Noise Prior for Video Diffusion Models](https://arxiv.org/abs/2305.10474) 57 | """ 58 | shape_shared = shape[:dim] + (1,) + shape[dim + 1:] 59 | 60 | # shared random tensor 61 | shared_std = alpha ** 2 / (1. + alpha ** 2) 62 | shared_tensor = randn_base( 63 | shape=shape_shared, mean=.0, std=shared_std, generator=generator, 64 | device=device, dtype=dtype 65 | ) 66 | 67 | # individual random tensor 68 | indv_std = 1. / (1. + alpha ** 2) 69 | indv_tensor = randn_base( 70 | shape=shape, mean=.0, std=indv_std, generator=generator, device=device, 71 | dtype=dtype 72 | ) 73 | 74 | return shared_tensor + indv_tensor 75 | 76 | 77 | def randn_progressive( 78 | shape: Union[Tuple, List], 79 | dim: int, 80 | alpha: float = .0, 81 | generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None, 82 | device: Optional["torch.device"] = None, 83 | dtype: Optional["torch.dtype"] = None 84 | ): 85 | """ Refer to Section 4 of Preserve Your Own Correlation: 86 | [A Noise Prior for Video Diffusion Models](https://arxiv.org/abs/2305.10474) 87 | """ 88 | num_prog = shape[dim] 89 | shape_slice = shape[:dim] + (1,) + shape[dim + 1:] 90 | tensors = [randn_base(shape=shape_slice, mean=.0, std=1., generator=generator, device=device, dtype=dtype)] 91 | beta = alpha / sqrt(1. + alpha ** 2) 92 | std = 1. / (1. + alpha ** 2) 93 | for i in range(1, num_prog): 94 | tensor_i = beta * tensors[-1] + randn_base( 95 | shape=shape_slice, mean=.0, std=std, generator=generator, device=device, dtype=dtype 96 | ) 97 | tensors.append(tensor_i) 98 | tensors = torch.cat(tensors, dim=dim) 99 | return tensors 100 | 101 | 102 | def prepare_masked_latents(images, vae_encode_func, scaling_factor=0.18215, sample_size=32, null_img_ratio=0): 103 | masks = torch.ones_like(images) # shape: [b, f, c, h, w] 104 | if random.random() < (1 - null_img_ratio): 105 | masks[:, 0, ...] = 0 106 | # masks[:, random.randrange(masks.shape[1]), ...] = 0 # TK 107 | masked_latents = images * (masks < 0.5) 108 | # map masks into latent space 109 | masks = masks[:, :, :1, :sample_size, :sample_size] 110 | masks = rearrange(masks, "b f c h w -> b c f h w") 111 | # map masked_latents into latent space 112 | masked_latents = vae_encode_func( 113 | masked_latents.view(masked_latents.shape[0] * masked_latents.shape[1], *masked_latents.shape[2:]) 114 | ).latent_dist.sample() * scaling_factor 115 | masked_latents = rearrange(masked_latents, "(b f) c h w -> b c f h w", f=images.shape[1]) 116 | 117 | return masks, masked_latents 118 | 119 | def decode_latents(vae_func, latents): 120 | scaling_factor = 0.18215 121 | video_length = latents.shape[2] 122 | latents = 1 / scaling_factor * latents 123 | latents = rearrange(latents, "b c f h w -> (b f) c h w") 124 | video = vae_func(latents).sample 125 | video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length) 126 | video = (video / 2 + 0.5).clamp(0, 1) 127 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 128 | video = video.float() 129 | return video 130 | 131 | def prepare_entity_latents(images, vae_encode_func, vae_decode_func, scaling_factor=0.18215, sample_size=32, null_img_ratio=0): 132 | # map masked_latents into latent space 133 | masked_latents = images 134 | masked_latents = vae_encode_func( 135 | masked_latents.view(masked_latents.shape[0] * masked_latents.shape[1], *masked_latents.shape[2:]) 136 | ).latent_dist.sample() * scaling_factor 137 | entity_vae_latent = rearrange(masked_latents, "(b f) c h w -> b c f h w", f=images.shape[1]) 138 | 139 | # save decode video 140 | # video = decode_latents(vae_decode_func, entity_vae_latent) 141 | # save_videos_grid(video, Path(f"samples_s{3000}", f"test decode.gif")) 142 | return entity_vae_latent 143 | 144 | def save_videos_grid(videos, path, rescale=False, n_rows=4, fps=4): 145 | if videos.dim() == 4: 146 | videos = videos.unsqueeze(0) 147 | videos = rearrange(videos, "b c t h w -> t b c h w") 148 | outputs = [] 149 | for x in videos: 150 | x = torchvision.utils.make_grid(x, nrow=n_rows) 151 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) 152 | if rescale: 153 | x = (x + 1.0) / 2.0 # [-1, 1) -> [0, 1) 154 | x = (x * 255).to(dtype=torch.uint8, device="cpu") 155 | outputs.append(x) 156 | Path(path).parent.mkdir(parents=True, exist_ok=True) 157 | imageio.mimwrite(Path(path).as_posix(), outputs, duration=1000 / fps, loop=0) 158 | 159 | 160 | @torch.no_grad() 161 | def compute_clip_score(model, model_processor, images, texts, local_bs=32, rescale=False): 162 | if rescale: 163 | images = (images + 1.0) / 2.0 # -1,1 -> 0,1 164 | images = (images * 255).to(torch.uint8) 165 | clip_scores = [] 166 | for start_idx in range(0, images.shape[0], local_bs): 167 | img_batch = images[start_idx:start_idx + local_bs] 168 | batch_size = img_batch.shape[0] # shape: [b c t h w] 169 | img_batch = rearrange(img_batch, "b c t h w -> (b t) c h w") 170 | outputs = [] 171 | for i in range(len(img_batch)): 172 | images_part = img_batch[i:i + 1] 173 | model_inputs = model_processor( 174 | text=texts, images=list(images_part), return_tensors="pt", padding=True 175 | ) 176 | model_inputs = { 177 | k: v.to(device=model.device, dtype=model.dtype) 178 | if k in ["pixel_values"] else v.to(device=model.device) 179 | for k, v in model_inputs.items() 180 | } 181 | logits = model(**model_inputs)["logits_per_image"] 182 | # For consistency with `torchmetrics.functional.multimodal.clip_score`. 183 | logits = logits / model.logit_scale.exp() 184 | outputs.append(logits) 185 | logits = torch.cat(outputs) 186 | logits = rearrange(logits, "(b t) p -> t b p", b=batch_size) 187 | frame_sims = [] 188 | for logit in logits: 189 | frame_sims.append(logit.diagonal()) 190 | frame_sims = torch.stack(frame_sims) # [t, b] 191 | clip_scores.append(frame_sims.mean(dim=0)) 192 | return torch.cat(clip_scores) 193 | 194 | 195 | class AbstractEncoder(nn.Module): 196 | def __init__(self): 197 | super().__init__() 198 | 199 | def encode(self, *args, **kwargs): 200 | raise NotImplementedError 201 | 202 | class FrozenOpenCLIPEmbedder(AbstractEncoder): 203 | """ 204 | Uses the OpenCLIP transformer encoder for text 205 | """ 206 | LAYERS = [ 207 | # "pooled", 208 | "last", 209 | "penultimate" 210 | ] 211 | 212 | def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, 213 | freeze=True, layer="last"): 214 | super().__init__() 215 | assert layer in self.LAYERS 216 | pretrained = "/home/user/model/open_clip/open_clip_ViT_H_14/open_clip_pytorch_model.bin" 217 | model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=pretrained) 218 | # model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu')) 219 | del model.visual 220 | self.model = model 221 | 222 | self.device = device 223 | self.max_length = max_length 224 | if freeze: 225 | self.freeze() 226 | self.layer = layer 227 | if self.layer == "last": 228 | self.layer_idx = 0 229 | elif self.layer == "penultimate": 230 | self.layer_idx = 1 231 | else: 232 | raise NotImplementedError() 233 | 234 | def freeze(self): 235 | self.model = self.model.eval() 236 | for param in self.parameters(): 237 | param.requires_grad = False 238 | 239 | def forward(self, text): 240 | self.device = self.model.positional_embedding.device 241 | tokens = open_clip.tokenize(text) 242 | z = self.encode_with_transformer(tokens.to(self.device)) 243 | return z 244 | 245 | def encode_with_transformer(self, text): 246 | x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] 247 | x = x + self.model.positional_embedding 248 | x = x.permute(1, 0, 2) # NLD -> LND 249 | x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) 250 | x = x.permute(1, 0, 2) # LND -> NLD 251 | x = self.model.ln_final(x) 252 | return x 253 | 254 | def text_transformer_forward(self, x: torch.Tensor, attn_mask=None): 255 | for i, r in enumerate(self.model.transformer.resblocks): 256 | if i == len(self.model.transformer.resblocks) - self.layer_idx: 257 | break 258 | if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(): 259 | x = checkpoint(r, x, attn_mask) 260 | else: 261 | x = r(x, attn_mask=attn_mask) 262 | return x 263 | 264 | def encode(self, text): 265 | return self(text) 266 | 267 | 268 | class FrozenOpenCLIPImageEmbedderV2(AbstractEncoder): 269 | """ 270 | Uses the OpenCLIP vision transformer encoder for images 271 | """ 272 | 273 | def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", 274 | freeze=True, layer="pooled", antialias=True): 275 | super().__init__() 276 | 277 | pretrained = "/home/user/model/open_clip/open_clip_ViT_H_14/open_clip_pytorch_model.bin" 278 | # model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version) 279 | model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=pretrained) 280 | 281 | del model.transformer 282 | self.model = model 283 | self.device = device 284 | 285 | if freeze: 286 | self.freeze() 287 | self.layer = layer 288 | if self.layer == "penultimate": 289 | raise NotImplementedError() 290 | self.layer_idx = 1 291 | 292 | self.antialias = antialias 293 | self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) 294 | self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) 295 | 296 | def preprocess(self, x): 297 | # normalize to [0,1] 298 | x = kornia.geometry.resize(x, (224, 224), 299 | interpolation='bicubic', align_corners=True, 300 | antialias=self.antialias) 301 | x = (x + 1.) / 2. 302 | # renormalize according to clip 303 | x = kornia.enhance.normalize(x, self.mean, self.std) 304 | return x 305 | 306 | def freeze(self): 307 | self.model = self.model.eval() 308 | for param in self.model.parameters(): 309 | param.requires_grad = False 310 | 311 | def forward(self, image, no_dropout=False): 312 | ## image: b c h w 313 | z = self.encode_with_vision_transformer(image) 314 | return z 315 | 316 | def encode_with_vision_transformer(self, x): 317 | # x.shape: [1, 3, 320, 512] 318 | x = self.preprocess(x) 319 | # x.shape: [1, 3, 224, 224] 320 | 321 | # to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1 322 | if self.model.visual.input_patchnorm: # False 323 | # einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)') 324 | x = x.reshape(x.shape[0], x.shape[1], self.model.visual.grid_size[0], self.model.visual.patch_size[0], self.model.visual.grid_size[1], self.model.visual.patch_size[1]) 325 | x = x.permute(0, 2, 4, 1, 3, 5) 326 | x = x.reshape(x.shape[0], self.model.visual.grid_size[0] * self.model.visual.grid_size[1], -1) 327 | x = self.model.visual.patchnorm_pre_ln(x) 328 | x = self.model.visual.conv1(x) 329 | else: 330 | x = self.model.visual.conv1(x) # shape = [*, width, grid, grid] 331 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 332 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 333 | # x.shape: [1, 256, 1280] 334 | 335 | # class embeddings and positional embeddings 336 | x = torch.cat( 337 | [self.model.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), 338 | x], dim=1) # shape = [*, grid ** 2 + 1, width] 339 | x = x + self.model.visual.positional_embedding.to(x.dtype) 340 | 341 | # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in 342 | x = self.model.visual.patch_dropout(x) 343 | x = self.model.visual.ln_pre(x) 344 | 345 | x = x.permute(1, 0, 2) # NLD -> LND 346 | x = self.model.visual.transformer(x) 347 | 348 | x = x.permute(1, 0, 2) # LND -> NLD 349 | 350 | # x.shape: [1, 257, 1280] 351 | return x 352 | 353 | def reshape_tensor(x, heads): 354 | bs, length, width = x.shape 355 | #(bs, length, width) --> (bs, length, n_heads, dim_per_head) 356 | x = x.view(bs, length, heads, -1) 357 | # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) 358 | x = x.transpose(1, 2) 359 | # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) 360 | x = x.reshape(bs, heads, length, -1) 361 | return x 362 | 363 | class PerceiverAttention(nn.Module): 364 | def __init__(self, *, dim, dim_head=64, heads=8): 365 | super().__init__() 366 | self.scale = dim_head**-0.5 367 | self.dim_head = dim_head 368 | self.heads = heads 369 | inner_dim = dim_head * heads 370 | 371 | self.norm1 = nn.LayerNorm(dim) 372 | self.norm2 = nn.LayerNorm(dim) 373 | 374 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 375 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) 376 | self.to_out = nn.Linear(inner_dim, dim, bias=False) 377 | 378 | 379 | def forward(self, x, latents): 380 | """ 381 | Args: 382 | x (torch.Tensor): image features 383 | shape (b, n1, D) 384 | latent (torch.Tensor): latent features 385 | shape (b, n2, D) 386 | """ 387 | x = self.norm1(x) 388 | latents = self.norm2(latents) 389 | 390 | b, l, _ = latents.shape 391 | 392 | q = self.to_q(latents) 393 | kv_input = torch.cat((x, latents), dim=-2) 394 | k, v = self.to_kv(kv_input).chunk(2, dim=-1) 395 | 396 | q = reshape_tensor(q, self.heads) 397 | k = reshape_tensor(k, self.heads) 398 | v = reshape_tensor(v, self.heads) 399 | 400 | # attention 401 | scale = 1 / math.sqrt(math.sqrt(self.dim_head)) 402 | weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards 403 | weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) 404 | out = weight @ v 405 | 406 | out = out.permute(0, 2, 1, 3).reshape(b, l, -1) 407 | 408 | return self.to_out(out) 409 | 410 | def FeedForward(dim, mult=4): 411 | inner_dim = int(dim * mult) 412 | return nn.Sequential( 413 | nn.LayerNorm(dim), 414 | nn.Linear(dim, inner_dim, bias=False), 415 | nn.GELU(), 416 | nn.Linear(inner_dim, dim, bias=False), 417 | ) 418 | 419 | class Resampler(nn.Module): 420 | def __init__( 421 | self, 422 | dim=1024, 423 | depth=8, 424 | dim_head=64, 425 | heads=16, 426 | num_queries=8, 427 | embedding_dim=768, 428 | output_dim=1024, 429 | ff_mult=4, 430 | ): 431 | super().__init__() 432 | 433 | self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) 434 | 435 | self.proj_in = nn.Linear(embedding_dim, dim) 436 | 437 | self.proj_out = nn.Linear(dim, output_dim) 438 | self.norm_out = nn.LayerNorm(output_dim) 439 | 440 | self.layers = nn.ModuleList([]) 441 | for _ in range(depth): 442 | self.layers.append( 443 | nn.ModuleList( 444 | [ 445 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), 446 | FeedForward(dim=dim, mult=ff_mult), 447 | ] 448 | ) 449 | ) 450 | 451 | def forward(self, x): 452 | latents = self.latents.repeat(x.size(0), 1, 1) 453 | 454 | x = self.proj_in(x) 455 | 456 | for attn, ff in self.layers: 457 | latents = attn(x, latents) + latents 458 | latents = ff(latents) + latents 459 | 460 | latents = self.proj_out(latents) 461 | return self.norm_out(latents) 462 | 463 | class ImageProjModel(nn.Module): 464 | """Projection Model""" 465 | def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): 466 | super().__init__() 467 | self.cross_attention_dim = cross_attention_dim 468 | self.clip_extra_context_tokens = clip_extra_context_tokens 469 | self.proj = nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) 470 | self.norm = nn.LayerNorm(cross_attention_dim) 471 | 472 | def forward(self, image_embeds): 473 | #embeds = image_embeds 474 | embeds = image_embeds.type(list(self.proj.parameters())[0].dtype) 475 | clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim) 476 | clip_extra_context_tokens = self.norm(clip_extra_context_tokens) 477 | return clip_extra_context_tokens 478 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.12.1 2 | torchvision==0.13.1 3 | diffusers[torch]==0.14.0 4 | transformers>=4.25.1 5 | hydra:hydra-core>=1.2.0 6 | pytorch-lightning>=1.9.2 7 | decord==0.6.0 8 | accelerate 9 | tensorboard 10 | modelcards 11 | omegaconf 12 | einops 13 | imageio 14 | ftfy -------------------------------------------------------------------------------- /scripts/launcher.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | # parse arguments 5 | parser = argparse.ArgumentParser(description="Launcher for ModelArts") 6 | parser.add_argument( 7 | "--script_path", 8 | required=True, 9 | type=str, 10 | default="VATMM/scripts/runs/run1.sh", 11 | help="Shell script path to evaluate.", 12 | ) 13 | parser.add_argument( 14 | "--world_size", 15 | default=1, 16 | type=int, 17 | help="Number of nodes.", 18 | ) 19 | parser.add_argument( 20 | "--rank", 21 | default=0, 22 | type=int, 23 | help="Node rank.", 24 | ) 25 | args, _ = parser.parse_known_args() 26 | 27 | # get base directory 28 | JOB_DIR = os.getenv("MA_JOB_DIR", "/home/ma-user/modelarts/user-job-dir/") 29 | 30 | # get absolute path of script 31 | script_path = os.path.join(JOB_DIR, args.script_path) 32 | 33 | # test if script exist 34 | if not os.path.exists(script_path): 35 | raise FileNotFoundError(script_path) 36 | 37 | # run script 38 | os.system( 39 | f"/bin/bash {script_path}" 40 | f" --world_size {args.world_size}" 41 | f" --node_rank {args.rank}" 42 | ) 43 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .dist import dist_envs 2 | from .file_ops import ops 3 | from .logger import enable_logger 4 | from .registry import Registry 5 | from .utils import save_config, get_free_space, instantiate_multi, get_free_mem, slugify 6 | -------------------------------------------------------------------------------- /utils/dist.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | logger = logging.getLogger(__name__) 4 | 5 | 6 | def ensure_env_init(fn): 7 | def wrapped_fn(obj): 8 | if not obj.is_initialized: 9 | logger.error("Distributed environments are not initialized!") 10 | return fn(obj) 11 | 12 | return wrapped_fn 13 | 14 | 15 | class DistEnvs: 16 | """Environments for distributed training.""" 17 | 18 | def __init__(self): 19 | self._world_size = 1 20 | self._num_nodes = 1 21 | self._node_rank = 0 22 | self._global_rank = 0 23 | self._local_rank = 0 24 | self.is_initialized = False 25 | 26 | def init_envs(self, trainer): 27 | """Distributed Environments need be initialized once and only once.""" 28 | if self.is_initialized: 29 | logger.warning("Distributed environments are repeatedly initialized!") 30 | self.is_initialized = True 31 | self._world_size = trainer.world_size 32 | self._num_nodes = trainer.num_nodes 33 | self._node_rank = trainer.node_rank 34 | self._global_rank = trainer.global_rank 35 | self._local_rank = trainer.local_rank 36 | 37 | @property 38 | @ensure_env_init 39 | def world_size(self): 40 | return self._world_size 41 | 42 | @property 43 | @ensure_env_init 44 | def num_nodes(self): 45 | return self._num_nodes 46 | 47 | @property 48 | @ensure_env_init 49 | def node_rank(self): 50 | return self._node_rank 51 | 52 | @property 53 | @ensure_env_init 54 | def global_rank(self): 55 | return self._global_rank 56 | 57 | @property 58 | @ensure_env_init 59 | def local_rank(self): 60 | return self._local_rank 61 | 62 | 63 | # singleton by module 64 | dist_envs = DistEnvs() 65 | -------------------------------------------------------------------------------- /utils/file_ops.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | import shutil 4 | from distutils.dir_util import copy_tree 5 | 6 | 7 | class FileOps: 8 | """ 9 | Unified file operations for both s3 and local paths 10 | """ 11 | 12 | def __init__(self): 13 | if importlib.util.find_spec("moxing"): 14 | import moxing as mox 15 | self.mox_valid = True 16 | self.mox = mox.file 17 | else: 18 | self.mox_valid = False 19 | self.mox = None 20 | 21 | @property 22 | def open(self): 23 | if self.mox_valid: 24 | return self.mox.File 25 | else: 26 | return open 27 | 28 | @property 29 | def exists(self): 30 | if self.mox_valid: 31 | return self.mox.exists 32 | else: 33 | return os.path.exists 34 | 35 | @property 36 | def listdir(self): 37 | if self.mox_valid: 38 | return self.mox.list_directory 39 | else: 40 | return os.listdir 41 | 42 | @property 43 | def isdir(self): 44 | if self.mox_valid: 45 | return self.mox.is_directory 46 | else: 47 | return os.path.isdir 48 | 49 | @property 50 | def makedirs(self): 51 | if self.mox_valid: 52 | return self.mox.make_dirs 53 | else: 54 | return os.makedirs 55 | 56 | @property 57 | def copy_dir(self): 58 | if self.mox_valid: 59 | return self.mox.copy_parallel 60 | else: 61 | return copy_tree 62 | 63 | @property 64 | def copy_file(self): 65 | if self.mox_valid: 66 | return self.mox.copy 67 | else: 68 | return shutil.copy 69 | 70 | def copy(self, src, dst, *args, **kwargs): 71 | if not self.exists(src): 72 | raise IOError('Source file {} does not exist.'.format(src)) 73 | if self.isdir(src): 74 | self.copy_dir(src, dst, *args, **kwargs) 75 | else: 76 | self.copy_file(src, dst, *args, **kwargs) 77 | 78 | def mkdir_or_exist(self, path): 79 | if not self.exists(path): 80 | self.makedirs(path) 81 | 82 | @property 83 | def remove(self): 84 | if self.mox_valid: 85 | return self.mox.remove 86 | else: 87 | return os.remove 88 | 89 | 90 | ops = FileOps() 91 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from .dist import dist_envs 4 | 5 | 6 | def enable_logger(logger_name, log_file=None, log_level=logging.INFO): 7 | logger = logging.getLogger(logger_name) 8 | logger.disabled = False 9 | if len(logger.handlers) > 0: 10 | return 11 | logger.propagate = False 12 | stream_handler = logging.StreamHandler() 13 | handlers = [stream_handler] 14 | 15 | if not dist_envs.is_initialized: 16 | raise ValueError(f"Distributed environments are not initialized!") 17 | 18 | # only global_rank=0 will add a FileHandler 19 | if dist_envs.global_rank == 0 and log_file is not None: 20 | file_handler = logging.FileHandler(log_file, 'w') 21 | handlers.append(file_handler) 22 | 23 | # only rank=0 for each node will print logs 24 | log_level = log_level if dist_envs.local_rank == 0 else logging.ERROR 25 | 26 | formatter = logging.Formatter("[%(asctime)s][%(name)s][%(levelname)s]: %(message)s") 27 | for handler in handlers: 28 | handler.setFormatter(formatter) 29 | handler.setLevel(log_level) 30 | logger.addHandler(handler) 31 | 32 | logger.setLevel(log_level) 33 | -------------------------------------------------------------------------------- /utils/registry.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import warnings 3 | from collections import abc 4 | from functools import partial 5 | 6 | 7 | def is_seq_of(seq, expected_type, seq_type=None): 8 | """Check whether it is a sequence of some type. 9 | 10 | Args: 11 | seq (Sequence): The sequence to be checked. 12 | expected_type (type): Expected type of sequence items. 13 | seq_type (type, optional): Expected sequence type. 14 | 15 | Returns: 16 | bool: Whether the sequence is valid. 17 | """ 18 | if seq_type is None: 19 | exp_seq_type = abc.Sequence 20 | else: 21 | assert isinstance(seq_type, type) 22 | exp_seq_type = seq_type 23 | if not isinstance(seq, exp_seq_type): 24 | return False 25 | for item in seq: 26 | if not isinstance(item, expected_type): 27 | return False 28 | return True 29 | 30 | 31 | class Registry: 32 | """A registry to map strings to classes. 33 | 34 | Args: 35 | name (str): Registry name. 36 | """ 37 | 38 | def __init__(self, name): 39 | self._name = name 40 | self._module_dict = dict() 41 | 42 | def __len__(self): 43 | return len(self._module_dict) 44 | 45 | def __contains__(self, key): 46 | return self.get(key) is not None 47 | 48 | def __repr__(self): 49 | format_str = self.__class__.__name__ + \ 50 | f'(name={self._name}, ' \ 51 | f'items={self._module_dict})' 52 | return format_str 53 | 54 | @property 55 | def name(self): 56 | return self._name 57 | 58 | @property 59 | def module_dict(self): 60 | return self._module_dict 61 | 62 | def get(self, key): 63 | """Get the registry record. 64 | 65 | Args: 66 | key (str): The class name in string format. 67 | 68 | Returns: 69 | class: The corresponding class. 70 | """ 71 | return self._module_dict.get(key, None) 72 | 73 | def _register_module(self, module_class, module_name=None, force=False): 74 | if not inspect.isclass(module_class): 75 | raise TypeError('module must be a class, ' 76 | f'but got {type(module_class)}') 77 | 78 | if module_name is None: 79 | module_name = module_class.__name__ 80 | if isinstance(module_name, str): 81 | module_name = [module_name] 82 | else: 83 | assert is_seq_of( 84 | module_name, 85 | str), ('module_name should be either of None, an ' 86 | f'instance of str or list, but got {type(module_name)}') 87 | for name in module_name: 88 | if not force and name in self._module_dict: 89 | raise KeyError(f'{name} is already registered ' 90 | f'in {self.name}') 91 | self._module_dict[name] = module_class 92 | 93 | def deprecated_register_module(self, cls=None, force=False): 94 | warnings.warn( 95 | 'The old API of register_module(module, force=False) ' 96 | 'is deprecated and will be removed, please use the new API ' 97 | 'register_module(name=None, force=False, module=None) instead.') 98 | if cls is None: 99 | return partial(self.deprecated_register_module, force=force) 100 | self._register_module(cls, force=force) 101 | return cls 102 | 103 | def register_module(self, name=None, force=False, module=None): 104 | """Register a module. 105 | 106 | A record will be added to `self._module_dict`, whose key is the class 107 | name or the specified name, and value is the class itself. 108 | It can be used as a decorator or a normal function. 109 | 110 | Example: 111 | >>> backbones = Registry('backbone') 112 | >>> @backbones.register_module() 113 | >>> class ResNet: 114 | >>> pass 115 | 116 | >>> backbones = Registry('backbone') 117 | >>> @backbones.register_module(name='mnet') 118 | >>> class MobileNet: 119 | >>> pass 120 | 121 | >>> backbones = Registry('backbone') 122 | >>> class ResNet: 123 | >>> pass 124 | >>> backbones.register_module(ResNet) 125 | 126 | Args: 127 | name (str | None): The module name to be registered. If not 128 | specified, the class name will be used. 129 | force (bool, optional): Whether to override an existing class with 130 | the same name. Default: False. 131 | module (type): Module class to be registered. 132 | """ 133 | if not isinstance(force, bool): 134 | raise TypeError(f'force must be a boolean, but got {type(force)}') 135 | # NOTE: This is a walkaround to be compatible with the old api, 136 | # while it may introduce unexpected bugs. 137 | if isinstance(name, type): 138 | return self.deprecated_register_module(name, force=force) 139 | 140 | # use it as a normal method: x.register_module(module=SomeClass) 141 | if module is not None: 142 | self._register_module( 143 | module_class=module, module_name=name, force=force) 144 | return module 145 | 146 | # raise the error ahead of time 147 | if not (name is None or isinstance(name, str)): 148 | raise TypeError(f'name must be a str, but got {type(name)}') 149 | 150 | # use it as a decorator: @x.register_module() 151 | def _register(cls): 152 | self._register_module( 153 | module_class=cls, module_name=name, force=force) 154 | return cls 155 | 156 | return _register 157 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import re 4 | import unicodedata 5 | 6 | import torch 7 | from hydra.utils import instantiate 8 | from omegaconf import OmegaConf 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | def save_config(config, config_file="config.yaml"): 14 | with open(config_file, "w") as f: 15 | OmegaConf.save(config, f) 16 | 17 | 18 | def get_free_space(path): 19 | if not os.path.exists(path): 20 | logger.info(f"Path dose not exist: {path}") 21 | return 0 22 | info = os.statvfs(path) 23 | free_size = info.f_bsize * info.f_bavail / 1024 ** 3 # GB 24 | return free_size 25 | 26 | 27 | def instantiate_multi(config, name): 28 | instances = [instantiate(x) for x in config.get(name).values() 29 | if x is not None and "_target_" in x] if name in config else [] 30 | return instances 31 | 32 | 33 | def get_free_mem(device): 34 | """ Get free memory of device. (MB) """ 35 | mem_used, mem_total = torch.cuda.mem_get_info(device) 36 | mem_free = (mem_total - mem_used) / 1024.0 ** 3 37 | return mem_free 38 | 39 | 40 | def slugify(value, allow_unicode=False): 41 | """ 42 | Taken from https://github.com/django/django/blob/master/django/utils/text.py 43 | Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated 44 | dashes to single dashes. Remove characters that aren't alphanumerics, 45 | underscores, or hyphens. Convert to lowercase. Also strip leading and 46 | trailing whitespace, dashes, and underscores. 47 | """ 48 | value = str(value) 49 | if allow_unicode: 50 | value = unicodedata.normalize('NFKC', value) 51 | else: 52 | value = unicodedata.normalize('NFKD', value).encode('ascii', 'ignore').decode('ascii') 53 | value = re.sub(r'[^\w\s-]', '', value.lower()) 54 | return re.sub(r'[-\s]+', '-', value).strip('-_') 55 | --------------------------------------------------------------------------------