├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── eval ├── calculate_image.py ├── calculate_video.py ├── extract_token.py └── reconstruct.py ├── open_tokenizer ├── __init__.py ├── model │ ├── common_modules.py │ ├── cosmos_tokenizer │ │ ├── __init__.py │ │ ├── image_cli.py │ │ ├── image_lib.py │ │ ├── modules │ │ │ ├── __init__.py │ │ │ ├── distributions.py │ │ │ ├── layers2d.py │ │ │ ├── layers3d.py │ │ │ ├── patching.py │ │ │ ├── quantizers.py │ │ │ └── utils.py │ │ ├── networks │ │ │ ├── __init__.py │ │ │ ├── configs.py │ │ │ ├── continuous_image.py │ │ │ ├── continuous_video.py │ │ │ ├── discrete_image.py │ │ │ └── discrete_video.py │ │ ├── utils.py │ │ ├── video_cli.py │ │ └── video_lib.py │ ├── llamagen_tokenizer.py │ ├── misc.py │ ├── modeling_utils.py │ ├── omnitokenizer.py │ ├── omnitokenizer_modules.py │ ├── showo.py │ ├── tiktok │ │ ├── __init__.py │ │ ├── maskgit.py │ │ ├── modules │ │ │ ├── __init__.py │ │ │ ├── base_model.py │ │ │ ├── blocks.py │ │ │ ├── discriminator.py │ │ │ ├── ema_model.py │ │ │ ├── losses.py │ │ │ ├── maskgit_vqgan.py │ │ │ └── perceptual_loss.py │ │ ├── quantizer │ │ │ ├── __init__.py │ │ │ └── quantizer.py │ │ ├── rar.py │ │ └── titok.py │ └── vidtok │ │ ├── configs │ │ ├── vidtok_fsq_causal_41616_262144.yaml │ │ ├── vidtok_fsq_causal_488_262144.yaml │ │ ├── vidtok_fsq_causal_488_32768.yaml │ │ ├── vidtok_fsq_causal_488_4096.yaml │ │ ├── vidtok_fsq_noncausal_41616_262144.yaml │ │ ├── vidtok_fsq_noncausal_488_262144.yaml │ │ ├── vidtok_kl_causal_288_8chn.yaml │ │ ├── vidtok_kl_causal_41616_4chn.yaml │ │ ├── vidtok_kl_causal_444_4chn.yaml │ │ ├── vidtok_kl_causal_488_16chn.yaml │ │ ├── vidtok_kl_causal_488_4chn.yaml │ │ ├── vidtok_kl_causal_488_8chn.yaml │ │ ├── vidtok_kl_noncausal_41616_4chn.yaml │ │ ├── vidtok_kl_noncausal_488_4chn.yaml │ │ └── vidtwin │ │ │ └── vidtwin_structure_7_7_8_dynamics_7_8.yaml │ │ ├── data │ │ ├── datamodule.py │ │ ├── video_read.py │ │ └── vidtok.py │ │ ├── models │ │ └── autoencoder.py │ │ └── modules │ │ ├── discriminator.py │ │ ├── distributions.py │ │ ├── ema.py │ │ ├── logger.py │ │ ├── losses.py │ │ ├── lpips.py │ │ ├── model_3dcausal.py │ │ ├── model_3dnoncausal.py │ │ ├── regularizers.py │ │ └── util.py └── utils │ ├── data_utils.py │ ├── dataset.py │ ├── ddp_distributed.py │ ├── fvd.py │ └── metrics.py ├── pyproject.toml ├── requirements.txt ├── scripts └── eval │ ├── image │ ├── cosmos.sh │ ├── emu3.sh │ ├── llamagen.sh │ ├── omnitokenizer.sh │ ├── showo.sh │ └── tiktok.sh │ └── video │ ├── cosmos.sh │ ├── emu3.sh │ └── omnitokenizer.sh └── src ├── 1-cos.gif ├── 1-emu3.gif ├── 1-omni.gif ├── 2-cos.gif ├── 2-emu3.gif ├── 2-omni.gif └── vis-img.png /.gitattributes: -------------------------------------------------------------------------------- 1 | # https://git-scm.com/docs/gitattributes 2 | # Set the default behavior, in case people don't have core.autocrlf set. 3 | # https://git-scm.com/docs/gitattributes#_end_of_line_conversion 4 | * text=auto 5 | # common python attributes, taken from https://github.com/alexkaratarakis/gitattributes/blob/710900479a2bedeec7003d381719521ffbb18bf8/Python.gitattributes 6 | # Source files 7 | # ============ 8 | *.pxd text diff=python 9 | *.py text diff=python 10 | *.py3 text diff=python 11 | *.pyw text diff=python 12 | *.pyx text diff=python 13 | *.pyz text diff=python 14 | *.pyi text diff=python 15 | # Binary files 16 | # ============ 17 | *.db binary 18 | *.p binary 19 | *.pkl binary 20 | *.pickle binary 21 | *.pyc binary export-ignore 22 | *.pyo binary export-ignore 23 | *.pyd binary 24 | # Jupyter notebook 25 | *.ipynb text eol=lf 26 | *.large-file-extension filter=lfs diff=lfs merge=lfs -text 27 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__ 3 | *.pyc 4 | *.egg-info 5 | dist 6 | 7 | # Log 8 | *.log 9 | *.log.* 10 | # *.json 11 | # *.jsonl 12 | 13 | # Data 14 | !**/alpaca-data-conversation.json 15 | # Editor 16 | .idea 17 | *.swp 18 | .vscode 19 | 20 | # Other 21 | .DS_Store 22 | wandb 23 | output 24 | llavavid 25 | 26 | checkpoints/ 27 | visualize/ 28 | project_checkpoints 29 | debug_checkpoints 30 | playground/data 31 | playground/cc3m_llava34b_cap 32 | ckpts* 33 | classify_image_graph_def.pb 34 | visualizations 35 | 36 | .ipynb_checkpoints 37 | chunyl_scripts 38 | *.ipynb 39 | 40 | # DevContainer 41 | !.devcontainer/* 42 | 43 | # Demo 44 | serve_images/ 45 | notebooks/ 46 | logs 47 | scripts/dist_* 48 | logs/ 49 | submissions/ 50 | cn_scripts/ 51 | internal_project_checkpoints/ 52 | work_dirs 53 | scripts/i18n/* 54 | playground/.nfs028b000000010add00000001 55 | HIP 56 | playground/.nfs028b0000017bff2c00000012 57 | scripts/qwen 58 | scripts/vicuna 59 | scripts/mistral 60 | scripts/baseline_rep 61 | scripts/cn_boli01_hl 62 | scripts/cn_boli01_lf 63 | scripts/cn_lf 64 | scripts/cn_lq 65 | scripts/cn_yg 66 | scripts/cn_yg_hao 67 | scripts/eva_encoder 68 | scripts/i18n 69 | scripts/i18n_higher_res 70 | scripts/multi-images 71 | scratchpad 72 | build/ 73 | playground/*.json 74 | mlx_configs/ 75 | data_processing/ 76 | # demo/ 77 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # OpenTokenizer: A Comprehensive Comparision on Open-sourced Visual Tokenizers 2 | 3 | This repo aims to provide a comprehensive comparison on open-sourced discrete visual tokenizers under a fair setting: same resoution, same data, same evaluation code. 4 | 5 | ## Setup 6 | 7 | We use Python 3.10 and Pytorch 2.1.2, with CUDA Version 12.2. 8 | Please install the required packages using the following command: 9 | 10 | ``` 11 | pip install -e . 12 | ``` 13 | 14 | ## Benchmark 15 | 16 | We use a subset of [laion-aesthetics-12m-umap](https://huggingface.co/datasets/dclure/laion-aesthetics-12m-umap) and [OpenVid-1M](https://huggingface.co/datasets/nkp37/OpenVid-1M) as the test set to benchmark all the available tokenizers. You can download the data directly from [here](https://huggingface.co/datasets/Daniel0724/VTokenizer_Bench). 17 | 18 | The comparison of image and video reconstruction performance is shown below: 19 | 20 | | Name | Compression | PSNR | SSIM | rFID | FPS2 | 21 | | ---------- | ---------- | ---------- | ---------- | ----------- | ----------- | 22 | | [OmniTokenizer](https://huggingface.co/datasets/Daniel0724/VTokenizer_Bench/resolve/main/omnitokenizer_rq_down16_code16384_joint.ckpt)1 | 8x8 | 32.16 | 0.86 | 2.02 | 121.24 | 23 | | [Cosmos-DI](https://huggingface.co/nvidia/Cosmos-0.1-Tokenizer-DI8x8) | 8x8 | 31.04 | 0.83 | 0.95 | 612.91 | 24 | | [Emu3](https://huggingface.co/BAAI/Emu3-VisionTokenizer) | 8x8 | 31.51 | 0.83 | 0.57 | 12.72 | 25 | | [LlamaGen](https://huggingface.co/FoundationVision/LlamaGen/resolve/main/vq_ds16_c2i.pt)| 16x16 | 26.98 | 0.71 | 1.64 | 198.06 | 26 | | [Show-O](https://huggingface.co/showlab/magvitv2) | 16x16 | 27.03 | 0.71 | 1.54 | 129.62 | 27 | | [Tiktok](https://huggingface.co/yucornetto/tokenizer_titok_sl256_vq8k_imagenet) | 1D | 27.48 | 0.72 | 1.30 | 303.06 | 28 | 29 | 30 | | Name | Compression | PSNR | SSIM | rFID | FPS | 31 | | ---------- | ---------- | ---------- | ---------- | ----------- | ----------- | 32 | | [OmniTokenizer](https://huggingface.co/datasets/Daniel0724/VTokenizer_Bench/resolve/main/omnitokenizer_rq_down16_code16384_joint.ckpt) | 2x8x8 | 33.51 | 0.93 | 4.10 | 67.08 | 33 | | [Cosmos-DV](https://huggingface.co/nvidia/Cosmos-0.1-Tokenizer-DV4x8x8) | 4x8x8 | 34.54 | 0.94 | 4.19 | 76.46 | 34 | | [Emu3](https://huggingface.co/BAAI/Emu3-VisionTokenizer) | 4x8x8 | 33.44 | 0.92 | 8.15 | 3.24 | 35 | 36 | 37 | 1: OmniTokenizer is the only tokenizer designed for joint image and video tokenization. Here we release an [any-res checkpoint](https://huggingface.co/datasets/Daniel0724/VTokenizer_Bench/resolve/main/omnitokenizer_rq_down16_code16384_joint.ckpt) of OmniTokenizer, meaning we train the model with 2D Rope on images / videos with different aspect ratios and resolutions. 38 | 39 | 2: we denote the throughput of tokenizers as FPS, images (videos) per second. 40 | 41 | 42 | 43 | ## Usage 44 | 45 | Please run the following code for reconstruction: 46 | 47 | ``` 48 | torchrun \ 49 | --nnodes=1 --nproc_per_node=4 --master_port 23456 \ 50 | eval/reconstruct.py \ 51 | --vq_model_type "omnitokenizer" \ 52 | --vq_model_ckpt "./checkpoints/omnitokenizer_rq_code16384_down16_joint2.ckpt" \ 53 | --dataset_type "video" \ 54 | --dataset_name "openvid" \ 55 | --save_dir "./tokenizers" \ 56 | --video_path "path_to_dir/tokenizer_bench/openvid.json" \ 57 | --video_folder "" \ 58 | --resolution 256 \ 59 | --video_fps 16 \ 60 | --sequence_length 17 61 | ``` 62 | 63 | Please specify the tokenizer type with **--vq_model_type**, the ckpt with **--vq_model_ckpt**. The above command will automatically save the reconstructed results under **--save_dir**. We provide the script to inference different tokenizers under *./scripts/eval*. 64 | 65 | After this, you may run the following command to obtain the reconstruction metrics: 66 | 67 | ``` 68 | # For image reconstruction: 69 | 70 | python eval/calculate_image.py /path_to_gt_dir /path_to_recon_dir 71 | 72 | # For video reconstruction: 73 | 74 | python eval/calculate_video.py /path_to_video_dir 75 | 76 | ``` 77 | 78 | ## Offline Code Extraction 79 | 80 | We also provide a script to facilitate the token extraction using different tokenizers in an offline manner in *./eval/extract_token.py*. 81 | 82 | ## Results 83 | 84 | We show the 256x256 image reconstruction results below: 85 | 86 |

87 | 88 |

89 | 90 | From left to right, we compare OmniTokenizer, Cosmos-Tokenizer, and Emu3: 91 | 92 |

93 | 94 |

95 | 96 |

97 | 98 |

99 | 100 | ## To Do List 101 | 102 | - ~~open source the evaluation code~~ 103 | - ~~support the calculation of different metrics~~ 104 | - open source the training code 105 | 106 | ## Citation 107 | 108 | If you consider our work useful, please consider citing our paper using: 109 | ```bib 110 | @inproceedings{wang2024omnitokenizer, 111 | title={OmniTokenizer: A Joint Image-Video Tokenizer for Visual Generation}, 112 | author={Wang, Junke and Jiang, Yi and Yuan, Zehuan and Peng, Binyue and Wu, Zuxuan and Jiang, Yu-Gang}, 113 | booktitle={NeurIPS}, 114 | year={2024} 115 | } 116 | ``` 117 | 118 | ## Acknowledgement 119 | 120 | Thanks [OmniTokenizer](https://arxiv.org/abs/2406.09399), [LlamaGen](https://arxiv.org/abs/2406.06525), [Cosmos-Tokenizer](https://arxiv.org/html/2501.03575v1), [Emu3](https://arxiv.org/abs/2409.18869), [Show-O](https://arxiv.org/abs/2408.12528), and [TikTok](https://arxiv.org/abs/2406.07550) for their great work. We also borrow several functions from [Reducio](https://arxiv.org/abs/2411.13552) for evaluation. 121 | 122 | 123 | 124 | 125 | -------------------------------------------------------------------------------- /eval/calculate_image.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | from tqdm import tqdm 5 | 6 | from PIL import Image 7 | from cleanfid import fid 8 | 9 | 10 | import torchvision.transforms as transforms 11 | from open_tokenizer.utils.metrics import compute_psnr, compute_ssim 12 | 13 | def load_image_as_tensor(image_path): 14 | """ 15 | Loads an image from the given path, converts it to a tensor, and normalizes it to [0, 1]. 16 | """ 17 | img = Image.open(image_path).convert("RGB") # Convert to RGB to ensure 3 channels 18 | transform = transforms.Compose([ 19 | transforms.ToTensor(), # Converts to tensor and scales to [0, 1] 20 | ]) 21 | return transform(img) 22 | 23 | 24 | def calculate_metrics(source_path, target_path): 25 | """ 26 | Calculates PSNR, SSIM, and FID between images in the source and target directories. 27 | Assumes both directories have corresponding images with identical filenames. 28 | """ 29 | source_images = sorted([os.path.join(source_path, f) for f in os.listdir(source_path) if f.endswith('jpg')]) 30 | target_images = sorted([os.path.join(target_path, f) for f in os.listdir(target_path) if f.endswith('jpg')]) 31 | 32 | if len(source_images) != len(target_images): 33 | raise ValueError("The number of images in the source and target directories must match.") 34 | 35 | psnr_values = [] 36 | ssim_values = [] 37 | 38 | for i in tqdm(range(len(source_images))): 39 | src, tgt = source_images[i], target_images[i] 40 | src_tensor = load_image_as_tensor(src).unsqueeze(0) 41 | tgt_tensor = load_image_as_tensor(tgt).unsqueeze(0) 42 | 43 | psnr = compute_psnr(src_tensor, tgt_tensor) 44 | ssim = compute_ssim(src_tensor, tgt_tensor) 45 | psnr_values.append(psnr) 46 | ssim_values.append(ssim) 47 | 48 | avg_psnr = sum(psnr_values) / len(psnr_values) 49 | avg_ssim = sum(ssim_values) / len(ssim_values) 50 | 51 | # Calculate FID 52 | score = fid.compute_fid(source_path, target_path) 53 | return avg_psnr.item(), avg_ssim.item(), score 54 | 55 | 56 | if __name__ == "__main__": 57 | parser = argparse.ArgumentParser(description="Calculate PSNR, SSIM, and FID metrics between source and target image directories.") 58 | parser.add_argument('source_path', type=str, help="Path to the directory containing source images.") 59 | parser.add_argument('target_path', type=str, help="Path to the directory containing target images.") 60 | args = parser.parse_args() 61 | 62 | avg_psnr, avg_ssim, fid = calculate_metrics(args.source_path, args.target_path) 63 | print(f"Average PSNR: {avg_psnr:.4f}") 64 | print(f"Average SSIM: {avg_ssim:.4f}") 65 | print(f"FID: {fid:.4f}") 66 | 67 | with open(os.path.join(os.path.dirname(args.source_path), "metrics.json"), "w") as f: 68 | json.dump( 69 | { 70 | "psnr": avg_psnr, 71 | "ssim": avg_ssim, 72 | "fid": fid 73 | }, 74 | f) 75 | -------------------------------------------------------------------------------- /eval/calculate_video.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | from tqdm import tqdm 5 | from einops import rearrange 6 | 7 | import torch 8 | from torch.utils.data import DataLoader 9 | 10 | from open_tokenizer.utils.data_utils import open_url 11 | from open_tokenizer.utils.dataset import PairedVideoDataset 12 | from open_tokenizer.utils.metrics import compute_psnr, compute_ssim 13 | from open_tokenizer.utils.fvd import compute_fvd 14 | 15 | 16 | class VideoMetrics: 17 | def __init__(self, video_folder, image_size, device="cuda"): 18 | self.dataset = PairedVideoDataset(video_folder, image_size) 19 | num_workers = 1 20 | self.dataloader = DataLoader(self.dataset, batch_size=4, num_workers=num_workers, drop_last=False) 21 | 22 | detector_ckpt = './checkpoints/i3d_torchscript.pt' 23 | if not os.path.exists(detector_ckpt): 24 | raise FileNotFoundError(f"Detector checkpoint not found at: {detector_ckpt}") 25 | 26 | self.detector_kwargs = dict(rescale=False, resize=False, return_features=True) 27 | 28 | with open_url(detector_ckpt, verbose=False) as f: 29 | detector = torch.jit.load(f).eval().to(device) 30 | 31 | self.detector = detector 32 | self.device = device 33 | 34 | def calculate_metrics(self, metrics="fvd,psnr,ssim"): 35 | metrics = metrics.split(",") 36 | all_gt_feats = [] 37 | all_pd_feats = [] 38 | all_psnr = [] 39 | all_ssim = [] 40 | 41 | total_samples = 0 42 | for video_gt, video_pd in tqdm(self.dataloader): 43 | videos_fake = video_pd.to(self.device)[:, :, :16] 44 | videos_real = video_gt.to(self.device)[:, :, :16] 45 | total_samples += video_gt.shape[0] * 16 46 | 47 | flatten_videos_real = rearrange(videos_real, "b c t h w -> (b t) c h w") 48 | flatten_videos_fake = rearrange(videos_fake, "b c t h w -> (b t) c h w") 49 | 50 | if "fvd" in metrics: 51 | with torch.no_grad(): 52 | feats_fake = self.detector(videos_fake, **self.detector_kwargs) 53 | feats_real = self.detector(videos_real, **self.detector_kwargs) 54 | 55 | all_gt_feats.append(feats_real) 56 | all_pd_feats.append(feats_fake) 57 | 58 | if "psnr" in metrics: 59 | psnr = compute_psnr(flatten_videos_real, flatten_videos_fake) 60 | all_psnr.append(psnr) 61 | 62 | if "ssim" in metrics: 63 | ssim = compute_ssim(flatten_videos_real, flatten_videos_fake) 64 | all_ssim.append(ssim) 65 | 66 | avg_psnr = sum(torch.cat(all_psnr, dim=0)) / total_samples if total_samples > 0 else None 67 | avg_ssim = sum(torch.cat(all_ssim, dim=0)) / total_samples if total_samples > 0 else None 68 | 69 | if all_gt_feats: 70 | all_gt_feats = torch.cat(all_gt_feats, dim=0).cpu().numpy() 71 | all_pd_feats = torch.cat(all_pd_feats, dim=0).cpu().numpy() 72 | avg_fvd = compute_fvd(all_gt_feats, all_pd_feats) 73 | else: 74 | avg_fvd = None 75 | 76 | return avg_psnr, avg_ssim, avg_fvd 77 | 78 | 79 | if __name__ == "__main__": 80 | parser = argparse.ArgumentParser(description="Calculate PSNR, SSIM, and FVD metrics between source and target videos.") 81 | parser.add_argument("--video_dir", type=str, help="Path to the directory containing videos.") 82 | parser.add_argument("--image_size", type=int, default=224, help="Size to resize videos (default: 224).") 83 | args = parser.parse_args() 84 | 85 | metrics = VideoMetrics(args.video_dir, args.image_size, device="cuda") 86 | avg_psnr, avg_ssim, avg_fvd = metrics.calculate_metrics() 87 | 88 | print(f"Average PSNR: {avg_psnr.item():.4f}" if avg_psnr else "PSNR not computed.") 89 | print(f"Average SSIM: {avg_ssim.item():.4f}" if avg_ssim else "SSIM not computed.") 90 | print(f"FVD: {avg_fvd:.4f}" if avg_fvd else "FVD not computed.") 91 | 92 | results_path = os.path.join(os.path.dirname(args.video_dir), "metrics.json") 93 | with open(results_path, "w") as f: 94 | json.dump( 95 | { 96 | "psnr": avg_psnr.item() if avg_psnr is not None else "N/A", 97 | "ssim": avg_ssim.item() if avg_ssim is not None else "N/A", 98 | "fvd": avg_fvd if avg_fvd is not None else "N/A", 99 | }, 100 | f, 101 | ) 102 | print(f"Metrics saved to {results_path}") -------------------------------------------------------------------------------- /eval/extract_token.py: -------------------------------------------------------------------------------- 1 | import torch 2 | torch.backends.cuda.matmul.allow_tf32 = True 3 | torch.backends.cudnn.allow_tf32 = True 4 | import torch.distributed as dist 5 | from torch.utils.data import DataLoader 6 | from torch.utils.data.distributed import DistributedSampler 7 | 8 | from PIL import Image 9 | import os 10 | import numpy as np 11 | from tqdm import tqdm 12 | from dataclasses import dataclass, field 13 | from typing import Optional 14 | 15 | import transformers 16 | from transformers import AutoTokenizer 17 | 18 | from open_tokenizer.utils.ddp_distributed import init_distributed_mode 19 | from open_tokenizer.model.omnitokenizer import OmniTokenizer 20 | from open_tokenizer.utils.dataset import EvalT2IDataset, EvalT2VDataset 21 | 22 | @dataclass 23 | class DataArguments: 24 | image_folder: Optional[str] = field(default=None) 25 | image_path: str = field(default=None) 26 | resolution: Optional[int] = field(default=None) 27 | 28 | video_folder: Optional[str] = field(default=None) 29 | video_path: str = field(default=None) 30 | sequence_length: int = field(default=17) 31 | video_fps: Optional[int] = field(default=1) 32 | 33 | 34 | vq_model_type: str = field(default="omnitokenizer") 35 | vq_model_ckpt: str = field(default="./omnitokenizer_rq_down16_code16384_joint.ckpt") 36 | save_dir: str = field(default=None) 37 | 38 | dataset_name: str = field(default="laion") 39 | dataset_type: str = field(default="image") 40 | 41 | batch_size: int = field(default=8) 42 | workers: int = field(default=4) 43 | 44 | 45 | def main(args): 46 | os.makedirs(f"{args.code_path}/{args.dataset_name}/{args.gen_resolution}_codes", exist_ok=True) 47 | os.makedirs(f"{args.code_path}/{args.dataset_name}/{args.gen_resolution}_labels", exist_ok=True) 48 | 49 | init_distributed_mode(args) 50 | rank = dist.get_rank() 51 | device = rank % torch.cuda.device_count() 52 | seed = 0 * dist.get_world_size() + rank 53 | torch.manual_seed(seed) 54 | torch.cuda.set_device(device) 55 | 56 | model = OmniTokenizer.load_from_checkpoint(args.vq_model_ckpt, strict=False) 57 | model.eval() 58 | model.requires_grad_(False) 59 | model = model.to(device) 60 | 61 | downsample_factor = model.downsample_factor if model is not None else 16 62 | if args.dataset_type == "image": 63 | dataset = EvalT2IDataset(image_folder=args.image_folder, data_path=args.image_path, image_size=args.resolution) 64 | else: 65 | dataset = EvalT2VDataset(image_folder=args.video_folder, data_path=args.video_path, image_size=args.resolution, sequence_length=args.sequence_length, fps=args.video_fps) 66 | 67 | if dist.get_rank() == 0: 68 | print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.") 69 | print(f"Extracting image tokens for {len(dataset)} samples.") 70 | 71 | sampler = DistributedSampler( 72 | dataset, 73 | num_replicas=dist.get_world_size(), 74 | rank=rank, 75 | shuffle=False, 76 | ) 77 | 78 | loader = DataLoader( 79 | dataset, 80 | batch_size=1, # important! 81 | shuffle=False, 82 | sampler=sampler, 83 | num_workers=0, 84 | pin_memory=False, 85 | drop_last=False 86 | ) 87 | 88 | total = 0 89 | for image, caption in tqdm(loader): 90 | image_tensor = image.to(device) 91 | 92 | if image_tensor.ndim == 5: 93 | image_tensor = image_tensor.permute(0, 2, 1, 3, 4).contiguous() 94 | 95 | with torch.no_grad(): 96 | image_token = model.encode(image_tensor)[1] 97 | 98 | x = image_token.detach().cpu().numpy() 99 | train_steps = rank + total 100 | 101 | np.save(f'{args.code_path}/{args.dataset_name}/{args.gen_resolution}_codes/{train_steps}.npy', x) 102 | with open(f'{args.code_path}/{args.dataset_name}/{args.gen_resolution}_labels/{train_steps}.npy', "w") as f: 103 | f.write(caption[0]) 104 | 105 | total += dist.get_world_size() 106 | 107 | dist.destroy_process_group() 108 | 109 | if __name__ == "__main__": 110 | parser = transformers.HfArgumentParser([DataArguments]) 111 | args = parser.parse_args_into_dataclasses()[0] 112 | args.gen_image_folder = "" 113 | 114 | main(args) -------------------------------------------------------------------------------- /open_tokenizer/__init__.py: -------------------------------------------------------------------------------- 1 | from .model.omnitokenizer import OmniTokenizer 2 | -------------------------------------------------------------------------------- /open_tokenizer/model/cosmos_tokenizer/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. -------------------------------------------------------------------------------- /open_tokenizer/model/cosmos_tokenizer/image_cli.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """A CLI to run ImageTokenizer on plain images based on torch.jit. 16 | 17 | Usage: 18 | python3 -m cosmos_tokenizer.image_cli \ 19 | --image_pattern 'path/to/input/folder/*.jpg' \ 20 | --output_dir ./reconstructions \ 21 | --checkpoint_enc ./pretrained_ckpts/CosmosCI_f8x8/encoder.jit \ 22 | --checkpoint_dec ./pretrained_ckpts/CosmosCI_f8x8/decoder.jit 23 | 24 | Optionally, you can run the model in pure PyTorch mode: 25 | python3 -m cosmos_tokenizer.image_cli \ 26 | --image_pattern 'path/to/input/folder/*.jpg' \ 27 | --mode torch \ 28 | --tokenizer_type CI \ 29 | --spatial_compression 8 \ 30 | --checkpoint_enc ./pretrained_ckpts/CosmosCI_f8x8/encoder.jit \ 31 | --checkpoint_dec ./pretrained_ckpts/CosmosCI_f8x8/decoder.jit 32 | """ 33 | 34 | import os 35 | from argparse import ArgumentParser, Namespace 36 | import sys 37 | from typing import Any 38 | 39 | import numpy as np 40 | from loguru import logger as logging 41 | from cosmos_tokenizer.networks import TokenizerConfigs 42 | 43 | from cosmos_tokenizer.image_lib import ImageTokenizer 44 | from cosmos_tokenizer.utils import ( 45 | get_filepaths, 46 | get_output_filepath, 47 | read_image, 48 | resize_image, 49 | write_image, 50 | ) 51 | 52 | 53 | def _parse_args() -> tuple[Namespace, dict[str, Any]]: 54 | parser = ArgumentParser( 55 | description="A CLI for running ImageTokenizer on plain images." 56 | ) 57 | parser.add_argument( 58 | "--image_pattern", 59 | type=str, 60 | default="path/to/images/*.jpg", 61 | help="Glob pattern.", 62 | ) 63 | parser.add_argument( 64 | "--checkpoint", 65 | type=str, 66 | default=None, 67 | help="JIT full Autoencoder model filepath.", 68 | ) 69 | parser.add_argument( 70 | "--checkpoint_enc", 71 | type=str, 72 | default=None, 73 | help="JIT Encoder model filepath.", 74 | ) 75 | parser.add_argument( 76 | "--checkpoint_dec", 77 | type=str, 78 | default=None, 79 | help="JIT Decoder model filepath.", 80 | ) 81 | parser.add_argument( 82 | "--tokenizer_type", 83 | type=str, 84 | choices=["CI", "DI"], 85 | help="Specifies the tokenizer type.", 86 | ) 87 | parser.add_argument( 88 | "--spatial_compression", 89 | type=int, 90 | choices=[8, 16], 91 | default=8, 92 | help="The spatial compression factor.", 93 | ) 94 | parser.add_argument( 95 | "--mode", 96 | type=str, 97 | choices=["torch", "jit"], 98 | default="jit", 99 | help="Specify the backend: native 'torch' or 'jit' (default: 'jit')", 100 | ) 101 | parser.add_argument( 102 | "--short_size", 103 | type=int, 104 | default=None, 105 | help="The size to resample inputs. None, by default.", 106 | ) 107 | parser.add_argument( 108 | "--dtype", 109 | type=str, 110 | default="bfloat16", 111 | help="Sets the precision. Default bfloat16.", 112 | ) 113 | parser.add_argument( 114 | "--device", 115 | type=str, 116 | default="cuda", 117 | help="Device for invoking the model.", 118 | ) 119 | parser.add_argument( 120 | "--output_dir", type=str, default=None, help="Output directory." 121 | ) 122 | parser.add_argument( 123 | "--save_input", 124 | action="store_true", 125 | help="If on, the input image will be be outputed too.", 126 | ) 127 | args = parser.parse_args() 128 | return args 129 | 130 | 131 | logging.info("Initializes args ...") 132 | args = _parse_args() 133 | if args.mode == "torch" and args.tokenizer_type not in ["CI", "DI"]: 134 | logging.error("'torch' backend requires the tokenizer_type of 'CI' or 'DI'.") 135 | sys.exit(1) 136 | 137 | 138 | def _run_eval() -> None: 139 | """Invokes the evaluation pipeline.""" 140 | 141 | if ( 142 | args.checkpoint_enc is None 143 | and args.checkpoint_dec is None 144 | and args.checkpoint is None 145 | ): 146 | logging.warning( 147 | "Aborting. Both encoder or decoder JIT required. Or provide the full autoencoder JIT model." 148 | ) 149 | return 150 | 151 | if args.mode == "torch": 152 | tokenizer_config = TokenizerConfigs[args.tokenizer_type].value 153 | tokenizer_config.update(dict(spatial_compression=args.spatial_compression)) 154 | else: 155 | tokenizer_config = None 156 | 157 | logging.info( 158 | f"Loading a torch.jit model `{os.path.dirname(args.checkpoint or args.checkpoint_enc or args.checkpoint_dec)}` ..." 159 | ) 160 | autoencoder = ImageTokenizer( 161 | checkpoint=args.checkpoint, 162 | checkpoint_enc=args.checkpoint_enc, 163 | checkpoint_dec=args.checkpoint_dec, 164 | tokenizer_config=tokenizer_config, 165 | device=args.device, 166 | dtype=args.dtype, 167 | ) 168 | 169 | filepaths = get_filepaths(args.image_pattern) 170 | logging.info(f"Found {len(filepaths)} images from {args.image_pattern}.") 171 | 172 | for filepath in filepaths: 173 | logging.info(f"Reading image {filepath} ...") 174 | image = read_image(filepath) 175 | image = resize_image(image, short_size=args.short_size) 176 | batch_image = np.expand_dims(image, axis=0) 177 | 178 | logging.info("Invoking the autoencoder model in ... ") 179 | output_image = autoencoder(batch_image)[0] 180 | 181 | output_filepath = get_output_filepath(filepath, output_dir=args.output_dir) 182 | logging.info(f"Outputing {output_filepath} ...") 183 | write_image(output_filepath, output_image) 184 | 185 | if args.save_input: 186 | ext = os.path.splitext(output_filepath)[-1] 187 | input_filepath = output_filepath.replace(ext, "_input" + ext) 188 | write_image(input_filepath, image) 189 | 190 | 191 | @logging.catch(reraise=True) 192 | def main() -> None: 193 | _run_eval() 194 | 195 | 196 | if __name__ == "__main__": 197 | main() 198 | -------------------------------------------------------------------------------- /open_tokenizer/model/cosmos_tokenizer/image_lib.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """A library for image tokenizers inference.""" 16 | 17 | import numpy as np 18 | import torch 19 | from typing import Any 20 | 21 | from .utils import ( 22 | load_model, 23 | load_encoder_model, 24 | load_decoder_model, 25 | numpy2tensor, 26 | pad_image_batch, 27 | tensor2numpy, 28 | unpad_image_batch, 29 | ) 30 | 31 | 32 | class ImageTokenizer(torch.nn.Module): 33 | def __init__( 34 | self, 35 | checkpoint: str = None, 36 | checkpoint_enc: str = None, 37 | checkpoint_dec: str = None, 38 | tokenizer_config: dict[str, Any] = None, 39 | device: str = "cuda", 40 | dtype: str = "bfloat16", 41 | ) -> None: 42 | super().__init__() 43 | self._device = device 44 | self._dtype = getattr(torch, dtype) 45 | self._full_model = ( 46 | load_model(checkpoint, tokenizer_config, device).to(self._dtype) 47 | if checkpoint is not None 48 | else None 49 | ) 50 | self._enc_model = ( 51 | load_encoder_model(checkpoint_enc, tokenizer_config, device).to(self._dtype) 52 | if checkpoint_enc is not None 53 | else None 54 | ) 55 | self._dec_model = ( 56 | load_decoder_model(checkpoint_dec, tokenizer_config, device).to(self._dtype) 57 | if checkpoint_dec is not None 58 | else None 59 | ) 60 | self.codebook_embed_dim = tokenizer_config["z_channels"] 61 | 62 | @torch.no_grad() 63 | def autoencode(self, input_tensor: torch.Tensor) -> torch.Tensor: 64 | """Reconstrcuts a batch of image tensors after embedding into a latent. 65 | 66 | Args: 67 | input_tensor: The input image Bx3xHxW layout, range [-1..1]. 68 | Returns: 69 | The reconstructed tensor, layout Bx3xHxW, range [-1..1]. 70 | """ 71 | if self._full_model is not None: 72 | output_tensor = self._full_model(input_tensor) 73 | output_tensor = ( 74 | output_tensor[0] if isinstance(output_tensor, tuple) else output_tensor 75 | ) 76 | else: 77 | output_latent = self.encode(input_tensor)[0] 78 | output_tensor = self.decode(output_latent) 79 | return output_tensor 80 | 81 | @torch.no_grad() 82 | def decode(self, input_latent: torch.Tensor, latent_shape=None): 83 | """Decodes an image from a provided latent embedding. 84 | 85 | Args: 86 | input_latent: The continuous latent Bx16xhxw for CI, 87 | or the discrete indices Bxhxw for DI. 88 | Returns: 89 | The output tensor in Bx3xHxW, range [-1..1]. 90 | """ 91 | output_tensor = self._dec_model(input_latent) 92 | return output_tensor 93 | 94 | @torch.no_grad() 95 | def encode(self, input_tensor: torch.Tensor): 96 | """Encodes an image into a latent embedding or code. 97 | 98 | Args: 99 | input_tensor: The input tensor Bx3xHxW layout, range [-1..1]. 100 | Returns: 101 | For continuous image (CI) tokenizer, the tuple contains: 102 | - The latent embedding, Bx16x(h)x(w), where the compression 103 | rate is (H/h x W/w), and channel dimension of 16. 104 | For discrete image (DI) tokenizer, the tuple contains: 105 | - The indices, Bx(h)x(w), from a codebook of size 64K, which 106 | corresponds to FSQ levels of (8,8,8,5,5,5). 107 | - The discrete code, Bx6x(h)x(w), where the compression rate is 108 | again (H/h x W/w), and channel dimension of 6. 109 | """ 110 | output_latent = self._enc_model(input_tensor) 111 | if isinstance(output_latent, torch.Tensor): 112 | return "pad", output_latent 113 | return "pad", output_latent[0], output_latent[1] 114 | 115 | @torch.no_grad() 116 | def forward(self, image: np.ndarray) -> np.ndarray: 117 | """Reconstructs an image using a pre-trained tokenizer. 118 | 119 | Args: 120 | image: The input image BxHxWxC layout, range [0..255]. 121 | Returns: 122 | The reconstructed image in range [0..255], layout BxHxWxC. 123 | """ 124 | padded_input_image, crop_region = pad_image_batch(image) 125 | input_tensor = numpy2tensor( 126 | padded_input_image, dtype=self._dtype, device=self._device 127 | ) 128 | output_tensor = self.autoencode(input_tensor) 129 | padded_output_image = tensor2numpy(output_tensor) 130 | return unpad_image_batch(padded_output_image, crop_region) 131 | -------------------------------------------------------------------------------- /open_tokenizer/model/cosmos_tokenizer/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | from enum import Enum 16 | 17 | from .distributions import ( 18 | GaussianDistribution, 19 | IdentityDistribution, 20 | ) 21 | from .layers2d import Decoder, Encoder 22 | from .layers3d import ( 23 | DecoderBase, 24 | DecoderFactorized, 25 | EncoderBase, 26 | EncoderFactorized, 27 | ) 28 | from .quantizers import ( 29 | FSQuantizer, 30 | LFQuantizer, 31 | ResidualFSQuantizer, 32 | VectorQuantizer, 33 | ) 34 | 35 | 36 | class EncoderType(Enum): 37 | Default = Encoder 38 | 39 | 40 | class DecoderType(Enum): 41 | Default = Decoder 42 | 43 | 44 | class Encoder3DType(Enum): 45 | BASE = EncoderBase 46 | FACTORIZED = EncoderFactorized 47 | 48 | 49 | class Decoder3DType(Enum): 50 | BASE = DecoderBase 51 | FACTORIZED = DecoderFactorized 52 | 53 | 54 | class ContinuousFormulation(Enum): 55 | VAE = GaussianDistribution 56 | AE = IdentityDistribution 57 | 58 | 59 | class DiscreteQuantizer(Enum): 60 | VQ = VectorQuantizer 61 | LFQ = LFQuantizer 62 | FSQ = FSQuantizer 63 | RESFSQ = ResidualFSQuantizer 64 | -------------------------------------------------------------------------------- /open_tokenizer/model/cosmos_tokenizer/modules/distributions.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """The distribution modes to use for continuous image tokenizers.""" 16 | 17 | import torch 18 | 19 | 20 | class IdentityDistribution(torch.nn.Module): 21 | def __init__(self): 22 | super().__init__() 23 | 24 | def forward(self, parameters): 25 | return parameters, (torch.tensor([0.0]), torch.tensor([0.0])) 26 | 27 | 28 | class GaussianDistribution(torch.nn.Module): 29 | def __init__(self, min_logvar: float = -30.0, max_logvar: float = 20.0): 30 | super().__init__() 31 | self.min_logvar = min_logvar 32 | self.max_logvar = max_logvar 33 | 34 | def sample(self, mean, logvar): 35 | std = torch.exp(0.5 * logvar) 36 | return mean + std * torch.randn_like(mean) 37 | 38 | def forward(self, parameters): 39 | mean, logvar = torch.chunk(parameters, 2, dim=1) 40 | logvar = torch.clamp(logvar, self.min_logvar, self.max_logvar) 41 | return self.sample(mean, logvar), (mean, logvar) 42 | -------------------------------------------------------------------------------- /open_tokenizer/model/cosmos_tokenizer/modules/utils.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Shared utilities for the networks module.""" 16 | 17 | from typing import Any 18 | 19 | import torch 20 | from einops import pack, rearrange, unpack 21 | 22 | 23 | def time2batch(x: torch.Tensor) -> tuple[torch.Tensor, int]: 24 | batch_size = x.shape[0] 25 | return rearrange(x, "b c t h w -> (b t) c h w"), batch_size 26 | 27 | 28 | def batch2time(x: torch.Tensor, batch_size: int) -> torch.Tensor: 29 | return rearrange(x, "(b t) c h w -> b c t h w", b=batch_size) 30 | 31 | 32 | def space2batch(x: torch.Tensor) -> tuple[torch.Tensor, int]: 33 | batch_size, height = x.shape[0], x.shape[-2] 34 | return rearrange(x, "b c t h w -> (b h w) c t"), batch_size, height 35 | 36 | 37 | def batch2space(x: torch.Tensor, batch_size: int, height: int) -> torch.Tensor: 38 | return rearrange(x, "(b h w) c t -> b c t h w", b=batch_size, h=height) 39 | 40 | 41 | def cast_tuple(t: Any, length: int = 1) -> Any: 42 | return t if isinstance(t, tuple) else ((t,) * length) 43 | 44 | 45 | def replication_pad(x): 46 | return torch.cat([x[:, :, :1, ...], x], dim=2) 47 | 48 | 49 | def divisible_by(num: int, den: int) -> bool: 50 | return (num % den) == 0 51 | 52 | 53 | def is_odd(n: int) -> bool: 54 | return not divisible_by(n, 2) 55 | 56 | 57 | def nonlinearity(x): 58 | return x * torch.sigmoid(x) 59 | 60 | 61 | def Normalize(in_channels, num_groups=32): 62 | return torch.nn.GroupNorm( 63 | num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True 64 | ) 65 | 66 | 67 | class CausalNormalize(torch.nn.Module): 68 | def __init__(self, in_channels, num_groups=1): 69 | super().__init__() 70 | self.norm = torch.nn.GroupNorm( 71 | num_groups=num_groups, 72 | num_channels=in_channels, 73 | eps=1e-6, 74 | affine=True, 75 | ) 76 | self.num_groups = num_groups 77 | 78 | def forward(self, x): 79 | # if num_groups !=1, we apply a spatio-temporal groupnorm for backward compatibility purpose. 80 | # All new models should use num_groups=1, otherwise causality is not guaranteed. 81 | if self.num_groups == 1: 82 | x, batch_size = time2batch(x) 83 | return batch2time(self.norm(x), batch_size) 84 | return self.norm(x) 85 | 86 | 87 | def exists(v): 88 | return v is not None 89 | 90 | 91 | def default(*args): 92 | for arg in args: 93 | if exists(arg): 94 | return arg 95 | return None 96 | 97 | 98 | def pack_one(t, pattern): 99 | return pack([t], pattern) 100 | 101 | 102 | def unpack_one(t, ps, pattern): 103 | return unpack(t, ps, pattern)[0] 104 | 105 | 106 | def round_ste(z: torch.Tensor) -> torch.Tensor: 107 | """Round with straight through gradients.""" 108 | zhat = z.round() 109 | return z + (zhat - z).detach() 110 | 111 | 112 | def log(t, eps=1e-5): 113 | return t.clamp(min=eps).log() 114 | 115 | 116 | def entropy(prob): 117 | return (-prob * log(prob)).sum(dim=-1) 118 | -------------------------------------------------------------------------------- /open_tokenizer/model/cosmos_tokenizer/networks/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from enum import Enum 17 | 18 | from .configs import ( 19 | continuous_image as continuous_image_dict, 20 | ) 21 | from .configs import ( 22 | discrete_image as discrete_image_dict, 23 | ) 24 | from .configs import ( 25 | continuous_video as continuous_video_dict, 26 | ) 27 | from .configs import ( 28 | discrete_video as discrete_video_dict, 29 | ) 30 | 31 | from .continuous_image import ContinuousImageTokenizer 32 | from .discrete_image import DiscreteImageTokenizer 33 | from .continuous_video import ( 34 | CausalContinuousVideoTokenizer, 35 | ) 36 | from .discrete_video import ( 37 | CausalDiscreteVideoTokenizer, 38 | ) 39 | 40 | 41 | class TokenizerConfigs(Enum): 42 | CI = continuous_image_dict 43 | DI = discrete_image_dict 44 | CV = continuous_video_dict 45 | DV = discrete_video_dict 46 | 47 | 48 | class TokenizerModels(Enum): 49 | CI = ContinuousImageTokenizer 50 | DI = DiscreteImageTokenizer 51 | CV = CausalContinuousVideoTokenizer 52 | DV = CausalDiscreteVideoTokenizer 53 | -------------------------------------------------------------------------------- /open_tokenizer/model/cosmos_tokenizer/networks/configs.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """The default image and video tokenizer configs.""" 16 | 17 | from ..modules import ( 18 | ContinuousFormulation, 19 | DiscreteQuantizer, 20 | EncoderType, 21 | DecoderType, 22 | Encoder3DType, 23 | Decoder3DType, 24 | ) 25 | 26 | continuous_image = dict( 27 | # The attention resolution for res blocks. 28 | attn_resolutions=[32], 29 | # The base number of channels. 30 | channels=128, 31 | # The channel multipler for each resolution. 32 | channels_mult=[2, 4, 4], 33 | dropout=0.0, 34 | in_channels=3, 35 | # The spatial compression ratio. 36 | spatial_compression=16, 37 | # The number of layers in each res block. 38 | num_res_blocks=2, 39 | out_channels=3, 40 | resolution=1024, 41 | patch_size=4, 42 | patch_method="haar", 43 | # The output latent dimension (channels). 44 | latent_channels=16, 45 | # The encoder output channels just before sampling. 46 | # Which is also the decoder's input channels. 47 | z_channels=16, 48 | # A factor over the z_channels, to get the total channels the encoder should output. 49 | # For a VAE for instance, we want to output the mean and variance, so we need 2 * z_channels. 50 | z_factor=1, 51 | name="CI", 52 | # What formulation to use, either "AE" or "VAE". 53 | # Chose VAE here, since the pre-trained ckpt were of a VAE formulation. 54 | formulation=ContinuousFormulation.AE.name, 55 | # Specify type of encoder ["Default", "LiteVAE"] 56 | encoder=EncoderType.Default.name, 57 | # Specify type of decoder ["Default"] 58 | decoder=DecoderType.Default.name, 59 | ) 60 | 61 | discrete_image = dict( 62 | # The attention resolution for res blocks. 63 | attn_resolutions=[32], 64 | # The base number of channels. 65 | channels=128, 66 | # The channel multipler for each resolution. 67 | channels_mult=[2, 4, 4], 68 | dropout=0.0, 69 | in_channels=3, 70 | # The spatial compression ratio. 71 | spatial_compression=16, 72 | # The number of layers in each res block. 73 | num_res_blocks=2, 74 | out_channels=3, 75 | resolution=1024, 76 | patch_size=4, 77 | patch_method="haar", 78 | # The encoder output channels just before sampling. 79 | z_channels=256, 80 | # A factor over the z_channels, to get the total channels the encoder should output. 81 | # for discrete tokenization, often we directly use the vector, so z_factor=1. 82 | z_factor=1, 83 | # The quantizer of choice, VQ, LFQ, FSQ, or ResFSQ. 84 | quantizer=DiscreteQuantizer.FSQ.name, 85 | # The embedding dimension post-quantization, which is also the input channels of the decoder. 86 | # Which is also the output 87 | embedding_dim=6, 88 | # The number of levels to use for fine-scalar quantization. 89 | levels=[8, 8, 8, 5, 5, 5], 90 | # The number of quantizers to use for residual fine-scalar quantization. 91 | num_quantizers=4, 92 | name="DI", 93 | # Specify type of encoder ["Default", "LiteVAE"] 94 | encoder=EncoderType.Default.name, 95 | # Specify type of decoder ["Default"] 96 | decoder=DecoderType.Default.name, 97 | ) 98 | 99 | continuous_video = dict( 100 | attn_resolutions=[32], 101 | channels=128, 102 | channels_mult=[2, 4, 4], 103 | dropout=0.0, 104 | in_channels=3, 105 | num_res_blocks=2, 106 | out_channels=3, 107 | resolution=1024, 108 | patch_size=4, 109 | patch_method="haar", 110 | latent_channels=16, 111 | z_channels=16, 112 | z_factor=1, 113 | num_groups=1, 114 | legacy_mode=False, 115 | spatial_compression=8, 116 | temporal_compression=8, 117 | formulation=ContinuousFormulation.AE.name, 118 | encoder=Encoder3DType.FACTORIZED.name, 119 | decoder=Decoder3DType.FACTORIZED.name, 120 | name="CV", 121 | ) 122 | 123 | discrete_video = dict( 124 | attn_resolutions=[32], 125 | channels=128, 126 | channels_mult=[2, 4, 4], 127 | dropout=0.0, 128 | in_channels=3, 129 | num_res_blocks=2, 130 | out_channels=3, 131 | resolution=1024, 132 | patch_size=4, 133 | patch_method="haar", 134 | z_channels=16, 135 | z_factor=1, 136 | num_groups=1, 137 | legacy_mode=False, 138 | spatial_compression=16, 139 | temporal_compression=8, 140 | quantizer=DiscreteQuantizer.FSQ.name, 141 | embedding_dim=6, 142 | levels=[8, 8, 8, 5, 5, 5], 143 | encoder=Encoder3DType.FACTORIZED.name, 144 | decoder=Decoder3DType.FACTORIZED.name, 145 | name="DV", 146 | ) 147 | -------------------------------------------------------------------------------- /open_tokenizer/model/cosmos_tokenizer/networks/continuous_image.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """The continuous image tokenizer with VAE or AE formulation for 2D data.""" 16 | 17 | from collections import OrderedDict, namedtuple 18 | 19 | import torch 20 | from loguru import logger as logging 21 | from torch import nn 22 | 23 | from ..modules import ( 24 | ContinuousFormulation, 25 | DecoderType, 26 | EncoderType, 27 | ) 28 | 29 | NetworkEval = namedtuple("NetworkEval", ["reconstructions", "posteriors", "latent"]) 30 | 31 | 32 | class ContinuousImageTokenizer(nn.Module): 33 | def __init__( 34 | self, z_channels: int, z_factor: int, latent_channels: int, **kwargs 35 | ) -> None: 36 | super().__init__() 37 | self.name = kwargs.get("name", "ContinuousImageTokenizer") 38 | self.latent_channels = latent_channels 39 | 40 | encoder_name = kwargs.get("encoder", EncoderType.Default.name) 41 | self.encoder = EncoderType[encoder_name].value( 42 | z_channels=z_factor * z_channels, **kwargs 43 | ) 44 | 45 | decoder_name = kwargs.get("decoder", DecoderType.Default.name) 46 | self.decoder = DecoderType[decoder_name].value(z_channels=z_channels, **kwargs) 47 | 48 | self.quant_conv = torch.nn.Conv2d( 49 | z_factor * z_channels, z_factor * latent_channels, 1 50 | ) 51 | self.post_quant_conv = torch.nn.Conv2d(latent_channels, z_channels, 1) 52 | 53 | formulation_name = kwargs.get("formulation", ContinuousFormulation.AE.name) 54 | self.distribution = ContinuousFormulation[formulation_name].value() 55 | logging.info( 56 | f"{self.name} based on {formulation_name} formulation, with {kwargs}." 57 | ) 58 | 59 | num_parameters = sum(param.numel() for param in self.parameters()) 60 | logging.info(f"model={self.name}, num_parameters={num_parameters:,}") 61 | logging.info( 62 | f"z_channels={z_channels}, latent_channels={self.latent_channels}." 63 | ) 64 | 65 | def encoder_jit(self): 66 | return nn.Sequential( 67 | OrderedDict( 68 | [ 69 | ("encoder", self.encoder), 70 | ("quant_conv", self.quant_conv), 71 | ("distribution", self.distribution), 72 | ] 73 | ) 74 | ) 75 | 76 | def decoder_jit(self): 77 | return nn.Sequential( 78 | OrderedDict( 79 | [ 80 | ("post_quant_conv", self.post_quant_conv), 81 | ("decoder", self.decoder), 82 | ] 83 | ) 84 | ) 85 | 86 | def last_decoder_layer(self): 87 | return self.decoder.conv_out 88 | 89 | def encode(self, x): 90 | h = self.encoder(x) 91 | moments = self.quant_conv(h) 92 | return self.distribution(moments) 93 | 94 | def decode(self, z): 95 | z = self.post_quant_conv(z) 96 | dec = self.decoder(z) 97 | return dec 98 | 99 | def forward(self, input) -> dict[str, torch.Tensor] | NetworkEval: 100 | latent, posteriors = self.encode(input) 101 | dec = self.decode(latent) 102 | if self.training: 103 | return dict(reconstructions=dec, posteriors=posteriors, latent=latent) 104 | return NetworkEval(reconstructions=dec, posteriors=posteriors, latent=latent) 105 | -------------------------------------------------------------------------------- /open_tokenizer/model/cosmos_tokenizer/networks/continuous_video.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """The causal continuous video tokenizer with VAE or AE formulation for 3D data..""" 16 | from collections import OrderedDict, namedtuple 17 | 18 | from loguru import logger as logging 19 | from torch import nn 20 | 21 | from ..modules import ( 22 | ContinuousFormulation, 23 | Decoder3DType, 24 | Encoder3DType, 25 | ) 26 | from ..modules.layers3d import CausalConv3d 27 | 28 | NetworkEval = namedtuple("NetworkEval", ["reconstructions", "posteriors", "latent"]) 29 | 30 | 31 | class CausalContinuousVideoTokenizer(nn.Module): 32 | def __init__( 33 | self, z_channels: int, z_factor: int, latent_channels: int, **kwargs 34 | ) -> None: 35 | super().__init__() 36 | self.name = kwargs.get("name", "CausalContinuousVideoTokenizer") 37 | self.latent_channels = latent_channels 38 | 39 | encoder_name = kwargs.get("encoder", Encoder3DType.BASE.name) 40 | self.encoder = Encoder3DType[encoder_name].value( 41 | z_channels=z_factor * z_channels, **kwargs 42 | ) 43 | if kwargs.get("temporal_compression", 4) == 4: 44 | kwargs["channels_mult"] = [2, 4] 45 | decoder_name = kwargs.get("decoder", Decoder3DType.BASE.name) 46 | self.decoder = Decoder3DType[decoder_name].value( 47 | z_channels=z_channels, **kwargs 48 | ) 49 | 50 | self.quant_conv = CausalConv3d( 51 | z_factor * z_channels, 52 | z_factor * latent_channels, 53 | kernel_size=1, 54 | padding=0, 55 | ) 56 | self.post_quant_conv = CausalConv3d( 57 | latent_channels, z_channels, kernel_size=1, padding=0 58 | ) 59 | 60 | formulation_name = kwargs.get("formulation", ContinuousFormulation.AE.name) 61 | self.distribution = ContinuousFormulation[formulation_name].value() 62 | logging.info( 63 | f"{self.name} based on {formulation_name} formulation, with {kwargs}." 64 | ) 65 | 66 | num_parameters = sum(param.numel() for param in self.parameters()) 67 | logging.info(f"model={self.name}, num_parameters={num_parameters:,}") 68 | logging.info( 69 | f"z_channels={z_channels}, latent_channels={self.latent_channels}." 70 | ) 71 | 72 | def encoder_jit(self): 73 | return nn.Sequential( 74 | OrderedDict( 75 | [ 76 | ("encoder", self.encoder), 77 | ("quant_conv", self.quant_conv), 78 | ("distribution", self.distribution), 79 | ] 80 | ) 81 | ) 82 | 83 | def decoder_jit(self): 84 | return nn.Sequential( 85 | OrderedDict( 86 | [ 87 | ("post_quant_conv", self.post_quant_conv), 88 | ("decoder", self.decoder), 89 | ] 90 | ) 91 | ) 92 | 93 | def last_decoder_layer(self): 94 | return self.decoder.conv_out 95 | 96 | def encode(self, x): 97 | h = self.encoder(x) 98 | moments = self.quant_conv(h) 99 | return self.distribution(moments) 100 | 101 | def decode(self, z): 102 | z = self.post_quant_conv(z) 103 | return self.decoder(z) 104 | 105 | def forward(self, input): 106 | latent, posteriors = self.encode(input) 107 | reconstructions = self.decode(latent) 108 | if self.training: 109 | return dict( 110 | reconstructions=reconstructions, 111 | posteriors=posteriors, 112 | latent=latent, 113 | ) 114 | return NetworkEval( 115 | reconstructions=reconstructions, 116 | posteriors=posteriors, 117 | latent=latent, 118 | ) 119 | -------------------------------------------------------------------------------- /open_tokenizer/model/cosmos_tokenizer/networks/discrete_image.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """The network definition for discrete image tokenization with VQ, LFQ, FSQ or ResidualFSQ.""" 16 | from collections import OrderedDict, namedtuple 17 | 18 | import torch 19 | from loguru import logger as logging 20 | from torch import nn 21 | 22 | from ..modules import DecoderType, DiscreteQuantizer, EncoderType 23 | from ..modules.quantizers import InvQuantizerJit 24 | 25 | NetworkEval = namedtuple("NetworkEval", ["reconstructions", "quant_loss", "quant_info"]) 26 | 27 | 28 | class DiscreteImageTokenizer(nn.Module): 29 | def __init__(self, z_channels: int, embedding_dim: int, **kwargs) -> None: 30 | super().__init__() 31 | self.name = kwargs.get("name", "DiscreteImageTokenizer") 32 | self.embedding_dim = embedding_dim 33 | 34 | encoder_name = kwargs.get("encoder", EncoderType.Default.name) 35 | self.encoder = EncoderType[encoder_name].value(z_channels=z_channels, **kwargs) 36 | 37 | decoder_name = kwargs.get("decoder", DecoderType.Default.name) 38 | self.decoder = DecoderType[decoder_name].value(z_channels=z_channels, **kwargs) 39 | self.quant_conv = nn.Conv2d(z_channels, embedding_dim, 1) 40 | self.post_quant_conv = nn.Conv2d(embedding_dim, z_channels, 1) 41 | 42 | quantizer_name = kwargs.get("quantizer", DiscreteQuantizer.RESFSQ.name) 43 | if quantizer_name == DiscreteQuantizer.VQ.name: 44 | assert ( 45 | "num_embeddings" in kwargs 46 | ), f"`num_embeddings` must be provided for {quantizer_name}." 47 | kwargs.update(dict(embedding_dim=embedding_dim)) 48 | elif quantizer_name == DiscreteQuantizer.LFQ.name: 49 | assert ( 50 | "codebook_size" in kwargs 51 | ), f"`codebook_size` must be provided for {quantizer_name}." 52 | assert ( 53 | "codebook_dim" in kwargs 54 | ), f"`codebook_dim` must be provided for {quantizer_name}." 55 | elif quantizer_name == DiscreteQuantizer.FSQ.name: 56 | assert ( 57 | "levels" in kwargs 58 | ), f"`levels` must be provided for {quantizer_name}." 59 | elif quantizer_name == DiscreteQuantizer.RESFSQ.name: 60 | assert ( 61 | "levels" in kwargs 62 | ), f"`levels` must be provided for {quantizer_name}.name." 63 | assert ( 64 | "num_quantizers" in kwargs 65 | ), f"`num_quantizers` must be provided for {quantizer_name}." 66 | self.quantizer = DiscreteQuantizer[quantizer_name].value(**kwargs) 67 | logging.info(f"{self.name} based on {quantizer_name}-VAE, with {kwargs}.") 68 | 69 | num_parameters = sum(param.numel() for param in self.parameters()) 70 | logging.info(f"model={self.name}, num_parameters={num_parameters:,}") 71 | logging.info(f"z_channels={z_channels}, embedding_dim={self.embedding_dim}.") 72 | 73 | def to(self, *args, **kwargs): 74 | setattr(self.quantizer, "dtype", kwargs.get("dtype", torch.bfloat16)) 75 | return super(DiscreteImageTokenizer, self).to(*args, **kwargs) 76 | 77 | def encoder_jit(self): 78 | return nn.Sequential( 79 | OrderedDict( 80 | [ 81 | ("encoder", self.encoder), 82 | ("quant_conv", self.quant_conv), 83 | ("quantizer", self.quantizer), 84 | ] 85 | ) 86 | ) 87 | 88 | def decoder_jit(self): 89 | return nn.Sequential( 90 | OrderedDict( 91 | [ 92 | ("inv_quant", InvQuantizerJit(self.quantizer)), 93 | ("post_quant_conv", self.post_quant_conv), 94 | ("decoder", self.decoder), 95 | ] 96 | ) 97 | ) 98 | 99 | def last_decoder_layer(self): 100 | return self.decoder.conv_out 101 | 102 | def encode(self, x): 103 | h = self.encoder(x) 104 | h = self.quant_conv(h) 105 | return self.quantizer(h) 106 | 107 | def decode(self, quant): 108 | quant = self.post_quant_conv(quant) 109 | return self.decoder(quant) 110 | 111 | def decode_code(self, code_b): 112 | quant_b = self.quantizer.indices_to_codes(code_b) 113 | quant_b = self.post_quant_conv(quant_b) 114 | return self.decoder(quant_b) 115 | 116 | def forward(self, input): 117 | quant_info, quant_codes, quant_loss = self.encode(input) 118 | reconstructions = self.decode(quant_codes) 119 | if self.training: 120 | return dict( 121 | reconstructions=reconstructions, 122 | quant_loss=quant_loss, 123 | quant_info=quant_info, 124 | ) 125 | return NetworkEval( 126 | reconstructions=reconstructions, 127 | quant_loss=quant_loss, 128 | quant_info=quant_info, 129 | ) 130 | -------------------------------------------------------------------------------- /open_tokenizer/model/cosmos_tokenizer/networks/discrete_video.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """The network definition for discrete video tokenizer with VQ, LFQ, FSQ or ResidualFSQ. """ 16 | from collections import OrderedDict, namedtuple 17 | 18 | import torch 19 | from loguru import logger as logging 20 | from torch import nn 21 | 22 | from ..modules import ( 23 | Decoder3DType, 24 | DiscreteQuantizer, 25 | Encoder3DType, 26 | ) 27 | from ..modules.layers3d import CausalConv3d 28 | from ..modules.quantizers import InvQuantizerJit 29 | 30 | NetworkEval = namedtuple("NetworkEval", ["reconstructions", "quant_loss", "quant_info"]) 31 | 32 | 33 | class CausalDiscreteVideoTokenizer(nn.Module): 34 | def __init__( 35 | self, z_channels: int, z_factor: int, embedding_dim: int, **kwargs 36 | ) -> None: 37 | super().__init__() 38 | self.name = kwargs.get("name", "CausalDiscreteVideoTokenizer") 39 | self.embedding_dim = embedding_dim 40 | 41 | encoder_name = kwargs.get("encoder", Encoder3DType.BASE.name) 42 | self.encoder = Encoder3DType[encoder_name].value( 43 | z_channels=z_factor * z_channels, **kwargs 44 | ) 45 | 46 | decoder_name = kwargs.get("decoder", Decoder3DType.BASE.name) 47 | self.decoder = Decoder3DType[decoder_name].value( 48 | z_channels=z_channels, **kwargs 49 | ) 50 | 51 | self.quant_conv = CausalConv3d( 52 | z_factor * z_channels, embedding_dim, kernel_size=1, padding=0 53 | ) 54 | self.post_quant_conv = CausalConv3d( 55 | embedding_dim, z_channels, kernel_size=1, padding=0 56 | ) 57 | 58 | quantizer_name = kwargs.get("quantizer", DiscreteQuantizer.RESFSQ.name) 59 | if quantizer_name == DiscreteQuantizer.VQ.name: 60 | assert ( 61 | "num_embeddings" in kwargs 62 | ), f"`num_embeddings` must be provided for {quantizer_name}." 63 | kwargs.update(dict(embedding_dim=embedding_dim)) 64 | elif quantizer_name == DiscreteQuantizer.LFQ.name: 65 | assert ( 66 | "codebook_size" in kwargs 67 | ), f"`codebook_size` must be provided for {quantizer_name}." 68 | assert ( 69 | "codebook_dim" in kwargs 70 | ), f"`codebook_dim` must be provided for {quantizer_name}." 71 | elif quantizer_name == DiscreteQuantizer.FSQ.name: 72 | assert ( 73 | "levels" in kwargs 74 | ), f"`levels` must be provided for {quantizer_name}." 75 | elif quantizer_name == DiscreteQuantizer.RESFSQ.name: 76 | assert ( 77 | "levels" in kwargs 78 | ), f"`levels` must be provided for {quantizer_name}." 79 | assert ( 80 | "num_quantizers" in kwargs 81 | ), f"`num_quantizers` must be provided for {quantizer_name}." 82 | self.quantizer = DiscreteQuantizer[quantizer_name].value(**kwargs) 83 | logging.info(f"{self.name} based on {quantizer_name}-VAE, with {kwargs}.") 84 | 85 | num_parameters = sum(param.numel() for param in self.parameters()) 86 | logging.info(f"model={self.name}, num_parameters={num_parameters:,}") 87 | logging.info(f"z_channels={z_channels}, embedding_dim={self.embedding_dim}.") 88 | 89 | def to(self, *args, **kwargs): 90 | setattr(self.quantizer, "dtype", kwargs.get("dtype", torch.bfloat16)) 91 | return super(CausalDiscreteVideoTokenizer, self).to(*args, **kwargs) 92 | 93 | def encoder_jit(self): 94 | return nn.Sequential( 95 | OrderedDict( 96 | [ 97 | ("encoder", self.encoder), 98 | ("quant_conv", self.quant_conv), 99 | ("quantizer", self.quantizer), 100 | ] 101 | ) 102 | ) 103 | 104 | def decoder_jit(self): 105 | return nn.Sequential( 106 | OrderedDict( 107 | [ 108 | ("inv_quant", InvQuantizerJit(self.quantizer)), 109 | ("post_quant_conv", self.post_quant_conv), 110 | ("decoder", self.decoder), 111 | ] 112 | ) 113 | ) 114 | 115 | def last_decoder_layer(self): 116 | return self.decoder.conv_out 117 | 118 | def encode(self, x): 119 | h = self.encoder(x) 120 | h = self.quant_conv(h) 121 | return self.quantizer(h) 122 | 123 | def decode(self, quant): 124 | quant = self.post_quant_conv(quant) 125 | return self.decoder(quant) 126 | 127 | def decode_code(self, code_b): 128 | quant_b = self.quantizer.indices_to_codes(code_b) 129 | quant_b = self.post_quant_conv(quant_b) 130 | return self.decoder(quant_b) 131 | 132 | def forward(self, input): 133 | quant_info, quant_codes, quant_loss = self.encode(input) 134 | reconstructions = self.decode(quant_codes) 135 | if self.training: 136 | return dict( 137 | reconstructions=reconstructions, 138 | quant_loss=quant_loss, 139 | quant_info=quant_info, 140 | ) 141 | return NetworkEval( 142 | reconstructions=reconstructions, 143 | quant_loss=quant_loss, 144 | quant_info=quant_info, 145 | ) 146 | -------------------------------------------------------------------------------- /open_tokenizer/model/cosmos_tokenizer/video_cli.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """A CLI to run CausalVideoTokenizer on plain videos based on torch.jit. 16 | 17 | Usage: 18 | python3 -m cosmos_tokenizer.video_cli \ 19 | --video_pattern 'path/to/video/samples/*.mp4' \ 20 | --output_dir ./reconstructions \ 21 | --checkpoint_enc ./pretrained_ckpts/CosmosCV_f4x8x8/encoder.jit \ 22 | --checkpoint_dec ./pretrained_ckpts/CosmosCV_f4x8x8/decoder.jit 23 | 24 | Optionally, you can run the model in pure PyTorch mode: 25 | python3 -m cosmos_tokenizer.video_cli \ 26 | --video_pattern 'path/to/video/samples/*.mp4' \ 27 | --mode=torch \ 28 | --tokenizer_type=CV \ 29 | --temporal_compression=4 \ 30 | --spatial_compression=8 \ 31 | --checkpoint_enc ./pretrained_ckpts/CosmosCV_f4x8x8/encoder.jit \ 32 | --checkpoint_dec ./pretrained_ckpts/CosmosCV_f4x8x8/decoder.jit 33 | """ 34 | 35 | import os 36 | from argparse import ArgumentParser, Namespace 37 | from typing import Any 38 | import sys 39 | 40 | import numpy as np 41 | from loguru import logger as logging 42 | 43 | from .networks import TokenizerConfigs 44 | from .utils import ( 45 | get_filepaths, 46 | get_output_filepath, 47 | read_video, 48 | resize_video, 49 | write_video, 50 | ) 51 | from .video_lib import CausalVideoTokenizer 52 | 53 | 54 | def _parse_args() -> tuple[Namespace, dict[str, Any]]: 55 | parser = ArgumentParser(description="A CLI for CausalVideoTokenizer.") 56 | parser.add_argument( 57 | "--video_pattern", 58 | type=str, 59 | default="path/to/videos/*.mp4", 60 | help="Glob pattern.", 61 | ) 62 | parser.add_argument( 63 | "--checkpoint", 64 | type=str, 65 | default=None, 66 | help="JIT full Autoencoder model filepath.", 67 | ) 68 | parser.add_argument( 69 | "--checkpoint_enc", 70 | type=str, 71 | default=None, 72 | help="JIT Encoder model filepath.", 73 | ) 74 | parser.add_argument( 75 | "--checkpoint_dec", 76 | type=str, 77 | default=None, 78 | help="JIT Decoder model filepath.", 79 | ) 80 | parser.add_argument( 81 | "--tokenizer_type", 82 | type=str, 83 | choices=["CV", "DV"], 84 | help="Specifies the tokenizer type.", 85 | ) 86 | parser.add_argument( 87 | "--spatial_compression", 88 | type=int, 89 | choices=[8, 16], 90 | default=8, 91 | help="The spatial compression factor.", 92 | ) 93 | parser.add_argument( 94 | "--temporal_compression", 95 | type=int, 96 | choices=[4, 8], 97 | default=4, 98 | help="The temporal compression factor.", 99 | ) 100 | parser.add_argument( 101 | "--mode", 102 | type=str, 103 | choices=["torch", "jit"], 104 | default="jit", 105 | help="Specify the backend: native 'torch' or 'jit' (default: 'jit')", 106 | ) 107 | parser.add_argument( 108 | "--short_size", 109 | type=int, 110 | default=None, 111 | help="The size to resample inputs. None, by default.", 112 | ) 113 | parser.add_argument( 114 | "--temporal_window", 115 | type=int, 116 | default=17, 117 | help="The temporal window to operate at a time.", 118 | ) 119 | parser.add_argument( 120 | "--dtype", 121 | type=str, 122 | default="bfloat16", 123 | help="Sets the precision, default bfloat16.", 124 | ) 125 | parser.add_argument( 126 | "--device", 127 | type=str, 128 | default="cuda", 129 | help="Device for invoking the model.", 130 | ) 131 | parser.add_argument( 132 | "--output_dir", type=str, default=None, help="Output directory." 133 | ) 134 | parser.add_argument( 135 | "--output_fps", 136 | type=float, 137 | default=24.0, 138 | help="Output frames-per-second (FPS).", 139 | ) 140 | parser.add_argument( 141 | "--save_input", 142 | action="store_true", 143 | help="If on, the input video will be be outputted too.", 144 | ) 145 | 146 | args = parser.parse_args() 147 | return args 148 | 149 | 150 | logging.info("Initializes args ...") 151 | args = _parse_args() 152 | if args.mode == "torch" and args.tokenizer_type not in ["CV", "DV"]: 153 | logging.error("'torch' backend requires the tokenizer_type of 'CV' or 'DV'.") 154 | sys.exit(1) 155 | 156 | 157 | def _run_eval() -> None: 158 | """Invokes JIT-compiled CausalVideoTokenizer on an input video.""" 159 | 160 | if ( 161 | args.checkpoint_enc is None 162 | and args.checkpoint_dec is None 163 | and args.checkpoint is None 164 | ): 165 | logging.warning( 166 | "Aborting. Both encoder or decoder JIT required. Or provide the full autoencoder JIT model." 167 | ) 168 | return 169 | 170 | if args.mode == "torch": 171 | tokenizer_config = TokenizerConfigs[args.tokenizer_type].value 172 | tokenizer_config.update(dict(spatial_compression=args.spatial_compression)) 173 | tokenizer_config.update(dict(temporal_compression=args.temporal_compression)) 174 | else: 175 | tokenizer_config = None 176 | 177 | logging.info( 178 | f"Loading a torch.jit model `{os.path.dirname(args.checkpoint or args.checkpoint_enc or args.checkpoint_dec)}` ..." 179 | ) 180 | autoencoder = CausalVideoTokenizer( 181 | checkpoint=args.checkpoint, 182 | checkpoint_enc=args.checkpoint_enc, 183 | checkpoint_dec=args.checkpoint_dec, 184 | tokenizer_config=tokenizer_config, 185 | device=args.device, 186 | dtype=args.dtype, 187 | ) 188 | 189 | logging.info(f"Looking for files matching video_pattern={args.video_pattern} ...") 190 | filepaths = get_filepaths(args.video_pattern) 191 | logging.info(f"Found {len(filepaths)} videos from {args.video_pattern}.") 192 | 193 | for filepath in filepaths: 194 | logging.info(f"Reading video {filepath} ...") 195 | video = read_video(filepath) 196 | video = resize_video(video, short_size=args.short_size) 197 | 198 | logging.info("Invoking the autoencoder model in ... ") 199 | batch_video = video[np.newaxis, ...] 200 | output_video = autoencoder(batch_video, temporal_window=args.temporal_window)[0] 201 | logging.info("Constructing output filepath ...") 202 | output_filepath = get_output_filepath(filepath, output_dir=args.output_dir) 203 | logging.info(f"Outputing {output_filepath} ...") 204 | write_video(output_filepath, output_video, fps=args.output_fps) 205 | if args.save_input: 206 | ext = os.path.splitext(output_filepath)[-1] 207 | input_filepath = output_filepath.replace(ext, "_input" + ext) 208 | write_video(input_filepath, video, fps=args.output_fps) 209 | 210 | 211 | @logging.catch(reraise=True) 212 | def main() -> None: 213 | _run_eval() 214 | 215 | 216 | if __name__ == "__main__": 217 | main() 218 | -------------------------------------------------------------------------------- /open_tokenizer/model/cosmos_tokenizer/video_lib.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """A library for Causal Video Tokenizer inference.""" 16 | 17 | import numpy as np 18 | import torch 19 | from typing import Any 20 | 21 | from tqdm import tqdm 22 | 23 | from cosmos_tokenizer.utils import ( 24 | load_model, 25 | load_encoder_model, 26 | load_decoder_model, 27 | numpy2tensor, 28 | pad_video_batch, 29 | tensor2numpy, 30 | unpad_video_batch, 31 | ) 32 | 33 | 34 | class CausalVideoTokenizer(torch.nn.Module): 35 | def __init__( 36 | self, 37 | checkpoint: str = None, 38 | checkpoint_enc: str = None, 39 | checkpoint_dec: str = None, 40 | tokenizer_config: dict[str, Any] = None, 41 | device: str = "cuda", 42 | dtype: str = "bfloat16", 43 | ) -> None: 44 | super().__init__() 45 | self._device = device 46 | self._dtype = getattr(torch, dtype) 47 | self._full_model = ( 48 | load_model(checkpoint, tokenizer_config, device).to(self._dtype) 49 | if checkpoint is not None 50 | else None 51 | ) 52 | self._enc_model = ( 53 | load_encoder_model(checkpoint_enc, tokenizer_config, device).to(self._dtype) 54 | if checkpoint_enc is not None 55 | else None 56 | ) 57 | self._dec_model = ( 58 | load_decoder_model(checkpoint_dec, tokenizer_config, device).to(self._dtype) 59 | if checkpoint_dec is not None 60 | else None 61 | ) 62 | 63 | @torch.no_grad() 64 | def autoencode(self, input_tensor: torch.Tensor) -> torch.Tensor: 65 | """Reconstrcuts a batch of video tensors after embedding into a latent. 66 | 67 | Args: 68 | video: The input video Bx3xTxHxW layout, range [-1..1]. 69 | Returns: 70 | The reconstructed video, layout Bx3xTxHxW, range [-1..1]. 71 | """ 72 | if self._full_model is not None: 73 | output_tensor = self._full_model(input_tensor) 74 | output_tensor = ( 75 | output_tensor[0] if isinstance(output_tensor, tuple) else output_tensor 76 | ) 77 | else: 78 | output_latent = self.encode(input_tensor)[0] 79 | output_tensor = self.decode(output_latent) 80 | return output_tensor 81 | 82 | @torch.no_grad() 83 | def encode(self, input_tensor: torch.Tensor) -> tuple[torch.Tensor]: 84 | """Encodes a numpy video into a CausalVideo latent or code. 85 | 86 | Args: 87 | input_tensor: The input tensor Bx3xTxHxW layout, range [-1..1]. 88 | Returns: 89 | For causal continuous video (CV) tokenizer, the tuple contains: 90 | - The latent embedding, Bx16x(t)x(h)x(w), where the compression 91 | rate is (T/t x H/h x W/w), and channel dimension of 16. 92 | For causal discrete video (DV) tokenizer, the tuple contains: 93 | 1) The indices, Bx(t)x(h)x(w), from a codebook of size 64K, which 94 | is formed by FSQ levels of (8,8,8,5,5,5). 95 | 2) The discrete code, Bx6x(t)x(h)x(w), where the compression rate 96 | is again (T/t x H/h x W/w), and channel dimension of 6. 97 | """ 98 | assert input_tensor.ndim == 5, "input video should be of 5D." 99 | 100 | output_latent = self._enc_model(input_tensor) 101 | if isinstance(output_latent, torch.Tensor): 102 | return output_latent 103 | return "pad", output_latent[0], output_latent[1] 104 | 105 | @torch.no_grad() 106 | def decode(self, input_latent: torch.Tensor) -> torch.Tensor: 107 | """Encodes a numpy video into a CausalVideo latent. 108 | 109 | Args: 110 | input_latent: The continuous latent Bx16xtxhxw for CV, 111 | or the discrete indices Bxtxhxw for DV. 112 | Returns: 113 | The reconstructed tensor, layout [B,3,1+(T-1)*8,H*16,W*16] in range [-1..1]. 114 | """ 115 | assert ( 116 | input_latent.ndim >= 4 117 | ), "input latent should be of 5D for continuous and 4D for discrete." 118 | return self._dec_model(input_latent) 119 | 120 | def forward( 121 | self, 122 | video: np.ndarray, 123 | temporal_window: int = 17, 124 | ) -> np.ndarray: 125 | """Reconstructs video using a pre-trained CausalTokenizer autoencoder. 126 | Given a video of arbitrary length, the forward invokes the CausalVideoTokenizer 127 | in a sliding manner with a `temporal_window` size. 128 | 129 | Args: 130 | video: The input video BxTxHxWx3 layout, range [0..255]. 131 | temporal_window: The length of the temporal window to process, default=25. 132 | Returns: 133 | The reconstructed video in range [0..255], layout BxTxHxWx3. 134 | """ 135 | assert video.ndim == 5, "input video should be of 5D." 136 | num_frames = video.shape[1] # can be of any length. 137 | output_video_list = [] 138 | for idx in tqdm(range(0, (num_frames - 1) // temporal_window + 1)): 139 | # Input video for the current window. 140 | start, end = idx * temporal_window, (idx + 1) * temporal_window 141 | input_video = video[:, start:end, ...] 142 | 143 | # Spatio-temporally pad input_video so it's evenly divisible. 144 | padded_input_video, crop_region = pad_video_batch(input_video) 145 | input_tensor = numpy2tensor( 146 | padded_input_video, dtype=self._dtype, device=self._device 147 | ) 148 | output_tensor = self.autoencode(input_tensor) 149 | padded_output_video = tensor2numpy(output_tensor) 150 | output_video = unpad_video_batch(padded_output_video, crop_region) 151 | 152 | output_video_list.append(output_video) 153 | return np.concatenate(output_video_list, axis=1) 154 | -------------------------------------------------------------------------------- /open_tokenizer/model/misc.py: -------------------------------------------------------------------------------- 1 | from omegaconf import OmegaConf 2 | import torch 3 | from typing import ( 4 | Any, 5 | Callable, 6 | Dict, 7 | Iterable, 8 | List, 9 | NamedTuple, 10 | NewType, 11 | Optional, 12 | Sized, 13 | Tuple, 14 | Type, 15 | TypeVar, 16 | Union, 17 | ) 18 | try: 19 | from typing import Literal 20 | except ImportError: 21 | from typing_extensions import Literal 22 | 23 | # Tensor dtype 24 | # for jaxtyping usage, see https://github.com/google/jaxtyping/blob/main/API.md 25 | from jaxtyping import Bool, Complex, Float, Inexact, Int, Integer, Num, Shaped, UInt 26 | 27 | # Config type 28 | from omegaconf import DictConfig 29 | 30 | # PyTorch Tensor type 31 | from torch import Tensor 32 | 33 | # Runtime type checking decorator 34 | from typeguard import typechecked as typechecker 35 | 36 | 37 | def broadcast(tensor, src=0): 38 | if not _distributed_available(): 39 | return tensor 40 | else: 41 | torch.distributed.broadcast(tensor, src=src) 42 | return tensor 43 | 44 | def _distributed_available(): 45 | return torch.distributed.is_available() and torch.distributed.is_initialized() 46 | 47 | def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any: 48 | # added by Xavier -- delete '--local-rank' in multi-nodes training, don't know why there is such a keyword 49 | if '--local-rank' in cfg: 50 | del cfg['--local-rank'] 51 | # added by Xavier -- delete '--local-rank' in multi-nodes training, don't know why there is such a keyword 52 | scfg = OmegaConf.structured(fields(**cfg)) 53 | return scfg -------------------------------------------------------------------------------- /open_tokenizer/model/tiktok/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (2024) Bytedance Ltd. and/or its affiliates 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ -------------------------------------------------------------------------------- /open_tokenizer/model/tiktok/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_model import BaseModel 2 | from .ema_model import EMAModel 3 | from .losses import ReconstructionLoss_Stage1, ReconstructionLoss_Stage2, MLMLoss, ARLoss 4 | from .blocks import TiTokEncoder, TiTokDecoder, UViTBlock 5 | from .maskgit_vqgan import Decoder as Pixel_Decoder 6 | from .maskgit_vqgan import VectorQuantizer as Pixel_Quantizer -------------------------------------------------------------------------------- /open_tokenizer/model/tiktok/modules/base_model.py: -------------------------------------------------------------------------------- 1 | """This file contains some base class implementation for models. 2 | 3 | This file may have been modified by Bytedance Ltd. and/or its affiliates (“Bytedance's Modifications”). 4 | All Bytedance's Modifications are Copyright (year) Bytedance Ltd. and/or its affiliates. 5 | 6 | Reference: 7 | https://github.com/huggingface/open-muse/blob/main/muse/modeling_utils.py 8 | """ 9 | import os 10 | from typing import Union, Callable, Dict, Optional 11 | 12 | import torch 13 | 14 | 15 | class BaseModel(torch.nn.Module): 16 | 17 | def __init__(self): 18 | super().__init__() 19 | 20 | def save_pretrained_weight( 21 | self, 22 | save_directory: Union[str, os.PathLike], 23 | save_function: Callable = None, 24 | state_dict: Optional[Dict[str, torch.Tensor]] = None, 25 | ): 26 | """Saves a model and its configuration file to a directory. 27 | 28 | Args: 29 | save_directory: A string or os.PathLike, directory to which to save. 30 | Will be created if it doesn't exist. 31 | save_function: A Callable function, the function to use to save the state dictionary. 32 | Useful on distributed training like TPUs when one need to replace `torch.save` by 33 | another method. Can be configured with the environment variable `DIFFUSERS_SAVE_MODE`. 34 | state_dict: A dictionary from str to torch.Tensor, the state dictionary to save. 35 | If `None`, the model's state dictionary will be saved. 36 | """ 37 | if os.path.isfile(save_directory): 38 | print(f"Provided path ({save_directory}) should be a directory, not a file") 39 | return 40 | 41 | if save_function is None: 42 | save_function = torch.save 43 | 44 | os.makedirs(save_directory, exist_ok=True) 45 | 46 | model_to_save = self 47 | 48 | if state_dict is None: 49 | state_dict = model_to_save.state_dict() 50 | weights_name = "pytorch_model.bin" 51 | 52 | save_function(state_dict, os.path.join(save_directory, weights_name)) 53 | 54 | print(f"Model weights saved in {os.path.join(save_directory, weights_name)}") 55 | 56 | def load_pretrained_weight( 57 | self, 58 | pretrained_model_path: Union[str, os.PathLike], 59 | strict_loading: bool = True, 60 | torch_dtype: Optional[torch.dtype] = None 61 | ): 62 | r"""Instantiates a pretrained pytorch model from a pre-trained model configuration. 63 | 64 | The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train 65 | the model, you should first set it back in training mode with `model.train()`. 66 | 67 | Args: 68 | pretrained_model_path: A string or os.PathLike, a path to a *directory* or *file* containing model weights. 69 | 70 | Raises: 71 | ValueError: If pretrained_model_path does not exist. 72 | """ 73 | # If pretrained_model_path is a file, set model_file to this file. 74 | if os.path.isfile(pretrained_model_path): 75 | model_file = pretrained_model_path 76 | # If pretrained_model_path is a directory, set model_file to the path of the 77 | # file "pytorch_model.bin" in this directory. 78 | elif os.path.isdir(pretrained_model_path): 79 | pretrained_model_path = os.path.join(pretrained_model_path, "pytorch_model.bin") 80 | if os.path.isfile(pretrained_model_path): 81 | model_file = pretrained_model_path 82 | else: 83 | raise ValueError(f"{pretrained_model_path} does not exist") 84 | else: 85 | raise ValueError(f"{pretrained_model_path} does not exist") 86 | 87 | # Load model state from checkpoint. 88 | checkpoint = torch.load(model_file, map_location="cpu") 89 | # Load state dictionary into self. 90 | msg = self.load_state_dict(checkpoint, strict=strict_loading) 91 | # Print information about loading weights. 92 | print(f"loading weight from {model_file}, msg: {msg}") 93 | # If torch_dtype is specified and is a valid torch.dtype, convert self to this dtype. 94 | if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): 95 | raise ValueError( 96 | f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." 97 | ) 98 | elif torch_dtype is not None: 99 | self.to(torch_dtype) 100 | 101 | # Set model in evaluation mode to deactivate DropOut modules by default. 102 | self.eval() 103 | 104 | def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int: 105 | """Gets the number of parameters in the module. 106 | 107 | Args: 108 | only_trainable: A boolean, whether to only include trainable parameters. 109 | exclude_embeddings: A boolean, whether to exclude parameters associated with embeddings. 110 | 111 | Returns: 112 | An integer, the number of parameters. 113 | """ 114 | 115 | if exclude_embeddings: 116 | embedding_param_names = [ 117 | f"{name}.weight" 118 | for name, module_type in self.named_modules() 119 | if isinstance(module_type, torch.nn.Embedding) 120 | ] 121 | non_embedding_parameters = [ 122 | parameter for name, parameter in self.named_parameters() if name not in embedding_param_names 123 | ] 124 | return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable) 125 | else: 126 | return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable) 127 | 128 | -------------------------------------------------------------------------------- /open_tokenizer/model/tiktok/modules/discriminator.py: -------------------------------------------------------------------------------- 1 | """This file contains some base implementation for discrminators. 2 | 3 | Copyright (2024) Bytedance Ltd. and/or its affiliates 4 | 5 | Licensed under the Apache License, Version 2.0 (the "License"); 6 | you may not use this file except in compliance with the License. 7 | You may obtain a copy of the License at 8 | 9 | http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | 17 | TODO: Add reference to Mark Weber's tech report on the improved discriminator architecture. 18 | """ 19 | import functools 20 | import math 21 | from typing import Tuple 22 | 23 | 24 | import torch 25 | import torch.nn as nn 26 | import torch.nn.functional as F 27 | 28 | from .maskgit_vqgan import Conv2dSame 29 | 30 | 31 | class BlurBlock(torch.nn.Module): 32 | def __init__(self, 33 | kernel: Tuple[int] = (1, 3, 3, 1) 34 | ): 35 | super().__init__() 36 | 37 | kernel = torch.tensor(kernel, dtype=torch.float32, requires_grad=False) 38 | kernel = kernel[None, :] * kernel[:, None] 39 | kernel /= kernel.sum() 40 | kernel = kernel.unsqueeze(0).unsqueeze(0) 41 | self.register_buffer("kernel", kernel) 42 | 43 | def calc_same_pad(self, i: int, k: int, s: int) -> int: 44 | return max((math.ceil(i / s) - 1) * s + (k - 1) + 1 - i, 0) 45 | 46 | def forward(self, x: torch.Tensor) -> torch.Tensor: 47 | ic, ih, iw = x.size()[-3:] 48 | pad_h = self.calc_same_pad(i=ih, k=4, s=2) 49 | pad_w = self.calc_same_pad(i=iw, k=4, s=2) 50 | if pad_h > 0 or pad_w > 0: 51 | x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) 52 | 53 | weight = self.kernel.expand(ic, -1, -1, -1) 54 | 55 | out = F.conv2d(input=x, weight=weight, stride=2, groups=x.shape[1]) 56 | return out 57 | 58 | 59 | class NLayerDiscriminator(torch.nn.Module): 60 | def __init__( 61 | self, 62 | num_channels: int = 3, 63 | hidden_channels: int = 128, 64 | num_stages: int = 3, 65 | blur_resample: bool = True, 66 | blur_kernel_size: int = 4 67 | ): 68 | """ Initializes the NLayerDiscriminator. 69 | 70 | Args: 71 | num_channels -> int: The number of input channels. 72 | hidden_channels -> int: The number of hidden channels. 73 | num_stages -> int: The number of stages. 74 | blur_resample -> bool: Whether to use blur resampling. 75 | blur_kernel_size -> int: The blur kernel size. 76 | """ 77 | super().__init__() 78 | assert num_stages > 0, "Discriminator cannot have 0 stages" 79 | assert (not blur_resample) or (blur_kernel_size >= 3 and blur_kernel_size <= 5), "Blur kernel size must be in [3,5] when sampling]" 80 | 81 | in_channel_mult = (1,) + tuple(map(lambda t: 2**t, range(num_stages))) 82 | init_kernel_size = 5 83 | activation = functools.partial(torch.nn.LeakyReLU, negative_slope=0.1) 84 | 85 | self.block_in = torch.nn.Sequential( 86 | Conv2dSame( 87 | num_channels, 88 | hidden_channels, 89 | kernel_size=init_kernel_size 90 | ), 91 | activation(), 92 | ) 93 | 94 | BLUR_KERNEL_MAP = { 95 | 3: (1,2,1), 96 | 4: (1,3,3,1), 97 | 5: (1,4,6,4,1), 98 | } 99 | 100 | discriminator_blocks = [] 101 | for i_level in range(num_stages): 102 | in_channels = hidden_channels * in_channel_mult[i_level] 103 | out_channels = hidden_channels * in_channel_mult[i_level + 1] 104 | block = torch.nn.Sequential( 105 | Conv2dSame( 106 | in_channels, 107 | out_channels, 108 | kernel_size=3, 109 | ), 110 | torch.nn.AvgPool2d(kernel_size=2, stride=2) if not blur_resample else BlurBlock(BLUR_KERNEL_MAP[blur_kernel_size]), 111 | torch.nn.GroupNorm(32, out_channels), 112 | activation(), 113 | ) 114 | discriminator_blocks.append(block) 115 | 116 | self.blocks = torch.nn.ModuleList(discriminator_blocks) 117 | 118 | self.pool = torch.nn.AdaptiveMaxPool2d((16, 16)) 119 | 120 | self.to_logits = torch.nn.Sequential( 121 | Conv2dSame(out_channels, out_channels, 1), 122 | activation(), 123 | Conv2dSame(out_channels, 1, kernel_size=5) 124 | ) 125 | 126 | def forward(self, x: torch.Tensor) -> torch.Tensor: 127 | """ Forward pass. 128 | 129 | Args: 130 | x -> torch.Tensor: The input tensor. 131 | 132 | Returns: 133 | output -> torch.Tensor: The output tensor. 134 | """ 135 | hidden_states = self.block_in(x) 136 | for block in self.blocks: 137 | hidden_states = block(hidden_states) 138 | 139 | hidden_states = self.pool(hidden_states) 140 | 141 | return self.to_logits(hidden_states) 142 | -------------------------------------------------------------------------------- /open_tokenizer/model/tiktok/modules/perceptual_loss.py: -------------------------------------------------------------------------------- 1 | """This file contains perceptual loss module using ConvNeXt-S. 2 | 3 | Copyright (2024) Bytedance Ltd. and/or its affiliates 4 | 5 | Licensed under the Apache License, Version 2.0 (the "License"); 6 | you may not use this file except in compliance with the License. 7 | You may obtain a copy of the License at 8 | 9 | http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | """ 17 | 18 | import torch 19 | import torch.nn.functional as F 20 | 21 | from torchvision import models 22 | 23 | _IMAGENET_MEAN = [0.485, 0.456, 0.406] 24 | _IMAGENET_STD = [0.229, 0.224, 0.225] 25 | 26 | 27 | class PerceptualLoss(torch.nn.Module): 28 | def __init__(self, model_name: str = "convnext_s"): 29 | """Initializes the PerceptualLoss class. 30 | 31 | Args: 32 | model_name: A string, the name of the perceptual loss model to use. 33 | 34 | Raise: 35 | ValueError: If the model_name does not contain "convnext_s". 36 | """ 37 | super().__init__() 38 | if "convnext_s" not in model_name: 39 | raise ValueError(f"Unsupported Perceptual Loss model name {model_name}") 40 | 41 | self.convnext = models.convnext_small(weights=models.ConvNeXt_Small_Weights.IMAGENET1K_V1).eval() 42 | self.register_buffer("imagenet_mean", torch.Tensor(_IMAGENET_MEAN)[None, :, None, None]) 43 | self.register_buffer("imagenet_std", torch.Tensor(_IMAGENET_STD)[None, :, None, None]) 44 | 45 | for param in self.parameters(): 46 | param.requires_grad = False 47 | 48 | def forward(self, input: torch.Tensor, target: torch.Tensor): 49 | """Computes the perceptual loss. 50 | 51 | Args: 52 | input: A tensor of shape (B, C, H, W), the input image. Normalized to [0, 1]. 53 | target: A tensor of shape (B, C, H, W), the target image. Normalized to [0, 1]. 54 | 55 | Returns: 56 | A scalar tensor, the perceptual loss. 57 | """ 58 | # Always in eval mode. 59 | self.eval() 60 | 61 | input = torch.nn.functional.interpolate(input, size=224, mode="bilinear", align_corners=False, antialias=True) 62 | target = torch.nn.functional.interpolate(target, size=224, mode="bilinear", align_corners=False, antialias=True) 63 | pred_input = self.convnext((input - self.imagenet_mean) / self.imagenet_std) 64 | pred_target = self.convnext((target - self.imagenet_mean) / self.imagenet_std) 65 | loss = torch.nn.functional.mse_loss( 66 | pred_input, 67 | pred_target, 68 | reduction="mean") 69 | 70 | return loss -------------------------------------------------------------------------------- /open_tokenizer/model/tiktok/quantizer/__init__.py: -------------------------------------------------------------------------------- 1 | from .quantizer import VectorQuantizer, DiagonalGaussianDistribution -------------------------------------------------------------------------------- /open_tokenizer/model/tiktok/quantizer/quantizer.py: -------------------------------------------------------------------------------- 1 | """Vector quantizer. 2 | 3 | Copyright (2024) Bytedance Ltd. and/or its affiliates 4 | 5 | Licensed under the Apache License, Version 2.0 (the "License"); 6 | you may not use this file except in compliance with the License. 7 | You may obtain a copy of the License at 8 | 9 | http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | 17 | Reference: 18 | https://github.com/CompVis/taming-transformers/blob/master/taming/modules/vqvae/quantize.py 19 | https://github.com/google-research/magvit/blob/main/videogvt/models/vqvae.py 20 | https://github.com/CompVis/latent-diffusion/blob/main/ldm/modules/distributions/distributions.py 21 | """ 22 | from typing import Mapping, Text, Tuple 23 | 24 | import torch 25 | from einops import rearrange 26 | from torch.cuda.amp import autocast 27 | 28 | class VectorQuantizer(torch.nn.Module): 29 | def __init__(self, 30 | codebook_size: int = 1024, 31 | token_size: int = 256, 32 | commitment_cost: float = 0.25, 33 | use_l2_norm: bool = False, 34 | ): 35 | super().__init__() 36 | self.commitment_cost = commitment_cost 37 | 38 | self.embedding = torch.nn.Embedding(codebook_size, token_size) 39 | self.embedding.weight.data.uniform_(-1.0 / codebook_size, 1.0 / codebook_size) 40 | self.use_l2_norm = use_l2_norm 41 | 42 | # Ensure quantization is performed using f32 43 | @autocast(enabled=False) 44 | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]: 45 | z = z.float() 46 | z = rearrange(z, 'b c h w -> b h w c').contiguous() 47 | z_flattened = rearrange(z, 'b h w c -> (b h w) c') 48 | 49 | if self.use_l2_norm: 50 | z_flattened = torch.nn.functional.normalize(z_flattened, dim=-1) 51 | embedding = torch.nn.functional.normalize(self.embedding.weight, dim=-1) 52 | else: 53 | embedding = self.embedding.weight 54 | d = torch.sum(z_flattened**2, dim=1, keepdim=True) + \ 55 | torch.sum(embedding**2, dim=1) - 2 * \ 56 | torch.einsum('bd,dn->bn', z_flattened, embedding.T) 57 | 58 | min_encoding_indices = torch.argmin(d, dim=1) # num_ele 59 | z_quantized = self.get_codebook_entry(min_encoding_indices).view(z.shape) 60 | 61 | if self.use_l2_norm: 62 | z = torch.nn.functional.normalize(z, dim=-1) 63 | 64 | # compute loss for embedding 65 | commitment_loss = self.commitment_cost * torch.mean((z_quantized.detach() - z) **2) 66 | codebook_loss = torch.mean((z_quantized - z.detach()) **2) 67 | 68 | loss = commitment_loss + codebook_loss 69 | 70 | # preserve gradients 71 | z_quantized = z + (z_quantized - z).detach() 72 | 73 | # reshape back to match original input shape 74 | z_quantized = rearrange(z_quantized, 'b h w c -> b c h w').contiguous() 75 | 76 | result_dict = dict( 77 | quantizer_loss=loss, 78 | commitment_loss=commitment_loss, 79 | codebook_loss=codebook_loss, 80 | min_encoding_indices=min_encoding_indices.view(z_quantized.shape[0], z_quantized.shape[2], z_quantized.shape[3]) 81 | ) 82 | 83 | return z_quantized, result_dict 84 | 85 | def get_codebook_entry(self, indices): 86 | if len(indices.shape) == 1: 87 | z_quantized = self.embedding(indices) 88 | elif len(indices.shape) == 2: 89 | z_quantized = torch.einsum('bd,dn->bn', indices, self.embedding.weight) 90 | else: 91 | raise NotImplementedError 92 | if self.use_l2_norm: 93 | z_quantized = torch.nn.functional.normalize(z_quantized, dim=-1) 94 | return z_quantized 95 | 96 | 97 | class DiagonalGaussianDistribution(object): 98 | @autocast(enabled=False) 99 | def __init__(self, parameters, deterministic=False): 100 | """Initializes a Gaussian distribution instance given the parameters. 101 | 102 | Args: 103 | parameters (torch.Tensor): The parameters for the Gaussian distribution. It is expected 104 | to be in shape [B, 2 * C, *], where B is batch size, and C is the embedding dimension. 105 | First C channels are used for mean and last C are used for logvar in the Gaussian distribution. 106 | deterministic (bool): Whether to use deterministic sampling. When it is true, the sampling results 107 | is purely based on mean (i.e., std = 0). 108 | """ 109 | self.parameters = parameters 110 | self.mean, self.logvar = torch.chunk(parameters.float(), 2, dim=1) 111 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 112 | self.deterministic = deterministic 113 | self.std = torch.exp(0.5 * self.logvar) 114 | self.var = torch.exp(self.logvar) 115 | if self.deterministic: 116 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 117 | 118 | @autocast(enabled=False) 119 | def sample(self): 120 | x = self.mean.float() + self.std.float() * torch.randn(self.mean.shape).to(device=self.parameters.device) 121 | return x 122 | 123 | @autocast(enabled=False) 124 | def mode(self): 125 | return self.mean 126 | 127 | @autocast(enabled=False) 128 | def kl(self): 129 | if self.deterministic: 130 | return torch.Tensor([0.]) 131 | else: 132 | return 0.5 * torch.sum(torch.pow(self.mean.float(), 2) 133 | + self.var.float() - 1.0 - self.logvar.float(), 134 | dim=[1, 2]) 135 | -------------------------------------------------------------------------------- /open_tokenizer/model/tiktok/titok.py: -------------------------------------------------------------------------------- 1 | """This file contains the model definition of TiTok. 2 | 3 | Copyright (2024) Bytedance Ltd. and/or its affiliates 4 | 5 | Licensed under the Apache License, Version 2.0 (the "License"); 6 | you may not use this file except in compliance with the License. 7 | You may obtain a copy of the License at 8 | 9 | http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | """ 17 | 18 | import torch 19 | import torch.nn as nn 20 | from einops import rearrange 21 | 22 | from .modules.base_model import BaseModel 23 | from .modules.blocks import TiTokEncoder, TiTokDecoder 24 | from .quantizer.quantizer import VectorQuantizer, DiagonalGaussianDistribution 25 | from .modules.maskgit_vqgan import Encoder as Pixel_Eecoder 26 | from .modules.maskgit_vqgan import Decoder as Pixel_Decoder 27 | from .modules.maskgit_vqgan import VectorQuantizer as Pixel_Quantizer 28 | import json 29 | from omegaconf import OmegaConf 30 | from pathlib import Path 31 | 32 | from huggingface_hub import PyTorchModelHubMixin 33 | 34 | 35 | class PretrainedTokenizer(nn.Module): 36 | def __init__(self, pretrained_weight): 37 | super().__init__() 38 | conf = OmegaConf.create( 39 | {"channel_mult": [1, 1, 2, 2, 4], 40 | "num_resolutions": 5, 41 | "dropout": 0.0, 42 | "hidden_channels": 128, 43 | "num_channels": 3, 44 | "num_res_blocks": 2, 45 | "resolution": 256, 46 | "z_channels": 256}) 47 | self.encoder = Pixel_Eecoder(conf) 48 | self.decoder = Pixel_Decoder(conf) 49 | self.quantize = Pixel_Quantizer( 50 | num_embeddings=1024, embedding_dim=256, commitment_cost=0.25) 51 | # Load pretrained weights 52 | self.load_state_dict(torch.load(pretrained_weight, map_location=torch.device("cpu")), strict=True) 53 | 54 | self.eval() 55 | for param in self.parameters(): 56 | param.requires_grad = False 57 | 58 | @torch.no_grad() 59 | def encode(self, x): 60 | hidden_states = self.encoder(x) 61 | quantized_states, codebook_indices, codebook_loss = self.quantize(hidden_states) 62 | return codebook_indices.detach() 63 | 64 | @torch.no_grad() 65 | def decode(self, codes): 66 | quantized_states = self.quantize.get_codebook_entry(codes) 67 | rec_images = self.decoder(quantized_states) 68 | rec_images = torch.clamp(rec_images, 0.0, 1.0) 69 | return rec_images.detach() 70 | 71 | @torch.no_grad() 72 | def decode_tokens(self, codes): 73 | return self.decode(codes) 74 | 75 | 76 | class TiTok(BaseModel, PyTorchModelHubMixin, tags=["arxiv:2406.07550", "image-tokenization"], repo_url="https://github.com/bytedance/1d-tokenizer", license="apache-2.0"): 77 | def __init__(self, config): 78 | 79 | if isinstance(config, dict): 80 | config = OmegaConf.create(config) 81 | 82 | super().__init__() 83 | self.config = config 84 | # This should be False for stage1 and True for stage2. 85 | self.finetune_decoder = config.model.vq_model.get("finetune_decoder", True) 86 | 87 | self.quantize_mode = config.model.vq_model.get("quantize_mode", "vq") 88 | if self.quantize_mode not in ["vq", "vae"]: 89 | raise ValueError(f"Unsupported quantize mode {self.quantize_mode}.") 90 | 91 | if self.finetune_decoder and self.quantize_mode not in ["vq"]: 92 | raise ValueError("Only supprot finetune_decoder with vq quantization for now.") 93 | 94 | self.encoder = TiTokEncoder(config) 95 | self.decoder = TiTokDecoder(config) 96 | 97 | self.num_latent_tokens = config.model.vq_model.num_latent_tokens 98 | scale = self.encoder.width ** -0.5 99 | self.latent_tokens = nn.Parameter( 100 | scale * torch.randn(self.num_latent_tokens, self.encoder.width)) 101 | 102 | self.apply(self._init_weights) 103 | 104 | if self.quantize_mode == "vq": 105 | self.quantize = VectorQuantizer( 106 | codebook_size=config.model.vq_model.codebook_size, 107 | token_size=config.model.vq_model.token_size, 108 | commitment_cost=config.model.vq_model.commitment_cost, 109 | use_l2_norm=config.model.vq_model.use_l2_norm,) 110 | elif self.quantize_mode == "vae": 111 | self.quantize = DiagonalGaussianDistribution 112 | else: 113 | raise NotImplementedError 114 | 115 | if self.finetune_decoder: 116 | # Freeze encoder/quantizer/latent tokens 117 | self.latent_tokens.requires_grad_(False) 118 | self.encoder.eval() 119 | self.encoder.requires_grad_(False) 120 | self.quantize.eval() 121 | self.quantize.requires_grad_(False) 122 | 123 | # Include MaskGiT-VQGAN's quantizer and decoder 124 | self.pixel_quantize = Pixel_Quantizer( 125 | num_embeddings=1024, embedding_dim=256, commitment_cost=0.25) 126 | self.pixel_decoder = Pixel_Decoder(OmegaConf.create( 127 | {"channel_mult": [1, 1, 2, 2, 4], 128 | "num_resolutions": 5, 129 | "dropout": 0.0, 130 | "hidden_channels": 128, 131 | "num_channels": 3, 132 | "num_res_blocks": 2, 133 | "resolution": 256, 134 | "z_channels": 256})) 135 | 136 | def _save_pretrained(self, save_directory: Path) -> None: 137 | """Save weights and config to a local directory.""" 138 | # Assume 'self.config' is your DictConfig object 139 | # Convert to a regular dictionary 140 | dict_config = OmegaConf.to_container(self.config) 141 | # Save as JSON 142 | file_path = Path(save_directory) / "config.json" 143 | with open(file_path, 'w') as json_file: 144 | json.dump(dict_config, json_file, indent=4) 145 | super()._save_pretrained(save_directory) 146 | 147 | def _init_weights(self, module): 148 | """ Initialize the weights. 149 | :param: 150 | module -> torch.nn.Module: module to initialize 151 | """ 152 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv1d) or isinstance(module, nn.Conv2d): 153 | module.weight.data = nn.init.trunc_normal_(module.weight.data, mean=0.0, std=0.02) 154 | if module.bias is not None: 155 | module.bias.data.zero_() 156 | elif isinstance(module, nn.Embedding): 157 | module.weight.data = nn.init.trunc_normal_(module.weight.data, mean=0.0, std=0.02) 158 | elif isinstance(module, nn.LayerNorm): 159 | module.bias.data.zero_() 160 | module.weight.data.fill_(1.0) 161 | 162 | def encode(self, x): 163 | if self.finetune_decoder: 164 | with torch.no_grad(): 165 | self.encoder.eval() 166 | self.quantize.eval() 167 | z = self.encoder(pixel_values=x, latent_tokens=self.latent_tokens) 168 | z_quantized, result_dict = self.quantize(z) 169 | result_dict["quantizer_loss"] *= 0 170 | result_dict["commitment_loss"] *= 0 171 | result_dict["codebook_loss"] *= 0 172 | else: 173 | z = self.encoder(pixel_values=x, latent_tokens=self.latent_tokens) 174 | if self.quantize_mode == "vq": 175 | z_quantized, result_dict = self.quantize(z) 176 | elif self.quantize_mode == "vae": 177 | posteriors = self.quantize(z) 178 | z_quantized = posteriors.sample() 179 | result_dict = posteriors 180 | 181 | return z_quantized, result_dict["min_encoding_indices"], result_dict 182 | 183 | def _decode(self, z_quantized): 184 | decoded = self.decoder(z_quantized) 185 | if self.finetune_decoder: 186 | quantized_states = torch.einsum( 187 | 'nchw,cd->ndhw', decoded.softmax(1), 188 | self.pixel_quantize.embedding.weight) 189 | decoded = self.pixel_decoder(quantized_states) 190 | return decoded 191 | 192 | def decode(self, tokens, latent_shape=None): 193 | if self.quantize_mode == "vq": 194 | tokens = tokens.squeeze(1) 195 | batch, seq_len = tokens.shape # B x N 196 | z_quantized = self.quantize.get_codebook_entry( 197 | tokens.reshape(-1)).reshape(batch, 1, seq_len, -1) 198 | z_quantized = rearrange(z_quantized, 'b h w c -> b c h w').contiguous() 199 | elif self.quantize_mode == "vae": 200 | z_quantized = tokens 201 | 202 | decoded = self._decode(z_quantized) 203 | return decoded 204 | 205 | def forward(self, x): 206 | z_quantized, result_dict = self.encode(x) 207 | decoded = self.decode(z_quantized) 208 | return decoded, result_dict 209 | -------------------------------------------------------------------------------- /open_tokenizer/model/vidtok/configs/vidtok_fsq_causal_41616_262144.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1e-5 3 | target: open_tokenizer.model.vidtok.models.autoencoder.AutoencodingEngine 4 | params: 5 | monitor: val/rec_loss 6 | mode: min 7 | # ckpt_path: checkpoints/vidtok_fsq_causal_41616_262144.ckpt # train from existing checkpoint 8 | ignore_keys: [] 9 | # ema_decay: 0.999 10 | 11 | encoder_config: 12 | target: open_tokenizer.model.vidtok.modules.model_3dcausal.EncoderCausal3DPadding 13 | params: 14 | double_z: false 15 | z_channels: 6 16 | in_channels: 3 17 | out_ch: 3 18 | ch: 128 19 | ch_mult: [1, 2, 4, 4, 4] 20 | time_downsample_factor: 4 21 | num_res_blocks: 2 22 | dropout: 0.0 23 | use_checkpoint: false 24 | init_pad_mode: replicate 25 | norm_type: layernorm # layernorm, groupnorm 26 | fix_encoder: false # if True, fix it without updating params 27 | fix_decoder: false # if True, fix it without updating params 28 | 29 | decoder_config: 30 | target: open_tokenizer.model.vidtok.modules.model_3dcausal.DecoderCausal3DPadding 31 | params: ${model.params.encoder_config.params} 32 | 33 | regularizer_config: 34 | target: open_tokenizer.model.vidtok.modules.regularizers.FSQRegularizer 35 | params: 36 | levels: [8, 8, 8, 8, 8, 8] # codebook size: 8*8*8*8*8*8=262144 37 | entropy_loss_weight: 0.1 38 | entropy_loss_annealing_steps: 2000 39 | entropy_loss_annealing_factor: 3 40 | commitment_loss_weight: 0.25 41 | 42 | loss_config: 43 | target: open_tokenizer.model.vidtok.modules.losses.GeneralLPIPSWithDiscriminator 44 | params: 45 | dims: 3 # video - [t,h,w] 46 | perceptual_weight: 1.0 47 | disc_start: 20001 48 | disc_weight: 0.2 49 | disc_type: 2d # 2d, 3d 50 | learn_logvar: true 51 | gen_loss_cross_entropy: true 52 | lecam_loss_weight: 0.005 53 | regularization_weights: {'aux_loss': 1.0, 'kl_loss': 0.000001} 54 | 55 | data: 56 | target: open_tokenizer.model.vidtok.data.datamodule.DataModuleFromConfig 57 | params: 58 | batch_size: 2 59 | num_workers: 12 60 | 61 | train: 62 | target: open_tokenizer.model.vidtok.data.vidtok.VidTokDataset 63 | params: 64 | data_dir: DATA_DIR_1 # DATA_DIR for training data 65 | meta_path: META_PATH_1 # path to the .csv meta file of training data 66 | video_params: 67 | input_height: INPUT_HEIGHT_1 68 | input_width: INPUT_WIDTH_1 69 | sample_num_frames: 17 70 | sample_fps: 3 71 | 72 | validation: 73 | target: open_tokenizer.model.vidtok.data.vidtok.VidTokDataset 74 | params: 75 | data_dir: DATA_DIR_2 # DATA_DIR for validation data 76 | meta_path: META_PATH_2 # path to the .csv meta file of validation data 77 | video_params: 78 | input_height: INPUT_HEIGHT_2 79 | input_width: INPUT_WIDTH_2 80 | sample_num_frames: 17 81 | sample_fps: 8 82 | start_index: 0 83 | 84 | lightning: 85 | strategy: 86 | target: lightning.pytorch.strategies.DDPStrategy 87 | params: 88 | find_unused_parameters: true 89 | 90 | modelcheckpoint: 91 | params: 92 | every_n_train_steps: 5000 93 | 94 | callbacks: 95 | image_logger: 96 | target: open_tokenizer.model.vidtok.modules.logger.ImageVideoLogger 97 | params: 98 | disabled: false 99 | rescale: true 100 | enable_autocast: false 101 | batch_frequency: 5000 102 | max_samples: 2 103 | increase_log_steps: false 104 | log_first_step: false 105 | log_before_first_step: false 106 | log_images_kwargs: 107 | n_rows: 17 108 | 109 | trainer: 110 | precision: bf16-mixed 111 | devices: auto 112 | num_nodes: 1 113 | benchmark: true 114 | num_sanity_val_steps: 10 115 | val_check_interval: 2000 116 | check_val_every_n_epoch: null # default: 1 117 | accumulate_grad_batches: 1 118 | max_epochs: 1000 119 | -------------------------------------------------------------------------------- /open_tokenizer/model/vidtok/configs/vidtok_fsq_causal_488_262144.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1e-5 3 | target: vidtok.models.autoencoder.AutoencodingEngine 4 | params: 5 | monitor: val/rec_loss 6 | mode: min 7 | # ckpt_path: checkpoints/vidtok_fsq_causal_488_262144.ckpt # train from existing checkpoint 8 | ignore_keys: [] 9 | # ema_decay: 0.999 10 | 11 | encoder_config: 12 | target: vidtok.modules.model_3dcausal.EncoderCausal3DPadding 13 | params: 14 | double_z: false 15 | z_channels: 6 16 | in_channels: 3 17 | out_ch: 3 18 | ch: 128 19 | ch_mult: [1, 2, 4, 4] 20 | time_downsample_factor: 4 21 | num_res_blocks: 2 22 | dropout: 0.0 23 | use_checkpoint: false 24 | init_pad_mode: replicate 25 | norm_type: layernorm # layernorm, groupnorm 26 | fix_encoder: false # if True, fix it without updating params 27 | fix_decoder: false # if True, fix it without updating params 28 | 29 | decoder_config: 30 | target: vidtok.modules.model_3dcausal.DecoderCausal3DPadding 31 | params: ${model.params.encoder_config.params} 32 | 33 | regularizer_config: 34 | target: vidtok.modules.regularizers.FSQRegularizer 35 | params: 36 | levels: [8, 8, 8, 8, 8, 8] # codebook size: 8*8*8*8*8*8=262144 37 | entropy_loss_weight: 0.1 38 | entropy_loss_annealing_steps: 2000 39 | entropy_loss_annealing_factor: 3 40 | commitment_loss_weight: 0.25 41 | 42 | loss_config: 43 | target: vidtok.modules.losses.GeneralLPIPSWithDiscriminator 44 | params: 45 | dims: 3 # video - [t,h,w] 46 | perceptual_weight: 1.0 47 | disc_start: 20001 48 | disc_weight: 0.2 49 | disc_type: 2d # 2d, 3d 50 | learn_logvar: true 51 | gen_loss_cross_entropy: true 52 | lecam_loss_weight: 0.005 53 | regularization_weights: {'aux_loss': 1.0, 'kl_loss': 0.000001} 54 | 55 | data: 56 | target: vidtok.data.datamodule.DataModuleFromConfig 57 | params: 58 | batch_size: 2 59 | num_workers: 12 60 | 61 | train: 62 | target: vidtok.data.vidtok.VidTokDataset 63 | params: 64 | data_dir: DATA_DIR_1 # DATA_DIR for training data 65 | meta_path: META_PATH_1 # path to the .csv meta file of training data 66 | video_params: 67 | input_height: INPUT_HEIGHT_1 68 | input_width: INPUT_WIDTH_1 69 | sample_num_frames: 17 70 | sample_fps: 3 71 | 72 | validation: 73 | target: vidtok.data.vidtok.VidTokDataset 74 | params: 75 | data_dir: DATA_DIR_2 # DATA_DIR for validation data 76 | meta_path: META_PATH_2 # path to the .csv meta file of validation data 77 | video_params: 78 | input_height: INPUT_HEIGHT_2 79 | input_width: INPUT_WIDTH_2 80 | sample_num_frames: 17 81 | sample_fps: 8 82 | start_index: 0 83 | 84 | lightning: 85 | strategy: 86 | target: lightning.pytorch.strategies.DDPStrategy 87 | params: 88 | find_unused_parameters: true 89 | 90 | modelcheckpoint: 91 | params: 92 | every_n_train_steps: 5000 93 | 94 | callbacks: 95 | image_logger: 96 | target: vidtok.modules.logger.ImageVideoLogger 97 | params: 98 | disabled: false 99 | rescale: true 100 | enable_autocast: false 101 | batch_frequency: 5000 102 | max_samples: 2 103 | increase_log_steps: false 104 | log_first_step: false 105 | log_before_first_step: false 106 | log_images_kwargs: 107 | n_rows: 17 108 | 109 | trainer: 110 | precision: bf16-mixed 111 | devices: auto 112 | num_nodes: 1 113 | benchmark: true 114 | num_sanity_val_steps: 10 115 | val_check_interval: 2000 116 | check_val_every_n_epoch: null # default: 1 117 | accumulate_grad_batches: 1 118 | max_epochs: 1000 119 | -------------------------------------------------------------------------------- /open_tokenizer/model/vidtok/configs/vidtok_fsq_causal_488_32768.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1e-5 3 | target: open_tokenizer.model.vidtok.models.autoencoder.AutoencodingEngine 4 | params: 5 | monitor: val/rec_loss 6 | mode: min 7 | # ckpt_path: checkpoints/vidtok_fsq_causal_488_32768.ckpt # train from existing checkpoint 8 | ignore_keys: [] 9 | # ema_decay: 0.999 10 | 11 | encoder_config: 12 | target: open_tokenizer.model.vidtok.modules.model_3dcausal.EncoderCausal3DPadding 13 | params: 14 | double_z: false 15 | z_channels: 5 16 | in_channels: 3 17 | out_ch: 3 18 | ch: 128 19 | ch_mult: [1, 2, 4, 4] 20 | time_downsample_factor: 4 21 | num_res_blocks: 2 22 | dropout: 0.0 23 | use_checkpoint: false 24 | init_pad_mode: replicate 25 | norm_type: layernorm # layernorm, groupnorm 26 | fix_encoder: false # if True, fix it without updating params 27 | fix_decoder: false # if True, fix it without updating params 28 | 29 | decoder_config: 30 | target: open_tokenizer.model.vidtok.modules.model_3dcausal.DecoderCausal3DPadding 31 | params: ${model.params.encoder_config.params} 32 | 33 | regularizer_config: 34 | target: open_tokenizer.model.vidtok.modules.regularizers.FSQRegularizer 35 | params: 36 | levels: [8, 8, 8, 8, 8] # codebook size: 8*8*8*8*8=32768 37 | entropy_loss_weight: 0.1 38 | entropy_loss_annealing_steps: 2000 39 | entropy_loss_annealing_factor: 3 40 | commitment_loss_weight: 0.25 41 | 42 | loss_config: 43 | target: open_tokenizer.model.vidtok.modules.losses.GeneralLPIPSWithDiscriminator 44 | params: 45 | dims: 3 # video - [t,h,w] 46 | perceptual_weight: 1.0 47 | disc_start: 20001 48 | disc_weight: 0.2 49 | disc_type: 2d # 2d, 3d 50 | learn_logvar: true 51 | gen_loss_cross_entropy: true 52 | lecam_loss_weight: 0.005 53 | regularization_weights: {'aux_loss': 1.0, 'kl_loss': 0.000001} 54 | 55 | data: 56 | target: open_tokenizer.model.vidtok.data.datamodule.DataModuleFromConfig 57 | params: 58 | batch_size: 2 59 | num_workers: 12 60 | 61 | train: 62 | target: open_tokenizer.model.vidtok.data.vidtok.VidTokDataset 63 | params: 64 | data_dir: DATA_DIR_1 # DATA_DIR for training data 65 | meta_path: META_PATH_1 # path to the .csv meta file of training data 66 | video_params: 67 | input_height: INPUT_HEIGHT_1 68 | input_width: INPUT_WIDTH_1 69 | sample_num_frames: 17 70 | sample_fps: 3 71 | 72 | validation: 73 | target: open_tokenizer.model.vidtok.data.vidtok.VidTokDataset 74 | params: 75 | data_dir: DATA_DIR_2 # DATA_DIR for validation data 76 | meta_path: META_PATH_2 # path to the .csv meta file of validation data 77 | video_params: 78 | input_height: INPUT_HEIGHT_2 79 | input_width: INPUT_WIDTH_2 80 | sample_num_frames: 17 81 | sample_fps: 8 82 | start_index: 0 83 | 84 | lightning: 85 | strategy: 86 | target: lightning.pytorch.strategies.DDPStrategy 87 | params: 88 | find_unused_parameters: true 89 | 90 | modelcheckpoint: 91 | params: 92 | every_n_train_steps: 5000 93 | 94 | callbacks: 95 | image_logger: 96 | target: open_tokenizer.model.vidtok.modules.logger.ImageVideoLogger 97 | params: 98 | disabled: false 99 | rescale: true 100 | enable_autocast: false 101 | batch_frequency: 5000 102 | max_samples: 2 103 | increase_log_steps: false 104 | log_first_step: false 105 | log_before_first_step: false 106 | log_images_kwargs: 107 | n_rows: 17 108 | 109 | trainer: 110 | precision: bf16-mixed 111 | devices: auto 112 | num_nodes: 1 113 | benchmark: true 114 | num_sanity_val_steps: 10 115 | val_check_interval: 2000 116 | check_val_every_n_epoch: null # default: 1 117 | accumulate_grad_batches: 1 118 | max_epochs: 1000 119 | -------------------------------------------------------------------------------- /open_tokenizer/model/vidtok/configs/vidtok_fsq_causal_488_4096.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1e-5 3 | target: vidtok.models.autoencoder.AutoencodingEngine 4 | params: 5 | monitor: val/rec_loss 6 | mode: min 7 | # ckpt_path: checkpoints/vidtok_fsq_causal_488_4096.ckpt # train from existing checkpoint 8 | ignore_keys: [] 9 | # ema_decay: 0.999 10 | 11 | encoder_config: 12 | target: vidtok.modules.model_3dcausal.EncoderCausal3DPadding 13 | params: 14 | double_z: false 15 | z_channels: 4 16 | in_channels: 3 17 | out_ch: 3 18 | ch: 128 19 | ch_mult: [1, 2, 4, 4] 20 | time_downsample_factor: 4 21 | num_res_blocks: 2 22 | dropout: 0.0 23 | use_checkpoint: false 24 | init_pad_mode: replicate 25 | norm_type: layernorm # layernorm, groupnorm 26 | fix_encoder: false # if True, fix it without updating params 27 | fix_decoder: false # if True, fix it without updating params 28 | 29 | decoder_config: 30 | target: vidtok.modules.model_3dcausal.DecoderCausal3DPadding 31 | params: ${model.params.encoder_config.params} 32 | 33 | regularizer_config: 34 | target: vidtok.modules.regularizers.FSQRegularizer 35 | params: 36 | levels: [8, 8, 8, 8] # codebook size: 8*8*8*8=4096 37 | entropy_loss_weight: 0.1 38 | entropy_loss_annealing_steps: 2000 39 | entropy_loss_annealing_factor: 3 40 | commitment_loss_weight: 0.25 41 | 42 | loss_config: 43 | target: vidtok.modules.losses.GeneralLPIPSWithDiscriminator 44 | params: 45 | dims: 3 # video - [t,h,w] 46 | perceptual_weight: 1.0 47 | disc_start: 20001 48 | disc_weight: 0.2 49 | disc_type: 2d # 2d, 3d 50 | learn_logvar: true 51 | gen_loss_cross_entropy: true 52 | lecam_loss_weight: 0.005 53 | regularization_weights: {'aux_loss': 1.0, 'kl_loss': 0.000001} 54 | 55 | data: 56 | target: vidtok.data.datamodule.DataModuleFromConfig 57 | params: 58 | batch_size: 2 59 | num_workers: 12 60 | 61 | train: 62 | target: vidtok.data.vidtok.VidTokDataset 63 | params: 64 | data_dir: DATA_DIR_1 # DATA_DIR for training data 65 | meta_path: META_PATH_1 # path to the .csv meta file of training data 66 | video_params: 67 | input_height: INPUT_HEIGHT_1 68 | input_width: INPUT_WIDTH_1 69 | sample_num_frames: 17 70 | sample_fps: 3 71 | 72 | validation: 73 | target: vidtok.data.vidtok.VidTokDataset 74 | params: 75 | data_dir: DATA_DIR_2 # DATA_DIR for validation data 76 | meta_path: META_PATH_2 # path to the .csv meta file of validation data 77 | video_params: 78 | input_height: INPUT_HEIGHT_2 79 | input_width: INPUT_WIDTH_2 80 | sample_num_frames: 17 81 | sample_fps: 8 82 | start_index: 0 83 | 84 | lightning: 85 | strategy: 86 | target: lightning.pytorch.strategies.DDPStrategy 87 | params: 88 | find_unused_parameters: true 89 | 90 | modelcheckpoint: 91 | params: 92 | every_n_train_steps: 5000 93 | 94 | callbacks: 95 | image_logger: 96 | target: vidtok.modules.logger.ImageVideoLogger 97 | params: 98 | disabled: false 99 | rescale: true 100 | enable_autocast: false 101 | batch_frequency: 5000 102 | max_samples: 2 103 | increase_log_steps: false 104 | log_first_step: false 105 | log_before_first_step: false 106 | log_images_kwargs: 107 | n_rows: 17 108 | 109 | trainer: 110 | precision: bf16-mixed 111 | devices: auto 112 | num_nodes: 1 113 | benchmark: true 114 | num_sanity_val_steps: 10 115 | val_check_interval: 2000 116 | check_val_every_n_epoch: null # default: 1 117 | accumulate_grad_batches: 1 118 | max_epochs: 1000 119 | -------------------------------------------------------------------------------- /open_tokenizer/model/vidtok/configs/vidtok_fsq_noncausal_41616_262144.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1e-5 3 | target: vidtok.models.autoencoder.AutoencodingEngine 4 | params: 5 | monitor: val/rec_loss 6 | mode: min 7 | # ckpt_path: checkpoints/vidtok_fsq_noncausal_41616_262144.ckpt # train from existing checkpoint 8 | ignore_keys: [] 9 | # ema_decay: 0.999 10 | 11 | encoder_config: 12 | target: vidtok.modules.model_3dnoncausal.Encoder3D 13 | params: 14 | double_z: false 15 | z_channels: 6 16 | in_channels: 3 17 | out_ch: 3 18 | ch: 128 19 | ch_mult: [1, 2, 4, 4, 4] 20 | num_res_blocks: 2 21 | dropout: 0.0 22 | use_checkpoint: false 23 | norm_type: layernorm # layernorm, groupnorm 24 | fix_encoder: false 25 | fix_decoder: false 26 | 27 | decoder_config: 28 | target: vidtok.modules.model_3dnoncausal.Decoder3D 29 | params: ${model.params.encoder_config.params} 30 | 31 | regularizer_config: 32 | target: vidtok.modules.regularizers.FSQRegularizer 33 | params: 34 | levels: [8, 8, 8, 8, 8, 8] # codebook size: 8*8*8*8*8*8=262144 35 | entropy_loss_weight: 0.1 36 | entropy_loss_annealing_steps: 2000 37 | entropy_loss_annealing_factor: 3 38 | commitment_loss_weight: 0.25 39 | 40 | loss_config: 41 | target: vidtok.modules.losses.GeneralLPIPSWithDiscriminator 42 | params: 43 | dims: 3 # video - [t,h,w] 44 | perceptual_weight: 1.0 45 | disc_start: 20001 46 | disc_weight: 0.2 47 | disc_type: 2d # 2d, 3d 48 | learn_logvar: true 49 | gen_loss_cross_entropy: true 50 | lecam_loss_weight: 0.005 51 | regularization_weights: {'aux_loss': 1.0, 'kl_loss': 0.000001} 52 | 53 | data: 54 | target: vidtok.data.datamodule.DataModuleFromConfig 55 | params: 56 | batch_size: 2 57 | num_workers: 12 58 | 59 | train: 60 | target: vidtok.data.vidtok.VidTokDataset 61 | params: 62 | data_dir: DATA_DIR_1 # DATA_DIR for training data 63 | meta_path: META_PATH_1 # path to the .csv meta file of training data 64 | video_params: 65 | input_height: INPUT_HEIGHT_1 66 | input_width: INPUT_WIDTH_1 67 | sample_num_frames: 16 68 | sample_fps: 3 69 | 70 | validation: 71 | target: vidtok.data.vidtok.VidTokDataset 72 | params: 73 | data_dir: DATA_DIR_2 # DATA_DIR for validation data 74 | meta_path: META_PATH_2 # path to the .csv meta file of validation data 75 | video_params: 76 | input_height: INPUT_HEIGHT_2 77 | input_width: INPUT_WIDTH_2 78 | sample_num_frames: 16 79 | sample_fps: 8 80 | start_index: 0 81 | 82 | lightning: 83 | strategy: 84 | target: lightning.pytorch.strategies.DDPStrategy 85 | params: 86 | find_unused_parameters: true 87 | 88 | modelcheckpoint: 89 | params: 90 | every_n_train_steps: 5000 91 | 92 | callbacks: 93 | image_logger: 94 | target: vidtok.modules.logger.ImageVideoLogger 95 | params: 96 | disabled: false 97 | rescale: true 98 | enable_autocast: false 99 | batch_frequency: 5000 100 | max_samples: 2 101 | increase_log_steps: false 102 | log_first_step: false 103 | log_before_first_step: false 104 | log_images_kwargs: 105 | n_rows: 16 106 | 107 | trainer: 108 | precision: bf16-mixed 109 | devices: auto 110 | num_nodes: 1 111 | benchmark: true 112 | num_sanity_val_steps: 10 113 | val_check_interval: 2000 114 | check_val_every_n_epoch: null # default: 1 115 | accumulate_grad_batches: 1 116 | max_epochs: 1000 117 | -------------------------------------------------------------------------------- /open_tokenizer/model/vidtok/configs/vidtok_fsq_noncausal_488_262144.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1e-5 3 | target: vidtok.models.autoencoder.AutoencodingEngine 4 | params: 5 | monitor: val/rec_loss 6 | mode: min 7 | # ckpt_path: checkpoints/vidtok_fsq_noncausal_488_262144.ckpt # train from existing checkpoint 8 | ignore_keys: [] 9 | # ema_decay: 0.999 10 | 11 | encoder_config: 12 | target: vidtok.modules.model_3dnoncausal.Encoder3D 13 | params: 14 | double_z: false 15 | z_channels: 6 16 | in_channels: 3 17 | out_ch: 3 18 | ch: 128 19 | ch_mult: [1, 2, 4, 4] 20 | num_res_blocks: 2 21 | dropout: 0.0 22 | use_checkpoint: false 23 | norm_type: layernorm # layernorm, groupnorm 24 | fix_encoder: false 25 | fix_decoder: false 26 | 27 | decoder_config: 28 | target: vidtok.modules.model_3dnoncausal.Decoder3D 29 | params: ${model.params.encoder_config.params} 30 | 31 | regularizer_config: 32 | target: vidtok.modules.regularizers.FSQRegularizer 33 | params: 34 | levels: [8, 8, 8, 8, 8, 8] # codebook size: 8*8*8*8*8*8=262144 35 | entropy_loss_weight: 0.1 36 | entropy_loss_annealing_steps: 2000 37 | entropy_loss_annealing_factor: 3 38 | commitment_loss_weight: 0.25 39 | 40 | loss_config: 41 | target: vidtok.modules.losses.GeneralLPIPSWithDiscriminator 42 | params: 43 | dims: 3 # video - [t,h,w] 44 | perceptual_weight: 1.0 45 | disc_start: 20001 46 | disc_weight: 0.2 47 | disc_type: 2d # 2d, 3d 48 | learn_logvar: true 49 | gen_loss_cross_entropy: true 50 | lecam_loss_weight: 0.005 51 | regularization_weights: {'aux_loss': 1.0, 'kl_loss': 0.000001} 52 | 53 | data: 54 | target: vidtok.data.datamodule.DataModuleFromConfig 55 | params: 56 | batch_size: 2 57 | num_workers: 12 58 | 59 | train: 60 | target: vidtok.data.vidtok.VidTokDataset 61 | params: 62 | data_dir: DATA_DIR_1 # DATA_DIR for training data 63 | meta_path: META_PATH_1 # path to the .csv meta file of training data 64 | video_params: 65 | input_height: INPUT_HEIGHT_1 66 | input_width: INPUT_WIDTH_1 67 | sample_num_frames: 16 68 | sample_fps: 3 69 | 70 | validation: 71 | target: vidtok.data.vidtok.VidTokDataset 72 | params: 73 | data_dir: DATA_DIR_2 # DATA_DIR for validation data 74 | meta_path: META_PATH_2 # path to the .csv meta file of validation data 75 | video_params: 76 | input_height: INPUT_HEIGHT_2 77 | input_width: INPUT_WIDTH_2 78 | sample_num_frames: 16 79 | sample_fps: 8 80 | start_index: 0 81 | 82 | lightning: 83 | strategy: 84 | target: lightning.pytorch.strategies.DDPStrategy 85 | params: 86 | find_unused_parameters: true 87 | 88 | modelcheckpoint: 89 | params: 90 | every_n_train_steps: 5000 91 | 92 | callbacks: 93 | image_logger: 94 | target: vidtok.modules.logger.ImageVideoLogger 95 | params: 96 | disabled: false 97 | rescale: true 98 | enable_autocast: false 99 | batch_frequency: 5000 100 | max_samples: 2 101 | increase_log_steps: false 102 | log_first_step: false 103 | log_before_first_step: false 104 | log_images_kwargs: 105 | n_rows: 16 106 | 107 | trainer: 108 | precision: bf16-mixed 109 | devices: auto 110 | num_nodes: 1 111 | benchmark: true 112 | num_sanity_val_steps: 10 113 | val_check_interval: 2000 114 | check_val_every_n_epoch: null # default: 1 115 | accumulate_grad_batches: 1 116 | max_epochs: 1000 117 | -------------------------------------------------------------------------------- /open_tokenizer/model/vidtok/configs/vidtok_kl_causal_288_8chn.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1e-5 3 | target: vidtok.models.autoencoder.AutoencodingEngine 4 | params: 5 | monitor: val/rec_loss 6 | mode: min 7 | # ckpt_path: checkpoints/vidtok_kl_causal_288_8chn.ckpt # train from existing checkpoint 8 | ignore_keys: [] 9 | # ema_decay: 0.999 10 | 11 | encoder_config: 12 | target: vidtok.modules.model_3dcausal.EncoderCausal3DPadding 13 | params: 14 | double_z: true 15 | z_channels: 8 16 | in_channels: 3 17 | out_ch: 3 18 | ch: 128 19 | ch_mult: [1, 2, 4, 4] 20 | tempo_ds: [1] 21 | tempo_us: [2] 22 | time_downsample_factor: 2 23 | num_res_blocks: 2 24 | dropout: 0.0 25 | use_checkpoint: false 26 | init_pad_mode: replicate 27 | norm_type: layernorm # layernorm, groupnorm 28 | fix_encoder: false # if True, fix it without updating params 29 | fix_decoder: false # if True, fix it without updating params 30 | 31 | decoder_config: 32 | target: vidtok.modules.model_3dcausal.DecoderCausal3DPadding 33 | params: ${model.params.encoder_config.params} 34 | 35 | regularizer_config: 36 | target: vidtok.modules.regularizers.DiagonalGaussianRegularizer 37 | 38 | loss_config: 39 | target: vidtok.modules.losses.GeneralLPIPSWithDiscriminator 40 | params: 41 | dims: 3 # video - [t,h,w] 42 | perceptual_weight: 1.0 43 | disc_start: 20001 44 | disc_weight: 0.2 45 | disc_type: 2d # 2d, 3d 46 | learn_logvar: true 47 | gen_loss_cross_entropy: true 48 | lecam_loss_weight: 0.005 49 | regularization_weights: {'aux_loss': 1.0, 'kl_loss': 0.000001} 50 | 51 | data: 52 | target: vidtok.data.datamodule.DataModuleFromConfig 53 | params: 54 | batch_size: 2 55 | num_workers: 12 56 | 57 | train: 58 | target: vidtok.data.vidtok.VidTokDataset 59 | params: 60 | data_dir: DATA_DIR_1 # DATA_DIR for training data 61 | meta_path: META_PATH_1 # path to the .csv meta file of training data 62 | video_params: 63 | input_height: INPUT_HEIGHT_1 64 | input_width: INPUT_WIDTH_1 65 | sample_num_frames: 17 66 | sample_fps: 3 67 | 68 | validation: 69 | target: vidtok.data.vidtok.VidTokDataset 70 | params: 71 | data_dir: DATA_DIR_2 # DATA_DIR for validation data 72 | meta_path: META_PATH_2 # path to the .csv meta file of validation data 73 | video_params: 74 | input_height: INPUT_HEIGHT_2 75 | input_width: INPUT_WIDTH_2 76 | sample_num_frames: 17 77 | sample_fps: 8 78 | start_index: 0 79 | 80 | lightning: 81 | strategy: 82 | target: lightning.pytorch.strategies.DDPStrategy 83 | params: 84 | find_unused_parameters: true 85 | 86 | modelcheckpoint: 87 | params: 88 | every_n_train_steps: 5000 89 | 90 | callbacks: 91 | image_logger: 92 | target: vidtok.modules.logger.ImageVideoLogger 93 | params: 94 | disabled: false 95 | rescale: true 96 | enable_autocast: false 97 | batch_frequency: 5000 98 | max_samples: 2 99 | increase_log_steps: false 100 | log_first_step: false 101 | log_before_first_step: false 102 | log_images_kwargs: 103 | n_rows: 17 104 | 105 | trainer: 106 | precision: bf16-mixed 107 | devices: auto 108 | num_nodes: 1 109 | benchmark: true 110 | num_sanity_val_steps: 10 111 | val_check_interval: 2000 112 | check_val_every_n_epoch: null # default: 1 113 | accumulate_grad_batches: 1 114 | max_epochs: 1000 115 | -------------------------------------------------------------------------------- /open_tokenizer/model/vidtok/configs/vidtok_kl_causal_41616_4chn.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1e-5 3 | target: vidtok.models.autoencoder.AutoencodingEngine 4 | params: 5 | monitor: val/rec_loss 6 | mode: min 7 | # ckpt_path: checkpoints/vidtok_kl_causal_41616_4chn.ckpt # train from existing checkpoint 8 | ignore_keys: [] 9 | # ema_decay: 0.999 10 | 11 | encoder_config: 12 | target: vidtok.modules.model_3dcausal.EncoderCausal3DPadding 13 | params: 14 | double_z: true 15 | z_channels: 4 16 | in_channels: 3 17 | out_ch: 3 18 | ch: 128 19 | ch_mult: [1, 2, 4, 4, 4] 20 | time_downsample_factor: 4 21 | num_res_blocks: 2 22 | dropout: 0.0 23 | use_checkpoint: false 24 | init_pad_mode: replicate 25 | norm_type: layernorm # layernorm, groupnorm 26 | fix_encoder: false # if True, fix it without updating params 27 | fix_decoder: false # if True, fix it without updating params 28 | 29 | decoder_config: 30 | target: vidtok.modules.model_3dcausal.DecoderCausal3DPadding 31 | params: ${model.params.encoder_config.params} 32 | 33 | regularizer_config: 34 | target: vidtok.modules.regularizers.DiagonalGaussianRegularizer 35 | 36 | loss_config: 37 | target: vidtok.modules.losses.GeneralLPIPSWithDiscriminator 38 | params: 39 | dims: 3 # video - [t,h,w] 40 | perceptual_weight: 1.0 41 | disc_start: 20001 42 | disc_weight: 0.2 43 | disc_type: 2d # 2d, 3d 44 | learn_logvar: true 45 | gen_loss_cross_entropy: true 46 | lecam_loss_weight: 0.005 47 | regularization_weights: {'aux_loss': 1.0, 'kl_loss': 0.000001} 48 | 49 | data: 50 | target: vidtok.data.datamodule.DataModuleFromConfig 51 | params: 52 | batch_size: 2 53 | num_workers: 12 54 | 55 | train: 56 | target: vidtok.data.vidtok.VidTokDataset 57 | params: 58 | data_dir: DATA_DIR_1 # DATA_DIR for training data 59 | meta_path: META_PATH_1 # path to the .csv meta file of training data 60 | video_params: 61 | input_height: INPUT_HEIGHT_1 62 | input_width: INPUT_WIDTH_1 63 | sample_num_frames: 17 64 | sample_fps: 3 65 | 66 | validation: 67 | target: vidtok.data.vidtok.VidTokDataset 68 | params: 69 | data_dir: DATA_DIR_2 # DATA_DIR for validation data 70 | meta_path: META_PATH_2 # path to the .csv meta file of validation data 71 | video_params: 72 | input_height: INPUT_HEIGHT_2 73 | input_width: INPUT_WIDTH_2 74 | sample_num_frames: 17 75 | sample_fps: 8 76 | start_index: 0 77 | 78 | lightning: 79 | strategy: 80 | target: lightning.pytorch.strategies.DDPStrategy 81 | params: 82 | find_unused_parameters: true 83 | 84 | modelcheckpoint: 85 | params: 86 | every_n_train_steps: 5000 87 | 88 | callbacks: 89 | image_logger: 90 | target: vidtok.modules.logger.ImageVideoLogger 91 | params: 92 | disabled: false 93 | rescale: true 94 | enable_autocast: false 95 | batch_frequency: 5000 96 | max_samples: 2 97 | increase_log_steps: false 98 | log_first_step: false 99 | log_before_first_step: false 100 | log_images_kwargs: 101 | n_rows: 17 102 | 103 | trainer: 104 | precision: bf16-mixed 105 | devices: auto 106 | num_nodes: 1 107 | benchmark: true 108 | num_sanity_val_steps: 10 109 | val_check_interval: 2000 110 | check_val_every_n_epoch: null # default: 1 111 | accumulate_grad_batches: 1 112 | max_epochs: 1000 113 | -------------------------------------------------------------------------------- /open_tokenizer/model/vidtok/configs/vidtok_kl_causal_444_4chn.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1e-5 3 | target: vidtok.models.autoencoder.AutoencodingEngine 4 | params: 5 | monitor: val/rec_loss 6 | mode: min 7 | # ckpt_path: checkpoints/vidtok_kl_causal_444_4chn.ckpt # train from existing checkpoint 8 | ignore_keys: [] 9 | # ema_decay: 0.999 10 | 11 | encoder_config: 12 | target: vidtok.modules.model_3dcausal.EncoderCausal3DPadding 13 | params: 14 | double_z: true 15 | z_channels: 4 16 | in_channels: 3 17 | out_ch: 3 18 | ch: 128 19 | ch_mult: [1, 2, 4, 4] 20 | spatial_ds: [1, 2] 21 | spatial_us: [1, 2] 22 | time_downsample_factor: 4 23 | num_res_blocks: 2 24 | dropout: 0.0 25 | use_checkpoint: false 26 | init_pad_mode: replicate 27 | norm_type: layernorm # layernorm, groupnorm 28 | fix_encoder: false # if True, fix it without updating params 29 | fix_decoder: false # if True, fix it without updating params 30 | 31 | decoder_config: 32 | target: vidtok.modules.model_3dcausal.DecoderCausal3DPadding 33 | params: ${model.params.encoder_config.params} 34 | 35 | regularizer_config: 36 | target: vidtok.modules.regularizers.DiagonalGaussianRegularizer 37 | 38 | loss_config: 39 | target: vidtok.modules.losses.GeneralLPIPSWithDiscriminator 40 | params: 41 | dims: 3 # video - [t,h,w] 42 | perceptual_weight: 1.0 43 | disc_start: 20001 44 | disc_weight: 0.2 45 | disc_type: 2d # 2d, 3d 46 | learn_logvar: true 47 | gen_loss_cross_entropy: true 48 | lecam_loss_weight: 0.005 49 | regularization_weights: {'aux_loss': 1.0, 'kl_loss': 0.000001} 50 | 51 | data: 52 | target: vidtok.data.datamodule.DataModuleFromConfig 53 | params: 54 | batch_size: 2 55 | num_workers: 12 56 | 57 | train: 58 | target: vidtok.data.vidtok.VidTokDataset 59 | params: 60 | data_dir: DATA_DIR_1 # DATA_DIR for training data 61 | meta_path: META_PATH_1 # path to the .csv meta file of training data 62 | video_params: 63 | input_height: INPUT_HEIGHT_1 64 | input_width: INPUT_WIDTH_1 65 | sample_num_frames: 17 66 | sample_fps: 3 67 | 68 | validation: 69 | target: vidtok.data.vidtok.VidTokDataset 70 | params: 71 | data_dir: DATA_DIR_2 # DATA_DIR for validation data 72 | meta_path: META_PATH_2 # path to the .csv meta file of validation data 73 | video_params: 74 | input_height: INPUT_HEIGHT_2 75 | input_width: INPUT_WIDTH_2 76 | sample_num_frames: 17 77 | sample_fps: 8 78 | start_index: 0 79 | 80 | lightning: 81 | strategy: 82 | target: lightning.pytorch.strategies.DDPStrategy 83 | params: 84 | find_unused_parameters: true 85 | 86 | modelcheckpoint: 87 | params: 88 | every_n_train_steps: 5000 89 | 90 | callbacks: 91 | image_logger: 92 | target: vidtok.modules.logger.ImageVideoLogger 93 | params: 94 | disabled: false 95 | rescale: true 96 | enable_autocast: false 97 | batch_frequency: 5000 98 | max_samples: 2 99 | increase_log_steps: false 100 | log_first_step: false 101 | log_before_first_step: false 102 | log_images_kwargs: 103 | n_rows: 17 104 | 105 | trainer: 106 | precision: bf16-mixed 107 | devices: auto 108 | num_nodes: 1 109 | benchmark: true 110 | num_sanity_val_steps: 10 111 | val_check_interval: 2000 112 | check_val_every_n_epoch: null # default: 1 113 | accumulate_grad_batches: 1 114 | max_epochs: 1000 115 | -------------------------------------------------------------------------------- /open_tokenizer/model/vidtok/configs/vidtok_kl_causal_488_16chn.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1e-5 3 | target: vidtok.models.autoencoder.AutoencodingEngine 4 | params: 5 | monitor: val/rec_loss 6 | mode: min 7 | # ckpt_path: checkpoints/vidtok_kl_causal_488_16chn.ckpt # train from existing checkpoint 8 | ignore_keys: [] 9 | # ema_decay: 0.999 10 | 11 | encoder_config: 12 | target: vidtok.modules.model_3dcausal.EncoderCausal3DPadding 13 | params: 14 | double_z: true 15 | z_channels: 16 16 | in_channels: 3 17 | out_ch: 3 18 | ch: 128 19 | ch_mult: [1, 2, 4, 4] 20 | time_downsample_factor: 4 21 | num_res_blocks: 2 22 | dropout: 0.0 23 | use_checkpoint: false 24 | init_pad_mode: replicate 25 | norm_type: layernorm # layernorm, groupnorm 26 | fix_encoder: false # if True, fix it without updating params 27 | fix_decoder: false # if True, fix it without updating params 28 | 29 | decoder_config: 30 | target: vidtok.modules.model_3dcausal.DecoderCausal3DPadding 31 | params: ${model.params.encoder_config.params} 32 | 33 | regularizer_config: 34 | target: vidtok.modules.regularizers.DiagonalGaussianRegularizer 35 | 36 | loss_config: 37 | target: vidtok.modules.losses.GeneralLPIPSWithDiscriminator 38 | params: 39 | dims: 3 # video - [t,h,w] 40 | perceptual_weight: 1.0 41 | disc_start: 20001 42 | disc_weight: 0.2 43 | disc_type: 2d # 2d, 3d 44 | learn_logvar: true 45 | gen_loss_cross_entropy: true 46 | lecam_loss_weight: 0.005 47 | regularization_weights: {'aux_loss': 1.0, 'kl_loss': 0.000001} 48 | 49 | data: 50 | target: vidtok.data.datamodule.DataModuleFromConfig 51 | params: 52 | batch_size: 2 53 | num_workers: 12 54 | 55 | train: 56 | target: vidtok.data.vidtok.VidTokDataset 57 | params: 58 | data_dir: DATA_DIR_1 # DATA_DIR for training data 59 | meta_path: META_PATH_1 # path to the .csv meta file of training data 60 | video_params: 61 | input_height: INPUT_HEIGHT_1 62 | input_width: INPUT_WIDTH_1 63 | sample_num_frames: 17 64 | sample_fps: 3 65 | 66 | validation: 67 | target: vidtok.data.vidtok.VidTokDataset 68 | params: 69 | data_dir: DATA_DIR_2 # DATA_DIR for validation data 70 | meta_path: META_PATH_2 # path to the .csv meta file of validation data 71 | video_params: 72 | input_height: INPUT_HEIGHT_2 73 | input_width: INPUT_WIDTH_2 74 | sample_num_frames: 17 75 | sample_fps: 8 76 | start_index: 0 77 | 78 | lightning: 79 | strategy: 80 | target: lightning.pytorch.strategies.DDPStrategy 81 | params: 82 | find_unused_parameters: true 83 | 84 | modelcheckpoint: 85 | params: 86 | every_n_train_steps: 5000 87 | 88 | callbacks: 89 | image_logger: 90 | target: vidtok.modules.logger.ImageVideoLogger 91 | params: 92 | disabled: false 93 | rescale: true 94 | enable_autocast: false 95 | batch_frequency: 5000 96 | max_samples: 2 97 | increase_log_steps: false 98 | log_first_step: false 99 | log_before_first_step: false 100 | log_images_kwargs: 101 | n_rows: 17 102 | 103 | trainer: 104 | precision: bf16-mixed 105 | devices: auto 106 | num_nodes: 1 107 | benchmark: true 108 | num_sanity_val_steps: 10 109 | val_check_interval: 2000 110 | check_val_every_n_epoch: null # default: 1 111 | accumulate_grad_batches: 1 112 | max_epochs: 1000 113 | -------------------------------------------------------------------------------- /open_tokenizer/model/vidtok/configs/vidtok_kl_causal_488_4chn.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1e-5 3 | target: vidtok.models.autoencoder.AutoencodingEngine 4 | params: 5 | monitor: val/rec_loss 6 | mode: min 7 | # ckpt_path: checkpoints/vidtok_kl_causal_488_4chn.ckpt # train from existing checkpoint 8 | ignore_keys: [] 9 | # ema_decay: 0.999 10 | 11 | encoder_config: 12 | target: vidtok.modules.model_3dcausal.EncoderCausal3DPadding 13 | params: 14 | double_z: true 15 | z_channels: 4 16 | in_channels: 3 17 | out_ch: 3 18 | ch: 128 19 | ch_mult: [1, 2, 4, 4] 20 | time_downsample_factor: 4 21 | num_res_blocks: 2 22 | dropout: 0.0 23 | use_checkpoint: false 24 | init_pad_mode: replicate 25 | norm_type: layernorm # layernorm, groupnorm 26 | fix_encoder: false # if True, fix it without updating params 27 | fix_decoder: false # if True, fix it without updating params 28 | 29 | decoder_config: 30 | target: vidtok.modules.model_3dcausal.DecoderCausal3DPadding 31 | params: ${model.params.encoder_config.params} 32 | 33 | regularizer_config: 34 | target: vidtok.modules.regularizers.DiagonalGaussianRegularizer 35 | 36 | loss_config: 37 | target: vidtok.modules.losses.GeneralLPIPSWithDiscriminator 38 | params: 39 | dims: 3 # video - [t,h,w] 40 | perceptual_weight: 1.0 41 | disc_start: 20001 42 | disc_weight: 0.2 43 | disc_type: 2d # 2d, 3d 44 | learn_logvar: true 45 | gen_loss_cross_entropy: true 46 | lecam_loss_weight: 0.005 47 | regularization_weights: {'aux_loss': 1.0, 'kl_loss': 0.000001} 48 | 49 | data: 50 | target: vidtok.data.datamodule.DataModuleFromConfig 51 | params: 52 | batch_size: 2 53 | num_workers: 12 54 | 55 | train: 56 | target: vidtok.data.vidtok.VidTokDataset 57 | params: 58 | data_dir: DATA_DIR_1 # DATA_DIR for training data 59 | meta_path: META_PATH_1 # path to the .csv meta file of training data 60 | video_params: 61 | input_height: INPUT_HEIGHT_1 62 | input_width: INPUT_WIDTH_1 63 | sample_num_frames: 17 64 | sample_fps: 3 65 | 66 | validation: 67 | target: vidtok.data.vidtok.VidTokDataset 68 | params: 69 | data_dir: DATA_DIR_2 # DATA_DIR for validation data 70 | meta_path: META_PATH_2 # path to the .csv meta file of validation data 71 | video_params: 72 | input_height: INPUT_HEIGHT_2 73 | input_width: INPUT_WIDTH_2 74 | sample_num_frames: 17 75 | sample_fps: 8 76 | start_index: 0 77 | 78 | lightning: 79 | strategy: 80 | target: lightning.pytorch.strategies.DDPStrategy 81 | params: 82 | find_unused_parameters: true 83 | 84 | modelcheckpoint: 85 | params: 86 | every_n_train_steps: 5000 87 | 88 | callbacks: 89 | image_logger: 90 | target: vidtok.modules.logger.ImageVideoLogger 91 | params: 92 | disabled: false 93 | rescale: true 94 | enable_autocast: false 95 | batch_frequency: 5000 96 | max_samples: 2 97 | increase_log_steps: false 98 | log_first_step: false 99 | log_before_first_step: false 100 | log_images_kwargs: 101 | n_rows: 17 102 | 103 | trainer: 104 | precision: bf16-mixed 105 | devices: auto 106 | num_nodes: 1 107 | benchmark: true 108 | num_sanity_val_steps: 10 109 | val_check_interval: 2000 110 | check_val_every_n_epoch: null # default: 1 111 | accumulate_grad_batches: 1 112 | max_epochs: 1000 113 | -------------------------------------------------------------------------------- /open_tokenizer/model/vidtok/configs/vidtok_kl_causal_488_8chn.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1e-5 3 | target: vidtok.models.autoencoder.AutoencodingEngine 4 | params: 5 | monitor: val/rec_loss 6 | mode: min 7 | # ckpt_path: checkpoints/vidtok_kl_causal_488_8chn.ckpt # train from existing checkpoint 8 | ignore_keys: [] 9 | # ema_decay: 0.999 10 | 11 | encoder_config: 12 | target: vidtok.modules.model_3dcausal.EncoderCausal3DPadding 13 | params: 14 | double_z: true 15 | z_channels: 8 16 | in_channels: 3 17 | out_ch: 3 18 | ch: 128 19 | ch_mult: [1, 2, 4, 4] 20 | time_downsample_factor: 4 21 | num_res_blocks: 2 22 | dropout: 0.0 23 | use_checkpoint: false 24 | init_pad_mode: replicate 25 | norm_type: layernorm # layernorm, groupnorm 26 | fix_encoder: false # if True, fix it without updating params 27 | fix_decoder: false # if True, fix it without updating params 28 | 29 | decoder_config: 30 | target: vidtok.modules.model_3dcausal.DecoderCausal3DPadding 31 | params: ${model.params.encoder_config.params} 32 | 33 | regularizer_config: 34 | target: vidtok.modules.regularizers.DiagonalGaussianRegularizer 35 | 36 | loss_config: 37 | target: vidtok.modules.losses.GeneralLPIPSWithDiscriminator 38 | params: 39 | dims: 3 # video - [t,h,w] 40 | perceptual_weight: 1.0 41 | disc_start: 20001 42 | disc_weight: 0.2 43 | disc_type: 2d # 2d, 3d 44 | learn_logvar: true 45 | gen_loss_cross_entropy: true 46 | lecam_loss_weight: 0.005 47 | regularization_weights: {'aux_loss': 1.0, 'kl_loss': 0.000001} 48 | 49 | data: 50 | target: vidtok.data.datamodule.DataModuleFromConfig 51 | params: 52 | batch_size: 2 53 | num_workers: 12 54 | 55 | train: 56 | target: vidtok.data.vidtok.VidTokDataset 57 | params: 58 | data_dir: DATA_DIR_1 # DATA_DIR for training data 59 | meta_path: META_PATH_1 # path to the .csv meta file of training data 60 | video_params: 61 | input_height: INPUT_HEIGHT_1 62 | input_width: INPUT_WIDTH_1 63 | sample_num_frames: 17 64 | sample_fps: 3 65 | 66 | validation: 67 | target: vidtok.data.vidtok.VidTokDataset 68 | params: 69 | data_dir: DATA_DIR_2 # DATA_DIR for validation data 70 | meta_path: META_PATH_2 # path to the .csv meta file of validation data 71 | video_params: 72 | input_height: INPUT_HEIGHT_2 73 | input_width: INPUT_WIDTH_2 74 | sample_num_frames: 17 75 | sample_fps: 8 76 | start_index: 0 77 | 78 | lightning: 79 | strategy: 80 | target: lightning.pytorch.strategies.DDPStrategy 81 | params: 82 | find_unused_parameters: true 83 | 84 | modelcheckpoint: 85 | params: 86 | every_n_train_steps: 5000 87 | 88 | callbacks: 89 | image_logger: 90 | target: vidtok.modules.logger.ImageVideoLogger 91 | params: 92 | disabled: false 93 | rescale: true 94 | enable_autocast: false 95 | batch_frequency: 5000 96 | max_samples: 2 97 | increase_log_steps: false 98 | log_first_step: false 99 | log_before_first_step: false 100 | log_images_kwargs: 101 | n_rows: 17 102 | 103 | trainer: 104 | precision: bf16-mixed 105 | devices: auto 106 | num_nodes: 1 107 | benchmark: true 108 | num_sanity_val_steps: 10 109 | val_check_interval: 2000 110 | check_val_every_n_epoch: null # default: 1 111 | accumulate_grad_batches: 1 112 | max_epochs: 1000 113 | -------------------------------------------------------------------------------- /open_tokenizer/model/vidtok/configs/vidtok_kl_noncausal_41616_4chn.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1e-5 3 | target: vidtok.models.autoencoder.AutoencodingEngine 4 | params: 5 | monitor: val/rec_loss 6 | mode: min 7 | # ckpt_path: checkpoints/vidtok_kl_noncausal_41616_4chn.ckpt # train from existing checkpoint 8 | ignore_keys: [] 9 | # ema_decay: 0.999 10 | 11 | encoder_config: 12 | target: vidtok.modules.model_3dnoncausal.Encoder3D 13 | params: 14 | double_z: true 15 | z_channels: 4 16 | in_channels: 3 17 | out_ch: 3 18 | ch: 128 19 | ch_mult: [1, 2, 4, 4, 4] 20 | num_res_blocks: 2 21 | dropout: 0.0 22 | use_checkpoint: false 23 | norm_type: layernorm # layernorm, groupnorm 24 | fix_encoder: false 25 | fix_decoder: false 26 | 27 | decoder_config: 28 | target: vidtok.modules.model_3dnoncausal.Decoder3D 29 | params: ${model.params.encoder_config.params} 30 | 31 | regularizer_config: 32 | target: vidtok.modules.regularizers.DiagonalGaussianRegularizer 33 | 34 | loss_config: 35 | target: vidtok.modules.losses.GeneralLPIPSWithDiscriminator 36 | params: 37 | dims: 3 # video - [t,h,w] 38 | perceptual_weight: 1.0 39 | disc_start: 20001 40 | disc_weight: 0.2 41 | disc_type: 2d # 2d, 3d 42 | learn_logvar: true 43 | gen_loss_cross_entropy: true 44 | lecam_loss_weight: 0.005 45 | regularization_weights: {'aux_loss': 1.0, 'kl_loss': 0.000001} 46 | 47 | data: 48 | target: vidtok.data.datamodule.DataModuleFromConfig 49 | params: 50 | batch_size: 2 51 | num_workers: 12 52 | 53 | train: 54 | target: vidtok.data.vidtok.VidTokDataset 55 | params: 56 | data_dir: DATA_DIR_1 # DATA_DIR for training data 57 | meta_path: META_PATH_1 # path to the .csv meta file of training data 58 | video_params: 59 | input_height: INPUT_HEIGHT_1 60 | input_width: INPUT_WIDTH_1 61 | sample_num_frames: 16 62 | sample_fps: 3 63 | 64 | validation: 65 | target: vidtok.data.vidtok.VidTokDataset 66 | params: 67 | data_dir: DATA_DIR_2 # DATA_DIR for validation data 68 | meta_path: META_PATH_2 # path to the .csv meta file of validation data 69 | video_params: 70 | input_height: INPUT_HEIGHT_2 71 | input_width: INPUT_WIDTH_2 72 | sample_num_frames: 16 73 | sample_fps: 8 74 | start_index: 0 75 | 76 | lightning: 77 | strategy: 78 | target: lightning.pytorch.strategies.DDPStrategy 79 | params: 80 | find_unused_parameters: true 81 | 82 | modelcheckpoint: 83 | params: 84 | every_n_train_steps: 5000 85 | 86 | callbacks: 87 | image_logger: 88 | target: vidtok.modules.logger.ImageVideoLogger 89 | params: 90 | disabled: false 91 | rescale: true 92 | enable_autocast: false 93 | batch_frequency: 5000 94 | max_samples: 2 95 | increase_log_steps: false 96 | log_first_step: false 97 | log_before_first_step: false 98 | log_images_kwargs: 99 | n_rows: 16 100 | 101 | trainer: 102 | precision: bf16-mixed 103 | devices: auto 104 | num_nodes: 1 105 | benchmark: true 106 | num_sanity_val_steps: 10 107 | val_check_interval: 2000 108 | check_val_every_n_epoch: null # default: 1 109 | accumulate_grad_batches: 1 110 | max_epochs: 1000 111 | -------------------------------------------------------------------------------- /open_tokenizer/model/vidtok/configs/vidtok_kl_noncausal_488_4chn.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1e-5 3 | target: vidtok.models.autoencoder.AutoencodingEngine 4 | params: 5 | monitor: val/rec_loss 6 | mode: min 7 | # ckpt_path: checkpoints/vidtok_kl_noncausal_488_4chn.ckpt # train from existing checkpoint 8 | ignore_keys: [] 9 | # ema_decay: 0.999 10 | 11 | encoder_config: 12 | target: vidtok.modules.model_3dnoncausal.Encoder3D 13 | params: 14 | double_z: true 15 | z_channels: 4 16 | in_channels: 3 17 | out_ch: 3 18 | ch: 128 19 | ch_mult: [1, 2, 4, 4] 20 | num_res_blocks: 2 21 | dropout: 0.0 22 | use_checkpoint: false 23 | norm_type: layernorm # layernorm, groupnorm 24 | fix_encoder: false 25 | fix_decoder: false 26 | 27 | decoder_config: 28 | target: vidtok.modules.model_3dnoncausal.Decoder3D 29 | params: ${model.params.encoder_config.params} 30 | 31 | regularizer_config: 32 | target: vidtok.modules.regularizers.DiagonalGaussianRegularizer 33 | 34 | loss_config: 35 | target: vidtok.modules.losses.GeneralLPIPSWithDiscriminator 36 | params: 37 | dims: 3 # video - [t,h,w] 38 | perceptual_weight: 1.0 39 | disc_start: 20001 40 | disc_weight: 0.2 41 | disc_type: 2d # 2d, 3d 42 | learn_logvar: true 43 | gen_loss_cross_entropy: true 44 | lecam_loss_weight: 0.005 45 | regularization_weights: {'aux_loss': 1.0, 'kl_loss': 0.000001} 46 | 47 | data: 48 | target: vidtok.data.datamodule.DataModuleFromConfig 49 | params: 50 | batch_size: 2 51 | num_workers: 12 52 | 53 | train: 54 | target: vidtok.data.vidtok.VidTokDataset 55 | params: 56 | data_dir: DATA_DIR_1 # DATA_DIR for training data 57 | meta_path: META_PATH_1 # path to the .csv meta file of training data 58 | video_params: 59 | input_height: INPUT_HEIGHT_1 60 | input_width: INPUT_WIDTH_1 61 | sample_num_frames: 16 62 | sample_fps: 3 63 | 64 | validation: 65 | target: vidtok.data.vidtok.VidTokDataset 66 | params: 67 | data_dir: DATA_DIR_2 # DATA_DIR for validation data 68 | meta_path: META_PATH_2 # path to the .csv meta file of validation data 69 | video_params: 70 | input_height: INPUT_HEIGHT_2 71 | input_width: INPUT_WIDTH_2 72 | sample_num_frames: 16 73 | sample_fps: 8 74 | start_index: 0 75 | 76 | lightning: 77 | strategy: 78 | target: lightning.pytorch.strategies.DDPStrategy 79 | params: 80 | find_unused_parameters: true 81 | 82 | modelcheckpoint: 83 | params: 84 | every_n_train_steps: 5000 85 | 86 | callbacks: 87 | image_logger: 88 | target: vidtok.modules.logger.ImageVideoLogger 89 | params: 90 | disabled: false 91 | rescale: true 92 | enable_autocast: false 93 | batch_frequency: 5000 94 | max_samples: 2 95 | increase_log_steps: false 96 | log_first_step: false 97 | log_before_first_step: false 98 | log_images_kwargs: 99 | n_rows: 16 100 | 101 | trainer: 102 | precision: bf16-mixed 103 | devices: auto 104 | num_nodes: 1 105 | benchmark: true 106 | num_sanity_val_steps: 10 107 | val_check_interval: 2000 108 | check_val_every_n_epoch: null # default: 1 109 | accumulate_grad_batches: 1 110 | max_epochs: 1000 111 | -------------------------------------------------------------------------------- /open_tokenizer/model/vidtok/configs/vidtwin/vidtwin_structure_7_7_8_dynamics_7_8.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.6e-4 3 | target: vidtwin.models.vidtwin_ae.VidAutoEncoderQformerCompactSymVidVAE 4 | params: 5 | input_key: jpg 6 | monitor: val/rec_loss 7 | ckpt_path: PATH_TO_CHECKPOINT 8 | ignore_keys: [] 9 | expect_ch: 8 10 | cont_num_blocks: 1 11 | downsample_motion: True 12 | motion_num_blocks: 1 13 | d_dim: 8 14 | 15 | temporal_qformer_config: 16 | target: vidtwin.modules.qformer.MyQformerInterface 17 | params: 18 | num_query_tokens: 16 19 | query_hidden_size: 64 20 | encoder_hidden_size: 768 21 | 22 | encoder_config: 23 | target: vidtwin.modules.st_transformer.STTEncoder 24 | params: 25 | in_channels: 3 26 | input_size: [16, 224, 224] 27 | patch_size: [1, 16, 16] 28 | hidden_size: 768 29 | depth: 16 30 | num_heads: 12 31 | temporal_casual: true 32 | 33 | decoder_config: 34 | target: vidtwin.modules.st_transformer.STTDecoder 35 | params: 36 | in_channels: 3 37 | input_size: [16, 224, 224] 38 | patch_size: [1, 16, 16] 39 | hidden_size: 768 40 | depth: 16 41 | num_heads: 12 42 | temporal_casual: true 43 | 44 | loss_config: 45 | target: vidtok.modules.losses.GeneralLPIPSWithDiscriminator 46 | params: 47 | perceptual_weight: 0.05 48 | disc_start: 20001 49 | disc_weight: 0.05 50 | learn_logvar: True 51 | dims: 3 52 | disc_type: 2d 53 | regularization_weights: 54 | kl_loss: 0.001 55 | 56 | regularizer_config: 57 | target: vidtok.modules.regularizers.DiagonalGaussianRegularizer 58 | params: 59 | sample: True 60 | 61 | 62 | lr_scheduler_config_d: 63 | target: vidtok.models.vidtwin_ae.LambdaWarmUpCosineScheduler 64 | params: 65 | lr_min: 0 66 | lr_max: 1.5e-05 67 | lr_start: 1.0e-05 68 | warmup_steps: 5000 69 | lr_scheduler_config_g: 70 | target: vidtok.models.vidtwin_ae.LambdaWarmUpCosineScheduler 71 | params: 72 | lr_min: 0 73 | lr_max: 3.0e-05 74 | lr_start: 0 75 | warmup_steps: 5000 76 | optimizer_config: 77 | target: torch.optim.AdamW 78 | params: 79 | betas: 80 | - 0 81 | - 0.9 82 | weight_decay: 0.0001 83 | lr_scheduler_config: 84 | target: inverse_sqrt 85 | params: 86 | num_warmup_steps: 2000 87 | frequency: 1 88 | 89 | data: 90 | target: vidtok.data.datamodule.DataModuleFromConfig 91 | params: 92 | batch_size: 2 93 | num_workers: 12 94 | 95 | train: 96 | target: vidtok.data.vidtok.VidTokDataset 97 | params: 98 | data_dir: DATA_DIR_1 # DATA_DIR for training data 99 | meta_path: META_PATH_1 # path to the .csv meta file of training data 100 | video_params: 101 | input_height: 224 102 | input_width: 224 103 | sample_num_frames: 16 104 | sample_fps: 8 105 | 106 | validation: 107 | target: vidtok.data.vidtok.VidTokDataset 108 | params: 109 | data_dir: DATA_DIR_2 # DATA_DIR for validation data 110 | meta_path: META_PATH_2 # path to the .csv meta file of validation data 111 | video_params: 112 | input_height: 224 113 | input_width: 224 114 | sample_num_frames: 16 115 | sample_fps: 8 116 | start_index: 0 117 | 118 | 119 | lightning: 120 | strategy: 121 | target: lightning.pytorch.strategies.DDPStrategy 122 | params: 123 | find_unused_parameters: True 124 | 125 | modelcheckpoint: 126 | params: 127 | every_n_train_steps: 5000 128 | 129 | 130 | callbacks: 131 | image_logger: 132 | target: vidtok.modules.logger.ImageVideoLogger 133 | params: 134 | disabled: false 135 | rescale: true 136 | enable_autocast: false 137 | batch_frequency: 5000 138 | max_samples: 2 139 | increase_log_steps: false 140 | log_first_step: false 141 | log_before_first_step: false 142 | log_images_kwargs: 143 | n_rows: 2 144 | 145 | 146 | 147 | trainer: 148 | # precision: bf16-mixed # 16-mixed 149 | benchmark: True 150 | devices: 4 151 | num_sanity_val_steps: 10 152 | val_check_interval: 5000 153 | accumulate_grad_batches: 1 154 | max_epochs: 10 155 | -------------------------------------------------------------------------------- /open_tokenizer/model/vidtok/data/datamodule.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from functools import partial 3 | 4 | import torch 5 | import lightning.pytorch as pl 6 | from torch.utils.data import DataLoader, Dataset, IterableDataset 7 | 8 | from vidtok.modules.util import instantiate_from_config 9 | 10 | 11 | class WrappedDataset(Dataset): 12 | """Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset""" 13 | 14 | def __init__(self, dataset): 15 | self.data = dataset 16 | 17 | def __len__(self): 18 | return len(self.data) 19 | 20 | def __getitem__(self, idx): 21 | return self.data[idx] 22 | 23 | 24 | def worker_init_fn(_): 25 | worker_info = torch.utils.data.get_worker_info() 26 | 27 | dataset = worker_info.dataset 28 | worker_id = worker_info.id 29 | 30 | if isinstance(dataset, IterableDataset): 31 | split_size = dataset.num_records // worker_info.num_workers 32 | # reset num_records to the true number to retain reliable length information 33 | dataset.sample_ids = dataset.valid_ids[ 34 | worker_id * split_size : (worker_id + 1) * split_size 35 | ] 36 | current_id = np.random.choice(len(np.random.get_state()[1]), 1) 37 | return np.random.seed(np.random.get_state()[1][current_id] + worker_id) 38 | else: 39 | return np.random.seed(np.random.get_state()[1][0] + worker_id) 40 | 41 | 42 | class DataModuleFromConfig(pl.LightningDataModule): 43 | def __init__( 44 | self, 45 | batch_size, 46 | train=None, 47 | validation=None, 48 | test=None, 49 | predict=None, 50 | wrap=False, 51 | num_workers=None, 52 | pin_train_memory=True, 53 | is_iterable_dataset=False, 54 | shuffle_test_loader=False, 55 | use_worker_init_fn=False, 56 | shuffle_val_dataloader=False, 57 | ): 58 | super().__init__() 59 | self.batch_size = batch_size 60 | self.dataset_configs = dict() 61 | self.num_workers = num_workers if num_workers is not None else batch_size * 2 62 | self.pin_train_memory = pin_train_memory 63 | self.is_iterable_dataset = is_iterable_dataset 64 | self.use_worker_init_fn = use_worker_init_fn 65 | if train is not None: 66 | self.dataset_configs["train"] = train 67 | self.train_dataloader = self._train_dataloader 68 | if validation is not None: 69 | self.dataset_configs["validation"] = validation 70 | self.val_dataloader = partial( 71 | self._val_dataloader, shuffle=shuffle_val_dataloader 72 | ) 73 | if test is not None: 74 | self.dataset_configs["test"] = test 75 | self.test_dataloader = partial( 76 | self._test_dataloader, shuffle=shuffle_test_loader 77 | ) 78 | if predict is not None: 79 | self.dataset_configs["predict"] = predict 80 | self.predict_dataloader = self._predict_dataloader 81 | self.wrap = wrap 82 | 83 | def prepare_data(self): 84 | for data_cfg in self.dataset_configs.values(): 85 | instantiate_from_config(data_cfg) 86 | 87 | def setup(self, stage=None): 88 | self.datasets = dict( 89 | (k, instantiate_from_config(self.dataset_configs[k])) 90 | for k in self.dataset_configs 91 | ) 92 | if self.wrap: 93 | for k in self.datasets: 94 | self.datasets[k] = WrappedDataset(self.datasets[k]) 95 | 96 | def _train_dataloader(self): 97 | if self.is_iterable_dataset or self.use_worker_init_fn: 98 | init_fn = worker_init_fn 99 | else: 100 | init_fn = None 101 | return DataLoader( 102 | self.datasets["train"], 103 | batch_size=self.batch_size, 104 | num_workers=self.num_workers, 105 | pin_memory=self.pin_train_memory, 106 | shuffle=False if self.is_iterable_dataset else True, 107 | worker_init_fn=init_fn, 108 | ) 109 | 110 | def _val_dataloader(self, shuffle=False): 111 | if self.is_iterable_dataset or self.use_worker_init_fn: 112 | init_fn = worker_init_fn 113 | else: 114 | init_fn = None 115 | return DataLoader( 116 | self.datasets["validation"], 117 | batch_size=self.batch_size, 118 | num_workers=self.num_workers, 119 | worker_init_fn=init_fn, 120 | shuffle=shuffle, 121 | ) 122 | 123 | def _test_dataloader(self, shuffle=False): 124 | if self.is_iterable_dataset or self.use_worker_init_fn: 125 | init_fn = worker_init_fn 126 | else: 127 | init_fn = None 128 | 129 | # do not shuffle dataloader for iterable dataset 130 | shuffle = shuffle and (not self.is_iterable_dataset) 131 | 132 | return DataLoader( 133 | self.datasets["test"], 134 | batch_size=self.batch_size, 135 | num_workers=self.num_workers, 136 | worker_init_fn=init_fn, 137 | shuffle=shuffle, 138 | ) 139 | 140 | def _predict_dataloader(self, shuffle=False): 141 | if self.is_iterable_dataset or self.use_worker_init_fn: 142 | init_fn = worker_init_fn 143 | else: 144 | init_fn = None 145 | return DataLoader( 146 | self.datasets["predict"], 147 | batch_size=self.batch_size, 148 | num_workers=self.num_workers, 149 | worker_init_fn=init_fn, 150 | ) 151 | -------------------------------------------------------------------------------- /open_tokenizer/model/vidtok/data/video_read.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import decord 4 | import numpy as np 5 | import torch 6 | 7 | from vidtok.modules.util import print0 8 | 9 | decord.bridge.set_bridge("torch") 10 | 11 | 12 | def sample_frames_with_fps( 13 | total_frames, 14 | video_fps, 15 | sample_num_frames, 16 | sample_fps, 17 | start_index=None 18 | ): 19 | """sample frames proportional to the length of the frames in one second 20 | e.g., 1s video has 30 frames, when 'fps'=3, we sample frames with spacing of 30/3=10 21 | return the frame indices 22 | 23 | Parameters 24 | ---------- 25 | total_frames : length of the video 26 | video_fps : original fps of the video 27 | sample_num_frames : number of frames to sample 28 | sample_fps : the fps to sample frames 29 | start_index : the starting frame index. If it is not None, it will be used as the starting frame index 30 | 31 | Returns 32 | ------- 33 | frame indices 34 | """ 35 | sample_num_frames = min(sample_num_frames, total_frames) 36 | interval = round(video_fps / sample_fps) 37 | frames_range = (sample_num_frames - 1) * interval + 1 38 | 39 | if start_index is not None: 40 | start = start_index 41 | elif total_frames - frames_range - 1 < 0: 42 | start = 0 43 | else: 44 | start = random.randint(0, total_frames - frames_range - 1) 45 | 46 | frame_idxs = np.linspace( 47 | start=start, stop=min(total_frames - 1, start + frames_range), num=sample_num_frames 48 | ).astype(int) 49 | 50 | return frame_idxs 51 | 52 | 53 | def read_frames_with_decord( 54 | video_path, 55 | sample_num_frames, 56 | sample_fps, 57 | start_index=None 58 | ) -> tuple[torch.Tensor, list[int]]: 59 | """read frames from video path using decord 60 | 61 | Parameters 62 | ---------- 63 | video_path : path to video 64 | sample_num_frames : number of frames to sample 65 | sample_fps : the fps to sample frames 66 | start_index : the starting frame index. If it is not None, it will be used as the starting frame index 67 | 68 | Returns 69 | ------- 70 | frames (tensor 0~1), frame indices 71 | """ 72 | video_reader = decord.VideoReader(video_path, num_threads=0) 73 | total_frames = len(video_reader) 74 | video_fps = video_reader.get_avg_fps() # note that the fps here is float. 75 | frame_idxs = sample_frames_with_fps( 76 | total_frames=total_frames, 77 | video_fps=video_fps, 78 | sample_num_frames=sample_num_frames, 79 | sample_fps=sample_fps, 80 | start_index=start_index 81 | ) 82 | frames = video_reader.get_batch(frame_idxs) 83 | frames = frames.float() / 255 84 | frames = frames.permute(0, 3, 1, 2) 85 | if (frames.shape[0] != sample_num_frames) or (len(frame_idxs) != sample_num_frames): 86 | print0(f"[bold yellow]\[vidtok.data.video_read][read_frames_with_decord][/bold yellow] Warning: need {sample_num_frames} frames, " 87 | f"but got {frames.shape[0]} frames, {len(frame_idxs)} frame indices, video_path={video_path}.") 88 | return frames, frame_idxs 89 | -------------------------------------------------------------------------------- /open_tokenizer/model/vidtok/modules/discriminator.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | def weights_init(m): 8 | classname = m.__class__.__name__ 9 | if classname.find("Conv") != -1: 10 | nn.init.normal_(m.weight.data, 0.0, 0.02) 11 | elif classname.find("BatchNorm") != -1: 12 | nn.init.normal_(m.weight.data, 1.0, 0.02) 13 | nn.init.constant_(m.bias.data, 0) 14 | 15 | 16 | class ActNorm(nn.Module): 17 | def __init__(self, num_features, logdet=False, affine=True, allow_reverse_init=False): 18 | assert affine 19 | super().__init__() 20 | self.logdet = logdet 21 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) 22 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) 23 | self.allow_reverse_init = allow_reverse_init 24 | 25 | self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8)) 26 | 27 | def initialize(self, input): 28 | with torch.no_grad(): 29 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) 30 | mean = flatten.mean(1).unsqueeze(1).unsqueeze(2).unsqueeze(3).permute(1, 0, 2, 3) 31 | std = flatten.std(1).unsqueeze(1).unsqueeze(2).unsqueeze(3).permute(1, 0, 2, 3) 32 | 33 | self.loc.data.copy_(-mean) 34 | self.scale.data.copy_(1 / (std + 1e-6)) 35 | 36 | def forward(self, input, reverse=False): 37 | if reverse: 38 | return self.reverse(input) 39 | if len(input.shape) == 2: 40 | input = input[:, :, None, None] 41 | squeeze = True 42 | else: 43 | squeeze = False 44 | 45 | _, _, height, width = input.shape 46 | 47 | if self.training and self.initialized.item() == 0: 48 | self.initialize(input) 49 | self.initialized.fill_(1) 50 | 51 | h = self.scale * (input + self.loc) 52 | 53 | if squeeze: 54 | h = h.squeeze(-1).squeeze(-1) 55 | 56 | if self.logdet: 57 | log_abs = torch.log(torch.abs(self.scale)) 58 | logdet = height * width * torch.sum(log_abs) 59 | logdet = logdet * torch.ones(input.shape[0]).to(input) 60 | return h, logdet 61 | 62 | return h 63 | 64 | def reverse(self, output): 65 | if self.training and self.initialized.item() == 0: 66 | if not self.allow_reverse_init: 67 | raise RuntimeError( 68 | "Initializing ActNorm in reverse direction is " 69 | "disabled by default. Use allow_reverse_init=True to enable." 70 | ) 71 | else: 72 | self.initialize(output) 73 | self.initialized.fill_(1) 74 | 75 | if len(output.shape) == 2: 76 | output = output[:, :, None, None] 77 | squeeze = True 78 | else: 79 | squeeze = False 80 | 81 | h = output / self.scale - self.loc 82 | 83 | if squeeze: 84 | h = h.squeeze(-1).squeeze(-1) 85 | return h 86 | 87 | 88 | class NLayerDiscriminator(nn.Module): 89 | """Defines a PatchGAN discriminator as in Pix2Pix.""" 90 | # https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 91 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): 92 | """Construct a PatchGAN discriminator 93 | Parameters: 94 | input_nc (int) -- the number of channels in input images 95 | ndf (int) -- the number of filters in the last conv layer 96 | n_layers (int) -- the number of conv layers in the discriminator 97 | """ 98 | super(NLayerDiscriminator, self).__init__() 99 | if not use_actnorm: 100 | norm_layer = nn.BatchNorm2d 101 | else: 102 | norm_layer = ActNorm 103 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 104 | use_bias = norm_layer.func != nn.BatchNorm2d 105 | else: 106 | use_bias = norm_layer != nn.BatchNorm2d 107 | 108 | kw = 4 109 | padw = 1 110 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] 111 | nf_mult = 1 112 | nf_mult_prev = 1 113 | for n in range(1, n_layers): # gradually increase the number of filters 114 | nf_mult_prev = nf_mult 115 | nf_mult = min(2**n, 8) 116 | sequence += [ 117 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), 118 | norm_layer(ndf * nf_mult), 119 | nn.LeakyReLU(0.2, True), 120 | ] 121 | 122 | nf_mult_prev = nf_mult 123 | nf_mult = min(2**n_layers, 8) 124 | sequence += [ 125 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), 126 | norm_layer(ndf * nf_mult), 127 | nn.LeakyReLU(0.2, True), 128 | ] 129 | 130 | sequence += [ 131 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw) 132 | ] # output 1 channel prediction map 133 | self.main = nn.Sequential(*sequence) 134 | 135 | def forward(self, input): 136 | """Standard forward.""" 137 | return self.main(input) 138 | 139 | 140 | class NLayerDiscriminator3D(nn.Module): 141 | """Defines a 3D PatchGAN discriminator as in Pix2Pix but for 3D inputs.""" 142 | # https://github.com/PKU-YuanGroup/Open-Sora-Plan/blob/main/opensora/models/causalvideovae/model/losses/discriminator.py 143 | def __init__(self, input_nc=1, ndf=64, n_layers=3, use_actnorm=False): 144 | """ 145 | Construct a 3D PatchGAN discriminator 146 | 147 | Parameters: 148 | input_nc (int) -- the number of channels in input volumes 149 | ndf (int) -- the number of filters in the last conv layer 150 | n_layers (int) -- the number of conv layers in the discriminator 151 | use_actnorm (bool) -- flag to use actnorm instead of batchnorm 152 | """ 153 | super(NLayerDiscriminator3D, self).__init__() 154 | if not use_actnorm: 155 | norm_layer = nn.BatchNorm3d 156 | else: 157 | raise NotImplementedError("Not implemented.") 158 | if type(norm_layer) == functools.partial: 159 | use_bias = norm_layer.func != nn.BatchNorm3d 160 | else: 161 | use_bias = norm_layer != nn.BatchNorm3d 162 | 163 | kw = 3 164 | padw = 1 165 | sequence = [nn.Conv3d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] 166 | nf_mult = 1 167 | nf_mult_prev = 1 168 | for n in range(1, n_layers): # gradually increase the number of filters 169 | nf_mult_prev = nf_mult 170 | nf_mult = min(2**n, 8) 171 | sequence += [ 172 | nn.Conv3d( 173 | ndf * nf_mult_prev, 174 | ndf * nf_mult, 175 | kernel_size=(kw, kw, kw), 176 | stride=(2 if n == 1 else 1, 2, 2), 177 | padding=padw, 178 | bias=use_bias, 179 | ), 180 | norm_layer(ndf * nf_mult), 181 | nn.LeakyReLU(0.2, True), 182 | ] 183 | 184 | nf_mult_prev = nf_mult 185 | nf_mult = min(2**n_layers, 8) 186 | sequence += [ 187 | nn.Conv3d( 188 | ndf * nf_mult_prev, ndf * nf_mult, kernel_size=(kw, kw, kw), stride=1, padding=padw, bias=use_bias 189 | ), 190 | norm_layer(ndf * nf_mult), 191 | nn.LeakyReLU(0.2, True), 192 | ] 193 | 194 | sequence += [ 195 | nn.Conv3d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw) 196 | ] # output 1 channel prediction map 197 | self.main = nn.Sequential(*sequence) 198 | 199 | def forward(self, input): 200 | """Standard forward.""" 201 | return self.main(input) 202 | -------------------------------------------------------------------------------- /open_tokenizer/model/vidtok/modules/distributions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class DiagonalGaussianDistribution(object): 6 | def __init__(self, parameters, deterministic=False): 7 | self.parameters = parameters 8 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 9 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 10 | self.deterministic = deterministic 11 | self.std = torch.exp(0.5 * self.logvar) 12 | self.var = torch.exp(self.logvar) 13 | if self.deterministic: 14 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 15 | 16 | def sample(self): 17 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 18 | return x 19 | 20 | def kl(self, other=None): 21 | if self.deterministic: 22 | return torch.Tensor([0.0]) 23 | else: 24 | if other is None: 25 | return 0.5 * torch.sum( 26 | torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, 27 | dim=[1, 2, 3], 28 | ) 29 | else: 30 | return 0.5 * torch.sum( 31 | torch.pow(self.mean - other.mean, 2) / other.var 32 | + self.var / other.var 33 | - 1.0 34 | - self.logvar 35 | + other.logvar, 36 | dim=[1, 2, 3], 37 | ) 38 | 39 | def nll(self, sample, dims=[1, 2, 3]): 40 | if self.deterministic: 41 | return torch.Tensor([0.0]) 42 | logtwopi = np.log(2.0 * np.pi) 43 | return 0.5 * torch.sum( 44 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 45 | dim=dims, 46 | ) 47 | 48 | def mode(self): 49 | return self.mean 50 | -------------------------------------------------------------------------------- /open_tokenizer/model/vidtok/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError("Decay must be between 0 and 1") 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer( 14 | "num_updates", 15 | torch.tensor(0, dtype=torch.int) if use_num_upates else torch.tensor(-1, dtype=torch.int), 16 | ) 17 | 18 | for name, p in model.named_parameters(): 19 | if p.requires_grad: 20 | # remove as '.'-character is not allowed in buffers 21 | s_name = name.replace(".", "") 22 | self.m_name2s_name.update({name: s_name}) 23 | self.register_buffer(s_name, p.clone().detach().data) 24 | 25 | self.collected_params = [] 26 | 27 | def reset_num_updates(self): 28 | del self.num_updates 29 | self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int)) 30 | 31 | def forward(self, model): 32 | decay = self.decay 33 | 34 | if self.num_updates >= 0: 35 | self.num_updates += 1 36 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) 37 | 38 | one_minus_decay = 1.0 - decay 39 | 40 | with torch.no_grad(): 41 | m_param = dict(model.named_parameters()) 42 | shadow_params = dict(self.named_buffers()) 43 | 44 | for key in m_param: 45 | if m_param[key].requires_grad: 46 | sname = self.m_name2s_name[key] 47 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 48 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 49 | else: 50 | assert not key in self.m_name2s_name 51 | 52 | def copy_to(self, model): 53 | m_param = dict(model.named_parameters()) 54 | shadow_params = dict(self.named_buffers()) 55 | for key in m_param: 56 | if m_param[key].requires_grad: 57 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 58 | else: 59 | assert not key in self.m_name2s_name 60 | 61 | def store(self, parameters): 62 | """ 63 | Save the current parameters for restoring later. 64 | Args: 65 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 66 | temporarily stored. 67 | """ 68 | self.collected_params = [param.clone() for param in parameters] 69 | 70 | def restore(self, parameters): 71 | """ 72 | Restore the parameters stored with the `store` method. 73 | Useful to validate the model with EMA parameters without affecting the 74 | original optimization process. Store the parameters before the 75 | `copy_to` method. After validation (or model saving), use this to 76 | restore the former parameters. 77 | Args: 78 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 79 | updated with the stored parameters. 80 | """ 81 | for c_param, param in zip(self.collected_params, parameters): 82 | param.data.copy_(c_param.data) 83 | -------------------------------------------------------------------------------- /open_tokenizer/model/vidtok/modules/lpips.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | from collections import namedtuple 4 | from tqdm import tqdm 5 | 6 | import requests 7 | import torch 8 | import torch.nn as nn 9 | from torchvision import models 10 | 11 | from .util import print0 12 | 13 | URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"} 14 | 15 | CKPT_MAP = {"vgg_lpips": "vgg.pth"} 16 | 17 | MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"} 18 | 19 | 20 | def download(url, local_path, chunk_size=1024): 21 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 22 | with requests.get(url, stream=True) as r: 23 | total_size = int(r.headers.get("content-length", 0)) 24 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 25 | with open(local_path, "wb") as f: 26 | for data in r.iter_content(chunk_size=chunk_size): 27 | if data: 28 | f.write(data) 29 | pbar.update(chunk_size) 30 | 31 | 32 | def md5_hash(path): 33 | with open(path, "rb") as f: 34 | content = f.read() 35 | return hashlib.md5(content).hexdigest() 36 | 37 | 38 | def get_ckpt_path(name, root, check=False): 39 | assert name in URL_MAP 40 | path = os.path.join(root, CKPT_MAP[name]) 41 | if os.path.exists(path) and not (check and not md5_hash(path) == MD5_MAP[name]): 42 | print0( 43 | "[bold cyan]\[vidtok.modules.lpips]\[get_ckpt_path][/bold cyan] Using existing path for {} model: {}".format( 44 | name, path 45 | ) 46 | ) 47 | return path 48 | 49 | # if not, download the model 50 | print0( 51 | "[bold cyan]\[vidtok.modules.lpips]\[get_ckpt_path][/bold cyan] Downloading {} model from {} to {}".format( 52 | name, URL_MAP[name], path 53 | ) 54 | ) 55 | download(URL_MAP[name], path) 56 | md5 = md5_hash(path) 57 | assert md5 == MD5_MAP[name], md5 58 | return path 59 | 60 | 61 | class LPIPS(nn.Module): 62 | # Learned perceptual metric 63 | def __init__(self, use_dropout=True): 64 | super().__init__() 65 | self.scaling_layer = ScalingLayer() 66 | self.chns = [64, 128, 256, 512, 512] # vg16 features 67 | self.net = vgg16(pretrained=True, requires_grad=False) 68 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 69 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 70 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 71 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 72 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 73 | self.load_from_pretrained() 74 | for param in self.parameters(): 75 | param.requires_grad = False 76 | 77 | def load_from_pretrained(self, name="vgg_lpips"): 78 | ckpt = get_ckpt_path(name, "checkpoints/lpips") 79 | self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 80 | print0("[bold cyan]\[vidtok.modules.lpips][LPIPS][/bold cyan] loaded pretrained LPIPS loss from {}".format(ckpt)) 81 | 82 | def forward(self, input, target): 83 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) 84 | outs0, outs1 = self.net(in0_input), self.net(in1_input) 85 | feats0, feats1, diffs = {}, {}, {} 86 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 87 | for kk in range(len(self.chns)): 88 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) 89 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 90 | 91 | res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))] 92 | val = res[0] 93 | for l in range(1, len(self.chns)): 94 | val += res[l] 95 | return val 96 | 97 | 98 | class ScalingLayer(nn.Module): 99 | def __init__(self): 100 | super(ScalingLayer, self).__init__() 101 | self.register_buffer("shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None]) 102 | self.register_buffer("scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None]) 103 | 104 | def forward(self, inp): 105 | return (inp - self.shift) / self.scale 106 | 107 | 108 | class NetLinLayer(nn.Module): 109 | """A single linear layer which does a 1x1 conv""" 110 | 111 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 112 | super(NetLinLayer, self).__init__() 113 | layers = ( 114 | [ 115 | nn.Dropout(), 116 | ] 117 | if (use_dropout) 118 | else [] 119 | ) 120 | layers += [ 121 | nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), 122 | ] 123 | self.model = nn.Sequential(*layers) 124 | 125 | 126 | class vgg16(torch.nn.Module): 127 | def __init__(self, requires_grad=False, pretrained=True): 128 | super(vgg16, self).__init__() 129 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features 130 | self.slice1 = torch.nn.Sequential() 131 | self.slice2 = torch.nn.Sequential() 132 | self.slice3 = torch.nn.Sequential() 133 | self.slice4 = torch.nn.Sequential() 134 | self.slice5 = torch.nn.Sequential() 135 | self.N_slices = 5 136 | for x in range(4): 137 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 138 | for x in range(4, 9): 139 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 140 | for x in range(9, 16): 141 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 142 | for x in range(16, 23): 143 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 144 | for x in range(23, 30): 145 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 146 | if not requires_grad: 147 | for param in self.parameters(): 148 | param.requires_grad = False 149 | 150 | def forward(self, X): 151 | h = self.slice1(X) 152 | h_relu1_2 = h 153 | h = self.slice2(h) 154 | h_relu2_2 = h 155 | h = self.slice3(h) 156 | h_relu3_3 = h 157 | h = self.slice4(h) 158 | h_relu4_3 = h 159 | h = self.slice5(h) 160 | h_relu5_3 = h 161 | vgg_outputs = namedtuple("VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]) 162 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 163 | return out 164 | 165 | 166 | def normalize_tensor(x, eps=1e-10): 167 | norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) 168 | return x / (norm_factor + eps) 169 | 170 | 171 | def spatial_average(x, keepdim=True): 172 | return x.mean([2, 3], keepdim=keepdim) 173 | -------------------------------------------------------------------------------- /open_tokenizer/utils/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from PIL import Image, ImageFile 4 | ImageFile.LOAD_TRUNCATED_IMAGES = True 5 | 6 | import torch 7 | from torch.utils.data import Dataset 8 | from torchvision import transforms 9 | from torchvision.transforms import InterpolationMode 10 | 11 | from .data_utils import load_video_from_path_decord 12 | 13 | class EvalT2IDataset(Dataset): 14 | def __init__(self, image_folder: str, data_path: str, image_size: int, scale: bool = True): 15 | super(EvalT2IDataset, self).__init__() 16 | 17 | self.image_folder = image_folder 18 | list_data_dict = json.load(open(data_path, "r")) 19 | self.list_data_dict = list_data_dict 20 | 21 | if scale: 22 | augmentations = transforms.Compose([ 23 | transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC), 24 | transforms.ToTensor(), 25 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) 26 | ]) 27 | else: 28 | augmentations = transforms.Compose([ 29 | transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC), 30 | transforms.ToTensor(), 31 | # transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) 32 | ]) 33 | self.transform = augmentations 34 | 35 | def __len__(self): 36 | return len(self.list_data_dict) 37 | 38 | 39 | def __getitem__(self, i): 40 | item = self.list_data_dict[i] 41 | image_file = item["image_path"] 42 | image = Image.open(os.path.join(self.image_folder, image_file)).convert("RGB") 43 | image_tensor = self.transform(image) 44 | 45 | return image_tensor, image_file 46 | 47 | 48 | class EvalT2VDataset(Dataset): 49 | def __init__(self, image_folder: str, data_path: str, image_size: int, sequence_length: int = 17, fps: int =16, sampling="center"): 50 | super(EvalT2VDataset, self).__init__() 51 | 52 | self.image_folder = image_folder 53 | list_data_dict = json.load(open(data_path, "r")) 54 | self.list_data_dict = list_data_dict 55 | 56 | self.image_size = image_size 57 | self.sequence_length = sequence_length 58 | self.fps = fps 59 | self.sampling = sampling 60 | 61 | def __len__(self): 62 | return len(self.list_data_dict) 63 | 64 | 65 | def __getitem__(self, i): 66 | item = self.list_data_dict[i] 67 | video_path = os.path.join( 68 | item["video_path"] 69 | ) 70 | frames = load_video_from_path_decord( 71 | video_path, 72 | self.sampling, 73 | fps=self.fps, 74 | num_frm=self.sequence_length, 75 | crop_type="resize", 76 | crop_size=self.image_size 77 | ) 78 | 79 | vid_frm_array = ( 80 | torch.from_numpy(frames).float().permute(3, 0, 1, 2) 81 | ) 82 | # T H W 3 -> 3 T H W 83 | assert vid_frm_array.shape[1] == self.sequence_length 84 | video = (vid_frm_array / 255.0 * 2.0 - 1.0) 85 | return video, video_path 86 | 87 | 88 | 89 | class PairedVideoDataset(Dataset): 90 | def __init__(self, video_folder, image_size=256): 91 | super(PairedVideoDataset, self).__init__() 92 | 93 | gt_folder = os.path.join(video_folder, "gt") 94 | pd_folder = os.path.join(video_folder, "pd") 95 | 96 | gt_videos = [ 97 | os.path.join(gt_folder, v) for v in os.listdir(gt_folder) 98 | ] 99 | pd_videos = [ 100 | os.path.join(pd_folder, v) for v in os.listdir(pd_folder) 101 | ] 102 | 103 | self.gt_videos = sorted(gt_videos) 104 | self.pd_videos = sorted(pd_videos) 105 | 106 | self.image_size = image_size 107 | 108 | def __len__(self): 109 | return len(self.gt_videos) 110 | 111 | 112 | def __getitem__(self, i): 113 | gt_video_path = self.gt_videos[i] 114 | pd_video_path = self.pd_videos[i] 115 | 116 | gt_frames = load_video_from_path_decord( 117 | gt_video_path, 118 | "all", 119 | crop_type="resize", 120 | crop_size=self.image_size 121 | ) 122 | 123 | gt_video_tensor = ( 124 | torch.from_numpy(gt_frames).float().permute(3, 0, 1, 2) 125 | ) 126 | 127 | gt_video_tensor = (gt_video_tensor / 255.0 * 2.0 - 1.0) 128 | 129 | pd_frames = load_video_from_path_decord( 130 | pd_video_path, 131 | "all", 132 | crop_type="resize", 133 | crop_size=self.image_size 134 | ) 135 | 136 | pd_video_tensor = ( 137 | torch.from_numpy(pd_frames).float().permute(3, 0, 1, 2) 138 | ) 139 | 140 | pd_video_tensor = (pd_video_tensor / 255.0 * 2.0 - 1.0) 141 | 142 | return gt_video_tensor, pd_video_tensor 143 | -------------------------------------------------------------------------------- /open_tokenizer/utils/ddp_distributed.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import subprocess 4 | 5 | 6 | def setup_for_distributed(is_master): 7 | """ 8 | This function disables printing when not in master process 9 | """ 10 | import builtins as __builtin__ 11 | builtin_print = __builtin__.print 12 | 13 | def print(*args, **kwargs): 14 | force = kwargs.pop('force', False) 15 | if is_master or force: 16 | builtin_print(*args, **kwargs) 17 | 18 | __builtin__.print = print 19 | 20 | def init_distributed_mode(args): 21 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 22 | args.rank = int(os.environ["RANK"]) 23 | args.world_size = int(os.environ['WORLD_SIZE']) 24 | args.gpu = int(os.environ['LOCAL_RANK']) 25 | args.dist_url = 'env://' 26 | os.environ['LOCAL_SIZE'] = str(torch.cuda.device_count()) 27 | elif 'SLURM_PROCID' in os.environ: 28 | proc_id = int(os.environ['SLURM_PROCID']) 29 | ntasks = int(os.environ['SLURM_NTASKS']) 30 | node_list = os.environ['SLURM_NODELIST'] 31 | num_gpus = torch.cuda.device_count() 32 | addr = subprocess.getoutput( 33 | 'scontrol show hostname {} | head -n1'.format(node_list)) 34 | os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', '29500') 35 | os.environ['MASTER_ADDR'] = addr 36 | os.environ['WORLD_SIZE'] = str(ntasks) 37 | os.environ['RANK'] = str(proc_id) 38 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) 39 | os.environ['LOCAL_SIZE'] = str(num_gpus) 40 | args.dist_url = 'env://' 41 | args.world_size = ntasks 42 | args.rank = proc_id 43 | args.gpu = proc_id % num_gpus 44 | else: 45 | print('Not using distributed mode') 46 | args.distributed = False 47 | return 48 | 49 | args.distributed = True 50 | 51 | torch.cuda.set_device(args.gpu) 52 | args.dist_backend = 'nccl' 53 | print('| distributed init (rank {}): {}'.format( 54 | args.rank, args.dist_url), flush=True) 55 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 56 | world_size=args.world_size, rank=args.rank) 57 | torch.distributed.barrier() 58 | setup_for_distributed(args.rank == 0) -------------------------------------------------------------------------------- /open_tokenizer/utils/fvd.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copy-pasted from https://github.com/cvpr2022-stylegan-v/stylegan-v/blob/main/src/metrics/frechet_video_distance.py 3 | """ 4 | from typing import Tuple 5 | import scipy 6 | import numpy as np 7 | 8 | 9 | def compute_fvd(feats_fake: np.ndarray, feats_real: np.ndarray) -> float: 10 | mu_gen, sigma_gen = compute_stats(feats_fake) 11 | mu_real, sigma_real = compute_stats(feats_real) 12 | 13 | m = np.square(mu_gen - mu_real).sum() 14 | s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member 15 | fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2)) 16 | 17 | return float(fid) 18 | 19 | 20 | def compute_stats(feats: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: 21 | mu = feats.mean(axis=0) # [d] 22 | sigma = np.cov(feats, rowvar=False) # [d, d] 23 | 24 | return mu, sigma -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 240 3 | 4 | [build-system] 5 | requires = ["setuptools>=61.0"] 6 | build-backend = "setuptools.build_meta" 7 | 8 | [project] 9 | name = "open_tokenizer" 10 | version = "1.0" 11 | description = "OpenTokenizer: a comprehensive comparison on open-sourced tokenizers" 12 | readme = "README.md" 13 | requires-python = ">=3.8" 14 | classifiers = [ 15 | "Programming Language :: Python :: 3", 16 | "License :: OSI Approved :: Apache Software License", 17 | ] 18 | 19 | [project.optional-dependencies] 20 | standalone = [ 21 | "shortuuid", 22 | "httpx==0.24.0", 23 | "einops", 24 | "ftfy", 25 | ] 26 | 27 | 28 | train = [ 29 | "open_tokenizer[standalone]", 30 | "numpy==1.26.1", 31 | "open_clip_torch", 32 | "fastapi", 33 | "markdown2[all]", 34 | "numpy", 35 | "requests", 36 | "sentencepiece", 37 | "torch==2.1.2", 38 | "torchvision==0.16.2", 39 | "uvicorn", 40 | "wandb", 41 | "deepspeed==0.14.4", 42 | "peft==0.4.0", 43 | "accelerate>=0.29.1", 44 | "tokenizers~=0.15.2", 45 | "transformers@git+https://github.com/huggingface/transformers.git@1c39974a4c4036fd641bc1191cc32799f85715a4", 46 | "bitsandbytes==0.41.0", 47 | "scikit-learn==1.2.2", 48 | "sentencepiece~=0.1.99", 49 | "einops==0.6.1", 50 | "einops-exts==0.0.4", 51 | "gradio_client==0.2.9", 52 | "urllib3<=2.0.0", 53 | "datasets==2.16.1", 54 | "pydantic==1.10.8", 55 | "timm", 56 | "hf_transfer", 57 | "opencv-python", 58 | "av", 59 | "decord", 60 | "tyro", 61 | "scipy", 62 | ] 63 | 64 | [project.urls] 65 | 66 | [tool.setuptools.packages.find] 67 | include = ["open_tokenizer*", "trl*"] 68 | exclude = [ 69 | "assets*", 70 | "benchmark*", 71 | "docs", 72 | "dist*", 73 | "playground*", 74 | "scripts*", 75 | "tests*", 76 | "checkpoints*", 77 | "project_checkpoints*", 78 | "debug_checkpoints*", 79 | "mlx_configs*", 80 | "wandb*", 81 | "notebooks*", 82 | ] 83 | 84 | [tool.wheel] 85 | exclude = [ 86 | "assets*", 87 | "benchmark*", 88 | "docs", 89 | "dist*", 90 | "playground*", 91 | "scripts*", 92 | "tests*", 93 | "checkpoints*", 94 | "project_checkpoints*", 95 | "debug_checkpoints*", 96 | "mlx_configs*", 97 | "wandb*", 98 | "notebooks*", 99 | ] -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.21.0 2 | aiohttp==3.9.5 3 | aiosignal==1.3.1 4 | albumentations==0.3.2 5 | annotated-types==0.7.0 6 | antlr4-python3-runtime==4.9.3 7 | anykeystore==0.2 8 | asn1crypto==1.5.1 9 | asttokens==2.4.1 10 | async-timeout==4.0.3 11 | attrs==21.2.0 12 | bidict==0.23.1 13 | blessed==1.20.0 14 | boto3==1.34.113 15 | botocore==1.34.113 16 | braceexpand==0.1.7 17 | cachetools==5.3.3 18 | certifi==2024.2.2 19 | cffi==1.16.0 20 | chardet==5.2.0 21 | charset-normalizer==3.3.2 22 | click==8.1.7 23 | clip==0.2.0 24 | clip-openai==1.0.post20230121 25 | cmake==3.29.3 26 | cramjam==2.8.3 27 | crcmod==1.7 28 | cryptacular==1.6.2 29 | cryptography==39.0.2 30 | cycler==0.12.1 31 | datasets 32 | diffusers==0.30.1 33 | decorator==5.1.1 34 | decord==0.6.0 35 | deepspeed==0.14.2 36 | defusedxml==0.7.1 37 | Deprecated==1.2.14 38 | descartes==1.1.0 39 | dill==0.3.8 40 | distlib==0.3.8 41 | distro-info==1.0 42 | dnspython==2.6.1 43 | docker-pycreds==0.4.0 44 | docstring_parser==0.16 45 | ecdsa==0.19.0 46 | einops==0.6.0 47 | exceptiongroup==1.2.1 48 | executing==2.0.1 49 | fairscale==0.4.13 50 | fastparquet==2024.5.0 51 | ffmpegcv==0.3.13 52 | filelock==3.14.0 53 | fire==0.6.0 54 | fonttools==4.51.0 55 | frozenlist==1.4.1 56 | fsspec==2023.6.0 57 | ftfy==6.2.0 58 | gitdb==4.0.11 59 | GitPython==3.1.43 60 | gpustat==1.1.1 61 | greenlet==3.0.3 62 | grpcio==1.64.0 63 | h11==0.14.0 64 | hjson==3.1.0 65 | huggingface-hub==0.23.2 66 | hupper==1.12.1 67 | idna==3.7 68 | imageio==2.34.1 69 | imgaug==0.2.6 70 | iniconfig==2.0.0 71 | ipaddress==1.0.23 72 | ipdb==0.13.13 73 | ipython==8.18.1 74 | jaxtyping==0.2.28 75 | jedi==0.19.1 76 | Jinja2==3.1.4 77 | jmespath==1.0.1 78 | joblib==1.4.2 79 | jsonargparse==4.14.1 80 | jsonlines==4.0.0 81 | kiwisolver==1.4.5 82 | kornia==0.7.2 83 | kornia_rs==0.1.3 84 | lazy_loader==0.4 85 | lightning==2.2.3 86 | lightning-utilities==0.11.2 87 | lit==18.1.6 88 | MarkupSafe==2.1.5 89 | matplotlib==3.5.3 90 | matplotlib-inline==0.1.7 91 | miscreant==0.3.0 92 | mpmath==1.3.0 93 | msgpack==1.0.8 94 | multidict==6.0.5 95 | multiprocess==0.70.16 96 | natsort==8.4.0 97 | networkx==3.2.1 98 | ninja==1.11.1.1 99 | numpy==1.24.4 100 | nuscenes-devkit==1.1.11 101 | oauthlib==3.2.2 102 | omegaconf==2.3.0 103 | open-clip-torch==2.24.0 104 | openai-clip 105 | opencv-python==4.9.0.80 106 | opencv-python-headless==3.4.18.65 107 | packaging==22.0 108 | pandas==1.5.3 109 | parquet==1.3.1 110 | parso==0.8.4 111 | PasteDeploy==3.1.0 112 | pathlib2==2.3.7.post1 113 | pathtools==0.1.2 114 | pbkdf2==1.3 115 | pexpect==4.9.0 116 | pillow==10.3.0 117 | plaster==1.1.2 118 | plaster-pastedeploy==1.0.1 119 | platformdirs==4.2.2 120 | plotly==5.22.0 121 | pluggy==1.5.0 122 | ply==3.11 123 | promise==2.3 124 | prompt-toolkit==3.0.43 125 | protobuf==3.20.3 126 | psutil==5.9.8 127 | ptyprocess==0.7.0 128 | pure-eval==0.2.2 129 | py==1.11.0 130 | py-cpuinfo==9.0.0 131 | py-spy==0.3.14 132 | pyarrow==11.0.0 133 | pyarrow-hotfix==0.6 134 | pyasn1==0.6.0 135 | pycocotools==2.0.7 136 | pycparser==2.22 137 | pycryptodomex==3.20.0 138 | pycurl==7.43.0.6 139 | pydantic==1.10.15 140 | pydantic_core==2.18.3 141 | Pygments==2.18.0 142 | PyJWT==2.8.0 143 | pynvml==11.5.0 144 | pyope==0.2.2 145 | pyOpenSSL==23.2.0 146 | pyparsing==3.1.2 147 | pyquaternion==0.9.9 148 | pyramid==2.0.2 149 | pyramid-mailer==0.15.1 150 | pytest==6.2.5 151 | python-consul==1.1.0 152 | python-dateutil==2.9.0.post0 153 | python-engineio==4.9.1 154 | python-etcd==0.4.5 155 | python-jose==3.3.0 156 | python-socketio==5.11.2 157 | python3-openid==3.2.0 158 | pytorch-extension==0.2 159 | pytorch-lightning==2.2.3 160 | pytz==2024.1 161 | PyYAML==6.0.1 162 | regex==2024.5.15 163 | repoze.sendmail==4.4.1 164 | requests==2.31.0 165 | requests-oauthlib==2.0.0 166 | rsa==4.9 167 | s3transfer==0.10.1 168 | safetensors==0.4.3 169 | schedule==1.2.2 170 | scikit-image==0.22.0 171 | scikit-learn==1.5.0 172 | scipy==1.13.1 173 | sentencepiece==0.2.0 174 | sentry-sdk==2.3.1 175 | setproctitle==1.3.3 176 | Shapely==1.8.5.post1 177 | shortuuid==1.0.13 178 | simple-websocket==1.0.0 179 | six==1.16.0 180 | smmap==5.0.1 181 | SQLAlchemy==2.0.30 182 | stack-data==0.6.3 183 | sympy==1.12 184 | taming-transformers-rom1504==0.0.6 185 | tenacity==8.3.0 186 | tensorboardX==2.6.2.2 187 | termcolor==2.4.0 188 | threadpoolctl==3.5.0 189 | thriftpy2==0.5.0 190 | tifffile==2024.5.22 191 | timm==1.0.3 192 | tokenizers==0.19.1 193 | toml==0.10.2 194 | tomli==2.0.1 195 | torch==2.2.1 196 | torch-fidelity==0.3.0 197 | torchmetrics==1.4.0.post0 198 | torchvision==0.17.1 199 | tox==3.28.0 200 | tqdm==4.66.4 201 | traitlets==5.14.3 202 | transaction==4.0 203 | transformers==4.41.1 204 | translationstring==1.4 205 | triton==2.2.0 206 | typeguard==2.13.3 207 | typing_extensions==4.12.0 208 | tzdata==2024.1 209 | urllib3==1.26.18 210 | velruse==1.1.1 211 | venusian==3.1.0 212 | virtualenv==20.26.2 213 | wandb==0.17.0 214 | watchdog==4.0.1 215 | wcwidth==0.2.13 216 | webdataset==0.2.86 217 | WebOb==1.8.7 218 | websocket-client==1.8.0 219 | wrapt==1.16.0 220 | wsproto==1.2.0 221 | WTForms==3.1.2 222 | wtforms-recaptcha==0.3.2 223 | xformers==0.0.25 224 | xxhash==3.4.1 225 | yarl==1.9.4 226 | zope.deprecation==5.0 227 | zope.interface==6.4.post2 228 | zope.sqlalchemy==3.1 -------------------------------------------------------------------------------- /scripts/eval/image/cosmos.sh: -------------------------------------------------------------------------------- 1 | 2 | torchrun \ 3 | --nnodes=1 --nproc_per_node=4 --master_port 23456 \ 4 | eval/reconstruct.py \ 5 | --vq_model_ckpt "./checkpoints/Cosmos-0.1-Tokenizer-DI8x8/" \ 6 | --vq_model_type "cosmos-i" \ 7 | --dataset_type "image" \ 8 | --dataset_name "laion" \ 9 | --save_dir "./tokenizers" \ 10 | --image_path "/path_to_dir/tokenizer_bench/laion.json" \ 11 | --image_folder "" \ 12 | --resolution 256 \ 13 | --video_fps 16 \ 14 | --sequence_length 33 15 | -------------------------------------------------------------------------------- /scripts/eval/image/emu3.sh: -------------------------------------------------------------------------------- 1 | 2 | torchrun \ 3 | --nnodes=1 --nproc_per_node=4 --master_port 23456 \ 4 | eval/reconstruct.py \ 5 | --vq_model_ckpt "./checkpoints/Emu3-VisionTokenizer" \ 6 | --vq_model_type "emu3" \ 7 | --dataset_type "image" \ 8 | --dataset_name "laion" \ 9 | --save_dir "./tokenizers" \ 10 | --image_path "/path_to_dir/tokenizer_bench/laion.json" \ 11 | --image_folder "" \ 12 | --resolution 256 \ 13 | --video_fps 16 \ 14 | --sequence_length 33 15 | -------------------------------------------------------------------------------- /scripts/eval/image/llamagen.sh: -------------------------------------------------------------------------------- 1 | 2 | torchrun \ 3 | --nnodes=1 --nproc_per_node=4 --master_port 23456 \ 4 | eval/reconstruct.py \ 5 | --vq_model_ckpt "./checkpoints/vq_ds16_c2i.pt" \ 6 | --vq_model_type "llamagen" \ 7 | --dataset_type "image" \ 8 | --dataset_name "laion" \ 9 | --save_dir "./tokenizers" \ 10 | --image_path "/path_to_dir/tokenizer_bench/laion.json" \ 11 | --image_folder "" \ 12 | --resolution 256 \ 13 | --video_fps 16 \ 14 | --sequence_length 33 15 | -------------------------------------------------------------------------------- /scripts/eval/image/omnitokenizer.sh: -------------------------------------------------------------------------------- 1 | 2 | torchrun \ 3 | --nnodes=1 --nproc_per_node=1 --master_port 23456 \ 4 | eval/reconstruct.py \ 5 | --vq_model_ckpt "./checkpoints/omnitokenizer_rq_down16_code16384_joint.ckpt" \ 6 | --dataset_type "image" \ 7 | --dataset_name "laion" \ 8 | --save_dir "./visualizations" \ 9 | --image_path "/inspire/hdd/ws-8207e9e2-e733-4eec-a475-cfa1c36480ba/embodied-multimodality/public/jkwang/workspace/vtokenizer_bench/laion_pub.json" \ 10 | --image_folder "/inspire/hdd/ws-8207e9e2-e733-4eec-a475-cfa1c36480ba/embodied-multimodality/public/jkwang/workspace/vtokenizer_bench" \ 11 | --resolution 256 \ 12 | --video_fps 16 \ 13 | --sequence_length 33 14 | -------------------------------------------------------------------------------- /scripts/eval/image/showo.sh: -------------------------------------------------------------------------------- 1 | 2 | torchrun \ 3 | --nnodes=1 --nproc_per_node=4 --master_port 23456 \ 4 | eval/reconstruct.py \ 5 | --vq_model_ckpt "./checkpoints/magvitv2" \ 6 | --vq_model_type "show-o" \ 7 | --dataset_type "image" \ 8 | --dataset_name "laion" \ 9 | --save_dir "./tokenizers" \ 10 | --image_path "/path_to_dir/tokenizer_bench/laion.json" \ 11 | --image_folder "" \ 12 | --resolution 256 \ 13 | --video_fps 16 \ 14 | --sequence_length 33 15 | -------------------------------------------------------------------------------- /scripts/eval/image/tiktok.sh: -------------------------------------------------------------------------------- 1 | 2 | torchrun \ 3 | --nnodes=1 --nproc_per_node=4 --master_port 23456 \ 4 | eval/reconstruct.py \ 5 | --vq_model_ckpt "./checkpoints/tokenizer_titok_sl256_vq8k_imagenet" \ 6 | --vq_model_type "tiktok" \ 7 | --dataset_type "image" \ 8 | --dataset_name "laion" \ 9 | --save_dir "./tokenizers" \ 10 | --image_path "/path_to_dir/tokenizer_bench/laion.json" \ 11 | --image_folder "" \ 12 | --resolution 256 \ 13 | --video_fps 16 \ 14 | --sequence_length 33 15 | -------------------------------------------------------------------------------- /scripts/eval/video/cosmos.sh: -------------------------------------------------------------------------------- 1 | 2 | torchrun \ 3 | --nnodes=1 --nproc_per_node=4 --master_port 23457 \ 4 | eval/reconstruct.py \ 5 | --vq_model_ckpt "./checkpoints/Cosmos-0.1-Tokenizer-DV4x8x8/" \ 6 | --vq_model_type "cosmos-v" \ 7 | --dataset_type "video" \ 8 | --dataset_name "openvid" \ 9 | --save_dir "./tokenizers" \ 10 | --video_path "/path_to_dir/tokenizer_bench/openvid.json" \ 11 | --video_folder "" \ 12 | --resolution 256 \ 13 | --video_fps 16 \ 14 | --sequence_length 16 15 | -------------------------------------------------------------------------------- /scripts/eval/video/emu3.sh: -------------------------------------------------------------------------------- 1 | 2 | torchrun \ 3 | --nnodes=1 --nproc_per_node=4 --master_port 23456 \ 4 | eval/reconstruct.py \ 5 | --vq_model_ckpt "./checkpoints/Emu3-VisionTokenizer" \ 6 | --vq_model_type "emu3" \ 7 | --dataset_type "video" \ 8 | --dataset_name "openvid" \ 9 | --save_dir "./tokenizers" \ 10 | --video_path "/path_to_dir/tokenizer_bench/openvid.json" \ 11 | --video_folder "" \ 12 | --resolution 256 \ 13 | --video_fps 16 \ 14 | --sequence_length 16 \ 15 | --batch_size 2 -------------------------------------------------------------------------------- /scripts/eval/video/omnitokenizer.sh: -------------------------------------------------------------------------------- 1 | 2 | torchrun \ 3 | --nnodes=1 --nproc_per_node=4 --master_port 23456 \ 4 | eval/reconstruct.py \ 5 | --vq_model_ckpt "./checkpoints/omnitokenizer_rq_code16384_down16_joint2.ckpt" \ 6 | --dataset_type "video" \ 7 | --dataset_name "openvid" \ 8 | --save_dir "./tokenizers" \ 9 | --video_path "/path_to_dir/tokenizer_bench/openvid.json" \ 10 | --video_folder "" \ 11 | --resolution 256 \ 12 | --video_fps 16 \ 13 | --sequence_length 17 14 | -------------------------------------------------------------------------------- /src/1-cos.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wdrink/OpenTokenizer/0f298fe6b3397acc358f16207c51e6e2820f413b/src/1-cos.gif -------------------------------------------------------------------------------- /src/1-emu3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wdrink/OpenTokenizer/0f298fe6b3397acc358f16207c51e6e2820f413b/src/1-emu3.gif -------------------------------------------------------------------------------- /src/1-omni.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wdrink/OpenTokenizer/0f298fe6b3397acc358f16207c51e6e2820f413b/src/1-omni.gif -------------------------------------------------------------------------------- /src/2-cos.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wdrink/OpenTokenizer/0f298fe6b3397acc358f16207c51e6e2820f413b/src/2-cos.gif -------------------------------------------------------------------------------- /src/2-emu3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wdrink/OpenTokenizer/0f298fe6b3397acc358f16207c51e6e2820f413b/src/2-emu3.gif -------------------------------------------------------------------------------- /src/2-omni.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wdrink/OpenTokenizer/0f298fe6b3397acc358f16207c51e6e2820f413b/src/2-omni.gif -------------------------------------------------------------------------------- /src/vis-img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wdrink/OpenTokenizer/0f298fe6b3397acc358f16207c51e6e2820f413b/src/vis-img.png --------------------------------------------------------------------------------