├── vqgan_jax ├── __init__.py ├── configuration_vqgan.py ├── convert_pt_model_to_jax.py └── modeling_flax_vqgan.py ├── README.md ├── setup.py └── .gitignore /vqgan_jax/__init__.py: -------------------------------------------------------------------------------- 1 | from . import * 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # vqgan-jax 2 | [WIP] JAX implementation of [taming-transformers](https://github.com/CompVis/taming-transformers) VQGAN. 3 | 4 | >(Note: For now this only consits of the `VQModel` for inference. The discriminatior, transformer model and training scripts will be added later.) 5 | 6 | Use this [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1mdXXsMbV6K_LTvCh3IImRsFIWcKU5m1w?usp=sharing) 7 | to see how to use this model to encode and reconstruct images. 8 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | # To use a consistent encoding 3 | from codecs import open 4 | import os 5 | 6 | here = os.path.abspath(os.path.dirname(__file__)) 7 | 8 | with open(os.path.join(here, 'README.md'), encoding='utf-8') as f: 9 | long_description = f.read() 10 | 11 | setuptools.setup( 12 | name='vqgan-jax', 13 | version='0.0.1', 14 | description='JAX implementation of VQGAN', 15 | long_description=long_description, 16 | long_description_content_type='text/markdown', 17 | packages=setuptools.find_packages(), 18 | install_requires=['jax>=0.2.6', 'flax', 'transformers'], 19 | ) 20 | -------------------------------------------------------------------------------- /vqgan_jax/configuration_vqgan.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | from transformers import PretrainedConfig 4 | 5 | 6 | class VQGANConfig(PretrainedConfig): 7 | def __init__( 8 | self, 9 | ch: int = 128, 10 | out_ch: int = 3, 11 | in_channels: int = 3, 12 | num_res_blocks: int = 2, 13 | resolution: int = 256, 14 | z_channels: int = 256, 15 | ch_mult: Tuple = (1, 1, 2, 2, 4), 16 | attn_resolutions: int = (16,), 17 | n_embed: int = 1024, 18 | embed_dim: int = 256, 19 | dropout: float = 0.0, 20 | double_z: bool = False, 21 | resamp_with_conv: bool = True, 22 | give_pre_end: bool = False, 23 | **kwargs, 24 | ): 25 | super().__init__(**kwargs) 26 | self.ch = ch 27 | self.out_ch = out_ch 28 | self.in_channels = in_channels 29 | self.num_res_blocks = num_res_blocks 30 | self.resolution = resolution 31 | self.z_channels = z_channels 32 | self.ch_mult = list(ch_mult) 33 | self.attn_resolutions = list(attn_resolutions) 34 | self.n_embed = n_embed 35 | self.embed_dim = embed_dim 36 | self.dropout = dropout 37 | self.double_z = double_z 38 | self.resamp_with_conv = resamp_with_conv 39 | self.give_pre_end = give_pre_end 40 | self.num_resolutions = len(ch_mult) 41 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Initially taken from Github's Python gitignore file 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # tests and logs 12 | tests/fixtures/cached_*_text.txt 13 | logs/ 14 | lightning_logs/ 15 | lang_code_data/ 16 | 17 | # Distribution / packaging 18 | .Python 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | .python-version 90 | 91 | # celery beat schedule file 92 | celerybeat-schedule 93 | 94 | # SageMath parsed files 95 | *.sage.py 96 | 97 | # Environments 98 | .env 99 | .venv 100 | env/ 101 | venv/ 102 | ENV/ 103 | env.bak/ 104 | venv.bak/ 105 | 106 | # Spyder project settings 107 | .spyderproject 108 | .spyproject 109 | 110 | # Rope project settings 111 | .ropeproject 112 | 113 | # mkdocs documentation 114 | /site 115 | 116 | # mypy 117 | .mypy_cache/ 118 | .dmypy.json 119 | dmypy.json 120 | 121 | # Pyre type checker 122 | .pyre/ 123 | 124 | # vscode 125 | .vs 126 | .vscode 127 | 128 | # Pycharm 129 | .idea 130 | 131 | # TF code 132 | tensorflow_code 133 | 134 | # Models 135 | proc_data 136 | 137 | # data 138 | /data 139 | serialization_dir 140 | 141 | # emacs 142 | *.*~ 143 | debug.env 144 | 145 | # vim 146 | .*.swp 147 | 148 | #ctags 149 | tags 150 | 151 | # pre-commit 152 | .pre-commit* 153 | 154 | # .lock 155 | *.lock -------------------------------------------------------------------------------- /vqgan_jax/convert_pt_model_to_jax.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import jax.numpy as jnp 4 | from flax.traverse_util import flatten_dict, unflatten_dict 5 | 6 | import torch 7 | 8 | from .modeling_flax_vqgan import VQModel 9 | from .configuration_vqgan import VQGANConfig 10 | 11 | regex = r"\w+[.]\d+" 12 | 13 | 14 | def rename_key(key): 15 | pats = re.findall(regex, key) 16 | for pat in pats: 17 | key = key.replace(pat, "_".join(pat.split("."))) 18 | return key 19 | 20 | 21 | # Adapted from https://github.com/huggingface/transformers/blob/ff5cdc086be1e0c3e2bbad8e3469b34cffb55a85/src/transformers/modeling_flax_pytorch_utils.py#L61 22 | def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model): 23 | # convert pytorch tensor to numpy 24 | pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()} 25 | 26 | random_flax_state_dict = flatten_dict(flax_model.params) 27 | flax_state_dict = {} 28 | 29 | remove_base_model_prefix = ( 30 | flax_model.base_model_prefix 31 | not in flax_model.params) and (flax_model.base_model_prefix in set( 32 | [k.split(".")[0] for k in pt_state_dict.keys()])) 33 | add_base_model_prefix = (flax_model.base_model_prefix in flax_model.params 34 | ) and (flax_model.base_model_prefix not in set( 35 | [k.split(".")[0] 36 | for k in pt_state_dict.keys()])) 37 | 38 | # Need to change some parameters name to match Flax names so that we don't have to fork any layer 39 | for pt_key, pt_tensor in pt_state_dict.items(): 40 | pt_tuple_key = tuple(pt_key.split(".")) 41 | 42 | has_base_model_prefix = pt_tuple_key[0] == flax_model.base_model_prefix 43 | require_base_model_prefix = (flax_model.base_model_prefix, 44 | ) + pt_tuple_key in random_flax_state_dict 45 | 46 | if remove_base_model_prefix and has_base_model_prefix: 47 | pt_tuple_key = pt_tuple_key[1:] 48 | elif add_base_model_prefix and require_base_model_prefix: 49 | pt_tuple_key = (flax_model.base_model_prefix, ) + pt_tuple_key 50 | 51 | # Correctly rename weight parameters 52 | if ("norm" in pt_key and (pt_tuple_key[-1] == "bias") 53 | and (pt_tuple_key[:-1] + ("bias", ) not in random_flax_state_dict) 54 | and (pt_tuple_key[:-1] + ("scale", ) in random_flax_state_dict)): 55 | pt_tuple_key = pt_tuple_key[:-1] + ("scale", ) 56 | elif pt_tuple_key[-1] in [ 57 | "weight", "gamma" 58 | ] and pt_tuple_key[:-1] + ("scale", ) in random_flax_state_dict: 59 | pt_tuple_key = pt_tuple_key[:-1] + ("scale", ) 60 | if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ( 61 | "embedding", ) in random_flax_state_dict: 62 | pt_tuple_key = pt_tuple_key[:-1] + ("embedding", ) 63 | elif pt_tuple_key[ 64 | -1] == "weight" and pt_tensor.ndim == 4 and pt_tuple_key not in random_flax_state_dict: 65 | # conv layer 66 | pt_tuple_key = pt_tuple_key[:-1] + ("kernel", ) 67 | pt_tensor = pt_tensor.transpose(2, 3, 1, 0) 68 | elif pt_tuple_key[ 69 | -1] == "weight" and pt_tuple_key not in random_flax_state_dict: 70 | # linear layer 71 | pt_tuple_key = pt_tuple_key[:-1] + ("kernel", ) 72 | pt_tensor = pt_tensor.T 73 | elif pt_tuple_key[-1] == "gamma": 74 | pt_tuple_key = pt_tuple_key[:-1] + ("weight", ) 75 | elif pt_tuple_key[-1] == "beta": 76 | pt_tuple_key = pt_tuple_key[:-1] + ("bias", ) 77 | 78 | if pt_tuple_key in random_flax_state_dict: 79 | if pt_tensor.shape != random_flax_state_dict[pt_tuple_key].shape: 80 | raise ValueError( 81 | f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape " 82 | f"{random_flax_state_dict[pt_tuple_key].shape}, but is {pt_tensor.shape}." 83 | ) 84 | 85 | # also add unexpected weight so that warning is thrown 86 | flax_state_dict[pt_tuple_key] = jnp.asarray(pt_tensor) 87 | 88 | return unflatten_dict(flax_state_dict) 89 | 90 | 91 | def convert_model(config_path, pt_state_dict_path, save_path): 92 | config = VQGANConfig.from_pretrained(config_path) 93 | model = VQModel(config) 94 | 95 | state_dict = torch.load(pt_state_dict_path, map_location="cpu")["state_dict"] 96 | keys = list(state_dict.keys()) 97 | for key in keys: 98 | if key.startswith("loss"): 99 | state_dict.pop(key) 100 | continue 101 | renamed_key = rename_key(key) 102 | state_dict[renamed_key] = state_dict.pop(key) 103 | 104 | state = convert_pytorch_state_dict_to_flax(state_dict, model) 105 | model.params = state 106 | model.save_pretrained(save_path) 107 | return model 108 | -------------------------------------------------------------------------------- /vqgan_jax/modeling_flax_vqgan.py: -------------------------------------------------------------------------------- 1 | # JAX implementation of VQGAN from taming-transformers https://github.com/CompVis/taming-transformers 2 | 3 | from functools import partial 4 | from typing import Tuple 5 | import math 6 | 7 | import jax 8 | import jax.numpy as jnp 9 | import numpy as np 10 | import flax.linen as nn 11 | from flax.core.frozen_dict import FrozenDict 12 | 13 | from transformers.modeling_flax_utils import FlaxPreTrainedModel 14 | 15 | from .configuration_vqgan import VQGANConfig 16 | 17 | 18 | class Upsample(nn.Module): 19 | in_channels: int 20 | with_conv: bool 21 | dtype: jnp.dtype = jnp.float32 22 | 23 | def setup(self): 24 | if self.with_conv: 25 | self.conv = nn.Conv( 26 | self.in_channels, 27 | kernel_size=(3, 3), 28 | strides=(1, 1), 29 | padding=((1, 1), (1, 1)), 30 | dtype=self.dtype, 31 | ) 32 | 33 | def __call__(self, hidden_states): 34 | batch, height, width, channels = hidden_states.shape 35 | hidden_states = jax.image.resize( 36 | hidden_states, 37 | shape=(batch, height * 2, width * 2, channels), 38 | method="nearest", 39 | ) 40 | if self.with_conv: 41 | hidden_states = self.conv(hidden_states) 42 | return hidden_states 43 | 44 | 45 | class Downsample(nn.Module): 46 | in_channels: int 47 | with_conv: bool 48 | dtype: jnp.dtype = jnp.float32 49 | 50 | def setup(self): 51 | if self.with_conv: 52 | self.conv = nn.Conv( 53 | self.in_channels, 54 | kernel_size=(3, 3), 55 | strides=(2, 2), 56 | padding="VALID", 57 | dtype=self.dtype, 58 | ) 59 | 60 | def __call__(self, hidden_states): 61 | if self.with_conv: 62 | pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim 63 | hidden_states = jnp.pad(hidden_states, pad_width=pad) 64 | hidden_states = self.conv(hidden_states) 65 | else: 66 | hidden_states = nn.avg_pool(hidden_states, 67 | window_shape=(2, 2), 68 | strides=(2, 2), 69 | padding="VALID") 70 | return hidden_states 71 | 72 | 73 | class ResnetBlock(nn.Module): 74 | in_channels: int 75 | out_channels: int = None 76 | use_conv_shortcut: bool = False 77 | temb_channels: int = 512 78 | dropout_prob: float = 0.0 79 | dtype: jnp.dtype = jnp.float32 80 | 81 | def setup(self): 82 | self.out_channels_ = self.in_channels if self.out_channels is None else self.out_channels 83 | 84 | self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-6) 85 | self.conv1 = nn.Conv( 86 | self.out_channels_, 87 | kernel_size=(3, 3), 88 | strides=(1, 1), 89 | padding=((1, 1), (1, 1)), 90 | dtype=self.dtype, 91 | ) 92 | 93 | if self.temb_channels: 94 | self.temb_proj = nn.Dense(self.out_channels_, dtype=self.dtype) 95 | 96 | self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-6) 97 | self.dropout = nn.Dropout(self.dropout_prob) 98 | self.conv2 = nn.Conv( 99 | self.out_channels_, 100 | kernel_size=(3, 3), 101 | strides=(1, 1), 102 | padding=((1, 1), (1, 1)), 103 | dtype=self.dtype, 104 | ) 105 | 106 | if self.in_channels != self.out_channels_: 107 | if self.use_conv_shortcut: 108 | self.conv_shortcut = nn.Conv( 109 | self.out_channels_, 110 | kernel_size=(3, 3), 111 | strides=(1, 1), 112 | padding=((1, 1), (1, 1)), 113 | dtype=self.dtype, 114 | ) 115 | else: 116 | self.nin_shortcut = nn.Conv( 117 | self.out_channels_, 118 | kernel_size=(1, 1), 119 | strides=(1, 1), 120 | padding="VALID", 121 | dtype=self.dtype, 122 | ) 123 | 124 | def __call__(self, hidden_states, temb=None, deterministic: bool = True): 125 | residual = hidden_states 126 | hidden_states = self.norm1(hidden_states) 127 | hidden_states = nn.swish(hidden_states) 128 | hidden_states = self.conv1(hidden_states) 129 | 130 | if temb is not None: 131 | hidden_states = hidden_states + self.temb_proj( 132 | nn.swish(temb))[:, :, None, None] # TODO: check shapes 133 | 134 | hidden_states = self.norm2(hidden_states) 135 | hidden_states = nn.swish(hidden_states) 136 | hidden_states = self.dropout(hidden_states, deterministic) 137 | hidden_states = self.conv2(hidden_states) 138 | 139 | if self.in_channels != self.out_channels_: 140 | if self.use_conv_shortcut: 141 | residual = self.conv_shortcut(residual) 142 | else: 143 | residual = self.nin_shortcut(residual) 144 | 145 | return hidden_states + residual 146 | 147 | 148 | class AttnBlock(nn.Module): 149 | in_channels: int 150 | dtype: jnp.dtype = jnp.float32 151 | 152 | def setup(self): 153 | conv = partial(nn.Conv, 154 | self.in_channels, 155 | kernel_size=(1, 1), 156 | strides=(1, 1), 157 | padding="VALID", 158 | dtype=self.dtype) 159 | 160 | self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-6) 161 | self.q, self.k, self.v = conv(), conv(), conv() 162 | self.proj_out = conv() 163 | 164 | def __call__(self, hidden_states): 165 | residual = hidden_states 166 | hidden_states = self.norm(hidden_states) 167 | 168 | query = self.q(hidden_states) 169 | key = self.k(hidden_states) 170 | value = self.v(hidden_states) 171 | 172 | # compute attentions 173 | batch, height, width, channels = query.shape 174 | query = query.reshape((batch, height * width, channels)) 175 | key = key.reshape((batch, height * width, channels)) 176 | attn_weights = jnp.einsum("...qc,...kc->...qk", query, key) 177 | attn_weights = attn_weights * (int(channels)**-0.5) 178 | attn_weights = nn.softmax(attn_weights, axis=2) 179 | 180 | ## attend to values 181 | value = value.reshape((batch, height * width, channels)) 182 | hidden_states = jnp.einsum("...kc,...qk->...qc", value, attn_weights) 183 | hidden_states = hidden_states.reshape((batch, height, width, channels)) 184 | 185 | hidden_states = self.proj_out(hidden_states) 186 | hidden_states = hidden_states + residual 187 | return hidden_states 188 | 189 | 190 | class UpsamplingBlock(nn.Module): 191 | config: VQGANConfig 192 | curr_res: int 193 | block_idx: int 194 | dtype: jnp.dtype = jnp.float32 195 | 196 | def setup(self): 197 | if self.block_idx == self.config.num_resolutions - 1: 198 | block_in = self.config.ch * self.config.ch_mult[-1] 199 | else: 200 | block_in = self.config.ch * self.config.ch_mult[self.block_idx + 1] 201 | 202 | block_out = self.config.ch * self.config.ch_mult[self.block_idx] 203 | self.temb_ch = 0 204 | 205 | res_blocks = [] 206 | attn_blocks = [] 207 | for _ in range(self.config.num_res_blocks + 1): 208 | res_blocks.append( 209 | ResnetBlock(block_in, 210 | block_out, 211 | temb_channels=self.temb_ch, 212 | dropout_prob=self.config.dropout, 213 | dtype=self.dtype)) 214 | block_in = block_out 215 | if self.curr_res in self.config.attn_resolutions: 216 | attn_blocks.append(AttnBlock(block_in, dtype=self.dtype)) 217 | 218 | self.block = res_blocks 219 | self.attn = attn_blocks 220 | 221 | self.upsample = None 222 | if self.block_idx != 0: 223 | self.upsample = Upsample(block_in, 224 | self.config.resamp_with_conv, 225 | dtype=self.dtype) 226 | 227 | def __call__(self, hidden_states, temb=None, deterministic: bool = True): 228 | for i, res_block in enumerate(self.block): 229 | hidden_states = res_block(hidden_states, 230 | temb, 231 | deterministic=deterministic) 232 | if self.attn: 233 | hidden_states = self.attn[i](hidden_states) 234 | 235 | if self.upsample is not None: 236 | hidden_states = self.upsample(hidden_states) 237 | 238 | return hidden_states 239 | 240 | 241 | class DownsamplingBlock(nn.Module): 242 | config: VQGANConfig 243 | curr_res: int 244 | block_idx: int 245 | dtype: jnp.dtype = jnp.float32 246 | 247 | def setup(self): 248 | in_ch_mult = (1, ) + tuple(self.config.ch_mult) 249 | block_in = self.config.ch * in_ch_mult[self.block_idx] 250 | block_out = self.config.ch * self.config.ch_mult[self.block_idx] 251 | self.temb_ch = 0 252 | 253 | res_blocks = [] 254 | attn_blocks = [] 255 | for _ in range(self.config.num_res_blocks): 256 | res_blocks.append( 257 | ResnetBlock(block_in, 258 | block_out, 259 | temb_channels=self.temb_ch, 260 | dropout_prob=self.config.dropout, 261 | dtype=self.dtype)) 262 | block_in = block_out 263 | if self.curr_res in self.config.attn_resolutions: 264 | attn_blocks.append(AttnBlock(block_in, dtype=self.dtype)) 265 | 266 | self.block = res_blocks 267 | self.attn = attn_blocks 268 | 269 | self.downsample = None 270 | if self.block_idx != self.config.num_resolutions - 1: 271 | self.downsample = Downsample(block_in, 272 | self.config.resamp_with_conv, 273 | dtype=self.dtype) 274 | 275 | def __call__(self, hidden_states, temb=None, deterministic: bool = True): 276 | for i, res_block in enumerate(self.block): 277 | hidden_states = res_block(hidden_states, 278 | temb, 279 | deterministic=deterministic) 280 | if self.attn: 281 | hidden_states = self.attn[i](hidden_states) 282 | 283 | if self.downsample is not None: 284 | hidden_states = self.downsample(hidden_states) 285 | 286 | return hidden_states 287 | 288 | 289 | class MidBlock(nn.Module): 290 | in_channels: int 291 | temb_channels: int 292 | dropout: float 293 | dtype: jnp.dtype = jnp.float32 294 | 295 | def setup(self): 296 | self.block_1 = ResnetBlock( 297 | self.in_channels, 298 | self.in_channels, 299 | temb_channels=self.temb_channels, 300 | dropout_prob=self.dropout, 301 | dtype=self.dtype, 302 | ) 303 | self.attn_1 = AttnBlock(self.in_channels, dtype=self.dtype) 304 | self.block_2 = ResnetBlock( 305 | self.in_channels, 306 | self.in_channels, 307 | temb_channels=self.temb_channels, 308 | dropout_prob=self.dropout, 309 | dtype=self.dtype, 310 | ) 311 | 312 | def __call__(self, hidden_states, temb=None, deterministic: bool = True): 313 | hidden_states = self.block_1(hidden_states, 314 | temb, 315 | deterministic=deterministic) 316 | hidden_states = self.attn_1(hidden_states) 317 | hidden_states = self.block_2(hidden_states, 318 | temb, 319 | deterministic=deterministic) 320 | return hidden_states 321 | 322 | 323 | class Encoder(nn.Module): 324 | config: VQGANConfig 325 | dtype: jnp.dtype = jnp.float32 326 | 327 | def setup(self): 328 | self.temb_ch = 0 329 | 330 | # downsampling 331 | self.conv_in = nn.Conv( 332 | self.config.ch, 333 | kernel_size=(3, 3), 334 | strides=(1, 1), 335 | padding=((1, 1), (1, 1)), 336 | dtype=self.dtype, 337 | ) 338 | 339 | curr_res = self.config.resolution 340 | downsample_blocks = [] 341 | for i_level in range(self.config.num_resolutions): 342 | downsample_blocks.append( 343 | DownsamplingBlock(self.config, 344 | curr_res, 345 | block_idx=i_level, 346 | dtype=self.dtype)) 347 | 348 | if i_level != self.config.num_resolutions - 1: 349 | curr_res = curr_res // 2 350 | self.down = downsample_blocks 351 | 352 | # middle 353 | mid_channels = self.config.ch * self.config.ch_mult[-1] 354 | self.mid = MidBlock(mid_channels, 355 | self.temb_ch, 356 | self.config.dropout, 357 | dtype=self.dtype) 358 | 359 | # end 360 | self.norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-6) 361 | self.conv_out = nn.Conv( 362 | 2 * self.config.z_channels 363 | if self.config.double_z else self.config.z_channels, 364 | kernel_size=(3, 3), 365 | strides=(1, 1), 366 | padding=((1, 1), (1, 1)), 367 | dtype=self.dtype, 368 | ) 369 | 370 | def __call__(self, pixel_values, deterministic: bool = True): 371 | # timestep embedding 372 | temb = None 373 | 374 | # downsampling 375 | hidden_states = self.conv_in(pixel_values) 376 | for block in self.down: 377 | hidden_states = block(hidden_states, temb, deterministic=deterministic) 378 | 379 | # middle 380 | hidden_states = self.mid(hidden_states, temb, deterministic=deterministic) 381 | 382 | # end 383 | hidden_states = self.norm_out(hidden_states) 384 | hidden_states = nn.swish(hidden_states) 385 | hidden_states = self.conv_out(hidden_states) 386 | 387 | return hidden_states 388 | 389 | 390 | class Decoder(nn.Module): 391 | config: VQGANConfig 392 | dtype: jnp.dtype = jnp.float32 393 | 394 | def setup(self): 395 | self.temb_ch = 0 396 | 397 | # compute in_ch_mult, block_in and curr_res at lowest res 398 | block_in = self.config.ch * self.config.ch_mult[self.config.num_resolutions 399 | - 1] 400 | curr_res = self.config.resolution // 2**(self.config.num_resolutions - 1) 401 | self.z_shape = (1, self.config.z_channels, curr_res, curr_res) 402 | 403 | # z to block_in 404 | self.conv_in = nn.Conv( 405 | block_in, 406 | kernel_size=(3, 3), 407 | strides=(1, 1), 408 | padding=((1, 1), (1, 1)), 409 | dtype=self.dtype, 410 | ) 411 | 412 | # middle 413 | self.mid = MidBlock(block_in, 414 | self.temb_ch, 415 | self.config.dropout, 416 | dtype=self.dtype) 417 | 418 | # upsampling 419 | upsample_blocks = [] 420 | for i_level in reversed(range(self.config.num_resolutions)): 421 | upsample_blocks.append( 422 | UpsamplingBlock(self.config, 423 | curr_res, 424 | block_idx=i_level, 425 | dtype=self.dtype)) 426 | if i_level != 0: 427 | curr_res = curr_res * 2 428 | self.up = list( 429 | reversed(upsample_blocks)) # reverse to get consistent order 430 | 431 | # end 432 | self.norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-6) 433 | self.conv_out = nn.Conv( 434 | self.config.out_ch, 435 | kernel_size=(3, 3), 436 | strides=(1, 1), 437 | padding=((1, 1), (1, 1)), 438 | dtype=self.dtype, 439 | ) 440 | 441 | def __call__(self, hidden_states, deterministic: bool = True): 442 | # timestep embedding 443 | temb = None 444 | 445 | # z to block_in 446 | hidden_states = self.conv_in(hidden_states) 447 | 448 | # middle 449 | hidden_states = self.mid(hidden_states, temb, deterministic=deterministic) 450 | 451 | # upsampling 452 | for block in reversed(self.up): 453 | hidden_states = block(hidden_states, temb, deterministic=deterministic) 454 | 455 | # end 456 | if self.config.give_pre_end: 457 | return hidden_states 458 | 459 | hidden_states = self.norm_out(hidden_states) 460 | hidden_states = nn.swish(hidden_states) 461 | hidden_states = self.conv_out(hidden_states) 462 | 463 | return hidden_states 464 | 465 | 466 | class VectorQuantizer(nn.Module): 467 | """ 468 | see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py 469 | ____________________________________________ 470 | Discretization bottleneck part of the VQ-VAE. 471 | Inputs: 472 | - n_e : number of embeddings 473 | - e_dim : dimension of embedding 474 | - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2 475 | _____________________________________________ 476 | """ 477 | 478 | config: VQGANConfig 479 | dtype: jnp.dtype = jnp.float32 480 | 481 | def setup(self): 482 | self.embedding = nn.Embed(self.config.n_embed, 483 | self.config.embed_dim, 484 | dtype=self.dtype) # TODO: init 485 | 486 | def __call__(self, hidden_states): 487 | """ 488 | Inputs the output of the encoder network z and maps it to a discrete 489 | one-hot vector that is the index of the closest embedding vector e_j 490 | z (continuous) -> z_q (discrete) 491 | z.shape = (batch, channel, height, width) 492 | quantization pipeline: 493 | 1. get encoder input (B,C,H,W) 494 | 2. flatten input to (B*H*W,C) 495 | """ 496 | # flatten 497 | hidden_states_flattended = hidden_states.reshape( 498 | (-1, self.config.embed_dim)) 499 | 500 | # dummy op to init the weights, so we can access them below 501 | self.embedding(jnp.ones((1, 1), dtype="i4")) 502 | 503 | # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z 504 | emb_weights = self.variables["params"]["embedding"]["embedding"] 505 | distance = (jnp.sum(hidden_states_flattended**2, axis=1, keepdims=True) + 506 | jnp.sum(emb_weights**2, axis=1) - 507 | 2 * jnp.dot(hidden_states_flattended, emb_weights.T)) 508 | 509 | # get quantized latent vectors 510 | min_encoding_indices = jnp.argmin(distance, axis=1) 511 | z_q = self.embedding(min_encoding_indices).reshape(hidden_states.shape) 512 | 513 | # reshape to (batch, num_tokens) 514 | min_encoding_indices = min_encoding_indices.reshape( 515 | hidden_states.shape[0], -1) 516 | 517 | # compute the codebook_loss (q_loss) outside the model 518 | # here we return the embeddings and indices 519 | return z_q, min_encoding_indices 520 | 521 | def get_codebook_entry(self, indices, shape=None): 522 | # indices are expected to be of shape (batch, num_tokens) 523 | # get quantized latent vectors 524 | batch, num_tokens = indices.shape 525 | z_q = self.embedding(indices) 526 | z_q = z_q.reshape(batch, int(math.sqrt(num_tokens)), 527 | int(math.sqrt(num_tokens)), -1) 528 | return z_q 529 | 530 | 531 | class VQModule(nn.Module): 532 | config: VQGANConfig 533 | dtype: jnp.dtype = jnp.float32 534 | 535 | def setup(self): 536 | self.encoder = Encoder(self.config, dtype=self.dtype) 537 | self.decoder = Decoder(self.config, dtype=self.dtype) 538 | self.quantize = VectorQuantizer(self.config, dtype=self.dtype) 539 | self.quant_conv = nn.Conv( 540 | self.config.embed_dim, 541 | kernel_size=(1, 1), 542 | strides=(1, 1), 543 | padding="VALID", 544 | dtype=self.dtype, 545 | ) 546 | self.post_quant_conv = nn.Conv( 547 | self.config.z_channels, 548 | kernel_size=(1, 1), 549 | strides=(1, 1), 550 | padding="VALID", 551 | dtype=self.dtype, 552 | ) 553 | 554 | def encode(self, pixel_values, deterministic: bool = True): 555 | hidden_states = self.encoder(pixel_values, deterministic=deterministic) 556 | hidden_states = self.quant_conv(hidden_states) 557 | quant_states, indices = self.quantize(hidden_states) 558 | return quant_states, indices 559 | 560 | def decode(self, hidden_states, deterministic: bool = True): 561 | hidden_states = self.post_quant_conv(hidden_states) 562 | hidden_states = self.decoder(hidden_states, deterministic=deterministic) 563 | return hidden_states 564 | 565 | def decode_code(self, code_b): 566 | hidden_states = self.quantize.get_codebook_entry(code_b) 567 | hidden_states = self.decode(hidden_states) 568 | return hidden_states 569 | 570 | def __call__(self, pixel_values, deterministic: bool = True): 571 | quant_states, indices = self.encode(pixel_values, deterministic) 572 | hidden_states = self.decode(quant_states, deterministic) 573 | return hidden_states, indices 574 | 575 | 576 | class VQGANPreTrainedModel(FlaxPreTrainedModel): 577 | """ 578 | An abstract class to handle weights initialization and a simple interface 579 | for downloading and loading pretrained models. 580 | """ 581 | 582 | config_class = VQGANConfig 583 | base_model_prefix = "model" 584 | module_class: nn.Module = None 585 | 586 | def __init__( 587 | self, 588 | config: VQGANConfig, 589 | input_shape: Tuple = (1, 256, 256, 3), 590 | seed: int = 0, 591 | dtype: jnp.dtype = jnp.float32, 592 | _do_init: bool = True, 593 | **kwargs, 594 | ): 595 | module = self.module_class(config=config, dtype=dtype, **kwargs) 596 | super().__init__(config, 597 | module, 598 | input_shape=input_shape, 599 | seed=seed, 600 | dtype=dtype, 601 | _do_init=_do_init) 602 | 603 | def init_weights(self, rng: jax.random.PRNGKey, 604 | input_shape: Tuple) -> FrozenDict: 605 | # init input tensors 606 | pixel_values = jnp.zeros(input_shape, dtype=jnp.float32) 607 | params_rng, dropout_rng = jax.random.split(rng) 608 | rngs = {"params": params_rng, "dropout": dropout_rng} 609 | 610 | return self.module.init(rngs, pixel_values)["params"] 611 | 612 | def encode(self, 613 | pixel_values, 614 | params: dict = None, 615 | dropout_rng: jax.random.PRNGKey = None, 616 | train: bool = False): 617 | # Handle any PRNG if needed 618 | rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} 619 | 620 | return self.module.apply({"params": params or self.params}, 621 | jnp.array(pixel_values), 622 | not train, 623 | rngs=rngs, 624 | method=self.module.encode) 625 | 626 | def decode(self, 627 | hidden_states, 628 | params: dict = None, 629 | dropout_rng: jax.random.PRNGKey = None, 630 | train: bool = False): 631 | # Handle any PRNG if needed 632 | rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} 633 | 634 | return self.module.apply( 635 | {"params": params or self.params}, 636 | jnp.array(hidden_states), 637 | not train, 638 | rngs=rngs, 639 | method=self.module.decode, 640 | ) 641 | 642 | def decode_code(self, indices, params: dict = None): 643 | return self.module.apply({"params": params or self.params}, 644 | jnp.array(indices, dtype="i4"), 645 | method=self.module.decode_code) 646 | 647 | def __call__( 648 | self, 649 | pixel_values, 650 | params: dict = None, 651 | dropout_rng: jax.random.PRNGKey = None, 652 | train: bool = False, 653 | ): 654 | # Handle any PRNG if needed 655 | rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} 656 | 657 | return self.module.apply( 658 | {"params": params or self.params}, 659 | jnp.array(pixel_values), 660 | not train, 661 | rngs=rngs, 662 | ) 663 | 664 | 665 | class VQModel(VQGANPreTrainedModel): 666 | module_class = VQModule 667 | --------------------------------------------------------------------------------