├── SphereAR ├── __init__.py ├── gan │ ├── __init__.py │ ├── discriminator_stylegan.py │ ├── discriminator_patchgan.py │ ├── gan_loss.py │ └── lpips.py ├── utils.py ├── psd.py ├── sampling.py ├── dataset.py ├── diff_head.py ├── vae.py ├── layers.py └── model.py ├── figures ├── grid.jpg ├── overview.png └── fid_vs_params.png ├── .gitignore ├── README.md ├── sample_ddp.py ├── train.py └── evaluator.py /SphereAR/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /SphereAR/gan/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /figures/grid.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guolinke/SphereAR/HEAD/figures/grid.jpg -------------------------------------------------------------------------------- /figures/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guolinke/SphereAR/HEAD/figures/overview.png -------------------------------------------------------------------------------- /figures/fid_vs_params.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guolinke/SphereAR/HEAD/figures/fid_vs_params.png -------------------------------------------------------------------------------- /SphereAR/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | import torch.distributed as dist 5 | from torch.nn import functional as F 6 | 7 | 8 | def create_logger(logging_dir): 9 | """ 10 | Create a logger that writes to a log file and stdout. 11 | """ 12 | if dist.get_rank() == 0: # real logger 13 | logging.basicConfig( 14 | level=logging.INFO, 15 | format="[\033[34m%(asctime)s\033[0m] %(message)s", 16 | datefmt="%Y-%m-%d %H:%M:%S", 17 | handlers=[ 18 | logging.StreamHandler(), 19 | logging.FileHandler(f"{logging_dir}/log.txt"), 20 | ], 21 | ) 22 | logger = logging.getLogger(__name__) 23 | else: # dummy logger (does nothing) 24 | logger = logging.getLogger(__name__) 25 | logger.addHandler(logging.NullHandler()) 26 | return logger 27 | 28 | 29 | @torch.no_grad() 30 | def update_ema(ema_model, model, decay=0.9999): 31 | """ 32 | Step the EMA model towards the current model. 33 | """ 34 | ema_ps = [] 35 | ps = [] 36 | 37 | for e, m in zip(ema_model.parameters(), model.parameters()): 38 | if m.requires_grad: 39 | ema_ps.append(e) 40 | ps.append(m) 41 | torch._foreach_lerp_(ema_ps, ps, 1.0 - decay) 42 | 43 | 44 | @torch.no_grad() 45 | def sync_frozen_params_once(ema_model, model): 46 | for e, m in zip(ema_model.parameters(), model.parameters()): 47 | if not m.requires_grad: 48 | e.copy_(m) 49 | 50 | 51 | def requires_grad(model, flag=True): 52 | """ 53 | Set requires_grad flag for all parameters in a model. 54 | """ 55 | for p in model.parameters(): 56 | p.requires_grad = flag 57 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pt 2 | *.tfevents.* 3 | # JetBrains PyCharm IDE 4 | .idea/ 5 | 6 | # Compress files 7 | .gz 8 | .tar.gz 9 | .zip 10 | .data 11 | .7z 12 | .7zip 13 | .sdf 14 | 15 | # Byte-compiled / optimized / DLL files 16 | __pycache__/ 17 | *.py[cod] 18 | *$py.class 19 | 20 | # C extensions 21 | *.so 22 | 23 | # macOS dir files 24 | .DS_Store 25 | 26 | # Distribution / packaging 27 | .Python 28 | env/ 29 | build/ 30 | develop-eggs/ 31 | dist/ 32 | downloads/ 33 | eggs/ 34 | .eggs/ 35 | lib/ 36 | lib64/ 37 | parts/ 38 | sdist/ 39 | var/ 40 | wheels/ 41 | *.egg-info/ 42 | .installed.args 43 | *.egg 44 | 45 | # Checkpoints 46 | checkpoints 47 | 48 | # PyInstaller 49 | # Usually these files are written by a python script from a template 50 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 51 | *.manifest 52 | *.spec 53 | 54 | # Installer logs 55 | pip-log.txt 56 | pip-delete-this-directory.txt 57 | 58 | # Unit test / coverage reports 59 | htmlcov/ 60 | .tox/ 61 | .coverage 62 | .coverage.* 63 | .cache 64 | nosetests.xml 65 | coverage.xml 66 | *.cover 67 | .hypothesis/ 68 | 69 | # Translations 70 | *.mo 71 | *.pot 72 | 73 | # Django stuff: 74 | *.log 75 | local_settings.py 76 | 77 | # Flask stuff: 78 | instance/ 79 | .webassets-cache 80 | 81 | # Scrapy stuff: 82 | .scrapy 83 | 84 | # Sphinx documentation 85 | docs/_build/ 86 | 87 | # PyBuilder 88 | target/ 89 | 90 | # Jupyter Notebook 91 | .ipynb_checkpoints 92 | 93 | # pyenv 94 | .python-version 95 | 96 | # celery beat schedule file 97 | celerybeat-schedule 98 | 99 | # SageMath parsed files 100 | *.sage.py 101 | 102 | # dotenv 103 | .env 104 | 105 | # virtualenv 106 | .venv 107 | venv/ 108 | ENV/ 109 | 110 | # Spyder project settings 111 | .spyderproject 112 | .spyproject 113 | 114 | # Rope project settings 115 | .ropeproject 116 | 117 | # mypy 118 | .mypy_cache/ 119 | 120 | # VSCODE 121 | .vscode/ftp-sync.json 122 | .vscode/settings.json 123 | 124 | # too big to git 125 | *.lmdb 126 | *.sto 127 | *.pt 128 | *.pkl 129 | 130 | # pytest 131 | .pytest_cache 132 | test/.pytest_cache 133 | /local* 134 | /_* 135 | weights 136 | wandb 137 | weights* 138 | ft_weights* 139 | finetune_scripts/*.dict 140 | *.pdf -------------------------------------------------------------------------------- /SphereAR/psd.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch.distributions import Beta 5 | 6 | 7 | def l2_norm(x: torch.Tensor, eps: float = 1e-7) -> torch.Tensor: 8 | return x / torch.clamp(x.norm(dim=-1, keepdim=True), min=eps) 9 | 10 | 11 | class PowerSphericalDistribution: 12 | def __init__(self, mu: torch.Tensor, kappa: torch.Tensor, eps: float = 1e-7): 13 | self.eps = eps 14 | self.mu = l2_norm(mu, eps) # [..., m] 15 | self.kappa = torch.clamp(kappa, min=0.0) 16 | 17 | self.m = self.mu.shape[-1] 18 | self.d = self.m - 1 19 | beta_const = 0.5 * self.d 20 | self.alpha = self.kappa + beta_const # [...,] 21 | self.beta = torch.as_tensor( 22 | beta_const, dtype=self.kappa.dtype, device=self.kappa.device 23 | ).expand_as(self.kappa) 24 | 25 | def _log_normalizer(self) -> torch.Tensor: 26 | # log N_X(κ,d) = -[ (α+β)log 2 + β log π + lgamma(α) - lgamma(α+β) ] 27 | return ( 28 | -(self.alpha + self.beta) * math.log(2.0) 29 | - self.beta * math.log(math.pi) 30 | - torch.lgamma(self.alpha) 31 | + torch.lgamma(self.alpha + self.beta) 32 | ) 33 | 34 | def log_prob(self, x: torch.Tensor) -> torch.Tensor: 35 | dot = (self.mu * x).sum(dim=-1).clamp(-1.0, 1.0) 36 | return self._log_normalizer() + self.kappa * torch.log1p(dot) 37 | 38 | def entropy(self) -> torch.Tensor: 39 | # H = -[ log N_X + κ ( log 2 + ψ(α) - ψ(α+β) ) ] 40 | return -( 41 | self._log_normalizer() 42 | + self.kappa 43 | * ( 44 | math.log(2.0) 45 | + (torch.digamma(self.alpha) - torch.digamma(self.alpha + self.beta)) 46 | ) 47 | ) 48 | 49 | def kl_to_uniform(self) -> torch.Tensor: 50 | # KL(q || U(S^{d})) = -H(q) + log |S^{d}| 51 | d = torch.as_tensor(self.d, dtype=self.kappa.dtype, device=self.kappa.device) 52 | log_area = ( 53 | math.log(2.0) 54 | + 0.5 * (d + 1.0) * math.log(math.pi) 55 | - torch.lgamma(0.5 * (d + 1.0)) 56 | ) 57 | return -self.entropy() + log_area 58 | 59 | def rsample(self): 60 | Z = Beta(self.alpha, self.beta).rsample() # [*S, *B] 61 | t = (2.0 * Z - 1.0).unsqueeze(-1) # [*S, *B, 1] 62 | 63 | # 2) v ~ U(S^{m-2}) 64 | v = torch.randn( 65 | *self.mu.shape[:-1], 66 | self.m - 1, 67 | device=self.mu.device, 68 | dtype=self.mu.dtype, 69 | ) # [*S, *B, m-1] 70 | v = l2_norm(v, self.eps) 71 | 72 | y = torch.cat( 73 | [t, torch.sqrt(torch.clamp(1 - t**2, min=0.0)) * v], dim=-1 74 | ) # [*S, *B, m] 75 | 76 | e1 = torch.zeros_like(self.mu) 77 | e1[..., 0] = 1.0 78 | u = l2_norm(e1 - self.mu, self.eps) 79 | if u.dim() < y.dim(): 80 | u = u.view((1,) * (y.dim() - u.dim()) + u.shape) 81 | z = y - 2.0 * (y * u).sum(dim=-1, keepdim=True) * u 82 | 83 | parallel = (self.mu - e1).abs().sum(dim=-1, keepdim=True) < 1e-6 84 | if parallel.any(): 85 | p = parallel 86 | if p.dim() < y.dim() - 1: 87 | p = p.view((1,) * (y.dim() - 1 - p.dim()) + p.shape) 88 | z = torch.where(p, y, z) 89 | return z 90 | -------------------------------------------------------------------------------- /SphereAR/sampling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def get_score_from_velocity(velocity, x, t): 5 | alpha_t, d_alpha_t = t, 1 6 | sigma_t, d_sigma_t = 1 - t, -1 7 | mean = x 8 | reverse_alpha_ratio = alpha_t / d_alpha_t 9 | var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t 10 | score = (reverse_alpha_ratio * velocity - mean) / var 11 | return score 12 | 13 | 14 | def get_velocity_from_cfg(velocity, cfg, cfg_mult): 15 | if cfg_mult == 2: 16 | cond_v, uncond_v = torch.chunk(velocity, 2, dim=0) 17 | velocity = uncond_v + cfg * (cond_v - uncond_v) 18 | return velocity 19 | 20 | 21 | @torch.compile() 22 | def euler_step(x, v, dt: float, cfg: float, cfg_mult: int): 23 | with torch.amp.autocast("cuda", enabled=False): 24 | v = v.to(torch.float32) 25 | v = get_velocity_from_cfg(v, cfg, cfg_mult) 26 | x = x + v * dt 27 | return x 28 | 29 | 30 | @torch.compile() 31 | def euler_maruyama_step(x, v, t, dt: float, cfg: float, cfg_mult: int): 32 | with torch.amp.autocast("cuda", enabled=False): 33 | v = v.to(torch.float32) 34 | v = get_velocity_from_cfg(v, cfg, cfg_mult) 35 | score = get_score_from_velocity(v, x, t) 36 | drift = v + (1 - t) * score 37 | noise_scale = (2.0 * (1.0 - t) * dt) ** 0.5 38 | x = x + drift * dt + noise_scale * torch.randn_like(x) 39 | return x 40 | 41 | 42 | def euler_maruyama( 43 | input_dim, 44 | forward_fn, 45 | c: torch.Tensor, 46 | cfg: float = 1.0, 47 | num_sampling_steps: int = 20, 48 | last_step_size: float = 0.04, 49 | ): 50 | cfg_mult = 1 51 | if cfg > 1.0: 52 | cfg_mult += 1 53 | 54 | x_shape = list(c.shape) 55 | x_shape[0] = x_shape[0] // cfg_mult 56 | x_shape[-1] = input_dim 57 | x = torch.randn(x_shape, device=c.device) 58 | dt = (1.0 - last_step_size) / num_sampling_steps 59 | t = torch.tensor( 60 | 0.0, device=c.device, dtype=torch.float32 61 | ) # use tensor to avoid compile warning 62 | t_batch = torch.zeros(c.shape[0], device=c.device) 63 | for _ in range(num_sampling_steps): 64 | t_batch[:] = t 65 | combined = torch.cat([x] * cfg_mult, dim=0) 66 | v = forward_fn( 67 | combined, 68 | t_batch, 69 | c, 70 | ) 71 | x = euler_maruyama_step(x, v, t, dt, cfg, cfg_mult) 72 | t += dt 73 | 74 | combined = torch.cat([x] * cfg_mult, dim=0) 75 | t_batch[:] = 1 - last_step_size 76 | v = forward_fn( 77 | combined, 78 | t_batch, 79 | c, 80 | ) 81 | x = euler_step(x, v, last_step_size, cfg, cfg_mult) 82 | 83 | return torch.cat([x] * cfg_mult, dim=0) 84 | 85 | 86 | def euler( 87 | input_dim, 88 | forward_fn, 89 | c, 90 | cfg: float = 1.0, 91 | num_sampling_steps: int = 50, 92 | ): 93 | cfg_mult = 1 94 | if cfg > 1.0: 95 | cfg_mult = 2 96 | 97 | x_shape = list(c.shape) 98 | x_shape[0] = x_shape[0] // cfg_mult 99 | x_shape[-1] = input_dim 100 | x = torch.randn(x_shape, device=c.device) 101 | dt = 1.0 / num_sampling_steps 102 | t = 0 103 | t_batch = torch.zeros(c.shape[0], device=c.device) 104 | for _ in range(num_sampling_steps): 105 | t_batch[:] = t 106 | combined = torch.cat([x] * cfg_mult, dim=0) 107 | v = forward_fn(combined, t_batch, c) 108 | x = euler_step(x, v, dt, cfg, cfg_mult) 109 | t += dt 110 | 111 | return torch.cat([x] * cfg_mult, dim=0) 112 | -------------------------------------------------------------------------------- /SphereAR/gan/discriminator_stylegan.py: -------------------------------------------------------------------------------- 1 | # Modified from: 2 | # stylegan2-pytorch: https://github.com/lucidrains/stylegan2-pytorch/blob/master/stylegan2_pytorch/stylegan2_pytorch.py 3 | # stylegan2-pytorch: https://github.com/rosinality/stylegan2-pytorch/blob/master/model.py 4 | # maskgit: https://github.com/google-research/maskgit/blob/main/maskgit/nets/discriminator.py 5 | import math 6 | import torch 7 | import torch.nn as nn 8 | 9 | try: 10 | from kornia.filters import filter2d 11 | except: 12 | pass 13 | 14 | 15 | class Discriminator(nn.Module): 16 | def __init__( 17 | self, input_nc=3, ndf=64, n_layers=3, channel_multiplier=1, image_size=256 18 | ): 19 | super().__init__() 20 | channels = { 21 | 4: 512, 22 | 8: 512, 23 | 16: 512, 24 | 32: 512, 25 | 64: 256 * channel_multiplier, 26 | 128: 128 * channel_multiplier, 27 | 256: 64 * channel_multiplier, 28 | 512: 32 * channel_multiplier, 29 | 1024: 16 * channel_multiplier, 30 | } 31 | 32 | log_size = int(math.log(image_size, 2)) 33 | in_channel = channels[image_size] 34 | 35 | blocks = [nn.Conv2d(input_nc, in_channel, 3, padding=1), leaky_relu()] 36 | for i in range(log_size, 2, -1): 37 | out_channel = channels[2 ** (i - 1)] 38 | blocks.append(DiscriminatorBlock(in_channel, out_channel)) 39 | in_channel = out_channel 40 | self.blocks = nn.ModuleList(blocks) 41 | 42 | self.final_conv = nn.Sequential( 43 | nn.Conv2d(in_channel, channels[4], 3, padding=1), 44 | leaky_relu(), 45 | ) 46 | self.final_linear = nn.Sequential( 47 | nn.Linear(channels[4] * 4 * 4, channels[4]), 48 | leaky_relu(), 49 | nn.Linear(channels[4], 1), 50 | ) 51 | 52 | def forward(self, x): 53 | for block in self.blocks: 54 | x = block(x) 55 | x = self.final_conv(x) 56 | x = x.reshape(x.shape[0], -1) 57 | x = self.final_linear(x) 58 | return x 59 | 60 | 61 | class DiscriminatorBlock(nn.Module): 62 | def __init__(self, input_channels, filters, downsample=True): 63 | super().__init__() 64 | self.conv_res = nn.Conv2d( 65 | input_channels, filters, 1, stride=(2 if downsample else 1) 66 | ) 67 | 68 | self.net = nn.Sequential( 69 | nn.Conv2d(input_channels, filters, 3, padding=1), 70 | leaky_relu(), 71 | nn.Conv2d(filters, filters, 3, padding=1), 72 | leaky_relu(), 73 | ) 74 | 75 | self.downsample = ( 76 | nn.Sequential(Blur(), nn.Conv2d(filters, filters, 3, padding=1, stride=2)) 77 | if downsample 78 | else None 79 | ) 80 | 81 | def forward(self, x): 82 | res = self.conv_res(x) 83 | x = self.net(x) 84 | if exists(self.downsample): 85 | x = self.downsample(x) 86 | x = (x + res) * (1 / math.sqrt(2)) 87 | return x 88 | 89 | 90 | class Blur(nn.Module): 91 | def __init__(self): 92 | super().__init__() 93 | f = torch.Tensor([1, 2, 1]) 94 | self.register_buffer("f", f) 95 | 96 | def forward(self, x): 97 | f = self.f 98 | f = f[None, None, :] * f[None, :, None] 99 | return filter2d(x, f, normalized=True) 100 | 101 | 102 | def leaky_relu(p=0.2): 103 | return nn.LeakyReLU(p, inplace=True) 104 | 105 | 106 | def exists(val): 107 | return val is not None 108 | -------------------------------------------------------------------------------- /SphereAR/gan/discriminator_patchgan.py: -------------------------------------------------------------------------------- 1 | # Modified from: 2 | # taming-transformers: https://github.com/CompVis/taming-transformers 3 | import functools 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class NLayerDiscriminator(nn.Module): 9 | """Defines a PatchGAN discriminator as in Pix2Pix 10 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 11 | """ 12 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): 13 | """Construct a PatchGAN discriminator 14 | Parameters: 15 | input_nc (int) -- the number of channels in input images 16 | ndf (int) -- the number of filters in the last conv layer 17 | n_layers (int) -- the number of conv layers in the discriminator 18 | norm_layer -- normalization layer 19 | """ 20 | super(NLayerDiscriminator, self).__init__() 21 | if not use_actnorm: 22 | norm_layer = nn.BatchNorm2d 23 | else: 24 | norm_layer = ActNorm 25 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 26 | use_bias = norm_layer.func != nn.BatchNorm2d 27 | else: 28 | use_bias = norm_layer != nn.BatchNorm2d 29 | 30 | kw = 4 31 | padw = 1 32 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] 33 | nf_mult = 1 34 | nf_mult_prev = 1 35 | for n in range(1, n_layers): # gradually increase the number of filters 36 | nf_mult_prev = nf_mult 37 | nf_mult = min(2 ** n, 8) 38 | sequence += [ 39 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), 40 | norm_layer(ndf * nf_mult), 41 | nn.LeakyReLU(0.2, True) 42 | ] 43 | 44 | nf_mult_prev = nf_mult 45 | nf_mult = min(2 ** n_layers, 8) 46 | sequence += [ 47 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), 48 | norm_layer(ndf * nf_mult), 49 | nn.LeakyReLU(0.2, True) 50 | ] 51 | 52 | sequence += [ 53 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map 54 | self.main = nn.Sequential(*sequence) 55 | 56 | self.apply(self._init_weights) 57 | 58 | def _init_weights(self, module): 59 | if isinstance(module, nn.Conv2d): 60 | nn.init.normal_(module.weight.data, 0.0, 0.02) 61 | elif isinstance(module, nn.BatchNorm2d): 62 | nn.init.normal_(module.weight.data, 1.0, 0.02) 63 | nn.init.constant_(module.bias.data, 0) 64 | 65 | def forward(self, input): 66 | """Standard forward.""" 67 | return self.main(input) 68 | 69 | 70 | class ActNorm(nn.Module): 71 | def __init__(self, num_features, logdet=False, affine=True, 72 | allow_reverse_init=False): 73 | assert affine 74 | super().__init__() 75 | self.logdet = logdet 76 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) 77 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) 78 | self.allow_reverse_init = allow_reverse_init 79 | 80 | self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8)) 81 | 82 | def initialize(self, input): 83 | with torch.no_grad(): 84 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) 85 | mean = ( 86 | flatten.mean(1) 87 | .unsqueeze(1) 88 | .unsqueeze(2) 89 | .unsqueeze(3) 90 | .permute(1, 0, 2, 3) 91 | ) 92 | std = ( 93 | flatten.std(1) 94 | .unsqueeze(1) 95 | .unsqueeze(2) 96 | .unsqueeze(3) 97 | .permute(1, 0, 2, 3) 98 | ) 99 | 100 | self.loc.data.copy_(-mean) 101 | self.scale.data.copy_(1 / (std + 1e-6)) 102 | 103 | def forward(self, input, reverse=False): 104 | if reverse: 105 | return self.reverse(input) 106 | if len(input.shape) == 2: 107 | input = input[:,:,None,None] 108 | squeeze = True 109 | else: 110 | squeeze = False 111 | 112 | _, _, height, width = input.shape 113 | 114 | if self.training and self.initialized.item() == 0: 115 | self.initialize(input) 116 | self.initialized.fill_(1) 117 | 118 | h = self.scale * (input + self.loc) 119 | 120 | if squeeze: 121 | h = h.squeeze(-1).squeeze(-1) 122 | 123 | if self.logdet: 124 | log_abs = torch.log(torch.abs(self.scale)) 125 | logdet = height*width*torch.sum(log_abs) 126 | logdet = logdet * torch.ones(input.shape[0]).to(input) 127 | return h, logdet 128 | 129 | return h 130 | 131 | def reverse(self, output): 132 | if self.training and self.initialized.item() == 0: 133 | if not self.allow_reverse_init: 134 | raise RuntimeError( 135 | "Initializing ActNorm in reverse direction is " 136 | "disabled by default. Use allow_reverse_init=True to enable." 137 | ) 138 | else: 139 | self.initialize(output) 140 | self.initialized.fill_(1) 141 | 142 | if len(output.shape) == 2: 143 | output = output[:,:,None,None] 144 | squeeze = True 145 | else: 146 | squeeze = False 147 | 148 | h = output / self.scale - self.loc 149 | 150 | if squeeze: 151 | h = h.squeeze(-1).squeeze(-1) 152 | return h -------------------------------------------------------------------------------- /SphereAR/gan/gan_loss.py: -------------------------------------------------------------------------------- 1 | # Modified from: 2 | # taming-transformers: https://github.com/CompVis/taming-transformers 3 | # muse-maskgit-pytorch: https://github.com/lucidrains/muse-maskgit-pytorch/blob/main/muse_maskgit_pytorch/vqgan_vae.py 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch._dynamo as dynamo 8 | 9 | from .lpips import LPIPS 10 | from .discriminator_patchgan import ( 11 | NLayerDiscriminator as PatchGANDiscriminator, 12 | ) 13 | from .discriminator_stylegan import ( 14 | Discriminator as StyleGANDiscriminator, 15 | ) 16 | 17 | 18 | def hinge_d_loss(logits_real, logits_fake): 19 | loss_real = torch.mean(F.relu(1.0 - logits_real)) 20 | loss_fake = torch.mean(F.relu(1.0 + logits_fake)) 21 | d_loss = 0.5 * (loss_real + loss_fake) 22 | return d_loss 23 | 24 | 25 | def vanilla_d_loss(logits_real, logits_fake): 26 | loss_real = torch.mean(F.softplus(-logits_real)) 27 | loss_fake = torch.mean(F.softplus(logits_fake)) 28 | d_loss = 0.5 * (loss_real + loss_fake) 29 | return d_loss 30 | 31 | 32 | def non_saturating_d_loss(logits_real, logits_fake): 33 | loss_real = torch.mean( 34 | F.binary_cross_entropy_with_logits(torch.ones_like(logits_real), logits_real) 35 | ) 36 | loss_fake = torch.mean( 37 | F.binary_cross_entropy_with_logits(torch.zeros_like(logits_fake), logits_fake) 38 | ) 39 | d_loss = 0.5 * (loss_real + loss_fake) 40 | return d_loss 41 | 42 | 43 | def hinge_gen_loss(logit_fake): 44 | return -torch.mean(logit_fake) 45 | 46 | 47 | def non_saturating_gen_loss(logit_fake): 48 | return torch.mean( 49 | F.binary_cross_entropy_with_logits(torch.ones_like(logit_fake), logit_fake) 50 | ) 51 | 52 | 53 | def adopt_weight(weight, global_step, threshold=0, value=0.0): 54 | if global_step < threshold: 55 | weight = value 56 | return weight 57 | 58 | 59 | class GANLoss(nn.Module): 60 | def __init__( 61 | self, 62 | disc_start, 63 | disc_loss="hinge", 64 | disc_dim=64, 65 | disc_type="patchgan", 66 | image_size=256, 67 | disc_num_layers=3, 68 | disc_in_channels=3, 69 | disc_weight=0.5, 70 | gen_adv_loss="hinge", 71 | reconstruction_loss="l2", 72 | ): 73 | super().__init__() 74 | # discriminator loss 75 | assert disc_type in ["patchgan", "stylegan"] 76 | assert disc_loss in ["hinge", "vanilla", "non-saturating"] 77 | if disc_type == "patchgan": 78 | self.discriminator = PatchGANDiscriminator( 79 | input_nc=disc_in_channels, 80 | n_layers=disc_num_layers, 81 | ndf=disc_dim, 82 | ) 83 | elif disc_type == "stylegan": 84 | self.discriminator = StyleGANDiscriminator( 85 | input_nc=disc_in_channels, 86 | image_size=image_size, 87 | ) 88 | else: 89 | raise ValueError(f"Unknown GAN discriminator type '{disc_type}'.") 90 | if disc_loss == "hinge": 91 | self.disc_loss = hinge_d_loss 92 | elif disc_loss == "vanilla": 93 | self.disc_loss = vanilla_d_loss 94 | elif disc_loss == "non-saturating": 95 | self.disc_loss = non_saturating_d_loss 96 | else: 97 | raise ValueError(f"Unknown GAN discriminator loss '{disc_loss}'.") 98 | self.discriminator_iter_start = disc_start 99 | self.disc_weight = disc_weight 100 | 101 | assert gen_adv_loss in ["hinge", "non-saturating"] 102 | # gen_adv_loss 103 | if gen_adv_loss == "hinge": 104 | self.gen_adv_loss = hinge_gen_loss 105 | elif gen_adv_loss == "non-saturating": 106 | self.gen_adv_loss = non_saturating_gen_loss 107 | else: 108 | raise ValueError(f"Unknown GAN generator loss '{gen_adv_loss}'.") 109 | 110 | # perceptual loss 111 | self.perceptual_loss = LPIPS().eval() 112 | 113 | # reconstruction loss 114 | if reconstruction_loss == "l1": 115 | self.rec_loss = F.l1_loss 116 | elif reconstruction_loss == "l2": 117 | self.rec_loss = F.mse_loss 118 | else: 119 | raise ValueError(f"Unknown rec loss '{reconstruction_loss}'.") 120 | 121 | def forward( 122 | self, 123 | inputs, 124 | reconstructions, 125 | optimizer_idx, 126 | global_step, 127 | ): 128 | # inputs = inputs.contiguous() 129 | # reconstructions = reconstructions.contiguous() 130 | # generator update 131 | if optimizer_idx == 0: 132 | # reconstruction loss 133 | rec_loss = self.rec_loss(inputs, reconstructions) 134 | 135 | # perceptual loss 136 | p_loss = self.perceptual_loss(inputs, reconstructions) 137 | p_loss = torch.mean(p_loss) 138 | 139 | # discriminator loss 140 | logits_fake = self.discriminator(reconstructions) 141 | generator_adv_loss = self.gen_adv_loss(logits_fake) 142 | 143 | disc_weight = adopt_weight( 144 | self.disc_weight, global_step, threshold=self.discriminator_iter_start 145 | ) 146 | 147 | return ( 148 | rec_loss, 149 | p_loss, 150 | disc_weight * generator_adv_loss, 151 | ) 152 | 153 | # discriminator update 154 | if optimizer_idx == 1: 155 | logits_real = self.discriminator(inputs.detach()) 156 | logits_fake = self.discriminator(reconstructions.detach()) 157 | 158 | disc_weight = adopt_weight( 159 | self.disc_weight, global_step, threshold=self.discriminator_iter_start 160 | ) 161 | d_adversarial_loss = disc_weight * self.disc_loss(logits_real, logits_fake) 162 | 163 | return d_adversarial_loss 164 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SphereAR: Hyperspherical Latents Improve Continuous-Token Autoregressive Generation 2 | 3 | [![arXiv](https://img.shields.io/badge/arXiv-2509.24335-b31b1b.svg)](https://arxiv.org/abs/2509.24335)  4 | [![huggingface](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-SphereAR-yellow)](https://huggingface.co/guolinke/SphereAR)  5 | 6 | 7 |

8 | 9 |

10 | 11 | This is the official PyTorch implementation of paper [Hyperspherical Latents Improve Continuous-Token Autoregressive Generation](https://arxiv.org/abs/2509.24335). 12 | 13 | ``` 14 | @article{ke2025hyperspherical, 15 | title={Hyperspherical Latents Improve Continuous-Token Autoregressive Generation}, 16 | author={Guolin Ke and Hui Xue}, 17 | journal={arXiv preprint arXiv:2509.24335}, 18 | year={2025} 19 | } 20 | ``` 21 | 22 | 23 | ## Introduction 24 | 25 |

26 | 27 | SphereAR is a simple yet effective approach to continuous-token autoregressive (AR) image generation: it makes AR scale-invariant by constraining all AR inputs and outputs---**including after CFG**---to lie on a fixed-radius hypersphere (constant L2 norm) via hyperspherical VAEs. 28 | 29 | The model is a **pure next-token** AR generator with **raster** order, matching standard language AR modeling (i.e., it is *not* next-scale AR like VAR and *not* next-set AR like MAR/MaskGIT). 30 | 31 | On ImageNet 256×256, SphereAR achieves a state-of-the-art FID of **1.34** among AR image generators. 32 | 33 | 34 | ## Environment 35 | 36 | - PyTorch: 2.7.1 (CUDA 12.6 build) 37 | - FlashAttention: 2.8.1 38 | 39 | ### Install notes 40 | 1. Install PyTorch 2.7.1 (CUDA 12.6) using your preferred method. 41 | 2. Install FlashAttention 2.8.1 from the prebuilt wheel (replace the cp310 tag with your Python version, e.g., cp311 for Python 3.11): 42 | ```shell 43 | pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.1/flash_attn-2.8.1+cu12torch2.7cxx11abiTRUE-cp310-cp310-linux_x86_64.whl 44 | ``` 45 | 46 | 47 | ## Class-conditional image generation on ImageNet 48 | 49 | ### Model Checkpoints 50 | 51 | Name | params | FID (256x256) | weight 52 | --- |:---:|:---:|:---:| 53 | S-VAE | 75M | - | [vae.pt](https://huggingface.co/guolinke/SphereAR/blob/main/vae.pt) 54 | SphereAR-B | 208M | 1.92 | [SphereAR_B.pt](https://huggingface.co/guolinke/SphereAR/blob/main/SphereAR_B.pt) 55 | SphereAR-L | 479M | 1.54 | [SphereAR_L.pt](https://huggingface.co/guolinke/SphereAR/blob/main/SphereAR_L.pt) 56 | SphereAR-H | 943M | 1.34 | [SphereAR_H.pt](https://huggingface.co/guolinke/SphereAR/blob/main/SphereAR_H.pt) 57 | 58 | ### Evaluation from checkpoints 59 | 60 | 1. Sample 50,000 images and save to `.npz`. 61 | 62 | SphereAR-B: 63 | ```shell 64 | ckpt=your_ckpt_path 65 | result_path=your_result_path 66 | torchrun --nnodes=1 --nproc_per_node=8 --node_rank=0 \ 67 | sample_ddp.py --model SphereAR-B --ckpt $ckpt --cfg-scale 4.5 \ 68 | --sample-dir $result_path --per-proc-batch-size 256 --to-npz 69 | ``` 70 | 71 | SphereAR-L: 72 | ```shell 73 | ckpt=your_ckpt_path 74 | result_path=your_result_path 75 | torchrun --nnodes=1 --nproc_per_node=8 --node_rank=0 \ 76 | sample_ddp.py --model SphereAR-L --ckpt $ckpt --cfg-scale 4.6 \ 77 | --sample-dir $result_path --per-proc-batch-size 256 --to-npz 78 | ``` 79 | 80 | SphereAR-H: 81 | ```shell 82 | ckpt=your_ckpt_path 83 | result_path=your_result_path 84 | torchrun --nnodes=1 --nproc_per_node=8 --node_rank=0 \ 85 | sample_ddp.py --model SphereAR-H --ckpt $ckpt --cfg-scale 4.5 \ 86 | --sample-dir $result_path --per-proc-batch-size 256 --to-npz 87 | ``` 88 | 89 | 2. Compute metrics following [OpenAI’s evaluation protocol](https://github.com/openai/guided-diffusion/tree/main/evaluations). You should download the [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/VIRTUAL_imagenet256_labeled.npz), and run `python evaluator.py VIRTUAL_imagenet256_labeled.npz your_generated.npz` for the metric. TensorFlow is required, and we use ```tensorflow==2.19.1```. 90 | 91 | 92 | ### Reproduce our training: 93 | 94 | 1. Download [ImageNet](http://image-net.org/download) dataset. **Note**: Our code support to train from the tar file, the decompression is not needed. 95 | 96 | 2. Train the S-VAE: 97 | 98 | ```shell 99 | data_path=your_data_path/ILSVRC2012_img_train.tar 100 | result_path=your_resulet_path 101 | torchrun --nnodes=1 --nproc_per_node=8 --node_rank=0 \ 102 | train.py --results-dir $result_path --data-path $data_path \ 103 | --image-size 256 --epochs 100 --patch-size 16 --latent-dim 16 --vae-only \ 104 | --lr 1e-4 --global-batch-size 256 --warmup-steps -1 --decay-start -1 105 | ``` 106 | 107 | 3. Train the AR model: 108 | 109 | ```shell 110 | data_path=your_data_path/ILSVRC2012_img_train.tar 111 | result_path=your_resulet_path 112 | vae_ckpt=your_vae_path 113 | torchrun --nproc_per_node=8 --master_addr=$WORKER_0_HOST --node_rank=$LOCAL_RANK --master_port=$WORKER_0_PORT --nnodes=$WORKER_NUM \ 114 | train.py --results-dir $result_path --data-path $data_path --image-size 256 \ 115 | --model SphereAR-B --epochs 400 --patch-size 16 --latent-dim 16 \ 116 | --lr 3e-4 --global-batch-size 512 --trained-vae $vae_ckpt --ema 0.9999 117 | ``` 118 | You can use the script above to train `SphereAR-B`; to train other sizes, set `--model` to `SphereAR-L` or `SphereAR-H`. 119 | We trained on A100 GPUs with the following setups: 8×A100 for SphereAR-B, 16×A100 for SphereAR-L, and 32×A100 for SphereAR-H. 120 | The training costs about 3 days for 400 epochs. 121 | 122 | **Note**: We use `torch.compile` for acceleration. Occasionally the TorchInductor compile step can hang; if that happens, re-run the job. Enabling Dynamo logs tends to reduce stalls: `export TORCH_LOGS="+dynamo"`. To avoid repeated compilation cost across runs, enable the compile caches: 123 | 124 | ```shell 125 | export TORCHINDUCTOR_FX_GRAPH_CACHE=1 126 | export TORCHINDUCTOR_AUTOGRAD_CACHE=1 127 | ``` 128 | 129 | Set these environment variables in your shell (or job script) before launching training. 130 | 131 | -------------------------------------------------------------------------------- /sample_ddp.py: -------------------------------------------------------------------------------- 1 | # Modified from: 2 | # DiT: https://github.com/facebookresearch/DiT/blob/main/sample.py 3 | import math 4 | import os 5 | import time 6 | 7 | import numpy as np 8 | import torch 9 | import torch.distributed as dist 10 | from PIL import Image 11 | from tqdm import tqdm 12 | 13 | from SphereAR.model import create_model, get_model_args 14 | 15 | 16 | def create_npz_from_sample_folder(sample_dir, num=50_000): 17 | """ 18 | Builds a single .npz file from a folder of .png samples. 19 | """ 20 | samples = [] 21 | for i in tqdm(range(num), desc="Building .npz file from samples"): 22 | sample_pil = Image.open(f"{sample_dir}/{i:06d}.png") 23 | sample_np = np.asarray(sample_pil).astype(np.uint8) 24 | samples.append(sample_np) 25 | samples = np.stack(samples) 26 | assert samples.shape == (num, samples.shape[1], samples.shape[2], 3) 27 | npz_path = f"{sample_dir}.npz" 28 | np.savez(npz_path, arr_0=samples) 29 | print(f"Saved .npz file to {npz_path} [shape={samples.shape}].") 30 | # delete the sample folder to save space 31 | os.system(f"rm -r {sample_dir}") 32 | return npz_path 33 | 34 | 35 | def main(args): 36 | assert ( 37 | torch.cuda.is_available() 38 | ), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage" 39 | 40 | dist.init_process_group("nccl") 41 | rank = dist.get_rank() 42 | device = rank % torch.cuda.device_count() 43 | seed = args.seed * dist.get_world_size() + rank 44 | torch.manual_seed(seed) 45 | torch.cuda.set_device(device) 46 | print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.") 47 | torch.set_float32_matmul_precision("high") 48 | torch.backends.cudnn.deterministic = True 49 | torch.backends.cudnn.benchmark = False 50 | torch.set_grad_enabled(False) 51 | 52 | # create and load gpt model 53 | precision = {"none": torch.float32, "bf16": torch.bfloat16}[args.mixed_precision] 54 | model = create_model(args, device) 55 | 56 | checkpoint = torch.load(args.ckpt, map_location="cpu", weights_only=False) 57 | if "ema" in checkpoint and not args.no_ema: 58 | print("use ema weight") 59 | model_weight = checkpoint["ema"] 60 | elif "model" in checkpoint: # ddp 61 | model_weight = checkpoint["model"] 62 | else: 63 | raise Exception("please check model weight") 64 | 65 | model.load_state_dict(model_weight, strict=True) 66 | model.eval() 67 | del checkpoint 68 | 69 | # Create folder to save samples: 70 | model_string_name = args.model.replace("/", "-") 71 | ckpt_string_name = ( 72 | os.path.basename(args.ckpt).replace(".pth", "").replace(".pt", "") 73 | ) 74 | folder_name = ( 75 | f"{model_string_name}-{ckpt_string_name}-size-{args.image_size}-" 76 | f"steps-{args.sample_steps}-cfg-{args.cfg_scale}-seed-{args.seed}" 77 | ) 78 | if not args.no_ema: 79 | folder_name += "-ema" 80 | sample_folder_dir = f"{args.sample_dir}/{folder_name}" 81 | 82 | if os.path.isfile(sample_folder_dir + ".npz"): 83 | if rank == 0: 84 | print(f"Found {sample_folder_dir}.npz, skipping sampling.") 85 | dist.barrier() 86 | dist.destroy_process_group() 87 | return 1 88 | if rank == 0: 89 | os.makedirs(sample_folder_dir, exist_ok=True) 90 | print(f"Saving .png samples at {sample_folder_dir}") 91 | dist.barrier() 92 | 93 | # Figure out how many samples we need to generate on each GPU and how many iterations we need to run: 94 | n = args.per_proc_batch_size 95 | global_batch_size = n * dist.get_world_size() 96 | # To make things evenly-divisible, we'll sample a bit more than we need and then discard the extra samples: 97 | total_samples = int( 98 | math.ceil(args.num_fid_samples / global_batch_size) * global_batch_size 99 | ) 100 | if rank == 0: 101 | print(f"Total number of images that will be sampled: {total_samples}") 102 | assert ( 103 | total_samples % dist.get_world_size() == 0 104 | ), "total_samples must be divisible by world_size" 105 | samples_needed_this_gpu = int(total_samples // dist.get_world_size()) 106 | assert ( 107 | samples_needed_this_gpu % n == 0 108 | ), "samples_needed_this_gpu must be divisible by the per-GPU batch size" 109 | iterations = int(samples_needed_this_gpu // n) 110 | total = 0 111 | start_time = time.time() 112 | for i in tqdm(range(iterations), desc="Sampling"): 113 | # Sample inputs: 114 | c_indices = torch.randint(0, args.num_classes, (n,), device=device) 115 | with torch.amp.autocast("cuda", dtype=precision): 116 | samples = model.sample( 117 | c_indices, 118 | sample_steps=args.sample_steps, 119 | cfg_scale=args.cfg_scale, 120 | ) 121 | 122 | samples = ( 123 | torch.clamp(127.5 * samples + 128.0, 0, 255) 124 | .permute(0, 2, 3, 1) 125 | .to("cpu", dtype=torch.uint8) 126 | .numpy() 127 | ) 128 | 129 | # Save samples to disk as individual .png files 130 | for i, sample in enumerate(samples): 131 | index = i * dist.get_world_size() + rank + total 132 | Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png") 133 | total += global_batch_size 134 | print( 135 | f"Rank {rank} has sampled {total} images so far, cost {time.time() - start_time:.2f} seconds" 136 | ) 137 | 138 | # Make sure all processes have finished saving their samples before attempting to convert to .npz 139 | dist.barrier() 140 | if rank == 0 and args.to_npz: 141 | print(f"Total time taken for sampling: {time.time() - start_time:.2f} seconds") 142 | create_npz_from_sample_folder(sample_folder_dir, args.num_fid_samples) 143 | print("Done.") 144 | dist.barrier() 145 | dist.destroy_process_group() 146 | 147 | 148 | if __name__ == "__main__": 149 | parser = get_model_args() 150 | parser.add_argument("--ckpt", type=str, default=None) 151 | parser.add_argument("--sample-dir", type=str, default="samples") 152 | parser.add_argument("--per-proc-batch-size", type=int, default=32) 153 | parser.add_argument("--num-fid-samples", type=int, default=50000) 154 | parser.add_argument("--cfg-scale", type=float, default=4.6) 155 | parser.add_argument("--seed", type=int, default=99) 156 | parser.add_argument("--sample-steps", type=int, default=100) 157 | parser.add_argument("--no-ema", action="store_true") 158 | parser.add_argument( 159 | "--mixed-precision", type=str, default="bf16", choices=["none", "bf16"] 160 | ) 161 | parser.add_argument("--to-npz", action="store_true") 162 | args = parser.parse_args() 163 | main(args) 164 | -------------------------------------------------------------------------------- /SphereAR/gan/lpips.py: -------------------------------------------------------------------------------- 1 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" 2 | 3 | import os, hashlib 4 | import requests 5 | from tqdm import tqdm 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torchvision import models 10 | from collections import namedtuple 11 | 12 | URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"} 13 | 14 | CKPT_MAP = {"vgg_lpips": "vgg.pth"} 15 | 16 | MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"} 17 | 18 | 19 | def download(url, local_path, chunk_size=1024): 20 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 21 | with requests.get(url, stream=True) as r: 22 | total_size = int(r.headers.get("content-length", 0)) 23 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 24 | with open(local_path, "wb") as f: 25 | for data in r.iter_content(chunk_size=chunk_size): 26 | if data: 27 | f.write(data) 28 | pbar.update(chunk_size) 29 | 30 | 31 | def md5_hash(path): 32 | with open(path, "rb") as f: 33 | content = f.read() 34 | return hashlib.md5(content).hexdigest() 35 | 36 | 37 | def get_ckpt_path(name, root, check=False): 38 | assert name in URL_MAP 39 | path = os.path.join(root, CKPT_MAP[name]) 40 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): 41 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) 42 | download(URL_MAP[name], path) 43 | md5 = md5_hash(path) 44 | assert md5 == MD5_MAP[name], md5 45 | return path 46 | 47 | 48 | class LPIPS(nn.Module): 49 | # Learned perceptual metric 50 | def __init__(self, use_dropout=True): 51 | super().__init__() 52 | self.scaling_layer = ScalingLayer() 53 | self.chns = [64, 128, 256, 512, 512] # vg16 features 54 | self.net = vgg16(pretrained=True, requires_grad=False) 55 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 56 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 57 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 58 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 59 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 60 | self.load_from_pretrained() 61 | for param in self.parameters(): 62 | param.requires_grad = False 63 | 64 | def load_from_pretrained(self, name="vgg_lpips"): 65 | ckpt = get_ckpt_path( 66 | name, os.path.join(os.path.dirname(os.path.abspath(__file__)), "cache") 67 | ) 68 | self.load_state_dict( 69 | torch.load(ckpt, map_location=torch.device("cpu")), strict=False 70 | ) 71 | print("loaded pretrained LPIPS loss from {}".format(ckpt)) 72 | 73 | @classmethod 74 | def from_pretrained(cls, name="vgg_lpips"): 75 | if name != "vgg_lpips": 76 | raise NotImplementedError 77 | model = cls() 78 | ckpt = get_ckpt_path( 79 | name, os.path.join(os.path.dirname(os.path.abspath(__file__)), "cache") 80 | ) 81 | model.load_state_dict( 82 | torch.load(ckpt, map_location=torch.device("cpu")), strict=False 83 | ) 84 | return model 85 | 86 | def forward(self, input, target): 87 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) 88 | outs0, outs1 = self.net(in0_input), self.net(in1_input) 89 | feats0, feats1, diffs = {}, {}, {} 90 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 91 | for kk in range(len(self.chns)): 92 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor( 93 | outs1[kk] 94 | ) 95 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 96 | 97 | res = [ 98 | spatial_average(lins[kk].model(diffs[kk]), keepdim=True) 99 | for kk in range(len(self.chns)) 100 | ] 101 | val = res[0] 102 | for l in range(1, len(self.chns)): 103 | val += res[l] 104 | return val 105 | 106 | 107 | class ScalingLayer(nn.Module): 108 | def __init__(self): 109 | super(ScalingLayer, self).__init__() 110 | self.register_buffer( 111 | "shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None] 112 | ) 113 | self.register_buffer( 114 | "scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None] 115 | ) 116 | 117 | def forward(self, inp): 118 | return (inp - self.shift) / self.scale 119 | 120 | 121 | class NetLinLayer(nn.Module): 122 | """A single linear layer which does a 1x1 conv""" 123 | 124 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 125 | super(NetLinLayer, self).__init__() 126 | layers = ( 127 | [ 128 | nn.Dropout(), 129 | ] 130 | if (use_dropout) 131 | else [] 132 | ) 133 | layers += [ 134 | nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), 135 | ] 136 | self.model = nn.Sequential(*layers) 137 | 138 | 139 | class vgg16(torch.nn.Module): 140 | def __init__(self, requires_grad=False, pretrained=True): 141 | super(vgg16, self).__init__() 142 | vgg_pretrained_features = models.vgg16( 143 | weights=models.VGG16_Weights.IMAGENET1K_V1 144 | ).features 145 | self.slice1 = torch.nn.Sequential() 146 | self.slice2 = torch.nn.Sequential() 147 | self.slice3 = torch.nn.Sequential() 148 | self.slice4 = torch.nn.Sequential() 149 | self.slice5 = torch.nn.Sequential() 150 | self.N_slices = 5 151 | for x in range(4): 152 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 153 | for x in range(4, 9): 154 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 155 | for x in range(9, 16): 156 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 157 | for x in range(16, 23): 158 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 159 | for x in range(23, 30): 160 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 161 | if not requires_grad: 162 | for param in self.parameters(): 163 | param.requires_grad = False 164 | 165 | def forward(self, X): 166 | h = self.slice1(X) 167 | h_relu1_2 = h 168 | h = self.slice2(h) 169 | h_relu2_2 = h 170 | h = self.slice3(h) 171 | h_relu3_3 = h 172 | h = self.slice4(h) 173 | h_relu4_3 = h 174 | h = self.slice5(h) 175 | h_relu5_3 = h 176 | vgg_outputs = namedtuple( 177 | "VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"] 178 | ) 179 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 180 | return out 181 | 182 | 183 | def normalize_tensor(x, eps=1e-10): 184 | norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) 185 | return x / (norm_factor + eps) 186 | 187 | 188 | def spatial_average(x, keepdim=True): 189 | return x.mean([2, 3], keepdim=keepdim) 190 | -------------------------------------------------------------------------------- /SphereAR/dataset.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import io 3 | import math 4 | import os 5 | import pickle 6 | import tarfile 7 | from functools import lru_cache 8 | 9 | import numpy as np 10 | import torch 11 | from PIL import Image 12 | from torch.utils.data import Dataset 13 | from torchvision.datasets import ImageFolder 14 | 15 | 16 | @contextlib.contextmanager 17 | def numpy_seed(seed, *addl_seeds): 18 | """Context manager which seeds the NumPy PRNG with the specified seed and 19 | restores the state afterward""" 20 | if seed is None: 21 | yield 22 | return 23 | 24 | def check_seed(s): 25 | assert type(s) == int or type(s) == np.int32 or type(s) == np.int64 26 | 27 | check_seed(seed) 28 | if len(addl_seeds) > 0: 29 | for s in addl_seeds: 30 | check_seed(s) 31 | seed = int(hash((seed, *addl_seeds)) % 1e8) 32 | state = np.random.get_state() 33 | np.random.seed(seed) 34 | try: 35 | yield 36 | finally: 37 | np.random.set_state(state) 38 | 39 | 40 | def build_flat_index(outer_path: str, idx_path: str): 41 | if os.path.exists(idx_path): 42 | print(f"Index file {idx_path} already exists. Skipping index building.") 43 | return pickle.load(open(idx_path, "rb")) 44 | entries = [] # (offset, size, label) 45 | cats = set() 46 | idx = 0 47 | with tarfile.open(outer_path, "r:") as outer: 48 | for sub in outer.getmembers(): 49 | if not sub.isfile() or not sub.name.endswith(".tar"): 50 | continue 51 | outer_off = sub.offset_data 52 | sub_fobj = outer.extractfile(sub) 53 | with tarfile.open(fileobj=sub_fobj, mode="r:") as inner: 54 | for m in inner.getmembers(): 55 | if not m.isfile(): 56 | continue 57 | cat = m.name.split("_", 1)[0] 58 | cats.add(cat) 59 | abs_off = outer_off + m.offset_data 60 | entries.append((abs_off, m.size, cat)) 61 | if idx % 1000 == 1: 62 | print(idx, m.name, abs_off, m.size, cat) 63 | idx += 1 64 | sorted_cats = sorted(cats) 65 | cat2idx = {c: i for i, c in enumerate(sorted_cats)} 66 | 67 | flat = [(off, size, cat2idx[c]) for off, size, c in entries] 68 | 69 | os.makedirs(os.path.dirname(idx_path), exist_ok=True) 70 | with open(idx_path, "wb") as f: 71 | pickle.dump( 72 | flat, 73 | f, 74 | ) 75 | print(f"Built flat index with {len(flat)} images.") 76 | return flat 77 | 78 | 79 | class ImageNetTarDataset(Dataset): 80 | """ 81 | ImageNet dataset stored in a tar file, avoid to decompress the whole dataset. 82 | You can direct use the original downloaded tar file (ILSVRC2012_img_train.tar) from official ImageNet website. 83 | The best practice is to copy the tar file to node's local disk or ramdisk (like /dev/shm/) first, to avoid remote I/O bottleneck. 84 | """ 85 | 86 | def __init__( 87 | self, 88 | tar_file, 89 | ): 90 | self.tar_file = tar_file 91 | self.tar_handle = None 92 | self.files = build_flat_index(tar_file, tar_file + ".index") 93 | self.num_examples = len(self.files) 94 | 95 | def __len__(self): 96 | return self.num_examples 97 | 98 | def get_raw_image(self, index): 99 | if self.tar_handle is None: 100 | self.tar_handle = open(self.tar_file, "rb") 101 | 102 | offset, size, label = self.files[index] 103 | self.tar_handle.seek(offset) 104 | data = self.tar_handle.read(size) 105 | image = Image.open(io.BytesIO(data)).convert("RGB") 106 | return image, label 107 | 108 | @lru_cache(maxsize=16) 109 | def __getitem__(self, idx): 110 | return self.get_raw_image(idx) 111 | 112 | 113 | def center_crop_arr(pil_image, image_size): 114 | """ 115 | Center cropping implementation from ADM. 116 | https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 117 | """ 118 | while min(*pil_image.size) >= 2 * image_size: 119 | pil_image = pil_image.resize( 120 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 121 | ) 122 | 123 | scale = image_size / min(*pil_image.size) 124 | pil_image = pil_image.resize( 125 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 126 | ) 127 | 128 | arr = np.array(pil_image) 129 | crop_y = (arr.shape[0] - image_size) // 2 130 | crop_x = (arr.shape[1] - image_size) // 2 131 | return Image.fromarray( 132 | arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size] 133 | ) 134 | 135 | 136 | def numpy_randrange(start, end): 137 | return int(np.random.randint(start, end)) 138 | 139 | 140 | def random_crop_arr(pil_image, image_size, min_crop_frac=0.8, max_crop_frac=1.0): 141 | min_smaller_dim_size = math.ceil(image_size / max_crop_frac) 142 | max_smaller_dim_size = math.ceil(image_size / min_crop_frac) 143 | smaller_dim_size = numpy_randrange(min_smaller_dim_size, max_smaller_dim_size + 1) 144 | 145 | # We are not on a new enough PIL to support the `reducing_gap` 146 | # argument, which uses BOX downsampling at powers of two first. 147 | # Thus, we do it by hand to improve downsample quality. 148 | while min(*pil_image.size) >= 2 * smaller_dim_size: 149 | pil_image = pil_image.resize( 150 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 151 | ) 152 | 153 | scale = smaller_dim_size / min(*pil_image.size) 154 | pil_image = pil_image.resize( 155 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 156 | ) 157 | 158 | arr = np.array(pil_image) 159 | crop_y = numpy_randrange(0, arr.shape[0] - image_size + 1) 160 | crop_x = numpy_randrange(0, arr.shape[1] - image_size + 1) 161 | return Image.fromarray( 162 | arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size] 163 | ) 164 | 165 | 166 | def crop(pil_image, left, top, right, bottom): 167 | """ 168 | Crop the image to the specified box. 169 | """ 170 | return pil_image.crop((left, top, right, bottom)) 171 | 172 | 173 | class ImageCropDataset(Dataset): 174 | 175 | def __init__( 176 | self, 177 | raw_dataset, 178 | resolution, 179 | patch_size, 180 | seed=42, 181 | ): 182 | self.raw_dataset = raw_dataset 183 | self.resolution = resolution 184 | self.patch_size = patch_size 185 | self.aug_ratio = 1.0 186 | self.seed = seed 187 | self.epoch = None 188 | 189 | def set_epoch(self, epoch): 190 | self.epoch = epoch 191 | 192 | def set_aug_ratio(self, aug_ratio): 193 | self.aug_ratio = aug_ratio 194 | 195 | def __len__(self): 196 | return len(self.raw_dataset) 197 | 198 | def crop_and_flip(self, image): 199 | is_aug = np.random.rand() < self.aug_ratio 200 | if not is_aug: 201 | image = center_crop_arr(image, self.resolution) 202 | else: 203 | image = random_crop_arr(image, self.resolution) 204 | 205 | arr = np.asarray(image) 206 | 207 | is_flip = int(np.random.randint(0, 2)) 208 | if is_flip == 1: 209 | # horizontal flip 210 | arr = arr[:, ::-1, :] 211 | 212 | return arr.transpose(2, 0, 1) # HWC to CHW 213 | 214 | def __getitem__(self, idx): 215 | with numpy_seed(self.seed, self.epoch, idx): 216 | image, label = self.raw_dataset[idx] 217 | samples = self.crop_and_flip(image) 218 | # to [-1, 1] 219 | samples = (samples.astype(np.float32) / 255.0 - 0.5) * 2.0 220 | samples = torch.from_numpy(samples).float() 221 | return ( 222 | samples, 223 | torch.tensor(label).long(), 224 | ) 225 | 226 | 227 | def build_dataset(args): 228 | # use tarred imagenet dataset if data_path ends with .tar 229 | raw_dataset = ( 230 | ImageNetTarDataset(args.data_path) 231 | if args.data_path.endswith(".tar") 232 | else ImageFolder(args.data_path) 233 | ) 234 | return ImageCropDataset( 235 | raw_dataset, 236 | args.image_size, 237 | args.patch_size, 238 | seed=args.global_seed if hasattr(args, "global_seed") else 42, 239 | ) 240 | -------------------------------------------------------------------------------- /SphereAR/diff_head.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.utils.checkpoint import checkpoint 7 | 8 | from .sampling import euler_maruyama, euler 9 | 10 | 11 | def timestep_embedding(t, dim, max_period=10000, time_factor: float = 1000.0): 12 | half = dim // 2 13 | t = time_factor * t.float() 14 | freqs = torch.exp( 15 | -math.log(max_period) 16 | * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) 17 | / half 18 | ) 19 | 20 | args = t[:, None] * freqs[None] 21 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 22 | if dim % 2: 23 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 24 | if torch.is_floating_point(t): 25 | embedding = embedding.to(t) 26 | return embedding 27 | 28 | 29 | class DiffHead(nn.Module): 30 | """Diffusion Loss""" 31 | 32 | def __init__( 33 | self, 34 | ch_target, 35 | ch_cond, 36 | ch_latent, 37 | depth_latent, 38 | depth_adanln, 39 | grad_checkpointing=False, 40 | ): 41 | super(DiffHead, self).__init__() 42 | self.ch_target = ch_target 43 | self.net = MlpEncoder( 44 | in_channels=ch_target, 45 | model_channels=ch_latent, 46 | z_channels=ch_cond, 47 | num_res_blocks=depth_latent, 48 | num_ada_ln_blocks=depth_adanln, 49 | grad_checkpointing=grad_checkpointing, 50 | ) 51 | 52 | def forward(self, target, z): 53 | with torch.autocast(device_type="cuda", enabled=False): 54 | with torch.no_grad(): 55 | t = torch.randn((target.shape[0]), device=target.device).sigmoid() 56 | noise = torch.randn_like(target) 57 | ti = t.view(-1, 1) 58 | x = (1.0 - ti) * noise + ti * target 59 | v = target - noise 60 | 61 | output = self.net(x, t, z) 62 | 63 | with torch.autocast(device_type="cuda", enabled=False): 64 | output = output.float() 65 | loss = torch.mean((output - v) ** 2) 66 | return loss 67 | 68 | def sample( 69 | self, 70 | z, 71 | cfg, 72 | num_sampling_steps, 73 | ): 74 | return euler_maruyama( 75 | self.ch_target, 76 | self.net.forward, 77 | z, 78 | cfg, 79 | num_sampling_steps=num_sampling_steps, 80 | ) 81 | 82 | def initialize_weights(self): 83 | self.net.initialize_weights() 84 | 85 | 86 | class TimestepEmbedder(nn.Module): 87 | """ 88 | Embeds scalar timesteps into vector representations. 89 | """ 90 | 91 | def __init__(self, hidden_size, frequency_embedding_size=256): 92 | super().__init__() 93 | self.mlp = nn.Sequential( 94 | nn.Linear(frequency_embedding_size, hidden_size, bias=True), 95 | nn.SiLU(), 96 | nn.Linear(hidden_size, hidden_size, bias=True), 97 | ) 98 | self.frequency_embedding_size = frequency_embedding_size 99 | 100 | def forward(self, t): 101 | t_freq = timestep_embedding(t, self.frequency_embedding_size) 102 | t_emb = self.mlp(t_freq) 103 | return t_emb 104 | 105 | 106 | class ResBlock(nn.Module): 107 | def __init__(self, channels): 108 | super().__init__() 109 | self.channels = channels 110 | self.norm = nn.LayerNorm(channels, eps=1e-6, elementwise_affine=True) 111 | hidden_dim = int(channels * 1.5) 112 | self.w1 = nn.Linear(channels, hidden_dim * 2, bias=True) 113 | self.w2 = nn.Linear(hidden_dim, channels, bias=True) 114 | 115 | def forward(self, x, scale, shift, gate): 116 | h = self.norm(x) * (1 + scale) + shift 117 | h1, h2 = self.w1(h).chunk(2, dim=-1) 118 | h = self.w2(F.silu(h1) * h2) 119 | return x + h * gate 120 | 121 | 122 | class FinalLayer(nn.Module): 123 | def __init__(self, channels, out_channels): 124 | super().__init__() 125 | self.norm_final = nn.LayerNorm(channels, eps=1e-6, elementwise_affine=False) 126 | self.ada_ln_modulation = nn.Linear(channels, channels * 2, bias=True) 127 | self.linear = nn.Linear(channels, out_channels, bias=True) 128 | 129 | def forward(self, x, y): 130 | scale, shift = self.ada_ln_modulation(y).chunk(2, dim=-1) 131 | x = self.norm_final(x) * (1.0 + scale) + shift 132 | x = self.linear(x) 133 | return x 134 | 135 | 136 | class MlpEncoder(nn.Module): 137 | 138 | def __init__( 139 | self, 140 | in_channels, 141 | model_channels, 142 | z_channels, 143 | num_res_blocks, 144 | num_ada_ln_blocks=2, 145 | grad_checkpointing=False, 146 | ): 147 | super().__init__() 148 | 149 | self.in_channels = in_channels 150 | self.model_channels = model_channels 151 | self.out_channels = in_channels 152 | self.num_res_blocks = num_res_blocks 153 | self.grad_checkpointing = grad_checkpointing 154 | 155 | self.time_embed = TimestepEmbedder(model_channels) 156 | self.cond_embed = nn.Linear(z_channels, model_channels) 157 | 158 | self.input_proj = nn.Linear(in_channels, model_channels) 159 | self.res_blocks = nn.ModuleList() 160 | for i in range(num_res_blocks): 161 | self.res_blocks.append( 162 | ResBlock( 163 | model_channels, 164 | ) 165 | ) 166 | # share adaLN for consecutive blocks, to save computation and parameters 167 | self.ada_ln_blocks = nn.ModuleList() 168 | for i in range(num_ada_ln_blocks): 169 | self.ada_ln_blocks.append( 170 | nn.Linear(model_channels, model_channels * 3, bias=True) 171 | ) 172 | self.ada_ln_switch_freq = max(1, num_res_blocks // num_ada_ln_blocks) 173 | assert ( 174 | num_res_blocks % self.ada_ln_switch_freq 175 | ) == 0, "num_res_blocks must be divisible by num_ada_ln_blocks" 176 | self.final_layer = FinalLayer(model_channels, self.out_channels) 177 | 178 | self.initialize_weights() 179 | 180 | def initialize_weights(self): 181 | def _basic_init(module): 182 | if isinstance(module, nn.Linear): 183 | torch.nn.init.xavier_uniform_(module.weight) 184 | if module.bias is not None: 185 | nn.init.constant_(module.bias, 0) 186 | 187 | self.apply(_basic_init) 188 | 189 | # Initialize timestep embedding MLP 190 | nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02) 191 | nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02) 192 | 193 | for block in self.ada_ln_blocks: 194 | nn.init.constant_(block.weight, 0) 195 | nn.init.constant_(block.bias, 0) 196 | 197 | # Zero-out output layers 198 | nn.init.constant_(self.final_layer.ada_ln_modulation.weight, 0) 199 | nn.init.constant_(self.final_layer.ada_ln_modulation.bias, 0) 200 | nn.init.constant_(self.final_layer.linear.weight, 0) 201 | nn.init.constant_(self.final_layer.linear.bias, 0) 202 | 203 | @torch.compile() 204 | def forward(self, x, t, c): 205 | """ 206 | Apply the model to an input batch. 207 | :param x: an [N x C] Tensor of inputs. 208 | :param t: a 1-D batch of timesteps. 209 | :param c: conditioning from AR transformer. 210 | :return: an [N x C] Tensor of outputs. 211 | """ 212 | x = self.input_proj(x) 213 | t = self.time_embed(t) 214 | c = self.cond_embed(c) 215 | 216 | y = F.silu(t + c) 217 | scale, shift, gate = self.ada_ln_blocks[0](y).chunk(3, dim=-1) 218 | if self.grad_checkpointing and self.training: 219 | for i, block in enumerate(self.res_blocks): 220 | if i > 0 and i % self.ada_ln_switch_freq == 0: 221 | ada_ln_block = self.ada_ln_blocks[i // self.ada_ln_switch_freq] 222 | scale, shift, gate = ada_ln_block(y).chunk(3, dim=-1) 223 | x = checkpoint(block, x, scale, shift, gate, use_reentrant=False) 224 | else: 225 | for i, block in enumerate(self.res_blocks): 226 | if i > 0 and i % self.ada_ln_switch_freq == 0: 227 | ada_ln_block = self.ada_ln_blocks[i // self.ada_ln_switch_freq] 228 | scale, shift, gate = ada_ln_block(y).chunk(3, dim=-1) 229 | x = block(x, scale, shift, gate) 230 | 231 | return self.final_layer(x, y) 232 | -------------------------------------------------------------------------------- /SphereAR/vae.py: -------------------------------------------------------------------------------- 1 | # Modified from: 2 | # taming-transformers: https://github.com/CompVis/taming-transformers 3 | # maskgit: https://github.com/google-research/maskgit 4 | # llamagen: https://github.com/FoundationVision/LlamaGen 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from .layers import TransformerBlock, get_2d_pos, precompute_freqs_cis_2d 11 | from .psd import PowerSphericalDistribution, l2_norm 12 | 13 | 14 | class VAE(nn.Module): 15 | 16 | def __init__( 17 | self, 18 | latent_dim=16, 19 | image_size=256, 20 | patch_size=16, 21 | z_channels=512, 22 | cnn_chs=[64, 64, 128, 256, 512], 23 | encoder_vit_layers=6, 24 | decoder_vit_layers=12, 25 | ): 26 | super().__init__() 27 | self.z_channels = z_channels 28 | n_head = z_channels // 64 29 | self.encoder = ViTEncoder( 30 | n_layers=encoder_vit_layers, 31 | d_model=z_channels, 32 | n_heads=n_head, 33 | cnn_chs=cnn_chs, 34 | image_size=image_size, 35 | patch_size=patch_size, 36 | ) 37 | self.decoder = ViTDecoder( 38 | n_layers=decoder_vit_layers, 39 | d_model=z_channels, 40 | n_heads=n_head, 41 | cnn_chs=cnn_chs[::-1], 42 | image_size=image_size, 43 | patch_size=patch_size, 44 | ) 45 | self.latent_dim = latent_dim 46 | self.quant_proj = nn.Linear(z_channels, latent_dim + 1, bias=True) 47 | self.post_quant_proj = nn.Linear(latent_dim, z_channels, bias=False) 48 | 49 | def initialize_weights(self): 50 | 51 | self.quant_proj.reset_parameters() 52 | self.post_quant_proj.reset_parameters() 53 | self.encoder.output.reset_parameters() 54 | 55 | def normalize(self, x): 56 | x = l2_norm(x) 57 | x = x * (self.latent_dim**0.5) 58 | return x 59 | 60 | def encode(self, x): 61 | x = self.encoder(x) 62 | x = self.quant_proj(x) 63 | mu = x[..., :-1] 64 | kappa = x[..., -1] 65 | mu = l2_norm(mu) 66 | kappa = F.softplus(kappa) + 1.0 67 | qz = PowerSphericalDistribution(mu, kappa) 68 | loss = qz.kl_to_uniform() 69 | x = qz.rsample() 70 | x = x * (self.latent_dim**0.5) 71 | return x, loss.mean() 72 | 73 | def decode(self, x): 74 | x = self.post_quant_proj(x) 75 | dec = self.decoder(x) 76 | return dec 77 | 78 | 79 | class ResDownBlock(nn.Module): 80 | def __init__(self, in_ch, out_ch): 81 | super().__init__() 82 | self.block = nn.Sequential( 83 | nn.GroupNorm(min(32, in_ch // 4), in_ch, eps=1e-6, affine=True), 84 | nn.SiLU(), 85 | nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=2, padding=1, bias=False), 86 | nn.GroupNorm(min(32, out_ch // 4), out_ch, eps=1e-6, affine=True), 87 | nn.SiLU(), 88 | nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=False), 89 | ) 90 | self.shortcut = nn.Conv2d( 91 | in_ch, out_ch, kernel_size=2, stride=2, padding=0, bias=False 92 | ) 93 | 94 | def forward(self, x): 95 | return self.shortcut(x) + self.block(x) 96 | 97 | 98 | class PatchifyNet(nn.Module): 99 | 100 | def __init__(self, chs): 101 | super().__init__() 102 | layers = [] 103 | for i in range(len(chs) - 1): 104 | layers.append(ResDownBlock(chs[i], chs[i + 1])) 105 | self.net = nn.Sequential(*layers) 106 | 107 | def forward(self, x): 108 | x = self.net(x) 109 | return x 110 | 111 | 112 | class NCHW_to_NLC(nn.Module): 113 | def __init__(self): 114 | super().__init__() 115 | 116 | def forward(self, x): 117 | n, c, h, w = x.shape 118 | x = x.permute(0, 2, 3, 1).view(n, h * w, c) 119 | return x 120 | 121 | 122 | class NLC_to_NCHW(nn.Module): 123 | def __init__(self): 124 | super().__init__() 125 | 126 | def forward(self, x): 127 | n, l, c = x.shape 128 | h = w = int(l**0.5) 129 | x = x.view(n, h, w, c).permute(0, 3, 1, 2) 130 | return x 131 | 132 | 133 | class ResUpBlock(nn.Module): 134 | def __init__(self, in_ch, out_ch): 135 | super().__init__() 136 | num_groups = min(32, out_ch // 4) 137 | self.block = nn.Sequential( 138 | nn.GroupNorm(min(32, in_ch // 4), in_ch, eps=1e-6, affine=True), 139 | nn.SiLU(), 140 | nn.ConvTranspose2d( 141 | in_ch, out_ch, kernel_size=4, stride=2, padding=1, bias=False 142 | ), 143 | nn.GroupNorm(num_groups, out_ch, eps=1e-6, affine=True), 144 | nn.SiLU(), 145 | nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=False), 146 | ) 147 | self.shortcut = nn.ConvTranspose2d( 148 | in_ch, out_ch, kernel_size=2, stride=2, bias=False 149 | ) 150 | self.block2 = nn.Sequential( 151 | nn.GroupNorm(num_groups, out_ch, eps=1e-6, affine=True), 152 | nn.SiLU(), 153 | nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=False), 154 | nn.GroupNorm(num_groups, out_ch, eps=1e-6, affine=True), 155 | nn.SiLU(), 156 | nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=False), 157 | ) 158 | 159 | def forward(self, x): 160 | h = self.block(x) 161 | x = self.shortcut(x) 162 | x = x + h 163 | x = x + self.block2(x) 164 | return x 165 | 166 | 167 | class UnpatchifyNet(nn.Module): 168 | 169 | def __init__(self, chs): 170 | super().__init__() 171 | layers = [] 172 | for i in range(len(chs) - 1): 173 | layers.append(ResUpBlock(chs[i], chs[i + 1])) 174 | self.net = nn.Sequential(*layers) 175 | 176 | def forward(self, x): 177 | x = self.net(x) 178 | return x 179 | 180 | 181 | class ViTEncoder(nn.Module): 182 | 183 | def __init__( 184 | self, 185 | n_layers=6, 186 | n_heads=8, 187 | d_model=512, 188 | cnn_chs=[64, 64, 128, 256, 512], 189 | image_size=256, 190 | patch_size=16, 191 | ): 192 | super().__init__() 193 | self.conv_in = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 194 | self.patchify = nn.Sequential( 195 | PatchifyNet(cnn_chs), 196 | NCHW_to_NLC(), 197 | ) 198 | assert cnn_chs[-1] == d_model 199 | self.register_num_tokens = 4 200 | self.register_token = nn.Embedding(self.register_num_tokens, d_model) 201 | self.layers = nn.ModuleList() 202 | for _ in range(n_layers): 203 | self.layers.append( 204 | TransformerBlock( 205 | d_model, 206 | n_heads, 207 | 0.0, 208 | 0.0, 209 | drop_path=0.0, 210 | causal=False, 211 | ) 212 | ) 213 | self.norm = nn.RMSNorm(d_model, eps=1e-6) 214 | self.output = nn.Linear(d_model, d_model, bias=False) 215 | raw_2d_pos = get_2d_pos(image_size, patch_size) 216 | self.register_buffer( 217 | "freqs_cis", 218 | precompute_freqs_cis_2d( 219 | raw_2d_pos, d_model // n_heads, cls_token_num=self.register_num_tokens 220 | ).clone(), 221 | ) 222 | 223 | def forward(self, image): 224 | x = self.conv_in(image) 225 | x = self.patchify(x) 226 | x_null = self.register_token.weight.view(1, -1, x.shape[-1]).expand( 227 | x.shape[0], -1, -1 228 | ) 229 | x = torch.cat([x_null, x], dim=1) 230 | for layer in self.layers: 231 | x = layer(x, self.freqs_cis) 232 | x = x[:, self.register_num_tokens :, :] 233 | x = self.output(self.norm(x)) 234 | return x 235 | 236 | 237 | class ViTDecoder(nn.Module): 238 | def __init__( 239 | self, 240 | n_layers=12, 241 | n_heads=8, 242 | d_model=512, 243 | cnn_chs=[512, 256, 128, 64, 64], 244 | image_size=256, 245 | patch_size=16, 246 | ): 247 | super().__init__() 248 | self.d_model = d_model 249 | assert d_model == cnn_chs[0] 250 | self.conv_in = nn.Sequential( 251 | NLC_to_NCHW(), 252 | nn.Conv2d(d_model, d_model, kernel_size=3, stride=1, padding=1, bias=False), 253 | NCHW_to_NLC(), 254 | ) 255 | self.register_num_tokens = 4 256 | self.register_token = nn.Embedding(self.register_num_tokens, d_model) 257 | self.layers = nn.ModuleList() 258 | for _ in range(n_layers): 259 | self.layers.append( 260 | TransformerBlock( 261 | d_model, 262 | n_heads, 263 | 0.0, 264 | 0.0, 265 | drop_path=0.0, 266 | causal=False, 267 | ) 268 | ) 269 | self.unpatchify = nn.Sequential( 270 | NLC_to_NCHW(), 271 | UnpatchifyNet(chs=cnn_chs), 272 | ) 273 | self.conv_out = nn.Sequential( 274 | nn.GroupNorm(16, cnn_chs[-1], eps=1e-6, affine=True), 275 | nn.SiLU(), 276 | nn.Conv2d(cnn_chs[-1], 3, kernel_size=3, stride=1, padding=1, bias=False), 277 | ) 278 | 279 | raw_2d_pos = get_2d_pos(image_size, patch_size) 280 | self.register_buffer( 281 | "freqs_cis", 282 | precompute_freqs_cis_2d( 283 | raw_2d_pos, d_model // n_heads, cls_token_num=self.register_num_tokens 284 | ).clone(), 285 | ) 286 | 287 | @torch.compile() 288 | def forward(self, x): 289 | x = self.conv_in(x) 290 | x_null = self.register_token.weight.view(1, -1, x.shape[-1]).expand( 291 | x.shape[0], -1, -1 292 | ) 293 | x = torch.cat([x_null, x], dim=1) 294 | for layer in self.layers: 295 | x = layer(x, self.freqs_cis) 296 | x = x[:, self.register_num_tokens :, :] 297 | x = self.unpatchify(x) 298 | x = self.conv_out(x) 299 | return x 300 | -------------------------------------------------------------------------------- /SphereAR/layers.py: -------------------------------------------------------------------------------- 1 | from functools import lru_cache 2 | from typing import Optional 3 | 4 | import torch 5 | import torch.nn as nn 6 | from flash_attn import flash_attn_func 7 | from torch.nn import RMSNorm 8 | from torch.nn import functional as F 9 | 10 | 11 | def drop_path( 12 | x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True 13 | ): 14 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 15 | 16 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 17 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 18 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 19 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 20 | 'survival rate' as the argument. 21 | 22 | """ 23 | if drop_prob == 0.0 or not training: 24 | return x 25 | keep_prob = 1 - drop_prob 26 | shape = (x.shape[0],) + (1,) * ( 27 | x.ndim - 1 28 | ) # work with diff dim tensors, not just 2D ConvNets 29 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob) 30 | if keep_prob > 0.0 and scale_by_keep: 31 | random_tensor.div_(keep_prob) 32 | return x * random_tensor 33 | 34 | 35 | class DropPath(torch.nn.Module): 36 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" 37 | 38 | def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True): 39 | super(DropPath, self).__init__() 40 | self.drop_prob = drop_prob 41 | self.scale_by_keep = scale_by_keep 42 | 43 | def forward(self, x): 44 | return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) 45 | 46 | def extra_repr(self): 47 | return f"drop_prob={round(self.drop_prob,3):0.3f}" 48 | 49 | 50 | def find_multiple(n: int, k: int): 51 | if n % k == 0: 52 | return n 53 | return n + k - (n % k) 54 | 55 | 56 | @lru_cache(maxsize=16) 57 | def get_causal_mask(seq_q, seq_k, device): 58 | offset = seq_k - seq_q 59 | i = torch.arange(seq_q, device=device).unsqueeze(1) 60 | j = torch.arange(seq_k, device=device).unsqueeze(0) 61 | causal_mask = (j > (offset + i)).bool() 62 | causal_mask = causal_mask.unsqueeze(0).unsqueeze(0) 63 | return causal_mask 64 | 65 | 66 | class Attention(nn.Module): 67 | def __init__( 68 | self, 69 | dim, 70 | n_head, 71 | attn_dropout_p, 72 | resid_dropout_p, 73 | causal: bool = True, 74 | ): 75 | super().__init__() 76 | assert dim % n_head == 0 77 | self.dim = dim 78 | self.head_dim = dim // n_head 79 | self.scale = self.head_dim**-0.5 80 | self.n_head = n_head 81 | total_kv_dim = (self.n_head * 3) * self.head_dim 82 | 83 | self.wqkv = nn.Linear(dim, total_kv_dim, bias=False) 84 | self.wo = nn.Linear(dim, dim, bias=False) 85 | 86 | self.attn_dropout_p = attn_dropout_p 87 | self.resid_dropout = nn.Dropout(resid_dropout_p) 88 | self.causal = causal 89 | 90 | self.k_cache = None 91 | self.v_cache = None 92 | self.kv_cache_size = None 93 | 94 | def enable_kv_cache(self, bsz, max_seq_len): 95 | if self.kv_cache_size != (bsz, max_seq_len): 96 | device = self.wo.weight.device 97 | dtype = self.wo.weight.dtype 98 | self.k_cache = torch.zeros( 99 | (bsz, self.n_head, max_seq_len, self.head_dim), 100 | device=device, 101 | dtype=dtype, 102 | ) 103 | self.v_cache = torch.zeros( 104 | (bsz, self.n_head, max_seq_len, self.head_dim), 105 | device=device, 106 | dtype=dtype, 107 | ) 108 | self.kv_cache_size = (bsz, max_seq_len) 109 | 110 | def update_kv_cache( 111 | self, start_pos, end_pos, keys: torch.Tensor, values: torch.Tensor 112 | ): 113 | self.k_cache[:, :, start_pos:end_pos, :] = keys 114 | self.v_cache[:, :, start_pos:end_pos, :] = values 115 | return ( 116 | self.k_cache[:, :, :end_pos, :], 117 | self.v_cache[:, :, :end_pos, :], 118 | ) 119 | 120 | def naive_attention(self, xq, keys, values, is_causal): 121 | xq = xq * self.scale 122 | # q: [B, H, 1, D], k: [B, H, D, L] -> attn [B, H, 1, L] 123 | attn = xq @ keys.transpose(-1, -2) 124 | seq_q, seq_k = attn.shape[-2], attn.shape[-1] 125 | if is_causal and seq_q > 1: 126 | causal_mask = get_causal_mask(seq_q, seq_k, attn.device) 127 | attn.masked_fill_(causal_mask, float("-inf")) 128 | attn = torch.softmax(attn, dim=-1) 129 | if self.attn_dropout_p > 0 and self.training: 130 | attn = F.dropout(attn, p=self.attn_dropout_p, training=self.training) 131 | # [B, H, 1, L] @ [B, H, L, D] -> [B, H, 1, D] 132 | return attn @ values 133 | 134 | def forward( 135 | self, 136 | x: torch.Tensor, 137 | freqs_cis: torch.Tensor = None, 138 | start_pos: Optional[int] = None, 139 | end_pos: Optional[int] = None, 140 | ): 141 | bsz, seqlen, _ = x.shape 142 | xq, xk, xv = self.wqkv(x).chunk(3, dim=-1) 143 | 144 | xq = xq.view(bsz, seqlen, self.n_head, self.head_dim) 145 | xk = xk.view(bsz, seqlen, self.n_head, self.head_dim) 146 | xv = xv.view(bsz, seqlen, self.n_head, self.head_dim) 147 | 148 | if freqs_cis is not None: 149 | xq = apply_rotary_emb(xq, freqs_cis) 150 | xk = apply_rotary_emb(xk, freqs_cis) 151 | 152 | is_causal = self.causal 153 | if self.k_cache is not None and start_pos is not None: 154 | xq, xk, xv = map(lambda x: x.transpose(1, 2), (xq, xk, xv)) 155 | keys, values = self.update_kv_cache(start_pos, end_pos, xk, xv) 156 | output = self.naive_attention(xq, keys, values, is_causal=is_causal) 157 | output = output.transpose(1, 2).contiguous() 158 | else: 159 | output = flash_attn_func( 160 | xq, 161 | xk, 162 | xv, 163 | causal=is_causal, 164 | dropout_p=self.attn_dropout_p if self.training else 0, 165 | ) 166 | 167 | output = output.view(bsz, seqlen, self.dim) 168 | 169 | output = self.resid_dropout(self.wo(output)) 170 | return output 171 | 172 | 173 | class FeedForward(nn.Module): 174 | 175 | def __init__(self, dim, dropout_p=0.1, mlp_ratio=4.0): 176 | super().__init__() 177 | hidden_dim = mlp_ratio * dim 178 | hidden_dim = int(2 * hidden_dim / 3) 179 | hidden_dim = find_multiple(hidden_dim, 256) 180 | 181 | self.w1 = nn.Linear(dim, hidden_dim * 2, bias=False) 182 | self.w2 = nn.Linear(hidden_dim, dim, bias=False) 183 | self.ffn_dropout = nn.Dropout(dropout_p) 184 | 185 | def forward(self, x): 186 | h1, h2 = self.w1(x).chunk(2, dim=-1) 187 | return self.ffn_dropout(self.w2(F.silu(h1) * h2)) 188 | 189 | 190 | class TransformerBlock(nn.Module): 191 | def __init__( 192 | self, 193 | dim, 194 | n_head, 195 | attn_dropout_p: float = 0.0, 196 | resid_dropout_p: float = 0.0, 197 | drop_path: float = 0.0, 198 | causal: bool = True, 199 | ): 200 | super().__init__() 201 | self.attention = Attention( 202 | dim=dim, 203 | n_head=n_head, 204 | attn_dropout_p=attn_dropout_p, 205 | resid_dropout_p=resid_dropout_p, 206 | causal=causal, 207 | ) 208 | self.feed_forward = FeedForward( 209 | dim=dim, 210 | dropout_p=resid_dropout_p, 211 | ) 212 | self.attention_norm = RMSNorm(dim, eps=1e-6) 213 | self.ffn_norm = RMSNorm(dim, eps=1e-6) 214 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 215 | 216 | def forward( 217 | self, 218 | x: torch.Tensor, 219 | freqs_cis: torch.Tensor, 220 | ): 221 | h = x + self.drop_path(self.attention(self.attention_norm(x), freqs_cis)) 222 | out = h + self.drop_path(self.feed_forward(self.ffn_norm(h))) 223 | return out 224 | 225 | def forward_onestep( 226 | self, 227 | x: torch.Tensor, 228 | freqs_cis: torch.Tensor, 229 | start_pos: int, 230 | end_pos: int, 231 | ): 232 | h = x + self.drop_path( 233 | self.attention(self.attention_norm(x), freqs_cis, start_pos, end_pos) 234 | ) 235 | out = h + self.drop_path(self.feed_forward(self.ffn_norm(h))) 236 | return out 237 | 238 | 239 | def get_2d_pos(resolution, patch_size, num_scales=1): 240 | max_pos = resolution // patch_size 241 | coords_list = [] 242 | 243 | for i in range(num_scales): 244 | scale = 2 ** (num_scales - i - 1) 245 | P = max(resolution // scale // patch_size, 1) 246 | edge = float(max_pos) / P 247 | centers = (torch.arange(P, dtype=torch.float32) + 0.5) * edge 248 | grid_y, grid_x = torch.meshgrid(centers, centers, indexing="ij") 249 | coords = torch.stack([grid_x.reshape(-1), grid_y.reshape(-1)], dim=1) 250 | coords_list.append(coords) 251 | 252 | return torch.cat(coords_list, dim=0) 253 | 254 | 255 | def precompute_freqs_cis_2d( 256 | pos_2d, n_elem: int, base: float = 10000, cls_token_num=120 257 | ): 258 | # split the dimension into half, one for x and one for y 259 | half_dim = n_elem // 2 260 | freqs = 1.0 / ( 261 | base ** (torch.arange(0, half_dim, 2)[: (half_dim // 2)].float() / half_dim) 262 | ) 263 | t = pos_2d + 1.0 264 | if cls_token_num > 0: 265 | t = torch.cat( 266 | [torch.zeros((cls_token_num, 2), device=freqs.device), t], 267 | dim=0, 268 | ) 269 | freqs = torch.outer(t.flatten(), freqs).view(*t.shape[:-1], -1) 270 | return torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1) 271 | 272 | 273 | def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor): 274 | # x: (bs, seq_len, n_head, head_dim) 275 | # freqs_cis (seq_len, head_dim // 2, 2) 276 | xshaped = x.float().reshape( 277 | *x.shape[:-1], -1, 2 278 | ) # (bs, seq_len, n_head, head_dim//2, 2) 279 | freqs_cis = freqs_cis.view( 280 | 1, xshaped.size(1), 1, xshaped.size(3), 2 281 | ) # (1, seq_len, 1, head_dim//2, 2) 282 | x_out2 = torch.stack( 283 | [ 284 | xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1], 285 | xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1], 286 | ], 287 | dim=-1, 288 | ) 289 | x_out2 = x_out2.flatten(3) 290 | return x_out2.type_as(x) 291 | -------------------------------------------------------------------------------- /SphereAR/model.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from functools import partial 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.nn import functional as F 7 | from torch.utils.checkpoint import checkpoint 8 | 9 | from .diff_head import DiffHead 10 | from .layers import TransformerBlock, get_2d_pos, precompute_freqs_cis_2d 11 | from .vae import VAE 12 | 13 | 14 | def get_model_args(): 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument( 17 | "--model", type=str, choices=list(SphereAR_models.keys()), default="SphereAR-L" 18 | ) 19 | parser.add_argument("--vae-only", action="store_true", help="only train vae") 20 | parser.add_argument("--image-size", type=int, choices=[256, 512], default=256) 21 | parser.add_argument("--patch-size", type=int, default=16, choices=[16]) 22 | parser.add_argument("--num-classes", type=int, default=1000) 23 | parser.add_argument("--cls-token-num", type=int, default=16) 24 | parser.add_argument("--latent-dim", type=int, default=16) 25 | parser.add_argument("--diff-batch-mul", type=int, default=4) 26 | parser.add_argument("--grad-checkpointing", action="store_true") 27 | return parser 28 | 29 | 30 | def create_model(args, device): 31 | model = SphereAR_models[args.model]( 32 | resolution=args.image_size, 33 | patch_size=args.patch_size, 34 | latent_dim=args.latent_dim, 35 | vae_only=args.vae_only, 36 | diff_batch_mul=args.diff_batch_mul, 37 | cls_token_num=args.cls_token_num, 38 | num_classes=args.num_classes, 39 | grad_checkpointing=args.grad_checkpointing, 40 | ).to(device, memory_format=torch.channels_last) 41 | return model 42 | 43 | 44 | class SphereAR(nn.Module): 45 | 46 | def __init__( 47 | self, 48 | dim, 49 | n_layer, 50 | n_head, 51 | diff_layers, 52 | diff_dim, 53 | diff_adanln_layers, 54 | latent_dim, 55 | patch_size, 56 | resolution, 57 | diff_batch_mul, 58 | vae_only=False, 59 | grad_checkpointing=False, 60 | cls_token_num=16, 61 | num_classes: int = 1000, 62 | class_dropout_prob: float = 0.1, 63 | ): 64 | super().__init__() 65 | 66 | self.n_layer = n_layer 67 | self.resolution = resolution 68 | self.patch_size = patch_size 69 | self.num_classes = num_classes 70 | self.cls_token_num = cls_token_num 71 | self.class_dropout_prob = class_dropout_prob 72 | self.latent_dim = latent_dim 73 | 74 | self.vae = VAE( 75 | latent_dim=latent_dim, image_size=resolution, patch_size=patch_size 76 | ) 77 | self.vae_only = vae_only 78 | self.grad_checkpointing = grad_checkpointing 79 | 80 | if not vae_only: 81 | self.cls_embedding = nn.Embedding(num_classes + 1, dim * self.cls_token_num) 82 | self.proj_in = nn.Linear(latent_dim, dim, bias=True) 83 | self.emb_norm = nn.RMSNorm(dim, eps=1e-6, elementwise_affine=True) 84 | self.h, self.w = resolution // patch_size, resolution // patch_size 85 | self.total_tokens = self.h * self.w + self.cls_token_num 86 | 87 | self.layers = torch.nn.ModuleList() 88 | for layer_id in range(n_layer): 89 | self.layers.append( 90 | TransformerBlock( 91 | dim, 92 | n_head, 93 | causal=True, 94 | ) 95 | ) 96 | 97 | self.norm = nn.RMSNorm(dim, eps=1e-6, elementwise_affine=True) 98 | self.pos_for_diff = nn.Embedding(self.h * self.w, dim) 99 | self.head = DiffHead( 100 | ch_target=latent_dim, 101 | ch_cond=dim, 102 | ch_latent=diff_dim, 103 | depth_latent=diff_layers, 104 | depth_adanln=diff_adanln_layers, 105 | grad_checkpointing=grad_checkpointing, 106 | ) 107 | self.diff_batch_mul = diff_batch_mul 108 | 109 | patch_2d_pos = get_2d_pos(resolution, patch_size) 110 | 111 | self.register_buffer( 112 | "freqs_cis", 113 | precompute_freqs_cis_2d( 114 | patch_2d_pos, 115 | dim // n_head, 116 | 10000, 117 | cls_token_num=self.cls_token_num, 118 | )[:-1], 119 | persistent=False, 120 | ) 121 | self.freeze_vae() 122 | 123 | self.initialize_weights() 124 | 125 | def non_decay_keys(self): 126 | return ["proj_in", "cls_embedding"] 127 | 128 | def freeze_module(self, module: nn.Module): 129 | for param in module.parameters(): 130 | param.requires_grad = False 131 | 132 | def freeze_vae(self): 133 | self.freeze_module(self.vae) 134 | self.vae.eval() 135 | 136 | def initialize_weights(self): 137 | # Initialize nn.Linear and nn.Embedding 138 | self.apply(self.__init_weights) 139 | if not self.vae_only: 140 | self.head.initialize_weights() 141 | self.vae.initialize_weights() 142 | 143 | def __init_weights(self, module): 144 | std = 0.02 145 | if isinstance(module, nn.Linear): 146 | module.weight.data.normal_(mean=0.0, std=std) 147 | if module.bias is not None: 148 | module.bias.data.zero_() 149 | elif isinstance(module, nn.Embedding): 150 | module.weight.data.normal_(mean=0.0, std=std) 151 | 152 | def drop_label(self, class_id): 153 | if self.class_dropout_prob > 0.0 and self.training: 154 | is_drop = ( 155 | torch.rand(class_id.shape, device=class_id.device) 156 | < self.class_dropout_prob 157 | ) 158 | class_id = torch.where(is_drop, self.num_classes, class_id) 159 | return class_id 160 | 161 | def forward( 162 | self, 163 | images, 164 | class_id, 165 | ): 166 | 167 | vae_latent, kl_loss = self.vae.encode(images) 168 | 169 | if not self.vae_only: 170 | x = vae_latent.detach() 171 | x = self.proj_in(x[:, :-1, :]) 172 | class_id = self.drop_label(class_id) 173 | bsz = x.shape[0] 174 | c = self.cls_embedding(class_id).view(bsz, self.cls_token_num, -1) 175 | x = torch.cat([c, x], dim=1) 176 | x = self.emb_norm(x) 177 | 178 | if self.grad_checkpointing and self.training: 179 | for layer in self.layers: 180 | block = partial(layer.forward, freqs_cis=self.freqs_cis) 181 | x = checkpoint(block, x, use_reentrant=False) 182 | else: 183 | for layer in self.layers: 184 | x = layer(x, self.freqs_cis) 185 | 186 | x = x[:, -self.h * self.w :, :] 187 | x = self.norm(x) 188 | x = x + self.pos_for_diff.weight 189 | 190 | target = vae_latent.detach() 191 | x = x.view(-1, x.shape[-1]) 192 | target = target.view(-1, target.shape[-1]) 193 | 194 | x = x.repeat(self.diff_batch_mul, 1) 195 | target = target.repeat(self.diff_batch_mul, 1) 196 | loss = self.head(target, x) 197 | recon = None 198 | else: 199 | loss = torch.tensor(0.0, device=images.device, dtype=images.dtype) 200 | recon = self.vae.decode(vae_latent) 201 | 202 | return loss, kl_loss, recon 203 | 204 | def enable_kv_cache(self, bsz): 205 | for layer in self.layers: 206 | layer.attention.enable_kv_cache(bsz, self.total_tokens) 207 | 208 | @torch.compile() 209 | def forward_model(self, x, start_pos, end_pos): 210 | x = self.emb_norm(x) 211 | for layer in self.layers: 212 | x = layer.forward_onestep( 213 | x, self.freqs_cis[start_pos:end_pos,], start_pos, end_pos 214 | ) 215 | x = self.norm(x) 216 | return x 217 | 218 | def head_sample(self, x, diff_pos, sample_steps, cfg_scale, cfg_schedule="linear"): 219 | x = x + self.pos_for_diff.weight[diff_pos : diff_pos + 1, :] 220 | x = x.view(-1, x.shape[-1]) 221 | seq_len = self.h * self.w 222 | if cfg_scale > 1.0: 223 | if cfg_schedule == "constant": 224 | cfg_iter = cfg_scale 225 | elif cfg_schedule == "linear": 226 | start = 1.0 227 | cfg_iter = start + (cfg_scale - start) * diff_pos / seq_len 228 | else: 229 | raise NotImplementedError(f"unknown cfg_schedule {cfg_schedule}") 230 | else: 231 | cfg_iter = 1.0 232 | pred = self.head.sample(x, num_sampling_steps=sample_steps, cfg=cfg_iter) 233 | pred = pred.view(-1, 1, pred.shape[-1]) 234 | # Important: normalize here, for both next-token prediction and vae decoding 235 | pred = self.vae.normalize(pred) 236 | return pred 237 | 238 | @torch.no_grad() 239 | def sample(self, cond, sample_steps, cfg_scale=1.0, cfg_schedule="linear"): 240 | self.eval() 241 | if cfg_scale > 1.0: 242 | cond_null = torch.ones_like(cond) * self.num_classes 243 | cond_combined = torch.cat([cond, cond_null]) 244 | else: 245 | cond_combined = cond 246 | bsz = cond_combined.shape[0] 247 | act_bsz = bsz // 2 if cfg_scale > 1.0 else bsz 248 | self.enable_kv_cache(bsz) 249 | 250 | c = self.cls_embedding(cond_combined).view(bsz, self.cls_token_num, -1) 251 | last_pred = None 252 | all_preds = [] 253 | for i in range(self.h * self.w): 254 | if i == 0: 255 | x = self.forward_model(c, 0, self.cls_token_num) 256 | else: 257 | x = self.proj_in(last_pred) 258 | x = self.forward_model( 259 | x, i + self.cls_token_num - 1, i + self.cls_token_num 260 | ) 261 | last_pred = self.head_sample( 262 | x[:, -1:, :], 263 | i, 264 | sample_steps, 265 | cfg_scale, 266 | cfg_schedule, 267 | ) 268 | all_preds.append(last_pred) 269 | 270 | x = torch.cat(all_preds, dim=-2)[:act_bsz] 271 | recon = self.vae.decode(x) 272 | return recon 273 | 274 | def get_fsdp_wrap_module_list(self): 275 | return list(self.layers) 276 | 277 | 278 | def SphereAR_H(**kwargs): 279 | return SphereAR( 280 | n_layer=40, 281 | n_head=20, 282 | dim=1280, 283 | diff_layers=12, 284 | diff_dim=1280, 285 | diff_adanln_layers=3, 286 | **kwargs, 287 | ) 288 | 289 | 290 | def SphereAR_L(**kwargs): 291 | return SphereAR( 292 | n_layer=32, 293 | n_head=16, 294 | dim=1024, 295 | diff_layers=8, 296 | diff_dim=1024, 297 | diff_adanln_layers=2, 298 | **kwargs, 299 | ) 300 | 301 | 302 | def SphereAR_B(**kwargs): 303 | return SphereAR( 304 | n_layer=24, 305 | n_head=12, 306 | dim=768, 307 | diff_layers=6, 308 | diff_dim=768, 309 | diff_adanln_layers=2, 310 | **kwargs, 311 | ) 312 | 313 | 314 | SphereAR_models = { 315 | "SphereAR-B": SphereAR_B, 316 | "SphereAR-L": SphereAR_L, 317 | "SphereAR-H": SphereAR_H, 318 | } 319 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Modified from: 2 | # llamagen: https://github.com/FoundationVision/LlamaGen/ 3 | import math 4 | 5 | import torch 6 | 7 | torch.backends.cuda.matmul.allow_tf32 = True 8 | torch.backends.cudnn.allow_tf32 = True 9 | 10 | import inspect 11 | import os 12 | import shutil 13 | import time 14 | from copy import deepcopy 15 | from multiprocessing.pool import ThreadPool 16 | 17 | import torch.distributed as dist 18 | from torch.nn.parallel import DistributedDataParallel as DDP 19 | from torch.utils.data import DataLoader 20 | from torch.utils.data.distributed import DistributedSampler 21 | from torch.utils.tensorboard import SummaryWriter 22 | 23 | from SphereAR.dataset import build_dataset 24 | from SphereAR.gan.gan_loss import GANLoss 25 | from SphereAR.model import create_model, get_model_args 26 | from SphereAR.utils import create_logger, requires_grad, update_ema 27 | 28 | 29 | def create_optimizer(model, weight_decay, learning_rate, betas, logger): 30 | def is_decay_param(name, param, no_decay_keys): 31 | for key in no_decay_keys: 32 | if key in name: 33 | return False 34 | if param.dim() < 2: 35 | return False 36 | return True 37 | 38 | # start with all of the candidate parameters 39 | param_dict = {pn: p for pn, p in model.named_parameters()} 40 | # filter out those that do not require grad 41 | param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} 42 | no_decay_keys = model.non_decay_keys() if hasattr(model, "non_decay_keys") else [] 43 | decay_params = [ 44 | p for n, p in param_dict.items() if is_decay_param(n, p, no_decay_keys) 45 | ] 46 | nodecay_params = [ 47 | p for n, p in param_dict.items() if not is_decay_param(n, p, no_decay_keys) 48 | ] 49 | optim_groups = [ 50 | {"params": decay_params, "weight_decay": weight_decay}, 51 | {"params": nodecay_params, "weight_decay": 0.0}, 52 | ] 53 | num_decay_params = sum(p.numel() for p in decay_params) 54 | num_nodecay_params = sum(p.numel() for p in nodecay_params) 55 | logger.info( 56 | f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters" 57 | ) 58 | logger.info( 59 | f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters" 60 | ) 61 | # Create AdamW optimizer and use the fused version if it is available 62 | fused_available = "fused" in inspect.signature(torch.optim.AdamW).parameters 63 | extra_args = dict(fused=True) if fused_available else dict() 64 | optimizer = torch.optim.AdamW( 65 | optim_groups, lr=learning_rate, betas=betas, **extra_args 66 | ) 67 | logger.info(f"using fused AdamW: {fused_available}") 68 | return optimizer 69 | 70 | 71 | def adjust_learning_rate(args, cur_steps, total_steps, optimizer): 72 | if cur_steps < args.warmup_steps and args.warmup_steps > 0: 73 | lr = args.lr * cur_steps / args.warmup_steps 74 | elif ( 75 | args.decay_start > 0 76 | and cur_steps >= args.decay_start 77 | and args.decay_start < total_steps 78 | ): 79 | # decay from decay_start to total_steps, with learning rate cosine decay to min_lr 80 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * ( 81 | 1.0 82 | + math.cos( 83 | math.pi 84 | * (cur_steps - args.decay_start) 85 | / max(total_steps - args.decay_start, 1e-8) 86 | ) 87 | ) 88 | else: 89 | lr = args.lr 90 | for param_group in optimizer.param_groups: 91 | param_group["lr"] = lr 92 | return lr 93 | 94 | 95 | def init_distributed_mode(args): 96 | args.rank = int(os.environ["RANK"]) 97 | args.world_size = int(os.environ["WORLD_SIZE"]) 98 | args.gpu = int(os.environ["LOCAL_RANK"]) 99 | 100 | args.distributed = True 101 | device = torch.device("cuda", args.gpu) 102 | torch.cuda.set_device(device) 103 | 104 | print(f"| distributed init (rank {args.rank}, gpu {args.gpu})", flush=True) 105 | 106 | dist.init_process_group( 107 | backend="nccl", 108 | init_method="env://", 109 | world_size=args.world_size, 110 | rank=args.rank, 111 | device_id=device, 112 | ) 113 | dist.barrier() 114 | return device 115 | 116 | 117 | def get_orig_model(model): 118 | if isinstance(model, DDP): 119 | model = model.module 120 | if hasattr(model, "_orig_mod"): 121 | model = model._orig_mod 122 | return model 123 | 124 | 125 | def _linear_decay_ratio(epoch: int, start: int, end: int) -> float: 126 | if start < 0 or end <= start: 127 | return 1.0 128 | if epoch < start: 129 | r = 1.0 130 | elif epoch >= end: 131 | r = 0.0 132 | else: 133 | r = 1.0 - (epoch - start) / float(end - start) 134 | return max(0.0, min(1.0, r)) 135 | 136 | 137 | def create_dataloader(dataset, sampler, epoch, args): 138 | sampler.set_epoch(epoch) 139 | dataset.set_epoch(epoch) 140 | # linear decay of aug_ratio 141 | aug_ratio = _linear_decay_ratio( 142 | epoch, args.aug_decay_start_epoch, args.aug_decay_end_epoch 143 | ) 144 | dataset.set_aug_ratio(aug_ratio) 145 | loader = DataLoader( 146 | dataset, 147 | batch_size=int(args.global_batch_size // dist.get_world_size()), 148 | shuffle=False, 149 | sampler=sampler, 150 | num_workers=args.num_workers, 151 | pin_memory=True, 152 | drop_last=True, 153 | ) 154 | return loader 155 | 156 | 157 | def update_loss_dict(running_loss_dict, **kwargs): 158 | for k, v in kwargs.items(): 159 | if v is not None: 160 | if torch.is_tensor(v): 161 | v = v.item() 162 | running_loss_dict[k] = running_loss_dict.get(k, 0.0) + v 163 | return running_loss_dict 164 | 165 | 166 | def vae_loss( 167 | args, 168 | gan_model, 169 | images, 170 | recon, 171 | kl_loss, 172 | train_steps, 173 | running_loss_dict, 174 | ): 175 | if args.vae_only: 176 | recon_loss, p_loss, gen_loss = gan_model( 177 | images, 178 | recon, 179 | optimizer_idx=0, 180 | global_step=train_steps + 1, 181 | ) 182 | running_loss_dict = update_loss_dict( 183 | running_loss_dict, 184 | kl_loss=kl_loss, 185 | recon_loss=recon_loss, 186 | p_loss=p_loss, 187 | gen_loss=gen_loss, 188 | ) 189 | loss = ( 190 | args.perceptual_weight * p_loss 191 | + gen_loss 192 | + args.reconstruction_weight * recon_loss 193 | + args.kl_weight * kl_loss 194 | ) 195 | return loss, running_loss_dict 196 | else: 197 | return 0.0, running_loss_dict 198 | 199 | 200 | def logging( 201 | running_loss_dict, 202 | running_gnorm, 203 | log_steps, 204 | steps_per_sec, 205 | train_steps, 206 | device, 207 | logger, 208 | tsb_writer, 209 | ): 210 | keys = sorted(running_loss_dict.keys()) 211 | running_losses = [running_loss_dict[k] for k in keys] 212 | # Reduce loss history over all processes: 213 | all_loss = torch.tensor( 214 | running_losses, 215 | device=device, 216 | ) 217 | dist.all_reduce(all_loss, op=dist.ReduceOp.SUM) 218 | 219 | avg_gnorm = running_gnorm / log_steps 220 | all_loss = [ 221 | (keys[i], all_loss[i].item() / dist.get_world_size() / log_steps) 222 | for i in range(len(keys)) 223 | ] 224 | loss_str = ", ".join([f"{k}: {v:.4f}" for k, v in all_loss]) 225 | 226 | logger.info( 227 | f"(step={train_steps:07d}): {loss_str} ,Train Steps/Sec: {steps_per_sec:.2f}, Train Grad Norm: {avg_gnorm:.4f}" 228 | ) 229 | if tsb_writer is not None: 230 | for k, v in all_loss: 231 | tsb_writer.add_scalar(f"train/{k}", v, train_steps) 232 | tsb_writer.add_scalar("train/steps_per_sec", steps_per_sec, train_steps) 233 | tsb_writer.add_scalar("train/grad_norm", avg_gnorm, train_steps) 234 | 235 | 236 | def copy_ckp_func(src_file, dest_path, cur_epoch, keep_freq): 237 | dest_file = os.path.join(dest_path, "last.pt") 238 | if os.path.exists(dest_file): 239 | shutil.copyfile(dest_file, os.path.join(dest_path, "prev.pt")) 240 | shutil.copyfile(src_file, dest_file) 241 | if cur_epoch > 0 and keep_freq > 0 and cur_epoch % keep_freq == 0: 242 | shutil.copyfile( 243 | dest_file, 244 | os.path.join(dest_path, f"epoch_{cur_epoch}.pt"), 245 | ) 246 | 247 | 248 | def main(args): 249 | assert torch.cuda.is_available(), "Training currently requires at least one GPU." 250 | 251 | # Setup DDP: 252 | device = init_distributed_mode(args) 253 | assert ( 254 | args.global_batch_size % dist.get_world_size() == 0 255 | ), f"Batch size must be divisible by world size." 256 | rank = dist.get_rank() 257 | seed = args.global_seed * dist.get_world_size() + rank 258 | torch.manual_seed(seed) 259 | torch.cuda.set_device(device) 260 | 261 | results_dir = args.results_dir 262 | 263 | if rank == 0: 264 | os.makedirs(args.results_dir, exist_ok=True) 265 | logger = create_logger(results_dir) 266 | logger.info(f"Experiment directory created at {results_dir}") 267 | ckp_async_thread = ThreadPool(processes=1) 268 | else: 269 | logger = create_logger(None) 270 | ckp_async_thread = None 271 | 272 | # training args 273 | logger.info(f"{args}") 274 | 275 | # training env 276 | logger.info( 277 | f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}." 278 | ) 279 | 280 | model = create_model(args, device) 281 | 282 | if args.trained_vae != "": 283 | vae_info = torch.load(args.trained_vae, map_location="cpu", weights_only=False) 284 | vae_info = vae_info["model"] 285 | res = model.load_state_dict(vae_info, strict=False) 286 | for k in res.missing_keys: 287 | if "vae" in k: 288 | raise ValueError(f"Fail to load VAE weights from {args.trained_vae}") 289 | model.freeze_vae() 290 | logger.info(f"loaded pretrained VAE from {args.trained_vae}") 291 | 292 | logger.info(model) 293 | logger.info(f"Model Parameters: {sum(p.numel() for p in model.parameters()):,}") 294 | 295 | if args.ema > 0: 296 | ema_model = deepcopy(model).to( 297 | device 298 | ) # Create an EMA of the model for use after training 299 | requires_grad(ema_model, False) 300 | logger.info( 301 | f"EMA Parameters: {sum(p.numel() for p in ema_model.parameters()):,}" 302 | ) 303 | 304 | # Setup optimizer 305 | optimizer = create_optimizer( 306 | model, args.weight_decay, args.lr, (args.beta1, args.beta2), logger 307 | ) 308 | 309 | # Setup gan loss for vae training 310 | if args.vae_only: 311 | gan_model = GANLoss( 312 | disc_start=args.disc_start, 313 | disc_weight=args.disc_weight, 314 | disc_type=args.disc_type, 315 | disc_loss=args.disc_loss, 316 | gen_adv_loss=args.gen_loss, 317 | image_size=args.image_size, 318 | reconstruction_loss=args.reconstruction_loss, 319 | ).to(device, memory_format=torch.channels_last) 320 | logger.info( 321 | f"Discriminator Parameters: {sum(p.numel() for p in gan_model.discriminator.parameters()):,}" 322 | ) 323 | optimizer_disc = create_optimizer( 324 | gan_model.discriminator, 325 | args.weight_decay, 326 | args.lr, 327 | (args.beta1, args.beta2), 328 | logger, 329 | ) 330 | else: 331 | gan_model = None 332 | optimizer_disc = None 333 | 334 | dataset = build_dataset(args) 335 | sampler = DistributedSampler( 336 | dataset, 337 | num_replicas=dist.get_world_size(), 338 | rank=rank, 339 | shuffle=True, 340 | seed=args.global_seed, 341 | ) 342 | 343 | logger.info(f"Dataset contains {len(dataset):,} images ({args.data_path})") 344 | 345 | checkpoint_path = f"{results_dir}/last.pt" 346 | total_steps = args.epochs * int(len(dataset) / args.global_batch_size) 347 | 348 | # Prepare models for training: 349 | if os.path.exists(checkpoint_path): 350 | checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False) 351 | start_epoch = checkpoint["epochs"] 352 | train_steps = int(start_epoch * int(len(dataset) / args.global_batch_size)) 353 | model.load_state_dict(checkpoint["model"], strict=True) 354 | if args.ema > 0: 355 | ema_model.load_state_dict( 356 | checkpoint["ema"] if "ema" in checkpoint else checkpoint["model"] 357 | ) 358 | optimizer.load_state_dict(checkpoint["optimizer"]) 359 | if args.vae_only: 360 | gan_model.discriminator.load_state_dict( 361 | checkpoint["model_disc"], strict=True 362 | ) 363 | optimizer_disc.load_state_dict(checkpoint["optimizer_disc"]) 364 | 365 | del checkpoint 366 | 367 | logger.info(f"Resume training from checkpoint: {checkpoint_path}") 368 | logger.info(f"Initial state: steps={train_steps}, epochs={start_epoch}") 369 | else: 370 | train_steps = 0 371 | start_epoch = 0 372 | if args.ema > 0: 373 | update_ema(ema_model, model, decay=0) 374 | 375 | if not args.no_compile: 376 | logger.info("compiling the model... (may take several minutes)") 377 | model = torch.compile(model) # requires PyTorch 2.0 378 | if args.vae_only: 379 | gan_model = torch.compile(gan_model) 380 | 381 | model = DDP(model.to(device), device_ids=[args.gpu]) 382 | model.train() 383 | 384 | if args.vae_only: 385 | gan_model = DDP(gan_model.to(device), device_ids=[args.gpu]) 386 | gan_model.train() 387 | if args.ema > 0: 388 | ema_model.eval() 389 | 390 | ptdtype = {"none": torch.float32, "bf16": torch.bfloat16}[args.mixed_precision] 391 | 392 | log_steps = 0 393 | running_loss_dict = {} 394 | running_gnorm = 0 395 | start_time = time.time() 396 | 397 | logger.info(f"Training for {args.epochs} epochs ({total_steps} steps)") 398 | tsb_writer = SummaryWriter(log_dir=results_dir) if rank == 0 else None 399 | 400 | for epoch in range(start_epoch, args.epochs): 401 | loader = create_dataloader(dataset, sampler, epoch, args) 402 | 403 | logger.info(f"Beginning epoch {epoch}...") 404 | for images, classes in loader: 405 | classes = classes.to(device, non_blocking=True) 406 | images = images.to(device, non_blocking=True).contiguous( 407 | memory_format=torch.channels_last 408 | ) 409 | 410 | optimizer.zero_grad(set_to_none=True) 411 | 412 | with torch.amp.autocast("cuda", dtype=ptdtype): 413 | ar_loss, kl_loss, recon = model(images, classes) 414 | gan_g_loss, running_loss_dict = vae_loss( 415 | args, 416 | gan_model, 417 | images, 418 | recon, 419 | kl_loss, 420 | train_steps, 421 | running_loss_dict, 422 | ) 423 | if not args.vae_only: 424 | running_loss_dict = update_loss_dict(running_loss_dict, loss=ar_loss) 425 | loss = ar_loss + gan_g_loss 426 | loss.backward() 427 | 428 | if args.max_grad_norm != 0.0: 429 | gnorm = torch.nn.utils.clip_grad_norm_( 430 | model.parameters(), args.max_grad_norm 431 | ) 432 | running_gnorm += gnorm.item() 433 | cur_lr = adjust_learning_rate(args, train_steps, total_steps, optimizer) 434 | running_loss_dict = update_loss_dict(running_loss_dict, lr=cur_lr) 435 | optimizer.step() 436 | 437 | if args.ema > 0: 438 | update_ema(ema_model, get_orig_model(model), decay=args.ema) 439 | 440 | if args.vae_only: 441 | # discriminator training 442 | optimizer_disc.zero_grad(set_to_none=True) 443 | with torch.amp.autocast("cuda", dtype=ptdtype): 444 | disc_loss = gan_model( 445 | images, 446 | recon.detach(), 447 | optimizer_idx=1, 448 | global_step=train_steps + 1, 449 | ) 450 | running_loss_dict = update_loss_dict( 451 | running_loss_dict, disc_loss=disc_loss 452 | ) 453 | disc_loss.backward() 454 | if args.max_grad_norm != 0.0: 455 | torch.nn.utils.clip_grad_norm_( 456 | get_orig_model(gan_model).discriminator.parameters(), 457 | args.max_grad_norm, 458 | ) 459 | adjust_learning_rate(args, train_steps, total_steps, optimizer_disc) 460 | optimizer_disc.step() 461 | 462 | log_steps += 1 463 | train_steps += 1 464 | if train_steps % args.log_every == 0: 465 | # Measure training speed: 466 | torch.cuda.synchronize() 467 | end_time = time.time() 468 | steps_per_sec = log_steps / (end_time - start_time) 469 | logging( 470 | running_loss_dict, 471 | running_gnorm, 472 | log_steps, 473 | steps_per_sec, 474 | train_steps, 475 | device, 476 | logger, 477 | tsb_writer, 478 | ) 479 | running_loss_dict = {} 480 | running_gnorm = 0 481 | log_steps = 0 482 | start_time = time.time() 483 | 484 | # save checkpoint at the end of each epoch 485 | if rank == 0: 486 | checkpoint = { 487 | "model": get_orig_model(model).state_dict(), 488 | "optimizer": optimizer.state_dict(), 489 | "epochs": epoch + 1, 490 | "args": args, 491 | } 492 | if args.vae_only: 493 | checkpoint["model_disc"] = get_orig_model( 494 | gan_model 495 | ).discriminator.state_dict() 496 | checkpoint["optimizer_disc"] = optimizer_disc.state_dict() 497 | if args.ema > 0: 498 | checkpoint["ema"] = ema_model.state_dict() 499 | # save on /dev/shm (memory), then async copy to remote 500 | local_file = os.path.join(args.tmp_results_dir, "last.pt") 501 | torch.save(checkpoint, local_file) 502 | ckp_async_thread.apply_async( 503 | copy_ckp_func, 504 | args=(local_file, results_dir, epoch + 1, args.keep_freq), 505 | error_callback=lambda e: logger.error("async copy error :" + str(e)), 506 | ) 507 | 508 | dist.barrier() 509 | 510 | if ckp_async_thread is not None: 511 | ckp_async_thread.close() 512 | ckp_async_thread.join() 513 | 514 | if rank == 0: 515 | # free space by removing prev checkpoint 516 | os.system(f"rm {results_dir}/prev.pt") 517 | logger.info("Done!") 518 | dist.destroy_process_group() 519 | 520 | 521 | if __name__ == "__main__": 522 | parser = get_model_args() 523 | parser.add_argument("--data-path", type=str, required=True) 524 | parser.add_argument("--trained-vae", type=str, default="") 525 | parser.add_argument("--aug-decay-start-epoch", type=int, default=350) 526 | parser.add_argument("--aug-decay-end-epoch", type=int, default=375) 527 | parser.add_argument("--ema", default=-1, type=float) 528 | parser.add_argument("--no-compile", action="store_true") 529 | parser.add_argument("--tmp-results-dir", type=str, default="/dev/shm/") 530 | parser.add_argument("--results-dir", type=str, default="results") 531 | parser.add_argument("--epochs", type=int, default=400) 532 | parser.add_argument("--lr", type=float, default=3e-4) 533 | parser.add_argument("--min-lr", type=float, default=1e-5) 534 | parser.add_argument("--warmup-steps", type=int, default=20000) 535 | parser.add_argument("--decay-start", type=int, default=20000) 536 | parser.add_argument( 537 | "--weight-decay", type=float, default=5e-2, help="Weight decay to use" 538 | ) 539 | parser.add_argument("--beta1", type=float, default=0.9) 540 | parser.add_argument("--beta2", type=float, default=0.95) 541 | parser.add_argument( 542 | "--max-grad-norm", default=1.0, type=float, help="Max gradient norm." 543 | ) 544 | parser.add_argument("--global-batch-size", type=int, default=256) 545 | parser.add_argument("--global-seed", type=int, default=0) 546 | parser.add_argument("--num-workers", type=int, default=16) 547 | parser.add_argument("--log-every", type=int, default=100) 548 | parser.add_argument( 549 | "--mixed-precision", type=str, default="bf16", choices=["none", "bf16"] 550 | ) 551 | parser.add_argument("--reconstruction-weight", type=float, default=1.0) 552 | parser.add_argument("--reconstruction-loss", type=str, default="l1") 553 | parser.add_argument("--perceptual-weight", type=float, default=1.0) 554 | parser.add_argument("--disc-weight", type=float, default=0.5) 555 | parser.add_argument("--kl-weight", type=float, default=0.004) 556 | parser.add_argument("--disc-start", type=int, default=20000) 557 | parser.add_argument( 558 | "--disc-type", type=str, choices=["patchgan", "stylegan"], default="patchgan" 559 | ) 560 | parser.add_argument( 561 | "--disc-loss", 562 | type=str, 563 | choices=["hinge", "vanilla", "non-saturating"], 564 | default="hinge", 565 | ) 566 | parser.add_argument( 567 | "--gen-loss", type=str, choices=["hinge", "non-saturating"], default="hinge" 568 | ) 569 | parser.add_argument("--keep-freq", type=int, default=50) 570 | main(parser.parse_args()) 571 | -------------------------------------------------------------------------------- /evaluator.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import io 3 | import os 4 | import random 5 | import warnings 6 | import zipfile 7 | from abc import ABC, abstractmethod 8 | from contextlib import contextmanager 9 | from functools import partial 10 | from multiprocessing import cpu_count 11 | from multiprocessing.pool import ThreadPool 12 | from typing import Iterable, Optional, Tuple 13 | 14 | import numpy as np 15 | import requests 16 | import tensorflow.compat.v1 as tf 17 | from scipy import linalg 18 | from tqdm.auto import tqdm 19 | 20 | INCEPTION_V3_URL = "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/classify_image_graph_def.pb" 21 | INCEPTION_V3_PATH = "classify_image_graph_def.pb" 22 | 23 | FID_POOL_NAME = "pool_3:0" 24 | FID_SPATIAL_NAME = "mixed_6/conv:0" 25 | 26 | 27 | def main(): 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument("ref_batch", help="path to reference batch npz file") 30 | parser.add_argument("sample_batch", help="path to sample batch npz file") 31 | args = parser.parse_args() 32 | 33 | config = tf.ConfigProto( 34 | allow_soft_placement=True # allows DecodeJpeg to run on CPU in Inception graph 35 | ) 36 | config.gpu_options.allow_growth = True 37 | evaluator = Evaluator(tf.Session(config=config)) 38 | 39 | print("warming up TensorFlow...") 40 | # This will cause TF to print a bunch of verbose stuff now rather 41 | # than after the next print(), to help prevent confusion. 42 | evaluator.warmup() 43 | 44 | print("computing reference batch activations...") 45 | ref_acts = evaluator.read_activations(args.ref_batch) 46 | print("computing/reading reference batch statistics...") 47 | ref_stats, ref_stats_spatial = evaluator.read_statistics(args.ref_batch, ref_acts) 48 | 49 | print("computing sample batch activations...") 50 | sample_acts = evaluator.read_activations(args.sample_batch) 51 | print("computing/reading sample batch statistics...") 52 | sample_stats, sample_stats_spatial = evaluator.read_statistics( 53 | args.sample_batch, sample_acts 54 | ) 55 | 56 | print("Computing evaluations...") 57 | IS = evaluator.compute_inception_score(sample_acts[0]) 58 | FID = sample_stats.frechet_distance(ref_stats) 59 | sFID = sample_stats_spatial.frechet_distance(ref_stats_spatial) 60 | print("Inception Score:", IS) 61 | print("FID:", FID) 62 | print("sFID:", sFID) 63 | prec, recall = evaluator.compute_prec_recall(ref_acts[0], sample_acts[0]) 64 | print("Precision:", prec) 65 | print("Recall:", recall) 66 | 67 | txt_path = args.sample_batch.replace(".npz", ".txt") 68 | print("writing to {}".format(txt_path)) 69 | with open(txt_path, "w") as f: 70 | print("Inception Score:", IS, file=f) 71 | print("FID:", FID, file=f) 72 | print("sFID:", sFID, file=f) 73 | print("Precision:", prec, file=f) 74 | print("Recall:", recall, file=f) 75 | 76 | 77 | class InvalidFIDException(Exception): 78 | pass 79 | 80 | 81 | class FIDStatistics: 82 | def __init__(self, mu: np.ndarray, sigma: np.ndarray): 83 | self.mu = mu 84 | self.sigma = sigma 85 | 86 | def frechet_distance(self, other, eps=1e-6): 87 | """ 88 | Compute the Frechet distance between two sets of statistics. 89 | """ 90 | # https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L132 91 | mu1, sigma1 = self.mu, self.sigma 92 | mu2, sigma2 = other.mu, other.sigma 93 | 94 | mu1 = np.atleast_1d(mu1) 95 | mu2 = np.atleast_1d(mu2) 96 | 97 | sigma1 = np.atleast_2d(sigma1) 98 | sigma2 = np.atleast_2d(sigma2) 99 | 100 | assert ( 101 | mu1.shape == mu2.shape 102 | ), f"Training and test mean vectors have different lengths: {mu1.shape}, {mu2.shape}" 103 | assert ( 104 | sigma1.shape == sigma2.shape 105 | ), f"Training and test covariances have different dimensions: {sigma1.shape}, {sigma2.shape}" 106 | 107 | diff = mu1 - mu2 108 | 109 | # product might be almost singular 110 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 111 | if not np.isfinite(covmean).all(): 112 | msg = ( 113 | "fid calculation produces singular product; adding %s to diagonal of cov estimates" 114 | % eps 115 | ) 116 | warnings.warn(msg) 117 | offset = np.eye(sigma1.shape[0]) * eps 118 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 119 | 120 | # numerical error might give slight imaginary component 121 | if np.iscomplexobj(covmean): 122 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 123 | m = np.max(np.abs(covmean.imag)) 124 | raise ValueError("Imaginary component {}".format(m)) 125 | covmean = covmean.real 126 | 127 | tr_covmean = np.trace(covmean) 128 | 129 | return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean 130 | 131 | 132 | class Evaluator: 133 | def __init__( 134 | self, 135 | session, 136 | batch_size=64, 137 | softmax_batch_size=512, 138 | ): 139 | self.sess = session 140 | self.batch_size = batch_size 141 | self.softmax_batch_size = softmax_batch_size 142 | self.manifold_estimator = ManifoldEstimator(session) 143 | with self.sess.graph.as_default(): 144 | self.image_input = tf.placeholder(tf.float32, shape=[None, None, None, 3]) 145 | self.softmax_input = tf.placeholder(tf.float32, shape=[None, 2048]) 146 | self.pool_features, self.spatial_features = _create_feature_graph( 147 | self.image_input 148 | ) 149 | self.softmax = _create_softmax_graph(self.softmax_input) 150 | 151 | def warmup(self): 152 | self.compute_activations(np.zeros([1, 8, 64, 64, 3])) 153 | 154 | def read_activations(self, npz_path: str) -> Tuple[np.ndarray, np.ndarray]: 155 | with open_npz_array(npz_path, "arr_0") as reader: 156 | return self.compute_activations(reader.read_batches(self.batch_size)) 157 | 158 | def compute_activations( 159 | self, batches: Iterable[np.ndarray] 160 | ) -> Tuple[np.ndarray, np.ndarray]: 161 | """ 162 | Compute image features for downstream evals. 163 | 164 | :param batches: a iterator over NHWC numpy arrays in [0, 255]. 165 | :return: a tuple of numpy arrays of shape [N x X], where X is a feature 166 | dimension. The tuple is (pool_3, spatial). 167 | """ 168 | preds = [] 169 | spatial_preds = [] 170 | for batch in tqdm(batches): 171 | batch = batch.astype(np.float32) 172 | pred, spatial_pred = self.sess.run( 173 | [self.pool_features, self.spatial_features], {self.image_input: batch} 174 | ) 175 | preds.append(pred.reshape([pred.shape[0], -1])) 176 | spatial_preds.append(spatial_pred.reshape([spatial_pred.shape[0], -1])) 177 | return ( 178 | np.concatenate(preds, axis=0), 179 | np.concatenate(spatial_preds, axis=0), 180 | ) 181 | 182 | def read_statistics( 183 | self, npz_path: str, activations: Tuple[np.ndarray, np.ndarray] 184 | ) -> Tuple[FIDStatistics, FIDStatistics]: 185 | obj = np.load(npz_path) 186 | if "mu" in list(obj.keys()): 187 | return FIDStatistics(obj["mu"], obj["sigma"]), FIDStatistics( 188 | obj["mu_s"], obj["sigma_s"] 189 | ) 190 | return tuple(self.compute_statistics(x) for x in activations) 191 | 192 | def compute_statistics(self, activations: np.ndarray) -> FIDStatistics: 193 | mu = np.mean(activations, axis=0) 194 | sigma = np.cov(activations, rowvar=False) 195 | return FIDStatistics(mu, sigma) 196 | 197 | def compute_inception_score( 198 | self, activations: np.ndarray, split_size: int = 5000 199 | ) -> float: 200 | softmax_out = [] 201 | for i in range(0, len(activations), self.softmax_batch_size): 202 | acts = activations[i : i + self.softmax_batch_size] 203 | softmax_out.append( 204 | self.sess.run(self.softmax, feed_dict={self.softmax_input: acts}) 205 | ) 206 | preds = np.concatenate(softmax_out, axis=0) 207 | # https://github.com/openai/improved-gan/blob/4f5d1ec5c16a7eceb206f42bfc652693601e1d5c/inception_score/model.py#L46 208 | scores = [] 209 | for i in range(0, len(preds), split_size): 210 | part = preds[i : i + split_size] 211 | kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0))) 212 | kl = np.mean(np.sum(kl, 1)) 213 | scores.append(np.exp(kl)) 214 | return float(np.mean(scores)) 215 | 216 | def compute_prec_recall( 217 | self, activations_ref: np.ndarray, activations_sample: np.ndarray 218 | ) -> Tuple[float, float]: 219 | radii_1 = self.manifold_estimator.manifold_radii(activations_ref) 220 | radii_2 = self.manifold_estimator.manifold_radii(activations_sample) 221 | pr = self.manifold_estimator.evaluate_pr( 222 | activations_ref, radii_1, activations_sample, radii_2 223 | ) 224 | return (float(pr[0][0]), float(pr[1][0])) 225 | 226 | 227 | class ManifoldEstimator: 228 | """ 229 | A helper for comparing manifolds of feature vectors. 230 | 231 | Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L57 232 | """ 233 | 234 | def __init__( 235 | self, 236 | session, 237 | row_batch_size=10000, 238 | col_batch_size=10000, 239 | nhood_sizes=(3,), 240 | clamp_to_percentile=None, 241 | eps=1e-5, 242 | ): 243 | """ 244 | Estimate the manifold of given feature vectors. 245 | 246 | :param session: the TensorFlow session. 247 | :param row_batch_size: row batch size to compute pairwise distances 248 | (parameter to trade-off between memory usage and performance). 249 | :param col_batch_size: column batch size to compute pairwise distances. 250 | :param nhood_sizes: number of neighbors used to estimate the manifold. 251 | :param clamp_to_percentile: prune hyperspheres that have radius larger than 252 | the given percentile. 253 | :param eps: small number for numerical stability. 254 | """ 255 | self.distance_block = DistanceBlock(session) 256 | self.row_batch_size = row_batch_size 257 | self.col_batch_size = col_batch_size 258 | self.nhood_sizes = nhood_sizes 259 | self.num_nhoods = len(nhood_sizes) 260 | self.clamp_to_percentile = clamp_to_percentile 261 | self.eps = eps 262 | 263 | def warmup(self): 264 | feats, radii = ( 265 | np.zeros([1, 2048], dtype=np.float32), 266 | np.zeros([1, 1], dtype=np.float32), 267 | ) 268 | self.evaluate_pr(feats, radii, feats, radii) 269 | 270 | def manifold_radii(self, features: np.ndarray) -> np.ndarray: 271 | num_images = len(features) 272 | 273 | # Estimate manifold of features by calculating distances to k-NN of each sample. 274 | radii = np.zeros([num_images, self.num_nhoods], dtype=np.float32) 275 | distance_batch = np.zeros([self.row_batch_size, num_images], dtype=np.float32) 276 | seq = np.arange(max(self.nhood_sizes) + 1, dtype=np.int32) 277 | 278 | for begin1 in range(0, num_images, self.row_batch_size): 279 | end1 = min(begin1 + self.row_batch_size, num_images) 280 | row_batch = features[begin1:end1] 281 | 282 | for begin2 in range(0, num_images, self.col_batch_size): 283 | end2 = min(begin2 + self.col_batch_size, num_images) 284 | col_batch = features[begin2:end2] 285 | 286 | # Compute distances between batches. 287 | distance_batch[0 : end1 - begin1, begin2:end2] = ( 288 | self.distance_block.pairwise_distances(row_batch, col_batch) 289 | ) 290 | 291 | # Find the k-nearest neighbor from the current batch. 292 | radii[begin1:end1, :] = np.concatenate( 293 | [ 294 | x[:, self.nhood_sizes] 295 | for x in _numpy_partition( 296 | distance_batch[0 : end1 - begin1, :], seq, axis=1 297 | ) 298 | ], 299 | axis=0, 300 | ) 301 | 302 | if self.clamp_to_percentile is not None: 303 | max_distances = np.percentile(radii, self.clamp_to_percentile, axis=0) 304 | radii[radii > max_distances] = 0 305 | return radii 306 | 307 | def evaluate( 308 | self, features: np.ndarray, radii: np.ndarray, eval_features: np.ndarray 309 | ): 310 | """ 311 | Evaluate if new feature vectors are at the manifold. 312 | """ 313 | num_eval_images = eval_features.shape[0] 314 | num_ref_images = radii.shape[0] 315 | distance_batch = np.zeros( 316 | [self.row_batch_size, num_ref_images], dtype=np.float32 317 | ) 318 | batch_predictions = np.zeros([num_eval_images, self.num_nhoods], dtype=np.int32) 319 | max_realism_score = np.zeros([num_eval_images], dtype=np.float32) 320 | nearest_indices = np.zeros([num_eval_images], dtype=np.int32) 321 | 322 | for begin1 in range(0, num_eval_images, self.row_batch_size): 323 | end1 = min(begin1 + self.row_batch_size, num_eval_images) 324 | feature_batch = eval_features[begin1:end1] 325 | 326 | for begin2 in range(0, num_ref_images, self.col_batch_size): 327 | end2 = min(begin2 + self.col_batch_size, num_ref_images) 328 | ref_batch = features[begin2:end2] 329 | 330 | distance_batch[0 : end1 - begin1, begin2:end2] = ( 331 | self.distance_block.pairwise_distances(feature_batch, ref_batch) 332 | ) 333 | 334 | # From the minibatch of new feature vectors, determine if they are in the estimated manifold. 335 | # If a feature vector is inside a hypersphere of some reference sample, then 336 | # the new sample lies at the estimated manifold. 337 | # The radii of the hyperspheres are determined from distances of neighborhood size k. 338 | samples_in_manifold = distance_batch[0 : end1 - begin1, :, None] <= radii 339 | batch_predictions[begin1:end1] = np.any(samples_in_manifold, axis=1).astype( 340 | np.int32 341 | ) 342 | 343 | max_realism_score[begin1:end1] = np.max( 344 | radii[:, 0] / (distance_batch[0 : end1 - begin1, :] + self.eps), axis=1 345 | ) 346 | nearest_indices[begin1:end1] = np.argmin( 347 | distance_batch[0 : end1 - begin1, :], axis=1 348 | ) 349 | 350 | return { 351 | "fraction": float(np.mean(batch_predictions)), 352 | "batch_predictions": batch_predictions, 353 | "max_realisim_score": max_realism_score, 354 | "nearest_indices": nearest_indices, 355 | } 356 | 357 | def evaluate_pr( 358 | self, 359 | features_1: np.ndarray, 360 | radii_1: np.ndarray, 361 | features_2: np.ndarray, 362 | radii_2: np.ndarray, 363 | ) -> Tuple[np.ndarray, np.ndarray]: 364 | """ 365 | Evaluate precision and recall efficiently. 366 | 367 | :param features_1: [N1 x D] feature vectors for reference batch. 368 | :param radii_1: [N1 x K1] radii for reference vectors. 369 | :param features_2: [N2 x D] feature vectors for the other batch. 370 | :param radii_2: [N x K2] radii for other vectors. 371 | :return: a tuple of arrays for (precision, recall): 372 | - precision: an np.ndarray of length K1 373 | - recall: an np.ndarray of length K2 374 | """ 375 | features_1_status = np.zeros([len(features_1), radii_2.shape[1]], dtype=bool) 376 | features_2_status = np.zeros([len(features_2), radii_1.shape[1]], dtype=bool) 377 | for begin_1 in range(0, len(features_1), self.row_batch_size): 378 | end_1 = begin_1 + self.row_batch_size 379 | batch_1 = features_1[begin_1:end_1] 380 | for begin_2 in range(0, len(features_2), self.col_batch_size): 381 | end_2 = begin_2 + self.col_batch_size 382 | batch_2 = features_2[begin_2:end_2] 383 | batch_1_in, batch_2_in = self.distance_block.less_thans( 384 | batch_1, radii_1[begin_1:end_1], batch_2, radii_2[begin_2:end_2] 385 | ) 386 | features_1_status[begin_1:end_1] |= batch_1_in 387 | features_2_status[begin_2:end_2] |= batch_2_in 388 | return ( 389 | np.mean(features_2_status.astype(np.float64), axis=0), 390 | np.mean(features_1_status.astype(np.float64), axis=0), 391 | ) 392 | 393 | 394 | class DistanceBlock: 395 | """ 396 | Calculate pairwise distances between vectors. 397 | 398 | Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L34 399 | """ 400 | 401 | def __init__(self, session): 402 | self.session = session 403 | 404 | # Initialize TF graph to calculate pairwise distances. 405 | with session.graph.as_default(): 406 | self._features_batch1 = tf.placeholder(tf.float32, shape=[None, None]) 407 | self._features_batch2 = tf.placeholder(tf.float32, shape=[None, None]) 408 | distance_block_16 = _batch_pairwise_distances( 409 | tf.cast(self._features_batch1, tf.float16), 410 | tf.cast(self._features_batch2, tf.float16), 411 | ) 412 | self.distance_block = tf.cond( 413 | tf.reduce_all(tf.math.is_finite(distance_block_16)), 414 | lambda: tf.cast(distance_block_16, tf.float32), 415 | lambda: _batch_pairwise_distances( 416 | self._features_batch1, self._features_batch2 417 | ), 418 | ) 419 | 420 | # Extra logic for less thans. 421 | self._radii1 = tf.placeholder(tf.float32, shape=[None, None]) 422 | self._radii2 = tf.placeholder(tf.float32, shape=[None, None]) 423 | dist32 = tf.cast(self.distance_block, tf.float32)[..., None] 424 | self._batch_1_in = tf.math.reduce_any(dist32 <= self._radii2, axis=1) 425 | self._batch_2_in = tf.math.reduce_any( 426 | dist32 <= self._radii1[:, None], axis=0 427 | ) 428 | 429 | def pairwise_distances(self, U, V): 430 | """ 431 | Evaluate pairwise distances between two batches of feature vectors. 432 | """ 433 | return self.session.run( 434 | self.distance_block, 435 | feed_dict={self._features_batch1: U, self._features_batch2: V}, 436 | ) 437 | 438 | def less_thans(self, batch_1, radii_1, batch_2, radii_2): 439 | return self.session.run( 440 | [self._batch_1_in, self._batch_2_in], 441 | feed_dict={ 442 | self._features_batch1: batch_1, 443 | self._features_batch2: batch_2, 444 | self._radii1: radii_1, 445 | self._radii2: radii_2, 446 | }, 447 | ) 448 | 449 | 450 | def _batch_pairwise_distances(U, V): 451 | """ 452 | Compute pairwise distances between two batches of feature vectors. 453 | """ 454 | with tf.variable_scope("pairwise_dist_block"): 455 | # Squared norms of each row in U and V. 456 | norm_u = tf.reduce_sum(tf.square(U), 1) 457 | norm_v = tf.reduce_sum(tf.square(V), 1) 458 | 459 | # norm_u as a column and norm_v as a row vectors. 460 | norm_u = tf.reshape(norm_u, [-1, 1]) 461 | norm_v = tf.reshape(norm_v, [1, -1]) 462 | 463 | # Pairwise squared Euclidean distances. 464 | D = tf.maximum(norm_u - 2 * tf.matmul(U, V, False, True) + norm_v, 0.0) 465 | 466 | return D 467 | 468 | 469 | class NpzArrayReader(ABC): 470 | @abstractmethod 471 | def read_batch(self, batch_size: int) -> Optional[np.ndarray]: 472 | pass 473 | 474 | @abstractmethod 475 | def remaining(self) -> int: 476 | pass 477 | 478 | def read_batches(self, batch_size: int) -> Iterable[np.ndarray]: 479 | def gen_fn(): 480 | while True: 481 | batch = self.read_batch(batch_size) 482 | if batch is None: 483 | break 484 | yield batch 485 | 486 | rem = self.remaining() 487 | num_batches = rem // batch_size + int(rem % batch_size != 0) 488 | return BatchIterator(gen_fn, num_batches) 489 | 490 | 491 | class BatchIterator: 492 | def __init__(self, gen_fn, length): 493 | self.gen_fn = gen_fn 494 | self.length = length 495 | 496 | def __len__(self): 497 | return self.length 498 | 499 | def __iter__(self): 500 | return self.gen_fn() 501 | 502 | 503 | class StreamingNpzArrayReader(NpzArrayReader): 504 | def __init__(self, arr_f, shape, dtype): 505 | self.arr_f = arr_f 506 | self.shape = shape 507 | self.dtype = dtype 508 | self.idx = 0 509 | 510 | def read_batch(self, batch_size: int) -> Optional[np.ndarray]: 511 | if self.idx >= self.shape[0]: 512 | return None 513 | 514 | bs = min(batch_size, self.shape[0] - self.idx) 515 | self.idx += bs 516 | 517 | if self.dtype.itemsize == 0: 518 | return np.ndarray([bs, *self.shape[1:]], dtype=self.dtype) 519 | 520 | read_count = bs * np.prod(self.shape[1:]) 521 | read_size = int(read_count * self.dtype.itemsize) 522 | data = _read_bytes(self.arr_f, read_size, "array data") 523 | return np.frombuffer(data, dtype=self.dtype).reshape([bs, *self.shape[1:]]) 524 | 525 | def remaining(self) -> int: 526 | return max(0, self.shape[0] - self.idx) 527 | 528 | 529 | class MemoryNpzArrayReader(NpzArrayReader): 530 | def __init__(self, arr): 531 | self.arr = arr 532 | self.idx = 0 533 | 534 | @classmethod 535 | def load(cls, path: str, arr_name: str): 536 | with open(path, "rb") as f: 537 | arr = np.load(f)[arr_name] 538 | return cls(arr) 539 | 540 | def read_batch(self, batch_size: int) -> Optional[np.ndarray]: 541 | if self.idx >= self.arr.shape[0]: 542 | return None 543 | 544 | res = self.arr[self.idx : self.idx + batch_size] 545 | self.idx += batch_size 546 | return res 547 | 548 | def remaining(self) -> int: 549 | return max(0, self.arr.shape[0] - self.idx) 550 | 551 | 552 | @contextmanager 553 | def open_npz_array(path: str, arr_name: str) -> NpzArrayReader: 554 | with _open_npy_file(path, arr_name) as arr_f: 555 | version = np.lib.format.read_magic(arr_f) 556 | if version == (1, 0): 557 | header = np.lib.format.read_array_header_1_0(arr_f) 558 | elif version == (2, 0): 559 | header = np.lib.format.read_array_header_2_0(arr_f) 560 | else: 561 | yield MemoryNpzArrayReader.load(path, arr_name) 562 | return 563 | shape, fortran, dtype = header 564 | if fortran or dtype.hasobject: 565 | yield MemoryNpzArrayReader.load(path, arr_name) 566 | else: 567 | yield StreamingNpzArrayReader(arr_f, shape, dtype) 568 | 569 | 570 | def _read_bytes(fp, size, error_template="ran out of data"): 571 | """ 572 | Copied from: https://github.com/numpy/numpy/blob/fb215c76967739268de71aa4bda55dd1b062bc2e/numpy/lib/format.py#L788-L886 573 | 574 | Read from file-like object until size bytes are read. 575 | Raises ValueError if not EOF is encountered before size bytes are read. 576 | Non-blocking objects only supported if they derive from io objects. 577 | Required as e.g. ZipExtFile in python 2.6 can return less data than 578 | requested. 579 | """ 580 | data = bytes() 581 | while True: 582 | # io files (default in python3) return None or raise on 583 | # would-block, python2 file will truncate, probably nothing can be 584 | # done about that. note that regular files can't be non-blocking 585 | try: 586 | r = fp.read(size - len(data)) 587 | data += r 588 | if len(r) == 0 or len(data) == size: 589 | break 590 | except io.BlockingIOError: 591 | pass 592 | if len(data) != size: 593 | msg = "EOF: reading %s, expected %d bytes got %d" 594 | raise ValueError(msg % (error_template, size, len(data))) 595 | else: 596 | return data 597 | 598 | 599 | @contextmanager 600 | def _open_npy_file(path: str, arr_name: str): 601 | with open(path, "rb") as f: 602 | with zipfile.ZipFile(f, "r") as zip_f: 603 | if f"{arr_name}.npy" not in zip_f.namelist(): 604 | raise ValueError(f"missing {arr_name} in npz file") 605 | with zip_f.open(f"{arr_name}.npy", "r") as arr_f: 606 | yield arr_f 607 | 608 | 609 | def _download_inception_model(): 610 | if os.path.exists(INCEPTION_V3_PATH): 611 | return 612 | print("downloading InceptionV3 model...") 613 | with requests.get(INCEPTION_V3_URL, stream=True) as r: 614 | r.raise_for_status() 615 | tmp_path = INCEPTION_V3_PATH + ".tmp" 616 | with open(tmp_path, "wb") as f: 617 | for chunk in tqdm(r.iter_content(chunk_size=8192)): 618 | f.write(chunk) 619 | os.rename(tmp_path, INCEPTION_V3_PATH) 620 | 621 | 622 | def _create_feature_graph(input_batch): 623 | _download_inception_model() 624 | prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}" 625 | with open(INCEPTION_V3_PATH, "rb") as f: 626 | graph_def = tf.GraphDef() 627 | graph_def.ParseFromString(f.read()) 628 | pool3, spatial = tf.import_graph_def( 629 | graph_def, 630 | input_map={f"ExpandDims:0": input_batch}, 631 | return_elements=[FID_POOL_NAME, FID_SPATIAL_NAME], 632 | name=prefix, 633 | ) 634 | _update_shapes(pool3) 635 | spatial = spatial[..., :7] 636 | return pool3, spatial 637 | 638 | 639 | def _create_softmax_graph(input_batch): 640 | _download_inception_model() 641 | prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}" 642 | with open(INCEPTION_V3_PATH, "rb") as f: 643 | graph_def = tf.GraphDef() 644 | graph_def.ParseFromString(f.read()) 645 | (matmul,) = tf.import_graph_def( 646 | graph_def, return_elements=[f"softmax/logits/MatMul"], name=prefix 647 | ) 648 | w = matmul.inputs[1] 649 | logits = tf.matmul(input_batch, w) 650 | return tf.nn.softmax(logits) 651 | 652 | 653 | def _update_shapes(pool3): 654 | # https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L50-L63 655 | ops = pool3.graph.get_operations() 656 | for op in ops: 657 | for o in op.outputs: 658 | shape = o.get_shape() 659 | if shape._dims is not None: # pylint: disable=protected-access 660 | # shape = [s.value for s in shape] TF 1.x 661 | shape = [s for s in shape] # TF 2.x 662 | new_shape = [] 663 | for j, s in enumerate(shape): 664 | if s == 1 and j == 0: 665 | new_shape.append(None) 666 | else: 667 | new_shape.append(s) 668 | o.__dict__["_shape_val"] = tf.TensorShape(new_shape) 669 | return pool3 670 | 671 | 672 | def _numpy_partition(arr, kth, **kwargs): 673 | num_workers = min(cpu_count(), len(arr)) 674 | chunk_size = len(arr) // num_workers 675 | extra = len(arr) % num_workers 676 | 677 | start_idx = 0 678 | batches = [] 679 | for i in range(num_workers): 680 | size = chunk_size + (1 if i < extra else 0) 681 | batches.append(arr[start_idx : start_idx + size]) 682 | start_idx += size 683 | 684 | with ThreadPool(num_workers) as pool: 685 | return list(pool.map(partial(np.partition, kth=kth, **kwargs), batches)) 686 | 687 | 688 | if __name__ == "__main__": 689 | main() 690 | --------------------------------------------------------------------------------