├── bitvae
├── utils
│ ├── __init__.py
│ ├── random_numbers.npy
│ ├── generate_random_num.py
│ ├── logger.py
│ ├── init_models.py
│ ├── misc.py
│ ├── distributed.py
│ └── arguments.py
├── data
│ ├── __init__.py
│ ├── dataset_zoo.py
│ └── data.py
├── modules
│ ├── quantizer
│ │ ├── __init__.py
│ │ ├── dynamic_resolution.py
│ │ └── multiscale_bsq.py
│ ├── cache
│ │ └── vgg.pth
│ ├── __init__.py
│ ├── conv.py
│ ├── loss.py
│ ├── normalization.py
│ └── lpips.py
├── evaluation
│ ├── __init__.py
│ ├── fid.py
│ └── inception.py
└── models
│ ├── __init__.py
│ ├── discriminator.py
│ └── d_vae.py
├── .gitignore
├── scripts
├── prepare.sh
└── release
│ ├── test_img_d16_stage1.sh
│ ├── test_img_d16_stage2.sh
│ ├── test_img_d32_stage1.sh
│ ├── test_img_d32_stage2.sh
│ ├── train_img_d32_stage1.sh
│ ├── train_img_d16_stage2.sh
│ ├── train_img_d32_stage2.sh
│ └── train_img_d16_stage1.sh
├── requirements.txt
├── LICENSE
├── README.md
├── eval.py
└── train.py
/bitvae/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/bitvae/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .data import ImageData, ImageDataset
--------------------------------------------------------------------------------
/bitvae/modules/quantizer/__init__.py:
--------------------------------------------------------------------------------
1 | from .multiscale_bsq import MultiScaleBSQ
2 |
--------------------------------------------------------------------------------
/bitvae/evaluation/__init__.py:
--------------------------------------------------------------------------------
1 | from .fid import calculate_frechet_distance
2 | from .inception import InceptionV3
--------------------------------------------------------------------------------
/bitvae/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .discriminator import ImageDiscriminator
2 | from .d_vae import AutoEncoder as d_vae
--------------------------------------------------------------------------------
/bitvae/modules/cache/vgg.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FoundationVision/BitVAE/HEAD/bitvae/modules/cache/vgg.pth
--------------------------------------------------------------------------------
/bitvae/utils/random_numbers.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FoundationVision/BitVAE/HEAD/bitvae/utils/random_numbers.npy
--------------------------------------------------------------------------------
/bitvae/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .lpips import LPIPS
2 | from .normalization import Normalize
3 | from .conv import Conv
4 | # from .commitments import DiagonalGaussianDistribution
5 | from .loss import adopt_weight
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | **/__pycache__/
2 | lightning_logs/
3 | .ipynb_checkpoints/
4 | *.egg-info
5 | .pyc
6 | results*
7 | logs*
8 | recon.*
9 | dataset/
10 | **/span.log
11 | wandb/*
12 | *.png
13 | bitvae_results
14 | checkpoints
15 | labels
16 | .idea/
17 |
--------------------------------------------------------------------------------
/scripts/prepare.sh:
--------------------------------------------------------------------------------
1 | ### trainer
2 | sudo apt-get install libatlas-base-dev -y
3 | sudo pip3 uninstall numpy -y
4 | pip3 install -r requirements.txt
5 | sudo apt-get install ffmpeg libsm6 libxext6 -y
6 |
7 | pip3 install moviepy==2.0.0.dev2 imageio
8 | pip3 install seaborn
9 | pip3 install einx
--------------------------------------------------------------------------------
/bitvae/modules/conv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class Conv(nn.Module):
5 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
6 | super().__init__()
7 |
8 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding)
9 | self.stride = stride
10 | self.kernel_size = kernel_size
11 |
12 | def forward(self, x):
13 | return self.conv(x)
--------------------------------------------------------------------------------
/bitvae/data/dataset_zoo.py:
--------------------------------------------------------------------------------
1 | DATASET_DICT = {
2 | "imagenet": {
3 | "train_label": "labels/imagenet/train.txt",
4 | "val_label": "labels/imagenet/val.txt",
5 | "dataset_path": "",
6 | "data_type": "image"
7 | },
8 | "openimages": {
9 | "train_label": "labels/openimages/train.txt", # 9M+ images
10 | "val_label": "labels/openimages/train.txt",
11 | "dataset_path": "",
12 | "data_type": "image"
13 | }
14 | }
--------------------------------------------------------------------------------
/bitvae/utils/generate_random_num.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | if __name__ == "__main__":
5 | # Generate random numbers and save them to a file
6 | np.random.seed(42) # Set the seed for reproducibility
7 | version = "v2" # ["v1", "v2"]
8 | if version == "v1":
9 | num_choices = 3
10 | save_path = "random_numbers.npy"
11 | elif version == "v2":
12 | num_choices = 45 # 3 or 45
13 | save_path = "random_numbers_v2.npy"
14 | else:
15 | raise ValueError("Invalid version")
16 | random_numbers = np.random.choice(list(range(num_choices)), size=500000)
17 | np.save(save_path, random_numbers)
18 | loaded_random_numbers = np.load(save_path)
19 | print(loaded_random_numbers[:100])
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # pytorch-lightning==1.8.6
2 | diffusers==0.29.1
3 | h5py==3.11.0
4 | einops==0.8.0
5 | ftfy==6.2.0
6 | imageio==2.34.1
7 | imageio-ffmpeg==0.5.1
8 | regex==2024.5.15
9 | scikit-video==1.1.11
10 | tqdm==4.66.4
11 | av==12.1.0
12 | beartype==0.18.5
13 | scikit-learn==1.5.0
14 | timm==1.0.7
15 | transformers==4.37.2
16 | decord==0.6.0
17 | fvcore==0.1.5.post20221221
18 | pycocotools==2.0.8
19 | cloudpickle==3.0.0
20 | omegaconf==2.3.0
21 | scikit-image==0.24.0
22 | lpips==0.1.4
23 | accelerate
24 | numpy==1.26.2
25 | # pip3 install git+https://github.com/nottombrown/imagenet_stubs
26 |
27 | fairscale==0.4.13
28 | # lightning==2.2.5
29 | opencv-python==4.10.0.84
30 | ipdb==0.13.13
31 |
32 | # torch==2.3.1
33 | # torchvision==0.18.1
34 | # torchaudio==2.3.1
35 |
36 |
37 | einx==0.3.0
--------------------------------------------------------------------------------
/bitvae/utils/logger.py:
--------------------------------------------------------------------------------
1 | # from https://github.com/FoundationVision/LlamaGen/blob/main/utils/logger.py
2 | import logging
3 | import glob
4 | import os
5 | import torch.distributed as dist
6 |
7 | def create_logger(logging_dir):
8 | """
9 | Create a logger that writes to a log file and stdout.
10 | """
11 | if dist.get_rank() == 0: # real logger
12 | existing_logs = glob.glob(os.path.join(logging_dir, 'log_*.txt'))
13 | log_numbers = [int(log.split('.txt')[0].split('_')[-1]) for log in existing_logs]
14 | next_log_number = max(log_numbers) + 1 if log_numbers else 1
15 | logging.basicConfig(
16 | level=logging.INFO,
17 | format='%(asctime)s %(message)s',
18 | datefmt='%Y-%m-%d %H:%M:%S',
19 | handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log_{next_log_number}.txt")]
20 | )
21 | logger = logging.getLogger(__name__)
22 | else: # dummy logger (does nothing)
23 | logger = logging.getLogger(__name__)
24 | logger.addHandler(logging.NullHandler())
25 | return logger
--------------------------------------------------------------------------------
/bitvae/modules/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | def hinge_d_loss(logits_real, logits_fake):
6 | loss_real = torch.mean(F.relu(1. - logits_real))
7 | loss_fake = torch.mean(F.relu(1. + logits_fake))
8 | d_loss = 0.5 * (loss_real + loss_fake)
9 | return d_loss
10 |
11 | def vanilla_d_loss(logits_real, logits_fake):
12 | d_loss = 0.5 * (
13 | torch.mean(torch.nn.functional.softplus(-logits_real)) +
14 | torch.mean(torch.nn.functional.softplus(logits_fake)))
15 | return d_loss
16 |
17 | def get_disc_loss(disc_loss_type):
18 | if disc_loss_type == 'vanilla':
19 | disc_loss = vanilla_d_loss
20 | elif disc_loss_type == 'hinge':
21 | disc_loss = hinge_d_loss
22 | return disc_loss
23 |
24 | def adopt_weight(global_step, threshold=0, value=0., warmup=0):
25 | if global_step < threshold or threshold < 0:
26 | weight = value
27 | else:
28 | weight = 1
29 | if global_step - threshold < warmup:
30 | weight = min((global_step - threshold) / warmup, 1)
31 | return weight
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 FoundationVision
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/scripts/release/test_img_d16_stage1.sh:
--------------------------------------------------------------------------------
1 | python3 eval.py --tokenizer 'flux' --inference_type 'image' --patch_size 16 \
2 | --base_ch 128 --encoder_ch_mult 1 2 4 4 4 --decoder_ch_mult 1 2 4 4 4 \
3 | --codebook_dim 16 --codebook_size 65536 \
4 | --vqgan_ckpt "bitvae_results/infinity_d16_stage1/checkpoints/model_step_499999.ckpt" \
5 | --batch_size 1 --dataset_list "imagenet" --save ./imagenet_256_dim16_stage1 --dataaug "resizecrop" --resolution 256 256 --num_workers 0 \
6 | --default_root_dir "test" --save_prediction --quantizer_type 'MultiScaleBSQ' \
7 | --new_quant --schedule_mode "dynamic" --remove_residual_detach --disable_codebook_usage \
8 | $@
9 |
10 | python3 eval.py --tokenizer 'flux' --inference_type 'image' --patch_size 16 \
11 | --base_ch 128 --encoder_ch_mult 1 2 4 4 4 --decoder_ch_mult 1 2 4 4 4 \
12 | --codebook_dim 16 --codebook_size 65536 \
13 | --vqgan_ckpt "bitvae_results/infinity_d16_stage1/checkpoints/model_step_499999.ckpt" \
14 | --batch_size 1 --dataset_list "imagenet" --save ./imagenet_512_dim16_stage1 --dataaug "resizecrop" --resolution 512 512 --num_workers 0 \
15 | --default_root_dir "test" --save_prediction --quantizer_type 'MultiScaleBSQ' \
16 | --new_quant --schedule_mode "dynamic" --remove_residual_detach --disable_codebook_usage \
17 | $@
18 |
--------------------------------------------------------------------------------
/scripts/release/test_img_d16_stage2.sh:
--------------------------------------------------------------------------------
1 | python3 eval.py --tokenizer 'flux' --inference_type 'image' --patch_size 16 \
2 | --base_ch 128 --encoder_ch_mult 1 2 4 4 4 --decoder_ch_mult 1 2 4 4 4 \
3 | --codebook_dim 16 --codebook_size 65536 \
4 | --vqgan_ckpt "bitvae_results/infinity_d16_stage2/checkpoints/model_step_199999.ckpt" \
5 | --batch_size 1 --dataset_list "imagenet" --save ./imagenet_256_dim16_stage2 --dataaug "resizecrop" --resolution 256 256 --num_workers 0 \
6 | --default_root_dir "test" --save_prediction --quantizer_type 'MultiScaleBSQ' \
7 | --new_quant --schedule_mode "dynamic" --remove_residual_detach --disable_codebook_usage \
8 | $@
9 |
10 | python3 eval.py --tokenizer 'flux' --inference_type 'image' --patch_size 16 \
11 | --base_ch 128 --encoder_ch_mult 1 2 4 4 4 --decoder_ch_mult 1 2 4 4 4 \
12 | --codebook_dim 16 --codebook_size 65536 \
13 | --vqgan_ckpt "bitvae_results/infinity_d16_stage2/checkpoints/model_step_199999.ckpt" \
14 | --batch_size 1 --dataset_list "imagenet" --save ./imagenet_512_dim16_stage2 --dataaug "resizecrop" --resolution 512 512 --num_workers 0 \
15 | --default_root_dir "test" --save_prediction --quantizer_type 'MultiScaleBSQ' \
16 | --new_quant --schedule_mode "dynamic" --remove_residual_detach --disable_codebook_usage \
17 | $@
18 |
--------------------------------------------------------------------------------
/scripts/release/test_img_d32_stage1.sh:
--------------------------------------------------------------------------------
1 | python3 eval.py --tokenizer 'flux' --inference_type 'image' --patch_size 16 \
2 | --base_ch 128 --encoder_ch_mult 1 2 4 4 4 --decoder_ch_mult 1 2 4 4 4 \
3 | --codebook_dim 32 --codebook_size 4294967296 \
4 | --vqgan_ckpt "bitvae_results/infinity_d32_stage1/checkpoints/model_step_499999.ckpt" \
5 | --batch_size 1 --dataset_list "imagenet" --save ./imagenet_256_dim32_stage1 --dataaug "resizecrop" --resolution 256 256 --num_workers 0 \
6 | --default_root_dir "test" --save_prediction --quantizer_type 'MultiScaleBSQ' \
7 | --new_quant --schedule_mode "dynamic" --remove_residual_detach --disable_codebook_usage \
8 | $@
9 |
10 | python3 eval.py --tokenizer 'flux' --inference_type 'image' --patch_size 16 \
11 | --base_ch 128 --encoder_ch_mult 1 2 4 4 4 --decoder_ch_mult 1 2 4 4 4 \
12 | --codebook_dim 32 --codebook_size 4294967296 \
13 | --vqgan_ckpt "bitvae_results/infinity_d32_stage1/checkpoints/model_step_499999.ckpt" \
14 | --batch_size 1 --dataset_list "imagenet" --save ./imagenet_512_dim32_stage1 --dataaug "resizecrop" --resolution 512 512 --num_workers 0 \
15 | --default_root_dir "test" --save_prediction --quantizer_type 'MultiScaleBSQ' \
16 | --new_quant --schedule_mode "dynamic" --remove_residual_detach --disable_codebook_usage \
17 | $@
--------------------------------------------------------------------------------
/scripts/release/test_img_d32_stage2.sh:
--------------------------------------------------------------------------------
1 | python3 eval.py --tokenizer 'flux' --inference_type 'image' --patch_size 16 \
2 | --base_ch 128 --encoder_ch_mult 1 2 4 4 4 --decoder_ch_mult 1 2 4 4 4 \
3 | --codebook_dim 32 --codebook_size 4294967296 \
4 | --vqgan_ckpt "bitvae_results/infinity_d32_stage2/checkpoints/model_step_199999.ckpt" \
5 | --batch_size 1 --dataset_list "imagenet" --save ./imagenet_256_dim32_stage2 --dataaug "resizecrop" --resolution 256 256 --num_workers 0 \
6 | --default_root_dir "test" --save_prediction --quantizer_type 'MultiScaleBSQ' \
7 | --new_quant --schedule_mode "dynamic" --remove_residual_detach --disable_codebook_usage \
8 | $@
9 |
10 | python3 eval.py --tokenizer 'flux' --inference_type 'image' --patch_size 16 \
11 | --base_ch 128 --encoder_ch_mult 1 2 4 4 4 --decoder_ch_mult 1 2 4 4 4 \
12 | --codebook_dim 32 --codebook_size 4294967296 \
13 | --vqgan_ckpt "bitvae_results/infinity_d32_stage2/checkpoints/model_step_199999.ckpt" \
14 | --batch_size 1 --dataset_list "imagenet" --save ./imagenet_512_dim32_stage2 --dataaug "resizecrop" --resolution 512 512 --num_workers 0 \
15 | --default_root_dir "test" --save_prediction --quantizer_type 'MultiScaleBSQ' \
16 | --new_quant --schedule_mode "dynamic" --remove_residual_detach --disable_codebook_usage \
17 | $@
--------------------------------------------------------------------------------
/scripts/release/train_img_d32_stage1.sh:
--------------------------------------------------------------------------------
1 | WORKER_GPU=8
2 | NODE_NUM=4
3 | NUM_WORKERS=12
4 |
5 | if [[ "$*" == *"--debug"* ]]; then
6 | WORKER_GPU=1
7 | NODE_NUM=1
8 | NUM_WORKERS=0
9 | fi
10 |
11 | torchrun \
12 | --nproc_per_node=$WORKER_GPU \
13 | --nnodes=$NODE_NUM --master_addr=$WORKER_0_HOST \
14 | --node_rank=$NODE_ID --master_port=$PORT \
15 | train.py --num_workers $NUM_WORKERS \
16 | --patch_size 16 \
17 | --base_ch 128 --encoder_ch_mult 1 2 4 4 4 --decoder_ch_mult 1 2 4 4 4 \
18 | --codebook_dim 32 \
19 | --optim_type AdamW --lr 1e-4 --disable_sch --dis_lr_multiplier 1 --max_steps 500000 \
20 | --resolution 256 256 --batch_size 8 --dataset_list "openimages" --dataaug "resizecrop" \
21 | --disc_layers 3 --discriminator_iter_start 50000 \
22 | --l1_weight 1 --perceptual_weight 1 --image_disc_weight 1 --image_gan_weight 0.3 --gan_feat_weight 0 --lfq_weight 4 \
23 | --codebook_size 4294967296 --entropy_loss_weight 0.1 --diversity_gamma 1 \
24 | --default_root_dir "bitvae_results/infinity_d32_stage1" --log_every 20 --ckpt_every 10000 --visu_every 10000 \
25 | --new_quant --lr_drop 450000 \
26 | --remove_residual_detach --use_lecam_reg_zero --base_ch_disc 128 --dis_lr_multiplier 2.0 \
27 | --schedule_mode "dense" --use_stochastic_depth --drop_rate 0.5 --keep_last_quant --tokenizer 'flux' --quantizer_type 'MultiScaleBSQ' $@
--------------------------------------------------------------------------------
/bitvae/modules/normalization.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from einops import rearrange
5 |
6 |
7 | class Normalize(nn.Module):
8 | def __init__(self, in_channels, norm_type):
9 | super().__init__()
10 | assert norm_type in ['group', 'batch', "no"]
11 | if norm_type == 'group':
12 | if in_channels % 32 == 0:
13 | self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
14 | elif in_channels % 24 == 0:
15 | self.norm = nn.GroupNorm(num_groups=24, num_channels=in_channels, eps=1e-6, affine=True)
16 | else:
17 | raise NotImplementedError
18 | elif norm_type == 'batch':
19 | self.norm = nn.SyncBatchNorm(in_channels, track_running_stats=False) # Runtime Error: grad inplace if set track_running_stats to True
20 | elif norm_type == 'no':
21 | self.norm = nn.Identity()
22 |
23 | def forward(self, x):
24 | assert x.ndim == 4
25 | x = self.norm(x)
26 | return x
27 |
28 | def l2norm(t):
29 | return F.normalize(t, dim=-1)
30 |
31 | class LayerNorm(nn.Module):
32 | def __init__(self, dim):
33 | super().__init__()
34 | self.gamma = nn.Parameter(torch.ones(dim))
35 | self.register_buffer("beta", torch.zeros(dim))
36 |
37 | def forward(self, x):
38 | return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
39 |
--------------------------------------------------------------------------------
/scripts/release/train_img_d16_stage2.sh:
--------------------------------------------------------------------------------
1 | WORKER_GPU=8
2 | NODE_NUM=4
3 | NUM_WORKERS=12
4 |
5 | if [[ "$*" == *"--debug"* ]]; then
6 | WORKER_GPU=1
7 | NODE_NUM=1
8 | NUM_WORKERS=0
9 | fi
10 |
11 | torchrun \
12 | --nproc_per_node=$WORKER_GPU \
13 | --nnodes=$NODE_NUM --master_addr=$WORKER_0_HOST \
14 | --node_rank=$NODE_ID --master_port=$PORT \
15 | train.py --num_workers $NUM_WORKERS \
16 | --patch_size 16 \
17 | --base_ch 128 --encoder_ch_mult 1 2 4 4 4 --decoder_ch_mult 1 2 4 4 4 \
18 | --codebook_dim 16 \
19 | --optim_type AdamW --lr 1e-4 --disable_sch --dis_lr_multiplier 1 --max_steps 200000 \
20 | --resolution 1024 1024 --batch_size 8 --dataset_list "openimages" --dataaug "resizecrop" \
21 | --disc_layers 3 --discriminator_iter_start 0 \
22 | --l1_weight 1 --perceptual_weight 1 --image_disc_weight 1 --image_gan_weight 0.3 --gan_feat_weight 0 --lfq_weight 4 \
23 | --codebook_size 65536 --entropy_loss_weight 0.1 --diversity_gamma 1 \
24 | --default_root_dir "bitvae_results/infinity_d16_stage2" --log_every 20 --ckpt_every 5000 --visu_every 5000 \
25 | --new_quant --lr_drop 150000 \
26 | --remove_residual_detach --use_lecam_reg_zero --base_ch_disc 128 --dis_lr_multiplier 2.0 --use_checkpoint \
27 | --schedule_mode "dense" --use_stochastic_depth --drop_rate 0.5 --keep_last_quant --tokenizer 'flux' --quantizer_type 'MultiScaleBSQ' \
28 | --pretrained "bitvae_results/infinity_d16_stage1/checkpoints/model_step_499999.ckpt" --not_load_optimizer --multiscale_training $@
--------------------------------------------------------------------------------
/scripts/release/train_img_d32_stage2.sh:
--------------------------------------------------------------------------------
1 | WORKER_GPU=8
2 | NODE_NUM=4
3 | NUM_WORKERS=12
4 |
5 | if [[ "$*" == *"--debug"* ]]; then
6 | WORKER_GPU=1
7 | NODE_NUM=1
8 | NUM_WORKERS=0
9 | fi
10 |
11 | torchrun \
12 | --nproc_per_node=$WORKER_GPU \
13 | --nnodes=$NODE_NUM --master_addr=$WORKER_0_HOST \
14 | --node_rank=$NODE_ID --master_port=$PORT \
15 | train.py --num_workers $NUM_WORKERS \
16 | --patch_size 16 \
17 | --base_ch 128 --encoder_ch_mult 1 2 4 4 4 --decoder_ch_mult 1 2 4 4 4 \
18 | --codebook_dim 32 \
19 | --optim_type AdamW --lr 1e-4 --disable_sch --dis_lr_multiplier 1 --max_steps 200000 \
20 | --resolution 1024 1024 --batch_size 8 --dataset_list "openimages" --dataaug "resizecrop" \
21 | --disc_layers 3 --discriminator_iter_start 0 \
22 | --l1_weight 1 --perceptual_weight 1 --image_disc_weight 1 --image_gan_weight 0.3 --gan_feat_weight 0 --lfq_weight 4 \
23 | --codebook_size 4294967296 --entropy_loss_weight 0.1 --diversity_gamma 1 \
24 | --default_root_dir "bitvae_results/infinity_d32_stage2" --log_every 20 --ckpt_every 5000 --visu_every 5000 \
25 | --new_quant --lr_drop 150000 \
26 | --remove_residual_detach --use_lecam_reg_zero --base_ch_disc 128 --dis_lr_multiplier 2.0 --use_checkpoint \
27 | --schedule_mode "dense" --use_stochastic_depth --drop_rate 0.5 --keep_last_quant --tokenizer 'flux' --quantizer_type 'MultiScaleBSQ' \
28 | --pretrained "bitvae_results/infinity_d32_stage1/checkpoints/model_step_499999.ckpt" --not_load_optimizer --multiscale_training $@
--------------------------------------------------------------------------------
/bitvae/utils/init_models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import math
5 | from bitvae.utils.misc import is_torch_optimizer
6 |
7 | def load_unstrictly(state_dict, model, loaded_keys=[]):
8 | missing_keys = []
9 | for name, param in model.named_parameters():
10 | if name in state_dict:
11 | try:
12 | param.data.copy_(state_dict[name])
13 | except:
14 | # print(f"{name} mismatch: param {name}, shape {param.data.shape}, state_dict shape {state_dict[name].shape}")
15 | missing_keys.append(name)
16 | elif name not in loaded_keys:
17 | missing_keys.append(name)
18 | return model, missing_keys
19 |
20 | def resume_from_ckpt(state_dict, model_optims, load_optimizer=True):
21 | all_missing_keys = []
22 | # load weights first
23 | for k in model_optims:
24 | if model_optims[k] and (not is_torch_optimizer(model_optims[k])) and k in state_dict:
25 | model_optims[k], missing_keys = load_unstrictly(state_dict[k], model_optims[k])
26 | all_missing_keys += missing_keys
27 |
28 | if len(all_missing_keys) == 0 and load_optimizer:
29 | print("Loading optimizer states")
30 | for k in model_optims:
31 | if model_optims[k] and is_torch_optimizer(model_optims[k]) and k in state_dict:
32 | model_optims[k].load_state_dict(state_dict[k])
33 | else:
34 | print(f"missing weights: {all_missing_keys}, do not load optimzer states")
35 | return model_optims, state_dict["step"]
36 |
--------------------------------------------------------------------------------
/scripts/release/train_img_d16_stage1.sh:
--------------------------------------------------------------------------------
1 | if [[ "$ARNOLD_DEVICE_TYPE" == *A100* ]]; then
2 | IB_HCA=mlx5
3 | export NCCL_IB_HCA=$IB_HCA
4 | else
5 | IB_HCA=$ARNOLD_RDMA_DEVICE:1
6 | fi
7 |
8 | if [[ "$RUNTIME_IDC_NAME" == *uswest2* ]]; then
9 | IDC_NAME=bond0
10 | export NCCL_SOCKET_IFNAME=$IDC_NAME
11 | else
12 | IDC_NAME=eth0
13 | fi
14 |
15 | port=$(echo "$ARNOLD_WORKER_0_PORT" | cut -d "," -f 1)
16 | echo $port
17 |
18 | export NCCL_DEBUG=WARN
19 | export NCCL_IB_DISABLE=0
20 | export NCCL_IB_GID_INDEX=3
21 |
22 | NUM_WORKERS=12
23 |
24 | if [[ "$*" == *"--debug"* ]]; then
25 | ARNOLD_WORKER_NUM=1
26 | ARNOLD_WORKER_GPU=1
27 | NUM_WORKERS=0
28 | fi
29 |
30 | torchrun \
31 | --nproc_per_node=$ARNOLD_WORKER_GPU \
32 | --nnodes=$ARNOLD_WORKER_NUM --master_addr=$ARNOLD_WORKER_0_HOST \
33 | --node-rank=$ARNOLD_ID --master_port=$PORT \
34 | train.py --num_workers $NUM_WORKERS \
35 | --patch_size 16 \
36 | --base_ch 128 --encoder_ch_mult 1 2 4 4 4 --decoder_ch_mult 1 2 4 4 4 \
37 | --codebook_dim 16 \
38 | --optim_type AdamW --lr 1e-4 --disable_sch --dis_lr_multiplier 1 --max_steps 500000 \
39 | --resolution 256 256 --batch_size 8 --dataset_list "imagenet" --dataaug "resizecrop" \
40 | --disc_layers 3 --discriminator_iter_start 50000 \
41 | --l1_weight 1 --perceptual_weight 1 --image_disc_weight 1 --image_gan_weight 0.3 --gan_feat_weight 0 --lfq_weight 4 \
42 | --codebook_size 65536 --entropy_loss_weight 0.1 --diversity_gamma 1 \
43 | --default_root_dir "bitvae_results/infinity_d16_stage1" --log_every 20 --ckpt_every 10000 --visu_every 10000 \
44 | --new_quant --lr_drop 450000 \
45 | --remove_residual_detach --use_lecam_reg_zero --base_ch_disc 128 --dis_lr_multiplier 2.0 \
46 | --schedule_mode "dense" --use_stochastic_depth --drop_rate 0.5 --keep_last_quant --tokenizer 'flux' --quantizer_type 'MultiScaleBSQ' $@
--------------------------------------------------------------------------------
/bitvae/utils/misc.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Foundationvision, Inc. All Rights Reserved
2 |
3 | import torch
4 | import torch.distributed as dist
5 | import imageio
6 | import os
7 | import random
8 |
9 | import math
10 | import numpy as np
11 | import skvideo.io
12 | from einops import rearrange
13 | import torch.optim as optim
14 |
15 | from contextlib import contextmanager
16 |
17 | ptdtype = {None: torch.float32, 'fp32': torch.float32, 'bf16': torch.bfloat16}
18 |
19 | def rank_zero_only(fn):
20 | def wrapped_fn(*args, **kwargs):
21 | if not dist.is_initialized() or dist.get_rank() == 0:
22 | return fn(*args, **kwargs)
23 | return wrapped_fn
24 |
25 | def is_torch_optimizer(obj):
26 | return isinstance(obj, optim.Optimizer)
27 |
28 | def rearranged_forward(x, func):
29 | x = rearrange(x, "B C H W -> B H W C")
30 | x = func(x)
31 | x = rearrange(x, "B H W C -> B C H W")
32 | return x
33 |
34 | def is_dtype_16(data):
35 | return data.dtype == torch.float16 or data.dtype == torch.bfloat16
36 |
37 | @contextmanager
38 | def set_tf32_flags(flag):
39 | old_matmul_flag = torch.backends.cuda.matmul.allow_tf32
40 | old_cudnn_flag = torch.backends.cudnn.allow_tf32
41 | torch.backends.cuda.matmul.allow_tf32 = flag
42 | torch.backends.cudnn.allow_tf32 = flag
43 | try:
44 | yield
45 | finally:
46 | # Restore the original flags
47 | torch.backends.cuda.matmul.allow_tf32 = old_matmul_flag
48 | torch.backends.cudnn.allow_tf32 = old_cudnn_flag
49 |
50 | def get_last_ckpt(root_dir):
51 | if not os.path.exists(root_dir): return None
52 | ckpt_files = {}
53 | for dirpath, dirnames, filenames in os.walk(root_dir):
54 | for filename in filenames:
55 | if filename.endswith('.ckpt'):
56 | num_iter = int(filename.split('.ckpt')[0].split('_')[-1])
57 | ckpt_files[num_iter]=os.path.join(dirpath, filename)
58 | iter_list = list(ckpt_files.keys())
59 | if len(iter_list) == 0: return None
60 | max_iter = max(iter_list)
61 | return ckpt_files[max_iter]
62 |
63 |
--------------------------------------------------------------------------------
/bitvae/modules/quantizer/dynamic_resolution.py:
--------------------------------------------------------------------------------
1 | import json
2 | import numpy as np
3 | import tqdm
4 |
5 | vae_stride = 16
6 | ratio2hws = {
7 | 1.000: [(1,1),(2,2),(4,4),(6,6),(8,8),(12,12),(16,16),(20,20),(24,24),(32,32),(40,40),(48,48),(64,64),(80,80),(96,96),(128,128)],
8 | 1.250: [(1,1),(2,2),(3,3),(5,4),(10,8),(15,12),(20,16),(25,20),(30,24),(35,28),(45,36),(55,44),(70,56),(90,72),(110,88),(140,112)],
9 | 1.333: [(1,1),(2,2),(4,3),(8,6),(12,9),(16,12),(20,15),(24,18),(28,21),(36,27),(48,36),(60,45),(72,54),(96,72),(120,90),(144,108)],
10 | 1.500: [(1,1),(2,2),(3,2),(6,4),(9,6),(15,10),(21,14),(27,18),(33,22),(39,26),(48,32),(63,42),(78,52),(96,64),(126,84),(156,104)],
11 | 1.750: [(1,1),(2,2),(3,3),(7,4),(11,6),(14,8),(21,12),(28,16),(35,20),(42,24),(56,32),(70,40),(84,48),(112,64),(140,80),(168,96)],
12 | 2.000: [(1,1),(2,2),(4,2),(6,3),(10,5),(16,8),(22,11),(30,15),(38,19),(46,23),(60,30),(74,37),(90,45),(120,60),(148,74),(180,90)],
13 | 2.500: [(1,1),(2,2),(5,2),(10,4),(15,6),(20,8),(25,10),(30,12),(40,16),(50,20),(65,26),(80,32),(100,40),(130,52),(160,64),(200,80)],
14 | 3.000: [(1,1),(2,2),(6,2),(9,3),(15,5),(21,7),(27,9),(36,12),(45,15),(54,18),(72,24),(90,30),(111,37),(144,48),(180,60),(222,74)],
15 | }
16 | full_ratio2hws = {}
17 | for ratio, hws in ratio2hws.items():
18 | full_ratio2hws[ratio] = hws
19 | full_ratio2hws[int(1/ratio*1000)/1000] = [(item[1], item[0]) for item in hws]
20 |
21 | dynamic_resolution_h_w = {}
22 | predefined_HW_Scales_dynamic = {}
23 | aspect_ratio_scale_list = []
24 | bs_dict = {7: 8, 10: 4, 13: 1, 16: 1} # 256x256: batch=8, 512x512: batch=4, 1024x1024: batch=1 (bs=1 avoid OOM)
25 | for ratio in full_ratio2hws:
26 | dynamic_resolution_h_w[ratio] ={}
27 | for ind, leng in enumerate([7, 10, 13, 16]):
28 | h, w = full_ratio2hws[ratio][leng-1][0], full_ratio2hws[ratio][leng-1][1] # feature map size
29 | pixel = (h * vae_stride, w * vae_stride) # The original image (H, W)
30 | dynamic_resolution_h_w[ratio][pixel[1]] = {
31 | 'pixel': pixel,
32 | 'scales': full_ratio2hws[ratio][:leng]
33 | } # W as key
34 | predefined_HW_Scales_dynamic[(h, w)] = full_ratio2hws[ratio][:leng]
35 | # deal with aspect_ratio_scale_list
36 | info_dict = {"ratio": ratio, "h": h * vae_stride, "w": w * vae_stride, "bs": bs_dict[leng]}
37 | aspect_ratio_scale_list.append(info_dict)
38 |
--------------------------------------------------------------------------------
/bitvae/evaluation/fid.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy import linalg
3 |
4 |
5 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
6 | """Numpy implementation of the Frechet Distance.
7 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
8 | and X_2 ~ N(mu_2, C_2) is
9 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
10 |
11 | Stable version by Dougal J. Sutherland.
12 |
13 | Params:
14 | -- mu1 : Numpy array containing the activations of a layer of the
15 | inception net (like returned by the function 'get_predictions')
16 | for generated samples.
17 | -- mu2 : The sample mean over activations, precalculated on an
18 | representative data set.
19 | -- sigma1: The covariance matrix over activations for generated samples.
20 | -- sigma2: The covariance matrix over activations, precalculated on an
21 | representative data set.
22 |
23 | Returns:
24 | -- : The Frechet Distance.
25 | """
26 |
27 | mu1 = np.atleast_1d(mu1)
28 | mu2 = np.atleast_1d(mu2)
29 |
30 | sigma1 = np.atleast_2d(sigma1)
31 | sigma2 = np.atleast_2d(sigma2)
32 |
33 | assert (
34 | mu1.shape == mu2.shape
35 | ), "Training and test mean vectors have different lengths"
36 | assert (
37 | sigma1.shape == sigma2.shape
38 | ), "Training and test covariances have different dimensions"
39 |
40 | diff = mu1 - mu2
41 |
42 | # Product might be almost singular
43 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
44 | if not np.isfinite(covmean).all():
45 | msg = (
46 | "fid calculation produces singular product; "
47 | "adding %s to diagonal of cov estimates"
48 | ) % eps
49 | print(msg)
50 | offset = np.eye(sigma1.shape[0]) * eps
51 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
52 |
53 | # Numerical error might give slight imaginary component
54 | if np.iscomplexobj(covmean):
55 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
56 | m = np.max(np.abs(covmean.imag))
57 | raise ValueError("Imaginary component {}".format(m))
58 | covmean = covmean.real
59 |
60 | tr_covmean = np.trace(covmean)
61 |
62 | return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
--------------------------------------------------------------------------------
/bitvae/utils/distributed.py:
--------------------------------------------------------------------------------
1 | # from https://github.com/FoundationVision/LlamaGen/blob/main/utils/distributed.py
2 | import os
3 | import torch
4 | import subprocess
5 | import torch.distributed as dist
6 | from bitvae.utils.misc import rank_zero_only
7 |
8 | def setup_for_distributed(is_master):
9 | """
10 | This function disables printing when not in master process
11 | """
12 | import builtins as __builtin__
13 | builtin_print = __builtin__.print
14 |
15 | def print(*args, **kwargs):
16 | force = kwargs.pop('force', False)
17 | if is_master or force:
18 | builtin_print(*args, **kwargs)
19 |
20 | __builtin__.print = print
21 |
22 | def init_distributed_mode(args):
23 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
24 | args.rank = int(os.environ["RANK"])
25 | args.world_size = int(os.environ['WORLD_SIZE'])
26 | args.gpu = int(os.environ['LOCAL_RANK'])
27 | args.dist_url = 'env://'
28 | os.environ['LOCAL_SIZE'] = str(torch.cuda.device_count())
29 | elif 'SLURM_PROCID' in os.environ:
30 | proc_id = int(os.environ['SLURM_PROCID'])
31 | ntasks = int(os.environ['SLURM_NTASKS'])
32 | node_list = os.environ['SLURM_NODELIST']
33 | num_gpus = torch.cuda.device_count()
34 | addr = subprocess.getoutput(
35 | 'scontrol show hostname {} | head -n1'.format(node_list))
36 | os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', '29500')
37 | os.environ['MASTER_ADDR'] = addr
38 | os.environ['WORLD_SIZE'] = str(ntasks)
39 | os.environ['RANK'] = str(proc_id)
40 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
41 | os.environ['LOCAL_SIZE'] = str(num_gpus)
42 | args.dist_url = 'env://'
43 | args.world_size = ntasks
44 | args.rank = proc_id
45 | args.gpu = proc_id % num_gpus
46 | else:
47 | print('Not using distributed mode')
48 | args.distributed = False
49 | return
50 |
51 | args.distributed = True
52 |
53 | torch.cuda.set_device(args.gpu)
54 | args.dist_backend = 'nccl'
55 | print('| distributed init (rank {}): {}'.format(
56 | args.rank, args.dist_url), flush=True)
57 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
58 | world_size=args.world_size, rank=args.rank)
59 | torch.distributed.barrier()
60 | setup_for_distributed(args.rank == 0)
61 |
62 |
63 | def reduce_losses(loss_dict, dst=0):
64 | loss_names = list(loss_dict.keys())
65 | loss_tensor = torch.stack([loss_dict[name] for name in loss_names])
66 |
67 | dist.reduce(loss_tensor, dst=dst, op=dist.ReduceOp.SUM)
68 | # Only average the loss values on the destination rank
69 | if dist.get_rank() == dst:
70 | loss_tensor /= dist.get_world_size()
71 | averaged_losses = {name: loss_tensor[i].item() for i, name in enumerate(loss_names)}
72 | else:
73 | averaged_losses = {name: None for name in loss_names}
74 |
75 | return averaged_losses
76 |
77 | @rank_zero_only
78 | def average_losses(loss_dict_list):
79 | sum_dict = {}
80 | count_dict = {}
81 | for loss_dict in loss_dict_list:
82 | for key, value in loss_dict.items():
83 | if key in sum_dict:
84 | sum_dict[key] += value
85 | count_dict[key] += 1
86 | else:
87 | sum_dict[key] = value
88 | count_dict[key] = 1
89 |
90 | avg_dict = {key: sum_dict[key] / count_dict[key] for key in sum_dict}
91 | return avg_dict
--------------------------------------------------------------------------------
/bitvae/models/discriminator.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torch.nn as nn
3 |
4 | from bitvae.modules.normalization import Normalize
5 |
6 | class DiscriminatorPool:
7 | def __init__(self, pool_size):
8 | self.pool_size = int(pool_size)
9 | self.num_imgs = 0
10 | self.images = []
11 |
12 | def query(self, images):
13 | if self.pool_size == 0:
14 | return images
15 |
16 | return_images = []
17 | for image in images:
18 | if self.num_imgs < self.pool_size:
19 | self.images.append(image)
20 | self.num_imgs += 1
21 | return_images.append(image)
22 | else:
23 | if random.uniform(0, 1) > 0.5:
24 | i = random.randint(0, self.pool_size - 1)
25 | tmp = self.images[i].clone()
26 | self.images[i] = image
27 | return_images.append(tmp)
28 | else:
29 | return_images.append(image)
30 | return torch.stack(return_images)
31 |
32 | class ImageDiscriminator(nn.Module):
33 | def __init__(self, args):
34 | super().__init__()
35 | self.discriminator = NLayerDiscriminator(ndf=args.base_ch_disc) # by default using PatchGAN
36 | self.disc_pool = args.disc_pool # be default "no"
37 | if args.disc_pool == "yes":
38 | self.real_pool = DiscriminatorPool(pool_size=args.batch_size[0] * args.disc_pool_size)
39 | self.fake_pool = DiscriminatorPool(pool_size=args.batch_size[0] * args.disc_pool_size)
40 |
41 | def forward(self, x, pool_name=None):
42 | if pool_name and self.disc_pool == "yes":
43 | assert pool_name in ["real", "fake"]
44 | if pool_name == "real":
45 | x = self.real_pool.query(x)
46 | elif pool_name == "fake":
47 | x = self.fake_pool.query(x)
48 | # by default without pool
49 | return self.discriminator(x)
50 |
51 | class NLayerDiscriminator(nn.Module):
52 | """Defines a PatchGAN discriminator as in Pix2Pix
53 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
54 | """
55 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
56 | """Construct a PatchGAN discriminator
57 | Parameters:
58 | input_nc (int) -- the number of channels in input images
59 | ndf (int) -- the number of filters in the last conv layer
60 | n_layers (int) -- the number of conv layers in the discriminator
61 | norm_layer -- normalization layer
62 | """
63 | super(NLayerDiscriminator, self).__init__()
64 | norm_type = "batch"
65 | use_bias = norm_type != "batch"
66 |
67 | kw = 4
68 | padw = 1
69 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
70 | nf_mult = 1
71 | nf_mult_prev = 1
72 | for n in range(1, n_layers): # gradually increase the number of filters
73 | nf_mult_prev = nf_mult
74 | nf_mult = min(2 ** n, 8)
75 | sequence += [
76 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
77 | Normalize(ndf * nf_mult, norm_type=norm_type),
78 | nn.LeakyReLU(0.2, True)
79 | ]
80 |
81 | nf_mult_prev = nf_mult
82 | nf_mult = min(2 ** n_layers, 8)
83 | sequence += [
84 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
85 | Normalize(ndf * nf_mult, norm_type=norm_type),
86 | nn.LeakyReLU(0.2, True)
87 | ]
88 |
89 | sequence += [
90 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
91 | self.main = nn.Sequential(*sequence)
92 |
93 | self.apply(self._init_weights)
94 |
95 | def _init_weights(self, module):
96 | if isinstance(module, nn.Conv2d):
97 | nn.init.normal_(module.weight.data, 0.0, 0.02)
98 | elif isinstance(module, nn.BatchNorm2d):
99 | nn.init.normal_(module.weight.data, 1.0, 0.02)
100 | nn.init.constant_(module.bias.data, 0)
101 |
102 | def forward(self, input):
103 | """Standard forward."""
104 | return self.main(input)
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Bitwise Visual Tokenizer
2 | The training and inference code of bitwise tokenizer used by [Infinity](https://github.com/FoundationVision/Infinity).
3 |
4 | ### BitVAE Model ZOO
5 | We provide Infinity models for you to play with, which are on
or can be downloaded from the following links:
6 |
7 | ### Visual Tokenizer
8 |
9 | | vocabulary | stride | IN-256 rFID $\downarrow$ | IN-256 PSNR $\uparrow$ | IN-512 rFID $\downarrow$ | IN-512 PSNR $\uparrow$ | HF weights🤗 |
10 | |:----------:|:-----:|:--------:|:---------:|:-------:|:-------:|:------------------------------------------------------------------------------------|
11 | | $V_d=2^{16}$ | 16 | 1.22 | 20.9 | 0.31 | 22.6 | [infinity_vae_d16.pth](https://huggingface.co/FoundationVision/infinity/blob/main/infinity_vae_d16.pth) |
12 | | $V_d=2^{24}$ | 16 | 0.75 | 22.0 | 0.30 | 23.5 | [infinity_vae_d24.pth](https://huggingface.co/FoundationVision/infinity/blob/main/infinity_vae_d24.pth) |
13 | | $V_d=2^{32}$ | 16 | 0.61 | 22.7 | 0.23 | 24.4 | [infinity_vae_d32.pth](https://huggingface.co/FoundationVision/infinity/blob/main/infinity_vae_d32.pth) |
14 | | $V_d=2^{64}$ | 16 | 0.33 | 24.9 | 0.15 | 26.4 | [infinity_vae_d64.pth](https://huggingface.co/FoundationVision/infinity/blob/main/infinity_vae_d64.pth) |
15 | | $V_d=2^{32}$ | 16 | 0.75 | 21.9 | 0.32 | 23.6 | [infinity_vae_d32_reg.pth](https://huggingface.co/FoundationVision/Infinity/blob/main/infinity_vae_d32reg.pth) |
16 |
17 | ### Environment installation
18 | ```
19 | bash scripts/prepare.sh
20 | ```
21 |
22 | Download `checkpoints` and `labels` from [Google Drive](https://drive.google.com/drive/folders/15VCFUpcv1ktU7RR3Yw_LqI4vuR5Q2r0y?usp=sharing) and put them under the project folder. If you want to use our trained model weights, please also download `bitvae_results`.
23 | We expect that the data is organized as below.
24 | ```
25 | ${PROJECT_ROOT}
26 | -- bitvae
27 | -- bitvae_results
28 | -- Infinity_d16_stage1
29 | -- Infinity_d16_stage2
30 | -- Infinity_d32_stage1
31 | -- Infinity_d32_stage2
32 | -- checkpoints
33 | -- labels
34 | -- scripts
35 | -- test
36 | ...
37 | ```
38 |
39 |
40 | ### Training
41 | Before training, please generate a `labels/openimages/train.txt` according to our provided `labels/imagenet/val_example.txt`. please replace with the real path on your system.
42 |
43 | Tokenizer with hidden dimension 16
44 | ```
45 | bash scripts/release/train_img_d16_stage1.sh # stage 1: single-scale pre-training
46 | bash scripts/release/train_img_d16_stage2.sh # stage 2: multi-scale fine-tuning
47 | ```
48 | Tokenizer with hidden dimension 32
49 | ```
50 | bash scripts/release/train_img_d32_stage1.sh # stage 1: single-scale pre-training
51 | bash scripts/release/train_img_d32_stage2.sh # stage 2: multi-scale fine-tuning
52 | ```
53 |
54 | ### Testing & evaluation
55 | Before testing, please generate a `labels/imagenet/val.txt` according to our provided `labels/imagenet/val_example.txt`. please replace with the real path on your system.
56 |
57 | Tokenizer with hidden dimension 16
58 | ```
59 | bash scripts/release/test_img_d16_stage1.sh
60 | bash scripts/release/test_img_d16_stage2.sh
61 | ```
62 | Tokenizer with hidden dimension 32
63 | ```
64 | bash scripts/release/test_img_d32_stage1.sh
65 | bash scripts/release/test_img_d32_stage2.sh
66 | ```
67 | ### 📖 Citation
68 | If our work assists your research, feel free to give us a star ⭐ or cite us using:
69 |
70 | ```
71 | @misc{Infinity,
72 | title={Infinity: Scaling Bitwise AutoRegressive Modeling for High-Resolution Image Synthesis},
73 | author={Jian Han and Jinlai Liu and Yi Jiang and Bin Yan and Yuqi Zhang and Zehuan Yuan and Bingyue Peng and Xiaobing Liu},
74 | year={2024},
75 | eprint={2412.04431},
76 | archivePrefix={arXiv},
77 | primaryClass={cs.CV},
78 | url={https://arxiv.org/abs/2412.04431},
79 | }
80 | ```
81 |
82 | ```
83 | @misc{VAR,
84 | title={Visual Autoregressive Modeling: Scalable Image Generation via Next-Scale Prediction},
85 | author={Keyu Tian and Yi Jiang and Zehuan Yuan and Bingyue Peng and Liwei Wang},
86 | year={2024},
87 | eprint={2404.02905},
88 | archivePrefix={arXiv},
89 | primaryClass={cs.CV},
90 | url={https://arxiv.org/abs/2404.02905},
91 | }
92 | ```
93 |
94 | ### License
95 | This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
96 |
--------------------------------------------------------------------------------
/bitvae/modules/lpips.py:
--------------------------------------------------------------------------------
1 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
2 |
3 | import os, hashlib
4 | import requests
5 | from tqdm import tqdm
6 |
7 | import torch
8 | import torch.nn as nn
9 | from torchvision import models
10 | from collections import namedtuple
11 |
12 | from bitvae.utils.misc import set_tf32_flags
13 |
14 | URL_MAP = {
15 | "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
16 | }
17 |
18 | CKPT_MAP = {
19 | "vgg_lpips": "vgg.pth"
20 | }
21 |
22 | MD5_MAP = {
23 | "vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
24 | }
25 |
26 | def download(url, local_path, chunk_size=1024):
27 | os.makedirs(os.path.split(local_path)[0], exist_ok=True)
28 | with requests.get(url, stream=True) as r:
29 | total_size = int(r.headers.get("content-length", 0))
30 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
31 | with open(local_path, "wb") as f:
32 | for data in r.iter_content(chunk_size=chunk_size):
33 | if data:
34 | f.write(data)
35 | pbar.update(chunk_size)
36 |
37 |
38 | def md5_hash(path):
39 | with open(path, "rb") as f:
40 | content = f.read()
41 | return hashlib.md5(content).hexdigest()
42 |
43 |
44 | def get_ckpt_path(name, root, check=False):
45 | assert name in URL_MAP
46 | path = os.path.join(root, CKPT_MAP[name])
47 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
48 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
49 | download(URL_MAP[name], path)
50 | md5 = md5_hash(path)
51 | assert md5 == MD5_MAP[name], md5
52 | return path
53 |
54 |
55 | class LPIPS(nn.Module):
56 | # Learned perceptual metric
57 | def __init__(self, use_dropout=True, upcast_tf32=False):
58 | super().__init__()
59 | self.upcast_tf32 = upcast_tf32
60 | self.scaling_layer = ScalingLayer()
61 | self.chns = [64, 128, 256, 512, 512] # vg16 features
62 | self.net = vgg16(pretrained=True, requires_grad=False)
63 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
64 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
65 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
66 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
67 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
68 | self.load_from_pretrained()
69 | for param in self.parameters():
70 | param.requires_grad = False
71 |
72 | def load_from_pretrained(self, name="vgg_lpips"):
73 | ckpt = get_ckpt_path(name, os.path.join(os.path.dirname(os.path.abspath(__file__)), "cache"))
74 | self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu"), weights_only=True), strict=False)
75 | print("loaded pretrained LPIPS loss from {}".format(ckpt))
76 |
77 | @classmethod
78 | def from_pretrained(cls, name="vgg_lpips"):
79 | if name is not "vgg_lpips":
80 | raise NotImplementedError
81 | model = cls()
82 | ckpt = get_ckpt_path(name, os.path.join(os.path.dirname(os.path.abspath(__file__)), "cache"))
83 | model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu"), weights_only=True), strict=False)
84 | return model
85 |
86 | def forward(self, input, target):
87 | with set_tf32_flags(not self.upcast_tf32):
88 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
89 | outs0, outs1 = self.net(in0_input), self.net(in1_input)
90 | feats0, feats1, diffs = {}, {}, {}
91 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
92 | for kk in range(len(self.chns)):
93 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
94 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
95 |
96 | res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
97 | val = res[0]
98 | for l in range(1, len(self.chns)):
99 | # print(res[l].shape)
100 | val += res[l]
101 |
102 | return val
103 |
104 |
105 | class ScalingLayer(nn.Module):
106 | def __init__(self):
107 | super(ScalingLayer, self).__init__()
108 | self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
109 | self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
110 |
111 | def forward(self, inp):
112 | return (inp - self.shift) / self.scale
113 |
114 |
115 | class NetLinLayer(nn.Module):
116 | """ A single linear layer which does a 1x1 conv """
117 | def __init__(self, chn_in, chn_out=1, use_dropout=False):
118 | super(NetLinLayer, self).__init__()
119 | layers = [nn.Dropout(), ] if (use_dropout) else []
120 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
121 | self.model = nn.Sequential(*layers)
122 |
123 |
124 | class vgg16(torch.nn.Module):
125 | def __init__(self, requires_grad=False, pretrained=True):
126 | super(vgg16, self).__init__()
127 | # load locally
128 | assert pretrained == True
129 | vgg_model = models.vgg16()
130 | vgg_model.load_state_dict(torch.load("checkpoints/vgg16-397923af.pth", weights_only=True))
131 | vgg_pretrained_features = vgg_model.features
132 |
133 | self.slice1 = torch.nn.Sequential()
134 | self.slice2 = torch.nn.Sequential()
135 | self.slice3 = torch.nn.Sequential()
136 | self.slice4 = torch.nn.Sequential()
137 | self.slice5 = torch.nn.Sequential()
138 | self.N_slices = 5
139 | for x in range(4):
140 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
141 | for x in range(4, 9):
142 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
143 | for x in range(9, 16):
144 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
145 | for x in range(16, 23):
146 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
147 | for x in range(23, 30):
148 | self.slice5.add_module(str(x), vgg_pretrained_features[x])
149 | if not requires_grad:
150 | for param in self.parameters():
151 | param.requires_grad = False
152 |
153 | def forward(self, X):
154 | h = self.slice1(X)
155 | h_relu1_2 = h
156 | h = self.slice2(h)
157 | h_relu2_2 = h
158 | h = self.slice3(h)
159 | h_relu3_3 = h
160 | h = self.slice4(h)
161 | h_relu4_3 = h
162 | h = self.slice5(h)
163 | h_relu5_3 = h
164 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
165 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
166 | return out
167 |
168 |
169 | def normalize_tensor(x,eps=1e-10):
170 | # norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True))
171 | norm_factor = x.norm(p=2, dim=1, keepdim=True)
172 | return x/(norm_factor+eps)
173 |
174 |
175 | def spatial_average(x, keepdim=True):
176 | return x.mean([2,3],keepdim=keepdim)
177 |
--------------------------------------------------------------------------------
/bitvae/utils/arguments.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | from bitvae.models import d_vae
4 |
5 | def add_model_specific_args(args, parser):
6 | if args.tokenizer == "flux":
7 | parser = d_vae.add_model_specific_args(parser) # flux config
8 | d_vae_model = d_vae
9 | else:
10 | raise NotImplementedError
11 | return args, parser, d_vae_model
12 |
13 | class MainArgs:
14 | @staticmethod
15 | def add_main_args(parser):
16 | # training
17 | parser.add_argument('--max_steps', type=int, default=1e6)
18 | parser.add_argument('--log_every', type=int, default=1)
19 | parser.add_argument('--visu_every', type=int, default=1000)
20 | parser.add_argument('--ckpt_every', type=int, default=1000)
21 | parser.add_argument('--default_root_dir', type=str, required=True)
22 | parser.add_argument('--multiscale_training', action="store_true")
23 |
24 | # optimization
25 | parser.add_argument('--lr', type=float, default=1e-4)
26 | parser.add_argument('--beta1', type=float, default=0.9)
27 | parser.add_argument('--beta2', type=float, default=0.95)
28 | parser.add_argument('--warmup_steps', type=int, default=0)
29 | parser.add_argument('--optim_type', type=str, default="Adam", choices=["Adam", "AdamW"])
30 | parser.add_argument('--disc_optim_type', type=str, default=None, choices=[None, "rmsprop"])
31 | parser.add_argument('--lr_min', type=float, default=0.)
32 | parser.add_argument('--warmup_lr_init', type=float, default=0.)
33 | parser.add_argument('--max_grad_norm', type=float, default=1.0)
34 | parser.add_argument('--max_grad_norm_disc', type=float, default=1.0)
35 | parser.add_argument('--disable_sch', action="store_true")
36 |
37 | # basic d_vae config
38 | parser.add_argument('--patch_size', type=int, default=8)
39 | parser.add_argument('--codebook_dim', type=int, default=16)
40 | parser.add_argument('--quantizer_type', type=str, default='MultiScaleLFQ')
41 |
42 | parser.add_argument('--new_quant', action="store_true") # use new quantization (fix the potential bugs of the old quantizer)
43 | parser.add_argument('--use_decay_factor', action="store_true")
44 | parser.add_argument('--use_stochastic_depth', action="store_true")
45 | parser.add_argument("--drop_rate", type=float, default=0.0)
46 | parser.add_argument('--schedule_mode', type=str, default="original", choices=["original", "dynamic", "dense", "same1", "same2", "same3", "half", "dense_f8"])
47 | parser.add_argument('--lr_drop', nargs='*', type=int, default=None, help="A list of numeric values. Example: --values 270 300")
48 | parser.add_argument('--lr_drop_rate', type=float, default=0.1)
49 | parser.add_argument('--keep_first_quant', action="store_true")
50 | parser.add_argument('--keep_last_quant', action="store_true")
51 | parser.add_argument('--remove_residual_detach', action="store_true")
52 | parser.add_argument('--use_out_phi', action="store_true")
53 | parser.add_argument('--use_out_phi_res', action="store_true")
54 | parser.add_argument('--lecam_weight', type=float, default=0.05)
55 | parser.add_argument('--perceptual_model', type=str, default="vgg16", choices=["vgg16"])
56 | parser.add_argument('--base_ch_disc', type=int, default=64)
57 | parser.add_argument('--random_flip', action="store_true")
58 | parser.add_argument('--flip_prob', type=float, default=0.5)
59 | parser.add_argument('--flip_mode', type=str, default="stochastic", choices=["stochastic"])
60 | parser.add_argument('--max_flip_lvl', type=int, default=1)
61 | parser.add_argument('--not_load_optimizer', action="store_true")
62 | parser.add_argument('--use_lecam_reg_zero', action="store_true")
63 | parser.add_argument('--rm_downsample', action="store_true")
64 | parser.add_argument('--random_flip_1lvl', action="store_true")
65 | parser.add_argument('--flip_lvl_idx', type=int, default=0)
66 | parser.add_argument('--drop_when_test', action="store_true")
67 | parser.add_argument('--drop_lvl_idx', type=int, default=None)
68 | parser.add_argument('--drop_lvl_num', type=int, default=0)
69 | parser.add_argument('--compute_all_commitment', action="store_true")
70 | parser.add_argument('--disable_codebook_usage', action="store_true")
71 | parser.add_argument('--random_short_schedule', action="store_true")
72 | parser.add_argument('--short_schedule_prob', type=float, default=0.5)
73 | parser.add_argument('--disable_flip_prob', type=float, default=0.0)
74 | parser.add_argument('--zeta', type=float, default=1.0) # entropy penalty weight
75 | parser.add_argument('--disable_codebook_usage_bit', action="store_true")
76 | parser.add_argument('--gamma', type=float, default=1.0) # loss weight of H(E[p(c|u)])
77 | parser.add_argument('--uniform_short_schedule', action="store_true")
78 |
79 | # discriminator config
80 | parser.add_argument('--dis_warmup_steps', type=int, default=0)
81 | parser.add_argument('--dis_lr_multiplier', type=float, default=1.)
82 | parser.add_argument('--dis_minlr_multiplier', action="store_true")
83 | parser.add_argument('--disc_layers', type=int, default=3)
84 | parser.add_argument('--discriminator_iter_start', type=int, default=0)
85 | parser.add_argument('--disc_pretrain_iter', type=int, default=0)
86 | parser.add_argument('--disc_optim_steps', type=int, default=1)
87 | parser.add_argument('--disc_warmup', type=int, default=0)
88 | parser.add_argument('--disc_pool', type=str, default="no", choices=["no", "yes"])
89 | parser.add_argument('--disc_pool_size', type=int, default=1000)
90 |
91 | # loss
92 | parser.add_argument("--recon_loss_type", type=str, default='l1', choices=['l1', 'l2'])
93 | parser.add_argument('--image_gan_weight', type=float, default=1.0)
94 | parser.add_argument('--image_disc_weight', type=float, default=0.)
95 | parser.add_argument('--l1_weight', type=float, default=4.0)
96 | parser.add_argument('--gan_feat_weight', type=float, default=0.0)
97 | parser.add_argument('--perceptual_weight', type=float, default=0.0)
98 | parser.add_argument('--kl_weight', type=float, default=0.)
99 | parser.add_argument('--lfq_weight', type=float, default=0.)
100 | parser.add_argument('--entropy_loss_weight', type=float, default=0.1)
101 | parser.add_argument('--commitment_loss_weight', type=float, default=0.25)
102 | parser.add_argument('--diversity_gamma', type=float, default=1)
103 | parser.add_argument('--norm_type', type=str, default='group', choices=['batch', 'group', "no"])
104 | parser.add_argument('--disc_loss_type', type=str, default='hinge', choices=['hinge', 'vanilla'])
105 |
106 | # acceleration
107 | parser.add_argument('--use_checkpoint', action="store_true")
108 | parser.add_argument('--precision', type=str, default="fp32", choices=['fp32', 'bf16']) # disable fp16
109 | parser.add_argument('--encoder_dtype', type=str, default="fp32", choices=['fp32', 'bf16']) # disable fp16
110 | parser.add_argument('--upcast_tf32', action="store_true")
111 |
112 | # initialization
113 | parser.add_argument('--tokenizer', type=str, default='flux', choices=["flux"])
114 | parser.add_argument('--pretrained', type=str, default=None)
115 | parser.add_argument('--pretrained_mode', type=str, default="full", choices=['full'])
116 |
117 | # misc
118 | parser.add_argument('--debug', action='store_true')
119 | parser.add_argument('--seed', type=int, default=1234)
120 | parser.add_argument('--bucket_cap_mb', type=int, default=40) # DDP
121 | parser.add_argument('--manual_gc_interval', type=int, default=1000) # DDP
122 |
123 | return parser
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tqdm
3 | import json
4 | import re
5 | import torch
6 | import torch.nn.functional as F
7 | import argparse
8 | import numpy as np
9 | import torch.nn as nn
10 |
11 | import lpips
12 | from tqdm import tqdm
13 | from PIL import Image
14 | Image.MAX_IMAGE_PIXELS = None
15 |
16 | import torch.distributed as dist
17 | from torch.multiprocessing import spawn
18 | from skimage.metrics import peak_signal_noise_ratio as psnr_loss
19 | from skimage.metrics import structural_similarity as ssim_loss
20 |
21 | from bitvae.data import ImageData
22 | from bitvae.utils.arguments import MainArgs, add_model_specific_args
23 | from bitvae.evaluation import calculate_frechet_distance
24 | from bitvae.evaluation import InceptionV3
25 |
26 | torch.set_num_threads(32)
27 |
28 |
29 | def calculate_batch_codebook_usage_percentage_bit(batch_encoding_indices):
30 | if isinstance(batch_encoding_indices, list):
31 | all_indices = []
32 | for one_encoding_indices in batch_encoding_indices:
33 | all_indices.append(one_encoding_indices.flatten(0, -2)) # [bhw, d]
34 | all_indices = torch.cat(all_indices, dim=0) # [sigma(bhw), d]
35 | else:
36 | # Flatten the batch of encoding indices into a single 1D tensor
37 | raise NotImplementedError
38 | all_indices = all_indices.detach().cpu()
39 |
40 | codebook_usage = torch.sum(all_indices, dim=0) # (d, )
41 |
42 | return codebook_usage, len(all_indices), all_indices.numpy()
43 |
44 | def default_parse_args():
45 | parser = argparse.ArgumentParser()
46 | parser.add_argument('--vqgan_ckpt', type=str, default=None)
47 | parser.add_argument('--inference_type', type=str, choices=["image"])
48 | parser.add_argument('--save', type=str, required=True)
49 | parser.add_argument('--device', type=str, default="cuda", choices=["cpu", "cuda"])
50 | parser = MainArgs.add_main_args(parser)
51 | parser = ImageData.add_data_specific_args(parser)
52 | args, unknown = parser.parse_known_args()
53 | args, parser, d_vae_model = add_model_specific_args(args, parser)
54 | args = parser.parse_args()
55 | return args, d_vae_model
56 |
57 |
58 | def setup(rank, world_size):
59 | os.environ['MASTER_ADDR'] = 'localhost'
60 | os.environ['MASTER_PORT'] = '12355'
61 | dist.init_process_group("nccl", rank=rank, world_size=world_size)
62 |
63 | def cleanup():
64 | dist.destroy_process_group()
65 |
66 | def main():
67 | args, d_vae_model = default_parse_args()
68 | os.makedirs(args.default_root_dir, exist_ok=True)
69 |
70 | # init resolution
71 | args.resolution = (args.resolution[0], args.resolution[0]) if len(args.resolution) == 1 else args.resolution # init resolution
72 |
73 | d_vae = None
74 | num_codes = None
75 |
76 | if args.tokenizer in ["flux"]:
77 | print('args: ',args)
78 | d_vae = d_vae_model(args)
79 | num_codes = args.codebook_size
80 | state_dict = torch.load(args.vqgan_ckpt, map_location=torch.device("cpu"), weights_only=True)
81 | d_vae.load_state_dict(state_dict["vae"])
82 | else:
83 | raise NotImplementedError
84 |
85 |
86 | world_size = 1 if args.debug else torch.cuda.device_count()
87 | manager = torch.multiprocessing.Manager()
88 | return_dict = manager.dict()
89 |
90 | if args.debug:
91 | inference_eval(0, world_size, args, d_vae_model, d_vae, num_codes, return_dict)
92 | else:
93 | spawn(inference_eval, args=(world_size, args, d_vae_model, d_vae, num_codes, return_dict), nprocs=world_size, join=True)
94 |
95 | pred_xs, pred_recs, lpips_alex, lpips_vgg, ssim_value, psnr_value, num_iter, total_usage, total_usage_bit, total_num_token, all_bit_indices_cat = [], [], 0, 0, 0, 0, 0, 0, 0, 0, []
96 | for rank in range(world_size):
97 | pred_xs.append(return_dict[rank]['pred_xs'])
98 | pred_recs.append(return_dict[rank]['pred_recs'])
99 | lpips_alex += return_dict[rank]['lpips_alex']
100 | lpips_vgg += return_dict[rank]['lpips_vgg']
101 | ssim_value += return_dict[rank]['ssim_value']
102 | psnr_value += return_dict[rank]['psnr_value']
103 | num_iter += return_dict[rank]['num_iter']
104 | total_usage += return_dict[rank]['total_usage']
105 | if not args.disable_codebook_usage_bit:
106 | total_usage_bit += return_dict[rank]['total_usage_bit']
107 | total_num_token += return_dict[rank]['total_num_token']
108 | all_bit_indices_cat.append(return_dict[rank]['all_bit_indices_cat'])
109 | pred_xs = np.concatenate(pred_xs, 0)
110 | pred_recs = np.concatenate(pred_recs, 0)
111 |
112 | result_str = image_eval(pred_xs, pred_recs, lpips_alex, lpips_vgg, ssim_value, psnr_value, num_iter, total_usage, num_codes, total_usage_bit, total_num_token)
113 | print(result_str)
114 | # save result_str to exp_dir
115 | if args.tokenizer == "flux":
116 | basename = os.path.basename(args.vqgan_ckpt)
117 | match = re.search(r'model_step_(\d+)\.ckpt', basename)
118 | iter_num = match.group(1) if match else None
119 | ckpt_dir = os.path.dirname(args.vqgan_ckpt)
120 | save_dir = os.path.join(ckpt_dir, "evaluation")
121 | os.makedirs(save_dir, exist_ok=True)
122 | if args.random_flip:
123 | flip_prob = int(args.flip_prob * 10)
124 | result_name = os.path.join(save_dir, f"result_{args.dataset_list}_{iter_num}_{args.schedule_mode}_{args.resolution}_max_flip_lvl_{args.max_flip_lvl}_flip_prob_{flip_prob}.txt")
125 | elif args.random_flip_1lvl:
126 | result_name = os.path.join(save_dir, f"result_{args.dataset_list}_{iter_num}_{args.schedule_mode}_flip_lvl_{args.flip_lvl_idx}.txt")
127 | elif args.drop_when_test:
128 | result_name = os.path.join(save_dir, f"result_{args.dataset_list}_{iter_num}_{args.schedule_mode}_drop_lvl_idx_{args.drop_lvl_idx}_drop_lvl_num_{args.drop_lvl_num}.txt")
129 | else:
130 | result_name = os.path.join(save_dir, f"result_{args.dataset_list}_{iter_num}_{args.schedule_mode}_{args.resolution}.txt")
131 | else:
132 | raise NotImplementedError
133 | with open(result_name, "w") as f:
134 | f.write(result_str)
135 | # print('Usage = %.2f'%((total_usage > 0.).sum() / num_codes))
136 |
137 | def inference_eval(rank, world_size, args, d_vae_model, d_vae, num_codes, return_dict):
138 | # Don't remove this setup!!! dist.init_process_group is important for building loader (data.distributed.DistributedSampler)
139 | setup(rank, world_size)
140 |
141 | device = torch.device(f"cuda:{rank}")
142 |
143 | for param in d_vae.parameters():
144 | param.requires_grad = False
145 | d_vae.to(device).eval()
146 |
147 | save_dir = 'results/%s'%(args.save)
148 | print('generating and saving image to %s...'%save_dir)
149 | os.makedirs(save_dir, exist_ok=True)
150 |
151 | data = ImageData(args)
152 |
153 | loader = data.val_dataloader()
154 |
155 | dims = 2048
156 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
157 | inception_model = InceptionV3([block_idx]).to(device)
158 | inception_model.eval()
159 |
160 | loader_iter = iter(loader)
161 |
162 | pred_xs = []
163 | pred_recs = []
164 | all_bit_indices_cat = []
165 | # LPIPS score related
166 | loss_fn_alex = lpips.LPIPS(net='alex').to(device) # best forward scores
167 | loss_fn_vgg = lpips.LPIPS(net='vgg').to(device) # closer to "traditional" perceptual loss, when used for optimization
168 | lpips_alex = 0.0
169 | lpips_vgg = 0.0
170 |
171 | # SSIM score related
172 | ssim_value = 0.0
173 |
174 | # PSNR score related
175 | psnr_value = 0.0
176 |
177 | num_images = len(loader)
178 | print(f"Testing {num_images} files")
179 | num_iter = 0
180 |
181 | total_usage = 0.0
182 | total_usage_bit = 0.0
183 | total_num_token = 0
184 |
185 | for batch_idx in tqdm(range(num_images)):
186 | batch = next(loader_iter)
187 |
188 | with torch.no_grad():
189 | x = batch['image']
190 | if args.tokenizer in ["flux"]:
191 | torch.cuda.empty_cache()
192 | # x: [-1, 1]
193 | x_recons, vq_output = d_vae(x.to(device), 2, 0, is_train=False)
194 | x_recons = x_recons.cpu()
195 | else:
196 | raise NotImplementedError
197 |
198 | if not args.disable_codebook_usage_bit:
199 | bit_indices = vq_output["bit_encodings"]
200 | codebook_usage_bit, num_token, bit_indices_cat = calculate_batch_codebook_usage_percentage_bit(bit_indices)
201 | total_usage_bit += codebook_usage_bit
202 | total_num_token += num_token
203 | all_bit_indices_cat.append(bit_indices_cat)
204 |
205 | paths = batch["path"]
206 | assert len(paths) == x.shape[0]
207 |
208 | for p, input_ori, recon_ori in zip(paths, x, x_recons):
209 | path = os.path.join(save_dir, "input_recon", os.path.basename(p))
210 | os.makedirs(os.path.split(path)[0], exist_ok=True)
211 |
212 | input_ori = input_ori.unsqueeze(0).to(device)
213 | input_ = (input_ori + 1) / 2 # [-1, 1] -> [0, 1]
214 |
215 | pred_x = inception_model(input_)[0]
216 | pred_x = pred_x.squeeze(3).squeeze(2).cpu().numpy()
217 |
218 | recon_ori = recon_ori.unsqueeze(0).to(device)
219 | recon_ = (recon_ori + 1) / 2 # [-1, 1] -> [0, 1]
220 | # recon_ = recon_.permute(1, 2, 0).detach().cpu()
221 | with torch.no_grad():
222 | pred_rec = inception_model(recon_)[0]
223 | pred_rec = pred_rec.squeeze(3).squeeze(2).cpu().numpy()
224 |
225 | pred_xs.append(pred_x)
226 | pred_recs.append(pred_rec)
227 |
228 | # calculate lpips
229 | with torch.no_grad():
230 | lpips_alex += loss_fn_alex(input_ori, recon_ori).sum() # [-1, 1]
231 | lpips_vgg += loss_fn_vgg(input_ori, recon_ori).sum() # [-1, 1]
232 |
233 | #calculate PSNR and SSIM
234 | rgb_restored = (recon_ * 255.0).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
235 | rgb_gt = (input_ * 255.0).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
236 | rgb_restored = rgb_restored.astype(np.float32) / 255.
237 | rgb_gt = rgb_gt.astype(np.float32) / 255.
238 | ssim_temp = 0
239 | psnr_temp = 0
240 | B, _, _, _ = rgb_restored.shape
241 | for i in range(B):
242 | rgb_restored_s, rgb_gt_s = rgb_restored[i], rgb_gt[i]
243 | with torch.no_grad():
244 | ssim_temp += ssim_loss(rgb_restored_s, rgb_gt_s, data_range=1.0, channel_axis=-1)
245 | psnr_temp += psnr_loss(rgb_gt, rgb_restored)
246 | ssim_value += ssim_temp / B
247 | psnr_value += psnr_temp / B
248 | num_iter += 1
249 |
250 | pred_xs = np.concatenate(pred_xs, axis=0)
251 | pred_recs = np.concatenate(pred_recs, axis=0)
252 | temp_dict = {
253 | 'pred_xs':pred_xs,
254 | 'pred_recs':pred_recs,
255 | 'lpips_alex':lpips_alex.cpu(),
256 | 'lpips_vgg':lpips_vgg.cpu(),
257 | 'ssim_value': ssim_value,
258 | 'psnr_value': psnr_value,
259 | 'num_iter': num_iter,
260 | 'total_usage': total_usage,
261 | 'total_usage_bit': total_usage_bit,
262 | 'total_num_token': total_num_token,
263 | }
264 | if not args.disable_codebook_usage_bit:
265 | all_bit_indices_cat = np.concatenate(all_bit_indices_cat, axis=0)
266 | temp_dict['all_bit_indices_cat'] = all_bit_indices_cat
267 | return_dict[rank] = temp_dict
268 |
269 | if dist.is_initialized():
270 | dist.barrier()
271 | cleanup()
272 |
273 | def image_eval(pred_xs, pred_recs, lpips_alex, lpips_vgg, ssim_value, psnr_value, num_iter, total_usage, num_codes, total_usage_bit, total_num_token):
274 | mu_x = np.mean(pred_xs, axis=0)
275 | sigma_x = np.cov(pred_xs, rowvar=False)
276 | mu_rec = np.mean(pred_recs, axis=0)
277 | sigma_rec = np.cov(pred_recs, rowvar=False)
278 |
279 | fid_value = calculate_frechet_distance(mu_x, sigma_x, mu_rec, sigma_rec)
280 | lpips_alex_value = lpips_alex / num_iter
281 | lpips_vgg_value = lpips_vgg / num_iter
282 | ssim_value = ssim_value / num_iter
283 | psnr_value = psnr_value / num_iter
284 | if total_num_token != 0:
285 | bit_distribution = total_usage_bit / total_num_token
286 | bit_distribution_str = '\n'.join(f'{value:.4f}' for value in bit_distribution)
287 |
288 | # usage_0 = (total_usage > 0.).sum() / num_codes * 100
289 | # usage_10 = (total_usage > 10.).sum() / num_codes * 100
290 |
291 | result_str = f"""
292 | FID = {fid_value:.4f}
293 | LPIPS_VGG: {lpips_vgg_value.item():.4f}
294 | LPIPS_ALEX: {lpips_alex_value.item():.4f}
295 | SSIM: {ssim_value:.4f}
296 | PSNR: {psnr_value:.3f}
297 | """
298 | if total_num_token != 0:
299 | result_str += f"""
300 | Bit_Distribution: {bit_distribution_str}
301 | """
302 | # Usage(>0): {usage_0:.2f}%
303 | # Usage(>10): {usage_10:.2f}%
304 | return result_str
305 | if __name__ == '__main__':
306 | main()
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import math
4 | import glob
5 | import time
6 | import logging
7 | from copy import deepcopy
8 | import numpy as np
9 |
10 | import torch
11 | import torch.nn.functional as F
12 | import torch.distributed as dist
13 |
14 | from torch.nn.parallel import DistributedDataParallel as DDP
15 | from timm.scheduler.cosine_lr import CosineLRScheduler
16 |
17 | from bitvae.utils.distributed import init_distributed_mode, reduce_losses, average_losses
18 | from bitvae.utils.logger import create_logger
19 |
20 | from bitvae.models import ImageDiscriminator
21 | from bitvae.data import ImageData
22 | from bitvae.modules.loss import get_disc_loss, adopt_weight
23 | from bitvae.utils.misc import get_last_ckpt
24 | from bitvae.utils.init_models import resume_from_ckpt
25 | from bitvae.utils.arguments import MainArgs, add_model_specific_args
26 |
27 | logger = logging.getLogger(__name__)
28 |
29 | def lecam_reg_zero(real_pred, fake_pred, thres=0.1):
30 | # avoid logits get too high
31 | assert real_pred.ndim == 0
32 | reg = torch.mean(F.relu(torch.abs(real_pred) - thres).pow(2)) + \
33 | torch.mean(F.relu(torch.abs(fake_pred) - thres).pow(2))
34 | return reg
35 |
36 |
37 | def main():
38 | parser = argparse.ArgumentParser()
39 | parser = MainArgs.add_main_args(parser)
40 | parser = ImageData.add_data_specific_args(parser)
41 | args, unknown = parser.parse_known_args()
42 | args, parser, d_vae_model = add_model_specific_args(args, parser)
43 | args = parser.parse_args()
44 |
45 | args.resolution = (args.resolution[0], args.resolution[0]) if len(args.resolution) == 1 else args.resolution # init resolution
46 |
47 | print(f"{args.default_root_dir=}")
48 |
49 | # Setup DDP:
50 | init_distributed_mode(args)
51 | rank = dist.get_rank()
52 | world_size = dist.get_world_size()
53 | device = rank % torch.cuda.device_count()
54 | torch.cuda.set_device(device)
55 |
56 | # Setup an experiment folder:
57 | if rank == 0:
58 | os.makedirs(args.default_root_dir, exist_ok=True) # Make results folder (holds all experiment subfolders
59 | checkpoint_dir = f"{args.default_root_dir}/checkpoints" # Stores saved model checkpoints
60 | os.makedirs(checkpoint_dir, exist_ok=True)
61 | logger = create_logger(args.default_root_dir)
62 | logger.info(f"Experiment directory created at {args.default_root_dir}")
63 |
64 | import wandb
65 | wandb_project = "VQVAE"
66 | wandb.init(
67 | project=wandb_project,
68 | name=os.path.basename(os.path.normpath(args.default_root_dir)),
69 | dir=args.default_root_dir,
70 | config=args,
71 | mode="offline" if args.debug else "online"
72 | )
73 | else:
74 | logger = create_logger(None)
75 |
76 | # init dataloader
77 | data = ImageData(args)
78 | dataloaders = data.train_dataloader()
79 | dataloader_iters = [iter(loader) for loader in dataloaders]
80 | data_epochs = [0 for _ in dataloaders]
81 |
82 | # init model
83 | d_vae = d_vae_model(args).to(device)
84 | d_vae.logger = logger
85 | image_disc = ImageDiscriminator(args).to(device)
86 |
87 | # init optimizers and schedulers
88 | if args.optim_type == "Adam":
89 | optim = torch.optim.Adam
90 | elif args.optim_type == "AdamW":
91 | optim = torch.optim.AdamW
92 | if args.disc_optim_type is None:
93 | disc_optim = optim
94 | elif args.disc_optim_type == "rmsprop":
95 | disc_optim = torch.optim.RMSprop
96 | opt_vae = optim(d_vae.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))
97 | if disc_optim == torch.optim.RMSprop:
98 | opt_image_disc = disc_optim(image_disc.parameters(), lr=args.lr * args.dis_lr_multiplier)
99 | else:
100 | opt_image_disc = disc_optim(image_disc.parameters(), lr=args.lr * args.dis_lr_multiplier, betas=(args.beta1, args.beta2))
101 |
102 | lr_min = args.lr_min
103 | train_iters = args.max_steps
104 | warmup_steps = args.warmup_steps
105 | warmup_lr_init = args.warmup_lr_init
106 |
107 | if args.disable_sch:
108 | # scheduler_list = [None, None]
109 | sch_vae, sch_image_disc = None, None
110 |
111 | model_optims = {
112 | "vae" : d_vae,
113 | "image_disc" : image_disc,
114 | "opt_vae" : opt_vae,
115 | "opt_image_disc" : opt_image_disc,
116 | "sch_vae" : sch_vae,
117 | "sch_image_disc" : sch_image_disc,
118 | }
119 |
120 | # resume from default_root_dir
121 | ckpt_path = None
122 | assert not args.default_root_dir is None # required argument
123 | ckpt_path = get_last_ckpt(args.default_root_dir)
124 | init_step = 0
125 | load_optimizer = not args.not_load_optimizer
126 | if ckpt_path:
127 | logger.info(f"Resuming from {ckpt_path}")
128 | state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
129 | model_optims, init_step = resume_from_ckpt(state_dict, model_optims, load_optimizer=True)
130 | # load pretrained weights
131 | elif args.pretrained is not None:
132 | state_dict = torch.load(args.pretrained, map_location="cpu", weights_only=True)
133 | if args.pretrained_mode == "full":
134 | model_optims, _ = resume_from_ckpt(state_dict, model_optims, load_optimizer=load_optimizer)
135 | logger.info(f"Successfully loaded ckpt {args.pretrained}, pretrained_mode {args.pretrained_mode}")
136 |
137 | d_vae = DDP(d_vae.to(device), device_ids=[args.gpu], bucket_cap_mb=args.bucket_cap_mb)
138 | image_disc = DDP(image_disc.to(device), device_ids=[args.gpu], bucket_cap_mb=args.bucket_cap_mb)
139 | disc_loss = get_disc_loss(args.disc_loss_type) # hinge loss by default
140 |
141 | if args.multiscale_training:
142 | scale_idx_list = np.load('bitvae/utils/random_numbers.npy') # load pre-computed scale_idx in each iteration
143 |
144 | start_time = time.time()
145 | for global_step in range(init_step, args.max_steps):
146 | loss_dicts = []
147 |
148 | if global_step == args.discriminator_iter_start - args.disc_pretrain_iter:
149 | logging.info(f"discriminator begins pretraining ")
150 | if global_step == args.discriminator_iter_start:
151 | log_str = "add GAN loss into training"
152 | if args.disc_pretrain_iter > 0:
153 | log_str += ", discriminator ends pretraining"
154 | logging.info(log_str)
155 |
156 | for idx in range(len(dataloader_iters)):
157 | try:
158 | _batch = next(dataloader_iters[idx])
159 | except StopIteration:
160 | data_epochs[idx] += 1
161 | logger.info(f"Reset the {idx}th dataloader as epoch {data_epochs[idx]}")
162 | dataloaders[idx].sampler.set_epoch(data_epochs[idx])
163 | dataloader_iters[idx] = iter(dataloaders[idx]) # update dataloader iter
164 | _batch = next(dataloader_iters[idx])
165 | except Exception as e:
166 | raise e
167 | x = _batch["image"]
168 | _type = _batch["type"][0]
169 |
170 | if args.multiscale_training:
171 | # data processing for multi-scale training
172 | scale_idx = scale_idx_list[global_step]
173 | if scale_idx == 0:
174 | # 256x256 batch=8
175 | x = F.interpolate(x, size=(256, 256), mode='area')
176 | elif scale_idx == 1:
177 | # 512x512 batch=4
178 | rdn_idx = torch.randperm(len(x))[:4] # without replacement
179 | x = x[rdn_idx]
180 | x = F.interpolate(x, size=(512, 512), mode='area')
181 | elif scale_idx == 2:
182 | # 1024x1024 batch=2
183 | rdn_idx = torch.randperm(len(x))[:2] # without replacement
184 | x = x[rdn_idx]
185 | else:
186 | raise ValueError(f"scale_idx {scale_idx} is not supported")
187 |
188 | if _type == "image":
189 | x_recon, flat_frames, flat_frames_recon, vae_loss_dict = d_vae(x, global_step, image_disc=image_disc)
190 | g_loss = sum(vae_loss_dict.values())
191 | opt_vae.zero_grad()
192 | g_loss.backward()
193 |
194 | if not ((global_step+1) % args.ckpt_every) == 0:
195 | if args.max_grad_norm > 0:
196 | torch.nn.utils.clip_grad_norm_(d_vae.parameters(), args.max_grad_norm)
197 | if not sch_vae is None:
198 | sch_vae.step(global_step)
199 | elif args.lr_drop and global_step in args.lr_drop:
200 | logger.info(f"multiply lr of VQ-VAE by {args.lr_drop_rate} at iteration {global_step}")
201 | for opt_vae_param_group in opt_vae.param_groups:
202 | opt_vae_param_group["lr"] = opt_vae_param_group["lr"] * args.lr_drop_rate
203 | opt_vae.step()
204 | opt_vae.zero_grad() # free memory
205 |
206 | disc_loss_dict = {}
207 | # disc_factor = 0 before (args.discriminator_iter_start - args.disc_pretrain_iter)
208 | disc_factor = adopt_weight(global_step, threshold=args.discriminator_iter_start - args.disc_pretrain_iter)
209 | discloss = d_image_loss = torch.tensor(0.).to(x.device)
210 | ### enable pool warmup
211 | for disc_step in range(args.disc_optim_steps): # train discriminator
212 | require_optim = False
213 | if _type == "image" and args.image_disc_weight > 0: # train image discriminator
214 | require_optim = True
215 | logits_image_real = image_disc(x, pool_name="real")
216 | logits_image_fake = image_disc(x_recon.detach(), pool_name="fake")
217 | d_image_loss = disc_loss(logits_image_real, logits_image_fake)
218 | disc_loss_dict["train/logits_image_real"] = logits_image_real.mean().detach()
219 | disc_loss_dict["train/logits_image_fake"] = logits_image_fake.mean().detach()
220 | disc_loss_dict["train/d_image_loss"] = d_image_loss.mean().detach()
221 | discloss = d_image_loss * args.image_disc_weight
222 | opt_discs, sch_discs = [opt_image_disc], [sch_image_disc]
223 | if global_step >= args.discriminator_iter_start and args.use_lecam_reg_zero:
224 | lecam_zero_loss = lecam_reg_zero(logits_image_real.mean(), logits_image_fake.mean())
225 | disc_loss_dict["train/lecam_zero_loss"] = lecam_zero_loss.mean().detach()
226 | discloss += lecam_zero_loss * args.lecam_weight
227 | discloss = disc_factor * discloss
228 |
229 | if require_optim:
230 | for opt_disc in opt_discs:
231 | opt_disc.zero_grad()
232 | discloss.backward()
233 |
234 | if not ((global_step+1) % args.ckpt_every) == 0:
235 | if args.max_grad_norm_disc > 0: # by default, 1.0
236 | torch.nn.utils.clip_grad_norm_(image_disc.parameters(), args.max_grad_norm_disc)
237 | for sch_disc in sch_discs:
238 | if not sch_disc is None:
239 | sch_disc.step(global_step)
240 | elif args.lr_drop and global_step in args.lr_drop:
241 | for opt_disc in opt_discs:
242 | logger.info(f"multiply lr of discriminator by {args.lr_drop_rate} at iteration {global_step}")
243 | for opt_disc_param_group in opt_disc.param_groups:
244 | opt_disc_param_group["lr"] = opt_disc_param_group["lr"] * args.lr_drop_rate
245 | for opt_disc in opt_discs:
246 | opt_disc.step()
247 | for opt_disc in opt_discs:
248 | opt_disc.zero_grad() # free memory
249 |
250 | loss_dict = {**vae_loss_dict, **disc_loss_dict}
251 | if (global_step+1) % args.log_every == 0:
252 | reduced_loss_dict = reduce_losses(loss_dict)
253 | else:
254 | reduced_loss_dict = {}
255 | loss_dicts.append(reduced_loss_dict)
256 |
257 | if (global_step+1) % args.log_every == 0:
258 | avg_loss_dict = average_losses(loss_dicts)
259 | torch.cuda.synchronize()
260 | end_time = time.time()
261 | iter_speed = (end_time - start_time) / args.log_every
262 | if rank == 0:
263 | for key, value in avg_loss_dict.items():
264 | wandb.log({key: value}, step=global_step)
265 | # writing logs
266 | logger.info(f'global_step={global_step}, precepetual_loss={avg_loss_dict.get("train/perceptual_loss",0):.4f}, recon_loss={avg_loss_dict.get("train/recon_loss",0):.4f}, commitment_loss={avg_loss_dict.get("train/commitment_loss",0):.4f}, logit_r={avg_loss_dict.get("train/logits_image_real",0):.4f}, logit_f={avg_loss_dict.get("train/logits_image_fake",0):.4f}, L_disc={avg_loss_dict.get("train/d_image_loss",0):.4f}, iter_speed={iter_speed:.2f}s')
267 | start_time = time.time()
268 |
269 | if (global_step+1) % args.ckpt_every == 0 and global_step != init_step:
270 | if rank == 0:
271 | checkpoint_path = os.path.join(checkpoint_dir, f'model_step_{global_step}.ckpt')
272 | save_dict = {}
273 | for k in model_optims:
274 | save_dict[k] = None if model_optims[k] is None \
275 | else model_optims[k].module.state_dict() if hasattr(model_optims[k], "module") \
276 | else model_optims[k].state_dict()
277 | torch.save({
278 | 'step': global_step,
279 | **save_dict,
280 | }, checkpoint_path)
281 | logger.info(f'Checkpoint saved at step {global_step}')
282 |
283 |
284 | if __name__ == '__main__':
285 | main()
286 |
--------------------------------------------------------------------------------
/bitvae/evaluation/inception.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torchvision import models
5 | from scipy import linalg
6 | import numpy as np
7 |
8 | try:
9 | from torchvision.models.utils import load_state_dict_from_url
10 | except ImportError:
11 | from torch.utils.model_zoo import load_url as load_state_dict_from_url
12 |
13 | # Inception weights ported to Pytorch from
14 | # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
15 | FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth'
16 |
17 | FID_WEIGHTS_PATH = "checkpoints/pt_inception-2015-12-05-6726825d.pth"
18 |
19 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
20 | """Numpy implementation of the Frechet Distance.
21 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
22 | and X_2 ~ N(mu_2, C_2) is
23 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
24 |
25 | Stable version by Dougal J. Sutherland.
26 |
27 | Params:
28 | -- mu1 : Numpy array containing the activations of a layer of the
29 | inception net (like returned by the function 'get_predictions')
30 | for generated samples.
31 | -- mu2 : The sample mean over activations, precalculated on an
32 | representative data set.
33 | -- sigma1: The covariance matrix over activations for generated samples.
34 | -- sigma2: The covariance matrix over activations, precalculated on an
35 | representative data set.
36 |
37 | Returns:
38 | -- : The Frechet Distance.
39 | """
40 |
41 | mu1 = np.atleast_1d(mu1)
42 | mu2 = np.atleast_1d(mu2)
43 |
44 | sigma1 = np.atleast_2d(sigma1)
45 | sigma2 = np.atleast_2d(sigma2)
46 |
47 | assert mu1.shape == mu2.shape, \
48 | 'Training and test mean vectors have different lengths'
49 | assert sigma1.shape == sigma2.shape, \
50 | 'Training and test covariances have different dimensions'
51 |
52 | diff = mu1 - mu2
53 |
54 | # Product might be almost singular
55 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
56 | if not np.isfinite(covmean).all():
57 | msg = ('fid calculation produces singular product; '
58 | 'adding %s to diagonal of cov estimates') % eps
59 | print(msg)
60 | offset = np.eye(sigma1.shape[0]) * eps
61 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
62 |
63 | # Numerical error might give slight imaginary component
64 | if np.iscomplexobj(covmean):
65 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
66 | m = np.max(np.abs(covmean.imag))
67 | raise ValueError('Imaginary component {}'.format(m))
68 | covmean = covmean.real
69 |
70 | tr_covmean = np.trace(covmean)
71 |
72 | return (diff.dot(diff) + np.trace(sigma1) +
73 | np.trace(sigma2) - 2 * tr_covmean)
74 |
75 | class InceptionV3(nn.Module):
76 | """Pretrained InceptionV3 network returning feature maps"""
77 |
78 | # Index of default block of inception to return,
79 | # corresponds to output of final average pooling
80 | DEFAULT_BLOCK_INDEX = 3
81 |
82 | # Maps feature dimensionality to their output blocks indices
83 | BLOCK_INDEX_BY_DIM = {
84 | 64: 0, # First max pooling features
85 | 192: 1, # Second max pooling featurs
86 | 768: 2, # Pre-aux classifier features
87 | 2048: 3 # Final average pooling features
88 | }
89 |
90 | def __init__(self,
91 | output_blocks=[DEFAULT_BLOCK_INDEX],
92 | resize_input=True,
93 | normalize_input=True,
94 | requires_grad=False,
95 | use_fid_inception=True):
96 | """Build pretrained InceptionV3
97 |
98 | Parameters
99 | ----------
100 | output_blocks : list of int
101 | Indices of blocks to return features of. Possible values are:
102 | - 0: corresponds to output of first max pooling
103 | - 1: corresponds to output of second max pooling
104 | - 2: corresponds to output which is fed to aux classifier
105 | - 3: corresponds to output of final average pooling
106 | resize_input : bool
107 | If true, bilinearly resizes input to width and height 299 before
108 | feeding input to model. As the network without fully connected
109 | layers is fully convolutional, it should be able to handle inputs
110 | of arbitrary size, so resizing might not be strictly needed
111 | normalize_input : bool
112 | If true, scales the input from range (0, 1) to the range the
113 | pretrained Inception network expects, namely (-1, 1)
114 | requires_grad : bool
115 | If true, parameters of the model require gradients. Possibly useful
116 | for finetuning the network
117 | use_fid_inception : bool
118 | If true, uses the pretrained Inception model used in Tensorflow's
119 | FID implementation. If false, uses the pretrained Inception model
120 | available in torchvision. The FID Inception model has different
121 | weights and a slightly different structure from torchvision's
122 | Inception model. If you want to compute FID scores, you are
123 | strongly advised to set this parameter to true to get comparable
124 | results.
125 | """
126 | super(InceptionV3, self).__init__()
127 |
128 | self.resize_input = resize_input
129 | self.normalize_input = normalize_input
130 | self.output_blocks = sorted(output_blocks)
131 | self.last_needed_block = max(output_blocks)
132 |
133 | assert self.last_needed_block <= 3, \
134 | 'Last possible output block index is 3'
135 |
136 | self.blocks = nn.ModuleList()
137 |
138 | if use_fid_inception:
139 | inception = fid_inception_v3()
140 | else:
141 | inception = models.inception_v3(pretrained=True)
142 |
143 | # Block 0: input to maxpool1
144 | block0 = [
145 | inception.Conv2d_1a_3x3,
146 | inception.Conv2d_2a_3x3,
147 | inception.Conv2d_2b_3x3,
148 | nn.MaxPool2d(kernel_size=3, stride=2)
149 | ]
150 | self.blocks.append(nn.Sequential(*block0))
151 |
152 | # Block 1: maxpool1 to maxpool2
153 | if self.last_needed_block >= 1:
154 | block1 = [
155 | inception.Conv2d_3b_1x1,
156 | inception.Conv2d_4a_3x3,
157 | nn.MaxPool2d(kernel_size=3, stride=2)
158 | ]
159 | self.blocks.append(nn.Sequential(*block1))
160 |
161 | # Block 2: maxpool2 to aux classifier
162 | if self.last_needed_block >= 2:
163 | block2 = [
164 | inception.Mixed_5b,
165 | inception.Mixed_5c,
166 | inception.Mixed_5d,
167 | inception.Mixed_6a,
168 | inception.Mixed_6b,
169 | inception.Mixed_6c,
170 | inception.Mixed_6d,
171 | inception.Mixed_6e,
172 | ]
173 | self.blocks.append(nn.Sequential(*block2))
174 |
175 | # Block 3: aux classifier to final avgpool
176 | if self.last_needed_block >= 3:
177 | block3 = [
178 | inception.Mixed_7a,
179 | inception.Mixed_7b,
180 | inception.Mixed_7c,
181 | nn.AdaptiveAvgPool2d(output_size=(1, 1))
182 | ]
183 | self.blocks.append(nn.Sequential(*block3))
184 |
185 | for param in self.parameters():
186 | param.requires_grad = requires_grad
187 |
188 | def forward(self, inp):
189 | """Get Inception feature maps
190 |
191 | Parameters
192 | ----------
193 | inp : torch.autograd.Variable
194 | Input tensor of shape Bx3xHxW. Values are expected to be in
195 | range (0, 1)
196 |
197 | Returns
198 | -------
199 | List of torch.autograd.Variable, corresponding to the selected output
200 | block, sorted ascending by index
201 | """
202 | outp = []
203 | x = inp
204 |
205 | if self.resize_input:
206 | x = F.interpolate(x,
207 | size=(299, 299),
208 | mode='bilinear',
209 | align_corners=False)
210 |
211 | if self.normalize_input:
212 | x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
213 |
214 | for idx, block in enumerate(self.blocks):
215 | x = block(x)
216 | if idx in self.output_blocks:
217 | outp.append(x)
218 |
219 | if idx == self.last_needed_block:
220 | break
221 |
222 | return outp
223 |
224 |
225 | def fid_inception_v3():
226 | """Build pretrained Inception model for FID computation
227 |
228 | The Inception model for FID computation uses a different set of weights
229 | and has a slightly different structure than torchvision's Inception.
230 |
231 | This method first constructs torchvision's Inception and then patches the
232 | necessary parts that are different in the FID Inception model.
233 | """
234 | inception = models.inception_v3(num_classes=1008,
235 | aux_logits=False,
236 | pretrained=False)
237 | inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
238 | inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
239 | inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
240 | inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
241 | inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
242 | inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
243 | inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
244 | inception.Mixed_7b = FIDInceptionE_1(1280)
245 | inception.Mixed_7c = FIDInceptionE_2(2048)
246 |
247 | # state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
248 | state_dict = torch.load(FID_WEIGHTS_PATH)
249 | inception.load_state_dict(state_dict)
250 | return inception
251 |
252 |
253 | class FIDInceptionA(models.inception.InceptionA):
254 | """InceptionA block patched for FID computation"""
255 | def __init__(self, in_channels, pool_features):
256 | super(FIDInceptionA, self).__init__(in_channels, pool_features)
257 |
258 | def forward(self, x):
259 | branch1x1 = self.branch1x1(x)
260 |
261 | branch5x5 = self.branch5x5_1(x)
262 | branch5x5 = self.branch5x5_2(branch5x5)
263 |
264 | branch3x3dbl = self.branch3x3dbl_1(x)
265 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
266 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
267 |
268 | # Patch: Tensorflow's average pool does not use the padded zero's in
269 | # its average calculation
270 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
271 | count_include_pad=False)
272 | branch_pool = self.branch_pool(branch_pool)
273 |
274 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
275 | return torch.cat(outputs, 1)
276 |
277 |
278 | class FIDInceptionC(models.inception.InceptionC):
279 | """InceptionC block patched for FID computation"""
280 | def __init__(self, in_channels, channels_7x7):
281 | super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
282 |
283 | def forward(self, x):
284 | branch1x1 = self.branch1x1(x)
285 |
286 | branch7x7 = self.branch7x7_1(x)
287 | branch7x7 = self.branch7x7_2(branch7x7)
288 | branch7x7 = self.branch7x7_3(branch7x7)
289 |
290 | branch7x7dbl = self.branch7x7dbl_1(x)
291 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
292 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
293 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
294 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
295 |
296 | # Patch: Tensorflow's average pool does not use the padded zero's in
297 | # its average calculation
298 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
299 | count_include_pad=False)
300 | branch_pool = self.branch_pool(branch_pool)
301 |
302 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
303 | return torch.cat(outputs, 1)
304 |
305 |
306 | class FIDInceptionE_1(models.inception.InceptionE):
307 | """First InceptionE block patched for FID computation"""
308 | def __init__(self, in_channels):
309 | super(FIDInceptionE_1, self).__init__(in_channels)
310 |
311 | def forward(self, x):
312 | branch1x1 = self.branch1x1(x)
313 |
314 | branch3x3 = self.branch3x3_1(x)
315 | branch3x3 = [
316 | self.branch3x3_2a(branch3x3),
317 | self.branch3x3_2b(branch3x3),
318 | ]
319 | branch3x3 = torch.cat(branch3x3, 1)
320 |
321 | branch3x3dbl = self.branch3x3dbl_1(x)
322 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
323 | branch3x3dbl = [
324 | self.branch3x3dbl_3a(branch3x3dbl),
325 | self.branch3x3dbl_3b(branch3x3dbl),
326 | ]
327 | branch3x3dbl = torch.cat(branch3x3dbl, 1)
328 |
329 | # Patch: Tensorflow's average pool does not use the padded zero's in
330 | # its average calculation
331 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
332 | count_include_pad=False)
333 | branch_pool = self.branch_pool(branch_pool)
334 |
335 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
336 | return torch.cat(outputs, 1)
337 |
338 |
339 | class FIDInceptionE_2(models.inception.InceptionE):
340 | """Second InceptionE block patched for FID computation"""
341 | def __init__(self, in_channels):
342 | super(FIDInceptionE_2, self).__init__(in_channels)
343 |
344 | def forward(self, x):
345 | branch1x1 = self.branch1x1(x)
346 |
347 | branch3x3 = self.branch3x3_1(x)
348 | branch3x3 = [
349 | self.branch3x3_2a(branch3x3),
350 | self.branch3x3_2b(branch3x3),
351 | ]
352 | branch3x3 = torch.cat(branch3x3, 1)
353 |
354 | branch3x3dbl = self.branch3x3dbl_1(x)
355 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
356 | branch3x3dbl = [
357 | self.branch3x3dbl_3a(branch3x3dbl),
358 | self.branch3x3dbl_3b(branch3x3dbl),
359 | ]
360 | branch3x3dbl = torch.cat(branch3x3dbl, 1)
361 |
362 | # Patch: The FID Inception model uses max pooling instead of average
363 | # pooling. This is likely an error in this specific Inception
364 | # implementation, as other Inception models use average pooling here
365 | # (which matches the description in the paper).
366 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
367 | branch_pool = self.branch_pool(branch_pool)
368 |
369 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
370 | return torch.cat(outputs, 1)
--------------------------------------------------------------------------------
/bitvae/models/d_vae.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import torch
4 | import numpy as np
5 | from einops import rearrange
6 | from torch import Tensor, nn
7 | from torchvision import transforms
8 | import torch.utils.checkpoint as checkpoint
9 | import torch.nn.functional as F
10 |
11 | from bitvae.modules.quantizer import MultiScaleBSQ
12 | from bitvae.modules import Conv, adopt_weight, LPIPS, Normalize
13 | from bitvae.utils.misc import ptdtype
14 |
15 |
16 | def swish(x: Tensor) -> Tensor:
17 | try:
18 | return x * torch.sigmoid(x)
19 | except:
20 | device = x.device
21 | x = x.cpu().pin_memory()
22 | return (x*torch.sigmoid(x)).to(device=device)
23 |
24 | class ResnetBlock(nn.Module):
25 | def __init__(self, in_channels: int, out_channels: int, norm_type='group'):
26 | super().__init__()
27 | self.in_channels = in_channels
28 | out_channels = in_channels if out_channels is None else out_channels
29 | self.out_channels = out_channels
30 |
31 | self.norm1 = Normalize(in_channels, norm_type)
32 |
33 | self.conv1 = Conv(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
34 |
35 | self.norm2 = Normalize(out_channels, norm_type)
36 |
37 | self.conv2 = Conv(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
38 |
39 | if self.in_channels != self.out_channels:
40 | self.nin_shortcut = Conv(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
41 |
42 | def forward(self, x):
43 | h = x
44 | h = self.norm1(h)
45 | h = swish(h)
46 | h = self.conv1(h)
47 |
48 | h = self.norm2(h)
49 | h = swish(h)
50 | h = self.conv2(h)
51 |
52 | if self.in_channels != self.out_channels:
53 | x = self.nin_shortcut(x)
54 |
55 | return x + h
56 |
57 |
58 |
59 | class Downsample(nn.Module):
60 | def __init__(self, in_channels, spatial_down=False):
61 | super().__init__()
62 | assert spatial_down == True
63 | self.pad = (0, 1, 0, 1)
64 | self.conv = Conv(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
65 |
66 | def forward(self, x: Tensor):
67 | x = nn.functional.pad(x, self.pad, mode="constant", value=0)
68 | x = self.conv(x)
69 | return x
70 |
71 |
72 | class Upsample(nn.Module):
73 | def __init__(self, in_channels, spatial_up=False):
74 | super().__init__()
75 | assert spatial_up == True
76 |
77 | self.scale_factor = 2
78 | self.conv = Conv(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
79 |
80 | def forward(self, x: Tensor):
81 | x = F.interpolate(x, scale_factor=self.scale_factor, mode="nearest")
82 | x = self.conv(x)
83 | return x
84 |
85 |
86 | class Encoder(nn.Module):
87 | def __init__(
88 | self,
89 | ch: int,
90 | ch_mult: list[int],
91 | num_res_blocks: int,
92 | z_channels: int,
93 | in_channels = 3,
94 | patch_size=8,
95 | norm_type='group',
96 | use_checkpoint=False,
97 | ):
98 | super().__init__()
99 | self.max_down = np.log2(patch_size)
100 | self.ch = ch
101 | self.num_resolutions = len(ch_mult)
102 | self.num_res_blocks = num_res_blocks
103 | self.in_channels = in_channels
104 | self.use_checkpoint = use_checkpoint
105 | # downsampling
106 | # self.conv_in = Conv(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
107 | self.conv_in = Conv(in_channels, ch, kernel_size=3, stride=1, padding=1)
108 |
109 | in_ch_mult = (1,) + tuple(ch_mult)
110 | self.in_ch_mult = in_ch_mult
111 | self.down = nn.ModuleList()
112 | block_in = self.ch
113 | for i_level in range(self.num_resolutions):
114 | block = nn.ModuleList()
115 | attn = nn.ModuleList()
116 | block_in = ch * in_ch_mult[i_level]
117 | block_out = ch * ch_mult[i_level]
118 | for _ in range(self.num_res_blocks):
119 | block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, norm_type=norm_type))
120 | block_in = block_out
121 | down = nn.Module()
122 | down.block = block
123 | down.attn = attn
124 |
125 | spatial_down = True if i_level < self.max_down else False
126 | if spatial_down:
127 | down.downsample = Downsample(block_in, spatial_down=spatial_down)
128 | self.down.append(down)
129 |
130 | # middle
131 | self.mid = nn.Module()
132 | self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, norm_type=norm_type)
133 | self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, norm_type=norm_type)
134 |
135 | # end
136 | self.norm_out = Normalize(block_in, norm_type)
137 | self.conv_out = Conv(block_in, z_channels, kernel_size=3, stride=1, padding=1)
138 |
139 | def forward(self, x, return_hidden=False):
140 | if not self.use_checkpoint:
141 | return self._forward(x, return_hidden=return_hidden)
142 | else:
143 | return checkpoint.checkpoint(self._forward, x, return_hidden, use_reentrant=False)
144 |
145 | def _forward(self, x: Tensor, return_hidden=False) -> Tensor:
146 | # downsampling
147 | h0 = self.conv_in(x)
148 | hs = [h0]
149 | for i_level in range(self.num_resolutions):
150 | for i_block in range(self.num_res_blocks):
151 | h = self.down[i_level].block[i_block](hs[-1])
152 | if len(self.down[i_level].attn) > 0:
153 | h = self.down[i_level].attn[i_block](h)
154 | hs.append(h)
155 | if hasattr(self.down[i_level], "downsample"):
156 | hs.append(self.down[i_level].downsample(hs[-1]))
157 |
158 | # middle
159 | h = hs[-1]
160 | hs_mid = [h]
161 | h = self.mid.block_1(h)
162 | h = self.mid.block_2(h)
163 | hs_mid.append(h)
164 | # end
165 | h = self.norm_out(h)
166 | h = swish(h)
167 | h = self.conv_out(h)
168 | if return_hidden:
169 | return h, hs, hs_mid
170 | else:
171 | return h
172 |
173 |
174 | class Decoder(nn.Module):
175 | def __init__(
176 | self,
177 | ch: int,
178 | ch_mult: list[int],
179 | num_res_blocks: int,
180 | z_channels: int,
181 | out_ch = 3,
182 | patch_size=8,
183 | norm_type="group",
184 | use_checkpoint=False,
185 | ):
186 | super().__init__()
187 | self.max_up = np.log2(patch_size)
188 | self.ch = ch
189 | self.num_resolutions = len(ch_mult)
190 | self.num_res_blocks = num_res_blocks
191 | self.ffactor = 2 ** (self.num_resolutions - 1)
192 | self.use_checkpoint = use_checkpoint
193 |
194 | # compute in_ch_mult, block_in and curr_res at lowest res
195 | block_in = ch * ch_mult[self.num_resolutions - 1]
196 |
197 | # z to block_in
198 | self.conv_in = Conv(z_channels, block_in, kernel_size=3, stride=1, padding=1)
199 |
200 | # middle
201 | self.mid = nn.Module()
202 | self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, norm_type=norm_type)
203 | self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, norm_type=norm_type)
204 |
205 | # upsampling
206 | self.up = nn.ModuleList()
207 | for i_level in reversed(range(self.num_resolutions)):
208 | block = nn.ModuleList()
209 | attn = nn.ModuleList()
210 | block_out = ch * ch_mult[i_level]
211 | for _ in range(self.num_res_blocks + 1):
212 | block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, norm_type=norm_type))
213 | block_in = block_out
214 | up = nn.Module()
215 | up.block = block
216 | up.attn = attn
217 | # https://github.com/black-forest-labs/flux/blob/b4f689aaccd40de93429865793e84a734f4a6254/src/flux/modules/autoencoder.py#L228
218 | spatial_up = True if 1 <= i_level <= self.max_up else False
219 | if spatial_up:
220 | up.upsample = Upsample(block_in, spatial_up=spatial_up)
221 | self.up.insert(0, up) # prepend to get consistent order
222 |
223 | # end
224 | self.norm_out = Normalize(block_in, norm_type)
225 | self.conv_out = Conv(block_in, out_ch, kernel_size=3, stride=1, padding=1)
226 |
227 | def forward(self, z):
228 | if not self.use_checkpoint:
229 | return self._forward(z)
230 | else:
231 | return checkpoint.checkpoint(self._forward, z, use_reentrant=False)
232 |
233 | def _forward(self, z: Tensor) -> Tensor:
234 | # z to block_in
235 | h = self.conv_in(z)
236 |
237 | # middle
238 | h = self.mid.block_1(h)
239 | h = self.mid.block_2(h)
240 |
241 | # upsampling
242 | for i_level in reversed(range(self.num_resolutions)):
243 | for i_block in range(self.num_res_blocks + 1):
244 | h = self.up[i_level].block[i_block](h)
245 | if len(self.up[i_level].attn) > 0:
246 | h = self.up[i_level].attn[i_block](h)
247 | if hasattr(self.up[i_level], "upsample"):
248 | h = self.up[i_level].upsample(h)
249 |
250 | # end
251 | h = self.norm_out(h)
252 | h = swish(h)
253 | h = self.conv_out(h)
254 | return h
255 |
256 | class AutoEncoder(nn.Module):
257 | def __init__(self, args):
258 | super().__init__()
259 | self.args = args
260 | self.encoder = Encoder(
261 | ch=args.base_ch,
262 | ch_mult=args.encoder_ch_mult,
263 | num_res_blocks=args.num_res_blocks,
264 | z_channels=args.codebook_dim,
265 | patch_size=args.patch_size,
266 | use_checkpoint=args.use_checkpoint,
267 | )
268 | self.decoder = Decoder(
269 | ch=args.base_ch,
270 | ch_mult=args.decoder_ch_mult,
271 | num_res_blocks=args.num_res_blocks,
272 | z_channels=args.codebook_dim,
273 | patch_size=args.patch_size,
274 | use_checkpoint=args.use_checkpoint,
275 | )
276 |
277 | self.gan_feat_weight = args.gan_feat_weight
278 | self.recon_loss_type = args.recon_loss_type
279 | self.l1_weight = args.l1_weight
280 | self.kl_weight = args.kl_weight
281 | self.lfq_weight = args.lfq_weight
282 | self.image_gan_weight = args.image_gan_weight # image GAN loss weight
283 | self.perceptual_weight = args.perceptual_weight
284 |
285 | self.compute_all_commitment = args.compute_all_commitment # compute commitment between input and rq-output
286 |
287 | self.perceptual_model = LPIPS(upcast_tf32=args.upcast_tf32).eval()
288 |
289 | if args.quantizer_type == 'MultiScaleBSQ':
290 | self.quantizer = MultiScaleBSQ(
291 | dim = args.codebook_dim, # this is the input feature dimension, defaults to log2(codebook_size) if not defined
292 | entropy_loss_weight = args.entropy_loss_weight, # how much weight to place on entropy loss
293 | diversity_gamma = args.diversity_gamma, # within entropy loss, how much weight to give to diversity of codes, taken from https://arxiv.org/abs/1911.05894
294 | commitment_loss_weight=args.commitment_loss_weight, # loss weight of commitment loss
295 | new_quant=args.new_quant,
296 | use_decay_factor=args.use_decay_factor,
297 | use_stochastic_depth=args.use_stochastic_depth,
298 | drop_rate=args.drop_rate,
299 | schedule_mode=args.schedule_mode,
300 | keep_first_quant=args.keep_first_quant,
301 | keep_last_quant=args.keep_last_quant,
302 | remove_residual_detach=args.remove_residual_detach,
303 | use_out_phi=args.use_out_phi,
304 | use_out_phi_res=args.use_out_phi_res,
305 | random_flip = args.random_flip,
306 | flip_prob = args.flip_prob,
307 | flip_mode = args.flip_mode,
308 | max_flip_lvl = args.max_flip_lvl,
309 | random_flip_1lvl = args.random_flip_1lvl,
310 | flip_lvl_idx = args.flip_lvl_idx,
311 | drop_when_test = args.drop_when_test,
312 | drop_lvl_idx = args.drop_lvl_idx,
313 | drop_lvl_num = args.drop_lvl_num,
314 | random_short_schedule = args.random_short_schedule,
315 | short_schedule_prob = args.short_schedule_prob,
316 | disable_flip_prob = args.disable_flip_prob,
317 | zeta = args.zeta,
318 | gamma = args.gamma,
319 | uniform_short_schedule = args.uniform_short_schedule
320 | )
321 | else:
322 | raise NotImplementedError(f"{args.quantizer_type} not supported")
323 | self.commitment_loss_weight = args.commitment_loss_weight
324 |
325 | def forward(self, x, global_step, image_disc=None, is_train=True):
326 | assert x.ndim == 4 # assert input data is image
327 |
328 | enc_dtype = ptdtype[self.args.encoder_dtype]
329 |
330 | with torch.amp.autocast("cuda", dtype=enc_dtype):
331 | h = self.encoder(x, return_hidden=False) # B C H W
332 | h = h.to(dtype=torch.float32)
333 |
334 | # Multiscale LFQ
335 | z, all_indices, all_bit_indices, all_loss = self.quantizer(h)
336 | # print(torch.unique(torch.round(z * 10**4)/10**4)) # keep 4 decimal places
337 | x_recon = self.decoder(z)
338 | vq_output = {
339 | "commitment_loss": torch.mean(all_loss) * self.lfq_weight, # here commitment loss is sum of commitment loss and entropy penalty
340 | "encodings": all_indices,
341 | "bit_encodings": all_bit_indices,
342 | }
343 | if self.compute_all_commitment:
344 | # compute commitment loss between input and rq-output
345 | vq_output["all_commitment_loss"] = F.mse_loss(h, z.detach(), reduction="mean") * self.commitment_loss_weight * self.lfq_weight
346 | else:
347 | # disable backward prop
348 | vq_output["all_commitment_loss"] = F.mse_loss(h.detach(), z.detach(), reduction="mean") * self.commitment_loss_weight * self.lfq_weight
349 |
350 | assert x.shape == x_recon.shape, f"x.shape {x.shape}, x_recon.shape {x_recon.shape}"
351 |
352 | if is_train == False:
353 | return x_recon, vq_output
354 |
355 | if self.recon_loss_type == 'l1':
356 | recon_loss = F.l1_loss(x_recon, x) * self.l1_weight
357 | else:
358 | recon_loss = F.mse_loss(x_recon, x) * self.l1_weight
359 |
360 | flat_frames = x
361 | flat_frames_recon = x_recon
362 |
363 | perceptual_loss = self.perceptual_model(flat_frames, flat_frames_recon).mean() * self.perceptual_weight
364 |
365 | loss_dict = {
366 | "train/perceptual_loss": perceptual_loss,
367 | "train/recon_loss": recon_loss,
368 | "train/commitment_loss": vq_output['commitment_loss'],
369 | "train/all_commitment_loss": vq_output['all_commitment_loss'],
370 | }
371 |
372 | ### GAN loss
373 | disc_factor = adopt_weight(global_step, threshold=self.args.discriminator_iter_start, warmup=self.args.disc_warmup)
374 | if self.image_gan_weight > 0: # image GAN loss
375 | logits_image_fake = image_disc(flat_frames_recon)
376 | g_image_loss = -torch.mean(logits_image_fake) * self.image_gan_weight * disc_factor # disc_factor=0 before self.args.discriminator_iter_start
377 | loss_dict["train/g_image_loss"] = g_image_loss
378 |
379 | return (x_recon.detach(), flat_frames.detach(), flat_frames_recon.detach(), loss_dict)
380 |
381 | @staticmethod
382 | def add_model_specific_args(parent_parser):
383 | parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False)
384 | parser.add_argument("--codebook_size", type=int, default=16384)
385 |
386 | parser.add_argument("--base_ch", type=int, default=128)
387 | parser.add_argument("--num_res_blocks", type=int, default=2) # num_res_blocks for encoder, num_res_blocks+1 for decoder
388 | parser.add_argument("--encoder_ch_mult", type=int, nargs='+', default=[1, 1, 2, 2, 4])
389 | parser.add_argument("--decoder_ch_mult", type=int, nargs='+', default=[1, 1, 2, 2, 4])
390 | return parser
391 |
--------------------------------------------------------------------------------
/bitvae/data/data.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Foundationvision, Inc. All Rights Reserved
2 |
3 | import os
4 | import os.path as osp
5 | import random
6 | import argparse
7 | import numpy as np
8 | from PIL import Image
9 | Image.MAX_IMAGE_PIXELS = None
10 |
11 | import torch
12 | import torch.utils.data as data
13 | import torch.distributed as dist
14 | from torchvision import transforms
15 |
16 | from bitvae.data.dataset_zoo import DATASET_DICT
17 | from torchvision.transforms import InterpolationMode
18 | from bitvae.modules.quantizer.dynamic_resolution import dynamic_resolution_h_w
19 |
20 | def _pil_interp(method):
21 | if method == 'bicubic':
22 | return InterpolationMode.BICUBIC
23 | elif method == 'lanczos':
24 | return InterpolationMode.LANCZOS
25 | elif method == 'hamming':
26 | return InterpolationMode.HAMMING
27 | else:
28 | # default bilinear, do we want to allow nearest?
29 | return InterpolationMode.BILINEAR
30 |
31 | import timm.data.transforms as timm_transforms
32 | timm_transforms._pil_interp = _pil_interp
33 |
34 | def get_parent_dir(path):
35 | return osp.basename(osp.dirname(path))
36 |
37 | # Only used during inference (for benchmarks like tuchong, which allows different resolutions)
38 | class DynamicAspectRatioGroupedDataset(data.Dataset):
39 | """
40 | Batch data that have similar aspect ratio together.
41 | In this implementation, images whose aspect ratio < (or >) 1 will
42 | be batched together.
43 | This improves training speed because the images then need less padding
44 | to form a batch.
45 |
46 | It assumes the underlying dataset produces dicts with "width" and "height" keys.
47 | It will then produce a list of original dicts with length = batch_size,
48 | all with similar aspect ratios.
49 | """
50 |
51 | def __init__(self,
52 | dataset, batch_size,
53 | debug=False, seed=0, train=True, random_bucket_ratio=0.
54 | ):
55 | """
56 | Args:
57 | dataset: an iterable. Each element must be a dict with keys
58 | "width" and "height", which will be used to batch data.
59 | batch_size (int):
60 | """
61 | self.train = train
62 | self._idx = 0
63 | self.dataset = dataset
64 | self.batch_size = batch_size
65 | self.random_bucket_ratio = random_bucket_ratio
66 | self.aspect_ratio = list(dynamic_resolution_h_w.keys())
67 | num_buckets = len(self.aspect_ratio) * 3 # in each aspect-ratio, we have three scales
68 | self._buckets = [[] for _ in range(num_buckets)]
69 | self.debug = debug
70 | self.seed = seed
71 | if type(dataset) == ImageDataset:
72 | self.batch_factor = 0.6 # A100 config
73 | else:
74 | raise NotImplementedError
75 | # Hard-coded two aspect ratio groups: w > h and w < h.
76 | # Can add support for more aspect ratio groups, but doesn't seem useful
77 |
78 | def closest_id(self, v, type="width"):
79 | if type == "width":
80 | dist = np.array([abs(v - self.width[i]) for i in range(len(self.width))])
81 | return np.argmin(dist)
82 | else:
83 | dist = np.array([abs(v - self.aspect_ratio[i]) for i in range(len(self.aspect_ratio))])
84 | return np.argmin(dist)
85 |
86 | def __len__(self):
87 | return len(self.dataset) // self.batch_size ### an approximate value
88 |
89 |
90 | def collate_func(self, batch_list):
91 | batch = batch_list[0]
92 |
93 | return {
94 | "image": torch.stack([d["image"] for d in batch], dim=0),
95 | "label": [d["label"] for d in batch],
96 | "path": [d["path"] for d in batch],
97 | "height": [d["height"] for d in batch],
98 | "width": [d["width"] for d in batch],
99 | "type": [d["type"] for d in batch],
100 | }
101 |
102 | def get_batch_size(self, w_id, ar_id):
103 | sel_ar = self.aspect_ratio[ar_id]
104 | sel_w = self.width[w_id]
105 | dynamic_batch_size = int(self.batch_size / sel_ar / ((sel_w / self.max_width)**2) / self.batch_factor)
106 | return max(dynamic_batch_size, 1)
107 |
108 | def get_ar_w_new(self, w, ar, strategy="closest"):
109 | if strategy == "closest":
110 | # get new ar
111 | dist = np.array([abs(ar - self.aspect_ratio[i]) for i in range(len(self.aspect_ratio))])
112 | ar_id = np.argmin(dist)
113 | ar_new = self.aspect_ratio[ar_id]
114 | # get new w
115 | w_list = list(dynamic_resolution_h_w[ar_new].keys())
116 | dist = np.array([abs(w - w_list[i]) for i in range(len(w_list))])
117 | w_id = np.argmin(dist)
118 | w_new = w_list[w_id]
119 | h_new = dynamic_resolution_h_w[ar_new][w_new]["pixel"][0]
120 | return w_new, h_new, w_id, ar_id
121 | elif strategy == "random":
122 | raise NotImplementedError
123 |
124 | def get_aug(self, w, h, strategy="closest"):
125 | sel_h, sel_w = h, w
126 | assert sel_h % 8 == 0 and sel_w % 8 == 0
127 | aug_shape = (sel_h, sel_w)
128 | if strategy == "closest":
129 | aug = transforms.Resize(aug_shape) # resize to the closest size
130 | elif strategy == "random":
131 | min_edge = min(sel_w, sel_h)
132 | aug = transforms.Compose([
133 | transforms.Resize(min_edge),
134 | transforms.CenterCrop((sel_h, sel_w)),
135 | ])
136 | return aug
137 |
138 | def __getitem__(self, idx):
139 | d = self.dataset.__getitem__(idx)
140 | w, h = d["width"], d["height"]
141 | ar = h / w
142 | strategy = "random" if random.random() < self.random_bucket_ratio else "closest"
143 | w_new, h_new, w_id, ar_id = self.get_ar_w_new(w, ar, strategy=strategy)
144 | aug = self.get_aug(w_new, h_new, strategy=strategy)
145 | images = d["image"]
146 | assert images.ndim == 3
147 | d["image"] = aug(images)
148 | assert (d["image"].shape[1] % 8 == 0) and (d["image"].shape[2] % 8 == 0)
149 |
150 | return [d]
151 |
152 | def __iter__(self):
153 | # if not self.debug:
154 | while True:
155 | if self.train:
156 | idx = random.randint(0, self.dataset.__len__()-1)
157 | else:
158 | idx = self._idx
159 | self._idx = (self._idx + 1) % self.dataset.__len__()
160 |
161 | d = self.dataset.__getitem__(idx)
162 | w, h = d["width"], d["height"]
163 | ar = h / w
164 | strategy = "random" if random.random() < self.random_bucket_ratio else "closest"
165 | w_new, h_new, w_id, ar_id = self.get_ar_w_new(w, ar, strategy=strategy)
166 | aug = self.get_aug(w_new, h_new, strategy=strategy)
167 | images = d["image"]
168 | assert images.ndim == 3
169 |
170 | d["image"] = aug(images)
171 | assert (d["image"].shape[1] % 8 == 0) and (d["image"].shape[2] % 8 == 0)
172 |
173 | bucket_id = ar_id * 3 + w_id # TODO: fix this hardcode 3
174 | bucket = self._buckets[bucket_id]
175 | bucket.append(d)
176 | target_batch_size = self.get_batch_size(w_id, ar_id) if self.train else self.batch_size
177 | if len(bucket) == target_batch_size:
178 | data = bucket[:]
179 | # Clear bucket first, because code after yield is not
180 | # guaranteed to execute
181 | del bucket[:]
182 | yield data
183 |
184 |
185 | class AspectRatioGroupedDataset(data.IterableDataset):
186 | """
187 | Batch data that have similar aspect ratio together.
188 | In this implementation, images whose aspect ratio < (or >) 1 will
189 | be batched together.
190 | This improves training speed because the images then need less padding
191 | to form a batch.
192 |
193 | It assumes the underlying dataset produces dicts with "width" and "height" keys.
194 | It will then produce a list of original dicts with length = batch_size,
195 | all with similar aspect ratios.
196 | """
197 |
198 | def __init__(self,
199 | dataset, batch_size,
200 | width=[256, 320, 384, 448, 512], # A100 config
201 | aspect_ratio=[4/16, 6/16, 8/16, 9/16, 10/16, 12/16, 14/16, 1, 16/14, 16/12, 16/10, 16/9, 16/8, 16/6, 16/4], # A100 config
202 | max_resolution=512*512, # A100 config
203 | debug=False, seed=0, train=True, random_bucket_ratio=0.
204 | ):
205 | """
206 | Args:
207 | dataset: an iterable. Each element must be a dict with keys
208 | "width" and "height", which will be used to batch data.
209 | batch_size (int):
210 | """
211 | self.train = train
212 | self._idx = 0
213 | self.dataset = dataset
214 | self.batch_size = batch_size
215 | self.random_bucket_ratio = random_bucket_ratio
216 | num_buckets = len(width) * len(aspect_ratio)
217 | self.width = width
218 | self.max_width = max(width)
219 | self.aspect_ratio = aspect_ratio
220 | self._buckets = [[] for _ in range(num_buckets)]
221 | self.debug = debug
222 | self.seed = seed
223 | self.max_resolution = max_resolution
224 | if type(dataset) == ImageDataset:
225 | self.batch_factor = 0.6 # A100 config
226 | else:
227 | raise NotImplementedError
228 | # Hard-coded two aspect ratio groups: w > h and w < h.
229 | # Can add support for more aspect ratio groups, but doesn't seem useful
230 |
231 | def closest_id(self, v, type="width"):
232 | if type == "width":
233 | dist = np.array([abs(v - self.width[i]) for i in range(len(self.width))])
234 | return np.argmin(dist)
235 | else:
236 | dist = np.array([abs(v - self.aspect_ratio[i]) for i in range(len(self.aspect_ratio))])
237 | return np.argmin(dist)
238 |
239 | def __len__(self):
240 | return len(self.dataset) // self.batch_size ### an approximate value
241 |
242 |
243 | def collate_func(self, batch_list):
244 | batch = batch_list[0]
245 |
246 | return {
247 | "image": torch.stack([d["image"] for d in batch], dim=0),
248 | "label": [d["label"] for d in batch],
249 | "path": [d["path"] for d in batch],
250 | "height": [d["height"] for d in batch],
251 | "width": [d["width"] for d in batch],
252 | "type": [d["type"] for d in batch],
253 | }
254 |
255 | def get_batch_size(self, w_id, ar_id):
256 | sel_ar = self.aspect_ratio[ar_id]
257 | sel_w = self.width[w_id]
258 | dynamic_batch_size = int(self.batch_size / sel_ar / ((sel_w / self.max_width)**2) / self.batch_factor)
259 | return max(dynamic_batch_size, 1)
260 |
261 | def memory_safty_guard(self, w_id, ar_id):
262 | while True:
263 | sel_ar = self.aspect_ratio[ar_id]
264 | sel_w = self.width[w_id]
265 | if self.max_resolution < 0 or (sel_ar * sel_w) * sel_w <= self.max_resolution:
266 | break
267 | else:
268 | w_id = w_id - 1
269 | return w_id, ar_id
270 |
271 | def get_ar_w_id(self, w, ar, strategy="closest"):
272 | if strategy == "closest":
273 | w_id = self.closest_id(w, type="width")
274 | ar_id = self.closest_id(ar, type="aspect_ratio")
275 | elif strategy == "random":
276 | h = w * ar
277 | ws = [_w for _w in self.width if _w <= w]
278 | _w = random.choice(ws) if len(ws) > 0 else self.width[0]
279 | w_id = self.width.index(_w)
280 | ars = [_ar for _ar in self.aspect_ratio if _ar * w < h]
281 | _ar = random.choice(ars) if len(ars) > 0 else self.aspect_ratio[0]
282 | ar_id = self.aspect_ratio.index(_ar)
283 | return self.memory_safty_guard(w_id, ar_id)
284 |
285 | def get_aug(self, w_id, ar_id, strategy="closest"):
286 | sel_ar = self.aspect_ratio[ar_id]
287 | sel_w = self.width[w_id]
288 | sel_h = int(sel_w * sel_ar)
289 | sel_h = (sel_h+4) - ((sel_h+4) % 8) # round by 8
290 | aug_shape = (sel_h, int(sel_w))
291 | if strategy == "closest":
292 | aug = transforms.Resize(aug_shape) # resize to the closest size
293 | elif strategy == "random":
294 | min_edge = min(sel_w, sel_h)
295 | aug = transforms.Compose([
296 | transforms.Resize(min_edge),
297 | transforms.CenterCrop((sel_h, sel_w)),
298 | ])
299 | return aug
300 |
301 | def __iter__(self):
302 | # if not self.debug:
303 | while True:
304 | if self.train:
305 | idx = random.randint(0, self.dataset.__len__()-1)
306 | else:
307 | idx = self._idx
308 | self._idx = (self._idx + 1) % self.dataset.__len__()
309 |
310 | d = self.dataset.__getitem__(idx)
311 | w, h = d["width"], d["height"]
312 | ar = h / w
313 | strategy = "random" if random.random() < self.random_bucket_ratio else "closest"
314 | w_id, ar_id = self.get_ar_w_id(w, ar, strategy=strategy)
315 | aug = self.get_aug(w_id, ar_id, strategy=strategy)
316 | images = d["image"]
317 | assert images.ndim == 3
318 |
319 | d["image"] = aug(images)
320 | assert (d["image"].shape[1] % 8 == 0) and (d["image"].shape[2] % 8 == 0)
321 |
322 | bucket_id = ar_id * len(self.width) + w_id
323 | bucket = self._buckets[bucket_id]
324 | bucket.append(d)
325 | target_batch_size = self.get_batch_size(w_id, ar_id) if self.train else self.batch_size
326 | if len(bucket) == target_batch_size:
327 | data = bucket[:]
328 | # Clear bucket first, because code after yield is not
329 | # guaranteed to execute
330 | del bucket[:]
331 | yield data
332 |
333 |
334 | class ImageDataset(data.Dataset):
335 | """ Generic dataset for Images files stored in folders
336 | Returns BCHW Images in the range [-0.5, 0.5] """
337 |
338 | def __init__(self, data_folder, data_list, train=True, resolution=64, aug="resize"):
339 | """
340 | Args:
341 | data_folder: path to the folder with images. The folder
342 | should contain a 'train' and a 'test' directory,
343 | each with corresponding images stored
344 | """
345 | super().__init__()
346 | self.train = train
347 | self.data_folder = data_folder
348 | self.data_list = data_list
349 | self.resolution = resolution
350 |
351 | with open(self.data_list) as f:
352 | self.annotations = f.readlines()
353 |
354 | total_classes = 1000
355 | classes = []
356 | # from imagenet_stubs.imagenet_2012_labels import label_to_name
357 | # for i in range(total_classes):
358 | # classes.append(label_to_name(i))
359 |
360 | self.classes = classes
361 | self.class_to_label = {c: i for i, c in enumerate(self.classes)}
362 | self.label_to_class = {i: c for i, c in enumerate(self.classes)}
363 |
364 | crop_function = transforms.RandomCrop(resolution) if train else transforms.CenterCrop(resolution)
365 | flip_function = transforms.RandomHorizontalFlip() if train else transforms.Lambda(lambda x: x) # flip if train else no op
366 | if aug == "resizecrop":
367 | augmentations = transforms.Compose([
368 | transforms.Resize(min(resolution), interpolation=_pil_interp("bicubic")),
369 | crop_function,
370 | flip_function,
371 | transforms.ToTensor(),
372 | ])
373 | elif aug == "crop":
374 | augmentations = transforms.Compose([
375 | crop_function,
376 | flip_function,
377 | transforms.ToTensor(),
378 | ])
379 | elif aug == "keep":
380 | augmentations = transforms.Compose([
381 | flip_function,
382 | transforms.ToTensor(),
383 | ])
384 | else:
385 | raise NotImplementedError
386 |
387 | self.aug = aug
388 | self.augmentations = augmentations
389 |
390 |
391 | @property
392 | def n_classes(self):
393 | return len(self.classes)
394 |
395 | def __len__(self):
396 | return len(self.annotations)
397 |
398 | def __getitem__(self, idx):
399 | try:
400 | ann = self.annotations[idx].strip()
401 | try:
402 | img_path, height, width = ann.split()
403 | img_label = -1
404 |
405 | except:
406 | img_path, img_label = ann.split()
407 |
408 | full_img_path = os.path.join(self.data_folder, img_path)
409 |
410 | img = Image.open(full_img_path).convert('RGB')
411 | h, w = img.height, img.width
412 |
413 | img = self.augmentations(img) * 2.0 - 1.0
414 | if self.aug != "keep":
415 | assert img.shape[1] == self.resolution[0] and img.shape[2] == self.resolution[1]
416 |
417 | return {"image": img, "label": int(img_label), "path": img_path, "height": h, "width": w, "type": "image"}
418 | except Exception as e:
419 | print(f"Error in dataloader {e}")
420 | return self.__getitem__((idx+1) % self.__len__())
421 |
422 |
423 | class ImageData():
424 |
425 | def __init__(self, args, shuffle=True):
426 | super().__init__()
427 | self.args = args
428 | self.shuffle = shuffle
429 |
430 | @property
431 | def n_classes(self):
432 | dataset = self._dataset(True)
433 | return dataset[0].n_classes
434 |
435 | def _dataset(self, train):
436 | datasets = []
437 | for dataset, batch_size in zip(self.args.dataset_list, self.args.batch_size):
438 | dataset_path = DATASET_DICT[dataset]["dataset_path"]
439 | data_type = DATASET_DICT[dataset]["data_type"]
440 | train_label = DATASET_DICT[dataset]["train_label"]
441 | val_label = DATASET_DICT[dataset]["val_label"]
442 | if data_type == "image":
443 | dataset = ImageDataset(
444 | dataset_path, train_label if train else val_label, train=train, resolution=self.args.resolution, aug=self.args.dataaug
445 | )
446 |
447 | if self.args.multi_resolution:
448 | # assert len(self.args.data_path) == 1
449 | if train:
450 | dataset = AspectRatioGroupedDataset(
451 | dataset, batch_size=batch_size, debug=self.args.debug, train=train, random_bucket_ratio=self.args.random_bucket_ratio
452 | )
453 | else:
454 | dataset = DynamicAspectRatioGroupedDataset(
455 | dataset, batch_size=batch_size, debug=self.args.debug, train=train, random_bucket_ratio=self.args.random_bucket_ratio
456 | )
457 | datasets.append(dataset)
458 | return datasets
459 |
460 | def _dataloader(self, train):
461 | dataset = self._dataset(train)
462 | # print(self.args.batch_size)
463 | if isinstance(self.args.batch_size, int):
464 | self.args.batch_size = [self.args.batch_size]
465 |
466 | assert len(dataset) == len(self.args.batch_size)
467 | dataloaders = []
468 | for dset, d_batch_size in zip(dataset, self.args.batch_size):
469 | if dist.is_initialized():
470 | sampler = data.distributed.DistributedSampler(
471 | dset, num_replicas=dist.get_world_size(), rank=dist.get_rank()
472 | )
473 | global_rank = dist.get_rank()
474 | else:
475 | sampler = None
476 | global_rank = None
477 |
478 | def seed_worker(worker_id):
479 | if global_rank:
480 | seed = self.args.num_workers * global_rank + worker_id
481 | else:
482 | seed = worker_id
483 | # print(f"Setting dataloader worker {worker_id} on GPU {global_rank} as seed {seed}")
484 | torch.manual_seed(seed)
485 | np.random.seed(seed)
486 | random.seed(seed)
487 |
488 | dataloader = data.DataLoader(
489 | dset,
490 | batch_size=d_batch_size if not self.args.multi_resolution else 1,
491 | num_workers=self.args.num_workers,
492 | pin_memory=True,
493 | worker_init_fn=seed_worker,
494 | sampler=sampler if not isinstance(dset, data.IterableDataset) else None,
495 | collate_fn=dset.collate_func if hasattr(dset, "collate_func") else None,
496 | shuffle=sampler is None and train,
497 | drop_last=True,
498 | )
499 |
500 | dataloaders.append(dataloader)
501 |
502 | return dataloaders
503 |
504 | def train_dataloader(self):
505 | return self._dataloader(True)
506 |
507 | def val_dataloader(self):
508 | return self._dataloader(False)[0]
509 |
510 | def test_dataloader(self):
511 | return self.val_dataloader()
512 |
513 |
514 | @staticmethod
515 | def add_data_specific_args(parent_parser):
516 | parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False)
517 | parser.add_argument('--data_path', type=str, nargs="+", default=[""])
518 | parser.add_argument('--data_type', type=str, nargs="+", default=[""])
519 | parser.add_argument('--dataset_list', type=str, nargs="+", default=[''])
520 | parser.add_argument('--dataaug', type=str, choices=["resize", "resizecrop", "crop", "keep"])
521 | parser.add_argument('--multi_resolution', action="store_true")
522 | parser.add_argument('--random_bucket_ratio', type=float, default=0.)
523 |
524 | parser.add_argument('--resolution', type=int, nargs="+", default=[512])
525 | parser.add_argument('--batch_size', type=int, nargs="+", default=[32])
526 | parser.add_argument('--num_workers', type=int, default=8)
527 | parser.add_argument('--image_channels', type=int, default=3)
528 |
529 | return parser
530 |
--------------------------------------------------------------------------------
/bitvae/modules/quantizer/multiscale_bsq.py:
--------------------------------------------------------------------------------
1 | """
2 | Binary Spherical Quantization
3 | Proposed in https://arxiv.org/abs/2406.07548
4 |
5 | In the simplest setup, each dimension is quantized into {-1, 1}.
6 | An entropy penalty is used to encourage utilization.
7 | """
8 |
9 | import random
10 | from math import log2, ceil
11 | from functools import partial, cache
12 | from collections import namedtuple
13 | from contextlib import nullcontext
14 |
15 | import torch.distributed as dist
16 | from torch.distributed import nn as dist_nn
17 |
18 | import torch
19 | from torch import nn, einsum
20 | import torch.nn.functional as F
21 | from torch.nn import Module
22 | from torch.amp import autocast
23 |
24 | from einops import rearrange, reduce, pack, unpack
25 |
26 | from einx import get_at
27 |
28 | from .dynamic_resolution import predefined_HW_Scales_dynamic
29 |
30 | # constants
31 |
32 | Return = namedtuple('Return', ['quantized', 'indices', 'bit_indices', 'entropy_aux_loss'])
33 |
34 | LossBreakdown = namedtuple('LossBreakdown', ['per_sample_entropy', 'batch_entropy', 'commitment'])
35 |
36 | # distributed helpers
37 |
38 | @cache
39 | def is_distributed():
40 | return dist.is_initialized() and dist.get_world_size() > 1
41 |
42 | def maybe_distributed_mean(t):
43 | if not is_distributed():
44 | return t
45 |
46 | dist_nn.all_reduce(t)
47 | t = t / dist.get_world_size()
48 | return t
49 |
50 | # helper functions
51 |
52 | def exists(v):
53 | return v is not None
54 |
55 | def identity(t):
56 | return t
57 |
58 | def default(*args):
59 | for arg in args:
60 | if exists(arg):
61 | return arg() if callable(arg) else arg
62 | return None
63 |
64 | def round_up_multiple(num, mult):
65 | return ceil(num / mult) * mult
66 |
67 | def pack_one(t, pattern):
68 | return pack([t], pattern)
69 |
70 | def unpack_one(t, ps, pattern):
71 | return unpack(t, ps, pattern)[0]
72 |
73 | def l2norm(t):
74 | return F.normalize(t, dim = -1)
75 |
76 | # entropy
77 |
78 | def log(t, eps = 1e-5):
79 | return t.clamp(min = eps).log()
80 |
81 | def entropy(prob):
82 | return (-prob * log(prob)).sum(dim=-1)
83 |
84 | # cosine sim linear
85 |
86 | class CosineSimLinear(Module):
87 | def __init__(
88 | self,
89 | dim_in,
90 | dim_out,
91 | scale = 1.
92 | ):
93 | super().__init__()
94 | self.scale = scale
95 | self.weight = nn.Parameter(torch.randn(dim_in, dim_out))
96 |
97 | def forward(self, x):
98 | x = F.normalize(x, dim = -1)
99 | w = F.normalize(self.weight, dim = 0)
100 | return (x @ w) * self.scale
101 |
102 |
103 | def get_latent2scale_schedule(T: int, H: int, W: int, mode="original"):
104 | assert mode in ["original", "dynamic", "dense", "same1", "same2", "same3", "half", "dense_f8"]
105 | predefined_HW_Scales = {
106 | # 256 * 256
107 | (32, 32): [(1, 1), (2, 2), (3, 3), (4, 4), (6, 6), (9, 9), (13, 13), (18, 18), (24, 24), (32, 32)],
108 | (16, 16): [(1, 1), (2, 2), (3, 3), (4, 4), (5, 5), (6, 6), (8, 8), (10, 10), (13, 13), (16, 16)],
109 | # 1024x1024
110 | (64, 64): [(1, 1), (2, 2), (3, 3), (4, 4), (5, 5), (7, 7), (9, 9), (12, 12), (16, 16), (21, 21), (27, 27), (36, 36), (48, 48), (64, 64)],
111 |
112 | (36, 64): [(1, 1), (2, 2), (3, 3), (4, 4), (6, 6), (9, 12), (13, 16), (18, 24), (24, 32), (32, 48), (36, 64)],
113 | }
114 | if mode == "dynamic":
115 | predefined_HW_Scales.update(predefined_HW_Scales_dynamic)
116 | elif mode == "dense":
117 | predefined_HW_Scales[(16, 16)] = [(x, x) for x in range(1, 16+1)]
118 | predefined_HW_Scales[(32, 32)] = predefined_HW_Scales[(16, 16)] + [(20, 20), (24, 24), (28, 28), (32, 32)]
119 | predefined_HW_Scales[(64, 64)] = predefined_HW_Scales[(32, 32)] + [(40, 40), (48, 48), (56, 56), (64, 64)]
120 | elif mode == "dense_f8":
121 | # predefined_HW_Scales[(16, 16)] = [(x, x) for x in range(1, 16+1)]
122 | predefined_HW_Scales[(32, 32)] = [(x, x) for x in range(1, 16+1)] + [(20, 20), (24, 24), (28, 28), (32, 32)]
123 | predefined_HW_Scales[(64, 64)] = predefined_HW_Scales[(32, 32)] + [(40, 40), (48, 48), (56, 56), (64, 64)]
124 | predefined_HW_Scales[(128, 128)] = predefined_HW_Scales[(64, 64)] + [(80, 80), (96, 96), (112, 112), (128, 128)]
125 | elif mode.startswith("same"):
126 | num_quant = int(mode[len("same"):])
127 | predefined_HW_Scales[(16, 16)] = [(16, 16) for _ in range(num_quant)]
128 | predefined_HW_Scales[(32, 32)] = [(32, 32) for _ in range(num_quant)]
129 | predefined_HW_Scales[(64, 64)] = [(64, 64) for _ in range(num_quant)]
130 | elif mode == "half":
131 | predefined_HW_Scales[(32, 32)] = [(1, 1), (2, 2), (3, 3), (4, 4), (5, 5), (6, 6), (8, 8), (10, 10), (13, 13), (16, 16)]
132 | predefined_HW_Scales[(64, 64)] = [(1,1),(2,2),(4,4),(6,6),(8,8),(12,12),(16,16)]
133 |
134 | predefined_T_Scales = [1, 2, 3, 4, 5, 6, 7, 9, 11, 13, 15, 17, 17, 17, 17, 17]
135 | patch_THW_shape_per_scale = predefined_HW_Scales[(H, W)]
136 | if len(predefined_T_Scales) < len(patch_THW_shape_per_scale):
137 | # print("warning: the length of predefined_T_Scales is less than the length of patch_THW_shape_per_scale!")
138 | predefined_T_Scales += [predefined_T_Scales[-1]] * (len(patch_THW_shape_per_scale) - len(predefined_T_Scales))
139 | patch_THW_shape_per_scale = [(min(T, t), h, w ) for (h, w), t in zip(patch_THW_shape_per_scale, predefined_T_Scales[:len(patch_THW_shape_per_scale)])]
140 | return patch_THW_shape_per_scale
141 |
142 |
143 | class MultiScaleBSQ(Module):
144 | """ Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf """
145 |
146 | def __init__(
147 | self,
148 | *,
149 | dim,
150 | soft_clamp_input_value = None,
151 | aux_loss = False, # intermediate auxiliary loss
152 | use_decay_factor=False,
153 | use_stochastic_depth=False,
154 | drop_rate=0.,
155 | schedule_mode="original", # ["original", "dynamic", "dense"]
156 | keep_first_quant=False,
157 | keep_last_quant=False,
158 | remove_residual_detach=False,
159 | random_flip = False,
160 | flip_prob = 0.5,
161 | flip_mode = "stochastic", # "stochastic", "deterministic"
162 | max_flip_lvl = 1,
163 | random_flip_1lvl = False, # random flip one level each time
164 | flip_lvl_idx = None,
165 | drop_when_test=False,
166 | drop_lvl_idx=None,
167 | drop_lvl_num=0,
168 | random_short_schedule = False, # randomly use short schedule (schedule for images of 256x256)
169 | short_schedule_prob = 0.5,
170 | disable_flip_prob = 0.0, # disable random flip in this image
171 | uniform_short_schedule = False,
172 | **kwargs
173 | ):
174 | super().__init__()
175 | codebook_dim = dim
176 |
177 | requires_projection = codebook_dim != dim
178 | self.project_in = nn.Linear(dim, codebook_dim) if requires_projection else nn.Identity()
179 | self.project_out = nn.Linear(codebook_dim, dim) if requires_projection else nn.Identity()
180 | self.has_projections = requires_projection
181 | self.layernorm = nn.Identity()
182 | self.use_stochastic_depth = use_stochastic_depth
183 | self.drop_rate = drop_rate
184 | self.remove_residual_detach = remove_residual_detach
185 | self.random_flip = random_flip
186 | self.flip_prob = flip_prob
187 | self.flip_mode = flip_mode
188 | self.max_flip_lvl = max_flip_lvl
189 | self.random_flip_1lvl = random_flip_1lvl
190 | self.flip_lvl_idx = flip_lvl_idx
191 | assert (random_flip and random_flip_1lvl) == False
192 | self.disable_flip_prob = disable_flip_prob
193 |
194 | self.drop_when_test = drop_when_test
195 | self.drop_lvl_idx = drop_lvl_idx
196 | self.drop_lvl_num = drop_lvl_num
197 | if self.drop_when_test:
198 | assert drop_lvl_idx is not None
199 | assert drop_lvl_num > 0
200 | self.random_short_schedule = random_short_schedule
201 | self.short_schedule_prob = short_schedule_prob
202 | self.full2short = {7:7, 10:7, 13:7, 16:16, 20:16, 24:16}
203 | self.full2short_f8 = {20:20, 24:20, 28:20}
204 | self.uniform_short_schedule = uniform_short_schedule
205 | assert not (self.random_short_schedule and self.uniform_short_schedule)
206 |
207 | self.lfq = BSQ(
208 | dim = codebook_dim,
209 | codebook_scale = 1,
210 | soft_clamp_input_value = soft_clamp_input_value,
211 | **kwargs
212 | )
213 |
214 | self.z_interplote_up = 'trilinear'
215 | self.z_interplote_down = 'area'
216 |
217 | self.use_decay_factor = use_decay_factor
218 | self.schedule_mode = schedule_mode
219 | self.keep_first_quant = keep_first_quant
220 | self.keep_last_quant = keep_last_quant
221 | if self.use_stochastic_depth and self.drop_rate > 0:
222 | assert self.keep_first_quant or self.keep_last_quant
223 |
224 | @property
225 | def codebooks(self):
226 | return self.lfq.codebook
227 |
228 | def get_codes_from_indices(self, indices_list):
229 | all_codes = []
230 | for indices in indices_list:
231 | codes = self.lfq.indices_to_codes(indices)
232 | all_codes.append(codes)
233 | _, _, T, H, W = all_codes[-1].size()
234 | summed_codes = 0
235 | for code in all_codes:
236 | summed_codes += F.interpolate(code, size=(T, H, W), mode=self.z_interplote_up)
237 | return summed_codes
238 |
239 | def get_output_from_indices(self, indices):
240 | codes = self.get_codes_from_indices(indices)
241 | codes_summed = reduce(codes, 'q ... -> ...', 'sum')
242 | return self.project_out(codes_summed)
243 |
244 | def flip_quant(self, x):
245 | if self.flip_mode == 'stochastic':
246 | flip_mask = torch.rand_like(x) < self.flip_prob
247 | else:
248 | raise NotImplementedError
249 | x = x.clone()
250 | x[flip_mask] = -x[flip_mask]
251 | return x
252 |
253 | def forward(
254 | self,
255 | x,
256 | mask = None,
257 | return_all_codes = False,
258 | ):
259 | if x.ndim == 4:
260 | x = x.unsqueeze(2)
261 | B, C, T, H, W = x.size()
262 |
263 | if self.schedule_mode.startswith("same"):
264 | scale_num = int(self.schedule_mode[len("same"):])
265 | assert T == 1
266 | scale_schedule = [(1, H, W)] * scale_num
267 | else:
268 | scale_schedule = get_latent2scale_schedule(T, H, W, mode=self.schedule_mode)
269 | scale_num = len(scale_schedule)
270 |
271 | if self.uniform_short_schedule:
272 | scale_num_short = self.full2short_f8[scale_num] if self.schedule_mode == "dense_f8" else self.full2short[scale_num]
273 | scale_num = random.randint(scale_num_short, scale_num)
274 | scale_schedule = scale_schedule[:scale_num]
275 | elif self.random_short_schedule and random.random() < self.short_schedule_prob:
276 | if self.schedule_mode == "dense_f8":
277 | scale_num = self.full2short_f8[scale_num]
278 | else:
279 | scale_num = self.full2short[scale_num]
280 | scale_schedule = scale_schedule[:scale_num]
281 |
282 | # x = self.project_in(x)
283 | x = x.permute(0, 2, 3, 4, 1).contiguous() # (b, c, t, h, w) => (b, t, h, w, c)
284 | x = self.project_in(x)
285 | x = x.permute(0, 4, 1, 2, 3).contiguous() # (b, t, h, w, c) => (b, c, t, h, w)
286 | x = self.layernorm(x)
287 |
288 | quantized_out = 0.
289 | residual = x
290 |
291 | all_losses = []
292 | all_indices = []
293 | all_bit_indices = []
294 |
295 | # go through the layers
296 | out_fact = init_out_fact = 1.0
297 | # residual_list = []
298 | # interpolate_residual_list = []
299 | # quantized_list = []
300 | if self.drop_when_test:
301 | drop_lvl_start = self.drop_lvl_idx
302 | drop_lvl_end = self.drop_lvl_idx + self.drop_lvl_num
303 | disable_flip = True if random.random() < self.disable_flip_prob else False # disable random flip in this image
304 | with autocast('cuda', enabled = False):
305 | for si, (pt, ph, pw) in enumerate(scale_schedule):
306 |
307 | out_fact = max(0.1, out_fact) if self.use_decay_factor else init_out_fact
308 | if (pt, ph, pw) != (T, H, W):
309 | interpolate_residual = F.interpolate(residual, size=(pt, ph, pw), mode=self.z_interplote_down)
310 | else:
311 | interpolate_residual = residual
312 | # residual_list.append(torch.norm(residual.detach(), dim=1).mean())
313 | # interpolate_residual_list.append(torch.norm(interpolate_residual.detach(), dim=1).mean())
314 | if self.training and self.use_stochastic_depth and random.random() < self.drop_rate:
315 | if (si == 0 and self.keep_first_quant) or (si == scale_num - 1 and self.keep_last_quant):
316 | quantized, indices, bit_indices, loss = self.lfq(interpolate_residual)
317 | if self.random_flip and si < self.max_flip_lvl and (not disable_flip):
318 | quantized = self.flip_quant(quantized)
319 | quantized = quantized * out_fact
320 | all_indices.append(indices)
321 | all_losses.append(loss)
322 | all_bit_indices.append(bit_indices)
323 | else:
324 | quantized = torch.zeros_like(interpolate_residual)
325 | elif self.drop_when_test and drop_lvl_start <= si < drop_lvl_end:
326 | continue
327 | else:
328 | # residual_norm = torch.norm(interpolate_residual.detach(), dim=1) # (b, t, h, w)
329 | # print(si, residual_norm.min(), residual_norm.max(), residual_norm.mean())
330 | quantized, indices, bit_indices, loss = self.lfq(interpolate_residual)
331 | if self.random_flip and si < self.max_flip_lvl and (not disable_flip):
332 | quantized = self.flip_quant(quantized)
333 | if self.random_flip_1lvl and si == self.flip_lvl_idx and (not disable_flip):
334 | quantized = self.flip_quant(quantized)
335 | quantized = quantized * out_fact
336 | all_indices.append(indices)
337 | all_losses.append(loss)
338 | all_bit_indices.append(bit_indices)
339 | # quantized_list.append(torch.norm(quantized.detach(), dim=1).mean())
340 | if (pt, ph, pw) != (T, H, W):
341 | quantized = F.interpolate(quantized, size=(T, H, W), mode=self.z_interplote_up).contiguous()
342 |
343 | if self.remove_residual_detach:
344 | residual = residual - quantized
345 | else:
346 | residual = residual - quantized.detach()
347 | quantized_out = quantized_out + quantized
348 |
349 | if self.use_decay_factor:
350 | out_fact -= 0.1
351 | # print("residual_list:", residual_list)
352 | # print("interpolate_residual_list:", interpolate_residual_list)
353 | # print("quantized_list:", quantized_list)
354 | # import ipdb; ipdb.set_trace()
355 | # project out, if needed
356 | quantized_out = quantized_out.permute(0, 2, 3, 4, 1).contiguous() # (b, c, t, h, w) => (b, t, h, w, c)
357 | quantized_out = self.project_out(quantized_out)
358 | quantized_out = quantized_out.permute(0, 4, 1, 2, 3).contiguous() # (b, t, h, w, c) => (b, c, t, h, w)
359 |
360 | # image
361 | if quantized_out.size(2) == 1:
362 | quantized_out = quantized_out.squeeze(2)
363 |
364 | # stack all losses and indices
365 |
366 | all_losses = torch.stack(all_losses, dim = -1)
367 |
368 | ret = (quantized_out, all_indices, all_bit_indices, all_losses)
369 |
370 | if not return_all_codes:
371 | return ret
372 |
373 | # whether to return all codes from all codebooks across layers
374 | all_codes = self.get_codes_from_indices(all_indices)
375 |
376 | # will return all codes in shape (quantizer, batch, sequence length, codebook dimension)
377 |
378 | return (*ret, all_codes)
379 |
380 |
381 | class BSQ(Module):
382 | def __init__(
383 | self,
384 | *,
385 | dim = None,
386 | entropy_loss_weight = 0.1,
387 | commitment_loss_weight = 0.25,
388 | diversity_gamma = 1.,
389 | straight_through_activation = nn.Identity(),
390 | num_codebooks = 1,
391 | keep_num_codebooks_dim = None,
392 | codebook_scale = 1., # for residual LFQ, codebook scaled down by 2x at each layer
393 | frac_per_sample_entropy = 1., # make less than 1. to only use a random fraction of the probs for per sample entropy
394 | has_projections = None,
395 | projection_has_bias = True,
396 | soft_clamp_input_value = None,
397 | cosine_sim_project_in = False,
398 | cosine_sim_project_in_scale = None,
399 | channel_first = None,
400 | experimental_softplus_entropy_loss = False,
401 | entropy_loss_offset = 5., # how much to shift the loss before softplus
402 | spherical = True, # from https://arxiv.org/abs/2406.07548
403 | force_quantization_f32 = True, # will force the quantization step to be full precision
404 | inv_temperature = 100.0,
405 | gamma0=1.0, gamma=1.0, zeta=1.0,
406 | new_quant = False, # new quant function,
407 | use_out_phi = False, # use output phi network
408 | use_out_phi_res = False, # residual out phi
409 | ):
410 | super().__init__()
411 |
412 | # some assert validations
413 | assert exists(dim) , 'dim must be specified for BSQ'
414 |
415 | codebook_dim = dim
416 | codebook_dims = codebook_dim * num_codebooks
417 | dim = default(dim, codebook_dims)
418 | self.codebook_dims = codebook_dims
419 |
420 | has_projections = default(has_projections, dim != codebook_dims)
421 |
422 | if cosine_sim_project_in:
423 | cosine_sim_project_in = default(cosine_sim_project_in_scale, codebook_scale)
424 | project_in_klass = partial(CosineSimLinear, scale = cosine_sim_project_in)
425 | else:
426 | project_in_klass = partial(nn.Linear, bias = projection_has_bias)
427 |
428 | self.project_in = project_in_klass(dim, codebook_dims) if has_projections else nn.Identity() # nn.Identity()
429 | self.project_out = nn.Linear(codebook_dims, dim, bias = projection_has_bias) if has_projections else nn.Identity() # nn.Identity()
430 | self.has_projections = has_projections
431 |
432 | self.out_phi = nn.Linear(codebook_dims, codebook_dims) if use_out_phi else nn.Identity()
433 | self.use_out_phi_res = use_out_phi_res
434 | if self.use_out_phi_res:
435 | self.out_phi_scale = nn.Parameter(torch.zeros(codebook_dims), requires_grad=True) # init as zero
436 |
437 | self.dim = dim
438 | self.codebook_dim = codebook_dim
439 | self.num_codebooks = num_codebooks
440 |
441 | keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
442 | assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
443 | self.keep_num_codebooks_dim = keep_num_codebooks_dim
444 |
445 | # channel first
446 |
447 | self.channel_first = channel_first
448 |
449 | # straight through activation
450 |
451 | self.activation = straight_through_activation
452 |
453 | # For BSQ (binary spherical quantization)
454 | if not spherical:
455 | raise ValueError("For BSQ, spherical must be True.")
456 | self.persample_entropy_compute = 'analytical'
457 | self.inv_temperature = inv_temperature
458 | self.gamma0 = gamma0 # loss weight for entropy penalty
459 | self.gamma = gamma # loss weight for entropy penalty
460 | self.zeta = zeta # loss weight for entire entropy penalty
461 | self.new_quant = new_quant
462 |
463 | # entropy aux loss related weights
464 |
465 | assert 0 < frac_per_sample_entropy <= 1.
466 | self.frac_per_sample_entropy = frac_per_sample_entropy
467 |
468 | self.diversity_gamma = diversity_gamma
469 | self.entropy_loss_weight = entropy_loss_weight
470 |
471 | # codebook scale
472 |
473 | self.codebook_scale = codebook_scale
474 |
475 | # commitment loss
476 |
477 | self.commitment_loss_weight = commitment_loss_weight
478 |
479 | # whether to soft clamp the input value from -value to value
480 |
481 | self.soft_clamp_input_value = soft_clamp_input_value
482 | assert not exists(soft_clamp_input_value) or soft_clamp_input_value >= codebook_scale
483 |
484 | # whether to make the entropy loss positive through a softplus (experimental, please report if this worked or not in discussions)
485 |
486 | self.entropy_loss_offset = entropy_loss_offset
487 | self.experimental_softplus_entropy_loss = experimental_softplus_entropy_loss
488 |
489 | # for no auxiliary loss, during inference
490 |
491 | self.register_buffer('mask', 2 ** torch.arange(codebook_dim - 1, -1, -1))
492 | self.register_buffer('zero', torch.tensor(0.), persistent = False)
493 |
494 | # whether to force quantization step to be f32
495 |
496 | self.force_quantization_f32 = force_quantization_f32
497 |
498 | def bits_to_codes(self, bits):
499 | return bits * self.codebook_scale * 2 - self.codebook_scale
500 |
501 | # @property
502 | # def dtype(self):
503 | # return self.codebook.dtype
504 |
505 | def indices_to_codes(
506 | self,
507 | indices,
508 | label_type = 'int_label',
509 | project_out = True
510 | ):
511 | assert label_type in ['int_label', 'bit_label']
512 | is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
513 | should_transpose = default(self.channel_first, is_img_or_video)
514 |
515 | if not self.keep_num_codebooks_dim:
516 | if label_type == 'int_label':
517 | indices = rearrange(indices, '... -> ... 1')
518 | else:
519 | indices = indices.unsqueeze(-2)
520 |
521 | # indices to codes, which are bits of either -1 or 1
522 |
523 | if label_type == 'int_label':
524 | assert indices[..., None].int().min() > 0
525 | bits = ((indices[..., None].int() & self.mask) != 0).float() # .to(self.dtype)
526 | else:
527 | bits = indices
528 |
529 | codes = self.bits_to_codes(bits)
530 |
531 | codes = l2norm(codes) # must normalize when using BSQ
532 |
533 | codes = rearrange(codes, '... c d -> ... (c d)')
534 |
535 | # whether to project codes out to original dimensions
536 | # if the input feature dimensions were not log2(codebook size)
537 |
538 | if project_out:
539 | codes = self.project_out(codes)
540 |
541 | # rearrange codes back to original shape
542 |
543 | if should_transpose:
544 | codes = rearrange(codes, 'b ... d -> b d ...')
545 |
546 | return codes
547 |
548 | def quantize(self, z):
549 | assert z.shape[-1] == self.codebook_dims, f"Expected {self.codebook_dims} dimensions, got {z.shape[-1]}"
550 |
551 | zhat = torch.where(z > 0,
552 | torch.tensor(1, dtype=z.dtype, device=z.device),
553 | torch.tensor(-1, dtype=z.dtype, device=z.device))
554 | return z + (zhat - z).detach()
555 |
556 | def quantize_new(self, z):
557 | assert z.shape[-1] == self.codebook_dims, f"Expected {self.codebook_dims} dimensions, got {z.shape[-1]}"
558 |
559 | zhat = torch.where(z > 0,
560 | torch.tensor(1, dtype=z.dtype, device=z.device),
561 | torch.tensor(-1, dtype=z.dtype, device=z.device))
562 |
563 | q_scale = 1. / (self.codebook_dims ** 0.5)
564 | zhat = q_scale * zhat # on unit sphere
565 |
566 | return z + (zhat - z).detach()
567 |
568 | def soft_entropy_loss(self, z):
569 | if self.persample_entropy_compute == 'analytical':
570 | # if self.l2_norm:
571 | p = torch.sigmoid(-4 * z / (self.codebook_dims ** 0.5) * self.inv_temperature)
572 | # else:
573 | # p = torch.sigmoid(-4 * z * self.inv_temperature)
574 | prob = torch.stack([p, 1-p], dim=-1) # (b, h, w, 18, 2)
575 | per_sample_entropy = self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean() # (b,h,w,18)->(b,h,w)->scalar
576 | else:
577 | per_sample_entropy = self.get_entropy(prob, dim=-1, normalize=False).sum(dim=-1).mean()
578 |
579 | # macro average of the probability of each subgroup
580 | avg_prob = reduce(prob, '... g d ->g d', 'mean') # (18, 2)
581 | codebook_entropy = self.get_entropy(avg_prob, dim=-1, normalize=False)
582 |
583 | # the approximation of the entropy is the sum of the entropy of each subgroup
584 | return per_sample_entropy, codebook_entropy.sum(), avg_prob
585 |
586 | def get_entropy(self, count, dim=-1, eps=1e-4, normalize=True):
587 | if normalize: # False
588 | probs = (count + eps) / (count + eps).sum(dim=dim, keepdim =True)
589 | else: # True
590 | probs = count
591 | H = -(probs * torch.log(probs + 1e-8)).sum(dim=dim)
592 | return H
593 |
594 | def forward(
595 | self,
596 | x,
597 | return_loss_breakdown = False,
598 | mask = None,
599 | entropy_weight=0.1
600 | ):
601 | """
602 | einstein notation
603 | b - batch
604 | n - sequence (or flattened spatial dimensions)
605 | d - feature dimension, which is also log2(codebook size)
606 | c - number of codebook dim
607 | """
608 |
609 | is_img_or_video = x.ndim >= 4
610 | should_transpose = default(self.channel_first, is_img_or_video)
611 |
612 | # standardize image or video into (batch, seq, dimension)
613 |
614 | if should_transpose:
615 | x = rearrange(x, 'b d ... -> b ... d')
616 | x, ps = pack_one(x, 'b * d') # x.shape [b, hwt, c]
617 |
618 | assert x.shape[-1] == self.dim, f'expected dimension of {self.dim} but received {x.shape[-1]}'
619 |
620 | x = self.project_in(x)
621 |
622 | # split out number of codebooks
623 |
624 | x = rearrange(x, 'b n (c d) -> b n c d', c = self.num_codebooks)
625 |
626 | x = l2norm(x)
627 |
628 | # whether to force quantization step to be full precision or not
629 |
630 | force_f32 = self.force_quantization_f32
631 |
632 | quantization_context = partial(autocast, 'cuda', enabled = False) if force_f32 else nullcontext
633 |
634 | indices = None
635 | with quantization_context():
636 |
637 | if force_f32:
638 | orig_dtype = x.dtype
639 | x = x.float()
640 |
641 | # use straight-through gradients (optionally with custom activation fn) if training
642 | if self.new_quant:
643 | quantized = self.quantize_new(x)
644 | else:
645 | quantized = self.quantize(x)
646 | q_scale = 1. / (self.codebook_dims ** 0.5)
647 | quantized = q_scale * quantized # on unit sphere
648 |
649 | # calculate indices
650 | bit_indices = (quantized > 0).int()
651 |
652 | # entropy aux loss
653 | if self.training:
654 | persample_entropy, cb_entropy, avg_prob = self.soft_entropy_loss(x) # compute entropy
655 | entropy_penalty = self.gamma0 * persample_entropy - self.gamma * cb_entropy
656 | else:
657 | # if not training, just return dummy 0
658 | entropy_penalty = persample_entropy = cb_entropy = self.zero
659 |
660 | # commit loss
661 |
662 | if self.training and self.commitment_loss_weight > 0.:
663 |
664 | commit_loss = F.mse_loss(x, quantized.detach(), reduction = 'none')
665 |
666 | if exists(mask):
667 | commit_loss = commit_loss[mask]
668 |
669 | commit_loss = commit_loss.mean()
670 | else:
671 | commit_loss = self.zero
672 |
673 | # input back to original dtype if needed
674 |
675 | if force_f32:
676 | x = x.type(orig_dtype)
677 |
678 | # merge back codebook dim
679 | x = quantized # rename quantized to x for output
680 |
681 | if self.use_out_phi_res:
682 | x = x + self.out_phi_scale * self.out_phi(x) # apply out_phi on quant output as residual
683 | else:
684 | x = self.out_phi(x) # apply out_phi on quant output
685 |
686 | x = rearrange(x, 'b n c d -> b n (c d)')
687 |
688 | # project out to feature dimension if needed
689 |
690 | x = self.project_out(x)
691 |
692 | # reconstitute image or video dimensions
693 |
694 | if should_transpose:
695 | x = unpack_one(x, ps, 'b * d')
696 | x = rearrange(x, 'b ... d -> b d ...')
697 |
698 | bit_indices = unpack_one(bit_indices, ps, 'b * c d')
699 |
700 | # whether to remove single codebook dim
701 |
702 | if not self.keep_num_codebooks_dim:
703 | bit_indices = rearrange(bit_indices, '... 1 d -> ... d')
704 |
705 | # complete aux loss
706 |
707 | aux_loss = commit_loss * self.commitment_loss_weight + (self.zeta * entropy_penalty / self.inv_temperature)*entropy_weight
708 | # returns
709 |
710 | ret = Return(x, indices, bit_indices, aux_loss)
711 |
712 | if not return_loss_breakdown:
713 | return ret
714 |
715 | return ret, LossBreakdown(persample_entropy, cb_entropy, commit_loss)
716 |
717 |
--------------------------------------------------------------------------------