├── .gitignore ├── README.md ├── fundamentals.py └── transformer.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | .idea/* -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # JAX-examples 2 | -------------------------------------------------------------------------------- /fundamentals.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import numpy as np 4 | 5 | x = np.zeros(10) 6 | y= jnp.zeros(10) 7 | 8 | x 9 | 10 | y 11 | 12 | x = np.random.rand(1000,1000) 13 | y = jnp.array(x) 14 | 15 | # Commented out IPython magic to ensure Python compatibility. 16 | # %timeit -n 1 -r 1 np.dot(x,x) 17 | 18 | # Commented out IPython magic to ensure Python compatibility. 19 | # %timeit -n 1 -r 1 jnp.dot(y,y).block_until_ready() 20 | 21 | """## Automatic differentiation with grad""" 22 | 23 | from jax import grad 24 | 25 | def f(x): 26 | return 3*x**2 + 2*x + 5 27 | 28 | def f_prime(x): 29 | return 6*x +2 30 | 31 | grad(f)(1.0) 32 | 33 | f_prime(1.0) 34 | 35 | """## XLA and Jit""" 36 | 37 | # Commented out IPython magic to ensure Python compatibility. 38 | from jax import jit 39 | 40 | x = np.random.rand(1000,1000) 41 | y = jnp.array(x) 42 | 43 | def f(x): 44 | for _ in range(10): 45 | x = 0.5*x + 0.1* jnp.sin(x) 46 | return x 47 | 48 | g = jit(f) 49 | 50 | 51 | 52 | # %timeit -n 5 -r 5 f(y).block_until_ready() 53 | 54 | # Commented out IPython magic to ensure Python compatibility. 55 | # %timeit -n 5 -r 5 g(y).block_until_ready() 56 | 57 | """## pmap""" 58 | 59 | from jax import pmap 60 | 61 | def f(x): 62 | return jnp.sin(x) + x**2 63 | 64 | f(np.arange(4)) 65 | # pmap(f)(np.arange(4)) 66 | 67 | ## Note:colab doesn't allow to attach multiple GPUs to test this 68 | 69 | from functools import partial 70 | from jax.lax import psum 71 | 72 | @partial(pmap, axis_name="i") 73 | def normalize(x): 74 | return x/ psum(x,'i') 75 | 76 | normalize(np.arange(8.)) 77 | 78 | ## Note:colab doesn't allow to attach multiple GPUs to test this 79 | 80 | """## vmap""" 81 | 82 | from jax import vmap 83 | 84 | def f(x): 85 | return jnp.square(x) 86 | 87 | f(jnp.arange(5)) 88 | vmap(f)(jnp.arange(5)) 89 | 90 | """## Pseudo Random Number Generator""" 91 | 92 | from jax import random 93 | key = random.PRNGKey(5) 94 | random.uniform(key) 95 | 96 | """## Profiler""" 97 | 98 | import jax.profiler 99 | 100 | def func1(x): 101 | return jnp.tile(x, 10) * 0.5 102 | 103 | def func2(x): 104 | y = func1(x) 105 | return y, jnp.tile(x, 10) + 1 106 | 107 | x = jax.random.normal(jax.random.PRNGKey(42), (1000, 1000)) 108 | y, z = func2(x) 109 | 110 | z.block_until_ready() 111 | 112 | jax.profiler.save_device_memory_profile("memory.prof") -------------------------------------------------------------------------------- /transformer.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import logging 3 | import time 4 | from typing import NamedTuple, Optional, Any, Mapping 5 | 6 | import haiku as hk 7 | import jax 8 | import jax.numpy as jnp 9 | import numpy as np 10 | import optax 11 | import tensorflow as tf 12 | import tensorflow_datasets as tfds 13 | 14 | 15 | class LanguageDataset(NamedTuple): 16 | records: tf.data.Dataset 17 | vocab_size: int 18 | 19 | 20 | def load(batch_size: int, sequence_length: int) -> LanguageDataset: 21 | """Load LM1B dataset, returning it and vocab_size.""" 22 | ds, ds_info = tfds.load( 23 | 'lm1b/subwords32k', 24 | split=tfds.Split.TRAIN, 25 | shuffle_files=True, 26 | with_info=True) 27 | 28 | crop_size = sequence_length + 1 29 | ds = ds.repeat() 30 | # Convert the dataset to constant-size int32 tensors. 31 | ds = ds.map(lambda d: tf.cast(d['text'], tf.int32)) 32 | ds = ds.map(lambda t: _crop_or_pad(t, crop_size, pad_token=0)) 33 | ds = ds.shuffle(batch_size * 10) 34 | # Create the language modeling observation/target pairs and batch them up. 35 | ds = ds.map(lambda t: dict(obs=t[:-1], target=t[1:])) 36 | ds = ds.batch(batch_size, drop_remainder=True) 37 | ds = ds.prefetch(tf.data.experimental.AUTOTUNE) 38 | ds = iter(tfds.as_numpy(ds)) 39 | return LanguageDataset(ds, ds_info.features['text'].encoder.vocab_size) 40 | 41 | 42 | def _crop_or_pad(value, size, pad_token): 43 | """Either crop or pad value to be of size size.""" 44 | val_size = tf.size(value) 45 | pad = lambda: tf.pad( # pylint: disable=g-long-lambda 46 | value, [[0, size - val_size]], 47 | 'CONSTANT', 48 | constant_values=pad_token) 49 | return tf.cond(val_size < size, pad, lambda: value[:size]) 50 | 51 | 52 | ############### MODEL ########################## 53 | 54 | 55 | class SelfAttention(hk.MultiHeadAttention): 56 | """Self attention with a causal mask applied.""" 57 | 58 | def __call__( 59 | self, 60 | query: jnp.ndarray, 61 | key: Optional[jnp.ndarray] = None, 62 | value: Optional[jnp.ndarray] = None, 63 | mask: Optional[jnp.ndarray] = None, 64 | ) -> jnp.ndarray: 65 | key = key if key is not None else query 66 | value = value if value is not None else query 67 | 68 | seq_len = query.shape[1] 69 | causal_mask = np.tril(np.ones((seq_len, seq_len))) 70 | mask = mask * causal_mask if mask is not None else causal_mask 71 | 72 | return super().__call__(query, key, value, mask) 73 | 74 | 75 | class DenseBlock(hk.Module): 76 | """A 2-layer MLP""" 77 | 78 | def __init__(self, 79 | init_scale: float, 80 | widening_factor: int = 4, 81 | name: Optional[str] = None): 82 | super().__init__(name=name) 83 | self._init_scale = init_scale 84 | self._widening_factor = widening_factor 85 | 86 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray: 87 | hiddens = x.shape[-1] 88 | initializer = hk.initializers.VarianceScaling(self._init_scale) 89 | x = hk.Linear(self._widening_factor * hiddens, w_init=initializer)(x) 90 | x = jax.nn.gelu(x) 91 | return hk.Linear(hiddens, w_init=initializer)(x) 92 | 93 | 94 | class Transformer(hk.Module): 95 | """A transformer stack.""" 96 | 97 | def __init__(self, 98 | num_heads: int, 99 | num_layers: int, 100 | dropout_rate: float, 101 | name: Optional[str] = None): 102 | super().__init__(name=name) 103 | self._num_layers = num_layers 104 | self._num_heads = num_heads 105 | self._dropout_rate = dropout_rate 106 | 107 | def __call__(self, 108 | h: jnp.ndarray, 109 | mask: Optional[jnp.ndarray], 110 | is_training: bool) -> jnp.ndarray: 111 | """Connects the transformer. 112 | Args: 113 | h: Inputs, [B, T, H]. 114 | mask: Padding mask, [B, T]. 115 | is_training: Whether we're training or not. 116 | Returns: 117 | Array of shape [B, T, H]. 118 | """ 119 | 120 | init_scale = 2. / self._num_layers 121 | dropout_rate = self._dropout_rate if is_training else 0. 122 | if mask is not None: 123 | mask = mask[:, None, None, :] 124 | 125 | for i in range(self._num_layers): 126 | h_norm = layer_norm(h, name=f'h{i}_ln_1') 127 | h_attn = SelfAttention( 128 | num_heads=self._num_heads, 129 | key_size=64, 130 | w_init_scale=init_scale, 131 | name=f'h{i}_attn')(h_norm, mask=mask) 132 | h_attn = hk.dropout(hk.next_rng_key(), dropout_rate, h_attn) 133 | h = h + h_attn 134 | h_norm = layer_norm(h, name=f'h{i}_ln_2') 135 | h_dense = DenseBlock(init_scale, name=f'h{i}_mlp')(h_norm) 136 | h_dense = hk.dropout(hk.next_rng_key(), dropout_rate, h_dense) 137 | h = h + h_dense 138 | h = layer_norm(h, name='ln_f') 139 | 140 | return h 141 | 142 | 143 | def layer_norm(x: jnp.ndarray, name: Optional[str] = None) -> jnp.ndarray: 144 | """Apply a unique LayerNorm to x with default settings.""" 145 | return hk.LayerNorm(axis=-1, 146 | create_scale=True, 147 | create_offset=True, 148 | name=name)(x) 149 | 150 | 151 | ######################################### TRAIN ######################################### 152 | 153 | 154 | batch_size = 16 # Train batch size per core 155 | sequence_length = 128 # Sequence length to learn on 156 | 157 | d_model = 256 # model width 158 | num_heads = 4 # Number of attention heads 159 | num_layers = 6 # Number of transformer layers 160 | dropout_rate = 0.1 # Dropout rate 161 | 162 | learning_rate = 2e-4 # Max learning-rate 163 | grad_clip_value = 0.25 # Gradient norm clip value 164 | 165 | checkpoint_dir = '/jax-transformer' # Directory to store checkpoints 166 | LOG_EVERY = 50 167 | MAX_STEPS = 10 ** 6 168 | 169 | 170 | def embeddings(data: Mapping[str, jnp.ndarray], vocab_size: int) : 171 | tokens = data['obs'] 172 | input_mask = jnp.greater(tokens, 0) 173 | seq_length = tokens.shape[1] 174 | 175 | # Embed the input tokens and positions. 176 | embed_init = hk.initializers.TruncatedNormal(stddev=0.02) 177 | token_embedding_map = hk.Embed(vocab_size, d_model, w_init=embed_init) 178 | token_embs = token_embedding_map(tokens) 179 | positional_embeddings = hk.get_parameter( 180 | 'pos_embs', [seq_length, d_model], init=embed_init) 181 | input_embeddings = token_embs + positional_embeddings 182 | return input_embeddings, input_mask 183 | 184 | 185 | def build_forward_fn(vocab_size: int, d_model: int, num_heads: int, 186 | num_layers: int, dropout_rate: float): 187 | """Create the model's forward pass.""" 188 | 189 | def forward_fn(data: Mapping[str, jnp.ndarray], 190 | is_training: bool = True) -> jnp.ndarray: 191 | """Forward pass.""" 192 | input_embeddings, input_mask = embeddings(data, vocab_size) 193 | 194 | # Run the transformer over the inputs. 195 | transformer = Transformer( 196 | num_heads=num_heads, num_layers=num_layers, dropout_rate=dropout_rate) 197 | output_embeddings = transformer(input_embeddings, input_mask, is_training) 198 | 199 | # Reverse the embeddings (untied). 200 | return hk.Linear(vocab_size)(output_embeddings) 201 | 202 | return forward_fn 203 | 204 | 205 | def lm_loss_fn(forward_fn, 206 | vocab_size: int, 207 | params, 208 | rng, 209 | data: Mapping[str, jnp.ndarray], 210 | is_training: bool = True) -> jnp.ndarray: 211 | """Compute the loss on data wrt params.""" 212 | logits = forward_fn(params, rng, data, is_training) 213 | targets = jax.nn.one_hot(data['target'], vocab_size) 214 | assert logits.shape == targets.shape 215 | 216 | mask = jnp.greater(data['obs'], 0) 217 | loss = -jnp.sum(targets * jax.nn.log_softmax(logits), axis=-1) 218 | loss = jnp.sum(loss * mask) / jnp.sum(mask) 219 | 220 | return loss 221 | 222 | 223 | class GradientUpdater: 224 | """A stateless abstraction around an init_fn/update_fn pair. 225 | This extracts some common boilerplate from the training loop. 226 | """ 227 | 228 | def __init__(self, net_init, loss_fn, 229 | optimizer: optax.GradientTransformation): 230 | self._net_init = net_init 231 | self._loss_fn = loss_fn 232 | self._opt = optimizer 233 | 234 | @functools.partial(jax.jit, static_argnums=0) 235 | def init(self, master_rng, data): 236 | """Initializes state of the updater.""" 237 | out_rng, init_rng = jax.random.split(master_rng) 238 | params = self._net_init(init_rng, data) 239 | opt_state = self._opt.init(params) 240 | out = dict( 241 | step=np.array(0), 242 | rng=out_rng, 243 | opt_state=opt_state, 244 | params=params, 245 | ) 246 | return out 247 | 248 | @functools.partial(jax.jit, static_argnums=0) 249 | def update(self, state: Mapping[str, Any], data: Mapping[str, jnp.ndarray]): 250 | """Updates the state using some data and returns metrics.""" 251 | rng, new_rng = jax.random.split(state['rng']) 252 | params = state['params'] 253 | loss, g = jax.value_and_grad(self._loss_fn)(params, rng, data) 254 | 255 | updates, opt_state = self._opt.update(g, state['opt_state']) 256 | params = optax.apply_updates(params, updates) 257 | 258 | new_state = { 259 | 'step': state['step'] + 1, 260 | 'rng': new_rng, 261 | 'opt_state': opt_state, 262 | 'params': params, 263 | } 264 | 265 | metrics = { 266 | 'step': state['step'], 267 | 'loss': loss, 268 | } 269 | return new_state, metrics 270 | 271 | 272 | def main(): 273 | # Create the dataset. 274 | train_dataset, vocab_size = load(batch_size, 275 | sequence_length) 276 | # Set up the model, loss, and updater. 277 | forward_fn = build_forward_fn(vocab_size, d_model, num_heads, 278 | num_layers, dropout_rate) 279 | forward_fn = hk.transform(forward_fn) 280 | loss_fn = functools.partial(lm_loss_fn, forward_fn.apply, vocab_size) 281 | 282 | optimizer = optax.chain( 283 | optax.clip_by_global_norm(grad_clip_value), 284 | optax.adam(learning_rate, b1=0.9, b2=0.99)) 285 | 286 | updater = GradientUpdater(forward_fn.init, loss_fn, optimizer) 287 | 288 | # Initialize parameters. 289 | logging.info('Initializing parameters...') 290 | rng = jax.random.PRNGKey(428) 291 | data = next(train_dataset) 292 | state = updater.init(rng, data) 293 | 294 | logging.info('Starting train loop...') 295 | prev_time = time.time() 296 | for step in range(MAX_STEPS): 297 | data = next(train_dataset) 298 | state, metrics = updater.update(state, data) 299 | # We use JAX runahead to mask data preprocessing and JAX dispatch overheads. 300 | # Using values from state/metrics too often will block the runahead and can 301 | # cause these overheads to become more prominent. 302 | # if step % LOG_EVERY == 0: 303 | # steps_per_sec = LOG_EVERY / (time.time() - prev_time) 304 | # prev_time = time.time() 305 | # metrics.update({'steps_per_sec': steps_per_sec}) 306 | # logging.info({k: float(v) for k, v in metrics.items()}) 307 | 308 | 309 | main() 310 | --------------------------------------------------------------------------------