├── lqae ├── __init__.py ├── main │ ├── __init__.py │ ├── eval_loss.py │ ├── linear_main.py │ ├── finetune_main.py │ └── lqae_main.py ├── data │ ├── __init__.py │ └── data.py ├── models │ ├── __init__.py │ ├── model_utils.py │ ├── base_resnet.py │ ├── base_vit.py │ ├── vqae.py │ └── lqae.py ├── utils.py └── jax_utils.py ├── scripts ├── gpu_requirement.yml └── tpu_vm_setup.sh ├── README.md ├── .gitignore └── jobs └── tpu_control.sh /lqae/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lqae/main/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lqae/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .data import ImageNetDataset, ImageTextDataset 2 | -------------------------------------------------------------------------------- /lqae/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .lqae import LQAE 2 | from .vqae import VQAE 3 | -------------------------------------------------------------------------------- /scripts/gpu_requirement.yml: -------------------------------------------------------------------------------- 1 | name: lqae 2 | channels: 3 | - defaults 4 | - nvidia 5 | - conda-forge 6 | dependencies: 7 | - python=3.8 8 | - pip 9 | - numpy 10 | - scipy 11 | - numba 12 | - h5py 13 | - matplotlib 14 | - scikit-learn 15 | - jupyter 16 | - tqdm 17 | - protobuf 18 | - pytorch=1.13.0 19 | - jax=0.3.25 20 | - jaxlib=*=*cuda* 21 | - cudatoolkit=11.7 22 | - cuda-nvcc=11.7 23 | - cudnn 24 | - scikit-image 25 | - pip: 26 | - flax==0.6.0 27 | - optax==0.1.3 28 | - distrax==0.1.2 29 | - transformers==4.24.0 30 | - datasets==2.7.0 31 | - einops 32 | - tensorflow==2.11.0 33 | - dill 34 | - seaborn 35 | - absl-py 36 | - opencv-python 37 | - joblib 38 | - wandb==0.13.5 39 | - ml_collections 40 | - gcsfs==2022.11.0 41 | - requests 42 | - jupyter_http_over_ws 43 | - einops 44 | - torchvision==0.14.1 45 | - timm==0.5.4 46 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Language Quantized AutoEncoders 2 | 3 | This is a Jax implementation of our work [Language Quantized AutoEncoders](https://arxiv.org/abs/2302.00902). 4 | 5 | It contains training and evalutation code. 6 | 7 | This implementation has been tested on multi-GPU and Google Cloud TPU and supports both multi-host training with TPU Pods and multi-GPU training. 8 | 9 | ## Usage 10 | Experiments can be launched via the following commands. 11 | 12 | An example script of launching a LQAE training job is: 13 | ``` 14 | export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" 15 | export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )" 16 | cd $PROJECT_DIR 17 | export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR" 18 | 19 | echo $PYTHONPATH 20 | export WANDB_API_KEY='' 21 | 22 | export experiment_name='lqae' 23 | export project_id='lqae' 24 | export wu='5' 25 | export ep='100' 26 | export model='lqae' 27 | export experiment_note="" 28 | export experiment_id="lqae-base" 29 | 30 | python3 -m lqae.main.lqae_main \ 31 | --model_type="$model" \ 32 | --lqae.bert_min_ratio=0.5 \ 33 | --lqae.bert_max_ratio=0.5 \ 34 | --lqae.quantizer_loss_commitment=0.005 \ 35 | --lqae.quantizer_loss_entropy=0.0 \ 36 | --lqae.quantizer_loss_perplexity=0.0 \ 37 | --lqae.l2_normalize=True \ 38 | --lqae.top_k_value=1 \ 39 | --lqae.top_k_avg=False \ 40 | --lqae.top_k_rnd=False \ 41 | --lqae.vit_encoder_decoder=True \ 42 | --lqae.vit_model_type='base' \ 43 | --lqae.patch_size=16 \ 44 | --lqae.use_bert_codebook=True \ 45 | --lqae.bert_mask_loss_weight=0.0001 \ 46 | --lqae.bert_channel_image_loss_weight=1.0 \ 47 | --lqae.nochannel_image_loss_weight=0.0 \ 48 | --lqae.quantizer_latent_dim=0 \ 49 | --lqae.strawman_codebook=False \ 50 | --lqae.use_bert_ste=False \ 51 | --seed=42 \ 52 | --epochs="$ep" \ 53 | --lr_warmup_epochs="$wu" \ 54 | --batch_size=512 \ 55 | --dataloader_n_workers=16 \ 56 | --log_freq=500 \ 57 | --plot_freq=2000 \ 58 | --save_model_freq=10000 \ 59 | --lr_peak_value=1.5e-4 \ 60 | --weight_decay=0.0005 \ 61 | --load_checkpoint='' \ 62 | --dataset='imagenet' \ 63 | --imagenet_data.path="YOUR IMAGENET FILE in HDF5" \ 64 | --imagenet_data.random_start=True \ 65 | --log_all_worker=False \ 66 | --logging.online=True \ 67 | --logging.project_id="$project_id" \ 68 | --logging.experiment_id="$experiment_id" \ 69 | --logging.experiment_note="$experiment_note" \ 70 | --logging.output_dir="$HOME/experiment_output/$project_id" 71 | 72 | ``` 73 | 74 | Example of running LLM based evaluation using LQAE pretrained model is at this [colab](https://colab.research.google.com/drive/1_nzC8W6yO9fYP8GLfUmY11hoVQUW9e6Q?usp=sharing). 75 | 76 | To run experiments more conveniently on TPUs, you may want to use the script in jobs folder to manage TPUs jobs. 77 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | 140 | # customized 141 | logs/ 142 | backup/ 143 | temp/ 144 | wandb/ 145 | local/ 146 | 147 | # Job scripts 148 | jobs/*.sh 149 | !jobs/tpu_control.sh 150 | !jobs/pod_test.sh 151 | -------------------------------------------------------------------------------- /scripts/tpu_vm_setup.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | sudo apt-get update && sudo apt-get install -y \ 4 | build-essential \ 5 | python-is-python3 \ 6 | tmux \ 7 | htop \ 8 | git \ 9 | nodejs \ 10 | bmon \ 11 | p7zip-full \ 12 | nfs-common \ 13 | libosmesa6-dev \ 14 | patchelf \ 15 | golang 16 | 17 | 18 | # Python dependencies 19 | cat > $HOME/tpu_requirements.txt <<- EndOfFile 20 | -f https://storage.googleapis.com/jax-releases/libtpu_releases.html 21 | jax[tpu]==0.3.25 22 | tensorflow==2.11.0 23 | flax==0.6.0 24 | optax==0.1.3 25 | distrax==0.1.2 26 | --extra-index-url https://download.pytorch.org/whl/cpu 27 | torch==1.12.1 28 | transformers==4.24.0 29 | datasets==2.7.0 30 | tqdm 31 | ml_collections 32 | wandb==0.13.5 33 | gcsfs==2022.11.0 34 | requests 35 | typing_extensions==4.2.0 36 | protobuf==3.20.2 37 | ipython 38 | cloudpickle==2.0.0 39 | torchvision==0.14.1 40 | timm==0.5.4 41 | einops 42 | h5py 43 | scikit-image==0.19.2 44 | dill 45 | EndOfFile 46 | 47 | pip install -r $HOME/tpu_requirements.txt 48 | 49 | 50 | # VIM configurations 51 | cat > $HOME/.vimrc <<- EndOfFile 52 | set tabstop=4 53 | set shiftwidth=4 54 | set softtabstop=4 55 | set expandtab 56 | set backspace=indent,eol,start 57 | syntax on 58 | EndOfFile 59 | 60 | # Tmux configurations 61 | cat > $HOME/.tmux.conf <<- EndOfFile 62 | bind r source-file ~/.tmux.conf \; display-message "█▓░ ~/.tmux.conf reloaded." 63 | 64 | # Enable colors, https://github.com/tmux/tmux/wiki/FAQ 65 | set -g default-terminal "tmux-256color" 66 | 67 | # start with window 1 (instead of 0) 68 | set -g base-index 1 69 | setw -g pane-base-index 1 70 | 71 | set -g prefix C-a 72 | 73 | set -g set-titles on 74 | set -g set-titles-string '#(whoami)::#h::#(curl ipecho.net/plain;echo)' 75 | 76 | # Status bar customization 77 | set -g status-interval 5 78 | set -g status-left-length 90 79 | set -g status-right-length 60 80 | set -g status-justify left 81 | 82 | # send the prefix to client inside window (ala nested sessions) 83 | bind-key a send-prefix 84 | 85 | bind-key x kill-pane 86 | 87 | # auto reorder 88 | set-option -g renumber-windows on 89 | 90 | # default window name 91 | set -g status-left "#[fg=green,bg=colour236] #S " 92 | 93 | # default statusbar colors 94 | set-option -g status-style fg=yellow,dim,bg=colour235 95 | 96 | # default window title colors 97 | set-window-option -g window-status-style fg=yellow,bg=colour236,dim 98 | 99 | # active window title colors 100 | set-window-option -g window-status-current-style fg=brightred,bg=colour236 101 | 102 | # basename as window title https://stackoverflow.com/a/37136828 103 | set-window-option -g window-status-current-format '#{window_index} #{pane_current_command} #(echo "#{pane_current_path}" | rev | cut -d'/' -f-3 | rev)' 104 | set-window-option -g window-status-format '#{window_index} #{pane_current_command} #(echo "#{pane_current_path}" | rev | cut -d'/' -f-3 | rev)' 105 | 106 | # pane border 107 | set-option -g pane-border-style fg=white #base2 108 | set-option -g pane-active-border-style fg=brightcyan #base1 109 | 110 | # enable mouse click 111 | set -g mouse on 112 | 113 | # keep window on 114 | set -g remain-on-exit on 115 | 116 | # Longer scrollback history 117 | set -g history-limit 50000 118 | 119 | # Scroll position indicator 120 | set -g mode-style bg=colour235,fg=colour245 121 | 122 | # SSH agent forwarding 123 | # set-environment -g SSH_AUTH_SOCK $SSH_AUTH_SOCK 124 | if-shell '[ -n $SSH_AUTH_SOCK ]' " \ 125 | set-option -sg update-environment \"DISPLAY WINDOWID XAUTHORITY\"; \ 126 | setenv -g SSH_AUTH_SOCK /tmp/ssh_auth_sock_tmux; \ 127 | run-shell \"ln -sf $(find /tmp/ssh-* -type s -readable | head -n 1) /tmp/ssh_auth_sock_tmux\" \ 128 | " 129 | 130 | # Drag windows on the status bar 131 | bind-key -n MouseDrag1Status swap-window -t= 132 | EndOfFile 133 | 134 | 135 | # HTop Configurations 136 | mkdir -p $HOME/.config/htop 137 | cat > $HOME/.config/htop/htoprc <<- EndOfFile 138 | # Beware! This file is rewritten by htop when settings are changed in the interface. 139 | # The parser is also very primitive, and not human-friendly. 140 | fields=0 48 17 18 38 39 40 2 46 47 49 1 141 | sort_key=46 142 | sort_direction=1 143 | hide_threads=0 144 | hide_kernel_threads=1 145 | hide_userland_threads=1 146 | shadow_other_users=0 147 | show_thread_names=0 148 | show_program_path=1 149 | highlight_base_name=0 150 | highlight_megabytes=1 151 | highlight_threads=1 152 | tree_view=0 153 | header_margin=1 154 | detailed_cpu_time=0 155 | cpu_count_from_zero=0 156 | update_process_names=0 157 | account_guest_in_cpu_meter=0 158 | color_scheme=0 159 | delay=15 160 | left_meters=CPU Memory Swap 161 | left_meter_modes=1 1 1 162 | right_meters=Tasks LoadAverage Uptime 163 | right_meter_modes=2 2 2 164 | EndOfFile 165 | -------------------------------------------------------------------------------- /lqae/models/model_utils.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import io 3 | import math 4 | import os 5 | from typing import Any, Callable, Optional 6 | 7 | import flax 8 | import flax.linen as nn 9 | import jax 10 | import jax.numpy as jnp 11 | import ml_collections 12 | import numpy as np 13 | import requests 14 | from ml_collections import ConfigDict 15 | from functools import partial 16 | 17 | 18 | 19 | 20 | def normalize_func(x, axis=None, eps=1e-12, use_l2_normalize=True): 21 | if use_l2_normalize: 22 | return x * jax.lax.rsqrt((x * x).sum(axis=axis, keepdims=True) + eps) 23 | else: 24 | return x 25 | 26 | 27 | def squared_euclidean_distance( 28 | a: jnp.ndarray, b: jnp.ndarray, b2: jnp.ndarray = None, precision: Any = None, 29 | dot_product: bool = False, 30 | ) -> jnp.ndarray: 31 | """Computes the pairwise squared Euclidean distance. 32 | Args: 33 | a: float32: (n, d): An array of points. 34 | b: float32: (m, d): An array of points. 35 | b2: float32: (d, m): b square transpose. 36 | precision: use DEFAULT precision by default 37 | Returns: 38 | d: float32: (n, m): Where d[i, j] is the squared Euclidean distance between 39 | a[i] and b[j]. 40 | """ 41 | if dot_product: 42 | return jnp.matmul(a, b.T, precision=precision) 43 | if b2 is None: 44 | b2 = jnp.sum(b.T**2, axis=0, keepdims=True) 45 | a2 = jnp.sum(a**2, axis=1, keepdims=True) 46 | ab = jnp.matmul(a, b.T, precision=precision) 47 | d = a2 - 2 * ab + b2 48 | return d 49 | 50 | 51 | def entropy_loss_fn(affinity, loss_type="softmax", temperature=1.0): 52 | """Calculates the entropy loss.""" 53 | flat_affinity = affinity.reshape(-1, affinity.shape[-1]) 54 | flat_affinity /= temperature 55 | probs = jax.nn.softmax(flat_affinity, axis=-1) 56 | log_probs = jax.nn.log_softmax(flat_affinity + 1e-5, axis=-1) 57 | if loss_type == "softmax": 58 | target_probs = probs 59 | elif loss_type == "argmax": 60 | codes = jnp.argmax(flat_affinity, axis=-1) 61 | onehots = jax.nn.one_hot( 62 | codes, flat_affinity.shape[-1], dtype=flat_affinity.dtype 63 | ) 64 | onehots = probs - jax.lax.stop_gradient(probs - onehots) 65 | target_probs = onehots 66 | else: 67 | raise ValueError("Entropy loss {} not supported".format(loss_type)) 68 | avg_probs = jnp.mean(target_probs, axis=0) 69 | avg_entropy = -jnp.sum(avg_probs * jnp.log(avg_probs + 1e-5)) 70 | sample_entropy = -jnp.mean(jnp.sum(target_probs * log_probs, axis=-1)) 71 | loss = sample_entropy - avg_entropy 72 | return loss 73 | 74 | 75 | class LinearCLS(nn.Module): 76 | num_classes: int = 1000 77 | 78 | @nn.compact 79 | def __call__(self, x, train=True): 80 | norm = functools.partial( 81 | nn.BatchNorm, 82 | use_running_average=not train, 83 | momentum=0.9, 84 | epsilon=1e-5, 85 | use_scale=False, 86 | use_bias=False, 87 | ) 88 | x = norm(name="bn")(x) 89 | logits = nn.Dense(self.num_classes)(x) 90 | return logits 91 | 92 | 93 | def update_vit_config(model_type, config): 94 | def get_config(model_type): 95 | if model_type == "small": 96 | return 12, 6 97 | elif model_type == "base": 98 | return 12, 12 99 | elif model_type == "large": 100 | return 24, 16 101 | elif model_type == "huge": 102 | return 32, 16 103 | elif model_type == "decoder_small": 104 | return 8, 16 105 | elif model_type == "decoder_base": 106 | return 12, 16 107 | elif model_type == "decoder_large": 108 | return 16, 16 109 | else: 110 | return tuple(int(x) for x in model_type.split("/")) 111 | 112 | if model_type == "debug": 113 | config.enc_num_layers = 2 114 | config.enc_num_heads = 4 115 | config.dec_num_layers = 2 116 | config.dec_num_heads = 4 117 | elif ":" not in model_type: 118 | config.enc_num_layers, config.enc_num_heads = get_config(model_type) 119 | config.dec_num_layers, config.dec_num_heads = get_config(model_type) 120 | else: 121 | encoder_type, decoder_type = model_type.split(":") 122 | config.enc_num_layers, config.enc_num_heads = get_config(encoder_type) 123 | config.dec_num_layers, config.dec_num_heads = get_config(decoder_type) 124 | 125 | assert_hidden_size(config) 126 | 127 | 128 | def assert_avg_rnd(config): 129 | if config.top_k_avg and config.top_k_rnd: 130 | raise ValueError("top_k_avg and top_k_rnd are mutually exclusive") 131 | if config.top_k_value > 1: 132 | assert ( 133 | config.top_k_avg or config.top_k_rnd 134 | ), "top_k_avg or top_k_rnd must be True when top_k_value > 1" 135 | elif config.top_k_value == 1: 136 | assert ( 137 | not config.top_k_avg and not config.top_k_rnd 138 | ), "top_k_avg and top_k_rnd must be False when top_k_value == 1" 139 | 140 | 141 | def assert_hidden_size(config): 142 | assert config.hidden_size % config.enc_num_heads == 0 143 | assert config.hidden_size % config.dec_num_heads == 0 144 | 145 | 146 | ACT2FN = { 147 | "gelu": nn.gelu, 148 | "relu": nn.relu, 149 | "silu": nn.swish, 150 | "swish": nn.swish, 151 | "gelu_new": partial(nn.gelu, approximate=True), 152 | "tanh": nn.tanh, 153 | } 154 | -------------------------------------------------------------------------------- /lqae/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pprint 3 | import random 4 | import tempfile 5 | import time 6 | import uuid 7 | from copy import copy 8 | from socket import gethostname 9 | 10 | import absl.flags 11 | import cloudpickle as pickle 12 | import gcsfs 13 | import numpy as np 14 | from absl import logging 15 | from ml_collections import ConfigDict 16 | from ml_collections.config_dict import config_dict 17 | from ml_collections.config_flags import config_flags 18 | 19 | import wandb 20 | 21 | from .jax_utils import init_rng 22 | 23 | 24 | class Timer(object): 25 | def __init__(self): 26 | self._time = None 27 | 28 | def __enter__(self): 29 | self._start_time = time.time() 30 | return self 31 | 32 | def __exit__(self, exc_type, exc_value, exc_tb): 33 | self._time = time.time() - self._start_time 34 | 35 | def __call__(self): 36 | return self._time 37 | 38 | 39 | class WandBLogger(object): 40 | @staticmethod 41 | def get_default_config(updates=None): 42 | config = ConfigDict() 43 | config.project_id = "" 44 | config.experiment_id = config_dict.placeholder(str) 45 | config.experiment_note = config_dict.placeholder(str) 46 | 47 | config.output_dir = "/tmp/" 48 | config.gcs_output_dir = "" 49 | 50 | config.online = False 51 | 52 | if updates is not None: 53 | config.update(ConfigDict(updates).copy_and_resolve_references()) 54 | return config 55 | 56 | def __init__(self, config, variant, enable=True): 57 | self.enable = enable 58 | self.config = self.get_default_config(config) 59 | 60 | if self.config.experiment_id is None or self.config.experiment_id == "": 61 | self.config.experiment_id = uuid.uuid4().hex 62 | else: 63 | self.config.experiment_id = str(self.config.experiment_id) + "_" + uuid.uuid4().hex 64 | 65 | if self.enable: 66 | if self.config.output_dir == "": 67 | self.config.output_dir = tempfile.mkdtemp() 68 | else: 69 | self.config.output_dir = os.path.join( 70 | self.config.output_dir, self.config.experiment_id 71 | ) 72 | os.makedirs(self.config.output_dir, exist_ok=True) 73 | 74 | if self.config.gcs_output_dir != "": 75 | self.config.gcs_output_dir = os.path.join( 76 | self.config.gcs_output_dir, self.config.experiment_id 77 | ) 78 | 79 | self._variant = copy(variant) 80 | 81 | if "hostname" not in self._variant: 82 | self._variant["hostname"] = gethostname() 83 | 84 | if self.enable: 85 | self.run = wandb.init( 86 | config=self._variant, 87 | project=self.config.project_id, 88 | dir=self.config.output_dir, 89 | id=self.config.experiment_id, 90 | resume="allow", 91 | reinit=True, 92 | notes=self.config.experiment_note, 93 | settings=wandb.Settings( 94 | start_method="thread", 95 | _disable_stats=True, 96 | ), 97 | mode="online" if self.config.online else "offline", 98 | ) 99 | else: 100 | self.run = None 101 | 102 | def log(self, *args, **kwargs): 103 | if self.enable: 104 | self.run.log(*args, **kwargs) 105 | 106 | def save_pickle(self, obj, filename): 107 | if self.enable: 108 | with open(os.path.join(self.config.output_dir, filename), "wb") as fout: 109 | pickle.dump(obj, fout) 110 | 111 | if self.config.gcs_output_dir != "": 112 | path = os.path.join(self.config.gcs_output_dir, filename) 113 | with gcsfs.GCSFileSystem().open(path, "wb") as fout: 114 | pickle.dump(obj, fout) 115 | 116 | @property 117 | def experiment_id(self): 118 | return self.config.experiment_id 119 | 120 | @property 121 | def variant(self): 122 | return self.config.variant 123 | 124 | @property 125 | def output_dir(self): 126 | return self.config.output_dir 127 | 128 | 129 | def define_flags_with_default(**kwargs): 130 | for key, val in kwargs.items(): 131 | if isinstance(val, ConfigDict): 132 | config_flags.DEFINE_config_dict(key, val) 133 | elif isinstance(val, bool): 134 | # Note that True and False are instances of int. 135 | absl.flags.DEFINE_bool(key, val, "automatically defined flag") 136 | elif isinstance(val, int): 137 | absl.flags.DEFINE_integer(key, val, "automatically defined flag") 138 | elif isinstance(val, float): 139 | absl.flags.DEFINE_float(key, val, "automatically defined flag") 140 | elif isinstance(val, str): 141 | absl.flags.DEFINE_string(key, val, "automatically defined flag") 142 | else: 143 | raise ValueError("Incorrect value type") 144 | return kwargs 145 | 146 | 147 | def set_random_seed(seed): 148 | np.random.seed(seed) 149 | random.seed(seed) 150 | init_rng(seed) 151 | 152 | 153 | def print_flags(flags, flags_def): 154 | logging.info( 155 | "Running training with hyperparameters: \n{}".format( 156 | pprint.pformat( 157 | [ 158 | "{}: {}".format(key, val) 159 | for key, val in get_user_flags(flags, flags_def).items() 160 | ] 161 | ) 162 | ) 163 | ) 164 | 165 | 166 | def get_user_flags(flags, flags_def): 167 | output = {} 168 | for key in flags_def: 169 | val = getattr(flags, key) 170 | if isinstance(val, ConfigDict): 171 | output.update(flatten_config_dict(val, prefix=key)) 172 | else: 173 | output[key] = val 174 | 175 | return output 176 | 177 | 178 | def flatten_config_dict(config, prefix=None): 179 | output = {} 180 | for key, val in config.items(): 181 | if isinstance(val, ConfigDict): 182 | output.update(flatten_config_dict(val, prefix=key)) 183 | else: 184 | if prefix is not None: 185 | output["{}.{}".format(prefix, key)] = val 186 | else: 187 | output[key] = val 188 | return output 189 | 190 | 191 | def prefix_metrics(metrics, prefix): 192 | return {"{}/{}".format(prefix, key): value for key, value in metrics.items()} 193 | 194 | 195 | def load_pickle(path): 196 | if path.startswith("gs://"): 197 | with gcsfs.GCSFileSystem().open(path) as fin: 198 | data = pickle.load(fin) 199 | else: 200 | with open(path, "rb") as fin: 201 | data = pickle.load(fin) 202 | return data 203 | 204 | 205 | def load_checkpoint(path): 206 | data = load_pickle(path) 207 | logging.info( 208 | "Loading checkpoint from %s, saved at step %d", 209 | path, 210 | data["step"], 211 | ) 212 | return data 213 | 214 | 215 | def image_float2int(image): 216 | return np.clip(image * 255.0, 0.0, 255.0).astype(np.uint8) 217 | 218 | 219 | def create_log_images(images, mean=(0.0, 0.0, 0.0), std=(1.0, 1.0, 1.0), n=5): 220 | images = [np.array(x) for x in images] 221 | images = [x.reshape(-1, *x.shape[2:]) for x in images] 222 | rows = np.concatenate(images, axis=2) 223 | rows = rows * std + mean 224 | return image_float2int(np.concatenate(rows[:n], axis=0)) 225 | -------------------------------------------------------------------------------- /lqae/models/base_resnet.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import io 3 | import math 4 | import os 5 | from typing import Any, Callable, Optional 6 | 7 | import flax 8 | import flax.linen as nn 9 | import jax 10 | import jax.numpy as jnp 11 | import ml_collections 12 | import numpy as np 13 | import requests 14 | from ml_collections import ConfigDict 15 | 16 | from .model_utils import ACT2FN 17 | 18 | 19 | class ResBlock(nn.Module): 20 | """Basic Residual Block.""" 21 | 22 | filters: int 23 | norm_fn: Any 24 | conv_fn: Any 25 | dtype: int = jnp.float32 26 | activation_fn: Any = nn.relu 27 | use_conv_shortcut: bool = False 28 | 29 | @nn.compact 30 | def __call__(self, x): 31 | input_dim = x.shape[-1] 32 | residual = x 33 | x = self.norm_fn()(x) 34 | x = self.activation_fn(x) 35 | x = self.conv_fn(self.filters, kernel_size=(3, 3), use_bias=False)(x) 36 | x = self.norm_fn()(x) 37 | x = self.activation_fn(x) 38 | x = self.conv_fn(self.filters, kernel_size=(3, 3), use_bias=False)(x) 39 | 40 | if input_dim != self.filters: 41 | if self.use_conv_shortcut: 42 | residual = self.conv_fn( 43 | self.filters, kernel_size=(3, 3), use_bias=False 44 | )(x) 45 | else: 46 | residual = self.conv_fn( 47 | self.filters, kernel_size=(1, 1), use_bias=False 48 | )(x) 49 | return x + residual 50 | 51 | 52 | class ResNetEncoder(nn.Module): 53 | """Encoder Blocks.""" 54 | 55 | config: ConfigDict 56 | dtype: int = jnp.float32 57 | 58 | def setup(self): 59 | self.filters = self.config.filters 60 | self.num_res_blocks = self.config.num_res_blocks 61 | self.channel_multipliers = self.config.channel_multipliers 62 | self.hidden_size = self.config.hidden_size 63 | self.conv_downsample = self.config.conv_downsample 64 | self.norm_type = "GN" 65 | self.activation_fn = ACT2FN["swish"] 66 | 67 | @nn.compact 68 | def __call__(self, x, train): 69 | conv_fn = nn.Conv 70 | norm_fn = get_norm_layer( 71 | train=train, dtype=self.dtype, norm_type=self.norm_type 72 | ) 73 | block_args = dict( 74 | norm_fn=norm_fn, 75 | conv_fn=conv_fn, 76 | dtype=self.dtype, 77 | activation_fn=self.activation_fn, 78 | use_conv_shortcut=False, 79 | ) 80 | x = conv_fn(self.filters, kernel_size=(3, 3), use_bias=False)(x) 81 | num_blocks = len(self.channel_multipliers) 82 | for i in range(num_blocks): 83 | filters = self.filters * self.channel_multipliers[i] 84 | for _ in range(self.num_res_blocks): 85 | x = ResBlock(filters, **block_args)(x) 86 | if i < num_blocks - 1: 87 | if self.conv_downsample: 88 | x = conv_fn(filters, kernel_size=(4, 4), strides=(2, 2))(x) 89 | else: 90 | x = dsample(x) 91 | for _ in range(self.num_res_blocks): 92 | x = ResBlock(filters, **block_args)(x) 93 | x = norm_fn()(x) 94 | x = self.activation_fn(x) 95 | x = conv_fn(self.hidden_size, kernel_size=(1, 1))(x) 96 | return x, None 97 | 98 | 99 | class ResNetDecoder(nn.Module): 100 | """Decoder Blocks.""" 101 | 102 | config: ConfigDict 103 | output_dim: int = 3 104 | dtype: Any = jnp.float32 105 | 106 | def setup(self): 107 | self.filters = self.config.filters 108 | self.num_res_blocks = self.config.num_res_blocks 109 | self.channel_multipliers = self.config.channel_multipliers 110 | self.norm_type = "GN" 111 | self.activation_fn = ACT2FN["swish"] 112 | 113 | @nn.compact 114 | def __call__(self, x, train): 115 | conv_fn = nn.Conv 116 | norm_fn = get_norm_layer( 117 | train=train, dtype=self.dtype, norm_type=self.norm_type 118 | ) 119 | block_args = dict( 120 | norm_fn=norm_fn, 121 | conv_fn=conv_fn, 122 | dtype=self.dtype, 123 | activation_fn=self.activation_fn, 124 | use_conv_shortcut=False, 125 | ) 126 | num_blocks = len(self.channel_multipliers) 127 | filters = self.filters * self.channel_multipliers[-1] 128 | x = conv_fn(filters, kernel_size=(3, 3), use_bias=True)(x) 129 | for _ in range(self.num_res_blocks): 130 | x = ResBlock(filters, **block_args)(x) 131 | for i in reversed(range(num_blocks)): 132 | filters = self.filters * self.channel_multipliers[i] 133 | for _ in range(self.num_res_blocks): 134 | x = ResBlock(filters, **block_args)(x) 135 | if i > 0: 136 | x = upsample(x, 2) 137 | x = conv_fn(filters, kernel_size=(3, 3))(x) 138 | x = norm_fn()(x) 139 | x = self.activation_fn(x) 140 | x = conv_fn(self.output_dim, kernel_size=(3, 3))(x) 141 | return x 142 | 143 | 144 | def l2_normalize(x, axis=None, eps=1e-12): 145 | """Normalizes along dimension `axis` using an L2 norm. 146 | This specialized function exists for numerical stability reasons. 147 | Args: 148 | x: An input ndarray. 149 | axis: Dimension along which to normalize, e.g. `1` to separately normalize 150 | vectors in a batch. Passing `None` views `t` as a flattened vector when 151 | calculating the norm (equivalent to Frobenius norm). 152 | eps: Epsilon to avoid dividing by zero. 153 | Returns: 154 | An array of the same shape as 'x' L2-normalized along 'axis'. 155 | """ 156 | return x * jax.lax.rsqrt((x * x).sum(axis=axis, keepdims=True) + eps) 157 | 158 | 159 | def tensorflow_style_avg_pooling(x, window_shape, strides, padding: str): 160 | """Avg pooling as done by TF (Flax layer gives different results). 161 | To be specific, Flax includes padding cells when taking the average, 162 | while TF does not. 163 | Args: 164 | x: Input tensor 165 | window_shape: Shape of pooling window; if 1-dim tuple is just 1d pooling, if 166 | 2-dim tuple one gets 2d pooling. 167 | strides: Must have the same dimension as the window_shape. 168 | padding: Either 'SAME' or 'VALID' to indicate pooling method. 169 | Returns: 170 | pooled: Tensor after applying pooling. 171 | """ 172 | pool_sum = jax.lax.reduce_window( 173 | x, 0.0, jax.lax.add, (1,) + window_shape + (1,), (1,) + strides + (1,), padding 174 | ) 175 | pool_denom = jax.lax.reduce_window( 176 | jnp.ones_like(x), 177 | 0.0, 178 | jax.lax.add, 179 | (1,) + window_shape + (1,), 180 | (1,) + strides + (1,), 181 | padding, 182 | ) 183 | return pool_sum / pool_denom 184 | 185 | 186 | def upsample(x, factor=2): 187 | n, h, w, c = x.shape 188 | x = jax.image.resize(x, (n, h * factor, w * factor, c), method="nearest") 189 | return x 190 | 191 | 192 | def dsample(x): 193 | return tensorflow_style_avg_pooling(x, (2, 2), strides=(2, 2), padding="same") 194 | 195 | 196 | def get_norm_layer(train, dtype, norm_type="BN"): 197 | """Normalization layer.""" 198 | if norm_type == "BN": 199 | norm_fn = functools.partial( 200 | nn.BatchNorm, 201 | use_running_average=not train, 202 | momentum=0.9, 203 | epsilon=1e-5, 204 | axis_name=None, 205 | axis_index_groups=None, 206 | dtype=jnp.float32, 207 | ) 208 | elif norm_type == "LN": 209 | norm_fn = functools.partial(nn.LayerNorm, dtype=dtype) 210 | elif norm_type == "GN": 211 | norm_fn = functools.partial(nn.GroupNorm, dtype=dtype) 212 | else: 213 | raise NotImplementedError 214 | return norm_fn 215 | -------------------------------------------------------------------------------- /lqae/jax_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from functools import partial 4 | from typing import Any, Mapping, Text, Tuple, Union 5 | 6 | import dill 7 | import flax 8 | import jax 9 | import jax.numpy as jnp 10 | import numpy as np 11 | from absl import logging 12 | from flax import jax_utils 13 | 14 | 15 | class JaxRNG(object): 16 | """A convenient stateful Jax RNG wrapper. Can be used to wrap RNG inside pure function.""" 17 | 18 | @classmethod 19 | def from_seed(cls, seed): 20 | return cls(jax.random.PRNGKey(seed)) 21 | 22 | def __init__(self, rng): 23 | self.rng = rng 24 | 25 | def __call__(self, keys=None): 26 | if keys is None: 27 | self.rng, split_rng = jax.random.split(self.rng) 28 | return split_rng 29 | elif isinstance(keys, int): 30 | split_rngs = jax.random.split(self.rng, num=keys + 1) 31 | self.rng = split_rngs[0] 32 | return tuple(split_rngs[1:]) 33 | else: 34 | split_rngs = jax.random.split(self.rng, num=len(keys) + 1) 35 | self.rng = split_rngs[0] 36 | return {key: val for key, val in zip(keys, split_rngs[1:])} 37 | 38 | 39 | def wrap_function_with_rng(rng): 40 | """To be used as decorator, automatically bookkeep a RNG for the wrapped function.""" 41 | 42 | def wrap_function(function): 43 | def wrapped(*args, **kwargs): 44 | nonlocal rng 45 | rng, split_rng = jax.random.split(rng) 46 | return function(split_rng, *args, **kwargs) 47 | 48 | return wrapped 49 | 50 | return wrap_function 51 | 52 | 53 | def init_rng(seed): 54 | global jax_utils_rng 55 | jax_utils_rng = JaxRNG.from_seed(seed) 56 | 57 | 58 | def next_rng(*args, **kwargs): 59 | global jax_utils_rng 60 | return jax_utils_rng(*args, **kwargs) 61 | 62 | 63 | def get_metrics(metrics, unreplicate=False, stack=False): 64 | if unreplicate: 65 | metrics = flax.jax_utils.unreplicate(metrics) 66 | metrics = jax.device_get(metrics) 67 | if stack: 68 | return jax.tree_map(lambda *args: np.stack(args), *metrics) 69 | else: 70 | return {key: float(val) for key, val in metrics.items()} 71 | 72 | 73 | def get_onehot(labels, num_classes, on_value=1.0, off_value=0.0): 74 | x = labels[..., None] == jnp.arange(num_classes).reshape((1,) * labels.ndim + (-1,)) 75 | x = jax.lax.select(x, jnp.full(x.shape, on_value), jnp.full(x.shape, off_value)) 76 | return x.astype(jnp.float32) 77 | 78 | 79 | def extend_and_repeat(tensor, axis, repeat): 80 | return jnp.repeat(jnp.expand_dims(tensor, axis), repeat, axis=axis) 81 | 82 | 83 | def mse_loss(val, target): 84 | return jnp.mean(jnp.square(val - target)) 85 | 86 | 87 | def cross_entropy_loss(logits, labels, smoothing_factor=0.0): 88 | num_classes = logits.shape[-1] 89 | if labels.dtype == jnp.int32 or labels.dtype == jnp.int64: 90 | labels = jax.nn.one_hot(labels, num_classes) 91 | if smoothing_factor > 0.0: 92 | labels = labels * (1.0 - smoothing_factor) + smoothing_factor / num_classes 93 | logp = jax.nn.log_softmax(logits, axis=-1) 94 | return -jnp.mean(jnp.sum(logp * labels, axis=-1)) 95 | 96 | 97 | def accumulated_gradient( 98 | state, accumulated_grads, accumulated_steps, grads, apply_steps, apply_fn=None 99 | ): 100 | accumulated_grads = jax.tree_map( 101 | lambda x, y: x + y, 102 | accumulated_grads, 103 | jax.tree_map(lambda x: x / float(apply_steps), grads), 104 | ) 105 | accumulated_steps = (accumulated_steps + 1) % apply_steps 106 | if apply_fn is None: 107 | apply_fn = lambda s, g: s.apply_gradients(grads=g) 108 | state = jax.lax.cond( 109 | accumulated_steps == 0, 110 | lambda: apply_fn(state, accumulated_grads), 111 | lambda: state, 112 | ) 113 | accumulated_grads = jax.lax.cond( 114 | accumulated_steps == 0, 115 | lambda: jax.tree_map(jnp.zeros_like, accumulated_grads), 116 | lambda: accumulated_grads, 117 | ) 118 | return state, accumulated_grads, accumulated_steps 119 | 120 | 121 | @partial(jax.pmap, axis_name="pmap", donate_argnums=0) 122 | def sync_state_across_devices(state): 123 | i = jax.lax.axis_index("pmap") 124 | 125 | def select(x): 126 | return jax.lax.psum(jnp.where(i == 0, x, jnp.zeros_like(x)), "pmap") 127 | 128 | return jax.tree_map(select, state) 129 | 130 | 131 | def get_random_bounding_box( 132 | image_shape: Tuple[int, int], lambda_cutmix: float, margin: float = 0.0 133 | ) -> Tuple[int, int, int, int]: 134 | ratio = np.sqrt(1 - lambda_cutmix) 135 | img_h, img_w = image_shape 136 | cut_h, cut_w = int(img_h * ratio), int(img_w * ratio) 137 | margin_y, margin_x = int(margin * cut_h), int(margin * cut_w) 138 | cy = np.random.randint(0 + margin_y, img_h - margin_y) 139 | cx = np.random.randint(0 + margin_x, img_w - margin_x) 140 | y_min = np.clip(cy - cut_h // 2, 0, img_h) 141 | y_max = np.clip(cy + cut_h // 2, 0, img_h) 142 | x_min = np.clip(cx - cut_w // 2, 0, img_w) 143 | x_max = np.clip(cx + cut_w // 2, 0, img_w) 144 | return y_min, y_max, x_min, x_max 145 | 146 | 147 | def label_smoothing_fn(labels, smoothing_factor): 148 | num_classes = labels.shape[-1] 149 | labels = labels * (1.0 - smoothing_factor) + smoothing_factor / num_classes 150 | return labels 151 | 152 | 153 | def mixup_cutmix( 154 | images: jnp.ndarray, 155 | labels: jnp.ndarray, 156 | rng: Any, 157 | num_classes: int, 158 | mixup_alpha: float = 1.0, 159 | cutmix_alpha: float = 0.0, 160 | switch_prob: float = 0.5, 161 | label_smoothing: float = 0.0, 162 | image_format: str = "NHWC" 163 | ): 164 | if len(labels.shape) == 1: 165 | labels = jax.nn.one_hot(labels, num_classes) 166 | 167 | if "N" not in image_format: 168 | raise ValueError('Mixup requires "N" to be in "image_format".') 169 | 170 | batch_size = labels.shape[0] 171 | 172 | if cutmix_alpha > 0 and (mixup_alpha <= 0 or np.random.rand() < switch_prob): 173 | do_mixup = False 174 | do_cutmix = True 175 | elif mixup_alpha > 0: 176 | do_mixup = True 177 | do_cutmix = False 178 | else: 179 | return images, labels 180 | 181 | if do_mixup: 182 | weight = jax.random.beta(rng, mixup_alpha, mixup_alpha) 183 | weight *= jnp.ones((batch_size, 1)) 184 | 185 | # Mixup inputs. 186 | # Shape calculations use np to avoid device memory fragmentation: 187 | image_weight_shape = np.ones((images.ndim)) 188 | image_weight_shape[image_format.index("N")] = batch_size 189 | image_weight = jnp.reshape( 190 | weight, image_weight_shape.astype(jnp.int32) 191 | ) 192 | reverse = tuple( 193 | slice(images.shape[i]) if d != "N" else slice(-1, None, -1) 194 | for i, d in enumerate(image_format) 195 | ) 196 | mixed_images = image_weight * images + (1.0 - image_weight) * images[reverse] 197 | label_weight = weight 198 | 199 | elif do_cutmix: 200 | if image_format not in {"NHWC", "NTHWC"}: 201 | raise ValueError( 202 | "Cutmix is only supported for inputs in format" 203 | f" NHWC or NTHWC. Got {image_format}." 204 | ) 205 | cutmix_lambda = np.random.beta(cutmix_alpha, cutmix_alpha) 206 | 207 | y_min, y_max, x_min, x_max = get_random_bounding_box( 208 | images.shape[-3:-1], cutmix_lambda 209 | ) 210 | image_mask = np.ones(images.shape) 211 | if image_format == "NHWC": 212 | image_mask[:, y_min:y_max, x_min:x_max, :] = 0.0 213 | else: 214 | image_mask[:, :, y_min:y_max, x_min:x_max, :] = 0.0 215 | height, width = images.shape[-3], images.shape[-2] 216 | 217 | mixed_images = images * image_mask + jnp.flip(images, axis=0) * ( 218 | 1.0 - image_mask 219 | ) 220 | box_area = (y_max - y_min) * (x_max - x_min) 221 | label_weight = 1.0 - box_area / float(height * width) 222 | 223 | # Mixup label 224 | if label_smoothing > 0: 225 | labels = label_smoothing_fn(labels, label_smoothing) 226 | 227 | mixed_labels = label_weight * labels + (1.0 - label_weight) * labels[::-1] 228 | 229 | return mixed_images, mixed_labels 230 | -------------------------------------------------------------------------------- /jobs/tpu_control.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | function tpu { 4 | trap "trap - SIGINT SIGTERM; return 1;" SIGINT SIGTERM 5 | 6 | # =============== TPU Project Specific Definitions =============== 7 | export PROJECT_NAME="lqae" 8 | export PROJECT_HOME="YOURS" 9 | echo "---------------------------------" 10 | echo "Project name: $PROJECT_NAME" 11 | echo "Project home: $PROJECT_HOME" 12 | echo "---------------------------------" 13 | 14 | if [ "$1" = "YOURS" ]; then 15 | tpu_project='YOURS' 16 | tpu_zone='YOURS' 17 | else 18 | echo "Invalid syntax!" 19 | trap - SIGINT SIGTERM 20 | return 1 21 | fi 22 | # =============== End of TPU Project Specific Definitions =============== 23 | 24 | 25 | if [ "$2" = "list" ]; then 26 | gcloud alpha compute tpus tpu-vm list --zone $tpu_zone --project $tpu_project 27 | elif [ "$2" = "describe" ]; then 28 | gcloud alpha compute tpus tpu-vm describe $3 --zone $tpu_zone --project $tpu_project 29 | elif [ "$2" = "ips" ]; then 30 | _tpu_ips $tpu_zone $tpu_project $3 31 | elif [ "$2" = "delete" ]; then 32 | echo "${@:3}" 33 | for tpu in "${@:3}"; do 34 | echo -n "Are you sure (y/n)? " 35 | read REPLY 36 | if [[ $REPLY =~ ^[Yy]$ ]] 37 | then 38 | echo y|gcloud alpha compute tpus tpu-vm delete "$tpu" --zone $tpu_zone --project $tpu_project & 39 | fi 40 | done 41 | elif [ "$2" = "create" ]; then 42 | _tpu_create $tpu_zone $tpu_project $3 $4 43 | elif [ "$2" = "retry_create" ]; then 44 | _tpu_retry_create $tpu_zone $tpu_project $3 $4 45 | elif [ "$2" = "create_pre" ]; then 46 | _tpu_create_pre $tpu_zone $tpu_project $3 $4 47 | elif [ "$2" = "retry_create_pre" ]; then 48 | _tpu_retry_create_pre $tpu_zone $tpu_project $3 $4 49 | elif [ "$2" = "cs" ]; then 50 | _tpu_create $tpu_zone $tpu_project $3 $4 51 | sleep 90s 52 | _tpu_setup $tpu_zone $tpu_project $4 53 | elif [ "$2" = "check" ]; then 54 | _tpu_check $tpu_zone $tpu_project $3 55 | elif [ "$2" = "setup" ]; then 56 | _tpu_setup $tpu_zone $tpu_project $3 57 | elif [ "$2" = "copy" ]; then 58 | _tpu_copy $tpu_zone $tpu_project $3 59 | elif [ "$2" = "stop" ]; then 60 | _tpu_stop $tpu_zone $tpu_project $3 61 | elif [ "$2" = "launch" ]; then 62 | _tpu_launch $tpu_zone $tpu_project $3 $4 63 | elif [ "$2" = "cl" ]; then 64 | _tpu_copy $tpu_zone $tpu_project $3 65 | _tpu_launch $tpu_zone $tpu_project $3 $4 66 | elif [ "$2" = "maintain" ]; then 67 | _tpu_maintain $tpu_zone $tpu_project $3 68 | elif [ "$2" = "ssh" ]; then 69 | _tpu_ssh $tpu_zone $tpu_project $3 "$4" 70 | elif [ "$2" = "reboot" ]; then 71 | _tpu_reboot $tpu_zone $tpu_project $3 72 | elif [ "$2" = "rm" ]; then 73 | _tpu_rm $tpu_zone $tpu_project $3 74 | else 75 | echo "Invalid syntax!" 76 | trap - SIGINT SIGTERM 77 | return 1 78 | fi 79 | trap - SIGINT SIGTERM 80 | } 81 | 82 | 83 | function _tpu_ips { 84 | tpu_zone=$1 85 | tpu_project=$2 86 | tpu_name=$3 87 | gcloud alpha compute tpus tpu-vm describe $tpu_name --zone $tpu_zone --project $tpu_project | grep -oP 'externalIp: \K(.+)$' 88 | } 89 | 90 | function _tpu_create { 91 | tpu_zone=$1 92 | tpu_project=$2 93 | tpu_cores=$3 94 | tpu_name=$4 95 | software_version='tpu-vm-base' 96 | gcloud alpha compute tpus tpu-vm create \ 97 | $tpu_name \ 98 | --accelerator-type="v3-$tpu_cores" \ 99 | --version $software_version \ 100 | --zone $tpu_zone \ 101 | --project $tpu_project 102 | } 103 | 104 | function _tpu_retry_create { 105 | while true; do 106 | _tpu_create "$@" 107 | sleep 30s 108 | done 109 | } 110 | 111 | function _tpu_create_pre { 112 | tpu_zone=$1 113 | tpu_project=$2 114 | tpu_cores=$3 115 | tpu_name=$4 116 | software_version='tpu-vm-base' 117 | gcloud alpha compute tpus tpu-vm create \ 118 | $tpu_name \ 119 | --accelerator-type="v3-$tpu_cores" \ 120 | --version $software_version \ 121 | --zone $tpu_zone \ 122 | --project $tpu_project \ 123 | --preemptible 124 | } 125 | 126 | function _tpu_retry_create_pre { 127 | while true; do 128 | _tpu_create_pre "$@" 129 | sleep 30s 130 | done 131 | } 132 | 133 | function _tpu_setup { 134 | tpu_zone=$1 135 | tpu_project=$2 136 | tpu_name=$3 137 | 138 | tpu_ips=($(echo "$(_tpu_ips $tpu_zone $tpu_project $tpu_name)")) 139 | for host in $tpu_ips[@]; do 140 | scp $PROJECT_HOME/$PROJECT_NAME/scripts/tpu_vm_setup.sh $host:~/ 141 | ssh $host '~/tpu_vm_setup.sh' & 142 | done 143 | wait &> /dev/null 144 | 145 | for host in $tpu_ips[@]; do 146 | scp $PROJECT_HOME/$PROJECT_NAME/scripts/tpu_vm_setup.sh $host:~/ 147 | wait 148 | ssh $host '~/tpu_vm_setup.sh' & 149 | done 150 | wait &> /dev/null 151 | } 152 | 153 | function _tpu_check { 154 | tpu_zone=$1 155 | tpu_project=$2 156 | tpu_name=$3 157 | 158 | tpu_ips=($(echo "$(_tpu_ips $tpu_zone $tpu_project $tpu_name)")) 159 | for host in $tpu_ips[@]; do 160 | echo "============== Checking host: $host ==============" 161 | ssh $host 'tmux capture-pane -pt launch' 162 | echo "============== End of host: $host ==============" 163 | echo 164 | echo 165 | done 166 | } 167 | 168 | function _tpu_copy { 169 | tpu_zone=$1 170 | tpu_project=$2 171 | tpu_name=$3 172 | 173 | tpu_ips=($(echo "$(_tpu_ips $tpu_zone $tpu_project $tpu_name)")) 174 | for host in $tpu_ips[@]; do 175 | rsync -avPI --exclude=logs --exclude=__pycache__ --exclude=.git --exclude=local $PROJECT_HOME/$PROJECT_NAME $host:~/ & 176 | done 177 | wait &> /dev/null 178 | sleep 1s 179 | 180 | for host in $tpu_ips[@]; do 181 | rsync -avPI --exclude=logs --exclude=__pycache__ --exclude=.git --exclude=local $PROJECT_HOME/$PROJECT_NAME $host:~/ & 182 | done 183 | wait &> /dev/null 184 | sleep 1s 185 | } 186 | 187 | function _tpu_stop { 188 | tpu_zone=$1 189 | tpu_project=$2 190 | tpu_name=$3 191 | 192 | tpu_ips=($(echo "$(_tpu_ips $tpu_zone $tpu_project $tpu_name)")) 193 | for host in $tpu_ips[@]; do 194 | ssh $host 'tmux kill-session -t launch ; pkill -9 python' & 195 | done 196 | wait &> /dev/null 197 | } 198 | 199 | function _tpu_launch { 200 | tpu_zone=$1 201 | tpu_project=$2 202 | tpu_name=$3 203 | command=$4 204 | 205 | if [ -z "$command" ]; then 206 | echo "Invalid syntax!" 207 | return 1 208 | fi 209 | 210 | tpu_ips=($(echo "$(_tpu_ips $tpu_zone $tpu_project $tpu_name)")) 211 | for host in $tpu_ips[@]; do 212 | ssh $host "tmux new -d -s launch ~/$PROJECT_NAME/jobs/$command" & 213 | done 214 | wait &> /dev/null 215 | } 216 | 217 | function _tpu_maintain { 218 | tpu_zone=$1 219 | tpu_project=$2 220 | tpu_name=$3 221 | 222 | gcloud alpha compute tpus tpu-vm simulate-maintenance-event $tpu_name \ 223 | --project $tpu_project \ 224 | --zone=$tpu_zone \ 225 | --workers=all 226 | } 227 | 228 | function _tpu_ssh { 229 | tpu_zone=$1 230 | tpu_project=$2 231 | tpu_name=$3 232 | command="$4" 233 | 234 | if [ -z "$command" ]; then 235 | echo "Invalid syntax!" 236 | return 1 237 | fi 238 | 239 | tpu_ips=($(echo "$(_tpu_ips $tpu_zone $tpu_project $tpu_name)")) 240 | for host in $tpu_ips[@]; do 241 | ssh $host "$command" & 242 | done 243 | wait &> /dev/null 244 | } 245 | 246 | function _tpu_reboot { 247 | tpu_zone=$1 248 | tpu_project=$2 249 | tpu_name=$3 250 | 251 | tpu_ips=($(echo "$(_tpu_ips $tpu_zone $tpu_project $tpu_name)")) 252 | for host in $tpu_ips[@]; do 253 | ssh $host 'sudo reboot' & 254 | done 255 | wait &> /dev/null 256 | } 257 | 258 | 259 | function _tpu_rm { 260 | tpu_zone=$1 261 | tpu_project=$2 262 | tpu_name=$3 263 | 264 | tpu_ips=($(echo "$(_tpu_ips $tpu_zone $tpu_project $tpu_name)")) 265 | for host in $tpu_ips[@]; do 266 | ssh $host 'rm -rf ~/*' & 267 | done 268 | wait &> /dev/null 269 | } 270 | 271 | function export_function() { 272 | if [[ -n "${ZSH_VERSION}" ]]; then 273 | for f in "$@"; do 274 | zle -N $f 275 | done 276 | else 277 | export -f "$@" 278 | fi 279 | } 280 | 281 | export_function tpu _tpu_ips _tpu_create _tpu_create_pre _tpu_retry_create _tpu_retry_create_pre _tpu_setup _tpu_check _tpu_copy _tpu_stop _tpu_launch _tpu_maintain _tpu_ssh _tpu_reboot _tpu_rm 282 | -------------------------------------------------------------------------------- /lqae/main/eval_loss.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import pprint 3 | from copy import copy, deepcopy 4 | from functools import partial 5 | 6 | import absl.app 7 | import absl.flags 8 | import flax 9 | import jax 10 | import jax.numpy as jnp 11 | import numpy as np 12 | import optax 13 | import torch 14 | from absl import logging 15 | from flax.jax_utils import prefetch_to_device 16 | from flax.training.train_state import TrainState 17 | from tqdm.auto import tqdm, trange 18 | 19 | import wandb 20 | 21 | from ..data import ImageNetDataset, ImageTextDataset 22 | from ..jax_utils import ( 23 | JaxRNG, 24 | accumulated_gradient, 25 | get_metrics, 26 | next_rng, 27 | sync_state_across_devices, 28 | ) 29 | from ..models import LQAE, VQAE 30 | from ..utils import ( 31 | WandBLogger, 32 | create_log_images, 33 | define_flags_with_default, 34 | get_user_flags, 35 | image_float2int, 36 | load_pickle, 37 | set_random_seed, 38 | ) 39 | 40 | FLAGS_DEF = define_flags_with_default( 41 | seed=42, 42 | epochs=200, 43 | batch_size=0, 44 | accumulate_grad_steps=1, 45 | dataloader_n_workers=0, 46 | dataloader_shuffle=False, 47 | log_freq=50, 48 | plot_freq=1000, 49 | save_model_freq=0, 50 | clip_gradient=1e9, 51 | lr_init_value=0.0, 52 | lr_end_value=0.0, 53 | lr_peak_value=1.0e-4, 54 | lr_warmup_epochs=0, 55 | weight_decay=0.0001, 56 | load_checkpoint="", 57 | load_pretrained="", 58 | dataset="imagenet", 59 | cc12m_data=ImageTextDataset.get_default_config(), 60 | imagenet_data=ImageNetDataset.get_default_config(), 61 | # lqae: encoder-decoder with frozen BERT 62 | # vqae: encoder-decoder without BERT 63 | # bert: trainable BERT with frozen encoder-decoder 64 | model_type="lqae", 65 | lqae=LQAE.get_default_config(), 66 | vqae=VQAE.get_default_config(), 67 | logging=WandBLogger.get_default_config(), 68 | log_all_worker=False, 69 | # eval loss ratio 70 | min_ratio=0.15, 71 | max_ratio=0.15, 72 | ) 73 | FLAGS = absl.flags.FLAGS 74 | 75 | 76 | def main(argv): 77 | variant = get_user_flags(FLAGS, FLAGS_DEF) 78 | assert FLAGS.model_type in [ 79 | "lqae", 80 | "vqae", 81 | "bert", 82 | ], "model_type must be one of lqae, vqae, bert" 83 | 84 | variant["jax_process_index"] = jax_process_index = jax.process_index() 85 | variant["jax_process_count"] = jax_process_count = jax.process_count() 86 | assert FLAGS.batch_size % jax_process_count == 0 87 | variant["process_batch_size"] = process_batch_size = ( 88 | FLAGS.batch_size // jax_process_count 89 | ) 90 | variant["device_batch_size"] = process_batch_size // jax.local_device_count() 91 | lr_scale = FLAGS.batch_size / 256 92 | variant["effective_lr"] = FLAGS.lr_peak_value * lr_scale 93 | jax_devices = jax.local_devices() 94 | n_devices = len(jax_devices) 95 | assert process_batch_size % n_devices == 0 96 | 97 | logger = WandBLogger( 98 | config=FLAGS.logging, 99 | variant=variant, 100 | enable=FLAGS.log_all_worker or (jax_process_index == 0), 101 | ) 102 | set_random_seed(FLAGS.seed * (jax_process_index + 1)) 103 | 104 | if FLAGS.dataset == "cc12m": 105 | FLAGS.cc12m_data.image_only = True 106 | dataset = ImageTextDataset( 107 | FLAGS.cc12m_data, jax_process_index / jax_process_count 108 | ) 109 | elif FLAGS.dataset == "imagenet": 110 | FLAGS.imagenet_data.image_only = True 111 | dataset = ImageNetDataset( 112 | FLAGS.imagenet_data, jax_process_index / jax_process_count 113 | ) 114 | else: 115 | raise ValueError("Unsupported dataset!") 116 | 117 | val_flags = deepcopy(FLAGS.imagenet_data) 118 | val_flags.partition = "val" 119 | val_flags.transform_type = "test" 120 | val_dataset = ImageNetDataset(val_flags, jax_process_index / jax_process_count) 121 | 122 | steps_per_epoch = int(len(dataset) / FLAGS.batch_size) 123 | val_steps = int(len(val_dataset) / FLAGS.batch_size) 124 | 125 | val_dataloader = torch.utils.data.DataLoader( 126 | val_dataset, 127 | batch_size=process_batch_size, 128 | shuffle=False, 129 | drop_last=True, 130 | num_workers=FLAGS.dataloader_n_workers, 131 | prefetch_factor=2, 132 | persistent_workers=FLAGS.dataloader_n_workers > 0, 133 | ) 134 | 135 | if FLAGS.model_type == "lqae" or FLAGS.model_type == "bert": 136 | logging.info(f"Using LQAE model for {FLAGS.model_type}") 137 | model = LQAE(FLAGS.lqae) 138 | elif FLAGS.model_type == "vqae": 139 | logging.info(f"Using LQAE model for {FLAGS.model_type}") 140 | model = VQAE(FLAGS.vqae) 141 | 142 | def get_loss(output, result_dict, image, train=True): 143 | if "bert_loss" in result_dict: 144 | bert_loss = result_dict["bert_loss"] 145 | else: 146 | bert_loss = 0.0 147 | 148 | if FLAGS.model_type == "lqae" or FLAGS.model_type == "bert": 149 | recon_loss = FLAGS.lqae.bert_channel_image_loss_weight * jnp.mean( 150 | (image - output["bert_channel_image_output"]) ** 2 151 | ) + FLAGS.lqae.nochannel_image_loss_weight * jnp.mean( 152 | (image - output["image_output"]) ** 2 153 | ) 154 | elif FLAGS.model_type == "vqae": 155 | recon_loss = jnp.mean((image - output["image_output"]) ** 2) 156 | 157 | if train: 158 | quantizer_loss = result_dict["quantizer_loss"] 159 | return quantizer_loss, bert_loss, recon_loss 160 | else: 161 | return bert_loss, recon_loss 162 | 163 | @partial(jax.pmap, axis_name="pmap") 164 | def val_step_fn(state, rng, image): 165 | rng_generator = JaxRNG(rng) 166 | 167 | output, result_dict = model.apply( 168 | state.params, 169 | image, 170 | train=False, 171 | ratio={"min_ratio": FLAGS.min_ratio, "max_ratio": FLAGS.max_ratio}, 172 | rngs=rng_generator(keys=model.rng_keys()), 173 | ) 174 | 175 | bert_loss, recon_loss = get_loss(output, result_dict, image, train=False) 176 | 177 | aux = dict( 178 | recon_loss=recon_loss, 179 | bert_loss=bert_loss, 180 | perplexity=result_dict["perplexity"], 181 | codebook_usage=result_dict["codebook_usage"], 182 | ) 183 | 184 | encoding_indices = result_dict["encoding_indices"] 185 | encoding_indices = jax.lax.all_gather(encoding_indices, axis_name="pmap") 186 | 187 | aux = jax.lax.pmean(aux, axis_name="pmap") 188 | return aux, rng_generator(), encoding_indices 189 | 190 | if FLAGS.load_checkpoint != "": 191 | checkpoint_data = load_pickle(FLAGS.load_checkpoint) 192 | state = flax.jax_utils.replicate(checkpoint_data["state"], jax_devices) 193 | else: 194 | image = jnp.zeros((6, 256, 256, 3), dtype=jnp.float32) 195 | rngs = next_rng(keys=model.rng_keys()) 196 | params = model.init(rngs, image, train=True) 197 | 198 | assert FLAGS.model_type == "lqae" 199 | if FLAGS.lqae.use_bert_codebook: 200 | params = model.load_bert_params(params) 201 | 202 | state = flax.jax_utils.replicate( 203 | TrainState.create( 204 | params=flax.core.frozen_dict.unfreeze(params), 205 | apply_fn=None, 206 | tx=optax.lars( 207 | learning_rate=0, 208 | weight_decay=0, 209 | momentum=0, 210 | ), 211 | ), 212 | jax_devices, 213 | ) 214 | 215 | del params 216 | 217 | def generate_batch(iterator): 218 | while True: 219 | for images in iterator: 220 | yield images.numpy().reshape(n_devices, -1, *images.shape[1:]) 221 | 222 | state = sync_state_across_devices(state) 223 | sharded_rng = jax.device_put_sharded(next_rng(n_devices), jax_devices) 224 | 225 | val_data_iterator = prefetch_to_device( 226 | generate_batch(val_dataloader), 2, jax_devices 227 | ) 228 | if FLAGS.model_type == "lqae" or FLAGS.model_type == "bert": 229 | codebook_size = 50265 230 | elif FLAGS.model_type == "vqae": 231 | codebook_size = FLAGS.vqae.codebook_size 232 | 233 | val_metrics = [] 234 | val_encoding_indices = [] 235 | for _, val_image in zip( 236 | trange(val_steps, ncols=0, desc="val"), val_data_iterator 237 | ): 238 | val_image = val_image.astype(jnp.float32) 239 | metrics, sharded_rng, encoding_indices = val_step_fn( 240 | state, sharded_rng, val_image 241 | ) 242 | val_metrics.append(metrics) 243 | val_encoding_indices.append(encoding_indices) 244 | log_metrics = get_metrics(val_metrics, unreplicate=True, stack=True) 245 | val_encoding_indices = jax.tree_map( 246 | lambda x: jax.device_get(flax.jax_utils.unreplicate(x)), 247 | val_encoding_indices, 248 | ) 249 | val_encoding_indices = jnp.concatenate(val_encoding_indices, axis=0) 250 | log_metrics = { 251 | f"val_{k}": v 252 | for k, v in jax.tree_map(lambda x: x.mean(), log_metrics).items() 253 | } 254 | val_indices_histogram = jnp.histogram( 255 | val_encoding_indices, bins=512, range=(0, codebook_size - 1) 256 | ) 257 | log_metrics.update( 258 | { 259 | "val_indices_histogram": wandb.Histogram( 260 | np_histogram=val_indices_histogram 261 | ), 262 | "val_encoding_indices": wandb.Histogram(val_encoding_indices), 263 | } 264 | ) 265 | logger.log(log_metrics) 266 | tqdm.write("\n" + pprint.pformat(log_metrics) + "\n") 267 | 268 | 269 | if __name__ == "__main__": 270 | torch.multiprocessing.set_start_method("spawn") 271 | absl.app.run(main) 272 | -------------------------------------------------------------------------------- /lqae/models/base_vit.py: -------------------------------------------------------------------------------- 1 | import enum 2 | import math 3 | from functools import partial 4 | from typing import Any, Dict, Optional, Tuple, Union 5 | 6 | import flax.linen as nn 7 | import jax 8 | import jax.numpy as jnp 9 | import numpy as np 10 | from flax.linen.attention import dot_product_attention_weights 11 | from flax.linen.initializers import variance_scaling 12 | from ml_collections import ConfigDict 13 | 14 | from .model_utils import ACT2FN 15 | 16 | 17 | # copied from https://github.com/deepmind/dm-haiku/blob/3f31e279d4ce613ae3e47b97031f8b2d732071b7/haiku/_src/spectral_norm.py#L46 18 | def l2_normalize(x, axis=None, eps=1e-12): 19 | """Normalizes along dimension `axis` using an L2 norm. 20 | This specialized function exists for numerical stability reasons. 21 | Args: 22 | x: An input ndarray. 23 | axis: Dimension along which to normalize, e.g. `1` to separately normalize 24 | vectors in a batch. Passing `None` views `t` as a flattened vector when 25 | calculating the norm (equivalent to Frobenius norm). 26 | eps: Epsilon to avoid dividing by zero. 27 | Returns: 28 | An array of the same shape as 'x' L2-normalized along 'axis'. 29 | """ 30 | return x * jax.lax.rsqrt((x * x).sum(axis=axis, keepdims=True) + eps) 31 | 32 | 33 | # Source: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gptj/modeling_flax_gptj.py 34 | def create_sinusoidal_positions(num_pos, dim): 35 | inv_freq = 1.0 / (10000 ** (np.arange(0, dim, 2) / dim)) 36 | sinusoid_inp = np.einsum("i , j -> i j", np.arange(num_pos), inv_freq).astype( 37 | "float32" 38 | ) 39 | sin, cos = np.sin(sinusoid_inp), np.cos(sinusoid_inp) 40 | sentinel = dim // 2 + dim % 2 41 | out = np.zeros((num_pos, dim)) 42 | out[:, 0:sentinel] = sin 43 | out[:, sentinel:] = cos 44 | return jnp.array(out) 45 | 46 | 47 | def patch_unflatten(patch, shape): 48 | return patch.reshape(patch.shape[0], *shape, patch.shape[-1]) 49 | 50 | 51 | def patch_flatten(patch): 52 | return patch.reshape(patch.shape[0], -1, patch.shape[-1]), patch.shape[1:-1] 53 | 54 | 55 | class ConvPatches(nn.Module): 56 | patch_size: Tuple[int] 57 | hidden_size: int 58 | dtype: jnp.dtype = jnp.float32 59 | 60 | @nn.compact 61 | def __call__(self, pixel_values): 62 | patch_embeds = nn.Conv( 63 | self.hidden_size, 64 | kernel_size=[self.patch_size, self.patch_size], 65 | strides=[self.patch_size, self.patch_size], 66 | padding="VALID", 67 | use_bias=False, 68 | dtype=self.dtype, 69 | name="patch_embeds", 70 | kernel_init=jax.nn.initializers.normal(0.02), 71 | )(pixel_values) 72 | return patch_embeds 73 | 74 | 75 | class GLU(nn.Module): 76 | dim1: int 77 | dim2: int 78 | activation: str 79 | dropout: float 80 | space_only_conv: bool = False 81 | dtype: jnp.dtype = jnp.float32 82 | 83 | @nn.compact 84 | def __call__(self, hidden_states, deterministic: bool = True): 85 | Dense = partial( 86 | nn.Dense, 87 | use_bias=False, 88 | dtype=self.dtype, 89 | kernel_init=jax.nn.initializers.normal(0.02), 90 | ) 91 | 92 | hidden_states = nn.LayerNorm( 93 | epsilon=1e-5, dtype=self.dtype, name="layernorm_0" 94 | )(hidden_states) 95 | 96 | hidden_gelu = Dense(features=self.dim1, name="fc1")(hidden_states) 97 | hidden_gelu = ACT2FN[self.activation](hidden_gelu) 98 | 99 | hidden_linear = Dense(features=self.dim1, name="fc2")(hidden_states) 100 | 101 | hidden_states = hidden_gelu * hidden_linear 102 | 103 | # suggestion from Katherine Crowson 104 | ndims = len(hidden_states.shape[1:-1]) 105 | assert ndims in [2, 3] 106 | if self.space_only_conv: 107 | kernel = (3, 3) if ndims == 2 else (1, 3, 3) 108 | else: 109 | kernel = (3,) * ndims 110 | hidden_states = nn.Conv( 111 | self.dim1, 112 | kernel_size=kernel, 113 | strides=(1,) * ndims, 114 | padding="SAME", 115 | feature_group_count=self.dim1, 116 | use_bias=False, 117 | dtype=self.dtype, 118 | kernel_init=jax.nn.initializers.normal(0.02), 119 | name="mid_ffn_conv", 120 | )(hidden_states) 121 | 122 | hidden_states = nn.LayerNorm( 123 | epsilon=1e-5, dtype=self.dtype, name="layernorm_1" 124 | )(hidden_states) 125 | 126 | hidden_states = nn.Dropout(rate=self.dropout)( 127 | hidden_states, deterministic=deterministic 128 | ) 129 | hidden_states = Dense(features=self.dim2, name="fc_out")(hidden_states) 130 | hidden_states = nn.Dropout(rate=self.dropout)( 131 | hidden_states, deterministic=deterministic 132 | ) 133 | return hidden_states 134 | 135 | 136 | class Attention(nn.Module): 137 | hidden_size: int 138 | num_heads: int 139 | dropout: float 140 | dtype: jnp.dtype = jnp.float32 141 | 142 | def _split_heads(self, hidden_states): 143 | head_dim = self.hidden_size // self.num_heads 144 | return hidden_states.reshape( 145 | hidden_states.shape[:2] + (self.num_heads, head_dim) 146 | ) 147 | 148 | def _merge_heads(self, hidden_states): 149 | return hidden_states.reshape(hidden_states.shape[:2] + (self.hidden_size,)) 150 | 151 | @nn.compact 152 | def __call__( 153 | self, 154 | hidden_states, 155 | deterministic: bool = True, 156 | mask: Optional[jnp.ndarray] = None, 157 | ): 158 | Dense = partial( 159 | nn.Dense, 160 | self.hidden_size, 161 | use_bias=False, 162 | dtype=self.dtype, 163 | kernel_init=jax.nn.initializers.normal(0.02), 164 | ) 165 | hidden_states, shape = patch_flatten(hidden_states) 166 | 167 | query = Dense()(hidden_states) 168 | key = Dense()(hidden_states) 169 | value = Dense()(hidden_states) 170 | 171 | query = self._split_heads(query) 172 | key = self._split_heads(key) 173 | value = self._split_heads(value) 174 | 175 | dropout_rng = None 176 | if not deterministic and self.dropout > 0.0: 177 | dropout_rng = self.make_rng("dropout") 178 | 179 | attn_weights = dot_product_attention_weights( 180 | query, 181 | key, 182 | bias=None, 183 | dropout_rng=dropout_rng, 184 | dropout_rate=self.dropout, 185 | deterministic=deterministic, 186 | dtype=self.dtype, 187 | precision=None, 188 | mask=mask, 189 | ) 190 | 191 | attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value) 192 | attn_output = self._merge_heads(attn_output) 193 | attn_output = Dense()(attn_output) 194 | attn_output = patch_unflatten(attn_output, shape) 195 | return attn_output 196 | 197 | 198 | class TransformerBlock(nn.Module): 199 | hidden_size: int 200 | intermediate_size: int 201 | num_heads: int 202 | dropout: float 203 | deterministic: bool 204 | space_only_conv: bool = False 205 | AttentionModule: Any = Attention 206 | dtype: jnp.dtype = jnp.float32 207 | 208 | @nn.compact 209 | def __call__(self, hidden_states, mask: Optional[jnp.ndarray] = None): 210 | deterministic = self.deterministic 211 | residual = hidden_states 212 | 213 | hidden_states = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)(hidden_states) 214 | hidden_states = self.AttentionModule( 215 | hidden_size=self.hidden_size, 216 | num_heads=self.num_heads, 217 | dropout=self.dropout, 218 | dtype=self.dtype, 219 | )(hidden_states=hidden_states, deterministic=deterministic, mask=mask) 220 | # normformer 221 | hidden_states = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)(hidden_states) 222 | hidden_states = residual + hidden_states 223 | 224 | residual = hidden_states 225 | hidden_states = GLU( 226 | dim1=self.intermediate_size, 227 | dim2=self.hidden_size, 228 | activation="gelu", 229 | dropout=self.dropout, 230 | space_only_conv=self.space_only_conv, 231 | dtype=self.dtype, 232 | )(hidden_states, deterministic=deterministic) 233 | hidden_states = residual + hidden_states 234 | 235 | return hidden_states 236 | 237 | 238 | class Transformer(nn.Module): 239 | num_layers: int 240 | hidden_size: int 241 | intermediate_size: int 242 | num_heads: int 243 | dropout: float 244 | dtype: jnp.dtype = jnp.float32 245 | 246 | @nn.compact 247 | def __call__( 248 | self, 249 | hidden_states, 250 | deterministic: bool = True, 251 | mask: Optional[jnp.ndarray] = None, 252 | ): 253 | space_only_conv = mask is not None 254 | layer_outputs = [] 255 | for i in range(self.num_layers): 256 | AttentionModule = Attention 257 | assert mask is None 258 | mask_in = mask 259 | 260 | hidden_states = nn.remat(TransformerBlock)( 261 | hidden_size=self.hidden_size, 262 | intermediate_size=self.intermediate_size, 263 | num_heads=self.num_heads, 264 | dropout=self.dropout, 265 | name=str(i), 266 | deterministic=deterministic, 267 | space_only_conv=space_only_conv, 268 | AttentionModule=AttentionModule, 269 | dtype=self.dtype, 270 | )(hidden_states, mask_in) 271 | layer_outputs.append(hidden_states) 272 | return hidden_states, layer_outputs 273 | 274 | 275 | class VitEncoder(nn.Module): 276 | config: ConfigDict 277 | causal_mask: bool = False 278 | dtype: jnp.dtype = jnp.float32 279 | 280 | @nn.compact 281 | def __call__( 282 | self, 283 | pixel_values, 284 | deterministic: bool = True, 285 | mask: Optional[jnp.ndarray] = None, 286 | ): 287 | hidden_states = ConvPatches( 288 | patch_size=self.config.patch_size, 289 | hidden_size=self.config.hidden_size, 290 | dtype=self.dtype, 291 | )(pixel_values) 292 | 293 | position_embeddings = self.param( 294 | "pos_embedding", 295 | jax.nn.initializers.normal(0.02, dtype=jnp.float32), 296 | (1, *hidden_states.shape[1:-1], self.config.hidden_size), 297 | ) 298 | hidden_states += position_embeddings 299 | hidden_states = nn.Dropout(rate=self.config.dropout)( 300 | hidden_states, deterministic=deterministic 301 | ) 302 | 303 | if self.causal_mask: 304 | assert len(hidden_states.shape[1:-1]) == 3, hidden_states.shape 305 | T = hidden_states.shape[1] 306 | mask = jnp.tril(jnp.ones((T, T), dtype=bool)) 307 | else: 308 | mask = None 309 | 310 | hidden_states, layer_outputs = Transformer( 311 | num_layers=self.config.enc_num_layers, 312 | hidden_size=self.config.hidden_size, 313 | intermediate_size=self.config.intermediate_size, 314 | num_heads=self.config.enc_num_heads, 315 | dropout=self.config.dropout, 316 | dtype=self.dtype, 317 | )(hidden_states, deterministic=deterministic, mask=mask) 318 | return hidden_states, layer_outputs 319 | 320 | 321 | class VitDecoder(nn.Module): 322 | config: ConfigDict 323 | causal_mask: bool = False 324 | dtype: jnp.dtype = jnp.float32 325 | 326 | @nn.compact 327 | def __call__( 328 | self, 329 | hidden_states, 330 | deterministic: bool = True, 331 | mask: Optional[jnp.ndarray] = None, 332 | ): 333 | position_embeddings = self.param( 334 | "pos_embedding", 335 | jax.nn.initializers.normal(0.02, dtype=jnp.float32), 336 | (1, *hidden_states.shape[1:-1], self.config.hidden_size), 337 | ) 338 | hidden_states += position_embeddings 339 | 340 | if self.causal_mask: 341 | assert len(hidden_states.shape[1:-1]) == 3, hidden_states.shape 342 | T = hidden_states.shape[1] 343 | mask = jnp.tril(jnp.ones((T, T), dtype=bool)) 344 | else: 345 | mask = None 346 | 347 | hidden_states, layer_outputs = Transformer( 348 | num_layers=self.config.dec_num_layers, 349 | hidden_size=self.config.hidden_size, 350 | intermediate_size=self.config.intermediate_size, 351 | num_heads=self.config.dec_num_heads, 352 | dropout=self.config.dropout, 353 | dtype=self.dtype, 354 | )(hidden_states, deterministic=deterministic, mask=mask) 355 | 356 | images = nn.ConvTranspose( 357 | 3, 358 | kernel_size=[self.config.patch_size, self.config.patch_size], 359 | strides=[self.config.patch_size, self.config.patch_size], 360 | padding="VALID", 361 | dtype=self.dtype, 362 | kernel_init=jax.nn.initializers.normal(0.02), 363 | )(hidden_states) 364 | return images 365 | -------------------------------------------------------------------------------- /lqae/main/linear_main.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import os 3 | import pprint 4 | from copy import copy, deepcopy 5 | from functools import partial 6 | from typing import Any, Callable, Optional 7 | 8 | import absl.app 9 | import absl.flags 10 | import flax 11 | import jax 12 | import jax.numpy as jnp 13 | import numpy as np 14 | import optax 15 | import torch 16 | from flax import linen as nn 17 | from flax.jax_utils import prefetch_to_device 18 | from flax.training import train_state 19 | from tqdm.auto import tqdm, trange 20 | 21 | from ..data import ImageNetDataset 22 | from ..jax_utils import ( 23 | JaxRNG, 24 | cross_entropy_loss, 25 | get_metrics, 26 | next_rng, 27 | sync_state_across_devices, 28 | mixup_cutmix, 29 | ) 30 | from ..models import LQAE 31 | from ..utils import ( 32 | WandBLogger, 33 | define_flags_with_default, 34 | get_user_flags, 35 | load_checkpoint, 36 | load_pickle, 37 | set_random_seed, 38 | ) 39 | 40 | 41 | FLAGS_DEF = define_flags_with_default( 42 | seed=42, 43 | epochs=200, 44 | batch_size=0, 45 | dataloader_n_workers=0, 46 | dataloader_shuffle=False, 47 | log_freq=50, 48 | save_model_freq=0, 49 | lr_init_value=0.0, 50 | lr_end_value=0.0, 51 | lr_peak_value=1.0e-4, 52 | lr_warmup_epochs=0, 53 | momentum=0.9, 54 | weight_decay=0.0001, 55 | clip_gradient=1e9, 56 | load_pretrained="", 57 | load_checkpoint="", 58 | last_embedding_layers="all", 59 | imagenet_data=ImageNetDataset.get_default_config(), 60 | lqae=LQAE.get_default_config(), 61 | logging=WandBLogger.get_default_config(), 62 | log_all_worker=False, 63 | load_bert_params=True, 64 | ) 65 | FLAGS = absl.flags.FLAGS 66 | 67 | 68 | def main(argv): 69 | FLAGS = absl.flags.FLAGS 70 | variant = get_user_flags(FLAGS, FLAGS_DEF) 71 | 72 | variant["jax_process_index"] = jax_process_index = jax.process_index() 73 | variant["jax_process_count"] = jax_process_count = jax.process_count() 74 | assert FLAGS.batch_size % jax_process_count == 0 75 | variant["process_batch_size"] = process_batch_size = ( 76 | FLAGS.batch_size // jax_process_count 77 | ) 78 | variant["device_batch_size"] = process_batch_size // jax.local_device_count() 79 | lr_scale = FLAGS.batch_size / 256 80 | variant["effective_lr"] = FLAGS.lr_peak_value * lr_scale 81 | jax_devices = jax.local_devices() 82 | n_devices = len(jax_devices) 83 | assert process_batch_size % n_devices == 0 84 | 85 | logger = WandBLogger( 86 | config=FLAGS.logging, 87 | variant=variant, 88 | enable=FLAGS.log_all_worker or (jax_process_index == 0), 89 | ) 90 | set_random_seed(FLAGS.seed * (jax_process_index + 1)) 91 | 92 | train_dataset = ImageNetDataset( 93 | FLAGS.imagenet_data, jax_process_index / jax_process_count 94 | ) 95 | val_flags = deepcopy(FLAGS.imagenet_data) 96 | val_flags.partition = "val" 97 | val_flags.transform_type = "test" 98 | val_dataset = ImageNetDataset(val_flags, jax_process_index / jax_process_count) 99 | 100 | steps_per_epoch = int(len(train_dataset) / FLAGS.batch_size) 101 | total_steps = steps_per_epoch * FLAGS.epochs 102 | val_steps = int(len(val_dataset) / FLAGS.batch_size) 103 | 104 | train_loader = torch.utils.data.DataLoader( 105 | train_dataset, 106 | batch_size=process_batch_size, 107 | shuffle=FLAGS.dataloader_shuffle, 108 | num_workers=FLAGS.dataloader_n_workers, 109 | prefetch_factor=2, 110 | persistent_workers=FLAGS.dataloader_n_workers > 0, 111 | drop_last=True, 112 | ) 113 | 114 | val_loader = torch.utils.data.DataLoader( 115 | val_dataset, 116 | batch_size=process_batch_size, 117 | shuffle=False, 118 | num_workers=FLAGS.dataloader_n_workers, 119 | prefetch_factor=2, 120 | persistent_workers=FLAGS.dataloader_n_workers > 0, 121 | drop_last=True, 122 | ) 123 | 124 | def combine_representation(output): 125 | encoder_embedding = [x for x in output["encoder_embedding"]] 126 | # Stop gradient on encoder_embedding in Linear 127 | encoder_embedding = jax.lax.stop_gradient(encoder_embedding) 128 | bert_embedding = [x for x in output["bert_embedding"]] 129 | # Stop gradient on bert_embedding in Linear 130 | bert_embedding = jax.lax.stop_gradient(bert_embedding) 131 | all_embedding = encoder_embedding + bert_embedding 132 | if FLAGS.last_embedding_layers == "all": 133 | representation = all_embedding 134 | elif FLAGS.last_embedding_layers == "all_bert": 135 | representation = bert_embedding 136 | elif FLAGS.last_embedding_layers == "all_encoder": 137 | representation = encoder_embedding 138 | else: 139 | representation = [ 140 | all_embedding[-int(x)] for x in FLAGS.last_embedding_layers.split(",") 141 | ] 142 | representation = jax.tree_util.tree_map( 143 | lambda x: jnp.mean(x[:, 1:, :], axis=1), representation 144 | ) 145 | representation = jnp.concatenate(representation, axis=-1) 146 | return representation 147 | 148 | class FinetuneCLS(nn.Module): 149 | backbone: nn.Module 150 | num_classes: int 151 | 152 | @nn.nowrap 153 | def rng_keys(self): 154 | return ('params', 'noise', 'drop_path') 155 | 156 | @nn.compact 157 | def __call__(self, x, deterministic=False): 158 | output = self.backbone.forward_image_representation(x, deterministic) 159 | x = combine_representation(output) 160 | x = nn.LayerNorm()(x) 161 | x = nn.Dense(self.num_classes)(x) 162 | logits = x 163 | return logits 164 | 165 | backbone = LQAE(FLAGS.lqae) 166 | model = FinetuneCLS( 167 | backbone=backbone, 168 | num_classes=train_dataset.num_classes(), 169 | ) 170 | 171 | learning_rate = optax.warmup_cosine_decay_schedule( 172 | init_value=FLAGS.lr_init_value * lr_scale, 173 | peak_value=FLAGS.lr_peak_value * lr_scale, 174 | warmup_steps=FLAGS.lr_warmup_epochs * steps_per_epoch, 175 | decay_steps=total_steps, 176 | end_value=FLAGS.lr_end_value * lr_scale, 177 | ) 178 | 179 | @partial(jax.pmap, axis_name="pmap", donate_argnums=[0]) 180 | def train_step_fn(state, rng, image, label): 181 | rng_generator = JaxRNG(rng) 182 | def loss_fn(params): 183 | logits = model.apply( 184 | params, 185 | image, 186 | deterministic=False, 187 | rngs=rng_generator(keys=backbone.rng_keys()), 188 | ) 189 | loss = cross_entropy_loss(logits, label) 190 | accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == jnp.argmax(label, axis=-1)) 191 | aux = dict(loss=loss, accuracy=accuracy) 192 | return loss, aux 193 | grad_fn = jax.value_and_grad(loss_fn, has_aux=True) 194 | (loss, aux), grads = jax.lax.pmean( 195 | grad_fn(state.params), 196 | axis_name="pmap", 197 | ) 198 | aux["learning_rate"] = learning_rate(state.step) 199 | state = state.apply_gradients(grads=grads) 200 | return state, aux, rng_generator() 201 | 202 | @partial(jax.pmap, axis_name="pmap") 203 | def eval_step_fn(state, rng, image, label): 204 | rng_generator = JaxRNG(rng) 205 | logits = model.apply( 206 | state.params, 207 | image, 208 | deterministic=True, 209 | rngs=rng_generator(keys=backbone.rng_keys()), 210 | ) 211 | accuracy = jax.lax.pmean( 212 | jnp.mean(jnp.argmax(logits, axis=-1) == label), axis_name="pmap" 213 | ) 214 | aux = dict(accuracy=accuracy) 215 | aux = jax.lax.pmean(aux, axis_name="pmap") 216 | return aux, rng_generator() 217 | 218 | if FLAGS.load_checkpoint != "": 219 | checkpoint_data = load_pickle(FLAGS.load_checkpoint) 220 | state = flax.jax_utils.replicate(checkpoint_data, jax_devices) 221 | start_step = checkpoint_data["step"] 222 | else: 223 | image = jnp.zeros((2, 256, 256, 3), dtype=jnp.float32) 224 | rngs = next_rng(keys=backbone.rng_keys()) 225 | params = model.init(rngs, image) 226 | 227 | if FLAGS.load_pretrained != "": 228 | checkpoint_data = load_checkpoint(FLAGS.load_pretrained) 229 | checkpoint_params = checkpoint_data["state"].params["params"] 230 | checkpoint_params = flax.core.unfreeze(checkpoint_params) 231 | backbone_params = flax.core.unfreeze(params['params']["backbone"]) 232 | for key in backbone_params.keys(): 233 | if key in ["decoder", "lang_model"]: 234 | continue 235 | else: 236 | assert ( 237 | key in checkpoint_params.keys() 238 | ), f"pretrained model miss key={key}" 239 | backbone_params[key] = checkpoint_params[key] 240 | params = flax.core.unfreeze(params) 241 | params["params"].update({"backbone": backbone_params}) 242 | params = flax.core.freeze(params) 243 | 244 | # make sure BERT pretrained code is loaded 245 | if FLAGS.load_bert_params: 246 | params = backbone.load_bert_params(params, False) 247 | 248 | state = train_state.TrainState.create( 249 | params=flax.core.unfreeze(params), 250 | apply_fn=None, 251 | tx=optax.lars( 252 | learning_rate=learning_rate, 253 | weight_decay=FLAGS.weight_decay, 254 | momentum=FLAGS.momentum, 255 | ), 256 | ) 257 | state = flax.jax_utils.replicate(state, jax_devices) 258 | start_step = 0 259 | 260 | del params 261 | 262 | state = sync_state_across_devices(state) 263 | sharded_rng = jax.device_put_sharded(next_rng(n_devices), jax_devices) 264 | 265 | def generate_batch(iterator): 266 | while True: 267 | for batch in iterator: 268 | imgs = batch[0].numpy() 269 | imgs = imgs.reshape(n_devices, -1, *imgs.shape[1:]) 270 | labels = batch[1].numpy() 271 | labels = labels.reshape(n_devices, -1, *labels.shape[1:]) 272 | yield tuple([imgs, labels]) 273 | 274 | train_iterator = prefetch_to_device(generate_batch(train_loader), 2, jax_devices) 275 | val_iterator = prefetch_to_device(generate_batch(val_loader), 2, jax_devices) 276 | 277 | best_val_acc = 0.0 278 | step_counter = trange(start_step, total_steps, desc="train", ncols=0) 279 | 280 | for step, (image, label) in zip(step_counter, train_iterator): 281 | epoch = step // steps_per_epoch 282 | if step % steps_per_epoch == 0: 283 | train_metrics = [] 284 | 285 | image = image.astype(jnp.float32) 286 | label = label.astype(jnp.int32) 287 | 288 | state, metrics, sharded_rng = train_step_fn(state, sharded_rng, image, label) 289 | train_metrics.append(metrics) 290 | 291 | if step % FLAGS.log_freq == 0: 292 | log_metrics = get_metrics(train_metrics, unreplicate=True, stack=True) 293 | log_metrics = { 294 | f"train_{k}": v 295 | for k, v in jax.tree_map(lambda x: x.mean(), log_metrics).items() 296 | } 297 | log_metrics.update({"step": step, "epoch": epoch}) 298 | logger.log(log_metrics) 299 | tqdm.write("\n" + pprint.pformat(log_metrics) + "\n") 300 | 301 | if FLAGS.save_model_freq > 0 and step % FLAGS.save_model_freq == 0: 302 | save_data = { 303 | "step": step, 304 | "epoch": epoch, 305 | "variant": variant, 306 | "state": jax.device_get(flax.jax_utils.unreplicate(state)), 307 | "best_val_acc": best_val_acc, 308 | } 309 | if jax_process_index == 0: 310 | logger.save_pickle(save_data, "model.pkl") 311 | 312 | if step % steps_per_epoch == 0: 313 | val_metrics = [] 314 | for _, (image, label) in zip( 315 | trange(val_steps, desc="val", ncols=0), val_iterator 316 | ): 317 | image = image.astype(jnp.float32) 318 | label = label.astype(jnp.int32) 319 | 320 | metrics, sharded_rng = eval_step_fn(state, sharded_rng, image, label) 321 | val_metrics.append(metrics) 322 | 323 | log_metrics = get_metrics(val_metrics, unreplicate=True, stack=True) 324 | accuracy = log_metrics["accuracy"].mean() 325 | log_metrics = { 326 | f"val_{k}": v 327 | for k, v in jax.tree_map(lambda x: x.mean(), log_metrics).items() 328 | } 329 | log_metrics.update({"step": step, "epoch": epoch}) 330 | logger.log(log_metrics) 331 | tqdm.write("\n" + pprint.pformat(log_metrics) + "\n") 332 | 333 | if accuracy > best_val_acc: 334 | best_val_acc = accuracy 335 | 336 | if FLAGS.save_model_freq > 0: 337 | save_data = { 338 | "epoch": epoch, 339 | "step": step, 340 | "variant": variant, 341 | "state": jax.device_get(flax.jax_utils.unreplicate(state)), 342 | "best_val_acc": best_val_acc, 343 | } 344 | if jax_process_index == 0: 345 | logger.save_pickle(save_data, "best_model.pkl") 346 | 347 | 348 | if __name__ == "__main__": 349 | torch.multiprocessing.set_start_method("spawn") 350 | absl.app.run(main) 351 | -------------------------------------------------------------------------------- /lqae/models/vqae.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import io 3 | import math 4 | import os 5 | from typing import Any 6 | 7 | import flax 8 | import flax.linen as nn 9 | import jax 10 | import jax.numpy as jnp 11 | import ml_collections 12 | import numpy as np 13 | import requests 14 | from flax.linen.initializers import variance_scaling 15 | from ml_collections import ConfigDict 16 | from ml_collections.config_dict import config_dict 17 | from PIL import Image, ImageFilter 18 | 19 | from .base_resnet import ResNetDecoder, ResNetEncoder 20 | from .base_vit import VitDecoder, VitEncoder 21 | from .model_utils import ( 22 | assert_avg_rnd, 23 | update_vit_config, 24 | entropy_loss_fn, 25 | normalize_func, 26 | squared_euclidean_distance, 27 | ) 28 | 29 | 30 | class VectorQuantizer(nn.Module): 31 | """Basic vector quantizer.""" 32 | 33 | config: ConfigDict 34 | dtype: int = jnp.float32 35 | 36 | def setup(self): 37 | if self.config.quantizer_latent_dim > 0: 38 | self.input_to_latent = nn.Dense( 39 | self.config.quantizer_latent_dim, dtype=self.dtype 40 | ) 41 | self.code_to_latent = nn.Dense( 42 | self.config.quantizer_latent_dim, dtype=self.dtype 43 | ) 44 | else: 45 | self.input_to_latent = self.code_to_latent = lambda x: x 46 | 47 | @nn.compact 48 | def __call__(self, x, train, rng): 49 | l2_normalize = lambda x, axis=1: normalize_func( 50 | x, axis=axis, use_l2_normalize=self.config.l2_normalize 51 | ) 52 | codebook_size = self.config.codebook_size 53 | embed_init = variance_scaling(1.0, "fan_in", "normal", out_axis=0) 54 | codebook = self.param( 55 | "codebook", 56 | embed_init, 57 | (codebook_size, x.shape[-1]), 58 | ) 59 | if self.config.strawman_codebook: 60 | strawman_codebook = self.param( 61 | "strawman_codebook", 62 | jax.nn.initializers.normal(0.02, dtype=jnp.float32), 63 | (codebook_size, self.config.quantizer_latent_dim), 64 | ) 65 | strawman_codebook = jnp.asarray(strawman_codebook, dtype=self.dtype) 66 | latent_input = self.input_to_latent(jnp.reshape(x, (-1, x.shape[-1]))) 67 | latent_input = l2_normalize(latent_input, axis=1) 68 | sg_strawman_codebook = jax.lax.stop_gradient( 69 | l2_normalize(strawman_codebook, axis=1) 70 | ) 71 | distances = jnp.reshape( 72 | squared_euclidean_distance(latent_input, sg_strawman_codebook), 73 | x.shape[:-1] + (codebook_size,), 74 | ) 75 | else: 76 | codebook = jnp.asarray(codebook, dtype=self.dtype) 77 | latent_input = self.input_to_latent(jnp.reshape(x, (-1, x.shape[-1]))) 78 | latent_input = l2_normalize(latent_input, axis=1) 79 | latent_codebook = self.code_to_latent(codebook) 80 | latent_codebook = l2_normalize(latent_codebook, axis=1) 81 | sg_latent_codebook = jax.lax.stop_gradient( 82 | l2_normalize(latent_codebook, axis=1) 83 | ) 84 | distances = jnp.reshape( 85 | squared_euclidean_distance(latent_input, sg_latent_codebook), 86 | x.shape[:-1] + (codebook_size,), 87 | ) 88 | 89 | encoding_indices = jax.lax.approx_min_k( 90 | distances, 91 | k=self.config.top_k_value, 92 | reduction_dimension=-1, 93 | aggregate_to_topk=True, 94 | )[1] 95 | 96 | encoding_indices, encodings, quantized = self.get_encoding_quantized( 97 | encoding_indices, train, rng, codebook_size 98 | ) 99 | 100 | codebook_usage = jnp.sum(encodings, axis=(0, 1)) > 0 101 | codebook_usage = jnp.sum(codebook_usage) / codebook_size 102 | if self.config.top_k_avg: 103 | codebook_usage = codebook_usage / self.config.top_k_value 104 | result_dict = dict() 105 | if train: 106 | result_dict = self.get_train_loss(quantized, x, distances) 107 | 108 | if self.config.strawman_codebook: 109 | strawman_quantized = self.quantize_strawman(encodings) 110 | strawman_result_dict = self.get_train_loss( 111 | strawman_quantized, self.input_to_latent(x), distances 112 | ) 113 | for k, v in result_dict.items(): 114 | result_dict[k] = v + strawman_result_dict[k] 115 | else: 116 | latent_quantized = self.code_to_latent(quantized) 117 | latent_result_dict = self.get_train_loss( 118 | latent_quantized, self.input_to_latent(x), distances 119 | ) 120 | for k, v in result_dict.items(): 121 | result_dict[k] = v + latent_result_dict[k] 122 | 123 | quantized = x + jax.lax.stop_gradient(quantized - x) 124 | 125 | avg_probs = jnp.mean(encodings.reshape(-1, encodings.shape[-1]), axis=0) 126 | log_perplexity = -jnp.sum(avg_probs * jnp.log(avg_probs + 1e-10)) 127 | perplexity = jnp.exp(log_perplexity) 128 | 129 | if "quantizer_loss" in result_dict: 130 | result_dict["quantizer_loss"] = ( 131 | result_dict["quantizer_loss"] 132 | + self.config.quantizer_loss_perplexity * log_perplexity 133 | ) 134 | result_dict.update( 135 | { 136 | "encodings": encodings, 137 | "encoding_indices": encoding_indices, 138 | "raw": x, 139 | "perplexity": perplexity, 140 | "codebook_usage": codebook_usage, 141 | } 142 | ) 143 | return quantized, result_dict 144 | 145 | def quantize(self, z: jnp.ndarray) -> jnp.ndarray: 146 | return jnp.dot(z, self.variables["params"]["codebook"]) 147 | 148 | def get_codebook(self) -> jnp.ndarray: 149 | return self.variables["params"]["codebook"] 150 | 151 | def decode_ids(self, ids: jnp.ndarray) -> jnp.ndarray: 152 | return jnp.take(self.variables["params"]["codebook"], ids, axis=0) 153 | 154 | def quantize_strawman(self, z: jnp.ndarray) -> jnp.ndarray: 155 | return jnp.dot(z, self.variables["params"]["strawman_codebook"]) 156 | 157 | def get_train_loss(self, quantized, x, distances): 158 | e_latent_loss = ( 159 | jnp.mean((jax.lax.stop_gradient(quantized) - x) ** 2) 160 | * self.config.quantizer_loss_commitment 161 | ) 162 | q_latent_loss = jnp.mean((quantized - jax.lax.stop_gradient(x)) ** 2) 163 | entropy_loss = 0.0 164 | if self.config.quantizer_loss_entropy != 0: 165 | entropy_loss = ( 166 | entropy_loss_fn( 167 | -distances, 168 | loss_type=self.config.entropy_loss_type, 169 | temperature=self.config.entropy_temperature, 170 | ) 171 | * self.config.quantizer_loss_entropy 172 | ) 173 | e_latent_loss = jnp.asarray(e_latent_loss, jnp.float32) 174 | q_latent_loss = jnp.asarray(q_latent_loss, jnp.float32) 175 | entropy_loss = jnp.asarray(entropy_loss, jnp.float32) 176 | loss = e_latent_loss + q_latent_loss + entropy_loss 177 | 178 | result_dict = dict( 179 | quantizer_loss=loss, 180 | e_latent_loss=e_latent_loss, 181 | q_latent_loss=q_latent_loss, 182 | entropy_loss=entropy_loss, 183 | ) 184 | return result_dict 185 | 186 | def get_encoding_quantized(self, encoding_indices, train, rng, codebook_size): 187 | if self.config.top_k_rnd: 188 | if train: 189 | encoding_indices = jax.random.choice(rng, encoding_indices, axis=-1) 190 | else: 191 | encoding_indices = encoding_indices[..., 0] 192 | encodings = jax.nn.one_hot( 193 | encoding_indices, codebook_size, dtype=self.dtype 194 | ) 195 | quantized = self.quantize(encodings) 196 | elif self.config.top_k_avg: 197 | encodings = jax.nn.one_hot( 198 | encoding_indices, codebook_size, dtype=self.dtype 199 | ) 200 | quantized = self.quantize(encodings) 201 | quantized = jnp.mean(quantized, axis=-2) 202 | encoding_indices = encoding_indices[..., 0] 203 | else: 204 | encoding_indices = encoding_indices[..., 0] 205 | encodings = jax.nn.one_hot( 206 | encoding_indices, codebook_size, dtype=self.dtype 207 | ) 208 | quantized = self.quantize(encodings) 209 | return encoding_indices, encodings, quantized 210 | 211 | 212 | class VQAE(nn.Module): 213 | config_updates: ... = None 214 | dtype: int = jnp.float32 215 | 216 | @staticmethod 217 | @nn.nowrap 218 | def get_default_config(updates=None): 219 | config = ConfigDict() 220 | 221 | # Quantizer config 222 | config.quantizer_loss_entropy = 0.0 223 | config.entropy_temperature = 0.01 224 | config.entropy_loss_type = "softmax" 225 | config.quantizer_loss_commitment = 0.25 226 | config.l2_normalize = False 227 | config.top_k_value = 1 228 | config.top_k_avg = False 229 | config.top_k_rnd = False 230 | config.quantizer_latent_dim = 0 231 | config.strawman_codebook = False 232 | config.quantizer_loss_perplexity = 0.0 233 | # VQ quantizer config 234 | config.codebook_size = 1024 235 | 236 | # ResNet config 237 | config.filters = 128 238 | config.num_res_blocks = 2 239 | config.channel_multipliers = [1, 1, 2, 2, 4] 240 | config.hidden_size = 768 241 | config.conv_downsample = False 242 | 243 | # VIT config 244 | config.vit_encoder_decoder = False 245 | config.vit_model_type = config_dict.placeholder(str) 246 | config.patch_size = 16 247 | config.dropout = 0.0 248 | config.hidden_size = 768 249 | config.mlp_ratio = 4 250 | config.intermediate_size = config.hidden_size * config.mlp_ratio 251 | 252 | # Bert config 253 | config.bert = "roberta-base" 254 | config.bert_min_ratio = 0.15 255 | config.bert_max_ratio = 0.15 256 | config.use_bert_codebook = True 257 | config.bert_loss_mask_only = True 258 | config.bert_mask_loss_weight = 0.0 259 | config.bert_channel_image_loss_weight = 0.0 260 | config.nochannel_image_loss_weight = 0.0 261 | 262 | if updates is not None: 263 | config.update(ConfigDict(updates).copy_and_resolve_references()) 264 | 265 | if config.vit_model_type is not None: 266 | update_vit_config(config.vit_model_type, config) 267 | assert_avg_rnd(config) 268 | 269 | return config 270 | 271 | @nn.nowrap 272 | def get_config(self): 273 | return self.get_default_config(self.config_updates) 274 | 275 | @nn.nowrap 276 | def rng_keys(self): 277 | return ("params", "dropout", "drop_path", "quantizer") 278 | 279 | @nn.nowrap 280 | def no_decay_list(self): 281 | no_decay = ["bias", "embedding"] 282 | return no_decay 283 | 284 | def setup(self): 285 | self.config = self.get_default_config(self.config_updates) 286 | self.quantizer = VectorQuantizer(config=self.config, dtype=self.dtype) 287 | if self.config.vit_encoder_decoder: 288 | self.encoder = VitEncoder(config=self.config, dtype=self.dtype) 289 | self.decoder = VitDecoder(config=self.config, dtype=self.dtype) 290 | else: 291 | self.encoder = ResNetEncoder(config=self.config, dtype=self.dtype) 292 | self.decoder = ResNetDecoder(config=self.config, dtype=self.dtype) 293 | 294 | def encode(self, image, train): 295 | encoded_feature, _ = self.encoder(image, train) 296 | quantized, result_dict = self.quantizer( 297 | encoded_feature, train, self.make_rng("quantizer") 298 | ) 299 | return quantized, result_dict 300 | 301 | def forward_image_representation(self, image, train): 302 | output = {} 303 | encoded_feature, encoder_embedding = self.encoder(image, train) 304 | if encoder_embedding is not None: 305 | encoder_embedding = jax.tree_util.tree_map( 306 | lambda x: jnp.reshape(x, (x.shape[0], -1, x.shape[-1])), 307 | encoder_embedding, 308 | ) 309 | output["encoder_embedding"] = encoder_embedding 310 | else: 311 | encoded_feature = jax.tree_util.tree_map( 312 | lambda x: jnp.reshape(x, (x.shape[0], -1, x.shape[-1])), encoded_feature 313 | ) 314 | output["encoder_embedding"] = [encoded_feature] 315 | all_embedding = [x for x in output["encoder_embedding"]] 316 | output["all_embedding"] = all_embedding 317 | return output 318 | 319 | def decode(self, x: jnp.ndarray, train) -> jnp.ndarray: 320 | reconstructed = self.decoder(x, train) 321 | return reconstructed 322 | 323 | def get_codebook_funct(self): 324 | return self.quantizer.get_codebook() 325 | 326 | def decode_from_indices(self, inputs, train): 327 | if isinstance(inputs, dict): 328 | ids = inputs["encoding_indices"] 329 | else: 330 | ids = inputs 331 | features = self.quantizer.decode_ids(ids) 332 | reconstructed_image = self.decode(features, train) 333 | return reconstructed_image 334 | 335 | def encode_to_indices(self, inputs, train): 336 | if isinstance(inputs, dict): 337 | image = inputs["image"] 338 | else: 339 | image = inputs 340 | encoded_feature, _ = self.encoder(image, train) 341 | _, result_dict = self.quantizer( 342 | encoded_feature, train, self.make_rng("quantizer") 343 | ) 344 | ids = result_dict["encoding_indices"] 345 | return ids 346 | 347 | def __call__(self, image, train, ratio=None): 348 | del ratio 349 | quantized, result_dict = self.encode(image, train) 350 | image_output = self.decode(quantized, train) 351 | return {"image_output": image_output}, result_dict 352 | -------------------------------------------------------------------------------- /lqae/main/finetune_main.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import os 3 | import pprint 4 | from copy import copy, deepcopy 5 | from functools import partial 6 | from typing import Any, Callable, Optional 7 | 8 | import absl.app 9 | import absl.flags 10 | import flax 11 | import jax 12 | import jax.numpy as jnp 13 | import numpy as np 14 | import optax 15 | import torch 16 | from flax import linen as nn 17 | from flax.jax_utils import prefetch_to_device 18 | from flax.training import train_state 19 | from tqdm.auto import tqdm, trange 20 | 21 | from ..data import ImageNetDataset 22 | from ..jax_utils import ( 23 | JaxRNG, 24 | cross_entropy_loss, 25 | get_metrics, 26 | next_rng, 27 | sync_state_across_devices, 28 | mixup_cutmix, 29 | ) 30 | from ..models import LQAE 31 | from ..utils import ( 32 | WandBLogger, 33 | define_flags_with_default, 34 | get_user_flags, 35 | load_checkpoint, 36 | load_pickle, 37 | set_random_seed, 38 | ) 39 | 40 | 41 | FLAGS_DEF = define_flags_with_default( 42 | seed=42, 43 | epochs=200, 44 | batch_size=0, 45 | dataloader_n_workers=0, 46 | dataloader_shuffle=False, 47 | log_freq=50, 48 | save_model_freq=0, 49 | lr_init_value=0.0, 50 | lr_end_value=0.0, 51 | lr_peak_value=1.0e-4, 52 | lr_warmup_epochs=0, 53 | weight_decay=0.0001, 54 | clip_gradient=1e9, 55 | mixup_alpha = 0.8, 56 | cutmix_alpha = 1.0, 57 | switch_prob = 0.5, 58 | label_smoothing = 0.1, 59 | load_pretrained="", 60 | load_checkpoint="", 61 | last_embedding_layers="all", 62 | imagenet_data=ImageNetDataset.get_default_config(), 63 | lqae=LQAE.get_default_config(), 64 | logging=WandBLogger.get_default_config(), 65 | log_all_worker=False, 66 | load_bert_params=True, 67 | ) 68 | FLAGS = absl.flags.FLAGS 69 | 70 | 71 | def main(argv): 72 | FLAGS = absl.flags.FLAGS 73 | variant = get_user_flags(FLAGS, FLAGS_DEF) 74 | 75 | variant["jax_process_index"] = jax_process_index = jax.process_index() 76 | variant["jax_process_count"] = jax_process_count = jax.process_count() 77 | assert FLAGS.batch_size % jax_process_count == 0 78 | variant["process_batch_size"] = process_batch_size = ( 79 | FLAGS.batch_size // jax_process_count 80 | ) 81 | variant["device_batch_size"] = process_batch_size // jax.local_device_count() 82 | lr_scale = FLAGS.batch_size / 256 83 | variant["effective_lr"] = FLAGS.lr_peak_value * lr_scale 84 | jax_devices = jax.local_devices() 85 | n_devices = len(jax_devices) 86 | assert process_batch_size % n_devices == 0 87 | 88 | logger = WandBLogger( 89 | config=FLAGS.logging, 90 | variant=variant, 91 | enable=FLAGS.log_all_worker or (jax_process_index == 0), 92 | ) 93 | set_random_seed(FLAGS.seed * (jax_process_index + 1)) 94 | 95 | train_dataset = ImageNetDataset( 96 | FLAGS.imagenet_data, jax_process_index / jax_process_count 97 | ) 98 | val_flags = deepcopy(FLAGS.imagenet_data) 99 | val_flags.partition = "val" 100 | val_flags.transform_type = "test" 101 | val_dataset = ImageNetDataset(val_flags, jax_process_index / jax_process_count) 102 | 103 | steps_per_epoch = int(len(train_dataset) / FLAGS.batch_size) 104 | total_steps = steps_per_epoch * FLAGS.epochs 105 | val_steps = int(len(val_dataset) / FLAGS.batch_size) 106 | 107 | train_loader = torch.utils.data.DataLoader( 108 | train_dataset, 109 | batch_size=process_batch_size, 110 | shuffle=FLAGS.dataloader_shuffle, 111 | num_workers=FLAGS.dataloader_n_workers, 112 | prefetch_factor=2, 113 | persistent_workers=FLAGS.dataloader_n_workers > 0, 114 | drop_last=True, 115 | ) 116 | 117 | val_loader = torch.utils.data.DataLoader( 118 | val_dataset, 119 | batch_size=process_batch_size, 120 | shuffle=False, 121 | num_workers=FLAGS.dataloader_n_workers, 122 | prefetch_factor=2, 123 | persistent_workers=FLAGS.dataloader_n_workers > 0, 124 | drop_last=True, 125 | ) 126 | 127 | def combine_representation(output): 128 | encoder_embedding = [x for x in output["encoder_embedding"]] 129 | # Stop gradient on encoder_embedding in Finetuning 130 | encoder_embedding = jax.lax.stop_gradient(encoder_embedding) 131 | bert_embedding = [x for x in output["bert_embedding"]] 132 | all_embedding = encoder_embedding + bert_embedding 133 | if FLAGS.last_embedding_layers == "all": 134 | representation = all_embedding 135 | elif FLAGS.last_embedding_layers == "all_bert": 136 | representation = bert_embedding 137 | elif FLAGS.last_embedding_layers == "all_encoder": 138 | representation = encoder_embedding 139 | else: 140 | representation = [ 141 | all_embedding[-int(x)] for x in FLAGS.last_embedding_layers.split(",") 142 | ] 143 | representation = jax.tree_util.tree_map( 144 | lambda x: jnp.mean(x[:, 1:, :], axis=1), representation 145 | ) 146 | representation = jnp.concatenate(representation, axis=-1) 147 | return representation 148 | 149 | class FinetuneCLS(nn.Module): 150 | backbone: nn.Module 151 | num_classes: int 152 | 153 | @nn.nowrap 154 | def rng_keys(self): 155 | return ('params', 'noise', 'drop_path') 156 | 157 | @nn.compact 158 | def __call__(self, x, deterministic=False): 159 | output = self.backbone.forward_image_representation(x, deterministic) 160 | x = combine_representation(output) 161 | x = nn.LayerNorm()(x) 162 | x = nn.Dense(self.num_classes)(x) 163 | logits = x 164 | return logits 165 | 166 | backbone = LQAE(FLAGS.lqae) 167 | model = FinetuneCLS( 168 | backbone=backbone, 169 | num_classes=train_dataset.num_classes(), 170 | ) 171 | 172 | learning_rate = optax.warmup_cosine_decay_schedule( 173 | init_value=FLAGS.lr_init_value * lr_scale, 174 | peak_value=FLAGS.lr_peak_value * lr_scale, 175 | warmup_steps=FLAGS.lr_warmup_epochs * steps_per_epoch, 176 | decay_steps=total_steps, 177 | end_value=FLAGS.lr_end_value * lr_scale, 178 | ) 179 | 180 | mixup_cutmix_fn = partial( 181 | mixup_cutmix, 182 | num_classes=train_dataset.num_classes(), 183 | mixup_alpha=FLAGS.mixup_alpha, 184 | cutmix_alpha=FLAGS.cutmix_alpha, 185 | switch_prob=FLAGS.switch_prob, 186 | label_smoothing=FLAGS.label_smoothing 187 | ) 188 | cross_entropy_loss_fn = partial(cross_entropy_loss, smoothing_factor = 0. if FLAGS.mixup_alpha > 0. else FLAGS.label_smoothing) 189 | 190 | @partial(jax.pmap, axis_name="pmap", donate_argnums=[0]) 191 | def train_step_fn(state, rng, image, label): 192 | rng_generator = JaxRNG(rng) 193 | augmented_image, augmented_label = mixup_cutmix_fn(image, label, rng) 194 | def loss_fn(params): 195 | logits = model.apply( 196 | params, 197 | augmented_image, 198 | deterministic=False, 199 | rngs=rng_generator(keys=backbone.rng_keys()), 200 | ) 201 | loss = cross_entropy_loss_fn(logits, augmented_label) 202 | accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == jnp.argmax(augmented_label, axis=-1)) 203 | aux = dict(loss=loss, accuracy=accuracy) 204 | return loss, aux 205 | grad_fn = jax.value_and_grad(loss_fn, has_aux=True) 206 | (loss, aux), grads = jax.lax.pmean( 207 | grad_fn(state.params), 208 | axis_name="pmap", 209 | ) 210 | aux["learning_rate"] = learning_rate(state.step) 211 | state = state.apply_gradients(grads=grads) 212 | return state, aux, rng_generator() 213 | 214 | @partial(jax.pmap, axis_name="pmap") 215 | def eval_step_fn(state, rng, image, label): 216 | rng_generator = JaxRNG(rng) 217 | logits = model.apply( 218 | state.params, 219 | image, 220 | deterministic=True, 221 | rngs=rng_generator(keys=backbone.rng_keys()), 222 | ) 223 | accuracy = jax.lax.pmean( 224 | jnp.mean(jnp.argmax(logits, axis=-1) == label), axis_name="pmap" 225 | ) 226 | aux = dict(accuracy=accuracy) 227 | aux = jax.lax.pmean(aux, axis_name="pmap") 228 | return aux, rng_generator() 229 | 230 | if FLAGS.load_checkpoint != "": 231 | checkpoint_data = load_pickle(FLAGS.load_checkpoint) 232 | state = flax.jax_utils.replicate(checkpoint_data, jax_devices) 233 | start_step = checkpoint_data["step"] 234 | else: 235 | image = jnp.zeros((2, 256, 256, 3), dtype=jnp.float32) 236 | rngs = next_rng(keys=backbone.rng_keys()) 237 | params = model.init(rngs, image) 238 | 239 | if FLAGS.load_pretrained != "": 240 | checkpoint_data = load_checkpoint(FLAGS.load_pretrained) 241 | checkpoint_params = checkpoint_data["state"].params["params"] 242 | checkpoint_params = flax.core.unfreeze(checkpoint_params) 243 | backbone_params = flax.core.unfreeze(params['params']["backbone"]) 244 | for key in backbone_params.keys(): 245 | if key in ["decoder", "lang_model"]: 246 | continue 247 | else: 248 | assert ( 249 | key in checkpoint_params.keys() 250 | ), f"pretrained model miss key={key}" 251 | backbone_params[key] = checkpoint_params[key] 252 | params = flax.core.unfreeze(params) 253 | params["params"].update({"backbone": backbone_params}) 254 | params = flax.core.freeze(params) 255 | 256 | # make sure BERT pretrained code is loaded 257 | if FLAGS.load_bert_params: 258 | params = backbone.load_bert_params(params, False) 259 | 260 | def get_weight_decay_mask(params): 261 | flattened_params = flax.traverse_util.flatten_dict( 262 | flax.core.unfreeze(params) 263 | ) 264 | def decay(key): 265 | return all([k not in backbone.no_decay_list() for k in key]) 266 | return flax.traverse_util.unflatten_dict( 267 | {key: decay(key) for key in flattened_params.keys()} 268 | ) 269 | 270 | state = train_state.TrainState.create( 271 | params=flax.core.unfreeze(params), 272 | apply_fn=None, 273 | tx=optax.adamw( 274 | learning_rate=learning_rate, 275 | weight_decay=FLAGS.weight_decay, 276 | b1=0.9, b2=0.999, 277 | mask=get_weight_decay_mask, 278 | ) 279 | ) 280 | state = flax.jax_utils.replicate(state, jax_devices) 281 | start_step = 0 282 | 283 | del params 284 | 285 | state = sync_state_across_devices(state) 286 | sharded_rng = jax.device_put_sharded(next_rng(n_devices), jax_devices) 287 | 288 | def generate_batch(iterator): 289 | while True: 290 | for batch in iterator: 291 | imgs = batch[0].numpy() 292 | imgs = imgs.reshape(n_devices, -1, *imgs.shape[1:]) 293 | labels = batch[1].numpy() 294 | labels = labels.reshape(n_devices, -1, *labels.shape[1:]) 295 | yield tuple([imgs, labels]) 296 | 297 | train_iterator = prefetch_to_device(generate_batch(train_loader), 2, jax_devices) 298 | val_iterator = prefetch_to_device(generate_batch(val_loader), 2, jax_devices) 299 | 300 | best_val_acc = 0.0 301 | step_counter = trange(start_step, total_steps, desc="train", ncols=0) 302 | 303 | for step, (image, label) in zip(step_counter, train_iterator): 304 | epoch = step // steps_per_epoch 305 | if step % steps_per_epoch == 0: 306 | train_metrics = [] 307 | 308 | image = image.astype(jnp.float32) 309 | label = label.astype(jnp.int32) 310 | 311 | state, metrics, sharded_rng = train_step_fn(state, sharded_rng, image, label) 312 | train_metrics.append(metrics) 313 | 314 | if step % FLAGS.log_freq == 0: 315 | log_metrics = get_metrics(train_metrics, unreplicate=True, stack=True) 316 | log_metrics = { 317 | f"train_{k}": v 318 | for k, v in jax.tree_map(lambda x: x.mean(), log_metrics).items() 319 | } 320 | log_metrics.update({"step": step, "epoch": epoch}) 321 | logger.log(log_metrics) 322 | tqdm.write("\n" + pprint.pformat(log_metrics) + "\n") 323 | 324 | if FLAGS.save_model_freq > 0 and step % FLAGS.save_model_freq == 0: 325 | save_data = { 326 | "step": step, 327 | "epoch": epoch, 328 | "variant": variant, 329 | "state": jax.device_get(flax.jax_utils.unreplicate(state)), 330 | "best_val_acc": best_val_acc, 331 | } 332 | if jax_process_index == 0: 333 | logger.save_pickle(save_data, "model.pkl") 334 | 335 | if step % steps_per_epoch == 0: 336 | val_metrics = [] 337 | for _, (image, label) in zip( 338 | trange(val_steps, desc="val", ncols=0), val_iterator 339 | ): 340 | image = image.astype(jnp.float32) 341 | label = label.astype(jnp.int32) 342 | 343 | metrics, sharded_rng = eval_step_fn(state, sharded_rng, image, label) 344 | val_metrics.append(metrics) 345 | 346 | log_metrics = get_metrics(val_metrics, unreplicate=True, stack=True) 347 | accuracy = log_metrics["accuracy"].mean() 348 | log_metrics = { 349 | f"val_{k}": v 350 | for k, v in jax.tree_map(lambda x: x.mean(), log_metrics).items() 351 | } 352 | log_metrics.update({"step": step, "epoch": epoch}) 353 | logger.log(log_metrics) 354 | tqdm.write("\n" + pprint.pformat(log_metrics) + "\n") 355 | 356 | if accuracy > best_val_acc: 357 | best_val_acc = accuracy 358 | 359 | if FLAGS.save_model_freq > 0: 360 | save_data = { 361 | "epoch": epoch, 362 | "step": step, 363 | "variant": variant, 364 | "state": jax.device_get(flax.jax_utils.unreplicate(state)), 365 | "best_val_acc": best_val_acc, 366 | } 367 | if jax_process_index == 0: 368 | logger.save_pickle(save_data, "best_model.pkl") 369 | 370 | 371 | if __name__ == "__main__": 372 | torch.multiprocessing.set_start_method("spawn") 373 | absl.app.run(main) 374 | -------------------------------------------------------------------------------- /lqae/data/data.py: -------------------------------------------------------------------------------- 1 | import threading 2 | from io import BytesIO 3 | from queue import Queue 4 | 5 | import gcsfs 6 | import h5py 7 | import numpy as np 8 | import skimage.io 9 | import torch 10 | import torchvision 11 | import transformers 12 | from ml_collections import ConfigDict 13 | from PIL import Image 14 | from skimage.color import gray2rgb, rgba2rgb 15 | from timm.data import create_transform 16 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 17 | from torchvision import transforms 18 | 19 | 20 | class ImageTextDataset(torch.utils.data.Dataset): 21 | @staticmethod 22 | def get_default_config(updates=None): 23 | config = ConfigDict() 24 | config.path = "" 25 | 26 | config.start_index = 0 27 | config.max_length = int(1e9) 28 | config.random_start = False 29 | 30 | config.image_only = False 31 | config.tokenize = True 32 | config.tokenizer = "bert-base-uncased" 33 | config.tokenizer_max_length = 64 34 | 35 | config.transform_type = "pretrain" 36 | config.image_size = 256 37 | 38 | config.image_normalization = "imagenet" 39 | config.custom_image_mean = "" 40 | config.custom_image_std = "" 41 | 42 | config.random_drop_text = 0.0 43 | config.deterministic_drop_text = 0.0 44 | 45 | if updates is not None: 46 | config.update(ConfigDict(updates).copy_and_resolve_references()) 47 | return config 48 | 49 | def __init__(self, config, start_offset_ratio=None): 50 | self.config = self.get_default_config(config) 51 | assert self.config.path != "" 52 | 53 | if self.config.image_normalization == "imagenet": 54 | self.image_mean = (0.485, 0.456, 0.406) 55 | self.image_std = (0.229, 0.224, 0.225) 56 | elif self.config.image_normalization == "cc12m": 57 | self.image_mean = (0.5762, 0.5503, 0.5213) 58 | self.image_std = (0.3207, 0.3169, 0.3307) 59 | elif self.config.image_normalization == "none": 60 | self.image_mean = (0.0, 0.0, 0.0) 61 | self.image_std = (1.0, 1.0, 1.0) 62 | elif self.config.image_normalization == "custom": 63 | self.image_mean = tuple( 64 | float(x) for x in self.config.custom_image_mean.split("-") 65 | ) 66 | self.image_std = tuple( 67 | float(x) for x in self.config.custom_image_std.split("-") 68 | ) 69 | assert len(self.image_mean) == len(self.image_std) == 3 70 | else: 71 | raise ValueError("Unsupported image normalization mode!") 72 | 73 | if self.config.path.startswith("gs://"): 74 | # Loading from GCS 75 | self.h5_file = h5py.File( 76 | gcsfs.GCSFileSystem().open(self.config.path, cache_type="block"), "r" 77 | ) 78 | else: 79 | self.h5_file = h5py.File(self.config.path, "r") 80 | 81 | if self.config.transform_type == "pretrain": 82 | # Use Kaiming's simple pretrain processing 83 | self.transform = transforms.Compose( 84 | [ 85 | transforms.RandomResizedCrop( 86 | self.config.image_size, 87 | scale=(0.2, 1.0), 88 | interpolation=transforms.InterpolationMode.BICUBIC, 89 | ), 90 | transforms.RandomHorizontalFlip(), 91 | transforms.ToTensor(), 92 | transforms.Normalize(mean=self.image_mean, std=self.image_std), 93 | ] 94 | ) 95 | elif self.config.transform_type == "finetune": 96 | # Use Kaiming's finetune processing 97 | self.transform = create_transform( 98 | input_size=self.config.image_size, 99 | is_training=True, 100 | color_jitter=True, 101 | auto_augment=None, 102 | interpolation="bicubic", 103 | re_prob=0, 104 | re_mode=0, 105 | re_count="const", 106 | mean=self.image_mean, 107 | std=self.image_std, 108 | ) 109 | elif self.config.transform_type == "test": 110 | self.transform = transforms.Compose( 111 | [ 112 | transforms.Resize( 113 | self.config.image_size, 114 | interpolation=transforms.InterpolationMode.BICUBIC, 115 | ), 116 | transforms.CenterCrop(self.config.image_size), 117 | transforms.ToTensor(), 118 | transforms.Normalize(mean=self.image_mean, std=self.image_std), 119 | ] 120 | ) 121 | elif self.config.transform_type == "resize_only": 122 | self.transform = transforms.Compose( 123 | [ 124 | transforms.Resize( 125 | self.config.image_size, 126 | interpolation=transforms.InterpolationMode.BICUBIC, 127 | ), 128 | transforms.CenterCrop(self.config.image_size), 129 | transforms.ToTensor(), 130 | ] 131 | ) 132 | else: 133 | raise ValueError("Unsupported transform_type!") 134 | 135 | if self.config.tokenize: 136 | self.tokenizer = transformers.BertTokenizer.from_pretrained( 137 | self.config.tokenizer 138 | ) 139 | 140 | if self.config.random_start: 141 | # Bypass numpy random seed 142 | self.random_start_offset = np.random.default_rng().choice(len(self)) 143 | elif start_offset_ratio is not None: 144 | self.random_start_offset = int(len(self) * start_offset_ratio) % len(self) 145 | else: 146 | self.random_start_offset = 0 147 | 148 | def __getstate__(self): 149 | return self.config, self.random_start_offset 150 | 151 | def __setstate__(self, state): 152 | config, random_start_offset = state 153 | self.__init__(config) 154 | self.random_start_offset = random_start_offset 155 | 156 | def __len__(self): 157 | return min( 158 | self.h5_file["jpg"].shape[0] - self.config.start_index, 159 | self.config.max_length, 160 | ) 161 | 162 | def process_index(self, index): 163 | index = (index + self.random_start_offset) % len(self) 164 | return index + self.config.start_index 165 | 166 | def drop_text(self, raw_index): 167 | deterministic_drop = ( 168 | float(raw_index % 100) / 100.0 < self.config.deterministic_drop_text 169 | ) 170 | random_drop = np.random.rand() < self.config.random_drop_text 171 | return deterministic_drop or random_drop 172 | 173 | def __getitem__(self, raw_index): 174 | index = self.process_index(raw_index) 175 | with BytesIO(self.h5_file["jpg"][index]) as fin: 176 | image = skimage.io.imread(fin) 177 | 178 | if len(image.shape) == 2: 179 | image = gray2rgb(image) 180 | elif image.shape[-1] == 4: 181 | image = rgba2rgb(image) 182 | 183 | image = ( 184 | self.transform(Image.fromarray(np.uint8(image))).permute(1, 2, 0).numpy() 185 | ) 186 | image = image.astype(np.float32) 187 | if self.config.image_only: 188 | return image 189 | 190 | with BytesIO(self.h5_file["caption"][index]) as fin: 191 | caption = fin.read().decode("utf-8") 192 | 193 | if not self.config.tokenize: 194 | return image, caption 195 | 196 | if len(caption) == 0 or self.drop_text(raw_index): 197 | tokenized_caption = np.zeros( 198 | self.config.tokenizer_max_length, dtype=np.int32 199 | ) 200 | padding_mask = np.ones(self.config.tokenizer_max_length, dtype=np.float32) 201 | return image, tokenized_caption, padding_mask 202 | 203 | encoded_caption = self.tokenizer( 204 | caption, 205 | padding="max_length", 206 | truncation=True, 207 | max_length=self.config.tokenizer_max_length, 208 | return_tensors="np", 209 | add_special_tokens=False, 210 | ) 211 | 212 | if encoded_caption["input_ids"][0].size == 0: # Empty token 213 | tokenized_caption = np.zeros( 214 | self.config.tokenizer_max_length, dtype=np.int32 215 | ) 216 | padding_mask = np.ones(self.config.tokenizer_max_length, dtype=np.float32) 217 | else: 218 | tokenized_caption = encoded_caption["input_ids"][0] 219 | padding_mask = 1.0 - encoded_caption["attention_mask"][0].astype(np.float32) 220 | 221 | return image, tokenized_caption, padding_mask 222 | 223 | @property 224 | def vocab_size(self): 225 | return self.tokenizer.vocab_size 226 | 227 | @property 228 | def text_length(self): 229 | return self.config.tokenizer_max_length 230 | 231 | 232 | class ImageNetDataset(torch.utils.data.Dataset): 233 | @staticmethod 234 | def get_default_config(updates=None): 235 | config = ConfigDict() 236 | config.path = "" 237 | config.partition = "train" 238 | config.image_only = True 239 | 240 | config.start_index = 0 241 | config.max_length = int(1e9) 242 | config.random_start = False 243 | 244 | config.image_normalization = "none" 245 | config.transform_type = "pretrain" 246 | config.image_size = 256 247 | 248 | config.autoaug = "rand-m9-mstd0.5-inc1" 249 | 250 | if updates is not None: 251 | config.update(ConfigDict(updates).copy_and_resolve_references()) 252 | return config 253 | 254 | def __init__(self, config, start_offset_ratio=None): 255 | self.config = self.get_default_config(config) 256 | assert self.config.path != "" 257 | 258 | if self.config.path.startswith("gs://"): 259 | # Loading from GCS 260 | self.h5_file = h5py.File( 261 | gcsfs.GCSFileSystem().open(self.config.path, cache_type="block"), "r" 262 | ) 263 | else: 264 | self.h5_file = h5py.File(self.config.path, "r") 265 | 266 | if self.config.image_normalization == "imagenet": 267 | self.image_mean = (0.485, 0.456, 0.406) 268 | self.image_std = (0.229, 0.224, 0.225) 269 | elif self.config.image_normalization == "cc12m": 270 | self.image_mean = (0.5762, 0.5503, 0.5213) 271 | self.image_std = (0.3207, 0.3169, 0.3307) 272 | elif self.config.image_normalization == "none": 273 | self.image_mean = (0.0, 0.0, 0.0) 274 | self.image_std = (1.0, 1.0, 1.0) 275 | elif self.config.image_normalization == "custom": 276 | self.image_mean = tuple( 277 | float(x) for x in self.config.custom_image_mean.split("-") 278 | ) 279 | self.image_std = tuple( 280 | float(x) for x in self.config.custom_image_std.split("-") 281 | ) 282 | assert len(self.image_mean) == len(self.image_std) == 3 283 | else: 284 | raise ValueError("Unsupported image normalization mode!") 285 | 286 | if self.config.transform_type == "pretrain": 287 | # Use Kaiming's simple pretrain processing 288 | self.transform = transforms.Compose( 289 | [ 290 | transforms.RandomResizedCrop( 291 | self.config.image_size, 292 | scale=(0.2, 1.0), 293 | interpolation=transforms.InterpolationMode.BICUBIC, 294 | ), 295 | transforms.RandomHorizontalFlip(), 296 | transforms.ToTensor(), 297 | transforms.Normalize(mean=self.image_mean, std=self.image_std), 298 | ] 299 | ) 300 | elif self.config.transform_type == "finetune": 301 | # Use Kaiming's finetune processing 302 | self.transform = create_transform( 303 | input_size=self.config.image_size, 304 | is_training=True, 305 | color_jitter=True, 306 | auto_augment=self.config.autoaug, 307 | interpolation="bicubic", 308 | re_prob=0, 309 | re_mode=0, 310 | re_count="const", 311 | mean=self.image_mean, 312 | std=self.image_std, 313 | ) 314 | elif self.config.transform_type == "plain_finetune": 315 | # Use supervised training processing of ViT from "Better plain ViT baselines for ImageNet-1k" https://arxiv.org/abs/2205.01580 316 | self.transform = transforms.Compose( 317 | [ 318 | transforms.RandomResizedCrop( 319 | self.config.image_size, 320 | interpolation=transforms.InterpolationMode.BICUBIC, 321 | ), 322 | transforms.RandomHorizontalFlip(), 323 | transforms.ToTensor(), 324 | transforms.Normalize(mean=self.image_mean, std=self.image_std), 325 | ] 326 | ) 327 | elif self.config.transform_type == "linear_prob": 328 | self.transform = transforms.Compose( 329 | [ 330 | transforms.RandomResizedCrop( 331 | self.config.image_size, 332 | interpolation=transforms.InterpolationMode.BICUBIC, 333 | ), 334 | transforms.RandomHorizontalFlip(), 335 | transforms.ToTensor(), 336 | transforms.Normalize(mean=self.image_mean, std=self.image_std), 337 | ] 338 | ) 339 | elif self.config.transform_type == "test": 340 | self.transform = transforms.Compose( 341 | [ 342 | transforms.Resize( 343 | self.config.image_size, 344 | interpolation=transforms.InterpolationMode.BICUBIC, 345 | ), 346 | transforms.CenterCrop(self.config.image_size), 347 | transforms.ToTensor(), 348 | transforms.Normalize(mean=self.image_mean, std=self.image_std), 349 | ] 350 | ) 351 | else: 352 | raise ValueError("Unsupported transform_type!") 353 | 354 | if self.config.random_start: 355 | # Bypass numpy random seed 356 | self.random_start_offset = np.random.default_rng().choice(len(self)) 357 | elif start_offset_ratio is not None: 358 | self.random_start_offset = int(len(self) * start_offset_ratio) % len(self) 359 | else: 360 | self.random_start_offset = 0 361 | 362 | def __getstate__(self): 363 | return self.config, self.random_start_offset 364 | 365 | def __setstate__(self, state): 366 | config, random_start_offset = state 367 | self.__init__(config) 368 | self.random_start_offset = random_start_offset 369 | 370 | def __len__(self): 371 | return min( 372 | self.h5_file["{}_jpg".format(self.config.partition)].shape[0] 373 | - self.config.start_index, 374 | self.config.max_length, 375 | ) 376 | 377 | def process_index(self, index): 378 | index = (index + self.random_start_offset) % len(self) 379 | return index + self.config.start_index 380 | 381 | def __getitem__(self, index): 382 | index = self.process_index(index) 383 | with BytesIO( 384 | self.h5_file["{}_jpg".format(self.config.partition)][index] 385 | ) as fin: 386 | image = skimage.io.imread(fin) 387 | 388 | if len(image.shape) == 2: 389 | image = gray2rgb(image) 390 | elif image.shape[-1] == 4: 391 | image = rgba2rgb(image) 392 | 393 | image = ( 394 | self.transform(Image.fromarray(np.uint8(image))).permute(1, 2, 0).numpy() 395 | ) 396 | image = image.astype(np.float32) 397 | 398 | if self.config.image_only: 399 | return image 400 | 401 | label = self.h5_file["{}_labels".format(self.config.partition)][index] 402 | 403 | return image, label 404 | 405 | def num_classes(self): 406 | return 1000 407 | -------------------------------------------------------------------------------- /lqae/main/lqae_main.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import pprint 3 | from copy import copy, deepcopy 4 | from functools import partial 5 | 6 | import absl.app 7 | import absl.flags 8 | import flax 9 | import jax 10 | import jax.numpy as jnp 11 | import numpy as np 12 | import optax 13 | import torch 14 | from absl import logging 15 | from flax.jax_utils import prefetch_to_device 16 | from flax.training.train_state import TrainState 17 | from tqdm.auto import tqdm, trange 18 | 19 | import wandb 20 | 21 | from ..data import ImageNetDataset, ImageTextDataset 22 | from ..jax_utils import ( 23 | JaxRNG, 24 | accumulated_gradient, 25 | get_metrics, 26 | next_rng, 27 | sync_state_across_devices, 28 | ) 29 | from ..models import LQAE, VQAE 30 | from ..utils import ( 31 | WandBLogger, 32 | create_log_images, 33 | define_flags_with_default, 34 | get_user_flags, 35 | image_float2int, 36 | load_pickle, 37 | set_random_seed, 38 | ) 39 | 40 | FLAGS_DEF = define_flags_with_default( 41 | seed=42, 42 | epochs=200, 43 | batch_size=0, 44 | accumulate_grad_steps=1, 45 | dataloader_n_workers=0, 46 | dataloader_shuffle=False, 47 | log_freq=50, 48 | plot_freq=1000, 49 | save_model_freq=0, 50 | clip_gradient=1e9, 51 | lr_init_value=0.0, 52 | lr_end_value=0.0, 53 | lr_peak_value=1.0e-4, 54 | lr_warmup_epochs=0, 55 | weight_decay=0.0001, 56 | load_checkpoint="", 57 | load_pretrained="", 58 | dataset="imagenet", 59 | cc12m_data=ImageTextDataset.get_default_config(), 60 | imagenet_data=ImageNetDataset.get_default_config(), 61 | # lqae: encoder-decoder with frozen BERT 62 | # vqae: encoder-decoder without BERT 63 | # bert: trainable BERT with frozen encoder-decoder 64 | model_type="lqae", 65 | lqae=LQAE.get_default_config(), 66 | vqae=VQAE.get_default_config(), 67 | logging=WandBLogger.get_default_config(), 68 | log_all_worker=False, 69 | ) 70 | FLAGS = absl.flags.FLAGS 71 | 72 | 73 | def main(argv): 74 | variant = get_user_flags(FLAGS, FLAGS_DEF) 75 | assert FLAGS.model_type in [ 76 | "lqae", 77 | "vqae", 78 | "bert", 79 | ], "model_type must be one of lqae, vqae, bert" 80 | 81 | variant["jax_process_index"] = jax_process_index = jax.process_index() 82 | variant["jax_process_count"] = jax_process_count = jax.process_count() 83 | assert FLAGS.batch_size % jax_process_count == 0 84 | variant["process_batch_size"] = process_batch_size = ( 85 | FLAGS.batch_size // jax_process_count 86 | ) 87 | variant["device_batch_size"] = process_batch_size // jax.local_device_count() 88 | lr_scale = FLAGS.batch_size / 256 89 | variant["effective_lr"] = FLAGS.lr_peak_value * lr_scale 90 | jax_devices = jax.local_devices() 91 | n_devices = len(jax_devices) 92 | assert process_batch_size % n_devices == 0 93 | 94 | logger = WandBLogger( 95 | config=FLAGS.logging, 96 | variant=variant, 97 | enable=FLAGS.log_all_worker or (jax_process_index == 0), 98 | ) 99 | set_random_seed(FLAGS.seed * (jax_process_index + 1)) 100 | 101 | if FLAGS.dataset == "cc12m": 102 | FLAGS.cc12m_data.image_only = True 103 | dataset = ImageTextDataset( 104 | FLAGS.cc12m_data, jax_process_index / jax_process_count 105 | ) 106 | elif FLAGS.dataset == "imagenet": 107 | FLAGS.imagenet_data.image_only = True 108 | dataset = ImageNetDataset( 109 | FLAGS.imagenet_data, jax_process_index / jax_process_count 110 | ) 111 | else: 112 | raise ValueError("Unsupported dataset!") 113 | 114 | val_flags = deepcopy(FLAGS.imagenet_data) 115 | val_flags.partition = "val" 116 | val_flags.transform_type = "test" 117 | val_dataset = ImageNetDataset(val_flags, jax_process_index / jax_process_count) 118 | 119 | steps_per_epoch = int(len(dataset) / FLAGS.batch_size) 120 | total_steps = steps_per_epoch * FLAGS.epochs 121 | val_steps = int(len(val_dataset) / FLAGS.batch_size) 122 | 123 | dataloader = torch.utils.data.DataLoader( 124 | dataset, 125 | batch_size=process_batch_size, 126 | shuffle=FLAGS.dataloader_shuffle, 127 | drop_last=True, 128 | num_workers=FLAGS.dataloader_n_workers, 129 | prefetch_factor=2, 130 | persistent_workers=FLAGS.dataloader_n_workers > 0, 131 | ) 132 | 133 | val_dataloader = torch.utils.data.DataLoader( 134 | val_dataset, 135 | batch_size=process_batch_size, 136 | shuffle=False, 137 | drop_last=True, 138 | num_workers=FLAGS.dataloader_n_workers, 139 | prefetch_factor=2, 140 | persistent_workers=FLAGS.dataloader_n_workers > 0, 141 | ) 142 | 143 | if FLAGS.model_type == "lqae" or FLAGS.model_type == "bert": 144 | logging.info(f"Using LQAE model for {FLAGS.model_type}") 145 | model = LQAE(FLAGS.lqae) 146 | elif FLAGS.model_type == "vqae": 147 | logging.info(f"Using LQAE model for {FLAGS.model_type}") 148 | model = VQAE(FLAGS.vqae) 149 | 150 | learning_rate = optax.warmup_cosine_decay_schedule( 151 | init_value=FLAGS.lr_init_value * lr_scale, 152 | peak_value=FLAGS.lr_peak_value * lr_scale, 153 | warmup_steps=FLAGS.lr_warmup_epochs 154 | * steps_per_epoch 155 | // FLAGS.accumulate_grad_steps, 156 | decay_steps=total_steps // FLAGS.accumulate_grad_steps, 157 | end_value=FLAGS.lr_end_value * lr_scale, 158 | ) 159 | 160 | def get_loss(output, result_dict, image, train=True): 161 | if "bert_loss" in result_dict: 162 | bert_loss = result_dict["bert_loss"] 163 | else: 164 | bert_loss = 0.0 165 | 166 | if FLAGS.model_type == "lqae" or FLAGS.model_type == "bert": 167 | recon_loss = FLAGS.lqae.bert_channel_image_loss_weight * jnp.mean( 168 | (image - output["bert_channel_image_output"]) ** 2 169 | ) + FLAGS.lqae.nochannel_image_loss_weight * jnp.mean( 170 | (image - output["image_output"]) ** 2 171 | ) 172 | elif FLAGS.model_type == "vqae": 173 | recon_loss = jnp.mean((image - output["image_output"]) ** 2) 174 | 175 | if train: 176 | quantizer_loss = result_dict["quantizer_loss"] 177 | return quantizer_loss, bert_loss, recon_loss 178 | else: 179 | return bert_loss, recon_loss 180 | 181 | @partial(jax.pmap, axis_name="pmap", donate_argnums=0) 182 | def train_step_fn(state, rng, accumulated_grads, accumulated_steps, image): 183 | rng_generator = JaxRNG(rng) 184 | 185 | def loss_fn(params): 186 | output, result_dict = model.apply( 187 | params, 188 | image, 189 | train=True, 190 | rngs=rng_generator(keys=model.rng_keys()), 191 | ) 192 | 193 | quantizer_loss, bert_loss, recon_loss = get_loss( 194 | output, result_dict, image, train=True 195 | ) 196 | 197 | total_loss = recon_loss + quantizer_loss + bert_loss 198 | 199 | aux = dict( 200 | recon_loss=recon_loss, 201 | quantizer_loss=quantizer_loss, 202 | bert_loss=bert_loss, 203 | total_loss=total_loss, 204 | e_latent_loss=result_dict["e_latent_loss"], 205 | q_latent_loss=result_dict["q_latent_loss"], 206 | entropy_loss=result_dict["entropy_loss"], 207 | perplexity=result_dict["perplexity"], 208 | codebook_usage=result_dict["codebook_usage"], 209 | ) 210 | encoding_indices = result_dict["encoding_indices"] 211 | 212 | return total_loss, (aux, encoding_indices) 213 | 214 | grad_fn = jax.value_and_grad(loss_fn, has_aux=True) 215 | (loss, (aux, encoding_indices)), grads = grad_fn(state.params) 216 | encoding_indices = jax.lax.all_gather(encoding_indices, axis_name="pmap") 217 | loss, aux = jax.lax.pmean((loss, aux), axis_name="pmap") 218 | aux["train_state_step"] = state.step 219 | aux["learning_rate"] = learning_rate(state.step) 220 | 221 | def global_norm(tree): 222 | """ Return the global L2 norm of a pytree. """ 223 | squared = jax.tree_util.tree_map(lambda x: jnp.sum(jnp.square(x)), tree) 224 | flattened, _ = jax.flatten_util.ravel_pytree(squared) 225 | return jnp.sqrt(jnp.sum(flattened)) 226 | 227 | grad_norm = global_norm(grads) 228 | aux["grad_norm"] = grad_norm 229 | 230 | if FLAGS.accumulate_grad_steps > 1: 231 | state, accumulated_grads, accumulated_steps = accumulated_gradient( 232 | state, 233 | accumulated_grads, 234 | accumulated_steps, 235 | grads, 236 | FLAGS.accumulate_grad_steps, 237 | lambda s, g: s.apply_gradients( 238 | grads=jax.lax.pmean(g, axis_name="pmap") 239 | ), 240 | ) 241 | else: 242 | state = state.apply_gradients(grads=jax.lax.pmean(grads, axis_name="pmap")) 243 | return ( 244 | state, 245 | aux, 246 | rng_generator(), 247 | accumulated_grads, 248 | accumulated_steps, 249 | encoding_indices, 250 | ) 251 | 252 | @partial(jax.pmap, axis_name="pmap") 253 | def val_step_fn(state, rng, image): 254 | rng_generator = JaxRNG(rng) 255 | 256 | output, result_dict = model.apply( 257 | state.params, 258 | image, 259 | train=False, 260 | ratio={"min_ratio": 0, "max_ratio": 0}, 261 | rngs=rng_generator(keys=model.rng_keys()), 262 | ) 263 | 264 | bert_loss, recon_loss = get_loss(output, result_dict, image, train=False) 265 | 266 | aux = dict( 267 | recon_loss=recon_loss, 268 | bert_loss=bert_loss, 269 | perplexity=result_dict["perplexity"], 270 | codebook_usage=result_dict["codebook_usage"], 271 | ) 272 | 273 | encoding_indices = result_dict["encoding_indices"] 274 | encoding_indices = jax.lax.all_gather(encoding_indices, axis_name="pmap") 275 | 276 | aux = jax.lax.pmean(aux, axis_name="pmap") 277 | return aux, rng_generator(), encoding_indices 278 | 279 | @partial(jax.pmap, axis_name="pmap") 280 | def reconstruction_fn(state, rng, image): 281 | rng_generator = JaxRNG(rng) 282 | 283 | output, _ = model.apply( 284 | state.params, 285 | image, 286 | train=False, 287 | ratio={"min_ratio": 0, "max_ratio": 0}, 288 | rngs=rng_generator(keys=model.rng_keys()), 289 | ) 290 | if FLAGS.model_type == "lqae" or FLAGS.model_type == "bert": 291 | image_output = output["image_output"] 292 | image_output = jnp.clip(image_output, 0, 1) 293 | bert_channel_image_output = output["bert_channel_image_output"] 294 | bert_channel_image_output = jnp.clip(bert_channel_image_output, 0, 1) 295 | return image, image_output, bert_channel_image_output 296 | elif FLAGS.model_type == "vqae": 297 | image_output = output["image_output"] 298 | image_output = jnp.clip(image_output, 0, 1) 299 | return image, image_output 300 | 301 | def get_weight_decay_mask(params): 302 | flattened_params = flax.traverse_util.flatten_dict(flax.core.unfreeze(params)) 303 | 304 | def decay(key): 305 | return all([k not in model.no_decay_list() for k in key]) 306 | 307 | return flax.traverse_util.unflatten_dict( 308 | {key: decay(key) for key in flattened_params.keys()} 309 | ) 310 | 311 | def get_no_gradient_update_fn(model_type): 312 | if model_type == "lqae": 313 | 314 | def func(key): 315 | return "lang_model" in key 316 | 317 | elif model_type == "bert": 318 | 319 | def func(key): 320 | return "encoder" in key or "decoder" in key 321 | 322 | elif model_type == "vqae": 323 | 324 | def func(key): 325 | return False 326 | 327 | return func 328 | 329 | no_gradient_update = get_no_gradient_update_fn(FLAGS.model_type) 330 | 331 | if FLAGS.load_checkpoint != "": 332 | checkpoint_data = load_pickle(FLAGS.load_checkpoint) 333 | state = flax.jax_utils.replicate(checkpoint_data["state"], jax_devices) 334 | start_step = checkpoint_data["step"] 335 | else: 336 | image = jnp.zeros((6, 256, 256, 3), dtype=jnp.float32) 337 | rngs = next_rng(keys=model.rng_keys()) 338 | params = model.init(rngs, image, train=True) 339 | 340 | if FLAGS.model_type == "bert": 341 | if FLAGS.load_pretrained != "": 342 | checkpoint_data = load_pickle(FLAGS.load_pretrained) 343 | checkpoint_params = checkpoint_data["state"].params["params"] 344 | checkpoint_params = flax.core.unfreeze(checkpoint_params) 345 | params = flax.core.unfreeze(params["params"]) 346 | for key in params.keys(): 347 | if key not in checkpoint_params.keys(): 348 | if key in ["lang_model"]: 349 | continue 350 | else: 351 | raise ValueError(f"pretrained model miss key={key}") 352 | params = flax.core.freeze({"params": params}) 353 | params = model.load_bert_params(params) 354 | elif FLAGS.model_type == "lqae": 355 | if FLAGS.lqae.use_bert_codebook: 356 | params = model.load_bert_params(params) 357 | 358 | transform_fn = { 359 | True: optax.set_to_zero(), 360 | False: optax.chain( 361 | optax.clip_by_global_norm(FLAGS.clip_gradient), 362 | optax.adamw( 363 | learning_rate=learning_rate, 364 | b1=0.9, 365 | b2=0.95, 366 | weight_decay=FLAGS.weight_decay, 367 | ), 368 | ), 369 | } 370 | 371 | def label_fn(params): 372 | flattened_params = flax.traverse_util.flatten_dict(params) 373 | return flax.traverse_util.unflatten_dict( 374 | { 375 | key: no_gradient_update(key) 376 | for key, value in flattened_params.items() 377 | } 378 | ) 379 | 380 | def count_params(params): 381 | flattened_params = flax.traverse_util.flatten_dict(params) 382 | tree = flax.traverse_util.unflatten_dict( 383 | { 384 | key: value 385 | for key, value in flattened_params.items() 386 | if no_gradient_update(key) 387 | } 388 | ) 389 | num_params = sum(p.size for p in jax.tree_leaves(tree)) 390 | return num_params 391 | 392 | num_params = count_params(params) 393 | logger.log({"num_learnable_params": num_params}) 394 | 395 | opt = optax.multi_transform(transform_fn, label_fn) 396 | state = flax.jax_utils.replicate( 397 | TrainState.create( 398 | params=flax.core.frozen_dict.unfreeze(params), 399 | apply_fn=None, 400 | tx=opt, 401 | ), 402 | jax_devices, 403 | ) 404 | start_step = 0 405 | 406 | del params 407 | 408 | def generate_batch(iterator): 409 | while True: 410 | for images in iterator: 411 | yield images.numpy().reshape(n_devices, -1, *images.shape[1:]) 412 | 413 | state = sync_state_across_devices(state) 414 | sharded_rng = jax.device_put_sharded(next_rng(n_devices), jax_devices) 415 | 416 | if FLAGS.accumulate_grad_steps > 1: 417 | accumulated_grads = flax.jax_utils.replicate( 418 | jax.tree_map(jnp.zeros_like, flax.jax_utils.unreplicate(state).params), 419 | jax_devices, 420 | ) 421 | accumulated_steps = flax.jax_utils.replicate( 422 | jnp.array(0, jnp.int32), jax_devices 423 | ) 424 | else: 425 | accumulated_grads = flax.jax_utils.replicate( 426 | jnp.array(0, jnp.int32), jax_devices 427 | ) 428 | accumulated_steps = flax.jax_utils.replicate( 429 | jnp.array(0, jnp.int32), jax_devices 430 | ) 431 | 432 | data_iterator = prefetch_to_device(generate_batch(dataloader), 2, jax_devices) 433 | val_data_iterator = prefetch_to_device( 434 | generate_batch(val_dataloader), 2, jax_devices 435 | ) 436 | step_counter = trange(start_step, total_steps, ncols=0, desc="Train") 437 | if FLAGS.model_type == "lqae" or FLAGS.model_type == "bert": 438 | codebook_size = 50265 439 | elif FLAGS.model_type == "vqae": 440 | codebook_size = FLAGS.vqae.codebook_size 441 | 442 | step = 0 443 | for step, image in zip(step_counter, data_iterator): 444 | epoch = int(step * jax_process_count / len(dataloader)) 445 | image = image.astype(jnp.float32) 446 | ( 447 | state, 448 | metrics, 449 | sharded_rng, 450 | accumulated_grads, 451 | accumulated_steps, 452 | encoding_indices, 453 | ) = train_step_fn( 454 | state, sharded_rng, accumulated_grads, accumulated_steps, image 455 | ) 456 | if step % FLAGS.log_freq == 0: 457 | log_metrics = {"step": step, "epoch": epoch} 458 | log_metrics.update(get_metrics(metrics, unreplicate=True)) 459 | encoding_indices = jax.device_get( 460 | flax.jax_utils.unreplicate(encoding_indices) 461 | ) 462 | indices_histogram = jnp.histogram( 463 | encoding_indices, bins=512, range=(0, codebook_size - 1) 464 | ) 465 | log_metrics.update( 466 | { 467 | "indices_histogram": wandb.Histogram( 468 | np_histogram=indices_histogram 469 | ), 470 | "encoding_indices": wandb.Histogram(encoding_indices), 471 | } 472 | ) 473 | logger.log(log_metrics) 474 | tqdm.write("\n" + pprint.pformat(log_metrics) + "\n") 475 | 476 | if FLAGS.plot_freq > 0 and step % FLAGS.plot_freq == 0: 477 | log_image = create_log_images( 478 | jax.device_get(reconstruction_fn(state, sharded_rng, image)), 479 | mean=dataset.image_mean, 480 | std=dataset.image_std, 481 | ) 482 | if jax_process_index == 0: 483 | logger.log({"image_prediction": wandb.Image(log_image)}) 484 | 485 | if step % steps_per_epoch == 0: 486 | val_metrics = [] 487 | val_encoding_indices = [] 488 | for _, val_image in zip( 489 | trange(val_steps, ncols=0, desc="val"), val_data_iterator 490 | ): 491 | val_image = val_image.astype(jnp.float32) 492 | metrics, sharded_rng, encoding_indices = val_step_fn( 493 | state, sharded_rng, val_image 494 | ) 495 | val_metrics.append(metrics) 496 | val_encoding_indices.append(encoding_indices) 497 | log_metrics = get_metrics(val_metrics, unreplicate=True, stack=True) 498 | val_encoding_indices = jax.tree_map( 499 | lambda x: jax.device_get(flax.jax_utils.unreplicate(x)), 500 | val_encoding_indices, 501 | ) 502 | val_encoding_indices = jnp.concatenate(val_encoding_indices, axis=0) 503 | log_metrics = { 504 | f"val_{k}": v 505 | for k, v in jax.tree_map(lambda x: x.mean(), log_metrics).items() 506 | } 507 | val_indices_histogram = jnp.histogram( 508 | val_encoding_indices, bins=512, range=(0, codebook_size - 1) 509 | ) 510 | log_metrics.update( 511 | { 512 | "val_indices_histogram": wandb.Histogram( 513 | np_histogram=val_indices_histogram 514 | ), 515 | "val_encoding_indices": wandb.Histogram(val_encoding_indices), 516 | } 517 | ) 518 | logger.log(log_metrics) 519 | tqdm.write("\n" + pprint.pformat(log_metrics) + "\n") 520 | 521 | if FLAGS.save_model_freq > 0 and step % FLAGS.save_model_freq == 0: 522 | save_data = { 523 | "step": step, 524 | "epoch": epoch, 525 | "variant": variant, 526 | "state": jax.device_get(flax.jax_utils.unreplicate(state)), 527 | } 528 | if jax_process_index == 0: 529 | logger.save_pickle(save_data, "model.pkl") 530 | 531 | if FLAGS.save_model_freq > 0: 532 | save_data = { 533 | "step": step, 534 | "epoch": epoch, 535 | "variant": variant, 536 | "state": jax.device_get(flax.jax_utils.unreplicate(state)), 537 | } 538 | if jax_process_index == 0: 539 | logger.save_pickle(save_data, "model.pkl") 540 | 541 | 542 | if __name__ == "__main__": 543 | torch.multiprocessing.set_start_method("spawn") 544 | absl.app.run(main) 545 | -------------------------------------------------------------------------------- /lqae/models/lqae.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import io 3 | import math 4 | import os 5 | from typing import Any, Callable, Optional, Tuple 6 | 7 | import flax 8 | import flax.linen as nn 9 | import jax 10 | import jax.numpy as jnp 11 | import ml_collections 12 | import numpy as np 13 | import optax 14 | import requests 15 | from ml_collections import ConfigDict 16 | from ml_collections.config_dict import config_dict 17 | from PIL import Image, ImageFilter 18 | from transformers import RobertaTokenizer 19 | from transformers.models.roberta.modeling_flax_roberta import ( 20 | FlaxBaseModelOutputWithPoolingAndCrossAttentions, 21 | FlaxMaskedLMOutput, 22 | FlaxRobertaEncoder, 23 | FlaxRobertaForMaskedLM, 24 | FlaxRobertaLMHead, 25 | FlaxRobertaPooler, 26 | RobertaConfig, 27 | create_position_ids_from_input_ids, 28 | ) 29 | 30 | from ..jax_utils import JaxRNG, get_onehot, next_rng 31 | from .base_resnet import ResNetDecoder, ResNetEncoder 32 | from .base_vit import VitDecoder, VitEncoder 33 | from .model_utils import ( 34 | assert_avg_rnd, 35 | update_vit_config, 36 | entropy_loss_fn, 37 | normalize_func, 38 | squared_euclidean_distance, 39 | ) 40 | 41 | 42 | class LanguageQuantizer(nn.Module): 43 | """Language quantizer.""" 44 | 45 | config: ConfigDict 46 | codebook: jnp.array 47 | dtype: int = jnp.float32 48 | 49 | def setup(self): 50 | if self.config.quantizer_latent_dim > 0: 51 | self.input_to_latent = nn.Dense( 52 | self.config.quantizer_latent_dim, dtype=self.dtype 53 | ) 54 | self.code_to_latent = nn.Dense( 55 | self.config.quantizer_latent_dim, dtype=self.dtype 56 | ) 57 | else: 58 | self.input_to_latent = self.code_to_latent = lambda x: x 59 | 60 | @nn.compact 61 | def __call__(self, x, train, rng): 62 | l2_normalize = lambda x, axis=1: normalize_func( 63 | x, axis=axis, use_l2_normalize=self.config.l2_normalize 64 | ) 65 | codebook_size = self.codebook.shape[0] 66 | if self.config.strawman_codebook: 67 | strawman_codebook = self.param( 68 | "strawman_codebook", 69 | jax.nn.initializers.normal(0.02, dtype=jnp.float32), 70 | (codebook_size, self.config.quantizer_latent_dim), 71 | ) 72 | strawman_codebook = jnp.asarray(strawman_codebook, dtype=self.dtype) 73 | latent_input = self.input_to_latent(jnp.reshape(x, (-1, x.shape[-1]))) 74 | latent_input = l2_normalize(latent_input, axis=1) 75 | sg_strawman_codebook = jax.lax.stop_gradient( 76 | l2_normalize(strawman_codebook, axis=1) 77 | ) 78 | distances = jnp.reshape( 79 | squared_euclidean_distance(latent_input, sg_strawman_codebook, 80 | dot_product=self.config.dot_product), 81 | x.shape[:-1] + (codebook_size,), 82 | ) 83 | else: 84 | codebook = jnp.asarray(self.codebook, dtype=self.dtype) 85 | latent_input = self.input_to_latent(jnp.reshape(x, (-1, x.shape[-1]))) 86 | latent_input = l2_normalize(latent_input, axis=1) 87 | latent_codebook = self.code_to_latent(codebook) 88 | latent_codebook = l2_normalize(latent_codebook, axis=1) 89 | sg_latent_codebook = jax.lax.stop_gradient( 90 | l2_normalize(latent_codebook, axis=1) 91 | ) 92 | distances = jnp.reshape( 93 | squared_euclidean_distance(latent_input, sg_latent_codebook, 94 | dot_product=self.config.dot_product), 95 | x.shape[:-1] + (codebook_size,), 96 | ) 97 | 98 | encoding_indices = jax.lax.approx_min_k( 99 | distances, 100 | k=self.config.top_k_value, 101 | reduction_dimension=-1, 102 | aggregate_to_topk=True, 103 | )[1] 104 | 105 | encoding_indices, encodings, quantized = self.get_encoding_quantized( 106 | encoding_indices, train, rng, codebook_size 107 | ) 108 | 109 | codebook_usage = jnp.sum(encodings, axis=(0, 1)) > 0 110 | codebook_usage = jnp.sum(codebook_usage) / codebook_size 111 | if self.config.top_k_avg: 112 | codebook_usage = codebook_usage / self.config.top_k_value 113 | result_dict = dict() 114 | if train: 115 | # disable gradient for LQAE 116 | quantized = jax.lax.stop_gradient(quantized) 117 | result_dict = self.get_train_loss(quantized, x, distances) 118 | 119 | if self.config.strawman_codebook: 120 | strawman_quantized = self.quantize_strawman(encodings) 121 | strawman_result_dict = self.get_train_loss( 122 | strawman_quantized, self.input_to_latent(x), distances 123 | ) 124 | for k, v in result_dict.items(): 125 | result_dict[k] = v + strawman_result_dict[k] 126 | else: 127 | latent_quantized = self.code_to_latent(quantized) 128 | latent_result_dict = self.get_train_loss( 129 | latent_quantized, self.input_to_latent(x), distances 130 | ) 131 | for k, v in result_dict.items(): 132 | result_dict[k] = v + latent_result_dict[k] 133 | 134 | quantized = x + jax.lax.stop_gradient(quantized - x) 135 | 136 | avg_probs = jnp.mean(encodings.reshape(-1, encodings.shape[-1]), axis=0) 137 | log_perplexity = -jnp.sum(avg_probs * jnp.log(avg_probs + 1e-10)) 138 | perplexity = jnp.exp(log_perplexity) 139 | 140 | if "quantizer_loss" in result_dict: 141 | result_dict["quantizer_loss"] = ( 142 | result_dict["quantizer_loss"] 143 | + self.config.quantizer_loss_perplexity * log_perplexity 144 | ) 145 | result_dict.update( 146 | { 147 | "encodings": encodings, 148 | "encoding_indices": encoding_indices, 149 | "raw": x, 150 | "perplexity": perplexity, 151 | "codebook_usage": codebook_usage, 152 | } 153 | ) 154 | return quantized, result_dict 155 | 156 | def quantize(self, z: jnp.ndarray) -> jnp.ndarray: 157 | return jnp.dot(z, self.codebook) 158 | 159 | def get_codebook(self) -> jnp.ndarray: 160 | return self.codebook 161 | 162 | def decode_ids(self, ids: jnp.ndarray) -> jnp.ndarray: 163 | return jnp.take(self.codebook, ids, axis=0) 164 | 165 | def quantize_strawman(self, z: jnp.ndarray) -> jnp.ndarray: 166 | return jnp.dot(z, self.variables["params"]["strawman_codebook"]) 167 | 168 | def get_train_loss(self, quantized, x, distances): 169 | e_latent_loss = ( 170 | jnp.mean((jax.lax.stop_gradient(quantized) - x) ** 2) 171 | * self.config.quantizer_loss_commitment 172 | ) 173 | q_latent_loss = jnp.mean((quantized - jax.lax.stop_gradient(x)) ** 2) 174 | entropy_loss = 0.0 175 | if self.config.quantizer_loss_entropy != 0: 176 | entropy_loss = ( 177 | entropy_loss_fn( 178 | -distances, 179 | loss_type=self.config.entropy_loss_type, 180 | temperature=self.config.entropy_temperature, 181 | ) 182 | * self.config.quantizer_loss_entropy 183 | ) 184 | e_latent_loss = jnp.asarray(e_latent_loss, jnp.float32) 185 | q_latent_loss = jnp.asarray(q_latent_loss, jnp.float32) 186 | entropy_loss = jnp.asarray(entropy_loss, jnp.float32) 187 | loss = e_latent_loss + q_latent_loss + entropy_loss 188 | 189 | result_dict = dict( 190 | quantizer_loss=loss, 191 | e_latent_loss=e_latent_loss, 192 | q_latent_loss=q_latent_loss, 193 | entropy_loss=entropy_loss, 194 | ) 195 | return result_dict 196 | 197 | def get_encoding_quantized(self, encoding_indices, train, rng, codebook_size): 198 | if self.config.top_k_rnd: 199 | if train: 200 | encoding_indices = jax.random.choice(rng, encoding_indices, axis=-1) 201 | else: 202 | encoding_indices = encoding_indices[..., 0] 203 | encodings = jax.nn.one_hot( 204 | encoding_indices, codebook_size, dtype=self.dtype 205 | ) 206 | quantized = self.quantize(encodings) 207 | elif self.config.top_k_avg: 208 | encodings = jax.nn.one_hot( 209 | encoding_indices, codebook_size, dtype=self.dtype 210 | ) 211 | quantized = self.quantize(encodings) 212 | quantized = jnp.mean(quantized, axis=-2) 213 | encoding_indices = encoding_indices[..., 0] 214 | else: 215 | encoding_indices = encoding_indices[..., 0] 216 | encodings = jax.nn.one_hot( 217 | encoding_indices, codebook_size, dtype=self.dtype 218 | ) 219 | quantized = self.quantize(encodings) 220 | return encoding_indices, encodings, quantized 221 | 222 | 223 | # Removed word embeddings. Copied from 224 | # https://github.com/huggingface/transformers/blob/12ce2941c7b67c0dedac0f0468b3ed854fa940ab/src/transformers/models/roberta/modeling_flax_roberta.py#L139-L176 225 | class Add_Pos(nn.Module): 226 | """Construct the embeddings from word, position and token_type embeddings.""" 227 | 228 | config: RobertaConfig 229 | dtype: jnp.dtype = jnp.float32 # the dtype of the computation 230 | 231 | def setup(self): 232 | self.position_embeddings = nn.Embed( 233 | self.config.max_position_embeddings, 234 | self.config.hidden_size, 235 | embedding_init=jax.nn.initializers.normal( 236 | stddev=self.config.initializer_range 237 | ), 238 | ) 239 | self.token_type_embeddings = nn.Embed( 240 | self.config.type_vocab_size, 241 | self.config.hidden_size, 242 | embedding_init=jax.nn.initializers.normal( 243 | stddev=self.config.initializer_range 244 | ), 245 | ) 246 | self.LayerNorm = nn.LayerNorm( 247 | epsilon=self.config.layer_norm_eps, dtype=self.dtype 248 | ) 249 | self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) 250 | 251 | def __call__( 252 | self, 253 | inputs_embeds, 254 | token_type_ids, 255 | position_ids, 256 | attention_mask, 257 | deterministic: bool = True, 258 | ): 259 | # Embed 260 | position_embeds = self.position_embeddings(position_ids.astype("i4")) 261 | token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4")) 262 | 263 | # Sum all embeddings 264 | hidden_states = inputs_embeds + token_type_embeddings + position_embeds 265 | 266 | # Layer Norm 267 | hidden_states = self.LayerNorm(hidden_states) 268 | hidden_states = self.dropout(hidden_states, deterministic=deterministic) 269 | return hidden_states 270 | 271 | 272 | @flax.struct.dataclass 273 | class NewFlaxMaskedLMOutput(FlaxMaskedLMOutput): 274 | logits: jnp.ndarray = None 275 | hidden_states: Optional[Tuple[jnp.ndarray]] = None 276 | attentions: Optional[Tuple[jnp.ndarray]] = None 277 | last_hidden_states: Optional[Tuple[jnp.ndarray]] = None 278 | 279 | 280 | # Removed word embeddings. Copied from 281 | # https://github.com/huggingface/transformers/blob/12ce2941c7b67c0dedac0f0468b3ed854fa940ab/src/transformers/models/roberta/modeling_flax_roberta.py#L914-L982 282 | # and 283 | # https://github.com/huggingface/transformers/blob/12ce2941c7b67c0dedac0f0468b3ed854fa940ab/src/transformers/models/roberta/modeling_flax_roberta.py#L998-L1053 284 | class Language_Model(nn.Module): 285 | config: RobertaConfig 286 | bert: str = "roberta-base" 287 | dtype: jnp.dtype = jnp.float32 # the dtype of the computation 288 | add_pooling_layer: bool = False 289 | gradient_checkpointing: bool = False 290 | 291 | def setup(self): 292 | self.embeddings = Add_Pos(self.config) 293 | self.encoder = FlaxRobertaEncoder( 294 | self.config, 295 | dtype=self.dtype, 296 | gradient_checkpointing=self.gradient_checkpointing, 297 | ) 298 | self.pooler = FlaxRobertaPooler(self.config, dtype=self.dtype) 299 | self.lm_head = FlaxRobertaLMHead(config=self.config, dtype=self.dtype) 300 | 301 | pretrained_model = FlaxRobertaForMaskedLM.from_pretrained(self.bert) 302 | self.word_embeddings = pretrained_model.params["roberta"]["embeddings"][ 303 | "word_embeddings" 304 | ]["embedding"] 305 | 306 | @nn.nowrap 307 | def rng_keys(self): 308 | return ("params", "dropout") 309 | 310 | def __call__( 311 | self, 312 | hidden_states, 313 | input_ids, 314 | attention_mask, 315 | token_type_ids: Optional[jnp.ndarray] = None, 316 | position_ids: Optional[jnp.ndarray] = None, 317 | head_mask: Optional[jnp.ndarray] = None, 318 | encoder_hidden_states: Optional[jnp.ndarray] = None, 319 | encoder_attention_mask: Optional[jnp.ndarray] = None, 320 | init_cache: bool = False, 321 | deterministic: bool = True, 322 | output_attentions: bool = False, 323 | output_hidden_states: bool = False, 324 | return_dict: bool = True, 325 | ): 326 | 327 | # make sure `token_type_ids` is correctly initialized when not passed 328 | if token_type_ids is None: 329 | token_type_ids = jnp.zeros_like(input_ids) 330 | 331 | # make sure `position_ids` is correctly initialized when not passed 332 | if position_ids is None: 333 | position_ids = create_position_ids_from_input_ids( 334 | input_ids, self.config.pad_token_id 335 | ) 336 | 337 | hidden_states = self.embeddings( 338 | hidden_states, 339 | token_type_ids, 340 | position_ids, 341 | attention_mask, 342 | deterministic=deterministic, 343 | ) 344 | 345 | outputs = self.encoder( 346 | hidden_states, 347 | attention_mask, 348 | head_mask=head_mask, 349 | deterministic=deterministic, 350 | encoder_hidden_states=encoder_hidden_states, 351 | encoder_attention_mask=encoder_attention_mask, 352 | init_cache=init_cache, 353 | output_attentions=output_attentions, 354 | output_hidden_states=output_hidden_states, 355 | return_dict=return_dict, 356 | ) 357 | hidden_states = outputs[0] 358 | pooled = self.pooler(hidden_states) if self.add_pooling_layer else None 359 | 360 | if not return_dict: 361 | # if pooled is None, don't return it 362 | if pooled is None: 363 | outputs = (hidden_states,) + outputs[1:] 364 | else: 365 | outputs = (hidden_states, pooled) + outputs[1:] 366 | else: 367 | outputs = FlaxBaseModelOutputWithPoolingAndCrossAttentions( 368 | last_hidden_state=hidden_states, 369 | pooler_output=pooled, 370 | hidden_states=outputs.hidden_states, 371 | attentions=outputs.attentions, 372 | cross_attentions=outputs.cross_attentions, 373 | ) 374 | 375 | hidden_states = outputs[0] 376 | if self.config.tie_word_embeddings: 377 | shared_embedding = self.word_embeddings 378 | else: 379 | shared_embedding = None 380 | 381 | # Compute the prediction scores 382 | logits = self.lm_head(hidden_states, shared_embedding=shared_embedding) 383 | 384 | if not return_dict: 385 | return (logits,) + outputs[1:] 386 | 387 | return NewFlaxMaskedLMOutput( 388 | logits=logits, 389 | hidden_states=outputs.hidden_states, 390 | attentions=outputs.attentions, 391 | last_hidden_states=hidden_states, 392 | ) 393 | 394 | 395 | class LQAE(nn.Module): 396 | config_updates: ... = None 397 | dtype: int = jnp.float32 398 | 399 | @staticmethod 400 | @nn.nowrap 401 | def get_default_config(updates=None): 402 | config = ConfigDict() 403 | 404 | # Quantizer config 405 | config.quantizer_loss_entropy = 0.0 406 | config.entropy_temperature = 0.01 407 | config.entropy_loss_type = "softmax" 408 | config.quantizer_loss_commitment = 0.25 409 | config.l2_normalize = False 410 | config.top_k_value = 1 411 | config.top_k_avg = False 412 | config.top_k_rnd = False 413 | config.quantizer_latent_dim = 0 414 | config.strawman_codebook = False 415 | config.strawman_codebook_init = "normal:1.0" 416 | config.quantizer_loss_perplexity = 0.0 417 | config.dot_product = False 418 | 419 | # ResNet config 420 | config.filters = 128 421 | config.num_res_blocks = 2 422 | config.channel_multipliers = [1, 1, 2, 2, 4] 423 | config.hidden_size = 768 424 | config.conv_downsample = False 425 | 426 | # VIT config 427 | config.vit_encoder_decoder = False 428 | config.vit_model_type = config_dict.placeholder(str) 429 | config.patch_size = 16 430 | config.dropout = 0.0 431 | config.hidden_size = 768 432 | config.mlp_ratio = 4 433 | config.intermediate_size = config.hidden_size * config.mlp_ratio 434 | 435 | # Bert config 436 | config.bert = "roberta-base" 437 | config.bert_min_ratio = 0.15 438 | config.bert_max_ratio = 0.15 439 | config.use_bert_codebook = True 440 | config.bert_loss_mask_only = True 441 | config.bert_mask_loss_weight = 0.0 442 | config.bert_channel_image_loss_weight = 0.0 443 | config.nochannel_image_loss_weight = 0.0 444 | config.use_bert_ste = True 445 | 446 | if updates is not None: 447 | config.update(ConfigDict(updates).copy_and_resolve_references()) 448 | 449 | if config.vit_model_type is not None: 450 | update_vit_config(config.vit_model_type, config) 451 | assert_avg_rnd(config) 452 | 453 | return config 454 | 455 | @nn.nowrap 456 | def get_config(self): 457 | return self.get_default_config(self.config_updates) 458 | 459 | @nn.nowrap 460 | def get_bert_config(self): 461 | config = self.get_default_config(self.config_updates) 462 | return RobertaConfig.from_pretrained(config.bert) 463 | 464 | @nn.nowrap 465 | def get_num_bert_layers(self): 466 | return self.get_bert_config().num_hidden_layers + 1 467 | 468 | @nn.nowrap 469 | def get_num_encoder_layers(self): 470 | return self.get_config().enc_num_layers 471 | 472 | @nn.nowrap 473 | def get_hidden_size(self): 474 | return self.get_config().hidden_size 475 | 476 | @nn.nowrap 477 | def rng_keys(self): 478 | return ("params", "dropout", "drop_path", "shuffle", "mask", "quantizer") 479 | 480 | @nn.nowrap 481 | def no_decay_list(self): 482 | no_decay = ["bias", "embedding"] 483 | return no_decay 484 | 485 | def setup(self): 486 | self.config = self.get_default_config(self.config_updates) 487 | ( 488 | self.lang_model, 489 | self.codebook, 490 | self.mask_code, 491 | self.tokenizer, 492 | ) = self.config_language_model() 493 | self.quantizer = LanguageQuantizer( 494 | config=self.config, 495 | codebook=self.codebook, 496 | dtype=self.dtype, 497 | ) 498 | if self.config.vit_encoder_decoder: 499 | self.encoder = VitEncoder(config=self.config, dtype=self.dtype) 500 | self.decoder = VitDecoder(config=self.config, dtype=self.dtype) 501 | else: 502 | self.encoder = ResNetEncoder(config=self.config, dtype=self.dtype) 503 | self.decoder = ResNetDecoder(config=self.config, dtype=self.dtype) 504 | 505 | def config_language_model(self): 506 | pretrained_bert = FlaxRobertaForMaskedLM.from_pretrained(self.config.bert) 507 | language_model = Language_Model(pretrained_bert.config, self.config.bert) 508 | codebook = pretrained_bert.params["roberta"]["embeddings"]["word_embeddings"][ 509 | "embedding" 510 | ] 511 | 512 | tokenizer = RobertaTokenizer.from_pretrained(self.config.bert) 513 | mask_token_id = tokenizer.convert_tokens_to_ids(tokenizer.mask_token) 514 | mask_code = codebook[mask_token_id] 515 | 516 | return language_model, codebook, mask_code, tokenizer 517 | 518 | @nn.nowrap 519 | def load_bert_params(self, params, pretrain=True): 520 | config = self.get_default_config(self.config_updates) 521 | pretrained_bert = FlaxRobertaForMaskedLM.from_pretrained(config.bert) 522 | pretrained_params = flax.core.unfreeze(pretrained_bert.params) 523 | 524 | if pretrain: 525 | lang_params = flax.core.unfreeze(params)["params"]["lang_model"] 526 | else: 527 | lang_params = flax.core.unfreeze(params)["params"]['backbone']["lang_model"] 528 | 529 | for key in lang_params.keys(): 530 | if key == "embeddings": 531 | for k in lang_params["embeddings"].keys(): 532 | assert ( 533 | k in pretrained_params["roberta"]["embeddings"].keys() 534 | ), f"pretrained model miss key={key}" 535 | lang_params["embeddings"][k] = pretrained_params["roberta"][ 536 | "embeddings" 537 | ][k] 538 | elif key == "lm_head": 539 | assert ( 540 | key in pretrained_params.keys() 541 | ), f"pretrained model miss key={key}" 542 | lang_params[key] = pretrained_params[key] 543 | else: 544 | assert ( 545 | key in pretrained_params["roberta"].keys() 546 | ), f"pretrained model miss key={key}" 547 | lang_params[key] = pretrained_params["roberta"][key] 548 | 549 | params = flax.core.unfreeze(params) 550 | if pretrain: 551 | params["params"].update({"lang_model": lang_params}) 552 | else: 553 | params["params"]["backbone"].update({"lang_model": lang_params}) 554 | params = flax.core.freeze(params) 555 | return params 556 | 557 | def languge_model_encode_decode( 558 | self, 559 | input_code, 560 | input_ids, 561 | ratio={}, 562 | output_hidden_states=False, 563 | ): 564 | input_shape = input_code.shape 565 | if len(input_code.shape) == 4: 566 | input_code = jnp.reshape( 567 | input_code, (input_code.shape[0], -1, input_code.shape[-1]) 568 | ) 569 | input_ids = jnp.reshape(input_ids, (input_ids.shape[0], -1)) 570 | 571 | min_ratio = ratio.get("min_ratio", self.config.bert_min_ratio) 572 | max_ratio = ratio.get("max_ratio", self.config.bert_max_ratio) 573 | assert min_ratio <= max_ratio, "min_ratio must be less than max_ratio" 574 | use_mask = random_ratio_mask( 575 | jnp.zeros((input_code.shape[0], input_code.shape[1])), 576 | min_ratio, 577 | max_ratio, 578 | self.make_rng("mask"), 579 | ).astype(bool) 580 | input_code = jnp.where( 581 | use_mask[..., None], self.mask_code[None, None, ...], input_code 582 | ) 583 | 584 | attention_mask = jnp.ones( 585 | (input_code.shape[0], input_code.shape[1]), dtype=jnp.uint8 586 | ) 587 | bert_output = self.lang_model( 588 | input_code, 589 | input_ids, 590 | attention_mask, 591 | output_hidden_states=output_hidden_states, 592 | deterministic=True, 593 | ) 594 | if self.config.use_bert_ste: 595 | logits = bert_output.logits 596 | decoding_indices = jnp.argmax(logits, axis=-1) 597 | codebook_size = self.codebook.shape[0] 598 | encodings = jax.nn.one_hot(decoding_indices, codebook_size, dtype=self.dtype) 599 | argmax_code = jnp.dot(encodings, self.codebook) 600 | softmax_code = jnp.dot(jax.nn.softmax(logits, axis=-1), self.codebook) 601 | output = softmax_code + jax.lax.stop_gradient(argmax_code - softmax_code) 602 | output = jnp.reshape(output, input_shape) 603 | else: 604 | output = bert_output.last_hidden_states 605 | output = jnp.reshape(output, input_shape) 606 | 607 | logits = bert_output.logits 608 | bert_loss = optax.softmax_cross_entropy( 609 | logits, get_onehot(input_ids, logits.shape[-1]) 610 | ) 611 | if self.config.bert_loss_mask_only: 612 | bert_loss = bert_loss * use_mask 613 | bert_loss = jnp.sum(bert_loss, axis=1) / jnp.sum(use_mask, axis=1) 614 | 615 | bert_loss = jnp.mean(bert_loss) * self.config.bert_mask_loss_weight 616 | language_model_output = { 617 | "bert_logits": bert_output.logits, 618 | "bert_hidden_states": bert_output.hidden_states, 619 | "bert_loss": bert_loss, 620 | } 621 | return output, language_model_output 622 | 623 | def encode(self, image, train: bool): 624 | encoded_feature, _ = self.encoder(image, train) 625 | quantized, result_dict = self.quantizer( 626 | encoded_feature, train, rng=self.make_rng("quantizer") 627 | ) 628 | return quantized, result_dict 629 | 630 | def forward_image_representation(self, image, train): 631 | output = {} 632 | encoded_feature, encoder_embedding = self.encoder(image, train) 633 | if encoder_embedding is not None: 634 | encoder_embedding = jax.tree_util.tree_map( 635 | lambda x: jnp.reshape(x, (x.shape[0], -1, x.shape[-1])), 636 | encoder_embedding, 637 | ) 638 | output["encoder_embedding"] = encoder_embedding 639 | else: 640 | encoded_feature = jax.tree_util.tree_map( 641 | lambda x: jnp.reshape(x, (x.shape[0], -1, x.shape[-1])), encoded_feature 642 | ) 643 | output["encoder_embedding"] = [encoded_feature] 644 | quantized, result_dict = self.quantizer( 645 | encoded_feature, train, rng=self.make_rng("quantizer") 646 | ) 647 | _, language_model_output = self.languge_model_encode_decode( 648 | quantized, 649 | result_dict["encoding_indices"], 650 | ratio={"min_ratio": 0, "max_ratio": 0}, 651 | output_hidden_states=True, 652 | ) 653 | bert_embedding = language_model_output["bert_hidden_states"] 654 | output["bert_embedding"] = bert_embedding 655 | return output 656 | 657 | def decode(self, x: jnp.ndarray, train: bool) -> jnp.ndarray: 658 | reconstructed = self.decoder(x, train) 659 | return reconstructed 660 | 661 | def get_codebook_funct(self): 662 | return self.quantizer.get_codebook() 663 | 664 | def decode_from_indices(self, inputs, train): 665 | if isinstance(inputs, dict): 666 | ids = inputs["encoding_indices"] 667 | else: 668 | ids = inputs 669 | features = self.quantizer.decode_ids(ids) 670 | reconstructed_image = self.decode(features, train) 671 | return reconstructed_image 672 | 673 | def encode_to_indices(self, inputs, train): 674 | if isinstance(inputs, dict): 675 | image = inputs["image"] 676 | else: 677 | image = inputs 678 | encoded_feature, _ = self.encoder(image, train) 679 | _, result_dict = self.quantizer( 680 | encoded_feature, train, rng=self.make_rng("quantizer") 681 | ) 682 | ids = result_dict["encoding_indices"] 683 | return ids 684 | 685 | def __call__(self, image, train, ratio={}): 686 | quantized, result_dict = self.encode(image, train) 687 | bert_quantized, language_model_output = self.languge_model_encode_decode( 688 | quantized, result_dict["encoding_indices"], ratio 689 | ) 690 | result_dict = {**result_dict, **language_model_output} 691 | image_output = self.decoder(quantized, train) 692 | bert_channel_image_output = self.decoder(bert_quantized, train) 693 | output = { 694 | "image_output": image_output, 695 | "bert_channel_image_output": bert_channel_image_output, 696 | } 697 | return output, result_dict 698 | 699 | 700 | def random_mask(x, ratio, rng): 701 | return (jax.random.uniform(rng, shape=x.shape[:2]) < ratio).astype(jnp.float32) 702 | 703 | 704 | def random_ratio_mask(x, min_ratio, max_ratio, rng): 705 | rng_generator = JaxRNG(rng) 706 | ratio = jax.random.uniform( 707 | rng_generator(), shape=x.shape[:2], minval=min_ratio, maxval=max_ratio 708 | ) 709 | return random_mask(x, ratio, rng_generator()) 710 | --------------------------------------------------------------------------------