├── .gitattributes
├── .gitignore
├── LICENSE
├── README.md
├── eval
├── calculate_image.py
├── calculate_video.py
├── extract_token.py
└── reconstruct.py
├── open_tokenizer
├── __init__.py
├── model
│ ├── common_modules.py
│ ├── cosmos_tokenizer
│ │ ├── __init__.py
│ │ ├── image_cli.py
│ │ ├── image_lib.py
│ │ ├── modules
│ │ │ ├── __init__.py
│ │ │ ├── distributions.py
│ │ │ ├── layers2d.py
│ │ │ ├── layers3d.py
│ │ │ ├── patching.py
│ │ │ ├── quantizers.py
│ │ │ └── utils.py
│ │ ├── networks
│ │ │ ├── __init__.py
│ │ │ ├── configs.py
│ │ │ ├── continuous_image.py
│ │ │ ├── continuous_video.py
│ │ │ ├── discrete_image.py
│ │ │ └── discrete_video.py
│ │ ├── utils.py
│ │ ├── video_cli.py
│ │ └── video_lib.py
│ ├── llamagen_tokenizer.py
│ ├── misc.py
│ ├── modeling_utils.py
│ ├── omnitokenizer.py
│ ├── omnitokenizer_modules.py
│ ├── showo.py
│ ├── tiktok
│ │ ├── __init__.py
│ │ ├── maskgit.py
│ │ ├── modules
│ │ │ ├── __init__.py
│ │ │ ├── base_model.py
│ │ │ ├── blocks.py
│ │ │ ├── discriminator.py
│ │ │ ├── ema_model.py
│ │ │ ├── losses.py
│ │ │ ├── maskgit_vqgan.py
│ │ │ └── perceptual_loss.py
│ │ ├── quantizer
│ │ │ ├── __init__.py
│ │ │ └── quantizer.py
│ │ ├── rar.py
│ │ └── titok.py
│ └── vidtok
│ │ ├── configs
│ │ ├── vidtok_fsq_causal_41616_262144.yaml
│ │ ├── vidtok_fsq_causal_488_262144.yaml
│ │ ├── vidtok_fsq_causal_488_32768.yaml
│ │ ├── vidtok_fsq_causal_488_4096.yaml
│ │ ├── vidtok_fsq_noncausal_41616_262144.yaml
│ │ ├── vidtok_fsq_noncausal_488_262144.yaml
│ │ ├── vidtok_kl_causal_288_8chn.yaml
│ │ ├── vidtok_kl_causal_41616_4chn.yaml
│ │ ├── vidtok_kl_causal_444_4chn.yaml
│ │ ├── vidtok_kl_causal_488_16chn.yaml
│ │ ├── vidtok_kl_causal_488_4chn.yaml
│ │ ├── vidtok_kl_causal_488_8chn.yaml
│ │ ├── vidtok_kl_noncausal_41616_4chn.yaml
│ │ ├── vidtok_kl_noncausal_488_4chn.yaml
│ │ └── vidtwin
│ │ │ └── vidtwin_structure_7_7_8_dynamics_7_8.yaml
│ │ ├── data
│ │ ├── datamodule.py
│ │ ├── video_read.py
│ │ └── vidtok.py
│ │ ├── models
│ │ └── autoencoder.py
│ │ └── modules
│ │ ├── discriminator.py
│ │ ├── distributions.py
│ │ ├── ema.py
│ │ ├── logger.py
│ │ ├── losses.py
│ │ ├── lpips.py
│ │ ├── model_3dcausal.py
│ │ ├── model_3dnoncausal.py
│ │ ├── regularizers.py
│ │ └── util.py
└── utils
│ ├── data_utils.py
│ ├── dataset.py
│ ├── ddp_distributed.py
│ ├── fvd.py
│ └── metrics.py
├── pyproject.toml
├── requirements.txt
├── scripts
└── eval
│ ├── image
│ ├── cosmos.sh
│ ├── emu3.sh
│ ├── llamagen.sh
│ ├── omnitokenizer.sh
│ ├── showo.sh
│ └── tiktok.sh
│ └── video
│ ├── cosmos.sh
│ ├── emu3.sh
│ └── omnitokenizer.sh
└── src
├── 1-cos.gif
├── 1-emu3.gif
├── 1-omni.gif
├── 2-cos.gif
├── 2-emu3.gif
├── 2-omni.gif
└── vis-img.png
/.gitattributes:
--------------------------------------------------------------------------------
1 | # https://git-scm.com/docs/gitattributes
2 | # Set the default behavior, in case people don't have core.autocrlf set.
3 | # https://git-scm.com/docs/gitattributes#_end_of_line_conversion
4 | * text=auto
5 | # common python attributes, taken from https://github.com/alexkaratarakis/gitattributes/blob/710900479a2bedeec7003d381719521ffbb18bf8/Python.gitattributes
6 | # Source files
7 | # ============
8 | *.pxd text diff=python
9 | *.py text diff=python
10 | *.py3 text diff=python
11 | *.pyw text diff=python
12 | *.pyx text diff=python
13 | *.pyz text diff=python
14 | *.pyi text diff=python
15 | # Binary files
16 | # ============
17 | *.db binary
18 | *.p binary
19 | *.pkl binary
20 | *.pickle binary
21 | *.pyc binary export-ignore
22 | *.pyo binary export-ignore
23 | *.pyd binary
24 | # Jupyter notebook
25 | *.ipynb text eol=lf
26 | *.large-file-extension filter=lfs diff=lfs merge=lfs -text
27 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Python
2 | __pycache__
3 | *.pyc
4 | *.egg-info
5 | dist
6 |
7 | # Log
8 | *.log
9 | *.log.*
10 | # *.json
11 | # *.jsonl
12 |
13 | # Data
14 | !**/alpaca-data-conversation.json
15 | # Editor
16 | .idea
17 | *.swp
18 | .vscode
19 |
20 | # Other
21 | .DS_Store
22 | wandb
23 | output
24 | llavavid
25 |
26 | checkpoints/
27 | visualize/
28 | project_checkpoints
29 | debug_checkpoints
30 | playground/data
31 | playground/cc3m_llava34b_cap
32 | ckpts*
33 | classify_image_graph_def.pb
34 | visualizations
35 |
36 | .ipynb_checkpoints
37 | chunyl_scripts
38 | *.ipynb
39 |
40 | # DevContainer
41 | !.devcontainer/*
42 |
43 | # Demo
44 | serve_images/
45 | notebooks/
46 | logs
47 | scripts/dist_*
48 | logs/
49 | submissions/
50 | cn_scripts/
51 | internal_project_checkpoints/
52 | work_dirs
53 | scripts/i18n/*
54 | playground/.nfs028b000000010add00000001
55 | HIP
56 | playground/.nfs028b0000017bff2c00000012
57 | scripts/qwen
58 | scripts/vicuna
59 | scripts/mistral
60 | scripts/baseline_rep
61 | scripts/cn_boli01_hl
62 | scripts/cn_boli01_lf
63 | scripts/cn_lf
64 | scripts/cn_lq
65 | scripts/cn_yg
66 | scripts/cn_yg_hao
67 | scripts/eva_encoder
68 | scripts/i18n
69 | scripts/i18n_higher_res
70 | scripts/multi-images
71 | scratchpad
72 | build/
73 | playground/*.json
74 | mlx_configs/
75 | data_processing/
76 | # demo/
77 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # OpenTokenizer: A Comprehensive Comparision on Open-sourced Visual Tokenizers
2 |
3 | This repo aims to provide a comprehensive comparison on open-sourced discrete visual tokenizers under a fair setting: same resoution, same data, same evaluation code.
4 |
5 | ## Setup
6 |
7 | We use Python 3.10 and Pytorch 2.1.2, with CUDA Version 12.2.
8 | Please install the required packages using the following command:
9 |
10 | ```
11 | pip install -e .
12 | ```
13 |
14 | ## Benchmark
15 |
16 | We use a subset of [laion-aesthetics-12m-umap](https://huggingface.co/datasets/dclure/laion-aesthetics-12m-umap) and [OpenVid-1M](https://huggingface.co/datasets/nkp37/OpenVid-1M) as the test set to benchmark all the available tokenizers. You can download the data directly from [here](https://huggingface.co/datasets/Daniel0724/VTokenizer_Bench).
17 |
18 | The comparison of image and video reconstruction performance is shown below:
19 |
20 | | Name | Compression | PSNR | SSIM | rFID | FPS2 |
21 | | ---------- | ---------- | ---------- | ---------- | ----------- | ----------- |
22 | | [OmniTokenizer](https://huggingface.co/datasets/Daniel0724/VTokenizer_Bench/resolve/main/omnitokenizer_rq_down16_code16384_joint.ckpt)1 | 8x8 | 32.16 | 0.86 | 2.02 | 121.24 |
23 | | [Cosmos-DI](https://huggingface.co/nvidia/Cosmos-0.1-Tokenizer-DI8x8) | 8x8 | 31.04 | 0.83 | 0.95 | 612.91 |
24 | | [Emu3](https://huggingface.co/BAAI/Emu3-VisionTokenizer) | 8x8 | 31.51 | 0.83 | 0.57 | 12.72 |
25 | | [LlamaGen](https://huggingface.co/FoundationVision/LlamaGen/resolve/main/vq_ds16_c2i.pt)| 16x16 | 26.98 | 0.71 | 1.64 | 198.06 |
26 | | [Show-O](https://huggingface.co/showlab/magvitv2) | 16x16 | 27.03 | 0.71 | 1.54 | 129.62 |
27 | | [Tiktok](https://huggingface.co/yucornetto/tokenizer_titok_sl256_vq8k_imagenet) | 1D | 27.48 | 0.72 | 1.30 | 303.06 |
28 |
29 |
30 | | Name | Compression | PSNR | SSIM | rFID | FPS |
31 | | ---------- | ---------- | ---------- | ---------- | ----------- | ----------- |
32 | | [OmniTokenizer](https://huggingface.co/datasets/Daniel0724/VTokenizer_Bench/resolve/main/omnitokenizer_rq_down16_code16384_joint.ckpt) | 2x8x8 | 33.51 | 0.93 | 4.10 | 67.08 |
33 | | [Cosmos-DV](https://huggingface.co/nvidia/Cosmos-0.1-Tokenizer-DV4x8x8) | 4x8x8 | 34.54 | 0.94 | 4.19 | 76.46 |
34 | | [Emu3](https://huggingface.co/BAAI/Emu3-VisionTokenizer) | 4x8x8 | 33.44 | 0.92 | 8.15 | 3.24 |
35 |
36 |
37 | 1: OmniTokenizer is the only tokenizer designed for joint image and video tokenization. Here we release an [any-res checkpoint](https://huggingface.co/datasets/Daniel0724/VTokenizer_Bench/resolve/main/omnitokenizer_rq_down16_code16384_joint.ckpt) of OmniTokenizer, meaning we train the model with 2D Rope on images / videos with different aspect ratios and resolutions.
38 |
39 | 2: we denote the throughput of tokenizers as FPS, images (videos) per second.
40 |
41 |
42 |
43 | ## Usage
44 |
45 | Please run the following code for reconstruction:
46 |
47 | ```
48 | torchrun \
49 | --nnodes=1 --nproc_per_node=4 --master_port 23456 \
50 | eval/reconstruct.py \
51 | --vq_model_type "omnitokenizer" \
52 | --vq_model_ckpt "./checkpoints/omnitokenizer_rq_code16384_down16_joint2.ckpt" \
53 | --dataset_type "video" \
54 | --dataset_name "openvid" \
55 | --save_dir "./tokenizers" \
56 | --video_path "path_to_dir/tokenizer_bench/openvid.json" \
57 | --video_folder "" \
58 | --resolution 256 \
59 | --video_fps 16 \
60 | --sequence_length 17
61 | ```
62 |
63 | Please specify the tokenizer type with **--vq_model_type**, the ckpt with **--vq_model_ckpt**. The above command will automatically save the reconstructed results under **--save_dir**. We provide the script to inference different tokenizers under *./scripts/eval*.
64 |
65 | After this, you may run the following command to obtain the reconstruction metrics:
66 |
67 | ```
68 | # For image reconstruction:
69 |
70 | python eval/calculate_image.py /path_to_gt_dir /path_to_recon_dir
71 |
72 | # For video reconstruction:
73 |
74 | python eval/calculate_video.py /path_to_video_dir
75 |
76 | ```
77 |
78 | ## Offline Code Extraction
79 |
80 | We also provide a script to facilitate the token extraction using different tokenizers in an offline manner in *./eval/extract_token.py*.
81 |
82 | ## Results
83 |
84 | We show the 256x256 image reconstruction results below:
85 |
86 |
87 |
88 |
89 |
90 | From left to right, we compare OmniTokenizer, Cosmos-Tokenizer, and Emu3:
91 |
92 |
93 | 

94 |
95 |
96 |
97 | 

98 |
99 |
100 | ## To Do List
101 |
102 | - ~~open source the evaluation code~~
103 | - ~~support the calculation of different metrics~~
104 | - open source the training code
105 |
106 | ## Citation
107 |
108 | If you consider our work useful, please consider citing our paper using:
109 | ```bib
110 | @inproceedings{wang2024omnitokenizer,
111 | title={OmniTokenizer: A Joint Image-Video Tokenizer for Visual Generation},
112 | author={Wang, Junke and Jiang, Yi and Yuan, Zehuan and Peng, Binyue and Wu, Zuxuan and Jiang, Yu-Gang},
113 | booktitle={NeurIPS},
114 | year={2024}
115 | }
116 | ```
117 |
118 | ## Acknowledgement
119 |
120 | Thanks [OmniTokenizer](https://arxiv.org/abs/2406.09399), [LlamaGen](https://arxiv.org/abs/2406.06525), [Cosmos-Tokenizer](https://arxiv.org/html/2501.03575v1), [Emu3](https://arxiv.org/abs/2409.18869), [Show-O](https://arxiv.org/abs/2408.12528), and [TikTok](https://arxiv.org/abs/2406.07550) for their great work. We also borrow several functions from [Reducio](https://arxiv.org/abs/2411.13552) for evaluation.
121 |
122 |
123 |
124 |
125 |
--------------------------------------------------------------------------------
/eval/calculate_image.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 | from tqdm import tqdm
5 |
6 | from PIL import Image
7 | from cleanfid import fid
8 |
9 |
10 | import torchvision.transforms as transforms
11 | from open_tokenizer.utils.metrics import compute_psnr, compute_ssim
12 |
13 | def load_image_as_tensor(image_path):
14 | """
15 | Loads an image from the given path, converts it to a tensor, and normalizes it to [0, 1].
16 | """
17 | img = Image.open(image_path).convert("RGB") # Convert to RGB to ensure 3 channels
18 | transform = transforms.Compose([
19 | transforms.ToTensor(), # Converts to tensor and scales to [0, 1]
20 | ])
21 | return transform(img)
22 |
23 |
24 | def calculate_metrics(source_path, target_path):
25 | """
26 | Calculates PSNR, SSIM, and FID between images in the source and target directories.
27 | Assumes both directories have corresponding images with identical filenames.
28 | """
29 | source_images = sorted([os.path.join(source_path, f) for f in os.listdir(source_path) if f.endswith('jpg')])
30 | target_images = sorted([os.path.join(target_path, f) for f in os.listdir(target_path) if f.endswith('jpg')])
31 |
32 | if len(source_images) != len(target_images):
33 | raise ValueError("The number of images in the source and target directories must match.")
34 |
35 | psnr_values = []
36 | ssim_values = []
37 |
38 | for i in tqdm(range(len(source_images))):
39 | src, tgt = source_images[i], target_images[i]
40 | src_tensor = load_image_as_tensor(src).unsqueeze(0)
41 | tgt_tensor = load_image_as_tensor(tgt).unsqueeze(0)
42 |
43 | psnr = compute_psnr(src_tensor, tgt_tensor)
44 | ssim = compute_ssim(src_tensor, tgt_tensor)
45 | psnr_values.append(psnr)
46 | ssim_values.append(ssim)
47 |
48 | avg_psnr = sum(psnr_values) / len(psnr_values)
49 | avg_ssim = sum(ssim_values) / len(ssim_values)
50 |
51 | # Calculate FID
52 | score = fid.compute_fid(source_path, target_path)
53 | return avg_psnr.item(), avg_ssim.item(), score
54 |
55 |
56 | if __name__ == "__main__":
57 | parser = argparse.ArgumentParser(description="Calculate PSNR, SSIM, and FID metrics between source and target image directories.")
58 | parser.add_argument('source_path', type=str, help="Path to the directory containing source images.")
59 | parser.add_argument('target_path', type=str, help="Path to the directory containing target images.")
60 | args = parser.parse_args()
61 |
62 | avg_psnr, avg_ssim, fid = calculate_metrics(args.source_path, args.target_path)
63 | print(f"Average PSNR: {avg_psnr:.4f}")
64 | print(f"Average SSIM: {avg_ssim:.4f}")
65 | print(f"FID: {fid:.4f}")
66 |
67 | with open(os.path.join(os.path.dirname(args.source_path), "metrics.json"), "w") as f:
68 | json.dump(
69 | {
70 | "psnr": avg_psnr,
71 | "ssim": avg_ssim,
72 | "fid": fid
73 | },
74 | f)
75 |
--------------------------------------------------------------------------------
/eval/calculate_video.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 | from tqdm import tqdm
5 | from einops import rearrange
6 |
7 | import torch
8 | from torch.utils.data import DataLoader
9 |
10 | from open_tokenizer.utils.data_utils import open_url
11 | from open_tokenizer.utils.dataset import PairedVideoDataset
12 | from open_tokenizer.utils.metrics import compute_psnr, compute_ssim
13 | from open_tokenizer.utils.fvd import compute_fvd
14 |
15 |
16 | class VideoMetrics:
17 | def __init__(self, video_folder, image_size, device="cuda"):
18 | self.dataset = PairedVideoDataset(video_folder, image_size)
19 | num_workers = 1
20 | self.dataloader = DataLoader(self.dataset, batch_size=4, num_workers=num_workers, drop_last=False)
21 |
22 | detector_ckpt = './checkpoints/i3d_torchscript.pt'
23 | if not os.path.exists(detector_ckpt):
24 | raise FileNotFoundError(f"Detector checkpoint not found at: {detector_ckpt}")
25 |
26 | self.detector_kwargs = dict(rescale=False, resize=False, return_features=True)
27 |
28 | with open_url(detector_ckpt, verbose=False) as f:
29 | detector = torch.jit.load(f).eval().to(device)
30 |
31 | self.detector = detector
32 | self.device = device
33 |
34 | def calculate_metrics(self, metrics="fvd,psnr,ssim"):
35 | metrics = metrics.split(",")
36 | all_gt_feats = []
37 | all_pd_feats = []
38 | all_psnr = []
39 | all_ssim = []
40 |
41 | total_samples = 0
42 | for video_gt, video_pd in tqdm(self.dataloader):
43 | videos_fake = video_pd.to(self.device)[:, :, :16]
44 | videos_real = video_gt.to(self.device)[:, :, :16]
45 | total_samples += video_gt.shape[0] * 16
46 |
47 | flatten_videos_real = rearrange(videos_real, "b c t h w -> (b t) c h w")
48 | flatten_videos_fake = rearrange(videos_fake, "b c t h w -> (b t) c h w")
49 |
50 | if "fvd" in metrics:
51 | with torch.no_grad():
52 | feats_fake = self.detector(videos_fake, **self.detector_kwargs)
53 | feats_real = self.detector(videos_real, **self.detector_kwargs)
54 |
55 | all_gt_feats.append(feats_real)
56 | all_pd_feats.append(feats_fake)
57 |
58 | if "psnr" in metrics:
59 | psnr = compute_psnr(flatten_videos_real, flatten_videos_fake)
60 | all_psnr.append(psnr)
61 |
62 | if "ssim" in metrics:
63 | ssim = compute_ssim(flatten_videos_real, flatten_videos_fake)
64 | all_ssim.append(ssim)
65 |
66 | avg_psnr = sum(torch.cat(all_psnr, dim=0)) / total_samples if total_samples > 0 else None
67 | avg_ssim = sum(torch.cat(all_ssim, dim=0)) / total_samples if total_samples > 0 else None
68 |
69 | if all_gt_feats:
70 | all_gt_feats = torch.cat(all_gt_feats, dim=0).cpu().numpy()
71 | all_pd_feats = torch.cat(all_pd_feats, dim=0).cpu().numpy()
72 | avg_fvd = compute_fvd(all_gt_feats, all_pd_feats)
73 | else:
74 | avg_fvd = None
75 |
76 | return avg_psnr, avg_ssim, avg_fvd
77 |
78 |
79 | if __name__ == "__main__":
80 | parser = argparse.ArgumentParser(description="Calculate PSNR, SSIM, and FVD metrics between source and target videos.")
81 | parser.add_argument("--video_dir", type=str, help="Path to the directory containing videos.")
82 | parser.add_argument("--image_size", type=int, default=224, help="Size to resize videos (default: 224).")
83 | args = parser.parse_args()
84 |
85 | metrics = VideoMetrics(args.video_dir, args.image_size, device="cuda")
86 | avg_psnr, avg_ssim, avg_fvd = metrics.calculate_metrics()
87 |
88 | print(f"Average PSNR: {avg_psnr.item():.4f}" if avg_psnr else "PSNR not computed.")
89 | print(f"Average SSIM: {avg_ssim.item():.4f}" if avg_ssim else "SSIM not computed.")
90 | print(f"FVD: {avg_fvd:.4f}" if avg_fvd else "FVD not computed.")
91 |
92 | results_path = os.path.join(os.path.dirname(args.video_dir), "metrics.json")
93 | with open(results_path, "w") as f:
94 | json.dump(
95 | {
96 | "psnr": avg_psnr.item() if avg_psnr is not None else "N/A",
97 | "ssim": avg_ssim.item() if avg_ssim is not None else "N/A",
98 | "fvd": avg_fvd if avg_fvd is not None else "N/A",
99 | },
100 | f,
101 | )
102 | print(f"Metrics saved to {results_path}")
--------------------------------------------------------------------------------
/eval/extract_token.py:
--------------------------------------------------------------------------------
1 | import torch
2 | torch.backends.cuda.matmul.allow_tf32 = True
3 | torch.backends.cudnn.allow_tf32 = True
4 | import torch.distributed as dist
5 | from torch.utils.data import DataLoader
6 | from torch.utils.data.distributed import DistributedSampler
7 |
8 | from PIL import Image
9 | import os
10 | import numpy as np
11 | from tqdm import tqdm
12 | from dataclasses import dataclass, field
13 | from typing import Optional
14 |
15 | import transformers
16 | from transformers import AutoTokenizer
17 |
18 | from open_tokenizer.utils.ddp_distributed import init_distributed_mode
19 | from open_tokenizer.model.omnitokenizer import OmniTokenizer
20 | from open_tokenizer.utils.dataset import EvalT2IDataset, EvalT2VDataset
21 |
22 | @dataclass
23 | class DataArguments:
24 | image_folder: Optional[str] = field(default=None)
25 | image_path: str = field(default=None)
26 | resolution: Optional[int] = field(default=None)
27 |
28 | video_folder: Optional[str] = field(default=None)
29 | video_path: str = field(default=None)
30 | sequence_length: int = field(default=17)
31 | video_fps: Optional[int] = field(default=1)
32 |
33 |
34 | vq_model_type: str = field(default="omnitokenizer")
35 | vq_model_ckpt: str = field(default="./omnitokenizer_rq_down16_code16384_joint.ckpt")
36 | save_dir: str = field(default=None)
37 |
38 | dataset_name: str = field(default="laion")
39 | dataset_type: str = field(default="image")
40 |
41 | batch_size: int = field(default=8)
42 | workers: int = field(default=4)
43 |
44 |
45 | def main(args):
46 | os.makedirs(f"{args.code_path}/{args.dataset_name}/{args.gen_resolution}_codes", exist_ok=True)
47 | os.makedirs(f"{args.code_path}/{args.dataset_name}/{args.gen_resolution}_labels", exist_ok=True)
48 |
49 | init_distributed_mode(args)
50 | rank = dist.get_rank()
51 | device = rank % torch.cuda.device_count()
52 | seed = 0 * dist.get_world_size() + rank
53 | torch.manual_seed(seed)
54 | torch.cuda.set_device(device)
55 |
56 | model = OmniTokenizer.load_from_checkpoint(args.vq_model_ckpt, strict=False)
57 | model.eval()
58 | model.requires_grad_(False)
59 | model = model.to(device)
60 |
61 | downsample_factor = model.downsample_factor if model is not None else 16
62 | if args.dataset_type == "image":
63 | dataset = EvalT2IDataset(image_folder=args.image_folder, data_path=args.image_path, image_size=args.resolution)
64 | else:
65 | dataset = EvalT2VDataset(image_folder=args.video_folder, data_path=args.video_path, image_size=args.resolution, sequence_length=args.sequence_length, fps=args.video_fps)
66 |
67 | if dist.get_rank() == 0:
68 | print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
69 | print(f"Extracting image tokens for {len(dataset)} samples.")
70 |
71 | sampler = DistributedSampler(
72 | dataset,
73 | num_replicas=dist.get_world_size(),
74 | rank=rank,
75 | shuffle=False,
76 | )
77 |
78 | loader = DataLoader(
79 | dataset,
80 | batch_size=1, # important!
81 | shuffle=False,
82 | sampler=sampler,
83 | num_workers=0,
84 | pin_memory=False,
85 | drop_last=False
86 | )
87 |
88 | total = 0
89 | for image, caption in tqdm(loader):
90 | image_tensor = image.to(device)
91 |
92 | if image_tensor.ndim == 5:
93 | image_tensor = image_tensor.permute(0, 2, 1, 3, 4).contiguous()
94 |
95 | with torch.no_grad():
96 | image_token = model.encode(image_tensor)[1]
97 |
98 | x = image_token.detach().cpu().numpy()
99 | train_steps = rank + total
100 |
101 | np.save(f'{args.code_path}/{args.dataset_name}/{args.gen_resolution}_codes/{train_steps}.npy', x)
102 | with open(f'{args.code_path}/{args.dataset_name}/{args.gen_resolution}_labels/{train_steps}.npy', "w") as f:
103 | f.write(caption[0])
104 |
105 | total += dist.get_world_size()
106 |
107 | dist.destroy_process_group()
108 |
109 | if __name__ == "__main__":
110 | parser = transformers.HfArgumentParser([DataArguments])
111 | args = parser.parse_args_into_dataclasses()[0]
112 | args.gen_image_folder = ""
113 |
114 | main(args)
--------------------------------------------------------------------------------
/open_tokenizer/__init__.py:
--------------------------------------------------------------------------------
1 | from .model.omnitokenizer import OmniTokenizer
2 |
--------------------------------------------------------------------------------
/open_tokenizer/model/cosmos_tokenizer/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
--------------------------------------------------------------------------------
/open_tokenizer/model/cosmos_tokenizer/image_cli.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """A CLI to run ImageTokenizer on plain images based on torch.jit.
16 |
17 | Usage:
18 | python3 -m cosmos_tokenizer.image_cli \
19 | --image_pattern 'path/to/input/folder/*.jpg' \
20 | --output_dir ./reconstructions \
21 | --checkpoint_enc ./pretrained_ckpts/CosmosCI_f8x8/encoder.jit \
22 | --checkpoint_dec ./pretrained_ckpts/CosmosCI_f8x8/decoder.jit
23 |
24 | Optionally, you can run the model in pure PyTorch mode:
25 | python3 -m cosmos_tokenizer.image_cli \
26 | --image_pattern 'path/to/input/folder/*.jpg' \
27 | --mode torch \
28 | --tokenizer_type CI \
29 | --spatial_compression 8 \
30 | --checkpoint_enc ./pretrained_ckpts/CosmosCI_f8x8/encoder.jit \
31 | --checkpoint_dec ./pretrained_ckpts/CosmosCI_f8x8/decoder.jit
32 | """
33 |
34 | import os
35 | from argparse import ArgumentParser, Namespace
36 | import sys
37 | from typing import Any
38 |
39 | import numpy as np
40 | from loguru import logger as logging
41 | from cosmos_tokenizer.networks import TokenizerConfigs
42 |
43 | from cosmos_tokenizer.image_lib import ImageTokenizer
44 | from cosmos_tokenizer.utils import (
45 | get_filepaths,
46 | get_output_filepath,
47 | read_image,
48 | resize_image,
49 | write_image,
50 | )
51 |
52 |
53 | def _parse_args() -> tuple[Namespace, dict[str, Any]]:
54 | parser = ArgumentParser(
55 | description="A CLI for running ImageTokenizer on plain images."
56 | )
57 | parser.add_argument(
58 | "--image_pattern",
59 | type=str,
60 | default="path/to/images/*.jpg",
61 | help="Glob pattern.",
62 | )
63 | parser.add_argument(
64 | "--checkpoint",
65 | type=str,
66 | default=None,
67 | help="JIT full Autoencoder model filepath.",
68 | )
69 | parser.add_argument(
70 | "--checkpoint_enc",
71 | type=str,
72 | default=None,
73 | help="JIT Encoder model filepath.",
74 | )
75 | parser.add_argument(
76 | "--checkpoint_dec",
77 | type=str,
78 | default=None,
79 | help="JIT Decoder model filepath.",
80 | )
81 | parser.add_argument(
82 | "--tokenizer_type",
83 | type=str,
84 | choices=["CI", "DI"],
85 | help="Specifies the tokenizer type.",
86 | )
87 | parser.add_argument(
88 | "--spatial_compression",
89 | type=int,
90 | choices=[8, 16],
91 | default=8,
92 | help="The spatial compression factor.",
93 | )
94 | parser.add_argument(
95 | "--mode",
96 | type=str,
97 | choices=["torch", "jit"],
98 | default="jit",
99 | help="Specify the backend: native 'torch' or 'jit' (default: 'jit')",
100 | )
101 | parser.add_argument(
102 | "--short_size",
103 | type=int,
104 | default=None,
105 | help="The size to resample inputs. None, by default.",
106 | )
107 | parser.add_argument(
108 | "--dtype",
109 | type=str,
110 | default="bfloat16",
111 | help="Sets the precision. Default bfloat16.",
112 | )
113 | parser.add_argument(
114 | "--device",
115 | type=str,
116 | default="cuda",
117 | help="Device for invoking the model.",
118 | )
119 | parser.add_argument(
120 | "--output_dir", type=str, default=None, help="Output directory."
121 | )
122 | parser.add_argument(
123 | "--save_input",
124 | action="store_true",
125 | help="If on, the input image will be be outputed too.",
126 | )
127 | args = parser.parse_args()
128 | return args
129 |
130 |
131 | logging.info("Initializes args ...")
132 | args = _parse_args()
133 | if args.mode == "torch" and args.tokenizer_type not in ["CI", "DI"]:
134 | logging.error("'torch' backend requires the tokenizer_type of 'CI' or 'DI'.")
135 | sys.exit(1)
136 |
137 |
138 | def _run_eval() -> None:
139 | """Invokes the evaluation pipeline."""
140 |
141 | if (
142 | args.checkpoint_enc is None
143 | and args.checkpoint_dec is None
144 | and args.checkpoint is None
145 | ):
146 | logging.warning(
147 | "Aborting. Both encoder or decoder JIT required. Or provide the full autoencoder JIT model."
148 | )
149 | return
150 |
151 | if args.mode == "torch":
152 | tokenizer_config = TokenizerConfigs[args.tokenizer_type].value
153 | tokenizer_config.update(dict(spatial_compression=args.spatial_compression))
154 | else:
155 | tokenizer_config = None
156 |
157 | logging.info(
158 | f"Loading a torch.jit model `{os.path.dirname(args.checkpoint or args.checkpoint_enc or args.checkpoint_dec)}` ..."
159 | )
160 | autoencoder = ImageTokenizer(
161 | checkpoint=args.checkpoint,
162 | checkpoint_enc=args.checkpoint_enc,
163 | checkpoint_dec=args.checkpoint_dec,
164 | tokenizer_config=tokenizer_config,
165 | device=args.device,
166 | dtype=args.dtype,
167 | )
168 |
169 | filepaths = get_filepaths(args.image_pattern)
170 | logging.info(f"Found {len(filepaths)} images from {args.image_pattern}.")
171 |
172 | for filepath in filepaths:
173 | logging.info(f"Reading image {filepath} ...")
174 | image = read_image(filepath)
175 | image = resize_image(image, short_size=args.short_size)
176 | batch_image = np.expand_dims(image, axis=0)
177 |
178 | logging.info("Invoking the autoencoder model in ... ")
179 | output_image = autoencoder(batch_image)[0]
180 |
181 | output_filepath = get_output_filepath(filepath, output_dir=args.output_dir)
182 | logging.info(f"Outputing {output_filepath} ...")
183 | write_image(output_filepath, output_image)
184 |
185 | if args.save_input:
186 | ext = os.path.splitext(output_filepath)[-1]
187 | input_filepath = output_filepath.replace(ext, "_input" + ext)
188 | write_image(input_filepath, image)
189 |
190 |
191 | @logging.catch(reraise=True)
192 | def main() -> None:
193 | _run_eval()
194 |
195 |
196 | if __name__ == "__main__":
197 | main()
198 |
--------------------------------------------------------------------------------
/open_tokenizer/model/cosmos_tokenizer/image_lib.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """A library for image tokenizers inference."""
16 |
17 | import numpy as np
18 | import torch
19 | from typing import Any
20 |
21 | from .utils import (
22 | load_model,
23 | load_encoder_model,
24 | load_decoder_model,
25 | numpy2tensor,
26 | pad_image_batch,
27 | tensor2numpy,
28 | unpad_image_batch,
29 | )
30 |
31 |
32 | class ImageTokenizer(torch.nn.Module):
33 | def __init__(
34 | self,
35 | checkpoint: str = None,
36 | checkpoint_enc: str = None,
37 | checkpoint_dec: str = None,
38 | tokenizer_config: dict[str, Any] = None,
39 | device: str = "cuda",
40 | dtype: str = "bfloat16",
41 | ) -> None:
42 | super().__init__()
43 | self._device = device
44 | self._dtype = getattr(torch, dtype)
45 | self._full_model = (
46 | load_model(checkpoint, tokenizer_config, device).to(self._dtype)
47 | if checkpoint is not None
48 | else None
49 | )
50 | self._enc_model = (
51 | load_encoder_model(checkpoint_enc, tokenizer_config, device).to(self._dtype)
52 | if checkpoint_enc is not None
53 | else None
54 | )
55 | self._dec_model = (
56 | load_decoder_model(checkpoint_dec, tokenizer_config, device).to(self._dtype)
57 | if checkpoint_dec is not None
58 | else None
59 | )
60 | self.codebook_embed_dim = tokenizer_config["z_channels"]
61 |
62 | @torch.no_grad()
63 | def autoencode(self, input_tensor: torch.Tensor) -> torch.Tensor:
64 | """Reconstrcuts a batch of image tensors after embedding into a latent.
65 |
66 | Args:
67 | input_tensor: The input image Bx3xHxW layout, range [-1..1].
68 | Returns:
69 | The reconstructed tensor, layout Bx3xHxW, range [-1..1].
70 | """
71 | if self._full_model is not None:
72 | output_tensor = self._full_model(input_tensor)
73 | output_tensor = (
74 | output_tensor[0] if isinstance(output_tensor, tuple) else output_tensor
75 | )
76 | else:
77 | output_latent = self.encode(input_tensor)[0]
78 | output_tensor = self.decode(output_latent)
79 | return output_tensor
80 |
81 | @torch.no_grad()
82 | def decode(self, input_latent: torch.Tensor, latent_shape=None):
83 | """Decodes an image from a provided latent embedding.
84 |
85 | Args:
86 | input_latent: The continuous latent Bx16xhxw for CI,
87 | or the discrete indices Bxhxw for DI.
88 | Returns:
89 | The output tensor in Bx3xHxW, range [-1..1].
90 | """
91 | output_tensor = self._dec_model(input_latent)
92 | return output_tensor
93 |
94 | @torch.no_grad()
95 | def encode(self, input_tensor: torch.Tensor):
96 | """Encodes an image into a latent embedding or code.
97 |
98 | Args:
99 | input_tensor: The input tensor Bx3xHxW layout, range [-1..1].
100 | Returns:
101 | For continuous image (CI) tokenizer, the tuple contains:
102 | - The latent embedding, Bx16x(h)x(w), where the compression
103 | rate is (H/h x W/w), and channel dimension of 16.
104 | For discrete image (DI) tokenizer, the tuple contains:
105 | - The indices, Bx(h)x(w), from a codebook of size 64K, which
106 | corresponds to FSQ levels of (8,8,8,5,5,5).
107 | - The discrete code, Bx6x(h)x(w), where the compression rate is
108 | again (H/h x W/w), and channel dimension of 6.
109 | """
110 | output_latent = self._enc_model(input_tensor)
111 | if isinstance(output_latent, torch.Tensor):
112 | return "pad", output_latent
113 | return "pad", output_latent[0], output_latent[1]
114 |
115 | @torch.no_grad()
116 | def forward(self, image: np.ndarray) -> np.ndarray:
117 | """Reconstructs an image using a pre-trained tokenizer.
118 |
119 | Args:
120 | image: The input image BxHxWxC layout, range [0..255].
121 | Returns:
122 | The reconstructed image in range [0..255], layout BxHxWxC.
123 | """
124 | padded_input_image, crop_region = pad_image_batch(image)
125 | input_tensor = numpy2tensor(
126 | padded_input_image, dtype=self._dtype, device=self._device
127 | )
128 | output_tensor = self.autoencode(input_tensor)
129 | padded_output_image = tensor2numpy(output_tensor)
130 | return unpad_image_batch(padded_output_image, crop_region)
131 |
--------------------------------------------------------------------------------
/open_tokenizer/model/cosmos_tokenizer/modules/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | from enum import Enum
16 |
17 | from .distributions import (
18 | GaussianDistribution,
19 | IdentityDistribution,
20 | )
21 | from .layers2d import Decoder, Encoder
22 | from .layers3d import (
23 | DecoderBase,
24 | DecoderFactorized,
25 | EncoderBase,
26 | EncoderFactorized,
27 | )
28 | from .quantizers import (
29 | FSQuantizer,
30 | LFQuantizer,
31 | ResidualFSQuantizer,
32 | VectorQuantizer,
33 | )
34 |
35 |
36 | class EncoderType(Enum):
37 | Default = Encoder
38 |
39 |
40 | class DecoderType(Enum):
41 | Default = Decoder
42 |
43 |
44 | class Encoder3DType(Enum):
45 | BASE = EncoderBase
46 | FACTORIZED = EncoderFactorized
47 |
48 |
49 | class Decoder3DType(Enum):
50 | BASE = DecoderBase
51 | FACTORIZED = DecoderFactorized
52 |
53 |
54 | class ContinuousFormulation(Enum):
55 | VAE = GaussianDistribution
56 | AE = IdentityDistribution
57 |
58 |
59 | class DiscreteQuantizer(Enum):
60 | VQ = VectorQuantizer
61 | LFQ = LFQuantizer
62 | FSQ = FSQuantizer
63 | RESFSQ = ResidualFSQuantizer
64 |
--------------------------------------------------------------------------------
/open_tokenizer/model/cosmos_tokenizer/modules/distributions.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """The distribution modes to use for continuous image tokenizers."""
16 |
17 | import torch
18 |
19 |
20 | class IdentityDistribution(torch.nn.Module):
21 | def __init__(self):
22 | super().__init__()
23 |
24 | def forward(self, parameters):
25 | return parameters, (torch.tensor([0.0]), torch.tensor([0.0]))
26 |
27 |
28 | class GaussianDistribution(torch.nn.Module):
29 | def __init__(self, min_logvar: float = -30.0, max_logvar: float = 20.0):
30 | super().__init__()
31 | self.min_logvar = min_logvar
32 | self.max_logvar = max_logvar
33 |
34 | def sample(self, mean, logvar):
35 | std = torch.exp(0.5 * logvar)
36 | return mean + std * torch.randn_like(mean)
37 |
38 | def forward(self, parameters):
39 | mean, logvar = torch.chunk(parameters, 2, dim=1)
40 | logvar = torch.clamp(logvar, self.min_logvar, self.max_logvar)
41 | return self.sample(mean, logvar), (mean, logvar)
42 |
--------------------------------------------------------------------------------
/open_tokenizer/model/cosmos_tokenizer/modules/utils.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Shared utilities for the networks module."""
16 |
17 | from typing import Any
18 |
19 | import torch
20 | from einops import pack, rearrange, unpack
21 |
22 |
23 | def time2batch(x: torch.Tensor) -> tuple[torch.Tensor, int]:
24 | batch_size = x.shape[0]
25 | return rearrange(x, "b c t h w -> (b t) c h w"), batch_size
26 |
27 |
28 | def batch2time(x: torch.Tensor, batch_size: int) -> torch.Tensor:
29 | return rearrange(x, "(b t) c h w -> b c t h w", b=batch_size)
30 |
31 |
32 | def space2batch(x: torch.Tensor) -> tuple[torch.Tensor, int]:
33 | batch_size, height = x.shape[0], x.shape[-2]
34 | return rearrange(x, "b c t h w -> (b h w) c t"), batch_size, height
35 |
36 |
37 | def batch2space(x: torch.Tensor, batch_size: int, height: int) -> torch.Tensor:
38 | return rearrange(x, "(b h w) c t -> b c t h w", b=batch_size, h=height)
39 |
40 |
41 | def cast_tuple(t: Any, length: int = 1) -> Any:
42 | return t if isinstance(t, tuple) else ((t,) * length)
43 |
44 |
45 | def replication_pad(x):
46 | return torch.cat([x[:, :, :1, ...], x], dim=2)
47 |
48 |
49 | def divisible_by(num: int, den: int) -> bool:
50 | return (num % den) == 0
51 |
52 |
53 | def is_odd(n: int) -> bool:
54 | return not divisible_by(n, 2)
55 |
56 |
57 | def nonlinearity(x):
58 | return x * torch.sigmoid(x)
59 |
60 |
61 | def Normalize(in_channels, num_groups=32):
62 | return torch.nn.GroupNorm(
63 | num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
64 | )
65 |
66 |
67 | class CausalNormalize(torch.nn.Module):
68 | def __init__(self, in_channels, num_groups=1):
69 | super().__init__()
70 | self.norm = torch.nn.GroupNorm(
71 | num_groups=num_groups,
72 | num_channels=in_channels,
73 | eps=1e-6,
74 | affine=True,
75 | )
76 | self.num_groups = num_groups
77 |
78 | def forward(self, x):
79 | # if num_groups !=1, we apply a spatio-temporal groupnorm for backward compatibility purpose.
80 | # All new models should use num_groups=1, otherwise causality is not guaranteed.
81 | if self.num_groups == 1:
82 | x, batch_size = time2batch(x)
83 | return batch2time(self.norm(x), batch_size)
84 | return self.norm(x)
85 |
86 |
87 | def exists(v):
88 | return v is not None
89 |
90 |
91 | def default(*args):
92 | for arg in args:
93 | if exists(arg):
94 | return arg
95 | return None
96 |
97 |
98 | def pack_one(t, pattern):
99 | return pack([t], pattern)
100 |
101 |
102 | def unpack_one(t, ps, pattern):
103 | return unpack(t, ps, pattern)[0]
104 |
105 |
106 | def round_ste(z: torch.Tensor) -> torch.Tensor:
107 | """Round with straight through gradients."""
108 | zhat = z.round()
109 | return z + (zhat - z).detach()
110 |
111 |
112 | def log(t, eps=1e-5):
113 | return t.clamp(min=eps).log()
114 |
115 |
116 | def entropy(prob):
117 | return (-prob * log(prob)).sum(dim=-1)
118 |
--------------------------------------------------------------------------------
/open_tokenizer/model/cosmos_tokenizer/networks/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from enum import Enum
17 |
18 | from .configs import (
19 | continuous_image as continuous_image_dict,
20 | )
21 | from .configs import (
22 | discrete_image as discrete_image_dict,
23 | )
24 | from .configs import (
25 | continuous_video as continuous_video_dict,
26 | )
27 | from .configs import (
28 | discrete_video as discrete_video_dict,
29 | )
30 |
31 | from .continuous_image import ContinuousImageTokenizer
32 | from .discrete_image import DiscreteImageTokenizer
33 | from .continuous_video import (
34 | CausalContinuousVideoTokenizer,
35 | )
36 | from .discrete_video import (
37 | CausalDiscreteVideoTokenizer,
38 | )
39 |
40 |
41 | class TokenizerConfigs(Enum):
42 | CI = continuous_image_dict
43 | DI = discrete_image_dict
44 | CV = continuous_video_dict
45 | DV = discrete_video_dict
46 |
47 |
48 | class TokenizerModels(Enum):
49 | CI = ContinuousImageTokenizer
50 | DI = DiscreteImageTokenizer
51 | CV = CausalContinuousVideoTokenizer
52 | DV = CausalDiscreteVideoTokenizer
53 |
--------------------------------------------------------------------------------
/open_tokenizer/model/cosmos_tokenizer/networks/configs.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """The default image and video tokenizer configs."""
16 |
17 | from ..modules import (
18 | ContinuousFormulation,
19 | DiscreteQuantizer,
20 | EncoderType,
21 | DecoderType,
22 | Encoder3DType,
23 | Decoder3DType,
24 | )
25 |
26 | continuous_image = dict(
27 | # The attention resolution for res blocks.
28 | attn_resolutions=[32],
29 | # The base number of channels.
30 | channels=128,
31 | # The channel multipler for each resolution.
32 | channels_mult=[2, 4, 4],
33 | dropout=0.0,
34 | in_channels=3,
35 | # The spatial compression ratio.
36 | spatial_compression=16,
37 | # The number of layers in each res block.
38 | num_res_blocks=2,
39 | out_channels=3,
40 | resolution=1024,
41 | patch_size=4,
42 | patch_method="haar",
43 | # The output latent dimension (channels).
44 | latent_channels=16,
45 | # The encoder output channels just before sampling.
46 | # Which is also the decoder's input channels.
47 | z_channels=16,
48 | # A factor over the z_channels, to get the total channels the encoder should output.
49 | # For a VAE for instance, we want to output the mean and variance, so we need 2 * z_channels.
50 | z_factor=1,
51 | name="CI",
52 | # What formulation to use, either "AE" or "VAE".
53 | # Chose VAE here, since the pre-trained ckpt were of a VAE formulation.
54 | formulation=ContinuousFormulation.AE.name,
55 | # Specify type of encoder ["Default", "LiteVAE"]
56 | encoder=EncoderType.Default.name,
57 | # Specify type of decoder ["Default"]
58 | decoder=DecoderType.Default.name,
59 | )
60 |
61 | discrete_image = dict(
62 | # The attention resolution for res blocks.
63 | attn_resolutions=[32],
64 | # The base number of channels.
65 | channels=128,
66 | # The channel multipler for each resolution.
67 | channels_mult=[2, 4, 4],
68 | dropout=0.0,
69 | in_channels=3,
70 | # The spatial compression ratio.
71 | spatial_compression=16,
72 | # The number of layers in each res block.
73 | num_res_blocks=2,
74 | out_channels=3,
75 | resolution=1024,
76 | patch_size=4,
77 | patch_method="haar",
78 | # The encoder output channels just before sampling.
79 | z_channels=256,
80 | # A factor over the z_channels, to get the total channels the encoder should output.
81 | # for discrete tokenization, often we directly use the vector, so z_factor=1.
82 | z_factor=1,
83 | # The quantizer of choice, VQ, LFQ, FSQ, or ResFSQ.
84 | quantizer=DiscreteQuantizer.FSQ.name,
85 | # The embedding dimension post-quantization, which is also the input channels of the decoder.
86 | # Which is also the output
87 | embedding_dim=6,
88 | # The number of levels to use for fine-scalar quantization.
89 | levels=[8, 8, 8, 5, 5, 5],
90 | # The number of quantizers to use for residual fine-scalar quantization.
91 | num_quantizers=4,
92 | name="DI",
93 | # Specify type of encoder ["Default", "LiteVAE"]
94 | encoder=EncoderType.Default.name,
95 | # Specify type of decoder ["Default"]
96 | decoder=DecoderType.Default.name,
97 | )
98 |
99 | continuous_video = dict(
100 | attn_resolutions=[32],
101 | channels=128,
102 | channels_mult=[2, 4, 4],
103 | dropout=0.0,
104 | in_channels=3,
105 | num_res_blocks=2,
106 | out_channels=3,
107 | resolution=1024,
108 | patch_size=4,
109 | patch_method="haar",
110 | latent_channels=16,
111 | z_channels=16,
112 | z_factor=1,
113 | num_groups=1,
114 | legacy_mode=False,
115 | spatial_compression=8,
116 | temporal_compression=8,
117 | formulation=ContinuousFormulation.AE.name,
118 | encoder=Encoder3DType.FACTORIZED.name,
119 | decoder=Decoder3DType.FACTORIZED.name,
120 | name="CV",
121 | )
122 |
123 | discrete_video = dict(
124 | attn_resolutions=[32],
125 | channels=128,
126 | channels_mult=[2, 4, 4],
127 | dropout=0.0,
128 | in_channels=3,
129 | num_res_blocks=2,
130 | out_channels=3,
131 | resolution=1024,
132 | patch_size=4,
133 | patch_method="haar",
134 | z_channels=16,
135 | z_factor=1,
136 | num_groups=1,
137 | legacy_mode=False,
138 | spatial_compression=16,
139 | temporal_compression=8,
140 | quantizer=DiscreteQuantizer.FSQ.name,
141 | embedding_dim=6,
142 | levels=[8, 8, 8, 5, 5, 5],
143 | encoder=Encoder3DType.FACTORIZED.name,
144 | decoder=Decoder3DType.FACTORIZED.name,
145 | name="DV",
146 | )
147 |
--------------------------------------------------------------------------------
/open_tokenizer/model/cosmos_tokenizer/networks/continuous_image.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """The continuous image tokenizer with VAE or AE formulation for 2D data."""
16 |
17 | from collections import OrderedDict, namedtuple
18 |
19 | import torch
20 | from loguru import logger as logging
21 | from torch import nn
22 |
23 | from ..modules import (
24 | ContinuousFormulation,
25 | DecoderType,
26 | EncoderType,
27 | )
28 |
29 | NetworkEval = namedtuple("NetworkEval", ["reconstructions", "posteriors", "latent"])
30 |
31 |
32 | class ContinuousImageTokenizer(nn.Module):
33 | def __init__(
34 | self, z_channels: int, z_factor: int, latent_channels: int, **kwargs
35 | ) -> None:
36 | super().__init__()
37 | self.name = kwargs.get("name", "ContinuousImageTokenizer")
38 | self.latent_channels = latent_channels
39 |
40 | encoder_name = kwargs.get("encoder", EncoderType.Default.name)
41 | self.encoder = EncoderType[encoder_name].value(
42 | z_channels=z_factor * z_channels, **kwargs
43 | )
44 |
45 | decoder_name = kwargs.get("decoder", DecoderType.Default.name)
46 | self.decoder = DecoderType[decoder_name].value(z_channels=z_channels, **kwargs)
47 |
48 | self.quant_conv = torch.nn.Conv2d(
49 | z_factor * z_channels, z_factor * latent_channels, 1
50 | )
51 | self.post_quant_conv = torch.nn.Conv2d(latent_channels, z_channels, 1)
52 |
53 | formulation_name = kwargs.get("formulation", ContinuousFormulation.AE.name)
54 | self.distribution = ContinuousFormulation[formulation_name].value()
55 | logging.info(
56 | f"{self.name} based on {formulation_name} formulation, with {kwargs}."
57 | )
58 |
59 | num_parameters = sum(param.numel() for param in self.parameters())
60 | logging.info(f"model={self.name}, num_parameters={num_parameters:,}")
61 | logging.info(
62 | f"z_channels={z_channels}, latent_channels={self.latent_channels}."
63 | )
64 |
65 | def encoder_jit(self):
66 | return nn.Sequential(
67 | OrderedDict(
68 | [
69 | ("encoder", self.encoder),
70 | ("quant_conv", self.quant_conv),
71 | ("distribution", self.distribution),
72 | ]
73 | )
74 | )
75 |
76 | def decoder_jit(self):
77 | return nn.Sequential(
78 | OrderedDict(
79 | [
80 | ("post_quant_conv", self.post_quant_conv),
81 | ("decoder", self.decoder),
82 | ]
83 | )
84 | )
85 |
86 | def last_decoder_layer(self):
87 | return self.decoder.conv_out
88 |
89 | def encode(self, x):
90 | h = self.encoder(x)
91 | moments = self.quant_conv(h)
92 | return self.distribution(moments)
93 |
94 | def decode(self, z):
95 | z = self.post_quant_conv(z)
96 | dec = self.decoder(z)
97 | return dec
98 |
99 | def forward(self, input) -> dict[str, torch.Tensor] | NetworkEval:
100 | latent, posteriors = self.encode(input)
101 | dec = self.decode(latent)
102 | if self.training:
103 | return dict(reconstructions=dec, posteriors=posteriors, latent=latent)
104 | return NetworkEval(reconstructions=dec, posteriors=posteriors, latent=latent)
105 |
--------------------------------------------------------------------------------
/open_tokenizer/model/cosmos_tokenizer/networks/continuous_video.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """The causal continuous video tokenizer with VAE or AE formulation for 3D data.."""
16 | from collections import OrderedDict, namedtuple
17 |
18 | from loguru import logger as logging
19 | from torch import nn
20 |
21 | from ..modules import (
22 | ContinuousFormulation,
23 | Decoder3DType,
24 | Encoder3DType,
25 | )
26 | from ..modules.layers3d import CausalConv3d
27 |
28 | NetworkEval = namedtuple("NetworkEval", ["reconstructions", "posteriors", "latent"])
29 |
30 |
31 | class CausalContinuousVideoTokenizer(nn.Module):
32 | def __init__(
33 | self, z_channels: int, z_factor: int, latent_channels: int, **kwargs
34 | ) -> None:
35 | super().__init__()
36 | self.name = kwargs.get("name", "CausalContinuousVideoTokenizer")
37 | self.latent_channels = latent_channels
38 |
39 | encoder_name = kwargs.get("encoder", Encoder3DType.BASE.name)
40 | self.encoder = Encoder3DType[encoder_name].value(
41 | z_channels=z_factor * z_channels, **kwargs
42 | )
43 | if kwargs.get("temporal_compression", 4) == 4:
44 | kwargs["channels_mult"] = [2, 4]
45 | decoder_name = kwargs.get("decoder", Decoder3DType.BASE.name)
46 | self.decoder = Decoder3DType[decoder_name].value(
47 | z_channels=z_channels, **kwargs
48 | )
49 |
50 | self.quant_conv = CausalConv3d(
51 | z_factor * z_channels,
52 | z_factor * latent_channels,
53 | kernel_size=1,
54 | padding=0,
55 | )
56 | self.post_quant_conv = CausalConv3d(
57 | latent_channels, z_channels, kernel_size=1, padding=0
58 | )
59 |
60 | formulation_name = kwargs.get("formulation", ContinuousFormulation.AE.name)
61 | self.distribution = ContinuousFormulation[formulation_name].value()
62 | logging.info(
63 | f"{self.name} based on {formulation_name} formulation, with {kwargs}."
64 | )
65 |
66 | num_parameters = sum(param.numel() for param in self.parameters())
67 | logging.info(f"model={self.name}, num_parameters={num_parameters:,}")
68 | logging.info(
69 | f"z_channels={z_channels}, latent_channels={self.latent_channels}."
70 | )
71 |
72 | def encoder_jit(self):
73 | return nn.Sequential(
74 | OrderedDict(
75 | [
76 | ("encoder", self.encoder),
77 | ("quant_conv", self.quant_conv),
78 | ("distribution", self.distribution),
79 | ]
80 | )
81 | )
82 |
83 | def decoder_jit(self):
84 | return nn.Sequential(
85 | OrderedDict(
86 | [
87 | ("post_quant_conv", self.post_quant_conv),
88 | ("decoder", self.decoder),
89 | ]
90 | )
91 | )
92 |
93 | def last_decoder_layer(self):
94 | return self.decoder.conv_out
95 |
96 | def encode(self, x):
97 | h = self.encoder(x)
98 | moments = self.quant_conv(h)
99 | return self.distribution(moments)
100 |
101 | def decode(self, z):
102 | z = self.post_quant_conv(z)
103 | return self.decoder(z)
104 |
105 | def forward(self, input):
106 | latent, posteriors = self.encode(input)
107 | reconstructions = self.decode(latent)
108 | if self.training:
109 | return dict(
110 | reconstructions=reconstructions,
111 | posteriors=posteriors,
112 | latent=latent,
113 | )
114 | return NetworkEval(
115 | reconstructions=reconstructions,
116 | posteriors=posteriors,
117 | latent=latent,
118 | )
119 |
--------------------------------------------------------------------------------
/open_tokenizer/model/cosmos_tokenizer/networks/discrete_image.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """The network definition for discrete image tokenization with VQ, LFQ, FSQ or ResidualFSQ."""
16 | from collections import OrderedDict, namedtuple
17 |
18 | import torch
19 | from loguru import logger as logging
20 | from torch import nn
21 |
22 | from ..modules import DecoderType, DiscreteQuantizer, EncoderType
23 | from ..modules.quantizers import InvQuantizerJit
24 |
25 | NetworkEval = namedtuple("NetworkEval", ["reconstructions", "quant_loss", "quant_info"])
26 |
27 |
28 | class DiscreteImageTokenizer(nn.Module):
29 | def __init__(self, z_channels: int, embedding_dim: int, **kwargs) -> None:
30 | super().__init__()
31 | self.name = kwargs.get("name", "DiscreteImageTokenizer")
32 | self.embedding_dim = embedding_dim
33 |
34 | encoder_name = kwargs.get("encoder", EncoderType.Default.name)
35 | self.encoder = EncoderType[encoder_name].value(z_channels=z_channels, **kwargs)
36 |
37 | decoder_name = kwargs.get("decoder", DecoderType.Default.name)
38 | self.decoder = DecoderType[decoder_name].value(z_channels=z_channels, **kwargs)
39 | self.quant_conv = nn.Conv2d(z_channels, embedding_dim, 1)
40 | self.post_quant_conv = nn.Conv2d(embedding_dim, z_channels, 1)
41 |
42 | quantizer_name = kwargs.get("quantizer", DiscreteQuantizer.RESFSQ.name)
43 | if quantizer_name == DiscreteQuantizer.VQ.name:
44 | assert (
45 | "num_embeddings" in kwargs
46 | ), f"`num_embeddings` must be provided for {quantizer_name}."
47 | kwargs.update(dict(embedding_dim=embedding_dim))
48 | elif quantizer_name == DiscreteQuantizer.LFQ.name:
49 | assert (
50 | "codebook_size" in kwargs
51 | ), f"`codebook_size` must be provided for {quantizer_name}."
52 | assert (
53 | "codebook_dim" in kwargs
54 | ), f"`codebook_dim` must be provided for {quantizer_name}."
55 | elif quantizer_name == DiscreteQuantizer.FSQ.name:
56 | assert (
57 | "levels" in kwargs
58 | ), f"`levels` must be provided for {quantizer_name}."
59 | elif quantizer_name == DiscreteQuantizer.RESFSQ.name:
60 | assert (
61 | "levels" in kwargs
62 | ), f"`levels` must be provided for {quantizer_name}.name."
63 | assert (
64 | "num_quantizers" in kwargs
65 | ), f"`num_quantizers` must be provided for {quantizer_name}."
66 | self.quantizer = DiscreteQuantizer[quantizer_name].value(**kwargs)
67 | logging.info(f"{self.name} based on {quantizer_name}-VAE, with {kwargs}.")
68 |
69 | num_parameters = sum(param.numel() for param in self.parameters())
70 | logging.info(f"model={self.name}, num_parameters={num_parameters:,}")
71 | logging.info(f"z_channels={z_channels}, embedding_dim={self.embedding_dim}.")
72 |
73 | def to(self, *args, **kwargs):
74 | setattr(self.quantizer, "dtype", kwargs.get("dtype", torch.bfloat16))
75 | return super(DiscreteImageTokenizer, self).to(*args, **kwargs)
76 |
77 | def encoder_jit(self):
78 | return nn.Sequential(
79 | OrderedDict(
80 | [
81 | ("encoder", self.encoder),
82 | ("quant_conv", self.quant_conv),
83 | ("quantizer", self.quantizer),
84 | ]
85 | )
86 | )
87 |
88 | def decoder_jit(self):
89 | return nn.Sequential(
90 | OrderedDict(
91 | [
92 | ("inv_quant", InvQuantizerJit(self.quantizer)),
93 | ("post_quant_conv", self.post_quant_conv),
94 | ("decoder", self.decoder),
95 | ]
96 | )
97 | )
98 |
99 | def last_decoder_layer(self):
100 | return self.decoder.conv_out
101 |
102 | def encode(self, x):
103 | h = self.encoder(x)
104 | h = self.quant_conv(h)
105 | return self.quantizer(h)
106 |
107 | def decode(self, quant):
108 | quant = self.post_quant_conv(quant)
109 | return self.decoder(quant)
110 |
111 | def decode_code(self, code_b):
112 | quant_b = self.quantizer.indices_to_codes(code_b)
113 | quant_b = self.post_quant_conv(quant_b)
114 | return self.decoder(quant_b)
115 |
116 | def forward(self, input):
117 | quant_info, quant_codes, quant_loss = self.encode(input)
118 | reconstructions = self.decode(quant_codes)
119 | if self.training:
120 | return dict(
121 | reconstructions=reconstructions,
122 | quant_loss=quant_loss,
123 | quant_info=quant_info,
124 | )
125 | return NetworkEval(
126 | reconstructions=reconstructions,
127 | quant_loss=quant_loss,
128 | quant_info=quant_info,
129 | )
130 |
--------------------------------------------------------------------------------
/open_tokenizer/model/cosmos_tokenizer/networks/discrete_video.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """The network definition for discrete video tokenizer with VQ, LFQ, FSQ or ResidualFSQ. """
16 | from collections import OrderedDict, namedtuple
17 |
18 | import torch
19 | from loguru import logger as logging
20 | from torch import nn
21 |
22 | from ..modules import (
23 | Decoder3DType,
24 | DiscreteQuantizer,
25 | Encoder3DType,
26 | )
27 | from ..modules.layers3d import CausalConv3d
28 | from ..modules.quantizers import InvQuantizerJit
29 |
30 | NetworkEval = namedtuple("NetworkEval", ["reconstructions", "quant_loss", "quant_info"])
31 |
32 |
33 | class CausalDiscreteVideoTokenizer(nn.Module):
34 | def __init__(
35 | self, z_channels: int, z_factor: int, embedding_dim: int, **kwargs
36 | ) -> None:
37 | super().__init__()
38 | self.name = kwargs.get("name", "CausalDiscreteVideoTokenizer")
39 | self.embedding_dim = embedding_dim
40 |
41 | encoder_name = kwargs.get("encoder", Encoder3DType.BASE.name)
42 | self.encoder = Encoder3DType[encoder_name].value(
43 | z_channels=z_factor * z_channels, **kwargs
44 | )
45 |
46 | decoder_name = kwargs.get("decoder", Decoder3DType.BASE.name)
47 | self.decoder = Decoder3DType[decoder_name].value(
48 | z_channels=z_channels, **kwargs
49 | )
50 |
51 | self.quant_conv = CausalConv3d(
52 | z_factor * z_channels, embedding_dim, kernel_size=1, padding=0
53 | )
54 | self.post_quant_conv = CausalConv3d(
55 | embedding_dim, z_channels, kernel_size=1, padding=0
56 | )
57 |
58 | quantizer_name = kwargs.get("quantizer", DiscreteQuantizer.RESFSQ.name)
59 | if quantizer_name == DiscreteQuantizer.VQ.name:
60 | assert (
61 | "num_embeddings" in kwargs
62 | ), f"`num_embeddings` must be provided for {quantizer_name}."
63 | kwargs.update(dict(embedding_dim=embedding_dim))
64 | elif quantizer_name == DiscreteQuantizer.LFQ.name:
65 | assert (
66 | "codebook_size" in kwargs
67 | ), f"`codebook_size` must be provided for {quantizer_name}."
68 | assert (
69 | "codebook_dim" in kwargs
70 | ), f"`codebook_dim` must be provided for {quantizer_name}."
71 | elif quantizer_name == DiscreteQuantizer.FSQ.name:
72 | assert (
73 | "levels" in kwargs
74 | ), f"`levels` must be provided for {quantizer_name}."
75 | elif quantizer_name == DiscreteQuantizer.RESFSQ.name:
76 | assert (
77 | "levels" in kwargs
78 | ), f"`levels` must be provided for {quantizer_name}."
79 | assert (
80 | "num_quantizers" in kwargs
81 | ), f"`num_quantizers` must be provided for {quantizer_name}."
82 | self.quantizer = DiscreteQuantizer[quantizer_name].value(**kwargs)
83 | logging.info(f"{self.name} based on {quantizer_name}-VAE, with {kwargs}.")
84 |
85 | num_parameters = sum(param.numel() for param in self.parameters())
86 | logging.info(f"model={self.name}, num_parameters={num_parameters:,}")
87 | logging.info(f"z_channels={z_channels}, embedding_dim={self.embedding_dim}.")
88 |
89 | def to(self, *args, **kwargs):
90 | setattr(self.quantizer, "dtype", kwargs.get("dtype", torch.bfloat16))
91 | return super(CausalDiscreteVideoTokenizer, self).to(*args, **kwargs)
92 |
93 | def encoder_jit(self):
94 | return nn.Sequential(
95 | OrderedDict(
96 | [
97 | ("encoder", self.encoder),
98 | ("quant_conv", self.quant_conv),
99 | ("quantizer", self.quantizer),
100 | ]
101 | )
102 | )
103 |
104 | def decoder_jit(self):
105 | return nn.Sequential(
106 | OrderedDict(
107 | [
108 | ("inv_quant", InvQuantizerJit(self.quantizer)),
109 | ("post_quant_conv", self.post_quant_conv),
110 | ("decoder", self.decoder),
111 | ]
112 | )
113 | )
114 |
115 | def last_decoder_layer(self):
116 | return self.decoder.conv_out
117 |
118 | def encode(self, x):
119 | h = self.encoder(x)
120 | h = self.quant_conv(h)
121 | return self.quantizer(h)
122 |
123 | def decode(self, quant):
124 | quant = self.post_quant_conv(quant)
125 | return self.decoder(quant)
126 |
127 | def decode_code(self, code_b):
128 | quant_b = self.quantizer.indices_to_codes(code_b)
129 | quant_b = self.post_quant_conv(quant_b)
130 | return self.decoder(quant_b)
131 |
132 | def forward(self, input):
133 | quant_info, quant_codes, quant_loss = self.encode(input)
134 | reconstructions = self.decode(quant_codes)
135 | if self.training:
136 | return dict(
137 | reconstructions=reconstructions,
138 | quant_loss=quant_loss,
139 | quant_info=quant_info,
140 | )
141 | return NetworkEval(
142 | reconstructions=reconstructions,
143 | quant_loss=quant_loss,
144 | quant_info=quant_info,
145 | )
146 |
--------------------------------------------------------------------------------
/open_tokenizer/model/cosmos_tokenizer/video_cli.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """A CLI to run CausalVideoTokenizer on plain videos based on torch.jit.
16 |
17 | Usage:
18 | python3 -m cosmos_tokenizer.video_cli \
19 | --video_pattern 'path/to/video/samples/*.mp4' \
20 | --output_dir ./reconstructions \
21 | --checkpoint_enc ./pretrained_ckpts/CosmosCV_f4x8x8/encoder.jit \
22 | --checkpoint_dec ./pretrained_ckpts/CosmosCV_f4x8x8/decoder.jit
23 |
24 | Optionally, you can run the model in pure PyTorch mode:
25 | python3 -m cosmos_tokenizer.video_cli \
26 | --video_pattern 'path/to/video/samples/*.mp4' \
27 | --mode=torch \
28 | --tokenizer_type=CV \
29 | --temporal_compression=4 \
30 | --spatial_compression=8 \
31 | --checkpoint_enc ./pretrained_ckpts/CosmosCV_f4x8x8/encoder.jit \
32 | --checkpoint_dec ./pretrained_ckpts/CosmosCV_f4x8x8/decoder.jit
33 | """
34 |
35 | import os
36 | from argparse import ArgumentParser, Namespace
37 | from typing import Any
38 | import sys
39 |
40 | import numpy as np
41 | from loguru import logger as logging
42 |
43 | from .networks import TokenizerConfigs
44 | from .utils import (
45 | get_filepaths,
46 | get_output_filepath,
47 | read_video,
48 | resize_video,
49 | write_video,
50 | )
51 | from .video_lib import CausalVideoTokenizer
52 |
53 |
54 | def _parse_args() -> tuple[Namespace, dict[str, Any]]:
55 | parser = ArgumentParser(description="A CLI for CausalVideoTokenizer.")
56 | parser.add_argument(
57 | "--video_pattern",
58 | type=str,
59 | default="path/to/videos/*.mp4",
60 | help="Glob pattern.",
61 | )
62 | parser.add_argument(
63 | "--checkpoint",
64 | type=str,
65 | default=None,
66 | help="JIT full Autoencoder model filepath.",
67 | )
68 | parser.add_argument(
69 | "--checkpoint_enc",
70 | type=str,
71 | default=None,
72 | help="JIT Encoder model filepath.",
73 | )
74 | parser.add_argument(
75 | "--checkpoint_dec",
76 | type=str,
77 | default=None,
78 | help="JIT Decoder model filepath.",
79 | )
80 | parser.add_argument(
81 | "--tokenizer_type",
82 | type=str,
83 | choices=["CV", "DV"],
84 | help="Specifies the tokenizer type.",
85 | )
86 | parser.add_argument(
87 | "--spatial_compression",
88 | type=int,
89 | choices=[8, 16],
90 | default=8,
91 | help="The spatial compression factor.",
92 | )
93 | parser.add_argument(
94 | "--temporal_compression",
95 | type=int,
96 | choices=[4, 8],
97 | default=4,
98 | help="The temporal compression factor.",
99 | )
100 | parser.add_argument(
101 | "--mode",
102 | type=str,
103 | choices=["torch", "jit"],
104 | default="jit",
105 | help="Specify the backend: native 'torch' or 'jit' (default: 'jit')",
106 | )
107 | parser.add_argument(
108 | "--short_size",
109 | type=int,
110 | default=None,
111 | help="The size to resample inputs. None, by default.",
112 | )
113 | parser.add_argument(
114 | "--temporal_window",
115 | type=int,
116 | default=17,
117 | help="The temporal window to operate at a time.",
118 | )
119 | parser.add_argument(
120 | "--dtype",
121 | type=str,
122 | default="bfloat16",
123 | help="Sets the precision, default bfloat16.",
124 | )
125 | parser.add_argument(
126 | "--device",
127 | type=str,
128 | default="cuda",
129 | help="Device for invoking the model.",
130 | )
131 | parser.add_argument(
132 | "--output_dir", type=str, default=None, help="Output directory."
133 | )
134 | parser.add_argument(
135 | "--output_fps",
136 | type=float,
137 | default=24.0,
138 | help="Output frames-per-second (FPS).",
139 | )
140 | parser.add_argument(
141 | "--save_input",
142 | action="store_true",
143 | help="If on, the input video will be be outputted too.",
144 | )
145 |
146 | args = parser.parse_args()
147 | return args
148 |
149 |
150 | logging.info("Initializes args ...")
151 | args = _parse_args()
152 | if args.mode == "torch" and args.tokenizer_type not in ["CV", "DV"]:
153 | logging.error("'torch' backend requires the tokenizer_type of 'CV' or 'DV'.")
154 | sys.exit(1)
155 |
156 |
157 | def _run_eval() -> None:
158 | """Invokes JIT-compiled CausalVideoTokenizer on an input video."""
159 |
160 | if (
161 | args.checkpoint_enc is None
162 | and args.checkpoint_dec is None
163 | and args.checkpoint is None
164 | ):
165 | logging.warning(
166 | "Aborting. Both encoder or decoder JIT required. Or provide the full autoencoder JIT model."
167 | )
168 | return
169 |
170 | if args.mode == "torch":
171 | tokenizer_config = TokenizerConfigs[args.tokenizer_type].value
172 | tokenizer_config.update(dict(spatial_compression=args.spatial_compression))
173 | tokenizer_config.update(dict(temporal_compression=args.temporal_compression))
174 | else:
175 | tokenizer_config = None
176 |
177 | logging.info(
178 | f"Loading a torch.jit model `{os.path.dirname(args.checkpoint or args.checkpoint_enc or args.checkpoint_dec)}` ..."
179 | )
180 | autoencoder = CausalVideoTokenizer(
181 | checkpoint=args.checkpoint,
182 | checkpoint_enc=args.checkpoint_enc,
183 | checkpoint_dec=args.checkpoint_dec,
184 | tokenizer_config=tokenizer_config,
185 | device=args.device,
186 | dtype=args.dtype,
187 | )
188 |
189 | logging.info(f"Looking for files matching video_pattern={args.video_pattern} ...")
190 | filepaths = get_filepaths(args.video_pattern)
191 | logging.info(f"Found {len(filepaths)} videos from {args.video_pattern}.")
192 |
193 | for filepath in filepaths:
194 | logging.info(f"Reading video {filepath} ...")
195 | video = read_video(filepath)
196 | video = resize_video(video, short_size=args.short_size)
197 |
198 | logging.info("Invoking the autoencoder model in ... ")
199 | batch_video = video[np.newaxis, ...]
200 | output_video = autoencoder(batch_video, temporal_window=args.temporal_window)[0]
201 | logging.info("Constructing output filepath ...")
202 | output_filepath = get_output_filepath(filepath, output_dir=args.output_dir)
203 | logging.info(f"Outputing {output_filepath} ...")
204 | write_video(output_filepath, output_video, fps=args.output_fps)
205 | if args.save_input:
206 | ext = os.path.splitext(output_filepath)[-1]
207 | input_filepath = output_filepath.replace(ext, "_input" + ext)
208 | write_video(input_filepath, video, fps=args.output_fps)
209 |
210 |
211 | @logging.catch(reraise=True)
212 | def main() -> None:
213 | _run_eval()
214 |
215 |
216 | if __name__ == "__main__":
217 | main()
218 |
--------------------------------------------------------------------------------
/open_tokenizer/model/cosmos_tokenizer/video_lib.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """A library for Causal Video Tokenizer inference."""
16 |
17 | import numpy as np
18 | import torch
19 | from typing import Any
20 |
21 | from tqdm import tqdm
22 |
23 | from cosmos_tokenizer.utils import (
24 | load_model,
25 | load_encoder_model,
26 | load_decoder_model,
27 | numpy2tensor,
28 | pad_video_batch,
29 | tensor2numpy,
30 | unpad_video_batch,
31 | )
32 |
33 |
34 | class CausalVideoTokenizer(torch.nn.Module):
35 | def __init__(
36 | self,
37 | checkpoint: str = None,
38 | checkpoint_enc: str = None,
39 | checkpoint_dec: str = None,
40 | tokenizer_config: dict[str, Any] = None,
41 | device: str = "cuda",
42 | dtype: str = "bfloat16",
43 | ) -> None:
44 | super().__init__()
45 | self._device = device
46 | self._dtype = getattr(torch, dtype)
47 | self._full_model = (
48 | load_model(checkpoint, tokenizer_config, device).to(self._dtype)
49 | if checkpoint is not None
50 | else None
51 | )
52 | self._enc_model = (
53 | load_encoder_model(checkpoint_enc, tokenizer_config, device).to(self._dtype)
54 | if checkpoint_enc is not None
55 | else None
56 | )
57 | self._dec_model = (
58 | load_decoder_model(checkpoint_dec, tokenizer_config, device).to(self._dtype)
59 | if checkpoint_dec is not None
60 | else None
61 | )
62 |
63 | @torch.no_grad()
64 | def autoencode(self, input_tensor: torch.Tensor) -> torch.Tensor:
65 | """Reconstrcuts a batch of video tensors after embedding into a latent.
66 |
67 | Args:
68 | video: The input video Bx3xTxHxW layout, range [-1..1].
69 | Returns:
70 | The reconstructed video, layout Bx3xTxHxW, range [-1..1].
71 | """
72 | if self._full_model is not None:
73 | output_tensor = self._full_model(input_tensor)
74 | output_tensor = (
75 | output_tensor[0] if isinstance(output_tensor, tuple) else output_tensor
76 | )
77 | else:
78 | output_latent = self.encode(input_tensor)[0]
79 | output_tensor = self.decode(output_latent)
80 | return output_tensor
81 |
82 | @torch.no_grad()
83 | def encode(self, input_tensor: torch.Tensor) -> tuple[torch.Tensor]:
84 | """Encodes a numpy video into a CausalVideo latent or code.
85 |
86 | Args:
87 | input_tensor: The input tensor Bx3xTxHxW layout, range [-1..1].
88 | Returns:
89 | For causal continuous video (CV) tokenizer, the tuple contains:
90 | - The latent embedding, Bx16x(t)x(h)x(w), where the compression
91 | rate is (T/t x H/h x W/w), and channel dimension of 16.
92 | For causal discrete video (DV) tokenizer, the tuple contains:
93 | 1) The indices, Bx(t)x(h)x(w), from a codebook of size 64K, which
94 | is formed by FSQ levels of (8,8,8,5,5,5).
95 | 2) The discrete code, Bx6x(t)x(h)x(w), where the compression rate
96 | is again (T/t x H/h x W/w), and channel dimension of 6.
97 | """
98 | assert input_tensor.ndim == 5, "input video should be of 5D."
99 |
100 | output_latent = self._enc_model(input_tensor)
101 | if isinstance(output_latent, torch.Tensor):
102 | return output_latent
103 | return "pad", output_latent[0], output_latent[1]
104 |
105 | @torch.no_grad()
106 | def decode(self, input_latent: torch.Tensor) -> torch.Tensor:
107 | """Encodes a numpy video into a CausalVideo latent.
108 |
109 | Args:
110 | input_latent: The continuous latent Bx16xtxhxw for CV,
111 | or the discrete indices Bxtxhxw for DV.
112 | Returns:
113 | The reconstructed tensor, layout [B,3,1+(T-1)*8,H*16,W*16] in range [-1..1].
114 | """
115 | assert (
116 | input_latent.ndim >= 4
117 | ), "input latent should be of 5D for continuous and 4D for discrete."
118 | return self._dec_model(input_latent)
119 |
120 | def forward(
121 | self,
122 | video: np.ndarray,
123 | temporal_window: int = 17,
124 | ) -> np.ndarray:
125 | """Reconstructs video using a pre-trained CausalTokenizer autoencoder.
126 | Given a video of arbitrary length, the forward invokes the CausalVideoTokenizer
127 | in a sliding manner with a `temporal_window` size.
128 |
129 | Args:
130 | video: The input video BxTxHxWx3 layout, range [0..255].
131 | temporal_window: The length of the temporal window to process, default=25.
132 | Returns:
133 | The reconstructed video in range [0..255], layout BxTxHxWx3.
134 | """
135 | assert video.ndim == 5, "input video should be of 5D."
136 | num_frames = video.shape[1] # can be of any length.
137 | output_video_list = []
138 | for idx in tqdm(range(0, (num_frames - 1) // temporal_window + 1)):
139 | # Input video for the current window.
140 | start, end = idx * temporal_window, (idx + 1) * temporal_window
141 | input_video = video[:, start:end, ...]
142 |
143 | # Spatio-temporally pad input_video so it's evenly divisible.
144 | padded_input_video, crop_region = pad_video_batch(input_video)
145 | input_tensor = numpy2tensor(
146 | padded_input_video, dtype=self._dtype, device=self._device
147 | )
148 | output_tensor = self.autoencode(input_tensor)
149 | padded_output_video = tensor2numpy(output_tensor)
150 | output_video = unpad_video_batch(padded_output_video, crop_region)
151 |
152 | output_video_list.append(output_video)
153 | return np.concatenate(output_video_list, axis=1)
154 |
--------------------------------------------------------------------------------
/open_tokenizer/model/misc.py:
--------------------------------------------------------------------------------
1 | from omegaconf import OmegaConf
2 | import torch
3 | from typing import (
4 | Any,
5 | Callable,
6 | Dict,
7 | Iterable,
8 | List,
9 | NamedTuple,
10 | NewType,
11 | Optional,
12 | Sized,
13 | Tuple,
14 | Type,
15 | TypeVar,
16 | Union,
17 | )
18 | try:
19 | from typing import Literal
20 | except ImportError:
21 | from typing_extensions import Literal
22 |
23 | # Tensor dtype
24 | # for jaxtyping usage, see https://github.com/google/jaxtyping/blob/main/API.md
25 | from jaxtyping import Bool, Complex, Float, Inexact, Int, Integer, Num, Shaped, UInt
26 |
27 | # Config type
28 | from omegaconf import DictConfig
29 |
30 | # PyTorch Tensor type
31 | from torch import Tensor
32 |
33 | # Runtime type checking decorator
34 | from typeguard import typechecked as typechecker
35 |
36 |
37 | def broadcast(tensor, src=0):
38 | if not _distributed_available():
39 | return tensor
40 | else:
41 | torch.distributed.broadcast(tensor, src=src)
42 | return tensor
43 |
44 | def _distributed_available():
45 | return torch.distributed.is_available() and torch.distributed.is_initialized()
46 |
47 | def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any:
48 | # added by Xavier -- delete '--local-rank' in multi-nodes training, don't know why there is such a keyword
49 | if '--local-rank' in cfg:
50 | del cfg['--local-rank']
51 | # added by Xavier -- delete '--local-rank' in multi-nodes training, don't know why there is such a keyword
52 | scfg = OmegaConf.structured(fields(**cfg))
53 | return scfg
--------------------------------------------------------------------------------
/open_tokenizer/model/tiktok/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (2024) Bytedance Ltd. and/or its affiliates
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
--------------------------------------------------------------------------------
/open_tokenizer/model/tiktok/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .base_model import BaseModel
2 | from .ema_model import EMAModel
3 | from .losses import ReconstructionLoss_Stage1, ReconstructionLoss_Stage2, MLMLoss, ARLoss
4 | from .blocks import TiTokEncoder, TiTokDecoder, UViTBlock
5 | from .maskgit_vqgan import Decoder as Pixel_Decoder
6 | from .maskgit_vqgan import VectorQuantizer as Pixel_Quantizer
--------------------------------------------------------------------------------
/open_tokenizer/model/tiktok/modules/base_model.py:
--------------------------------------------------------------------------------
1 | """This file contains some base class implementation for models.
2 |
3 | This file may have been modified by Bytedance Ltd. and/or its affiliates (“Bytedance's Modifications”).
4 | All Bytedance's Modifications are Copyright (year) Bytedance Ltd. and/or its affiliates.
5 |
6 | Reference:
7 | https://github.com/huggingface/open-muse/blob/main/muse/modeling_utils.py
8 | """
9 | import os
10 | from typing import Union, Callable, Dict, Optional
11 |
12 | import torch
13 |
14 |
15 | class BaseModel(torch.nn.Module):
16 |
17 | def __init__(self):
18 | super().__init__()
19 |
20 | def save_pretrained_weight(
21 | self,
22 | save_directory: Union[str, os.PathLike],
23 | save_function: Callable = None,
24 | state_dict: Optional[Dict[str, torch.Tensor]] = None,
25 | ):
26 | """Saves a model and its configuration file to a directory.
27 |
28 | Args:
29 | save_directory: A string or os.PathLike, directory to which to save.
30 | Will be created if it doesn't exist.
31 | save_function: A Callable function, the function to use to save the state dictionary.
32 | Useful on distributed training like TPUs when one need to replace `torch.save` by
33 | another method. Can be configured with the environment variable `DIFFUSERS_SAVE_MODE`.
34 | state_dict: A dictionary from str to torch.Tensor, the state dictionary to save.
35 | If `None`, the model's state dictionary will be saved.
36 | """
37 | if os.path.isfile(save_directory):
38 | print(f"Provided path ({save_directory}) should be a directory, not a file")
39 | return
40 |
41 | if save_function is None:
42 | save_function = torch.save
43 |
44 | os.makedirs(save_directory, exist_ok=True)
45 |
46 | model_to_save = self
47 |
48 | if state_dict is None:
49 | state_dict = model_to_save.state_dict()
50 | weights_name = "pytorch_model.bin"
51 |
52 | save_function(state_dict, os.path.join(save_directory, weights_name))
53 |
54 | print(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
55 |
56 | def load_pretrained_weight(
57 | self,
58 | pretrained_model_path: Union[str, os.PathLike],
59 | strict_loading: bool = True,
60 | torch_dtype: Optional[torch.dtype] = None
61 | ):
62 | r"""Instantiates a pretrained pytorch model from a pre-trained model configuration.
63 |
64 | The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
65 | the model, you should first set it back in training mode with `model.train()`.
66 |
67 | Args:
68 | pretrained_model_path: A string or os.PathLike, a path to a *directory* or *file* containing model weights.
69 |
70 | Raises:
71 | ValueError: If pretrained_model_path does not exist.
72 | """
73 | # If pretrained_model_path is a file, set model_file to this file.
74 | if os.path.isfile(pretrained_model_path):
75 | model_file = pretrained_model_path
76 | # If pretrained_model_path is a directory, set model_file to the path of the
77 | # file "pytorch_model.bin" in this directory.
78 | elif os.path.isdir(pretrained_model_path):
79 | pretrained_model_path = os.path.join(pretrained_model_path, "pytorch_model.bin")
80 | if os.path.isfile(pretrained_model_path):
81 | model_file = pretrained_model_path
82 | else:
83 | raise ValueError(f"{pretrained_model_path} does not exist")
84 | else:
85 | raise ValueError(f"{pretrained_model_path} does not exist")
86 |
87 | # Load model state from checkpoint.
88 | checkpoint = torch.load(model_file, map_location="cpu")
89 | # Load state dictionary into self.
90 | msg = self.load_state_dict(checkpoint, strict=strict_loading)
91 | # Print information about loading weights.
92 | print(f"loading weight from {model_file}, msg: {msg}")
93 | # If torch_dtype is specified and is a valid torch.dtype, convert self to this dtype.
94 | if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
95 | raise ValueError(
96 | f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
97 | )
98 | elif torch_dtype is not None:
99 | self.to(torch_dtype)
100 |
101 | # Set model in evaluation mode to deactivate DropOut modules by default.
102 | self.eval()
103 |
104 | def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
105 | """Gets the number of parameters in the module.
106 |
107 | Args:
108 | only_trainable: A boolean, whether to only include trainable parameters.
109 | exclude_embeddings: A boolean, whether to exclude parameters associated with embeddings.
110 |
111 | Returns:
112 | An integer, the number of parameters.
113 | """
114 |
115 | if exclude_embeddings:
116 | embedding_param_names = [
117 | f"{name}.weight"
118 | for name, module_type in self.named_modules()
119 | if isinstance(module_type, torch.nn.Embedding)
120 | ]
121 | non_embedding_parameters = [
122 | parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
123 | ]
124 | return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
125 | else:
126 | return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
127 |
128 |
--------------------------------------------------------------------------------
/open_tokenizer/model/tiktok/modules/discriminator.py:
--------------------------------------------------------------------------------
1 | """This file contains some base implementation for discrminators.
2 |
3 | Copyright (2024) Bytedance Ltd. and/or its affiliates
4 |
5 | Licensed under the Apache License, Version 2.0 (the "License");
6 | you may not use this file except in compliance with the License.
7 | You may obtain a copy of the License at
8 |
9 | http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | Unless required by applicable law or agreed to in writing, software
12 | distributed under the License is distributed on an "AS IS" BASIS,
13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | See the License for the specific language governing permissions and
15 | limitations under the License.
16 |
17 | TODO: Add reference to Mark Weber's tech report on the improved discriminator architecture.
18 | """
19 | import functools
20 | import math
21 | from typing import Tuple
22 |
23 |
24 | import torch
25 | import torch.nn as nn
26 | import torch.nn.functional as F
27 |
28 | from .maskgit_vqgan import Conv2dSame
29 |
30 |
31 | class BlurBlock(torch.nn.Module):
32 | def __init__(self,
33 | kernel: Tuple[int] = (1, 3, 3, 1)
34 | ):
35 | super().__init__()
36 |
37 | kernel = torch.tensor(kernel, dtype=torch.float32, requires_grad=False)
38 | kernel = kernel[None, :] * kernel[:, None]
39 | kernel /= kernel.sum()
40 | kernel = kernel.unsqueeze(0).unsqueeze(0)
41 | self.register_buffer("kernel", kernel)
42 |
43 | def calc_same_pad(self, i: int, k: int, s: int) -> int:
44 | return max((math.ceil(i / s) - 1) * s + (k - 1) + 1 - i, 0)
45 |
46 | def forward(self, x: torch.Tensor) -> torch.Tensor:
47 | ic, ih, iw = x.size()[-3:]
48 | pad_h = self.calc_same_pad(i=ih, k=4, s=2)
49 | pad_w = self.calc_same_pad(i=iw, k=4, s=2)
50 | if pad_h > 0 or pad_w > 0:
51 | x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
52 |
53 | weight = self.kernel.expand(ic, -1, -1, -1)
54 |
55 | out = F.conv2d(input=x, weight=weight, stride=2, groups=x.shape[1])
56 | return out
57 |
58 |
59 | class NLayerDiscriminator(torch.nn.Module):
60 | def __init__(
61 | self,
62 | num_channels: int = 3,
63 | hidden_channels: int = 128,
64 | num_stages: int = 3,
65 | blur_resample: bool = True,
66 | blur_kernel_size: int = 4
67 | ):
68 | """ Initializes the NLayerDiscriminator.
69 |
70 | Args:
71 | num_channels -> int: The number of input channels.
72 | hidden_channels -> int: The number of hidden channels.
73 | num_stages -> int: The number of stages.
74 | blur_resample -> bool: Whether to use blur resampling.
75 | blur_kernel_size -> int: The blur kernel size.
76 | """
77 | super().__init__()
78 | assert num_stages > 0, "Discriminator cannot have 0 stages"
79 | assert (not blur_resample) or (blur_kernel_size >= 3 and blur_kernel_size <= 5), "Blur kernel size must be in [3,5] when sampling]"
80 |
81 | in_channel_mult = (1,) + tuple(map(lambda t: 2**t, range(num_stages)))
82 | init_kernel_size = 5
83 | activation = functools.partial(torch.nn.LeakyReLU, negative_slope=0.1)
84 |
85 | self.block_in = torch.nn.Sequential(
86 | Conv2dSame(
87 | num_channels,
88 | hidden_channels,
89 | kernel_size=init_kernel_size
90 | ),
91 | activation(),
92 | )
93 |
94 | BLUR_KERNEL_MAP = {
95 | 3: (1,2,1),
96 | 4: (1,3,3,1),
97 | 5: (1,4,6,4,1),
98 | }
99 |
100 | discriminator_blocks = []
101 | for i_level in range(num_stages):
102 | in_channels = hidden_channels * in_channel_mult[i_level]
103 | out_channels = hidden_channels * in_channel_mult[i_level + 1]
104 | block = torch.nn.Sequential(
105 | Conv2dSame(
106 | in_channels,
107 | out_channels,
108 | kernel_size=3,
109 | ),
110 | torch.nn.AvgPool2d(kernel_size=2, stride=2) if not blur_resample else BlurBlock(BLUR_KERNEL_MAP[blur_kernel_size]),
111 | torch.nn.GroupNorm(32, out_channels),
112 | activation(),
113 | )
114 | discriminator_blocks.append(block)
115 |
116 | self.blocks = torch.nn.ModuleList(discriminator_blocks)
117 |
118 | self.pool = torch.nn.AdaptiveMaxPool2d((16, 16))
119 |
120 | self.to_logits = torch.nn.Sequential(
121 | Conv2dSame(out_channels, out_channels, 1),
122 | activation(),
123 | Conv2dSame(out_channels, 1, kernel_size=5)
124 | )
125 |
126 | def forward(self, x: torch.Tensor) -> torch.Tensor:
127 | """ Forward pass.
128 |
129 | Args:
130 | x -> torch.Tensor: The input tensor.
131 |
132 | Returns:
133 | output -> torch.Tensor: The output tensor.
134 | """
135 | hidden_states = self.block_in(x)
136 | for block in self.blocks:
137 | hidden_states = block(hidden_states)
138 |
139 | hidden_states = self.pool(hidden_states)
140 |
141 | return self.to_logits(hidden_states)
142 |
--------------------------------------------------------------------------------
/open_tokenizer/model/tiktok/modules/perceptual_loss.py:
--------------------------------------------------------------------------------
1 | """This file contains perceptual loss module using ConvNeXt-S.
2 |
3 | Copyright (2024) Bytedance Ltd. and/or its affiliates
4 |
5 | Licensed under the Apache License, Version 2.0 (the "License");
6 | you may not use this file except in compliance with the License.
7 | You may obtain a copy of the License at
8 |
9 | http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | Unless required by applicable law or agreed to in writing, software
12 | distributed under the License is distributed on an "AS IS" BASIS,
13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | See the License for the specific language governing permissions and
15 | limitations under the License.
16 | """
17 |
18 | import torch
19 | import torch.nn.functional as F
20 |
21 | from torchvision import models
22 |
23 | _IMAGENET_MEAN = [0.485, 0.456, 0.406]
24 | _IMAGENET_STD = [0.229, 0.224, 0.225]
25 |
26 |
27 | class PerceptualLoss(torch.nn.Module):
28 | def __init__(self, model_name: str = "convnext_s"):
29 | """Initializes the PerceptualLoss class.
30 |
31 | Args:
32 | model_name: A string, the name of the perceptual loss model to use.
33 |
34 | Raise:
35 | ValueError: If the model_name does not contain "convnext_s".
36 | """
37 | super().__init__()
38 | if "convnext_s" not in model_name:
39 | raise ValueError(f"Unsupported Perceptual Loss model name {model_name}")
40 |
41 | self.convnext = models.convnext_small(weights=models.ConvNeXt_Small_Weights.IMAGENET1K_V1).eval()
42 | self.register_buffer("imagenet_mean", torch.Tensor(_IMAGENET_MEAN)[None, :, None, None])
43 | self.register_buffer("imagenet_std", torch.Tensor(_IMAGENET_STD)[None, :, None, None])
44 |
45 | for param in self.parameters():
46 | param.requires_grad = False
47 |
48 | def forward(self, input: torch.Tensor, target: torch.Tensor):
49 | """Computes the perceptual loss.
50 |
51 | Args:
52 | input: A tensor of shape (B, C, H, W), the input image. Normalized to [0, 1].
53 | target: A tensor of shape (B, C, H, W), the target image. Normalized to [0, 1].
54 |
55 | Returns:
56 | A scalar tensor, the perceptual loss.
57 | """
58 | # Always in eval mode.
59 | self.eval()
60 |
61 | input = torch.nn.functional.interpolate(input, size=224, mode="bilinear", align_corners=False, antialias=True)
62 | target = torch.nn.functional.interpolate(target, size=224, mode="bilinear", align_corners=False, antialias=True)
63 | pred_input = self.convnext((input - self.imagenet_mean) / self.imagenet_std)
64 | pred_target = self.convnext((target - self.imagenet_mean) / self.imagenet_std)
65 | loss = torch.nn.functional.mse_loss(
66 | pred_input,
67 | pred_target,
68 | reduction="mean")
69 |
70 | return loss
--------------------------------------------------------------------------------
/open_tokenizer/model/tiktok/quantizer/__init__.py:
--------------------------------------------------------------------------------
1 | from .quantizer import VectorQuantizer, DiagonalGaussianDistribution
--------------------------------------------------------------------------------
/open_tokenizer/model/tiktok/quantizer/quantizer.py:
--------------------------------------------------------------------------------
1 | """Vector quantizer.
2 |
3 | Copyright (2024) Bytedance Ltd. and/or its affiliates
4 |
5 | Licensed under the Apache License, Version 2.0 (the "License");
6 | you may not use this file except in compliance with the License.
7 | You may obtain a copy of the License at
8 |
9 | http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | Unless required by applicable law or agreed to in writing, software
12 | distributed under the License is distributed on an "AS IS" BASIS,
13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | See the License for the specific language governing permissions and
15 | limitations under the License.
16 |
17 | Reference:
18 | https://github.com/CompVis/taming-transformers/blob/master/taming/modules/vqvae/quantize.py
19 | https://github.com/google-research/magvit/blob/main/videogvt/models/vqvae.py
20 | https://github.com/CompVis/latent-diffusion/blob/main/ldm/modules/distributions/distributions.py
21 | """
22 | from typing import Mapping, Text, Tuple
23 |
24 | import torch
25 | from einops import rearrange
26 | from torch.cuda.amp import autocast
27 |
28 | class VectorQuantizer(torch.nn.Module):
29 | def __init__(self,
30 | codebook_size: int = 1024,
31 | token_size: int = 256,
32 | commitment_cost: float = 0.25,
33 | use_l2_norm: bool = False,
34 | ):
35 | super().__init__()
36 | self.commitment_cost = commitment_cost
37 |
38 | self.embedding = torch.nn.Embedding(codebook_size, token_size)
39 | self.embedding.weight.data.uniform_(-1.0 / codebook_size, 1.0 / codebook_size)
40 | self.use_l2_norm = use_l2_norm
41 |
42 | # Ensure quantization is performed using f32
43 | @autocast(enabled=False)
44 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]:
45 | z = z.float()
46 | z = rearrange(z, 'b c h w -> b h w c').contiguous()
47 | z_flattened = rearrange(z, 'b h w c -> (b h w) c')
48 |
49 | if self.use_l2_norm:
50 | z_flattened = torch.nn.functional.normalize(z_flattened, dim=-1)
51 | embedding = torch.nn.functional.normalize(self.embedding.weight, dim=-1)
52 | else:
53 | embedding = self.embedding.weight
54 | d = torch.sum(z_flattened**2, dim=1, keepdim=True) + \
55 | torch.sum(embedding**2, dim=1) - 2 * \
56 | torch.einsum('bd,dn->bn', z_flattened, embedding.T)
57 |
58 | min_encoding_indices = torch.argmin(d, dim=1) # num_ele
59 | z_quantized = self.get_codebook_entry(min_encoding_indices).view(z.shape)
60 |
61 | if self.use_l2_norm:
62 | z = torch.nn.functional.normalize(z, dim=-1)
63 |
64 | # compute loss for embedding
65 | commitment_loss = self.commitment_cost * torch.mean((z_quantized.detach() - z) **2)
66 | codebook_loss = torch.mean((z_quantized - z.detach()) **2)
67 |
68 | loss = commitment_loss + codebook_loss
69 |
70 | # preserve gradients
71 | z_quantized = z + (z_quantized - z).detach()
72 |
73 | # reshape back to match original input shape
74 | z_quantized = rearrange(z_quantized, 'b h w c -> b c h w').contiguous()
75 |
76 | result_dict = dict(
77 | quantizer_loss=loss,
78 | commitment_loss=commitment_loss,
79 | codebook_loss=codebook_loss,
80 | min_encoding_indices=min_encoding_indices.view(z_quantized.shape[0], z_quantized.shape[2], z_quantized.shape[3])
81 | )
82 |
83 | return z_quantized, result_dict
84 |
85 | def get_codebook_entry(self, indices):
86 | if len(indices.shape) == 1:
87 | z_quantized = self.embedding(indices)
88 | elif len(indices.shape) == 2:
89 | z_quantized = torch.einsum('bd,dn->bn', indices, self.embedding.weight)
90 | else:
91 | raise NotImplementedError
92 | if self.use_l2_norm:
93 | z_quantized = torch.nn.functional.normalize(z_quantized, dim=-1)
94 | return z_quantized
95 |
96 |
97 | class DiagonalGaussianDistribution(object):
98 | @autocast(enabled=False)
99 | def __init__(self, parameters, deterministic=False):
100 | """Initializes a Gaussian distribution instance given the parameters.
101 |
102 | Args:
103 | parameters (torch.Tensor): The parameters for the Gaussian distribution. It is expected
104 | to be in shape [B, 2 * C, *], where B is batch size, and C is the embedding dimension.
105 | First C channels are used for mean and last C are used for logvar in the Gaussian distribution.
106 | deterministic (bool): Whether to use deterministic sampling. When it is true, the sampling results
107 | is purely based on mean (i.e., std = 0).
108 | """
109 | self.parameters = parameters
110 | self.mean, self.logvar = torch.chunk(parameters.float(), 2, dim=1)
111 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
112 | self.deterministic = deterministic
113 | self.std = torch.exp(0.5 * self.logvar)
114 | self.var = torch.exp(self.logvar)
115 | if self.deterministic:
116 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
117 |
118 | @autocast(enabled=False)
119 | def sample(self):
120 | x = self.mean.float() + self.std.float() * torch.randn(self.mean.shape).to(device=self.parameters.device)
121 | return x
122 |
123 | @autocast(enabled=False)
124 | def mode(self):
125 | return self.mean
126 |
127 | @autocast(enabled=False)
128 | def kl(self):
129 | if self.deterministic:
130 | return torch.Tensor([0.])
131 | else:
132 | return 0.5 * torch.sum(torch.pow(self.mean.float(), 2)
133 | + self.var.float() - 1.0 - self.logvar.float(),
134 | dim=[1, 2])
135 |
--------------------------------------------------------------------------------
/open_tokenizer/model/tiktok/titok.py:
--------------------------------------------------------------------------------
1 | """This file contains the model definition of TiTok.
2 |
3 | Copyright (2024) Bytedance Ltd. and/or its affiliates
4 |
5 | Licensed under the Apache License, Version 2.0 (the "License");
6 | you may not use this file except in compliance with the License.
7 | You may obtain a copy of the License at
8 |
9 | http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | Unless required by applicable law or agreed to in writing, software
12 | distributed under the License is distributed on an "AS IS" BASIS,
13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | See the License for the specific language governing permissions and
15 | limitations under the License.
16 | """
17 |
18 | import torch
19 | import torch.nn as nn
20 | from einops import rearrange
21 |
22 | from .modules.base_model import BaseModel
23 | from .modules.blocks import TiTokEncoder, TiTokDecoder
24 | from .quantizer.quantizer import VectorQuantizer, DiagonalGaussianDistribution
25 | from .modules.maskgit_vqgan import Encoder as Pixel_Eecoder
26 | from .modules.maskgit_vqgan import Decoder as Pixel_Decoder
27 | from .modules.maskgit_vqgan import VectorQuantizer as Pixel_Quantizer
28 | import json
29 | from omegaconf import OmegaConf
30 | from pathlib import Path
31 |
32 | from huggingface_hub import PyTorchModelHubMixin
33 |
34 |
35 | class PretrainedTokenizer(nn.Module):
36 | def __init__(self, pretrained_weight):
37 | super().__init__()
38 | conf = OmegaConf.create(
39 | {"channel_mult": [1, 1, 2, 2, 4],
40 | "num_resolutions": 5,
41 | "dropout": 0.0,
42 | "hidden_channels": 128,
43 | "num_channels": 3,
44 | "num_res_blocks": 2,
45 | "resolution": 256,
46 | "z_channels": 256})
47 | self.encoder = Pixel_Eecoder(conf)
48 | self.decoder = Pixel_Decoder(conf)
49 | self.quantize = Pixel_Quantizer(
50 | num_embeddings=1024, embedding_dim=256, commitment_cost=0.25)
51 | # Load pretrained weights
52 | self.load_state_dict(torch.load(pretrained_weight, map_location=torch.device("cpu")), strict=True)
53 |
54 | self.eval()
55 | for param in self.parameters():
56 | param.requires_grad = False
57 |
58 | @torch.no_grad()
59 | def encode(self, x):
60 | hidden_states = self.encoder(x)
61 | quantized_states, codebook_indices, codebook_loss = self.quantize(hidden_states)
62 | return codebook_indices.detach()
63 |
64 | @torch.no_grad()
65 | def decode(self, codes):
66 | quantized_states = self.quantize.get_codebook_entry(codes)
67 | rec_images = self.decoder(quantized_states)
68 | rec_images = torch.clamp(rec_images, 0.0, 1.0)
69 | return rec_images.detach()
70 |
71 | @torch.no_grad()
72 | def decode_tokens(self, codes):
73 | return self.decode(codes)
74 |
75 |
76 | class TiTok(BaseModel, PyTorchModelHubMixin, tags=["arxiv:2406.07550", "image-tokenization"], repo_url="https://github.com/bytedance/1d-tokenizer", license="apache-2.0"):
77 | def __init__(self, config):
78 |
79 | if isinstance(config, dict):
80 | config = OmegaConf.create(config)
81 |
82 | super().__init__()
83 | self.config = config
84 | # This should be False for stage1 and True for stage2.
85 | self.finetune_decoder = config.model.vq_model.get("finetune_decoder", True)
86 |
87 | self.quantize_mode = config.model.vq_model.get("quantize_mode", "vq")
88 | if self.quantize_mode not in ["vq", "vae"]:
89 | raise ValueError(f"Unsupported quantize mode {self.quantize_mode}.")
90 |
91 | if self.finetune_decoder and self.quantize_mode not in ["vq"]:
92 | raise ValueError("Only supprot finetune_decoder with vq quantization for now.")
93 |
94 | self.encoder = TiTokEncoder(config)
95 | self.decoder = TiTokDecoder(config)
96 |
97 | self.num_latent_tokens = config.model.vq_model.num_latent_tokens
98 | scale = self.encoder.width ** -0.5
99 | self.latent_tokens = nn.Parameter(
100 | scale * torch.randn(self.num_latent_tokens, self.encoder.width))
101 |
102 | self.apply(self._init_weights)
103 |
104 | if self.quantize_mode == "vq":
105 | self.quantize = VectorQuantizer(
106 | codebook_size=config.model.vq_model.codebook_size,
107 | token_size=config.model.vq_model.token_size,
108 | commitment_cost=config.model.vq_model.commitment_cost,
109 | use_l2_norm=config.model.vq_model.use_l2_norm,)
110 | elif self.quantize_mode == "vae":
111 | self.quantize = DiagonalGaussianDistribution
112 | else:
113 | raise NotImplementedError
114 |
115 | if self.finetune_decoder:
116 | # Freeze encoder/quantizer/latent tokens
117 | self.latent_tokens.requires_grad_(False)
118 | self.encoder.eval()
119 | self.encoder.requires_grad_(False)
120 | self.quantize.eval()
121 | self.quantize.requires_grad_(False)
122 |
123 | # Include MaskGiT-VQGAN's quantizer and decoder
124 | self.pixel_quantize = Pixel_Quantizer(
125 | num_embeddings=1024, embedding_dim=256, commitment_cost=0.25)
126 | self.pixel_decoder = Pixel_Decoder(OmegaConf.create(
127 | {"channel_mult": [1, 1, 2, 2, 4],
128 | "num_resolutions": 5,
129 | "dropout": 0.0,
130 | "hidden_channels": 128,
131 | "num_channels": 3,
132 | "num_res_blocks": 2,
133 | "resolution": 256,
134 | "z_channels": 256}))
135 |
136 | def _save_pretrained(self, save_directory: Path) -> None:
137 | """Save weights and config to a local directory."""
138 | # Assume 'self.config' is your DictConfig object
139 | # Convert to a regular dictionary
140 | dict_config = OmegaConf.to_container(self.config)
141 | # Save as JSON
142 | file_path = Path(save_directory) / "config.json"
143 | with open(file_path, 'w') as json_file:
144 | json.dump(dict_config, json_file, indent=4)
145 | super()._save_pretrained(save_directory)
146 |
147 | def _init_weights(self, module):
148 | """ Initialize the weights.
149 | :param:
150 | module -> torch.nn.Module: module to initialize
151 | """
152 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv1d) or isinstance(module, nn.Conv2d):
153 | module.weight.data = nn.init.trunc_normal_(module.weight.data, mean=0.0, std=0.02)
154 | if module.bias is not None:
155 | module.bias.data.zero_()
156 | elif isinstance(module, nn.Embedding):
157 | module.weight.data = nn.init.trunc_normal_(module.weight.data, mean=0.0, std=0.02)
158 | elif isinstance(module, nn.LayerNorm):
159 | module.bias.data.zero_()
160 | module.weight.data.fill_(1.0)
161 |
162 | def encode(self, x):
163 | if self.finetune_decoder:
164 | with torch.no_grad():
165 | self.encoder.eval()
166 | self.quantize.eval()
167 | z = self.encoder(pixel_values=x, latent_tokens=self.latent_tokens)
168 | z_quantized, result_dict = self.quantize(z)
169 | result_dict["quantizer_loss"] *= 0
170 | result_dict["commitment_loss"] *= 0
171 | result_dict["codebook_loss"] *= 0
172 | else:
173 | z = self.encoder(pixel_values=x, latent_tokens=self.latent_tokens)
174 | if self.quantize_mode == "vq":
175 | z_quantized, result_dict = self.quantize(z)
176 | elif self.quantize_mode == "vae":
177 | posteriors = self.quantize(z)
178 | z_quantized = posteriors.sample()
179 | result_dict = posteriors
180 |
181 | return z_quantized, result_dict["min_encoding_indices"], result_dict
182 |
183 | def _decode(self, z_quantized):
184 | decoded = self.decoder(z_quantized)
185 | if self.finetune_decoder:
186 | quantized_states = torch.einsum(
187 | 'nchw,cd->ndhw', decoded.softmax(1),
188 | self.pixel_quantize.embedding.weight)
189 | decoded = self.pixel_decoder(quantized_states)
190 | return decoded
191 |
192 | def decode(self, tokens, latent_shape=None):
193 | if self.quantize_mode == "vq":
194 | tokens = tokens.squeeze(1)
195 | batch, seq_len = tokens.shape # B x N
196 | z_quantized = self.quantize.get_codebook_entry(
197 | tokens.reshape(-1)).reshape(batch, 1, seq_len, -1)
198 | z_quantized = rearrange(z_quantized, 'b h w c -> b c h w').contiguous()
199 | elif self.quantize_mode == "vae":
200 | z_quantized = tokens
201 |
202 | decoded = self._decode(z_quantized)
203 | return decoded
204 |
205 | def forward(self, x):
206 | z_quantized, result_dict = self.encode(x)
207 | decoded = self.decode(z_quantized)
208 | return decoded, result_dict
209 |
--------------------------------------------------------------------------------
/open_tokenizer/model/vidtok/configs/vidtok_fsq_causal_41616_262144.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1e-5
3 | target: open_tokenizer.model.vidtok.models.autoencoder.AutoencodingEngine
4 | params:
5 | monitor: val/rec_loss
6 | mode: min
7 | # ckpt_path: checkpoints/vidtok_fsq_causal_41616_262144.ckpt # train from existing checkpoint
8 | ignore_keys: []
9 | # ema_decay: 0.999
10 |
11 | encoder_config:
12 | target: open_tokenizer.model.vidtok.modules.model_3dcausal.EncoderCausal3DPadding
13 | params:
14 | double_z: false
15 | z_channels: 6
16 | in_channels: 3
17 | out_ch: 3
18 | ch: 128
19 | ch_mult: [1, 2, 4, 4, 4]
20 | time_downsample_factor: 4
21 | num_res_blocks: 2
22 | dropout: 0.0
23 | use_checkpoint: false
24 | init_pad_mode: replicate
25 | norm_type: layernorm # layernorm, groupnorm
26 | fix_encoder: false # if True, fix it without updating params
27 | fix_decoder: false # if True, fix it without updating params
28 |
29 | decoder_config:
30 | target: open_tokenizer.model.vidtok.modules.model_3dcausal.DecoderCausal3DPadding
31 | params: ${model.params.encoder_config.params}
32 |
33 | regularizer_config:
34 | target: open_tokenizer.model.vidtok.modules.regularizers.FSQRegularizer
35 | params:
36 | levels: [8, 8, 8, 8, 8, 8] # codebook size: 8*8*8*8*8*8=262144
37 | entropy_loss_weight: 0.1
38 | entropy_loss_annealing_steps: 2000
39 | entropy_loss_annealing_factor: 3
40 | commitment_loss_weight: 0.25
41 |
42 | loss_config:
43 | target: open_tokenizer.model.vidtok.modules.losses.GeneralLPIPSWithDiscriminator
44 | params:
45 | dims: 3 # video - [t,h,w]
46 | perceptual_weight: 1.0
47 | disc_start: 20001
48 | disc_weight: 0.2
49 | disc_type: 2d # 2d, 3d
50 | learn_logvar: true
51 | gen_loss_cross_entropy: true
52 | lecam_loss_weight: 0.005
53 | regularization_weights: {'aux_loss': 1.0, 'kl_loss': 0.000001}
54 |
55 | data:
56 | target: open_tokenizer.model.vidtok.data.datamodule.DataModuleFromConfig
57 | params:
58 | batch_size: 2
59 | num_workers: 12
60 |
61 | train:
62 | target: open_tokenizer.model.vidtok.data.vidtok.VidTokDataset
63 | params:
64 | data_dir: DATA_DIR_1 # DATA_DIR for training data
65 | meta_path: META_PATH_1 # path to the .csv meta file of training data
66 | video_params:
67 | input_height: INPUT_HEIGHT_1
68 | input_width: INPUT_WIDTH_1
69 | sample_num_frames: 17
70 | sample_fps: 3
71 |
72 | validation:
73 | target: open_tokenizer.model.vidtok.data.vidtok.VidTokDataset
74 | params:
75 | data_dir: DATA_DIR_2 # DATA_DIR for validation data
76 | meta_path: META_PATH_2 # path to the .csv meta file of validation data
77 | video_params:
78 | input_height: INPUT_HEIGHT_2
79 | input_width: INPUT_WIDTH_2
80 | sample_num_frames: 17
81 | sample_fps: 8
82 | start_index: 0
83 |
84 | lightning:
85 | strategy:
86 | target: lightning.pytorch.strategies.DDPStrategy
87 | params:
88 | find_unused_parameters: true
89 |
90 | modelcheckpoint:
91 | params:
92 | every_n_train_steps: 5000
93 |
94 | callbacks:
95 | image_logger:
96 | target: open_tokenizer.model.vidtok.modules.logger.ImageVideoLogger
97 | params:
98 | disabled: false
99 | rescale: true
100 | enable_autocast: false
101 | batch_frequency: 5000
102 | max_samples: 2
103 | increase_log_steps: false
104 | log_first_step: false
105 | log_before_first_step: false
106 | log_images_kwargs:
107 | n_rows: 17
108 |
109 | trainer:
110 | precision: bf16-mixed
111 | devices: auto
112 | num_nodes: 1
113 | benchmark: true
114 | num_sanity_val_steps: 10
115 | val_check_interval: 2000
116 | check_val_every_n_epoch: null # default: 1
117 | accumulate_grad_batches: 1
118 | max_epochs: 1000
119 |
--------------------------------------------------------------------------------
/open_tokenizer/model/vidtok/configs/vidtok_fsq_causal_488_262144.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1e-5
3 | target: vidtok.models.autoencoder.AutoencodingEngine
4 | params:
5 | monitor: val/rec_loss
6 | mode: min
7 | # ckpt_path: checkpoints/vidtok_fsq_causal_488_262144.ckpt # train from existing checkpoint
8 | ignore_keys: []
9 | # ema_decay: 0.999
10 |
11 | encoder_config:
12 | target: vidtok.modules.model_3dcausal.EncoderCausal3DPadding
13 | params:
14 | double_z: false
15 | z_channels: 6
16 | in_channels: 3
17 | out_ch: 3
18 | ch: 128
19 | ch_mult: [1, 2, 4, 4]
20 | time_downsample_factor: 4
21 | num_res_blocks: 2
22 | dropout: 0.0
23 | use_checkpoint: false
24 | init_pad_mode: replicate
25 | norm_type: layernorm # layernorm, groupnorm
26 | fix_encoder: false # if True, fix it without updating params
27 | fix_decoder: false # if True, fix it without updating params
28 |
29 | decoder_config:
30 | target: vidtok.modules.model_3dcausal.DecoderCausal3DPadding
31 | params: ${model.params.encoder_config.params}
32 |
33 | regularizer_config:
34 | target: vidtok.modules.regularizers.FSQRegularizer
35 | params:
36 | levels: [8, 8, 8, 8, 8, 8] # codebook size: 8*8*8*8*8*8=262144
37 | entropy_loss_weight: 0.1
38 | entropy_loss_annealing_steps: 2000
39 | entropy_loss_annealing_factor: 3
40 | commitment_loss_weight: 0.25
41 |
42 | loss_config:
43 | target: vidtok.modules.losses.GeneralLPIPSWithDiscriminator
44 | params:
45 | dims: 3 # video - [t,h,w]
46 | perceptual_weight: 1.0
47 | disc_start: 20001
48 | disc_weight: 0.2
49 | disc_type: 2d # 2d, 3d
50 | learn_logvar: true
51 | gen_loss_cross_entropy: true
52 | lecam_loss_weight: 0.005
53 | regularization_weights: {'aux_loss': 1.0, 'kl_loss': 0.000001}
54 |
55 | data:
56 | target: vidtok.data.datamodule.DataModuleFromConfig
57 | params:
58 | batch_size: 2
59 | num_workers: 12
60 |
61 | train:
62 | target: vidtok.data.vidtok.VidTokDataset
63 | params:
64 | data_dir: DATA_DIR_1 # DATA_DIR for training data
65 | meta_path: META_PATH_1 # path to the .csv meta file of training data
66 | video_params:
67 | input_height: INPUT_HEIGHT_1
68 | input_width: INPUT_WIDTH_1
69 | sample_num_frames: 17
70 | sample_fps: 3
71 |
72 | validation:
73 | target: vidtok.data.vidtok.VidTokDataset
74 | params:
75 | data_dir: DATA_DIR_2 # DATA_DIR for validation data
76 | meta_path: META_PATH_2 # path to the .csv meta file of validation data
77 | video_params:
78 | input_height: INPUT_HEIGHT_2
79 | input_width: INPUT_WIDTH_2
80 | sample_num_frames: 17
81 | sample_fps: 8
82 | start_index: 0
83 |
84 | lightning:
85 | strategy:
86 | target: lightning.pytorch.strategies.DDPStrategy
87 | params:
88 | find_unused_parameters: true
89 |
90 | modelcheckpoint:
91 | params:
92 | every_n_train_steps: 5000
93 |
94 | callbacks:
95 | image_logger:
96 | target: vidtok.modules.logger.ImageVideoLogger
97 | params:
98 | disabled: false
99 | rescale: true
100 | enable_autocast: false
101 | batch_frequency: 5000
102 | max_samples: 2
103 | increase_log_steps: false
104 | log_first_step: false
105 | log_before_first_step: false
106 | log_images_kwargs:
107 | n_rows: 17
108 |
109 | trainer:
110 | precision: bf16-mixed
111 | devices: auto
112 | num_nodes: 1
113 | benchmark: true
114 | num_sanity_val_steps: 10
115 | val_check_interval: 2000
116 | check_val_every_n_epoch: null # default: 1
117 | accumulate_grad_batches: 1
118 | max_epochs: 1000
119 |
--------------------------------------------------------------------------------
/open_tokenizer/model/vidtok/configs/vidtok_fsq_causal_488_32768.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1e-5
3 | target: open_tokenizer.model.vidtok.models.autoencoder.AutoencodingEngine
4 | params:
5 | monitor: val/rec_loss
6 | mode: min
7 | # ckpt_path: checkpoints/vidtok_fsq_causal_488_32768.ckpt # train from existing checkpoint
8 | ignore_keys: []
9 | # ema_decay: 0.999
10 |
11 | encoder_config:
12 | target: open_tokenizer.model.vidtok.modules.model_3dcausal.EncoderCausal3DPadding
13 | params:
14 | double_z: false
15 | z_channels: 5
16 | in_channels: 3
17 | out_ch: 3
18 | ch: 128
19 | ch_mult: [1, 2, 4, 4]
20 | time_downsample_factor: 4
21 | num_res_blocks: 2
22 | dropout: 0.0
23 | use_checkpoint: false
24 | init_pad_mode: replicate
25 | norm_type: layernorm # layernorm, groupnorm
26 | fix_encoder: false # if True, fix it without updating params
27 | fix_decoder: false # if True, fix it without updating params
28 |
29 | decoder_config:
30 | target: open_tokenizer.model.vidtok.modules.model_3dcausal.DecoderCausal3DPadding
31 | params: ${model.params.encoder_config.params}
32 |
33 | regularizer_config:
34 | target: open_tokenizer.model.vidtok.modules.regularizers.FSQRegularizer
35 | params:
36 | levels: [8, 8, 8, 8, 8] # codebook size: 8*8*8*8*8=32768
37 | entropy_loss_weight: 0.1
38 | entropy_loss_annealing_steps: 2000
39 | entropy_loss_annealing_factor: 3
40 | commitment_loss_weight: 0.25
41 |
42 | loss_config:
43 | target: open_tokenizer.model.vidtok.modules.losses.GeneralLPIPSWithDiscriminator
44 | params:
45 | dims: 3 # video - [t,h,w]
46 | perceptual_weight: 1.0
47 | disc_start: 20001
48 | disc_weight: 0.2
49 | disc_type: 2d # 2d, 3d
50 | learn_logvar: true
51 | gen_loss_cross_entropy: true
52 | lecam_loss_weight: 0.005
53 | regularization_weights: {'aux_loss': 1.0, 'kl_loss': 0.000001}
54 |
55 | data:
56 | target: open_tokenizer.model.vidtok.data.datamodule.DataModuleFromConfig
57 | params:
58 | batch_size: 2
59 | num_workers: 12
60 |
61 | train:
62 | target: open_tokenizer.model.vidtok.data.vidtok.VidTokDataset
63 | params:
64 | data_dir: DATA_DIR_1 # DATA_DIR for training data
65 | meta_path: META_PATH_1 # path to the .csv meta file of training data
66 | video_params:
67 | input_height: INPUT_HEIGHT_1
68 | input_width: INPUT_WIDTH_1
69 | sample_num_frames: 17
70 | sample_fps: 3
71 |
72 | validation:
73 | target: open_tokenizer.model.vidtok.data.vidtok.VidTokDataset
74 | params:
75 | data_dir: DATA_DIR_2 # DATA_DIR for validation data
76 | meta_path: META_PATH_2 # path to the .csv meta file of validation data
77 | video_params:
78 | input_height: INPUT_HEIGHT_2
79 | input_width: INPUT_WIDTH_2
80 | sample_num_frames: 17
81 | sample_fps: 8
82 | start_index: 0
83 |
84 | lightning:
85 | strategy:
86 | target: lightning.pytorch.strategies.DDPStrategy
87 | params:
88 | find_unused_parameters: true
89 |
90 | modelcheckpoint:
91 | params:
92 | every_n_train_steps: 5000
93 |
94 | callbacks:
95 | image_logger:
96 | target: open_tokenizer.model.vidtok.modules.logger.ImageVideoLogger
97 | params:
98 | disabled: false
99 | rescale: true
100 | enable_autocast: false
101 | batch_frequency: 5000
102 | max_samples: 2
103 | increase_log_steps: false
104 | log_first_step: false
105 | log_before_first_step: false
106 | log_images_kwargs:
107 | n_rows: 17
108 |
109 | trainer:
110 | precision: bf16-mixed
111 | devices: auto
112 | num_nodes: 1
113 | benchmark: true
114 | num_sanity_val_steps: 10
115 | val_check_interval: 2000
116 | check_val_every_n_epoch: null # default: 1
117 | accumulate_grad_batches: 1
118 | max_epochs: 1000
119 |
--------------------------------------------------------------------------------
/open_tokenizer/model/vidtok/configs/vidtok_fsq_causal_488_4096.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1e-5
3 | target: vidtok.models.autoencoder.AutoencodingEngine
4 | params:
5 | monitor: val/rec_loss
6 | mode: min
7 | # ckpt_path: checkpoints/vidtok_fsq_causal_488_4096.ckpt # train from existing checkpoint
8 | ignore_keys: []
9 | # ema_decay: 0.999
10 |
11 | encoder_config:
12 | target: vidtok.modules.model_3dcausal.EncoderCausal3DPadding
13 | params:
14 | double_z: false
15 | z_channels: 4
16 | in_channels: 3
17 | out_ch: 3
18 | ch: 128
19 | ch_mult: [1, 2, 4, 4]
20 | time_downsample_factor: 4
21 | num_res_blocks: 2
22 | dropout: 0.0
23 | use_checkpoint: false
24 | init_pad_mode: replicate
25 | norm_type: layernorm # layernorm, groupnorm
26 | fix_encoder: false # if True, fix it without updating params
27 | fix_decoder: false # if True, fix it without updating params
28 |
29 | decoder_config:
30 | target: vidtok.modules.model_3dcausal.DecoderCausal3DPadding
31 | params: ${model.params.encoder_config.params}
32 |
33 | regularizer_config:
34 | target: vidtok.modules.regularizers.FSQRegularizer
35 | params:
36 | levels: [8, 8, 8, 8] # codebook size: 8*8*8*8=4096
37 | entropy_loss_weight: 0.1
38 | entropy_loss_annealing_steps: 2000
39 | entropy_loss_annealing_factor: 3
40 | commitment_loss_weight: 0.25
41 |
42 | loss_config:
43 | target: vidtok.modules.losses.GeneralLPIPSWithDiscriminator
44 | params:
45 | dims: 3 # video - [t,h,w]
46 | perceptual_weight: 1.0
47 | disc_start: 20001
48 | disc_weight: 0.2
49 | disc_type: 2d # 2d, 3d
50 | learn_logvar: true
51 | gen_loss_cross_entropy: true
52 | lecam_loss_weight: 0.005
53 | regularization_weights: {'aux_loss': 1.0, 'kl_loss': 0.000001}
54 |
55 | data:
56 | target: vidtok.data.datamodule.DataModuleFromConfig
57 | params:
58 | batch_size: 2
59 | num_workers: 12
60 |
61 | train:
62 | target: vidtok.data.vidtok.VidTokDataset
63 | params:
64 | data_dir: DATA_DIR_1 # DATA_DIR for training data
65 | meta_path: META_PATH_1 # path to the .csv meta file of training data
66 | video_params:
67 | input_height: INPUT_HEIGHT_1
68 | input_width: INPUT_WIDTH_1
69 | sample_num_frames: 17
70 | sample_fps: 3
71 |
72 | validation:
73 | target: vidtok.data.vidtok.VidTokDataset
74 | params:
75 | data_dir: DATA_DIR_2 # DATA_DIR for validation data
76 | meta_path: META_PATH_2 # path to the .csv meta file of validation data
77 | video_params:
78 | input_height: INPUT_HEIGHT_2
79 | input_width: INPUT_WIDTH_2
80 | sample_num_frames: 17
81 | sample_fps: 8
82 | start_index: 0
83 |
84 | lightning:
85 | strategy:
86 | target: lightning.pytorch.strategies.DDPStrategy
87 | params:
88 | find_unused_parameters: true
89 |
90 | modelcheckpoint:
91 | params:
92 | every_n_train_steps: 5000
93 |
94 | callbacks:
95 | image_logger:
96 | target: vidtok.modules.logger.ImageVideoLogger
97 | params:
98 | disabled: false
99 | rescale: true
100 | enable_autocast: false
101 | batch_frequency: 5000
102 | max_samples: 2
103 | increase_log_steps: false
104 | log_first_step: false
105 | log_before_first_step: false
106 | log_images_kwargs:
107 | n_rows: 17
108 |
109 | trainer:
110 | precision: bf16-mixed
111 | devices: auto
112 | num_nodes: 1
113 | benchmark: true
114 | num_sanity_val_steps: 10
115 | val_check_interval: 2000
116 | check_val_every_n_epoch: null # default: 1
117 | accumulate_grad_batches: 1
118 | max_epochs: 1000
119 |
--------------------------------------------------------------------------------
/open_tokenizer/model/vidtok/configs/vidtok_fsq_noncausal_41616_262144.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1e-5
3 | target: vidtok.models.autoencoder.AutoencodingEngine
4 | params:
5 | monitor: val/rec_loss
6 | mode: min
7 | # ckpt_path: checkpoints/vidtok_fsq_noncausal_41616_262144.ckpt # train from existing checkpoint
8 | ignore_keys: []
9 | # ema_decay: 0.999
10 |
11 | encoder_config:
12 | target: vidtok.modules.model_3dnoncausal.Encoder3D
13 | params:
14 | double_z: false
15 | z_channels: 6
16 | in_channels: 3
17 | out_ch: 3
18 | ch: 128
19 | ch_mult: [1, 2, 4, 4, 4]
20 | num_res_blocks: 2
21 | dropout: 0.0
22 | use_checkpoint: false
23 | norm_type: layernorm # layernorm, groupnorm
24 | fix_encoder: false
25 | fix_decoder: false
26 |
27 | decoder_config:
28 | target: vidtok.modules.model_3dnoncausal.Decoder3D
29 | params: ${model.params.encoder_config.params}
30 |
31 | regularizer_config:
32 | target: vidtok.modules.regularizers.FSQRegularizer
33 | params:
34 | levels: [8, 8, 8, 8, 8, 8] # codebook size: 8*8*8*8*8*8=262144
35 | entropy_loss_weight: 0.1
36 | entropy_loss_annealing_steps: 2000
37 | entropy_loss_annealing_factor: 3
38 | commitment_loss_weight: 0.25
39 |
40 | loss_config:
41 | target: vidtok.modules.losses.GeneralLPIPSWithDiscriminator
42 | params:
43 | dims: 3 # video - [t,h,w]
44 | perceptual_weight: 1.0
45 | disc_start: 20001
46 | disc_weight: 0.2
47 | disc_type: 2d # 2d, 3d
48 | learn_logvar: true
49 | gen_loss_cross_entropy: true
50 | lecam_loss_weight: 0.005
51 | regularization_weights: {'aux_loss': 1.0, 'kl_loss': 0.000001}
52 |
53 | data:
54 | target: vidtok.data.datamodule.DataModuleFromConfig
55 | params:
56 | batch_size: 2
57 | num_workers: 12
58 |
59 | train:
60 | target: vidtok.data.vidtok.VidTokDataset
61 | params:
62 | data_dir: DATA_DIR_1 # DATA_DIR for training data
63 | meta_path: META_PATH_1 # path to the .csv meta file of training data
64 | video_params:
65 | input_height: INPUT_HEIGHT_1
66 | input_width: INPUT_WIDTH_1
67 | sample_num_frames: 16
68 | sample_fps: 3
69 |
70 | validation:
71 | target: vidtok.data.vidtok.VidTokDataset
72 | params:
73 | data_dir: DATA_DIR_2 # DATA_DIR for validation data
74 | meta_path: META_PATH_2 # path to the .csv meta file of validation data
75 | video_params:
76 | input_height: INPUT_HEIGHT_2
77 | input_width: INPUT_WIDTH_2
78 | sample_num_frames: 16
79 | sample_fps: 8
80 | start_index: 0
81 |
82 | lightning:
83 | strategy:
84 | target: lightning.pytorch.strategies.DDPStrategy
85 | params:
86 | find_unused_parameters: true
87 |
88 | modelcheckpoint:
89 | params:
90 | every_n_train_steps: 5000
91 |
92 | callbacks:
93 | image_logger:
94 | target: vidtok.modules.logger.ImageVideoLogger
95 | params:
96 | disabled: false
97 | rescale: true
98 | enable_autocast: false
99 | batch_frequency: 5000
100 | max_samples: 2
101 | increase_log_steps: false
102 | log_first_step: false
103 | log_before_first_step: false
104 | log_images_kwargs:
105 | n_rows: 16
106 |
107 | trainer:
108 | precision: bf16-mixed
109 | devices: auto
110 | num_nodes: 1
111 | benchmark: true
112 | num_sanity_val_steps: 10
113 | val_check_interval: 2000
114 | check_val_every_n_epoch: null # default: 1
115 | accumulate_grad_batches: 1
116 | max_epochs: 1000
117 |
--------------------------------------------------------------------------------
/open_tokenizer/model/vidtok/configs/vidtok_fsq_noncausal_488_262144.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1e-5
3 | target: vidtok.models.autoencoder.AutoencodingEngine
4 | params:
5 | monitor: val/rec_loss
6 | mode: min
7 | # ckpt_path: checkpoints/vidtok_fsq_noncausal_488_262144.ckpt # train from existing checkpoint
8 | ignore_keys: []
9 | # ema_decay: 0.999
10 |
11 | encoder_config:
12 | target: vidtok.modules.model_3dnoncausal.Encoder3D
13 | params:
14 | double_z: false
15 | z_channels: 6
16 | in_channels: 3
17 | out_ch: 3
18 | ch: 128
19 | ch_mult: [1, 2, 4, 4]
20 | num_res_blocks: 2
21 | dropout: 0.0
22 | use_checkpoint: false
23 | norm_type: layernorm # layernorm, groupnorm
24 | fix_encoder: false
25 | fix_decoder: false
26 |
27 | decoder_config:
28 | target: vidtok.modules.model_3dnoncausal.Decoder3D
29 | params: ${model.params.encoder_config.params}
30 |
31 | regularizer_config:
32 | target: vidtok.modules.regularizers.FSQRegularizer
33 | params:
34 | levels: [8, 8, 8, 8, 8, 8] # codebook size: 8*8*8*8*8*8=262144
35 | entropy_loss_weight: 0.1
36 | entropy_loss_annealing_steps: 2000
37 | entropy_loss_annealing_factor: 3
38 | commitment_loss_weight: 0.25
39 |
40 | loss_config:
41 | target: vidtok.modules.losses.GeneralLPIPSWithDiscriminator
42 | params:
43 | dims: 3 # video - [t,h,w]
44 | perceptual_weight: 1.0
45 | disc_start: 20001
46 | disc_weight: 0.2
47 | disc_type: 2d # 2d, 3d
48 | learn_logvar: true
49 | gen_loss_cross_entropy: true
50 | lecam_loss_weight: 0.005
51 | regularization_weights: {'aux_loss': 1.0, 'kl_loss': 0.000001}
52 |
53 | data:
54 | target: vidtok.data.datamodule.DataModuleFromConfig
55 | params:
56 | batch_size: 2
57 | num_workers: 12
58 |
59 | train:
60 | target: vidtok.data.vidtok.VidTokDataset
61 | params:
62 | data_dir: DATA_DIR_1 # DATA_DIR for training data
63 | meta_path: META_PATH_1 # path to the .csv meta file of training data
64 | video_params:
65 | input_height: INPUT_HEIGHT_1
66 | input_width: INPUT_WIDTH_1
67 | sample_num_frames: 16
68 | sample_fps: 3
69 |
70 | validation:
71 | target: vidtok.data.vidtok.VidTokDataset
72 | params:
73 | data_dir: DATA_DIR_2 # DATA_DIR for validation data
74 | meta_path: META_PATH_2 # path to the .csv meta file of validation data
75 | video_params:
76 | input_height: INPUT_HEIGHT_2
77 | input_width: INPUT_WIDTH_2
78 | sample_num_frames: 16
79 | sample_fps: 8
80 | start_index: 0
81 |
82 | lightning:
83 | strategy:
84 | target: lightning.pytorch.strategies.DDPStrategy
85 | params:
86 | find_unused_parameters: true
87 |
88 | modelcheckpoint:
89 | params:
90 | every_n_train_steps: 5000
91 |
92 | callbacks:
93 | image_logger:
94 | target: vidtok.modules.logger.ImageVideoLogger
95 | params:
96 | disabled: false
97 | rescale: true
98 | enable_autocast: false
99 | batch_frequency: 5000
100 | max_samples: 2
101 | increase_log_steps: false
102 | log_first_step: false
103 | log_before_first_step: false
104 | log_images_kwargs:
105 | n_rows: 16
106 |
107 | trainer:
108 | precision: bf16-mixed
109 | devices: auto
110 | num_nodes: 1
111 | benchmark: true
112 | num_sanity_val_steps: 10
113 | val_check_interval: 2000
114 | check_val_every_n_epoch: null # default: 1
115 | accumulate_grad_batches: 1
116 | max_epochs: 1000
117 |
--------------------------------------------------------------------------------
/open_tokenizer/model/vidtok/configs/vidtok_kl_causal_288_8chn.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1e-5
3 | target: vidtok.models.autoencoder.AutoencodingEngine
4 | params:
5 | monitor: val/rec_loss
6 | mode: min
7 | # ckpt_path: checkpoints/vidtok_kl_causal_288_8chn.ckpt # train from existing checkpoint
8 | ignore_keys: []
9 | # ema_decay: 0.999
10 |
11 | encoder_config:
12 | target: vidtok.modules.model_3dcausal.EncoderCausal3DPadding
13 | params:
14 | double_z: true
15 | z_channels: 8
16 | in_channels: 3
17 | out_ch: 3
18 | ch: 128
19 | ch_mult: [1, 2, 4, 4]
20 | tempo_ds: [1]
21 | tempo_us: [2]
22 | time_downsample_factor: 2
23 | num_res_blocks: 2
24 | dropout: 0.0
25 | use_checkpoint: false
26 | init_pad_mode: replicate
27 | norm_type: layernorm # layernorm, groupnorm
28 | fix_encoder: false # if True, fix it without updating params
29 | fix_decoder: false # if True, fix it without updating params
30 |
31 | decoder_config:
32 | target: vidtok.modules.model_3dcausal.DecoderCausal3DPadding
33 | params: ${model.params.encoder_config.params}
34 |
35 | regularizer_config:
36 | target: vidtok.modules.regularizers.DiagonalGaussianRegularizer
37 |
38 | loss_config:
39 | target: vidtok.modules.losses.GeneralLPIPSWithDiscriminator
40 | params:
41 | dims: 3 # video - [t,h,w]
42 | perceptual_weight: 1.0
43 | disc_start: 20001
44 | disc_weight: 0.2
45 | disc_type: 2d # 2d, 3d
46 | learn_logvar: true
47 | gen_loss_cross_entropy: true
48 | lecam_loss_weight: 0.005
49 | regularization_weights: {'aux_loss': 1.0, 'kl_loss': 0.000001}
50 |
51 | data:
52 | target: vidtok.data.datamodule.DataModuleFromConfig
53 | params:
54 | batch_size: 2
55 | num_workers: 12
56 |
57 | train:
58 | target: vidtok.data.vidtok.VidTokDataset
59 | params:
60 | data_dir: DATA_DIR_1 # DATA_DIR for training data
61 | meta_path: META_PATH_1 # path to the .csv meta file of training data
62 | video_params:
63 | input_height: INPUT_HEIGHT_1
64 | input_width: INPUT_WIDTH_1
65 | sample_num_frames: 17
66 | sample_fps: 3
67 |
68 | validation:
69 | target: vidtok.data.vidtok.VidTokDataset
70 | params:
71 | data_dir: DATA_DIR_2 # DATA_DIR for validation data
72 | meta_path: META_PATH_2 # path to the .csv meta file of validation data
73 | video_params:
74 | input_height: INPUT_HEIGHT_2
75 | input_width: INPUT_WIDTH_2
76 | sample_num_frames: 17
77 | sample_fps: 8
78 | start_index: 0
79 |
80 | lightning:
81 | strategy:
82 | target: lightning.pytorch.strategies.DDPStrategy
83 | params:
84 | find_unused_parameters: true
85 |
86 | modelcheckpoint:
87 | params:
88 | every_n_train_steps: 5000
89 |
90 | callbacks:
91 | image_logger:
92 | target: vidtok.modules.logger.ImageVideoLogger
93 | params:
94 | disabled: false
95 | rescale: true
96 | enable_autocast: false
97 | batch_frequency: 5000
98 | max_samples: 2
99 | increase_log_steps: false
100 | log_first_step: false
101 | log_before_first_step: false
102 | log_images_kwargs:
103 | n_rows: 17
104 |
105 | trainer:
106 | precision: bf16-mixed
107 | devices: auto
108 | num_nodes: 1
109 | benchmark: true
110 | num_sanity_val_steps: 10
111 | val_check_interval: 2000
112 | check_val_every_n_epoch: null # default: 1
113 | accumulate_grad_batches: 1
114 | max_epochs: 1000
115 |
--------------------------------------------------------------------------------
/open_tokenizer/model/vidtok/configs/vidtok_kl_causal_41616_4chn.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1e-5
3 | target: vidtok.models.autoencoder.AutoencodingEngine
4 | params:
5 | monitor: val/rec_loss
6 | mode: min
7 | # ckpt_path: checkpoints/vidtok_kl_causal_41616_4chn.ckpt # train from existing checkpoint
8 | ignore_keys: []
9 | # ema_decay: 0.999
10 |
11 | encoder_config:
12 | target: vidtok.modules.model_3dcausal.EncoderCausal3DPadding
13 | params:
14 | double_z: true
15 | z_channels: 4
16 | in_channels: 3
17 | out_ch: 3
18 | ch: 128
19 | ch_mult: [1, 2, 4, 4, 4]
20 | time_downsample_factor: 4
21 | num_res_blocks: 2
22 | dropout: 0.0
23 | use_checkpoint: false
24 | init_pad_mode: replicate
25 | norm_type: layernorm # layernorm, groupnorm
26 | fix_encoder: false # if True, fix it without updating params
27 | fix_decoder: false # if True, fix it without updating params
28 |
29 | decoder_config:
30 | target: vidtok.modules.model_3dcausal.DecoderCausal3DPadding
31 | params: ${model.params.encoder_config.params}
32 |
33 | regularizer_config:
34 | target: vidtok.modules.regularizers.DiagonalGaussianRegularizer
35 |
36 | loss_config:
37 | target: vidtok.modules.losses.GeneralLPIPSWithDiscriminator
38 | params:
39 | dims: 3 # video - [t,h,w]
40 | perceptual_weight: 1.0
41 | disc_start: 20001
42 | disc_weight: 0.2
43 | disc_type: 2d # 2d, 3d
44 | learn_logvar: true
45 | gen_loss_cross_entropy: true
46 | lecam_loss_weight: 0.005
47 | regularization_weights: {'aux_loss': 1.0, 'kl_loss': 0.000001}
48 |
49 | data:
50 | target: vidtok.data.datamodule.DataModuleFromConfig
51 | params:
52 | batch_size: 2
53 | num_workers: 12
54 |
55 | train:
56 | target: vidtok.data.vidtok.VidTokDataset
57 | params:
58 | data_dir: DATA_DIR_1 # DATA_DIR for training data
59 | meta_path: META_PATH_1 # path to the .csv meta file of training data
60 | video_params:
61 | input_height: INPUT_HEIGHT_1
62 | input_width: INPUT_WIDTH_1
63 | sample_num_frames: 17
64 | sample_fps: 3
65 |
66 | validation:
67 | target: vidtok.data.vidtok.VidTokDataset
68 | params:
69 | data_dir: DATA_DIR_2 # DATA_DIR for validation data
70 | meta_path: META_PATH_2 # path to the .csv meta file of validation data
71 | video_params:
72 | input_height: INPUT_HEIGHT_2
73 | input_width: INPUT_WIDTH_2
74 | sample_num_frames: 17
75 | sample_fps: 8
76 | start_index: 0
77 |
78 | lightning:
79 | strategy:
80 | target: lightning.pytorch.strategies.DDPStrategy
81 | params:
82 | find_unused_parameters: true
83 |
84 | modelcheckpoint:
85 | params:
86 | every_n_train_steps: 5000
87 |
88 | callbacks:
89 | image_logger:
90 | target: vidtok.modules.logger.ImageVideoLogger
91 | params:
92 | disabled: false
93 | rescale: true
94 | enable_autocast: false
95 | batch_frequency: 5000
96 | max_samples: 2
97 | increase_log_steps: false
98 | log_first_step: false
99 | log_before_first_step: false
100 | log_images_kwargs:
101 | n_rows: 17
102 |
103 | trainer:
104 | precision: bf16-mixed
105 | devices: auto
106 | num_nodes: 1
107 | benchmark: true
108 | num_sanity_val_steps: 10
109 | val_check_interval: 2000
110 | check_val_every_n_epoch: null # default: 1
111 | accumulate_grad_batches: 1
112 | max_epochs: 1000
113 |
--------------------------------------------------------------------------------
/open_tokenizer/model/vidtok/configs/vidtok_kl_causal_444_4chn.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1e-5
3 | target: vidtok.models.autoencoder.AutoencodingEngine
4 | params:
5 | monitor: val/rec_loss
6 | mode: min
7 | # ckpt_path: checkpoints/vidtok_kl_causal_444_4chn.ckpt # train from existing checkpoint
8 | ignore_keys: []
9 | # ema_decay: 0.999
10 |
11 | encoder_config:
12 | target: vidtok.modules.model_3dcausal.EncoderCausal3DPadding
13 | params:
14 | double_z: true
15 | z_channels: 4
16 | in_channels: 3
17 | out_ch: 3
18 | ch: 128
19 | ch_mult: [1, 2, 4, 4]
20 | spatial_ds: [1, 2]
21 | spatial_us: [1, 2]
22 | time_downsample_factor: 4
23 | num_res_blocks: 2
24 | dropout: 0.0
25 | use_checkpoint: false
26 | init_pad_mode: replicate
27 | norm_type: layernorm # layernorm, groupnorm
28 | fix_encoder: false # if True, fix it without updating params
29 | fix_decoder: false # if True, fix it without updating params
30 |
31 | decoder_config:
32 | target: vidtok.modules.model_3dcausal.DecoderCausal3DPadding
33 | params: ${model.params.encoder_config.params}
34 |
35 | regularizer_config:
36 | target: vidtok.modules.regularizers.DiagonalGaussianRegularizer
37 |
38 | loss_config:
39 | target: vidtok.modules.losses.GeneralLPIPSWithDiscriminator
40 | params:
41 | dims: 3 # video - [t,h,w]
42 | perceptual_weight: 1.0
43 | disc_start: 20001
44 | disc_weight: 0.2
45 | disc_type: 2d # 2d, 3d
46 | learn_logvar: true
47 | gen_loss_cross_entropy: true
48 | lecam_loss_weight: 0.005
49 | regularization_weights: {'aux_loss': 1.0, 'kl_loss': 0.000001}
50 |
51 | data:
52 | target: vidtok.data.datamodule.DataModuleFromConfig
53 | params:
54 | batch_size: 2
55 | num_workers: 12
56 |
57 | train:
58 | target: vidtok.data.vidtok.VidTokDataset
59 | params:
60 | data_dir: DATA_DIR_1 # DATA_DIR for training data
61 | meta_path: META_PATH_1 # path to the .csv meta file of training data
62 | video_params:
63 | input_height: INPUT_HEIGHT_1
64 | input_width: INPUT_WIDTH_1
65 | sample_num_frames: 17
66 | sample_fps: 3
67 |
68 | validation:
69 | target: vidtok.data.vidtok.VidTokDataset
70 | params:
71 | data_dir: DATA_DIR_2 # DATA_DIR for validation data
72 | meta_path: META_PATH_2 # path to the .csv meta file of validation data
73 | video_params:
74 | input_height: INPUT_HEIGHT_2
75 | input_width: INPUT_WIDTH_2
76 | sample_num_frames: 17
77 | sample_fps: 8
78 | start_index: 0
79 |
80 | lightning:
81 | strategy:
82 | target: lightning.pytorch.strategies.DDPStrategy
83 | params:
84 | find_unused_parameters: true
85 |
86 | modelcheckpoint:
87 | params:
88 | every_n_train_steps: 5000
89 |
90 | callbacks:
91 | image_logger:
92 | target: vidtok.modules.logger.ImageVideoLogger
93 | params:
94 | disabled: false
95 | rescale: true
96 | enable_autocast: false
97 | batch_frequency: 5000
98 | max_samples: 2
99 | increase_log_steps: false
100 | log_first_step: false
101 | log_before_first_step: false
102 | log_images_kwargs:
103 | n_rows: 17
104 |
105 | trainer:
106 | precision: bf16-mixed
107 | devices: auto
108 | num_nodes: 1
109 | benchmark: true
110 | num_sanity_val_steps: 10
111 | val_check_interval: 2000
112 | check_val_every_n_epoch: null # default: 1
113 | accumulate_grad_batches: 1
114 | max_epochs: 1000
115 |
--------------------------------------------------------------------------------
/open_tokenizer/model/vidtok/configs/vidtok_kl_causal_488_16chn.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1e-5
3 | target: vidtok.models.autoencoder.AutoencodingEngine
4 | params:
5 | monitor: val/rec_loss
6 | mode: min
7 | # ckpt_path: checkpoints/vidtok_kl_causal_488_16chn.ckpt # train from existing checkpoint
8 | ignore_keys: []
9 | # ema_decay: 0.999
10 |
11 | encoder_config:
12 | target: vidtok.modules.model_3dcausal.EncoderCausal3DPadding
13 | params:
14 | double_z: true
15 | z_channels: 16
16 | in_channels: 3
17 | out_ch: 3
18 | ch: 128
19 | ch_mult: [1, 2, 4, 4]
20 | time_downsample_factor: 4
21 | num_res_blocks: 2
22 | dropout: 0.0
23 | use_checkpoint: false
24 | init_pad_mode: replicate
25 | norm_type: layernorm # layernorm, groupnorm
26 | fix_encoder: false # if True, fix it without updating params
27 | fix_decoder: false # if True, fix it without updating params
28 |
29 | decoder_config:
30 | target: vidtok.modules.model_3dcausal.DecoderCausal3DPadding
31 | params: ${model.params.encoder_config.params}
32 |
33 | regularizer_config:
34 | target: vidtok.modules.regularizers.DiagonalGaussianRegularizer
35 |
36 | loss_config:
37 | target: vidtok.modules.losses.GeneralLPIPSWithDiscriminator
38 | params:
39 | dims: 3 # video - [t,h,w]
40 | perceptual_weight: 1.0
41 | disc_start: 20001
42 | disc_weight: 0.2
43 | disc_type: 2d # 2d, 3d
44 | learn_logvar: true
45 | gen_loss_cross_entropy: true
46 | lecam_loss_weight: 0.005
47 | regularization_weights: {'aux_loss': 1.0, 'kl_loss': 0.000001}
48 |
49 | data:
50 | target: vidtok.data.datamodule.DataModuleFromConfig
51 | params:
52 | batch_size: 2
53 | num_workers: 12
54 |
55 | train:
56 | target: vidtok.data.vidtok.VidTokDataset
57 | params:
58 | data_dir: DATA_DIR_1 # DATA_DIR for training data
59 | meta_path: META_PATH_1 # path to the .csv meta file of training data
60 | video_params:
61 | input_height: INPUT_HEIGHT_1
62 | input_width: INPUT_WIDTH_1
63 | sample_num_frames: 17
64 | sample_fps: 3
65 |
66 | validation:
67 | target: vidtok.data.vidtok.VidTokDataset
68 | params:
69 | data_dir: DATA_DIR_2 # DATA_DIR for validation data
70 | meta_path: META_PATH_2 # path to the .csv meta file of validation data
71 | video_params:
72 | input_height: INPUT_HEIGHT_2
73 | input_width: INPUT_WIDTH_2
74 | sample_num_frames: 17
75 | sample_fps: 8
76 | start_index: 0
77 |
78 | lightning:
79 | strategy:
80 | target: lightning.pytorch.strategies.DDPStrategy
81 | params:
82 | find_unused_parameters: true
83 |
84 | modelcheckpoint:
85 | params:
86 | every_n_train_steps: 5000
87 |
88 | callbacks:
89 | image_logger:
90 | target: vidtok.modules.logger.ImageVideoLogger
91 | params:
92 | disabled: false
93 | rescale: true
94 | enable_autocast: false
95 | batch_frequency: 5000
96 | max_samples: 2
97 | increase_log_steps: false
98 | log_first_step: false
99 | log_before_first_step: false
100 | log_images_kwargs:
101 | n_rows: 17
102 |
103 | trainer:
104 | precision: bf16-mixed
105 | devices: auto
106 | num_nodes: 1
107 | benchmark: true
108 | num_sanity_val_steps: 10
109 | val_check_interval: 2000
110 | check_val_every_n_epoch: null # default: 1
111 | accumulate_grad_batches: 1
112 | max_epochs: 1000
113 |
--------------------------------------------------------------------------------
/open_tokenizer/model/vidtok/configs/vidtok_kl_causal_488_4chn.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1e-5
3 | target: vidtok.models.autoencoder.AutoencodingEngine
4 | params:
5 | monitor: val/rec_loss
6 | mode: min
7 | # ckpt_path: checkpoints/vidtok_kl_causal_488_4chn.ckpt # train from existing checkpoint
8 | ignore_keys: []
9 | # ema_decay: 0.999
10 |
11 | encoder_config:
12 | target: vidtok.modules.model_3dcausal.EncoderCausal3DPadding
13 | params:
14 | double_z: true
15 | z_channels: 4
16 | in_channels: 3
17 | out_ch: 3
18 | ch: 128
19 | ch_mult: [1, 2, 4, 4]
20 | time_downsample_factor: 4
21 | num_res_blocks: 2
22 | dropout: 0.0
23 | use_checkpoint: false
24 | init_pad_mode: replicate
25 | norm_type: layernorm # layernorm, groupnorm
26 | fix_encoder: false # if True, fix it without updating params
27 | fix_decoder: false # if True, fix it without updating params
28 |
29 | decoder_config:
30 | target: vidtok.modules.model_3dcausal.DecoderCausal3DPadding
31 | params: ${model.params.encoder_config.params}
32 |
33 | regularizer_config:
34 | target: vidtok.modules.regularizers.DiagonalGaussianRegularizer
35 |
36 | loss_config:
37 | target: vidtok.modules.losses.GeneralLPIPSWithDiscriminator
38 | params:
39 | dims: 3 # video - [t,h,w]
40 | perceptual_weight: 1.0
41 | disc_start: 20001
42 | disc_weight: 0.2
43 | disc_type: 2d # 2d, 3d
44 | learn_logvar: true
45 | gen_loss_cross_entropy: true
46 | lecam_loss_weight: 0.005
47 | regularization_weights: {'aux_loss': 1.0, 'kl_loss': 0.000001}
48 |
49 | data:
50 | target: vidtok.data.datamodule.DataModuleFromConfig
51 | params:
52 | batch_size: 2
53 | num_workers: 12
54 |
55 | train:
56 | target: vidtok.data.vidtok.VidTokDataset
57 | params:
58 | data_dir: DATA_DIR_1 # DATA_DIR for training data
59 | meta_path: META_PATH_1 # path to the .csv meta file of training data
60 | video_params:
61 | input_height: INPUT_HEIGHT_1
62 | input_width: INPUT_WIDTH_1
63 | sample_num_frames: 17
64 | sample_fps: 3
65 |
66 | validation:
67 | target: vidtok.data.vidtok.VidTokDataset
68 | params:
69 | data_dir: DATA_DIR_2 # DATA_DIR for validation data
70 | meta_path: META_PATH_2 # path to the .csv meta file of validation data
71 | video_params:
72 | input_height: INPUT_HEIGHT_2
73 | input_width: INPUT_WIDTH_2
74 | sample_num_frames: 17
75 | sample_fps: 8
76 | start_index: 0
77 |
78 | lightning:
79 | strategy:
80 | target: lightning.pytorch.strategies.DDPStrategy
81 | params:
82 | find_unused_parameters: true
83 |
84 | modelcheckpoint:
85 | params:
86 | every_n_train_steps: 5000
87 |
88 | callbacks:
89 | image_logger:
90 | target: vidtok.modules.logger.ImageVideoLogger
91 | params:
92 | disabled: false
93 | rescale: true
94 | enable_autocast: false
95 | batch_frequency: 5000
96 | max_samples: 2
97 | increase_log_steps: false
98 | log_first_step: false
99 | log_before_first_step: false
100 | log_images_kwargs:
101 | n_rows: 17
102 |
103 | trainer:
104 | precision: bf16-mixed
105 | devices: auto
106 | num_nodes: 1
107 | benchmark: true
108 | num_sanity_val_steps: 10
109 | val_check_interval: 2000
110 | check_val_every_n_epoch: null # default: 1
111 | accumulate_grad_batches: 1
112 | max_epochs: 1000
113 |
--------------------------------------------------------------------------------
/open_tokenizer/model/vidtok/configs/vidtok_kl_causal_488_8chn.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1e-5
3 | target: vidtok.models.autoencoder.AutoencodingEngine
4 | params:
5 | monitor: val/rec_loss
6 | mode: min
7 | # ckpt_path: checkpoints/vidtok_kl_causal_488_8chn.ckpt # train from existing checkpoint
8 | ignore_keys: []
9 | # ema_decay: 0.999
10 |
11 | encoder_config:
12 | target: vidtok.modules.model_3dcausal.EncoderCausal3DPadding
13 | params:
14 | double_z: true
15 | z_channels: 8
16 | in_channels: 3
17 | out_ch: 3
18 | ch: 128
19 | ch_mult: [1, 2, 4, 4]
20 | time_downsample_factor: 4
21 | num_res_blocks: 2
22 | dropout: 0.0
23 | use_checkpoint: false
24 | init_pad_mode: replicate
25 | norm_type: layernorm # layernorm, groupnorm
26 | fix_encoder: false # if True, fix it without updating params
27 | fix_decoder: false # if True, fix it without updating params
28 |
29 | decoder_config:
30 | target: vidtok.modules.model_3dcausal.DecoderCausal3DPadding
31 | params: ${model.params.encoder_config.params}
32 |
33 | regularizer_config:
34 | target: vidtok.modules.regularizers.DiagonalGaussianRegularizer
35 |
36 | loss_config:
37 | target: vidtok.modules.losses.GeneralLPIPSWithDiscriminator
38 | params:
39 | dims: 3 # video - [t,h,w]
40 | perceptual_weight: 1.0
41 | disc_start: 20001
42 | disc_weight: 0.2
43 | disc_type: 2d # 2d, 3d
44 | learn_logvar: true
45 | gen_loss_cross_entropy: true
46 | lecam_loss_weight: 0.005
47 | regularization_weights: {'aux_loss': 1.0, 'kl_loss': 0.000001}
48 |
49 | data:
50 | target: vidtok.data.datamodule.DataModuleFromConfig
51 | params:
52 | batch_size: 2
53 | num_workers: 12
54 |
55 | train:
56 | target: vidtok.data.vidtok.VidTokDataset
57 | params:
58 | data_dir: DATA_DIR_1 # DATA_DIR for training data
59 | meta_path: META_PATH_1 # path to the .csv meta file of training data
60 | video_params:
61 | input_height: INPUT_HEIGHT_1
62 | input_width: INPUT_WIDTH_1
63 | sample_num_frames: 17
64 | sample_fps: 3
65 |
66 | validation:
67 | target: vidtok.data.vidtok.VidTokDataset
68 | params:
69 | data_dir: DATA_DIR_2 # DATA_DIR for validation data
70 | meta_path: META_PATH_2 # path to the .csv meta file of validation data
71 | video_params:
72 | input_height: INPUT_HEIGHT_2
73 | input_width: INPUT_WIDTH_2
74 | sample_num_frames: 17
75 | sample_fps: 8
76 | start_index: 0
77 |
78 | lightning:
79 | strategy:
80 | target: lightning.pytorch.strategies.DDPStrategy
81 | params:
82 | find_unused_parameters: true
83 |
84 | modelcheckpoint:
85 | params:
86 | every_n_train_steps: 5000
87 |
88 | callbacks:
89 | image_logger:
90 | target: vidtok.modules.logger.ImageVideoLogger
91 | params:
92 | disabled: false
93 | rescale: true
94 | enable_autocast: false
95 | batch_frequency: 5000
96 | max_samples: 2
97 | increase_log_steps: false
98 | log_first_step: false
99 | log_before_first_step: false
100 | log_images_kwargs:
101 | n_rows: 17
102 |
103 | trainer:
104 | precision: bf16-mixed
105 | devices: auto
106 | num_nodes: 1
107 | benchmark: true
108 | num_sanity_val_steps: 10
109 | val_check_interval: 2000
110 | check_val_every_n_epoch: null # default: 1
111 | accumulate_grad_batches: 1
112 | max_epochs: 1000
113 |
--------------------------------------------------------------------------------
/open_tokenizer/model/vidtok/configs/vidtok_kl_noncausal_41616_4chn.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1e-5
3 | target: vidtok.models.autoencoder.AutoencodingEngine
4 | params:
5 | monitor: val/rec_loss
6 | mode: min
7 | # ckpt_path: checkpoints/vidtok_kl_noncausal_41616_4chn.ckpt # train from existing checkpoint
8 | ignore_keys: []
9 | # ema_decay: 0.999
10 |
11 | encoder_config:
12 | target: vidtok.modules.model_3dnoncausal.Encoder3D
13 | params:
14 | double_z: true
15 | z_channels: 4
16 | in_channels: 3
17 | out_ch: 3
18 | ch: 128
19 | ch_mult: [1, 2, 4, 4, 4]
20 | num_res_blocks: 2
21 | dropout: 0.0
22 | use_checkpoint: false
23 | norm_type: layernorm # layernorm, groupnorm
24 | fix_encoder: false
25 | fix_decoder: false
26 |
27 | decoder_config:
28 | target: vidtok.modules.model_3dnoncausal.Decoder3D
29 | params: ${model.params.encoder_config.params}
30 |
31 | regularizer_config:
32 | target: vidtok.modules.regularizers.DiagonalGaussianRegularizer
33 |
34 | loss_config:
35 | target: vidtok.modules.losses.GeneralLPIPSWithDiscriminator
36 | params:
37 | dims: 3 # video - [t,h,w]
38 | perceptual_weight: 1.0
39 | disc_start: 20001
40 | disc_weight: 0.2
41 | disc_type: 2d # 2d, 3d
42 | learn_logvar: true
43 | gen_loss_cross_entropy: true
44 | lecam_loss_weight: 0.005
45 | regularization_weights: {'aux_loss': 1.0, 'kl_loss': 0.000001}
46 |
47 | data:
48 | target: vidtok.data.datamodule.DataModuleFromConfig
49 | params:
50 | batch_size: 2
51 | num_workers: 12
52 |
53 | train:
54 | target: vidtok.data.vidtok.VidTokDataset
55 | params:
56 | data_dir: DATA_DIR_1 # DATA_DIR for training data
57 | meta_path: META_PATH_1 # path to the .csv meta file of training data
58 | video_params:
59 | input_height: INPUT_HEIGHT_1
60 | input_width: INPUT_WIDTH_1
61 | sample_num_frames: 16
62 | sample_fps: 3
63 |
64 | validation:
65 | target: vidtok.data.vidtok.VidTokDataset
66 | params:
67 | data_dir: DATA_DIR_2 # DATA_DIR for validation data
68 | meta_path: META_PATH_2 # path to the .csv meta file of validation data
69 | video_params:
70 | input_height: INPUT_HEIGHT_2
71 | input_width: INPUT_WIDTH_2
72 | sample_num_frames: 16
73 | sample_fps: 8
74 | start_index: 0
75 |
76 | lightning:
77 | strategy:
78 | target: lightning.pytorch.strategies.DDPStrategy
79 | params:
80 | find_unused_parameters: true
81 |
82 | modelcheckpoint:
83 | params:
84 | every_n_train_steps: 5000
85 |
86 | callbacks:
87 | image_logger:
88 | target: vidtok.modules.logger.ImageVideoLogger
89 | params:
90 | disabled: false
91 | rescale: true
92 | enable_autocast: false
93 | batch_frequency: 5000
94 | max_samples: 2
95 | increase_log_steps: false
96 | log_first_step: false
97 | log_before_first_step: false
98 | log_images_kwargs:
99 | n_rows: 16
100 |
101 | trainer:
102 | precision: bf16-mixed
103 | devices: auto
104 | num_nodes: 1
105 | benchmark: true
106 | num_sanity_val_steps: 10
107 | val_check_interval: 2000
108 | check_val_every_n_epoch: null # default: 1
109 | accumulate_grad_batches: 1
110 | max_epochs: 1000
111 |
--------------------------------------------------------------------------------
/open_tokenizer/model/vidtok/configs/vidtok_kl_noncausal_488_4chn.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1e-5
3 | target: vidtok.models.autoencoder.AutoencodingEngine
4 | params:
5 | monitor: val/rec_loss
6 | mode: min
7 | # ckpt_path: checkpoints/vidtok_kl_noncausal_488_4chn.ckpt # train from existing checkpoint
8 | ignore_keys: []
9 | # ema_decay: 0.999
10 |
11 | encoder_config:
12 | target: vidtok.modules.model_3dnoncausal.Encoder3D
13 | params:
14 | double_z: true
15 | z_channels: 4
16 | in_channels: 3
17 | out_ch: 3
18 | ch: 128
19 | ch_mult: [1, 2, 4, 4]
20 | num_res_blocks: 2
21 | dropout: 0.0
22 | use_checkpoint: false
23 | norm_type: layernorm # layernorm, groupnorm
24 | fix_encoder: false
25 | fix_decoder: false
26 |
27 | decoder_config:
28 | target: vidtok.modules.model_3dnoncausal.Decoder3D
29 | params: ${model.params.encoder_config.params}
30 |
31 | regularizer_config:
32 | target: vidtok.modules.regularizers.DiagonalGaussianRegularizer
33 |
34 | loss_config:
35 | target: vidtok.modules.losses.GeneralLPIPSWithDiscriminator
36 | params:
37 | dims: 3 # video - [t,h,w]
38 | perceptual_weight: 1.0
39 | disc_start: 20001
40 | disc_weight: 0.2
41 | disc_type: 2d # 2d, 3d
42 | learn_logvar: true
43 | gen_loss_cross_entropy: true
44 | lecam_loss_weight: 0.005
45 | regularization_weights: {'aux_loss': 1.0, 'kl_loss': 0.000001}
46 |
47 | data:
48 | target: vidtok.data.datamodule.DataModuleFromConfig
49 | params:
50 | batch_size: 2
51 | num_workers: 12
52 |
53 | train:
54 | target: vidtok.data.vidtok.VidTokDataset
55 | params:
56 | data_dir: DATA_DIR_1 # DATA_DIR for training data
57 | meta_path: META_PATH_1 # path to the .csv meta file of training data
58 | video_params:
59 | input_height: INPUT_HEIGHT_1
60 | input_width: INPUT_WIDTH_1
61 | sample_num_frames: 16
62 | sample_fps: 3
63 |
64 | validation:
65 | target: vidtok.data.vidtok.VidTokDataset
66 | params:
67 | data_dir: DATA_DIR_2 # DATA_DIR for validation data
68 | meta_path: META_PATH_2 # path to the .csv meta file of validation data
69 | video_params:
70 | input_height: INPUT_HEIGHT_2
71 | input_width: INPUT_WIDTH_2
72 | sample_num_frames: 16
73 | sample_fps: 8
74 | start_index: 0
75 |
76 | lightning:
77 | strategy:
78 | target: lightning.pytorch.strategies.DDPStrategy
79 | params:
80 | find_unused_parameters: true
81 |
82 | modelcheckpoint:
83 | params:
84 | every_n_train_steps: 5000
85 |
86 | callbacks:
87 | image_logger:
88 | target: vidtok.modules.logger.ImageVideoLogger
89 | params:
90 | disabled: false
91 | rescale: true
92 | enable_autocast: false
93 | batch_frequency: 5000
94 | max_samples: 2
95 | increase_log_steps: false
96 | log_first_step: false
97 | log_before_first_step: false
98 | log_images_kwargs:
99 | n_rows: 16
100 |
101 | trainer:
102 | precision: bf16-mixed
103 | devices: auto
104 | num_nodes: 1
105 | benchmark: true
106 | num_sanity_val_steps: 10
107 | val_check_interval: 2000
108 | check_val_every_n_epoch: null # default: 1
109 | accumulate_grad_batches: 1
110 | max_epochs: 1000
111 |
--------------------------------------------------------------------------------
/open_tokenizer/model/vidtok/configs/vidtwin/vidtwin_structure_7_7_8_dynamics_7_8.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.6e-4
3 | target: vidtwin.models.vidtwin_ae.VidAutoEncoderQformerCompactSymVidVAE
4 | params:
5 | input_key: jpg
6 | monitor: val/rec_loss
7 | ckpt_path: PATH_TO_CHECKPOINT
8 | ignore_keys: []
9 | expect_ch: 8
10 | cont_num_blocks: 1
11 | downsample_motion: True
12 | motion_num_blocks: 1
13 | d_dim: 8
14 |
15 | temporal_qformer_config:
16 | target: vidtwin.modules.qformer.MyQformerInterface
17 | params:
18 | num_query_tokens: 16
19 | query_hidden_size: 64
20 | encoder_hidden_size: 768
21 |
22 | encoder_config:
23 | target: vidtwin.modules.st_transformer.STTEncoder
24 | params:
25 | in_channels: 3
26 | input_size: [16, 224, 224]
27 | patch_size: [1, 16, 16]
28 | hidden_size: 768
29 | depth: 16
30 | num_heads: 12
31 | temporal_casual: true
32 |
33 | decoder_config:
34 | target: vidtwin.modules.st_transformer.STTDecoder
35 | params:
36 | in_channels: 3
37 | input_size: [16, 224, 224]
38 | patch_size: [1, 16, 16]
39 | hidden_size: 768
40 | depth: 16
41 | num_heads: 12
42 | temporal_casual: true
43 |
44 | loss_config:
45 | target: vidtok.modules.losses.GeneralLPIPSWithDiscriminator
46 | params:
47 | perceptual_weight: 0.05
48 | disc_start: 20001
49 | disc_weight: 0.05
50 | learn_logvar: True
51 | dims: 3
52 | disc_type: 2d
53 | regularization_weights:
54 | kl_loss: 0.001
55 |
56 | regularizer_config:
57 | target: vidtok.modules.regularizers.DiagonalGaussianRegularizer
58 | params:
59 | sample: True
60 |
61 |
62 | lr_scheduler_config_d:
63 | target: vidtok.models.vidtwin_ae.LambdaWarmUpCosineScheduler
64 | params:
65 | lr_min: 0
66 | lr_max: 1.5e-05
67 | lr_start: 1.0e-05
68 | warmup_steps: 5000
69 | lr_scheduler_config_g:
70 | target: vidtok.models.vidtwin_ae.LambdaWarmUpCosineScheduler
71 | params:
72 | lr_min: 0
73 | lr_max: 3.0e-05
74 | lr_start: 0
75 | warmup_steps: 5000
76 | optimizer_config:
77 | target: torch.optim.AdamW
78 | params:
79 | betas:
80 | - 0
81 | - 0.9
82 | weight_decay: 0.0001
83 | lr_scheduler_config:
84 | target: inverse_sqrt
85 | params:
86 | num_warmup_steps: 2000
87 | frequency: 1
88 |
89 | data:
90 | target: vidtok.data.datamodule.DataModuleFromConfig
91 | params:
92 | batch_size: 2
93 | num_workers: 12
94 |
95 | train:
96 | target: vidtok.data.vidtok.VidTokDataset
97 | params:
98 | data_dir: DATA_DIR_1 # DATA_DIR for training data
99 | meta_path: META_PATH_1 # path to the .csv meta file of training data
100 | video_params:
101 | input_height: 224
102 | input_width: 224
103 | sample_num_frames: 16
104 | sample_fps: 8
105 |
106 | validation:
107 | target: vidtok.data.vidtok.VidTokDataset
108 | params:
109 | data_dir: DATA_DIR_2 # DATA_DIR for validation data
110 | meta_path: META_PATH_2 # path to the .csv meta file of validation data
111 | video_params:
112 | input_height: 224
113 | input_width: 224
114 | sample_num_frames: 16
115 | sample_fps: 8
116 | start_index: 0
117 |
118 |
119 | lightning:
120 | strategy:
121 | target: lightning.pytorch.strategies.DDPStrategy
122 | params:
123 | find_unused_parameters: True
124 |
125 | modelcheckpoint:
126 | params:
127 | every_n_train_steps: 5000
128 |
129 |
130 | callbacks:
131 | image_logger:
132 | target: vidtok.modules.logger.ImageVideoLogger
133 | params:
134 | disabled: false
135 | rescale: true
136 | enable_autocast: false
137 | batch_frequency: 5000
138 | max_samples: 2
139 | increase_log_steps: false
140 | log_first_step: false
141 | log_before_first_step: false
142 | log_images_kwargs:
143 | n_rows: 2
144 |
145 |
146 |
147 | trainer:
148 | # precision: bf16-mixed # 16-mixed
149 | benchmark: True
150 | devices: 4
151 | num_sanity_val_steps: 10
152 | val_check_interval: 5000
153 | accumulate_grad_batches: 1
154 | max_epochs: 10
155 |
--------------------------------------------------------------------------------
/open_tokenizer/model/vidtok/data/datamodule.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from functools import partial
3 |
4 | import torch
5 | import lightning.pytorch as pl
6 | from torch.utils.data import DataLoader, Dataset, IterableDataset
7 |
8 | from vidtok.modules.util import instantiate_from_config
9 |
10 |
11 | class WrappedDataset(Dataset):
12 | """Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""
13 |
14 | def __init__(self, dataset):
15 | self.data = dataset
16 |
17 | def __len__(self):
18 | return len(self.data)
19 |
20 | def __getitem__(self, idx):
21 | return self.data[idx]
22 |
23 |
24 | def worker_init_fn(_):
25 | worker_info = torch.utils.data.get_worker_info()
26 |
27 | dataset = worker_info.dataset
28 | worker_id = worker_info.id
29 |
30 | if isinstance(dataset, IterableDataset):
31 | split_size = dataset.num_records // worker_info.num_workers
32 | # reset num_records to the true number to retain reliable length information
33 | dataset.sample_ids = dataset.valid_ids[
34 | worker_id * split_size : (worker_id + 1) * split_size
35 | ]
36 | current_id = np.random.choice(len(np.random.get_state()[1]), 1)
37 | return np.random.seed(np.random.get_state()[1][current_id] + worker_id)
38 | else:
39 | return np.random.seed(np.random.get_state()[1][0] + worker_id)
40 |
41 |
42 | class DataModuleFromConfig(pl.LightningDataModule):
43 | def __init__(
44 | self,
45 | batch_size,
46 | train=None,
47 | validation=None,
48 | test=None,
49 | predict=None,
50 | wrap=False,
51 | num_workers=None,
52 | pin_train_memory=True,
53 | is_iterable_dataset=False,
54 | shuffle_test_loader=False,
55 | use_worker_init_fn=False,
56 | shuffle_val_dataloader=False,
57 | ):
58 | super().__init__()
59 | self.batch_size = batch_size
60 | self.dataset_configs = dict()
61 | self.num_workers = num_workers if num_workers is not None else batch_size * 2
62 | self.pin_train_memory = pin_train_memory
63 | self.is_iterable_dataset = is_iterable_dataset
64 | self.use_worker_init_fn = use_worker_init_fn
65 | if train is not None:
66 | self.dataset_configs["train"] = train
67 | self.train_dataloader = self._train_dataloader
68 | if validation is not None:
69 | self.dataset_configs["validation"] = validation
70 | self.val_dataloader = partial(
71 | self._val_dataloader, shuffle=shuffle_val_dataloader
72 | )
73 | if test is not None:
74 | self.dataset_configs["test"] = test
75 | self.test_dataloader = partial(
76 | self._test_dataloader, shuffle=shuffle_test_loader
77 | )
78 | if predict is not None:
79 | self.dataset_configs["predict"] = predict
80 | self.predict_dataloader = self._predict_dataloader
81 | self.wrap = wrap
82 |
83 | def prepare_data(self):
84 | for data_cfg in self.dataset_configs.values():
85 | instantiate_from_config(data_cfg)
86 |
87 | def setup(self, stage=None):
88 | self.datasets = dict(
89 | (k, instantiate_from_config(self.dataset_configs[k]))
90 | for k in self.dataset_configs
91 | )
92 | if self.wrap:
93 | for k in self.datasets:
94 | self.datasets[k] = WrappedDataset(self.datasets[k])
95 |
96 | def _train_dataloader(self):
97 | if self.is_iterable_dataset or self.use_worker_init_fn:
98 | init_fn = worker_init_fn
99 | else:
100 | init_fn = None
101 | return DataLoader(
102 | self.datasets["train"],
103 | batch_size=self.batch_size,
104 | num_workers=self.num_workers,
105 | pin_memory=self.pin_train_memory,
106 | shuffle=False if self.is_iterable_dataset else True,
107 | worker_init_fn=init_fn,
108 | )
109 |
110 | def _val_dataloader(self, shuffle=False):
111 | if self.is_iterable_dataset or self.use_worker_init_fn:
112 | init_fn = worker_init_fn
113 | else:
114 | init_fn = None
115 | return DataLoader(
116 | self.datasets["validation"],
117 | batch_size=self.batch_size,
118 | num_workers=self.num_workers,
119 | worker_init_fn=init_fn,
120 | shuffle=shuffle,
121 | )
122 |
123 | def _test_dataloader(self, shuffle=False):
124 | if self.is_iterable_dataset or self.use_worker_init_fn:
125 | init_fn = worker_init_fn
126 | else:
127 | init_fn = None
128 |
129 | # do not shuffle dataloader for iterable dataset
130 | shuffle = shuffle and (not self.is_iterable_dataset)
131 |
132 | return DataLoader(
133 | self.datasets["test"],
134 | batch_size=self.batch_size,
135 | num_workers=self.num_workers,
136 | worker_init_fn=init_fn,
137 | shuffle=shuffle,
138 | )
139 |
140 | def _predict_dataloader(self, shuffle=False):
141 | if self.is_iterable_dataset or self.use_worker_init_fn:
142 | init_fn = worker_init_fn
143 | else:
144 | init_fn = None
145 | return DataLoader(
146 | self.datasets["predict"],
147 | batch_size=self.batch_size,
148 | num_workers=self.num_workers,
149 | worker_init_fn=init_fn,
150 | )
151 |
--------------------------------------------------------------------------------
/open_tokenizer/model/vidtok/data/video_read.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import decord
4 | import numpy as np
5 | import torch
6 |
7 | from vidtok.modules.util import print0
8 |
9 | decord.bridge.set_bridge("torch")
10 |
11 |
12 | def sample_frames_with_fps(
13 | total_frames,
14 | video_fps,
15 | sample_num_frames,
16 | sample_fps,
17 | start_index=None
18 | ):
19 | """sample frames proportional to the length of the frames in one second
20 | e.g., 1s video has 30 frames, when 'fps'=3, we sample frames with spacing of 30/3=10
21 | return the frame indices
22 |
23 | Parameters
24 | ----------
25 | total_frames : length of the video
26 | video_fps : original fps of the video
27 | sample_num_frames : number of frames to sample
28 | sample_fps : the fps to sample frames
29 | start_index : the starting frame index. If it is not None, it will be used as the starting frame index
30 |
31 | Returns
32 | -------
33 | frame indices
34 | """
35 | sample_num_frames = min(sample_num_frames, total_frames)
36 | interval = round(video_fps / sample_fps)
37 | frames_range = (sample_num_frames - 1) * interval + 1
38 |
39 | if start_index is not None:
40 | start = start_index
41 | elif total_frames - frames_range - 1 < 0:
42 | start = 0
43 | else:
44 | start = random.randint(0, total_frames - frames_range - 1)
45 |
46 | frame_idxs = np.linspace(
47 | start=start, stop=min(total_frames - 1, start + frames_range), num=sample_num_frames
48 | ).astype(int)
49 |
50 | return frame_idxs
51 |
52 |
53 | def read_frames_with_decord(
54 | video_path,
55 | sample_num_frames,
56 | sample_fps,
57 | start_index=None
58 | ) -> tuple[torch.Tensor, list[int]]:
59 | """read frames from video path using decord
60 |
61 | Parameters
62 | ----------
63 | video_path : path to video
64 | sample_num_frames : number of frames to sample
65 | sample_fps : the fps to sample frames
66 | start_index : the starting frame index. If it is not None, it will be used as the starting frame index
67 |
68 | Returns
69 | -------
70 | frames (tensor 0~1), frame indices
71 | """
72 | video_reader = decord.VideoReader(video_path, num_threads=0)
73 | total_frames = len(video_reader)
74 | video_fps = video_reader.get_avg_fps() # note that the fps here is float.
75 | frame_idxs = sample_frames_with_fps(
76 | total_frames=total_frames,
77 | video_fps=video_fps,
78 | sample_num_frames=sample_num_frames,
79 | sample_fps=sample_fps,
80 | start_index=start_index
81 | )
82 | frames = video_reader.get_batch(frame_idxs)
83 | frames = frames.float() / 255
84 | frames = frames.permute(0, 3, 1, 2)
85 | if (frames.shape[0] != sample_num_frames) or (len(frame_idxs) != sample_num_frames):
86 | print0(f"[bold yellow]\[vidtok.data.video_read][read_frames_with_decord][/bold yellow] Warning: need {sample_num_frames} frames, "
87 | f"but got {frames.shape[0]} frames, {len(frame_idxs)} frame indices, video_path={video_path}.")
88 | return frames, frame_idxs
89 |
--------------------------------------------------------------------------------
/open_tokenizer/model/vidtok/modules/discriminator.py:
--------------------------------------------------------------------------------
1 | import functools
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 |
7 | def weights_init(m):
8 | classname = m.__class__.__name__
9 | if classname.find("Conv") != -1:
10 | nn.init.normal_(m.weight.data, 0.0, 0.02)
11 | elif classname.find("BatchNorm") != -1:
12 | nn.init.normal_(m.weight.data, 1.0, 0.02)
13 | nn.init.constant_(m.bias.data, 0)
14 |
15 |
16 | class ActNorm(nn.Module):
17 | def __init__(self, num_features, logdet=False, affine=True, allow_reverse_init=False):
18 | assert affine
19 | super().__init__()
20 | self.logdet = logdet
21 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
22 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
23 | self.allow_reverse_init = allow_reverse_init
24 |
25 | self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8))
26 |
27 | def initialize(self, input):
28 | with torch.no_grad():
29 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
30 | mean = flatten.mean(1).unsqueeze(1).unsqueeze(2).unsqueeze(3).permute(1, 0, 2, 3)
31 | std = flatten.std(1).unsqueeze(1).unsqueeze(2).unsqueeze(3).permute(1, 0, 2, 3)
32 |
33 | self.loc.data.copy_(-mean)
34 | self.scale.data.copy_(1 / (std + 1e-6))
35 |
36 | def forward(self, input, reverse=False):
37 | if reverse:
38 | return self.reverse(input)
39 | if len(input.shape) == 2:
40 | input = input[:, :, None, None]
41 | squeeze = True
42 | else:
43 | squeeze = False
44 |
45 | _, _, height, width = input.shape
46 |
47 | if self.training and self.initialized.item() == 0:
48 | self.initialize(input)
49 | self.initialized.fill_(1)
50 |
51 | h = self.scale * (input + self.loc)
52 |
53 | if squeeze:
54 | h = h.squeeze(-1).squeeze(-1)
55 |
56 | if self.logdet:
57 | log_abs = torch.log(torch.abs(self.scale))
58 | logdet = height * width * torch.sum(log_abs)
59 | logdet = logdet * torch.ones(input.shape[0]).to(input)
60 | return h, logdet
61 |
62 | return h
63 |
64 | def reverse(self, output):
65 | if self.training and self.initialized.item() == 0:
66 | if not self.allow_reverse_init:
67 | raise RuntimeError(
68 | "Initializing ActNorm in reverse direction is "
69 | "disabled by default. Use allow_reverse_init=True to enable."
70 | )
71 | else:
72 | self.initialize(output)
73 | self.initialized.fill_(1)
74 |
75 | if len(output.shape) == 2:
76 | output = output[:, :, None, None]
77 | squeeze = True
78 | else:
79 | squeeze = False
80 |
81 | h = output / self.scale - self.loc
82 |
83 | if squeeze:
84 | h = h.squeeze(-1).squeeze(-1)
85 | return h
86 |
87 |
88 | class NLayerDiscriminator(nn.Module):
89 | """Defines a PatchGAN discriminator as in Pix2Pix."""
90 | # https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
91 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
92 | """Construct a PatchGAN discriminator
93 | Parameters:
94 | input_nc (int) -- the number of channels in input images
95 | ndf (int) -- the number of filters in the last conv layer
96 | n_layers (int) -- the number of conv layers in the discriminator
97 | """
98 | super(NLayerDiscriminator, self).__init__()
99 | if not use_actnorm:
100 | norm_layer = nn.BatchNorm2d
101 | else:
102 | norm_layer = ActNorm
103 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
104 | use_bias = norm_layer.func != nn.BatchNorm2d
105 | else:
106 | use_bias = norm_layer != nn.BatchNorm2d
107 |
108 | kw = 4
109 | padw = 1
110 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
111 | nf_mult = 1
112 | nf_mult_prev = 1
113 | for n in range(1, n_layers): # gradually increase the number of filters
114 | nf_mult_prev = nf_mult
115 | nf_mult = min(2**n, 8)
116 | sequence += [
117 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
118 | norm_layer(ndf * nf_mult),
119 | nn.LeakyReLU(0.2, True),
120 | ]
121 |
122 | nf_mult_prev = nf_mult
123 | nf_mult = min(2**n_layers, 8)
124 | sequence += [
125 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
126 | norm_layer(ndf * nf_mult),
127 | nn.LeakyReLU(0.2, True),
128 | ]
129 |
130 | sequence += [
131 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
132 | ] # output 1 channel prediction map
133 | self.main = nn.Sequential(*sequence)
134 |
135 | def forward(self, input):
136 | """Standard forward."""
137 | return self.main(input)
138 |
139 |
140 | class NLayerDiscriminator3D(nn.Module):
141 | """Defines a 3D PatchGAN discriminator as in Pix2Pix but for 3D inputs."""
142 | # https://github.com/PKU-YuanGroup/Open-Sora-Plan/blob/main/opensora/models/causalvideovae/model/losses/discriminator.py
143 | def __init__(self, input_nc=1, ndf=64, n_layers=3, use_actnorm=False):
144 | """
145 | Construct a 3D PatchGAN discriminator
146 |
147 | Parameters:
148 | input_nc (int) -- the number of channels in input volumes
149 | ndf (int) -- the number of filters in the last conv layer
150 | n_layers (int) -- the number of conv layers in the discriminator
151 | use_actnorm (bool) -- flag to use actnorm instead of batchnorm
152 | """
153 | super(NLayerDiscriminator3D, self).__init__()
154 | if not use_actnorm:
155 | norm_layer = nn.BatchNorm3d
156 | else:
157 | raise NotImplementedError("Not implemented.")
158 | if type(norm_layer) == functools.partial:
159 | use_bias = norm_layer.func != nn.BatchNorm3d
160 | else:
161 | use_bias = norm_layer != nn.BatchNorm3d
162 |
163 | kw = 3
164 | padw = 1
165 | sequence = [nn.Conv3d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
166 | nf_mult = 1
167 | nf_mult_prev = 1
168 | for n in range(1, n_layers): # gradually increase the number of filters
169 | nf_mult_prev = nf_mult
170 | nf_mult = min(2**n, 8)
171 | sequence += [
172 | nn.Conv3d(
173 | ndf * nf_mult_prev,
174 | ndf * nf_mult,
175 | kernel_size=(kw, kw, kw),
176 | stride=(2 if n == 1 else 1, 2, 2),
177 | padding=padw,
178 | bias=use_bias,
179 | ),
180 | norm_layer(ndf * nf_mult),
181 | nn.LeakyReLU(0.2, True),
182 | ]
183 |
184 | nf_mult_prev = nf_mult
185 | nf_mult = min(2**n_layers, 8)
186 | sequence += [
187 | nn.Conv3d(
188 | ndf * nf_mult_prev, ndf * nf_mult, kernel_size=(kw, kw, kw), stride=1, padding=padw, bias=use_bias
189 | ),
190 | norm_layer(ndf * nf_mult),
191 | nn.LeakyReLU(0.2, True),
192 | ]
193 |
194 | sequence += [
195 | nn.Conv3d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
196 | ] # output 1 channel prediction map
197 | self.main = nn.Sequential(*sequence)
198 |
199 | def forward(self, input):
200 | """Standard forward."""
201 | return self.main(input)
202 |
--------------------------------------------------------------------------------
/open_tokenizer/model/vidtok/modules/distributions.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | class DiagonalGaussianDistribution(object):
6 | def __init__(self, parameters, deterministic=False):
7 | self.parameters = parameters
8 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
9 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
10 | self.deterministic = deterministic
11 | self.std = torch.exp(0.5 * self.logvar)
12 | self.var = torch.exp(self.logvar)
13 | if self.deterministic:
14 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
15 |
16 | def sample(self):
17 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
18 | return x
19 |
20 | def kl(self, other=None):
21 | if self.deterministic:
22 | return torch.Tensor([0.0])
23 | else:
24 | if other is None:
25 | return 0.5 * torch.sum(
26 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
27 | dim=[1, 2, 3],
28 | )
29 | else:
30 | return 0.5 * torch.sum(
31 | torch.pow(self.mean - other.mean, 2) / other.var
32 | + self.var / other.var
33 | - 1.0
34 | - self.logvar
35 | + other.logvar,
36 | dim=[1, 2, 3],
37 | )
38 |
39 | def nll(self, sample, dims=[1, 2, 3]):
40 | if self.deterministic:
41 | return torch.Tensor([0.0])
42 | logtwopi = np.log(2.0 * np.pi)
43 | return 0.5 * torch.sum(
44 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
45 | dim=dims,
46 | )
47 |
48 | def mode(self):
49 | return self.mean
50 |
--------------------------------------------------------------------------------
/open_tokenizer/model/vidtok/modules/ema.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class LitEma(nn.Module):
6 | def __init__(self, model, decay=0.9999, use_num_upates=True):
7 | super().__init__()
8 | if decay < 0.0 or decay > 1.0:
9 | raise ValueError("Decay must be between 0 and 1")
10 |
11 | self.m_name2s_name = {}
12 | self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
13 | self.register_buffer(
14 | "num_updates",
15 | torch.tensor(0, dtype=torch.int) if use_num_upates else torch.tensor(-1, dtype=torch.int),
16 | )
17 |
18 | for name, p in model.named_parameters():
19 | if p.requires_grad:
20 | # remove as '.'-character is not allowed in buffers
21 | s_name = name.replace(".", "")
22 | self.m_name2s_name.update({name: s_name})
23 | self.register_buffer(s_name, p.clone().detach().data)
24 |
25 | self.collected_params = []
26 |
27 | def reset_num_updates(self):
28 | del self.num_updates
29 | self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int))
30 |
31 | def forward(self, model):
32 | decay = self.decay
33 |
34 | if self.num_updates >= 0:
35 | self.num_updates += 1
36 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
37 |
38 | one_minus_decay = 1.0 - decay
39 |
40 | with torch.no_grad():
41 | m_param = dict(model.named_parameters())
42 | shadow_params = dict(self.named_buffers())
43 |
44 | for key in m_param:
45 | if m_param[key].requires_grad:
46 | sname = self.m_name2s_name[key]
47 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
48 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
49 | else:
50 | assert not key in self.m_name2s_name
51 |
52 | def copy_to(self, model):
53 | m_param = dict(model.named_parameters())
54 | shadow_params = dict(self.named_buffers())
55 | for key in m_param:
56 | if m_param[key].requires_grad:
57 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
58 | else:
59 | assert not key in self.m_name2s_name
60 |
61 | def store(self, parameters):
62 | """
63 | Save the current parameters for restoring later.
64 | Args:
65 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
66 | temporarily stored.
67 | """
68 | self.collected_params = [param.clone() for param in parameters]
69 |
70 | def restore(self, parameters):
71 | """
72 | Restore the parameters stored with the `store` method.
73 | Useful to validate the model with EMA parameters without affecting the
74 | original optimization process. Store the parameters before the
75 | `copy_to` method. After validation (or model saving), use this to
76 | restore the former parameters.
77 | Args:
78 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
79 | updated with the stored parameters.
80 | """
81 | for c_param, param in zip(self.collected_params, parameters):
82 | param.data.copy_(c_param.data)
83 |
--------------------------------------------------------------------------------
/open_tokenizer/model/vidtok/modules/lpips.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import os
3 | from collections import namedtuple
4 | from tqdm import tqdm
5 |
6 | import requests
7 | import torch
8 | import torch.nn as nn
9 | from torchvision import models
10 |
11 | from .util import print0
12 |
13 | URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"}
14 |
15 | CKPT_MAP = {"vgg_lpips": "vgg.pth"}
16 |
17 | MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"}
18 |
19 |
20 | def download(url, local_path, chunk_size=1024):
21 | os.makedirs(os.path.split(local_path)[0], exist_ok=True)
22 | with requests.get(url, stream=True) as r:
23 | total_size = int(r.headers.get("content-length", 0))
24 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
25 | with open(local_path, "wb") as f:
26 | for data in r.iter_content(chunk_size=chunk_size):
27 | if data:
28 | f.write(data)
29 | pbar.update(chunk_size)
30 |
31 |
32 | def md5_hash(path):
33 | with open(path, "rb") as f:
34 | content = f.read()
35 | return hashlib.md5(content).hexdigest()
36 |
37 |
38 | def get_ckpt_path(name, root, check=False):
39 | assert name in URL_MAP
40 | path = os.path.join(root, CKPT_MAP[name])
41 | if os.path.exists(path) and not (check and not md5_hash(path) == MD5_MAP[name]):
42 | print0(
43 | "[bold cyan]\[vidtok.modules.lpips]\[get_ckpt_path][/bold cyan] Using existing path for {} model: {}".format(
44 | name, path
45 | )
46 | )
47 | return path
48 |
49 | # if not, download the model
50 | print0(
51 | "[bold cyan]\[vidtok.modules.lpips]\[get_ckpt_path][/bold cyan] Downloading {} model from {} to {}".format(
52 | name, URL_MAP[name], path
53 | )
54 | )
55 | download(URL_MAP[name], path)
56 | md5 = md5_hash(path)
57 | assert md5 == MD5_MAP[name], md5
58 | return path
59 |
60 |
61 | class LPIPS(nn.Module):
62 | # Learned perceptual metric
63 | def __init__(self, use_dropout=True):
64 | super().__init__()
65 | self.scaling_layer = ScalingLayer()
66 | self.chns = [64, 128, 256, 512, 512] # vg16 features
67 | self.net = vgg16(pretrained=True, requires_grad=False)
68 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
69 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
70 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
71 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
72 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
73 | self.load_from_pretrained()
74 | for param in self.parameters():
75 | param.requires_grad = False
76 |
77 | def load_from_pretrained(self, name="vgg_lpips"):
78 | ckpt = get_ckpt_path(name, "checkpoints/lpips")
79 | self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
80 | print0("[bold cyan]\[vidtok.modules.lpips][LPIPS][/bold cyan] loaded pretrained LPIPS loss from {}".format(ckpt))
81 |
82 | def forward(self, input, target):
83 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
84 | outs0, outs1 = self.net(in0_input), self.net(in1_input)
85 | feats0, feats1, diffs = {}, {}, {}
86 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
87 | for kk in range(len(self.chns)):
88 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
89 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
90 |
91 | res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
92 | val = res[0]
93 | for l in range(1, len(self.chns)):
94 | val += res[l]
95 | return val
96 |
97 |
98 | class ScalingLayer(nn.Module):
99 | def __init__(self):
100 | super(ScalingLayer, self).__init__()
101 | self.register_buffer("shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None])
102 | self.register_buffer("scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None])
103 |
104 | def forward(self, inp):
105 | return (inp - self.shift) / self.scale
106 |
107 |
108 | class NetLinLayer(nn.Module):
109 | """A single linear layer which does a 1x1 conv"""
110 |
111 | def __init__(self, chn_in, chn_out=1, use_dropout=False):
112 | super(NetLinLayer, self).__init__()
113 | layers = (
114 | [
115 | nn.Dropout(),
116 | ]
117 | if (use_dropout)
118 | else []
119 | )
120 | layers += [
121 | nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),
122 | ]
123 | self.model = nn.Sequential(*layers)
124 |
125 |
126 | class vgg16(torch.nn.Module):
127 | def __init__(self, requires_grad=False, pretrained=True):
128 | super(vgg16, self).__init__()
129 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
130 | self.slice1 = torch.nn.Sequential()
131 | self.slice2 = torch.nn.Sequential()
132 | self.slice3 = torch.nn.Sequential()
133 | self.slice4 = torch.nn.Sequential()
134 | self.slice5 = torch.nn.Sequential()
135 | self.N_slices = 5
136 | for x in range(4):
137 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
138 | for x in range(4, 9):
139 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
140 | for x in range(9, 16):
141 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
142 | for x in range(16, 23):
143 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
144 | for x in range(23, 30):
145 | self.slice5.add_module(str(x), vgg_pretrained_features[x])
146 | if not requires_grad:
147 | for param in self.parameters():
148 | param.requires_grad = False
149 |
150 | def forward(self, X):
151 | h = self.slice1(X)
152 | h_relu1_2 = h
153 | h = self.slice2(h)
154 | h_relu2_2 = h
155 | h = self.slice3(h)
156 | h_relu3_3 = h
157 | h = self.slice4(h)
158 | h_relu4_3 = h
159 | h = self.slice5(h)
160 | h_relu5_3 = h
161 | vgg_outputs = namedtuple("VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"])
162 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
163 | return out
164 |
165 |
166 | def normalize_tensor(x, eps=1e-10):
167 | norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
168 | return x / (norm_factor + eps)
169 |
170 |
171 | def spatial_average(x, keepdim=True):
172 | return x.mean([2, 3], keepdim=keepdim)
173 |
--------------------------------------------------------------------------------
/open_tokenizer/utils/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | from PIL import Image, ImageFile
4 | ImageFile.LOAD_TRUNCATED_IMAGES = True
5 |
6 | import torch
7 | from torch.utils.data import Dataset
8 | from torchvision import transforms
9 | from torchvision.transforms import InterpolationMode
10 |
11 | from .data_utils import load_video_from_path_decord
12 |
13 | class EvalT2IDataset(Dataset):
14 | def __init__(self, image_folder: str, data_path: str, image_size: int, scale: bool = True):
15 | super(EvalT2IDataset, self).__init__()
16 |
17 | self.image_folder = image_folder
18 | list_data_dict = json.load(open(data_path, "r"))
19 | self.list_data_dict = list_data_dict
20 |
21 | if scale:
22 | augmentations = transforms.Compose([
23 | transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
24 | transforms.ToTensor(),
25 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
26 | ])
27 | else:
28 | augmentations = transforms.Compose([
29 | transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
30 | transforms.ToTensor(),
31 | # transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
32 | ])
33 | self.transform = augmentations
34 |
35 | def __len__(self):
36 | return len(self.list_data_dict)
37 |
38 |
39 | def __getitem__(self, i):
40 | item = self.list_data_dict[i]
41 | image_file = item["image_path"]
42 | image = Image.open(os.path.join(self.image_folder, image_file)).convert("RGB")
43 | image_tensor = self.transform(image)
44 |
45 | return image_tensor, image_file
46 |
47 |
48 | class EvalT2VDataset(Dataset):
49 | def __init__(self, image_folder: str, data_path: str, image_size: int, sequence_length: int = 17, fps: int =16, sampling="center"):
50 | super(EvalT2VDataset, self).__init__()
51 |
52 | self.image_folder = image_folder
53 | list_data_dict = json.load(open(data_path, "r"))
54 | self.list_data_dict = list_data_dict
55 |
56 | self.image_size = image_size
57 | self.sequence_length = sequence_length
58 | self.fps = fps
59 | self.sampling = sampling
60 |
61 | def __len__(self):
62 | return len(self.list_data_dict)
63 |
64 |
65 | def __getitem__(self, i):
66 | item = self.list_data_dict[i]
67 | video_path = os.path.join(
68 | item["video_path"]
69 | )
70 | frames = load_video_from_path_decord(
71 | video_path,
72 | self.sampling,
73 | fps=self.fps,
74 | num_frm=self.sequence_length,
75 | crop_type="resize",
76 | crop_size=self.image_size
77 | )
78 |
79 | vid_frm_array = (
80 | torch.from_numpy(frames).float().permute(3, 0, 1, 2)
81 | )
82 | # T H W 3 -> 3 T H W
83 | assert vid_frm_array.shape[1] == self.sequence_length
84 | video = (vid_frm_array / 255.0 * 2.0 - 1.0)
85 | return video, video_path
86 |
87 |
88 |
89 | class PairedVideoDataset(Dataset):
90 | def __init__(self, video_folder, image_size=256):
91 | super(PairedVideoDataset, self).__init__()
92 |
93 | gt_folder = os.path.join(video_folder, "gt")
94 | pd_folder = os.path.join(video_folder, "pd")
95 |
96 | gt_videos = [
97 | os.path.join(gt_folder, v) for v in os.listdir(gt_folder)
98 | ]
99 | pd_videos = [
100 | os.path.join(pd_folder, v) for v in os.listdir(pd_folder)
101 | ]
102 |
103 | self.gt_videos = sorted(gt_videos)
104 | self.pd_videos = sorted(pd_videos)
105 |
106 | self.image_size = image_size
107 |
108 | def __len__(self):
109 | return len(self.gt_videos)
110 |
111 |
112 | def __getitem__(self, i):
113 | gt_video_path = self.gt_videos[i]
114 | pd_video_path = self.pd_videos[i]
115 |
116 | gt_frames = load_video_from_path_decord(
117 | gt_video_path,
118 | "all",
119 | crop_type="resize",
120 | crop_size=self.image_size
121 | )
122 |
123 | gt_video_tensor = (
124 | torch.from_numpy(gt_frames).float().permute(3, 0, 1, 2)
125 | )
126 |
127 | gt_video_tensor = (gt_video_tensor / 255.0 * 2.0 - 1.0)
128 |
129 | pd_frames = load_video_from_path_decord(
130 | pd_video_path,
131 | "all",
132 | crop_type="resize",
133 | crop_size=self.image_size
134 | )
135 |
136 | pd_video_tensor = (
137 | torch.from_numpy(pd_frames).float().permute(3, 0, 1, 2)
138 | )
139 |
140 | pd_video_tensor = (pd_video_tensor / 255.0 * 2.0 - 1.0)
141 |
142 | return gt_video_tensor, pd_video_tensor
143 |
--------------------------------------------------------------------------------
/open_tokenizer/utils/ddp_distributed.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import subprocess
4 |
5 |
6 | def setup_for_distributed(is_master):
7 | """
8 | This function disables printing when not in master process
9 | """
10 | import builtins as __builtin__
11 | builtin_print = __builtin__.print
12 |
13 | def print(*args, **kwargs):
14 | force = kwargs.pop('force', False)
15 | if is_master or force:
16 | builtin_print(*args, **kwargs)
17 |
18 | __builtin__.print = print
19 |
20 | def init_distributed_mode(args):
21 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
22 | args.rank = int(os.environ["RANK"])
23 | args.world_size = int(os.environ['WORLD_SIZE'])
24 | args.gpu = int(os.environ['LOCAL_RANK'])
25 | args.dist_url = 'env://'
26 | os.environ['LOCAL_SIZE'] = str(torch.cuda.device_count())
27 | elif 'SLURM_PROCID' in os.environ:
28 | proc_id = int(os.environ['SLURM_PROCID'])
29 | ntasks = int(os.environ['SLURM_NTASKS'])
30 | node_list = os.environ['SLURM_NODELIST']
31 | num_gpus = torch.cuda.device_count()
32 | addr = subprocess.getoutput(
33 | 'scontrol show hostname {} | head -n1'.format(node_list))
34 | os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', '29500')
35 | os.environ['MASTER_ADDR'] = addr
36 | os.environ['WORLD_SIZE'] = str(ntasks)
37 | os.environ['RANK'] = str(proc_id)
38 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
39 | os.environ['LOCAL_SIZE'] = str(num_gpus)
40 | args.dist_url = 'env://'
41 | args.world_size = ntasks
42 | args.rank = proc_id
43 | args.gpu = proc_id % num_gpus
44 | else:
45 | print('Not using distributed mode')
46 | args.distributed = False
47 | return
48 |
49 | args.distributed = True
50 |
51 | torch.cuda.set_device(args.gpu)
52 | args.dist_backend = 'nccl'
53 | print('| distributed init (rank {}): {}'.format(
54 | args.rank, args.dist_url), flush=True)
55 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
56 | world_size=args.world_size, rank=args.rank)
57 | torch.distributed.barrier()
58 | setup_for_distributed(args.rank == 0)
--------------------------------------------------------------------------------
/open_tokenizer/utils/fvd.py:
--------------------------------------------------------------------------------
1 | """
2 | Copy-pasted from https://github.com/cvpr2022-stylegan-v/stylegan-v/blob/main/src/metrics/frechet_video_distance.py
3 | """
4 | from typing import Tuple
5 | import scipy
6 | import numpy as np
7 |
8 |
9 | def compute_fvd(feats_fake: np.ndarray, feats_real: np.ndarray) -> float:
10 | mu_gen, sigma_gen = compute_stats(feats_fake)
11 | mu_real, sigma_real = compute_stats(feats_real)
12 |
13 | m = np.square(mu_gen - mu_real).sum()
14 | s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member
15 | fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2))
16 |
17 | return float(fid)
18 |
19 |
20 | def compute_stats(feats: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
21 | mu = feats.mean(axis=0) # [d]
22 | sigma = np.cov(feats, rowvar=False) # [d, d]
23 |
24 | return mu, sigma
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.black]
2 | line-length = 240
3 |
4 | [build-system]
5 | requires = ["setuptools>=61.0"]
6 | build-backend = "setuptools.build_meta"
7 |
8 | [project]
9 | name = "open_tokenizer"
10 | version = "1.0"
11 | description = "OpenTokenizer: a comprehensive comparison on open-sourced tokenizers"
12 | readme = "README.md"
13 | requires-python = ">=3.8"
14 | classifiers = [
15 | "Programming Language :: Python :: 3",
16 | "License :: OSI Approved :: Apache Software License",
17 | ]
18 |
19 | [project.optional-dependencies]
20 | standalone = [
21 | "shortuuid",
22 | "httpx==0.24.0",
23 | "einops",
24 | "ftfy",
25 | ]
26 |
27 |
28 | train = [
29 | "open_tokenizer[standalone]",
30 | "numpy==1.26.1",
31 | "open_clip_torch",
32 | "fastapi",
33 | "markdown2[all]",
34 | "numpy",
35 | "requests",
36 | "sentencepiece",
37 | "torch==2.1.2",
38 | "torchvision==0.16.2",
39 | "uvicorn",
40 | "wandb",
41 | "deepspeed==0.14.4",
42 | "peft==0.4.0",
43 | "accelerate>=0.29.1",
44 | "tokenizers~=0.15.2",
45 | "transformers@git+https://github.com/huggingface/transformers.git@1c39974a4c4036fd641bc1191cc32799f85715a4",
46 | "bitsandbytes==0.41.0",
47 | "scikit-learn==1.2.2",
48 | "sentencepiece~=0.1.99",
49 | "einops==0.6.1",
50 | "einops-exts==0.0.4",
51 | "gradio_client==0.2.9",
52 | "urllib3<=2.0.0",
53 | "datasets==2.16.1",
54 | "pydantic==1.10.8",
55 | "timm",
56 | "hf_transfer",
57 | "opencv-python",
58 | "av",
59 | "decord",
60 | "tyro",
61 | "scipy",
62 | ]
63 |
64 | [project.urls]
65 |
66 | [tool.setuptools.packages.find]
67 | include = ["open_tokenizer*", "trl*"]
68 | exclude = [
69 | "assets*",
70 | "benchmark*",
71 | "docs",
72 | "dist*",
73 | "playground*",
74 | "scripts*",
75 | "tests*",
76 | "checkpoints*",
77 | "project_checkpoints*",
78 | "debug_checkpoints*",
79 | "mlx_configs*",
80 | "wandb*",
81 | "notebooks*",
82 | ]
83 |
84 | [tool.wheel]
85 | exclude = [
86 | "assets*",
87 | "benchmark*",
88 | "docs",
89 | "dist*",
90 | "playground*",
91 | "scripts*",
92 | "tests*",
93 | "checkpoints*",
94 | "project_checkpoints*",
95 | "debug_checkpoints*",
96 | "mlx_configs*",
97 | "wandb*",
98 | "notebooks*",
99 | ]
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.21.0
2 | aiohttp==3.9.5
3 | aiosignal==1.3.1
4 | albumentations==0.3.2
5 | annotated-types==0.7.0
6 | antlr4-python3-runtime==4.9.3
7 | anykeystore==0.2
8 | asn1crypto==1.5.1
9 | asttokens==2.4.1
10 | async-timeout==4.0.3
11 | attrs==21.2.0
12 | bidict==0.23.1
13 | blessed==1.20.0
14 | boto3==1.34.113
15 | botocore==1.34.113
16 | braceexpand==0.1.7
17 | cachetools==5.3.3
18 | certifi==2024.2.2
19 | cffi==1.16.0
20 | chardet==5.2.0
21 | charset-normalizer==3.3.2
22 | click==8.1.7
23 | clip==0.2.0
24 | clip-openai==1.0.post20230121
25 | cmake==3.29.3
26 | cramjam==2.8.3
27 | crcmod==1.7
28 | cryptacular==1.6.2
29 | cryptography==39.0.2
30 | cycler==0.12.1
31 | datasets
32 | diffusers==0.30.1
33 | decorator==5.1.1
34 | decord==0.6.0
35 | deepspeed==0.14.2
36 | defusedxml==0.7.1
37 | Deprecated==1.2.14
38 | descartes==1.1.0
39 | dill==0.3.8
40 | distlib==0.3.8
41 | distro-info==1.0
42 | dnspython==2.6.1
43 | docker-pycreds==0.4.0
44 | docstring_parser==0.16
45 | ecdsa==0.19.0
46 | einops==0.6.0
47 | exceptiongroup==1.2.1
48 | executing==2.0.1
49 | fairscale==0.4.13
50 | fastparquet==2024.5.0
51 | ffmpegcv==0.3.13
52 | filelock==3.14.0
53 | fire==0.6.0
54 | fonttools==4.51.0
55 | frozenlist==1.4.1
56 | fsspec==2023.6.0
57 | ftfy==6.2.0
58 | gitdb==4.0.11
59 | GitPython==3.1.43
60 | gpustat==1.1.1
61 | greenlet==3.0.3
62 | grpcio==1.64.0
63 | h11==0.14.0
64 | hjson==3.1.0
65 | huggingface-hub==0.23.2
66 | hupper==1.12.1
67 | idna==3.7
68 | imageio==2.34.1
69 | imgaug==0.2.6
70 | iniconfig==2.0.0
71 | ipaddress==1.0.23
72 | ipdb==0.13.13
73 | ipython==8.18.1
74 | jaxtyping==0.2.28
75 | jedi==0.19.1
76 | Jinja2==3.1.4
77 | jmespath==1.0.1
78 | joblib==1.4.2
79 | jsonargparse==4.14.1
80 | jsonlines==4.0.0
81 | kiwisolver==1.4.5
82 | kornia==0.7.2
83 | kornia_rs==0.1.3
84 | lazy_loader==0.4
85 | lightning==2.2.3
86 | lightning-utilities==0.11.2
87 | lit==18.1.6
88 | MarkupSafe==2.1.5
89 | matplotlib==3.5.3
90 | matplotlib-inline==0.1.7
91 | miscreant==0.3.0
92 | mpmath==1.3.0
93 | msgpack==1.0.8
94 | multidict==6.0.5
95 | multiprocess==0.70.16
96 | natsort==8.4.0
97 | networkx==3.2.1
98 | ninja==1.11.1.1
99 | numpy==1.24.4
100 | nuscenes-devkit==1.1.11
101 | oauthlib==3.2.2
102 | omegaconf==2.3.0
103 | open-clip-torch==2.24.0
104 | openai-clip
105 | opencv-python==4.9.0.80
106 | opencv-python-headless==3.4.18.65
107 | packaging==22.0
108 | pandas==1.5.3
109 | parquet==1.3.1
110 | parso==0.8.4
111 | PasteDeploy==3.1.0
112 | pathlib2==2.3.7.post1
113 | pathtools==0.1.2
114 | pbkdf2==1.3
115 | pexpect==4.9.0
116 | pillow==10.3.0
117 | plaster==1.1.2
118 | plaster-pastedeploy==1.0.1
119 | platformdirs==4.2.2
120 | plotly==5.22.0
121 | pluggy==1.5.0
122 | ply==3.11
123 | promise==2.3
124 | prompt-toolkit==3.0.43
125 | protobuf==3.20.3
126 | psutil==5.9.8
127 | ptyprocess==0.7.0
128 | pure-eval==0.2.2
129 | py==1.11.0
130 | py-cpuinfo==9.0.0
131 | py-spy==0.3.14
132 | pyarrow==11.0.0
133 | pyarrow-hotfix==0.6
134 | pyasn1==0.6.0
135 | pycocotools==2.0.7
136 | pycparser==2.22
137 | pycryptodomex==3.20.0
138 | pycurl==7.43.0.6
139 | pydantic==1.10.15
140 | pydantic_core==2.18.3
141 | Pygments==2.18.0
142 | PyJWT==2.8.0
143 | pynvml==11.5.0
144 | pyope==0.2.2
145 | pyOpenSSL==23.2.0
146 | pyparsing==3.1.2
147 | pyquaternion==0.9.9
148 | pyramid==2.0.2
149 | pyramid-mailer==0.15.1
150 | pytest==6.2.5
151 | python-consul==1.1.0
152 | python-dateutil==2.9.0.post0
153 | python-engineio==4.9.1
154 | python-etcd==0.4.5
155 | python-jose==3.3.0
156 | python-socketio==5.11.2
157 | python3-openid==3.2.0
158 | pytorch-extension==0.2
159 | pytorch-lightning==2.2.3
160 | pytz==2024.1
161 | PyYAML==6.0.1
162 | regex==2024.5.15
163 | repoze.sendmail==4.4.1
164 | requests==2.31.0
165 | requests-oauthlib==2.0.0
166 | rsa==4.9
167 | s3transfer==0.10.1
168 | safetensors==0.4.3
169 | schedule==1.2.2
170 | scikit-image==0.22.0
171 | scikit-learn==1.5.0
172 | scipy==1.13.1
173 | sentencepiece==0.2.0
174 | sentry-sdk==2.3.1
175 | setproctitle==1.3.3
176 | Shapely==1.8.5.post1
177 | shortuuid==1.0.13
178 | simple-websocket==1.0.0
179 | six==1.16.0
180 | smmap==5.0.1
181 | SQLAlchemy==2.0.30
182 | stack-data==0.6.3
183 | sympy==1.12
184 | taming-transformers-rom1504==0.0.6
185 | tenacity==8.3.0
186 | tensorboardX==2.6.2.2
187 | termcolor==2.4.0
188 | threadpoolctl==3.5.0
189 | thriftpy2==0.5.0
190 | tifffile==2024.5.22
191 | timm==1.0.3
192 | tokenizers==0.19.1
193 | toml==0.10.2
194 | tomli==2.0.1
195 | torch==2.2.1
196 | torch-fidelity==0.3.0
197 | torchmetrics==1.4.0.post0
198 | torchvision==0.17.1
199 | tox==3.28.0
200 | tqdm==4.66.4
201 | traitlets==5.14.3
202 | transaction==4.0
203 | transformers==4.41.1
204 | translationstring==1.4
205 | triton==2.2.0
206 | typeguard==2.13.3
207 | typing_extensions==4.12.0
208 | tzdata==2024.1
209 | urllib3==1.26.18
210 | velruse==1.1.1
211 | venusian==3.1.0
212 | virtualenv==20.26.2
213 | wandb==0.17.0
214 | watchdog==4.0.1
215 | wcwidth==0.2.13
216 | webdataset==0.2.86
217 | WebOb==1.8.7
218 | websocket-client==1.8.0
219 | wrapt==1.16.0
220 | wsproto==1.2.0
221 | WTForms==3.1.2
222 | wtforms-recaptcha==0.3.2
223 | xformers==0.0.25
224 | xxhash==3.4.1
225 | yarl==1.9.4
226 | zope.deprecation==5.0
227 | zope.interface==6.4.post2
228 | zope.sqlalchemy==3.1
--------------------------------------------------------------------------------
/scripts/eval/image/cosmos.sh:
--------------------------------------------------------------------------------
1 |
2 | torchrun \
3 | --nnodes=1 --nproc_per_node=4 --master_port 23456 \
4 | eval/reconstruct.py \
5 | --vq_model_ckpt "./checkpoints/Cosmos-0.1-Tokenizer-DI8x8/" \
6 | --vq_model_type "cosmos-i" \
7 | --dataset_type "image" \
8 | --dataset_name "laion" \
9 | --save_dir "./tokenizers" \
10 | --image_path "/path_to_dir/tokenizer_bench/laion.json" \
11 | --image_folder "" \
12 | --resolution 256 \
13 | --video_fps 16 \
14 | --sequence_length 33
15 |
--------------------------------------------------------------------------------
/scripts/eval/image/emu3.sh:
--------------------------------------------------------------------------------
1 |
2 | torchrun \
3 | --nnodes=1 --nproc_per_node=4 --master_port 23456 \
4 | eval/reconstruct.py \
5 | --vq_model_ckpt "./checkpoints/Emu3-VisionTokenizer" \
6 | --vq_model_type "emu3" \
7 | --dataset_type "image" \
8 | --dataset_name "laion" \
9 | --save_dir "./tokenizers" \
10 | --image_path "/path_to_dir/tokenizer_bench/laion.json" \
11 | --image_folder "" \
12 | --resolution 256 \
13 | --video_fps 16 \
14 | --sequence_length 33
15 |
--------------------------------------------------------------------------------
/scripts/eval/image/llamagen.sh:
--------------------------------------------------------------------------------
1 |
2 | torchrun \
3 | --nnodes=1 --nproc_per_node=4 --master_port 23456 \
4 | eval/reconstruct.py \
5 | --vq_model_ckpt "./checkpoints/vq_ds16_c2i.pt" \
6 | --vq_model_type "llamagen" \
7 | --dataset_type "image" \
8 | --dataset_name "laion" \
9 | --save_dir "./tokenizers" \
10 | --image_path "/path_to_dir/tokenizer_bench/laion.json" \
11 | --image_folder "" \
12 | --resolution 256 \
13 | --video_fps 16 \
14 | --sequence_length 33
15 |
--------------------------------------------------------------------------------
/scripts/eval/image/omnitokenizer.sh:
--------------------------------------------------------------------------------
1 |
2 | torchrun \
3 | --nnodes=1 --nproc_per_node=1 --master_port 23456 \
4 | eval/reconstruct.py \
5 | --vq_model_ckpt "./checkpoints/omnitokenizer_rq_down16_code16384_joint.ckpt" \
6 | --dataset_type "image" \
7 | --dataset_name "laion" \
8 | --save_dir "./visualizations" \
9 | --image_path "/inspire/hdd/ws-8207e9e2-e733-4eec-a475-cfa1c36480ba/embodied-multimodality/public/jkwang/workspace/vtokenizer_bench/laion_pub.json" \
10 | --image_folder "/inspire/hdd/ws-8207e9e2-e733-4eec-a475-cfa1c36480ba/embodied-multimodality/public/jkwang/workspace/vtokenizer_bench" \
11 | --resolution 256 \
12 | --video_fps 16 \
13 | --sequence_length 33
14 |
--------------------------------------------------------------------------------
/scripts/eval/image/showo.sh:
--------------------------------------------------------------------------------
1 |
2 | torchrun \
3 | --nnodes=1 --nproc_per_node=4 --master_port 23456 \
4 | eval/reconstruct.py \
5 | --vq_model_ckpt "./checkpoints/magvitv2" \
6 | --vq_model_type "show-o" \
7 | --dataset_type "image" \
8 | --dataset_name "laion" \
9 | --save_dir "./tokenizers" \
10 | --image_path "/path_to_dir/tokenizer_bench/laion.json" \
11 | --image_folder "" \
12 | --resolution 256 \
13 | --video_fps 16 \
14 | --sequence_length 33
15 |
--------------------------------------------------------------------------------
/scripts/eval/image/tiktok.sh:
--------------------------------------------------------------------------------
1 |
2 | torchrun \
3 | --nnodes=1 --nproc_per_node=4 --master_port 23456 \
4 | eval/reconstruct.py \
5 | --vq_model_ckpt "./checkpoints/tokenizer_titok_sl256_vq8k_imagenet" \
6 | --vq_model_type "tiktok" \
7 | --dataset_type "image" \
8 | --dataset_name "laion" \
9 | --save_dir "./tokenizers" \
10 | --image_path "/path_to_dir/tokenizer_bench/laion.json" \
11 | --image_folder "" \
12 | --resolution 256 \
13 | --video_fps 16 \
14 | --sequence_length 33
15 |
--------------------------------------------------------------------------------
/scripts/eval/video/cosmos.sh:
--------------------------------------------------------------------------------
1 |
2 | torchrun \
3 | --nnodes=1 --nproc_per_node=4 --master_port 23457 \
4 | eval/reconstruct.py \
5 | --vq_model_ckpt "./checkpoints/Cosmos-0.1-Tokenizer-DV4x8x8/" \
6 | --vq_model_type "cosmos-v" \
7 | --dataset_type "video" \
8 | --dataset_name "openvid" \
9 | --save_dir "./tokenizers" \
10 | --video_path "/path_to_dir/tokenizer_bench/openvid.json" \
11 | --video_folder "" \
12 | --resolution 256 \
13 | --video_fps 16 \
14 | --sequence_length 16
15 |
--------------------------------------------------------------------------------
/scripts/eval/video/emu3.sh:
--------------------------------------------------------------------------------
1 |
2 | torchrun \
3 | --nnodes=1 --nproc_per_node=4 --master_port 23456 \
4 | eval/reconstruct.py \
5 | --vq_model_ckpt "./checkpoints/Emu3-VisionTokenizer" \
6 | --vq_model_type "emu3" \
7 | --dataset_type "video" \
8 | --dataset_name "openvid" \
9 | --save_dir "./tokenizers" \
10 | --video_path "/path_to_dir/tokenizer_bench/openvid.json" \
11 | --video_folder "" \
12 | --resolution 256 \
13 | --video_fps 16 \
14 | --sequence_length 16 \
15 | --batch_size 2
--------------------------------------------------------------------------------
/scripts/eval/video/omnitokenizer.sh:
--------------------------------------------------------------------------------
1 |
2 | torchrun \
3 | --nnodes=1 --nproc_per_node=4 --master_port 23456 \
4 | eval/reconstruct.py \
5 | --vq_model_ckpt "./checkpoints/omnitokenizer_rq_code16384_down16_joint2.ckpt" \
6 | --dataset_type "video" \
7 | --dataset_name "openvid" \
8 | --save_dir "./tokenizers" \
9 | --video_path "/path_to_dir/tokenizer_bench/openvid.json" \
10 | --video_folder "" \
11 | --resolution 256 \
12 | --video_fps 16 \
13 | --sequence_length 17
14 |
--------------------------------------------------------------------------------
/src/1-cos.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wdrink/OpenTokenizer/0f298fe6b3397acc358f16207c51e6e2820f413b/src/1-cos.gif
--------------------------------------------------------------------------------
/src/1-emu3.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wdrink/OpenTokenizer/0f298fe6b3397acc358f16207c51e6e2820f413b/src/1-emu3.gif
--------------------------------------------------------------------------------
/src/1-omni.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wdrink/OpenTokenizer/0f298fe6b3397acc358f16207c51e6e2820f413b/src/1-omni.gif
--------------------------------------------------------------------------------
/src/2-cos.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wdrink/OpenTokenizer/0f298fe6b3397acc358f16207c51e6e2820f413b/src/2-cos.gif
--------------------------------------------------------------------------------
/src/2-emu3.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wdrink/OpenTokenizer/0f298fe6b3397acc358f16207c51e6e2820f413b/src/2-emu3.gif
--------------------------------------------------------------------------------
/src/2-omni.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wdrink/OpenTokenizer/0f298fe6b3397acc358f16207c51e6e2820f413b/src/2-omni.gif
--------------------------------------------------------------------------------
/src/vis-img.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wdrink/OpenTokenizer/0f298fe6b3397acc358f16207c51e6e2820f413b/src/vis-img.png
--------------------------------------------------------------------------------