├── SphereAR
├── __init__.py
├── gan
│ ├── __init__.py
│ ├── discriminator_stylegan.py
│ ├── discriminator_patchgan.py
│ ├── gan_loss.py
│ └── lpips.py
├── utils.py
├── psd.py
├── sampling.py
├── dataset.py
├── diff_head.py
├── vae.py
├── layers.py
└── model.py
├── figures
├── grid.jpg
├── overview.png
└── fid_vs_params.png
├── .gitignore
├── README.md
├── sample_ddp.py
├── train.py
└── evaluator.py
/SphereAR/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/SphereAR/gan/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/figures/grid.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guolinke/SphereAR/HEAD/figures/grid.jpg
--------------------------------------------------------------------------------
/figures/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guolinke/SphereAR/HEAD/figures/overview.png
--------------------------------------------------------------------------------
/figures/fid_vs_params.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guolinke/SphereAR/HEAD/figures/fid_vs_params.png
--------------------------------------------------------------------------------
/SphereAR/utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import torch
4 | import torch.distributed as dist
5 | from torch.nn import functional as F
6 |
7 |
8 | def create_logger(logging_dir):
9 | """
10 | Create a logger that writes to a log file and stdout.
11 | """
12 | if dist.get_rank() == 0: # real logger
13 | logging.basicConfig(
14 | level=logging.INFO,
15 | format="[\033[34m%(asctime)s\033[0m] %(message)s",
16 | datefmt="%Y-%m-%d %H:%M:%S",
17 | handlers=[
18 | logging.StreamHandler(),
19 | logging.FileHandler(f"{logging_dir}/log.txt"),
20 | ],
21 | )
22 | logger = logging.getLogger(__name__)
23 | else: # dummy logger (does nothing)
24 | logger = logging.getLogger(__name__)
25 | logger.addHandler(logging.NullHandler())
26 | return logger
27 |
28 |
29 | @torch.no_grad()
30 | def update_ema(ema_model, model, decay=0.9999):
31 | """
32 | Step the EMA model towards the current model.
33 | """
34 | ema_ps = []
35 | ps = []
36 |
37 | for e, m in zip(ema_model.parameters(), model.parameters()):
38 | if m.requires_grad:
39 | ema_ps.append(e)
40 | ps.append(m)
41 | torch._foreach_lerp_(ema_ps, ps, 1.0 - decay)
42 |
43 |
44 | @torch.no_grad()
45 | def sync_frozen_params_once(ema_model, model):
46 | for e, m in zip(ema_model.parameters(), model.parameters()):
47 | if not m.requires_grad:
48 | e.copy_(m)
49 |
50 |
51 | def requires_grad(model, flag=True):
52 | """
53 | Set requires_grad flag for all parameters in a model.
54 | """
55 | for p in model.parameters():
56 | p.requires_grad = flag
57 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pt
2 | *.tfevents.*
3 | # JetBrains PyCharm IDE
4 | .idea/
5 |
6 | # Compress files
7 | .gz
8 | .tar.gz
9 | .zip
10 | .data
11 | .7z
12 | .7zip
13 | .sdf
14 |
15 | # Byte-compiled / optimized / DLL files
16 | __pycache__/
17 | *.py[cod]
18 | *$py.class
19 |
20 | # C extensions
21 | *.so
22 |
23 | # macOS dir files
24 | .DS_Store
25 |
26 | # Distribution / packaging
27 | .Python
28 | env/
29 | build/
30 | develop-eggs/
31 | dist/
32 | downloads/
33 | eggs/
34 | .eggs/
35 | lib/
36 | lib64/
37 | parts/
38 | sdist/
39 | var/
40 | wheels/
41 | *.egg-info/
42 | .installed.args
43 | *.egg
44 |
45 | # Checkpoints
46 | checkpoints
47 |
48 | # PyInstaller
49 | # Usually these files are written by a python script from a template
50 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
51 | *.manifest
52 | *.spec
53 |
54 | # Installer logs
55 | pip-log.txt
56 | pip-delete-this-directory.txt
57 |
58 | # Unit test / coverage reports
59 | htmlcov/
60 | .tox/
61 | .coverage
62 | .coverage.*
63 | .cache
64 | nosetests.xml
65 | coverage.xml
66 | *.cover
67 | .hypothesis/
68 |
69 | # Translations
70 | *.mo
71 | *.pot
72 |
73 | # Django stuff:
74 | *.log
75 | local_settings.py
76 |
77 | # Flask stuff:
78 | instance/
79 | .webassets-cache
80 |
81 | # Scrapy stuff:
82 | .scrapy
83 |
84 | # Sphinx documentation
85 | docs/_build/
86 |
87 | # PyBuilder
88 | target/
89 |
90 | # Jupyter Notebook
91 | .ipynb_checkpoints
92 |
93 | # pyenv
94 | .python-version
95 |
96 | # celery beat schedule file
97 | celerybeat-schedule
98 |
99 | # SageMath parsed files
100 | *.sage.py
101 |
102 | # dotenv
103 | .env
104 |
105 | # virtualenv
106 | .venv
107 | venv/
108 | ENV/
109 |
110 | # Spyder project settings
111 | .spyderproject
112 | .spyproject
113 |
114 | # Rope project settings
115 | .ropeproject
116 |
117 | # mypy
118 | .mypy_cache/
119 |
120 | # VSCODE
121 | .vscode/ftp-sync.json
122 | .vscode/settings.json
123 |
124 | # too big to git
125 | *.lmdb
126 | *.sto
127 | *.pt
128 | *.pkl
129 |
130 | # pytest
131 | .pytest_cache
132 | test/.pytest_cache
133 | /local*
134 | /_*
135 | weights
136 | wandb
137 | weights*
138 | ft_weights*
139 | finetune_scripts/*.dict
140 | *.pdf
--------------------------------------------------------------------------------
/SphereAR/psd.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | from torch.distributions import Beta
5 |
6 |
7 | def l2_norm(x: torch.Tensor, eps: float = 1e-7) -> torch.Tensor:
8 | return x / torch.clamp(x.norm(dim=-1, keepdim=True), min=eps)
9 |
10 |
11 | class PowerSphericalDistribution:
12 | def __init__(self, mu: torch.Tensor, kappa: torch.Tensor, eps: float = 1e-7):
13 | self.eps = eps
14 | self.mu = l2_norm(mu, eps) # [..., m]
15 | self.kappa = torch.clamp(kappa, min=0.0)
16 |
17 | self.m = self.mu.shape[-1]
18 | self.d = self.m - 1
19 | beta_const = 0.5 * self.d
20 | self.alpha = self.kappa + beta_const # [...,]
21 | self.beta = torch.as_tensor(
22 | beta_const, dtype=self.kappa.dtype, device=self.kappa.device
23 | ).expand_as(self.kappa)
24 |
25 | def _log_normalizer(self) -> torch.Tensor:
26 | # log N_X(κ,d) = -[ (α+β)log 2 + β log π + lgamma(α) - lgamma(α+β) ]
27 | return (
28 | -(self.alpha + self.beta) * math.log(2.0)
29 | - self.beta * math.log(math.pi)
30 | - torch.lgamma(self.alpha)
31 | + torch.lgamma(self.alpha + self.beta)
32 | )
33 |
34 | def log_prob(self, x: torch.Tensor) -> torch.Tensor:
35 | dot = (self.mu * x).sum(dim=-1).clamp(-1.0, 1.0)
36 | return self._log_normalizer() + self.kappa * torch.log1p(dot)
37 |
38 | def entropy(self) -> torch.Tensor:
39 | # H = -[ log N_X + κ ( log 2 + ψ(α) - ψ(α+β) ) ]
40 | return -(
41 | self._log_normalizer()
42 | + self.kappa
43 | * (
44 | math.log(2.0)
45 | + (torch.digamma(self.alpha) - torch.digamma(self.alpha + self.beta))
46 | )
47 | )
48 |
49 | def kl_to_uniform(self) -> torch.Tensor:
50 | # KL(q || U(S^{d})) = -H(q) + log |S^{d}|
51 | d = torch.as_tensor(self.d, dtype=self.kappa.dtype, device=self.kappa.device)
52 | log_area = (
53 | math.log(2.0)
54 | + 0.5 * (d + 1.0) * math.log(math.pi)
55 | - torch.lgamma(0.5 * (d + 1.0))
56 | )
57 | return -self.entropy() + log_area
58 |
59 | def rsample(self):
60 | Z = Beta(self.alpha, self.beta).rsample() # [*S, *B]
61 | t = (2.0 * Z - 1.0).unsqueeze(-1) # [*S, *B, 1]
62 |
63 | # 2) v ~ U(S^{m-2})
64 | v = torch.randn(
65 | *self.mu.shape[:-1],
66 | self.m - 1,
67 | device=self.mu.device,
68 | dtype=self.mu.dtype,
69 | ) # [*S, *B, m-1]
70 | v = l2_norm(v, self.eps)
71 |
72 | y = torch.cat(
73 | [t, torch.sqrt(torch.clamp(1 - t**2, min=0.0)) * v], dim=-1
74 | ) # [*S, *B, m]
75 |
76 | e1 = torch.zeros_like(self.mu)
77 | e1[..., 0] = 1.0
78 | u = l2_norm(e1 - self.mu, self.eps)
79 | if u.dim() < y.dim():
80 | u = u.view((1,) * (y.dim() - u.dim()) + u.shape)
81 | z = y - 2.0 * (y * u).sum(dim=-1, keepdim=True) * u
82 |
83 | parallel = (self.mu - e1).abs().sum(dim=-1, keepdim=True) < 1e-6
84 | if parallel.any():
85 | p = parallel
86 | if p.dim() < y.dim() - 1:
87 | p = p.view((1,) * (y.dim() - 1 - p.dim()) + p.shape)
88 | z = torch.where(p, y, z)
89 | return z
90 |
--------------------------------------------------------------------------------
/SphereAR/sampling.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def get_score_from_velocity(velocity, x, t):
5 | alpha_t, d_alpha_t = t, 1
6 | sigma_t, d_sigma_t = 1 - t, -1
7 | mean = x
8 | reverse_alpha_ratio = alpha_t / d_alpha_t
9 | var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t
10 | score = (reverse_alpha_ratio * velocity - mean) / var
11 | return score
12 |
13 |
14 | def get_velocity_from_cfg(velocity, cfg, cfg_mult):
15 | if cfg_mult == 2:
16 | cond_v, uncond_v = torch.chunk(velocity, 2, dim=0)
17 | velocity = uncond_v + cfg * (cond_v - uncond_v)
18 | return velocity
19 |
20 |
21 | @torch.compile()
22 | def euler_step(x, v, dt: float, cfg: float, cfg_mult: int):
23 | with torch.amp.autocast("cuda", enabled=False):
24 | v = v.to(torch.float32)
25 | v = get_velocity_from_cfg(v, cfg, cfg_mult)
26 | x = x + v * dt
27 | return x
28 |
29 |
30 | @torch.compile()
31 | def euler_maruyama_step(x, v, t, dt: float, cfg: float, cfg_mult: int):
32 | with torch.amp.autocast("cuda", enabled=False):
33 | v = v.to(torch.float32)
34 | v = get_velocity_from_cfg(v, cfg, cfg_mult)
35 | score = get_score_from_velocity(v, x, t)
36 | drift = v + (1 - t) * score
37 | noise_scale = (2.0 * (1.0 - t) * dt) ** 0.5
38 | x = x + drift * dt + noise_scale * torch.randn_like(x)
39 | return x
40 |
41 |
42 | def euler_maruyama(
43 | input_dim,
44 | forward_fn,
45 | c: torch.Tensor,
46 | cfg: float = 1.0,
47 | num_sampling_steps: int = 20,
48 | last_step_size: float = 0.04,
49 | ):
50 | cfg_mult = 1
51 | if cfg > 1.0:
52 | cfg_mult += 1
53 |
54 | x_shape = list(c.shape)
55 | x_shape[0] = x_shape[0] // cfg_mult
56 | x_shape[-1] = input_dim
57 | x = torch.randn(x_shape, device=c.device)
58 | dt = (1.0 - last_step_size) / num_sampling_steps
59 | t = torch.tensor(
60 | 0.0, device=c.device, dtype=torch.float32
61 | ) # use tensor to avoid compile warning
62 | t_batch = torch.zeros(c.shape[0], device=c.device)
63 | for _ in range(num_sampling_steps):
64 | t_batch[:] = t
65 | combined = torch.cat([x] * cfg_mult, dim=0)
66 | v = forward_fn(
67 | combined,
68 | t_batch,
69 | c,
70 | )
71 | x = euler_maruyama_step(x, v, t, dt, cfg, cfg_mult)
72 | t += dt
73 |
74 | combined = torch.cat([x] * cfg_mult, dim=0)
75 | t_batch[:] = 1 - last_step_size
76 | v = forward_fn(
77 | combined,
78 | t_batch,
79 | c,
80 | )
81 | x = euler_step(x, v, last_step_size, cfg, cfg_mult)
82 |
83 | return torch.cat([x] * cfg_mult, dim=0)
84 |
85 |
86 | def euler(
87 | input_dim,
88 | forward_fn,
89 | c,
90 | cfg: float = 1.0,
91 | num_sampling_steps: int = 50,
92 | ):
93 | cfg_mult = 1
94 | if cfg > 1.0:
95 | cfg_mult = 2
96 |
97 | x_shape = list(c.shape)
98 | x_shape[0] = x_shape[0] // cfg_mult
99 | x_shape[-1] = input_dim
100 | x = torch.randn(x_shape, device=c.device)
101 | dt = 1.0 / num_sampling_steps
102 | t = 0
103 | t_batch = torch.zeros(c.shape[0], device=c.device)
104 | for _ in range(num_sampling_steps):
105 | t_batch[:] = t
106 | combined = torch.cat([x] * cfg_mult, dim=0)
107 | v = forward_fn(combined, t_batch, c)
108 | x = euler_step(x, v, dt, cfg, cfg_mult)
109 | t += dt
110 |
111 | return torch.cat([x] * cfg_mult, dim=0)
112 |
--------------------------------------------------------------------------------
/SphereAR/gan/discriminator_stylegan.py:
--------------------------------------------------------------------------------
1 | # Modified from:
2 | # stylegan2-pytorch: https://github.com/lucidrains/stylegan2-pytorch/blob/master/stylegan2_pytorch/stylegan2_pytorch.py
3 | # stylegan2-pytorch: https://github.com/rosinality/stylegan2-pytorch/blob/master/model.py
4 | # maskgit: https://github.com/google-research/maskgit/blob/main/maskgit/nets/discriminator.py
5 | import math
6 | import torch
7 | import torch.nn as nn
8 |
9 | try:
10 | from kornia.filters import filter2d
11 | except:
12 | pass
13 |
14 |
15 | class Discriminator(nn.Module):
16 | def __init__(
17 | self, input_nc=3, ndf=64, n_layers=3, channel_multiplier=1, image_size=256
18 | ):
19 | super().__init__()
20 | channels = {
21 | 4: 512,
22 | 8: 512,
23 | 16: 512,
24 | 32: 512,
25 | 64: 256 * channel_multiplier,
26 | 128: 128 * channel_multiplier,
27 | 256: 64 * channel_multiplier,
28 | 512: 32 * channel_multiplier,
29 | 1024: 16 * channel_multiplier,
30 | }
31 |
32 | log_size = int(math.log(image_size, 2))
33 | in_channel = channels[image_size]
34 |
35 | blocks = [nn.Conv2d(input_nc, in_channel, 3, padding=1), leaky_relu()]
36 | for i in range(log_size, 2, -1):
37 | out_channel = channels[2 ** (i - 1)]
38 | blocks.append(DiscriminatorBlock(in_channel, out_channel))
39 | in_channel = out_channel
40 | self.blocks = nn.ModuleList(blocks)
41 |
42 | self.final_conv = nn.Sequential(
43 | nn.Conv2d(in_channel, channels[4], 3, padding=1),
44 | leaky_relu(),
45 | )
46 | self.final_linear = nn.Sequential(
47 | nn.Linear(channels[4] * 4 * 4, channels[4]),
48 | leaky_relu(),
49 | nn.Linear(channels[4], 1),
50 | )
51 |
52 | def forward(self, x):
53 | for block in self.blocks:
54 | x = block(x)
55 | x = self.final_conv(x)
56 | x = x.reshape(x.shape[0], -1)
57 | x = self.final_linear(x)
58 | return x
59 |
60 |
61 | class DiscriminatorBlock(nn.Module):
62 | def __init__(self, input_channels, filters, downsample=True):
63 | super().__init__()
64 | self.conv_res = nn.Conv2d(
65 | input_channels, filters, 1, stride=(2 if downsample else 1)
66 | )
67 |
68 | self.net = nn.Sequential(
69 | nn.Conv2d(input_channels, filters, 3, padding=1),
70 | leaky_relu(),
71 | nn.Conv2d(filters, filters, 3, padding=1),
72 | leaky_relu(),
73 | )
74 |
75 | self.downsample = (
76 | nn.Sequential(Blur(), nn.Conv2d(filters, filters, 3, padding=1, stride=2))
77 | if downsample
78 | else None
79 | )
80 |
81 | def forward(self, x):
82 | res = self.conv_res(x)
83 | x = self.net(x)
84 | if exists(self.downsample):
85 | x = self.downsample(x)
86 | x = (x + res) * (1 / math.sqrt(2))
87 | return x
88 |
89 |
90 | class Blur(nn.Module):
91 | def __init__(self):
92 | super().__init__()
93 | f = torch.Tensor([1, 2, 1])
94 | self.register_buffer("f", f)
95 |
96 | def forward(self, x):
97 | f = self.f
98 | f = f[None, None, :] * f[None, :, None]
99 | return filter2d(x, f, normalized=True)
100 |
101 |
102 | def leaky_relu(p=0.2):
103 | return nn.LeakyReLU(p, inplace=True)
104 |
105 |
106 | def exists(val):
107 | return val is not None
108 |
--------------------------------------------------------------------------------
/SphereAR/gan/discriminator_patchgan.py:
--------------------------------------------------------------------------------
1 | # Modified from:
2 | # taming-transformers: https://github.com/CompVis/taming-transformers
3 | import functools
4 | import torch
5 | import torch.nn as nn
6 |
7 |
8 | class NLayerDiscriminator(nn.Module):
9 | """Defines a PatchGAN discriminator as in Pix2Pix
10 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
11 | """
12 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
13 | """Construct a PatchGAN discriminator
14 | Parameters:
15 | input_nc (int) -- the number of channels in input images
16 | ndf (int) -- the number of filters in the last conv layer
17 | n_layers (int) -- the number of conv layers in the discriminator
18 | norm_layer -- normalization layer
19 | """
20 | super(NLayerDiscriminator, self).__init__()
21 | if not use_actnorm:
22 | norm_layer = nn.BatchNorm2d
23 | else:
24 | norm_layer = ActNorm
25 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
26 | use_bias = norm_layer.func != nn.BatchNorm2d
27 | else:
28 | use_bias = norm_layer != nn.BatchNorm2d
29 |
30 | kw = 4
31 | padw = 1
32 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
33 | nf_mult = 1
34 | nf_mult_prev = 1
35 | for n in range(1, n_layers): # gradually increase the number of filters
36 | nf_mult_prev = nf_mult
37 | nf_mult = min(2 ** n, 8)
38 | sequence += [
39 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
40 | norm_layer(ndf * nf_mult),
41 | nn.LeakyReLU(0.2, True)
42 | ]
43 |
44 | nf_mult_prev = nf_mult
45 | nf_mult = min(2 ** n_layers, 8)
46 | sequence += [
47 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
48 | norm_layer(ndf * nf_mult),
49 | nn.LeakyReLU(0.2, True)
50 | ]
51 |
52 | sequence += [
53 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
54 | self.main = nn.Sequential(*sequence)
55 |
56 | self.apply(self._init_weights)
57 |
58 | def _init_weights(self, module):
59 | if isinstance(module, nn.Conv2d):
60 | nn.init.normal_(module.weight.data, 0.0, 0.02)
61 | elif isinstance(module, nn.BatchNorm2d):
62 | nn.init.normal_(module.weight.data, 1.0, 0.02)
63 | nn.init.constant_(module.bias.data, 0)
64 |
65 | def forward(self, input):
66 | """Standard forward."""
67 | return self.main(input)
68 |
69 |
70 | class ActNorm(nn.Module):
71 | def __init__(self, num_features, logdet=False, affine=True,
72 | allow_reverse_init=False):
73 | assert affine
74 | super().__init__()
75 | self.logdet = logdet
76 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
77 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
78 | self.allow_reverse_init = allow_reverse_init
79 |
80 | self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
81 |
82 | def initialize(self, input):
83 | with torch.no_grad():
84 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
85 | mean = (
86 | flatten.mean(1)
87 | .unsqueeze(1)
88 | .unsqueeze(2)
89 | .unsqueeze(3)
90 | .permute(1, 0, 2, 3)
91 | )
92 | std = (
93 | flatten.std(1)
94 | .unsqueeze(1)
95 | .unsqueeze(2)
96 | .unsqueeze(3)
97 | .permute(1, 0, 2, 3)
98 | )
99 |
100 | self.loc.data.copy_(-mean)
101 | self.scale.data.copy_(1 / (std + 1e-6))
102 |
103 | def forward(self, input, reverse=False):
104 | if reverse:
105 | return self.reverse(input)
106 | if len(input.shape) == 2:
107 | input = input[:,:,None,None]
108 | squeeze = True
109 | else:
110 | squeeze = False
111 |
112 | _, _, height, width = input.shape
113 |
114 | if self.training and self.initialized.item() == 0:
115 | self.initialize(input)
116 | self.initialized.fill_(1)
117 |
118 | h = self.scale * (input + self.loc)
119 |
120 | if squeeze:
121 | h = h.squeeze(-1).squeeze(-1)
122 |
123 | if self.logdet:
124 | log_abs = torch.log(torch.abs(self.scale))
125 | logdet = height*width*torch.sum(log_abs)
126 | logdet = logdet * torch.ones(input.shape[0]).to(input)
127 | return h, logdet
128 |
129 | return h
130 |
131 | def reverse(self, output):
132 | if self.training and self.initialized.item() == 0:
133 | if not self.allow_reverse_init:
134 | raise RuntimeError(
135 | "Initializing ActNorm in reverse direction is "
136 | "disabled by default. Use allow_reverse_init=True to enable."
137 | )
138 | else:
139 | self.initialize(output)
140 | self.initialized.fill_(1)
141 |
142 | if len(output.shape) == 2:
143 | output = output[:,:,None,None]
144 | squeeze = True
145 | else:
146 | squeeze = False
147 |
148 | h = output / self.scale - self.loc
149 |
150 | if squeeze:
151 | h = h.squeeze(-1).squeeze(-1)
152 | return h
--------------------------------------------------------------------------------
/SphereAR/gan/gan_loss.py:
--------------------------------------------------------------------------------
1 | # Modified from:
2 | # taming-transformers: https://github.com/CompVis/taming-transformers
3 | # muse-maskgit-pytorch: https://github.com/lucidrains/muse-maskgit-pytorch/blob/main/muse_maskgit_pytorch/vqgan_vae.py
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import torch._dynamo as dynamo
8 |
9 | from .lpips import LPIPS
10 | from .discriminator_patchgan import (
11 | NLayerDiscriminator as PatchGANDiscriminator,
12 | )
13 | from .discriminator_stylegan import (
14 | Discriminator as StyleGANDiscriminator,
15 | )
16 |
17 |
18 | def hinge_d_loss(logits_real, logits_fake):
19 | loss_real = torch.mean(F.relu(1.0 - logits_real))
20 | loss_fake = torch.mean(F.relu(1.0 + logits_fake))
21 | d_loss = 0.5 * (loss_real + loss_fake)
22 | return d_loss
23 |
24 |
25 | def vanilla_d_loss(logits_real, logits_fake):
26 | loss_real = torch.mean(F.softplus(-logits_real))
27 | loss_fake = torch.mean(F.softplus(logits_fake))
28 | d_loss = 0.5 * (loss_real + loss_fake)
29 | return d_loss
30 |
31 |
32 | def non_saturating_d_loss(logits_real, logits_fake):
33 | loss_real = torch.mean(
34 | F.binary_cross_entropy_with_logits(torch.ones_like(logits_real), logits_real)
35 | )
36 | loss_fake = torch.mean(
37 | F.binary_cross_entropy_with_logits(torch.zeros_like(logits_fake), logits_fake)
38 | )
39 | d_loss = 0.5 * (loss_real + loss_fake)
40 | return d_loss
41 |
42 |
43 | def hinge_gen_loss(logit_fake):
44 | return -torch.mean(logit_fake)
45 |
46 |
47 | def non_saturating_gen_loss(logit_fake):
48 | return torch.mean(
49 | F.binary_cross_entropy_with_logits(torch.ones_like(logit_fake), logit_fake)
50 | )
51 |
52 |
53 | def adopt_weight(weight, global_step, threshold=0, value=0.0):
54 | if global_step < threshold:
55 | weight = value
56 | return weight
57 |
58 |
59 | class GANLoss(nn.Module):
60 | def __init__(
61 | self,
62 | disc_start,
63 | disc_loss="hinge",
64 | disc_dim=64,
65 | disc_type="patchgan",
66 | image_size=256,
67 | disc_num_layers=3,
68 | disc_in_channels=3,
69 | disc_weight=0.5,
70 | gen_adv_loss="hinge",
71 | reconstruction_loss="l2",
72 | ):
73 | super().__init__()
74 | # discriminator loss
75 | assert disc_type in ["patchgan", "stylegan"]
76 | assert disc_loss in ["hinge", "vanilla", "non-saturating"]
77 | if disc_type == "patchgan":
78 | self.discriminator = PatchGANDiscriminator(
79 | input_nc=disc_in_channels,
80 | n_layers=disc_num_layers,
81 | ndf=disc_dim,
82 | )
83 | elif disc_type == "stylegan":
84 | self.discriminator = StyleGANDiscriminator(
85 | input_nc=disc_in_channels,
86 | image_size=image_size,
87 | )
88 | else:
89 | raise ValueError(f"Unknown GAN discriminator type '{disc_type}'.")
90 | if disc_loss == "hinge":
91 | self.disc_loss = hinge_d_loss
92 | elif disc_loss == "vanilla":
93 | self.disc_loss = vanilla_d_loss
94 | elif disc_loss == "non-saturating":
95 | self.disc_loss = non_saturating_d_loss
96 | else:
97 | raise ValueError(f"Unknown GAN discriminator loss '{disc_loss}'.")
98 | self.discriminator_iter_start = disc_start
99 | self.disc_weight = disc_weight
100 |
101 | assert gen_adv_loss in ["hinge", "non-saturating"]
102 | # gen_adv_loss
103 | if gen_adv_loss == "hinge":
104 | self.gen_adv_loss = hinge_gen_loss
105 | elif gen_adv_loss == "non-saturating":
106 | self.gen_adv_loss = non_saturating_gen_loss
107 | else:
108 | raise ValueError(f"Unknown GAN generator loss '{gen_adv_loss}'.")
109 |
110 | # perceptual loss
111 | self.perceptual_loss = LPIPS().eval()
112 |
113 | # reconstruction loss
114 | if reconstruction_loss == "l1":
115 | self.rec_loss = F.l1_loss
116 | elif reconstruction_loss == "l2":
117 | self.rec_loss = F.mse_loss
118 | else:
119 | raise ValueError(f"Unknown rec loss '{reconstruction_loss}'.")
120 |
121 | def forward(
122 | self,
123 | inputs,
124 | reconstructions,
125 | optimizer_idx,
126 | global_step,
127 | ):
128 | # inputs = inputs.contiguous()
129 | # reconstructions = reconstructions.contiguous()
130 | # generator update
131 | if optimizer_idx == 0:
132 | # reconstruction loss
133 | rec_loss = self.rec_loss(inputs, reconstructions)
134 |
135 | # perceptual loss
136 | p_loss = self.perceptual_loss(inputs, reconstructions)
137 | p_loss = torch.mean(p_loss)
138 |
139 | # discriminator loss
140 | logits_fake = self.discriminator(reconstructions)
141 | generator_adv_loss = self.gen_adv_loss(logits_fake)
142 |
143 | disc_weight = adopt_weight(
144 | self.disc_weight, global_step, threshold=self.discriminator_iter_start
145 | )
146 |
147 | return (
148 | rec_loss,
149 | p_loss,
150 | disc_weight * generator_adv_loss,
151 | )
152 |
153 | # discriminator update
154 | if optimizer_idx == 1:
155 | logits_real = self.discriminator(inputs.detach())
156 | logits_fake = self.discriminator(reconstructions.detach())
157 |
158 | disc_weight = adopt_weight(
159 | self.disc_weight, global_step, threshold=self.discriminator_iter_start
160 | )
161 | d_adversarial_loss = disc_weight * self.disc_loss(logits_real, logits_fake)
162 |
163 | return d_adversarial_loss
164 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SphereAR: Hyperspherical Latents Improve Continuous-Token Autoregressive Generation
2 |
3 | [](https://arxiv.org/abs/2509.24335)
4 | [](https://huggingface.co/guolinke/SphereAR)
5 |
6 |
7 |
8 |
9 |
10 |
11 | This is the official PyTorch implementation of paper [Hyperspherical Latents Improve Continuous-Token Autoregressive Generation](https://arxiv.org/abs/2509.24335).
12 |
13 | ```
14 | @article{ke2025hyperspherical,
15 | title={Hyperspherical Latents Improve Continuous-Token Autoregressive Generation},
16 | author={Guolin Ke and Hui Xue},
17 | journal={arXiv preprint arXiv:2509.24335},
18 | year={2025}
19 | }
20 | ```
21 |
22 |
23 | ## Introduction
24 |
25 |


26 |
27 | SphereAR is a simple yet effective approach to continuous-token autoregressive (AR) image generation: it makes AR scale-invariant by constraining all AR inputs and outputs---**including after CFG**---to lie on a fixed-radius hypersphere (constant L2 norm) via hyperspherical VAEs.
28 |
29 | The model is a **pure next-token** AR generator with **raster** order, matching standard language AR modeling (i.e., it is *not* next-scale AR like VAR and *not* next-set AR like MAR/MaskGIT).
30 |
31 | On ImageNet 256×256, SphereAR achieves a state-of-the-art FID of **1.34** among AR image generators.
32 |
33 |
34 | ## Environment
35 |
36 | - PyTorch: 2.7.1 (CUDA 12.6 build)
37 | - FlashAttention: 2.8.1
38 |
39 | ### Install notes
40 | 1. Install PyTorch 2.7.1 (CUDA 12.6) using your preferred method.
41 | 2. Install FlashAttention 2.8.1 from the prebuilt wheel (replace the cp310 tag with your Python version, e.g., cp311 for Python 3.11):
42 | ```shell
43 | pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.1/flash_attn-2.8.1+cu12torch2.7cxx11abiTRUE-cp310-cp310-linux_x86_64.whl
44 | ```
45 |
46 |
47 | ## Class-conditional image generation on ImageNet
48 |
49 | ### Model Checkpoints
50 |
51 | Name | params | FID (256x256) | weight
52 | --- |:---:|:---:|:---:|
53 | S-VAE | 75M | - | [vae.pt](https://huggingface.co/guolinke/SphereAR/blob/main/vae.pt)
54 | SphereAR-B | 208M | 1.92 | [SphereAR_B.pt](https://huggingface.co/guolinke/SphereAR/blob/main/SphereAR_B.pt)
55 | SphereAR-L | 479M | 1.54 | [SphereAR_L.pt](https://huggingface.co/guolinke/SphereAR/blob/main/SphereAR_L.pt)
56 | SphereAR-H | 943M | 1.34 | [SphereAR_H.pt](https://huggingface.co/guolinke/SphereAR/blob/main/SphereAR_H.pt)
57 |
58 | ### Evaluation from checkpoints
59 |
60 | 1. Sample 50,000 images and save to `.npz`.
61 |
62 | SphereAR-B:
63 | ```shell
64 | ckpt=your_ckpt_path
65 | result_path=your_result_path
66 | torchrun --nnodes=1 --nproc_per_node=8 --node_rank=0 \
67 | sample_ddp.py --model SphereAR-B --ckpt $ckpt --cfg-scale 4.5 \
68 | --sample-dir $result_path --per-proc-batch-size 256 --to-npz
69 | ```
70 |
71 | SphereAR-L:
72 | ```shell
73 | ckpt=your_ckpt_path
74 | result_path=your_result_path
75 | torchrun --nnodes=1 --nproc_per_node=8 --node_rank=0 \
76 | sample_ddp.py --model SphereAR-L --ckpt $ckpt --cfg-scale 4.6 \
77 | --sample-dir $result_path --per-proc-batch-size 256 --to-npz
78 | ```
79 |
80 | SphereAR-H:
81 | ```shell
82 | ckpt=your_ckpt_path
83 | result_path=your_result_path
84 | torchrun --nnodes=1 --nproc_per_node=8 --node_rank=0 \
85 | sample_ddp.py --model SphereAR-H --ckpt $ckpt --cfg-scale 4.5 \
86 | --sample-dir $result_path --per-proc-batch-size 256 --to-npz
87 | ```
88 |
89 | 2. Compute metrics following [OpenAI’s evaluation protocol](https://github.com/openai/guided-diffusion/tree/main/evaluations). You should download the [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/VIRTUAL_imagenet256_labeled.npz), and run `python evaluator.py VIRTUAL_imagenet256_labeled.npz your_generated.npz` for the metric. TensorFlow is required, and we use ```tensorflow==2.19.1```.
90 |
91 |
92 | ### Reproduce our training:
93 |
94 | 1. Download [ImageNet](http://image-net.org/download) dataset. **Note**: Our code support to train from the tar file, the decompression is not needed.
95 |
96 | 2. Train the S-VAE:
97 |
98 | ```shell
99 | data_path=your_data_path/ILSVRC2012_img_train.tar
100 | result_path=your_resulet_path
101 | torchrun --nnodes=1 --nproc_per_node=8 --node_rank=0 \
102 | train.py --results-dir $result_path --data-path $data_path \
103 | --image-size 256 --epochs 100 --patch-size 16 --latent-dim 16 --vae-only \
104 | --lr 1e-4 --global-batch-size 256 --warmup-steps -1 --decay-start -1
105 | ```
106 |
107 | 3. Train the AR model:
108 |
109 | ```shell
110 | data_path=your_data_path/ILSVRC2012_img_train.tar
111 | result_path=your_resulet_path
112 | vae_ckpt=your_vae_path
113 | torchrun --nproc_per_node=8 --master_addr=$WORKER_0_HOST --node_rank=$LOCAL_RANK --master_port=$WORKER_0_PORT --nnodes=$WORKER_NUM \
114 | train.py --results-dir $result_path --data-path $data_path --image-size 256 \
115 | --model SphereAR-B --epochs 400 --patch-size 16 --latent-dim 16 \
116 | --lr 3e-4 --global-batch-size 512 --trained-vae $vae_ckpt --ema 0.9999
117 | ```
118 | You can use the script above to train `SphereAR-B`; to train other sizes, set `--model` to `SphereAR-L` or `SphereAR-H`.
119 | We trained on A100 GPUs with the following setups: 8×A100 for SphereAR-B, 16×A100 for SphereAR-L, and 32×A100 for SphereAR-H.
120 | The training costs about 3 days for 400 epochs.
121 |
122 | **Note**: We use `torch.compile` for acceleration. Occasionally the TorchInductor compile step can hang; if that happens, re-run the job. Enabling Dynamo logs tends to reduce stalls: `export TORCH_LOGS="+dynamo"`. To avoid repeated compilation cost across runs, enable the compile caches:
123 |
124 | ```shell
125 | export TORCHINDUCTOR_FX_GRAPH_CACHE=1
126 | export TORCHINDUCTOR_AUTOGRAD_CACHE=1
127 | ```
128 |
129 | Set these environment variables in your shell (or job script) before launching training.
130 |
131 |
--------------------------------------------------------------------------------
/sample_ddp.py:
--------------------------------------------------------------------------------
1 | # Modified from:
2 | # DiT: https://github.com/facebookresearch/DiT/blob/main/sample.py
3 | import math
4 | import os
5 | import time
6 |
7 | import numpy as np
8 | import torch
9 | import torch.distributed as dist
10 | from PIL import Image
11 | from tqdm import tqdm
12 |
13 | from SphereAR.model import create_model, get_model_args
14 |
15 |
16 | def create_npz_from_sample_folder(sample_dir, num=50_000):
17 | """
18 | Builds a single .npz file from a folder of .png samples.
19 | """
20 | samples = []
21 | for i in tqdm(range(num), desc="Building .npz file from samples"):
22 | sample_pil = Image.open(f"{sample_dir}/{i:06d}.png")
23 | sample_np = np.asarray(sample_pil).astype(np.uint8)
24 | samples.append(sample_np)
25 | samples = np.stack(samples)
26 | assert samples.shape == (num, samples.shape[1], samples.shape[2], 3)
27 | npz_path = f"{sample_dir}.npz"
28 | np.savez(npz_path, arr_0=samples)
29 | print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
30 | # delete the sample folder to save space
31 | os.system(f"rm -r {sample_dir}")
32 | return npz_path
33 |
34 |
35 | def main(args):
36 | assert (
37 | torch.cuda.is_available()
38 | ), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage"
39 |
40 | dist.init_process_group("nccl")
41 | rank = dist.get_rank()
42 | device = rank % torch.cuda.device_count()
43 | seed = args.seed * dist.get_world_size() + rank
44 | torch.manual_seed(seed)
45 | torch.cuda.set_device(device)
46 | print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
47 | torch.set_float32_matmul_precision("high")
48 | torch.backends.cudnn.deterministic = True
49 | torch.backends.cudnn.benchmark = False
50 | torch.set_grad_enabled(False)
51 |
52 | # create and load gpt model
53 | precision = {"none": torch.float32, "bf16": torch.bfloat16}[args.mixed_precision]
54 | model = create_model(args, device)
55 |
56 | checkpoint = torch.load(args.ckpt, map_location="cpu", weights_only=False)
57 | if "ema" in checkpoint and not args.no_ema:
58 | print("use ema weight")
59 | model_weight = checkpoint["ema"]
60 | elif "model" in checkpoint: # ddp
61 | model_weight = checkpoint["model"]
62 | else:
63 | raise Exception("please check model weight")
64 |
65 | model.load_state_dict(model_weight, strict=True)
66 | model.eval()
67 | del checkpoint
68 |
69 | # Create folder to save samples:
70 | model_string_name = args.model.replace("/", "-")
71 | ckpt_string_name = (
72 | os.path.basename(args.ckpt).replace(".pth", "").replace(".pt", "")
73 | )
74 | folder_name = (
75 | f"{model_string_name}-{ckpt_string_name}-size-{args.image_size}-"
76 | f"steps-{args.sample_steps}-cfg-{args.cfg_scale}-seed-{args.seed}"
77 | )
78 | if not args.no_ema:
79 | folder_name += "-ema"
80 | sample_folder_dir = f"{args.sample_dir}/{folder_name}"
81 |
82 | if os.path.isfile(sample_folder_dir + ".npz"):
83 | if rank == 0:
84 | print(f"Found {sample_folder_dir}.npz, skipping sampling.")
85 | dist.barrier()
86 | dist.destroy_process_group()
87 | return 1
88 | if rank == 0:
89 | os.makedirs(sample_folder_dir, exist_ok=True)
90 | print(f"Saving .png samples at {sample_folder_dir}")
91 | dist.barrier()
92 |
93 | # Figure out how many samples we need to generate on each GPU and how many iterations we need to run:
94 | n = args.per_proc_batch_size
95 | global_batch_size = n * dist.get_world_size()
96 | # To make things evenly-divisible, we'll sample a bit more than we need and then discard the extra samples:
97 | total_samples = int(
98 | math.ceil(args.num_fid_samples / global_batch_size) * global_batch_size
99 | )
100 | if rank == 0:
101 | print(f"Total number of images that will be sampled: {total_samples}")
102 | assert (
103 | total_samples % dist.get_world_size() == 0
104 | ), "total_samples must be divisible by world_size"
105 | samples_needed_this_gpu = int(total_samples // dist.get_world_size())
106 | assert (
107 | samples_needed_this_gpu % n == 0
108 | ), "samples_needed_this_gpu must be divisible by the per-GPU batch size"
109 | iterations = int(samples_needed_this_gpu // n)
110 | total = 0
111 | start_time = time.time()
112 | for i in tqdm(range(iterations), desc="Sampling"):
113 | # Sample inputs:
114 | c_indices = torch.randint(0, args.num_classes, (n,), device=device)
115 | with torch.amp.autocast("cuda", dtype=precision):
116 | samples = model.sample(
117 | c_indices,
118 | sample_steps=args.sample_steps,
119 | cfg_scale=args.cfg_scale,
120 | )
121 |
122 | samples = (
123 | torch.clamp(127.5 * samples + 128.0, 0, 255)
124 | .permute(0, 2, 3, 1)
125 | .to("cpu", dtype=torch.uint8)
126 | .numpy()
127 | )
128 |
129 | # Save samples to disk as individual .png files
130 | for i, sample in enumerate(samples):
131 | index = i * dist.get_world_size() + rank + total
132 | Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png")
133 | total += global_batch_size
134 | print(
135 | f"Rank {rank} has sampled {total} images so far, cost {time.time() - start_time:.2f} seconds"
136 | )
137 |
138 | # Make sure all processes have finished saving their samples before attempting to convert to .npz
139 | dist.barrier()
140 | if rank == 0 and args.to_npz:
141 | print(f"Total time taken for sampling: {time.time() - start_time:.2f} seconds")
142 | create_npz_from_sample_folder(sample_folder_dir, args.num_fid_samples)
143 | print("Done.")
144 | dist.barrier()
145 | dist.destroy_process_group()
146 |
147 |
148 | if __name__ == "__main__":
149 | parser = get_model_args()
150 | parser.add_argument("--ckpt", type=str, default=None)
151 | parser.add_argument("--sample-dir", type=str, default="samples")
152 | parser.add_argument("--per-proc-batch-size", type=int, default=32)
153 | parser.add_argument("--num-fid-samples", type=int, default=50000)
154 | parser.add_argument("--cfg-scale", type=float, default=4.6)
155 | parser.add_argument("--seed", type=int, default=99)
156 | parser.add_argument("--sample-steps", type=int, default=100)
157 | parser.add_argument("--no-ema", action="store_true")
158 | parser.add_argument(
159 | "--mixed-precision", type=str, default="bf16", choices=["none", "bf16"]
160 | )
161 | parser.add_argument("--to-npz", action="store_true")
162 | args = parser.parse_args()
163 | main(args)
164 |
--------------------------------------------------------------------------------
/SphereAR/gan/lpips.py:
--------------------------------------------------------------------------------
1 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
2 |
3 | import os, hashlib
4 | import requests
5 | from tqdm import tqdm
6 |
7 | import torch
8 | import torch.nn as nn
9 | from torchvision import models
10 | from collections import namedtuple
11 |
12 | URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"}
13 |
14 | CKPT_MAP = {"vgg_lpips": "vgg.pth"}
15 |
16 | MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"}
17 |
18 |
19 | def download(url, local_path, chunk_size=1024):
20 | os.makedirs(os.path.split(local_path)[0], exist_ok=True)
21 | with requests.get(url, stream=True) as r:
22 | total_size = int(r.headers.get("content-length", 0))
23 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
24 | with open(local_path, "wb") as f:
25 | for data in r.iter_content(chunk_size=chunk_size):
26 | if data:
27 | f.write(data)
28 | pbar.update(chunk_size)
29 |
30 |
31 | def md5_hash(path):
32 | with open(path, "rb") as f:
33 | content = f.read()
34 | return hashlib.md5(content).hexdigest()
35 |
36 |
37 | def get_ckpt_path(name, root, check=False):
38 | assert name in URL_MAP
39 | path = os.path.join(root, CKPT_MAP[name])
40 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
41 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
42 | download(URL_MAP[name], path)
43 | md5 = md5_hash(path)
44 | assert md5 == MD5_MAP[name], md5
45 | return path
46 |
47 |
48 | class LPIPS(nn.Module):
49 | # Learned perceptual metric
50 | def __init__(self, use_dropout=True):
51 | super().__init__()
52 | self.scaling_layer = ScalingLayer()
53 | self.chns = [64, 128, 256, 512, 512] # vg16 features
54 | self.net = vgg16(pretrained=True, requires_grad=False)
55 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
56 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
57 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
58 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
59 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
60 | self.load_from_pretrained()
61 | for param in self.parameters():
62 | param.requires_grad = False
63 |
64 | def load_from_pretrained(self, name="vgg_lpips"):
65 | ckpt = get_ckpt_path(
66 | name, os.path.join(os.path.dirname(os.path.abspath(__file__)), "cache")
67 | )
68 | self.load_state_dict(
69 | torch.load(ckpt, map_location=torch.device("cpu")), strict=False
70 | )
71 | print("loaded pretrained LPIPS loss from {}".format(ckpt))
72 |
73 | @classmethod
74 | def from_pretrained(cls, name="vgg_lpips"):
75 | if name != "vgg_lpips":
76 | raise NotImplementedError
77 | model = cls()
78 | ckpt = get_ckpt_path(
79 | name, os.path.join(os.path.dirname(os.path.abspath(__file__)), "cache")
80 | )
81 | model.load_state_dict(
82 | torch.load(ckpt, map_location=torch.device("cpu")), strict=False
83 | )
84 | return model
85 |
86 | def forward(self, input, target):
87 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
88 | outs0, outs1 = self.net(in0_input), self.net(in1_input)
89 | feats0, feats1, diffs = {}, {}, {}
90 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
91 | for kk in range(len(self.chns)):
92 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(
93 | outs1[kk]
94 | )
95 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
96 |
97 | res = [
98 | spatial_average(lins[kk].model(diffs[kk]), keepdim=True)
99 | for kk in range(len(self.chns))
100 | ]
101 | val = res[0]
102 | for l in range(1, len(self.chns)):
103 | val += res[l]
104 | return val
105 |
106 |
107 | class ScalingLayer(nn.Module):
108 | def __init__(self):
109 | super(ScalingLayer, self).__init__()
110 | self.register_buffer(
111 | "shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None]
112 | )
113 | self.register_buffer(
114 | "scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None]
115 | )
116 |
117 | def forward(self, inp):
118 | return (inp - self.shift) / self.scale
119 |
120 |
121 | class NetLinLayer(nn.Module):
122 | """A single linear layer which does a 1x1 conv"""
123 |
124 | def __init__(self, chn_in, chn_out=1, use_dropout=False):
125 | super(NetLinLayer, self).__init__()
126 | layers = (
127 | [
128 | nn.Dropout(),
129 | ]
130 | if (use_dropout)
131 | else []
132 | )
133 | layers += [
134 | nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),
135 | ]
136 | self.model = nn.Sequential(*layers)
137 |
138 |
139 | class vgg16(torch.nn.Module):
140 | def __init__(self, requires_grad=False, pretrained=True):
141 | super(vgg16, self).__init__()
142 | vgg_pretrained_features = models.vgg16(
143 | weights=models.VGG16_Weights.IMAGENET1K_V1
144 | ).features
145 | self.slice1 = torch.nn.Sequential()
146 | self.slice2 = torch.nn.Sequential()
147 | self.slice3 = torch.nn.Sequential()
148 | self.slice4 = torch.nn.Sequential()
149 | self.slice5 = torch.nn.Sequential()
150 | self.N_slices = 5
151 | for x in range(4):
152 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
153 | for x in range(4, 9):
154 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
155 | for x in range(9, 16):
156 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
157 | for x in range(16, 23):
158 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
159 | for x in range(23, 30):
160 | self.slice5.add_module(str(x), vgg_pretrained_features[x])
161 | if not requires_grad:
162 | for param in self.parameters():
163 | param.requires_grad = False
164 |
165 | def forward(self, X):
166 | h = self.slice1(X)
167 | h_relu1_2 = h
168 | h = self.slice2(h)
169 | h_relu2_2 = h
170 | h = self.slice3(h)
171 | h_relu3_3 = h
172 | h = self.slice4(h)
173 | h_relu4_3 = h
174 | h = self.slice5(h)
175 | h_relu5_3 = h
176 | vgg_outputs = namedtuple(
177 | "VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]
178 | )
179 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
180 | return out
181 |
182 |
183 | def normalize_tensor(x, eps=1e-10):
184 | norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
185 | return x / (norm_factor + eps)
186 |
187 |
188 | def spatial_average(x, keepdim=True):
189 | return x.mean([2, 3], keepdim=keepdim)
190 |
--------------------------------------------------------------------------------
/SphereAR/dataset.py:
--------------------------------------------------------------------------------
1 | import contextlib
2 | import io
3 | import math
4 | import os
5 | import pickle
6 | import tarfile
7 | from functools import lru_cache
8 |
9 | import numpy as np
10 | import torch
11 | from PIL import Image
12 | from torch.utils.data import Dataset
13 | from torchvision.datasets import ImageFolder
14 |
15 |
16 | @contextlib.contextmanager
17 | def numpy_seed(seed, *addl_seeds):
18 | """Context manager which seeds the NumPy PRNG with the specified seed and
19 | restores the state afterward"""
20 | if seed is None:
21 | yield
22 | return
23 |
24 | def check_seed(s):
25 | assert type(s) == int or type(s) == np.int32 or type(s) == np.int64
26 |
27 | check_seed(seed)
28 | if len(addl_seeds) > 0:
29 | for s in addl_seeds:
30 | check_seed(s)
31 | seed = int(hash((seed, *addl_seeds)) % 1e8)
32 | state = np.random.get_state()
33 | np.random.seed(seed)
34 | try:
35 | yield
36 | finally:
37 | np.random.set_state(state)
38 |
39 |
40 | def build_flat_index(outer_path: str, idx_path: str):
41 | if os.path.exists(idx_path):
42 | print(f"Index file {idx_path} already exists. Skipping index building.")
43 | return pickle.load(open(idx_path, "rb"))
44 | entries = [] # (offset, size, label)
45 | cats = set()
46 | idx = 0
47 | with tarfile.open(outer_path, "r:") as outer:
48 | for sub in outer.getmembers():
49 | if not sub.isfile() or not sub.name.endswith(".tar"):
50 | continue
51 | outer_off = sub.offset_data
52 | sub_fobj = outer.extractfile(sub)
53 | with tarfile.open(fileobj=sub_fobj, mode="r:") as inner:
54 | for m in inner.getmembers():
55 | if not m.isfile():
56 | continue
57 | cat = m.name.split("_", 1)[0]
58 | cats.add(cat)
59 | abs_off = outer_off + m.offset_data
60 | entries.append((abs_off, m.size, cat))
61 | if idx % 1000 == 1:
62 | print(idx, m.name, abs_off, m.size, cat)
63 | idx += 1
64 | sorted_cats = sorted(cats)
65 | cat2idx = {c: i for i, c in enumerate(sorted_cats)}
66 |
67 | flat = [(off, size, cat2idx[c]) for off, size, c in entries]
68 |
69 | os.makedirs(os.path.dirname(idx_path), exist_ok=True)
70 | with open(idx_path, "wb") as f:
71 | pickle.dump(
72 | flat,
73 | f,
74 | )
75 | print(f"Built flat index with {len(flat)} images.")
76 | return flat
77 |
78 |
79 | class ImageNetTarDataset(Dataset):
80 | """
81 | ImageNet dataset stored in a tar file, avoid to decompress the whole dataset.
82 | You can direct use the original downloaded tar file (ILSVRC2012_img_train.tar) from official ImageNet website.
83 | The best practice is to copy the tar file to node's local disk or ramdisk (like /dev/shm/) first, to avoid remote I/O bottleneck.
84 | """
85 |
86 | def __init__(
87 | self,
88 | tar_file,
89 | ):
90 | self.tar_file = tar_file
91 | self.tar_handle = None
92 | self.files = build_flat_index(tar_file, tar_file + ".index")
93 | self.num_examples = len(self.files)
94 |
95 | def __len__(self):
96 | return self.num_examples
97 |
98 | def get_raw_image(self, index):
99 | if self.tar_handle is None:
100 | self.tar_handle = open(self.tar_file, "rb")
101 |
102 | offset, size, label = self.files[index]
103 | self.tar_handle.seek(offset)
104 | data = self.tar_handle.read(size)
105 | image = Image.open(io.BytesIO(data)).convert("RGB")
106 | return image, label
107 |
108 | @lru_cache(maxsize=16)
109 | def __getitem__(self, idx):
110 | return self.get_raw_image(idx)
111 |
112 |
113 | def center_crop_arr(pil_image, image_size):
114 | """
115 | Center cropping implementation from ADM.
116 | https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
117 | """
118 | while min(*pil_image.size) >= 2 * image_size:
119 | pil_image = pil_image.resize(
120 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX
121 | )
122 |
123 | scale = image_size / min(*pil_image.size)
124 | pil_image = pil_image.resize(
125 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
126 | )
127 |
128 | arr = np.array(pil_image)
129 | crop_y = (arr.shape[0] - image_size) // 2
130 | crop_x = (arr.shape[1] - image_size) // 2
131 | return Image.fromarray(
132 | arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size]
133 | )
134 |
135 |
136 | def numpy_randrange(start, end):
137 | return int(np.random.randint(start, end))
138 |
139 |
140 | def random_crop_arr(pil_image, image_size, min_crop_frac=0.8, max_crop_frac=1.0):
141 | min_smaller_dim_size = math.ceil(image_size / max_crop_frac)
142 | max_smaller_dim_size = math.ceil(image_size / min_crop_frac)
143 | smaller_dim_size = numpy_randrange(min_smaller_dim_size, max_smaller_dim_size + 1)
144 |
145 | # We are not on a new enough PIL to support the `reducing_gap`
146 | # argument, which uses BOX downsampling at powers of two first.
147 | # Thus, we do it by hand to improve downsample quality.
148 | while min(*pil_image.size) >= 2 * smaller_dim_size:
149 | pil_image = pil_image.resize(
150 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX
151 | )
152 |
153 | scale = smaller_dim_size / min(*pil_image.size)
154 | pil_image = pil_image.resize(
155 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
156 | )
157 |
158 | arr = np.array(pil_image)
159 | crop_y = numpy_randrange(0, arr.shape[0] - image_size + 1)
160 | crop_x = numpy_randrange(0, arr.shape[1] - image_size + 1)
161 | return Image.fromarray(
162 | arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size]
163 | )
164 |
165 |
166 | def crop(pil_image, left, top, right, bottom):
167 | """
168 | Crop the image to the specified box.
169 | """
170 | return pil_image.crop((left, top, right, bottom))
171 |
172 |
173 | class ImageCropDataset(Dataset):
174 |
175 | def __init__(
176 | self,
177 | raw_dataset,
178 | resolution,
179 | patch_size,
180 | seed=42,
181 | ):
182 | self.raw_dataset = raw_dataset
183 | self.resolution = resolution
184 | self.patch_size = patch_size
185 | self.aug_ratio = 1.0
186 | self.seed = seed
187 | self.epoch = None
188 |
189 | def set_epoch(self, epoch):
190 | self.epoch = epoch
191 |
192 | def set_aug_ratio(self, aug_ratio):
193 | self.aug_ratio = aug_ratio
194 |
195 | def __len__(self):
196 | return len(self.raw_dataset)
197 |
198 | def crop_and_flip(self, image):
199 | is_aug = np.random.rand() < self.aug_ratio
200 | if not is_aug:
201 | image = center_crop_arr(image, self.resolution)
202 | else:
203 | image = random_crop_arr(image, self.resolution)
204 |
205 | arr = np.asarray(image)
206 |
207 | is_flip = int(np.random.randint(0, 2))
208 | if is_flip == 1:
209 | # horizontal flip
210 | arr = arr[:, ::-1, :]
211 |
212 | return arr.transpose(2, 0, 1) # HWC to CHW
213 |
214 | def __getitem__(self, idx):
215 | with numpy_seed(self.seed, self.epoch, idx):
216 | image, label = self.raw_dataset[idx]
217 | samples = self.crop_and_flip(image)
218 | # to [-1, 1]
219 | samples = (samples.astype(np.float32) / 255.0 - 0.5) * 2.0
220 | samples = torch.from_numpy(samples).float()
221 | return (
222 | samples,
223 | torch.tensor(label).long(),
224 | )
225 |
226 |
227 | def build_dataset(args):
228 | # use tarred imagenet dataset if data_path ends with .tar
229 | raw_dataset = (
230 | ImageNetTarDataset(args.data_path)
231 | if args.data_path.endswith(".tar")
232 | else ImageFolder(args.data_path)
233 | )
234 | return ImageCropDataset(
235 | raw_dataset,
236 | args.image_size,
237 | args.patch_size,
238 | seed=args.global_seed if hasattr(args, "global_seed") else 42,
239 | )
240 |
--------------------------------------------------------------------------------
/SphereAR/diff_head.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch.utils.checkpoint import checkpoint
7 |
8 | from .sampling import euler_maruyama, euler
9 |
10 |
11 | def timestep_embedding(t, dim, max_period=10000, time_factor: float = 1000.0):
12 | half = dim // 2
13 | t = time_factor * t.float()
14 | freqs = torch.exp(
15 | -math.log(max_period)
16 | * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device)
17 | / half
18 | )
19 |
20 | args = t[:, None] * freqs[None]
21 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
22 | if dim % 2:
23 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
24 | if torch.is_floating_point(t):
25 | embedding = embedding.to(t)
26 | return embedding
27 |
28 |
29 | class DiffHead(nn.Module):
30 | """Diffusion Loss"""
31 |
32 | def __init__(
33 | self,
34 | ch_target,
35 | ch_cond,
36 | ch_latent,
37 | depth_latent,
38 | depth_adanln,
39 | grad_checkpointing=False,
40 | ):
41 | super(DiffHead, self).__init__()
42 | self.ch_target = ch_target
43 | self.net = MlpEncoder(
44 | in_channels=ch_target,
45 | model_channels=ch_latent,
46 | z_channels=ch_cond,
47 | num_res_blocks=depth_latent,
48 | num_ada_ln_blocks=depth_adanln,
49 | grad_checkpointing=grad_checkpointing,
50 | )
51 |
52 | def forward(self, target, z):
53 | with torch.autocast(device_type="cuda", enabled=False):
54 | with torch.no_grad():
55 | t = torch.randn((target.shape[0]), device=target.device).sigmoid()
56 | noise = torch.randn_like(target)
57 | ti = t.view(-1, 1)
58 | x = (1.0 - ti) * noise + ti * target
59 | v = target - noise
60 |
61 | output = self.net(x, t, z)
62 |
63 | with torch.autocast(device_type="cuda", enabled=False):
64 | output = output.float()
65 | loss = torch.mean((output - v) ** 2)
66 | return loss
67 |
68 | def sample(
69 | self,
70 | z,
71 | cfg,
72 | num_sampling_steps,
73 | ):
74 | return euler_maruyama(
75 | self.ch_target,
76 | self.net.forward,
77 | z,
78 | cfg,
79 | num_sampling_steps=num_sampling_steps,
80 | )
81 |
82 | def initialize_weights(self):
83 | self.net.initialize_weights()
84 |
85 |
86 | class TimestepEmbedder(nn.Module):
87 | """
88 | Embeds scalar timesteps into vector representations.
89 | """
90 |
91 | def __init__(self, hidden_size, frequency_embedding_size=256):
92 | super().__init__()
93 | self.mlp = nn.Sequential(
94 | nn.Linear(frequency_embedding_size, hidden_size, bias=True),
95 | nn.SiLU(),
96 | nn.Linear(hidden_size, hidden_size, bias=True),
97 | )
98 | self.frequency_embedding_size = frequency_embedding_size
99 |
100 | def forward(self, t):
101 | t_freq = timestep_embedding(t, self.frequency_embedding_size)
102 | t_emb = self.mlp(t_freq)
103 | return t_emb
104 |
105 |
106 | class ResBlock(nn.Module):
107 | def __init__(self, channels):
108 | super().__init__()
109 | self.channels = channels
110 | self.norm = nn.LayerNorm(channels, eps=1e-6, elementwise_affine=True)
111 | hidden_dim = int(channels * 1.5)
112 | self.w1 = nn.Linear(channels, hidden_dim * 2, bias=True)
113 | self.w2 = nn.Linear(hidden_dim, channels, bias=True)
114 |
115 | def forward(self, x, scale, shift, gate):
116 | h = self.norm(x) * (1 + scale) + shift
117 | h1, h2 = self.w1(h).chunk(2, dim=-1)
118 | h = self.w2(F.silu(h1) * h2)
119 | return x + h * gate
120 |
121 |
122 | class FinalLayer(nn.Module):
123 | def __init__(self, channels, out_channels):
124 | super().__init__()
125 | self.norm_final = nn.LayerNorm(channels, eps=1e-6, elementwise_affine=False)
126 | self.ada_ln_modulation = nn.Linear(channels, channels * 2, bias=True)
127 | self.linear = nn.Linear(channels, out_channels, bias=True)
128 |
129 | def forward(self, x, y):
130 | scale, shift = self.ada_ln_modulation(y).chunk(2, dim=-1)
131 | x = self.norm_final(x) * (1.0 + scale) + shift
132 | x = self.linear(x)
133 | return x
134 |
135 |
136 | class MlpEncoder(nn.Module):
137 |
138 | def __init__(
139 | self,
140 | in_channels,
141 | model_channels,
142 | z_channels,
143 | num_res_blocks,
144 | num_ada_ln_blocks=2,
145 | grad_checkpointing=False,
146 | ):
147 | super().__init__()
148 |
149 | self.in_channels = in_channels
150 | self.model_channels = model_channels
151 | self.out_channels = in_channels
152 | self.num_res_blocks = num_res_blocks
153 | self.grad_checkpointing = grad_checkpointing
154 |
155 | self.time_embed = TimestepEmbedder(model_channels)
156 | self.cond_embed = nn.Linear(z_channels, model_channels)
157 |
158 | self.input_proj = nn.Linear(in_channels, model_channels)
159 | self.res_blocks = nn.ModuleList()
160 | for i in range(num_res_blocks):
161 | self.res_blocks.append(
162 | ResBlock(
163 | model_channels,
164 | )
165 | )
166 | # share adaLN for consecutive blocks, to save computation and parameters
167 | self.ada_ln_blocks = nn.ModuleList()
168 | for i in range(num_ada_ln_blocks):
169 | self.ada_ln_blocks.append(
170 | nn.Linear(model_channels, model_channels * 3, bias=True)
171 | )
172 | self.ada_ln_switch_freq = max(1, num_res_blocks // num_ada_ln_blocks)
173 | assert (
174 | num_res_blocks % self.ada_ln_switch_freq
175 | ) == 0, "num_res_blocks must be divisible by num_ada_ln_blocks"
176 | self.final_layer = FinalLayer(model_channels, self.out_channels)
177 |
178 | self.initialize_weights()
179 |
180 | def initialize_weights(self):
181 | def _basic_init(module):
182 | if isinstance(module, nn.Linear):
183 | torch.nn.init.xavier_uniform_(module.weight)
184 | if module.bias is not None:
185 | nn.init.constant_(module.bias, 0)
186 |
187 | self.apply(_basic_init)
188 |
189 | # Initialize timestep embedding MLP
190 | nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02)
191 | nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02)
192 |
193 | for block in self.ada_ln_blocks:
194 | nn.init.constant_(block.weight, 0)
195 | nn.init.constant_(block.bias, 0)
196 |
197 | # Zero-out output layers
198 | nn.init.constant_(self.final_layer.ada_ln_modulation.weight, 0)
199 | nn.init.constant_(self.final_layer.ada_ln_modulation.bias, 0)
200 | nn.init.constant_(self.final_layer.linear.weight, 0)
201 | nn.init.constant_(self.final_layer.linear.bias, 0)
202 |
203 | @torch.compile()
204 | def forward(self, x, t, c):
205 | """
206 | Apply the model to an input batch.
207 | :param x: an [N x C] Tensor of inputs.
208 | :param t: a 1-D batch of timesteps.
209 | :param c: conditioning from AR transformer.
210 | :return: an [N x C] Tensor of outputs.
211 | """
212 | x = self.input_proj(x)
213 | t = self.time_embed(t)
214 | c = self.cond_embed(c)
215 |
216 | y = F.silu(t + c)
217 | scale, shift, gate = self.ada_ln_blocks[0](y).chunk(3, dim=-1)
218 | if self.grad_checkpointing and self.training:
219 | for i, block in enumerate(self.res_blocks):
220 | if i > 0 and i % self.ada_ln_switch_freq == 0:
221 | ada_ln_block = self.ada_ln_blocks[i // self.ada_ln_switch_freq]
222 | scale, shift, gate = ada_ln_block(y).chunk(3, dim=-1)
223 | x = checkpoint(block, x, scale, shift, gate, use_reentrant=False)
224 | else:
225 | for i, block in enumerate(self.res_blocks):
226 | if i > 0 and i % self.ada_ln_switch_freq == 0:
227 | ada_ln_block = self.ada_ln_blocks[i // self.ada_ln_switch_freq]
228 | scale, shift, gate = ada_ln_block(y).chunk(3, dim=-1)
229 | x = block(x, scale, shift, gate)
230 |
231 | return self.final_layer(x, y)
232 |
--------------------------------------------------------------------------------
/SphereAR/vae.py:
--------------------------------------------------------------------------------
1 | # Modified from:
2 | # taming-transformers: https://github.com/CompVis/taming-transformers
3 | # maskgit: https://github.com/google-research/maskgit
4 | # llamagen: https://github.com/FoundationVision/LlamaGen
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 | from .layers import TransformerBlock, get_2d_pos, precompute_freqs_cis_2d
11 | from .psd import PowerSphericalDistribution, l2_norm
12 |
13 |
14 | class VAE(nn.Module):
15 |
16 | def __init__(
17 | self,
18 | latent_dim=16,
19 | image_size=256,
20 | patch_size=16,
21 | z_channels=512,
22 | cnn_chs=[64, 64, 128, 256, 512],
23 | encoder_vit_layers=6,
24 | decoder_vit_layers=12,
25 | ):
26 | super().__init__()
27 | self.z_channels = z_channels
28 | n_head = z_channels // 64
29 | self.encoder = ViTEncoder(
30 | n_layers=encoder_vit_layers,
31 | d_model=z_channels,
32 | n_heads=n_head,
33 | cnn_chs=cnn_chs,
34 | image_size=image_size,
35 | patch_size=patch_size,
36 | )
37 | self.decoder = ViTDecoder(
38 | n_layers=decoder_vit_layers,
39 | d_model=z_channels,
40 | n_heads=n_head,
41 | cnn_chs=cnn_chs[::-1],
42 | image_size=image_size,
43 | patch_size=patch_size,
44 | )
45 | self.latent_dim = latent_dim
46 | self.quant_proj = nn.Linear(z_channels, latent_dim + 1, bias=True)
47 | self.post_quant_proj = nn.Linear(latent_dim, z_channels, bias=False)
48 |
49 | def initialize_weights(self):
50 |
51 | self.quant_proj.reset_parameters()
52 | self.post_quant_proj.reset_parameters()
53 | self.encoder.output.reset_parameters()
54 |
55 | def normalize(self, x):
56 | x = l2_norm(x)
57 | x = x * (self.latent_dim**0.5)
58 | return x
59 |
60 | def encode(self, x):
61 | x = self.encoder(x)
62 | x = self.quant_proj(x)
63 | mu = x[..., :-1]
64 | kappa = x[..., -1]
65 | mu = l2_norm(mu)
66 | kappa = F.softplus(kappa) + 1.0
67 | qz = PowerSphericalDistribution(mu, kappa)
68 | loss = qz.kl_to_uniform()
69 | x = qz.rsample()
70 | x = x * (self.latent_dim**0.5)
71 | return x, loss.mean()
72 |
73 | def decode(self, x):
74 | x = self.post_quant_proj(x)
75 | dec = self.decoder(x)
76 | return dec
77 |
78 |
79 | class ResDownBlock(nn.Module):
80 | def __init__(self, in_ch, out_ch):
81 | super().__init__()
82 | self.block = nn.Sequential(
83 | nn.GroupNorm(min(32, in_ch // 4), in_ch, eps=1e-6, affine=True),
84 | nn.SiLU(),
85 | nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=2, padding=1, bias=False),
86 | nn.GroupNorm(min(32, out_ch // 4), out_ch, eps=1e-6, affine=True),
87 | nn.SiLU(),
88 | nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=False),
89 | )
90 | self.shortcut = nn.Conv2d(
91 | in_ch, out_ch, kernel_size=2, stride=2, padding=0, bias=False
92 | )
93 |
94 | def forward(self, x):
95 | return self.shortcut(x) + self.block(x)
96 |
97 |
98 | class PatchifyNet(nn.Module):
99 |
100 | def __init__(self, chs):
101 | super().__init__()
102 | layers = []
103 | for i in range(len(chs) - 1):
104 | layers.append(ResDownBlock(chs[i], chs[i + 1]))
105 | self.net = nn.Sequential(*layers)
106 |
107 | def forward(self, x):
108 | x = self.net(x)
109 | return x
110 |
111 |
112 | class NCHW_to_NLC(nn.Module):
113 | def __init__(self):
114 | super().__init__()
115 |
116 | def forward(self, x):
117 | n, c, h, w = x.shape
118 | x = x.permute(0, 2, 3, 1).view(n, h * w, c)
119 | return x
120 |
121 |
122 | class NLC_to_NCHW(nn.Module):
123 | def __init__(self):
124 | super().__init__()
125 |
126 | def forward(self, x):
127 | n, l, c = x.shape
128 | h = w = int(l**0.5)
129 | x = x.view(n, h, w, c).permute(0, 3, 1, 2)
130 | return x
131 |
132 |
133 | class ResUpBlock(nn.Module):
134 | def __init__(self, in_ch, out_ch):
135 | super().__init__()
136 | num_groups = min(32, out_ch // 4)
137 | self.block = nn.Sequential(
138 | nn.GroupNorm(min(32, in_ch // 4), in_ch, eps=1e-6, affine=True),
139 | nn.SiLU(),
140 | nn.ConvTranspose2d(
141 | in_ch, out_ch, kernel_size=4, stride=2, padding=1, bias=False
142 | ),
143 | nn.GroupNorm(num_groups, out_ch, eps=1e-6, affine=True),
144 | nn.SiLU(),
145 | nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=False),
146 | )
147 | self.shortcut = nn.ConvTranspose2d(
148 | in_ch, out_ch, kernel_size=2, stride=2, bias=False
149 | )
150 | self.block2 = nn.Sequential(
151 | nn.GroupNorm(num_groups, out_ch, eps=1e-6, affine=True),
152 | nn.SiLU(),
153 | nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=False),
154 | nn.GroupNorm(num_groups, out_ch, eps=1e-6, affine=True),
155 | nn.SiLU(),
156 | nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=False),
157 | )
158 |
159 | def forward(self, x):
160 | h = self.block(x)
161 | x = self.shortcut(x)
162 | x = x + h
163 | x = x + self.block2(x)
164 | return x
165 |
166 |
167 | class UnpatchifyNet(nn.Module):
168 |
169 | def __init__(self, chs):
170 | super().__init__()
171 | layers = []
172 | for i in range(len(chs) - 1):
173 | layers.append(ResUpBlock(chs[i], chs[i + 1]))
174 | self.net = nn.Sequential(*layers)
175 |
176 | def forward(self, x):
177 | x = self.net(x)
178 | return x
179 |
180 |
181 | class ViTEncoder(nn.Module):
182 |
183 | def __init__(
184 | self,
185 | n_layers=6,
186 | n_heads=8,
187 | d_model=512,
188 | cnn_chs=[64, 64, 128, 256, 512],
189 | image_size=256,
190 | patch_size=16,
191 | ):
192 | super().__init__()
193 | self.conv_in = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
194 | self.patchify = nn.Sequential(
195 | PatchifyNet(cnn_chs),
196 | NCHW_to_NLC(),
197 | )
198 | assert cnn_chs[-1] == d_model
199 | self.register_num_tokens = 4
200 | self.register_token = nn.Embedding(self.register_num_tokens, d_model)
201 | self.layers = nn.ModuleList()
202 | for _ in range(n_layers):
203 | self.layers.append(
204 | TransformerBlock(
205 | d_model,
206 | n_heads,
207 | 0.0,
208 | 0.0,
209 | drop_path=0.0,
210 | causal=False,
211 | )
212 | )
213 | self.norm = nn.RMSNorm(d_model, eps=1e-6)
214 | self.output = nn.Linear(d_model, d_model, bias=False)
215 | raw_2d_pos = get_2d_pos(image_size, patch_size)
216 | self.register_buffer(
217 | "freqs_cis",
218 | precompute_freqs_cis_2d(
219 | raw_2d_pos, d_model // n_heads, cls_token_num=self.register_num_tokens
220 | ).clone(),
221 | )
222 |
223 | def forward(self, image):
224 | x = self.conv_in(image)
225 | x = self.patchify(x)
226 | x_null = self.register_token.weight.view(1, -1, x.shape[-1]).expand(
227 | x.shape[0], -1, -1
228 | )
229 | x = torch.cat([x_null, x], dim=1)
230 | for layer in self.layers:
231 | x = layer(x, self.freqs_cis)
232 | x = x[:, self.register_num_tokens :, :]
233 | x = self.output(self.norm(x))
234 | return x
235 |
236 |
237 | class ViTDecoder(nn.Module):
238 | def __init__(
239 | self,
240 | n_layers=12,
241 | n_heads=8,
242 | d_model=512,
243 | cnn_chs=[512, 256, 128, 64, 64],
244 | image_size=256,
245 | patch_size=16,
246 | ):
247 | super().__init__()
248 | self.d_model = d_model
249 | assert d_model == cnn_chs[0]
250 | self.conv_in = nn.Sequential(
251 | NLC_to_NCHW(),
252 | nn.Conv2d(d_model, d_model, kernel_size=3, stride=1, padding=1, bias=False),
253 | NCHW_to_NLC(),
254 | )
255 | self.register_num_tokens = 4
256 | self.register_token = nn.Embedding(self.register_num_tokens, d_model)
257 | self.layers = nn.ModuleList()
258 | for _ in range(n_layers):
259 | self.layers.append(
260 | TransformerBlock(
261 | d_model,
262 | n_heads,
263 | 0.0,
264 | 0.0,
265 | drop_path=0.0,
266 | causal=False,
267 | )
268 | )
269 | self.unpatchify = nn.Sequential(
270 | NLC_to_NCHW(),
271 | UnpatchifyNet(chs=cnn_chs),
272 | )
273 | self.conv_out = nn.Sequential(
274 | nn.GroupNorm(16, cnn_chs[-1], eps=1e-6, affine=True),
275 | nn.SiLU(),
276 | nn.Conv2d(cnn_chs[-1], 3, kernel_size=3, stride=1, padding=1, bias=False),
277 | )
278 |
279 | raw_2d_pos = get_2d_pos(image_size, patch_size)
280 | self.register_buffer(
281 | "freqs_cis",
282 | precompute_freqs_cis_2d(
283 | raw_2d_pos, d_model // n_heads, cls_token_num=self.register_num_tokens
284 | ).clone(),
285 | )
286 |
287 | @torch.compile()
288 | def forward(self, x):
289 | x = self.conv_in(x)
290 | x_null = self.register_token.weight.view(1, -1, x.shape[-1]).expand(
291 | x.shape[0], -1, -1
292 | )
293 | x = torch.cat([x_null, x], dim=1)
294 | for layer in self.layers:
295 | x = layer(x, self.freqs_cis)
296 | x = x[:, self.register_num_tokens :, :]
297 | x = self.unpatchify(x)
298 | x = self.conv_out(x)
299 | return x
300 |
--------------------------------------------------------------------------------
/SphereAR/layers.py:
--------------------------------------------------------------------------------
1 | from functools import lru_cache
2 | from typing import Optional
3 |
4 | import torch
5 | import torch.nn as nn
6 | from flash_attn import flash_attn_func
7 | from torch.nn import RMSNorm
8 | from torch.nn import functional as F
9 |
10 |
11 | def drop_path(
12 | x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
13 | ):
14 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
15 |
16 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
17 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
18 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
19 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
20 | 'survival rate' as the argument.
21 |
22 | """
23 | if drop_prob == 0.0 or not training:
24 | return x
25 | keep_prob = 1 - drop_prob
26 | shape = (x.shape[0],) + (1,) * (
27 | x.ndim - 1
28 | ) # work with diff dim tensors, not just 2D ConvNets
29 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
30 | if keep_prob > 0.0 and scale_by_keep:
31 | random_tensor.div_(keep_prob)
32 | return x * random_tensor
33 |
34 |
35 | class DropPath(torch.nn.Module):
36 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
37 |
38 | def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
39 | super(DropPath, self).__init__()
40 | self.drop_prob = drop_prob
41 | self.scale_by_keep = scale_by_keep
42 |
43 | def forward(self, x):
44 | return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
45 |
46 | def extra_repr(self):
47 | return f"drop_prob={round(self.drop_prob,3):0.3f}"
48 |
49 |
50 | def find_multiple(n: int, k: int):
51 | if n % k == 0:
52 | return n
53 | return n + k - (n % k)
54 |
55 |
56 | @lru_cache(maxsize=16)
57 | def get_causal_mask(seq_q, seq_k, device):
58 | offset = seq_k - seq_q
59 | i = torch.arange(seq_q, device=device).unsqueeze(1)
60 | j = torch.arange(seq_k, device=device).unsqueeze(0)
61 | causal_mask = (j > (offset + i)).bool()
62 | causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)
63 | return causal_mask
64 |
65 |
66 | class Attention(nn.Module):
67 | def __init__(
68 | self,
69 | dim,
70 | n_head,
71 | attn_dropout_p,
72 | resid_dropout_p,
73 | causal: bool = True,
74 | ):
75 | super().__init__()
76 | assert dim % n_head == 0
77 | self.dim = dim
78 | self.head_dim = dim // n_head
79 | self.scale = self.head_dim**-0.5
80 | self.n_head = n_head
81 | total_kv_dim = (self.n_head * 3) * self.head_dim
82 |
83 | self.wqkv = nn.Linear(dim, total_kv_dim, bias=False)
84 | self.wo = nn.Linear(dim, dim, bias=False)
85 |
86 | self.attn_dropout_p = attn_dropout_p
87 | self.resid_dropout = nn.Dropout(resid_dropout_p)
88 | self.causal = causal
89 |
90 | self.k_cache = None
91 | self.v_cache = None
92 | self.kv_cache_size = None
93 |
94 | def enable_kv_cache(self, bsz, max_seq_len):
95 | if self.kv_cache_size != (bsz, max_seq_len):
96 | device = self.wo.weight.device
97 | dtype = self.wo.weight.dtype
98 | self.k_cache = torch.zeros(
99 | (bsz, self.n_head, max_seq_len, self.head_dim),
100 | device=device,
101 | dtype=dtype,
102 | )
103 | self.v_cache = torch.zeros(
104 | (bsz, self.n_head, max_seq_len, self.head_dim),
105 | device=device,
106 | dtype=dtype,
107 | )
108 | self.kv_cache_size = (bsz, max_seq_len)
109 |
110 | def update_kv_cache(
111 | self, start_pos, end_pos, keys: torch.Tensor, values: torch.Tensor
112 | ):
113 | self.k_cache[:, :, start_pos:end_pos, :] = keys
114 | self.v_cache[:, :, start_pos:end_pos, :] = values
115 | return (
116 | self.k_cache[:, :, :end_pos, :],
117 | self.v_cache[:, :, :end_pos, :],
118 | )
119 |
120 | def naive_attention(self, xq, keys, values, is_causal):
121 | xq = xq * self.scale
122 | # q: [B, H, 1, D], k: [B, H, D, L] -> attn [B, H, 1, L]
123 | attn = xq @ keys.transpose(-1, -2)
124 | seq_q, seq_k = attn.shape[-2], attn.shape[-1]
125 | if is_causal and seq_q > 1:
126 | causal_mask = get_causal_mask(seq_q, seq_k, attn.device)
127 | attn.masked_fill_(causal_mask, float("-inf"))
128 | attn = torch.softmax(attn, dim=-1)
129 | if self.attn_dropout_p > 0 and self.training:
130 | attn = F.dropout(attn, p=self.attn_dropout_p, training=self.training)
131 | # [B, H, 1, L] @ [B, H, L, D] -> [B, H, 1, D]
132 | return attn @ values
133 |
134 | def forward(
135 | self,
136 | x: torch.Tensor,
137 | freqs_cis: torch.Tensor = None,
138 | start_pos: Optional[int] = None,
139 | end_pos: Optional[int] = None,
140 | ):
141 | bsz, seqlen, _ = x.shape
142 | xq, xk, xv = self.wqkv(x).chunk(3, dim=-1)
143 |
144 | xq = xq.view(bsz, seqlen, self.n_head, self.head_dim)
145 | xk = xk.view(bsz, seqlen, self.n_head, self.head_dim)
146 | xv = xv.view(bsz, seqlen, self.n_head, self.head_dim)
147 |
148 | if freqs_cis is not None:
149 | xq = apply_rotary_emb(xq, freqs_cis)
150 | xk = apply_rotary_emb(xk, freqs_cis)
151 |
152 | is_causal = self.causal
153 | if self.k_cache is not None and start_pos is not None:
154 | xq, xk, xv = map(lambda x: x.transpose(1, 2), (xq, xk, xv))
155 | keys, values = self.update_kv_cache(start_pos, end_pos, xk, xv)
156 | output = self.naive_attention(xq, keys, values, is_causal=is_causal)
157 | output = output.transpose(1, 2).contiguous()
158 | else:
159 | output = flash_attn_func(
160 | xq,
161 | xk,
162 | xv,
163 | causal=is_causal,
164 | dropout_p=self.attn_dropout_p if self.training else 0,
165 | )
166 |
167 | output = output.view(bsz, seqlen, self.dim)
168 |
169 | output = self.resid_dropout(self.wo(output))
170 | return output
171 |
172 |
173 | class FeedForward(nn.Module):
174 |
175 | def __init__(self, dim, dropout_p=0.1, mlp_ratio=4.0):
176 | super().__init__()
177 | hidden_dim = mlp_ratio * dim
178 | hidden_dim = int(2 * hidden_dim / 3)
179 | hidden_dim = find_multiple(hidden_dim, 256)
180 |
181 | self.w1 = nn.Linear(dim, hidden_dim * 2, bias=False)
182 | self.w2 = nn.Linear(hidden_dim, dim, bias=False)
183 | self.ffn_dropout = nn.Dropout(dropout_p)
184 |
185 | def forward(self, x):
186 | h1, h2 = self.w1(x).chunk(2, dim=-1)
187 | return self.ffn_dropout(self.w2(F.silu(h1) * h2))
188 |
189 |
190 | class TransformerBlock(nn.Module):
191 | def __init__(
192 | self,
193 | dim,
194 | n_head,
195 | attn_dropout_p: float = 0.0,
196 | resid_dropout_p: float = 0.0,
197 | drop_path: float = 0.0,
198 | causal: bool = True,
199 | ):
200 | super().__init__()
201 | self.attention = Attention(
202 | dim=dim,
203 | n_head=n_head,
204 | attn_dropout_p=attn_dropout_p,
205 | resid_dropout_p=resid_dropout_p,
206 | causal=causal,
207 | )
208 | self.feed_forward = FeedForward(
209 | dim=dim,
210 | dropout_p=resid_dropout_p,
211 | )
212 | self.attention_norm = RMSNorm(dim, eps=1e-6)
213 | self.ffn_norm = RMSNorm(dim, eps=1e-6)
214 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
215 |
216 | def forward(
217 | self,
218 | x: torch.Tensor,
219 | freqs_cis: torch.Tensor,
220 | ):
221 | h = x + self.drop_path(self.attention(self.attention_norm(x), freqs_cis))
222 | out = h + self.drop_path(self.feed_forward(self.ffn_norm(h)))
223 | return out
224 |
225 | def forward_onestep(
226 | self,
227 | x: torch.Tensor,
228 | freqs_cis: torch.Tensor,
229 | start_pos: int,
230 | end_pos: int,
231 | ):
232 | h = x + self.drop_path(
233 | self.attention(self.attention_norm(x), freqs_cis, start_pos, end_pos)
234 | )
235 | out = h + self.drop_path(self.feed_forward(self.ffn_norm(h)))
236 | return out
237 |
238 |
239 | def get_2d_pos(resolution, patch_size, num_scales=1):
240 | max_pos = resolution // patch_size
241 | coords_list = []
242 |
243 | for i in range(num_scales):
244 | scale = 2 ** (num_scales - i - 1)
245 | P = max(resolution // scale // patch_size, 1)
246 | edge = float(max_pos) / P
247 | centers = (torch.arange(P, dtype=torch.float32) + 0.5) * edge
248 | grid_y, grid_x = torch.meshgrid(centers, centers, indexing="ij")
249 | coords = torch.stack([grid_x.reshape(-1), grid_y.reshape(-1)], dim=1)
250 | coords_list.append(coords)
251 |
252 | return torch.cat(coords_list, dim=0)
253 |
254 |
255 | def precompute_freqs_cis_2d(
256 | pos_2d, n_elem: int, base: float = 10000, cls_token_num=120
257 | ):
258 | # split the dimension into half, one for x and one for y
259 | half_dim = n_elem // 2
260 | freqs = 1.0 / (
261 | base ** (torch.arange(0, half_dim, 2)[: (half_dim // 2)].float() / half_dim)
262 | )
263 | t = pos_2d + 1.0
264 | if cls_token_num > 0:
265 | t = torch.cat(
266 | [torch.zeros((cls_token_num, 2), device=freqs.device), t],
267 | dim=0,
268 | )
269 | freqs = torch.outer(t.flatten(), freqs).view(*t.shape[:-1], -1)
270 | return torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1)
271 |
272 |
273 | def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor):
274 | # x: (bs, seq_len, n_head, head_dim)
275 | # freqs_cis (seq_len, head_dim // 2, 2)
276 | xshaped = x.float().reshape(
277 | *x.shape[:-1], -1, 2
278 | ) # (bs, seq_len, n_head, head_dim//2, 2)
279 | freqs_cis = freqs_cis.view(
280 | 1, xshaped.size(1), 1, xshaped.size(3), 2
281 | ) # (1, seq_len, 1, head_dim//2, 2)
282 | x_out2 = torch.stack(
283 | [
284 | xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
285 | xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
286 | ],
287 | dim=-1,
288 | )
289 | x_out2 = x_out2.flatten(3)
290 | return x_out2.type_as(x)
291 |
--------------------------------------------------------------------------------
/SphereAR/model.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from functools import partial
3 |
4 | import torch
5 | import torch.nn as nn
6 | from torch.nn import functional as F
7 | from torch.utils.checkpoint import checkpoint
8 |
9 | from .diff_head import DiffHead
10 | from .layers import TransformerBlock, get_2d_pos, precompute_freqs_cis_2d
11 | from .vae import VAE
12 |
13 |
14 | def get_model_args():
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument(
17 | "--model", type=str, choices=list(SphereAR_models.keys()), default="SphereAR-L"
18 | )
19 | parser.add_argument("--vae-only", action="store_true", help="only train vae")
20 | parser.add_argument("--image-size", type=int, choices=[256, 512], default=256)
21 | parser.add_argument("--patch-size", type=int, default=16, choices=[16])
22 | parser.add_argument("--num-classes", type=int, default=1000)
23 | parser.add_argument("--cls-token-num", type=int, default=16)
24 | parser.add_argument("--latent-dim", type=int, default=16)
25 | parser.add_argument("--diff-batch-mul", type=int, default=4)
26 | parser.add_argument("--grad-checkpointing", action="store_true")
27 | return parser
28 |
29 |
30 | def create_model(args, device):
31 | model = SphereAR_models[args.model](
32 | resolution=args.image_size,
33 | patch_size=args.patch_size,
34 | latent_dim=args.latent_dim,
35 | vae_only=args.vae_only,
36 | diff_batch_mul=args.diff_batch_mul,
37 | cls_token_num=args.cls_token_num,
38 | num_classes=args.num_classes,
39 | grad_checkpointing=args.grad_checkpointing,
40 | ).to(device, memory_format=torch.channels_last)
41 | return model
42 |
43 |
44 | class SphereAR(nn.Module):
45 |
46 | def __init__(
47 | self,
48 | dim,
49 | n_layer,
50 | n_head,
51 | diff_layers,
52 | diff_dim,
53 | diff_adanln_layers,
54 | latent_dim,
55 | patch_size,
56 | resolution,
57 | diff_batch_mul,
58 | vae_only=False,
59 | grad_checkpointing=False,
60 | cls_token_num=16,
61 | num_classes: int = 1000,
62 | class_dropout_prob: float = 0.1,
63 | ):
64 | super().__init__()
65 |
66 | self.n_layer = n_layer
67 | self.resolution = resolution
68 | self.patch_size = patch_size
69 | self.num_classes = num_classes
70 | self.cls_token_num = cls_token_num
71 | self.class_dropout_prob = class_dropout_prob
72 | self.latent_dim = latent_dim
73 |
74 | self.vae = VAE(
75 | latent_dim=latent_dim, image_size=resolution, patch_size=patch_size
76 | )
77 | self.vae_only = vae_only
78 | self.grad_checkpointing = grad_checkpointing
79 |
80 | if not vae_only:
81 | self.cls_embedding = nn.Embedding(num_classes + 1, dim * self.cls_token_num)
82 | self.proj_in = nn.Linear(latent_dim, dim, bias=True)
83 | self.emb_norm = nn.RMSNorm(dim, eps=1e-6, elementwise_affine=True)
84 | self.h, self.w = resolution // patch_size, resolution // patch_size
85 | self.total_tokens = self.h * self.w + self.cls_token_num
86 |
87 | self.layers = torch.nn.ModuleList()
88 | for layer_id in range(n_layer):
89 | self.layers.append(
90 | TransformerBlock(
91 | dim,
92 | n_head,
93 | causal=True,
94 | )
95 | )
96 |
97 | self.norm = nn.RMSNorm(dim, eps=1e-6, elementwise_affine=True)
98 | self.pos_for_diff = nn.Embedding(self.h * self.w, dim)
99 | self.head = DiffHead(
100 | ch_target=latent_dim,
101 | ch_cond=dim,
102 | ch_latent=diff_dim,
103 | depth_latent=diff_layers,
104 | depth_adanln=diff_adanln_layers,
105 | grad_checkpointing=grad_checkpointing,
106 | )
107 | self.diff_batch_mul = diff_batch_mul
108 |
109 | patch_2d_pos = get_2d_pos(resolution, patch_size)
110 |
111 | self.register_buffer(
112 | "freqs_cis",
113 | precompute_freqs_cis_2d(
114 | patch_2d_pos,
115 | dim // n_head,
116 | 10000,
117 | cls_token_num=self.cls_token_num,
118 | )[:-1],
119 | persistent=False,
120 | )
121 | self.freeze_vae()
122 |
123 | self.initialize_weights()
124 |
125 | def non_decay_keys(self):
126 | return ["proj_in", "cls_embedding"]
127 |
128 | def freeze_module(self, module: nn.Module):
129 | for param in module.parameters():
130 | param.requires_grad = False
131 |
132 | def freeze_vae(self):
133 | self.freeze_module(self.vae)
134 | self.vae.eval()
135 |
136 | def initialize_weights(self):
137 | # Initialize nn.Linear and nn.Embedding
138 | self.apply(self.__init_weights)
139 | if not self.vae_only:
140 | self.head.initialize_weights()
141 | self.vae.initialize_weights()
142 |
143 | def __init_weights(self, module):
144 | std = 0.02
145 | if isinstance(module, nn.Linear):
146 | module.weight.data.normal_(mean=0.0, std=std)
147 | if module.bias is not None:
148 | module.bias.data.zero_()
149 | elif isinstance(module, nn.Embedding):
150 | module.weight.data.normal_(mean=0.0, std=std)
151 |
152 | def drop_label(self, class_id):
153 | if self.class_dropout_prob > 0.0 and self.training:
154 | is_drop = (
155 | torch.rand(class_id.shape, device=class_id.device)
156 | < self.class_dropout_prob
157 | )
158 | class_id = torch.where(is_drop, self.num_classes, class_id)
159 | return class_id
160 |
161 | def forward(
162 | self,
163 | images,
164 | class_id,
165 | ):
166 |
167 | vae_latent, kl_loss = self.vae.encode(images)
168 |
169 | if not self.vae_only:
170 | x = vae_latent.detach()
171 | x = self.proj_in(x[:, :-1, :])
172 | class_id = self.drop_label(class_id)
173 | bsz = x.shape[0]
174 | c = self.cls_embedding(class_id).view(bsz, self.cls_token_num, -1)
175 | x = torch.cat([c, x], dim=1)
176 | x = self.emb_norm(x)
177 |
178 | if self.grad_checkpointing and self.training:
179 | for layer in self.layers:
180 | block = partial(layer.forward, freqs_cis=self.freqs_cis)
181 | x = checkpoint(block, x, use_reentrant=False)
182 | else:
183 | for layer in self.layers:
184 | x = layer(x, self.freqs_cis)
185 |
186 | x = x[:, -self.h * self.w :, :]
187 | x = self.norm(x)
188 | x = x + self.pos_for_diff.weight
189 |
190 | target = vae_latent.detach()
191 | x = x.view(-1, x.shape[-1])
192 | target = target.view(-1, target.shape[-1])
193 |
194 | x = x.repeat(self.diff_batch_mul, 1)
195 | target = target.repeat(self.diff_batch_mul, 1)
196 | loss = self.head(target, x)
197 | recon = None
198 | else:
199 | loss = torch.tensor(0.0, device=images.device, dtype=images.dtype)
200 | recon = self.vae.decode(vae_latent)
201 |
202 | return loss, kl_loss, recon
203 |
204 | def enable_kv_cache(self, bsz):
205 | for layer in self.layers:
206 | layer.attention.enable_kv_cache(bsz, self.total_tokens)
207 |
208 | @torch.compile()
209 | def forward_model(self, x, start_pos, end_pos):
210 | x = self.emb_norm(x)
211 | for layer in self.layers:
212 | x = layer.forward_onestep(
213 | x, self.freqs_cis[start_pos:end_pos,], start_pos, end_pos
214 | )
215 | x = self.norm(x)
216 | return x
217 |
218 | def head_sample(self, x, diff_pos, sample_steps, cfg_scale, cfg_schedule="linear"):
219 | x = x + self.pos_for_diff.weight[diff_pos : diff_pos + 1, :]
220 | x = x.view(-1, x.shape[-1])
221 | seq_len = self.h * self.w
222 | if cfg_scale > 1.0:
223 | if cfg_schedule == "constant":
224 | cfg_iter = cfg_scale
225 | elif cfg_schedule == "linear":
226 | start = 1.0
227 | cfg_iter = start + (cfg_scale - start) * diff_pos / seq_len
228 | else:
229 | raise NotImplementedError(f"unknown cfg_schedule {cfg_schedule}")
230 | else:
231 | cfg_iter = 1.0
232 | pred = self.head.sample(x, num_sampling_steps=sample_steps, cfg=cfg_iter)
233 | pred = pred.view(-1, 1, pred.shape[-1])
234 | # Important: normalize here, for both next-token prediction and vae decoding
235 | pred = self.vae.normalize(pred)
236 | return pred
237 |
238 | @torch.no_grad()
239 | def sample(self, cond, sample_steps, cfg_scale=1.0, cfg_schedule="linear"):
240 | self.eval()
241 | if cfg_scale > 1.0:
242 | cond_null = torch.ones_like(cond) * self.num_classes
243 | cond_combined = torch.cat([cond, cond_null])
244 | else:
245 | cond_combined = cond
246 | bsz = cond_combined.shape[0]
247 | act_bsz = bsz // 2 if cfg_scale > 1.0 else bsz
248 | self.enable_kv_cache(bsz)
249 |
250 | c = self.cls_embedding(cond_combined).view(bsz, self.cls_token_num, -1)
251 | last_pred = None
252 | all_preds = []
253 | for i in range(self.h * self.w):
254 | if i == 0:
255 | x = self.forward_model(c, 0, self.cls_token_num)
256 | else:
257 | x = self.proj_in(last_pred)
258 | x = self.forward_model(
259 | x, i + self.cls_token_num - 1, i + self.cls_token_num
260 | )
261 | last_pred = self.head_sample(
262 | x[:, -1:, :],
263 | i,
264 | sample_steps,
265 | cfg_scale,
266 | cfg_schedule,
267 | )
268 | all_preds.append(last_pred)
269 |
270 | x = torch.cat(all_preds, dim=-2)[:act_bsz]
271 | recon = self.vae.decode(x)
272 | return recon
273 |
274 | def get_fsdp_wrap_module_list(self):
275 | return list(self.layers)
276 |
277 |
278 | def SphereAR_H(**kwargs):
279 | return SphereAR(
280 | n_layer=40,
281 | n_head=20,
282 | dim=1280,
283 | diff_layers=12,
284 | diff_dim=1280,
285 | diff_adanln_layers=3,
286 | **kwargs,
287 | )
288 |
289 |
290 | def SphereAR_L(**kwargs):
291 | return SphereAR(
292 | n_layer=32,
293 | n_head=16,
294 | dim=1024,
295 | diff_layers=8,
296 | diff_dim=1024,
297 | diff_adanln_layers=2,
298 | **kwargs,
299 | )
300 |
301 |
302 | def SphereAR_B(**kwargs):
303 | return SphereAR(
304 | n_layer=24,
305 | n_head=12,
306 | dim=768,
307 | diff_layers=6,
308 | diff_dim=768,
309 | diff_adanln_layers=2,
310 | **kwargs,
311 | )
312 |
313 |
314 | SphereAR_models = {
315 | "SphereAR-B": SphereAR_B,
316 | "SphereAR-L": SphereAR_L,
317 | "SphereAR-H": SphereAR_H,
318 | }
319 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # Modified from:
2 | # llamagen: https://github.com/FoundationVision/LlamaGen/
3 | import math
4 |
5 | import torch
6 |
7 | torch.backends.cuda.matmul.allow_tf32 = True
8 | torch.backends.cudnn.allow_tf32 = True
9 |
10 | import inspect
11 | import os
12 | import shutil
13 | import time
14 | from copy import deepcopy
15 | from multiprocessing.pool import ThreadPool
16 |
17 | import torch.distributed as dist
18 | from torch.nn.parallel import DistributedDataParallel as DDP
19 | from torch.utils.data import DataLoader
20 | from torch.utils.data.distributed import DistributedSampler
21 | from torch.utils.tensorboard import SummaryWriter
22 |
23 | from SphereAR.dataset import build_dataset
24 | from SphereAR.gan.gan_loss import GANLoss
25 | from SphereAR.model import create_model, get_model_args
26 | from SphereAR.utils import create_logger, requires_grad, update_ema
27 |
28 |
29 | def create_optimizer(model, weight_decay, learning_rate, betas, logger):
30 | def is_decay_param(name, param, no_decay_keys):
31 | for key in no_decay_keys:
32 | if key in name:
33 | return False
34 | if param.dim() < 2:
35 | return False
36 | return True
37 |
38 | # start with all of the candidate parameters
39 | param_dict = {pn: p for pn, p in model.named_parameters()}
40 | # filter out those that do not require grad
41 | param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
42 | no_decay_keys = model.non_decay_keys() if hasattr(model, "non_decay_keys") else []
43 | decay_params = [
44 | p for n, p in param_dict.items() if is_decay_param(n, p, no_decay_keys)
45 | ]
46 | nodecay_params = [
47 | p for n, p in param_dict.items() if not is_decay_param(n, p, no_decay_keys)
48 | ]
49 | optim_groups = [
50 | {"params": decay_params, "weight_decay": weight_decay},
51 | {"params": nodecay_params, "weight_decay": 0.0},
52 | ]
53 | num_decay_params = sum(p.numel() for p in decay_params)
54 | num_nodecay_params = sum(p.numel() for p in nodecay_params)
55 | logger.info(
56 | f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters"
57 | )
58 | logger.info(
59 | f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters"
60 | )
61 | # Create AdamW optimizer and use the fused version if it is available
62 | fused_available = "fused" in inspect.signature(torch.optim.AdamW).parameters
63 | extra_args = dict(fused=True) if fused_available else dict()
64 | optimizer = torch.optim.AdamW(
65 | optim_groups, lr=learning_rate, betas=betas, **extra_args
66 | )
67 | logger.info(f"using fused AdamW: {fused_available}")
68 | return optimizer
69 |
70 |
71 | def adjust_learning_rate(args, cur_steps, total_steps, optimizer):
72 | if cur_steps < args.warmup_steps and args.warmup_steps > 0:
73 | lr = args.lr * cur_steps / args.warmup_steps
74 | elif (
75 | args.decay_start > 0
76 | and cur_steps >= args.decay_start
77 | and args.decay_start < total_steps
78 | ):
79 | # decay from decay_start to total_steps, with learning rate cosine decay to min_lr
80 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * (
81 | 1.0
82 | + math.cos(
83 | math.pi
84 | * (cur_steps - args.decay_start)
85 | / max(total_steps - args.decay_start, 1e-8)
86 | )
87 | )
88 | else:
89 | lr = args.lr
90 | for param_group in optimizer.param_groups:
91 | param_group["lr"] = lr
92 | return lr
93 |
94 |
95 | def init_distributed_mode(args):
96 | args.rank = int(os.environ["RANK"])
97 | args.world_size = int(os.environ["WORLD_SIZE"])
98 | args.gpu = int(os.environ["LOCAL_RANK"])
99 |
100 | args.distributed = True
101 | device = torch.device("cuda", args.gpu)
102 | torch.cuda.set_device(device)
103 |
104 | print(f"| distributed init (rank {args.rank}, gpu {args.gpu})", flush=True)
105 |
106 | dist.init_process_group(
107 | backend="nccl",
108 | init_method="env://",
109 | world_size=args.world_size,
110 | rank=args.rank,
111 | device_id=device,
112 | )
113 | dist.barrier()
114 | return device
115 |
116 |
117 | def get_orig_model(model):
118 | if isinstance(model, DDP):
119 | model = model.module
120 | if hasattr(model, "_orig_mod"):
121 | model = model._orig_mod
122 | return model
123 |
124 |
125 | def _linear_decay_ratio(epoch: int, start: int, end: int) -> float:
126 | if start < 0 or end <= start:
127 | return 1.0
128 | if epoch < start:
129 | r = 1.0
130 | elif epoch >= end:
131 | r = 0.0
132 | else:
133 | r = 1.0 - (epoch - start) / float(end - start)
134 | return max(0.0, min(1.0, r))
135 |
136 |
137 | def create_dataloader(dataset, sampler, epoch, args):
138 | sampler.set_epoch(epoch)
139 | dataset.set_epoch(epoch)
140 | # linear decay of aug_ratio
141 | aug_ratio = _linear_decay_ratio(
142 | epoch, args.aug_decay_start_epoch, args.aug_decay_end_epoch
143 | )
144 | dataset.set_aug_ratio(aug_ratio)
145 | loader = DataLoader(
146 | dataset,
147 | batch_size=int(args.global_batch_size // dist.get_world_size()),
148 | shuffle=False,
149 | sampler=sampler,
150 | num_workers=args.num_workers,
151 | pin_memory=True,
152 | drop_last=True,
153 | )
154 | return loader
155 |
156 |
157 | def update_loss_dict(running_loss_dict, **kwargs):
158 | for k, v in kwargs.items():
159 | if v is not None:
160 | if torch.is_tensor(v):
161 | v = v.item()
162 | running_loss_dict[k] = running_loss_dict.get(k, 0.0) + v
163 | return running_loss_dict
164 |
165 |
166 | def vae_loss(
167 | args,
168 | gan_model,
169 | images,
170 | recon,
171 | kl_loss,
172 | train_steps,
173 | running_loss_dict,
174 | ):
175 | if args.vae_only:
176 | recon_loss, p_loss, gen_loss = gan_model(
177 | images,
178 | recon,
179 | optimizer_idx=0,
180 | global_step=train_steps + 1,
181 | )
182 | running_loss_dict = update_loss_dict(
183 | running_loss_dict,
184 | kl_loss=kl_loss,
185 | recon_loss=recon_loss,
186 | p_loss=p_loss,
187 | gen_loss=gen_loss,
188 | )
189 | loss = (
190 | args.perceptual_weight * p_loss
191 | + gen_loss
192 | + args.reconstruction_weight * recon_loss
193 | + args.kl_weight * kl_loss
194 | )
195 | return loss, running_loss_dict
196 | else:
197 | return 0.0, running_loss_dict
198 |
199 |
200 | def logging(
201 | running_loss_dict,
202 | running_gnorm,
203 | log_steps,
204 | steps_per_sec,
205 | train_steps,
206 | device,
207 | logger,
208 | tsb_writer,
209 | ):
210 | keys = sorted(running_loss_dict.keys())
211 | running_losses = [running_loss_dict[k] for k in keys]
212 | # Reduce loss history over all processes:
213 | all_loss = torch.tensor(
214 | running_losses,
215 | device=device,
216 | )
217 | dist.all_reduce(all_loss, op=dist.ReduceOp.SUM)
218 |
219 | avg_gnorm = running_gnorm / log_steps
220 | all_loss = [
221 | (keys[i], all_loss[i].item() / dist.get_world_size() / log_steps)
222 | for i in range(len(keys))
223 | ]
224 | loss_str = ", ".join([f"{k}: {v:.4f}" for k, v in all_loss])
225 |
226 | logger.info(
227 | f"(step={train_steps:07d}): {loss_str} ,Train Steps/Sec: {steps_per_sec:.2f}, Train Grad Norm: {avg_gnorm:.4f}"
228 | )
229 | if tsb_writer is not None:
230 | for k, v in all_loss:
231 | tsb_writer.add_scalar(f"train/{k}", v, train_steps)
232 | tsb_writer.add_scalar("train/steps_per_sec", steps_per_sec, train_steps)
233 | tsb_writer.add_scalar("train/grad_norm", avg_gnorm, train_steps)
234 |
235 |
236 | def copy_ckp_func(src_file, dest_path, cur_epoch, keep_freq):
237 | dest_file = os.path.join(dest_path, "last.pt")
238 | if os.path.exists(dest_file):
239 | shutil.copyfile(dest_file, os.path.join(dest_path, "prev.pt"))
240 | shutil.copyfile(src_file, dest_file)
241 | if cur_epoch > 0 and keep_freq > 0 and cur_epoch % keep_freq == 0:
242 | shutil.copyfile(
243 | dest_file,
244 | os.path.join(dest_path, f"epoch_{cur_epoch}.pt"),
245 | )
246 |
247 |
248 | def main(args):
249 | assert torch.cuda.is_available(), "Training currently requires at least one GPU."
250 |
251 | # Setup DDP:
252 | device = init_distributed_mode(args)
253 | assert (
254 | args.global_batch_size % dist.get_world_size() == 0
255 | ), f"Batch size must be divisible by world size."
256 | rank = dist.get_rank()
257 | seed = args.global_seed * dist.get_world_size() + rank
258 | torch.manual_seed(seed)
259 | torch.cuda.set_device(device)
260 |
261 | results_dir = args.results_dir
262 |
263 | if rank == 0:
264 | os.makedirs(args.results_dir, exist_ok=True)
265 | logger = create_logger(results_dir)
266 | logger.info(f"Experiment directory created at {results_dir}")
267 | ckp_async_thread = ThreadPool(processes=1)
268 | else:
269 | logger = create_logger(None)
270 | ckp_async_thread = None
271 |
272 | # training args
273 | logger.info(f"{args}")
274 |
275 | # training env
276 | logger.info(
277 | f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}."
278 | )
279 |
280 | model = create_model(args, device)
281 |
282 | if args.trained_vae != "":
283 | vae_info = torch.load(args.trained_vae, map_location="cpu", weights_only=False)
284 | vae_info = vae_info["model"]
285 | res = model.load_state_dict(vae_info, strict=False)
286 | for k in res.missing_keys:
287 | if "vae" in k:
288 | raise ValueError(f"Fail to load VAE weights from {args.trained_vae}")
289 | model.freeze_vae()
290 | logger.info(f"loaded pretrained VAE from {args.trained_vae}")
291 |
292 | logger.info(model)
293 | logger.info(f"Model Parameters: {sum(p.numel() for p in model.parameters()):,}")
294 |
295 | if args.ema > 0:
296 | ema_model = deepcopy(model).to(
297 | device
298 | ) # Create an EMA of the model for use after training
299 | requires_grad(ema_model, False)
300 | logger.info(
301 | f"EMA Parameters: {sum(p.numel() for p in ema_model.parameters()):,}"
302 | )
303 |
304 | # Setup optimizer
305 | optimizer = create_optimizer(
306 | model, args.weight_decay, args.lr, (args.beta1, args.beta2), logger
307 | )
308 |
309 | # Setup gan loss for vae training
310 | if args.vae_only:
311 | gan_model = GANLoss(
312 | disc_start=args.disc_start,
313 | disc_weight=args.disc_weight,
314 | disc_type=args.disc_type,
315 | disc_loss=args.disc_loss,
316 | gen_adv_loss=args.gen_loss,
317 | image_size=args.image_size,
318 | reconstruction_loss=args.reconstruction_loss,
319 | ).to(device, memory_format=torch.channels_last)
320 | logger.info(
321 | f"Discriminator Parameters: {sum(p.numel() for p in gan_model.discriminator.parameters()):,}"
322 | )
323 | optimizer_disc = create_optimizer(
324 | gan_model.discriminator,
325 | args.weight_decay,
326 | args.lr,
327 | (args.beta1, args.beta2),
328 | logger,
329 | )
330 | else:
331 | gan_model = None
332 | optimizer_disc = None
333 |
334 | dataset = build_dataset(args)
335 | sampler = DistributedSampler(
336 | dataset,
337 | num_replicas=dist.get_world_size(),
338 | rank=rank,
339 | shuffle=True,
340 | seed=args.global_seed,
341 | )
342 |
343 | logger.info(f"Dataset contains {len(dataset):,} images ({args.data_path})")
344 |
345 | checkpoint_path = f"{results_dir}/last.pt"
346 | total_steps = args.epochs * int(len(dataset) / args.global_batch_size)
347 |
348 | # Prepare models for training:
349 | if os.path.exists(checkpoint_path):
350 | checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
351 | start_epoch = checkpoint["epochs"]
352 | train_steps = int(start_epoch * int(len(dataset) / args.global_batch_size))
353 | model.load_state_dict(checkpoint["model"], strict=True)
354 | if args.ema > 0:
355 | ema_model.load_state_dict(
356 | checkpoint["ema"] if "ema" in checkpoint else checkpoint["model"]
357 | )
358 | optimizer.load_state_dict(checkpoint["optimizer"])
359 | if args.vae_only:
360 | gan_model.discriminator.load_state_dict(
361 | checkpoint["model_disc"], strict=True
362 | )
363 | optimizer_disc.load_state_dict(checkpoint["optimizer_disc"])
364 |
365 | del checkpoint
366 |
367 | logger.info(f"Resume training from checkpoint: {checkpoint_path}")
368 | logger.info(f"Initial state: steps={train_steps}, epochs={start_epoch}")
369 | else:
370 | train_steps = 0
371 | start_epoch = 0
372 | if args.ema > 0:
373 | update_ema(ema_model, model, decay=0)
374 |
375 | if not args.no_compile:
376 | logger.info("compiling the model... (may take several minutes)")
377 | model = torch.compile(model) # requires PyTorch 2.0
378 | if args.vae_only:
379 | gan_model = torch.compile(gan_model)
380 |
381 | model = DDP(model.to(device), device_ids=[args.gpu])
382 | model.train()
383 |
384 | if args.vae_only:
385 | gan_model = DDP(gan_model.to(device), device_ids=[args.gpu])
386 | gan_model.train()
387 | if args.ema > 0:
388 | ema_model.eval()
389 |
390 | ptdtype = {"none": torch.float32, "bf16": torch.bfloat16}[args.mixed_precision]
391 |
392 | log_steps = 0
393 | running_loss_dict = {}
394 | running_gnorm = 0
395 | start_time = time.time()
396 |
397 | logger.info(f"Training for {args.epochs} epochs ({total_steps} steps)")
398 | tsb_writer = SummaryWriter(log_dir=results_dir) if rank == 0 else None
399 |
400 | for epoch in range(start_epoch, args.epochs):
401 | loader = create_dataloader(dataset, sampler, epoch, args)
402 |
403 | logger.info(f"Beginning epoch {epoch}...")
404 | for images, classes in loader:
405 | classes = classes.to(device, non_blocking=True)
406 | images = images.to(device, non_blocking=True).contiguous(
407 | memory_format=torch.channels_last
408 | )
409 |
410 | optimizer.zero_grad(set_to_none=True)
411 |
412 | with torch.amp.autocast("cuda", dtype=ptdtype):
413 | ar_loss, kl_loss, recon = model(images, classes)
414 | gan_g_loss, running_loss_dict = vae_loss(
415 | args,
416 | gan_model,
417 | images,
418 | recon,
419 | kl_loss,
420 | train_steps,
421 | running_loss_dict,
422 | )
423 | if not args.vae_only:
424 | running_loss_dict = update_loss_dict(running_loss_dict, loss=ar_loss)
425 | loss = ar_loss + gan_g_loss
426 | loss.backward()
427 |
428 | if args.max_grad_norm != 0.0:
429 | gnorm = torch.nn.utils.clip_grad_norm_(
430 | model.parameters(), args.max_grad_norm
431 | )
432 | running_gnorm += gnorm.item()
433 | cur_lr = adjust_learning_rate(args, train_steps, total_steps, optimizer)
434 | running_loss_dict = update_loss_dict(running_loss_dict, lr=cur_lr)
435 | optimizer.step()
436 |
437 | if args.ema > 0:
438 | update_ema(ema_model, get_orig_model(model), decay=args.ema)
439 |
440 | if args.vae_only:
441 | # discriminator training
442 | optimizer_disc.zero_grad(set_to_none=True)
443 | with torch.amp.autocast("cuda", dtype=ptdtype):
444 | disc_loss = gan_model(
445 | images,
446 | recon.detach(),
447 | optimizer_idx=1,
448 | global_step=train_steps + 1,
449 | )
450 | running_loss_dict = update_loss_dict(
451 | running_loss_dict, disc_loss=disc_loss
452 | )
453 | disc_loss.backward()
454 | if args.max_grad_norm != 0.0:
455 | torch.nn.utils.clip_grad_norm_(
456 | get_orig_model(gan_model).discriminator.parameters(),
457 | args.max_grad_norm,
458 | )
459 | adjust_learning_rate(args, train_steps, total_steps, optimizer_disc)
460 | optimizer_disc.step()
461 |
462 | log_steps += 1
463 | train_steps += 1
464 | if train_steps % args.log_every == 0:
465 | # Measure training speed:
466 | torch.cuda.synchronize()
467 | end_time = time.time()
468 | steps_per_sec = log_steps / (end_time - start_time)
469 | logging(
470 | running_loss_dict,
471 | running_gnorm,
472 | log_steps,
473 | steps_per_sec,
474 | train_steps,
475 | device,
476 | logger,
477 | tsb_writer,
478 | )
479 | running_loss_dict = {}
480 | running_gnorm = 0
481 | log_steps = 0
482 | start_time = time.time()
483 |
484 | # save checkpoint at the end of each epoch
485 | if rank == 0:
486 | checkpoint = {
487 | "model": get_orig_model(model).state_dict(),
488 | "optimizer": optimizer.state_dict(),
489 | "epochs": epoch + 1,
490 | "args": args,
491 | }
492 | if args.vae_only:
493 | checkpoint["model_disc"] = get_orig_model(
494 | gan_model
495 | ).discriminator.state_dict()
496 | checkpoint["optimizer_disc"] = optimizer_disc.state_dict()
497 | if args.ema > 0:
498 | checkpoint["ema"] = ema_model.state_dict()
499 | # save on /dev/shm (memory), then async copy to remote
500 | local_file = os.path.join(args.tmp_results_dir, "last.pt")
501 | torch.save(checkpoint, local_file)
502 | ckp_async_thread.apply_async(
503 | copy_ckp_func,
504 | args=(local_file, results_dir, epoch + 1, args.keep_freq),
505 | error_callback=lambda e: logger.error("async copy error :" + str(e)),
506 | )
507 |
508 | dist.barrier()
509 |
510 | if ckp_async_thread is not None:
511 | ckp_async_thread.close()
512 | ckp_async_thread.join()
513 |
514 | if rank == 0:
515 | # free space by removing prev checkpoint
516 | os.system(f"rm {results_dir}/prev.pt")
517 | logger.info("Done!")
518 | dist.destroy_process_group()
519 |
520 |
521 | if __name__ == "__main__":
522 | parser = get_model_args()
523 | parser.add_argument("--data-path", type=str, required=True)
524 | parser.add_argument("--trained-vae", type=str, default="")
525 | parser.add_argument("--aug-decay-start-epoch", type=int, default=350)
526 | parser.add_argument("--aug-decay-end-epoch", type=int, default=375)
527 | parser.add_argument("--ema", default=-1, type=float)
528 | parser.add_argument("--no-compile", action="store_true")
529 | parser.add_argument("--tmp-results-dir", type=str, default="/dev/shm/")
530 | parser.add_argument("--results-dir", type=str, default="results")
531 | parser.add_argument("--epochs", type=int, default=400)
532 | parser.add_argument("--lr", type=float, default=3e-4)
533 | parser.add_argument("--min-lr", type=float, default=1e-5)
534 | parser.add_argument("--warmup-steps", type=int, default=20000)
535 | parser.add_argument("--decay-start", type=int, default=20000)
536 | parser.add_argument(
537 | "--weight-decay", type=float, default=5e-2, help="Weight decay to use"
538 | )
539 | parser.add_argument("--beta1", type=float, default=0.9)
540 | parser.add_argument("--beta2", type=float, default=0.95)
541 | parser.add_argument(
542 | "--max-grad-norm", default=1.0, type=float, help="Max gradient norm."
543 | )
544 | parser.add_argument("--global-batch-size", type=int, default=256)
545 | parser.add_argument("--global-seed", type=int, default=0)
546 | parser.add_argument("--num-workers", type=int, default=16)
547 | parser.add_argument("--log-every", type=int, default=100)
548 | parser.add_argument(
549 | "--mixed-precision", type=str, default="bf16", choices=["none", "bf16"]
550 | )
551 | parser.add_argument("--reconstruction-weight", type=float, default=1.0)
552 | parser.add_argument("--reconstruction-loss", type=str, default="l1")
553 | parser.add_argument("--perceptual-weight", type=float, default=1.0)
554 | parser.add_argument("--disc-weight", type=float, default=0.5)
555 | parser.add_argument("--kl-weight", type=float, default=0.004)
556 | parser.add_argument("--disc-start", type=int, default=20000)
557 | parser.add_argument(
558 | "--disc-type", type=str, choices=["patchgan", "stylegan"], default="patchgan"
559 | )
560 | parser.add_argument(
561 | "--disc-loss",
562 | type=str,
563 | choices=["hinge", "vanilla", "non-saturating"],
564 | default="hinge",
565 | )
566 | parser.add_argument(
567 | "--gen-loss", type=str, choices=["hinge", "non-saturating"], default="hinge"
568 | )
569 | parser.add_argument("--keep-freq", type=int, default=50)
570 | main(parser.parse_args())
571 |
--------------------------------------------------------------------------------
/evaluator.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import io
3 | import os
4 | import random
5 | import warnings
6 | import zipfile
7 | from abc import ABC, abstractmethod
8 | from contextlib import contextmanager
9 | from functools import partial
10 | from multiprocessing import cpu_count
11 | from multiprocessing.pool import ThreadPool
12 | from typing import Iterable, Optional, Tuple
13 |
14 | import numpy as np
15 | import requests
16 | import tensorflow.compat.v1 as tf
17 | from scipy import linalg
18 | from tqdm.auto import tqdm
19 |
20 | INCEPTION_V3_URL = "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/classify_image_graph_def.pb"
21 | INCEPTION_V3_PATH = "classify_image_graph_def.pb"
22 |
23 | FID_POOL_NAME = "pool_3:0"
24 | FID_SPATIAL_NAME = "mixed_6/conv:0"
25 |
26 |
27 | def main():
28 | parser = argparse.ArgumentParser()
29 | parser.add_argument("ref_batch", help="path to reference batch npz file")
30 | parser.add_argument("sample_batch", help="path to sample batch npz file")
31 | args = parser.parse_args()
32 |
33 | config = tf.ConfigProto(
34 | allow_soft_placement=True # allows DecodeJpeg to run on CPU in Inception graph
35 | )
36 | config.gpu_options.allow_growth = True
37 | evaluator = Evaluator(tf.Session(config=config))
38 |
39 | print("warming up TensorFlow...")
40 | # This will cause TF to print a bunch of verbose stuff now rather
41 | # than after the next print(), to help prevent confusion.
42 | evaluator.warmup()
43 |
44 | print("computing reference batch activations...")
45 | ref_acts = evaluator.read_activations(args.ref_batch)
46 | print("computing/reading reference batch statistics...")
47 | ref_stats, ref_stats_spatial = evaluator.read_statistics(args.ref_batch, ref_acts)
48 |
49 | print("computing sample batch activations...")
50 | sample_acts = evaluator.read_activations(args.sample_batch)
51 | print("computing/reading sample batch statistics...")
52 | sample_stats, sample_stats_spatial = evaluator.read_statistics(
53 | args.sample_batch, sample_acts
54 | )
55 |
56 | print("Computing evaluations...")
57 | IS = evaluator.compute_inception_score(sample_acts[0])
58 | FID = sample_stats.frechet_distance(ref_stats)
59 | sFID = sample_stats_spatial.frechet_distance(ref_stats_spatial)
60 | print("Inception Score:", IS)
61 | print("FID:", FID)
62 | print("sFID:", sFID)
63 | prec, recall = evaluator.compute_prec_recall(ref_acts[0], sample_acts[0])
64 | print("Precision:", prec)
65 | print("Recall:", recall)
66 |
67 | txt_path = args.sample_batch.replace(".npz", ".txt")
68 | print("writing to {}".format(txt_path))
69 | with open(txt_path, "w") as f:
70 | print("Inception Score:", IS, file=f)
71 | print("FID:", FID, file=f)
72 | print("sFID:", sFID, file=f)
73 | print("Precision:", prec, file=f)
74 | print("Recall:", recall, file=f)
75 |
76 |
77 | class InvalidFIDException(Exception):
78 | pass
79 |
80 |
81 | class FIDStatistics:
82 | def __init__(self, mu: np.ndarray, sigma: np.ndarray):
83 | self.mu = mu
84 | self.sigma = sigma
85 |
86 | def frechet_distance(self, other, eps=1e-6):
87 | """
88 | Compute the Frechet distance between two sets of statistics.
89 | """
90 | # https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L132
91 | mu1, sigma1 = self.mu, self.sigma
92 | mu2, sigma2 = other.mu, other.sigma
93 |
94 | mu1 = np.atleast_1d(mu1)
95 | mu2 = np.atleast_1d(mu2)
96 |
97 | sigma1 = np.atleast_2d(sigma1)
98 | sigma2 = np.atleast_2d(sigma2)
99 |
100 | assert (
101 | mu1.shape == mu2.shape
102 | ), f"Training and test mean vectors have different lengths: {mu1.shape}, {mu2.shape}"
103 | assert (
104 | sigma1.shape == sigma2.shape
105 | ), f"Training and test covariances have different dimensions: {sigma1.shape}, {sigma2.shape}"
106 |
107 | diff = mu1 - mu2
108 |
109 | # product might be almost singular
110 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
111 | if not np.isfinite(covmean).all():
112 | msg = (
113 | "fid calculation produces singular product; adding %s to diagonal of cov estimates"
114 | % eps
115 | )
116 | warnings.warn(msg)
117 | offset = np.eye(sigma1.shape[0]) * eps
118 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
119 |
120 | # numerical error might give slight imaginary component
121 | if np.iscomplexobj(covmean):
122 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
123 | m = np.max(np.abs(covmean.imag))
124 | raise ValueError("Imaginary component {}".format(m))
125 | covmean = covmean.real
126 |
127 | tr_covmean = np.trace(covmean)
128 |
129 | return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
130 |
131 |
132 | class Evaluator:
133 | def __init__(
134 | self,
135 | session,
136 | batch_size=64,
137 | softmax_batch_size=512,
138 | ):
139 | self.sess = session
140 | self.batch_size = batch_size
141 | self.softmax_batch_size = softmax_batch_size
142 | self.manifold_estimator = ManifoldEstimator(session)
143 | with self.sess.graph.as_default():
144 | self.image_input = tf.placeholder(tf.float32, shape=[None, None, None, 3])
145 | self.softmax_input = tf.placeholder(tf.float32, shape=[None, 2048])
146 | self.pool_features, self.spatial_features = _create_feature_graph(
147 | self.image_input
148 | )
149 | self.softmax = _create_softmax_graph(self.softmax_input)
150 |
151 | def warmup(self):
152 | self.compute_activations(np.zeros([1, 8, 64, 64, 3]))
153 |
154 | def read_activations(self, npz_path: str) -> Tuple[np.ndarray, np.ndarray]:
155 | with open_npz_array(npz_path, "arr_0") as reader:
156 | return self.compute_activations(reader.read_batches(self.batch_size))
157 |
158 | def compute_activations(
159 | self, batches: Iterable[np.ndarray]
160 | ) -> Tuple[np.ndarray, np.ndarray]:
161 | """
162 | Compute image features for downstream evals.
163 |
164 | :param batches: a iterator over NHWC numpy arrays in [0, 255].
165 | :return: a tuple of numpy arrays of shape [N x X], where X is a feature
166 | dimension. The tuple is (pool_3, spatial).
167 | """
168 | preds = []
169 | spatial_preds = []
170 | for batch in tqdm(batches):
171 | batch = batch.astype(np.float32)
172 | pred, spatial_pred = self.sess.run(
173 | [self.pool_features, self.spatial_features], {self.image_input: batch}
174 | )
175 | preds.append(pred.reshape([pred.shape[0], -1]))
176 | spatial_preds.append(spatial_pred.reshape([spatial_pred.shape[0], -1]))
177 | return (
178 | np.concatenate(preds, axis=0),
179 | np.concatenate(spatial_preds, axis=0),
180 | )
181 |
182 | def read_statistics(
183 | self, npz_path: str, activations: Tuple[np.ndarray, np.ndarray]
184 | ) -> Tuple[FIDStatistics, FIDStatistics]:
185 | obj = np.load(npz_path)
186 | if "mu" in list(obj.keys()):
187 | return FIDStatistics(obj["mu"], obj["sigma"]), FIDStatistics(
188 | obj["mu_s"], obj["sigma_s"]
189 | )
190 | return tuple(self.compute_statistics(x) for x in activations)
191 |
192 | def compute_statistics(self, activations: np.ndarray) -> FIDStatistics:
193 | mu = np.mean(activations, axis=0)
194 | sigma = np.cov(activations, rowvar=False)
195 | return FIDStatistics(mu, sigma)
196 |
197 | def compute_inception_score(
198 | self, activations: np.ndarray, split_size: int = 5000
199 | ) -> float:
200 | softmax_out = []
201 | for i in range(0, len(activations), self.softmax_batch_size):
202 | acts = activations[i : i + self.softmax_batch_size]
203 | softmax_out.append(
204 | self.sess.run(self.softmax, feed_dict={self.softmax_input: acts})
205 | )
206 | preds = np.concatenate(softmax_out, axis=0)
207 | # https://github.com/openai/improved-gan/blob/4f5d1ec5c16a7eceb206f42bfc652693601e1d5c/inception_score/model.py#L46
208 | scores = []
209 | for i in range(0, len(preds), split_size):
210 | part = preds[i : i + split_size]
211 | kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
212 | kl = np.mean(np.sum(kl, 1))
213 | scores.append(np.exp(kl))
214 | return float(np.mean(scores))
215 |
216 | def compute_prec_recall(
217 | self, activations_ref: np.ndarray, activations_sample: np.ndarray
218 | ) -> Tuple[float, float]:
219 | radii_1 = self.manifold_estimator.manifold_radii(activations_ref)
220 | radii_2 = self.manifold_estimator.manifold_radii(activations_sample)
221 | pr = self.manifold_estimator.evaluate_pr(
222 | activations_ref, radii_1, activations_sample, radii_2
223 | )
224 | return (float(pr[0][0]), float(pr[1][0]))
225 |
226 |
227 | class ManifoldEstimator:
228 | """
229 | A helper for comparing manifolds of feature vectors.
230 |
231 | Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L57
232 | """
233 |
234 | def __init__(
235 | self,
236 | session,
237 | row_batch_size=10000,
238 | col_batch_size=10000,
239 | nhood_sizes=(3,),
240 | clamp_to_percentile=None,
241 | eps=1e-5,
242 | ):
243 | """
244 | Estimate the manifold of given feature vectors.
245 |
246 | :param session: the TensorFlow session.
247 | :param row_batch_size: row batch size to compute pairwise distances
248 | (parameter to trade-off between memory usage and performance).
249 | :param col_batch_size: column batch size to compute pairwise distances.
250 | :param nhood_sizes: number of neighbors used to estimate the manifold.
251 | :param clamp_to_percentile: prune hyperspheres that have radius larger than
252 | the given percentile.
253 | :param eps: small number for numerical stability.
254 | """
255 | self.distance_block = DistanceBlock(session)
256 | self.row_batch_size = row_batch_size
257 | self.col_batch_size = col_batch_size
258 | self.nhood_sizes = nhood_sizes
259 | self.num_nhoods = len(nhood_sizes)
260 | self.clamp_to_percentile = clamp_to_percentile
261 | self.eps = eps
262 |
263 | def warmup(self):
264 | feats, radii = (
265 | np.zeros([1, 2048], dtype=np.float32),
266 | np.zeros([1, 1], dtype=np.float32),
267 | )
268 | self.evaluate_pr(feats, radii, feats, radii)
269 |
270 | def manifold_radii(self, features: np.ndarray) -> np.ndarray:
271 | num_images = len(features)
272 |
273 | # Estimate manifold of features by calculating distances to k-NN of each sample.
274 | radii = np.zeros([num_images, self.num_nhoods], dtype=np.float32)
275 | distance_batch = np.zeros([self.row_batch_size, num_images], dtype=np.float32)
276 | seq = np.arange(max(self.nhood_sizes) + 1, dtype=np.int32)
277 |
278 | for begin1 in range(0, num_images, self.row_batch_size):
279 | end1 = min(begin1 + self.row_batch_size, num_images)
280 | row_batch = features[begin1:end1]
281 |
282 | for begin2 in range(0, num_images, self.col_batch_size):
283 | end2 = min(begin2 + self.col_batch_size, num_images)
284 | col_batch = features[begin2:end2]
285 |
286 | # Compute distances between batches.
287 | distance_batch[0 : end1 - begin1, begin2:end2] = (
288 | self.distance_block.pairwise_distances(row_batch, col_batch)
289 | )
290 |
291 | # Find the k-nearest neighbor from the current batch.
292 | radii[begin1:end1, :] = np.concatenate(
293 | [
294 | x[:, self.nhood_sizes]
295 | for x in _numpy_partition(
296 | distance_batch[0 : end1 - begin1, :], seq, axis=1
297 | )
298 | ],
299 | axis=0,
300 | )
301 |
302 | if self.clamp_to_percentile is not None:
303 | max_distances = np.percentile(radii, self.clamp_to_percentile, axis=0)
304 | radii[radii > max_distances] = 0
305 | return radii
306 |
307 | def evaluate(
308 | self, features: np.ndarray, radii: np.ndarray, eval_features: np.ndarray
309 | ):
310 | """
311 | Evaluate if new feature vectors are at the manifold.
312 | """
313 | num_eval_images = eval_features.shape[0]
314 | num_ref_images = radii.shape[0]
315 | distance_batch = np.zeros(
316 | [self.row_batch_size, num_ref_images], dtype=np.float32
317 | )
318 | batch_predictions = np.zeros([num_eval_images, self.num_nhoods], dtype=np.int32)
319 | max_realism_score = np.zeros([num_eval_images], dtype=np.float32)
320 | nearest_indices = np.zeros([num_eval_images], dtype=np.int32)
321 |
322 | for begin1 in range(0, num_eval_images, self.row_batch_size):
323 | end1 = min(begin1 + self.row_batch_size, num_eval_images)
324 | feature_batch = eval_features[begin1:end1]
325 |
326 | for begin2 in range(0, num_ref_images, self.col_batch_size):
327 | end2 = min(begin2 + self.col_batch_size, num_ref_images)
328 | ref_batch = features[begin2:end2]
329 |
330 | distance_batch[0 : end1 - begin1, begin2:end2] = (
331 | self.distance_block.pairwise_distances(feature_batch, ref_batch)
332 | )
333 |
334 | # From the minibatch of new feature vectors, determine if they are in the estimated manifold.
335 | # If a feature vector is inside a hypersphere of some reference sample, then
336 | # the new sample lies at the estimated manifold.
337 | # The radii of the hyperspheres are determined from distances of neighborhood size k.
338 | samples_in_manifold = distance_batch[0 : end1 - begin1, :, None] <= radii
339 | batch_predictions[begin1:end1] = np.any(samples_in_manifold, axis=1).astype(
340 | np.int32
341 | )
342 |
343 | max_realism_score[begin1:end1] = np.max(
344 | radii[:, 0] / (distance_batch[0 : end1 - begin1, :] + self.eps), axis=1
345 | )
346 | nearest_indices[begin1:end1] = np.argmin(
347 | distance_batch[0 : end1 - begin1, :], axis=1
348 | )
349 |
350 | return {
351 | "fraction": float(np.mean(batch_predictions)),
352 | "batch_predictions": batch_predictions,
353 | "max_realisim_score": max_realism_score,
354 | "nearest_indices": nearest_indices,
355 | }
356 |
357 | def evaluate_pr(
358 | self,
359 | features_1: np.ndarray,
360 | radii_1: np.ndarray,
361 | features_2: np.ndarray,
362 | radii_2: np.ndarray,
363 | ) -> Tuple[np.ndarray, np.ndarray]:
364 | """
365 | Evaluate precision and recall efficiently.
366 |
367 | :param features_1: [N1 x D] feature vectors for reference batch.
368 | :param radii_1: [N1 x K1] radii for reference vectors.
369 | :param features_2: [N2 x D] feature vectors for the other batch.
370 | :param radii_2: [N x K2] radii for other vectors.
371 | :return: a tuple of arrays for (precision, recall):
372 | - precision: an np.ndarray of length K1
373 | - recall: an np.ndarray of length K2
374 | """
375 | features_1_status = np.zeros([len(features_1), radii_2.shape[1]], dtype=bool)
376 | features_2_status = np.zeros([len(features_2), radii_1.shape[1]], dtype=bool)
377 | for begin_1 in range(0, len(features_1), self.row_batch_size):
378 | end_1 = begin_1 + self.row_batch_size
379 | batch_1 = features_1[begin_1:end_1]
380 | for begin_2 in range(0, len(features_2), self.col_batch_size):
381 | end_2 = begin_2 + self.col_batch_size
382 | batch_2 = features_2[begin_2:end_2]
383 | batch_1_in, batch_2_in = self.distance_block.less_thans(
384 | batch_1, radii_1[begin_1:end_1], batch_2, radii_2[begin_2:end_2]
385 | )
386 | features_1_status[begin_1:end_1] |= batch_1_in
387 | features_2_status[begin_2:end_2] |= batch_2_in
388 | return (
389 | np.mean(features_2_status.astype(np.float64), axis=0),
390 | np.mean(features_1_status.astype(np.float64), axis=0),
391 | )
392 |
393 |
394 | class DistanceBlock:
395 | """
396 | Calculate pairwise distances between vectors.
397 |
398 | Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L34
399 | """
400 |
401 | def __init__(self, session):
402 | self.session = session
403 |
404 | # Initialize TF graph to calculate pairwise distances.
405 | with session.graph.as_default():
406 | self._features_batch1 = tf.placeholder(tf.float32, shape=[None, None])
407 | self._features_batch2 = tf.placeholder(tf.float32, shape=[None, None])
408 | distance_block_16 = _batch_pairwise_distances(
409 | tf.cast(self._features_batch1, tf.float16),
410 | tf.cast(self._features_batch2, tf.float16),
411 | )
412 | self.distance_block = tf.cond(
413 | tf.reduce_all(tf.math.is_finite(distance_block_16)),
414 | lambda: tf.cast(distance_block_16, tf.float32),
415 | lambda: _batch_pairwise_distances(
416 | self._features_batch1, self._features_batch2
417 | ),
418 | )
419 |
420 | # Extra logic for less thans.
421 | self._radii1 = tf.placeholder(tf.float32, shape=[None, None])
422 | self._radii2 = tf.placeholder(tf.float32, shape=[None, None])
423 | dist32 = tf.cast(self.distance_block, tf.float32)[..., None]
424 | self._batch_1_in = tf.math.reduce_any(dist32 <= self._radii2, axis=1)
425 | self._batch_2_in = tf.math.reduce_any(
426 | dist32 <= self._radii1[:, None], axis=0
427 | )
428 |
429 | def pairwise_distances(self, U, V):
430 | """
431 | Evaluate pairwise distances between two batches of feature vectors.
432 | """
433 | return self.session.run(
434 | self.distance_block,
435 | feed_dict={self._features_batch1: U, self._features_batch2: V},
436 | )
437 |
438 | def less_thans(self, batch_1, radii_1, batch_2, radii_2):
439 | return self.session.run(
440 | [self._batch_1_in, self._batch_2_in],
441 | feed_dict={
442 | self._features_batch1: batch_1,
443 | self._features_batch2: batch_2,
444 | self._radii1: radii_1,
445 | self._radii2: radii_2,
446 | },
447 | )
448 |
449 |
450 | def _batch_pairwise_distances(U, V):
451 | """
452 | Compute pairwise distances between two batches of feature vectors.
453 | """
454 | with tf.variable_scope("pairwise_dist_block"):
455 | # Squared norms of each row in U and V.
456 | norm_u = tf.reduce_sum(tf.square(U), 1)
457 | norm_v = tf.reduce_sum(tf.square(V), 1)
458 |
459 | # norm_u as a column and norm_v as a row vectors.
460 | norm_u = tf.reshape(norm_u, [-1, 1])
461 | norm_v = tf.reshape(norm_v, [1, -1])
462 |
463 | # Pairwise squared Euclidean distances.
464 | D = tf.maximum(norm_u - 2 * tf.matmul(U, V, False, True) + norm_v, 0.0)
465 |
466 | return D
467 |
468 |
469 | class NpzArrayReader(ABC):
470 | @abstractmethod
471 | def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
472 | pass
473 |
474 | @abstractmethod
475 | def remaining(self) -> int:
476 | pass
477 |
478 | def read_batches(self, batch_size: int) -> Iterable[np.ndarray]:
479 | def gen_fn():
480 | while True:
481 | batch = self.read_batch(batch_size)
482 | if batch is None:
483 | break
484 | yield batch
485 |
486 | rem = self.remaining()
487 | num_batches = rem // batch_size + int(rem % batch_size != 0)
488 | return BatchIterator(gen_fn, num_batches)
489 |
490 |
491 | class BatchIterator:
492 | def __init__(self, gen_fn, length):
493 | self.gen_fn = gen_fn
494 | self.length = length
495 |
496 | def __len__(self):
497 | return self.length
498 |
499 | def __iter__(self):
500 | return self.gen_fn()
501 |
502 |
503 | class StreamingNpzArrayReader(NpzArrayReader):
504 | def __init__(self, arr_f, shape, dtype):
505 | self.arr_f = arr_f
506 | self.shape = shape
507 | self.dtype = dtype
508 | self.idx = 0
509 |
510 | def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
511 | if self.idx >= self.shape[0]:
512 | return None
513 |
514 | bs = min(batch_size, self.shape[0] - self.idx)
515 | self.idx += bs
516 |
517 | if self.dtype.itemsize == 0:
518 | return np.ndarray([bs, *self.shape[1:]], dtype=self.dtype)
519 |
520 | read_count = bs * np.prod(self.shape[1:])
521 | read_size = int(read_count * self.dtype.itemsize)
522 | data = _read_bytes(self.arr_f, read_size, "array data")
523 | return np.frombuffer(data, dtype=self.dtype).reshape([bs, *self.shape[1:]])
524 |
525 | def remaining(self) -> int:
526 | return max(0, self.shape[0] - self.idx)
527 |
528 |
529 | class MemoryNpzArrayReader(NpzArrayReader):
530 | def __init__(self, arr):
531 | self.arr = arr
532 | self.idx = 0
533 |
534 | @classmethod
535 | def load(cls, path: str, arr_name: str):
536 | with open(path, "rb") as f:
537 | arr = np.load(f)[arr_name]
538 | return cls(arr)
539 |
540 | def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
541 | if self.idx >= self.arr.shape[0]:
542 | return None
543 |
544 | res = self.arr[self.idx : self.idx + batch_size]
545 | self.idx += batch_size
546 | return res
547 |
548 | def remaining(self) -> int:
549 | return max(0, self.arr.shape[0] - self.idx)
550 |
551 |
552 | @contextmanager
553 | def open_npz_array(path: str, arr_name: str) -> NpzArrayReader:
554 | with _open_npy_file(path, arr_name) as arr_f:
555 | version = np.lib.format.read_magic(arr_f)
556 | if version == (1, 0):
557 | header = np.lib.format.read_array_header_1_0(arr_f)
558 | elif version == (2, 0):
559 | header = np.lib.format.read_array_header_2_0(arr_f)
560 | else:
561 | yield MemoryNpzArrayReader.load(path, arr_name)
562 | return
563 | shape, fortran, dtype = header
564 | if fortran or dtype.hasobject:
565 | yield MemoryNpzArrayReader.load(path, arr_name)
566 | else:
567 | yield StreamingNpzArrayReader(arr_f, shape, dtype)
568 |
569 |
570 | def _read_bytes(fp, size, error_template="ran out of data"):
571 | """
572 | Copied from: https://github.com/numpy/numpy/blob/fb215c76967739268de71aa4bda55dd1b062bc2e/numpy/lib/format.py#L788-L886
573 |
574 | Read from file-like object until size bytes are read.
575 | Raises ValueError if not EOF is encountered before size bytes are read.
576 | Non-blocking objects only supported if they derive from io objects.
577 | Required as e.g. ZipExtFile in python 2.6 can return less data than
578 | requested.
579 | """
580 | data = bytes()
581 | while True:
582 | # io files (default in python3) return None or raise on
583 | # would-block, python2 file will truncate, probably nothing can be
584 | # done about that. note that regular files can't be non-blocking
585 | try:
586 | r = fp.read(size - len(data))
587 | data += r
588 | if len(r) == 0 or len(data) == size:
589 | break
590 | except io.BlockingIOError:
591 | pass
592 | if len(data) != size:
593 | msg = "EOF: reading %s, expected %d bytes got %d"
594 | raise ValueError(msg % (error_template, size, len(data)))
595 | else:
596 | return data
597 |
598 |
599 | @contextmanager
600 | def _open_npy_file(path: str, arr_name: str):
601 | with open(path, "rb") as f:
602 | with zipfile.ZipFile(f, "r") as zip_f:
603 | if f"{arr_name}.npy" not in zip_f.namelist():
604 | raise ValueError(f"missing {arr_name} in npz file")
605 | with zip_f.open(f"{arr_name}.npy", "r") as arr_f:
606 | yield arr_f
607 |
608 |
609 | def _download_inception_model():
610 | if os.path.exists(INCEPTION_V3_PATH):
611 | return
612 | print("downloading InceptionV3 model...")
613 | with requests.get(INCEPTION_V3_URL, stream=True) as r:
614 | r.raise_for_status()
615 | tmp_path = INCEPTION_V3_PATH + ".tmp"
616 | with open(tmp_path, "wb") as f:
617 | for chunk in tqdm(r.iter_content(chunk_size=8192)):
618 | f.write(chunk)
619 | os.rename(tmp_path, INCEPTION_V3_PATH)
620 |
621 |
622 | def _create_feature_graph(input_batch):
623 | _download_inception_model()
624 | prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}"
625 | with open(INCEPTION_V3_PATH, "rb") as f:
626 | graph_def = tf.GraphDef()
627 | graph_def.ParseFromString(f.read())
628 | pool3, spatial = tf.import_graph_def(
629 | graph_def,
630 | input_map={f"ExpandDims:0": input_batch},
631 | return_elements=[FID_POOL_NAME, FID_SPATIAL_NAME],
632 | name=prefix,
633 | )
634 | _update_shapes(pool3)
635 | spatial = spatial[..., :7]
636 | return pool3, spatial
637 |
638 |
639 | def _create_softmax_graph(input_batch):
640 | _download_inception_model()
641 | prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}"
642 | with open(INCEPTION_V3_PATH, "rb") as f:
643 | graph_def = tf.GraphDef()
644 | graph_def.ParseFromString(f.read())
645 | (matmul,) = tf.import_graph_def(
646 | graph_def, return_elements=[f"softmax/logits/MatMul"], name=prefix
647 | )
648 | w = matmul.inputs[1]
649 | logits = tf.matmul(input_batch, w)
650 | return tf.nn.softmax(logits)
651 |
652 |
653 | def _update_shapes(pool3):
654 | # https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L50-L63
655 | ops = pool3.graph.get_operations()
656 | for op in ops:
657 | for o in op.outputs:
658 | shape = o.get_shape()
659 | if shape._dims is not None: # pylint: disable=protected-access
660 | # shape = [s.value for s in shape] TF 1.x
661 | shape = [s for s in shape] # TF 2.x
662 | new_shape = []
663 | for j, s in enumerate(shape):
664 | if s == 1 and j == 0:
665 | new_shape.append(None)
666 | else:
667 | new_shape.append(s)
668 | o.__dict__["_shape_val"] = tf.TensorShape(new_shape)
669 | return pool3
670 |
671 |
672 | def _numpy_partition(arr, kth, **kwargs):
673 | num_workers = min(cpu_count(), len(arr))
674 | chunk_size = len(arr) // num_workers
675 | extra = len(arr) % num_workers
676 |
677 | start_idx = 0
678 | batches = []
679 | for i in range(num_workers):
680 | size = chunk_size + (1 if i < extra else 0)
681 | batches.append(arr[start_idx : start_idx + size])
682 | start_idx += size
683 |
684 | with ThreadPool(num_workers) as pool:
685 | return list(pool.map(partial(np.partition, kth=kth, **kwargs), batches))
686 |
687 |
688 | if __name__ == "__main__":
689 | main()
690 |
--------------------------------------------------------------------------------