├── LICENSE
├── README.md
├── demo
├── run_fractalgen.ipynb
└── visual.gif
├── engine_fractalgen.py
├── environment.yaml
├── fid_stats
├── adm_in256_stats.npz
└── adm_in64_stats.npz
├── main_fractalgen.py
├── models
├── ar.py
├── fractalgen.py
├── mar.py
└── pixelloss.py
└── util
├── crop.py
├── download.py
├── lr_sched.py
├── misc.py
└── visualize.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Tianhong Li
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Fractal Generative Models
2 |
3 | [](https://arxiv.org/abs/2502.17437)
4 | [](http://colab.research.google.com/github/LTH14/fractalgen/blob/main/demo/run_fractalgen.ipynb)
5 |
6 |
7 |
8 |
9 |
10 | This is a PyTorch/GPU implementation of the paper [Fractal Generative Models](https://arxiv.org/abs/2502.17437):
11 |
12 | ```
13 | @article{li2025fractal,
14 | title={Fractal Generative Models},
15 | author={Li, Tianhong and Sun, Qinyi and Fan, Lijie and He, Kaiming},
16 | journal={arXiv preprint arXiv:2502.17437},
17 | year={2025}
18 | }
19 | ```
20 |
21 | FractalGen enables pixel-by-pixel high-resolution image generation for the first time. This repo contains:
22 |
23 | * 🪐 A simple PyTorch implementation of [Fractal Generative Model](models/fractalgen.py).
24 | * ⚡️ Pre-trained pixel-by-pixel generation models trained on ImageNet 64x64 and 256x256.
25 | * 💥 A self-contained [Colab notebook](http://colab.research.google.com/github/LTH14/fractalgen/blob/main/demo/run_fractalgen.ipynb) for running pre-trained models tasks.
26 | * 🛸 A [training and evaluation script](main_fractalgen.py) using PyTorch DDP.
27 |
28 | ## Preparation
29 |
30 | ### Dataset
31 | Download [ImageNet](http://image-net.org/download) dataset, and place it in your `IMAGENET_PATH`.
32 |
33 | ### Installation
34 |
35 | Download the code:
36 | ```
37 | git clone https://github.com/LTH14/fractalgen.git
38 | cd fractalgen
39 | ```
40 |
41 | A suitable [conda](https://conda.io/) environment named `fractalgen` can be created and activated with:
42 |
43 | ```
44 | conda env create -f environment.yaml
45 | conda activate fractalgen
46 | ```
47 |
48 | Download pre-trained models:
49 |
50 | ```
51 | python util/download.py
52 | ```
53 |
54 | For convenience, our pre-trained models can be downloaded directly here as well:
55 |
56 | | Model | FID-50K | Inception Score | #params |
57 | |-------------------------------------------------------------------------------------------------------------------------------------------------------|----------|-----------------|-----------|
58 | | [FractalAR (IN64)](https://www.dropbox.com/scl/fi/n25tbij7aqkwo1ypqhz72/checkpoint-last.pth?rlkey=2czevgex3ocg2ae8zde3xpb3f&st=mj0subup&dl=0) | 5.30 | 56.8 | 432M |
59 | | [FractalMAR (IN64)](https://www.dropbox.com/scl/fi/lh7fmv48pusujd6m4kcdn/checkpoint-last.pth?rlkey=huihey61ok32h28o3tbbq6ek9&st=fxtoawba&dl=0) | 2.72 | 87.9 | 432M |
60 | | [FractalMAR-Base (IN256)](https://www.dropbox.com/scl/fi/zrdm7853ih4tcv98wmzhe/checkpoint-last.pth?rlkey=htq9yuzovet7d6ioa64s1xxd0&st=4c4d93vs&dl=0) | 11.80 | 274.3 | 186M |
61 | | [FractalMAR-Large (IN256)](https://www.dropbox.com/scl/fi/y1k05xx7ry8521ckxkqgt/checkpoint-last.pth?rlkey=wolq4krdq7z7eyjnaw5ndhq6k&st=vjeu5uzo&dl=0) | 7.30 | 334.9 | 438M |
62 | | [FractalMAR-Huge (IN256)](https://www.dropbox.com/scl/fi/t2rru8xr6wm23yvxskpww/checkpoint-last.pth?rlkey=dn9ss9zw4zsnckf6bat9hss6h&st=y7w921zo&dl=0) | 6.15 | 348.9 | 848M |
63 |
64 | ## Usage
65 |
66 | ### Demo
67 | Run our interactive visualization [demo](http://colab.research.google.com/github/LTH14/fractalgen/blob/main/demo/run_fractalgen.ipynb) using Colab notebook!
68 |
69 | ### Training
70 | The below training scripts have been tested on 4x8 H100 GPUs.
71 |
72 | Example script for training FractalAR on ImageNet 64x64 for 800 epochs:
73 | ```
74 | torchrun --nproc_per_node=8 --nnodes=4 --node_rank=${NODE_RANK} --master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} \
75 | main_fractalgen.py \
76 | --model fractalar_in64 --img_size 64 --num_conds 1 \
77 | --batch_size 64 --eval_freq 40 --save_last_freq 10 \
78 | --epochs 800 --warmup_epochs 40 \
79 | --blr 5.0e-5 --weight_decay 0.05 --attn_dropout 0.1 --proj_dropout 0.1 --lr_schedule cosine \
80 | --gen_bsz 256 --num_images 8000 --num_iter_list 64,16 --cfg 11.0 --cfg_schedule linear --temperature 1.03 \
81 | --output_dir ${OUTPUT_DIR} --resume ${OUTPUT_DIR} \
82 | --data_path ${IMAGENET_PATH} --grad_checkpointing --online_eval
83 | ```
84 |
85 | Example script for training FractalMAR on ImageNet 64x64 for 800 epochs:
86 | ```
87 | torchrun --nproc_per_node=8 --nnodes=4 --node_rank=${NODE_RANK} --master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} \
88 | main_fractalgen.py \
89 | --model fractalmar_in64 --img_size 64 --num_conds 5 \
90 | --batch_size 64 --eval_freq 40 --save_last_freq 10 \
91 | --epochs 800 --warmup_epochs 40 \
92 | --blr 5.0e-5 --weight_decay 0.05 --attn_dropout 0.1 --proj_dropout 0.1 --lr_schedule cosine \
93 | --gen_bsz 256 --num_images 8000 --num_iter_list 64,16 --cfg 6.5 --cfg_schedule linear --temperature 1.02 \
94 | --output_dir ${OUTPUT_DIR} --resume ${OUTPUT_DIR} \
95 | --data_path ${IMAGENET_PATH} --grad_checkpointing --online_eval
96 | ```
97 |
98 | Example script for training FractalMAR-L on ImageNet 256x256 for 800 epochs:
99 | ```
100 | torchrun --nproc_per_node=8 --nnodes=4 --node_rank=${NODE_RANK} --master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} \
101 | main_fractalgen.py \
102 | --model fractalmar_large_in256 --img_size 256 --num_conds 5 --guiding_pixel \
103 | --batch_size 32 --eval_freq 40 --save_last_freq 10 \
104 | --epochs 800 --warmup_epochs 40 \
105 | --blr 5.0e-5 --weight_decay 0.05 --attn_dropout 0.1 --proj_dropout 0.1 --lr_schedule cosine \
106 | --gen_bsz 256 --num_images 8000 --num_iter_list 64,16,16 --cfg 21.0 --cfg_schedule linear --temperature 1.1 \
107 | --output_dir ${OUTPUT_DIR} --resume ${OUTPUT_DIR} \
108 | --data_path ${IMAGENET_PATH} --grad_checkpointing --online_eval
109 | ```
110 |
111 | ### Evaluation
112 |
113 | Evaluate pre-trained FractalAR on ImageNet 64x64 unconditional likelihood estimation (single GPU):
114 | ```
115 | torchrun --nproc_per_node=1 --nnodes=1 --node_rank=0 \
116 | main_fractalgen.py \
117 | --model fractalar_in64 --img_size 64 --num_conds 1 \
118 | --nll_bsz 128 --nll_forward_number 1 \
119 | --output_dir pretrained_models/fractalar_in64 \
120 | --resume pretrained_models/fractalar_in64 \
121 | --data_path ${IMAGENET_PATH} --seed 0 --evaluate_nll
122 | ```
123 |
124 | Evaluate pre-trained FractalMAR on ImageNet 64x64 unconditional likelihood estimation (single GPU):
125 | ```
126 | torchrun --nproc_per_node=1 --nnodes=1 --node_rank=0 \
127 | main_fractalgen.py \
128 | --model fractalmar_in64 --img_size 64 --num_conds 5 \
129 | --nll_bsz 128 --nll_forward_number 10 \
130 | --output_dir pretrained_models/fractalmar_in64 \
131 | --resume pretrained_models/fractalmar_in64 \
132 | --data_path ${IMAGENET_PATH} --seed 0 --evaluate_nll
133 | ```
134 |
135 | Evaluate pre-trained FractalAR on ImageNet 64x64 class-conditional generation:
136 | ```
137 | torchrun --nproc_per_node=8 --nnodes=1 --node_rank=0 \
138 | main_fractalgen.py \
139 | --model fractalar_in64 --img_size 64 --num_conds 1 \
140 | --gen_bsz 512 --num_images 50000 \
141 | --num_iter_list 64,16 --cfg 11.0 --cfg_schedule linear --temperature 1.03 \
142 | --output_dir pretrained_models/fractalar_in64 \
143 | --resume pretrained_models/fractalar_in64 \
144 | --data_path ${IMAGENET_PATH} --seed 0 --evaluate_gen
145 | ```
146 |
147 | Evaluate pre-trained FractalMAR on ImageNet 64x64 class-conditional generation:
148 | ```
149 | torchrun --nproc_per_node=8 --nnodes=1 --node_rank=0 \
150 | main_fractalgen.py \
151 | --model fractalmar_in64 --img_size 64 --num_conds 5 \
152 | --gen_bsz 1024 --num_images 50000 \
153 | --num_iter_list 64,16 --cfg 6.5 --cfg_schedule linear --temperature 1.02 \
154 | --output_dir pretrained_models/fractalmar_in64 \
155 | --resume pretrained_models/fractalmar_in64 \
156 | --data_path ${IMAGENET_PATH} --seed 0 --evaluate_gen
157 | ```
158 |
159 | Evaluate pre-trained FractalMAR-Huge on ImageNet 256x256 class-conditional generation:
160 | ```
161 | torchrun --nproc_per_node=8 --nnodes=1 --node_rank=0 \
162 | main_fractalgen.py \
163 | --model fractalmar_huge_in256 --img_size 256 --num_conds 5 --guiding_pixel \
164 | --gen_bsz 1024 --num_images 50000 \
165 | --num_iter_list 64,16,16 --cfg 19.0 --cfg_schedule linear --temperature 1.1 \
166 | --output_dir pretrained_models/fractalmar_huge_in256 \
167 | --resume pretrained_models/fractalmar_huge_in256 \
168 | --data_path ${IMAGENET_PATH} --seed 0 --evaluate_gen
169 | ```
170 |
171 | For ImageNet 256x256, the optimal classifier-free guidance values `--cfg` that achieve the best FID are `29.0` for FractalMAR-Base and `21.0` for FractalMAR-Large.
172 |
173 | ## Acknowledgements
174 |
175 | We thank Google TPU Research Cloud (TRC) for granting us access to TPUs, and Google Cloud Platform for supporting GPU resources.
176 |
177 | ## Contact
178 |
179 | If you have any questions, feel free to contact me through email (tianhong@mit.edu). Enjoy!
180 |
--------------------------------------------------------------------------------
/demo/visual.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LTH14/fractalgen/c7d099043dd987ae7742a7f8c8f1cab71023ba0e/demo/visual.gif
--------------------------------------------------------------------------------
/engine_fractalgen.py:
--------------------------------------------------------------------------------
1 | import math
2 | import sys
3 | import os
4 | import time
5 | import shutil
6 | from typing import Iterable
7 |
8 | import torch
9 | import torch.nn as nn
10 | import numpy as np
11 | import cv2
12 |
13 | import util.misc as misc
14 | import util.lr_sched as lr_sched
15 | import torch_fidelity
16 |
17 |
18 | def train_one_epoch(model, data_loader: Iterable, optimizer: torch.optim.Optimizer,
19 | device: torch.device, epoch: int, loss_scaler, log_writer=None, args=None):
20 | model.train(True)
21 | metric_logger = misc.MetricLogger(delimiter=" ")
22 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
23 | header = 'Epoch: [{}]'.format(epoch)
24 | print_freq = 20
25 |
26 | optimizer.zero_grad()
27 |
28 | if log_writer is not None:
29 | print('log_dir: {}'.format(log_writer.log_dir))
30 |
31 | for data_iter_step, (samples, labels) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
32 | # per iteration (instead of per epoch) lr scheduler
33 | lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
34 |
35 | samples = samples.to(device, non_blocking=True)
36 | labels = labels.to(device, non_blocking=True)
37 |
38 | # forward
39 | with torch.cuda.amp.autocast():
40 | loss = model(samples, labels)
41 |
42 | loss_value = loss.item()
43 | if not math.isfinite(loss_value):
44 | print("Loss is {}, stopping training".format(loss_value))
45 | sys.exit(1)
46 |
47 | loss_scaler(loss, optimizer, clip_grad=args.grad_clip, parameters=model.parameters(), update_grad=True)
48 | optimizer.zero_grad()
49 |
50 | torch.cuda.synchronize()
51 |
52 | metric_logger.update(loss=loss_value)
53 | lr = optimizer.param_groups[0]["lr"]
54 | metric_logger.update(lr=lr)
55 |
56 | loss_value_reduce = misc.all_reduce_mean(loss_value)
57 | if log_writer is not None:
58 | # Use epoch_1000x as the x-axis in TensorBoard to calibrate curves.
59 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
60 | log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x)
61 | log_writer.add_scalar('lr', lr, epoch_1000x)
62 |
63 | # gather the stats from all processes
64 | metric_logger.synchronize_between_processes()
65 | print("Averaged stats:", metric_logger)
66 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
67 |
68 |
69 | def compute_nll(model: torch.nn.Module, data_loader: Iterable, device: torch.device, N: int):
70 | model.eval()
71 | metric_logger = misc.MetricLogger(delimiter=" ")
72 | header = ''
73 | print_freq = 20
74 |
75 | total_samples = 0
76 | total_bpd = 0.0
77 |
78 | for _, (samples, labels) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
79 | samples = samples.to(device, non_blocking=True)
80 | labels = labels.to(device, non_blocking=True)
81 |
82 | loss = 0.0
83 | # Average multiple forward passes for a stable NLL estimate.
84 | for _ in range(N):
85 | with torch.cuda.amp.autocast():
86 | with torch.no_grad():
87 | one_loss = model(samples, labels)
88 | loss += one_loss
89 | loss /= N
90 | loss_value = loss.item()
91 |
92 | # convert loss to bits/dim
93 | bpd_value = loss_value / math.log(2)
94 | total_samples += samples.size(0)
95 | total_bpd += bpd_value * samples.size(0)
96 |
97 | torch.cuda.synchronize()
98 | metric_logger.update(bpd=bpd_value)
99 |
100 | print("BPD: {:.5f}".format(total_bpd / total_samples))
101 |
102 |
103 | def evaluate(model_without_ddp, args, epoch, batch_size=64, log_writer=None):
104 | model_without_ddp.eval()
105 | world_size = misc.get_world_size()
106 | local_rank = misc.get_rank()
107 | num_steps = args.num_images // (batch_size * world_size) + 1
108 |
109 | # Construct the folder name for saving generated images.
110 | save_folder = os.path.join(
111 | args.output_dir,
112 | "ariter{}-temp{}-{}cfg{}-filter{}-image{}".format(
113 | args.num_iter_list, args.temperature, args.cfg_schedule,
114 | args.cfg, args.filter_threshold, args.num_images
115 | )
116 | )
117 | if args.evaluate_gen:
118 | save_folder += "_evaluate"
119 | print("Save to:", save_folder)
120 | if misc.get_rank() == 0 and not os.path.exists(save_folder):
121 | os.makedirs(save_folder)
122 |
123 | # Ensure that the number of images per class is equal.
124 | class_num = args.class_num
125 | assert args.num_images % class_num == 0, "Number of images per class must be the same"
126 | class_label_gen_world = np.arange(0, class_num).repeat(args.num_images // class_num)
127 | class_label_gen_world = np.hstack([class_label_gen_world, np.zeros(50000)])
128 |
129 | used_time = 0.0
130 | gen_img_cnt = 0
131 |
132 | for i in range(num_steps):
133 | print("Generation step {}/{}".format(i, num_steps))
134 |
135 | start_idx = world_size * batch_size * i + local_rank * batch_size
136 | end_idx = start_idx + batch_size
137 | labels_gen = class_label_gen_world[start_idx:end_idx]
138 | labels_gen = torch.Tensor(labels_gen).long().cuda()
139 |
140 | torch.cuda.synchronize()
141 | start_time = time.time()
142 |
143 | # generation
144 | with torch.no_grad():
145 | with torch.cuda.amp.autocast():
146 | class_embedding = model_without_ddp.class_emb(labels_gen)
147 | if not args.cfg == 1.0:
148 | # Concatenate fake latent for classifier-free guidance.
149 | class_embedding = torch.cat(
150 | [class_embedding, model_without_ddp.fake_latent.repeat(batch_size, 1)],
151 | dim=0
152 | )
153 | sampled_images = model_without_ddp.sample(
154 | cond_list=[class_embedding for _ in range(args.num_conds)],
155 | num_iter_list=[int(num_iter) for num_iter in args.num_iter_list.split(",")],
156 | cfg=args.cfg, cfg_schedule=args.cfg_schedule,
157 | temperature=args.temperature,
158 | filter_threshold=args.filter_threshold,
159 | fractal_level=0
160 | )
161 |
162 | # Measure generation speed (skip first batch).
163 | torch.cuda.synchronize()
164 | batch_time = time.time() - start_time
165 | if i >= 1:
166 | used_time += batch_time
167 | gen_img_cnt += batch_size
168 | print("Generating {} images takes {:.5f} seconds, {:.5f} sec per image".format(gen_img_cnt, used_time, used_time / gen_img_cnt))
169 |
170 | torch.distributed.barrier()
171 |
172 | # Denormalize images.
173 | pix_mean = torch.Tensor([0.485, 0.456, 0.406]).cuda().view(1, -1, 1, 1)
174 | pix_std = torch.Tensor([0.229, 0.224, 0.225]).cuda().view(1, -1, 1, 1)
175 | sampled_images = sampled_images * pix_std + pix_mean
176 | sampled_images = sampled_images.detach().cpu()
177 |
178 | # distributed save images
179 | for b_id in range(sampled_images.size(0)):
180 | img_id = i * sampled_images.size(0) * world_size + local_rank * sampled_images.size(0) + b_id
181 | if img_id >= args.num_images:
182 | break
183 | gen_img = np.round(np.clip(sampled_images[b_id].numpy().transpose([1, 2, 0]) * 255, 0, 255))
184 | gen_img = gen_img.astype(np.uint8)[:, :, ::-1]
185 | cv2.imwrite(os.path.join(save_folder, '{}.png'.format(str(img_id).zfill(5))), gen_img)
186 |
187 | torch.distributed.barrier()
188 | time.sleep(10)
189 |
190 | # compute FID and IS
191 | if log_writer is not None:
192 | if args.img_size == 64:
193 | fid_statistics_file = 'fid_stats/adm_in64_stats.npz'
194 | elif args.img_size == 256:
195 | fid_statistics_file = 'fid_stats/adm_in256_stats.npz'
196 | else:
197 | raise NotImplementedError
198 | metrics_dict = torch_fidelity.calculate_metrics(
199 | input1=save_folder,
200 | input2=None,
201 | fid_statistics_file=fid_statistics_file,
202 | cuda=True,
203 | isc=True,
204 | fid=True,
205 | kid=False,
206 | prc=False,
207 | verbose=False,
208 | )
209 | fid = metrics_dict['frechet_inception_distance']
210 | inception_score = metrics_dict['inception_score_mean']
211 | postfix = "_cfg{}".format(args.cfg)
212 | log_writer.add_scalar('fid{}'.format(postfix), fid, epoch)
213 | log_writer.add_scalar('is{}'.format(postfix), inception_score, epoch)
214 | print("FID: {:.4f}, Inception Score: {:.4f}".format(fid, inception_score))
215 | if not args.evaluate_gen:
216 | # remove temporal saving folder for online eval
217 | shutil.rmtree(save_folder)
218 |
219 | torch.distributed.barrier()
220 | time.sleep(10)
221 |
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | name: fractalgen
2 | channels:
3 | - pytorch
4 | - defaults
5 | - nvidia
6 | dependencies:
7 | - python=3.8.5
8 | - pip=20.3
9 | - pytorch-cuda=11.8
10 | - pytorch=2.2.2
11 | - torchvision=0.17.2
12 | - numpy=1.22
13 | - pip:
14 | - opencv-python==4.1.2.30
15 | - timm==0.9.12
16 | - tensorboard==2.10.0
17 | - scipy==1.9.1
18 | - gdown==5.2.0
19 | - -e git+https://github.com/LTH14/torch-fidelity.git@master#egg=torch-fidelity
20 |
--------------------------------------------------------------------------------
/fid_stats/adm_in256_stats.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LTH14/fractalgen/c7d099043dd987ae7742a7f8c8f1cab71023ba0e/fid_stats/adm_in256_stats.npz
--------------------------------------------------------------------------------
/fid_stats/adm_in64_stats.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LTH14/fractalgen/c7d099043dd987ae7742a7f8c8f1cab71023ba0e/fid_stats/adm_in64_stats.npz
--------------------------------------------------------------------------------
/main_fractalgen.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datetime
3 | import numpy as np
4 | import os
5 | import time
6 | from pathlib import Path
7 |
8 | import torch
9 | import torch.backends.cudnn as cudnn
10 | from torch.utils.tensorboard import SummaryWriter
11 | import torchvision.transforms as transforms
12 | import torchvision.datasets as datasets
13 |
14 | from util.crop import center_crop_arr
15 | import util.misc as misc
16 | from util.misc import NativeScalerWithGradNormCount as NativeScaler
17 |
18 | from models import fractalgen
19 | from engine_fractalgen import train_one_epoch, compute_nll, evaluate
20 |
21 |
22 | def get_args_parser():
23 | parser = argparse.ArgumentParser('Fractal Generative Models', add_help=False)
24 | parser.add_argument('--batch_size', default=64, type=int,
25 | help='Batch size per GPU (effective batch size = batch_size * # GPUs)')
26 | parser.add_argument('--epochs', default=400, type=int)
27 | parser.add_argument('--seed', default=0, type=int)
28 | parser.add_argument('--resume', default='',
29 | help='Folder that contains checkpoint to resume from')
30 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
31 | help='Starting epoch')
32 | parser.add_argument('--num_workers', default=10, type=int)
33 | parser.add_argument('--pin_mem', action='store_true',
34 | help='Pin CPU memory in DataLoader for faster GPU transfers')
35 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
36 | parser.set_defaults(pin_mem=True)
37 |
38 | # Model parameters
39 | parser.add_argument('--model', default='fractalmar_in64', type=str, metavar='MODEL',
40 | help='Name of the model to train')
41 | parser.add_argument('--img_size', default=64, type=int, help='Image size')
42 |
43 | # Generation parameters
44 | parser.add_argument('--num_iter_list', default='64,16', type=str,
45 | help='Number of autoregressive iterations for each fractal level')
46 | parser.add_argument('--num_images', default=50000, type=int,
47 | help='Number of images to generate')
48 | parser.add_argument('--cfg', default=1.0, type=float,
49 | help='Classifier-free guidance factor')
50 | parser.add_argument('--cfg_schedule', default='linear', type=str)
51 | parser.add_argument('--temperature', default=1.0, type=float,
52 | help='Sampling temperature')
53 | parser.add_argument('--filter_threshold', default=1e-4, type=float,
54 | help='Filter threshold for low probability tokens in cfg')
55 | parser.add_argument('--label_drop_prob', default=0.1, type=float)
56 | parser.add_argument('--eval_freq', type=int, default=40,
57 | help='Frequency (in epochs) for evaluation')
58 | parser.add_argument('--save_last_freq', type=int, default=5,
59 | help='Frequency (in epochs) to save checkpoints')
60 | parser.add_argument('--online_eval', action='store_true')
61 | parser.add_argument('--evaluate_gen', action='store_true')
62 | parser.add_argument('--evaluate_nll', action='store_true')
63 | parser.add_argument('--gen_bsz', type=int, default=1024,
64 | help='Generation batch size')
65 | parser.add_argument('--nll_bsz', type=int, default=128,
66 | help='NLL evaluation batch size')
67 | parser.add_argument('--nll_forward_number', type=int, default=1,
68 | help='Number of forward passes used to evaluate the NLL for each data sample. '
69 | 'This does not affect the NLL of AR model, but for the MAR model, multiple passes (each '
70 | 'randomly sampling a masking ratio) result in a more accurate NLL estimation.'
71 | )
72 | # Optimizer parameters
73 | parser.add_argument('--weight_decay', type=float, default=0.05,
74 | help='Weight decay (default: 0.05)')
75 | parser.add_argument('--grad_checkpointing', action='store_true')
76 | parser.add_argument('--lr', type=float, default=None, metavar='LR',
77 | help='Learning rate (absolute)')
78 | parser.add_argument('--blr', type=float, default=5e-5, metavar='LR',
79 | help='Base learning rate: absolute_lr = base_lr * total_batch_size / 256')
80 | parser.add_argument('--min_lr', type=float, default=0., metavar='LR',
81 | help='Minimum LR for cyclic schedulers that hit 0')
82 | parser.add_argument('--lr_schedule', type=str, default='cosine',
83 | help='Learning rate schedule')
84 | parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N',
85 | help='Epochs to warm up LR')
86 |
87 | # Fractal generator parameters
88 | parser.add_argument('--guiding_pixel', action='store_true',
89 | help='Use guiding pixels')
90 | parser.add_argument('--num_conds', type=int, default=1,
91 | help='Number of conditions to use')
92 | parser.add_argument('--r_weight', type=float, default=5.0,
93 | help='Loss weight on the red channel')
94 | parser.add_argument('--grad_clip', type=float, default=3.0,
95 | help='Gradient clipping value')
96 | parser.add_argument('--attn_dropout', type=float, default=0.1,
97 | help='Attention dropout rate')
98 | parser.add_argument('--proj_dropout', type=float, default=0.1,
99 | help='Projection dropout rate')
100 |
101 | # Dataset parameters
102 | parser.add_argument('--data_path', default='./data/imagenet', type=str,
103 | help='Path to the dataset')
104 | parser.add_argument('--class_num', default=1000, type=int)
105 | parser.add_argument('--output_dir', default='./output_dir',
106 | help='Directory to save outputs (empty for no saving)')
107 | parser.add_argument('--device', default='cuda',
108 | help='Device to use for training/testing')
109 |
110 | # Distributed training parameters
111 | parser.add_argument('--world_size', default=1, type=int,
112 | help='Number of distributed processes')
113 | parser.add_argument('--local_rank', default=-1, type=int)
114 | parser.add_argument('--dist_on_itp', action='store_true')
115 | parser.add_argument('--dist_url', default='env://',
116 | help='URL used to set up distributed training')
117 |
118 | return parser
119 |
120 |
121 | def main(args):
122 | misc.init_distributed_mode(args)
123 | print('Job directory:', os.path.dirname(os.path.realpath(__file__)))
124 | print("Arguments:\n{}".format(args).replace(', ', ',\n'))
125 |
126 | device = torch.device(args.device)
127 |
128 | # Set seeds for reproducibility
129 | seed = args.seed + misc.get_rank()
130 | torch.manual_seed(seed)
131 | np.random.seed(seed)
132 |
133 | cudnn.benchmark = True
134 |
135 | num_tasks = misc.get_world_size()
136 | global_rank = misc.get_rank()
137 |
138 | # Set up TensorBoard logging (only on main process)
139 | if global_rank == 0 and args.output_dir is not None:
140 | os.makedirs(args.output_dir, exist_ok=True)
141 | log_writer = SummaryWriter(log_dir=args.output_dir)
142 | else:
143 | log_writer = None
144 |
145 | # Data augmentation transforms (following DiT and ADM)
146 | transform_train = transforms.Compose([
147 | transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.img_size)),
148 | transforms.RandomHorizontalFlip(),
149 | transforms.ToTensor(),
150 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
151 | ])
152 | transform_val = transforms.Compose([
153 | transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.img_size)),
154 | transforms.ToTensor(),
155 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
156 | ])
157 |
158 | dataset_train = datasets.ImageFolder(os.path.join(args.data_path, 'train'), transform=transform_train)
159 | dataset_val = datasets.ImageFolder(os.path.join(args.data_path, 'val'), transform=transform_val)
160 |
161 | sampler_train = torch.utils.data.DistributedSampler(
162 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
163 | )
164 | print("Sampler_train =", sampler_train)
165 |
166 | data_loader_train = torch.utils.data.DataLoader(
167 | dataset_train, sampler=sampler_train,
168 | batch_size=args.batch_size,
169 | num_workers=args.num_workers,
170 | pin_memory=args.pin_mem,
171 | drop_last=True,
172 | )
173 | data_loader_val = torch.utils.data.DataLoader(
174 | dataset_val, shuffle=True,
175 | batch_size=args.nll_bsz,
176 | num_workers=args.num_workers,
177 | pin_memory=args.pin_mem,
178 | drop_last=False,
179 | )
180 |
181 | # Create fractal generative model
182 | model = fractalgen.__dict__[args.model](
183 | label_drop_prob=args.label_drop_prob,
184 | class_num=args.class_num,
185 | attn_dropout=args.attn_dropout,
186 | proj_dropout=args.proj_dropout,
187 | guiding_pixel=args.guiding_pixel,
188 | num_conds=args.num_conds,
189 | r_weight=args.r_weight,
190 | grad_checkpointing=args.grad_checkpointing
191 | )
192 |
193 | print("Model =", model)
194 | n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
195 | print("Number of trainable parameters: {:.2f}M".format(n_params / 1e6))
196 |
197 | model.to(device)
198 |
199 | eff_batch_size = args.batch_size * misc.get_world_size()
200 | if args.lr is None: # only base_lr (blr) is specified
201 | args.lr = args.blr * eff_batch_size / 256
202 |
203 | print("Base lr: {:.2e}".format(args.lr * 256 / eff_batch_size))
204 | print("Actual lr: {:.2e}".format(args.lr))
205 | print("Effective batch size: %d" % eff_batch_size)
206 |
207 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
208 | model_without_ddp = model.module
209 |
210 | # Set up optimizer with weight decay adjustment for bias and norm layers
211 | param_groups = misc.add_weight_decay(model_without_ddp, args.weight_decay)
212 | optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95))
213 | print(optimizer)
214 | loss_scaler = NativeScaler()
215 |
216 | # Resume from checkpoint if provided
217 | checkpoint_path = os.path.join(args.resume, "checkpoint-last.pth") if args.resume else None
218 | if checkpoint_path and os.path.exists(checkpoint_path):
219 | checkpoint = torch.load(checkpoint_path, map_location='cpu')
220 | model_without_ddp.load_state_dict(checkpoint['model'])
221 | print("Resumed checkpoint from", args.resume)
222 |
223 | if 'optimizer' in checkpoint and 'epoch' in checkpoint:
224 | optimizer.load_state_dict(checkpoint['optimizer'])
225 | args.start_epoch = checkpoint['epoch'] + 1
226 | if 'scaler' in checkpoint:
227 | loss_scaler.load_state_dict(checkpoint['scaler'])
228 | print("Loaded optimizer & scaler state!")
229 | del checkpoint
230 | else:
231 | print("Training from scratch")
232 |
233 | # Evaluation modes
234 | if args.evaluate_gen:
235 | torch.cuda.empty_cache()
236 | evaluate(model_without_ddp, args, 0, batch_size=args.gen_bsz, log_writer=log_writer)
237 | return
238 |
239 | if args.evaluate_nll:
240 | torch.cuda.empty_cache()
241 | compute_nll(model, data_loader_val, device, N=args.nll_forward_number)
242 | return
243 |
244 | # Training loop
245 | print(f"Start training for {args.epochs} epochs")
246 | start_time = time.time()
247 | for epoch in range(args.start_epoch, args.epochs):
248 | if args.distributed:
249 | data_loader_train.sampler.set_epoch(epoch)
250 |
251 | train_one_epoch(
252 | model, data_loader_train, optimizer, device, epoch, loss_scaler, log_writer=log_writer, args=args
253 | )
254 |
255 | # Save checkpoint periodically
256 | if epoch % args.save_last_freq == 0 or epoch + 1 == args.epochs:
257 | misc.save_model(
258 | args=args,
259 | model_without_ddp=model_without_ddp,
260 | optimizer=optimizer,
261 | loss_scaler=loss_scaler,
262 | epoch=epoch,
263 | epoch_name="last"
264 | )
265 |
266 | # Perform online evaluation at specified intervals
267 | if args.online_eval and (epoch % args.eval_freq == 0 or epoch + 1 == args.epochs):
268 | torch.cuda.empty_cache()
269 | evaluate(model_without_ddp, args, epoch, batch_size=args.gen_bsz, log_writer=log_writer)
270 | torch.cuda.empty_cache()
271 |
272 | if misc.is_main_process() and log_writer is not None:
273 | log_writer.flush()
274 |
275 | total_time = time.time() - start_time
276 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
277 | print('Training time:', total_time_str)
278 |
279 |
280 | if __name__ == '__main__':
281 | args = get_args_parser().parse_args()
282 | Path(args.output_dir).mkdir(parents=True, exist_ok=True)
283 | main(args)
284 |
--------------------------------------------------------------------------------
/models/ar.py:
--------------------------------------------------------------------------------
1 | # Modified from:
2 | # LlamaGen: https://github.com/FoundationVision/LlamaGen/blob/main/autoregressive/models/gpt.py
3 | from dataclasses import dataclass
4 | from typing import Optional
5 |
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | from torch.utils.checkpoint import checkpoint
10 | from torch.nn import functional as F
11 | from util.visualize import visualize_patch
12 | import math
13 |
14 |
15 | def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
16 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
17 |
18 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
19 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
20 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
21 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
22 | 'survival rate' as the argument.
23 |
24 | """
25 | if drop_prob == 0. or not training:
26 | return x
27 | keep_prob = 1 - drop_prob
28 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
29 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
30 | if keep_prob > 0.0 and scale_by_keep:
31 | random_tensor.div_(keep_prob)
32 | return x * random_tensor
33 |
34 |
35 | class DropPath(torch.nn.Module):
36 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
37 | """
38 | def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
39 | super(DropPath, self).__init__()
40 | self.drop_prob = drop_prob
41 | self.scale_by_keep = scale_by_keep
42 |
43 | def forward(self, x):
44 | return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
45 |
46 | def extra_repr(self):
47 | return f'drop_prob={round(self.drop_prob,3):0.3f}'
48 |
49 |
50 | def find_multiple(n: int, k: int):
51 | if n % k == 0:
52 | return n
53 | return n + k - (n % k)
54 |
55 |
56 | @dataclass
57 | class ModelArgs:
58 | dim: int = 4096
59 | n_layer: int = 32
60 | n_head: int = 32
61 | n_kv_head: Optional[int] = None
62 | multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
63 | ffn_dim_multiplier: Optional[float] = None
64 | rope_base: float = 10000
65 | norm_eps: float = 1e-5
66 | initializer_range: float = 0.02
67 |
68 | token_dropout_p: float = 0.1
69 | attn_dropout_p: float = 0.0
70 | resid_dropout_p: float = 0.1
71 | ffn_dropout_p: float = 0.1
72 | drop_path_rate: float = 0.0
73 |
74 | num_classes: int = 1000
75 | caption_dim: int = 2048
76 | class_dropout_prob: float = 0.1
77 | model_type: str = 'c2i'
78 |
79 | vocab_size: int = 16384
80 | cls_token_num: int = 1
81 | block_size: int = 256
82 | max_batch_size: int = 32
83 | max_seq_len: int = 2048
84 |
85 |
86 | #################################################################################
87 | # Embedding Layers for Class Labels #
88 | #################################################################################
89 | class LabelEmbedder(nn.Module):
90 | """
91 | Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
92 | """
93 |
94 | def __init__(self, num_classes, hidden_size, dropout_prob):
95 | super().__init__()
96 | use_cfg_embedding = dropout_prob > 0
97 | self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
98 | self.num_classes = num_classes
99 | self.dropout_prob = dropout_prob
100 |
101 | def token_drop(self, labels, force_drop_ids=None):
102 | """
103 | Drops labels to enable classifier-free guidance.
104 | """
105 | if force_drop_ids is None:
106 | drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
107 | else:
108 | drop_ids = force_drop_ids == 1
109 | labels = torch.where(drop_ids, self.num_classes, labels)
110 | return labels
111 |
112 | def forward(self, labels, train, force_drop_ids=None):
113 | use_dropout = self.dropout_prob > 0
114 | if (train and use_dropout) or (force_drop_ids is not None):
115 | labels = self.token_drop(labels, force_drop_ids)
116 | embeddings = self.embedding_table(labels).unsqueeze(1)
117 | return embeddings
118 |
119 |
120 | #################################################################################
121 | # GPT Model #
122 | #################################################################################
123 | class RMSNorm(torch.nn.Module):
124 | def __init__(self, dim: int, eps: float = 1e-5):
125 | super().__init__()
126 | self.eps = eps
127 | self.weight = nn.Parameter(torch.ones(dim))
128 |
129 | def _norm(self, x):
130 | return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
131 |
132 | def forward(self, x):
133 | output = self._norm(x.float()).type_as(x)
134 | return output * self.weight
135 |
136 |
137 | class FeedForward(nn.Module):
138 | def __init__(self, config: ModelArgs):
139 | super().__init__()
140 | hidden_dim = 4 * config.dim
141 | hidden_dim = int(2 * hidden_dim / 3)
142 | # custom dim factor multiplier
143 | if config.ffn_dim_multiplier is not None:
144 | hidden_dim = int(config.ffn_dim_multiplier * hidden_dim)
145 | hidden_dim = find_multiple(hidden_dim, config.multiple_of)
146 |
147 | self.w1 = nn.Linear(config.dim, hidden_dim, bias=False)
148 | self.w3 = nn.Linear(config.dim, hidden_dim, bias=False)
149 | self.w2 = nn.Linear(hidden_dim, config.dim, bias=False)
150 | self.ffn_dropout = nn.Dropout(config.ffn_dropout_p)
151 |
152 | def forward(self, x):
153 | return self.ffn_dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
154 |
155 |
156 | class KVCache(nn.Module):
157 | def __init__(self, max_batch_size, max_seq_length, n_head, head_dim):
158 | super().__init__()
159 | cache_shape = (max_batch_size, n_head, max_seq_length, head_dim)
160 | self.register_buffer('k_cache', torch.zeros(cache_shape))
161 | self.register_buffer('v_cache', torch.zeros(cache_shape))
162 |
163 | def update(self, input_pos, k_val, v_val):
164 | # input_pos: [S], k_val: [B, H, S, D]
165 | k_out = self.k_cache
166 | v_out = self.v_cache
167 | k_out[:, :, input_pos] = k_val.to(k_out.dtype)
168 | v_out[:, :, input_pos] = v_val.to(k_out.dtype)
169 |
170 | return k_out, v_out
171 |
172 |
173 | def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor:
174 | L, S = query.size(-2), key.size(-2)
175 | scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
176 | attn_bias = torch.zeros(L, S, dtype=query.dtype).cuda()
177 | if is_causal:
178 | assert attn_mask is None
179 | temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0).cuda()
180 | attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
181 | attn_bias.to(query.dtype)
182 |
183 | if attn_mask is not None:
184 | if attn_mask.dtype == torch.bool:
185 | attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
186 | else:
187 | attn_bias += attn_mask
188 | with torch.cuda.amp.autocast(enabled=False):
189 | attn_weight = query.float() @ key.float().transpose(-2, -1) * scale_factor
190 | attn_weight += attn_bias
191 | attn_weight = torch.softmax(attn_weight, dim=-1)
192 | attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
193 | return attn_weight @ value
194 |
195 |
196 | class Attention(nn.Module):
197 | def __init__(self, config: ModelArgs):
198 | super().__init__()
199 | assert config.dim % config.n_head == 0
200 | self.dim = config.dim
201 | self.head_dim = config.dim // config.n_head
202 | self.n_head = config.n_head
203 | self.n_kv_head = config.n_kv_head if config.n_kv_head is not None else config.n_head
204 | total_kv_dim = (self.n_head + 2 * self.n_kv_head) * self.head_dim
205 |
206 | # key, query, value projections for all heads, but in a batch
207 | self.wqkv = nn.Linear(config.dim, total_kv_dim, bias=False)
208 | self.wo = nn.Linear(config.dim, config.dim, bias=False)
209 | self.kv_cache = None
210 |
211 | # regularization
212 | self.attn_dropout_p = config.attn_dropout_p
213 | self.resid_dropout = nn.Dropout(config.resid_dropout_p)
214 |
215 | def forward(
216 | self, x: torch.Tensor, freqs_cis=None, input_pos=None, mask=None
217 | ):
218 | bsz, seqlen, _ = x.shape
219 | kv_size = self.n_kv_head * self.head_dim
220 | xq, xk, xv = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
221 |
222 | xq = xq.view(bsz, seqlen, self.n_head, self.head_dim)
223 | xk = xk.view(bsz, seqlen, self.n_kv_head, self.head_dim)
224 | xv = xv.view(bsz, seqlen, self.n_kv_head, self.head_dim)
225 |
226 | xq = apply_rotary_emb(xq, freqs_cis)
227 | xk = apply_rotary_emb(xk, freqs_cis)
228 |
229 | xq, xk, xv = map(lambda x: x.transpose(1, 2), (xq, xk, xv))
230 |
231 | if self.kv_cache is not None:
232 | keys, values = self.kv_cache.update(input_pos, xk, xv)
233 | else:
234 | keys, values = xk, xv
235 | keys = keys.repeat_interleave(self.n_head // self.n_kv_head, dim=1)
236 | values = values.repeat_interleave(self.n_head // self.n_kv_head, dim=1)
237 |
238 | output = scaled_dot_product_attention(
239 | xq, keys, values,
240 | attn_mask=mask,
241 | is_causal=True if mask is None else False, # is_causal=False is for KV cache
242 | dropout_p=self.attn_dropout_p if self.training else 0)
243 |
244 | output = output.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
245 |
246 | output = self.resid_dropout(self.wo(output))
247 | return output
248 |
249 |
250 | class TransformerBlock(nn.Module):
251 | def __init__(self, config: ModelArgs, drop_path: float):
252 | super().__init__()
253 | self.attention = Attention(config)
254 | self.feed_forward = FeedForward(config)
255 | self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
256 | self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
257 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
258 |
259 | def forward(
260 | self, x: torch.Tensor, freqs_cis: torch.Tensor, start_pos: int, mask: Optional[torch.Tensor] = None):
261 | h = x + self.drop_path(self.attention(self.attention_norm(x), freqs_cis, start_pos, mask))
262 | out = h + self.drop_path(self.feed_forward(self.ffn_norm(h)))
263 | return out
264 |
265 |
266 | #################################################################################
267 | # Rotary Positional Embedding Functions #
268 | #################################################################################
269 | # https://github.com/pytorch-labs/gpt-fast/blob/main/model.py
270 | def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000, cls_token_num=120):
271 | freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
272 | t = torch.arange(seq_len, device=freqs.device)
273 | freqs = torch.outer(t, freqs) # (seq_len, head_dim // 2)
274 | freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
275 | cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) # (cls_token_num+seq_len, head_dim // 2, 2)
276 | cond_cache = torch.cat(
277 | [torch.zeros(cls_token_num, n_elem // 2, 2), cache]) # (cls_token_num+seq_len, head_dim // 2, 2)
278 | return cond_cache
279 |
280 |
281 | def precompute_freqs_cis_2d(grid_size: int, n_elem: int, base: int = 10000, cls_token_num=120):
282 | # split the dimension into half, one for x and one for y
283 | half_dim = n_elem // 2
284 | freqs = 1.0 / (base ** (torch.arange(0, half_dim, 2)[: (half_dim // 2)].float() / half_dim))
285 | t = torch.arange(grid_size, device=freqs.device)
286 | freqs = torch.outer(t, freqs) # (grid_size, head_dim // 2)
287 | freqs_grid = torch.concat([
288 | freqs[:, None, :].expand(-1, grid_size, -1),
289 | freqs[None, :, :].expand(grid_size, -1, -1),
290 | ], dim=-1) # (grid_size, grid_size, head_dim // 2)
291 | cache_grid = torch.stack([torch.cos(freqs_grid), torch.sin(freqs_grid)],
292 | dim=-1) # (grid_size, grid_size, head_dim // 2, 2)
293 | cache = cache_grid.flatten(0, 1)
294 | cond_cache = torch.cat(
295 | [torch.zeros(cls_token_num, n_elem // 2, 2), cache]) # (cls_token_num+grid_size**2, head_dim // 2, 2)
296 | return cond_cache
297 |
298 |
299 | def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor):
300 | # x: (bs, seq_len, n_head, head_dim)
301 | # freqs_cis (seq_len, head_dim // 2, 2)
302 | xshaped = x.float().reshape(*x.shape[:-1], -1, 2) # (bs, seq_len, n_head, head_dim//2, 2)
303 | freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) # (1, seq_len, 1, head_dim//2, 2)
304 | x_out2 = torch.stack([
305 | xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
306 | xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
307 | ], dim=-1)
308 | x_out2 = x_out2.flatten(3)
309 | return x_out2.type_as(x)
310 |
311 |
312 | class AR(nn.Module):
313 | def __init__(self, seq_len, patch_size, cond_embed_dim, embed_dim, num_blocks, num_heads,
314 | grad_checkpointing=False, **kwargs):
315 | super().__init__()
316 |
317 | self.seq_len = seq_len
318 | self.patch_size = patch_size
319 |
320 | self.grad_checkpointing = grad_checkpointing
321 |
322 | # --------------------------------------------------------------------------
323 | # network
324 | self.patch_emb = nn.Linear(3 * patch_size ** 2, embed_dim, bias=True)
325 | self.patch_emb_ln = nn.LayerNorm(embed_dim, eps=1e-6)
326 | self.pos_embed_learned = nn.Parameter(torch.zeros(1, seq_len+1, embed_dim))
327 | self.cond_emb = nn.Linear(cond_embed_dim, embed_dim, bias=True)
328 |
329 | self.config = model_args = ModelArgs(dim=embed_dim, n_head=num_heads)
330 | self.blocks = nn.ModuleList([TransformerBlock(config=model_args, drop_path=0.0) for _ in range(num_blocks)])
331 |
332 | # 2d rotary pos embedding
333 | grid_size = int(seq_len ** 0.5)
334 | assert grid_size * grid_size == seq_len
335 | self.freqs_cis = precompute_freqs_cis_2d(grid_size, model_args.dim // model_args.n_head,
336 | model_args.rope_base, cls_token_num=1).cuda()
337 |
338 | # KVCache
339 | self.max_batch_size = -1
340 | self.max_seq_length = -1
341 |
342 | self.norm = nn.LayerNorm(embed_dim, eps=1e-6)
343 |
344 | self.initialize_weights()
345 |
346 | def initialize_weights(self):
347 | # parameters
348 | torch.nn.init.normal_(self.pos_embed_learned, std=.02)
349 |
350 | # initialize nn.Linear and nn.LayerNorm
351 | self.apply(self._init_weights)
352 |
353 | def _init_weights(self, m):
354 | if isinstance(m, nn.Linear):
355 | # we use xavier_uniform following official JAX ViT:
356 | torch.nn.init.xavier_uniform_(m.weight)
357 | if isinstance(m, nn.Linear) and m.bias is not None:
358 | nn.init.constant_(m.bias, 0)
359 | elif isinstance(m, nn.LayerNorm):
360 | if m.bias is not None:
361 | nn.init.constant_(m.bias, 0)
362 | if m.weight is not None:
363 | nn.init.constant_(m.weight, 1.0)
364 |
365 | def setup_caches(self, max_batch_size, max_seq_length):
366 | # if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
367 | # return
368 | head_dim = self.config.dim // self.config.n_head
369 | max_seq_length = find_multiple(max_seq_length, 8)
370 | self.max_seq_length = max_seq_length
371 | self.max_batch_size = max_batch_size
372 | for b in self.blocks:
373 | b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_head, head_dim)
374 |
375 | causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool))
376 | self.causal_mask = causal_mask
377 | grid_size = int(self.seq_len ** 0.5)
378 | assert grid_size * grid_size == self.seq_len
379 | self.freqs_cis = precompute_freqs_cis_2d(grid_size, self.config.dim // self.config.n_head,
380 | self.config.rope_base, 1)
381 |
382 | def patchify(self, x):
383 | bsz, c, h, w = x.shape
384 | p = self.patch_size
385 | h_, w_ = h // p, w // p
386 |
387 | x = x.reshape(bsz, c, h_, p, w_, p)
388 | x = torch.einsum('nchpwq->nhwcpq', x)
389 | x = x.reshape(bsz, h_ * w_, c * p ** 2)
390 | return x # [n, l, d]
391 |
392 | def unpatchify(self, x):
393 | bsz = x.shape[0]
394 | p = self.patch_size
395 | h_, w_ = int(np.sqrt(self.seq_len)), int(np.sqrt(self.seq_len))
396 |
397 | x = x.reshape(bsz, h_, w_, 3, p, p)
398 | x = torch.einsum('nhwcpq->nchpwq', x)
399 | x = x.reshape(bsz, 3, h_ * p, w_ * p)
400 | return x # [n, 3, h, w]
401 |
402 | def predict(self, x, cond_list, input_pos=None):
403 | x = self.patch_emb(x)
404 | x = torch.cat([self.cond_emb(cond_list[0]).unsqueeze(1).repeat(1, 1, 1), x], dim=1)
405 |
406 | # position embedding
407 | x = x + self.pos_embed_learned[:, :x.shape[1]]
408 | x = self.patch_emb_ln(x)
409 |
410 | if input_pos is not None:
411 | # use kv cache
412 | freqs_cis = self.freqs_cis[input_pos]
413 | mask = self.causal_mask[input_pos]
414 | x = x[:, input_pos]
415 | else:
416 | # training
417 | freqs_cis = self.freqs_cis[:x.shape[1]]
418 | mask = None
419 |
420 | # apply Transformer blocks
421 | if self.grad_checkpointing and not torch.jit.is_scripting() and self.training:
422 | for block in self.blocks:
423 | x = checkpoint(block, x, freqs_cis, input_pos, mask)
424 | else:
425 | for block in self.blocks:
426 | x = block(x, freqs_cis, input_pos, mask)
427 | x = self.norm(x)
428 |
429 | # return middle condition
430 | if input_pos is not None:
431 | middle_cond = x[:, 0]
432 | else:
433 | middle_cond = x[:, :-1]
434 |
435 | return [middle_cond]
436 |
437 | def forward(self, imgs, cond_list):
438 | """ training """
439 | # patchify to get gt
440 | patches = self.patchify(imgs)
441 | mask = torch.ones(patches.size(0), patches.size(1)).to(patches.device)
442 |
443 | # get condition for next level
444 | cond_list_next = self.predict(patches, cond_list)
445 |
446 | # reshape conditions and patches for next level
447 | for cond_idx in range(len(cond_list_next)):
448 | cond_list_next[cond_idx] = cond_list_next[cond_idx].reshape(cond_list_next[cond_idx].size(0) * cond_list_next[cond_idx].size(1), -1)
449 |
450 | patches = patches.reshape(patches.size(0) * patches.size(1), -1)
451 | patches = patches.reshape(patches.size(0), 3, self.patch_size, self.patch_size)
452 |
453 | return patches, cond_list_next, 0
454 |
455 | def sample(self, cond_list, num_iter, cfg, cfg_schedule, temperature, filter_threshold, next_level_sample_function,
456 | visualize=False):
457 | """ generation """
458 | if cfg == 1.0:
459 | bsz = cond_list[0].size(0)
460 | else:
461 | bsz = cond_list[0].size(0) // 2
462 |
463 | patches = torch.zeros(bsz, self.seq_len, 3 * self.patch_size**2).cuda()
464 | num_iter = self.seq_len
465 |
466 | device = cond_list[0].device
467 | with torch.device(device):
468 | self.setup_caches(max_batch_size=cond_list[0].size(0), max_seq_length=num_iter)
469 |
470 | # sample
471 | for step in range(num_iter):
472 | cur_patches = patches.clone()
473 |
474 | if not cfg == 1.0:
475 | patches = torch.cat([patches, patches], dim=0)
476 |
477 | # get next level conditions
478 | cond_list_next = self.predict(patches, cond_list, input_pos=torch.Tensor([step]).int())
479 | # cfg schedule
480 | if cfg_schedule == "linear":
481 | cfg_iter = 1 + (cfg - 1) * (step + 1) / self.seq_len
482 | else:
483 | cfg_iter = cfg
484 | sampled_patches = next_level_sample_function(cond_list=cond_list_next, cfg=cfg_iter,
485 | temperature=temperature, filter_threshold=filter_threshold)
486 | sampled_patches = sampled_patches.reshape(sampled_patches.size(0), -1)
487 |
488 | cur_patches[:, step] = sampled_patches.to(cur_patches.dtype)
489 | patches = cur_patches.clone()
490 |
491 | # visualize generation process for colab
492 | if visualize:
493 | visualize_patch(self.unpatchify(patches))
494 |
495 | # clean up kv cache
496 | for b in self.blocks:
497 | b.attention.kv_cache = None
498 | patches = self.unpatchify(patches)
499 | return patches
500 |
--------------------------------------------------------------------------------
/models/fractalgen.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from models.ar import AR
7 | from models.mar import MAR
8 | from models.pixelloss import PixelLoss
9 |
10 |
11 | class FractalGen(nn.Module):
12 | """ Fractal Generative Model"""
13 |
14 | def __init__(self,
15 | img_size_list,
16 | embed_dim_list,
17 | num_blocks_list,
18 | num_heads_list,
19 | generator_type_list,
20 | label_drop_prob=0.1,
21 | class_num=1000,
22 | attn_dropout=0.1,
23 | proj_dropout=0.1,
24 | guiding_pixel=False,
25 | num_conds=1,
26 | r_weight=1.0,
27 | grad_checkpointing=False,
28 | fractal_level=0):
29 | super().__init__()
30 |
31 | # --------------------------------------------------------------------------
32 | # fractal specifics
33 | self.fractal_level = fractal_level
34 | self.num_fractal_levels = len(img_size_list)
35 |
36 | # --------------------------------------------------------------------------
37 | # Class embedding for the first fractal level
38 | if self.fractal_level == 0:
39 | self.num_classes = class_num
40 | self.class_emb = nn.Embedding(class_num, embed_dim_list[0])
41 | self.label_drop_prob = label_drop_prob
42 | self.fake_latent = nn.Parameter(torch.zeros(1, embed_dim_list[0]))
43 | torch.nn.init.normal_(self.class_emb.weight, std=0.02)
44 | torch.nn.init.normal_(self.fake_latent, std=0.02)
45 |
46 | # --------------------------------------------------------------------------
47 | # Generator for the current level
48 | if generator_type_list[fractal_level] == "ar":
49 | generator = AR
50 | elif generator_type_list[fractal_level] == "mar":
51 | generator = MAR
52 | else:
53 | raise NotImplementedError
54 | self.generator = generator(
55 | seq_len=(img_size_list[fractal_level] // img_size_list[fractal_level+1]) ** 2,
56 | patch_size=img_size_list[fractal_level+1],
57 | cond_embed_dim=embed_dim_list[fractal_level-1] if fractal_level > 0 else embed_dim_list[0],
58 | embed_dim=embed_dim_list[fractal_level],
59 | num_blocks=num_blocks_list[fractal_level],
60 | num_heads=num_heads_list[fractal_level],
61 | attn_dropout=attn_dropout,
62 | proj_dropout=proj_dropout,
63 | guiding_pixel=guiding_pixel if fractal_level > 0 else False,
64 | num_conds=num_conds,
65 | grad_checkpointing=grad_checkpointing,
66 | )
67 |
68 | # --------------------------------------------------------------------------
69 | # Build the next fractal level recursively
70 | if self.fractal_level < self.num_fractal_levels - 2:
71 | self.next_fractal = FractalGen(
72 | img_size_list=img_size_list,
73 | embed_dim_list=embed_dim_list,
74 | num_blocks_list=num_blocks_list,
75 | num_heads_list=num_heads_list,
76 | generator_type_list=generator_type_list,
77 | label_drop_prob=label_drop_prob,
78 | class_num=class_num,
79 | attn_dropout=attn_dropout,
80 | proj_dropout=proj_dropout,
81 | guiding_pixel=guiding_pixel,
82 | num_conds=num_conds,
83 | r_weight=r_weight,
84 | grad_checkpointing=grad_checkpointing,
85 | fractal_level=fractal_level+1
86 | )
87 | else:
88 | # The final fractal level uses PixelLoss.
89 | self.next_fractal = PixelLoss(
90 | c_channels=embed_dim_list[fractal_level],
91 | depth=num_blocks_list[fractal_level+1],
92 | width=embed_dim_list[fractal_level+1],
93 | num_heads=num_heads_list[fractal_level+1],
94 | r_weight=r_weight,
95 | )
96 |
97 | def forward(self, imgs, cond_list):
98 | """
99 | Forward pass to get loss recursively.
100 | """
101 | if self.fractal_level == 0:
102 | # Compute class embedding conditions.
103 | class_embedding = self.class_emb(cond_list)
104 | if self.training:
105 | # Randomly drop labels according to label_drop_prob.
106 | drop_latent_mask = (torch.rand(cond_list.size(0)) < self.label_drop_prob).unsqueeze(-1).cuda().to(class_embedding.dtype)
107 | class_embedding = drop_latent_mask * self.fake_latent + (1 - drop_latent_mask) * class_embedding
108 | else:
109 | # For evaluation (unconditional NLL), use a constant mask.
110 | drop_latent_mask = torch.ones(cond_list.size(0)).unsqueeze(-1).cuda().to(class_embedding.dtype)
111 | class_embedding = drop_latent_mask * self.fake_latent + (1 - drop_latent_mask) * class_embedding
112 | cond_list = [class_embedding for _ in range(5)]
113 |
114 | # Get image patches and conditions for the next level
115 | imgs, cond_list, guiding_pixel_loss = self.generator(imgs, cond_list)
116 | # Compute loss recursively from the next fractal level.
117 | loss = self.next_fractal(imgs, cond_list)
118 | return loss + guiding_pixel_loss
119 |
120 | def sample(self, cond_list, num_iter_list, cfg, cfg_schedule, temperature, filter_threshold, fractal_level,
121 | visualize=False):
122 | """
123 | Generate samples recursively.
124 | """
125 | if fractal_level < self.num_fractal_levels - 2:
126 | next_level_sample_function = partial(
127 | self.next_fractal.sample,
128 | num_iter_list=num_iter_list,
129 | cfg_schedule="constant",
130 | fractal_level=fractal_level + 1
131 | )
132 | else:
133 | next_level_sample_function = self.next_fractal.sample
134 |
135 | # Recursively sample using the current generator.
136 | return self.generator.sample(
137 | cond_list, num_iter_list[fractal_level], cfg, cfg_schedule,
138 | temperature, filter_threshold, next_level_sample_function, visualize
139 | )
140 |
141 |
142 | def fractalar_in64(**kwargs):
143 | model = FractalGen(
144 | img_size_list=(64, 4, 1),
145 | embed_dim_list=(1024, 512, 128),
146 | num_blocks_list=(32, 8, 3),
147 | num_heads_list=(16, 8, 4),
148 | generator_type_list=("ar", "ar", "ar"),
149 | fractal_level=0,
150 | **kwargs)
151 | return model
152 |
153 |
154 | def fractalmar_in64(**kwargs):
155 | model = FractalGen(
156 | img_size_list=(64, 4, 1),
157 | embed_dim_list=(1024, 512, 128),
158 | num_blocks_list=(32, 8, 3),
159 | num_heads_list=(16, 8, 4),
160 | generator_type_list=("mar", "mar", "ar"),
161 | fractal_level=0,
162 | **kwargs)
163 | return model
164 |
165 |
166 | def fractalmar_base_in256(**kwargs):
167 | model = FractalGen(
168 | img_size_list=(256, 16, 4, 1),
169 | embed_dim_list=(768, 384, 192, 64),
170 | num_blocks_list=(24, 6, 3, 1),
171 | num_heads_list=(12, 6, 3, 4),
172 | generator_type_list=("mar", "mar", "mar", "ar"),
173 | fractal_level=0,
174 | **kwargs)
175 | return model
176 |
177 |
178 | def fractalmar_large_in256(**kwargs):
179 | model = FractalGen(
180 | img_size_list=(256, 16, 4, 1),
181 | embed_dim_list=(1024, 512, 256, 64),
182 | num_blocks_list=(32, 8, 4, 1),
183 | num_heads_list=(16, 8, 4, 4),
184 | generator_type_list=("mar", "mar", "mar", "ar"),
185 | fractal_level=0,
186 | **kwargs)
187 | return model
188 |
189 |
190 | def fractalmar_huge_in256(**kwargs):
191 | model = FractalGen(
192 | img_size_list=(256, 16, 4, 1),
193 | embed_dim_list=(1280, 640, 320, 64),
194 | num_blocks_list=(40, 10, 5, 1),
195 | num_heads_list=(16, 8, 4, 4),
196 | generator_type_list=("mar", "mar", "mar", "ar"),
197 | fractal_level=0,
198 | **kwargs)
199 | return model
200 |
--------------------------------------------------------------------------------
/models/mar.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 |
3 | import math
4 | import numpy as np
5 | import scipy.stats as stats
6 | import torch
7 | import torch.nn as nn
8 | from torch.utils.checkpoint import checkpoint
9 | from util.visualize import visualize_patch
10 |
11 | from timm.models.vision_transformer import DropPath, Mlp
12 | from models.pixelloss import PixelLoss
13 |
14 |
15 | def mask_by_order(mask_len, order, bsz, seq_len):
16 | masking = torch.zeros(bsz, seq_len).cuda()
17 | masking = torch.scatter(masking, dim=-1, index=order[:, :mask_len.long()], src=torch.ones(bsz, seq_len).cuda()).bool()
18 | return masking
19 |
20 |
21 | class Attention(nn.Module):
22 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
23 | super().__init__()
24 | self.num_heads = num_heads
25 | head_dim = dim // num_heads
26 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
27 | self.scale = qk_scale or head_dim ** -0.5
28 |
29 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
30 | self.attn_drop = nn.Dropout(attn_drop)
31 | self.proj = nn.Linear(dim, dim)
32 | self.proj_drop = nn.Dropout(proj_drop)
33 |
34 | def forward(self, x):
35 | B, N, C = x.shape
36 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
37 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
38 |
39 | with torch.cuda.amp.autocast(enabled=False):
40 | attn = (q.float() @ k.float().transpose(-2, -1)) * self.scale
41 |
42 | attn = attn - torch.max(attn, dim=-1, keepdim=True)[0]
43 | attn = attn.softmax(dim=-1)
44 | attn = self.attn_drop(attn)
45 |
46 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
47 | x = self.proj(x)
48 | x = self.proj_drop(x)
49 | return x
50 |
51 |
52 | class Block(nn.Module):
53 |
54 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, proj_drop=0., attn_drop=0.,
55 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
56 | super().__init__()
57 | self.norm1 = norm_layer(dim)
58 | self.attn = Attention(
59 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=proj_drop)
60 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
61 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
62 | self.norm2 = norm_layer(dim)
63 | mlp_hidden_dim = int(dim * mlp_ratio)
64 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=proj_drop)
65 |
66 | def forward(self, x):
67 | x = x + self.drop_path(self.attn(self.norm1(x)))
68 | x = x + self.drop_path(self.mlp(self.norm2(x)))
69 | return x
70 |
71 |
72 | class MAR(nn.Module):
73 | def __init__(self, seq_len, patch_size, cond_embed_dim, embed_dim, num_blocks, num_heads, attn_dropout, proj_dropout,
74 | num_conds=1, guiding_pixel=False, grad_checkpointing=False
75 | ):
76 | super().__init__()
77 |
78 | self.seq_len = seq_len
79 | self.patch_size = patch_size
80 |
81 | self.num_conds = num_conds
82 | self.guiding_pixel = guiding_pixel
83 |
84 | self.grad_checkpointing = grad_checkpointing
85 |
86 | # --------------------------------------------------------------------------
87 | # variant masking ratio
88 | self.mask_ratio_generator = stats.truncnorm(-4, 0, loc=1.0, scale=0.25)
89 |
90 | # --------------------------------------------------------------------------
91 | # network
92 | self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
93 | self.patch_emb = nn.Linear(3 * patch_size ** 2, embed_dim, bias=True)
94 | self.patch_emb_ln = nn.LayerNorm(embed_dim, eps=1e-6)
95 | self.cond_emb = nn.Linear(cond_embed_dim, embed_dim, bias=True)
96 | if self.guiding_pixel:
97 | self.pix_proj = nn.Linear(3, embed_dim, bias=True)
98 | self.pos_embed_learned = nn.Parameter(torch.zeros(1, seq_len+num_conds+self.guiding_pixel, embed_dim))
99 |
100 | self.blocks = nn.ModuleList([
101 | Block(embed_dim, num_heads, mlp_ratio=4.,
102 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
103 | proj_drop=proj_dropout, attn_drop=attn_dropout)
104 | for _ in range(num_blocks)
105 | ])
106 | self.norm = nn.LayerNorm(embed_dim, eps=1e-6)
107 |
108 | self.initialize_weights()
109 |
110 | if self.guiding_pixel:
111 | self.guiding_pixel_loss = PixelLoss(
112 | c_channels=cond_embed_dim,
113 | width=128,
114 | depth=3,
115 | num_heads=4,
116 | r_weight=5.0
117 | )
118 |
119 | def initialize_weights(self):
120 | # parameters
121 | torch.nn.init.normal_(self.mask_token, std=.02)
122 | torch.nn.init.normal_(self.pos_embed_learned, std=.02)
123 |
124 | # initialize nn.Linear and nn.LayerNorm
125 | self.apply(self._init_weights)
126 |
127 | def _init_weights(self, m):
128 | if isinstance(m, nn.Linear):
129 | # we use xavier_uniform following official JAX ViT:
130 | torch.nn.init.xavier_uniform_(m.weight)
131 | if isinstance(m, nn.Linear) and m.bias is not None:
132 | nn.init.constant_(m.bias, 0)
133 | elif isinstance(m, nn.LayerNorm):
134 | if m.bias is not None:
135 | nn.init.constant_(m.bias, 0)
136 | if m.weight is not None:
137 | nn.init.constant_(m.weight, 1.0)
138 |
139 | def patchify(self, x):
140 | bsz, c, h, w = x.shape
141 | p = self.patch_size
142 | h_, w_ = h // p, w // p
143 |
144 | x = x.reshape(bsz, c, h_, p, w_, p)
145 | x = torch.einsum('nchpwq->nhwcpq', x)
146 | x = x.reshape(bsz, h_ * w_, c * p ** 2)
147 | return x # [n, l, d]
148 |
149 | def unpatchify(self, x):
150 | bsz = x.shape[0]
151 | p = self.patch_size
152 | h_, w_ = int(np.sqrt(self.seq_len)), int(np.sqrt(self.seq_len))
153 |
154 | x = x.reshape(bsz, h_, w_, 3, p, p)
155 | x = torch.einsum('nhwcpq->nchpwq', x)
156 | x = x.reshape(bsz, 3, h_ * p, w_ * p)
157 | return x # [n, 3, h, w]
158 |
159 | def sample_orders(self, bsz):
160 | orders = torch.argsort(torch.rand(bsz, self.seq_len).cuda(), dim=1).long()
161 | return orders
162 |
163 | def random_masking_uniform(self, x, orders):
164 | bsz, seq_len, embed_dim = x.shape
165 | num_masked_tokens = np.random.randint(seq_len) + 1
166 | mask = torch.zeros(bsz, seq_len, device=x.device)
167 | mask = torch.scatter(mask, dim=-1, index=orders[:, :num_masked_tokens],
168 | src=torch.ones(bsz, seq_len, device=x.device))
169 | return mask
170 |
171 | def random_masking(self, x, orders):
172 | bsz, seq_len, embed_dim = x.shape
173 | mask_rates = self.mask_ratio_generator.rvs(bsz)
174 | num_masked_tokens = torch.Tensor(np.ceil(seq_len * mask_rates)).cuda()
175 | expanded_indices = torch.arange(seq_len, device=x.device).expand(bsz, seq_len)
176 | sorted_orders = torch.argsort(orders, dim=-1)
177 | mask = (expanded_indices < num_masked_tokens[:, None]).float()
178 | mask = torch.scatter(torch.zeros_like(mask), dim=-1, index=sorted_orders, src=mask)
179 |
180 | return mask
181 |
182 | def predict(self, x, mask, cond_list):
183 | x = self.patch_emb(x)
184 |
185 | # prepend conditions from prev generator
186 | for i in range(self.num_conds):
187 | x = torch.cat([self.cond_emb(cond_list[i]).unsqueeze(1), x], dim=1)
188 |
189 | # prepend guiding pixel
190 | if self.guiding_pixel:
191 | x = torch.cat([self.pix_proj(cond_list[-1]).unsqueeze(1), x], dim=1)
192 |
193 | # masking
194 | mask_with_cond = torch.cat([torch.zeros(x.size(0), self.num_conds+self.guiding_pixel, device=x.device), mask], dim=1).bool()
195 | x = torch.where(mask_with_cond.unsqueeze(-1), self.mask_token.to(x.dtype), x)
196 |
197 | # position embedding
198 | x = x + self.pos_embed_learned
199 | x = self.patch_emb_ln(x)
200 |
201 | # apply Transformer blocks
202 | if self.grad_checkpointing and not torch.jit.is_scripting() and self.training:
203 | for block in self.blocks:
204 | x = checkpoint(block, x)
205 | else:
206 | for block in self.blocks:
207 | x = block(x)
208 | x = self.norm(x)
209 |
210 | # return 5 conditions: middle, top, right, bottom, left
211 | middle_cond = x[:, self.num_conds+self.guiding_pixel:]
212 | bsz, seq_len, c = middle_cond.size()
213 | h = int(np.sqrt(seq_len))
214 | w = int(np.sqrt(seq_len))
215 | top_cond = middle_cond.reshape(bsz, h, w, c)
216 | top_cond = torch.cat([torch.zeros(bsz, 1, w, c, device=top_cond.device), top_cond[:, :-1]], dim=1)
217 | top_cond = top_cond.reshape(bsz, seq_len, c)
218 |
219 | right_cond = middle_cond.reshape(bsz, h, w, c)
220 | right_cond = torch.cat([right_cond[:, :, 1:], torch.zeros(bsz, h, 1, c, device=right_cond.device)], dim=2)
221 | right_cond = right_cond.reshape(bsz, seq_len, c)
222 |
223 | bottom_cond = middle_cond.reshape(bsz, h, w, c)
224 | bottom_cond = torch.cat([bottom_cond[:, 1:], torch.zeros(bsz, 1, w, c, device=bottom_cond.device)], dim=1)
225 | bottom_cond = bottom_cond.reshape(bsz, seq_len, c)
226 |
227 | left_cond = middle_cond.reshape(bsz, h, w, c)
228 | left_cond = torch.cat([torch.zeros(bsz, h, 1, c, device=left_cond.device), left_cond[:, :, :-1]], dim=2)
229 | left_cond = left_cond.reshape(bsz, seq_len, c)
230 |
231 | return [middle_cond, top_cond, right_cond, bottom_cond, left_cond]
232 |
233 | def forward(self, imgs, cond_list):
234 | """ training """
235 | # patchify to get gt
236 | patches = self.patchify(imgs)
237 |
238 | # mask tokens
239 | orders = self.sample_orders(bsz=patches.size(0))
240 | if self.training:
241 | mask = self.random_masking(patches, orders)
242 | else:
243 | # uniform random masking for NLL computation
244 | mask = self.random_masking_uniform(patches, orders)
245 |
246 | # guiding pixel
247 | if self.guiding_pixel:
248 | guiding_pixels = imgs.mean(-1).mean(-1)
249 | guiding_pixel_loss = self.guiding_pixel_loss(guiding_pixels, cond_list)
250 | cond_list.append(guiding_pixels)
251 | else:
252 | guiding_pixel_loss = torch.Tensor([0]).cuda().mean()
253 |
254 | # get condition for next level
255 | cond_list_next = self.predict(patches, mask, cond_list)
256 |
257 | # only keep those conditions and patches on mask
258 | for cond_idx in range(len(cond_list_next)):
259 | cond_list_next[cond_idx] = cond_list_next[cond_idx].reshape(cond_list_next[cond_idx].size(0) * cond_list_next[cond_idx].size(1), -1)
260 | cond_list_next[cond_idx] = cond_list_next[cond_idx][mask.reshape(-1).bool()]
261 |
262 | patches = patches.reshape(patches.size(0) * patches.size(1), -1)
263 | patches = patches[mask.reshape(-1).bool()]
264 | patches = patches.reshape(patches.size(0), 3, self.patch_size, self.patch_size)
265 |
266 | return patches, cond_list_next, guiding_pixel_loss
267 |
268 | def sample(self, cond_list, num_iter, cfg, cfg_schedule, temperature, filter_threshold, next_level_sample_function,
269 | visualize=False):
270 | """ generation """
271 | if cfg == 1.0:
272 | bsz = cond_list[0].size(0)
273 | else:
274 | bsz = cond_list[0].size(0) // 2
275 |
276 | # sample the guiding pixel
277 | if self.guiding_pixel:
278 | sampled_pixels = self.guiding_pixel_loss.sample(cond_list, temperature, cfg, filter_threshold)
279 | if not cfg == 1.0:
280 | sampled_pixels = torch.cat([sampled_pixels, sampled_pixels], dim=0)
281 | cond_list.append(sampled_pixels)
282 |
283 | # init token mask
284 | mask = torch.ones(bsz, self.seq_len).cuda()
285 | patches = torch.zeros(bsz, self.seq_len, 3 * self.patch_size**2).cuda()
286 | orders = self.sample_orders(bsz)
287 | num_iter = min(self.seq_len, num_iter)
288 |
289 | # sample image
290 | for step in range(num_iter):
291 | cur_patches = patches.clone()
292 |
293 | if not cfg == 1.0:
294 | patches = torch.cat([patches, patches], dim=0)
295 | mask = torch.cat([mask, mask], dim=0)
296 |
297 | # get next level conditions
298 | cond_list_next = self.predict(patches, mask, cond_list)
299 |
300 | # mask ratio for the next round, following MAR.
301 | mask_ratio = np.cos(math.pi / 2. * (step + 1) / num_iter)
302 | mask_len = torch.Tensor([np.floor(self.seq_len * mask_ratio)]).cuda()
303 |
304 | # masks out at least one for the next iteration
305 | mask_len = torch.maximum(torch.Tensor([1]).cuda(),
306 | torch.minimum(torch.sum(mask, dim=-1, keepdims=True) - 1, mask_len))
307 |
308 | # get masking for next iteration and locations to be predicted in this iteration
309 | mask_next = mask_by_order(mask_len[0], orders, bsz, self.seq_len)
310 | if step >= num_iter - 1:
311 | mask_to_pred = mask[:bsz].bool()
312 | else:
313 | mask_to_pred = torch.logical_xor(mask[:bsz].bool(), mask_next.bool())
314 | mask = mask_next
315 | if not cfg == 1.0:
316 | mask_to_pred = torch.cat([mask_to_pred, mask_to_pred], dim=0)
317 |
318 | # sample token latents for this step
319 | for cond_idx in range(len(cond_list_next)):
320 | cond_list_next[cond_idx] = cond_list_next[cond_idx][mask_to_pred.nonzero(as_tuple=True)]
321 |
322 | # cfg schedule
323 | if cfg_schedule == "linear":
324 | cfg_iter = 1 + (cfg - 1) * (self.seq_len - mask_len[0]) / self.seq_len
325 | else:
326 | cfg_iter = cfg
327 | sampled_patches = next_level_sample_function(cond_list=cond_list_next, cfg=cfg_iter,
328 | temperature=temperature, filter_threshold=filter_threshold)
329 | sampled_patches = sampled_patches.reshape(sampled_patches.size(0), -1)
330 |
331 | if not cfg == 1.0:
332 | mask_to_pred, _ = mask_to_pred.chunk(2, dim=0)
333 |
334 | cur_patches[mask_to_pred.nonzero(as_tuple=True)] = sampled_patches.to(cur_patches.dtype)
335 | patches = cur_patches.clone()
336 |
337 | # visualize generation process for colab
338 | if visualize:
339 | visualize_patch(self.unpatchify(patches))
340 |
341 | patches = self.unpatchify(patches)
342 | return patches
343 |
--------------------------------------------------------------------------------
/models/pixelloss.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 |
3 | import math
4 | import torch
5 | import torch.nn as nn
6 |
7 | from timm.models.vision_transformer import DropPath, Mlp
8 |
9 |
10 | def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor:
11 | L, S = query.size(-2), key.size(-2)
12 | scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
13 | attn_bias = torch.zeros(L, S, dtype=query.dtype).cuda()
14 | if is_causal:
15 | assert attn_mask is None
16 | temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0).cuda()
17 | attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
18 | attn_bias.to(query.dtype)
19 |
20 | if attn_mask is not None:
21 | if attn_mask.dtype == torch.bool:
22 | attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
23 | else:
24 | attn_bias += attn_mask
25 | with torch.cuda.amp.autocast(enabled=False):
26 | attn_weight = query @ key.transpose(-2, -1) * scale_factor
27 | attn_weight += attn_bias
28 | attn_weight = torch.softmax(attn_weight, dim=-1)
29 | attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
30 | return attn_weight @ value
31 |
32 |
33 | class CausalAttention(nn.Module):
34 | def __init__(
35 | self,
36 | dim: int,
37 | num_heads: int = 8,
38 | qkv_bias: bool = False,
39 | qk_norm: bool = False,
40 | attn_drop: float = 0.0,
41 | proj_drop: float = 0.0,
42 | norm_layer: nn.Module = nn.LayerNorm
43 | ) -> None:
44 | super().__init__()
45 | assert dim % num_heads == 0, "dim should be divisible by num_heads"
46 | self.num_heads = num_heads
47 | self.head_dim = dim // num_heads
48 | self.scale = self.head_dim ** -0.5
49 |
50 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
51 | self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
52 | self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
53 | self.attn_drop = nn.Dropout(attn_drop)
54 | self.proj = nn.Linear(dim, dim)
55 | self.proj_drop = nn.Dropout(proj_drop)
56 |
57 | def forward(self, x: torch.Tensor) -> torch.Tensor:
58 | B, N, C = x.shape
59 | qkv = (
60 | self.qkv(x)
61 | .reshape(B, N, 3, self.num_heads, self.head_dim)
62 | .permute(2, 0, 3, 1, 4)
63 | )
64 | q, k, v = qkv.unbind(0)
65 | q, k = self.q_norm(q), self.k_norm(k)
66 |
67 | x = scaled_dot_product_attention(
68 | q,
69 | k,
70 | v,
71 | dropout_p=self.attn_drop.p if self.training else 0.0,
72 | is_causal=True
73 | )
74 |
75 | x = x.transpose(1, 2).reshape(B, N, C)
76 | x = self.proj(x)
77 | x = self.proj_drop(x)
78 | return x
79 |
80 |
81 | class CausalBlock(nn.Module):
82 |
83 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, proj_drop=0., attn_drop=0.,
84 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
85 | super().__init__()
86 | self.norm1 = norm_layer(dim)
87 | self.attn = CausalAttention(
88 | dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=proj_drop)
89 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
90 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
91 | self.norm2 = norm_layer(dim)
92 | mlp_hidden_dim = int(dim * mlp_ratio)
93 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=proj_drop)
94 |
95 | def forward(self, x):
96 | x = x + self.drop_path(self.attn(self.norm1(x)))
97 | x = x + self.drop_path(self.mlp(self.norm2(x)))
98 | return x
99 |
100 |
101 | class MlmLayer(nn.Module):
102 |
103 | def __init__(self, vocab_size):
104 | super().__init__()
105 | self.bias = nn.Parameter(torch.zeros(1, vocab_size))
106 |
107 | def forward(self, x, word_embeddings):
108 | word_embeddings = word_embeddings.transpose(0, 1)
109 | logits = torch.matmul(x, word_embeddings)
110 | logits = logits + self.bias
111 | return logits
112 |
113 |
114 | class PixelLoss(nn.Module):
115 | def __init__(self, c_channels, width, depth, num_heads, r_weight=1.0):
116 | super().__init__()
117 |
118 | self.pix_mean = torch.Tensor([0.485, 0.456, 0.406])
119 | self.pix_std = torch.Tensor([0.229, 0.224, 0.225])
120 |
121 | self.cond_proj = nn.Linear(c_channels, width)
122 | self.r_codebook = nn.Embedding(256, width)
123 | self.g_codebook = nn.Embedding(256, width)
124 | self.b_codebook = nn.Embedding(256, width)
125 |
126 | self.ln = nn.LayerNorm(width, eps=1e-6)
127 | self.blocks = nn.ModuleList([
128 | CausalBlock(width, num_heads=num_heads, mlp_ratio=4.0,
129 | qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6),
130 | proj_drop=0, attn_drop=0)
131 | for _ in range(depth)
132 | ])
133 | self.norm = nn.LayerNorm(width, eps=1e-6)
134 |
135 | self.r_weight = r_weight
136 | self.r_mlm = MlmLayer(256)
137 | self.g_mlm = MlmLayer(256)
138 | self.b_mlm = MlmLayer(256)
139 |
140 | self.criterion = torch.nn.CrossEntropyLoss(reduction="none")
141 |
142 | self.initialize_weights()
143 |
144 | def initialize_weights(self):
145 | # parameters
146 | torch.nn.init.normal_(self.r_codebook.weight, std=.02)
147 | torch.nn.init.normal_(self.g_codebook.weight, std=.02)
148 | torch.nn.init.normal_(self.b_codebook.weight, std=.02)
149 |
150 | # initialize nn.Linear and nn.LayerNorm
151 | self.apply(self._init_weights)
152 |
153 | def _init_weights(self, m):
154 | if isinstance(m, nn.Linear):
155 | # we use xavier_uniform following official JAX ViT:
156 | torch.nn.init.xavier_uniform_(m.weight)
157 | if isinstance(m, nn.Linear) and m.bias is not None:
158 | nn.init.constant_(m.bias, 0)
159 | elif isinstance(m, nn.LayerNorm):
160 | if m.bias is not None:
161 | nn.init.constant_(m.bias, 0)
162 | if m.weight is not None:
163 | nn.init.constant_(m.weight, 1.0)
164 |
165 | def predict(self, target, cond_list):
166 | target = target.reshape(target.size(0), target.size(1))
167 | # back to [0, 255]
168 | mean = self.pix_mean.cuda().unsqueeze(0)
169 | std = self.pix_std.cuda().unsqueeze(0)
170 | target = target * std + mean
171 | # add a very small noice to avoid pixel distribution inconsistency caused by banker's rounding
172 | target = (target * 255 + 1e-2 * torch.randn_like(target)).round().long()
173 |
174 | # take only the middle condition
175 | cond = cond_list[0]
176 | x = torch.cat(
177 | [self.cond_proj(cond).unsqueeze(1), self.r_codebook(target[:, 0:1]), self.g_codebook(target[:, 1:2]),
178 | self.b_codebook(target[:, 2:3])], dim=1)
179 | x = self.ln(x)
180 |
181 | for block in self.blocks:
182 | x = block(x)
183 |
184 | x = self.norm(x)
185 | with torch.cuda.amp.autocast(enabled=False):
186 | r_logits = self.r_mlm(x[:, 0], self.r_codebook.weight)
187 | g_logits = self.g_mlm(x[:, 1], self.g_codebook.weight)
188 | b_logits = self.b_mlm(x[:, 2], self.b_codebook.weight)
189 |
190 | logits = torch.cat([r_logits.unsqueeze(1), g_logits.unsqueeze(1), b_logits.unsqueeze(1)], dim=1)
191 | return logits, target
192 |
193 | def forward(self, target, cond_list):
194 | """ training """
195 | logits, target = self.predict(target, cond_list)
196 | loss_r = self.criterion(logits[:, 0], target[:, 0])
197 | loss_g = self.criterion(logits[:, 1], target[:, 1])
198 | loss_b = self.criterion(logits[:, 2], target[:, 2])
199 |
200 | if self.training:
201 | loss = (self.r_weight * loss_r + loss_g + loss_b) / (self.r_weight + 2)
202 | else:
203 | # for NLL computation
204 | loss = (loss_r + loss_g + loss_b) / 3
205 |
206 | return loss.mean()
207 |
208 | def sample(self, cond_list, temperature, cfg, filter_threshold=0):
209 | """ generation """
210 | if cfg == 1.0:
211 | bsz = cond_list[0].size(0)
212 | else:
213 | bsz = cond_list[0].size(0) // 2
214 | pixel_values = torch.zeros(bsz, 3).cuda()
215 |
216 | for i in range(3):
217 | if cfg == 1.0:
218 | logits, _ = self.predict(pixel_values, cond_list)
219 | else:
220 | logits, _ = self.predict(torch.cat([pixel_values, pixel_values], dim=0), cond_list)
221 | logits = logits[:, i]
222 | logits = logits * temperature
223 |
224 | if not cfg == 1.0:
225 | cond_logits = logits[:bsz]
226 | uncond_logits = logits[bsz:]
227 |
228 | # very unlikely conditional logits will be suppressed
229 | cond_probs = torch.softmax(cond_logits, dim=-1)
230 | mask = cond_probs < filter_threshold
231 | uncond_logits[mask] = torch.max(
232 | uncond_logits,
233 | cond_logits - torch.max(cond_logits, dim=-1, keepdim=True)[0] + torch.max(uncond_logits, dim=-1, keepdim=True)[0]
234 | )[mask]
235 |
236 | logits = uncond_logits + cfg * (cond_logits - uncond_logits)
237 |
238 | # get token prediction
239 | probs = torch.softmax(logits, dim=-1)
240 | sampled_ids = torch.multinomial(probs, num_samples=1).reshape(-1)
241 | pixel_values[:, i] = (sampled_ids.float() / 255 - self.pix_mean[i]) / self.pix_std[i]
242 |
243 | # back to [0, 1]
244 | return pixel_values
245 |
--------------------------------------------------------------------------------
/util/crop.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 |
4 |
5 | def center_crop_arr(pil_image, image_size):
6 | """
7 | Center cropping implementation from ADM.
8 | https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
9 | """
10 | while min(*pil_image.size) >= 2 * image_size:
11 | pil_image = pil_image.resize(
12 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX
13 | )
14 |
15 | scale = image_size / min(*pil_image.size)
16 | pil_image = pil_image.resize(
17 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
18 | )
19 |
20 | arr = np.array(pil_image)
21 | crop_y = (arr.shape[0] - image_size) // 2
22 | crop_x = (arr.shape[1] - image_size) // 2
23 | return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
24 |
--------------------------------------------------------------------------------
/util/download.py:
--------------------------------------------------------------------------------
1 | import os
2 | from tqdm import tqdm
3 | import requests
4 |
5 |
6 | def download_pretrained_fractalar_in64(overwrite=False):
7 | download_path = "pretrained_models/fractalar_in64/checkpoint-last.pth"
8 | if not os.path.exists(download_path) or overwrite:
9 | headers = {'user-agent': 'Wget/1.16 (linux-gnu)'}
10 | os.makedirs("pretrained_models/fractalar_in64", exist_ok=True)
11 | r = requests.get("https://www.dropbox.com/scl/fi/n25tbij7aqkwo1ypqhz72/checkpoint-last.pth?rlkey=2czevgex3ocg2ae8zde3xpb3f&st=mj0subup&dl=0", stream=True, headers=headers)
12 | print("Downloading FractalAR on ImageNet 64x64...")
13 | with open(download_path, 'wb') as f:
14 | for chunk in tqdm(r.iter_content(chunk_size=1024*1024), unit="MB", total=1688):
15 | if chunk:
16 | f.write(chunk)
17 |
18 |
19 | def download_pretrained_fractalmar_in64(overwrite=False):
20 | download_path = "pretrained_models/fractalmar_in64/checkpoint-last.pth"
21 | if not os.path.exists(download_path) or overwrite:
22 | headers = {'user-agent': 'Wget/1.16 (linux-gnu)'}
23 | os.makedirs("pretrained_models/fractalmar_in64", exist_ok=True)
24 | r = requests.get("https://www.dropbox.com/scl/fi/lh7fmv48pusujd6m4kcdn/checkpoint-last.pth?rlkey=huihey61ok32h28o3tbbq6ek9&st=fxtoawba&dl=0", stream=True, headers=headers)
25 | print("Downloading FractalMAR on ImageNet 64x64...")
26 | with open(download_path, 'wb') as f:
27 | for chunk in tqdm(r.iter_content(chunk_size=1024*1024), unit="MB", total=1650):
28 | if chunk:
29 | f.write(chunk)
30 |
31 |
32 | def download_pretrained_fractalmar_base_in256(overwrite=False):
33 | download_path = "pretrained_models/fractalmar_base_in256/checkpoint-last.pth"
34 | if not os.path.exists(download_path) or overwrite:
35 | headers = {'user-agent': 'Wget/1.16 (linux-gnu)'}
36 | os.makedirs("pretrained_models/fractalmar_base_in256", exist_ok=True)
37 | r = requests.get("https://www.dropbox.com/scl/fi/zrdm7853ih4tcv98wmzhe/checkpoint-last.pth?rlkey=htq9yuzovet7d6ioa64s1xxd0&st=4c4d93vs&dl=0", stream=True, headers=headers)
38 | print("Downloading FractalMAR-Base on ImageNet 256x256...")
39 | with open(download_path, 'wb') as f:
40 | for chunk in tqdm(r.iter_content(chunk_size=1024*1024), unit="MB", total=712):
41 | if chunk:
42 | f.write(chunk)
43 |
44 |
45 | def download_pretrained_fractalmar_large_in256(overwrite=False):
46 | download_path = "pretrained_models/fractalmar_large_in256/checkpoint-last.pth"
47 | if not os.path.exists(download_path) or overwrite:
48 | headers = {'user-agent': 'Wget/1.16 (linux-gnu)'}
49 | os.makedirs("pretrained_models/fractalmar_large_in256", exist_ok=True)
50 | r = requests.get("https://www.dropbox.com/scl/fi/y1k05xx7ry8521ckxkqgt/checkpoint-last.pth?rlkey=wolq4krdq7z7eyjnaw5ndhq6k&st=vjeu5uzo&dl=0", stream=True, headers=headers)
51 | print("Downloading FractalMAR-Large on ImageNet 256x256...")
52 | with open(download_path, 'wb') as f:
53 | for chunk in tqdm(r.iter_content(chunk_size=1024*1024), unit="MB", total=1669):
54 | if chunk:
55 | f.write(chunk)
56 |
57 |
58 | def download_pretrained_fractalmar_huge_in256(overwrite=False):
59 | download_path = "pretrained_models/fractalmar_huge_in256/checkpoint-last.pth"
60 | if not os.path.exists(download_path) or overwrite:
61 | headers = {'user-agent': 'Wget/1.16 (linux-gnu)'}
62 | os.makedirs("pretrained_models/fractalmar_huge_in256", exist_ok=True)
63 | r = requests.get("https://www.dropbox.com/scl/fi/t2rru8xr6wm23yvxskpww/checkpoint-last.pth?rlkey=dn9ss9zw4zsnckf6bat9hss6h&st=y7w921zo&dl=0", stream=True, headers=headers)
64 | print("Downloading FractalMAR-Huge on ImageNet 256x256...")
65 | with open(download_path, 'wb') as f:
66 | for chunk in tqdm(r.iter_content(chunk_size=1024*1024), unit="MB", total=3243):
67 | if chunk:
68 | f.write(chunk)
69 |
70 |
71 | if __name__ == "__main__":
72 | download_pretrained_fractalar_in64()
73 | download_pretrained_fractalmar_in64()
74 | download_pretrained_fractalmar_base_in256()
75 | download_pretrained_fractalmar_large_in256()
76 | download_pretrained_fractalmar_huge_in256()
77 |
--------------------------------------------------------------------------------
/util/lr_sched.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 |
4 | def adjust_learning_rate(optimizer, epoch, args):
5 | """Decay the learning rate with half-cycle cosine after warmup"""
6 | if epoch < args.warmup_epochs:
7 | lr = args.lr * epoch / args.warmup_epochs
8 | else:
9 | if args.lr_schedule == "constant":
10 | lr = args.lr
11 | elif args.lr_schedule == "cosine":
12 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \
13 | (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
14 | else:
15 | raise NotImplementedError
16 | for param_group in optimizer.param_groups:
17 | if "lr_scale" in param_group:
18 | param_group["lr"] = lr * param_group["lr_scale"]
19 | else:
20 | param_group["lr"] = lr
21 | return lr
22 |
--------------------------------------------------------------------------------
/util/misc.py:
--------------------------------------------------------------------------------
1 | import builtins
2 | import datetime
3 | import os
4 | import time
5 | from collections import defaultdict, deque
6 | from pathlib import Path
7 |
8 | import torch
9 | import torch.distributed as dist
10 | TORCH_MAJOR = int(torch.__version__.split('.')[0])
11 | TORCH_MINOR = int(torch.__version__.split('.')[1])
12 |
13 | if TORCH_MAJOR == 1 and TORCH_MINOR < 8:
14 | from torch._six import inf
15 | else:
16 | from torch import inf
17 |
18 |
19 | class SmoothedValue(object):
20 | """Track a series of values and provide access to smoothed values over a
21 | window or the global series average.
22 | """
23 |
24 | def __init__(self, window_size=20, fmt=None):
25 | if fmt is None:
26 | fmt = "{median:.4f} ({global_avg:.4f})"
27 | self.deque = deque(maxlen=window_size)
28 | self.total = 0.0
29 | self.count = 0
30 | self.fmt = fmt
31 |
32 | def update(self, value, n=1):
33 | self.deque.append(value)
34 | self.count += n
35 | self.total += value * n
36 |
37 | def synchronize_between_processes(self):
38 | """
39 | Warning: does not synchronize the deque!
40 | """
41 | if not is_dist_avail_and_initialized():
42 | return
43 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
44 | dist.barrier()
45 | dist.all_reduce(t)
46 | t = t.tolist()
47 | self.count = int(t[0])
48 | self.total = t[1]
49 |
50 | @property
51 | def median(self):
52 | d = torch.tensor(list(self.deque))
53 | return d.median().item()
54 |
55 | @property
56 | def avg(self):
57 | d = torch.tensor(list(self.deque), dtype=torch.float32)
58 | return d.mean().item()
59 |
60 | @property
61 | def global_avg(self):
62 | return self.total / self.count
63 |
64 | @property
65 | def max(self):
66 | return max(self.deque)
67 |
68 | @property
69 | def value(self):
70 | return self.deque[-1]
71 |
72 | def __str__(self):
73 | return self.fmt.format(
74 | median=self.median,
75 | avg=self.avg,
76 | global_avg=self.global_avg,
77 | max=self.max,
78 | value=self.value)
79 |
80 |
81 | class MetricLogger(object):
82 | def __init__(self, delimiter="\t"):
83 | self.meters = defaultdict(SmoothedValue)
84 | self.delimiter = delimiter
85 |
86 | def update(self, **kwargs):
87 | for k, v in kwargs.items():
88 | if v is None:
89 | continue
90 | if isinstance(v, torch.Tensor):
91 | v = v.item()
92 | assert isinstance(v, (float, int))
93 | self.meters[k].update(v)
94 |
95 | def __getattr__(self, attr):
96 | if attr in self.meters:
97 | return self.meters[attr]
98 | if attr in self.__dict__:
99 | return self.__dict__[attr]
100 | raise AttributeError("'{}' object has no attribute '{}'".format(
101 | type(self).__name__, attr))
102 |
103 | def __str__(self):
104 | loss_str = []
105 | for name, meter in self.meters.items():
106 | loss_str.append(
107 | "{}: {}".format(name, str(meter))
108 | )
109 | return self.delimiter.join(loss_str)
110 |
111 | def synchronize_between_processes(self):
112 | for meter in self.meters.values():
113 | meter.synchronize_between_processes()
114 |
115 | def add_meter(self, name, meter):
116 | self.meters[name] = meter
117 |
118 | def log_every(self, iterable, print_freq, header=None):
119 | i = 0
120 | if not header:
121 | header = ''
122 | start_time = time.time()
123 | end = time.time()
124 | iter_time = SmoothedValue(fmt='{avg:.4f}')
125 | data_time = SmoothedValue(fmt='{avg:.4f}')
126 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
127 | log_msg = [
128 | header,
129 | '[{0' + space_fmt + '}/{1}]',
130 | 'eta: {eta}',
131 | '{meters}',
132 | 'time: {time}',
133 | 'data: {data}'
134 | ]
135 | if torch.cuda.is_available():
136 | log_msg.append('max mem: {memory:.0f}')
137 | log_msg = self.delimiter.join(log_msg)
138 | MB = 1024.0 * 1024.0
139 | for obj in iterable:
140 | data_time.update(time.time() - end)
141 | yield obj
142 | iter_time.update(time.time() - end)
143 | if i % print_freq == 0 or i == len(iterable) - 1:
144 | eta_seconds = iter_time.global_avg * (len(iterable) - i)
145 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
146 | if torch.cuda.is_available():
147 | print(log_msg.format(
148 | i, len(iterable), eta=eta_string,
149 | meters=str(self),
150 | time=str(iter_time), data=str(data_time),
151 | memory=torch.cuda.max_memory_allocated() / MB))
152 | else:
153 | print(log_msg.format(
154 | i, len(iterable), eta=eta_string,
155 | meters=str(self),
156 | time=str(iter_time), data=str(data_time)))
157 | i += 1
158 | end = time.time()
159 | total_time = time.time() - start_time
160 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
161 | print('{} Total time: {} ({:.4f} s / it)'.format(
162 | header, total_time_str, total_time / len(iterable)))
163 |
164 |
165 | def setup_for_distributed(is_master):
166 | """
167 | This function disables printing when not in master process
168 | """
169 | builtin_print = builtins.print
170 |
171 | def print(*args, **kwargs):
172 | force = kwargs.pop('force', False)
173 | force = force or (get_world_size() > 8)
174 | if is_master or force:
175 | now = datetime.datetime.now().time()
176 | builtin_print('[{}] '.format(now), end='') # print with time stamp
177 | builtin_print(*args, **kwargs)
178 |
179 | builtins.print = print
180 |
181 |
182 | def is_dist_avail_and_initialized():
183 | if not dist.is_available():
184 | return False
185 | if not dist.is_initialized():
186 | return False
187 | return True
188 |
189 |
190 | def get_world_size():
191 | if not is_dist_avail_and_initialized():
192 | return 1
193 | return dist.get_world_size()
194 |
195 |
196 | def get_rank():
197 | if not is_dist_avail_and_initialized():
198 | return 0
199 | return dist.get_rank()
200 |
201 |
202 | def is_main_process():
203 | return get_rank() == 0
204 |
205 |
206 | def save_on_master(*args, **kwargs):
207 | if is_main_process():
208 | torch.save(*args, **kwargs)
209 |
210 |
211 | def init_distributed_mode(args):
212 | if args.dist_on_itp:
213 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
214 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
215 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
216 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
217 | os.environ['LOCAL_RANK'] = str(args.gpu)
218 | os.environ['RANK'] = str(args.rank)
219 | os.environ['WORLD_SIZE'] = str(args.world_size)
220 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
221 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
222 | args.rank = int(os.environ["RANK"])
223 | args.world_size = int(os.environ['WORLD_SIZE'])
224 | args.gpu = int(os.environ['LOCAL_RANK'])
225 | elif 'SLURM_PROCID' in os.environ:
226 | args.rank = int(os.environ['SLURM_PROCID'])
227 | args.gpu = args.rank % torch.cuda.device_count()
228 | else:
229 | print('Not using distributed mode')
230 | setup_for_distributed(is_master=True) # hack
231 | args.distributed = False
232 | return
233 |
234 | args.distributed = True
235 |
236 | torch.cuda.set_device(args.gpu)
237 | args.dist_backend = 'nccl'
238 | print('| distributed init (rank {}): {}, gpu {}'.format(
239 | args.rank, args.dist_url, args.gpu), flush=True)
240 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
241 | world_size=args.world_size, rank=args.rank)
242 | torch.distributed.barrier()
243 | setup_for_distributed(args.rank == 0)
244 |
245 |
246 | class NativeScalerWithGradNormCount:
247 | state_dict_key = "amp_scaler"
248 |
249 | def __init__(self):
250 | self._scaler = torch.cuda.amp.GradScaler()
251 |
252 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
253 | self._scaler.scale(loss).backward(create_graph=create_graph)
254 | if update_grad:
255 | if clip_grad is not None:
256 | assert parameters is not None
257 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
258 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
259 | else:
260 | self._scaler.unscale_(optimizer)
261 | norm = get_grad_norm_(parameters)
262 | self._scaler.step(optimizer)
263 | self._scaler.update()
264 | else:
265 | norm = None
266 | return norm
267 |
268 | def state_dict(self):
269 | return self._scaler.state_dict()
270 |
271 | def load_state_dict(self, state_dict):
272 | self._scaler.load_state_dict(state_dict)
273 |
274 |
275 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
276 | if isinstance(parameters, torch.Tensor):
277 | parameters = [parameters]
278 | parameters = [p for p in parameters if p.grad is not None]
279 | norm_type = float(norm_type)
280 | if len(parameters) == 0:
281 | return torch.tensor(0.)
282 | device = parameters[0].grad.device
283 | if norm_type == inf:
284 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
285 | else:
286 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
287 | return total_norm
288 |
289 |
290 | def add_weight_decay(model, weight_decay=1e-5, skip_list=()):
291 | decay = []
292 | no_decay = []
293 | for name, param in model.named_parameters():
294 | if not param.requires_grad:
295 | continue # frozen weights
296 | if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list or 'diffloss' in name:
297 | no_decay.append(param) # no weight decay on bias, norm and diffloss
298 | else:
299 | decay.append(param)
300 | return [
301 | {'params': no_decay, 'weight_decay': 0.},
302 | {'params': decay, 'weight_decay': weight_decay}]
303 |
304 |
305 | def save_model(args, epoch, model_without_ddp, optimizer, loss_scaler, epoch_name=None):
306 | if epoch_name is None:
307 | epoch_name = str(epoch)
308 | output_dir = Path(args.output_dir)
309 | checkpoint_path = output_dir / ('checkpoint-%s.pth' % epoch_name)
310 |
311 | to_save = {
312 | 'model': model_without_ddp.state_dict(),
313 | 'optimizer': optimizer.state_dict(),
314 | 'epoch': epoch,
315 | 'scaler': loss_scaler.state_dict(),
316 | 'args': args,
317 | }
318 | save_on_master(to_save, checkpoint_path)
319 |
320 |
321 | def all_reduce_mean(x):
322 | world_size = get_world_size()
323 | if world_size > 1:
324 | x_reduce = torch.tensor(x).cuda()
325 | dist.all_reduce(x_reduce)
326 | x_reduce /= world_size
327 | return x_reduce.item()
328 | else:
329 | return x
--------------------------------------------------------------------------------
/util/visualize.py:
--------------------------------------------------------------------------------
1 | from torchvision.utils import save_image
2 | from PIL import Image
3 | import torch
4 |
5 |
6 | def visualize_patch(viz_patches):
7 | from IPython.display import display, clear_output
8 | pix_mean = torch.Tensor([0.485, 0.456, 0.406]).cuda().view(1, -1, 1, 1)
9 | pix_std = torch.Tensor([0.229, 0.224, 0.225]).cuda().view(1, -1, 1, 1)
10 | viz_patches = viz_patches * pix_std + pix_mean
11 | img_size = viz_patches.size(2)
12 | if img_size < 256:
13 | viz_patches = torch.nn.functional.interpolate(viz_patches, scale_factor=256 // img_size, mode="nearest")
14 | save_image(viz_patches, "samples.png", nrow=4, normalize=True, value_range=(0, 1))
15 | sampled_patches_viz = Image.open("samples.png")
16 | clear_output(wait=True)
17 | display(sampled_patches_viz)
18 |
--------------------------------------------------------------------------------