├── README.md
├── assets
└── teaser
│ ├── teaser.png
│ ├── video_editing.gif
│ └── video_generation.gif
├── callbacks
├── __init__.py
└── log_master.py
├── configs
├── _meta_
│ └── inference.yaml
├── callbacks
│ ├── early_stopping.yaml
│ ├── log_master.yaml
│ ├── lr_monitor.yaml
│ └── model_checkpoint.yaml
├── data
│ ├── default.yaml
│ ├── train
│ │ └── pexels300k.yaml
│ └── val
│ │ └── pexels300k.yaml
├── default.yaml
├── logdir
│ └── default.yaml
├── loggers
│ └── tensorboard.yaml
├── model
│ └── sdvideo.yaml
└── trainer
│ ├── default.yaml
│ └── distributed.yaml
├── data
├── __init__.py
├── constants.py
├── data_downloader.py
├── dataloader.py
├── datasets.py
├── registry.py
└── utils.py
├── inference.sh
├── main.py
├── model
├── __init__.py
├── model.py
├── modules
│ ├── __init__.py
│ ├── attention.py
│ ├── attention_hasFFN.py
│ ├── modules.py
│ ├── resnet.py
│ ├── uent.py
│ ├── unet_blocks.py
│ ├── utils.py
│ └── zero_snr_ddpm.py
├── pipeline.py
└── utils.py
├── requirements.txt
├── scripts
└── launcher.py
└── utils
├── __init__.py
├── dist.py
├── file_ops.py
├── logger.py
├── registry.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
MagDiff: Multi-Alignment Diffusion for High-Fidelity Video Generation and Editing
4 |
5 | Haoyu Zhao
6 | ·
7 | Tianyi Lu
8 | ·
9 | Jiaxi Gu
10 | ·
11 | Xing Zhang
12 | ·
13 | Qingping Zheng
14 | ·
15 | Zuxuan Wu
16 | ·
17 | Hang Xu
18 | ·
19 | Yu-Gang Jiang
20 |
21 |
22 |
23 |
24 |
25 | Fudan University | Huawei Noah's Ark Lab
26 |
27 |
28 |
29 |
30 |
31 | 
32 | |
33 |
34 |
35 |
36 | 
37 | |
38 |
39 |
40 |
41 | ## 📢 News
42 | * **[2024.12.22]** Release inference code. We are working to improve MagDiff, stay tuned!
43 | * **[2024.07.04]** Our paper has been accepted by the 18th European Conference on Computer Vision (ECCV) 2024.
44 | * **[2023.11.29]** Release first paper version on Arxiv.
45 |
46 | ## 🏃♂️ Getting Started
47 | Download the pretrained base models for [StableDiffusion V2.1](https://huggingface.co/stabilityai/stable-diffusion-2-1-base).
48 |
49 | Download our MagDiff [checkpoints](https://huggingface.co/).
50 |
51 | Please follow the huggingface download instructions to download the above models and checkpoints.
52 |
53 | Below is an example structure of these model files.
54 |
55 | ```
56 | assets/
57 | ├── MagDiff.pth
58 | └── stable-diffusion-2-1-base/
59 | ├── scheduler/...
60 | ├── text_encoder/...
61 | ├── tokenizer/...
62 | ├── unet/...
63 | ├── vae/...
64 | ├── ...
65 | └── README.md
66 | ```
67 |
68 | ## ⚒️ Installation
69 | prerequisites: `python>=3.10`, `CUDA>=11.8`.
70 |
71 | Install with `pip`:
72 | ```bash
73 | pip3 install -r requirements.txt
74 | ```
75 |
76 | ## 💃 Inference
77 | Run inference on single GPU:
78 | ```bash
79 | bash inference.sh
80 | ```
81 |
82 | ## 🎓 Citation
83 | If you find this codebase useful for your research, please use the following entry.
84 | ```BibTeX
85 | @inproceedings{zhao2024magdiff,
86 | author = {Zhao, Haoyu and Lu, Tianyi and Gu, Jiaxi and Zhang, Xing and Zheng, Qingping and Wu, Zuxuan and Xu, Hang and Jiang Yu-Gang},
87 | title = {MagDiff: Multi-Alignment Diffusion for High-Fidelity Video Generation and Editing},
88 | booktitle = {European Conference on Computer Vision},
89 | year = {2024}
90 | }
91 | ```
--------------------------------------------------------------------------------
/assets/teaser/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gulucaptain/videoassembler/1fb50cd5f85aa5896b85003aff3641b00946eddc/assets/teaser/teaser.png
--------------------------------------------------------------------------------
/assets/teaser/video_editing.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gulucaptain/videoassembler/1fb50cd5f85aa5896b85003aff3641b00946eddc/assets/teaser/video_editing.gif
--------------------------------------------------------------------------------
/assets/teaser/video_generation.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gulucaptain/videoassembler/1fb50cd5f85aa5896b85003aff3641b00946eddc/assets/teaser/video_generation.gif
--------------------------------------------------------------------------------
/callbacks/__init__.py:
--------------------------------------------------------------------------------
1 | from .log_master import LogMaster
--------------------------------------------------------------------------------
/callbacks/log_master.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import json
3 | import logging
4 | import os
5 | import time
6 | from copy import deepcopy
7 | from pathlib import Path
8 |
9 | from pytorch_lightning.callbacks.callback import Callback
10 | from pytorch_lightning.loggers import TensorBoardLogger
11 | from pytorch_lightning.utilities import rank_zero_only
12 |
13 | from utils import ops
14 |
15 | logger = logging.getLogger(__name__)
16 |
17 |
18 | class LogMaster(Callback):
19 | def __init__(
20 | self, name, log_file, monitor="val_loss", monitor_fn="min",
21 | save_ckpt=True, remote_dir=None
22 | ):
23 | super().__init__()
24 | # logger's file handler is not set here and should be set outside
25 | # because rank is unknown when this function is called in ddp.
26 | self.name = name
27 | self.log_file = log_file
28 | self.monitor = monitor
29 | self.monitor_fn = eval(monitor_fn)
30 | self.save_ckpt = save_ckpt
31 | self.remote_dir = os.path.join(remote_dir, self.name) \
32 | if remote_dir else remote_dir
33 | self.model_perf = {}
34 | self.t_train_start, self.t_batch_start = .0, .0
35 |
36 | def on_train_start(self, trainer, pl_module) -> None:
37 | # log train configs
38 | logger.info(
39 | f"WORLD_SIZE={trainer.world_size}; "
40 | f"NUM_NODES={trainer.num_nodes}; "
41 | f"GPUS_PER_NODE={trainer.num_devices}; "
42 | f"STEPS_PER_EPOCH={trainer.num_training_batches}."
43 | )
44 | # upload training task config
45 | self.upload_to_cloud("config.yaml")
46 | # upload log file
47 | self.upload_to_cloud(self.log_file)
48 | self.t_train_start = time.monotonic()
49 |
50 | def on_train_batch_start(self, trainer, module, batch, batch_idx, unused=0):
51 | if batch_idx % trainer.log_every_n_steps == 1:
52 | self.t_batch_start = time.monotonic()
53 |
54 | def on_train_batch_end(
55 | self, trainer, module, outputs, batch, batch_idx, unused=0) -> None:
56 | if batch_idx % trainer.log_every_n_steps == 1:
57 | metrics = deepcopy(trainer.callback_metrics)
58 | msg = "; ".join([f"{k}={v.item():.4f}" for k, v in metrics.items()
59 | if k.startswith("train")])
60 | zfill_batch = len(str(trainer.estimated_stepping_batches))
61 | time_elapsed = datetime.timedelta(
62 | seconds=int(time.monotonic() - self.t_train_start))
63 | time_remained = datetime.timedelta(
64 | seconds=int(
65 | (time.monotonic() - self.t_batch_start) *
66 | (trainer.estimated_stepping_batches - trainer.global_step)
67 | )
68 | )
69 | time_info = f"{time_elapsed} < {time_remained}"
70 | logger.info("[Steps {}/{}]: {} (Time: {})".format(
71 | str(trainer.global_step).zfill(zfill_batch),
72 | str(trainer.estimated_stepping_batches).zfill(zfill_batch),
73 | msg, time_info
74 | ))
75 | # upload log files
76 | self.upload_to_cloud(self.log_file)
77 | self.upload_to_cloud(self.get_tblog_dir(trainer.loggers))
78 |
79 | def on_validation_end(self, trainer, pl_module):
80 | # exit()
81 | zfill_batch = len(str(trainer.estimated_stepping_batches))
82 | metrics = deepcopy(trainer.callback_metrics)
83 | monitor_metric = metrics.get(self.monitor)
84 | if trainer.global_step > 200:
85 | if trainer.global_step % 5000 == 0:
86 | ckpt_name = f"steps_{str(trainer.global_step).zfill(zfill_batch)}.pth"
87 | self.model_perf.setdefault(ckpt_name, (ckpt_name, monitor_metric))
88 | if self.save_ckpt:
89 | trainer.save_checkpoint(ckpt_name, weights_only=True)
90 | return
91 | if trainer.sanity_checking:
92 | return
93 | metrics = deepcopy(trainer.callback_metrics)
94 | metrics = {k: v.item() for k, v in metrics.items() if k.startswith("val")}
95 | if len(metrics) < 1:
96 | logger.warning("There are no metrics for validation!")
97 | zfill_batch = len(str(trainer.estimated_stepping_batches))
98 | msg = "; ".join([f"{k}={v:.6f}" for k, v in metrics.items()])
99 | logger.info("[Evaluation] [Steps={}/{}]: {}".format(
100 | str(trainer.global_step).zfill(zfill_batch),
101 | str(trainer.estimated_stepping_batches).zfill(zfill_batch), msg
102 | ))
103 | # upload log file
104 | if self.monitor not in metrics:
105 | raise KeyError(f"Metric `{self.monitor}` not in callback metrics: {metrics.keys()}")
106 | monitor_metric = metrics.get(self.monitor)
107 | self.upload_to_cloud(self.log_file)
108 | self.upload_to_cloud(self.get_tblog_dir(trainer.loggers))
109 | # compute model performance
110 | ckpt_name = f"steps_{str(trainer.global_step).zfill(zfill_batch)}.pth"
111 | self.model_perf.setdefault(ckpt_name, (ckpt_name, monitor_metric))
112 | # update best model information
113 | best_model = self.monitor_fn(self.model_perf, key=lambda x: self.model_perf.get(x)[-1])
114 | self.model_perf["best_model"] = (
115 | self.model_perf[best_model][0], self.model_perf[best_model][-1])
116 | perf_file = f"performances_{self.monitor}.json"
117 | with open(perf_file, "w") as f:
118 | json.dump(self.model_perf, f, indent=2)
119 | self.upload_to_cloud(perf_file)
120 | # save to checkpoint and upload to remote server
121 | if self.save_ckpt:
122 | trainer.save_checkpoint(ckpt_name, weights_only=True)
123 | if self.upload_to_cloud(ckpt_name):
124 | Path(ckpt_name).unlink()
125 | if self.remote_dir:
126 | logger.info(f"Detailed results are saved in: {self.remote_dir}")
127 | # upload output samples
128 | samples_dir = Path(f"samples_s{str(trainer.global_step).zfill(zfill_batch)}")
129 | self.upload_to_cloud(samples_dir.as_posix())
130 |
131 | @staticmethod
132 | def get_tblog_dir(loggers):
133 | tb_loggers = [x for x in loggers if isinstance(x, TensorBoardLogger)]
134 | return tb_loggers[0].log_dir if tb_loggers else ""
135 |
136 | @rank_zero_only
137 | def upload_to_cloud(self, local_file):
138 | if ops.mox_valid and self.remote_dir and os.path.exists(local_file):
139 | remote_path = os.path.join(
140 | self.remote_dir, os.path.basename(os.path.normpath(local_file)))
141 | ops.copy(local_file, remote_path)
142 | return remote_path
143 | else:
144 | return None
145 |
--------------------------------------------------------------------------------
/configs/_meta_/inference.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | defaults:
4 | - /default
5 | - override /hydra/job_logging: default
6 | - override /data/train:
7 | - webvid5s
8 | - override /data/val:
9 | - webvid5s
10 | - _self_
11 |
12 | job_name: "magdiff_inference"
13 | output_dir: "_.outputs_20241206"
14 |
15 | callbacks:
16 | log_master:
17 | remote_dir: ""
18 | save_ckpt: true
19 |
20 | loggers:
21 | tensorboard:
22 | log_graph: True
23 | name: "tensorboard_save_dir"
24 |
25 | trainer:
26 | max_steps: 1
27 | log_every_n_steps: 1
28 | val_check_interval: 1
29 | num_sanity_val_steps: 1
30 | enable_progress_bar: True
31 | max_epochs: 1
32 |
33 | evaluator: pl_validate
34 |
35 | model:
36 | pretrained_model_path: "/home/user/model/stable-diffusion-2-1-base/"
37 | ckpt_path: "/home/user/model/magdiff.pth"
38 | lr: 0.00005
39 | scheduler_name: constant_with_warmup
40 | warmup_steps: 100
41 | num_inference_steps: 50 # debug
42 | # classifier-free guidance
43 | null_text_ratio: 0.15
44 | guidance_scale: 7.5
45 | # model component variants
46 | add_temp_embed: true
47 | prepend_first_frame: false
48 | add_temp_transformer: false
49 | add_temp_conv: true
50 | # trainable module
51 | freeze_text_encoder: true
52 | trainable_modules:
53 | - "temp_"
54 | - "transformer_blocks\\.\\d+"
55 | - "conv_in"
56 | # added model config
57 | load_pretrained_conv_in: True
58 | enable_xformers: False
59 | resolution: 256
60 | in_channels: 8
61 | add_entity_vae: True
62 | add_entity_clip: True
63 |
64 | data:
65 | batch_size_train: 8
66 | batch_size_val: 2
67 | resolution: 256
68 | sample_rate: 1
69 | train:
70 | webvid5s:
71 | data_dir: /home/user/data/magdiff/videos
72 | csv_subdir: annotations_76k.jsonl
73 | val:
74 | webvid5s:
75 | data_dir: /home/user/data/magdiff/val_videos
76 | csv_subdir: annotations_val.jsonl
77 |
--------------------------------------------------------------------------------
/configs/callbacks/early_stopping.yaml:
--------------------------------------------------------------------------------
1 | early_stopping:
2 | _target_: pytorch_lightning.callbacks.EarlyStopping
3 | monitor: loss
4 | mode: min
5 | patience: 20
--------------------------------------------------------------------------------
/configs/callbacks/log_master.yaml:
--------------------------------------------------------------------------------
1 | log_master:
2 | _target_: callbacks.LogMaster
3 | name: ${job_name}
4 | log_file: mlog_${hydra:job.name}.log
5 | remote_dir: null
6 | monitor: val_clip_score
7 | monitor_fn: max
--------------------------------------------------------------------------------
/configs/callbacks/lr_monitor.yaml:
--------------------------------------------------------------------------------
1 | lr_monitor:
2 | _target_: pytorch_lightning.callbacks.LearningRateMonitor
3 | logging_interval: step
--------------------------------------------------------------------------------
/configs/callbacks/model_checkpoint.yaml:
--------------------------------------------------------------------------------
1 | model_checkpoint:
2 | _target_: pytorch_lightning.callbacks.ModelCheckpoint
3 | dirpath: ckpt
4 | every_n_train_steps: 10000
5 | filename: "model_e{epoch:03d}s{step:08d}"
6 | save_last: true
7 | save_weights_only: false
--------------------------------------------------------------------------------
/configs/data/default.yaml:
--------------------------------------------------------------------------------
1 | _target_: data.SDVideoDataModule
2 |
3 | defaults:
4 | - train: null
5 | - train_alt: null
6 | - val: null
7 | - _self_
8 |
9 | batch_size_train: 1
10 | batch_size_val: 1
11 |
12 | num_frames: 8
13 | resolution: 512
14 |
15 | num_workers: 8
--------------------------------------------------------------------------------
/configs/data/train/pexels300k.yaml:
--------------------------------------------------------------------------------
1 | pexels300k:
2 | type: PexelsDataset
3 | data_dir: s3://bucket-9329/gujiaxi/data/Pexels300k/
4 | csv_subdir: csv
5 | tar_subdir: tar
6 |
7 | num_frames: ${data.num_frames}
8 | resolution: ${data.resolution}
9 |
--------------------------------------------------------------------------------
/configs/data/val/pexels300k.yaml:
--------------------------------------------------------------------------------
1 | pexels300k:
2 | type: PexelsDataset
3 | data_dir: s3://bucket-9329/gujiaxi/data/Pexels300k/
4 | csv_subdir: csv
5 | tar_subdir: tar
6 |
7 | num_frames: ${data.num_frames}
8 | resolution: ${data.resolution}
9 |
--------------------------------------------------------------------------------
/configs/default.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | defaults:
4 | - data: default
5 | - model: sdvideo
6 | - callbacks:
7 | - lr_monitor
8 | - log_master
9 | - loggers:
10 | - tensorboard
11 | - trainer: distributed
12 | - logdir: default
13 | - override hydra/hydra_logging: disabled
14 | - override hydra/job_logging: disabled
15 | - _self_
16 |
17 | seed: 42
18 |
19 | job_name: "RUN_1"
20 |
21 | output_dir: "output"
--------------------------------------------------------------------------------
/configs/logdir/default.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | hydra:
4 | output_subdir: null
5 | job:
6 | name: ${job_name}_${now:%Y%m%d}_${now:%H%M%S}
7 | chdir: true
8 | env_set:
9 | TOKENIZERS_PARALLELISM: false
10 | run:
11 | dir: ${output_dir}/${job_name}
12 | sweep:
13 | dir: ${output_dir}/${job_name}
--------------------------------------------------------------------------------
/configs/loggers/tensorboard.yaml:
--------------------------------------------------------------------------------
1 | tensorboard:
2 | _target_: pytorch_lightning.loggers.tensorboard.TensorBoardLogger
3 | save_dir: "."
4 | name: null
5 | version: tblog_${hydra:job.name}
6 | log_graph: False
7 | default_hp_metric: True
8 | prefix: ""
9 |
--------------------------------------------------------------------------------
/configs/model/sdvideo.yaml:
--------------------------------------------------------------------------------
1 | _target_: model.SDVideoModel
2 |
3 | pretrained_model_path: null
4 | ckpt_path: null
5 |
6 | # train
7 | null_text_ratio: 0.1
8 | lr: 0.001
9 | weight_decay: 0.01
10 | warmup_steps: 10
11 |
12 | # inference
13 | num_inference_steps: 50
14 | guidance_scale: 7.5
15 |
16 | # input resolution
17 | resolution: ${data.resolution}
18 |
19 | # memory saving techniques
20 | enable_gradient_checkpointing: true
21 | enable_xformers: true
--------------------------------------------------------------------------------
/configs/trainer/default.yaml:
--------------------------------------------------------------------------------
1 | _target_: pytorch_lightning.Trainer
2 |
3 | check_val_every_n_epoch: null
4 | max_steps: 8
5 | val_check_interval: 2
6 | log_every_n_steps: 2
7 |
8 | # disable checkpoint and progress bar because they are included in log_master callback
9 | enable_checkpointing: false
10 | enable_progress_bar: false
11 |
12 | # disables sampler replacement
13 | replace_sampler_ddp: false
--------------------------------------------------------------------------------
/configs/trainer/distributed.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - default.yaml
3 |
4 | strategy:
5 | _target_: pytorch_lightning.strategies.DDPStrategy
6 | find_unused_parameters: false
7 | gradient_as_bucket_view: true
8 | timeout:
9 | _target_: datetime.timedelta
10 | hours: 8
11 |
12 | # mixed-precision
13 | precision: 32
14 |
15 | # gradient clipping
16 | gradient_clip_val: 1.0
17 |
18 | # number of devices (-1 for all)
19 | accelerator: gpu
20 | devices: auto
21 |
22 | num_nodes: 1
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .data_downloader import download_data, download_model
2 | from .dataloader import SDVideoDataModule
--------------------------------------------------------------------------------
/data/constants.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | dir_path = Path(__file__).parent.as_posix()
4 |
5 | DEFAULT_DATA_DIR = "/cache/data/"
6 | DEFAULT_MODEL_DIR = "/cache/pretrained/"
7 | DEFAULT_TOKENIZER_DIR = Path(dir_path).parent.joinpath("assets/tokenizer/")
8 | DEFAULT_CSV_SUBDIR = "csv"
9 | DEFAULT_TAR_SUBDIR = "tar"
10 |
11 | DEFAULT_LABEL_FILENAME = "classnames.txt"
12 |
--------------------------------------------------------------------------------
/data/data_downloader.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import tarfile
4 | import time
5 | from multiprocessing import Pool
6 | from pathlib import Path
7 |
8 | from utils import dist_envs
9 | from utils import get_free_space
10 | from utils import ops
11 | from .constants import (
12 | DEFAULT_DATA_DIR,
13 | DEFAULT_MODEL_DIR,
14 | DEFAULT_LABEL_FILENAME,
15 | DEFAULT_CSV_SUBDIR,
16 | DEFAULT_TAR_SUBDIR,
17 | )
18 | from .utils import read_multi_csv
19 |
20 | logger = logging.getLogger(__name__)
21 |
22 |
23 | class DataDownloader:
24 | def __init__(self, processes=8):
25 | self.pool = Pool(processes)
26 |
27 | def apply_job(self, remote_path, local_path, unzip_dir=None):
28 | # download file
29 | try:
30 | ops.copy(remote_path, local_path)
31 | except Exception as e:
32 | logger.error(
33 | f"File download failed (remote_path={remote_path}; local_path={local_path}): {e}"
34 | )
35 | logger.info(f"File downloaded: {os.path.basename(local_path)}")
36 | # unzip file
37 | if unzip_dir:
38 | self.wait_for_load_below(self.pool._processes)
39 | self.pool.apply_async(self.unzip, args=(local_path, unzip_dir))
40 |
41 | @staticmethod
42 | def unzip(file_path, unzip_dir):
43 | if not tarfile.is_tarfile(file_path):
44 | return
45 | with tarfile.open(file_path) as tarf:
46 | tarf.extractall(unzip_dir)
47 | logger.info(f"File unzipped: {os.path.basename(file_path)}")
48 | # delete zip file after unzipped
49 | os.remove(file_path)
50 |
51 | def wait_for_load_below(self, load_limit: int):
52 | t_start = time.monotonic()
53 | logger.info(f"Current free space: {get_free_space('/cache/'):.2f} GB.")
54 | while len(self.pool._cache) > load_limit:
55 | time.sleep(5)
56 | if time.monotonic() - t_start > 60:
57 | logger.info("Downloader is currently overloaded: waiting...")
58 | if load_limit == 0:
59 | self.pool.close()
60 | self.pool.join()
61 |
62 |
63 | def download_dataset(downloader, dataset_cfg, local_dir, max_num_csv=0):
64 | # data are downloaded only on main process of each node
65 | if not dist_envs.is_initialized:
66 | raise ValueError(f"Distributed environments are not initialized!")
67 | # only the main process of each node is used for data downloading
68 | if dist_envs.local_rank != 0:
69 | return
70 | local_dir = Path(local_dir)
71 | # get directory names
72 | remote_dir = dataset_cfg["data_dir"]
73 | csv_subdir = dataset_cfg.get("csv_subdir", DEFAULT_CSV_SUBDIR)
74 | tar_subdir = dataset_cfg.get("tar_subdir", DEFAULT_TAR_SUBDIR)
75 | # download label file
76 | label_filename = dataset_cfg.get("label_filename", DEFAULT_LABEL_FILENAME)
77 | remote_label_file = os.path.join(remote_dir, label_filename)
78 | local_label_file = local_dir.joinpath(label_filename)
79 | if not local_label_file.exists() and ops.exists(remote_label_file):
80 | downloader.apply_job(remote_label_file, local_label_file.as_posix())
81 | # download csvs
82 | remote_csv_dir = os.path.join(remote_dir, csv_subdir)
83 | csv_files = [x for x in ops.listdir(remote_csv_dir) if x.endswith(".csv")]
84 | csv_files = csv_files[:max_num_csv] if max_num_csv > 0 else csv_files
85 | local_csv_dir = local_dir.joinpath(csv_subdir)
86 | for idx, csv_file in enumerate(csv_files):
87 | local_csv = local_csv_dir.joinpath(csv_file)
88 | if not local_csv.exists():
89 | logger.info(f"[{idx + 1}/{len(csv_files)}] "
90 | f"Downloading csv file: {csv_file}")
91 | downloader.apply_job(
92 | os.path.join(remote_csv_dir, csv_file), local_csv.as_posix())
93 |
94 | tar_name = Path(csv_files[0]).stem
95 | tar_ext_real = None
96 | for tar_ext in ["", ".tar"]:
97 | if ops.exists(os.path.join(remote_dir, tar_subdir, tar_name + tar_ext)):
98 | tar_ext_real = tar_ext
99 | break
100 | if tar_ext_real is None:
101 | raise NotImplementedError(
102 | f"Extension of tar files is not recognized: {os.path.join(remote_dir, tar_subdir)}")
103 | tar_files = [Path(x).with_suffix(tar_ext_real) for x in csv_files]
104 |
105 | if dataset_cfg.get("split_among_nodes", False):
106 | # Only download a part of tar files according to current node_rank
107 | total_df = read_multi_csv(local_csv_dir)
108 | packages = total_df["package"].to_list()
109 | node_rank, num_nodes, world_size = \
110 | dist_envs.node_rank, dist_envs.num_nodes, dist_envs.world_size
111 | # drop_last is necessary when split_among_nodes=True
112 | packages = packages[:(len(packages) // world_size) * world_size]
113 | chunk_size = len(packages) // num_nodes
114 | packages = set(packages[node_rank * chunk_size:(node_rank + 1) * chunk_size])
115 | tar_files = [x for x in tar_files if Path(x).stem in packages]
116 |
117 | # Start download tar files
118 | remote_tar_dir = os.path.join(remote_dir, tar_subdir)
119 | local_tar_dir = local_dir.joinpath(tar_subdir)
120 | for idx, tar_file in enumerate(tar_files):
121 | local_tar = local_tar_dir.joinpath(tar_file)
122 | local_subdir = local_tar_dir.joinpath(local_tar.stem)
123 | if local_subdir.exists():
124 | continue
125 | logger.info(
126 | f"[{idx + 1}/{len(tar_files)}] "
127 | f"Downloading & extracting tar file: {tar_file}")
128 | downloader.apply_job(
129 | remote_path=os.path.join(remote_tar_dir, tar_file),
130 | local_path=local_tar.as_posix(),
131 | unzip_dir=local_tar_dir.as_posix()
132 | )
133 | logger.info(f"Dataset download complete: {Path(local_dir).name}")
134 |
135 |
136 | def download_data(data_cfg):
137 | downloader = DataDownloader()
138 | stages = ("train", "train_alt", "val")
139 | for stage in stages:
140 | datasets = data_cfg.get(stage, dict()).items()
141 | for dataset_name, dataset_cfg in datasets:
142 | if dataset_cfg["data_dir"].startswith("s3://"):
143 | local_dir = Path(DEFAULT_DATA_DIR, dataset_name).as_posix()
144 | logger.info(f"Downloading dataset: {dataset_name}")
145 | download_dataset(downloader, dataset_cfg, local_dir,
146 | max_num_csv=dataset_cfg.get("max_num_csv", 0))
147 | data_cfg[stage][dataset_name]["data_dir"] = local_dir
148 | downloader.wait_for_load_below(0)
149 | return data_cfg
150 |
151 |
152 | def download_model(model_path):
153 | if not model_path or not model_path.startswith("s3://"):
154 | return model_path
155 | local_path = Path(DEFAULT_MODEL_DIR, os.path.basename(os.path.normpath(model_path)))
156 | if dist_envs.local_rank == 0 and not local_path.exists():
157 | logger.info(f"Downloading model: {local_path.name}")
158 | downloader = DataDownloader()
159 | downloader.apply_job(model_path, local_path.as_posix())
160 | downloader.wait_for_load_below(0)
161 | return local_path.as_posix()
162 |
--------------------------------------------------------------------------------
/data/dataloader.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import pytorch_lightning as pl
4 | from omegaconf import DictConfig
5 | from pytorch_lightning.utilities.data import CombinedLoader
6 | from torch.utils.data import ConcatDataset
7 | from torch.utils.data import DataLoader
8 |
9 | from .datasets import DATASETS
10 | from .utils import GroupDistributedSampler, MergeDataset, DistributedSampler
11 |
12 | logger = logging.getLogger(__name__)
13 |
14 |
15 | def build_dataloader(data_config: DictConfig, batch_size=4, shuffle=False,
16 | drop_last=True, num_workers=8, pin_memory=True):
17 | cfg_datasets = list(data_config.values())
18 | datasets = []
19 | for cfg_dataset in cfg_datasets:
20 | dataset_cls = DATASETS.get(cfg_dataset.get("type"))
21 | cfg_dataset.update({"random_sample": shuffle})
22 | dataset = dataset_cls(**cfg_dataset)
23 | datasets.append(dataset)
24 | if len(datasets) == 1:
25 | dataset = ConcatDataset(datasets)
26 | else:
27 | if not drop_last:
28 | logger.warning(f"Option `drop_last` is forced activated when merging multiple datasets.")
29 | dataset = MergeDataset(datasets)
30 | dataloader_args = dict(
31 | dataset=dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory)
32 | if not drop_last:
33 | dataloader = DataLoader(**dataloader_args, shuffle=shuffle, drop_last=drop_last)
34 | # sampler = DistributedSampler(dataset, shuffle=False, drop_last=False)
35 | # dataloader = DataLoader(**dataloader_args, sampler=sampler)
36 | else:
37 | sampler = GroupDistributedSampler(dataset, shuffle=shuffle, drop_last=drop_last)
38 | dataloader = DataLoader(**dataloader_args, sampler=sampler)
39 | return dataloader
40 |
41 |
42 | class SDVideoDataModule(pl.LightningDataModule):
43 | def __init__(
44 | self, train, val, train_alt=None, batch_size_train=1, batch_size_val=1,
45 | num_workers=8, pin_memory=True, **kwargs
46 | ):
47 | super().__init__()
48 | # make `prepare_data` called in each node when ddp is used.
49 | self.prepare_data_per_node = True
50 | # dir config
51 | self.train = train
52 | self.val = val
53 | self.train_alt = train_alt
54 | # dataloader config
55 | self.batch_size_train = batch_size_train
56 | self.batch_size_val = batch_size_val
57 | self.num_workers = num_workers
58 | self.pin_memory = pin_memory
59 |
60 | def train_dataloader(self):
61 | shuffle, drop_last = True, True
62 | dataloader = build_dataloader(
63 | self.train, shuffle=shuffle, drop_last=drop_last, batch_size=self.batch_size_train,
64 | num_workers=self.num_workers, pin_memory=self.pin_memory
65 | )
66 | dataloaders = {"train": dataloader}
67 | if not self.train_alt:
68 | return dataloaders
69 | dataloader_alt = build_dataloader(
70 | self.train_alt, shuffle=shuffle, drop_last=drop_last, batch_size=self.batch_size_train * 8,
71 | num_workers=self.num_workers, pin_memory=self.pin_memory
72 | )
73 | dataloaders.update({"train_alt": dataloader_alt})
74 | return CombinedLoader(dataloaders, mode="max_size_cycle")
75 |
76 | def val_dataloader(self):
77 | dataloader = build_dataloader(
78 | self.val, shuffle=False, drop_last=False, batch_size=self.batch_size_val,
79 | num_workers=self.num_workers
80 | )
81 | return dataloader
82 |
--------------------------------------------------------------------------------
/data/datasets.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import random
3 | import warnings
4 | from ast import literal_eval
5 | from pathlib import Path
6 |
7 | from PIL import Image
8 | from torch.utils.data import Dataset
9 | from torchvision.datasets.folder import pil_loader
10 | from torchvision.transforms import Compose, Resize, RandomCrop, CenterCrop, Lambda
11 | from torchvision.transforms.functional import pil_to_tensor
12 | from transformers import CLIPTokenizer
13 |
14 | from utils import dist_envs
15 | import os
16 | import json
17 | import open_clip
18 | import kornia
19 |
20 | from .constants import DEFAULT_TOKENIZER_DIR
21 | from .registry import DATASETS
22 | from .utils import read_multi_csv, load_video, load_entity_vae, load_entity_clip
23 |
24 | logger = logging.getLogger(__name__)
25 |
26 | warnings.simplefilter("error", Image.DecompressionBombWarning)
27 |
28 | DATALOAD_TRY_TIMES = 64
29 |
30 | def load_jsonl(filename):
31 | with open(filename, "r") as f:
32 | return [json.loads(l.strip("\n")) for l in f.readlines()]
33 |
34 | class GeneralPairDataset(Dataset):
35 | def __init__(
36 | self, data_dir, csv_subdir, tar_subdir, resolution, tokenizer_dir,
37 | random_sample,
38 | ):
39 | self.random_sample = random_sample
40 | # data related
41 | self.data_dir = data_dir
42 | self.anns_dir = os.path.join(data_dir, csv_subdir)
43 | self.frame_anns = load_jsonl(self.anns_dir)
44 | # image related
45 | # image_crop_op = RandomCrop if random_sample else CenterCrop
46 | image_crop_op = CenterCrop
47 | # resize; crop; scale pixels from [0, 255) to [-1, 1)
48 | self.transform = Compose([
49 | Resize(resolution, antialias=False), image_crop_op(resolution),
50 | Lambda(lambda pixels: pixels / 127.5 - 1.0)
51 | ])
52 | self.entity_clip_transform = Compose([
53 | Resize(448, antialias=False), image_crop_op(448)
54 | ])
55 | # text related
56 | self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_dir)
57 |
58 | self.preprocess = None
59 |
60 | def __len__(self):
61 | return len(read_multi_csv(self.meta_dir))
62 |
63 | def __getitem__(self, idx):
64 | raise NotImplementedError
65 |
66 |
67 | class VideoTextDataset(GeneralPairDataset):
68 | def __init__(
69 | self, data_dir, csv_subdir, tar_subdir, tokenizer_dir=DEFAULT_TOKENIZER_DIR,
70 | num_frames=8, resolution=512, random_sample=True, **kwargs,
71 | ):
72 | super().__init__(
73 | data_dir=data_dir, csv_subdir=csv_subdir, tar_subdir=tar_subdir,
74 | resolution=resolution, tokenizer_dir=tokenizer_dir,
75 | random_sample=random_sample
76 | )
77 | # sample config
78 | self.num_frames = num_frames
79 | # process dataframe
80 | # self.videos, self.texts = self.read_data(self.meta_dir)
81 | self.videos, self.tars, self.texts, self.coordinates = self.read_anns(self.frame_anns)
82 |
83 | @staticmethod
84 | def read_anns(anns):
85 | videos = []
86 | tars = []
87 | texts = []
88 | coordinates = []
89 | for ann in anns:
90 | videos.append(ann['vid'])
91 | tars.append(ann['tar'])
92 | texts.append(ann['text'])
93 | coordinates.append(ann['coordinates'])
94 | return videos, tars, texts, coordinates
95 |
96 | @staticmethod
97 | def read_data(meta_dir):
98 | df = read_multi_csv(meta_dir)
99 | videos, texts = df["video"].to_list(), df["caption"].to_list()
100 | return videos, texts
101 |
102 | def __len__(self):
103 | return len(self.videos)
104 |
105 | def __getitem__(self, idx):
106 | try_num = 0
107 | while try_num < DATALOAD_TRY_TIMES:
108 | try:
109 | video_path = Path(self.data_dir, "videos", self.tars[idx], "raw_vid", self.videos[idx])
110 | text = self.texts[idx]
111 | coords = self.coordinates[idx]
112 | if isinstance(text, list):
113 | text = random.choice(text)
114 | # VIDEO
115 | frame_rate = 1
116 | pixel_values, sample_ids = load_video(
117 | video_path.as_posix(),
118 | n_sample_frames=self.num_frames,
119 | sample_rate=frame_rate,
120 | random_sample=self.random_sample,
121 | transform=self.transform,
122 | selected_frames=coords
123 | )
124 | use_rand_entity_sample = True
125 | chosed_index = 0.2
126 | entity_vae = load_entity_vae(
127 | video_path, sample_ids=sample_ids, transform=self.transform,
128 | use_rand_entity_sample=use_rand_entity_sample, chosed_index=chosed_index
129 | )
130 | entity_clip = load_entity_clip(
131 | video_path, sample_ids=sample_ids,
132 | preprocess=self.preprocess, use_rand_entity_sample=use_rand_entity_sample,
133 | chosed_index=chosed_index, transform = self.entity_clip_transform
134 | )
135 | # TEXT
136 | text_token_ids = self.tokenizer(
137 | text,
138 | max_length=self.tokenizer.model_max_length,
139 | padding="max_length",
140 | truncation=True,
141 | return_tensors="pt",
142 | ).input_ids[0]
143 | break
144 | except Exception as e:
145 | logger.warning(f"Exception occurred parsing video file ({self.videos[idx]}): {e}")
146 | idx = random.randrange(
147 | len(self) // dist_envs.num_nodes * dist_envs.node_rank,
148 | len(self) // dist_envs.num_nodes * (dist_envs.node_rank + 1)
149 | ) if try_num < DATALOAD_TRY_TIMES // 2 else random.randrange(len(self))
150 | try_num += 1
151 | # output
152 | output = dict(
153 | pixel_values=pixel_values, entity_vae=entity_vae, entity_clip=entity_clip, text_token_ids=text_token_ids, frame_rates=frame_rate
154 | )
155 | return output
156 |
157 |
158 | @DATASETS.register_module()
159 | class TgifDataset(VideoTextDataset):
160 | @staticmethod
161 | def read_data(meta_dir):
162 | df = read_multi_csv(meta_dir)
163 | df["video"] = df[["package", "id"]].agg(
164 | lambda xs: "{0}/{1}.gif".format(*xs), axis=1)
165 | # load data paths
166 | videos = df["video"].to_list()
167 | texts = list(map(literal_eval, df["caption"]))
168 | return videos, texts
169 |
170 |
171 | @DATASETS.register_module()
172 | class VatexDataset(VideoTextDataset):
173 | @staticmethod
174 | def read_data(meta_dir):
175 | df = read_multi_csv(meta_dir)
176 | df["video"] = df[["package", "videoID"]].agg(
177 | lambda xs: "{0}/{1}".format(*xs), axis=1)
178 | # load data paths
179 | videos = df["video"].to_list()
180 | texts = list(map(literal_eval, df["enCap"]))
181 | return videos, texts
182 |
183 |
184 | @DATASETS.register_module()
185 | class WebvidDataset(VideoTextDataset):
186 | @staticmethod
187 | def read_data(meta_dir):
188 | df = read_multi_csv(meta_dir)
189 | df["video"] = df[["package", "videoid"]].agg(
190 | lambda xs: "{0}/{1}.mp4".format(*xs), axis=1)
191 | # load data paths
192 | videos = df["video"].to_list()
193 | # add watermark as keyword since all videos in WebVid have watermarks
194 | texts = [[f"{x}, watermark", f"{x} with watermark"]
195 | for x in df["name"].to_list()]
196 | return videos, texts
197 |
198 |
199 | @DATASETS.register_module()
200 | class K700CaptionDataset(VideoTextDataset):
201 | @staticmethod
202 | def read_data(meta_dir):
203 | df = read_multi_csv(meta_dir)
204 | df["video"] = df[["package", "id"]].agg(
205 | lambda xs: "{0}/{1}".format(*xs), axis=1)
206 | df["text"] = df[["class", "caption"]].agg(
207 | lambda xs: ["{0}: {1}".format(*xs),
208 | "{1}, {0}".format(*xs)], axis=1)
209 | # load data paths
210 | videos = df["video"].to_list()
211 | texts = df["text"].to_list()
212 | return videos, texts
213 |
214 |
215 | @DATASETS.register_module()
216 | class MidjourneyVideoDataset(VideoTextDataset):
217 | @staticmethod
218 | def read_data(meta_dir):
219 | df = read_multi_csv(meta_dir)
220 | df["video"] = df[["package", "videoname"]].agg(
221 | lambda xs: "{0}/{1}".format(*xs), axis=1)
222 | # load data paths
223 | videos = df["video"].to_list()
224 | texts = df["caption"].to_list()
225 | return videos, texts
226 |
227 |
228 | @DATASETS.register_module()
229 | class MomentsInTimeDataset(VideoTextDataset):
230 | @staticmethod
231 | def read_data(meta_dir):
232 | df = read_multi_csv(meta_dir)
233 | df["video"] = df[["package", "Video Name"]].agg(
234 | lambda xs: "{0}/{1}".format(*xs), axis=1)
235 | # load data paths
236 | videos = df["video"].to_list()
237 | texts = df["caption"].to_list()
238 | return videos, texts
239 |
240 |
241 | @DATASETS.register_module()
242 | class PexelsDataset(VideoTextDataset):
243 | @staticmethod
244 | def read_data(meta_dir):
245 | df = read_multi_csv(meta_dir)
246 | df["video"] = df[["package", "id"]].agg(
247 | lambda xs: "{0}/{1}.mp4".format(*xs), axis=1)
248 | # load data paths
249 | videos = df["video"].to_list()
250 | texts = df["caption"].to_list()
251 | return videos, texts
252 |
253 |
254 | class ImageTextDataset(GeneralPairDataset):
255 | def __init__(
256 | self, data_dir, csv_subdir, tar_subdir, tokenizer_dir=DEFAULT_TOKENIZER_DIR,
257 | resolution=512, random_sample=True, **kwargs,
258 | ):
259 | super().__init__(
260 | data_dir=data_dir, csv_subdir=csv_subdir, tar_subdir=tar_subdir,
261 | resolution=resolution, tokenizer_dir=tokenizer_dir,
262 | random_sample=random_sample
263 | )
264 | self.images, self.texts = self.read_data(self.meta_dir)
265 |
266 | @staticmethod
267 | def read_data(meta_dir):
268 | df = read_multi_csv(meta_dir)
269 | images, texts = df["image"].to_list(), df["text"].to_list()
270 | return images, texts
271 |
272 | def __len__(self):
273 | return len(self.images)
274 |
275 | def __getitem__(self, idx):
276 | try_num = 0
277 | while try_num < DATALOAD_TRY_TIMES:
278 | try:
279 | # IMAGE
280 | image_path = Path(self.data_dir, self.images[idx])
281 | image = pil_loader(image_path.as_posix())
282 | pixel_values = self.transform(pil_to_tensor(image))
283 | # TEXT
284 | text = self.texts[idx]
285 | text_token_ids = self.tokenizer(
286 | text,
287 | max_length=self.tokenizer.model_max_length,
288 | padding="max_length",
289 | truncation=True,
290 | return_tensors="pt",
291 | ).input_ids[0]
292 | break
293 | except Exception as e:
294 | logger.warning(f"Exception occurred parsing image file ({self.images[idx]}): {e}")
295 | idx = random.randrange(
296 | len(self) // dist_envs.num_nodes * dist_envs.node_rank,
297 | len(self) // dist_envs.num_nodes * (dist_envs.node_rank + 1)
298 | ) if try_num < DATALOAD_TRY_TIMES // 2 else random.randrange(len(self))
299 | try_num += 1
300 | # output
301 | frame_rates = random.randrange(30)
302 | output = dict(
303 | pixel_values=pixel_values, text_token_ids=text_token_ids, frame_rates=frame_rates
304 | )
305 | return output
306 |
307 |
308 | @DATASETS.register_module()
309 | class LaionDataset(ImageTextDataset):
310 | @staticmethod
311 | def read_data(meta_dir):
312 | df = read_multi_csv(meta_dir)
313 | df = df.dropna(subset=["dir", "text"], how="any").reset_index(drop=True)
314 | images, texts = df["dir"].to_list(), df["text"].to_list()
315 | return images, texts
316 |
317 |
318 | @DATASETS.register_module()
319 | class MidJourneyImageDataset(ImageTextDataset):
320 | @staticmethod
321 | def read_data(meta_dir):
322 | df = read_multi_csv(meta_dir)
323 | df = df.dropna(subset=["videoname", "caption"], how="any").reset_index(drop=True)
324 | df["videoname"] = df[["package", "videoname"]].agg(
325 | lambda xs: "{0}/{1}.png".format(*xs), axis=1)
326 | images, texts = df["videoname"].to_list(), df["caption"].to_list()
327 | return images, texts
328 |
--------------------------------------------------------------------------------
/data/registry.py:
--------------------------------------------------------------------------------
1 | from utils import Registry
2 |
3 | DATASETS = Registry("datasets")
4 |
--------------------------------------------------------------------------------
/data/utils.py:
--------------------------------------------------------------------------------
1 | import bisect
2 | import importlib
3 | import logging
4 | import os
5 | import random
6 | from collections import Counter
7 | from operator import itemgetter
8 | from pathlib import Path
9 | from typing import Union
10 |
11 | import numpy as np
12 | import pandas as pd
13 | import torch
14 | from PIL import Image, ImageSequence
15 | from torch.utils.data import ConcatDataset
16 | from torch.utils.data.distributed import DistributedSampler
17 | from torchvision.datasets.folder import pil_loader
18 | from torchvision.transforms import Compose, Resize, CenterCrop, Lambda
19 | from torchvision.transforms.functional import pil_to_tensor
20 |
21 | from utils import dist_envs
22 | import cv2
23 | import random
24 |
25 | if importlib.util.find_spec("decord"):
26 | import decord
27 |
28 | logger = logging.getLogger(__name__)
29 |
30 |
31 | def sparse_sample(total_frames, sample_frames, sample_rate, random_sample=False):
32 | if sample_frames <= 0: # sample over the total sequence of frames
33 | ids = np.arange(0, total_frames, sample_rate, dtype=int).tolist()
34 | elif sample_rate * (sample_frames - 1) + 1 <= total_frames:
35 | offset = random.randrange(total_frames - (sample_rate * (sample_frames - 1))) \
36 | if random_sample else 0
37 | ids = list(range(offset, total_frames + offset, sample_rate))[:sample_frames]
38 | else:
39 | ids = np.linspace(0, total_frames, sample_frames, endpoint=False, dtype=int).tolist()
40 | return ids
41 |
42 | def load_image(img_path, resize_res=None, crop=False, rescale=False):
43 | """ Load an image to tensor
44 |
45 | Args:
46 | img_path: image file path to load.
47 | resize_res: resolution of resize, no resize if None
48 | crop: center crop if True
49 | rescale: values in [-1, 1) if rescale=True otherwise [0, 255)
50 |
51 | Returns:
52 | torch.Tensor: image tensor in the shape of [c, h, w]
53 | """
54 | img = pil_to_tensor(pil_loader(img_path))
55 | tsfm_ops = []
56 | crop_size = min(img.shape[-2:])
57 | if resize_res:
58 | tsfm_ops.append(Resize(resize_res, antialias=False))
59 | crop_size = resize_res
60 | if crop:
61 | tsfm_ops.append(CenterCrop(crop_size))
62 | if rescale:
63 | tsfm_ops.append(Lambda(lambda pixels: pixels / 127.5 - 1.0))
64 | transform = Compose(tsfm_ops)
65 | return transform(img)
66 |
67 |
68 | def load_video(file_path, n_sample_frames, sample_rate=4, random_sample=False, transform=None, selected_frames=None):
69 | # random_sample = True
70 | sample_args = dict(
71 | sample_frames=n_sample_frames, sample_rate=sample_rate, random_sample=random_sample)
72 | video = []
73 | if Path(file_path).is_dir():
74 | img_files = sorted(Path(file_path).glob("*"), key=lambda i: int(i.stem[6:]))
75 | if len(img_files) < 1:
76 | logger.error(f"No data in video directory: {file_path}")
77 | raise FileNotFoundError(f"No data in video directory: {file_path}")
78 |
79 | # sample_ids = sparse_sample(len(img_files), **sample_args)
80 | selected_frames_ids = sparse_sample(len(selected_frames), **sample_args) ##FIXME change to selected frames, not all the frames
81 | sample_ids = []
82 | for i in range(0, len(selected_frames_ids)):
83 | sample_ids.append(selected_frames[selected_frames_ids[i]])
84 |
85 | for img_file in itemgetter(*sample_ids)(img_files):
86 | img = pil_loader(img_file.as_posix())
87 | img = pil_to_tensor(img)
88 | video.append(img)
89 | elif file_path.endswith(".gif"):
90 | with Image.open(file_path) as gif:
91 | sample_ids = sparse_sample(gif.n_frames, **sample_args)
92 | sample_ids_counter = Counter(sample_ids)
93 | for frame_idx, frame in enumerate(ImageSequence.Iterator(gif)):
94 | if frame_idx in sample_ids_counter:
95 | frame = pil_to_tensor(frame.convert("RGB"))
96 | for _ in range(sample_ids_counter[frame_idx]):
97 | video.append(frame)
98 | else:
99 | vreader = decord.VideoReader(file_path)
100 | sample_ids = sparse_sample(len(vreader), **sample_args)
101 | frames = vreader.get_batch(sample_ids).asnumpy() # (f, h, w, c)
102 | for frame_idx in range(frames.shape[0]):
103 | video.append(pil_to_tensor(Image.fromarray(frames[frame_idx]).convert("RGB")))
104 | video = torch.stack(video) # (f, c, h, w)
105 | if transform is not None:
106 | video = transform(video)
107 | return video, sample_ids
108 |
109 | def load_entity_vae(file_path, sample_ids, transform=None, use_rand_entity_sample=False, chosed_index=None):
110 | video = []
111 | selected_frames = [1]
112 | if Path(file_path).is_dir():
113 | img_files = sorted(Path(file_path).glob("*"), key=lambda i: int(i.stem[6:]))
114 | if len(img_files) < 1:
115 | logger.error(f"No data in video directory: {file_path}")
116 | raise FileNotFoundError(f"No data in video directory: {file_path}")
117 | index = 1
118 | for img_file in itemgetter(*sample_ids)(img_files):
119 | img_file = str(img_file)
120 | mask_file = img_file.replace("raw_vid/", "").replace("videos", "mask").replace("frame_","").replace("jpg","png")
121 | video_img = cv2.imread(img_file)
122 | mask_img = cv2.imread(mask_file, cv2.IMREAD_GRAYSCALE)
123 | entity = cv2.bitwise_and(video_img, video_img, mask=mask_img)
124 | # cv2 -> PIL
125 | img = Image.fromarray(cv2.cvtColor(entity, cv2.COLOR_BGR2RGB))
126 | img = pil_to_tensor(img)
127 | if use_rand_entity_sample: ###FIXME random mask some frames
128 | if index not in selected_frames:
129 | img = torch.zeros_like(img)
130 | video.append(img)
131 | index += 1
132 | video = torch.stack(video) # (f, c, h, w)
133 | if transform is not None:
134 | video = transform(video)
135 | return video
136 |
137 | def load_entity_clip(file_path, sample_ids, preprocess=None, use_rand_entity_sample=False, chosed_index=None, transform=None):
138 | video = []
139 | selected_frames = [1]
140 | if Path(file_path).is_dir():
141 | img_files = sorted(Path(file_path).glob("*"), key=lambda i: int(i.stem[6:]))
142 | if len(img_files) < 1:
143 | logger.error(f"No data in video directory: {file_path}")
144 | raise FileNotFoundError(f"No data in video directory: {file_path}")
145 | index = 1
146 | for img_file in itemgetter(*sample_ids)(img_files):
147 | img_file = str(img_file)
148 | mask_file = img_file.replace("raw_vid/", "").replace("videos", "mask").replace("frame_","").replace("jpg","png")
149 | video_img = cv2.imread(img_file)
150 | mask_img = cv2.imread(mask_file, cv2.IMREAD_GRAYSCALE)
151 | entity = cv2.bitwise_and(video_img, video_img, mask=mask_img)
152 |
153 | img_tensor = torch.from_numpy(entity).permute(2, 0, 1).float()
154 | if use_rand_entity_sample: ###FIXME random mask some frames
155 | if index not in selected_frames:
156 | img_tensor = torch.zeros_like(img_tensor)
157 | img = (img_tensor / 255. - 0.5) * 2
158 | video.append(img)
159 | index += 1
160 | video = torch.stack(video, dim=0) # (f, c, h, w)
161 | if transform is not None:
162 | video = transform(video)
163 | return video
164 |
165 |
166 | class MergeDataset(ConcatDataset):
167 | r"""Dataset as a merge of multiple datasets.
168 |
169 | Datasets are firstly split into multiple chunks and these chunks are merged
170 | one after another.
171 |
172 | Example:
173 | 3 Datasets with sizes: [6, 17, 27]; split_num=4
174 | The global indices will be like:
175 | [[1, 4, 6]
176 | [1, 4, 6]
177 | [1, 4, 6]
178 | [3, 5, 9]]
179 |
180 | Args:
181 | datasets: List of datasets to be merged
182 | """
183 |
184 | @staticmethod
185 | def split_cumsum(group_sizes, split_num):
186 | r, s = [], 0
187 | for split in range(split_num):
188 | chunk_sizes = [x // split_num for x in group_sizes]
189 | if split == split_num - 1:
190 | chunk_sizes = [chunk_sizes[i] + group_sizes[i] % split_num
191 | for i in range(len(group_sizes))]
192 | for chunk_size in chunk_sizes:
193 | r.append(chunk_size + s)
194 | s += chunk_size
195 | return r
196 |
197 | def __init__(self, datasets) -> None:
198 | super(MergeDataset, self).__init__(datasets)
199 | num_nodes, world_size = dist_envs.num_nodes, dist_envs.world_size
200 | # drop_last for all datasets
201 | self.datasets = list(datasets)
202 | self.dataset_sizes = list()
203 | for dataset in self.datasets:
204 | dataset_size = len(dataset) # type: ignore[arg-type]
205 | self.dataset_sizes.append((dataset_size // world_size) * world_size)
206 | self.chunk_sizes = [x // num_nodes for x in self.dataset_sizes]
207 | self.cumulative_sizes = self.split_cumsum(self.dataset_sizes, num_nodes)
208 |
209 | def __len__(self):
210 | return sum(self.dataset_sizes)
211 |
212 | def __getitem__(self, idx):
213 | if idx < 0:
214 | if -idx > len(self):
215 | raise ValueError("absolute value of index should not exceed dataset length")
216 | idx = len(self) + idx
217 | global_chunk_idx = bisect.bisect_right(self.cumulative_sizes, idx)
218 | if global_chunk_idx == 0:
219 | dataset_idx = 0
220 | sample_idx = idx
221 | else:
222 | dataset_idx = global_chunk_idx % len(self.datasets)
223 | sample_delta = idx - self.cumulative_sizes[global_chunk_idx - 1]
224 | sample_idx = global_chunk_idx // len(self.datasets) * self.chunk_sizes[dataset_idx] + sample_delta
225 | return self.datasets[dataset_idx][sample_idx]
226 |
227 |
228 | class GroupDistributedSampler(DistributedSampler):
229 | r"""Sampler that restricts grouped data loading to a subset of the dataset.
230 |
231 | The dataset is firstly split into groups determined by `num_nodes`. The
232 | grouped datasets are then shuffled and distributed among devices in each
233 | node.
234 | """
235 |
236 | def __iter__(self):
237 | num_nodes = dist_envs.num_nodes
238 | assert self.num_replicas % num_nodes == 0
239 | indices = list(range(len(self.dataset))) # type: ignore[arg-type]
240 | chunk_size = self.total_size // num_nodes
241 | if self.shuffle:
242 | indices = []
243 | for node_rank in range(num_nodes):
244 | g = torch.Generator()
245 | g.manual_seed(self.seed + node_rank + self.epoch)
246 | chunk_indices = torch.randperm(
247 | chunk_size, generator=g).tolist() # type: ignore[arg-type]
248 | indices.extend([x + chunk_size * node_rank for x in chunk_indices])
249 |
250 | # remove tail of data to make it evenly divisible.
251 | indices = indices[:self.total_size]
252 | assert len(indices) == self.total_size
253 |
254 | # subsample
255 | num_devices = self.num_replicas // num_nodes
256 | node_rank = int(self.rank / num_devices)
257 | local_rank = self.rank % num_devices
258 | indices = indices[local_rank:chunk_size:num_devices]
259 | indices = [x + chunk_size * node_rank for x in indices]
260 | assert len(indices) == self.num_samples
261 |
262 | return iter(indices)
263 |
264 |
265 | class LabelEncoder:
266 | """ Encodes an label via a dictionary.
267 | Args:
268 | label_source (list of strings): labels of data used to build encoding dictionary.
269 | Example:
270 | >>> labels = ['label_a', 'label_b']
271 | >>> encoder = LabelEncoder(labels)
272 | >>> encoder.encode('label_a')
273 | tensor(0)
274 | >>> encoder.decode(encoder.encode('label_a'))
275 | 'label_a'
276 | >>> encoder.encode('label_b')
277 | tensor(1)
278 | >>> encoder.size
279 | ['label_a', 'label_b']
280 | """
281 |
282 | def __init__(self, label_source: Union[list, str, os.PathLike]):
283 | if isinstance(label_source, list):
284 | self.labels = label_source
285 | else:
286 | with open(label_source, "r", encoding="utf-8") as f:
287 | lines = [x.strip() for x in f.readlines() if x.strip()]
288 | self.labels = lines
289 | self.idx_to_label = {idx: lab for idx, lab in enumerate(self.labels)}
290 | self.label_to_idx = {lab: idx for idx, lab in enumerate(self.labels)}
291 |
292 | @property
293 | def size(self):
294 | """
295 | Returns:
296 | int: Number of labels in the dictionary.
297 | """
298 | return len(self.labels)
299 |
300 | def encode(self, label):
301 | """ Encodes a ``label``.
302 |
303 | Args:
304 | label (object): Label to encode.
305 |
306 | Returns:
307 | torch.Tensor: Encoding of the label.
308 | """
309 | return torch.tensor(self.label_to_idx.get(label), dtype=torch.long)
310 |
311 | def batch_encode(self, iterator, dim=0):
312 | """
313 | Args:
314 | iterator (iterator): Batch of labels to encode.
315 | dim (int, optional): Dimension along which to concatenate tensors.
316 |
317 | Returns:
318 | torch.Tensor: Tensor of encoded labels.
319 | """
320 | return torch.stack([self.encode(x) for x in iterator], dim=dim)
321 |
322 | def decode(self, encoded):
323 | """ Decodes ``encoded`` label.
324 |
325 | Args:
326 | encoded (torch.Tensor): Encoded label.
327 |
328 | Returns:
329 | object: Label decoded from ``encoded``.
330 | """
331 | if encoded.numel() > 1:
332 | raise ValueError(
333 | '``decode`` decodes one label at a time, use ``batch_decode`` instead.')
334 |
335 | return self.idx_to_label[encoded.squeeze().item()]
336 |
337 | def batch_decode(self, tensor, dim=0):
338 | """
339 | Args:
340 | tensor (torch.Tensor): Batch of tensors.
341 | dim (int, optional): Dimension along which to split tensors.
342 |
343 | Returns:
344 | list: Batch of decoded labels.
345 | """
346 | return [self.decode(x) for x in [t.squeeze(0) for t in tensor.split(1, dim=dim)]]
347 |
348 |
349 | def read_multi_csv(csv_dir):
350 | csvs = sorted(Path(csv_dir).glob("*.csv"))
351 | df_all = []
352 | for c in csvs:
353 | df = pd.read_csv(c)
354 | df["package"] = Path(c).stem
355 | df_all.append(df)
356 | df_all = pd.concat(df_all, ignore_index=True)
357 | return df_all
358 |
--------------------------------------------------------------------------------
/inference.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 MASTER_ADDR=127.0.0.1 MASTER_PORT=10086 NODE_RANK=0 WORLD_SIZE=1 HYDRA_FULL_ERROR=1 \
2 | python main.py --config-name=_meta_/inference.yaml
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 |
4 | import hydra
5 | from hydra.utils import instantiate
6 | from lightning_fabric.utilities.seed import seed_everything
7 | from omegaconf import OmegaConf
8 |
9 | from data import download_data, download_model
10 | from utils import instantiate_multi, save_config, dist_envs, enable_logger
11 |
12 | logger = logging.getLogger(__name__)
13 |
14 |
15 | @hydra.main(config_path="configs", config_name="example", version_base=None)
16 | def main(config):
17 | # 1. initialize trainer (should be done at the first place)
18 | trainer_callbacks = instantiate_multi(config, "callbacks")
19 | trainer_loggers = instantiate_multi(config, "loggers")
20 | trainer = instantiate(
21 | config.trainer, callbacks=trainer_callbacks, logger=trainer_loggers,
22 | )
23 | # 2. save envs and configs
24 | save_config(config)
25 | seed_everything(config.get("seed", 42), workers=True)
26 | dist_envs.init_envs(trainer)
27 | # 3. enable some loggers
28 | if "log_master" in config.callbacks and config.callbacks.log_master is not None:
29 | enable_logger("callbacks.log_master", config.callbacks.log_master.log_file)
30 | for log_name in (
31 | "__main__", "data.dataloader", "data.datasets", "data.data_downloader",
32 | "model.model", "components.interpolation.module", "components.interpolation.pipeline"
33 | ):
34 | enable_logger(log_name)
35 | # 4. build model
36 | remote_paths = ["pretrained_model_path", "ckpt_path", "temporal_vae_path"]
37 | for remote_path in remote_paths:
38 | if hasattr(config.model, remote_path):
39 | rpath = getattr(config.model, remote_path)
40 | setattr(config.model, remote_path, download_model(rpath))
41 | model = instantiate(config.model)
42 | # [DEBUG] show whole config
43 | if os.environ.get("DEBUG_ON", None):
44 | logger.info(OmegaConf.to_yaml(config))
45 | # 5. starting training/testing
46 | evaluator = config.get("evaluator", None)
47 | if evaluator is None or evaluator == "pl_validate":
48 | # config.data = download_data(config.data)
49 | datamodule = instantiate(config.data)
50 | run_fn = trainer.fit if evaluator is None else trainer.validate
51 | run_fn(model=model, datamodule=datamodule)
52 | else:
53 | model.setup(stage="test")
54 | evaluator = instantiate(config.evaluator)
55 | evaluator(model)
56 | logger.info("All finished.")
57 |
58 |
59 | if __name__ == "__main__":
60 | main()
61 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import SDVideoModel, SDVideoModelEvaluator
2 |
--------------------------------------------------------------------------------
/model/model.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import os
4 | import random
5 | import re
6 | from datetime import datetime
7 | from pathlib import Path
8 | from typing import Union
9 | import time
10 |
11 | import pandas as pd
12 | import pytorch_lightning as pl
13 | import torch
14 | import torch.nn.functional as F
15 | import torch.utils.checkpoint
16 | from diffusers import AutoencoderKL, DDIMScheduler
17 | from .modules.zero_snr_ddpm import DDPMScheduler ###FIXME changed zero-SNR
18 | from diffusers.optimization import get_scheduler
19 | from diffusers.utils.import_utils import is_xformers_available
20 | from einops import rearrange
21 | from torch.optim import AdamW
22 | from transformers import CLIPTextModel, CLIPTokenizer, CLIPModel, CLIPProcessor
23 |
24 | from components.vae import TemporalAutoencoderKL
25 | from data.utils import load_image
26 | from utils import dist_envs
27 | from utils import slugify
28 | from .modules.unet import UNet3DConditionModel
29 | from .pipeline import SDVideoPipeline
30 | from .utils import save_videos_grid, compute_clip_score, prepare_masked_latents, prepare_entity_latents, FrozenOpenCLIPImageEmbedderV2, Resampler
31 |
32 | logger = logging.getLogger(__name__)
33 |
34 |
35 | class SDVideoModel(pl.LightningModule):
36 | def __init__(self, pretrained_model_path, **kwargs):
37 | super().__init__()
38 | self.save_hyperparameters(ignore=["pretrained_model_path"], logger=False)
39 | # main training module
40 | self.unet: Union[str, UNet3DConditionModel] = Path(pretrained_model_path, "unet").as_posix()
41 | # components for training
42 | self.noise_scheduler_dir = Path(pretrained_model_path, "scheduler").as_posix()
43 | self.vae = Path(pretrained_model_path, "vae").as_posix()
44 | self.text_encoder = Path(pretrained_model_path, "text_encoder").as_posix()
45 | self.tokenizer: Union[str, CLIPTokenizer] = Path(pretrained_model_path, "tokenizer").as_posix()
46 | # clip model for metric
47 | self.clip = Path(pretrained_model_path, "clip").as_posix()
48 | self.clip_processor = Path(pretrained_model_path, "clip").as_posix()
49 | # define pipeline for inference
50 | self.val_pipeline = None
51 | # video frame resolution
52 | self.resolution = kwargs.get("resolution", 512)
53 | # use temporal_vae
54 | self.temporal_vae_path = kwargs.get("temporal_vae_path", None)
55 | # use prompt image
56 | self.in_channels = kwargs.get("in_channels", 4)
57 | self.use_prompt_image = self.in_channels > 4
58 | self.add_entity_vae = kwargs.get("add_entity_vae", False)
59 | self.add_entity_clip = kwargs.get("add_entity_clip", False)
60 |
61 | ### add open clip model
62 | if self.add_entity_clip:
63 | self.embedding_dim = 1280
64 | self.entity_clip_model = FrozenOpenCLIPImageEmbedderV2(arch="ViT-H-14") ###FIXME
65 | self.enclip_projector = Resampler(dim=1024, depth=4, dim_head=64, heads=12, num_queries=16, embedding_dim=self.embedding_dim, output_dim=1024, ff_mult=4)
66 | else:
67 | self.entity_clip_model = None
68 | self.enclip_projector = None
69 |
70 | def setup(self, stage: str) -> None:
71 | # build modules
72 | self.noise_scheduler = DDPMScheduler.from_pretrained(self.noise_scheduler_dir)
73 | self.tokenizer = CLIPTokenizer.from_pretrained(self.tokenizer)
74 |
75 | if self.temporal_vae_path:
76 | self.vae = TemporalAutoencoderKL.from_pretrained(self.temporal_vae_path)
77 | else:
78 | self.vae = AutoencoderKL.from_pretrained(self.vae)
79 | self.text_encoder = CLIPTextModel.from_pretrained(self.text_encoder)
80 | self.unet = UNet3DConditionModel.from_pretrained_2d(
81 | self.unet, sample_size=self.resolution // (2 ** (len(self.vae.config.block_out_channels) - 1)),
82 | in_channels=self.in_channels,
83 | add_temp_transformer=self.hparams.get("add_temp_transformer", False),
84 | add_temp_attn_only_on_upblocks=self.hparams.get("add_temp_attn_only_on_upblocks", False),
85 | prepend_first_frame=self.hparams.get("prepend_first_frame", False),
86 | add_temp_embed=self.hparams.get("add_temp_embed", False),
87 | add_temp_ff=self.hparams.get("add_temp_ff", False),
88 | add_temp_conv=self.hparams.get("add_temp_conv", False),
89 | num_class_embeds=self.hparams.get("num_class_embeds", None)
90 | )
91 |
92 | # load previously trained components for resumed training
93 | ckpt_path = self.hparams.get("ckpt_path", None)
94 | if ckpt_path is not None:
95 | state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
96 |
97 | mod_list = ["unet", "text_encoder"] if self.temporal_vae_path \
98 | else ["unet", "text_encoder", "vae", "enclip_projector"]
99 | for mod in mod_list:
100 | if any(filter(lambda x: x.startswith(mod), state_dict.keys())):
101 | mod_instance = getattr(self, mod)
102 | mod_instance.load_state_dict(
103 | {k[len(mod) + 1:]: v for k, v in state_dict.items() if k.startswith(mod)}, strict=False
104 | )
105 |
106 | # null text for classifier-free guidance
107 | self.null_text_token_ids = self.tokenizer( # noqa
108 | "", max_length=self.tokenizer.model_max_length, padding="max_length",
109 | truncation=True, return_tensors="pt",
110 | ).input_ids[0]
111 |
112 | # train only the trainable modules
113 | trainable_modules = self.hparams.get("trainable_modules", None)
114 | if trainable_modules is not None:
115 | self.unet.requires_grad_(False)
116 | for pname, params in self.unet.named_parameters():
117 | if any([re.search(pat, pname) for pat in trainable_modules]):
118 | params.requires_grad = True
119 | if self.add_entity_clip:
120 | for pname, params in self.entity_clip_model.named_parameters():
121 | params.requires_grad = False
122 | for pname, params in self.enclip_projector.named_parameters():
123 | params.requires_grad = True
124 |
125 | # raise error when `in_channel` > 4 and `conv_in` is not trainable
126 | if self.use_prompt_image and not self.unet.conv_in.weight.requires_grad:
127 | raise AssertionError(f"use_prompt_image=True but `unet.conv_in` is frozen.")
128 | if not self.use_prompt_image and self.unet.conv_in.weight.requires_grad:
129 | logger.warning(f"use_prompt_image=False but `unet.conv_in` is trainable.")
130 |
131 | # load clip modules for evaluation
132 | self.clip = CLIPModel.from_pretrained(self.clip)
133 | self.clip_processor = CLIPProcessor.from_pretrained(self.clip_processor)
134 | # prepare modules
135 | for component in [self.vae, self.text_encoder, self.clip]:
136 | if not isinstance(component, CLIPTextModel) or self.hparams.get("freeze_text_encoder", False):
137 | component.requires_grad_(False).eval()
138 | if stage != "test" and self.trainer.precision.startswith("16"):
139 | component.to(dtype=torch.float16)
140 | # [DEBUG] show which parameters are trainable
141 | if os.environ.get("DEBUG_ON", None):
142 | params_trainable, params_frozen = [], []
143 | for name, params in self.named_parameters():
144 | if params.requires_grad:
145 | params_trainable.append(name)
146 | else:
147 | params_frozen.append(name)
148 | logger.info(f"*** [Trainable parameters]: {params_trainable}")
149 | logger.info(f"*** [Frozen parameters]: {params_frozen}")
150 | # use gradient checkpointing
151 | if self.hparams.get("enable_gradient_checkpointing", True):
152 | if not self.hparams.get("freeze_text_encoder", False):
153 | self.text_encoder.gradient_checkpointing_enable()
154 | self.unet.enable_gradient_checkpointing()
155 | # use xformers for efficient training
156 | if self.hparams.get("enable_xformers", True) and not hasattr(F, "scaled_dot_product_attention"):
157 | if is_xformers_available():
158 | self.unet.enable_xformers_memory_efficient_attention()
159 | if self.unet.temp_transformer is not None:
160 | # FIXME: disable this specific layer otherwise CUDA error occurred
161 | self.unet.temp_transformer.set_use_memory_efficient_attention_xformers(False)
162 | else:
163 | raise ValueError("xformers is not available. Make sure it is installed correctly")
164 | # construct pipeline for inference
165 | self.val_pipeline = SDVideoPipeline(
166 | vae=self.vae, text_encoder=self.text_encoder, tokenizer=self.tokenizer, unet=self.unet,
167 | entity_clip_model=self.entity_clip_model,enclip_projector=self.enclip_projector,
168 | scheduler=DDIMScheduler.from_pretrained(self.noise_scheduler_dir),
169 | )
170 |
171 | def validation_step(self, batch, batch_idx):
172 | pixel_values = batch["pixel_values"].to(dtype=torch.float16) \
173 | if self.trainer.precision.startswith("16") else batch["pixel_values"]
174 | entity_vae = batch["entity_vae"].to(dtype=torch.float16) \
175 | if self.trainer.precision.startswith("16") else batch["entity_vae"]
176 | entity_clip = batch["entity_clip"].to(dtype=torch.float16) \
177 | if self.trainer.precision.startswith("16") else batch["entity_clip"]
178 |
179 | text_token_ids = batch["text_token_ids"]
180 | video_len = pixel_values.shape[1]
181 | # inference arguments
182 | num_inference_steps = self.hparams.get("num_inference_steps", 50)
183 | guidance_scale = self.hparams.get("guidance_scale", 7.5)
184 | noise_alpha = self.hparams.get("noise_alpha", .0)
185 | # parase prompts
186 | prompts = self.tokenizer.batch_decode(text_token_ids, skip_special_tokens=True)
187 | generator = torch.Generator(device=self.device)
188 | seed = 42
189 | generator.manual_seed(seed)
190 | # compose args
191 | pipeline_args = dict(
192 | generator=generator, num_frames=video_len,
193 | num_inference_steps=num_inference_steps,
194 | guidance_scale=guidance_scale, noise_alpha=noise_alpha,
195 | frame_rate=1 if self.hparams.get("num_class_embeds", None) else None
196 | )
197 | samples = []
198 |
199 | for example_id in range(len(prompts)):
200 | prompt = prompts[example_id]
201 | prompt_image = pixel_values[example_id][:1, ...]
202 | entity_vae_image = torch.unsqueeze(entity_vae[example_id], 0)
203 | entity_clip_image = torch.unsqueeze(entity_clip[example_id], 0)
204 | sample = self.val_pipeline(
205 | prompt, prompt_image if self.use_prompt_image else None,
206 | entity_vae=entity_vae_image, entity_clip=entity_clip_image,
207 | add_entity_vae=self.add_entity_vae, add_entity_clip=self.add_entity_clip,
208 | **pipeline_args
209 | ).videos
210 | if self.trainer.is_global_zero:
211 | num_step_str = str(self.global_step).zfill(len(str(self.trainer.estimated_stepping_batches)))
212 | if prompt == "":
213 | prompt = "text_prompt_equal_null"
214 | save_videos_grid(sample, Path(f"samples_s{num_step_str}", f"{prompt}.gif"))
215 | samples.append(sample)
216 | # clip model for metric
217 | clip_scores = compute_clip_score(
218 | model=self.clip, model_processor=self.clip_processor,
219 | images=torch.cat(samples), texts=list(prompts), rescale=False,
220 | )
221 | self.log("val_clip_score", clip_scores.mean(), on_step=False, on_epoch=True, sync_dist=True)
222 |
223 | def configure_optimizers(self):
224 | optimizer_args = dict(
225 | params=filter(lambda p: p.requires_grad, self.parameters()),
226 | lr=self.hparams.get("lr", 1e-3),
227 | weight_decay=self.hparams.get("weight_decay", 1e-2)
228 | )
229 | optimizer = AdamW(**optimizer_args)
230 | # valid scheduler names: diffusers.optimization.SchedulerType
231 | scheduler_name = self.hparams.get("scheduler_name", "cosine")
232 | scheduler = get_scheduler(
233 | scheduler_name, optimizer=optimizer,
234 | num_warmup_steps=self.hparams.get("warmup_steps", 8),
235 | num_training_steps=self.trainer.estimated_stepping_batches
236 | )
237 | lr_scheduler = dict(scheduler=scheduler, interval="step", frequency=1)
238 | return dict(optimizer=optimizer, lr_scheduler=lr_scheduler)
239 |
240 |
241 | class SDVideoModelEvaluator:
242 | def __init__(self, **kwargs):
243 | torch.multiprocessing.set_start_method("spawn", force=True)
244 | torch.multiprocessing.set_sharing_strategy("file_system")
245 |
246 | self.seed = kwargs.pop("seed", 42)
247 | self.prompts = kwargs.pop("prompts", None)
248 | if self.prompts is None:
249 | raise ValueError(f"No prompts provided.")
250 | elif isinstance(self.prompts, str) and not Path(self.prompts).exists():
251 | raise FileNotFoundError(f"Prompt file not found: {self.prompts}")
252 | elif isinstance(self.prompts, str):
253 | if self.prompts.endswith(".txt"):
254 | with open(self.prompts, "r", encoding="utf-8") as f:
255 | self.prompts = [x.strip() for x in f.readlines() if x.strip()]
256 | elif self.prompts.endswith(".json"):
257 | with open(self.prompts, "r", encoding="utf-8") as f:
258 | self.prompts = sorted([
259 | random.choice(x) if isinstance(x, list) else x
260 | for x in json.load(f).values()
261 | ])
262 | elif self.prompts.endswith(".csv"):
263 | # prompt images can be set in this condition.
264 | csv_path = self.prompts
265 | df_prompts = pd.read_csv(self.prompts)
266 | self.prompts = df_prompts.iloc[:, 0].tolist()
267 | if len(df_prompts.columns) >= 2:
268 | self.prompts_img = [Path(csv_path).parent.joinpath(x).as_posix() for x in df_prompts.iloc[:, 1]]
269 | else:
270 | self.prompts_img = None
271 |
272 | self.add_file_logger(logger, kwargs.pop("log_file", None))
273 | self.output_file = kwargs.pop("output_file", "results.csv")
274 | self.batch_size = kwargs.pop("batch_size", 4)
275 | self.val_params = kwargs
276 |
277 | @staticmethod
278 | def add_file_logger(logger, log_file=None, log_level=logging.INFO):
279 | if dist_envs.global_rank == 0 and log_file is not None:
280 | log_handler = logging.FileHandler(log_file, "w")
281 | log_handler.setFormatter(
282 | logging.Formatter("[%(asctime)s][%(name)s][%(levelname)s]: %(message)s"))
283 | log_handler.setLevel(log_level)
284 | logger.addHandler(log_handler)
285 |
286 | @staticmethod
287 | def infer(rank, model, model_params, q_input, q_output, seed=42):
288 | device = torch.device(f"cuda:{rank}")
289 | model.to(device)
290 | generator = torch.Generator(device=device)
291 | generator.manual_seed(seed + rank)
292 | output_video_dir = Path("output_videos")
293 | output_video_dir.mkdir(parents=True, exist_ok=True)
294 | while True:
295 | inputs = q_input.get()
296 | if inputs is None: # check for sentinel value
297 | print(f"[{datetime.now()}] Process #{rank} ended.")
298 | break
299 | start_idx, prompts, prompts_img = inputs
300 | if prompts_img is not None:
301 | prompts_img= prompts_img.to(device)
302 | videos = model.val_pipeline(
303 | prompts, prompts_img, generator=generator, negative_prompt=["watermark"] * len(prompts),
304 | **model_params
305 | ).videos
306 | for idx, prompt in enumerate(prompts):
307 | gif_file = output_video_dir.joinpath(f"{start_idx + idx}_{prompt}.gif")
308 | save_videos_grid(videos[idx:idx + 1, ...], gif_file)
309 | print(f"[{datetime.now()}] Sample is saved #{start_idx + idx}: \"{prompt}\"")
310 | clip_scores = compute_clip_score(
311 | model=model.clip, model_processor=model.clip_processor,
312 | images=videos, texts=prompts, rescale=False,
313 | )
314 | q_output.put((prompts, clip_scores.cpu().tolist()))
315 | return None
316 |
317 | def __call__(self, model):
318 | model.eval()
319 |
320 | # load prompts images if exist
321 | if model.use_prompt_image and self.prompts_img is not None:
322 | self.prompts_img = torch.stack([
323 | load_image(x, model.resolution, True, True)
324 | for x in self.prompts_img
325 | ])
326 |
327 | if not torch.cuda.is_available():
328 | raise NotImplementedError(f"No GPU found.")
329 |
330 | self.val_params.setdefault(
331 | "num_inference_steps", model.hparams.get("num_inference_steps", 50)
332 | )
333 | self.val_params.setdefault(
334 | "guidance_scale", model.hparams.get("guidance_scale", 7.5)
335 | )
336 | self.val_params.setdefault(
337 | "noise_alpha", model.hparams.get("noise_alpha", .0)
338 | )
339 | logger.info(f"val_params: {self.val_params}")
340 |
341 | q_input = torch.multiprocessing.Queue()
342 | q_output = torch.multiprocessing.Queue()
343 | processes = []
344 | for rank in range(torch.cuda.device_count()):
345 | p = torch.multiprocessing.Process(
346 | target=self.infer,
347 | args=(rank, model, self.val_params, q_input, q_output, self.seed)
348 | )
349 | p.start()
350 | processes.append(p)
351 | # send model inputs to queue
352 | result_num = 0
353 | for start_idx in range(0, len(self.prompts), self.batch_size):
354 | result_num += 1
355 | ref_images = self.prompts_img[start_idx:start_idx + self.batch_size] \
356 | if model.use_prompt_image and self.prompts_img is not None else None
357 | q_input.put((
358 | start_idx,
359 | self.prompts[start_idx:start_idx + self.batch_size],
360 | ref_images
361 | ))
362 | for _ in processes:
363 | q_input.put(None) # sentinel value to signal subprocesses to exit
364 | # The result queue has to be processed before joining the processes.
365 | results = [q_output.get() for _ in range(result_num)]
366 | # joining the processes
367 | for p in processes:
368 | p.join() # wait for all subprocesses to finish
369 | all_prompts, all_clip_scores = [], []
370 | for prompts, clip_scores in results:
371 | all_prompts.extend(prompts)
372 | all_clip_scores.extend(clip_scores)
373 | output_df = pd.DataFrame({
374 | "prompt": all_prompts, "clip_score": all_clip_scores
375 | })
376 | output_df.to_csv(self.output_file, index=False)
377 | logger.info(f"--- Metrics ---")
378 | logger.info(f"Mean CLIP_SCORE: {sum(all_clip_scores) / len(all_clip_scores)}")
379 | logger.info(f"Test results saved in: {self.output_file}")
380 |
--------------------------------------------------------------------------------
/model/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gulucaptain/videoassembler/1fb50cd5f85aa5896b85003aff3641b00946eddc/model/modules/__init__.py
--------------------------------------------------------------------------------
/model/modules/attention_hasFFN.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
2 | from dataclasses import dataclass
3 | from typing import Callable, Optional
4 |
5 | import torch
6 | from diffusers.configuration_utils import ConfigMixin, register_to_config
7 | from diffusers.models import ModelMixin
8 | from diffusers.models.attention import FeedForward, AdaLayerNorm
9 | from diffusers.models.cross_attention import CrossAttention
10 | from diffusers.utils import BaseOutput
11 | from diffusers.utils.import_utils import is_xformers_available
12 | from einops import rearrange, repeat
13 | from torch import nn
14 |
15 | from .modules import get_sin_pos_embedding
16 | from .utils import zero_module
17 |
18 | if is_xformers_available():
19 | import xformers
20 | import xformers.ops
21 | else:
22 | xformers = None
23 |
24 |
25 | class BasicTransformerBlock(nn.Module):
26 | r"""
27 | A basic Transformer block.
28 | Parameters:
29 | dim (`int`): The number of channels in the input and output.
30 | num_attention_heads (`int`): The number of heads to use for multi-head attention.
31 | attention_head_dim (`int`): The number of channels in each head.
32 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
33 | cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
34 | only_cross_attention (`bool`, *optional*):
35 | Whether to use only cross-attention layers. In this case two cross attention layers are used.
36 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
37 | num_embeds_ada_norm (:
38 | obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
39 | attention_bias (:
40 | obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
41 | """
42 |
43 | def __init__(
44 | self,
45 | dim: int,
46 | num_attention_heads: int,
47 | attention_head_dim: int,
48 | dropout=0.0,
49 | cross_attention_dim: Optional[int] = None,
50 | activation_fn: str = "geglu",
51 | num_embeds_ada_norm: Optional[int] = None,
52 | attention_bias: bool = False,
53 | only_cross_attention: bool = False,
54 | upcast_attention: bool = False,
55 | norm_elementwise_affine: bool = True,
56 | final_dropout: bool = False,
57 | add_temp_attn: bool = False,
58 | prepend_first_frame: bool = False,
59 | add_temp_embed: bool = False,
60 | ):
61 | super().__init__()
62 | self.only_cross_attention = only_cross_attention
63 | self.use_ada_layer_norm = num_embeds_ada_norm is not None
64 |
65 | # temporal embedding
66 | self.add_temp_embed = add_temp_embed
67 |
68 | if add_temp_attn:
69 | if prepend_first_frame:
70 | # SC-Attn
71 | self.attn1 = SparseCausalAttention(
72 | query_dim=dim,
73 | cross_attention_dim=cross_attention_dim if only_cross_attention else None,
74 | heads=num_attention_heads,
75 | dim_head=attention_head_dim,
76 | dropout=dropout,
77 | bias=attention_bias,
78 | upcast_attention=upcast_attention,
79 | )
80 | else:
81 | # Normal CrossAttn
82 | self.attn1 = CrossAttention(
83 | query_dim=dim,
84 | cross_attention_dim=cross_attention_dim if only_cross_attention else None,
85 | heads=num_attention_heads,
86 | dim_head=attention_head_dim,
87 | dropout=dropout,
88 | bias=attention_bias,
89 | upcast_attention=upcast_attention,
90 | )
91 |
92 | # Temp-Attn
93 | self.temp_norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
94 | self.temp_attn = CrossAttention(
95 | query_dim=dim,
96 | heads=num_attention_heads,
97 | dim_head=attention_head_dim,
98 | dropout=dropout,
99 | bias=attention_bias,
100 | upcast_attention=upcast_attention,
101 | )
102 | self.temp_norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
103 | self.temp_ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
104 | if isinstance(self.temp_ff.net[-1], nn.Linear):
105 | self.temp_ff.net[-1] = zero_module(self.temp_ff.net[-1])
106 | elif isinstance(self.temp_ff.net[-2], nn.Linear):
107 | self.temp_ff.net[-2] = zero_module(self.temp_ff.net[-2])
108 | else:
109 | # Normal Attention
110 | self.attn1 = CrossAttention(
111 | query_dim=dim,
112 | cross_attention_dim=cross_attention_dim if only_cross_attention else None,
113 | heads=num_attention_heads,
114 | dim_head=attention_head_dim,
115 | dropout=dropout,
116 | bias=attention_bias,
117 | upcast_attention=upcast_attention,
118 | )
119 | self.temp_attn = None
120 |
121 | self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) \
122 | if self.use_ada_layer_norm else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
123 |
124 | # Cross-Attn
125 | if cross_attention_dim is not None:
126 | self.attn2 = CrossAttention(
127 | query_dim=dim,
128 | cross_attention_dim=cross_attention_dim,
129 | heads=num_attention_heads,
130 | dim_head=attention_head_dim,
131 | dropout=dropout,
132 | bias=attention_bias,
133 | upcast_attention=upcast_attention,
134 | )
135 | # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
136 | # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
137 | # the second cross attention block.
138 | self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) \
139 | if self.use_ada_layer_norm else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
140 | else:
141 | self.attn2 = None
142 | self.norm2 = None
143 |
144 | # Feed-forward
145 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
146 | self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
147 |
148 | def set_use_memory_efficient_attention_xformers(
149 | self, use_memory_efficient_attention_xformers: bool,
150 | attention_op: Optional[Callable] = None
151 | ):
152 | if not is_xformers_available():
153 | print("Here is how to install it")
154 | raise ModuleNotFoundError(
155 | "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
156 | " xformers",
157 | name="xformers",
158 | )
159 | elif not torch.cuda.is_available():
160 | raise ValueError(
161 | "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
162 | " available for GPU "
163 | )
164 | else:
165 | try:
166 | # Make sure we can run the memory efficient attention
167 | xformers.ops.memory_efficient_attention(
168 | torch.randn((1, 2, 40), device="cuda"),
169 | torch.randn((1, 2, 40), device="cuda"),
170 | torch.randn((1, 2, 40), device="cuda"),
171 | )
172 | except Exception as e:
173 | raise e
174 | self.attn1.set_use_memory_efficient_attention_xformers(
175 | use_memory_efficient_attention_xformers, attention_op=attention_op)
176 | if self.attn2 is not None:
177 | self.attn2.set_use_memory_efficient_attention_xformers(
178 | use_memory_efficient_attention_xformers, attention_op=attention_op)
179 | if self.temp_attn is not None:
180 | self.temp_attn.set_use_memory_efficient_attention_xformers(
181 | use_memory_efficient_attention_xformers, attention_op=attention_op)
182 |
183 | def forward(
184 | self,
185 | hidden_states,
186 | encoder_hidden_states=None,
187 | timestep=None,
188 | attention_mask=None,
189 | video_length=None,
190 | ):
191 | # SparseCausal-Attention
192 | norm_hidden_states = (
193 | self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
194 | )
195 |
196 | attn1_args = dict(hidden_states=norm_hidden_states, attention_mask=attention_mask)
197 | if self.temp_attn is not None and isinstance(self.attn1, SparseCausalAttention):
198 | attn1_args.update({"video_length": video_length})
199 | # Self-/Sparse-Attention
200 | if self.only_cross_attention:
201 | hidden_states = self.attn1(
202 | **attn1_args, encoder_hidden_states=encoder_hidden_states
203 | ) + hidden_states
204 | else:
205 | hidden_states = self.attn1(**attn1_args) + hidden_states
206 |
207 | if self.attn2 is not None:
208 | # Cross-Attention
209 | norm_hidden_states = (
210 | self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
211 | )
212 | hidden_states = (
213 | self.attn2(
214 | norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
215 | )
216 | + hidden_states
217 | )
218 |
219 | if self.temp_attn is not None:
220 | identity = hidden_states
221 | d = hidden_states.shape[1]
222 | # add temporal embedding
223 | if self.add_temp_embed:
224 | temp_emb = get_sin_pos_embedding(
225 | hidden_states.shape[-1], video_length).to(hidden_states)
226 | hidden_states = rearrange(hidden_states, "(b f) d c -> b d f c", f=video_length)
227 | hidden_states += temp_emb
228 | hidden_states = rearrange(hidden_states, "b d f c -> (b f) d c")
229 | # normalization
230 | hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
231 | norm_hidden_states1 = (
232 | self.temp_norm1(hidden_states, timestep) if self.use_ada_layer_norm
233 | else self.temp_norm1(hidden_states)
234 | )
235 | # apply temporal attention
236 | hidden_states = self.temp_attn(norm_hidden_states1) + hidden_states
237 | # apply temporal feed-forward
238 | norm_hidden_states2 = (
239 | self.temp_norm2(hidden_states, timestep) if self.use_ada_layer_norm
240 | else self.temp_norm2(hidden_states)
241 | )
242 | hidden_states = self.temp_ff(norm_hidden_states2) + hidden_states
243 | # rearrange back
244 | hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
245 | # ignore effects of temporal layers on image inputs
246 | if video_length <= 1:
247 | hidden_states = identity + 0.0 * hidden_states
248 |
249 | # Feed-forward
250 | hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
251 |
252 | return hidden_states
253 |
254 |
255 | @dataclass
256 | class Transformer3DModelOutput(BaseOutput):
257 | sample: torch.FloatTensor
258 |
259 |
260 | class Transformer3DModel(ModelMixin, ConfigMixin):
261 | @register_to_config
262 | def __init__(
263 | self,
264 | num_attention_heads: int = 16,
265 | attention_head_dim: int = 88,
266 | in_channels: Optional[int] = None,
267 | num_layers: int = 1,
268 | dropout: float = 0.0,
269 | norm_num_groups: int = 32,
270 | cross_attention_dim: Optional[int] = None,
271 | attention_bias: bool = False,
272 | activation_fn: str = "geglu",
273 | num_embeds_ada_norm: Optional[int] = None,
274 | use_linear_projection: bool = False,
275 | only_cross_attention: bool = False,
276 | upcast_attention: bool = False,
277 | add_temp_attn: bool = False,
278 | prepend_first_frame: bool = False,
279 | add_temp_embed: bool = False
280 | ):
281 | super().__init__()
282 | self.use_linear_projection = use_linear_projection
283 | self.num_attention_heads = num_attention_heads
284 | self.attention_head_dim = attention_head_dim
285 | inner_dim = num_attention_heads * attention_head_dim
286 |
287 | # Define input layers
288 | self.in_channels = in_channels
289 |
290 | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
291 | if use_linear_projection:
292 | self.proj_in = nn.Linear(in_channels, inner_dim)
293 | else:
294 | self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
295 |
296 | # Define transformers blocks
297 | self.transformer_blocks = nn.ModuleList(
298 | [
299 | BasicTransformerBlock(
300 | inner_dim,
301 | num_attention_heads,
302 | attention_head_dim,
303 | dropout=dropout,
304 | cross_attention_dim=cross_attention_dim,
305 | activation_fn=activation_fn,
306 | num_embeds_ada_norm=num_embeds_ada_norm,
307 | attention_bias=attention_bias,
308 | only_cross_attention=only_cross_attention,
309 | upcast_attention=upcast_attention,
310 | add_temp_attn=add_temp_attn,
311 | prepend_first_frame=prepend_first_frame,
312 | add_temp_embed=add_temp_embed
313 | )
314 | for _ in range(num_layers)
315 | ]
316 | )
317 |
318 | # 4. Define output layers
319 | if use_linear_projection:
320 | self.proj_out = nn.Linear(in_channels, inner_dim)
321 | else:
322 | self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
323 |
324 | def forward(self, hidden_states, encoder_hidden_states=None, timestep=None,
325 | return_dict=False):
326 | # Input
327 | assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
328 | video_length = hidden_states.shape[2]
329 | hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
330 | encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length)
331 |
332 | batch, channel, height, weight = hidden_states.shape
333 | residual = hidden_states
334 |
335 | hidden_states = self.norm(hidden_states)
336 | if not self.use_linear_projection:
337 | hidden_states = self.proj_in(hidden_states)
338 | inner_dim = hidden_states.shape[1]
339 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
340 | else:
341 | inner_dim = hidden_states.shape[1]
342 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
343 | hidden_states = self.proj_in(hidden_states)
344 |
345 | # Blocks
346 | for block in self.transformer_blocks:
347 | hidden_states = block(
348 | hidden_states,
349 | encoder_hidden_states=encoder_hidden_states,
350 | timestep=timestep,
351 | video_length=video_length
352 | )
353 |
354 | # Output
355 | if not self.use_linear_projection:
356 | hidden_states = (
357 | hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
358 | )
359 | hidden_states = self.proj_out(hidden_states)
360 | else:
361 | hidden_states = self.proj_out(hidden_states)
362 | hidden_states = (
363 | hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
364 | )
365 |
366 | output = hidden_states + residual
367 |
368 | output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
369 | if not return_dict:
370 | return output,
371 |
372 | return Transformer3DModelOutput(sample=output)
373 |
374 |
375 | @dataclass
376 | class TransformerTemporalModelOutput(BaseOutput):
377 | """
378 | Args:
379 | sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`)
380 | Hidden states conditioned on `encoder_hidden_states` input.
381 | """
382 | sample: torch.FloatTensor
383 |
384 |
385 | class TransformerTemporalModel(ModelMixin, ConfigMixin):
386 | """
387 | Transformer model for video-like data.
388 |
389 | Parameters:
390 | num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
391 | attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
392 | in_channels (`int`, *optional*):
393 | Pass if the input is continuous. The number of channels in the input and output.
394 | num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
395 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
396 | cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
397 | sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
398 | Note that this is fixed at training time as it is used for learning a number of position embeddings. See
399 | `ImagePositionalEmbeddings`.
400 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
401 | attention_bias (`bool`, *optional*):
402 | Configure if the TransformerBlocks' attention should contain a bias parameter.
403 | """
404 |
405 | @register_to_config
406 | def __init__(
407 | self,
408 | num_attention_heads: int = 16,
409 | attention_head_dim: int = 88,
410 | in_channels: Optional[int] = None,
411 | num_layers: int = 1,
412 | dropout: float = 0.0,
413 | norm_num_groups: int = 32,
414 | cross_attention_dim: Optional[int] = None,
415 | attention_bias: bool = False,
416 | activation_fn: str = "geglu",
417 | norm_elementwise_affine: bool = True,
418 | add_temp_embed: bool = False,
419 | ):
420 | super().__init__()
421 | self.num_attention_heads = num_attention_heads
422 | self.attention_head_dim = attention_head_dim
423 | inner_dim = num_attention_heads * attention_head_dim
424 |
425 | self.in_channels = in_channels
426 |
427 | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
428 | self.proj_in = nn.Linear(in_channels, inner_dim)
429 |
430 | # 3. Define transformers blocks
431 | self.transformer_blocks = nn.ModuleList(
432 | [
433 | BasicTransformerBlock(
434 | inner_dim,
435 | num_attention_heads,
436 | attention_head_dim,
437 | dropout=dropout,
438 | cross_attention_dim=cross_attention_dim,
439 | activation_fn=activation_fn,
440 | attention_bias=attention_bias,
441 | norm_elementwise_affine=norm_elementwise_affine,
442 | add_temp_embed=add_temp_embed
443 | )
444 | for _ in range(num_layers)
445 | ]
446 | )
447 |
448 | self.proj_out = nn.Linear(inner_dim, in_channels)
449 | self.proj_out = zero_module(self.proj_out)
450 |
451 | def forward(
452 | self,
453 | hidden_states,
454 | encoder_hidden_states=None,
455 | timestep=None,
456 | return_dict: bool = True,
457 | ):
458 | """
459 | Args:
460 | hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
461 | When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
462 | hidden_states
463 | encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
464 | Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
465 | self-attention.
466 | timestep ( `torch.long`, *optional*):
467 | Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
468 | return_dict (`bool`, *optional*, defaults to `True`):
469 | Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
470 |
471 | Returns:
472 | [`~models.transformer_2d.TransformerTemporalModelOutput`] or `tuple`:
473 | [`~models.transformer_2d.TransformerTemporalModelOutput`] if `return_dict` is True, otherwise a `tuple`.
474 | When returning a tuple, the first element is the sample tensor.
475 | """
476 | # 1. Input
477 | batch_size, channel, num_frames, height, width = hidden_states.shape
478 |
479 | residual = hidden_states
480 |
481 | hidden_states = self.norm(hidden_states)
482 | hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(
483 | batch_size * height * width, num_frames, channel)
484 | hidden_states = self.proj_in(hidden_states)
485 |
486 | # 2. Blocks
487 | for block in self.transformer_blocks:
488 | hidden_states = block(
489 | hidden_states,
490 | encoder_hidden_states=encoder_hidden_states,
491 | timestep=timestep,
492 | video_length=num_frames
493 | )
494 |
495 | # 3. Output
496 | hidden_states = self.proj_out(hidden_states)
497 | hidden_states = (
498 | hidden_states[None, None, :]
499 | .reshape(batch_size, height, width, channel, num_frames)
500 | .permute(0, 3, 4, 1, 2)
501 | .contiguous()
502 | )
503 | output = hidden_states + residual
504 |
505 | if not return_dict:
506 | return output,
507 |
508 | return TransformerTemporalModelOutput(sample=output)
509 |
510 |
511 | class SparseCausalAttention(CrossAttention):
512 | def forward(self, hidden_states, encoder_hidden_states=None,
513 | attention_mask=None, **cross_attention_kwargs):
514 | batch_size, sequence_length, _ = hidden_states.shape
515 | video_length = cross_attention_kwargs.get("video_length", 8)
516 | attention_mask = self.prepare_attention_mask(
517 | attention_mask, sequence_length, batch_size)
518 | query = self.to_q(hidden_states)
519 | dim = query.shape[-1]
520 |
521 | if self.added_kv_proj_dim is not None:
522 | raise NotImplementedError
523 |
524 | if encoder_hidden_states is None:
525 | encoder_hidden_states = hidden_states
526 | elif self.cross_attention_norm:
527 | encoder_hidden_states = self.norm_cross(encoder_hidden_states)
528 |
529 | key = self.to_k(encoder_hidden_states)
530 | value = self.to_v(encoder_hidden_states)
531 |
532 | former_frame_index = torch.arange(video_length) - 1
533 | former_frame_index[0] = 0
534 |
535 | key = rearrange(key, "(b f) d c -> b f d c", f=video_length)
536 | if video_length > 1:
537 | key = torch.cat([key[:, [0] * video_length], key[:, former_frame_index]], dim=2)
538 | key = rearrange(key, "b f d c -> (b f) d c")
539 |
540 | value = rearrange(value, "(b f) d c -> b f d c", f=video_length)
541 | if video_length > 1:
542 | value = torch.cat([value[:, [0] * video_length], value[:, former_frame_index]], dim=2)
543 | value = rearrange(value, "b f d c -> (b f) d c")
544 |
545 | query = self.head_to_batch_dim(query)
546 | key = self.head_to_batch_dim(key)
547 | value = self.head_to_batch_dim(value)
548 |
549 | # attention, what we cannot get enough of
550 | if hasattr(self.processor, "attention_op"):
551 | hidden_states = xformers.ops.memory_efficient_attention(
552 | query, key, value, attn_bias=attention_mask, op=self.processor.attention_op
553 | )
554 | hidden_states = hidden_states.to(query.dtype)
555 | elif hasattr(self.processor, "slice_size"):
556 | batch_size_attention = query.shape[0]
557 | hidden_states = torch.zeros(
558 | (batch_size_attention, sequence_length, dim // self.heads),
559 | device=query.device, dtype=query.dtype
560 | )
561 | for i in range(hidden_states.shape[0] // self.processor.slice_size):
562 | start_idx = i * self.slice_size
563 | end_idx = (i + 1) * self.slice_size
564 | query_slice = query[start_idx:end_idx]
565 | key_slice = key[start_idx:end_idx]
566 | attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
567 | attn_slice = self.get_attention_scores(query_slice, key_slice, attn_mask_slice)
568 | attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
569 | hidden_states[start_idx:end_idx] = attn_slice
570 | else:
571 | attention_probs = self.get_attention_scores(query, key, attention_mask)
572 | hidden_states = torch.bmm(attention_probs, value)
573 | hidden_states = self.batch_to_head_dim(hidden_states)
574 |
575 | # linear proj
576 | hidden_states = self.to_out[0](hidden_states)
577 |
578 | # dropout
579 | hidden_states = self.to_out[1](hidden_states)
580 | return hidden_states
581 |
--------------------------------------------------------------------------------
/model/modules/modules.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 |
5 |
6 | def get_sin_pos_embedding(embed_dim, seq_len):
7 | """
8 | :param embed_dim: dimension of the model
9 | :param seq_len: length of positions
10 | :return: [length, embed_dim] position matrix
11 | """
12 | if embed_dim % 2 != 0:
13 | raise ValueError("Cannot use sin/cos positional encoding with "
14 | "odd dim (got dim={:d})".format(embed_dim))
15 | pe = torch.zeros(seq_len, embed_dim)
16 | position = torch.arange(0, seq_len).unsqueeze(1)
17 | div_term = torch.exp(torch.arange(0, embed_dim, 2, dtype=torch.float) *
18 | -(math.log(10000.0) / embed_dim))
19 | pe[:, 0::2] = torch.sin(position.float() * div_term)
20 | pe[:, 1::2] = torch.cos(position.float() * div_term)
21 |
22 | return pe
23 |
--------------------------------------------------------------------------------
/model/modules/resnet.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from einops import rearrange
8 |
9 |
10 | class InflatedConv3d(nn.Conv2d):
11 | def forward(self, x):
12 | video_length = x.shape[2]
13 |
14 | x = rearrange(x, "b c f h w -> (b f) c h w")
15 | x = super().forward(x)
16 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
17 | return x
18 |
19 |
20 | class Upsample3D(nn.Module):
21 | def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
22 | super().__init__()
23 | self.channels = channels
24 | self.out_channels = out_channels or channels
25 | self.use_conv = use_conv
26 | self.use_conv_transpose = use_conv_transpose
27 | self.name = name
28 |
29 | conv = None
30 | if use_conv_transpose:
31 | raise NotImplementedError
32 | elif use_conv:
33 | conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
34 |
35 | if name == "conv":
36 | self.conv = conv
37 | else:
38 | self.Conv2d_0 = conv
39 |
40 | def forward(self, hidden_states, output_size=None):
41 | assert hidden_states.shape[1] == self.channels
42 |
43 | if self.use_conv_transpose:
44 | raise NotImplementedError
45 |
46 | # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
47 | dtype = hidden_states.dtype
48 | if dtype == torch.bfloat16:
49 | hidden_states = hidden_states.to(torch.float32)
50 |
51 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
52 | if hidden_states.shape[0] >= 64:
53 | hidden_states = hidden_states.contiguous()
54 |
55 | # if `output_size` is passed we force the interpolation output
56 | # size and do not make use of `scale_factor=2`
57 | if output_size is None:
58 | hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest")
59 | else:
60 | hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
61 |
62 | # If the input is bfloat16, we cast back to bfloat16
63 | if dtype == torch.bfloat16:
64 | hidden_states = hidden_states.to(dtype)
65 |
66 | if self.use_conv:
67 | if self.name == "conv":
68 | hidden_states = self.conv(hidden_states)
69 | else:
70 | hidden_states = self.Conv2d_0(hidden_states)
71 |
72 | return hidden_states
73 |
74 |
75 | class Downsample3D(nn.Module):
76 | def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
77 | super().__init__()
78 | self.channels = channels
79 | self.out_channels = out_channels or channels
80 | self.use_conv = use_conv
81 | self.padding = padding
82 | stride = 2
83 | self.name = name
84 |
85 | if use_conv:
86 | conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
87 | else:
88 | raise NotImplementedError
89 |
90 | if name == "conv":
91 | self.Conv2d_0 = conv
92 | self.conv = conv
93 | elif name == "Conv2d_0":
94 | self.conv = conv
95 | else:
96 | self.conv = conv
97 |
98 | def forward(self, hidden_states):
99 | assert hidden_states.shape[1] == self.channels
100 | if self.use_conv and self.padding == 0:
101 | raise NotImplementedError
102 |
103 | assert hidden_states.shape[1] == self.channels
104 | hidden_states = self.conv(hidden_states)
105 |
106 | return hidden_states
107 |
108 |
109 | class ResnetBlock3D(nn.Module):
110 | def __init__(
111 | self,
112 | *,
113 | in_channels,
114 | out_channels=None,
115 | conv_shortcut=False,
116 | dropout=0.0,
117 | temb_channels=512,
118 | groups=32,
119 | groups_out=None,
120 | pre_norm=True,
121 | eps=1e-6,
122 | non_linearity="swish",
123 | time_embedding_norm="default",
124 | output_scale_factor=1.0,
125 | use_in_shortcut=None,
126 | ):
127 | super().__init__()
128 | self.pre_norm = pre_norm
129 | self.pre_norm = True
130 | self.in_channels = in_channels
131 | out_channels = in_channels if out_channels is None else out_channels
132 | self.out_channels = out_channels
133 | self.use_conv_shortcut = conv_shortcut
134 | self.time_embedding_norm = time_embedding_norm
135 | self.output_scale_factor = output_scale_factor
136 |
137 | if groups_out is None:
138 | groups_out = groups
139 |
140 | self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
141 |
142 | self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
143 |
144 | if temb_channels is not None:
145 | if self.time_embedding_norm == "default":
146 | time_emb_proj_out_channels = out_channels
147 | elif self.time_embedding_norm == "scale_shift":
148 | time_emb_proj_out_channels = out_channels * 2
149 | else:
150 | raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
151 |
152 | self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
153 | else:
154 | self.time_emb_proj = None
155 |
156 | self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
157 | self.dropout = torch.nn.Dropout(dropout)
158 | self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
159 |
160 | if non_linearity == "swish":
161 | self.nonlinearity = lambda x: F.silu(x)
162 | elif non_linearity == "mish":
163 | self.nonlinearity = Mish()
164 | elif non_linearity == "silu":
165 | self.nonlinearity = nn.SiLU()
166 |
167 | self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
168 |
169 | self.conv_shortcut = None
170 | if self.use_in_shortcut:
171 | self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
172 |
173 | def forward(self, input_tensor, temb):
174 | hidden_states = input_tensor
175 |
176 | hidden_states = self.norm1(hidden_states)
177 | hidden_states = self.nonlinearity(hidden_states)
178 |
179 | hidden_states = self.conv1(hidden_states)
180 |
181 | if temb is not None:
182 | temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
183 |
184 | if temb is not None and self.time_embedding_norm == "default":
185 | hidden_states = hidden_states + temb
186 |
187 | hidden_states = self.norm2(hidden_states)
188 |
189 | if temb is not None and self.time_embedding_norm == "scale_shift":
190 | scale, shift = torch.chunk(temb, 2, dim=1)
191 | hidden_states = hidden_states * (1 + scale) + shift
192 |
193 | hidden_states = self.nonlinearity(hidden_states)
194 |
195 | hidden_states = self.dropout(hidden_states)
196 | hidden_states = self.conv2(hidden_states)
197 |
198 | if self.conv_shortcut is not None:
199 | input_tensor = self.conv_shortcut(input_tensor)
200 |
201 | output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
202 |
203 | return output_tensor
204 |
205 |
206 | class Mish(torch.nn.Module):
207 | def forward(self, hidden_states):
208 | return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
209 |
--------------------------------------------------------------------------------
/model/modules/uent.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py
2 |
3 | import json
4 | import os
5 | from dataclasses import dataclass
6 | from typing import List, Optional, Tuple, Union
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.utils.checkpoint
11 | from diffusers.configuration_utils import ConfigMixin, register_to_config
12 | from diffusers.models import ModelMixin
13 | from diffusers.models.embeddings import TimestepEmbedding, Timesteps
14 | from diffusers.utils import BaseOutput, logging
15 |
16 | from .attention import TransformerTemporalModel
17 | from .resnet import InflatedConv3d
18 | from .unet_blocks import (
19 | CrossAttnDownBlock3D,
20 | CrossAttnUpBlock3D,
21 | DownBlock3D,
22 | UNetMidBlock3DCrossAttn,
23 | UpBlock3D,
24 | get_down_block,
25 | get_up_block,
26 | )
27 |
28 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
29 |
30 |
31 | @dataclass
32 | class UNet3DConditionOutput(BaseOutput):
33 | sample: torch.FloatTensor
34 |
35 |
36 | class UNet3DConditionModel(ModelMixin, ConfigMixin):
37 | _supports_gradient_checkpointing = True
38 |
39 | @register_to_config
40 | def __init__(
41 | self,
42 | sample_size: Optional[int] = None,
43 | in_channels: int = 4,
44 | out_channels: int = 4,
45 | center_input_sample: bool = False,
46 | flip_sin_to_cos: bool = True,
47 | freq_shift: int = 0,
48 | down_block_types: Tuple[str] = (
49 | "CrossAttnDownBlock3D",
50 | "CrossAttnDownBlock3D",
51 | "CrossAttnDownBlock3D",
52 | "DownBlock3D",
53 | ),
54 | mid_block_type: str = "UNetMidBlock3DCrossAttn",
55 | up_block_types: Tuple[str] = (
56 | "UpBlock3D",
57 | "CrossAttnUpBlock3D",
58 | "CrossAttnUpBlock3D",
59 | "CrossAttnUpBlock3D",
60 | ),
61 | only_cross_attention: Union[bool, Tuple[bool]] = False,
62 | block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
63 | layers_per_block: int = 2,
64 | downsample_padding: int = 1,
65 | mid_block_scale_factor: float = 1,
66 | act_fn: str = "silu",
67 | norm_num_groups: int = 32,
68 | norm_eps: float = 1e-5,
69 | cross_attention_dim: int = 1280,
70 | attention_head_dim: Union[int, Tuple[int]] = 8,
71 | dual_cross_attention: bool = False,
72 | use_linear_projection: bool = False,
73 | class_embed_type: Optional[str] = None,
74 | num_class_embeds: Optional[int] = None,
75 | upcast_attention: bool = False,
76 | resnet_time_scale_shift: str = "default",
77 | add_temp_transformer: bool = False,
78 | add_temp_attn_only_on_upblocks: bool = False,
79 | prepend_first_frame: bool = False,
80 | add_temp_embed: bool = False,
81 | add_temp_ff: bool = False,
82 | add_temp_conv: bool = False,
83 | ):
84 | super().__init__()
85 |
86 | self.sample_size = sample_size
87 | time_embed_dim = block_out_channels[0] * 4
88 |
89 | # input
90 | self.conv_in = InflatedConv3d(
91 | in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)
92 | )
93 |
94 | self.temp_transformer = TransformerTemporalModel(
95 | num_attention_heads=8, attention_head_dim=64,
96 | in_channels=block_out_channels[0],
97 | num_layers=1, add_temp_embed=add_temp_embed,
98 | add_temp_ff=add_temp_ff
99 | ) if add_temp_transformer else None
100 |
101 | # time
102 | self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
103 | timestep_input_dim = block_out_channels[0]
104 |
105 | self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
106 |
107 | # class embedding
108 | if class_embed_type is None and num_class_embeds is not None:
109 | self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
110 | elif class_embed_type == "timestep":
111 | self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
112 | elif class_embed_type == "identity":
113 | self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
114 | else:
115 | self.class_embedding = None
116 |
117 | self.down_blocks = nn.ModuleList([])
118 | self.mid_block = None
119 | self.up_blocks = nn.ModuleList([])
120 |
121 | if isinstance(only_cross_attention, bool):
122 | only_cross_attention = [only_cross_attention] * len(down_block_types)
123 |
124 | if isinstance(attention_head_dim, int):
125 | attention_head_dim = (attention_head_dim,) * len(down_block_types)
126 |
127 | # down
128 | output_channel = block_out_channels[0]
129 | for i, down_block_type in enumerate(down_block_types):
130 | input_channel = output_channel
131 | output_channel = block_out_channels[i]
132 | is_final_block = i == len(block_out_channels) - 1
133 |
134 | down_block = get_down_block(
135 | down_block_type,
136 | num_layers=layers_per_block,
137 | in_channels=input_channel,
138 | out_channels=output_channel,
139 | temb_channels=time_embed_dim,
140 | add_downsample=not is_final_block,
141 | resnet_eps=norm_eps,
142 | resnet_act_fn=act_fn,
143 | resnet_groups=norm_num_groups,
144 | cross_attention_dim=cross_attention_dim,
145 | attn_num_head_channels=attention_head_dim[i],
146 | downsample_padding=downsample_padding,
147 | dual_cross_attention=dual_cross_attention,
148 | use_linear_projection=use_linear_projection,
149 | only_cross_attention=only_cross_attention[i],
150 | upcast_attention=upcast_attention,
151 | resnet_time_scale_shift=resnet_time_scale_shift,
152 | add_temp_attn=not add_temp_attn_only_on_upblocks,
153 | prepend_first_frame=prepend_first_frame,
154 | add_temp_embed=add_temp_embed,
155 | add_temp_ff=add_temp_ff,
156 | add_temp_conv=add_temp_conv
157 | )
158 | self.down_blocks.append(down_block)
159 |
160 | # mid
161 | if mid_block_type == "UNetMidBlock3DCrossAttn":
162 | self.mid_block = UNetMidBlock3DCrossAttn(
163 | in_channels=block_out_channels[-1],
164 | temb_channels=time_embed_dim,
165 | resnet_eps=norm_eps,
166 | resnet_act_fn=act_fn,
167 | output_scale_factor=mid_block_scale_factor,
168 | resnet_time_scale_shift=resnet_time_scale_shift,
169 | cross_attention_dim=cross_attention_dim,
170 | attn_num_head_channels=attention_head_dim[-1],
171 | resnet_groups=norm_num_groups,
172 | dual_cross_attention=dual_cross_attention,
173 | use_linear_projection=use_linear_projection,
174 | upcast_attention=upcast_attention,
175 | add_temp_attn=not add_temp_attn_only_on_upblocks,
176 | prepend_first_frame=prepend_first_frame,
177 | add_temp_embed=add_temp_embed,
178 | add_temp_ff=add_temp_ff,
179 | add_temp_conv=add_temp_conv,
180 | )
181 | else:
182 | raise ValueError(f"unknown mid_block_type : {mid_block_type}")
183 |
184 | # count how many layers upsample the videos
185 | self.num_upsamplers = 0
186 |
187 | # up
188 | reversed_block_out_channels = list(reversed(block_out_channels))
189 | reversed_attention_head_dim = list(reversed(attention_head_dim))
190 | only_cross_attention = list(reversed(only_cross_attention))
191 | output_channel = reversed_block_out_channels[0]
192 | for i, up_block_type in enumerate(up_block_types):
193 | is_final_block = i == len(block_out_channels) - 1
194 |
195 | prev_output_channel = output_channel
196 | output_channel = reversed_block_out_channels[i]
197 | input_channel = reversed_block_out_channels[
198 | min(i + 1, len(block_out_channels) - 1)
199 | ]
200 |
201 | # add upsample block for all BUT final layer
202 | if not is_final_block:
203 | add_upsample = True
204 | self.num_upsamplers += 1
205 | else:
206 | add_upsample = False
207 |
208 | up_block = get_up_block(
209 | up_block_type,
210 | num_layers=layers_per_block + 1,
211 | in_channels=input_channel,
212 | out_channels=output_channel,
213 | prev_output_channel=prev_output_channel,
214 | temb_channels=time_embed_dim,
215 | add_upsample=add_upsample,
216 | resnet_eps=norm_eps,
217 | resnet_act_fn=act_fn,
218 | resnet_groups=norm_num_groups,
219 | cross_attention_dim=cross_attention_dim,
220 | attn_num_head_channels=reversed_attention_head_dim[i],
221 | dual_cross_attention=dual_cross_attention,
222 | use_linear_projection=use_linear_projection,
223 | only_cross_attention=only_cross_attention[i],
224 | upcast_attention=upcast_attention,
225 | resnet_time_scale_shift=resnet_time_scale_shift,
226 | add_temp_attn=True,
227 | prepend_first_frame=prepend_first_frame,
228 | add_temp_embed=add_temp_embed,
229 | add_temp_ff=add_temp_ff,
230 | add_temp_conv=add_temp_conv
231 | )
232 | self.up_blocks.append(up_block)
233 | prev_output_channel = output_channel
234 |
235 | # out
236 | self.conv_norm_out = nn.GroupNorm(
237 | num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
238 | )
239 | self.conv_act = nn.SiLU()
240 | self.conv_out = InflatedConv3d(
241 | block_out_channels[0], out_channels, kernel_size=3, padding=1
242 | )
243 |
244 | def set_attention_slice(self, slice_size):
245 | r"""
246 | Enable sliced attention computation.
247 |
248 | When this option is enabled, the attention module will split the input tensor in slices, to compute attention
249 | in several steps. This is useful to save some memory in exchange for a small speed decrease.
250 |
251 | Args:
252 | slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
253 | When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
254 | `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
255 | provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
256 | must be a multiple of `slice_size`.
257 | """
258 | sliceable_head_dims = []
259 |
260 | def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
261 | if hasattr(module, "set_attention_slice"):
262 | sliceable_head_dims.append(module.sliceable_head_dim)
263 |
264 | for child in module.children():
265 | fn_recursive_retrieve_slicable_dims(child)
266 |
267 | # retrieve number of attention layers
268 | for module in self.children():
269 | fn_recursive_retrieve_slicable_dims(module)
270 |
271 | num_slicable_layers = len(sliceable_head_dims)
272 |
273 | if slice_size == "auto":
274 | # half the attention head size is usually a good trade-off between
275 | # speed and memory
276 | slice_size = [dim // 2 for dim in sliceable_head_dims]
277 | elif slice_size == "max":
278 | # make smallest slice possible
279 | slice_size = num_slicable_layers * [1]
280 |
281 | slice_size = (
282 | num_slicable_layers * [slice_size]
283 | if not isinstance(slice_size, list)
284 | else slice_size
285 | )
286 |
287 | if len(slice_size) != len(sliceable_head_dims):
288 | raise ValueError(
289 | f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
290 | f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
291 | )
292 |
293 | for i in range(len(slice_size)):
294 | size = slice_size[i]
295 | dim = sliceable_head_dims[i]
296 | if size is not None and size > dim:
297 | raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
298 |
299 | # Recursively walk through all the children.
300 | # Any children which exposes the set_attention_slice method
301 | # gets the message
302 | def fn_recursive_set_attention_slice(
303 | module: torch.nn.Module, slice_size: List[int]
304 | ):
305 | if hasattr(module, "set_attention_slice"):
306 | module.set_attention_slice(slice_size.pop())
307 |
308 | for child in module.children():
309 | fn_recursive_set_attention_slice(child, slice_size)
310 |
311 | reversed_slice_size = list(reversed(slice_size))
312 | for module in self.children():
313 | fn_recursive_set_attention_slice(module, reversed_slice_size)
314 |
315 | def _set_gradient_checkpointing(self, module, value=False):
316 | if isinstance(
317 | module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)
318 | ):
319 | module.gradient_checkpointing = value
320 |
321 | def forward(
322 | self,
323 | sample: torch.FloatTensor,
324 | timestep: Union[torch.Tensor, float, int],
325 | encoder_hidden_states: torch.Tensor,
326 | class_labels: Optional[torch.Tensor] = None,
327 | attention_mask: Optional[torch.Tensor] = None,
328 | return_dict: bool = True,
329 | ) -> Union[UNet3DConditionOutput, Tuple]:
330 | r"""
331 | Args:
332 | sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
333 | timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
334 | encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
335 | return_dict (`bool`, *optional*, defaults to `True`):
336 | Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
337 |
338 | Returns:
339 | [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
340 | [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
341 | returning a tuple, the first element is the sample tensor.
342 | """
343 | # By default samples have to be AT least a multiple of the overall upsampling factor.
344 | # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
345 | # However, the upsampling interpolation output size can be forced to fit any upsampling size
346 | # on the fly if necessary.
347 | default_overall_up_factor = 2 ** self.num_upsamplers
348 |
349 | # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
350 | forward_upsample_size = False
351 | upsample_size = None
352 |
353 | if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
354 | logger.info("Forward upsample size to force interpolation output size.")
355 | forward_upsample_size = True
356 |
357 | # prepare attention_mask
358 | if attention_mask is not None:
359 | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
360 | attention_mask = attention_mask.unsqueeze(1)
361 |
362 | # center input if necessary
363 | if self.config.center_input_sample:
364 | sample = 2 * sample - 1.0
365 |
366 | # time
367 | timesteps = timestep
368 | if not torch.is_tensor(timesteps):
369 | # This would be a good case for the `match` statement (Python 3.10+)
370 | is_mps = sample.device.type == "mps"
371 | if isinstance(timestep, float):
372 | dtype = torch.float32 if is_mps else torch.float64
373 | else:
374 | dtype = torch.int32 if is_mps else torch.int64
375 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
376 | elif len(timesteps.shape) == 0:
377 | timesteps = timesteps[None].to(sample.device)
378 |
379 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
380 | timesteps = timesteps.expand(sample.shape[0])
381 |
382 | t_emb = self.time_proj(timesteps)
383 |
384 | # timesteps does not contain any weights and will always return f32 tensors
385 | # but time_embedding might actually be running in fp16. so we need to cast here.
386 | # there might be better ways to encapsulate this.
387 | t_emb = t_emb.to(dtype=self.dtype)
388 | emb = self.time_embedding(t_emb)
389 |
390 | if self.class_embedding is not None:
391 | if class_labels is None:
392 | raise ValueError(
393 | "class_labels should be provided when num_class_embeds > 0"
394 | )
395 |
396 | if self.config.class_embed_type == "timestep":
397 | class_labels = self.time_proj(class_labels)
398 |
399 | class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
400 | emb = emb + class_emb
401 |
402 | # fp16: cast to model dtype
403 | sample = sample.to(self.dtype)
404 | encoder_hidden_states = encoder_hidden_states.to(self.dtype)
405 |
406 | # pre-process
407 | sample = self.conv_in(sample)
408 | if self.temp_transformer is not None:
409 | sample_new = self.temp_transformer(sample).sample
410 | sample = sample_new if sample.shape[2] > 1 else sample + 0.0 * sample_new
411 |
412 | # down
413 | down_block_res_samples = (sample,)
414 | for downsample_block in self.down_blocks:
415 | if (
416 | hasattr(downsample_block, "has_cross_attention")
417 | and downsample_block.has_cross_attention
418 | ):
419 | sample, res_samples = downsample_block(
420 | hidden_states=sample,
421 | temb=emb,
422 | encoder_hidden_states=encoder_hidden_states,
423 | attention_mask=attention_mask,
424 | )
425 | else:
426 | sample, res_samples = downsample_block(
427 | hidden_states=sample, temb=emb
428 | )
429 | down_block_res_samples += res_samples
430 |
431 | # mid
432 | sample = self.mid_block(
433 | sample,
434 | emb,
435 | encoder_hidden_states=encoder_hidden_states,
436 | attention_mask=attention_mask,
437 | )
438 |
439 | # up
440 | for i, upsample_block in enumerate(self.up_blocks):
441 | is_final_block = i == len(self.up_blocks) - 1
442 |
443 | res_samples = down_block_res_samples[-len(upsample_block.resnets):]
444 | down_block_res_samples = down_block_res_samples[:-len(upsample_block.resnets)]
445 |
446 | # if we have not reached the final block and need to forward the
447 | # upsample size, we do it here
448 | if not is_final_block and forward_upsample_size:
449 | upsample_size = down_block_res_samples[-1].shape[2:]
450 |
451 | if (
452 | hasattr(upsample_block, "has_cross_attention")
453 | and upsample_block.has_cross_attention
454 | ):
455 | sample = upsample_block(
456 | hidden_states=sample,
457 | temb=emb,
458 | res_hidden_states_tuple=res_samples,
459 | encoder_hidden_states=encoder_hidden_states,
460 | upsample_size=upsample_size,
461 | attention_mask=attention_mask,
462 | )
463 | else:
464 | sample = upsample_block(
465 | hidden_states=sample,
466 | temb=emb,
467 | res_hidden_states_tuple=res_samples,
468 | upsample_size=upsample_size,
469 | )
470 | # post-process
471 | sample = self.conv_norm_out(sample)
472 | sample = self.conv_act(sample)
473 | sample = self.conv_out(sample)
474 |
475 | if not return_dict:
476 | return sample,
477 |
478 | return UNet3DConditionOutput(sample=sample)
479 |
480 | @classmethod
481 | def from_pretrained_2d(
482 | cls, pretrained_model_path, subfolder=None, **kwargs):
483 | if subfolder is not None:
484 | pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
485 |
486 | config_file = os.path.join(pretrained_model_path, "config.json")
487 | if not os.path.isfile(config_file):
488 | raise RuntimeError(f"{config_file} does not exist")
489 | with open(config_file, "r") as f:
490 | config = json.load(f)
491 | config["_class_name"] = cls.__name__
492 | config["down_block_types"] = [
493 | "CrossAttnDownBlock3D",
494 | "CrossAttnDownBlock3D",
495 | "CrossAttnDownBlock3D",
496 | "DownBlock3D",
497 | ]
498 | config["up_block_types"] = [
499 | "UpBlock3D",
500 | "CrossAttnUpBlock3D",
501 | "CrossAttnUpBlock3D",
502 | "CrossAttnUpBlock3D",
503 | ]
504 | if "sample_size" in kwargs and "sample_size" in config:
505 | config["sample_size"] = kwargs.get("sample_size")
506 | if "in_channels" in kwargs and "in_channels" in config:
507 | config["in_channels"] = kwargs.get("in_channels")
508 |
509 | from diffusers.utils import WEIGHTS_NAME
510 |
511 | model = cls.from_config(
512 | config, add_temp_transformer=kwargs.get("add_temp_transformer", False),
513 | add_temp_attn_only_on_upblocks=kwargs.get("add_temp_attn_only_on_upblocks", False),
514 | prepend_first_frame=kwargs.get("prepend_first_frame", False),
515 | add_temp_embed=kwargs.get("add_temp_embed", False),
516 | add_temp_ff=kwargs.get("add_temp_ff", False),
517 | add_temp_conv=kwargs.get("add_temp_conv", False),
518 | num_class_embeds=kwargs.get("num_class_embeds", None)
519 | )
520 | model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
521 | if not os.path.isfile(model_file):
522 | raise RuntimeError(f"{model_file} does not exist")
523 | state_dict = torch.load(model_file, map_location="cpu")
524 | for k, v in model.state_dict().items():
525 | if ("temp_" in k or "class_embedding" in k) and k not in state_dict:
526 | state_dict.update({k: v})
527 | if "conv_in" in k and v.shape != state_dict[k].shape:
528 | state_dict.update({k: v})
529 | model.load_state_dict(state_dict, strict=False)
530 |
531 | return model
532 |
--------------------------------------------------------------------------------
/model/modules/unet_blocks.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py
2 |
3 | import torch
4 | from torch import nn
5 |
6 | from .attention import Transformer3DModel
7 | from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
8 | from .utils import checkpoint, zero_module
9 |
10 |
11 | def get_down_block(
12 | down_block_type,
13 | num_layers,
14 | in_channels,
15 | out_channels,
16 | temb_channels,
17 | add_downsample,
18 | resnet_eps,
19 | resnet_act_fn,
20 | attn_num_head_channels,
21 | resnet_groups=None,
22 | cross_attention_dim=None,
23 | downsample_padding=None,
24 | dual_cross_attention=False,
25 | use_linear_projection=False,
26 | only_cross_attention=False,
27 | upcast_attention=False,
28 | resnet_time_scale_shift="default",
29 | add_temp_attn=False,
30 | prepend_first_frame=False,
31 | add_temp_embed=False,
32 | add_temp_ff=False,
33 | add_temp_conv=False,
34 | ):
35 | down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
36 | if down_block_type == "DownBlock3D":
37 | return DownBlock3D(
38 | num_layers=num_layers,
39 | in_channels=in_channels,
40 | out_channels=out_channels,
41 | temb_channels=temb_channels,
42 | add_downsample=add_downsample,
43 | resnet_eps=resnet_eps,
44 | resnet_act_fn=resnet_act_fn,
45 | resnet_groups=resnet_groups,
46 | downsample_padding=downsample_padding,
47 | resnet_time_scale_shift=resnet_time_scale_shift,
48 | add_temp_conv=add_temp_conv
49 | )
50 | elif down_block_type == "CrossAttnDownBlock3D":
51 | if cross_attention_dim is None:
52 | raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
53 | return CrossAttnDownBlock3D(
54 | num_layers=num_layers,
55 | in_channels=in_channels,
56 | out_channels=out_channels,
57 | temb_channels=temb_channels,
58 | add_downsample=add_downsample,
59 | resnet_eps=resnet_eps,
60 | resnet_act_fn=resnet_act_fn,
61 | resnet_groups=resnet_groups,
62 | downsample_padding=downsample_padding,
63 | cross_attention_dim=cross_attention_dim,
64 | attn_num_head_channels=attn_num_head_channels,
65 | dual_cross_attention=dual_cross_attention,
66 | use_linear_projection=use_linear_projection,
67 | only_cross_attention=only_cross_attention,
68 | upcast_attention=upcast_attention,
69 | resnet_time_scale_shift=resnet_time_scale_shift,
70 | add_temp_attn=add_temp_attn,
71 | prepend_first_frame=prepend_first_frame,
72 | add_temp_embed=add_temp_embed,
73 | add_temp_ff=add_temp_ff,
74 | add_temp_conv=add_temp_conv
75 | )
76 | raise ValueError(f"{down_block_type} does not exist.")
77 |
78 |
79 | def get_up_block(
80 | up_block_type,
81 | num_layers,
82 | in_channels,
83 | out_channels,
84 | prev_output_channel,
85 | temb_channels,
86 | add_upsample,
87 | resnet_eps,
88 | resnet_act_fn,
89 | attn_num_head_channels,
90 | resnet_groups=None,
91 | cross_attention_dim=None,
92 | dual_cross_attention=False,
93 | use_linear_projection=False,
94 | only_cross_attention=False,
95 | upcast_attention=False,
96 | resnet_time_scale_shift="default",
97 | add_temp_attn=False,
98 | prepend_first_frame=False,
99 | add_temp_embed=False,
100 | add_temp_ff=False,
101 | add_temp_conv=False,
102 | ):
103 | up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
104 | if up_block_type == "UpBlock3D":
105 | return UpBlock3D(
106 | num_layers=num_layers,
107 | in_channels=in_channels,
108 | out_channels=out_channels,
109 | prev_output_channel=prev_output_channel,
110 | temb_channels=temb_channels,
111 | add_upsample=add_upsample,
112 | resnet_eps=resnet_eps,
113 | resnet_act_fn=resnet_act_fn,
114 | resnet_groups=resnet_groups,
115 | resnet_time_scale_shift=resnet_time_scale_shift,
116 | add_temp_conv=add_temp_conv
117 | )
118 | elif up_block_type == "CrossAttnUpBlock3D":
119 | if cross_attention_dim is None:
120 | raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
121 | return CrossAttnUpBlock3D(
122 | num_layers=num_layers,
123 | in_channels=in_channels,
124 | out_channels=out_channels,
125 | prev_output_channel=prev_output_channel,
126 | temb_channels=temb_channels,
127 | add_upsample=add_upsample,
128 | resnet_eps=resnet_eps,
129 | resnet_act_fn=resnet_act_fn,
130 | resnet_groups=resnet_groups,
131 | cross_attention_dim=cross_attention_dim,
132 | attn_num_head_channels=attn_num_head_channels,
133 | dual_cross_attention=dual_cross_attention,
134 | use_linear_projection=use_linear_projection,
135 | only_cross_attention=only_cross_attention,
136 | upcast_attention=upcast_attention,
137 | resnet_time_scale_shift=resnet_time_scale_shift,
138 | add_temp_attn=add_temp_attn,
139 | prepend_first_frame=prepend_first_frame,
140 | add_temp_embed=add_temp_embed,
141 | add_temp_ff=add_temp_ff,
142 | add_temp_conv=add_temp_conv
143 | )
144 | raise ValueError(f"{up_block_type} does not exist.")
145 |
146 |
147 | class TemporalConvLayer(nn.Module):
148 | """
149 | Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from:
150 | https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016
151 | """
152 |
153 | def __init__(self, in_dim, out_dim=None, num_layers=4, dropout=0.0):
154 | super().__init__()
155 | out_dim = out_dim or in_dim
156 |
157 | # conv layers
158 | convs = []
159 | prev_dim, next_dim = in_dim, out_dim
160 | for i in range(num_layers):
161 | if i == num_layers - 1:
162 | next_dim = out_dim
163 | convs.extend([
164 | nn.GroupNorm(32, prev_dim), nn.SiLU(), nn.Dropout(dropout),
165 | nn.Conv3d(prev_dim, next_dim, (3, 1, 1), padding=(1, 0, 0))
166 | ])
167 | prev_dim, next_dim = next_dim, prev_dim
168 | self.convs = nn.ModuleList(convs)
169 |
170 | def forward(self, hidden_states):
171 | video_length = hidden_states.shape[2]
172 |
173 | identity = hidden_states
174 | for conv in self.convs:
175 | hidden_states = conv(hidden_states)
176 |
177 | # ignore effects of temporal layers on image inputs
178 | hidden_states = identity + hidden_states if video_length > 1 \
179 | else identity + 0.0 * hidden_states
180 |
181 | return hidden_states
182 |
183 |
184 | class UNetMidBlock3DCrossAttn(nn.Module):
185 | def __init__(
186 | self,
187 | in_channels: int,
188 | temb_channels: int,
189 | dropout: float = 0.0,
190 | num_layers: int = 1,
191 | resnet_eps: float = 1e-6,
192 | resnet_time_scale_shift: str = "default",
193 | resnet_act_fn: str = "swish",
194 | resnet_groups: int = 32,
195 | resnet_pre_norm: bool = True,
196 | attn_num_head_channels=1,
197 | output_scale_factor=1.0,
198 | cross_attention_dim=1280,
199 | dual_cross_attention=False,
200 | use_linear_projection=False,
201 | upcast_attention=False,
202 | add_temp_attn=False,
203 | prepend_first_frame=False,
204 | add_temp_embed=False,
205 | add_temp_ff=False,
206 | add_temp_conv=False,
207 | ):
208 | super().__init__()
209 |
210 | self.has_cross_attention = True
211 | self.attn_num_head_channels = attn_num_head_channels
212 | resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
213 |
214 | # there is always at least one resnet
215 | resnets = [
216 | ResnetBlock3D(
217 | in_channels=in_channels,
218 | out_channels=in_channels,
219 | temb_channels=temb_channels,
220 | eps=resnet_eps,
221 | groups=resnet_groups,
222 | dropout=dropout,
223 | time_embedding_norm=resnet_time_scale_shift,
224 | non_linearity=resnet_act_fn,
225 | output_scale_factor=output_scale_factor,
226 | pre_norm=resnet_pre_norm,
227 | )
228 | ]
229 | attentions = []
230 | if add_temp_conv:
231 | self.temp_convs = None
232 | temp_convs = [TemporalConvLayer(in_channels, in_channels, dropout=0.1)]
233 | temp_convs[-1].convs[-1] = zero_module(temp_convs[-1].convs[-1])
234 |
235 | for _ in range(num_layers):
236 | if dual_cross_attention:
237 | raise NotImplementedError
238 | attentions.append(
239 | Transformer3DModel(
240 | attn_num_head_channels,
241 | in_channels // attn_num_head_channels,
242 | in_channels=in_channels,
243 | num_layers=1,
244 | cross_attention_dim=cross_attention_dim,
245 | norm_num_groups=resnet_groups,
246 | use_linear_projection=use_linear_projection,
247 | upcast_attention=upcast_attention,
248 | add_temp_attn=add_temp_attn,
249 | prepend_first_frame=prepend_first_frame,
250 | add_temp_embed=add_temp_embed,
251 | add_temp_ff=add_temp_ff
252 | )
253 | )
254 | resnets.append(
255 | ResnetBlock3D(
256 | in_channels=in_channels,
257 | out_channels=in_channels,
258 | temb_channels=temb_channels,
259 | eps=resnet_eps,
260 | groups=resnet_groups,
261 | dropout=dropout,
262 | time_embedding_norm=resnet_time_scale_shift,
263 | non_linearity=resnet_act_fn,
264 | output_scale_factor=output_scale_factor,
265 | pre_norm=resnet_pre_norm,
266 | )
267 | )
268 | if hasattr(self, "temp_convs"):
269 | temp_convs.append(TemporalConvLayer(in_channels, in_channels, dropout=0.1))
270 | temp_convs[-1].convs[-1] = zero_module(temp_convs[-1].convs[-1])
271 |
272 | self.attentions = nn.ModuleList(attentions)
273 | self.resnets = nn.ModuleList(resnets)
274 | if hasattr(self, "temp_convs"):
275 | self.temp_convs = nn.ModuleList(temp_convs)
276 |
277 | def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
278 | hidden_states = self.resnets[0](hidden_states, temb)
279 | if hasattr(self, "temp_convs"):
280 | hidden_states = self.temp_convs[0](hidden_states)
281 | for layer_idx in range(len(self.attentions)):
282 | attn = self.attentions[layer_idx]
283 | resnet = self.resnets[layer_idx + 1]
284 | hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states)[0]
285 | hidden_states = resnet(hidden_states, temb)
286 | if hasattr(self, "temp_convs"):
287 | temp_conv = self.temp_convs[layer_idx + 1]
288 | hidden_states = temp_conv(hidden_states)
289 |
290 | return hidden_states
291 |
292 |
293 | class CrossAttnDownBlock3D(nn.Module):
294 | def __init__(
295 | self,
296 | in_channels: int,
297 | out_channels: int,
298 | temb_channels: int,
299 | dropout: float = 0.0,
300 | num_layers: int = 1,
301 | resnet_eps: float = 1e-6,
302 | resnet_time_scale_shift: str = "default",
303 | resnet_act_fn: str = "swish",
304 | resnet_groups: int = 32,
305 | resnet_pre_norm: bool = True,
306 | attn_num_head_channels=1,
307 | cross_attention_dim=1280,
308 | output_scale_factor=1.0,
309 | downsample_padding=1,
310 | add_downsample=True,
311 | dual_cross_attention=False,
312 | use_linear_projection=False,
313 | only_cross_attention=False,
314 | upcast_attention=False,
315 | add_temp_attn=False,
316 | prepend_first_frame=False,
317 | add_temp_embed=False,
318 | add_temp_ff=False,
319 | add_temp_conv=False,
320 | ):
321 | super().__init__()
322 | resnets = []
323 | attentions = []
324 | if add_temp_conv:
325 | self.temp_convs = None
326 | temp_convs = []
327 |
328 | self.has_cross_attention = True
329 | self.attn_num_head_channels = attn_num_head_channels
330 |
331 | for i in range(num_layers):
332 | in_channels = in_channels if i == 0 else out_channels
333 | resnets.append(
334 | ResnetBlock3D(
335 | in_channels=in_channels,
336 | out_channels=out_channels,
337 | temb_channels=temb_channels,
338 | eps=resnet_eps,
339 | groups=resnet_groups,
340 | dropout=dropout,
341 | time_embedding_norm=resnet_time_scale_shift,
342 | non_linearity=resnet_act_fn,
343 | output_scale_factor=output_scale_factor,
344 | pre_norm=resnet_pre_norm,
345 | )
346 | )
347 | if hasattr(self, "temp_convs"):
348 | temp_convs.append(TemporalConvLayer(out_channels, out_channels, dropout=0.1))
349 | temp_convs[-1].convs[-1] = zero_module(temp_convs[-1].convs[-1])
350 | if dual_cross_attention:
351 | raise NotImplementedError
352 | attentions.append(
353 | Transformer3DModel(
354 | attn_num_head_channels,
355 | out_channels // attn_num_head_channels,
356 | in_channels=out_channels,
357 | num_layers=1,
358 | cross_attention_dim=cross_attention_dim,
359 | norm_num_groups=resnet_groups,
360 | use_linear_projection=use_linear_projection,
361 | only_cross_attention=only_cross_attention,
362 | upcast_attention=upcast_attention,
363 | add_temp_attn=add_temp_attn,
364 | prepend_first_frame=prepend_first_frame,
365 | add_temp_embed=add_temp_embed,
366 | add_temp_ff=add_temp_ff,
367 | )
368 | )
369 | self.attentions = nn.ModuleList(attentions)
370 | self.resnets = nn.ModuleList(resnets)
371 | if hasattr(self, "temp_convs"):
372 | self.temp_convs = nn.ModuleList(temp_convs)
373 |
374 | if add_downsample:
375 | self.downsamplers = nn.ModuleList(
376 | [
377 | Downsample3D(
378 | out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
379 | )
380 | ]
381 | )
382 | else:
383 | self.downsamplers = None
384 |
385 | self.gradient_checkpointing = False
386 |
387 | def forward(
388 | self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
389 | output_states = ()
390 |
391 | for layer_idx in range(len(self.resnets)):
392 | resnet, attn = self.resnets[layer_idx], self.attentions[layer_idx]
393 | is_checkpointing = self.training and self.gradient_checkpointing
394 | hidden_states = checkpoint(func=resnet, inputs=(hidden_states, temb), flag=is_checkpointing)
395 | if hasattr(self, "temp_convs"):
396 | temp_conv = self.temp_convs[layer_idx]
397 | hidden_states = checkpoint(func=temp_conv, inputs=(hidden_states,), flag=is_checkpointing)
398 | hidden_states = checkpoint(
399 | func=attn, inputs=(hidden_states, encoder_hidden_states), flag=is_checkpointing)[0]
400 |
401 | output_states += (hidden_states,)
402 |
403 | if self.downsamplers is not None:
404 | for downsampler in self.downsamplers:
405 | hidden_states = downsampler(hidden_states)
406 |
407 | output_states += (hidden_states,)
408 |
409 | return hidden_states, output_states
410 |
411 |
412 | class DownBlock3D(nn.Module):
413 | def __init__(
414 | self,
415 | in_channels: int,
416 | out_channels: int,
417 | temb_channels: int,
418 | dropout: float = 0.0,
419 | num_layers: int = 1,
420 | resnet_eps: float = 1e-6,
421 | resnet_time_scale_shift: str = "default",
422 | resnet_act_fn: str = "swish",
423 | resnet_groups: int = 32,
424 | resnet_pre_norm: bool = True,
425 | output_scale_factor=1.0,
426 | add_downsample=True,
427 | downsample_padding=1,
428 | add_temp_conv=False,
429 | ):
430 | super().__init__()
431 | resnets = []
432 | if add_temp_conv:
433 | self.temp_convs = None
434 | temp_convs = []
435 | for i in range(num_layers):
436 | in_channels = in_channels if i == 0 else out_channels
437 | resnets.append(
438 | ResnetBlock3D(
439 | in_channels=in_channels,
440 | out_channels=out_channels,
441 | temb_channels=temb_channels,
442 | eps=resnet_eps,
443 | groups=resnet_groups,
444 | dropout=dropout,
445 | time_embedding_norm=resnet_time_scale_shift,
446 | non_linearity=resnet_act_fn,
447 | output_scale_factor=output_scale_factor,
448 | pre_norm=resnet_pre_norm,
449 | )
450 | )
451 | if hasattr(self, "temp_convs"):
452 | temp_convs.append(TemporalConvLayer(out_channels, out_channels, dropout=0.1))
453 | temp_convs[-1].convs[-1] = zero_module(temp_convs[-1].convs[-1])
454 |
455 | self.resnets = nn.ModuleList(resnets)
456 | if hasattr(self, "temp_convs"):
457 | self.temp_convs = nn.ModuleList(temp_convs)
458 |
459 | if add_downsample:
460 | self.downsamplers = nn.ModuleList(
461 | [
462 | Downsample3D(
463 | out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
464 | )
465 | ]
466 | )
467 | else:
468 | self.downsamplers = None
469 |
470 | self.gradient_checkpointing = False
471 |
472 | def forward(self, hidden_states, temb=None):
473 | output_states = ()
474 |
475 | for layer_idx in range(len(self.resnets)):
476 | resnet = self.resnets[layer_idx]
477 | is_checkpointing = self.training and self.gradient_checkpointing
478 | hidden_states = checkpoint(func=resnet, inputs=(hidden_states, temb), flag=is_checkpointing)
479 | if hasattr(self, "temp_convs"):
480 | temp_conv = self.temp_convs[layer_idx]
481 | hidden_states = checkpoint(func=temp_conv, inputs=(hidden_states,), flag=is_checkpointing)
482 |
483 | output_states += (hidden_states,)
484 |
485 | if self.downsamplers is not None:
486 | for downsampler in self.downsamplers:
487 | hidden_states = downsampler(hidden_states)
488 |
489 | output_states += (hidden_states,)
490 |
491 | return hidden_states, output_states
492 |
493 |
494 | class CrossAttnUpBlock3D(nn.Module):
495 | def __init__(
496 | self,
497 | in_channels: int,
498 | out_channels: int,
499 | prev_output_channel: int,
500 | temb_channels: int,
501 | dropout: float = 0.0,
502 | num_layers: int = 1,
503 | resnet_eps: float = 1e-6,
504 | resnet_time_scale_shift: str = "default",
505 | resnet_act_fn: str = "swish",
506 | resnet_groups: int = 32,
507 | resnet_pre_norm: bool = True,
508 | attn_num_head_channels=1,
509 | cross_attention_dim=1280,
510 | output_scale_factor=1.0,
511 | add_upsample=True,
512 | dual_cross_attention=False,
513 | use_linear_projection=False,
514 | only_cross_attention=False,
515 | upcast_attention=False,
516 | add_temp_attn=False,
517 | prepend_first_frame=False,
518 | add_temp_embed=False,
519 | add_temp_ff=False,
520 | add_temp_conv=False,
521 | ):
522 | super().__init__()
523 | resnets = []
524 | attentions = []
525 | if add_temp_conv:
526 | self.temp_convs = None
527 | temp_convs = []
528 |
529 | self.has_cross_attention = True
530 | self.attn_num_head_channels = attn_num_head_channels
531 |
532 | for i in range(num_layers):
533 | res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
534 | resnet_in_channels = prev_output_channel if i == 0 else out_channels
535 |
536 | resnets.append(
537 | ResnetBlock3D(
538 | in_channels=resnet_in_channels + res_skip_channels,
539 | out_channels=out_channels,
540 | temb_channels=temb_channels,
541 | eps=resnet_eps,
542 | groups=resnet_groups,
543 | dropout=dropout,
544 | time_embedding_norm=resnet_time_scale_shift,
545 | non_linearity=resnet_act_fn,
546 | output_scale_factor=output_scale_factor,
547 | pre_norm=resnet_pre_norm,
548 | )
549 | )
550 | if dual_cross_attention:
551 | raise NotImplementedError
552 | attentions.append(
553 | Transformer3DModel(
554 | attn_num_head_channels,
555 | out_channels // attn_num_head_channels,
556 | in_channels=out_channels,
557 | num_layers=1,
558 | cross_attention_dim=cross_attention_dim,
559 | norm_num_groups=resnet_groups,
560 | use_linear_projection=use_linear_projection,
561 | only_cross_attention=only_cross_attention,
562 | upcast_attention=upcast_attention,
563 | add_temp_attn=add_temp_attn,
564 | prepend_first_frame=prepend_first_frame,
565 | add_temp_embed=add_temp_embed,
566 | add_temp_ff=add_temp_ff
567 | )
568 | )
569 | if hasattr(self, "temp_convs"):
570 | temp_convs.append(TemporalConvLayer(out_channels, out_channels, dropout=0.1))
571 | temp_convs[-1].convs[-1] = zero_module(temp_convs[-1].convs[-1])
572 |
573 | self.attentions = nn.ModuleList(attentions)
574 | self.resnets = nn.ModuleList(resnets)
575 | if hasattr(self, "temp_convs"):
576 | self.temp_convs = nn.ModuleList(temp_convs)
577 |
578 | if add_upsample:
579 | self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
580 | else:
581 | self.upsamplers = None
582 |
583 | self.gradient_checkpointing = False
584 |
585 | def forward(
586 | self,
587 | hidden_states,
588 | res_hidden_states_tuple,
589 | temb=None,
590 | encoder_hidden_states=None,
591 | upsample_size=None,
592 | attention_mask=None
593 | ):
594 | for layer_idx in range(len(self.resnets)):
595 | resnet, attn = self.resnets[layer_idx], self.attentions[layer_idx]
596 | # pop res hidden states
597 | res_hidden_states = res_hidden_states_tuple[-1]
598 | res_hidden_states_tuple = res_hidden_states_tuple[:-1]
599 | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
600 |
601 | is_checkpointing = self.training and self.gradient_checkpointing
602 | hidden_states = checkpoint(func=resnet, inputs=(hidden_states, temb), flag=is_checkpointing)
603 | if hasattr(self, "temp_convs"):
604 | temp_conv = self.temp_convs[layer_idx]
605 | hidden_states = checkpoint(func=temp_conv, inputs=(hidden_states,), flag=is_checkpointing)
606 | hidden_states = checkpoint(
607 | func=attn, inputs=(hidden_states, encoder_hidden_states), flag=is_checkpointing)[0]
608 |
609 | if self.upsamplers is not None:
610 | for upsampler in self.upsamplers:
611 | hidden_states = upsampler(hidden_states, upsample_size)
612 |
613 | return hidden_states
614 |
615 |
616 | class UpBlock3D(nn.Module):
617 | def __init__(
618 | self,
619 | in_channels: int,
620 | prev_output_channel: int,
621 | out_channels: int,
622 | temb_channels: int,
623 | dropout: float = 0.0,
624 | num_layers: int = 1,
625 | resnet_eps: float = 1e-6,
626 | resnet_time_scale_shift: str = "default",
627 | resnet_act_fn: str = "swish",
628 | resnet_groups: int = 32,
629 | resnet_pre_norm: bool = True,
630 | output_scale_factor=1.0,
631 | add_upsample=True,
632 | add_temp_conv=False,
633 | ):
634 | super().__init__()
635 | resnets = []
636 |
637 | if add_temp_conv:
638 | self.temp_convs = None
639 | temp_convs = []
640 |
641 | for i in range(num_layers):
642 | res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
643 | resnet_in_channels = prev_output_channel if i == 0 else out_channels
644 |
645 | resnets.append(
646 | ResnetBlock3D(
647 | in_channels=resnet_in_channels + res_skip_channels,
648 | out_channels=out_channels,
649 | temb_channels=temb_channels,
650 | eps=resnet_eps,
651 | groups=resnet_groups,
652 | dropout=dropout,
653 | time_embedding_norm=resnet_time_scale_shift,
654 | non_linearity=resnet_act_fn,
655 | output_scale_factor=output_scale_factor,
656 | pre_norm=resnet_pre_norm,
657 | )
658 | )
659 | if hasattr(self, "temp_convs"):
660 | temp_convs.append(TemporalConvLayer(out_channels, out_channels, dropout=0.1))
661 | temp_convs[-1].convs[-1] = zero_module(temp_convs[-1].convs[-1])
662 |
663 | self.resnets = nn.ModuleList(resnets)
664 | if hasattr(self, "temp_convs"):
665 | self.temp_convs = nn.ModuleList(temp_convs)
666 |
667 | if add_upsample:
668 | self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
669 | else:
670 | self.upsamplers = None
671 |
672 | self.gradient_checkpointing = False
673 |
674 | def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
675 | for layer_idx in range(len(self.resnets)):
676 | resnet = self.resnets[layer_idx]
677 | # pop res hidden states
678 | res_hidden_states = res_hidden_states_tuple[-1]
679 | res_hidden_states_tuple = res_hidden_states_tuple[:-1]
680 | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
681 |
682 | is_checkpointing = self.training and self.gradient_checkpointing
683 | hidden_states = checkpoint(func=resnet, inputs=(hidden_states, temb), flag=is_checkpointing)
684 | if hasattr(self, "temp_convs"):
685 | temp_conv = self.temp_convs[layer_idx]
686 | hidden_states = checkpoint(func=temp_conv, inputs=(hidden_states,), flag=is_checkpointing)
687 |
688 | if self.upsamplers is not None:
689 | for upsampler in self.upsamplers:
690 | hidden_states = upsampler(hidden_states, upsample_size)
691 |
692 | return hidden_states
693 |
--------------------------------------------------------------------------------
/model/modules/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def checkpoint(func, inputs, flag):
5 | """
6 | Evaluate a function without caching intermediate activations, allowing for
7 | reduced memory at the expense of extra compute in the backward pass.
8 | :param func: the function to evaluate.
9 | :param inputs: the argument sequence to pass to `func`.
10 | :param flag: if False, disable gradient checkpointing.
11 | """
12 | if flag:
13 | return torch.utils.checkpoint.checkpoint(func, *inputs, use_reentrant=False)
14 | else:
15 | return func(*inputs)
16 |
17 |
18 | def zero_module(module):
19 | """
20 | Zero out the parameters of a module and return it.
21 | """
22 | for p in module.parameters():
23 | p.detach().zero_()
24 | return module
25 |
--------------------------------------------------------------------------------
/model/modules/zero_snr_ddpm.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 UC Berkeley Team and The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
16 |
17 | import math
18 | from dataclasses import dataclass
19 | from typing import List, Optional, Tuple, Union
20 |
21 | import numpy as np
22 | import torch
23 |
24 | from diffusers.configuration_utils import ConfigMixin, register_to_config
25 | from diffusers.utils import BaseOutput, randn_tensor
26 | from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
27 |
28 |
29 | @dataclass
30 | class DDPMSchedulerOutput(BaseOutput):
31 | """
32 | Output class for the scheduler's step function output.
33 |
34 | Args:
35 | prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
36 | Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
37 | denoising loop.
38 | pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
39 | The predicted denoised sample (x_{0}) based on the model output from the current timestep.
40 | `pred_original_sample` can be used to preview progress or for guidance.
41 | """
42 |
43 | prev_sample: torch.FloatTensor
44 | pred_original_sample: Optional[torch.FloatTensor] = None
45 |
46 |
47 | def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
48 | """
49 | Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
50 | (1-beta) over time from t = [0,1].
51 |
52 | Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
53 | to that part of the diffusion process.
54 |
55 |
56 | Args:
57 | num_diffusion_timesteps (`int`): the number of betas to produce.
58 | max_beta (`float`): the maximum beta to use; use values lower than 1 to
59 | prevent singularities.
60 |
61 | Returns:
62 | betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
63 | """
64 |
65 | def alpha_bar(time_step):
66 | return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
67 |
68 | betas = []
69 | for i in range(num_diffusion_timesteps):
70 | t1 = i / num_diffusion_timesteps
71 | t2 = (i + 1) / num_diffusion_timesteps
72 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
73 | return torch.tensor(betas, dtype=torch.float32)
74 |
75 | def enforce_zero_terminal_snr(betas):
76 | # convert betas to alphas_bar_sqrt
77 | alphas = 1 - betas
78 | alphas_bar = alphas.cumprod(0)
79 | alphas_bar_sqrt = alphas_bar.sqrt()
80 |
81 | # store old values
82 | alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
83 | alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
84 | # shift so last timestep is zero.
85 | alphas_bar_sqrt -= alphas_bar_sqrt_T
86 | # scale so first timestep is back to old value.
87 | alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
88 |
89 | # convert alphas_bar_sqrt to betas
90 | alphas_bar = alphas_bar_sqrt ** 2
91 | alphas = alphas_bar[1:] / alphas_bar[:-1]
92 | alphas = torch.cat([alphas_bar[0:1], alphas])
93 | betas = 1 - alphas
94 | return betas
95 |
96 | class DDPMScheduler(SchedulerMixin, ConfigMixin):
97 | """
98 | Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and
99 | Langevin dynamics sampling.
100 |
101 | [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
102 | function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
103 | [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
104 | [`~SchedulerMixin.from_pretrained`] functions.
105 |
106 | For more details, see the original paper: https://arxiv.org/abs/2006.11239
107 |
108 | Args:
109 | num_train_timesteps (`int`): number of diffusion steps used to train the model.
110 | beta_start (`float`): the starting `beta` value of inference.
111 | beta_end (`float`): the final `beta` value.
112 | beta_schedule (`str`):
113 | the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
114 | `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
115 | trained_betas (`np.ndarray`, optional):
116 | option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
117 | variance_type (`str`):
118 | options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
119 | `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
120 | clip_sample (`bool`, default `True`):
121 | option to clip predicted sample between -1 and 1 for numerical stability.
122 | prediction_type (`str`, default `epsilon`, optional):
123 | prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
124 | process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
125 | https://imagen.research.google/video/paper.pdf)
126 | """
127 |
128 | _compatibles = [e.name for e in KarrasDiffusionSchedulers]
129 | order = 1
130 |
131 | @register_to_config
132 | def __init__(
133 | self,
134 | num_train_timesteps: int = 1000,
135 | beta_start: float = 0.0001,
136 | beta_end: float = 0.02,
137 | beta_schedule: str = "linear",
138 | trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
139 | variance_type: str = "fixed_small",
140 | clip_sample: bool = True,
141 | prediction_type: str = "epsilon",
142 | clip_sample_range: Optional[float] = 1.0,
143 | ):
144 | if trained_betas is not None:
145 | self.betas = torch.tensor(trained_betas, dtype=torch.float32)
146 | elif beta_schedule == "linear":
147 | self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
148 | elif beta_schedule == "scaled_linear":
149 | # this schedule is very specific to the latent diffusion model.
150 | self.betas = (
151 | torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
152 | )
153 | self.betas = enforce_zero_terminal_snr(self.betas) # Zero-SNR
154 | elif beta_schedule == "squaredcos_cap_v2":
155 | # Glide cosine schedule
156 | self.betas = betas_for_alpha_bar(num_train_timesteps)
157 | elif beta_schedule == "sigmoid":
158 | # GeoDiff sigmoid schedule
159 | betas = torch.linspace(-6, 6, num_train_timesteps)
160 | self.betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
161 | else:
162 | raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
163 |
164 | self.alphas = 1.0 - self.betas
165 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
166 | self.one = torch.tensor(1.0)
167 |
168 | # standard deviation of the initial noise distribution
169 | self.init_noise_sigma = 1.0
170 |
171 | # setable values
172 | self.num_inference_steps = None
173 | self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
174 |
175 | self.variance_type = variance_type
176 |
177 | def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
178 | """
179 | Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
180 | current timestep.
181 |
182 | Args:
183 | sample (`torch.FloatTensor`): input sample
184 | timestep (`int`, optional): current timestep
185 |
186 | Returns:
187 | `torch.FloatTensor`: scaled input sample
188 | """
189 | return sample
190 |
191 | def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
192 | """
193 | Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
194 |
195 | Args:
196 | num_inference_steps (`int`):
197 | the number of diffusion steps used when generating samples with a pre-trained model.
198 | """
199 |
200 | if num_inference_steps > self.config.num_train_timesteps:
201 | raise ValueError(
202 | f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
203 | f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
204 | f" maximal {self.config.num_train_timesteps} timesteps."
205 | )
206 |
207 | self.num_inference_steps = num_inference_steps
208 |
209 | step_ratio = self.config.num_train_timesteps // self.num_inference_steps
210 | timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
211 | self.timesteps = torch.from_numpy(timesteps).to(device)
212 |
213 | def _get_variance(self, t, predicted_variance=None, variance_type=None):
214 | num_inference_steps = self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
215 | prev_t = t - self.config.num_train_timesteps // num_inference_steps
216 | alpha_prod_t = self.alphas_cumprod[t]
217 | alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
218 | current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev
219 |
220 | # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
221 | # and sample from it to get previous sample
222 | # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
223 | variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t
224 |
225 | if variance_type is None:
226 | variance_type = self.config.variance_type
227 |
228 | # hacks - were probably added for training stability
229 | if variance_type == "fixed_small":
230 | variance = torch.clamp(variance, min=1e-20)
231 | # for rl-diffuser https://arxiv.org/abs/2205.09991
232 | elif variance_type == "fixed_small_log":
233 | variance = torch.log(torch.clamp(variance, min=1e-20))
234 | variance = torch.exp(0.5 * variance)
235 | elif variance_type == "fixed_large":
236 | variance = current_beta_t
237 | elif variance_type == "fixed_large_log":
238 | # Glide max_log
239 | variance = torch.log(current_beta_t)
240 | elif variance_type == "learned":
241 | return predicted_variance
242 | elif variance_type == "learned_range":
243 | min_log = torch.log(variance)
244 | max_log = torch.log(self.betas[t])
245 | frac = (predicted_variance + 1) / 2
246 | variance = frac * max_log + (1 - frac) * min_log
247 |
248 | return variance
249 |
250 | def step(
251 | self,
252 | model_output: torch.FloatTensor,
253 | timestep: int,
254 | sample: torch.FloatTensor,
255 | generator=None,
256 | return_dict: bool = True,
257 | ) -> Union[DDPMSchedulerOutput, Tuple]:
258 | """
259 | Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
260 | process from the learned model outputs (most often the predicted noise).
261 |
262 | Args:
263 | model_output (`torch.FloatTensor`): direct output from learned diffusion model.
264 | timestep (`int`): current discrete timestep in the diffusion chain.
265 | sample (`torch.FloatTensor`):
266 | current instance of sample being created by diffusion process.
267 | generator: random number generator.
268 | return_dict (`bool`): option for returning tuple rather than DDPMSchedulerOutput class
269 |
270 | Returns:
271 | [`~schedulers.scheduling_utils.DDPMSchedulerOutput`] or `tuple`:
272 | [`~schedulers.scheduling_utils.DDPMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
273 | returning a tuple, the first element is the sample tensor.
274 |
275 | """
276 | t = timestep
277 | num_inference_steps = self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
278 | prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
279 |
280 | if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
281 | model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
282 | else:
283 | predicted_variance = None
284 |
285 | # 1. compute alphas, betas
286 | alpha_prod_t = self.alphas_cumprod[t]
287 | alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
288 | beta_prod_t = 1 - alpha_prod_t
289 | beta_prod_t_prev = 1 - alpha_prod_t_prev
290 | current_alpha_t = alpha_prod_t / alpha_prod_t_prev
291 | current_beta_t = 1 - current_alpha_t
292 |
293 | # 2. compute predicted original sample from predicted noise also called
294 | # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
295 | if self.config.prediction_type == "epsilon":
296 | pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
297 | elif self.config.prediction_type == "sample":
298 | pred_original_sample = model_output
299 | elif self.config.prediction_type == "v_prediction":
300 | pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
301 | else:
302 | raise ValueError(
303 | f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or"
304 | " `v_prediction` for the DDPMScheduler."
305 | )
306 |
307 | # 3. Clip "predicted x_0"
308 | if self.config.clip_sample:
309 | pred_original_sample = torch.clamp(
310 | pred_original_sample, -self.config.clip_sample_range, self.config.clip_sample_range
311 | )
312 |
313 | # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
314 | # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
315 | pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t
316 | current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t
317 |
318 | # 5. Compute predicted previous sample µ_t
319 | # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
320 | pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
321 |
322 | # 6. Add noise
323 | variance = 0
324 | if t > 0:
325 | device = model_output.device
326 | variance_noise = randn_tensor(
327 | model_output.shape, generator=generator, device=device, dtype=model_output.dtype
328 | )
329 | if self.variance_type == "fixed_small_log":
330 | variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise
331 | elif self.variance_type == "learned_range":
332 | variance = self._get_variance(t, predicted_variance=predicted_variance)
333 | variance = torch.exp(0.5 * variance) * variance_noise
334 | else:
335 | variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * variance_noise
336 |
337 | pred_prev_sample = pred_prev_sample + variance
338 |
339 | if not return_dict:
340 | return (pred_prev_sample,)
341 |
342 | return DDPMSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)
343 |
344 | def add_noise(
345 | self,
346 | original_samples: torch.FloatTensor,
347 | noise: torch.FloatTensor,
348 | timesteps: torch.IntTensor,
349 | ) -> torch.FloatTensor:
350 | # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
351 | self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
352 | timesteps = timesteps.to(original_samples.device)
353 |
354 | sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
355 | sqrt_alpha_prod = sqrt_alpha_prod.flatten()
356 | while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
357 | sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
358 |
359 | sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
360 | sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
361 | while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
362 | sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
363 |
364 | noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
365 | return noisy_samples
366 |
367 | def get_velocity(
368 | self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
369 | ) -> torch.FloatTensor:
370 | # Make sure alphas_cumprod and timestep have same device and dtype as sample
371 | self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
372 | timesteps = timesteps.to(sample.device)
373 |
374 | sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
375 | sqrt_alpha_prod = sqrt_alpha_prod.flatten()
376 | while len(sqrt_alpha_prod.shape) < len(sample.shape):
377 | sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
378 |
379 | sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
380 | sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
381 | while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
382 | sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
383 |
384 | velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
385 | return velocity
386 |
387 | def __len__(self):
388 | return self.config.num_train_timesteps
389 |
--------------------------------------------------------------------------------
/model/utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import random
3 | from math import sqrt
4 | from pathlib import Path
5 | from typing import Union, Tuple, List, Optional
6 |
7 | import imageio
8 | import torch
9 | import torchvision
10 | from einops import rearrange
11 | from torch.utils.checkpoint import checkpoint
12 |
13 | import torch.nn as nn
14 | import kornia
15 | import open_clip
16 | import math
17 |
18 | logger = logging.getLogger(__name__)
19 |
20 |
21 | def randn_base(
22 | shape: Union[Tuple, List],
23 | mean: float = .0,
24 | std: float = 1.,
25 | generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None,
26 | device: Optional["torch.device"] = None,
27 | dtype: Optional["torch.dtype"] = None
28 | ):
29 | if isinstance(generator, list):
30 | shape = (1,) + shape[1:]
31 | tensor = [
32 | torch.normal(
33 | mean=mean, std=std, size=shape, generator=generator[i],
34 | device=device, dtype=dtype
35 | )
36 | for i in range(len(generator))
37 | ]
38 | tensor = torch.cat(tensor, dim=0).to(device)
39 | else:
40 | tensor = torch.normal(
41 | mean=mean, std=std, size=shape, generator=generator, device=device,
42 | dtype=dtype
43 | )
44 | return tensor
45 |
46 |
47 | def randn_mixed(
48 | shape: Union[Tuple, List],
49 | dim: int,
50 | alpha: float = .0,
51 | generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None,
52 | device: Optional["torch.device"] = None,
53 | dtype: Optional["torch.dtype"] = None
54 | ):
55 | """ Refer to Section 4 of Preserve Your Own Correlation:
56 | [A Noise Prior for Video Diffusion Models](https://arxiv.org/abs/2305.10474)
57 | """
58 | shape_shared = shape[:dim] + (1,) + shape[dim + 1:]
59 |
60 | # shared random tensor
61 | shared_std = alpha ** 2 / (1. + alpha ** 2)
62 | shared_tensor = randn_base(
63 | shape=shape_shared, mean=.0, std=shared_std, generator=generator,
64 | device=device, dtype=dtype
65 | )
66 |
67 | # individual random tensor
68 | indv_std = 1. / (1. + alpha ** 2)
69 | indv_tensor = randn_base(
70 | shape=shape, mean=.0, std=indv_std, generator=generator, device=device,
71 | dtype=dtype
72 | )
73 |
74 | return shared_tensor + indv_tensor
75 |
76 |
77 | def randn_progressive(
78 | shape: Union[Tuple, List],
79 | dim: int,
80 | alpha: float = .0,
81 | generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None,
82 | device: Optional["torch.device"] = None,
83 | dtype: Optional["torch.dtype"] = None
84 | ):
85 | """ Refer to Section 4 of Preserve Your Own Correlation:
86 | [A Noise Prior for Video Diffusion Models](https://arxiv.org/abs/2305.10474)
87 | """
88 | num_prog = shape[dim]
89 | shape_slice = shape[:dim] + (1,) + shape[dim + 1:]
90 | tensors = [randn_base(shape=shape_slice, mean=.0, std=1., generator=generator, device=device, dtype=dtype)]
91 | beta = alpha / sqrt(1. + alpha ** 2)
92 | std = 1. / (1. + alpha ** 2)
93 | for i in range(1, num_prog):
94 | tensor_i = beta * tensors[-1] + randn_base(
95 | shape=shape_slice, mean=.0, std=std, generator=generator, device=device, dtype=dtype
96 | )
97 | tensors.append(tensor_i)
98 | tensors = torch.cat(tensors, dim=dim)
99 | return tensors
100 |
101 |
102 | def prepare_masked_latents(images, vae_encode_func, scaling_factor=0.18215, sample_size=32, null_img_ratio=0):
103 | masks = torch.ones_like(images) # shape: [b, f, c, h, w]
104 | if random.random() < (1 - null_img_ratio):
105 | masks[:, 0, ...] = 0
106 | # masks[:, random.randrange(masks.shape[1]), ...] = 0 # TK
107 | masked_latents = images * (masks < 0.5)
108 | # map masks into latent space
109 | masks = masks[:, :, :1, :sample_size, :sample_size]
110 | masks = rearrange(masks, "b f c h w -> b c f h w")
111 | # map masked_latents into latent space
112 | masked_latents = vae_encode_func(
113 | masked_latents.view(masked_latents.shape[0] * masked_latents.shape[1], *masked_latents.shape[2:])
114 | ).latent_dist.sample() * scaling_factor
115 | masked_latents = rearrange(masked_latents, "(b f) c h w -> b c f h w", f=images.shape[1])
116 |
117 | return masks, masked_latents
118 |
119 | def decode_latents(vae_func, latents):
120 | scaling_factor = 0.18215
121 | video_length = latents.shape[2]
122 | latents = 1 / scaling_factor * latents
123 | latents = rearrange(latents, "b c f h w -> (b f) c h w")
124 | video = vae_func(latents).sample
125 | video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
126 | video = (video / 2 + 0.5).clamp(0, 1)
127 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
128 | video = video.float()
129 | return video
130 |
131 | def prepare_entity_latents(images, vae_encode_func, vae_decode_func, scaling_factor=0.18215, sample_size=32, null_img_ratio=0):
132 | # map masked_latents into latent space
133 | masked_latents = images
134 | masked_latents = vae_encode_func(
135 | masked_latents.view(masked_latents.shape[0] * masked_latents.shape[1], *masked_latents.shape[2:])
136 | ).latent_dist.sample() * scaling_factor
137 | entity_vae_latent = rearrange(masked_latents, "(b f) c h w -> b c f h w", f=images.shape[1])
138 |
139 | # save decode video
140 | # video = decode_latents(vae_decode_func, entity_vae_latent)
141 | # save_videos_grid(video, Path(f"samples_s{3000}", f"test decode.gif"))
142 | return entity_vae_latent
143 |
144 | def save_videos_grid(videos, path, rescale=False, n_rows=4, fps=4):
145 | if videos.dim() == 4:
146 | videos = videos.unsqueeze(0)
147 | videos = rearrange(videos, "b c t h w -> t b c h w")
148 | outputs = []
149 | for x in videos:
150 | x = torchvision.utils.make_grid(x, nrow=n_rows)
151 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
152 | if rescale:
153 | x = (x + 1.0) / 2.0 # [-1, 1) -> [0, 1)
154 | x = (x * 255).to(dtype=torch.uint8, device="cpu")
155 | outputs.append(x)
156 | Path(path).parent.mkdir(parents=True, exist_ok=True)
157 | imageio.mimwrite(Path(path).as_posix(), outputs, duration=1000 / fps, loop=0)
158 |
159 |
160 | @torch.no_grad()
161 | def compute_clip_score(model, model_processor, images, texts, local_bs=32, rescale=False):
162 | if rescale:
163 | images = (images + 1.0) / 2.0 # -1,1 -> 0,1
164 | images = (images * 255).to(torch.uint8)
165 | clip_scores = []
166 | for start_idx in range(0, images.shape[0], local_bs):
167 | img_batch = images[start_idx:start_idx + local_bs]
168 | batch_size = img_batch.shape[0] # shape: [b c t h w]
169 | img_batch = rearrange(img_batch, "b c t h w -> (b t) c h w")
170 | outputs = []
171 | for i in range(len(img_batch)):
172 | images_part = img_batch[i:i + 1]
173 | model_inputs = model_processor(
174 | text=texts, images=list(images_part), return_tensors="pt", padding=True
175 | )
176 | model_inputs = {
177 | k: v.to(device=model.device, dtype=model.dtype)
178 | if k in ["pixel_values"] else v.to(device=model.device)
179 | for k, v in model_inputs.items()
180 | }
181 | logits = model(**model_inputs)["logits_per_image"]
182 | # For consistency with `torchmetrics.functional.multimodal.clip_score`.
183 | logits = logits / model.logit_scale.exp()
184 | outputs.append(logits)
185 | logits = torch.cat(outputs)
186 | logits = rearrange(logits, "(b t) p -> t b p", b=batch_size)
187 | frame_sims = []
188 | for logit in logits:
189 | frame_sims.append(logit.diagonal())
190 | frame_sims = torch.stack(frame_sims) # [t, b]
191 | clip_scores.append(frame_sims.mean(dim=0))
192 | return torch.cat(clip_scores)
193 |
194 |
195 | class AbstractEncoder(nn.Module):
196 | def __init__(self):
197 | super().__init__()
198 |
199 | def encode(self, *args, **kwargs):
200 | raise NotImplementedError
201 |
202 | class FrozenOpenCLIPEmbedder(AbstractEncoder):
203 | """
204 | Uses the OpenCLIP transformer encoder for text
205 | """
206 | LAYERS = [
207 | # "pooled",
208 | "last",
209 | "penultimate"
210 | ]
211 |
212 | def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
213 | freeze=True, layer="last"):
214 | super().__init__()
215 | assert layer in self.LAYERS
216 | pretrained = "/home/user/model/open_clip/open_clip_ViT_H_14/open_clip_pytorch_model.bin"
217 | model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=pretrained)
218 | # model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'))
219 | del model.visual
220 | self.model = model
221 |
222 | self.device = device
223 | self.max_length = max_length
224 | if freeze:
225 | self.freeze()
226 | self.layer = layer
227 | if self.layer == "last":
228 | self.layer_idx = 0
229 | elif self.layer == "penultimate":
230 | self.layer_idx = 1
231 | else:
232 | raise NotImplementedError()
233 |
234 | def freeze(self):
235 | self.model = self.model.eval()
236 | for param in self.parameters():
237 | param.requires_grad = False
238 |
239 | def forward(self, text):
240 | self.device = self.model.positional_embedding.device
241 | tokens = open_clip.tokenize(text)
242 | z = self.encode_with_transformer(tokens.to(self.device))
243 | return z
244 |
245 | def encode_with_transformer(self, text):
246 | x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
247 | x = x + self.model.positional_embedding
248 | x = x.permute(1, 0, 2) # NLD -> LND
249 | x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
250 | x = x.permute(1, 0, 2) # LND -> NLD
251 | x = self.model.ln_final(x)
252 | return x
253 |
254 | def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
255 | for i, r in enumerate(self.model.transformer.resblocks):
256 | if i == len(self.model.transformer.resblocks) - self.layer_idx:
257 | break
258 | if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
259 | x = checkpoint(r, x, attn_mask)
260 | else:
261 | x = r(x, attn_mask=attn_mask)
262 | return x
263 |
264 | def encode(self, text):
265 | return self(text)
266 |
267 |
268 | class FrozenOpenCLIPImageEmbedderV2(AbstractEncoder):
269 | """
270 | Uses the OpenCLIP vision transformer encoder for images
271 | """
272 |
273 | def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda",
274 | freeze=True, layer="pooled", antialias=True):
275 | super().__init__()
276 |
277 | pretrained = "/home/user/model/open_clip/open_clip_ViT_H_14/open_clip_pytorch_model.bin"
278 | # model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version)
279 | model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=pretrained)
280 |
281 | del model.transformer
282 | self.model = model
283 | self.device = device
284 |
285 | if freeze:
286 | self.freeze()
287 | self.layer = layer
288 | if self.layer == "penultimate":
289 | raise NotImplementedError()
290 | self.layer_idx = 1
291 |
292 | self.antialias = antialias
293 | self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
294 | self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
295 |
296 | def preprocess(self, x):
297 | # normalize to [0,1]
298 | x = kornia.geometry.resize(x, (224, 224),
299 | interpolation='bicubic', align_corners=True,
300 | antialias=self.antialias)
301 | x = (x + 1.) / 2.
302 | # renormalize according to clip
303 | x = kornia.enhance.normalize(x, self.mean, self.std)
304 | return x
305 |
306 | def freeze(self):
307 | self.model = self.model.eval()
308 | for param in self.model.parameters():
309 | param.requires_grad = False
310 |
311 | def forward(self, image, no_dropout=False):
312 | ## image: b c h w
313 | z = self.encode_with_vision_transformer(image)
314 | return z
315 |
316 | def encode_with_vision_transformer(self, x):
317 | # x.shape: [1, 3, 320, 512]
318 | x = self.preprocess(x)
319 | # x.shape: [1, 3, 224, 224]
320 |
321 | # to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1
322 | if self.model.visual.input_patchnorm: # False
323 | # einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)')
324 | x = x.reshape(x.shape[0], x.shape[1], self.model.visual.grid_size[0], self.model.visual.patch_size[0], self.model.visual.grid_size[1], self.model.visual.patch_size[1])
325 | x = x.permute(0, 2, 4, 1, 3, 5)
326 | x = x.reshape(x.shape[0], self.model.visual.grid_size[0] * self.model.visual.grid_size[1], -1)
327 | x = self.model.visual.patchnorm_pre_ln(x)
328 | x = self.model.visual.conv1(x)
329 | else:
330 | x = self.model.visual.conv1(x) # shape = [*, width, grid, grid]
331 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
332 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
333 | # x.shape: [1, 256, 1280]
334 |
335 | # class embeddings and positional embeddings
336 | x = torch.cat(
337 | [self.model.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
338 | x], dim=1) # shape = [*, grid ** 2 + 1, width]
339 | x = x + self.model.visual.positional_embedding.to(x.dtype)
340 |
341 | # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
342 | x = self.model.visual.patch_dropout(x)
343 | x = self.model.visual.ln_pre(x)
344 |
345 | x = x.permute(1, 0, 2) # NLD -> LND
346 | x = self.model.visual.transformer(x)
347 |
348 | x = x.permute(1, 0, 2) # LND -> NLD
349 |
350 | # x.shape: [1, 257, 1280]
351 | return x
352 |
353 | def reshape_tensor(x, heads):
354 | bs, length, width = x.shape
355 | #(bs, length, width) --> (bs, length, n_heads, dim_per_head)
356 | x = x.view(bs, length, heads, -1)
357 | # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
358 | x = x.transpose(1, 2)
359 | # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
360 | x = x.reshape(bs, heads, length, -1)
361 | return x
362 |
363 | class PerceiverAttention(nn.Module):
364 | def __init__(self, *, dim, dim_head=64, heads=8):
365 | super().__init__()
366 | self.scale = dim_head**-0.5
367 | self.dim_head = dim_head
368 | self.heads = heads
369 | inner_dim = dim_head * heads
370 |
371 | self.norm1 = nn.LayerNorm(dim)
372 | self.norm2 = nn.LayerNorm(dim)
373 |
374 | self.to_q = nn.Linear(dim, inner_dim, bias=False)
375 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
376 | self.to_out = nn.Linear(inner_dim, dim, bias=False)
377 |
378 |
379 | def forward(self, x, latents):
380 | """
381 | Args:
382 | x (torch.Tensor): image features
383 | shape (b, n1, D)
384 | latent (torch.Tensor): latent features
385 | shape (b, n2, D)
386 | """
387 | x = self.norm1(x)
388 | latents = self.norm2(latents)
389 |
390 | b, l, _ = latents.shape
391 |
392 | q = self.to_q(latents)
393 | kv_input = torch.cat((x, latents), dim=-2)
394 | k, v = self.to_kv(kv_input).chunk(2, dim=-1)
395 |
396 | q = reshape_tensor(q, self.heads)
397 | k = reshape_tensor(k, self.heads)
398 | v = reshape_tensor(v, self.heads)
399 |
400 | # attention
401 | scale = 1 / math.sqrt(math.sqrt(self.dim_head))
402 | weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
403 | weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
404 | out = weight @ v
405 |
406 | out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
407 |
408 | return self.to_out(out)
409 |
410 | def FeedForward(dim, mult=4):
411 | inner_dim = int(dim * mult)
412 | return nn.Sequential(
413 | nn.LayerNorm(dim),
414 | nn.Linear(dim, inner_dim, bias=False),
415 | nn.GELU(),
416 | nn.Linear(inner_dim, dim, bias=False),
417 | )
418 |
419 | class Resampler(nn.Module):
420 | def __init__(
421 | self,
422 | dim=1024,
423 | depth=8,
424 | dim_head=64,
425 | heads=16,
426 | num_queries=8,
427 | embedding_dim=768,
428 | output_dim=1024,
429 | ff_mult=4,
430 | ):
431 | super().__init__()
432 |
433 | self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
434 |
435 | self.proj_in = nn.Linear(embedding_dim, dim)
436 |
437 | self.proj_out = nn.Linear(dim, output_dim)
438 | self.norm_out = nn.LayerNorm(output_dim)
439 |
440 | self.layers = nn.ModuleList([])
441 | for _ in range(depth):
442 | self.layers.append(
443 | nn.ModuleList(
444 | [
445 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
446 | FeedForward(dim=dim, mult=ff_mult),
447 | ]
448 | )
449 | )
450 |
451 | def forward(self, x):
452 | latents = self.latents.repeat(x.size(0), 1, 1)
453 |
454 | x = self.proj_in(x)
455 |
456 | for attn, ff in self.layers:
457 | latents = attn(x, latents) + latents
458 | latents = ff(latents) + latents
459 |
460 | latents = self.proj_out(latents)
461 | return self.norm_out(latents)
462 |
463 | class ImageProjModel(nn.Module):
464 | """Projection Model"""
465 | def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
466 | super().__init__()
467 | self.cross_attention_dim = cross_attention_dim
468 | self.clip_extra_context_tokens = clip_extra_context_tokens
469 | self.proj = nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
470 | self.norm = nn.LayerNorm(cross_attention_dim)
471 |
472 | def forward(self, image_embeds):
473 | #embeds = image_embeds
474 | embeds = image_embeds.type(list(self.proj.parameters())[0].dtype)
475 | clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim)
476 | clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
477 | return clip_extra_context_tokens
478 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.12.1
2 | torchvision==0.13.1
3 | diffusers[torch]==0.14.0
4 | transformers>=4.25.1
5 | hydra:hydra-core>=1.2.0
6 | pytorch-lightning>=1.9.2
7 | decord==0.6.0
8 | accelerate
9 | tensorboard
10 | modelcards
11 | omegaconf
12 | einops
13 | imageio
14 | ftfy
--------------------------------------------------------------------------------
/scripts/launcher.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | # parse arguments
5 | parser = argparse.ArgumentParser(description="Launcher for ModelArts")
6 | parser.add_argument(
7 | "--script_path",
8 | required=True,
9 | type=str,
10 | default="VATMM/scripts/runs/run1.sh",
11 | help="Shell script path to evaluate.",
12 | )
13 | parser.add_argument(
14 | "--world_size",
15 | default=1,
16 | type=int,
17 | help="Number of nodes.",
18 | )
19 | parser.add_argument(
20 | "--rank",
21 | default=0,
22 | type=int,
23 | help="Node rank.",
24 | )
25 | args, _ = parser.parse_known_args()
26 |
27 | # get base directory
28 | JOB_DIR = os.getenv("MA_JOB_DIR", "/home/ma-user/modelarts/user-job-dir/")
29 |
30 | # get absolute path of script
31 | script_path = os.path.join(JOB_DIR, args.script_path)
32 |
33 | # test if script exist
34 | if not os.path.exists(script_path):
35 | raise FileNotFoundError(script_path)
36 |
37 | # run script
38 | os.system(
39 | f"/bin/bash {script_path}"
40 | f" --world_size {args.world_size}"
41 | f" --node_rank {args.rank}"
42 | )
43 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .dist import dist_envs
2 | from .file_ops import ops
3 | from .logger import enable_logger
4 | from .registry import Registry
5 | from .utils import save_config, get_free_space, instantiate_multi, get_free_mem, slugify
6 |
--------------------------------------------------------------------------------
/utils/dist.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | logger = logging.getLogger(__name__)
4 |
5 |
6 | def ensure_env_init(fn):
7 | def wrapped_fn(obj):
8 | if not obj.is_initialized:
9 | logger.error("Distributed environments are not initialized!")
10 | return fn(obj)
11 |
12 | return wrapped_fn
13 |
14 |
15 | class DistEnvs:
16 | """Environments for distributed training."""
17 |
18 | def __init__(self):
19 | self._world_size = 1
20 | self._num_nodes = 1
21 | self._node_rank = 0
22 | self._global_rank = 0
23 | self._local_rank = 0
24 | self.is_initialized = False
25 |
26 | def init_envs(self, trainer):
27 | """Distributed Environments need be initialized once and only once."""
28 | if self.is_initialized:
29 | logger.warning("Distributed environments are repeatedly initialized!")
30 | self.is_initialized = True
31 | self._world_size = trainer.world_size
32 | self._num_nodes = trainer.num_nodes
33 | self._node_rank = trainer.node_rank
34 | self._global_rank = trainer.global_rank
35 | self._local_rank = trainer.local_rank
36 |
37 | @property
38 | @ensure_env_init
39 | def world_size(self):
40 | return self._world_size
41 |
42 | @property
43 | @ensure_env_init
44 | def num_nodes(self):
45 | return self._num_nodes
46 |
47 | @property
48 | @ensure_env_init
49 | def node_rank(self):
50 | return self._node_rank
51 |
52 | @property
53 | @ensure_env_init
54 | def global_rank(self):
55 | return self._global_rank
56 |
57 | @property
58 | @ensure_env_init
59 | def local_rank(self):
60 | return self._local_rank
61 |
62 |
63 | # singleton by module
64 | dist_envs = DistEnvs()
65 |
--------------------------------------------------------------------------------
/utils/file_ops.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import os
3 | import shutil
4 | from distutils.dir_util import copy_tree
5 |
6 |
7 | class FileOps:
8 | """
9 | Unified file operations for both s3 and local paths
10 | """
11 |
12 | def __init__(self):
13 | if importlib.util.find_spec("moxing"):
14 | import moxing as mox
15 | self.mox_valid = True
16 | self.mox = mox.file
17 | else:
18 | self.mox_valid = False
19 | self.mox = None
20 |
21 | @property
22 | def open(self):
23 | if self.mox_valid:
24 | return self.mox.File
25 | else:
26 | return open
27 |
28 | @property
29 | def exists(self):
30 | if self.mox_valid:
31 | return self.mox.exists
32 | else:
33 | return os.path.exists
34 |
35 | @property
36 | def listdir(self):
37 | if self.mox_valid:
38 | return self.mox.list_directory
39 | else:
40 | return os.listdir
41 |
42 | @property
43 | def isdir(self):
44 | if self.mox_valid:
45 | return self.mox.is_directory
46 | else:
47 | return os.path.isdir
48 |
49 | @property
50 | def makedirs(self):
51 | if self.mox_valid:
52 | return self.mox.make_dirs
53 | else:
54 | return os.makedirs
55 |
56 | @property
57 | def copy_dir(self):
58 | if self.mox_valid:
59 | return self.mox.copy_parallel
60 | else:
61 | return copy_tree
62 |
63 | @property
64 | def copy_file(self):
65 | if self.mox_valid:
66 | return self.mox.copy
67 | else:
68 | return shutil.copy
69 |
70 | def copy(self, src, dst, *args, **kwargs):
71 | if not self.exists(src):
72 | raise IOError('Source file {} does not exist.'.format(src))
73 | if self.isdir(src):
74 | self.copy_dir(src, dst, *args, **kwargs)
75 | else:
76 | self.copy_file(src, dst, *args, **kwargs)
77 |
78 | def mkdir_or_exist(self, path):
79 | if not self.exists(path):
80 | self.makedirs(path)
81 |
82 | @property
83 | def remove(self):
84 | if self.mox_valid:
85 | return self.mox.remove
86 | else:
87 | return os.remove
88 |
89 |
90 | ops = FileOps()
91 |
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from .dist import dist_envs
4 |
5 |
6 | def enable_logger(logger_name, log_file=None, log_level=logging.INFO):
7 | logger = logging.getLogger(logger_name)
8 | logger.disabled = False
9 | if len(logger.handlers) > 0:
10 | return
11 | logger.propagate = False
12 | stream_handler = logging.StreamHandler()
13 | handlers = [stream_handler]
14 |
15 | if not dist_envs.is_initialized:
16 | raise ValueError(f"Distributed environments are not initialized!")
17 |
18 | # only global_rank=0 will add a FileHandler
19 | if dist_envs.global_rank == 0 and log_file is not None:
20 | file_handler = logging.FileHandler(log_file, 'w')
21 | handlers.append(file_handler)
22 |
23 | # only rank=0 for each node will print logs
24 | log_level = log_level if dist_envs.local_rank == 0 else logging.ERROR
25 |
26 | formatter = logging.Formatter("[%(asctime)s][%(name)s][%(levelname)s]: %(message)s")
27 | for handler in handlers:
28 | handler.setFormatter(formatter)
29 | handler.setLevel(log_level)
30 | logger.addHandler(handler)
31 |
32 | logger.setLevel(log_level)
33 |
--------------------------------------------------------------------------------
/utils/registry.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | import warnings
3 | from collections import abc
4 | from functools import partial
5 |
6 |
7 | def is_seq_of(seq, expected_type, seq_type=None):
8 | """Check whether it is a sequence of some type.
9 |
10 | Args:
11 | seq (Sequence): The sequence to be checked.
12 | expected_type (type): Expected type of sequence items.
13 | seq_type (type, optional): Expected sequence type.
14 |
15 | Returns:
16 | bool: Whether the sequence is valid.
17 | """
18 | if seq_type is None:
19 | exp_seq_type = abc.Sequence
20 | else:
21 | assert isinstance(seq_type, type)
22 | exp_seq_type = seq_type
23 | if not isinstance(seq, exp_seq_type):
24 | return False
25 | for item in seq:
26 | if not isinstance(item, expected_type):
27 | return False
28 | return True
29 |
30 |
31 | class Registry:
32 | """A registry to map strings to classes.
33 |
34 | Args:
35 | name (str): Registry name.
36 | """
37 |
38 | def __init__(self, name):
39 | self._name = name
40 | self._module_dict = dict()
41 |
42 | def __len__(self):
43 | return len(self._module_dict)
44 |
45 | def __contains__(self, key):
46 | return self.get(key) is not None
47 |
48 | def __repr__(self):
49 | format_str = self.__class__.__name__ + \
50 | f'(name={self._name}, ' \
51 | f'items={self._module_dict})'
52 | return format_str
53 |
54 | @property
55 | def name(self):
56 | return self._name
57 |
58 | @property
59 | def module_dict(self):
60 | return self._module_dict
61 |
62 | def get(self, key):
63 | """Get the registry record.
64 |
65 | Args:
66 | key (str): The class name in string format.
67 |
68 | Returns:
69 | class: The corresponding class.
70 | """
71 | return self._module_dict.get(key, None)
72 |
73 | def _register_module(self, module_class, module_name=None, force=False):
74 | if not inspect.isclass(module_class):
75 | raise TypeError('module must be a class, '
76 | f'but got {type(module_class)}')
77 |
78 | if module_name is None:
79 | module_name = module_class.__name__
80 | if isinstance(module_name, str):
81 | module_name = [module_name]
82 | else:
83 | assert is_seq_of(
84 | module_name,
85 | str), ('module_name should be either of None, an '
86 | f'instance of str or list, but got {type(module_name)}')
87 | for name in module_name:
88 | if not force and name in self._module_dict:
89 | raise KeyError(f'{name} is already registered '
90 | f'in {self.name}')
91 | self._module_dict[name] = module_class
92 |
93 | def deprecated_register_module(self, cls=None, force=False):
94 | warnings.warn(
95 | 'The old API of register_module(module, force=False) '
96 | 'is deprecated and will be removed, please use the new API '
97 | 'register_module(name=None, force=False, module=None) instead.')
98 | if cls is None:
99 | return partial(self.deprecated_register_module, force=force)
100 | self._register_module(cls, force=force)
101 | return cls
102 |
103 | def register_module(self, name=None, force=False, module=None):
104 | """Register a module.
105 |
106 | A record will be added to `self._module_dict`, whose key is the class
107 | name or the specified name, and value is the class itself.
108 | It can be used as a decorator or a normal function.
109 |
110 | Example:
111 | >>> backbones = Registry('backbone')
112 | >>> @backbones.register_module()
113 | >>> class ResNet:
114 | >>> pass
115 |
116 | >>> backbones = Registry('backbone')
117 | >>> @backbones.register_module(name='mnet')
118 | >>> class MobileNet:
119 | >>> pass
120 |
121 | >>> backbones = Registry('backbone')
122 | >>> class ResNet:
123 | >>> pass
124 | >>> backbones.register_module(ResNet)
125 |
126 | Args:
127 | name (str | None): The module name to be registered. If not
128 | specified, the class name will be used.
129 | force (bool, optional): Whether to override an existing class with
130 | the same name. Default: False.
131 | module (type): Module class to be registered.
132 | """
133 | if not isinstance(force, bool):
134 | raise TypeError(f'force must be a boolean, but got {type(force)}')
135 | # NOTE: This is a walkaround to be compatible with the old api,
136 | # while it may introduce unexpected bugs.
137 | if isinstance(name, type):
138 | return self.deprecated_register_module(name, force=force)
139 |
140 | # use it as a normal method: x.register_module(module=SomeClass)
141 | if module is not None:
142 | self._register_module(
143 | module_class=module, module_name=name, force=force)
144 | return module
145 |
146 | # raise the error ahead of time
147 | if not (name is None or isinstance(name, str)):
148 | raise TypeError(f'name must be a str, but got {type(name)}')
149 |
150 | # use it as a decorator: @x.register_module()
151 | def _register(cls):
152 | self._register_module(
153 | module_class=cls, module_name=name, force=force)
154 | return cls
155 |
156 | return _register
157 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import re
4 | import unicodedata
5 |
6 | import torch
7 | from hydra.utils import instantiate
8 | from omegaconf import OmegaConf
9 |
10 | logger = logging.getLogger(__name__)
11 |
12 |
13 | def save_config(config, config_file="config.yaml"):
14 | with open(config_file, "w") as f:
15 | OmegaConf.save(config, f)
16 |
17 |
18 | def get_free_space(path):
19 | if not os.path.exists(path):
20 | logger.info(f"Path dose not exist: {path}")
21 | return 0
22 | info = os.statvfs(path)
23 | free_size = info.f_bsize * info.f_bavail / 1024 ** 3 # GB
24 | return free_size
25 |
26 |
27 | def instantiate_multi(config, name):
28 | instances = [instantiate(x) for x in config.get(name).values()
29 | if x is not None and "_target_" in x] if name in config else []
30 | return instances
31 |
32 |
33 | def get_free_mem(device):
34 | """ Get free memory of device. (MB) """
35 | mem_used, mem_total = torch.cuda.mem_get_info(device)
36 | mem_free = (mem_total - mem_used) / 1024.0 ** 3
37 | return mem_free
38 |
39 |
40 | def slugify(value, allow_unicode=False):
41 | """
42 | Taken from https://github.com/django/django/blob/master/django/utils/text.py
43 | Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated
44 | dashes to single dashes. Remove characters that aren't alphanumerics,
45 | underscores, or hyphens. Convert to lowercase. Also strip leading and
46 | trailing whitespace, dashes, and underscores.
47 | """
48 | value = str(value)
49 | if allow_unicode:
50 | value = unicodedata.normalize('NFKC', value)
51 | else:
52 | value = unicodedata.normalize('NFKD', value).encode('ascii', 'ignore').decode('ascii')
53 | value = re.sub(r'[^\w\s-]', '', value.lower())
54 | return re.sub(r'[-\s]+', '-', value).strip('-_')
55 |
--------------------------------------------------------------------------------