├── .gitignore ├── EasyLM ├── __init__.py ├── bpt.py ├── checkpoint.py ├── data.py ├── jax_utils.py ├── models │ ├── __init__.py │ └── llama │ │ ├── convert_easylm_to_hf.py │ │ ├── convert_hf_to_easylm.py │ │ ├── llama_model.py │ │ ├── llama_serve.py │ │ └── llama_train.py ├── optimizers.py ├── scripts │ ├── __init__.py │ ├── benchmark_attention.py │ ├── convert_checkpoint.py │ ├── diff_checkpoint.py │ ├── lm_eval_harness.py │ └── lm_eval_json.py └── serving.py ├── LICENSE ├── README.md ├── docs ├── README.md ├── checkpointing.md ├── dataset.md ├── evaluation.md ├── koala.md ├── llama.md ├── logger.md ├── optimizer.md ├── parallelism.md └── serving.md ├── examples ├── pretrain_llama_7b.sh └── serve_llama_7b.sh └── scripts ├── gpu_environment.yml ├── tpu_commands.sh └── tpu_vm_setup.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | launcher/ 132 | -------------------------------------------------------------------------------- /EasyLM/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/young-geng/EasyLM/fe5b2c354e25d697fce7cd225e23bbbe72570da3/EasyLM/__init__.py -------------------------------------------------------------------------------- /EasyLM/bpt.py: -------------------------------------------------------------------------------- 1 | """ 2 | An implementation of Blockwise parallel transformer https://arxiv.org/abs/2305.19370 3 | Also include a reference implementation of memory-efficient transformer https://arxiv.org/abs/2112.05682 4 | """ 5 | 6 | import functools 7 | from typing import NamedTuple 8 | 9 | import flax.linen as nn 10 | import jax 11 | import jax.lax as lax 12 | import jax.numpy as jnp 13 | from einops import rearrange 14 | 15 | """ 16 | Computing ffn blockwise without materializing the large hidden tensor, training 17 | 4x longer sequences than the memory-efficient transformer. 18 | Blockwise parallel transformer https://arxiv.org/abs/2305.19370 Liu et al. 2023 19 | """ 20 | def blockwise_ffn(remat_ffn, inputs, chunk_size=2048, deterministic=True): 21 | # remat_ffn: a rematerialized ffn with policy jax.checkpoint_policies.nothing_saveable() 22 | # inputs: (batch, seq_len, dim) 23 | # chunk_size: the chunk size to split the sequence 24 | inputs = rearrange(inputs, 'b (c n) d -> b c n d', c=chunk_size) 25 | def scan_ffn(remat_ffn, carry, hidden_states): 26 | outputs = remat_ffn(hidden_states, deterministic=deterministic) 27 | return carry, outputs 28 | scan_axis = inputs.ndim - 2 29 | _, res = nn.scan( 30 | scan_ffn, 31 | variable_broadcast="params", 32 | split_rngs={"params": False, "dropout": True}, 33 | in_axes=scan_axis, 34 | out_axes=scan_axis, 35 | )(remat_ffn, None, inputs) 36 | res = rearrange(res, 'b c n d -> b (c n) d') 37 | return res 38 | 39 | 40 | """ 41 | Compute attention blockwise without materializing the full attention matrix, 42 | initially proposed in memory-efficient transformer https://arxiv.org/abs/2112.05682 Rabe et al. 2021; 43 | flash attention https://arxiv.org/abs/2205.14135 Dao et al. 2022 proposes a CUDA 44 | efficient implementation; blockwise parallel transformer https://arxiv.org/abs/2305.19370 45 | Liu et al. 2023 proposes blockwise computing both attention and FFN, enabling 4x 46 | longer sequences than memory-efficient/flash-attention and fusion of attention and FFN. 47 | """ 48 | def blockwise_attn( 49 | query, key, value, 50 | bias=None, 51 | deterministic=True, 52 | dropout_rng=None, 53 | attn_pdrop=0.0, 54 | causal=True, 55 | query_chunk_size=2048, 56 | key_chunk_size=2048, 57 | dtype=jnp.float32, 58 | policy=jax.checkpoint_policies.nothing_saveable(), 59 | precision=None, 60 | float32_logits=True, 61 | prevent_cse=True, 62 | ): 63 | # query, key, value: (batch, seq_len, num_heads, dim_per_head) 64 | # bias: (batch, seq_len) can be used to mask out attention (e.g. padding) 65 | # causal: whether to use causal mask 66 | # policy: one of jax.checkpoint_policies 67 | query = query / jnp.sqrt(query.shape[-1]).astype(dtype) 68 | if float32_logits: 69 | query = query.astype(jnp.float32) 70 | key = key.astype(jnp.float32) 71 | 72 | batch, q_len, num_heads, dim_per_head = query.shape 73 | batch, kv_len, num_heads, dim_per_head = key.shape 74 | batch, kv_len, num_heads, dim_per_head = value.shape 75 | 76 | num_q = q_len // query_chunk_size 77 | num_kv = kv_len // key_chunk_size 78 | query = query.reshape((batch, num_q, query_chunk_size, num_heads, dim_per_head)) 79 | key = key.reshape((batch, num_kv, key_chunk_size, num_heads, dim_per_head)) 80 | value = value.reshape((batch, num_kv, key_chunk_size, num_heads, dim_per_head)) 81 | 82 | query = jnp.moveaxis(query, 1, 0) 83 | key = jnp.moveaxis(key, 1, 0) 84 | value = jnp.moveaxis(value, 1, 0) 85 | 86 | if bias is not None: 87 | for bias_dim, broadcast_dim in zip(bias.shape, (batch, num_heads, q_len, kv_len)): 88 | assert bias_dim == 1 or bias_dim == broadcast_dim 89 | if not deterministic and attn_pdrop > 0.0: 90 | attn_dropout_rng, dropout_rng = jax.random.split(dropout_rng) 91 | attn_dropout = jax.random.bernoulli(attn_dropout_rng, attn_pdrop, (batch, num_heads, q_len, kv_len)) 92 | else: 93 | attn_dropout = None 94 | 95 | _chunk_bias_fn = functools.partial( 96 | _chunk_attention_bias, 97 | query_chunk_size, key_chunk_size, bias, deterministic, 98 | attn_dropout, attn_pdrop, causal, dtype) 99 | 100 | def scan_attention(args): 101 | query_chunk, query_chunk_idx = args 102 | 103 | @functools.partial(jax.checkpoint, prevent_cse=prevent_cse, policy=policy) 104 | def scan_kv_block(carry, args): 105 | key_chunk, value_chunk, key_chunk_idx = args 106 | (numerator, denominator, prev_max_score) = carry 107 | attn_weights = jnp.einsum('bqhd,bkhd->bqhk', query_chunk, key_chunk, precision=precision) 108 | bias_chunk = _chunk_bias_fn(query_chunk_idx, key_chunk_idx) 109 | bias_chunk = jnp.moveaxis(bias_chunk, 1, 2) 110 | attn_weights = attn_weights + bias_chunk 111 | 112 | max_score = jnp.max(attn_weights, axis=-1, keepdims=True) 113 | max_score = jnp.maximum(prev_max_score, max_score) 114 | max_score = jax.lax.stop_gradient(max_score) 115 | exp_weights = jnp.exp(attn_weights - max_score) 116 | exp_values = jnp.einsum( 117 | 'bqhv,bvhd->bqhd', exp_weights, value_chunk, precision=precision 118 | ) 119 | correction = jnp.exp(prev_max_score - max_score) 120 | numerator = numerator * correction + exp_values 121 | denominator = denominator * correction + exp_weights.sum(axis=-1, keepdims=True) 122 | return Carry(numerator, denominator, max_score), None 123 | 124 | def skip_upper_half(carry, args): 125 | key_chunk, value_chunk, key_chunk_idx = args 126 | skip_block = jnp.array(False) 127 | if causal: 128 | skip_block = query_chunk_idx < key_chunk_idx 129 | return jax.lax.cond( 130 | skip_block, 131 | lambda carry, args: (carry, None), 132 | scan_kv_block, 133 | carry, 134 | args, 135 | ) 136 | 137 | init_carry = Carry( 138 | jnp.zeros((batch, query_chunk_size, num_heads, dim_per_head), dtype=query.dtype), 139 | jnp.zeros((batch, query_chunk_size, num_heads, dim_per_head), dtype=query.dtype), 140 | (-jnp.inf) * jnp.ones((batch, query_chunk_size, num_heads, 1), dtype=query.dtype), 141 | ) 142 | (numerator, denominator, max_score), _ = lax.scan( 143 | skip_upper_half, init_carry, xs=(key, value, jnp.arange(0, num_kv)) 144 | ) 145 | outputs = (numerator / denominator).astype(dtype) 146 | return outputs 147 | 148 | _, res = lax.scan( 149 | lambda _, x: ((), scan_attention(x)), 150 | (), xs=(query, jnp.arange(0, num_q)) 151 | ) 152 | res = rearrange(res, 'n b c h d -> b (n c) h d') 153 | return res 154 | 155 | 156 | class Carry(NamedTuple): 157 | numerator: jax.Array 158 | denominator: jax.Array 159 | max_so_far: jax.Array 160 | 161 | 162 | def _chunk_attention_bias(query_chunk_size, key_chunk_size, 163 | bias, deterministic, attn_dropout, attn_pdrop, causal, 164 | dtype, query_chunk_idx, key_chunk_idx): 165 | query_offset = query_chunk_idx * query_chunk_size 166 | key_offset = key_chunk_idx * key_chunk_size 167 | chunk_bias = jnp.zeros((1, 1, 1, 1), dtype=dtype) 168 | if bias is not None: 169 | chunk_bias = lax.dynamic_slice( 170 | bias, 171 | start_indices=(0, 0, query_offset, key_offset), 172 | slice_sizes=(*bias.shape[:2], min(bias.shape[-2], query_chunk_size), min(bias.shape[-1], key_chunk_size)), 173 | ) 174 | 175 | if causal: 176 | query_idx = lax.broadcasted_iota(dtype=jnp.int32, shape=(query_chunk_size, 1), dimension=0) 177 | key_idx = lax.broadcasted_iota(dtype=jnp.int32, shape=(1, key_chunk_size), dimension=1) 178 | offset = query_offset - key_offset 179 | query_idx += offset 180 | causal_mask_value = (query_idx < key_idx) * jnp.finfo(dtype).min 181 | chunk_bias += causal_mask_value.reshape(1, 1, *causal_mask_value.shape) 182 | 183 | if not deterministic and attn_pdrop > 0.0: 184 | attn_dropout_slice = lax.dynamic_slice( 185 | attn_dropout, 186 | start_indices=(0, 0, query_offset, key_offset), 187 | slice_sizes=( 188 | *attn_dropout.shape[:2], 189 | min(attn_dropout.shape[-2], query_chunk_size), 190 | min(attn_dropout.shape[-1], key_chunk_size), 191 | ), 192 | ) 193 | chunk_bias += attn_dropout_slice * jnp.finfo(dtype).min 194 | return chunk_bias.astype(dtype) 195 | 196 | 197 | if __name__ == '__main__': 198 | # test 199 | def reference_attn(query, key, value, causal, dtype): 200 | query = query / jnp.sqrt(query.shape[-1]).astype(dtype) 201 | logits = jnp.einsum("bqhc,bkhc->bhqk", query, key) 202 | if causal: 203 | mask_value = jnp.finfo(logits.dtype).min 204 | _, q_seq_len, _, _ = query.shape 205 | _, kv_seq_len, _, _ = key.shape 206 | mask_shape = (q_seq_len, kv_seq_len) 207 | row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0) 208 | col_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 1) 209 | causal_mask = (row_ids < col_ids)[None, None, :, :] 210 | logits = logits + jnp.where(causal_mask, mask_value, 0.0) 211 | weights = jax.nn.softmax(logits, axis=-1) 212 | out = jnp.einsum("bhqk,bkhc->bqhc", weights, value) 213 | return out 214 | 215 | # random inputs 216 | shape = (1, 32, 8, 64) 217 | query = jax.random.normal(jax.random.PRNGKey(0), shape) 218 | key = jax.random.normal(jax.random.PRNGKey(1), shape) 219 | value = jax.random.normal(jax.random.PRNGKey(2), shape) 220 | 221 | causal = True 222 | chunk_size = 4 223 | policy = jax.checkpoint_policies.nothing_saveable() 224 | 225 | blockwise = blockwise_attn(query, key, value, None, False, None, 0.0, causal, chunk_size, chunk_size, jnp.float32, policy, 'float32', True, False) 226 | reference = reference_attn(query, key, value, causal, 'float32') 227 | 228 | assert jnp.allclose(reference, blockwise, atol=1e-6) 229 | -------------------------------------------------------------------------------- /EasyLM/checkpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from ml_collections import ConfigDict 4 | import mlxu 5 | import jax 6 | import jax.numpy as jnp 7 | import flax 8 | from flax.serialization import ( 9 | from_bytes, to_bytes, to_state_dict, from_state_dict 10 | ) 11 | from flax.traverse_util import flatten_dict, unflatten_dict, empty_node 12 | import msgpack 13 | 14 | from EasyLM.jax_utils import tree_apply, float_tensor_to_dtype 15 | 16 | 17 | class StreamingCheckpointer(object): 18 | """ Custom msgpack checkpointer that saves large train states by serializing 19 | and saving tensors one by one in a streaming fashion. Avoids running 20 | out of memory or local TPU disk with default flax checkpointer. 21 | """ 22 | 23 | @staticmethod 24 | def get_default_config(updates=None): 25 | config = ConfigDict() 26 | config.float_dtype = 'bf16' 27 | config.save_optimizer_state = False 28 | 29 | if updates is not None: 30 | config.update(ConfigDict(updates).copy_and_resolve_references()) 31 | return config 32 | 33 | def __init__(self, config, checkpoint_dir, enable=True): 34 | self.config = self.get_default_config(config) 35 | self.checkpoint_dir = checkpoint_dir 36 | self.enable = enable 37 | 38 | def save_checkpoint(self, train_state, filename, gather_fns=None): 39 | if self.enable: 40 | path = os.path.join(self.checkpoint_dir, filename) 41 | else: 42 | path = '/dev/null' 43 | self.save_train_state_to_file( 44 | train_state, path, gather_fns, self.config.float_dtype 45 | ) 46 | 47 | @staticmethod 48 | def save_train_state_to_file(train_state, path, gather_fns=None, float_dtype=None): 49 | train_state = to_state_dict(train_state) 50 | packer = msgpack.Packer() 51 | flattend_train_state = flatten_dict(train_state) 52 | if gather_fns is not None: 53 | gather_fns = flatten_dict(to_state_dict(gather_fns)) 54 | 55 | with mlxu.open_file(path, "wb") as fout: 56 | for key, value in flattend_train_state.items(): 57 | if gather_fns is not None: 58 | value = gather_fns[key](value) 59 | value = float_tensor_to_dtype(value, float_dtype) 60 | fout.write(packer.pack((key, to_bytes(value)))) 61 | 62 | def save_pickle(self, obj, filename): 63 | if self.enable: 64 | path = os.path.join(self.checkpoint_dir, filename) 65 | else: 66 | path = '/dev/null' 67 | mlxu.save_pickle(obj, path) 68 | 69 | def save_all(self, train_state, gather_fns, metadata=None, dataset=None, milestone=False): 70 | step = int(jax.device_get(train_state.step)) 71 | if self.config.save_optimizer_state: 72 | checkpoint_state = train_state 73 | checkpoint_name = 'streaming_train_state' 74 | checkpoint_gather_fns = gather_fns 75 | else: 76 | checkpoint_state = train_state.params['params'] 77 | checkpoint_name = 'streaming_params' 78 | checkpoint_gather_fns = gather_fns.params['params'] 79 | 80 | if milestone: 81 | # Save a milestone checkpoint that will not be overwritten 82 | self.save_pickle(metadata, f'metadata_{step}.pkl') 83 | self.save_pickle(dataset, f'dataset_{step}.pkl') 84 | self.save_checkpoint( 85 | checkpoint_state, f'{checkpoint_name}_{step}', checkpoint_gather_fns 86 | ) 87 | else: 88 | # Save a normal checkpoint that can be overwritten 89 | self.save_pickle(metadata, 'metadata.pkl') 90 | self.save_pickle(dataset, 'dataset.pkl') 91 | self.save_checkpoint( 92 | checkpoint_state, f'{checkpoint_name}', checkpoint_gather_fns 93 | ) 94 | 95 | @staticmethod 96 | def load_checkpoint(path, target=None, shard_fns=None, remove_dict_prefix=None): 97 | if shard_fns is not None: 98 | shard_fns = flatten_dict( 99 | to_state_dict(shard_fns) 100 | ) 101 | if remove_dict_prefix is not None: 102 | remove_dict_prefix = tuple(remove_dict_prefix) 103 | flattend_train_state = {} 104 | with mlxu.open_file(path) as fin: 105 | # 83886080 bytes = 80 MB, which is 16 blocks on GCS 106 | unpacker = msgpack.Unpacker(fin, read_size=83886080, max_buffer_size=0) 107 | for key, value in unpacker: 108 | key = tuple(key) 109 | if remove_dict_prefix is not None: 110 | if key[:len(remove_dict_prefix)] == remove_dict_prefix: 111 | key = key[len(remove_dict_prefix):] 112 | else: 113 | continue 114 | 115 | tensor = from_bytes(None, value) 116 | if shard_fns is not None: 117 | tensor = shard_fns[key](tensor) 118 | flattend_train_state[key] = tensor 119 | 120 | if target is not None: 121 | flattened_target = flatten_dict( 122 | to_state_dict(target), keep_empty_nodes=True 123 | ) 124 | for key, value in flattened_target.items(): 125 | if key not in flattend_train_state and value == empty_node: 126 | flattend_train_state[key] = value 127 | 128 | train_state = unflatten_dict(flattend_train_state) 129 | if target is None: 130 | return train_state 131 | 132 | return from_state_dict(target, train_state) 133 | 134 | @staticmethod 135 | def load_flax_checkpoint(path, target=None, shard_fns=None): 136 | """ Load a standard flax checkpoint that's not saved with the 137 | msgpack streaming format. 138 | """ 139 | with mlxu.open_file(path, "rb") as fin: 140 | encoded_bytes = fin.read() 141 | 142 | state_dict = flax.serialization.msgpack_restore(encoded_bytes) 143 | if shard_fns is not None: 144 | shard_fns = to_state_dict(shard_fns) 145 | state_dict = tree_apply(shard_fns, state_dict) 146 | 147 | if target is None: 148 | return state_dict 149 | return from_state_dict(target, state_dict) 150 | 151 | @classmethod 152 | def load_trainstate_checkpoint(cls, load_from, trainstate_target=None, 153 | trainstate_shard_fns=None, 154 | disallow_trainstate=False): 155 | if trainstate_target is not None: 156 | params_target = trainstate_target.params['params'] 157 | else: 158 | params_target = None 159 | 160 | if trainstate_shard_fns is not None: 161 | params_shard_fns = trainstate_shard_fns.params['params'] 162 | else: 163 | params_shard_fns = None 164 | 165 | load_type, load_path = load_from.split('::', 1) 166 | if disallow_trainstate: 167 | assert load_type != 'trainstate', 'Loading full trainstate is not allowed!' 168 | train_state = None 169 | restored_params = None 170 | if load_type == 'trainstate': 171 | # Load the entire train state in the streaming format 172 | train_state = cls.load_checkpoint( 173 | path=load_path, 174 | target=trainstate_target, 175 | shard_fns=trainstate_shard_fns, 176 | ) 177 | elif load_type == 'trainstate_params': 178 | # Load the params part of the train state in the streaming format 179 | restored_params = cls.load_checkpoint( 180 | path=load_path, 181 | target=params_target, 182 | shard_fns=params_shard_fns, 183 | remove_dict_prefix=('params', 'params'), 184 | ) 185 | restored_params = {'params': restored_params} 186 | elif load_type == 'params': 187 | # Load the params in the streaming format 188 | restored_params = cls.load_checkpoint( 189 | path=load_path, 190 | target=params_target, 191 | shard_fns=params_shard_fns, 192 | ) 193 | restored_params = {'params': restored_params} 194 | elif load_type == 'flax_params': 195 | # Load the params in the standard flax format (non-streaming) 196 | # This requires the entire params to fit in memory 197 | restored_params = cls.load_flax_checkpoint( 198 | path=load_path, 199 | target=params_target, 200 | shard_fns=params_shard_fns 201 | ) 202 | restored_params = {'params': restored_params} 203 | else: 204 | raise ValueError(f'Invalid load_from type: {load_type}') 205 | 206 | return train_state, restored_params 207 | -------------------------------------------------------------------------------- /EasyLM/data.py: -------------------------------------------------------------------------------- 1 | import time 2 | from functools import partial 3 | import json 4 | import base64 5 | from multiprocessing import Pool 6 | 7 | import mlxu 8 | import numpy as np 9 | from datasets import load_dataset 10 | 11 | 12 | class DatasetFactory(object): 13 | """ Datset builder class. """ 14 | 15 | @staticmethod 16 | def get_default_config(updates=None): 17 | config = mlxu.config_dict() 18 | config.type = 'huggingface' 19 | config.text_processor = TextProcessor.get_default_config() 20 | config.huggingface_dataset = HuggingfaceDataset.get_default_config() 21 | config.json_dataset = JsonDataset.get_default_config() 22 | return mlxu.update_config_dict(config, updates) 23 | 24 | @classmethod 25 | def load_dataset(cls, config, tokenizer, **kwargs): 26 | config = cls.get_default_config(config) 27 | text_processor = TextProcessor(config.text_processor, tokenizer) 28 | if config.type == 'huggingface': 29 | return HuggingfaceDataset( 30 | config.huggingface_dataset, tokenizer, text_processor, **kwargs 31 | ) 32 | elif config.type == 'json': 33 | return JsonDataset(config.json_dataset, tokenizer, text_processor, **kwargs) 34 | else: 35 | raise ValueError(f'Unknown dataset type: {config.type}') 36 | 37 | def __init__(self): 38 | raise ValueError('DatasetFactory is a static class and should not be instantiated.') 39 | 40 | 41 | class TextProcessor(object): 42 | """ Example processor that converts a dictionary of texts into tokens. """ 43 | 44 | @staticmethod 45 | def get_default_config(updates=None): 46 | config = mlxu.config_dict() 47 | config.fields_from_example = '' 48 | config.fields = '' 49 | config.subfield_separator = ' ' 50 | config.add_bos_token = True 51 | config.add_eos_token = True 52 | config.prepend_text = '' 53 | config.base64_token_dtype = 'i4' 54 | return mlxu.update_config_dict(config, updates) 55 | 56 | def __init__(self, config, tokenizer): 57 | self.config = self.get_default_config(config) 58 | assert self.config.fields != '' or self.config.fields_from_example != '', ( 59 | 'Either fields or fields_from_example must be specified.' 60 | ) 61 | self.tokenizer = tokenizer 62 | 63 | def __call__(self, example, has_aux=False): 64 | if has_aux: 65 | example, *aux = example 66 | else: 67 | aux = tuple() 68 | token_buffer = [] 69 | loss_mask_buffer = [] 70 | 71 | if self.config.add_bos_token: 72 | token_buffer.append(self.tokenizer.bos_token_id) 73 | loss_mask_buffer.append(0.0) 74 | 75 | if self.config.fields_from_example != '': 76 | fields = example[self.config.fields_from_example].split(',') 77 | else: 78 | fields = self.config.fields.split(',') 79 | 80 | for i, field in enumerate(fields): 81 | if field.startswith('[') and field.endswith(']'): 82 | # No loss for this field. 83 | field = field[1:-1] 84 | mask = 0.0 85 | else: 86 | mask = 1.0 87 | 88 | if field.startswith('<|') and field.endswith('|>'): 89 | # Special tokens. 90 | field = field[2:-2] 91 | if field == 'bos': 92 | token_buffer.append(self.tokenizer.bos_token_id) 93 | elif field == 'eos': 94 | token_buffer.append(self.tokenizer.eos_token_id) 95 | else: 96 | # Token ID specified directly. 97 | token_buffer.append(int(field)) 98 | loss_mask_buffer.append(mask) 99 | elif field.startswith('{') and field.endswith('}'): 100 | field = field[1:-1] 101 | # Base64 encoded raw tokens. 102 | tokens = np.frombuffer( 103 | base64.b64decode(example[field]), 104 | dtype=self.config.base64_token_dtype 105 | ).tolist() 106 | token_buffer.extend(tokens) 107 | loss_mask_buffer.extend([mask for _ in range(len(tokens))]) 108 | else: 109 | subfields = field.split('+') 110 | text = self.config.subfield_separator.join( 111 | [example[subfield] for subfield in subfields] 112 | ) 113 | if i == 0: 114 | text = self.config.prepend_text + text 115 | tokens = self.tokenizer.encode(text, add_special_tokens=False) 116 | token_buffer.extend(tokens) 117 | loss_mask_buffer.extend([mask for _ in range(len(tokens))]) 118 | 119 | if self.config.add_eos_token: 120 | token_buffer.append(self.tokenizer.eos_token_id) 121 | loss_mask_buffer.append(1.0) 122 | 123 | return token_buffer, loss_mask_buffer, *aux 124 | 125 | 126 | class HuggingfaceDataset(object): 127 | """ Huggingface dataset, where the dataset is loaded using the huggingface 128 | datasets.load_dataset() function. 129 | """ 130 | 131 | @staticmethod 132 | def get_default_config(updates=None): 133 | config = mlxu.config_dict() 134 | config.path = 'c4' 135 | config.name = 'en' 136 | config.split = 'train' 137 | config.streaming = False 138 | config.seq_length = 1024 139 | config.batch_size = 8 140 | config.always_start_with_bos = False 141 | config.batch_token_dtype = 'i4' 142 | return mlxu.update_config_dict(config, updates) 143 | 144 | def __init__(self, config, tokenizer, text_processor): 145 | self.config = self.get_default_config(config) 146 | name = self.config.name if self.config.name != '' else None 147 | split = self.config.split if self.config.split != '' else None 148 | self._tokenizer = tokenizer 149 | self._text_processor = text_processor 150 | self._dataset = load_dataset( 151 | self.config.path, name, split=split, streaming=self.config.streaming 152 | ) 153 | 154 | def __iter__(self): 155 | chunk_size = self.config.batch_size * self.config.seq_length 156 | total_tokens = 0 157 | while True: 158 | token_buffer = [] 159 | loss_mask_buffer = [] 160 | for index, example in enumerate(self._dataset): 161 | tokens, loss_masks = self.text_processor(example) 162 | token_buffer.extend(tokens) 163 | loss_mask_buffer.extend(loss_masks) 164 | while len(token_buffer) > chunk_size + 1: 165 | total_tokens += chunk_size 166 | metrics = { 167 | 'dataset_example_index': index, 168 | 'dataset_total_tokens': total_tokens, 169 | } 170 | batch = { 171 | 'input_tokens': np.array(token_buffer[:chunk_size], dtype=self.config.batch_token_dtype).reshape( 172 | self.config.batch_size, -1 173 | ), 174 | 'target_tokens': np.array(token_buffer[1:chunk_size + 1], dtype=self.config.batch_token_dtype).reshape( 175 | self.config.batch_size, -1 176 | ), 177 | 'loss_masks': np.array(loss_mask_buffer[1:chunk_size + 1], dtype=np.float32).reshape( 178 | self.config.batch_size, -1 179 | ), 180 | } 181 | if self.config.always_start_with_bos: 182 | batch['input_tokens'][:, 0] = self.tokenizer.bos_token_id 183 | yield batch, metrics 184 | token_buffer = token_buffer[chunk_size:] 185 | loss_mask_buffer = loss_mask_buffer[chunk_size:] 186 | 187 | def get_state_dict(self): 188 | return dict(config=self.config) 189 | 190 | def load_state_dict(self, state_dict): 191 | if 'config' in state_dict: 192 | self.config.update(mlxu.ConfigDict(state_dict['config'])) 193 | 194 | @property 195 | def seq_length(self): 196 | return self.config.seq_length 197 | 198 | @property 199 | def tokenizer(self): 200 | return self._tokenizer 201 | 202 | @property 203 | def text_processor(self): 204 | return self._text_processor 205 | 206 | @property 207 | def dataset(self): 208 | return self._dataset 209 | 210 | @property 211 | def vocab_size(self): 212 | return len(self._tokenizer) 213 | 214 | 215 | class JsonDataset(object): 216 | """ JSON dataset, where each line of the data file contains a JSON 217 | dictionary with text fields. 218 | """ 219 | 220 | @staticmethod 221 | def get_default_config(updates=None): 222 | config = mlxu.config_dict() 223 | config.path = '' 224 | config.seq_length = 1024 225 | config.batch_size = 8 226 | config.always_start_with_bos = False 227 | config.start_seek_loc = 0 228 | config.example_index_at_start = 0 229 | config.tokens_count_at_start = 0 230 | config.tokenizer_processes = 1 231 | config.tokenizer_parallel_chunk_size = 32 232 | config.tokenizer_parallel_batch_size = 1024 233 | config.throughput_average_window_size = 200 234 | return mlxu.update_config_dict(config, updates) 235 | 236 | def __init__(self, config, tokenizer, text_processor): 237 | self.config = self.get_default_config(config) 238 | assert self.config.path != '' 239 | self._tokenizer = tokenizer 240 | self._text_processor = text_processor 241 | self._index = self.config.example_index_at_start 242 | self._file_loc = self.config.start_seek_loc 243 | self._total_tokens = self.config.tokens_count_at_start 244 | 245 | def parse_json(self, line): 246 | if not line or line == '\n': 247 | return None 248 | try: 249 | data = json.loads(line) 250 | except json.decoder.JSONDecodeError: 251 | print(f'Error parsing json line:\n{line}') 252 | return None 253 | return data 254 | 255 | def json_iterator(self): 256 | with mlxu.open_file(self.config.path, 'r') as fin: 257 | fin.seek(self._file_loc) 258 | while True: 259 | line = fin.readline() 260 | self._file_loc = fin.tell() 261 | if not line: # Reached EOF 262 | self._index = 0 263 | fin.seek(0) 264 | continue 265 | 266 | data = self.parse_json(line) 267 | if data is not None: 268 | # JSON parsing succeeded 269 | yield data, self._file_loc, self._index 270 | self._index += 1 271 | 272 | def batched(self, iterator, batch_size): 273 | batch = [] 274 | for example in iterator: 275 | batch.append(example) 276 | if len(batch) == batch_size: 277 | yield batch 278 | batch = [] 279 | if len(batch) > 0: 280 | yield batch 281 | 282 | def parallel_example_iterator(self): 283 | if self.config.tokenizer_processes == 1: 284 | for example, loc, index in self.json_iterator(): 285 | yield self.text_processor((example, loc, index), has_aux=True) 286 | else: 287 | process_pool = Pool(self.config.tokenizer_processes) 288 | batched_iterator = self.batched( 289 | self.json_iterator(), self.config.tokenizer_parallel_batch_size 290 | ) 291 | with process_pool as pool: 292 | map_fn = partial(self.text_processor, has_aux=True) 293 | next_batch = pool.map_async( 294 | map_fn, next(batched_iterator), 295 | chunksize=self.config.tokenizer_parallel_chunk_size 296 | ) 297 | while True: 298 | current_batch = next_batch 299 | next_batch = pool.map_async( 300 | map_fn, next(batched_iterator), 301 | chunksize=self.config.tokenizer_parallel_chunk_size 302 | ) 303 | for example in current_batch.get(): 304 | yield example 305 | 306 | def __iter__(self): 307 | chunk_size = self.config.batch_size * self.config.seq_length 308 | token_buffer = [] 309 | loss_mask_buffer = [] 310 | last_time = 0.0 311 | step_times = [] 312 | start_time = time.time() 313 | start_tokens = self._total_tokens 314 | for tokens, loss_masks, loc, index in self.parallel_example_iterator(): 315 | token_buffer.extend(tokens) 316 | loss_mask_buffer.extend(loss_masks) 317 | while len(token_buffer) > chunk_size + 1: 318 | self._total_tokens += chunk_size 319 | step_times.append(time.time() - last_time) 320 | last_time = time.time() 321 | if len(step_times) > self.config.throughput_average_window_size: 322 | step_times = step_times[-self.config.throughput_average_window_size:] 323 | average_throughput = chunk_size / np.mean(step_times) 324 | accumulated_throughput = ( 325 | (self._total_tokens - start_tokens) / (time.time() - start_time) 326 | ) 327 | metrics = { 328 | 'dataset_file_loc': loc, 329 | 'dataset_example_index': index, 330 | 'dataset_total_tokens': self._total_tokens, 331 | 'dataset_accumulated_tps': accumulated_throughput, 332 | 'dataset_average_tps': average_throughput, 333 | } 334 | batch = { 335 | 'input_tokens': np.array(token_buffer[:chunk_size], dtype=np.int32).reshape( 336 | self.config.batch_size, -1 337 | ), 338 | 'target_tokens': np.array(token_buffer[1:chunk_size + 1], dtype=np.int32).reshape( 339 | self.config.batch_size, -1 340 | ), 341 | 'loss_masks': np.array(loss_mask_buffer[1:chunk_size + 1], dtype=np.float32).reshape( 342 | self.config.batch_size, -1 343 | ), 344 | } 345 | if self.config.always_start_with_bos: 346 | batch['input_tokens'][:, 0] = self.tokenizer.bos_token_id 347 | yield batch, metrics 348 | token_buffer = token_buffer[chunk_size:] 349 | loss_mask_buffer = loss_mask_buffer[chunk_size:] 350 | 351 | def get_state_dict(self): 352 | return dict( 353 | config=self.config, 354 | index=self._index, 355 | file_loc=self._file_loc, 356 | total_tokens=self._total_tokens, 357 | ) 358 | 359 | def load_state_dict(self, state_dict): 360 | if 'config' in state_dict: 361 | self.config.update(mlxu.ConfigDict(state_dict['config'])) 362 | self._index = state_dict.get('index', self.config.example_index_at_start) 363 | self._file_loc = state_dict.get('file_loc', self.config.start_seek_loc) 364 | self._total_tokens = state_dict.get('total_tokens', self.config.tokens_count_at_start) 365 | 366 | @property 367 | def seq_length(self): 368 | return self.config.seq_length 369 | 370 | @property 371 | def tokenizer(self): 372 | return self._tokenizer 373 | 374 | @property 375 | def text_processor(self): 376 | return self._text_processor 377 | 378 | @property 379 | def vocab_size(self): 380 | return len(self.tokenizer) 381 | -------------------------------------------------------------------------------- /EasyLM/jax_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | from typing import Any, Mapping, Text, Tuple, Union, NamedTuple 4 | from functools import partial 5 | import re 6 | import dataclasses 7 | import random 8 | from ml_collections import ConfigDict 9 | from ml_collections.config_dict.config_dict import placeholder 10 | 11 | import flax 12 | import jax 13 | import jax.numpy as jnp 14 | from jax.sharding import PartitionSpec as PS 15 | from jax.sharding import Mesh 16 | from jax.experimental import mesh_utils 17 | from jax.experimental.pjit import pjit 18 | from jax.interpreters import pxla 19 | import numpy as np 20 | from transformers import FlaxLogitsWarper 21 | 22 | 23 | class JaxRNG(object): 24 | """ A convenient stateful Jax RNG wrapper. Can be used to wrap RNG inside 25 | pure function. 26 | """ 27 | 28 | @classmethod 29 | def from_seed(cls, seed): 30 | return cls(jax.random.PRNGKey(seed)) 31 | 32 | def __init__(self, rng): 33 | self.rng = rng 34 | 35 | def __call__(self, keys=None): 36 | if keys is None: 37 | self.rng, split_rng = jax.random.split(self.rng) 38 | return split_rng 39 | elif isinstance(keys, int): 40 | split_rngs = jax.random.split(self.rng, num=keys + 1) 41 | self.rng = split_rngs[0] 42 | return tuple(split_rngs[1:]) 43 | else: 44 | split_rngs = jax.random.split(self.rng, num=len(keys) + 1) 45 | self.rng = split_rngs[0] 46 | return {key: val for key, val in zip(keys, split_rngs[1:])} 47 | 48 | 49 | class JaxDistributedConfig(object): 50 | """ Utility class for initializing JAX distributed. """ 51 | 52 | @staticmethod 53 | def get_default_config(updates=None): 54 | config = ConfigDict() 55 | config.initialize_jax_distributed = False 56 | config.coordinator_address = placeholder(str) 57 | config.num_processes = placeholder(int) 58 | config.process_id = placeholder(int) 59 | config.local_device_ids = placeholder(str) 60 | 61 | if updates is not None: 62 | config.update(ConfigDict(updates).copy_and_resolve_references()) 63 | return config 64 | 65 | @classmethod 66 | def initialize(cls, config): 67 | config = cls.get_default_config(config) 68 | if config.initialize_jax_distributed: 69 | if config.local_device_ids is not None: 70 | local_device_ids = [int(x) for x in config.local_device_ids.split(',')] 71 | else: 72 | local_device_ids = None 73 | 74 | jax.distributed.initialize( 75 | coordinator_address=config.coordinator_address, 76 | num_processes=config.num_processes, 77 | process_id=config.process_id, 78 | local_device_ids=local_device_ids, 79 | ) 80 | 81 | 82 | class FlaxTemperatureLogitsWarper(FlaxLogitsWarper): 83 | """ JIT traceable version of FlaxLogitsWarper that performs temperature scaling.""" 84 | def __init__(self, temperature): 85 | self.temperature = temperature 86 | 87 | def __call__(self, input_ids, scores, cur_len): 88 | return scores / jnp.clip(self.temperature, a_min=1e-8) 89 | 90 | 91 | def make_shard_and_gather_fns(partition_specs, dtype_specs=None): 92 | """ Create pytree of sharding and gathering functions from pytree of 93 | partition specs. 94 | """ 95 | float_dtypes = (jnp.bfloat16, jnp.float16, jnp.float32, jnp.float64) 96 | 97 | def make_to_dtype_fn(dtype_spec): 98 | def to_dtype(tensor): 99 | if dtype_specs in float_dtypes and getattr(tensor, 'dtype', None) in float_dtypes: 100 | # Convert all float tensors to the same dtype 101 | return tensor.astype(dtype_specs) 102 | elif hasattr(dtype_spec, 'dtype') and hasattr(tensor, 'dtype'): 103 | return tensor.astype(dtype_spec.dtype) 104 | return tensor 105 | return to_dtype 106 | 107 | def make_shard_fn(partition_spec, dtype_spec=None): 108 | jax_shard_function = pjit( 109 | make_to_dtype_fn(dtype_spec), 110 | in_shardings=None, 111 | out_shardings=partition_spec 112 | ) 113 | def shard_fn(tensor): 114 | return jax_shard_function(tensor).block_until_ready() 115 | return shard_fn 116 | 117 | def make_gather_fn(partition_spec, dtype_spec=None): 118 | jax_gather_fn = pjit( 119 | make_to_dtype_fn(dtype_spec), 120 | in_shardings=partition_spec, 121 | out_shardings=None 122 | ) 123 | def gather_fn(tensor): 124 | return jax.device_get(jax_gather_fn(tensor)) 125 | return gather_fn 126 | 127 | if dtype_specs is None or dtype_specs in float_dtypes: 128 | shard_fns = jax.tree_util.tree_map(make_shard_fn, partition_specs) 129 | gather_fns = jax.tree_util.tree_map(make_gather_fn, partition_specs) 130 | else: 131 | shard_fns = jax.tree_util.tree_map( 132 | make_shard_fn, partition_specs, dtype_specs 133 | ) 134 | gather_fns = jax.tree_util.tree_map( 135 | make_gather_fn, partition_specs, dtype_specs 136 | ) 137 | return shard_fns, gather_fns 138 | 139 | 140 | def set_random_seed(seed): 141 | np.random.seed(seed) 142 | random.seed(seed) 143 | init_rng(seed) 144 | 145 | 146 | def get_jax_mesh(axis_dims, names): 147 | if axis_dims.startswith('!'): 148 | # Allow splitting a physical mesh axis if needed 149 | mesh_axis_splitting = True 150 | axis_dims = axis_dims[1:] 151 | else: 152 | mesh_axis_splitting = False 153 | 154 | if ':' in axis_dims: 155 | dims = [] 156 | dim_names = [] 157 | for axis in axis_dims.split(','): 158 | name, dim = axis.split(':') 159 | assert name in names 160 | dims.append(int(dim)) 161 | dim_names.append(name) 162 | assert(set(dim_names) == set(names)) 163 | else: 164 | dims = [int(x) for x in axis_dims.split(',')] 165 | dim_names = names 166 | assert len(dims) == len(names) 167 | mesh_shape = np.arange(jax.device_count()).reshape(dims).shape 168 | if mesh_axis_splitting: 169 | physical_mesh = np.array(jax.devices()).reshape(mesh_shape) 170 | else: 171 | physical_mesh = mesh_utils.create_device_mesh(mesh_shape) 172 | return Mesh(physical_mesh, dim_names) 173 | 174 | 175 | def names_in_current_mesh(*names): 176 | """ Check if current mesh axes contain these names. """ 177 | mesh_axis_names = pxla.thread_resources.env.physical_mesh.axis_names 178 | return set(names) <= set(mesh_axis_names) 179 | 180 | 181 | def get_names_from_parition_spec(partition_specs): 182 | """ Return axis names from partition specs. """ 183 | names = set() 184 | if isinstance(partition_specs, dict): 185 | partition_specs = partition_specs.values() 186 | for item in partition_specs: 187 | if item is None: 188 | continue 189 | elif isinstance(item, str): 190 | names.add(item) 191 | else: 192 | names.update(get_names_from_parition_spec(item)) 193 | 194 | return list(names) 195 | 196 | 197 | def with_sharding_constraint(x, partition_specs): 198 | """ A smarter version of with_sharding_constraint that only applies the 199 | constraint if the current mesh contains the axes in the partition specs. 200 | """ 201 | axis_names = get_names_from_parition_spec(partition_specs) 202 | if names_in_current_mesh(*axis_names): 203 | x = jax.lax.with_sharding_constraint(x, partition_specs) 204 | return x 205 | 206 | 207 | def wrap_function_with_rng(rng): 208 | """ To be used as decorator, automatically bookkeep a RNG for the wrapped function. """ 209 | def wrap_function(function): 210 | def wrapped(*args, **kwargs): 211 | nonlocal rng 212 | rng, split_rng = jax.random.split(rng) 213 | return function(split_rng, *args, **kwargs) 214 | return wrapped 215 | return wrap_function 216 | 217 | 218 | def init_rng(seed): 219 | global jax_utils_rng 220 | jax_utils_rng = JaxRNG.from_seed(seed) 221 | 222 | 223 | def next_rng(*args, **kwargs): 224 | global jax_utils_rng 225 | return jax_utils_rng(*args, **kwargs) 226 | 227 | 228 | def get_metrics(metrics, unreplicate=False, stack=False): 229 | if unreplicate: 230 | metrics = flax.jax_utils.unreplicate(metrics) 231 | metrics = jax.device_get(metrics) 232 | if stack: 233 | return jax.tree_map(lambda *args: np.stack(args), *metrics) 234 | else: 235 | return {key: float(val) for key, val in metrics.items()} 236 | 237 | 238 | def mse_loss(val, target, valid=None): 239 | if valid is None: 240 | valid = jnp.ones((*target.shape[:2], 1)) 241 | valid = valid.astype(jnp.float32) 242 | loss = jnp.mean( 243 | jnp.where( 244 | valid > 0.0, 245 | jnp.square(val - target), 246 | 0.0 247 | ) 248 | ) 249 | return loss 250 | 251 | 252 | def cross_entropy_loss_and_accuracy(logits, tokens, valid=None): 253 | if valid is None: 254 | valid = jnp.ones(tokens.shape[:2]) 255 | valid = valid.astype(jnp.float32) 256 | valid_text_length = jnp.maximum(jnp.sum(valid, axis=-1), 1e-10) 257 | logits = logits.astype(jnp.float32) # for numerical stability 258 | token_log_prob = jnp.squeeze( 259 | jnp.take_along_axis( 260 | jax.nn.log_softmax(logits, axis=-1), 261 | jnp.expand_dims(tokens, -1), 262 | axis=-1, 263 | ), 264 | -1, 265 | ) 266 | token_log_prob = jnp.where(valid > 0.0, token_log_prob, jnp.array(0.0)) 267 | loss = -jnp.mean(jnp.sum(token_log_prob, axis=-1) / valid_text_length) 268 | correct = jnp.where( 269 | valid > 0.0, 270 | jnp.argmax(logits, axis=-1) == tokens, 271 | jnp.array(False) 272 | ) 273 | accuracy = jnp.mean(jnp.sum(correct, axis=-1) / valid_text_length) 274 | return loss, accuracy 275 | 276 | 277 | def global_norm(tree): 278 | """ Return the global L2 norm of a pytree. """ 279 | squared = jax.tree_util.tree_map(lambda x: jnp.sum(jnp.square(x)), tree) 280 | flattened, _ = jax.flatten_util.ravel_pytree(squared) 281 | return jnp.sqrt(jnp.sum(flattened)) 282 | 283 | 284 | def average_metrics(metrics): 285 | return jax.tree_map( 286 | lambda *args: jnp.mean(jnp.stack(args)), 287 | *metrics 288 | ) 289 | 290 | 291 | def get_float_dtype_by_name(dtype): 292 | return { 293 | 'bf16': jnp.bfloat16, 294 | 'bfloat16': jnp.bfloat16, 295 | 'fp16': jnp.float16, 296 | 'float16': jnp.float16, 297 | 'fp32': jnp.float32, 298 | 'float32': jnp.float32, 299 | 'fp64': jnp.float64, 300 | 'float64': jnp.float64, 301 | }[dtype] 302 | 303 | 304 | def float_tensor_to_dtype(tensor, dtype): 305 | if dtype is None or dtype == '': 306 | return tensor 307 | if isinstance(dtype, str): 308 | dtype = get_float_dtype_by_name(dtype) 309 | float_dtypes = (jnp.bfloat16, jnp.float16, jnp.float32, jnp.float64) 310 | if getattr(tensor, 'dtype', None) in float_dtypes: 311 | tensor = tensor.astype(dtype) 312 | return tensor 313 | 314 | 315 | def float_to_dtype(tree, dtype): 316 | return jax.tree_util.tree_map( 317 | partial(float_tensor_to_dtype, dtype=dtype), tree 318 | ) 319 | 320 | 321 | def get_gradient_checkpoint_policy(name): 322 | return { 323 | 'everything_saveable': jax.checkpoint_policies.everything_saveable, 324 | 'nothing_saveable': jax.checkpoint_policies.nothing_saveable, 325 | 'checkpoint_dots': jax.checkpoint_policies.checkpoint_dots, 326 | 'checkpoint_dots_with_no_batch_dims': jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims, 327 | }[name] 328 | 329 | 330 | def tree_path_to_string(path, sep=None): 331 | keys = [] 332 | for key in path: 333 | if isinstance(key, jax.tree_util.SequenceKey): 334 | keys.append(str(key.idx)) 335 | elif isinstance(key, jax.tree_util.DictKey): 336 | keys.append(str(key.key)) 337 | elif isinstance(key, jax.tree_util.GetAttrKey): 338 | keys.append(str(key.name)) 339 | elif isinstance(key, jax.tree_util.FlattenedIndexKey): 340 | keys.append(str(key.key)) 341 | else: 342 | keys.append(str(key)) 343 | if sep is None: 344 | return tuple(keys) 345 | return sep.join(keys) 346 | 347 | 348 | def flatten_tree(xs, is_leaf=None, sep=None): 349 | flattened, _ = jax.tree_util.tree_flatten_with_path(xs, is_leaf=is_leaf) 350 | output = {} 351 | for key, val in flattened: 352 | output[tree_path_to_string(key, sep=sep)] = val 353 | return output 354 | 355 | 356 | def named_tree_map(f, tree, *rest, is_leaf=None, sep=None): 357 | """ An extended version of jax.tree_util.tree_map, where the mapped function 358 | f takes both the name (path) and the tree leaf as input. 359 | """ 360 | return jax.tree_util.tree_map_with_path( 361 | lambda path, x, *r: f(tree_path_to_string(path, sep=sep), x, *r), 362 | tree, *rest, 363 | is_leaf=is_leaf 364 | ) 365 | 366 | 367 | def match_partition_rules(rules, params): 368 | """ Returns a pytree of PartitionSpec according to rules. Supports handling 369 | Flax TrainState and Optax optimizer state. 370 | """ 371 | def get_partition_spec(name, leaf): 372 | if len(leaf.shape) == 0 or np.prod(leaf.shape) == 1: 373 | """ Don't partition scalar values. """ 374 | return PS() 375 | for rule, ps in rules: 376 | if re.search(rule, name) is not None: 377 | return ps 378 | raise ValueError(f'Partition rule not found for param: {name}') 379 | return named_tree_map(get_partition_spec, params, sep='/') 380 | 381 | 382 | def get_weight_decay_mask(exclusions): 383 | """ Return a weight decay mask function that computes the pytree masks 384 | according to the given exclusion rules. 385 | """ 386 | def decay(name, _): 387 | for rule in exclusions: 388 | if re.search(rule, name) is not None: 389 | return False 390 | return True 391 | 392 | def weight_decay_mask(params): 393 | return named_tree_map(decay, params, sep='/') 394 | 395 | return weight_decay_mask 396 | 397 | 398 | def tree_apply(fns, tree): 399 | """ Apply a pytree of functions to the pytree. """ 400 | return jax.tree_util.tree_map(lambda fn, x: fn(x), fns, tree) 401 | 402 | -------------------------------------------------------------------------------- /EasyLM/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/young-geng/EasyLM/fe5b2c354e25d697fce7cd225e23bbbe72570da3/EasyLM/models/__init__.py -------------------------------------------------------------------------------- /EasyLM/models/llama/convert_easylm_to_hf.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved. 2 | # Copyright 2023 Xinyang Geng 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # This script converts LLaMA model checkpoint trained by EsayLM to the 17 | # HuggingFace transformers LLaMA PyTorch format, which can then be loaded 18 | # by HuggingFace transformers. 19 | 20 | import gc 21 | import json 22 | import math 23 | import os 24 | import shutil 25 | 26 | import numpy as np 27 | import mlxu 28 | import jax 29 | import jax.numpy as jnp 30 | import flax 31 | from flax.traverse_util import flatten_dict 32 | import torch 33 | from transformers import LlamaConfig, LlamaForCausalLM 34 | 35 | from EasyLM.models.llama.llama_model import LLaMAConfigurator 36 | from EasyLM.checkpoint import StreamingCheckpointer 37 | from EasyLM.jax_utils import float_tensor_to_dtype 38 | 39 | 40 | FLAGS, FLAGS_DEF = mlxu.define_flags_with_default( 41 | load_checkpoint='', 42 | output_dir='', 43 | llama=LLaMAConfigurator.get_default_config(), 44 | ) 45 | 46 | def match_keywords(string, positives, negatives): 47 | for positive in positives: 48 | if positive not in string: 49 | return False 50 | for negative in negatives: 51 | if negative in string: 52 | return False 53 | return True 54 | 55 | 56 | def load_and_convert_checkpoint(path): 57 | _, flax_params = StreamingCheckpointer.load_trainstate_checkpoint(path) 58 | flax_params = flatten_dict(flax_params['params'], sep='.') 59 | torch_params = {} 60 | for key, tensor in flax_params.items(): 61 | if match_keywords(key, ["kernel"], ["norm", 'ln_f']): 62 | tensor = tensor.T 63 | torch_params[key] = torch.tensor( 64 | float_tensor_to_dtype(tensor, 'fp32'), dtype=torch.float16 65 | ) 66 | return torch_params 67 | 68 | 69 | def read_json(path): 70 | with open(path, "r") as f: 71 | return json.load(f) 72 | 73 | 74 | def write_json(text, path): 75 | with open(path, "w") as f: 76 | json.dump(text, f) 77 | 78 | 79 | def permute(w, n_heads, input_dim, output_dim): 80 | # permute for sliced rotary embedding 81 | return w.view( 82 | n_heads, output_dim // n_heads // 2, 2, input_dim 83 | ).transpose(1, 2).reshape(output_dim, input_dim) 84 | 85 | 86 | def write_model(loaded, model_path): 87 | os.makedirs(model_path, exist_ok=True) 88 | tmp_model_path = os.path.join(model_path, "tmp") 89 | os.makedirs(tmp_model_path, exist_ok=True) 90 | 91 | llama_config = LLaMAConfigurator.finalize_config(FLAGS.llama) 92 | 93 | n_layers = llama_config.num_hidden_layers 94 | n_heads = llama_config.num_attention_heads 95 | n_kv_heads = llama_config.num_key_value_heads 96 | dim = llama_config.hidden_size 97 | dims_per_head = dim // n_heads 98 | base = llama_config.rope_theta 99 | inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head)) 100 | 101 | param_count = 0 102 | index_dict = {"weight_map": {}} 103 | for layer_i in range(n_layers): 104 | filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin" 105 | state_dict = { 106 | f"model.layers.{layer_i}.self_attn.q_proj.weight": permute( 107 | loaded[f"transformer.h.{layer_i}.attention.wq.kernel"], 108 | llama_config.num_attention_heads, 109 | llama_config.hidden_size, 110 | llama_config.hidden_size, 111 | ), 112 | f"model.layers.{layer_i}.self_attn.k_proj.weight": permute( 113 | loaded[f"transformer.h.{layer_i}.attention.wk.kernel"], 114 | llama_config.num_key_value_heads, 115 | llama_config.hidden_size, 116 | llama_config.hidden_size // ( 117 | llama_config.num_attention_heads 118 | // llama_config.num_key_value_heads 119 | ), 120 | ), 121 | f"model.layers.{layer_i}.self_attn.v_proj.weight": loaded[f"transformer.h.{layer_i}.attention.wv.kernel"], 122 | f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"transformer.h.{layer_i}.attention.wo.kernel"], 123 | 124 | f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded[f"transformer.h.{layer_i}.feed_forward.w1.kernel"], 125 | f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"transformer.h.{layer_i}.feed_forward.w2.kernel"], 126 | f"model.layers.{layer_i}.mlp.up_proj.weight": loaded[f"transformer.h.{layer_i}.feed_forward.w3.kernel"], 127 | 128 | f"model.layers.{layer_i}.input_layernorm.weight": loaded[f"transformer.h.{layer_i}.attention_norm.kernel"], 129 | f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[f"transformer.h.{layer_i}.ffn_norm.kernel"], 130 | 131 | } 132 | 133 | state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq 134 | for k, v in state_dict.items(): 135 | index_dict["weight_map"][k] = filename 136 | param_count += v.numel() 137 | torch.save(state_dict, os.path.join(tmp_model_path, filename)) 138 | 139 | filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin" 140 | # Unsharded 141 | state_dict = { 142 | "model.embed_tokens.weight": loaded["transformer.wte.embedding"], 143 | "model.norm.weight": loaded["transformer.ln_f.kernel"], 144 | "lm_head.weight": loaded["lm_head.kernel"], 145 | } 146 | 147 | for k, v in state_dict.items(): 148 | index_dict["weight_map"][k] = filename 149 | param_count += v.numel() 150 | torch.save(state_dict, os.path.join(tmp_model_path, filename)) 151 | 152 | # Write configs 153 | index_dict["metadata"] = {"total_size": param_count * 2} 154 | write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json")) 155 | 156 | config = LlamaConfig( 157 | vocab_size=llama_config.vocab_size, 158 | hidden_size=llama_config.hidden_size, 159 | intermediate_size=llama_config.intermediate_size, 160 | num_hidden_layers=llama_config.num_hidden_layers, 161 | num_attention_heads=llama_config.num_attention_heads, 162 | num_key_value_heads=llama_config.num_key_value_heads, 163 | initializer_range=llama_config.initializer_range, 164 | rms_norm_eps=llama_config.rms_norm_eps, 165 | max_position_embeddings=llama_config.max_position_embeddings, 166 | rope_theta=llama_config.rope_theta, 167 | ) 168 | config.save_pretrained(tmp_model_path) 169 | 170 | # Make space so we can load the model properly now. 171 | del state_dict 172 | del loaded 173 | gc.collect() 174 | 175 | print("Loading the checkpoint in a Llama model.") 176 | model = LlamaForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.float16) 177 | # Avoid saving this as part of the config. 178 | del model.config._name_or_path 179 | 180 | print("Saving in the Transformers format.") 181 | model.save_pretrained(model_path) 182 | shutil.rmtree(tmp_model_path) 183 | 184 | 185 | def main(argv): 186 | assert FLAGS.load_checkpoint != "" and FLAGS.output_dir != "" 187 | write_model( 188 | load_and_convert_checkpoint(FLAGS.load_checkpoint), 189 | model_path=FLAGS.output_dir, 190 | ) 191 | 192 | 193 | if __name__ == "__main__": 194 | mlxu.run(main) -------------------------------------------------------------------------------- /EasyLM/models/llama/convert_hf_to_easylm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python convert_hf_to_easylm.py \ 4 | --hf_model /path/hf_format_dir \ 5 | --output_file /path/easylm_format.easylm \ 6 | --llama.base_model llama_7b \ 7 | --streaming 8 | """ 9 | import os 10 | os.environ["CUDA_VISIBLE_DEVICES"] = '' 11 | import time 12 | from pathlib import Path 13 | 14 | import mlxu 15 | import torch 16 | import flax 17 | from transformers import AutoModelForCausalLM 18 | 19 | from EasyLM.models.llama.llama_model import LLaMAConfigurator 20 | from EasyLM.checkpoint import StreamingCheckpointer 21 | from EasyLM.jax_utils import get_float_dtype_by_name 22 | 23 | 24 | FLAGS, FLAGS_DEF = mlxu.define_flags_with_default( 25 | hf_model="", 26 | output_file="", 27 | streaming=True, 28 | float_dtype="bf16", 29 | llama=LLaMAConfigurator.get_default_config(), 30 | ) 31 | 32 | 33 | def inverse_permute(w, n_heads, input_dim, output_dim): 34 | reshaped_w = w.reshape(n_heads, 2, output_dim // n_heads // 2, input_dim) 35 | transposed_w = reshaped_w.transpose(0, 2, 1, 3) 36 | inverted_w = transposed_w.reshape(output_dim, input_dim) 37 | return inverted_w 38 | 39 | 40 | def main(argv): 41 | start = time.time() 42 | llama_config = LLaMAConfigurator.finalize_config(FLAGS.llama) 43 | hf_model = AutoModelForCausalLM.from_pretrained(FLAGS.hf_model) 44 | ckpt = hf_model.state_dict() 45 | 46 | print(f"Start convert weight to easylm format...") 47 | jax_weights = { 48 | "transformer": { 49 | "wte": {"embedding": ckpt["model.embed_tokens.weight"].numpy()}, 50 | "ln_f": {"kernel": ckpt["model.norm.weight"].numpy()}, 51 | "h": { 52 | "%d" 53 | % (layer): { 54 | "attention": { 55 | "wq": { 56 | "kernel": inverse_permute( 57 | ckpt[f"model.layers.{layer}.self_attn.q_proj.weight"].numpy(), 58 | llama_config.num_attention_heads, 59 | llama_config.hidden_size, 60 | llama_config.hidden_size, 61 | ).transpose() 62 | }, 63 | "wk": { 64 | "kernel": inverse_permute( 65 | ckpt[f"model.layers.{layer}.self_attn.k_proj.weight"].numpy(), 66 | llama_config.num_key_value_heads, 67 | llama_config.hidden_size, 68 | llama_config.hidden_size // ( 69 | llama_config.num_attention_heads 70 | // llama_config.num_key_value_heads 71 | ), 72 | ).transpose() 73 | }, 74 | "wv": { 75 | "kernel": ckpt[f"model.layers.{layer}.self_attn.v_proj.weight"] 76 | .numpy().transpose() 77 | }, 78 | "wo": { 79 | "kernel": ckpt[f"model.layers.{layer}.self_attn.o_proj.weight"] 80 | .numpy().transpose() 81 | }, 82 | }, 83 | "feed_forward": { 84 | "w1": { 85 | "kernel": ckpt[f"model.layers.{layer}.mlp.gate_proj.weight"] 86 | .numpy().transpose() 87 | }, 88 | "w2": { 89 | "kernel": ckpt[f"model.layers.{layer}.mlp.down_proj.weight"] 90 | .numpy().transpose() 91 | }, 92 | "w3": { 93 | "kernel": ckpt[f"model.layers.{layer}.mlp.up_proj.weight"] 94 | .numpy().transpose() 95 | }, 96 | }, 97 | "attention_norm": { 98 | "kernel": ckpt[f"model.layers.{layer}.input_layernorm.weight"].numpy() 99 | }, 100 | "ffn_norm": { 101 | "kernel": ckpt[ 102 | f"model.layers.{layer}.post_attention_layernorm.weight" 103 | ].numpy() 104 | }, 105 | } 106 | for layer in range(llama_config.num_hidden_layers) 107 | }, 108 | }, 109 | "lm_head": {"kernel": ckpt["lm_head.weight"].numpy().transpose()}, 110 | } 111 | print(f"Convert weight to easylm format finished...") 112 | print(f"Start to save...") 113 | 114 | if FLAGS.streaming: 115 | StreamingCheckpointer.save_train_state_to_file( 116 | jax_weights, 117 | FLAGS.output_file, 118 | float_dtype=get_float_dtype_by_name(FLAGS.float_dtype), 119 | ) 120 | else: 121 | with mlxu.open_file(FLAGS.output_file, "wb") as fout: 122 | fout.write(flax.serialization.msgpack_serialize(jax_weights, in_place=True)) 123 | 124 | print( 125 | f"Save finished!!! take time: {time.time() - start} save path: {FLAGS.output_file}" 126 | ) 127 | 128 | 129 | if __name__ == "__main__": 130 | mlxu.run(main) 131 | -------------------------------------------------------------------------------- /EasyLM/models/llama/llama_serve.py: -------------------------------------------------------------------------------- 1 | import pprint 2 | from functools import partial 3 | 4 | import numpy as np 5 | import mlxu 6 | 7 | import jax 8 | import jax.numpy as jnp 9 | from jax.experimental.pjit import pjit 10 | from jax.sharding import PartitionSpec as PS 11 | import optax 12 | from transformers import ( 13 | AutoTokenizer, GenerationConfig, FlaxLogitsProcessorList 14 | ) 15 | 16 | from EasyLM.checkpoint import StreamingCheckpointer 17 | from EasyLM.serving import LMServer 18 | from EasyLM.jax_utils import ( 19 | JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules, tree_apply, 20 | set_random_seed, get_float_dtype_by_name, make_shard_and_gather_fns, 21 | with_sharding_constraint, FlaxTemperatureLogitsWarper 22 | ) 23 | from EasyLM.models.llama.llama_model import ( 24 | LLaMAConfigurator, FlaxLLaMAForCausalLM 25 | ) 26 | 27 | 28 | FLAGS, FLAGS_DEF = mlxu.define_flags_with_default( 29 | seed=42, 30 | mesh_dim='1,-1,1', 31 | param_dtype='bf16', 32 | dtype='bf16', 33 | input_length=1024, 34 | seq_length=2048, 35 | top_k=50, 36 | top_p=1.0, 37 | do_sample=True, 38 | num_beams=1, 39 | add_bos_token=True, 40 | load_checkpoint='', 41 | tokenizer='openlm-research/open_llama_3b_v2', 42 | llama=LLaMAConfigurator.get_default_config(), 43 | lm_server=LMServer.get_default_config(), 44 | jax_distributed=JaxDistributedConfig.get_default_config(), 45 | ) 46 | 47 | 48 | def main(argv): 49 | JaxDistributedConfig.initialize(FLAGS.jax_distributed) 50 | set_random_seed(FLAGS.seed) 51 | 52 | prefix_tokenizer = AutoTokenizer.from_pretrained( 53 | FLAGS.tokenizer, truncation_side='left', padding_side='left' 54 | ) 55 | prefix_tokenizer.pad_token = prefix_tokenizer.eos_token 56 | tokenizer = AutoTokenizer.from_pretrained( 57 | FLAGS.tokenizer, truncation_side='right', padding_side='right' 58 | ) 59 | tokenizer.pad_token = tokenizer.eos_token 60 | llama_config = LLaMAConfigurator.finalize_config(FLAGS.llama) 61 | 62 | with jax.default_device(jax.devices("cpu")[0]): 63 | _, params = StreamingCheckpointer.load_trainstate_checkpoint( 64 | FLAGS.load_checkpoint, disallow_trainstate=True 65 | ) 66 | 67 | hf_model = FlaxLLaMAForCausalLM( 68 | llama_config, 69 | input_shape=(1, FLAGS.seq_length), 70 | seed=FLAGS.seed, 71 | _do_init=False, 72 | dtype=get_float_dtype_by_name(FLAGS.dtype), 73 | param_dtype=get_float_dtype_by_name(FLAGS.param_dtype), 74 | ) 75 | 76 | model_ps = match_partition_rules( 77 | LLaMAConfigurator.get_partition_rules(), params 78 | ) 79 | shard_fns, _ = make_shard_and_gather_fns( 80 | model_ps, get_float_dtype_by_name(FLAGS.param_dtype) 81 | ) 82 | 83 | @partial( 84 | pjit, 85 | in_shardings=(model_ps, PS(), PS()), 86 | out_shardings=(PS(), PS(), PS()) 87 | ) 88 | def forward_loglikelihood(params, rng, batch): 89 | batch = with_sharding_constraint(batch, PS(('dp', 'fsdp'))) 90 | rng_generator = JaxRNG(rng) 91 | input_tokens = batch['input_tokens'] 92 | output_tokens = batch['output_tokens'] 93 | input_mask = batch['input_mask'] 94 | output_mask = batch['output_mask'] 95 | 96 | logits = hf_model.module.apply( 97 | params, input_tokens, attention_mask=input_mask, 98 | deterministic=True, rngs=rng_generator(LLaMAConfigurator.rng_keys()), 99 | ).logits 100 | loglikelihood = -optax.softmax_cross_entropy_with_integer_labels( 101 | logits, output_tokens 102 | ) 103 | loglikelihood = jnp.sum(loglikelihood * output_mask, axis=-1) 104 | match_count = jnp.sum( 105 | (jnp.argmax(logits, axis=-1) == output_tokens) * output_mask, 106 | axis=-1 107 | ) 108 | total = jnp.sum(output_mask, axis=-1) 109 | is_greedy = match_count == total 110 | return loglikelihood, is_greedy, rng_generator() 111 | 112 | 113 | @partial( 114 | pjit, 115 | in_shardings=(model_ps, PS(), PS(), PS()), 116 | out_shardings=(PS(), PS()) 117 | ) 118 | def forward_generate(params, rng, batch, temperature): 119 | batch = with_sharding_constraint(batch, PS(('dp', 'fsdp'))) 120 | rng_generator = JaxRNG(rng) 121 | output = hf_model.generate( 122 | batch['input_tokens'], 123 | attention_mask=batch['attention_mask'], 124 | params=params['params'], 125 | prng_key=rng_generator(), 126 | logits_processor=FlaxLogitsProcessorList( 127 | [FlaxTemperatureLogitsWarper(temperature)] 128 | ), 129 | generation_config=GenerationConfig( 130 | max_new_tokens=FLAGS.seq_length - FLAGS.input_length, 131 | pad_token_id=tokenizer.eos_token_id, 132 | bos_token_id=tokenizer.bos_token_id, 133 | eos_token_id=tokenizer.eos_token_id, 134 | do_sample=FLAGS.do_sample, 135 | num_beams=FLAGS.num_beams, 136 | top_k=FLAGS.top_k, 137 | top_p=FLAGS.top_p, 138 | ) 139 | ).sequences[:, batch['input_tokens'].shape[1]:] 140 | return output, rng_generator() 141 | 142 | @partial( 143 | pjit, 144 | in_shardings=(model_ps, PS(), PS()), 145 | out_shardings=(PS(), PS()) 146 | ) 147 | def forward_greedy_generate(params, rng, batch): 148 | batch = with_sharding_constraint(batch, PS(('dp', 'fsdp'))) 149 | rng_generator = JaxRNG(rng) 150 | output = hf_model.generate( 151 | batch['input_tokens'], 152 | attention_mask=batch['attention_mask'], 153 | params=params['params'], 154 | prng_key=rng_generator(), 155 | generation_config=GenerationConfig( 156 | max_new_tokens=FLAGS.seq_length - FLAGS.input_length, 157 | pad_token_id=tokenizer.eos_token_id, 158 | bos_token_id=tokenizer.bos_token_id, 159 | eos_token_id=tokenizer.eos_token_id, 160 | do_sample=False, 161 | num_beams=1, 162 | ) 163 | ).sequences[:, batch['input_tokens'].shape[1]:] 164 | return output, rng_generator() 165 | 166 | mesh = LLaMAConfigurator.get_jax_mesh(FLAGS.mesh_dim) 167 | with mesh: 168 | params = tree_apply(shard_fns, params) 169 | sharded_rng = next_rng() 170 | 171 | class ModelServer(LMServer): 172 | 173 | @staticmethod 174 | def loglikelihood(prefix_text, text): 175 | nonlocal sharded_rng 176 | prefix = prefix_tokenizer( 177 | prefix_text, 178 | padding='max_length', 179 | truncation=True, 180 | max_length=FLAGS.input_length, 181 | return_tensors='np', 182 | ) 183 | inputs = tokenizer( 184 | text, 185 | padding='max_length', 186 | truncation=True, 187 | max_length=FLAGS.seq_length - FLAGS.input_length, 188 | return_tensors='np', 189 | ) 190 | output_tokens = np.concatenate([prefix.input_ids, inputs.input_ids], axis=1) 191 | bos_tokens = np.full( 192 | (output_tokens.shape[0], 1), tokenizer.bos_token_id, dtype=np.int32 193 | ) 194 | input_tokens = np.concatenate([bos_tokens, output_tokens[:, :-1]], axis=-1) 195 | input_mask = np.concatenate( 196 | [prefix.attention_mask, inputs.attention_mask], axis=1 197 | ) 198 | if FLAGS.add_bos_token: 199 | bos_mask = np.ones_like(input_mask[:, :1]) 200 | else: 201 | bos_mask = np.zeros_like(input_mask[:, :1]) 202 | 203 | input_mask = np.concatenate([bos_mask, input_mask[:, :-1]], axis=1) 204 | output_mask = np.concatenate( 205 | [np.zeros_like(prefix.attention_mask), inputs.attention_mask], axis=1 206 | ) 207 | batch = dict( 208 | input_tokens=input_tokens, 209 | output_tokens=output_tokens, 210 | input_mask=input_mask, 211 | output_mask=output_mask, 212 | ) 213 | with mesh: 214 | loglikelihood, is_greedy, sharded_rng = forward_loglikelihood( 215 | params, sharded_rng, batch 216 | ) 217 | loglikelihood, is_greedy = jax.device_get((loglikelihood, is_greedy)) 218 | return loglikelihood, is_greedy 219 | 220 | @staticmethod 221 | def loglikelihood_rolling(text): 222 | nonlocal sharded_rng 223 | inputs = tokenizer( 224 | text, 225 | padding='longest', 226 | truncation=False, 227 | max_length=np.iinfo(np.int32).max, 228 | return_tensors='np', 229 | ) 230 | batch_size = inputs.input_ids.shape[0] 231 | output_tokens = inputs.input_ids 232 | attention_mask = inputs.attention_mask 233 | 234 | if output_tokens.shape[1] < FLAGS.seq_length: 235 | padding_length = FLAGS.seq_length - output_tokens.shape[1] 236 | pad_tokens = np.full( 237 | (batch_size, padding_length), tokenizer.pad_token_id, dtype=np.int32 238 | ) 239 | output_tokens = np.concatenate([output_tokens, pad_tokens], axis=-1) 240 | pad_mask = np.zeros( 241 | (batch_size, padding_length), dtype=inputs.attention_mask.dtype 242 | ) 243 | attention_mask = np.concatenate([attention_mask, pad_mask], axis=-1) 244 | 245 | bos_tokens = np.full( 246 | (batch_size, 1), tokenizer.bos_token_id, dtype=np.int32 247 | ) 248 | input_tokens = np.concatenate([bos_tokens, output_tokens[:, :-1]], axis=-1) 249 | bos_mask = np.ones((batch_size, 1), dtype=inputs.attention_mask.dtype) 250 | total_seq_length = output_tokens.shape[1] 251 | 252 | total_loglikelihood = 0.0 253 | total_is_greedy = True 254 | # Sliding window 255 | for i in range(0, total_seq_length, FLAGS.seq_length): 256 | # Last window 257 | if i + FLAGS.seq_length > total_seq_length: 258 | last_output_mask = np.copy(attention_mask[:, -FLAGS.seq_length:]) 259 | last_output_mask[:, :i - total_seq_length] = 0.0 260 | 261 | batch = dict( 262 | input_tokens=input_tokens[:, -FLAGS.seq_length:], 263 | output_tokens=output_tokens[:, -FLAGS.seq_length:], 264 | input_mask=attention_mask[:, -FLAGS.seq_length:], 265 | output_mask=last_output_mask, 266 | ) 267 | 268 | # Normal window 269 | else: 270 | batch = dict( 271 | input_tokens=input_tokens[:, i:i + FLAGS.seq_length], 272 | output_tokens=output_tokens[:, i:i + FLAGS.seq_length], 273 | input_mask=attention_mask[:, i:i + FLAGS.seq_length], 274 | output_mask=attention_mask[:, i:i + FLAGS.seq_length], 275 | ) 276 | 277 | with mesh: 278 | loglikelihood, is_greedy, sharded_rng = forward_loglikelihood( 279 | params, sharded_rng, batch 280 | ) 281 | loglikelihood, is_greedy = jax.device_get((loglikelihood, is_greedy)) 282 | 283 | total_loglikelihood += loglikelihood 284 | total_is_greedy = np.logical_and(is_greedy, total_is_greedy) 285 | 286 | return total_loglikelihood, total_is_greedy 287 | 288 | @staticmethod 289 | def generate(text, temperature): 290 | nonlocal sharded_rng 291 | inputs = prefix_tokenizer( 292 | text, 293 | padding='max_length', 294 | truncation=True, 295 | max_length=FLAGS.input_length, 296 | return_tensors='np', 297 | ) 298 | input_tokens = inputs.input_ids 299 | input_mask = inputs.attention_mask 300 | if FLAGS.add_bos_token: 301 | input_tokens[:, 0] = tokenizer.bos_token_id 302 | input_mask[:, 0] = 1 303 | batch = dict( 304 | input_tokens=input_tokens, 305 | attention_mask=input_mask, 306 | ) 307 | with mesh: 308 | output, sharded_rng = forward_generate( 309 | params, sharded_rng, batch, temperature 310 | ) 311 | output = jax.device_get(output) 312 | output_text = [] 313 | for text in list(tokenizer.batch_decode(output)): 314 | if tokenizer.eos_token in text: 315 | text = text.split(tokenizer.eos_token, maxsplit=1)[0] 316 | output_text.append(text) 317 | 318 | return output_text 319 | 320 | @staticmethod 321 | def greedy_until(prefix_text, until, max_length): 322 | nonlocal sharded_rng 323 | all_outputs = [] 324 | for pf, ut in zip(prefix_text, until): 325 | if isinstance(ut, str): 326 | ut = [ut] 327 | total_length = 0 328 | total_generated = '' 329 | 330 | while total_length < max_length: 331 | pf_tokens = tokenizer( 332 | pf, 333 | padding=False, 334 | truncation=False, 335 | max_length=np.iinfo(np.int32).max, 336 | return_tensors='np', 337 | ) 338 | input_tokens = pf_tokens.input_ids 339 | attention_mask = pf_tokens.attention_mask 340 | 341 | if input_tokens.shape[1] < FLAGS.input_length: 342 | extra = FLAGS.input_length - input_tokens.shape[1] 343 | pad_tokens = np.full( 344 | (1, extra), tokenizer.pad_token_id, dtype=np.int32 345 | ) 346 | input_tokens = np.concatenate( 347 | [pad_tokens, input_tokens], axis=1 348 | ) 349 | pad_attention = np.zeros((1, extra), dtype=attention_mask.dtype) 350 | attention_mask = np.concatenate( 351 | [pad_attention, attention_mask], axis=1 352 | ) 353 | elif input_tokens.shape[1] > FLAGS.input_length: 354 | input_tokens = input_tokens[:, -FLAGS.input_length:] 355 | attention_mask = attention_mask[:, -FLAGS.input_length:] 356 | 357 | if FLAGS.add_bos_token: 358 | input_tokens[:, 0] = tokenizer.bos_token_id 359 | attention_mask[:, 0] = 1 360 | 361 | batch = dict(input_tokens=input_tokens, attention_mask=attention_mask) 362 | 363 | with mesh: 364 | output, sharded_rng = forward_greedy_generate( 365 | params, sharded_rng, batch 366 | ) 367 | output = jax.device_get(output) 368 | 369 | total_length += output.shape[1] 370 | output_text = tokenizer.batch_decode(output)[0] 371 | total_generated = total_generated + output_text 372 | pf = pf + output_text 373 | 374 | done = False 375 | for s in ut: 376 | if s in total_generated: 377 | total_generated = total_generated.split(s, maxsplit=1)[0] 378 | done = True 379 | if done: 380 | break 381 | 382 | all_outputs.append(total_generated) 383 | 384 | return all_outputs 385 | 386 | 387 | server = ModelServer(FLAGS.lm_server) 388 | server.run() 389 | 390 | 391 | if __name__ == "__main__": 392 | mlxu.run(main) 393 | -------------------------------------------------------------------------------- /EasyLM/models/llama/llama_train.py: -------------------------------------------------------------------------------- 1 | import pprint 2 | from functools import partial 3 | 4 | from tqdm import tqdm, trange 5 | import numpy as np 6 | import mlxu 7 | 8 | import jax 9 | import jax.numpy as jnp 10 | from jax.experimental.pjit import pjit 11 | from jax.sharding import PartitionSpec as PS 12 | from flax.training.train_state import TrainState 13 | from transformers import AutoTokenizer 14 | 15 | from EasyLM.data import DatasetFactory 16 | from EasyLM.checkpoint import StreamingCheckpointer 17 | from EasyLM.optimizers import OptimizerFactory 18 | from EasyLM.jax_utils import ( 19 | JaxRNG, JaxDistributedConfig, next_rng, match_partition_rules, 20 | cross_entropy_loss_and_accuracy, global_norm, get_float_dtype_by_name, 21 | set_random_seed, average_metrics, make_shard_and_gather_fns, 22 | with_sharding_constraint, 23 | ) 24 | from EasyLM.models.llama.llama_model import ( 25 | LLaMAConfigurator, FlaxLLaMAForCausalLMModule 26 | ) 27 | 28 | 29 | FLAGS, FLAGS_DEF = mlxu.define_flags_with_default( 30 | seed=42, 31 | mesh_dim='1,-1,1', 32 | dtype='fp32', 33 | param_dtype='fp32', 34 | total_steps=10000, 35 | load_llama_config='', 36 | update_llama_config='', 37 | load_checkpoint='', 38 | load_dataset_state='', 39 | log_freq=50, 40 | save_model_freq=0, 41 | save_milestone_freq=0, 42 | eval_steps=0, 43 | tokenizer='openlm-research/open_llama_3b_v2', 44 | train_dataset=DatasetFactory.get_default_config(), 45 | eval_dataset=DatasetFactory.get_default_config(), 46 | optimizer=OptimizerFactory.get_default_config(), 47 | checkpointer=StreamingCheckpointer.get_default_config(), 48 | llama=LLaMAConfigurator.get_default_config(), 49 | logger=mlxu.WandBLogger.get_default_config(), 50 | log_all_worker=False, 51 | jax_distributed=JaxDistributedConfig.get_default_config(), 52 | ) 53 | 54 | 55 | def main(argv): 56 | JaxDistributedConfig.initialize(FLAGS.jax_distributed) 57 | variant = mlxu.get_user_flags(FLAGS, FLAGS_DEF) 58 | flags_config_dict = mlxu.user_flags_to_config_dict(FLAGS, FLAGS_DEF) 59 | logger = mlxu.WandBLogger( 60 | config=FLAGS.logger, 61 | variant=variant, 62 | enable=FLAGS.log_all_worker or (jax.process_index() == 0), 63 | ) 64 | set_random_seed(FLAGS.seed) 65 | 66 | tokenizer = AutoTokenizer.from_pretrained(FLAGS.tokenizer) 67 | dataset = DatasetFactory.load_dataset(FLAGS.train_dataset, tokenizer) 68 | if FLAGS.load_dataset_state != '': 69 | dataset.load_state_dict(mlxu.load_pickle(FLAGS.load_dataset_state)) 70 | 71 | if FLAGS.eval_steps > 0: 72 | eval_dataset = DatasetFactory.load_dataset( 73 | FLAGS.eval_dataset, dataset.tokenizer 74 | ) 75 | eval_iterator = iter(eval_dataset) 76 | 77 | seq_length = dataset.seq_length 78 | llama_config = LLaMAConfigurator.finalize_config(FLAGS.llama) 79 | 80 | model = FlaxLLaMAForCausalLMModule( 81 | llama_config, 82 | dtype=get_float_dtype_by_name(FLAGS.dtype), 83 | param_dtype=get_float_dtype_by_name(FLAGS.param_dtype), 84 | ) 85 | 86 | optimizer, optimizer_info = OptimizerFactory.get_optimizer(FLAGS.optimizer) 87 | 88 | def create_trainstate_from_params(params): 89 | return TrainState.create(params=params, tx=optimizer, apply_fn=None) 90 | 91 | def init_fn(rng): 92 | rng_generator = JaxRNG(rng) 93 | params = model.init( 94 | input_ids=jnp.zeros((4, seq_length), dtype=jnp.int32), 95 | position_ids=jnp.zeros((4, seq_length), dtype=jnp.int32), 96 | attention_mask=jnp.ones((4, seq_length), dtype=jnp.int32), 97 | rngs=rng_generator(LLaMAConfigurator.rng_keys()), 98 | ) 99 | return TrainState.create(params=params, tx=optimizer, apply_fn=None) 100 | 101 | def train_step(train_state, rng, batch): 102 | rng_generator = JaxRNG(rng) 103 | batch = with_sharding_constraint(batch, PS(('dp', 'fsdp'))) 104 | def loss_and_accuracy(params): 105 | logits = model.apply( 106 | params, batch['input_tokens'], deterministic=False, 107 | rngs=rng_generator(LLaMAConfigurator.rng_keys()), 108 | ).logits 109 | return cross_entropy_loss_and_accuracy( 110 | logits, batch['target_tokens'], batch['loss_masks'] 111 | ) 112 | grad_fn = jax.value_and_grad(loss_and_accuracy, has_aux=True) 113 | (loss, accuracy), grads = grad_fn(train_state.params) 114 | train_state = train_state.apply_gradients(grads=grads) 115 | metrics = dict( 116 | loss=loss, 117 | accuracy=accuracy, 118 | learning_rate=optimizer_info['learning_rate_schedule'](train_state.step), 119 | gradient_norm=global_norm(grads), 120 | param_norm=global_norm(train_state.params), 121 | ) 122 | return train_state, rng_generator(), metrics 123 | 124 | def eval_step(train_state, rng, batch): 125 | rng_generator = JaxRNG(rng) 126 | batch = with_sharding_constraint(batch, PS(('dp', 'fsdp'))) 127 | logits = model.apply( 128 | train_state.params, batch['input_tokens'], deterministic=True, 129 | rngs=rng_generator(LLaMAConfigurator.rng_keys()), 130 | ).logits 131 | loss, accuracy = cross_entropy_loss_and_accuracy( 132 | logits, batch['target_tokens'], batch['loss_masks'] 133 | ) 134 | metrics = dict( 135 | eval_loss=loss, 136 | eval_accuracy=accuracy, 137 | ) 138 | return rng_generator(), metrics 139 | 140 | train_state_shapes = jax.eval_shape(init_fn, next_rng()) 141 | train_state_partition = match_partition_rules( 142 | LLaMAConfigurator.get_partition_rules(), train_state_shapes 143 | ) 144 | 145 | shard_fns, gather_fns = make_shard_and_gather_fns( 146 | train_state_partition, train_state_shapes 147 | ) 148 | checkpointer = StreamingCheckpointer( 149 | FLAGS.checkpointer, logger.output_dir, 150 | enable=jax.process_index() == 0, 151 | ) 152 | 153 | sharded_init_fn = pjit( 154 | init_fn, 155 | in_shardings=PS(), 156 | out_shardings=train_state_partition 157 | ) 158 | 159 | sharded_create_trainstate_from_params = pjit( 160 | create_trainstate_from_params, 161 | in_shardings=(train_state_partition.params, ), 162 | out_shardings=train_state_partition, 163 | donate_argnums=(0, ), 164 | ) 165 | 166 | sharded_train_step = pjit( 167 | train_step, 168 | in_shardings=(train_state_partition, PS(), PS()), 169 | out_shardings=(train_state_partition, PS(), PS()), 170 | donate_argnums=(0, 1), 171 | ) 172 | 173 | sharded_eval_step = pjit( 174 | eval_step, 175 | in_shardings=(train_state_partition, PS(), PS()), 176 | out_shardings=(PS(), PS()), 177 | donate_argnums=(1,), 178 | ) 179 | 180 | def save_checkpoint(train_state, milestone=False): 181 | step = int(jax.device_get(train_state.step)) 182 | metadata = dict( 183 | step=step, 184 | variant=variant, 185 | flags=flags_config_dict, 186 | llama_config=llama_config.to_dict(), 187 | ) 188 | checkpointer.save_all( 189 | train_state=train_state, 190 | gather_fns=gather_fns, 191 | metadata=metadata, 192 | dataset=dataset.get_state_dict(), 193 | milestone=milestone, 194 | ) 195 | 196 | mesh = LLaMAConfigurator.get_jax_mesh(FLAGS.mesh_dim) 197 | with mesh: 198 | train_state, restored_params = None, None 199 | if FLAGS.load_checkpoint != '': 200 | train_state, restored_params = checkpointer.load_trainstate_checkpoint( 201 | FLAGS.load_checkpoint, train_state_shapes, shard_fns 202 | ) 203 | 204 | if train_state is None and restored_params is None: 205 | # Initialize from scratch 206 | train_state = sharded_init_fn(next_rng()) 207 | elif train_state is None and restored_params is not None: 208 | # Restore from params but initialize train_state 209 | train_state = sharded_create_trainstate_from_params(restored_params) 210 | del restored_params 211 | 212 | start_step = int(jax.device_get(train_state.step)) 213 | 214 | if FLAGS.save_model_freq > 0: 215 | save_checkpoint(train_state) 216 | 217 | sharded_rng = next_rng() 218 | 219 | step_counter = trange(start_step, FLAGS.total_steps, ncols=0) 220 | 221 | for step, (batch, dataset_metrics) in zip(step_counter, dataset): 222 | train_state, sharded_rng, metrics = sharded_train_step( 223 | train_state, sharded_rng, batch 224 | ) 225 | 226 | if step % FLAGS.log_freq == 0: 227 | if FLAGS.eval_steps > 0: 228 | eval_metric_list = [] 229 | for _ in range(FLAGS.eval_steps): 230 | eval_batch, _ = next(eval_iterator) 231 | sharded_rng, eval_metrics = sharded_eval_step( 232 | train_state, sharded_rng, eval_batch 233 | ) 234 | eval_metric_list.append(eval_metrics) 235 | metrics.update(average_metrics(eval_metric_list)) 236 | 237 | log_metrics = {"step": step} 238 | log_metrics.update(metrics) 239 | log_metrics.update(dataset_metrics) 240 | log_metrics = jax.device_get(log_metrics) 241 | logger.log(log_metrics) 242 | tqdm.write("\n" + pprint.pformat(log_metrics) + "\n") 243 | 244 | if FLAGS.save_milestone_freq > 0 and (step + 1) % FLAGS.save_milestone_freq == 0: 245 | save_checkpoint(train_state, milestone=True) 246 | elif FLAGS.save_model_freq > 0 and (step + 1) % FLAGS.save_model_freq == 0: 247 | save_checkpoint(train_state) 248 | 249 | if FLAGS.save_model_freq > 0: 250 | save_checkpoint(train_state) 251 | 252 | 253 | if __name__ == "__main__": 254 | mlxu.run(main) 255 | -------------------------------------------------------------------------------- /EasyLM/optimizers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from typing import Any, Mapping, Text, Tuple, Union, NamedTuple 4 | from functools import partial 5 | import re 6 | import dataclasses 7 | import random 8 | 9 | from ml_collections.config_dict import config_dict 10 | from ml_collections import ConfigDict 11 | import jax 12 | import jax.numpy as jnp 13 | import numpy as np 14 | from absl import logging 15 | import optax 16 | 17 | from EasyLM.jax_utils import float_to_dtype 18 | 19 | 20 | class OptimizerFactory(object): 21 | """ Configurable optax optimizer factory. """ 22 | 23 | def __init__(self): 24 | raise NotImplementedError 25 | 26 | @staticmethod 27 | def get_default_config(updates=None): 28 | config = ConfigDict() 29 | config.accumulate_gradient_steps = 1 30 | config.type = 'adamw' 31 | config.palm_optimizer = PalmOptimizerFactory.get_default_config() 32 | config.adamw_optimizer = AdamWOptimizerFactory.get_default_config() 33 | 34 | if updates is not None: 35 | config.update(ConfigDict(updates).copy_and_resolve_references()) 36 | return config 37 | 38 | @classmethod 39 | def get_optimizer(cls, config, weight_decay_mask=None): 40 | config = cls.get_default_config(config) 41 | if config.type == 'palm': 42 | optimizer, optimizer_info = PalmOptimizerFactory.get_optimizer( 43 | config.palm_optimizer, weight_decay_mask 44 | ) 45 | elif config.type == 'adamw': 46 | optimizer, optimizer_info = AdamWOptimizerFactory.get_optimizer( 47 | config.adamw_optimizer, weight_decay_mask 48 | ) 49 | else: 50 | raise ValueError(f'Unknown optimizer type: {config.type}') 51 | 52 | if config.accumulate_gradient_steps > 1: 53 | optimizer = optax.MultiSteps( 54 | optimizer, config.accumulate_gradient_steps 55 | ) 56 | 57 | return optimizer, optimizer_info 58 | 59 | 60 | class PalmOptimizerFactory(object): 61 | """ PaLM optimizer factory. This optimizer implements the optimizer 62 | described in the PaLM paper: https://arxiv.org/abs/2204.02311 63 | """ 64 | 65 | def __init__(self): 66 | raise NotImplementedError 67 | 68 | @staticmethod 69 | def get_default_config(updates=None): 70 | config = ConfigDict() 71 | config.lr = 0.01 72 | config.lr_warmup_steps = 10000 73 | config.b1 = 0.9 74 | config.b2 = 0.99 75 | config.clip_gradient = 1.0 76 | config.weight_decay = 1e-4 77 | config.bf16_momentum = False 78 | 79 | if updates is not None: 80 | config.update(ConfigDict(updates).copy_and_resolve_references()) 81 | return config 82 | 83 | @classmethod 84 | def get_optimizer(cls, config, weight_decay_mask=None): 85 | config = cls.get_default_config(config) 86 | 87 | def learning_rate_schedule(step): 88 | multiplier = config.lr / 0.01 89 | return multiplier / jnp.sqrt(jnp.maximum(step, config.lr_warmup_steps)) 90 | 91 | def weight_decay_schedule(step): 92 | multiplier = config.weight_decay / 1e-4 93 | return -multiplier * jnp.square(learning_rate_schedule(step)) 94 | 95 | optimizer_info = dict( 96 | learning_rate_schedule=learning_rate_schedule, 97 | weight_decay_schedule=weight_decay_schedule, 98 | ) 99 | 100 | optimizer = optax.chain( 101 | optax.clip_by_global_norm(config.clip_gradient), 102 | optax.adafactor( 103 | learning_rate=learning_rate_schedule, 104 | multiply_by_parameter_scale=True, 105 | momentum=config.b1, 106 | decay_rate=config.b2, 107 | factored=False, 108 | clipping_threshold=None, 109 | dtype_momentum=jnp.bfloat16 if config.bf16_momentum else jnp.float32, 110 | ), 111 | optax_add_scheduled_weight_decay( 112 | weight_decay_schedule, weight_decay_mask 113 | ) 114 | ) 115 | return optimizer, optimizer_info 116 | 117 | 118 | class AdamWOptimizerFactory(object): 119 | """ AdamW optimizer with cosine schedule. """ 120 | 121 | def __init__(self): 122 | raise NotImplementedError 123 | 124 | @staticmethod 125 | def get_default_config(updates=None): 126 | config = ConfigDict() 127 | config.init_lr = 0.0 128 | config.end_lr = 0.001 129 | config.lr = 0.01 130 | config.lr_warmup_steps = 2000 131 | config.lr_decay_steps = 500000 132 | config.b1 = 0.9 133 | config.b2 = 0.95 134 | config.clip_gradient = 1.0 135 | config.weight_decay = 1e-4 136 | config.bf16_momentum = False 137 | config.multiply_by_parameter_scale = False 138 | 139 | if updates is not None: 140 | config.update(ConfigDict(updates).copy_and_resolve_references()) 141 | return config 142 | 143 | @classmethod 144 | def get_optimizer(cls, config, weight_decay_mask=None): 145 | config = cls.get_default_config(config) 146 | 147 | learning_rate_schedule = optax.warmup_cosine_decay_schedule( 148 | init_value=config.init_lr, 149 | peak_value=config.lr, 150 | warmup_steps=config.lr_warmup_steps, 151 | decay_steps=config.lr_decay_steps, 152 | end_value=config.end_lr, 153 | ) 154 | 155 | optimizer_info = dict( 156 | learning_rate_schedule=learning_rate_schedule, 157 | ) 158 | 159 | if config.multiply_by_parameter_scale: 160 | optimizer = optax.chain( 161 | optax.clip_by_global_norm(config.clip_gradient), 162 | optax.adafactor( 163 | learning_rate=learning_rate_schedule, 164 | multiply_by_parameter_scale=True, 165 | momentum=config.b1, 166 | decay_rate=config.b2, 167 | factored=False, 168 | clipping_threshold=None, 169 | dtype_momentum=jnp.bfloat16 if config.bf16_momentum else jnp.float32, 170 | ), 171 | optax_add_scheduled_weight_decay( 172 | lambda step: -learning_rate_schedule(step) * config.weight_decay, 173 | weight_decay_mask 174 | ) 175 | ) 176 | else: 177 | optimizer = optax.chain( 178 | optax.clip_by_global_norm(config.clip_gradient), 179 | optax.adamw( 180 | learning_rate=learning_rate_schedule, 181 | weight_decay=config.weight_decay, 182 | b1=config.b1, 183 | b2=config.b2, 184 | mask=weight_decay_mask, 185 | mu_dtype=jnp.bfloat16 if config.bf16_momentum else jnp.float32, 186 | ), 187 | ) 188 | 189 | return optimizer, optimizer_info 190 | 191 | 192 | class OptaxScheduledWeightDecayState(NamedTuple): 193 | count: jax.Array 194 | 195 | 196 | def optax_add_scheduled_weight_decay(schedule_fn, mask=None): 197 | """ Apply weight decay with schedule. """ 198 | 199 | def init_fn(params): 200 | del params 201 | return OptaxScheduledWeightDecayState(count=jnp.zeros([], jnp.int32)) 202 | 203 | def update_fn(updates, state, params): 204 | if params is None: 205 | raise ValueError('Params cannot be None for weight decay!') 206 | 207 | weight_decay = schedule_fn(state.count) 208 | updates = jax.tree_util.tree_map( 209 | lambda g, p: g + weight_decay * p, updates, params 210 | ) 211 | return updates, OptaxScheduledWeightDecayState( 212 | count=optax.safe_int32_increment(state.count) 213 | ) 214 | 215 | if mask is not None: 216 | return optax.masked(optax.GradientTransformation(init_fn, update_fn), mask) 217 | return optax.GradientTransformation(init_fn, update_fn) 218 | -------------------------------------------------------------------------------- /EasyLM/scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/young-geng/EasyLM/fe5b2c354e25d697fce7cd225e23bbbe72570da3/EasyLM/scripts/__init__.py -------------------------------------------------------------------------------- /EasyLM/scripts/benchmark_attention.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from time import time 3 | import os 4 | import numpy as np 5 | import jax 6 | import jax.flatten_util 7 | import jax.numpy as jnp 8 | import mlxu 9 | from EasyLM.bpt import blockwise_attn 10 | from EasyLM.jax_utils import ( 11 | get_float_dtype_by_name, set_random_seed, next_rng, JaxRNG 12 | ) 13 | 14 | 15 | FLAGS, _ = mlxu.define_flags_with_default( 16 | seed=42, 17 | dtype='fp32', 18 | embed_dim=2048, 19 | n_heads=16, 20 | ref_attn_seq_len=2048, 21 | eff_attn_seq_len=16384, 22 | batch_size=1, 23 | query_chunk_size=2048, 24 | key_chunk_size=2048, 25 | warmup_steps=40, 26 | steps=200, 27 | ) 28 | 29 | 30 | def main(argv): 31 | 32 | def random_kqv(rng_key, seq_len): 33 | rng_generator = JaxRNG(rng_key) 34 | kqv = [] 35 | for i in range(3): 36 | kqv.append( 37 | jax.random.normal( 38 | rng_generator(), 39 | (FLAGS.batch_size, seq_len, FLAGS.n_heads, FLAGS.embed_dim // FLAGS.n_heads), 40 | dtype=get_float_dtype_by_name(FLAGS.dtype) 41 | ) 42 | ) 43 | return tuple(kqv) 44 | 45 | def reference_attn(query, key, value): 46 | dtype = get_float_dtype_by_name(FLAGS.dtype) 47 | query = query / jnp.sqrt(query.shape[-1]).astype(dtype) 48 | logits = jnp.einsum("bqhc,bkhc->bhqk", query, key) 49 | mask_value = jnp.finfo(logits.dtype).min 50 | _, q_seq_len, _, _ = query.shape 51 | _, kv_seq_len, _, _ = key.shape 52 | mask_shape = (q_seq_len, kv_seq_len) 53 | row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0) 54 | col_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 1) 55 | causal_mask = (row_ids < col_ids)[None, None, :, :] 56 | logits = logits + jnp.where(causal_mask, mask_value, 0.0) 57 | weights = jax.nn.softmax(logits, axis=-1) 58 | out = jnp.einsum("bhqk,bkhc->bqhc", weights, value) 59 | return out 60 | 61 | def efficient_attention(query, key, value): 62 | dtype = get_float_dtype_by_name(FLAGS.dtype) 63 | return blockwise_attn( 64 | query, key, value, 65 | bias=None, 66 | deterministic=True, 67 | dropout_rng=None, 68 | attn_pdrop=0.0, 69 | causal=True, 70 | query_chunk_size=FLAGS.query_chunk_size, 71 | key_chunk_size=FLAGS.key_chunk_size, 72 | dtype=get_float_dtype_by_name(FLAGS.dtype), 73 | policy=jax.checkpoint_policies.nothing_saveable(), 74 | precision=None, 75 | float32_logits=True, 76 | prevent_cse=True, 77 | ) 78 | 79 | 80 | @partial(jax.jit, static_argnums=(1,)) 81 | def reference_attn_forward_backward(rng_key, seq_len): 82 | @partial(jax.grad, argnums=(0, 1, 2)) 83 | @partial(jax.checkpoint, policy=jax.checkpoint_policies.nothing_saveable()) 84 | def grad_fn(query, key, value): 85 | out = reference_attn(query, key, value) 86 | return jnp.mean(out) 87 | 88 | query, key, value = random_kqv(rng_key, seq_len) 89 | return jax.flatten_util.ravel_pytree( 90 | grad_fn(query, key, value)[1] 91 | )[0].mean() 92 | 93 | @partial(jax.jit, static_argnums=(1,)) 94 | def efficient_attn_forward_backward(rng_key, seq_len): 95 | @partial(jax.grad, argnums=(0, 1, 2)) 96 | def grad_fn(query, key, value): 97 | out = efficient_attention(query, key, value) 98 | return jnp.mean(out) 99 | 100 | query, key, value = random_kqv(rng_key, seq_len) 101 | return jax.flatten_util.ravel_pytree( 102 | grad_fn(query, key, value)[1] 103 | )[0].mean() 104 | 105 | 106 | set_random_seed(FLAGS.seed) 107 | 108 | jax.block_until_ready(reference_attn_forward_backward(next_rng(), FLAGS.ref_attn_seq_len)) 109 | jax.block_until_ready(efficient_attn_forward_backward(next_rng(), FLAGS.eff_attn_seq_len)) 110 | 111 | all_results = [] 112 | for i in range(FLAGS.warmup_steps): 113 | all_results.append(reference_attn_forward_backward(next_rng(), FLAGS.ref_attn_seq_len)) 114 | jax.block_until_ready(all_results) 115 | 116 | start_time = time() 117 | all_results = [] 118 | for i in range(FLAGS.steps): 119 | all_results.append(reference_attn_forward_backward(next_rng(), FLAGS.ref_attn_seq_len)) 120 | 121 | jax.block_until_ready(all_results) 122 | elapsed_time_ref_attn = time() - start_time 123 | print(f'Reference attention: {elapsed_time_ref_attn:.3f} seconds') 124 | 125 | 126 | all_results = [] 127 | for i in range(FLAGS.warmup_steps): 128 | all_results.append(efficient_attn_forward_backward(next_rng(), FLAGS.eff_attn_seq_len)) 129 | jax.block_until_ready(all_results) 130 | 131 | 132 | start_time = time() 133 | all_results = [] 134 | for i in range(FLAGS.steps): 135 | all_results.append(efficient_attn_forward_backward(next_rng(), FLAGS.eff_attn_seq_len)) 136 | 137 | jax.block_until_ready(all_results) 138 | elapsed_time_efficient_attn = time() - start_time 139 | print(f'Efficient attention: {elapsed_time_efficient_attn:.3f} seconds') 140 | 141 | flops_ratio = (FLAGS.eff_attn_seq_len / FLAGS.ref_attn_seq_len) ** 2 142 | efficiency = elapsed_time_ref_attn / elapsed_time_efficient_attn * flops_ratio 143 | print(f'Efficiency: {efficiency:.3f}') 144 | 145 | 146 | if __name__ == '__main__': 147 | mlxu.run(main) 148 | 149 | 150 | 151 | -------------------------------------------------------------------------------- /EasyLM/scripts/convert_checkpoint.py: -------------------------------------------------------------------------------- 1 | # This script converts model checkpoint trained by EsayLM to a standard 2 | # mspack checkpoint that can be loaded by huggingface transformers or 3 | # flax.serialization.msgpack_restore. Such conversion allows models to be 4 | # used by other frameworks that integrate with huggingface transformers. 5 | 6 | import pprint 7 | from functools import partial 8 | import os 9 | import numpy as np 10 | import mlxu 11 | import jax.numpy as jnp 12 | import flax.serialization 13 | from EasyLM.checkpoint import StreamingCheckpointer 14 | from EasyLM.jax_utils import float_to_dtype 15 | 16 | 17 | FLAGS, FLAGS_DEF = mlxu.define_flags_with_default( 18 | load_checkpoint='', 19 | output_file='', 20 | streaming=False, 21 | float_dtype='bf16', 22 | ) 23 | 24 | 25 | def main(argv): 26 | assert FLAGS.load_checkpoint != '' and FLAGS.output_file != '', 'input and output must be specified' 27 | params = StreamingCheckpointer.load_trainstate_checkpoint( 28 | FLAGS.load_checkpoint, disallow_trainstate=True 29 | )[1]['params'] 30 | 31 | if FLAGS.streaming: 32 | StreamingCheckpointer.save_train_state_to_file( 33 | params, FLAGS.output_file, float_dtype=FLAGS.float_dtype 34 | ) 35 | else: 36 | params = float_to_dtype(params, FLAGS.float_dtype) 37 | with mlxu.open_file(FLAGS.output, 'wb') as fout: 38 | fout.write(flax.serialization.msgpack_serialize(params, in_place=True)) 39 | 40 | 41 | if __name__ == "__main__": 42 | mlxu.run(main) 43 | -------------------------------------------------------------------------------- /EasyLM/scripts/diff_checkpoint.py: -------------------------------------------------------------------------------- 1 | # This script converts model checkpoint trained by EsayLM to a standard 2 | # mspack checkpoint that can be loaded by huggingface transformers or 3 | # flax.serialization.msgpack_restore. Such conversion allows models to be 4 | # used by other frameworks that integrate with huggingface transformers. 5 | 6 | import pprint 7 | from functools import partial 8 | import os 9 | import numpy as np 10 | import jax 11 | import jax.numpy as jnp 12 | import flax.serialization 13 | import mlxu 14 | from EasyLM.checkpoint import StreamingCheckpointer 15 | from EasyLM.jax_utils import float_to_dtype 16 | 17 | 18 | FLAGS, FLAGS_DEF = mlxu.define_flags_with_default( 19 | recover_diff=False, 20 | load_base_checkpoint='', 21 | load_target_checkpoint='', 22 | output_file='', 23 | streaming=True, 24 | float_dtype='bf16', 25 | ) 26 | 27 | 28 | def main(argv): 29 | assert FLAGS.load_base_checkpoint != '' and FLAGS.load_target_checkpoint != '' 30 | assert FLAGS.output_file != '' 31 | base_params = StreamingCheckpointer.load_trainstate_checkpoint( 32 | FLAGS.load_base_checkpoint, disallow_trainstate=True 33 | )[1]['params'] 34 | 35 | target_params = StreamingCheckpointer.load_trainstate_checkpoint( 36 | FLAGS.load_target_checkpoint, disallow_trainstate=True 37 | )[1]['params'] 38 | 39 | if FLAGS.recover_diff: 40 | params = jax.tree_util.tree_map( 41 | lambda b, t: b + t, base_params, target_params 42 | ) 43 | else: 44 | params = jax.tree_util.tree_map( 45 | lambda b, t: t - b, base_params, target_params 46 | ) 47 | 48 | if FLAGS.streaming: 49 | StreamingCheckpointer.save_train_state_to_file( 50 | params, FLAGS.output_file, float_dtype=FLAGS.float_dtype 51 | ) 52 | else: 53 | params = float_to_dtype(params, FLAGS.float_dtype) 54 | with mlxu.open_file(FLAGS.output, 'wb') as fout: 55 | fout.write(flax.serialization.msgpack_serialize(params, in_place=True)) 56 | 57 | 58 | if __name__ == "__main__": 59 | mlxu.run(main) 60 | -------------------------------------------------------------------------------- /EasyLM/scripts/lm_eval_harness.py: -------------------------------------------------------------------------------- 1 | # This script runs lm_eval_harness evaluations against a served language model. 2 | # Typically, you need to run a language model server first, e.g.: 3 | # python -m EasyLM.models.gptj.gptj_serve ... 4 | 5 | import dataclasses 6 | import pprint 7 | from functools import partial 8 | import os 9 | from tqdm import tqdm, trange 10 | import numpy as np 11 | import mlxu 12 | 13 | from flax.traverse_util import flatten_dict 14 | from lm_eval import evaluator, tasks 15 | from lm_eval.base import LM 16 | 17 | from EasyLM.serving import LMClient 18 | 19 | 20 | FLAGS, FLAGS_DEF = mlxu.define_flags_with_default( 21 | tasks='wsc,piqa,winogrande,openbookqa,logiqa', 22 | shots=0, 23 | limit=0, 24 | write_out=False, 25 | lm_client=LMClient.get_default_config(), 26 | logger=mlxu.WandBLogger.get_default_config(), 27 | ) 28 | 29 | 30 | class LMEvalHarnessInterface(LM): 31 | 32 | def __init__(self, lm_client): 33 | self.lm_client = lm_client 34 | 35 | def greedy_until(self, inputs): 36 | prefix, until = zip(*inputs) 37 | return self.lm_client.greedy_until(prefix, until) 38 | 39 | def loglikelihood_rolling(self, inputs): 40 | loglikelihood, is_greedy = self.lm_client.loglikelihood_rolling(inputs) 41 | return list(zip(loglikelihood, is_greedy)) 42 | 43 | def loglikelihood(self, inputs): 44 | prefix, text = zip(*inputs) 45 | loglikelihood, is_greedy = self.lm_client.loglikelihood(prefix, text) 46 | return list(zip(loglikelihood, is_greedy)) 47 | 48 | 49 | def main(argv): 50 | logger = mlxu.WandBLogger( 51 | config=FLAGS.logger, variant=mlxu.get_user_flags(FLAGS, FLAGS_DEF) 52 | ) 53 | model = LMEvalHarnessInterface(LMClient(FLAGS.lm_client)) 54 | task_list = FLAGS.tasks.split(',') 55 | results = evaluator.evaluate( 56 | model, tasks.get_task_dict(task_list), False, FLAGS.shots, 57 | limit=None if FLAGS.limit <= 0 else FLAGS.limit, 58 | write_out=FLAGS.write_out, 59 | ) 60 | logger.log(flatten_dict(results['results'], sep='/')) 61 | pprint.pprint(results) 62 | 63 | 64 | if __name__ == "__main__": 65 | mlxu.run(main) 66 | -------------------------------------------------------------------------------- /EasyLM/scripts/lm_eval_json.py: -------------------------------------------------------------------------------- 1 | import json 2 | import mlxu 3 | from EasyLM.serving import LMClient 4 | 5 | 6 | FLAGS, FLAGS_DEF = mlxu.define_flags_with_default( 7 | input_file='', 8 | output_file='', 9 | prefix_field='prefix', 10 | text_field='text', 11 | until_field='until', 12 | eval_type='loglikelihood', 13 | lm_client=LMClient.get_default_config(), 14 | ) 15 | 16 | 17 | def main(argv): 18 | lm_client = LMClient(FLAGS.lm_client) 19 | with mlxu.open_file(FLAGS.input_file, 'r') as fin: 20 | input_data = json.load(fin) 21 | 22 | if FLAGS.eval_type == 'loglikelihood': 23 | prefix = input_data[FLAGS.prefix_field] 24 | text = input_data[FLAGS.text_field] 25 | loglikelihoods, is_greedys = lm_client.loglikelihood(prefix, text) 26 | output_data = { 27 | 'loglikelihood': loglikelihoods, 28 | 'is_greedy': is_greedys, 29 | } 30 | elif FLAGS.eval_type == 'loglikelihood_rolling': 31 | text = input_data[FLAGS.text_field] 32 | loglikelihoods, is_greedys = lm_client.loglikelihood_rolling(text) 33 | output_data = { 34 | 'loglikelihood': loglikelihoods, 35 | 'is_greedy': is_greedys, 36 | } 37 | elif FLAGS.eval_type == 'greedy_until': 38 | prefix = input_data[FLAGS.prefix_field] 39 | until = input_data[FLAGS.until_field] 40 | output_data = {'output_text': lm_client.greedy_until(prefix, until)} 41 | elif FLAGS.eval_type == 'generate': 42 | prefix = input_data[FLAGS.prefix_field] 43 | output_data = {'output_text': lm_client.generate(prefix)} 44 | else: 45 | raise ValueError(f'Unknown eval_type: {FLAGS.eval_type}') 46 | 47 | with mlxu.open_file(FLAGS.output_file, 'w') as fout: 48 | json.dump(output_data, fout) 49 | 50 | 51 | if __name__ == "__main__": 52 | mlxu.run(main) 53 | -------------------------------------------------------------------------------- /EasyLM/serving.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import pprint 3 | from functools import partial 4 | import re 5 | import os 6 | from threading import Lock 7 | import urllib 8 | import time 9 | from typing import List, Optional, Union 10 | 11 | from pydantic import BaseModel 12 | import absl.logging 13 | from tqdm import tqdm, trange 14 | import numpy as np 15 | import mlxu 16 | from ml_collections import ConfigDict 17 | import uvicorn 18 | from fastapi import FastAPI 19 | import gradio as gr 20 | import requests 21 | from requests.exceptions import Timeout, ConnectionError 22 | 23 | 24 | class InferenceRequest(BaseModel): 25 | prefix_text: Optional[List[str]] = None 26 | text: Optional[List[str]] = None 27 | until: Optional[Union[List[str], List[List[str]]]] = None 28 | temperature: Optional[float] = None 29 | 30 | 31 | class ChatRequest(BaseModel): 32 | prompt: str 33 | context: str = '' 34 | temperature: Optional[float] = None 35 | 36 | 37 | class LMServer(object): 38 | """ HTTP server for serving langauge models. """ 39 | 40 | @staticmethod 41 | def get_default_config(updates=None): 42 | config = ConfigDict() 43 | config.host = '0.0.0.0' 44 | config.port = 5007 45 | config.batch_size = 1 46 | config.logging = False 47 | config.pre_compile = 'loglikelihood' 48 | config.default_temperature = 1.0 49 | config.greedy_until_max_length = 5000 50 | config.prepend_to_prefix = '' 51 | config.append_to_prefix = '' 52 | config.prepend_to_text = '' 53 | config.append_to_text = '' 54 | config.chat_prepend_text = '' 55 | config.chat_user_prefix = '' 56 | config.chat_user_suffix = '' 57 | config.chat_lm_prefix = '' 58 | config.chat_lm_suffix = '' 59 | config.notes = '' 60 | 61 | if updates is not None: 62 | config.update(ConfigDict(updates).copy_and_resolve_references()) 63 | return config 64 | 65 | def __init__(self, config): 66 | self.config = self.get_default_config(config) 67 | self.lock = Lock() 68 | self.app = FastAPI() 69 | self.app.post('/loglikelihood')(self.serve_loglikelihood) 70 | self.app.post('/loglikelihood-rolling')(self.serve_loglikelihood_rolling) 71 | self.app.post('/generate')(self.serve_generate) 72 | self.app.post('/greedy-until')(self.serve_greedy_until) 73 | self.app.post('/chat')(self.serve_chat) 74 | self.app.get('/ready')(self.serve_ready) 75 | self.app = gr.mount_gradio_app(self.app, self.create_chat_app(), '/') 76 | 77 | @staticmethod 78 | def loglikelihood(prefix_text, text): 79 | raise NotImplementedError() 80 | 81 | @staticmethod 82 | def loglikelihood_rolling(text): 83 | raise NotImplementedError() 84 | 85 | @staticmethod 86 | def generate(text, temperature): 87 | raise NotImplementedError() 88 | 89 | @staticmethod 90 | def greedy_until(prefix_text, until, max_length): 91 | raise NotImplementedError() 92 | 93 | @staticmethod 94 | def to_list(x): 95 | if isinstance(x, np.ndarray): 96 | return x.tolist() 97 | return x 98 | 99 | def serve_ready(self): 100 | return 'Ready!\n' 101 | 102 | def serve_loglikelihood(self, data: InferenceRequest): 103 | with self.lock: 104 | if self.config.logging: 105 | absl.logging.info( 106 | '\n========= Serving Log Likelihood Request ========= \n' 107 | + pprint.pformat(data) + '\n' 108 | ) 109 | 110 | if data.prefix_text is None: 111 | data.prefix_text = ['' for _ in data.text] 112 | 113 | prefix_text = [ 114 | self.config.prepend_to_prefix + p + self.config.append_to_prefix 115 | for p in data.prefix_text 116 | ] 117 | text = [ 118 | self.config.prepend_to_text + t + self.config.append_to_text 119 | for t in data.text 120 | ] 121 | 122 | log_likelihood = [] 123 | is_greedy = [] 124 | for i in trange(0, len(text), self.config.batch_size, ncols=0): 125 | batch_prefix_text = prefix_text[i:i + self.config.batch_size] 126 | batch_text = text[i:i + self.config.batch_size] 127 | batch_size = len(batch_text) 128 | 129 | if batch_size < self.config.batch_size: 130 | extra = self.config.batch_size - batch_size 131 | batch_prefix_text.extend(['a' for _ in range(extra)]) 132 | batch_text.extend(['a' for _ in range(extra)]) 133 | 134 | batch_log_likelihood, batch_is_greedy = self.loglikelihood( 135 | batch_prefix_text, batch_text 136 | ) 137 | batch_log_likelihood = self.to_list(batch_log_likelihood) 138 | batch_is_greedy = self.to_list(batch_is_greedy) 139 | log_likelihood.extend(batch_log_likelihood[:batch_size]) 140 | is_greedy.extend(batch_is_greedy[:batch_size]) 141 | 142 | output = { 143 | 'prefix_text': data.prefix_text, 144 | 'text': data.text, 145 | 'log_likelihood': log_likelihood, 146 | 'is_greedy': is_greedy, 147 | } 148 | if self.config.logging: 149 | absl.logging.info( 150 | '\n========= Output ========= \n' 151 | + pprint.pformat(output) + '\n' 152 | ) 153 | 154 | return output 155 | 156 | def serve_loglikelihood_rolling(self, data: InferenceRequest): 157 | with self.lock: 158 | if self.config.logging: 159 | absl.logging.info( 160 | '\n========= Serving Log Likelihood Request ========= \n' 161 | + pprint.pformat(data) + '\n' 162 | ) 163 | 164 | text = [ 165 | self.config.prepend_to_text + t + self.config.append_to_text 166 | for t in data.text 167 | ] 168 | log_likelihood = [] 169 | is_greedy = [] 170 | for i in trange(0, len(text), self.config.batch_size, ncols=0): 171 | batch_text = text[i:i + self.config.batch_size] 172 | batch_size = len(batch_text) 173 | 174 | if batch_size < self.config.batch_size: 175 | extra = self.config.batch_size - batch_size 176 | batch_text.extend(['a' for _ in range(extra)]) 177 | 178 | batch_log_likelihood, batch_is_greedy = self.loglikelihood_rolling( 179 | batch_text 180 | ) 181 | batch_log_likelihood = self.to_list(batch_log_likelihood) 182 | batch_is_greedy = self.to_list(batch_is_greedy) 183 | log_likelihood.extend(batch_log_likelihood[:batch_size]) 184 | is_greedy.extend(batch_is_greedy[:batch_size]) 185 | 186 | output = { 187 | 'text': data.text, 188 | 'log_likelihood': log_likelihood, 189 | 'is_greedy': is_greedy, 190 | } 191 | if self.config.logging: 192 | absl.logging.info( 193 | '\n========= Output ========= \n' 194 | + pprint.pformat(output) + '\n' 195 | ) 196 | 197 | return output 198 | 199 | def serve_generate(self, data: InferenceRequest): 200 | with self.lock: 201 | if self.config.logging: 202 | absl.logging.info( 203 | '\n========= Serving Generate Request ========= \n' 204 | + pprint.pformat(data) + '\n' 205 | ) 206 | prefix_text = [ 207 | self.config.prepend_to_prefix + p + self.config.append_to_prefix 208 | for p in data.prefix_text 209 | ] 210 | 211 | if data.temperature is None: 212 | data.temperature = self.config.default_temperature 213 | 214 | output_text = [] 215 | for i in trange(0, len(prefix_text), self.config.batch_size, ncols=0): 216 | batch_prefix_text = prefix_text[i:i + self.config.batch_size] 217 | batch_size = len(batch_prefix_text) 218 | 219 | if batch_size < self.config.batch_size: 220 | extra = self.config.batch_size - batch_size 221 | batch_prefix_text.extend(['a' for _ in range(extra)]) 222 | 223 | batch_output_text = self.generate( 224 | batch_prefix_text, 225 | temperature=data.temperature, 226 | ) 227 | output_text.extend(self.to_list(batch_output_text)[:batch_size]) 228 | 229 | output = { 230 | 'prefix_text': data.prefix_text, 231 | 'output_text': output_text, 232 | 'temperature': data.temperature, 233 | } 234 | if self.config.logging: 235 | absl.logging.info( 236 | '\n========= Output ========= \n' 237 | + pprint.pformat(output) + '\n' 238 | ) 239 | return output 240 | 241 | def serve_greedy_until(self, data: InferenceRequest): 242 | with self.lock: 243 | if self.config.logging: 244 | absl.logging.info( 245 | '\n========= Serving Greedy Until Request ========= \n' 246 | + pprint.pformat(data) + '\n' 247 | ) 248 | prefix_text = [ 249 | self.config.prepend_to_prefix + p + self.config.append_to_prefix 250 | for p in data.prefix_text 251 | ] 252 | until = data.until 253 | max_length = self.config.greedy_until_max_length 254 | 255 | output_text = [] 256 | for i in range(0, len(prefix_text), self.config.batch_size): 257 | batch_prefix_text = prefix_text[i:i + self.config.batch_size] 258 | batch_until = until[i:i + self.config.batch_size] 259 | batch_size = len(batch_prefix_text) 260 | 261 | batch_output_text = self.greedy_until(batch_prefix_text, batch_until, max_length) 262 | output_text.extend(self.to_list(batch_output_text)[:batch_size]) 263 | 264 | output = { 265 | 'prefix_text': data.prefix_text, 266 | 'until': data.until, 267 | 'max_length': max_length, 268 | 'output_text': output_text, 269 | } 270 | if self.config.logging: 271 | absl.logging.info( 272 | '\n========= Output ========= \n' 273 | + pprint.pformat(output) + '\n' 274 | ) 275 | return output 276 | 277 | def process_chat(self, prompt, context, temperature): 278 | context = ( 279 | context + self.config.chat_user_prefix 280 | + prompt + self.config.chat_user_suffix 281 | + self.config.chat_lm_prefix 282 | ) 283 | response = self.generate( 284 | [self.config.chat_prepend_text + context], 285 | temperature=float(temperature), 286 | )[0] 287 | context = context + response + self.config.chat_lm_suffix 288 | return response, context 289 | 290 | def serve_chat(self, data: ChatRequest): 291 | if data.temperature is None: 292 | data.temperature = self.config.default_temperature 293 | response, context = self.process_chat( 294 | data.prompt, data.context, 295 | temperature=data.temperature, 296 | ) 297 | return { 298 | 'response': response, 299 | 'context': context, 300 | 'temperature': data.temperature, 301 | } 302 | 303 | def create_chat_app(self): 304 | with gr.Blocks(analytics_enabled=False, title='EasyLM Chat') as gradio_chatbot: 305 | gr.Markdown('# Chatbot Powered by [EasyLM](https://github.com/young-geng/EasyLM)') 306 | gr.Markdown(self.config.notes) 307 | chatbot = gr.Chatbot(label='Chat history') 308 | msg = gr.Textbox( 309 | placeholder='Type your message here...', 310 | show_label=False 311 | ) 312 | with gr.Row(): 313 | send = gr.Button('Send') 314 | regenerate = gr.Button('Regenerate', interactive=False) 315 | clear = gr.Button('Reset') 316 | 317 | temp_slider = gr.Slider( 318 | label='Temperature', minimum=0, maximum=2.0, 319 | value=self.config.default_temperature 320 | ) 321 | 322 | context_state = gr.State(['', '']) 323 | 324 | def user_fn(user_message, history, context): 325 | return { 326 | msg: gr.update(value='', interactive=False), 327 | clear: gr.update(interactive=False), 328 | send: gr.update(interactive=False), 329 | regenerate: gr.update(interactive=False), 330 | chatbot: history + [[user_message, None]], 331 | context_state: [context[1], context[1]], 332 | } 333 | 334 | def model_fn(history, context, temperature): 335 | history[-1][1], new_context = self.process_chat( 336 | history[-1][0], context[0], temperature 337 | ) 338 | return { 339 | msg: gr.update(value='', interactive=True), 340 | clear: gr.update(interactive=True), 341 | send: gr.update(interactive=True), 342 | chatbot: history, 343 | context_state: [context[0], new_context], 344 | regenerate: gr.update(interactive=True), 345 | } 346 | 347 | def regenerate_fn(): 348 | return { 349 | msg: gr.update(value='', interactive=False), 350 | clear: gr.update(interactive=False), 351 | send: gr.update(interactive=False), 352 | regenerate: gr.update(interactive=False), 353 | } 354 | 355 | def clear_fn(): 356 | return { 357 | chatbot: None, 358 | msg: '', 359 | context_state: ['', ''], 360 | regenerate: gr.update(interactive=False), 361 | } 362 | 363 | msg.submit( 364 | user_fn, 365 | inputs=[msg, chatbot, context_state], 366 | outputs=[msg, clear, send, chatbot, context_state, regenerate], 367 | queue=False 368 | ).then( 369 | model_fn, 370 | inputs=[chatbot, context_state, temp_slider], 371 | outputs=[msg, clear, send, chatbot, context_state, regenerate], 372 | queue=True, 373 | concurrency_limit=1, 374 | ) 375 | send.click( 376 | user_fn, 377 | inputs=[msg, chatbot, context_state], 378 | outputs=[msg, clear, send, chatbot, context_state, regenerate], 379 | queue=False 380 | ).then( 381 | model_fn, 382 | inputs=[chatbot, context_state, temp_slider], 383 | outputs=[msg, clear, send, chatbot, context_state, regenerate], 384 | queue=True, 385 | concurrency_limit=1, 386 | ) 387 | regenerate.click( 388 | regenerate_fn, 389 | inputs=None, 390 | outputs=[msg, clear, send, regenerate], 391 | queue=False 392 | ).then( 393 | model_fn, 394 | inputs=[chatbot, context_state, temp_slider], 395 | outputs=[msg, clear, send, chatbot, context_state, regenerate], 396 | queue=True, 397 | concurrency_limit=1, 398 | ) 399 | clear.click( 400 | clear_fn, 401 | inputs=None, 402 | outputs=[chatbot, msg, context_state, regenerate], 403 | queue=False 404 | ) 405 | 406 | gradio_chatbot.queue() 407 | return gradio_chatbot 408 | 409 | def run(self): 410 | if self.config.pre_compile != '': 411 | if self.config.pre_compile == 'all': 412 | pre_compile = ['loglikelihood', 'generate', 'greedy_until', 'chat'] 413 | else: 414 | pre_compile = self.config.pre_compile.split(',') 415 | 416 | pre_compile_data = ['a' for _ in range(self.config.batch_size)] 417 | for task in pre_compile: 418 | if task == 'loglikelihood': 419 | self.loglikelihood(pre_compile_data, pre_compile_data) 420 | self.loglikelihood_rolling(pre_compile_data) 421 | elif task == 'generate': 422 | self.generate(pre_compile_data, 1.0) 423 | elif task == 'greedy_until': 424 | self.greedy_until( 425 | pre_compile_data, pre_compile_data, 426 | self.config.greedy_until_max_length 427 | ) 428 | elif task == 'chat': 429 | self.process_chat('a', 'a', 1.0) 430 | else: 431 | raise ValueError(f'Invalid precompile task: {task}!') 432 | 433 | uvicorn.run(self.app, host=self.config.host, port=self.config.port) 434 | 435 | 436 | class LMClient(object): 437 | """ A simple client for the LM server. """ 438 | 439 | @staticmethod 440 | def get_default_config(updates=None): 441 | config = ConfigDict() 442 | config.url = 'http://localhost:5007' 443 | config.batch_size = 1 444 | config.wait_for_ready = True 445 | config.dummy = False 446 | 447 | if updates is not None: 448 | config.update(ConfigDict(updates).copy_and_resolve_references()) 449 | return config 450 | 451 | def __init__(self, config=None): 452 | self.config = self.get_default_config(config) 453 | if self.config.wait_for_ready: 454 | self.wait_for_ready() 455 | 456 | def wait_for_ready(self): 457 | if self.config.dummy: 458 | return 459 | while True: 460 | try: 461 | requests.get(urllib.parse.urljoin(self.config.url, 'ready')) 462 | return 463 | except (Timeout, ConnectionError) as e: 464 | time.sleep(10) 465 | 466 | @staticmethod 467 | def batched(iterator, batch_size): 468 | batch = [] 469 | for example in iterator: 470 | batch.append(example) 471 | if len(batch) == batch_size: 472 | yield batch 473 | batch = [] 474 | if len(batch) > 0: 475 | yield batch 476 | 477 | def loglikelihood(self, prefix, text): 478 | prefix, text = list(prefix), list(text) 479 | if self.config.dummy: 480 | return [-1.0 for _ in text], [False for _ in text] 481 | 482 | log_likelihood = [] 483 | is_greedy = [] 484 | 485 | batched_iterator = list(zip( 486 | self.batched(prefix, self.config.batch_size), 487 | self.batched(text, self.config.batch_size) 488 | )) 489 | for batch_prefix, batch_text in tqdm(batched_iterator, ncols=0): 490 | response = requests.post( 491 | urllib.parse.urljoin(self.config.url, 'loglikelihood'), 492 | json={'prefix_text': batch_prefix, 'text': batch_text} 493 | ).json() 494 | log_likelihood.extend(response['log_likelihood']) 495 | is_greedy.extend(response['is_greedy']) 496 | 497 | return log_likelihood, is_greedy 498 | 499 | def loglikelihood_rolling(self, text): 500 | text = list(text) 501 | if self.config.dummy: 502 | return [-1.0 for _ in text], [False for _ in text] 503 | 504 | log_likelihood = [] 505 | is_greedy = [] 506 | batched_iterator = list(self.batched(text, self.config.batch_size)) 507 | for batch_text in tqdm(batched_iterator, ncols=0): 508 | response = requests.post( 509 | urllib.parse.urljoin(self.config.url, 'loglikelihood-rolling'), 510 | json={'text': batch_text} 511 | ).json() 512 | log_likelihood.extend(response['log_likelihood']) 513 | is_greedy.extend(response['is_greedy']) 514 | return log_likelihood, is_greedy 515 | 516 | def greedy_until(self, prefix, until): 517 | prefix, until = list(prefix), list(until) 518 | if self.config.dummy: 519 | results = [] 520 | for u in until: 521 | if isinstance(u, str): 522 | results.append('dummy text ' + u) 523 | else: 524 | results.append('dummy text ' + u[0]) 525 | return results 526 | 527 | batched_iterator = list(zip( 528 | self.batched(prefix, self.config.batch_size), 529 | self.batched(until, self.config.batch_size), 530 | )) 531 | output_text = [] 532 | for batch_prefix, batch_until in tqdm(batched_iterator, ncols=0): 533 | response = requests.post( 534 | urllib.parse.urljoin(self.config.url, 'greedy-until'), 535 | json={'prefix_text': batch_prefix, 'until': batch_until} 536 | ).json() 537 | output_text.extend(response['output_text']) 538 | return output_text 539 | 540 | def generate(self, prefix, temperature=None): 541 | prefix = list(prefix) 542 | if self.config.dummy: 543 | return ['' for _ in prefix] 544 | 545 | output_text = [] 546 | batched_iterator = list(self.batched(prefix, self.config.batch_size)) 547 | for batch_prefix in tqdm(batched_iterator, ncols=0): 548 | response = requests.post( 549 | urllib.parse.urljoin(self.config.url, 'generate'), 550 | json={ 551 | 'prefix_text': batch_prefix, 552 | 'temperature': temperature, 553 | } 554 | ).json() 555 | output_text.extend(response['output_text']) 556 | return output_text 557 | 558 | def chat(self, prompt, context, temperature=None): 559 | if self.config.dummy: 560 | return '' 561 | response = requests.post( 562 | urllib.parse.urljoin(self.config.url, 'chat'), 563 | json={ 564 | 'prompt': prompt, 565 | 'context': context, 566 | 'temperature': temperature, 567 | } 568 | ).json() 569 | return response['response'], response['context'] 570 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EasyLM 2 | Large language models (LLMs) made easy, EasyLM is a one stop solution for 3 | pre-training, finetuning, evaluating and serving LLMs in JAX/Flax. EasyLM can 4 | scale up LLM training to hundreds of TPU/GPU accelerators by leveraging 5 | JAX's pjit functionality. 6 | 7 | 8 | Building on top of Hugginface's [transformers](https://huggingface.co/docs/transformers/main/en/index) 9 | and [datasets](https://huggingface.co/docs/datasets/index), this repo provides 10 | an easy to use and easy to customize codebase for training large language models 11 | without the complexity in many other frameworks. 12 | 13 | 14 | EasyLM is built with JAX/Flax. By leveraging JAX's pjit utility, EasyLM is able 15 | to train large models that don't fit on a single accelerator by sharding 16 | the model weights and training data across multiple accelerators. Currently, 17 | EasyLM supports multiple TPU/GPU training in a single host as well as multi-host 18 | training on Google Cloud TPU Pods. 19 | 20 | Currently, the following models are supported: 21 | * [LLaMA](https://arxiv.org/abs/2302.13971) 22 | * [LLaMA 2](https://arxiv.org/abs/2307.09288) 23 | * [LLaMA 3](https://llama.meta.com/llama3/) 24 | 25 | ## Discord Server 26 | We are running an unofficial Discord community (unaffiliated with Google) for discussion related to training LLMs in JAX. [Follow this link to join the Discord server](https://discord.gg/Rf4drG3Bhp). We have dedicated channels for several JAX based LLM frameworks, include EasyLM, [JaxSeq](https://github.com/Sea-Snell/JAXSeq), [Alpa](https://github.com/alpa-projects/alpa) and [Levanter](https://github.com/stanford-crfm/levanter). 27 | 28 | 29 | ## Models Trained with EasyLM 30 | ### OpenLLaMA 31 | OpenLLaMA is our permissively licensed reproduction of LLaMA which can be used 32 | for commercial purposes. Check out the [project main page here](https://github.com/openlm-research/open_llama). 33 | The OpenLLaMA can serve as drop in replacement for the LLaMA weights in EasyLM. 34 | Please refer to the [LLaMA documentation](docs/llama.md) for more details. 35 | 36 | 37 | ### Koala 38 | Koala is our new chatbot fine-tuned on top of LLaMA. If you are interested in 39 | our Koala chatbot, you can check out the [blogpost](https://bair.berkeley.edu/blog/2023/04/03/koala/) 40 | and [documentation for running it locally](docs/koala.md). 41 | 42 | 43 | ## Installation 44 | The installation method differs between GPU hosts and Cloud TPU hosts. The first 45 | step is to pull from GitHub. 46 | 47 | ``` shell 48 | git clone https://github.com/young-geng/EasyLM.git 49 | cd EasyLM 50 | export PYTHONPATH="${PWD}:$PYTHONPATH" 51 | ``` 52 | 53 | #### Installing on GPU Host 54 | The GPU environment can be installed via [Anaconda](https://www.anaconda.com/products/distribution). 55 | 56 | ``` shell 57 | conda env create -f scripts/gpu_environment.yml 58 | conda activate EasyLM 59 | ``` 60 | 61 | #### Installing on Cloud TPU Host 62 | The TPU host VM comes with Python and PIP pre-installed. Simply run the following 63 | script to set up the TPU host. 64 | 65 | ``` shell 66 | ./scripts/tpu_vm_setup.sh 67 | ``` 68 | 69 | 70 | ## [Documentations](docs/README.md) 71 | The EasyLM documentations can be found in the [docs](docs/) directory. 72 | 73 | 74 | ## Reference 75 | If you found EasyLM useful in your research or applications, please cite using the following BibTeX: 76 | ``` 77 | @software{geng2023easylm, 78 | author = {Geng, Xinyang}, 79 | title = {EasyLM: A Simple And Scalable Training Framework for Large Language Models}, 80 | month = March, 81 | year = 2023, 82 | url = {https://github.com/young-geng/EasyLM} 83 | } 84 | ``` 85 | 86 | 87 | 88 | ## Credits 89 | * The LLaMA implementation is from [JAX_llama](https://github.com/Sea-Snell/JAX_llama) 90 | * The JAX/Flax GPT-J and RoBERTa implementation are from [transformers](https://huggingface.co/docs/transformers/main/en/index) 91 | * Most of the JAX utilities are from [mlxu](https://github.com/young-geng/mlxu) 92 | * The codebase is heavily inspired by [JAXSeq](https://github.com/Sea-Snell/JAXSeq) 93 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # EasyLM Documentations 2 | EasyLM is a framework for pre-training, fine-tuning, and evaluating and serving 3 | large language models in JAX/Flax. EasyLM is designed to be easy to use by 4 | hiding the complexities of distributed model/data parallelism but exposing the 5 | core training and inference details of large language models, making it easy to 6 | customize. EasyLM can scale up LLM training to hundreds of TPU/GPU accelerators 7 | without the need to write complicated distributed training code. 8 | 9 | ## Installation 10 | EasyLM supports both GPU and TPU training. The installation method differs by 11 | the type of accelerator. The first step is to pull from GitHub. 12 | 13 | ``` shell 14 | git clone https://github.com/young-geng/EasyLM.git 15 | cd EasyLM 16 | export PYTHONPATH="${PWD}:$PYTHONPATH" 17 | ``` 18 | 19 | #### Installing on GPU Host 20 | The GPU environment can be installed via [Anaconda](https://www.anaconda.com/products/distribution). 21 | 22 | ``` shell 23 | conda env create -f scripts/gpu_environment.yml 24 | conda activate EasyLM 25 | ``` 26 | 27 | #### Installing on Cloud TPU Host 28 | The TPU host VM comes with Python and PIP pre-installed. Simply run the following 29 | script to set up the TPU host. 30 | 31 | ``` shell 32 | ./scripts/tpu_vm_setup.sh 33 | ``` 34 | 35 | 36 | ## Modular Configuration 37 | EasyLM is designed to be highly modular. Typically, the training or inference 38 | script will combine various modules to form a complete training or 39 | inference process. Building on top of [MLXU](https://github.com/young-geng/mlxu), 40 | EasyLM training or inference scripts can directly plug in the configuration of 41 | used modules into the command line flags of the script. 42 | 43 | For example, if we have a training script `train.py` that uses the optimizer module, 44 | we can directly plug in the configuration of the optimizer module into the FLAGS 45 | of the training script in this way: 46 | 47 | ``` python 48 | import mlxu 49 | from EasyLM.optimizer import OptimizerFactory 50 | 51 | # Defining the command line flags, flag data type will be inferred from the default value 52 | FLAGS, FLAGS_DEF = mlxu.define_flags_with_default( 53 | seed=42, # Defining a integer flag with default value 42 54 | optimizer=OptimizerFactory.get_default_config(), # Plugging in the default configuration of the optimizer module 55 | ) 56 | 57 | def main(argv): 58 | seed = FLAGS.seed 59 | optimizer, optimizer_info = OptimizerFactory.get_optimizer(FLAGS.optimizer) 60 | ... 61 | 62 | if __name__ == "__main__": 63 | mlxu.run(main) 64 | 65 | ``` 66 | 67 | 68 | By plugging in the configuration of the optimizer module into the FLAGS of the 69 | training script, we can directly control the optimizer module from the command 70 | line. For example, if we want to use the AdamW optimizer with learning rate 1e-4, 71 | we can run the training script with the following command: 72 | 73 | ``` shell 74 | python train.py \ 75 | --seed=42 \ 76 | --optimizer.type=adamw \ 77 | --optimizer.adamw_optimizer.lr=1e-4 78 | ``` 79 | 80 | For more information about the configuration of each module, please refer to the 81 | `get_default_config()` method of the module. 82 | 83 | 84 | ## Documentations for EasyLM Modules and Scripts 85 | Here are the documentations for the common modules and scripts in EasyLM: 86 | * [Parallelism](parallelism.md): model and data parallelism in EasyLM 87 | * [Dataset](dataset.md): data loading and processing 88 | * [Optimizer](optimizer.md): optimizer and gradient accumulation 89 | * [Checkpointing](checkpointing.md): checkpointing 90 | * [Serving](serving.md): serving the language model with an HTTP server 91 | * [Logger](logger.md): logging metrics for training 92 | * [Evaluation](evaluation.md): evaluation of language models on benchmarks 93 | 94 | 95 | 96 | ## Documentations for Language Models Supported by EasyLM 97 | Currently, the following models are supported: 98 | * [LLaMA](llama.md) 99 | * GPT-J 100 | * OPT 101 | * RoBERTa 102 | 103 | 104 | ## Additional Examples and Tutorials 105 | * [Running Koala locally](koala.md) -------------------------------------------------------------------------------- /docs/checkpointing.md: -------------------------------------------------------------------------------- 1 | # Checkpointing 2 | To facilitate training very large language models that does not fit into the 3 | main memory of a single machine, EasyLM adopt a streaming format of model 4 | checkpoint. The streaming checkpointing format is implemented in 5 | [checkpoint.py](/EasyLM/checkpoint.py). During checkpointing, the 6 | StreamingCheckpointer simply flatten a nested state dictionary into a single 7 | level dictionary, and stream the key, value pairs to a file one by one using 8 | messagepack. Because it streams the tensors one by one, the checkpointer only 9 | needs to gather one tensor from the distributed accelerators to the main memory 10 | at a time, hence saving a lot of memory. 11 | 12 | 13 | ## Loading Checkpoint 14 | While EasyLM mainly uses the streaming checkpointing format, it also supports 15 | directly loading the standard flax checkpoint file created using 16 | `flax.training.checkpoints.save_checkpoint`. The loading format can be specified 17 | as part of the path passed into the training or serving script. For example, if 18 | we want to serve a LLaMA model using the streaming checkpointing format, we can 19 | use the following command: 20 | 21 | ``` shell 22 | python -m EasyLM.models.llama.llama_serve \ 23 | --load_checkpoint='params::path/to/checkpoint' 24 | ... 25 | ``` 26 | 27 | Note that the `params::` prefix is used to specify that the checkpoint is in 28 | streaming format. The following prefix are supported for loading checkpoint: 29 | * `params::`: Streaming checkpointing format. 30 | * `flax::`: Standard flax checkpointing format. 31 | * `trainstate::`: Loading an entire train state with optimizer state, this 32 | option is only supported for training script. 33 | * `trainstate_params::`: Loading the params part from the entire train state. 34 | 35 | By default, EasyLM does not save the optimizer state in the checkpoint, so 36 | we will rarely need to use the `trainstate::` or `trainstate_params::` options. 37 | 38 | 39 | ## Saving Checkpoint 40 | EasyLM will only save the checkpoint in the streaming format. By default, only 41 | the model parameters are saved in the checkpoint file in the bfloat16 data type. 42 | To configure the checkpointing behavior, you can use the following options: 43 | * `float_dtype`: The float data type of the model parameters in the checkpoint file. 44 | The default value is `bf16`, other supported values are `fp32` and `fp16`. 45 | * `save_optimizer_state`: Whether to save the entire train state in the checkpoint 46 | 47 | Typically, we pass these optiosn into the training script. For example, for 48 | LLaMA, we can use the following command to save the checkpoint in the fp32 data: 49 | ``` shell 50 | python -m EasyLM.models.llama.llama_train \ 51 | --checkpointer.float_dtype='fp32' \ 52 | ... 53 | ``` 54 | 55 | 56 | ## Converting Checkpoint to and from Standard Flax Format 57 | To facilitate the use of EasyLM trained models with other Flax based libraries, 58 | EasyLM provides a script to convert between the streaming checkpointing format 59 | and the standard flax checkpointing format. The script can be found at 60 | [EasyLM/scripts/convert_checkpoint.py](/EasyLM/scripts/convert_checkpoint.py). 61 | 62 | To convert a checkpoint from the streaming format to the standard flax format, 63 | use the following command: 64 | 65 | ``` shell 66 | python -m EasyLM.scripts.convert_checkpoint \ 67 | --load_checkpoint='params::path/to/checkpoint' \ 68 | --output_file='path/to/output/checkpoint' \ 69 | --streaming=False 70 | ``` 71 | 72 | To convert a standard flax checkpoint to the streaming format, use the following 73 | command: 74 | 75 | ``` shell 76 | python -m EasyLM.scripts.convert_checkpoint \ 77 | --load_checkpoint='flax::path/to/checkpoint' \ 78 | --output_file='path/to/output/checkpoint' \ 79 | --streaming=True 80 | ``` 81 | 82 | 83 | ## Diffing Checkpoint 84 | To facilitate the release of fine-tuned model checkpoints that's based on 85 | a non-public base model checkpoint, EasyLM provides a script to compute the 86 | difference between two checkpoints. The script can be found at 87 | [EasyLM/scripts/diff_checkpoint.py](/EasyLM/scripts/diff_checkpoint.py). 88 | 89 | To compute the difference between a based checkpoint (based model) and a 90 | target checkpoint (fine-tuned model), use the following command: 91 | 92 | ``` shell 93 | python -m EasyLM.scripts.diff_checkpoint \ 94 | --recover_diff=False \ 95 | --load_base_checkpoint='params::path/to/based/checkpoint' \ 96 | --load_target_checkpoint='params::path/to/target/checkpoint' \ 97 | --output_file='path/to/output/checkpoint' \ 98 | --streaming=True 99 | ``` 100 | 101 | The script will output a checkpoint that contains the difference between the 102 | two checkpoints. You can use the `--streaming` flag to specify the format 103 | (streaming or standard flax) of the output checkpoint. To recover a checkpoint 104 | from a based checkpoint and a diff checkpoint, use the following command: 105 | 106 | ``` shell 107 | python -m EasyLM.scripts.diff_checkpoint \ 108 | --recover_diff=True \ 109 | --load_base_checkpoint='params::path/to/base/checkpoint' \ 110 | --load_target_checkpoint='params::path/to/diff/checkpoint' \ 111 | --output_file='path/to/output/checkpoint' \ 112 | --streaming=True 113 | ``` 114 | -------------------------------------------------------------------------------- /docs/dataset.md: -------------------------------------------------------------------------------- 1 | # Dataset 2 | EasyLM has built in support for the following types of datasets: 3 | * Huggingface dataset 4 | * JSON dataset 5 | 6 | These dataset modules are implemented in the [data.py](/EasyLM/data.py) file. 7 | 8 | Typically, datasets are configured by passing in command line arguments to the 9 | training script. For example, to use the Huggingface dataset for training GPT-J, 10 | you can use the following command line options: 11 | 12 | ```bash 13 | python -m EasyLM.models.gptj.gptj_train \ 14 | --train_dataset.text_processor.fields='text' \ 15 | --train_dataset.type='huggingface' \ 16 | --train_dataset.huggingface_dataset.path='c4' 17 | ``` 18 | 19 | In this example, we select the Huggingface dataset by specifying the `type` of 20 | `train_dataset` to be `huggingface`. We then specify the path to the dataset, 21 | which is `c4` in this case. The examples loaded from the dataset will be processed 22 | by a TextProcessor, which is configured by the `text_processor` field. 23 | 24 | The following options are supported for the dataset module: 25 | * `type`: The type of the dataset. Supported values are `huggingface` and `json`. 26 | * `text_processor`: The configuration of the TextProcessor used to process the 27 | loaded examples. 28 | * `huggingface_dataset`: The configuration of the Huggingface dataset. 29 | * `json_dataset`: The configuration of the JSON dataset. 30 | 31 | 32 | ## Huggingface Dataset 33 | Huggingface dataset uses the [datasets](https://huggingface.co/docs/datasets/index) 34 | library to download and load datasets. Here are the configurable options for 35 | Huggingface dataset: 36 | * `path`: The path to the dataset. Same as the `path` argument in 37 | `datasets.load_dataset`. 38 | * `name`: Name of the dataset within the path. Same as the `name` argument in 39 | `datasets.load_dataset`. 40 | * `split`: The split of the dataset. Same as the `split` argument in 41 | `datasets.load_dataset`. 42 | * `streaming`: Whether to stream the dataset. Same as the `streaming` argument 43 | in `datasets.load_dataset`. 44 | * `seq_length`: The length of the tokenized sequence. 45 | * `batch_size`: Batch size of tokenized examples. 46 | 47 | Each loaded example is a dictionary, which will be processed by a TextProcessor 48 | to become the final tokens and masks. 49 | 50 | 51 | ## JSON Dataset 52 | JSON dataset loads examples from a text file, where each line represents a 53 | JSON encoded dictionary. Here are the configurable options for JSON dataset: 54 | * `path`: Path to the text file. The file can be located on the local file system 55 | or on Google Cloud Storage bucket. 56 | * `seq_length`: The length of the tokenized sequence. 57 | * `batch_size`: Batch size of tokenized examples. 58 | * `start_seek_loc`: The starting seek location in the file. This is useful when 59 | you want to resume training from a particular location in the file. 60 | * `index_at_start`: The counting index at the beginning. This is useful to 61 | keep the index count when resuming from a particular location in the file. 62 | Note that this is only for logging purpose, and does not affect the actual 63 | examples starting from. To start from a different example in the dataset, 64 | you should use the `start_seek_loc` option. 65 | * `tokenizer_processes`: The number of processes to use for tokenization. 66 | Tokenization is done in parallel to speed up the loading process. 67 | 68 | 69 | Each loaded example is a dictionary, which will be processed by a TextProcessor 70 | 71 | 72 | ## Text Processor 73 | A TextProcessor is used to process the loaded examples from a dataset. Each 74 | input example is a dictionary of multiple text fields. The TextProcessor will 75 | process text fields according to its configurations, and return the final tokens. 76 | 77 | Here are the configurable options for TextProcessor: 78 | * `fields`: A comma separated list of text fields to process. 79 | * `fields_from_example`: Whether to use the keys of the input example as the 80 | text fields to process. If this option is set, the `fields` argument will 81 | be ignored. 82 | * `subfield_separator`: The text separator to use when concatenating subfields 83 | of a texts. 84 | * `add_eos_token`: Whether to add an EOS token to the end of the text. 85 | * `prepend_text`: The text to prepended to the beginning. 86 | 87 | The most important configuration for TextProcessor is the `fields` argument. It 88 | is a comma separated list of text fields to process. Each field consists of one 89 | or more subfields, which are separated by a `+`. Each subfield represent a key 90 | used to extract the text from the input example dictionary. The TextProcessor 91 | joins the extracted subfields of texts with the `subfield_separator` in the text 92 | level and then tokenize the joined text. Finally, the TextProcessor will concatenate 93 | the tokenized text fields at the token level, and add the EOS token if specified. 94 | 95 | Other than the keys in the input example, you can also use the following special 96 | keys to indicate a special token for a text field: 97 | * `<|bos|>`: Beginning of sentence token. 98 | * `<|eos|>`: End of sentence token. 99 | 100 | For each text field, you can encapulate the subfields with `[]` to specify that 101 | the loss should not be computed for this field. Doing so will make the loss 102 | masks to be 0 for this field. This is useful when you want to use the text field 103 | as a prompt for the model. 104 | 105 | 106 | To give a concrete example, if the input example looks like this: 107 | ```python 108 | { 109 | 'question': 'Would ice float on water?', 110 | 'prompt': 'Think step by step.', 111 | 'answer': 'The density of ice is 0.92 g/cm3, and the density of water is 1.0 g/cm3. So ice will float on water.', 112 | } 113 | ``` 114 | To use the `question` and `prompt` as the input text, and `answer` as the target, 115 | we can specify the following configuration for the `fields` argument: 116 | ``` 117 | [question+prompt],answer 118 | ``` 119 | 120 | The `question+prompt` indicates that the `question` and `prompt` should be joined 121 | togather with the `subfield_separator`, which is a space by default. The `[]` 122 | indicates that the loss should not be computed for this field. The `answer` field 123 | is then concatenated at the token level, where the loss will be computed. 124 | 125 | -------------------------------------------------------------------------------- /docs/evaluation.md: -------------------------------------------------------------------------------- 1 | # Evaluating Language Models 2 | EasyLM has builtin support for evaluating language models on a variety of tasks. 3 | Once the trained language model is served with LMServer, it can be evaluated 4 | against various benchmarks in few-shot and zero-shot settings. 5 | 6 | ## LM Evaluation Harness 7 | EasyLM comes with builtin support for [lm-eval-harness](https://github.com/EleutherAI/lm-evaluation-harness), 8 | which can evaluate the language model on a variety of tasks. For example, 9 | you can use the following command to evaluate the langauge model served with 10 | the HTTP server: 11 | 12 | ```shell 13 | python -m EasyLM.scripts.lm_eval_harness \ 14 | --lm_client.url='http://localhost:5007/' \ 15 | --tasks='wsc,piqa,winogrande,openbookqa,logiqa' \ 16 | --shots=0 17 | ``` 18 | 19 | The `lm_eval_harness` script supports the following commnad line options: 20 | * `tasks`: a comma separated list of tasks to evaluate the language model on. 21 | The supported tasks are listed in the 22 | [lm-eval-harness task table](https://github.com/EleutherAI/lm-evaluation-harness/blob/master/docs/task_table.md) 23 | * `shots`: the number of shots to use for the evaluation. 24 | * `batch_size`: the batch size to use for each http request. Too large a batch 25 | size may cause the request to time out. Default to 1. 26 | * `lm_client`: the configurations for LMClient. See [the LMClient documentation](serving.md) 27 | for more details. 28 | * `logger`: the configurations for the logger. See [the logger documentation](logger.md) 29 | for more details. 30 | 31 | ## Evaluating on MMLU 32 | The served langauge model can also be evaluated with the [MMLU](https://github.com/hendrycks/test) 33 | benchmark. In order to run the evaluation, you'll need to use [my fork of MMLU](https://github.com/young-geng/mmlu_easylm) which supports EasyLM LMServer. 34 | 35 | ```shell 36 | git clone https://github.com/young-geng/mmlu_easylm.git 37 | cd mmlu_easylm 38 | python evaluate_easylm.py \ 39 | --name='llama' \ 40 | --lm_server_url='http://localhost:5007' \ 41 | --ntrain=5 42 | ``` 43 | 44 | 45 | -------------------------------------------------------------------------------- /docs/koala.md: -------------------------------------------------------------------------------- 1 | # Koala 2 | Koala is a language model fine-tuned on top of LLaMA. 3 | [Check out the blogpost!](https://bair.berkeley.edu/blog/2023/04/03/koala/) 4 | This documentation will describe the process of downloading, recovering the 5 | Koala model weights, and running the Koala chatbot locally. 6 | 7 | 8 | ## Obtaining the Wegith Diff of Koala 9 | Due to the licence of the LLaMA model, we cannot directly release the fine-tuned 10 | Koala model weights. Instead, we release the diff of weights, which can be used 11 | to recover the Koala model weights with the origina LLaMA model weights. The diff 12 | weights can be downloaded from the following sources: 13 | * [HuggingFace Hub](https://huggingface.co/young-geng/koala/tree/main). 14 | * [Google Drive](https://drive.google.com/drive/folders/10f7wrlAFoPIy-TECHsx9DKIvbQYunCfl?usp=sharing). 15 | 16 | 17 | ## Recovering the Koala Model Weights 18 | The first step of recovering the Koala model weights is to obtain the original 19 | LLaMA model weights and convert it to EasyLM checkpoint format. To convert the weights, 20 | use the following command: 21 | 22 | ``` shell 23 | python -m EasyLM.models.llama.convert_torch_to_easylm \ 24 | --checkpoint_dir='path/to/torch/llama/checkpoint/directory' \ 25 | --output_file='path/to/output/easylm/checkpoint/file' \ 26 | --streaming=True 27 | ``` 28 | 29 | This script will convert the official torch checkpoint from Meta to the 30 | streaming checkpoint format used by EasyLM. For more information 31 | about the checkpoint format of EasyLM, see [the checkpointing documentation](checkpointing.md). 32 | 33 | 34 | After converting the original LLaMA model weights, you can recover the Koala 35 | model weights with the following command: 36 | 37 | ``` shell 38 | python -m EasyLM.scripts.diff_checkpoint \ 39 | --recover_diff=True \ 40 | --load_base_checkpoint='params::path/to/llama/checkpoint/file' \ 41 | --load_target_checkpoint='params::path/to/koala/diff/checkpoint/file' \ 42 | --output_file='path/to/output/checkpoint/file' \ 43 | --streaming=True 44 | ``` 45 | 46 | 47 | ## Serving the Koala Chatbot 48 | You can serve the LLaMA model with the LMServer of EasyLM. To do so, use the 49 | following command: 50 | 51 | ``` shell 52 | python -m EasyLM.models.llama.llama_serve \ 53 | --load_llama_config='13b' \ 54 | --load_checkpoint="params::/path/to/recovered/checkpoint" \ 55 | --tokenizer.vocab_file='/path/to/tokenizer.model' \ 56 | --mesh_dim='1,1,-1' \ 57 | --dtype='bf16' \ 58 | --input_length=1024 \ 59 | --seq_length=2048 \ 60 | --do_sample=True \ 61 | --lm_server.batch_size=1 \ 62 | --lm_server.port=5009 \ 63 | --lm_server.pre_compile='chat' \ 64 | --lm_server.chat_prepend_text='BEGINNING OF CONVERSATION: ' \ 65 | --lm_server.chat_lm_prefix='GPT:' \ 66 | --lm_server.chat_lm_suffix='' \ 67 | --lm_server.chat_user_prefix='USER: ' \ 68 | --lm_server.chat_user_suffix=' ' 69 | ``` 70 | 71 | Then navigate to `http://localhost:5009` to interact with the chatbot. 72 | 73 | 74 | ## Converting the Koala Weights to HuggingFace Transformers 75 | You can also convert the Koala model weights to HuggingFace Transformers format, 76 | so it can be used with the LLaMA implementation in transformers. To do so, use 77 | the following command: 78 | 79 | ``` shell 80 | python -m EasyLM.models.llama.convert_easylm_to_hf \ 81 | --load_checkpoint='params::path/to/koala/checkpoint' \ 82 | --tokenizer_path='path/to/llama/tokenizer' \ 83 | --model_size='13b' \ # '7b', '13b', '30b' or '65b' 84 | --output_dir='path/to/output/huggingface/koala/checkpoint' 85 | ``` 86 | 87 | 88 | ## Koala Chatbot Prompts 89 | As can been seen in the serving command above, the Koala chatbot requires a 90 | series of prompts to be prepended and appended to the user input in order to 91 | generate response correctly. Hence, to use the Koala weights in other frameworks, 92 | you will need to process the prompts accordingly. 93 | 94 | The beginning of prompt `BEGINNING OF CONVERSATION: ` is always prepended to 95 | every conversation. For each user input, the user prompt `USER: ` is prepended 96 | to the user input, a space ` ` is appended to the user input and then the 97 | language model prompt `GPT:` is appended to the user input. This whole string 98 | will be used as prompt input to the language model for generating the response. 99 | For example, in the first round of conversation, when the user inputs `Hello!`, 100 | the whole prompt for generating the first response is: 101 | 102 | ``` 103 | BEGINNING OF CONVERSATION: USER: Hello! GPT: 104 | ``` 105 | 106 | After the language model generates the response, we append the response to the 107 | prompt and then append the EOS token `` to the prompt. Suppose the language 108 | model generates the following response: `Hi! How can I help you?`, and for the 109 | next round, the user input is `What is the largest animal on earth?`. Then 110 | the whole prompt for generating the second response is: 111 | 112 | ``` 113 | BEGINNING OF CONVERSATION: USER: Hello! GPT:Hi! How can I help you?USER: What is the largest animal on earth? GPT: 114 | ``` 115 | 116 | Note that due to the prompt and generated parts are tokenized separately, there's 117 | no space between the model prompt `GPT:` and the generated response. 118 | -------------------------------------------------------------------------------- /docs/llama.md: -------------------------------------------------------------------------------- 1 | # LLaMA 2 | LLaMA is a language model developed by Meta. The official implementation can 3 | be found [here](https://github.com/facebookresearch/llama). EasyLM provides 4 | a JAX implementation of LLaMA, located at [EasyLM/models/llama](/EasyLM/models/llama). 5 | 6 | 7 | ## Converting the Official LLaMA Checkpoint to EasyLM Format 8 | If you are using our [OpenLLaMA](https://github.com/openlm-research/open_llama), 9 | you can directly download the EasyLM checkpoints and skip this section. 10 | If you are using the official LLaMA weights from Meta, the first step of is to 11 | convert the Huggingface transformers LLaMA checkpoint to the EasyLM checkpoint format. To do so, 12 | use the following command: 13 | 14 | ``` shell 15 | python -m EasyLM.models.llama.convert_hf_to_easylm \ 16 | --hf_model='path/to/transformers/llama/checkpoint' \ 17 | --output_file='path/to/output/easylm/checkpoint' \ 18 | --streaming=True \ 19 | --llama.base_model='llama_7b' 20 | ``` 21 | 22 | This script will convert the official torch checkpoint from Meta to the 23 | streaming checkpoint format used by EasyLM. If you set `--streaming` to `False`, 24 | the script will output a standard flax checkpoint instead. For more information 25 | about the checkpoint format of EasyLM, see [the checkpointing documentation](checkpointing.md). 26 | 27 | 28 | ## Fine-Tuning LLaMA 29 | After converting the checkpoint and setting up the data, you can fine-tune 30 | LLaMA with EasyLM. The training script is implemented in 31 | [EasyLM/models/llama/llama_train.py](/EasyLM/models/llama/llama_train.py). 32 | To fine-tune LLaMA, use the following command: 33 | 34 | ``` shell 35 | python -m EasyLM.models.llama.llama_train \ 36 | --mesh_dim='1,-1,1' \ 37 | --llama.base_model='llama_7b' \ 38 | --load_checkpoint='params::path/to/easylm/llama/checkpoint' \ 39 | ... 40 | ``` 41 | 42 | The following command line options are supported for the training script: 43 | * `seed`: The random seed to use for the training script. 44 | * `mesh_dim`: The mesh dimensions for the data, fully sharded data and model parallelism. 45 | LLaMA uses 3D mesh so a comma separated list of 3 values are required. See 46 | [the parallelism documentation](parallelism.md) for more details. 47 | * `dtype`: the float dtype to use for the model activation. Can be `bf16` or `fp16` or `fp32`. 48 | * 'params_dtype': the float dtype to use for the model parameters. Can be `bf16` or `fp16` or `fp32`. 49 | * `total_steps`: The total number of training steps. 50 | * `load_checkpoint`: the checkpoint to load. See [the checkpointing documentation](checkpointing.md) 51 | for more details. 52 | * `load_dataset_state`: the dataset state to load. Rarely used. 53 | * `log_freq`: the frequency of logging the training metrics. 54 | * `save_model_freq`: the frequency of saving the model checkpoint. The older 55 | checkpoints will be overwritten by the newest checkpoint. 56 | * `save_milestone_freq`: the frequency of saving the milestones of model checkpoint. 57 | The milestone checkpoints will not be overwritten. 58 | * `eval_steps`: the number of evaluation steps to run to evaluate the model. Setting 59 | to 0 will disable the evaluation. Using this requires the `eval_dataset` to be 60 | properly specified. 61 | * `tokenizer`: Huggingface transformers pretrained tokenizer. 62 | * `train_dataset`: training dataset configuration. See [the dataset documentation](dataset.md) 63 | for more details. 64 | * `eval_dataset`: evaluation dataset configuration. See [the dataset documentation](dataset.md) 65 | for more details. 66 | * `optimizer`: optimizer configuration. See [the optimizer documentation](optimizer.md) 67 | for more details. 68 | * `checkpointer`: checkpointer configuration. See [the checkpointing documentation](checkpointing.md) 69 | for more details. 70 | * `llama`: Specify the LLaMA configuration by starting from a base model. The avaiable configurations 71 | can be found in the [LLaMA model implementation](/EasyLM/models/llama/llama_model.py). 72 | * `logger`: logger configuration. For more details, see [the logger documentation](logger.md). 73 | * `log_all_workers`: whether to log the metrics from all workers in a multi-host 74 | setting. If set to `False`, only the metrics from the first worker will be logged. 75 | * `jax_distributed`: JAX distributed configuration. This only needs to be set when running 76 | multi-host training on GPU. 77 | 78 | 79 | ## Serving LLaMA 80 | You can serve the LLaMA model with the LMServer of EasyLM. To do so, use the 81 | following command: 82 | 83 | ``` shell 84 | python -m EasyLM.models.llama.llama_serve \ 85 | --mesh_dim='1,1,-1' \ 86 | --llama.base_model='llama_7b' \ 87 | --load_checkpoint='params::path/to/easylm/llama/checkpoint' \ 88 | ... 89 | ``` 90 | 91 | The following command line options are supported for the serving script: 92 | * `seed`: The random seed to use for the serving script. 93 | * `mesh_dim`: The mesh dimensions for the data, fully sharded data and model parallelism. 94 | LLaMA uses 3D mesh so a comma separated list of 3 values are required. See 95 | [the parallelism documentation](parallelism.md) for more details. 96 | * `dtype`: the float dtype to use for the model activation. Can be `bf16` or `fp16` or `fp32`. 97 | * `params_dtype`: the float dtype to use for the model parameters. Can be `bf16` or `fp16` or `fp32`. 98 | * `input_length`: the maximum length of the input sequence. 99 | * `seq_length`: the maximum length of the total sequence (input and output). 100 | * `top_k`: the number of top-k candidates to use for the sampling. 101 | * `top_p`: the top-p sampling probability. 102 | * `do_sample`: whether to use sampling or greedy decoding. 103 | * `num_beams`: the number of beams to use for beam search. 104 | * `add_bos_token`: whether to add the bos token for loglikelihood 105 | calculation and text generation. 106 | * `llama`: the LLaMA configuration to use. 107 | * `load_checkpoint`: the checkpoint to load. See [the checkpointing documentation](checkpointing.md) 108 | for more details. 109 | * `tokenizer`: Huggingface transformers pretrained tokenizer. 110 | * `lm_server`: the LM server configuration. See [the LM server documentation](serving.md) 111 | for more details. 112 | * `jax_distributed`: JAX distributed configuration. This only needs to be set when running 113 | multi-host training on GPU. 114 | 115 | ## Converting the EasyLM LLaMA Checkpoint to Huggingface LLaMA Checkpoint 116 | To facilitate the interoperability with Huggingface transformers, EasyLM also 117 | provides a script to convert the EasyLM LLaMA checkpoint to the Huggingface 118 | Pytorch LLaMA checkpoint. To do so, use the following command: 119 | 120 | ``` shell 121 | python -m EasyLM.models.llama.convert_easylm_to_hf \ 122 | --load_checkpoint='params::path/to/easylm/checkpoint' \ 123 | --output_dir='path/to/output/huggingface/llama/checkpoint' \ 124 | --llama.base_model='llama_7b' 125 | ``` 126 | -------------------------------------------------------------------------------- /docs/logger.md: -------------------------------------------------------------------------------- 1 | # Logger 2 | EasyLM uses the [MLXU](https://github.com/young-geng/mlxu) library for logging. 3 | Specifically, EasyLM uses the `mlxu.WandBLogger` module for logging. The `WandBLogger` 4 | module is a wrapper of the [Weights & Biases](https://wandb.ai/site) library. The 5 | following options are available for configuring the `WandBLogger` module: 6 | * `online`: Whether to log online. If `False`, the logger will not upload to W&B service. 7 | * `prefix`: The prefix of the W&B project name. 8 | * `project`: The W&B project name. 9 | * `output_dir`: The output directory for checkpointing, can be a local directory or a 10 | Google Cloud Storage directory. 11 | * `wandb_dir`: The output directory for W&B logs. Must be a local directory. 12 | * `random_delay`: Whether to add a random delay to the logging process. 13 | * `experiment_id`: The experiment ID. If not specified, a random ID will be generated. 14 | * `anonymous`: Whether to log anonymously. 15 | * `notes`: The notes for the experiment. 16 | * `entity`: The W&B entity name. 17 | * `prefix_to_id`: Whether to add a prefix to the experiment ID. 18 | 19 | For more information about the logger configuration, please refer to the 20 | [MLXU](https://github.com/young-geng/mlxu) library. 21 | -------------------------------------------------------------------------------- /docs/optimizer.md: -------------------------------------------------------------------------------- 1 | # Optimizers 2 | EasyLM provides a number of optimizers for training neural language models. The 3 | optimizers are implemented in the [optimizer.py](/EasyLM/optimizer.py) 4 | 5 | Currently, the following optimizers are supported: 6 | * AdamW 7 | * PaLM: the optimizer described in the PaLM paper 8 | 9 | In addition to optimizer configurations, the optimizer module also provides 10 | support for gradient accumulation. 11 | 12 | 13 | ## Selecting Optimizer and Gradient Accumulation 14 | Optimizer type can be selected by setting the `type` field in the optimizer 15 | configuration. For example, to use the AdamW optimizer, we can set the `type` to 16 | `adamw` and configuring the `adamw_optimizer` subfields: 17 | ```shell 18 | python train.py --optimizer.type=adamw --optimizer.adamw_optimizer.lr=1e-4 19 | ``` 20 | 21 | To use gradient accumulation, we can set the `accumulate_gradient_steps` field 22 | in the optimizer configuration. For example, to use gradient accumulation with 23 | step size 2, we can set the `accumulate_gradient_steps` to 2: 24 | ```shell 25 | python train.py --optimizer.accumulate_gradient_steps=2 26 | ``` 27 | 28 | The following options are supported for the optimizer module: 29 | * `type`: the optimizer type. Currently, `adamw` and `palm` are supported. 30 | * `adamw_optimizer`: the configuration for the AdamW optimizer 31 | * `palm_optimizer`: the configuration for the PaLM optimizer 32 | * `accumulate_gradient_steps`: the number of steps for gradient accumulation 33 | 34 | 35 | ## AdamW Optimizer 36 | The AdamW optimizer implements AdamW with liear learning rate warmup and cosine 37 | learning rate decay. The following options are supported for the AdamW optimizer: 38 | * `init_lr`: the initial learning rate 39 | * `end_lr`: the final learning rate after decay 40 | * `lr`: the peak learning rate 41 | * `lr_warmup_steps`: the number of steps for linear learning rate warmup 42 | * `lr_decay_steps`: the number of steps for cosine learning rate decay 43 | * `b1`: the beta1 parameter for AdamW 44 | * `b2`: the beta2 parameter for AdamW 45 | * `clip_gradient`: the gradient clipping threshold 46 | * `weight_decay`: the weight decay parameter for AdamW 47 | * `bf16_momentum`: whether to use bf16 for momentum to save memory 48 | * `multiply_by_parameter_scale`: whether to multiply the gradient by parameter scale (as in adafactor) 49 | 50 | 51 | ## PaLM Optimizer 52 | The PaLM optimizer implements the optimizer described in the PaLM paper. The optimizer 53 | is essential adafactor with no-factoring and a inverse square root learning rate 54 | decay and weight decay schedule. The following options are supported for the PaLM optimizer: 55 | * `lr`: the initial learning rate 56 | * `lr_warmup_steps`: the number of steps for constant learning rate warmup 57 | * `b1`: beta1 parameter for adafactor 58 | * `b2`: beta2 parameter for adafactor 59 | * `clip_gradient`: the gradient clipping threshold 60 | * `weight_decay`: the weight decay parameter 61 | * `bf16_momentum`: whether to use bf16 for momentum to save memory 62 | 63 | 64 | 65 | 66 | 67 | -------------------------------------------------------------------------------- /docs/parallelism.md: -------------------------------------------------------------------------------- 1 | # Model, Data and Fully Sharded Data Parallelism 2 | EasyLM supports flexible model and data parallelism for training and serving 3 | large language models. Specifically, EasyLM uses the PJIT feature of JAX 4 | to parallelize the computation across multiple of accelerators or multiple hosts. 5 | To do so, all the accelerators are first grouped into a multi-dimensional mesh, 6 | where each axis represents a different type of parallelism. Currently, EasyLM 7 | uses 3D meshes for most of the models. Typically, the first axis of the mesh is 8 | used for data parallelism, the second axis used for fully sharded data 9 | parallelism (FSDP), and the third axis is used for model parallelism. 10 | For more information about FSDP, please refer 11 | to [this FSDP tutorial](https://engineering.fb.com/2021/07/15/open-source/fsdp/). 12 | 13 | For example, if we have 8 accelerators for each host and 32 hosts in total, 14 | this gives us a total of 256 accelerators. We can use a 3D mesh of shape 15 | (1, 8, 32) to specify that one model is partitioned across 32 accelerators for 16 | model parallelism, and each parition has 8 replicas for fully sharded data parallelism. 17 | 18 | ## Specifying the Mesh Axis Dimensions 19 | While the multi-dimensional mesh parallelism is not very intuitive, EasyLM hides 20 | most of the complexity from the user. For most use cases, the user only needs 21 | to specify the parallelism axis dimensions based on the memory capacity and the 22 | compute performance of the accelerators used. Typically, this is done by passing 23 | the `mesh_dim` command line argument to the training or serving script. The 24 | `mesh_dim` is a comma separated list of integers representing the parallelism 25 | mesh axis dimensions. One of the axis dimensions can be `-1`, which means that 26 | the axis dimension will be inferred based on the total number of accelerators. 27 | 28 | For example, if we want to train a LLaMA model on 8 accelerators, 29 | we can pass in the following option for `mesh_dim`: 30 | ``` shell 31 | python -m EasyLM.models.llama.llama_train \ 32 | --mesh_dim='1,8,1' \ 33 | ... 34 | ``` 35 | 36 | This specifies that the model is paritioned across 8 accelerators for FSDP. Note that 37 | we can use `-1` for one of the axis dimensions, which means that the axis dimension 38 | will be inferred based on the total number of accelerators. For example, on a 8 39 | accelerator machine, specifying `1,1,-1` for `mesh_dim` is equivalent to 40 | specifying `1,1,8`. 41 | 42 | 43 | ## Tuning the Parallelism Axis Dimensions 44 | The parallelism axis dimensions can be tuned to achieve the best performance. 45 | Generally, it is recommended to use larger FSDP axis and a small model parallelism 46 | axis to achieve the best throughput. -------------------------------------------------------------------------------- /docs/serving.md: -------------------------------------------------------------------------------- 1 | # Serving and Client 2 | EasyLM provides an HTTP server and a client for serving and querying language 3 | models. The server and client are implemented in [serving.py](/EasyLM/serving.py). 4 | The HTTP server serves multiple endpoints as well as a chat web UI for querying 5 | the language model. These endpoints can be used to interact with the language 6 | model or perform batch evaluation. 7 | 8 | 9 | ## LMServer Interface 10 | The `LMServer` class implements the HTTP server. Each language model serving 11 | script should inherit from this class and implement the following static methods: 12 | 13 | * `loglikelihood(prefix_text, text)`: given a list of prefix text strings and a 14 | list of text strings, return the loglikelihood of the text strings given the 15 | prefix text strings. The prefix text strings do not contribute to the 16 | total loglikelihood. This static method returns a pair of lists, where the 17 | first list contains the loglikelihoods of the text strings and the second list 18 | contains whether the text strings match the greedy decoding choice with the 19 | maximum log likelihood. 20 | * `loglikelihood_rolling(text)`: computes the log likelihood of the text strings, 21 | where the text strings might be longer than the maximum sequence length of the 22 | language model. If the text strings are longer than the maximum sequence length, 23 | the log likelihood is computed using a window. This method returns a pair of 24 | lists, where the first list contains the loglikelihoods of the text strings and 25 | the second list contains whether the text strings match the greedy decoding. 26 | * `generate(prefix_text, temperature)`: given a list of prefix text strings and 27 | a temperature value, generate a list of strings. This method returns the list 28 | of generated strings. 29 | * `greedy_until(prefix_text, until, max_length)`: given a list of prefix text 30 | strings, a list of until strings, and a maximum length, generate a list of 31 | strings greedily. The generated strings will be generated until the until strings are 32 | generated, or the maximum length is reached. This method returns the list of 33 | generated strings. 34 | 35 | These static methods are called by the HTTP server to serve the endpoints. These 36 | methods are defined largely by the [Language Model Evaluation Harness 37 | ](https://github.com/EleutherAI/lm-evaluation-harness) library, which is used by 38 | EasyLM to evaluate the served language models. 39 | 40 | 41 | ## LMServer Endpoints and LMCient 42 | The `LMServer` class implements the following endpoints for querying the language 43 | model with HTTP requests. These endpoints can be queried by sending a JSON 44 | dictionary using the POST method. The following fields are used for each endpoint: 45 | 46 | #### `/loglikelihood` 47 | The input JSON dictionary should contain the following fields: 48 | * `prefix_text`: a list of prefix text strings. 49 | * `text`: a list of text strings. 50 | 51 | The output JSON dictionary contains the following fields: 52 | * `loglikelihood`: a list of loglikelihoods of the text strings given the prefix 53 | text strings. 54 | * `is_greedy`: a list of booleans indicating whether the text strings match the 55 | greedy decoding choice with the maximum log likelihood. 56 | 57 | 58 | #### `/serve_loglikelihood_rolling` 59 | The input JSON dictionary should contain the following fields: 60 | * `text`: a list of text strings. 61 | 62 | The output JSON dictionary contains the following fields: 63 | * `loglikelihood`: a list of loglikelihoods of the text strings. 64 | * `is_greedy`: a list of booleans indicating whether the text strings match the 65 | greedy decoding choice with the maximum log likelihood. 66 | 67 | 68 | #### `/generate` 69 | The input JSON dictionary should contain the following fields: 70 | * `prefix_text`: a list of prefix text strings. 71 | * `temperature`: a temperature value. 72 | 73 | The output JSON dictionary contains the following fields: 74 | * `output_text`: a list of generated text strings. 75 | 76 | 77 | #### `/greedy_until` 78 | The input JSON dictionary should contain the following fields: 79 | * `prefix_text`: a list of prefix text strings. 80 | * `until`: a list of until strings. 81 | 82 | The output JSON dictionary contains the following fields: 83 | * `output_text`: a list of generated text strings. 84 | 85 | 86 | #### `/chat` 87 | The chat endpoint is intended to be used for a dialogue language model. The input 88 | JSON dictionary should contain the following fields: 89 | * `prompt`: a prompt string. 90 | * `context`: a context string. Can be empty for the first query. 91 | * `temperature`: a temperature value. 92 | 93 | The output JSON dictionary contains the following fields: 94 | * `response`: a model response string. 95 | * `context`: the updated context string containing the chat history. This is 96 | used for the next round of dialogue. 97 | 98 | ### Chat UI 99 | For interacting with a dialogue language model over the web UI, simply navigate 100 | to the root of the HTTP server. The chat UI will be served at the root URL. 101 | 102 | 103 | ### LMCient 104 | The `LMClient` class implements a client for querying the served language model. 105 | The python methods of this class are similar to the endpoints of the HTTP server. 106 | 107 | 108 | ## LMServer Options 109 | The `LMServer` class implements the following command line options: 110 | * `host`: the host ip address to serve the HTTP server. 111 | * `port`: the port to serve the HTTP server. 112 | * `batch_size`: the batch size for serving the language model. 113 | * `logging`: whether to log the requests to the HTTP server. 114 | * `pre_compile`: a command separated list of endpoints to trigger JAX compilation 115 | before serving the language model. This is useful for speeding up the first 116 | request to the language model. The following endpoints are supported: 117 | `loglikelihood`, `generate`, `greedy_until`, `chat`, or `all` for all endpoints. 118 | * `default_temperature`: the default temperature for the `generate` endpoint. 119 | * `greedy_until_max_length`: the maximum length for the `greedy_until` endpoint. 120 | * `prepend_to_prefix`: a string to prepend to the prefix text strings for the 121 | `loglikelihood` and `generate` and `greedy_until` endpoints. 122 | * `append_to_prefix`: a string to append to the prefix text strings for the 123 | `loglikelihood` and `generate` and `greedy_until` endpoints. 124 | * `prepend_to_text`: a string to prepend to the text strings for the `loglikelihood` 125 | endpoint. 126 | * `append_to_text`: a string to append to the text strings for the `loglikelihood` 127 | endpoint. 128 | * `chat_prepend_text`: a string to prepend to the context strings for the `chat` 129 | endpoint. 130 | * `chat_user_prefix`: a string to prepend to the user input strings for the `chat` 131 | endpoint. 132 | * `chat_user_suffix`: a string to append to the user input strings for the `chat` 133 | endpoint. 134 | * `chat_lm_prefix`: a string to prepend to the model response strings for the `chat` 135 | endpoint. 136 | * `chat_lm_suffix`: a string to append to the model response strings for the `chat` 137 | * `notes`: a string to display on the chat UI. 138 | 139 | 140 | ## LMClient Options 141 | The `LMClient` class implements the following command line options: 142 | * `url`: the base URL of the HTTP server. 143 | * `wait_for_ready`: whether to wait for the HTTP server to be ready before 144 | sending requests. 145 | * `dummy`: whether to use a dummy language model for debugging. If set to True, 146 | the LMCient will always return some fixed results. -------------------------------------------------------------------------------- /examples/pretrain_llama_7b.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | # This is the example script to pretrain a 7B LLaMA model on a TPU v4-512 pod. 4 | # These hyperparameters are the ones we used to train the OpenLLaMA 7B model on 5 | # the RedPajama dataset. To use this on TPU pod, you need to run this 6 | # script on every hosts in a TPU pod. 7 | 8 | # Put your WANDB API key here to enable logging to wandb. 9 | export WANDB_API_KEY='' 10 | 11 | # TPU specific flags to improve training throughput 12 | export LIBTPU_INIT_ARGS='--xla_jf_spmd_threshold_for_windowed_einsum_mib=0 --xla_tpu_spmd_threshold_for_allgather_cse=10000 --xla_tpu_spmd_rewrite_einsum_with_reshape=true --xla_enable_async_all_gather=true --jax_enable_async_collective_offload=true --xla_tpu_enable_latency_hiding_scheduler=true TPU_MEGACORE=MEGACORE_DENSE' 13 | 14 | 15 | python -m EasyLM.models.llama.llama_train \ 16 | --mesh_dim='-1,64,1' \ 17 | --dtype='fp32' \ 18 | --total_steps=250000 \ 19 | --log_freq=50 \ 20 | --save_model_freq=0 \ 21 | --save_milestone_freq=2500 \ 22 | --load_llama_config='7b' \ 23 | --update_llama_config='' \ 24 | --load_dataset_state='' \ 25 | --load_checkpoint='' \ 26 | --tokenizer.vocab_file='/path/to/llama/tokenizer/file' \ 27 | --optimizer.type='adamw' \ 28 | --optimizer.adamw_optimizer.weight_decay=0.1 \ 29 | --optimizer.adamw_optimizer.lr=3e-4 \ 30 | --optimizer.adamw_optimizer.end_lr=3e-5 \ 31 | --optimizer.adamw_optimizer.lr_warmup_steps=2000 \ 32 | --optimizer.adamw_optimizer.lr_decay_steps=250000 \ 33 | --train_dataset.type='json' \ 34 | --train_dataset.text_processor.fields='text' \ 35 | --train_dataset.json_dataset.path='/path/to/shuffled/redpajama/dataset' \ 36 | --train_dataset.json_dataset.seq_length=2048 \ 37 | --train_dataset.json_dataset.batch_size=2048 \ 38 | --train_dataset.json_dataset.tokenizer_processes=16 \ 39 | --checkpointer.save_optimizer_state=True \ 40 | --logger.online=True \ 41 | --logger.prefix='EasyLM' \ 42 | --logger.project="open_llama_7b" \ 43 | --logger.output_dir="/path/to/checkpoint/dir" \ 44 | --logger.wandb_dir="$HOME/experiment_output/open_llama_7b" \ 45 | |& tee $HOME/output.txt 46 | 47 | -------------------------------------------------------------------------------- /examples/serve_llama_7b.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | # This is the example script to serve a 7B LLaMA model on a GPU machine or 4 | # single TPU v3-8 VM. The server will be listening on port 35009. 5 | 6 | 7 | python -m EasyLM.models.llama.llama_serve \ 8 | --load_llama_config='7b' \ 9 | --load_checkpoint="params::/path/to/checkpoint/file" \ 10 | --tokenizer.vocab_file='/path/to/llama/tokenizer/vocab/file' \ 11 | --mesh_dim='1,-1,1' \ 12 | --dtype='bf16' \ 13 | --input_length=1024 \ 14 | --seq_length=2048 \ 15 | --lm_server.batch_size=4 \ 16 | --lm_server.port=35009 \ 17 | --lm_server.pre_compile='all' 18 | 19 | -------------------------------------------------------------------------------- /scripts/gpu_environment.yml: -------------------------------------------------------------------------------- 1 | name: EasyLM 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | dependencies: 7 | - python=3.10 8 | - pip 9 | - numpy 10 | - scipy 11 | - matplotlib 12 | - seaborn 13 | - jupyter 14 | - tqdm 15 | - pytorch=2.3.0 16 | - pytorch-cuda=12.1 17 | - pip: 18 | - jax[cuda12]==0.4.28 19 | - flax==0.8.3 20 | - optax==0.2.2 21 | - transformers==4.41.0 22 | - torch==2.3.0 23 | - sentencepiece 24 | - datasets 25 | - mlxu >= 0.1.13 26 | - einops 27 | - gcsfs 28 | - requests 29 | - lm-eval 30 | - pydantic 31 | - fastapi 32 | - uvicorn 33 | - gradio -------------------------------------------------------------------------------- /scripts/tpu_commands.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | function _tpu_ips { 4 | tpu_zone=$1 5 | tpu_project=$2 6 | tpu_name=$3 7 | gcloud alpha compute tpus tpu-vm describe $tpu_name --zone $tpu_zone --project $tpu_project | grep -oP 'externalIp: \K(.+)$' 8 | 9 | } 10 | 11 | function _tpu_create { 12 | tpu_zone=$1 13 | tpu_project=$2 14 | tpu_gen=$3 15 | tpu_cores=$4 16 | tpu_name=$5 17 | if [ "$tpu_gen" = "v3" ]; then 18 | software_version='tpu-vm-base' 19 | else 20 | software_version='tpu-vm-v4-base' 21 | fi 22 | 23 | if [[ $tpu_cores =~ ^[0-9]+$ ]]; then 24 | gcloud alpha compute tpus tpu-vm create \ 25 | $tpu_name \ 26 | --accelerator-type="$tpu_gen-$tpu_cores" \ 27 | --version $software_version \ 28 | --zone $tpu_zone \ 29 | --project $tpu_project 30 | else 31 | gcloud alpha compute tpus tpu-vm create \ 32 | $tpu_name \ 33 | --type="$tpu_gen" \ 34 | --topology="$tpu_cores" \ 35 | --version $software_version \ 36 | --zone $tpu_zone \ 37 | --project $tpu_project 38 | fi 39 | } 40 | 41 | function _tpu_retry_create { 42 | while true; do 43 | _tpu_create "$@" 44 | sleep 120s 45 | done 46 | } 47 | 48 | function _tpu_cp_ssh_key { 49 | tpu_zone=$1 50 | tpu_project=$2 51 | tpu_name=$3 52 | 53 | gcloud alpha compute tpus tpu-vm scp \ 54 | $HOME/.ssh/authorized_keys \ 55 | $tpu_name:/home/$USER/.ssh/ \ 56 | --worker=all \ 57 | --project $tpu_project \ 58 | --zone $tpu_zone 59 | } 60 | 61 | function _tpu_setup { 62 | tpu_zone=$1 63 | tpu_project=$2 64 | tpu_name=$3 65 | 66 | tpu_ips=$(_tpu_ips $tpu_zone $tpu_project $tpu_name) 67 | for host in $tpu_ips; do 68 | scp $PROJECT_HOME/$PROJECT_NAME/scripts/tpu_vm_setup.sh $host:~/ 69 | ssh $host '~/tpu_vm_setup.sh' & 70 | done 71 | wait &> /dev/null 72 | 73 | for host in $tpu_ips; do 74 | scp $PROJECT_HOME/$PROJECT_NAME/scripts/tpu_vm_setup.sh $host:~/ 75 | ssh $host '~/tpu_vm_setup.sh' & 76 | done 77 | wait &> /dev/null 78 | } 79 | 80 | function _tpu_check { 81 | tpu_zone=$1 82 | tpu_project=$2 83 | tpu_name=$3 84 | 85 | tpu_ips=$(_tpu_ips $tpu_zone $tpu_project $tpu_name) 86 | for host in $tpu_ips; do 87 | echo "============== Checking host: $host ==============" 88 | ssh $host 'tmux capture-pane -pt launch -S -2000' 89 | echo 90 | echo 91 | done 92 | } 93 | 94 | function _tpu_copy { 95 | tpu_zone=$1 96 | tpu_project=$2 97 | tpu_name=$3 98 | 99 | tpu_ips=$(_tpu_ips $tpu_zone $tpu_project $tpu_name) 100 | for host in $tpu_ips; do 101 | rsync -avPI --exclude=logs --exclude=__pycache__ --exclude=.git $PROJECT_HOME/$PROJECT_NAME $host:~/ & 102 | done 103 | wait &> /dev/null 104 | sleep 1s 105 | 106 | for host in $tpu_ips; do 107 | rsync -avPI --exclude=logs --exclude=__pycache__ --exclude=.git $PROJECT_HOME/$PROJECT_NAME $host:~/ & 108 | done 109 | wait &> /dev/null 110 | sleep 1s 111 | } 112 | 113 | function _tpu_stop { 114 | tpu_zone=$1 115 | tpu_project=$2 116 | tpu_name=$3 117 | 118 | tpu_ips=$(_tpu_ips $tpu_zone $tpu_project $tpu_name) 119 | for host in $tpu_ips; do 120 | ssh $host 'tmux kill-session -t launch ; pkill -9 python' & 121 | done 122 | wait &> /dev/null 123 | } 124 | 125 | function _tpu_launch { 126 | tpu_zone=$1 127 | tpu_project=$2 128 | tpu_name=$3 129 | command=$4 130 | 131 | if [ -z "$command" ]; then 132 | echo "Invalid syntax!" 133 | return 1 134 | fi 135 | 136 | tpu_ips=$(_tpu_ips $tpu_zone $tpu_project $tpu_name) 137 | for host in $tpu_ips; do 138 | ssh $host "tmux new -d -s launch ~/$PROJECT_NAME/launcher/$command" & 139 | done 140 | wait &> /dev/null 141 | } 142 | 143 | function _tpu_maintain { 144 | tpu_zone=$1 145 | tpu_project=$2 146 | tpu_name=$3 147 | 148 | gcloud alpha compute tpus tpu-vm simulate-maintenance-event $tpu_name \ 149 | --project $tpu_project \ 150 | --zone=$tpu_zone \ 151 | --workers=all 152 | } 153 | 154 | function _tpu_ssh { 155 | tpu_zone=$1 156 | tpu_project=$2 157 | tpu_name=$3 158 | command="$4" 159 | 160 | if [ -z "$command" ]; then 161 | echo "Invalid syntax!" 162 | return 1 163 | fi 164 | 165 | tpu_ips=$(_tpu_ips $tpu_zone $tpu_project $tpu_name) 166 | for host in $tpu_ips; do 167 | ssh $host "$command" & 168 | done 169 | wait &> /dev/null 170 | } 171 | 172 | function _tpu_reboot { 173 | tpu_zone=$1 174 | tpu_project=$2 175 | tpu_name=$3 176 | 177 | tpu_ips=$(_tpu_ips $tpu_zone $tpu_project $tpu_name) 178 | for host in $tpu_ips; do 179 | ssh $host 'sudo reboot' & 180 | done 181 | wait &> /dev/null 182 | } 183 | 184 | 185 | function tpu { 186 | trap "trap - SIGINT SIGTERM; return 1;" SIGINT SIGTERM 187 | 188 | 189 | # =============== TPU Project Specific Definitions =============== 190 | export PROJECT_HOME=' $HOME/tpu_requirements.txt <<- EndOfFile 17 | -f https://storage.googleapis.com/jax-releases/libtpu_releases.html 18 | jax[tpu]==0.4.28 19 | flax==0.8.3 20 | optax==0.2.2 21 | einops 22 | --extra-index-url https://download.pytorch.org/whl/cpu 23 | torch==2.3.0 24 | transformers==4.41.0 25 | datasets==2.19.1 26 | tqdm 27 | requests 28 | typing-extensions 29 | mlxu>=0.1.13 30 | sentencepiece 31 | pydantic 32 | fastapi 33 | uvicorn 34 | gradio 35 | EndOfFile 36 | 37 | pip install --upgrade -r $HOME/tpu_requirements.txt 38 | 39 | 40 | # vim configurations 41 | cat > $HOME/.vimrc <<- EndOfFile 42 | set tabstop=4 43 | set shiftwidth=4 44 | set softtabstop=4 45 | set expandtab 46 | set backspace=indent,eol,start 47 | syntax on 48 | EndOfFile 49 | 50 | # tmux configurations 51 | cat > $HOME/.tmux.conf <<- EndOfFile 52 | bind r source-file ~/.tmux.conf 53 | 54 | set -g prefix C-a 55 | 56 | set -g set-titles on 57 | set -g set-titles-string '#(whoami)::#h::#(curl ipecho.net/plain;echo)' 58 | 59 | set -g default-terminal "screen-256color" 60 | 61 | # Status bar customization 62 | #set -g status-utf8 on 63 | set -g status-bg white 64 | set -g status-fg black 65 | set -g status-interval 5 66 | set -g status-left-length 90 67 | set -g status-right-length 60 68 | 69 | set -g status-justify left 70 | 71 | unbind-key C-o 72 | bind -n C-o prev 73 | unbind-key C-p 74 | bind -n C-p next 75 | unbind-key C-w 76 | bind -n C-w new-window 77 | 78 | unbind-key C-j 79 | bind -n C-j select-pane -D 80 | unbind-key C-k 81 | bind -n C-k select-pane -U 82 | unbind-key C-h 83 | bind -n C-h select-pane -L 84 | unbind-key C-l 85 | bind -n C-l select-pane -R 86 | 87 | unbind-key C-e 88 | bind -n C-e split-window -h 89 | unbind-key C-q 90 | bind -n C-q split-window -v 91 | unbind '"' 92 | unbind % 93 | 94 | unbind-key u 95 | bind-key u split-window -h 96 | unbind-key i 97 | bind-key i split-window -v 98 | EndOfFile 99 | 100 | 101 | # htop Configurations 102 | mkdir -p $HOME/.config/htop 103 | cat > $HOME/.config/htop/htoprc <<- EndOfFile 104 | # Beware! This file is rewritten by htop when settings are changed in the interface. 105 | # The parser is also very primitive, and not human-friendly. 106 | fields=0 48 17 18 38 39 40 2 46 47 49 1 107 | sort_key=46 108 | sort_direction=1 109 | hide_threads=0 110 | hide_kernel_threads=1 111 | hide_userland_threads=1 112 | shadow_other_users=0 113 | show_thread_names=0 114 | show_program_path=1 115 | highlight_base_name=0 116 | highlight_megabytes=1 117 | highlight_threads=1 118 | tree_view=0 119 | header_margin=1 120 | detailed_cpu_time=0 121 | cpu_count_from_zero=0 122 | update_process_names=0 123 | account_guest_in_cpu_meter=0 124 | color_scheme=0 125 | delay=15 126 | left_meters=CPU Memory Swap 127 | left_meter_modes=1 1 1 128 | right_meters=Tasks LoadAverage Uptime 129 | right_meter_modes=2 2 2 130 | EndOfFile 131 | --------------------------------------------------------------------------------