├── assets ├── IBQ-teaser.png ├── comparsion.png ├── IBQ-gradient-flow.png ├── Open-MAGVIT2-teaser.png ├── Pretrain_comparison.png └── Open-MAGVIT2-framework.png ├── src ├── IBQ │ ├── modules │ │ ├── losses │ │ │ ├── __init__.py │ │ │ ├── segmentation.py │ │ │ └── lpips.py │ │ ├── autoencoder │ │ │ └── lpips │ │ │ │ └── vgg.pth │ │ ├── scheduler │ │ │ └── lr_scheduler.py │ │ ├── discriminator │ │ │ └── model.py │ │ ├── ema.py │ │ └── util.py │ ├── models │ │ └── dummy_cond_stage.py │ ├── lr_scheduler.py │ ├── data │ │ ├── helper_types.py │ │ └── base.py │ └── util.py └── Open_MAGVIT2 │ ├── modules │ ├── losses │ │ ├── __init__.py │ │ ├── segmentation.py │ │ └── lpips.py │ ├── autoencoder │ │ └── lpips │ │ │ └── vgg.pth │ ├── scheduler │ │ └── lr_scheduler.py │ └── ema.py │ ├── models │ └── dummy_cond_stage.py │ ├── lr_scheduler.py │ ├── data │ ├── helper_types.py │ ├── base.py │ ├── functional.py │ ├── prepare_pretrain.py │ └── volume_transforms.py │ └── util.py ├── scripts ├── evaluation │ ├── evaluation_video.sh │ ├── evaluation_128.sh │ ├── evaluation_original.sh │ └── evaluation_256.sh ├── inference │ ├── reconstruct_video.sh │ ├── reconstruct_image.sh │ └── generate.sh ├── train_tokenizer │ ├── IBQ │ │ ├── run_262144.sh │ │ ├── run_16384.sh │ │ └── pretrain_256.sh │ └── Open-MAGVIT2 │ │ ├── run_video.sh │ │ ├── run_128_L.sh │ │ ├── run_256_L.sh │ │ └── pretrain_256.sh └── train_autogressive │ └── run.sh ├── requirements.txt ├── combine_npz.py ├── metrics └── fid.py ├── configs ├── Open-MAGVIT2 │ ├── gpu │ │ ├── imagenet_lfqgan_256_L.yaml │ │ ├── imagenet_lfqgan_128_L.yaml │ │ ├── ucf101_lfqfan_128_L.yaml │ │ ├── imagenet_conditional_llama_B.yaml │ │ ├── imagenet_conditional_llama_L.yaml │ │ ├── imagenet_conditional_llama_XL.yaml │ │ ├── pretrain_lfqgan_256_262144.yaml │ │ └── pretrain_lfqgan_256_16384.yaml │ └── npu │ │ ├── imagenet_lfqgan_256_L.yaml │ │ ├── imagenet_lfqgan_128_L.yaml │ │ ├── ucf101_lfqgan_128_L.yaml │ │ ├── imagenet_conditional_llama_B.yaml │ │ ├── imagenet_conditional_llama_L.yaml │ │ ├── imagenet_conditional_llama_XL.yaml │ │ ├── pretrain_lfqgan_256_16384.yaml │ │ └── pretrain_lfqgan_256_262144.yaml └── IBQ │ ├── gpu │ ├── imagenet_ibqgan_8192.yaml │ ├── imagenet_ibqgan_1024.yaml │ ├── imagenet_ibqgan_16384.yaml │ ├── imagenet_ibqgan_262144.yaml │ ├── imagenet_conditional_llama_B.yaml │ ├── imagenet_conditional_llama_L.yaml │ ├── imagenet_conditional_llama_XL.yaml │ ├── imagenet_conditional_llama_XXL.yaml │ ├── pretrain_ibqgan_262144.yaml │ └── pretrain_ibqgan_16384.yaml │ └── npu │ ├── imagenet_ibqgan_8192.yaml │ ├── imagenet_ibqgan_1024.yaml │ ├── imagenet_ibqgan_16384.yaml │ ├── imagenet_ibqgan_262144.yaml │ ├── imagenet_conditional_llama_B.yaml │ ├── imagenet_conditional_llama_L.yaml │ ├── imagenet_conditional_llama_XL.yaml │ ├── imagenet_conditional_llama_XXL.yaml │ ├── pretrain_ibqgan_16384.yaml │ └── pretrain_ibqgan_262144.yaml ├── main.py └── reconstruct_image.py /assets/IBQ-teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/SEED-Voken/HEAD/assets/IBQ-teaser.png -------------------------------------------------------------------------------- /assets/comparsion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/SEED-Voken/HEAD/assets/comparsion.png -------------------------------------------------------------------------------- /src/IBQ/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from src.IBQ.modules.losses.vqperceptual import DummyLoss 2 | 3 | -------------------------------------------------------------------------------- /assets/IBQ-gradient-flow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/SEED-Voken/HEAD/assets/IBQ-gradient-flow.png -------------------------------------------------------------------------------- /assets/Open-MAGVIT2-teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/SEED-Voken/HEAD/assets/Open-MAGVIT2-teaser.png -------------------------------------------------------------------------------- /assets/Pretrain_comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/SEED-Voken/HEAD/assets/Pretrain_comparison.png -------------------------------------------------------------------------------- /assets/Open-MAGVIT2-framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/SEED-Voken/HEAD/assets/Open-MAGVIT2-framework.png -------------------------------------------------------------------------------- /src/Open_MAGVIT2/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from src.Open_MAGVIT2.modules.losses.vqperceptual import DummyLoss 2 | 3 | -------------------------------------------------------------------------------- /src/IBQ/modules/autoencoder/lpips/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/SEED-Voken/HEAD/src/IBQ/modules/autoencoder/lpips/vgg.pth -------------------------------------------------------------------------------- /src/Open_MAGVIT2/modules/autoencoder/lpips/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/SEED-Voken/HEAD/src/Open_MAGVIT2/modules/autoencoder/lpips/vgg.pth -------------------------------------------------------------------------------- /scripts/evaluation/evaluation_video.sh: -------------------------------------------------------------------------------- 1 | ### Video Evaluation Scripts 2 | 3 | ## Open-MAGVIT2 262144 Video Version 4 | python evaluation_video.py \ 5 | --config_file "configs/Open-MAGVIT2/npu/ucf101_lfqgan_128_L.yaml" \ 6 | --ckpt_path "../upload_ckpts/Open-MAGVIT2/video/ucf101_lfqgan_128_262144.ckpt" -------------------------------------------------------------------------------- /scripts/inference/reconstruct_video.sh: -------------------------------------------------------------------------------- 1 | ### Open-MAGVIT2 Video Version Reconstruction 2 | python reconstruct_video.py \ 3 | --config_file "configs/Open-MAGVIT2/npu/ucf101_lfqgan_128_L.yaml" \ 4 | --ckpt_path "../upload_ckpts/Open-MAGVIT2/video/ucf101_lfqgan_128_262144.ckpt" \ 5 | --version "video_visualize" -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.1.1 --index-url https://download.pytorch.org/whl/cu118 2 | torchvision==0.16.1 --index-url https://download.pytorch.org/whl/cu118 3 | lightning==2.2.0 4 | jsonargparse[signatures]>=4.27.7 5 | tensorboard 6 | tensorboardx 7 | albumentations==1.4.4 8 | omegaconf 9 | einops 10 | requests 11 | transformers==4.37.2 12 | lpips 13 | av 14 | decord 15 | -------------------------------------------------------------------------------- /scripts/train_tokenizer/IBQ/run_262144.sh: -------------------------------------------------------------------------------- 1 | export MASTER_ADDR=${1:-localhost} 2 | export MASTER_PORT=${2:-10055} 3 | export NODE_RANK=${3:-0} 4 | export OMP_NUM_THREADS=6 5 | 6 | export MASTER_ADDR=$MASTER_ADDR 7 | export MASTER_PORT=$MASTER_PORT 8 | 9 | echo $MASTER_ADDR 10 | echo $MASTER_PORT 11 | 12 | ##NPU 13 | NODE_RANK=$NODE_RANK python main.py fit --config configs/IBQ/npu/imagenet_ibqgan_262144.yaml 14 | 15 | ###GPU 16 | # NODE_RANK=$NODE_RANK python main.py fit --config configs/IBQ/gpu/imagenet_ibqgan_262144.yaml -------------------------------------------------------------------------------- /scripts/evaluation/evaluation_128.sh: -------------------------------------------------------------------------------- 1 | ## GPU and NPU can use the same config for evaluation 2 | # python evaluation_image.py --config_file configs/Open-MAGVIT2/gpu/imagenet_lfqgan_128_L.yaml --ckpt_path ../upload_ckpts/Open-MAGVIT2/in1k_128_L/imagenet_128_L.ckpt --image_size 128 --model Open-MAGVIT2 3 | 4 | ##NPU 5 | ##Open-MAGVIT2 6 | python evaluation_image.py --config_file configs/Open-MAGVIT2/npu/imagenet_lfqgan_128_L.yaml --ckpt_path ../upload_ckpts/Open-MAGVIT2/in1k_128_L/imagenet_128_L.ckpt --image_size 128 --model Open-MAGVIT2 -------------------------------------------------------------------------------- /scripts/train_tokenizer/IBQ/run_16384.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | export MASTER_ADDR=${1:-localhost} 3 | export MASTER_PORT=${2:-10055} 4 | export NODE_RANK=${3:-0} 5 | export OMP_NUM_THREADS=6 6 | 7 | export MASTER_ADDR=$MASTER_ADDR 8 | export MASTER_PORT=$MASTER_PORT 9 | 10 | echo $MASTER_ADDR 11 | echo $MASTER_PORT 12 | 13 | ##NPU 14 | NODE_RANK=$NODE_RANK python main.py fit --config configs/IBQ/npu/imagenet_ibqgan_16384.yaml 15 | 16 | ###GPU 17 | # NODE_RANK=$NODE_RANK python main.py fit --config configs/IBQ/gpu/imagenet_ibqgan_16384.yaml -------------------------------------------------------------------------------- /src/Open_MAGVIT2/models/dummy_cond_stage.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | 3 | 4 | class DummyCondStage: 5 | def __init__(self, conditional_key): 6 | self.conditional_key = conditional_key 7 | self.train = None 8 | 9 | def eval(self): 10 | return self 11 | 12 | @staticmethod 13 | def encode(c: Tensor): 14 | return c, None, (None, None, c) 15 | 16 | @staticmethod 17 | def decode(c: Tensor): 18 | return c 19 | 20 | @staticmethod 21 | def to_rgb(c: Tensor): 22 | return c 23 | -------------------------------------------------------------------------------- /scripts/train_tokenizer/Open-MAGVIT2/run_video.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | export MASTER_ADDR=${1:-localhost} 3 | export MASTER_PORT=${2:-10055} 4 | export NODE_RANK=${3:-0} 5 | export OMP_NUM_THREADS=6 6 | 7 | export MASTER_ADDR=$MASTER_ADDR 8 | export MASTER_PORT=$MASTER_PORT 9 | 10 | echo $MASTER_ADDR 11 | echo $MASTER_PORT 12 | 13 | ##NPU 14 | NODE_RANK=$NODE_RANK python main.py fit --config configs/Open-MAGVIT2/npu/ucf101_lfqgan_128_L.yaml 15 | 16 | ###GPU 17 | # NODE_RANK=$NODE_RANK python main.py fit --config configs/Open-MAGVIT2/gpu/ucf101_lfqgan_128_L.yaml -------------------------------------------------------------------------------- /scripts/train_tokenizer/Open-MAGVIT2/run_128_L.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | export MASTER_ADDR=${1:-localhost} 3 | export MASTER_PORT=${2:-10055} 4 | export NODE_RANK=${3:-0} 5 | export OMP_NUM_THREADS=6 6 | 7 | export MASTER_ADDR=$MASTER_ADDR 8 | export MASTER_PORT=$MASTER_PORT 9 | 10 | echo $MASTER_ADDR 11 | echo $MASTER_PORT 12 | 13 | ##NPU 14 | NODE_RANK=$NODE_RANK python main.py fit --config configs/Open-MAGVIT2/npu/imagenet_lfqgan_128_L.yaml 15 | 16 | ###GPU 17 | # NODE_RANK=$NODE_RANK python main.py fit --config configs/Open-MAGVIT2/gpu/imagenet_lfqgan_128_L.yaml -------------------------------------------------------------------------------- /scripts/train_tokenizer/Open-MAGVIT2/run_256_L.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | export MASTER_ADDR=${1:-localhost} 3 | export MASTER_PORT=${2:-10055} 4 | export NODE_RANK=${3:-0} 5 | export OMP_NUM_THREADS=6 6 | 7 | export MASTER_ADDR=$MASTER_ADDR 8 | export MASTER_PORT=$MASTER_PORT 9 | 10 | echo $MASTER_ADDR 11 | echo $MASTER_PORT 12 | 13 | ##NPU 14 | NODE_RANK=$NODE_RANK python main.py fit --config configs/Open-MAGVIT2/npu/imagenet_lfqgan_256_L.yaml 15 | 16 | ###GPU 17 | # NODE_RANK=$NODE_RANK python main.py fit --config configs/Open-MAGVIT2/gpu/imagenet_lfqgan_256_L.yaml -------------------------------------------------------------------------------- /src/IBQ/models/dummy_cond_stage.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | 3 | 4 | class DummyCondStage: 5 | def __init__(self, conditional_key): 6 | self.conditional_key = conditional_key 7 | self.train = None 8 | 9 | def eval(self): 10 | return self 11 | 12 | @staticmethod 13 | def encode(c: Tensor): 14 | return c, None, (None, None, c) 15 | 16 | @staticmethod 17 | def decode(c: Tensor): 18 | return c 19 | 20 | @staticmethod 21 | def to_rgb(c: Tensor): 22 | return c 23 | -------------------------------------------------------------------------------- /scripts/train_tokenizer/IBQ/pretrain_256.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | export MASTER_ADDR=${1:-localhost} 3 | export MASTER_PORT=${2:-10055} 4 | export NODE_RANK=${3:-0} 5 | export OMP_NUM_THREADS=6 6 | 7 | export MASTER_ADDR=$MASTER_ADDR 8 | export MASTER_PORT=$MASTER_PORT 9 | 10 | echo $MASTER_ADDR 11 | echo $MASTER_PORT 12 | 13 | ## Pretrain 262144 14 | NODE_RANK=$NODE_RANK python main.py fit --config configs/IBQ/npu/pretrain_ibqgan_256_262144.yaml 15 | 16 | ## Pretrain 16384 17 | # NODE_RANK=$NODE_RANK python main.py fit --config configs/IBQ/gpu/pretrain_ibqgan_256_16384.yaml 18 | 19 | -------------------------------------------------------------------------------- /scripts/train_tokenizer/Open-MAGVIT2/pretrain_256.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | export MASTER_ADDR=${1:-localhost} 3 | export MASTER_PORT=${2:-10055} 4 | export NODE_RANK=${3:-0} 5 | export OMP_NUM_THREADS=6 6 | 7 | export MASTER_ADDR=$MASTER_ADDR 8 | export MASTER_PORT=$MASTER_PORT 9 | 10 | echo $MASTER_ADDR 11 | echo $MASTER_PORT 12 | 13 | ## Pretrain 262144 14 | NODE_RANK=$NODE_RANK python main.py fit --config configs/Open-MAGVIT2/npu/pretrain_lfqgan_256_262144.yaml 15 | 16 | ## Pretrain 16384 17 | # NODE_RANK=$NODE_RANK python main.py fit --config configs/Open-MAGVIT2/gpu/pretrain_lfqgan_256_16384.yaml 18 | 19 | -------------------------------------------------------------------------------- /combine_npz.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import argparse 4 | 5 | def get_parser(): 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument("--logdir", required=True) 8 | return parser 9 | 10 | def combine_npz(logdir): 11 | save_npzs = [npz for npz in os.listdir(logdir)] 12 | 13 | npzs = [] 14 | 15 | for save_npz in save_npzs: 16 | tem_npz = np.load(os.path.join(logdir, save_npz)) 17 | data = tem_npz["arr_0"] 18 | npzs.append(data) 19 | 20 | save_npz = np.vstack(npzs) 21 | np.random.shuffle(save_npz) 22 | np.savez(os.path.join(logdir, "sample.npz"), save_npz) 23 | 24 | if __name__ == "__main__": 25 | parser = get_parser() 26 | args = parser.parse_args() 27 | combine_npz(args.logdir) -------------------------------------------------------------------------------- /src/Open_MAGVIT2/modules/losses/segmentation.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class BCELoss(nn.Module): 6 | def forward(self, prediction, target): 7 | loss = F.binary_cross_entropy_with_logits(prediction,target) 8 | return loss, {} 9 | 10 | 11 | class BCELossWithQuant(nn.Module): 12 | def __init__(self, codebook_weight=1.): 13 | super().__init__() 14 | self.codebook_weight = codebook_weight 15 | 16 | def forward(self, qloss, target, prediction, split): 17 | bce_loss = F.binary_cross_entropy_with_logits(prediction,target) 18 | loss = bce_loss + self.codebook_weight*qloss 19 | return loss, {"{}/total_loss".format(split): loss.clone().detach().mean(), 20 | "{}/bce_loss".format(split): bce_loss.detach().mean(), 21 | "{}/quant_loss".format(split): qloss.detach().mean() 22 | } 23 | -------------------------------------------------------------------------------- /src/IBQ/modules/losses/segmentation.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class BCELoss(nn.Module): 6 | def forward(self, prediction, target): 7 | loss = F.binary_cross_entropy_with_logits(prediction,target) 8 | return loss, {} 9 | 10 | 11 | class BCELossWithQuant(nn.Module): 12 | def __init__(self, codebook_weight=1.): 13 | super().__init__() 14 | self.codebook_weight = codebook_weight 15 | 16 | def forward(self, qloss, target, prediction, split): 17 | bce_loss = F.binary_cross_entropy_with_logits(prediction,target) 18 | loss = bce_loss + self.codebook_weight*qloss 19 | return loss, {"{}/total_loss".format(split): loss.clone().detach().mean(), 20 | "{}/bce_loss".format(split): bce_loss.detach().mean(), 21 | "{}/quant_loss".format(split): qloss.detach().mean() 22 | } 23 | -------------------------------------------------------------------------------- /src/Open_MAGVIT2/modules/scheduler/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from functools import partial 4 | 5 | # step scheduler 6 | def fn_LinearWarmup(warmup_steps, step): 7 | if step < warmup_steps: # linear warmup 8 | return float(step) / float(max(1, warmup_steps)) 9 | else: 10 | return 1.0 11 | 12 | def Scheduler_LinearWarmup(warmup_steps): 13 | return partial(fn_LinearWarmup, warmup_steps) 14 | 15 | 16 | def fn_LinearWarmup_CosineDecay(warmup_steps, max_steps, multipler_min, step): 17 | if step < warmup_steps: # linear warmup 18 | return float(step) / float(max(1, warmup_steps)) 19 | else: # cosine learning rate schedule 20 | multipler = 0.5 * (math.cos((step - warmup_steps) / (max_steps - warmup_steps) * math.pi) + 1) 21 | return max(multipler, multipler_min) 22 | 23 | def Scheduler_LinearWarmup_CosineDecay(warmup_steps, max_steps, multipler_min): 24 | return partial(fn_LinearWarmup_CosineDecay, warmup_steps, max_steps, multipler_min) -------------------------------------------------------------------------------- /src/IBQ/modules/scheduler/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from functools import partial 4 | 5 | # step scheduler 6 | def fn_LinearWarmup(warmup_steps, step): 7 | if step < warmup_steps: # linear warmup 8 | return float(step) / float(max(1, warmup_steps)) 9 | else: 10 | return 1.0 11 | 12 | def Scheduler_LinearWarmup(warmup_steps): 13 | return partial(fn_LinearWarmup, warmup_steps) 14 | 15 | 16 | def fn_LinearWarmup_CosineDecay(warmup_steps, max_steps, multipler_min, step): 17 | if step < warmup_steps: # linear warmup 18 | return float(step) / float(max(1, warmup_steps)) 19 | else: # cosine learning rate schedule 20 | multipler = 0.5 * (math.cos((step - warmup_steps) / (max_steps - warmup_steps) * math.pi) + 1) 21 | return max(multipler, multipler_min) 22 | 23 | def Scheduler_LinearWarmup_CosineDecay(warmup_steps, max_steps, multipler_min): 24 | return partial(fn_LinearWarmup_CosineDecay, warmup_steps, max_steps, multipler_min) -------------------------------------------------------------------------------- /scripts/evaluation/evaluation_original.sh: -------------------------------------------------------------------------------- 1 | ### Evaluate Open-MAGVIT2 Pretrain 262144 NPU 2 | python evaluation_original_reso.py --config_file configs/Open-MAGVIT2/npu/pretrain_lfqgan_256_262144.yaml --ckpt_path ../upload_ckpts/Open-MAGVIT2/pretrain_256_262144/pretrain256_262144.ckpt --original_reso --model Open-MAGVIT2 --batch_size 1 3 | 4 | ### Evaluate Open-MAGVIT2 Pretrain 16384 NPU 5 | # python evaluation_original_reso.py --config_file configs/Open-MAGVIT2/npu/pretrain_lfqgan_256_16384.yaml --ckpt_path ../upload_ckpts/Open-MAGVIT2/pretrain_256_16384/pretrain256_16384.ckpt --model Open-MAGVIT2 --original_reso --batch_size 1 6 | 7 | ### Evaluate Open-MAGVIT2 Pretrain 262144 GPU 8 | # python evaluation_original_reso.py --config_file configs/Open-MAGVIT2/gpu/pretrain_lfqgan_256_262144.yaml --ckpt_path ../upload_ckpts/Open-MAGVIT2/pretrain_256_262144/pretrain256_262144.ckpt --original_reso --model Open-MAGVIT2 --batch_size 1 9 | 10 | ### Evaluate Open-MAGVIT2 Pretrain 16384 GPU 11 | # python evaluation_original_reso.py --config_file configs/Open-MAGVIT2/gpu/pretrain_lfqgan_256_16384.yaml --ckpt_path ../upload_ckpts/Open-MAGVIT2/pretrain_256_16384/pretrain256_16384.ckpt --original_reso --model Open-MAGVIT2 --batch_size 1 -------------------------------------------------------------------------------- /src/Open_MAGVIT2/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 9 | self.lr_warm_up_steps = warm_up_steps 10 | self.lr_start = lr_start 11 | self.lr_min = lr_min 12 | self.lr_max = lr_max 13 | self.lr_max_decay_steps = max_decay_steps 14 | self.last_lr = 0. 15 | self.verbosity_interval = verbosity_interval 16 | 17 | def schedule(self, n): 18 | if self.verbosity_interval > 0: 19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 20 | if n < self.lr_warm_up_steps: 21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 22 | self.last_lr = lr 23 | return lr 24 | else: 25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 26 | t = min(t, 1.0) 27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 28 | 1 + np.cos(t * np.pi)) 29 | self.last_lr = lr 30 | return lr 31 | 32 | def __call__(self, n): 33 | return self.schedule(n) 34 | 35 | -------------------------------------------------------------------------------- /scripts/inference/reconstruct_image.sh: -------------------------------------------------------------------------------- 1 | ## NPU 2 | 3 | ##Open-MAGVIT2 4 | # python reconstruct_image.py \ 5 | # --config_file "configs/Open-MAGVIT2/npu/imagenet_lfqgan_256_L.yaml" \ 6 | # --ckpt_path ../upload_ckpts/Open-MAGVIT2/in1k_256_L/imagenet_256_L.ckpt \ 7 | # --save_dir "./visualize" \ 8 | # --version "1k_Open_MAGVIT2" \ 9 | # --image_num 50 \ 10 | # --image_size 256 \ 11 | # --model Open-MAGVIT2 \ 12 | 13 | # ##IBQ 14 | # python reconstruct_image.py \ 15 | # --config_file "configs/IBQ/npu/imagenet_ibqgan_262144.yaml" \ 16 | # --ckpt_path ../upload_ckpts/IBQ/in1k_262144/imagenet256_262144.ckpt \ 17 | # --save_dir "./visualize" \ 18 | # --version "1k_IBQ" \ 19 | # --image_num 50 \ 20 | # --image_size 256 \ 21 | # --model IBQ \ 22 | 23 | ##GPU 24 | # python reconstruct.py \ 25 | # --config_file "configs/Open-MAGVIT2/gpu/imagenet_lfqgan_256_L.yaml" \ 26 | # --ckpt_path ../upload_ckpts/Open-MAGVIT2/in1k_256_L/imagenet_256_L.ckpt \ 27 | # --save_dir "./visualize" \ 28 | # --version "1k_Open_MAGVIT2" \ 29 | # --image_num 50 \ 30 | # --image_size 256 \ 31 | # --model Open-MAGVIT2 \ 32 | 33 | # python reconstruct.py \ 34 | # --config_file "configs/IBQ/gpu/imagenet_ibqgan_262144.yaml" \ 35 | # --ckpt_path ../upload_ckpts/IBQ/in1k_262144/imagenet256_262144.ckpt \ 36 | # --save_dir "./visualize" \ 37 | # --version "1k_IBQ" \ 38 | # --image_num 50 \ 39 | # --image_size 256 \ 40 | # --model IBQ \ -------------------------------------------------------------------------------- /src/IBQ/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 9 | self.lr_warm_up_steps = warm_up_steps 10 | self.lr_start = lr_start 11 | self.lr_min = lr_min 12 | self.lr_max = lr_max 13 | self.lr_max_decay_steps = max_decay_steps 14 | self.last_lr = 0. 15 | self.verbosity_interval = verbosity_interval 16 | 17 | def schedule(self, n): 18 | if self.verbosity_interval > 0: 19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 20 | if n < self.lr_warm_up_steps: 21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 22 | self.last_lr = lr 23 | return lr 24 | else: 25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 26 | t = min(t, 1.0) 27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 28 | 1 + np.cos(t * np.pi)) 29 | self.last_lr = lr 30 | return lr 31 | 32 | def __call__(self, n): 33 | return self.schedule(n) 34 | 35 | -------------------------------------------------------------------------------- /scripts/inference/generate.sh: -------------------------------------------------------------------------------- 1 | ##Open-MAGVIT2 2 | # python generate.py \ 3 | # --ckpt "../upload_ckpts/Open-MAGVIT2/AR_256_XL/AR_256_XL.ckpt" \ 4 | # -o "./visualize" \ 5 | # --config "configs/Open-MAGVIT2/npu/imagenet_conditional_llama_XL.yaml" \ 6 | # -k "0,0" \ 7 | # -p "0.96,0.96" \ 8 | # --token_factorization \ 9 | # -n 8 \ 10 | # -t "1.0,1.0" \ 11 | # --classes "207" \ 12 | # --batch_size 8 \ 13 | # --cfg_scale "4.0,4.0" \ 14 | # --model Open-MAGVIT2 15 | 16 | ##GPU 17 | # python generate.py \ 18 | # --ckpt "../upload_ckpts/Open-MAGVIT2/AR_256_XL/AR_256_XL.ckpt" \ 19 | # -o "./visualize" \ 20 | # --config "configs/Open-MAGVIT2/gpu/imagenet_conditional_llama_XL.yaml" \ 21 | # -k "0,0" \ 22 | # -p "0.96,0.96" \ 23 | # --token_factorization \ 24 | # -n 8 \ 25 | # -t "1.0,1.0" \ 26 | # --classes "207" \ 27 | # --batch_size 8 \ 28 | # --cfg_scale "4.0,4.0" \ 29 | # --model Open-MAGVIT2 30 | 31 | ##IBQ 32 | # python generate.py \ 33 | # --ckpt "../upload_ckpts/IBQ/AR_256_XXL/AR_256_XXL.ckpt" \ 34 | # -o "./visualize" \ 35 | # --config "configs/IBQ/npu/imagenet_conditional_llama_XXL.yaml" \ 36 | # -k 0 \ 37 | # -p 0.96 \ 38 | # -n 8 \ 39 | # -t 1.0 \ 40 | # --classes "207" \ 41 | # --batch_size 8 \ 42 | # --cfg_scale 4.0 \ 43 | # --model IBQ 44 | 45 | # python generate.py \ 46 | # --ckpt "../upload_ckpts/IBQ/AR_256_XXL/AR_256_XXL.ckpt" \ 47 | # -o "./visualize" \ 48 | # --config "configs/IBQ/gpu/imagenet_conditional_llama_XXL.yaml" \ 49 | # -k 0 \ 50 | # -p 0.96 \ 51 | # -n 8 \ 52 | # -t 1.0 \ 53 | # --classes "207" \ 54 | # --batch_size 8 \ 55 | # --cfg_scale 4.0 \ 56 | # --model IBQ -------------------------------------------------------------------------------- /src/Open_MAGVIT2/data/helper_types.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Tuple, Optional, NamedTuple, Union 2 | from PIL.Image import Image as pil_image 3 | from torch import Tensor 4 | 5 | try: 6 | from typing import Literal 7 | except ImportError: 8 | from typing_extensions import Literal 9 | 10 | Image = Union[Tensor, pil_image] 11 | BoundingBox = Tuple[float, float, float, float] # x0, y0, w, h 12 | CropMethodType = Literal['none', 'random', 'center', 'random-2d'] 13 | SplitType = Literal['train', 'validation', 'test'] 14 | 15 | 16 | class ImageDescription(NamedTuple): 17 | id: int 18 | file_name: str 19 | original_size: Tuple[int, int] # w, h 20 | url: Optional[str] = None 21 | license: Optional[int] = None 22 | coco_url: Optional[str] = None 23 | date_captured: Optional[str] = None 24 | flickr_url: Optional[str] = None 25 | flickr_id: Optional[str] = None 26 | coco_id: Optional[str] = None 27 | 28 | 29 | class Category(NamedTuple): 30 | id: str 31 | super_category: Optional[str] 32 | name: str 33 | 34 | 35 | class Annotation(NamedTuple): 36 | area: float 37 | image_id: str 38 | bbox: BoundingBox 39 | category_no: int 40 | category_id: str 41 | id: Optional[int] = None 42 | source: Optional[str] = None 43 | confidence: Optional[float] = None 44 | is_group_of: Optional[bool] = None 45 | is_truncated: Optional[bool] = None 46 | is_occluded: Optional[bool] = None 47 | is_depiction: Optional[bool] = None 48 | is_inside: Optional[bool] = None 49 | segmentation: Optional[Dict] = None 50 | -------------------------------------------------------------------------------- /src/IBQ/data/helper_types.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Tuple, Optional, NamedTuple, Union 2 | from PIL.Image import Image as pil_image 3 | from torch import Tensor 4 | 5 | try: 6 | from typing import Literal 7 | except ImportError: 8 | from typing_extensions import Literal 9 | 10 | Image = Union[Tensor, pil_image] 11 | BoundingBox = Tuple[float, float, float, float] # x0, y0, w, h 12 | CropMethodType = Literal['none', 'random', 'center', 'random-2d'] 13 | SplitType = Literal['train', 'validation', 'test'] 14 | 15 | 16 | class ImageDescription(NamedTuple): 17 | id: int 18 | file_name: str 19 | original_size: Tuple[int, int] # w, h 20 | url: Optional[str] = None 21 | license: Optional[int] = None 22 | coco_url: Optional[str] = None 23 | date_captured: Optional[str] = None 24 | flickr_url: Optional[str] = None 25 | flickr_id: Optional[str] = None 26 | coco_id: Optional[str] = None 27 | 28 | 29 | class Category(NamedTuple): 30 | id: str 31 | super_category: Optional[str] 32 | name: str 33 | 34 | 35 | class Annotation(NamedTuple): 36 | area: float 37 | image_id: str 38 | bbox: BoundingBox 39 | category_no: int 40 | category_id: str 41 | id: Optional[int] = None 42 | source: Optional[str] = None 43 | confidence: Optional[float] = None 44 | is_group_of: Optional[bool] = None 45 | is_truncated: Optional[bool] = None 46 | is_occluded: Optional[bool] = None 47 | is_depiction: Optional[bool] = None 48 | is_inside: Optional[bool] = None 49 | segmentation: Optional[Dict] = None 50 | -------------------------------------------------------------------------------- /scripts/train_autogressive/run.sh: -------------------------------------------------------------------------------- 1 | export MASTER_ADDR=${1:-localhost} 2 | export MASTER_PORT=${2:-10055} 3 | export NODE_RANK=${3:-0} 4 | 5 | export OMP_NUM_THREADS=6 6 | export MASTER_ADDR=$MASTER_ADDR 7 | export MASTER_PORT=$MASTER_PORT 8 | 9 | echo $MASTER_ADDR 10 | echo $MASTER_PORT 11 | 12 | # NPU Open-MAGVIT2 13 | NODE_RANK=$NODE_RANK python main.py fit --config configs/Open-MAGVIT2/npu/imagenet_conditional_llama_XL.yaml 14 | # NODE_RANK=$NODE_RANK python main.py fit --config configs/Open-MAGVIT2/npu/imagenet_conditional_llama_L.yaml 15 | # NODE_RANK=$NODE_RANK python main.py fit --config configs/Open-MAGVIT2/npu/imagenet_conditional_llama_B.yaml 16 | 17 | # GPU 18 | # NODE_RANK=$NODE_RANK python main.py fit --config configs/Open-MAGVIT2/gpu/imagenet_conditional_llama_XL.yaml 19 | # NODE_RANK=$NODE_RANK python main.py fit --config configs/Open-MAGVIT2/gpu/imagenet_conditional_llama_L.yaml 20 | # NODE_RANK=$NODE_RANK python main.py fit --config configs/Open-MAGVIT2/gpu/imagenet_conditional_llama_B.yaml 21 | 22 | # IBQ NPU 23 | # NODE_RANK=$NODE_RANK python main.py fit --config configs/IBQ/npu/imagenet_conditional_llama_XXL.yaml 24 | # NODE_RANK=$NODE_RANK python main.py fit --config configs/IBQ/npu/imagenet_conditional_llama_XL.yaml 25 | # NODE_RANK=$NODE_RANK python main.py fit --config configs/IBQ/npu/imagenet_conditional_llama_L.yaml 26 | # NODE_RANK=$NODE_RANK python main.py fit --config configs/IBQ/npu/imagenet_conditional_llama_B.yaml 27 | 28 | # GPU 29 | # NODE_RANK=$NODE_RANK python main.py fit --config configs/IBQ/gpu/imagenet_conditional_llama_XXL.yaml 30 | # NODE_RANK=$NODE_RANK python main.py fit --config configs/IBQ/gpu/imagenet_conditional_llama_XL.yaml 31 | # NODE_RANK=$NODE_RANK python main.py fit --config configs/IBQ/gpu/imagenet_conditional_llama_L.yaml 32 | # NODE_RANK=$NODE_RANK python main.py fit --config configs/IBQ/gpu/imagenet_conditional_llama_B.yaml 33 | -------------------------------------------------------------------------------- /metrics/fid.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import linalg 3 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 4 | """Numpy implementation of the Frechet Distance. 5 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 6 | and X_2 ~ N(mu_2, C_2) is 7 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 8 | 9 | Stable version by Dougal J. Sutherland. 10 | 11 | Params: 12 | -- mu1 : Numpy array containing the activations of a layer of the 13 | inception net (like returned by the function 'get_predictions') 14 | for generated samples. 15 | -- mu2 : The sample mean over activations, precalculated on an 16 | representative data set. 17 | -- sigma1: The covariance matrix over activations for generated samples. 18 | -- sigma2: The covariance matrix over activations, precalculated on an 19 | representative data set. 20 | 21 | Returns: 22 | -- : The Frechet Distance. 23 | """ 24 | mu1 = np.atleast_1d(mu1) 25 | mu2 = np.atleast_1d(mu2) 26 | 27 | sigma1 = np.atleast_2d(sigma1) 28 | sigma2 = np.atleast_2d(sigma2) 29 | 30 | assert ( 31 | mu1.shape == mu2.shape 32 | ), "Training and test mean vectors have different lengths" 33 | assert ( 34 | sigma1.shape == sigma2.shape 35 | ), "Training and test covariances have different dimensions" 36 | 37 | diff = mu1 - mu2 38 | 39 | # Product might be almost singular 40 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 41 | if not np.isfinite(covmean).all(): 42 | msg = ( 43 | "fid calculation produces singular product; " 44 | "adding %s to diagonal of cov estimates" 45 | ) % eps 46 | print(msg) 47 | offset = np.eye(sigma1.shape[0]) * eps 48 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 49 | 50 | # Numerical error might give slight imaginary component 51 | if np.iscomplexobj(covmean): 52 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 53 | m = np.max(np.abs(covmean.imag)) 54 | raise ValueError("Imaginary component {}".format(m)) 55 | covmean = covmean.real 56 | 57 | tr_covmean = np.trace(covmean) 58 | 59 | return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean 60 | 61 | 62 | -------------------------------------------------------------------------------- /configs/Open-MAGVIT2/gpu/imagenet_lfqgan_256_L.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: true 2 | trainer: 3 | accelerator: gpu 4 | strategy: ddp_find_unused_parameters_true 5 | devices: 8 6 | num_nodes: 4 7 | precision: 16-mixed 8 | max_epochs: 270 9 | check_val_every_n_epoch: 1 10 | num_sanity_val_steps: -1 11 | log_every_n_steps: 100 12 | callbacks: 13 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 14 | init_args: 15 | dirpath: "../../checkpoints/vqgan/test" 16 | save_top_k: -1 # save all checkpoints 17 | - class_path: lightning.pytorch.callbacks.LearningRateMonitor 18 | init_args: 19 | logging_interval: step 20 | logger: 21 | class_path: lightning.pytorch.loggers.TensorBoardLogger 22 | init_args: 23 | save_dir: "../../results/vqgan/" 24 | version: "test" 25 | name: 26 | 27 | model: 28 | class_path: src.Open_MAGVIT2.models.lfqgan.VQModel 29 | init_args: 30 | ddconfig: 31 | double_z: False 32 | z_channels: 18 33 | resolution: 128 34 | in_channels: 3 35 | out_ch: 3 36 | ch: 128 37 | ch_mult: [1,1,2,2,4] # num_down = len(ch_mult)-1 38 | num_res_blocks: 4 39 | 40 | lossconfig: 41 | target: src.Open_MAGVIT2.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 42 | params: 43 | disc_conditional: False 44 | disc_in_channels: 3 45 | disc_start: 0 # from 0 epoch 46 | disc_weight: 0.8 47 | gen_loss_weight: 0.1 48 | lecam_loss_weight: 0.05 49 | codebook_weight: 0.1 50 | commit_weight: 0.25 51 | codebook_enlarge_ratio: 0 52 | codebook_enlarge_steps: 2000 53 | 54 | n_embed: 262144 55 | embed_dim: 18 56 | learning_rate: 1e-4 57 | sample_minimization_weight: 1.0 58 | batch_maximization_weight: 1.0 59 | scheduler_type: "None" 60 | use_ema: True 61 | resume_lr: 62 | lr_drop_epoch: [200, 250] 63 | 64 | data: 65 | class_path: main.DataModuleFromConfig 66 | init_args: 67 | batch_size: 8 68 | num_workers: 16 69 | train: 70 | target: src.Open_MAGVIT2.data.imagenet.ImageNetTrain 71 | params: 72 | config: 73 | size: 256 74 | subset: 75 | validation: 76 | target: src.Open_MAGVIT2.data.imagenet.ImageNetValidation 77 | params: 78 | config: 79 | size: 256 80 | subset: 81 | test: 82 | target: src.Open_MAGVIT2.data.imagenet.ImageNetValidation 83 | params: 84 | config: 85 | size: 256 86 | subset: 87 | 88 | ckpt_path: null # to resume -------------------------------------------------------------------------------- /configs/Open-MAGVIT2/npu/imagenet_lfqgan_256_L.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: true 2 | trainer: 3 | accelerator: npu 4 | strategy: ddp_find_unused_parameters_true 5 | devices: 8 6 | num_nodes: 4 7 | precision: bf16-mixed 8 | max_epochs: 270 9 | check_val_every_n_epoch: 1 10 | num_sanity_val_steps: -1 11 | log_every_n_steps: 100 12 | callbacks: 13 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 14 | init_args: 15 | dirpath: "../../checkpoints/vqgan/test" 16 | save_top_k: -1 # save all checkpoints 17 | - class_path: lightning.pytorch.callbacks.LearningRateMonitor 18 | init_args: 19 | logging_interval: step 20 | logger: 21 | class_path: lightning.pytorch.loggers.TensorBoardLogger 22 | init_args: 23 | save_dir: "../../results/vqgan/" 24 | version: "test" 25 | name: 26 | 27 | model: 28 | class_path: src.Open_MAGVIT2.models.lfqgan.VQModel 29 | init_args: 30 | ddconfig: 31 | double_z: False 32 | z_channels: 18 33 | resolution: 128 34 | in_channels: 3 35 | out_ch: 3 36 | ch: 128 37 | ch_mult: [1,1,2,2,4] # num_down = len(ch_mult)-1 38 | num_res_blocks: 4 39 | 40 | lossconfig: 41 | target: src.Open_MAGVIT2.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 42 | params: 43 | disc_conditional: False 44 | disc_in_channels: 3 45 | disc_start: 0 # from 0 epoch 46 | disc_weight: 0.8 47 | gen_loss_weight: 0.1 48 | lecam_loss_weight: 0.05 49 | codebook_weight: 0.1 50 | commit_weight: 0.25 51 | codebook_enlarge_ratio: 0 52 | codebook_enlarge_steps: 2000 53 | 54 | n_embed: 262144 55 | embed_dim: 18 56 | learning_rate: 1e-4 57 | sample_minimization_weight: 1.0 58 | batch_maximization_weight: 1.0 59 | scheduler_type: "None" 60 | use_ema: True 61 | resume_lr: 62 | lr_drop_epoch: [200, 250] 63 | 64 | data: 65 | class_path: main.DataModuleFromConfig 66 | init_args: 67 | batch_size: 8 68 | num_workers: 16 69 | train: 70 | target: src.Open_MAGVIT2.data.imagenet.ImageNetTrain 71 | params: 72 | config: 73 | size: 256 74 | subset: 75 | validation: 76 | target: src.Open_MAGVIT2.data.imagenet.ImageNetValidation 77 | params: 78 | config: 79 | size: 256 80 | subset: 81 | test: 82 | target: src.Open_MAGVIT2.data.imagenet.ImageNetValidation 83 | params: 84 | config: 85 | size: 256 86 | subset: 87 | 88 | ckpt_path: null # to resume -------------------------------------------------------------------------------- /configs/Open-MAGVIT2/npu/imagenet_lfqgan_128_L.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: true 2 | trainer: 3 | accelerator: npu 4 | strategy: ddp_find_unused_parameters_true 5 | devices: 8 6 | num_nodes: 4 7 | precision: bf16-mixed 8 | max_epochs: 350 9 | check_val_every_n_epoch: 1 10 | num_sanity_val_steps: -1 11 | log_every_n_steps: 100 12 | callbacks: 13 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 14 | init_args: 15 | dirpath: "../../checkpoints/vqgan/test" 16 | save_top_k: -1 # save all checkpoints 17 | - class_path: lightning.pytorch.callbacks.LearningRateMonitor 18 | init_args: 19 | logging_interval: step 20 | logger: 21 | class_path: lightning.pytorch.loggers.TensorBoardLogger 22 | init_args: 23 | save_dir: "../../results/vqgan/" 24 | version: "test" 25 | name: 26 | 27 | model: 28 | class_path: src.Open_MAGVIT2.models.lfqgan.VQModel 29 | init_args: 30 | ddconfig: 31 | double_z: False 32 | z_channels: 18 33 | resolution: 128 34 | in_channels: 3 35 | out_ch: 3 36 | ch: 128 37 | ch_mult: [1,2,2,4] # num_down = len(ch_mult)-1 38 | num_res_blocks: 4 39 | 40 | lossconfig: 41 | target: src.Open_MAGVIT2.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 42 | params: 43 | disc_conditional: False 44 | disc_in_channels: 3 45 | disc_start: 0 # from 0 epoch 46 | disc_weight: 0.8 47 | gen_loss_weight: 0.1 #see if it is reuqired to change 48 | lecam_loss_weight: 0.01 49 | codebook_weight: 0.1 50 | commit_weight: 0.25 51 | codebook_enlarge_ratio: 0 52 | codebook_enlarge_steps: 2000 53 | 54 | n_embed: 262144 55 | embed_dim: 18 56 | learning_rate: 1e-4 57 | sample_minimization_weight: 1.0 58 | batch_maximization_weight: 1.0 59 | scheduler_type: "None" 60 | use_ema: True 61 | resume_lr: 62 | lr_drop_epoch: [250, 300] 63 | 64 | data: 65 | class_path: main.DataModuleFromConfig 66 | init_args: 67 | batch_size: 8 68 | num_workers: 16 69 | train: 70 | target: src.Open_MAGVIT2.data.imagenet.ImageNetTrain 71 | params: 72 | config: 73 | size: 128 74 | subset: 75 | validation: 76 | target: src.Open_MAGVIT2.data.imagenet.ImageNetValidation 77 | params: 78 | config: 79 | size: 128 80 | subset: 81 | test: 82 | target: src.Open_MAGVIT2.data.imagenet.ImageNetValidation 83 | params: 84 | config: 85 | size: 128 86 | subset: 87 | 88 | ckpt_path: null # to resume -------------------------------------------------------------------------------- /configs/Open-MAGVIT2/gpu/imagenet_lfqgan_128_L.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: true 2 | trainer: 3 | accelerator: gpu 4 | strategy: ddp_find_unused_parameters_true 5 | devices: 8 6 | num_nodes: 4 7 | precision: 16-mixed 8 | max_epochs: 350 9 | check_val_every_n_epoch: 1 10 | num_sanity_val_steps: -1 11 | log_every_n_steps: 100 12 | callbacks: 13 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 14 | init_args: 15 | dirpath: "../../checkpoints/vqgan/test" # Please specify your own path 16 | save_top_k: -1 # save all checkpoints 17 | - class_path: lightning.pytorch.callbacks.LearningRateMonitor 18 | init_args: 19 | logging_interval: step 20 | logger: 21 | class_path: lightning.pytorch.loggers.TensorBoardLogger 22 | init_args: 23 | save_dir: "../../results/vqgan/" #Please specify your own path 24 | version: "test" 25 | name: 26 | 27 | model: 28 | class_path: src.Open_MAGVIT2.models.lfqgan.VQModel 29 | init_args: 30 | ddconfig: 31 | double_z: False 32 | z_channels: 18 33 | resolution: 128 34 | in_channels: 3 35 | out_ch: 3 36 | ch: 128 37 | ch_mult: [1,2,2,4] # num_down = len(ch_mult)-1 38 | num_res_blocks: 4 39 | 40 | lossconfig: 41 | target: src.Open_MAGVIT2.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 42 | params: 43 | disc_conditional: False 44 | disc_in_channels: 3 45 | disc_start: 0 # from 0 epoch 46 | disc_weight: 0.8 47 | gen_loss_weight: 0.1 48 | lecam_loss_weight: 0.01 49 | codebook_weight: 0.1 50 | commit_weight: 0.25 51 | codebook_enlarge_ratio: 0 52 | codebook_enlarge_steps: 2000 53 | 54 | n_embed: 262144 55 | embed_dim: 18 56 | learning_rate: 1e-4 57 | sample_minimization_weight: 1.0 58 | batch_maximization_weight: 1.0 59 | scheduler_type: "None" 60 | use_ema: True 61 | resume_lr: 62 | lr_drop_epoch: [250, 300] 63 | 64 | 65 | data: 66 | class_path: main.DataModuleFromConfig 67 | init_args: 68 | batch_size: 8 69 | num_workers: 16 70 | train: 71 | target: src.Open_MAGVIT2.data.imagenet.ImageNetTrain 72 | params: 73 | config: 74 | size: 128 75 | subset: 76 | validation: 77 | target: src.Open_MAGVIT2.data.imagenet.ImageNetValidation 78 | params: 79 | config: 80 | size: 128 81 | subset: 82 | test: 83 | target: src.Open_MAGVIT2.data.imagenet.ImageNetValidation 84 | params: 85 | config: 86 | size: 128 87 | subset: 88 | 89 | ckpt_path: null # to resume -------------------------------------------------------------------------------- /configs/IBQ/gpu/imagenet_ibqgan_8192.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: true 2 | trainer: 3 | accelerator: gpu 4 | strategy: ddp_find_unused_parameters_true 5 | devices: 8 6 | num_nodes: 8 7 | precision: bf16-mixed 8 | max_epochs: 280 9 | check_val_every_n_epoch: 1 10 | num_sanity_val_steps: -1 11 | callbacks: 12 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 13 | init_args: 14 | dirpath: "../../checkpoints/vqgan/test" 15 | save_top_k: -1 16 | save_last: True 17 | monitor: "train/perceptual_loss" 18 | - class_path: lightning.pytorch.callbacks.LearningRateMonitor 19 | init_args: 20 | logging_interval: step 21 | logger: 22 | class_path: lightning.pytorch.loggers.TensorBoardLogger 23 | init_args: 24 | save_dir: "../../results/vqgan/" 25 | version: "test" 26 | name: 27 | 28 | model: 29 | class_path: src.IBQ.models.ibqgan.IBQ 30 | init_args: 31 | ddconfig: 32 | double_z: False 33 | z_channels: 256 34 | resolution: 256 35 | in_channels: 3 36 | out_ch: 3 37 | ch: 128 38 | ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1 39 | num_res_blocks: 4 #not adopt from showo 40 | attn_resolutions: [16] #not adopt from showo 41 | dropout: 0.0 42 | 43 | lossconfig: 44 | target: src.IBQ.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 45 | params: 46 | disc_conditional: False 47 | disc_in_channels: 3 48 | disc_start: 0 # from 0 epoch 49 | 50 | disc_weight: 0.4 # default 0.4 51 | quant_loss_weight: 1.0 # default 1.0 52 | entropy_loss_weight: 0.05 # default 0.1 53 | gen_loss_weight: 0.1 54 | lecam_loss_weight: 0.05 55 | 56 | n_embed: 8192 57 | embed_dim: 256 58 | learning_rate: 1e-4 59 | l2_normalize: False 60 | use_entropy_loss: True 61 | sample_minimization_weight: 1.0 62 | batch_maximization_weight: 1.0 63 | entropy_temperature: 0.01 # default 0.01 64 | beta: 0.25 65 | use_ema: True 66 | resume_lr: 67 | lr_drop_epoch: [250] 68 | 69 | 70 | data: 71 | class_path: main.DataModuleFromConfig 72 | init_args: 73 | batch_size: 4 74 | num_workers: 16 75 | train: 76 | target: src.IBQ.data.imagenet.ImageNetTrain 77 | params: 78 | config: 79 | size: 256 80 | subset: 81 | validation: 82 | target: src.IBQ.data.imagenet.ImageNetValidation 83 | params: 84 | config: 85 | size: 256 86 | test: 87 | target: src.IBQ.data.imagenet.ImageNetValidation 88 | params: 89 | config: 90 | size: 256 91 | 92 | ckpt_path: null # to resume -------------------------------------------------------------------------------- /configs/IBQ/npu/imagenet_ibqgan_8192.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: true 2 | trainer: 3 | accelerator: npu 4 | strategy: ddp_find_unused_parameters_true 5 | devices: 8 6 | num_nodes: 8 7 | precision: bf16-mixed 8 | max_epochs: 280 9 | check_val_every_n_epoch: 1 10 | num_sanity_val_steps: -1 11 | callbacks: 12 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 13 | init_args: 14 | dirpath: "../../checkpoints/vqgan/test" 15 | save_top_k: -1 16 | save_last: True 17 | monitor: "train/perceptual_loss" 18 | - class_path: lightning.pytorch.callbacks.LearningRateMonitor 19 | init_args: 20 | logging_interval: step 21 | logger: 22 | class_path: lightning.pytorch.loggers.TensorBoardLogger 23 | init_args: 24 | save_dir: "../../results/vqgan/" 25 | version: "test" 26 | name: 27 | 28 | model: 29 | class_path: src.IBQ.models.ibqgan.IBQ 30 | init_args: 31 | ddconfig: 32 | double_z: False 33 | z_channels: 256 34 | resolution: 256 35 | in_channels: 3 36 | out_ch: 3 37 | ch: 128 38 | ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1 39 | num_res_blocks: 4 #not adopt from showo 40 | attn_resolutions: [16] #not adopt from showo 41 | dropout: 0.0 42 | 43 | lossconfig: 44 | target: src.IBQ.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 45 | params: 46 | disc_conditional: False 47 | disc_in_channels: 3 48 | disc_start: 0 # from 0 epoch 49 | 50 | disc_weight: 0.4 # default 0.4 51 | quant_loss_weight: 1.0 # default 1.0 52 | entropy_loss_weight: 0.05 # default 0.1 53 | gen_loss_weight: 0.1 54 | lecam_loss_weight: 0.05 55 | 56 | n_embed: 8192 57 | embed_dim: 256 58 | learning_rate: 1e-4 59 | l2_normalize: False 60 | use_entropy_loss: True 61 | sample_minimization_weight: 1.0 62 | batch_maximization_weight: 1.0 63 | entropy_temperature: 0.01 # default 0.01 64 | beta: 0.25 65 | use_ema: True 66 | resume_lr: 67 | lr_drop_epoch: [250] 68 | 69 | 70 | data: 71 | class_path: main.DataModuleFromConfig 72 | init_args: 73 | batch_size: 4 74 | num_workers: 16 75 | train: 76 | target: src.IBQ.data.imagenet.ImageNetTrain 77 | params: 78 | config: 79 | size: 256 80 | subset: 81 | validation: 82 | target: src.IBQ.data.imagenet.ImageNetValidation 83 | params: 84 | config: 85 | size: 256 86 | test: 87 | target: src.IBQ.data.imagenet.ImageNetValidation 88 | params: 89 | config: 90 | size: 256 91 | 92 | ckpt_path: null # to resume -------------------------------------------------------------------------------- /configs/IBQ/gpu/imagenet_ibqgan_1024.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: true 2 | trainer: 3 | accelerator: gpu 4 | strategy: ddp_find_unused_parameters_true 5 | devices: 8 6 | num_nodes: 8 7 | precision: 16-mixed 8 | max_epochs: 330 9 | check_val_every_n_epoch: 1 10 | num_sanity_val_steps: -1 11 | callbacks: 12 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 13 | init_args: 14 | dirpath: "../../checkpoints/vqgan/test" 15 | save_top_k: -1 16 | save_last: True 17 | monitor: "train/perceptual_loss" 18 | - class_path: lightning.pytorch.callbacks.LearningRateMonitor 19 | init_args: 20 | logging_interval: step 21 | logger: 22 | class_path: lightning.pytorch.loggers.TensorBoardLogger 23 | init_args: 24 | save_dir: "../../results/vqgan/" 25 | version: "test" 26 | name: 27 | 28 | model: 29 | class_path: src.IBQ.models.ibqgan.IBQ 30 | init_args: 31 | ddconfig: 32 | double_z: False 33 | z_channels: 256 34 | resolution: 256 35 | in_channels: 3 36 | out_ch: 3 37 | ch: 128 38 | ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1 39 | num_res_blocks: 4 #not adopt from showo 40 | attn_resolutions: [16] #not adopt from showo 41 | dropout: 0.0 42 | 43 | lossconfig: 44 | target: src.IBQ.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 45 | params: 46 | disc_conditional: False 47 | disc_in_channels: 3 48 | disc_start: 0 # from 0 epoch 49 | 50 | disc_weight: 0.4 # default 0.4 51 | quant_loss_weight: 1.0 # default 1.0 52 | entropy_loss_weight: 0.05 # default 0.1 53 | gen_loss_weight: 0.1 54 | lecam_loss_weight: 0.05 55 | 56 | n_embed: 1024 57 | embed_dim: 256 58 | learning_rate: 1e-4 59 | l2_normalize: False 60 | use_entropy_loss: True 61 | sample_minimization_weight: 1.0 62 | batch_maximization_weight: 1.0 63 | entropy_temperature: 0.01 # default 0.01 64 | beta: 0.25 65 | use_ema: True 66 | resume_lr: 67 | lr_drop_epoch: [250, 300] 68 | 69 | 70 | data: 71 | class_path: main.DataModuleFromConfig 72 | init_args: 73 | batch_size: 4 74 | num_workers: 16 75 | train: 76 | target: src.IBQ.data.imagenet.ImageNetTrain 77 | params: 78 | config: 79 | size: 256 80 | subset: 81 | validation: 82 | target: src.IBQ.data.imagenet.ImageNetValidation 83 | params: 84 | config: 85 | size: 256 86 | test: 87 | target: src.IBQ.data.imagenet.ImageNetValidation 88 | params: 89 | config: 90 | size: 256 91 | 92 | ckpt_path: null # to resume -------------------------------------------------------------------------------- /configs/IBQ/gpu/imagenet_ibqgan_16384.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: true 2 | trainer: 3 | accelerator: gpu 4 | strategy: ddp_find_unused_parameters_true 5 | devices: 8 6 | num_nodes: 8 7 | precision: 16-mixed 8 | max_epochs: 330 9 | check_val_every_n_epoch: 1 10 | num_sanity_val_steps: -1 11 | callbacks: 12 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 13 | init_args: 14 | dirpath: "../../checkpoints/vqgan/test" 15 | save_top_k: -1 16 | save_last: True 17 | monitor: "train/perceptual_loss" 18 | - class_path: lightning.pytorch.callbacks.LearningRateMonitor 19 | init_args: 20 | logging_interval: step 21 | logger: 22 | class_path: lightning.pytorch.loggers.TensorBoardLogger 23 | init_args: 24 | save_dir: "../../results/vqgan/" 25 | version: "test" 26 | name: 27 | 28 | model: 29 | class_path: src.IBQ.models.ibqgan.IBQ 30 | init_args: 31 | ddconfig: 32 | double_z: False 33 | z_channels: 256 34 | resolution: 256 35 | in_channels: 3 36 | out_ch: 3 37 | ch: 128 38 | ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1 39 | num_res_blocks: 4 #not adopt from showo 40 | attn_resolutions: [16] #not adopt from showo 41 | dropout: 0.0 42 | 43 | lossconfig: 44 | target: src.IBQ.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 45 | params: 46 | disc_conditional: False 47 | disc_in_channels: 3 48 | disc_start: 0 # from 0 epoch 49 | 50 | disc_weight: 0.4 # default 0.4 51 | quant_loss_weight: 1.0 # default 1.0 52 | entropy_loss_weight: 0.05 # default 0.1 53 | gen_loss_weight: 0.1 54 | lecam_loss_weight: 0.05 55 | 56 | n_embed: 16384 57 | embed_dim: 256 58 | learning_rate: 1e-4 59 | l2_normalize: False 60 | use_entropy_loss: True 61 | sample_minimization_weight: 1.0 62 | batch_maximization_weight: 1.0 63 | entropy_temperature: 0.01 # default 0.01 64 | beta: 0.25 65 | use_ema: True 66 | resume_lr: 67 | lr_drop_epoch: [250, 300] 68 | 69 | 70 | data: 71 | class_path: main.DataModuleFromConfig 72 | init_args: 73 | batch_size: 4 74 | num_workers: 16 75 | train: 76 | target: src.IBQ.data.imagenet.ImageNetTrain 77 | params: 78 | config: 79 | size: 256 80 | subset: 81 | validation: 82 | target: src.IBQ.data.imagenet.ImageNetValidation 83 | params: 84 | config: 85 | size: 256 86 | test: 87 | target: src.IBQ.data.imagenet.ImageNetValidation 88 | params: 89 | config: 90 | size: 256 91 | 92 | ckpt_path: null # to resume -------------------------------------------------------------------------------- /configs/IBQ/npu/imagenet_ibqgan_1024.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: true 2 | trainer: 3 | accelerator: npu 4 | strategy: ddp_find_unused_parameters_true 5 | devices: 8 6 | num_nodes: 8 7 | precision: bf16-mixed 8 | max_epochs: 330 9 | check_val_every_n_epoch: 1 10 | num_sanity_val_steps: -1 11 | callbacks: 12 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 13 | init_args: 14 | dirpath: "../../checkpoints/vqgan/test" 15 | save_top_k: -1 16 | save_last: True 17 | monitor: "train/perceptual_loss" 18 | - class_path: lightning.pytorch.callbacks.LearningRateMonitor 19 | init_args: 20 | logging_interval: step 21 | logger: 22 | class_path: lightning.pytorch.loggers.TensorBoardLogger 23 | init_args: 24 | save_dir: "../../results/vqgan/" 25 | version: "test" 26 | name: 27 | 28 | model: 29 | class_path: src.IBQ.models.ibqgan.IBQ 30 | init_args: 31 | ddconfig: 32 | double_z: False 33 | z_channels: 256 34 | resolution: 256 35 | in_channels: 3 36 | out_ch: 3 37 | ch: 128 38 | ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1 39 | num_res_blocks: 4 #not adopt from showo 40 | attn_resolutions: [16] #not adopt from showo 41 | dropout: 0.0 42 | 43 | lossconfig: 44 | target: src.IBQ.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 45 | params: 46 | disc_conditional: False 47 | disc_in_channels: 3 48 | disc_start: 0 # from 0 epoch 49 | 50 | disc_weight: 0.4 # default 0.4 51 | quant_loss_weight: 1.0 # default 1.0 52 | entropy_loss_weight: 0.05 # default 0.1 53 | gen_loss_weight: 0.1 54 | lecam_loss_weight: 0.05 55 | 56 | n_embed: 1024 57 | embed_dim: 256 58 | learning_rate: 1e-4 59 | l2_normalize: False 60 | use_entropy_loss: True 61 | sample_minimization_weight: 1.0 62 | batch_maximization_weight: 1.0 63 | entropy_temperature: 0.01 # default 0.01 64 | beta: 0.25 65 | use_ema: True 66 | resume_lr: 67 | lr_drop_epoch: [250, 300] 68 | 69 | 70 | data: 71 | class_path: main.DataModuleFromConfig 72 | init_args: 73 | batch_size: 4 74 | num_workers: 16 75 | train: 76 | target: src.IBQ.data.imagenet.ImageNetTrain 77 | params: 78 | config: 79 | size: 256 80 | subset: 81 | validation: 82 | target: src.IBQ.data.imagenet.ImageNetValidation 83 | params: 84 | config: 85 | size: 256 86 | test: 87 | target: src.IBQ.data.imagenet.ImageNetValidation 88 | params: 89 | config: 90 | size: 256 91 | 92 | ckpt_path: null # to resume -------------------------------------------------------------------------------- /configs/IBQ/gpu/imagenet_ibqgan_262144.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: true 2 | trainer: 3 | accelerator: gpu 4 | strategy: ddp_find_unused_parameters_true 5 | devices: 8 6 | num_nodes: 8 7 | precision: 16-mixed 8 | max_epochs: 330 9 | check_val_every_n_epoch: 1 10 | num_sanity_val_steps: -1 11 | callbacks: 12 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 13 | init_args: 14 | dirpath: "../../checkpoints/vqgan/test" 15 | save_top_k: -1 16 | save_last: True 17 | monitor: "train/perceptual_loss" 18 | - class_path: lightning.pytorch.callbacks.LearningRateMonitor 19 | init_args: 20 | logging_interval: step 21 | logger: 22 | class_path: lightning.pytorch.loggers.TensorBoardLogger 23 | init_args: 24 | save_dir: "../../results/vqgan/" 25 | version: "test" 26 | name: 27 | 28 | model: 29 | class_path: src.IBQ.models.ibqgan.IBQ 30 | init_args: 31 | ddconfig: 32 | double_z: False 33 | z_channels: 256 34 | resolution: 256 35 | in_channels: 3 36 | out_ch: 3 37 | ch: 128 38 | ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1 39 | num_res_blocks: 4 #not adopt from showo 40 | attn_resolutions: [16] #not adopt from showo 41 | dropout: 0.0 42 | 43 | lossconfig: 44 | target: src.IBQ.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 45 | params: 46 | disc_conditional: False 47 | disc_in_channels: 3 48 | disc_start: 0 # from 0 epoch 49 | disc_weight: 0.4 # default 0.4 50 | quant_loss_weight: 1.0 # default 1.0 51 | entropy_loss_weight: 0.05 # default 0.1 52 | gen_loss_weight: 0.1 53 | lecam_loss_weight: 0.05 54 | 55 | n_embed: 262144 56 | embed_dim: 256 57 | learning_rate: 1e-4 58 | l2_normalize: False 59 | use_entropy_loss: True 60 | sample_minimization_weight: 1.0 61 | batch_maximization_weight: 1.0 62 | entropy_temperature: 0.01 # default 0.01 63 | beta: 0.25 64 | use_ema: True 65 | resume_lr: 66 | lr_drop_epoch: [250, 300] 67 | 68 | 69 | data: 70 | class_path: main.DataModuleFromConfig 71 | init_args: 72 | batch_size: 4 73 | num_workers: 16 74 | train: 75 | target: src.IBQ.data.imagenet.ImageNetTrain 76 | params: 77 | config: 78 | size: 256 79 | subset: 80 | validation: 81 | target: src.IBQ.data.imagenet.ImageNetValidation 82 | params: 83 | config: 84 | size: 256 85 | test: 86 | target: src.IBQ.data.imagenet.ImageNetValidation 87 | params: 88 | config: 89 | size: 256 90 | 91 | ckpt_path: null # to resume 92 | -------------------------------------------------------------------------------- /configs/IBQ/npu/imagenet_ibqgan_16384.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: true 2 | trainer: 3 | accelerator: npu 4 | strategy: ddp_find_unused_parameters_true 5 | devices: 8 6 | num_nodes: 8 7 | precision: bf16-mixed 8 | max_epochs: 330 9 | check_val_every_n_epoch: 1 10 | num_sanity_val_steps: -1 11 | callbacks: 12 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 13 | init_args: 14 | dirpath: "../../checkpoints/vqgan/test" 15 | save_top_k: -1 16 | save_last: True 17 | monitor: "train/perceptual_loss" 18 | - class_path: lightning.pytorch.callbacks.LearningRateMonitor 19 | init_args: 20 | logging_interval: step 21 | logger: 22 | class_path: lightning.pytorch.loggers.TensorBoardLogger 23 | init_args: 24 | save_dir: "../../results/vqgan/" 25 | version: "test" 26 | name: 27 | 28 | model: 29 | class_path: src.IBQ.models.ibqgan.IBQ 30 | init_args: 31 | ddconfig: 32 | double_z: False 33 | z_channels: 256 34 | resolution: 256 35 | in_channels: 3 36 | out_ch: 3 37 | ch: 128 38 | ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1 39 | num_res_blocks: 4 #not adopt from showo 40 | attn_resolutions: [16] #not adopt from showo 41 | dropout: 0.0 42 | 43 | lossconfig: 44 | target: src.IBQ.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 45 | params: 46 | disc_conditional: False 47 | disc_in_channels: 3 48 | disc_start: 0 # from 0 epoch 49 | 50 | disc_weight: 0.4 # default 0.4 51 | quant_loss_weight: 1.0 # default 1.0 52 | entropy_loss_weight: 0.05 # default 0.1 53 | gen_loss_weight: 0.1 54 | lecam_loss_weight: 0.05 55 | 56 | n_embed: 16384 57 | embed_dim: 256 58 | learning_rate: 1e-4 59 | l2_normalize: False 60 | use_entropy_loss: True 61 | sample_minimization_weight: 1.0 62 | batch_maximization_weight: 1.0 63 | entropy_temperature: 0.01 # default 0.01 64 | beta: 0.25 65 | use_ema: True 66 | resume_lr: 67 | lr_drop_epoch: [250, 300] 68 | 69 | 70 | data: 71 | class_path: main.DataModuleFromConfig 72 | init_args: 73 | batch_size: 4 74 | num_workers: 16 75 | train: 76 | target: src.IBQ.data.imagenet.ImageNetTrain 77 | params: 78 | config: 79 | size: 256 80 | subset: 81 | validation: 82 | target: src.IBQ.data.imagenet.ImageNetValidation 83 | params: 84 | config: 85 | size: 256 86 | test: 87 | target: src.IBQ.data.imagenet.ImageNetValidation 88 | params: 89 | config: 90 | size: 256 91 | 92 | ckpt_path: null # to resume -------------------------------------------------------------------------------- /configs/IBQ/npu/imagenet_ibqgan_262144.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: true 2 | trainer: 3 | accelerator: npu 4 | strategy: ddp_find_unused_parameters_true 5 | devices: 8 6 | num_nodes: 8 7 | precision: bf16-mixed 8 | max_epochs: 330 9 | check_val_every_n_epoch: 1 10 | num_sanity_val_steps: -1 11 | callbacks: 12 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 13 | init_args: 14 | dirpath: "../../checkpoints/vqgan/test" 15 | save_top_k: -1 16 | save_last: True 17 | monitor: "train/perceptual_loss" 18 | - class_path: lightning.pytorch.callbacks.LearningRateMonitor 19 | init_args: 20 | logging_interval: step 21 | logger: 22 | class_path: lightning.pytorch.loggers.TensorBoardLogger 23 | init_args: 24 | save_dir: "../../results/vqgan/" 25 | version: "test" 26 | name: 27 | 28 | model: 29 | class_path: src.IBQ.models.ibqgan.IBQ 30 | init_args: 31 | ddconfig: 32 | double_z: False 33 | z_channels: 256 34 | resolution: 256 35 | in_channels: 3 36 | out_ch: 3 37 | ch: 128 38 | ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1 39 | num_res_blocks: 4 #not adopt from showo 40 | attn_resolutions: [16] #not adopt from showo 41 | dropout: 0.0 42 | 43 | lossconfig: 44 | target: src.IBQ.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 45 | params: 46 | disc_conditional: False 47 | disc_in_channels: 3 48 | disc_start: 0 # from 0 epoch 49 | disc_weight: 0.4 # default 0.4 50 | quant_loss_weight: 1.0 # default 1.0 51 | entropy_loss_weight: 0.05 # default 0.1 52 | gen_loss_weight: 0.1 53 | lecam_loss_weight: 0.05 54 | 55 | n_embed: 262144 56 | embed_dim: 256 57 | learning_rate: 1e-4 58 | l2_normalize: False 59 | use_entropy_loss: True 60 | sample_minimization_weight: 1.0 61 | batch_maximization_weight: 1.0 62 | entropy_temperature: 0.01 # default 0.01 63 | beta: 0.25 64 | use_ema: True 65 | resume_lr: 66 | lr_drop_epoch: [250, 300] 67 | 68 | 69 | data: 70 | class_path: main.DataModuleFromConfig 71 | init_args: 72 | batch_size: 4 73 | num_workers: 16 74 | train: 75 | target: src.IBQ.data.imagenet.ImageNetTrain 76 | params: 77 | config: 78 | size: 256 79 | subset: 80 | validation: 81 | target: src.IBQ.data.imagenet.ImageNetValidation 82 | params: 83 | config: 84 | size: 256 85 | test: 86 | target: src.IBQ.data.imagenet.ImageNetValidation 87 | params: 88 | config: 89 | size: 256 90 | 91 | ckpt_path: null # to resume 92 | -------------------------------------------------------------------------------- /src/IBQ/modules/discriminator/model.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch.nn as nn 3 | 4 | from src.IBQ.modules.util import ActNorm 5 | 6 | 7 | def weights_init(m): 8 | classname = m.__class__.__name__ 9 | if classname.find('Conv') != -1: 10 | nn.init.normal_(m.weight.data, 0.0, 0.02) 11 | elif classname.find('BatchNorm') != -1: 12 | nn.init.normal_(m.weight.data, 1.0, 0.02) 13 | nn.init.constant_(m.bias.data, 0) 14 | 15 | 16 | class NLayerDiscriminator(nn.Module): 17 | """Defines a PatchGAN discriminator as in Pix2Pix 18 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 19 | """ 20 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): 21 | """Construct a PatchGAN discriminator 22 | Parameters: 23 | input_nc (int) -- the number of channels in input images 24 | ndf (int) -- the number of filters in the last conv layer 25 | n_layers (int) -- the number of conv layers in the discriminator 26 | norm_layer -- normalization layer 27 | """ 28 | super(NLayerDiscriminator, self).__init__() 29 | if not use_actnorm: 30 | norm_layer = nn.BatchNorm2d 31 | else: 32 | norm_layer = ActNorm 33 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 34 | use_bias = norm_layer.func != nn.BatchNorm2d 35 | else: 36 | use_bias = norm_layer != nn.BatchNorm2d 37 | 38 | kw = 4 39 | padw = 1 40 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] 41 | nf_mult = 1 42 | nf_mult_prev = 1 43 | for n in range(1, n_layers): # gradually increase the number of filters 44 | nf_mult_prev = nf_mult 45 | nf_mult = min(2 ** n, 8) 46 | sequence += [ 47 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), 48 | norm_layer(ndf * nf_mult), 49 | nn.LeakyReLU(0.2, True) 50 | ] 51 | 52 | nf_mult_prev = nf_mult 53 | nf_mult = min(2 ** n_layers, 8) 54 | sequence += [ 55 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), 56 | norm_layer(ndf * nf_mult), 57 | nn.LeakyReLU(0.2, True) 58 | ] 59 | 60 | sequence += [ 61 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map 62 | self.main = nn.Sequential(*sequence) 63 | 64 | def forward(self, input): 65 | """Standard forward.""" 66 | return self.main(input) 67 | -------------------------------------------------------------------------------- /src/Open_MAGVIT2/data/base.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import numpy as np 3 | import albumentations 4 | from PIL import Image 5 | from torch.utils.data import Dataset, ConcatDataset 6 | import torchvision.transforms as T 7 | 8 | class ConcatDatasetWithIndex(ConcatDataset): 9 | """Modified from original pytorch code to return dataset idx""" 10 | def __getitem__(self, idx): 11 | if idx < 0: 12 | if -idx > len(self): 13 | raise ValueError("absolute value of index should not exceed dataset length") 14 | idx = len(self) + idx 15 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) 16 | if dataset_idx == 0: 17 | sample_idx = idx 18 | else: 19 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] 20 | return self.datasets[dataset_idx][sample_idx], dataset_idx 21 | 22 | class ImagePaths(Dataset): 23 | def __init__(self, paths, original_reso = False, size=None, random_crop=False, labels=None): 24 | self.size = size 25 | self.random_crop = random_crop 26 | self.original_reso = original_reso 27 | 28 | self.labels = dict() if labels is None else labels 29 | self.labels["file_path_"] = paths 30 | self._length = len(paths) 31 | 32 | if self.size is not None and self.size > 0: 33 | self.rescaler = T.Resize(self.size) 34 | if not self.random_crop: 35 | self.cropper = T.CenterCrop((self.size, self.size)) 36 | else: 37 | self.cropper = T.RandomCrop((self.size, self.size)) 38 | self.preprocessor = T.Compose([self.rescaler, self.cropper]) 39 | else: 40 | self.preprocessor = lambda **kwargs: kwargs 41 | 42 | def __len__(self): 43 | return self._length 44 | 45 | def preprocess_image(self, image_path): 46 | image = Image.open(image_path) 47 | if not image.mode == "RGB": 48 | image = image.convert("RGB") 49 | if not self.original_reso: 50 | image = self.preprocessor(image) 51 | image = np.array(image) 52 | image = (image/127.5 - 1.0).astype(np.float32) 53 | return image 54 | 55 | def __getitem__(self, i): 56 | example = dict() 57 | example["image"] = self.preprocess_image(self.labels["file_path_"][i]) 58 | for k in self.labels: 59 | example[k] = self.labels[k][i] 60 | return example 61 | 62 | class NumpyPaths(ImagePaths): 63 | def preprocess_image(self, image_path): 64 | image = np.load(image_path).squeeze(0) # 3 x 1024 x 1024 65 | image = np.transpose(image, (1,2,0)) 66 | image = Image.fromarray(image, mode="RGB") 67 | image = np.array(image).astype(np.uint8) 68 | image = self.preprocessor(image=image)["image"] 69 | image = (image/127.5 - 1.0).astype(np.float32) 70 | return image 71 | -------------------------------------------------------------------------------- /scripts/evaluation/evaluation_256.sh: -------------------------------------------------------------------------------- 1 | ## GPU and NPU can use the same config for evaluation 2 | ##Open-MAGVIT2 GPU 3 | # python evaluation_image.py --config_file configs/Open-MAGVIT2/gpu/imagenet_lfqgan_256_L.yaml --ckpt_path ../upload_ckpts/Open-MAGVIT2/in1k_256_L/imagenet_256_L.ckpt --image_size 256 --model Open-MAGVIT2 4 | 5 | ##NPU 6 | ##Open-MAGVIT2 7 | # python evaluation_image.py --config_file configs/Open-MAGVIT2/npu/imagenet_lfqgan_256_L.yaml --ckpt_path ../upload_ckpts/Open-MAGVIT2/in1k_256_L/imagenet_256_L.ckpt --image_size 256 --model Open-MAGVIT2 8 | 9 | ##NPU 10 | ##Open-MAGVIT2 pretrain 262144 11 | # python evaluation_image.py --config_file configs/Open-MAGVIT2/npu/pretrain_lfqgan_256_262144.yaml --ckpt_path ../upload_ckpts/Open-MAGVIT2/pretrain_256_262144/pretrain256_262144.ckpt --image_size 256 --model Open-MAGVIT2 12 | 13 | ##NPU 14 | ##Open-MAGVIT2 pretrain 16384 15 | # python evaluation_image.py --config_file configs/Open-MAGVIT2/npu/pretrain_lfqgan_256_16384.yaml --ckpt_path ../upload_ckpts/Open-MAGVIT2/pretrain_256_16384/pretrain256_16384.ckpt --image_size 256 --model Open-MAGVIT2 16 | 17 | 18 | ##IBQ NPU 19 | ## 16384 20 | # python evaluation_image.py --config_file configs/IBQ/npu/imagenet_ibqgan_16384.yaml --ckpt_path ../upload_ckpts/IBQ/in1k_16384/imagenet256_16384.ckpt --image_size 256 --model IBQ 21 | 22 | ## 262144 23 | # python evaluation_image.py --config_file configs/IBQ/npu/imagenet_ibqgan_262144.yaml --ckpt_path ../upload_ckpts/IBQ/in1k_262144/imagenet256_262144.ckpt --image_size 256 --model IBQ 24 | 25 | ## Pretrain 262144 26 | # python evaluation_image.py --config_file configs/IBQ/npu/pretrain_ibqgan_262144.yaml --ckpt_path ../upload_ckpts/IBQ/pretrain_262144/pretrain256_262144.ckpt --image_size 256 --model IBQ 27 | 28 | 29 | ## 8192 30 | # python evaluation_image.py --config_file configs/IBQ/npu/imagenet_ibqgan_8192.yaml --ckpt_path ../upload_ckpts/IBQ/in1k_8192/imagenet256_8192.ckpt --image_size 256 --model IBQ 31 | 32 | ##1024 33 | # python evaluation_image.py --config_file configs/IBQ/npu/imagenet_ibqgan_1024.yaml --ckpt_path ../upload_ckpts/IBQ/in1k_1024/imagenet256_1024.ckpt --image_size 256 --model IBQ 34 | 35 | ## IBQ GPU 36 | ## 16384 37 | # python evaluation_image.py --config_file configs/IBQ/gpu/imagenet_ibqgan_16384.yaml --ckpt_path ../upload_ckpts/IBQ/in1k_16384/imagenet256_16384.ckpt --image_size 256 --model IBQ 38 | 39 | ## 262144 40 | # python evaluation_image.py --config_file configs/IBQ/gpu/imagenet_ibqgan_262144.yaml --ckpt_path ../upload_ckpts/IBQ/in1k_262144/imagenet256_262144.ckpt --image_size 256 --model IBQ 41 | 42 | ## 8192 43 | # python evaluation_image.py --config_file configs/IBQ/gpu/imagenet_ibqgan_8192.yaml --ckpt_path ../upload_ckpts/IBQ/in1k_8192/imagenet256_8192.ckpt --image_size 256 --model IBQ 44 | 45 | ## 1024 46 | # python evaluation_image.py --config_file configs/IBQ/gpu/imagenet_ibqgan_1024.yaml --ckpt_path ../upload_ckpts/IBQ/in1k_1024/imagenet256_1024.ckpt --image_size 256 --model IBQ -------------------------------------------------------------------------------- /src/IBQ/data/base.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import numpy as np 3 | import albumentations 4 | from PIL import Image 5 | from torch.utils.data import Dataset, ConcatDataset 6 | 7 | 8 | class ConcatDatasetWithIndex(ConcatDataset): 9 | """Modified from original pytorch code to return dataset idx""" 10 | def __getitem__(self, idx): 11 | if idx < 0: 12 | if -idx > len(self): 13 | raise ValueError("absolute value of index should not exceed dataset length") 14 | idx = len(self) + idx 15 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) 16 | if dataset_idx == 0: 17 | sample_idx = idx 18 | else: 19 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] 20 | return self.datasets[dataset_idx][sample_idx], dataset_idx 21 | 22 | 23 | class ImagePaths(Dataset): 24 | def __init__(self, paths, original_reso=False, size=None, random_crop=False, labels=None): 25 | self.size = size 26 | self.random_crop = random_crop 27 | self.original_reso = original_reso 28 | 29 | self.labels = dict() if labels is None else labels 30 | self.labels["file_path_"] = paths 31 | self._length = len(paths) 32 | 33 | if self.size is not None and self.size > 0: 34 | self.rescaler = albumentations.SmallestMaxSize(max_size = self.size) 35 | if not self.random_crop: 36 | self.cropper = albumentations.CenterCrop(height=self.size,width=self.size) 37 | else: 38 | self.cropper = albumentations.RandomCrop(height=self.size,width=self.size) 39 | self.preprocessor = albumentations.Compose([self.rescaler, self.cropper]) 40 | else: 41 | self.preprocessor = lambda **kwargs: kwargs 42 | 43 | def __len__(self): 44 | return self._length 45 | 46 | def preprocess_image(self, image_path): 47 | image = Image.open(image_path) 48 | if not image.mode == "RGB": 49 | image = image.convert("RGB") 50 | image = np.array(image).astype(np.uint8) 51 | if not self.original_reso: 52 | image = self.preprocessor(image=image)["image"] 53 | image = (image/127.5 - 1.0).astype(np.float32) 54 | return image 55 | 56 | def __getitem__(self, i): 57 | example = dict() 58 | example["image"] = self.preprocess_image(self.labels["file_path_"][i]) 59 | for k in self.labels: 60 | example[k] = self.labels[k][i] 61 | return example 62 | 63 | 64 | class NumpyPaths(ImagePaths): 65 | def preprocess_image(self, image_path): 66 | image = np.load(image_path).squeeze(0) # 3 x 1024 x 1024 67 | image = np.transpose(image, (1,2,0)) 68 | image = Image.fromarray(image, mode="RGB") 69 | image = np.array(image).astype(np.uint8) 70 | image = self.preprocessor(image=image)["image"] 71 | image = (image/127.5 - 1.0).astype(np.float32) 72 | return image 73 | -------------------------------------------------------------------------------- /src/Open_MAGVIT2/data/functional.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | import cv2 3 | import numpy as np 4 | import PIL 5 | import torch 6 | 7 | 8 | def _is_tensor_clip(clip): 9 | return torch.is_tensor(clip) and clip.ndimension() == 4 10 | 11 | 12 | def crop_clip(clip, min_h, min_w, h, w): 13 | if isinstance(clip[0], np.ndarray): 14 | cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip] 15 | 16 | elif isinstance(clip[0], PIL.Image.Image): 17 | cropped = [ 18 | img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip 19 | ] 20 | else: 21 | raise TypeError('Expected numpy.ndarray or PIL.Image' + 22 | 'but got list of {0}'.format(type(clip[0]))) 23 | return cropped 24 | 25 | 26 | def resize_clip(clip, size, interpolation='bilinear'): 27 | if isinstance(clip[0], np.ndarray): 28 | if isinstance(size, numbers.Number): 29 | im_h, im_w, im_c = clip[0].shape 30 | # Min spatial dim already matches minimal size 31 | if (im_w <= im_h and im_w == size) or (im_h <= im_w 32 | and im_h == size): 33 | return clip 34 | new_h, new_w = get_resize_sizes(im_h, im_w, size) 35 | size = (new_w, new_h) 36 | else: 37 | size = size[0], size[1] 38 | if interpolation == 'bilinear': 39 | np_inter = cv2.INTER_LINEAR 40 | else: 41 | np_inter = cv2.INTER_NEAREST 42 | scaled = [ 43 | cv2.resize(img, size, interpolation=np_inter) for img in clip 44 | ] 45 | elif isinstance(clip[0], PIL.Image.Image): 46 | if isinstance(size, numbers.Number): 47 | im_w, im_h = clip[0].size 48 | # Min spatial dim already matches minimal size 49 | if (im_w <= im_h and im_w == size) or (im_h <= im_w 50 | and im_h == size): 51 | return clip 52 | new_h, new_w = get_resize_sizes(im_h, im_w, size) 53 | size = (new_w, new_h) 54 | else: 55 | size = size[1], size[0] 56 | if interpolation == 'bilinear': 57 | pil_inter = PIL.Image.BILINEAR 58 | else: 59 | pil_inter = PIL.Image.NEAREST 60 | scaled = [img.resize(size, pil_inter) for img in clip] 61 | else: 62 | raise TypeError('Expected numpy.ndarray or PIL.Image' + 63 | 'but got list of {0}'.format(type(clip[0]))) 64 | return scaled 65 | 66 | 67 | def get_resize_sizes(im_h, im_w, size): 68 | if im_w < im_h: 69 | ow = size 70 | oh = int(size * im_h / im_w) 71 | else: 72 | oh = size 73 | ow = int(size * im_w / im_h) 74 | return oh, ow 75 | 76 | 77 | def normalize(clip, mean, std, inplace=False): 78 | if not _is_tensor_clip(clip): 79 | raise TypeError('tensor is not a torch clip.') 80 | 81 | if not inplace: 82 | clip = clip.clone() 83 | 84 | dtype = clip.dtype 85 | mean = torch.as_tensor(mean, dtype=dtype, device=clip.device) 86 | std = torch.as_tensor(std, dtype=dtype, device=clip.device) 87 | clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None]) 88 | 89 | return clip -------------------------------------------------------------------------------- /src/IBQ/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.999, use_num_upates=False): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1, dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | # remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.', '') 20 | self.m_name2s_name.update({name: s_name}) 21 | self.register_buffer(s_name, p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def reset_num_updates(self): 26 | del self.num_updates 27 | self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int)) 28 | 29 | def forward(self, model): 30 | decay = self.decay 31 | 32 | if self.num_updates >= 0: 33 | self.num_updates += 1 34 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) 35 | 36 | one_minus_decay = 1.0 - decay 37 | 38 | with torch.no_grad(): 39 | m_param = dict(model.named_parameters()) 40 | shadow_params = dict(self.named_buffers()) 41 | 42 | for key in m_param: 43 | if m_param[key].requires_grad: 44 | sname = self.m_name2s_name[key] 45 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 46 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 47 | else: 48 | assert not key in self.m_name2s_name 49 | 50 | def copy_to(self, model): 51 | m_param = dict(model.named_parameters()) 52 | shadow_params = dict(self.named_buffers()) 53 | for key in m_param: 54 | if m_param[key].requires_grad: 55 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 56 | else: 57 | assert not key in self.m_name2s_name 58 | 59 | def store(self, parameters): 60 | """ 61 | Save the current parameters for restoring later. 62 | Args: 63 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 64 | temporarily stored. 65 | """ 66 | self.collected_params = [param.clone() for param in parameters] 67 | 68 | def restore(self, parameters): 69 | """ 70 | Restore the parameters stored with the `store` method. 71 | Useful to validate the model with EMA parameters without affecting the 72 | original optimization process. Store the parameters before the 73 | `copy_to` method. After validation (or model saving), use this to 74 | restore the former parameters. 75 | Args: 76 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 77 | updated with the stored parameters. 78 | """ 79 | for c_param, param in zip(self.collected_params, parameters): 80 | param.data.copy_(c_param.data) -------------------------------------------------------------------------------- /src/Open_MAGVIT2/modules/ema.py: -------------------------------------------------------------------------------- 1 | """ 2 | Refer to 3 | https://github.com/Stability-AI/stablediffusion/blob/main/ldm/modules/ema.py 4 | """ 5 | 6 | 7 | import torch 8 | from torch import nn 9 | 10 | 11 | class LitEma(nn.Module): 12 | def __init__(self, model, decay=0.999, use_num_upates=False): 13 | super().__init__() 14 | if decay < 0.0 or decay > 1.0: 15 | raise ValueError('Decay must be between 0 and 1') 16 | 17 | self.m_name2s_name = {} 18 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 19 | self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates 20 | else torch.tensor(-1, dtype=torch.int)) 21 | 22 | for name, p in model.named_parameters(): 23 | if p.requires_grad: 24 | # remove as '.'-character is not allowed in buffers 25 | s_name = name.replace('.', '') 26 | self.m_name2s_name.update({name: s_name}) 27 | self.register_buffer(s_name, p.clone().detach().data) 28 | 29 | self.collected_params = [] 30 | 31 | def reset_num_updates(self): 32 | del self.num_updates 33 | self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int)) 34 | 35 | def forward(self, model): 36 | decay = self.decay 37 | 38 | if self.num_updates >= 0: 39 | self.num_updates += 1 40 | decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) 41 | 42 | one_minus_decay = 1.0 - decay 43 | 44 | with torch.no_grad(): 45 | m_param = dict(model.named_parameters()) 46 | shadow_params = dict(self.named_buffers()) 47 | 48 | for key in m_param: 49 | if m_param[key].requires_grad: 50 | sname = self.m_name2s_name[key] 51 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 52 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 53 | else: 54 | assert not key in self.m_name2s_name 55 | 56 | def copy_to(self, model): 57 | m_param = dict(model.named_parameters()) 58 | shadow_params = dict(self.named_buffers()) 59 | for key in m_param: 60 | if m_param[key].requires_grad: 61 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 62 | else: 63 | assert not key in self.m_name2s_name 64 | 65 | def store(self, parameters): 66 | """ 67 | Save the current parameters for restoring later. 68 | Args: 69 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 70 | temporarily stored. 71 | """ 72 | self.collected_params = [param.clone() for param in parameters] 73 | 74 | def restore(self, parameters): 75 | """ 76 | Restore the parameters stored with the `store` method. 77 | Useful to validate the model with EMA parameters without affecting the 78 | original optimization process. Store the parameters before the 79 | `copy_to` method. After validation (or model saving), use this to 80 | restore the former parameters. 81 | Args: 82 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 83 | updated with the stored parameters. 84 | """ 85 | for c_param, param in zip(self.collected_params, parameters): 86 | param.data.copy_(c_param.data) -------------------------------------------------------------------------------- /configs/IBQ/gpu/imagenet_conditional_llama_B.yaml: -------------------------------------------------------------------------------- 1 | # refer to https://app.koofr.net/links/90cbd5aa-ef70-4f5e-99bc-f12e5a89380e?path=%2F2021-04-03T19-39-50_cin_transformer%2Fconfigs%2F2021-04-03T19-39-50-project.yaml 2 | seed_everything: true 3 | trainer: 4 | accelerator: gpu 5 | strategy: ddp_find_unused_parameters_true 6 | devices: 8 7 | num_nodes: 4 8 | precision: 16-mixed 9 | max_epochs: 300 10 | check_val_every_n_epoch: 1 11 | num_sanity_val_steps: 0 12 | gradient_clip_val: 1.0 13 | callbacks: 14 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 15 | init_args: 16 | dirpath: "../../checkpoints/vqgan/test" 17 | save_top_k: 20 18 | monitor: "train/loss" 19 | - class_path: lightning.pytorch.callbacks.LearningRateMonitor 20 | init_args: 21 | logging_interval: step 22 | logger: 23 | class_path: lightning.pytorch.loggers.TensorBoardLogger 24 | init_args: 25 | save_dir: "../../results/vqgan/" 26 | version: "test" 27 | name: 28 | 29 | model: 30 | class_path: src.IBQ.models.cond_transformer_llama.Net2NetTransformer 31 | init_args: 32 | learning_rate: 3e-4 33 | first_stage_key: image 34 | cond_stage_key: class_label 35 | weight_decay: 5e-2 36 | wpe: 0.1 #learning rate decay 37 | wp: 6 38 | wp0: 0.005 39 | twde: 0 40 | transformer_config: 41 | target: src.IBQ.modules.transformer.llama.GPT 42 | params: 43 | vocab_size: 16384 # 262144 tokens 44 | block_size: 256 45 | n_layer: 16 46 | n_head: 16 47 | n_embd: 1024 48 | cond_dim: 1024 49 | resid_dropout_p: 0.1 50 | ffn_dropout_p: 0.1 51 | token_drop: 0.1 52 | drop_path_rate: 0.0 ##not using droppath rate 53 | alng: 1e-3 54 | class_num: 1000 #class tokens 55 | first_stage_config: 56 | target: src.IBQ.models.ibqgan.IBQ 57 | params: 58 | ckpt_path: #specify your path for tokenizer # FID: 1.37 59 | n_embed: 16384 60 | embed_dim: 256 61 | learning_rate: 1e-4 62 | l2_normalize: False 63 | use_entropy_loss: True 64 | sample_minimization_weight: 1.0 65 | batch_maximization_weight: 1.0 66 | entropy_temperature: 0.01 # default 0.01 67 | beta: 0.25 68 | use_ema: True 69 | stage: transformer 70 | ddconfig: 71 | double_z: False 72 | z_channels: 256 73 | resolution: 256 74 | in_channels: 3 75 | out_ch: 3 76 | ch: 128 77 | ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1 78 | num_res_blocks: 4 79 | attn_resolutions: [16] 80 | dropout: 0.0 81 | lossconfig: 82 | target: src.IBQ.modules.losses.DummyLoss 83 | cond_stage_config: 84 | target: src.IBQ.modules.util.Labelator 85 | params: 86 | n_classes: 1000 87 | 88 | data: 89 | class_path: main.DataModuleFromConfig 90 | init_args: 91 | batch_size: 24 92 | num_workers: 16 93 | train: 94 | target: src.IBQ.data.imagenet.ImageNetTrain 95 | params: 96 | config: 97 | size: 256 98 | validation: 99 | target: src.IBQ.data.imagenet.ImageNetValidation 100 | params: 101 | config: 102 | size: 256 103 | subset: 104 | test: 105 | target: src.IBQ.data.imagenet.ImageNetValidation 106 | params: 107 | config: 108 | size: 256 109 | subset: 110 | 111 | ckpt_path: null # to resume -------------------------------------------------------------------------------- /configs/IBQ/npu/imagenet_conditional_llama_B.yaml: -------------------------------------------------------------------------------- 1 | # refer to https://app.koofr.net/links/90cbd5aa-ef70-4f5e-99bc-f12e5a89380e?path=%2F2021-04-03T19-39-50_cin_transformer%2Fconfigs%2F2021-04-03T19-39-50-project.yaml 2 | seed_everything: true 3 | trainer: 4 | accelerator: npu 5 | strategy: ddp_find_unused_parameters_true 6 | devices: 8 7 | num_nodes: 4 8 | precision: bf16-mixed 9 | max_epochs: 300 10 | check_val_every_n_epoch: 1 11 | num_sanity_val_steps: 0 12 | gradient_clip_val: 1.0 13 | callbacks: 14 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 15 | init_args: 16 | dirpath: "../../checkpoints/vqgan/test" 17 | save_top_k: 20 18 | monitor: "train/loss" 19 | - class_path: lightning.pytorch.callbacks.LearningRateMonitor 20 | init_args: 21 | logging_interval: step 22 | logger: 23 | class_path: lightning.pytorch.loggers.TensorBoardLogger 24 | init_args: 25 | save_dir: "../../results/vqgan/" 26 | version: "test" 27 | name: 28 | 29 | model: 30 | class_path: src.IBQ.models.cond_transformer_llama.Net2NetTransformer 31 | init_args: 32 | learning_rate: 3e-4 33 | first_stage_key: image 34 | cond_stage_key: class_label 35 | weight_decay: 5e-2 36 | wpe: 0.1 #learning rate decay 37 | wp: 6 38 | wp0: 0.005 39 | twde: 0 40 | transformer_config: 41 | target: src.IBQ.modules.transformer.llama.GPT 42 | params: 43 | vocab_size: 16384 # 262144 tokens 44 | block_size: 256 45 | n_layer: 16 46 | n_head: 16 47 | n_embd: 1024 48 | cond_dim: 1024 49 | resid_dropout_p: 0.1 50 | ffn_dropout_p: 0.1 51 | token_drop: 0.1 52 | drop_path_rate: 0.0 ##not using droppath rate 53 | alng: 1e-3 54 | class_num: 1000 #class tokens 55 | first_stage_config: 56 | target: src.IBQ.models.ibqgan.IBQ 57 | params: 58 | ckpt_path: # specify your path for tokenizer FID: 1.37 59 | n_embed: 16384 60 | embed_dim: 256 61 | learning_rate: 1e-4 62 | l2_normalize: False 63 | use_entropy_loss: True 64 | sample_minimization_weight: 1.0 65 | batch_maximization_weight: 1.0 66 | entropy_temperature: 0.01 # default 0.01 67 | beta: 0.25 68 | use_ema: True 69 | stage: transformer 70 | ddconfig: 71 | double_z: False 72 | z_channels: 256 73 | resolution: 256 74 | in_channels: 3 75 | out_ch: 3 76 | ch: 128 77 | ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1 78 | num_res_blocks: 4 79 | attn_resolutions: [16] 80 | dropout: 0.0 81 | lossconfig: 82 | target: src.IBQ.modules.losses.DummyLoss 83 | cond_stage_config: 84 | target: src.IBQ.modules.util.Labelator 85 | params: 86 | n_classes: 1000 87 | 88 | data: 89 | class_path: main.DataModuleFromConfig 90 | init_args: 91 | batch_size: 24 92 | num_workers: 16 93 | train: 94 | target: src.IBQ.data.imagenet.ImageNetTrain 95 | params: 96 | config: 97 | size: 256 98 | validation: 99 | target: src.IBQ.data.imagenet.ImageNetValidation 100 | params: 101 | config: 102 | size: 256 103 | subset: 104 | test: 105 | target: src.IBQ.data.imagenet.ImageNetValidation 106 | params: 107 | config: 108 | size: 256 109 | subset: 110 | 111 | ckpt_path: null # to resume -------------------------------------------------------------------------------- /configs/IBQ/gpu/imagenet_conditional_llama_L.yaml: -------------------------------------------------------------------------------- 1 | # refer to https://app.koofr.net/links/90cbd5aa-ef70-4f5e-99bc-f12e5a89380e?path=%2F2021-04-03T19-39-50_cin_transformer%2Fconfigs%2F2021-04-03T19-39-50-project.yaml 2 | seed_everything: true 3 | trainer: 4 | accelerator: gpu 5 | strategy: ddp_find_unused_parameters_true 6 | devices: 8 7 | num_nodes: 8 8 | precision: 16-mixed 9 | max_epochs: 350 10 | check_val_every_n_epoch: 1 11 | num_sanity_val_steps: 0 12 | gradient_clip_val: 1.0 13 | callbacks: 14 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 15 | init_args: 16 | dirpath: "../../checkpoints/vqgan/test" 17 | save_top_k: 20 18 | monitor: "train/loss" 19 | - class_path: lightning.pytorch.callbacks.LearningRateMonitor 20 | init_args: 21 | logging_interval: step 22 | logger: 23 | class_path: lightning.pytorch.loggers.TensorBoardLogger 24 | init_args: 25 | save_dir: "../../results/vqgan" 26 | version: "test" 27 | name: 28 | 29 | model: 30 | class_path: src.IBQ.models.cond_transformer_llama.Net2NetTransformer 31 | init_args: 32 | learning_rate: 3e-4 33 | first_stage_key: image 34 | cond_stage_key: class_label 35 | weight_decay: 5e-2 36 | wpe: 0.1 #learning rate decay #1B can be 0.01 37 | wp: 7 38 | wp0: 0.005 39 | twde: 0 40 | transformer_config: 41 | target: src.IBQ.modules.transformer.llama.GPT 42 | params: 43 | vocab_size: 16384 # 262144 tokens 44 | block_size: 256 45 | n_layer: 20 46 | n_head: 20 47 | n_embd: 1280 48 | cond_dim: 1280 49 | resid_dropout_p: 0.1 50 | ffn_dropout_p: 0.1 51 | token_drop: 0.1 52 | drop_path_rate: 0.0 ##not using droppath rate 53 | alng: 1e-3 54 | class_num: 1000 #class tokens 55 | first_stage_config: 56 | target: src.IBQ.models.ibqgan.IBQ 57 | params: 58 | ckpt_path: # specify your path for tokenizer FID: 1.37 59 | n_embed: 16384 60 | embed_dim: 256 61 | learning_rate: 1e-4 62 | l2_normalize: False 63 | use_entropy_loss: True 64 | sample_minimization_weight: 1.0 65 | batch_maximization_weight: 1.0 66 | entropy_temperature: 0.01 # default 0.01 67 | beta: 0.25 68 | use_ema: True 69 | stage: transformer 70 | ddconfig: 71 | double_z: False 72 | z_channels: 256 73 | resolution: 256 74 | in_channels: 3 75 | out_ch: 3 76 | ch: 128 77 | ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1 78 | num_res_blocks: 4 79 | attn_resolutions: [16] 80 | dropout: 0.0 81 | lossconfig: 82 | target: src.IBQ.modules.losses.DummyLoss 83 | cond_stage_config: 84 | target: src.IBQ.modules.util.Labelator 85 | params: 86 | n_classes: 1000 87 | 88 | data: 89 | class_path: main.DataModuleFromConfig 90 | init_args: 91 | batch_size: 12 92 | num_workers: 16 93 | train: 94 | target: src.IBQ.data.imagenet.ImageNetTrain 95 | params: 96 | config: 97 | size: 256 98 | validation: 99 | target: src.IBQ.data.imagenet.ImageNetValidation 100 | params: 101 | config: 102 | size: 256 103 | subset: 104 | test: 105 | target: src.IBQ.data.imagenet.ImageNetValidation 106 | params: 107 | config: 108 | size: 256 109 | subset: 110 | 111 | ckpt_path: null # to resume -------------------------------------------------------------------------------- /configs/IBQ/gpu/imagenet_conditional_llama_XL.yaml: -------------------------------------------------------------------------------- 1 | # refer to https://app.koofr.net/links/90cbd5aa-ef70-4f5e-99bc-f12e5a89380e?path=%2F2021-04-03T19-39-50_cin_transformer%2Fconfigs%2F2021-04-03T19-39-50-project.yaml 2 | seed_everything: true 3 | trainer: 4 | accelerator: gpu 5 | strategy: ddp_find_unused_parameters_true 6 | devices: 8 7 | num_nodes: 8 8 | precision: 16-mixed 9 | max_epochs: 400 10 | check_val_every_n_epoch: 1 11 | num_sanity_val_steps: 0 12 | gradient_clip_val: 1.0 13 | callbacks: 14 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 15 | init_args: 16 | dirpath: "../../checkpoints/vqgan/test" 17 | save_top_k: 20 18 | monitor: "train/loss" 19 | - class_path: lightning.pytorch.callbacks.LearningRateMonitor 20 | init_args: 21 | logging_interval: step 22 | logger: 23 | class_path: lightning.pytorch.loggers.TensorBoardLogger 24 | init_args: 25 | save_dir: "../../results/vqgan" 26 | version: "test" 27 | name: 28 | 29 | model: 30 | class_path: src.IBQ.models.cond_transformer_llama.Net2NetTransformer 31 | init_args: 32 | learning_rate: 3e-4 33 | first_stage_key: image 34 | cond_stage_key: class_label 35 | weight_decay: 5e-2 36 | wpe: 0.01 #learning rate decay #1B can be 0.01 37 | wp: 8 38 | wp0: 0.005 39 | twde: 0 40 | transformer_config: 41 | target: src.IBQ.modules.transformer.llama.GPT 42 | params: 43 | vocab_size: 16384 # 262144 tokens 44 | block_size: 256 45 | n_layer: 24 46 | n_head: 24 47 | n_embd: 1536 48 | cond_dim: 1536 49 | resid_dropout_p: 0.1 50 | ffn_dropout_p: 0.1 51 | token_drop: 0.1 52 | drop_path_rate: 0.0 ##not using droppath rate 53 | alng: 1e-4 54 | class_num: 1000 #class tokens 55 | first_stage_config: 56 | target: src.IBQ.models.ibqgan.IBQ 57 | params: 58 | ckpt_path: # specify your path for tokenizer FID: 1.37 59 | n_embed: 16384 60 | embed_dim: 256 61 | learning_rate: 1e-4 62 | l2_normalize: False 63 | use_entropy_loss: True 64 | sample_minimization_weight: 1.0 65 | batch_maximization_weight: 1.0 66 | entropy_temperature: 0.01 # default 0.01 67 | beta: 0.25 68 | use_ema: True 69 | stage: transformer 70 | ddconfig: 71 | double_z: False 72 | z_channels: 256 73 | resolution: 256 74 | in_channels: 3 75 | out_ch: 3 76 | ch: 128 77 | ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1 78 | num_res_blocks: 4 79 | attn_resolutions: [16] 80 | dropout: 0.0 81 | lossconfig: 82 | target: src.IBQ.modules.losses.DummyLoss 83 | cond_stage_config: 84 | target: src.IBQ.modules.util.Labelator 85 | params: 86 | n_classes: 1000 87 | 88 | data: 89 | class_path: main.DataModuleFromConfig 90 | init_args: 91 | batch_size: 12 92 | num_workers: 16 93 | train: 94 | target: src.IBQ.data.imagenet.ImageNetTrain 95 | params: 96 | config: 97 | size: 256 98 | validation: 99 | target: src.IBQ.data.imagenet.ImageNetValidation 100 | params: 101 | config: 102 | size: 256 103 | subset: 104 | test: 105 | target: src.IBQ.data.imagenet.ImageNetValidation 106 | params: 107 | config: 108 | size: 256 109 | subset: 110 | 111 | ckpt_path: null # to resume -------------------------------------------------------------------------------- /configs/IBQ/npu/imagenet_conditional_llama_L.yaml: -------------------------------------------------------------------------------- 1 | # refer to https://app.koofr.net/links/90cbd5aa-ef70-4f5e-99bc-f12e5a89380e?path=%2F2021-04-03T19-39-50_cin_transformer%2Fconfigs%2F2021-04-03T19-39-50-project.yaml 2 | seed_everything: true 3 | trainer: 4 | accelerator: npu 5 | strategy: ddp_find_unused_parameters_true 6 | devices: 8 7 | num_nodes: 8 8 | precision: bf16-mixed 9 | max_epochs: 350 10 | check_val_every_n_epoch: 1 11 | num_sanity_val_steps: 0 12 | gradient_clip_val: 1.0 13 | callbacks: 14 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 15 | init_args: 16 | dirpath: "../../checkpoints/vqgan/test" 17 | save_top_k: 20 18 | monitor: "train/loss" 19 | - class_path: lightning.pytorch.callbacks.LearningRateMonitor 20 | init_args: 21 | logging_interval: step 22 | logger: 23 | class_path: lightning.pytorch.loggers.TensorBoardLogger 24 | init_args: 25 | save_dir: "../../results/vqgan" 26 | version: "test" 27 | name: 28 | 29 | model: 30 | class_path: src.IBQ.models.cond_transformer_llama.Net2NetTransformer 31 | init_args: 32 | learning_rate: 3e-4 33 | first_stage_key: image 34 | cond_stage_key: class_label 35 | weight_decay: 5e-2 36 | wpe: 0.1 #learning rate decay #1B can be 0.01 37 | wp: 7 38 | wp0: 0.005 39 | twde: 0 40 | transformer_config: 41 | target: src.IBQ.modules.transformer.llama.GPT 42 | params: 43 | vocab_size: 16384 # 262144 tokens 44 | block_size: 256 45 | n_layer: 20 46 | n_head: 20 47 | n_embd: 1280 48 | cond_dim: 1280 49 | resid_dropout_p: 0.1 50 | ffn_dropout_p: 0.1 51 | token_drop: 0.1 52 | drop_path_rate: 0.0 ##not using droppath rate 53 | alng: 1e-3 54 | class_num: 1000 #class tokens 55 | first_stage_config: 56 | target: src.IBQ.models.ibqgan.IBQ 57 | params: 58 | ckpt_path: #specify your path for tokenizer FID: 1.37 59 | n_embed: 16384 60 | embed_dim: 256 61 | learning_rate: 1e-4 62 | l2_normalize: False 63 | use_entropy_loss: True 64 | sample_minimization_weight: 1.0 65 | batch_maximization_weight: 1.0 66 | entropy_temperature: 0.01 # default 0.01 67 | beta: 0.25 68 | use_ema: True 69 | stage: transformer 70 | ddconfig: 71 | double_z: False 72 | z_channels: 256 73 | resolution: 256 74 | in_channels: 3 75 | out_ch: 3 76 | ch: 128 77 | ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1 78 | num_res_blocks: 4 79 | attn_resolutions: [16] 80 | dropout: 0.0 81 | lossconfig: 82 | target: src.IBQ.modules.losses.DummyLoss 83 | cond_stage_config: 84 | target: src.IBQ.modules.util.Labelator 85 | params: 86 | n_classes: 1000 87 | 88 | data: 89 | class_path: main.DataModuleFromConfig 90 | init_args: 91 | batch_size: 12 92 | num_workers: 16 93 | train: 94 | target: src.IBQ.data.imagenet.ImageNetTrain 95 | params: 96 | config: 97 | size: 256 98 | validation: 99 | target: src.IBQ.data.imagenet.ImageNetValidation 100 | params: 101 | config: 102 | size: 256 103 | subset: 104 | test: 105 | target: src.IBQ.data.imagenet.ImageNetValidation 106 | params: 107 | config: 108 | size: 256 109 | subset: 110 | 111 | ckpt_path: null # to resume -------------------------------------------------------------------------------- /configs/IBQ/npu/imagenet_conditional_llama_XL.yaml: -------------------------------------------------------------------------------- 1 | # refer to https://app.koofr.net/links/90cbd5aa-ef70-4f5e-99bc-f12e5a89380e?path=%2F2021-04-03T19-39-50_cin_transformer%2Fconfigs%2F2021-04-03T19-39-50-project.yaml 2 | seed_everything: true 3 | trainer: 4 | accelerator: npu 5 | strategy: ddp_find_unused_parameters_true 6 | devices: 8 7 | num_nodes: 8 8 | precision: bf16-mixed 9 | max_epochs: 400 10 | check_val_every_n_epoch: 1 11 | num_sanity_val_steps: 0 12 | gradient_clip_val: 1.0 13 | callbacks: 14 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 15 | init_args: 16 | dirpath: "../../checkpoints/vqgan/test" 17 | save_top_k: 20 18 | monitor: "train/loss" 19 | - class_path: lightning.pytorch.callbacks.LearningRateMonitor 20 | init_args: 21 | logging_interval: step 22 | logger: 23 | class_path: lightning.pytorch.loggers.TensorBoardLogger 24 | init_args: 25 | save_dir: "../../results/vqgan" 26 | version: "test" 27 | name: 28 | 29 | model: 30 | class_path: src.IBQ.models.cond_transformer_llama.Net2NetTransformer 31 | init_args: 32 | learning_rate: 3e-4 33 | first_stage_key: image 34 | cond_stage_key: class_label 35 | weight_decay: 5e-2 36 | wpe: 0.01 #learning rate decay #1B can be 0.01 37 | wp: 8 38 | wp0: 0.005 39 | twde: 0 40 | transformer_config: 41 | target: src.IBQ.modules.transformer.llama.GPT 42 | params: 43 | vocab_size: 16384 # 262144 tokens 44 | block_size: 256 45 | n_layer: 24 46 | n_head: 24 47 | n_embd: 1536 48 | cond_dim: 1536 49 | resid_dropout_p: 0.1 50 | ffn_dropout_p: 0.1 51 | token_drop: 0.1 52 | drop_path_rate: 0.0 ##not using droppath rate 53 | alng: 1e-4 54 | class_num: 1000 #class tokens 55 | first_stage_config: 56 | target: src.IBQ.models.ibqgan.IBQ 57 | params: 58 | ckpt_path: #specify your path for tokenizer FID: 1.37 59 | n_embed: 16384 60 | embed_dim: 256 61 | learning_rate: 1e-4 62 | l2_normalize: False 63 | use_entropy_loss: True 64 | sample_minimization_weight: 1.0 65 | batch_maximization_weight: 1.0 66 | entropy_temperature: 0.01 # default 0.01 67 | beta: 0.25 68 | use_ema: True 69 | stage: transformer 70 | ddconfig: 71 | double_z: False 72 | z_channels: 256 73 | resolution: 256 74 | in_channels: 3 75 | out_ch: 3 76 | ch: 128 77 | ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1 78 | num_res_blocks: 4 79 | attn_resolutions: [16] 80 | dropout: 0.0 81 | lossconfig: 82 | target: src.IBQ.modules.losses.DummyLoss 83 | cond_stage_config: 84 | target: src.IBQ.modules.util.Labelator 85 | params: 86 | n_classes: 1000 87 | 88 | data: 89 | class_path: main.DataModuleFromConfig 90 | init_args: 91 | batch_size: 12 92 | num_workers: 16 93 | train: 94 | target: src.IBQ.data.imagenet.ImageNetTrain 95 | params: 96 | config: 97 | size: 256 98 | validation: 99 | target: src.IBQ.data.imagenet.ImageNetValidation 100 | params: 101 | config: 102 | size: 256 103 | subset: 104 | test: 105 | target: src.IBQ.data.imagenet.ImageNetValidation 106 | params: 107 | config: 108 | size: 256 109 | subset: 110 | 111 | ckpt_path: null # to resume -------------------------------------------------------------------------------- /configs/IBQ/gpu/imagenet_conditional_llama_XXL.yaml: -------------------------------------------------------------------------------- 1 | # refer to https://app.koofr.net/links/90cbd5aa-ef70-4f5e-99bc-f12e5a89380e?path=%2F2021-04-03T19-39-50_cin_transformer%2Fconfigs%2F2021-04-03T19-39-50-project.yaml 2 | seed_everything: true 3 | trainer: 4 | accelerator: gpu 5 | strategy: ddp_find_unused_parameters_true 6 | devices: 8 7 | num_nodes: 12 8 | precision: 16-mixed 9 | max_epochs: 450 10 | check_val_every_n_epoch: 1 11 | num_sanity_val_steps: 0 12 | gradient_clip_val: 1.0 13 | callbacks: 14 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 15 | init_args: 16 | dirpath: "../../checkpoints/vqgan/test" 17 | save_top_k: 20 18 | monitor: "train/loss" 19 | - class_path: lightning.pytorch.callbacks.LearningRateMonitor 20 | init_args: 21 | logging_interval: step 22 | logger: 23 | class_path: lightning.pytorch.loggers.TensorBoardLogger 24 | init_args: 25 | save_dir: "../../results/vqgan" 26 | version: "test" 27 | name: 28 | 29 | model: 30 | class_path: src.IBQ.models.cond_transformer_llama.Net2NetTransformer 31 | init_args: 32 | learning_rate: 3e-4 33 | first_stage_key: image 34 | cond_stage_key: class_label 35 | weight_decay: 5e-2 36 | wpe: 0.01 #learning rate decay #1B can be 0.01 37 | wp: 9 38 | wp0: 0.005 39 | twde: 0.08 40 | transformer_config: 41 | target: src.IBQ.modules.transformer.llama.GPT 42 | params: 43 | vocab_size: 16384 # 262144 tokens 44 | block_size: 256 45 | n_layer: 30 46 | n_head: 30 47 | n_embd: 1920 48 | cond_dim: 1920 49 | resid_dropout_p: 0.1 50 | ffn_dropout_p: 0.1 51 | token_drop: 0.1 52 | drop_path_rate: 0.0 ##not using droppath rate 53 | alng: 1e-5 54 | class_num: 1000 #class tokens 55 | first_stage_config: 56 | target: src.IBQ.models.ibqgan.IBQ 57 | params: 58 | ckpt_path: # specify your path for tokenizer FID: 1.37 59 | n_embed: 16384 60 | embed_dim: 256 61 | learning_rate: 1e-4 62 | l2_normalize: False 63 | use_entropy_loss: True 64 | sample_minimization_weight: 1.0 65 | batch_maximization_weight: 1.0 66 | entropy_temperature: 0.01 # default 0.01 67 | beta: 0.25 68 | use_ema: True 69 | stage: transformer 70 | ddconfig: 71 | double_z: False 72 | z_channels: 256 73 | resolution: 256 74 | in_channels: 3 75 | out_ch: 3 76 | ch: 128 77 | ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1 78 | num_res_blocks: 4 79 | attn_resolutions: [16] 80 | dropout: 0.0 81 | lossconfig: 82 | target: src.IBQ.modules.losses.DummyLoss 83 | cond_stage_config: 84 | target: src.IBQ.modules.util.Labelator 85 | params: 86 | n_classes: 1000 87 | 88 | data: 89 | class_path: main.DataModuleFromConfig 90 | init_args: 91 | batch_size: 8 92 | num_workers: 16 93 | train: 94 | target: src.IBQ.data.imagenet.ImageNetTrain 95 | params: 96 | config: 97 | size: 256 98 | validation: 99 | target: src.IBQ.data.imagenet.ImageNetValidation 100 | params: 101 | config: 102 | size: 256 103 | subset: 104 | test: 105 | target: src.IBQ.data.imagenet.ImageNetValidation 106 | params: 107 | config: 108 | size: 256 109 | subset: 110 | 111 | ckpt_path: null # to resume -------------------------------------------------------------------------------- /configs/IBQ/npu/imagenet_conditional_llama_XXL.yaml: -------------------------------------------------------------------------------- 1 | # refer to https://app.koofr.net/links/90cbd5aa-ef70-4f5e-99bc-f12e5a89380e?path=%2F2021-04-03T19-39-50_cin_transformer%2Fconfigs%2F2021-04-03T19-39-50-project.yaml 2 | seed_everything: true 3 | trainer: 4 | accelerator: npu 5 | strategy: ddp_find_unused_parameters_true 6 | devices: 8 7 | num_nodes: 12 8 | precision: bf16-mixed 9 | max_epochs: 450 10 | check_val_every_n_epoch: 1 11 | num_sanity_val_steps: 0 12 | gradient_clip_val: 1.0 13 | callbacks: 14 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 15 | init_args: 16 | dirpath: "../../checkpoints/vqgan/test" 17 | save_top_k: 20 18 | monitor: "train/loss" 19 | - class_path: lightning.pytorch.callbacks.LearningRateMonitor 20 | init_args: 21 | logging_interval: step 22 | logger: 23 | class_path: lightning.pytorch.loggers.TensorBoardLogger 24 | init_args: 25 | save_dir: "../../results/vqgan" 26 | version: "test" 27 | name: 28 | 29 | model: 30 | class_path: src.IBQ.models.cond_transformer_llama.Net2NetTransformer 31 | init_args: 32 | learning_rate: 3e-4 33 | first_stage_key: image 34 | cond_stage_key: class_label 35 | weight_decay: 5e-2 36 | wpe: 0.01 #learning rate decay #1B can be 0.01 37 | wp: 9 38 | wp0: 0.005 39 | twde: 0.08 40 | transformer_config: 41 | target: src.IBQ.modules.transformer.llama.GPT 42 | params: 43 | vocab_size: 16384 # 262144 tokens 44 | block_size: 256 45 | n_layer: 30 46 | n_head: 30 47 | n_embd: 1920 48 | cond_dim: 1920 49 | resid_dropout_p: 0.1 50 | ffn_dropout_p: 0.1 51 | token_drop: 0.1 52 | drop_path_rate: 0.0 ##not using droppath rate 53 | alng: 1e-5 54 | class_num: 1000 #class tokens 55 | first_stage_config: 56 | target: src.IBQ.models.ibqgan.IBQ 57 | params: 58 | ckpt_path: #specify your path for tokenizer FID: 1.37 59 | n_embed: 16384 60 | embed_dim: 256 61 | learning_rate: 1e-4 62 | l2_normalize: False 63 | use_entropy_loss: True 64 | sample_minimization_weight: 1.0 65 | batch_maximization_weight: 1.0 66 | entropy_temperature: 0.01 # default 0.01 67 | beta: 0.25 68 | use_ema: True 69 | stage: transformer 70 | ddconfig: 71 | double_z: False 72 | z_channels: 256 73 | resolution: 256 74 | in_channels: 3 75 | out_ch: 3 76 | ch: 128 77 | ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1 78 | num_res_blocks: 4 79 | attn_resolutions: [16] 80 | dropout: 0.0 81 | lossconfig: 82 | target: src.IBQ.modules.losses.DummyLoss 83 | cond_stage_config: 84 | target: src.IBQ.modules.util.Labelator 85 | params: 86 | n_classes: 1000 87 | 88 | data: 89 | class_path: main.DataModuleFromConfig 90 | init_args: 91 | batch_size: 8 92 | num_workers: 16 93 | train: 94 | target: src.IBQ.data.imagenet.ImageNetTrain 95 | params: 96 | config: 97 | size: 256 98 | validation: 99 | target: src.IBQ.data.imagenet.ImageNetValidation 100 | params: 101 | config: 102 | size: 256 103 | subset: 104 | test: 105 | target: src.IBQ.data.imagenet.ImageNetValidation 106 | params: 107 | config: 108 | size: 256 109 | subset: 110 | 111 | ckpt_path: null # to resume -------------------------------------------------------------------------------- /configs/Open-MAGVIT2/gpu/ucf101_lfqfan_128_L.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: true 2 | trainer: 3 | accelerator: gpu 4 | strategy: ddp_find_unused_parameters_true 5 | devices: 8 6 | num_nodes: 8 7 | precision: 16-mixed 8 | max_epochs: 2000 9 | check_val_every_n_epoch: 20 10 | num_sanity_val_steps: -1 11 | log_every_n_steps: 100 12 | callbacks: 13 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 14 | init_args: 15 | dirpath: "../../checkpoints/vqgan/test" 16 | save_top_k: -1 # save all checkpoints 17 | - class_path: lightning.pytorch.callbacks.LearningRateMonitor 18 | init_args: 19 | logging_interval: step 20 | logger: 21 | class_path: lightning.pytorch.loggers.TensorBoardLogger 22 | init_args: 23 | save_dir: "../../results/vqgan/" 24 | version: "test" 25 | name: 26 | 27 | model: 28 | class_path: src.Open_MAGVIT2.models.video_lfqgan.VQModel 29 | init_args: 30 | ddconfig: 31 | double_z: False 32 | z_channels: 18 33 | resolution: 128 34 | in_channels: 3 35 | out_ch: 3 36 | ch: 128 37 | ch_mult: [1,2,2,4] # num_down = len(ch_mult)-1 38 | num_res_blocks: 4 39 | 40 | lossconfig: 41 | target: src.Open_MAGVIT2.modules.losses.video_vqperceptual.VQLPIPSWithDiscriminator 42 | params: 43 | disc_conditional: False 44 | disc_in_channels: 3 45 | disc_num_layers: 3 46 | disc_start: 0 # from 0 epoch 47 | disc_weight: 0.8 48 | gen_loss_weight: 0.3 #see if it is reuqired to change 49 | lecam_loss_weight: 0.05 ##increase to 0.1 50 | codebook_weight: 0.1 51 | commit_weight: 0.25 52 | codebook_enlarge_ratio: 0 53 | codebook_enlarge_steps: 2000 54 | 55 | n_embed: 262144 56 | embed_dim: 18 57 | learning_rate: 1e-4 58 | sample_minimization_weight: 1.0 59 | batch_maximization_weight: 1.0 60 | scheduler_type: "None" 61 | use_ema: True 62 | image_pretrain_path: ../upload_ckpts/Open-MAGVIT2/in1k_128_L/imagenet_128_L.ckpt 63 | sche_type: cos 64 | wpe: 0.01 ## learning rate decay to final lr to 0 65 | wp: 2 ##one epoch for linear warmup 66 | wp0: 0.0 ##for warmup from zero 67 | max_iter: 68 | wp_iter: 69 | resume_lr: 70 | 71 | data: 72 | class_path: main.DataModuleFromConfig 73 | init_args: 74 | batch_size: 2 75 | num_workers: 16 76 | train: 77 | target: src.Open_MAGVIT2.data.ucf101.VideoDataset 78 | params: 79 | config: 80 | data_folder: ../../data/UCF-101 81 | size: 128 82 | mode: train 83 | sequence_length: 17 84 | sample_every_n_frames: 1 85 | frame_sample_rate: 4 86 | subset: 87 | validation: 88 | target: src.Open_MAGVIT2.data.ucf101.VideoDataset 89 | params: 90 | config: 91 | data_folder: ../../data/UCF-101 92 | size: 128 93 | mode: validation 94 | sequence_length: 17 95 | sample_every_n_frames: 1 96 | frame_sample_rate: 4 97 | subset: 98 | test: 99 | target: src.Open_MAGVIT2.data.ucf101.VideoDataset 100 | params: 101 | config: 102 | data_folder: ../../data/UCF-101 103 | size: 128 104 | mode: test 105 | sequence_length: 17 106 | sample_every_n_frames: 1 107 | frame_sample_rate: 4 108 | subset: 109 | 110 | ckpt_path: null # to resume -------------------------------------------------------------------------------- /configs/Open-MAGVIT2/npu/ucf101_lfqgan_128_L.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: true 2 | trainer: 3 | accelerator: npu 4 | strategy: ddp_find_unused_parameters_true 5 | devices: 8 6 | num_nodes: 8 7 | precision: bf16-mixed 8 | max_epochs: 2000 9 | check_val_every_n_epoch: 20 10 | num_sanity_val_steps: -1 11 | log_every_n_steps: 100 12 | callbacks: 13 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 14 | init_args: 15 | dirpath: "../../checkpoints/vqgan/test" 16 | save_top_k: -1 # save all checkpoints 17 | - class_path: lightning.pytorch.callbacks.LearningRateMonitor 18 | init_args: 19 | logging_interval: step 20 | logger: 21 | class_path: lightning.pytorch.loggers.TensorBoardLogger 22 | init_args: 23 | save_dir: "../../results/vqgan/" 24 | version: "test" 25 | name: 26 | 27 | model: 28 | class_path: src.Open_MAGVIT2.models.video_lfqgan.VQModel 29 | init_args: 30 | ddconfig: 31 | double_z: False 32 | z_channels: 18 33 | resolution: 128 34 | in_channels: 3 35 | out_ch: 3 36 | ch: 128 37 | ch_mult: [1,2,2,4] # num_down = len(ch_mult)-1 38 | num_res_blocks: 4 39 | 40 | lossconfig: 41 | target: src.Open_MAGVIT2.modules.losses.video_vqperceptual.VQLPIPSWithDiscriminator 42 | params: 43 | disc_conditional: False 44 | disc_in_channels: 3 45 | disc_num_layers: 3 46 | disc_start: 0 # from 0 epoch 47 | disc_weight: 0.8 48 | gen_loss_weight: 0.3 #see if it is reuqired to change 49 | lecam_loss_weight: 0.05 ##increase to 0.1 50 | codebook_weight: 0.1 51 | commit_weight: 0.25 52 | codebook_enlarge_ratio: 0 53 | codebook_enlarge_steps: 2000 54 | 55 | n_embed: 262144 56 | embed_dim: 18 57 | learning_rate: 1e-4 58 | sample_minimization_weight: 1.0 59 | batch_maximization_weight: 1.0 60 | scheduler_type: "None" 61 | use_ema: True 62 | image_pretrain_path: ../upload_ckpts/Open-MAGVIT2/in1k_128_L/imagenet_128_L.ckpt 63 | sche_type: cos 64 | wpe: 0.01 ## learning rate decay to final lr to 0 65 | wp: 2 ##one epoch for linear warmup 66 | wp0: 0.0 ##for warmup from zero 67 | max_iter: 68 | wp_iter: 69 | resume_lr: 70 | 71 | data: 72 | class_path: main.DataModuleFromConfig 73 | init_args: 74 | batch_size: 2 75 | num_workers: 16 76 | train: 77 | target: src.Open_MAGVIT2.data.ucf101.VideoDataset 78 | params: 79 | config: 80 | data_folder: ../../data/UCF-101 81 | size: 128 82 | mode: train 83 | sequence_length: 17 84 | sample_every_n_frames: 1 85 | frame_sample_rate: 4 86 | subset: 87 | validation: 88 | target: src.Open_MAGVIT2.data.ucf101.VideoDataset 89 | params: 90 | config: 91 | data_folder: ../../data/UCF-101 92 | size: 128 93 | mode: validation 94 | sequence_length: 17 95 | sample_every_n_frames: 1 96 | frame_sample_rate: 4 97 | subset: 98 | test: 99 | target: src.Open_MAGVIT2.data.ucf101.VideoDataset 100 | params: 101 | config: 102 | data_folder: ../../data/UCF-101 103 | size: 128 104 | mode: test 105 | sequence_length: 17 106 | sample_every_n_frames: 1 107 | frame_sample_rate: 4 108 | subset: 109 | 110 | ckpt_path: null # to resume -------------------------------------------------------------------------------- /configs/Open-MAGVIT2/gpu/imagenet_conditional_llama_B.yaml: -------------------------------------------------------------------------------- 1 | # refer to https://app.koofr.net/links/90cbd5aa-ef70-4f5e-99bc-f12e5a89380e?path=%2F2021-04-03T19-39-50_cin_transformer%2Fconfigs%2F2021-04-03T19-39-50-project.yaml 2 | seed_everything: true 3 | trainer: 4 | accelerator: gpu 5 | strategy: ddp_find_unused_parameters_true 6 | # strategy: 7 | # class_path: lightning.pytorch.strategies.FSDPStrategy 8 | # init_args: 9 | # sharding_strategy: "SHARD_GRAD_OP" 10 | devices: 8 11 | num_nodes: 4 12 | precision: 16-mixed 13 | max_epochs: 300 14 | check_val_every_n_epoch: 1 15 | num_sanity_val_steps: -1 16 | gradient_clip_val: 1.0 17 | callbacks: 18 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 19 | init_args: 20 | dirpath: "../../checkpoints/vqgan/test" 21 | save_top_k: 20 22 | monitor: "train/loss" 23 | - class_path: lightning.pytorch.callbacks.LearningRateMonitor 24 | init_args: 25 | logging_interval: step 26 | logger: 27 | class_path: lightning.pytorch.loggers.TensorBoardLogger 28 | init_args: 29 | save_dir: ../../results/vqgan" 30 | version: "test" 31 | name: 32 | 33 | model: 34 | class_path: src.Open_MAGVIT2.models.cond_transformer_gpt.Net2NetTransformer 35 | init_args: 36 | learning_rate: 3e-4 37 | first_stage_key: image 38 | cond_stage_key: class_label 39 | token_factorization: True 40 | weight_decay: 5e-2 41 | wpe: 0.1 ## learning rate decay 42 | wp: 6 ##no warmup 43 | wp0: 0.005 ##for warmup 44 | twde: 0 45 | transformer_config: 46 | target: src.Open_MAGVIT2.modules.transformer.gpt.GPT 47 | params: 48 | # vocab_size: 262144 # 262144 tokens 49 | vocab_size: 512 50 | block_size: 256 51 | spatial_n_layer: 24 ## follow LlamaGen 52 | factorized_n_layer: 2 53 | factorized_bits: [6, 12] #asymmetrical head 54 | n_head: 16 55 | dim: 1024 56 | cond_dim: 1024 57 | token_drop: 0.1 58 | resid_dropout_p: 0.1 59 | token_factorization: True 60 | class_num: 1000 #class tokens 61 | first_stage_config: 62 | target: src.Open_MAGVIT2.models.lfqgan.VQModel 63 | params: 64 | ckpt_path: #specify your path for tokenizer FID: 1.17 65 | n_embed: 262144 66 | embed_dim: 18 67 | learning_rate: 1e-4 68 | sample_minimization_weight: 1.0 69 | batch_maximization_weight: 1.0 70 | scheduler_type: "None" 71 | use_ema: False 72 | stage: "transformer" 73 | token_factorization: True 74 | factorized_bits: [6, 12] 75 | ddconfig: 76 | double_z: False 77 | z_channels: 18 78 | resolution: 128 79 | in_channels: 3 80 | out_ch: 3 81 | ch: 128 82 | ch_mult: [1,1,2,2,4] # num_down = len(ch_mult)-1 83 | num_res_blocks: 4 84 | lossconfig: 85 | target: src.Open_MAGVIT2.modules.losses.DummyLoss 86 | cond_stage_config: 87 | target: src.Open_MAGVIT2.modules.util.Labelator 88 | params: 89 | n_classes: 1000 90 | permuter_config: 91 | target: src.Open_MAGVIT2.modules.transformer.permuter.ShiftPermuter 92 | params: 93 | shift_pos: 1000 # num_classes 94 | 95 | data: 96 | class_path: main.DataModuleFromConfig 97 | init_args: 98 | batch_size: 24 99 | num_workers: 16 100 | train: 101 | target: src.Open_MAGVIT2.data.imagenet.ImageNetTrain 102 | params: 103 | config: 104 | size: 256 105 | subset: 106 | validation: 107 | target: src.Open_MAGVIT2.data.imagenet.ImageNetValidation 108 | params: 109 | config: 110 | size: 256 111 | subset: 112 | test: 113 | target: src.Open_MAGVIT2.data.imagenet.ImageNetValidation 114 | params: 115 | config: 116 | size: 256 117 | subset: 118 | 119 | ckpt_path: null # to resume -------------------------------------------------------------------------------- /configs/Open-MAGVIT2/gpu/imagenet_conditional_llama_L.yaml: -------------------------------------------------------------------------------- 1 | # refer to https://app.koofr.net/links/90cbd5aa-ef70-4f5e-99bc-f12e5a89380e?path=%2F2021-04-03T19-39-50_cin_transformer%2Fconfigs%2F2021-04-03T19-39-50-project.yaml 2 | seed_everything: true 3 | trainer: 4 | accelerator: gpu 5 | strategy: ddp_find_unused_parameters_true 6 | # strategy: 7 | # class_path: lightning.pytorch.strategies.FSDPStrategy 8 | # init_args: 9 | # sharding_strategy: "SHARD_GRAD_OP" 10 | devices: 8 11 | num_nodes: 8 12 | precision: 16-mixed 13 | max_epochs: 300 14 | check_val_every_n_epoch: 1 15 | num_sanity_val_steps: -1 16 | gradient_clip_val: 1.0 17 | callbacks: 18 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 19 | init_args: 20 | dirpath: "../../checkpoints/vqgan/test" 21 | save_top_k: 20 22 | monitor: "train/loss" 23 | - class_path: lightning.pytorch.callbacks.LearningRateMonitor 24 | init_args: 25 | logging_interval: step 26 | logger: 27 | class_path: lightning.pytorch.loggers.TensorBoardLogger 28 | init_args: 29 | save_dir: ../../results/vqgan" 30 | version: "test" 31 | name: 32 | 33 | model: 34 | class_path: src.Open_MAGVIT2.models.cond_transformer_gpt.Net2NetTransformer 35 | init_args: 36 | learning_rate: 3e-4 37 | first_stage_key: image 38 | cond_stage_key: class_label 39 | token_factorization: True 40 | weight_decay: 5e-2 41 | wpe: 0.1 ## learning rate decay 42 | wp: 6 ##no warmup 43 | wp0: 0.005 ##for warmup 44 | twde: 0 45 | transformer_config: 46 | target: src.Open_MAGVIT2.modules.transformer.gpt.GPT 47 | params: 48 | # vocab_size: 262144 # 262144 tokens 49 | vocab_size: 512 50 | block_size: 256 51 | spatial_n_layer: 36 ## follow LlamaGen 52 | factorized_n_layer: 3 53 | factorized_bits: [6, 12] #asymmetrical head 54 | n_head: 20 55 | dim: 1280 56 | cond_dim: 1280 57 | token_drop: 0.1 58 | resid_dropout_p: 0.1 59 | token_factorization: True 60 | class_num: 1000 #class tokens 61 | first_stage_config: 62 | target: src.Open_MAGVIT2.models.lfqgan.VQModel 63 | params: 64 | ckpt_path: #specify your path for tokenizer FID: 1.17 65 | n_embed: 262144 66 | embed_dim: 18 67 | learning_rate: 1e-4 68 | sample_minimization_weight: 1.0 69 | batch_maximization_weight: 1.0 70 | scheduler_type: "None" 71 | use_ema: False 72 | stage: "transformer" 73 | token_factorization: True 74 | factorized_bits: [6, 12] 75 | ddconfig: 76 | double_z: False 77 | z_channels: 18 78 | resolution: 128 79 | in_channels: 3 80 | out_ch: 3 81 | ch: 128 82 | ch_mult: [1,1,2,2,4] # num_down = len(ch_mult)-1 83 | num_res_blocks: 4 84 | lossconfig: 85 | target: src.Open_MAGVIT2.modules.losses.DummyLoss 86 | cond_stage_config: 87 | target: src.Open_MAGVIT2.modules.util.Labelator 88 | params: 89 | n_classes: 1000 90 | permuter_config: 91 | target: src.Open_MAGVIT2.modules.transformer.permuter.ShiftPermuter 92 | params: 93 | shift_pos: 1000 # num_classes 94 | 95 | data: 96 | class_path: main.DataModuleFromConfig 97 | init_args: 98 | batch_size: 12 99 | num_workers: 16 100 | train: 101 | target: src.Open_MAGVIT2.data.imagenet.ImageNetTrain 102 | params: 103 | config: 104 | size: 256 105 | subset: 106 | validation: 107 | target: src.Open_MAGVIT2.data.imagenet.ImageNetValidation 108 | params: 109 | config: 110 | size: 256 111 | subset: 112 | test: 113 | target: src.Open_MAGVIT2.data.imagenet.ImageNetValidation 114 | params: 115 | config: 116 | size: 256 117 | subset: 118 | 119 | ckpt_path: null # to resume -------------------------------------------------------------------------------- /configs/Open-MAGVIT2/gpu/imagenet_conditional_llama_XL.yaml: -------------------------------------------------------------------------------- 1 | # refer to https://app.koofr.net/links/90cbd5aa-ef70-4f5e-99bc-f12e5a89380e?path=%2F2021-04-03T19-39-50_cin_transformer%2Fconfigs%2F2021-04-03T19-39-50-project.yaml 2 | seed_everything: true 3 | trainer: 4 | accelerator: gpu 5 | strategy: ddp_find_unused_parameters_true 6 | # strategy: 7 | # class_path: lightning.pytorch.strategies.FSDPStrategy 8 | # init_args: 9 | # sharding_strategy: "SHARD_GRAD_OP" 10 | devices: 8 11 | num_nodes: 12 12 | precision: 16-mixed 13 | max_epochs: 350 14 | check_val_every_n_epoch: 1 15 | num_sanity_val_steps: -1 16 | gradient_clip_val: 1.0 17 | callbacks: 18 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 19 | init_args: 20 | dirpath: "../../checkpoints/vqgan/test" 21 | save_top_k: 20 22 | monitor: "train/loss" 23 | - class_path: lightning.pytorch.callbacks.LearningRateMonitor 24 | init_args: 25 | logging_interval: step 26 | logger: 27 | class_path: lightning.pytorch.loggers.TensorBoardLogger 28 | init_args: 29 | save_dir: ../../results/vqgan" 30 | version: "test" 31 | name: 32 | 33 | model: 34 | class_path: src.Open_MAGVIT2.models.cond_transformer_gpt.Net2NetTransformer 35 | init_args: 36 | learning_rate: 2.4e-4 37 | first_stage_key: image 38 | cond_stage_key: class_label 39 | token_factorization: True 40 | weight_decay: 5e-2 41 | wpe: 0.1 ## learning rate decay 42 | wp: 7 ##no warmup 43 | wp0: 0.005 ##for warmup 44 | twde: 0 45 | transformer_config: 46 | target: src.Open_MAGVIT2.modules.transformer.gpt.GPT 47 | params: 48 | # vocab_size: 262144 # 262144 tokens 49 | vocab_size: 512 50 | block_size: 256 51 | spatial_n_layer: 48 ## follow LlamaGen 52 | factorized_n_layer: 4 53 | factorized_bits: [6, 12] #asymmetrical head 54 | n_head: 24 55 | dim: 1536 56 | cond_dim: 1536 57 | token_drop: 0.1 58 | resid_dropout_p: 0.1 59 | token_factorization: True 60 | class_num: 1000 #class tokens 61 | first_stage_config: 62 | target: src.Open_MAGVIT2.models.lfqgan.VQModel 63 | params: 64 | ckpt_path: #specify your path for tokenizer FID: 1.17 65 | n_embed: 262144 66 | embed_dim: 18 67 | learning_rate: 1e-4 68 | sample_minimization_weight: 1.0 69 | batch_maximization_weight: 1.0 70 | scheduler_type: "None" 71 | use_ema: False 72 | stage: "transformer" 73 | token_factorization: True 74 | factorized_bits: [6, 12] 75 | ddconfig: 76 | double_z: False 77 | z_channels: 18 78 | resolution: 128 79 | in_channels: 3 80 | out_ch: 3 81 | ch: 128 82 | ch_mult: [1,1,2,2,4] # num_down = len(ch_mult)-1 83 | num_res_blocks: 4 84 | lossconfig: 85 | target: src.Open_MAGVIT2.modules.losses.DummyLoss 86 | cond_stage_config: 87 | target: src.Open_MAGVIT2.modules.util.Labelator 88 | params: 89 | n_classes: 1000 90 | permuter_config: 91 | target: src.Open_MAGVIT2.modules.transformer.permuter.ShiftPermuter 92 | params: 93 | shift_pos: 1000 # num_classes 94 | 95 | data: 96 | class_path: main.DataModuleFromConfig 97 | init_args: 98 | batch_size: 8 99 | num_workers: 16 100 | train: 101 | target: src.Open_MAGVIT2.data.imagenet.ImageNetTrain 102 | params: 103 | config: 104 | size: 256 105 | subset: 106 | validation: 107 | target: src.Open_MAGVIT2.data.imagenet.ImageNetValidation 108 | params: 109 | config: 110 | size: 256 111 | subset: 112 | test: 113 | target: src.Open_MAGVIT2.data.imagenet.ImageNetValidation 114 | params: 115 | config: 116 | size: 256 117 | subset: 118 | 119 | ckpt_path: null # to resume -------------------------------------------------------------------------------- /configs/Open-MAGVIT2/npu/imagenet_conditional_llama_B.yaml: -------------------------------------------------------------------------------- 1 | # refer to https://app.koofr.net/links/90cbd5aa-ef70-4f5e-99bc-f12e5a89380e?path=%2F2021-04-03T19-39-50_cin_transformer%2Fconfigs%2F2021-04-03T19-39-50-project.yaml 2 | seed_everything: true 3 | trainer: 4 | accelerator: npu 5 | strategy: ddp_find_unused_parameters_true 6 | # strategy: 7 | # class_path: lightning.pytorch.strategies.FSDPStrategy 8 | # init_args: 9 | # sharding_strategy: "SHARD_GRAD_OP" 10 | devices: 8 11 | num_nodes: 4 12 | precision: bf16-mixed 13 | max_epochs: 300 14 | check_val_every_n_epoch: 1 15 | num_sanity_val_steps: -1 16 | gradient_clip_val: 1.0 17 | callbacks: 18 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 19 | init_args: 20 | dirpath: "../../checkpoints/vqgan/test" 21 | save_top_k: 20 22 | monitor: "train/loss" 23 | - class_path: lightning.pytorch.callbacks.LearningRateMonitor 24 | init_args: 25 | logging_interval: step 26 | logger: 27 | class_path: lightning.pytorch.loggers.TensorBoardLogger 28 | init_args: 29 | save_dir: ../../results/vqgan" 30 | version: "test" 31 | name: 32 | 33 | model: 34 | class_path: src.Open_MAGVIT2.models.cond_transformer_gpt.Net2NetTransformer 35 | init_args: 36 | learning_rate: 3e-4 37 | first_stage_key: image 38 | cond_stage_key: class_label 39 | token_factorization: True 40 | weight_decay: 5e-2 41 | wpe: 0.1 ## learning rate decay 42 | wp: 6 ##no warmup 43 | wp0: 0.005 ##for warmup 44 | twde: 0 45 | transformer_config: 46 | target: src.Open_MAGVIT2.modules.transformer.gpt.GPT 47 | params: 48 | # vocab_size: 262144 # 262144 tokens 49 | vocab_size: 512 50 | block_size: 256 51 | spatial_n_layer: 24 ## follow LlamaGen 52 | factorized_n_layer: 2 53 | factorized_bits: [6, 12] #asymmetrical head 54 | n_head: 16 55 | dim: 1024 56 | cond_dim: 1024 57 | token_drop: 0.1 58 | resid_dropout_p: 0.1 59 | token_factorization: True 60 | class_num: 1000 #class tokens 61 | first_stage_config: 62 | target: src.Open_MAGVIT2.models.lfqgan.VQModel 63 | params: 64 | ckpt_path: #specify your path for tokenizer FID: 1.17 65 | n_embed: 262144 66 | embed_dim: 18 67 | learning_rate: 1e-4 68 | sample_minimization_weight: 1.0 69 | batch_maximization_weight: 1.0 70 | scheduler_type: "None" 71 | use_ema: False 72 | stage: "transformer" 73 | token_factorization: True 74 | factorized_bits: [6, 12] 75 | ddconfig: 76 | double_z: False 77 | z_channels: 18 78 | resolution: 128 79 | in_channels: 3 80 | out_ch: 3 81 | ch: 128 82 | ch_mult: [1,1,2,2,4] # num_down = len(ch_mult)-1 83 | num_res_blocks: 4 84 | lossconfig: 85 | target: src.Open_MAGVIT2.modules.losses.DummyLoss 86 | cond_stage_config: 87 | target: src.Open_MAGVIT2.modules.util.Labelator 88 | params: 89 | n_classes: 1000 90 | permuter_config: 91 | target: src.Open_MAGVIT2.modules.transformer.permuter.ShiftPermuter 92 | params: 93 | shift_pos: 1000 # num_classes 94 | 95 | data: 96 | class_path: main.DataModuleFromConfig 97 | init_args: 98 | batch_size: 24 99 | num_workers: 16 100 | train: 101 | target: src.Open_MAGVIT2.data.imagenet.ImageNetTrain 102 | params: 103 | config: 104 | size: 256 105 | subset: 106 | validation: 107 | target: src.Open_MAGVIT2.data.imagenet.ImageNetValidation 108 | params: 109 | config: 110 | size: 256 111 | subset: 112 | test: 113 | target: src.Open_MAGVIT2.data.imagenet.ImageNetValidation 114 | params: 115 | config: 116 | size: 256 117 | subset: 118 | 119 | ckpt_path: null # to resume -------------------------------------------------------------------------------- /configs/Open-MAGVIT2/npu/imagenet_conditional_llama_L.yaml: -------------------------------------------------------------------------------- 1 | # refer to https://app.koofr.net/links/90cbd5aa-ef70-4f5e-99bc-f12e5a89380e?path=%2F2021-04-03T19-39-50_cin_transformer%2Fconfigs%2F2021-04-03T19-39-50-project.yaml 2 | seed_everything: true 3 | trainer: 4 | accelerator: npu 5 | strategy: ddp_find_unused_parameters_true 6 | # strategy: 7 | # class_path: lightning.pytorch.strategies.FSDPStrategy 8 | # init_args: 9 | # sharding_strategy: "SHARD_GRAD_OP" 10 | devices: 8 11 | num_nodes: 8 12 | precision: bf16-mixed 13 | max_epochs: 300 14 | check_val_every_n_epoch: 1 15 | num_sanity_val_steps: -1 16 | gradient_clip_val: 1.0 17 | callbacks: 18 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 19 | init_args: 20 | dirpath: "../../checkpoints/vqgan/test" 21 | save_top_k: 20 22 | monitor: "train/loss" 23 | - class_path: lightning.pytorch.callbacks.LearningRateMonitor 24 | init_args: 25 | logging_interval: step 26 | logger: 27 | class_path: lightning.pytorch.loggers.TensorBoardLogger 28 | init_args: 29 | save_dir: ../../results/vqgan" 30 | version: "test" 31 | name: 32 | 33 | model: 34 | class_path: src.Open_MAGVIT2.models.cond_transformer_gpt.Net2NetTransformer 35 | init_args: 36 | learning_rate: 3e-4 37 | first_stage_key: image 38 | cond_stage_key: class_label 39 | token_factorization: True 40 | weight_decay: 5e-2 41 | wpe: 0.1 ## learning rate decay 42 | wp: 6 ##no warmup 43 | wp0: 0.005 ##for warmup 44 | twde: 0 45 | transformer_config: 46 | target: src.Open_MAGVIT2.modules.transformer.gpt.GPT 47 | params: 48 | # vocab_size: 262144 # 262144 tokens 49 | vocab_size: 512 50 | block_size: 256 51 | spatial_n_layer: 36 ## follow LlamaGen 52 | factorized_n_layer: 3 53 | factorized_bits: [6, 12] #asymmetrical head 54 | n_head: 20 55 | dim: 1280 56 | cond_dim: 1280 57 | token_drop: 0.1 58 | resid_dropout_p: 0.1 59 | token_factorization: True 60 | class_num: 1000 #class tokens 61 | first_stage_config: 62 | target: src.Open_MAGVIT2.models.lfqgan.VQModel 63 | params: 64 | ckpt_path: #specify your path for tokenizer FID: 1.17 65 | n_embed: 262144 66 | embed_dim: 18 67 | learning_rate: 1e-4 68 | sample_minimization_weight: 1.0 69 | batch_maximization_weight: 1.0 70 | scheduler_type: "None" 71 | use_ema: False 72 | stage: "transformer" 73 | token_factorization: True 74 | factorized_bits: [6, 12] 75 | ddconfig: 76 | double_z: False 77 | z_channels: 18 78 | resolution: 128 79 | in_channels: 3 80 | out_ch: 3 81 | ch: 128 82 | ch_mult: [1,1,2,2,4] # num_down = len(ch_mult)-1 83 | num_res_blocks: 4 84 | lossconfig: 85 | target: src.Open_MAGVIT2.modules.losses.DummyLoss 86 | cond_stage_config: 87 | target: src.Open_MAGVIT2.modules.util.Labelator 88 | params: 89 | n_classes: 1000 90 | permuter_config: 91 | target: src.Open_MAGVIT2.modules.transformer.permuter.ShiftPermuter 92 | params: 93 | shift_pos: 1000 # num_classes 94 | 95 | data: 96 | class_path: main.DataModuleFromConfig 97 | init_args: 98 | batch_size: 12 99 | num_workers: 16 100 | train: 101 | target: src.Open_MAGVIT2.data.imagenet.ImageNetTrain 102 | params: 103 | config: 104 | size: 256 105 | subset: 106 | validation: 107 | target: src.Open_MAGVIT2.data.imagenet.ImageNetValidation 108 | params: 109 | config: 110 | size: 256 111 | subset: 112 | test: 113 | target: src.Open_MAGVIT2.data.imagenet.ImageNetValidation 114 | params: 115 | config: 116 | size: 256 117 | subset: 118 | 119 | ckpt_path: null # to resume -------------------------------------------------------------------------------- /configs/Open-MAGVIT2/npu/imagenet_conditional_llama_XL.yaml: -------------------------------------------------------------------------------- 1 | # refer to https://app.koofr.net/links/90cbd5aa-ef70-4f5e-99bc-f12e5a89380e?path=%2F2021-04-03T19-39-50_cin_transformer%2Fconfigs%2F2021-04-03T19-39-50-project.yaml 2 | seed_everything: true 3 | trainer: 4 | accelerator: npu 5 | strategy: ddp_find_unused_parameters_true 6 | # strategy: 7 | # class_path: lightning.pytorch.strategies.FSDPStrategy 8 | # init_args: 9 | # sharding_strategy: "SHARD_GRAD_OP" 10 | devices: 8 11 | num_nodes: 12 12 | precision: bf16-mixed 13 | max_epochs: 350 14 | check_val_every_n_epoch: 1 15 | num_sanity_val_steps: -1 16 | gradient_clip_val: 1.0 17 | callbacks: 18 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 19 | init_args: 20 | dirpath: "../../checkpoints/vqgan/test" 21 | save_top_k: 20 22 | monitor: "train/loss" 23 | - class_path: lightning.pytorch.callbacks.LearningRateMonitor 24 | init_args: 25 | logging_interval: step 26 | logger: 27 | class_path: lightning.pytorch.loggers.TensorBoardLogger 28 | init_args: 29 | save_dir: ../../results/vqgan" 30 | version: "test" 31 | name: 32 | 33 | model: 34 | class_path: src.Open_MAGVIT2.models.cond_transformer_gpt.Net2NetTransformer 35 | init_args: 36 | learning_rate: 2.4e-4 37 | first_stage_key: image 38 | cond_stage_key: class_label 39 | token_factorization: True 40 | weight_decay: 5e-2 41 | wpe: 0.1 ## learning rate decay 42 | wp: 7 ##no warmup 43 | wp0: 0.005 ##for warmup 44 | twde: 0 45 | transformer_config: 46 | target: src.Open_MAGVIT2.modules.transformer.gpt.GPT 47 | params: 48 | # vocab_size: 262144 # 262144 tokens 49 | vocab_size: 512 50 | block_size: 256 51 | spatial_n_layer: 48 ## follow LlamaGen 52 | factorized_n_layer: 4 53 | factorized_bits: [6, 12] #asymmetrical head 54 | n_head: 24 55 | dim: 1536 56 | cond_dim: 1536 57 | token_drop: 0.1 58 | resid_dropout_p: 0.1 59 | token_factorization: True 60 | class_num: 1000 #class tokens 61 | first_stage_config: 62 | target: src.Open_MAGVIT2.models.lfqgan.VQModel 63 | params: 64 | ckpt_path: #specify your path for tokenizer FID: 1.17 65 | n_embed: 262144 66 | embed_dim: 18 67 | learning_rate: 1e-4 68 | sample_minimization_weight: 1.0 69 | batch_maximization_weight: 1.0 70 | scheduler_type: "None" 71 | use_ema: False 72 | stage: "transformer" 73 | token_factorization: True 74 | factorized_bits: [6, 12] 75 | ddconfig: 76 | double_z: False 77 | z_channels: 18 78 | resolution: 128 79 | in_channels: 3 80 | out_ch: 3 81 | ch: 128 82 | ch_mult: [1,1,2,2,4] # num_down = len(ch_mult)-1 83 | num_res_blocks: 4 84 | lossconfig: 85 | target: src.Open_MAGVIT2.modules.losses.DummyLoss 86 | cond_stage_config: 87 | target: src.Open_MAGVIT2.modules.util.Labelator 88 | params: 89 | n_classes: 1000 90 | permuter_config: 91 | target: src.Open_MAGVIT2.modules.transformer.permuter.ShiftPermuter 92 | params: 93 | shift_pos: 1000 # num_classes 94 | 95 | data: 96 | class_path: main.DataModuleFromConfig 97 | init_args: 98 | batch_size: 8 99 | num_workers: 16 100 | train: 101 | target: src.Open_MAGVIT2.data.imagenet.ImageNetTrain 102 | params: 103 | config: 104 | size: 256 105 | subset: 106 | validation: 107 | target: src.Open_MAGVIT2.data.imagenet.ImageNetValidation 108 | params: 109 | config: 110 | size: 256 111 | subset: 112 | test: 113 | target: src.Open_MAGVIT2.data.imagenet.ImageNetValidation 114 | params: 115 | config: 116 | size: 256 117 | subset: 118 | 119 | ckpt_path: null # to resume -------------------------------------------------------------------------------- /configs/IBQ/gpu/pretrain_ibqgan_262144.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: true 2 | trainer: 3 | accelerator: gpu 4 | strategy: ddp_find_unused_parameters_true 5 | devices: 8 6 | num_nodes: 4 7 | precision: 16-mixed 8 | max_steps: 1500000 9 | check_val_every_n_epoch: null 10 | val_check_interval: 5005 ## one imagenet epoch length 11 | num_sanity_val_steps: -1 12 | log_every_n_steps: 100 13 | callbacks: 14 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 15 | init_args: 16 | dirpath: "../../checkpoints/vqgan/test" 17 | save_top_k: -1 18 | save_last: True 19 | - class_path: lightning.pytorch.callbacks.LearningRateMonitor 20 | init_args: 21 | logging_interval: step 22 | logger: 23 | class_path: lightning.pytorch.loggers.TensorBoardLogger 24 | init_args: 25 | save_dir: "../../results/vqgan/" 26 | version: "test" 27 | name: 28 | 29 | model: 30 | class_path: src.IBQ.models.ibqgan.IBQ 31 | init_args: 32 | ddconfig: 33 | double_z: False 34 | z_channels: 256 35 | resolution: 256 36 | in_channels: 3 37 | out_ch: 3 38 | ch: 128 39 | ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1 40 | num_res_blocks: 4 #not adopt from showo 41 | attn_resolutions: [16] #not adopt from showo 42 | dropout: 0.0 43 | 44 | lossconfig: 45 | target: src.IBQ.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 46 | params: 47 | disc_conditional: False 48 | disc_in_channels: 3 49 | disc_start: 0 # from 0 epoch 50 | disc_weight: 0.4 # default 0.4 51 | quant_loss_weight: 1.0 # default 1.0 52 | entropy_loss_weight: 0.05 # default 0.1 53 | gen_loss_weight: 0.1 54 | lecam_loss_weight: 0.05 55 | 56 | n_embed: 262144 57 | embed_dim: 256 58 | learning_rate: 1e-4 59 | l2_normalize: False 60 | use_entropy_loss: True 61 | sample_minimization_weight: 1.0 62 | batch_maximization_weight: 1.0 63 | entropy_temperature: 0.01 # default 0.01 64 | beta: 0.25 65 | use_ema: True 66 | use_shared_epoch: True 67 | resume_lr: 68 | sche_type: 69 | wpe: 0.01 ## learning rate decay to zero 70 | wp: 1 ##one epoch for linear warmup 71 | wp0: 0.0 ##for warmup #from zero to lr 72 | max_iter: 1500000 73 | wp_iter: 5000 74 | lr_drop_iter: [50, 100] 75 | 76 | data: 77 | class_path: main.DataModuleFromConfig 78 | init_args: 79 | batch_size: 8 80 | num_workers: 16 81 | train: 82 | target: src.Open_MAGVIT2.data.pretrain.LAIONCombineTrain 83 | params: 84 | config: 85 | size: 256 86 | subset: 87 | filter_path: ["../../data/laion-aesthetic-v2_filter_keys.json", "../../data/JourneyDB_filter_keys.json", "../../data/laion-aesthetic_v1_filter_keys.json", "../../data/laion-hd_sub_filter_keys_2.json", "../../data/capfusion_filter_keys.json"] 88 | sample_json_path: ["../../data/capfusion_samples.json","../../data/laion-coco_samples.json", "../../data/cc15m_samples_2.json", "../../data/laion-aesthetic-v2_samples.json", "../../data/JourneyDB_samples.json", "../../data/laion-aesthetic_v1_samples.json", "../../data/laion-hd_sub_samples_2.json"] 89 | sample_coco_urls: ../../data/laion-coco_sample_urls_20M.txt 90 | sample_hd_urls: ../../data/laion-hd_sample_urls_30M_2.txt 91 | data_dir: ["../../data/CapFusion-120M", "../../data/LAION-COCO-Recaption", "../../data/CC12M/webdataset/gcc12m_shards", "../../data/Laion-aesthetic-v2/data", "../../data/CC3M/webdataset/gcc3m_shards", "../../data/JourneyDB/wds", "../../data/laion-aesthetics-12M/webdataset_train", "../../data/laion-hd/webdataset_train/"] 92 | image_key: [jpg, jpeg.jpg, "jpg.jpg"] 93 | enable_image: True 94 | validation: 95 | target: src.Open_MAGVIT2.data.imagenet.ImageNetValidation 96 | params: 97 | config: 98 | size: 256 99 | subset: 100 | test: 101 | target: src.Open_MAGVIT2.data.imagenet.ImageNetValidation 102 | params: 103 | config: 104 | size: 256 105 | subset: 106 | 107 | ckpt_path: null # to resume -------------------------------------------------------------------------------- /configs/IBQ/npu/pretrain_ibqgan_16384.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: true 2 | trainer: 3 | accelerator: npu 4 | strategy: ddp_find_unused_parameters_true 5 | devices: 8 6 | num_nodes: 4 7 | precision: bf16-mixed 8 | max_steps: 1500000 9 | check_val_every_n_epoch: null 10 | val_check_interval: 5005 ## one imagenet epoch length 11 | num_sanity_val_steps: -1 12 | log_every_n_steps: 100 13 | callbacks: 14 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 15 | init_args: 16 | dirpath: "../../checkpoints/vqgan/test" 17 | save_top_k: -1 18 | save_last: True 19 | - class_path: lightning.pytorch.callbacks.LearningRateMonitor 20 | init_args: 21 | logging_interval: step 22 | logger: 23 | class_path: lightning.pytorch.loggers.TensorBoardLogger 24 | init_args: 25 | save_dir: "../../results/vqgan/" 26 | version: "test" 27 | name: 28 | 29 | model: 30 | class_path: src.IBQ.models.ibqgan.IBQ 31 | init_args: 32 | ddconfig: 33 | double_z: False 34 | z_channels: 256 35 | resolution: 256 36 | in_channels: 3 37 | out_ch: 3 38 | ch: 128 39 | ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1 40 | num_res_blocks: 4 #not adopt from showo 41 | attn_resolutions: [16] #not adopt from showo 42 | dropout: 0.0 43 | 44 | lossconfig: 45 | target: src.IBQ.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 46 | params: 47 | disc_conditional: False 48 | disc_in_channels: 3 49 | disc_start: 0 # from 0 epoch 50 | disc_weight: 0.4 # default 0.4 51 | quant_loss_weight: 1.0 # default 1.0 52 | entropy_loss_weight: 0.05 # default 0.1 53 | gen_loss_weight: 0.1 54 | lecam_loss_weight: 0.05 55 | 56 | n_embed: 262144 57 | embed_dim: 256 58 | learning_rate: 1e-4 59 | l2_normalize: False 60 | use_entropy_loss: True 61 | sample_minimization_weight: 1.0 62 | batch_maximization_weight: 1.0 63 | entropy_temperature: 0.01 # default 0.01 64 | beta: 0.25 65 | use_ema: True 66 | use_shared_epoch: True 67 | resume_lr: 68 | sche_type: 69 | wpe: 0.01 ## learning rate decay to zero 70 | wp: 1 ##one epoch for linear warmup 71 | wp0: 0.0 ##for warmup #from zero to lr 72 | max_iter: 1500000 73 | wp_iter: 5000 74 | lr_drop_iter: [50, 100] 75 | 76 | data: 77 | class_path: main.DataModuleFromConfig 78 | init_args: 79 | batch_size: 8 80 | num_workers: 16 81 | train: 82 | target: src.Open_MAGVIT2.data.pretrain.LAIONCombineTrain 83 | params: 84 | config: 85 | size: 256 86 | subset: 87 | filter_path: ["../../data/laion-aesthetic-v2_filter_keys.json", "../../data/JourneyDB_filter_keys.json", "../../data/laion-aesthetic_v1_filter_keys.json", "../../data/laion-hd_sub_filter_keys_2.json", "../../data/capfusion_filter_keys.json"] 88 | sample_json_path: ["../../data/capfusion_samples.json","../../data/laion-coco_samples.json", "../../data/cc15m_samples_2.json", "../../data/laion-aesthetic-v2_samples.json", "../../data/JourneyDB_samples.json", "../../data/laion-aesthetic_v1_samples.json", "../../data/laion-hd_sub_samples_2.json"] 89 | sample_coco_urls: ../../data/laion-coco_sample_urls_20M.txt 90 | sample_hd_urls: ../../data/laion-hd_sample_urls_30M_2.txt 91 | data_dir: ["../../data/CapFusion-120M", "../../data/LAION-COCO-Recaption", "../../data/CC12M/webdataset/gcc12m_shards", "../../data/Laion-aesthetic-v2/data", "../../data/CC3M/webdataset/gcc3m_shards", "../../data/JourneyDB/wds", "../../data/laion-aesthetics-12M/webdataset_train", "../../data/laion-hd/webdataset_train/"] 92 | image_key: [jpg, jpeg.jpg, "jpg.jpg"] 93 | enable_image: True 94 | validation: 95 | target: src.Open_MAGVIT2.data.imagenet.ImageNetValidation 96 | params: 97 | config: 98 | size: 256 99 | subset: 100 | test: 101 | target: src.Open_MAGVIT2.data.imagenet.ImageNetValidation 102 | params: 103 | config: 104 | size: 256 105 | subset: 106 | 107 | ckpt_path: null # to resume -------------------------------------------------------------------------------- /configs/IBQ/npu/pretrain_ibqgan_262144.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: true 2 | trainer: 3 | accelerator: npu 4 | strategy: ddp_find_unused_parameters_true 5 | devices: 8 6 | num_nodes: 4 7 | precision: bf16-mixed 8 | max_steps: 1500000 9 | check_val_every_n_epoch: null 10 | val_check_interval: 5005 ## one imagenet epoch length 11 | num_sanity_val_steps: -1 12 | log_every_n_steps: 100 13 | callbacks: 14 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 15 | init_args: 16 | dirpath: "../../checkpoints/vqgan/test" 17 | save_top_k: -1 18 | save_last: True 19 | - class_path: lightning.pytorch.callbacks.LearningRateMonitor 20 | init_args: 21 | logging_interval: step 22 | logger: 23 | class_path: lightning.pytorch.loggers.TensorBoardLogger 24 | init_args: 25 | save_dir: "../../results/vqgan/" 26 | version: "test" 27 | name: 28 | 29 | model: 30 | class_path: src.IBQ.models.ibqgan.IBQ 31 | init_args: 32 | ddconfig: 33 | double_z: False 34 | z_channels: 256 35 | resolution: 256 36 | in_channels: 3 37 | out_ch: 3 38 | ch: 128 39 | ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1 40 | num_res_blocks: 4 #not adopt from showo 41 | attn_resolutions: [16] #not adopt from showo 42 | dropout: 0.0 43 | 44 | lossconfig: 45 | target: src.IBQ.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 46 | params: 47 | disc_conditional: False 48 | disc_in_channels: 3 49 | disc_start: 0 # from 0 epoch 50 | disc_weight: 0.4 # default 0.4 51 | quant_loss_weight: 1.0 # default 1.0 52 | entropy_loss_weight: 0.05 # default 0.1 53 | gen_loss_weight: 0.1 54 | lecam_loss_weight: 0.05 55 | 56 | n_embed: 262144 57 | embed_dim: 256 58 | learning_rate: 1e-4 59 | l2_normalize: False 60 | use_entropy_loss: True 61 | sample_minimization_weight: 1.0 62 | batch_maximization_weight: 1.0 63 | entropy_temperature: 0.01 # default 0.01 64 | beta: 0.25 65 | use_ema: True 66 | use_shared_epoch: True 67 | resume_lr: 68 | sche_type: 69 | wpe: 0.01 ## learning rate decay to zero 70 | wp: 1 ##one epoch for linear warmup 71 | wp0: 0.0 ##for warmup #from zero to lr 72 | max_iter: 1500000 73 | wp_iter: 5000 74 | lr_drop_iter: [50, 100] 75 | 76 | data: 77 | class_path: main.DataModuleFromConfig 78 | init_args: 79 | batch_size: 8 80 | num_workers: 16 81 | train: 82 | target: src.Open_MAGVIT2.data.pretrain.LAIONCombineTrain 83 | params: 84 | config: 85 | size: 256 86 | subset: 87 | filter_path: ["../../data/laion-aesthetic-v2_filter_keys.json", "../../data/JourneyDB_filter_keys.json", "../../data/laion-aesthetic_v1_filter_keys.json", "../../data/laion-hd_sub_filter_keys_2.json", "../../data/capfusion_filter_keys.json"] 88 | sample_json_path: ["../../data/capfusion_samples.json","../../data/laion-coco_samples.json", "../../data/cc15m_samples_2.json", "../../data/laion-aesthetic-v2_samples.json", "../../data/JourneyDB_samples.json", "../../data/laion-aesthetic_v1_samples.json", "../../data/laion-hd_sub_samples_2.json"] 89 | sample_coco_urls: ../../data/laion-coco_sample_urls_20M.txt 90 | sample_hd_urls: ../../data/laion-hd_sample_urls_30M_2.txt 91 | data_dir: ["../../data/CapFusion-120M", "../../data/LAION-COCO-Recaption", "../../data/CC12M/webdataset/gcc12m_shards", "../../data/Laion-aesthetic-v2/data", "../../data/CC3M/webdataset/gcc3m_shards", "../../data/JourneyDB/wds", "../../data/laion-aesthetics-12M/webdataset_train", "../../data/laion-hd/webdataset_train/"] 92 | image_key: [jpg, jpeg.jpg, "jpg.jpg"] 93 | enable_image: True 94 | validation: 95 | target: src.Open_MAGVIT2.data.imagenet.ImageNetValidation 96 | params: 97 | config: 98 | size: 256 99 | subset: 100 | test: 101 | target: src.Open_MAGVIT2.data.imagenet.ImageNetValidation 102 | params: 103 | config: 104 | size: 256 105 | subset: 106 | 107 | ckpt_path: null # to resume -------------------------------------------------------------------------------- /configs/IBQ/gpu/pretrain_ibqgan_16384.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: true 2 | trainer: 3 | accelerator: gpu 4 | strategy: ddp_find_unused_parameters_true 5 | devices: 8 6 | num_nodes: 4 7 | precision: 16-mixed 8 | max_steps: 1500000 9 | check_val_every_n_epoch: null 10 | val_check_interval: 5005 ## one imagenet epoch length 11 | num_sanity_val_steps: -1 12 | log_every_n_steps: 100 13 | callbacks: 14 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 15 | init_args: 16 | dirpath: "../../checkpoints/vqgan/test" 17 | save_top_k: -1 18 | save_last: True 19 | - class_path: lightning.pytorch.callbacks.LearningRateMonitor 20 | init_args: 21 | logging_interval: step 22 | logger: 23 | class_path: lightning.pytorch.loggers.TensorBoardLogger 24 | init_args: 25 | save_dir: "../../results/vqgan/" 26 | version: "test" 27 | name: 28 | 29 | model: 30 | class_path: src.IBQ.models.ibqgan.IBQ 31 | init_args: 32 | ddconfig: 33 | double_z: False 34 | z_channels: 256 35 | resolution: 256 36 | in_channels: 3 37 | out_ch: 3 38 | ch: 128 39 | ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1 40 | num_res_blocks: 4 #not adopt from showo 41 | attn_resolutions: [16] #not adopt from showo 42 | dropout: 0.0 43 | 44 | lossconfig: 45 | target: src.IBQ.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 46 | params: 47 | disc_conditional: False 48 | disc_in_channels: 3 49 | disc_start: 0 # from 0 epoch 50 | disc_weight: 0.4 # default 0.4 51 | quant_loss_weight: 1.0 # default 1.0 52 | entropy_loss_weight: 0.05 # default 0.1 53 | gen_loss_weight: 0.1 54 | lecam_loss_weight: 0.05 55 | 56 | n_embed: 16384 57 | embed_dim: 256 58 | learning_rate: 1e-4 59 | l2_normalize: False 60 | use_entropy_loss: True 61 | sample_minimization_weight: 1.0 62 | batch_maximization_weight: 1.0 63 | entropy_temperature: 0.01 # default 0.01 64 | beta: 0.25 65 | use_ema: True 66 | use_shared_epoch: True 67 | resume_lr: 68 | sche_type: 69 | wpe: 0.01 ## learning rate decay to zero 70 | wp: 1 ##one epoch for linear warmup 71 | wp0: 0.0 ##for warmup #from zero to lr 72 | max_iter: 1500000 73 | wp_iter: 5000 74 | lr_drop_iter: [50, 100] 75 | 76 | data: 77 | class_path: main.DataModuleFromConfig 78 | init_args: 79 | batch_size: 8 80 | num_workers: 16 81 | train: 82 | target: src.Open_MAGVIT2.data.pretrain.LAIONCombineTrain 83 | params: 84 | config: 85 | size: 256 86 | subset: 87 | filter_path: ["../../data/laion-aesthetic-v2_filter_keys.json", "../../data/JourneyDB_filter_keys.json", "../../data/laion-aesthetic_v1_filter_keys.json", "../../data/laion-hd_sub_filter_keys_2.json", "../../data/capfusion_filter_keys.json"] 88 | sample_json_path: ["../../data/capfusion_samples.json","../../data/laion-coco_samples.json", "../../data/cc15m_samples_2.json", "../../data/laion-aesthetic-v2_samples.json", "../../data/JourneyDB_samples.json", "../../data/laion-aesthetic_v1_samples.json", "../../data/laion-hd_sub_samples_2.json"] 89 | sample_coco_urls: ../../data/laion-coco_sample_urls_20M.txt 90 | sample_hd_urls: ../../data/laion-hd_sample_urls_30M_2.txt 91 | data_dir: ["../../data/CapFusion-120M", "../../data/LAION-COCO-Recaption", "../../data/CC12M/webdataset/gcc12m_shards", "../../data/Laion-aesthetic-v2/data", "../../data/CC3M/webdataset/gcc3m_shards", "../../data/JourneyDB/wds", "../../data/laion-aesthetics-12M/webdataset_train", "../../data/laion-hd/webdataset_train/"] 92 | image_key: [jpg, jpeg.jpg, "jpg.jpg"] 93 | enable_image: True 94 | validation: 95 | target: src.Open_MAGVIT2.data.imagenet.ImageNetValidation 96 | params: 97 | config: 98 | size: 256 99 | subset: 100 | test: 101 | target: src.Open_MAGVIT2.data.imagenet.ImageNetValidation 102 | params: 103 | config: 104 | size: 256 105 | subset: 106 | 107 | ckpt_path: null # to resume 108 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse, os, sys, datetime, glob, importlib 2 | from torch.utils.data import random_split, DataLoader, Dataset 3 | 4 | import lightning as L 5 | from lightning.pytorch.cli import LightningCLI 6 | from lightning.pytorch.callbacks import ModelCheckpoint, Callback, LearningRateMonitor 7 | from lightning import seed_everything 8 | 9 | from torch.utils.data.dataloader import default_collate as custom_collate 10 | 11 | import torch 12 | torch.set_float32_matmul_precision("high") 13 | torch.backends.cudnn.deterministic = True #True 14 | torch.backends.cudnn.benchmark = False #False 15 | 16 | def get_obj_from_str(string, reload=False): 17 | module, cls = string.rsplit(".", 1) 18 | if reload: 19 | module_imp = importlib.import_module(module) 20 | importlib.reload(module_imp) 21 | return getattr(importlib.import_module(module, package=None), cls) 22 | 23 | 24 | def instantiate_from_config(config): 25 | if not "target" in config: 26 | raise KeyError("Expected key `target` to instantiate.") 27 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 28 | 29 | 30 | class WrappedDataset(Dataset): 31 | """Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset""" 32 | def __init__(self, dataset): 33 | self.data = dataset 34 | 35 | def __len__(self): 36 | return len(self.data) 37 | 38 | def __getitem__(self, idx): 39 | return self.data[idx] 40 | 41 | 42 | class DataModuleFromConfig(L.LightningDataModule): 43 | def __init__(self, batch_size, train=None, validation=None, test=None, 44 | wrap=False, num_workers=None): 45 | super().__init__() 46 | self.batch_size = batch_size 47 | self.dataset_configs = dict() 48 | self.num_workers = num_workers if num_workers is not None else batch_size*2 49 | if train is not None: 50 | self.dataset_configs["train"] = train 51 | self.train_dataloader = self._train_dataloader 52 | if validation is not None: 53 | self.dataset_configs["validation"] = validation 54 | self.val_dataloader = self._val_dataloader 55 | if test is not None: 56 | self.dataset_configs["test"] = test 57 | self.test_dataloader = self._test_dataloader 58 | self.wrap = wrap 59 | 60 | def prepare_data(self): 61 | for data_cfg in self.dataset_configs.values(): 62 | instantiate_from_config(data_cfg) 63 | 64 | def setup(self, stage=None): 65 | self.datasets = dict() 66 | for k in self.dataset_configs: 67 | if "pretrain" not in self.dataset_configs[k]["target"]: ##laion should use webdataset 68 | self.datasets[k] = instantiate_from_config(self.dataset_configs[k]) 69 | else: 70 | self.datasets[k] = instantiate_from_config(self.dataset_configs[k]).create_dataset() 71 | if self.wrap: 72 | for k in self.datasets: 73 | self.datasets[k] = WrappedDataset(self.datasets[k]) 74 | 75 | def _train_dataloader(self): 76 | """ 77 | laion serves as the train loader 78 | """ 79 | if "pretrain" in self.dataset_configs["train"]["target"]: ## webdataset no need for shuffle=True 80 | return DataLoader(self.datasets["train"], batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True) 81 | else: 82 | return DataLoader(self.datasets["train"], batch_size=self.batch_size, 83 | num_workers=self.num_workers, shuffle=True, collate_fn=custom_collate, pin_memory=True) 84 | 85 | def _val_dataloader(self): 86 | return DataLoader(self.datasets["validation"], 87 | batch_size=self.batch_size, 88 | num_workers=self.num_workers, collate_fn=custom_collate, shuffle=False, pin_memory=True) 89 | 90 | def _test_dataloader(self): 91 | return DataLoader(self.datasets["test"], batch_size=self.batch_size, 92 | num_workers=self.num_workers, collate_fn=custom_collate, shuffle=False, pin_memory=True) 93 | 94 | def main(): 95 | cli = LightningCLI( 96 | save_config_kwargs={"overwrite": True}, 97 | ) 98 | 99 | 100 | if __name__ == "__main__": 101 | main() 102 | -------------------------------------------------------------------------------- /configs/Open-MAGVIT2/gpu/pretrain_lfqgan_256_262144.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: true 2 | trainer: 3 | accelerator: gpu 4 | strategy: ddp_find_unused_parameters_true 5 | devices: 8 6 | num_nodes: 4 7 | precision: 16-mixed 8 | max_steps: 1500000 9 | check_val_every_n_epoch: null 10 | val_check_interval: 5005 ## one imagenet epoch length 11 | num_sanity_val_steps: -1 12 | log_every_n_steps: 100 13 | callbacks: 14 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 15 | init_args: 16 | dirpath: "../../checkpoints/vqgan/test" 17 | save_top_k: -1 # save all checkpoints 18 | - class_path: lightning.pytorch.callbacks.LearningRateMonitor 19 | init_args: 20 | logging_interval: step 21 | logger: 22 | class_path: lightning.pytorch.loggers.TensorBoardLogger 23 | init_args: 24 | save_dir: "../../results/vqgan/" 25 | version: "test" 26 | name: 27 | 28 | model: 29 | class_path: src.Open_MAGVIT2.models.lfqgan_pretrain.VQModel 30 | init_args: 31 | ddconfig: 32 | double_z: False 33 | z_channels: 18 #18 34 | resolution: 128 35 | in_channels: 3 36 | out_ch: 3 37 | ch: 128 38 | ch_mult: [1,1,2,2,4] # num_down = len(ch_mult)-1 39 | num_res_blocks: 4 40 | 41 | lossconfig: 42 | target: src.Open_MAGVIT2.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 43 | params: 44 | disc_conditional: False 45 | disc_in_channels: 3 46 | disc_start: 0 # from 0 epoch #70000 is ok 47 | disc_num_layers: 3 48 | disc_weight: 0.8 49 | gen_loss_weight: 0.1 #using 0.1 for more training stability 50 | lecam_loss_weight: 0.05 51 | codebook_weight: 0.1 #can be lowered to 0.05 52 | commit_weight: 0.25 53 | codebook_enlarge_ratio: 0 54 | codebook_enlarge_steps: 2000 55 | disc_loss: hinge 56 | disc_num_channels: 3 57 | disc_num_stages: 3 58 | disc_hidden_channels: 128 59 | blur_resample: True 60 | blur_kernel_size: 4 61 | 62 | n_embed: 262144 #262144 63 | embed_dim: 18 #18 64 | learning_rate: 1e-4 65 | sample_minimization_weight: 1.0 66 | batch_maximization_weight: 1.0 67 | scheduler_type: "None" 68 | use_ema: True 69 | use_shared_epoch: True 70 | sche_type: 71 | wpe: 0.01 ## learning rate decay to zero 72 | wp: 1 ##one epoch for linear warmup 73 | wp0: 0.0 ##for warmup #from zero to lr 74 | max_iter: 75 | wp_iter: 76 | lr_drop_iter: [800000, 1000000] 77 | 78 | data: 79 | class_path: main.DataModuleFromConfig 80 | init_args: 81 | batch_size: 8 82 | num_workers: 16 83 | train: 84 | target: src.Open_MAGVIT2.data.pretrain.LAIONCombineTrain 85 | params: 86 | config: 87 | size: 256 88 | subset: 89 | filter_path: ["../../data/laion-aesthetic-v2_filter_keys.json", "../../data/JourneyDB_filter_keys.json", "../../data/laion-aesthetic_v1_filter_keys.json", "../../data/laion-hd_sub_filter_keys_2.json"] 90 | sample_json_path: ["../../data/laion-coco_samples.json", "../../data/cc15m_samples_2.json", "../../data/laion-aesthetic-v2_samples.json", "../../data/JourneyDB_samples.json", "../../data/laion-aesthetic_v1_samples.json", "../../data/laion-hd_sub_samples_2.json"] 91 | sample_coco_urls: ../../data/laion-coco_sample_urls_20M.txt 92 | sample_hd_urls: ../../data/laion-hd_sample_urls_30M_2.txt 93 | data_dir: ["../../data/LAION-COCO-Recaption", "../../data/CC12M/webdataset/gcc12m_shards", "../../data/Laion-aesthetic-v2/data", "../../data/CC3M/webdataset/gcc3m_shards", "../../data/JourneyDB/wds", "../../data/laion-aesthetics-12M/webdataset_train", "../../data/laion-hd/webdataset_train/"] 94 | image_key: [jpg, jpeg.jpg, "jpg.jpg"] 95 | enable_image: True 96 | validation: 97 | target: src.Open_MAGVIT2.data.imagenet.ImageNetValidation 98 | params: 99 | config: 100 | size: 256 101 | subset: 102 | test: 103 | target: src.Open_MAGVIT2.data.imagenet.ImageNetValidation 104 | params: 105 | config: 106 | size: 256 107 | subset: 108 | 109 | ckpt_path: null # to resume -------------------------------------------------------------------------------- /configs/Open-MAGVIT2/npu/pretrain_lfqgan_256_16384.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: true 2 | trainer: 3 | accelerator: npu 4 | strategy: ddp_find_unused_parameters_true 5 | devices: 8 6 | num_nodes: 4 7 | precision: bf16-mixed 8 | max_steps: 1500000 9 | check_val_every_n_epoch: null 10 | val_check_interval: 5005 ## one imagenet epoch length 11 | num_sanity_val_steps: -1 12 | log_every_n_steps: 100 13 | callbacks: 14 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 15 | init_args: 16 | dirpath: "../../checkpoints/vqgan/test" 17 | save_top_k: -1 # save all checkpoints 18 | - class_path: lightning.pytorch.callbacks.LearningRateMonitor 19 | init_args: 20 | logging_interval: step 21 | logger: 22 | class_path: lightning.pytorch.loggers.TensorBoardLogger 23 | init_args: 24 | save_dir: "../../results/vqgan/" 25 | version: "test" 26 | name: 27 | 28 | model: 29 | class_path: src.Open_MAGVIT2.models.lfqgan_pretrain.VQModel 30 | init_args: 31 | ddconfig: 32 | double_z: False 33 | z_channels: 14 #18 34 | resolution: 128 35 | in_channels: 3 36 | out_ch: 3 37 | ch: 128 38 | ch_mult: [1,1,2,2,4] # num_down = len(ch_mult)-1 39 | num_res_blocks: 4 40 | 41 | lossconfig: 42 | target: src.Open_MAGVIT2.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 43 | params: 44 | disc_conditional: False 45 | disc_in_channels: 3 46 | disc_start: 0 # from 0 epoch #70000 is ok 47 | disc_num_layers: 4 48 | disc_weight: 0.8 49 | gen_loss_weight: 0.1 #using 0.1 for more training stability 50 | lecam_loss_weight: 0.05 51 | codebook_weight: 0.1 #can be lowered to 0.05 52 | commit_weight: 0.25 53 | codebook_enlarge_ratio: 0 54 | codebook_enlarge_steps: 2000 55 | disc_loss: hinge 56 | disc_num_channels: 3 57 | disc_num_stages: 3 58 | disc_hidden_channels: 128 59 | blur_resample: True 60 | blur_kernel_size: 4 61 | use_blur: True 62 | 63 | n_embed: 16384 #262144 64 | embed_dim: 14 #18 65 | learning_rate: 1e-4 66 | sample_minimization_weight: 1.0 67 | batch_maximization_weight: 1.0 68 | scheduler_type: "None" 69 | use_ema: True 70 | use_shared_epoch: True 71 | sche_type: cos 72 | wpe: 0.01 ## learning rate decay to zero 73 | wp: 1 ##one epoch for linear warmup 74 | wp0: 0.0 ##for warmup #from zero to lr 75 | max_iter: 1500000 76 | wp_iter: 5000 77 | 78 | data: 79 | class_path: main.DataModuleFromConfig 80 | init_args: 81 | batch_size: 8 82 | num_workers: 16 83 | train: 84 | target: src.Open_MAGVIT2.data.pretrain.LAIONCombineTrain 85 | params: 86 | config: 87 | size: 256 88 | subset: 89 | filter_path: ["../../data/laion-aesthetic-v2_filter_keys.json", "../../data/JourneyDB_filter_keys.json", "../../data/laion-aesthetic_v1_filter_keys.json", "../../data/laion-hd_sub_filter_keys_2.json"] 90 | sample_json_path: ["../../data/laion-coco_samples.json", "../../data/cc15m_samples_2.json", "../../data/laion-aesthetic-v2_samples.json", "../../data/JourneyDB_samples.json", "../../data/laion-aesthetic_v1_samples.json", "../../data/laion-hd_sub_samples_2.json"] 91 | sample_coco_urls: ../../data/laion-coco_sample_urls_20M.txt 92 | sample_hd_urls: ../../data/laion-hd_sample_urls_30M_2.txt 93 | data_dir: ["../../data/LAION-COCO-Recaption", "../../data/CC12M/webdataset/gcc12m_shards", "../../data/Laion-aesthetic-v2/data", "../../data/CC3M/webdataset/gcc3m_shards", "../../data/JourneyDB/wds", "../../data/laion-aesthetics-12M/webdataset_train", "../../data/laion-hd/webdataset_train/"] 94 | image_key: [jpg, jpeg.jpg, "jpg.jpg"] 95 | enable_image: True 96 | validation: 97 | target: src.Open_MAGVIT2.data.imagenet.ImageNetValidation 98 | params: 99 | config: 100 | size: 256 101 | subset: 102 | test: 103 | target: src.Open_MAGVIT2.data.imagenet.ImageNetValidation 104 | params: 105 | config: 106 | size: 256 107 | subset: 108 | 109 | ckpt_path: null # to resume -------------------------------------------------------------------------------- /configs/Open-MAGVIT2/npu/pretrain_lfqgan_256_262144.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: true 2 | trainer: 3 | accelerator: npu 4 | strategy: ddp_find_unused_parameters_true 5 | devices: 8 6 | num_nodes: 4 7 | precision: bf16-mixed 8 | max_steps: 1500000 9 | check_val_every_n_epoch: null 10 | val_check_interval: 5005 ## one imagenet epoch length 11 | num_sanity_val_steps: -1 12 | log_every_n_steps: 100 13 | callbacks: 14 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 15 | init_args: 16 | dirpath: "../../checkpoints/vqgan/test" 17 | save_top_k: -1 # save all checkpoints 18 | - class_path: lightning.pytorch.callbacks.LearningRateMonitor 19 | init_args: 20 | logging_interval: step 21 | logger: 22 | class_path: lightning.pytorch.loggers.TensorBoardLogger 23 | init_args: 24 | save_dir: "../../results/vqgan/" 25 | version: "test" 26 | name: 27 | 28 | model: 29 | class_path: src.Open_MAGVIT2.models.lfqgan_pretrain.VQModel 30 | init_args: 31 | ddconfig: 32 | double_z: False 33 | z_channels: 18 #18 34 | resolution: 128 35 | in_channels: 3 36 | out_ch: 3 37 | ch: 128 38 | ch_mult: [1,1,2,2,4] # num_down = len(ch_mult)-1 39 | num_res_blocks: 4 40 | 41 | lossconfig: 42 | target: src.Open_MAGVIT2.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 43 | params: 44 | disc_conditional: False 45 | disc_in_channels: 3 46 | disc_start: 0 # from 0 epoch #70000 is ok 47 | disc_num_layers: 3 48 | disc_weight: 0.8 49 | gen_loss_weight: 0.1 #using 0.1 for more training stability 50 | lecam_loss_weight: 0.05 51 | codebook_weight: 0.1 #can be lowered to 0.05 52 | commit_weight: 0.25 53 | codebook_enlarge_ratio: 0 54 | codebook_enlarge_steps: 2000 55 | disc_loss: hinge 56 | disc_num_channels: 3 57 | disc_num_stages: 3 58 | disc_hidden_channels: 128 59 | blur_resample: True 60 | blur_kernel_size: 4 61 | 62 | n_embed: 262144 #262144 63 | embed_dim: 18 #18 64 | learning_rate: 1e-4 65 | sample_minimization_weight: 1.0 66 | batch_maximization_weight: 1.0 67 | scheduler_type: "None" 68 | use_ema: True 69 | use_shared_epoch: True 70 | sche_type: 71 | wpe: 0.01 ## learning rate decay to zero 72 | wp: 1 ##one epoch for linear warmup 73 | wp0: 0.0 ##for warmup #from zero to lr 74 | max_iter: 75 | wp_iter: 76 | lr_drop_iter: [800000, 1000000] 77 | 78 | data: 79 | class_path: main.DataModuleFromConfig 80 | init_args: 81 | batch_size: 8 82 | num_workers: 16 83 | train: 84 | target: src.Open_MAGVIT2.data.pretrain.LAIONCombineTrain 85 | params: 86 | config: 87 | size: 256 88 | subset: 89 | filter_path: ["../../data/laion-aesthetic-v2_filter_keys.json", "../../data/JourneyDB_filter_keys.json", "../../data/laion-aesthetic_v1_filter_keys.json", "../../data/laion-hd_sub_filter_keys_2.json"] 90 | sample_json_path: ["../../data/laion-coco_samples.json", "../../data/cc15m_samples_2.json", "../../data/laion-aesthetic-v2_samples.json", "../../data/JourneyDB_samples.json", "../../data/laion-aesthetic_v1_samples.json", "../../data/laion-hd_sub_samples_2.json"] 91 | sample_coco_urls: ../../data/laion-coco_sample_urls_20M.txt 92 | sample_hd_urls: ../../data/laion-hd_sample_urls_30M_2.txt 93 | data_dir: ["../../data/LAION-COCO-Recaption", "../../data/CC12M/webdataset/gcc12m_shards", "../../data/Laion-aesthetic-v2/data", "../../data/CC3M/webdataset/gcc3m_shards", "../../data/JourneyDB/wds", "../../data/laion-aesthetics-12M/webdataset_train", "../../data/laion-hd/webdataset_train/"] 94 | image_key: [jpg, jpeg.jpg, "jpg.jpg"] 95 | enable_image: True 96 | validation: 97 | target: src.Open_MAGVIT2.data.imagenet.ImageNetValidation 98 | params: 99 | config: 100 | size: 256 101 | subset: 102 | test: 103 | target: src.Open_MAGVIT2.data.imagenet.ImageNetValidation 104 | params: 105 | config: 106 | size: 256 107 | subset: 108 | 109 | ckpt_path: null # to resume -------------------------------------------------------------------------------- /configs/Open-MAGVIT2/gpu/pretrain_lfqgan_256_16384.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: true 2 | trainer: 3 | accelerator: gpu 4 | strategy: ddp_find_unused_parameters_true 5 | devices: 8 6 | num_nodes: 4 7 | precision: 16-mixed 8 | max_steps: 1500000 9 | check_val_every_n_epoch: null 10 | val_check_interval: 5005 ## one imagenet epoch length 11 | num_sanity_val_steps: -1 12 | log_every_n_steps: 100 13 | callbacks: 14 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 15 | init_args: 16 | dirpath: "../../checkpoints/vqgan/test" 17 | save_top_k: -1 # save all checkpoints 18 | - class_path: lightning.pytorch.callbacks.LearningRateMonitor 19 | init_args: 20 | logging_interval: step 21 | logger: 22 | class_path: lightning.pytorch.loggers.TensorBoardLogger 23 | init_args: 24 | save_dir: "../../results/vqgan/" 25 | version: "test" 26 | name: 27 | 28 | model: 29 | class_path: src.Open_MAGVIT2.models.lfqgan_pretrain.VQModel 30 | init_args: 31 | ddconfig: 32 | double_z: False 33 | z_channels: 14 #18 34 | resolution: 128 35 | in_channels: 3 36 | out_ch: 3 37 | ch: 128 38 | ch_mult: [1,1,2,2,4] # num_down = len(ch_mult)-1 39 | num_res_blocks: 4 40 | 41 | lossconfig: 42 | target: src.Open_MAGVIT2.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 43 | params: 44 | disc_conditional: False 45 | disc_in_channels: 3 46 | disc_start: 0 # from 0 epoch #70000 is ok 47 | disc_num_layers: 4 48 | disc_weight: 0.8 49 | gen_loss_weight: 0.1 #using 0.1 for more training stability 50 | lecam_loss_weight: 0.05 51 | codebook_weight: 0.1 #can be lowered to 0.05 52 | commit_weight: 0.25 53 | codebook_enlarge_ratio: 0 54 | codebook_enlarge_steps: 2000 55 | disc_loss: hinge 56 | disc_num_channels: 3 57 | disc_num_stages: 3 58 | disc_hidden_channels: 128 59 | blur_resample: True 60 | blur_kernel_size: 4 61 | use_blur: True 62 | 63 | n_embed: 16384 #262144 64 | embed_dim: 14 #18 65 | learning_rate: 1e-4 66 | sample_minimization_weight: 1.0 67 | batch_maximization_weight: 1.0 68 | scheduler_type: "None" 69 | use_ema: True 70 | use_shared_epoch: True 71 | sche_type: cos 72 | wpe: 0.01 ## learning rate decay to zero 73 | wp: 1 ##one epoch for linear warmup 74 | wp0: 0.0 ##for warmup #from zero to lr 75 | max_iter: 1500000 76 | wp_iter: 5000 77 | 78 | data: 79 | class_path: main.DataModuleFromConfig 80 | init_args: 81 | batch_size: 8 82 | num_workers: 16 83 | train: 84 | target: src.Open_MAGVIT2.data.pretrain.LAIONCombineTrain 85 | params: 86 | config: 87 | size: 256 88 | subset: 89 | filter_path: ["../../data/laion-aesthetic-v2_filter_keys.json", "../../data/JourneyDB_filter_keys.json", "../../data/laion-aesthetic_v1_filter_keys.json", "../../data/laion-hd_sub_filter_keys_2.json"] 90 | sample_json_path: ["../../data/laion-coco_samples.json", "../../data/cc15m_samples_2.json", "../../data/laion-aesthetic-v2_samples.json", "../../data/JourneyDB_samples.json", "../../data/laion-aesthetic_v1_samples.json", "../../data/laion-hd_sub_samples_2.json"] 91 | sample_coco_urls: ../../data/laion-coco_sample_urls_20M.txt #please specify your path 92 | sample_hd_urls: ../../data/laion-hd_sample_urls_30M_2.txt ##please specify your path 93 | data_dir: ["../../data/LAION-COCO-Recaption", "../../data/CC12M/webdataset/gcc12m_shards", "../../data/Laion-aesthetic-v2/data", "../../data/CC3M/webdataset/gcc3m_shards", "../../data/public_datasets/JourneyDB/wds", "../../data/laion-aesthetics-12M/webdataset_train", "../../public_datasets/laion-hd/webdataset_train/"] 94 | image_key: [jpg, jpeg.jpg, "jpg.jpg"] 95 | enable_image: True 96 | validation: 97 | target: src.Open_MAGVIT2.data.imagenet.ImageNetValidation 98 | params: 99 | config: 100 | size: 256 101 | subset: 102 | test: 103 | target: src.Open_MAGVIT2.data.imagenet.ImageNetValidation 104 | params: 105 | config: 106 | size: 256 107 | subset: 108 | 109 | ckpt_path: null # to resume -------------------------------------------------------------------------------- /src/Open_MAGVIT2/data/prepare_pretrain.py: -------------------------------------------------------------------------------- 1 | """ 2 | Using Webdataset in Lightning 3 | should prepare sample.json and filter_keys.json 4 | We use a subset of 5 | LAION-COCO, CC15M, LAION-Aesthetic-umap 6 | LAION-Aesthetic-v2, JourneyDB, LAION-HD 7 | """ 8 | import webdataset as wds 9 | from PIL import Image 10 | import io 11 | from torch.utils.data import DataLoader, default_collate 12 | import torchvision.transforms as T 13 | import os 14 | import json 15 | from omegaconf import OmegaConf 16 | from tqdm import tqdm 17 | import os 18 | import tarfile 19 | import pandas as pd 20 | from PIL import Image 21 | import io 22 | import json 23 | import multiprocessing as mp 24 | import datetime 25 | import warnings 26 | warnings.simplefilter("always") 27 | 28 | def check_image(image_data, filter=True): 29 | save_image = False 30 | try: 31 | with warnings.catch_warnings(record=True) as w: 32 | image = Image.open(io.BytesIO(image_data)) 33 | if filter: 34 | w, h = image.size 35 | if w >= 512 and h >= 512: ## filter low resolution and aspect ratio > 2 36 | horizational_aspect_ratio = w // h 37 | vertical_aspect_ratio = h // w 38 | if horizational_aspect_ratio > 2 or vertical_aspect_ratio > 2: 39 | save_image = False 40 | else: 41 | save_image = True 42 | else: 43 | save_image = True 44 | if w: 45 | save_image = False 46 | print(f"warning: {w[0].message}") 47 | else: 48 | save_image = True 49 | return save_image 50 | except Exception as e: 51 | print(f"Error details: {str(e)}") 52 | save_image = False 53 | return save_image 54 | 55 | def check_tar_file(args): 56 | tar_dict = dict() 57 | filter_keys = dict() 58 | bad_tar_file = [] 59 | tar_paths, num_processes_idx, unit = args[0], args[1], args[2] 60 | for idx, tar_path in enumerate(tar_paths): 61 | cnt = 0 62 | temp_filter_keys = [] 63 | with tarfile.open(os.path.join(tar_path), "r") as tar: 64 | try: 65 | members = tar.getmembers() 66 | except Exception as e: 67 | print(f"Error details: {str(e)}") 68 | print("skip the" + tar_path) 69 | bad_tar_file.append(tar_path) 70 | continue 71 | for member in members: 72 | if member.isfile(): 73 | name = member.name 74 | if name.endswith(".jpg"): 75 | image_data = tar.extractfile(member).read() 76 | check = check_image(image_data, filter=False) 77 | if check: 78 | name, ext = os.path.splitext(name) #name, jpg 79 | cnt +=1 80 | else: 81 | temp_filter_keys.append(name) 82 | continue 83 | filter_keys[tar_path] = temp_filter_keys 84 | tar_dict[tar_path] = cnt 85 | print(f"[{datetime.datetime.now()}] complete to check in {(num_processes_idx * unit + idx)}") 86 | return tar_dict, filter_keys, bad_tar_file 87 | 88 | if __name__ == "__main__": 89 | ### The datasets should be in the format 90 | ### {1..n}.tar 91 | 92 | TAR_DIR = "../../data/tar_dirs" ##please specify your own datasets 93 | filter_keys = dict() 94 | tar_dicts = dict() 95 | bad_tar_files = [] 96 | 97 | tar_paths = [os.path.join(TAR_DIR, name) for name in os.listdir(TAR_DIR) if name.endswith("tar")] 98 | num_processes = int(max(mp.cpu_count(), 4) * 0.8) 99 | unit = len(tar_paths) // num_processes + 1 100 | work_list = [(tar_paths[idx*unit:(idx+1)*unit], idx, unit) for idx in range(num_processes)] 101 | with mp.Pool(processes=num_processes) as pool: 102 | result = pool.map(check_tar_file, work_list) 103 | 104 | for sublist in result: 105 | tar_dict, filter_key, bad_tar_file = sublist[0], sublist[1], sublist[2] 106 | tar_dicts.update(tar_dict) 107 | filter_keys.update(filter_key) 108 | bad_tar_files = bad_tar_files + bad_tar_file 109 | 110 | with open("../../data/samples.json", "w") as f: 111 | json.dump(tar_dicts, f) 112 | 113 | with open("../../data/filter_keys.json", "w") as f: 114 | json.dump(filter_keys, f) 115 | -------------------------------------------------------------------------------- /src/IBQ/modules/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def count_params(model): 6 | total_params = sum(p.numel() for p in model.parameters()) 7 | return total_params 8 | 9 | 10 | class ActNorm(nn.Module): 11 | def __init__(self, num_features, logdet=False, affine=True, 12 | allow_reverse_init=False): 13 | assert affine 14 | super().__init__() 15 | self.logdet = logdet 16 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) 17 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) 18 | self.allow_reverse_init = allow_reverse_init 19 | 20 | self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8)) 21 | 22 | def initialize(self, input): 23 | with torch.no_grad(): 24 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) 25 | mean = ( 26 | flatten.mean(1) 27 | .unsqueeze(1) 28 | .unsqueeze(2) 29 | .unsqueeze(3) 30 | .permute(1, 0, 2, 3) 31 | ) 32 | std = ( 33 | flatten.std(1) 34 | .unsqueeze(1) 35 | .unsqueeze(2) 36 | .unsqueeze(3) 37 | .permute(1, 0, 2, 3) 38 | ) 39 | 40 | self.loc.data.copy_(-mean) 41 | self.scale.data.copy_(1 / (std + 1e-6)) 42 | 43 | def forward(self, input, reverse=False): 44 | if reverse: 45 | return self.reverse(input) 46 | if len(input.shape) == 2: 47 | input = input[:,:,None,None] 48 | squeeze = True 49 | else: 50 | squeeze = False 51 | 52 | _, _, height, width = input.shape 53 | 54 | if self.training and self.initialized.item() == 0: 55 | self.initialize(input) 56 | self.initialized.fill_(1) 57 | 58 | h = self.scale * (input + self.loc) 59 | 60 | if squeeze: 61 | h = h.squeeze(-1).squeeze(-1) 62 | 63 | if self.logdet: 64 | log_abs = torch.log(torch.abs(self.scale)) 65 | logdet = height*width*torch.sum(log_abs) 66 | logdet = logdet * torch.ones(input.shape[0]).to(input) 67 | return h, logdet 68 | 69 | return h 70 | 71 | def reverse(self, output): 72 | if self.training and self.initialized.item() == 0: 73 | if not self.allow_reverse_init: 74 | raise RuntimeError( 75 | "Initializing ActNorm in reverse direction is " 76 | "disabled by default. Use allow_reverse_init=True to enable." 77 | ) 78 | else: 79 | self.initialize(output) 80 | self.initialized.fill_(1) 81 | 82 | if len(output.shape) == 2: 83 | output = output[:,:,None,None] 84 | squeeze = True 85 | else: 86 | squeeze = False 87 | 88 | h = output / self.scale - self.loc 89 | 90 | if squeeze: 91 | h = h.squeeze(-1).squeeze(-1) 92 | return h 93 | 94 | 95 | class AbstractEncoder(nn.Module): 96 | def __init__(self): 97 | super().__init__() 98 | 99 | def encode(self, *args, **kwargs): 100 | raise NotImplementedError 101 | 102 | 103 | class Labelator(AbstractEncoder): 104 | """Net2Net Interface for Class-Conditional Model""" 105 | def __init__(self, n_classes, quantize_interface=True): 106 | super().__init__() 107 | self.n_classes = n_classes 108 | self.quantize_interface = quantize_interface 109 | 110 | def encode(self, c): 111 | c = c[:,None] 112 | if self.quantize_interface: 113 | return c, None, [None, None, c.long()] 114 | return c 115 | 116 | 117 | class SOSProvider(AbstractEncoder): 118 | # for unconditional training 119 | def __init__(self, sos_token, quantize_interface=True): 120 | super().__init__() 121 | self.sos_token = sos_token 122 | self.quantize_interface = quantize_interface 123 | 124 | def encode(self, x): 125 | # get batch size from data and replicate sos_token 126 | c = torch.ones(x.shape[0], 1)*self.sos_token 127 | c = c.long().to(x.device) 128 | if self.quantize_interface: 129 | return c, None, [None, None, c] 130 | return c 131 | 132 | def requires_grad(model, flag=True): 133 | """ 134 | Set requires_grad flag for all parameters in a model. 135 | """ 136 | for p in model.parameters(): 137 | p.requires_grad = flag -------------------------------------------------------------------------------- /src/Open_MAGVIT2/data/volume_transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import torch 4 | 5 | 6 | def convert_img(img): 7 | """Converts (H, W, C) numpy.ndarray to (C, W, H) format 8 | """ 9 | if len(img.shape) == 3: 10 | img = img.transpose(2, 0, 1) 11 | if len(img.shape) == 2: 12 | img = np.expand_dims(img, 0) 13 | return img 14 | 15 | 16 | class ClipToTensor(object): 17 | """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255] 18 | to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0] 19 | """ 20 | 21 | def __init__(self, channel_nb=3, div_255=True, numpy=False): 22 | self.channel_nb = channel_nb 23 | self.div_255 = div_255 24 | self.numpy = numpy 25 | 26 | def __call__(self, clip): 27 | """ 28 | Args: clip (list of numpy.ndarray): clip (list of images) 29 | to be converted to tensor. 30 | """ 31 | # Retrieve shape 32 | if isinstance(clip[0], np.ndarray): 33 | h, w, ch = clip[0].shape 34 | assert ch == self.channel_nb, 'Got {0} instead of 3 channels'.format( 35 | ch) 36 | elif isinstance(clip[0], Image.Image): 37 | w, h = clip[0].size 38 | else: 39 | raise TypeError('Expected numpy.ndarray or PIL.Image\ 40 | but got list of {0}'.format(type(clip[0]))) 41 | 42 | np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)]) 43 | 44 | # Convert 45 | for img_idx, img in enumerate(clip): 46 | if isinstance(img, np.ndarray): 47 | pass 48 | elif isinstance(img, Image.Image): 49 | img = np.array(img, copy=False) 50 | else: 51 | raise TypeError('Expected numpy.ndarray or PIL.Image\ 52 | but got list of {0}'.format(type(clip[0]))) 53 | img = convert_img(img) 54 | np_clip[:, img_idx, :, :] = img 55 | if self.numpy: 56 | if self.div_255: 57 | np_clip = np_clip / 255.0 58 | return np_clip 59 | 60 | else: 61 | tensor_clip = torch.from_numpy(np_clip) 62 | 63 | if not isinstance(tensor_clip, torch.FloatTensor): 64 | tensor_clip = tensor_clip.float() 65 | if self.div_255: 66 | tensor_clip = torch.div(tensor_clip, 255) 67 | return tensor_clip 68 | 69 | 70 | # Note this norms data to -1/1 71 | class ClipToTensor_K(object): 72 | """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255] 73 | to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0] 74 | """ 75 | 76 | def __init__(self, channel_nb=3, div_255=True, numpy=False): 77 | self.channel_nb = channel_nb 78 | self.div_255 = div_255 79 | self.numpy = numpy 80 | 81 | def __call__(self, clip): 82 | """ 83 | Args: clip (list of numpy.ndarray): clip (list of images) 84 | to be converted to tensor. 85 | """ 86 | # Retrieve shape 87 | if isinstance(clip[0], np.ndarray): 88 | h, w, ch = clip[0].shape 89 | assert ch == self.channel_nb, 'Got {0} instead of 3 channels'.format( 90 | ch) 91 | elif isinstance(clip[0], Image.Image): 92 | w, h = clip[0].size 93 | else: 94 | raise TypeError('Expected numpy.ndarray or PIL.Image\ 95 | but got list of {0}'.format(type(clip[0]))) 96 | 97 | np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)]) 98 | 99 | # Convert 100 | for img_idx, img in enumerate(clip): 101 | if isinstance(img, np.ndarray): 102 | pass 103 | elif isinstance(img, Image.Image): 104 | img = np.array(img, copy=False) 105 | else: 106 | raise TypeError('Expected numpy.ndarray or PIL.Image\ 107 | but got list of {0}'.format(type(clip[0]))) 108 | img = convert_img(img) 109 | np_clip[:, img_idx, :, :] = img 110 | if self.numpy: 111 | if self.div_255: 112 | np_clip = (np_clip - 127.5) / 127.5 113 | return np_clip 114 | 115 | else: 116 | tensor_clip = torch.from_numpy(np_clip) 117 | 118 | if not isinstance(tensor_clip, torch.FloatTensor): 119 | tensor_clip = tensor_clip.float() 120 | if self.div_255: 121 | tensor_clip = torch.div(torch.sub(tensor_clip, 127.5), 127.5) 122 | return tensor_clip 123 | 124 | 125 | class ToTensor(object): 126 | """Converts numpy array to tensor 127 | """ 128 | 129 | def __call__(self, array): 130 | tensor = torch.from_numpy(array) 131 | return tensor 132 | -------------------------------------------------------------------------------- /src/Open_MAGVIT2/modules/losses/lpips.py: -------------------------------------------------------------------------------- 1 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torchvision import models 6 | from collections import namedtuple 7 | 8 | from src.Open_MAGVIT2.util import get_ckpt_path 9 | 10 | 11 | class LPIPS(nn.Module): 12 | # Learned perceptual metric 13 | def __init__(self, use_dropout=True): 14 | super().__init__() 15 | self.scaling_layer = ScalingLayer() 16 | self.chns = [64, 128, 256, 512, 512] # vg16 features 17 | self.net = vgg16(pretrained=True, requires_grad=False) 18 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 19 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 20 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 21 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 22 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 23 | self.load_from_pretrained() 24 | for param in self.parameters(): 25 | param.requires_grad = False 26 | 27 | def load_from_pretrained(self, name="vgg_lpips"): 28 | ckpt = get_ckpt_path(name, "src/Open_MAGVIT2/modules/autoencoder/lpips") 29 | self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 30 | print("loaded pretrained LPIPS loss from {}".format(ckpt)) 31 | 32 | @classmethod 33 | def from_pretrained(cls, name="vgg_lpips"): 34 | if name != "vgg_lpips": 35 | raise NotImplementedError 36 | model = cls() 37 | ckpt = get_ckpt_path(name) 38 | model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 39 | return model 40 | 41 | def forward(self, input, target): 42 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) 43 | outs0, outs1 = self.net(in0_input), self.net(in1_input) 44 | feats0, feats1, diffs = {}, {}, {} 45 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 46 | for kk in range(len(self.chns)): 47 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) 48 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 49 | 50 | res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))] 51 | val = res[0] 52 | for l in range(1, len(self.chns)): 53 | val += res[l] 54 | return val 55 | 56 | 57 | class ScalingLayer(nn.Module): 58 | def __init__(self): 59 | super(ScalingLayer, self).__init__() 60 | self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 61 | self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) 62 | 63 | def forward(self, inp): 64 | return (inp - self.shift) / self.scale 65 | 66 | 67 | class NetLinLayer(nn.Module): 68 | """ A single linear layer which does a 1x1 conv """ 69 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 70 | super(NetLinLayer, self).__init__() 71 | layers = [nn.Dropout(), ] if (use_dropout) else [] 72 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] 73 | self.model = nn.Sequential(*layers) 74 | 75 | 76 | class vgg16(torch.nn.Module): 77 | def __init__(self, requires_grad=False, pretrained=True): 78 | super(vgg16, self).__init__() 79 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features 80 | self.slice1 = torch.nn.Sequential() 81 | self.slice2 = torch.nn.Sequential() 82 | self.slice3 = torch.nn.Sequential() 83 | self.slice4 = torch.nn.Sequential() 84 | self.slice5 = torch.nn.Sequential() 85 | self.N_slices = 5 86 | for x in range(4): 87 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 88 | for x in range(4, 9): 89 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 90 | for x in range(9, 16): 91 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 92 | for x in range(16, 23): 93 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 94 | for x in range(23, 30): 95 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 96 | if not requires_grad: 97 | for param in self.parameters(): 98 | param.requires_grad = False 99 | 100 | def forward(self, X): 101 | h = self.slice1(X) 102 | h_relu1_2 = h 103 | h = self.slice2(h) 104 | h_relu2_2 = h 105 | h = self.slice3(h) 106 | h_relu3_3 = h 107 | h = self.slice4(h) 108 | h_relu4_3 = h 109 | h = self.slice5(h) 110 | h_relu5_3 = h 111 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) 112 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 113 | return out 114 | 115 | 116 | def normalize_tensor(x,eps=1e-10): 117 | norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True)) 118 | return x/(norm_factor+eps) 119 | 120 | 121 | def spatial_average(x, keepdim=True): 122 | return x.mean([2,3],keepdim=keepdim) -------------------------------------------------------------------------------- /reconstruct_image.py: -------------------------------------------------------------------------------- 1 | """ 2 | Image Reconstruction code 3 | """ 4 | import os 5 | import sys 6 | sys.path.append(os.getcwd()) 7 | import torch 8 | from omegaconf import OmegaConf 9 | import importlib 10 | import numpy as np 11 | from PIL import Image 12 | from tqdm import tqdm 13 | from src.Open_MAGVIT2.models.lfqgan import VQModel 14 | from src.IBQ.models.ibqgan import IBQ 15 | import argparse 16 | try: 17 | import torch_npu 18 | except: 19 | pass 20 | 21 | if hasattr(torch, "npu"): 22 | DEVICE = torch.device("npu:0" if torch_npu.npu.is_available() else "cpu") 23 | else: 24 | DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 25 | 26 | ## for different model configuration 27 | MODEL_TYPE = { 28 | "Open-MAGVIT2": VQModel, 29 | "IBQ": IBQ 30 | } 31 | 32 | def load_vqgan_new(config, model_type, ckpt_path=None, is_gumbel=False): 33 | model = MODEL_TYPE[model_type](**config.model.init_args) 34 | if ckpt_path is not None: 35 | sd = torch.load(ckpt_path, map_location="cpu")["state_dict"] 36 | missing, unexpected = model.load_state_dict(sd, strict=False) 37 | return model.eval() 38 | 39 | def get_obj_from_str(string, reload=False): 40 | print(string) 41 | module, cls = string.rsplit(".", 1) 42 | if reload: 43 | module_imp = importlib.import_module(module) 44 | importlib.reload(module_imp) 45 | return getattr(importlib.import_module(module, package=None), cls) 46 | 47 | def instantiate_from_config(config): 48 | if not "class_path" in config: 49 | raise KeyError("Expected key `class_path` to instantiate.") 50 | return get_obj_from_str(config["class_path"])(**config.get("init_args", dict())) 51 | 52 | def custom_to_pil(x): 53 | x = x.detach().cpu() 54 | x = torch.clamp(x, -1., 1.) 55 | x = (x + 1.)/2. 56 | x = x.permute(1,2,0).numpy() 57 | x = (255*x).astype(np.uint8) 58 | x = Image.fromarray(x) 59 | if not x.mode == "RGB": 60 | x = x.convert("RGB") 61 | return x 62 | 63 | def main(args): 64 | config_file = args.config_file 65 | configs = OmegaConf.load(config_file) 66 | configs.data.init_args.batch_size = args.batch_size # change the batch size 67 | configs.data.init_args.test.params.config.size = args.image_size #using test to inference 68 | configs.data.init_args.test.params.config.subset = args.subset #using the specific data for comparsion 69 | 70 | model = load_vqgan_new(configs, args.model, args.ckpt_path).to(DEVICE) 71 | 72 | visualize_dir = args.save_dir 73 | visualize_version = args.version 74 | visualize_original = os.path.join(visualize_dir, visualize_version, "original_{}".format(args.image_size)) 75 | visualize_rec = os.path.join(visualize_dir, visualize_version, "rec_{}".format(args.image_size)) 76 | if not os.path.exists(visualize_original): 77 | os.makedirs(visualize_original, exist_ok=True) 78 | 79 | if not os.path.exists(visualize_rec): 80 | os.makedirs(visualize_rec, exist_ok=True) 81 | 82 | dataset = instantiate_from_config(configs.data) 83 | dataset.prepare_data() 84 | dataset.setup() 85 | 86 | count = 0 87 | with torch.no_grad(): 88 | for idx, batch in tqdm(enumerate(dataset._val_dataloader())): 89 | if count > args.image_num: 90 | break 91 | images = batch["image"].permute(0, 3, 1, 2).to(DEVICE) 92 | 93 | count += images.shape[0] 94 | if model.use_ema: 95 | with model.ema_scope(): 96 | if args.model == "Open-MAGVIT2": 97 | quant, diff, indices, _ = model.encode(images) 98 | elif args.model == "IBQ": 99 | quant, qloss, (_, _, indices) = model.encode(images) 100 | reconstructed_images = model.decode(quant) 101 | else: 102 | if args.model == "Open-MAGVIT2": 103 | quant, diff, indices, _ = model.encode(images) 104 | elif args.model == "IBQ": 105 | quant, qloss, (_, _, indices) = model.encode(images) 106 | reconstructed_images = model.decode(quant) 107 | 108 | image = images[0] 109 | reconstructed_image = reconstructed_images[0] 110 | 111 | image = custom_to_pil(image) 112 | reconstructed_image = custom_to_pil(reconstructed_image) 113 | 114 | image.save(os.path.join(visualize_original, "{}.png".format(idx))) 115 | reconstructed_image.save(os.path.join(visualize_rec, "{}.png".format(idx))) 116 | 117 | 118 | def get_args(): 119 | parser = argparse.ArgumentParser(description="inference parameters") 120 | parser.add_argument("--config_file", required=True, type=str) 121 | parser.add_argument("--ckpt_path", required=True, type=str) 122 | parser.add_argument("--image_size", default=256, type=int) 123 | parser.add_argument("--batch_size", default=1, type=int) ## inference only using 1 batch size 124 | parser.add_argument("--image_num", default=50, type=int) 125 | parser.add_argument("--subset", default=None) 126 | parser.add_argument("--version", type=str, required=True) 127 | parser.add_argument("--save_dir", type=str, required=True) 128 | parser.add_argument("--model", choices=["Open-MAGVIT2", "IBQ"]) 129 | 130 | return parser.parse_args() 131 | 132 | if __name__ == "__main__": 133 | args = get_args() 134 | main(args) -------------------------------------------------------------------------------- /src/IBQ/modules/losses/lpips.py: -------------------------------------------------------------------------------- 1 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torchvision import models 6 | from collections import namedtuple 7 | 8 | from src.IBQ.util import get_ckpt_path 9 | 10 | 11 | class LPIPS(nn.Module): 12 | # Learned perceptual metric 13 | def __init__(self, use_dropout=True): 14 | super().__init__() 15 | self.scaling_layer = ScalingLayer() 16 | self.chns = [64, 128, 256, 512, 512] # vg16 features 17 | self.net = vgg16(pretrained=True, requires_grad=False) 18 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 19 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 20 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 21 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 22 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 23 | self.load_from_pretrained() 24 | for param in self.parameters(): 25 | param.requires_grad = False 26 | 27 | def load_from_pretrained(self, name="vgg_lpips"): 28 | ckpt = get_ckpt_path(name, "src/IBQ/modules/autoencoder/lpips") 29 | self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 30 | print("loaded pretrained LPIPS loss from {}".format(ckpt)) 31 | 32 | @classmethod 33 | def from_pretrained(cls, name="vgg_lpips"): 34 | if name != "vgg_lpips": 35 | raise NotImplementedError 36 | model = cls() 37 | ckpt = get_ckpt_path(name) 38 | model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 39 | return model 40 | 41 | def forward(self, input, target): 42 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) 43 | outs0, outs1 = self.net(in0_input), self.net(in1_input) 44 | feats0, feats1, diffs = {}, {}, {} 45 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 46 | for kk in range(len(self.chns)): 47 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) 48 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 49 | 50 | res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))] 51 | val = res[0] 52 | for l in range(1, len(self.chns)): 53 | val += res[l] 54 | return val 55 | 56 | 57 | class ScalingLayer(nn.Module): 58 | def __init__(self): 59 | super(ScalingLayer, self).__init__() 60 | self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 61 | self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) 62 | 63 | def forward(self, inp): 64 | return (inp - self.shift) / self.scale 65 | 66 | 67 | class NetLinLayer(nn.Module): 68 | """ A single linear layer which does a 1x1 conv """ 69 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 70 | super(NetLinLayer, self).__init__() 71 | layers = [nn.Dropout(), ] if (use_dropout) else [] 72 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] 73 | self.model = nn.Sequential(*layers) 74 | 75 | 76 | class vgg16(torch.nn.Module): 77 | def __init__(self, requires_grad=False, pretrained=True): 78 | super(vgg16, self).__init__() 79 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features 80 | self.slice1 = torch.nn.Sequential() 81 | self.slice2 = torch.nn.Sequential() 82 | self.slice3 = torch.nn.Sequential() 83 | self.slice4 = torch.nn.Sequential() 84 | self.slice5 = torch.nn.Sequential() 85 | self.N_slices = 5 86 | for x in range(4): 87 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 88 | for x in range(4, 9): 89 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 90 | for x in range(9, 16): 91 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 92 | for x in range(16, 23): 93 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 94 | for x in range(23, 30): 95 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 96 | if not requires_grad: 97 | for param in self.parameters(): 98 | param.requires_grad = False 99 | 100 | def forward(self, X): 101 | h = self.slice1(X) 102 | h_relu1_2 = h 103 | h = self.slice2(h) 104 | h_relu2_2 = h 105 | h = self.slice3(h) 106 | h_relu3_3 = h 107 | h = self.slice4(h) 108 | h_relu4_3 = h 109 | h = self.slice5(h) 110 | h_relu5_3 = h 111 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) 112 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 113 | return out 114 | 115 | 116 | def normalize_tensor(x,eps=1e-10): 117 | norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True)) 118 | return x/(norm_factor+eps) 119 | 120 | 121 | def spatial_average(x, keepdim=True): 122 | return x.mean([2,3],keepdim=keepdim) 123 | 124 | -------------------------------------------------------------------------------- /src/Open_MAGVIT2/util.py: -------------------------------------------------------------------------------- 1 | import os, hashlib 2 | import requests 3 | from tqdm import tqdm 4 | 5 | URL_MAP = { 6 | "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" 7 | } 8 | 9 | CKPT_MAP = { 10 | "vgg_lpips": "vgg.pth" 11 | } 12 | 13 | MD5_MAP = { 14 | "vgg_lpips": "d507d7349b931f0638a25a48a722f98a" 15 | } 16 | 17 | 18 | def download(url, local_path, chunk_size=1024): 19 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 20 | with requests.get(url, stream=True) as r: 21 | total_size = int(r.headers.get("content-length", 0)) 22 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 23 | with open(local_path, "wb") as f: 24 | for data in r.iter_content(chunk_size=chunk_size): 25 | if data: 26 | f.write(data) 27 | pbar.update(chunk_size) 28 | 29 | 30 | def md5_hash(path): 31 | with open(path, "rb") as f: 32 | content = f.read() 33 | return hashlib.md5(content).hexdigest() 34 | 35 | 36 | def get_ckpt_path(name, root, check=False): 37 | assert name in URL_MAP 38 | path = os.path.join(root, CKPT_MAP[name]) 39 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): 40 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) 41 | download(URL_MAP[name], path) 42 | md5 = md5_hash(path) 43 | assert md5 == MD5_MAP[name], md5 44 | return path 45 | 46 | 47 | class KeyNotFoundError(Exception): 48 | def __init__(self, cause, keys=None, visited=None): 49 | self.cause = cause 50 | self.keys = keys 51 | self.visited = visited 52 | messages = list() 53 | if keys is not None: 54 | messages.append("Key not found: {}".format(keys)) 55 | if visited is not None: 56 | messages.append("Visited: {}".format(visited)) 57 | messages.append("Cause:\n{}".format(cause)) 58 | message = "\n".join(messages) 59 | super().__init__(message) 60 | 61 | 62 | def retrieve( 63 | list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False 64 | ): 65 | """Given a nested list or dict return the desired value at key expanding 66 | callable nodes if necessary and :attr:`expand` is ``True``. The expansion 67 | is done in-place. 68 | 69 | Parameters 70 | ---------- 71 | list_or_dict : list or dict 72 | Possibly nested list or dictionary. 73 | key : str 74 | key/to/value, path like string describing all keys necessary to 75 | consider to get to the desired value. List indices can also be 76 | passed here. 77 | splitval : str 78 | String that defines the delimiter between keys of the 79 | different depth levels in `key`. 80 | default : obj 81 | Value returned if :attr:`key` is not found. 82 | expand : bool 83 | Whether to expand callable nodes on the path or not. 84 | 85 | Returns 86 | ------- 87 | The desired value or if :attr:`default` is not ``None`` and the 88 | :attr:`key` is not found returns ``default``. 89 | 90 | Raises 91 | ------ 92 | Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is 93 | ``None``. 94 | """ 95 | 96 | keys = key.split(splitval) 97 | 98 | success = True 99 | try: 100 | visited = [] 101 | parent = None 102 | last_key = None 103 | for key in keys: 104 | if callable(list_or_dict): 105 | if not expand: 106 | raise KeyNotFoundError( 107 | ValueError( 108 | "Trying to get past callable node with expand=False." 109 | ), 110 | keys=keys, 111 | visited=visited, 112 | ) 113 | list_or_dict = list_or_dict() 114 | parent[last_key] = list_or_dict 115 | 116 | last_key = key 117 | parent = list_or_dict 118 | 119 | try: 120 | if isinstance(list_or_dict, dict): 121 | list_or_dict = list_or_dict[key] 122 | else: 123 | list_or_dict = list_or_dict[int(key)] 124 | except (KeyError, IndexError, ValueError) as e: 125 | raise KeyNotFoundError(e, keys=keys, visited=visited) 126 | 127 | visited += [key] 128 | # final expansion of retrieved value 129 | if expand and callable(list_or_dict): 130 | list_or_dict = list_or_dict() 131 | parent[last_key] = list_or_dict 132 | except KeyNotFoundError as e: 133 | if default is None: 134 | raise e 135 | else: 136 | list_or_dict = default 137 | success = False 138 | 139 | if not pass_success: 140 | return list_or_dict 141 | else: 142 | return list_or_dict, success 143 | 144 | 145 | if __name__ == "__main__": 146 | config = {"keya": "a", 147 | "keyb": "b", 148 | "keyc": 149 | {"cc1": 1, 150 | "cc2": 2, 151 | } 152 | } 153 | from omegaconf import OmegaConf 154 | config = OmegaConf.create(config) 155 | print(config) 156 | retrieve(config, "keya") 157 | 158 | -------------------------------------------------------------------------------- /src/IBQ/util.py: -------------------------------------------------------------------------------- 1 | import os, hashlib 2 | import requests 3 | from tqdm import tqdm 4 | 5 | URL_MAP = { 6 | "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" 7 | } 8 | 9 | CKPT_MAP = { 10 | "vgg_lpips": "vgg.pth" 11 | } 12 | 13 | MD5_MAP = { 14 | "vgg_lpips": "d507d7349b931f0638a25a48a722f98a" 15 | } 16 | 17 | 18 | def download(url, local_path, chunk_size=1024): 19 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 20 | with requests.get(url, stream=True) as r: 21 | total_size = int(r.headers.get("content-length", 0)) 22 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 23 | with open(local_path, "wb") as f: 24 | for data in r.iter_content(chunk_size=chunk_size): 25 | if data: 26 | f.write(data) 27 | pbar.update(chunk_size) 28 | 29 | 30 | def md5_hash(path): 31 | with open(path, "rb") as f: 32 | content = f.read() 33 | return hashlib.md5(content).hexdigest() 34 | 35 | 36 | def get_ckpt_path(name, root, check=False): 37 | assert name in URL_MAP 38 | path = os.path.join(root, CKPT_MAP[name]) 39 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): 40 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) 41 | download(URL_MAP[name], path) 42 | md5 = md5_hash(path) 43 | assert md5 == MD5_MAP[name], md5 44 | return path 45 | 46 | 47 | class KeyNotFoundError(Exception): 48 | def __init__(self, cause, keys=None, visited=None): 49 | self.cause = cause 50 | self.keys = keys 51 | self.visited = visited 52 | messages = list() 53 | if keys is not None: 54 | messages.append("Key not found: {}".format(keys)) 55 | if visited is not None: 56 | messages.append("Visited: {}".format(visited)) 57 | messages.append("Cause:\n{}".format(cause)) 58 | message = "\n".join(messages) 59 | super().__init__(message) 60 | 61 | 62 | def retrieve( 63 | list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False 64 | ): 65 | """Given a nested list or dict return the desired value at key expanding 66 | callable nodes if necessary and :attr:`expand` is ``True``. The expansion 67 | is done in-place. 68 | 69 | Parameters 70 | ---------- 71 | list_or_dict : list or dict 72 | Possibly nested list or dictionary. 73 | key : str 74 | key/to/value, path like string describing all keys necessary to 75 | consider to get to the desired value. List indices can also be 76 | passed here. 77 | splitval : str 78 | String that defines the delimiter between keys of the 79 | different depth levels in `key`. 80 | default : obj 81 | Value returned if :attr:`key` is not found. 82 | expand : bool 83 | Whether to expand callable nodes on the path or not. 84 | 85 | Returns 86 | ------- 87 | The desired value or if :attr:`default` is not ``None`` and the 88 | :attr:`key` is not found returns ``default``. 89 | 90 | Raises 91 | ------ 92 | Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is 93 | ``None``. 94 | """ 95 | 96 | keys = key.split(splitval) 97 | 98 | success = True 99 | try: 100 | visited = [] 101 | parent = None 102 | last_key = None 103 | for key in keys: 104 | if callable(list_or_dict): 105 | if not expand: 106 | raise KeyNotFoundError( 107 | ValueError( 108 | "Trying to get past callable node with expand=False." 109 | ), 110 | keys=keys, 111 | visited=visited, 112 | ) 113 | list_or_dict = list_or_dict() 114 | parent[last_key] = list_or_dict 115 | 116 | last_key = key 117 | parent = list_or_dict 118 | 119 | try: 120 | if isinstance(list_or_dict, dict): 121 | list_or_dict = list_or_dict[key] 122 | else: 123 | list_or_dict = list_or_dict[int(key)] 124 | except (KeyError, IndexError, ValueError) as e: 125 | raise KeyNotFoundError(e, keys=keys, visited=visited) 126 | 127 | visited += [key] 128 | # final expansion of retrieved value 129 | if expand and callable(list_or_dict): 130 | list_or_dict = list_or_dict() 131 | parent[last_key] = list_or_dict 132 | except KeyNotFoundError as e: 133 | if default is None: 134 | raise e 135 | else: 136 | list_or_dict = default 137 | success = False 138 | 139 | if not pass_success: 140 | return list_or_dict 141 | else: 142 | return list_or_dict, success 143 | 144 | 145 | if __name__ == "__main__": 146 | config = {"keya": "a", 147 | "keyb": "b", 148 | "keyc": 149 | {"cc1": 1, 150 | "cc2": 2, 151 | } 152 | } 153 | from omegaconf import OmegaConf 154 | config = OmegaConf.create(config) 155 | print(config) 156 | retrieve(config, "keya") 157 | 158 | --------------------------------------------------------------------------------