├── bitvae ├── utils │ ├── __init__.py │ ├── random_numbers.npy │ ├── generate_random_num.py │ ├── logger.py │ ├── init_models.py │ ├── misc.py │ ├── distributed.py │ └── arguments.py ├── data │ ├── __init__.py │ ├── dataset_zoo.py │ └── data.py ├── modules │ ├── quantizer │ │ ├── __init__.py │ │ ├── dynamic_resolution.py │ │ └── multiscale_bsq.py │ ├── cache │ │ └── vgg.pth │ ├── __init__.py │ ├── conv.py │ ├── loss.py │ ├── normalization.py │ └── lpips.py ├── evaluation │ ├── __init__.py │ ├── fid.py │ └── inception.py └── models │ ├── __init__.py │ ├── discriminator.py │ └── d_vae.py ├── .gitignore ├── scripts ├── prepare.sh └── release │ ├── test_img_d16_stage1.sh │ ├── test_img_d16_stage2.sh │ ├── test_img_d32_stage1.sh │ ├── test_img_d32_stage2.sh │ ├── train_img_d32_stage1.sh │ ├── train_img_d16_stage2.sh │ ├── train_img_d32_stage2.sh │ └── train_img_d16_stage1.sh ├── requirements.txt ├── LICENSE ├── README.md ├── eval.py └── train.py /bitvae/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /bitvae/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .data import ImageData, ImageDataset -------------------------------------------------------------------------------- /bitvae/modules/quantizer/__init__.py: -------------------------------------------------------------------------------- 1 | from .multiscale_bsq import MultiScaleBSQ 2 | -------------------------------------------------------------------------------- /bitvae/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .fid import calculate_frechet_distance 2 | from .inception import InceptionV3 -------------------------------------------------------------------------------- /bitvae/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .discriminator import ImageDiscriminator 2 | from .d_vae import AutoEncoder as d_vae -------------------------------------------------------------------------------- /bitvae/modules/cache/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FoundationVision/BitVAE/HEAD/bitvae/modules/cache/vgg.pth -------------------------------------------------------------------------------- /bitvae/utils/random_numbers.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FoundationVision/BitVAE/HEAD/bitvae/utils/random_numbers.npy -------------------------------------------------------------------------------- /bitvae/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .lpips import LPIPS 2 | from .normalization import Normalize 3 | from .conv import Conv 4 | # from .commitments import DiagonalGaussianDistribution 5 | from .loss import adopt_weight -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__/ 2 | lightning_logs/ 3 | .ipynb_checkpoints/ 4 | *.egg-info 5 | .pyc 6 | results* 7 | logs* 8 | recon.* 9 | dataset/ 10 | **/span.log 11 | wandb/* 12 | *.png 13 | bitvae_results 14 | checkpoints 15 | labels 16 | .idea/ 17 | -------------------------------------------------------------------------------- /scripts/prepare.sh: -------------------------------------------------------------------------------- 1 | ### trainer 2 | sudo apt-get install libatlas-base-dev -y 3 | sudo pip3 uninstall numpy -y 4 | pip3 install -r requirements.txt 5 | sudo apt-get install ffmpeg libsm6 libxext6 -y 6 | 7 | pip3 install moviepy==2.0.0.dev2 imageio 8 | pip3 install seaborn 9 | pip3 install einx -------------------------------------------------------------------------------- /bitvae/modules/conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class Conv(nn.Module): 5 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0): 6 | super().__init__() 7 | 8 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding) 9 | self.stride = stride 10 | self.kernel_size = kernel_size 11 | 12 | def forward(self, x): 13 | return self.conv(x) -------------------------------------------------------------------------------- /bitvae/data/dataset_zoo.py: -------------------------------------------------------------------------------- 1 | DATASET_DICT = { 2 | "imagenet": { 3 | "train_label": "labels/imagenet/train.txt", 4 | "val_label": "labels/imagenet/val.txt", 5 | "dataset_path": "", 6 | "data_type": "image" 7 | }, 8 | "openimages": { 9 | "train_label": "labels/openimages/train.txt", # 9M+ images 10 | "val_label": "labels/openimages/train.txt", 11 | "dataset_path": "", 12 | "data_type": "image" 13 | } 14 | } -------------------------------------------------------------------------------- /bitvae/utils/generate_random_num.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | if __name__ == "__main__": 5 | # Generate random numbers and save them to a file 6 | np.random.seed(42) # Set the seed for reproducibility 7 | version = "v2" # ["v1", "v2"] 8 | if version == "v1": 9 | num_choices = 3 10 | save_path = "random_numbers.npy" 11 | elif version == "v2": 12 | num_choices = 45 # 3 or 45 13 | save_path = "random_numbers_v2.npy" 14 | else: 15 | raise ValueError("Invalid version") 16 | random_numbers = np.random.choice(list(range(num_choices)), size=500000) 17 | np.save(save_path, random_numbers) 18 | loaded_random_numbers = np.load(save_path) 19 | print(loaded_random_numbers[:100]) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # pytorch-lightning==1.8.6 2 | diffusers==0.29.1 3 | h5py==3.11.0 4 | einops==0.8.0 5 | ftfy==6.2.0 6 | imageio==2.34.1 7 | imageio-ffmpeg==0.5.1 8 | regex==2024.5.15 9 | scikit-video==1.1.11 10 | tqdm==4.66.4 11 | av==12.1.0 12 | beartype==0.18.5 13 | scikit-learn==1.5.0 14 | timm==1.0.7 15 | transformers==4.37.2 16 | decord==0.6.0 17 | fvcore==0.1.5.post20221221 18 | pycocotools==2.0.8 19 | cloudpickle==3.0.0 20 | omegaconf==2.3.0 21 | scikit-image==0.24.0 22 | lpips==0.1.4 23 | accelerate 24 | numpy==1.26.2 25 | # pip3 install git+https://github.com/nottombrown/imagenet_stubs 26 | 27 | fairscale==0.4.13 28 | # lightning==2.2.5 29 | opencv-python==4.10.0.84 30 | ipdb==0.13.13 31 | 32 | # torch==2.3.1 33 | # torchvision==0.18.1 34 | # torchaudio==2.3.1 35 | 36 | 37 | einx==0.3.0 -------------------------------------------------------------------------------- /bitvae/utils/logger.py: -------------------------------------------------------------------------------- 1 | # from https://github.com/FoundationVision/LlamaGen/blob/main/utils/logger.py 2 | import logging 3 | import glob 4 | import os 5 | import torch.distributed as dist 6 | 7 | def create_logger(logging_dir): 8 | """ 9 | Create a logger that writes to a log file and stdout. 10 | """ 11 | if dist.get_rank() == 0: # real logger 12 | existing_logs = glob.glob(os.path.join(logging_dir, 'log_*.txt')) 13 | log_numbers = [int(log.split('.txt')[0].split('_')[-1]) for log in existing_logs] 14 | next_log_number = max(log_numbers) + 1 if log_numbers else 1 15 | logging.basicConfig( 16 | level=logging.INFO, 17 | format='%(asctime)s %(message)s', 18 | datefmt='%Y-%m-%d %H:%M:%S', 19 | handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log_{next_log_number}.txt")] 20 | ) 21 | logger = logging.getLogger(__name__) 22 | else: # dummy logger (does nothing) 23 | logger = logging.getLogger(__name__) 24 | logger.addHandler(logging.NullHandler()) 25 | return logger -------------------------------------------------------------------------------- /bitvae/modules/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | def hinge_d_loss(logits_real, logits_fake): 6 | loss_real = torch.mean(F.relu(1. - logits_real)) 7 | loss_fake = torch.mean(F.relu(1. + logits_fake)) 8 | d_loss = 0.5 * (loss_real + loss_fake) 9 | return d_loss 10 | 11 | def vanilla_d_loss(logits_real, logits_fake): 12 | d_loss = 0.5 * ( 13 | torch.mean(torch.nn.functional.softplus(-logits_real)) + 14 | torch.mean(torch.nn.functional.softplus(logits_fake))) 15 | return d_loss 16 | 17 | def get_disc_loss(disc_loss_type): 18 | if disc_loss_type == 'vanilla': 19 | disc_loss = vanilla_d_loss 20 | elif disc_loss_type == 'hinge': 21 | disc_loss = hinge_d_loss 22 | return disc_loss 23 | 24 | def adopt_weight(global_step, threshold=0, value=0., warmup=0): 25 | if global_step < threshold or threshold < 0: 26 | weight = value 27 | else: 28 | weight = 1 29 | if global_step - threshold < warmup: 30 | weight = min((global_step - threshold) / warmup, 1) 31 | return weight -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 FoundationVision 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /scripts/release/test_img_d16_stage1.sh: -------------------------------------------------------------------------------- 1 | python3 eval.py --tokenizer 'flux' --inference_type 'image' --patch_size 16 \ 2 | --base_ch 128 --encoder_ch_mult 1 2 4 4 4 --decoder_ch_mult 1 2 4 4 4 \ 3 | --codebook_dim 16 --codebook_size 65536 \ 4 | --vqgan_ckpt "bitvae_results/infinity_d16_stage1/checkpoints/model_step_499999.ckpt" \ 5 | --batch_size 1 --dataset_list "imagenet" --save ./imagenet_256_dim16_stage1 --dataaug "resizecrop" --resolution 256 256 --num_workers 0 \ 6 | --default_root_dir "test" --save_prediction --quantizer_type 'MultiScaleBSQ' \ 7 | --new_quant --schedule_mode "dynamic" --remove_residual_detach --disable_codebook_usage \ 8 | $@ 9 | 10 | python3 eval.py --tokenizer 'flux' --inference_type 'image' --patch_size 16 \ 11 | --base_ch 128 --encoder_ch_mult 1 2 4 4 4 --decoder_ch_mult 1 2 4 4 4 \ 12 | --codebook_dim 16 --codebook_size 65536 \ 13 | --vqgan_ckpt "bitvae_results/infinity_d16_stage1/checkpoints/model_step_499999.ckpt" \ 14 | --batch_size 1 --dataset_list "imagenet" --save ./imagenet_512_dim16_stage1 --dataaug "resizecrop" --resolution 512 512 --num_workers 0 \ 15 | --default_root_dir "test" --save_prediction --quantizer_type 'MultiScaleBSQ' \ 16 | --new_quant --schedule_mode "dynamic" --remove_residual_detach --disable_codebook_usage \ 17 | $@ 18 | -------------------------------------------------------------------------------- /scripts/release/test_img_d16_stage2.sh: -------------------------------------------------------------------------------- 1 | python3 eval.py --tokenizer 'flux' --inference_type 'image' --patch_size 16 \ 2 | --base_ch 128 --encoder_ch_mult 1 2 4 4 4 --decoder_ch_mult 1 2 4 4 4 \ 3 | --codebook_dim 16 --codebook_size 65536 \ 4 | --vqgan_ckpt "bitvae_results/infinity_d16_stage2/checkpoints/model_step_199999.ckpt" \ 5 | --batch_size 1 --dataset_list "imagenet" --save ./imagenet_256_dim16_stage2 --dataaug "resizecrop" --resolution 256 256 --num_workers 0 \ 6 | --default_root_dir "test" --save_prediction --quantizer_type 'MultiScaleBSQ' \ 7 | --new_quant --schedule_mode "dynamic" --remove_residual_detach --disable_codebook_usage \ 8 | $@ 9 | 10 | python3 eval.py --tokenizer 'flux' --inference_type 'image' --patch_size 16 \ 11 | --base_ch 128 --encoder_ch_mult 1 2 4 4 4 --decoder_ch_mult 1 2 4 4 4 \ 12 | --codebook_dim 16 --codebook_size 65536 \ 13 | --vqgan_ckpt "bitvae_results/infinity_d16_stage2/checkpoints/model_step_199999.ckpt" \ 14 | --batch_size 1 --dataset_list "imagenet" --save ./imagenet_512_dim16_stage2 --dataaug "resizecrop" --resolution 512 512 --num_workers 0 \ 15 | --default_root_dir "test" --save_prediction --quantizer_type 'MultiScaleBSQ' \ 16 | --new_quant --schedule_mode "dynamic" --remove_residual_detach --disable_codebook_usage \ 17 | $@ 18 | -------------------------------------------------------------------------------- /scripts/release/test_img_d32_stage1.sh: -------------------------------------------------------------------------------- 1 | python3 eval.py --tokenizer 'flux' --inference_type 'image' --patch_size 16 \ 2 | --base_ch 128 --encoder_ch_mult 1 2 4 4 4 --decoder_ch_mult 1 2 4 4 4 \ 3 | --codebook_dim 32 --codebook_size 4294967296 \ 4 | --vqgan_ckpt "bitvae_results/infinity_d32_stage1/checkpoints/model_step_499999.ckpt" \ 5 | --batch_size 1 --dataset_list "imagenet" --save ./imagenet_256_dim32_stage1 --dataaug "resizecrop" --resolution 256 256 --num_workers 0 \ 6 | --default_root_dir "test" --save_prediction --quantizer_type 'MultiScaleBSQ' \ 7 | --new_quant --schedule_mode "dynamic" --remove_residual_detach --disable_codebook_usage \ 8 | $@ 9 | 10 | python3 eval.py --tokenizer 'flux' --inference_type 'image' --patch_size 16 \ 11 | --base_ch 128 --encoder_ch_mult 1 2 4 4 4 --decoder_ch_mult 1 2 4 4 4 \ 12 | --codebook_dim 32 --codebook_size 4294967296 \ 13 | --vqgan_ckpt "bitvae_results/infinity_d32_stage1/checkpoints/model_step_499999.ckpt" \ 14 | --batch_size 1 --dataset_list "imagenet" --save ./imagenet_512_dim32_stage1 --dataaug "resizecrop" --resolution 512 512 --num_workers 0 \ 15 | --default_root_dir "test" --save_prediction --quantizer_type 'MultiScaleBSQ' \ 16 | --new_quant --schedule_mode "dynamic" --remove_residual_detach --disable_codebook_usage \ 17 | $@ -------------------------------------------------------------------------------- /scripts/release/test_img_d32_stage2.sh: -------------------------------------------------------------------------------- 1 | python3 eval.py --tokenizer 'flux' --inference_type 'image' --patch_size 16 \ 2 | --base_ch 128 --encoder_ch_mult 1 2 4 4 4 --decoder_ch_mult 1 2 4 4 4 \ 3 | --codebook_dim 32 --codebook_size 4294967296 \ 4 | --vqgan_ckpt "bitvae_results/infinity_d32_stage2/checkpoints/model_step_199999.ckpt" \ 5 | --batch_size 1 --dataset_list "imagenet" --save ./imagenet_256_dim32_stage2 --dataaug "resizecrop" --resolution 256 256 --num_workers 0 \ 6 | --default_root_dir "test" --save_prediction --quantizer_type 'MultiScaleBSQ' \ 7 | --new_quant --schedule_mode "dynamic" --remove_residual_detach --disable_codebook_usage \ 8 | $@ 9 | 10 | python3 eval.py --tokenizer 'flux' --inference_type 'image' --patch_size 16 \ 11 | --base_ch 128 --encoder_ch_mult 1 2 4 4 4 --decoder_ch_mult 1 2 4 4 4 \ 12 | --codebook_dim 32 --codebook_size 4294967296 \ 13 | --vqgan_ckpt "bitvae_results/infinity_d32_stage2/checkpoints/model_step_199999.ckpt" \ 14 | --batch_size 1 --dataset_list "imagenet" --save ./imagenet_512_dim32_stage2 --dataaug "resizecrop" --resolution 512 512 --num_workers 0 \ 15 | --default_root_dir "test" --save_prediction --quantizer_type 'MultiScaleBSQ' \ 16 | --new_quant --schedule_mode "dynamic" --remove_residual_detach --disable_codebook_usage \ 17 | $@ -------------------------------------------------------------------------------- /scripts/release/train_img_d32_stage1.sh: -------------------------------------------------------------------------------- 1 | WORKER_GPU=8 2 | NODE_NUM=4 3 | NUM_WORKERS=12 4 | 5 | if [[ "$*" == *"--debug"* ]]; then 6 | WORKER_GPU=1 7 | NODE_NUM=1 8 | NUM_WORKERS=0 9 | fi 10 | 11 | torchrun \ 12 | --nproc_per_node=$WORKER_GPU \ 13 | --nnodes=$NODE_NUM --master_addr=$WORKER_0_HOST \ 14 | --node_rank=$NODE_ID --master_port=$PORT \ 15 | train.py --num_workers $NUM_WORKERS \ 16 | --patch_size 16 \ 17 | --base_ch 128 --encoder_ch_mult 1 2 4 4 4 --decoder_ch_mult 1 2 4 4 4 \ 18 | --codebook_dim 32 \ 19 | --optim_type AdamW --lr 1e-4 --disable_sch --dis_lr_multiplier 1 --max_steps 500000 \ 20 | --resolution 256 256 --batch_size 8 --dataset_list "openimages" --dataaug "resizecrop" \ 21 | --disc_layers 3 --discriminator_iter_start 50000 \ 22 | --l1_weight 1 --perceptual_weight 1 --image_disc_weight 1 --image_gan_weight 0.3 --gan_feat_weight 0 --lfq_weight 4 \ 23 | --codebook_size 4294967296 --entropy_loss_weight 0.1 --diversity_gamma 1 \ 24 | --default_root_dir "bitvae_results/infinity_d32_stage1" --log_every 20 --ckpt_every 10000 --visu_every 10000 \ 25 | --new_quant --lr_drop 450000 \ 26 | --remove_residual_detach --use_lecam_reg_zero --base_ch_disc 128 --dis_lr_multiplier 2.0 \ 27 | --schedule_mode "dense" --use_stochastic_depth --drop_rate 0.5 --keep_last_quant --tokenizer 'flux' --quantizer_type 'MultiScaleBSQ' $@ -------------------------------------------------------------------------------- /bitvae/modules/normalization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from einops import rearrange 5 | 6 | 7 | class Normalize(nn.Module): 8 | def __init__(self, in_channels, norm_type): 9 | super().__init__() 10 | assert norm_type in ['group', 'batch', "no"] 11 | if norm_type == 'group': 12 | if in_channels % 32 == 0: 13 | self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 14 | elif in_channels % 24 == 0: 15 | self.norm = nn.GroupNorm(num_groups=24, num_channels=in_channels, eps=1e-6, affine=True) 16 | else: 17 | raise NotImplementedError 18 | elif norm_type == 'batch': 19 | self.norm = nn.SyncBatchNorm(in_channels, track_running_stats=False) # Runtime Error: grad inplace if set track_running_stats to True 20 | elif norm_type == 'no': 21 | self.norm = nn.Identity() 22 | 23 | def forward(self, x): 24 | assert x.ndim == 4 25 | x = self.norm(x) 26 | return x 27 | 28 | def l2norm(t): 29 | return F.normalize(t, dim=-1) 30 | 31 | class LayerNorm(nn.Module): 32 | def __init__(self, dim): 33 | super().__init__() 34 | self.gamma = nn.Parameter(torch.ones(dim)) 35 | self.register_buffer("beta", torch.zeros(dim)) 36 | 37 | def forward(self, x): 38 | return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta) 39 | -------------------------------------------------------------------------------- /scripts/release/train_img_d16_stage2.sh: -------------------------------------------------------------------------------- 1 | WORKER_GPU=8 2 | NODE_NUM=4 3 | NUM_WORKERS=12 4 | 5 | if [[ "$*" == *"--debug"* ]]; then 6 | WORKER_GPU=1 7 | NODE_NUM=1 8 | NUM_WORKERS=0 9 | fi 10 | 11 | torchrun \ 12 | --nproc_per_node=$WORKER_GPU \ 13 | --nnodes=$NODE_NUM --master_addr=$WORKER_0_HOST \ 14 | --node_rank=$NODE_ID --master_port=$PORT \ 15 | train.py --num_workers $NUM_WORKERS \ 16 | --patch_size 16 \ 17 | --base_ch 128 --encoder_ch_mult 1 2 4 4 4 --decoder_ch_mult 1 2 4 4 4 \ 18 | --codebook_dim 16 \ 19 | --optim_type AdamW --lr 1e-4 --disable_sch --dis_lr_multiplier 1 --max_steps 200000 \ 20 | --resolution 1024 1024 --batch_size 8 --dataset_list "openimages" --dataaug "resizecrop" \ 21 | --disc_layers 3 --discriminator_iter_start 0 \ 22 | --l1_weight 1 --perceptual_weight 1 --image_disc_weight 1 --image_gan_weight 0.3 --gan_feat_weight 0 --lfq_weight 4 \ 23 | --codebook_size 65536 --entropy_loss_weight 0.1 --diversity_gamma 1 \ 24 | --default_root_dir "bitvae_results/infinity_d16_stage2" --log_every 20 --ckpt_every 5000 --visu_every 5000 \ 25 | --new_quant --lr_drop 150000 \ 26 | --remove_residual_detach --use_lecam_reg_zero --base_ch_disc 128 --dis_lr_multiplier 2.0 --use_checkpoint \ 27 | --schedule_mode "dense" --use_stochastic_depth --drop_rate 0.5 --keep_last_quant --tokenizer 'flux' --quantizer_type 'MultiScaleBSQ' \ 28 | --pretrained "bitvae_results/infinity_d16_stage1/checkpoints/model_step_499999.ckpt" --not_load_optimizer --multiscale_training $@ -------------------------------------------------------------------------------- /scripts/release/train_img_d32_stage2.sh: -------------------------------------------------------------------------------- 1 | WORKER_GPU=8 2 | NODE_NUM=4 3 | NUM_WORKERS=12 4 | 5 | if [[ "$*" == *"--debug"* ]]; then 6 | WORKER_GPU=1 7 | NODE_NUM=1 8 | NUM_WORKERS=0 9 | fi 10 | 11 | torchrun \ 12 | --nproc_per_node=$WORKER_GPU \ 13 | --nnodes=$NODE_NUM --master_addr=$WORKER_0_HOST \ 14 | --node_rank=$NODE_ID --master_port=$PORT \ 15 | train.py --num_workers $NUM_WORKERS \ 16 | --patch_size 16 \ 17 | --base_ch 128 --encoder_ch_mult 1 2 4 4 4 --decoder_ch_mult 1 2 4 4 4 \ 18 | --codebook_dim 32 \ 19 | --optim_type AdamW --lr 1e-4 --disable_sch --dis_lr_multiplier 1 --max_steps 200000 \ 20 | --resolution 1024 1024 --batch_size 8 --dataset_list "openimages" --dataaug "resizecrop" \ 21 | --disc_layers 3 --discriminator_iter_start 0 \ 22 | --l1_weight 1 --perceptual_weight 1 --image_disc_weight 1 --image_gan_weight 0.3 --gan_feat_weight 0 --lfq_weight 4 \ 23 | --codebook_size 4294967296 --entropy_loss_weight 0.1 --diversity_gamma 1 \ 24 | --default_root_dir "bitvae_results/infinity_d32_stage2" --log_every 20 --ckpt_every 5000 --visu_every 5000 \ 25 | --new_quant --lr_drop 150000 \ 26 | --remove_residual_detach --use_lecam_reg_zero --base_ch_disc 128 --dis_lr_multiplier 2.0 --use_checkpoint \ 27 | --schedule_mode "dense" --use_stochastic_depth --drop_rate 0.5 --keep_last_quant --tokenizer 'flux' --quantizer_type 'MultiScaleBSQ' \ 28 | --pretrained "bitvae_results/infinity_d32_stage1/checkpoints/model_step_499999.ckpt" --not_load_optimizer --multiscale_training $@ -------------------------------------------------------------------------------- /bitvae/utils/init_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | from bitvae.utils.misc import is_torch_optimizer 6 | 7 | def load_unstrictly(state_dict, model, loaded_keys=[]): 8 | missing_keys = [] 9 | for name, param in model.named_parameters(): 10 | if name in state_dict: 11 | try: 12 | param.data.copy_(state_dict[name]) 13 | except: 14 | # print(f"{name} mismatch: param {name}, shape {param.data.shape}, state_dict shape {state_dict[name].shape}") 15 | missing_keys.append(name) 16 | elif name not in loaded_keys: 17 | missing_keys.append(name) 18 | return model, missing_keys 19 | 20 | def resume_from_ckpt(state_dict, model_optims, load_optimizer=True): 21 | all_missing_keys = [] 22 | # load weights first 23 | for k in model_optims: 24 | if model_optims[k] and (not is_torch_optimizer(model_optims[k])) and k in state_dict: 25 | model_optims[k], missing_keys = load_unstrictly(state_dict[k], model_optims[k]) 26 | all_missing_keys += missing_keys 27 | 28 | if len(all_missing_keys) == 0 and load_optimizer: 29 | print("Loading optimizer states") 30 | for k in model_optims: 31 | if model_optims[k] and is_torch_optimizer(model_optims[k]) and k in state_dict: 32 | model_optims[k].load_state_dict(state_dict[k]) 33 | else: 34 | print(f"missing weights: {all_missing_keys}, do not load optimzer states") 35 | return model_optims, state_dict["step"] 36 | -------------------------------------------------------------------------------- /scripts/release/train_img_d16_stage1.sh: -------------------------------------------------------------------------------- 1 | if [[ "$ARNOLD_DEVICE_TYPE" == *A100* ]]; then 2 | IB_HCA=mlx5 3 | export NCCL_IB_HCA=$IB_HCA 4 | else 5 | IB_HCA=$ARNOLD_RDMA_DEVICE:1 6 | fi 7 | 8 | if [[ "$RUNTIME_IDC_NAME" == *uswest2* ]]; then 9 | IDC_NAME=bond0 10 | export NCCL_SOCKET_IFNAME=$IDC_NAME 11 | else 12 | IDC_NAME=eth0 13 | fi 14 | 15 | port=$(echo "$ARNOLD_WORKER_0_PORT" | cut -d "," -f 1) 16 | echo $port 17 | 18 | export NCCL_DEBUG=WARN 19 | export NCCL_IB_DISABLE=0 20 | export NCCL_IB_GID_INDEX=3 21 | 22 | NUM_WORKERS=12 23 | 24 | if [[ "$*" == *"--debug"* ]]; then 25 | ARNOLD_WORKER_NUM=1 26 | ARNOLD_WORKER_GPU=1 27 | NUM_WORKERS=0 28 | fi 29 | 30 | torchrun \ 31 | --nproc_per_node=$ARNOLD_WORKER_GPU \ 32 | --nnodes=$ARNOLD_WORKER_NUM --master_addr=$ARNOLD_WORKER_0_HOST \ 33 | --node-rank=$ARNOLD_ID --master_port=$PORT \ 34 | train.py --num_workers $NUM_WORKERS \ 35 | --patch_size 16 \ 36 | --base_ch 128 --encoder_ch_mult 1 2 4 4 4 --decoder_ch_mult 1 2 4 4 4 \ 37 | --codebook_dim 16 \ 38 | --optim_type AdamW --lr 1e-4 --disable_sch --dis_lr_multiplier 1 --max_steps 500000 \ 39 | --resolution 256 256 --batch_size 8 --dataset_list "imagenet" --dataaug "resizecrop" \ 40 | --disc_layers 3 --discriminator_iter_start 50000 \ 41 | --l1_weight 1 --perceptual_weight 1 --image_disc_weight 1 --image_gan_weight 0.3 --gan_feat_weight 0 --lfq_weight 4 \ 42 | --codebook_size 65536 --entropy_loss_weight 0.1 --diversity_gamma 1 \ 43 | --default_root_dir "bitvae_results/infinity_d16_stage1" --log_every 20 --ckpt_every 10000 --visu_every 10000 \ 44 | --new_quant --lr_drop 450000 \ 45 | --remove_residual_detach --use_lecam_reg_zero --base_ch_disc 128 --dis_lr_multiplier 2.0 \ 46 | --schedule_mode "dense" --use_stochastic_depth --drop_rate 0.5 --keep_last_quant --tokenizer 'flux' --quantizer_type 'MultiScaleBSQ' $@ -------------------------------------------------------------------------------- /bitvae/utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Foundationvision, Inc. All Rights Reserved 2 | 3 | import torch 4 | import torch.distributed as dist 5 | import imageio 6 | import os 7 | import random 8 | 9 | import math 10 | import numpy as np 11 | import skvideo.io 12 | from einops import rearrange 13 | import torch.optim as optim 14 | 15 | from contextlib import contextmanager 16 | 17 | ptdtype = {None: torch.float32, 'fp32': torch.float32, 'bf16': torch.bfloat16} 18 | 19 | def rank_zero_only(fn): 20 | def wrapped_fn(*args, **kwargs): 21 | if not dist.is_initialized() or dist.get_rank() == 0: 22 | return fn(*args, **kwargs) 23 | return wrapped_fn 24 | 25 | def is_torch_optimizer(obj): 26 | return isinstance(obj, optim.Optimizer) 27 | 28 | def rearranged_forward(x, func): 29 | x = rearrange(x, "B C H W -> B H W C") 30 | x = func(x) 31 | x = rearrange(x, "B H W C -> B C H W") 32 | return x 33 | 34 | def is_dtype_16(data): 35 | return data.dtype == torch.float16 or data.dtype == torch.bfloat16 36 | 37 | @contextmanager 38 | def set_tf32_flags(flag): 39 | old_matmul_flag = torch.backends.cuda.matmul.allow_tf32 40 | old_cudnn_flag = torch.backends.cudnn.allow_tf32 41 | torch.backends.cuda.matmul.allow_tf32 = flag 42 | torch.backends.cudnn.allow_tf32 = flag 43 | try: 44 | yield 45 | finally: 46 | # Restore the original flags 47 | torch.backends.cuda.matmul.allow_tf32 = old_matmul_flag 48 | torch.backends.cudnn.allow_tf32 = old_cudnn_flag 49 | 50 | def get_last_ckpt(root_dir): 51 | if not os.path.exists(root_dir): return None 52 | ckpt_files = {} 53 | for dirpath, dirnames, filenames in os.walk(root_dir): 54 | for filename in filenames: 55 | if filename.endswith('.ckpt'): 56 | num_iter = int(filename.split('.ckpt')[0].split('_')[-1]) 57 | ckpt_files[num_iter]=os.path.join(dirpath, filename) 58 | iter_list = list(ckpt_files.keys()) 59 | if len(iter_list) == 0: return None 60 | max_iter = max(iter_list) 61 | return ckpt_files[max_iter] 62 | 63 | -------------------------------------------------------------------------------- /bitvae/modules/quantizer/dynamic_resolution.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import tqdm 4 | 5 | vae_stride = 16 6 | ratio2hws = { 7 | 1.000: [(1,1),(2,2),(4,4),(6,6),(8,8),(12,12),(16,16),(20,20),(24,24),(32,32),(40,40),(48,48),(64,64),(80,80),(96,96),(128,128)], 8 | 1.250: [(1,1),(2,2),(3,3),(5,4),(10,8),(15,12),(20,16),(25,20),(30,24),(35,28),(45,36),(55,44),(70,56),(90,72),(110,88),(140,112)], 9 | 1.333: [(1,1),(2,2),(4,3),(8,6),(12,9),(16,12),(20,15),(24,18),(28,21),(36,27),(48,36),(60,45),(72,54),(96,72),(120,90),(144,108)], 10 | 1.500: [(1,1),(2,2),(3,2),(6,4),(9,6),(15,10),(21,14),(27,18),(33,22),(39,26),(48,32),(63,42),(78,52),(96,64),(126,84),(156,104)], 11 | 1.750: [(1,1),(2,2),(3,3),(7,4),(11,6),(14,8),(21,12),(28,16),(35,20),(42,24),(56,32),(70,40),(84,48),(112,64),(140,80),(168,96)], 12 | 2.000: [(1,1),(2,2),(4,2),(6,3),(10,5),(16,8),(22,11),(30,15),(38,19),(46,23),(60,30),(74,37),(90,45),(120,60),(148,74),(180,90)], 13 | 2.500: [(1,1),(2,2),(5,2),(10,4),(15,6),(20,8),(25,10),(30,12),(40,16),(50,20),(65,26),(80,32),(100,40),(130,52),(160,64),(200,80)], 14 | 3.000: [(1,1),(2,2),(6,2),(9,3),(15,5),(21,7),(27,9),(36,12),(45,15),(54,18),(72,24),(90,30),(111,37),(144,48),(180,60),(222,74)], 15 | } 16 | full_ratio2hws = {} 17 | for ratio, hws in ratio2hws.items(): 18 | full_ratio2hws[ratio] = hws 19 | full_ratio2hws[int(1/ratio*1000)/1000] = [(item[1], item[0]) for item in hws] 20 | 21 | dynamic_resolution_h_w = {} 22 | predefined_HW_Scales_dynamic = {} 23 | aspect_ratio_scale_list = [] 24 | bs_dict = {7: 8, 10: 4, 13: 1, 16: 1} # 256x256: batch=8, 512x512: batch=4, 1024x1024: batch=1 (bs=1 avoid OOM) 25 | for ratio in full_ratio2hws: 26 | dynamic_resolution_h_w[ratio] ={} 27 | for ind, leng in enumerate([7, 10, 13, 16]): 28 | h, w = full_ratio2hws[ratio][leng-1][0], full_ratio2hws[ratio][leng-1][1] # feature map size 29 | pixel = (h * vae_stride, w * vae_stride) # The original image (H, W) 30 | dynamic_resolution_h_w[ratio][pixel[1]] = { 31 | 'pixel': pixel, 32 | 'scales': full_ratio2hws[ratio][:leng] 33 | } # W as key 34 | predefined_HW_Scales_dynamic[(h, w)] = full_ratio2hws[ratio][:leng] 35 | # deal with aspect_ratio_scale_list 36 | info_dict = {"ratio": ratio, "h": h * vae_stride, "w": w * vae_stride, "bs": bs_dict[leng]} 37 | aspect_ratio_scale_list.append(info_dict) 38 | -------------------------------------------------------------------------------- /bitvae/evaluation/fid.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import linalg 3 | 4 | 5 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 6 | """Numpy implementation of the Frechet Distance. 7 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 8 | and X_2 ~ N(mu_2, C_2) is 9 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 10 | 11 | Stable version by Dougal J. Sutherland. 12 | 13 | Params: 14 | -- mu1 : Numpy array containing the activations of a layer of the 15 | inception net (like returned by the function 'get_predictions') 16 | for generated samples. 17 | -- mu2 : The sample mean over activations, precalculated on an 18 | representative data set. 19 | -- sigma1: The covariance matrix over activations for generated samples. 20 | -- sigma2: The covariance matrix over activations, precalculated on an 21 | representative data set. 22 | 23 | Returns: 24 | -- : The Frechet Distance. 25 | """ 26 | 27 | mu1 = np.atleast_1d(mu1) 28 | mu2 = np.atleast_1d(mu2) 29 | 30 | sigma1 = np.atleast_2d(sigma1) 31 | sigma2 = np.atleast_2d(sigma2) 32 | 33 | assert ( 34 | mu1.shape == mu2.shape 35 | ), "Training and test mean vectors have different lengths" 36 | assert ( 37 | sigma1.shape == sigma2.shape 38 | ), "Training and test covariances have different dimensions" 39 | 40 | diff = mu1 - mu2 41 | 42 | # Product might be almost singular 43 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 44 | if not np.isfinite(covmean).all(): 45 | msg = ( 46 | "fid calculation produces singular product; " 47 | "adding %s to diagonal of cov estimates" 48 | ) % eps 49 | print(msg) 50 | offset = np.eye(sigma1.shape[0]) * eps 51 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 52 | 53 | # Numerical error might give slight imaginary component 54 | if np.iscomplexobj(covmean): 55 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 56 | m = np.max(np.abs(covmean.imag)) 57 | raise ValueError("Imaginary component {}".format(m)) 58 | covmean = covmean.real 59 | 60 | tr_covmean = np.trace(covmean) 61 | 62 | return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean -------------------------------------------------------------------------------- /bitvae/utils/distributed.py: -------------------------------------------------------------------------------- 1 | # from https://github.com/FoundationVision/LlamaGen/blob/main/utils/distributed.py 2 | import os 3 | import torch 4 | import subprocess 5 | import torch.distributed as dist 6 | from bitvae.utils.misc import rank_zero_only 7 | 8 | def setup_for_distributed(is_master): 9 | """ 10 | This function disables printing when not in master process 11 | """ 12 | import builtins as __builtin__ 13 | builtin_print = __builtin__.print 14 | 15 | def print(*args, **kwargs): 16 | force = kwargs.pop('force', False) 17 | if is_master or force: 18 | builtin_print(*args, **kwargs) 19 | 20 | __builtin__.print = print 21 | 22 | def init_distributed_mode(args): 23 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 24 | args.rank = int(os.environ["RANK"]) 25 | args.world_size = int(os.environ['WORLD_SIZE']) 26 | args.gpu = int(os.environ['LOCAL_RANK']) 27 | args.dist_url = 'env://' 28 | os.environ['LOCAL_SIZE'] = str(torch.cuda.device_count()) 29 | elif 'SLURM_PROCID' in os.environ: 30 | proc_id = int(os.environ['SLURM_PROCID']) 31 | ntasks = int(os.environ['SLURM_NTASKS']) 32 | node_list = os.environ['SLURM_NODELIST'] 33 | num_gpus = torch.cuda.device_count() 34 | addr = subprocess.getoutput( 35 | 'scontrol show hostname {} | head -n1'.format(node_list)) 36 | os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', '29500') 37 | os.environ['MASTER_ADDR'] = addr 38 | os.environ['WORLD_SIZE'] = str(ntasks) 39 | os.environ['RANK'] = str(proc_id) 40 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) 41 | os.environ['LOCAL_SIZE'] = str(num_gpus) 42 | args.dist_url = 'env://' 43 | args.world_size = ntasks 44 | args.rank = proc_id 45 | args.gpu = proc_id % num_gpus 46 | else: 47 | print('Not using distributed mode') 48 | args.distributed = False 49 | return 50 | 51 | args.distributed = True 52 | 53 | torch.cuda.set_device(args.gpu) 54 | args.dist_backend = 'nccl' 55 | print('| distributed init (rank {}): {}'.format( 56 | args.rank, args.dist_url), flush=True) 57 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 58 | world_size=args.world_size, rank=args.rank) 59 | torch.distributed.barrier() 60 | setup_for_distributed(args.rank == 0) 61 | 62 | 63 | def reduce_losses(loss_dict, dst=0): 64 | loss_names = list(loss_dict.keys()) 65 | loss_tensor = torch.stack([loss_dict[name] for name in loss_names]) 66 | 67 | dist.reduce(loss_tensor, dst=dst, op=dist.ReduceOp.SUM) 68 | # Only average the loss values on the destination rank 69 | if dist.get_rank() == dst: 70 | loss_tensor /= dist.get_world_size() 71 | averaged_losses = {name: loss_tensor[i].item() for i, name in enumerate(loss_names)} 72 | else: 73 | averaged_losses = {name: None for name in loss_names} 74 | 75 | return averaged_losses 76 | 77 | @rank_zero_only 78 | def average_losses(loss_dict_list): 79 | sum_dict = {} 80 | count_dict = {} 81 | for loss_dict in loss_dict_list: 82 | for key, value in loss_dict.items(): 83 | if key in sum_dict: 84 | sum_dict[key] += value 85 | count_dict[key] += 1 86 | else: 87 | sum_dict[key] = value 88 | count_dict[key] = 1 89 | 90 | avg_dict = {key: sum_dict[key] / count_dict[key] for key in sum_dict} 91 | return avg_dict -------------------------------------------------------------------------------- /bitvae/models/discriminator.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch.nn as nn 3 | 4 | from bitvae.modules.normalization import Normalize 5 | 6 | class DiscriminatorPool: 7 | def __init__(self, pool_size): 8 | self.pool_size = int(pool_size) 9 | self.num_imgs = 0 10 | self.images = [] 11 | 12 | def query(self, images): 13 | if self.pool_size == 0: 14 | return images 15 | 16 | return_images = [] 17 | for image in images: 18 | if self.num_imgs < self.pool_size: 19 | self.images.append(image) 20 | self.num_imgs += 1 21 | return_images.append(image) 22 | else: 23 | if random.uniform(0, 1) > 0.5: 24 | i = random.randint(0, self.pool_size - 1) 25 | tmp = self.images[i].clone() 26 | self.images[i] = image 27 | return_images.append(tmp) 28 | else: 29 | return_images.append(image) 30 | return torch.stack(return_images) 31 | 32 | class ImageDiscriminator(nn.Module): 33 | def __init__(self, args): 34 | super().__init__() 35 | self.discriminator = NLayerDiscriminator(ndf=args.base_ch_disc) # by default using PatchGAN 36 | self.disc_pool = args.disc_pool # be default "no" 37 | if args.disc_pool == "yes": 38 | self.real_pool = DiscriminatorPool(pool_size=args.batch_size[0] * args.disc_pool_size) 39 | self.fake_pool = DiscriminatorPool(pool_size=args.batch_size[0] * args.disc_pool_size) 40 | 41 | def forward(self, x, pool_name=None): 42 | if pool_name and self.disc_pool == "yes": 43 | assert pool_name in ["real", "fake"] 44 | if pool_name == "real": 45 | x = self.real_pool.query(x) 46 | elif pool_name == "fake": 47 | x = self.fake_pool.query(x) 48 | # by default without pool 49 | return self.discriminator(x) 50 | 51 | class NLayerDiscriminator(nn.Module): 52 | """Defines a PatchGAN discriminator as in Pix2Pix 53 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 54 | """ 55 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): 56 | """Construct a PatchGAN discriminator 57 | Parameters: 58 | input_nc (int) -- the number of channels in input images 59 | ndf (int) -- the number of filters in the last conv layer 60 | n_layers (int) -- the number of conv layers in the discriminator 61 | norm_layer -- normalization layer 62 | """ 63 | super(NLayerDiscriminator, self).__init__() 64 | norm_type = "batch" 65 | use_bias = norm_type != "batch" 66 | 67 | kw = 4 68 | padw = 1 69 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] 70 | nf_mult = 1 71 | nf_mult_prev = 1 72 | for n in range(1, n_layers): # gradually increase the number of filters 73 | nf_mult_prev = nf_mult 74 | nf_mult = min(2 ** n, 8) 75 | sequence += [ 76 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), 77 | Normalize(ndf * nf_mult, norm_type=norm_type), 78 | nn.LeakyReLU(0.2, True) 79 | ] 80 | 81 | nf_mult_prev = nf_mult 82 | nf_mult = min(2 ** n_layers, 8) 83 | sequence += [ 84 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), 85 | Normalize(ndf * nf_mult, norm_type=norm_type), 86 | nn.LeakyReLU(0.2, True) 87 | ] 88 | 89 | sequence += [ 90 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map 91 | self.main = nn.Sequential(*sequence) 92 | 93 | self.apply(self._init_weights) 94 | 95 | def _init_weights(self, module): 96 | if isinstance(module, nn.Conv2d): 97 | nn.init.normal_(module.weight.data, 0.0, 0.02) 98 | elif isinstance(module, nn.BatchNorm2d): 99 | nn.init.normal_(module.weight.data, 1.0, 0.02) 100 | nn.init.constant_(module.bias.data, 0) 101 | 102 | def forward(self, input): 103 | """Standard forward.""" 104 | return self.main(input) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Bitwise Visual Tokenizer 2 | The training and inference code of bitwise tokenizer used by [Infinity](https://github.com/FoundationVision/Infinity). 3 | 4 | ### BitVAE Model ZOO 5 | We provide Infinity models for you to play with, which are on or can be downloaded from the following links: 6 | 7 | ### Visual Tokenizer 8 | 9 | | vocabulary | stride | IN-256 rFID $\downarrow$ | IN-256 PSNR $\uparrow$ | IN-512 rFID $\downarrow$ | IN-512 PSNR $\uparrow$ | HF weights🤗 | 10 | |:----------:|:-----:|:--------:|:---------:|:-------:|:-------:|:------------------------------------------------------------------------------------| 11 | | $V_d=2^{16}$ | 16 | 1.22 | 20.9 | 0.31 | 22.6 | [infinity_vae_d16.pth](https://huggingface.co/FoundationVision/infinity/blob/main/infinity_vae_d16.pth) | 12 | | $V_d=2^{24}$ | 16 | 0.75 | 22.0 | 0.30 | 23.5 | [infinity_vae_d24.pth](https://huggingface.co/FoundationVision/infinity/blob/main/infinity_vae_d24.pth) | 13 | | $V_d=2^{32}$ | 16 | 0.61 | 22.7 | 0.23 | 24.4 | [infinity_vae_d32.pth](https://huggingface.co/FoundationVision/infinity/blob/main/infinity_vae_d32.pth) | 14 | | $V_d=2^{64}$ | 16 | 0.33 | 24.9 | 0.15 | 26.4 | [infinity_vae_d64.pth](https://huggingface.co/FoundationVision/infinity/blob/main/infinity_vae_d64.pth) | 15 | | $V_d=2^{32}$ | 16 | 0.75 | 21.9 | 0.32 | 23.6 | [infinity_vae_d32_reg.pth](https://huggingface.co/FoundationVision/Infinity/blob/main/infinity_vae_d32reg.pth) | 16 | 17 | ### Environment installation 18 | ``` 19 | bash scripts/prepare.sh 20 | ``` 21 | 22 | Download `checkpoints` and `labels` from [Google Drive](https://drive.google.com/drive/folders/15VCFUpcv1ktU7RR3Yw_LqI4vuR5Q2r0y?usp=sharing) and put them under the project folder. If you want to use our trained model weights, please also download `bitvae_results`. 23 | We expect that the data is organized as below. 24 | ``` 25 | ${PROJECT_ROOT} 26 | -- bitvae 27 | -- bitvae_results 28 | -- Infinity_d16_stage1 29 | -- Infinity_d16_stage2 30 | -- Infinity_d32_stage1 31 | -- Infinity_d32_stage2 32 | -- checkpoints 33 | -- labels 34 | -- scripts 35 | -- test 36 | ... 37 | ``` 38 | 39 | 40 | ### Training 41 | Before training, please generate a `labels/openimages/train.txt` according to our provided `labels/imagenet/val_example.txt`. please replace with the real path on your system. 42 | 43 | Tokenizer with hidden dimension 16 44 | ``` 45 | bash scripts/release/train_img_d16_stage1.sh # stage 1: single-scale pre-training 46 | bash scripts/release/train_img_d16_stage2.sh # stage 2: multi-scale fine-tuning 47 | ``` 48 | Tokenizer with hidden dimension 32 49 | ``` 50 | bash scripts/release/train_img_d32_stage1.sh # stage 1: single-scale pre-training 51 | bash scripts/release/train_img_d32_stage2.sh # stage 2: multi-scale fine-tuning 52 | ``` 53 | 54 | ### Testing & evaluation 55 | Before testing, please generate a `labels/imagenet/val.txt` according to our provided `labels/imagenet/val_example.txt`. please replace with the real path on your system. 56 | 57 | Tokenizer with hidden dimension 16 58 | ``` 59 | bash scripts/release/test_img_d16_stage1.sh 60 | bash scripts/release/test_img_d16_stage2.sh 61 | ``` 62 | Tokenizer with hidden dimension 32 63 | ``` 64 | bash scripts/release/test_img_d32_stage1.sh 65 | bash scripts/release/test_img_d32_stage2.sh 66 | ``` 67 | ### 📖 Citation 68 | If our work assists your research, feel free to give us a star ⭐ or cite us using: 69 | 70 | ``` 71 | @misc{Infinity, 72 | title={Infinity: Scaling Bitwise AutoRegressive Modeling for High-Resolution Image Synthesis}, 73 | author={Jian Han and Jinlai Liu and Yi Jiang and Bin Yan and Yuqi Zhang and Zehuan Yuan and Bingyue Peng and Xiaobing Liu}, 74 | year={2024}, 75 | eprint={2412.04431}, 76 | archivePrefix={arXiv}, 77 | primaryClass={cs.CV}, 78 | url={https://arxiv.org/abs/2412.04431}, 79 | } 80 | ``` 81 | 82 | ``` 83 | @misc{VAR, 84 | title={Visual Autoregressive Modeling: Scalable Image Generation via Next-Scale Prediction}, 85 | author={Keyu Tian and Yi Jiang and Zehuan Yuan and Bingyue Peng and Liwei Wang}, 86 | year={2024}, 87 | eprint={2404.02905}, 88 | archivePrefix={arXiv}, 89 | primaryClass={cs.CV}, 90 | url={https://arxiv.org/abs/2404.02905}, 91 | } 92 | ``` 93 | 94 | ### License 95 | This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. 96 | -------------------------------------------------------------------------------- /bitvae/modules/lpips.py: -------------------------------------------------------------------------------- 1 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" 2 | 3 | import os, hashlib 4 | import requests 5 | from tqdm import tqdm 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torchvision import models 10 | from collections import namedtuple 11 | 12 | from bitvae.utils.misc import set_tf32_flags 13 | 14 | URL_MAP = { 15 | "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" 16 | } 17 | 18 | CKPT_MAP = { 19 | "vgg_lpips": "vgg.pth" 20 | } 21 | 22 | MD5_MAP = { 23 | "vgg_lpips": "d507d7349b931f0638a25a48a722f98a" 24 | } 25 | 26 | def download(url, local_path, chunk_size=1024): 27 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 28 | with requests.get(url, stream=True) as r: 29 | total_size = int(r.headers.get("content-length", 0)) 30 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 31 | with open(local_path, "wb") as f: 32 | for data in r.iter_content(chunk_size=chunk_size): 33 | if data: 34 | f.write(data) 35 | pbar.update(chunk_size) 36 | 37 | 38 | def md5_hash(path): 39 | with open(path, "rb") as f: 40 | content = f.read() 41 | return hashlib.md5(content).hexdigest() 42 | 43 | 44 | def get_ckpt_path(name, root, check=False): 45 | assert name in URL_MAP 46 | path = os.path.join(root, CKPT_MAP[name]) 47 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): 48 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) 49 | download(URL_MAP[name], path) 50 | md5 = md5_hash(path) 51 | assert md5 == MD5_MAP[name], md5 52 | return path 53 | 54 | 55 | class LPIPS(nn.Module): 56 | # Learned perceptual metric 57 | def __init__(self, use_dropout=True, upcast_tf32=False): 58 | super().__init__() 59 | self.upcast_tf32 = upcast_tf32 60 | self.scaling_layer = ScalingLayer() 61 | self.chns = [64, 128, 256, 512, 512] # vg16 features 62 | self.net = vgg16(pretrained=True, requires_grad=False) 63 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 64 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 65 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 66 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 67 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 68 | self.load_from_pretrained() 69 | for param in self.parameters(): 70 | param.requires_grad = False 71 | 72 | def load_from_pretrained(self, name="vgg_lpips"): 73 | ckpt = get_ckpt_path(name, os.path.join(os.path.dirname(os.path.abspath(__file__)), "cache")) 74 | self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu"), weights_only=True), strict=False) 75 | print("loaded pretrained LPIPS loss from {}".format(ckpt)) 76 | 77 | @classmethod 78 | def from_pretrained(cls, name="vgg_lpips"): 79 | if name is not "vgg_lpips": 80 | raise NotImplementedError 81 | model = cls() 82 | ckpt = get_ckpt_path(name, os.path.join(os.path.dirname(os.path.abspath(__file__)), "cache")) 83 | model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu"), weights_only=True), strict=False) 84 | return model 85 | 86 | def forward(self, input, target): 87 | with set_tf32_flags(not self.upcast_tf32): 88 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) 89 | outs0, outs1 = self.net(in0_input), self.net(in1_input) 90 | feats0, feats1, diffs = {}, {}, {} 91 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 92 | for kk in range(len(self.chns)): 93 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) 94 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 95 | 96 | res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))] 97 | val = res[0] 98 | for l in range(1, len(self.chns)): 99 | # print(res[l].shape) 100 | val += res[l] 101 | 102 | return val 103 | 104 | 105 | class ScalingLayer(nn.Module): 106 | def __init__(self): 107 | super(ScalingLayer, self).__init__() 108 | self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 109 | self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) 110 | 111 | def forward(self, inp): 112 | return (inp - self.shift) / self.scale 113 | 114 | 115 | class NetLinLayer(nn.Module): 116 | """ A single linear layer which does a 1x1 conv """ 117 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 118 | super(NetLinLayer, self).__init__() 119 | layers = [nn.Dropout(), ] if (use_dropout) else [] 120 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] 121 | self.model = nn.Sequential(*layers) 122 | 123 | 124 | class vgg16(torch.nn.Module): 125 | def __init__(self, requires_grad=False, pretrained=True): 126 | super(vgg16, self).__init__() 127 | # load locally 128 | assert pretrained == True 129 | vgg_model = models.vgg16() 130 | vgg_model.load_state_dict(torch.load("checkpoints/vgg16-397923af.pth", weights_only=True)) 131 | vgg_pretrained_features = vgg_model.features 132 | 133 | self.slice1 = torch.nn.Sequential() 134 | self.slice2 = torch.nn.Sequential() 135 | self.slice3 = torch.nn.Sequential() 136 | self.slice4 = torch.nn.Sequential() 137 | self.slice5 = torch.nn.Sequential() 138 | self.N_slices = 5 139 | for x in range(4): 140 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 141 | for x in range(4, 9): 142 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 143 | for x in range(9, 16): 144 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 145 | for x in range(16, 23): 146 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 147 | for x in range(23, 30): 148 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 149 | if not requires_grad: 150 | for param in self.parameters(): 151 | param.requires_grad = False 152 | 153 | def forward(self, X): 154 | h = self.slice1(X) 155 | h_relu1_2 = h 156 | h = self.slice2(h) 157 | h_relu2_2 = h 158 | h = self.slice3(h) 159 | h_relu3_3 = h 160 | h = self.slice4(h) 161 | h_relu4_3 = h 162 | h = self.slice5(h) 163 | h_relu5_3 = h 164 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) 165 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 166 | return out 167 | 168 | 169 | def normalize_tensor(x,eps=1e-10): 170 | # norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True)) 171 | norm_factor = x.norm(p=2, dim=1, keepdim=True) 172 | return x/(norm_factor+eps) 173 | 174 | 175 | def spatial_average(x, keepdim=True): 176 | return x.mean([2,3],keepdim=keepdim) 177 | -------------------------------------------------------------------------------- /bitvae/utils/arguments.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from bitvae.models import d_vae 4 | 5 | def add_model_specific_args(args, parser): 6 | if args.tokenizer == "flux": 7 | parser = d_vae.add_model_specific_args(parser) # flux config 8 | d_vae_model = d_vae 9 | else: 10 | raise NotImplementedError 11 | return args, parser, d_vae_model 12 | 13 | class MainArgs: 14 | @staticmethod 15 | def add_main_args(parser): 16 | # training 17 | parser.add_argument('--max_steps', type=int, default=1e6) 18 | parser.add_argument('--log_every', type=int, default=1) 19 | parser.add_argument('--visu_every', type=int, default=1000) 20 | parser.add_argument('--ckpt_every', type=int, default=1000) 21 | parser.add_argument('--default_root_dir', type=str, required=True) 22 | parser.add_argument('--multiscale_training', action="store_true") 23 | 24 | # optimization 25 | parser.add_argument('--lr', type=float, default=1e-4) 26 | parser.add_argument('--beta1', type=float, default=0.9) 27 | parser.add_argument('--beta2', type=float, default=0.95) 28 | parser.add_argument('--warmup_steps', type=int, default=0) 29 | parser.add_argument('--optim_type', type=str, default="Adam", choices=["Adam", "AdamW"]) 30 | parser.add_argument('--disc_optim_type', type=str, default=None, choices=[None, "rmsprop"]) 31 | parser.add_argument('--lr_min', type=float, default=0.) 32 | parser.add_argument('--warmup_lr_init', type=float, default=0.) 33 | parser.add_argument('--max_grad_norm', type=float, default=1.0) 34 | parser.add_argument('--max_grad_norm_disc', type=float, default=1.0) 35 | parser.add_argument('--disable_sch', action="store_true") 36 | 37 | # basic d_vae config 38 | parser.add_argument('--patch_size', type=int, default=8) 39 | parser.add_argument('--codebook_dim', type=int, default=16) 40 | parser.add_argument('--quantizer_type', type=str, default='MultiScaleLFQ') 41 | 42 | parser.add_argument('--new_quant', action="store_true") # use new quantization (fix the potential bugs of the old quantizer) 43 | parser.add_argument('--use_decay_factor', action="store_true") 44 | parser.add_argument('--use_stochastic_depth', action="store_true") 45 | parser.add_argument("--drop_rate", type=float, default=0.0) 46 | parser.add_argument('--schedule_mode', type=str, default="original", choices=["original", "dynamic", "dense", "same1", "same2", "same3", "half", "dense_f8"]) 47 | parser.add_argument('--lr_drop', nargs='*', type=int, default=None, help="A list of numeric values. Example: --values 270 300") 48 | parser.add_argument('--lr_drop_rate', type=float, default=0.1) 49 | parser.add_argument('--keep_first_quant', action="store_true") 50 | parser.add_argument('--keep_last_quant', action="store_true") 51 | parser.add_argument('--remove_residual_detach', action="store_true") 52 | parser.add_argument('--use_out_phi', action="store_true") 53 | parser.add_argument('--use_out_phi_res', action="store_true") 54 | parser.add_argument('--lecam_weight', type=float, default=0.05) 55 | parser.add_argument('--perceptual_model', type=str, default="vgg16", choices=["vgg16"]) 56 | parser.add_argument('--base_ch_disc', type=int, default=64) 57 | parser.add_argument('--random_flip', action="store_true") 58 | parser.add_argument('--flip_prob', type=float, default=0.5) 59 | parser.add_argument('--flip_mode', type=str, default="stochastic", choices=["stochastic"]) 60 | parser.add_argument('--max_flip_lvl', type=int, default=1) 61 | parser.add_argument('--not_load_optimizer', action="store_true") 62 | parser.add_argument('--use_lecam_reg_zero', action="store_true") 63 | parser.add_argument('--rm_downsample', action="store_true") 64 | parser.add_argument('--random_flip_1lvl', action="store_true") 65 | parser.add_argument('--flip_lvl_idx', type=int, default=0) 66 | parser.add_argument('--drop_when_test', action="store_true") 67 | parser.add_argument('--drop_lvl_idx', type=int, default=None) 68 | parser.add_argument('--drop_lvl_num', type=int, default=0) 69 | parser.add_argument('--compute_all_commitment', action="store_true") 70 | parser.add_argument('--disable_codebook_usage', action="store_true") 71 | parser.add_argument('--random_short_schedule', action="store_true") 72 | parser.add_argument('--short_schedule_prob', type=float, default=0.5) 73 | parser.add_argument('--disable_flip_prob', type=float, default=0.0) 74 | parser.add_argument('--zeta', type=float, default=1.0) # entropy penalty weight 75 | parser.add_argument('--disable_codebook_usage_bit', action="store_true") 76 | parser.add_argument('--gamma', type=float, default=1.0) # loss weight of H(E[p(c|u)]) 77 | parser.add_argument('--uniform_short_schedule', action="store_true") 78 | 79 | # discriminator config 80 | parser.add_argument('--dis_warmup_steps', type=int, default=0) 81 | parser.add_argument('--dis_lr_multiplier', type=float, default=1.) 82 | parser.add_argument('--dis_minlr_multiplier', action="store_true") 83 | parser.add_argument('--disc_layers', type=int, default=3) 84 | parser.add_argument('--discriminator_iter_start', type=int, default=0) 85 | parser.add_argument('--disc_pretrain_iter', type=int, default=0) 86 | parser.add_argument('--disc_optim_steps', type=int, default=1) 87 | parser.add_argument('--disc_warmup', type=int, default=0) 88 | parser.add_argument('--disc_pool', type=str, default="no", choices=["no", "yes"]) 89 | parser.add_argument('--disc_pool_size', type=int, default=1000) 90 | 91 | # loss 92 | parser.add_argument("--recon_loss_type", type=str, default='l1', choices=['l1', 'l2']) 93 | parser.add_argument('--image_gan_weight', type=float, default=1.0) 94 | parser.add_argument('--image_disc_weight', type=float, default=0.) 95 | parser.add_argument('--l1_weight', type=float, default=4.0) 96 | parser.add_argument('--gan_feat_weight', type=float, default=0.0) 97 | parser.add_argument('--perceptual_weight', type=float, default=0.0) 98 | parser.add_argument('--kl_weight', type=float, default=0.) 99 | parser.add_argument('--lfq_weight', type=float, default=0.) 100 | parser.add_argument('--entropy_loss_weight', type=float, default=0.1) 101 | parser.add_argument('--commitment_loss_weight', type=float, default=0.25) 102 | parser.add_argument('--diversity_gamma', type=float, default=1) 103 | parser.add_argument('--norm_type', type=str, default='group', choices=['batch', 'group', "no"]) 104 | parser.add_argument('--disc_loss_type', type=str, default='hinge', choices=['hinge', 'vanilla']) 105 | 106 | # acceleration 107 | parser.add_argument('--use_checkpoint', action="store_true") 108 | parser.add_argument('--precision', type=str, default="fp32", choices=['fp32', 'bf16']) # disable fp16 109 | parser.add_argument('--encoder_dtype', type=str, default="fp32", choices=['fp32', 'bf16']) # disable fp16 110 | parser.add_argument('--upcast_tf32', action="store_true") 111 | 112 | # initialization 113 | parser.add_argument('--tokenizer', type=str, default='flux', choices=["flux"]) 114 | parser.add_argument('--pretrained', type=str, default=None) 115 | parser.add_argument('--pretrained_mode', type=str, default="full", choices=['full']) 116 | 117 | # misc 118 | parser.add_argument('--debug', action='store_true') 119 | parser.add_argument('--seed', type=int, default=1234) 120 | parser.add_argument('--bucket_cap_mb', type=int, default=40) # DDP 121 | parser.add_argument('--manual_gc_interval', type=int, default=1000) # DDP 122 | 123 | return parser -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tqdm 3 | import json 4 | import re 5 | import torch 6 | import torch.nn.functional as F 7 | import argparse 8 | import numpy as np 9 | import torch.nn as nn 10 | 11 | import lpips 12 | from tqdm import tqdm 13 | from PIL import Image 14 | Image.MAX_IMAGE_PIXELS = None 15 | 16 | import torch.distributed as dist 17 | from torch.multiprocessing import spawn 18 | from skimage.metrics import peak_signal_noise_ratio as psnr_loss 19 | from skimage.metrics import structural_similarity as ssim_loss 20 | 21 | from bitvae.data import ImageData 22 | from bitvae.utils.arguments import MainArgs, add_model_specific_args 23 | from bitvae.evaluation import calculate_frechet_distance 24 | from bitvae.evaluation import InceptionV3 25 | 26 | torch.set_num_threads(32) 27 | 28 | 29 | def calculate_batch_codebook_usage_percentage_bit(batch_encoding_indices): 30 | if isinstance(batch_encoding_indices, list): 31 | all_indices = [] 32 | for one_encoding_indices in batch_encoding_indices: 33 | all_indices.append(one_encoding_indices.flatten(0, -2)) # [bhw, d] 34 | all_indices = torch.cat(all_indices, dim=0) # [sigma(bhw), d] 35 | else: 36 | # Flatten the batch of encoding indices into a single 1D tensor 37 | raise NotImplementedError 38 | all_indices = all_indices.detach().cpu() 39 | 40 | codebook_usage = torch.sum(all_indices, dim=0) # (d, ) 41 | 42 | return codebook_usage, len(all_indices), all_indices.numpy() 43 | 44 | def default_parse_args(): 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument('--vqgan_ckpt', type=str, default=None) 47 | parser.add_argument('--inference_type', type=str, choices=["image"]) 48 | parser.add_argument('--save', type=str, required=True) 49 | parser.add_argument('--device', type=str, default="cuda", choices=["cpu", "cuda"]) 50 | parser = MainArgs.add_main_args(parser) 51 | parser = ImageData.add_data_specific_args(parser) 52 | args, unknown = parser.parse_known_args() 53 | args, parser, d_vae_model = add_model_specific_args(args, parser) 54 | args = parser.parse_args() 55 | return args, d_vae_model 56 | 57 | 58 | def setup(rank, world_size): 59 | os.environ['MASTER_ADDR'] = 'localhost' 60 | os.environ['MASTER_PORT'] = '12355' 61 | dist.init_process_group("nccl", rank=rank, world_size=world_size) 62 | 63 | def cleanup(): 64 | dist.destroy_process_group() 65 | 66 | def main(): 67 | args, d_vae_model = default_parse_args() 68 | os.makedirs(args.default_root_dir, exist_ok=True) 69 | 70 | # init resolution 71 | args.resolution = (args.resolution[0], args.resolution[0]) if len(args.resolution) == 1 else args.resolution # init resolution 72 | 73 | d_vae = None 74 | num_codes = None 75 | 76 | if args.tokenizer in ["flux"]: 77 | print('args: ',args) 78 | d_vae = d_vae_model(args) 79 | num_codes = args.codebook_size 80 | state_dict = torch.load(args.vqgan_ckpt, map_location=torch.device("cpu"), weights_only=True) 81 | d_vae.load_state_dict(state_dict["vae"]) 82 | else: 83 | raise NotImplementedError 84 | 85 | 86 | world_size = 1 if args.debug else torch.cuda.device_count() 87 | manager = torch.multiprocessing.Manager() 88 | return_dict = manager.dict() 89 | 90 | if args.debug: 91 | inference_eval(0, world_size, args, d_vae_model, d_vae, num_codes, return_dict) 92 | else: 93 | spawn(inference_eval, args=(world_size, args, d_vae_model, d_vae, num_codes, return_dict), nprocs=world_size, join=True) 94 | 95 | pred_xs, pred_recs, lpips_alex, lpips_vgg, ssim_value, psnr_value, num_iter, total_usage, total_usage_bit, total_num_token, all_bit_indices_cat = [], [], 0, 0, 0, 0, 0, 0, 0, 0, [] 96 | for rank in range(world_size): 97 | pred_xs.append(return_dict[rank]['pred_xs']) 98 | pred_recs.append(return_dict[rank]['pred_recs']) 99 | lpips_alex += return_dict[rank]['lpips_alex'] 100 | lpips_vgg += return_dict[rank]['lpips_vgg'] 101 | ssim_value += return_dict[rank]['ssim_value'] 102 | psnr_value += return_dict[rank]['psnr_value'] 103 | num_iter += return_dict[rank]['num_iter'] 104 | total_usage += return_dict[rank]['total_usage'] 105 | if not args.disable_codebook_usage_bit: 106 | total_usage_bit += return_dict[rank]['total_usage_bit'] 107 | total_num_token += return_dict[rank]['total_num_token'] 108 | all_bit_indices_cat.append(return_dict[rank]['all_bit_indices_cat']) 109 | pred_xs = np.concatenate(pred_xs, 0) 110 | pred_recs = np.concatenate(pred_recs, 0) 111 | 112 | result_str = image_eval(pred_xs, pred_recs, lpips_alex, lpips_vgg, ssim_value, psnr_value, num_iter, total_usage, num_codes, total_usage_bit, total_num_token) 113 | print(result_str) 114 | # save result_str to exp_dir 115 | if args.tokenizer == "flux": 116 | basename = os.path.basename(args.vqgan_ckpt) 117 | match = re.search(r'model_step_(\d+)\.ckpt', basename) 118 | iter_num = match.group(1) if match else None 119 | ckpt_dir = os.path.dirname(args.vqgan_ckpt) 120 | save_dir = os.path.join(ckpt_dir, "evaluation") 121 | os.makedirs(save_dir, exist_ok=True) 122 | if args.random_flip: 123 | flip_prob = int(args.flip_prob * 10) 124 | result_name = os.path.join(save_dir, f"result_{args.dataset_list}_{iter_num}_{args.schedule_mode}_{args.resolution}_max_flip_lvl_{args.max_flip_lvl}_flip_prob_{flip_prob}.txt") 125 | elif args.random_flip_1lvl: 126 | result_name = os.path.join(save_dir, f"result_{args.dataset_list}_{iter_num}_{args.schedule_mode}_flip_lvl_{args.flip_lvl_idx}.txt") 127 | elif args.drop_when_test: 128 | result_name = os.path.join(save_dir, f"result_{args.dataset_list}_{iter_num}_{args.schedule_mode}_drop_lvl_idx_{args.drop_lvl_idx}_drop_lvl_num_{args.drop_lvl_num}.txt") 129 | else: 130 | result_name = os.path.join(save_dir, f"result_{args.dataset_list}_{iter_num}_{args.schedule_mode}_{args.resolution}.txt") 131 | else: 132 | raise NotImplementedError 133 | with open(result_name, "w") as f: 134 | f.write(result_str) 135 | # print('Usage = %.2f'%((total_usage > 0.).sum() / num_codes)) 136 | 137 | def inference_eval(rank, world_size, args, d_vae_model, d_vae, num_codes, return_dict): 138 | # Don't remove this setup!!! dist.init_process_group is important for building loader (data.distributed.DistributedSampler) 139 | setup(rank, world_size) 140 | 141 | device = torch.device(f"cuda:{rank}") 142 | 143 | for param in d_vae.parameters(): 144 | param.requires_grad = False 145 | d_vae.to(device).eval() 146 | 147 | save_dir = 'results/%s'%(args.save) 148 | print('generating and saving image to %s...'%save_dir) 149 | os.makedirs(save_dir, exist_ok=True) 150 | 151 | data = ImageData(args) 152 | 153 | loader = data.val_dataloader() 154 | 155 | dims = 2048 156 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] 157 | inception_model = InceptionV3([block_idx]).to(device) 158 | inception_model.eval() 159 | 160 | loader_iter = iter(loader) 161 | 162 | pred_xs = [] 163 | pred_recs = [] 164 | all_bit_indices_cat = [] 165 | # LPIPS score related 166 | loss_fn_alex = lpips.LPIPS(net='alex').to(device) # best forward scores 167 | loss_fn_vgg = lpips.LPIPS(net='vgg').to(device) # closer to "traditional" perceptual loss, when used for optimization 168 | lpips_alex = 0.0 169 | lpips_vgg = 0.0 170 | 171 | # SSIM score related 172 | ssim_value = 0.0 173 | 174 | # PSNR score related 175 | psnr_value = 0.0 176 | 177 | num_images = len(loader) 178 | print(f"Testing {num_images} files") 179 | num_iter = 0 180 | 181 | total_usage = 0.0 182 | total_usage_bit = 0.0 183 | total_num_token = 0 184 | 185 | for batch_idx in tqdm(range(num_images)): 186 | batch = next(loader_iter) 187 | 188 | with torch.no_grad(): 189 | x = batch['image'] 190 | if args.tokenizer in ["flux"]: 191 | torch.cuda.empty_cache() 192 | # x: [-1, 1] 193 | x_recons, vq_output = d_vae(x.to(device), 2, 0, is_train=False) 194 | x_recons = x_recons.cpu() 195 | else: 196 | raise NotImplementedError 197 | 198 | if not args.disable_codebook_usage_bit: 199 | bit_indices = vq_output["bit_encodings"] 200 | codebook_usage_bit, num_token, bit_indices_cat = calculate_batch_codebook_usage_percentage_bit(bit_indices) 201 | total_usage_bit += codebook_usage_bit 202 | total_num_token += num_token 203 | all_bit_indices_cat.append(bit_indices_cat) 204 | 205 | paths = batch["path"] 206 | assert len(paths) == x.shape[0] 207 | 208 | for p, input_ori, recon_ori in zip(paths, x, x_recons): 209 | path = os.path.join(save_dir, "input_recon", os.path.basename(p)) 210 | os.makedirs(os.path.split(path)[0], exist_ok=True) 211 | 212 | input_ori = input_ori.unsqueeze(0).to(device) 213 | input_ = (input_ori + 1) / 2 # [-1, 1] -> [0, 1] 214 | 215 | pred_x = inception_model(input_)[0] 216 | pred_x = pred_x.squeeze(3).squeeze(2).cpu().numpy() 217 | 218 | recon_ori = recon_ori.unsqueeze(0).to(device) 219 | recon_ = (recon_ori + 1) / 2 # [-1, 1] -> [0, 1] 220 | # recon_ = recon_.permute(1, 2, 0).detach().cpu() 221 | with torch.no_grad(): 222 | pred_rec = inception_model(recon_)[0] 223 | pred_rec = pred_rec.squeeze(3).squeeze(2).cpu().numpy() 224 | 225 | pred_xs.append(pred_x) 226 | pred_recs.append(pred_rec) 227 | 228 | # calculate lpips 229 | with torch.no_grad(): 230 | lpips_alex += loss_fn_alex(input_ori, recon_ori).sum() # [-1, 1] 231 | lpips_vgg += loss_fn_vgg(input_ori, recon_ori).sum() # [-1, 1] 232 | 233 | #calculate PSNR and SSIM 234 | rgb_restored = (recon_ * 255.0).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy() 235 | rgb_gt = (input_ * 255.0).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy() 236 | rgb_restored = rgb_restored.astype(np.float32) / 255. 237 | rgb_gt = rgb_gt.astype(np.float32) / 255. 238 | ssim_temp = 0 239 | psnr_temp = 0 240 | B, _, _, _ = rgb_restored.shape 241 | for i in range(B): 242 | rgb_restored_s, rgb_gt_s = rgb_restored[i], rgb_gt[i] 243 | with torch.no_grad(): 244 | ssim_temp += ssim_loss(rgb_restored_s, rgb_gt_s, data_range=1.0, channel_axis=-1) 245 | psnr_temp += psnr_loss(rgb_gt, rgb_restored) 246 | ssim_value += ssim_temp / B 247 | psnr_value += psnr_temp / B 248 | num_iter += 1 249 | 250 | pred_xs = np.concatenate(pred_xs, axis=0) 251 | pred_recs = np.concatenate(pred_recs, axis=0) 252 | temp_dict = { 253 | 'pred_xs':pred_xs, 254 | 'pred_recs':pred_recs, 255 | 'lpips_alex':lpips_alex.cpu(), 256 | 'lpips_vgg':lpips_vgg.cpu(), 257 | 'ssim_value': ssim_value, 258 | 'psnr_value': psnr_value, 259 | 'num_iter': num_iter, 260 | 'total_usage': total_usage, 261 | 'total_usage_bit': total_usage_bit, 262 | 'total_num_token': total_num_token, 263 | } 264 | if not args.disable_codebook_usage_bit: 265 | all_bit_indices_cat = np.concatenate(all_bit_indices_cat, axis=0) 266 | temp_dict['all_bit_indices_cat'] = all_bit_indices_cat 267 | return_dict[rank] = temp_dict 268 | 269 | if dist.is_initialized(): 270 | dist.barrier() 271 | cleanup() 272 | 273 | def image_eval(pred_xs, pred_recs, lpips_alex, lpips_vgg, ssim_value, psnr_value, num_iter, total_usage, num_codes, total_usage_bit, total_num_token): 274 | mu_x = np.mean(pred_xs, axis=0) 275 | sigma_x = np.cov(pred_xs, rowvar=False) 276 | mu_rec = np.mean(pred_recs, axis=0) 277 | sigma_rec = np.cov(pred_recs, rowvar=False) 278 | 279 | fid_value = calculate_frechet_distance(mu_x, sigma_x, mu_rec, sigma_rec) 280 | lpips_alex_value = lpips_alex / num_iter 281 | lpips_vgg_value = lpips_vgg / num_iter 282 | ssim_value = ssim_value / num_iter 283 | psnr_value = psnr_value / num_iter 284 | if total_num_token != 0: 285 | bit_distribution = total_usage_bit / total_num_token 286 | bit_distribution_str = '\n'.join(f'{value:.4f}' for value in bit_distribution) 287 | 288 | # usage_0 = (total_usage > 0.).sum() / num_codes * 100 289 | # usage_10 = (total_usage > 10.).sum() / num_codes * 100 290 | 291 | result_str = f""" 292 | FID = {fid_value:.4f} 293 | LPIPS_VGG: {lpips_vgg_value.item():.4f} 294 | LPIPS_ALEX: {lpips_alex_value.item():.4f} 295 | SSIM: {ssim_value:.4f} 296 | PSNR: {psnr_value:.3f} 297 | """ 298 | if total_num_token != 0: 299 | result_str += f""" 300 | Bit_Distribution: {bit_distribution_str} 301 | """ 302 | # Usage(>0): {usage_0:.2f}% 303 | # Usage(>10): {usage_10:.2f}% 304 | return result_str 305 | if __name__ == '__main__': 306 | main() -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import math 4 | import glob 5 | import time 6 | import logging 7 | from copy import deepcopy 8 | import numpy as np 9 | 10 | import torch 11 | import torch.nn.functional as F 12 | import torch.distributed as dist 13 | 14 | from torch.nn.parallel import DistributedDataParallel as DDP 15 | from timm.scheduler.cosine_lr import CosineLRScheduler 16 | 17 | from bitvae.utils.distributed import init_distributed_mode, reduce_losses, average_losses 18 | from bitvae.utils.logger import create_logger 19 | 20 | from bitvae.models import ImageDiscriminator 21 | from bitvae.data import ImageData 22 | from bitvae.modules.loss import get_disc_loss, adopt_weight 23 | from bitvae.utils.misc import get_last_ckpt 24 | from bitvae.utils.init_models import resume_from_ckpt 25 | from bitvae.utils.arguments import MainArgs, add_model_specific_args 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | def lecam_reg_zero(real_pred, fake_pred, thres=0.1): 30 | # avoid logits get too high 31 | assert real_pred.ndim == 0 32 | reg = torch.mean(F.relu(torch.abs(real_pred) - thres).pow(2)) + \ 33 | torch.mean(F.relu(torch.abs(fake_pred) - thres).pow(2)) 34 | return reg 35 | 36 | 37 | def main(): 38 | parser = argparse.ArgumentParser() 39 | parser = MainArgs.add_main_args(parser) 40 | parser = ImageData.add_data_specific_args(parser) 41 | args, unknown = parser.parse_known_args() 42 | args, parser, d_vae_model = add_model_specific_args(args, parser) 43 | args = parser.parse_args() 44 | 45 | args.resolution = (args.resolution[0], args.resolution[0]) if len(args.resolution) == 1 else args.resolution # init resolution 46 | 47 | print(f"{args.default_root_dir=}") 48 | 49 | # Setup DDP: 50 | init_distributed_mode(args) 51 | rank = dist.get_rank() 52 | world_size = dist.get_world_size() 53 | device = rank % torch.cuda.device_count() 54 | torch.cuda.set_device(device) 55 | 56 | # Setup an experiment folder: 57 | if rank == 0: 58 | os.makedirs(args.default_root_dir, exist_ok=True) # Make results folder (holds all experiment subfolders 59 | checkpoint_dir = f"{args.default_root_dir}/checkpoints" # Stores saved model checkpoints 60 | os.makedirs(checkpoint_dir, exist_ok=True) 61 | logger = create_logger(args.default_root_dir) 62 | logger.info(f"Experiment directory created at {args.default_root_dir}") 63 | 64 | import wandb 65 | wandb_project = "VQVAE" 66 | wandb.init( 67 | project=wandb_project, 68 | name=os.path.basename(os.path.normpath(args.default_root_dir)), 69 | dir=args.default_root_dir, 70 | config=args, 71 | mode="offline" if args.debug else "online" 72 | ) 73 | else: 74 | logger = create_logger(None) 75 | 76 | # init dataloader 77 | data = ImageData(args) 78 | dataloaders = data.train_dataloader() 79 | dataloader_iters = [iter(loader) for loader in dataloaders] 80 | data_epochs = [0 for _ in dataloaders] 81 | 82 | # init model 83 | d_vae = d_vae_model(args).to(device) 84 | d_vae.logger = logger 85 | image_disc = ImageDiscriminator(args).to(device) 86 | 87 | # init optimizers and schedulers 88 | if args.optim_type == "Adam": 89 | optim = torch.optim.Adam 90 | elif args.optim_type == "AdamW": 91 | optim = torch.optim.AdamW 92 | if args.disc_optim_type is None: 93 | disc_optim = optim 94 | elif args.disc_optim_type == "rmsprop": 95 | disc_optim = torch.optim.RMSprop 96 | opt_vae = optim(d_vae.parameters(), lr=args.lr, betas=(args.beta1, args.beta2)) 97 | if disc_optim == torch.optim.RMSprop: 98 | opt_image_disc = disc_optim(image_disc.parameters(), lr=args.lr * args.dis_lr_multiplier) 99 | else: 100 | opt_image_disc = disc_optim(image_disc.parameters(), lr=args.lr * args.dis_lr_multiplier, betas=(args.beta1, args.beta2)) 101 | 102 | lr_min = args.lr_min 103 | train_iters = args.max_steps 104 | warmup_steps = args.warmup_steps 105 | warmup_lr_init = args.warmup_lr_init 106 | 107 | if args.disable_sch: 108 | # scheduler_list = [None, None] 109 | sch_vae, sch_image_disc = None, None 110 | 111 | model_optims = { 112 | "vae" : d_vae, 113 | "image_disc" : image_disc, 114 | "opt_vae" : opt_vae, 115 | "opt_image_disc" : opt_image_disc, 116 | "sch_vae" : sch_vae, 117 | "sch_image_disc" : sch_image_disc, 118 | } 119 | 120 | # resume from default_root_dir 121 | ckpt_path = None 122 | assert not args.default_root_dir is None # required argument 123 | ckpt_path = get_last_ckpt(args.default_root_dir) 124 | init_step = 0 125 | load_optimizer = not args.not_load_optimizer 126 | if ckpt_path: 127 | logger.info(f"Resuming from {ckpt_path}") 128 | state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True) 129 | model_optims, init_step = resume_from_ckpt(state_dict, model_optims, load_optimizer=True) 130 | # load pretrained weights 131 | elif args.pretrained is not None: 132 | state_dict = torch.load(args.pretrained, map_location="cpu", weights_only=True) 133 | if args.pretrained_mode == "full": 134 | model_optims, _ = resume_from_ckpt(state_dict, model_optims, load_optimizer=load_optimizer) 135 | logger.info(f"Successfully loaded ckpt {args.pretrained}, pretrained_mode {args.pretrained_mode}") 136 | 137 | d_vae = DDP(d_vae.to(device), device_ids=[args.gpu], bucket_cap_mb=args.bucket_cap_mb) 138 | image_disc = DDP(image_disc.to(device), device_ids=[args.gpu], bucket_cap_mb=args.bucket_cap_mb) 139 | disc_loss = get_disc_loss(args.disc_loss_type) # hinge loss by default 140 | 141 | if args.multiscale_training: 142 | scale_idx_list = np.load('bitvae/utils/random_numbers.npy') # load pre-computed scale_idx in each iteration 143 | 144 | start_time = time.time() 145 | for global_step in range(init_step, args.max_steps): 146 | loss_dicts = [] 147 | 148 | if global_step == args.discriminator_iter_start - args.disc_pretrain_iter: 149 | logging.info(f"discriminator begins pretraining ") 150 | if global_step == args.discriminator_iter_start: 151 | log_str = "add GAN loss into training" 152 | if args.disc_pretrain_iter > 0: 153 | log_str += ", discriminator ends pretraining" 154 | logging.info(log_str) 155 | 156 | for idx in range(len(dataloader_iters)): 157 | try: 158 | _batch = next(dataloader_iters[idx]) 159 | except StopIteration: 160 | data_epochs[idx] += 1 161 | logger.info(f"Reset the {idx}th dataloader as epoch {data_epochs[idx]}") 162 | dataloaders[idx].sampler.set_epoch(data_epochs[idx]) 163 | dataloader_iters[idx] = iter(dataloaders[idx]) # update dataloader iter 164 | _batch = next(dataloader_iters[idx]) 165 | except Exception as e: 166 | raise e 167 | x = _batch["image"] 168 | _type = _batch["type"][0] 169 | 170 | if args.multiscale_training: 171 | # data processing for multi-scale training 172 | scale_idx = scale_idx_list[global_step] 173 | if scale_idx == 0: 174 | # 256x256 batch=8 175 | x = F.interpolate(x, size=(256, 256), mode='area') 176 | elif scale_idx == 1: 177 | # 512x512 batch=4 178 | rdn_idx = torch.randperm(len(x))[:4] # without replacement 179 | x = x[rdn_idx] 180 | x = F.interpolate(x, size=(512, 512), mode='area') 181 | elif scale_idx == 2: 182 | # 1024x1024 batch=2 183 | rdn_idx = torch.randperm(len(x))[:2] # without replacement 184 | x = x[rdn_idx] 185 | else: 186 | raise ValueError(f"scale_idx {scale_idx} is not supported") 187 | 188 | if _type == "image": 189 | x_recon, flat_frames, flat_frames_recon, vae_loss_dict = d_vae(x, global_step, image_disc=image_disc) 190 | g_loss = sum(vae_loss_dict.values()) 191 | opt_vae.zero_grad() 192 | g_loss.backward() 193 | 194 | if not ((global_step+1) % args.ckpt_every) == 0: 195 | if args.max_grad_norm > 0: 196 | torch.nn.utils.clip_grad_norm_(d_vae.parameters(), args.max_grad_norm) 197 | if not sch_vae is None: 198 | sch_vae.step(global_step) 199 | elif args.lr_drop and global_step in args.lr_drop: 200 | logger.info(f"multiply lr of VQ-VAE by {args.lr_drop_rate} at iteration {global_step}") 201 | for opt_vae_param_group in opt_vae.param_groups: 202 | opt_vae_param_group["lr"] = opt_vae_param_group["lr"] * args.lr_drop_rate 203 | opt_vae.step() 204 | opt_vae.zero_grad() # free memory 205 | 206 | disc_loss_dict = {} 207 | # disc_factor = 0 before (args.discriminator_iter_start - args.disc_pretrain_iter) 208 | disc_factor = adopt_weight(global_step, threshold=args.discriminator_iter_start - args.disc_pretrain_iter) 209 | discloss = d_image_loss = torch.tensor(0.).to(x.device) 210 | ### enable pool warmup 211 | for disc_step in range(args.disc_optim_steps): # train discriminator 212 | require_optim = False 213 | if _type == "image" and args.image_disc_weight > 0: # train image discriminator 214 | require_optim = True 215 | logits_image_real = image_disc(x, pool_name="real") 216 | logits_image_fake = image_disc(x_recon.detach(), pool_name="fake") 217 | d_image_loss = disc_loss(logits_image_real, logits_image_fake) 218 | disc_loss_dict["train/logits_image_real"] = logits_image_real.mean().detach() 219 | disc_loss_dict["train/logits_image_fake"] = logits_image_fake.mean().detach() 220 | disc_loss_dict["train/d_image_loss"] = d_image_loss.mean().detach() 221 | discloss = d_image_loss * args.image_disc_weight 222 | opt_discs, sch_discs = [opt_image_disc], [sch_image_disc] 223 | if global_step >= args.discriminator_iter_start and args.use_lecam_reg_zero: 224 | lecam_zero_loss = lecam_reg_zero(logits_image_real.mean(), logits_image_fake.mean()) 225 | disc_loss_dict["train/lecam_zero_loss"] = lecam_zero_loss.mean().detach() 226 | discloss += lecam_zero_loss * args.lecam_weight 227 | discloss = disc_factor * discloss 228 | 229 | if require_optim: 230 | for opt_disc in opt_discs: 231 | opt_disc.zero_grad() 232 | discloss.backward() 233 | 234 | if not ((global_step+1) % args.ckpt_every) == 0: 235 | if args.max_grad_norm_disc > 0: # by default, 1.0 236 | torch.nn.utils.clip_grad_norm_(image_disc.parameters(), args.max_grad_norm_disc) 237 | for sch_disc in sch_discs: 238 | if not sch_disc is None: 239 | sch_disc.step(global_step) 240 | elif args.lr_drop and global_step in args.lr_drop: 241 | for opt_disc in opt_discs: 242 | logger.info(f"multiply lr of discriminator by {args.lr_drop_rate} at iteration {global_step}") 243 | for opt_disc_param_group in opt_disc.param_groups: 244 | opt_disc_param_group["lr"] = opt_disc_param_group["lr"] * args.lr_drop_rate 245 | for opt_disc in opt_discs: 246 | opt_disc.step() 247 | for opt_disc in opt_discs: 248 | opt_disc.zero_grad() # free memory 249 | 250 | loss_dict = {**vae_loss_dict, **disc_loss_dict} 251 | if (global_step+1) % args.log_every == 0: 252 | reduced_loss_dict = reduce_losses(loss_dict) 253 | else: 254 | reduced_loss_dict = {} 255 | loss_dicts.append(reduced_loss_dict) 256 | 257 | if (global_step+1) % args.log_every == 0: 258 | avg_loss_dict = average_losses(loss_dicts) 259 | torch.cuda.synchronize() 260 | end_time = time.time() 261 | iter_speed = (end_time - start_time) / args.log_every 262 | if rank == 0: 263 | for key, value in avg_loss_dict.items(): 264 | wandb.log({key: value}, step=global_step) 265 | # writing logs 266 | logger.info(f'global_step={global_step}, precepetual_loss={avg_loss_dict.get("train/perceptual_loss",0):.4f}, recon_loss={avg_loss_dict.get("train/recon_loss",0):.4f}, commitment_loss={avg_loss_dict.get("train/commitment_loss",0):.4f}, logit_r={avg_loss_dict.get("train/logits_image_real",0):.4f}, logit_f={avg_loss_dict.get("train/logits_image_fake",0):.4f}, L_disc={avg_loss_dict.get("train/d_image_loss",0):.4f}, iter_speed={iter_speed:.2f}s') 267 | start_time = time.time() 268 | 269 | if (global_step+1) % args.ckpt_every == 0 and global_step != init_step: 270 | if rank == 0: 271 | checkpoint_path = os.path.join(checkpoint_dir, f'model_step_{global_step}.ckpt') 272 | save_dict = {} 273 | for k in model_optims: 274 | save_dict[k] = None if model_optims[k] is None \ 275 | else model_optims[k].module.state_dict() if hasattr(model_optims[k], "module") \ 276 | else model_optims[k].state_dict() 277 | torch.save({ 278 | 'step': global_step, 279 | **save_dict, 280 | }, checkpoint_path) 281 | logger.info(f'Checkpoint saved at step {global_step}') 282 | 283 | 284 | if __name__ == '__main__': 285 | main() 286 | -------------------------------------------------------------------------------- /bitvae/evaluation/inception.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchvision import models 5 | from scipy import linalg 6 | import numpy as np 7 | 8 | try: 9 | from torchvision.models.utils import load_state_dict_from_url 10 | except ImportError: 11 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 12 | 13 | # Inception weights ported to Pytorch from 14 | # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 15 | FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' 16 | 17 | FID_WEIGHTS_PATH = "checkpoints/pt_inception-2015-12-05-6726825d.pth" 18 | 19 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 20 | """Numpy implementation of the Frechet Distance. 21 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 22 | and X_2 ~ N(mu_2, C_2) is 23 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 24 | 25 | Stable version by Dougal J. Sutherland. 26 | 27 | Params: 28 | -- mu1 : Numpy array containing the activations of a layer of the 29 | inception net (like returned by the function 'get_predictions') 30 | for generated samples. 31 | -- mu2 : The sample mean over activations, precalculated on an 32 | representative data set. 33 | -- sigma1: The covariance matrix over activations for generated samples. 34 | -- sigma2: The covariance matrix over activations, precalculated on an 35 | representative data set. 36 | 37 | Returns: 38 | -- : The Frechet Distance. 39 | """ 40 | 41 | mu1 = np.atleast_1d(mu1) 42 | mu2 = np.atleast_1d(mu2) 43 | 44 | sigma1 = np.atleast_2d(sigma1) 45 | sigma2 = np.atleast_2d(sigma2) 46 | 47 | assert mu1.shape == mu2.shape, \ 48 | 'Training and test mean vectors have different lengths' 49 | assert sigma1.shape == sigma2.shape, \ 50 | 'Training and test covariances have different dimensions' 51 | 52 | diff = mu1 - mu2 53 | 54 | # Product might be almost singular 55 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 56 | if not np.isfinite(covmean).all(): 57 | msg = ('fid calculation produces singular product; ' 58 | 'adding %s to diagonal of cov estimates') % eps 59 | print(msg) 60 | offset = np.eye(sigma1.shape[0]) * eps 61 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 62 | 63 | # Numerical error might give slight imaginary component 64 | if np.iscomplexobj(covmean): 65 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 66 | m = np.max(np.abs(covmean.imag)) 67 | raise ValueError('Imaginary component {}'.format(m)) 68 | covmean = covmean.real 69 | 70 | tr_covmean = np.trace(covmean) 71 | 72 | return (diff.dot(diff) + np.trace(sigma1) + 73 | np.trace(sigma2) - 2 * tr_covmean) 74 | 75 | class InceptionV3(nn.Module): 76 | """Pretrained InceptionV3 network returning feature maps""" 77 | 78 | # Index of default block of inception to return, 79 | # corresponds to output of final average pooling 80 | DEFAULT_BLOCK_INDEX = 3 81 | 82 | # Maps feature dimensionality to their output blocks indices 83 | BLOCK_INDEX_BY_DIM = { 84 | 64: 0, # First max pooling features 85 | 192: 1, # Second max pooling featurs 86 | 768: 2, # Pre-aux classifier features 87 | 2048: 3 # Final average pooling features 88 | } 89 | 90 | def __init__(self, 91 | output_blocks=[DEFAULT_BLOCK_INDEX], 92 | resize_input=True, 93 | normalize_input=True, 94 | requires_grad=False, 95 | use_fid_inception=True): 96 | """Build pretrained InceptionV3 97 | 98 | Parameters 99 | ---------- 100 | output_blocks : list of int 101 | Indices of blocks to return features of. Possible values are: 102 | - 0: corresponds to output of first max pooling 103 | - 1: corresponds to output of second max pooling 104 | - 2: corresponds to output which is fed to aux classifier 105 | - 3: corresponds to output of final average pooling 106 | resize_input : bool 107 | If true, bilinearly resizes input to width and height 299 before 108 | feeding input to model. As the network without fully connected 109 | layers is fully convolutional, it should be able to handle inputs 110 | of arbitrary size, so resizing might not be strictly needed 111 | normalize_input : bool 112 | If true, scales the input from range (0, 1) to the range the 113 | pretrained Inception network expects, namely (-1, 1) 114 | requires_grad : bool 115 | If true, parameters of the model require gradients. Possibly useful 116 | for finetuning the network 117 | use_fid_inception : bool 118 | If true, uses the pretrained Inception model used in Tensorflow's 119 | FID implementation. If false, uses the pretrained Inception model 120 | available in torchvision. The FID Inception model has different 121 | weights and a slightly different structure from torchvision's 122 | Inception model. If you want to compute FID scores, you are 123 | strongly advised to set this parameter to true to get comparable 124 | results. 125 | """ 126 | super(InceptionV3, self).__init__() 127 | 128 | self.resize_input = resize_input 129 | self.normalize_input = normalize_input 130 | self.output_blocks = sorted(output_blocks) 131 | self.last_needed_block = max(output_blocks) 132 | 133 | assert self.last_needed_block <= 3, \ 134 | 'Last possible output block index is 3' 135 | 136 | self.blocks = nn.ModuleList() 137 | 138 | if use_fid_inception: 139 | inception = fid_inception_v3() 140 | else: 141 | inception = models.inception_v3(pretrained=True) 142 | 143 | # Block 0: input to maxpool1 144 | block0 = [ 145 | inception.Conv2d_1a_3x3, 146 | inception.Conv2d_2a_3x3, 147 | inception.Conv2d_2b_3x3, 148 | nn.MaxPool2d(kernel_size=3, stride=2) 149 | ] 150 | self.blocks.append(nn.Sequential(*block0)) 151 | 152 | # Block 1: maxpool1 to maxpool2 153 | if self.last_needed_block >= 1: 154 | block1 = [ 155 | inception.Conv2d_3b_1x1, 156 | inception.Conv2d_4a_3x3, 157 | nn.MaxPool2d(kernel_size=3, stride=2) 158 | ] 159 | self.blocks.append(nn.Sequential(*block1)) 160 | 161 | # Block 2: maxpool2 to aux classifier 162 | if self.last_needed_block >= 2: 163 | block2 = [ 164 | inception.Mixed_5b, 165 | inception.Mixed_5c, 166 | inception.Mixed_5d, 167 | inception.Mixed_6a, 168 | inception.Mixed_6b, 169 | inception.Mixed_6c, 170 | inception.Mixed_6d, 171 | inception.Mixed_6e, 172 | ] 173 | self.blocks.append(nn.Sequential(*block2)) 174 | 175 | # Block 3: aux classifier to final avgpool 176 | if self.last_needed_block >= 3: 177 | block3 = [ 178 | inception.Mixed_7a, 179 | inception.Mixed_7b, 180 | inception.Mixed_7c, 181 | nn.AdaptiveAvgPool2d(output_size=(1, 1)) 182 | ] 183 | self.blocks.append(nn.Sequential(*block3)) 184 | 185 | for param in self.parameters(): 186 | param.requires_grad = requires_grad 187 | 188 | def forward(self, inp): 189 | """Get Inception feature maps 190 | 191 | Parameters 192 | ---------- 193 | inp : torch.autograd.Variable 194 | Input tensor of shape Bx3xHxW. Values are expected to be in 195 | range (0, 1) 196 | 197 | Returns 198 | ------- 199 | List of torch.autograd.Variable, corresponding to the selected output 200 | block, sorted ascending by index 201 | """ 202 | outp = [] 203 | x = inp 204 | 205 | if self.resize_input: 206 | x = F.interpolate(x, 207 | size=(299, 299), 208 | mode='bilinear', 209 | align_corners=False) 210 | 211 | if self.normalize_input: 212 | x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) 213 | 214 | for idx, block in enumerate(self.blocks): 215 | x = block(x) 216 | if idx in self.output_blocks: 217 | outp.append(x) 218 | 219 | if idx == self.last_needed_block: 220 | break 221 | 222 | return outp 223 | 224 | 225 | def fid_inception_v3(): 226 | """Build pretrained Inception model for FID computation 227 | 228 | The Inception model for FID computation uses a different set of weights 229 | and has a slightly different structure than torchvision's Inception. 230 | 231 | This method first constructs torchvision's Inception and then patches the 232 | necessary parts that are different in the FID Inception model. 233 | """ 234 | inception = models.inception_v3(num_classes=1008, 235 | aux_logits=False, 236 | pretrained=False) 237 | inception.Mixed_5b = FIDInceptionA(192, pool_features=32) 238 | inception.Mixed_5c = FIDInceptionA(256, pool_features=64) 239 | inception.Mixed_5d = FIDInceptionA(288, pool_features=64) 240 | inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128) 241 | inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160) 242 | inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160) 243 | inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192) 244 | inception.Mixed_7b = FIDInceptionE_1(1280) 245 | inception.Mixed_7c = FIDInceptionE_2(2048) 246 | 247 | # state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True) 248 | state_dict = torch.load(FID_WEIGHTS_PATH) 249 | inception.load_state_dict(state_dict) 250 | return inception 251 | 252 | 253 | class FIDInceptionA(models.inception.InceptionA): 254 | """InceptionA block patched for FID computation""" 255 | def __init__(self, in_channels, pool_features): 256 | super(FIDInceptionA, self).__init__(in_channels, pool_features) 257 | 258 | def forward(self, x): 259 | branch1x1 = self.branch1x1(x) 260 | 261 | branch5x5 = self.branch5x5_1(x) 262 | branch5x5 = self.branch5x5_2(branch5x5) 263 | 264 | branch3x3dbl = self.branch3x3dbl_1(x) 265 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 266 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 267 | 268 | # Patch: Tensorflow's average pool does not use the padded zero's in 269 | # its average calculation 270 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 271 | count_include_pad=False) 272 | branch_pool = self.branch_pool(branch_pool) 273 | 274 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] 275 | return torch.cat(outputs, 1) 276 | 277 | 278 | class FIDInceptionC(models.inception.InceptionC): 279 | """InceptionC block patched for FID computation""" 280 | def __init__(self, in_channels, channels_7x7): 281 | super(FIDInceptionC, self).__init__(in_channels, channels_7x7) 282 | 283 | def forward(self, x): 284 | branch1x1 = self.branch1x1(x) 285 | 286 | branch7x7 = self.branch7x7_1(x) 287 | branch7x7 = self.branch7x7_2(branch7x7) 288 | branch7x7 = self.branch7x7_3(branch7x7) 289 | 290 | branch7x7dbl = self.branch7x7dbl_1(x) 291 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) 292 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) 293 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) 294 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) 295 | 296 | # Patch: Tensorflow's average pool does not use the padded zero's in 297 | # its average calculation 298 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 299 | count_include_pad=False) 300 | branch_pool = self.branch_pool(branch_pool) 301 | 302 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] 303 | return torch.cat(outputs, 1) 304 | 305 | 306 | class FIDInceptionE_1(models.inception.InceptionE): 307 | """First InceptionE block patched for FID computation""" 308 | def __init__(self, in_channels): 309 | super(FIDInceptionE_1, self).__init__(in_channels) 310 | 311 | def forward(self, x): 312 | branch1x1 = self.branch1x1(x) 313 | 314 | branch3x3 = self.branch3x3_1(x) 315 | branch3x3 = [ 316 | self.branch3x3_2a(branch3x3), 317 | self.branch3x3_2b(branch3x3), 318 | ] 319 | branch3x3 = torch.cat(branch3x3, 1) 320 | 321 | branch3x3dbl = self.branch3x3dbl_1(x) 322 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 323 | branch3x3dbl = [ 324 | self.branch3x3dbl_3a(branch3x3dbl), 325 | self.branch3x3dbl_3b(branch3x3dbl), 326 | ] 327 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 328 | 329 | # Patch: Tensorflow's average pool does not use the padded zero's in 330 | # its average calculation 331 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 332 | count_include_pad=False) 333 | branch_pool = self.branch_pool(branch_pool) 334 | 335 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 336 | return torch.cat(outputs, 1) 337 | 338 | 339 | class FIDInceptionE_2(models.inception.InceptionE): 340 | """Second InceptionE block patched for FID computation""" 341 | def __init__(self, in_channels): 342 | super(FIDInceptionE_2, self).__init__(in_channels) 343 | 344 | def forward(self, x): 345 | branch1x1 = self.branch1x1(x) 346 | 347 | branch3x3 = self.branch3x3_1(x) 348 | branch3x3 = [ 349 | self.branch3x3_2a(branch3x3), 350 | self.branch3x3_2b(branch3x3), 351 | ] 352 | branch3x3 = torch.cat(branch3x3, 1) 353 | 354 | branch3x3dbl = self.branch3x3dbl_1(x) 355 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 356 | branch3x3dbl = [ 357 | self.branch3x3dbl_3a(branch3x3dbl), 358 | self.branch3x3dbl_3b(branch3x3dbl), 359 | ] 360 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 361 | 362 | # Patch: The FID Inception model uses max pooling instead of average 363 | # pooling. This is likely an error in this specific Inception 364 | # implementation, as other Inception models use average pooling here 365 | # (which matches the description in the paper). 366 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) 367 | branch_pool = self.branch_pool(branch_pool) 368 | 369 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 370 | return torch.cat(outputs, 1) -------------------------------------------------------------------------------- /bitvae/models/d_vae.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | import numpy as np 5 | from einops import rearrange 6 | from torch import Tensor, nn 7 | from torchvision import transforms 8 | import torch.utils.checkpoint as checkpoint 9 | import torch.nn.functional as F 10 | 11 | from bitvae.modules.quantizer import MultiScaleBSQ 12 | from bitvae.modules import Conv, adopt_weight, LPIPS, Normalize 13 | from bitvae.utils.misc import ptdtype 14 | 15 | 16 | def swish(x: Tensor) -> Tensor: 17 | try: 18 | return x * torch.sigmoid(x) 19 | except: 20 | device = x.device 21 | x = x.cpu().pin_memory() 22 | return (x*torch.sigmoid(x)).to(device=device) 23 | 24 | class ResnetBlock(nn.Module): 25 | def __init__(self, in_channels: int, out_channels: int, norm_type='group'): 26 | super().__init__() 27 | self.in_channels = in_channels 28 | out_channels = in_channels if out_channels is None else out_channels 29 | self.out_channels = out_channels 30 | 31 | self.norm1 = Normalize(in_channels, norm_type) 32 | 33 | self.conv1 = Conv(in_channels, out_channels, kernel_size=3, stride=1, padding=1) 34 | 35 | self.norm2 = Normalize(out_channels, norm_type) 36 | 37 | self.conv2 = Conv(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 38 | 39 | if self.in_channels != self.out_channels: 40 | self.nin_shortcut = Conv(in_channels, out_channels, kernel_size=1, stride=1, padding=0) 41 | 42 | def forward(self, x): 43 | h = x 44 | h = self.norm1(h) 45 | h = swish(h) 46 | h = self.conv1(h) 47 | 48 | h = self.norm2(h) 49 | h = swish(h) 50 | h = self.conv2(h) 51 | 52 | if self.in_channels != self.out_channels: 53 | x = self.nin_shortcut(x) 54 | 55 | return x + h 56 | 57 | 58 | 59 | class Downsample(nn.Module): 60 | def __init__(self, in_channels, spatial_down=False): 61 | super().__init__() 62 | assert spatial_down == True 63 | self.pad = (0, 1, 0, 1) 64 | self.conv = Conv(in_channels, in_channels, kernel_size=3, stride=2, padding=0) 65 | 66 | def forward(self, x: Tensor): 67 | x = nn.functional.pad(x, self.pad, mode="constant", value=0) 68 | x = self.conv(x) 69 | return x 70 | 71 | 72 | class Upsample(nn.Module): 73 | def __init__(self, in_channels, spatial_up=False): 74 | super().__init__() 75 | assert spatial_up == True 76 | 77 | self.scale_factor = 2 78 | self.conv = Conv(in_channels, in_channels, kernel_size=3, stride=1, padding=1) 79 | 80 | def forward(self, x: Tensor): 81 | x = F.interpolate(x, scale_factor=self.scale_factor, mode="nearest") 82 | x = self.conv(x) 83 | return x 84 | 85 | 86 | class Encoder(nn.Module): 87 | def __init__( 88 | self, 89 | ch: int, 90 | ch_mult: list[int], 91 | num_res_blocks: int, 92 | z_channels: int, 93 | in_channels = 3, 94 | patch_size=8, 95 | norm_type='group', 96 | use_checkpoint=False, 97 | ): 98 | super().__init__() 99 | self.max_down = np.log2(patch_size) 100 | self.ch = ch 101 | self.num_resolutions = len(ch_mult) 102 | self.num_res_blocks = num_res_blocks 103 | self.in_channels = in_channels 104 | self.use_checkpoint = use_checkpoint 105 | # downsampling 106 | # self.conv_in = Conv(in_channels, self.ch, kernel_size=3, stride=1, padding=1) 107 | self.conv_in = Conv(in_channels, ch, kernel_size=3, stride=1, padding=1) 108 | 109 | in_ch_mult = (1,) + tuple(ch_mult) 110 | self.in_ch_mult = in_ch_mult 111 | self.down = nn.ModuleList() 112 | block_in = self.ch 113 | for i_level in range(self.num_resolutions): 114 | block = nn.ModuleList() 115 | attn = nn.ModuleList() 116 | block_in = ch * in_ch_mult[i_level] 117 | block_out = ch * ch_mult[i_level] 118 | for _ in range(self.num_res_blocks): 119 | block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, norm_type=norm_type)) 120 | block_in = block_out 121 | down = nn.Module() 122 | down.block = block 123 | down.attn = attn 124 | 125 | spatial_down = True if i_level < self.max_down else False 126 | if spatial_down: 127 | down.downsample = Downsample(block_in, spatial_down=spatial_down) 128 | self.down.append(down) 129 | 130 | # middle 131 | self.mid = nn.Module() 132 | self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, norm_type=norm_type) 133 | self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, norm_type=norm_type) 134 | 135 | # end 136 | self.norm_out = Normalize(block_in, norm_type) 137 | self.conv_out = Conv(block_in, z_channels, kernel_size=3, stride=1, padding=1) 138 | 139 | def forward(self, x, return_hidden=False): 140 | if not self.use_checkpoint: 141 | return self._forward(x, return_hidden=return_hidden) 142 | else: 143 | return checkpoint.checkpoint(self._forward, x, return_hidden, use_reentrant=False) 144 | 145 | def _forward(self, x: Tensor, return_hidden=False) -> Tensor: 146 | # downsampling 147 | h0 = self.conv_in(x) 148 | hs = [h0] 149 | for i_level in range(self.num_resolutions): 150 | for i_block in range(self.num_res_blocks): 151 | h = self.down[i_level].block[i_block](hs[-1]) 152 | if len(self.down[i_level].attn) > 0: 153 | h = self.down[i_level].attn[i_block](h) 154 | hs.append(h) 155 | if hasattr(self.down[i_level], "downsample"): 156 | hs.append(self.down[i_level].downsample(hs[-1])) 157 | 158 | # middle 159 | h = hs[-1] 160 | hs_mid = [h] 161 | h = self.mid.block_1(h) 162 | h = self.mid.block_2(h) 163 | hs_mid.append(h) 164 | # end 165 | h = self.norm_out(h) 166 | h = swish(h) 167 | h = self.conv_out(h) 168 | if return_hidden: 169 | return h, hs, hs_mid 170 | else: 171 | return h 172 | 173 | 174 | class Decoder(nn.Module): 175 | def __init__( 176 | self, 177 | ch: int, 178 | ch_mult: list[int], 179 | num_res_blocks: int, 180 | z_channels: int, 181 | out_ch = 3, 182 | patch_size=8, 183 | norm_type="group", 184 | use_checkpoint=False, 185 | ): 186 | super().__init__() 187 | self.max_up = np.log2(patch_size) 188 | self.ch = ch 189 | self.num_resolutions = len(ch_mult) 190 | self.num_res_blocks = num_res_blocks 191 | self.ffactor = 2 ** (self.num_resolutions - 1) 192 | self.use_checkpoint = use_checkpoint 193 | 194 | # compute in_ch_mult, block_in and curr_res at lowest res 195 | block_in = ch * ch_mult[self.num_resolutions - 1] 196 | 197 | # z to block_in 198 | self.conv_in = Conv(z_channels, block_in, kernel_size=3, stride=1, padding=1) 199 | 200 | # middle 201 | self.mid = nn.Module() 202 | self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, norm_type=norm_type) 203 | self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, norm_type=norm_type) 204 | 205 | # upsampling 206 | self.up = nn.ModuleList() 207 | for i_level in reversed(range(self.num_resolutions)): 208 | block = nn.ModuleList() 209 | attn = nn.ModuleList() 210 | block_out = ch * ch_mult[i_level] 211 | for _ in range(self.num_res_blocks + 1): 212 | block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, norm_type=norm_type)) 213 | block_in = block_out 214 | up = nn.Module() 215 | up.block = block 216 | up.attn = attn 217 | # https://github.com/black-forest-labs/flux/blob/b4f689aaccd40de93429865793e84a734f4a6254/src/flux/modules/autoencoder.py#L228 218 | spatial_up = True if 1 <= i_level <= self.max_up else False 219 | if spatial_up: 220 | up.upsample = Upsample(block_in, spatial_up=spatial_up) 221 | self.up.insert(0, up) # prepend to get consistent order 222 | 223 | # end 224 | self.norm_out = Normalize(block_in, norm_type) 225 | self.conv_out = Conv(block_in, out_ch, kernel_size=3, stride=1, padding=1) 226 | 227 | def forward(self, z): 228 | if not self.use_checkpoint: 229 | return self._forward(z) 230 | else: 231 | return checkpoint.checkpoint(self._forward, z, use_reentrant=False) 232 | 233 | def _forward(self, z: Tensor) -> Tensor: 234 | # z to block_in 235 | h = self.conv_in(z) 236 | 237 | # middle 238 | h = self.mid.block_1(h) 239 | h = self.mid.block_2(h) 240 | 241 | # upsampling 242 | for i_level in reversed(range(self.num_resolutions)): 243 | for i_block in range(self.num_res_blocks + 1): 244 | h = self.up[i_level].block[i_block](h) 245 | if len(self.up[i_level].attn) > 0: 246 | h = self.up[i_level].attn[i_block](h) 247 | if hasattr(self.up[i_level], "upsample"): 248 | h = self.up[i_level].upsample(h) 249 | 250 | # end 251 | h = self.norm_out(h) 252 | h = swish(h) 253 | h = self.conv_out(h) 254 | return h 255 | 256 | class AutoEncoder(nn.Module): 257 | def __init__(self, args): 258 | super().__init__() 259 | self.args = args 260 | self.encoder = Encoder( 261 | ch=args.base_ch, 262 | ch_mult=args.encoder_ch_mult, 263 | num_res_blocks=args.num_res_blocks, 264 | z_channels=args.codebook_dim, 265 | patch_size=args.patch_size, 266 | use_checkpoint=args.use_checkpoint, 267 | ) 268 | self.decoder = Decoder( 269 | ch=args.base_ch, 270 | ch_mult=args.decoder_ch_mult, 271 | num_res_blocks=args.num_res_blocks, 272 | z_channels=args.codebook_dim, 273 | patch_size=args.patch_size, 274 | use_checkpoint=args.use_checkpoint, 275 | ) 276 | 277 | self.gan_feat_weight = args.gan_feat_weight 278 | self.recon_loss_type = args.recon_loss_type 279 | self.l1_weight = args.l1_weight 280 | self.kl_weight = args.kl_weight 281 | self.lfq_weight = args.lfq_weight 282 | self.image_gan_weight = args.image_gan_weight # image GAN loss weight 283 | self.perceptual_weight = args.perceptual_weight 284 | 285 | self.compute_all_commitment = args.compute_all_commitment # compute commitment between input and rq-output 286 | 287 | self.perceptual_model = LPIPS(upcast_tf32=args.upcast_tf32).eval() 288 | 289 | if args.quantizer_type == 'MultiScaleBSQ': 290 | self.quantizer = MultiScaleBSQ( 291 | dim = args.codebook_dim, # this is the input feature dimension, defaults to log2(codebook_size) if not defined 292 | entropy_loss_weight = args.entropy_loss_weight, # how much weight to place on entropy loss 293 | diversity_gamma = args.diversity_gamma, # within entropy loss, how much weight to give to diversity of codes, taken from https://arxiv.org/abs/1911.05894 294 | commitment_loss_weight=args.commitment_loss_weight, # loss weight of commitment loss 295 | new_quant=args.new_quant, 296 | use_decay_factor=args.use_decay_factor, 297 | use_stochastic_depth=args.use_stochastic_depth, 298 | drop_rate=args.drop_rate, 299 | schedule_mode=args.schedule_mode, 300 | keep_first_quant=args.keep_first_quant, 301 | keep_last_quant=args.keep_last_quant, 302 | remove_residual_detach=args.remove_residual_detach, 303 | use_out_phi=args.use_out_phi, 304 | use_out_phi_res=args.use_out_phi_res, 305 | random_flip = args.random_flip, 306 | flip_prob = args.flip_prob, 307 | flip_mode = args.flip_mode, 308 | max_flip_lvl = args.max_flip_lvl, 309 | random_flip_1lvl = args.random_flip_1lvl, 310 | flip_lvl_idx = args.flip_lvl_idx, 311 | drop_when_test = args.drop_when_test, 312 | drop_lvl_idx = args.drop_lvl_idx, 313 | drop_lvl_num = args.drop_lvl_num, 314 | random_short_schedule = args.random_short_schedule, 315 | short_schedule_prob = args.short_schedule_prob, 316 | disable_flip_prob = args.disable_flip_prob, 317 | zeta = args.zeta, 318 | gamma = args.gamma, 319 | uniform_short_schedule = args.uniform_short_schedule 320 | ) 321 | else: 322 | raise NotImplementedError(f"{args.quantizer_type} not supported") 323 | self.commitment_loss_weight = args.commitment_loss_weight 324 | 325 | def forward(self, x, global_step, image_disc=None, is_train=True): 326 | assert x.ndim == 4 # assert input data is image 327 | 328 | enc_dtype = ptdtype[self.args.encoder_dtype] 329 | 330 | with torch.amp.autocast("cuda", dtype=enc_dtype): 331 | h = self.encoder(x, return_hidden=False) # B C H W 332 | h = h.to(dtype=torch.float32) 333 | 334 | # Multiscale LFQ 335 | z, all_indices, all_bit_indices, all_loss = self.quantizer(h) 336 | # print(torch.unique(torch.round(z * 10**4)/10**4)) # keep 4 decimal places 337 | x_recon = self.decoder(z) 338 | vq_output = { 339 | "commitment_loss": torch.mean(all_loss) * self.lfq_weight, # here commitment loss is sum of commitment loss and entropy penalty 340 | "encodings": all_indices, 341 | "bit_encodings": all_bit_indices, 342 | } 343 | if self.compute_all_commitment: 344 | # compute commitment loss between input and rq-output 345 | vq_output["all_commitment_loss"] = F.mse_loss(h, z.detach(), reduction="mean") * self.commitment_loss_weight * self.lfq_weight 346 | else: 347 | # disable backward prop 348 | vq_output["all_commitment_loss"] = F.mse_loss(h.detach(), z.detach(), reduction="mean") * self.commitment_loss_weight * self.lfq_weight 349 | 350 | assert x.shape == x_recon.shape, f"x.shape {x.shape}, x_recon.shape {x_recon.shape}" 351 | 352 | if is_train == False: 353 | return x_recon, vq_output 354 | 355 | if self.recon_loss_type == 'l1': 356 | recon_loss = F.l1_loss(x_recon, x) * self.l1_weight 357 | else: 358 | recon_loss = F.mse_loss(x_recon, x) * self.l1_weight 359 | 360 | flat_frames = x 361 | flat_frames_recon = x_recon 362 | 363 | perceptual_loss = self.perceptual_model(flat_frames, flat_frames_recon).mean() * self.perceptual_weight 364 | 365 | loss_dict = { 366 | "train/perceptual_loss": perceptual_loss, 367 | "train/recon_loss": recon_loss, 368 | "train/commitment_loss": vq_output['commitment_loss'], 369 | "train/all_commitment_loss": vq_output['all_commitment_loss'], 370 | } 371 | 372 | ### GAN loss 373 | disc_factor = adopt_weight(global_step, threshold=self.args.discriminator_iter_start, warmup=self.args.disc_warmup) 374 | if self.image_gan_weight > 0: # image GAN loss 375 | logits_image_fake = image_disc(flat_frames_recon) 376 | g_image_loss = -torch.mean(logits_image_fake) * self.image_gan_weight * disc_factor # disc_factor=0 before self.args.discriminator_iter_start 377 | loss_dict["train/g_image_loss"] = g_image_loss 378 | 379 | return (x_recon.detach(), flat_frames.detach(), flat_frames_recon.detach(), loss_dict) 380 | 381 | @staticmethod 382 | def add_model_specific_args(parent_parser): 383 | parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False) 384 | parser.add_argument("--codebook_size", type=int, default=16384) 385 | 386 | parser.add_argument("--base_ch", type=int, default=128) 387 | parser.add_argument("--num_res_blocks", type=int, default=2) # num_res_blocks for encoder, num_res_blocks+1 for decoder 388 | parser.add_argument("--encoder_ch_mult", type=int, nargs='+', default=[1, 1, 2, 2, 4]) 389 | parser.add_argument("--decoder_ch_mult", type=int, nargs='+', default=[1, 1, 2, 2, 4]) 390 | return parser 391 | -------------------------------------------------------------------------------- /bitvae/data/data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Foundationvision, Inc. All Rights Reserved 2 | 3 | import os 4 | import os.path as osp 5 | import random 6 | import argparse 7 | import numpy as np 8 | from PIL import Image 9 | Image.MAX_IMAGE_PIXELS = None 10 | 11 | import torch 12 | import torch.utils.data as data 13 | import torch.distributed as dist 14 | from torchvision import transforms 15 | 16 | from bitvae.data.dataset_zoo import DATASET_DICT 17 | from torchvision.transforms import InterpolationMode 18 | from bitvae.modules.quantizer.dynamic_resolution import dynamic_resolution_h_w 19 | 20 | def _pil_interp(method): 21 | if method == 'bicubic': 22 | return InterpolationMode.BICUBIC 23 | elif method == 'lanczos': 24 | return InterpolationMode.LANCZOS 25 | elif method == 'hamming': 26 | return InterpolationMode.HAMMING 27 | else: 28 | # default bilinear, do we want to allow nearest? 29 | return InterpolationMode.BILINEAR 30 | 31 | import timm.data.transforms as timm_transforms 32 | timm_transforms._pil_interp = _pil_interp 33 | 34 | def get_parent_dir(path): 35 | return osp.basename(osp.dirname(path)) 36 | 37 | # Only used during inference (for benchmarks like tuchong, which allows different resolutions) 38 | class DynamicAspectRatioGroupedDataset(data.Dataset): 39 | """ 40 | Batch data that have similar aspect ratio together. 41 | In this implementation, images whose aspect ratio < (or >) 1 will 42 | be batched together. 43 | This improves training speed because the images then need less padding 44 | to form a batch. 45 | 46 | It assumes the underlying dataset produces dicts with "width" and "height" keys. 47 | It will then produce a list of original dicts with length = batch_size, 48 | all with similar aspect ratios. 49 | """ 50 | 51 | def __init__(self, 52 | dataset, batch_size, 53 | debug=False, seed=0, train=True, random_bucket_ratio=0. 54 | ): 55 | """ 56 | Args: 57 | dataset: an iterable. Each element must be a dict with keys 58 | "width" and "height", which will be used to batch data. 59 | batch_size (int): 60 | """ 61 | self.train = train 62 | self._idx = 0 63 | self.dataset = dataset 64 | self.batch_size = batch_size 65 | self.random_bucket_ratio = random_bucket_ratio 66 | self.aspect_ratio = list(dynamic_resolution_h_w.keys()) 67 | num_buckets = len(self.aspect_ratio) * 3 # in each aspect-ratio, we have three scales 68 | self._buckets = [[] for _ in range(num_buckets)] 69 | self.debug = debug 70 | self.seed = seed 71 | if type(dataset) == ImageDataset: 72 | self.batch_factor = 0.6 # A100 config 73 | else: 74 | raise NotImplementedError 75 | # Hard-coded two aspect ratio groups: w > h and w < h. 76 | # Can add support for more aspect ratio groups, but doesn't seem useful 77 | 78 | def closest_id(self, v, type="width"): 79 | if type == "width": 80 | dist = np.array([abs(v - self.width[i]) for i in range(len(self.width))]) 81 | return np.argmin(dist) 82 | else: 83 | dist = np.array([abs(v - self.aspect_ratio[i]) for i in range(len(self.aspect_ratio))]) 84 | return np.argmin(dist) 85 | 86 | def __len__(self): 87 | return len(self.dataset) // self.batch_size ### an approximate value 88 | 89 | 90 | def collate_func(self, batch_list): 91 | batch = batch_list[0] 92 | 93 | return { 94 | "image": torch.stack([d["image"] for d in batch], dim=0), 95 | "label": [d["label"] for d in batch], 96 | "path": [d["path"] for d in batch], 97 | "height": [d["height"] for d in batch], 98 | "width": [d["width"] for d in batch], 99 | "type": [d["type"] for d in batch], 100 | } 101 | 102 | def get_batch_size(self, w_id, ar_id): 103 | sel_ar = self.aspect_ratio[ar_id] 104 | sel_w = self.width[w_id] 105 | dynamic_batch_size = int(self.batch_size / sel_ar / ((sel_w / self.max_width)**2) / self.batch_factor) 106 | return max(dynamic_batch_size, 1) 107 | 108 | def get_ar_w_new(self, w, ar, strategy="closest"): 109 | if strategy == "closest": 110 | # get new ar 111 | dist = np.array([abs(ar - self.aspect_ratio[i]) for i in range(len(self.aspect_ratio))]) 112 | ar_id = np.argmin(dist) 113 | ar_new = self.aspect_ratio[ar_id] 114 | # get new w 115 | w_list = list(dynamic_resolution_h_w[ar_new].keys()) 116 | dist = np.array([abs(w - w_list[i]) for i in range(len(w_list))]) 117 | w_id = np.argmin(dist) 118 | w_new = w_list[w_id] 119 | h_new = dynamic_resolution_h_w[ar_new][w_new]["pixel"][0] 120 | return w_new, h_new, w_id, ar_id 121 | elif strategy == "random": 122 | raise NotImplementedError 123 | 124 | def get_aug(self, w, h, strategy="closest"): 125 | sel_h, sel_w = h, w 126 | assert sel_h % 8 == 0 and sel_w % 8 == 0 127 | aug_shape = (sel_h, sel_w) 128 | if strategy == "closest": 129 | aug = transforms.Resize(aug_shape) # resize to the closest size 130 | elif strategy == "random": 131 | min_edge = min(sel_w, sel_h) 132 | aug = transforms.Compose([ 133 | transforms.Resize(min_edge), 134 | transforms.CenterCrop((sel_h, sel_w)), 135 | ]) 136 | return aug 137 | 138 | def __getitem__(self, idx): 139 | d = self.dataset.__getitem__(idx) 140 | w, h = d["width"], d["height"] 141 | ar = h / w 142 | strategy = "random" if random.random() < self.random_bucket_ratio else "closest" 143 | w_new, h_new, w_id, ar_id = self.get_ar_w_new(w, ar, strategy=strategy) 144 | aug = self.get_aug(w_new, h_new, strategy=strategy) 145 | images = d["image"] 146 | assert images.ndim == 3 147 | d["image"] = aug(images) 148 | assert (d["image"].shape[1] % 8 == 0) and (d["image"].shape[2] % 8 == 0) 149 | 150 | return [d] 151 | 152 | def __iter__(self): 153 | # if not self.debug: 154 | while True: 155 | if self.train: 156 | idx = random.randint(0, self.dataset.__len__()-1) 157 | else: 158 | idx = self._idx 159 | self._idx = (self._idx + 1) % self.dataset.__len__() 160 | 161 | d = self.dataset.__getitem__(idx) 162 | w, h = d["width"], d["height"] 163 | ar = h / w 164 | strategy = "random" if random.random() < self.random_bucket_ratio else "closest" 165 | w_new, h_new, w_id, ar_id = self.get_ar_w_new(w, ar, strategy=strategy) 166 | aug = self.get_aug(w_new, h_new, strategy=strategy) 167 | images = d["image"] 168 | assert images.ndim == 3 169 | 170 | d["image"] = aug(images) 171 | assert (d["image"].shape[1] % 8 == 0) and (d["image"].shape[2] % 8 == 0) 172 | 173 | bucket_id = ar_id * 3 + w_id # TODO: fix this hardcode 3 174 | bucket = self._buckets[bucket_id] 175 | bucket.append(d) 176 | target_batch_size = self.get_batch_size(w_id, ar_id) if self.train else self.batch_size 177 | if len(bucket) == target_batch_size: 178 | data = bucket[:] 179 | # Clear bucket first, because code after yield is not 180 | # guaranteed to execute 181 | del bucket[:] 182 | yield data 183 | 184 | 185 | class AspectRatioGroupedDataset(data.IterableDataset): 186 | """ 187 | Batch data that have similar aspect ratio together. 188 | In this implementation, images whose aspect ratio < (or >) 1 will 189 | be batched together. 190 | This improves training speed because the images then need less padding 191 | to form a batch. 192 | 193 | It assumes the underlying dataset produces dicts with "width" and "height" keys. 194 | It will then produce a list of original dicts with length = batch_size, 195 | all with similar aspect ratios. 196 | """ 197 | 198 | def __init__(self, 199 | dataset, batch_size, 200 | width=[256, 320, 384, 448, 512], # A100 config 201 | aspect_ratio=[4/16, 6/16, 8/16, 9/16, 10/16, 12/16, 14/16, 1, 16/14, 16/12, 16/10, 16/9, 16/8, 16/6, 16/4], # A100 config 202 | max_resolution=512*512, # A100 config 203 | debug=False, seed=0, train=True, random_bucket_ratio=0. 204 | ): 205 | """ 206 | Args: 207 | dataset: an iterable. Each element must be a dict with keys 208 | "width" and "height", which will be used to batch data. 209 | batch_size (int): 210 | """ 211 | self.train = train 212 | self._idx = 0 213 | self.dataset = dataset 214 | self.batch_size = batch_size 215 | self.random_bucket_ratio = random_bucket_ratio 216 | num_buckets = len(width) * len(aspect_ratio) 217 | self.width = width 218 | self.max_width = max(width) 219 | self.aspect_ratio = aspect_ratio 220 | self._buckets = [[] for _ in range(num_buckets)] 221 | self.debug = debug 222 | self.seed = seed 223 | self.max_resolution = max_resolution 224 | if type(dataset) == ImageDataset: 225 | self.batch_factor = 0.6 # A100 config 226 | else: 227 | raise NotImplementedError 228 | # Hard-coded two aspect ratio groups: w > h and w < h. 229 | # Can add support for more aspect ratio groups, but doesn't seem useful 230 | 231 | def closest_id(self, v, type="width"): 232 | if type == "width": 233 | dist = np.array([abs(v - self.width[i]) for i in range(len(self.width))]) 234 | return np.argmin(dist) 235 | else: 236 | dist = np.array([abs(v - self.aspect_ratio[i]) for i in range(len(self.aspect_ratio))]) 237 | return np.argmin(dist) 238 | 239 | def __len__(self): 240 | return len(self.dataset) // self.batch_size ### an approximate value 241 | 242 | 243 | def collate_func(self, batch_list): 244 | batch = batch_list[0] 245 | 246 | return { 247 | "image": torch.stack([d["image"] for d in batch], dim=0), 248 | "label": [d["label"] for d in batch], 249 | "path": [d["path"] for d in batch], 250 | "height": [d["height"] for d in batch], 251 | "width": [d["width"] for d in batch], 252 | "type": [d["type"] for d in batch], 253 | } 254 | 255 | def get_batch_size(self, w_id, ar_id): 256 | sel_ar = self.aspect_ratio[ar_id] 257 | sel_w = self.width[w_id] 258 | dynamic_batch_size = int(self.batch_size / sel_ar / ((sel_w / self.max_width)**2) / self.batch_factor) 259 | return max(dynamic_batch_size, 1) 260 | 261 | def memory_safty_guard(self, w_id, ar_id): 262 | while True: 263 | sel_ar = self.aspect_ratio[ar_id] 264 | sel_w = self.width[w_id] 265 | if self.max_resolution < 0 or (sel_ar * sel_w) * sel_w <= self.max_resolution: 266 | break 267 | else: 268 | w_id = w_id - 1 269 | return w_id, ar_id 270 | 271 | def get_ar_w_id(self, w, ar, strategy="closest"): 272 | if strategy == "closest": 273 | w_id = self.closest_id(w, type="width") 274 | ar_id = self.closest_id(ar, type="aspect_ratio") 275 | elif strategy == "random": 276 | h = w * ar 277 | ws = [_w for _w in self.width if _w <= w] 278 | _w = random.choice(ws) if len(ws) > 0 else self.width[0] 279 | w_id = self.width.index(_w) 280 | ars = [_ar for _ar in self.aspect_ratio if _ar * w < h] 281 | _ar = random.choice(ars) if len(ars) > 0 else self.aspect_ratio[0] 282 | ar_id = self.aspect_ratio.index(_ar) 283 | return self.memory_safty_guard(w_id, ar_id) 284 | 285 | def get_aug(self, w_id, ar_id, strategy="closest"): 286 | sel_ar = self.aspect_ratio[ar_id] 287 | sel_w = self.width[w_id] 288 | sel_h = int(sel_w * sel_ar) 289 | sel_h = (sel_h+4) - ((sel_h+4) % 8) # round by 8 290 | aug_shape = (sel_h, int(sel_w)) 291 | if strategy == "closest": 292 | aug = transforms.Resize(aug_shape) # resize to the closest size 293 | elif strategy == "random": 294 | min_edge = min(sel_w, sel_h) 295 | aug = transforms.Compose([ 296 | transforms.Resize(min_edge), 297 | transforms.CenterCrop((sel_h, sel_w)), 298 | ]) 299 | return aug 300 | 301 | def __iter__(self): 302 | # if not self.debug: 303 | while True: 304 | if self.train: 305 | idx = random.randint(0, self.dataset.__len__()-1) 306 | else: 307 | idx = self._idx 308 | self._idx = (self._idx + 1) % self.dataset.__len__() 309 | 310 | d = self.dataset.__getitem__(idx) 311 | w, h = d["width"], d["height"] 312 | ar = h / w 313 | strategy = "random" if random.random() < self.random_bucket_ratio else "closest" 314 | w_id, ar_id = self.get_ar_w_id(w, ar, strategy=strategy) 315 | aug = self.get_aug(w_id, ar_id, strategy=strategy) 316 | images = d["image"] 317 | assert images.ndim == 3 318 | 319 | d["image"] = aug(images) 320 | assert (d["image"].shape[1] % 8 == 0) and (d["image"].shape[2] % 8 == 0) 321 | 322 | bucket_id = ar_id * len(self.width) + w_id 323 | bucket = self._buckets[bucket_id] 324 | bucket.append(d) 325 | target_batch_size = self.get_batch_size(w_id, ar_id) if self.train else self.batch_size 326 | if len(bucket) == target_batch_size: 327 | data = bucket[:] 328 | # Clear bucket first, because code after yield is not 329 | # guaranteed to execute 330 | del bucket[:] 331 | yield data 332 | 333 | 334 | class ImageDataset(data.Dataset): 335 | """ Generic dataset for Images files stored in folders 336 | Returns BCHW Images in the range [-0.5, 0.5] """ 337 | 338 | def __init__(self, data_folder, data_list, train=True, resolution=64, aug="resize"): 339 | """ 340 | Args: 341 | data_folder: path to the folder with images. The folder 342 | should contain a 'train' and a 'test' directory, 343 | each with corresponding images stored 344 | """ 345 | super().__init__() 346 | self.train = train 347 | self.data_folder = data_folder 348 | self.data_list = data_list 349 | self.resolution = resolution 350 | 351 | with open(self.data_list) as f: 352 | self.annotations = f.readlines() 353 | 354 | total_classes = 1000 355 | classes = [] 356 | # from imagenet_stubs.imagenet_2012_labels import label_to_name 357 | # for i in range(total_classes): 358 | # classes.append(label_to_name(i)) 359 | 360 | self.classes = classes 361 | self.class_to_label = {c: i for i, c in enumerate(self.classes)} 362 | self.label_to_class = {i: c for i, c in enumerate(self.classes)} 363 | 364 | crop_function = transforms.RandomCrop(resolution) if train else transforms.CenterCrop(resolution) 365 | flip_function = transforms.RandomHorizontalFlip() if train else transforms.Lambda(lambda x: x) # flip if train else no op 366 | if aug == "resizecrop": 367 | augmentations = transforms.Compose([ 368 | transforms.Resize(min(resolution), interpolation=_pil_interp("bicubic")), 369 | crop_function, 370 | flip_function, 371 | transforms.ToTensor(), 372 | ]) 373 | elif aug == "crop": 374 | augmentations = transforms.Compose([ 375 | crop_function, 376 | flip_function, 377 | transforms.ToTensor(), 378 | ]) 379 | elif aug == "keep": 380 | augmentations = transforms.Compose([ 381 | flip_function, 382 | transforms.ToTensor(), 383 | ]) 384 | else: 385 | raise NotImplementedError 386 | 387 | self.aug = aug 388 | self.augmentations = augmentations 389 | 390 | 391 | @property 392 | def n_classes(self): 393 | return len(self.classes) 394 | 395 | def __len__(self): 396 | return len(self.annotations) 397 | 398 | def __getitem__(self, idx): 399 | try: 400 | ann = self.annotations[idx].strip() 401 | try: 402 | img_path, height, width = ann.split() 403 | img_label = -1 404 | 405 | except: 406 | img_path, img_label = ann.split() 407 | 408 | full_img_path = os.path.join(self.data_folder, img_path) 409 | 410 | img = Image.open(full_img_path).convert('RGB') 411 | h, w = img.height, img.width 412 | 413 | img = self.augmentations(img) * 2.0 - 1.0 414 | if self.aug != "keep": 415 | assert img.shape[1] == self.resolution[0] and img.shape[2] == self.resolution[1] 416 | 417 | return {"image": img, "label": int(img_label), "path": img_path, "height": h, "width": w, "type": "image"} 418 | except Exception as e: 419 | print(f"Error in dataloader {e}") 420 | return self.__getitem__((idx+1) % self.__len__()) 421 | 422 | 423 | class ImageData(): 424 | 425 | def __init__(self, args, shuffle=True): 426 | super().__init__() 427 | self.args = args 428 | self.shuffle = shuffle 429 | 430 | @property 431 | def n_classes(self): 432 | dataset = self._dataset(True) 433 | return dataset[0].n_classes 434 | 435 | def _dataset(self, train): 436 | datasets = [] 437 | for dataset, batch_size in zip(self.args.dataset_list, self.args.batch_size): 438 | dataset_path = DATASET_DICT[dataset]["dataset_path"] 439 | data_type = DATASET_DICT[dataset]["data_type"] 440 | train_label = DATASET_DICT[dataset]["train_label"] 441 | val_label = DATASET_DICT[dataset]["val_label"] 442 | if data_type == "image": 443 | dataset = ImageDataset( 444 | dataset_path, train_label if train else val_label, train=train, resolution=self.args.resolution, aug=self.args.dataaug 445 | ) 446 | 447 | if self.args.multi_resolution: 448 | # assert len(self.args.data_path) == 1 449 | if train: 450 | dataset = AspectRatioGroupedDataset( 451 | dataset, batch_size=batch_size, debug=self.args.debug, train=train, random_bucket_ratio=self.args.random_bucket_ratio 452 | ) 453 | else: 454 | dataset = DynamicAspectRatioGroupedDataset( 455 | dataset, batch_size=batch_size, debug=self.args.debug, train=train, random_bucket_ratio=self.args.random_bucket_ratio 456 | ) 457 | datasets.append(dataset) 458 | return datasets 459 | 460 | def _dataloader(self, train): 461 | dataset = self._dataset(train) 462 | # print(self.args.batch_size) 463 | if isinstance(self.args.batch_size, int): 464 | self.args.batch_size = [self.args.batch_size] 465 | 466 | assert len(dataset) == len(self.args.batch_size) 467 | dataloaders = [] 468 | for dset, d_batch_size in zip(dataset, self.args.batch_size): 469 | if dist.is_initialized(): 470 | sampler = data.distributed.DistributedSampler( 471 | dset, num_replicas=dist.get_world_size(), rank=dist.get_rank() 472 | ) 473 | global_rank = dist.get_rank() 474 | else: 475 | sampler = None 476 | global_rank = None 477 | 478 | def seed_worker(worker_id): 479 | if global_rank: 480 | seed = self.args.num_workers * global_rank + worker_id 481 | else: 482 | seed = worker_id 483 | # print(f"Setting dataloader worker {worker_id} on GPU {global_rank} as seed {seed}") 484 | torch.manual_seed(seed) 485 | np.random.seed(seed) 486 | random.seed(seed) 487 | 488 | dataloader = data.DataLoader( 489 | dset, 490 | batch_size=d_batch_size if not self.args.multi_resolution else 1, 491 | num_workers=self.args.num_workers, 492 | pin_memory=True, 493 | worker_init_fn=seed_worker, 494 | sampler=sampler if not isinstance(dset, data.IterableDataset) else None, 495 | collate_fn=dset.collate_func if hasattr(dset, "collate_func") else None, 496 | shuffle=sampler is None and train, 497 | drop_last=True, 498 | ) 499 | 500 | dataloaders.append(dataloader) 501 | 502 | return dataloaders 503 | 504 | def train_dataloader(self): 505 | return self._dataloader(True) 506 | 507 | def val_dataloader(self): 508 | return self._dataloader(False)[0] 509 | 510 | def test_dataloader(self): 511 | return self.val_dataloader() 512 | 513 | 514 | @staticmethod 515 | def add_data_specific_args(parent_parser): 516 | parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False) 517 | parser.add_argument('--data_path', type=str, nargs="+", default=[""]) 518 | parser.add_argument('--data_type', type=str, nargs="+", default=[""]) 519 | parser.add_argument('--dataset_list', type=str, nargs="+", default=['']) 520 | parser.add_argument('--dataaug', type=str, choices=["resize", "resizecrop", "crop", "keep"]) 521 | parser.add_argument('--multi_resolution', action="store_true") 522 | parser.add_argument('--random_bucket_ratio', type=float, default=0.) 523 | 524 | parser.add_argument('--resolution', type=int, nargs="+", default=[512]) 525 | parser.add_argument('--batch_size', type=int, nargs="+", default=[32]) 526 | parser.add_argument('--num_workers', type=int, default=8) 527 | parser.add_argument('--image_channels', type=int, default=3) 528 | 529 | return parser 530 | -------------------------------------------------------------------------------- /bitvae/modules/quantizer/multiscale_bsq.py: -------------------------------------------------------------------------------- 1 | """ 2 | Binary Spherical Quantization 3 | Proposed in https://arxiv.org/abs/2406.07548 4 | 5 | In the simplest setup, each dimension is quantized into {-1, 1}. 6 | An entropy penalty is used to encourage utilization. 7 | """ 8 | 9 | import random 10 | from math import log2, ceil 11 | from functools import partial, cache 12 | from collections import namedtuple 13 | from contextlib import nullcontext 14 | 15 | import torch.distributed as dist 16 | from torch.distributed import nn as dist_nn 17 | 18 | import torch 19 | from torch import nn, einsum 20 | import torch.nn.functional as F 21 | from torch.nn import Module 22 | from torch.amp import autocast 23 | 24 | from einops import rearrange, reduce, pack, unpack 25 | 26 | from einx import get_at 27 | 28 | from .dynamic_resolution import predefined_HW_Scales_dynamic 29 | 30 | # constants 31 | 32 | Return = namedtuple('Return', ['quantized', 'indices', 'bit_indices', 'entropy_aux_loss']) 33 | 34 | LossBreakdown = namedtuple('LossBreakdown', ['per_sample_entropy', 'batch_entropy', 'commitment']) 35 | 36 | # distributed helpers 37 | 38 | @cache 39 | def is_distributed(): 40 | return dist.is_initialized() and dist.get_world_size() > 1 41 | 42 | def maybe_distributed_mean(t): 43 | if not is_distributed(): 44 | return t 45 | 46 | dist_nn.all_reduce(t) 47 | t = t / dist.get_world_size() 48 | return t 49 | 50 | # helper functions 51 | 52 | def exists(v): 53 | return v is not None 54 | 55 | def identity(t): 56 | return t 57 | 58 | def default(*args): 59 | for arg in args: 60 | if exists(arg): 61 | return arg() if callable(arg) else arg 62 | return None 63 | 64 | def round_up_multiple(num, mult): 65 | return ceil(num / mult) * mult 66 | 67 | def pack_one(t, pattern): 68 | return pack([t], pattern) 69 | 70 | def unpack_one(t, ps, pattern): 71 | return unpack(t, ps, pattern)[0] 72 | 73 | def l2norm(t): 74 | return F.normalize(t, dim = -1) 75 | 76 | # entropy 77 | 78 | def log(t, eps = 1e-5): 79 | return t.clamp(min = eps).log() 80 | 81 | def entropy(prob): 82 | return (-prob * log(prob)).sum(dim=-1) 83 | 84 | # cosine sim linear 85 | 86 | class CosineSimLinear(Module): 87 | def __init__( 88 | self, 89 | dim_in, 90 | dim_out, 91 | scale = 1. 92 | ): 93 | super().__init__() 94 | self.scale = scale 95 | self.weight = nn.Parameter(torch.randn(dim_in, dim_out)) 96 | 97 | def forward(self, x): 98 | x = F.normalize(x, dim = -1) 99 | w = F.normalize(self.weight, dim = 0) 100 | return (x @ w) * self.scale 101 | 102 | 103 | def get_latent2scale_schedule(T: int, H: int, W: int, mode="original"): 104 | assert mode in ["original", "dynamic", "dense", "same1", "same2", "same3", "half", "dense_f8"] 105 | predefined_HW_Scales = { 106 | # 256 * 256 107 | (32, 32): [(1, 1), (2, 2), (3, 3), (4, 4), (6, 6), (9, 9), (13, 13), (18, 18), (24, 24), (32, 32)], 108 | (16, 16): [(1, 1), (2, 2), (3, 3), (4, 4), (5, 5), (6, 6), (8, 8), (10, 10), (13, 13), (16, 16)], 109 | # 1024x1024 110 | (64, 64): [(1, 1), (2, 2), (3, 3), (4, 4), (5, 5), (7, 7), (9, 9), (12, 12), (16, 16), (21, 21), (27, 27), (36, 36), (48, 48), (64, 64)], 111 | 112 | (36, 64): [(1, 1), (2, 2), (3, 3), (4, 4), (6, 6), (9, 12), (13, 16), (18, 24), (24, 32), (32, 48), (36, 64)], 113 | } 114 | if mode == "dynamic": 115 | predefined_HW_Scales.update(predefined_HW_Scales_dynamic) 116 | elif mode == "dense": 117 | predefined_HW_Scales[(16, 16)] = [(x, x) for x in range(1, 16+1)] 118 | predefined_HW_Scales[(32, 32)] = predefined_HW_Scales[(16, 16)] + [(20, 20), (24, 24), (28, 28), (32, 32)] 119 | predefined_HW_Scales[(64, 64)] = predefined_HW_Scales[(32, 32)] + [(40, 40), (48, 48), (56, 56), (64, 64)] 120 | elif mode == "dense_f8": 121 | # predefined_HW_Scales[(16, 16)] = [(x, x) for x in range(1, 16+1)] 122 | predefined_HW_Scales[(32, 32)] = [(x, x) for x in range(1, 16+1)] + [(20, 20), (24, 24), (28, 28), (32, 32)] 123 | predefined_HW_Scales[(64, 64)] = predefined_HW_Scales[(32, 32)] + [(40, 40), (48, 48), (56, 56), (64, 64)] 124 | predefined_HW_Scales[(128, 128)] = predefined_HW_Scales[(64, 64)] + [(80, 80), (96, 96), (112, 112), (128, 128)] 125 | elif mode.startswith("same"): 126 | num_quant = int(mode[len("same"):]) 127 | predefined_HW_Scales[(16, 16)] = [(16, 16) for _ in range(num_quant)] 128 | predefined_HW_Scales[(32, 32)] = [(32, 32) for _ in range(num_quant)] 129 | predefined_HW_Scales[(64, 64)] = [(64, 64) for _ in range(num_quant)] 130 | elif mode == "half": 131 | predefined_HW_Scales[(32, 32)] = [(1, 1), (2, 2), (3, 3), (4, 4), (5, 5), (6, 6), (8, 8), (10, 10), (13, 13), (16, 16)] 132 | predefined_HW_Scales[(64, 64)] = [(1,1),(2,2),(4,4),(6,6),(8,8),(12,12),(16,16)] 133 | 134 | predefined_T_Scales = [1, 2, 3, 4, 5, 6, 7, 9, 11, 13, 15, 17, 17, 17, 17, 17] 135 | patch_THW_shape_per_scale = predefined_HW_Scales[(H, W)] 136 | if len(predefined_T_Scales) < len(patch_THW_shape_per_scale): 137 | # print("warning: the length of predefined_T_Scales is less than the length of patch_THW_shape_per_scale!") 138 | predefined_T_Scales += [predefined_T_Scales[-1]] * (len(patch_THW_shape_per_scale) - len(predefined_T_Scales)) 139 | patch_THW_shape_per_scale = [(min(T, t), h, w ) for (h, w), t in zip(patch_THW_shape_per_scale, predefined_T_Scales[:len(patch_THW_shape_per_scale)])] 140 | return patch_THW_shape_per_scale 141 | 142 | 143 | class MultiScaleBSQ(Module): 144 | """ Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf """ 145 | 146 | def __init__( 147 | self, 148 | *, 149 | dim, 150 | soft_clamp_input_value = None, 151 | aux_loss = False, # intermediate auxiliary loss 152 | use_decay_factor=False, 153 | use_stochastic_depth=False, 154 | drop_rate=0., 155 | schedule_mode="original", # ["original", "dynamic", "dense"] 156 | keep_first_quant=False, 157 | keep_last_quant=False, 158 | remove_residual_detach=False, 159 | random_flip = False, 160 | flip_prob = 0.5, 161 | flip_mode = "stochastic", # "stochastic", "deterministic" 162 | max_flip_lvl = 1, 163 | random_flip_1lvl = False, # random flip one level each time 164 | flip_lvl_idx = None, 165 | drop_when_test=False, 166 | drop_lvl_idx=None, 167 | drop_lvl_num=0, 168 | random_short_schedule = False, # randomly use short schedule (schedule for images of 256x256) 169 | short_schedule_prob = 0.5, 170 | disable_flip_prob = 0.0, # disable random flip in this image 171 | uniform_short_schedule = False, 172 | **kwargs 173 | ): 174 | super().__init__() 175 | codebook_dim = dim 176 | 177 | requires_projection = codebook_dim != dim 178 | self.project_in = nn.Linear(dim, codebook_dim) if requires_projection else nn.Identity() 179 | self.project_out = nn.Linear(codebook_dim, dim) if requires_projection else nn.Identity() 180 | self.has_projections = requires_projection 181 | self.layernorm = nn.Identity() 182 | self.use_stochastic_depth = use_stochastic_depth 183 | self.drop_rate = drop_rate 184 | self.remove_residual_detach = remove_residual_detach 185 | self.random_flip = random_flip 186 | self.flip_prob = flip_prob 187 | self.flip_mode = flip_mode 188 | self.max_flip_lvl = max_flip_lvl 189 | self.random_flip_1lvl = random_flip_1lvl 190 | self.flip_lvl_idx = flip_lvl_idx 191 | assert (random_flip and random_flip_1lvl) == False 192 | self.disable_flip_prob = disable_flip_prob 193 | 194 | self.drop_when_test = drop_when_test 195 | self.drop_lvl_idx = drop_lvl_idx 196 | self.drop_lvl_num = drop_lvl_num 197 | if self.drop_when_test: 198 | assert drop_lvl_idx is not None 199 | assert drop_lvl_num > 0 200 | self.random_short_schedule = random_short_schedule 201 | self.short_schedule_prob = short_schedule_prob 202 | self.full2short = {7:7, 10:7, 13:7, 16:16, 20:16, 24:16} 203 | self.full2short_f8 = {20:20, 24:20, 28:20} 204 | self.uniform_short_schedule = uniform_short_schedule 205 | assert not (self.random_short_schedule and self.uniform_short_schedule) 206 | 207 | self.lfq = BSQ( 208 | dim = codebook_dim, 209 | codebook_scale = 1, 210 | soft_clamp_input_value = soft_clamp_input_value, 211 | **kwargs 212 | ) 213 | 214 | self.z_interplote_up = 'trilinear' 215 | self.z_interplote_down = 'area' 216 | 217 | self.use_decay_factor = use_decay_factor 218 | self.schedule_mode = schedule_mode 219 | self.keep_first_quant = keep_first_quant 220 | self.keep_last_quant = keep_last_quant 221 | if self.use_stochastic_depth and self.drop_rate > 0: 222 | assert self.keep_first_quant or self.keep_last_quant 223 | 224 | @property 225 | def codebooks(self): 226 | return self.lfq.codebook 227 | 228 | def get_codes_from_indices(self, indices_list): 229 | all_codes = [] 230 | for indices in indices_list: 231 | codes = self.lfq.indices_to_codes(indices) 232 | all_codes.append(codes) 233 | _, _, T, H, W = all_codes[-1].size() 234 | summed_codes = 0 235 | for code in all_codes: 236 | summed_codes += F.interpolate(code, size=(T, H, W), mode=self.z_interplote_up) 237 | return summed_codes 238 | 239 | def get_output_from_indices(self, indices): 240 | codes = self.get_codes_from_indices(indices) 241 | codes_summed = reduce(codes, 'q ... -> ...', 'sum') 242 | return self.project_out(codes_summed) 243 | 244 | def flip_quant(self, x): 245 | if self.flip_mode == 'stochastic': 246 | flip_mask = torch.rand_like(x) < self.flip_prob 247 | else: 248 | raise NotImplementedError 249 | x = x.clone() 250 | x[flip_mask] = -x[flip_mask] 251 | return x 252 | 253 | def forward( 254 | self, 255 | x, 256 | mask = None, 257 | return_all_codes = False, 258 | ): 259 | if x.ndim == 4: 260 | x = x.unsqueeze(2) 261 | B, C, T, H, W = x.size() 262 | 263 | if self.schedule_mode.startswith("same"): 264 | scale_num = int(self.schedule_mode[len("same"):]) 265 | assert T == 1 266 | scale_schedule = [(1, H, W)] * scale_num 267 | else: 268 | scale_schedule = get_latent2scale_schedule(T, H, W, mode=self.schedule_mode) 269 | scale_num = len(scale_schedule) 270 | 271 | if self.uniform_short_schedule: 272 | scale_num_short = self.full2short_f8[scale_num] if self.schedule_mode == "dense_f8" else self.full2short[scale_num] 273 | scale_num = random.randint(scale_num_short, scale_num) 274 | scale_schedule = scale_schedule[:scale_num] 275 | elif self.random_short_schedule and random.random() < self.short_schedule_prob: 276 | if self.schedule_mode == "dense_f8": 277 | scale_num = self.full2short_f8[scale_num] 278 | else: 279 | scale_num = self.full2short[scale_num] 280 | scale_schedule = scale_schedule[:scale_num] 281 | 282 | # x = self.project_in(x) 283 | x = x.permute(0, 2, 3, 4, 1).contiguous() # (b, c, t, h, w) => (b, t, h, w, c) 284 | x = self.project_in(x) 285 | x = x.permute(0, 4, 1, 2, 3).contiguous() # (b, t, h, w, c) => (b, c, t, h, w) 286 | x = self.layernorm(x) 287 | 288 | quantized_out = 0. 289 | residual = x 290 | 291 | all_losses = [] 292 | all_indices = [] 293 | all_bit_indices = [] 294 | 295 | # go through the layers 296 | out_fact = init_out_fact = 1.0 297 | # residual_list = [] 298 | # interpolate_residual_list = [] 299 | # quantized_list = [] 300 | if self.drop_when_test: 301 | drop_lvl_start = self.drop_lvl_idx 302 | drop_lvl_end = self.drop_lvl_idx + self.drop_lvl_num 303 | disable_flip = True if random.random() < self.disable_flip_prob else False # disable random flip in this image 304 | with autocast('cuda', enabled = False): 305 | for si, (pt, ph, pw) in enumerate(scale_schedule): 306 | 307 | out_fact = max(0.1, out_fact) if self.use_decay_factor else init_out_fact 308 | if (pt, ph, pw) != (T, H, W): 309 | interpolate_residual = F.interpolate(residual, size=(pt, ph, pw), mode=self.z_interplote_down) 310 | else: 311 | interpolate_residual = residual 312 | # residual_list.append(torch.norm(residual.detach(), dim=1).mean()) 313 | # interpolate_residual_list.append(torch.norm(interpolate_residual.detach(), dim=1).mean()) 314 | if self.training and self.use_stochastic_depth and random.random() < self.drop_rate: 315 | if (si == 0 and self.keep_first_quant) or (si == scale_num - 1 and self.keep_last_quant): 316 | quantized, indices, bit_indices, loss = self.lfq(interpolate_residual) 317 | if self.random_flip and si < self.max_flip_lvl and (not disable_flip): 318 | quantized = self.flip_quant(quantized) 319 | quantized = quantized * out_fact 320 | all_indices.append(indices) 321 | all_losses.append(loss) 322 | all_bit_indices.append(bit_indices) 323 | else: 324 | quantized = torch.zeros_like(interpolate_residual) 325 | elif self.drop_when_test and drop_lvl_start <= si < drop_lvl_end: 326 | continue 327 | else: 328 | # residual_norm = torch.norm(interpolate_residual.detach(), dim=1) # (b, t, h, w) 329 | # print(si, residual_norm.min(), residual_norm.max(), residual_norm.mean()) 330 | quantized, indices, bit_indices, loss = self.lfq(interpolate_residual) 331 | if self.random_flip and si < self.max_flip_lvl and (not disable_flip): 332 | quantized = self.flip_quant(quantized) 333 | if self.random_flip_1lvl and si == self.flip_lvl_idx and (not disable_flip): 334 | quantized = self.flip_quant(quantized) 335 | quantized = quantized * out_fact 336 | all_indices.append(indices) 337 | all_losses.append(loss) 338 | all_bit_indices.append(bit_indices) 339 | # quantized_list.append(torch.norm(quantized.detach(), dim=1).mean()) 340 | if (pt, ph, pw) != (T, H, W): 341 | quantized = F.interpolate(quantized, size=(T, H, W), mode=self.z_interplote_up).contiguous() 342 | 343 | if self.remove_residual_detach: 344 | residual = residual - quantized 345 | else: 346 | residual = residual - quantized.detach() 347 | quantized_out = quantized_out + quantized 348 | 349 | if self.use_decay_factor: 350 | out_fact -= 0.1 351 | # print("residual_list:", residual_list) 352 | # print("interpolate_residual_list:", interpolate_residual_list) 353 | # print("quantized_list:", quantized_list) 354 | # import ipdb; ipdb.set_trace() 355 | # project out, if needed 356 | quantized_out = quantized_out.permute(0, 2, 3, 4, 1).contiguous() # (b, c, t, h, w) => (b, t, h, w, c) 357 | quantized_out = self.project_out(quantized_out) 358 | quantized_out = quantized_out.permute(0, 4, 1, 2, 3).contiguous() # (b, t, h, w, c) => (b, c, t, h, w) 359 | 360 | # image 361 | if quantized_out.size(2) == 1: 362 | quantized_out = quantized_out.squeeze(2) 363 | 364 | # stack all losses and indices 365 | 366 | all_losses = torch.stack(all_losses, dim = -1) 367 | 368 | ret = (quantized_out, all_indices, all_bit_indices, all_losses) 369 | 370 | if not return_all_codes: 371 | return ret 372 | 373 | # whether to return all codes from all codebooks across layers 374 | all_codes = self.get_codes_from_indices(all_indices) 375 | 376 | # will return all codes in shape (quantizer, batch, sequence length, codebook dimension) 377 | 378 | return (*ret, all_codes) 379 | 380 | 381 | class BSQ(Module): 382 | def __init__( 383 | self, 384 | *, 385 | dim = None, 386 | entropy_loss_weight = 0.1, 387 | commitment_loss_weight = 0.25, 388 | diversity_gamma = 1., 389 | straight_through_activation = nn.Identity(), 390 | num_codebooks = 1, 391 | keep_num_codebooks_dim = None, 392 | codebook_scale = 1., # for residual LFQ, codebook scaled down by 2x at each layer 393 | frac_per_sample_entropy = 1., # make less than 1. to only use a random fraction of the probs for per sample entropy 394 | has_projections = None, 395 | projection_has_bias = True, 396 | soft_clamp_input_value = None, 397 | cosine_sim_project_in = False, 398 | cosine_sim_project_in_scale = None, 399 | channel_first = None, 400 | experimental_softplus_entropy_loss = False, 401 | entropy_loss_offset = 5., # how much to shift the loss before softplus 402 | spherical = True, # from https://arxiv.org/abs/2406.07548 403 | force_quantization_f32 = True, # will force the quantization step to be full precision 404 | inv_temperature = 100.0, 405 | gamma0=1.0, gamma=1.0, zeta=1.0, 406 | new_quant = False, # new quant function, 407 | use_out_phi = False, # use output phi network 408 | use_out_phi_res = False, # residual out phi 409 | ): 410 | super().__init__() 411 | 412 | # some assert validations 413 | assert exists(dim) , 'dim must be specified for BSQ' 414 | 415 | codebook_dim = dim 416 | codebook_dims = codebook_dim * num_codebooks 417 | dim = default(dim, codebook_dims) 418 | self.codebook_dims = codebook_dims 419 | 420 | has_projections = default(has_projections, dim != codebook_dims) 421 | 422 | if cosine_sim_project_in: 423 | cosine_sim_project_in = default(cosine_sim_project_in_scale, codebook_scale) 424 | project_in_klass = partial(CosineSimLinear, scale = cosine_sim_project_in) 425 | else: 426 | project_in_klass = partial(nn.Linear, bias = projection_has_bias) 427 | 428 | self.project_in = project_in_klass(dim, codebook_dims) if has_projections else nn.Identity() # nn.Identity() 429 | self.project_out = nn.Linear(codebook_dims, dim, bias = projection_has_bias) if has_projections else nn.Identity() # nn.Identity() 430 | self.has_projections = has_projections 431 | 432 | self.out_phi = nn.Linear(codebook_dims, codebook_dims) if use_out_phi else nn.Identity() 433 | self.use_out_phi_res = use_out_phi_res 434 | if self.use_out_phi_res: 435 | self.out_phi_scale = nn.Parameter(torch.zeros(codebook_dims), requires_grad=True) # init as zero 436 | 437 | self.dim = dim 438 | self.codebook_dim = codebook_dim 439 | self.num_codebooks = num_codebooks 440 | 441 | keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1) 442 | assert not (num_codebooks > 1 and not keep_num_codebooks_dim) 443 | self.keep_num_codebooks_dim = keep_num_codebooks_dim 444 | 445 | # channel first 446 | 447 | self.channel_first = channel_first 448 | 449 | # straight through activation 450 | 451 | self.activation = straight_through_activation 452 | 453 | # For BSQ (binary spherical quantization) 454 | if not spherical: 455 | raise ValueError("For BSQ, spherical must be True.") 456 | self.persample_entropy_compute = 'analytical' 457 | self.inv_temperature = inv_temperature 458 | self.gamma0 = gamma0 # loss weight for entropy penalty 459 | self.gamma = gamma # loss weight for entropy penalty 460 | self.zeta = zeta # loss weight for entire entropy penalty 461 | self.new_quant = new_quant 462 | 463 | # entropy aux loss related weights 464 | 465 | assert 0 < frac_per_sample_entropy <= 1. 466 | self.frac_per_sample_entropy = frac_per_sample_entropy 467 | 468 | self.diversity_gamma = diversity_gamma 469 | self.entropy_loss_weight = entropy_loss_weight 470 | 471 | # codebook scale 472 | 473 | self.codebook_scale = codebook_scale 474 | 475 | # commitment loss 476 | 477 | self.commitment_loss_weight = commitment_loss_weight 478 | 479 | # whether to soft clamp the input value from -value to value 480 | 481 | self.soft_clamp_input_value = soft_clamp_input_value 482 | assert not exists(soft_clamp_input_value) or soft_clamp_input_value >= codebook_scale 483 | 484 | # whether to make the entropy loss positive through a softplus (experimental, please report if this worked or not in discussions) 485 | 486 | self.entropy_loss_offset = entropy_loss_offset 487 | self.experimental_softplus_entropy_loss = experimental_softplus_entropy_loss 488 | 489 | # for no auxiliary loss, during inference 490 | 491 | self.register_buffer('mask', 2 ** torch.arange(codebook_dim - 1, -1, -1)) 492 | self.register_buffer('zero', torch.tensor(0.), persistent = False) 493 | 494 | # whether to force quantization step to be f32 495 | 496 | self.force_quantization_f32 = force_quantization_f32 497 | 498 | def bits_to_codes(self, bits): 499 | return bits * self.codebook_scale * 2 - self.codebook_scale 500 | 501 | # @property 502 | # def dtype(self): 503 | # return self.codebook.dtype 504 | 505 | def indices_to_codes( 506 | self, 507 | indices, 508 | label_type = 'int_label', 509 | project_out = True 510 | ): 511 | assert label_type in ['int_label', 'bit_label'] 512 | is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim)) 513 | should_transpose = default(self.channel_first, is_img_or_video) 514 | 515 | if not self.keep_num_codebooks_dim: 516 | if label_type == 'int_label': 517 | indices = rearrange(indices, '... -> ... 1') 518 | else: 519 | indices = indices.unsqueeze(-2) 520 | 521 | # indices to codes, which are bits of either -1 or 1 522 | 523 | if label_type == 'int_label': 524 | assert indices[..., None].int().min() > 0 525 | bits = ((indices[..., None].int() & self.mask) != 0).float() # .to(self.dtype) 526 | else: 527 | bits = indices 528 | 529 | codes = self.bits_to_codes(bits) 530 | 531 | codes = l2norm(codes) # must normalize when using BSQ 532 | 533 | codes = rearrange(codes, '... c d -> ... (c d)') 534 | 535 | # whether to project codes out to original dimensions 536 | # if the input feature dimensions were not log2(codebook size) 537 | 538 | if project_out: 539 | codes = self.project_out(codes) 540 | 541 | # rearrange codes back to original shape 542 | 543 | if should_transpose: 544 | codes = rearrange(codes, 'b ... d -> b d ...') 545 | 546 | return codes 547 | 548 | def quantize(self, z): 549 | assert z.shape[-1] == self.codebook_dims, f"Expected {self.codebook_dims} dimensions, got {z.shape[-1]}" 550 | 551 | zhat = torch.where(z > 0, 552 | torch.tensor(1, dtype=z.dtype, device=z.device), 553 | torch.tensor(-1, dtype=z.dtype, device=z.device)) 554 | return z + (zhat - z).detach() 555 | 556 | def quantize_new(self, z): 557 | assert z.shape[-1] == self.codebook_dims, f"Expected {self.codebook_dims} dimensions, got {z.shape[-1]}" 558 | 559 | zhat = torch.where(z > 0, 560 | torch.tensor(1, dtype=z.dtype, device=z.device), 561 | torch.tensor(-1, dtype=z.dtype, device=z.device)) 562 | 563 | q_scale = 1. / (self.codebook_dims ** 0.5) 564 | zhat = q_scale * zhat # on unit sphere 565 | 566 | return z + (zhat - z).detach() 567 | 568 | def soft_entropy_loss(self, z): 569 | if self.persample_entropy_compute == 'analytical': 570 | # if self.l2_norm: 571 | p = torch.sigmoid(-4 * z / (self.codebook_dims ** 0.5) * self.inv_temperature) 572 | # else: 573 | # p = torch.sigmoid(-4 * z * self.inv_temperature) 574 | prob = torch.stack([p, 1-p], dim=-1) # (b, h, w, 18, 2) 575 | per_sample_entropy = self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean() # (b,h,w,18)->(b,h,w)->scalar 576 | else: 577 | per_sample_entropy = self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean() 578 | 579 | # macro average of the probability of each subgroup 580 | avg_prob = reduce(prob, '... g d ->g d', 'mean') # (18, 2) 581 | codebook_entropy = self.get_entropy(avg_prob, dim=-1, normalize=False) 582 | 583 | # the approximation of the entropy is the sum of the entropy of each subgroup 584 | return per_sample_entropy, codebook_entropy.sum(), avg_prob 585 | 586 | def get_entropy(self, count, dim=-1, eps=1e-4, normalize=True): 587 | if normalize: # False 588 | probs = (count + eps) / (count + eps).sum(dim=dim, keepdim =True) 589 | else: # True 590 | probs = count 591 | H = -(probs * torch.log(probs + 1e-8)).sum(dim=dim) 592 | return H 593 | 594 | def forward( 595 | self, 596 | x, 597 | return_loss_breakdown = False, 598 | mask = None, 599 | entropy_weight=0.1 600 | ): 601 | """ 602 | einstein notation 603 | b - batch 604 | n - sequence (or flattened spatial dimensions) 605 | d - feature dimension, which is also log2(codebook size) 606 | c - number of codebook dim 607 | """ 608 | 609 | is_img_or_video = x.ndim >= 4 610 | should_transpose = default(self.channel_first, is_img_or_video) 611 | 612 | # standardize image or video into (batch, seq, dimension) 613 | 614 | if should_transpose: 615 | x = rearrange(x, 'b d ... -> b ... d') 616 | x, ps = pack_one(x, 'b * d') # x.shape [b, hwt, c] 617 | 618 | assert x.shape[-1] == self.dim, f'expected dimension of {self.dim} but received {x.shape[-1]}' 619 | 620 | x = self.project_in(x) 621 | 622 | # split out number of codebooks 623 | 624 | x = rearrange(x, 'b n (c d) -> b n c d', c = self.num_codebooks) 625 | 626 | x = l2norm(x) 627 | 628 | # whether to force quantization step to be full precision or not 629 | 630 | force_f32 = self.force_quantization_f32 631 | 632 | quantization_context = partial(autocast, 'cuda', enabled = False) if force_f32 else nullcontext 633 | 634 | indices = None 635 | with quantization_context(): 636 | 637 | if force_f32: 638 | orig_dtype = x.dtype 639 | x = x.float() 640 | 641 | # use straight-through gradients (optionally with custom activation fn) if training 642 | if self.new_quant: 643 | quantized = self.quantize_new(x) 644 | else: 645 | quantized = self.quantize(x) 646 | q_scale = 1. / (self.codebook_dims ** 0.5) 647 | quantized = q_scale * quantized # on unit sphere 648 | 649 | # calculate indices 650 | bit_indices = (quantized > 0).int() 651 | 652 | # entropy aux loss 653 | if self.training: 654 | persample_entropy, cb_entropy, avg_prob = self.soft_entropy_loss(x) # compute entropy 655 | entropy_penalty = self.gamma0 * persample_entropy - self.gamma * cb_entropy 656 | else: 657 | # if not training, just return dummy 0 658 | entropy_penalty = persample_entropy = cb_entropy = self.zero 659 | 660 | # commit loss 661 | 662 | if self.training and self.commitment_loss_weight > 0.: 663 | 664 | commit_loss = F.mse_loss(x, quantized.detach(), reduction = 'none') 665 | 666 | if exists(mask): 667 | commit_loss = commit_loss[mask] 668 | 669 | commit_loss = commit_loss.mean() 670 | else: 671 | commit_loss = self.zero 672 | 673 | # input back to original dtype if needed 674 | 675 | if force_f32: 676 | x = x.type(orig_dtype) 677 | 678 | # merge back codebook dim 679 | x = quantized # rename quantized to x for output 680 | 681 | if self.use_out_phi_res: 682 | x = x + self.out_phi_scale * self.out_phi(x) # apply out_phi on quant output as residual 683 | else: 684 | x = self.out_phi(x) # apply out_phi on quant output 685 | 686 | x = rearrange(x, 'b n c d -> b n (c d)') 687 | 688 | # project out to feature dimension if needed 689 | 690 | x = self.project_out(x) 691 | 692 | # reconstitute image or video dimensions 693 | 694 | if should_transpose: 695 | x = unpack_one(x, ps, 'b * d') 696 | x = rearrange(x, 'b ... d -> b d ...') 697 | 698 | bit_indices = unpack_one(bit_indices, ps, 'b * c d') 699 | 700 | # whether to remove single codebook dim 701 | 702 | if not self.keep_num_codebooks_dim: 703 | bit_indices = rearrange(bit_indices, '... 1 d -> ... d') 704 | 705 | # complete aux loss 706 | 707 | aux_loss = commit_loss * self.commitment_loss_weight + (self.zeta * entropy_penalty / self.inv_temperature)*entropy_weight 708 | # returns 709 | 710 | ret = Return(x, indices, bit_indices, aux_loss) 711 | 712 | if not return_loss_breakdown: 713 | return ret 714 | 715 | return ret, LossBreakdown(persample_entropy, cb_entropy, commit_loss) 716 | 717 | --------------------------------------------------------------------------------