├── models ├── __init__.py ├── discriminator.py ├── vggish.py ├── modeling.py ├── checkpoint.py ├── vgg.py ├── models_vit.py ├── stylegan_discriminator.py └── models_vqgan.py ├── training ├── __init__.py ├── train_state.py ├── train_model_2.py ├── optimization.py └── train_model.py ├── examples └── sample.wav ├── additional_file └── vgg.h5 ├── run.sh ├── data ├── __init__.py ├── mixtures.py ├── metrics.py ├── dataloader.py ├── tasks.py ├── preprocessors.py └── imagenet_utils.py ├── README.md ├── tpu_startup_script.sh ├── audio_diff.py ├── configs ├── audio_audioset_sh.yaml ├── audio_audioset_sh_vocab_4096.yaml ├── audio_audioset_bb.yaml ├── image_laion400m_ss_1.yaml ├── image_laion400m_ss.yaml ├── audio_audioset_ll.yaml └── audio_audioset_ss.yaml ├── network_test.py ├── setup.py ├── dataset_test.py ├── config.py ├── train.py └── tpu_run.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /training/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/sample.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiasenlu/vit-vqgan-jax/HEAD/examples/sample.wav -------------------------------------------------------------------------------- /additional_file/vgg.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiasenlu/vit-vqgan-jax/HEAD/additional_file/vgg.h5 -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | # python3 train.py configs/image_laion400m_ss.yaml 2 | 3 | # 4 | python3 train.py configs/audio_audioset_ss.yaml -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from seqio.dataset_providers import * 2 | from seqio.utils import * 3 | from seqio.vocabularies import * 4 | from seqio.utils import * -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VIT-VQGAN 2 | 3 | ## Run code on TPU VMs 4 | 5 | Once you've created a TPU VM, run 6 | 7 | ```bash 8 | python tpu_run.py 9 | ``` 10 | 11 | Basically this SSH's into your TPU VM, installs dependencies, and then runs your command. 12 | You should modify the code in `tpu_run.py` and configuration in `pretrain/configs/*.yaml`. -------------------------------------------------------------------------------- /data/mixtures.py: -------------------------------------------------------------------------------- 1 | import seqio 2 | # To ensure all tasks are registered 3 | from data import tasks 4 | 5 | MixtureRegistry = seqio.MixtureRegistry 6 | 7 | MixtureRegistry.add( 8 | "audio_datasets", 9 | [ 10 | "vit_vqgan_audioset", 11 | "vit_vqgan_acav20m", 12 | "vit_vqgan_yttemoporal1b", 13 | ], 14 | default_rate=1.0) 15 | -------------------------------------------------------------------------------- /data/metrics.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence, List, Optional, Mapping, Any, Dict 2 | import numpy as np 3 | from seqio.metrics import Scalar, Text 4 | 5 | 6 | def cls_accuracy_metric(targets: Sequence[int], predictions): 7 | predictions = np.array(predictions).astype(np.float32) 8 | targets = np.array(targets).astype(np.int32) 9 | score = np.equal(targets, np.argmax(predictions, axis=-1)).astype(np.float32) 10 | score = np.mean(score) 11 | return { 12 | "score": Scalar(score), 13 | } -------------------------------------------------------------------------------- /tpu_startup_script.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -e 3 | 4 | # This script will get ran on the servers 5 | 6 | # this locks the python executable down to hopefully stop if from being fiddled with... 7 | screen -d -m python -c 'import time; time.sleep(999999999)' 8 | 9 | # initializes jax and installs ray on cloud TPUs. 10 | # Note that `clu` already installs tensorflow-cpu, which is needed to kick out the default tensorflow 11 | # installing on sudo doesn't work, idk why though 12 | /usr/bin/python3 -m pip install --upgrade pip 13 | 14 | cd ~ 15 | 16 | python3 -m pip install -e '.[tpu]' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html 17 | pip3 install --upgrade fabric dataclasses optax tqdm cloudpickle smart_open[gcs] func_timeout aioredis==1.3.1 wandb pandas simplejson 18 | 19 | # 32 * 1024 ** 3 -> 32 gigabytes 20 | export TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD=34359738368 -------------------------------------------------------------------------------- /models/discriminator.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import Any, Callable, Iterable, Optional, Sequence, Tuple, Union 3 | import math 4 | import functools 5 | 6 | from flax.linen.linear import DenseGeneral 7 | import flax.linen as nn 8 | from flax.linen.module import merge_param 9 | 10 | import jax 11 | import jax.numpy as jnp 12 | from jax import lax, random 13 | import numpy as np 14 | import einops 15 | import re 16 | import pickle 17 | 18 | from models.clip import CLIP, CLIPConfig, MultiLevelDViT 19 | from models.stylegan_discriminator import stylegan_discriminator 20 | 21 | class Discriminator(nn.Module): 22 | """Discriminators""" 23 | dtype: Any = jnp.float32 24 | num_channels: int = 3 25 | resolution: int = 256 26 | use_clip: bool = True 27 | 28 | def setup(self): 29 | self.stylegan_disc = stylegan_discriminator( 30 | num_channels = self.num_channels, 31 | dtype = self.dtype, 32 | resolution = self.resolution) 33 | 34 | if self.use_clip: 35 | self.clip = CLIP(CLIPConfig) 36 | self.multi_level_dvit = MultiLevelDViT(dtype=self.dtype) 37 | 38 | def get_stylegan_logit(self, x, train=True): 39 | return self.stylegan_disc(x, train=train) 40 | 41 | def get_clip_feature(self, x, train=False): 42 | return self.clip(x, train=False) 43 | 44 | def get_clip_logit(self, x, train=True): 45 | return self.multi_level_dvit(x, train=train) 46 | 47 | @nn.compact 48 | def __call__(self, x, c=None, train=True): 49 | stylegan_logit = self.get_stylegan_logit(x, train=train) 50 | 51 | clip_logit = None 52 | if self.use_clip: 53 | clip_feat = self.get_clip_feature(x, train=train) 54 | clip_logit = self.get_clip_logit(clip_feat, train=train) 55 | 56 | return stylegan_logit, clip_logit -------------------------------------------------------------------------------- /audio_diff.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import os 3 | import sys 4 | import librosa 5 | import scipy.signal.windows 6 | import soundfile as sf 7 | import numpy as np 8 | from io import BytesIO 9 | from PIL import Image 10 | from scipy.io import wavfile 11 | import io 12 | import matplotlib.pyplot as plt 13 | from PIL import Image 14 | 15 | sample_rate = 22050 16 | n_fft = 2048 17 | hop_length=736 18 | eps = 0.1 19 | 20 | params={ 21 | 'n_fft': n_fft, # Manually selected by Sangho 22 | 'hop_length': hop_length, # Manually selected by Sangho 23 | 'window': scipy.signal.windows.hann, # Default 24 | 'n_mels': 64, # Manually selected by Sanho 25 | 'fmin': 20.0, # Manually selected by Sanho 26 | 'fmax': sample_rate / 2.0, # Default 22050 27 | } # Spectrogram therefore has shape (64, 376) for 10s 28 | 29 | video_fn = '0012y1s1bJI_000350.mp4' 30 | audio_fn = 'audio.wav' 31 | ffmpeg_process = subprocess.Popen( 32 | ['ffmpeg', '-y', '-i', video_fn, '-ac', '1', '-ar', str(sample_rate), audio_fn], 33 | stdout=-1, stderr=-1, text=True 34 | ) 35 | 36 | stdout, stderr = ffmpeg_process.communicate(None, timeout=5.0) 37 | ffmpeg_process.kill() 38 | 39 | sr, waveform = wavfile.read(audio_fn, mmap=True) 40 | waveform = waveform.astype('float32') 41 | waveform /= max(np.abs(waveform).max(), 1.0) 42 | 43 | window_size = 2.0 44 | playback_speed = 1 45 | 46 | st = float(60 * 0 + 0.0) 47 | start_idx = int(sr * st) 48 | end_idx = start_idx + int(sr * window_size) * playback_speed 49 | 50 | y = waveform[start_idx:end_idx] 51 | 52 | mel = librosa.feature.melspectrogram(y=y, sr=sr, **params) 53 | log_mel = np.log(mel + eps) - np.log(eps) 54 | 55 | 56 | 57 | -------------------------------------------------------------------------------- /configs/audio_audioset_sh.yaml: -------------------------------------------------------------------------------- 1 | # ImageNet-1K 2 | # trainset size: 1_281_167 3 | # bsize of 4096 so 312 is one epoch 4 | data: 5 | task_name: "audio_datasets" 6 | task: "audio" 7 | input_size: [128, 256] 8 | patch_size: [8, 8] 9 | 10 | model: 11 | vocab_size: 8192 12 | proj_dim: 32 13 | # Transformers 14 | encoder_hidden_size: 512 15 | encoder_num_layers: 8 16 | encoder_mlp_dim: 2048 17 | encoder_num_heads: 8 18 | decoder_hidden_size: 1280 19 | decoder_num_layers: 32 20 | decoder_mlp_dim: 5120 21 | decoder_num_heads: 16 22 | dropout_rate: 0.0 23 | dropath_rate: 0.0 24 | attention_dropout_rate: 0.0 25 | default_input_size: [128, 256] 26 | output_channel: 1 27 | # PE 28 | add_position_embedding: False 29 | 30 | # Misc. 31 | use_bfloat16: True 32 | 33 | loss: 34 | # loss 35 | codebook_weight: 1.0 36 | loggaussian_weight: 1.0 37 | loglaplace_weight: 0.0 38 | perceptual_weight: 0.1 39 | adversarial_weight: 0.02 40 | disc_g_start: 100000 #50k 41 | disc_d_start: 80000 #45k 42 | 43 | device: 44 | use_tpu: True 45 | initialize_ckpt: "" 46 | output_dir: "gs://jiasen-us-east/audio_audioset_sh_all" 47 | wandb_api: "4e6d1a3bbc9e8bce0ee37bec376733982d76e01b" 48 | wandb_project: "vit-vqgan-audio" 49 | wandb_entity: "jiasenl" 50 | wandb_name: "" 51 | batch_size: 256 52 | save_every_nsteps: 50000 53 | commit_every_nsteps: 100 54 | 55 | optimizer_g: 56 | optim: "adamw" 57 | learning_rate: 0.0001 58 | end_learning_rate: 0.0 59 | num_train_steps: 1000000 # 200k 60 | num_warmup_steps: 20000 61 | weight_decay_rate: 0.0001 62 | beta_1: 0.9 63 | beta_2: 0.99 64 | adafactor: False 65 | use_bfloat16_optim: False 66 | eps: 0.00000001 67 | use_bfloat16_weights: False 68 | do_bias_correction: True 69 | global_max_norm: 1.0 70 | 71 | optimizer_d: 72 | optim: "adamw" 73 | learning_rate: 0.0001 74 | end_learning_rate: 0.0 75 | num_train_steps: 1000000 # 300 epochs 76 | num_warmup_steps: 20000 77 | weight_decay_rate: 0.0001 78 | beta_1: 0.9 79 | beta_2: 0.99 80 | adafactor: False 81 | use_bfloat16_optim: False 82 | eps: 0.00000001 83 | use_bfloat16_weights: False 84 | do_bias_correction: True 85 | global_max_norm: 1.0 -------------------------------------------------------------------------------- /configs/audio_audioset_sh_vocab_4096.yaml: -------------------------------------------------------------------------------- 1 | # ImageNet-1K 2 | # trainset size: 1_281_167 3 | # bsize of 4096 so 312 is one epoch 4 | data: 5 | task_name: "audio_datasets" 6 | task: "audio" 7 | input_size: [128, 256] 8 | patch_size: [8, 8] 9 | 10 | model: 11 | vocab_size: 4096 12 | proj_dim: 32 13 | # Transformers 14 | encoder_hidden_size: 512 15 | encoder_num_layers: 8 16 | encoder_mlp_dim: 2048 17 | encoder_num_heads: 8 18 | decoder_hidden_size: 1280 19 | decoder_num_layers: 32 20 | decoder_mlp_dim: 5120 21 | decoder_num_heads: 16 22 | dropout_rate: 0.0 23 | dropath_rate: 0.0 24 | attention_dropout_rate: 0.0 25 | default_input_size: [128, 256] 26 | output_channel: 1 27 | # PE 28 | add_position_embedding: False 29 | 30 | # Misc. 31 | use_bfloat16: True 32 | 33 | loss: 34 | # loss 35 | codebook_weight: 1.0 36 | loggaussian_weight: 1.0 37 | loglaplace_weight: 0.0 38 | perceptual_weight: 0.1 39 | adversarial_weight: 0.02 40 | disc_g_start: 100000 #50k 41 | disc_d_start: 80000 #45k 42 | 43 | device: 44 | use_tpu: True 45 | initialize_ckpt: "" 46 | output_dir: "gs://jiasen-us-east/audio_audioset_sh_all_4096" 47 | wandb_api: "4e6d1a3bbc9e8bce0ee37bec376733982d76e01b" 48 | wandb_project: "vit-vqgan-audio" 49 | wandb_entity: "jiasenl" 50 | wandb_name: "" 51 | batch_size: 256 52 | save_every_nsteps: 50000 53 | commit_every_nsteps: 100 54 | 55 | optimizer_g: 56 | optim: "adamw" 57 | learning_rate: 0.0001 58 | end_learning_rate: 0.0 59 | num_train_steps: 1000000 # 200k 60 | num_warmup_steps: 20000 61 | weight_decay_rate: 0.0001 62 | beta_1: 0.9 63 | beta_2: 0.99 64 | adafactor: False 65 | use_bfloat16_optim: False 66 | eps: 0.00000001 67 | use_bfloat16_weights: False 68 | do_bias_correction: True 69 | global_max_norm: 1.0 70 | 71 | optimizer_d: 72 | optim: "adamw" 73 | learning_rate: 0.0001 74 | end_learning_rate: 0.0 75 | num_train_steps: 1000000 # 300 epochs 76 | num_warmup_steps: 20000 77 | weight_decay_rate: 0.0001 78 | beta_1: 0.9 79 | beta_2: 0.99 80 | adafactor: False 81 | use_bfloat16_optim: False 82 | eps: 0.00000001 83 | use_bfloat16_weights: False 84 | do_bias_correction: True 85 | global_max_norm: 1.0 -------------------------------------------------------------------------------- /configs/audio_audioset_bb.yaml: -------------------------------------------------------------------------------- 1 | # ImageNet-1K 2 | # trainset size: 1_281_167 3 | # bsize of 4096 so 312 is one epoch 4 | 5 | data: 6 | task_name: "vit_vqgan_audioset" 7 | task: "audio" 8 | input_size: [128, 256] 9 | patch_size: [8, 8] 10 | 11 | model: 12 | vocab_size: 8192 13 | proj_dim: 32 14 | # Transformers 15 | encoder_hidden_size: 768 16 | encoder_num_layers: 12 17 | encoder_mlp_dim: 3072 18 | encoder_num_heads: 12 19 | decoder_hidden_size: 768 20 | decoder_num_layers: 12 21 | decoder_mlp_dim: 3072 22 | decoder_num_heads: 12 23 | dropout_rate: 0.0 24 | attention_dropout_rate: 0.0 25 | default_input_size: [64, 60] 26 | output_channel: 1 27 | use_bias: False 28 | act_fn: 'relu' 29 | # PE 30 | add_position_embedding: False 31 | 32 | # Misc. 33 | use_bfloat16: False 34 | 35 | loss: 36 | # loss 37 | codebook_weight: 1.0 38 | loggaussian_weight: 1.0 39 | loglaplace_weight: 0.1 40 | perceptual_weight: 0.1 41 | adversarial_weight: 0.1 42 | disc_g_start: 100000 #50k 43 | disc_d_start: 80000 #45k 44 | 45 | device: 46 | use_tpu: True 47 | initialize_ckpt: "" 48 | output_dir: "gs://jiasen-us-east/audio_audioset_bb" 49 | wandb_api: "4e6d1a3bbc9e8bce0ee37bec376733982d76e01b" 50 | wandb_project: "vit-vqgan-audio" 51 | wandb_entity: "jiasenl" 52 | wandb_name: "" 53 | batch_size: 1024 54 | save_every_nsteps: 50000 55 | commit_every_nsteps: 100 56 | 57 | optimizer_g: 58 | optim: "adamw" 59 | learning_rate: 0.0001 60 | end_learning_rate: 0.00005 61 | num_train_steps: 500000 # 500k 62 | num_warmup_steps: 20000 63 | weight_decay_rate: 0.0001 64 | beta_1: 0.9 65 | beta_2: 0.99 66 | adafactor: False 67 | use_bfloat16_optim: False 68 | eps: 0.00000001 69 | use_bfloat16_weights: False 70 | do_bias_correction: True 71 | global_max_norm: 1.0 72 | 73 | optimizer_d: 74 | optim: "adamw" 75 | learning_rate: 0.0001 76 | end_learning_rate: 0.00005 77 | num_train_steps: 500000 78 | num_warmup_steps: 20000 79 | weight_decay_rate: 0.0001 80 | beta_1: 0.9 81 | beta_2: 0.99 82 | adafactor: False 83 | use_bfloat16_optim: False 84 | eps: 0.00000001 85 | use_bfloat16_weights: False 86 | do_bias_correction: True 87 | global_max_norm: 1.0 88 | -------------------------------------------------------------------------------- /configs/image_laion400m_ss_1.yaml: -------------------------------------------------------------------------------- 1 | # ImageNet-1K 2 | # trainset size: 1_281_167 3 | # bsize of 4096 so 312 is one epoch 4 | 5 | data: 6 | task_name: "vit_vqgan_liaon_400m" 7 | task: "image" 8 | 9 | # Vision 10 | input_size: [256, 256] 11 | patch_size: [8, 8] 12 | 13 | model: 14 | vocab_size: 8192 15 | proj_dim: 32 16 | # Transformers 17 | encoder_hidden_size: 512 18 | encoder_num_layers: 8 19 | encoder_mlp_dim: 2048 20 | encoder_num_heads: 8 21 | decoder_hidden_size: 512 22 | decoder_num_layers: 8 23 | decoder_mlp_dim: 2048 24 | decoder_num_heads: 8 25 | dropout_rate: 0.0 26 | dropath_rate: 0.0 27 | attention_dropout_rate: 0.0 28 | 29 | # PE 30 | add_position_embedding: False 31 | 32 | # Misc. 33 | use_bfloat16: True 34 | 35 | loss: 36 | # loss 37 | codebook_weight: 1.0 38 | loggaussian_weight: 1.0 39 | loglaplace_weight: 0.0 40 | perceptual_weight: 0.1 41 | adversarial_weight: 0.1 42 | disc_g_start: 100000 43 | disc_d_start: 100000 44 | disc_d_flip_update: 200000 45 | 46 | device: 47 | use_tpu: True 48 | initialize_ckpt: "" 49 | output_dir: "gs://jiasen-us-east/vit_vqgan_laion400m_ss_new" 50 | wandb_api: "4e6d1a3bbc9e8bce0ee37bec376733982d76e01b" 51 | wandb_project: "vit-vqgan-image-new" 52 | wandb_entity: "jiasenl" 53 | wandb_name: "" 54 | batch_size: 512 # 16 / core 55 | save_every_nsteps: 50000 56 | commit_every_nsteps: 100 57 | 58 | optimizer_g: 59 | optim: "adamw" 60 | learning_rate: 0.0001 61 | end_learning_rate: 0.00005 62 | num_train_steps: 500000 # 200k 63 | num_warmup_steps: 20000 64 | weight_decay_rate: 0.0001 65 | beta_1: 0.9 66 | beta_2: 0.99 67 | adafactor: False 68 | use_bfloat16_optim: False 69 | eps: 0.00000001 70 | use_bfloat16_weights: False 71 | do_bias_correction: True 72 | global_max_norm: 1.0 73 | 74 | optimizer_d: 75 | optim: "adamw" 76 | learning_rate: 0.001 77 | end_learning_rate: 0.00005 78 | num_train_steps: 500000 # 300 epochs 79 | num_warmup_steps: 20000 80 | weight_decay_rate: 0.0001 81 | beta_1: 0.9 82 | beta_2: 0.99 83 | adafactor: False 84 | use_bfloat16_optim: False 85 | eps: 0.00000001 86 | use_bfloat16_weights: False 87 | do_bias_correction: True 88 | global_max_norm: 1.0 -------------------------------------------------------------------------------- /configs/image_laion400m_ss.yaml: -------------------------------------------------------------------------------- 1 | # ImageNet-1K 2 | # trainset size: 1_281_167 3 | # bsize of 4096 so 312 is one epoch 4 | 5 | data: 6 | task_name: "vit_vqgan_liaon_400m" 7 | task: "image" 8 | 9 | # Vision 10 | input_size: [256, 256] 11 | patch_size: [8, 8] 12 | 13 | model: 14 | vocab_size: 8192 15 | proj_dim: 32 16 | # Transformers 17 | encoder_hidden_size: 512 18 | encoder_num_layers: 8 19 | encoder_mlp_dim: 2048 20 | encoder_num_heads: 8 21 | decoder_hidden_size: 512 22 | decoder_num_layers: 8 23 | decoder_mlp_dim: 2048 24 | decoder_num_heads: 8 25 | dropout_rate: 0.0 26 | dropath_rate: 0.0 27 | attention_dropout_rate: 0.0 28 | 29 | # PE 30 | add_position_embedding: False 31 | 32 | # Misc. 33 | use_bfloat16: True 34 | 35 | loss: 36 | # loss 37 | codebook_weight: 1.0 38 | loggaussian_weight: 1.0 39 | loglaplace_weight: 0.0 40 | perceptual_weight: 0.1 41 | adversarial_weight: 0.1 42 | disc_g_start: 100000 43 | disc_d_start: 100000 44 | disc_d_flip_update: 200000 45 | 46 | device: 47 | use_tpu: True 48 | initialize_ckpt: "" 49 | output_dir: "gs://jiasen-us-east/vit_vqgan_laion400m_ss_new_final_init" 50 | wandb_api: "4e6d1a3bbc9e8bce0ee37bec376733982d76e01b" 51 | wandb_project: "vit-vqgan-image-new" 52 | wandb_entity: "jiasenl" 53 | wandb_name: "" 54 | batch_size: 512 # 16 / core 55 | save_every_nsteps: 50000 56 | commit_every_nsteps: 100 57 | 58 | optimizer_g: 59 | optim: "adamw" 60 | learning_rate: 0.0001 61 | end_learning_rate: 0.00005 62 | num_train_steps: 500000 # 200k 63 | num_warmup_steps: 20000 64 | weight_decay_rate: 0.0001 65 | beta_1: 0.9 66 | beta_2: 0.99 67 | adafactor: False 68 | use_bfloat16_optim: False 69 | eps: 0.00000001 70 | use_bfloat16_weights: False 71 | do_bias_correction: True 72 | global_max_norm: 1.0 73 | 74 | optimizer_d: 75 | optim: "adamw" 76 | learning_rate: 0.001 77 | end_learning_rate: 0.00005 78 | num_train_steps: 500000 # 300 epochs 79 | num_warmup_steps: 20000 80 | weight_decay_rate: 0.0001 81 | beta_1: 0.9 82 | beta_2: 0.99 83 | adafactor: False 84 | use_bfloat16_optim: False 85 | eps: 0.00000001 86 | use_bfloat16_weights: False 87 | do_bias_correction: True 88 | global_max_norm: 1.0 -------------------------------------------------------------------------------- /configs/audio_audioset_ll.yaml: -------------------------------------------------------------------------------- 1 | # ImageNet-1K 2 | # trainset size: 1_281_167 3 | # bsize of 4096 so 312 is one epoch 4 | 5 | data: 6 | task_name: "vit_vqgan_audioset" 7 | task: "audio" 8 | input_size: [128, 256] 9 | patch_size: [8, 8] 10 | 11 | model: 12 | vocab_size: 8192 13 | proj_dim: 32 14 | # Transformers 15 | encoder_hidden_size: 1024 16 | encoder_num_layers: 24 17 | encoder_mlp_dim: 4096 18 | encoder_num_heads: 16 19 | decoder_hidden_size: 1024 20 | decoder_num_layers: 24 21 | decoder_mlp_dim: 4096 22 | decoder_num_heads: 16 23 | dropout_rate: 0.0 24 | dropath_rate: 0.0 25 | attention_dropout_rate: 0.0 26 | default_input_size: [64, 60] 27 | output_channel: 1 28 | use_bias: False 29 | act_fn: 'relu' 30 | # PE 31 | add_position_embedding: False 32 | 33 | # Misc. 34 | use_bfloat16: False 35 | 36 | loss: 37 | # loss 38 | codebook_weight: 1.0 39 | loggaussian_weight: 1.0 40 | loglaplace_weight: 0.1 41 | perceptual_weight: 0.1 42 | adversarial_weight: 0.1 43 | disc_g_start: 100000 #50k 44 | disc_d_start: 80000 #45k 45 | 46 | device: 47 | use_tpu: True 48 | initialize_ckpt: "" 49 | output_dir: "gs://jiasen-us-east/audio_audioset_ll" 50 | wandb_api: "4e6d1a3bbc9e8bce0ee37bec376733982d76e01b" 51 | wandb_project: "vit-vqgan-audio" 52 | wandb_entity: "jiasenl" 53 | wandb_name: "" 54 | batch_size: 1024 55 | save_every_nsteps: 50000 56 | commit_every_nsteps: 100 57 | 58 | optimizer_g: 59 | optim: "adamw" 60 | learning_rate: 0.0001 61 | end_learning_rate: 0.00005 62 | num_train_steps: 200000 # 200k 63 | num_warmup_steps: 20000 64 | weight_decay_rate: 0.0001 65 | beta_1: 0.9 66 | beta_2: 0.99 67 | adafactor: False 68 | use_bfloat16_optim: False 69 | eps: 0.00000001 70 | use_bfloat16_weights: False 71 | do_bias_correction: True 72 | global_max_norm: 1.0 73 | 74 | optimizer_d: 75 | optim: "adamw" 76 | learning_rate: 0.0001 77 | end_learning_rate: 0.00005 78 | num_train_steps: 200000 # 300 epochs 79 | num_warmup_steps: 20000 80 | weight_decay_rate: 0.0001 81 | beta_1: 0.9 82 | beta_2: 0.99 83 | adafactor: False 84 | use_bfloat16_optim: False 85 | eps: 0.00000001 86 | use_bfloat16_weights: False 87 | do_bias_correction: True 88 | global_max_norm: 1.0 89 | -------------------------------------------------------------------------------- /configs/audio_audioset_ss.yaml: -------------------------------------------------------------------------------- 1 | # ImageNet-1K 2 | # trainset size: 1_281_167 3 | # bsize of 4096 so 312 is one epoch 4 | 5 | data: 6 | task_name: "audio_datasets" 7 | task: "audio" 8 | input_size: [128, 256] 9 | patch_size: [8, 8] 10 | 11 | model: 12 | vocab_size: 8192 13 | proj_dim: 32 14 | # Transformers 15 | encoder_hidden_size: 512 16 | encoder_num_layers: 8 17 | encoder_mlp_dim: 2048 18 | encoder_num_heads: 8 19 | decoder_hidden_size: 512 20 | decoder_num_layers: 8 21 | decoder_mlp_dim: 2048 22 | decoder_num_heads: 8 23 | dropout_rate: 0.0 24 | dropath_rate: 0.0 25 | attention_dropout_rate: 0.0 26 | default_input_size: [128, 256] 27 | output_channel: 1 28 | # PE 29 | add_position_embedding: False 30 | 31 | # Misc. 32 | use_bfloat16: True 33 | 34 | loss: 35 | # loss 36 | codebook_weight: 1.0 37 | loggaussian_weight: 1.0 38 | loglaplace_weight: 0 39 | perceptual_weight: 0.1 40 | adversarial_weight: 0.1 41 | disc_g_start: 100000 #50k 42 | disc_d_start: 100000 43 | disc_d_flip_update: 200000 44 | 45 | device: 46 | use_tpu: True 47 | initialize_ckpt: "" 48 | output_dir: "gs://jiasen-us-east/audio/audio_audioset_ss_new_final_init" 49 | wandb_api: "4e6d1a3bbc9e8bce0ee37bec376733982d76e01b" 50 | wandb_project: "vit-vqgan-audio" 51 | wandb_entity: "jiasenl" 52 | wandb_name: "" 53 | batch_size: 512 54 | save_every_nsteps: 50000 55 | commit_every_nsteps: 100 56 | 57 | 58 | optimizer_g: 59 | optim: "adamw" 60 | learning_rate: 0.0001 61 | end_learning_rate: 0.00005 62 | num_train_steps: 1000000 # 500k 63 | num_warmup_steps: 20000 64 | weight_decay_rate: 0.0001 65 | beta_1: 0.9 66 | beta_2: 0.99 67 | adafactor: False 68 | use_bfloat16_optim: False 69 | eps: 0.00000001 70 | use_bfloat16_weights: False 71 | do_bias_correction: True 72 | global_max_norm: 1.0 73 | end_learning_rate: 0.0 74 | 75 | optimizer_d: 76 | optim: "adamw" 77 | learning_rate: 0.001 78 | end_learning_rate: 0.0001 79 | num_train_steps: 500000 # 300 epochs 80 | num_warmup_steps: 20000 81 | weight_decay_rate: 0.0001 82 | beta_1: 0.9 83 | beta_2: 0.99 84 | adafactor: False 85 | use_bfloat16_optim: False 86 | eps: 0.00000001 87 | use_bfloat16_weights: False 88 | do_bias_correction: True 89 | global_max_norm: 1.0 -------------------------------------------------------------------------------- /models/vggish.py: -------------------------------------------------------------------------------- 1 | from jax import random 2 | import jax.numpy as jnp 3 | import flax.linen as nn 4 | import functools 5 | from typing import Any, Tuple 6 | import h5py 7 | import warnings 8 | 9 | from tqdm import tqdm 10 | import requests 11 | import os 12 | import tempfile 13 | 14 | from models.vgg import download 15 | 16 | URLS = { 17 | 'vggish': 'https://github.com/harritaylor/torchvggish/' 18 | 'releases/download/v0.1/vggish-10086976.pth', 19 | 'pca': 'https://github.com/harritaylor/torchvggish/' 20 | 'releases/download/v0.1/vggish_pca_params-970ea276.pth' 21 | } 22 | 23 | def normalize_tensor(x, eps=1e-10): 24 | norm_factor = jnp.sqrt(jnp.sum(x ** 2, axis=-1, keepdims=True)) 25 | return x / (norm_factor + eps) 26 | 27 | def spatial_average(x, keepdims=True): 28 | return jnp.mean(x, axis=[1,2], keepdims=keepdims) 29 | 30 | class VGG(nn.Module): 31 | dtype: Any = jnp.float32 32 | ckpt_dir: str = None 33 | 34 | def setup(self): 35 | ckpt_file = download(self.ckpt_dir, URLS['vggish']) 36 | import torch 37 | self.param_dict = torch.load(ckpt_file) 38 | 39 | def _forward(self, x): 40 | out = [] 41 | cnt = 0 42 | for v in [64, "M", 128, "M", 256, 256, "M", 512, 512, "M"]: 43 | if v == "M": 44 | out.append(x) 45 | x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2)) 46 | cnt += 1 47 | else: 48 | w = lambda *_ : jnp.transpose(jnp.array(self.param_dict[f'features.{cnt}.weight']), (2,3,1,0)) 49 | b = lambda *_ : jnp.array(self.param_dict[f'features.{cnt}.bias']) 50 | x = nn.Conv( 51 | features=v, 52 | kernel_size=(3, 3), 53 | padding=((1, 1), (1, 1)), 54 | kernel_init=w, 55 | use_bias=True, 56 | bias_init=b, 57 | dtype=self.dtype)(x) 58 | cnt += 1 59 | 60 | x = nn.relu(x) 61 | cnt += 1 62 | 63 | return out 64 | 65 | @nn.compact 66 | def __call__(self, x0, x1, train=False): 67 | 68 | act0 = self._forward(x0) 69 | act1 = self._forward(x1) 70 | 71 | diffs = {} 72 | 73 | num = len(act0) 74 | for i in range(num): 75 | diffs[i] = (normalize_tensor(act0[i]) - normalize_tensor(act1[i])) ** 2 76 | 77 | res = [spatial_average(jnp.sum(diffs[i], axis=-1, keepdims=True), keepdims=True) for i in range(num)] 78 | 79 | return jnp.reshape(sum(res), (-1)) 80 | -------------------------------------------------------------------------------- /network_test.py: -------------------------------------------------------------------------------- 1 | from models.vggish import VGG 2 | import jax.numpy as jnp 3 | from jax.random import PRNGKey 4 | from training.train_model import * 5 | import yaml 6 | 7 | import tensorflow as tf 8 | import seqio 9 | from data.tasks import TaskRegistry 10 | import numpy as np 11 | import matplotlib.pyplot as plt 12 | import librosa.display 13 | from PIL import Image 14 | 15 | import subprocess 16 | import os 17 | import sys 18 | import librosa 19 | import scipy.signal.windows 20 | import soundfile as sf 21 | import numpy as np 22 | from io import BytesIO 23 | from PIL import Image 24 | from scipy.io import wavfile 25 | import io 26 | from PIL import Image 27 | 28 | from models.checkpoint import initialize_using_checkpoint, save_checkpoint, load_checkpoint, bf16_to_f32 29 | 30 | 31 | # load the data 32 | window_size = 4.08 33 | sample_rate = 16000 34 | n_fft = 1024 35 | win_len = 1024 36 | hop_len=256 37 | n_mels = 128 38 | fmin = 0.0 39 | eps = 0.1 40 | max_wav_value=32768.0 41 | playback_speed = 1 42 | fmax = 8000 43 | 44 | audio_fn = 'examples/1aigfM5Tmqk_000020.wav' 45 | sr, waveform = wavfile.read(audio_fn, mmap=True) 46 | waveform = waveform.astype('float32') 47 | waveform /= max_wav_value 48 | 49 | st = float(60 * 0 + 0.0) 50 | start_idx = int(sr * st) 51 | end_idx = start_idx + int(sr * window_size) * playback_speed 52 | waveform = waveform[start_idx:end_idx] 53 | 54 | librosa_melspec = librosa.feature.melspectrogram( 55 | waveform, 56 | sr=sample_rate, 57 | n_fft=n_fft, 58 | hop_length=hop_len, 59 | win_length=win_len, 60 | center=True, 61 | pad_mode="reflect", 62 | power=2.0, 63 | n_mels=n_mels, 64 | ) 65 | 66 | audio = librosa_melspec.reshape((1, 1, 128, 256, 1)) 67 | audio_mask = tf.cast(audio != 0, tf.float32) 68 | audio = tf.math.log(tf.clip_by_value(audio, 1e-5, 1e5)) 69 | audio = (audio + 5.0945) / 3.8312 70 | audio = audio * audio_mask 71 | audio = audio.numpy() 72 | 73 | seed = 0 74 | aux_rng_keys=["dropout", "drop_path"] 75 | 76 | dummy_batch = {'inputs': audio} 77 | 78 | with open('configs/audio_audioset_sh.yaml', 'r') as f: 79 | config = yaml.load(f, yaml.FullLoader) 80 | generator = Generator.from_config(config, 'generator') 81 | 82 | ckpt_path = 'gs://jiasen-us-east/audio_audioset_sh/2022-12-06-11:51.47/ckpt_700000' 83 | ckpt = load_checkpoint(path=ckpt_path) 84 | cache_params_g = ckpt['params_g'] 85 | del ckpt 86 | 87 | data = audio.reshape(1, 128, 256, 1) 88 | 89 | dataset = seqio.get_mixture_or_task("vit_vqgan_audioset").get_dataset( 90 | sequence_length={}, 91 | split="train", 92 | num_epochs=1, 93 | shard_info=seqio.ShardInfo(index=0, num_shards=10), 94 | use_cached=False, 95 | seed=42, 96 | shuffle=False, 97 | ) 98 | 99 | for ex in zip(dataset.as_numpy_iterator()): 100 | import pdb; pdb.set_trace() 101 | data = ex[0]['inputs'].reshape(1, 128, 256, 1) 102 | rec = generator.apply({'params': cache_params_g},data,train=False,) 103 | 104 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The T5X Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Install T5X.""" 16 | 17 | import os 18 | import sys 19 | import setuptools 20 | 21 | # To enable importing version.py directly, we add its path to sys.path. 22 | __version__ = '0.0.0' 23 | 24 | # Get the long description from the README file. 25 | with open('README.md') as fp: 26 | _LONG_DESCRIPTION = fp.read() 27 | 28 | _jax_version = '0.2.27' 29 | _jaxlib_version = '0.1.76' 30 | 31 | setuptools.setup( 32 | name='vit-vqgan', 33 | version=__version__, 34 | description='ViT-VQGAN in JAX', 35 | long_description=_LONG_DESCRIPTION, 36 | long_description_content_type='text/markdown', 37 | author='AI2', 38 | author_email='jiasenl@allenai.org', 39 | url='http://github.com/jiasenlu/vit-vgqan-jax', 40 | license='Apache 2.0', 41 | packages=setuptools.find_packages(), 42 | package_data={ 43 | '': ['**/*.gin'], # not all subdirectories may have __init__.py. 44 | }, 45 | scripts=[], 46 | install_requires=[ 47 | 'absl-py', 48 | 'cached_property', 49 | 'protobuf==3.19.4', 50 | 'google-api-core==2.8.2', 51 | # TODO(adarob): Replace with 'clu' once >0.0.6 is released. 52 | 'clu==0.0.8', 53 | 'flax==0.6.3', 54 | 'gin-config', 55 | f'jax >= {_jax_version},<0.4.0', 56 | f'jaxlib >= {_jaxlib_version},<0.4.0', 57 | 'numpy', 58 | 'orbax==0.0.2', 59 | 't5', 60 | 'tensorflow', 61 | 'einops', 62 | 'tfds-nightly', 63 | 'tensorflow_probability', 64 | 'tensorflow-addons', 65 | 'tensorflow-datasets @ git+https://github.com/tensorflow/datasets', 66 | 'pycocoevalcap', 67 | 'tensorstore >= 0.1.20', 68 | 'librosa', 69 | 'sk-video', 70 | 'SoundFile', 71 | 'scikit-image', 72 | 'wandb' 73 | ], 74 | extras_require={ 75 | 'gcp': [ 76 | 'gevent', 'google-api-python-client', 'google-compute-engine', 77 | 'google-cloud-storage', 'oauth2client' 78 | ], 79 | 'test': ['pytest'], 80 | 81 | 'data': ['ffmpeg-python', 'scikit-video', 'librosa', 'scikit-image', 'pafy', 'youtube_dl==2020.12.2', 'tensorflow_io', 'pydub'], 82 | 83 | # Cloud TPU requirements. 84 | 'tpu': [f'jax[tpu] >= {_jax_version}'], 85 | }, 86 | classifiers=[ 87 | 'Development Status :: 4 - Beta', 88 | 'Intended Audience :: Developers', 89 | 'Intended Audience :: Science/Research', 90 | 'License :: OSI Approved :: Apache Software License', 91 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 92 | ], 93 | keywords='machinelearning', 94 | ) -------------------------------------------------------------------------------- /dataset_test.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import seqio 3 | from data.tasks import TaskRegistry 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | import librosa.display 7 | from PIL import Image 8 | 9 | eps = 0.1 10 | min_level_db=-100 11 | 12 | print(tf.executing_eagerly()) 13 | tf.config.run_functions_eagerly(True) 14 | tf.data.experimental.enable_debug_mode() 15 | 16 | def _normalize(S): 17 | return np.clip((S - min_level_db) / -min_level_db, 0, 1) 18 | 19 | def plot_spectrogram(data, eps=0.1, n_fft=2048, hop_length=736, sr=22050, fmax=22050/2.0): 20 | fig, ax = plt.subplots(1, 1) 21 | data = data * (100) - 100 22 | data = np.transpose(np.reshape(data, (64, 60))) 23 | img = librosa.display.specshow(data, x_axis='time', y_axis='mel', sr=sr, fmax=sr/2, n_fft=n_fft, hop_length=hop_length, ax=ax) 24 | fig.colorbar(img, ax=ax, format='%+2.0f dB') 25 | fig.tight_layout(pad=0) 26 | fig.canvas.draw() 27 | data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) 28 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 29 | plt.close(fig) 30 | return data 31 | 32 | 33 | dataset = seqio.get_mixture_or_task("vit_vqgan_yttemoporal1b").get_dataset( 34 | sequence_length={}, 35 | split="train", 36 | num_epochs=1, 37 | shard_info=seqio.ShardInfo(index=0, num_shards=10), 38 | use_cached=False, 39 | seed=42, 40 | shuffle=False, 41 | ) 42 | max_value = 0 43 | min_value = 10000 44 | 45 | log_mel_list = [] 46 | nlog_mel_list = [] 47 | n_s_list = [] 48 | 49 | psum = 0 50 | count = 0 51 | cnt = 0 52 | psum_sq = 0 53 | for ex in zip(dataset.as_numpy_iterator()): 54 | # img = plot_spectrogram(ex['inputs']) 55 | # Image.fromarray(img).save('mel_spectrogram.png') 56 | import pdb; pdb.set_trace() 57 | psum += ex[0]['inputs'].sum() 58 | psum_sq += (ex[0]['inputs']**2).sum() 59 | 60 | count += (ex[0]['inputs'] != 0).sum() 61 | cnt += 1 62 | 63 | if cnt % 1000 == 0: 64 | print(cnt) 65 | 66 | total_mean = psum / count 67 | 68 | print(total_mean) 69 | # calulate 70 | 71 | total_var = (psum_sq / count) - (total_mean ** 2) 72 | total_std = np.sqrt(total_var) 73 | 74 | import pdb; pdb.set_trace() 75 | 76 | # from torchvision.utils import draw_bounding_boxes, save_image 77 | # import torch 78 | # import numpy as np 79 | # image = inputs 80 | # image_plot = torch.tensor(np.array(image, dtype=np.float32), dtype=torch.float32) 81 | # image_plot = image_plot.permute(2,0,1) 82 | # save_image(image_plot, 'input_image.jpg') 83 | 84 | # normalized_log_mel = log_mel / 11.535412 85 | # spec = np.exp(log_mel + np.log(eps)) - eps 86 | # S = librosa.power_to_db(spec, ref=np.max) 87 | # s_normalize = _normalize(S) 88 | 89 | # log_mel_list.append(log_mel) 90 | # nlog_mel_list.append(normalized_log_mel) 91 | # n_s_list.append(s_normalize) 92 | 93 | # log_mel_list = np.concatenate(log_mel_list, 0) 94 | # nlog_mel_list = np.concatenate(nlog_mel_list, 0) 95 | # n_s_list = np.concatenate(n_s_list, 0) 96 | 97 | # log_mel_list = np.reshape(log_mel_list, (-1)) 98 | # plt.hist(log_mel_list.tolist(), color = 'blue', edgecolor = 'black') 99 | # plt.savefig('log_mel_list.png') 100 | 101 | # plt.clf() 102 | # nlog_mel_list = np.reshape(nlog_mel_list, (-1)) 103 | # plt.hist(nlog_mel_list.tolist(), color = 'blue', edgecolor = 'black') 104 | # plt.savefig('nlog_mel_list.png') 105 | 106 | # plt.clf() 107 | # n_s_list = np.reshape(n_s_list, (-1)) 108 | # plt.hist(n_s_list.tolist(), color = 'blue', edgecolor = 'black') 109 | # plt.savefig('n_s_list.png') 110 | 111 | # import pdb; pdb.set_trace() 112 | 113 | # from torchvision.utils import save_image 114 | # import torch 115 | # import numpy as np 116 | # image = x[0] 117 | # image_plot = torch.tensor(np.array(image, dtype=np.float32), dtype=torch.float32) 118 | # # image_plot = image_plot.permute(2,0,1) 119 | # save_image(image_plot, 'input_image1.jpg') 120 | # import pdb; pdb.set_trace() 121 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | """Global UnifiedIO config parameters""" 2 | from typing import Any, Sequence 3 | 4 | from flax import struct 5 | from jax import numpy as jnp 6 | 7 | import seqio 8 | import tensorflow as tf 9 | 10 | import matplotlib.pyplot as plt 11 | from PIL import Image 12 | import numpy as np 13 | 14 | TFDS_DATA_DIR = 'gs://jiasen-us-east/datasets' 15 | 16 | MIN_LEVEL_DB = -100 17 | EPS = 0.1 18 | AMIN = 1e-10 19 | TOP_DB=80.0 20 | 21 | # Constants used when encoding region 22 | VOCAB_START = 100 23 | NUM_DETECTION_BIN = 1000 24 | 25 | # Controls data augmentation 26 | RANDOM_SCALE_MAX = 1.2 27 | RANDOM_SCALE_MIN = 1.0 28 | 29 | # Controls input/output image sizes 30 | IMAGE_INPUT_SIZE = [384, 384] 31 | IMAGE_INPUT_D = 16 32 | IMAGE_TARGET_SIZE = [256, 256] 33 | IMAGE_TARGET_D = 16 34 | 35 | FINETUNE_VIDEO_INPUT_SIZE = [384,384] 36 | VIDEO_INPUT_D = 16 37 | 38 | IMAGE_FEATURES = { 39 | "inputs": seqio.ContinuousFeature(dtype=tf.float32, rank=2), 40 | "targets": seqio.ContinuousFeature(dtype=tf.int32, rank=0), 41 | } 42 | 43 | VIDEO_FEATURES = { 44 | "inputs": seqio.ContinuousFeature(dtype=tf.float32, rank=3), 45 | "targets": seqio.ContinuousFeature(dtype=tf.int32, rank=0), 46 | } 47 | 48 | VIT_VQGAN_OUTPUT_FEATURES = { 49 | "inputs": seqio.ContinuousFeature(dtype=tf.float32, rank=3), 50 | } 51 | 52 | AUDIO_FEATURE_DESCRIPTION = { 53 | 'id': tf.io.FixedLenFeature([], tf.string), 54 | 'text': tf.io.FixedLenFeature([], tf.string), 55 | 'video': tf.io.FixedLenFeature([], tf.string), 56 | 'video_nframes': tf.io.FixedLenFeature([], tf.int64), 57 | 'video_width': tf.io.FixedLenFeature([], tf.int64), 58 | 'video_height': tf.io.FixedLenFeature([], tf.int64), 59 | 'video_nchannels': tf.io.FixedLenFeature([], tf.int64), 60 | 'audio': tf.io.FixedLenFeature([], tf.string), 61 | 'audio_nspectrograms': tf.io.FixedLenFeature([], tf.int64), 62 | 'audio_nmels': tf.io.FixedLenFeature([], tf.int64), 63 | 'audio_nhops': tf.io.FixedLenFeature([], tf.int64) 64 | } 65 | 66 | ACAV20M_FEATURE_DESCRIPTION = { 67 | 'id': tf.io.FixedLenFeature([], tf.string), 68 | 'video': tf.io.FixedLenFeature([], tf.string), 69 | 'video_nframes': tf.io.FixedLenFeature([], tf.int64), 70 | 'video_width': tf.io.FixedLenFeature([], tf.int64), 71 | 'video_height': tf.io.FixedLenFeature([], tf.int64), 72 | 'video_nchannels': tf.io.FixedLenFeature([], tf.int64), 73 | 'audio': tf.io.FixedLenFeature([], tf.string), 74 | 'audio_nspectrograms': tf.io.FixedLenFeature([], tf.int64), 75 | 'audio_nmels': tf.io.FixedLenFeature([], tf.int64), 76 | 'audio_nhops': tf.io.FixedLenFeature([], tf.int64) 77 | } 78 | 79 | NUM_CHUNKS = 8 80 | YTTEMOPORAL1B_FEATURE_DESCRIPTION = { 81 | 'id': tf.io.FixedLenFeature([], tf.string), 82 | 'video': tf.io.FixedLenFeature([NUM_CHUNKS], tf.string), 83 | 'video_nframes': tf.io.FixedLenFeature([], tf.int64), 84 | 'video_width': tf.io.FixedLenFeature([], tf.int64), 85 | 'video_height': tf.io.FixedLenFeature([], tf.int64), 86 | 'video_nchannels': tf.io.FixedLenFeature([], tf.int64), 87 | 'audio': tf.io.FixedLenFeature([], tf.string), 88 | 'audio_nspectrograms': tf.io.FixedLenFeature([], tf.int64), 89 | 'audio_nmels': tf.io.FixedLenFeature([], tf.int64), 90 | 'audio_nhops': tf.io.FixedLenFeature([], tf.int64), 91 | 'caption': tf.io.FixedLenFeature([NUM_CHUNKS], tf.string), 92 | 'caption_nsentences': tf.io.FixedLenFeature([], tf.int64), 93 | } 94 | eps = 0.1 95 | 96 | def plot_spectrogram(log_mel, eps=0.1, ylabel='freq_bin', aspect='auto', xmax=None, to_db=True): 97 | import librosa 98 | fig, axs = plt.subplots(1, 1) 99 | spec = np.exp(log_mel + np.log(eps)) - eps 100 | if to_db: 101 | spec = librosa.power_to_db(spec, ref=np.max) 102 | axs.set_ylabel(ylabel) 103 | axs.set_xlabel('frame') 104 | im = axs.imshow(spec, origin='lower', aspect=aspect) 105 | if xmax: 106 | axs.set_xlim((0, xmax)) 107 | fig.colorbar(im, ax=axs) 108 | fig.tight_layout(pad=0) 109 | fig.canvas.draw() 110 | 111 | data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) 112 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 113 | plt.close(fig) 114 | return data -------------------------------------------------------------------------------- /models/modeling.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | 3 | import jax 4 | from jax import numpy as jnp, lax 5 | import numpy as np 6 | import flax 7 | import flax.linen as nn 8 | from typing import Any, Dict, Union, Optional, Sequence 9 | import clu.parameter_overview 10 | from copy import deepcopy 11 | from dataclasses import dataclass 12 | from models import discriminator 13 | 14 | from models.checkpoint import bf16_to_f32 15 | from models.models_vit import VisionTransformer 16 | from models.models_vqgan import Generator 17 | from models.vggish import VGG as audio_vgg 18 | from models.vgg import VGG 19 | from models.discriminator import Discriminator 20 | 21 | class Model(nn.Module): 22 | config: Dict = None 23 | model_str: str = 'generator' 24 | 25 | @classmethod 26 | def from_config(cls, config, model_str, **kwargs): 27 | my_config = deepcopy(config) 28 | my_config['model']['data'] = my_config['data'] 29 | return cls(config=my_config['model'], model_str=model_str, **kwargs) 30 | 31 | def setup(self): 32 | for k, v in self.config.items(): 33 | setattr(self, k, v) 34 | 35 | self.dtype = jnp.bfloat16 if self.config.get('use_bfloat16', False) else jnp.float32 36 | print(f"Using dtype {self.dtype}", flush=True) 37 | 38 | if self.model_str == 'generator': 39 | self.model = Generator( 40 | vocab_size=self.config['vocab_size'], 41 | proj_dim=self.config['proj_dim'], 42 | patch_size=self.config['data']['patch_size'], 43 | encoder_hidden_size=self.config['encoder_hidden_size'], 44 | encoder_num_layers=self.config['encoder_num_layers'], 45 | encoder_mlp_dim=self.config['encoder_mlp_dim'], 46 | encoder_num_heads=self.config['encoder_num_heads'], 47 | decoder_hidden_size=self.config['decoder_hidden_size'], 48 | decoder_num_layers=self.config['decoder_num_layers'], 49 | decoder_mlp_dim=self.config['decoder_mlp_dim'], 50 | decoder_num_heads=self.config['decoder_num_heads'], 51 | dtype=self.dtype, 52 | dropout_rate=self.config.get('dropout_rate', 0.0), 53 | droppath_rate=self.config.get(' droppath_rate', 0.0), 54 | attention_dropout_rate=self.config.get('attention_dropout_rate', 0.0), 55 | add_position_embedding=self.config.get('add_position_embedding', False), 56 | default_input_size=self.config.get('default_input_size', (256, 256)), 57 | output_channel=self.config.get('output_channel', 3), 58 | use_bias=self.config.get('use_bias', False), 59 | act_fn=self.config.get('act_fn', 'relu'), 60 | ) 61 | 62 | elif self.model_str == 'vgg': 63 | if self.config['data']['task'] == 'image': 64 | self.model = VGG(dtype=self.dtype) 65 | else: 66 | self.model = audio_vgg(dtype=self.dtype) 67 | 68 | elif self.model_str == 'discriminator': 69 | if self.config['data']['task'] == 'image': 70 | self.model = Discriminator(dtype=self.dtype) 71 | else: 72 | self.model = Discriminator( 73 | num_channels = 1, 74 | resolution = 64, 75 | dtype=self.dtype, 76 | use_clip=False) 77 | 78 | def init_from_dummy_batch( 79 | self, 80 | dummy_batch, 81 | seed=0, 82 | aux_rng_keys=["dropout", "drop_path"], 83 | ): 84 | 85 | if self.model_str == 'vgg': 86 | def init_model(rngs, x0, x1): 87 | return self.init(rngs, x0, x1, train=False) 88 | else: 89 | def init_model(rngs, x): 90 | return self.init(rngs, x, train=False) 91 | 92 | num_keys = len(aux_rng_keys) 93 | rng = jax.random.PRNGKey(seed) 94 | key, *subkeys = jax.random.split(rng, num_keys + 1) 95 | rng_keys = {aux_rng_keys[ix]: subkeys[ix] for ix in range(len(aux_rng_keys))} 96 | dummy_batch_jax = {k: jnp.asarray(v[0, 0, None]) for k, v in dummy_batch.items()} 97 | 98 | print("start compiling", flush=True) 99 | x = dummy_batch_jax['inputs'] 100 | if self.model_str == 'vgg': 101 | params = jax.jit(init_model, backend='cpu')({'params': key, **rng_keys}, x, x)['params'] 102 | # params = init_model({'params': key, **rng_keys}, x, x)['params'] 103 | else: 104 | params = jax.jit(init_model, backend='cpu')({'params': key, **rng_keys}, x)['params'] 105 | # params = init_model({'params': key, **rng_keys}, x)['params'] 106 | 107 | rngs = flax.core.FrozenDict(rng_keys) 108 | 109 | # in case anything got initialized to bf16 110 | params = bf16_to_f32(params) 111 | print(clu.parameter_overview.get_parameter_overview(params), flush=True) 112 | 113 | return params, rngs 114 | 115 | def __call__(self, batch): 116 | raise NotImplementedError() -------------------------------------------------------------------------------- /data/dataloader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Pretraining dataloader 3 | """ 4 | import time 5 | import math 6 | import tensorflow as tf 7 | import seqio 8 | import functools 9 | from copy import deepcopy 10 | import random 11 | import warnings 12 | import numpy as np 13 | import jax 14 | from jax.experimental import multihost_utils 15 | import clu.data 16 | 17 | from data.tasks import TaskRegistry 18 | from data.mixtures import MixtureRegistry 19 | 20 | from data.data_utils import get_shape_list 21 | 22 | with warnings.catch_warnings(): 23 | warnings.simplefilter("ignore") 24 | tf.config.experimental.set_visible_devices([], 'GPU') 25 | 26 | logger = tf.get_logger() 27 | 28 | def handle_batch(batched_tensor, num_devices=None, use_bfloat16=False): 29 | """ 30 | Deal with the fact that for a batched tensor, the pointers are off 31 | nvm i'm just not going to worry about that and make the pointers only valid in-batch since we never 32 | link to anything outside of the batch 33 | :param batched_tensor: 34 | :return: 35 | """ 36 | # Mask batch 37 | # logger.info("BEFORE HANDLING BATCH") 38 | # for k, v in batched_tensor.items(): 39 | # logger.info("{}: {}".format(k, v.shape)) 40 | 41 | batch_size, height, width, channel = get_shape_list(batched_tensor['inputs'], 4) 42 | if num_devices is not None: 43 | assert num_devices <= batch_size 44 | assert batch_size % num_devices == 0 45 | shape_prefix = [num_devices, batch_size // num_devices] 46 | # logger.info("{} devices: shape prefix is {}".format(num_devices, shape_prefix)) 47 | else: 48 | # logger.info("No devices, batch size is just {}".format(batch_size)) 49 | shape_prefix = [batch_size] 50 | 51 | batched_tensor["inputs"] = tf.reshape(batched_tensor['inputs'], shape_prefix + [height, width, channel]) 52 | 53 | if use_bfloat16: 54 | batched_tensor['inputs'] = tf.cast(batched_tensor['inputs'], dtype=tf.bfloat16) 55 | 56 | return batched_tensor 57 | 58 | 59 | def make_dataset(config, batch_size, current_host, num_hosts, num_devices=None, seed=None, is_training=True): 60 | """ 61 | Create seqio dataset 62 | :param merged_config: 63 | :param batch_size: 64 | :param current_host: 65 | :param num_hosts: 66 | :param num_devices: 67 | :param is_training: 68 | :return: 69 | """ 70 | merged_config = deepcopy(config['data']) 71 | merged_config.update(config['model']) 72 | if seed is not None: 73 | multihost_utils.assert_equal( 74 | np.array(seed), 75 | f'`seed` is not same across hosts; {jax.process_index()} has a seed of ' 76 | f'{seed}') 77 | logger.info( 78 | "Initializing dataset for task '%s' with a replica batch size of %d and " 79 | 'a seed of %d', merged_config['task_name'], batch_size, seed) 80 | 81 | mixture_or_task = seqio.get_mixture_or_task(merged_config['task_name']) 82 | 83 | shard_info = seqio.ShardInfo(index=current_host, num_shards=num_hosts) 84 | 85 | sequence_length = { 86 | 'input_size': merged_config['input_size'], 87 | 'patch_size': merged_config['patch_size'], 88 | 'rand_aug': merged_config.get('rand_aug', None), 89 | 'rand_erase': merged_config.get('rand_erase', 0.0), 90 | 'is_training': is_training, 91 | } 92 | 93 | dataset = mixture_or_task.get_dataset( 94 | split="train" if is_training else "validation", 95 | sequence_length=sequence_length, 96 | shuffle=True if is_training else False, 97 | shard_info=shard_info, 98 | trim_output_features=True, 99 | seed=None, 100 | ) 101 | 102 | dataset = dataset.batch(batch_size, drop_remainder=True) 103 | dataset = dataset.map(functools.partial(handle_batch, num_devices=num_devices, 104 | use_bfloat16=merged_config['use_bfloat16'])) 105 | return dataset 106 | 107 | 108 | def input_fn_builder(config, seed, is_training=True, make_dataset_fn=make_dataset): 109 | """ 110 | Get input fn for TPU use -- for training 111 | :param config: 112 | :param is_training: 113 | :param as_numpy_iter: 114 | :return: 115 | """ 116 | import jax 117 | from flax import jax_utils 118 | 119 | current_host = jax.process_index() 120 | num_hosts = jax.process_count() 121 | num_devices = jax.local_device_count() 122 | batch_size = config['device']['batch_size'] // num_hosts 123 | random.seed(seed) 124 | tf.random.set_seed(seed) 125 | 126 | dataset = make_dataset_fn( 127 | config, 128 | batch_size=batch_size, 129 | current_host=current_host, 130 | num_hosts=num_hosts, 131 | num_devices=num_devices, 132 | seed=seed, 133 | is_training=is_training, 134 | ) 135 | 136 | # dataset = clu.data.TfDatasetIterator(dataset) 137 | # return dataset 138 | 139 | def _multi_iterator0(): 140 | n_epochs = 0 141 | while True: 142 | print(f"Resetting iterator, epoch={n_epochs + 1}", flush=True) 143 | try: 144 | dataset_iter = iter(dataset) 145 | for item in dataset_iter: 146 | item = jax.tree_map(lambda x: x._numpy(), item) 147 | yield item 148 | except Exception as e: 149 | print(str(e)) 150 | time.sleep(5) 151 | n_epochs += 1 152 | 153 | if config['device'].get('prefetch_size', 1) > 0: 154 | return jax_utils.prefetch_to_device(_multi_iterator0(), size=config['device'].get('prefetch_size', 1)) 155 | return _multi_iterator0() -------------------------------------------------------------------------------- /data/tasks.py: -------------------------------------------------------------------------------- 1 | # Model that can be imported to register all tasks 2 | import os 3 | from seqio import FileDataSource, TaskRegistry 4 | 5 | from data.metrics import * 6 | from data.preprocessors import * 7 | 8 | from config import * 9 | 10 | 11 | TaskRegistry.add( 12 | "encoder_only_imagenet2012", 13 | source=seqio.TfdsDataSource( 14 | tfds_name="imagenet2012:5.1.0", 15 | tfds_data_dir=TFDS_DATA_DIR, 16 | ), 17 | preprocessors=[ 18 | functools.partial( 19 | rekey, key_map={ 20 | "image": ["image"], 21 | "label": ["label"] 22 | }), 23 | functools.partial( 24 | encoder_only_preprocessor, 25 | ), 26 | ], 27 | metric_fns=[cls_accuracy_metric], 28 | output_features=IMAGE_FEATURES, 29 | ) 30 | 31 | TaskRegistry.add( 32 | "vitvqgan_imagenet2012", 33 | source=seqio.TfdsDataSource( 34 | tfds_name="imagenet2012:5.1.0", 35 | tfds_data_dir=TFDS_DATA_DIR, 36 | ), 37 | preprocessors=[ 38 | functools.partial( 39 | rekey, key_map={ 40 | "image": ["image"], 41 | }), 42 | functools.partial( 43 | vit_vqgan_preprocessor, 44 | ), 45 | ], 46 | metric_fns=[cls_accuracy_metric], 47 | output_features=VIT_VQGAN_OUTPUT_FEATURES, 48 | ) 49 | 50 | TFRECORD_LAION400M_FEATURES = { 51 | 'image': tf.io.FixedLenFeature(shape=(), dtype=tf.string), 52 | 'text':tf.io.FixedLenFeature(shape=(), dtype=tf.string), 53 | } 54 | 55 | TaskRegistry.add( 56 | "vit_vqgan_liaon_400m", 57 | source=seqio.TFExampleDataSource( 58 | split_to_filepattern={ 59 | "train": os.path.join('gs://unified-io-2/pretrain-datasets', "laion400m", "1.0.0", "laion400m-train*"), 60 | }, 61 | feature_description=TFRECORD_LAION400M_FEATURES, 62 | ), 63 | preprocessors=[ 64 | functools.partial( 65 | rekey, key_map={ 66 | "image": ["image"], 67 | }), 68 | functools.partial( 69 | vit_vqgan_preprocessor, 70 | decode_jpeg=True, 71 | ), 72 | ], 73 | metric_fns=[], 74 | output_features=VIT_VQGAN_OUTPUT_FEATURES, 75 | ) 76 | 77 | TaskRegistry.add( 78 | "vit_vqgan_imagenet2012", 79 | source=seqio.TfdsDataSource( 80 | tfds_name="imagenet2012:5.1.0", 81 | tfds_data_dir=TFDS_DATA_DIR, 82 | ), 83 | preprocessors=[ 84 | functools.partial( 85 | rekey, key_map={ 86 | "image": ["image"], 87 | }), 88 | functools.partial( 89 | vit_vqgan_preprocessor, 90 | decode_jpeg=False, 91 | ), 92 | ], 93 | metric_fns=[], 94 | output_features=VIT_VQGAN_OUTPUT_FEATURES, 95 | ) 96 | 97 | TaskRegistry.add( 98 | "caltech_birds2011", 99 | source=seqio.TfdsDataSource( 100 | tfds_name="caltech_birds2011:0.1.1", 101 | tfds_data_dir=TFDS_DATA_DIR, 102 | ), 103 | preprocessors=[ 104 | functools.partial( 105 | rekey, key_map={ 106 | "image": ["image"], 107 | }), 108 | functools.partial( 109 | vit_vqgan_preprocessor, 110 | decode_jpeg=False, 111 | ), 112 | ], 113 | metric_fns=[], 114 | output_features=VIT_VQGAN_OUTPUT_FEATURES, 115 | ) 116 | 117 | audioset_keys = list(AUDIO_FEATURE_DESCRIPTION.keys()) 118 | audioset_keymap = {key: [key] for key in audioset_keys} 119 | audioset_keymap["class_name"] = ["text"] 120 | del audioset_keymap["text"] 121 | 122 | TaskRegistry.add( 123 | "vit_vqgan_audioset", 124 | source=seqio.TFExampleDataSource( 125 | split_to_filepattern={ 126 | "train": os.path.join("gs://unified-io-2/pretrain-datasets", "audioset", "1.0.0", "audioset-train*"), 127 | }, 128 | feature_description=AUDIO_FEATURE_DESCRIPTION, 129 | ), 130 | preprocessors=[ 131 | functools.partial( 132 | rekey, key_map=audioset_keymap, 133 | ), 134 | functools.partial( 135 | audio_preprocessor, 136 | decode_video_string=True 137 | ), 138 | ], 139 | metric_fns=[], 140 | output_features=VIT_VQGAN_OUTPUT_FEATURES, 141 | ) 142 | 143 | TaskRegistry.add( 144 | "vit_vqgan_acav20m", 145 | source=seqio.TFExampleDataSource( 146 | split_to_filepattern={ 147 | "train": os.path.join("gs://unified-io-2/pretrain-datasets", "acav20m_v1", "1.0.0", "acav20m-train*"), 148 | }, 149 | feature_description=ACAV20M_FEATURE_DESCRIPTION, 150 | ), 151 | preprocessors=[ 152 | functools.partial( 153 | rekey, key_map=audioset_keymap, 154 | ), 155 | functools.partial( 156 | audio_preprocessor, 157 | decode_video_string=True 158 | ), 159 | ], 160 | metric_fns=[], 161 | output_features=VIT_VQGAN_OUTPUT_FEATURES, 162 | ) 163 | 164 | TaskRegistry.add( 165 | "vit_vqgan_hdvila10m", 166 | source=seqio.TFExampleDataSource( 167 | split_to_filepattern={ 168 | "train": os.path.join("gs://unified-io-2/pretrain-datasets", "hdvila10m", "1.0.0", "hdvila10m-train*"), 169 | }, 170 | feature_description=ACAV20M_FEATURE_DESCRIPTION, 171 | ), 172 | preprocessors=[ 173 | functools.partial( 174 | rekey, key_map=audioset_keymap, 175 | ), 176 | functools.partial( 177 | audio_preprocessor, 178 | decode_video_string=True 179 | ), 180 | ], 181 | metric_fns=[], 182 | output_features=VIT_VQGAN_OUTPUT_FEATURES, 183 | ) 184 | 185 | TaskRegistry.add( 186 | "vit_vqgan_yttemoporal1b", 187 | source=seqio.TFExampleDataSource( 188 | split_to_filepattern={ 189 | "train": os.path.join("gs://unified-io-2/pretrain-datasets", "yttemporal1b", "1.0.0", "yttemporal1b-train*"), 190 | }, 191 | feature_description=YTTEMOPORAL1B_FEATURE_DESCRIPTION, 192 | ), 193 | preprocessors=[ 194 | functools.partial( 195 | rekey, key_map=audioset_keymap, 196 | ), 197 | functools.partial( 198 | audio_preprocessor, 199 | decode_video_string=True, 200 | random_start=False, 201 | ), 202 | ], 203 | metric_fns=[], 204 | output_features=VIT_VQGAN_OUTPUT_FEATURES, 205 | ) 206 | -------------------------------------------------------------------------------- /training/train_state.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified training state with easy param freezing support and updating rngs over training iterations 3 | 4 | Shamelessly copied from audax (https://github.com/SarthakYadav/audax/blob/master/audax/training_utils/trainstate.py) 5 | Written for audax by / Copyright 2022, Sarthak Yadav 6 | """ 7 | from typing import Any, Callable, Dict 8 | import jax 9 | from flax import core 10 | from flax import struct 11 | import optax 12 | import flax 13 | 14 | 15 | class TrainState_v2(struct.PyTreeNode): 16 | """Simple train state for the common case with a single Optax optimizer. 17 | 18 | Synopsis:: 19 | state = TrainState.create( 20 | apply_fn=model.apply, 21 | params=variables['params'], 22 | tx=tx) 23 | grad_fn = jax.grad(make_loss_fn(state.apply_fn)) 24 | for batch in data: 25 | grads = grad_fn(state.params, batch) 26 | state = state.apply_gradients(grads=grads) 27 | 28 | Note that you can easily extend this dataclass by subclassing it for storing 29 | additional data (e.g. additional variable collections). 30 | 31 | For more exotic usecases (e.g. multiple optimizers) it's probably best to 32 | fork the class and modify it. 33 | 34 | Args: 35 | step: Counter starts at 0 and is incremented by every call to 36 | `.apply_gradients()`. 37 | apply_fn: Usually set to `model.apply()`. Kept in this dataclass for 38 | convenience to have a shorter params list for the `train_step()` function 39 | in your training loop. 40 | params: The parameters to be updated by `tx` and used by `apply_fn`. 41 | frozen_params: 42 | tx: An Optax gradient transformation. 43 | opt_state: The state for `tx`. 44 | """ 45 | step: int 46 | params_g: core.FrozenDict[str, Any] 47 | params_d: core.FrozenDict[str, Any] 48 | params_p: core.FrozenDict[str, Any] 49 | frozen_params_g: core.FrozenDict[str, Any] 50 | frozen_params_d: core.FrozenDict[str, Any] 51 | aux_rng_keys_g: core.FrozenDict[str, Any] 52 | aux_rng_keys_d: core.FrozenDict[str, Any] 53 | aux_rng_keys_p: core.FrozenDict[str, Any] 54 | opt_state_g: optax.OptState 55 | opt_state_d: optax.OptState 56 | apply_fn_g: Callable = struct.field(pytree_node=False) 57 | apply_fn_d: Callable = struct.field(pytree_node=False) 58 | apply_fn_p: Callable = struct.field(pytree_node=False) 59 | tx_g: optax.GradientTransformation = struct.field(pytree_node=False) 60 | tx_d: optax.GradientTransformation = struct.field(pytree_node=False) 61 | 62 | def apply_gradients_g(self, *, grads, **kwargs): 63 | """Updates `step`, `params`, `opt_state` and `**kwargs` in return value. 64 | 65 | Note that internally this function calls `.tx.update()` followed by a call 66 | to `optax.apply_updates()` to update `params` and `opt_state`. 67 | 68 | Args: 69 | grads: Gradients that have the same pytree structure as `.params`. 70 | **kwargs: Additional dataclass attributes that should be `.replace()`-ed. 71 | 72 | Returns: 73 | An updated instance of `self` with `step` incremented by one, `params` 74 | and `opt_state` updated by applying `grads`, and additional attributes 75 | replaced as specified by `kwargs`. 76 | """ 77 | 78 | updates, new_opt_state = self.tx_g.update(grads, self.opt_state_g, self.params_g) 79 | new_params = optax.apply_updates(self.params_g, updates) 80 | rng_keys = self.update_rng_keys_g() 81 | 82 | return self.replace( 83 | params_g=new_params, 84 | frozen_params_g=self.frozen_params_g, 85 | opt_state_g=new_opt_state, 86 | aux_rng_keys_g=rng_keys, 87 | **kwargs) 88 | 89 | def apply_gradients_d(self, *, grads, **kwargs): 90 | """Updates `step`, `params`, `opt_state` and `**kwargs` in return value. 91 | 92 | Note that internally this function calls `.tx.update()` followed by a call 93 | to `optax.apply_updates()` to update `params` and `opt_state`. 94 | 95 | Args: 96 | grads: Gradients that have the same pytree structure as `.params`. 97 | **kwargs: Additional dataclass attributes that should be `.replace()`-ed. 98 | 99 | Returns: 100 | An updated instance of `self` with `step` incremented by one, `params` 101 | and `opt_state` updated by applying `grads`, and additional attributes 102 | replaced as specified by `kwargs`. 103 | """ 104 | 105 | updates, new_opt_state = self.tx_d.update(grads, self.opt_state_d, self.params_d) 106 | new_params = optax.apply_updates(self.params_d, updates) 107 | rng_keys = self.update_rng_keys_d() 108 | 109 | return self.replace( 110 | params_d=new_params, 111 | frozen_params_d=self.frozen_params_d, 112 | opt_state_d=new_opt_state, 113 | aux_rng_keys_d=rng_keys, 114 | **kwargs) 115 | 116 | 117 | def update_rng_keys_g(self): 118 | unfrozen = flax.core.unfreeze(self.aux_rng_keys_g) 119 | for k in self.aux_rng_keys_g.keys(): 120 | unfrozen[k] = jax.random.split(unfrozen[k], 1)[0] 121 | return flax.core.freeze(unfrozen) 122 | 123 | def update_rng_keys_d(self): 124 | unfrozen = flax.core.unfreeze(self.aux_rng_keys_d) 125 | for k in self.aux_rng_keys_d.keys(): 126 | unfrozen[k] = jax.random.split(unfrozen[k], 1)[0] 127 | return flax.core.freeze(unfrozen) 128 | 129 | @property 130 | def get_all_params(self): 131 | return {**self.params, **self.frozen_params} 132 | 133 | @classmethod 134 | def create(cls, *, models, params, frozen_params, txs, aux_rng_keys, **kwargs): 135 | """Creates a new instance with `step=0` and initialized `opt_state`.""" 136 | opt_states = {} 137 | for k, tx in txs.items(): 138 | opt_states[k] = tx.init(params[k]) 139 | 140 | apply_fns = {k: v.apply for k, v in models.items()} 141 | 142 | return cls( 143 | step=0, 144 | apply_fn_g = apply_fns['g'], 145 | apply_fn_d = apply_fns['d'], 146 | apply_fn_p = apply_fns['p'], 147 | params_g=params['g'], 148 | params_d=params['d'], 149 | params_p=params['p'], 150 | frozen_params_g=frozen_params['g'], 151 | frozen_params_d=frozen_params['d'], 152 | tx_g=txs['g'], 153 | tx_d=txs['d'], 154 | opt_state_g=opt_states['g'], 155 | opt_state_d=opt_states['d'], 156 | aux_rng_keys_g=aux_rng_keys['g'], 157 | aux_rng_keys_d=aux_rng_keys['d'], 158 | aux_rng_keys_p=aux_rng_keys['p'], 159 | **kwargs, 160 | ) -------------------------------------------------------------------------------- /models/checkpoint.py: -------------------------------------------------------------------------------- 1 | from flax.training import checkpoints 2 | from training.train_state import TrainState_v2 3 | import flax 4 | import jax 5 | from typing import Optional, Any 6 | import clu.parameter_overview 7 | import operator 8 | import jax.numpy as jnp 9 | import os 10 | from pathlib import Path 11 | 12 | 13 | def _treemap_cast(from_dtype, to_dtype, tree): 14 | """ 15 | Convert leaves in a tree from `from_dtype` to `to_dtype` 16 | :param from_dtype: 17 | :param to_dtype: 18 | :param tree: 19 | :return: 20 | """ 21 | 22 | def _do_cast(x): 23 | if not hasattr(x, 'dtype'): # for ints and stuff 24 | return x 25 | if x.dtype == from_dtype: 26 | return x.astype(to_dtype) 27 | return x 28 | 29 | return jax.tree_map(_do_cast, tree) 30 | 31 | 32 | def _compress_state(state: TrainState_v2): 33 | """ 34 | For saving i'll cast float32 down to float16, keep bfloat unchanged 35 | I'm doing this because float16 has more precision 36 | :param state: 37 | :return: 38 | """ 39 | return _treemap_cast(from_dtype=jnp.float32, to_dtype=jnp.float16, tree=state) 40 | 41 | 42 | def _decompress_state(state: TrainState_v2): 43 | return _treemap_cast(from_dtype=jnp.float16, to_dtype=jnp.float32, tree=state) 44 | 45 | 46 | def bf16_to_f32(params): 47 | """ 48 | Cast params to float32 49 | :param params: 50 | :return: 51 | """ 52 | return _treemap_cast(from_dtype=jnp.bfloat16, to_dtype=jnp.float32, tree=params) 53 | 54 | 55 | def f32_to_bf16(params): 56 | """ 57 | Cast params to float32 58 | :param params: 59 | :return: 60 | """ 61 | return _treemap_cast(from_dtype=jnp.float32, to_dtype=jnp.bfloat16, tree=params) 62 | 63 | 64 | def _flatten_info(input_dict, *, prefix="", delimiter="/"): 65 | output_keys = [] 66 | for key, value in input_dict.items(): 67 | nested_key = f"{prefix}{delimiter}{key}" if prefix else key 68 | if isinstance(value, (dict, flax.core.FrozenDict)): 69 | output_keys += _flatten_info(value, prefix=nested_key, delimiter=delimiter) 70 | else: 71 | output_keys += [(nested_key, value.shape)] 72 | return output_keys 73 | 74 | 75 | def initialize_using_checkpoint(model_params, cache_params): 76 | model_params = flax.core.unfreeze(model_params) 77 | for key in cache_params: 78 | assert key in model_params 79 | model_info = _flatten_info(model_params[key]) 80 | cache_info = _flatten_info(cache_params[key]) 81 | assert set(model_info) == set(cache_info) 82 | model_params[key] = cache_params[key] 83 | model_params = flax.core.freeze(model_params) 84 | return model_params 85 | 86 | 87 | def save_checkpoint(state: TrainState_v2, path: str, keep=None, overwrite=True, with_shard_optimizer=False, 88 | no_optimizer=False): 89 | """ 90 | :param state: 91 | :param path: Path where we'll save stuff to 92 | :param keep: If specified this is how many we should keep 93 | :param overwrite: If we should overwrite 94 | :return: 95 | """ 96 | 97 | step = int(state.step[0]) 98 | if keep is None: 99 | keep = 100000000 100 | 101 | if jax.process_index() == 0: 102 | print(f"Saving checkpoint at step {step}, path {path}", flush=True) 103 | 104 | if with_shard_optimizer: 105 | print("Dealing with sharded optimizer", flush=True) 106 | params_g = jax.device_get(jax.tree_map(lambda x: x[0], state.params_g)) 107 | params_d = jax.device_get(jax.tree_map(lambda x: x[0], state.params_d)) 108 | params_p = jax.device_get(jax.tree_map(lambda x: x[0], state.params_p)) 109 | 110 | opt_state_g = jax.device_get(state.opt_state_g) 111 | opt_state_d = jax.device_get(state.opt_state_d) 112 | 113 | state = state.replace(step=step, 114 | params_g=params_g, 115 | params_d=params_d, 116 | params_p=params_p, 117 | opt_state_g=opt_state_g, 118 | opt_state_d=opt_state_d) 119 | elif no_optimizer: 120 | print("Not including the optimizer state", flush=True) 121 | params_g = jax.device_get(jax.tree_map(lambda x: x[0], state.params_g)) 122 | params_d = jax.device_get(jax.tree_map(lambda x: x[0], state.params_d)) 123 | params_p = jax.device_get(jax.tree_map(lambda x: x[0], state.params_p)) 124 | 125 | state = state.replace(step=step, 126 | params_g=params_g, 127 | params_d=params_d, 128 | params_p=params_p, 129 | opt_state=None, 130 | ) 131 | else: 132 | # Get first replica 133 | state = jax.device_get(jax.tree_map(lambda x: x[0], state)) 134 | 135 | state = _compress_state(state) 136 | 137 | checkpoints.save_checkpoint(path, state, step=step, prefix='ckpt_', keep=keep, overwrite=overwrite) 138 | 139 | 140 | def load_checkpoint(path: str, state: Optional[TrainState_v2] = None, step=None, use_bfloat16_weights=False): 141 | """ 142 | Loads a checkpoint. I'm saving the weights in float16 and the adam variables in a weird bfloat16 format. 143 | :param state: 144 | :param path: 145 | :param step: 146 | :param to_float32: Whether to convert weights to float32 -- needed for training 147 | :return: 148 | """ 149 | # Temporarily compress the state to be equal to what we're loading 150 | if state is not None: 151 | state = _compress_state(state) 152 | 153 | state = checkpoints.restore_checkpoint(ckpt_dir=path, target=state, step=step, prefix='ckpt_', parallel=True) 154 | state = _decompress_state(state) 155 | if use_bfloat16_weights: 156 | state = state.replace( 157 | params_g=f32_to_bf16(state.params_g), 158 | params_d=f32_to_bf16(state.params_d), 159 | params_p=f32_to_bf16(state.params_p)) 160 | 161 | return state 162 | 163 | 164 | def log_param_shapes(params: Any) -> int: 165 | """ 166 | # Maybe could be useful: 167 | https://github.com/google-research/scenic/blob/ab3083d8cbfe3216119a0f24fce23ca988e20355/scenic/common_lib/debug_utils.py 168 | 169 | Prints out shape of parameters and total number of trainable parameters. 170 | Args: 171 | params: PyTree of model parameters. 172 | print_params_nested_dict: If True, it prints parameters in shape of a nested 173 | dict. 174 | Returns: 175 | int; Total number of trainable parameters. 176 | """ 177 | print(clu.parameter_overview.get_parameter_overview(params)) 178 | total_params = jax.tree_util.tree_reduce(operator.add, jax.tree_map(lambda x: x.size, params)) 179 | # logging.info('Total params: %d', total_params) 180 | return total_params 181 | 182 | 183 | def tree_map_nested_keys(f, params): 184 | """ 185 | Tree map, but you get the KEY and the VALUE 186 | :param f: function returning nested keys joined by a '/' and values 187 | :param params: 188 | :return: new tree 189 | """ 190 | leaves, treedef = jax.tree_util.tree_flatten(params) 191 | params_flat = clu.parameter_overview.flatten_dict(params) 192 | for i, k in enumerate(sorted(params_flat.keys())): 193 | assert params_flat[k] is leaves[i] 194 | leaves[i] = f(k, leaves[i]) 195 | return treedef.unflatten(leaves) 196 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is the training script 3 | """ 4 | 5 | import sys 6 | 7 | import os 8 | import yaml 9 | from datetime import datetime 10 | import pytz 11 | import jax 12 | import jax.numpy as jnp 13 | from data.dataloader import input_fn_builder 14 | from training.train_model import * 15 | from flax import jax_utils 16 | from training.optimization import construct_train_state 17 | from models.checkpoint import initialize_using_checkpoint, save_checkpoint, load_checkpoint, bf16_to_f32 18 | from jax.experimental import multihost_utils 19 | import argparse 20 | import numpy as np 21 | import functools 22 | import time 23 | import decimal 24 | import simplejson 25 | from config import plot_spectrogram, eps 26 | 27 | # jax.config.update('jax_log_compiles', True) 28 | is_on_gpu = any([x.platform == 'gpu' for x in jax.local_devices()]) 29 | if not is_on_gpu: 30 | assert any([x.platform == 'tpu' for x in jax.local_devices()]) 31 | print('JAX process: {} / {}. Local devices {}. Using {}'.format( 32 | jax.process_index(), jax.process_count(), jax.local_devices(), 'GPU' if is_on_gpu else 'TPU'), flush=True) 33 | 34 | parser = argparse.ArgumentParser(description='Train model!') 35 | parser.add_argument( 36 | 'config_file', 37 | help='Where the config.yaml is located', 38 | type=str, 39 | ) 40 | parser.add_argument( 41 | '-output_dir', 42 | help='Override output directory (otherwise we do whats in the config file and add timestamp).', 43 | dest='output_dir', 44 | default='', 45 | type=str, 46 | ) 47 | 48 | parser.add_argument( 49 | '-disable_wandb', 50 | help='dont log this result on weights and biases', 51 | dest='disable_wandb', 52 | action='store_true', 53 | ) 54 | args = parser.parse_args() 55 | 56 | print(f"Loading from {args.config_file}", flush=True) 57 | with open(args.config_file, 'r') as f: 58 | config = yaml.load(f, yaml.FullLoader) 59 | 60 | seattle_time = pytz.utc.localize(datetime.utcnow()).astimezone(pytz.timezone('America/Los_Angeles')) 61 | seattle_time = seattle_time.strftime("%Y-%m-%d-%H:%M.%S") 62 | 63 | if is_on_gpu: 64 | config['data']['num_train_files'] = 1 65 | config['device']['output_dir'] = 'temp' 66 | config['model']['use_bfloat16'] = False 67 | config['device']['batch_size'] = 6 68 | 69 | config['optimizer']['num_train_steps_override'] = 1000 70 | elif args.output_dir == '': 71 | config['device']['output_dir'] = os.path.join(config['device']['output_dir'], seattle_time) 72 | else: 73 | config['device']['output_dir'] = args.output_dir 74 | 75 | config['_path'] = args.config_file 76 | if (jax.process_index() == 0) and (not is_on_gpu) and (not args.disable_wandb): 77 | import wandb 78 | wandb_api, wandb_project, wandb_entity, wandb_name = ( 79 | config['device']['wandb_api'], 80 | config['device']['wandb_project'], 81 | config['device']['wandb_entity'], 82 | config['device']['wandb_name'], 83 | ) 84 | del config['device']['wandb_api'] 85 | del config['device']['wandb_project'] 86 | del config['device']['wandb_entity'] 87 | del config['device']['wandb_name'] 88 | os.environ["WANDB_API_KEY"] = wandb_api 89 | wandb.init( 90 | project=wandb_project, 91 | entity=wandb_entity, 92 | name=wandb_name, 93 | config=config, 94 | ) 95 | else: 96 | wandb = None 97 | 98 | 99 | seed = config['device'].get('seed', None) 100 | if seed is None: 101 | seed = multihost_utils.broadcast_one_to_all(np.int32(time.time())) 102 | 103 | ds_train_iter = input_fn_builder(config, seed, is_training=True) 104 | 105 | dummy_batch = next(ds_train_iter) 106 | 107 | for k, v in dummy_batch.items(): 108 | print("{}: {} {}".format(k, v.shape, v.dtype), flush=True) 109 | 110 | aux_rng_keys=["dropout", "drop_path"] 111 | 112 | generator = Generator.from_config(config, 'generator') 113 | discriminator = Discriminator.from_config(config, 'discriminator') 114 | lpips = LPIPS.from_config(config, 'vgg') 115 | 116 | if is_on_gpu: 117 | print("DEBUG GPU BATCH!", flush=True) 118 | rng = jax.random.PRNGKey(0) 119 | num_keys = len(aux_rng_keys) 120 | key, *subkeys = jax.random.split(rng, num_keys + 1) 121 | rng_keys = {aux_rng_keys[ix]: subkeys[ix] for ix in range(len(aux_rng_keys))} 122 | generator.init({'params': key, **rng_keys}, {k: jnp.asarray(v[0]) for k, v in dummy_batch.items()}) 123 | 124 | g_params, g_rng_keys = generator.init_from_dummy_batch(dummy_batch, seed, aux_rng_keys) 125 | d_params, d_rng_keys = discriminator.init_from_dummy_batch(dummy_batch, seed, aux_rng_keys) 126 | p_params, p_rng_keys = lpips.init_from_dummy_batch(dummy_batch, seed, aux_rng_keys) 127 | 128 | state = construct_train_state( 129 | opt_config={'g': config['optimizer_g'], 'd': config['optimizer_d']}, 130 | models={'g': generator, 'd': discriminator, 'p': lpips}, 131 | params={'g': g_params, 'd': d_params, 'p': p_params}, 132 | rng_keys={'g': g_rng_keys, 'd': d_rng_keys, 'p': p_rng_keys}) 133 | 134 | step = None 135 | # Initialize params using merlot reserve checkpoint 136 | ckpt_path = config['device'].get('initialize_ckpt', '') 137 | if ckpt_path: 138 | ckpt = load_checkpoint(path=ckpt_path) 139 | cache_params_g = ckpt['params_g'] 140 | cache_params_d = ckpt['params_d'] 141 | cache_params_p = ckpt['params_p'] 142 | cache_opt_state_g = ckpt['opt_state_g'] 143 | cache_opt_state_d = ckpt['opt_state_d'] 144 | step = ckpt['step'] 145 | del ckpt 146 | 147 | print(f"{ckpt_path}: {list(cache_params_g.keys())} loaded on the model", flush=True) 148 | print(f"{ckpt_path}: {list(cache_params_d.keys())} loaded on the model", flush=True) 149 | print(f"{ckpt_path}: {list(cache_params_p.keys())} loaded on the model", flush=True) 150 | 151 | state = state.replace( 152 | params_p=initialize_using_checkpoint(state.params_p, cache_params_p), 153 | params_g=initialize_using_checkpoint(state.params_g, cache_params_g), 154 | params_d=initialize_using_checkpoint(state.params_d, cache_params_d), 155 | step = step, 156 | ) 157 | 158 | # load if we can 159 | state = load_checkpoint(state=state, path=config['device']['initialize_ckpt'], step=None, 160 | use_bfloat16_weights=config['optimizer_g'].get('use_bfloat16_weights', False)) 161 | start_step = int(state.step) 162 | state = jax_utils.replicate(state) 163 | 164 | p_train_step = jax.pmap(functools.partial(train_step, config=config,), 165 | axis_name='batch', donate_argnums=(0, 1,)) 166 | 167 | 168 | # p_train_step = jax.vmap(functools.partial(train_step, config=config,), 169 | # axis_name='batch')#, donate_argnums=(0, 1,)) 170 | 171 | train_metrics = [] 172 | time_elapsed = [] 173 | num_train_steps = config['optimizer_g'].get('num_train_steps_override', config['optimizer_g']['num_train_steps']) 174 | log_every = config['device'].get('commit_every_nsteps', 50) 175 | 176 | for n in range(start_step, num_train_steps): 177 | st = time.time() 178 | batch = next(ds_train_iter) 179 | state, loss_info = p_train_step(state, batch) 180 | 181 | # Async transfer. Basically we queue the last thing, then log the thing from `log_every` iterations ago 182 | if jax.process_index() == 0: 183 | image_info = {k:v[0] for k, v in loss_info.items() if 'image' in k} 184 | jax.tree_map(lambda x: x.copy_to_host_async(), image_info) 185 | 186 | loss_info = {k:v for k, v in loss_info.items() if 'image' not in k} 187 | train_metrics.append(jax.tree_map(lambda x: x[0], loss_info)) 188 | jax.tree_map(lambda x: x.copy_to_host_async(), train_metrics[-1]) 189 | 190 | step_for_logging = n - log_every 191 | if step_for_logging >= 0: 192 | train_metrics[step_for_logging] = {k: float(v) for k, v in train_metrics[step_for_logging].items()} 193 | tmp_metrics = {k:v for k, v in train_metrics[step_for_logging].items()} 194 | if (n + 1) % log_every == 0: 195 | if wandb is not None: 196 | for k, v in image_info.items(): tmp_metrics['image' + '/' + k] = wandb.Image(np.array(v[0]), caption=k) 197 | stats = { 198 | k: decimal.Decimal("{:.6f}".format(v)) if isinstance(v, float) else v 199 | for k, v in train_metrics[step_for_logging].items() 200 | } 201 | json_stats = simplejson.dumps(stats, sort_keys=True, use_decimal=True) 202 | print("@iter {} stats: {:s}".format(step_for_logging + start_step, json_stats), flush=True) 203 | if wandb is not None: 204 | wandb.log(tmp_metrics, step=step_for_logging + start_step, commit=(n + 1) % log_every == 0) 205 | 206 | if (n + 1) % config['device']['save_every_nsteps'] == 0 or (n + 1) == num_train_steps: 207 | save_checkpoint(state, path=config['device']['output_dir']) 208 | print(f"Saving @iter {n:03d}.", flush=True) 209 | 210 | time_elapsed.append(time.time() - st) 211 | if len(time_elapsed) >= 100: 212 | tsum = sum(time_elapsed) 213 | print("Completed 100 batches in {:.3f}sec, avg {:.3f} it/sec".format(tsum, 100.0/tsum), flush=True) 214 | time_elapsed = [] 215 | 216 | if wandb is not None: 217 | wandb.finish() -------------------------------------------------------------------------------- /training/train_model_2.py: -------------------------------------------------------------------------------- 1 | from models.modeling import * 2 | from models.checkpoint import f32_to_bf16, bf16_to_f32 3 | from training.train_state import TrainState_v2 4 | 5 | class Generator(Model): 6 | """ 7 | Generator 8 | """ 9 | def __call__(self, x, *, train): 10 | """ 11 | Does a forward pass for pretraining 12 | :param batch: Everything from pretraining 13 | :return: 14 | """ 15 | rec, diff = self.model(x, train=train) 16 | return rec, diff 17 | 18 | class Discriminator(Model): 19 | """ 20 | Discriminator 21 | """ 22 | def __call__(self, x, *, train): 23 | """ 24 | Does a forward pass for pretraining 25 | :param batch: Everything from pretraining 26 | :return: 27 | """ 28 | output = self.model(x, train=train) 29 | return output 30 | 31 | class LPIPS(Model): 32 | """ 33 | LPIPS 34 | """ 35 | def __call__(self, x0, x1, *, train): 36 | """ 37 | Does a forward pass for pretraining 38 | :param batch: Everything from pretraining 39 | :return: 40 | """ 41 | output = self.model(x0, x1, train=train) 42 | return output 43 | 44 | def vanilla_d_loss(logits_fake, logits_real = None): 45 | loss_fake = jnp.mean(jax.nn.softplus(-logits_fake)) * 2 if logits_real is None else jnp.mean(jax.nn.softplus(logits_fake)) 46 | loss_real = 0 if logits_real is None else jnp.mean(jax.nn.softplus(-logits_real)) 47 | return 0.5 * (loss_real + loss_fake) 48 | 49 | def hinge_d_loss(logits_fake, logits_real = None): 50 | loss_fake = jnp.mean(jax.nn.relu(1.0 + logits_fake)) 51 | loss_real = jnp.mean(jax.nn.relu(1.0 - logits_real)) 52 | return 0.5 * (loss_real + loss_fake) 53 | 54 | def calculate_adaptive_weight(): 55 | pass 56 | 57 | def VQLPIPS( 58 | input, 59 | rec, 60 | codebook_loss, 61 | p_logits, 62 | g_logits, 63 | codebook_weight: float = 1.0, 64 | loggaussian_weight: float = 1.0, 65 | loglaplace_weight: float = 0.0, 66 | perceptual_weight: float = 0.1, 67 | adversarial_weight: float = 0.1, 68 | disc_g_start: int = 0, 69 | step: int = 0, 70 | ): 71 | 72 | loglaplace_loss = jnp.mean(jnp.abs(rec - input)) 73 | loggaussian_loss = jnp.mean((rec - input)**2) 74 | perceptual_loss = jnp.mean(p_logits) 75 | 76 | d_weight_factor = jnp.where(step > disc_g_start, adversarial_weight,0.0) 77 | adversarial_loss = jnp.mean(vanilla_d_loss(g_logits)) 78 | 79 | nll_loss = loglaplace_weight * loglaplace_loss + loggaussian_weight * loggaussian_loss \ 80 | + perceptual_weight * perceptual_loss + d_weight_factor * adversarial_loss 81 | 82 | loss = nll_loss + codebook_weight * codebook_loss 83 | return loss, (loglaplace_loss, loggaussian_loss, perceptual_loss, adversarial_loss, codebook_loss) 84 | 85 | def train_step(state: TrainState_v2, batch, config=None): 86 | """ 87 | Note: we'll compile this with pmap so no need to jit 88 | :param state: 89 | :param batch: 90 | :param use_bfloat16_grads: Whether to use bfloat16 for storing grads. I think it is probably OK considering 91 | momentum is bfloat16 anyways. i'm just going to cast down (rounding down here rather 92 | than to nearest or anything) 93 | :return: 94 | """ 95 | def _loss_fn_nll(params): 96 | rec, codebook_loss = state.apply_fn_g( 97 | {'params': params}, 98 | rngs=state.aux_rng_keys_g, 99 | x=batch['inputs'], 100 | train=True, 101 | ) 102 | 103 | p_logits = state.apply_fn_p( 104 | {'params': state.params_p}, 105 | rngs=state.aux_rng_keys_p, 106 | x0=batch['inputs'], 107 | x1=rec, 108 | train=False, 109 | ) 110 | 111 | g_logits = state.apply_fn_d( 112 | {'params': state.params_d}, 113 | rngs=state.aux_rng_keys_d, 114 | x=rec, 115 | train=False, 116 | ) 117 | 118 | loss, aux_loss = VQLPIPS( 119 | batch['inputs'], 120 | rec, 121 | codebook_loss, 122 | p_logits, 123 | g_logits, 124 | codebook_weight = config['loss']['codebook_weight'], 125 | loggaussian_weight = config['loss']['loggaussian_weight'], 126 | loglaplace_weight = config['loss']['loglaplace_weight'], 127 | perceptual_weight = config['loss']['perceptual_weight'], 128 | adversarial_weight = config['loss']['adversarial_weight'], 129 | disc_g_start = config['loss']['disc_g_start'], 130 | step = state.step, 131 | ) 132 | 133 | return jnp.mean(loss), (rec, aux_loss) 134 | 135 | def _loss_fn_d(params): 136 | rec, _ = state.apply_fn_g( 137 | {'params': state.params_g}, 138 | rngs=state.aux_rng_keys_g, 139 | x=batch['inputs'], 140 | train=False, 141 | ) 142 | rec = jax.lax.stop_gradient(rec) 143 | d_real_logit = state.apply_fn_d( 144 | {'params': params}, 145 | rngs=state.aux_rng_keys_d, 146 | x=batch['inputs'], 147 | train=True, 148 | ) 149 | d_fake_logit = state.apply_fn_d( 150 | {'params': params}, 151 | rngs=state.aux_rng_keys_d, 152 | x=rec, 153 | train=True, 154 | ) 155 | d_loss = vanilla_d_loss(d_fake_logit, d_real_logit) 156 | 157 | d_weight = jnp.where(state.step > config['loss']['disc_d_start'], 1.0, 0.0) 158 | d_loss = d_loss * d_weight 159 | 160 | return jnp.mean(d_loss), (jnp.mean(d_real_logit),jnp.mean(d_fake_logit)) 161 | 162 | use_bfloat16_grads = config['model']['use_bfloat16'] 163 | params_g = state.params_g 164 | if use_bfloat16_grads: 165 | params_g = f32_to_bf16(params_g) 166 | 167 | grad_fn_nll = jax.value_and_grad(_loss_fn_nll, has_aux=True) 168 | output_nll, grads_nll = grad_fn_nll(params_g) 169 | loss_nll, aux_output = output_nll 170 | rec, aux_loss = aux_output 171 | loglaplace_loss, loggaussian_loss, p_loss, loss_g, c_loss = aux_loss 172 | 173 | # grad_fn_g = jax.value_and_grad(_loss_fn_g, has_aux=False) 174 | # loss_g, grads_g = grad_fn_g(params_g) 175 | 176 | # last_layer_nll = jnp.linalg.norm(grads_nll['model']['decoder_proj']['kernel']) 177 | # last_layer_g = jnp.linalg.norm(grads_g['model']['decoder_proj']['kernel']) 178 | 179 | # d_weight_factor = jnp.where(state.step > config['loss']['disc_g_start'], config['loss']['adversarial_weight'],0.0) 180 | # d_weight = last_layer_nll / (last_layer_g + 1e-4) * d_weight_factor 181 | 182 | # grads_g = jax.tree_map(lambda x: d_weight*x, grads_g) 183 | 184 | # grads = jax.tree_map(lambda x: jnp.nan_to_num(x, copy=False), grads) 185 | grads_nll = jax.lax.pmean(grads_nll, axis_name='batch') 186 | # grads_g = jax.lax.pmean(grads_g, axis_name='batch') 187 | 188 | # Cast up grads here (after the pmean) which reduces bandwidth maybe 189 | if use_bfloat16_grads: 190 | grads_nll = bf16_to_f32(grads_nll) 191 | # grads_g = bf16_to_f32(grads_g) 192 | 193 | state = state.apply_gradients_g(grads=grads_nll) 194 | # state = state.apply_gradients_g(grads=grads_g) 195 | 196 | loss_info = {} 197 | loss_info['loss'] = loss_nll 198 | loss_info['loglaplace_loss'] = loglaplace_loss 199 | loss_info['loggaussian_loss'] = loggaussian_loss 200 | loss_info['perceptual_loss'] = p_loss 201 | loss_info['codebook_loss'] = c_loss 202 | # loss_info['d_weight'] = d_weight 203 | 204 | if config['data']['task'] == 'audio': 205 | loss_info['image_rec'] = rec 206 | loss_info['image_ori'] = batch['inputs'] 207 | else: 208 | loss_info['image_rec'] = (rec.clip(-1.0, 1.0) + 1.0) / 2.0 209 | loss_info['image_ori'] = (batch['inputs'] + 1.0) / 2.0 210 | 211 | # discriminator 212 | grad_fn_d = jax.value_and_grad(_loss_fn_d, has_aux=True) 213 | params_d = state.params_d 214 | if use_bfloat16_grads: 215 | params_d = f32_to_bf16(state.params_d) 216 | output_d, grads_d = grad_fn_d(params_d) 217 | 218 | loss_d, aux_output_d = output_d 219 | d_real_logit, d_fake_logit = aux_output_d 220 | 221 | loss_info['g_loss'] = loss_g 222 | loss_info['d_loss'] = loss_d 223 | loss_info['d_real_logit'] = d_real_logit 224 | loss_info['d_fake_logit'] = d_fake_logit 225 | 226 | if use_bfloat16_grads: 227 | grads_d = bf16_to_f32(grads_d) 228 | 229 | grads_d = jax.lax.pmean(grads_d, axis_name='batch') 230 | state = state.apply_gradients_d(grads=grads_d) 231 | 232 | # Average metrics over all replicas (maybe this isn't a great idea, idk) 233 | # loss_info = jax.lax.pmean(loss_info, axis_name='batch') 234 | loss_info = bf16_to_f32(loss_info) 235 | state = state.replace(step = state.step + 1) 236 | 237 | return state, loss_info -------------------------------------------------------------------------------- /tpu_run.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import os 3 | import subprocess 4 | import time 5 | 6 | import glob 7 | import requests 8 | from fabric import Connection 9 | from dataclasses import dataclass 10 | import multiprocessing.pool 11 | import regex as re 12 | import pandas as pd 13 | import random 14 | multiprocessing.set_start_method("spawn") 15 | 16 | from func_timeout import func_set_timeout 17 | import time 18 | import json 19 | 20 | @functools.lru_cache() 21 | def get_bearer(): 22 | return subprocess.check_output("gcloud auth print-access-token", shell=True).decode("utf-8").strip() 23 | 24 | 25 | @functools.lru_cache() 26 | def get_project(): 27 | return subprocess.check_output("gcloud config list --format 'value(core.project)'", shell=True).decode( 28 | "utf-8").strip() 29 | 30 | @dataclass 31 | class TPUCreator: 32 | """ 33 | Utility for creating TPUs and stuff 34 | """ 35 | name: str 36 | tpu_size: int 37 | zone: str = 'europe-west4-a' 38 | preemptible: bool = False 39 | network: str='main' 40 | subnetwork: str='europe-west4' 41 | version: str='v2-alpha' 42 | accelerator_type: str='v3' 43 | 44 | @property 45 | def base_url(self): 46 | # https://cloud.google.com/tpu/docs/reference/rest/v2alpha1/projects.locations.nodes/create 47 | return f'https://tpu.googleapis.com/v2alpha1/projects/{get_project()}/locations/{self.zone}/nodes' 48 | 49 | def check_tpu(self): 50 | response = requests.get(f'{self.base_url}/{self.name}', 51 | headers={'Authorization': f'Bearer {get_bearer()}'}) 52 | return response.json() 53 | 54 | def create_tpu(self): 55 | """ 56 | Tries to create a TPU, 57 | :return: returns True if successful and False otherwise 58 | """ 59 | if not os.path.expanduser('~/.ssh/google_compute_engine'): 60 | raise ValueError("Must create SSH keys in legacy mode with something like" 61 | "ssh-keygen -m PEM -t rsa -b 4096 -C \"$(whoami)@$(hostname)\" -f ~/.ssh/google_compute_engine") 62 | 63 | try: 64 | status = self.check_tpu() 65 | 66 | if status["state"] not in ["CREATING", "READY"]: 67 | print("deleting TPU") 68 | self.delete_tpu() 69 | 70 | while True: 71 | try: 72 | print("deleting check: {}".format(self.check_tpu()["state"]), flush=True) 73 | time.sleep(1) 74 | except: 75 | break 76 | except: 77 | pass 78 | 79 | data = { 80 | "accelerator_type": f'{self.accelerator_type}-{self.tpu_size}', 81 | "runtime_version": f'{self.version}', 82 | "network_config": {"enable_external_ips": True, "network": self.network, "subnetwork": self.subnetwork}, 83 | "tags": "unified_io", 84 | } 85 | 86 | if self.preemptible: 87 | data["schedulingConfig"] = {"preemptible": True} 88 | 89 | response = requests.post(self.base_url, 90 | headers={'Authorization': f'Bearer {get_bearer()}', 91 | 'Content-Type': 'application/json', }, 92 | params=(('node_id', self.name),), json=data) 93 | print(response.json()) 94 | return response.status_code == 200 95 | 96 | def delete_tpu(self): 97 | response = requests.delete(f'{self.base_url}/{self.name}', headers={'Authorization': f'Bearer {get_bearer()}'}) 98 | return response.json() 99 | 100 | def wait_until_tpu_ready(self): 101 | desired_state = {'state': 'READY', 'health': 'HEALTHY'} 102 | # desired_state = {'state': 'READY'} 103 | while True: 104 | ret = self.check_tpu() 105 | 106 | print(f"wait_until_tpu_ready check: {ret}", flush=True) 107 | 108 | if ("error" in ret) or (ret["state"] == "TERMINATED"): 109 | return False 110 | 111 | matches = True 112 | for k, expected_v in desired_state.items(): 113 | if k not in ret: 114 | matches = False 115 | continue 116 | if ret[k] != expected_v: 117 | matches = False 118 | 119 | if matches: 120 | return True 121 | time.sleep(30) 122 | 123 | def get_connections(self): 124 | host = self.name 125 | zone = self.zone 126 | key_path = os.path.expanduser('~/.ssh/google_compute_engine') 127 | 128 | out = subprocess.getoutput(f"gcloud alpha compute tpus tpu-vm describe --zone {zone} {host} --format json") 129 | out = json.loads(out) 130 | ips = [x["accessConfig"]["externalIp"] for x in out["networkEndpoints"]] 131 | print(f"Identified {ips} ips for host {host}") 132 | 133 | # This will (sometimes?) take care of some know-host issue that would otherwise prevent us 134 | # from ssh-ing in normally 135 | # Might be some ssh things we could do to fix this in a better way 136 | print(f"Testing connection with gcloud ssh....") 137 | exit_code = os.system('gcloud alpha compute tpus tpu-vm ssh {} --zone {} --command="echo gcloud connected"'.format(host, zone)) 138 | if exit_code != 0: 139 | raise ValueError(f"gcloud connection failed, host {host} might be not be reachable") 140 | 141 | conns = [Connection(h, connect_kwargs={"key_filename": key_path}) for h in ips] 142 | return conns 143 | 144 | def install_dependencies(conn): 145 | """ 146 | Upload all the code 147 | :param conn: 148 | :param address: 149 | :return: 150 | """ 151 | try: 152 | conn.run('killall -9 python3') 153 | except Exception as e: 154 | print(e) 155 | 156 | try: 157 | conn.run('killall -9 screen') 158 | except Exception as e: 159 | print(e) 160 | 161 | print(f"Starting on {conn}", flush=True) 162 | conn.run('rm -rf *.py') 163 | conn.run('rm -rf *.json') 164 | conn.run('rm -rf screenlog.0') 165 | conn.run('rm -rf configs && rm -rf data && rm -rf models && rm -rf training && rm -rf additional_file') 166 | conn.run('rm -rf startup.sh') 167 | 168 | # copy credential for some error 169 | conn.run(f"mkdir /home/jiasenl/.config/gcloud -p") 170 | conn.put('/home/jiasenl/.config/gcloud/application_default_credentials.json', f'/home/jiasenl/.config/gcloud') 171 | 172 | # conn.sudo('rm -rf *') 173 | local_code_path = os.path.expanduser('~/vit-vqgan-jax/') 174 | # Copy files 175 | for i in glob.glob(os.path.join(local_code_path, '*.py')): 176 | conn.put(i, f'') 177 | 178 | for i in glob.glob(os.path.join(local_code_path, '*.md')): 179 | conn.put(i, f'') 180 | 181 | # Copy python files 182 | for ok_folder in ['data', 'models', 'training']: 183 | conn.sudo(f'rm -rf {ok_folder}') 184 | conn.run(f"mkdir {ok_folder} -p") 185 | for i in glob.glob(os.path.join(local_code_path, ok_folder, '*.py')): 186 | conn.put(i, f'{ok_folder}/') 187 | 188 | for ok_folder in ['configs', ]: 189 | conn.run(f'rm -rf {ok_folder}') 190 | conn.run(f"mkdir {ok_folder} -p") 191 | for i in glob.glob(os.path.join(local_code_path, ok_folder, '*.yaml')): 192 | conn.put(i, f'{ok_folder}/') 193 | 194 | # vgg 195 | for ok_folder in ['additional_file', ]: 196 | conn.run(f'rm -rf {ok_folder}') 197 | conn.run(f"mkdir {ok_folder} -p") 198 | for i in glob.glob(os.path.join(local_code_path, ok_folder, '*.h5')): 199 | conn.put(i, f'{ok_folder}/') 200 | 201 | conn.put(os.path.join(local_code_path, 'tpu_startup_script.sh'), "/tmp/startup_vit_vqgan.sh") 202 | conn.sudo('chmod +x /tmp/startup_vit_vqgan.sh', hide=True) 203 | conn.run('/tmp/startup_vit_vqgan.sh', hide=True) 204 | 205 | conn.put(os.path.join(local_code_path, 'run.sh'), "") 206 | conn.sudo('chmod +x run.sh', hide=True) 207 | 208 | if __name__ == '__main__': 209 | tpu_creator = TPUCreator(name='jiasen-v3-64-1', zone='us-east1-d', 210 | tpu_size=64, network='rowan', subnetwork='rowan', accelerator_type='v3', 211 | version='v2-alpha') 212 | # tpu_creator = TPUCreator(name='jiasen-v4-64-0', zone='us-central2-b', 213 | # tpu_size=64, network='default', subnetwork='default', accelerator_type='v4', 214 | # version='tpu-vm-v4-base') 215 | """ 216 | while True: 217 | tpu_creator.create_tpu() 218 | tpu_is_ready = tpu_creator.wait_until_tpu_ready() 219 | # info = tpu_creator.check_tpu() 220 | # if 'error' in info: 221 | if tpu_is_ready: 222 | break 223 | else: 224 | info = tpu_creator.check_tpu() 225 | print(f"\n~ERROR retrying: \n{info['error']}\n", flush=True) 226 | time.sleep(60 * 5) 227 | """ 228 | 229 | conns = tpu_creator.get_connections() 230 | 231 | with multiprocessing.pool.ThreadPool(processes=len(conns)) as p: 232 | p.map(install_dependencies, conns) 233 | time.sleep(30) 234 | 235 | def _run_things(conn): 236 | with conn.cd(''): 237 | local_code_path = os.path.expanduser('~/vit-vqgan-jax/') 238 | conn.put(os.path.join(local_code_path, 'additional_file', 'utils.py'), "/home/jiasenl/.local/lib/python3.8/site-packages/seqio/utils.py") 239 | conn.run('screen -d -m -L bash -c ./run.sh', pty=False) 240 | print('done') 241 | 242 | with multiprocessing.pool.ThreadPool(processes=len(conns)) as p: 243 | p.map(_run_things, conns) 244 | -------------------------------------------------------------------------------- /training/optimization.py: -------------------------------------------------------------------------------- 1 | import optax 2 | from optax import GradientTransformation 3 | from optax._src.base import NO_PARAMS_MSG 4 | import jax 5 | import flax 6 | import chex 7 | import jax.numpy as jnp 8 | import functools 9 | from optax._src import numerics, wrappers, transform 10 | from flax.core.frozen_dict import FrozenDict 11 | from flax.training import train_state 12 | from models.checkpoint import f32_to_bf16, bf16_to_f32 13 | from optax._src.factorized import _factored_dims 14 | import numpy as np 15 | from typing import NamedTuple, Any 16 | from training.train_state import TrainState_v2 17 | 18 | 19 | class ScaleByAdamState(NamedTuple): 20 | """State for the Adam algorithm.""" 21 | count: chex.Array 22 | mu: optax.Updates 23 | nu: optax.Updates 24 | 25 | 26 | def _bias_correction(moment, decay, count): 27 | """Perform bias correction. This becomes a no-op as count goes to infinity.""" 28 | bias_correction = 1 - decay ** count 29 | return jax.tree_map(lambda t: t / bias_correction.astype(t.dtype), moment) 30 | 31 | 32 | ######################### 33 | # Bfloat16 34 | # using this as the sign bit 35 | # cube root to exchange range for mantissa precision 36 | # i think this encoding probably isn't as good as I would want were i to do this again. 37 | # the issue is for really small numbers we almost never choose the negative option 38 | missing_precision = 1 + (1 / 2 ** 9) 39 | 40 | def _unsigned_bfloat16_decode(v): 41 | v_abs = jnp.abs(v).astype(jnp.float32) 42 | v_abs = jax.lax.select(v >= 0, v_abs, v_abs * missing_precision) 43 | return jnp.cbrt(v_abs) 44 | 45 | 46 | def _unsigned_bfloat16_encode(v): 47 | v_pow = jnp.power(v, 3) 48 | v_bf = v_pow.astype(jnp.bfloat16) 49 | v_bf32 = v_bf.astype(jnp.float32) 50 | 51 | err0 = jnp.abs(v_bf32 - v_pow) 52 | err1 = jnp.abs(v_bf32 * missing_precision - v_pow) 53 | return jax.lax.select(err0 < err1, v_bf, -v_bf) 54 | 55 | 56 | def scale_by_bfloat16_adam( 57 | b1: float = 0.9, 58 | b2: float = 0.999, 59 | eps: float = 1e-8, 60 | eps_root: float = 0.0, 61 | use_bfloat16=True, 62 | do_bias_correction=True, 63 | ) -> GradientTransformation: 64 | """ 65 | Scales by bfloat16 adam 66 | :param b1: 67 | :param b2: 68 | :param eps: 69 | :param eps_root: 70 | :param use_bfloat16: 71 | :param do_bias_correction: 72 | :return: 73 | """ 74 | if not use_bfloat16: 75 | assert do_bias_correction 76 | return optax.scale_by_adam(b1, b2, eps, eps_root) 77 | 78 | _init = functools.partial(jnp.zeros_like, dtype=jnp.bfloat16) 79 | 80 | def init_fn(params): 81 | running_m = jax.tree_map(_init, params) 82 | running_v = jax.tree_map(_init, params) 83 | return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=running_m, nu=running_v) 84 | 85 | def _momentum_update(grad, current_m): 86 | # Cast up here 87 | current_m = current_m.astype(jnp.float32) 88 | next_m = (1 - b1) * grad + b1 * current_m 89 | return next_m 90 | 91 | def _secondorder_update(grad, current_v): 92 | current_v_dec = _unsigned_bfloat16_decode(current_v) 93 | next_v = (1 - b2) * jnp.square(grad) + b2 * current_v_dec 94 | return next_v 95 | 96 | def update_fn(updates, state, params=None): 97 | del params 98 | 99 | next_m = jax.tree_map(_momentum_update, updates, state.mu) 100 | next_m_enc = jax.tree_map(lambda x: x.astype(jnp.bfloat16), next_m) 101 | 102 | next_v = jax.tree_map(_secondorder_update, updates, state.nu) 103 | next_v_enc = jax.tree_map(_unsigned_bfloat16_encode, next_v) 104 | 105 | count_inc = numerics.safe_int32_increment(state.count) 106 | if do_bias_correction: 107 | next_m = _bias_correction(next_m, b1, count_inc) 108 | next_v = _bias_correction(next_v, b2, count_inc) 109 | 110 | # Apply updates 111 | updates = jax.tree_map( 112 | lambda m, v: m / (jnp.sqrt(v + eps_root) + eps), next_m, next_v) 113 | 114 | return updates, ScaleByAdamState(count=count_inc, mu=next_m_enc, nu=next_v_enc) 115 | 116 | return GradientTransformation(init_fn, update_fn) 117 | 118 | 119 | def lr_scale_linearwarmup_cosinedecay(num_warmup_steps, num_train_steps, final_lr_scale=0.1): 120 | """ 121 | :param num_warmup_steps: Linear warmup for this many steps 122 | :param num_train_steps: Cosine decay for num_train_steps - num_warmup_steps 123 | :param final_lr_scale: We will end at this * learning_rate 124 | :return: 125 | """ 126 | assert num_warmup_steps <= num_train_steps 127 | 128 | def schedule(step): 129 | warmup_scale = step / num_warmup_steps 130 | post_warmup_scale = (step - num_warmup_steps) / (num_train_steps - num_warmup_steps + 1.0) 131 | post_warmup_scale = jnp.minimum(post_warmup_scale, 1.0) 132 | 133 | # linear -> cosine 134 | post_warmup_scale = (1.0 - (1.0 - jnp.cos(jnp.pi * post_warmup_scale)) / 2.0) 135 | post_warmup_scale = final_lr_scale + (1.0 - final_lr_scale) * post_warmup_scale 136 | 137 | return jax.lax.select(step < num_warmup_steps, warmup_scale, post_warmup_scale) 138 | 139 | return schedule 140 | 141 | 142 | def lr_scale_linearwarmup_lineardecay(num_warmup_steps, num_train_steps): 143 | """ 144 | :param num_warmup_steps: Linear warmup for this many steps 145 | :param num_train_steps: Linear decay for num_train_steps - num_warmup_steps 146 | :param final_lr_scale: We will end at this * learning_rate 147 | :return: 148 | """ 149 | assert num_warmup_steps <= num_train_steps 150 | 151 | def schedule(step): 152 | warmup_scale = step / num_warmup_steps 153 | post_warmup_scale = (step - num_warmup_steps) / (num_train_steps - num_warmup_steps + 1.0) 154 | post_warmup_scale = 1.0 - jnp.minimum(post_warmup_scale, 1.0) 155 | return jax.lax.select(step < num_warmup_steps, warmup_scale, post_warmup_scale) 156 | 157 | return schedule 158 | 159 | 160 | def construct_optimizer(opt_config, return_chainables=False): 161 | chainables = [] 162 | # clip_norm = opt_config.get('global_max_norm', None) 163 | # if clip_norm: 164 | # chainables = [optax.clip_by_global_norm(clip_norm)] 165 | # optim = opt_config.get('optim', 'adamw').lower() 166 | # if optim == 'adamw': 167 | # opt = scale_by_bfloat16_adam(b1=opt_config.get('beta_1', 0.9), 168 | # b2=opt_config.get('beta_2', 0.98), 169 | # eps=opt_config.get('eps', 1e-8), 170 | # use_bfloat16=opt_config.get('use_bfloat16_optim', True), 171 | # do_bias_correction=opt_config.get('do_bias_correction', False), 172 | # ) 173 | # chainables += [opt] 174 | # elif optim == 'sgd': 175 | # opt = transform.trace( 176 | # decay=opt_config.get('momentum', 0.9), 177 | # nesterov=opt_config.get('nesterov', False), 178 | # accumulator_dtype=jnp.bfloat16 if opt_config.get('use_bfloat16_optim', True) else None, 179 | # ) 180 | # chainables += [opt] 181 | # else: 182 | # raise NotImplementedError 183 | 184 | # chainables += [ 185 | # optax.add_decayed_weights(weight_decay=opt_config['weight_decay_rate'], 186 | # mask=lambda p: jax.tree_map(lambda x: x.ndim > 1, p), 187 | # ), 188 | # optax.scale_by_schedule(lr_scale_linearwarmup_cosinedecay(num_warmup_steps=opt_config['num_warmup_steps'], 189 | # num_train_steps=opt_config['num_train_steps'], 190 | # final_lr_scale=opt_config.get('final_lr_scale', 0.02), 191 | # )), 192 | # optax.scale(-opt_config['learning_rate']), 193 | # ] 194 | learning_rate_fn = optax.warmup_cosine_decay_schedule( 195 | init_value = 0.0, 196 | peak_value = opt_config['learning_rate'], 197 | warmup_steps = opt_config['num_warmup_steps'], 198 | decay_steps = opt_config['num_train_steps'], 199 | end_value = opt_config['end_learning_rate'], 200 | ) 201 | chainables += [ 202 | optax.clip(max_delta = 1.0), 203 | optax.adamw( 204 | learning_rate = learning_rate_fn, 205 | b1=opt_config.get('beta_1', 0.9), 206 | b2=opt_config.get('beta_2', 0.98), 207 | weight_decay=1e-4, 208 | ) 209 | ] 210 | 211 | if return_chainables: 212 | return chainables 213 | 214 | tx = optax.chain(*chainables) 215 | return tx 216 | 217 | 218 | def construct_train_state(opt_config, models, params, rng_keys, return_chainables=False): 219 | """ 220 | :param optimizer_params: Dict like 221 | { 222 | learning_rate: 0.001 223 | num_train_steps: 93600 # 300 epochs 224 | num_warmup_steps: 10000 225 | weight_decay_rate: 0.1 226 | beta_2: 0.999 227 | clip_norm: 0.0 228 | adafactor: False 229 | use_bfloat16_adam: True 230 | } 231 | 232 | :return: 233 | """ 234 | 235 | txs = {} 236 | for k, v in opt_config.items(): txs[k] = construct_optimizer(v) 237 | 238 | state = TrainState_v2.create( 239 | models=models, 240 | params=params, 241 | frozen_params={'g': flax.core.freeze({}), 'd': flax.core.freeze({})}, 242 | txs=txs, 243 | aux_rng_keys=rng_keys, 244 | ) 245 | return state -------------------------------------------------------------------------------- /models/vgg.py: -------------------------------------------------------------------------------- 1 | from jax import random 2 | import jax.numpy as jnp 3 | import flax.linen as nn 4 | import functools 5 | from typing import Any, Tuple 6 | import h5py 7 | import warnings 8 | 9 | from tqdm import tqdm 10 | import requests 11 | import os 12 | import tempfile 13 | # import t5x.examples.vit_vqgan.layers as layers 14 | 15 | URLS = {'vgg16': 'https://www.dropbox.com/s/ew3vhtlg5kks8mz/vgg16_weights.h5?dl=1', 16 | 'vgg19': 'https://www.dropbox.com/s/1sn02fnkj579u1w/vgg19_weights.h5?dl=1'} 17 | 18 | def download(ckpt_dir, url): 19 | name = url[url.rfind('/') + 1 : url.rfind('?')] 20 | if ckpt_dir is None: 21 | ckpt_dir = tempfile.gettempdir() 22 | ckpt_dir = os.path.join(ckpt_dir, 'flaxmodels') 23 | ckpt_file = os.path.join(ckpt_dir, name) 24 | if not os.path.exists(ckpt_file): 25 | print(f'Downloading: \"{url[:url.rfind("?")]}\" to {ckpt_file}') 26 | if not os.path.exists(ckpt_dir): 27 | os.makedirs(ckpt_dir) 28 | 29 | response = requests.get(url, stream=True) 30 | total_size_in_bytes = int(response.headers.get('content-length', 0)) 31 | progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) 32 | 33 | # first create temp file, in case the download fails 34 | ckpt_file_temp = os.path.join(ckpt_dir, name + '.temp') 35 | with open(ckpt_file_temp, 'wb') as file: 36 | for data in response.iter_content(chunk_size=1024): 37 | progress_bar.update(len(data)) 38 | file.write(data) 39 | progress_bar.close() 40 | 41 | if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: 42 | print('An error occured while downloading, please try again.') 43 | if os.path.exists(ckpt_file_temp): 44 | os.remove(ckpt_file_temp) 45 | else: 46 | # if download was successful, rename the temp file 47 | os.rename(ckpt_file_temp, ckpt_file) 48 | return ckpt_file 49 | 50 | def normalize_tensor(x, eps=1e-10): 51 | norm_factor = jnp.sqrt(jnp.sum(x ** 2, axis=-1, keepdims=True)) 52 | return x / (norm_factor + eps) 53 | 54 | def spatial_average(x, keepdims=True): 55 | return jnp.mean(x, axis=[1,2], keepdims=keepdims) 56 | 57 | class VGG(nn.Module): 58 | """ 59 | VGG. 60 | Attributes: 61 | output (str): 62 | Output of the module. Available options are: 63 | - 'softmax': Output is a softmax tensor of shape [N, 1000] 64 | - 'log_softmax': Output is a softmax tensor of shape [N, 1000] 65 | - 'logits': Output is a tensor of shape [N, 1000] 66 | - 'activations': Output is a dictionary containing the VGG activations 67 | pretrained (str): 68 | Indicates if and what type of weights to load. Options are: 69 | - 'imagenet': Loads the network parameters trained on ImageNet 70 | - None: Parameters of the module are initialized randomly 71 | normalize (bool): 72 | If True, the input will be normalized with the ImageNet statistics. 73 | architecture (str): 74 | Architecture type: 75 | - 'vgg16' 76 | - 'vgg19' 77 | include_head (bool): 78 | If True, include the three fully-connected layers at the top of the network. 79 | This option is useful when you want to obtain activations for images whose 80 | size is different than 224x224. 81 | num_classes (int): 82 | Number of classes. Only relevant if 'include_head' is True. 83 | kernel_init (function): 84 | A function that takes in a shape and returns a tensor. 85 | bias_init (function): 86 | A function that takes in a shape and returns a tensor. 87 | ckpt_dir (str): 88 | The directory to which the pretrained weights are downloaded. 89 | Only relevant if a pretrained model is used. 90 | If this argument is None, the weights will be saved to a temp directory. 91 | dtype (str): Data type. 92 | """ 93 | output: str='softmax' 94 | pretrained: str='imagenet' 95 | normalize: bool=True 96 | architecture: str='vgg16' 97 | include_head: bool=False 98 | num_classes: int=1000 99 | kernel_init: functools.partial=nn.initializers.lecun_normal() 100 | bias_init: functools.partial=nn.initializers.zeros 101 | ckpt_dir: str=None 102 | lpips: bool=True 103 | enable_dropout: bool=True 104 | dropout_rate: float=0.5 105 | chns: Tuple[int] = (64, 128, 256, 512, 512) 106 | vgg_output_names: Tuple[str] = ('relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3') 107 | dtype: str='float32' 108 | 109 | def setup(self): 110 | self.param_dict = None 111 | if self.pretrained == 'imagenet': 112 | ckpt_file = download(self.ckpt_dir, URLS[self.architecture]) 113 | self.param_dict = h5py.File(ckpt_file, 'r') 114 | 115 | if self.lpips: 116 | self.param_lpips = h5py.File('additional_file/vgg.h5', 'r') 117 | 118 | def _conv_block(self, x, features, num_layers, block_num, act, dtype='float32'): 119 | 120 | for l in range(num_layers): 121 | layer_name = f'conv{block_num}_{l + 1}' 122 | w = self.kernel_init if self.param_dict is None else lambda *_ : jnp.array(self.param_dict[layer_name]['weight']) 123 | b = self.bias_init if self.param_dict is None else lambda *_ : jnp.array(self.param_dict[layer_name]['bias']) 124 | 125 | x = nn.Conv(features=features, 126 | kernel_size=(3, 3), 127 | padding=((1, 1), (1, 1)), 128 | kernel_init=w, 129 | use_bias=True, 130 | bias_init=b, 131 | dtype=self.dtype)(x) 132 | 133 | act[layer_name] = x 134 | x = nn.relu(x) 135 | act[f'relu{block_num}_{l + 1}'] = x 136 | return x 137 | 138 | def _net_lin_layer(self, x, block_num, chn_out=1, deterministic=False): 139 | if self.enable_dropout: 140 | x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic) 141 | 142 | w = self.kernel_init if self.param_lpips is None else \ 143 | lambda *_ : jnp.transpose(jnp.array(self.param_lpips[f'lin{block_num}.model.1.weight']), (2,3,1,0)) 144 | 145 | x = nn.Conv(features=chn_out, 146 | kernel_size=(1,1), 147 | strides=1, 148 | padding=((0, 0), (0, 0)), 149 | kernel_init=w, 150 | use_bias=False, 151 | dtype=self.dtype 152 | )(x) 153 | 154 | return x 155 | 156 | def _forward(self, x): 157 | act = {} 158 | x = self._conv_block(x, features=64, num_layers=2, block_num=1, act=act, dtype=self.dtype) 159 | x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2)) 160 | 161 | x = self._conv_block(x, features=128, num_layers=2, block_num=2, act=act, dtype=self.dtype) 162 | x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2)) 163 | 164 | x = self._conv_block(x, features=256, num_layers=3 if self.architecture == 'vgg16' else 4, block_num=3, act=act, dtype=self.dtype) 165 | x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2)) 166 | 167 | x = self._conv_block(x, features=512, num_layers=3 if self.architecture == 'vgg16' else 4, block_num=4, act=act, dtype=self.dtype) 168 | x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2)) 169 | 170 | x = self._conv_block(x, features=512, num_layers=3 if self.architecture == 'vgg16' else 4, block_num=5, act=act, dtype=self.dtype) 171 | x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2)) 172 | 173 | return x, act 174 | 175 | @nn.compact 176 | def __call__(self, x0, x1, train=False): 177 | """ 178 | Args: 179 | x (tensor of shape [N, H, W, 3]): 180 | Batch of input images (RGB format). Images must be in range [0, 1]. 181 | If 'include_head' is True, the images must be 224x224. 182 | train (bool): Training mode. 183 | Returns: 184 | If output == 'logits' or output == 'softmax': 185 | (tensor): Output tensor of shape [N, num_classes]. 186 | If output == 'activations': 187 | (dict): Dictionary of activations. 188 | """ 189 | assert x0.shape == x1.shape 190 | 191 | if self.output not in ['softmax', 'log_softmax', 'logits', 'activations']: 192 | raise ValueError('Wrong argument. Possible choices for output are "softmax", "logits", and "activations".') 193 | 194 | if self.pretrained is not None and self.pretrained != 'imagenet': 195 | raise ValueError('Wrong argument. Possible choices for pretrained are "imagenet" and None.') 196 | 197 | if self.include_head and (x0.shape[1] != 224 or x0.shape[2] != 224): 198 | raise ValueError('Wrong argument. If include_head is True, then input image must be of size 224x224.') 199 | 200 | if self.normalize: 201 | mean = jnp.array([-.030,-.088,-.188]).reshape(1, 1, 1, -1).astype(x0.dtype) 202 | std = jnp.array([.458,.448,.450]).reshape(1, 1, 1, -1).astype(x0.dtype) 203 | 204 | x0 = (x0 - mean) / std 205 | x1 = (x1 - mean) / std 206 | 207 | if self.pretrained == 'imagenet': 208 | if self.num_classes != 1000: 209 | warnings.warn(f'The user specified parameter \'num_classes\' was set to {self.num_classes} ' 210 | 'but will be overwritten with 1000 to match the specified pretrained checkpoint \'imagenet\', if ', UserWarning) 211 | 212 | num_classes = 1000 213 | else: 214 | num_classes = self.num_classes 215 | 216 | x0, act0 = self._forward(x0) 217 | x1, act1 = self._forward(x1) 218 | 219 | diffs = {} 220 | # calculate the diff. 221 | for i, n in enumerate(self.vgg_output_names): 222 | # normalize the tensor and calculate the diffs. 223 | diffs[i] = (normalize_tensor(act0[n]) - normalize_tensor(act1[n])) ** 2 224 | 225 | if self.lpips: 226 | # LPIPS from https://github.com/richzhang/PerceptualSimilarity/blob/31bc1271ae6f13b7e281b9959ac24a5e8f2ed522/lpips/lpips.py#L87 227 | # spaital average and linear 228 | res = [spatial_average(self._net_lin_layer(diffs[block_num], block_num, deterministic= not train)) for block_num, _ in enumerate(self.vgg_output_names)] 229 | else: 230 | res = [spatial_average(jnp.sum(diffs[block_num], axis=1, keepdims=True), keepdims=True) for block_num, _ in enumerate(self.vgg_output_names)] 231 | 232 | return jnp.reshape(sum(res), (-1)) 233 | 234 | -------------------------------------------------------------------------------- /data/preprocessors.py: -------------------------------------------------------------------------------- 1 | import random 2 | from functools import reduce 3 | import einops 4 | 5 | from config import * 6 | from data.data_utils import * 7 | from data.imagenet_utils import _preprocess_image, RandomErasing 8 | 9 | 10 | AUTOTUNE = tf.data.experimental.AUTOTUNE 11 | rekey = seqio.preprocessors.rekey 12 | 13 | def log10(x): 14 | numerator = tf.math.log(x) 15 | denominator = tf.math.log(tf.constant(10, dtype=numerator.dtype)) 16 | return numerator / denominator 17 | 18 | def get_from_dict(data, keys): 19 | """Iterate nested dictionary""" 20 | return reduce(dict.get, keys, data) 21 | 22 | @seqio.utils.map_over_dataset 23 | def rekey(x, key_map=None): 24 | """Replace the feature keys according to the mapping in `key_map`. 25 | For example, if the dataset returns examples of the format: 26 | {'foo': 'something', 'bar': 'something else'} 27 | and key_map = {'boo': 'foo', 'spar': 'bar'} then this function will return 28 | examples with the format 29 | {'boo': 'something', 'spar': 'something else'} 30 | If a mapping is to an empty key or None, set the new key to an empty string. 31 | Args: 32 | x: an example to process. 33 | key_map: dictionary mapping new keys to original keys 34 | Returns: 35 | A preprocessed example with the format listed above. 36 | """ 37 | if key_map: 38 | return { 39 | new_key: get_from_dict(x, old_key) if old_key else '' 40 | for new_key, old_key in key_map.items() 41 | } 42 | return x 43 | 44 | 45 | def vit_vqgan_preprocessor(ds, sequence_length, decode_jpeg=False): 46 | 47 | image_input_size = [256, 256] 48 | is_training = sequence_length.get('is_training', True) 49 | 50 | def to_inputs_and_targets(ex): 51 | if decode_jpeg: 52 | img = tf.image.decode_jpeg(ex['image'], channels=3) 53 | else: 54 | img = ex['image'] 55 | 56 | img = tf.image.convert_image_dtype(img, dtype=tf.float32) 57 | img, img_mask, this_image_info = resize_and_pad(img, image_input_size, 58 | do_random_scale=is_training, 59 | random_scale_max=RANDOM_SCALE_MAX, 60 | random_scale_min=RANDOM_SCALE_MIN, 61 | shrink_both_sides=True, 62 | do_flip_if_vertical=False, 63 | random_scale_ratio=1) 64 | 65 | image_info, masks, boxes, labels, indices = this_image_info 66 | 67 | #[-1, 1] 68 | img = img * 2 - 1 69 | return {'inputs': img} 70 | 71 | return ds.map(to_inputs_and_targets, num_parallel_calls=tf.data.experimental.AUTOTUNE) 72 | 73 | 74 | def audio_preprocessor(ds, sequence_length, decode_video_string=True, random_start=True): 75 | 76 | def to_inputs_and_targets(ex): 77 | audio_shape = (ex['audio_nspectrograms'], ex['audio_nmels'], ex['audio_nhops']) 78 | spectrograms = tf.io.parse_tensor(ex['audio'], tf.float32) 79 | spectrograms = tf.reshape(spectrograms, audio_shape) 80 | 81 | rand_idx = tf.random.uniform([], minval=0, maxval=tf.shape(spectrograms)[0], dtype=tf.int32) 82 | audio = spectrograms[rand_idx] 83 | # random augment the audio masks. for 2 second audio, it will [72, 192] 84 | 85 | if random_start: 86 | audio = audio[:, 72:192] 87 | # random select a start point. 88 | rand_start = tf.random.uniform([], minval=0, maxval=17, dtype=tf.int32) 89 | audio_padded = tf.pad(audio, [[0,0], [rand_start*8, (17-rand_start)*8]], "CONSTANT") 90 | audio_mask = tf.cast(audio_padded != 0, tf.float32) 91 | else: 92 | audio_padded = audio 93 | audio_mask = tf.ones_like(audio_padded) 94 | 95 | audio_padded = tf.math.log(tf.clip_by_value(audio_padded, 1e-5, 1e5)) 96 | audio_padded = (audio_padded + 5.0945) / 3.8312 97 | audio_padded = audio_padded * audio_mask 98 | audio_padded = tf.expand_dims(audio_padded, -1) 99 | 100 | return {'inputs': audio_padded} 101 | 102 | return ds.map(to_inputs_and_targets, num_parallel_calls=tf.data.experimental.AUTOTUNE) 103 | 104 | 105 | def encoder_only_preprocessor( 106 | ds, sequence_length, class_id=None, class_map=None, decode_jpeg=False, 107 | ): 108 | image_input_size = sequence_length.get('input_size', IMAGE_INPUT_SIZE) 109 | image_input_d = sequence_length.get('patch_size', IMAGE_INPUT_D) 110 | is_training = sequence_length.get('is_training', True) 111 | rand_aug = sequence_length.get('rand_aug', None) 112 | rand_erase = sequence_length.get('rand_erase', 0.0) 113 | if rand_aug is not None: 114 | augmentation_settings = {} 115 | augmentation_settings['randaugment'] = dict(num_layers=rand_aug[0], magnitude=rand_aug[1]) 116 | augmentation_settings['cutmix'] = False 117 | augmentation_settings['mixup_alpha'] = None 118 | 119 | if class_id is not None: 120 | keys_tensor = tf.constant(class_id) 121 | table = tf.lookup.StaticHashTable( 122 | tf.lookup.KeyValueTensorInitializer( 123 | keys_tensor, 124 | tf.constant([i for i in range(len(class_id))], tf.int32), 125 | ), 126 | default_value=-1, 127 | ) 128 | elif class_map is not None: 129 | keys_tensor = tf.constant([i for i in range(len(class_map))], tf.int32) 130 | table = tf.lookup.StaticHashTable( 131 | tf.lookup.KeyValueTensorInitializer( 132 | keys_tensor, 133 | class_map 134 | ), 135 | default_value=21843, 136 | ) 137 | if rand_erase > 0: 138 | random_eraser = RandomErasing(probability=rand_erase) 139 | 140 | def to_inputs_and_targets(ex): 141 | if rand_aug is not None: 142 | image_inputs, _ = _preprocess_image( 143 | ex['image'], is_training, image_input_size, augmentation_settings) 144 | image_input_masks = tf.ones_like(image_inputs)[:, :, 0] 145 | else: 146 | if decode_jpeg: 147 | img = tf.image.decode_jpeg(ex['image'], channels=3) 148 | else: 149 | img = ex['image'] 150 | 151 | img = tf.image.convert_image_dtype(img, dtype=tf.float32) 152 | img, img_mask, this_image_info = resize_and_pad(img, image_input_size, 153 | do_random_scale=is_training, 154 | random_scale_max=RANDOM_SCALE_MAX, 155 | random_scale_min=RANDOM_SCALE_MIN, 156 | shrink_both_sides=True, 157 | do_flip_if_vertical=False, 158 | random_scale_ratio=0.5, 159 | resize_method='random' if is_training else tf.image.ResizeMethod.BILINEAR) 160 | image_inputs = img 161 | image_inputs = normalize_image(image_inputs) 162 | image_input_masks = img_mask 163 | if is_training and rand_erase > 0: 164 | image_inputs = random_eraser.distort(image_inputs) 165 | 166 | # Arrange into a list of patches 167 | image_inputs = einops.rearrange( 168 | image_inputs, '(h dh) (w dw) c -> (h w) (dh dw c)', 169 | dh=image_input_d, dw=image_input_d) 170 | 171 | if 'class_id' in ex: 172 | cid = ex['class_id'] 173 | label = table.lookup(cid) 174 | ex['label'] = label 175 | elif class_map is not None: 176 | ex['label'] = table.lookup(tf.cast(ex['label'], tf.int32)) 177 | return { 178 | 'inputs': image_inputs, 179 | 'targets': tf.cast(ex['label'], tf.int32), 180 | } 181 | 182 | return ds.map(to_inputs_and_targets, num_parallel_calls=tf.data.experimental.AUTOTUNE) 183 | 184 | 185 | def video_encoder_only_preprocessor(ds, sequence_length, class_name=None, decode_video_string=False): 186 | video_input_size = sequence_length.get('input_size', FINETUNE_VIDEO_INPUT_SIZE) 187 | video_input_d = sequence_length.get('patch_size', VIDEO_INPUT_D) 188 | num_frames = sequence_length.get('num_frames', 4) 189 | 190 | if class_name is not None: 191 | table = create_lookup_table_of_class_names(class_name) 192 | 193 | is_training = sequence_length.get('is_training', True) 194 | 195 | def to_inputs_and_targets(ex): 196 | 197 | # create video_inputs shape: (T,H,W,C) [0,255] 198 | if decode_video_string: 199 | # parse if stored as a TFRecord (byte string) 200 | video_shape = (ex['video_nframes'], ex['video_height'], ex['video_width'], ex['video_nchannels']) 201 | video = tf.io.parse_tensor(ex['video'], tf.uint8) 202 | video = tf.reshape(video, video_shape) 203 | # save shape in the TFRecord 204 | # tf.reshape(video, shape) 205 | else: 206 | video = ex['video'] 207 | 208 | video = convert_video_dtype(video,tf.float32) # [0,1] 209 | #video: TxHxWx3; video_mask: TxHxW 210 | 211 | video, video_mask, _ = resize_and_pad( 212 | video, 213 | is_video=True, 214 | desired_output_size=video_input_size, 215 | do_random_scale=is_training, 216 | random_scale_max=RANDOM_SCALE_MAX, 217 | random_scale_min=RANDOM_SCALE_MIN, 218 | shrink_both_sides=True, 219 | do_flip_if_vertical=False, 220 | random_scale_ratio=0.5, 221 | resize_method='random' if is_training else tf.image.ResizeMethod.BILINEAR) 222 | video_inputs = video 223 | video_input_masks = video_mask 224 | 225 | # Sample a fixed number of frames 226 | # [T, H, W, C] 227 | video_inputs, indices = sample_uniform_sequence( 228 | sequence=video_inputs, 229 | num_steps=num_frames, 230 | random=is_training, 231 | ) 232 | video_inputs = normalize_video(video_inputs) 233 | 234 | video_input_masks = tf.gather(video_input_masks, indices) 235 | 236 | video_inputs = einops.rearrange( 237 | video_inputs, 't (h dh) (w dw) c -> t (h w) (dh dw c)', 238 | dh=video_input_d, dw=video_input_d) 239 | 240 | # create text_targets 241 | label = tf.cast(ex['label'], tf.int32) 242 | 243 | return { 244 | 'inputs': video_inputs, 245 | 'targets': label, 246 | } 247 | 248 | return ds.map(to_inputs_and_targets, num_parallel_calls=tf.data.experimental.AUTOTUNE) -------------------------------------------------------------------------------- /models/models_vit.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Any, Callable, Optional, Tuple, Type 16 | 17 | import flax.linen as nn 18 | from flax.linen.module import merge_param 19 | import jax.numpy as jnp 20 | from jax import lax, random 21 | import numpy as np 22 | import einops 23 | 24 | Array = Any 25 | PRNGKey = Any 26 | Shape = Tuple[int] 27 | Dtype = Any 28 | 29 | class IdentityLayer(nn.Module): 30 | """Identity layer, convenient for giving a name to an array.""" 31 | 32 | @nn.compact 33 | def __call__(self, x): 34 | return x 35 | 36 | def get_sinusoid_encoding_table(seq_length, emb_dim, dtype): 37 | """Sinusoid position encoding table: excerpt from original Transformer""" 38 | def get_position_angle_vec(position): 39 | return [ 40 | position / np.power(10000, 2 * (dim_j // 2) / emb_dim) 41 | for dim_j in range(emb_dim) 42 | ] 43 | 44 | sinusoid_table = np.array( 45 | [get_position_angle_vec(pos_i) for pos_i in range(seq_length)] 46 | ) 47 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 48 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 49 | 50 | pos_emb = jnp.array(sinusoid_table).astype(dtype) 51 | return pos_emb 52 | 53 | 54 | def drop_path(x: jnp.array, rng, drop_rate: float = 0.) -> jnp.array: 55 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 56 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 57 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 58 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 59 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 60 | 'survival rate' as the argument. 61 | """ 62 | if drop_rate == 0.: 63 | return x 64 | keep_prob = 1. - drop_rate 65 | mask = random.bernoulli(key=rng, p=keep_prob, shape=(x.shape[0],) + (1,)*(x.ndim-1)) 66 | mask = jnp.broadcast_to(mask, x.shape) 67 | return lax.select(mask, x / keep_prob, jnp.zeros_like(x)) 68 | 69 | 70 | class DropPath(nn.Module): 71 | rate: float = 0. 72 | deterministic: Optional[bool] = None 73 | 74 | @nn.compact 75 | def __call__(self, x, deterministic: bool): 76 | deterministic = merge_param( 77 | 'deterministic', self.deterministic, deterministic) 78 | if deterministic or self.rate == 0.: 79 | return x 80 | else: 81 | rng = self.make_rng('drop_path') 82 | return drop_path(x, rng, self.rate) 83 | 84 | 85 | class AddPositionEmbs(nn.Module): 86 | """Adds learned positional embeddings to the inputs. 87 | 88 | Attributes: 89 | posemb_init: positional embedding initializer. 90 | """ 91 | 92 | posemb_init: Callable[[PRNGKey, Shape, Dtype], Array] 93 | dtype: Dtype = jnp.float32 94 | 95 | @nn.compact 96 | def __call__(self, inputs): 97 | """Applies the AddPositionEmbs module. 98 | 99 | Args: 100 | inputs: Inputs to the layer. 101 | 102 | Returns: 103 | Output tensor with shape `(bs, timesteps, in_dim)`. 104 | """ 105 | # inputs.shape is (batch_size, seq_len, emb_dim). 106 | assert inputs.ndim == 3, ('Number of dimensions should be 3,' 107 | ' but it is: %d' % inputs.ndim) 108 | pos_emb_shape = (1, inputs.shape[1], inputs.shape[2]) 109 | pe = self.param('pos_embedding', self.posemb_init, pos_emb_shape) 110 | pe = pe.astype(self.dtype) 111 | return inputs + pe 112 | 113 | 114 | class MlpBlock(nn.Module): 115 | """Transformer MLP / feed-forward block.""" 116 | 117 | mlp_dim: int 118 | dtype: Dtype = jnp.float32 119 | out_dim: Optional[int] = None 120 | dropout_rate: float = 0.0 121 | kernel_init: Callable[[PRNGKey, Shape, Dtype], 122 | Array] = nn.initializers.xavier_uniform() 123 | bias_init: Callable[[PRNGKey, Shape, Dtype], 124 | Array] = nn.initializers.normal(stddev=1e-6) 125 | 126 | @nn.compact 127 | def __call__(self, inputs, *, deterministic): 128 | """Applies Transformer MlpBlock module.""" 129 | actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim 130 | x = nn.Dense( 131 | features=self.mlp_dim, 132 | dtype=self.dtype, 133 | kernel_init=self.kernel_init, 134 | bias_init=self.bias_init)( # pytype: disable=wrong-arg-types 135 | inputs) 136 | x = nn.tanh(x) 137 | x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic) 138 | output = nn.Dense( 139 | features=actual_out_dim, 140 | dtype=self.dtype, 141 | kernel_init=self.kernel_init, 142 | bias_init=self.bias_init)( # pytype: disable=wrong-arg-types 143 | x) 144 | output = nn.Dropout( 145 | rate=self.dropout_rate)( 146 | output, deterministic=deterministic) 147 | return output 148 | 149 | 150 | class Encoder1DBlock(nn.Module): 151 | """Transformer encoder layer. 152 | 153 | Attributes: 154 | inputs: input data. 155 | mlp_dim: dimension of the mlp on top of attention block. 156 | dtype: the dtype of the computation (default: float32). 157 | dropout_rate: dropout rate. 158 | attention_dropout_rate: dropout for attention heads. 159 | deterministic: bool, deterministic or not (to apply dropout). 160 | num_heads: Number of heads in nn.MultiHeadDotProductAttention 161 | """ 162 | 163 | mlp_dim: int 164 | num_heads: int 165 | dtype: Dtype = jnp.float32 166 | dropout_rate: float = 0.0 167 | droppath_rate: float = 0.0 168 | attention_dropout_rate: float = 0.0 169 | 170 | @nn.compact 171 | def __call__(self, inputs, *, deterministic): 172 | """Applies Encoder1DBlock module. 173 | 174 | Args: 175 | inputs: Inputs to the layer. 176 | deterministic: Dropout will not be applied when set to true. 177 | 178 | Returns: 179 | output after transformer encoder block. 180 | """ 181 | 182 | # Attention block. 183 | assert inputs.ndim == 3, f'Expected (batch, seq, hidden) got {inputs.shape}' 184 | x = nn.LayerNorm(dtype=self.dtype)(inputs) 185 | x = nn.MultiHeadDotProductAttention( 186 | dtype=self.dtype, 187 | kernel_init=nn.initializers.xavier_uniform(), 188 | broadcast_dropout=False, 189 | deterministic=deterministic, 190 | dropout_rate=self.attention_dropout_rate, 191 | num_heads=self.num_heads)( 192 | x, x) 193 | x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic) 194 | x = DropPath(rate=self.droppath_rate)(x, deterministic=deterministic) + inputs 195 | 196 | # MLP block. 197 | y = nn.LayerNorm(dtype=self.dtype)(x) 198 | y = MlpBlock( 199 | mlp_dim=self.mlp_dim, dtype=self.dtype, dropout_rate=self.dropout_rate)( 200 | y, deterministic=deterministic) 201 | 202 | return x + DropPath(rate=self.droppath_rate)(y, deterministic=deterministic) 203 | 204 | 205 | class Encoder(nn.Module): 206 | """Transformer Model Encoder for sequence to sequence translation. 207 | 208 | Attributes: 209 | num_layers: number of layers 210 | mlp_dim: dimension of the mlp on top of attention block 211 | num_heads: Number of heads in nn.MultiHeadDotProductAttention 212 | dropout_rate: dropout rate. 213 | attention_dropout_rate: dropout rate in self attention. 214 | """ 215 | 216 | num_layers: int 217 | mlp_dim: int 218 | num_heads: int 219 | dtype: Dtype = jnp.float32 220 | dropout_rate: float = 0.0 221 | droppath_rate: float = 0.0 222 | attention_dropout_rate: float = 0.0 223 | add_position_embedding: bool = True 224 | 225 | @nn.compact 226 | def __call__(self, x, *, train): 227 | """Applies Transformer model on the inputs. 228 | 229 | Args: 230 | x: Inputs to the layer. 231 | train: Set to `True` when training. 232 | 233 | Returns: 234 | output of a transformer encoder. 235 | """ 236 | assert x.ndim == 3 # (batch, len, emb) 237 | 238 | if self.add_position_embedding: 239 | x = AddPositionEmbs( 240 | posemb_init=nn.initializers.normal(stddev=0.02), # from BERT. 241 | dtype=self.dtype, 242 | name='posembed_input')( 243 | x) 244 | x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train) 245 | else: 246 | x = get_sinusoid_encoding_table(x.shape[-2], x.shape[-1], self.dtype) + x 247 | x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train) 248 | 249 | # Input Encoder 250 | dpr = [x for x in np.linspace(0, self.droppath_rate, self.num_layers)] 251 | for lyr in range(self.num_layers): 252 | x = Encoder1DBlock( 253 | mlp_dim=self.mlp_dim, 254 | dropout_rate=self.dropout_rate, 255 | dtype=self.dtype, 256 | droppath_rate=dpr[lyr], 257 | attention_dropout_rate=self.attention_dropout_rate, 258 | name=f'encoderblock_{lyr}', 259 | num_heads=self.num_heads)( 260 | x, deterministic=not train) 261 | encoded = nn.LayerNorm(name='encoder_norm')(x) 262 | 263 | return encoded 264 | 265 | def space_to_depth( 266 | frames: jnp.ndarray, 267 | spatial_block_size: Any = [1, 1]) -> jnp.ndarray: 268 | """Space to depth transform.""" 269 | if len(frames.shape) == 4: 270 | return einops.rearrange( 271 | frames, 'b (h dh) (w dw) c -> b (h w) (dh dw c)', 272 | dh=spatial_block_size[0], dw=spatial_block_size[1]) 273 | elif len(frames.shape) == 5: 274 | return einops.rearrange( 275 | frames, 'b t (h dh) (w dw) c -> b t (dh dw c)', 276 | dh=spatial_block_size[0], dw=spatial_block_size[1]) 277 | else: 278 | raise ValueError( 279 | 'Frames should be of rank 4 (batch, height, width, channels)' 280 | ' or rank 5 (batch, time, height, width, channels)') 281 | 282 | def reverse_space_to_depth( 283 | frames: jnp.ndarray, 284 | temporal_block_size: int = 1, 285 | spatial_block_size: int = 1, 286 | height: int = 16, 287 | width: int = 16) -> jnp.ndarray: 288 | """Reverse space to depth transform.""" 289 | if len(frames.shape) == 3: 290 | return einops.rearrange( 291 | frames, 'b (h w) (dh dw c) -> b (h dh) (w dw) c', 292 | h=height, w=width, dh=spatial_block_size, dw=spatial_block_size) 293 | elif len(frames.shape) == 4: 294 | return einops.rearrange( 295 | frames, 'b h w (dh dw c) -> b (h dh) (w dw) c', 296 | dh=spatial_block_size, dw=spatial_block_size) 297 | elif len(frames.shape) == 5: 298 | return einops.rearrange( 299 | frames, 'b t h w (dt dh dw c) -> b (t dt) (h dh) (w dw) c', 300 | dt=temporal_block_size, dh=spatial_block_size, dw=spatial_block_size) 301 | else: 302 | raise ValueError( 303 | 'Frames should be of rank 4 (batch, height, width, channels)' 304 | ' or rank 5 (batch, time, height, width, channels)') 305 | 306 | class VisionTransformer(nn.Module): 307 | """VisionTransformer.""" 308 | 309 | num_classes: int 310 | patch_size: int 311 | hidden_size: int 312 | num_layers: int 313 | mlp_dim: int 314 | num_heads: int 315 | dtype: Dtype = jnp.float32 316 | dropout_rate: float = 0.0 317 | droppath_rate: float = 0.0 318 | attention_dropout_rate: float = 0.0 319 | add_position_embedding: bool = False 320 | representation_size: Optional[int] = None 321 | classifier: str = 'token' 322 | head_bias_init: float = 0. 323 | 324 | @nn.compact 325 | def __call__(self, inputs, *, train): 326 | 327 | x = inputs 328 | if x.ndim != 3: 329 | x = space_to_depth(x, spatial_block_size=self.patch_size) 330 | 331 | x = nn.Dense( 332 | features=self.hidden_size, 333 | dtype=self.dtype, 334 | name='embedding', 335 | )(x) 336 | 337 | n, l, c = x.shape 338 | 339 | # If we want to add a class token, add it here. 340 | if self.classifier == 'token': 341 | cls = self.param('cls', nn.initializers.zeros, (1, 1, c)) 342 | cls = jnp.tile(cls, [n, 1, 1]) 343 | x = jnp.concatenate([cls, x], axis=1) 344 | 345 | x = Encoder( 346 | num_layers=self.num_layers, 347 | mlp_dim=self.mlp_dim, 348 | num_heads=self.num_heads, 349 | dtype=self.dtype, 350 | dropout_rate=self.dropout_rate, 351 | droppath_rate=self.droppath_rate, 352 | attention_dropout_rate=self.attention_dropout_rate, 353 | add_position_embedding=self.add_position_embedding, 354 | )(x, train=train) 355 | 356 | if self.classifier == 'token': 357 | x = x[:, 0] 358 | elif self.classifier == 'gap': 359 | x = jnp.mean(x, axis=list(range(1, x.ndim - 1))) # (1,) or (1,2) 360 | elif self.classifier == 'unpooled': 361 | pass 362 | else: 363 | raise ValueError(f'Invalid classifier={self.classifier}') 364 | 365 | if self.representation_size is not None: 366 | x = nn.Dense(features=self.representation_size, name='pre_logits')(x) 367 | x = nn.tanh(x) 368 | else: 369 | x = IdentityLayer(name='pre_logits')(x) 370 | 371 | if self.num_classes: 372 | x = nn.Dense( 373 | features=self.num_classes, 374 | name='head', 375 | kernel_init=nn.initializers.zeros, 376 | bias_init=nn.initializers.constant(self.head_bias_init))(x) 377 | return x -------------------------------------------------------------------------------- /training/train_model.py: -------------------------------------------------------------------------------- 1 | from models.modeling import * 2 | from models.checkpoint import f32_to_bf16, bf16_to_f32 3 | from training.train_state import TrainState_v2 4 | from models.discriminator import Discriminator 5 | 6 | class Generator(Model): 7 | """ 8 | Generator 9 | """ 10 | def __call__(self, x, *, train): 11 | """ 12 | Does a forward pass for pretraining 13 | :param batch: Everything from pretraining 14 | :return: 15 | """ 16 | rec, diff = self.model(x, train=train) 17 | return rec, diff 18 | 19 | class Discriminator(Model): 20 | """ 21 | Discriminator 22 | """ 23 | def get_stylegan_logit(self, x, *, train=True): 24 | return self.model.get_stylegan_logit(x, train=train) 25 | 26 | def get_clip_feature(self, x, *, train=False): 27 | return self.model.get_clip_feature(x, train=False) 28 | 29 | def get_clip_logit(self, x, *, train=True): 30 | return self.model.get_clip_logit(x, train=train) 31 | 32 | def __call__(self, x, *, train): 33 | """ 34 | Does a forward pass for pretraining 35 | :param batch: Everything from pretraining 36 | :return: 37 | """ 38 | output = self.model(x, train=train) 39 | return output 40 | 41 | 42 | class LPIPS(Model): 43 | """ 44 | LPIPS 45 | """ 46 | def __call__(self, x0, x1, *, train): 47 | """ 48 | Does a forward pass for pretraining 49 | :param batch: Everything from pretraining 50 | :return: 51 | """ 52 | output = self.model(x0, x1, train=train) 53 | return output 54 | 55 | def vanilla_d_loss(logits_fake, logits_real = None): 56 | loss_fake = jnp.mean(jax.nn.softplus(-logits_fake)) * 2 if logits_real is None else jnp.mean(jax.nn.softplus(logits_fake)) 57 | loss_real = 0 if logits_real is None else jnp.mean(jax.nn.softplus(-logits_real)) 58 | return 0.5 * (loss_real + loss_fake) 59 | 60 | def hinge_d_loss(logits_fake, logits_real = None): 61 | loss_fake = jnp.mean(jax.nn.relu(1.0 + logits_fake)) 62 | loss_real = jnp.mean(jax.nn.relu(1.0 - logits_real)) 63 | return 0.5 * (loss_real + loss_fake) 64 | 65 | def calculate_adaptive_weight(): 66 | pass 67 | 68 | def VQLPIPS( 69 | input, 70 | rec, 71 | codebook_loss, 72 | p_logits, 73 | codebook_weight: float = 1.0, 74 | loggaussian_weight: float = 1.0, 75 | loglaplace_weight: float = 0.0, 76 | perceptual_weight: float = 0.1, 77 | ): 78 | 79 | loglaplace_loss = jnp.mean(jnp.abs(rec - input)) 80 | loggaussian_loss = jnp.mean((rec - input)**2) 81 | perceptual_loss = jnp.mean(p_logits) 82 | 83 | nll_loss = loglaplace_weight * loglaplace_loss + loggaussian_weight * loggaussian_loss \ 84 | + perceptual_weight * perceptual_loss 85 | 86 | loss = nll_loss + codebook_weight * codebook_loss 87 | return loss, (loglaplace_loss, loggaussian_loss, perceptual_loss, codebook_loss) 88 | 89 | def train_step(state: TrainState_v2, batch, config=None): 90 | """ 91 | Note: we'll compile this with pmap so no need to jit 92 | :param state: 93 | :param batch: 94 | :param use_bfloat16_grads: Whether to use bfloat16 for storing grads. I think it is probably OK considering 95 | momentum is bfloat16 anyways. i'm just going to cast down (rounding down here rather 96 | than to nearest or anything) 97 | :return: 98 | """ 99 | def _loss_fn_nll(params): 100 | rec, codebook_loss = state.apply_fn_g( 101 | {'params': params}, 102 | rngs=state.aux_rng_keys_g, 103 | x=batch['inputs'], 104 | train=True, 105 | ) 106 | 107 | p_logits = state.apply_fn_p( 108 | {'params': state.params_p}, 109 | rngs=state.aux_rng_keys_p, 110 | x0=batch['inputs'], 111 | x1=rec, 112 | train=False, 113 | ) 114 | 115 | loss, aux_loss = VQLPIPS( 116 | batch['inputs'], 117 | rec, 118 | codebook_loss, 119 | p_logits, 120 | codebook_weight = config['loss']['codebook_weight'], 121 | loggaussian_weight = config['loss']['loggaussian_weight'], 122 | loglaplace_weight = config['loss']['loglaplace_weight'], 123 | perceptual_weight = config['loss']['perceptual_weight'], 124 | ) 125 | 126 | return jnp.mean(loss), (rec, aux_loss) 127 | 128 | def _loss_fn_g_image(params): 129 | rec, codebook_loss = state.apply_fn_g( 130 | {'params': params}, 131 | rngs=state.aux_rng_keys_g, 132 | x=batch['inputs'], 133 | train=True, 134 | ) 135 | 136 | style_g_logits = state.apply_fn_d( 137 | {'params': state.params_d}, 138 | method=Discriminator.get_stylegan_logit, 139 | rngs=state.aux_rng_keys_d, 140 | x=rec, 141 | train=False, 142 | ) 143 | 144 | style_g_loss = jnp.mean(vanilla_d_loss(style_g_logits)) 145 | 146 | clip_feat = state.apply_fn_d( 147 | {'params': state.params_d}, 148 | method=Discriminator.get_clip_feature, 149 | rngs=state.aux_rng_keys_d, 150 | x=(rec+1.0)/2.0, 151 | train=False, 152 | ) 153 | 154 | clip_logits = state.apply_fn_d( 155 | {'params': state.params_d}, 156 | method=Discriminator.get_clip_logit, 157 | rngs=state.aux_rng_keys_d, 158 | x=clip_feat, 159 | train=True, 160 | ) 161 | 162 | clip_loss = 0 163 | for logits in clip_logits: 164 | clip_loss += jnp.mean(vanilla_d_loss(logits)) 165 | 166 | clip_loss = clip_loss / len(clip_logits) 167 | return style_g_loss + clip_loss 168 | 169 | def _loss_fn_d_image(params): 170 | rec, _ = state.apply_fn_g( 171 | {'params': state.params_g}, 172 | rngs=state.aux_rng_keys_g, 173 | x=batch['inputs'], 174 | train=False, 175 | ) 176 | rec = jax.lax.stop_gradient(rec) 177 | 178 | c_fake_feat = state.apply_fn_d( 179 | {'params': state.params_d}, 180 | method=Discriminator.get_clip_feature, 181 | rngs=state.aux_rng_keys_d, 182 | x=(rec+1.0)/2.0, 183 | train=False, 184 | ) 185 | c_real_feat = state.apply_fn_d( 186 | {'params': state.params_d}, 187 | method=Discriminator.get_clip_feature, 188 | rngs=state.aux_rng_keys_d, 189 | x=(batch['inputs']+1.0)/2.0, 190 | train=False, 191 | ) 192 | 193 | c_fake_feat = jax.lax.stop_gradient(c_fake_feat) 194 | c_real_feat = jax.lax.stop_gradient(c_real_feat) 195 | 196 | c_real_logit = state.apply_fn_d( 197 | {'params': state.params_d}, 198 | method=Discriminator.get_clip_logit, 199 | rngs=state.aux_rng_keys_d, 200 | x=c_real_feat, 201 | train=True, 202 | ) 203 | 204 | c_fake_logit = state.apply_fn_d( 205 | {'params': state.params_d}, 206 | method=Discriminator.get_clip_logit, 207 | rngs=state.aux_rng_keys_d, 208 | x=c_fake_feat, 209 | train=True, 210 | ) 211 | 212 | d_real_logit = state.apply_fn_d( 213 | {'params': params}, 214 | method=Discriminator.get_stylegan_logit, 215 | rngs=state.aux_rng_keys_d, 216 | x=batch['inputs'], 217 | train=True, 218 | ) 219 | d_fake_logit = state.apply_fn_d( 220 | {'params': params}, 221 | method=Discriminator.get_stylegan_logit, 222 | rngs=state.aux_rng_keys_d, 223 | x=rec, 224 | train=True, 225 | ) 226 | 227 | # adding flip update for the discriminator. 228 | flip_update = jnp.where( 229 | jnp.logical_and( 230 | state.step < config['loss']['disc_d_flip_update'], 231 | jnp.array_equal(jnp.mod(state.step, jnp.array(3)), jnp.array(0))), 232 | True, False) 233 | 234 | def positive_branch(arg): 235 | d_real_logit, d_fake_logit, c_real_logit, c_fake_logit = arg 236 | d_loss = vanilla_d_loss(d_real_logit, d_fake_logit) 237 | 238 | clip_loss = 0 239 | for real_logits, fake_logits in zip(c_real_logit, c_fake_logit): 240 | clip_loss += jnp.mean(vanilla_d_loss(real_logits, fake_logits)) 241 | 242 | return d_loss, clip_loss 243 | 244 | def negative_branch(arg): 245 | d_real_logit, d_fake_logit, c_real_logit, c_fake_logit = arg 246 | d_loss = vanilla_d_loss(d_fake_logit, d_real_logit) 247 | 248 | clip_loss = 0 249 | for real_logits, fake_logits in zip(c_real_logit, c_fake_logit): 250 | clip_loss += jnp.mean(vanilla_d_loss(fake_logits, real_logits)) 251 | 252 | clip_loss = clip_loss / len(c_real_logit) 253 | 254 | return d_loss, clip_loss 255 | 256 | d_loss, clip_loss = jax.lax.cond(flip_update, positive_branch, negative_branch, (d_real_logit, d_fake_logit, c_real_logit, c_fake_logit)) 257 | 258 | d_weight = jnp.where(state.step > config['loss']['disc_d_start'], 1.0, 0.0) 259 | loss = (d_loss + clip_loss) * d_weight 260 | return jnp.mean(loss), (jnp.mean(d_real_logit),jnp.mean(d_fake_logit), [jnp.mean(x) for x in c_real_logit], [jnp.mean(x) for x in c_fake_logit], jnp.mean(d_loss), jnp.mean(clip_loss)) 261 | 262 | 263 | def _loss_fn_g_audio(params): 264 | rec, codebook_loss = state.apply_fn_g( 265 | {'params': params}, 266 | rngs=state.aux_rng_keys_g, 267 | x=batch['inputs'], 268 | train=True, 269 | ) 270 | 271 | style_g_logits = state.apply_fn_d( 272 | {'params': state.params_d}, 273 | method=Discriminator.get_stylegan_logit, 274 | rngs=state.aux_rng_keys_d, 275 | x=rec, 276 | train=False, 277 | ) 278 | style_g_loss = jnp.mean(vanilla_d_loss(style_g_logits)) 279 | return style_g_loss 280 | 281 | def _loss_fn_d_audio(params): 282 | rec, _ = state.apply_fn_g( 283 | {'params': state.params_g}, 284 | rngs=state.aux_rng_keys_g, 285 | x=batch['inputs'], 286 | train=False, 287 | ) 288 | rec = jax.lax.stop_gradient(rec) 289 | 290 | d_real_logit = state.apply_fn_d( 291 | {'params': params}, 292 | method=Discriminator.get_stylegan_logit, 293 | rngs=state.aux_rng_keys_d, 294 | x=batch['inputs'], 295 | train=True, 296 | ) 297 | d_fake_logit = state.apply_fn_d( 298 | {'params': params}, 299 | method=Discriminator.get_stylegan_logit, 300 | rngs=state.aux_rng_keys_d, 301 | x=rec, 302 | train=True, 303 | ) 304 | 305 | # adding flip update for the discriminator. 306 | flip_update = jnp.where( 307 | jnp.logical_and( 308 | state.step < config['loss']['disc_d_flip_update'], 309 | jnp.array_equal(jnp.mod(state.step, jnp.array(3)), jnp.array(0))), 310 | True, False) 311 | 312 | def positive_branch(arg): 313 | d_real_logit, d_fake_logit = arg 314 | d_loss = vanilla_d_loss(d_real_logit, d_fake_logit) 315 | 316 | return d_loss 317 | 318 | def negative_branch(arg): 319 | d_real_logit, d_fake_logit = arg 320 | d_loss = vanilla_d_loss(d_fake_logit, d_real_logit) 321 | 322 | return d_loss 323 | 324 | d_loss = jax.lax.cond(flip_update, positive_branch, negative_branch, (d_real_logit, d_fake_logit)) 325 | 326 | d_weight = jnp.where(state.step > config['loss']['disc_d_start'], 1.0, 0.0) 327 | loss = d_loss * d_weight 328 | return jnp.mean(loss), (jnp.mean(d_real_logit),jnp.mean(d_fake_logit) , jnp.mean(d_loss)) 329 | 330 | 331 | 332 | use_bfloat16_grads = config['model']['use_bfloat16'] 333 | params_g = state.params_g 334 | if use_bfloat16_grads: 335 | params_g = f32_to_bf16(params_g) 336 | 337 | grad_fn_nll = jax.value_and_grad(_loss_fn_nll, has_aux=True) 338 | output_nll, grads_nll = grad_fn_nll(params_g) 339 | loss_nll, aux_output = output_nll 340 | rec, aux_loss = aux_output 341 | loglaplace_loss, loggaussian_loss, p_loss, c_loss = aux_loss 342 | 343 | if config['data']['task'] == 'image': 344 | grad_fn_g = jax.value_and_grad(_loss_fn_g_image, has_aux=False) 345 | else: 346 | grad_fn_g = jax.value_and_grad(_loss_fn_g_audio, has_aux=False) 347 | 348 | loss_g, grads_g = grad_fn_g(params_g) 349 | 350 | # new way to calculate d_weight, get the norm of the gradient. 351 | # grads_g = jax.tree_map(lambda x, y: 0.1 * (1 / jnp.linalg.norm(x))*y, params_g, grads_g) 352 | 353 | # -------------------------------------------------------------- 354 | last_layer_nll = jnp.linalg.norm(grads_nll['model']['decoder_proj']['kernel']) 355 | last_layer_g = jnp.linalg.norm(grads_g['model']['decoder_proj']['kernel']) 356 | 357 | d_weight_factor = jnp.where(state.step > config['loss']['disc_g_start'], config['loss']['adversarial_weight'],0.0) 358 | d_weight = last_layer_nll / (last_layer_g + 1e-4) * d_weight_factor 359 | 360 | grads_g = jax.tree_map(lambda x: d_weight*x, grads_g) 361 | # # -------------------------------------------------------------- 362 | 363 | # grads = jax.tree_map(lambda x: jnp.nan_to_num(x, copy=False), grads) 364 | grads_nll = jax.lax.pmean(grads_nll, axis_name='batch') 365 | grads_g = jax.lax.pmean(grads_g, axis_name='batch') 366 | 367 | # Cast up grads here (after the pmean) which reduces bandwidth maybe 368 | if use_bfloat16_grads: 369 | grads_nll = bf16_to_f32(grads_nll) 370 | grads_g = bf16_to_f32(grads_g) 371 | 372 | state = state.apply_gradients_g(grads=grads_nll) 373 | state = state.apply_gradients_g(grads=grads_g) 374 | 375 | loss_info = {} 376 | loss_info['loss'] = loss_nll 377 | loss_info['loglaplace_loss'] = loglaplace_loss 378 | loss_info['loggaussian_loss'] = loggaussian_loss 379 | loss_info['perceptual_loss'] = p_loss 380 | loss_info['codebook_loss'] = c_loss 381 | loss_info['d_weight'] = d_weight 382 | 383 | loss_info['image_rec'] = (rec.clip(-1.0, 1.0) + 1.0) / 2.0 384 | loss_info['image_ori'] = (batch['inputs'] + 1.0) / 2.0 385 | 386 | # discriminator 387 | 388 | if config['data']['task'] == 'image': 389 | grad_fn_d = jax.value_and_grad(_loss_fn_d_image, has_aux=True) 390 | else: 391 | grad_fn_d = jax.value_and_grad(_loss_fn_d_audio, has_aux=True) 392 | 393 | params_d = state.params_d 394 | if use_bfloat16_grads: 395 | params_d = f32_to_bf16(state.params_d) 396 | output_d, grads_d = grad_fn_d(params_d) 397 | 398 | loss_d, aux_output_d = output_d 399 | 400 | if config['data']['task'] == 'image': 401 | d_real_logit, d_fake_logit, c_real_logit, c_fake_logit, d_loss, clip_loss = aux_output_d 402 | loss_info['c_real_logit_0'] = c_real_logit[-1] 403 | loss_info['c_fake_logit_0'] = c_fake_logit[-1] 404 | loss_info['d_loss_clip'] = clip_loss 405 | else: 406 | d_real_logit, d_fake_logit, d_loss = aux_output_d 407 | 408 | loss_info['g_loss'] = loss_g 409 | loss_info['d_loss'] = loss_d 410 | loss_info['d_real_logit'] = d_real_logit 411 | loss_info['d_fake_logit'] = d_fake_logit 412 | loss_info['d_loss_stylegan'] = d_loss 413 | 414 | if use_bfloat16_grads: 415 | grads_d = bf16_to_f32(grads_d) 416 | 417 | grads_d = jax.lax.pmean(grads_d, axis_name='batch') 418 | state = state.apply_gradients_d(grads=grads_d) 419 | 420 | # Average metrics over all replicas (maybe this isn't a great idea, idk) 421 | # loss_info = jax.lax.pmean(loss_info, axis_name='batch') 422 | loss_info = bf16_to_f32(loss_info) 423 | state = state.replace(step = state.step + 1) 424 | 425 | return state, loss_info -------------------------------------------------------------------------------- /data/imagenet_utils.py: -------------------------------------------------------------------------------- 1 | import enum 2 | from typing import Any, Generator, Mapping, Optional, Sequence, Text, Tuple 3 | 4 | import jax 5 | import numpy as np 6 | import math 7 | import tensorflow.compat.v2 as tf 8 | import tensorflow_datasets as tfds 9 | import tensorflow_probability as tfp 10 | 11 | from . import autoaugment 12 | 13 | MEAN_RGB = (0.485, 0.456, 0.406) 14 | STDDEV_RGB = (0.229, 0.224, 0.225) 15 | AUTOTUNE = tf.data.experimental.AUTOTUNE 16 | 17 | INPUT_DIM = 224 18 | 19 | 20 | def _preprocess_image( 21 | image_bytes: tf.Tensor, 22 | is_training: bool, 23 | image_size: Sequence[int], 24 | augmentation_settings: Mapping[str, Any], 25 | ) -> Tuple[tf.Tensor, tf.Tensor]: 26 | """Returns processed and resized images.""" 27 | 28 | # Get the image crop. 29 | if is_training: 30 | image, im_shape = _decode_and_random_crop(image_bytes) 31 | image = tf.image.random_flip_left_right(image) 32 | else: 33 | image, im_shape = _decode_and_center_crop(image_bytes) 34 | assert image.dtype == tf.uint8 35 | 36 | # Optionally apply RandAugment: https://arxiv.org/abs/1909.13719 37 | if is_training: 38 | if augmentation_settings['randaugment'] is not None: 39 | # Input and output images are dtype uint8. 40 | image = autoaugment.distort_image_with_randaugment( 41 | image, 42 | num_layers=augmentation_settings['randaugment']['num_layers'], 43 | magnitude=augmentation_settings['randaugment']['magnitude']) 44 | 45 | # Resize and normalize the image crop. 46 | # NOTE: Bicubic resize (1) casts uint8 to float32 and (2) resizes without 47 | # clamping overshoots. This means values returned will be outside the range 48 | # [0.0, 255.0] (e.g. we have observed outputs in the range [-51.1, 336.6]). 49 | image = tf.image.resize( 50 | image, image_size, tf.image.ResizeMethod.BICUBIC) 51 | image = image / 255. 52 | image = _normalize_image(image) 53 | 54 | return image, im_shape 55 | 56 | 57 | def _normalize_image(image: tf.Tensor) -> tf.Tensor: 58 | """Normalize the image to zero mean and unit variance.""" 59 | image -= tf.constant(MEAN_RGB, shape=[1, 1, 3], dtype=image.dtype) 60 | image /= tf.constant(STDDEV_RGB, shape=[1, 1, 3], dtype=image.dtype) 61 | return image 62 | 63 | 64 | def _decode_and_random_crop( 65 | image_bytes: tf.Tensor 66 | ) -> Tuple[tf.Tensor, tf.Tensor]: 67 | """Make a random crop of INPUT_DIM.""" 68 | 69 | if image_bytes.dtype == tf.dtypes.string: 70 | jpeg_shape = tf.image.extract_jpeg_shape(image_bytes) 71 | else: 72 | jpeg_shape = tf.shape(image_bytes) 73 | 74 | bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4]) 75 | image, im_shape = _distorted_bounding_box_crop( 76 | image_bytes, 77 | jpeg_shape=jpeg_shape, 78 | bbox=bbox, 79 | min_object_covered=0.1, 80 | aspect_ratio_range=(3 / 4, 4 / 3), 81 | area_range=(0.08, 1.0), 82 | max_attempts=10) 83 | 84 | if tf.reduce_all(tf.equal(jpeg_shape, tf.shape(image))): 85 | # If the random crop failed fall back to center crop. 86 | image, im_shape = _decode_and_center_crop(image_bytes, jpeg_shape) 87 | return image, im_shape 88 | 89 | 90 | def _distorted_bounding_box_crop( 91 | image_bytes: tf.Tensor, 92 | *, 93 | jpeg_shape: tf.Tensor, 94 | bbox: tf.Tensor, 95 | min_object_covered: float, 96 | aspect_ratio_range: Tuple[float, float], 97 | area_range: Tuple[float, float], 98 | max_attempts: int, 99 | ) -> Tuple[tf.Tensor, tf.Tensor]: 100 | """Generates cropped_image using one of the bboxes randomly distorted.""" 101 | bbox_begin, bbox_size, _ = tf.image.sample_distorted_bounding_box( 102 | jpeg_shape, 103 | bounding_boxes=bbox, 104 | min_object_covered=min_object_covered, 105 | aspect_ratio_range=aspect_ratio_range, 106 | area_range=area_range, 107 | max_attempts=max_attempts, 108 | use_image_if_no_bounding_boxes=True) 109 | 110 | # Crop the image to the specified bounding box. 111 | offset_y, offset_x, _ = tf.unstack(bbox_begin) 112 | target_height, target_width, _ = tf.unstack(bbox_size) 113 | crop_window = [offset_y, offset_x, target_height, target_width] 114 | 115 | if image_bytes.dtype == tf.dtypes.string: 116 | image = tf.image.decode_and_crop_jpeg(image_bytes, 117 | tf.stack(crop_window), 118 | channels=3) 119 | else: 120 | image = tf.image.crop_to_bounding_box(image_bytes, *crop_window) 121 | 122 | im_shape = tf.stack([target_height, target_width]) 123 | return image, im_shape 124 | 125 | 126 | def _center_crop(image, crop_dim): 127 | """Center crops an image to a target dimension.""" 128 | image_height = image.shape[0] 129 | image_width = image.shape[1] 130 | offset_height = ((image_height - crop_dim) + 1) // 2 131 | offset_width = ((image_width - crop_dim) + 1) // 2 132 | return tf.image.crop_to_bounding_box( 133 | image, offset_height, offset_width, crop_dim, crop_dim) 134 | 135 | 136 | def _decode_and_center_crop( 137 | image_bytes: tf.Tensor, 138 | jpeg_shape: Optional[tf.Tensor] = None, 139 | ) -> Tuple[tf.Tensor, tf.Tensor]: 140 | """Crops to center of image with padding then scales.""" 141 | if jpeg_shape is None: 142 | if image_bytes.dtype == tf.dtypes.string: 143 | jpeg_shape = tf.image.extract_jpeg_shape(image_bytes) 144 | else: 145 | jpeg_shape = tf.shape(image_bytes) 146 | 147 | image_height = jpeg_shape[0] 148 | image_width = jpeg_shape[1] 149 | 150 | padded_center_crop_size = tf.cast( 151 | ((INPUT_DIM / (INPUT_DIM + 32)) * 152 | tf.cast(tf.minimum(image_height, image_width), tf.float32)), tf.int32) 153 | 154 | offset_height = ((image_height - padded_center_crop_size) + 1) // 2 155 | offset_width = ((image_width - padded_center_crop_size) + 1) // 2 156 | crop_window = [offset_height, offset_width, 157 | padded_center_crop_size, padded_center_crop_size] 158 | 159 | if image_bytes.dtype == tf.dtypes.string: 160 | image = tf.image.decode_and_crop_jpeg(image_bytes, 161 | tf.stack(crop_window), 162 | channels=3) 163 | else: 164 | image = tf.image.crop_to_bounding_box(image_bytes, *crop_window) 165 | 166 | im_shape = tf.stack([padded_center_crop_size, padded_center_crop_size]) 167 | return image, im_shape 168 | 169 | 170 | def cutmix_padding(h, w): 171 | """Returns image mask for CutMix. 172 | Taken from (https://github.com/google/edward2/blob/master/experimental 173 | /marginalization_mixup/data_utils.py#L367) 174 | Args: 175 | h: image height. 176 | w: image width. 177 | """ 178 | r_x = tf.random.uniform([], 0, w, tf.int32) 179 | r_y = tf.random.uniform([], 0, h, tf.int32) 180 | 181 | # Beta dist in paper, but they used Beta(1,1) which is just uniform. 182 | image1_proportion = tf.random.uniform([]) 183 | patch_length_ratio = tf.math.sqrt(1 - image1_proportion) 184 | r_w = tf.cast(patch_length_ratio * tf.cast(w, tf.float32), tf.int32) 185 | r_h = tf.cast(patch_length_ratio * tf.cast(h, tf.float32), tf.int32) 186 | bbx1 = tf.clip_by_value(tf.cast(r_x - r_w // 2, tf.int32), 0, w) 187 | bby1 = tf.clip_by_value(tf.cast(r_y - r_h // 2, tf.int32), 0, h) 188 | bbx2 = tf.clip_by_value(tf.cast(r_x + r_w // 2, tf.int32), 0, w) 189 | bby2 = tf.clip_by_value(tf.cast(r_y + r_h // 2, tf.int32), 0, h) 190 | 191 | # Create the binary mask. 192 | pad_left = bbx1 193 | pad_top = bby1 194 | pad_right = tf.maximum(w - bbx2, 0) 195 | pad_bottom = tf.maximum(h - bby2, 0) 196 | r_h = bby2 - bby1 197 | r_w = bbx2 - bbx1 198 | 199 | mask = tf.pad( 200 | tf.ones((r_h, r_w)), 201 | paddings=[[pad_top, pad_bottom], [pad_left, pad_right]], 202 | mode='CONSTANT', 203 | constant_values=0) 204 | mask.set_shape((h, w)) 205 | return mask[..., None] # Add channel dim. 206 | 207 | 208 | def my_cutmix(batch): 209 | """Apply CutMix: https://arxiv.org/abs/1905.04899.""" 210 | batch = dict(**batch) 211 | bs = tf.shape(batch['images'])[0] // 2 212 | mask = batch['mask'][:bs] 213 | images = (mask * batch['images'][:bs] + (1.0 - mask) * batch['images'][bs:]) 214 | mix_labels = batch['labels'][bs:] 215 | labels = batch['labels'][:bs] 216 | ratio = batch['cutmix_ratio'][:bs] 217 | return {'images': images, 'labels': labels, 218 | 'mix_labels': mix_labels, 'ratio': ratio} 219 | 220 | 221 | def my_mixup(batch): 222 | """Apply mixup: https://arxiv.org/abs/1710.09412.""" 223 | batch = dict(**batch) 224 | bs = tf.shape(batch['images'])[0] // 2 225 | ratio = batch['mixup_ratio'][:bs, None, None, None] 226 | images = (ratio * batch['images'][:bs] + (1.0 - ratio) * batch['images'][bs:]) 227 | mix_labels = batch['labels'][bs:] 228 | labels = batch['labels'][:bs] 229 | ratio = ratio[..., 0, 0, 0] # Unsqueeze 230 | return {'images': images, 'labels': labels, 231 | 'mix_labels': mix_labels, 'ratio': ratio} 232 | 233 | 234 | def my_mixup_cutmix(batch): 235 | """Apply mixup to half the batch, and cutmix to the other.""" 236 | batch = dict(**batch) 237 | bs = tf.shape(batch['images'])[0] // 4 238 | mixup_ratio = batch['mixup_ratio'][:bs, None, None, None] 239 | mixup_images = (mixup_ratio * batch['images'][:bs] 240 | + (1.0 - mixup_ratio) * batch['images'][bs:2*bs]) 241 | mixup_labels = batch['labels'][:bs] 242 | mixup_mix_labels = batch['labels'][bs:2*bs] 243 | 244 | cutmix_mask = batch['mask'][2*bs:3*bs] 245 | 246 | cutmix_images = (cutmix_mask * batch['images'][2*bs:3*bs] 247 | + (1.0 - cutmix_mask) * batch['images'][-bs:]) 248 | cutmix_labels = batch['labels'][2*bs:3*bs] 249 | cutmix_mix_labels = batch['labels'][-bs:] 250 | cutmix_ratio = batch['cutmix_ratio'][2*bs : 3*bs] 251 | 252 | return {'images': tf.concat([mixup_images, cutmix_images], axis=0), 253 | 'labels': tf.concat([mixup_labels, cutmix_labels], axis=0), 254 | 'mix_labels': tf.concat([mixup_mix_labels, cutmix_mix_labels], 0), 255 | 'ratio': tf.concat([mixup_ratio[..., 0, 0, 0], cutmix_ratio], axis=0)} 256 | 257 | 258 | def _fill_rectangle(image, 259 | center_width, 260 | center_height, 261 | half_width, 262 | half_height, 263 | replace=None): 264 | """Fills blank area.""" 265 | image_height = tf.shape(image)[0] 266 | image_width = tf.shape(image)[1] 267 | 268 | lower_pad = tf.maximum(0, center_height - half_height) 269 | upper_pad = tf.maximum(0, image_height - center_height - half_height) 270 | left_pad = tf.maximum(0, center_width - half_width) 271 | right_pad = tf.maximum(0, image_width - center_width - half_width) 272 | 273 | cutout_shape = [ 274 | image_height - (lower_pad + upper_pad), 275 | image_width - (left_pad + right_pad) 276 | ] 277 | padding_dims = [[lower_pad, upper_pad], [left_pad, right_pad]] 278 | mask = tf.pad( 279 | tf.zeros(cutout_shape, dtype=image.dtype), 280 | padding_dims, 281 | constant_values=1) 282 | mask = tf.expand_dims(mask, -1) 283 | mask = tf.tile(mask, [1, 1, 3]) 284 | 285 | if replace is None: 286 | fill = tf.random.normal(tf.shape(image), dtype=image.dtype) 287 | elif isinstance(replace, tf.Tensor): 288 | fill = replace 289 | else: 290 | fill = tf.ones_like(image, dtype=image.dtype) * replace 291 | image = tf.where(tf.equal(mask, 0), fill, image) 292 | 293 | return image 294 | 295 | 296 | def fill_rectangle(image, image_width, image_height, half_width, half_height): 297 | center_height = tf.random.uniform( 298 | shape=[], 299 | minval=0, 300 | maxval=tf.cast(image_height - 2 * half_height, tf.int32), 301 | dtype=tf.int32) 302 | center_width = tf.random.uniform( 303 | shape=[], 304 | minval=0, 305 | maxval=tf.cast(image_width - 2 * half_width, tf.int32), 306 | dtype=tf.int32) 307 | 308 | image = _fill_rectangle( 309 | image, 310 | center_width, 311 | center_height, 312 | half_width, 313 | half_height, 314 | replace=None) 315 | return image 316 | 317 | 318 | class ImageAugment(object): 319 | """Image augmentation class for applying image distortions.""" 320 | 321 | def distort( 322 | self, 323 | image: tf.Tensor 324 | ) -> tf.Tensor: 325 | """Given an image tensor, returns a distorted image with the same shape. 326 | Args: 327 | image: `Tensor` of shape [height, width, 3] or 328 | [num_frames, height, width, 3] representing an image or image sequence. 329 | Returns: 330 | The augmented version of `image`. 331 | """ 332 | raise NotImplementedError() 333 | 334 | def distort_with_boxes( 335 | self, 336 | image: tf.Tensor, 337 | bboxes: tf.Tensor 338 | ) -> Tuple[tf.Tensor, tf.Tensor]: 339 | """Distorts the image and bounding boxes. 340 | Args: 341 | image: `Tensor` of shape [height, width, 3] or 342 | [num_frames, height, width, 3] representing an image or image sequence. 343 | bboxes: `Tensor` of shape [num_boxes, 4] or [num_frames, num_boxes, 4] 344 | representing bounding boxes for an image or image sequence. 345 | Returns: 346 | The augmented version of `image` and `bboxes`. 347 | """ 348 | raise NotImplementedError 349 | 350 | 351 | class RandomErasing(ImageAugment): 352 | """Applies RandomErasing to a single image. 353 | Reference: https://arxiv.org/abs/1708.04896 354 | Implementation is inspired by 355 | https://github.com/rwightman/pytorch-image-models. 356 | """ 357 | 358 | def __init__(self, 359 | probability: float = 0.25, 360 | min_area: float = 0.02, 361 | max_area: float = 1 / 3, 362 | min_aspect: float = 0.3, 363 | max_aspect: Optional[float] = None, 364 | min_count=1, 365 | max_count=1, 366 | trials=10): 367 | """Applies RandomErasing to a single image. 368 | Args: 369 | probability: Probability of augmenting the image. Defaults to `0.25`. 370 | min_area: Minimum area of the random erasing rectangle. Defaults to 371 | `0.02`. 372 | max_area: Maximum area of the random erasing rectangle. Defaults to `1/3`. 373 | min_aspect: Minimum aspect rate of the random erasing rectangle. Defaults 374 | to `0.3`. 375 | max_aspect: Maximum aspect rate of the random erasing rectangle. Defaults 376 | to `None`. 377 | min_count: Minimum number of erased rectangles. Defaults to `1`. 378 | max_count: Maximum number of erased rectangles. Defaults to `1`. 379 | trials: Maximum number of trials to randomly sample a rectangle that 380 | fulfills constraint. Defaults to `10`. 381 | """ 382 | self._probability = probability 383 | self._min_area = float(min_area) 384 | self._max_area = float(max_area) 385 | self._min_log_aspect = math.log(min_aspect) 386 | self._max_log_aspect = math.log(max_aspect or 1 / min_aspect) 387 | self._min_count = min_count 388 | self._max_count = max_count 389 | self._trials = trials 390 | 391 | def distort(self, image: tf.Tensor) -> tf.Tensor: 392 | """Applies RandomErasing to single `image`. 393 | Args: 394 | image (tf.Tensor): Of shape [height, width, 3] representing an image. 395 | Returns: 396 | tf.Tensor: The augmented version of `image`. 397 | """ 398 | uniform_random = tf.random.uniform(shape=[], minval=0., maxval=1.0) 399 | mirror_cond = tf.less(uniform_random, self._probability) 400 | image = tf.cond(mirror_cond, lambda: self._erase(image), lambda: image) 401 | return image 402 | 403 | @tf.function 404 | def _erase(self, image: tf.Tensor) -> tf.Tensor: 405 | """Erase an area.""" 406 | if self._min_count == self._max_count: 407 | count = self._min_count 408 | else: 409 | count = tf.random.uniform( 410 | shape=[], 411 | minval=int(self._min_count), 412 | maxval=int(self._max_count - self._min_count + 1), 413 | dtype=tf.int32) 414 | 415 | image_height = tf.shape(image)[0] 416 | image_width = tf.shape(image)[1] 417 | area = tf.cast(image_width * image_height, tf.float32) 418 | 419 | for _ in range(count): 420 | # Work around since break is not supported in tf.function 421 | is_trial_successful = False 422 | for _ in range(self._trials): 423 | erase_area = tf.random.uniform( 424 | shape=[], 425 | minval=area * self._min_area, 426 | maxval=area * self._max_area) 427 | aspect_ratio = tf.math.exp( 428 | tf.random.uniform( 429 | shape=[], 430 | minval=self._min_log_aspect, 431 | maxval=self._max_log_aspect)) 432 | 433 | half_height = tf.cast( 434 | tf.math.round(tf.math.sqrt(erase_area * aspect_ratio) / 2), 435 | dtype=tf.int32) 436 | half_width = tf.cast( 437 | tf.math.round(tf.math.sqrt(erase_area / aspect_ratio) / 2), 438 | dtype=tf.int32) 439 | 440 | do_erase = tf.logical_and( 441 | tf.logical_not(is_trial_successful), tf.less(2 * half_height, image_height), 442 | ) 443 | do_erase = tf.logical_and( 444 | do_erase, tf.less(2 * half_width, image_width), 445 | ) 446 | 447 | image, is_trial_successful = tf.cond( 448 | do_erase, 449 | lambda: (fill_rectangle(image, image_width, image_height, half_width, half_height), True), 450 | lambda: (image, is_trial_successful), 451 | ) 452 | 453 | return image -------------------------------------------------------------------------------- /models/stylegan_discriminator.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import numpy as np 4 | import jax 5 | from jax import random 6 | import jax.numpy as jnp 7 | import flax.linen as nn 8 | from typing import Any, Tuple, List, Callable 9 | import h5py 10 | from models import ops 11 | 12 | URLS = {'afhqcat': 'https://www.dropbox.com/s/qygbjkefyqyu9k9/stylegan2_discriminator_afhqcat.h5?dl=1', 13 | 'afhqdog': 'https://www.dropbox.com/s/kmoxbp33qswz64p/stylegan2_discriminator_afhqdog.h5?dl=1', 14 | 'afhqwild': 'https://www.dropbox.com/s/jz1hpsyt3isj6e7/stylegan2_discriminator_afhqwild.h5?dl=1', 15 | 'brecahad': 'https://www.dropbox.com/s/h0cb89hruo6pmyj/stylegan2_discriminator_brecahad.h5?dl=1', 16 | 'car': 'https://www.dropbox.com/s/2ghjrmxih7cic76/stylegan2_discriminator_car.h5?dl=1', 17 | 'cat': 'https://www.dropbox.com/s/zfhjsvlsny5qixd/stylegan2_discriminator_cat.h5?dl=1', 18 | 'church': 'https://www.dropbox.com/s/jlno7zeivkjtk8g/stylegan2_discriminator_church.h5?dl=1', 19 | 'cifar10': 'https://www.dropbox.com/s/eldpubfkl4c6rur/stylegan2_discriminator_cifar10.h5?dl=1', 20 | 'ffhq': 'https://www.dropbox.com/s/m42qy9951b7lq1s/stylegan2_discriminator_ffhq.h5?dl=1', 21 | 'horse': 'https://www.dropbox.com/s/19f5pxrcdh2g8cw/stylegan2_discriminator_horse.h5?dl=1', 22 | 'metfaces': 'https://www.dropbox.com/s/xnokaunql12glkd/stylegan2_discriminator_metfaces.h5?dl=1'} 23 | 24 | RESOLUTION = {'metfaces': 1024, 25 | 'ffhq': 1024, 26 | 'church': 256, 27 | 'cat': 256, 28 | 'horse': 256, 29 | 'car': 512, 30 | 'brecahad': 512, 31 | 'afhqwild': 512, 32 | 'afhqdog': 512, 33 | 'afhqcat': 512, 34 | 'cifar10': 32} 35 | 36 | C_DIM = {'metfaces': 0, 37 | 'ffhq': 0, 38 | 'church': 0, 39 | 'cat': 0, 40 | 'horse': 0, 41 | 'car': 0, 42 | 'brecahad': 0, 43 | 'afhqwild': 0, 44 | 'afhqdog': 0, 45 | 'afhqcat': 0, 46 | 'cifar10': 10} 47 | 48 | ARCHITECTURE = {'metfaces': 'resnet', 49 | 'ffhq': 'resnet', 50 | 'church': 'resnet', 51 | 'cat': 'resnet', 52 | 'horse': 'resnet', 53 | 'car': 'resnet', 54 | 'brecahad': 'resnet', 55 | 'afhqwild': 'resnet', 56 | 'afhqdog': 'resnet', 57 | 'afhqcat': 'resnet', 58 | 'cifar10': 'orig'} 59 | 60 | MBSTD_GROUP_SIZE = {'metfaces': None, 61 | 'ffhq': None, 62 | 'church': None, 63 | 'cat': None, 64 | 'horse': None, 65 | 'car': None, 66 | 'brecahad': None, 67 | 'afhqwild': None, 68 | 'afhqdog': None, 69 | 'afhqcat': None, 70 | 'cifar10': 32} 71 | 72 | def download(ckpt_dir, url): 73 | name = url[url.rfind('/') + 1 : url.rfind('?')] 74 | if ckpt_dir is None: 75 | ckpt_dir = tempfile.gettempdir() 76 | ckpt_dir = os.path.join(ckpt_dir, 'flaxmodels') 77 | ckpt_file = os.path.join(ckpt_dir, name) 78 | if not os.path.exists(ckpt_file): 79 | print(f'Downloading: \"{url[:url.rfind("?")]}\" to {ckpt_file}') 80 | if not os.path.exists(ckpt_dir): 81 | os.makedirs(ckpt_dir) 82 | 83 | response = requests.get(url, stream=True) 84 | total_size_in_bytes = int(response.headers.get('content-length', 0)) 85 | progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) 86 | 87 | # first create temp file, in case the download fails 88 | ckpt_file_temp = os.path.join(ckpt_dir, name + '.temp') 89 | with open(ckpt_file_temp, 'wb') as file: 90 | for data in response.iter_content(chunk_size=1024): 91 | progress_bar.update(len(data)) 92 | file.write(data) 93 | progress_bar.close() 94 | 95 | if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: 96 | print('An error occured while downloading, please try again.') 97 | if os.path.exists(ckpt_file_temp): 98 | os.remove(ckpt_file_temp) 99 | else: 100 | # if download was successful, rename the temp file 101 | os.rename(ckpt_file_temp, ckpt_file) 102 | return ckpt_file 103 | 104 | class FromRGBLayer(nn.Module): 105 | """ 106 | From RGB Layer. 107 | Attributes: 108 | fmaps (int): Number of output channels of the convolution. 109 | kernel (int): Kernel size of the convolution. 110 | lr_multiplier (float): Learning rate multiplier. 111 | activation (str): Activation function: 'relu', 'lrelu', etc. 112 | param_dict (h5py.Group): Parameter dict with pretrained parameters. If not None, 'pretrained' will be ignored. 113 | clip_conv (float): Clip the output of convolution layers to [-clip_conv, +clip_conv], None = disable clipping. 114 | dtype (str): Data dtype. 115 | rng (jax.random.PRNGKey): PRNG for initialization. 116 | """ 117 | fmaps: int 118 | kernel: int=1 119 | lr_multiplier: float=1 120 | activation: str='leaky_relu' 121 | param_dict: h5py.Group=None 122 | clip_conv: float=None 123 | dtype: str='float32' 124 | rng: Any=random.PRNGKey(0) 125 | 126 | @nn.compact 127 | def __call__(self, x, y): 128 | """ 129 | Run From RGB Layer. 130 | Args: 131 | x (tensor): Input image of shape [N, H, W, num_channels]. 132 | y (tensor): Input tensor of shape [N, H, W, out_channels]. 133 | Returns: 134 | (tensor): Output tensor of shape [N, H, W, out_channels]. 135 | """ 136 | w_shape = [self.kernel, self.kernel, x.shape[3], self.fmaps] 137 | w, b = ops.get_weight(w_shape, self.lr_multiplier, True, self.param_dict, 'fromrgb', self.rng) 138 | 139 | w = self.param(name='weight', init_fn=lambda *_ : w) 140 | b = self.param(name='bias', init_fn=lambda *_ : b) 141 | w = ops.equalize_lr_weight(w, self.lr_multiplier) 142 | b = ops.equalize_lr_bias(b, self.lr_multiplier) 143 | 144 | x = x.astype(self.dtype) 145 | x = ops.conv2d(x, w.astype(x.dtype)) 146 | x += b.astype(x.dtype) 147 | x = ops.apply_activation(x, activation=self.activation) 148 | if self.clip_conv is not None: 149 | x = jnp.clip(x, -self.clip_conv, self.clip_conv) 150 | if y is not None: 151 | x += y 152 | return x 153 | 154 | 155 | class DiscriminatorLayer(nn.Module): 156 | """ 157 | Discriminator Layer. 158 | Attributes: 159 | fmaps (int): Number of output channels of the convolution. 160 | kernel (int): Kernel size of the convolution. 161 | use_bias (bool): If True, use bias. 162 | down (bool): If True, downsample the spatial resolution. 163 | resample_kernel (Tuple): Kernel that is used for FIR filter. 164 | activation (str): Activation function: 'relu', 'lrelu', etc. 165 | layer_name (str): Layer name. 166 | param_dict (h5py.Group): Parameter dict with pretrained parameters. 167 | lr_multiplier (float): Learning rate multiplier. 168 | clip_conv (float): Clip the output of convolution layers to [-clip_conv, +clip_conv], None = disable clipping. 169 | dtype (str): Data dtype. 170 | rng (jax.random.PRNGKey): PRNG for initialization. 171 | """ 172 | fmaps: int 173 | kernel: int=3 174 | use_bias: bool=True 175 | down: bool=False 176 | resample_kernel: Tuple=None 177 | activation: str='leaky_relu' 178 | layer_name: str=None 179 | param_dict: h5py.Group=None 180 | lr_multiplier: float=1 181 | clip_conv: float=None 182 | dtype: str='float32' 183 | rng: Any=random.PRNGKey(0) 184 | 185 | @nn.compact 186 | def __call__(self, x): 187 | """ 188 | Run Discriminator Layer. 189 | Args: 190 | x (tensor): Input tensor of shape [N, H, W, C]. 191 | Returns: 192 | (tensor): Output tensor of shape [N, H, W, fmaps]. 193 | """ 194 | w_shape = [self.kernel, self.kernel, x.shape[3], self.fmaps] 195 | if self.use_bias: 196 | w, b = ops.get_weight(w_shape, self.lr_multiplier, self.use_bias, self.param_dict, self.layer_name, self.rng) 197 | else: 198 | w = ops.get_weight(w_shape, self.lr_multiplier, self.use_bias, self.param_dict, self.layer_name, self.rng) 199 | 200 | w = self.param(name='weight', init_fn=lambda *_ : w) 201 | w = ops.equalize_lr_weight(w, self.lr_multiplier) 202 | if self.use_bias: 203 | b = self.param(name='bias', init_fn=lambda *_ : b) 204 | b = ops.equalize_lr_bias(b, self.lr_multiplier) 205 | 206 | x = x.astype(self.dtype) 207 | x = ops.conv2d(x, w, down=self.down, resample_kernel=self.resample_kernel) 208 | if self.use_bias: x += b.astype(x.dtype) 209 | x = ops.apply_activation(x, activation=self.activation) 210 | if self.clip_conv is not None: 211 | x = jnp.clip(x, -self.clip_conv, self.clip_conv) 212 | return x 213 | 214 | 215 | class DiscriminatorBlock(nn.Module): 216 | """ 217 | Discriminator Block. 218 | Attributes: 219 | fmaps (int): Number of output channels of the convolution. 220 | kernel (int): Kernel size of the convolution. 221 | resample_kernel (Tuple): Kernel that is used for FIR filter. 222 | activation (str): Activation function: 'relu', 'lrelu', etc. 223 | param_dict (h5py.Group): Parameter dict with pretrained parameters. 224 | lr_multiplier (float): Learning rate multiplier. 225 | architecture (str): Architecture: 'orig', 'resnet'. 226 | nf (Callable): Callable that returns the number of feature maps for a given layer. 227 | clip_conv (float): Clip the output of convolution layers to [-clip_conv, +clip_conv], None = disable clipping. 228 | dtype (str): Data dtype. 229 | rng (jax.random.PRNGKey): Random seed for initialization. 230 | """ 231 | res: int 232 | kernel: int=3 233 | resample_kernel: Tuple=(1, 3, 3, 1) 234 | activation: str='leaky_relu' 235 | param_dict: Any=None 236 | lr_multiplier: float=1 237 | architecture: str='resnet' 238 | nf: Callable=None 239 | clip_conv: float=None 240 | dtype: str='float32' 241 | rng: Any=random.PRNGKey(0) 242 | 243 | @nn.compact 244 | def __call__(self, x): 245 | """ 246 | Run Discriminator Block. 247 | Args: 248 | x (tensor): Input tensor of shape [N, H, W, C]. 249 | Returns: 250 | (tensor): Output tensor of shape [N, H, W, fmaps]. 251 | """ 252 | init_rng = self.rng 253 | x = x.astype(self.dtype) 254 | residual = x 255 | for i in range(2): 256 | init_rng, init_key = random.split(init_rng) 257 | x = DiscriminatorLayer(fmaps=self.nf(self.res - (i + 1)), 258 | kernel=self.kernel, 259 | down=i == 1, 260 | resample_kernel=self.resample_kernel if i == 1 else None, 261 | activation=self.activation, 262 | layer_name=f'conv{i}', 263 | param_dict=self.param_dict, 264 | lr_multiplier=self.lr_multiplier, 265 | clip_conv=self.clip_conv, 266 | dtype=self.dtype, 267 | rng=init_key)(x) 268 | 269 | 270 | if self.architecture == 'resnet': 271 | init_rng, init_key = random.split(init_rng) 272 | residual = DiscriminatorLayer(fmaps=self.nf(self.res - 2), 273 | kernel=1, 274 | use_bias=False, 275 | down=True, 276 | resample_kernel=self.resample_kernel, 277 | activation='linear', 278 | layer_name='skip', 279 | param_dict=self.param_dict, 280 | lr_multiplier=self.lr_multiplier, 281 | dtype=self.dtype, 282 | rng=init_key)(residual) 283 | 284 | x = (x + residual) * np.sqrt(0.5, dtype=x.dtype) 285 | return x 286 | 287 | 288 | class stylegan_discriminator(nn.Module): 289 | """ 290 | Discriminator. 291 | Attributes: 292 | resolution (int): Input resolution. Overridden based on dataset. 293 | num_channels (int): Number of input color channels. Overridden based on dataset. 294 | c_dim (int): Dimensionality of the labels (c), 0 if no labels. Overrttten based on dataset. 295 | fmap_base (int): Overall multiplier for the number of feature maps. 296 | fmap_decay (int): Log2 feature map reduction when doubling the resolution. 297 | fmap_min (int): Minimum number of feature maps in any layer. 298 | fmap_max (int): Maximum number of feature maps in any layer. 299 | mapping_layers (int): Number of additional mapping layers for the conditioning labels. 300 | mapping_fmaps (int): Number of activations in the mapping layers, None = default. 301 | mapping_lr_multiplier (float): Learning rate multiplier for the mapping layers. 302 | architecture (str): Architecture: 'orig', 'resnet'. 303 | activation (int): Activation function: 'relu', 'leaky_relu', etc. 304 | mbstd_group_size (int): Group size for the minibatch standard deviation layer, None = entire minibatch. 305 | mbstd_num_features (int): Number of features for the minibatch standard deviation layer, 0 = disable. 306 | resample_kernel (Tuple): Low-pass filter to apply when resampling activations, None = box filter. 307 | num_fp16_res (int): Use float16 for the 'num_fp16_res' highest resolutions. 308 | clip_conv (float): Clip the output of convolution layers to [-clip_conv, +clip_conv], None = disable clipping. 309 | pretrained (str): Use pretrained model, None for random initialization. 310 | ckpt_dir (str): Directory to which the pretrained weights are downloaded. If None, a temp directory will be used. 311 | dtype (str): Data type. 312 | rng (jax.random.PRNGKey): PRNG for initialization. 313 | """ 314 | # Input dimensions. 315 | resolution: int=256 316 | num_channels: int=3 317 | c_dim: int=0 318 | 319 | # Capacity. 320 | fmap_base: int=16384 321 | fmap_decay: int=1 322 | fmap_min: int=1 323 | fmap_max: int=512 324 | 325 | # Internal details. 326 | mapping_layers: int=0 327 | mapping_fmaps: int=None 328 | mapping_lr_multiplier: float=0.1 329 | architecture: str='resnet' 330 | activation: str='leaky_relu' 331 | mbstd_group_size: int=None 332 | mbstd_num_features: int=1 333 | resample_kernel: Tuple=(1, 3, 3, 1) 334 | num_fp16_res: int=0 335 | clip_conv: float=None 336 | 337 | # Pretraining 338 | pretrained: str=None 339 | ckpt_dir: str=None 340 | 341 | dtype: str='float32' 342 | rng: Any=random.PRNGKey(0) 343 | 344 | def setup(self): 345 | self.resolution_ = self.resolution 346 | self.c_dim_ = self.c_dim 347 | self.architecture_ = self.architecture 348 | self.mbstd_group_size_ = self.mbstd_group_size 349 | self.param_dict = None 350 | if self.pretrained is not None: 351 | assert self.pretrained in URLS.keys(), f'Pretrained model not available: {self.pretrained}' 352 | ckpt_file = download(self.ckpt_dir, URLS[self.pretrained]) 353 | self.param_dict = h5py.File(ckpt_file, 'r')['discriminator'] 354 | self.resolution_ = RESOLUTION[self.pretrained] 355 | self.architecture_ = ARCHITECTURE[self.pretrained] 356 | self.mbstd_group_size_ = MBSTD_GROUP_SIZE[self.pretrained] 357 | self.c_dim_ = C_DIM[self.pretrained] 358 | 359 | assert self.architecture in ['orig', 'resnet'] 360 | 361 | @nn.compact 362 | def __call__(self, x, c=None, train=True): 363 | """ 364 | Run Discriminator. 365 | Args: 366 | x (tensor): Input image of shape [N, H, W, num_channels]. 367 | c (tensor): Input labels, shape [N, c_dim]. 368 | 369 | Returns: 370 | (tensor): Output tensor of shape [N, 1]. 371 | """ 372 | resolution_log2 = int(np.log2(self.resolution_)) 373 | assert self.resolution_ == 2**resolution_log2 and self.resolution_ >= 4 374 | def nf(stage): return np.clip(int(self.fmap_base / (2.0 ** (stage * self.fmap_decay))), self.fmap_min, self.fmap_max) 375 | if self.mapping_fmaps is None: 376 | mapping_fmaps = nf(0) 377 | else: 378 | mapping_fmaps = self.mapping_fmaps 379 | 380 | init_rng = self.rng 381 | # Label embedding and mapping. 382 | if self.c_dim_ > 0: 383 | c = ops.LinearLayer(in_features=self.c_dim_, 384 | out_features=mapping_fmaps, 385 | lr_multiplier=self.mapping_lr_multiplier, 386 | param_dict=self.param_dict, 387 | layer_name='label_embedding', 388 | dtype=self.dtype, 389 | rng=init_rng)(c) 390 | 391 | c = ops.normalize_2nd_moment(c) 392 | for i in range(self.mapping_layers): 393 | init_rng, init_key = random.split(init_rng) 394 | c = ops.LinearLayer(in_features=self.c_dim_, 395 | out_features=mapping_fmaps, 396 | lr_multiplier=self.mapping_lr_multiplier, 397 | param_dict=self.param_dict, 398 | layer_name=f'fc{i}', 399 | dtype=self.dtype, 400 | rng=init_key)(c) 401 | 402 | # Layers for >=8x8 resolutions. 403 | y = None 404 | for res in range(resolution_log2, 2, -1): 405 | res_str = f'block_{2**res}x{2**res}' 406 | if res == resolution_log2: 407 | init_rng, init_key = random.split(init_rng) 408 | x = FromRGBLayer(fmaps=nf(res - 1), 409 | kernel=1, 410 | activation=self.activation, 411 | param_dict=self.param_dict[res_str] if self.param_dict is not None else None, 412 | clip_conv=self.clip_conv, 413 | dtype=self.dtype if res >= resolution_log2 + 1 - self.num_fp16_res else 'float32', 414 | rng=init_key)(x, y) 415 | 416 | init_rng, init_key = random.split(init_rng) 417 | x = DiscriminatorBlock(res=res, 418 | kernel=3, 419 | resample_kernel=self.resample_kernel, 420 | activation=self.activation, 421 | param_dict=self.param_dict[res_str] if self.param_dict is not None else None, 422 | architecture=self.architecture_, 423 | nf=nf, 424 | clip_conv=self.clip_conv, 425 | dtype=self.dtype if res >= resolution_log2 + 1 - self.num_fp16_res else 'float32', 426 | rng=init_key)(x) 427 | 428 | # Layers for 4x4 resolution. 429 | dtype = jnp.float32 430 | x = x.astype(dtype) 431 | if self.mbstd_num_features > 0: 432 | x = ops.minibatch_stddev_layer(x, self.mbstd_group_size_, self.mbstd_num_features) 433 | init_rng, init_key = random.split(init_rng) 434 | x = DiscriminatorLayer(fmaps=nf(1), 435 | kernel=3, 436 | use_bias=True, 437 | activation=self.activation, 438 | layer_name='conv0', 439 | param_dict=self.param_dict['block_4x4'] if self.param_dict is not None else None, 440 | clip_conv=self.clip_conv, 441 | dtype=dtype, 442 | rng=init_rng)(x) 443 | 444 | # Switch to NCHW so that the pretrained weights still work after reshaping 445 | x = jnp.transpose(x, axes=(0, 3, 1, 2)) 446 | x = jnp.reshape(x, newshape=(-1, x.shape[1] * x.shape[2] * x.shape[3])) 447 | 448 | init_rng, init_key = random.split(init_rng) 449 | x = ops.LinearLayer(in_features=x.shape[1], 450 | out_features=nf(0), 451 | activation=self.activation, 452 | param_dict=self.param_dict['block_4x4'] if self.param_dict is not None else None, 453 | layer_name='fc0', 454 | dtype=dtype, 455 | rng=init_key)(x) 456 | 457 | # Output layer. 458 | init_rng, init_key = random.split(init_rng) 459 | x = ops.LinearLayer(in_features=x.shape[1], 460 | out_features=1 if self.c_dim_ == 0 else mapping_fmaps, 461 | param_dict=self.param_dict, 462 | layer_name='output', 463 | dtype=dtype, 464 | rng=init_key)(x) 465 | 466 | if self.c_dim_ > 0: 467 | x = jnp.sum(x * c, axis=1, keepdims=True) / jnp.sqrt(mapping_fmaps) 468 | 469 | return jnp.reshape(x, (-1)) 470 | -------------------------------------------------------------------------------- /models/models_vqgan.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Any, Callable, Iterable, Optional, Sequence, Tuple, Union, Dict 16 | 17 | from flax import struct 18 | import flax.linen as nn 19 | from flax.linen.module import merge_param 20 | import jax 21 | import jax.numpy as jnp 22 | from jax import lax, random 23 | from jax._src import dtypes 24 | 25 | import numpy as np 26 | import einops 27 | from models.models_vit import Encoder, space_to_depth, reverse_space_to_depth, DropPath 28 | from jax.nn import initializers 29 | from models.vgg import VGG 30 | from models.discriminator import Discriminator 31 | import collections 32 | from flax.linen.linear import DenseGeneral 33 | from flax.linen.linear import PrecisionLike 34 | from flax.linen.linear import default_kernel_init 35 | from flax.linen.initializers import zeros 36 | from flax.linen.attention import dot_product_attention 37 | import functools 38 | 39 | Array = jnp.ndarray 40 | DType = jnp.dtype 41 | PRNGKey = jnp.ndarray 42 | Shape = Iterable[int] 43 | 44 | Initializer = Callable[[PRNGKey, Shape, DType], Array] 45 | Dtype = Any 46 | 47 | DTypeLikeFloat = Any 48 | DTypeLikeComplex = Any 49 | DTypeLikeInexact = Any # DTypeLikeFloat | DTypeLikeComplex 50 | RealNumeric = Any # Scalar jnp array or float 51 | 52 | KeyArray = random.KeyArray 53 | Array = Any 54 | 55 | ACT2FN = { 56 | "tanh": nn.tanh, 57 | "relu": nn.relu, 58 | "swish": nn.swish, 59 | } 60 | 61 | def recover_tree(keys, values): 62 | """Recovers a tree as a nested dict from flat names and values. 63 | This function is useful to analyze checkpoints that are without need to access 64 | the exact source code of the experiment. In particular, it can be used to 65 | extract an reuse various subtrees of the scheckpoint, e.g. subtree of 66 | parameters. 67 | Args: 68 | keys: a list of keys, where '/' is used as separator between nodes. 69 | values: a list of leaf values. 70 | Returns: 71 | A nested tree-like dict. 72 | """ 73 | tree = {} 74 | sub_trees = collections.defaultdict(list) 75 | for k, v in zip(keys, values): 76 | if '.' not in k: 77 | tree[k] = v 78 | else: 79 | k_left, k_right = k.split('.', 1) 80 | sub_trees[k_left].append((k_right, v)) 81 | for k, kv_pairs in sub_trees.items(): 82 | k_subtree, v_subtree = zip(*kv_pairs) 83 | tree[k] = recover_tree(k_subtree, v_subtree) 84 | return tree 85 | 86 | 87 | def uniform(scale: RealNumeric = 1e-2, offset: RealNumeric = 0, 88 | dtype: DTypeLikeInexact = jnp.float_): 89 | """Builds an initializer that returns real uniformly-distributed random arrays. 90 | """ 91 | def init(key: KeyArray, 92 | shape: Shape, 93 | dtype: DTypeLikeInexact = dtype) -> Array: 94 | dtype = dtypes.canonicalize_dtype(dtype) 95 | return random.uniform(key, shape, dtype) * scale + offset 96 | return init 97 | 98 | def l2_normalize(x, axis=None, eps=1e-12): 99 | """Normalizes along dimension `axis` using an L2 norm. 100 | This specialized function exists for numerical stability reasons. 101 | Args: 102 | x: An input ndarray. 103 | axis: Dimension along which to normalize, e.g. `1` to separately normalize 104 | vectors in a batch. Passing `None` views `t` as a flattened vector when 105 | calculating the norm (equivalent to Frobenius norm). 106 | eps: Epsilon to avoid dividing by zero. 107 | Returns: 108 | An array of the same shape as 'x' L2-normalized along 'axis'. 109 | """ 110 | return x * jax.lax.rsqrt((x * x).sum(axis=axis, keepdims=True) + eps) 111 | 112 | class MlpBlock(nn.Module): 113 | """Transformer MLP / feed-forward block.""" 114 | 115 | mlp_dim: int 116 | dtype: Dtype = jnp.float32 117 | out_dim: Optional[int] = None 118 | dropout_rate: float = 0.0 119 | use_bias: bool = True 120 | act_fn: str = 'relu' 121 | kernel_init: Callable[[PRNGKey, Shape, Dtype], 122 | Array] = nn.initializers.xavier_uniform() 123 | bias_init: Callable[[PRNGKey, Shape, Dtype], 124 | Array] = nn.initializers.normal(stddev=1e-6) 125 | 126 | @nn.compact 127 | def __call__(self, inputs, *, deterministic): 128 | """Applies Transformer MlpBlock module.""" 129 | actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim 130 | x = nn.Dense( 131 | features=self.mlp_dim, 132 | dtype=self.dtype, 133 | use_bias=self.use_bias, 134 | kernel_init=self.kernel_init, 135 | bias_init=self.bias_init, 136 | )( # pytype: disable=wrong-arg-types 137 | inputs) 138 | 139 | x = ACT2FN[self.act_fn](x) 140 | x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic) 141 | 142 | output = nn.Dense( 143 | features=actual_out_dim, 144 | dtype=self.dtype, 145 | use_bias=self.use_bias, 146 | kernel_init=self.kernel_init, 147 | bias_init=self.bias_init, 148 | )( # pytype: disable=wrong-arg-types 149 | x) 150 | output = nn.Dropout( 151 | rate=self.dropout_rate)( 152 | output, deterministic=deterministic) 153 | return output 154 | 155 | 156 | class MultiHeadDotProductAttention(nn.Module): 157 | """Multi-head dot-product attention. 158 | 159 | Attributes: 160 | num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1]) 161 | should be divisible by the number of heads. 162 | dtype: the dtype of the computation 163 | (default: infer from inputs and params) 164 | param_dtype: the dtype passed to parameter initializers (default: float32) 165 | qkv_features: dimension of the key, query, and value. 166 | out_features: dimension of the last projection 167 | broadcast_dropout: bool: use a broadcasted dropout along batch dims. 168 | dropout_rate: dropout rate 169 | deterministic: if false, the attention weight is masked randomly 170 | using dropout, whereas if true, the attention weights 171 | are deterministic. 172 | precision: numerical precision of the computation see `jax.lax.Precision` 173 | for details. 174 | kernel_init: initializer for the kernel of the Dense layers. 175 | bias_init: initializer for the bias of the Dense layers. 176 | use_bias: bool: whether pointwise QKVO dense transforms use bias. 177 | attention_fn: dot_product_attention or compatible function. Accepts 178 | query, key, value, and returns output of shape 179 | `[bs, dim1, dim2, ..., dimN,, num_heads, value_channels]`` 180 | decode: whether to prepare and use an autoregressive cache. 181 | """ 182 | num_heads: int 183 | dtype: Optional[Dtype] = None 184 | param_dtype: Dtype = jnp.float32 185 | qkv_features: Optional[int] = None 186 | out_features: Optional[int] = None 187 | broadcast_dropout: bool = True 188 | dropout_rate: float = 0. 189 | deterministic: Optional[bool] = None 190 | precision: PrecisionLike = None 191 | kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init 192 | bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = zeros 193 | use_bias: bool = True 194 | attention_fn: Callable[[Array, Array, Array], Array] = dot_product_attention 195 | decode: bool = False 196 | params_init: Any = None 197 | 198 | @nn.compact 199 | def __call__(self, 200 | inputs_q: Array, 201 | inputs_kv: Array, 202 | mask: Optional[Array] = None, 203 | deterministic: Optional[bool] = None): 204 | """Applies multi-head dot product attention on the input data. 205 | 206 | Projects the inputs into multi-headed query, key, and value vectors, 207 | applies dot-product attention and project the results to an output vector. 208 | 209 | Args: 210 | inputs_q: input queries of shape 211 | `[batch_sizes..., length, features]`. 212 | inputs_kv: key/values of shape 213 | `[batch_sizes..., length, features]`. 214 | mask: attention mask of shape 215 | `[batch_sizes..., num_heads, query_length, key/value_length]`. 216 | Attention weights are masked out if their corresponding mask value 217 | is `False`. 218 | deterministic: if false, the attention weight is masked randomly 219 | using dropout, whereas if true, the attention weights 220 | are deterministic. 221 | 222 | Returns: 223 | output of shape `[batch_sizes..., length, features]`. 224 | """ 225 | features = self.out_features or inputs_q.shape[-1] 226 | qkv_features = self.qkv_features or inputs_q.shape[-1] 227 | assert qkv_features % self.num_heads == 0, ( 228 | 'Memory dimension must be divisible by number of heads.') 229 | head_dim = qkv_features // self.num_heads 230 | 231 | dense = functools.partial(DenseGeneral, 232 | axis=-1, 233 | dtype=self.dtype, 234 | param_dtype=self.param_dtype, 235 | features=(self.num_heads, head_dim), 236 | use_bias=False, 237 | precision=self.precision) 238 | 239 | # project inputs_q to multi-headed q/k/v 240 | # dimensions are then [batch..., length, n_heads, n_features_per_head] 241 | 242 | if self.params_init is not None: 243 | qkv_kernel = jnp.split(jnp.transpose(np.array(self.params_init['to_qkv']['weight']), (1,0)), 3, axis=1) 244 | 245 | query_kernel_init = self.kernel_init if self.params_init is None else lambda *_ : jnp.array(qkv_kernel[0]) 246 | key_kernel_init = self.kernel_init if self.params_init is None else lambda *_ : jnp.array(qkv_kernel[1]) 247 | value_kernel_init = self.kernel_init if self.params_init is None else lambda *_ : jnp.array(qkv_kernel[2]) 248 | 249 | query, key, value = (dense(kernel_init=query_kernel_init, name='query')(inputs_q), 250 | dense(kernel_init=key_kernel_init, name='key')(inputs_kv), 251 | dense(kernel_init=value_kernel_init, name='value')(inputs_kv)) 252 | 253 | dropout_rng = None 254 | if self.dropout_rate > 0.: # Require `deterministic` only if using dropout. 255 | m_deterministic = merge_param('deterministic', self.deterministic, 256 | deterministic) 257 | if not m_deterministic: 258 | dropout_rng = self.make_rng('dropout') 259 | else: 260 | m_deterministic = True 261 | 262 | # apply attention 263 | x = self.attention_fn( 264 | query, 265 | key, 266 | value, 267 | mask=mask, 268 | dropout_rng=dropout_rng, 269 | dropout_rate=self.dropout_rate, 270 | broadcast_dropout=self.broadcast_dropout, 271 | deterministic=m_deterministic, 272 | dtype=self.dtype, 273 | precision=self.precision) # pytype: disable=wrong-keyword-args 274 | # back to the original inputs dimensions 275 | 276 | out_kernel_init = self.kernel_init if self.params_init is None else lambda *_ : jnp.transpose(jnp.array(self.params_init['to_out']['weight']), (1,0)) 277 | out_bias_init = nn.initializers.zeros if self.params_init is None else lambda *_ : jnp.array(self.params_init['to_out']['bias']) 278 | 279 | out = DenseGeneral(features=features, 280 | axis=(-2, -1), 281 | kernel_init=out_kernel_init, 282 | bias_init=out_bias_init, 283 | use_bias=self.config.use_bias, 284 | dtype=self.dtype, 285 | param_dtype=self.param_dtype, 286 | precision=self.precision, 287 | name='out')(x) 288 | return out 289 | 290 | 291 | class TransformerLayer(nn.Module): 292 | mlp_dim: int 293 | num_heads: int 294 | dtype: Dtype = jnp.float32 295 | dropout_rate: float = 0.0 296 | droppath_rate: float = 0.0 297 | attention_dropout_rate: float = 0.0 298 | use_bias: bool = False 299 | act_fn: str = 'relu' 300 | 301 | @nn.compact 302 | def __call__(self, inputs, *, deterministic): 303 | assert inputs.ndim == 3, f'Expected (batch, seq, hidden) got {inputs.shape}' 304 | 305 | x = nn.LayerNorm(dtype=self.dtype)(inputs) 306 | 307 | # x = MultiHeadDotProductAttention( 308 | # dtype=self.dtype, 309 | # kernel_init=nn.initializers.xavier_uniform(), 310 | # broadcast_dropout=False, 311 | # deterministic=deterministic, 312 | # dropout_rate=self.attention_dropout_rate, 313 | # num_heads=self.num_heads, 314 | # use_bias=self.use_bias)( 315 | # x, x) 316 | 317 | x = nn.MultiHeadDotProductAttention( 318 | num_heads=self.num_heads, 319 | dtype=self.dtype, 320 | broadcast_dropout=False, 321 | deterministic=deterministic, 322 | dropout_rate=self.attention_dropout_rate, 323 | use_bias=self.use_bias, 324 | )(x, x) 325 | 326 | x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic) 327 | x = DropPath(rate=self.droppath_rate)(x, deterministic=deterministic) + inputs 328 | 329 | # MLP block. 330 | y = nn.LayerNorm(dtype=self.dtype)(x) 331 | 332 | y = MlpBlock( 333 | mlp_dim=self.mlp_dim, 334 | dtype=self.dtype, 335 | act_fn=self.act_fn, 336 | dropout_rate=self.dropout_rate)( 337 | y, deterministic=deterministic) 338 | 339 | return x + DropPath(rate=self.droppath_rate)(y, deterministic=deterministic) 340 | 341 | 342 | class Transformer(nn.Module): 343 | """Transformer Model for sequence to sequence translation. 344 | 345 | Attributes: 346 | num_layers: number of layers 347 | mlp_dim: dimension of the mlp on top of attention block 348 | num_heads: Number of heads in nn.MultiHeadDotProductAttention 349 | dropout_rate: dropout rate. 350 | attention_dropout_rate: dropout rate in self attention. 351 | """ 352 | 353 | num_layers: int 354 | mlp_dim: int 355 | num_heads: int 356 | dtype: Dtype = jnp.float32 357 | dropout_rate: float = 0.0 358 | droppath_rate: float = 0.0 359 | attention_dropout_rate: float = 0.0 360 | add_position_embedding: bool = True 361 | use_bias: bool = False 362 | act_fn: str = 'relu' 363 | 364 | @nn.compact 365 | def __call__(self, x, *, train): 366 | 367 | assert x.ndim == 3 # (batch, len, emb) 368 | 369 | x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train) 370 | dpr = [x for x in np.linspace(0, self.droppath_rate, self.num_layers)] 371 | for lyr in range(self.num_layers): 372 | x = TransformerLayer( 373 | mlp_dim=self.mlp_dim, 374 | dropout_rate=self.dropout_rate, 375 | dtype=self.dtype, 376 | droppath_rate=dpr[lyr], 377 | attention_dropout_rate=self.attention_dropout_rate, 378 | name=f'encoderblock_{lyr}', 379 | num_heads=self.num_heads, 380 | use_bias=self.use_bias, 381 | act_fn=self.act_fn, 382 | )(x, deterministic=not train) 383 | 384 | x = nn.LayerNorm(name='encoder_norm')(x) 385 | return x 386 | 387 | 388 | class VectorQuantizer(nn.Module): 389 | n_e: int 390 | e_dim: int 391 | beta: float = 0.25 392 | embedding_init: Callable[[PRNGKey, Shape, DType], Array] = uniform(2.0, -1.0) 393 | dtype: Any = jnp.float32 394 | param_dict: Any = None 395 | 396 | def setup(self): 397 | 398 | kernel_init = self.embedding_init if self.param_dict is None \ 399 | else lambda *_ : jnp.array(self.param_dict['embedding']['weight']) 400 | 401 | self.embedding = self.param( 402 | 'embedding', 403 | kernel_init, (self.n_e, self.e_dim), 404 | jnp.float32) 405 | 406 | def get_codebook_entry(self, indices): 407 | # indices are expected to be of shape (batch, num_tokens) 408 | # get quantized latent vectors 409 | z_q = jnp.take(self.embedding, indices, axis=0) 410 | # normalize latent variable (Ze(x) in the paper) 411 | z_q = l2_normalize(z_q, axis=-1) 412 | return z_q 413 | 414 | @nn.compact 415 | def __call__(self, z: Array) -> Array: 416 | 417 | z_reshaped = jnp.reshape(z, (-1, self.e_dim)) 418 | # first normalize the input. 419 | z_reshaped_norm = l2_normalize(z_reshaped, axis=-1) #/ jnp.linalg.norm(z_reshaped, axis=-1, keepdims=True) 420 | embedding_norm = l2_normalize(self.embedding, axis=-1) #/ jnp.linalg.norm(self.embedding, axis=-1, keepdims=True) 421 | 422 | # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z 423 | d = jnp.sum(z_reshaped_norm ** 2, axis=1, keepdims=True) + \ 424 | jnp.sum(embedding_norm ** 2, axis=1) - 2 * \ 425 | jnp.einsum('ij,kj->ik', z_reshaped_norm, embedding_norm) 426 | 427 | min_encoding_indices = jnp.reshape(jnp.argmin(d, axis=1), z.shape[:-1]) 428 | 429 | # z_q = jnp.take(self.embedding, min_encoding_indices, axis=0) 430 | z_q = self.get_codebook_entry(min_encoding_indices) 431 | z_norm = l2_normalize(z, axis=-1) 432 | 433 | # e_mean = jnp.mean(min_encoding_indices, axis=0) 434 | # perplexity = jnp.exp(-jnp.sum(e_mean * jnp.log(e_mean + 1e-10))) 435 | perplexity = None 436 | min_encodings = None 437 | 438 | loss = self.beta * jnp.mean(jnp.square((jax.lax.stop_gradient(z_q)-z_norm))) + \ 439 | jnp.mean(jnp.square((z_q - jax.lax.stop_gradient(z_norm)))) 440 | 441 | z_q = z + jax.lax.stop_gradient(z_q - z) 442 | 443 | return z_q, loss, (perplexity, min_encodings, min_encoding_indices) 444 | 445 | def get_2d_sincos_pos_embed(emb_dim, image_size, image_patch_size, dtype, class_token=False, temperature=10000.): 446 | """ 447 | (Absolute, additive) 2D sinusoidal positional embeddings used in MoCo v3, MAE 448 | Args: 449 | emb_dim (int): embedding dimension 450 | image_size (tuple): image size 451 | image_patch_size (int): image patch size 452 | class_token (bool): whether to use class token 453 | """ 454 | h, w = image_size[0] // image_patch_size[0], image_size[1] // image_patch_size[1] 455 | grid_h = jnp.arange(h, dtype=jnp.float32) 456 | grid_w = jnp.arange(w, dtype=jnp.float32) 457 | grid_w, grid_h = jnp.meshgrid(grid_w, grid_h, indexing='xy') 458 | 459 | assert emb_dim % 4 == 0, 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding' 460 | emb_w = get_1d_sincos_pos_embed_from_grid(emb_dim // 2, grid_w, jnp.float32, temperature) # (H*W, D/2) 461 | emb_h = get_1d_sincos_pos_embed_from_grid(emb_dim // 2, grid_h, jnp.float32, temperature) # (H*W, D/2) 462 | pos_emb = jnp.concatenate([emb_w, emb_h], axis=1) # (H*W, D) 463 | if class_token: 464 | pos_emb = jnp.concatenate([jnp.zeros([1, emb_dim], dtype=pos_emb.dtype), pos_emb], axis=0) 465 | pos_emb = pos_emb.astype(dtype) 466 | return pos_emb 467 | 468 | def get_1d_sincos_pos_embed_from_grid(emb_dim, pos, dtype, temperature=10000.): 469 | """ 470 | (Absolute, additive) 1D sinusoidal positional embeddings used in MoCo v3, MAE 471 | Args: 472 | emb_dim (int):output dimension for each position 473 | pos: a list of positions to be encoded: size (M, ) 474 | out: (M, D) 475 | """ 476 | assert emb_dim % 2 == 0 477 | omega = jnp.arange(emb_dim // 2, dtype=jnp.float32) 478 | omega /= emb_dim / 2. 479 | omega = 1. / temperature**omega # (D/2,) 480 | 481 | pos = pos.reshape(-1).astype(jnp.float32) # (M,) 482 | out = jnp.einsum('m,d->md', pos, omega) # (M, D/2), outer product 483 | 484 | emb_sin = jnp.sin(out) # (M, D/2) 485 | emb_cos = jnp.cos(out) # (M, D/2) 486 | 487 | emb = jnp.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 488 | return emb.astype(dtype) 489 | 490 | class Generator(nn.Module): 491 | """An encoder-decoder Transformer model.""" 492 | vocab_size: int 493 | proj_dim: int 494 | patch_size: Any 495 | encoder_hidden_size: int 496 | encoder_num_layers: int 497 | encoder_mlp_dim: int 498 | encoder_num_heads: int 499 | decoder_hidden_size: int 500 | decoder_num_layers: int 501 | decoder_mlp_dim: int 502 | decoder_num_heads: int 503 | dtype: Any = jnp.float32 504 | dropout_rate: float = 0.0 505 | droppath_rate: float = 0.0 506 | attention_dropout_rate: float = 0.0 507 | add_position_embedding: bool = False 508 | head_bias_init: float = 0. 509 | default_input_size: Any = (256, 256) 510 | output_channel: int = 3 511 | use_bias: bool = False 512 | act_fn: str = 'relu' 513 | 514 | def setup(self): 515 | self.encoder_position_embedding = get_2d_sincos_pos_embed( 516 | emb_dim=self.encoder_hidden_size, 517 | image_size=self.default_input_size, 518 | image_patch_size=self.patch_size, 519 | dtype=self.dtype, 520 | class_token=False, 521 | ) 522 | 523 | self.decoder_position_embedding = get_2d_sincos_pos_embed( 524 | emb_dim=self.decoder_hidden_size, 525 | image_size=self.default_input_size, 526 | image_patch_size=self.patch_size, 527 | dtype=self.dtype, 528 | class_token=False, 529 | ) 530 | 531 | def encode(self, x, train=True): 532 | x = space_to_depth(x, spatial_block_size=self.patch_size) 533 | x = nn.Dense( 534 | features=self.encoder_hidden_size, 535 | dtype=self.dtype, 536 | name='embedding', 537 | )(x) 538 | 539 | x += jnp.expand_dims(self.encoder_position_embedding, 0) 540 | 541 | x = Transformer( 542 | num_layers=self.encoder_num_layers, 543 | mlp_dim=self.encoder_mlp_dim, 544 | num_heads=self.encoder_num_heads, 545 | dtype=self.dtype, 546 | dropout_rate=self.dropout_rate, 547 | droppath_rate=self.droppath_rate, 548 | attention_dropout_rate=self.attention_dropout_rate, 549 | add_position_embedding=self.add_position_embedding, 550 | use_bias=self.use_bias, 551 | act_fn=self.act_fn 552 | )(x, train=train) 553 | 554 | x = ACT2FN[self.act_fn](x) 555 | 556 | x = nn.Dense( 557 | features=self.proj_dim, 558 | dtype=self.dtype, 559 | use_bias=self.use_bias, 560 | name='encoder_proj' 561 | )(x) 562 | 563 | x = nn.LayerNorm(use_scale=False, name='encoder_norm')(x) 564 | return x 565 | 566 | def decode(self, x, image_shape, train=True): 567 | x = nn.Dense( 568 | features=self.decoder_hidden_size, 569 | dtype=self.dtype, 570 | use_bias=self.use_bias, 571 | name='decoder_proj' 572 | )(x) 573 | 574 | x += jnp.expand_dims(self.decoder_position_embedding, 0) 575 | 576 | x = Transformer( 577 | num_layers=self.decoder_num_layers, 578 | mlp_dim=self.decoder_mlp_dim, 579 | num_heads=self.decoder_num_heads, 580 | dtype=self.dtype, 581 | dropout_rate=self.dropout_rate, 582 | droppath_rate=self.droppath_rate, 583 | attention_dropout_rate=self.attention_dropout_rate, 584 | add_position_embedding=self.add_position_embedding, 585 | use_bias=self.use_bias, 586 | act_fn=self.act_fn 587 | )(x, train=train) 588 | 589 | img_size = self.default_input_size 590 | x = jnp.reshape(x, (-1, img_size[0] // self.patch_size[0], img_size[1] // self.patch_size[1], self.decoder_hidden_size)) 591 | 592 | x = nn.ConvTranspose( 593 | features = self.output_channel, 594 | kernel_size = self.patch_size, 595 | strides = self.patch_size, 596 | use_bias=self.use_bias, 597 | )(x) 598 | return x 599 | 600 | @nn.compact 601 | def __call__(self, x, *, train): 602 | 603 | h = self.encode(x, train=train) 604 | quant, emb_loss, _ = VectorQuantizer( 605 | n_e=self.vocab_size, 606 | e_dim=self.proj_dim, 607 | beta=0.25, 608 | )(h) 609 | 610 | rec = self.decode(quant, x.shape[1:], train=train) 611 | return rec, emb_loss --------------------------------------------------------------------------------