├── models ├── __init__.py ├── gpt.py └── vqvae.py ├── requirements.txt ├── .gitignore ├── utils ├── losses.py ├── logger.py ├── annotations.py └── dataset.py ├── static └── configs │ ├── train_vqvae_config.json │ └── train_gpt_config.json ├── LICENSE ├── train_vqvae.py ├── vqvae_encode.py ├── generate.py ├── train_gpt.py ├── trainers ├── vqvae_trainer.py └── gpt_trainer.py └── README.md /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .gpt import GPTLmHeadModel 2 | from .vqvae import CnnDecoder, CnnEncoder, QuantizedCodebook 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | dm-haiku==0.0.11 2 | optax==0.1.7 3 | numpy==1.22.3 4 | tqdm==4.64.0 5 | datasets==2.0.0 6 | Pillow==9.3.0 7 | black 8 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .python-version 3 | 4 | # train configs 5 | /train_vqvae_config.json 6 | /train_gpt_config.json 7 | 8 | # training outputs 9 | runs 10 | datasets 11 | generated 12 | -------------------------------------------------------------------------------- /utils/losses.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import jax.nn as nn 3 | 4 | 5 | def mse(y, y_pred): 6 | return jnp.mean((y - y_pred) ** 2) 7 | 8 | 9 | def cross_entropy(y, y_pred): 10 | return jnp.mean(jnp.sum(-y * nn.log_softmax(y_pred), axis=-1)) 11 | -------------------------------------------------------------------------------- /static/configs/train_vqvae_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed": 23, 3 | "dataset": "mnist", 4 | "resize_shape": [40, 40], 5 | "K": 256, 6 | "D": 128, 7 | "compression_level": 3, 8 | "res_layers": 2, 9 | "commitment_loss": 0.25, 10 | "train_dset_percentage": 100, 11 | "test_dset_percentage": 100, 12 | "train_steps": 15000, 13 | "test_steps": 50, 14 | "test_every": 250, 15 | "train_batch_size": 64, 16 | "test_batch_size": 64, 17 | "learning_rate": 3e-4, 18 | "weight_decay": 1e-5, 19 | "logdir": "runs/vqvae", 20 | "output_name": "train_state.pkl" 21 | } 22 | -------------------------------------------------------------------------------- /static/configs/train_gpt_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed": 23, 3 | "num_heads": 4, 4 | "hidden_dim": 64, 5 | "num_layers": 4, 6 | "dropout_rate": 0.1, 7 | "vqvae_config": "runs/vqvae/exp0/config.json", 8 | "vqvae_state": "runs/vqvae/exp0/train_state.pkl", 9 | "train_steps": 15000, 10 | "test_steps": 20, 11 | "test_every": 250, 12 | "train_dataset": "datasets/exp0-encoded/train", 13 | "test_dataset": "datasets/exp0-encoded/test", 14 | "train_batch_size": 64, 15 | "test_batch_size": 64, 16 | "generate_samples": 16, 17 | "sample_temperature": 0.3, 18 | "learning_rate": 3e-4, 19 | "weight_decay": 1e-5, 20 | "logdir": "runs/gpt", 21 | "output_name": "train_state.pkl" 22 | } 23 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Andy Lo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | from pathlib import Path 3 | import re 4 | 5 | from tensorboardX import SummaryWriter 6 | 7 | prefix = "exp" 8 | 9 | 10 | def get_writer(base_dir: str, disable: bool = False) -> SummaryWriter: 11 | logdir = Path(base_dir) 12 | logdir.mkdir(parents=True, exist_ok=True) 13 | max_run = -1 14 | for path in logdir.iterdir(): 15 | match = re.fullmatch(rf"{prefix}([0-9]+)", path.name) 16 | max_run = max(max_run, int(match.group(1))) 17 | return SummaryWriter( 18 | logdir=str(logdir / f"{prefix}{max_run+1}"), 19 | flush_secs=10, 20 | write_to_disk=not disable, 21 | ) 22 | 23 | 24 | def log_dict(writer: SummaryWriter, logs: dict[str, Any], step: int, prefix: str = ""): 25 | for k, v in logs.items(): 26 | if k.startswith("scalar_"): 27 | k = k[len("scalar_") :] 28 | value = v 29 | if isinstance(value, list): 30 | value = sum(value) / len(value) 31 | writer.add_scalar(f"{prefix}{k}", value, step) 32 | elif k.startswith("images_"): 33 | k = k[len("images_") :] 34 | value = v 35 | if isinstance(value, list): 36 | value = value[0] 37 | writer.add_images(f"{prefix}{k}", value, step, dataformats="NHWC") 38 | -------------------------------------------------------------------------------- /utils/annotations.py: -------------------------------------------------------------------------------- 1 | from typing import Any, NamedTuple, TypedDict 2 | 3 | import numpy as np 4 | import haiku as hk 5 | import optax 6 | from jax._src.random import KeyArray 7 | 8 | 9 | class VqVaeConfig(NamedTuple): 10 | seed: int 11 | dataset: str 12 | resize_shape: tuple[int, int] 13 | K: int 14 | D: int 15 | compression_level: int 16 | res_layers: int 17 | commitment_loss: float 18 | train_dset_percentage: int 19 | test_dset_percentage: int 20 | train_steps: int 21 | test_steps: int 22 | test_every: int 23 | train_batch_size: int 24 | test_batch_size: int 25 | learning_rate: float 26 | weight_decay: float 27 | logdir: str 28 | output_name: str 29 | 30 | 31 | class GPTConfig(NamedTuple): 32 | seed: int 33 | num_heads: int 34 | hidden_dim: int 35 | num_layers: int 36 | dropout_rate: float 37 | vqvae_config: str 38 | vqvae_state: str 39 | train_steps: int 40 | test_steps: int 41 | test_every: int 42 | train_dataset: str 43 | test_dataset: str 44 | train_batch_size: int 45 | test_batch_size: int 46 | generate_samples: int 47 | sample_temperature: float 48 | learning_rate: float 49 | weight_decay: float 50 | logdir: str 51 | output_name: str 52 | 53 | 54 | class VqVaeTuple(NamedTuple): 55 | encoder: Any 56 | decoder: Any 57 | quantizer: Any 58 | 59 | 60 | class VqVaeState(NamedTuple): 61 | params: hk.Params 62 | state: hk.State 63 | opt_state: optax.OptState 64 | 65 | 66 | class GPTState(NamedTuple): 67 | params: hk.Params 68 | state: hk.State 69 | opt_state: optax.OptState 70 | rng: KeyArray 71 | 72 | 73 | class VqVaeBatch(TypedDict): 74 | image: np.ndarray 75 | label: np.ndarray 76 | 77 | 78 | class GPTBatch(TypedDict): 79 | label: np.ndarray 80 | encoding_indices: np.ndarray 81 | -------------------------------------------------------------------------------- /models/gpt.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import haiku as hk 4 | import jax.nn as nn 5 | import jax.numpy as jnp 6 | 7 | 8 | class CasualSelfAttention(hk.MultiHeadAttention): 9 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray: 10 | # x is B x N x H 11 | seq_len = x.shape[1] 12 | casual_mask = jnp.tril(jnp.ones((seq_len, seq_len))) 13 | # mask is B x num_heads x N x N 14 | casual_mask = jnp.tile(casual_mask, (x.shape[0], self.num_heads, 1, 1)) 15 | return super().__call__(x, x, x, casual_mask) 16 | 17 | 18 | class DecoderBlock(hk.Module): 19 | def __init__( 20 | self, 21 | num_heads: int, 22 | hidden_dim: int, 23 | model_size: int, 24 | weight_init_scale: float, 25 | dropout_rate: float, 26 | name: Optional[str] = None, 27 | ): 28 | super().__init__(name) 29 | self.casual_atten = CasualSelfAttention( 30 | num_heads, hidden_dim, weight_init_scale 31 | ) 32 | self.dropout_rate = dropout_rate 33 | 34 | def __call__(self, x: jnp.ndarray, is_training: bool) -> jnp.ndarray: 35 | # Structured according to original paper 36 | # https://arxiv.org/pdf/1706.03762.pdf#section.3 37 | res = self.casual_atten(x) 38 | if is_training: 39 | res = hk.dropout(hk.next_rng_key(), self.dropout_rate, res) 40 | x += res 41 | x = hk.LayerNorm(-1, True, True)(x) 42 | 43 | dim = x.shape[-1] 44 | res = hk.Linear(dim * 4)(x) 45 | res = nn.gelu(res) 46 | res = hk.Linear(dim)(res) 47 | if is_training: 48 | res = hk.dropout(hk.next_rng_key(), self.dropout_rate, res) 49 | x += res 50 | x = hk.LayerNorm(-1, True, True)(x) 51 | return x 52 | 53 | 54 | class GPTLmHeadModel(hk.Module): 55 | def __init__( 56 | self, 57 | num_heads: int, 58 | hidden_dim: int, 59 | num_layers: int, 60 | num_classes: int, 61 | dropout_rate: float, 62 | max_length: int, 63 | name: Optional[str] = None, 64 | ): 65 | super().__init__(name) 66 | self.num_heads = num_heads 67 | self.hidden_dim = hidden_dim 68 | self.num_layers = num_layers 69 | self.num_classes = num_classes 70 | self.dropout_rate = dropout_rate 71 | self.max_length = max_length 72 | self.model_size = self.num_heads * self.hidden_dim 73 | 74 | self.init_scale = 2.0 / num_layers 75 | self.embed = hk.Embed(num_classes, self.model_size) 76 | self.positional_embeddings = hk.get_parameter( 77 | "pos_embs", 78 | [self.max_length, self.model_size], 79 | init=hk.initializers.TruncatedNormal(stddev=0.02), 80 | ) 81 | self.blocks = [ 82 | DecoderBlock( 83 | self.num_heads, 84 | self.hidden_dim, 85 | self.model_size, 86 | self.init_scale, 87 | self.dropout_rate, 88 | ) 89 | for _ in range(num_layers) 90 | ] 91 | self.lm_head = hk.Linear(num_classes) 92 | 93 | def __call__(self, x, is_training: bool): 94 | seq_length = x.shape[1] 95 | x = self.embed(x) + self.positional_embeddings[:seq_length] 96 | for block in self.blocks: 97 | x = block(x, is_training) 98 | x = self.lm_head(x) 99 | # softmax is taken outside 100 | return x 101 | -------------------------------------------------------------------------------- /train_vqvae.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import pickle 3 | import json 4 | import argparse 5 | from pathlib import Path 6 | from collections import defaultdict 7 | 8 | from tqdm import tqdm 9 | import jax 10 | import optax 11 | 12 | from trainers.vqvae_trainer import VqVaeTrainer 13 | from utils.dataset import load_dset 14 | from utils.logger import get_writer, log_dict 15 | from utils.annotations import VqVaeConfig, VqVaeState 16 | 17 | 18 | def parse_args(): 19 | parser = argparse.ArgumentParser(description="Train a VQ-VAE on the MNIST dataset.") 20 | parser.add_argument( 21 | "-f", "--file", type=str, required=True, help="path to the json config file." 22 | ) 23 | parser.add_argument( 24 | "-chkp", "--checkpoint", type=str, help="path to train state pkl file." 25 | ) 26 | return parser.parse_args() 27 | 28 | 29 | def main(config: VqVaeConfig, checkpoint: Optional[str] = None): 30 | writer = get_writer(config.logdir) 31 | exp_dir = writer.logdir 32 | 33 | # save config file 34 | with open(Path(exp_dir) / "config.json", "w") as f: 35 | json.dump(config._asdict(), f, indent=4) 36 | 37 | # load dataset 38 | _, dset_train = load_dset( 39 | name=config.dataset, 40 | split="train", 41 | batch_size=config.train_batch_size, 42 | percentage=config.train_dset_percentage, 43 | resize_shape=config.resize_shape, 44 | seed=config.seed, 45 | ) 46 | _, dset_test = load_dset( 47 | name=config.dataset, 48 | split="test", 49 | batch_size=config.test_batch_size, 50 | percentage=config.test_dset_percentage, 51 | resize_shape=config.resize_shape, 52 | seed=config.seed, 53 | ) 54 | 55 | # initialize model 56 | optimizer = optax.adamw(config.learning_rate, weight_decay=config.weight_decay) 57 | trainer = VqVaeTrainer( 58 | K=config.K, 59 | D=config.D, 60 | compression_level=config.compression_level, 61 | res_layers=config.res_layers, 62 | commitment_loss=config.commitment_loss, 63 | optimizer=optimizer, 64 | ) 65 | if checkpoint is None: 66 | key = jax.random.PRNGKey(config.seed) 67 | train_state = trainer.initial_state(key, next(dset_train)[1]) 68 | else: 69 | with open(checkpoint, "rb") as f: 70 | train_state: VqVaeState = pickle.load(f) 71 | 72 | # training loop 73 | for i in tqdm(range(config.train_steps)): 74 | # update 75 | epoch, batch = next(dset_train) 76 | train_state, logs = trainer.update(train_state, batch) 77 | 78 | # log 79 | logs["scalar_epoch"] = epoch 80 | log_dict(writer, logs, step=i, prefix="train/") 81 | 82 | # evaluate 83 | if (i + 1) % config.test_every == 0: 84 | logs = defaultdict(list) 85 | for _ in range(config.test_steps): 86 | _, batch = next(dset_test) 87 | log = trainer.evaluate(train_state, batch) 88 | for k, v in log.items(): 89 | logs[k].append(v) 90 | log_dict(writer, logs, step=i, prefix="test/") 91 | 92 | # save model 93 | with open(Path(exp_dir) / config.output_name, "wb") as f: 94 | pickle.dump(train_state, f) 95 | 96 | writer.close() 97 | 98 | 99 | if __name__ == "__main__": 100 | args = parse_args() 101 | with open(args.file, "r") as f: 102 | config = VqVaeConfig(**json.load(f)) 103 | 104 | main(config, checkpoint=args.checkpoint) 105 | -------------------------------------------------------------------------------- /utils/dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Iterator, Optional, Union 2 | from itertools import count 3 | from pathlib import Path 4 | from datasets.features.features import Features 5 | 6 | import numpy as np 7 | import datasets 8 | from datasets.arrow_dataset import Dataset 9 | from skimage.transform import resize 10 | 11 | from utils.annotations import VqVaeBatch, GPTBatch 12 | 13 | 14 | def process_image(img, shape: tuple[int, int]) -> np.ndarray: 15 | img = np.array(img, dtype=np.float32) / 255 16 | img = resize(img, shape) 17 | img = img[..., None] 18 | return img 19 | 20 | 21 | def load_dset( 22 | name: str, 23 | split: str, 24 | batch_size: int, 25 | percentage: int, 26 | resize_shape: tuple[int, int], 27 | repeat: bool = True, 28 | seed: Optional[int] = None, 29 | ) -> tuple[Features, Iterator[tuple[int, VqVaeBatch]]]: 30 | """ 31 | Loads a dataset with preprocessing, batching, and repeating. 32 | 33 | Args: 34 | name (str): The name of the dataset on Hugging Face Hub. 35 | split (str): The split of the dataset, such as "train" / "test". 36 | batch_size (int): The batch size to load the data. 37 | percentage (int): The percentage of the dataset to use. 38 | resize_shape (tuple[int, int], optional): Shape to resize the image to. 39 | repeat (bool, optional): Whether or not to repeat the dataset after 40 | iterating through one epoch. Defaults to True. 41 | seed (Optional[int], optional): The seed used to suffle the dataset. Defaults to None. 42 | 43 | Returns: 44 | tuple[Features, Iterator[tuple[int, VqVaeBatch]]]: A tuple of dataset features and the 45 | iterator which yields preprocessed, batched, and repeated data. 46 | """ 47 | dset = datasets.load_dataset(name, split=f"{split}[:{percentage}%]") 48 | assert isinstance(dset, Dataset) 49 | 50 | features: Features = dset.features 51 | 52 | def preprocess(batch) -> VqVaeBatch: 53 | return { 54 | "image": np.array( 55 | [process_image(img, resize_shape) for img in batch["image"]] 56 | ), 57 | "label": np.array(batch["label"]), 58 | } 59 | 60 | dset.set_transform(preprocess) 61 | 62 | def iterator(dset: Dataset) -> Iterator[tuple[int, VqVaeBatch]]: 63 | counter = count() 64 | while True: 65 | dset = dset.shuffle(seed) 66 | epoch = next(counter) 67 | for i in range(0, len(dset) - batch_size, batch_size): 68 | yield epoch, dset[i : i + batch_size] 69 | 70 | if not repeat: 71 | break 72 | 73 | return features, iterator(dset) 74 | 75 | 76 | def load_vqvae_processed( 77 | path: Union[str, Path], 78 | batch_size: int, 79 | repeat: bool = True, 80 | seed: Optional[int] = None, 81 | ) -> tuple[Features, Iterator[tuple[int, GPTBatch]]]: 82 | dset = datasets.load.load_from_disk(str(path)) 83 | assert isinstance(dset, Dataset) 84 | 85 | features: Features = dset.features 86 | 87 | def preprocess(batch) -> GPTBatch: 88 | return { 89 | "encoding_indices": np.array(batch["encoding_indices"]), 90 | "label": np.array(batch["label"]), 91 | } 92 | 93 | dset.set_transform(preprocess) 94 | 95 | def iterator(dset: Dataset) -> Iterator[tuple[int, GPTBatch]]: 96 | counter = count() 97 | while True: 98 | dset = dset.shuffle(seed) 99 | epoch = next(counter) 100 | for i in range(0, len(dset) - batch_size, batch_size): 101 | yield epoch, dset[i : i + batch_size] 102 | 103 | if not repeat: 104 | break 105 | 106 | return features, iterator(dset) 107 | -------------------------------------------------------------------------------- /vqvae_encode.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import json 3 | import argparse 4 | from pathlib import Path 5 | 6 | import datasets 7 | import jax 8 | import numpy as np 9 | 10 | from trainers.vqvae_trainer import VqVaeTrainer 11 | from utils.dataset import process_image 12 | from utils.annotations import VqVaeConfig, VqVaeState 13 | 14 | 15 | def parse_args(): 16 | parser = argparse.ArgumentParser( 17 | description="Encode the MNIST dataset with a VQ-VAE.", 18 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 19 | ) 20 | parser.add_argument( 21 | "-p", 22 | "--path", 23 | type=str, 24 | required=True, 25 | default=argparse.SUPPRESS, 26 | help="path to directory of the trained VQ-VAE model.", 27 | ) 28 | parser.add_argument( 29 | "-o", 30 | "--out", 31 | type=str, 32 | required=True, 33 | default=argparse.SUPPRESS, 34 | help="path to directory to save the processed datasets.", 35 | ) 36 | parser.add_argument( 37 | "-b", 38 | "--batch_size", 39 | type=int, 40 | default=64, 41 | help="batch size to process the dataset with.", 42 | ) 43 | parser.add_argument( 44 | "-P", 45 | "--percentage", 46 | type=int, 47 | default=100, 48 | help="percentage of dataset to encode.", 49 | ) 50 | return parser.parse_args() 51 | 52 | 53 | def main(path: str, out_path: str, batch_size: int, percentage: int): 54 | model_dir = Path(path) 55 | 56 | with open(model_dir / "config.json", "r") as f: 57 | config = VqVaeConfig(**json.load(f)) 58 | with open(model_dir / config.output_name, "rb") as f: 59 | vqvae_state: VqVaeState = pickle.load(f) 60 | 61 | trainer = VqVaeTrainer( 62 | K=config.K, 63 | D=config.D, 64 | compression_level=config.compression_level, 65 | res_layers=config.res_layers, 66 | commitment_loss=config.commitment_loss, 67 | optimizer=None, 68 | ) 69 | 70 | @jax.jit 71 | def infer(vqvae_state: VqVaeState, x: np.ndarray): 72 | params, state = vqvae_state.params, vqvae_state.state 73 | z_e, _ = trainer.apply.encode(params, state, None, x, is_training=False) 74 | result, _ = trainer.apply.quantize(params, state, None, z_e) 75 | indices = result["encoding_indices"] 76 | 77 | # z1, z2 are not necessary but used for assertion 78 | z1 = trainer.lookup_indices(vqvae_state, indices) 79 | z2 = result["quantize"] 80 | 81 | return result, (z1, z2) 82 | 83 | def encode(batch): 84 | images = np.array( 85 | [process_image(img, shape=config.resize_shape) for img in batch["image"]] 86 | ) 87 | result, (z1, z2) = infer(vqvae_state, images) 88 | batch["encoding_indices"] = np.array(result["encoding_indices"]) 89 | 90 | assert np.allclose(z1, z2, atol=1e-6, rtol=0) 91 | assert batch["encoding_indices"].ndim == 3 92 | assert batch["encoding_indices"].dtype == np.int32 93 | assert np.max(batch["encoding_indices"]) < config.K 94 | assert np.min(batch["encoding_indices"]) >= 0 95 | 96 | return batch 97 | 98 | out_dir = Path(out_path) 99 | out_dir.mkdir(parents=True, exist_ok=True) 100 | for split in ("train", "test"): 101 | dset = datasets.load_dataset("mnist", split=f"{split}[:{percentage}%]") 102 | dset = dset.map(encode, batched=True, batch_size=batch_size) 103 | dset.save_to_disk(str(out_dir / split)) 104 | 105 | 106 | if __name__ == "__main__": 107 | args = parse_args() 108 | main( 109 | path=args.path, 110 | out_path=args.out, 111 | batch_size=args.batch_size, 112 | percentage=args.percentage, 113 | ) 114 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import pickle 4 | from pathlib import Path 5 | 6 | from PIL import Image 7 | import jax 8 | import numpy as np 9 | 10 | from utils.annotations import GPTBatch, GPTConfig, GPTState, VqVaeConfig, VqVaeState 11 | from trainers.gpt_trainer import VqVaeGPTTrainer 12 | from trainers.vqvae_trainer import VqVaeTrainer 13 | 14 | 15 | def parse_args(): 16 | parser = argparse.ArgumentParser( 17 | description="Generate MNIST samples by sampling VQ-VAE codes with a GPT-style transformer.", 18 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 19 | ) 20 | parser.add_argument( 21 | "-p", 22 | "--path", 23 | type=str, 24 | required=True, 25 | default=argparse.SUPPRESS, 26 | help="path to the pretrained GPT directory.", 27 | ) 28 | parser.add_argument( 29 | "-o", 30 | "--out", 31 | type=str, 32 | required=True, 33 | default=argparse.SUPPRESS, 34 | help="output directory to save generated images.", 35 | ) 36 | parser.add_argument( 37 | "-s", "--seed", type=int, default=0, help="seed to sample results from." 38 | ) 39 | parser.add_argument( 40 | "-t", 41 | "--temperature", 42 | type=float, 43 | default=0.2, 44 | help="temperature to sample results.", 45 | ) 46 | parser.add_argument( 47 | "-S", "--samples", type=int, default=8, help="Generates S x S samples." 48 | ) 49 | return parser.parse_args() 50 | 51 | 52 | def main(path: str, seed: int, temp: float, samples: int, out_path: str): 53 | # load configs 54 | model_dir = Path(path) 55 | with open(model_dir / "config.json", "r") as f: 56 | gpt_config = GPTConfig(**json.load(f)) 57 | with open(gpt_config.vqvae_config, "r") as f: 58 | vqvae_config = VqVaeConfig(**json.load(f)) 59 | h, w = vqvae_config.resize_shape 60 | 61 | # load VQ-VAE model for decoing sampled indices into image 62 | with open(gpt_config.vqvae_state, "rb") as f: 63 | vqvae_state: VqVaeState = pickle.load(f) 64 | vqvae = VqVaeTrainer( 65 | K=vqvae_config.K, 66 | D=vqvae_config.D, 67 | compression_level=vqvae_config.compression_level, 68 | res_layers=vqvae_config.res_layers, 69 | commitment_loss=vqvae_config.commitment_loss, 70 | optimizer=None, 71 | ) 72 | 73 | @jax.jit 74 | def decode_indices(vqvae_state: VqVaeState, indices): 75 | z_q = vqvae.lookup_indices(vqvae_state, indices) 76 | img, _ = vqvae.apply.decode( 77 | vqvae_state.params, vqvae_state.state, None, z_q, is_training=False 78 | ) 79 | return img 80 | 81 | # load GPT and initialize input output shapes 82 | with open(model_dir / gpt_config.output_name, "rb") as f: 83 | gpt_state: GPTState = pickle.load(f) 84 | x = np.zeros((1, h, w, 1), dtype=np.float32) 85 | res, _ = vqvae.forward(vqvae_state.params, vqvae_state.state, x, False) 86 | sample: GPTBatch = { 87 | "encoding_indices": res["encoding_indices"], 88 | "label": np.zeros((1,)), 89 | } 90 | gpt = VqVaeGPTTrainer( 91 | num_label_classes=10, 92 | vqvae_config=vqvae_config, 93 | num_heads=gpt_config.num_heads, 94 | hidden_dim=gpt_config.hidden_dim, 95 | num_layers=gpt_config.num_layers, 96 | dropout_rate=gpt_config.dropout_rate, 97 | sample=sample, 98 | optimizer=None, 99 | ) 100 | 101 | # generate 102 | rng = jax.random.PRNGKey(seed) 103 | out_dir = Path(out_path) 104 | out_dir.mkdir(parents=True, exist_ok=True) 105 | for label in range(10): 106 | image = np.zeros((samples * h, samples * w), dtype=np.uint8) 107 | for i in range(samples): 108 | for j in range(samples): 109 | indices, rng = gpt.generate(gpt_state, rng, label, temp=temp) 110 | img = decode_indices(vqvae_state, indices) 111 | img = (img[0, :, :, 0] * 255).astype(np.uint8) 112 | x, y = i * h, j * w 113 | image[x : x + h, y : y + w] = img 114 | im = Image.fromarray(image) 115 | im.save(str(out_dir / f"generated_{label}.png")) 116 | 117 | 118 | if __name__ == "__main__": 119 | args = parse_args() 120 | main( 121 | path=args.path, 122 | out_path=args.out, 123 | seed=args.seed, 124 | temp=args.temperature, 125 | samples=args.samples, 126 | ) 127 | -------------------------------------------------------------------------------- /train_gpt.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from pathlib import Path 3 | import pickle 4 | import json 5 | import argparse 6 | from collections import defaultdict 7 | 8 | import jax 9 | import optax 10 | from tqdm import tqdm 11 | import numpy as np 12 | 13 | from trainers.gpt_trainer import VqVaeGPTTrainer 14 | from trainers.vqvae_trainer import VqVaeTrainer 15 | from utils.dataset import load_vqvae_processed 16 | from utils.annotations import GPTConfig, GPTState, VqVaeConfig, VqVaeState 17 | from utils.logger import get_writer, log_dict 18 | 19 | 20 | def parse_args(): 21 | parser = argparse.ArgumentParser( 22 | description="Train a GPT-style transformer the VQ-VAE tokens of the MNIST dataset." 23 | ) 24 | parser.add_argument( 25 | "-f", "--file", type=str, required=True, help="path to the json config file." 26 | ) 27 | parser.add_argument( 28 | "-chkp", "--checkpoint", type=str, help="path to train state pkl file." 29 | ) 30 | return parser.parse_args() 31 | 32 | 33 | def main(config: GPTConfig, checkpoint: Optional[str] = None): 34 | writer = get_writer(config.logdir) 35 | exp_dir = writer.logdir 36 | 37 | # save config file 38 | with open(Path(exp_dir) / "config.json", "w") as f: 39 | json.dump(config._asdict(), f, indent=4) 40 | 41 | # load dataset 42 | features, dset_train = load_vqvae_processed( 43 | path=config.train_dataset, 44 | batch_size=config.train_batch_size, 45 | repeat=True, 46 | seed=config.seed, 47 | ) 48 | _, dset_test = load_vqvae_processed( 49 | path=config.test_dataset, 50 | batch_size=config.test_batch_size, 51 | repeat=True, 52 | seed=config.seed, 53 | ) 54 | label_classes = features["label"].num_classes 55 | 56 | # load vqvae for evaluation 57 | with open(config.vqvae_config, "r") as f: 58 | vqvae_config = VqVaeConfig(**json.load(f)) 59 | with open(config.vqvae_state, "rb") as f: 60 | vqvae_state: VqVaeState = pickle.load(f) 61 | vqvae = VqVaeTrainer( 62 | K=vqvae_config.K, 63 | D=vqvae_config.D, 64 | compression_level=vqvae_config.compression_level, 65 | res_layers=vqvae_config.res_layers, 66 | commitment_loss=vqvae_config.commitment_loss, 67 | optimizer=None, 68 | ) 69 | 70 | @jax.jit 71 | def decode_indices(vqvae_state: VqVaeState, indices): 72 | z_q = vqvae.lookup_indices(vqvae_state, indices) 73 | img, _ = vqvae.apply.decode( 74 | vqvae_state.params, vqvae_state.state, None, z_q, is_training=False 75 | ) 76 | return img 77 | 78 | # initialize model 79 | _, sample = next(dset_train) 80 | optimizer = optax.adamw(config.learning_rate, weight_decay=config.weight_decay) 81 | trainer = VqVaeGPTTrainer( 82 | label_classes, 83 | vqvae_config, 84 | config.num_heads, 85 | config.hidden_dim, 86 | config.num_layers, 87 | config.dropout_rate, 88 | sample, 89 | optimizer, 90 | ) 91 | key = jax.random.PRNGKey(config.seed) 92 | if checkpoint is None: 93 | key, key1 = jax.random.split(key) 94 | train_state = trainer.initial_state(key1, sample) 95 | else: 96 | with open(checkpoint, "rb") as f: 97 | train_state: GPTState = pickle.load(f) 98 | 99 | # training loop 100 | for i in tqdm(range(config.train_steps)): 101 | # update 102 | epoch, batch = next(dset_train) 103 | train_state, logs = trainer.update(train_state, batch) 104 | 105 | # log 106 | logs["scalar_epoch"] = epoch 107 | log_dict(writer, logs, step=i, prefix="train/") 108 | 109 | # evaluate 110 | if (i + 1) % config.test_every == 0: 111 | # test loss 112 | logs = defaultdict(list) 113 | for _ in range(config.test_steps): 114 | _, batch = next(dset_test) 115 | train_state, log = trainer.evaluate(train_state, batch) 116 | for k, v in log.items(): 117 | logs[k].append(v) 118 | log_dict(writer, logs, step=i, prefix="test/") 119 | 120 | # generate samples 121 | for label in range(label_classes): 122 | images = [] 123 | for _ in range(config.generate_samples): 124 | indices, key = trainer.generate( 125 | train_state, key, label, temp=config.sample_temperature 126 | ) 127 | img = decode_indices(vqvae_state, indices) 128 | images.append(img[0]) 129 | images = np.array(images) 130 | logs = {f"images_generate_{label}": images} 131 | log_dict(writer, logs, step=i, prefix="test/") 132 | 133 | with open(Path(exp_dir) / config.output_name, "wb") as f: 134 | pickle.dump(train_state, f) 135 | 136 | writer.close() 137 | 138 | 139 | if __name__ == "__main__": 140 | args = parse_args() 141 | with open(args.file, "r") as f: 142 | config = GPTConfig(**json.load(f)) 143 | 144 | main(config, checkpoint=args.checkpoint) 145 | -------------------------------------------------------------------------------- /trainers/vqvae_trainer.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, NamedTuple, Optional 2 | import functools 3 | 4 | import haiku as hk 5 | import jax.numpy as jnp 6 | import jax 7 | import optax 8 | from jax._src.random import KeyArray 9 | from optax._src.base import GradientTransformation 10 | 11 | from models import CnnEncoder, CnnDecoder, QuantizedCodebook 12 | from utils.annotations import VqVaeBatch, VqVaeState 13 | from utils.losses import mse 14 | 15 | 16 | class VqVaeApply(NamedTuple): 17 | encode: Callable[..., Any] 18 | decode: Callable[..., Any] 19 | quantize: Callable[..., Any] 20 | embed: Callable[..., Any] 21 | 22 | 23 | class VqVaeTrainer: 24 | def __init__( 25 | self, 26 | K: int, 27 | D: int, 28 | compression_level: int, 29 | res_layers: int, 30 | commitment_loss: float, 31 | optimizer: Optional[GradientTransformation], 32 | ): 33 | self.K = K 34 | self.D = D 35 | self.compression_level = compression_level 36 | self.res_layers = res_layers 37 | 38 | transformed = self.build( 39 | self.K, self.D, self.compression_level, self.res_layers, commitment_loss 40 | ) 41 | self.init = transformed.init 42 | self.apply = VqVaeApply(*transformed.apply) 43 | 44 | self.optimizer = optimizer 45 | 46 | @staticmethod 47 | def build( 48 | K: int, D: int, compression_level: int, res_layers: int, commitment_loss: float 49 | ): 50 | def f(): 51 | encoder = CnnEncoder( 52 | out_channels=D, 53 | downscale_level=compression_level, 54 | res_layers=res_layers, 55 | name="encoder", 56 | ) 57 | decoder = CnnDecoder( 58 | in_channels=D, 59 | upscale_level=compression_level, 60 | res_layers=res_layers, 61 | name="decoder", 62 | ) 63 | quantizer = QuantizedCodebook(K, D, commitment_loss, name="quantizer") 64 | 65 | def encode(x, is_training: bool): 66 | return encoder(x, is_training) 67 | 68 | def decode(x, is_training: bool): 69 | return decoder(x, is_training) 70 | 71 | def quantize(x): 72 | return quantizer(x) 73 | 74 | def embed(x): 75 | return quantizer.embed(x) 76 | 77 | def init(x, is_training: bool): 78 | encodings = encode(x, is_training) 79 | result = quantize(encodings) 80 | x_pred = decode(result["quantize"], is_training) 81 | z_q = embed(result["encoding_indices"]) 82 | return x_pred, z_q 83 | 84 | return init, (encode, decode, quantize, embed) 85 | 86 | return hk.multi_transform_with_state(f) 87 | 88 | def initial_state(self, rng: KeyArray, batch: VqVaeBatch) -> VqVaeState: 89 | params, state = self.init(rng, batch["image"], is_training=True) 90 | opt_state = self.optimizer.init(params) 91 | return VqVaeState(params, state, opt_state) 92 | 93 | def forward(self, params: hk.Params, state: hk.State, x, is_training: bool): 94 | z_e, state = self.apply.encode(params, state, None, x, is_training) 95 | result, state = self.apply.quantize(params, state, None, z_e) 96 | z_q = result["quantize"] 97 | x_pred, state = self.apply.decode(params, state, None, z_q, is_training) 98 | result["x_pred"] = x_pred 99 | return result, state 100 | 101 | def loss( 102 | self, params: hk.Params, state: hk.State, batch: VqVaeBatch, is_training: bool 103 | ): 104 | x = batch["image"] 105 | result, state = self.forward(params, state, x, is_training) 106 | reconstruct_loss = mse(x, result["x_pred"]) 107 | loss = reconstruct_loss + result["codebook_loss"] 108 | return loss, (state, result) 109 | 110 | @functools.partial(jax.jit, static_argnums=0) 111 | def update( 112 | self, vqvae_state: VqVaeState, batch: VqVaeBatch 113 | ) -> tuple[VqVaeState, dict[str, Any]]: 114 | assert self.optimizer is not None 115 | 116 | loss_and_grad = jax.value_and_grad(self.loss, has_aux=True) 117 | (loss, (state, _)), grads = loss_and_grad( 118 | vqvae_state.params, vqvae_state.state, batch, True 119 | ) 120 | updates, opt_state = self.optimizer.update( 121 | grads, vqvae_state.opt_state, vqvae_state.params 122 | ) 123 | params = optax.apply_updates(vqvae_state.params, updates) 124 | new_vqvae_state = VqVaeState(params, state, opt_state) 125 | logs = {"scalar_loss": jax.device_get(loss)} 126 | return new_vqvae_state, logs 127 | 128 | @functools.partial(jax.jit, static_argnums=0) 129 | def evaluate(self, vqvae_state: VqVaeState, batch: VqVaeBatch) -> dict[str, Any]: 130 | loss, (_, result) = self.loss( 131 | vqvae_state.params, vqvae_state.state, batch, is_training=False 132 | ) 133 | logs = { 134 | "scalar_loss": jax.device_get(loss), 135 | "images_original": batch["image"], 136 | "images_reconstruction": result["x_pred"], 137 | } 138 | return logs 139 | 140 | def lookup_indices(self, vqvae_state: VqVaeState, indices) -> jnp.ndarray: 141 | z_q, _ = self.apply.embed(vqvae_state.params, vqvae_state.state, None, indices) 142 | return z_q 143 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # VQ-VAE + GPT on JAX (and Haiku :scroll:) 4 | 5 |
6 | 7 | This is an implementation of VQ-VAE with a GPT-style sampler 8 | in the [JAX](https://github.com/google/jax) and 9 | [Haiku](https://github.com/deepmind/dm-haiku) ecosystem. 10 | 11 | Instead of using a PixelCNN to sample from the latent space like the 12 | [original paper](https://arxiv.org/pdf/1711.00937.pdf), this 13 | implementation uses a GPT-style, decoder-only transformer to generate samples. 14 | 15 | ## :star2: Generated samples 16 | 17 |
18 | 19 | ![generated_0](https://github.com/andylolu2/jax-vqvae-gpt/assets/66584117/60f3889d-d2c4-4c40-ab79-edf03c8983c1) 20 | ![generated_1](https://github.com/andylolu2/jax-vqvae-gpt/assets/66584117/264ffa1d-361d-4364-9829-24da6cae0f34) 21 | ![generated_2](https://github.com/andylolu2/jax-vqvae-gpt/assets/66584117/89fdd52b-06f6-47c8-bbed-576ee0e7abb6) 22 | ![generated_3](https://github.com/andylolu2/jax-vqvae-gpt/assets/66584117/a71b4e7e-ef63-4595-9ab4-5c93e5fb3082) 23 | ![generated_4](https://github.com/andylolu2/jax-vqvae-gpt/assets/66584117/6f272c9f-ff0a-4ccb-9518-6d33af578f71) 24 | ![generated_5](https://github.com/andylolu2/jax-vqvae-gpt/assets/66584117/005a779c-2ef7-4c53-b362-9fe049f4754e) 25 | ![generated_6](https://github.com/andylolu2/jax-vqvae-gpt/assets/66584117/c52fa039-f569-483d-9cdd-e10b7eccdfe7) 26 | ![generated_7](https://github.com/andylolu2/jax-vqvae-gpt/assets/66584117/370adcd2-39d1-455b-a741-f2b84ebac4e7) 27 | ![generated_8](https://github.com/andylolu2/jax-vqvae-gpt/assets/66584117/170f62c6-491e-42c4-9e27-efa67823ac6d) 28 | ![generated_9](https://github.com/andylolu2/jax-vqvae-gpt/assets/66584117/046d7773-03a1-4694-ac37-469b3ef0e766) 29 | 30 |
31 | 32 | > Generated with 33 | > ```terminal 34 | > python -m generate -p runs/gpt/exp0 -o generated/exp0 -t 0.5 -S 5 35 | 36 | ## :nut_and_bolt: Run it yourself! 37 | 38 | ### Step 0: (Optional, recommended) Create virtual environment 39 | 40 | 41 | ### Step 1: Install requirements 42 | 43 | **You will need to install JAX separately.** This is because 44 | the installation procedure will be different depending on which 45 | platform / accelerator / CUDA version you are on. 46 | 47 | Please follow [these instructions](https://github.com/google/jax#installation) 48 | to install JAX accordingly. 49 | 50 | Then, install the project's dependencies 51 | 52 | ```terminal 53 | pip install -r requirements.txt 54 | ``` 55 | 56 | ### Step 2: Create & modify `train_vqvae_config.json` 57 | 58 | Create a copy of the included sample JSON file. 59 | 60 | ```terminal 61 | cp static/configs/train_vqvae_config.json train_vqvae_config.json 62 | ``` 63 | 64 | Optionally, you can change the training parameters. 65 | 66 | > Since this file is likely to be changed while experimenting, 67 | > `/train_vqvae_config.json` is included in the [`.gitignore`](.gitignore). 68 | 69 | ### Step 3: Train VQ-VAE! 70 | 71 | ```terminal 72 | python -m train_vqvae -f train_vqvae_config.json 73 | ``` 74 | 75 | In another other terminal, open `tensorbaord` to monitor 76 | the training progress. 77 | 78 | ```terminal 79 | tensorboard --logdir runs/vqvae 80 | ``` 81 | > Your value to `--logdir` is the value of `logdir` 82 | > in `train_vqvae_config.json`. By default, it is `runs/vqvae`. 83 | 84 | ### Step 4: Encode dataset to prepare for training the GPT 85 | 86 | ```terminal 87 | python -m vqvae_encode -p runs/vqvae/exp0/ -o datasets/exp0-encoded 88 | ``` 89 | > See `python -m vqvae_encode -h` for usage details. 90 | 91 | This goes through the MNIST dataset and adds a column for 92 | the indices into the quantized codebook of each image. 93 | 94 | ### Step 5: Create & modify `train_gpt_config.json` 95 | 96 | Create a copy of the included sample JSON file. 97 | 98 | ```terminal 99 | cp static/configs/train_gpt_config.json train_gpt_config.json 100 | ``` 101 | 102 | Optionally, you can change the training parameters. 103 | 104 | > Since this file is likely to be changed while experimenting, 105 | > `/train_gpt_config.json` is included in the [`.gitignore`](.gitignore). 106 | 107 | ### Step 6: Train the GPT! 108 | 109 | ```terminal 110 | python -m train_gpt -f train_gpt_config.json 111 | ``` 112 | 113 | The training script prepends the class label on the sequence of 114 | encoding indices which allows for conditional generation 115 | afterwards. 116 | 117 | In another other terminal, open `tensorbaord` to monitor 118 | the training progress. 119 | 120 | ```terminal 121 | tensorboard --logdir runs/gpt 122 | ``` 123 | > Your value to `--logdir` is the value of `logdir` 124 | > in `train_gpt_config.json`. By default, it is `runs/gpt`. 125 | 126 | ### Step 7: Generate samples! 127 | 128 | ```terminal 129 | python -m generate -p runs/gpt/exp0 -o generated/exp0 -t 0.4 130 | ``` 131 | > See `python -m generate -h` for usage details. 132 | 133 | Voilà! View your generated samples in `generated/exp0`! 134 | 135 | ## References 136 | 137 | 1. Aaron van den Oord, Oriol Vinyals, Koray Kavukcuoglu. 138 | [Neural Discrete Representation Learning](https://arxiv.org/abs/1711.00937). 139 | 2017. arXiv:1711.00937. 140 | 2. Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, Illia Polosukhin. 141 | [Attention Is All You Need](https://arxiv.org/abs/1812.11118). 142 | 2017. arXiv:1706.03762. 143 | 3. Radford, Alec and Karthik Narasimhan. 144 | [Improving Language Understanding by Generative Pre-Training.](https://cdn.openai.com/research-covers/language-unsupervised/language_understanding_paper.pdf) 145 | 2018. 146 | 147 | ## Code references 148 | 149 | 1. [DeepMind Haiku examples](https://github.com/deepmind/dm-haiku/tree/main/examples) 150 | -------------------------------------------------------------------------------- /models/vqvae.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import haiku as hk 4 | import jax 5 | import jax.nn as nn 6 | import jax.numpy as jnp 7 | 8 | 9 | class ResBlock(hk.Module): 10 | def __init__(self, dim: int, kernel_size: int, name: Optional[str] = None): 11 | super().__init__(name) 12 | self.dim = dim 13 | self.kernel_size = kernel_size 14 | 15 | def __call__(self, x, is_training: bool) -> jnp.ndarray: 16 | res = hk.Conv2D(self.dim, self.kernel_size)(x) 17 | res = hk.BatchNorm(True, True, 0.9)(res, is_training) 18 | res = nn.relu(res) 19 | res = hk.Conv2D(self.dim, self.kernel_size)(res) 20 | res = hk.BatchNorm(True, True, 0.9)(res, is_training) 21 | x += res 22 | x = nn.relu(x) 23 | return x 24 | 25 | 26 | class CnnEncoder(hk.Module): 27 | def __init__( 28 | self, 29 | out_channels: int, 30 | downscale_level: int, 31 | res_layers: int = 1, 32 | kernel_size: int = 5, 33 | name: Optional[str] = None, 34 | ): 35 | super().__init__(name) 36 | self.out_channels = out_channels 37 | self.downscale_level = downscale_level 38 | self.res_layers = res_layers 39 | self.kernel_size = kernel_size 40 | 41 | def __call__(self, x, is_training: bool) -> jnp.ndarray: 42 | for i in range(self.downscale_level - 1, -1, -1): 43 | num_channels = self.out_channels // (2**i) 44 | x = hk.Conv2D(num_channels, self.kernel_size, stride=2)(x) 45 | x = hk.BatchNorm(True, True, 0.9)(x, is_training) 46 | x = nn.relu(x) 47 | for _ in range(self.res_layers): 48 | x = ResBlock(num_channels, self.kernel_size)(x, is_training) 49 | return x 50 | 51 | 52 | class CnnDecoder(hk.Module): 53 | def __init__( 54 | self, 55 | in_channels: int, 56 | upscale_level: int, 57 | res_layers: int = 1, 58 | kernel_size: int = 5, 59 | name: Optional[str] = None, 60 | ): 61 | super().__init__(name) 62 | self.in_channels = in_channels 63 | self.upscale_level = upscale_level 64 | self.res_layers = res_layers 65 | self.kernel_size = kernel_size 66 | 67 | def __call__(self, x: jnp.ndarray, is_training: bool) -> jnp.ndarray: 68 | for i in range(self.upscale_level - 1): 69 | num_channels = self.in_channels // (2**i) 70 | x = hk.Conv2DTranspose(num_channels, self.kernel_size, stride=2)(x) 71 | x = hk.BatchNorm(True, True, 0.9)(x, is_training) 72 | x = nn.relu(x) 73 | for _ in range(self.res_layers): 74 | x = ResBlock(num_channels, self.kernel_size)(x, is_training) 75 | x = hk.Conv2DTranspose(1, self.kernel_size, stride=2)(x) 76 | x = nn.sigmoid(x) 77 | return x 78 | 79 | 80 | class QuantizedCodebook(hk.Module): 81 | def __init__( 82 | self, 83 | embed_size_K: int, 84 | embed_dim_D: int, 85 | commitment_loss: float, 86 | name: Optional[str] = None, 87 | ): 88 | super().__init__(name) 89 | self.K = embed_size_K 90 | self.D = embed_dim_D 91 | self.beta = commitment_loss 92 | 93 | initializer = hk.initializers.VarianceScaling(distribution="uniform") 94 | self.codebook = hk.get_parameter("codebook", (self.K, self.D), init=initializer) 95 | 96 | def __call__(self, inputs) -> dict[str, jnp.ndarray]: 97 | """Connects the module to some inputs. 98 | 99 | Args: 100 | inputs: Tensor, final dimension must be equal to ``embedding_dim``. All 101 | other leading dimensions will be flattened and treated as a large batch. 102 | is_training: boolean, whether this connection is to training data. 103 | 104 | Returns: 105 | dict: Dictionary containing the following keys and values: 106 | * ``quantize``: Tensor containing the quantized version of the input. 107 | * ``loss``: Tensor containing the loss to optimize. 108 | * ``encoding_indices``: Tensor containing the discrete encoding indices, 109 | ie which element of the quantized space each input element was mapped 110 | to. 111 | """ 112 | # input shape A1 x ... x An x D 113 | # shape N x D, N = A1 * ... * An 114 | flattened = jnp.reshape(inputs, (-1, self.D)) 115 | 116 | # shape N x 1 117 | flattened_sqr = jnp.sum(flattened**2, axis=-1, keepdims=True) 118 | 119 | # shape 1 x K 120 | codeboook_sqr = jnp.sum(self.codebook**2, axis=-1, keepdims=True).T 121 | 122 | # shape N x K 123 | # distances = (a-b)^2 = a^2 - 2*a*b + b^2 124 | distances = flattened_sqr - 2 * (flattened @ self.codebook.T) + codeboook_sqr 125 | 126 | # shape A1 x ... x An 127 | encoding_indices = jnp.reshape( 128 | jnp.argmin(distances, axis=-1), inputs.shape[:-1] 129 | ) 130 | 131 | # shape A1 x ... x An x D 132 | quantize = self.codebook[encoding_indices] 133 | 134 | # loss = ||sg[z_e(x)] - e|| + beta||z_e(x) - sg[e]|| 135 | encoding_loss = jnp.mean((jax.lax.stop_gradient(inputs) - quantize) ** 2) 136 | commit_loss = jnp.mean((inputs - jax.lax.stop_gradient(quantize)) ** 2) 137 | loss = encoding_loss + self.beta * commit_loss 138 | 139 | # straight-through estimator for reconstruction loss 140 | quantize = inputs + jax.lax.stop_gradient(quantize - inputs) 141 | 142 | return { 143 | "codebook_loss": loss, 144 | "quantize": quantize, 145 | "encoding_indices": encoding_indices, 146 | } 147 | 148 | def embed(self, indices): 149 | outshape = indices.shape + (self.D,) 150 | x = self.codebook[indices].reshape(outshape) 151 | return x 152 | -------------------------------------------------------------------------------- /trainers/gpt_trainer.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import functools 3 | 4 | import haiku as hk 5 | import jax.numpy as jnp 6 | import jax.nn as nn 7 | import jax 8 | import optax 9 | from jax._src.random import KeyArray 10 | from optax._src.base import GradientTransformation 11 | 12 | from models import GPTLmHeadModel 13 | from utils.annotations import GPTBatch, GPTState, VqVaeConfig 14 | from utils.losses import cross_entropy 15 | 16 | 17 | class VqVaeGPTTrainer: 18 | def __init__( 19 | self, 20 | num_label_classes: int, 21 | vqvae_config: VqVaeConfig, 22 | num_heads: int, 23 | hidden_dim: int, 24 | num_layers: int, 25 | dropout_rate: float, 26 | sample: GPTBatch, 27 | optimizer: Optional[GradientTransformation], 28 | ): 29 | self.vqvae_config = vqvae_config 30 | self.num_label_classes = num_label_classes 31 | self.num_classes = num_label_classes + vqvae_config.K 32 | self.decoder_input_shape = sample["encoding_indices"].shape[1:] 33 | self.seq_length = self.tokenize(sample).shape[-1] 34 | 35 | transformed = self.build( 36 | num_heads, 37 | hidden_dim, 38 | num_layers, 39 | self.num_classes, 40 | dropout_rate, 41 | self.seq_length, 42 | ) 43 | self.init = transformed.init 44 | self.apply = transformed.apply 45 | 46 | self.optimizer = optimizer 47 | 48 | @staticmethod 49 | def build( 50 | num_heads: int, 51 | hidden_dim: int, 52 | num_layers: int, 53 | num_classes: int, 54 | dropout_rate: float, 55 | seq_length: int, 56 | ): 57 | def init(tokens, is_training: bool): 58 | net = GPTLmHeadModel( 59 | num_heads, hidden_dim, num_layers, num_classes, dropout_rate, seq_length 60 | ) 61 | return net(tokens, is_training) 62 | 63 | return hk.transform_with_state(init) 64 | 65 | def initial_state(self, rng, batch: GPTBatch) -> GPTState: 66 | tokens = self.tokenize(batch) 67 | 68 | rng, rng1 = jax.random.split(rng) 69 | params, state = self.init(rng, tokens, is_training=True) 70 | opt_state = self.optimizer.init(params) 71 | 72 | return GPTState(params, state, opt_state, rng1) 73 | 74 | def tokenize(self, batch: GPTBatch): 75 | # labels shape B x 1 76 | labels = batch["label"][..., None] 77 | vqvae_tokens = batch["encoding_indices"] 78 | # tokens shape B x (W * H) 79 | vqvae_tokens = vqvae_tokens.reshape((vqvae_tokens.shape[0], -1)) 80 | # offset encoding indices 81 | vqvae_tokens += self.num_label_classes 82 | 83 | # add labels as additional tokens 84 | # tokens shape B x (1 + W * H) 85 | tokens = jnp.concatenate((labels, vqvae_tokens), axis=-1) 86 | return tokens 87 | 88 | def forward( 89 | self, params: hk.Params, state: hk.State, rng, tokens, is_training: bool 90 | ) -> tuple[jnp.ndarray, hk.State]: 91 | y_pred, state = self.apply(params, state, rng, tokens, is_training) 92 | return y_pred, state 93 | 94 | def loss(self, params: hk.Params, state: hk.State, rng, tokens, is_training: bool): 95 | y = nn.one_hot(tokens, self.num_classes) 96 | y_pred, state = self.forward(params, state, rng, tokens, is_training) 97 | 98 | # use the first n-1 tokens to predict the nth token 99 | loss = cross_entropy(y[:, 1:], y_pred[:, :-1]) 100 | return loss, state 101 | 102 | @functools.partial(jax.jit, static_argnums=0) 103 | def update(self, gpt_state: GPTState, batch: GPTBatch) -> tuple[GPTState, dict]: 104 | assert self.optimizer is not None 105 | 106 | rng, rng1 = jax.random.split(gpt_state.rng) 107 | tokens = self.tokenize(batch) 108 | loss_and_grad = jax.value_and_grad(self.loss, has_aux=True) 109 | (loss, state), grads = loss_and_grad( 110 | gpt_state.params, gpt_state.state, rng, tokens, True 111 | ) 112 | updates, opt_state = self.optimizer.update( 113 | grads, gpt_state.opt_state, gpt_state.params 114 | ) 115 | params = optax.apply_updates(gpt_state.params, updates) 116 | 117 | new_gpt_state = GPTState(params, state, opt_state, rng1) 118 | logs = {"scalar_loss": jax.device_get(loss)} 119 | return new_gpt_state, logs 120 | 121 | @functools.partial(jax.jit, static_argnums=0) 122 | def evaluate(self, gpt_state: GPTState, batch: GPTBatch) -> tuple[GPTState, dict]: 123 | tokens = self.tokenize(batch) 124 | loss, state = self.loss(gpt_state.params, gpt_state.state, None, tokens, False) 125 | new_gpt_state = GPTState( 126 | gpt_state.params, state, gpt_state.opt_state, gpt_state.rng 127 | ) 128 | logs = { 129 | "scalar_loss": jax.device_get(loss), 130 | } 131 | return new_gpt_state, logs 132 | 133 | @functools.partial(jax.jit, static_argnums=0) 134 | def generate(self, gpt_state: GPTState, rng: KeyArray, label: int, temp: float = 1): 135 | output_len = self.seq_length - 1 136 | padded_tokens = [[label] + [0] * output_len] 137 | tokens = jnp.array(padded_tokens, dtype=jnp.int32) 138 | 139 | def body_fun(i, val: tuple[jnp.ndarray, GPTState, KeyArray]): 140 | # token shape 1 x (W * H + 1) 141 | tokens, gpt_state, rng = val 142 | 143 | y_pred, _ = self.apply( 144 | gpt_state.params, gpt_state.state, None, tokens, False 145 | ) 146 | probs = (y_pred[0, i, :] / temp)[self.num_label_classes :] 147 | probs = nn.softmax(probs) 148 | 149 | vqvae_tokens = jnp.arange( 150 | self.num_label_classes, self.num_classes, dtype=jnp.int32 151 | ) 152 | rng, rng1 = jax.random.split(rng) 153 | next_token = jax.random.choice(rng1, vqvae_tokens, p=probs) 154 | tokens = tokens.at[0, i + 1].set(next_token) 155 | 156 | return (tokens, gpt_state, rng) 157 | 158 | # token shape 1 x (H * W + 1) 159 | tokens, _, rng = jax.lax.fori_loop( 160 | 0, output_len, body_fun, (tokens, gpt_state, rng) 161 | ) 162 | # shape 1 x H * W 163 | tokens = tokens[0, 1:] 164 | # shape 1 x H x W 165 | tokens = jnp.reshape(tokens, self.decoder_input_shape)[None, ...] 166 | tokens = tokens - self.num_label_classes 167 | return tokens, rng 168 | --------------------------------------------------------------------------------