├── LICENSE ├── README.md ├── demo ├── run_fractalgen.ipynb └── visual.gif ├── engine_fractalgen.py ├── environment.yaml ├── fid_stats ├── adm_in256_stats.npz └── adm_in64_stats.npz ├── main_fractalgen.py ├── models ├── ar.py ├── fractalgen.py ├── mar.py └── pixelloss.py └── util ├── crop.py ├── download.py ├── lr_sched.py ├── misc.py └── visualize.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Tianhong Li 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Fractal Generative Models 2 | 3 | [![arXiv](https://img.shields.io/badge/arXiv%20paper-2502.17437-b31b1b.svg)](https://arxiv.org/abs/2502.17437)  4 | [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](http://colab.research.google.com/github/LTH14/fractalgen/blob/main/demo/run_fractalgen.ipynb) 5 | 6 |

7 | 8 |

9 | 10 | This is a PyTorch/GPU implementation of the paper [Fractal Generative Models](https://arxiv.org/abs/2502.17437): 11 | 12 | ``` 13 | @article{li2025fractal, 14 | title={Fractal Generative Models}, 15 | author={Li, Tianhong and Sun, Qinyi and Fan, Lijie and He, Kaiming}, 16 | journal={arXiv preprint arXiv:2502.17437}, 17 | year={2025} 18 | } 19 | ``` 20 | 21 | FractalGen enables pixel-by-pixel high-resolution image generation for the first time. This repo contains: 22 | 23 | * 🪐 A simple PyTorch implementation of [Fractal Generative Model](models/fractalgen.py). 24 | * ⚡️ Pre-trained pixel-by-pixel generation models trained on ImageNet 64x64 and 256x256. 25 | * 💥 A self-contained [Colab notebook](http://colab.research.google.com/github/LTH14/fractalgen/blob/main/demo/run_fractalgen.ipynb) for running pre-trained models tasks. 26 | * 🛸 A [training and evaluation script](main_fractalgen.py) using PyTorch DDP. 27 | 28 | ## Preparation 29 | 30 | ### Dataset 31 | Download [ImageNet](http://image-net.org/download) dataset, and place it in your `IMAGENET_PATH`. 32 | 33 | ### Installation 34 | 35 | Download the code: 36 | ``` 37 | git clone https://github.com/LTH14/fractalgen.git 38 | cd fractalgen 39 | ``` 40 | 41 | A suitable [conda](https://conda.io/) environment named `fractalgen` can be created and activated with: 42 | 43 | ``` 44 | conda env create -f environment.yaml 45 | conda activate fractalgen 46 | ``` 47 | 48 | Download pre-trained models: 49 | 50 | ``` 51 | python util/download.py 52 | ``` 53 | 54 | For convenience, our pre-trained models can be downloaded directly here as well: 55 | 56 | | Model | FID-50K | Inception Score | #params | 57 | |-------------------------------------------------------------------------------------------------------------------------------------------------------|----------|-----------------|-----------| 58 | | [FractalAR (IN64)](https://www.dropbox.com/scl/fi/n25tbij7aqkwo1ypqhz72/checkpoint-last.pth?rlkey=2czevgex3ocg2ae8zde3xpb3f&st=mj0subup&dl=0) | 5.30 | 56.8 | 432M | 59 | | [FractalMAR (IN64)](https://www.dropbox.com/scl/fi/lh7fmv48pusujd6m4kcdn/checkpoint-last.pth?rlkey=huihey61ok32h28o3tbbq6ek9&st=fxtoawba&dl=0) | 2.72 | 87.9 | 432M | 60 | | [FractalMAR-Base (IN256)](https://www.dropbox.com/scl/fi/zrdm7853ih4tcv98wmzhe/checkpoint-last.pth?rlkey=htq9yuzovet7d6ioa64s1xxd0&st=4c4d93vs&dl=0) | 11.80 | 274.3 | 186M | 61 | | [FractalMAR-Large (IN256)](https://www.dropbox.com/scl/fi/y1k05xx7ry8521ckxkqgt/checkpoint-last.pth?rlkey=wolq4krdq7z7eyjnaw5ndhq6k&st=vjeu5uzo&dl=0) | 7.30 | 334.9 | 438M | 62 | | [FractalMAR-Huge (IN256)](https://www.dropbox.com/scl/fi/t2rru8xr6wm23yvxskpww/checkpoint-last.pth?rlkey=dn9ss9zw4zsnckf6bat9hss6h&st=y7w921zo&dl=0) | 6.15 | 348.9 | 848M | 63 | 64 | ## Usage 65 | 66 | ### Demo 67 | Run our interactive visualization [demo](http://colab.research.google.com/github/LTH14/fractalgen/blob/main/demo/run_fractalgen.ipynb) using Colab notebook! 68 | 69 | ### Training 70 | The below training scripts have been tested on 4x8 H100 GPUs. 71 | 72 | Example script for training FractalAR on ImageNet 64x64 for 800 epochs: 73 | ``` 74 | torchrun --nproc_per_node=8 --nnodes=4 --node_rank=${NODE_RANK} --master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} \ 75 | main_fractalgen.py \ 76 | --model fractalar_in64 --img_size 64 --num_conds 1 \ 77 | --batch_size 64 --eval_freq 40 --save_last_freq 10 \ 78 | --epochs 800 --warmup_epochs 40 \ 79 | --blr 5.0e-5 --weight_decay 0.05 --attn_dropout 0.1 --proj_dropout 0.1 --lr_schedule cosine \ 80 | --gen_bsz 256 --num_images 8000 --num_iter_list 64,16 --cfg 11.0 --cfg_schedule linear --temperature 1.03 \ 81 | --output_dir ${OUTPUT_DIR} --resume ${OUTPUT_DIR} \ 82 | --data_path ${IMAGENET_PATH} --grad_checkpointing --online_eval 83 | ``` 84 | 85 | Example script for training FractalMAR on ImageNet 64x64 for 800 epochs: 86 | ``` 87 | torchrun --nproc_per_node=8 --nnodes=4 --node_rank=${NODE_RANK} --master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} \ 88 | main_fractalgen.py \ 89 | --model fractalmar_in64 --img_size 64 --num_conds 5 \ 90 | --batch_size 64 --eval_freq 40 --save_last_freq 10 \ 91 | --epochs 800 --warmup_epochs 40 \ 92 | --blr 5.0e-5 --weight_decay 0.05 --attn_dropout 0.1 --proj_dropout 0.1 --lr_schedule cosine \ 93 | --gen_bsz 256 --num_images 8000 --num_iter_list 64,16 --cfg 6.5 --cfg_schedule linear --temperature 1.02 \ 94 | --output_dir ${OUTPUT_DIR} --resume ${OUTPUT_DIR} \ 95 | --data_path ${IMAGENET_PATH} --grad_checkpointing --online_eval 96 | ``` 97 | 98 | Example script for training FractalMAR-L on ImageNet 256x256 for 800 epochs: 99 | ``` 100 | torchrun --nproc_per_node=8 --nnodes=4 --node_rank=${NODE_RANK} --master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} \ 101 | main_fractalgen.py \ 102 | --model fractalmar_large_in256 --img_size 256 --num_conds 5 --guiding_pixel \ 103 | --batch_size 32 --eval_freq 40 --save_last_freq 10 \ 104 | --epochs 800 --warmup_epochs 40 \ 105 | --blr 5.0e-5 --weight_decay 0.05 --attn_dropout 0.1 --proj_dropout 0.1 --lr_schedule cosine \ 106 | --gen_bsz 256 --num_images 8000 --num_iter_list 64,16,16 --cfg 21.0 --cfg_schedule linear --temperature 1.1 \ 107 | --output_dir ${OUTPUT_DIR} --resume ${OUTPUT_DIR} \ 108 | --data_path ${IMAGENET_PATH} --grad_checkpointing --online_eval 109 | ``` 110 | 111 | ### Evaluation 112 | 113 | Evaluate pre-trained FractalAR on ImageNet 64x64 unconditional likelihood estimation (single GPU): 114 | ``` 115 | torchrun --nproc_per_node=1 --nnodes=1 --node_rank=0 \ 116 | main_fractalgen.py \ 117 | --model fractalar_in64 --img_size 64 --num_conds 1 \ 118 | --nll_bsz 128 --nll_forward_number 1 \ 119 | --output_dir pretrained_models/fractalar_in64 \ 120 | --resume pretrained_models/fractalar_in64 \ 121 | --data_path ${IMAGENET_PATH} --seed 0 --evaluate_nll 122 | ``` 123 | 124 | Evaluate pre-trained FractalMAR on ImageNet 64x64 unconditional likelihood estimation (single GPU): 125 | ``` 126 | torchrun --nproc_per_node=1 --nnodes=1 --node_rank=0 \ 127 | main_fractalgen.py \ 128 | --model fractalmar_in64 --img_size 64 --num_conds 5 \ 129 | --nll_bsz 128 --nll_forward_number 10 \ 130 | --output_dir pretrained_models/fractalmar_in64 \ 131 | --resume pretrained_models/fractalmar_in64 \ 132 | --data_path ${IMAGENET_PATH} --seed 0 --evaluate_nll 133 | ``` 134 | 135 | Evaluate pre-trained FractalAR on ImageNet 64x64 class-conditional generation: 136 | ``` 137 | torchrun --nproc_per_node=8 --nnodes=1 --node_rank=0 \ 138 | main_fractalgen.py \ 139 | --model fractalar_in64 --img_size 64 --num_conds 1 \ 140 | --gen_bsz 512 --num_images 50000 \ 141 | --num_iter_list 64,16 --cfg 11.0 --cfg_schedule linear --temperature 1.03 \ 142 | --output_dir pretrained_models/fractalar_in64 \ 143 | --resume pretrained_models/fractalar_in64 \ 144 | --data_path ${IMAGENET_PATH} --seed 0 --evaluate_gen 145 | ``` 146 | 147 | Evaluate pre-trained FractalMAR on ImageNet 64x64 class-conditional generation: 148 | ``` 149 | torchrun --nproc_per_node=8 --nnodes=1 --node_rank=0 \ 150 | main_fractalgen.py \ 151 | --model fractalmar_in64 --img_size 64 --num_conds 5 \ 152 | --gen_bsz 1024 --num_images 50000 \ 153 | --num_iter_list 64,16 --cfg 6.5 --cfg_schedule linear --temperature 1.02 \ 154 | --output_dir pretrained_models/fractalmar_in64 \ 155 | --resume pretrained_models/fractalmar_in64 \ 156 | --data_path ${IMAGENET_PATH} --seed 0 --evaluate_gen 157 | ``` 158 | 159 | Evaluate pre-trained FractalMAR-Huge on ImageNet 256x256 class-conditional generation: 160 | ``` 161 | torchrun --nproc_per_node=8 --nnodes=1 --node_rank=0 \ 162 | main_fractalgen.py \ 163 | --model fractalmar_huge_in256 --img_size 256 --num_conds 5 --guiding_pixel \ 164 | --gen_bsz 1024 --num_images 50000 \ 165 | --num_iter_list 64,16,16 --cfg 19.0 --cfg_schedule linear --temperature 1.1 \ 166 | --output_dir pretrained_models/fractalmar_huge_in256 \ 167 | --resume pretrained_models/fractalmar_huge_in256 \ 168 | --data_path ${IMAGENET_PATH} --seed 0 --evaluate_gen 169 | ``` 170 | 171 | For ImageNet 256x256, the optimal classifier-free guidance values `--cfg` that achieve the best FID are `29.0` for FractalMAR-Base and `21.0` for FractalMAR-Large. 172 | 173 | ## Acknowledgements 174 | 175 | We thank Google TPU Research Cloud (TRC) for granting us access to TPUs, and Google Cloud Platform for supporting GPU resources. 176 | 177 | ## Contact 178 | 179 | If you have any questions, feel free to contact me through email (tianhong@mit.edu). Enjoy! 180 | -------------------------------------------------------------------------------- /demo/visual.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LTH14/fractalgen/c7d099043dd987ae7742a7f8c8f1cab71023ba0e/demo/visual.gif -------------------------------------------------------------------------------- /engine_fractalgen.py: -------------------------------------------------------------------------------- 1 | import math 2 | import sys 3 | import os 4 | import time 5 | import shutil 6 | from typing import Iterable 7 | 8 | import torch 9 | import torch.nn as nn 10 | import numpy as np 11 | import cv2 12 | 13 | import util.misc as misc 14 | import util.lr_sched as lr_sched 15 | import torch_fidelity 16 | 17 | 18 | def train_one_epoch(model, data_loader: Iterable, optimizer: torch.optim.Optimizer, 19 | device: torch.device, epoch: int, loss_scaler, log_writer=None, args=None): 20 | model.train(True) 21 | metric_logger = misc.MetricLogger(delimiter=" ") 22 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) 23 | header = 'Epoch: [{}]'.format(epoch) 24 | print_freq = 20 25 | 26 | optimizer.zero_grad() 27 | 28 | if log_writer is not None: 29 | print('log_dir: {}'.format(log_writer.log_dir)) 30 | 31 | for data_iter_step, (samples, labels) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 32 | # per iteration (instead of per epoch) lr scheduler 33 | lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args) 34 | 35 | samples = samples.to(device, non_blocking=True) 36 | labels = labels.to(device, non_blocking=True) 37 | 38 | # forward 39 | with torch.cuda.amp.autocast(): 40 | loss = model(samples, labels) 41 | 42 | loss_value = loss.item() 43 | if not math.isfinite(loss_value): 44 | print("Loss is {}, stopping training".format(loss_value)) 45 | sys.exit(1) 46 | 47 | loss_scaler(loss, optimizer, clip_grad=args.grad_clip, parameters=model.parameters(), update_grad=True) 48 | optimizer.zero_grad() 49 | 50 | torch.cuda.synchronize() 51 | 52 | metric_logger.update(loss=loss_value) 53 | lr = optimizer.param_groups[0]["lr"] 54 | metric_logger.update(lr=lr) 55 | 56 | loss_value_reduce = misc.all_reduce_mean(loss_value) 57 | if log_writer is not None: 58 | # Use epoch_1000x as the x-axis in TensorBoard to calibrate curves. 59 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) 60 | log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x) 61 | log_writer.add_scalar('lr', lr, epoch_1000x) 62 | 63 | # gather the stats from all processes 64 | metric_logger.synchronize_between_processes() 65 | print("Averaged stats:", metric_logger) 66 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 67 | 68 | 69 | def compute_nll(model: torch.nn.Module, data_loader: Iterable, device: torch.device, N: int): 70 | model.eval() 71 | metric_logger = misc.MetricLogger(delimiter=" ") 72 | header = '' 73 | print_freq = 20 74 | 75 | total_samples = 0 76 | total_bpd = 0.0 77 | 78 | for _, (samples, labels) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 79 | samples = samples.to(device, non_blocking=True) 80 | labels = labels.to(device, non_blocking=True) 81 | 82 | loss = 0.0 83 | # Average multiple forward passes for a stable NLL estimate. 84 | for _ in range(N): 85 | with torch.cuda.amp.autocast(): 86 | with torch.no_grad(): 87 | one_loss = model(samples, labels) 88 | loss += one_loss 89 | loss /= N 90 | loss_value = loss.item() 91 | 92 | # convert loss to bits/dim 93 | bpd_value = loss_value / math.log(2) 94 | total_samples += samples.size(0) 95 | total_bpd += bpd_value * samples.size(0) 96 | 97 | torch.cuda.synchronize() 98 | metric_logger.update(bpd=bpd_value) 99 | 100 | print("BPD: {:.5f}".format(total_bpd / total_samples)) 101 | 102 | 103 | def evaluate(model_without_ddp, args, epoch, batch_size=64, log_writer=None): 104 | model_without_ddp.eval() 105 | world_size = misc.get_world_size() 106 | local_rank = misc.get_rank() 107 | num_steps = args.num_images // (batch_size * world_size) + 1 108 | 109 | # Construct the folder name for saving generated images. 110 | save_folder = os.path.join( 111 | args.output_dir, 112 | "ariter{}-temp{}-{}cfg{}-filter{}-image{}".format( 113 | args.num_iter_list, args.temperature, args.cfg_schedule, 114 | args.cfg, args.filter_threshold, args.num_images 115 | ) 116 | ) 117 | if args.evaluate_gen: 118 | save_folder += "_evaluate" 119 | print("Save to:", save_folder) 120 | if misc.get_rank() == 0 and not os.path.exists(save_folder): 121 | os.makedirs(save_folder) 122 | 123 | # Ensure that the number of images per class is equal. 124 | class_num = args.class_num 125 | assert args.num_images % class_num == 0, "Number of images per class must be the same" 126 | class_label_gen_world = np.arange(0, class_num).repeat(args.num_images // class_num) 127 | class_label_gen_world = np.hstack([class_label_gen_world, np.zeros(50000)]) 128 | 129 | used_time = 0.0 130 | gen_img_cnt = 0 131 | 132 | for i in range(num_steps): 133 | print("Generation step {}/{}".format(i, num_steps)) 134 | 135 | start_idx = world_size * batch_size * i + local_rank * batch_size 136 | end_idx = start_idx + batch_size 137 | labels_gen = class_label_gen_world[start_idx:end_idx] 138 | labels_gen = torch.Tensor(labels_gen).long().cuda() 139 | 140 | torch.cuda.synchronize() 141 | start_time = time.time() 142 | 143 | # generation 144 | with torch.no_grad(): 145 | with torch.cuda.amp.autocast(): 146 | class_embedding = model_without_ddp.class_emb(labels_gen) 147 | if not args.cfg == 1.0: 148 | # Concatenate fake latent for classifier-free guidance. 149 | class_embedding = torch.cat( 150 | [class_embedding, model_without_ddp.fake_latent.repeat(batch_size, 1)], 151 | dim=0 152 | ) 153 | sampled_images = model_without_ddp.sample( 154 | cond_list=[class_embedding for _ in range(args.num_conds)], 155 | num_iter_list=[int(num_iter) for num_iter in args.num_iter_list.split(",")], 156 | cfg=args.cfg, cfg_schedule=args.cfg_schedule, 157 | temperature=args.temperature, 158 | filter_threshold=args.filter_threshold, 159 | fractal_level=0 160 | ) 161 | 162 | # Measure generation speed (skip first batch). 163 | torch.cuda.synchronize() 164 | batch_time = time.time() - start_time 165 | if i >= 1: 166 | used_time += batch_time 167 | gen_img_cnt += batch_size 168 | print("Generating {} images takes {:.5f} seconds, {:.5f} sec per image".format(gen_img_cnt, used_time, used_time / gen_img_cnt)) 169 | 170 | torch.distributed.barrier() 171 | 172 | # Denormalize images. 173 | pix_mean = torch.Tensor([0.485, 0.456, 0.406]).cuda().view(1, -1, 1, 1) 174 | pix_std = torch.Tensor([0.229, 0.224, 0.225]).cuda().view(1, -1, 1, 1) 175 | sampled_images = sampled_images * pix_std + pix_mean 176 | sampled_images = sampled_images.detach().cpu() 177 | 178 | # distributed save images 179 | for b_id in range(sampled_images.size(0)): 180 | img_id = i * sampled_images.size(0) * world_size + local_rank * sampled_images.size(0) + b_id 181 | if img_id >= args.num_images: 182 | break 183 | gen_img = np.round(np.clip(sampled_images[b_id].numpy().transpose([1, 2, 0]) * 255, 0, 255)) 184 | gen_img = gen_img.astype(np.uint8)[:, :, ::-1] 185 | cv2.imwrite(os.path.join(save_folder, '{}.png'.format(str(img_id).zfill(5))), gen_img) 186 | 187 | torch.distributed.barrier() 188 | time.sleep(10) 189 | 190 | # compute FID and IS 191 | if log_writer is not None: 192 | if args.img_size == 64: 193 | fid_statistics_file = 'fid_stats/adm_in64_stats.npz' 194 | elif args.img_size == 256: 195 | fid_statistics_file = 'fid_stats/adm_in256_stats.npz' 196 | else: 197 | raise NotImplementedError 198 | metrics_dict = torch_fidelity.calculate_metrics( 199 | input1=save_folder, 200 | input2=None, 201 | fid_statistics_file=fid_statistics_file, 202 | cuda=True, 203 | isc=True, 204 | fid=True, 205 | kid=False, 206 | prc=False, 207 | verbose=False, 208 | ) 209 | fid = metrics_dict['frechet_inception_distance'] 210 | inception_score = metrics_dict['inception_score_mean'] 211 | postfix = "_cfg{}".format(args.cfg) 212 | log_writer.add_scalar('fid{}'.format(postfix), fid, epoch) 213 | log_writer.add_scalar('is{}'.format(postfix), inception_score, epoch) 214 | print("FID: {:.4f}, Inception Score: {:.4f}".format(fid, inception_score)) 215 | if not args.evaluate_gen: 216 | # remove temporal saving folder for online eval 217 | shutil.rmtree(save_folder) 218 | 219 | torch.distributed.barrier() 220 | time.sleep(10) 221 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: fractalgen 2 | channels: 3 | - pytorch 4 | - defaults 5 | - nvidia 6 | dependencies: 7 | - python=3.8.5 8 | - pip=20.3 9 | - pytorch-cuda=11.8 10 | - pytorch=2.2.2 11 | - torchvision=0.17.2 12 | - numpy=1.22 13 | - pip: 14 | - opencv-python==4.1.2.30 15 | - timm==0.9.12 16 | - tensorboard==2.10.0 17 | - scipy==1.9.1 18 | - gdown==5.2.0 19 | - -e git+https://github.com/LTH14/torch-fidelity.git@master#egg=torch-fidelity 20 | -------------------------------------------------------------------------------- /fid_stats/adm_in256_stats.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LTH14/fractalgen/c7d099043dd987ae7742a7f8c8f1cab71023ba0e/fid_stats/adm_in256_stats.npz -------------------------------------------------------------------------------- /fid_stats/adm_in64_stats.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LTH14/fractalgen/c7d099043dd987ae7742a7f8c8f1cab71023ba0e/fid_stats/adm_in64_stats.npz -------------------------------------------------------------------------------- /main_fractalgen.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import numpy as np 4 | import os 5 | import time 6 | from pathlib import Path 7 | 8 | import torch 9 | import torch.backends.cudnn as cudnn 10 | from torch.utils.tensorboard import SummaryWriter 11 | import torchvision.transforms as transforms 12 | import torchvision.datasets as datasets 13 | 14 | from util.crop import center_crop_arr 15 | import util.misc as misc 16 | from util.misc import NativeScalerWithGradNormCount as NativeScaler 17 | 18 | from models import fractalgen 19 | from engine_fractalgen import train_one_epoch, compute_nll, evaluate 20 | 21 | 22 | def get_args_parser(): 23 | parser = argparse.ArgumentParser('Fractal Generative Models', add_help=False) 24 | parser.add_argument('--batch_size', default=64, type=int, 25 | help='Batch size per GPU (effective batch size = batch_size * # GPUs)') 26 | parser.add_argument('--epochs', default=400, type=int) 27 | parser.add_argument('--seed', default=0, type=int) 28 | parser.add_argument('--resume', default='', 29 | help='Folder that contains checkpoint to resume from') 30 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 31 | help='Starting epoch') 32 | parser.add_argument('--num_workers', default=10, type=int) 33 | parser.add_argument('--pin_mem', action='store_true', 34 | help='Pin CPU memory in DataLoader for faster GPU transfers') 35 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') 36 | parser.set_defaults(pin_mem=True) 37 | 38 | # Model parameters 39 | parser.add_argument('--model', default='fractalmar_in64', type=str, metavar='MODEL', 40 | help='Name of the model to train') 41 | parser.add_argument('--img_size', default=64, type=int, help='Image size') 42 | 43 | # Generation parameters 44 | parser.add_argument('--num_iter_list', default='64,16', type=str, 45 | help='Number of autoregressive iterations for each fractal level') 46 | parser.add_argument('--num_images', default=50000, type=int, 47 | help='Number of images to generate') 48 | parser.add_argument('--cfg', default=1.0, type=float, 49 | help='Classifier-free guidance factor') 50 | parser.add_argument('--cfg_schedule', default='linear', type=str) 51 | parser.add_argument('--temperature', default=1.0, type=float, 52 | help='Sampling temperature') 53 | parser.add_argument('--filter_threshold', default=1e-4, type=float, 54 | help='Filter threshold for low probability tokens in cfg') 55 | parser.add_argument('--label_drop_prob', default=0.1, type=float) 56 | parser.add_argument('--eval_freq', type=int, default=40, 57 | help='Frequency (in epochs) for evaluation') 58 | parser.add_argument('--save_last_freq', type=int, default=5, 59 | help='Frequency (in epochs) to save checkpoints') 60 | parser.add_argument('--online_eval', action='store_true') 61 | parser.add_argument('--evaluate_gen', action='store_true') 62 | parser.add_argument('--evaluate_nll', action='store_true') 63 | parser.add_argument('--gen_bsz', type=int, default=1024, 64 | help='Generation batch size') 65 | parser.add_argument('--nll_bsz', type=int, default=128, 66 | help='NLL evaluation batch size') 67 | parser.add_argument('--nll_forward_number', type=int, default=1, 68 | help='Number of forward passes used to evaluate the NLL for each data sample. ' 69 | 'This does not affect the NLL of AR model, but for the MAR model, multiple passes (each ' 70 | 'randomly sampling a masking ratio) result in a more accurate NLL estimation.' 71 | ) 72 | # Optimizer parameters 73 | parser.add_argument('--weight_decay', type=float, default=0.05, 74 | help='Weight decay (default: 0.05)') 75 | parser.add_argument('--grad_checkpointing', action='store_true') 76 | parser.add_argument('--lr', type=float, default=None, metavar='LR', 77 | help='Learning rate (absolute)') 78 | parser.add_argument('--blr', type=float, default=5e-5, metavar='LR', 79 | help='Base learning rate: absolute_lr = base_lr * total_batch_size / 256') 80 | parser.add_argument('--min_lr', type=float, default=0., metavar='LR', 81 | help='Minimum LR for cyclic schedulers that hit 0') 82 | parser.add_argument('--lr_schedule', type=str, default='cosine', 83 | help='Learning rate schedule') 84 | parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N', 85 | help='Epochs to warm up LR') 86 | 87 | # Fractal generator parameters 88 | parser.add_argument('--guiding_pixel', action='store_true', 89 | help='Use guiding pixels') 90 | parser.add_argument('--num_conds', type=int, default=1, 91 | help='Number of conditions to use') 92 | parser.add_argument('--r_weight', type=float, default=5.0, 93 | help='Loss weight on the red channel') 94 | parser.add_argument('--grad_clip', type=float, default=3.0, 95 | help='Gradient clipping value') 96 | parser.add_argument('--attn_dropout', type=float, default=0.1, 97 | help='Attention dropout rate') 98 | parser.add_argument('--proj_dropout', type=float, default=0.1, 99 | help='Projection dropout rate') 100 | 101 | # Dataset parameters 102 | parser.add_argument('--data_path', default='./data/imagenet', type=str, 103 | help='Path to the dataset') 104 | parser.add_argument('--class_num', default=1000, type=int) 105 | parser.add_argument('--output_dir', default='./output_dir', 106 | help='Directory to save outputs (empty for no saving)') 107 | parser.add_argument('--device', default='cuda', 108 | help='Device to use for training/testing') 109 | 110 | # Distributed training parameters 111 | parser.add_argument('--world_size', default=1, type=int, 112 | help='Number of distributed processes') 113 | parser.add_argument('--local_rank', default=-1, type=int) 114 | parser.add_argument('--dist_on_itp', action='store_true') 115 | parser.add_argument('--dist_url', default='env://', 116 | help='URL used to set up distributed training') 117 | 118 | return parser 119 | 120 | 121 | def main(args): 122 | misc.init_distributed_mode(args) 123 | print('Job directory:', os.path.dirname(os.path.realpath(__file__))) 124 | print("Arguments:\n{}".format(args).replace(', ', ',\n')) 125 | 126 | device = torch.device(args.device) 127 | 128 | # Set seeds for reproducibility 129 | seed = args.seed + misc.get_rank() 130 | torch.manual_seed(seed) 131 | np.random.seed(seed) 132 | 133 | cudnn.benchmark = True 134 | 135 | num_tasks = misc.get_world_size() 136 | global_rank = misc.get_rank() 137 | 138 | # Set up TensorBoard logging (only on main process) 139 | if global_rank == 0 and args.output_dir is not None: 140 | os.makedirs(args.output_dir, exist_ok=True) 141 | log_writer = SummaryWriter(log_dir=args.output_dir) 142 | else: 143 | log_writer = None 144 | 145 | # Data augmentation transforms (following DiT and ADM) 146 | transform_train = transforms.Compose([ 147 | transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.img_size)), 148 | transforms.RandomHorizontalFlip(), 149 | transforms.ToTensor(), 150 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 151 | ]) 152 | transform_val = transforms.Compose([ 153 | transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.img_size)), 154 | transforms.ToTensor(), 155 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 156 | ]) 157 | 158 | dataset_train = datasets.ImageFolder(os.path.join(args.data_path, 'train'), transform=transform_train) 159 | dataset_val = datasets.ImageFolder(os.path.join(args.data_path, 'val'), transform=transform_val) 160 | 161 | sampler_train = torch.utils.data.DistributedSampler( 162 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 163 | ) 164 | print("Sampler_train =", sampler_train) 165 | 166 | data_loader_train = torch.utils.data.DataLoader( 167 | dataset_train, sampler=sampler_train, 168 | batch_size=args.batch_size, 169 | num_workers=args.num_workers, 170 | pin_memory=args.pin_mem, 171 | drop_last=True, 172 | ) 173 | data_loader_val = torch.utils.data.DataLoader( 174 | dataset_val, shuffle=True, 175 | batch_size=args.nll_bsz, 176 | num_workers=args.num_workers, 177 | pin_memory=args.pin_mem, 178 | drop_last=False, 179 | ) 180 | 181 | # Create fractal generative model 182 | model = fractalgen.__dict__[args.model]( 183 | label_drop_prob=args.label_drop_prob, 184 | class_num=args.class_num, 185 | attn_dropout=args.attn_dropout, 186 | proj_dropout=args.proj_dropout, 187 | guiding_pixel=args.guiding_pixel, 188 | num_conds=args.num_conds, 189 | r_weight=args.r_weight, 190 | grad_checkpointing=args.grad_checkpointing 191 | ) 192 | 193 | print("Model =", model) 194 | n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 195 | print("Number of trainable parameters: {:.2f}M".format(n_params / 1e6)) 196 | 197 | model.to(device) 198 | 199 | eff_batch_size = args.batch_size * misc.get_world_size() 200 | if args.lr is None: # only base_lr (blr) is specified 201 | args.lr = args.blr * eff_batch_size / 256 202 | 203 | print("Base lr: {:.2e}".format(args.lr * 256 / eff_batch_size)) 204 | print("Actual lr: {:.2e}".format(args.lr)) 205 | print("Effective batch size: %d" % eff_batch_size) 206 | 207 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 208 | model_without_ddp = model.module 209 | 210 | # Set up optimizer with weight decay adjustment for bias and norm layers 211 | param_groups = misc.add_weight_decay(model_without_ddp, args.weight_decay) 212 | optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95)) 213 | print(optimizer) 214 | loss_scaler = NativeScaler() 215 | 216 | # Resume from checkpoint if provided 217 | checkpoint_path = os.path.join(args.resume, "checkpoint-last.pth") if args.resume else None 218 | if checkpoint_path and os.path.exists(checkpoint_path): 219 | checkpoint = torch.load(checkpoint_path, map_location='cpu') 220 | model_without_ddp.load_state_dict(checkpoint['model']) 221 | print("Resumed checkpoint from", args.resume) 222 | 223 | if 'optimizer' in checkpoint and 'epoch' in checkpoint: 224 | optimizer.load_state_dict(checkpoint['optimizer']) 225 | args.start_epoch = checkpoint['epoch'] + 1 226 | if 'scaler' in checkpoint: 227 | loss_scaler.load_state_dict(checkpoint['scaler']) 228 | print("Loaded optimizer & scaler state!") 229 | del checkpoint 230 | else: 231 | print("Training from scratch") 232 | 233 | # Evaluation modes 234 | if args.evaluate_gen: 235 | torch.cuda.empty_cache() 236 | evaluate(model_without_ddp, args, 0, batch_size=args.gen_bsz, log_writer=log_writer) 237 | return 238 | 239 | if args.evaluate_nll: 240 | torch.cuda.empty_cache() 241 | compute_nll(model, data_loader_val, device, N=args.nll_forward_number) 242 | return 243 | 244 | # Training loop 245 | print(f"Start training for {args.epochs} epochs") 246 | start_time = time.time() 247 | for epoch in range(args.start_epoch, args.epochs): 248 | if args.distributed: 249 | data_loader_train.sampler.set_epoch(epoch) 250 | 251 | train_one_epoch( 252 | model, data_loader_train, optimizer, device, epoch, loss_scaler, log_writer=log_writer, args=args 253 | ) 254 | 255 | # Save checkpoint periodically 256 | if epoch % args.save_last_freq == 0 or epoch + 1 == args.epochs: 257 | misc.save_model( 258 | args=args, 259 | model_without_ddp=model_without_ddp, 260 | optimizer=optimizer, 261 | loss_scaler=loss_scaler, 262 | epoch=epoch, 263 | epoch_name="last" 264 | ) 265 | 266 | # Perform online evaluation at specified intervals 267 | if args.online_eval and (epoch % args.eval_freq == 0 or epoch + 1 == args.epochs): 268 | torch.cuda.empty_cache() 269 | evaluate(model_without_ddp, args, epoch, batch_size=args.gen_bsz, log_writer=log_writer) 270 | torch.cuda.empty_cache() 271 | 272 | if misc.is_main_process() and log_writer is not None: 273 | log_writer.flush() 274 | 275 | total_time = time.time() - start_time 276 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 277 | print('Training time:', total_time_str) 278 | 279 | 280 | if __name__ == '__main__': 281 | args = get_args_parser().parse_args() 282 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 283 | main(args) 284 | -------------------------------------------------------------------------------- /models/ar.py: -------------------------------------------------------------------------------- 1 | # Modified from: 2 | # LlamaGen: https://github.com/FoundationVision/LlamaGen/blob/main/autoregressive/models/gpt.py 3 | from dataclasses import dataclass 4 | from typing import Optional 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | from torch.utils.checkpoint import checkpoint 10 | from torch.nn import functional as F 11 | from util.visualize import visualize_patch 12 | import math 13 | 14 | 15 | def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True): 16 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 17 | 18 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 19 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 20 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 21 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 22 | 'survival rate' as the argument. 23 | 24 | """ 25 | if drop_prob == 0. or not training: 26 | return x 27 | keep_prob = 1 - drop_prob 28 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 29 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob) 30 | if keep_prob > 0.0 and scale_by_keep: 31 | random_tensor.div_(keep_prob) 32 | return x * random_tensor 33 | 34 | 35 | class DropPath(torch.nn.Module): 36 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 37 | """ 38 | def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True): 39 | super(DropPath, self).__init__() 40 | self.drop_prob = drop_prob 41 | self.scale_by_keep = scale_by_keep 42 | 43 | def forward(self, x): 44 | return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) 45 | 46 | def extra_repr(self): 47 | return f'drop_prob={round(self.drop_prob,3):0.3f}' 48 | 49 | 50 | def find_multiple(n: int, k: int): 51 | if n % k == 0: 52 | return n 53 | return n + k - (n % k) 54 | 55 | 56 | @dataclass 57 | class ModelArgs: 58 | dim: int = 4096 59 | n_layer: int = 32 60 | n_head: int = 32 61 | n_kv_head: Optional[int] = None 62 | multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 63 | ffn_dim_multiplier: Optional[float] = None 64 | rope_base: float = 10000 65 | norm_eps: float = 1e-5 66 | initializer_range: float = 0.02 67 | 68 | token_dropout_p: float = 0.1 69 | attn_dropout_p: float = 0.0 70 | resid_dropout_p: float = 0.1 71 | ffn_dropout_p: float = 0.1 72 | drop_path_rate: float = 0.0 73 | 74 | num_classes: int = 1000 75 | caption_dim: int = 2048 76 | class_dropout_prob: float = 0.1 77 | model_type: str = 'c2i' 78 | 79 | vocab_size: int = 16384 80 | cls_token_num: int = 1 81 | block_size: int = 256 82 | max_batch_size: int = 32 83 | max_seq_len: int = 2048 84 | 85 | 86 | ################################################################################# 87 | # Embedding Layers for Class Labels # 88 | ################################################################################# 89 | class LabelEmbedder(nn.Module): 90 | """ 91 | Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. 92 | """ 93 | 94 | def __init__(self, num_classes, hidden_size, dropout_prob): 95 | super().__init__() 96 | use_cfg_embedding = dropout_prob > 0 97 | self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) 98 | self.num_classes = num_classes 99 | self.dropout_prob = dropout_prob 100 | 101 | def token_drop(self, labels, force_drop_ids=None): 102 | """ 103 | Drops labels to enable classifier-free guidance. 104 | """ 105 | if force_drop_ids is None: 106 | drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob 107 | else: 108 | drop_ids = force_drop_ids == 1 109 | labels = torch.where(drop_ids, self.num_classes, labels) 110 | return labels 111 | 112 | def forward(self, labels, train, force_drop_ids=None): 113 | use_dropout = self.dropout_prob > 0 114 | if (train and use_dropout) or (force_drop_ids is not None): 115 | labels = self.token_drop(labels, force_drop_ids) 116 | embeddings = self.embedding_table(labels).unsqueeze(1) 117 | return embeddings 118 | 119 | 120 | ################################################################################# 121 | # GPT Model # 122 | ################################################################################# 123 | class RMSNorm(torch.nn.Module): 124 | def __init__(self, dim: int, eps: float = 1e-5): 125 | super().__init__() 126 | self.eps = eps 127 | self.weight = nn.Parameter(torch.ones(dim)) 128 | 129 | def _norm(self, x): 130 | return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) 131 | 132 | def forward(self, x): 133 | output = self._norm(x.float()).type_as(x) 134 | return output * self.weight 135 | 136 | 137 | class FeedForward(nn.Module): 138 | def __init__(self, config: ModelArgs): 139 | super().__init__() 140 | hidden_dim = 4 * config.dim 141 | hidden_dim = int(2 * hidden_dim / 3) 142 | # custom dim factor multiplier 143 | if config.ffn_dim_multiplier is not None: 144 | hidden_dim = int(config.ffn_dim_multiplier * hidden_dim) 145 | hidden_dim = find_multiple(hidden_dim, config.multiple_of) 146 | 147 | self.w1 = nn.Linear(config.dim, hidden_dim, bias=False) 148 | self.w3 = nn.Linear(config.dim, hidden_dim, bias=False) 149 | self.w2 = nn.Linear(hidden_dim, config.dim, bias=False) 150 | self.ffn_dropout = nn.Dropout(config.ffn_dropout_p) 151 | 152 | def forward(self, x): 153 | return self.ffn_dropout(self.w2(F.silu(self.w1(x)) * self.w3(x))) 154 | 155 | 156 | class KVCache(nn.Module): 157 | def __init__(self, max_batch_size, max_seq_length, n_head, head_dim): 158 | super().__init__() 159 | cache_shape = (max_batch_size, n_head, max_seq_length, head_dim) 160 | self.register_buffer('k_cache', torch.zeros(cache_shape)) 161 | self.register_buffer('v_cache', torch.zeros(cache_shape)) 162 | 163 | def update(self, input_pos, k_val, v_val): 164 | # input_pos: [S], k_val: [B, H, S, D] 165 | k_out = self.k_cache 166 | v_out = self.v_cache 167 | k_out[:, :, input_pos] = k_val.to(k_out.dtype) 168 | v_out[:, :, input_pos] = v_val.to(k_out.dtype) 169 | 170 | return k_out, v_out 171 | 172 | 173 | def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor: 174 | L, S = query.size(-2), key.size(-2) 175 | scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale 176 | attn_bias = torch.zeros(L, S, dtype=query.dtype).cuda() 177 | if is_causal: 178 | assert attn_mask is None 179 | temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0).cuda() 180 | attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) 181 | attn_bias.to(query.dtype) 182 | 183 | if attn_mask is not None: 184 | if attn_mask.dtype == torch.bool: 185 | attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) 186 | else: 187 | attn_bias += attn_mask 188 | with torch.cuda.amp.autocast(enabled=False): 189 | attn_weight = query.float() @ key.float().transpose(-2, -1) * scale_factor 190 | attn_weight += attn_bias 191 | attn_weight = torch.softmax(attn_weight, dim=-1) 192 | attn_weight = torch.dropout(attn_weight, dropout_p, train=True) 193 | return attn_weight @ value 194 | 195 | 196 | class Attention(nn.Module): 197 | def __init__(self, config: ModelArgs): 198 | super().__init__() 199 | assert config.dim % config.n_head == 0 200 | self.dim = config.dim 201 | self.head_dim = config.dim // config.n_head 202 | self.n_head = config.n_head 203 | self.n_kv_head = config.n_kv_head if config.n_kv_head is not None else config.n_head 204 | total_kv_dim = (self.n_head + 2 * self.n_kv_head) * self.head_dim 205 | 206 | # key, query, value projections for all heads, but in a batch 207 | self.wqkv = nn.Linear(config.dim, total_kv_dim, bias=False) 208 | self.wo = nn.Linear(config.dim, config.dim, bias=False) 209 | self.kv_cache = None 210 | 211 | # regularization 212 | self.attn_dropout_p = config.attn_dropout_p 213 | self.resid_dropout = nn.Dropout(config.resid_dropout_p) 214 | 215 | def forward( 216 | self, x: torch.Tensor, freqs_cis=None, input_pos=None, mask=None 217 | ): 218 | bsz, seqlen, _ = x.shape 219 | kv_size = self.n_kv_head * self.head_dim 220 | xq, xk, xv = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1) 221 | 222 | xq = xq.view(bsz, seqlen, self.n_head, self.head_dim) 223 | xk = xk.view(bsz, seqlen, self.n_kv_head, self.head_dim) 224 | xv = xv.view(bsz, seqlen, self.n_kv_head, self.head_dim) 225 | 226 | xq = apply_rotary_emb(xq, freqs_cis) 227 | xk = apply_rotary_emb(xk, freqs_cis) 228 | 229 | xq, xk, xv = map(lambda x: x.transpose(1, 2), (xq, xk, xv)) 230 | 231 | if self.kv_cache is not None: 232 | keys, values = self.kv_cache.update(input_pos, xk, xv) 233 | else: 234 | keys, values = xk, xv 235 | keys = keys.repeat_interleave(self.n_head // self.n_kv_head, dim=1) 236 | values = values.repeat_interleave(self.n_head // self.n_kv_head, dim=1) 237 | 238 | output = scaled_dot_product_attention( 239 | xq, keys, values, 240 | attn_mask=mask, 241 | is_causal=True if mask is None else False, # is_causal=False is for KV cache 242 | dropout_p=self.attn_dropout_p if self.training else 0) 243 | 244 | output = output.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) 245 | 246 | output = self.resid_dropout(self.wo(output)) 247 | return output 248 | 249 | 250 | class TransformerBlock(nn.Module): 251 | def __init__(self, config: ModelArgs, drop_path: float): 252 | super().__init__() 253 | self.attention = Attention(config) 254 | self.feed_forward = FeedForward(config) 255 | self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps) 256 | self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps) 257 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 258 | 259 | def forward( 260 | self, x: torch.Tensor, freqs_cis: torch.Tensor, start_pos: int, mask: Optional[torch.Tensor] = None): 261 | h = x + self.drop_path(self.attention(self.attention_norm(x), freqs_cis, start_pos, mask)) 262 | out = h + self.drop_path(self.feed_forward(self.ffn_norm(h))) 263 | return out 264 | 265 | 266 | ################################################################################# 267 | # Rotary Positional Embedding Functions # 268 | ################################################################################# 269 | # https://github.com/pytorch-labs/gpt-fast/blob/main/model.py 270 | def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000, cls_token_num=120): 271 | freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)) 272 | t = torch.arange(seq_len, device=freqs.device) 273 | freqs = torch.outer(t, freqs) # (seq_len, head_dim // 2) 274 | freqs_cis = torch.polar(torch.ones_like(freqs), freqs) 275 | cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) # (cls_token_num+seq_len, head_dim // 2, 2) 276 | cond_cache = torch.cat( 277 | [torch.zeros(cls_token_num, n_elem // 2, 2), cache]) # (cls_token_num+seq_len, head_dim // 2, 2) 278 | return cond_cache 279 | 280 | 281 | def precompute_freqs_cis_2d(grid_size: int, n_elem: int, base: int = 10000, cls_token_num=120): 282 | # split the dimension into half, one for x and one for y 283 | half_dim = n_elem // 2 284 | freqs = 1.0 / (base ** (torch.arange(0, half_dim, 2)[: (half_dim // 2)].float() / half_dim)) 285 | t = torch.arange(grid_size, device=freqs.device) 286 | freqs = torch.outer(t, freqs) # (grid_size, head_dim // 2) 287 | freqs_grid = torch.concat([ 288 | freqs[:, None, :].expand(-1, grid_size, -1), 289 | freqs[None, :, :].expand(grid_size, -1, -1), 290 | ], dim=-1) # (grid_size, grid_size, head_dim // 2) 291 | cache_grid = torch.stack([torch.cos(freqs_grid), torch.sin(freqs_grid)], 292 | dim=-1) # (grid_size, grid_size, head_dim // 2, 2) 293 | cache = cache_grid.flatten(0, 1) 294 | cond_cache = torch.cat( 295 | [torch.zeros(cls_token_num, n_elem // 2, 2), cache]) # (cls_token_num+grid_size**2, head_dim // 2, 2) 296 | return cond_cache 297 | 298 | 299 | def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor): 300 | # x: (bs, seq_len, n_head, head_dim) 301 | # freqs_cis (seq_len, head_dim // 2, 2) 302 | xshaped = x.float().reshape(*x.shape[:-1], -1, 2) # (bs, seq_len, n_head, head_dim//2, 2) 303 | freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) # (1, seq_len, 1, head_dim//2, 2) 304 | x_out2 = torch.stack([ 305 | xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1], 306 | xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1], 307 | ], dim=-1) 308 | x_out2 = x_out2.flatten(3) 309 | return x_out2.type_as(x) 310 | 311 | 312 | class AR(nn.Module): 313 | def __init__(self, seq_len, patch_size, cond_embed_dim, embed_dim, num_blocks, num_heads, 314 | grad_checkpointing=False, **kwargs): 315 | super().__init__() 316 | 317 | self.seq_len = seq_len 318 | self.patch_size = patch_size 319 | 320 | self.grad_checkpointing = grad_checkpointing 321 | 322 | # -------------------------------------------------------------------------- 323 | # network 324 | self.patch_emb = nn.Linear(3 * patch_size ** 2, embed_dim, bias=True) 325 | self.patch_emb_ln = nn.LayerNorm(embed_dim, eps=1e-6) 326 | self.pos_embed_learned = nn.Parameter(torch.zeros(1, seq_len+1, embed_dim)) 327 | self.cond_emb = nn.Linear(cond_embed_dim, embed_dim, bias=True) 328 | 329 | self.config = model_args = ModelArgs(dim=embed_dim, n_head=num_heads) 330 | self.blocks = nn.ModuleList([TransformerBlock(config=model_args, drop_path=0.0) for _ in range(num_blocks)]) 331 | 332 | # 2d rotary pos embedding 333 | grid_size = int(seq_len ** 0.5) 334 | assert grid_size * grid_size == seq_len 335 | self.freqs_cis = precompute_freqs_cis_2d(grid_size, model_args.dim // model_args.n_head, 336 | model_args.rope_base, cls_token_num=1).cuda() 337 | 338 | # KVCache 339 | self.max_batch_size = -1 340 | self.max_seq_length = -1 341 | 342 | self.norm = nn.LayerNorm(embed_dim, eps=1e-6) 343 | 344 | self.initialize_weights() 345 | 346 | def initialize_weights(self): 347 | # parameters 348 | torch.nn.init.normal_(self.pos_embed_learned, std=.02) 349 | 350 | # initialize nn.Linear and nn.LayerNorm 351 | self.apply(self._init_weights) 352 | 353 | def _init_weights(self, m): 354 | if isinstance(m, nn.Linear): 355 | # we use xavier_uniform following official JAX ViT: 356 | torch.nn.init.xavier_uniform_(m.weight) 357 | if isinstance(m, nn.Linear) and m.bias is not None: 358 | nn.init.constant_(m.bias, 0) 359 | elif isinstance(m, nn.LayerNorm): 360 | if m.bias is not None: 361 | nn.init.constant_(m.bias, 0) 362 | if m.weight is not None: 363 | nn.init.constant_(m.weight, 1.0) 364 | 365 | def setup_caches(self, max_batch_size, max_seq_length): 366 | # if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size: 367 | # return 368 | head_dim = self.config.dim // self.config.n_head 369 | max_seq_length = find_multiple(max_seq_length, 8) 370 | self.max_seq_length = max_seq_length 371 | self.max_batch_size = max_batch_size 372 | for b in self.blocks: 373 | b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_head, head_dim) 374 | 375 | causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)) 376 | self.causal_mask = causal_mask 377 | grid_size = int(self.seq_len ** 0.5) 378 | assert grid_size * grid_size == self.seq_len 379 | self.freqs_cis = precompute_freqs_cis_2d(grid_size, self.config.dim // self.config.n_head, 380 | self.config.rope_base, 1) 381 | 382 | def patchify(self, x): 383 | bsz, c, h, w = x.shape 384 | p = self.patch_size 385 | h_, w_ = h // p, w // p 386 | 387 | x = x.reshape(bsz, c, h_, p, w_, p) 388 | x = torch.einsum('nchpwq->nhwcpq', x) 389 | x = x.reshape(bsz, h_ * w_, c * p ** 2) 390 | return x # [n, l, d] 391 | 392 | def unpatchify(self, x): 393 | bsz = x.shape[0] 394 | p = self.patch_size 395 | h_, w_ = int(np.sqrt(self.seq_len)), int(np.sqrt(self.seq_len)) 396 | 397 | x = x.reshape(bsz, h_, w_, 3, p, p) 398 | x = torch.einsum('nhwcpq->nchpwq', x) 399 | x = x.reshape(bsz, 3, h_ * p, w_ * p) 400 | return x # [n, 3, h, w] 401 | 402 | def predict(self, x, cond_list, input_pos=None): 403 | x = self.patch_emb(x) 404 | x = torch.cat([self.cond_emb(cond_list[0]).unsqueeze(1).repeat(1, 1, 1), x], dim=1) 405 | 406 | # position embedding 407 | x = x + self.pos_embed_learned[:, :x.shape[1]] 408 | x = self.patch_emb_ln(x) 409 | 410 | if input_pos is not None: 411 | # use kv cache 412 | freqs_cis = self.freqs_cis[input_pos] 413 | mask = self.causal_mask[input_pos] 414 | x = x[:, input_pos] 415 | else: 416 | # training 417 | freqs_cis = self.freqs_cis[:x.shape[1]] 418 | mask = None 419 | 420 | # apply Transformer blocks 421 | if self.grad_checkpointing and not torch.jit.is_scripting() and self.training: 422 | for block in self.blocks: 423 | x = checkpoint(block, x, freqs_cis, input_pos, mask) 424 | else: 425 | for block in self.blocks: 426 | x = block(x, freqs_cis, input_pos, mask) 427 | x = self.norm(x) 428 | 429 | # return middle condition 430 | if input_pos is not None: 431 | middle_cond = x[:, 0] 432 | else: 433 | middle_cond = x[:, :-1] 434 | 435 | return [middle_cond] 436 | 437 | def forward(self, imgs, cond_list): 438 | """ training """ 439 | # patchify to get gt 440 | patches = self.patchify(imgs) 441 | mask = torch.ones(patches.size(0), patches.size(1)).to(patches.device) 442 | 443 | # get condition for next level 444 | cond_list_next = self.predict(patches, cond_list) 445 | 446 | # reshape conditions and patches for next level 447 | for cond_idx in range(len(cond_list_next)): 448 | cond_list_next[cond_idx] = cond_list_next[cond_idx].reshape(cond_list_next[cond_idx].size(0) * cond_list_next[cond_idx].size(1), -1) 449 | 450 | patches = patches.reshape(patches.size(0) * patches.size(1), -1) 451 | patches = patches.reshape(patches.size(0), 3, self.patch_size, self.patch_size) 452 | 453 | return patches, cond_list_next, 0 454 | 455 | def sample(self, cond_list, num_iter, cfg, cfg_schedule, temperature, filter_threshold, next_level_sample_function, 456 | visualize=False): 457 | """ generation """ 458 | if cfg == 1.0: 459 | bsz = cond_list[0].size(0) 460 | else: 461 | bsz = cond_list[0].size(0) // 2 462 | 463 | patches = torch.zeros(bsz, self.seq_len, 3 * self.patch_size**2).cuda() 464 | num_iter = self.seq_len 465 | 466 | device = cond_list[0].device 467 | with torch.device(device): 468 | self.setup_caches(max_batch_size=cond_list[0].size(0), max_seq_length=num_iter) 469 | 470 | # sample 471 | for step in range(num_iter): 472 | cur_patches = patches.clone() 473 | 474 | if not cfg == 1.0: 475 | patches = torch.cat([patches, patches], dim=0) 476 | 477 | # get next level conditions 478 | cond_list_next = self.predict(patches, cond_list, input_pos=torch.Tensor([step]).int()) 479 | # cfg schedule 480 | if cfg_schedule == "linear": 481 | cfg_iter = 1 + (cfg - 1) * (step + 1) / self.seq_len 482 | else: 483 | cfg_iter = cfg 484 | sampled_patches = next_level_sample_function(cond_list=cond_list_next, cfg=cfg_iter, 485 | temperature=temperature, filter_threshold=filter_threshold) 486 | sampled_patches = sampled_patches.reshape(sampled_patches.size(0), -1) 487 | 488 | cur_patches[:, step] = sampled_patches.to(cur_patches.dtype) 489 | patches = cur_patches.clone() 490 | 491 | # visualize generation process for colab 492 | if visualize: 493 | visualize_patch(self.unpatchify(patches)) 494 | 495 | # clean up kv cache 496 | for b in self.blocks: 497 | b.attention.kv_cache = None 498 | patches = self.unpatchify(patches) 499 | return patches 500 | -------------------------------------------------------------------------------- /models/fractalgen.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from models.ar import AR 7 | from models.mar import MAR 8 | from models.pixelloss import PixelLoss 9 | 10 | 11 | class FractalGen(nn.Module): 12 | """ Fractal Generative Model""" 13 | 14 | def __init__(self, 15 | img_size_list, 16 | embed_dim_list, 17 | num_blocks_list, 18 | num_heads_list, 19 | generator_type_list, 20 | label_drop_prob=0.1, 21 | class_num=1000, 22 | attn_dropout=0.1, 23 | proj_dropout=0.1, 24 | guiding_pixel=False, 25 | num_conds=1, 26 | r_weight=1.0, 27 | grad_checkpointing=False, 28 | fractal_level=0): 29 | super().__init__() 30 | 31 | # -------------------------------------------------------------------------- 32 | # fractal specifics 33 | self.fractal_level = fractal_level 34 | self.num_fractal_levels = len(img_size_list) 35 | 36 | # -------------------------------------------------------------------------- 37 | # Class embedding for the first fractal level 38 | if self.fractal_level == 0: 39 | self.num_classes = class_num 40 | self.class_emb = nn.Embedding(class_num, embed_dim_list[0]) 41 | self.label_drop_prob = label_drop_prob 42 | self.fake_latent = nn.Parameter(torch.zeros(1, embed_dim_list[0])) 43 | torch.nn.init.normal_(self.class_emb.weight, std=0.02) 44 | torch.nn.init.normal_(self.fake_latent, std=0.02) 45 | 46 | # -------------------------------------------------------------------------- 47 | # Generator for the current level 48 | if generator_type_list[fractal_level] == "ar": 49 | generator = AR 50 | elif generator_type_list[fractal_level] == "mar": 51 | generator = MAR 52 | else: 53 | raise NotImplementedError 54 | self.generator = generator( 55 | seq_len=(img_size_list[fractal_level] // img_size_list[fractal_level+1]) ** 2, 56 | patch_size=img_size_list[fractal_level+1], 57 | cond_embed_dim=embed_dim_list[fractal_level-1] if fractal_level > 0 else embed_dim_list[0], 58 | embed_dim=embed_dim_list[fractal_level], 59 | num_blocks=num_blocks_list[fractal_level], 60 | num_heads=num_heads_list[fractal_level], 61 | attn_dropout=attn_dropout, 62 | proj_dropout=proj_dropout, 63 | guiding_pixel=guiding_pixel if fractal_level > 0 else False, 64 | num_conds=num_conds, 65 | grad_checkpointing=grad_checkpointing, 66 | ) 67 | 68 | # -------------------------------------------------------------------------- 69 | # Build the next fractal level recursively 70 | if self.fractal_level < self.num_fractal_levels - 2: 71 | self.next_fractal = FractalGen( 72 | img_size_list=img_size_list, 73 | embed_dim_list=embed_dim_list, 74 | num_blocks_list=num_blocks_list, 75 | num_heads_list=num_heads_list, 76 | generator_type_list=generator_type_list, 77 | label_drop_prob=label_drop_prob, 78 | class_num=class_num, 79 | attn_dropout=attn_dropout, 80 | proj_dropout=proj_dropout, 81 | guiding_pixel=guiding_pixel, 82 | num_conds=num_conds, 83 | r_weight=r_weight, 84 | grad_checkpointing=grad_checkpointing, 85 | fractal_level=fractal_level+1 86 | ) 87 | else: 88 | # The final fractal level uses PixelLoss. 89 | self.next_fractal = PixelLoss( 90 | c_channels=embed_dim_list[fractal_level], 91 | depth=num_blocks_list[fractal_level+1], 92 | width=embed_dim_list[fractal_level+1], 93 | num_heads=num_heads_list[fractal_level+1], 94 | r_weight=r_weight, 95 | ) 96 | 97 | def forward(self, imgs, cond_list): 98 | """ 99 | Forward pass to get loss recursively. 100 | """ 101 | if self.fractal_level == 0: 102 | # Compute class embedding conditions. 103 | class_embedding = self.class_emb(cond_list) 104 | if self.training: 105 | # Randomly drop labels according to label_drop_prob. 106 | drop_latent_mask = (torch.rand(cond_list.size(0)) < self.label_drop_prob).unsqueeze(-1).cuda().to(class_embedding.dtype) 107 | class_embedding = drop_latent_mask * self.fake_latent + (1 - drop_latent_mask) * class_embedding 108 | else: 109 | # For evaluation (unconditional NLL), use a constant mask. 110 | drop_latent_mask = torch.ones(cond_list.size(0)).unsqueeze(-1).cuda().to(class_embedding.dtype) 111 | class_embedding = drop_latent_mask * self.fake_latent + (1 - drop_latent_mask) * class_embedding 112 | cond_list = [class_embedding for _ in range(5)] 113 | 114 | # Get image patches and conditions for the next level 115 | imgs, cond_list, guiding_pixel_loss = self.generator(imgs, cond_list) 116 | # Compute loss recursively from the next fractal level. 117 | loss = self.next_fractal(imgs, cond_list) 118 | return loss + guiding_pixel_loss 119 | 120 | def sample(self, cond_list, num_iter_list, cfg, cfg_schedule, temperature, filter_threshold, fractal_level, 121 | visualize=False): 122 | """ 123 | Generate samples recursively. 124 | """ 125 | if fractal_level < self.num_fractal_levels - 2: 126 | next_level_sample_function = partial( 127 | self.next_fractal.sample, 128 | num_iter_list=num_iter_list, 129 | cfg_schedule="constant", 130 | fractal_level=fractal_level + 1 131 | ) 132 | else: 133 | next_level_sample_function = self.next_fractal.sample 134 | 135 | # Recursively sample using the current generator. 136 | return self.generator.sample( 137 | cond_list, num_iter_list[fractal_level], cfg, cfg_schedule, 138 | temperature, filter_threshold, next_level_sample_function, visualize 139 | ) 140 | 141 | 142 | def fractalar_in64(**kwargs): 143 | model = FractalGen( 144 | img_size_list=(64, 4, 1), 145 | embed_dim_list=(1024, 512, 128), 146 | num_blocks_list=(32, 8, 3), 147 | num_heads_list=(16, 8, 4), 148 | generator_type_list=("ar", "ar", "ar"), 149 | fractal_level=0, 150 | **kwargs) 151 | return model 152 | 153 | 154 | def fractalmar_in64(**kwargs): 155 | model = FractalGen( 156 | img_size_list=(64, 4, 1), 157 | embed_dim_list=(1024, 512, 128), 158 | num_blocks_list=(32, 8, 3), 159 | num_heads_list=(16, 8, 4), 160 | generator_type_list=("mar", "mar", "ar"), 161 | fractal_level=0, 162 | **kwargs) 163 | return model 164 | 165 | 166 | def fractalmar_base_in256(**kwargs): 167 | model = FractalGen( 168 | img_size_list=(256, 16, 4, 1), 169 | embed_dim_list=(768, 384, 192, 64), 170 | num_blocks_list=(24, 6, 3, 1), 171 | num_heads_list=(12, 6, 3, 4), 172 | generator_type_list=("mar", "mar", "mar", "ar"), 173 | fractal_level=0, 174 | **kwargs) 175 | return model 176 | 177 | 178 | def fractalmar_large_in256(**kwargs): 179 | model = FractalGen( 180 | img_size_list=(256, 16, 4, 1), 181 | embed_dim_list=(1024, 512, 256, 64), 182 | num_blocks_list=(32, 8, 4, 1), 183 | num_heads_list=(16, 8, 4, 4), 184 | generator_type_list=("mar", "mar", "mar", "ar"), 185 | fractal_level=0, 186 | **kwargs) 187 | return model 188 | 189 | 190 | def fractalmar_huge_in256(**kwargs): 191 | model = FractalGen( 192 | img_size_list=(256, 16, 4, 1), 193 | embed_dim_list=(1280, 640, 320, 64), 194 | num_blocks_list=(40, 10, 5, 1), 195 | num_heads_list=(16, 8, 4, 4), 196 | generator_type_list=("mar", "mar", "mar", "ar"), 197 | fractal_level=0, 198 | **kwargs) 199 | return model 200 | -------------------------------------------------------------------------------- /models/mar.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import math 4 | import numpy as np 5 | import scipy.stats as stats 6 | import torch 7 | import torch.nn as nn 8 | from torch.utils.checkpoint import checkpoint 9 | from util.visualize import visualize_patch 10 | 11 | from timm.models.vision_transformer import DropPath, Mlp 12 | from models.pixelloss import PixelLoss 13 | 14 | 15 | def mask_by_order(mask_len, order, bsz, seq_len): 16 | masking = torch.zeros(bsz, seq_len).cuda() 17 | masking = torch.scatter(masking, dim=-1, index=order[:, :mask_len.long()], src=torch.ones(bsz, seq_len).cuda()).bool() 18 | return masking 19 | 20 | 21 | class Attention(nn.Module): 22 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 23 | super().__init__() 24 | self.num_heads = num_heads 25 | head_dim = dim // num_heads 26 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 27 | self.scale = qk_scale or head_dim ** -0.5 28 | 29 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 30 | self.attn_drop = nn.Dropout(attn_drop) 31 | self.proj = nn.Linear(dim, dim) 32 | self.proj_drop = nn.Dropout(proj_drop) 33 | 34 | def forward(self, x): 35 | B, N, C = x.shape 36 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 37 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 38 | 39 | with torch.cuda.amp.autocast(enabled=False): 40 | attn = (q.float() @ k.float().transpose(-2, -1)) * self.scale 41 | 42 | attn = attn - torch.max(attn, dim=-1, keepdim=True)[0] 43 | attn = attn.softmax(dim=-1) 44 | attn = self.attn_drop(attn) 45 | 46 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 47 | x = self.proj(x) 48 | x = self.proj_drop(x) 49 | return x 50 | 51 | 52 | class Block(nn.Module): 53 | 54 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, proj_drop=0., attn_drop=0., 55 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 56 | super().__init__() 57 | self.norm1 = norm_layer(dim) 58 | self.attn = Attention( 59 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=proj_drop) 60 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 61 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 62 | self.norm2 = norm_layer(dim) 63 | mlp_hidden_dim = int(dim * mlp_ratio) 64 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=proj_drop) 65 | 66 | def forward(self, x): 67 | x = x + self.drop_path(self.attn(self.norm1(x))) 68 | x = x + self.drop_path(self.mlp(self.norm2(x))) 69 | return x 70 | 71 | 72 | class MAR(nn.Module): 73 | def __init__(self, seq_len, patch_size, cond_embed_dim, embed_dim, num_blocks, num_heads, attn_dropout, proj_dropout, 74 | num_conds=1, guiding_pixel=False, grad_checkpointing=False 75 | ): 76 | super().__init__() 77 | 78 | self.seq_len = seq_len 79 | self.patch_size = patch_size 80 | 81 | self.num_conds = num_conds 82 | self.guiding_pixel = guiding_pixel 83 | 84 | self.grad_checkpointing = grad_checkpointing 85 | 86 | # -------------------------------------------------------------------------- 87 | # variant masking ratio 88 | self.mask_ratio_generator = stats.truncnorm(-4, 0, loc=1.0, scale=0.25) 89 | 90 | # -------------------------------------------------------------------------- 91 | # network 92 | self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 93 | self.patch_emb = nn.Linear(3 * patch_size ** 2, embed_dim, bias=True) 94 | self.patch_emb_ln = nn.LayerNorm(embed_dim, eps=1e-6) 95 | self.cond_emb = nn.Linear(cond_embed_dim, embed_dim, bias=True) 96 | if self.guiding_pixel: 97 | self.pix_proj = nn.Linear(3, embed_dim, bias=True) 98 | self.pos_embed_learned = nn.Parameter(torch.zeros(1, seq_len+num_conds+self.guiding_pixel, embed_dim)) 99 | 100 | self.blocks = nn.ModuleList([ 101 | Block(embed_dim, num_heads, mlp_ratio=4., 102 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), 103 | proj_drop=proj_dropout, attn_drop=attn_dropout) 104 | for _ in range(num_blocks) 105 | ]) 106 | self.norm = nn.LayerNorm(embed_dim, eps=1e-6) 107 | 108 | self.initialize_weights() 109 | 110 | if self.guiding_pixel: 111 | self.guiding_pixel_loss = PixelLoss( 112 | c_channels=cond_embed_dim, 113 | width=128, 114 | depth=3, 115 | num_heads=4, 116 | r_weight=5.0 117 | ) 118 | 119 | def initialize_weights(self): 120 | # parameters 121 | torch.nn.init.normal_(self.mask_token, std=.02) 122 | torch.nn.init.normal_(self.pos_embed_learned, std=.02) 123 | 124 | # initialize nn.Linear and nn.LayerNorm 125 | self.apply(self._init_weights) 126 | 127 | def _init_weights(self, m): 128 | if isinstance(m, nn.Linear): 129 | # we use xavier_uniform following official JAX ViT: 130 | torch.nn.init.xavier_uniform_(m.weight) 131 | if isinstance(m, nn.Linear) and m.bias is not None: 132 | nn.init.constant_(m.bias, 0) 133 | elif isinstance(m, nn.LayerNorm): 134 | if m.bias is not None: 135 | nn.init.constant_(m.bias, 0) 136 | if m.weight is not None: 137 | nn.init.constant_(m.weight, 1.0) 138 | 139 | def patchify(self, x): 140 | bsz, c, h, w = x.shape 141 | p = self.patch_size 142 | h_, w_ = h // p, w // p 143 | 144 | x = x.reshape(bsz, c, h_, p, w_, p) 145 | x = torch.einsum('nchpwq->nhwcpq', x) 146 | x = x.reshape(bsz, h_ * w_, c * p ** 2) 147 | return x # [n, l, d] 148 | 149 | def unpatchify(self, x): 150 | bsz = x.shape[0] 151 | p = self.patch_size 152 | h_, w_ = int(np.sqrt(self.seq_len)), int(np.sqrt(self.seq_len)) 153 | 154 | x = x.reshape(bsz, h_, w_, 3, p, p) 155 | x = torch.einsum('nhwcpq->nchpwq', x) 156 | x = x.reshape(bsz, 3, h_ * p, w_ * p) 157 | return x # [n, 3, h, w] 158 | 159 | def sample_orders(self, bsz): 160 | orders = torch.argsort(torch.rand(bsz, self.seq_len).cuda(), dim=1).long() 161 | return orders 162 | 163 | def random_masking_uniform(self, x, orders): 164 | bsz, seq_len, embed_dim = x.shape 165 | num_masked_tokens = np.random.randint(seq_len) + 1 166 | mask = torch.zeros(bsz, seq_len, device=x.device) 167 | mask = torch.scatter(mask, dim=-1, index=orders[:, :num_masked_tokens], 168 | src=torch.ones(bsz, seq_len, device=x.device)) 169 | return mask 170 | 171 | def random_masking(self, x, orders): 172 | bsz, seq_len, embed_dim = x.shape 173 | mask_rates = self.mask_ratio_generator.rvs(bsz) 174 | num_masked_tokens = torch.Tensor(np.ceil(seq_len * mask_rates)).cuda() 175 | expanded_indices = torch.arange(seq_len, device=x.device).expand(bsz, seq_len) 176 | sorted_orders = torch.argsort(orders, dim=-1) 177 | mask = (expanded_indices < num_masked_tokens[:, None]).float() 178 | mask = torch.scatter(torch.zeros_like(mask), dim=-1, index=sorted_orders, src=mask) 179 | 180 | return mask 181 | 182 | def predict(self, x, mask, cond_list): 183 | x = self.patch_emb(x) 184 | 185 | # prepend conditions from prev generator 186 | for i in range(self.num_conds): 187 | x = torch.cat([self.cond_emb(cond_list[i]).unsqueeze(1), x], dim=1) 188 | 189 | # prepend guiding pixel 190 | if self.guiding_pixel: 191 | x = torch.cat([self.pix_proj(cond_list[-1]).unsqueeze(1), x], dim=1) 192 | 193 | # masking 194 | mask_with_cond = torch.cat([torch.zeros(x.size(0), self.num_conds+self.guiding_pixel, device=x.device), mask], dim=1).bool() 195 | x = torch.where(mask_with_cond.unsqueeze(-1), self.mask_token.to(x.dtype), x) 196 | 197 | # position embedding 198 | x = x + self.pos_embed_learned 199 | x = self.patch_emb_ln(x) 200 | 201 | # apply Transformer blocks 202 | if self.grad_checkpointing and not torch.jit.is_scripting() and self.training: 203 | for block in self.blocks: 204 | x = checkpoint(block, x) 205 | else: 206 | for block in self.blocks: 207 | x = block(x) 208 | x = self.norm(x) 209 | 210 | # return 5 conditions: middle, top, right, bottom, left 211 | middle_cond = x[:, self.num_conds+self.guiding_pixel:] 212 | bsz, seq_len, c = middle_cond.size() 213 | h = int(np.sqrt(seq_len)) 214 | w = int(np.sqrt(seq_len)) 215 | top_cond = middle_cond.reshape(bsz, h, w, c) 216 | top_cond = torch.cat([torch.zeros(bsz, 1, w, c, device=top_cond.device), top_cond[:, :-1]], dim=1) 217 | top_cond = top_cond.reshape(bsz, seq_len, c) 218 | 219 | right_cond = middle_cond.reshape(bsz, h, w, c) 220 | right_cond = torch.cat([right_cond[:, :, 1:], torch.zeros(bsz, h, 1, c, device=right_cond.device)], dim=2) 221 | right_cond = right_cond.reshape(bsz, seq_len, c) 222 | 223 | bottom_cond = middle_cond.reshape(bsz, h, w, c) 224 | bottom_cond = torch.cat([bottom_cond[:, 1:], torch.zeros(bsz, 1, w, c, device=bottom_cond.device)], dim=1) 225 | bottom_cond = bottom_cond.reshape(bsz, seq_len, c) 226 | 227 | left_cond = middle_cond.reshape(bsz, h, w, c) 228 | left_cond = torch.cat([torch.zeros(bsz, h, 1, c, device=left_cond.device), left_cond[:, :, :-1]], dim=2) 229 | left_cond = left_cond.reshape(bsz, seq_len, c) 230 | 231 | return [middle_cond, top_cond, right_cond, bottom_cond, left_cond] 232 | 233 | def forward(self, imgs, cond_list): 234 | """ training """ 235 | # patchify to get gt 236 | patches = self.patchify(imgs) 237 | 238 | # mask tokens 239 | orders = self.sample_orders(bsz=patches.size(0)) 240 | if self.training: 241 | mask = self.random_masking(patches, orders) 242 | else: 243 | # uniform random masking for NLL computation 244 | mask = self.random_masking_uniform(patches, orders) 245 | 246 | # guiding pixel 247 | if self.guiding_pixel: 248 | guiding_pixels = imgs.mean(-1).mean(-1) 249 | guiding_pixel_loss = self.guiding_pixel_loss(guiding_pixels, cond_list) 250 | cond_list.append(guiding_pixels) 251 | else: 252 | guiding_pixel_loss = torch.Tensor([0]).cuda().mean() 253 | 254 | # get condition for next level 255 | cond_list_next = self.predict(patches, mask, cond_list) 256 | 257 | # only keep those conditions and patches on mask 258 | for cond_idx in range(len(cond_list_next)): 259 | cond_list_next[cond_idx] = cond_list_next[cond_idx].reshape(cond_list_next[cond_idx].size(0) * cond_list_next[cond_idx].size(1), -1) 260 | cond_list_next[cond_idx] = cond_list_next[cond_idx][mask.reshape(-1).bool()] 261 | 262 | patches = patches.reshape(patches.size(0) * patches.size(1), -1) 263 | patches = patches[mask.reshape(-1).bool()] 264 | patches = patches.reshape(patches.size(0), 3, self.patch_size, self.patch_size) 265 | 266 | return patches, cond_list_next, guiding_pixel_loss 267 | 268 | def sample(self, cond_list, num_iter, cfg, cfg_schedule, temperature, filter_threshold, next_level_sample_function, 269 | visualize=False): 270 | """ generation """ 271 | if cfg == 1.0: 272 | bsz = cond_list[0].size(0) 273 | else: 274 | bsz = cond_list[0].size(0) // 2 275 | 276 | # sample the guiding pixel 277 | if self.guiding_pixel: 278 | sampled_pixels = self.guiding_pixel_loss.sample(cond_list, temperature, cfg, filter_threshold) 279 | if not cfg == 1.0: 280 | sampled_pixels = torch.cat([sampled_pixels, sampled_pixels], dim=0) 281 | cond_list.append(sampled_pixels) 282 | 283 | # init token mask 284 | mask = torch.ones(bsz, self.seq_len).cuda() 285 | patches = torch.zeros(bsz, self.seq_len, 3 * self.patch_size**2).cuda() 286 | orders = self.sample_orders(bsz) 287 | num_iter = min(self.seq_len, num_iter) 288 | 289 | # sample image 290 | for step in range(num_iter): 291 | cur_patches = patches.clone() 292 | 293 | if not cfg == 1.0: 294 | patches = torch.cat([patches, patches], dim=0) 295 | mask = torch.cat([mask, mask], dim=0) 296 | 297 | # get next level conditions 298 | cond_list_next = self.predict(patches, mask, cond_list) 299 | 300 | # mask ratio for the next round, following MAR. 301 | mask_ratio = np.cos(math.pi / 2. * (step + 1) / num_iter) 302 | mask_len = torch.Tensor([np.floor(self.seq_len * mask_ratio)]).cuda() 303 | 304 | # masks out at least one for the next iteration 305 | mask_len = torch.maximum(torch.Tensor([1]).cuda(), 306 | torch.minimum(torch.sum(mask, dim=-1, keepdims=True) - 1, mask_len)) 307 | 308 | # get masking for next iteration and locations to be predicted in this iteration 309 | mask_next = mask_by_order(mask_len[0], orders, bsz, self.seq_len) 310 | if step >= num_iter - 1: 311 | mask_to_pred = mask[:bsz].bool() 312 | else: 313 | mask_to_pred = torch.logical_xor(mask[:bsz].bool(), mask_next.bool()) 314 | mask = mask_next 315 | if not cfg == 1.0: 316 | mask_to_pred = torch.cat([mask_to_pred, mask_to_pred], dim=0) 317 | 318 | # sample token latents for this step 319 | for cond_idx in range(len(cond_list_next)): 320 | cond_list_next[cond_idx] = cond_list_next[cond_idx][mask_to_pred.nonzero(as_tuple=True)] 321 | 322 | # cfg schedule 323 | if cfg_schedule == "linear": 324 | cfg_iter = 1 + (cfg - 1) * (self.seq_len - mask_len[0]) / self.seq_len 325 | else: 326 | cfg_iter = cfg 327 | sampled_patches = next_level_sample_function(cond_list=cond_list_next, cfg=cfg_iter, 328 | temperature=temperature, filter_threshold=filter_threshold) 329 | sampled_patches = sampled_patches.reshape(sampled_patches.size(0), -1) 330 | 331 | if not cfg == 1.0: 332 | mask_to_pred, _ = mask_to_pred.chunk(2, dim=0) 333 | 334 | cur_patches[mask_to_pred.nonzero(as_tuple=True)] = sampled_patches.to(cur_patches.dtype) 335 | patches = cur_patches.clone() 336 | 337 | # visualize generation process for colab 338 | if visualize: 339 | visualize_patch(self.unpatchify(patches)) 340 | 341 | patches = self.unpatchify(patches) 342 | return patches 343 | -------------------------------------------------------------------------------- /models/pixelloss.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | 7 | from timm.models.vision_transformer import DropPath, Mlp 8 | 9 | 10 | def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor: 11 | L, S = query.size(-2), key.size(-2) 12 | scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale 13 | attn_bias = torch.zeros(L, S, dtype=query.dtype).cuda() 14 | if is_causal: 15 | assert attn_mask is None 16 | temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0).cuda() 17 | attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) 18 | attn_bias.to(query.dtype) 19 | 20 | if attn_mask is not None: 21 | if attn_mask.dtype == torch.bool: 22 | attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) 23 | else: 24 | attn_bias += attn_mask 25 | with torch.cuda.amp.autocast(enabled=False): 26 | attn_weight = query @ key.transpose(-2, -1) * scale_factor 27 | attn_weight += attn_bias 28 | attn_weight = torch.softmax(attn_weight, dim=-1) 29 | attn_weight = torch.dropout(attn_weight, dropout_p, train=True) 30 | return attn_weight @ value 31 | 32 | 33 | class CausalAttention(nn.Module): 34 | def __init__( 35 | self, 36 | dim: int, 37 | num_heads: int = 8, 38 | qkv_bias: bool = False, 39 | qk_norm: bool = False, 40 | attn_drop: float = 0.0, 41 | proj_drop: float = 0.0, 42 | norm_layer: nn.Module = nn.LayerNorm 43 | ) -> None: 44 | super().__init__() 45 | assert dim % num_heads == 0, "dim should be divisible by num_heads" 46 | self.num_heads = num_heads 47 | self.head_dim = dim // num_heads 48 | self.scale = self.head_dim ** -0.5 49 | 50 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 51 | self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() 52 | self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() 53 | self.attn_drop = nn.Dropout(attn_drop) 54 | self.proj = nn.Linear(dim, dim) 55 | self.proj_drop = nn.Dropout(proj_drop) 56 | 57 | def forward(self, x: torch.Tensor) -> torch.Tensor: 58 | B, N, C = x.shape 59 | qkv = ( 60 | self.qkv(x) 61 | .reshape(B, N, 3, self.num_heads, self.head_dim) 62 | .permute(2, 0, 3, 1, 4) 63 | ) 64 | q, k, v = qkv.unbind(0) 65 | q, k = self.q_norm(q), self.k_norm(k) 66 | 67 | x = scaled_dot_product_attention( 68 | q, 69 | k, 70 | v, 71 | dropout_p=self.attn_drop.p if self.training else 0.0, 72 | is_causal=True 73 | ) 74 | 75 | x = x.transpose(1, 2).reshape(B, N, C) 76 | x = self.proj(x) 77 | x = self.proj_drop(x) 78 | return x 79 | 80 | 81 | class CausalBlock(nn.Module): 82 | 83 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, proj_drop=0., attn_drop=0., 84 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 85 | super().__init__() 86 | self.norm1 = norm_layer(dim) 87 | self.attn = CausalAttention( 88 | dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=proj_drop) 89 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 90 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 91 | self.norm2 = norm_layer(dim) 92 | mlp_hidden_dim = int(dim * mlp_ratio) 93 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=proj_drop) 94 | 95 | def forward(self, x): 96 | x = x + self.drop_path(self.attn(self.norm1(x))) 97 | x = x + self.drop_path(self.mlp(self.norm2(x))) 98 | return x 99 | 100 | 101 | class MlmLayer(nn.Module): 102 | 103 | def __init__(self, vocab_size): 104 | super().__init__() 105 | self.bias = nn.Parameter(torch.zeros(1, vocab_size)) 106 | 107 | def forward(self, x, word_embeddings): 108 | word_embeddings = word_embeddings.transpose(0, 1) 109 | logits = torch.matmul(x, word_embeddings) 110 | logits = logits + self.bias 111 | return logits 112 | 113 | 114 | class PixelLoss(nn.Module): 115 | def __init__(self, c_channels, width, depth, num_heads, r_weight=1.0): 116 | super().__init__() 117 | 118 | self.pix_mean = torch.Tensor([0.485, 0.456, 0.406]) 119 | self.pix_std = torch.Tensor([0.229, 0.224, 0.225]) 120 | 121 | self.cond_proj = nn.Linear(c_channels, width) 122 | self.r_codebook = nn.Embedding(256, width) 123 | self.g_codebook = nn.Embedding(256, width) 124 | self.b_codebook = nn.Embedding(256, width) 125 | 126 | self.ln = nn.LayerNorm(width, eps=1e-6) 127 | self.blocks = nn.ModuleList([ 128 | CausalBlock(width, num_heads=num_heads, mlp_ratio=4.0, 129 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), 130 | proj_drop=0, attn_drop=0) 131 | for _ in range(depth) 132 | ]) 133 | self.norm = nn.LayerNorm(width, eps=1e-6) 134 | 135 | self.r_weight = r_weight 136 | self.r_mlm = MlmLayer(256) 137 | self.g_mlm = MlmLayer(256) 138 | self.b_mlm = MlmLayer(256) 139 | 140 | self.criterion = torch.nn.CrossEntropyLoss(reduction="none") 141 | 142 | self.initialize_weights() 143 | 144 | def initialize_weights(self): 145 | # parameters 146 | torch.nn.init.normal_(self.r_codebook.weight, std=.02) 147 | torch.nn.init.normal_(self.g_codebook.weight, std=.02) 148 | torch.nn.init.normal_(self.b_codebook.weight, std=.02) 149 | 150 | # initialize nn.Linear and nn.LayerNorm 151 | self.apply(self._init_weights) 152 | 153 | def _init_weights(self, m): 154 | if isinstance(m, nn.Linear): 155 | # we use xavier_uniform following official JAX ViT: 156 | torch.nn.init.xavier_uniform_(m.weight) 157 | if isinstance(m, nn.Linear) and m.bias is not None: 158 | nn.init.constant_(m.bias, 0) 159 | elif isinstance(m, nn.LayerNorm): 160 | if m.bias is not None: 161 | nn.init.constant_(m.bias, 0) 162 | if m.weight is not None: 163 | nn.init.constant_(m.weight, 1.0) 164 | 165 | def predict(self, target, cond_list): 166 | target = target.reshape(target.size(0), target.size(1)) 167 | # back to [0, 255] 168 | mean = self.pix_mean.cuda().unsqueeze(0) 169 | std = self.pix_std.cuda().unsqueeze(0) 170 | target = target * std + mean 171 | # add a very small noice to avoid pixel distribution inconsistency caused by banker's rounding 172 | target = (target * 255 + 1e-2 * torch.randn_like(target)).round().long() 173 | 174 | # take only the middle condition 175 | cond = cond_list[0] 176 | x = torch.cat( 177 | [self.cond_proj(cond).unsqueeze(1), self.r_codebook(target[:, 0:1]), self.g_codebook(target[:, 1:2]), 178 | self.b_codebook(target[:, 2:3])], dim=1) 179 | x = self.ln(x) 180 | 181 | for block in self.blocks: 182 | x = block(x) 183 | 184 | x = self.norm(x) 185 | with torch.cuda.amp.autocast(enabled=False): 186 | r_logits = self.r_mlm(x[:, 0], self.r_codebook.weight) 187 | g_logits = self.g_mlm(x[:, 1], self.g_codebook.weight) 188 | b_logits = self.b_mlm(x[:, 2], self.b_codebook.weight) 189 | 190 | logits = torch.cat([r_logits.unsqueeze(1), g_logits.unsqueeze(1), b_logits.unsqueeze(1)], dim=1) 191 | return logits, target 192 | 193 | def forward(self, target, cond_list): 194 | """ training """ 195 | logits, target = self.predict(target, cond_list) 196 | loss_r = self.criterion(logits[:, 0], target[:, 0]) 197 | loss_g = self.criterion(logits[:, 1], target[:, 1]) 198 | loss_b = self.criterion(logits[:, 2], target[:, 2]) 199 | 200 | if self.training: 201 | loss = (self.r_weight * loss_r + loss_g + loss_b) / (self.r_weight + 2) 202 | else: 203 | # for NLL computation 204 | loss = (loss_r + loss_g + loss_b) / 3 205 | 206 | return loss.mean() 207 | 208 | def sample(self, cond_list, temperature, cfg, filter_threshold=0): 209 | """ generation """ 210 | if cfg == 1.0: 211 | bsz = cond_list[0].size(0) 212 | else: 213 | bsz = cond_list[0].size(0) // 2 214 | pixel_values = torch.zeros(bsz, 3).cuda() 215 | 216 | for i in range(3): 217 | if cfg == 1.0: 218 | logits, _ = self.predict(pixel_values, cond_list) 219 | else: 220 | logits, _ = self.predict(torch.cat([pixel_values, pixel_values], dim=0), cond_list) 221 | logits = logits[:, i] 222 | logits = logits * temperature 223 | 224 | if not cfg == 1.0: 225 | cond_logits = logits[:bsz] 226 | uncond_logits = logits[bsz:] 227 | 228 | # very unlikely conditional logits will be suppressed 229 | cond_probs = torch.softmax(cond_logits, dim=-1) 230 | mask = cond_probs < filter_threshold 231 | uncond_logits[mask] = torch.max( 232 | uncond_logits, 233 | cond_logits - torch.max(cond_logits, dim=-1, keepdim=True)[0] + torch.max(uncond_logits, dim=-1, keepdim=True)[0] 234 | )[mask] 235 | 236 | logits = uncond_logits + cfg * (cond_logits - uncond_logits) 237 | 238 | # get token prediction 239 | probs = torch.softmax(logits, dim=-1) 240 | sampled_ids = torch.multinomial(probs, num_samples=1).reshape(-1) 241 | pixel_values[:, i] = (sampled_ids.float() / 255 - self.pix_mean[i]) / self.pix_std[i] 242 | 243 | # back to [0, 1] 244 | return pixel_values 245 | -------------------------------------------------------------------------------- /util/crop.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | 4 | 5 | def center_crop_arr(pil_image, image_size): 6 | """ 7 | Center cropping implementation from ADM. 8 | https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 9 | """ 10 | while min(*pil_image.size) >= 2 * image_size: 11 | pil_image = pil_image.resize( 12 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 13 | ) 14 | 15 | scale = image_size / min(*pil_image.size) 16 | pil_image = pil_image.resize( 17 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 18 | ) 19 | 20 | arr = np.array(pil_image) 21 | crop_y = (arr.shape[0] - image_size) // 2 22 | crop_x = (arr.shape[1] - image_size) // 2 23 | return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) 24 | -------------------------------------------------------------------------------- /util/download.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | import requests 4 | 5 | 6 | def download_pretrained_fractalar_in64(overwrite=False): 7 | download_path = "pretrained_models/fractalar_in64/checkpoint-last.pth" 8 | if not os.path.exists(download_path) or overwrite: 9 | headers = {'user-agent': 'Wget/1.16 (linux-gnu)'} 10 | os.makedirs("pretrained_models/fractalar_in64", exist_ok=True) 11 | r = requests.get("https://www.dropbox.com/scl/fi/n25tbij7aqkwo1ypqhz72/checkpoint-last.pth?rlkey=2czevgex3ocg2ae8zde3xpb3f&st=mj0subup&dl=0", stream=True, headers=headers) 12 | print("Downloading FractalAR on ImageNet 64x64...") 13 | with open(download_path, 'wb') as f: 14 | for chunk in tqdm(r.iter_content(chunk_size=1024*1024), unit="MB", total=1688): 15 | if chunk: 16 | f.write(chunk) 17 | 18 | 19 | def download_pretrained_fractalmar_in64(overwrite=False): 20 | download_path = "pretrained_models/fractalmar_in64/checkpoint-last.pth" 21 | if not os.path.exists(download_path) or overwrite: 22 | headers = {'user-agent': 'Wget/1.16 (linux-gnu)'} 23 | os.makedirs("pretrained_models/fractalmar_in64", exist_ok=True) 24 | r = requests.get("https://www.dropbox.com/scl/fi/lh7fmv48pusujd6m4kcdn/checkpoint-last.pth?rlkey=huihey61ok32h28o3tbbq6ek9&st=fxtoawba&dl=0", stream=True, headers=headers) 25 | print("Downloading FractalMAR on ImageNet 64x64...") 26 | with open(download_path, 'wb') as f: 27 | for chunk in tqdm(r.iter_content(chunk_size=1024*1024), unit="MB", total=1650): 28 | if chunk: 29 | f.write(chunk) 30 | 31 | 32 | def download_pretrained_fractalmar_base_in256(overwrite=False): 33 | download_path = "pretrained_models/fractalmar_base_in256/checkpoint-last.pth" 34 | if not os.path.exists(download_path) or overwrite: 35 | headers = {'user-agent': 'Wget/1.16 (linux-gnu)'} 36 | os.makedirs("pretrained_models/fractalmar_base_in256", exist_ok=True) 37 | r = requests.get("https://www.dropbox.com/scl/fi/zrdm7853ih4tcv98wmzhe/checkpoint-last.pth?rlkey=htq9yuzovet7d6ioa64s1xxd0&st=4c4d93vs&dl=0", stream=True, headers=headers) 38 | print("Downloading FractalMAR-Base on ImageNet 256x256...") 39 | with open(download_path, 'wb') as f: 40 | for chunk in tqdm(r.iter_content(chunk_size=1024*1024), unit="MB", total=712): 41 | if chunk: 42 | f.write(chunk) 43 | 44 | 45 | def download_pretrained_fractalmar_large_in256(overwrite=False): 46 | download_path = "pretrained_models/fractalmar_large_in256/checkpoint-last.pth" 47 | if not os.path.exists(download_path) or overwrite: 48 | headers = {'user-agent': 'Wget/1.16 (linux-gnu)'} 49 | os.makedirs("pretrained_models/fractalmar_large_in256", exist_ok=True) 50 | r = requests.get("https://www.dropbox.com/scl/fi/y1k05xx7ry8521ckxkqgt/checkpoint-last.pth?rlkey=wolq4krdq7z7eyjnaw5ndhq6k&st=vjeu5uzo&dl=0", stream=True, headers=headers) 51 | print("Downloading FractalMAR-Large on ImageNet 256x256...") 52 | with open(download_path, 'wb') as f: 53 | for chunk in tqdm(r.iter_content(chunk_size=1024*1024), unit="MB", total=1669): 54 | if chunk: 55 | f.write(chunk) 56 | 57 | 58 | def download_pretrained_fractalmar_huge_in256(overwrite=False): 59 | download_path = "pretrained_models/fractalmar_huge_in256/checkpoint-last.pth" 60 | if not os.path.exists(download_path) or overwrite: 61 | headers = {'user-agent': 'Wget/1.16 (linux-gnu)'} 62 | os.makedirs("pretrained_models/fractalmar_huge_in256", exist_ok=True) 63 | r = requests.get("https://www.dropbox.com/scl/fi/t2rru8xr6wm23yvxskpww/checkpoint-last.pth?rlkey=dn9ss9zw4zsnckf6bat9hss6h&st=y7w921zo&dl=0", stream=True, headers=headers) 64 | print("Downloading FractalMAR-Huge on ImageNet 256x256...") 65 | with open(download_path, 'wb') as f: 66 | for chunk in tqdm(r.iter_content(chunk_size=1024*1024), unit="MB", total=3243): 67 | if chunk: 68 | f.write(chunk) 69 | 70 | 71 | if __name__ == "__main__": 72 | download_pretrained_fractalar_in64() 73 | download_pretrained_fractalmar_in64() 74 | download_pretrained_fractalmar_base_in256() 75 | download_pretrained_fractalmar_large_in256() 76 | download_pretrained_fractalmar_huge_in256() 77 | -------------------------------------------------------------------------------- /util/lr_sched.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | 4 | def adjust_learning_rate(optimizer, epoch, args): 5 | """Decay the learning rate with half-cycle cosine after warmup""" 6 | if epoch < args.warmup_epochs: 7 | lr = args.lr * epoch / args.warmup_epochs 8 | else: 9 | if args.lr_schedule == "constant": 10 | lr = args.lr 11 | elif args.lr_schedule == "cosine": 12 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \ 13 | (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs))) 14 | else: 15 | raise NotImplementedError 16 | for param_group in optimizer.param_groups: 17 | if "lr_scale" in param_group: 18 | param_group["lr"] = lr * param_group["lr_scale"] 19 | else: 20 | param_group["lr"] = lr 21 | return lr 22 | -------------------------------------------------------------------------------- /util/misc.py: -------------------------------------------------------------------------------- 1 | import builtins 2 | import datetime 3 | import os 4 | import time 5 | from collections import defaultdict, deque 6 | from pathlib import Path 7 | 8 | import torch 9 | import torch.distributed as dist 10 | TORCH_MAJOR = int(torch.__version__.split('.')[0]) 11 | TORCH_MINOR = int(torch.__version__.split('.')[1]) 12 | 13 | if TORCH_MAJOR == 1 and TORCH_MINOR < 8: 14 | from torch._six import inf 15 | else: 16 | from torch import inf 17 | 18 | 19 | class SmoothedValue(object): 20 | """Track a series of values and provide access to smoothed values over a 21 | window or the global series average. 22 | """ 23 | 24 | def __init__(self, window_size=20, fmt=None): 25 | if fmt is None: 26 | fmt = "{median:.4f} ({global_avg:.4f})" 27 | self.deque = deque(maxlen=window_size) 28 | self.total = 0.0 29 | self.count = 0 30 | self.fmt = fmt 31 | 32 | def update(self, value, n=1): 33 | self.deque.append(value) 34 | self.count += n 35 | self.total += value * n 36 | 37 | def synchronize_between_processes(self): 38 | """ 39 | Warning: does not synchronize the deque! 40 | """ 41 | if not is_dist_avail_and_initialized(): 42 | return 43 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 44 | dist.barrier() 45 | dist.all_reduce(t) 46 | t = t.tolist() 47 | self.count = int(t[0]) 48 | self.total = t[1] 49 | 50 | @property 51 | def median(self): 52 | d = torch.tensor(list(self.deque)) 53 | return d.median().item() 54 | 55 | @property 56 | def avg(self): 57 | d = torch.tensor(list(self.deque), dtype=torch.float32) 58 | return d.mean().item() 59 | 60 | @property 61 | def global_avg(self): 62 | return self.total / self.count 63 | 64 | @property 65 | def max(self): 66 | return max(self.deque) 67 | 68 | @property 69 | def value(self): 70 | return self.deque[-1] 71 | 72 | def __str__(self): 73 | return self.fmt.format( 74 | median=self.median, 75 | avg=self.avg, 76 | global_avg=self.global_avg, 77 | max=self.max, 78 | value=self.value) 79 | 80 | 81 | class MetricLogger(object): 82 | def __init__(self, delimiter="\t"): 83 | self.meters = defaultdict(SmoothedValue) 84 | self.delimiter = delimiter 85 | 86 | def update(self, **kwargs): 87 | for k, v in kwargs.items(): 88 | if v is None: 89 | continue 90 | if isinstance(v, torch.Tensor): 91 | v = v.item() 92 | assert isinstance(v, (float, int)) 93 | self.meters[k].update(v) 94 | 95 | def __getattr__(self, attr): 96 | if attr in self.meters: 97 | return self.meters[attr] 98 | if attr in self.__dict__: 99 | return self.__dict__[attr] 100 | raise AttributeError("'{}' object has no attribute '{}'".format( 101 | type(self).__name__, attr)) 102 | 103 | def __str__(self): 104 | loss_str = [] 105 | for name, meter in self.meters.items(): 106 | loss_str.append( 107 | "{}: {}".format(name, str(meter)) 108 | ) 109 | return self.delimiter.join(loss_str) 110 | 111 | def synchronize_between_processes(self): 112 | for meter in self.meters.values(): 113 | meter.synchronize_between_processes() 114 | 115 | def add_meter(self, name, meter): 116 | self.meters[name] = meter 117 | 118 | def log_every(self, iterable, print_freq, header=None): 119 | i = 0 120 | if not header: 121 | header = '' 122 | start_time = time.time() 123 | end = time.time() 124 | iter_time = SmoothedValue(fmt='{avg:.4f}') 125 | data_time = SmoothedValue(fmt='{avg:.4f}') 126 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 127 | log_msg = [ 128 | header, 129 | '[{0' + space_fmt + '}/{1}]', 130 | 'eta: {eta}', 131 | '{meters}', 132 | 'time: {time}', 133 | 'data: {data}' 134 | ] 135 | if torch.cuda.is_available(): 136 | log_msg.append('max mem: {memory:.0f}') 137 | log_msg = self.delimiter.join(log_msg) 138 | MB = 1024.0 * 1024.0 139 | for obj in iterable: 140 | data_time.update(time.time() - end) 141 | yield obj 142 | iter_time.update(time.time() - end) 143 | if i % print_freq == 0 or i == len(iterable) - 1: 144 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 145 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 146 | if torch.cuda.is_available(): 147 | print(log_msg.format( 148 | i, len(iterable), eta=eta_string, 149 | meters=str(self), 150 | time=str(iter_time), data=str(data_time), 151 | memory=torch.cuda.max_memory_allocated() / MB)) 152 | else: 153 | print(log_msg.format( 154 | i, len(iterable), eta=eta_string, 155 | meters=str(self), 156 | time=str(iter_time), data=str(data_time))) 157 | i += 1 158 | end = time.time() 159 | total_time = time.time() - start_time 160 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 161 | print('{} Total time: {} ({:.4f} s / it)'.format( 162 | header, total_time_str, total_time / len(iterable))) 163 | 164 | 165 | def setup_for_distributed(is_master): 166 | """ 167 | This function disables printing when not in master process 168 | """ 169 | builtin_print = builtins.print 170 | 171 | def print(*args, **kwargs): 172 | force = kwargs.pop('force', False) 173 | force = force or (get_world_size() > 8) 174 | if is_master or force: 175 | now = datetime.datetime.now().time() 176 | builtin_print('[{}] '.format(now), end='') # print with time stamp 177 | builtin_print(*args, **kwargs) 178 | 179 | builtins.print = print 180 | 181 | 182 | def is_dist_avail_and_initialized(): 183 | if not dist.is_available(): 184 | return False 185 | if not dist.is_initialized(): 186 | return False 187 | return True 188 | 189 | 190 | def get_world_size(): 191 | if not is_dist_avail_and_initialized(): 192 | return 1 193 | return dist.get_world_size() 194 | 195 | 196 | def get_rank(): 197 | if not is_dist_avail_and_initialized(): 198 | return 0 199 | return dist.get_rank() 200 | 201 | 202 | def is_main_process(): 203 | return get_rank() == 0 204 | 205 | 206 | def save_on_master(*args, **kwargs): 207 | if is_main_process(): 208 | torch.save(*args, **kwargs) 209 | 210 | 211 | def init_distributed_mode(args): 212 | if args.dist_on_itp: 213 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 214 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) 215 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 216 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) 217 | os.environ['LOCAL_RANK'] = str(args.gpu) 218 | os.environ['RANK'] = str(args.rank) 219 | os.environ['WORLD_SIZE'] = str(args.world_size) 220 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] 221 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 222 | args.rank = int(os.environ["RANK"]) 223 | args.world_size = int(os.environ['WORLD_SIZE']) 224 | args.gpu = int(os.environ['LOCAL_RANK']) 225 | elif 'SLURM_PROCID' in os.environ: 226 | args.rank = int(os.environ['SLURM_PROCID']) 227 | args.gpu = args.rank % torch.cuda.device_count() 228 | else: 229 | print('Not using distributed mode') 230 | setup_for_distributed(is_master=True) # hack 231 | args.distributed = False 232 | return 233 | 234 | args.distributed = True 235 | 236 | torch.cuda.set_device(args.gpu) 237 | args.dist_backend = 'nccl' 238 | print('| distributed init (rank {}): {}, gpu {}'.format( 239 | args.rank, args.dist_url, args.gpu), flush=True) 240 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 241 | world_size=args.world_size, rank=args.rank) 242 | torch.distributed.barrier() 243 | setup_for_distributed(args.rank == 0) 244 | 245 | 246 | class NativeScalerWithGradNormCount: 247 | state_dict_key = "amp_scaler" 248 | 249 | def __init__(self): 250 | self._scaler = torch.cuda.amp.GradScaler() 251 | 252 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 253 | self._scaler.scale(loss).backward(create_graph=create_graph) 254 | if update_grad: 255 | if clip_grad is not None: 256 | assert parameters is not None 257 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 258 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 259 | else: 260 | self._scaler.unscale_(optimizer) 261 | norm = get_grad_norm_(parameters) 262 | self._scaler.step(optimizer) 263 | self._scaler.update() 264 | else: 265 | norm = None 266 | return norm 267 | 268 | def state_dict(self): 269 | return self._scaler.state_dict() 270 | 271 | def load_state_dict(self, state_dict): 272 | self._scaler.load_state_dict(state_dict) 273 | 274 | 275 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 276 | if isinstance(parameters, torch.Tensor): 277 | parameters = [parameters] 278 | parameters = [p for p in parameters if p.grad is not None] 279 | norm_type = float(norm_type) 280 | if len(parameters) == 0: 281 | return torch.tensor(0.) 282 | device = parameters[0].grad.device 283 | if norm_type == inf: 284 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 285 | else: 286 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) 287 | return total_norm 288 | 289 | 290 | def add_weight_decay(model, weight_decay=1e-5, skip_list=()): 291 | decay = [] 292 | no_decay = [] 293 | for name, param in model.named_parameters(): 294 | if not param.requires_grad: 295 | continue # frozen weights 296 | if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list or 'diffloss' in name: 297 | no_decay.append(param) # no weight decay on bias, norm and diffloss 298 | else: 299 | decay.append(param) 300 | return [ 301 | {'params': no_decay, 'weight_decay': 0.}, 302 | {'params': decay, 'weight_decay': weight_decay}] 303 | 304 | 305 | def save_model(args, epoch, model_without_ddp, optimizer, loss_scaler, epoch_name=None): 306 | if epoch_name is None: 307 | epoch_name = str(epoch) 308 | output_dir = Path(args.output_dir) 309 | checkpoint_path = output_dir / ('checkpoint-%s.pth' % epoch_name) 310 | 311 | to_save = { 312 | 'model': model_without_ddp.state_dict(), 313 | 'optimizer': optimizer.state_dict(), 314 | 'epoch': epoch, 315 | 'scaler': loss_scaler.state_dict(), 316 | 'args': args, 317 | } 318 | save_on_master(to_save, checkpoint_path) 319 | 320 | 321 | def all_reduce_mean(x): 322 | world_size = get_world_size() 323 | if world_size > 1: 324 | x_reduce = torch.tensor(x).cuda() 325 | dist.all_reduce(x_reduce) 326 | x_reduce /= world_size 327 | return x_reduce.item() 328 | else: 329 | return x -------------------------------------------------------------------------------- /util/visualize.py: -------------------------------------------------------------------------------- 1 | from torchvision.utils import save_image 2 | from PIL import Image 3 | import torch 4 | 5 | 6 | def visualize_patch(viz_patches): 7 | from IPython.display import display, clear_output 8 | pix_mean = torch.Tensor([0.485, 0.456, 0.406]).cuda().view(1, -1, 1, 1) 9 | pix_std = torch.Tensor([0.229, 0.224, 0.225]).cuda().view(1, -1, 1, 1) 10 | viz_patches = viz_patches * pix_std + pix_mean 11 | img_size = viz_patches.size(2) 12 | if img_size < 256: 13 | viz_patches = torch.nn.functional.interpolate(viz_patches, scale_factor=256 // img_size, mode="nearest") 14 | save_image(viz_patches, "samples.png", nrow=4, normalize=True, value_range=(0, 1)) 15 | sampled_patches_viz = Image.open("samples.png") 16 | clear_output(wait=True) 17 | display(sampled_patches_viz) 18 | --------------------------------------------------------------------------------