├── gemma ├── __init__.py ├── layers.py ├── layers_test.py ├── positional_embeddings.py ├── params.py ├── positional_embeddings_test.py ├── sampler_test.py ├── transformer_test.py ├── modules_test.py ├── transformer.py ├── modules.py └── sampler.py ├── pyproject.toml ├── CONTRIBUTING.md ├── examples └── sampling.py ├── README.md ├── colabs ├── sampling_tutorial.ipynb ├── gsm8k_eval.ipynb └── fine_tuning_tutorial.ipynb └── LICENSE /gemma/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "gemma" 3 | version = "1.0.0" 4 | description = "Open weights large language model (LLM) from Google DeepMind." 5 | authors = [ 6 | "Alek Andreev ", 7 | "Armand Joulin ", 8 | "Cassidy Hardin ", 9 | "Juliette Love ", 10 | "Kathleen Kenealy ", 11 | "Laurent Sifre ", 12 | "Léonard Hussenot ", 13 | "Mihir Sanjay Kale ", 14 | "Morgane Riviere ", 15 | "Robert Dadashi ", 16 | "Shreya Pathak ", 17 | "Surya Bhupatiraju ", 18 | "Thomas Mesnard " 19 | ] 20 | python-version = "3.10" 21 | license = "Apache-2.0" 22 | readme = "README.md" 23 | 24 | [tool.poetry.dependencies] 25 | absl-py = "^2.1.0" 26 | sentencepiece = "^0.1.99" 27 | flax = ">=0.8" 28 | pytest = {version = "^8.0.0", optional = true} 29 | 30 | [tool.poetry.extras] 31 | test = ["pytest"] 32 | 33 | [build-system] 34 | requires = ["poetry-core"] 35 | build-backend = "poetry.core.masonry.api" 36 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We would love to accept your patches and contributions to this project. 4 | 5 | ## Before you begin 6 | 7 | ### Sign our Contributor License Agreement 8 | 9 | Contributions to this project must be accompanied by a 10 | [Contributor License Agreement](https://cla.developers.google.com/about) (CLA). 11 | You (or your employer) retain the copyright to your contribution; this simply 12 | gives us permission to use and redistribute your contributions as part of the 13 | project. 14 | 15 | If you or your current employer have already signed the Google CLA (even if it 16 | was for a different project), you probably don't need to do it again. 17 | 18 | Visit to see your current agreements or to 19 | sign a new one. 20 | 21 | ### Review our Community Guidelines 22 | 23 | This project follows [Google's Open Source Community 24 | Guidelines](https://opensource.google/conduct/). 25 | 26 | ## Contribution process 27 | 28 | ### Code Reviews 29 | 30 | All submissions, including submissions by project members, require review. We 31 | use [GitHub pull requests](https://docs.github.com/articles/about-pull-requests) 32 | for this purpose. 33 | -------------------------------------------------------------------------------- /gemma/layers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Base layers.""" 16 | 17 | from flax import linen as nn 18 | import jax 19 | import jax.numpy as jnp 20 | 21 | 22 | class Einsum(nn.Module): 23 | """Einsum is a convenience module for parameterized tensor multiplication.""" 24 | shape: tuple[int, ...] 25 | 26 | @nn.compact 27 | def __call__(self, eqn: str, x: jax.Array) -> jax.Array: 28 | w = self.param('w', nn.initializers.normal(), self.shape) 29 | return jnp.einsum(eqn, x, w) 30 | 31 | 32 | class RMSNorm(nn.Module): 33 | """RMSNorm layer.""" 34 | 35 | @nn.compact 36 | def __call__(self, x): 37 | scale = self.param('scale', nn.initializers.zeros_init(), (x.shape[-1])) 38 | var = jnp.mean(jnp.square(x), axis=-1, keepdims=True) 39 | normed_inputs = jnp.asarray(x * jnp.reciprocal(jnp.sqrt(var + 1e-06))) 40 | # normed_inputs is a rank-K tensor, K > 1 (K is typically 2 or 3). scale is 41 | # a rank-1 tensor. To avoid implicit rank-promotion, reshape scale to 42 | # a (1, ..., 1, D) tensor, so the rank of scale matches normed_inputs. 43 | scale = jnp.expand_dims(scale, axis=range(len(x.shape) - 1)) 44 | normed_inputs = normed_inputs * (1 + scale) 45 | return normed_inputs 46 | -------------------------------------------------------------------------------- /gemma/layers_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Tests for transformer layers.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | from gemma import layers 20 | import jax 21 | import jax.numpy as jnp 22 | import numpy as np 23 | 24 | 25 | class EinsumTest(parameterized.TestCase): 26 | 27 | @parameterized.parameters( 28 | dict( 29 | inputs_shape=(1, 4), 30 | params_shape=(3, 2, 4, 3), 31 | eqn='TD,SNDH->STNH', 32 | expected_shape=(3, 1, 2, 3), 33 | ), 34 | dict( 35 | inputs_shape=(1, 2, 4), 36 | params_shape=(2, 4, 8), 37 | eqn='ANH,NHD->AD', 38 | expected_shape=(1, 8), 39 | ), 40 | ) 41 | def test_einsum(self, inputs_shape, params_shape, eqn, expected_shape): 42 | einsum = layers.Einsum(params_shape) 43 | output = einsum.apply( 44 | {'params': {'w': jnp.ones(params_shape)}}, 45 | eqn, 46 | jnp.ones(inputs_shape), 47 | ) 48 | self.assertEqual(output.shape, expected_shape) 49 | 50 | @parameterized.parameters(dict(x=[0.1, 0.2], expected=[0.6324429, 1.2648858])) 51 | def test_rmsnorm(self, x, expected): 52 | x = jnp.array([x]) 53 | rmsnorm = layers.RMSNorm() 54 | params = rmsnorm.init(jax.random.PRNGKey(0), x) 55 | output = rmsnorm.apply(params, x) 56 | np.testing.assert_array_equal(output, jnp.array([expected])) 57 | 58 | 59 | if __name__ == '__main__': 60 | absltest.main() 61 | -------------------------------------------------------------------------------- /gemma/positional_embeddings.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Utils for positional embeddings (including RoPE).""" 16 | 17 | import jax 18 | import jax.numpy as jnp 19 | 20 | _MAX_WAVELENGTH = 10_000 21 | 22 | 23 | def add_positional_embedding( 24 | input_embedding: jax.Array, 25 | position: int, 26 | max_wavelength: int = _MAX_WAVELENGTH, 27 | ) -> jax.Array: 28 | """Adds positional embeddings to input embeddings.""" 29 | embed_dim = input_embedding.shape[-1] 30 | num_timescales = embed_dim // 2 31 | log_timescale_increment = jnp.log(float(max_wavelength)) / jnp.maximum( 32 | jnp.asarray(num_timescales, dtype=jnp.float32) - 1, 1 33 | ) 34 | inv_timescales = jnp.exp( 35 | jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment 36 | ) 37 | scaled_time = position * inv_timescales 38 | signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)]) 39 | signal = jnp.pad(signal, [[0, jnp.mod(embed_dim, 2)]]) 40 | position_embedding = signal.astype(jnp.float32) 41 | 42 | return input_embedding + position_embedding 43 | 44 | 45 | def apply_rope( 46 | inputs: jax.Array, # [B, L] 47 | positions: jax.Array, # [B, L] 48 | head_dim: int, 49 | max_wavelength: int = _MAX_WAVELENGTH, 50 | ) -> jax.Array: 51 | """Applies RoPE.""" 52 | fraction = 2 * jnp.arange(0, head_dim // 2) / head_dim 53 | timescale = max_wavelength**fraction 54 | 55 | sinusoid_inp = ( 56 | positions[..., jnp.newaxis] / timescale[jnp.newaxis, jnp.newaxis, :] 57 | ) 58 | sinusoid_inp = sinusoid_inp[..., jnp.newaxis, :] 59 | sin = jnp.sin(sinusoid_inp) 60 | cos = jnp.cos(sinusoid_inp) 61 | 62 | first_half, second_half = jnp.split(inputs, 2, axis=-1) 63 | first_part = first_half * cos - second_half * sin 64 | second_part = second_half * cos + first_half * sin 65 | out = jnp.concatenate([first_part, second_part], axis=-1) 66 | return out.astype(inputs.dtype) 67 | -------------------------------------------------------------------------------- /gemma/params.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Utils for loading Gemma params.""" 16 | 17 | import functools 18 | from typing import Any, Mapping 19 | 20 | import jax 21 | import jax.numpy as jnp 22 | import orbax.checkpoint 23 | 24 | Params = Mapping[str, Any] 25 | 26 | 27 | def load_and_format_params(path: str) -> Params: 28 | """Loads parameters and formats them for compatibility.""" 29 | params = load_params(path) 30 | param_state = jax.tree_util.tree_map(jnp.array, params) 31 | remapped_params = param_remapper(param_state) 32 | nested_params = nest_params(remapped_params) 33 | return nested_params 34 | 35 | 36 | @functools.cache 37 | def load_params(path: str) -> Params: 38 | """Loads parameters from a checkpoint path.""" 39 | checkpointer = orbax.checkpoint.PyTreeCheckpointer() 40 | params = checkpointer.restore(path) 41 | return params 42 | 43 | 44 | def param_remapper(orig_params: Params) -> Params: 45 | """Remaps params to new module layout. 46 | 47 | This is needed here because the model definition does not have a separate 48 | `mlp` module. 49 | 50 | Args: 51 | orig_params: original dict of parameters in Gemma format. 52 | 53 | Returns: 54 | dict of params with different names. 55 | """ 56 | new_params = {} 57 | for k, v in orig_params.items(): 58 | if 'mlp/' in k: 59 | layer_name, param = k.rsplit('/', maxsplit=1) 60 | if layer_name not in new_params: 61 | new_params[layer_name] = {} 62 | if 'w' in v: 63 | new_params[layer_name][param] = v['w'] 64 | else: 65 | new_params[k] = v 66 | return new_params 67 | 68 | 69 | def nest_params(params: Params) -> Params: 70 | """Nests params as a dict of dicts rather than a flat dict.""" 71 | nested_params = {} 72 | for path, param in params.items(): 73 | *path, leaf = path.split('/') 74 | subdict = nested_params 75 | for key in path: 76 | subdict = subdict.setdefault(key, {}) 77 | subdict[leaf] = param 78 | return nested_params 79 | -------------------------------------------------------------------------------- /gemma/positional_embeddings_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Tests for the positional embeddings utilities.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | from gemma import positional_embeddings 20 | import jax.numpy as jnp 21 | import numpy as np 22 | 23 | 24 | class PositionalEmbeddingsTest(parameterized.TestCase): 25 | 26 | @parameterized.parameters( 27 | dict( 28 | input_embedding_shape=(2, 1, 1, 5), 29 | position=3, 30 | max_wavelength=100, 31 | expected=[[[[1.1411201, 1.0299965, 0.0100075, 1.99955, 1.0]]], 32 | [[[1.1411201, 1.0299965, 0.0100075, 1.99955, 1.0]]]] 33 | ) 34 | ) 35 | def test_adds_positional_embeddings( 36 | self, input_embedding_shape, position, max_wavelength, expected 37 | ): 38 | outputs = positional_embeddings.add_positional_embedding( 39 | jnp.ones(input_embedding_shape), position, max_wavelength 40 | ) 41 | np.testing.assert_array_almost_equal(outputs, jnp.array(expected)) 42 | 43 | @parameterized.parameters( 44 | dict( 45 | input_embedding_shape=(2, 1, 2, 4), 46 | position=3, 47 | head_dim=4, 48 | max_wavelength=100, 49 | expected=[ 50 | [[ 51 | [-1.1311126, 0.6598157, -0.8488725, 1.2508571], 52 | [-1.1311126, 0.6598157, -0.8488725, 1.2508571], 53 | ]], 54 | [[ 55 | [-1.1311126, 0.6598157, -0.8488725, 1.2508571], 56 | [-1.1311126, 0.6598157, -0.8488725, 1.2508571], 57 | ]], 58 | ], 59 | ) 60 | ) 61 | def test_rope_positional_embeddings( 62 | self, input_embedding_shape, position, head_dim, max_wavelength, expected 63 | ): 64 | outputs = positional_embeddings.apply_rope( 65 | jnp.ones(input_embedding_shape), 66 | jnp.array([[position]]), 67 | head_dim, 68 | max_wavelength, 69 | ) 70 | np.testing.assert_array_almost_equal(outputs, jnp.array(expected)) 71 | 72 | 73 | if __name__ == "__main__": 74 | absltest.main() 75 | -------------------------------------------------------------------------------- /examples/sampling.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | r"""An example showing how to load a checkpoint and sample from it. 16 | 17 | Getting Started with Gemma Sampling: 18 | 19 | Prerequisites: 20 | 21 | 1. Download your Gemma checkpoint: Choose the desired checkpoint and download it. 22 | 2. Get the Gemma tokenizer: Download the tokenizer file required for your model. 23 | 3. Install Gemma: Follow the straightforward instructions in the README to install the Gemma repository. 24 | 25 | Ready to Sample! 26 | 27 | Here's how to run the sampling.py script: 28 | 29 | python sampling.py --path_checkpoint=${PATH_TO_THE_GEMMA_CHECKPOINT} \ 30 | --path_tokenizer=${PATH_TO_THE_GEMMA_TOKENIZER} \ 31 | --string_to_sample="Where is Paris?" 32 | """ 33 | from typing import Sequence 34 | 35 | from absl import app 36 | from absl import flags 37 | from gemma import params as params_lib 38 | from gemma import sampler as sampler_lib 39 | from gemma import transformer as transformer_lib 40 | 41 | import sentencepiece as spm 42 | 43 | _PATH_CHECKPOINT = flags.DEFINE_string( 44 | "path_checkpoint", None, required=True, help="Path to checkpoint." 45 | ) 46 | _PATH_TOKENIZER = flags.DEFINE_string( 47 | "path_tokenizer", None, required=True, help="Path to tokenizer." 48 | ) 49 | _TOTAL_GENERATION_STEPS = flags.DEFINE_integer( 50 | "total_sampling_steps", 51 | 128, 52 | help="Maximum number of step to run when decoding.", 53 | ) 54 | _STRING_TO_SAMPLE = flags.DEFINE_string( 55 | "string_to_sample", 56 | "Where is Paris ?", 57 | help="Input string to sample.", 58 | ) 59 | 60 | _CACHE_SIZE = 1024 61 | 62 | 63 | def _load_and_sample( 64 | *, 65 | path_checkpoint: str, 66 | path_tokenizer: str, 67 | input_string: str, 68 | cache_size: int, 69 | total_generation_steps: int, 70 | ) -> None: 71 | """Loads and samples a string from a checkpoint.""" 72 | print(f"Loading the parameters from {path_checkpoint}") 73 | parameters = params_lib.load_and_format_params(path_checkpoint) 74 | print("Parameters loaded.") 75 | # Create a sampler with the right param shapes. 76 | vocab = spm.SentencePieceProcessor() 77 | vocab.Load(path_tokenizer) 78 | transformer_config = transformer_lib.TransformerConfig.from_params( 79 | parameters, 80 | cache_size=cache_size 81 | ) 82 | transformer = transformer_lib.Transformer(transformer_config) 83 | sampler = sampler_lib.Sampler( 84 | transformer=transformer, 85 | vocab=vocab, 86 | params=parameters["transformer"], 87 | ) 88 | sampled_str = sampler( 89 | input_strings=[input_string], 90 | total_generation_steps=total_generation_steps, 91 | ).text 92 | 93 | print(f"Input string: {input_string}") 94 | print(f"Sampled string: {sampled_str}") 95 | 96 | 97 | def main(argv: Sequence[str]) -> None: 98 | 99 | if len(argv) > 1: 100 | raise app.UsageError("Too many command-line arguments.") 101 | 102 | _load_and_sample( 103 | path_checkpoint=_PATH_CHECKPOINT.value, 104 | path_tokenizer=_PATH_TOKENIZER.value, 105 | input_string=_STRING_TO_SAMPLE.value, 106 | cache_size=_CACHE_SIZE, 107 | total_generation_steps=_TOTAL_GENERATION_STEPS.value, 108 | ) 109 | 110 | 111 | if __name__ == "__main__": 112 | app.run(main) 113 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Gemma 2 | 3 | [Gemma](https://ai.google.dev/gemma) is a family of open-weights Large Language 4 | Model (LLM) by [Google DeepMind](https://deepmind.google/), based on Gemini 5 | research and technology. 6 | 7 | This repository contains an inference implementation and examples, based on the 8 | [Flax](https://github.com/google/flax) and [JAX](https://github.com/google/jax). 9 | 10 | ### Learn more about Gemma 11 | 12 | - The [Gemma technical report](https://ai.google.dev/gemma/technical-report) 13 | details the models' capabilities. 14 | - For tutorials, reference implementations in other ML frameworks, and more, 15 | visit https://ai.google.dev/gemma. 16 | 17 | ## Quick start 18 | 19 | ### Installation 20 | 21 | 1. To install Gemma you need to use Python 3.10 or higher. 22 | 23 | 2. Install JAX for CPU, GPU or TPU. Follow instructions at 24 | [the JAX website](https://jax.readthedocs.io/en/latest/installation.html). 25 | 26 | 3. Run 27 | 28 | ``` 29 | python -m venv gemma-demo 30 | . gemma-demo/bin/activate 31 | pip install git+https://github.com/google-deepmind/gemma.git 32 | ``` 33 | 34 | ### Downloading the models 35 | 36 | The model checkpoints are available through Kaggle at 37 | http://kaggle.com/models/google/gemma. Select one of the **Flax** model 38 | variations, click the ⤓ button to download the model archive, then extract the 39 | contents to a local directory. 40 | 41 | Alternatively, visit the [gemma]([https://huggingface.co/models?other=gemma.cpp](https://huggingface.co/models?other=gemma_jax)) 42 | models on the Hugging Face Hub. To download the model, you can run the following code if you have `huggingface_hub` installed: 43 | 44 | ``` 45 | from huggingface_hub import snapshot_download 46 | 47 | local_dir = snapshot_download(repo_id="google/gemma-2b-flax") 48 | snapshot_download(repo_id="google/gemma-2b-flax", local_dir=local_dir) 49 | ``` 50 | 51 | In both cases, the archive contains both the model weights and 52 | the tokenizer, for example the 2b Flax variation contains: 53 | 54 | ``` 55 | 2b/ # Directory containing model weights 56 | tokenizer.model # Tokenizer 57 | ``` 58 | 59 | ### Running the unit tests 60 | 61 | To run the unit tests, install the optional `[test]` dependencies (e.g. using 62 | `pip install -e .[test]` from the root of the source tree), then: 63 | 64 | ``` 65 | pytest . 66 | ``` 67 | 68 | Note that the tests in `sampler_test.py` are skipped by default since no 69 | tokenizer is distributed with the Gemma sources. To run these tests, download a 70 | tokenizer following the instructions above, and update the `_VOCAB` constant in 71 | `sampler_test.py` with the path to `tokenizer.model`. 72 | 73 | ## Examples 74 | 75 | To run the example sampling script, pass the paths to the weights directory and 76 | tokenizer: 77 | 78 | ``` 79 | python examples/sampling.py \ 80 | --path_checkpoint=/path/to/archive/contents/2b/ \ 81 | --path_tokenizer=/path/to/archive/contents/tokenizer.model 82 | ``` 83 | 84 | There are also several Colab notebook tutorials: 85 | 86 | - [`colabs/sampling_tutorial.ipynb`](https://colab.sandbox.google.com/github/google-deepmind/gemma/blob/main/colabs/sampling_tutorial.ipynb) 87 | contains a [Colab](http://colab.google) notebook with a sampling example. 88 | 89 | - [`colabs/fine_tuning_tutorial.ipynb`](https://colab.sandbox.google.com/github/google-deepmind/gemma/blob/main/colabs/fine_tuning_tutorial.ipynb) 90 | contains a [Colab](http://colab.google) with a basic tutorial on how to fine 91 | tune Gemma for a task, such as English to French translation. 92 | 93 | - [`colabs/gsm8k_eval.ipynb`](https://colab.sandbox.google.com/github/google-deepmind/gemma/blob/main/colabs/gsm8k_eval.ipynb) 94 | is a [Colab](http://colab.google) with a reference GSM8K eval 95 | implementation. 96 | 97 | To run these notebooks you will need to download a local copy of the weights and 98 | tokenizer (see above), and update the `ckpt_path` and `vocab_path` variables 99 | with the corresponding paths. 100 | 101 | ## System Requirements 102 | 103 | Gemma can run on a CPU, GPU and TPU. For GPU, we recommend a 8GB+ RAM on GPU for 104 | the 2B checkpoint and 24GB+ RAM on GPU for the 7B checkpoint. 105 | 106 | ## Contributing 107 | 108 | We are open to bug reports, pull requests (PR), and other contributions. Please 109 | see [CONTRIBUTING.md](CONTRIBUTING.md) for details on PRs. 110 | 111 | ## License 112 | 113 | Copyright 2024 DeepMind Technologies Limited 114 | 115 | This code is licensed under the Apache License, Version 2.0 (the \"License\"); 116 | you may not use this file except in compliance with the License. You may obtain 117 | a copy of the License at http://www.apache.org/licenses/LICENSE-2.0. 118 | 119 | Unless required by applicable law or agreed to in writing, software distributed 120 | under the License is distributed on an AS IS BASIS, WITHOUT WARRANTIES OR 121 | CONDITIONS OF ANY KIND, either express or implied. See the License for the 122 | specific language governing permissions and limitations under the License. 123 | 124 | ## Disclaimer 125 | 126 | This is not an official Google product. 127 | -------------------------------------------------------------------------------- /gemma/sampler_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Minimal test for sampler.""" 16 | 17 | from typing import Iterable 18 | 19 | from absl.testing import absltest 20 | from gemma import sampler as sampler_lib 21 | from gemma import transformer as transformer_lib 22 | import jax 23 | import jax.numpy as jnp 24 | import numpy as np 25 | 26 | import sentencepiece as spm 27 | 28 | 29 | class MockVocab(spm.SentencePieceProcessor): 30 | 31 | def __init__(self): 32 | super().__init__() 33 | self._start_id = 3 34 | self._mapping_text_to_id = { 35 | '': 0, 36 | '': 1, 37 | '': 2, 38 | 'input': 3, 39 | 'string': 4, 40 | 'hello': 5, 41 | 'world': 6, 42 | 'Hello': 7, 43 | 'there': 8, 44 | '!': 9, 45 | 'My': 10, 46 | 'name': 11, 47 | 'is': 12, 48 | 'Morgane': 13, 49 | } 50 | self._vocab_size = len(self._mapping_text_to_id) 51 | 52 | def pad_id(self) -> int: 53 | return 0 54 | 55 | def bos_id(self) -> int: 56 | return 1 57 | 58 | def eos_id(self) -> int: 59 | return 2 60 | 61 | def GetPieceSize(self) -> int: # pylint: disable=invalid-name 62 | return self._vocab_size 63 | 64 | def DecodeIds(self, ids: Iterable[int]) -> str: # pylint: disable=invalid-name 65 | reverse_mapping = {v: k for k, v in self._mapping_text_to_id.items()} 66 | return ' '.join(reverse_mapping[e] for e in ids) 67 | 68 | def EncodeAsIds(self, text: str) -> list[int]: # pylint: disable=invalid-name 69 | words = text.split(' ') 70 | return [self._mapping_text_to_id[word] for word in words] 71 | 72 | 73 | class SamplerTest(absltest.TestCase): 74 | 75 | def test_samples(self): 76 | vocab = MockVocab() 77 | 78 | transformer_config = transformer_lib.TransformerConfig( 79 | num_layers=6, 80 | num_embed=vocab.GetPieceSize(), 81 | embed_dim=768, 82 | hidden_dim=6144, 83 | num_heads=4, 84 | num_kv_heads=4, 85 | head_dim=256, 86 | max_cache_length=1024, 87 | ) 88 | attention_mask = jnp.ones((1, 1, transformer_config.max_cache_length)) 89 | cache = transformer_config.init_cache(1, dtype=jnp.float32) 90 | transformer = transformer_lib.Transformer(transformer_config) 91 | params = transformer.init( 92 | jax.random.PRNGKey(0), 93 | jnp.array([[1]]), 94 | jnp.array([[1]]), 95 | cache, 96 | attention_mask, 97 | ) 98 | sampler = sampler_lib.Sampler( 99 | transformer=transformer, 100 | vocab=vocab, 101 | params=params['params'], 102 | ) 103 | 104 | result = sampler(['input string', 'hello world'], total_generation_steps=10) 105 | self.assertIsNotNone(result) 106 | 107 | def test_forward_equivalence(self): 108 | vocab = MockVocab() 109 | transformer_config = transformer_lib.TransformerConfig( 110 | num_layers=2, 111 | num_embed=vocab.GetPieceSize(), 112 | embed_dim=32, 113 | hidden_dim=64, 114 | num_heads=4, 115 | num_kv_heads=1, 116 | head_dim=64, 117 | max_cache_length=8, 118 | ) 119 | 120 | transformer = transformer_lib.Transformer(transformer_config) 121 | raw_input = 'Hello there ! My name is Morgane' 122 | token_input = jnp.asarray( 123 | [vocab.bos_id()] + vocab.EncodeAsIds(raw_input) 124 | ).reshape((1, -1)) 125 | batch_size = 1 126 | cache = transformer_config.init_cache(batch_size, dtype=jnp.float32) 127 | input_mask = token_input != vocab.pad_id() 128 | positions = transformer_lib.build_positions_from_mask( 129 | input_mask 130 | ) 131 | attention_mask = transformer_lib.make_causal_attn_mask( 132 | token_input != vocab.pad_id() 133 | ) 134 | 135 | n_input_tokens = token_input.shape[1] 136 | 137 | params = transformer.init( 138 | jax.random.PRNGKey(42), 139 | token_input, 140 | positions, 141 | cache, 142 | attention_mask, 143 | ) 144 | 145 | output_forward, _ = transformer.apply( 146 | params, 147 | last_tokens=token_input, 148 | positions=positions, 149 | cache=cache, 150 | attention_mask=attention_mask, 151 | ) 152 | output_forward = output_forward[0, :n_input_tokens] 153 | 154 | sampler = sampler_lib.Sampler( 155 | transformer=transformer, 156 | vocab=vocab, 157 | params=params['params'], 158 | ) 159 | 160 | output_transformer = sampler( 161 | [raw_input], total_generation_steps=10, echo=True 162 | ) 163 | out_logits = np.array(output_transformer.logits)[0, 1 : n_input_tokens + 1] 164 | 165 | np.testing.assert_almost_equal(output_forward, out_logits) 166 | 167 | 168 | if __name__ == '__main__': 169 | absltest.main() 170 | -------------------------------------------------------------------------------- /gemma/transformer_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Tests for the Gemma transformer.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | from gemma import transformer as transformer_lib 20 | import jax 21 | import jax.numpy as jnp 22 | import numpy as np 23 | 24 | jax.config.update('jax_numpy_rank_promotion', 'raise') 25 | 26 | 27 | class TransformerTest(parameterized.TestCase): 28 | 29 | @parameterized.parameters( 30 | dict( 31 | num_layers=3, 32 | num_embed=4, 33 | embed_dim=2, 34 | num_heads=2, 35 | num_kv_heads=2, 36 | hidden_dim=4, 37 | head_dim=4, 38 | cache_size=2, 39 | batch_size=1, 40 | expected_outputs_shape=(1, 1, 4), 41 | expected_cache_shape=(1, 2, 2, 4), 42 | ), 43 | dict( 44 | num_layers=3, 45 | num_embed=4, 46 | embed_dim=2, 47 | num_heads=2, 48 | num_kv_heads=1, 49 | hidden_dim=4, 50 | head_dim=4, 51 | cache_size=2, 52 | batch_size=1, 53 | expected_outputs_shape=(1, 1, 4), 54 | expected_cache_shape=(1, 2, 2, 4), 55 | ), 56 | ) 57 | def test_transformer( 58 | self, 59 | num_layers, 60 | num_embed, 61 | embed_dim, 62 | num_heads, 63 | num_kv_heads, 64 | hidden_dim, 65 | head_dim, 66 | cache_size, 67 | batch_size, 68 | expected_outputs_shape, 69 | expected_cache_shape, 70 | ): 71 | 72 | config = transformer_lib.TransformerConfig( 73 | num_layers=num_layers, 74 | num_embed=num_embed, 75 | embed_dim=embed_dim, 76 | hidden_dim=hidden_dim, 77 | num_heads=num_heads, 78 | head_dim=head_dim, 79 | num_kv_heads=num_kv_heads, 80 | max_cache_length=cache_size, 81 | ) 82 | cache = config.init_cache(batch_size, dtype=jnp.float32) 83 | attention_mask = jnp.ones((batch_size, 1, cache_size), dtype=jnp.bool) 84 | transformer = transformer_lib.Transformer(config=config) 85 | params = transformer.init( 86 | jax.random.PRNGKey(0), 87 | jnp.array([[1]]), 88 | jnp.array([[1]]), 89 | cache, 90 | attention_mask, 91 | ) 92 | 93 | outputs, cache = transformer.apply( 94 | params, jnp.array([[1]]), jnp.array([[1]]), cache, attention_mask 95 | ) 96 | 97 | self.assertEqual(outputs.shape, expected_outputs_shape) 98 | self.assertEqual(cache['layer_0']['v'].shape, expected_cache_shape) 99 | 100 | @parameterized.parameters([ 101 | dict( 102 | config=transformer_lib.TransformerConfig( 103 | num_layers=2, 104 | num_embed=0, # unused 105 | embed_dim=0, # unused 106 | hidden_dim=0, # unused 107 | num_heads=3, 108 | head_dim=4, 109 | num_kv_heads=3, 110 | max_cache_length=2, 111 | ), 112 | keys=['layer_0', 'layer_1'], 113 | k_shape=(1, 2, 3, 4), 114 | v_shape=(1, 2, 3, 4), 115 | ) 116 | ]) 117 | def test_creates_cache(self, config, keys, k_shape, v_shape): 118 | cache = config.init_cache(1) 119 | self.assertEqual(list(cache.keys()), keys) 120 | self.assertEqual(cache['layer_0']['k'].shape, k_shape) 121 | self.assertEqual(cache['layer_0']['v'].shape, v_shape) 122 | 123 | @parameterized.parameters([ 124 | dict( 125 | batch_size=1, 126 | seq_size=4, 127 | config=transformer_lib.TransformerConfig( 128 | num_layers=2, 129 | num_embed=4, # unused 130 | embed_dim=2, 131 | hidden_dim=12, # unused 132 | num_heads=3, 133 | head_dim=4, 134 | num_kv_heads=3, 135 | max_cache_length=6, 136 | ), 137 | ) 138 | ]) 139 | def test_forward_no_cache( 140 | self, 141 | batch_size: int, 142 | seq_size: int, 143 | config: transformer_lib.TransformerConfig, 144 | ): 145 | 146 | token_input = jnp.ones((batch_size, seq_size), dtype=jnp.int32) 147 | empty_cache = config.init_cache(batch_size, dtype=jnp.float32) 148 | transformer = transformer_lib.Transformer(config=config) 149 | attention_mask = jnp.ones( 150 | (batch_size, seq_size, config.max_cache_length), dtype=jnp.bool 151 | ) 152 | positions = transformer_lib.build_positions_from_mask(token_input != 0) 153 | params = transformer.init( 154 | jax.random.PRNGKey(0), 155 | token_input, 156 | positions, 157 | empty_cache, 158 | attention_mask, 159 | ) 160 | 161 | output_cache, _ = transformer.apply( 162 | params, token_input, positions, empty_cache, attention_mask 163 | ) 164 | 165 | attention_mask = jnp.ones((batch_size, seq_size, seq_size), dtype=jnp.bool) 166 | output_none, cache_none = transformer.apply( 167 | params, token_input, positions, None, attention_mask 168 | ) 169 | 170 | self.assertIsNone(cache_none) 171 | np.testing.assert_array_almost_equal(output_cache, output_none, 1e-5) 172 | 173 | 174 | if __name__ == '__main__': 175 | absltest.main() 176 | -------------------------------------------------------------------------------- /gemma/modules_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Tests for transformer modules.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | from gemma import modules 20 | import jax 21 | import jax.numpy as jnp 22 | import numpy as np 23 | 24 | 25 | class EmbedderTest(parameterized.TestCase): 26 | 27 | @parameterized.parameters( 28 | dict( 29 | vocab_size=10, 30 | embed_dim=4, 31 | inputs=[2, 3], 32 | expected=[[2.0, 2.0, 2.0, 2.0], [2.0, 2.0, 2.0, 2.0]], 33 | ), 34 | ) 35 | def test_encodes(self, vocab_size, embed_dim, inputs, expected): 36 | embedder = modules.Embedder(vocab_size=vocab_size, embed_dim=embed_dim) 37 | output = embedder.apply( 38 | {'params': {'input_embedding': jnp.ones((vocab_size, embed_dim))}}, 39 | inputs, 40 | method=modules.Embedder.encode, 41 | ) 42 | np.testing.assert_array_equal(output, jnp.array(expected)) 43 | 44 | @parameterized.parameters( 45 | dict( 46 | vocab_size=5, 47 | embed_dim=2, 48 | inputs=[[1, 2]], 49 | expected=[[3.0, 3.0, 3.0, 3.0, 3.0]], 50 | ), 51 | ) 52 | def test_decodes(self, vocab_size, embed_dim, inputs, expected): 53 | embedder = modules.Embedder(vocab_size=vocab_size, embed_dim=embed_dim) 54 | output = embedder.apply( 55 | {'params': {'input_embedding': jnp.ones((vocab_size, embed_dim))}}, 56 | jnp.array(inputs), 57 | method=modules.Embedder.decode, 58 | ) 59 | np.testing.assert_array_equal(output, jnp.array(expected)) 60 | 61 | 62 | class AttentionTest(parameterized.TestCase): 63 | 64 | @parameterized.parameters( 65 | dict( 66 | num_heads=2, 67 | head_dim=4, 68 | features=8, 69 | segment_pos=0, 70 | cache_size=2, 71 | batch_size=2, 72 | expected_cache_shape=(2, 2, 2, 4), 73 | expected_output_shape=(2, 1, 8), 74 | ), 75 | ) 76 | def test_attention( 77 | self, 78 | num_heads, 79 | head_dim, 80 | features, 81 | segment_pos, 82 | cache_size, 83 | batch_size, 84 | expected_cache_shape, 85 | expected_output_shape, 86 | ): 87 | attn_mask = jnp.ones((batch_size, 1, num_heads)) 88 | attn = modules.Attention(num_heads, num_heads, features, head_dim) 89 | cache = modules.Attention.init_cache( 90 | cache_size, num_heads, head_dim, batch_size, dtype=jnp.float32 91 | ) 92 | x = jnp.ones((batch_size, 1, features)) 93 | params = attn.init( 94 | jax.random.PRNGKey(0), 95 | x, 96 | jnp.array([[segment_pos]]), 97 | cache, 98 | attn_mask, 99 | ) 100 | cache, output = attn.apply( 101 | params, x, jnp.array([[segment_pos]]), cache, attn_mask 102 | ) 103 | 104 | self.assertEqual(cache['k'].shape, expected_cache_shape) 105 | self.assertEqual(output.shape, expected_output_shape) 106 | 107 | 108 | class FeedForwardTest(parameterized.TestCase): 109 | 110 | @parameterized.parameters( 111 | dict( 112 | features=2, 113 | hidden_dim=3, 114 | batch_size=2, 115 | expected_val=[11.72758674, 47.99916], 116 | expected_shape=(2, 1, 2), 117 | ), 118 | ) 119 | def test_ffw( 120 | self, features, hidden_dim, batch_size, expected_val, expected_shape 121 | ): 122 | inputs = jnp.arange(1, batch_size+1)[:, None, None] 123 | inputs = jnp.repeat(inputs, features, axis=-1) 124 | ffw = modules.FeedForward(features=features, hidden_dim=hidden_dim) 125 | params = { 126 | 'gating_einsum': jnp.ones((2, features, hidden_dim)), 127 | 'linear': jnp.ones((hidden_dim, features)), 128 | } 129 | 130 | outputs = ffw.apply({'params': params}, inputs) 131 | 132 | np.testing.assert_array_almost_equal(outputs[:, 0, 0], expected_val) 133 | self.assertEqual(outputs.shape, expected_shape) 134 | 135 | 136 | class BlockTest(parameterized.TestCase): 137 | 138 | @parameterized.parameters( 139 | dict( 140 | num_heads=2, 141 | embed_dim=4, 142 | head_dim=6, 143 | cache_size=3, 144 | batch_size=2, 145 | expected_cache_shape=(2, 3, 2, 6), 146 | expected_output_shape=(2, 1, 4), 147 | ), 148 | ) 149 | def test_block( 150 | self, 151 | num_heads, 152 | embed_dim, 153 | head_dim, 154 | cache_size, 155 | batch_size, 156 | expected_cache_shape, 157 | expected_output_shape, 158 | ): 159 | inputs = jnp.ones((batch_size, 1, embed_dim)) 160 | cache = modules.Attention.init_cache( 161 | cache_size, num_heads, head_dim, batch_size, dtype=jnp.float32 162 | ) 163 | attn_mask = jnp.ones((batch_size, 1, cache_size)) 164 | block = modules.Block(num_heads, num_heads, embed_dim, head_dim, 1) 165 | params = block.init( 166 | jax.random.PRNGKey(0), inputs, jnp.array([[0]]), cache, attn_mask 167 | ) 168 | 169 | new_cache, outputs = block.apply( 170 | params, inputs, jnp.array([[0]]), cache, attn_mask 171 | ) 172 | 173 | self.assertEqual(new_cache['k'].shape, expected_cache_shape) 174 | self.assertEqual(outputs.shape, expected_output_shape) 175 | 176 | 177 | if __name__ == '__main__': 178 | absltest.main() 179 | -------------------------------------------------------------------------------- /gemma/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Gemma transformer.""" 16 | 17 | import dataclasses 18 | 19 | from flax import linen as nn 20 | from gemma import layers 21 | from gemma import modules 22 | from gemma import params as params_lib 23 | import jax 24 | import jax.numpy as jnp 25 | 26 | Cache = dict[str, modules.LayerCache] 27 | 28 | 29 | def make_causal_attn_mask( 30 | input_mask: jax.Array, 31 | ) -> jax.Array: 32 | """Attention mask in batch mode. 33 | 34 | Args: 35 | input_mask: Input mask for the input 36 | 37 | Returns: 38 | Attention mask. 39 | """ 40 | seq_len = input_mask.shape[-1] 41 | attn_mask = input_mask[..., None, :] 42 | causal_mask = jnp.tril(jnp.ones((seq_len, seq_len), dtype=jnp.bool_)) 43 | # Prefixes can be attended by all tokens 44 | attn_mask *= causal_mask[None, ...] 45 | return attn_mask 46 | 47 | 48 | def build_positions_from_mask(input_mask: jax.Array) -> jax.Array: 49 | """Computes the `positions` from the `input_mask`. 50 | 51 | Args: 52 | input_mask: The tokens `input_mask`, True for non-padded tokens only. 53 | 54 | Returns: 55 | The indices to use for RoPE and absolute position encodings for the given 56 | input mask. 57 | """ 58 | positions = jnp.cumsum(input_mask, axis=-1) 59 | # Subtract one for all positions from the first valid one as they are 60 | # 0-indexed 61 | return positions - (positions >= 1) 62 | 63 | 64 | @dataclasses.dataclass(frozen=True) 65 | class TransformerConfig: 66 | """Configuration for the gemma transformer.""" 67 | 68 | num_layers: int 69 | num_embed: int 70 | embed_dim: int 71 | hidden_dim: int 72 | num_heads: int 73 | head_dim: int 74 | num_kv_heads: int 75 | max_cache_length: int = 1024 76 | 77 | @classmethod 78 | def from_params( 79 | cls, params: params_lib.Params, cache_size: int = 1024 80 | ) -> 'TransformerConfig': 81 | """Creates a TransformerConfig from loaded parameters.""" 82 | num_layers = ( 83 | max([ 84 | int(k.split('_')[1]) 85 | for k in params['transformer'].keys() 86 | if 'layer_' in k 87 | ]) 88 | + 1 89 | ) 90 | 91 | hidden_dim, embed_dim = ( 92 | params['transformer']['layer_0']['mlp']['linear'].shape 93 | ) 94 | 95 | num_heads, head_dim, _ = ( 96 | params['transformer']['layer_0']['attn']['attn_vec_einsum']['w'].shape 97 | ) 98 | 99 | use_qkv_einsum = 'qkv_einsum' in params['transformer']['layer_0']['attn'] 100 | if use_qkv_einsum: 101 | num_kv_heads = num_heads 102 | else: 103 | num_kv_heads = params['transformer']['layer_0']['attn']['kv_einsum'][ 104 | 'w' 105 | ].shape[1] 106 | 107 | num_embed = params['transformer']['embedder']['input_embedding'].shape[0] 108 | 109 | return cls( 110 | num_layers=num_layers, 111 | num_embed=num_embed, 112 | embed_dim=embed_dim, 113 | hidden_dim=hidden_dim, 114 | num_heads=num_heads, 115 | head_dim=head_dim, 116 | num_kv_heads=num_kv_heads, 117 | max_cache_length=cache_size, 118 | ) 119 | 120 | def init_cache( 121 | self, 122 | batch_size: int, 123 | dtype: jnp.dtype = jnp.bfloat16, 124 | ) -> Cache: 125 | """Initializes a new Transformer cache.""" 126 | cache = { 127 | f'layer_{i}': modules.Attention.init_cache( 128 | self.max_cache_length, 129 | self.num_heads, 130 | self.head_dim, 131 | batch_size, 132 | dtype, 133 | ) 134 | for i in range(self.num_layers) 135 | } 136 | return cache 137 | 138 | 139 | class Transformer(nn.Module): 140 | """Gemma transformer.""" 141 | 142 | config: TransformerConfig 143 | 144 | def setup(self): 145 | self.embedder = modules.Embedder( 146 | vocab_size=self.config.num_embed, 147 | embed_dim=self.config.embed_dim, 148 | ) 149 | self.blocks = [ 150 | modules.Block( 151 | name=f'layer_{i}', 152 | num_heads=self.config.num_heads, 153 | num_kv_heads=self.config.num_kv_heads, 154 | embed_dim=self.config.embed_dim, 155 | head_dim=self.config.head_dim, 156 | hidden_dim=self.config.hidden_dim, 157 | ) 158 | for i in range(self.config.num_layers) 159 | ] 160 | self.final_norm = layers.RMSNorm() 161 | 162 | def __call__( 163 | self, 164 | last_tokens: jax.Array, # [B, L] 165 | positions: jax.Array, # [B, L] 166 | cache: Cache | None, # (sequence length L') 167 | attention_mask: jax.Array, # [B, L, L'] 168 | ) -> tuple[jax.Array, Cache | None]: 169 | """Transformer forward pass. 170 | 171 | You can run this forward pass two ways: with or without an attention kv 172 | cache. 173 | 174 | Args: 175 | last_tokens: input sequence of tokens. 176 | positions: input absolute positions. 177 | cache: Attention KV cache or None. 178 | attention_mask: transformer input mask. 179 | 180 | Returns: 181 | predicted_logits, new_cache 182 | 183 | predicted_logits: output logits predicted by the model 184 | new_cache: updated cache if the input cache is not None, None elsewhere. 185 | """ 186 | x = self.embedder.encode(last_tokens) 187 | for i, block in enumerate(self.blocks): 188 | layer_name = f'layer_{i}' 189 | layer_cache = cache[layer_name] if cache else None 190 | layer_cache, x = block( 191 | x, 192 | positions, 193 | layer_cache, 194 | attention_mask, 195 | ) 196 | if cache is not None: 197 | cache[layer_name] = layer_cache # pytype: disable=container-type-mismatch 198 | 199 | x = self.final_norm(x) 200 | logits = self.embedder.decode(x) 201 | 202 | return logits, cache # pytype: disable=bad-return-type 203 | -------------------------------------------------------------------------------- /gemma/modules.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Transformer sub-modules.""" 16 | 17 | from flax import linen as nn 18 | from gemma import layers 19 | from gemma import positional_embeddings 20 | import jax 21 | import jax.numpy as jnp 22 | 23 | K_MASK = -2.3819763e38 # Set to a large negative number. 24 | LayerCache = dict[str, jax.Array] 25 | 26 | 27 | class Embedder(nn.Module): 28 | """Embedder module.""" 29 | 30 | vocab_size: int 31 | embed_dim: int 32 | 33 | def setup(self): 34 | self.input_embedding_table = self.param( 35 | 'input_embedding', 36 | nn.initializers.normal(), 37 | (self.vocab_size, self.embed_dim), 38 | ) 39 | 40 | def encode(self, x: jax.Array) -> jax.Array: 41 | x = self.input_embedding_table[(x,)] 42 | x *= jnp.sqrt(self.embed_dim).astype(x.dtype) 43 | return x 44 | 45 | def decode(self, x: jax.Array) -> jax.Array: 46 | return jnp.dot(x, self.input_embedding_table.T) 47 | 48 | 49 | class Attention(nn.Module): 50 | """Attention module.""" 51 | 52 | num_heads: int 53 | num_kv_heads: int 54 | features: int 55 | head_dim: int 56 | 57 | @property 58 | def use_qkv_einsum(self): 59 | return self.num_kv_heads == self.num_heads 60 | 61 | def setup(self): 62 | self.attn_vec_einsum = layers.Einsum( 63 | shape=(self.num_heads, self.head_dim, self.features), 64 | ) 65 | 66 | if self.use_qkv_einsum: 67 | self.qkv_einsum = layers.Einsum( 68 | shape=(3, self.num_heads, self.features, self.head_dim), 69 | ) 70 | else: 71 | self.q_einsum = layers.Einsum( 72 | shape=(self.num_heads, self.features, self.head_dim), 73 | ) 74 | self.kv_einsum = layers.Einsum( 75 | shape=(2, self.num_kv_heads, self.features, self.head_dim), 76 | ) 77 | 78 | def __call__( 79 | self, 80 | x: jax.Array, 81 | segment_pos: jax.Array, 82 | cache: LayerCache | None, 83 | attn_mask: jax.Array, 84 | ) -> tuple[LayerCache | None, jax.Array]: 85 | seq_len = x.shape[1] 86 | 87 | if self.use_qkv_einsum: 88 | query_proj, key_proj, value_proj = self.qkv_einsum('BTD,SNDH->SBTNH', x) 89 | else: 90 | query_proj = self.q_einsum('BTD,NDH->BTNH', x) 91 | key_proj, value_proj = self.kv_einsum('BSD,CKDH->CBSKH', x) 92 | 93 | query_proj = positional_embeddings.apply_rope( 94 | query_proj, 95 | segment_pos, 96 | head_dim=self.head_dim, 97 | ) 98 | query_scaled = query_proj * self.head_dim**-0.5 99 | key_proj = positional_embeddings.apply_rope( 100 | key_proj, 101 | segment_pos, 102 | head_dim=self.head_dim, 103 | ) 104 | 105 | if not self.use_qkv_einsum: 106 | value_proj = jnp.repeat(value_proj, self.num_heads, axis=-2) 107 | key_proj = jnp.repeat(key_proj, self.num_heads, axis=-2) 108 | # Cache is left aligned. 109 | if cache is not None: 110 | end_index = cache['end_index'][0] 111 | slice_indices = (0, end_index % cache['v'].shape[1], 0, 0) 112 | value_proj = jax.lax.dynamic_update_slice( 113 | cache['v'], 114 | value_proj, 115 | slice_indices, 116 | ) 117 | key_proj = jax.lax.dynamic_update_slice( 118 | cache['k'], key_proj, slice_indices 119 | ) 120 | 121 | logits = jnp.einsum('BTNH,BSNH->BTNS', query_scaled, key_proj) 122 | padded_logits = jnp.where( 123 | (jnp.expand_dims(attn_mask, -2)), logits, K_MASK 124 | ) 125 | probs = jax.nn.softmax(padded_logits, axis=-1).astype(key_proj.dtype) 126 | encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj) 127 | attn_output = self.attn_vec_einsum('BTNH,NHD->BTD', encoded) 128 | 129 | if cache is not None: 130 | new_cache = { 131 | 'v': value_proj, 132 | 'k': key_proj, 133 | 'end_index': cache['end_index'] + seq_len, 134 | } 135 | else: 136 | new_cache = None 137 | 138 | return new_cache, attn_output 139 | 140 | @classmethod 141 | def init_cache( 142 | cls, 143 | cache_size: int, 144 | num_heads: int, 145 | head_dim: int, 146 | batch_size: int, 147 | dtype: jnp.dtype = jnp.bfloat16, 148 | ) -> LayerCache: 149 | del cls # not used 150 | return { 151 | 'v': jnp.zeros( 152 | (batch_size, cache_size, num_heads, head_dim), dtype=dtype 153 | ), 154 | 'k': jnp.zeros( 155 | (batch_size, cache_size, num_heads, head_dim), dtype=dtype 156 | ), 157 | 'end_index': jnp.zeros((batch_size,), dtype=jnp.int32), 158 | } 159 | 160 | 161 | class FeedForward(nn.Module): 162 | """Feed forward module.""" 163 | 164 | features: int 165 | hidden_dim: int 166 | 167 | @nn.compact 168 | def __call__(self, x): 169 | w_gating = self.param( 170 | 'gating_einsum', 171 | nn.initializers.zeros_init(), 172 | ((2, self.features, self.hidden_dim)), 173 | ) 174 | ff_gate = jnp.dot(x, w_gating[0]) 175 | gate_value = nn.gelu(ff_gate) 176 | 177 | ff1 = jnp.dot(x, w_gating[1]) 178 | activations = gate_value * ff1 179 | 180 | w_linear = self.param( 181 | 'linear', 182 | nn.initializers.zeros_init(), 183 | (self.hidden_dim, self.features), 184 | ) 185 | outputs = jnp.dot(activations, w_linear) 186 | 187 | return outputs 188 | 189 | 190 | class Block(nn.Module): 191 | """Transformer block.""" 192 | 193 | num_heads: int 194 | num_kv_heads: int 195 | embed_dim: int 196 | head_dim: int 197 | hidden_dim: int 198 | 199 | def setup(self): 200 | self.pre_attention_norm = layers.RMSNorm() 201 | self.attn = Attention( 202 | num_heads=self.num_heads, 203 | features=self.embed_dim, 204 | head_dim=self.head_dim, 205 | num_kv_heads=self.num_kv_heads, 206 | ) 207 | self.pre_ffw_norm = layers.RMSNorm() 208 | self.mlp = FeedForward(features=self.embed_dim, hidden_dim=self.hidden_dim) 209 | 210 | def __call__( 211 | self, 212 | x: jax.Array, 213 | segment_pos: jax.Array, 214 | cache: LayerCache | None, 215 | attn_mask: jax.Array, 216 | ) -> tuple[LayerCache | None, jax.Array]: 217 | inputs_normalized = self.pre_attention_norm(x) 218 | cache, attn_output = self.attn( 219 | inputs_normalized, 220 | segment_pos, 221 | cache, 222 | attn_mask, 223 | ) 224 | attn_output += x 225 | residual = attn_output 226 | attn_output = self.pre_ffw_norm(attn_output) 227 | outputs = self.mlp(attn_output) 228 | outputs = residual + outputs 229 | return cache, outputs 230 | -------------------------------------------------------------------------------- /colabs/sampling_tutorial.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "SC77q_zBESaM" 7 | }, 8 | "source": [ 9 | "Copyright 2024 DeepMind Technologies Limited.\n", 10 | "\n", 11 | "Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at\n", 12 | "\n", 13 | "http://www.apache.org/licenses/LICENSE-2.0\n", 14 | "\n", 15 | "Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.\n", 16 | "\n", 17 | "---" 18 | ] 19 | }, 20 | { 21 | "cell_type": "markdown", 22 | "metadata": { 23 | "id": "TpESp4p5ESaM" 24 | }, 25 | "source": [ 26 | "# Getting Started with Gemma Sampling: A Step-by-Step Guide\n", 27 | "\n", 28 | "\n", 29 | " \"Open\n", 30 | "\n", 31 | "\n", 32 | "You will find in this colab a detailed tutorial explaining how to load a Gemma checkpoint and sample from it.\n", 33 | "\n" 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "metadata": { 39 | "id": "LtzOe_3XY9R5" 40 | }, 41 | "source": [ 42 | "## Installation" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": null, 48 | "metadata": { 49 | "id": "iq2ebV_6YNiU" 50 | }, 51 | "outputs": [], 52 | "source": [ 53 | "! pip install git+https://github.com/google-deepmind/gemma.git\n", 54 | "! pip install --user kaggle" 55 | ] 56 | }, 57 | { 58 | "cell_type": "markdown", 59 | "metadata": { 60 | "id": "QOzN-gxIYSB4" 61 | }, 62 | "source": [ 63 | "## Downloading the checkpoint\n", 64 | "\n", 65 | "\"To use Gemma's checkpoints, you'll need a Kaggle account and API key. Here's how to get them:\n", 66 | "\n", 67 | "1. Visit https://www.kaggle.com/ and create an account.\n", 68 | "2. Go to your account settings, then the 'API' section.\n", 69 | "3. Click 'Create new token' to download your key.\n", 70 | "\n", 71 | "Then run the cell below." 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": null, 77 | "metadata": { 78 | "id": "likVQiEEYS5X" 79 | }, 80 | "outputs": [], 81 | "source": [ 82 | "import kagglehub\n", 83 | "kagglehub.login()" 84 | ] 85 | }, 86 | { 87 | "cell_type": "markdown", 88 | "metadata": { 89 | "id": "QRxOFyGbYUjZ" 90 | }, 91 | "source": [ 92 | "If everything went well, you should see:\n", 93 | "```\n", 94 | "Kaggle credentials set.\n", 95 | "Kaggle credentials successfully validated.\n", 96 | "```\n", 97 | "\n", 98 | "Now select and download the checkpoint you want to try. Note that you will need an A100 runtime for the 7b models." 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": null, 104 | "metadata": { 105 | "id": "O-sxcasvESaM" 106 | }, 107 | "outputs": [], 108 | "source": [ 109 | "import os\n", 110 | "\n", 111 | "VARIANT = '2b-it' # @param ['2b', '2b-it', '7b', '7b-it'] {type:\"string\"}\n", 112 | "weights_dir = kagglehub.model_download(f'google/gemma/Flax/{VARIANT}')\n", 113 | "\n", 114 | "ckpt_path = os.path.join(weights_dir, variant)\n", 115 | "vocab_path = os.path.join(weights_dir, 'tokenizer.model')" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": null, 121 | "metadata": { 122 | "cellView": "form", 123 | "id": "-jpTUa1YESaM" 124 | }, 125 | "outputs": [], 126 | "source": [ 127 | "# @title Python imports\n", 128 | "from gemma import params as params_lib\n", 129 | "from gemma import sampler as sampler_lib\n", 130 | "from gemma import transformer as transformer_lib\n", 131 | "import sentencepiece as spm" 132 | ] 133 | }, 134 | { 135 | "cell_type": "markdown", 136 | "metadata": { 137 | "id": "4fDQsC87ESaN" 138 | }, 139 | "source": [ 140 | "## Start Generating with Your Model\n", 141 | "\n", 142 | "Load and prepare your LLM's checkpoint for use with Flax." 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": null, 148 | "metadata": { 149 | "cellView": "form", 150 | "id": "57nMYQ4HESaN" 151 | }, 152 | "outputs": [], 153 | "source": [ 154 | "# Load parameters\n", 155 | "params = params_lib.load_and_format_params(ckpt_path)" 156 | ] 157 | }, 158 | { 159 | "cell_type": "markdown", 160 | "metadata": { 161 | "id": "NWJ3UvHXESaN" 162 | }, 163 | "source": [ 164 | "Load your tokenizer, which we'll construct using the [SentencePiece](https://github.com/google/sentencepiece) library." 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": null, 170 | "metadata": { 171 | "cellView": "form", 172 | "id": "khXrjEF0ESaN" 173 | }, 174 | "outputs": [], 175 | "source": [ 176 | "vocab = spm.SentencePieceProcessor()\n", 177 | "vocab.Load(vocab_path)" 178 | ] 179 | }, 180 | { 181 | "cell_type": "markdown", 182 | "metadata": { 183 | "id": "tCRtZMg0ESaN" 184 | }, 185 | "source": [ 186 | "Use the `transformer_lib.TransformerConfig.from_params` function to automatically load the correct configuration from a checkpoint. Note that the vocabulary size is smaller than the number of input embeddings due to unused tokens in this release." 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": null, 192 | "metadata": { 193 | "cellView": "form", 194 | "id": "7InOzQtcESaN" 195 | }, 196 | "outputs": [], 197 | "source": [ 198 | "transformer_config=transformer_lib.TransformerConfig.from_params(\n", 199 | " params,\n", 200 | " cache_size=1024 # Number of time steps in the transformer's cache\n", 201 | ")\n", 202 | "transformer = transformer_lib.Transformer(transformer_config)" 203 | ] 204 | }, 205 | { 206 | "cell_type": "markdown", 207 | "metadata": { 208 | "id": "KaU-X3_jESaN" 209 | }, 210 | "source": [ 211 | "Finally, build a sampler on top of your model and your tokenizer." 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": null, 217 | "metadata": { 218 | "cellView": "form", 219 | "id": "bdstASGrESaN" 220 | }, 221 | "outputs": [], 222 | "source": [ 223 | "# Create a sampler with the right param shapes.\n", 224 | "sampler = sampler_lib.Sampler(\n", 225 | " transformer=transformer,\n", 226 | " vocab=vocab,\n", 227 | " params=params['transformer'],\n", 228 | ")" 229 | ] 230 | }, 231 | { 232 | "cell_type": "markdown", 233 | "metadata": { 234 | "id": "C1fLns-_ESaN" 235 | }, 236 | "source": [ 237 | "You're ready to start sampling ! This sampler uses just-in-time compilation, so changing the input shape triggers recompilation, which can slow things down. For the fastest and most efficient results, keep your batch size consistent." 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "execution_count": null, 243 | "metadata": { 244 | "cellView": "form", 245 | "id": "qA0BhNQvESaN" 246 | }, 247 | "outputs": [], 248 | "source": [ 249 | "input_batch = [\n", 250 | " \"\\n# Python program for implementation of Bubble Sort\\n\\ndef bubbleSort(arr):\",\n", 251 | " \"What are the planets of the solar system?\",\n", 252 | " ]\n", 253 | "\n", 254 | "out_data = sampler(\n", 255 | " input_strings=input_batch,\n", 256 | " total_generation_steps=300, # number of steps performed when generating\n", 257 | " )\n", 258 | "\n", 259 | "for input_string, out_string in zip(input_batch, out_data.text):\n", 260 | " print(f\"Prompt:\\n{input_string}\\nOutput:\\n{out_string}\")\n", 261 | " print()\n", 262 | " print(10*'#')" 263 | ] 264 | }, 265 | { 266 | "cell_type": "markdown", 267 | "metadata": { 268 | "id": "tqbJ1SUcESaN" 269 | }, 270 | "source": [ 271 | "You should get an implementation of bubble sort and a description of the solar system.\n" 272 | ] 273 | } 274 | ], 275 | "metadata": { 276 | "colab": { 277 | "private_outputs": true 278 | }, 279 | "language_info": { 280 | "name": "python" 281 | } 282 | }, 283 | "nbformat": 4, 284 | "nbformat_minor": 0 285 | } 286 | -------------------------------------------------------------------------------- /gemma/sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================ 15 | """Sampler for Gemma transformer. 16 | 17 | An example of a sampling class for a Gemma model. 18 | """ 19 | from collections.abc import Sequence 20 | import dataclasses 21 | 22 | import chex 23 | from gemma import modules 24 | from gemma import params as params_lib 25 | from gemma import transformer as transformer_lib 26 | import jax 27 | import jax.numpy as jnp 28 | 29 | import sentencepiece as spm 30 | 31 | 32 | def _compute_attention_masks( 33 | time_step: jax.Array, seq_len: int, input_mask: jax.Array 34 | ) -> jax.Array: 35 | """Computes causal attention mask.""" 36 | bsz = input_mask.shape[0] 37 | batch_time_step = jnp.full((bsz, 1), time_step, dtype=jnp.uint32) 38 | causal_padding = jnp.greater( 39 | jnp.expand_dims(jnp.arange(seq_len), 0), batch_time_step 40 | ) 41 | causal_padding = causal_padding * jnp.expand_dims(input_mask, axis=-1) 42 | attention_mask = causal_padding[:, jnp.newaxis, jnp.newaxis, :].astype( 43 | jnp.bool_ 44 | ) 45 | attention_mask = jnp.squeeze(attention_mask, axis=1) 46 | return ~attention_mask 47 | 48 | 49 | @chex.dataclass 50 | class _SamplingState: 51 | """Internal sampling state.""" 52 | 53 | # Decoding step 54 | decoding_step: jnp.int32 55 | 56 | # Number of tokens in the prompt. 57 | num_input_tokens: jnp.ndarray # [B] 58 | 59 | # Fixed-size buffer for accumulating the output tokens. 60 | token_buffer: jnp.ndarray # [B, L] 61 | 62 | # Model state for conditioning the model on autoregressively. 63 | cache: dict[str, modules.LayerCache] 64 | 65 | # Is decoding done on the given sequence 66 | done: jnp.ndarray # [B] 67 | 68 | # Total sampling steps (including the prompt) 69 | total_sampling_steps: int 70 | 71 | # Fixed-size buffer for accumulating the output logits. 72 | logits_buffer: jnp.ndarray | None = None # [B, L, V] 73 | 74 | 75 | @dataclasses.dataclass 76 | class SamplerOutput: 77 | 78 | # Decoded samples from the model. 79 | text: list[str] 80 | 81 | # Per-step logits used during sampling. 82 | logits: list[list[float]] 83 | 84 | # Tokens corresponding to the generated samples. 85 | tokens: list[list[int]] 86 | 87 | 88 | class Sampler: 89 | """Sampler for gemma transformer.""" 90 | 91 | def __init__( 92 | self, 93 | transformer: transformer_lib.Transformer, 94 | vocab: spm.SentencePieceProcessor, 95 | params: params_lib.Params, 96 | ): 97 | """Initializes a sampler for a Gemma model. 98 | 99 | Args: 100 | transformer: an instance of the Gemma transformer. 101 | vocab: vocabulary of the given model. 102 | params: weights of the model. 103 | """ 104 | self.transformer = transformer 105 | self.vocab = vocab 106 | self.params = params 107 | self._compiled_sample_fn = jax.jit(self._sample_fn) 108 | 109 | @property 110 | def dtype(self) -> jnp.dtype: 111 | return jax.tree_util.tree_leaves(self.params)[0].dtype 112 | 113 | def _sample_step( 114 | self, params, sampler_state: _SamplingState 115 | ) -> _SamplingState: 116 | """Performs a single sampling step.""" 117 | batch_size = sampler_state.token_buffer.shape[0] 118 | decoding_step = jnp.asarray(sampler_state.decoding_step, dtype=jnp.int32) 119 | last_token = sampler_state.token_buffer[:, decoding_step] 120 | input_mask = last_token != self.vocab.pad_id() 121 | attention_mask = _compute_attention_masks( 122 | decoding_step, self.transformer.config.max_cache_length, input_mask 123 | ) 124 | positions = jnp.full((batch_size, 1), decoding_step, dtype=jnp.int32) 125 | last_token = last_token.reshape((batch_size, 1)) 126 | 127 | logits, cache = self.transformer.apply( 128 | {'params': params}, 129 | last_token, 130 | positions, 131 | sampler_state.cache, 132 | attention_mask, 133 | ) 134 | 135 | next_token_candidate = jnp.argmax(logits, axis=-1) # [B, 1] 136 | next_token_candidate = next_token_candidate[:, 0] # [B,] 137 | 138 | next_token_candidate = jnp.where( 139 | decoding_step < sampler_state.num_input_tokens - 1, 140 | sampler_state.token_buffer[:, decoding_step + 1], 141 | next_token_candidate, 142 | ) 143 | 144 | token_buffer = sampler_state.token_buffer.at[:, decoding_step + 1].set( 145 | next_token_candidate 146 | ) 147 | 148 | if sampler_state.logits_buffer is not None: 149 | next_logits = jnp.squeeze(logits, 1) 150 | logits_buffer = sampler_state.logits_buffer.at[:, decoding_step + 1].set( 151 | next_logits 152 | ) 153 | else: 154 | logits_buffer = sampler_state.logits_buffer 155 | 156 | done = sampler_state.done | jnp.equal( 157 | token_buffer[:, decoding_step + 1], self.vocab.eos_id() 158 | ) 159 | 160 | return _SamplingState( 161 | decoding_step=sampler_state.decoding_step + 1, 162 | num_input_tokens=sampler_state.num_input_tokens, 163 | token_buffer=token_buffer, 164 | logits_buffer=logits_buffer, 165 | cache=cache, 166 | done=done, 167 | total_sampling_steps=sampler_state.total_sampling_steps, 168 | ) 169 | 170 | def init_cache(self, bsz) -> dict[str, modules.LayerCache]: 171 | """Initializes the attention cache for each layer.""" 172 | return self.transformer.config.init_cache(bsz, dtype=self.dtype) 173 | 174 | def init_sample_state( 175 | self, 176 | all_input_ids: list[jax.Array], 177 | total_sampling_steps: int, 178 | include_logits: bool = False, 179 | ) -> _SamplingState: 180 | """Initializes the sampling state given input prompts.""" 181 | bsz = len(all_input_ids) 182 | num_input_tokens = [len(input_ids) for input_ids in all_input_ids] 183 | buffer_size = total_sampling_steps + 1 184 | 185 | token_buffer = jnp.full( 186 | ( 187 | bsz, 188 | buffer_size, 189 | ), 190 | self.vocab.pad_id(), 191 | dtype=jnp.int32, 192 | ) 193 | for i, (input_ids, num_tokens) in enumerate( 194 | zip(all_input_ids, num_input_tokens) 195 | ): 196 | token_buffer = token_buffer.at[i, :num_tokens].set(input_ids) 197 | 198 | done = jnp.zeros((bsz,), dtype=jnp.bool_) 199 | 200 | if include_logits: 201 | logits_buffer = jnp.zeros( 202 | (bsz, buffer_size, self.transformer.config.num_embed), 203 | dtype=jnp.float32, 204 | ) 205 | else: 206 | logits_buffer = None 207 | 208 | return _SamplingState( 209 | decoding_step=0, 210 | num_input_tokens=jnp.array(num_input_tokens, dtype=jnp.int32), 211 | token_buffer=token_buffer, 212 | logits_buffer=logits_buffer, 213 | cache=self.init_cache(bsz), 214 | done=done, 215 | total_sampling_steps=total_sampling_steps, 216 | ) 217 | 218 | def tokenize(self, input_string: str) -> jax.Array: 219 | """Tokenizes the input string.""" 220 | input_ids = self.vocab.EncodeAsIds(input_string) 221 | input_ids = jnp.array( 222 | [self.vocab.bos_id()] + jnp.array(input_ids).tolist(), dtype=jnp.int32 223 | ) 224 | return input_ids 225 | 226 | def _sample_fn( 227 | self, 228 | params: params_lib.Params, 229 | initial_sampling_state: _SamplingState, 230 | ) -> _SamplingState: 231 | """Internal sampling function (to be jitted).""" 232 | 233 | def sample_with_params(sampler_state: _SamplingState): 234 | return self._sample_step(params, sampler_state) 235 | 236 | def cond_fn(sampler_state: _SamplingState): 237 | return ( 238 | sampler_state.decoding_step < sampler_state.total_sampling_steps 239 | ) & jnp.any(jnp.logical_not(sampler_state.done)) 240 | 241 | return jax.lax.while_loop( 242 | cond_fn, sample_with_params, initial_sampling_state 243 | ) 244 | 245 | def __call__( 246 | self, 247 | input_strings: Sequence[str], 248 | total_generation_steps: int, 249 | echo: bool = False, 250 | return_logits: bool = True, 251 | ) -> SamplerOutput: 252 | """Samples a completion of the input string. 253 | 254 | Args: 255 | input_strings: input prompts to feed to the model for sampling. 256 | total_generation_steps: number of generation steps. will correspond to the 257 | longest prompt in the batch. 258 | echo: whether to return the prompt as part of the output sample. 259 | return_logits: whether to return per-step logits used during generation. 260 | 261 | Returns: 262 | sampler_output: A SamplerOutput object containing the generated samples. 263 | """ 264 | 265 | all_input_ids = [self.tokenize(x) for x in input_strings] 266 | max_input_length = max(len(input_ids) for input_ids in all_input_ids) 267 | total_sampling_steps = max_input_length + total_generation_steps 268 | initial_sampling_state = self.init_sample_state( 269 | all_input_ids, 270 | include_logits=return_logits, 271 | total_sampling_steps=total_sampling_steps, 272 | ) 273 | 274 | sampling_state = self._compiled_sample_fn( 275 | self.params, initial_sampling_state 276 | ) 277 | 278 | out_tokens = [] 279 | out_logits = [] 280 | for i, (token_buffer, num_tokens) in enumerate( 281 | zip( 282 | sampling_state.token_buffer, 283 | sampling_state.num_input_tokens, 284 | ) 285 | ): 286 | start_idx = 0 if echo else num_tokens 287 | out_tokens.append(token_buffer[start_idx:total_sampling_steps].tolist()) 288 | if return_logits: 289 | logits_buffer = sampling_state.logits_buffer[i] 290 | out_logits.append( 291 | logits_buffer[start_idx:total_sampling_steps].tolist() 292 | ) 293 | 294 | decoded_outputs = [ 295 | self.vocab.DecodeIds(tokens) for tokens in out_tokens 296 | ] 297 | 298 | result = SamplerOutput( 299 | text=decoded_outputs, 300 | logits=out_logits, 301 | tokens=out_tokens, 302 | ) 303 | return result 304 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /colabs/gsm8k_eval.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "J72yaKjJEXip" 7 | }, 8 | "source": [ 9 | "Copyright 2024 DeepMind Technologies Limited.\n", 10 | "\n", 11 | "Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at\n", 12 | "\n", 13 | "http://www.apache.org/licenses/LICENSE-2.0\n", 14 | "\n", 15 | "Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.\n", 16 | "\n", 17 | "---" 18 | ] 19 | }, 20 | { 21 | "cell_type": "markdown", 22 | "metadata": { 23 | "id": "kRShUtLfEXiq" 24 | }, 25 | "source": [ 26 | "# GSM8K evaluation using Gemma\n", 27 | "\n", 28 | "\n", 29 | " \"Open\n", 30 | "\n", 31 | "\n", 32 | "The [GSM8K dataset](https://arxiv.org/pdf/2110.14168.pdf) presents a good evaluation challenge for small models for several reasons:\n", 33 | "\n", 34 | "1. **Conceptual Simplicity:** While the problems in GSM8K require multi-step reasoning, they primarily involve elementary mathematical concepts and basic arithmetic operations. This makes the dataset accessible to smaller models that may not have the capacity to handle complex mathematical reasoning.\n", 35 | "\n", 36 | "2. **Linguistic Diversity:** GSM8K emphasizes linguistic diversity, ensuring that problems are not simply variations of the same template. This forces models to generalize their understanding of language and mathematical concepts, rather than relying on superficial pattern matching.\n", 37 | "\n", 38 | "3. **Moderate Difficulty:** The problems in GSM8K are challenging enough to test the limits of small models without being completely intractable. This allows for meaningful evaluation and comparison of different models and methods within a reasonable difficulty range.\n", 39 | "\n", 40 | "4. **Natural Language Solutions:** GSM8K provides solutions in natural language, encouraging models to develop verbal analytical skills and produce human-interpretable reasoning steps. This is particularly relevant for smaller models that may struggle with purely symbolic or equation-based solutions.\n", 41 | "\n", 42 | "By focusing on grade-school math concepts and emphasizing linguistic diversity, GSM8K provides a valuable benchmark for evaluating the informal reasoning abilities of smaller language models and identifying areas for improvement.\n", 43 | "\n", 44 | "The 2B Gemma checkpoint achieves a score of 19%, which is a higher result than obtained using [much larger competing checkpoints](https://paperswithcode.com/sota/arithmetic-reasoning-on-gsm8k)." 45 | ] 46 | }, 47 | { 48 | "cell_type": "markdown", 49 | "metadata": { 50 | "id": "GMv_56WyEXiq" 51 | }, 52 | "source": [ 53 | "## Installation" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": null, 59 | "metadata": { 60 | "id": "pIF7Tr8yEXiq" 61 | }, 62 | "outputs": [], 63 | "source": [ 64 | "! pip install git+https://github.com/google-deepmind/gemma.git\n", 65 | "! pip install --user kaggle" 66 | ] 67 | }, 68 | { 69 | "cell_type": "markdown", 70 | "metadata": { 71 | "id": "ucx1AgltaZRF" 72 | }, 73 | "source": [ 74 | "## Downloading the checkpoint\n", 75 | "\n", 76 | "\"To use Gemma's checkpoints, you'll need a Kaggle account and API key. Here's how to get them:\n", 77 | "\n", 78 | "1. Visit https://www.kaggle.com/ and create an account.\n", 79 | "2. Go to your account settings, then the 'API' section.\n", 80 | "3. Click 'Create new token' to download your key.\n", 81 | "\n", 82 | "Then run the cell below." 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": null, 88 | "metadata": { 89 | "id": "qai-J2Dgaac0" 90 | }, 91 | "outputs": [], 92 | "source": [ 93 | "import kagglehub\n", 94 | "kagglehub.login()" 95 | ] 96 | }, 97 | { 98 | "cell_type": "markdown", 99 | "metadata": { 100 | "id": "rSv5uG5_acQk" 101 | }, 102 | "source": [ 103 | "If everything went well, you should see:\n", 104 | "```\n", 105 | "Kaggle credentials set.\n", 106 | "Kaggle credentials successfully validated.\n", 107 | "```\n", 108 | "\n", 109 | "Now select and download the checkpoint you want to try. Note that you will need an A100 runtime for the 7b models." 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": null, 115 | "metadata": { 116 | "id": "-uHzK733EXiq" 117 | }, 118 | "outputs": [], 119 | "source": [ 120 | "import os\n", 121 | "\n", 122 | "VARIANT = '2b-it' # @param ['2b', '2b-it', '7b', '7b-it'] {type:\"string\"}\n", 123 | "weights_dir = kagglehub.model_download(f'google/gemma/Flax/{VARIANT}')\n", 124 | "\n", 125 | "ckpt_path = os.path.join(weights_dir, variant)\n", 126 | "vocab_path = os.path.join(weights_dir, 'tokenizer.model')" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": null, 132 | "metadata": { 133 | "id": "udFgOLxJEXiq" 134 | }, 135 | "outputs": [], 136 | "source": [ 137 | "# @title Python imports\n", 138 | "import re\n", 139 | "from gemma import params as params_lib\n", 140 | "from gemma import sampler as sampler_lib\n", 141 | "from gemma import transformer as transformer_lib\n", 142 | "\n", 143 | "import datasets\n", 144 | "import sentencepiece as spm" 145 | ] 146 | }, 147 | { 148 | "cell_type": "markdown", 149 | "metadata": { 150 | "id": "vNEpwGyREXiq" 151 | }, 152 | "source": [ 153 | "## Load GSM8K dataset" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": null, 159 | "metadata": { 160 | "cellView": "form", 161 | "id": "E47hYa8dEXiq" 162 | }, 163 | "outputs": [], 164 | "source": [ 165 | "gsm8k = datasets.load_dataset(\"gsm8k\", \"main\", cache_dir='/tmp')\n", 166 | "gsm8k_train, gsm8k_test = gsm8k['train'], gsm8k['test']" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": null, 172 | "metadata": { 173 | "cellView": "form", 174 | "id": "ReheKSODEXiq" 175 | }, 176 | "outputs": [], 177 | "source": [ 178 | "# @title Testing library\n", 179 | "\n", 180 | "def find_numbers(x: str) -> list[str]:\n", 181 | " \"\"\"Finds all numbers in a string.\"\"\"\n", 182 | " # Search for number, possibly negative (hyphen), with thousand separators\n", 183 | " # (comma), and with a decimal point (period inbetween digits).\n", 184 | " numbers = re.compile(\n", 185 | " r'-?[\\d,]*\\.?\\d+',\n", 186 | " re.MULTILINE | re.DOTALL | re.IGNORECASE,\n", 187 | " ).findall(x)\n", 188 | " return numbers\n", 189 | "\n", 190 | "\n", 191 | "def find_number(x: str,\n", 192 | " answer_delimiter: str = 'The answer is') -> str:\n", 193 | " \"\"\"Finds the most relevant number in a string.\"\"\"\n", 194 | " # If model uses the answer delimiter, then select the first number following\n", 195 | " # that format.\n", 196 | " if answer_delimiter in x:\n", 197 | " answer = x.split(answer_delimiter)[-1]\n", 198 | " numbers = find_numbers(answer)\n", 199 | " if numbers:\n", 200 | " return numbers[0]\n", 201 | "\n", 202 | " # In general, select the last number in the string.\n", 203 | " numbers = find_numbers(x)\n", 204 | " if numbers:\n", 205 | " return numbers[-1]\n", 206 | " return ''\n", 207 | "\n", 208 | "\n", 209 | "def maybe_remove_comma(x: str) -> str:\n", 210 | " # Example: 5,600 -> 5600\n", 211 | " return x.replace(',', '')" 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": null, 217 | "metadata": { 218 | "cellView": "form", 219 | "id": "cXoCKMi9EXir" 220 | }, 221 | "outputs": [], 222 | "source": [ 223 | "# @title GSM8K Prompts\n", 224 | "\n", 225 | "PREAMBLE = \"\"\"As an expert problem solver solve step by step the following mathematical questions.\"\"\"\n", 226 | "\n", 227 | "# The default gsm8k prompt from the CoT paper\n", 228 | "# https://arxiv.org/pdf/2201.11903.pdf page 35.\n", 229 | "\n", 230 | "PROMPT = \"\"\"Q: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?\n", 231 | "A: We start with 15 trees. Later we have 21 trees. The difference must be the number of trees they planted. So, they must have planted 21 - 15 = 6 trees. The answer is 6.\n", 232 | "\n", 233 | "Q: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?\n", 234 | "A: There are 3 cars in the parking lot already. 2 more arrive. Now there are 3 + 2 = 5 cars. The answer is 5.\n", 235 | "\n", 236 | "Q: Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?\n", 237 | "A: Leah had 32 chocolates and Leah's sister had 42. That means there were originally 32 + 42 = 74 chocolates. 35 have been eaten. So in total they still have 74 - 35 = 39 chocolates. The answer is 39.\n", 238 | "\n", 239 | "Q: Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny?\n", 240 | "A: Jason had 20 lollipops. Since he only has 12 now, he must have given the rest to Denny. The number of lollipops he has given to Denny must have been 20 - 12 = 8 lollipops. The answer is 8.\n", 241 | "\n", 242 | "Q: Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now?\n", 243 | "A: He has 5 toys. He got 2 from mom, so after that he has 5 + 2 = 7 toys. Then he got 2 more from dad, so in total he has 7 + 2 = 9 toys. The answer is 9.\n", 244 | "\n", 245 | "Q: There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room?\n", 246 | "A: There are 4 days from monday to thursday. 5 computers were added each day. That means in total 4 * 5 = 20 computers were added. There were 9 computers in the beginning, so now there are 9 + 20 = 29 computers. The answer is 29.\n", 247 | "\n", 248 | "Q: Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?\n", 249 | "A: Michael initially had 58 balls. He lost 23 on Tuesday, so after that he has 58 - 23 = 35 balls. On Wednesday he lost 2 more so now he has 35 - 2 = 33 balls. The answer is 33.\n", 250 | "\n", 251 | "Q: Olivia has $23. She bought five bagels for $3 each. How much money does she have left?\n", 252 | "A: She bought 5 bagels for $3 each. This means she spent 5 * $3 = $15 on the bagels. She had $23 in beginning, so now she has $23 - $15 = $8. The answer is 8.\"\"\"\n", 253 | "\n", 254 | "\n", 255 | "# Extension of the default 8-shot prompt, page 35 in\n", 256 | "# https://arxiv.org/pdf/2201.11903.pdf\n", 257 | "# The extension is intended to improve performance on\n", 258 | "# more complicated gsm8k examples.\n", 259 | "\n", 260 | "EXTRA_3_SHOTS = \"\"\"As an expert problem solver solve step by step the following mathematical questions.\n", 261 | "\n", 262 | "Q: Tina makes $18.00 an hour. If she works more than 8 hours per shift, she is eligible for overtime, which is paid by your hourly wage + 1/2 your hourly wage. If she works 10 hours every day for 5 days, how much money does she make?\n", 263 | "A: Here's how to calculate Tina's earnings:\n", 264 | "\n", 265 | "**Regular Time:**\n", 266 | "- Hours per shift: 8 hours\n", 267 | "- Wage per hour: $18.00\n", 268 | "- Regular pay per shift: 8 hours * $18.00/hour = $144.00\n", 269 | "\n", 270 | "**Overtime:**\n", 271 | "- Overtime hours per shift: 10 hours - 8 hours = 2 hours\n", 272 | "- Overtime pay per hour: $18.00 + ($18.00 / 2) = $27.00\n", 273 | "- Overtime pay per shift: 2 hours * $27.00/hour = $54.00\n", 274 | "\n", 275 | "**Total per day:**\n", 276 | "- Regular pay + overtime pay: $144.00/shift + $54.00/shift = $198.00/day\n", 277 | "\n", 278 | "**Total for 5 days:**\n", 279 | "- 5 days * $198.00/day = $990.00\n", 280 | "\n", 281 | "**Therefore, Tina will make $990.00 in 5 days.** The answer is 990.\n", 282 | "\n", 283 | "Q: Abigail is trying a new recipe for a cold drink. It uses 1/4 of a cup of iced tea and 1 and 1/4 of a cup of lemonade to make one drink. If she fills a pitcher with 18 total cups of this drink, how many cups of lemonade are in the pitcher?\n", 284 | "A: ## Ambiguity in the Problem Statement:\n", 285 | "\n", 286 | "There is one main ambiguity in the problem statement:\n", 287 | "\n", 288 | "**Total volume vs. Number of servings:** The statement \"18 total cups of this drink\" could be interpreted in two ways:\n", 289 | " * **18 cups of the combined volume:** This would mean Abigail used a total of 18 cups of liquid, including both iced tea and lemonade.\n", 290 | " * **18 individual servings:** This would mean Abigail made 18 individual drinks, each containing 1/4 cup of iced tea and 1 1/4 cup of lemonade.\n", 291 | "\n", 292 | "Let us assume the interpretation \"18 cups of the combined volume\".\n", 293 | "\n", 294 | "## Solution assuming 18 cups of combined volume:\n", 295 | "\n", 296 | "**Step 1: Find the proportion of lemonade in one drink:**\n", 297 | "\n", 298 | "* Lemonade: 1 1/4 cups\n", 299 | "* Iced tea: 1/4 cup\n", 300 | "* Total: 1 1/4 + 1/4 = 1 1/2 cups\n", 301 | "* Lemonade proportion: (1 1/4) / (1 1/2) = 5/6\n", 302 | "\n", 303 | "**Step 2: Calculate the amount of lemonade in the pitcher:**\n", 304 | "\n", 305 | "* Total volume: 18 cups\n", 306 | "* Lemonade proportion: 5/6\n", 307 | "* Volume of lemonade: 18 * (5/6) = 15 cups\n", 308 | "\n", 309 | "Therefore, there are 15 cups of lemonade in the pitcher. The answer is 15.\n", 310 | "\n", 311 | "Q: A deep-sea monster rises from the waters once every hundred years to feast on a ship and sate its hunger. Over three hundred years, it has consumed 847 people. Ships have been built larger over time, so each new ship has twice as many people as the last ship. How many people were on the ship the monster ate in the first hundred years?\n", 312 | "A: Let us solve it using algebra. Let x be the number of people on the ship the monster ate in the first hundred years.\n", 313 | "\n", 314 | "The number of people on the ship eaten in the second hundred years is 2x, and in the third hundred years is 4x.\n", 315 | "\n", 316 | "Therefore, the total number of people eaten over three hundred years is x + 2x + 4x = 847.\n", 317 | "\n", 318 | "Combining like terms, we get 7x = 847.\n", 319 | "\n", 320 | "Dividing both sides by 7, we find x = 121.\n", 321 | "\n", 322 | "Therefore, there were 121 people on the ship the monster ate in the first hundred years. The answer is 121.\"\"\"" 323 | ] 324 | }, 325 | { 326 | "cell_type": "markdown", 327 | "metadata": { 328 | "id": "9LoeozW4EXir" 329 | }, 330 | "source": [ 331 | "## Load and prepare your LLM's checkpoint for use with Flax.\n", 332 | "\n", 333 | "Start by loading the weights of your model." 334 | ] 335 | }, 336 | { 337 | "cell_type": "code", 338 | "execution_count": null, 339 | "metadata": { 340 | "cellView": "form", 341 | "id": "7s15QMbbEXir" 342 | }, 343 | "outputs": [], 344 | "source": [ 345 | "# Load parameters\n", 346 | "params = params_lib.load_and_format_params(ckpt_path)" 347 | ] 348 | }, 349 | { 350 | "cell_type": "markdown", 351 | "metadata": { 352 | "id": "2tkY-sLuEXir" 353 | }, 354 | "source": [ 355 | "Then load the tokenizer." 356 | ] 357 | }, 358 | { 359 | "cell_type": "code", 360 | "execution_count": null, 361 | "metadata": { 362 | "cellView": "form", 363 | "id": "_n0KePI2EXir" 364 | }, 365 | "outputs": [], 366 | "source": [ 367 | "vocab = spm.SentencePieceProcessor()\n", 368 | "vocab.Load(vocab_path)" 369 | ] 370 | }, 371 | { 372 | "cell_type": "markdown", 373 | "metadata": { 374 | "id": "1wvhxNb1EXir" 375 | }, 376 | "source": [ 377 | "Finally, build a sampler from the transformer configuration deduced from the checkpoint." 378 | ] 379 | }, 380 | { 381 | "cell_type": "code", 382 | "execution_count": null, 383 | "metadata": { 384 | "cellView": "form", 385 | "id": "51WOHSzVEXir" 386 | }, 387 | "outputs": [], 388 | "source": [ 389 | "transformer_config = transformer_lib.TransformerConfig.from_params(\n", 390 | " params, cache_size=1024)\n", 391 | "transformer = transformer_lib.Transformer(transformer_config)\n", 392 | "\n", 393 | "# Create a sampler with the right param shapes for the GSM8K prompt below\n", 394 | "sampler = sampler_lib.Sampler(\n", 395 | " transformer=transformer,\n", 396 | " vocab=vocab,\n", 397 | " params=params['transformer'],\n", 398 | ")" 399 | ] 400 | }, 401 | { 402 | "cell_type": "markdown", 403 | "metadata": { 404 | "id": "5NhlBMaIEXir" 405 | }, 406 | "source": [ 407 | "## Main Evaluation loop\n", 408 | "\n", 409 | "You should expect a score of 19.86% with the 2B model." 410 | ] 411 | }, 412 | { 413 | "cell_type": "code", 414 | "execution_count": null, 415 | "metadata": { 416 | "cellView": "form", 417 | "id": "iHxQeQ4hEXir" 418 | }, 419 | "outputs": [], 420 | "source": [ 421 | "%%time\n", 422 | "all_correct = 0\n", 423 | "all_responses = {}\n", 424 | "short_responses = {}\n", 425 | "idx = 0\n", 426 | "correct = 0\n", 427 | "\n", 428 | "TEMPLATE = \"\"\"\n", 429 | "Q: {question}\n", 430 | "A:\"\"\"\n", 431 | "\n", 432 | "for task_id, problem in enumerate(gsm8k_test):\n", 433 | "\n", 434 | " if task_id in all_responses: continue\n", 435 | "\n", 436 | " # Print Task ID\n", 437 | " print(f\"task_id {task_id}\")\n", 438 | "\n", 439 | " # Formulate and print the full prompt\n", 440 | " full_prompt = (PREAMBLE +'\\n\\n' + PROMPT + '\\n' +\n", 441 | " TEMPLATE.format(question=problem['question']))\n", 442 | " short_prompt = PREAMBLE +'\\n' + TEMPLATE.format(question=problem['question'])\n", 443 | "\n", 444 | " input_batch = [full_prompt]\n", 445 | " response = sampler(input_strings=input_batch, total_generation_steps=1024)\n", 446 | " print(response.text)\n", 447 | "\n", 448 | " all_responses[task_id] = response.text[0].split('\\nQ:')[0]\n", 449 | " short_responses[task_id] = maybe_remove_comma(find_number(all_responses[task_id]))\n", 450 | " print(f\"Short answer: {short_responses[task_id]}\")\n", 451 | " try:\n", 452 | " correct += float(maybe_remove_comma(\n", 453 | " find_number(problem['answer']))) == float(short_responses[task_id])\n", 454 | " except:\n", 455 | " correct += maybe_remove_comma(\n", 456 | " find_number(problem['answer'])) == maybe_remove_comma(\n", 457 | " find_number(short_responses[task_id]))\n", 458 | " print('-'*40)\n", 459 | " print(f\"Ground truth answer {problem['answer']}\")\n", 460 | " print(f\"Short ground truth answer {find_number(problem['answer'])}\")\n", 461 | " print(f\"Correct: {correct} out of {idx+1}\")\n", 462 | " print(\"=\"*40)\n", 463 | " idx += 1\n" 464 | ] 465 | } 466 | ], 467 | "metadata": { 468 | "colab": { 469 | "private_outputs": true 470 | }, 471 | "language_info": { 472 | "name": "python" 473 | } 474 | }, 475 | "nbformat": 4, 476 | "nbformat_minor": 0 477 | } 478 | -------------------------------------------------------------------------------- /colabs/fine_tuning_tutorial.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "OiBSu3YkEcoX" 7 | }, 8 | "source": [ 9 | "Copyright 2024 DeepMind Technologies Limited.\n", 10 | "\n", 11 | "Licensed under the Apache License, Version 2.0 (the \"License\");\n", 12 | "you may not use this file except in compliance with the License.\n", 13 | "You may obtain a copy of the License at\n", 14 | "\n", 15 | " http://www.apache.org/licenses/LICENSE-2.0\n", 16 | "\n", 17 | "Unless required by applicable law or agreed to in writing, software\n", 18 | "distributed under the License is distributed on an \"AS IS\" BASIS,\n", 19 | "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", 20 | "See the License for the specific language governing permissions and\n", 21 | "limitations under the License.\n", 22 | "\n", 23 | "---" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": { 29 | "id": "Y5OeTiryEcoX" 30 | }, 31 | "source": [ 32 | "# Fine-tuning the 2B Gemma model with flax\n", 33 | "\n", 34 | "\n", 35 | " \"Open\n", 36 | "\n", 37 | "\n", 38 | "In this tutorial you will learn how to fine-tune the 2B Gemma model for a simple translation task. To run this colab, you will need to use a TPU v4 runtime." 39 | ] 40 | }, 41 | { 42 | "cell_type": "markdown", 43 | "metadata": { 44 | "id": "5m81VQOqEcoX" 45 | }, 46 | "source": [ 47 | "## Setup" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": null, 53 | "metadata": { 54 | "id": "XpSw-_4EEcoY" 55 | }, 56 | "outputs": [], 57 | "source": [ 58 | "# @title Installation\n", 59 | "! pip install git+https://github.com/google-deepmind/gemma.git\n", 60 | "! pip install --user kaggle" 61 | ] 62 | }, 63 | { 64 | "cell_type": "markdown", 65 | "metadata": { 66 | "id": "iLafhtv3Rg5F" 67 | }, 68 | "source": [ 69 | "## Downloading the checkpoint\n", 70 | "\n", 71 | "\"To use Gemma's checkpoints, you'll need a Kaggle account and API key. Here's how to get them:\n", 72 | "\n", 73 | "1. Visit https://www.kaggle.com/ and create an account.\n", 74 | "2. Go to your account settings, then the 'API' section.\n", 75 | "3. Click 'Create new token' to download your key.\n", 76 | "\n", 77 | "Then run the cell below." 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": null, 83 | "metadata": { 84 | "id": "8q5seOhcUBhx" 85 | }, 86 | "outputs": [], 87 | "source": [ 88 | "import kagglehub\n", 89 | "kagglehub.login()" 90 | ] 91 | }, 92 | { 93 | "cell_type": "markdown", 94 | "metadata": { 95 | "id": "jCZSmEVDVv6O" 96 | }, 97 | "source": [ 98 | "If everything went well, you should see:\n", 99 | "```\n", 100 | "Kaggle credentials set.\n", 101 | "Kaggle credentials successfully validated.\n", 102 | "```\n", 103 | "\n", 104 | "Now select and download the checkpoint you want to try. On a single host, only the 2b model can fit in memory for fine-tuning." 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": null, 110 | "metadata": { 111 | "id": "9PEefz8wEcoY" 112 | }, 113 | "outputs": [], 114 | "source": [ 115 | "import os\n", 116 | "\n", 117 | "VARIANT = '2b-it' # @param ['2b', '2b-it'] {type:\"string\"}\n", 118 | "weights_dir = kagglehub.model_download(f'google/gemma/Flax/{VARIANT}')\n", 119 | "\n", 120 | "ckpt_path = os.path.join(weights_dir, variant)\n", 121 | "vocab_path = os.path.join(weights_dir, 'tokenizer.model')" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": null, 127 | "metadata": { 128 | "id": "yWaP_LPoEcoY" 129 | }, 130 | "outputs": [], 131 | "source": [ 132 | "# @title Python imports\n", 133 | "\n", 134 | "import enum\n", 135 | "import re\n", 136 | "import string\n", 137 | "\n", 138 | "# We import JAX and some related packages.\n", 139 | "import chex\n", 140 | "import jax\n", 141 | "import jax.numpy as jnp\n", 142 | "import optax\n", 143 | "\n", 144 | "# We will use tensorflow to handle the dataset\n", 145 | "import tensorflow as tf\n", 146 | "import tensorflow_datasets as tfds\n", 147 | "\n", 148 | "# Finally, we import Gemma.\n", 149 | "from gemma import params as params_lib\n", 150 | "from gemma import sampler as sampler_lib\n", 151 | "from gemma import transformer as transformer_lib\n", 152 | "import sentencepiece as spm" 153 | ] 154 | }, 155 | { 156 | "cell_type": "markdown", 157 | "metadata": { 158 | "id": "ejQhgtjbEcoY" 159 | }, 160 | "source": [ 161 | "## Step 1: prepare the dataset\n", 162 | "\n", 163 | "### The MTNT dataset\n", 164 | "\n", 165 | "In this tutorial, we will use the MTNT dataset, from the paper [MTNT: A Testbed for Machine Translation of Noisy Text](https://arxiv.org/abs/1809.00388). This dataset is directly available in the [TensorFlow dataset catalog](https://www.tensorflow.org/datasets/catalog/mtnt).\n", 166 | "\n", 167 | "More precisely we will focus on the English to French translation.\n", 168 | "\n", 169 | "But let's have a look at the data themselves." 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": null, 175 | "metadata": { 176 | "cellView": "form", 177 | "id": "pg8SfQH0EcoY" 178 | }, 179 | "outputs": [], 180 | "source": [ 181 | "ds = tfds.load(\"mtnt/en-fr\", split=\"train\")\n", 182 | "ds = ds.take(2)\n", 183 | "ds = ds.as_numpy_iterator()\n", 184 | "for idx, example in enumerate(ds):\n", 185 | " print(f'Example {idx}:')\n", 186 | " for key, val in example.items():\n", 187 | " print(f'{key}: {val}')\n", 188 | " print()" 189 | ] 190 | }, 191 | { 192 | "cell_type": "markdown", 193 | "metadata": { 194 | "id": "aYy4EJDsEcoY" 195 | }, 196 | "source": [ 197 | "Each sample in the dataset contains two entries:\n", 198 | "- 'src': The original English sentence.\n", 199 | "- 'dst': The corresponding French translation." 200 | ] 201 | }, 202 | { 203 | "cell_type": "markdown", 204 | "metadata": { 205 | "id": "NYC42hJgEcoY" 206 | }, 207 | "source": [ 208 | "### Tokenizer\n", 209 | "\n", 210 | "Let's start by loading our vocabulary base tokenizer, which we'll construct using the [SentencePiece](https://github.com/google/sentencepiece) library." 211 | ] 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": null, 216 | "metadata": { 217 | "cellView": "form", 218 | "id": "TpyG5YW1EcoY" 219 | }, 220 | "outputs": [], 221 | "source": [ 222 | "vocab = spm.SentencePieceProcessor()\n", 223 | "vocab.Load(vocab_path)" 224 | ] 225 | }, 226 | { 227 | "cell_type": "markdown", 228 | "metadata": { 229 | "id": "Ab2MSf-qEcoY" 230 | }, 231 | "source": [ 232 | "Let's customize `SentencePieceProcessor` for our English-to-French translation task. Since we're fine-tuning the English-only Gemma 2B model, we need a few adjustments:\n", 233 | "\n", 234 | "- **Input Prefix**: Adding a common prefix to each input signals the translation task. For example we could go with a prompt like `Translate this into French: [INPUT_SENTENCE]`.\n", 235 | "\n", 236 | "- **Translation Start suffix**: We add a suffix at the end of each prompt tells the model exactly when to begin the translation process. A new line should do the job.\n", 237 | "\n", 238 | "- **LM Tokens**: Gemma models expect a *beginning of sequence* token at the beginning of each sequence. Similarly, we need to add an *end of sequence* token at the end of each training example." 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "execution_count": null, 244 | "metadata": { 245 | "cellView": "form", 246 | "id": "L9cjK0uxEcoY" 247 | }, 248 | "outputs": [], 249 | "source": [ 250 | "class GemmaTokenizer:\n", 251 | " \"\"\"Custom wrapper around a SentencePieceProcessor for tensorflow.\"\"\"\n", 252 | "\n", 253 | " def __init__(self,\n", 254 | " spm_processor: spm.SentencePieceProcessor):\n", 255 | " self._spm_processor = spm_processor\n", 256 | "\n", 257 | " @property\n", 258 | " def pad_id(self) -> int:\n", 259 | " \"\"\"Fast access to the pad id.\"\"\"\n", 260 | " return self._spm_processor.pad_id()\n", 261 | "\n", 262 | " def tokenize(self,\n", 263 | " example: str | bytes,\n", 264 | " prefix: str = '',\n", 265 | " suffix: str = '',\n", 266 | " add_eos: bool = True) -> jax.Array:\n", 267 | " \"\"\"\n", 268 | " Tokenization function.\n", 269 | "\n", 270 | " Args:\n", 271 | " example: input string to tokenize.\n", 272 | " prefix: prefix to add to the input string.\n", 273 | " suffix: suffix to add to the input string.\n", 274 | " add_eos: if True, add an end of sentence token at the end of the output\n", 275 | " sequence.\n", 276 | " Returns:\n", 277 | " Tokens corresponding to the input string.\n", 278 | " \"\"\"\n", 279 | " int_list = [self._spm_processor.bos_id()]\n", 280 | " int_list.extend(self._spm_processor.EncodeAsIds(prefix + example + suffix))\n", 281 | " if add_eos:\n", 282 | " int_list.append(self._spm_processor.eos_id())\n", 283 | "\n", 284 | " return jnp.array(int_list, dtype=jnp.int32)\n", 285 | "\n", 286 | " def tokenize_tf_op(self,\n", 287 | " str_tensor: tf.Tensor,\n", 288 | " prefix: str = '',\n", 289 | " suffix: str = '',\n", 290 | " add_eos: bool = True) -> tf.Tensor:\n", 291 | " \"\"\"Tensforflow operator for the tokenize function.\"\"\"\n", 292 | " encoded = tf.numpy_function(\n", 293 | " self.tokenize,\n", 294 | " [str_tensor, prefix, suffix, add_eos],\n", 295 | " tf.int32)\n", 296 | " encoded.set_shape([None])\n", 297 | " return encoded\n", 298 | "\n", 299 | " def to_string(self, tokens: jax.Array) -> str:\n", 300 | " \"\"\"Convert an array of tokens to a string.\"\"\"\n", 301 | " return self._spm_processor.EncodeIds(tokens.tolist())" 302 | ] 303 | }, 304 | { 305 | "cell_type": "markdown", 306 | "metadata": { 307 | "id": "6xuCVkurEcoY" 308 | }, 309 | "source": [ 310 | "Now let's try our custom tokenizer on the MTNT dataset" 311 | ] 312 | }, 313 | { 314 | "cell_type": "code", 315 | "execution_count": null, 316 | "metadata": { 317 | "cellView": "form", 318 | "id": "xEA-97ioEcoY" 319 | }, 320 | "outputs": [], 321 | "source": [ 322 | "tokenizer = GemmaTokenizer(vocab)\n", 323 | "\n", 324 | "def tokenize_source(tokenizer, example: tf.Tensor):\n", 325 | " return tokenizer.tokenize_tf_op(example,\n", 326 | " prefix='Translate this into French:\\n',\n", 327 | " suffix='\\n',\n", 328 | " add_eos=False)\n", 329 | "def tokenize_destination(tokenizer, example: tf.Tensor):\n", 330 | " return tokenizer.tokenize_tf_op(example,\n", 331 | " add_eos=True)\n", 332 | "\n", 333 | "ds = tfds.load(\"mtnt/en-fr\",split=\"train\")\n", 334 | "ds = ds.take(2)\n", 335 | "ds = ds.map(lambda x: {'src': tokenize_source(tokenizer, x['src']),\n", 336 | " 'dst': tokenize_destination(tokenizer, x['dst'])})\n", 337 | "ds = ds.as_numpy_iterator()\n", 338 | "for idx, example in enumerate(ds):\n", 339 | " print(f'Example {idx}:')\n", 340 | " for key, val in example.items():\n", 341 | " print(f'{key}: {val}')\n", 342 | " print()" 343 | ] 344 | }, 345 | { 346 | "cell_type": "markdown", 347 | "metadata": { 348 | "id": "r-x0aTugEcoY" 349 | }, 350 | "source": [ 351 | "### Data loader\n", 352 | "\n", 353 | "We can now wrap everything a build our data loader." 354 | ] 355 | }, 356 | { 357 | "cell_type": "code", 358 | "execution_count": null, 359 | "metadata": { 360 | "cellView": "form", 361 | "id": "XwFFs2mDEcoY" 362 | }, 363 | "outputs": [], 364 | "source": [ 365 | "@chex.dataclass(frozen=True)\n", 366 | "class TrainingInput:\n", 367 | " # Input tokens given to the model\n", 368 | " input_tokens: jax.Array\n", 369 | "\n", 370 | " # A mask that determines which tokens contribute to the target loss\n", 371 | " # calculation.\n", 372 | " target_mask: jax.Array\n", 373 | "\n", 374 | "class DatasetSplit(enum.Enum):\n", 375 | " TRAIN = 'train'\n", 376 | " VALIDATION = 'valid'\n", 377 | "\n", 378 | "\n", 379 | "class MTNTDatasetBuilder:\n", 380 | " \"\"\"Data loader for the MTNT dataset.\"\"\"\n", 381 | "\n", 382 | " N_ITEMS = {DatasetSplit.TRAIN: 35_692,\n", 383 | " DatasetSplit.VALIDATION: 811}\n", 384 | "\n", 385 | " BUFFER_SIZE_SHUFFLE = 10_000\n", 386 | " TRANSLATION_PREFIX = 'Translate this into French:\\n'\n", 387 | " TRANSLATION_SUFFIX = '\\n'\n", 388 | "\n", 389 | " def __init__(self,\n", 390 | " tokenizer : GemmaTokenizer,\n", 391 | " max_seq_len: int):\n", 392 | " \"\"\"Constructor.\n", 393 | "\n", 394 | " Args:\n", 395 | " tokenizer: Gemma tokenizer to use.\n", 396 | " max_seq_len: size of each sequence in a given batch.\n", 397 | " \"\"\"\n", 398 | " self._tokenizer = tokenizer\n", 399 | " self._base_data = {\n", 400 | " DatasetSplit.TRAIN: tfds.load(\"mtnt/en-fr\",split=\"train\"),\n", 401 | " DatasetSplit.VALIDATION: tfds.load(\"mtnt/en-fr\",split=\"valid\"),\n", 402 | " }\n", 403 | " self._max_seq_len = max_seq_len\n", 404 | "\n", 405 | " def _tokenize_source(self, example: tf.Tensor):\n", 406 | " \"\"\"Tokenization function for the source.\"\"\"\n", 407 | " return self._tokenizer.tokenize_tf_op(example,\n", 408 | " prefix=self.TRANSLATION_PREFIX,\n", 409 | " suffix=self.TRANSLATION_SUFFIX,\n", 410 | " add_eos=False)\n", 411 | "\n", 412 | " def _tokenize_destination(self, example: tf.Tensor):\n", 413 | " \"\"\"Tokenization function for the French translation.\"\"\"\n", 414 | " return self._tokenizer.tokenize_tf_op(example,\n", 415 | " add_eos=True)\n", 416 | "\n", 417 | " def _pad_up_to_max_len(self,\n", 418 | " input_tensor: tf.Tensor,\n", 419 | " pad_value: int | bool,\n", 420 | " ) -> tf.Tensor:\n", 421 | " \"\"\"Pad the given tensor up to sequence length of a batch.\"\"\"\n", 422 | " seq_len = tf.shape(input_tensor)[0]\n", 423 | " to_pad = tf.maximum(self._max_seq_len - seq_len, 0)\n", 424 | " return tf.pad(input_tensor,\n", 425 | " [[0, to_pad]],\n", 426 | " mode='CONSTANT',\n", 427 | " constant_values=pad_value,\n", 428 | " )\n", 429 | "\n", 430 | " def _to_training_input(self,\n", 431 | " src_tokens: jax.Array,\n", 432 | " dst_tokens: jax.Array,\n", 433 | " ) -> TrainingInput:\n", 434 | " \"\"\"Build a training input from a tuple of source and destination tokens.\"\"\"\n", 435 | "\n", 436 | " # The input sequence fed to the model is simply the concatenation of the\n", 437 | " # source and the destination.\n", 438 | " tokens = tf.concat([src_tokens, dst_tokens], axis=0)\n", 439 | "\n", 440 | " # We want to prevent the model from updating based on the source (input)\n", 441 | " # tokens. To achieve this, we add a target mask to each input.\n", 442 | " q_mask = tf.zeros_like(src_tokens, dtype=tf.bool)\n", 443 | " a_mask = tf.ones_like(dst_tokens, dtype=tf.bool)\n", 444 | " mask = tf.concat([q_mask, a_mask], axis=0)\n", 445 | "\n", 446 | " # If the output tokens sequence is smaller than the target sequence size,\n", 447 | " # then we pad it with pad tokens.\n", 448 | " tokens = self._pad_up_to_max_len(tokens, self._tokenizer.pad_id)\n", 449 | "\n", 450 | " # We don't want to perform the backward on the pad tokens.\n", 451 | " mask = self._pad_up_to_max_len(mask, False)\n", 452 | "\n", 453 | " return TrainingInput(input_tokens=tokens, target_mask=mask)\n", 454 | "\n", 455 | "\n", 456 | " def get_train_dataset(self, batch_size: int, num_epochs: int):\n", 457 | " \"\"\"Build the training dataset.\"\"\"\n", 458 | "\n", 459 | " # Tokenize each sample\n", 460 | " ds = self._base_data[DatasetSplit.TRAIN].map(lambda x : (self._tokenize_source(x['src']),\n", 461 | " self._tokenize_destination(x['dst'])))\n", 462 | "\n", 463 | " # Convert them to training inputs\n", 464 | " ds = ds.map(lambda x, y: self._to_training_input(x, y))\n", 465 | "\n", 466 | " # Remove the samples which are too long\n", 467 | " ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)\n", 468 | "\n", 469 | " # Shuffle the dataset\n", 470 | " ds = ds.shuffle(buffer_size=self.BUFFER_SIZE_SHUFFLE)\n", 471 | "\n", 472 | " # Repeat if necessary\n", 473 | " ds = ds.repeat(num_epochs)\n", 474 | "\n", 475 | " # Build batches\n", 476 | " ds = ds.batch(batch_size, drop_remainder=True)\n", 477 | " return ds\n", 478 | "\n", 479 | " def get_validation_dataset(self, batch_size: int):\n", 480 | " \"\"\"Build the validation dataset.\"\"\"\n", 481 | "\n", 482 | " # Same as the training dataset, but no shuffling and no repetition\n", 483 | " ds = self._base_data[DatasetSplit.VALIDATION].map(lambda x : (self._tokenize_source(x['src']),\n", 484 | " self._tokenize_destination(x['dst'])))\n", 485 | " ds = ds.map(lambda x, y: self._to_training_input(x, y))\n", 486 | " ds = ds.filter(lambda x: tf.shape(x.input_tokens)[0] <= self._max_seq_len)\n", 487 | " ds = ds.batch(batch_size, drop_remainder=True)\n", 488 | " return ds" 489 | ] 490 | }, 491 | { 492 | "cell_type": "markdown", 493 | "metadata": { 494 | "id": "_Sq9uC15EcoZ" 495 | }, 496 | "source": [ 497 | "Let's give it a try." 498 | ] 499 | }, 500 | { 501 | "cell_type": "code", 502 | "execution_count": null, 503 | "metadata": { 504 | "cellView": "form", 505 | "id": "bYeduOaNEcoZ" 506 | }, 507 | "outputs": [], 508 | "source": [ 509 | "tokenizer = GemmaTokenizer(vocab)\n", 510 | "dataset_builder = MTNTDatasetBuilder(tokenizer, max_seq_len=20)\n", 511 | "ds = dataset_builder.get_train_dataset(3, 1)\n", 512 | "ds = ds.take(2)\n", 513 | "ds = ds.as_numpy_iterator()\n", 514 | "for idx, example in enumerate(ds):\n", 515 | " print(f'Example {idx}:')\n", 516 | " for key, val in example.items():\n", 517 | " print(f'{key}: {val}')\n", 518 | " print()" 519 | ] 520 | }, 521 | { 522 | "cell_type": "markdown", 523 | "metadata": { 524 | "id": "_VsT2o6JEcoZ" 525 | }, 526 | "source": [ 527 | "## Fine tuning the Gemma model\n", 528 | "\n", 529 | "### Getting started\n", 530 | "\n", 531 | "First let's load the model" 532 | ] 533 | }, 534 | { 535 | "cell_type": "code", 536 | "execution_count": null, 537 | "metadata": { 538 | "cellView": "form", 539 | "id": "VDlfziQVEcoZ" 540 | }, 541 | "outputs": [], 542 | "source": [ 543 | "# Load parameters\n", 544 | "\n", 545 | "# TODO: change once the downloading url is known\n", 546 | "params = params_lib.load_and_format_params(ckpt_path)\n", 547 | "\n", 548 | "# We use the `transformer_lib.TransformerConfig.from_params` function to\n", 549 | "# automatically load the correct configuration from a checkpoint. Note that the\n", 550 | "# vocabulary size is smaller than the number of input embeddings due to unused\n", 551 | "# tokens in this release.\n", 552 | "config_2b = transformer_lib.TransformerConfig.from_params(\n", 553 | " params,\n", 554 | " cache_size=30 # Number of time steps in the transformer's cache\n", 555 | ")\n", 556 | "model_2b = transformer_lib.Transformer(config=config_2b)" 557 | ] 558 | }, 559 | { 560 | "cell_type": "markdown", 561 | "metadata": { 562 | "id": "cGbfx6XVEcoZ" 563 | }, 564 | "source": [ 565 | "Can our model translate French ? Well let's try it out !" 566 | ] 567 | }, 568 | { 569 | "cell_type": "code", 570 | "execution_count": null, 571 | "metadata": { 572 | "cellView": "form", 573 | "id": "jWr6Sea_EcoZ" 574 | }, 575 | "outputs": [], 576 | "source": [ 577 | "sampler_old = sampler_lib.Sampler(\n", 578 | " transformer=model_2b,\n", 579 | " vocab=vocab,\n", 580 | " params=params['transformer'],\n", 581 | ")" 582 | ] 583 | }, 584 | { 585 | "cell_type": "code", 586 | "execution_count": null, 587 | "metadata": { 588 | "cellView": "form", 589 | "id": "S6937NTjEcoZ" 590 | }, 591 | "outputs": [], 592 | "source": [ 593 | "print(sampler_old(\n", 594 | " [\"Translate this into French:\\nHello, my name is Morgane.\\n\"],\n", 595 | " # number of steps performed when generating\n", 596 | " total_generation_steps=30,\n", 597 | " ).text)" 598 | ] 599 | }, 600 | { 601 | "cell_type": "markdown", 602 | "metadata": { 603 | "id": "0Z0CXW4REcoZ" 604 | }, 605 | "source": [ 606 | "As expected, it didn't work. Let's see if we can get better results by fine-tuning.\n", 607 | "\n", 608 | "Before moving further, don't forget to clear the memory if necessary." 609 | ] 610 | }, 611 | { 612 | "cell_type": "code", 613 | "execution_count": null, 614 | "metadata": { 615 | "cellView": "form", 616 | "id": "LbJa4S5WEcoZ" 617 | }, 618 | "outputs": [], 619 | "source": [ 620 | "del sampler_old" 621 | ] 622 | }, 623 | { 624 | "cell_type": "markdown", 625 | "metadata": { 626 | "id": "gxf6gVGCEcoZ" 627 | }, 628 | "source": [ 629 | "### Model forward and loss function\n", 630 | "\n", 631 | "Gemma `Transformer` class inherits from [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/guides/flax_fundamentals/flax_basics.html). It offers two essential methods:\n", 632 | "\n", 633 | "- `init`: Initializes the model's parameters.\n", 634 | "\n", 635 | "- `apply`: Executes the model's `__call__` function using a given set of parameters.\n", 636 | "\n", 637 | "Since are working with pre-trained weights, we won't use the `init` function.\n", 638 | "\n", 639 | "We define a `forward_and_loss_fn` as follows:" 640 | ] 641 | }, 642 | { 643 | "cell_type": "code", 644 | "execution_count": null, 645 | "metadata": { 646 | "cellView": "form", 647 | "id": "iEcV0XEEEcoZ" 648 | }, 649 | "outputs": [], 650 | "source": [ 651 | "def forward_and_loss_fn(params,\n", 652 | " *,\n", 653 | " model: transformer_lib.Transformer,\n", 654 | " input_tokens: jax.Array, # Shape [B, L]\n", 655 | " input_mask: jax.Array, # Shape [B, L]\n", 656 | " positions: jax.Array, # Shape [B, L]\n", 657 | " attention_mask: jax.Array, # [B, L, L]\n", 658 | " ) -> jax.Array:\n", 659 | " \"\"\"Foward pass and loss function.\n", 660 | "\n", 661 | " Args:\n", 662 | " params: model's input parameters.\n", 663 | " model: gemma transformer model to call.\n", 664 | " input_tokens: input tokens sequence, shape [B, L].\n", 665 | " input_mask: tokens to ignore when computing the loss, shape [B, L].\n", 666 | " positions: relative position of each token, shape [B, L].\n", 667 | " attention_mask: input attention mask, shape [B, L].\n", 668 | "\n", 669 | " Returns:\n", 670 | " Softmax cross-entropy loss for the next-token prediction task.\n", 671 | " \"\"\"\n", 672 | "\n", 673 | " # Foward pass on the input data.\n", 674 | " # No attention cache is needed here.\n", 675 | " logits, _ = model.apply(\n", 676 | " params,\n", 677 | " input_tokens,\n", 678 | " positions,\n", 679 | " None, # Attention cache is None.\n", 680 | " attention_mask,\n", 681 | " )\n", 682 | "\n", 683 | " # Exclude the last step as it does not appear in the targets.\n", 684 | " logits = logits[0, :-1]\n", 685 | "\n", 686 | " # Similarly, the first token cannot be predicteds.\n", 687 | " target_tokens = input_tokens[0, 1:]\n", 688 | " target_mask = input_mask[0, 1:]\n", 689 | "\n", 690 | " # Convert the target labels into one-hot encoded vectors.\n", 691 | " one_hot = jax.nn.one_hot(target_tokens, logits.shape[-1])\n", 692 | "\n", 693 | " # Don't update on unwanted tokens.\n", 694 | " one_hot = one_hot * target_mask.astype(one_hot.dtype)[...,None]\n", 695 | "\n", 696 | " # Normalisation factor.\n", 697 | " norm_factor = 1 / (jnp.sum(target_mask) + 1e-8)\n", 698 | "\n", 699 | " # Return the nll loss.\n", 700 | " return -jnp.sum(jax.nn.log_softmax(logits) * one_hot) * norm_factor" 701 | ] 702 | }, 703 | { 704 | "cell_type": "markdown", 705 | "metadata": { 706 | "id": "Y83DimpjEcoZ" 707 | }, 708 | "source": [ 709 | "The Gemma transformer requires an attention mask and position vector alongside each input. We can conveniently generate these using the following function:" 710 | ] 711 | }, 712 | { 713 | "cell_type": "code", 714 | "execution_count": null, 715 | "metadata": { 716 | "cellView": "form", 717 | "id": "cbWfdHf0EcoZ" 718 | }, 719 | "outputs": [], 720 | "source": [ 721 | "def get_attention_mask_and_positions(example: jax.Array,\n", 722 | " pad_id : int,\n", 723 | " )-> tuple[jax.Array, jax.Array]:\n", 724 | " \"\"\"Builds the position and attention mask vectors from the given tokens.\"\"\"\n", 725 | " pad_mask = example != pad_id\n", 726 | " current_token_position = transformer_lib.build_positions_from_mask(pad_mask)\n", 727 | " attention_mask = transformer_lib.make_causal_attn_mask(pad_mask)\n", 728 | " return current_token_position, attention_mask" 729 | ] 730 | }, 731 | { 732 | "cell_type": "markdown", 733 | "metadata": { 734 | "id": "xbxYMMWLEcoZ" 735 | }, 736 | "source": [ 737 | "We can now build the train_step function which performs the backward pass and updates the model's parameters accordingly." 738 | ] 739 | }, 740 | { 741 | "cell_type": "code", 742 | "execution_count": null, 743 | "metadata": { 744 | "cellView": "form", 745 | "id": "cPSfp7ZUEcoZ" 746 | }, 747 | "outputs": [], 748 | "source": [ 749 | "def train_step(model: transformer_lib.Transformer,\n", 750 | " params,\n", 751 | " optimizer: optax.GradientTransformation,\n", 752 | " opt_state: optax.OptState,\n", 753 | " pad_id: int,\n", 754 | " example: TrainingInput):\n", 755 | " \"\"\"Train step.\n", 756 | "\n", 757 | " Args:\n", 758 | " model: gemma transformer model.\n", 759 | " params: model's input parameters.\n", 760 | " optimizer: optax optimizer to use.\n", 761 | " opt_state: input optimizer's state.\n", 762 | " pad_id: id of the pad token.\n", 763 | " example: input batch.\n", 764 | "\n", 765 | " Returns:\n", 766 | " Training loss, updated parameters, updated optimizer state.\n", 767 | " \"\"\"\n", 768 | "\n", 769 | " # Build the position and attention mask vectors.\n", 770 | " positions, attention_mask = get_attention_mask_and_positions(example.input_tokens, pad_id)\n", 771 | "\n", 772 | " # Forward and backward passes\n", 773 | " train_loss, grads = jax.value_and_grad(forward_and_loss_fn)(params,\n", 774 | " model=model,\n", 775 | " input_tokens=example.input_tokens,\n", 776 | " input_mask=example.target_mask,\n", 777 | " positions=positions,\n", 778 | " attention_mask=attention_mask)\n", 779 | " # Update the parameters\n", 780 | " updates, opt_state = optimizer.update(grads, opt_state)\n", 781 | " params = optax.apply_updates(params, updates)\n", 782 | "\n", 783 | " return train_loss, params, opt_state" 784 | ] 785 | }, 786 | { 787 | "cell_type": "markdown", 788 | "metadata": { 789 | "id": "R2QXp116EcoZ" 790 | }, 791 | "source": [ 792 | "Similarly, we build a `validation_step` function without backward pass." 793 | ] 794 | }, 795 | { 796 | "cell_type": "code", 797 | "execution_count": null, 798 | "metadata": { 799 | "cellView": "form", 800 | "id": "yU4oR92YEcoa" 801 | }, 802 | "outputs": [], 803 | "source": [ 804 | "def validation_step(model: transformer_lib.Transformer,\n", 805 | " params,\n", 806 | " pad_id: int,\n", 807 | " example: TrainingInput,\n", 808 | " ):\n", 809 | " positions, attention_mask = get_attention_mask_and_positions(example.input_tokens, pad_id)\n", 810 | " val_loss = forward_and_loss_fn(params,\n", 811 | " model=model,\n", 812 | " input_tokens=example.input_tokens,\n", 813 | " input_mask=example.target_mask,\n", 814 | " positions=positions,\n", 815 | " attention_mask=attention_mask)\n", 816 | " return val_loss" 817 | ] 818 | }, 819 | { 820 | "cell_type": "markdown", 821 | "metadata": { 822 | "id": "6g6LFWJbEcoa" 823 | }, 824 | "source": [ 825 | "And now the training loop itself." 826 | ] 827 | }, 828 | { 829 | "cell_type": "code", 830 | "execution_count": null, 831 | "metadata": { 832 | "cellView": "form", 833 | "id": "xT4bAqNLEcoa" 834 | }, 835 | "outputs": [], 836 | "source": [ 837 | "@chex.dataclass(frozen=True)\n", 838 | "class TrainingConfig:\n", 839 | " learning_rate: float\n", 840 | " num_epochs: int\n", 841 | " eval_every_n: int\n", 842 | " batch_size: int\n", 843 | " max_steps: int | None = None\n", 844 | "\n", 845 | "\n", 846 | "def train_loop(\n", 847 | " model: transformer_lib.Transformer,\n", 848 | " params,\n", 849 | " dataset_builder: MTNTDatasetBuilder,\n", 850 | " training_cfg: TrainingConfig):\n", 851 | "\n", 852 | "\n", 853 | " # We jit the train step, making the whole loop much more efficient\n", 854 | " compiled_train_step = jax.jit(train_step, static_argnames=['model', 'optimizer'])\n", 855 | "\n", 856 | " # We do the same with the validation step\n", 857 | " compiled_validation_step = jax.jit(validation_step, static_argnames=['model'])\n", 858 | "\n", 859 | " # To save memory, we use a SGD optimizer instead of the usual Adam. Note that\n", 860 | " # for this specific example SGD is more than enough.\n", 861 | " optimizer = optax.sgd(training_cfg.learning_rate)\n", 862 | " opt_state = optimizer.init(params)\n", 863 | "\n", 864 | " # Build the training dataset\n", 865 | " train_ds = dataset_builder.get_train_dataset(batch_size=training_cfg.batch_size,\n", 866 | " num_epochs=training_cfg.num_epochs)\n", 867 | " train_ds = train_ds.as_numpy_iterator()\n", 868 | "\n", 869 | " # Build the validation dataset, with a limited number of samples for this demo\n", 870 | " validation_ds = dataset_builder.get_validation_dataset(batch_size=training_cfg.batch_size)\n", 871 | " validation_ds = validation_ds.take(50)\n", 872 | "\n", 873 | " n_steps = 0\n", 874 | " avg_loss=0\n", 875 | "\n", 876 | " # A first round of validation loss\n", 877 | " n_steps_eval = 0\n", 878 | " eval_loss = 0\n", 879 | " val_iterator = validation_ds.as_numpy_iterator()\n", 880 | " for val_example in val_iterator:\n", 881 | " eval_loss += compiled_validation_step(model,\n", 882 | " params,\n", 883 | " dataset_builder._tokenizer.pad_id,\n", 884 | " val_example)\n", 885 | " n_steps_eval += 1\n", 886 | " print(f\"Start, validation loss: {eval_loss/n_steps_eval}\")\n", 887 | "\n", 888 | " for train_example in train_ds:\n", 889 | " train_loss, params, opt_state = compiled_train_step(model=model,\n", 890 | " params=params,\n", 891 | " optimizer=optimizer,\n", 892 | " opt_state=opt_state,\n", 893 | " pad_id=dataset_builder._tokenizer.pad_id,\n", 894 | " example=train_example)\n", 895 | " n_steps += 1\n", 896 | " avg_loss += train_loss\n", 897 | " if n_steps % training_cfg.eval_every_n == 0:\n", 898 | " eval_loss = 0\n", 899 | "\n", 900 | " n_steps_eval = 0\n", 901 | " val_iterator = validation_ds.as_numpy_iterator()\n", 902 | " for val_example in val_iterator:\n", 903 | " eval_loss += compiled_validation_step(model,\n", 904 | " params,\n", 905 | " dataset_builder._tokenizer.pad_id,\n", 906 | " val_example)\n", 907 | " n_steps_eval +=1\n", 908 | " avg_loss /= training_cfg.eval_every_n\n", 909 | " eval_loss /= n_steps_eval\n", 910 | " print(f\"STEP {n_steps} training loss: {avg_loss} - eval loss: {eval_loss}\")\n", 911 | " avg_loss=0\n", 912 | " if training_cfg.max_steps is not None and n_steps > training_cfg.max_steps:\n", 913 | " break\n", 914 | " return params" 915 | ] 916 | }, 917 | { 918 | "cell_type": "markdown", 919 | "metadata": { 920 | "id": "muwkf_ZgEcoa" 921 | }, 922 | "source": [ 923 | "We can fine-tune our model on a limited number of steps." 924 | ] 925 | }, 926 | { 927 | "cell_type": "code", 928 | "execution_count": null, 929 | "metadata": { 930 | "cellView": "form", 931 | "id": "7SL2VAmVEcoa" 932 | }, 933 | "outputs": [], 934 | "source": [ 935 | "# Small seq size so that everything fits in memory\n", 936 | "SEQ_SIZE = 25\n", 937 | "tokenizer = GemmaTokenizer(vocab)\n", 938 | "dataset_builder= MTNTDatasetBuilder(tokenizer, SEQ_SIZE)\n", 939 | "training_cfg = TrainingConfig(learning_rate=1e-4,\n", 940 | " num_epochs=1,\n", 941 | " eval_every_n=20,\n", 942 | " batch_size=1,\n", 943 | " max_steps=100)\n", 944 | "\n", 945 | "params = train_loop(model=model_2b,\n", 946 | " params={'params': params['transformer']},\n", 947 | " dataset_builder=dataset_builder,\n", 948 | " training_cfg=training_cfg)" 949 | ] 950 | }, 951 | { 952 | "cell_type": "markdown", 953 | "metadata": { 954 | "id": "abChlybFEcod" 955 | }, 956 | "source": [ 957 | "Both the training loss and the validation's are going down. But is it working ? Let's try again with our previous example:" 958 | ] 959 | }, 960 | { 961 | "cell_type": "code", 962 | "execution_count": null, 963 | "metadata": { 964 | "cellView": "form", 965 | "id": "dQ1oCF10Ecod" 966 | }, 967 | "outputs": [], 968 | "source": [ 969 | "sampler = sampler_lib.Sampler(\n", 970 | " transformer=model_2b,\n", 971 | " vocab=vocab,\n", 972 | " params=params['params'],\n", 973 | ")" 974 | ] 975 | }, 976 | { 977 | "cell_type": "markdown", 978 | "metadata": { 979 | "id": "fIwhAvMsEcod" 980 | }, 981 | "source": [ 982 | "To ensure our input matches the training format, remember to use the prefix 'Translate this into French:\\n' and a newline character at the end. This signals the model to begin translation." 983 | ] 984 | }, 985 | { 986 | "cell_type": "code", 987 | "execution_count": null, 988 | "metadata": { 989 | "cellView": "form", 990 | "id": "S5F3fk22Ecod" 991 | }, 992 | "outputs": [], 993 | "source": [ 994 | "sampler(\n", 995 | " [\"Translate this into French:\\nHello, my name is Morgane.\\n\"],\n", 996 | " total_generation_steps=30,\n", 997 | " ).text\n" 998 | ] 999 | } 1000 | ], 1001 | "metadata": { 1002 | "accelerator": "GPU", 1003 | "colab": { 1004 | "private_outputs": true 1005 | }, 1006 | "kernelspec": { 1007 | "display_name": "Python 3", 1008 | "name": "python3" 1009 | }, 1010 | "language_info": { 1011 | "name": "python" 1012 | } 1013 | }, 1014 | "nbformat": 4, 1015 | "nbformat_minor": 0 1016 | } 1017 | --------------------------------------------------------------------------------