├── .gitignore ├── README.md ├── container ├── Dockerfile └── magix-gpu.def ├── generate.py ├── magix ├── __init__.py ├── checkpoint_utils.py ├── lora.py ├── models │ ├── __init__.py │ ├── bert_model.py │ ├── gemma_model.py │ ├── llama_model.py │ ├── mistral_model.py │ └── t5_model.py └── spmd_utils.py ├── setup.py ├── train.py └── train_lora.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Magix 2 | Magix is a mininalist toolkit for training LLM with flexible data and model parallel. 3 | 4 | ## Features 5 | - Training Billion-scale LLM on GPUs and TPUs. 6 | - Familiar Huggingface model interfaces and eco-system (dataset, hub, etc.). 7 | - Pre-defined model parallel (sharding) rules for popular models like Llama, Mistral, Gemma, etc. 8 | - Acceleration with flash attention and operation fusion. 9 | - Fast checkpoint save/restore with arbirary device and parallism design. 10 | 11 | ## Magix 101 12 | If you have ever used Huggingface Flax transformers, using magix is as simple as adding several magic functions into the common worflow. 13 | 14 | 1. We start by importing necessary dependencies, 15 | ``` 16 | import magix 17 | from magix.models.llama_model import FlaxLlamaForCausalLM 18 | ``` 19 | 20 | 2. We will explicitly reason about all the GPU(TPU) devices available to us. We will place the GPUs in a grid (aka mesh) using the `magix.create_device_mesh` function. 21 | ``` 22 | # Assume we have 4 GPUs in total; we can arrange them arbitrarily. 23 | # Say, we arrange them into 2x2 mesh and name the first axis `data` and the second axis `model`. 24 | # These axes will be responsible for data and model parallelisms respectively. 25 | 26 | mesh = magix.create_device_mesh((2,2), names=('data', 'model')) 27 | ``` 28 | 29 | 3. For the next step we will load our model onto the mesh, each device will hold a part (shard) of the full model. Instead of the familiar `from_pretrained`, we will use the function `magix.load_model_hub` function which will call `from_pretrained` internally but also place the model correctly. 30 | ``` 31 | model, params = magix.load_model_hub( 32 | FlaxLlamaForCausalLM, 33 | 'meta-llama/Llama-2-13b', 34 | FlaxLlamaForCausalLM.partition_rules, # use the pre-defined partitioning 35 | mesh 36 | ) 37 | ``` 38 | Here `params` is partitioned and placed on to the mesh. As a side note, JAX will reason about model definition and parameter seperately, analogous to `y = f(x|θ)`. 39 | 40 | 4. For training, you will also need to do something simlar and build the optimizer states onto the mesh, 41 | ``` 42 | opt_state = magix.initialize_opt_state(optimizer, params, sharding_config, mesh) 43 | ``` 44 | 45 | 5. You may have seen tutorial using `jax.pmap`. For our case with both data and model parallelism, we will use the more powerful `jax.jit`, 46 | ``` 47 | train_step = jax.jit( 48 | train_step, # or generate_step 49 | donate_argnums=... # set based on the actual function input 50 | out_shardings=(magix.item_sharding(params), magix.item_sharding(opt_state),... # set based on the actual function output 51 | ) 52 | ``` 53 | 54 | With all these, you are ready to start your training/inference loop. 55 | 56 | Take a look at the complete scripts in [train.py](https://github.com/luyug/magix/blob/main/train.py), [train_lora.py](https://github.com/luyug/magix/blob/main/train_lora.py) and [generate.py](https://github.com/luyug/magix/blob/main/generate.py). 57 | 58 | ## Example: Train a Mistral ChatBot with Lora and Data&Tensor Parallelism 59 | Assume we have 4 GPUs. Let's train `mistral-7b` on `UltraChat` with data and tensor parallism, `dp=2` and `tp=2` (`mesh_shape=2 2`): 60 | ``` 61 | python train_lora.py \ 62 | --checkpoint_dir /absolute/path/to/checkpoint \ 63 | --model_type mistral \ 64 | --model_name mistralai/Mistral-7B-v0.1 \ 65 | --tokenizer_name mistralai/Mistral-7B-v0.1 \ 66 | --train_file HuggingFaceH4/ultrachat_200k \ 67 | --split train_sft \ 68 | --train_data_field messages \ 69 | --use_chat_template \ 70 | --batch_size 32 \ 71 | --num_epochs 1 \ 72 | --learning_rate 5e-5 \ 73 | --seed 12345 \ 74 | --mesh_shape 2 2 \ 75 | --weight_decay 0.001 \ 76 | --max_length 1024 77 | ``` 78 | After training, let's solve some math problems. Do generation with full tensor parallel `tp=4` (`mesh_shape=1 -1`): 79 | ``` 80 | python generate.py \ 81 | --prompts gsm8k \ 82 | --hf_data_config main \ 83 | --hf_data_split test \ 84 | --use_chat_template \ 85 | --data_field question \ 86 | --output_file generation.jsonl \ 87 | --mesh_shape 1 -1 \ 88 | --model_type mistral \ 89 | --model_name_or_path mistralai/Mistral-7B-v0.1 \ 90 | --tokenizer_name_or_path mistralai/Mistral-7B-v0.1 \ 91 | --model_config_name mistralai/Mistral-7B-v0.1 \ 92 | --batch_size 32 \ 93 | --pad_to_multiple_of 64 \ 94 | --max_length 512 \ 95 | --lora /absolute/path/to/checkpoint/EVALUATION_STEP/lora 96 | ``` 97 | 98 | ## Runnning on GPUs 99 | We recommend using the jax-toolbox jax [container image](https://github.com/NVIDIA/JAX-Toolbox/pkgs/container/jax) from nvidia. We have example [Dockerfile](https://github.com/luyug/magix/blob/main/container/Dockerfile) and Singulrity [Definition File](https://github.com/luyug/magix/blob/main/container/magix-gpu.def). 100 | 101 | ## Runing on TPUs 102 | Install appropriate `jax` build, `torch-cpu` and then the rest of the dependencies. 103 | ``` 104 | pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html 105 | 106 | # get torch-cpu for model conversion 107 | pip install torch --index-url https://download.pytorch.org/whl/cpu 108 | 109 | git clone https://github.com/luyug/magix.git 110 | cd magix 111 | pip install -e . 112 | ``` 113 | -------------------------------------------------------------------------------- /container/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ghcr.io/nvidia/jax:jax-2024-03-08 2 | 3 | RUN apt-get update && \ 4 | apt-get install -y --no-install-recommends python3-pip && \ 5 | apt-get clean && \ 6 | rm -rf /var/lib/apt/lists/* && \ 7 | pip install --no-cache-dir transformers sentencepiece simple_parsing datasets orbax==0.4.8 && \ 8 | pip install --no-cache-dir torch --index-url https://download.pytorch.org/whl/cpu -------------------------------------------------------------------------------- /container/magix-gpu.def: -------------------------------------------------------------------------------- 1 | Bootstrap: docker 2 | From: ghcr.io/nvidia/jax:jax-2024-03-08 3 | 4 | %post 5 | apt-get update && \ 6 | apt-get install -y --no-install-recommends \ 7 | python3-pip \ 8 | && \ 9 | apt-get clean && \ 10 | rm -rf /var/lib/apt/lists/* 11 | 12 | pip install --no-cache-dir transformers sentencepiece simple_parsing datasets orbax==0.4.8 13 | pip install --no-cache-dir torch --index-url https://download.pytorch.org/whl/cpu -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import json 4 | 5 | from dataclasses import dataclass, field 6 | from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union 7 | from tqdm import tqdm, trange 8 | from functools import partial 9 | 10 | import jax 11 | from jax.sharding import Mesh 12 | from jax.sharding import PartitionSpec as PS 13 | from jax.sharding import NamedSharding 14 | import orbax.checkpoint 15 | 16 | import datasets 17 | from transformers import AutoTokenizer, AutoConfig 18 | from simple_parsing import ArgumentParser 19 | from simple_parsing.helpers import list_field 20 | 21 | import magix 22 | import magix.models 23 | import magix.lora 24 | 25 | @dataclass 26 | class GenerateArgs: 27 | prompts: str = None 28 | use_chat_template: bool = False 29 | data_field : str = 'prompt' 30 | hf_data_config: str = None 31 | hf_data_split: str = 'test' 32 | output_file: str = 'generated.txt' 33 | batch_size: int = 32 34 | pad_to_multiple_of: int = 64 35 | sample: bool = False 36 | tempearature: float = 0.7 37 | seed: int = 42 38 | max_length: int = 256 39 | model_type: str = 'llama' 40 | model_name_or_path: str = None 41 | model_config_name: Optional[str] = None 42 | tokenizer_name_or_path: str = None 43 | mesh_shape: List[int] = list_field(1, -1) 44 | hf_format: bool = False 45 | lora: str = None 46 | lora_alpha: float = 32.0 47 | 48 | def main(): 49 | parser = ArgumentParser() 50 | parser.add_arguments(GenerateArgs, dest="generate_args") 51 | args = parser.parse_args().generate_args 52 | 53 | tokenizer = AutoTokenizer.from_pretrained( 54 | args.tokenizer_name_or_path, 55 | add_eos_token=False, 56 | use_fast=True, 57 | padding_side='left', 58 | legacy=False 59 | ) 60 | tokenizer.pad_token = tokenizer.eos_token 61 | 62 | _model_cls = magix.models.CAUSAL_LM_MODEL_MAPPING.get(args.model_type) 63 | if _model_cls is None: 64 | raise ValueError(f"Model type {args.model_type} not found") 65 | 66 | mesh = magix.create_device_mesh(args.mesh_shape) 67 | 68 | if args.hf_format or not os.path.exists(args.model_name_or_path): 69 | model, params = magix.load_model_hub( 70 | _model_cls, 71 | args.model_name_or_path, 72 | _model_cls.partition_rules, 73 | mesh, 74 | half=True, 75 | from_pt=True, 76 | ) 77 | else: 78 | model, params = magix.load_model_local( 79 | _model_cls, 80 | args.model_name_or_path, 81 | _model_cls.partition_rules, 82 | mesh, 83 | model_config=AutoConfig.from_pretrained(args.model_config_name), 84 | ) 85 | 86 | if args.lora is not None: 87 | lora = magix.lora.Lora( 88 | args.lora_alpha, 89 | rules={ 90 | 'layers/.*/kernel': 1, # rank place holder 91 | } 92 | ) 93 | # infer the lora parameters 94 | lora_params_absract = jax.eval_shape(lora.init_params, jax.random.PRNGKey(0), params) 95 | lora_params_sharding = magix.lora.create_lora_sharding(_model_cls.partition_rules, mesh, lora_params_absract) 96 | lora_params = magix.checkpoint_utils.load_by_sharding_no_manager(lora_params_sharding, args.lora) 97 | params = jax.jit( 98 | lora.apply, 99 | donate_argnums=(0,), 100 | in_shardings=(magix.item_sharding(params), magix.item_sharding(lora_params)), 101 | out_shardings=magix.item_sharding(params) 102 | ) (params, lora_params) 103 | del lora_params 104 | 105 | def tokenize(batch): 106 | return tokenizer( 107 | batch, 108 | padding=True, 109 | max_length=args.max_length, 110 | pad_to_multiple_of=args.pad_to_multiple_of, 111 | truncation=True, 112 | return_tensors="np", 113 | ) 114 | 115 | @partial( 116 | jax.jit, 117 | static_argnames=('sample', 'tempearature',), 118 | out_shardings=NamedSharding(mesh, PS()), 119 | donate_argnums=(3,) 120 | ) 121 | def generate( 122 | params, 123 | inputs, 124 | mask, 125 | rng_key=None, 126 | sample=False, 127 | tempearature=1.0, 128 | ): 129 | generation = model.generate( 130 | inputs, 131 | attention_mask=mask, 132 | prng_key=rng_key, 133 | max_length=args.max_length, 134 | params=params, 135 | do_sample=sample, 136 | temperature=tempearature, 137 | ).sequences 138 | new_rng_key, _ = jax.random.split(rng_key) 139 | 140 | return generation, new_rng_key 141 | 142 | if args.prompts.endswith('.txt'): 143 | with open(args.prompts, 'r') as f: 144 | prompts = [l.strip() for l in f] 145 | elif args.prompts.endswith('.jsonl'): 146 | with open(args.prompts, 'r') as f: 147 | prompts = [json.loads(l)[args.data_field] for l in f] 148 | else: 149 | prompts = datasets.load_dataset( 150 | args.prompts, args.hf_data_config 151 | )[args.hf_data_split][args.data_field] 152 | 153 | if args.use_chat_template: 154 | CHAT_FORMAT = '<|user|>\n{prompt}{eos}<|assistant|>\n' 155 | prompts = [CHAT_FORMAT.format(prompt=p, eos=tokenizer.eos_token) for p in prompts] 156 | 157 | rng_key = jax.random.PRNGKey(args.seed) 158 | 159 | with open(args.output_file, 'w') as f: 160 | with mesh: 161 | for i in trange(0, len(prompts), args.batch_size): 162 | batch = prompts[i:i+args.batch_size] 163 | batch_size = len(batch) 164 | if batch_size < args.batch_size: 165 | batch += ['EMPTY'] * (args.batch_size - len(batch)) 166 | batch = tokenize(batch) 167 | generated, rng_key = generate( 168 | params, 169 | batch['input_ids'], 170 | batch['attention_mask'], 171 | rng_key, 172 | sample=args.sample, 173 | tempearature=args.tempearature, 174 | ) 175 | input_seq_len = batch['input_ids'].shape[1] 176 | generated = generated[:, input_seq_len:] 177 | generated = tokenizer.batch_decode( 178 | generated, skip_special_tokens=True) 179 | for p, g in zip(prompts[i:i+batch_size], generated): 180 | f.write(json.dumps({'prompt': p, 'generated': g}) + '\n') 181 | 182 | if __name__ == "__main__": 183 | main() 184 | -------------------------------------------------------------------------------- /magix/__init__.py: -------------------------------------------------------------------------------- 1 | from .checkpoint_utils import ( 2 | load_model_and_optimizer_local, 3 | load_model_local, 4 | load_model_hub, 5 | save_model_local, 6 | get_chckpoint_manager 7 | ) 8 | from .spmd_utils import ( 9 | initialize_opt_state, 10 | item_sharding, 11 | create_device_mesh 12 | ) -------------------------------------------------------------------------------- /magix/checkpoint_utils.py: -------------------------------------------------------------------------------- 1 | import orbax.checkpoint 2 | import numpy as np 3 | import jax 4 | import logging 5 | from typing import Any, Iterable 6 | from functools import partial 7 | from jax.sharding import Mesh 8 | 9 | from . import spmd_utils 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | 15 | def array_restore_args_from_sharding_pytree(pytree): 16 | return jax.tree_util.tree_map( 17 | lambda s: orbax.checkpoint.ArrayRestoreArgs( 18 | restore_type=jax.Array, 19 | sharding=s, 20 | ), 21 | pytree) 22 | 23 | 24 | def load_model_hub( 25 | model_cls, 26 | model_name, 27 | sharding_config, 28 | mesh, 29 | ignore_mismatched_sizes=True, 30 | half=False, 31 | from_pt=True, 32 | ): 33 | # Define sharding function using sharding config over mesh 34 | get_sharding = partial( 35 | spmd_utils.get_sharding, 36 | sharding_config=sharding_config, 37 | mesh=mesh 38 | ) 39 | 40 | # Load model from hub 41 | with jax.default_device(jax.local_devices(backend="cpu")[0]): 42 | with Mesh(devices = np.array(jax.local_devices(backend='cpu')[0]).reshape(1,1), axis_names=('data', 'model')): 43 | model = model_cls.from_pretrained(model_name, ignore_mismatched_sizes=ignore_mismatched_sizes,from_pt=from_pt) 44 | if not half: 45 | model.params = model.to_fp32(model.params) 46 | else: 47 | model.params = model.to_bf16(model.params) 48 | logger.info("Model loaded from hub") 49 | 50 | 51 | # Shard model onto device mesh 52 | model_sharding = jax.tree_util.tree_map_with_path(get_sharding, model.params) 53 | sharded_params = jax.tree_map( 54 | lambda a, s: jax.make_array_from_callback(a.shape, s, lambda i: a[i]), 55 | model.params, model_sharding 56 | ) 57 | logger.info("Model shards transferred to devices") 58 | 59 | return model, sharded_params 60 | 61 | 62 | def load_by_sharding( 63 | checkpoint_manager: orbax.checkpoint.CheckpointManager, 64 | items: Iterable[str], 65 | dummies: Iterable[Any], 66 | shardings: Iterable[Any], 67 | ): 68 | restore_kwargs = { 69 | item: {'restore_args': array_restore_args_from_sharding_pytree(s)} 70 | for item, s in zip(items, shardings) 71 | } 72 | restored = checkpoint_manager.restore( 73 | checkpoint_manager.latest_step(), 74 | items={item: dummy for item, dummy in zip(items, dummies)}, 75 | restore_kwargs=restore_kwargs 76 | ) 77 | return restored 78 | 79 | 80 | def load_by_sharding_no_manager(sharding, path): 81 | checkpointer = orbax.checkpoint.Checkpointer(orbax.checkpoint.PyTreeCheckpointHandler()) 82 | restored = checkpointer.restore( 83 | path, 84 | restore_args=array_restore_args_from_sharding_pytree(sharding) 85 | ) 86 | return restored 87 | 88 | 89 | def load_model_and_optimizer_local( 90 | model_cls, 91 | optimizer, 92 | checkpoint_manager, 93 | sharding_config, 94 | mesh, 95 | model_name=None, 96 | model_config=None, 97 | step=None, 98 | ): 99 | # Create sharding function using sharding config over mesh 100 | get_sharding = partial( 101 | spmd_utils.get_sharding, 102 | sharding_config=sharding_config, 103 | mesh=mesh 104 | ) 105 | 106 | # Load model config from hub 107 | if model_config is None: 108 | model_config = model_cls.config_class.from_pretrained(model_name) 109 | 110 | # Create model instance and get shape pytrees for model and optimizer 111 | with Mesh(devices = np.array(jax.devices('cpu')[0]).reshape(1,1), axis_names=('data', 'model')): 112 | model_no_init = model_cls(model_config, _do_init=False) 113 | 114 | def opt_shape(): 115 | params = model_no_init.init_weights(model_no_init.key, model_no_init.input_shape) 116 | return optimizer.init(params) 117 | opt_shapes = jax.eval_shape(opt_shape) 118 | 119 | # Define sharding for model and optimizer 120 | model_sharding = jax.tree_util.tree_map_with_path(get_sharding, model_no_init._params_shape_tree) 121 | opt_sharding = jax.tree_util.tree_map_with_path(get_sharding, opt_shapes) 122 | 123 | # Restore model and optimizer from local storage 124 | step = checkpoint_manager.latest_step() if step is None else step 125 | restored = load_by_sharding( 126 | checkpoint_manager, 127 | items=['model', 'optimizer'], 128 | dummies=[model_no_init._params_shape_tree, opt_shapes], 129 | shardings=[model_sharding, opt_sharding] 130 | ) 131 | params, opt_state = restored['model'], restored['optimizer'] 132 | logger.info( 133 | "Model and optimizer restored from local storage at step %d", checkpoint_manager.latest_step()) 134 | 135 | return model_no_init, params, opt_state 136 | 137 | 138 | def load_model_local( 139 | model_cls, 140 | path, 141 | sharding_config, 142 | mesh, 143 | model_name=None, 144 | model_config=None, 145 | ): 146 | # Create sharding function using sharding config over mesh 147 | get_sharding = partial( 148 | spmd_utils.get_sharding, 149 | sharding_config=sharding_config, 150 | mesh=mesh 151 | ) 152 | 153 | # Load model config from hub 154 | if model_config is None: 155 | model_config = model_cls.config_class.from_pretrained(model_name) 156 | 157 | # Create model instance and get shape pytrees for model and optimizer 158 | with Mesh(devices = np.array(jax.devices('cpu')[0]).reshape(1,1), axis_names=('data', 'model')): 159 | model_no_init = model_cls(model_config, _do_init=False) 160 | 161 | # Define sharding for model and optimizer 162 | model_sharding = jax.tree_util.tree_map_with_path(get_sharding, model_no_init._params_shape_tree) 163 | 164 | # Restore model 165 | checkpointer = orbax.checkpoint.Checkpointer(orbax.checkpoint.PyTreeCheckpointHandler()) 166 | params = checkpointer.restore( 167 | path, 168 | restore_args=array_restore_args_from_sharding_pytree(model_sharding) 169 | ) 170 | logger.info("Model restored from local storage at %s", path) 171 | 172 | return model_no_init, params 173 | 174 | 175 | def save_model_local( 176 | params, 177 | path, 178 | ): 179 | checkpointer = orbax.checkpoint.Checkpointer(orbax.checkpoint.PyTreeCheckpointHandler()) 180 | checkpointer.save(path, params) 181 | logger.info("Model saved to local storage at %s", path) 182 | 183 | 184 | def get_chckpoint_manager(checkpoint_dir, save_steps=500, max_to_keep=3, items=['model', 'optimizer'], json_items=[]): 185 | options = orbax.checkpoint.CheckpointManagerOptions( 186 | save_interval_steps=save_steps, max_to_keep=max_to_keep) 187 | def get_checkpointer(): 188 | return orbax.checkpoint.AsyncCheckpointer(orbax.checkpoint.PyTreeCheckpointHandler()) 189 | def get_json_checkpointer(): 190 | return orbax.checkpoint.AsyncCheckpointer(orbax.checkpoint.JsonCheckpointHandler()) 191 | checkpoint_manager = orbax.checkpoint.CheckpointManager( 192 | checkpoint_dir, 193 | {item: get_checkpointer() for item in items} | {item: get_json_checkpointer() for item in json_items}, 194 | options 195 | ) 196 | return checkpoint_manager -------------------------------------------------------------------------------- /magix/lora.py: -------------------------------------------------------------------------------- 1 | import re 2 | import logging 3 | from functools import partial 4 | from typing import Any, Callable, Dict 5 | from collections import namedtuple 6 | 7 | import jax 8 | import jax.numpy as jnp 9 | import flax 10 | from jax.sharding import PartitionSpec as PS, NamedSharding 11 | 12 | from . import spmd_utils 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | LoraPair = namedtuple('LoraPair', ['in_matrix', 'out_matrix']) 17 | 18 | 19 | def create_lora_sharding(sharding_config, mesh, lora_abs): 20 | def create_sharding_one(k, v): 21 | if v is None: # skip 22 | return None 23 | if isinstance(v, LoraPair): 24 | v = v[0] 25 | spec = spmd_utils.get_sharding(k, v, sharding_config) 26 | spec_in = PS(*spec[:-1], None) 27 | spec_out = PS(*spec[:-2], None, spec[-1]) 28 | return LoraPair(NamedSharding(mesh, spec_in), NamedSharding(mesh, spec_out)) 29 | else: 30 | return spmd_utils.get_sharding(k, v, sharding_config, mesh) 31 | 32 | return jax.tree_util.tree_map_with_path( 33 | create_sharding_one, lora_abs, is_leaf=lambda x: isinstance(x, LoraPair)) 34 | 35 | 36 | def adapt_params(params, lora_states, alpha=32): 37 | def adapt_one_param(p, l): 38 | if l is not None: 39 | if isinstance(l, dict): 40 | l = LoraPair(l['in_matrix'], l['out_matrix']) 41 | l = tuple(map(lambda x: jnp.astype(x, jnp.bfloat16), l)) 42 | return p + (alpha / l[0].shape[1])*jnp.matmul(l[0], l[1]) 43 | return p 44 | return jax.tree_map(adapt_one_param, params, lora_states) 45 | 46 | 47 | def init_lora_params(prng, params, rules): 48 | # initialization guard 49 | assert rules is not None, "LORA rules must be provided for initialization" 50 | for v in rules.values(): 51 | assert v > 0, "LORA rank must be greater than 0 for initialization" 52 | 53 | init = jax.nn.initializers.he_uniform() 54 | def init_one_param(prng, path, param): 55 | path_str = "/".join(path) 56 | for r in rules: 57 | if re.search(r, path_str): 58 | lora_rank = rules[r] 59 | assert len(param.shape) >= 2 60 | if len(param.shape) != 2: 61 | logger.warn( 62 | 'Initializing LORA for a tensor parameter.' 63 | 'Will apply the decomposition to the last two dimensions.' 64 | ) 65 | 66 | new_rng, in_rng = jax.random.split(prng, 2) 67 | leading_dims = param.shape[:-2] 68 | in_dims = leading_dims + (param.shape[-2], lora_rank) 69 | out_dims = leading_dims + (lora_rank, param.shape[-1]) 70 | in_mat = init(in_rng, in_dims) 71 | out_mat = jnp.zeros(out_dims) 72 | return LoraPair(in_mat, out_mat), new_rng 73 | 74 | return None, prng 75 | 76 | flat_params = flax.traverse_util.flatten_dict(params) 77 | lora_state = {} 78 | for path, param in flat_params.items(): 79 | lora_matrices, prng = init_one_param(prng, path, param) 80 | lora_state[path] = lora_matrices 81 | 82 | return flax.traverse_util.unflatten_dict(lora_state) 83 | 84 | 85 | class Lora: 86 | def __init__( 87 | self, 88 | alpha: float, 89 | rules: Dict[str, int]=None, 90 | ): 91 | self.apply = partial(adapt_params, alpha=alpha) 92 | self.init_params = partial(init_lora_params, rules=rules) -------------------------------------------------------------------------------- /magix/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .llama_model import FlaxLlamaModel, FlaxLlamaForCausalLM 2 | from .mistral_model import FlaxMistralModel, FlaxMistralForCausalLM 3 | from .bert_model import FlaxBertModel 4 | from .t5_model import FlaxT5EncoderModel 5 | from .gemma_model import FlaxGemmaModel, FlaxGemmaForCausalLM 6 | 7 | ENCODER_MODEL_MAPPING = { 8 | "llama": FlaxLlamaModel, 9 | "mistral": FlaxMistralModel, 10 | "bert": FlaxBertModel, 11 | "t5": FlaxT5EncoderModel, 12 | "gemma": FlaxGemmaModel, 13 | } 14 | 15 | CAUSAL_LM_MODEL_MAPPING = { 16 | "llama": FlaxLlamaForCausalLM, 17 | "mistral": FlaxMistralForCausalLM, 18 | "gemma": FlaxGemmaForCausalLM, 19 | } -------------------------------------------------------------------------------- /magix/models/bert_model.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Optional, Tuple 2 | 3 | import flax 4 | import flax.linen as nn 5 | import jax 6 | import jax.numpy as jnp 7 | import numpy as np 8 | from flax.core.frozen_dict import FrozenDict, freeze, unfreeze 9 | from flax.linen.attention import dot_product_attention_weights 10 | from flax.traverse_util import flatten_dict, unflatten_dict 11 | from jax import lax 12 | from jax.sharding import PartitionSpec as PS 13 | 14 | 15 | from transformers.modeling_flax_outputs import ( 16 | FlaxBaseModelOutputWithPastAndCrossAttentions, 17 | FlaxBaseModelOutputWithPoolingAndCrossAttentions, 18 | FlaxMaskedLMOutput, 19 | FlaxMultipleChoiceModelOutput, 20 | FlaxNextSentencePredictorOutput, 21 | FlaxQuestionAnsweringModelOutput, 22 | FlaxSequenceClassifierOutput, 23 | FlaxTokenClassifierOutput, 24 | ) 25 | from transformers.modeling_flax_utils import ( 26 | ACT2FN, 27 | FlaxPreTrainedModel 28 | ) 29 | from transformers.utils import ModelOutput, logging 30 | from transformers.models.bert import BertConfig 31 | 32 | 33 | logger = logging.get_logger(__name__) 34 | 35 | 36 | @flax.struct.dataclass 37 | class FlaxBertForPreTrainingOutput(ModelOutput): 38 | """ 39 | Output type of [`BertForPreTraining`]. 40 | 41 | Args: 42 | prediction_logits (`jnp.ndarray` of shape `(batch_size, sequence_length, config.vocab_size)`): 43 | Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). 44 | seq_relationship_logits (`jnp.ndarray` of shape `(batch_size, 2)`): 45 | Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation 46 | before SoftMax). 47 | hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): 48 | Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape 49 | `(batch_size, sequence_length, hidden_size)`. 50 | 51 | Hidden-states of the model at the output of each layer plus the initial embedding outputs. 52 | attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): 53 | Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, 54 | sequence_length)`. 55 | 56 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention 57 | heads. 58 | """ 59 | 60 | prediction_logits: jnp.ndarray = None 61 | seq_relationship_logits: jnp.ndarray = None 62 | hidden_states: Optional[Tuple[jnp.ndarray]] = None 63 | attentions: Optional[Tuple[jnp.ndarray]] = None 64 | 65 | 66 | class FlaxBertEmbeddings(nn.Module): 67 | """Construct the embeddings from word, position and token_type embeddings.""" 68 | 69 | config: BertConfig 70 | dtype: jnp.dtype = jnp.float32 # the dtype of the computation 71 | 72 | def setup(self): 73 | self.word_embeddings = nn.Embed( 74 | self.config.vocab_size, 75 | self.config.hidden_size, 76 | embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), 77 | dtype=self.dtype, 78 | ) 79 | self.position_embeddings = nn.Embed( 80 | self.config.max_position_embeddings, 81 | self.config.hidden_size, 82 | embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), 83 | dtype=self.dtype, 84 | ) 85 | self.token_type_embeddings = nn.Embed( 86 | self.config.type_vocab_size, 87 | self.config.hidden_size, 88 | embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), 89 | dtype=self.dtype, 90 | ) 91 | self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) 92 | self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) 93 | 94 | def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True): 95 | # Embed 96 | inputs_embeds = self.word_embeddings(input_ids.astype("i4")) 97 | position_embeds = self.position_embeddings(position_ids.astype("i4")) 98 | token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4")) 99 | 100 | # Sum all embeddings 101 | hidden_states = inputs_embeds + token_type_embeddings + position_embeds 102 | 103 | # Layer Norm 104 | hidden_states = self.LayerNorm(hidden_states) 105 | hidden_states = self.dropout(hidden_states, deterministic=deterministic) 106 | return hidden_states 107 | 108 | 109 | class FlaxBertSelfAttention(nn.Module): 110 | config: BertConfig 111 | dtype: jnp.dtype = jnp.float32 # the dtype of the computation 112 | 113 | def setup(self): 114 | self.head_dim = self.config.hidden_size // self.config.num_attention_heads 115 | if self.config.hidden_size % self.config.num_attention_heads != 0: 116 | raise ValueError( 117 | "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads` " 118 | " : {self.config.num_attention_heads}" 119 | ) 120 | 121 | self.query = nn.Dense( 122 | self.config.hidden_size, 123 | dtype=self.dtype, 124 | kernel_init=jax.nn.initializers.normal(self.config.initializer_range), 125 | ) 126 | self.key = nn.Dense( 127 | self.config.hidden_size, 128 | dtype=self.dtype, 129 | kernel_init=jax.nn.initializers.normal(self.config.initializer_range), 130 | ) 131 | self.value = nn.Dense( 132 | self.config.hidden_size, 133 | dtype=self.dtype, 134 | kernel_init=jax.nn.initializers.normal(self.config.initializer_range), 135 | ) 136 | 137 | 138 | def _split_heads(self, hidden_states): 139 | return hidden_states.reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, self.head_dim)) 140 | 141 | def _merge_heads(self, hidden_states): 142 | return hidden_states.reshape(hidden_states.shape[:2] + (self.config.hidden_size,)) 143 | 144 | def __call__( 145 | self, 146 | hidden_states, 147 | attention_mask, 148 | layer_head_mask, 149 | key_value_states: Optional[jnp.ndarray] = None, 150 | init_cache: bool = False, 151 | deterministic=True, 152 | output_attentions: bool = False, 153 | ): 154 | # if key_value_states are provided this layer is used as a cross-attention layer 155 | # for the decoder 156 | is_cross_attention = key_value_states is not None 157 | batch_size = hidden_states.shape[0] 158 | 159 | # get query proj 160 | query_states = self.query(hidden_states) 161 | # get key, value proj 162 | if is_cross_attention: 163 | # cross_attentions 164 | key_states = self.key(key_value_states) 165 | value_states = self.value(key_value_states) 166 | else: 167 | # self_attention 168 | key_states = self.key(hidden_states) 169 | value_states = self.value(hidden_states) 170 | 171 | query_states = self._split_heads(query_states) 172 | key_states = self._split_heads(key_states) 173 | value_states = self._split_heads(value_states) 174 | 175 | attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) 176 | 177 | # Convert the boolean attention mask to an attention bias. 178 | if attention_mask is not None: 179 | # attention mask in the form of attention bias 180 | attention_bias = lax.select( 181 | attention_mask > 0, 182 | jnp.full(attention_mask.shape, 0.0).astype(self.dtype), 183 | jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), 184 | ) 185 | else: 186 | attention_bias = None 187 | 188 | dropout_rng = None 189 | if not deterministic and self.config.attention_probs_dropout_prob > 0.0: 190 | dropout_rng = self.make_rng("dropout") 191 | 192 | attn_weights = dot_product_attention_weights( 193 | query_states, 194 | key_states, 195 | bias=attention_bias, 196 | dropout_rng=dropout_rng, 197 | dropout_rate=self.config.attention_probs_dropout_prob, 198 | broadcast_dropout=True, 199 | deterministic=deterministic, 200 | dtype=self.dtype, 201 | precision=None, 202 | ) 203 | 204 | # Mask heads if we want to 205 | if layer_head_mask is not None: 206 | attn_weights = jnp.einsum("...hqk,h->...hqk", attn_weights, layer_head_mask) 207 | 208 | attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) 209 | attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,)) 210 | 211 | outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) 212 | return outputs 213 | 214 | 215 | class FlaxBertSelfOutput(nn.Module): 216 | config: BertConfig 217 | dtype: jnp.dtype = jnp.float32 # the dtype of the computation 218 | 219 | def setup(self): 220 | self.dense = nn.Dense( 221 | self.config.hidden_size, 222 | kernel_init=jax.nn.initializers.normal(self.config.initializer_range), 223 | dtype=self.dtype, 224 | ) 225 | self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) 226 | self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) 227 | 228 | def __call__(self, hidden_states, input_tensor, deterministic: bool = True): 229 | hidden_states = self.dense(hidden_states) 230 | hidden_states = self.dropout(hidden_states, deterministic=deterministic) 231 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 232 | return hidden_states 233 | 234 | 235 | class FlaxBertAttention(nn.Module): 236 | config: BertConfig 237 | dtype: jnp.dtype = jnp.float32 238 | 239 | def setup(self): 240 | self.self = FlaxBertSelfAttention(self.config, dtype=self.dtype) 241 | self.output = FlaxBertSelfOutput(self.config, dtype=self.dtype) 242 | 243 | def __call__( 244 | self, 245 | hidden_states, 246 | attention_mask, 247 | layer_head_mask, 248 | key_value_states=None, 249 | init_cache=False, 250 | deterministic=True, 251 | output_attentions: bool = False, 252 | ): 253 | # Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length) 254 | # FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable 255 | # with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length) 256 | attn_outputs = self.self( 257 | hidden_states, 258 | attention_mask, 259 | layer_head_mask=layer_head_mask, 260 | key_value_states=key_value_states, 261 | init_cache=init_cache, 262 | deterministic=deterministic, 263 | output_attentions=output_attentions, 264 | ) 265 | attn_output = attn_outputs[0] 266 | hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic) 267 | 268 | outputs = (hidden_states,) 269 | 270 | if output_attentions: 271 | outputs += (attn_outputs[1],) 272 | 273 | return outputs 274 | 275 | 276 | class FlaxBertIntermediate(nn.Module): 277 | config: BertConfig 278 | dtype: jnp.dtype = jnp.float32 # the dtype of the computation 279 | 280 | def setup(self): 281 | self.dense = nn.Dense( 282 | self.config.intermediate_size, 283 | kernel_init=jax.nn.initializers.normal(self.config.initializer_range), 284 | dtype=self.dtype, 285 | ) 286 | self.activation = ACT2FN[self.config.hidden_act] 287 | 288 | def __call__(self, hidden_states): 289 | hidden_states = self.dense(hidden_states) 290 | hidden_states = self.activation(hidden_states) 291 | return hidden_states 292 | 293 | 294 | class FlaxBertOutput(nn.Module): 295 | config: BertConfig 296 | dtype: jnp.dtype = jnp.float32 # the dtype of the computation 297 | 298 | def setup(self): 299 | self.dense = nn.Dense( 300 | self.config.hidden_size, 301 | kernel_init=jax.nn.initializers.normal(self.config.initializer_range), 302 | dtype=self.dtype, 303 | ) 304 | self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) 305 | self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) 306 | 307 | def __call__(self, hidden_states, attention_output, deterministic: bool = True): 308 | hidden_states = self.dense(hidden_states) 309 | hidden_states = self.dropout(hidden_states, deterministic=deterministic) 310 | hidden_states = self.LayerNorm(hidden_states + attention_output) 311 | return hidden_states 312 | 313 | 314 | class FlaxBertLayer(nn.Module): 315 | config: BertConfig 316 | dtype: jnp.dtype = jnp.float32 # the dtype of the computation 317 | 318 | def setup(self): 319 | self.attention = FlaxBertAttention(self.config, dtype=self.dtype) 320 | self.intermediate = FlaxBertIntermediate(self.config, dtype=self.dtype) 321 | self.output = FlaxBertOutput(self.config, dtype=self.dtype) 322 | 323 | def __call__( 324 | self, 325 | hidden_states, 326 | attention_mask, 327 | layer_head_mask, 328 | encoder_hidden_states: Optional[jnp.ndarray] = None, 329 | encoder_attention_mask: Optional[jnp.ndarray] = None, 330 | init_cache: bool = False, 331 | deterministic: bool = True, 332 | output_attentions: bool = False, 333 | ): 334 | # Self Attention 335 | attention_outputs = self.attention( 336 | hidden_states, 337 | attention_mask, 338 | layer_head_mask=layer_head_mask, 339 | init_cache=init_cache, 340 | deterministic=deterministic, 341 | output_attentions=output_attentions, 342 | ) 343 | attention_output = attention_outputs[0] 344 | 345 | # Cross-Attention Block 346 | if encoder_hidden_states is not None: 347 | cross_attention_outputs = self.crossattention( 348 | attention_output, 349 | attention_mask=encoder_attention_mask, 350 | layer_head_mask=layer_head_mask, 351 | key_value_states=encoder_hidden_states, 352 | deterministic=deterministic, 353 | output_attentions=output_attentions, 354 | ) 355 | attention_output = cross_attention_outputs[0] 356 | 357 | hidden_states = self.intermediate(attention_output) 358 | hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic) 359 | 360 | outputs = (hidden_states,) 361 | 362 | if output_attentions: 363 | outputs += (attention_outputs[1],) 364 | if encoder_hidden_states is not None: 365 | outputs += (cross_attention_outputs[1],) 366 | return outputs 367 | 368 | 369 | class FlaxBertLayerCollection(nn.Module): 370 | config: BertConfig 371 | dtype: jnp.dtype = jnp.float32 # the dtype of the computation 372 | gradient_checkpointing: bool = False 373 | 374 | def setup(self): 375 | self.layers = [ 376 | FlaxBertLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers) 377 | ] 378 | 379 | def __call__( 380 | self, 381 | hidden_states, 382 | attention_mask, 383 | head_mask, 384 | encoder_hidden_states: Optional[jnp.ndarray] = None, 385 | encoder_attention_mask: Optional[jnp.ndarray] = None, 386 | init_cache: bool = False, 387 | deterministic: bool = True, 388 | output_attentions: bool = False, 389 | output_hidden_states: bool = False, 390 | return_dict: bool = True, 391 | ): 392 | all_attentions = () if output_attentions else None 393 | all_hidden_states = () if output_hidden_states else None 394 | all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None 395 | 396 | # Check if head_mask has a correct number of layers specified if desired 397 | if head_mask is not None: 398 | if head_mask.shape[0] != (len(self.layers)): 399 | raise ValueError( 400 | f"The head_mask should be specified for {len(self.layers)} layers, but it is for " 401 | f" {head_mask.shape[0]}." 402 | ) 403 | 404 | for i, layer in enumerate(self.layers): 405 | if output_hidden_states: 406 | all_hidden_states += (hidden_states,) 407 | 408 | layer_outputs = layer( 409 | hidden_states, 410 | attention_mask, 411 | head_mask[i] if head_mask is not None else None, 412 | encoder_hidden_states, 413 | encoder_attention_mask, 414 | init_cache, 415 | deterministic, 416 | output_attentions, 417 | ) 418 | 419 | hidden_states = layer_outputs[0] 420 | 421 | if output_attentions: 422 | all_attentions += (layer_outputs[1],) 423 | 424 | if encoder_hidden_states is not None: 425 | all_cross_attentions += (layer_outputs[2],) 426 | 427 | if output_hidden_states: 428 | all_hidden_states += (hidden_states,) 429 | 430 | outputs = (hidden_states, all_hidden_states, all_attentions, all_cross_attentions) 431 | 432 | if not return_dict: 433 | return tuple(v for v in outputs if v is not None) 434 | 435 | return FlaxBaseModelOutputWithPastAndCrossAttentions( 436 | last_hidden_state=hidden_states, 437 | hidden_states=all_hidden_states, 438 | attentions=all_attentions, 439 | cross_attentions=all_cross_attentions, 440 | ) 441 | 442 | 443 | class FlaxBertEncoder(nn.Module): 444 | config: BertConfig 445 | dtype: jnp.dtype = jnp.float32 # the dtype of the computation 446 | gradient_checkpointing: bool = False 447 | 448 | def setup(self): 449 | self.layer = FlaxBertLayerCollection( 450 | self.config, 451 | dtype=self.dtype, 452 | gradient_checkpointing=self.gradient_checkpointing, 453 | ) 454 | 455 | def __call__( 456 | self, 457 | hidden_states, 458 | attention_mask, 459 | head_mask, 460 | encoder_hidden_states: Optional[jnp.ndarray] = None, 461 | encoder_attention_mask: Optional[jnp.ndarray] = None, 462 | init_cache: bool = False, 463 | deterministic: bool = True, 464 | output_attentions: bool = False, 465 | output_hidden_states: bool = False, 466 | return_dict: bool = True, 467 | ): 468 | return self.layer( 469 | hidden_states, 470 | attention_mask, 471 | head_mask=head_mask, 472 | encoder_hidden_states=encoder_hidden_states, 473 | encoder_attention_mask=encoder_attention_mask, 474 | init_cache=init_cache, 475 | deterministic=deterministic, 476 | output_attentions=output_attentions, 477 | output_hidden_states=output_hidden_states, 478 | return_dict=return_dict, 479 | ) 480 | 481 | 482 | class FlaxBertPooler(nn.Module): 483 | config: BertConfig 484 | dtype: jnp.dtype = jnp.float32 # the dtype of the computation 485 | 486 | def setup(self): 487 | self.dense = nn.Dense( 488 | self.config.hidden_size, 489 | kernel_init=jax.nn.initializers.normal(self.config.initializer_range), 490 | dtype=self.dtype, 491 | ) 492 | 493 | def __call__(self, hidden_states): 494 | cls_hidden_state = hidden_states[:, 0] 495 | cls_hidden_state = self.dense(cls_hidden_state) 496 | return nn.tanh(cls_hidden_state) 497 | 498 | 499 | class FlaxBertPredictionHeadTransform(nn.Module): 500 | config: BertConfig 501 | dtype: jnp.dtype = jnp.float32 502 | 503 | def setup(self): 504 | self.dense = nn.Dense(self.config.hidden_size, dtype=self.dtype) 505 | self.activation = ACT2FN[self.config.hidden_act] 506 | self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) 507 | 508 | def __call__(self, hidden_states): 509 | hidden_states = self.dense(hidden_states) 510 | hidden_states = self.activation(hidden_states) 511 | return self.LayerNorm(hidden_states) 512 | 513 | 514 | class FlaxBertLMPredictionHead(nn.Module): 515 | config: BertConfig 516 | dtype: jnp.dtype = jnp.float32 517 | bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros 518 | 519 | def setup(self): 520 | self.transform = FlaxBertPredictionHeadTransform(self.config, dtype=self.dtype) 521 | self.decoder = nn.Dense(self.config.vocab_size, dtype=self.dtype, use_bias=False) 522 | self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,)) 523 | 524 | def __call__(self, hidden_states, shared_embedding=None): 525 | hidden_states = self.transform(hidden_states) 526 | 527 | if shared_embedding is not None: 528 | hidden_states = self.decoder.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) 529 | else: 530 | hidden_states = self.decoder(hidden_states) 531 | 532 | bias = jnp.asarray(self.bias, self.dtype) 533 | hidden_states += bias 534 | return hidden_states 535 | 536 | 537 | class FlaxBertOnlyMLMHead(nn.Module): 538 | config: BertConfig 539 | dtype: jnp.dtype = jnp.float32 540 | 541 | def setup(self): 542 | self.predictions = FlaxBertLMPredictionHead(self.config, dtype=self.dtype) 543 | 544 | def __call__(self, hidden_states, shared_embedding=None): 545 | hidden_states = self.predictions(hidden_states, shared_embedding=shared_embedding) 546 | return hidden_states 547 | 548 | 549 | class FlaxBertOnlyNSPHead(nn.Module): 550 | dtype: jnp.dtype = jnp.float32 551 | 552 | def setup(self): 553 | self.seq_relationship = nn.Dense(2, dtype=self.dtype) 554 | 555 | def __call__(self, pooled_output): 556 | return self.seq_relationship(pooled_output) 557 | 558 | 559 | class FlaxBertPreTrainingHeads(nn.Module): 560 | config: BertConfig 561 | dtype: jnp.dtype = jnp.float32 562 | 563 | def setup(self): 564 | self.predictions = FlaxBertLMPredictionHead(self.config, dtype=self.dtype) 565 | self.seq_relationship = nn.Dense(2, dtype=self.dtype) 566 | 567 | def __call__(self, hidden_states, pooled_output, shared_embedding=None): 568 | prediction_scores = self.predictions(hidden_states, shared_embedding=shared_embedding) 569 | seq_relationship_score = self.seq_relationship(pooled_output) 570 | return prediction_scores, seq_relationship_score 571 | 572 | 573 | class FlaxBertPreTrainedModel(FlaxPreTrainedModel): 574 | """ 575 | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained 576 | models. 577 | """ 578 | 579 | config_class = BertConfig 580 | base_model_prefix = "bert" 581 | module_class: nn.Module = None 582 | 583 | partition_rules = { 584 | 'word_embeddings': PS('data', 'model'), 585 | '(query|key|value)/kernel': PS('data', 'model'), 586 | 'intermediate/dense/kernel': PS('data', 'model'), 587 | 'output/dense/kernel': PS('model', 'data'), 588 | 'predictions/transform/dense/kernel': PS('data', 'model'), 589 | 'predictions/decoder/kernel': PS('model', 'data'), 590 | } 591 | 592 | def __init__( 593 | self, 594 | config: BertConfig, 595 | input_shape: Tuple = (1, 1), 596 | seed: int = 0, 597 | dtype: jnp.dtype = jnp.float32, 598 | _do_init: bool = True, 599 | gradient_checkpointing: bool = False, 600 | **kwargs, 601 | ): 602 | module = self.module_class( 603 | config=config, 604 | dtype=dtype, 605 | gradient_checkpointing=gradient_checkpointing, 606 | **kwargs, 607 | ) 608 | super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) 609 | 610 | def enable_gradient_checkpointing(self): 611 | self._module = self.module_class( 612 | config=self.config, 613 | dtype=self.dtype, 614 | gradient_checkpointing=True, 615 | ) 616 | 617 | def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: 618 | # init input tensors 619 | input_ids = jnp.zeros(input_shape, dtype="i4") 620 | token_type_ids = jnp.zeros_like(input_ids) 621 | position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape) 622 | attention_mask = jnp.ones_like(input_ids) 623 | head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads)) 624 | 625 | params_rng, dropout_rng = jax.random.split(rng) 626 | rngs = {"params": params_rng, "dropout": dropout_rng} 627 | 628 | if self.config.add_cross_attention: 629 | encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,)) 630 | encoder_attention_mask = attention_mask 631 | module_init_outputs = self.module.init( 632 | rngs, 633 | input_ids, 634 | attention_mask, 635 | token_type_ids, 636 | position_ids, 637 | head_mask, 638 | encoder_hidden_states, 639 | encoder_attention_mask, 640 | return_dict=False, 641 | ) 642 | else: 643 | module_init_outputs = self.module.init( 644 | rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False 645 | ) 646 | 647 | random_params = module_init_outputs["params"] 648 | 649 | if params is not None: 650 | random_params = flatten_dict(unfreeze(random_params)) 651 | params = flatten_dict(unfreeze(params)) 652 | for missing_key in self._missing_keys: 653 | params[missing_key] = random_params[missing_key] 654 | self._missing_keys = set() 655 | return freeze(unflatten_dict(params)) 656 | else: 657 | return random_params 658 | 659 | 660 | def __call__( 661 | self, 662 | input_ids, 663 | attention_mask=None, 664 | token_type_ids=None, 665 | position_ids=None, 666 | head_mask=None, 667 | encoder_hidden_states=None, 668 | encoder_attention_mask=None, 669 | params: dict = None, 670 | dropout_rng: jax.random.PRNGKey = None, 671 | train: bool = False, 672 | output_attentions: Optional[bool] = None, 673 | output_hidden_states: Optional[bool] = None, 674 | return_dict: Optional[bool] = None, 675 | past_key_values: dict = None, 676 | ): 677 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 678 | output_hidden_states = ( 679 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 680 | ) 681 | return_dict = return_dict if return_dict is not None else self.config.return_dict 682 | 683 | # init input tensors if not passed 684 | if token_type_ids is None: 685 | token_type_ids = jnp.zeros_like(input_ids) 686 | 687 | if position_ids is None: 688 | position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) 689 | 690 | if attention_mask is None: 691 | attention_mask = jnp.ones_like(input_ids) 692 | 693 | if head_mask is None: 694 | head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads)) 695 | 696 | # Handle any PRNG if needed 697 | rngs = {} 698 | if dropout_rng is not None: 699 | rngs["dropout"] = dropout_rng 700 | 701 | inputs = {"params": params or self.params} 702 | 703 | if self.config.add_cross_attention: 704 | # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed 705 | # down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be 706 | # changed by FlaxBertAttention module 707 | if past_key_values: 708 | inputs["cache"] = past_key_values 709 | mutable = ["cache"] 710 | else: 711 | mutable = False 712 | 713 | outputs = self.module.apply( 714 | inputs, 715 | jnp.array(input_ids, dtype="i4"), 716 | jnp.array(attention_mask, dtype="i4"), 717 | token_type_ids=jnp.array(token_type_ids, dtype="i4"), 718 | position_ids=jnp.array(position_ids, dtype="i4"), 719 | head_mask=jnp.array(head_mask, dtype="i4"), 720 | encoder_hidden_states=encoder_hidden_states, 721 | encoder_attention_mask=encoder_attention_mask, 722 | deterministic=not train, 723 | output_attentions=output_attentions, 724 | output_hidden_states=output_hidden_states, 725 | return_dict=return_dict, 726 | rngs=rngs, 727 | mutable=mutable, 728 | ) 729 | 730 | # add updated cache to model output 731 | if past_key_values is not None and return_dict: 732 | outputs, past_key_values = outputs 733 | outputs["past_key_values"] = unfreeze(past_key_values["cache"]) 734 | return outputs 735 | elif past_key_values is not None and not return_dict: 736 | outputs, past_key_values = outputs 737 | outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:] 738 | 739 | else: 740 | outputs = self.module.apply( 741 | inputs, 742 | jnp.array(input_ids, dtype="i4"), 743 | jnp.array(attention_mask, dtype="i4"), 744 | token_type_ids=jnp.array(token_type_ids, dtype="i4"), 745 | position_ids=jnp.array(position_ids, dtype="i4"), 746 | head_mask=jnp.array(head_mask, dtype="i4"), 747 | deterministic=not train, 748 | output_attentions=output_attentions, 749 | output_hidden_states=output_hidden_states, 750 | return_dict=return_dict, 751 | rngs=rngs, 752 | ) 753 | 754 | return outputs 755 | 756 | 757 | class FlaxBertModule(nn.Module): 758 | config: BertConfig 759 | dtype: jnp.dtype = jnp.float32 # the dtype of the computation 760 | add_pooling_layer: bool = True 761 | gradient_checkpointing: bool = False 762 | 763 | def setup(self): 764 | self.embeddings = FlaxBertEmbeddings(self.config, dtype=self.dtype) 765 | self.encoder = FlaxBertEncoder( 766 | self.config, 767 | dtype=self.dtype, 768 | gradient_checkpointing=self.gradient_checkpointing, 769 | ) 770 | self.pooler = FlaxBertPooler(self.config, dtype=self.dtype) 771 | 772 | def __call__( 773 | self, 774 | input_ids, 775 | attention_mask, 776 | token_type_ids: Optional[jnp.ndarray] = None, 777 | position_ids: Optional[jnp.ndarray] = None, 778 | head_mask: Optional[jnp.ndarray] = None, 779 | encoder_hidden_states: Optional[jnp.ndarray] = None, 780 | encoder_attention_mask: Optional[jnp.ndarray] = None, 781 | init_cache: bool = False, 782 | deterministic: bool = True, 783 | output_attentions: bool = False, 784 | output_hidden_states: bool = False, 785 | return_dict: bool = True, 786 | ): 787 | # make sure `token_type_ids` is correctly initialized when not passed 788 | if token_type_ids is None: 789 | token_type_ids = jnp.zeros_like(input_ids) 790 | 791 | # make sure `position_ids` is correctly initialized when not passed 792 | if position_ids is None: 793 | position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) 794 | 795 | hidden_states = self.embeddings( 796 | input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic 797 | ) 798 | outputs = self.encoder( 799 | hidden_states, 800 | attention_mask, 801 | head_mask=head_mask, 802 | deterministic=deterministic, 803 | encoder_hidden_states=encoder_hidden_states, 804 | encoder_attention_mask=encoder_attention_mask, 805 | init_cache=init_cache, 806 | output_attentions=output_attentions, 807 | output_hidden_states=output_hidden_states, 808 | return_dict=return_dict, 809 | ) 810 | hidden_states = outputs[0] 811 | pooled = self.pooler(hidden_states) if self.add_pooling_layer else None 812 | 813 | if not return_dict: 814 | # if pooled is None, don't return it 815 | if pooled is None: 816 | return (hidden_states,) + outputs[1:] 817 | return (hidden_states, pooled) + outputs[1:] 818 | 819 | return FlaxBaseModelOutputWithPoolingAndCrossAttentions( 820 | last_hidden_state=hidden_states, 821 | pooler_output=pooled, 822 | hidden_states=outputs.hidden_states, 823 | attentions=outputs.attentions, 824 | cross_attentions=outputs.cross_attentions, 825 | ) 826 | 827 | 828 | 829 | class FlaxBertModel(FlaxBertPreTrainedModel): 830 | module_class = FlaxBertModule 831 | 832 | 833 | class FlaxBertForPreTrainingModule(nn.Module): 834 | config: BertConfig 835 | dtype: jnp.dtype = jnp.float32 836 | gradient_checkpointing: bool = False 837 | 838 | def setup(self): 839 | self.bert = FlaxBertModule( 840 | config=self.config, 841 | dtype=self.dtype, 842 | gradient_checkpointing=self.gradient_checkpointing, 843 | ) 844 | self.cls = FlaxBertPreTrainingHeads(config=self.config, dtype=self.dtype) 845 | 846 | def __call__( 847 | self, 848 | input_ids, 849 | attention_mask, 850 | token_type_ids, 851 | position_ids, 852 | head_mask, 853 | deterministic: bool = True, 854 | output_attentions: bool = False, 855 | output_hidden_states: bool = False, 856 | return_dict: bool = True, 857 | ): 858 | # Model 859 | outputs = self.bert( 860 | input_ids, 861 | attention_mask, 862 | token_type_ids, 863 | position_ids, 864 | head_mask, 865 | deterministic=deterministic, 866 | output_attentions=output_attentions, 867 | output_hidden_states=output_hidden_states, 868 | return_dict=return_dict, 869 | ) 870 | 871 | if self.config.tie_word_embeddings: 872 | shared_embedding = self.bert.variables["params"]["embeddings"]["word_embeddings"]["embedding"] 873 | else: 874 | shared_embedding = None 875 | 876 | hidden_states = outputs[0] 877 | pooled_output = outputs[1] 878 | 879 | prediction_scores, seq_relationship_score = self.cls( 880 | hidden_states, pooled_output, shared_embedding=shared_embedding 881 | ) 882 | 883 | if not return_dict: 884 | return (prediction_scores, seq_relationship_score) + outputs[2:] 885 | 886 | return FlaxBertForPreTrainingOutput( 887 | prediction_logits=prediction_scores, 888 | seq_relationship_logits=seq_relationship_score, 889 | hidden_states=outputs.hidden_states, 890 | attentions=outputs.attentions, 891 | ) 892 | 893 | 894 | class FlaxBertForPreTraining(FlaxBertPreTrainedModel): 895 | module_class = FlaxBertForPreTrainingModule 896 | 897 | 898 | class FlaxBertForMaskedLMModule(nn.Module): 899 | config: BertConfig 900 | dtype: jnp.dtype = jnp.float32 901 | gradient_checkpointing: bool = False 902 | 903 | def setup(self): 904 | self.bert = FlaxBertModule( 905 | config=self.config, 906 | add_pooling_layer=False, 907 | dtype=self.dtype, 908 | gradient_checkpointing=self.gradient_checkpointing, 909 | ) 910 | self.cls = FlaxBertOnlyMLMHead(config=self.config, dtype=self.dtype) 911 | 912 | def __call__( 913 | self, 914 | input_ids, 915 | attention_mask, 916 | token_type_ids, 917 | position_ids, 918 | head_mask, 919 | deterministic: bool = True, 920 | output_attentions: bool = False, 921 | output_hidden_states: bool = False, 922 | return_dict: bool = True, 923 | ): 924 | # Model 925 | outputs = self.bert( 926 | input_ids, 927 | attention_mask, 928 | token_type_ids, 929 | position_ids, 930 | head_mask, 931 | deterministic=deterministic, 932 | output_attentions=output_attentions, 933 | output_hidden_states=output_hidden_states, 934 | return_dict=return_dict, 935 | ) 936 | 937 | hidden_states = outputs[0] 938 | if self.config.tie_word_embeddings: 939 | shared_embedding = self.bert.variables["params"]["embeddings"]["word_embeddings"]["embedding"] 940 | else: 941 | shared_embedding = None 942 | 943 | # Compute the prediction scores 944 | logits = self.cls(hidden_states, shared_embedding=shared_embedding) 945 | 946 | if not return_dict: 947 | return (logits,) + outputs[1:] 948 | 949 | return FlaxMaskedLMOutput( 950 | logits=logits, 951 | hidden_states=outputs.hidden_states, 952 | attentions=outputs.attentions, 953 | ) 954 | 955 | 956 | class FlaxBertForMaskedLM(FlaxBertPreTrainedModel): 957 | module_class = FlaxBertForMaskedLMModule 958 | 959 | 960 | class FlaxBertForNextSentencePredictionModule(nn.Module): 961 | config: BertConfig 962 | dtype: jnp.dtype = jnp.float32 963 | gradient_checkpointing: bool = False 964 | 965 | def setup(self): 966 | self.bert = FlaxBertModule( 967 | config=self.config, 968 | dtype=self.dtype, 969 | gradient_checkpointing=self.gradient_checkpointing, 970 | ) 971 | self.cls = FlaxBertOnlyNSPHead(dtype=self.dtype) 972 | 973 | def __call__( 974 | self, 975 | input_ids, 976 | attention_mask, 977 | token_type_ids, 978 | position_ids, 979 | head_mask, 980 | deterministic: bool = True, 981 | output_attentions: bool = False, 982 | output_hidden_states: bool = False, 983 | return_dict: bool = True, 984 | ): 985 | return_dict = return_dict if return_dict is not None else self.config.return_dict 986 | 987 | # Model 988 | outputs = self.bert( 989 | input_ids, 990 | attention_mask, 991 | token_type_ids, 992 | position_ids, 993 | head_mask, 994 | deterministic=deterministic, 995 | output_attentions=output_attentions, 996 | output_hidden_states=output_hidden_states, 997 | return_dict=return_dict, 998 | ) 999 | 1000 | pooled_output = outputs[1] 1001 | seq_relationship_scores = self.cls(pooled_output) 1002 | 1003 | if not return_dict: 1004 | return (seq_relationship_scores,) + outputs[2:] 1005 | 1006 | return FlaxNextSentencePredictorOutput( 1007 | logits=seq_relationship_scores, 1008 | hidden_states=outputs.hidden_states, 1009 | attentions=outputs.attentions, 1010 | ) 1011 | 1012 | 1013 | class FlaxBertForNextSentencePrediction(FlaxBertPreTrainedModel): 1014 | module_class = FlaxBertForNextSentencePredictionModule 1015 | 1016 | class FlaxBertForSequenceClassificationModule(nn.Module): 1017 | config: BertConfig 1018 | dtype: jnp.dtype = jnp.float32 1019 | gradient_checkpointing: bool = False 1020 | 1021 | def setup(self): 1022 | self.bert = FlaxBertModule( 1023 | config=self.config, 1024 | dtype=self.dtype, 1025 | gradient_checkpointing=self.gradient_checkpointing, 1026 | ) 1027 | classifier_dropout = ( 1028 | self.config.classifier_dropout 1029 | if self.config.classifier_dropout is not None 1030 | else self.config.hidden_dropout_prob 1031 | ) 1032 | self.dropout = nn.Dropout(rate=classifier_dropout) 1033 | self.classifier = nn.Dense( 1034 | self.config.num_labels, 1035 | dtype=self.dtype, 1036 | ) 1037 | 1038 | def __call__( 1039 | self, 1040 | input_ids, 1041 | attention_mask, 1042 | token_type_ids, 1043 | position_ids, 1044 | head_mask, 1045 | deterministic: bool = True, 1046 | output_attentions: bool = False, 1047 | output_hidden_states: bool = False, 1048 | return_dict: bool = True, 1049 | ): 1050 | # Model 1051 | outputs = self.bert( 1052 | input_ids, 1053 | attention_mask, 1054 | token_type_ids, 1055 | position_ids, 1056 | head_mask, 1057 | deterministic=deterministic, 1058 | output_attentions=output_attentions, 1059 | output_hidden_states=output_hidden_states, 1060 | return_dict=return_dict, 1061 | ) 1062 | 1063 | pooled_output = outputs[1] 1064 | pooled_output = self.dropout(pooled_output, deterministic=deterministic) 1065 | logits = self.classifier(pooled_output) 1066 | 1067 | if not return_dict: 1068 | return (logits,) + outputs[2:] 1069 | 1070 | return FlaxSequenceClassifierOutput( 1071 | logits=logits, 1072 | hidden_states=outputs.hidden_states, 1073 | attentions=outputs.attentions, 1074 | ) 1075 | 1076 | 1077 | class FlaxBertForSequenceClassification(FlaxBertPreTrainedModel): 1078 | module_class = FlaxBertForSequenceClassificationModule 1079 | 1080 | 1081 | class FlaxBertForMultipleChoiceModule(nn.Module): 1082 | config: BertConfig 1083 | dtype: jnp.dtype = jnp.float32 1084 | gradient_checkpointing: bool = False 1085 | 1086 | def setup(self): 1087 | self.bert = FlaxBertModule( 1088 | config=self.config, 1089 | dtype=self.dtype, 1090 | gradient_checkpointing=self.gradient_checkpointing, 1091 | ) 1092 | self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) 1093 | self.classifier = nn.Dense(1, dtype=self.dtype) 1094 | 1095 | def __call__( 1096 | self, 1097 | input_ids, 1098 | attention_mask, 1099 | token_type_ids, 1100 | position_ids, 1101 | head_mask, 1102 | deterministic: bool = True, 1103 | output_attentions: bool = False, 1104 | output_hidden_states: bool = False, 1105 | return_dict: bool = True, 1106 | ): 1107 | num_choices = input_ids.shape[1] 1108 | input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None 1109 | attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) if attention_mask is not None else None 1110 | token_type_ids = token_type_ids.reshape(-1, token_type_ids.shape[-1]) if token_type_ids is not None else None 1111 | position_ids = position_ids.reshape(-1, position_ids.shape[-1]) if position_ids is not None else None 1112 | 1113 | # Model 1114 | outputs = self.bert( 1115 | input_ids, 1116 | attention_mask, 1117 | token_type_ids, 1118 | position_ids, 1119 | head_mask, 1120 | deterministic=deterministic, 1121 | output_attentions=output_attentions, 1122 | output_hidden_states=output_hidden_states, 1123 | return_dict=return_dict, 1124 | ) 1125 | 1126 | pooled_output = outputs[1] 1127 | pooled_output = self.dropout(pooled_output, deterministic=deterministic) 1128 | logits = self.classifier(pooled_output) 1129 | 1130 | reshaped_logits = logits.reshape(-1, num_choices) 1131 | 1132 | if not return_dict: 1133 | return (reshaped_logits,) + outputs[2:] 1134 | 1135 | return FlaxMultipleChoiceModelOutput( 1136 | logits=reshaped_logits, 1137 | hidden_states=outputs.hidden_states, 1138 | attentions=outputs.attentions, 1139 | ) 1140 | 1141 | 1142 | class FlaxBertForTokenClassificationModule(nn.Module): 1143 | config: BertConfig 1144 | dtype: jnp.dtype = jnp.float32 1145 | gradient_checkpointing: bool = False 1146 | 1147 | def setup(self): 1148 | self.bert = FlaxBertModule( 1149 | config=self.config, 1150 | dtype=self.dtype, 1151 | add_pooling_layer=False, 1152 | gradient_checkpointing=self.gradient_checkpointing, 1153 | ) 1154 | classifier_dropout = ( 1155 | self.config.classifier_dropout 1156 | if self.config.classifier_dropout is not None 1157 | else self.config.hidden_dropout_prob 1158 | ) 1159 | self.dropout = nn.Dropout(rate=classifier_dropout) 1160 | self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype) 1161 | 1162 | def __call__( 1163 | self, 1164 | input_ids, 1165 | attention_mask, 1166 | token_type_ids, 1167 | position_ids, 1168 | head_mask, 1169 | deterministic: bool = True, 1170 | output_attentions: bool = False, 1171 | output_hidden_states: bool = False, 1172 | return_dict: bool = True, 1173 | ): 1174 | # Model 1175 | outputs = self.bert( 1176 | input_ids, 1177 | attention_mask, 1178 | token_type_ids, 1179 | position_ids, 1180 | head_mask, 1181 | deterministic=deterministic, 1182 | output_attentions=output_attentions, 1183 | output_hidden_states=output_hidden_states, 1184 | return_dict=return_dict, 1185 | ) 1186 | 1187 | hidden_states = outputs[0] 1188 | hidden_states = self.dropout(hidden_states, deterministic=deterministic) 1189 | logits = self.classifier(hidden_states) 1190 | 1191 | if not return_dict: 1192 | return (logits,) + outputs[1:] 1193 | 1194 | return FlaxTokenClassifierOutput( 1195 | logits=logits, 1196 | hidden_states=outputs.hidden_states, 1197 | attentions=outputs.attentions, 1198 | ) 1199 | 1200 | 1201 | class FlaxBertForQuestionAnsweringModule(nn.Module): 1202 | config: BertConfig 1203 | dtype: jnp.dtype = jnp.float32 1204 | gradient_checkpointing: bool = False 1205 | 1206 | def setup(self): 1207 | self.bert = FlaxBertModule( 1208 | config=self.config, 1209 | dtype=self.dtype, 1210 | add_pooling_layer=False, 1211 | gradient_checkpointing=self.gradient_checkpointing, 1212 | ) 1213 | self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype) 1214 | 1215 | def __call__( 1216 | self, 1217 | input_ids, 1218 | attention_mask, 1219 | token_type_ids, 1220 | position_ids, 1221 | head_mask, 1222 | deterministic: bool = True, 1223 | output_attentions: bool = False, 1224 | output_hidden_states: bool = False, 1225 | return_dict: bool = True, 1226 | ): 1227 | # Model 1228 | outputs = self.bert( 1229 | input_ids, 1230 | attention_mask, 1231 | token_type_ids, 1232 | position_ids, 1233 | head_mask, 1234 | deterministic=deterministic, 1235 | output_attentions=output_attentions, 1236 | output_hidden_states=output_hidden_states, 1237 | return_dict=return_dict, 1238 | ) 1239 | 1240 | hidden_states = outputs[0] 1241 | 1242 | logits = self.qa_outputs(hidden_states) 1243 | start_logits, end_logits = logits.split(self.config.num_labels, axis=-1) 1244 | start_logits = start_logits.squeeze(-1) 1245 | end_logits = end_logits.squeeze(-1) 1246 | 1247 | if not return_dict: 1248 | return (start_logits, end_logits) + outputs[1:] 1249 | 1250 | return FlaxQuestionAnsweringModelOutput( 1251 | start_logits=start_logits, 1252 | end_logits=end_logits, 1253 | hidden_states=outputs.hidden_states, 1254 | attentions=outputs.attentions, 1255 | ) 1256 | 1257 | class FlaxBertForQuestionAnswering(FlaxBertPreTrainedModel): 1258 | module_class = FlaxBertForQuestionAnsweringModule -------------------------------------------------------------------------------- /magix/models/gemma_model.py: -------------------------------------------------------------------------------- 1 | """Flax Gemma model.""" 2 | from typing import Optional, Tuple 3 | 4 | import math 5 | import flax.linen as nn 6 | import jax 7 | import jax.numpy as jnp 8 | import numpy as np 9 | from flax.core.frozen_dict import FrozenDict, freeze, unfreeze 10 | from flax.linen import combine_masks, make_causal_mask 11 | from flax.linen.attention import dot_product_attention_weights 12 | from flax.traverse_util import flatten_dict, unflatten_dict 13 | from jax import lax 14 | from jax.sharding import PartitionSpec as PS 15 | 16 | from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput 17 | from transformers.modeling_flax_utils import ACT2FN, FlaxPreTrainedModel 18 | from transformers.utils import logging 19 | from transformers.models.gemma import GemmaConfig 20 | 21 | try: 22 | from transformer_engine.jax import fused_attn as te_attn 23 | _te_available = True 24 | except ImportError: 25 | _te_available = False 26 | 27 | logger = logging.get_logger(__name__) 28 | 29 | def create_sinusoidal_positions(num_pos, dim): 30 | inv_freq = 1.0 / (10000 ** (np.arange(0, dim, 2)[: (dim // 2)] / dim)) 31 | freqs = np.einsum("i , j -> i j", np.arange(num_pos), inv_freq).astype("float32") 32 | 33 | emb = np.concatenate((freqs, freqs), axis=-1) 34 | out = np.concatenate((np.sin(emb)[:, None, :], np.cos(emb)[:, None, :]), axis=-1) 35 | return jnp.array(out[:, :, :num_pos]) 36 | 37 | 38 | # Copied from transformers.models.llama.modeling_flax_llama.rotate_half 39 | def rotate_half(tensor): 40 | """Rotates half the hidden dims of the input.""" 41 | rotate_half_tensor = jnp.concatenate( 42 | (-tensor[..., tensor.shape[-1] // 2 :], tensor[..., : tensor.shape[-1] // 2]), axis=-1 43 | ) 44 | return rotate_half_tensor 45 | 46 | 47 | # Copied from transformers.models.llama.modeling_flax_llama.apply_rotary_pos_emb 48 | def apply_rotary_pos_emb(tensor, sin_pos, cos_pos): 49 | return (tensor * cos_pos) + (rotate_half(tensor) * sin_pos) 50 | 51 | 52 | class FlaxGemmaRMSNorm(nn.Module): 53 | config: GemmaConfig 54 | dtype: jnp.dtype = jnp.float32 55 | 56 | def setup(self): 57 | self.epsilon = self.config.rms_norm_eps 58 | self.weight = self.param("weight", lambda _, shape: jnp.ones(shape), self.config.hidden_size) 59 | 60 | def __call__(self, hidden_states): 61 | variance = jnp.asarray(hidden_states, dtype=jnp.float32) 62 | variance = jnp.power(variance, 2) 63 | variance = variance.mean(-1, keepdims=True) 64 | # use `jax.numpy.sqrt` as `jax.lax.rsqrt` does not match `torch.rsqrt` 65 | hidden_states = hidden_states / jnp.sqrt(variance + self.epsilon) 66 | 67 | return (1 + self.weight) * jnp.asarray(hidden_states, dtype=self.dtype) 68 | 69 | 70 | # Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaRotaryEmbedding with Llama->Gemma 71 | class FlaxGemmaRotaryEmbedding(nn.Module): 72 | config: GemmaConfig 73 | dtype: jnp.dtype = jnp.float32 74 | 75 | # Ignore copy 76 | def setup(self): 77 | head_dim = self.config.head_dim 78 | self.sincos = create_sinusoidal_positions(self.config.max_position_embeddings, head_dim) 79 | 80 | def __call__(self, key, query, position_ids): 81 | sincos = self.sincos[position_ids] 82 | sin_pos, cos_pos = jnp.split(sincos, 2, axis=-1) 83 | 84 | key = apply_rotary_pos_emb(key, sin_pos, cos_pos) 85 | query = apply_rotary_pos_emb(query, sin_pos, cos_pos) 86 | 87 | key = jnp.asarray(key, dtype=self.dtype) 88 | query = jnp.asarray(query, dtype=self.dtype) 89 | 90 | return key, query 91 | 92 | 93 | class FlaxGemmaAttention(nn.Module): 94 | config: GemmaConfig 95 | dtype: jnp.dtype = jnp.float32 96 | causal: bool = True 97 | fused_attention: bool = True 98 | 99 | def setup(self): 100 | config = self.config 101 | self.embed_dim = config.hidden_size 102 | self.num_heads = config.num_attention_heads 103 | self.head_dim = config.head_dim 104 | self.attention_softmax_in_fp32 = self.dtype is not jnp.float32 105 | 106 | self.num_key_value_heads = config.num_key_value_heads 107 | self.num_key_value_groups = self.num_heads // self.num_key_value_heads 108 | 109 | kernel = jax.nn.initializers.normal(self.config.initializer_range) 110 | self.q_proj = nn.Dense( 111 | self.num_heads * self.head_dim, 112 | use_bias=config.attention_bias, 113 | dtype=jnp.bfloat16, 114 | kernel_init=kernel 115 | ) 116 | self.k_proj = nn.Dense( 117 | self.num_key_value_heads * self.head_dim, 118 | use_bias=config.attention_bias, 119 | dtype=jnp.bfloat16, 120 | kernel_init=kernel, 121 | ) 122 | self.v_proj = nn.Dense( 123 | self.num_key_value_heads * self.head_dim, 124 | use_bias=config.attention_bias, 125 | dtype=jnp.bfloat16, 126 | kernel_init=kernel, 127 | ) 128 | self.o_proj = nn.Dense( 129 | self.embed_dim, 130 | use_bias=config.attention_bias, 131 | dtype=jnp.bfloat16, 132 | kernel_init=kernel 133 | ) 134 | 135 | self.causal_mask = make_causal_mask(jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool") 136 | self.rotary_emb = FlaxGemmaRotaryEmbedding(config, dtype=self.dtype) 137 | 138 | def _split_heads(self, hidden_states, num_heads): 139 | return hidden_states.reshape(hidden_states.shape[:2] + (num_heads, self.head_dim)) 140 | 141 | def _merge_heads(self, hidden_states): 142 | return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads * self.head_dim,)) 143 | 144 | @nn.compact 145 | # Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoSelfAttention._concatenate_to_cache 146 | def _concatenate_to_cache(self, key, value, query, attention_mask): 147 | """ 148 | This function takes projected key, value states from a single input token and concatenates the states to cached 149 | states from previous steps. This function is slighly adapted from the official Flax repository: 150 | https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 151 | """ 152 | # detect if we're initializing by absence of existing cache data. 153 | is_initialized = self.has_variable("cache", "cached_key") 154 | cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) 155 | cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) 156 | cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) 157 | 158 | if is_initialized: 159 | *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape 160 | # update key, value caches with our new 1d spatial slices 161 | cur_index = cache_index.value 162 | indices = (0,) * len(batch_dims) + (cur_index, 0, 0) 163 | key = lax.dynamic_update_slice(cached_key.value, key, indices) 164 | value = lax.dynamic_update_slice(cached_value.value, value, indices) 165 | cached_key.value = key 166 | cached_value.value = value 167 | num_updated_cache_vectors = query.shape[1] 168 | cache_index.value = cache_index.value + num_updated_cache_vectors 169 | # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements. 170 | pad_mask = jnp.broadcast_to( 171 | jnp.arange(max_length) < cur_index + num_updated_cache_vectors, 172 | tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), 173 | ) 174 | attention_mask = combine_masks(pad_mask, attention_mask) 175 | return key, value, attention_mask 176 | 177 | def __call__( 178 | self, 179 | hidden_states, 180 | attention_mask, 181 | position_ids, 182 | deterministic: bool = True, 183 | init_cache: bool = False, 184 | output_attentions: bool = False, 185 | ): 186 | query = self.q_proj(hidden_states) 187 | key = self.k_proj(hidden_states) 188 | value = self.v_proj(hidden_states) 189 | 190 | query = lax.with_sharding_constraint(query, PS('data', None, 'model')) 191 | key = lax.with_sharding_constraint(key, PS('data', None, 'model')) 192 | value = lax.with_sharding_constraint(value, PS('data', None, 'model')) 193 | 194 | query = self._split_heads(query, self.num_heads) 195 | key = self._split_heads(key, self.num_key_value_heads) 196 | value = self._split_heads(value, self.num_key_value_heads) 197 | 198 | key, query = self.rotary_emb(key, query, position_ids) 199 | 200 | query_length, key_length = query.shape[1], key.shape[1] 201 | 202 | if self.has_variable("cache", "cached_key"): 203 | mask_shift = self.variables["cache"]["cache_index"] 204 | max_decoder_length = self.variables["cache"]["cached_key"].shape[1] 205 | causal_mask = lax.dynamic_slice( 206 | self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length) 207 | ) 208 | else: 209 | causal_mask = self.causal_mask[:, :, :query_length, :key_length] 210 | 211 | batch_size = hidden_states.shape[0] 212 | causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) 213 | 214 | attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) 215 | attention_mask = combine_masks(attention_mask, causal_mask) 216 | 217 | dropout_rng = None 218 | if not deterministic and self.config.attention_dropout > 0.0: 219 | dropout_rng = self.make_rng("dropout") 220 | 221 | # During fast autoregressive decoding, we feed one position at a time, 222 | # and cache the keys and values step by step. 223 | if self.has_variable("cache", "cached_key") or init_cache: 224 | key, value, attention_mask = self._concatenate_to_cache(key, value, query, attention_mask) 225 | 226 | use_fused_attention = ( 227 | self.fused_attention 228 | and _te_available 229 | and query_length >= 32 230 | and not init_cache 231 | and not self.has_variable("cache", "cached_key") 232 | ) 233 | 234 | if not use_fused_attention: 235 | attention_bias = lax.select( 236 | attention_mask > 0, 237 | jnp.full(attention_mask.shape, 0.0).astype(self.dtype), 238 | jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), 239 | ) 240 | 241 | key = jnp.repeat(key, repeats=self.num_key_value_groups, axis=2) 242 | value = jnp.repeat(value, repeats=self.num_key_value_groups, axis=2) 243 | 244 | # usual dot product attention 245 | attention_dtype = jnp.float32 if self.attention_softmax_in_fp32 else self.dtype 246 | attn_weights = dot_product_attention_weights( 247 | query, 248 | key, 249 | bias=attention_bias, 250 | dropout_rng=dropout_rng, 251 | dropout_rate=self.config.attention_dropout, 252 | deterministic=deterministic, 253 | dtype=attention_dtype, 254 | ) 255 | 256 | if self.attention_softmax_in_fp32: 257 | attn_weights = attn_weights.astype(self.dtype) 258 | 259 | attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value) 260 | else: 261 | query, key, value = map(lambda x: x.astype(jnp.bfloat16), (query, key, value)) 262 | kv = jnp.stack([key, value], axis=2) 263 | 264 | attn_output = te_attn.cross_fused_attn( 265 | query, 266 | kv, 267 | None, 268 | ~attention_mask, 269 | None, 270 | attn_bias_type=te_attn.AttnBiasType.NO_BIAS, 271 | attn_mask_type=te_attn.AttnMaskType.PADDING_CAUSAL_MASK, 272 | scaling_factor=1.0 / math.sqrt(query.shape[-1]), 273 | dropout_probability=self.config.attention_dropout, 274 | is_training=not deterministic, 275 | ) 276 | 277 | 278 | attn_output = self._merge_heads(attn_output) 279 | attn_output = self.o_proj(attn_output) 280 | 281 | outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) 282 | return outputs 283 | 284 | 285 | # Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaMLP with Llama->Gemma 286 | class FlaxGemmaMLP(nn.Module): 287 | config: GemmaConfig 288 | dtype: jnp.dtype = jnp.float32 289 | 290 | def setup(self): 291 | embed_dim = self.config.hidden_size 292 | inner_dim = self.config.intermediate_size if self.config.intermediate_size is not None else 4 * embed_dim 293 | 294 | kernel_init = jax.nn.initializers.normal(self.config.initializer_range) 295 | self.act = ACT2FN[self.config.hidden_act] 296 | 297 | self.gate_proj = nn.Dense(inner_dim, use_bias=False, dtype=jnp.bfloat16, kernel_init=kernel_init) 298 | self.down_proj = nn.Dense(embed_dim, use_bias=False, dtype=jnp.bfloat16, kernel_init=kernel_init) 299 | self.up_proj = nn.Dense(inner_dim, use_bias=False, dtype=jnp.bfloat16, kernel_init=kernel_init) 300 | 301 | def __call__(self, hidden_states): 302 | up_proj_states = self.up_proj(hidden_states) 303 | gate_states = self.act(self.gate_proj(hidden_states)) 304 | 305 | up_proj_states = lax.with_sharding_constraint( 306 | up_proj_states, PS('data', None, 'model')) 307 | gate_states = lax.with_sharding_constraint( 308 | gate_states, PS('data', None, 'model')) 309 | 310 | hidden_states = self.down_proj(up_proj_states * gate_states) 311 | return hidden_states 312 | 313 | 314 | # Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaDecoderLayer with Llama->Gemma 315 | class FlaxGemmaDecoderLayer(nn.Module): 316 | config: GemmaConfig 317 | dtype: jnp.dtype = jnp.float32 318 | 319 | def setup(self): 320 | self.input_layernorm = FlaxGemmaRMSNorm(self.config, dtype=self.dtype) 321 | self.self_attn = FlaxGemmaAttention(self.config, dtype=self.dtype) 322 | self.post_attention_layernorm = FlaxGemmaRMSNorm(self.config, dtype=self.dtype) 323 | self.mlp = FlaxGemmaMLP(self.config, dtype=self.dtype) 324 | 325 | def __call__( 326 | self, 327 | hidden_states, 328 | attention_mask=None, 329 | position_ids=None, 330 | deterministic: bool = True, 331 | init_cache: bool = False, 332 | output_attentions: bool = False, 333 | ): 334 | residual = hidden_states 335 | hidden_states = self.input_layernorm(hidden_states) 336 | outputs = self.self_attn( 337 | hidden_states, 338 | attention_mask=attention_mask, 339 | position_ids=position_ids, 340 | deterministic=deterministic, 341 | init_cache=init_cache, 342 | output_attentions=output_attentions, 343 | ) 344 | # residual connection 345 | attn_output = outputs[0] 346 | hidden_states = residual + attn_output 347 | 348 | residual = hidden_states 349 | hidden_states = self.post_attention_layernorm(hidden_states) 350 | hidden_states = self.mlp(hidden_states) 351 | # residual connection 352 | hidden_states = residual + hidden_states 353 | 354 | return (hidden_states,) + outputs[1:] 355 | 356 | 357 | # Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoPreTrainedModel with GPTNeo->Gemma, GPT_NEO->GEMMA, transformer->model 358 | class FlaxGemmaPreTrainedModel(FlaxPreTrainedModel): 359 | """ 360 | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained 361 | models. 362 | """ 363 | 364 | config_class = GemmaConfig 365 | base_model_prefix = "model" 366 | module_class: nn.Module = None 367 | 368 | partition_rules = { 369 | 'embed_tokens/embedding': PS('model', 'data'), 370 | 'lm_head': PS('data', 'model'), 371 | 'mlp/(gate|up)_proj': PS('data', 'model'), 372 | 'mlp/down_proj': PS('model', 'data'), 373 | 'self_attn/(k|q|v)_proj': PS('data', 'model'), 374 | 'self_attn/o_proj': PS('model', 'data'), 375 | # 'layernorm/weight': PS('model'), 376 | } 377 | 378 | def __init__( 379 | self, 380 | config: GemmaConfig, 381 | input_shape: Tuple = (1, 1), 382 | seed: int = 0, 383 | dtype: jnp.dtype = jnp.float32, 384 | _do_init: bool = True, 385 | **kwargs, 386 | ): 387 | module = self.module_class(config=config, dtype=dtype, **kwargs) 388 | super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) 389 | 390 | def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: 391 | # init input tensors 392 | input_ids = jnp.zeros(input_shape, dtype="i4") 393 | attention_mask = jnp.ones_like(input_ids) 394 | position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape) 395 | params_rng, dropout_rng = jax.random.split(rng) 396 | rngs = {"params": params_rng, "dropout": dropout_rng} 397 | 398 | random_params = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)["params"] 399 | 400 | if params is not None: 401 | random_params = flatten_dict(unfreeze(random_params)) 402 | params = flatten_dict(unfreeze(params)) 403 | for missing_key in self._missing_keys: 404 | params[missing_key] = random_params[missing_key] 405 | self._missing_keys = set() 406 | return freeze(unflatten_dict(params)) 407 | else: 408 | return random_params 409 | 410 | def init_cache(self, batch_size, max_length): 411 | r""" 412 | Args: 413 | batch_size (`int`): 414 | batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. 415 | max_length (`int`): 416 | maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized 417 | cache. 418 | """ 419 | # init input variables to retrieve cache 420 | input_ids = jnp.ones((batch_size, max_length)) 421 | attention_mask = jnp.ones_like(input_ids) 422 | position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) 423 | 424 | init_variables = self.module.init( 425 | jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True 426 | ) 427 | return unfreeze(init_variables["cache"]) 428 | 429 | def __call__( 430 | self, 431 | input_ids, 432 | attention_mask=None, 433 | position_ids=None, 434 | params: dict = None, 435 | past_key_values: dict = None, 436 | dropout_rng: jax.random.PRNGKey = None, 437 | train: bool = False, 438 | output_attentions: Optional[bool] = None, 439 | output_hidden_states: Optional[bool] = None, 440 | return_dict: Optional[bool] = None, 441 | ): 442 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 443 | output_hidden_states = ( 444 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 445 | ) 446 | return_dict = return_dict if return_dict is not None else self.config.return_dict 447 | 448 | batch_size, sequence_length = input_ids.shape 449 | 450 | if position_ids is None: 451 | if past_key_values is not None: 452 | raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.") 453 | 454 | position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) 455 | 456 | if attention_mask is None: 457 | attention_mask = jnp.ones((batch_size, sequence_length)) 458 | 459 | # Handle any PRNG if needed 460 | rngs = {} 461 | if dropout_rng is not None: 462 | rngs["dropout"] = dropout_rng 463 | 464 | inputs = {"params": params or self.params} 465 | 466 | # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be changed by FlaxGemmaAttention module 467 | if past_key_values: 468 | inputs["cache"] = past_key_values 469 | mutable = ["cache"] 470 | else: 471 | mutable = False 472 | 473 | outputs = self.module.apply( 474 | inputs, 475 | jnp.array(input_ids, dtype="i4"), 476 | jnp.array(attention_mask, dtype="i4"), 477 | jnp.array(position_ids, dtype="i4"), 478 | not train, 479 | False, 480 | output_attentions, 481 | output_hidden_states, 482 | return_dict, 483 | rngs=rngs, 484 | mutable=mutable, 485 | ) 486 | 487 | # add updated cache to model output 488 | if past_key_values is not None and return_dict: 489 | outputs, past_key_values = outputs 490 | outputs["past_key_values"] = unfreeze(past_key_values["cache"]) 491 | return outputs 492 | elif past_key_values is not None and not return_dict: 493 | outputs, past_key_values = outputs 494 | outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:] 495 | 496 | return outputs 497 | 498 | 499 | # Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaLayerCollection with Llama->Gemma 500 | class FlaxGemmaLayerCollection(nn.Module): 501 | config: GemmaConfig 502 | dtype: jnp.dtype = jnp.float32 503 | 504 | def setup(self): 505 | self.blocks = [ 506 | FlaxGemmaDecoderLayer(self.config, dtype=self.dtype, name=str(i)) 507 | for i in range(self.config.num_hidden_layers) 508 | ] 509 | 510 | def __call__( 511 | self, 512 | hidden_states, 513 | attention_mask=None, 514 | position_ids=None, 515 | deterministic: bool = True, 516 | init_cache: bool = False, 517 | output_attentions: bool = False, 518 | output_hidden_states: bool = False, 519 | return_dict: bool = False, 520 | ): 521 | all_attentions = () if output_attentions else None 522 | all_hidden_states = () if output_hidden_states else None 523 | 524 | for block in self.blocks: 525 | if output_hidden_states: 526 | all_hidden_states += (hidden_states,) 527 | layer_outputs = block( 528 | hidden_states, 529 | attention_mask=attention_mask, 530 | position_ids=position_ids, 531 | deterministic=deterministic, 532 | init_cache=init_cache, 533 | output_attentions=output_attentions, 534 | ) 535 | hidden_states = layer_outputs[0] 536 | 537 | if output_attentions: 538 | all_attentions += (layer_outputs[1],) 539 | 540 | # this contains possible `None` values - `FlaxGemmaModule` will filter them out 541 | outputs = (hidden_states, all_hidden_states, all_attentions) 542 | 543 | return outputs 544 | 545 | 546 | # Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaModule with Llama->Gemma 547 | class FlaxGemmaModule(nn.Module): 548 | config: GemmaConfig 549 | dtype: jnp.dtype = jnp.float32 550 | 551 | def setup(self): 552 | self.hidden_size = self.config.hidden_size 553 | embedding_init = jax.nn.initializers.normal(stddev=self.config.initializer_range) 554 | self.embed_tokens = nn.Embed( 555 | self.config.vocab_size, 556 | self.hidden_size, 557 | embedding_init=embedding_init, 558 | dtype=self.dtype, 559 | ) 560 | self.layers = FlaxGemmaLayerCollection(self.config, dtype=self.dtype) 561 | self.norm = FlaxGemmaRMSNorm(self.config, dtype=self.dtype) 562 | 563 | # Ignore copy 564 | def __call__( 565 | self, 566 | input_ids, 567 | attention_mask=None, 568 | position_ids=None, 569 | deterministic=True, 570 | init_cache: bool = False, 571 | output_attentions: bool = False, 572 | output_hidden_states: bool = False, 573 | return_dict: bool = True, 574 | ): 575 | input_ids = lax.with_sharding_constraint(input_ids, PS('data', None,)) 576 | input_embeds = self.embed_tokens(input_ids.astype("i4")) 577 | input_embeds = lax.with_sharding_constraint(input_embeds, PS('data', None, 'model')) 578 | 579 | input_embeds = input_embeds * (self.config.hidden_size**0.5) 580 | 581 | outputs = self.layers( 582 | input_embeds, 583 | position_ids=position_ids, 584 | attention_mask=attention_mask, 585 | deterministic=deterministic, 586 | init_cache=init_cache, 587 | output_attentions=output_attentions, 588 | output_hidden_states=output_hidden_states, 589 | return_dict=return_dict, 590 | ) 591 | 592 | hidden_states = outputs[0] 593 | hidden_states = self.norm(hidden_states) 594 | 595 | if output_hidden_states: 596 | all_hidden_states = outputs[1] + (hidden_states,) 597 | outputs = (hidden_states, all_hidden_states) + outputs[2:] 598 | else: 599 | outputs = (hidden_states,) + outputs[1:] 600 | 601 | if not return_dict: 602 | return tuple(v for v in outputs if v is not None) 603 | 604 | return FlaxBaseModelOutput( 605 | last_hidden_state=hidden_states, 606 | hidden_states=outputs[1], 607 | attentions=outputs[-1], 608 | ) 609 | 610 | 611 | # Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaModel with Llama->Gemma 612 | class FlaxGemmaModel(FlaxGemmaPreTrainedModel): 613 | module_class = FlaxGemmaModule 614 | 615 | 616 | # Copied from transformers.models.llama.modeling_flax_llama.FlaxLlamaForCausalLMModule with Llama->Gemma 617 | class FlaxGemmaForCausalLMModule(nn.Module): 618 | config: GemmaConfig 619 | dtype: jnp.dtype = jnp.float32 620 | 621 | def setup(self): 622 | self.model = FlaxGemmaModule(self.config, dtype=self.dtype) 623 | self.lm_head = nn.Dense( 624 | self.config.vocab_size, 625 | use_bias=False, 626 | dtype=jnp.bfloat16, 627 | kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), 628 | ) 629 | 630 | # Ignore copy 631 | def __call__( 632 | self, 633 | input_ids, 634 | attention_mask=None, 635 | position_ids=None, 636 | deterministic: bool = True, 637 | init_cache: bool = False, 638 | output_attentions: bool = False, 639 | output_hidden_states: bool = False, 640 | return_dict: bool = True, 641 | ): 642 | outputs = self.model( 643 | input_ids, 644 | position_ids=position_ids, 645 | attention_mask=attention_mask, 646 | deterministic=deterministic, 647 | init_cache=init_cache, 648 | output_attentions=output_attentions, 649 | output_hidden_states=output_hidden_states, 650 | return_dict=return_dict, 651 | ) 652 | 653 | hidden_states = outputs[0] 654 | if self.config.tie_word_embeddings: 655 | shared_kernel = self.model.variables["params"]["embed_tokens"]["embedding"].T 656 | lm_logits = self.lm_head.apply({"params": {"kernel": shared_kernel}}, hidden_states) 657 | else: 658 | lm_logits = self.lm_head(hidden_states) 659 | 660 | if not return_dict: 661 | return (lm_logits,) + outputs[1:] 662 | 663 | return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions) 664 | 665 | 666 | # Copied from transformers.models.gptj.modeling_flax_gptj.FlaxGPTJForCausalLM with GPTJ->Gemma 667 | class FlaxGemmaForCausalLM(FlaxGemmaPreTrainedModel): 668 | module_class = FlaxGemmaForCausalLMModule 669 | 670 | def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None): 671 | # initializing the cache 672 | batch_size, seq_length = input_ids.shape 673 | 674 | past_key_values = self.init_cache(batch_size, max_length) 675 | # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. 676 | # But since Gemma uses a causal mask, those positions are masked anyways. 677 | # Thus we can create a single static attention_mask here, which is more efficient for compilation 678 | extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") 679 | if attention_mask is not None: 680 | position_ids = attention_mask.cumsum(axis=-1) - 1 681 | extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) 682 | else: 683 | position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) 684 | 685 | return { 686 | "past_key_values": past_key_values, 687 | "attention_mask": extended_attention_mask, 688 | "position_ids": position_ids, 689 | } 690 | 691 | def update_inputs_for_generation(self, model_outputs, model_kwargs): 692 | model_kwargs["past_key_values"] = model_outputs.past_key_values 693 | model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1 694 | return model_kwargs -------------------------------------------------------------------------------- /magix/models/llama_model.py: -------------------------------------------------------------------------------- 1 | """Flax LLaMA model.""" 2 | from functools import partial 3 | from typing import Optional, Tuple 4 | 5 | import math 6 | 7 | import jax 8 | import jax.numpy as jnp 9 | import numpy as np 10 | from jax import lax 11 | from jax.sharding import Mesh, PartitionSpec as PS 12 | from jax.experimental.pallas.ops import attention as attn_ops 13 | from jax.experimental.shard_map import shard_map 14 | from jax._src import mesh as mesh_lib 15 | import jax.experimental.pallas.ops.tpu.flash_attention as tpu_attn_ops 16 | import flax.linen as nn 17 | from flax.core.frozen_dict import FrozenDict, freeze, unfreeze 18 | from flax.linen import combine_masks, make_causal_mask 19 | from flax.linen.attention import dot_product_attention_weights 20 | from flax.traverse_util import flatten_dict, unflatten_dict 21 | 22 | from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput 23 | from transformers.modeling_flax_utils import ACT2FN, FlaxPreTrainedModel 24 | from transformers.utils import logging 25 | from transformers import LlamaConfig 26 | 27 | try: 28 | from transformer_engine.jax import fused_attn as te_attn 29 | _te_available = True 30 | except ImportError: 31 | _te_available = False 32 | 33 | logger = logging.get_logger(__name__) 34 | 35 | 36 | 37 | def create_sinusoidal_positions(num_pos, dim, base=10000): 38 | inv_freq = 1.0 / (base ** (np.arange(0, dim, 2) / dim)) 39 | freqs = np.einsum("i , j -> i j", np.arange(num_pos), inv_freq).astype("float32") 40 | 41 | emb = np.concatenate((freqs, freqs), axis=-1) 42 | out = np.concatenate((np.sin(emb)[:, None, :], np.cos(emb)[:, None, :]), axis=-1) 43 | return jnp.array(out[:, :, :num_pos]) 44 | 45 | 46 | def rotate_half(tensor): 47 | """Rotates half the hidden dims of the input.""" 48 | rotate_half_tensor = jnp.concatenate( 49 | (-tensor[..., tensor.shape[-1] // 2 :], tensor[..., : tensor.shape[-1] // 2]), axis=-1 50 | ) 51 | return rotate_half_tensor 52 | 53 | 54 | def apply_rotary_pos_emb(tensor, sin_pos, cos_pos): 55 | return (tensor * cos_pos) + (rotate_half(tensor) * sin_pos) 56 | 57 | 58 | class FlaxLlamaRMSNorm(nn.Module): 59 | config: LlamaConfig 60 | dtype: jnp.dtype = jnp.float32 61 | 62 | def setup(self): 63 | self.epsilon = self.config.rms_norm_eps 64 | self.weight = self.param("weight", lambda _, shape: jnp.ones(shape), self.config.hidden_size) 65 | 66 | def __call__(self, hidden_states): 67 | variance = jnp.asarray(hidden_states, dtype=jnp.float32) 68 | variance = jnp.power(variance, 2) 69 | variance = variance.mean(-1, keepdims=True) 70 | # use `jax.numpy.sqrt` as `jax.lax.rsqrt` does not match `torch.rsqrt` 71 | hidden_states = hidden_states / jnp.sqrt(variance + self.epsilon) 72 | 73 | return self.weight * jnp.asarray(hidden_states, dtype=self.dtype) 74 | 75 | 76 | class FlaxLlamaRotaryEmbedding(nn.Module): 77 | config: LlamaConfig 78 | dtype: jnp.dtype = jnp.float32 79 | 80 | def setup(self): 81 | head_dim = self.config.hidden_size // self.config.num_attention_heads 82 | self.sincos = create_sinusoidal_positions( 83 | self.config.max_position_embeddings, head_dim, base=self.config.rope_theta) 84 | 85 | def __call__(self, key, query, position_ids): 86 | sincos = self.sincos[position_ids] 87 | sin_pos, cos_pos = jnp.split(sincos, 2, axis=-1) 88 | 89 | key = apply_rotary_pos_emb(key, sin_pos, cos_pos) 90 | query = apply_rotary_pos_emb(query, sin_pos, cos_pos) 91 | 92 | key = jnp.asarray(key, dtype=self.dtype) 93 | query = jnp.asarray(query, dtype=self.dtype) 94 | 95 | return key, query 96 | 97 | 98 | class FlaxLlamaAttention(nn.Module): 99 | config: LlamaConfig 100 | dtype: jnp.dtype = jnp.float32 101 | causal: bool = True 102 | is_cross_attention: bool = False 103 | fused_attention: bool = True 104 | 105 | def setup(self): 106 | config = self.config 107 | self.embed_dim = config.hidden_size 108 | self.num_heads = config.num_attention_heads 109 | self.num_kv_heads = getattr(config, "num_key_value_heads", self.num_heads) 110 | self.head_dim = self.embed_dim // self.num_heads 111 | self.attention_softmax_in_fp32 = self.dtype is not jnp.float32 112 | 113 | dense = partial( 114 | nn.Dense, 115 | use_bias=config.attention_bias, 116 | dtype=jnp.bfloat16, 117 | kernel_init=jax.nn.initializers.normal(self.config.initializer_range), 118 | ) 119 | 120 | self.q_proj = dense(self.num_heads * self.head_dim) 121 | self.k_proj = dense(self.num_kv_heads * self.head_dim) 122 | self.v_proj = dense(self.num_kv_heads * self.head_dim) 123 | self.o_proj = dense(self.embed_dim) 124 | 125 | self.causal_mask = make_causal_mask(jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool") 126 | self.rotary_emb = FlaxLlamaRotaryEmbedding(config, dtype=jnp.float32) 127 | 128 | def _split_heads(self, hidden_states, num_heads): 129 | return hidden_states.reshape(hidden_states.shape[:2] + (num_heads, self.head_dim)) 130 | 131 | def _merge_heads(self, hidden_states): 132 | return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,)) 133 | 134 | def repeat_hidden_states(self, hidden_states, times, axis=-2): 135 | return jnp.repeat(hidden_states, times, axis=axis) 136 | 137 | @nn.compact 138 | def _concatenate_to_cache(self, key, value, query, attention_mask): 139 | """ 140 | This function takes projected key, value states from a single input token and concatenates the states to cached 141 | states from previous steps. This function is slighly adapted from the official Flax repository: 142 | https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 143 | """ 144 | # detect if we're initializing by absence of existing cache data. 145 | is_initialized = self.has_variable("cache", "cached_key") 146 | cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) 147 | cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) 148 | cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) 149 | 150 | if is_initialized: 151 | *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape 152 | # update key, value caches with our new 1d spatial slices 153 | cur_index = cache_index.value 154 | indices = (0,) * len(batch_dims) + (cur_index, 0, 0) 155 | key = lax.dynamic_update_slice(cached_key.value, key, indices) 156 | value = lax.dynamic_update_slice(cached_value.value, value, indices) 157 | cached_key.value = key 158 | cached_value.value = value 159 | num_updated_cache_vectors = query.shape[1] 160 | cache_index.value = cache_index.value + num_updated_cache_vectors 161 | # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements. 162 | pad_mask = jnp.broadcast_to( 163 | jnp.arange(max_length) < cur_index + num_updated_cache_vectors, 164 | tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), 165 | ) 166 | attention_mask = combine_masks(pad_mask, attention_mask) 167 | return key, value, attention_mask 168 | 169 | def __call__( 170 | self, 171 | hidden_states, 172 | attention_mask, 173 | position_ids, 174 | deterministic: bool = True, 175 | init_cache: bool = False, 176 | output_attentions: bool = False, 177 | ): 178 | query = self.q_proj(hidden_states) 179 | key = self.k_proj(hidden_states) 180 | value = self.v_proj(hidden_states) 181 | 182 | query = lax.with_sharding_constraint(query, PS('data', None, 'model')) 183 | key = lax.with_sharding_constraint(key, PS('data', None, 'model')) 184 | value = lax.with_sharding_constraint(value, PS('data', None, 'model')) 185 | 186 | query = self._split_heads(query, self.num_heads) 187 | key = self._split_heads(key, self.num_kv_heads) 188 | value = self._split_heads(value, self.num_kv_heads) 189 | 190 | if self.num_heads != self.num_kv_heads: 191 | key = self.repeat_hidden_states(key, self.num_heads // self.num_kv_heads) 192 | value = self.repeat_hidden_states(value, self.num_heads // self.num_kv_heads) 193 | 194 | key, query = self.rotary_emb(key, query, position_ids) 195 | 196 | query_length, key_length = query.shape[1], key.shape[1] 197 | 198 | if self.has_variable("cache", "cached_key"): 199 | mask_shift = self.variables["cache"]["cache_index"] 200 | max_decoder_length = self.variables["cache"]["cached_key"].shape[1] 201 | causal_mask = lax.dynamic_slice( 202 | self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length) 203 | ) 204 | else: 205 | causal_mask = self.causal_mask[:, :, :query_length, :key_length] 206 | 207 | batch_size = hidden_states.shape[0] 208 | causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) 209 | 210 | attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) 211 | 212 | # During fast autoregressive decoding, we feed one position at a time, 213 | # and cache the keys and values step by step. 214 | if self.has_variable("cache", "cached_key") or init_cache: 215 | key, value, attention_mask = self._concatenate_to_cache(key, value, query, attention_mask) 216 | 217 | use_fused_attention = ( 218 | self.fused_attention 219 | and _te_available 220 | and query_length >= 32 221 | and not init_cache 222 | and not self.has_variable("cache", "cached_key") 223 | ) 224 | 225 | if not use_fused_attention: 226 | attention_mask = combine_masks(attention_mask, causal_mask) # make fp mask 227 | # transform boolean mask into float mask 228 | attention_bias = lax.select( 229 | attention_mask > 0, 230 | jnp.full(attention_mask.shape, 0.0).astype(self.dtype), 231 | jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), 232 | ) 233 | 234 | # usual dot product attention 235 | attention_dtype = jnp.float32 if self.attention_softmax_in_fp32 else self.dtype 236 | attn_weights = dot_product_attention_weights( 237 | query, 238 | key, 239 | bias=attention_bias, 240 | deterministic=deterministic, 241 | dtype=attention_dtype, 242 | ) 243 | 244 | if self.attention_softmax_in_fp32: 245 | attn_weights = attn_weights.astype(self.dtype) 246 | 247 | attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value) 248 | 249 | else: 250 | attention_mask = ~combine_masks(attention_mask, causal_mask, dtype='bool') # make bool mask 251 | 252 | query, key, value = map(lambda x: x.astype(jnp.bfloat16), (query, key, value)) 253 | qkv = jnp.stack((query, key, value), axis=2) 254 | 255 | attn_output = te_attn.self_fused_attn( 256 | qkv, 257 | None, 258 | attention_mask, 259 | None, 260 | attn_bias_type=te_attn.AttnBiasType.NO_BIAS, 261 | attn_mask_type=te_attn.AttnMaskType.PADDING_CAUSAL_MASK, 262 | scaling_factor=1.0 / math.sqrt(query.shape[-1]), 263 | dropout_probability=self.config.attention_dropout, 264 | is_training=not deterministic, 265 | ) 266 | 267 | attn_output = self._merge_heads(attn_output) 268 | attn_output = self.o_proj(attn_output) 269 | attn_output = lax.with_sharding_constraint(attn_output, PS('data', None, 'model')) 270 | 271 | outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) 272 | return outputs 273 | 274 | 275 | class FlaxLlamaMLP(nn.Module): 276 | config: LlamaConfig 277 | dtype: jnp.dtype = jnp.float32 278 | 279 | def setup(self): 280 | embed_dim = self.config.hidden_size 281 | inner_dim = self.config.intermediate_size if self.config.intermediate_size is not None else 4 * embed_dim 282 | 283 | kernel_init = jax.nn.initializers.normal(self.config.initializer_range) 284 | self.act = ACT2FN[self.config.hidden_act] 285 | 286 | dense = partial( 287 | nn.Dense, 288 | use_bias=False, 289 | dtype=jnp.bfloat16, 290 | kernel_init=kernel_init, 291 | ) 292 | 293 | self.gate_proj = dense(inner_dim) 294 | self.down_proj = dense(embed_dim) 295 | self.up_proj = dense(inner_dim) 296 | 297 | def __call__(self, hidden_states): 298 | up_proj_states = self.up_proj(hidden_states) 299 | gate_states = self.act(self.gate_proj(hidden_states)) 300 | 301 | up_proj_states = lax.with_sharding_constraint(up_proj_states, PS('data', None, 'model')) 302 | gate_states = lax.with_sharding_constraint(gate_states, PS('data', None, 'model')) 303 | 304 | hidden_states = up_proj_states * gate_states 305 | hidden_states = self.down_proj(hidden_states) 306 | return hidden_states 307 | 308 | 309 | class FlaxLlamaDecoderLayer(nn.Module): 310 | config: LlamaConfig 311 | dtype: jnp.dtype = jnp.float32 312 | 313 | def setup(self): 314 | self.input_layernorm = FlaxLlamaRMSNorm(self.config, dtype=self.dtype) 315 | self.self_attn = FlaxLlamaAttention(self.config, dtype=jnp.bfloat16) 316 | self.post_attention_layernorm = FlaxLlamaRMSNorm(self.config, dtype=self.dtype) 317 | self.mlp = FlaxLlamaMLP(self.config, dtype=jnp.bfloat16) 318 | 319 | def __call__( 320 | self, 321 | hidden_states, 322 | attention_mask=None, 323 | position_ids=None, 324 | deterministic: bool = True, 325 | init_cache: bool = False, 326 | output_attentions: bool = False, 327 | ): 328 | residual = hidden_states 329 | hidden_states = self.input_layernorm(hidden_states) 330 | outputs = self.self_attn( 331 | hidden_states, 332 | attention_mask=attention_mask, 333 | position_ids=position_ids, 334 | deterministic=deterministic, 335 | init_cache=init_cache, 336 | output_attentions=output_attentions, 337 | ) 338 | # residual connection 339 | attn_output = outputs[0] 340 | hidden_states = residual + attn_output 341 | 342 | residual = hidden_states 343 | hidden_states = self.post_attention_layernorm(hidden_states) 344 | hidden_states = self.mlp(hidden_states) 345 | 346 | # residual connection 347 | hidden_states = residual + hidden_states 348 | 349 | return (hidden_states,) + outputs[1:] 350 | 351 | 352 | class FlaxLlamaPreTrainedModel(FlaxPreTrainedModel): 353 | """ 354 | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained 355 | models. 356 | """ 357 | 358 | config_class = LlamaConfig 359 | base_model_prefix = "model" 360 | module_class: nn.Module = None 361 | 362 | partition_rules = { 363 | 'embed_tokens/embedding': PS('data', 'model'), 364 | 'lm_head': PS('data', 'model'), 365 | 'mlp/(gate|up)_proj': PS('data', 'model'), 366 | 'mlp/down_proj': PS('model', 'data'), 367 | 'self_attn/(k|q|v)_proj': PS('data', 'model'), 368 | 'self_attn/o_proj': PS('model', 'data'), 369 | # 'norm/weight': PS('model'), 370 | } 371 | 372 | def __init__( 373 | self, 374 | config: LlamaConfig, 375 | input_shape: Tuple = (1, 1), 376 | seed: int = 0, 377 | dtype: jnp.dtype = jnp.float32, 378 | _do_init: bool = True, 379 | **kwargs, 380 | ): 381 | module = self.module_class(config=config, dtype=dtype, **kwargs) 382 | super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) 383 | 384 | def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: 385 | # init input tensors 386 | input_ids = jnp.zeros(input_shape, dtype="i4") 387 | attention_mask = jnp.ones_like(input_ids) 388 | position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape) 389 | params_rng, dropout_rng = jax.random.split(rng) 390 | rngs = {"params": params_rng, "dropout": dropout_rng} 391 | 392 | random_params = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)["params"] 393 | 394 | if params is not None: 395 | random_params = flatten_dict(unfreeze(random_params)) 396 | params = flatten_dict(unfreeze(params)) 397 | for missing_key in self._missing_keys: 398 | params[missing_key] = random_params[missing_key] 399 | self._missing_keys = set() 400 | return freeze(unflatten_dict(params)) 401 | else: 402 | return random_params 403 | 404 | def init_cache(self, batch_size, max_length): 405 | r""" 406 | Args: 407 | batch_size (`int`): 408 | batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. 409 | max_length (`int`): 410 | maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized 411 | cache. 412 | """ 413 | # init input variables to retrieve cache 414 | input_ids = jnp.ones((batch_size, max_length)) 415 | attention_mask = jnp.ones_like(input_ids) 416 | position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) 417 | 418 | init_variables = self.module.init( 419 | jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True 420 | ) 421 | return unfreeze(init_variables["cache"]) 422 | 423 | def __call__( 424 | self, 425 | input_ids, 426 | attention_mask=None, 427 | position_ids=None, 428 | params: dict = None, 429 | past_key_values: dict = None, 430 | dropout_rng: jax.random.PRNGKey = None, 431 | train: bool = False, 432 | output_attentions: Optional[bool] = None, 433 | output_hidden_states: Optional[bool] = None, 434 | return_dict: Optional[bool] = None, 435 | ): 436 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 437 | output_hidden_states = ( 438 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 439 | ) 440 | return_dict = return_dict if return_dict is not None else self.config.return_dict 441 | 442 | batch_size, sequence_length = input_ids.shape 443 | 444 | if position_ids is None: 445 | if past_key_values is not None: 446 | raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.") 447 | 448 | position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) 449 | 450 | if attention_mask is None: 451 | attention_mask = jnp.ones((batch_size, sequence_length)) 452 | 453 | # Handle any PRNG if needed 454 | rngs = {} 455 | if dropout_rng is not None: 456 | rngs["dropout"] = dropout_rng 457 | 458 | inputs = {"params": params or self.params} 459 | 460 | # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be changed by FlaxLlamaAttention module 461 | if past_key_values: 462 | inputs["cache"] = past_key_values 463 | mutable = ["cache"] 464 | else: 465 | mutable = False 466 | 467 | outputs = self.module.apply( 468 | inputs, 469 | jnp.array(input_ids, dtype="i4"), 470 | jnp.array(attention_mask, dtype="i4"), 471 | jnp.array(position_ids, dtype="i4"), 472 | not train, 473 | False, 474 | output_attentions, 475 | output_hidden_states, 476 | return_dict, 477 | rngs=rngs, 478 | mutable=mutable, 479 | ) 480 | 481 | # add updated cache to model output 482 | if past_key_values is not None and return_dict: 483 | outputs, past_key_values = outputs 484 | outputs["past_key_values"] = unfreeze(past_key_values["cache"]) 485 | return outputs 486 | elif past_key_values is not None and not return_dict: 487 | outputs, past_key_values = outputs 488 | outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:] 489 | 490 | return outputs 491 | 492 | 493 | class FlaxLlamaLayerCollection(nn.Module): 494 | config: LlamaConfig 495 | dtype: jnp.dtype = jnp.float32 496 | 497 | def setup(self): 498 | self.blocks = [ 499 | FlaxLlamaDecoderLayer(self.config, dtype=self.dtype, name=str(i)) 500 | for i in range(self.config.num_hidden_layers) 501 | ] 502 | 503 | def __call__( 504 | self, 505 | hidden_states, 506 | attention_mask=None, 507 | position_ids=None, 508 | deterministic: bool = True, 509 | init_cache: bool = False, 510 | output_attentions: bool = False, 511 | output_hidden_states: bool = False, 512 | return_dict: bool = False, 513 | ): 514 | all_attentions = () if output_attentions else None 515 | all_hidden_states = () if output_hidden_states else None 516 | 517 | for block in self.blocks: 518 | if output_hidden_states: 519 | all_hidden_states += (hidden_states,) 520 | layer_outputs = block( 521 | hidden_states, 522 | attention_mask=attention_mask, 523 | position_ids=position_ids, 524 | deterministic=deterministic, 525 | init_cache=init_cache, 526 | output_attentions=output_attentions, 527 | ) 528 | hidden_states = layer_outputs[0] 529 | 530 | if output_attentions: 531 | all_attentions += (layer_outputs[1],) 532 | 533 | # this contains possible `None` values - `FlaxLlamaModule` will filter them out 534 | outputs = (hidden_states, all_hidden_states, all_attentions) 535 | 536 | return outputs 537 | 538 | 539 | class FlaxLlamaModule(nn.Module): 540 | config: LlamaConfig 541 | dtype: jnp.dtype = jnp.float32 542 | 543 | def setup(self): 544 | self.hidden_size = self.config.hidden_size 545 | embedding_init = jax.nn.initializers.normal(stddev=self.config.initializer_range) 546 | self.embed_tokens = nn.Embed( 547 | self.config.vocab_size, 548 | self.hidden_size, 549 | embedding_init=embedding_init, 550 | dtype=self.dtype, 551 | ) 552 | self.layers = FlaxLlamaLayerCollection(self.config, dtype=self.dtype) 553 | self.norm = FlaxLlamaRMSNorm(self.config, dtype=self.dtype) 554 | 555 | def __call__( 556 | self, 557 | input_ids, 558 | attention_mask=None, 559 | position_ids=None, 560 | deterministic=True, 561 | init_cache: bool = False, 562 | output_attentions: bool = False, 563 | output_hidden_states: bool = False, 564 | return_dict: bool = True, 565 | ): 566 | input_ids = lax.with_sharding_constraint(input_ids, PS('data', None,)) 567 | input_embeds = self.embed_tokens(input_ids.astype("i4")) 568 | input_embeds = lax.with_sharding_constraint(input_embeds, PS('data', None, 'model')) 569 | outputs = self.layers( 570 | input_embeds, 571 | position_ids=position_ids, 572 | attention_mask=attention_mask, 573 | deterministic=deterministic, 574 | init_cache=init_cache, 575 | output_attentions=output_attentions, 576 | output_hidden_states=output_hidden_states, 577 | return_dict=return_dict, 578 | ) 579 | 580 | hidden_states = outputs[0] 581 | hidden_states = self.norm(hidden_states) 582 | 583 | if output_hidden_states: 584 | all_hidden_states = outputs[1] + (hidden_states,) 585 | outputs = (hidden_states, all_hidden_states) + outputs[2:] 586 | else: 587 | outputs = (hidden_states,) + outputs[1:] 588 | 589 | if not return_dict: 590 | return tuple(v for v in outputs if v is not None) 591 | 592 | return FlaxBaseModelOutput( 593 | last_hidden_state=hidden_states, 594 | hidden_states=outputs[1], 595 | attentions=outputs[-1], 596 | ) 597 | 598 | 599 | 600 | class FlaxLlamaModel(FlaxLlamaPreTrainedModel): 601 | module_class = FlaxLlamaModule 602 | 603 | 604 | class FlaxLlamaForCausalLMModule(nn.Module): 605 | config: LlamaConfig 606 | dtype: jnp.dtype = jnp.float32 607 | 608 | def setup(self): 609 | self.model = FlaxLlamaModule(self.config, dtype=self.dtype) 610 | self.lm_head = nn.Dense( 611 | self.config.vocab_size, 612 | use_bias=False, 613 | dtype=jnp.bfloat16, 614 | kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), 615 | ) 616 | 617 | def __call__( 618 | self, 619 | input_ids, 620 | attention_mask=None, 621 | position_ids=None, 622 | deterministic: bool = True, 623 | init_cache: bool = False, 624 | output_attentions: bool = False, 625 | output_hidden_states: bool = False, 626 | return_dict: bool = True, 627 | ): 628 | outputs = self.model( 629 | input_ids, 630 | position_ids=position_ids, 631 | attention_mask=attention_mask, 632 | deterministic=deterministic, 633 | init_cache=init_cache, 634 | output_attentions=output_attentions, 635 | output_hidden_states=output_hidden_states, 636 | return_dict=return_dict, 637 | ) 638 | 639 | hidden_states = outputs[0] 640 | lm_logits = self.lm_head(hidden_states) 641 | 642 | if not return_dict: 643 | return (lm_logits,) + outputs[1:] 644 | 645 | return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions) 646 | 647 | 648 | class FlaxLlamaForCausalLM(FlaxLlamaPreTrainedModel): 649 | module_class = FlaxLlamaForCausalLMModule 650 | 651 | def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None): 652 | # initializing the cache 653 | batch_size, seq_length = input_ids.shape 654 | 655 | past_key_values = self.init_cache(batch_size, max_length) 656 | # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. 657 | # But since Llama uses a causal mask, those positions are masked anyways. 658 | # Thus we can create a single static attention_mask here, which is more efficient for compilation 659 | extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") 660 | if attention_mask is not None: 661 | position_ids = attention_mask.cumsum(axis=-1) - 1 662 | extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) 663 | else: 664 | position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) 665 | 666 | return { 667 | "past_key_values": past_key_values, 668 | "attention_mask": extended_attention_mask, 669 | "position_ids": position_ids, 670 | } 671 | 672 | def update_inputs_for_generation(self, model_outputs, model_kwargs): 673 | model_kwargs["past_key_values"] = model_outputs.past_key_values 674 | model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1 675 | return model_kwargs 676 | -------------------------------------------------------------------------------- /magix/models/mistral_model.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional, Tuple 3 | from functools import partial 4 | 5 | import flax.linen as nn 6 | import jax 7 | import jax.numpy as jnp 8 | import numpy as np 9 | from jax import lax 10 | from jax.sharding import PartitionSpec as PS 11 | from jax._src import mesh as mesh_lib 12 | from jax.experimental.pallas.ops import attention as attn_ops 13 | from jax.experimental.shard_map import shard_map 14 | from flax.core.frozen_dict import FrozenDict, freeze, unfreeze 15 | from flax.linen import combine_masks, make_causal_mask 16 | from flax.linen.attention import dot_product_attention_weights 17 | from flax.traverse_util import flatten_dict, unflatten_dict 18 | 19 | from transformers.modeling_flax_outputs import ( 20 | FlaxBaseModelOutput, 21 | FlaxCausalLMOutput, 22 | ) 23 | from transformers.modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring, logging 24 | from transformers import MistralConfig 25 | 26 | try: 27 | from transformer_engine.jax import fused_attn as te_attn 28 | _te_available = True 29 | except ImportError: 30 | _te_available = False 31 | 32 | logger = logging.get_logger(__name__) 33 | 34 | class FlaxMistralRMSNorm(nn.Module): 35 | config: MistralConfig 36 | dtype: jnp.dtype = jnp.float32 37 | 38 | def setup(self): 39 | self.epsilon = self.config.rms_norm_eps 40 | self.weight = self.param("weight", lambda _, shape: jnp.ones(shape), self.config.hidden_size) 41 | 42 | def __call__(self, hidden_states): 43 | variance = jnp.asarray(hidden_states, dtype=jnp.float32) 44 | variance = jnp.power(variance, 2) 45 | variance = variance.mean(-1, keepdims=True) 46 | # use `jax.numpy.sqrt` as `jax.lax.rsqrt` does not match `torch.rsqrt` 47 | hidden_states = hidden_states / jnp.sqrt(variance + self.epsilon) 48 | 49 | return self.weight * jnp.asarray(hidden_states, dtype=self.dtype) 50 | 51 | 52 | class FlaxMistralRotaryEmbedding(nn.Module): 53 | config: MistralConfig 54 | dtype: jnp.dtype = jnp.float32 55 | 56 | def setup(self): 57 | head_dim = self.config.hidden_size // self.config.num_attention_heads 58 | self.sincos = create_sinusoidal_positions(self.config.max_position_embeddings, head_dim) 59 | 60 | def __call__(self, key, query, position_ids): 61 | sincos = self.sincos[position_ids] 62 | sin_pos, cos_pos = jnp.split(sincos, 2, axis=-1) 63 | 64 | key = apply_rotary_pos_emb(key, sin_pos, cos_pos) 65 | query = apply_rotary_pos_emb(query, sin_pos, cos_pos) 66 | 67 | key = jnp.asarray(key, dtype=self.dtype) 68 | query = jnp.asarray(query, dtype=self.dtype) 69 | 70 | return key, query 71 | 72 | 73 | class FlaxMistralMLP(nn.Module): 74 | config: MistralConfig 75 | dtype: jnp.dtype = jnp.float32 76 | 77 | def setup(self): 78 | embed_dim = self.config.hidden_size 79 | inner_dim = self.config.intermediate_size if self.config.intermediate_size is not None else 4 * embed_dim 80 | 81 | kernel_init = jax.nn.initializers.normal(self.config.initializer_range) 82 | self.act = ACT2FN[self.config.hidden_act] 83 | 84 | self.gate_proj = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init) 85 | self.down_proj = nn.Dense(embed_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init) 86 | self.up_proj = nn.Dense(inner_dim, use_bias=False, dtype=self.dtype, kernel_init=kernel_init) 87 | 88 | def __call__(self, hidden_states): 89 | up_proj_states = self.up_proj(hidden_states) 90 | gate_states = self.act(self.gate_proj(hidden_states)) 91 | 92 | up_proj_states = lax.with_sharding_constraint(up_proj_states, PS('data', None, 'model')) 93 | gate_states = lax.with_sharding_constraint(gate_states, PS('data', None, 'model')) 94 | 95 | hidden_states = self.down_proj(up_proj_states * gate_states) 96 | return hidden_states 97 | 98 | 99 | def apply_rotary_pos_emb(tensor, sin_pos, cos_pos): 100 | return (tensor * cos_pos) + (rotate_half(tensor) * sin_pos) 101 | 102 | 103 | def create_sinusoidal_positions(num_pos, dim): 104 | inv_freq = 1.0 / (10000 ** (np.arange(0, dim, 2) / dim)) 105 | freqs = np.einsum("i , j -> i j", np.arange(num_pos), inv_freq).astype("float32") 106 | 107 | emb = np.concatenate((freqs, freqs), axis=-1) 108 | out = np.concatenate((np.sin(emb)[:, None, :], np.cos(emb)[:, None, :]), axis=-1) 109 | return jnp.array(out[:, :, :num_pos]) 110 | 111 | 112 | def rotate_half(tensor): 113 | """Rotates half the hidden dims of the input.""" 114 | rotate_half_tensor = jnp.concatenate( 115 | (-tensor[..., tensor.shape[-1] // 2 :], tensor[..., : tensor.shape[-1] // 2]), axis=-1 116 | ) 117 | return rotate_half_tensor 118 | 119 | 120 | def flax_repeat_kv(hidden_states: jnp.ndarray, n_rep: int) -> jnp.ndarray: 121 | """ 122 | This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, 123 | num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) 124 | """ 125 | batch, slen, num_key_value_heads, head_dim = hidden_states.shape 126 | if n_rep == 1: 127 | return hidden_states 128 | hidden_states = jnp.repeat(hidden_states[:, :, None, :, :], n_rep, axis=3) 129 | new_size = (batch, slen, num_key_value_heads * n_rep, head_dim) 130 | return jax.lax.reshape(hidden_states, new_size) 131 | 132 | 133 | class FlaxMistralAttention(nn.Module): 134 | config: MistralConfig 135 | dtype: jnp.dtype = jnp.float32 136 | fused_attention: bool = True 137 | 138 | def setup(self): 139 | config = self.config 140 | self.hidden_size = config.hidden_size 141 | self.num_heads = config.num_attention_heads 142 | self.head_dim = self.hidden_size // self.num_heads 143 | self.num_key_value_heads = config.num_key_value_heads 144 | self.num_key_value_groups = self.num_heads // self.num_key_value_heads 145 | self.max_position_embeddings = config.max_position_embeddings 146 | self.attention_softmax_in_fp32 = self.dtype is not jnp.float32 147 | self.rope_theta = config.rope_theta 148 | if (self.head_dim * self.num_heads) != self.hidden_size: 149 | raise ValueError( 150 | f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" 151 | f" and `num_heads`: {self.num_heads})." 152 | ) 153 | self.q_proj = nn.Dense(self.num_heads * self.head_dim, use_bias=False, dtype=self.dtype) 154 | self.k_proj = nn.Dense(self.num_key_value_heads * self.head_dim, use_bias=False, dtype=self.dtype) 155 | self.v_proj = nn.Dense(self.num_key_value_heads * self.head_dim, use_bias=False, dtype=self.dtype) 156 | self.o_proj = nn.Dense(self.hidden_size, use_bias=False, dtype=self.dtype) 157 | self.causal_mask = jnp.triu( 158 | make_causal_mask(jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool"), 159 | k=-config.sliding_window, 160 | ) 161 | self.rotary_emb = FlaxMistralRotaryEmbedding(config, dtype=jnp.float32) 162 | 163 | def _split_heads(self, hidden_states, num_heads): 164 | return hidden_states.reshape(hidden_states.shape[:2] + (num_heads, self.head_dim)) 165 | 166 | def _merge_heads(self, hidden_states): 167 | return hidden_states.reshape(hidden_states.shape[:2] + (self.hidden_size,)) 168 | 169 | @nn.compact 170 | def _concatenate_to_cache(self, key, value, query, attention_mask): 171 | """ 172 | This function takes projected key, value states from a single input token and concatenates the states to cached 173 | states from previous steps. This function is slighly adapted from the official Flax repository: 174 | https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 175 | """ 176 | # detect if we're initializing by absence of existing cache data. 177 | is_initialized = self.has_variable("cache", "cached_key") 178 | cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) 179 | cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) 180 | cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) 181 | 182 | if is_initialized: 183 | *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape 184 | # update key, value caches with our new 1d spatial slices 185 | cur_index = cache_index.value 186 | indices = (0,) * len(batch_dims) + (cur_index, 0, 0) 187 | key = lax.dynamic_update_slice(cached_key.value, key, indices) 188 | value = lax.dynamic_update_slice(cached_value.value, value, indices) 189 | cached_key.value = key 190 | cached_value.value = value 191 | num_updated_cache_vectors = query.shape[1] 192 | cache_index.value = cache_index.value + num_updated_cache_vectors 193 | # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements. 194 | pad_mask = jnp.broadcast_to( 195 | jnp.arange(max_length) < cur_index + num_updated_cache_vectors, 196 | tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), 197 | ) 198 | attention_mask = combine_masks(pad_mask, attention_mask) 199 | return key, value, attention_mask 200 | 201 | def __call__( 202 | self, 203 | hidden_states: jnp.ndarray, 204 | attention_mask: Optional[jnp.ndarray] = None, 205 | position_ids: Optional[jnp.ndarray] = None, 206 | deterministic: bool = True, 207 | output_attentions: bool = False, 208 | init_cache: bool = False, 209 | ) -> Tuple[jnp.ndarray, jnp.ndarray]: 210 | query_states = self.q_proj(hidden_states) 211 | key_states = self.k_proj(hidden_states) 212 | value_states = self.v_proj(hidden_states) 213 | 214 | query_states = self._split_heads(query_states, self.num_heads) 215 | key_states = self._split_heads(key_states, self.num_key_value_heads) 216 | value_states = self._split_heads(value_states, self.num_key_value_heads) 217 | 218 | key_states, query_states = self.rotary_emb(key_states, query_states, position_ids) 219 | query_length, key_length = query_states.shape[1], key_states.shape[1] 220 | 221 | if self.has_variable("cache", "cached_key"): 222 | mask_shift = self.variables["cache"]["cache_index"] 223 | max_decoder_length = self.variables["cache"]["cached_key"].shape[1] 224 | causal_mask = lax.dynamic_slice( 225 | self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length) 226 | ) 227 | else: 228 | causal_mask = self.causal_mask[:, :, :query_length, :key_length] 229 | 230 | batch_size = hidden_states.shape[0] 231 | causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) 232 | attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) 233 | attention_mask = combine_masks(attention_mask, causal_mask, dtype="bool") 234 | 235 | if self.has_variable("cache", "cached_key") or init_cache: 236 | key_states, value_states, attention_mask = self._concatenate_to_cache( 237 | key_states, value_states, query_states, attention_mask 238 | ) 239 | 240 | use_fused_attention = ( 241 | self.fused_attention 242 | and _te_available 243 | and query_length >= 32 244 | and not init_cache 245 | and not self.has_variable("cache", "cached_key") 246 | ) 247 | 248 | if not use_fused_attention: 249 | attention_bias = lax.select( 250 | attention_mask > 0, 251 | jnp.full(attention_mask.shape, 0.0).astype(self.dtype), 252 | jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), 253 | ) 254 | 255 | # usual dot product attention 256 | attention_dtype = jnp.float32 if self.attention_softmax_in_fp32 else self.dtype 257 | 258 | key_states = flax_repeat_kv(key_states, self.num_key_value_groups) 259 | value_states = flax_repeat_kv(value_states, self.num_key_value_groups) 260 | attn_weights = dot_product_attention_weights( 261 | query_states, 262 | key_states, 263 | bias=attention_bias, 264 | deterministic=deterministic, 265 | dropout_rng=self.make_rng("dropout") if not deterministic else None, 266 | dropout_rate=self.config.attention_dropout, 267 | dtype=attention_dtype, 268 | ) 269 | 270 | query_states = lax.with_sharding_constraint(query_states, PS('data', None, 'model', None)) 271 | key_states = lax.with_sharding_constraint(key_states, PS('data', None, 'model', None)) 272 | value_states = lax.with_sharding_constraint(value_states, PS('data', None, 'model', None)) 273 | if self.attention_softmax_in_fp32: 274 | attn_weights = attn_weights.astype(self.dtype) 275 | attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) 276 | 277 | else: 278 | query, key, value = map(lambda x: x.astype(jnp.bfloat16), (query_states, key_states, value_states)) 279 | kv = jnp.stack([key, value], axis=2) 280 | 281 | attn_output = te_attn.cross_fused_attn( 282 | query, 283 | kv, 284 | None, 285 | ~attention_mask, 286 | None, 287 | attn_bias_type=te_attn.AttnBiasType.NO_BIAS, 288 | attn_mask_type=te_attn.AttnMaskType.PADDING_CAUSAL_MASK, 289 | scaling_factor=1.0 / math.sqrt(query.shape[-1]), 290 | dropout_probability=self.config.attention_dropout, 291 | is_training=not deterministic, 292 | ) 293 | 294 | attn_output = self._merge_heads(attn_output) 295 | attn_output = self.o_proj(attn_output) 296 | 297 | outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) 298 | return outputs 299 | 300 | 301 | class FlaxMistralDecoderLayer(nn.Module): 302 | config: MistralConfig 303 | dtype: jnp.dtype = jnp.float32 304 | 305 | def setup(self): 306 | self.input_layernorm = FlaxMistralRMSNorm(self.config, dtype=self.dtype) 307 | self.self_attn = FlaxMistralAttention(self.config, dtype=jnp.bfloat16) 308 | self.post_attention_layernorm = FlaxMistralRMSNorm(self.config, dtype=self.dtype) 309 | self.mlp = FlaxMistralMLP(self.config, dtype=jnp.bfloat16) 310 | 311 | def __call__( 312 | self, 313 | hidden_states, 314 | attention_mask=None, 315 | position_ids=None, 316 | deterministic: bool = True, 317 | init_cache: bool = False, 318 | output_attentions: bool = False, 319 | ): 320 | residual = hidden_states 321 | hidden_states = self.input_layernorm(hidden_states) 322 | outputs = self.self_attn( 323 | hidden_states, 324 | attention_mask=attention_mask, 325 | position_ids=position_ids, 326 | deterministic=deterministic, 327 | init_cache=init_cache, 328 | output_attentions=output_attentions, 329 | ) 330 | # residual connection 331 | attn_output = outputs[0] 332 | hidden_states = residual + attn_output 333 | 334 | residual = hidden_states 335 | hidden_states = self.post_attention_layernorm(hidden_states) 336 | hidden_states = self.mlp(hidden_states) 337 | # residual connection 338 | hidden_states = residual + hidden_states 339 | 340 | return (hidden_states,) + outputs[1:] 341 | 342 | 343 | class FlaxMistralPreTrainedModel(FlaxPreTrainedModel): 344 | """ 345 | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained 346 | models. 347 | """ 348 | 349 | config_class = MistralConfig 350 | base_model_prefix = "model" 351 | module_class: nn.Module = None 352 | 353 | partition_rules = { 354 | 'embed_tokens/embedding': PS('data', 'model'), 355 | 'lm_head': PS('data', 'model'), 356 | 'mlp/(gate|up)_proj': PS('data', 'model'), 357 | 'mlp/down_proj': PS('model', 'data'), 358 | 'self_attn/(k|q|v)_proj': PS('data', 'model'), 359 | 'self_attn/o_proj': PS('model', 'data'), 360 | 'norm/weight': PS('model'), 361 | } 362 | 363 | 364 | def __init__( 365 | self, 366 | config: MistralConfig, 367 | input_shape: Tuple = (1, 1), 368 | seed: int = 0, 369 | dtype: jnp.dtype = jnp.float32, 370 | _do_init: bool = True, 371 | **kwargs, 372 | ): 373 | module = self.module_class(config=config, dtype=dtype, **kwargs) 374 | super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) 375 | 376 | def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: 377 | # init input tensors 378 | input_ids = jnp.zeros(input_shape, dtype="i4") 379 | attention_mask = jnp.ones_like(input_ids) 380 | position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape) 381 | params_rng, dropout_rng = jax.random.split(rng) 382 | rngs = {"params": params_rng, "dropout": dropout_rng} 383 | 384 | random_params = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)["params"] 385 | 386 | if params is not None: 387 | random_params = flatten_dict(unfreeze(random_params)) 388 | params = flatten_dict(unfreeze(params)) 389 | for missing_key in self._missing_keys: 390 | params[missing_key] = random_params[missing_key] 391 | self._missing_keys = set() 392 | return freeze(unflatten_dict(params)) 393 | else: 394 | return freeze(random_params) 395 | 396 | def init_cache(self, batch_size, max_length): 397 | r""" 398 | Args: 399 | batch_size (`int`): 400 | batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. 401 | max_length (`int`): 402 | maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized 403 | cache. 404 | """ 405 | # init input variables to retrieve cache 406 | input_ids = jnp.ones((batch_size, max_length)) 407 | attention_mask = jnp.ones_like(input_ids) 408 | position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) 409 | 410 | init_variables = self.module.init( 411 | jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True 412 | ) 413 | return unfreeze(init_variables["cache"]) 414 | 415 | def __call__( 416 | self, 417 | input_ids, 418 | attention_mask=None, 419 | position_ids=None, 420 | params: dict = None, 421 | past_key_values: dict = None, 422 | dropout_rng: jax.random.PRNGKey = None, 423 | train: bool = False, 424 | output_attentions: Optional[bool] = None, 425 | output_hidden_states: Optional[bool] = None, 426 | return_dict: Optional[bool] = None, 427 | ): 428 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 429 | output_hidden_states = ( 430 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 431 | ) 432 | return_dict = return_dict if return_dict is not None else self.config.return_dict 433 | 434 | batch_size, sequence_length = input_ids.shape 435 | 436 | if position_ids is None: 437 | if past_key_values is not None: 438 | raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.") 439 | 440 | position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) 441 | 442 | if attention_mask is None: 443 | attention_mask = jnp.ones((batch_size, sequence_length)) 444 | 445 | # Handle any PRNG if needed 446 | rngs = {} 447 | if dropout_rng is not None: 448 | rngs["dropout"] = dropout_rng 449 | 450 | inputs = {"params": params or self.params} 451 | 452 | # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be changed by FlaxMistralAttention module 453 | if past_key_values: 454 | inputs["cache"] = past_key_values 455 | mutable = ["cache"] 456 | else: 457 | mutable = False 458 | 459 | outputs = self.module.apply( 460 | inputs, 461 | jnp.array(input_ids, dtype="i4"), 462 | jnp.array(attention_mask, dtype="i4"), 463 | jnp.array(position_ids, dtype="i4"), 464 | not train, 465 | False, 466 | output_attentions, 467 | output_hidden_states, 468 | return_dict, 469 | rngs=rngs, 470 | mutable=mutable, 471 | ) 472 | 473 | # add updated cache to model output 474 | if past_key_values is not None and return_dict: 475 | outputs, past_key_values = outputs 476 | outputs["past_key_values"] = unfreeze(past_key_values["cache"]) 477 | return outputs 478 | elif past_key_values is not None and not return_dict: 479 | outputs, past_key_values = outputs 480 | outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:] 481 | 482 | return outputs 483 | 484 | 485 | class FlaxMistralLayerCollection(nn.Module): 486 | config: MistralConfig 487 | dtype: jnp.dtype = jnp.float32 488 | 489 | def setup(self): 490 | self.blocks = [ 491 | FlaxMistralDecoderLayer(self.config, dtype=self.dtype, name=str(i)) 492 | for i in range(self.config.num_hidden_layers) 493 | ] 494 | 495 | def __call__( 496 | self, 497 | hidden_states, 498 | attention_mask=None, 499 | position_ids=None, 500 | deterministic: bool = True, 501 | init_cache: bool = False, 502 | output_attentions: bool = False, 503 | output_hidden_states: bool = False, 504 | return_dict: bool = False, 505 | ): 506 | all_attentions = () if output_attentions else None 507 | all_hidden_states = () if output_hidden_states else None 508 | 509 | for block in self.blocks: 510 | if output_hidden_states: 511 | all_hidden_states += (hidden_states,) 512 | layer_outputs = block( 513 | hidden_states, 514 | attention_mask=attention_mask, 515 | position_ids=position_ids, 516 | deterministic=deterministic, 517 | init_cache=init_cache, 518 | output_attentions=output_attentions, 519 | ) 520 | hidden_states = layer_outputs[0] 521 | 522 | if output_attentions: 523 | all_attentions += (layer_outputs[1],) 524 | 525 | # this contains possible `None` values - `FlaxMistralModule` will filter them out 526 | outputs = (hidden_states, all_hidden_states, all_attentions) 527 | 528 | return outputs 529 | 530 | 531 | class FlaxMistralModule(nn.Module): 532 | config: MistralConfig 533 | dtype: jnp.dtype = jnp.float32 534 | 535 | def setup(self): 536 | self.hidden_size = self.config.hidden_size 537 | embedding_init = jax.nn.initializers.normal(stddev=self.config.initializer_range) 538 | self.embed_tokens = nn.Embed( 539 | self.config.vocab_size, 540 | self.hidden_size, 541 | embedding_init=embedding_init, 542 | dtype=self.dtype, 543 | ) 544 | self.layers = FlaxMistralLayerCollection(self.config, dtype=self.dtype) 545 | self.norm = FlaxMistralRMSNorm(self.config, dtype=self.dtype) 546 | 547 | def __call__( 548 | self, 549 | input_ids, 550 | attention_mask=None, 551 | position_ids=None, 552 | deterministic=True, 553 | init_cache: bool = False, 554 | output_attentions: bool = False, 555 | output_hidden_states: bool = False, 556 | return_dict: bool = True, 557 | ): 558 | input_embeds = self.embed_tokens(input_ids.astype("i4")) 559 | 560 | outputs = self.layers( 561 | input_embeds, 562 | position_ids=position_ids, 563 | attention_mask=attention_mask, 564 | deterministic=deterministic, 565 | init_cache=init_cache, 566 | output_attentions=output_attentions, 567 | output_hidden_states=output_hidden_states, 568 | return_dict=return_dict, 569 | ) 570 | 571 | hidden_states = outputs[0] 572 | hidden_states = self.norm(hidden_states) 573 | 574 | if output_hidden_states: 575 | all_hidden_states = outputs[1] + (hidden_states,) 576 | outputs = (hidden_states, all_hidden_states) + outputs[2:] 577 | else: 578 | outputs = (hidden_states,) + outputs[1:] 579 | 580 | if not return_dict: 581 | return tuple(v for v in outputs if v is not None) 582 | 583 | return FlaxBaseModelOutput( 584 | last_hidden_state=hidden_states, 585 | hidden_states=outputs[1], 586 | attentions=outputs[-1], 587 | ) 588 | 589 | 590 | class FlaxMistralModel(FlaxMistralPreTrainedModel): 591 | module_class = FlaxMistralModule 592 | 593 | class FlaxMistralForCausalLMModule(nn.Module): 594 | config: MistralConfig 595 | dtype: jnp.dtype = jnp.float32 596 | 597 | def setup(self): 598 | self.model = FlaxMistralModule(self.config, dtype=self.dtype) 599 | self.lm_head = nn.Dense( 600 | self.config.vocab_size, 601 | use_bias=False, 602 | dtype=jnp.bfloat16, 603 | kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), 604 | ) 605 | 606 | def __call__( 607 | self, 608 | input_ids, 609 | attention_mask=None, 610 | position_ids=None, 611 | deterministic: bool = True, 612 | init_cache: bool = False, 613 | output_attentions: bool = False, 614 | output_hidden_states: bool = False, 615 | return_dict: bool = True, 616 | ): 617 | outputs = self.model( 618 | input_ids, 619 | position_ids=position_ids, 620 | attention_mask=attention_mask, 621 | deterministic=deterministic, 622 | init_cache=init_cache, 623 | output_attentions=output_attentions, 624 | output_hidden_states=output_hidden_states, 625 | return_dict=return_dict, 626 | ) 627 | 628 | hidden_states = outputs[0] 629 | lm_logits = self.lm_head(hidden_states) 630 | 631 | if not return_dict: 632 | return (lm_logits,) + outputs[1:] 633 | 634 | return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions) 635 | 636 | 637 | class FlaxMistralForCausalLM(FlaxMistralPreTrainedModel): 638 | module_class = FlaxMistralForCausalLMModule 639 | 640 | def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None): 641 | # initializing the cache 642 | batch_size, seq_length = input_ids.shape 643 | 644 | past_key_values = self.init_cache(batch_size, max_length) 645 | # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. 646 | # But since Mistral uses a causal mask, those positions are masked anyways. 647 | # Thus we can create a single static attention_mask here, which is more efficient for compilation 648 | extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") 649 | if attention_mask is not None: 650 | position_ids = attention_mask.cumsum(axis=-1) - 1 651 | extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) 652 | else: 653 | position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) 654 | 655 | return { 656 | "past_key_values": past_key_values, 657 | "attention_mask": extended_attention_mask, 658 | "position_ids": position_ids, 659 | } 660 | 661 | def update_inputs_for_generation(self, model_outputs, model_kwargs): 662 | model_kwargs["past_key_values"] = model_outputs.past_key_values 663 | model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1 664 | return model_kwargs 665 | -------------------------------------------------------------------------------- /magix/spmd_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import logging 3 | 4 | from functools import partial 5 | 6 | import numpy as np 7 | import jax 8 | from jax.sharding import PartitionSpec as PS 9 | from jax.sharding import NamedSharding, Mesh 10 | from jax.experimental import mesh_utils 11 | from jax._src.tree_util import GetAttrKey 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | def get_sharding(k, v, sharding_config=None, mesh=None): 17 | def get_key(x): 18 | if isinstance(x, GetAttrKey): 19 | name = str(x)[1:] 20 | else: 21 | name = str(getattr(x, 'key', getattr(x, 'idx', None))) 22 | return name 23 | 24 | path = '/'.join([get_key(p) for p in k]) 25 | rule = PS(None) 26 | for param_name, sharding_rule in sharding_config.items(): 27 | if re.search(param_name, path): 28 | rule = sharding_rule 29 | break 30 | if len(v.shape) == 0: 31 | rule = PS() 32 | 33 | if mesh is None: 34 | return rule 35 | 36 | return NamedSharding(mesh, rule) 37 | 38 | 39 | def item_sharding(pytree): 40 | return jax.tree_map(lambda x: x.sharding, pytree) 41 | 42 | 43 | def initialize_opt_state(optimizer, sharded_params, sharding_config, mesh): 44 | get_sharding_fn = partial( 45 | get_sharding, 46 | sharding_config=sharding_config, 47 | mesh=mesh 48 | ) 49 | opt_shapes = jax.eval_shape(optimizer.init, sharded_params) 50 | opt_sharding = jax.tree_util.tree_map_with_path(get_sharding_fn, opt_shapes) 51 | opt_state = jax.jit(optimizer.init, out_shardings=opt_sharding)(sharded_params) 52 | logger.info("Optimizer shards initialized on devices") 53 | return opt_state 54 | 55 | 56 | def create_device_mesh(shape, names=('data', 'model')): 57 | if -1 in shape: 58 | from collections import Counter 59 | assert Counter(shape)[-1] == 1, "Only one -1 is allowed in shape" 60 | shape = np.array(jax.devices()).reshape(shape).shape 61 | 62 | return Mesh(devices=mesh_utils.create_device_mesh(shape), axis_names=names) 63 | 64 | 65 | def duplicate_over(ob, *dup_axes): 66 | def transform_spec(spec): 67 | new_axes = [ax if ax not in dup_axes else None for ax in spec] 68 | return PS(*new_axes) 69 | return jax.tree_map(transform_spec, ob, is_leaf=lambda x: isinstance(x, PS)) -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages, find_namespace_packages 2 | 3 | setup( 4 | name='magix', 5 | version='0.0.1', 6 | packages=find_namespace_packages(include=['magix*']), 7 | license='Apache 2.0', 8 | author='Luyu Gao', 9 | author_email='luyug@cs.cmu.edu', 10 | python_requires='>=3.10', 11 | install_requires=[ 12 | 'transformers>=4.10.0', 13 | 'datasets>=1.1.3', 14 | 'simple_parsing', 15 | 'sentencepiece', 16 | 'orbax-checkpoint==0.4.8', 17 | ] 18 | ) 19 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | 4 | from dataclasses import dataclass, field 5 | from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union 6 | from tqdm import tqdm, trange 7 | from functools import partial 8 | 9 | import jax 10 | if jax.default_backend() == 'gpu': 11 | os.environ['XLA_FLAGS'] = ( 12 | # '--xla_gpu_enable_triton_softmax_fusion=true ' 13 | '--xla_gpu_triton_gemm_any=false ' 14 | '--xla_gpu_enable_async_collectives=true ' 15 | '--xla_gpu_enable_async_all_gather=true ' 16 | '--xla_gpu_enable_async_reduce_scatter=true ' 17 | '--xla_gpu_enable_latency_hiding_scheduler=true ' 18 | '--xla_gpu_enable_highest_priority_async_stream=true ' 19 | '--xla_gpu_collective_permute_decomposer_threshold=1024 ' 20 | '--xla_gpu_all_reduce_combine_threshold_bytes=51200 ' 21 | '--xla_gpu_simplify_all_fp_conversions=true ' 22 | ) 23 | import jax.numpy as jnp 24 | import optax 25 | import flax 26 | 27 | from jax.sharding import Mesh 28 | from jax.sharding import PartitionSpec as PS 29 | 30 | import datasets 31 | from transformers import AutoTokenizer 32 | from simple_parsing import ArgumentParser 33 | from simple_parsing.helpers import list_field 34 | 35 | 36 | import magix 37 | import magix.models 38 | from magix import ( 39 | get_chckpoint_manager, 40 | load_model_hub, 41 | load_model_and_optimizer_local, 42 | initialize_opt_state 43 | ) 44 | 45 | def apply_chat_template(turns: Iterable[Dict[str, str]], eos_token: str = None): 46 | ROLE_DICT = { 47 | 'user': '<|user|>', 48 | 'assistant': '<|assistant|>', 49 | 'system': '<|system|>', 50 | } 51 | def _format(turn): 52 | role, content = turn['role'], turn['content'] 53 | return f"{ROLE_DICT[role]}\n{content}{eos_token}" 54 | 55 | return '\n'.join(_format(turn) for turn in turns) 56 | 57 | 58 | class TrainDataset: 59 | def __init__( 60 | self, 61 | train_data, 62 | tokenizer, 63 | field_name: str = 'text', 64 | max_len: int = 1024, 65 | use_chat_template: bool = False, 66 | ): 67 | self.data = train_data 68 | self.tokenizer = tokenizer 69 | self.field_name = field_name 70 | self.max_len = max_len 71 | self.use_chat_template = use_chat_template 72 | 73 | def __len__(self): 74 | return len(self.data) 75 | 76 | def get_batch(self, indices): 77 | batch = self.data[indices] 78 | batch = batch[self.field_name] 79 | if self.use_chat_template: 80 | batch = [apply_chat_template(turns, eos_token=self.tokenizer.eos_token) for turns in batch] 81 | tokenized = self.tokenizer( 82 | batch, max_length=self.max_len+1, padding='max_length', 83 | truncation=True, return_tensors='np', 84 | ) 85 | return dict(tokenized) 86 | 87 | class Batches: 88 | def __init__( 89 | self, 90 | rng: jax.random.PRNGKey, 91 | dataset: TrainDataset, 92 | batch_size: int, 93 | shuffle: bool = False 94 | ): 95 | steps_per_epoch = len(dataset) // batch_size 96 | 97 | if shuffle: 98 | batch_idx = jax.random.permutation(rng, len(dataset)) 99 | else: 100 | batch_idx = jnp.arange(len(dataset)) 101 | 102 | batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch. 103 | batch_idx = batch_idx.reshape((steps_per_epoch, batch_size)) 104 | 105 | self.dataset = dataset 106 | self.batch_idx = batch_idx 107 | 108 | def __call__(self, step): 109 | idx = self.batch_idx[step] 110 | batch = self.dataset.get_batch(idx) 111 | return batch 112 | 113 | 114 | def decay_mask_fn(params): 115 | flat_params = flax.traverse_util.flatten_dict(params) 116 | flat_mask = {path: (path[-1] != "bias" and 'layernorm' not in path[-2]) for path in flat_params} 117 | return flax.traverse_util.unflatten_dict(flat_mask) 118 | 119 | 120 | @dataclass 121 | class TrainArgs: 122 | train_file: str = None 123 | train_data_config: str = None 124 | train_data_field: str = 'text' 125 | split: str = 'train' 126 | use_chat_template: bool = False 127 | checkpoint_dir: str = None 128 | max_length: int = 1024 129 | num_epochs: int = 1 130 | batch_size: int = 16 131 | num_target_passages: int = 16 132 | query_num_chunks: int = 4 133 | passage_num_chunks: int = 8 134 | learning_rate: float = 2e-6 135 | weight_decay: float = 0.0001 136 | adam_beta1: float = 0.9 137 | adam_beta2: float = 0.999 138 | max_grad_norm: float = 1.0 139 | save_steps: int = 200 140 | seed: int = 42 141 | 142 | @dataclass 143 | class ModelArgs: 144 | model_type: str = 'llama' 145 | model_name: str = None 146 | tokenizer_name: str = None 147 | model_cache_dir: str = None 148 | mesh_shape: List[int] = list_field(-1, 1) 149 | bf16_model_weights: bool = False 150 | 151 | def main(): 152 | parser = ArgumentParser() 153 | parser.add_arguments(TrainArgs, dest="train_args") 154 | parser.add_arguments(ModelArgs, dest="model_args") 155 | args = parser.parse_args() 156 | train_args: TrainArgs = args.train_args 157 | model_args: ModelArgs = args.model_args 158 | 159 | # logger with date and time 160 | logging.basicConfig( 161 | format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 162 | datefmt='%m/%d/%Y %H:%M:%S', 163 | level=logging.INFO 164 | ) 165 | logger = logging.getLogger(__name__) 166 | 167 | # dataset setup 168 | if train_args.train_file.endswith('.jsonl'): 169 | train_data = datasets.load_dataset('json', data_files=train_args.train_file)['train'] 170 | else: 171 | train_data = datasets.load_dataset( 172 | train_args.train_file, 173 | train_args.train_data_config 174 | )[train_args.split] 175 | tokenizer = AutoTokenizer.from_pretrained( 176 | model_args.tokenizer_name, 177 | add_eos_token=not train_args.use_chat_template, 178 | use_fast=True, padding_side='right', legacy=False) 179 | tokenizer.pad_token = tokenizer.eos_token 180 | train_dataset = TrainDataset(train_data, tokenizer, train_args.train_data_field, train_args.max_length, train_args.use_chat_template) 181 | 182 | # optimizer setup 183 | total_train_steps = len(train_dataset) // train_args.batch_size * train_args.num_epochs 184 | lr_schedule = optax.warmup_cosine_decay_schedule( 185 | 0, train_args.learning_rate, int(total_train_steps*0.1), int(total_train_steps*0.9)) 186 | 187 | optimizer = optax.adamw( 188 | lr_schedule, 189 | mask=decay_mask_fn, 190 | b1=train_args.adam_beta1, 191 | b2=train_args.adam_beta2, 192 | weight_decay=train_args.weight_decay, 193 | ) 194 | optimizer = optax.chain( 195 | optax.clip_by_global_norm(train_args.max_grad_norm), 196 | optimizer 197 | ) 198 | optimizer = optax.apply_if_finite(optimizer, 10) 199 | 200 | # initalize model parameters and optimizer state 201 | mesh = magix.create_device_mesh(model_args.mesh_shape) 202 | 203 | checkpoint_manager = get_chckpoint_manager(train_args.checkpoint_dir, train_args.save_steps) 204 | is_new_train = checkpoint_manager.latest_step() is None 205 | 206 | _model_cls = magix.models.CAUSAL_LM_MODEL_MAPPING.get(model_args.model_type, None) 207 | if _model_cls is None: 208 | raise NotImplementedError(f"Model type {model_args.model_type} is not implemented") 209 | sharding_config = _model_cls.partition_rules 210 | 211 | if is_new_train: 212 | logger.info("Loading model from hub") 213 | model, params = load_model_hub(_model_cls, model_args.model_name, sharding_config, mesh, half=model_args.bf16_model_weights) 214 | opt_state = initialize_opt_state(optimizer, params, sharding_config, mesh) 215 | else: 216 | logger.info("Loading model from checkpoint") 217 | model, params, opt_state = load_model_and_optimizer_local( 218 | _model_cls, optimizer, checkpoint_manager, sharding_config, mesh, model_name=model_args.model_name) 219 | 220 | 221 | def train_step(params, opt_state, batch, dropout_rng): 222 | def compute_loss(params, batch): 223 | input_ids = batch['input_ids'] 224 | attention_mask = jnp.logical_and(batch['attention_mask'][:,:-1], batch['attention_mask'][:,1:]).astype('bool') 225 | logits = model( 226 | input_ids=input_ids[:,:-1], attention_mask=attention_mask, 227 | params=params, train=True, dropout_rng=dropout_rng)[0] 228 | target_ids = input_ids[:,1:] 229 | loss = optax.softmax_cross_entropy_with_integer_labels(logits, target_ids) 230 | loss = loss * attention_mask / attention_mask.sum() 231 | loss = loss.sum() 232 | return loss 233 | 234 | loss, grads = jax.value_and_grad(compute_loss, argnums=0) (params, batch) 235 | metrics = {"loss": loss} 236 | 237 | updates, new_opt_state = optimizer.update(grads, opt_state, params) # transform & update state 238 | new_params = optax.apply_updates(params, updates) 239 | return new_params, new_opt_state, metrics 240 | 241 | p_train_step = jax.jit( 242 | train_step, 243 | donate_argnums=(0,1,2,3), 244 | out_shardings=(magix.item_sharding(params), magix.item_sharding(opt_state), None) 245 | ) 246 | 247 | 248 | rng = jax.random.key(train_args.seed) 249 | dropout_rng, data_rng = jax.random.split(rng) 250 | 251 | # train loop 252 | lastest_step = checkpoint_manager.latest_step() 253 | if lastest_step is None: 254 | lastest_step = -1 255 | 256 | train_metrics = [] 257 | 258 | def combine_metrics(list_of_dicts): 259 | return {key: jnp.array([d[key] for d in list_of_dicts]) for key in list_of_dicts[0]} 260 | 261 | 262 | epochs = tqdm(range(train_args.num_epochs), desc=f"Epoch ... (1/{train_args.num_epochs})", position=0) 263 | 264 | logger.info("Starting training loop...") 265 | logger.info(" Num examples = %d", len(train_dataset)) 266 | logger.info(" Num Epochs = %d", train_args.num_epochs) 267 | logger.info(" Instantaneous batch size = %d", train_args.batch_size) 268 | 269 | 270 | with mesh: 271 | for epoch in epochs: 272 | # Create sampling rng 273 | input_rng = jax.random.fold_in(data_rng, epoch) 274 | batch_loader = Batches( 275 | input_rng, train_dataset, train_args.batch_size, shuffle=True) 276 | steps_per_epoch = len(train_dataset) // train_args.batch_size 277 | # train 278 | for step in trange(steps_per_epoch): 279 | cur_step = epoch * (len(train_dataset) // train_args.batch_size) + step 280 | if lastest_step >= cur_step: 281 | continue 282 | elif lastest_step == cur_step: 283 | logger.info('Resuming training from step %d', cur_step) 284 | 285 | batch = batch_loader(step) 286 | dropout_rngs = jax.random.fold_in(dropout_rng, cur_step) 287 | params, opt_state, metrics = p_train_step(params, opt_state, batch, dropout_rngs) 288 | 289 | is_last_step = (cur_step + 1) == total_train_steps 290 | checkpoint_manager.save( 291 | cur_step, items={'model': params, 'optimizer': opt_state}, force=is_last_step 292 | ) 293 | train_metrics.append(metrics) 294 | 295 | if cur_step % 100 == 0 and cur_step > 0: 296 | print( 297 | f"Step... ({cur_step} | Loss: {combine_metrics(train_metrics)['loss'].mean()}, Learning Rate: {lr_schedule(cur_step)})", 298 | flush=True, 299 | ) 300 | train_metrics = [] 301 | 302 | epochs.write( 303 | f"Epoch... ({epoch + 1}/{train_args.num_epochs})" 304 | ) 305 | 306 | if __name__ == '__main__': 307 | main() -------------------------------------------------------------------------------- /train_lora.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | 4 | from dataclasses import dataclass, field 5 | from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union 6 | from tqdm import tqdm, trange 7 | from functools import partial 8 | 9 | import jax 10 | if jax.default_backend() == 'gpu': 11 | os.environ['XLA_FLAGS'] = ( 12 | '--xla_gpu_triton_gemm_any=false ' 13 | '--xla_gpu_enable_async_collectives=true ' 14 | '--xla_gpu_enable_async_all_gather=true ' 15 | '--xla_gpu_enable_async_reduce_scatter=true ' 16 | '--xla_gpu_enable_latency_hiding_scheduler=true ' 17 | '--xla_gpu_enable_highest_priority_async_stream=true ' 18 | '--xla_gpu_collective_permute_decomposer_threshold=1024 ' 19 | '--xla_gpu_all_reduce_combine_threshold_bytes=51200 ' 20 | '--xla_gpu_simplify_all_fp_conversions=true ' 21 | ) 22 | import jax.numpy as jnp 23 | import optax 24 | import flax 25 | 26 | from jax.sharding import Mesh 27 | from jax.sharding import PartitionSpec as PS 28 | 29 | import datasets 30 | from transformers import AutoTokenizer 31 | from simple_parsing import ArgumentParser 32 | from simple_parsing.helpers import list_field 33 | 34 | 35 | import magix 36 | import magix.models 37 | import magix.lora 38 | from magix import ( 39 | get_chckpoint_manager, 40 | load_model_hub, 41 | ) 42 | 43 | def apply_chat_template(turns: Iterable[Dict[str, str]], eos_token: str = None): 44 | ROLE_DICT = { 45 | 'user': '<|user|>', 46 | 'assistant': '<|assistant|>', 47 | 'system': '<|system|>', 48 | } 49 | def _format(turn): 50 | role, content = turn['role'], turn['content'] 51 | return f"{ROLE_DICT[role]}\n{content}{eos_token}" 52 | 53 | return '\n'.join(_format(turn) for turn in turns) 54 | 55 | 56 | class TrainDataset: 57 | def __init__( 58 | self, 59 | train_data, 60 | tokenizer, 61 | field_name: str = 'text', 62 | max_len: int = 1024, 63 | use_chat_template: bool = False, 64 | ): 65 | self.data = train_data 66 | self.tokenizer = tokenizer 67 | self.field_name = field_name 68 | self.max_len = max_len 69 | self.use_chat_template = use_chat_template 70 | 71 | def __len__(self): 72 | return len(self.data) 73 | 74 | def get_batch(self, indices): 75 | batch = self.data[indices] 76 | batch = batch[self.field_name] 77 | if self.use_chat_template: 78 | batch = [apply_chat_template(turns, eos_token=self.tokenizer.eos_token) for turns in batch] 79 | tokenized = self.tokenizer( 80 | batch, max_length=self.max_len+1, padding='max_length', 81 | truncation=True, return_tensors='np', 82 | ) 83 | return dict(tokenized) 84 | 85 | class Batches: 86 | def __init__( 87 | self, 88 | rng: jax.random.PRNGKey, 89 | dataset: TrainDataset, 90 | batch_size: int, 91 | shuffle: bool = False 92 | ): 93 | steps_per_epoch = len(dataset) // batch_size 94 | 95 | if shuffle: 96 | batch_idx = jax.random.permutation(rng, len(dataset)) 97 | else: 98 | batch_idx = jnp.arange(len(dataset)) 99 | 100 | batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch. 101 | batch_idx = batch_idx.reshape((steps_per_epoch, batch_size)) 102 | 103 | self.dataset = dataset 104 | self.batch_idx = batch_idx 105 | 106 | def __call__(self, step): 107 | idx = self.batch_idx[step] 108 | batch = self.dataset.get_batch(idx) 109 | return batch 110 | 111 | 112 | def decay_mask_fn(params): 113 | flat_params = flax.traverse_util.flatten_dict(params) 114 | flat_mask = {path: (path[-1] != "bias" and 'layernorm' not in path[-2]) for path in flat_params} 115 | return flax.traverse_util.unflatten_dict(flat_mask) 116 | 117 | 118 | @dataclass 119 | class TrainArgs: 120 | train_file: str = None 121 | train_data_config: str = None 122 | train_data_field: str = 'text' 123 | split: str = 'train' 124 | use_chat_template: bool = False 125 | checkpoint_dir: str = None 126 | max_length: int = 1024 127 | num_epochs: int = 1 128 | batch_size: int = 16 129 | num_target_passages: int = 16 130 | query_num_chunks: int = 4 131 | passage_num_chunks: int = 8 132 | learning_rate: float = 2e-6 133 | weight_decay: float = 0.0001 134 | adam_beta1: float = 0.9 135 | adam_beta2: float = 0.999 136 | max_grad_norm: float = 1.0 137 | save_steps: int = 200 138 | seed: int = 42 139 | lora_alpha: float = 32.0 140 | lora_rank: int = 8 141 | 142 | @dataclass 143 | class ModelArgs: 144 | model_type: str = 'llama' 145 | model_name: str = None 146 | tokenizer_name: str = None 147 | model_cache_dir: str = None 148 | mesh_shape: List[int] = list_field(-1, 1) 149 | bf16_model_weights: bool = False 150 | 151 | def main(): 152 | parser = ArgumentParser() 153 | parser.add_arguments(TrainArgs, dest="train_args") 154 | parser.add_arguments(ModelArgs, dest="model_args") 155 | args = parser.parse_args() 156 | train_args: TrainArgs = args.train_args 157 | model_args: ModelArgs = args.model_args 158 | 159 | # logger with date and time 160 | logging.basicConfig( 161 | format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 162 | datefmt='%m/%d/%Y %H:%M:%S', 163 | level=logging.INFO 164 | ) 165 | logger = logging.getLogger(__name__) 166 | 167 | # dataset setup 168 | if train_args.train_file.endswith('.jsonl'): 169 | train_data = datasets.load_dataset('json', data_files=train_args.train_file)['train'] 170 | else: 171 | train_data = datasets.load_dataset( 172 | train_args.train_file, 173 | train_args.train_data_config 174 | )[train_args.split] 175 | tokenizer = AutoTokenizer.from_pretrained( 176 | model_args.tokenizer_name, 177 | add_eos_token=not train_args.use_chat_template, 178 | use_fast=True, padding_side='right', legacy=False) 179 | tokenizer.pad_token = tokenizer.eos_token 180 | train_dataset = TrainDataset(train_data, tokenizer, train_args.train_data_field, train_args.max_length, train_args.use_chat_template) 181 | 182 | # optimizer setup 183 | total_train_steps = len(train_dataset) // train_args.batch_size * train_args.num_epochs 184 | lr_schedule = optax.warmup_cosine_decay_schedule( 185 | 0, train_args.learning_rate, int(total_train_steps*0.1), int(total_train_steps*0.9)) 186 | 187 | optimizer = optax.adamw( 188 | lr_schedule, 189 | mask=decay_mask_fn, 190 | b1=train_args.adam_beta1, 191 | b2=train_args.adam_beta2, 192 | weight_decay=train_args.weight_decay, 193 | ) 194 | optimizer = optax.chain( 195 | optax.clip_by_global_norm(train_args.max_grad_norm), 196 | optimizer 197 | ) 198 | optimizer = optax.apply_if_finite(optimizer, 10) 199 | 200 | lora = magix.lora.Lora( 201 | alpha=train_args.lora_alpha, 202 | rules={ 203 | 'layers/.*/kernel': train_args.lora_rank, 204 | } 205 | ) 206 | 207 | # initalize model parameters and optimizer state 208 | mesh = magix.create_device_mesh(model_args.mesh_shape) 209 | 210 | checkpoint_manager = get_chckpoint_manager(train_args.checkpoint_dir, train_args.save_steps, items=['lora', 'optimizer']) 211 | is_new_train = checkpoint_manager.latest_step() is None 212 | 213 | _model_cls = magix.models.CAUSAL_LM_MODEL_MAPPING.get(model_args.model_type, None) 214 | if _model_cls is None: 215 | raise NotImplementedError(f"Model type {model_args.model_type} is not implemented") 216 | sharding_config = _model_cls.partition_rules 217 | 218 | logger.info("Loading model from hub") 219 | if model_args.model_cache_dir and os.path.exists(model_args.model_cache_dir): 220 | model, params = magix.checkpoint_utils.load_model_local( 221 | _model_cls, 222 | model_args.model_cache_dir, 223 | sharding_config, 224 | mesh, 225 | model_name=model_args.model_name, 226 | ) 227 | else: 228 | model, params = load_model_hub(_model_cls, model_args.model_name, sharding_config, mesh, half=model_args.bf16_model_weights) 229 | # magix.checkpoint_utils.save_model_local(params, model_args.model_cache_dir) 230 | 231 | rng = jax.random.key(train_args.seed) 232 | dropout_rng, data_rng, lora_rng = jax.random.split(rng, 3) 233 | 234 | def create_lora_and_opt_states(rng, params): 235 | lora_state = lora.init_params(rng, params) 236 | opt_state = optimizer.init(lora_state) 237 | return lora_state, opt_state 238 | 239 | lora_state_shapes, opt_shapes = jax.eval_shape(create_lora_and_opt_states, lora_rng, params) 240 | lora_sharding = magix.lora.create_lora_sharding(sharding_config, mesh, lora_state_shapes) 241 | opt_sharding = magix.lora.create_lora_sharding(sharding_config, mesh, opt_shapes) 242 | 243 | if is_new_train: 244 | lora_state = jax.jit(lora.init_params, out_shardings=lora_sharding) (lora_rng, params) 245 | opt_state = jax.jit(optimizer.init, out_shardings=opt_sharding)(lora_state) 246 | else: 247 | loaded = magix.checkpoint_utils.load_by_sharding( 248 | checkpoint_manager, 249 | items=['lora', 'optimizer'], 250 | dummies=[lora_state_shapes, opt_shapes], 251 | shardings=[lora_sharding, opt_sharding] 252 | ) 253 | lora_state, opt_state = loaded['lora'], loaded['optimizer'] 254 | 255 | 256 | def train_step(params, lora_state, opt_state, batch, dropout_rng): 257 | def compute_loss(params, lora_state, batch, dropout_rng): 258 | params = lora.apply(params, lora_state) 259 | input_ids = batch['input_ids'] 260 | attention_mask = jnp.logical_and(batch['attention_mask'][:,:-1], batch['attention_mask'][:,1:]).astype('bool') 261 | logits = model( 262 | input_ids=input_ids[:,:-1], attention_mask=attention_mask, 263 | params=params, train=True, dropout_rng=dropout_rng)[0] 264 | target_ids = input_ids[:,1:] 265 | loss = optax.softmax_cross_entropy_with_integer_labels(logits, target_ids) 266 | loss = loss * attention_mask / attention_mask.sum() 267 | loss = loss.sum() 268 | return loss 269 | 270 | loss, grads = jax.value_and_grad(compute_loss, argnums=1) (params, lora_state, batch, dropout_rng) 271 | metrics = {"loss": loss} 272 | 273 | updates, new_opt_state = optimizer.update(grads, opt_state, lora_state) 274 | new_lora_state = optax.apply_updates(lora_state, updates) 275 | return new_lora_state, new_opt_state, metrics 276 | 277 | p_train_step = jax.jit( 278 | train_step, 279 | donate_argnums=(1,2,3), 280 | out_shardings=( 281 | magix.item_sharding(lora_state), 282 | magix.item_sharding(opt_state), 283 | None 284 | ) 285 | ) 286 | p_train_step = partial(p_train_step, params) # safeguard params in a closure 287 | 288 | 289 | # train loop 290 | lastest_step = checkpoint_manager.latest_step() 291 | if lastest_step is None: 292 | lastest_step = -1 293 | 294 | train_metrics = [] 295 | 296 | def combine_metrics(list_of_dicts): 297 | return {key: jnp.array([d[key] for d in list_of_dicts]) for key in list_of_dicts[0]} 298 | 299 | 300 | epochs = tqdm(range(train_args.num_epochs), desc=f"Epoch ... (1/{train_args.num_epochs})", position=0) 301 | 302 | logger.info("Starting training loop...") 303 | logger.info(" Num examples = %d", len(train_dataset)) 304 | logger.info(" Num Epochs = %d", train_args.num_epochs) 305 | logger.info(" Instantaneous batch size = %d", train_args.batch_size) 306 | 307 | 308 | with mesh: 309 | for epoch in epochs: 310 | # Create sampling rng 311 | input_rng = jax.random.fold_in(data_rng, epoch) 312 | batch_loader = Batches( 313 | input_rng, train_dataset, train_args.batch_size, shuffle=True) 314 | steps_per_epoch = len(train_dataset) // train_args.batch_size 315 | # train 316 | for step in trange(steps_per_epoch): 317 | cur_step = epoch * (len(train_dataset) // train_args.batch_size) + step 318 | if lastest_step >= cur_step: 319 | continue 320 | elif lastest_step == cur_step: 321 | logger.info('Resuming training from step %d', cur_step) 322 | 323 | batch = batch_loader(step) 324 | dropout_rngs = jax.random.fold_in(dropout_rng, cur_step) 325 | lora_state, opt_state, metrics = p_train_step(lora_state, opt_state, batch, dropout_rngs) 326 | 327 | is_last_step = (cur_step + 1) == total_train_steps 328 | checkpoint_manager.save( 329 | cur_step, 330 | items={'lora': lora_state, 'optimizer': opt_state}, 331 | force=is_last_step 332 | ) 333 | train_metrics.append(metrics) 334 | 335 | if cur_step % 100 == 0 and cur_step > 0: 336 | print( 337 | f"Step... ({cur_step} | Loss: {combine_metrics(train_metrics)['loss'].mean()}, Learning Rate: {lr_schedule(cur_step)})", 338 | flush=True, 339 | ) 340 | train_metrics = [] 341 | 342 | epochs.write( 343 | f"Epoch... ({epoch + 1}/{train_args.num_epochs})" 344 | ) 345 | checkpoint_manager.wait_until_finished() 346 | 347 | if __name__ == '__main__': 348 | main() --------------------------------------------------------------------------------