├── models ├── __init__.py ├── gpt.py └── vqvae.py ├── requirements.txt ├── .gitignore ├── utils ├── losses.py ├── logger.py ├── annotations.py └── dataset.py ├── static └── configs │ ├── train_vqvae_config.json │ └── train_gpt_config.json ├── LICENSE ├── train_vqvae.py ├── vqvae_encode.py ├── generate.py ├── train_gpt.py ├── trainers ├── vqvae_trainer.py └── gpt_trainer.py └── README.md /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .gpt import GPTLmHeadModel 2 | from .vqvae import CnnDecoder, CnnEncoder, QuantizedCodebook 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | dm-haiku==0.0.11 2 | optax==0.1.7 3 | numpy==1.22.3 4 | tqdm==4.64.0 5 | datasets==2.0.0 6 | Pillow==9.3.0 7 | black 8 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .python-version 3 | 4 | # train configs 5 | /train_vqvae_config.json 6 | /train_gpt_config.json 7 | 8 | # training outputs 9 | runs 10 | datasets 11 | generated 12 | -------------------------------------------------------------------------------- /utils/losses.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | import jax.nn as nn 3 | 4 | 5 | def mse(y, y_pred): 6 | return jnp.mean((y - y_pred) ** 2) 7 | 8 | 9 | def cross_entropy(y, y_pred): 10 | return jnp.mean(jnp.sum(-y * nn.log_softmax(y_pred), axis=-1)) 11 | -------------------------------------------------------------------------------- /static/configs/train_vqvae_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed": 23, 3 | "dataset": "mnist", 4 | "resize_shape": [40, 40], 5 | "K": 256, 6 | "D": 128, 7 | "compression_level": 3, 8 | "res_layers": 2, 9 | "commitment_loss": 0.25, 10 | "train_dset_percentage": 100, 11 | "test_dset_percentage": 100, 12 | "train_steps": 15000, 13 | "test_steps": 50, 14 | "test_every": 250, 15 | "train_batch_size": 64, 16 | "test_batch_size": 64, 17 | "learning_rate": 3e-4, 18 | "weight_decay": 1e-5, 19 | "logdir": "runs/vqvae", 20 | "output_name": "train_state.pkl" 21 | } 22 | -------------------------------------------------------------------------------- /static/configs/train_gpt_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed": 23, 3 | "num_heads": 4, 4 | "hidden_dim": 64, 5 | "num_layers": 4, 6 | "dropout_rate": 0.1, 7 | "vqvae_config": "runs/vqvae/exp0/config.json", 8 | "vqvae_state": "runs/vqvae/exp0/train_state.pkl", 9 | "train_steps": 15000, 10 | "test_steps": 20, 11 | "test_every": 250, 12 | "train_dataset": "datasets/exp0-encoded/train", 13 | "test_dataset": "datasets/exp0-encoded/test", 14 | "train_batch_size": 64, 15 | "test_batch_size": 64, 16 | "generate_samples": 16, 17 | "sample_temperature": 0.3, 18 | "learning_rate": 3e-4, 19 | "weight_decay": 1e-5, 20 | "logdir": "runs/gpt", 21 | "output_name": "train_state.pkl" 22 | } 23 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Andy Lo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | from pathlib import Path 3 | import re 4 | 5 | from tensorboardX import SummaryWriter 6 | 7 | prefix = "exp" 8 | 9 | 10 | def get_writer(base_dir: str, disable: bool = False) -> SummaryWriter: 11 | logdir = Path(base_dir) 12 | logdir.mkdir(parents=True, exist_ok=True) 13 | max_run = -1 14 | for path in logdir.iterdir(): 15 | match = re.fullmatch(rf"{prefix}([0-9]+)", path.name) 16 | max_run = max(max_run, int(match.group(1))) 17 | return SummaryWriter( 18 | logdir=str(logdir / f"{prefix}{max_run+1}"), 19 | flush_secs=10, 20 | write_to_disk=not disable, 21 | ) 22 | 23 | 24 | def log_dict(writer: SummaryWriter, logs: dict[str, Any], step: int, prefix: str = ""): 25 | for k, v in logs.items(): 26 | if k.startswith("scalar_"): 27 | k = k[len("scalar_") :] 28 | value = v 29 | if isinstance(value, list): 30 | value = sum(value) / len(value) 31 | writer.add_scalar(f"{prefix}{k}", value, step) 32 | elif k.startswith("images_"): 33 | k = k[len("images_") :] 34 | value = v 35 | if isinstance(value, list): 36 | value = value[0] 37 | writer.add_images(f"{prefix}{k}", value, step, dataformats="NHWC") 38 | -------------------------------------------------------------------------------- /utils/annotations.py: -------------------------------------------------------------------------------- 1 | from typing import Any, NamedTuple, TypedDict 2 | 3 | import numpy as np 4 | import haiku as hk 5 | import optax 6 | from jax._src.random import KeyArray 7 | 8 | 9 | class VqVaeConfig(NamedTuple): 10 | seed: int 11 | dataset: str 12 | resize_shape: tuple[int, int] 13 | K: int 14 | D: int 15 | compression_level: int 16 | res_layers: int 17 | commitment_loss: float 18 | train_dset_percentage: int 19 | test_dset_percentage: int 20 | train_steps: int 21 | test_steps: int 22 | test_every: int 23 | train_batch_size: int 24 | test_batch_size: int 25 | learning_rate: float 26 | weight_decay: float 27 | logdir: str 28 | output_name: str 29 | 30 | 31 | class GPTConfig(NamedTuple): 32 | seed: int 33 | num_heads: int 34 | hidden_dim: int 35 | num_layers: int 36 | dropout_rate: float 37 | vqvae_config: str 38 | vqvae_state: str 39 | train_steps: int 40 | test_steps: int 41 | test_every: int 42 | train_dataset: str 43 | test_dataset: str 44 | train_batch_size: int 45 | test_batch_size: int 46 | generate_samples: int 47 | sample_temperature: float 48 | learning_rate: float 49 | weight_decay: float 50 | logdir: str 51 | output_name: str 52 | 53 | 54 | class VqVaeTuple(NamedTuple): 55 | encoder: Any 56 | decoder: Any 57 | quantizer: Any 58 | 59 | 60 | class VqVaeState(NamedTuple): 61 | params: hk.Params 62 | state: hk.State 63 | opt_state: optax.OptState 64 | 65 | 66 | class GPTState(NamedTuple): 67 | params: hk.Params 68 | state: hk.State 69 | opt_state: optax.OptState 70 | rng: KeyArray 71 | 72 | 73 | class VqVaeBatch(TypedDict): 74 | image: np.ndarray 75 | label: np.ndarray 76 | 77 | 78 | class GPTBatch(TypedDict): 79 | label: np.ndarray 80 | encoding_indices: np.ndarray 81 | -------------------------------------------------------------------------------- /models/gpt.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import haiku as hk 4 | import jax.nn as nn 5 | import jax.numpy as jnp 6 | 7 | 8 | class CasualSelfAttention(hk.MultiHeadAttention): 9 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray: 10 | # x is B x N x H 11 | seq_len = x.shape[1] 12 | casual_mask = jnp.tril(jnp.ones((seq_len, seq_len))) 13 | # mask is B x num_heads x N x N 14 | casual_mask = jnp.tile(casual_mask, (x.shape[0], self.num_heads, 1, 1)) 15 | return super().__call__(x, x, x, casual_mask) 16 | 17 | 18 | class DecoderBlock(hk.Module): 19 | def __init__( 20 | self, 21 | num_heads: int, 22 | hidden_dim: int, 23 | model_size: int, 24 | weight_init_scale: float, 25 | dropout_rate: float, 26 | name: Optional[str] = None, 27 | ): 28 | super().__init__(name) 29 | self.casual_atten = CasualSelfAttention( 30 | num_heads, hidden_dim, weight_init_scale 31 | ) 32 | self.dropout_rate = dropout_rate 33 | 34 | def __call__(self, x: jnp.ndarray, is_training: bool) -> jnp.ndarray: 35 | # Structured according to original paper 36 | # https://arxiv.org/pdf/1706.03762.pdf#section.3 37 | res = self.casual_atten(x) 38 | if is_training: 39 | res = hk.dropout(hk.next_rng_key(), self.dropout_rate, res) 40 | x += res 41 | x = hk.LayerNorm(-1, True, True)(x) 42 | 43 | dim = x.shape[-1] 44 | res = hk.Linear(dim * 4)(x) 45 | res = nn.gelu(res) 46 | res = hk.Linear(dim)(res) 47 | if is_training: 48 | res = hk.dropout(hk.next_rng_key(), self.dropout_rate, res) 49 | x += res 50 | x = hk.LayerNorm(-1, True, True)(x) 51 | return x 52 | 53 | 54 | class GPTLmHeadModel(hk.Module): 55 | def __init__( 56 | self, 57 | num_heads: int, 58 | hidden_dim: int, 59 | num_layers: int, 60 | num_classes: int, 61 | dropout_rate: float, 62 | max_length: int, 63 | name: Optional[str] = None, 64 | ): 65 | super().__init__(name) 66 | self.num_heads = num_heads 67 | self.hidden_dim = hidden_dim 68 | self.num_layers = num_layers 69 | self.num_classes = num_classes 70 | self.dropout_rate = dropout_rate 71 | self.max_length = max_length 72 | self.model_size = self.num_heads * self.hidden_dim 73 | 74 | self.init_scale = 2.0 / num_layers 75 | self.embed = hk.Embed(num_classes, self.model_size) 76 | self.positional_embeddings = hk.get_parameter( 77 | "pos_embs", 78 | [self.max_length, self.model_size], 79 | init=hk.initializers.TruncatedNormal(stddev=0.02), 80 | ) 81 | self.blocks = [ 82 | DecoderBlock( 83 | self.num_heads, 84 | self.hidden_dim, 85 | self.model_size, 86 | self.init_scale, 87 | self.dropout_rate, 88 | ) 89 | for _ in range(num_layers) 90 | ] 91 | self.lm_head = hk.Linear(num_classes) 92 | 93 | def __call__(self, x, is_training: bool): 94 | seq_length = x.shape[1] 95 | x = self.embed(x) + self.positional_embeddings[:seq_length] 96 | for block in self.blocks: 97 | x = block(x, is_training) 98 | x = self.lm_head(x) 99 | # softmax is taken outside 100 | return x 101 | -------------------------------------------------------------------------------- /train_vqvae.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import pickle 3 | import json 4 | import argparse 5 | from pathlib import Path 6 | from collections import defaultdict 7 | 8 | from tqdm import tqdm 9 | import jax 10 | import optax 11 | 12 | from trainers.vqvae_trainer import VqVaeTrainer 13 | from utils.dataset import load_dset 14 | from utils.logger import get_writer, log_dict 15 | from utils.annotations import VqVaeConfig, VqVaeState 16 | 17 | 18 | def parse_args(): 19 | parser = argparse.ArgumentParser(description="Train a VQ-VAE on the MNIST dataset.") 20 | parser.add_argument( 21 | "-f", "--file", type=str, required=True, help="path to the json config file." 22 | ) 23 | parser.add_argument( 24 | "-chkp", "--checkpoint", type=str, help="path to train state pkl file." 25 | ) 26 | return parser.parse_args() 27 | 28 | 29 | def main(config: VqVaeConfig, checkpoint: Optional[str] = None): 30 | writer = get_writer(config.logdir) 31 | exp_dir = writer.logdir 32 | 33 | # save config file 34 | with open(Path(exp_dir) / "config.json", "w") as f: 35 | json.dump(config._asdict(), f, indent=4) 36 | 37 | # load dataset 38 | _, dset_train = load_dset( 39 | name=config.dataset, 40 | split="train", 41 | batch_size=config.train_batch_size, 42 | percentage=config.train_dset_percentage, 43 | resize_shape=config.resize_shape, 44 | seed=config.seed, 45 | ) 46 | _, dset_test = load_dset( 47 | name=config.dataset, 48 | split="test", 49 | batch_size=config.test_batch_size, 50 | percentage=config.test_dset_percentage, 51 | resize_shape=config.resize_shape, 52 | seed=config.seed, 53 | ) 54 | 55 | # initialize model 56 | optimizer = optax.adamw(config.learning_rate, weight_decay=config.weight_decay) 57 | trainer = VqVaeTrainer( 58 | K=config.K, 59 | D=config.D, 60 | compression_level=config.compression_level, 61 | res_layers=config.res_layers, 62 | commitment_loss=config.commitment_loss, 63 | optimizer=optimizer, 64 | ) 65 | if checkpoint is None: 66 | key = jax.random.PRNGKey(config.seed) 67 | train_state = trainer.initial_state(key, next(dset_train)[1]) 68 | else: 69 | with open(checkpoint, "rb") as f: 70 | train_state: VqVaeState = pickle.load(f) 71 | 72 | # training loop 73 | for i in tqdm(range(config.train_steps)): 74 | # update 75 | epoch, batch = next(dset_train) 76 | train_state, logs = trainer.update(train_state, batch) 77 | 78 | # log 79 | logs["scalar_epoch"] = epoch 80 | log_dict(writer, logs, step=i, prefix="train/") 81 | 82 | # evaluate 83 | if (i + 1) % config.test_every == 0: 84 | logs = defaultdict(list) 85 | for _ in range(config.test_steps): 86 | _, batch = next(dset_test) 87 | log = trainer.evaluate(train_state, batch) 88 | for k, v in log.items(): 89 | logs[k].append(v) 90 | log_dict(writer, logs, step=i, prefix="test/") 91 | 92 | # save model 93 | with open(Path(exp_dir) / config.output_name, "wb") as f: 94 | pickle.dump(train_state, f) 95 | 96 | writer.close() 97 | 98 | 99 | if __name__ == "__main__": 100 | args = parse_args() 101 | with open(args.file, "r") as f: 102 | config = VqVaeConfig(**json.load(f)) 103 | 104 | main(config, checkpoint=args.checkpoint) 105 | -------------------------------------------------------------------------------- /utils/dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Iterator, Optional, Union 2 | from itertools import count 3 | from pathlib import Path 4 | from datasets.features.features import Features 5 | 6 | import numpy as np 7 | import datasets 8 | from datasets.arrow_dataset import Dataset 9 | from skimage.transform import resize 10 | 11 | from utils.annotations import VqVaeBatch, GPTBatch 12 | 13 | 14 | def process_image(img, shape: tuple[int, int]) -> np.ndarray: 15 | img = np.array(img, dtype=np.float32) / 255 16 | img = resize(img, shape) 17 | img = img[..., None] 18 | return img 19 | 20 | 21 | def load_dset( 22 | name: str, 23 | split: str, 24 | batch_size: int, 25 | percentage: int, 26 | resize_shape: tuple[int, int], 27 | repeat: bool = True, 28 | seed: Optional[int] = None, 29 | ) -> tuple[Features, Iterator[tuple[int, VqVaeBatch]]]: 30 | """ 31 | Loads a dataset with preprocessing, batching, and repeating. 32 | 33 | Args: 34 | name (str): The name of the dataset on Hugging Face Hub. 35 | split (str): The split of the dataset, such as "train" / "test". 36 | batch_size (int): The batch size to load the data. 37 | percentage (int): The percentage of the dataset to use. 38 | resize_shape (tuple[int, int], optional): Shape to resize the image to. 39 | repeat (bool, optional): Whether or not to repeat the dataset after 40 | iterating through one epoch. Defaults to True. 41 | seed (Optional[int], optional): The seed used to suffle the dataset. Defaults to None. 42 | 43 | Returns: 44 | tuple[Features, Iterator[tuple[int, VqVaeBatch]]]: A tuple of dataset features and the 45 | iterator which yields preprocessed, batched, and repeated data. 46 | """ 47 | dset = datasets.load_dataset(name, split=f"{split}[:{percentage}%]") 48 | assert isinstance(dset, Dataset) 49 | 50 | features: Features = dset.features 51 | 52 | def preprocess(batch) -> VqVaeBatch: 53 | return { 54 | "image": np.array( 55 | [process_image(img, resize_shape) for img in batch["image"]] 56 | ), 57 | "label": np.array(batch["label"]), 58 | } 59 | 60 | dset.set_transform(preprocess) 61 | 62 | def iterator(dset: Dataset) -> Iterator[tuple[int, VqVaeBatch]]: 63 | counter = count() 64 | while True: 65 | dset = dset.shuffle(seed) 66 | epoch = next(counter) 67 | for i in range(0, len(dset) - batch_size, batch_size): 68 | yield epoch, dset[i : i + batch_size] 69 | 70 | if not repeat: 71 | break 72 | 73 | return features, iterator(dset) 74 | 75 | 76 | def load_vqvae_processed( 77 | path: Union[str, Path], 78 | batch_size: int, 79 | repeat: bool = True, 80 | seed: Optional[int] = None, 81 | ) -> tuple[Features, Iterator[tuple[int, GPTBatch]]]: 82 | dset = datasets.load.load_from_disk(str(path)) 83 | assert isinstance(dset, Dataset) 84 | 85 | features: Features = dset.features 86 | 87 | def preprocess(batch) -> GPTBatch: 88 | return { 89 | "encoding_indices": np.array(batch["encoding_indices"]), 90 | "label": np.array(batch["label"]), 91 | } 92 | 93 | dset.set_transform(preprocess) 94 | 95 | def iterator(dset: Dataset) -> Iterator[tuple[int, GPTBatch]]: 96 | counter = count() 97 | while True: 98 | dset = dset.shuffle(seed) 99 | epoch = next(counter) 100 | for i in range(0, len(dset) - batch_size, batch_size): 101 | yield epoch, dset[i : i + batch_size] 102 | 103 | if not repeat: 104 | break 105 | 106 | return features, iterator(dset) 107 | -------------------------------------------------------------------------------- /vqvae_encode.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import json 3 | import argparse 4 | from pathlib import Path 5 | 6 | import datasets 7 | import jax 8 | import numpy as np 9 | 10 | from trainers.vqvae_trainer import VqVaeTrainer 11 | from utils.dataset import process_image 12 | from utils.annotations import VqVaeConfig, VqVaeState 13 | 14 | 15 | def parse_args(): 16 | parser = argparse.ArgumentParser( 17 | description="Encode the MNIST dataset with a VQ-VAE.", 18 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 19 | ) 20 | parser.add_argument( 21 | "-p", 22 | "--path", 23 | type=str, 24 | required=True, 25 | default=argparse.SUPPRESS, 26 | help="path to directory of the trained VQ-VAE model.", 27 | ) 28 | parser.add_argument( 29 | "-o", 30 | "--out", 31 | type=str, 32 | required=True, 33 | default=argparse.SUPPRESS, 34 | help="path to directory to save the processed datasets.", 35 | ) 36 | parser.add_argument( 37 | "-b", 38 | "--batch_size", 39 | type=int, 40 | default=64, 41 | help="batch size to process the dataset with.", 42 | ) 43 | parser.add_argument( 44 | "-P", 45 | "--percentage", 46 | type=int, 47 | default=100, 48 | help="percentage of dataset to encode.", 49 | ) 50 | return parser.parse_args() 51 | 52 | 53 | def main(path: str, out_path: str, batch_size: int, percentage: int): 54 | model_dir = Path(path) 55 | 56 | with open(model_dir / "config.json", "r") as f: 57 | config = VqVaeConfig(**json.load(f)) 58 | with open(model_dir / config.output_name, "rb") as f: 59 | vqvae_state: VqVaeState = pickle.load(f) 60 | 61 | trainer = VqVaeTrainer( 62 | K=config.K, 63 | D=config.D, 64 | compression_level=config.compression_level, 65 | res_layers=config.res_layers, 66 | commitment_loss=config.commitment_loss, 67 | optimizer=None, 68 | ) 69 | 70 | @jax.jit 71 | def infer(vqvae_state: VqVaeState, x: np.ndarray): 72 | params, state = vqvae_state.params, vqvae_state.state 73 | z_e, _ = trainer.apply.encode(params, state, None, x, is_training=False) 74 | result, _ = trainer.apply.quantize(params, state, None, z_e) 75 | indices = result["encoding_indices"] 76 | 77 | # z1, z2 are not necessary but used for assertion 78 | z1 = trainer.lookup_indices(vqvae_state, indices) 79 | z2 = result["quantize"] 80 | 81 | return result, (z1, z2) 82 | 83 | def encode(batch): 84 | images = np.array( 85 | [process_image(img, shape=config.resize_shape) for img in batch["image"]] 86 | ) 87 | result, (z1, z2) = infer(vqvae_state, images) 88 | batch["encoding_indices"] = np.array(result["encoding_indices"]) 89 | 90 | assert np.allclose(z1, z2, atol=1e-6, rtol=0) 91 | assert batch["encoding_indices"].ndim == 3 92 | assert batch["encoding_indices"].dtype == np.int32 93 | assert np.max(batch["encoding_indices"]) < config.K 94 | assert np.min(batch["encoding_indices"]) >= 0 95 | 96 | return batch 97 | 98 | out_dir = Path(out_path) 99 | out_dir.mkdir(parents=True, exist_ok=True) 100 | for split in ("train", "test"): 101 | dset = datasets.load_dataset("mnist", split=f"{split}[:{percentage}%]") 102 | dset = dset.map(encode, batched=True, batch_size=batch_size) 103 | dset.save_to_disk(str(out_dir / split)) 104 | 105 | 106 | if __name__ == "__main__": 107 | args = parse_args() 108 | main( 109 | path=args.path, 110 | out_path=args.out, 111 | batch_size=args.batch_size, 112 | percentage=args.percentage, 113 | ) 114 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import pickle 4 | from pathlib import Path 5 | 6 | from PIL import Image 7 | import jax 8 | import numpy as np 9 | 10 | from utils.annotations import GPTBatch, GPTConfig, GPTState, VqVaeConfig, VqVaeState 11 | from trainers.gpt_trainer import VqVaeGPTTrainer 12 | from trainers.vqvae_trainer import VqVaeTrainer 13 | 14 | 15 | def parse_args(): 16 | parser = argparse.ArgumentParser( 17 | description="Generate MNIST samples by sampling VQ-VAE codes with a GPT-style transformer.", 18 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 19 | ) 20 | parser.add_argument( 21 | "-p", 22 | "--path", 23 | type=str, 24 | required=True, 25 | default=argparse.SUPPRESS, 26 | help="path to the pretrained GPT directory.", 27 | ) 28 | parser.add_argument( 29 | "-o", 30 | "--out", 31 | type=str, 32 | required=True, 33 | default=argparse.SUPPRESS, 34 | help="output directory to save generated images.", 35 | ) 36 | parser.add_argument( 37 | "-s", "--seed", type=int, default=0, help="seed to sample results from." 38 | ) 39 | parser.add_argument( 40 | "-t", 41 | "--temperature", 42 | type=float, 43 | default=0.2, 44 | help="temperature to sample results.", 45 | ) 46 | parser.add_argument( 47 | "-S", "--samples", type=int, default=8, help="Generates S x S samples." 48 | ) 49 | return parser.parse_args() 50 | 51 | 52 | def main(path: str, seed: int, temp: float, samples: int, out_path: str): 53 | # load configs 54 | model_dir = Path(path) 55 | with open(model_dir / "config.json", "r") as f: 56 | gpt_config = GPTConfig(**json.load(f)) 57 | with open(gpt_config.vqvae_config, "r") as f: 58 | vqvae_config = VqVaeConfig(**json.load(f)) 59 | h, w = vqvae_config.resize_shape 60 | 61 | # load VQ-VAE model for decoing sampled indices into image 62 | with open(gpt_config.vqvae_state, "rb") as f: 63 | vqvae_state: VqVaeState = pickle.load(f) 64 | vqvae = VqVaeTrainer( 65 | K=vqvae_config.K, 66 | D=vqvae_config.D, 67 | compression_level=vqvae_config.compression_level, 68 | res_layers=vqvae_config.res_layers, 69 | commitment_loss=vqvae_config.commitment_loss, 70 | optimizer=None, 71 | ) 72 | 73 | @jax.jit 74 | def decode_indices(vqvae_state: VqVaeState, indices): 75 | z_q = vqvae.lookup_indices(vqvae_state, indices) 76 | img, _ = vqvae.apply.decode( 77 | vqvae_state.params, vqvae_state.state, None, z_q, is_training=False 78 | ) 79 | return img 80 | 81 | # load GPT and initialize input output shapes 82 | with open(model_dir / gpt_config.output_name, "rb") as f: 83 | gpt_state: GPTState = pickle.load(f) 84 | x = np.zeros((1, h, w, 1), dtype=np.float32) 85 | res, _ = vqvae.forward(vqvae_state.params, vqvae_state.state, x, False) 86 | sample: GPTBatch = { 87 | "encoding_indices": res["encoding_indices"], 88 | "label": np.zeros((1,)), 89 | } 90 | gpt = VqVaeGPTTrainer( 91 | num_label_classes=10, 92 | vqvae_config=vqvae_config, 93 | num_heads=gpt_config.num_heads, 94 | hidden_dim=gpt_config.hidden_dim, 95 | num_layers=gpt_config.num_layers, 96 | dropout_rate=gpt_config.dropout_rate, 97 | sample=sample, 98 | optimizer=None, 99 | ) 100 | 101 | # generate 102 | rng = jax.random.PRNGKey(seed) 103 | out_dir = Path(out_path) 104 | out_dir.mkdir(parents=True, exist_ok=True) 105 | for label in range(10): 106 | image = np.zeros((samples * h, samples * w), dtype=np.uint8) 107 | for i in range(samples): 108 | for j in range(samples): 109 | indices, rng = gpt.generate(gpt_state, rng, label, temp=temp) 110 | img = decode_indices(vqvae_state, indices) 111 | img = (img[0, :, :, 0] * 255).astype(np.uint8) 112 | x, y = i * h, j * w 113 | image[x : x + h, y : y + w] = img 114 | im = Image.fromarray(image) 115 | im.save(str(out_dir / f"generated_{label}.png")) 116 | 117 | 118 | if __name__ == "__main__": 119 | args = parse_args() 120 | main( 121 | path=args.path, 122 | out_path=args.out, 123 | seed=args.seed, 124 | temp=args.temperature, 125 | samples=args.samples, 126 | ) 127 | -------------------------------------------------------------------------------- /train_gpt.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from pathlib import Path 3 | import pickle 4 | import json 5 | import argparse 6 | from collections import defaultdict 7 | 8 | import jax 9 | import optax 10 | from tqdm import tqdm 11 | import numpy as np 12 | 13 | from trainers.gpt_trainer import VqVaeGPTTrainer 14 | from trainers.vqvae_trainer import VqVaeTrainer 15 | from utils.dataset import load_vqvae_processed 16 | from utils.annotations import GPTConfig, GPTState, VqVaeConfig, VqVaeState 17 | from utils.logger import get_writer, log_dict 18 | 19 | 20 | def parse_args(): 21 | parser = argparse.ArgumentParser( 22 | description="Train a GPT-style transformer the VQ-VAE tokens of the MNIST dataset." 23 | ) 24 | parser.add_argument( 25 | "-f", "--file", type=str, required=True, help="path to the json config file." 26 | ) 27 | parser.add_argument( 28 | "-chkp", "--checkpoint", type=str, help="path to train state pkl file." 29 | ) 30 | return parser.parse_args() 31 | 32 | 33 | def main(config: GPTConfig, checkpoint: Optional[str] = None): 34 | writer = get_writer(config.logdir) 35 | exp_dir = writer.logdir 36 | 37 | # save config file 38 | with open(Path(exp_dir) / "config.json", "w") as f: 39 | json.dump(config._asdict(), f, indent=4) 40 | 41 | # load dataset 42 | features, dset_train = load_vqvae_processed( 43 | path=config.train_dataset, 44 | batch_size=config.train_batch_size, 45 | repeat=True, 46 | seed=config.seed, 47 | ) 48 | _, dset_test = load_vqvae_processed( 49 | path=config.test_dataset, 50 | batch_size=config.test_batch_size, 51 | repeat=True, 52 | seed=config.seed, 53 | ) 54 | label_classes = features["label"].num_classes 55 | 56 | # load vqvae for evaluation 57 | with open(config.vqvae_config, "r") as f: 58 | vqvae_config = VqVaeConfig(**json.load(f)) 59 | with open(config.vqvae_state, "rb") as f: 60 | vqvae_state: VqVaeState = pickle.load(f) 61 | vqvae = VqVaeTrainer( 62 | K=vqvae_config.K, 63 | D=vqvae_config.D, 64 | compression_level=vqvae_config.compression_level, 65 | res_layers=vqvae_config.res_layers, 66 | commitment_loss=vqvae_config.commitment_loss, 67 | optimizer=None, 68 | ) 69 | 70 | @jax.jit 71 | def decode_indices(vqvae_state: VqVaeState, indices): 72 | z_q = vqvae.lookup_indices(vqvae_state, indices) 73 | img, _ = vqvae.apply.decode( 74 | vqvae_state.params, vqvae_state.state, None, z_q, is_training=False 75 | ) 76 | return img 77 | 78 | # initialize model 79 | _, sample = next(dset_train) 80 | optimizer = optax.adamw(config.learning_rate, weight_decay=config.weight_decay) 81 | trainer = VqVaeGPTTrainer( 82 | label_classes, 83 | vqvae_config, 84 | config.num_heads, 85 | config.hidden_dim, 86 | config.num_layers, 87 | config.dropout_rate, 88 | sample, 89 | optimizer, 90 | ) 91 | key = jax.random.PRNGKey(config.seed) 92 | if checkpoint is None: 93 | key, key1 = jax.random.split(key) 94 | train_state = trainer.initial_state(key1, sample) 95 | else: 96 | with open(checkpoint, "rb") as f: 97 | train_state: GPTState = pickle.load(f) 98 | 99 | # training loop 100 | for i in tqdm(range(config.train_steps)): 101 | # update 102 | epoch, batch = next(dset_train) 103 | train_state, logs = trainer.update(train_state, batch) 104 | 105 | # log 106 | logs["scalar_epoch"] = epoch 107 | log_dict(writer, logs, step=i, prefix="train/") 108 | 109 | # evaluate 110 | if (i + 1) % config.test_every == 0: 111 | # test loss 112 | logs = defaultdict(list) 113 | for _ in range(config.test_steps): 114 | _, batch = next(dset_test) 115 | train_state, log = trainer.evaluate(train_state, batch) 116 | for k, v in log.items(): 117 | logs[k].append(v) 118 | log_dict(writer, logs, step=i, prefix="test/") 119 | 120 | # generate samples 121 | for label in range(label_classes): 122 | images = [] 123 | for _ in range(config.generate_samples): 124 | indices, key = trainer.generate( 125 | train_state, key, label, temp=config.sample_temperature 126 | ) 127 | img = decode_indices(vqvae_state, indices) 128 | images.append(img[0]) 129 | images = np.array(images) 130 | logs = {f"images_generate_{label}": images} 131 | log_dict(writer, logs, step=i, prefix="test/") 132 | 133 | with open(Path(exp_dir) / config.output_name, "wb") as f: 134 | pickle.dump(train_state, f) 135 | 136 | writer.close() 137 | 138 | 139 | if __name__ == "__main__": 140 | args = parse_args() 141 | with open(args.file, "r") as f: 142 | config = GPTConfig(**json.load(f)) 143 | 144 | main(config, checkpoint=args.checkpoint) 145 | -------------------------------------------------------------------------------- /trainers/vqvae_trainer.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, NamedTuple, Optional 2 | import functools 3 | 4 | import haiku as hk 5 | import jax.numpy as jnp 6 | import jax 7 | import optax 8 | from jax._src.random import KeyArray 9 | from optax._src.base import GradientTransformation 10 | 11 | from models import CnnEncoder, CnnDecoder, QuantizedCodebook 12 | from utils.annotations import VqVaeBatch, VqVaeState 13 | from utils.losses import mse 14 | 15 | 16 | class VqVaeApply(NamedTuple): 17 | encode: Callable[..., Any] 18 | decode: Callable[..., Any] 19 | quantize: Callable[..., Any] 20 | embed: Callable[..., Any] 21 | 22 | 23 | class VqVaeTrainer: 24 | def __init__( 25 | self, 26 | K: int, 27 | D: int, 28 | compression_level: int, 29 | res_layers: int, 30 | commitment_loss: float, 31 | optimizer: Optional[GradientTransformation], 32 | ): 33 | self.K = K 34 | self.D = D 35 | self.compression_level = compression_level 36 | self.res_layers = res_layers 37 | 38 | transformed = self.build( 39 | self.K, self.D, self.compression_level, self.res_layers, commitment_loss 40 | ) 41 | self.init = transformed.init 42 | self.apply = VqVaeApply(*transformed.apply) 43 | 44 | self.optimizer = optimizer 45 | 46 | @staticmethod 47 | def build( 48 | K: int, D: int, compression_level: int, res_layers: int, commitment_loss: float 49 | ): 50 | def f(): 51 | encoder = CnnEncoder( 52 | out_channels=D, 53 | downscale_level=compression_level, 54 | res_layers=res_layers, 55 | name="encoder", 56 | ) 57 | decoder = CnnDecoder( 58 | in_channels=D, 59 | upscale_level=compression_level, 60 | res_layers=res_layers, 61 | name="decoder", 62 | ) 63 | quantizer = QuantizedCodebook(K, D, commitment_loss, name="quantizer") 64 | 65 | def encode(x, is_training: bool): 66 | return encoder(x, is_training) 67 | 68 | def decode(x, is_training: bool): 69 | return decoder(x, is_training) 70 | 71 | def quantize(x): 72 | return quantizer(x) 73 | 74 | def embed(x): 75 | return quantizer.embed(x) 76 | 77 | def init(x, is_training: bool): 78 | encodings = encode(x, is_training) 79 | result = quantize(encodings) 80 | x_pred = decode(result["quantize"], is_training) 81 | z_q = embed(result["encoding_indices"]) 82 | return x_pred, z_q 83 | 84 | return init, (encode, decode, quantize, embed) 85 | 86 | return hk.multi_transform_with_state(f) 87 | 88 | def initial_state(self, rng: KeyArray, batch: VqVaeBatch) -> VqVaeState: 89 | params, state = self.init(rng, batch["image"], is_training=True) 90 | opt_state = self.optimizer.init(params) 91 | return VqVaeState(params, state, opt_state) 92 | 93 | def forward(self, params: hk.Params, state: hk.State, x, is_training: bool): 94 | z_e, state = self.apply.encode(params, state, None, x, is_training) 95 | result, state = self.apply.quantize(params, state, None, z_e) 96 | z_q = result["quantize"] 97 | x_pred, state = self.apply.decode(params, state, None, z_q, is_training) 98 | result["x_pred"] = x_pred 99 | return result, state 100 | 101 | def loss( 102 | self, params: hk.Params, state: hk.State, batch: VqVaeBatch, is_training: bool 103 | ): 104 | x = batch["image"] 105 | result, state = self.forward(params, state, x, is_training) 106 | reconstruct_loss = mse(x, result["x_pred"]) 107 | loss = reconstruct_loss + result["codebook_loss"] 108 | return loss, (state, result) 109 | 110 | @functools.partial(jax.jit, static_argnums=0) 111 | def update( 112 | self, vqvae_state: VqVaeState, batch: VqVaeBatch 113 | ) -> tuple[VqVaeState, dict[str, Any]]: 114 | assert self.optimizer is not None 115 | 116 | loss_and_grad = jax.value_and_grad(self.loss, has_aux=True) 117 | (loss, (state, _)), grads = loss_and_grad( 118 | vqvae_state.params, vqvae_state.state, batch, True 119 | ) 120 | updates, opt_state = self.optimizer.update( 121 | grads, vqvae_state.opt_state, vqvae_state.params 122 | ) 123 | params = optax.apply_updates(vqvae_state.params, updates) 124 | new_vqvae_state = VqVaeState(params, state, opt_state) 125 | logs = {"scalar_loss": jax.device_get(loss)} 126 | return new_vqvae_state, logs 127 | 128 | @functools.partial(jax.jit, static_argnums=0) 129 | def evaluate(self, vqvae_state: VqVaeState, batch: VqVaeBatch) -> dict[str, Any]: 130 | loss, (_, result) = self.loss( 131 | vqvae_state.params, vqvae_state.state, batch, is_training=False 132 | ) 133 | logs = { 134 | "scalar_loss": jax.device_get(loss), 135 | "images_original": batch["image"], 136 | "images_reconstruction": result["x_pred"], 137 | } 138 | return logs 139 | 140 | def lookup_indices(self, vqvae_state: VqVaeState, indices) -> jnp.ndarray: 141 | z_q, _ = self.apply.embed(vqvae_state.params, vqvae_state.state, None, indices) 142 | return z_q 143 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |