├── requirements.txt ├── scripts ├── images │ ├── lilly.jpg │ ├── sunflower.JPG │ ├── cow_in_beach.jpg │ └── test_image.jpg ├── run.py ├── run_multimodal.py ├── run_multimodal.py.orig └── run_xla.py ├── tokenizer ├── tokenizer.model └── gemma3_cleaned_262144_v2.spiece.model ├── gemma ├── __init__.py ├── siglip_vision │ ├── __init__.py │ ├── config.py │ ├── preprocessor.py │ ├── pan_and_scan.py │ └── siglip_vision_model.py ├── tokenizer.py ├── gemma3_preprocessor.py ├── config.py ├── gemma3_model.py ├── model_xla.py ├── xla_model_parallel.py └── model.py ├── CONTRIBUTING.md ├── docker ├── Dockerfile ├── xla.Dockerfile └── xla_gpu.Dockerfile ├── setup.py ├── .gitignore ├── README.md └── LICENSE /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==2.2.3 2 | pillow==11.1.0 3 | sentencepiece==0.2.0 4 | torch==2.6.0 5 | absl-py==2.1.0 -------------------------------------------------------------------------------- /scripts/images/lilly.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/gemma_pytorch/HEAD/scripts/images/lilly.jpg -------------------------------------------------------------------------------- /tokenizer/tokenizer.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/gemma_pytorch/HEAD/tokenizer/tokenizer.model -------------------------------------------------------------------------------- /scripts/images/sunflower.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/gemma_pytorch/HEAD/scripts/images/sunflower.JPG -------------------------------------------------------------------------------- /scripts/images/cow_in_beach.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/gemma_pytorch/HEAD/scripts/images/cow_in_beach.jpg -------------------------------------------------------------------------------- /scripts/images/test_image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/gemma_pytorch/HEAD/scripts/images/test_image.jpg -------------------------------------------------------------------------------- /tokenizer/gemma3_cleaned_262144_v2.spiece.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/gemma_pytorch/HEAD/tokenizer/gemma3_cleaned_262144_v2.spiece.model -------------------------------------------------------------------------------- /gemma/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /gemma/siglip_vision/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to contribute 2 | 3 | We'd love to accept your patches and contributions to this project. 4 | 5 | ## Before you begin 6 | 7 | ### Sign our Contributor License Agreement 8 | 9 | Contributions to this project must be accompanied by a 10 | [Contributor License Agreement](https://cla.developers.google.com/about) (CLA). 11 | You (or your employer) retain the copyright to your contribution; this simply 12 | gives us permission to use and redistribute your contributions as part of the 13 | project. 14 | 15 | If you or your current employer have already signed the Google CLA (even if it 16 | was for a different project), you probably don't need to do it again. 17 | 18 | Visit to see your current agreements or to 19 | sign a new one. 20 | 21 | ### Review our community guidelines 22 | 23 | This project follows 24 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/). 25 | 26 | ## Contribution process 27 | 28 | ### Code reviews 29 | 30 | All submissions, including submissions by project members, require review. We 31 | use GitHub pull requests for this purpose. Consult 32 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 33 | information on using pull requests. 34 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Use PyTorch CUDA base image 16 | FROM pytorch/pytorch:2.1.2-cuda11.8-cudnn8-runtime 17 | 18 | USER root 19 | 20 | # Install tools. 21 | ENV DEBIAN_FRONTEND=noninteractive 22 | RUN apt-get update 23 | RUN apt-get install -y --no-install-recommends apt-utils 24 | RUN apt-get install -y --no-install-recommends curl 25 | RUN apt-get install -y --no-install-recommends wget 26 | RUN apt-get install -y --no-install-recommends git 27 | 28 | # Install libraries. 29 | ENV PIP_ROOT_USER_ACTION=ignore 30 | RUN python -m pip install --upgrade pip 31 | RUN pip install numpy==1.24.4 32 | RUN pip install sentencepiece==0.1.99 33 | 34 | # Install from source. 35 | COPY . /workspace/gemma/ 36 | WORKDIR /workspace/gemma/ 37 | RUN pip install -e . 38 | -------------------------------------------------------------------------------- /docker/xla.Dockerfile: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Use pytorch/xla base image 16 | FROM us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_tpuvm_20231128 17 | 18 | USER root 19 | 20 | # Install tools. 21 | ENV DEBIAN_FRONTEND=noninteractive 22 | RUN apt-get update 23 | RUN apt-get install -y --no-install-recommends apt-utils 24 | RUN apt-get install -y --no-install-recommends curl 25 | RUN apt-get install -y --no-install-recommends wget 26 | RUN apt-get install -y --no-install-recommends git 27 | 28 | # Install libraries. 29 | ENV PIP_ROOT_USER_ACTION=ignore 30 | RUN python3 -m pip install --upgrade pip 31 | RUN pip install numpy==1.24.4 32 | RUN pip install sentencepiece==0.1.99 33 | 34 | # Install from source. 35 | COPY . /workspace/gemma/ 36 | WORKDIR /workspace/gemma/ 37 | RUN pip install -e . 38 | -------------------------------------------------------------------------------- /docker/xla_gpu.Dockerfile: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Use pytorch/xla base image 16 | FROM us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_11.8_20231128 17 | 18 | USER root 19 | 20 | # Install tools. 21 | ENV DEBIAN_FRONTEND=noninteractive 22 | RUN apt-get update 23 | RUN apt-get install -y --no-install-recommends apt-utils 24 | RUN apt-get install -y --no-install-recommends curl 25 | RUN apt-get install -y --no-install-recommends wget 26 | RUN apt-get install -y --no-install-recommends git 27 | 28 | # Install libraries. 29 | ENV PIP_ROOT_USER_ACTION=ignore 30 | RUN python3 -m pip install --upgrade pip 31 | RUN pip uninstall -y torch 32 | RUN pip install torch==2.1.1 33 | RUN pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/11.8/torch_xla-2.2.0rc1-cp38-cp38-linux_x86_64.whl 34 | RUN pip install numpy==1.24.4 35 | RUN pip install sentencepiece==0.1.99 36 | 37 | # Install from source. 38 | COPY . /workspace/gemma/ 39 | WORKDIR /workspace/gemma/ 40 | RUN pip install -e . 41 | -------------------------------------------------------------------------------- /gemma/tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import os 15 | from typing import List, Optional 16 | 17 | import sentencepiece 18 | 19 | def _assert_file_exists(model_path: str): 20 | assert os.path.isfile(model_path), model_path 21 | 22 | _BEGIN_IMAGE_TOKEN = 255999 23 | _END_IMAGE_TOKEN = 256000 24 | 25 | class Tokenizer: 26 | 27 | def __init__(self, model_path: Optional[str]): 28 | _assert_file_exists(model_path) 29 | self.sp_model = sentencepiece.SentencePieceProcessor() 30 | self.sp_model.Load(model_path) 31 | 32 | # BOS / EOS token IDs. 33 | self.n_words: int = self.sp_model.GetPieceSize() 34 | self.bos_id: int = self.sp_model.bos_id() 35 | self.eos_id: int = self.sp_model.eos_id() 36 | self.pad_id: int = self.sp_model.pad_id() 37 | self.boi_id: int = _BEGIN_IMAGE_TOKEN 38 | self.eoi_id: int = _END_IMAGE_TOKEN 39 | self.image_token_placeholder_id: int = self.sp_model.pad_id() 40 | 41 | def encode(self, s: str, bos: bool = True, eos: bool = False) -> List[int]: 42 | """Converts a string into a list of tokens.""" 43 | assert isinstance(s, str) 44 | t = self.sp_model.EncodeAsIds(s) 45 | if bos: 46 | t = [self.bos_id] + t 47 | if eos: 48 | t = t + [self.eos_id] 49 | return t 50 | 51 | def decode(self, t: List[int]) -> str: 52 | """Converts a list of tokens into a string.""" 53 | return self.sp_model.DecodeIds(t) 54 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import io 16 | import os 17 | from typing import List 18 | 19 | import setuptools 20 | 21 | ROOT_DIR = os.path.dirname(__file__) 22 | 23 | 24 | def get_path(*filepath) -> str: 25 | return os.path.join(ROOT_DIR, *filepath) 26 | 27 | 28 | def read_readme() -> str: 29 | """Read the README file.""" 30 | return io.open(get_path("README.md"), "r", encoding="utf-8").read() 31 | 32 | 33 | def get_requirements() -> List[str]: 34 | """Get Python package dependencies from requirements.txt.""" 35 | with open(get_path("requirements.txt")) as f: 36 | requirements = f.read().strip().split("\n") 37 | return requirements 38 | 39 | 40 | setuptools.setup( 41 | name="gemma", 42 | version="0.1", 43 | author="Gemma contributors", 44 | license="Apache 2.0", 45 | description=("Gemma model implementation"), 46 | long_description=read_readme(), 47 | long_description_content_type="text/markdown", 48 | classifiers=[ 49 | "Programming Language :: Python :: 3.8", 50 | "Programming Language :: Python :: 3.9", 51 | "Programming Language :: Python :: 3.10", 52 | "Programming Language :: Python :: 3.11", 53 | "License :: OSI Approved :: Apache Software License", 54 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 55 | ], 56 | packages=setuptools.find_packages(exclude=("benchmarks", "docs", 57 | "examples", "tests")), 58 | python_requires=">=3.11", 59 | install_requires=get_requirements(), 60 | ) 61 | -------------------------------------------------------------------------------- /gemma/siglip_vision/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Gemma model config.""" 16 | 17 | import dataclasses 18 | from . import preprocessor 19 | 20 | 21 | # https://developers.googleblog.com/en/gemma-explained-paligemma-architecture/ 22 | @dataclasses.dataclass 23 | class SiglipVisionModelConfig: 24 | """Returns the model config for the vision model of Gemma 3 andPaliGemma.""" 25 | # The number of transformer encoder blocks in the siglip encoder model. 26 | num_hidden_layers: int = 27 27 | # The dimension of the embedding. 28 | embedding_dim: int = 1152 29 | # Whether to use bias in the 2D conv embedding layer. 30 | embedding_use_bias: bool = True 31 | # The number of channels in the input images. 32 | input_channels: int = 3 33 | # The input image size. 34 | image_size: int = preprocessor.DEFAULT_IMAGE_SIZE 35 | # Kernel size of 2D convolution layer. 36 | conv2d_patch_size = 14 37 | # The number of attention heads used in the attention layers of the model. 38 | num_attention_heads: int = 16 39 | # The number of head dimensions. 40 | head_dim: int = 72 41 | # Clarify: is num_key_value same as num_query_groups? 42 | num_key_value_heads: int = 16 43 | # The number of query groups for implementing attention. 44 | num_query_groups: int = 16 45 | # Clarify: usecase of this field is not clear. 46 | qkv_use_bias: bool = True 47 | # Clarify: usecase of this field is not clear. 48 | output_proj_use_bias: bool = True 49 | # The dimension of the MLP representations. 50 | intermediate_size: int = 4304 51 | # The epsilon used by the layer normalization layers. 52 | layer_norm_eps: float = 1e-6 53 | # Clarify: identify if the dtype varies for the siglip vision model. 54 | dtype: str = 'bfloat16' 55 | # Whether a quantized version of the model is used. 56 | quant: bool = False 57 | # The sequence length of the encoding. 58 | encoding_sequence_length: int = 256 59 | 60 | 61 | def get_siglip_vision_model_config() -> SiglipVisionModelConfig: 62 | """Returns the default model config for the vision model of Gemma 3 and PaliGemma.""" 63 | return SiglipVisionModelConfig() 64 | 65 | -------------------------------------------------------------------------------- /gemma/siglip_vision/preprocessor.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Preprocessor for Siglip vision model. 15 | 16 | No neural network is used in the following functions. These are heuristic based. 17 | """ 18 | 19 | from collections.abc import Sequence 20 | 21 | from PIL import Image 22 | import torch 23 | import numpy as np # Import NumPy 24 | 25 | _IMAGE_MEAN = [0.5, 0.5, 0.5] # equivalent to 127.5/255 26 | _IMAGE_STD = [0.5, 0.5, 0.5] # equivalent to 127.5/255 27 | DEFAULT_IMAGE_SIZE = 896 28 | 29 | 30 | def preprocess_images_for_siglip_vision( 31 | images: Sequence[Image.Image], image_size=DEFAULT_IMAGE_SIZE 32 | ) -> torch.Tensor: 33 | """Preprocesses a list of PIL images for Siglip vision model using only PyTorch and PIL. 34 | 35 | Args: 36 | images: A sequence of PIL Image objects. 37 | image_size: The target size for resizing the images. 38 | 39 | Returns: 40 | A sequence of torch.Tensor objects, each of shape (C, H, W). 41 | """ 42 | processed_images = [] 43 | 44 | mean_tensor = torch.tensor(_IMAGE_MEAN, dtype=torch.float32).reshape(3, 1, 1) 45 | std_tensor = torch.tensor(_IMAGE_STD, dtype=torch.float32).reshape(3, 1, 1) 46 | 47 | for image in images: 48 | # Resize image 49 | image = image.resize((image_size, image_size), Image.Resampling.BILINEAR) 50 | 51 | # Convert to NumPy and ensure float32 type 52 | image_np = np.array(image, dtype=np.float32) / 255.0 # Normalize to [0,1] 53 | 54 | # Convert to PyTorch tensor and rearrange channels 55 | image_tensor = torch.from_numpy(image_np).permute(2, 0, 1) # (H, W, C) → (C, H, W) 56 | 57 | # Normalize 58 | image_tensor = (image_tensor - mean_tensor) / std_tensor 59 | 60 | # Clip the values to [-1, 1] 61 | image_tensor = torch.clamp(image_tensor, -1, 1) 62 | 63 | processed_images.append(image_tensor) 64 | 65 | return processed_images 66 | 67 | 68 | # Example usage: 69 | # Assuming you have a list of PIL images called 'pil_images' 70 | # pil_images = [Image.open("image1.jpg"), Image.open("image2.png")] 71 | # processed_tensors = preprocess_images_pytorch(pil_images) 72 | # for tensor in processed_tensors: print(tensor.shape) 73 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | .idea/ 161 | 162 | # VSCode 163 | .vscode/* 164 | !.vscode/extensions.json 165 | 166 | # DS Store 167 | .DS_Store 168 | 169 | # Results 170 | *.csv 171 | 172 | # Python pickle files 173 | *.pkl 174 | 175 | # Sphinx documentation 176 | _build/ 177 | -------------------------------------------------------------------------------- /scripts/run.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import contextlib 17 | import random 18 | 19 | import numpy as np 20 | import torch 21 | from absl import app, flags 22 | 23 | from gemma import config 24 | from gemma import model as gemma_model 25 | 26 | # Define flags 27 | FLAGS = flags.FLAGS 28 | 29 | flags.DEFINE_string('ckpt', None, 'Path to the checkpoint file.', required=True) 30 | flags.DEFINE_string('variant', '4b', 'Model variant.') 31 | flags.DEFINE_string('device', 'cpu', 'Device to run the model on.') 32 | flags.DEFINE_integer('output_len', 10, 'Length of the output sequence.') 33 | flags.DEFINE_integer('seed', 12345, 'Random seed.') 34 | flags.DEFINE_boolean('quant', False, 'Whether to use quantization.') 35 | flags.DEFINE_string('prompt', 'What are large language models?', 'Input prompt for the model.') 36 | 37 | # Define valid text only model variants 38 | _VALID_MODEL_VARIANTS = ['2b', '2b-v2', '7b', '9b', '27b', '1b'] 39 | 40 | # Define valid devices 41 | _VALID_DEVICES = ['cpu', 'cuda'] 42 | 43 | # Validator function for the 'variant' flag 44 | def validate_variant(variant): 45 | if variant not in _VALID_MODEL_VARIANTS: 46 | raise ValueError(f'Invalid variant: {variant}. Valid variants are: {_VALID_MODEL_VARIANTS}') 47 | return True 48 | 49 | # Validator function for the 'device' flag 50 | def validate_device(device): 51 | if device not in _VALID_DEVICES: 52 | raise ValueError(f'Invalid device: {device}. Valid devices are: {_VALID_DEVICES}') 53 | return True 54 | 55 | # Register the validator for the 'variant' flag 56 | flags.register_validator('variant', validate_variant, message='Invalid model variant.') 57 | 58 | # Register the validator for the 'device' flag 59 | flags.register_validator('device', validate_device, message='Invalid device.') 60 | 61 | @contextlib.contextmanager 62 | def _set_default_tensor_type(dtype: torch.dtype): 63 | """Sets the default torch dtype to the given dtype.""" 64 | torch.set_default_dtype(dtype) 65 | yield 66 | torch.set_default_dtype(torch.float) 67 | 68 | def main(_): 69 | # Construct the model config. 70 | model_config = config.get_model_config(FLAGS.variant) 71 | model_config.dtype = "float32" 72 | model_config.quant = FLAGS.quant 73 | 74 | # Seed random. 75 | random.seed(FLAGS.seed) 76 | np.random.seed(FLAGS.seed) 77 | torch.manual_seed(FLAGS.seed) 78 | 79 | # Create the model and load the weights. 80 | device = torch.device(FLAGS.device) 81 | with _set_default_tensor_type(model_config.get_dtype()): 82 | model = gemma_model.GemmaForCausalLM(model_config) 83 | model.load_weights(FLAGS.ckpt) 84 | model = model.to(device).eval() 85 | print("Model loading done") 86 | 87 | # Generate the response. 88 | result = model.generate(FLAGS.prompt, device, output_len=FLAGS.output_len) 89 | 90 | # Print the prompts and results. 91 | print('======================================') 92 | print(f'PROMPT: {FLAGS.prompt}') 93 | print(f'RESULT: {result}') 94 | print('======================================') 95 | 96 | if __name__ == "__main__": 97 | app.run(main) 98 | 99 | 100 | # How to run this script: 101 | 102 | # Example command (replace with your actual paths and values): 103 | # python scripts/run.py --device=cpu --ckpt=/path/to/your/pytorch_checkpoint/model.ckpt --output_len=2 --prompt="The name of the capital of Italy is" 104 | # Important: 105 | # - Replace '/path/to/your/pytorch_checkpoint/model.ckpt' with the actual path to your checkpoint file. 106 | # - Choose the correct --variant (model size). 107 | # - Use --device=cuda if you have a GPU; otherwise, use --device=cpu. -------------------------------------------------------------------------------- /gemma/siglip_vision/pan_and_scan.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Pan and scan image cropping implementation.""" 15 | 16 | from collections.abc import Sequence 17 | 18 | import numpy as np 19 | from PIL import Image 20 | 21 | 22 | def pan_and_scan( 23 | img: Image.Image, 24 | *, 25 | min_crop_size: int = 256, 26 | max_num_crops: int = 4, 27 | ) -> Sequence[Image.Image]: 28 | return _pan_and_scan_os( 29 | img, 30 | min_crop_size=min_crop_size, 31 | max_num_crops=max_num_crops, 32 | )[0] 33 | 34 | 35 | def pan_and_scan_os_with_crop_positions( 36 | img: Image.Image, 37 | *, 38 | min_crop_size: int = 256, 39 | max_num_crops: int = 4, 40 | ) -> tuple[Sequence[Image.Image], Sequence[tuple[int, int, int, int]]]: 41 | return _pan_and_scan_os( 42 | img, 43 | min_crop_size=min_crop_size, 44 | max_num_crops=max_num_crops, 45 | ) 46 | 47 | 48 | def _pan_and_scan_os( 49 | img: Image.Image, 50 | *, 51 | min_crop_size: int, 52 | max_num_crops: int, 53 | ) -> tuple[Sequence[Image.Image], Sequence[tuple[int, int, int, int]]]: 54 | """Pan and scan an image for open source. 55 | 56 | If the image is landscape, the crops are made horizontally and if the image is 57 | portrait, the crops are made vertically. The longer side of the image is split 58 | into [2 - max_num_crops] crops. 59 | 60 | Args: 61 | img: PIL Image object. 62 | min_crop_size: The minimum size of each crop. 63 | max_num_crops: The maximum desired number of crops to be generated. 64 | 65 | Returns: 66 | List of cropped PIL Image objects and a list of crop positions. 67 | """ 68 | w, h = img.size 69 | 70 | # Square or landscape image. 71 | if w >= h: 72 | if w / h < 1.5: 73 | return [img], [(0, 0, h, w)] 74 | 75 | # Select ideal number of crops close to the image aspect ratio and such that 76 | # crop_size > min_crop_size. 77 | num_crops_w = int(np.floor(w / h + 0.5)) # Half round up rounding. 78 | num_crops_w = min( 79 | int(np.floor(w / min_crop_size)), 80 | num_crops_w, 81 | ) 82 | 83 | # Make sure the number of crops is in range [2, max_num_crops]. 84 | num_crops_w = max(2, num_crops_w) 85 | num_crops_w = min(max_num_crops, num_crops_w) 86 | num_crops_h = 1 87 | 88 | # Portrait image. 89 | else: 90 | if h / w < 1.5: 91 | return [img], [(0, 0, h, w)] 92 | 93 | num_crops_h = int(np.floor(h / w + 0.5)) 94 | num_crops_h = min(int(np.floor(h / min_crop_size)), num_crops_h) 95 | num_crops_h = max(2, num_crops_h) 96 | num_crops_h = min(max_num_crops, num_crops_h) 97 | num_crops_w = 1 98 | 99 | crop_size_w = int(np.ceil(w / num_crops_w)) 100 | crop_size_h = int(np.ceil(h / num_crops_h)) 101 | 102 | # Don't apply pan and scan if crop size is too small. 103 | if min(crop_size_w, crop_size_h) < min_crop_size: 104 | return [img], [(0, 0, h, w)] 105 | 106 | crop_positions_w = [crop_size_w * i for i in range(num_crops_w)] 107 | crop_positions_h = [crop_size_h * i for i in range(num_crops_h)] 108 | 109 | # Generate crops. 110 | crops = [] 111 | crop_positions = [] 112 | for pos_h in crop_positions_h: 113 | for pos_w in crop_positions_w: 114 | crops.append( 115 | img.crop(( 116 | pos_w, 117 | pos_h, 118 | pos_w + crop_size_w, 119 | pos_h + crop_size_h, 120 | )) 121 | ) 122 | crop_positions.append( 123 | (pos_h, pos_w, pos_h + crop_size_h, pos_w + crop_size_w) 124 | ) 125 | 126 | return crops, crop_positions 127 | -------------------------------------------------------------------------------- /scripts/run_multimodal.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import contextlib 17 | import random 18 | 19 | from absl import app 20 | from absl import flags 21 | import numpy as np 22 | from PIL import Image 23 | import torch 24 | 25 | from gemma import config 26 | from gemma import gemma3_model 27 | 28 | # Define flags 29 | FLAGS = flags.FLAGS 30 | 31 | _CKPT = flags.DEFINE_string( 32 | 'ckpt', None, 'Path to the checkpoint file.', required=True 33 | ) 34 | _VARIANT = flags.DEFINE_string('variant', '4b', 'Model variant.') 35 | _DEVICE = flags.DEFINE_string('device', 'cpu', 'Device to run the model on.') 36 | _OUTPUT_LEN = flags.DEFINE_integer( 37 | 'output_len', 10, 'Length of the output sequence.' 38 | ) 39 | _SEED = flags.DEFINE_integer('seed', 12345, 'Random seed.') 40 | _QUANT = flags.DEFINE_boolean('quant', False, 'Whether to use quantization.') 41 | 42 | # Define valid multimodal model variants 43 | _VALID_MODEL_VARIANTS = ['4b', '12b', '27b_v3'] 44 | 45 | # Define valid devices 46 | _VALID_DEVICES = ['cpu', 'cuda'] 47 | 48 | 49 | # Validator function for the 'variant' flag 50 | def validate_variant(variant): 51 | if variant not in _VALID_MODEL_VARIANTS: 52 | raise ValueError( 53 | f'Invalid variant: {variant}. Valid variants are:' 54 | f' {_VALID_MODEL_VARIANTS}' 55 | ) 56 | return True 57 | 58 | 59 | # Validator function for the 'device' flag 60 | def validate_device(device): 61 | if device not in _VALID_DEVICES: 62 | raise ValueError( 63 | f'Invalid device: {device}. Valid devices are: {_VALID_DEVICES}' 64 | ) 65 | return True 66 | 67 | 68 | # Register the validator for the 'variant' flag 69 | flags.register_validator( 70 | 'variant', validate_variant, message='Invalid model variant.' 71 | ) 72 | 73 | # Register the validator for the 'device' flag 74 | flags.register_validator('device', validate_device, message='Invalid device.') 75 | 76 | 77 | @contextlib.contextmanager 78 | def _set_default_tensor_type(dtype: torch.dtype): 79 | """Sets the default torch dtype to the given dtype.""" 80 | torch.set_default_dtype(dtype) 81 | yield 82 | torch.set_default_dtype(torch.float) 83 | 84 | 85 | def main(_): 86 | # Construct the model config. 87 | model_config = config.get_model_config(_VARIANT.value) 88 | model_config.dtype = 'float32' 89 | model_config.quant = _QUANT.value 90 | image_paths = {"cow_in_beach": "scripts/images/cow_in_beach.jpg", 91 | "lilly": "scripts/images/lilly.jpg", 92 | "sunflower": "scripts/images/sunflower.JPG", 93 | 'golden_test_image': ( 94 | 'scripts/images/test_image.jpg' 95 | ), 96 | } 97 | 98 | image = {} 99 | for key in image_paths: 100 | try: 101 | image[key] = Image.open(image_paths[key]) # Open local file 102 | image[key].show() 103 | except IOError as e: 104 | print(f"Error loading image: {e}") 105 | exit() 106 | 107 | # Seed random. 108 | random.seed(_SEED.value) 109 | np.random.seed(_SEED.value) 110 | torch.manual_seed(_SEED.value) 111 | 112 | # Create the model and load the weights. 113 | device = torch.device(_DEVICE.value) 114 | with _set_default_tensor_type(model_config.get_dtype()): 115 | model = gemma3_model.Gemma3ForMultimodalLM(model_config) 116 | model.load_state_dict(torch.load(_CKPT.value)['model_state_dict']) 117 | # model.load_weights(_CKPT.value) 118 | model = model.to(device).eval() 119 | print('Model loading done') 120 | 121 | # Generate text only. 122 | result = model.generate( 123 | [ 124 | [ 125 | 'user The capital of Italy' 126 | ' is?\nmodel' 127 | ], 128 | [ 129 | 'user What is your' 130 | ' purpose?\nmodel' 131 | ], 132 | ], 133 | device, 134 | output_len=_OUTPUT_LEN.value, 135 | ) 136 | 137 | # Print the results. 138 | print('======================================') 139 | print(f'Text only RESULT: {result}') 140 | print('======================================') 141 | 142 | # Generate golden Gemax test image. 143 | result = model.generate( 144 | [[ 145 | 'user\n', 146 | image['golden_test_image'], 147 | 'Caption this image. \nmodel', 148 | ]], 149 | device, 150 | output_len=_OUTPUT_LEN.value, 151 | ) 152 | 153 | # Print the result. 154 | print('======================================') 155 | print(f'Golden test image RESULT: {result}') 156 | print('======================================') 157 | 158 | # Generate text and image. 159 | result = model.generate( 160 | [[ 161 | 'user\n', 162 | image['cow_in_beach'], 163 | ( 164 | 'The name of the animal in the image is' 165 | ' \nmodel' 166 | ), 167 | ]], 168 | device, 169 | output_len=_OUTPUT_LEN.value, 170 | ) 171 | 172 | # Print the result. 173 | print('======================================') 174 | print(f'Single image RESULT: {result}') 175 | print('======================================') 176 | 177 | # Generate interleave text and multiple images. 178 | result = model.generate( 179 | [[ 180 | 'user\nThis image', 181 | image['lilly'], 182 | 'and this image', 183 | image['sunflower'], 184 | 'are similar because? \nmodel', 185 | ]], 186 | device, 187 | output_len=_OUTPUT_LEN.value, 188 | ) 189 | 190 | # Print the result. 191 | print('======================================') 192 | print(f'Interleave images RESULT: {result}') 193 | print('======================================') 194 | 195 | 196 | if __name__ == '__main__': 197 | app.run(main) 198 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Gemma in PyTorch 2 | 3 | **Gemma** is a family of lightweight, state-of-the art open models built from research and technology used to create Google Gemini models. They include both text-only and multimodal decoder-only large language models, with open weights, pre-trained variants, and instruction-tuned variants. For more details, please check out the following links: 4 | 5 | * [Gemma on Google AI](https://ai.google.dev/gemma) 6 | * [Gemma on Kaggle](https://www.kaggle.com/models/google/gemma-3) 7 | * [Gemma on Vertex AI Model Garden](https://pantheon.corp.google.com/vertex-ai/publishers/google/model-garden/gemma3) 8 | 9 | This is the official PyTorch implementation of Gemma models. We provide model and inference implementations using both PyTorch and PyTorch/XLA, and support running inference on CPU, GPU and TPU. 10 | 11 | ## Updates 12 | 13 | * [March 12th, 2025 🔥] Support Gemma v3. You can find the checkpoints [on Kaggle](https://www.kaggle.com/models/google/gemma-3/pytorch) and [Hugging Face](https://huggingface.co/models?other=gemma_torch) 14 | 15 | * [June 26th, 2024] Support Gemma v2. You can find the checkpoints [on Kaggle](https://www.kaggle.com/models/google/gemma-2/pytorch) and Hugging Face 16 | 17 | * [April 9th, 2024] Support CodeGemma. You can find the checkpoints [on Kaggle](https://www.kaggle.com/models/google/codegemma/pytorch) and [Hugging Face](https://huggingface.co/collections/google/codegemma-release-66152ac7b683e2667abdee11) 18 | 19 | * [April 5, 2024] Support Gemma v1.1. You can find the v1.1 checkpoints [on Kaggle](https://www.kaggle.com/models/google/gemma/frameworks/pyTorch) and [Hugging Face](https://huggingface.co/collections/google/gemma-release-65d5efbccdbb8c4202ec078b). 20 | 21 | ## Download Gemma model checkpoint 22 | 23 | You can find the model checkpoints on Kaggle: 24 | 25 | - [Gemma 3](https://www.kaggle.com/models/google/gemma-3/pyTorch) 26 | - [Gemma 2](https://www.kaggle.com/models/google/gemma-2/pyTorch) 27 | - [Gemma](https://www.kaggle.com/models/google/gemma/pyTorch) 28 | 29 | Alternatively, you can find the model checkpoints on the Hugging Face Hub [here](https://huggingface.co/models?other=gemma_torch). To download the models, go the the model repository of the model of interest and click the `Files and versions` tab, and download the model and tokenizer files. For programmatic downloading, if you have `huggingface_hub` installed, you can also run: 30 | 31 | ``` 32 | huggingface-cli download google/gemma-3-4b-it-pytorch 33 | ``` 34 | 35 | The following model sizes are available: 36 | 37 | - **Gemma 3**: 38 | - **Text only**: 1b 39 | - **Multimodal**: 4b, 12b, 27b_v3 40 | - **Gemma 2**: 41 | - **Text only**: 2b-v2, 9b, 27b 42 | - **Gemma**: 43 | - **Text only**: 2b, 7b 44 | 45 | 46 | Note that you can choose between the 1B, 4B, 12B, and 27B variants. 47 | 48 | ``` 49 | VARIANT=<1b, 2b, 2b-v2, 4b, 7b, 9b, 12b, 27b, 27b_v3> 50 | CKPT_PATH= 51 | ``` 52 | 53 | ## Try it free on Colab 54 | 55 | Follow the steps at 56 | [https://ai.google.dev/gemma/docs/pytorch_gemma](https://ai.google.dev/gemma/docs/pytorch_gemma). 57 | 58 | ## Try it out with PyTorch 59 | 60 | Prerequisite: make sure you have setup docker permission properly as a non-root user. 61 | 62 | ```bash 63 | sudo usermod -aG docker $USER 64 | newgrp docker 65 | ``` 66 | 67 | ### Build the docker image. 68 | 69 | ```bash 70 | DOCKER_URI=gemma:${USER} 71 | 72 | docker build -f docker/Dockerfile ./ -t ${DOCKER_URI} 73 | ``` 74 | 75 | ### Run Gemma inference on CPU. 76 | 77 | > NOTE: This is a multimodal example. Use a multimodal variant. 78 | 79 | ```bash 80 | docker run -t --rm \ 81 | -v ${CKPT_PATH}:/tmp/ckpt \ 82 | ${DOCKER_URI} \ 83 | python scripts/run_multimodal.py \ 84 | --ckpt=/tmp/ckpt \ 85 | --variant="${VARIANT}" \ 86 | # add `--quant` for the int8 quantized model. 87 | ``` 88 | 89 | ### Run Gemma inference on GPU. 90 | 91 | > NOTE: This is a multimodal example. Use a multimodal variant. 92 | 93 | ```bash 94 | docker run -t --rm \ 95 | --gpus all \ 96 | -v ${CKPT_PATH}:/tmp/ckpt \ 97 | ${DOCKER_URI} \ 98 | python scripts/run_multimodal.py \ 99 | --device=cuda \ 100 | --ckpt=/tmp/ckpt \ 101 | --variant="${VARIANT}" 102 | # add `--quant` for the int8 quantized model. 103 | ``` 104 | 105 | ## Try It out with PyTorch/XLA 106 | 107 | ### Build the docker image (CPU, TPU). 108 | 109 | ```bash 110 | DOCKER_URI=gemma_xla:${USER} 111 | 112 | docker build -f docker/xla.Dockerfile ./ -t ${DOCKER_URI} 113 | ``` 114 | 115 | ### Build the docker image (GPU). 116 | 117 | ```bash 118 | DOCKER_URI=gemma_xla_gpu:${USER} 119 | 120 | docker build -f docker/xla_gpu.Dockerfile ./ -t ${DOCKER_URI} 121 | ``` 122 | 123 | ### Run Gemma inference on CPU. 124 | 125 | > NOTE: This is a multimodal example. Use a multimodal variant. 126 | 127 | ```bash 128 | docker run -t --rm \ 129 | --shm-size 4gb \ 130 | -e PJRT_DEVICE=CPU \ 131 | -v ${CKPT_PATH}:/tmp/ckpt \ 132 | ${DOCKER_URI} \ 133 | python scripts/run_xla.py \ 134 | --ckpt=/tmp/ckpt \ 135 | --variant="${VARIANT}" \ 136 | # add `--quant` for the int8 quantized model. 137 | ``` 138 | 139 | ### Run Gemma inference on TPU. 140 | 141 | Note: be sure to use the docker container built from `xla.Dockerfile`. 142 | 143 | ```bash 144 | docker run -t --rm \ 145 | --shm-size 4gb \ 146 | -e PJRT_DEVICE=TPU \ 147 | -v ${CKPT_PATH}:/tmp/ckpt \ 148 | ${DOCKER_URI} \ 149 | python scripts/run_xla.py \ 150 | --ckpt=/tmp/ckpt \ 151 | --variant="${VARIANT}" \ 152 | # add `--quant` for the int8 quantized model. 153 | ``` 154 | 155 | ### Run Gemma inference on GPU. 156 | 157 | Note: be sure to use the docker container built from `xla_gpu.Dockerfile`. 158 | 159 | ```bash 160 | docker run -t --rm --privileged \ 161 | --shm-size=16g --net=host --gpus all \ 162 | -e USE_CUDA=1 \ 163 | -e PJRT_DEVICE=CUDA \ 164 | -v ${CKPT_PATH}:/tmp/ckpt \ 165 | ${DOCKER_URI} \ 166 | python scripts/run_xla.py \ 167 | --ckpt=/tmp/ckpt \ 168 | --variant="${VARIANT}" \ 169 | # add `--quant` for the int8 quantized model. 170 | ``` 171 | 172 | ### Tokenizer Notes 173 | 174 | 99 unused tokens are reserved in the pretrained tokenizer model to assist with more efficient training/fine-tuning. Unused tokens are in the string format of `` with token id range of `[7-104]`. 175 | 176 | ``` 177 | "": 7, 178 | "": 8, 179 | "": 9, 180 | ... 181 | "": 104, 182 | ``` 183 | 184 | ## Disclaimer 185 | 186 | This is not an officially supported Google product. 187 | -------------------------------------------------------------------------------- /scripts/run_multimodal.py.orig: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import contextlib 17 | import random 18 | 19 | import numpy as np 20 | import torch 21 | from absl import app 22 | from absl import flags 23 | from google3.pyglib import gfile 24 | from PIL import Image 25 | 26 | from google3.third_party.open_models_release.gemma_pytorch.gemma import config 27 | from google3.third_party.open_models_release.gemma_pytorch.gemma import model as gemma_model 28 | from google3.third_party.open_models_release.gemma_pytorch.gemma import gemma3_model 29 | 30 | # Define flags 31 | FLAGS = flags.FLAGS 32 | 33 | _CKPT = flags.DEFINE_string('ckpt', None, 'Path to the checkpoint file.', required=True) 34 | _VARIANT = flags.DEFINE_string('variant', '4b', 'Model variant.') 35 | _DEVICE = flags.DEFINE_string('device', 'cpu', 'Device to run the model on.') 36 | _OUTPUT_LEN = flags.DEFINE_integer('output_len', 10, 'Length of the output sequence.') 37 | _SEED = flags.DEFINE_integer('seed', 12345, 'Random seed.') 38 | _QUANT = flags.DEFINE_boolean('quant', False, 'Whether to use quantization.') 39 | 40 | # Define valid model variants 41 | _VALID_MODEL_VARIANTS = ['4b', '12b', '27b_v3'] 42 | 43 | # Define valid devices 44 | _VALID_DEVICES = ['cpu', 'cuda'] 45 | 46 | # Validator function for the 'variant' flag 47 | def validate_variant(variant): 48 | if variant not in _VALID_MODEL_VARIANTS: 49 | raise ValueError(f'Invalid variant: {variant}. Valid variants are: {_VALID_MODEL_VARIANTS}') 50 | return True 51 | 52 | # Validator function for the 'device' flag 53 | def validate_device(device): 54 | if device not in _VALID_DEVICES: 55 | raise ValueError(f'Invalid device: {device}. Valid devices are: {_VALID_DEVICES}') 56 | return True 57 | 58 | # Register the validator for the 'variant' flag 59 | flags.register_validator('variant', validate_variant, message='Invalid model variant.') 60 | 61 | # Register the validator for the 'device' flag 62 | flags.register_validator('device', validate_device, message='Invalid device.') 63 | 64 | @contextlib.contextmanager 65 | def _set_default_tensor_type(dtype: torch.dtype): 66 | """Sets the default torch dtype to the given dtype.""" 67 | torch.set_default_dtype(dtype) 68 | yield 69 | torch.set_default_dtype(torch.float) 70 | 71 | def main(_): 72 | # Construct the model config. 73 | model_config = config.get_model_config(_VARIANT.value) 74 | model_config.dtype = "float32" if _DEVICE.value == "cpu" else "float16" 75 | model_config.quant = _QUANT.value 76 | image_paths = {"cow_in_beach": "/cns/pd-d/home/omst/test_images/cow_in_beach.jpg", 77 | "lilly": "/cns/pd-d/home/omst/test_images/lilly.jpg", 78 | "sunflower": "/cns/pd-d/home/omst/test_images/sunflower.JPG", 79 | 'golden_test_image': ( 80 | '/cns/ge-d/home/mriviere/gemma_logits/test_image.jpg' 81 | ), 82 | } 83 | image = {} 84 | for key in image_paths: 85 | try: 86 | with gfile.Open(image_paths[key], 'rb') as f: 87 | image[key] = Image.open(f) 88 | image[key].show() 89 | except gfile.FileError as e: 90 | print(f"Error loading image: {e}") 91 | exit() 92 | 93 | # Seed random. 94 | random.seed(_SEED.value) 95 | np.random.seed(_SEED.value) 96 | torch.manual_seed(_SEED.value) 97 | 98 | # Create the model and load the weights. 99 | device = torch.device(_DEVICE.value) 100 | with _set_default_tensor_type(model_config.get_dtype()): 101 | model = gemma3_model.Gemma3ForMultimodalLM(model_config) 102 | model.load_state_dict(torch.load(_CKPT.value)['model_state_dict']) 103 | # model.load_weights(_CKPT.value) 104 | model = model.to(device).eval() 105 | print("Model loading done") 106 | 107 | # Generate text only. 108 | result = model.generate([["user The capital of Italy is?\nmodel"], ["user What is your purpose?\nmodel"]], device, output_len=_OUTPUT_LEN.value) 109 | 110 | # Print the results. 111 | print('======================================') 112 | print(f'Text only RESULT: {result}') 113 | print('======================================') 114 | 115 | # Generate golden Gemax test image. 116 | result = model.generate( 117 | [['user\n', image['golden_test_image'], 'Caption this image. \nmodel']], 118 | device, 119 | output_len=_OUTPUT_LEN.value, 120 | ) 121 | 122 | # Print the result. 123 | print('======================================') 124 | print(f'Golden test image RESULT: {result}') 125 | print('======================================') 126 | 127 | # Generate text and image. 128 | result = model.generate([['user\n', image["cow_in_beach"], "The name of the animal in the image is \nmodel"]], device, output_len=_OUTPUT_LEN.value) 129 | 130 | 131 | # Print the result. 132 | print('======================================') 133 | print(f'Single image RESULT: {result}') 134 | print('======================================') 135 | 136 | # Generate interleave text and multiple images. 137 | result = model.generate([["user\nThis image", image["lilly"], "and this image", image["sunflower"], "are similar because? \nmodel"]], device, output_len=_OUTPUT_LEN.value) 138 | 139 | # Print the result. 140 | print('======================================') 141 | print(f'Interleave images RESULT: {result}') 142 | print('======================================') 143 | 144 | if __name__ == "__main__": 145 | app.run(main) 146 | 147 | 148 | # How to run this script: 149 | 150 | # Example command (replace with your actual paths and values): 151 | # blaze run third_party/open_models_release/gemma_pytorch/scripts:run_multimodal -- --device=cpu --ckpt=/usr/local/google/home/imayank/Desktop/gemma_files/ckpts/gemma-3.0-4b-pt/mm/model5.ckpt --output_len=5 152 | # Important: 153 | # - Replace '/path/to/your/pytorch_checkpoint/model.ckpt' with the actual path to your checkpoint file. 154 | # - Choose the correct --variant (model size). 155 | # - Use --device=cuda if you have a GPU; otherwise, use --device=cpu. -------------------------------------------------------------------------------- /gemma/siglip_vision/siglip_vision_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Siglip vision model for gemma 3 and paligemma.""" 15 | 16 | import torch 17 | from torch import nn 18 | import torch.nn.functional as F 19 | 20 | from . import config as siglip_vision_config 21 | SiglipVisionModelConfig = siglip_vision_config.SiglipVisionModelConfig 22 | 23 | class AveragePool2D(nn.Module): 24 | """Applies 4x4 average pooling and reshaping.""" 25 | def __init__(self, config): 26 | super().__init__() 27 | self.config = config 28 | 29 | def forward(self, x): 30 | """Applies 4x4 average pooling and reshaping.""" 31 | batch_size, seq_len, channels = x.shape 32 | width = int(seq_len**0.5) 33 | if width * width != seq_len: 34 | raise ValueError( 35 | f"Sequence length {seq_len} is not a perfect square. Cannot reshape to a square image." 36 | ) 37 | # Bx(64^2)x1152 -> Bx1152x(64^2) -> Bx1152x64x64 38 | x = x.transpose(1, 2).reshape(batch_size, channels, width, width) 39 | # Bx1152x64x64-> Bx1152x16x16 40 | x = F.avg_pool2d(x, kernel_size=4, stride=4) 41 | # Bx1152x64x64-> Bx1152x256 -> Bx256x1152 42 | x = x.flatten(2).transpose(1, 2) 43 | return x 44 | 45 | # Siglip Attention 46 | class SiglipAttention(nn.Module): 47 | """Siglip attention module.""" 48 | 49 | def __init__(self, dim, num_heads, head_dim): 50 | super().__init__() 51 | self.dim = dim 52 | self.num_heads = num_heads 53 | self.head_dim = head_dim 54 | 55 | # Key, Query, Value projections 56 | self.k_proj = nn.Linear(dim, num_heads * head_dim, bias=True) 57 | self.q_proj = nn.Linear(dim, num_heads * head_dim, bias=True) 58 | self.v_proj = nn.Linear(dim, num_heads * head_dim, bias=True) 59 | 60 | # Output projection 61 | self.o_proj = nn.Linear(num_heads * head_dim, dim, bias=True) 62 | 63 | def forward(self, x): 64 | batch_size, seq_len, _ = x.size() 65 | 66 | # Project inputs to key, query, value 67 | k = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim) 68 | q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim) 69 | v = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim) 70 | 71 | # Transpose for multi-head attention 72 | k = k.transpose(1, 2) # (batch_size, num_heads, seq_len, head_dim) 73 | q = q.transpose(1, 2) # (batch_size, num_heads, seq_len, head_dim) 74 | v = v.transpose(1, 2) # (batch_size, num_heads, seq_len, head_dim) 75 | 76 | # Scaled dot-product attention 77 | scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim**0.5) 78 | attn_weights = F.softmax(scores, dim=-1) 79 | attn_output = torch.matmul(attn_weights, v) 80 | 81 | # Transpose back to (batch_size, seq_len, num_heads, head_dim) 82 | attn_output = attn_output.transpose(1, 2).contiguous() 83 | attn_output = attn_output.view( 84 | batch_size, seq_len, self.num_heads * self.head_dim 85 | ) 86 | 87 | # Apply output projection 88 | output = self.o_proj(attn_output) 89 | 90 | return output 91 | 92 | 93 | class SiglipMLP(nn.Module): 94 | """Siglip MLP module.""" 95 | def __init__(self, hidden_size, intermediate_size): 96 | super().__init__() 97 | self.fc1 = nn.Linear(hidden_size, intermediate_size) 98 | self.fc2 = nn.Linear(intermediate_size, hidden_size) 99 | 100 | def gelu_tanh(self, x): 101 | return ( 102 | 0.5 103 | * x 104 | * ( 105 | 1 106 | + torch.tanh( 107 | torch.sqrt(torch.tensor(2.0 / torch.pi, device=x.device)) 108 | * (x + 0.044715 * torch.pow(x, 3)) 109 | ) 110 | ) 111 | ) 112 | 113 | def forward(self, x): 114 | x = self.fc1(x) 115 | x = self.gelu_tanh(x) 116 | x = self.fc2(x) 117 | return x 118 | 119 | 120 | class SiglipEncoderBlock(nn.Module): 121 | """Encoder block (Transformer layer) for siglip vision model.""" 122 | 123 | def __init__(self, config: SiglipVisionModelConfig): 124 | super().__init__() 125 | self.self_attn = SiglipAttention( 126 | config.embedding_dim, config.num_attention_heads, config.head_dim 127 | ) 128 | # SigLiPFromPatches_0/siglip_encoder/Transformer/encoderblock_0/LayerNorm_0 129 | self.layer_norm1 = nn.LayerNorm(config.embedding_dim, eps=config.layer_norm_eps) 130 | self.mlp = SiglipMLP(config.embedding_dim, config.intermediate_size) 131 | # SigLiPFromPatches_0/siglip_encoder/Transformer/encoderblock_0/LayerNorm_1 132 | self.layer_norm2 = nn.LayerNorm(config.embedding_dim, eps=config.layer_norm_eps) 133 | 134 | def forward(self, x): 135 | # Pre-LN 136 | residual = x 137 | # SigLiPFromPatches_0/siglip_encoder/Transformer/encoderblock_0/LayerNorm_0 138 | x = self.layer_norm1(x) 139 | x = self.self_attn(x) 140 | x = x + residual # Residual connection *after* LayerNorm 141 | 142 | residual = x 143 | # SigLiPFromPatches_0/siglip_encoder/Transformer/encoderblock_0/LayerNorm_1 144 | x = self.layer_norm2(x) 145 | x = self.mlp(x) 146 | x = x + residual # Residual connection *after* LayerNorm 147 | return x 148 | 149 | 150 | # https://developers.googleblog.com/en/gemma-explained-paligemma-architecture/ 151 | class SiglipVisionModel(nn.Module): 152 | """Signlip vision model for gemma 3 and paligemma.""" 153 | 154 | def __init__(self, config: SiglipVisionModelConfig): 155 | super().__init__() 156 | 157 | # SigLiPFromPatches_0/siglip_encoder/embedding 158 | self.patch_embedding = nn.Conv2d( 159 | in_channels=config.input_channels, 160 | out_channels=config.embedding_dim, 161 | kernel_size=config.conv2d_patch_size, 162 | stride=config.conv2d_patch_size, 163 | padding=0, 164 | bias=config.embedding_use_bias, 165 | ) 166 | self.num_patches = (config.image_size // config.conv2d_patch_size) ** 2 167 | self.num_positions = self.num_patches 168 | # SigLiPFromPatches_0/siglip_encoder 169 | self.position_embedding = nn.Embedding( 170 | self.num_positions, config.embedding_dim 171 | ) 172 | self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) 173 | 174 | # SigLiPFromPatches_0/siglip_encoder/Transformer/encoderblock_i 175 | self.encoder_blocks = nn.ModuleList( 176 | SiglipEncoderBlock(config=config) 177 | for _ in range(config.num_hidden_layers) 178 | ) 179 | # SigLiPFromPatches_0/siglip_encoder/Transformer/encoder_norm 180 | self.final_norm = nn.LayerNorm(config.embedding_dim, config.layer_norm_eps) 181 | self.avg_pool = AveragePool2D(config) 182 | self.config = config 183 | 184 | # pixel_values=Bxconfig.input_channels x config.image_size x config.image_size 185 | @torch.inference_mode 186 | def forward( 187 | self, 188 | pixel_values: torch.Tensor, 189 | ) -> torch.Tensor: 190 | # Embed the image according to SiplipVisionEmbeddings. 191 | x = self.patch_embedding(pixel_values) 192 | # (batch_size,channels,height,width)->(batch_size, height*width, channels) 193 | x = x.flatten(2).transpose(1, 2) 194 | 195 | position_ids = self.position_ids.to(pixel_values.device) 196 | x = x + self.position_embedding(position_ids) 197 | 198 | for block in self.encoder_blocks: 199 | x = block(x) # batch_size, height*width, embedding_dim (1152) 200 | x = self.final_norm(x) 201 | 202 | # siglip exit https://source.corp.google.com/piper///depot/google3/third_party/py/gemma/multimodal/vision.py;l=220 203 | return self.avg_pool(x) 204 | -------------------------------------------------------------------------------- /gemma/gemma3_preprocessor.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Preprocessor for Gemma3 input.""" 15 | import token 16 | 17 | from typing import Union, Any, Sequence 18 | 19 | import torch 20 | from absl import app 21 | from PIL import Image 22 | from .siglip_vision import preprocessor as siglip_vision_preprocessor 23 | from .siglip_vision import pan_and_scan 24 | from . import tokenizer 25 | from . import config as gemma_config 26 | 27 | CROPPED_IMAGE_PREFIX = "here is the original image" 28 | CROPPED_IMAGE_FILLER = "and here are some crops to help you see better" 29 | 30 | 31 | def gemma3_input_preprocessor( 32 | raw_user_prompt: Sequence[Union[Image.Image, str]], 33 | ) -> Sequence[Union[torch.Tensor, str]]: 34 | """Preprocessor for Gemma3 input. 35 | 36 | Args: 37 | raw_user_prompt: A list of images or strings, as provided by the user. 38 | 39 | Returns: 40 | A list of preprocessed images or strings. 41 | """ 42 | preprocessed_input: list[Union[torch.Tensor, str]] = [] 43 | for element in raw_user_prompt: 44 | if isinstance(element, Image.Image): 45 | cropped_images = pan_and_scan.pan_and_scan(element) 46 | preprocessed_images_cropped = siglip_vision_preprocessor.preprocess_images_for_siglip_vision(cropped_images) 47 | preprocessed_images_uncropped = siglip_vision_preprocessor.preprocess_images_for_siglip_vision([element]) 48 | if len(preprocessed_images_cropped) == 1: 49 | preprocessed_input.append(preprocessed_images_uncropped[0]) 50 | elif len(preprocessed_images_cropped) > 1: 51 | preprocessed_input.append(CROPPED_IMAGE_PREFIX) 52 | preprocessed_input.append(preprocessed_images_uncropped[0]) 53 | preprocessed_input.append(CROPPED_IMAGE_FILLER) 54 | preprocessed_input.extend(preprocessed_images_cropped) 55 | else: 56 | raise ValueError("No images found in the input.") 57 | else: 58 | preprocessed_input.append(element) 59 | 60 | return preprocessed_input 61 | 62 | 63 | def gemma3_batch_input_preprocessor(raw_input: Sequence[Sequence[Union[Image.Image, str]]]): 64 | """Preprocessor for Gemma3 batch input. 65 | """ 66 | preprocessed_input: list[Sequence[Union[torch.Tensor, str]]] = [] 67 | for element in raw_input: 68 | preprocessed_input.append(gemma3_input_preprocessor(element)) 69 | return preprocessed_input 70 | 71 | 72 | def tokenize_raw_input( 73 | tokenizer_obj: tokenizer.Tokenizer, 74 | raw_input: Sequence[Sequence[Union[str, Image.Image]]], 75 | config: gemma_config.GemmaConfig, 76 | output_len: int, 77 | device: Any, 78 | ) -> dict[str, Any]: 79 | """ 80 | Converts a preprocessed batch of interleaved text and image inputs into 81 | token IDs and an image batch suitable for gemma3 model. 82 | 83 | Args: 84 | preprocessed_batch: List of lists containing strings and torch.Tensor images. 85 | image_token_id: Token ID to represent image placeholders. 86 | max_image_tokens: Number of tokens reserved for each image. 87 | image_size: Expected size of images (C, H, W). 88 | 89 | Returns: 90 | user_input_token_ids: Batch of token IDs with shape (B, L), where L is the max sequence length. 91 | image_batch: Batch of images with shape (B, N, C, H, W), where N is the max number of images. 92 | """ 93 | if config.vision_config is None: 94 | raise ValueError('vision_config must be provided for Gemma3.') 95 | 96 | preprocessed_batch = gemma3_batch_input_preprocessor(raw_input) 97 | 98 | # Initialize lists to store token IDs and image tensors 99 | all_token_ids = [] 100 | all_images = [] 101 | prompt_lengths = [] 102 | 103 | max_prompt_len = 0 104 | min_prompt_len = float("inf") 105 | max_num_images = 0 106 | # Iterate over each user prompt in the batch 107 | for prompt in preprocessed_batch: 108 | token_ids = [] 109 | images = [] 110 | token_ids.append(tokenizer_obj.bos_id) 111 | # Process each element in the prompt 112 | for element in prompt: 113 | if isinstance(element, str): 114 | # Tokenize text and add to token_ids 115 | tokens = tokenizer_obj.encode(element, bos=False, eos=False) 116 | token_ids.extend(tokens) 117 | elif isinstance(element, torch.Tensor): 118 | # Prepend (dual endline + tokenizer_obj.boi) 119 | token_ids.extend(tokenizer_obj.encode("\n\n", bos=False, eos=False)) 120 | token_ids.append(tokenizer_obj.boi_id) 121 | # Add image placeholder tokens 122 | token_ids.extend( 123 | [tokenizer_obj.image_token_placeholder_id] 124 | * config.vision_config.encoding_sequence_length 125 | ) 126 | # Append (tokenizer_obj.eoi + dual endline) 127 | token_ids.append(tokenizer_obj.eoi_id) 128 | token_ids.extend(tokenizer_obj.encode("\n\n", bos=False, eos=False)) 129 | # Store the image tensor 130 | images.append(element) 131 | else: 132 | raise ValueError( 133 | "Unsupported type in prompt. Expected str or torch.Tensor." 134 | ) 135 | curr_prompt_len = len(token_ids) 136 | prompt_lengths.append(curr_prompt_len) 137 | 138 | max_prompt_len = max(max_prompt_len, curr_prompt_len) 139 | min_prompt_len = min(min_prompt_len, curr_prompt_len) 140 | max_num_images = max(max_num_images, len(images)) 141 | 142 | all_token_ids.append(token_ids) 143 | all_images.append(images) 144 | 145 | max_seq_len = max_prompt_len + output_len 146 | 147 | # Pad token IDs to the maximum sequence length 148 | user_input_token_ids = [] 149 | for token_ids in all_token_ids: 150 | pad_length = max_seq_len - len(token_ids) 151 | padded_token_ids = token_ids + [tokenizer_obj.pad_id] * pad_length 152 | user_input_token_ids.append(padded_token_ids) 153 | 154 | # Pad images to the maximum number of images in the batch 155 | image_batch = [] 156 | image_presence_mask = [] 157 | for images in all_images: 158 | # Check if all images within the current sublist have the same shape 159 | if images: # Check if the sublist is not empty 160 | first_shape = images[0].shape 161 | for img in images: 162 | assert img.shape == first_shape, "Images within a sublist must have the same shape." 163 | pad_length = max_num_images - len(images) 164 | padded_images = images.copy() #create a copy so the original data is not altered. 165 | presence_mask = [True] * len(images) 166 | 167 | if pad_length > 0: 168 | # Create a list of zero tensors for padding 169 | padding = [ 170 | torch.zeros( 171 | ( 172 | config.vision_config.input_channels, 173 | config.vision_config.image_size, 174 | config.vision_config.image_size, 175 | ), device=device 176 | ) 177 | for _ in range(pad_length) 178 | ] 179 | padded_images.extend(padding) 180 | presence_mask.extend([False] * pad_length) 181 | image_batch.append(padded_images) 182 | image_presence_mask.append(presence_mask) 183 | 184 | # Convert lists to tensors 185 | user_input_token_ids = torch.tensor(user_input_token_ids, dtype=torch.long, device=device) 186 | if max_num_images > 0: 187 | image_batch = torch.stack([torch.stack(images) for images in image_batch]).to( 188 | device 189 | ) 190 | image_presence_mask = torch.tensor(image_presence_mask, dtype=torch.bool, device=device) 191 | else: 192 | image_batch = None 193 | image_presence_mask = None 194 | 195 | # Prepare the output dictionary 196 | output_dict = { 197 | "user_input_token_ids": user_input_token_ids, 198 | "image_batch": image_batch, 199 | "batch_size": len(preprocessed_batch), 200 | "min_prompt_len": min_prompt_len, 201 | "max_prompt_len": max_prompt_len, 202 | "max_seq_len": max_seq_len, 203 | "image_presence_mask": image_presence_mask, 204 | } 205 | 206 | return output_dict 207 | -------------------------------------------------------------------------------- /scripts/run_xla.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import argparse 15 | import contextlib 16 | import os 17 | import random 18 | import socket 19 | import sys 20 | from typing import List, Union 21 | 22 | import numpy as np 23 | import torch 24 | import torch.multiprocessing 25 | 26 | from gemma.config import GemmaConfig, get_model_config 27 | from gemma.model_xla import GemmaForCausalLM 28 | from gemma.tokenizer import Tokenizer 29 | import gemma.xla_model_parallel as xla_model_parallel 30 | 31 | USE_CUDA = os.environ.get('USE_CUDA', False) 32 | if not USE_CUDA: 33 | import torch_xla.core.xla_model as xm 34 | import torch_xla.distributed.xla_multiprocessing as xmp 35 | else: 36 | # Choose an available port. 37 | with contextlib.closing(socket.socket(socket.AF_INET, 38 | socket.SOCK_STREAM)) as s: 39 | s.bind(('', 0)) 40 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 41 | MASTER_PORT = str(s.getsockname()[1]) 42 | 43 | 44 | @contextlib.contextmanager 45 | def _set_default_tensor_type(dtype: torch.dtype): 46 | """Sets the default torch dtype to the given dtype.""" 47 | torch.set_default_dtype(dtype) 48 | yield 49 | torch.set_default_dtype(torch.float) 50 | 51 | 52 | def generate( 53 | i: int, 54 | model_config: GemmaConfig, 55 | ckpt_path: str, 56 | prompts: List[str], 57 | output_lens: List[int], 58 | temperatures: Union[List[float], None], 59 | top_ps: List[float], 60 | top_ks: List[int], 61 | seed: int 62 | ): 63 | random.seed(seed) 64 | np.random.seed(seed) 65 | torch.manual_seed(seed) 66 | if USE_CUDA: 67 | os.environ['MASTER_ADDR'] = '127.0.0.1' 68 | os.environ['MASTER_PORT'] = MASTER_PORT 69 | if not torch.distributed.is_initialized(): 70 | torch.distributed.init_process_group( 71 | "nccl", 72 | rank=int(os.environ.get("RANK", 0)), 73 | world_size=int(os.environ.get("WORLD_SIZE", 1))) 74 | xla_model_parallel.set_g_group() 75 | local_rank = int(os.environ.get("LOCAL_RANK", 0)) 76 | device = torch.device("cuda", local_rank) 77 | torch.cuda.set_device(local_rank) 78 | else: 79 | device = xm.xla_device() 80 | xm.set_rng_state(seed, device) 81 | 82 | rank = xla_model_parallel.get_model_parallel_rank() 83 | world_size = xla_model_parallel.get_model_parallel_world_size() 84 | if rank > 0: 85 | sys.stdout = open(os.devnull, 'w') 86 | 87 | # build, load and compile model. 88 | with _set_default_tensor_type(model_config.get_dtype()): 89 | model = GemmaForCausalLM(model_config, world_size, rank, device) 90 | model.load_weights(ckpt_path) 91 | model = model.to(device).eval() 92 | 93 | # create tokenizer. 94 | tokenizer = Tokenizer(model_config.tokenizer) 95 | 96 | prompt_tokens = [tokenizer.encode(prompt) for prompt in prompts] 97 | min_prompt_len = min(len(p) for p in prompt_tokens) 98 | 99 | batch_size = len(prompts) 100 | if temperatures is not None: 101 | assert batch_size == len(temperatures) 102 | assert batch_size == len(top_ps) 103 | assert batch_size == len(top_ks) 104 | max_seq_len = max([len(p) + o for p, o in zip(prompt_tokens, output_lens)]) 105 | assert max_seq_len <= model_config.max_position_embeddings 106 | if model_config.num_key_value_heads < world_size: 107 | assert world_size % model_config.num_key_value_heads == 0 108 | n_local_heads = 1 109 | else: 110 | assert model_config.num_key_value_heads % world_size == 0 111 | n_local_heads = model_config.num_key_value_heads // world_size 112 | 113 | # build KV caches 114 | kv_caches = [] 115 | for _ in range(model_config.num_hidden_layers): 116 | k_cache = torch.zeros( 117 | size=(batch_size, max_seq_len, n_local_heads, 118 | model_config.head_dim), 119 | dtype=model_config.get_dtype(), 120 | device=device, 121 | ) 122 | v_cache = torch.zeros( 123 | size=(batch_size, max_seq_len, n_local_heads, 124 | model_config.head_dim), 125 | dtype=model_config.get_dtype(), 126 | device=device, 127 | ) 128 | kv_caches.append((k_cache, v_cache)) 129 | 130 | # prepare inputs 131 | token_ids_tensor = torch.full((batch_size, max_seq_len), 132 | tokenizer.pad_id, 133 | dtype=torch.int64) 134 | input_token_ids_tensor = torch.full((batch_size, min_prompt_len), 135 | tokenizer.pad_id, 136 | dtype=torch.int64) 137 | for i, p in enumerate(prompt_tokens): 138 | token_ids_tensor[i, :len(p)] = torch.tensor(p) 139 | input_token_ids_tensor[i, :min_prompt_len] = torch.tensor( 140 | p[:min_prompt_len]) 141 | token_ids_tensor = token_ids_tensor.to(device) 142 | prompt_mask_tensor = token_ids_tensor != tokenizer.pad_id 143 | input_token_ids_tensor = input_token_ids_tensor.to(device) 144 | input_positions_tensor = torch.arange(0, min_prompt_len, 145 | dtype=torch.int64).to(device) 146 | mask_tensor = torch.full((1, 1, max_seq_len, max_seq_len), 147 | -2.3819763e38).to(torch.float) 148 | mask_tensor = torch.triu(mask_tensor, diagonal=1).to(device) 149 | curr_mask_tensor = mask_tensor.index_select(2, input_positions_tensor) 150 | output_positions_tensor = torch.LongTensor([min_prompt_len - 1]).to(device) 151 | temperatures_tensor = None if not temperatures else torch.FloatTensor(temperatures).to(device) 152 | top_ps_tensor = torch.FloatTensor(top_ps).to(device) 153 | top_ks_tensor = torch.LongTensor(top_ks).to(device) 154 | output_index = torch.tensor(min_prompt_len, dtype=torch.int64).to(device) 155 | if not USE_CUDA: 156 | xm.mark_step() 157 | 158 | # Prefill up to min_prompt_len tokens, then treat other prefill as decode and ignore output. 159 | for i in range(max_seq_len - min_prompt_len): 160 | next_token_ids, _ = model( 161 | input_token_ids=input_token_ids_tensor, 162 | input_positions=input_positions_tensor, 163 | kv_write_indices=None, 164 | kv_caches=kv_caches, 165 | mask=curr_mask_tensor, 166 | output_positions=output_positions_tensor, 167 | temperatures=temperatures_tensor, 168 | top_ps=top_ps_tensor, 169 | top_ks=top_ks_tensor, 170 | ) 171 | curr_prompt_mask = prompt_mask_tensor.index_select( 172 | 1, output_index).squeeze(dim=1) 173 | curr_token_ids = token_ids_tensor.index_select( 174 | 1, output_index).squeeze(dim=1) 175 | output_token_ids = torch.where(curr_prompt_mask, curr_token_ids, 176 | next_token_ids).unsqueeze(dim=1) 177 | token_ids_tensor.index_copy_(1, output_index, output_token_ids) 178 | 179 | input_token_ids_tensor = output_token_ids 180 | input_positions_tensor = output_index.unsqueeze(dim=-1) 181 | curr_mask_tensor = mask_tensor.index_select(2, input_positions_tensor) 182 | output_positions_tensor = torch.tensor(0, dtype=torch.int64).to(device) 183 | output_index = output_index + 1 184 | if not USE_CUDA: 185 | xm.mark_step() 186 | 187 | # Detokenization. 188 | token_ids = token_ids_tensor.tolist() 189 | results = [] 190 | for i, tokens in enumerate(token_ids): 191 | trimmed_output = tokens[len(prompt_tokens[i]):len(prompt_tokens[i]) + 192 | output_lens[i]] 193 | if tokenizer.eos_id in trimmed_output: 194 | eos_index = trimmed_output.index(tokenizer.eos_id) 195 | trimmed_output = trimmed_output[:eos_index] 196 | results.append(tokenizer.decode(trimmed_output)) 197 | 198 | for prompt, result in zip(prompts, results): 199 | print('======================================') 200 | print(f'PROMPT: {prompt}') 201 | print(f'RESULT: {result}') 202 | print('======================================') 203 | 204 | 205 | def main(args): 206 | model_config = get_model_config(args.variant) 207 | model_config.quant = args.quant 208 | 209 | prompts = [args.prompt] 210 | n = len(prompts) 211 | output_lengths = [args.output_len] * n 212 | temperatures = [0.95] * n 213 | top_ps = [1.0] * n 214 | top_ks = [100] * n 215 | 216 | if USE_CUDA: 217 | os.environ['MASTER_ADDR'] = '127.0.0.1' 218 | os.environ['MASTER_PORT'] = MASTER_PORT 219 | if not torch.distributed.is_initialized(): 220 | torch.distributed.init_process_group( 221 | "nccl", 222 | rank=int(os.environ.get("RANK", 0)), 223 | world_size=int(os.environ.get("WORLD_SIZE", 1))) 224 | xla_model_parallel.set_g_group() 225 | torch.multiprocessing.spawn( 226 | generate, 227 | args=( 228 | model_config, 229 | args.ckpt, 230 | prompts, 231 | output_lengths, 232 | temperatures, 233 | top_ps, 234 | top_ks, 235 | args.seed, 236 | ), 237 | ) 238 | else: 239 | xmp.spawn( 240 | generate, 241 | args=( 242 | model_config, 243 | args.ckpt, 244 | prompts, 245 | output_lengths, 246 | temperatures, 247 | top_ps, 248 | top_ks, 249 | args.seed, 250 | ), 251 | ) 252 | 253 | 254 | if __name__ == '__main__': 255 | parser = argparse.ArgumentParser() 256 | parser.add_argument("--ckpt", type=str, required=True) 257 | parser.add_argument("--variant", 258 | type=str, 259 | default="2b", 260 | choices=["2b", "2b-v2", "7b", "9b", "27b"]) 261 | parser.add_argument("--output_len", type=int, default=4) 262 | parser.add_argument("--seed", type=int, default=12345) 263 | parser.add_argument("--quant", action='store_true') 264 | parser.add_argument("--prompt", type=str, default="The meaning of life is") 265 | args = parser.parse_args() 266 | 267 | main(args) 268 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /gemma/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Gemma model config.""" 16 | 17 | import dataclasses 18 | import enum 19 | import os 20 | import torch 21 | from typing import Optional, Sequence 22 | from .siglip_vision import config as siglip_vision_config 23 | 24 | 25 | # Keep a mapping from dtype strings to the supported torch dtypes. 26 | _STR_DTYPE_TO_TORCH_DTYPE = dict({ 27 | 'float16': torch.float16, 28 | 'float': torch.float32, 29 | 'float32': torch.float32, 30 | 'bfloat16': torch.bfloat16, 31 | }) 32 | 33 | 34 | class AttentionType(enum.Enum): 35 | GLOBAL = 1 36 | LOCAL_SLIDING = 2 37 | 38 | 39 | class Architecture(enum.Enum): 40 | GEMMA_1 = 1 41 | GEMMA_2 = 2 42 | GEMMA_3 = 3 43 | 44 | 45 | @dataclasses.dataclass 46 | class GemmaConfig: 47 | # The architecture of the model. 48 | architecture: Architecture = Architecture.GEMMA_1 49 | # The number of tokens in the vocabulary. 50 | vocab_size: int = 256000 51 | # The maximum sequence length that this model might ever be used with. 52 | max_position_embeddings: int = 8192 53 | # The number of blocks in the model. 54 | num_hidden_layers: int = 28 55 | # The number of attention heads used in the attention layers of the model. 56 | num_attention_heads: int = 16 57 | # The number of key-value heads for implementing attention. 58 | num_key_value_heads: int = 16 59 | # The hidden size of the model. 60 | hidden_size: int = 3072 61 | # The dimension of the MLP representations. 62 | intermediate_size: int = 24576 63 | # The number of head dimensions. 64 | head_dim: int = 256 65 | # The epsilon used by the rms normalization layers. 66 | rms_norm_eps: float = 1e-6 67 | # The dtype of the weights. 68 | dtype: str = 'bfloat16' 69 | # Whether a quantized version of the model is used. 70 | quant: bool = False 71 | # The path to the model tokenizer. 72 | tokenizer: Optional[str] = ( 73 | 'tokenizer/tokenizer.model' 74 | ) 75 | # The types of attention used in the layers of the model. 76 | attn_types: Optional[Sequence[AttentionType]] = None 77 | # The size of the sliding window used for local attention. 78 | sliding_window_size: Optional[int] = None 79 | # If provided, the final logits are softcapped to this value. 80 | final_logit_softcapping: Optional[float] = None 81 | # If provided, the attention logits are softcapped to this value. 82 | attn_logit_softcapping: Optional[float] = None 83 | # If provided, the query vector is normalized using the 84 | # inverse square root of this value instead of head_dim. 85 | query_pre_attn_scalar: Optional[int] = None 86 | # Whether to use pre mlp normalization. 87 | use_pre_ffw_norm: bool = False 88 | # Whether to use post mlp normalization. 89 | use_post_ffw_norm: bool = False 90 | # The wave length of the rotary embedding. 91 | rope_wave_length: dict[AttentionType, int] | None = None 92 | # Whether to use QK normalization in the attention blocks. 93 | use_qk_norm: bool = False 94 | # Vision model config. 95 | vision_config: siglip_vision_config.SiglipVisionModelConfig | None = None 96 | # The factor by which the rope wave length is divided for global layers. 97 | rope_scaling_factor: int| None = None 98 | 99 | def get_dtype(self) -> Optional[torch.dtype]: 100 | """Gets the torch dtype from the config dtype string.""" 101 | return _STR_DTYPE_TO_TORCH_DTYPE.get(self.dtype, None) 102 | 103 | 104 | def get_config_for_7b(dtype: str = 'bfloat16') -> GemmaConfig: 105 | return GemmaConfig(dtype=dtype) 106 | 107 | 108 | def get_config_for_2b(dtype: str = 'bfloat16') -> GemmaConfig: 109 | return GemmaConfig( 110 | dtype=dtype, 111 | num_hidden_layers=18, 112 | num_attention_heads=8, 113 | num_key_value_heads=1, 114 | hidden_size=2048, 115 | intermediate_size=16384, 116 | ) 117 | 118 | 119 | def get_config_for_2b_v2(dtype: str = 'bfloat16') -> GemmaConfig: 120 | return GemmaConfig( 121 | dtype=dtype, 122 | architecture=Architecture.GEMMA_2, 123 | num_hidden_layers=26, 124 | num_attention_heads=8, 125 | num_key_value_heads=4, 126 | hidden_size=2304, 127 | intermediate_size=9216, 128 | use_pre_ffw_norm=True, 129 | use_post_ffw_norm=True, 130 | final_logit_softcapping=30.0, 131 | attn_logit_softcapping=50.0, 132 | head_dim=256, 133 | attn_types=[AttentionType.LOCAL_SLIDING, AttentionType.GLOBAL] * 13, 134 | sliding_window_size=4096, 135 | ) 136 | 137 | 138 | def get_config_for_9b(dtype: str = 'bfloat16') -> GemmaConfig: 139 | return GemmaConfig( 140 | dtype=dtype, 141 | architecture=Architecture.GEMMA_2, 142 | num_hidden_layers=42, 143 | num_attention_heads=16, 144 | num_key_value_heads=8, 145 | hidden_size=3584, 146 | intermediate_size=14336, 147 | use_pre_ffw_norm=True, 148 | use_post_ffw_norm=True, 149 | final_logit_softcapping=30.0, 150 | attn_logit_softcapping=50.0, 151 | head_dim=256, 152 | attn_types=[AttentionType.LOCAL_SLIDING, AttentionType.GLOBAL] * 21, 153 | sliding_window_size=4096, 154 | ) 155 | 156 | 157 | def get_config_for_27b(dtype: str = 'bfloat16') -> GemmaConfig: 158 | return GemmaConfig( 159 | dtype=dtype, 160 | architecture=Architecture.GEMMA_2, 161 | num_hidden_layers=46, 162 | num_attention_heads=32, 163 | num_key_value_heads=16, 164 | hidden_size=4608, 165 | intermediate_size=36864, 166 | use_pre_ffw_norm=True, 167 | use_post_ffw_norm=True, 168 | final_logit_softcapping=30.0, 169 | attn_logit_softcapping=50.0, 170 | head_dim=128, 171 | attn_types=[AttentionType.LOCAL_SLIDING, AttentionType.GLOBAL] * 23, 172 | sliding_window_size=4096, 173 | query_pre_attn_scalar=144, # hidden_size / num_attention_heads 174 | ) 175 | 176 | 177 | def get_config_for_1b(dtype: str) -> GemmaConfig: 178 | return GemmaConfig( 179 | dtype=dtype, 180 | architecture=Architecture.GEMMA_3, 181 | num_hidden_layers=26, 182 | num_attention_heads=4, 183 | num_key_value_heads=1, 184 | hidden_size=1152, 185 | intermediate_size=6912, 186 | use_pre_ffw_norm=True, 187 | use_post_ffw_norm=True, 188 | head_dim=256, 189 | attn_types=( 190 | AttentionType.LOCAL_SLIDING, 191 | AttentionType.LOCAL_SLIDING, 192 | AttentionType.LOCAL_SLIDING, 193 | AttentionType.LOCAL_SLIDING, 194 | AttentionType.LOCAL_SLIDING, 195 | AttentionType.GLOBAL, 196 | ), 197 | sliding_window_size=512, 198 | rope_wave_length={ 199 | AttentionType.LOCAL_SLIDING: 10_000, 200 | AttentionType.GLOBAL: 1_000_000, 201 | }, 202 | vocab_size=262_144, 203 | max_position_embeddings=32_768, 204 | tokenizer='tokenizer/gemma3_cleaned_262144_v2.spiece.model', 205 | use_qk_norm=True, 206 | vision_config=None, 207 | ) 208 | 209 | 210 | def get_config_for_4b(dtype: str) -> GemmaConfig: 211 | return GemmaConfig( 212 | dtype=dtype, 213 | architecture=Architecture.GEMMA_3, 214 | num_hidden_layers=34, 215 | num_attention_heads=8, 216 | num_key_value_heads=4, 217 | hidden_size=2560, 218 | intermediate_size=10240, 219 | use_pre_ffw_norm=True, 220 | use_post_ffw_norm=True, 221 | head_dim=256, 222 | attn_types=( 223 | AttentionType.LOCAL_SLIDING, 224 | AttentionType.LOCAL_SLIDING, 225 | AttentionType.LOCAL_SLIDING, 226 | AttentionType.LOCAL_SLIDING, 227 | AttentionType.LOCAL_SLIDING, 228 | AttentionType.GLOBAL, 229 | ), 230 | sliding_window_size=1024, 231 | rope_wave_length={ 232 | AttentionType.LOCAL_SLIDING: 10_000, 233 | AttentionType.GLOBAL: 1_000_000, 234 | }, 235 | vocab_size=262_144, 236 | tokenizer='tokenizer/gemma3_cleaned_262144_v2.spiece.model', 237 | use_qk_norm=True, 238 | vision_config=siglip_vision_config.get_siglip_vision_model_config(), 239 | rope_scaling_factor=8, 240 | ) 241 | 242 | 243 | def get_config_for_12b(dtype: str) -> GemmaConfig: 244 | return GemmaConfig( 245 | dtype=dtype, 246 | architecture=Architecture.GEMMA_3, 247 | num_hidden_layers=48, 248 | num_attention_heads=16, 249 | num_key_value_heads=8, 250 | hidden_size=3840, 251 | intermediate_size=3840 * 8 // 2, 252 | use_pre_ffw_norm=True, 253 | use_post_ffw_norm=True, 254 | head_dim=256, 255 | attn_types=( 256 | AttentionType.LOCAL_SLIDING, 257 | AttentionType.LOCAL_SLIDING, 258 | AttentionType.LOCAL_SLIDING, 259 | AttentionType.LOCAL_SLIDING, 260 | AttentionType.LOCAL_SLIDING, 261 | AttentionType.GLOBAL, 262 | ), 263 | sliding_window_size=1024, 264 | rope_wave_length={ 265 | AttentionType.LOCAL_SLIDING: 10_000, 266 | AttentionType.GLOBAL: 1_000_000, 267 | }, 268 | vocab_size=262_144, 269 | max_position_embeddings=131_072, 270 | tokenizer='tokenizer/gemma3_cleaned_262144_v2.spiece.model', 271 | use_qk_norm=True, 272 | vision_config=siglip_vision_config.get_siglip_vision_model_config(), 273 | rope_scaling_factor=8, 274 | ) 275 | 276 | 277 | def get_config_for_27b_v3(dtype: str) -> GemmaConfig: 278 | return GemmaConfig( 279 | dtype=dtype, 280 | architecture=Architecture.GEMMA_3, 281 | num_hidden_layers=62, 282 | num_attention_heads=32, 283 | num_key_value_heads=16, 284 | hidden_size=5376, 285 | intermediate_size=5376 * 8 // 2, 286 | use_pre_ffw_norm=True, 287 | use_post_ffw_norm=True, 288 | head_dim=128, 289 | query_pre_attn_scalar=5376 // 32, 290 | attn_types=( 291 | AttentionType.LOCAL_SLIDING, 292 | AttentionType.LOCAL_SLIDING, 293 | AttentionType.LOCAL_SLIDING, 294 | AttentionType.LOCAL_SLIDING, 295 | AttentionType.LOCAL_SLIDING, 296 | AttentionType.GLOBAL, 297 | ), 298 | sliding_window_size=1024, 299 | rope_wave_length={ 300 | AttentionType.LOCAL_SLIDING: 10_000, 301 | AttentionType.GLOBAL: 1_000_000, 302 | }, 303 | vocab_size=262_144, 304 | max_position_embeddings=131_072, 305 | tokenizer='tokenizer/gemma3_cleaned_262144_v2.spiece.model', 306 | use_qk_norm=True, 307 | vision_config=siglip_vision_config.get_siglip_vision_model_config(), 308 | rope_scaling_factor=8, 309 | ) 310 | 311 | 312 | def get_model_config(variant: str, dtype: str = 'bfloat16') -> GemmaConfig: 313 | """Gets the GemmaConfig for the diresired variant and dtype.""" 314 | # Gemma1 variants 315 | if variant == '7b': 316 | return get_config_for_7b(dtype) 317 | elif variant == '2b': 318 | return get_config_for_2b(dtype) 319 | # Gemma2 variants 320 | elif variant == '2b-v2': 321 | return get_config_for_2b_v2(dtype) 322 | elif variant == '9b': 323 | return get_config_for_9b(dtype) 324 | elif variant == '27b': 325 | return get_config_for_27b(dtype) 326 | # Gemma3 variants 327 | elif variant == '1b': 328 | return get_config_for_1b(dtype) 329 | elif variant == '4b': 330 | return get_config_for_4b(dtype) 331 | elif variant == '12b': 332 | return get_config_for_12b(dtype) 333 | elif variant == '27b_v3': 334 | return get_config_for_27b_v3(dtype) 335 | # Invalid variants 336 | else: 337 | raise ValueError( 338 | f'Invalid variant {variant}. Supported variants are "1b", "2b", ' 339 | '"2b-v2", "4b",, "7b", "9b" "12b", "27b", and "27b_v3".' 340 | ) 341 | -------------------------------------------------------------------------------- /gemma/gemma3_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Inference-only Gemma 3 multimodal model implementation.""" 15 | 16 | import torch 17 | import os 18 | import json 19 | import gc 20 | from torch import nn 21 | from PIL import Image 22 | from typing import Any, List, Sequence, Tuple, Union 23 | 24 | from . import model as gemma_model 25 | from . import config as gemma_config 26 | from . import gemma3_preprocessor 27 | from . import tokenizer 28 | from .siglip_vision import siglip_vision_model 29 | 30 | class Gemma3ForMultimodalLM(nn.Module): 31 | """Gemma3 model for multimodal causal LM.""" 32 | def __init__( 33 | self, 34 | config: gemma_config.GemmaConfig, 35 | ): 36 | super().__init__() 37 | self.dtype = config.get_dtype() 38 | assert config.architecture == gemma_config.Architecture.GEMMA_3 39 | self.config = config 40 | max_seq_len = config.max_position_embeddings 41 | head_dim = config.head_dim 42 | vocab_size = config.vocab_size 43 | self.tokenizer = tokenizer.Tokenizer(config.tokenizer) 44 | self.text_token_embedder = gemma_model.Embedding(vocab_size, config.hidden_size, config.quant) 45 | self.model = gemma_model.GemmaModel(config) 46 | self.sampler = gemma_model.Sampler(vocab_size, config) 47 | 48 | if config.vision_config is None: 49 | raise ValueError('vision_config must be provided for Gemma3.') 50 | self.siglip_vision_model = siglip_vision_model.SiglipVisionModel(config.vision_config) 51 | # transformer/embedder/mm_soft_embedding_norm 52 | self.mm_soft_embedding_norm = gemma_model.RMSNorm(config.vision_config.embedding_dim, 53 | eps = config.rms_norm_eps) 54 | # transformer/embedder/mm_input_projection 55 | self.mm_input_projection = gemma_model.Linear(config.vision_config.embedding_dim, config.hidden_size, config.quant) 56 | 57 | if config.rope_wave_length is None: 58 | raise ValueError('rope_wave_length must be provided for Gemma3.') 59 | rope_lengths = config.rope_wave_length 60 | defaults = { 61 | gemma_config.AttentionType.LOCAL_SLIDING: 10_000, 62 | gemma_config.AttentionType.GLOBAL: 10_000, 63 | } 64 | self._register_freqs_cis('local_freqs_cis', head_dim, max_seq_len, theta=rope_lengths.get( 65 | gemma_config.AttentionType.LOCAL_SLIDING, defaults[gemma_config.AttentionType.LOCAL_SLIDING] 66 | )) 67 | self._register_freqs_cis('global_freqs_cis', head_dim, max_seq_len, theta=rope_lengths.get( 68 | gemma_config.AttentionType.GLOBAL, defaults[gemma_config.AttentionType.GLOBAL] 69 | ), rope_scaling_factor=config.rope_scaling_factor) 70 | 71 | def _register_freqs_cis( 72 | self, name: str, head_dim: int, max_seq_len: int, theta: int = 10_000, rope_scaling_factor: int = 1 73 | ): 74 | self.register_buffer( 75 | name, gemma_model.precompute_freqs_cis(head_dim, max_seq_len * 2, theta=theta, rope_scaling_factor=rope_scaling_factor) 76 | ) 77 | 78 | @torch.no_grad() 79 | def forward(self, 80 | input_token_ids: torch.Tensor, # B x L 81 | image_patches: torch.Tensor, # B x N x C x H x W (3x896x896) 82 | image_presence_mask: torch.Tensor, # B x N 83 | input_positions: torch.Tensor, 84 | kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], 85 | mask: torch.Tensor, 86 | output_positions: torch.Tensor, 87 | temperatures: Union[torch.Tensor, None], 88 | top_ps: torch.Tensor, 89 | top_ks: torch.Tensor, 90 | local_mask: torch.Tensor | None = None, 91 | **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: 92 | freqs_cis = {} 93 | freqs_cis[gemma_config.AttentionType.LOCAL_SLIDING] = ( 94 | self.local_freqs_cis.index_select(0, input_positions) 95 | ) 96 | freqs_cis[gemma_config.AttentionType.GLOBAL] = ( 97 | self.global_freqs_cis.index_select(0, input_positions) 98 | ) 99 | hidden_states = self.text_token_embedder(input_token_ids) 100 | normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype, device=hidden_states.device) 101 | hidden_states = hidden_states * normalizer 102 | if image_patches is not None and self.config.vision_config is not None: 103 | # the input has images 104 | B, N, C, H, W = image_patches.shape 105 | # Flatten and Pass to SiglipVisionModel, and apply SiglipVisionModel Exit 106 | flattened_input = image_patches.reshape(B * N, C, H, W) # (B*N)xCxHxW 107 | image_embeddings = self.siglip_vision_model(flattened_input) # (B*N)xUxD 108 | image_embeddings = self.mm_soft_embedding_norm(image_embeddings) # (B*N) x U x D 109 | image_embeddings = self.mm_input_projection(image_embeddings) # (B*N) x U x model_dim 110 | hidden_states = self.populate_image_embeddings( 111 | hidden_states.clone(), 112 | image_embeddings.clone(), 113 | input_token_ids.clone(), 114 | image_presence_mask.clone(), 115 | ) 116 | 117 | kv_write_indices = input_positions 118 | 119 | hidden_states = self.model( 120 | hidden_states=hidden_states, 121 | freqs_cis=freqs_cis, 122 | kv_write_indices=kv_write_indices, 123 | kv_caches=kv_caches, 124 | mask=mask, 125 | local_mask=local_mask, 126 | ) 127 | embedder_weight = self.text_token_embedder.weight 128 | if self.config.quant: 129 | embedder_weight = ( 130 | embedder_weight * self.text_token_embedder.weight_scaler.unsqueeze(-1)) 131 | 132 | next_tokens, logits = self.sampler( 133 | embedding=embedder_weight, 134 | hidden_states=hidden_states, 135 | output_positions=output_positions, 136 | temperatures=temperatures, 137 | top_ps=top_ps, 138 | top_ks=top_ks, 139 | ) 140 | return next_tokens, logits 141 | 142 | def populate_image_embeddings(self, 143 | hidden_states: torch.Tensor, # B x L x model_dim 144 | image_embeddings: torch.Tensor, # (B*N) x U x model_dim 145 | input_token_ids: torch.Tensor, # B x L 146 | image_presence_mask: torch.Tensor, # B x N 147 | ): 148 | batch_size, seq_len, model_dim = hidden_states.shape 149 | # Step 1 of 2: Fetch valid image embeddings 150 | # flatten indices of valid image embeddings 151 | valid_image_embeddings_indices = torch.nonzero(image_presence_mask.flatten(), as_tuple=False).squeeze() 152 | # num_valid_images x model_dim 153 | valid_image_embeddings = image_embeddings.index_select(0, valid_image_embeddings_indices) 154 | 155 | # Step 2 of 2: Replace image embeddings at right places. 156 | image_placeholder_mask = input_token_ids == self.tokenizer.image_token_placeholder_id 157 | image_placeholder_indices = image_placeholder_mask.flatten().nonzero(as_tuple=False).squeeze() 158 | 159 | hidden_states = hidden_states.reshape(-1, self.config.hidden_size) 160 | hidden_states[image_placeholder_indices] = valid_image_embeddings.reshape(-1, self.config.hidden_size) 161 | return hidden_states.reshape(batch_size, seq_len, model_dim).contiguous() 162 | 163 | def create_attention_mask(self, input_ids: torch.Tensor, sequence_length: int): 164 | batch_size = input_ids.shape[0] 165 | causal_mask = torch.tril(torch.ones((batch_size, 1, sequence_length, sequence_length), dtype=torch.bool, device=input_ids.device)) 166 | image_token_mask = input_ids == self.tokenizer.image_token_placeholder_id 167 | # Pad the mask to the left with 0. This is to make sure the boundary 168 | # detection works correctly. Boundary (starting index of image patch) is 169 | # detected when the value changes from 0 to 1. 170 | padded_mask = nn.functional.pad(image_token_mask, (1, 0), value=0) 171 | # Find the boundary (starting index) of the image tokens patch. 172 | boundary = padded_mask[:, 1:] > padded_mask[:, :-1] 173 | # Number the boundary. 174 | # boundary: 175 | # [[False, False, True, False, False], 176 | # [False, True, False, True, False]] 177 | # numbered_boundary: 178 | # [[0, 0, 1, 1, 1], 179 | # [0, 1, 1, 2, 2]] 180 | numbered_boundary = torch.cumsum(boundary, dim=-1) 181 | 182 | # image_token_mask: 183 | # [[False, False, True, True, False], 184 | # [True, True, False, True, True]] 185 | # numbered_boundary: 186 | # [[0, 0, 1, 1, 1], 187 | # [1, 1, 1, 2, 2]] 188 | # q_block_indices: 189 | # [[0, 0, 1, 1, 0], 190 | # [1, 1, 0, 2, 2]] 191 | q_block_indices = image_token_mask * numbered_boundary 192 | kv_block_indices = q_block_indices 193 | # Test the equality of vertical and horizontal numbered patches 194 | # to create the bidirectional mask. 195 | bidirectional_mask = torch.logical_and( 196 | kv_block_indices[:, None, :] == q_block_indices.unsqueeze(-1), 197 | q_block_indices.unsqueeze(-1) > 0, 198 | ) 199 | attention_mask = torch.logical_or(causal_mask, bidirectional_mask.unsqueeze(1)) 200 | # The upper triangular matrix's diagonal is shifted by sliding window size 201 | # before doing logical 'and' with attention mask. This is to make sure the 202 | # local attention is within the sliding window. 203 | local_mask = torch.logical_and( 204 | attention_mask, 205 | torch.triu(torch.ones((1, 1, sequence_length, sequence_length), dtype=torch.bool, device=input_ids.device), diagonal=-(self.config.sliding_window_size-1)) 206 | ) 207 | return attention_mask, local_mask 208 | 209 | def generate( 210 | self, 211 | prompts: Sequence[Sequence[Union[str, Image.Image]]], 212 | device: Any, 213 | output_len: int = 100, 214 | temperature: Union[float, None] = 1.0, 215 | top_p: float = 0.95, 216 | top_k: int = 64, 217 | ) -> Sequence[str]: 218 | """Generates responses for given prompts using Gemma model.""" 219 | # Inference only. 220 | processing_result = gemma3_preprocessor.tokenize_raw_input( 221 | self.tokenizer, prompts, self.config, output_len, device 222 | ) 223 | batch_size = processing_result["batch_size"] 224 | user_input_token_ids = processing_result["user_input_token_ids"] 225 | image_batch = processing_result["image_batch"] 226 | min_prompt_len = processing_result["min_prompt_len"] 227 | max_prompt_len = processing_result["max_prompt_len"] 228 | total_seq_len = processing_result["max_seq_len"] 229 | image_presence_mask = processing_result["image_presence_mask"] 230 | 231 | # Create attention mask. 232 | min_dtype = torch.finfo(self.dtype).min 233 | if self.config.sliding_window_size is None: 234 | raise ValueError('gemma 3 model requires sliding_window size') 235 | boolean_mask, local_boolean_mask = self.create_attention_mask(user_input_token_ids, total_seq_len) 236 | mask_tensor = torch.where(boolean_mask, 0, torch.tensor(min_dtype, dtype=torch.float32, device=device)).contiguous() 237 | local_mask_tensor = torch.where(local_boolean_mask, 0, torch.tensor(min_dtype, dtype=torch.float32, device=device)).contiguous() 238 | 239 | kv_caches = [] 240 | for _ in range(self.config.num_hidden_layers): 241 | size = (batch_size, total_seq_len, self.config.num_key_value_heads, 242 | self.config.head_dim) 243 | dtype = self.config.get_dtype() 244 | k_cache = torch.zeros(size=size, dtype=dtype, device=device) 245 | v_cache = torch.zeros(size=size, dtype=dtype, device=device) 246 | kv_caches.append((k_cache, v_cache)) 247 | 248 | input_token_ids_tensor = torch.full((batch_size, min_prompt_len), 249 | self.tokenizer.pad_id, 250 | dtype=torch.int64, device=device) 251 | token_ids_tensor = user_input_token_ids.to(device) 252 | for i in range(batch_size): 253 | p = user_input_token_ids[i] 254 | input_token_ids_tensor[i, :min_prompt_len] = p[:min_prompt_len] 255 | 256 | input_positions_tensor = torch.arange(0, min_prompt_len, dtype=torch.int64, device=device) 257 | prompt_mask_tensor = token_ids_tensor != self.tokenizer.pad_id 258 | curr_mask_tensor = mask_tensor.index_select(2, input_positions_tensor) 259 | curr_local_mask_tensor = local_mask_tensor.index_select(2, input_positions_tensor) 260 | output_positions_tensor = torch.LongTensor([min_prompt_len - 1]).to(device) 261 | temperatures_tensor = None if not temperature else torch.FloatTensor( 262 | [temperature] * batch_size).to(device) 263 | top_ps_tensor = torch.FloatTensor([top_p] * batch_size).to(device) 264 | top_ks_tensor = torch.LongTensor([top_k] * batch_size).to(device) 265 | output_index = torch.tensor(min_prompt_len, dtype=torch.int64, device=device) 266 | 267 | # Prefill up to min_prompt_len tokens, then treat other prefill as 268 | # decode and ignore output. 269 | for i in range(total_seq_len - min_prompt_len): 270 | next_token_ids, _ = self( 271 | input_token_ids=input_token_ids_tensor, 272 | image_patches=image_batch, 273 | image_presence_mask=image_presence_mask, 274 | input_positions=input_positions_tensor, 275 | kv_caches=kv_caches, 276 | mask=curr_mask_tensor, 277 | output_positions=output_positions_tensor, 278 | temperatures=temperatures_tensor, 279 | top_ps=top_ps_tensor, 280 | top_ks=top_ks_tensor, 281 | local_mask=curr_local_mask_tensor, 282 | ) 283 | curr_prompt_mask = prompt_mask_tensor.index_select( 284 | 1, output_index).squeeze(dim=1) 285 | curr_token_ids = token_ids_tensor.index_select( 286 | 1, output_index).squeeze(dim=1) 287 | output_token_ids = torch.where(curr_prompt_mask, curr_token_ids, 288 | next_token_ids).unsqueeze(dim=1) 289 | token_ids_tensor.index_copy_(1, output_index, output_token_ids) 290 | 291 | input_token_ids_tensor = output_token_ids 292 | input_positions_tensor = output_index.unsqueeze(dim=-1) 293 | curr_mask_tensor = mask_tensor.index_select(2, 294 | input_positions_tensor) 295 | curr_local_mask_tensor = local_mask_tensor.index_select( 296 | 2, input_positions_tensor 297 | ) if local_mask_tensor is not None else None 298 | output_positions_tensor = torch.tensor(0, dtype=torch.int64, device=device) 299 | output_index = output_index + 1 300 | image_batch = None 301 | image_presence_mask = None 302 | 303 | # Detokenization. 304 | token_ids = token_ids_tensor.tolist() 305 | results = [] 306 | for i, tokens in enumerate(token_ids): 307 | output = tokens 308 | if self.tokenizer.eos_id in output: 309 | eos_index = output.index(self.tokenizer.eos_id) 310 | output = output[:eos_index] 311 | results.append(self.tokenizer.decode(output)) 312 | 313 | return results 314 | 315 | def load_weights(self, model_path: str): 316 | if os.path.isfile(model_path): 317 | self.load_state_dict( 318 | torch.load( 319 | model_path, mmap=True, weights_only=True, 320 | )['model_state_dict'], 321 | strict=False, 322 | ) 323 | else: 324 | index_path = os.path.join(model_path, 'pytorch_model.bin.index.json') 325 | with open(index_path, "r", encoding="utf-8") as f: 326 | index = json.load(f) 327 | shard_files = list(set(index["weight_map"].values())) 328 | for shard_file in shard_files: 329 | shard_path = os.path.join(model_path, shard_file) 330 | state_dict = torch.load(shard_path, map_location="cpu", weights_only=True) 331 | self.load_state_dict(state_dict, strict=False) 332 | del state_dict # Save memory. 333 | gc.collect() 334 | -------------------------------------------------------------------------------- /gemma/model_xla.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Inference-only Gemma model implementation.""" 16 | 17 | import json 18 | import gc 19 | import os 20 | import re 21 | import torch 22 | from torch import nn 23 | import torch.nn.functional as F 24 | from typing import List, Mapping, Optional, Tuple, Union 25 | 26 | from google3.third_party.open_models_release.gemma_pytorch.gemma import config as gemma_config 27 | from google3.third_party.open_models_release.gemma_pytorch.gemma.xla_model_parallel import ( 28 | ColumnParallelLinear, 29 | ParallelEmbedding, 30 | RowParallelLinear, 31 | reduce_from_model_parallel_region, 32 | scatter_to_model_parallel_region, 33 | ) 34 | 35 | 36 | class Sampler(nn.Module): 37 | 38 | def __init__(self, vocab_size: int, world_size: int, rank: int, 39 | config: gemma_config.GemmaConfig) -> None: 40 | super().__init__() 41 | self.vocab_size = vocab_size 42 | self.world_size = world_size 43 | self.rank = rank 44 | self.config = config 45 | 46 | @torch.no_grad() 47 | def forward( 48 | self, 49 | embedding: torch.Tensor, 50 | hidden_states: torch.Tensor, 51 | output_positions: torch.Tensor, 52 | temperatures: Union[torch.Tensor, None], 53 | top_ps: torch.Tensor, 54 | top_ks: torch.Tensor, 55 | embedding_bias: Optional[torch.Tensor] = None, 56 | ) -> Tuple[torch.Tensor, torch.Tensor]: 57 | # Select the last element for each sequence. 58 | # (batch_size, input_len, hidden_size) -> (batch_size, hidden_size) 59 | hidden_states = hidden_states.index_select( 60 | 1, output_positions).squeeze(dim=1) 61 | 62 | hidden_states_parallel = scatter_to_model_parallel_region( 63 | hidden_states, 64 | groups=None, 65 | world_size=self.world_size, 66 | rank=self.rank) 67 | hidden_states_parallel = torch.matmul(hidden_states_parallel, 68 | embedding.t()) 69 | logits = reduce_from_model_parallel_region( 70 | hidden_states_parallel, 71 | groups=None, 72 | world_size=self.world_size, 73 | rank=self.rank, 74 | ) 75 | if embedding_bias is not None: 76 | logits += embedding_bias 77 | if self.config.final_logit_softcapping is not None: 78 | logits = logits / self.config.final_logit_softcapping 79 | logits = torch.tanh(logits) 80 | logits = logits * self.config.final_logit_softcapping 81 | 82 | if temperatures is None: 83 | return torch.argmax(logits, dim=-1).squeeze(dim=-1), logits 84 | 85 | # Apply temperature scaling. 86 | logits.div_(temperatures.unsqueeze(dim=1)) 87 | 88 | # Calculate probabilities with softmax. 89 | probs = torch.softmax(logits, dim=-1, dtype=torch.float) 90 | probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) 91 | 92 | # Apply top-p, top-k. 93 | probs_sum = torch.cumsum(probs_sort, dim=-1) 94 | top_ps_mask = (probs_sum - probs_sort) > top_ps.unsqueeze(dim=1) 95 | probs_sort = torch.where(top_ps_mask, 0, probs_sort) 96 | 97 | top_ks_mask = torch.arange(probs_idx.shape[-1], 98 | device=probs_idx.device) 99 | top_ks_mask = top_ks_mask.expand(probs_idx.shape[0], -1) 100 | top_ks_mask = top_ks_mask >= top_ks.unsqueeze(dim=1) 101 | probs_sort = torch.where(top_ks_mask, 0, probs_sort) 102 | 103 | # Re-normalization. 104 | probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) 105 | probs = torch.gather(probs_sort, 106 | dim=-1, 107 | index=torch.argsort(probs_idx, dim=-1)) 108 | 109 | next_token_ids = torch.multinomial(probs, 110 | num_samples=1, 111 | replacement=True).squeeze(dim=-1) 112 | return next_token_ids, logits 113 | 114 | 115 | def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): 116 | """Precomputes the frequency cis.""" 117 | freqs = 1.0 / (theta**(torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) 118 | t = torch.arange(end, device=freqs.device) # type: ignore 119 | freqs = torch.outer(t, freqs).float() # type: ignore 120 | freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 121 | return freqs_cis 122 | 123 | 124 | def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: 125 | """Applies the rotary embedding to the query and key tensors.""" 126 | x_ = torch.view_as_complex( 127 | torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1), 128 | dim=-1)) 129 | x_out = torch.view_as_real(x_ * freqs_cis).type_as(x) 130 | x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2) 131 | x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2], 132 | -1).transpose(1, 2) 133 | return x_out 134 | 135 | 136 | class RMSNorm(torch.nn.Module): 137 | 138 | def __init__( 139 | self, 140 | dim: int, 141 | eps: float = 1e-6, 142 | add_unit_offset: bool = True, 143 | ): 144 | super().__init__() 145 | self.eps = eps 146 | self.add_unit_offset = add_unit_offset 147 | self.weight = nn.Parameter(torch.ones(dim)) 148 | 149 | def _norm(self, x): 150 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 151 | 152 | def forward(self, x): 153 | # Llama does x.to(float16) * w whilst Gemma2 is (x * w).to(float16) 154 | # See https://github.com/huggingface/transformers/pull/29402 155 | output = self._norm(x.float()) 156 | if self.add_unit_offset: 157 | output = output * (1 + self.weight.float()) 158 | else: 159 | output = output * self.weight.float() 160 | return output.type_as(x) 161 | 162 | 163 | class GemmaMLP(nn.Module): 164 | 165 | def __init__( 166 | self, 167 | hidden_size: int, 168 | intermediate_size: int, 169 | world_size: int, 170 | rank: int, 171 | quant: bool, 172 | ): 173 | super().__init__() 174 | self.hidden_size = hidden_size 175 | self.intermediate_size = intermediate_size 176 | 177 | def init_method(x): 178 | return x 179 | 180 | self.gate_proj = ColumnParallelLinear( 181 | hidden_size, 182 | intermediate_size, 183 | bias=False, 184 | gather_output=False, 185 | init_method=init_method, 186 | world_size=world_size, 187 | rank=rank, 188 | quant=quant, 189 | ) 190 | 191 | self.up_proj = ColumnParallelLinear( 192 | hidden_size, 193 | intermediate_size, 194 | bias=False, 195 | gather_output=False, 196 | init_method=init_method, 197 | world_size=world_size, 198 | rank=rank, 199 | quant=quant, 200 | ) 201 | 202 | self.down_proj = RowParallelLinear( 203 | intermediate_size, 204 | hidden_size, 205 | bias=False, 206 | input_is_parallel=True, 207 | init_method=init_method, 208 | world_size=world_size, 209 | rank=rank, 210 | quant=quant, 211 | ) 212 | 213 | def forward(self, x): 214 | gate = self.gate_proj(x) 215 | gate = F.gelu(gate, approximate="tanh") 216 | up = self.up_proj(x) 217 | fuse = gate * up 218 | outputs = self.down_proj(fuse) 219 | return outputs 220 | 221 | 222 | class GemmaAttention(nn.Module): 223 | 224 | def __init__( 225 | self, 226 | hidden_size: int, 227 | num_heads: int, 228 | num_kv_heads: int, 229 | attn_logit_softcapping: Optional[float], 230 | query_pre_attn_scalar: Optional[int], 231 | head_dim: int, 232 | world_size: int, 233 | rank: int, 234 | quant: bool, 235 | attn_type: gemma_config.AttentionType, 236 | sliding_window_size: Optional[int] = None, 237 | ): 238 | super().__init__() 239 | self.rank = rank 240 | 241 | def init_method(x): 242 | return x 243 | 244 | self.total_num_heads = num_heads 245 | assert self.total_num_heads % world_size == 0 246 | self.num_heads = self.total_num_heads // world_size # head per shard 247 | 248 | if num_kv_heads < world_size: 249 | assert world_size % num_kv_heads == 0 250 | self.total_num_kv_heads = world_size 251 | else: 252 | assert num_kv_heads % world_size == 0 253 | self.total_num_kv_heads = num_kv_heads 254 | self.num_kv_heads = self.total_num_kv_heads // world_size # kv head per shard 255 | 256 | assert self.num_heads % self.num_kv_heads == 0 257 | self.num_queries_per_kv = self.num_heads // self.num_kv_heads 258 | 259 | self.hidden_size = hidden_size 260 | self.head_dim = head_dim 261 | 262 | self.q_size = self.num_heads * self.head_dim 263 | self.kv_size = self.num_kv_heads * self.head_dim 264 | 265 | if query_pre_attn_scalar is not None: 266 | self.scaling = query_pre_attn_scalar**-0.5 267 | else: 268 | self.scaling = self.head_dim**-0.5 269 | 270 | self.qkv_proj = ColumnParallelLinear( 271 | self.hidden_size, 272 | (self.total_num_heads + 2 * self.total_num_kv_heads) * 273 | self.head_dim, 274 | bias=False, 275 | gather_output=False, 276 | init_method=init_method, 277 | world_size=world_size, 278 | rank=rank, 279 | quant=quant, 280 | ) 281 | 282 | self.o_proj = RowParallelLinear( 283 | self.total_num_heads * self.head_dim, 284 | self.hidden_size, 285 | bias=False, 286 | input_is_parallel=True, 287 | init_method=init_method, 288 | world_size=world_size, 289 | rank=rank, 290 | quant=quant, 291 | ) 292 | 293 | self.attn_type = attn_type 294 | self.sliding_window_size = sliding_window_size 295 | self.attn_logit_softcapping = attn_logit_softcapping 296 | 297 | def forward( 298 | self, 299 | hidden_states: torch.Tensor, 300 | freqs_cis: torch.Tensor, 301 | kv_write_indices: torch.Tensor, 302 | kv_cache: Tuple[torch.Tensor, torch.Tensor], 303 | mask: torch.Tensor, 304 | ) -> torch.Tensor: 305 | hidden_states_shape = hidden_states.shape 306 | assert len(hidden_states_shape) == 3 307 | 308 | batch_size, input_len, _ = hidden_states_shape 309 | 310 | qkv = self.qkv_proj(hidden_states) 311 | xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size], 312 | dim=-1) 313 | 314 | xq = xq.view(batch_size, -1, self.num_heads, self.head_dim) 315 | xk = xk.view(batch_size, -1, self.num_kv_heads, self.head_dim) 316 | xv = xv.view(batch_size, -1, self.num_kv_heads, self.head_dim) 317 | 318 | # Positional embedding. 319 | xq = apply_rotary_emb(xq, freqs_cis=freqs_cis) 320 | xk = apply_rotary_emb(xk, freqs_cis=freqs_cis) 321 | 322 | # Write new kv cache. 323 | # [batch_size, input_len, n_local_kv_heads, head_dim] 324 | k_cache, v_cache = kv_cache 325 | k_cache.index_copy_(1, kv_write_indices, xk) 326 | v_cache.index_copy_(1, kv_write_indices, xv) 327 | 328 | key = k_cache 329 | value = v_cache 330 | if self.num_kv_heads != self.num_heads: 331 | # [batch_size, max_seq_len, n_local_heads, head_dim] 332 | key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=2) 333 | value = torch.repeat_interleave(value, 334 | self.num_queries_per_kv, 335 | dim=2) 336 | 337 | # [batch_size, n_local_heads, input_len, head_dim] 338 | q = xq.transpose(1, 2) 339 | # [batch_size, n_local_heads, max_seq_len, head_dim] 340 | k = key.transpose(1, 2) 341 | v = value.transpose(1, 2) 342 | 343 | # [batch_size, n_local_heads, input_len, max_seq_len] 344 | q.mul_(self.scaling) 345 | scores = torch.matmul(q, k.transpose(2, 3)) 346 | if ( 347 | self.attn_type == gemma_config.AttentionType.LOCAL_SLIDING 348 | and self.sliding_window_size is not None 349 | ): 350 | all_ones = torch.ones_like(mask) 351 | sliding_mask = torch.triu( 352 | all_ones, -1 * self.sliding_window_size + 1 353 | ) * torch.tril(all_ones, self.sliding_window_size - 1) 354 | mask = torch.where(sliding_mask == 1, mask, -2.3819763e38) 355 | if self.attn_logit_softcapping is not None: 356 | scores = scores / self.attn_logit_softcapping 357 | scores = torch.tanh(scores) 358 | scores = scores * self.attn_logit_softcapping 359 | scores = scores + mask 360 | scores = F.softmax(scores.float(), dim=-1).type_as(q) 361 | 362 | # [batch_size, n_local_heads, input_len, head_dim] 363 | output = torch.matmul(scores, v) 364 | 365 | # [batch_size, input_len, hidden_dim] 366 | output = (output.transpose(1, 2).contiguous().view( 367 | batch_size, input_len, -1)) 368 | output = self.o_proj(output) 369 | return output 370 | 371 | 372 | class GemmaDecoderLayer(nn.Module): 373 | 374 | def __init__( 375 | self, 376 | config: gemma_config.GemmaConfig, 377 | world_size: int, 378 | rank: int, 379 | ): 380 | super().__init__() 381 | self.rank = rank 382 | self.self_attn = GemmaAttention( 383 | hidden_size=config.hidden_size, 384 | num_heads=config.num_attention_heads, 385 | num_kv_heads=config.num_key_value_heads, 386 | attn_logit_softcapping=config.attn_logit_softcapping, 387 | query_pre_attn_scalar=config.query_pre_attn_scalar, 388 | head_dim=config.head_dim, 389 | world_size=world_size, 390 | rank=rank, 391 | quant=config.quant, 392 | attn_type=gemma_config.AttentionType.GLOBAL, 393 | ) 394 | self.mlp = GemmaMLP( 395 | hidden_size=config.hidden_size, 396 | intermediate_size=config.intermediate_size, 397 | world_size=world_size, 398 | rank=rank, 399 | quant=config.quant, 400 | ) 401 | self.input_layernorm = RMSNorm(config.hidden_size, 402 | eps=config.rms_norm_eps) 403 | self.post_attention_layernorm = RMSNorm(config.hidden_size, 404 | eps=config.rms_norm_eps) 405 | 406 | def forward( 407 | self, 408 | hidden_states: torch.Tensor, 409 | freqs_cis: torch.Tensor, 410 | kv_write_indices: torch.Tensor, 411 | kv_cache: Tuple[torch.Tensor, torch.Tensor], 412 | mask: torch.Tensor, 413 | ) -> torch.Tensor: 414 | # Self Attention 415 | residual = hidden_states 416 | hidden_states = self.input_layernorm(hidden_states) 417 | hidden_states = self.self_attn( 418 | hidden_states=hidden_states, 419 | freqs_cis=freqs_cis, 420 | kv_write_indices=kv_write_indices, 421 | kv_cache=kv_cache, 422 | mask=mask, 423 | ) 424 | hidden_states = residual + hidden_states 425 | 426 | # MLP 427 | residual = hidden_states 428 | hidden_states = self.post_attention_layernorm(hidden_states) 429 | hidden_states = self.mlp(hidden_states) 430 | hidden_states = residual + hidden_states 431 | 432 | return hidden_states 433 | 434 | 435 | class Gemma2DecoderLayer(nn.Module): 436 | 437 | def __init__( 438 | self, 439 | config: gemma_config.GemmaConfig, 440 | attn_type: gemma_config.AttentionType, 441 | world_size: int, 442 | rank: int, 443 | ): 444 | super().__init__() 445 | self.rank = rank 446 | self.self_attn = GemmaAttention( 447 | hidden_size=config.hidden_size, 448 | num_heads=config.num_attention_heads, 449 | num_kv_heads=config.num_key_value_heads, 450 | attn_logit_softcapping=config.attn_logit_softcapping, 451 | query_pre_attn_scalar=config.query_pre_attn_scalar, 452 | head_dim=config.head_dim, 453 | world_size=world_size, 454 | rank=rank, 455 | quant=config.quant, 456 | attn_type=attn_type, 457 | sliding_window_size=config.sliding_window_size, 458 | ) 459 | self.mlp = GemmaMLP( 460 | hidden_size=config.hidden_size, 461 | intermediate_size=config.intermediate_size, 462 | world_size=world_size, 463 | rank=rank, 464 | quant=config.quant, 465 | ) 466 | self.input_layernorm = RMSNorm(config.hidden_size, 467 | eps=config.rms_norm_eps) 468 | self.post_attention_layernorm = RMSNorm(config.hidden_size, 469 | eps=config.rms_norm_eps) 470 | self.pre_feedforward_layernorm = ( 471 | RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 472 | if config.use_pre_ffw_norm 473 | else None 474 | ) 475 | self.post_feedforward_layernorm = ( 476 | RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 477 | if config.use_post_ffw_norm 478 | else None 479 | ) 480 | 481 | def forward( 482 | self, 483 | hidden_states: torch.Tensor, 484 | freqs_cis: torch.Tensor, 485 | kv_write_indices: torch.Tensor, 486 | kv_cache: Tuple[torch.Tensor, torch.Tensor], 487 | mask: torch.Tensor, 488 | ) -> torch.Tensor: 489 | # Self Attention 490 | residual = hidden_states 491 | hidden_states = self.input_layernorm(hidden_states) 492 | hidden_states = self.self_attn( 493 | hidden_states=hidden_states, 494 | freqs_cis=freqs_cis, 495 | kv_write_indices=kv_write_indices, 496 | kv_cache=kv_cache, 497 | mask=mask, 498 | ) 499 | hidden_states = self.post_attention_layernorm(hidden_states) 500 | hidden_states = residual + hidden_states 501 | 502 | # MLP 503 | residual = hidden_states 504 | if self.pre_feedforward_layernorm is not None: 505 | hidden_states = self.pre_feedforward_layernorm(hidden_states) 506 | hidden_states = self.mlp(hidden_states) 507 | if self.post_feedforward_layernorm is not None: 508 | hidden_states = self.post_feedforward_layernorm(hidden_states) 509 | hidden_states = residual + hidden_states 510 | 511 | return hidden_states 512 | 513 | 514 | class GemmaModel(nn.Module): 515 | 516 | def __init__( 517 | self, 518 | config: gemma_config.GemmaConfig, 519 | world_size: int, 520 | rank: int 521 | ): 522 | super().__init__() 523 | self.config = config 524 | self.rank = rank 525 | self.vocab_size = config.vocab_size 526 | 527 | self.layers = nn.ModuleList() 528 | for i in range(config.num_hidden_layers): 529 | if config.architecture == gemma_config.Architecture.GEMMA_1: 530 | self.layers.append(GemmaDecoderLayer(config, world_size, rank)) 531 | elif config.architecture == gemma_config.Architecture.GEMMA_2: 532 | attn_type = ( 533 | config.attn_types[i] 534 | if config.attn_types is not None 535 | else gemma_config.AttentionType.GLOBAL 536 | ) 537 | self.layers.append( 538 | Gemma2DecoderLayer(config, attn_type, world_size, rank)) 539 | else: 540 | raise ValueError(f'Unknown architecture: {config.architecture}') 541 | self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 542 | 543 | def forward( 544 | self, 545 | hidden_states: torch.Tensor, 546 | freqs_cis: torch.Tensor, 547 | kv_write_indices: torch.Tensor, 548 | kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], 549 | mask: torch.Tensor, 550 | ) -> torch.Tensor: 551 | for i in range(len(self.layers)): 552 | layer = self.layers[i] 553 | hidden_states = layer( 554 | hidden_states=hidden_states, 555 | freqs_cis=freqs_cis, 556 | kv_write_indices=kv_write_indices, 557 | kv_cache=kv_caches[i], 558 | mask=mask, 559 | ) 560 | hidden_states = self.norm(hidden_states) 561 | return hidden_states 562 | 563 | 564 | class GemmaForCausalLM(nn.Module): 565 | 566 | def __init__( 567 | self, 568 | config: gemma_config.GemmaConfig, 569 | world_size: int, 570 | rank: int, 571 | device: torch.device, 572 | ): 573 | super().__init__() 574 | self.config = config 575 | self.world_size = world_size 576 | self.rank = rank 577 | self.device = device 578 | 579 | assert config.num_attention_heads % world_size == 0 580 | assert config.hidden_size % config.num_attention_heads == 0 581 | 582 | max_seq_len = config.max_position_embeddings 583 | head_dim = config.head_dim 584 | vocab_size = config.vocab_size 585 | 586 | def init_method(x): 587 | return x 588 | 589 | self.embedder = ParallelEmbedding( 590 | vocab_size, 591 | config.hidden_size, 592 | init_method=init_method, 593 | world_size=world_size, 594 | rank=rank, 595 | quant=config.quant, 596 | ) 597 | self.model = GemmaModel(config, world_size, rank) 598 | self.sampler = Sampler(vocab_size, world_size, rank, config) 599 | 600 | rope_theta = getattr(config, 'rope_theta', 10000) 601 | # [head_dim * 2, ] -> complex -> two dim (real, imaginary) implicitly 602 | freqs_cis = precompute_freqs_cis(head_dim, 603 | max_seq_len * 2, 604 | theta=rope_theta) 605 | self.register_buffer('freqs_cis', freqs_cis) 606 | 607 | @torch.no_grad() 608 | def forward( 609 | self, 610 | input_token_ids: torch.Tensor, 611 | input_positions: torch.Tensor, 612 | kv_write_indices: torch.Tensor, 613 | kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], 614 | mask: torch.Tensor, 615 | output_positions: torch.Tensor, 616 | temperatures: Union[torch.Tensor, None], 617 | top_ps: torch.Tensor, 618 | top_ks: torch.Tensor, 619 | **kwargs, 620 | ) -> Tuple[torch.Tensor, torch.Tensor]: 621 | freqs_cis = self.freqs_cis.index_select(0, input_positions) 622 | kv_write_indices = input_positions 623 | 624 | hidden_states = self.embedder(input_token_ids) 625 | # Gemma normalizes the embedding by sqrt(hidden_size). 626 | # Gemma2 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 627 | # See https://github.com/huggingface/transformers/pull/29402 628 | normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) 629 | hidden_states = hidden_states * normalizer 630 | # hidden_states should be [batch_size, input_len, hidden_size] 631 | 632 | hidden_states = self.model( 633 | hidden_states=hidden_states, 634 | freqs_cis=freqs_cis, 635 | kv_write_indices=kv_write_indices, 636 | kv_caches=kv_caches, 637 | mask=mask, 638 | ) 639 | embedder_weight = self.embedder.weight 640 | if self.config.quant: 641 | embedder_weight = ( 642 | embedder_weight * self.embedder.weight_scaler.unsqueeze(-1)) 643 | next_tokens, logits = self.sampler( 644 | embedding=embedder_weight, 645 | hidden_states=hidden_states, 646 | output_positions=output_positions, 647 | temperatures=temperatures, 648 | top_ps=top_ps, 649 | top_ks=top_ks, 650 | ) 651 | return next_tokens, logits 652 | 653 | def _load_weights(self, model_state_dict: Mapping[str, torch.Tensor]): 654 | num_attn_heads = self.config.num_attention_heads 655 | num_kv_heads = self.config.num_key_value_heads 656 | head_dim = self.config.head_dim 657 | hidden_size = self.config.hidden_size 658 | 659 | def split(tensor: torch.Tensor, axis: int) -> torch.Tensor: 660 | axis_len = tensor.shape[axis] 661 | split_len = axis_len // self.world_size 662 | split_start = split_len * self.rank 663 | split_end = split_start + split_len 664 | tensor = torch.moveaxis(tensor, axis, 0) 665 | tensor = tensor[split_start:split_end, ...] 666 | tensor = torch.moveaxis(tensor, 0, axis) 667 | return tensor 668 | 669 | for k, v in model_state_dict.items(): 670 | if k == 'freqs_cis': 671 | continue 672 | if (k == 'model.norm.weight' 673 | or k.endswith('_layernorm.weight') 674 | or k.endswith('weight_scaler')): 675 | pass 676 | elif (k == 'embedder.weight' or re.fullmatch( 677 | r'model.layers.\d+.mlp.down_proj.weight', k)): 678 | v = split(v, 1) 679 | elif (re.fullmatch(r'model.layers.\d+.mlp.gate_proj.weight', k) 680 | or re.fullmatch(r'model.layers.\d+.mlp.up_proj.weight', k)): 681 | v = split(v, 0) 682 | elif re.fullmatch(r'model.layers.\d+.self_attn.qkv_proj.weight', 683 | k): 684 | if num_kv_heads <= num_attn_heads: 685 | # If num_kv_heads > self.world_size, we still want 1 686 | # replica. 687 | num_replicas = max(self.world_size // num_kv_heads, 1) 688 | v = v.reshape(num_attn_heads + num_kv_heads * 2, head_dim, 689 | hidden_size) 690 | query = v[:num_attn_heads, ...] 691 | key = v[num_attn_heads:num_attn_heads + num_kv_heads, 692 | ...].repeat(num_replicas, 1, 1) 693 | value = v[-num_kv_heads:, ...].repeat(num_replicas, 1, 1) 694 | v = torch.cat( 695 | (split(query, 0), split(key, 0), split(value, 0)), 696 | dim=0) 697 | else: 698 | v = v.reshape(3, num_attn_heads, head_dim, hidden_size) 699 | v = split(v, 1) 700 | v = v.reshape(-1, hidden_size) 701 | elif re.fullmatch(r'model.layers.\d+.self_attn.o_proj.weight', k): 702 | v = v.reshape(hidden_size, num_attn_heads, head_dim) 703 | v = split(v, 1) 704 | v = v.reshape(hidden_size, -1) 705 | else: 706 | raise ValueError(f'Unrecognized key: {k}') 707 | self.state_dict()[k].copy_(v) 708 | 709 | def load_weights(self, model_path: str): 710 | if os.path.isfile(model_path): 711 | checkpoint = torch.load(model_path, weights_only=True) 712 | model_state_dict = checkpoint['model_state_dict'] 713 | self._load_weights(model_state_dict) 714 | else: 715 | index_path = os.path.join(model_path, 'pytorch_model.bin.index.json') 716 | with open(index_path, "r", encoding="utf-8") as f: 717 | index = json.load(f) 718 | shard_files = list(set(index["weight_map"].values())) 719 | for shard_file in shard_files: 720 | shard_path = os.path.join(model_path, shard_file) 721 | state_dict = torch.load(shard_path, map_location="cpu", weights_only=True) 722 | self._load_weights(state_dict) 723 | del state_dict # Save memory. 724 | gc.collect() 725 | -------------------------------------------------------------------------------- /gemma/xla_model_parallel.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from copy import deepcopy 16 | from dataclasses import dataclass 17 | import os 18 | from typing import Callable, List, Optional, Tuple 19 | 20 | import torch 21 | import torch.ao.quantization.fx._decomposed 22 | import torch.distributed as dist 23 | import torch.distributed._functional_collectives as fc 24 | import torch.distributed.distributed_c10d as c10d 25 | import torch.nn.functional as F 26 | import torch.nn.init as init 27 | from torch.nn.parameter import Parameter 28 | 29 | EPS = torch.finfo(torch.float32).eps 30 | 31 | USE_CUDA = os.environ.get('USE_CUDA', False) 32 | if not USE_CUDA: 33 | import torch_xla.core.xla_model as xm 34 | 35 | TAG = None 36 | RANKSET = None 37 | GROUP_SIZE = None 38 | 39 | 40 | def set_g_group(): 41 | global TAG 42 | global RANKSET 43 | global GROUP_SIZE 44 | 45 | assert USE_CUDA, "This hack is only for PyTorch non-XLA CUDA paths, i.e., eager and inductor." 46 | TAG, RANKSET, GROUP_SIZE = fc._expand_group(c10d._get_default_group()) 47 | 48 | 49 | @dataclass 50 | class TensorQConfig: 51 | dtype: torch.dtype = torch.int8 52 | axis: int = -1 53 | quant_min: int = -128 54 | quant_max: int = 127 55 | symmetric_quant: bool = True 56 | 57 | 58 | def _find_per_channel_min_max(x: torch.Tensor, axis: int): 59 | x_dim = x.size() 60 | new_axis_list = list(range(len(x_dim))) 61 | new_axis_list[axis] = 0 62 | new_axis_list[0] = axis 63 | y = x.permute(new_axis_list) 64 | y = torch.flatten(y, start_dim=1) 65 | return torch.aminmax(y, dim=1) 66 | 67 | 68 | def _find_qparams(x: torch.Tensor, qconfig: TensorQConfig): 69 | # Only support per-channel symmetric quant to int8 now 70 | axis = qconfig.axis 71 | dtype = qconfig.dtype 72 | symmetric_quant = qconfig.symmetric_quant 73 | quant_min = qconfig.quant_min 74 | quant_max = qconfig.quant_max 75 | assert axis >= 0 and axis < len(x.shape) 76 | assert dtype == torch.int8 77 | min_val, max_val = _find_per_channel_min_max(x, axis) 78 | min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) 79 | max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) 80 | scale = torch.ones(min_val_neg.size(), dtype=torch.float32) 81 | if symmetric_quant: 82 | max_val_pos = torch.max(-min_val_neg, max_val_pos) 83 | scale = max_val_pos / (float(quant_max - quant_min) / 2) 84 | eps = torch.zeros_like(scale).fill_(EPS) 85 | scale = torch.max(scale, eps) 86 | return scale, None 87 | else: 88 | assert symmetric_quant 89 | 90 | 91 | def _quantize_to_dtype( 92 | x: torch.Tensor, 93 | qconfig: TensorQConfig, 94 | scale: torch.Tensor, 95 | zero_point: Optional[torch.Tensor] = None, 96 | ): 97 | if zero_point is None: 98 | zero_point = torch.zeros_like(scale) 99 | return torch.ops.quantized_decomposed.quantize_per_channel( 100 | x, 101 | scale, 102 | zero_point, 103 | qconfig.axis, 104 | qconfig.quant_min, 105 | qconfig.quant_max, 106 | qconfig.dtype, 107 | ) 108 | 109 | 110 | def quantize_tensor(x: torch.Tensor, qconfig: TensorQConfig): 111 | scale, zp = _find_qparams(x, qconfig) 112 | x_int = _quantize_to_dtype(x, qconfig, scale, zp) 113 | return x_int, scale, zp 114 | 115 | 116 | def get_model_parallel_rank(): 117 | if USE_CUDA: 118 | return dist.get_rank() 119 | return xm.get_ordinal() 120 | 121 | 122 | def get_model_parallel_world_size(): 123 | if USE_CUDA: 124 | return dist.get_world_size() 125 | return xm.xrt_world_size() 126 | 127 | 128 | def get_model_parallel_group(): 129 | return None 130 | 131 | 132 | class _CopyToModelParallelRegion(torch.autograd.Function): 133 | """Pass the input to the model parallel region.""" 134 | 135 | @staticmethod 136 | def forward(ctx, input_, groups, world_size, rank): # type: ignore 137 | ctx.groups, ctx.world_size, ctx.rank = groups, world_size, rank 138 | return input_ 139 | 140 | @staticmethod 141 | def backward(ctx, grad_output): # type: ignore 142 | groups, world_size, rank = ctx.groups, ctx.world_size, ctx.rank 143 | return my_reduce(grad_output, groups, world_size, rank) 144 | 145 | 146 | class _ReduceFromModelParallelRegion(torch.autograd.Function): 147 | """All-redcue the input from the model parallel region.""" 148 | 149 | @staticmethod 150 | def forward(ctx, input_, groups, world_size, rank): # type: ignore 151 | return my_reduce(input_, groups, world_size, rank) 152 | 153 | @staticmethod 154 | def backward(ctx, grad_output): # type: ignore 155 | return grad_output 156 | 157 | 158 | class _ScatterToModelParallelRegion(torch.autograd.Function): 159 | """Split the input and keep only the corresponding chuck to the rank.""" 160 | 161 | @staticmethod 162 | def forward(ctx, input_, groups, world_size, rank): # type: ignore 163 | ctx.groups, ctx.world_size, ctx.rank = groups, world_size, rank 164 | return my_split(input_, groups, world_size, rank) 165 | 166 | @staticmethod 167 | def backward(ctx, grad_output): # type: ignore 168 | groups, world_size, rank = ctx.groups, ctx.world_size, ctx.rank 169 | return my_gather(grad_output, groups, world_size, rank) 170 | 171 | 172 | class _GatherFromModelParallelRegion(torch.autograd.Function): 173 | """Gather the input from model parallel region and concatinate.""" 174 | 175 | @staticmethod 176 | def forward(ctx, input_, groups, world_size, rank): # type: ignore 177 | ctx.groups, ctx.world_size, ctx.rank = groups, world_size, rank 178 | return my_gather(input_, groups, world_size, rank) 179 | 180 | @staticmethod 181 | def backward(ctx, grad_output): # type: ignore 182 | groups, world_size, rank = ctx.groups, ctx.world_size, ctx.rank 183 | return my_split(grad_output, groups, world_size, rank) 184 | 185 | 186 | # ----------------- 187 | # Helper functions. 188 | # ----------------- 189 | 190 | 191 | def copy_to_model_parallel_region(input_: torch.Tensor, groups, world_size, 192 | rank) -> torch.Tensor: 193 | return _CopyToModelParallelRegion.apply(input_, groups, world_size, rank) 194 | 195 | 196 | def reduce_from_model_parallel_region(input_: torch.Tensor, groups, world_size, 197 | rank) -> torch.Tensor: 198 | return _ReduceFromModelParallelRegion.apply(input_, groups, world_size, 199 | rank) 200 | 201 | 202 | def scatter_to_model_parallel_region(input_: torch.Tensor, groups, world_size, 203 | rank) -> torch.Tensor: 204 | return _ScatterToModelParallelRegion.apply(input_, groups, world_size, 205 | rank) 206 | 207 | 208 | def gather_from_model_parallel_region(input_: torch.Tensor, groups, world_size, 209 | rank) -> torch.Tensor: 210 | return _GatherFromModelParallelRegion.apply(input_, groups, world_size, 211 | rank) 212 | 213 | 214 | def ensure_divisibility(numerator: int, denominator: int) -> None: 215 | """Ensure that numerator is divisible by the denominator.""" 216 | assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator) 217 | 218 | 219 | def divide_and_check_no_remainder(numerator: int, denominator: int) -> int: 220 | """Ensure that numerator is divisible by the denominator and return 221 | the division value.""" 222 | ensure_divisibility(numerator, denominator) 223 | return numerator // denominator 224 | 225 | 226 | def split_tensor_along_last_dim( 227 | tensor: torch.Tensor, num_partitions: int, contiguous_split_chunks: bool = False 228 | ) -> Tuple[torch.Tensor, ...]: 229 | """Split a tensor along its last dimension. 230 | Arguments: 231 | tensor: input tensor. 232 | num_partitions: number of partitions to split the tensor 233 | contiguous_split_chunks: If True, make each chunk contiguous 234 | in memory. 235 | """ 236 | # Get the size and dimension. 237 | last_dim = tensor.dim() - 1 238 | last_dim_size = divide_and_check_no_remainder(tensor.size()[last_dim], num_partitions) 239 | # Split. 240 | tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) 241 | # Note: torch.split does not create contiguous tensors by default. 242 | if contiguous_split_chunks: 243 | return tuple(chunk.contiguous() for chunk in tensor_list) 244 | 245 | return tensor_list 246 | 247 | # Below copied from fairscale/nn/model_parallel/layers.py 248 | 249 | 250 | def my_reduce(input_: torch.Tensor, groups, world_size, rank) -> torch.Tensor: 251 | """All-reduce the the input tensor across model parallel group.""" 252 | # Bypass the function if we are using only 1 GPU. 253 | if world_size == 1: 254 | return input_ 255 | 256 | # All-reduce. 257 | if USE_CUDA: 258 | input_ = torch.ops.c10d_functional.all_reduce(input_, "sum", TAG, 259 | RANKSET, GROUP_SIZE) 260 | else: 261 | input_ = xm.all_reduce(xm.REDUCE_SUM, input_, groups=groups) 262 | 263 | return input_ 264 | 265 | 266 | def my_split(input_: torch.Tensor, groups, world_size, rank) -> torch.Tensor: 267 | """Split the tensor along its last dimension and keep the 268 | 269 | corresponding slice. 270 | """ 271 | # Bypass the function if we are using only 1 GPU. 272 | if world_size == 1: 273 | return input_ 274 | 275 | # Split along last dimension. 276 | input_list = split_tensor_along_last_dim(input_, world_size) 277 | 278 | # Note: torch.split does not create contiguous tensors by default. 279 | output = input_list[rank].contiguous() 280 | 281 | return output 282 | 283 | 284 | def my_gather(input_: torch.Tensor, groups, world_size, rank) -> torch.Tensor: 285 | """Gather tensors and concatinate along the last dimension.""" 286 | # Bypass the function if we are using only 1 GPU. 287 | if world_size == 1: 288 | return input_ 289 | 290 | if USE_CUDA: 291 | last_dim = input_.dim() - 1 292 | 293 | # Using all_reduce to achieve all_gather as torch.ops.c10d_functional.all_gather_into_tensor 294 | # is buggy in 16 bits. 295 | size = input_.size(last_dim) 296 | padding = [0] * (2 * input_.dim()) 297 | ordinal = rank 298 | left, right = ordinal, world_size - 1 - ordinal 299 | idx = input_.dim() - 1 - last_dim 300 | padding[2 * idx] = left * size 301 | padding[2 * idx + 1] = right * size 302 | output = torch.ops.c10d_functional.all_reduce(F.pad(input_, 303 | padding), "sum", 304 | TAG, RANKSET, GROUP_SIZE) 305 | else: 306 | output = xm.all_gather(input_, dim=-1, groups=groups) 307 | 308 | return output 309 | 310 | 311 | def _initialize_affine_weight( 312 | weight: torch.Tensor, 313 | out_features: int, 314 | in_features: int, 315 | per_partition_size: int, 316 | partition_dim: int, 317 | init_method: Callable[[torch.Tensor], torch.Tensor], 318 | world_size: int, 319 | rank: int, 320 | stride: int = 1, 321 | return_master_weight: bool = False, 322 | ) -> Optional[torch.Tensor]: 323 | """Initialize affine weight for model parallel. 324 | 325 | Build the master weight on all processes and scatter 326 | the relevant chunk. 327 | """ 328 | 329 | # If we only use 1 process for model parallelism, bypass scatter. 330 | if world_size == 1: 331 | init_method(weight) 332 | if return_master_weight: 333 | return weight 334 | return None 335 | 336 | # Initialize master weight 337 | master_weight = torch.empty(out_features, 338 | in_features, 339 | dtype=weight.dtype, 340 | requires_grad=False) 341 | init_method(master_weight) 342 | 343 | # Split and copy 344 | per_partition_per_stride_size = divide_and_check_no_remainder( 345 | per_partition_size, stride) 346 | weight_list = torch.split(master_weight, 347 | per_partition_per_stride_size, 348 | dim=partition_dim) 349 | my_weight_list = weight_list[rank::world_size] 350 | 351 | with torch.no_grad(): 352 | torch.cat(my_weight_list, dim=partition_dim, out=weight) 353 | if return_master_weight: 354 | return master_weight 355 | return None 356 | 357 | 358 | class ParallelEmbedding(torch.nn.Module): 359 | """Embedding parallelized in the embedding dimension. 360 | 361 | This is mainly adapted from torch.nn.Embedding and all the default 362 | values are kept. 363 | Arguments: 364 | num_embeddings: vocabulary size. 365 | embedding_dim: size of hidden state. 366 | init_method: method to initialize weights. 367 | """ 368 | 369 | def __init__( 370 | self, 371 | num_embeddings: int, 372 | embedding_dim: int, 373 | padding_idx: Optional[int] = None, 374 | max_norm: Optional[float] = None, 375 | norm_type: float = 2.0, 376 | scale_grad_by_freq: bool = False, 377 | sparse: bool = False, 378 | init_method: Callable[[torch.Tensor], 379 | torch.Tensor] = init.xavier_normal_, 380 | keep_master_weight_for_test: bool = False, 381 | world_size: Optional[int] = None, 382 | rank: Optional[int] = None, 383 | groups: Optional[List] = None, 384 | quant: bool = False, 385 | ) -> None: 386 | super(ParallelEmbedding, self).__init__() 387 | 388 | if world_size is None: 389 | self.groups = get_model_parallel_group() 390 | self.world_size = get_model_parallel_world_size() 391 | self.rank = get_model_parallel_rank() 392 | else: 393 | self.groups = groups 394 | self.world_size = world_size 395 | self.rank = rank 396 | 397 | # Keep the input dimensions. 398 | self.num_embeddings = num_embeddings 399 | self.embedding_dim = embedding_dim 400 | self.padding_idx = padding_idx 401 | self.max_norm = max_norm 402 | self.norm_type = norm_type 403 | self.scale_grad_by_freq = scale_grad_by_freq 404 | self.sparse = sparse 405 | self._weight = None 406 | self.quant = quant 407 | # Divide the weight matrix along the embedding dimension. 408 | self.embedding_dim_per_partition = divide_and_check_no_remainder( 409 | self.embedding_dim, self.world_size) 410 | 411 | # Allocate weights. 412 | if quant: 413 | self.weight = Parameter( 414 | torch.empty( 415 | (self.num_embeddings, self.embedding_dim_per_partition), 416 | dtype=torch.int8, 417 | ), 418 | requires_grad=False, 419 | ) 420 | self.weight_scaler = Parameter(torch.Tensor(self.num_embeddings)) 421 | else: 422 | self.weight = Parameter( 423 | torch.Tensor(self.num_embeddings, 424 | self.embedding_dim_per_partition)) 425 | 426 | # And initialize. 427 | _initialize_affine_weight( 428 | self.weight, 429 | self.num_embeddings, 430 | self.embedding_dim, 431 | self.embedding_dim_per_partition, 432 | 1, 433 | init_method, 434 | self.world_size, 435 | self.rank, 436 | stride=1, 437 | return_master_weight=False, 438 | ) 439 | 440 | def forward(self, input_: torch.Tensor) -> torch.Tensor: # type: ignore 441 | input_parallel = copy_to_model_parallel_region(input_, self.groups, 442 | self.world_size, 443 | self.rank) 444 | # PyTorch eager and inductor do not accept negative values in the input to embedding 445 | # layers. Take the modulus to avoid this error. 446 | if USE_CUDA: 447 | input_parallel = torch.remainder(input_parallel, 448 | self.weight.shape[0]) 449 | weight = self.weight 450 | if self.quant: 451 | weight = weight * self.weight_scaler.unsqueeze(-1) 452 | output_parallel = F.embedding( 453 | input_parallel, 454 | weight, 455 | self.padding_idx, 456 | self.max_norm, 457 | self.norm_type, 458 | self.scale_grad_by_freq, 459 | self.sparse, 460 | ) 461 | output = gather_from_model_parallel_region(output_parallel, 462 | self.groups, 463 | self.world_size, self.rank) 464 | return output 465 | 466 | 467 | class ColumnParallelLinear(torch.nn.Module): 468 | """Linear layer with column parallelism. 469 | 470 | The linear layer is defined as Y = XA + b. A is parallelized along 471 | its second dimension as A = [A_1, ..., A_p]. 472 | 473 | Arguments: 474 | in_features: first dimension of matrix A. 475 | out_features: second dimension of matrix A. 476 | bias: If true, add bias 477 | gather_output: If true, call all-gether on output and make Y available to 478 | all GPUs, otherwise, every GPU will have its output which is Y_i = XA_i 479 | init_method: method to initialize weights. Note that bias is always set to 480 | zero. 481 | stride: For the strided linear layers. 482 | keep_master_weight_for_test: This was added for testing and should be set 483 | to False. It returns the master weights used for initialization. 484 | """ 485 | 486 | def __init__( 487 | self, 488 | in_features: int, 489 | out_features: int, 490 | bias: bool = True, 491 | gather_output: bool = True, 492 | init_method: Callable[[torch.Tensor], 493 | torch.Tensor] = init.xavier_normal_, 494 | stride: int = 1, 495 | keep_master_weight_for_test: bool = False, 496 | world_size: Optional[int] = None, 497 | rank: Optional[int] = None, 498 | groups: Optional[List] = None, 499 | quant: bool = False, 500 | ) -> None: 501 | super(ColumnParallelLinear, self).__init__() 502 | 503 | if world_size is None: 504 | self.groups = get_model_parallel_group() 505 | self.world_size = get_model_parallel_world_size() 506 | self.rank = get_model_parallel_rank() 507 | else: 508 | self.groups = groups 509 | self.world_size = world_size 510 | self.rank = rank 511 | 512 | # Keep input parameters 513 | self.in_features = in_features 514 | self.out_features = out_features 515 | self.gather_output = gather_output 516 | self.quant = quant 517 | # Divide the weight matrix along the last dimension. 518 | self.output_size_per_partition = divide_and_check_no_remainder( 519 | out_features, self.world_size) 520 | 521 | # Parameters. 522 | # Note: torch.nn.functional.linear performs XA^T + b and as a result 523 | # we allocate the transpose. 524 | if quant: 525 | self.weight = Parameter( 526 | torch.empty( 527 | (self.output_size_per_partition, self.in_features), 528 | dtype=torch.int8, 529 | ), 530 | requires_grad=False, 531 | ) 532 | self.weight_scaler = Parameter( 533 | torch.Tensor(self.output_size_per_partition)) 534 | else: 535 | self.weight = Parameter( 536 | torch.Tensor(self.output_size_per_partition, self.in_features)) 537 | if bias: 538 | self.bias = Parameter(torch.Tensor(self.output_size_per_partition)) 539 | # Always initialize bias to zero. 540 | with torch.no_grad(): 541 | self.bias.zero_() 542 | else: 543 | self.register_parameter('bias', None) 544 | 545 | # Initialize weight. 546 | self.master_weight = _initialize_affine_weight( 547 | self.weight, 548 | self.out_features, 549 | self.in_features, 550 | self.output_size_per_partition, 551 | 0, 552 | init_method, 553 | self.world_size, 554 | self.rank, 555 | stride=stride, 556 | return_master_weight=keep_master_weight_for_test, 557 | ) 558 | 559 | def get_master_weight(self) -> torch.Tensor: 560 | return gather_from_model_parallel_region( 561 | self.weight.data.transpose(0, 1), 562 | self.groups, 563 | self.world_size, 564 | self.rank, 565 | ).transpose_(0, 1) 566 | 567 | def set_quantize(self): 568 | assert not self.quant 569 | self.weight = Parameter( 570 | torch.empty((self.output_size_per_partition, self.in_features), 571 | dtype=torch.int8), 572 | requires_grad=False, 573 | ) 574 | self.weight_scaler = Parameter( 575 | torch.Tensor(self.output_size_per_partition)) 576 | self.quant = True 577 | 578 | def quantize(self): 579 | assert not self.quant 580 | fp_w = deepcopy(self.weight.data) 581 | orig_dtype = fp_w.dtype 582 | fp_w = fp_w.to(torch.float32) 583 | self.weight = Parameter( 584 | torch.empty((self.output_size_per_partition, self.in_features), 585 | dtype=torch.int8), 586 | requires_grad=False, 587 | ) 588 | self.weight_scaler = Parameter( 589 | torch.Tensor(self.output_size_per_partition)) 590 | qconfig = TensorQConfig(axis=0) 591 | self.weight.data, scale, zero_point = quantize_tensor(fp_w, qconfig) 592 | self.weight_scaler.data = scale.to(orig_dtype) 593 | self.quant = True 594 | 595 | def forward(self, input_: torch.Tensor) -> torch.Tensor: # type: ignore 596 | # Set up backprop all-reduce. 597 | input_parallel = copy_to_model_parallel_region(input_, self.groups, 598 | self.world_size, 599 | self.rank) 600 | # Matrix multiply. 601 | if self.quant and USE_CUDA: 602 | # GPUs do not support mixed int8 bf16 computation. Scale int8 weights to bf16 before linear. 603 | scaled_weight = self.weight * self.weight_scaler 604 | output_parallel = F.linear(input_parallel, scaled_weight, self.bias) 605 | elif self.quant: 606 | output_parallel = F.linear(input_parallel, self.weight, self.bias) 607 | output_parallel = output_parallel * self.weight_scaler 608 | else: 609 | output_parallel = F.linear(input_parallel, self.weight, self.bias) 610 | if self.gather_output: 611 | # All-gather across the partitions. 612 | output = gather_from_model_parallel_region(output_parallel, 613 | self.groups, 614 | self.world_size, 615 | self.rank) 616 | else: 617 | output = output_parallel 618 | return output 619 | 620 | 621 | class RowParallelLinear(torch.nn.Module): 622 | """Linear layer with row parallelism. 623 | 624 | The linear layer is defined as Y = XA + b. A is parallelized along 625 | its first dimension and X along its second dimension as: 626 | - - 627 | | A_1 | 628 | | . | 629 | A = | . | X = [X_1, ..., X_p] 630 | | . | 631 | | A_p | 632 | - - 633 | Arguments: 634 | in_features: first dimension of matrix A. 635 | out_features: second dimension of matrix A. 636 | bias: If true, add bias. Note that bias is not parallelized. 637 | input_is_parallel: If true, we assume that the input is already split 638 | across the GPUs and we do not split again. 639 | init_method: method to initialize weights. Note that bias is always set to 640 | zero. 641 | stride: For the strided linear layers. 642 | keep_master_weight_for_test: This was added for testing and should be set 643 | to False. It returns the master weights used for initialization. 644 | """ 645 | 646 | def __init__( 647 | self, 648 | in_features: int, 649 | out_features: int, 650 | bias: bool = True, 651 | input_is_parallel: bool = False, 652 | init_method: Callable[[torch.Tensor], 653 | torch.Tensor] = init.xavier_normal_, 654 | stride: int = 1, 655 | keep_master_weight_for_test: bool = False, 656 | world_size: Optional[int] = None, 657 | rank: Optional[int] = None, 658 | groups: Optional[List] = None, 659 | quant: bool = False, 660 | ): 661 | super(RowParallelLinear, self).__init__() 662 | 663 | if world_size is None: 664 | self.groups = get_model_parallel_group() 665 | self.world_size = get_model_parallel_world_size() 666 | self.rank = get_model_parallel_rank() 667 | else: 668 | self.groups = groups 669 | self.world_size = world_size 670 | self.rank = rank 671 | 672 | # Keep input parameters 673 | self.in_features = in_features 674 | self.out_features = out_features 675 | self.input_is_parallel = input_is_parallel 676 | self.quant = quant 677 | # Divide the weight matrix along the last dimension. 678 | self.input_size_per_partition = divide_and_check_no_remainder( 679 | in_features, self.world_size) 680 | 681 | # Parameters. 682 | # Note: torch.nn.functional.linear performs XA^T + b and as a result 683 | # we allocate the transpose. 684 | if quant: 685 | self.weight = Parameter( 686 | torch.empty( 687 | (self.out_features, self.input_size_per_partition), 688 | dtype=torch.int8, 689 | ), 690 | requires_grad=False, 691 | ) 692 | self.weight_scaler = Parameter(torch.Tensor(self.out_features)) 693 | else: 694 | self.weight = Parameter( 695 | torch.Tensor(self.out_features, self.input_size_per_partition)) 696 | if bias: 697 | self.bias = Parameter(torch.Tensor(self.out_features)) 698 | # Always initialize bias to zero. 699 | with torch.no_grad(): 700 | self.bias.zero_() 701 | else: 702 | self.register_parameter('bias', None) 703 | 704 | # Initialize weight. 705 | self.master_weight = _initialize_affine_weight( 706 | self.weight, 707 | self.out_features, 708 | self.in_features, 709 | self.input_size_per_partition, 710 | 1, 711 | init_method, 712 | self.world_size, 713 | self.rank, 714 | stride=stride, 715 | return_master_weight=keep_master_weight_for_test, 716 | ) 717 | 718 | def get_master_weight(self) -> torch.Tensor: 719 | return gather_from_model_parallel_region(self.weight.data, self.groups, 720 | self.world_size, self.rank) 721 | 722 | def set_quantize(self): 723 | assert not self.quant 724 | self.weight = Parameter( 725 | torch.empty((self.out_features, self.input_size_per_partition), 726 | dtype=torch.int8), 727 | requires_grad=False, 728 | ) 729 | self.weight_scaler = Parameter(torch.Tensor(self.out_features)) 730 | self.quant = True 731 | 732 | def quantize(self): 733 | assert not self.quant 734 | fp_w = deepcopy(self.weight.data) 735 | orig_dtype = fp_w.dtype 736 | fp_w = fp_w.to(torch.float32) 737 | self.weight = Parameter( 738 | torch.empty((self.out_features, self.input_size_per_partition), 739 | dtype=torch.int8), 740 | requires_grad=False, 741 | ) 742 | self.weight_scaler = Parameter(torch.Tensor(self.out_features)) 743 | qconfig = TensorQConfig(axis=0) 744 | self.weight.data, scale, zero_point = quantize_tensor(fp_w, qconfig) 745 | self.weight_scaler.data = scale.to(orig_dtype) 746 | self.quant = True 747 | 748 | def forward(self, input_: torch.Tensor) -> torch.Tensor: # type:ignore 749 | # Set up backprop all-reduce. 750 | if self.input_is_parallel: 751 | input_parallel = input_ 752 | else: 753 | input_parallel = scatter_to_model_parallel_region( 754 | input_, self.groups, self.world_size, self.rank) 755 | # Matrix multiply. 756 | if self.quant and USE_CUDA: 757 | # GPUs do not support mixed int8 bf16 computation. Scale int8 weights to bf16 before linear. 758 | scaled_weight = self.weight * self.weight_scaler 759 | output_parallel = F.linear(input_parallel, scaled_weight, self.bias) 760 | elif self.quant: 761 | output_parallel = F.linear(input_parallel, self.weight, self.bias) 762 | output_parallel = output_parallel * self.weight_scaler 763 | else: 764 | output_parallel = F.linear(input_parallel, self.weight) 765 | # All-reduce across all the partitions. 766 | output_ = reduce_from_model_parallel_region(output_parallel, 767 | self.groups, 768 | self.world_size, self.rank) 769 | if self.bias is not None: 770 | output = output_ + self.bias 771 | else: 772 | output = output_ 773 | return output 774 | -------------------------------------------------------------------------------- /gemma/model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Inference-only Gemma model implementation.""" 15 | 16 | import json 17 | import gc 18 | import os 19 | import torch 20 | from torch import nn 21 | import torch.nn.functional as F 22 | from typing import Any, List, Optional, Sequence, Tuple, Union, Mapping 23 | 24 | from gemma import config as gemma_config 25 | from gemma import tokenizer 26 | 27 | 28 | class Sampler(nn.Module): 29 | 30 | def __init__(self, vocab_size: int, config: gemma_config.GemmaConfig): 31 | super().__init__() 32 | self.vocab_size = vocab_size 33 | self.config = config 34 | 35 | @torch.no_grad() 36 | def forward( 37 | self, 38 | embedding: torch.Tensor, 39 | hidden_states: torch.Tensor, 40 | output_positions: torch.Tensor, 41 | temperatures: Union[torch.Tensor, None], 42 | top_ps: torch.Tensor, 43 | top_ks: torch.Tensor, 44 | embedding_bias: Optional[torch.Tensor] = None, 45 | ) -> Tuple[torch.Tensor, torch.Tensor]: 46 | # Select the last element for each sequence. 47 | # (batch_size, input_len, hidden_size) -> (batch_size, hidden_size) 48 | hidden_states = hidden_states.index_select( 49 | 1, output_positions).squeeze(dim=1) 50 | logits = torch.matmul(hidden_states, embedding.t()) 51 | if embedding_bias is not None: 52 | logits += embedding_bias 53 | if self.config.final_logit_softcapping is not None: 54 | logits = logits / self.config.final_logit_softcapping 55 | logits = torch.tanh(logits) 56 | logits = logits * self.config.final_logit_softcapping 57 | 58 | if temperatures is None: 59 | return torch.argmax(logits, dim=-1).squeeze(dim=-1), logits 60 | 61 | # Apply temperature scaling. 62 | logits.div_(temperatures.unsqueeze(dim=1)) 63 | 64 | # Calculate probabilities with softmax. 65 | probs = torch.softmax(logits, dim=-1, dtype=torch.float) 66 | probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) 67 | 68 | # Apply top-p, top-k. 69 | probs_sum = torch.cumsum(probs_sort, dim=-1) 70 | top_ps_mask = (probs_sum - probs_sort) > top_ps.unsqueeze(dim=1) 71 | probs_sort = torch.where(top_ps_mask, 0, probs_sort) 72 | 73 | top_ks_mask = torch.arange(probs_idx.shape[-1], 74 | device=probs_idx.device) 75 | top_ks_mask = top_ks_mask.expand(probs_idx.shape[0], -1) 76 | top_ks_mask = top_ks_mask >= top_ks.unsqueeze(dim=1) 77 | probs_sort = torch.where(top_ks_mask, 0, probs_sort) 78 | 79 | # Re-normalization. 80 | probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) 81 | probs = torch.gather(probs_sort, 82 | dim=-1, 83 | index=torch.argsort(probs_idx, dim=-1)) 84 | 85 | next_token_ids = torch.multinomial(probs, 86 | num_samples=1, 87 | replacement=True).squeeze(dim=-1) 88 | return next_token_ids, logits 89 | 90 | 91 | def precompute_freqs_cis(dim: int, 92 | end: int, 93 | theta: float = 10000.0, 94 | rope_scaling_factor:int = 1) -> torch.Tensor: 95 | """Precomputes the frequency cis.""" 96 | freqs = 1.0 / (theta**(torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) 97 | freqs = freqs/rope_scaling_factor 98 | t = torch.arange(end, device=freqs.device) 99 | freqs = torch.outer(t, freqs).float() 100 | freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 101 | return freqs_cis 102 | 103 | 104 | def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: 105 | """Applies the rotary embedding to the query and key tensors.""" 106 | x_ = torch.view_as_complex( 107 | torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1), 108 | dim=-1)) 109 | x_out = torch.view_as_real(x_ * freqs_cis).type_as(x) 110 | x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2) 111 | x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2], 112 | -1).transpose(1, 2) 113 | return x_out 114 | 115 | 116 | class Linear(nn.Module): 117 | 118 | def __init__(self, in_features: int, out_features: int, quant: bool): 119 | super().__init__() 120 | if quant: 121 | self.weight = nn.Parameter( 122 | torch.empty((out_features, in_features), dtype=torch.int8), 123 | requires_grad=False, 124 | ) 125 | self.weight_scaler = nn.Parameter(torch.Tensor(out_features)) 126 | else: 127 | self.weight = nn.Parameter( 128 | torch.empty((out_features, in_features)), 129 | requires_grad=False, 130 | ) 131 | self.quant = quant 132 | 133 | def forward(self, x): 134 | weight = self.weight 135 | if self.quant: 136 | weight = weight * self.weight_scaler.unsqueeze(-1) 137 | output = F.linear(x, weight) 138 | return output 139 | 140 | 141 | class Embedding(nn.Module): 142 | 143 | def __init__(self, num_embeddings: int, embedding_dim: int, quant: bool): 144 | super().__init__() 145 | if quant: 146 | self.weight = nn.Parameter( 147 | torch.empty((num_embeddings, embedding_dim), dtype=torch.int8), 148 | requires_grad=False, 149 | ) 150 | self.weight_scaler = nn.Parameter(torch.Tensor(num_embeddings)) 151 | else: 152 | self.weight = nn.Parameter( 153 | torch.empty((num_embeddings, embedding_dim)), 154 | requires_grad=False, 155 | ) 156 | self.quant = quant 157 | 158 | def forward(self, x): 159 | weight = self.weight 160 | if self.quant: 161 | weight = weight * self.weight_scaler.unsqueeze(-1) 162 | output = F.embedding(x, weight) 163 | return output 164 | 165 | 166 | class RMSNorm(torch.nn.Module): 167 | 168 | def __init__( 169 | self, 170 | dim: int, 171 | eps: float = 1e-6, 172 | add_unit_offset: bool = True, 173 | ): 174 | super().__init__() 175 | self.eps = eps 176 | self.add_unit_offset = add_unit_offset 177 | self.weight = nn.Parameter(torch.zeros(dim)) 178 | 179 | def _norm(self, x): 180 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 181 | 182 | def forward(self, x): 183 | # Llama does x.to(float16) * w whilst Gemma2 is (x * w).to(float16) 184 | # See https://github.com/huggingface/transformers/pull/29402 185 | output = self._norm(x.float()) 186 | if self.add_unit_offset: 187 | output = output * (1 + self.weight.float()) 188 | else: 189 | output = output * self.weight.float() 190 | return output.type_as(x) 191 | 192 | 193 | class GemmaMLP(nn.Module): 194 | 195 | def __init__( 196 | self, 197 | hidden_size: int, 198 | intermediate_size: int, 199 | quant: bool, 200 | ): 201 | super().__init__() 202 | self.gate_proj = Linear(hidden_size, intermediate_size, quant) 203 | self.up_proj = Linear(hidden_size, intermediate_size, quant) 204 | self.down_proj = Linear(intermediate_size, hidden_size, quant) 205 | 206 | def forward(self, x): 207 | gate = self.gate_proj(x) 208 | gate = F.gelu(gate, approximate="tanh") 209 | up = self.up_proj(x) 210 | fuse = gate * up 211 | outputs = self.down_proj(fuse) 212 | return outputs 213 | 214 | 215 | class GemmaAttention(nn.Module): 216 | 217 | def __init__( 218 | self, 219 | config: gemma_config.GemmaConfig, 220 | attn_type: gemma_config.AttentionType, 221 | ): 222 | super().__init__() 223 | 224 | self.num_heads = config.num_attention_heads 225 | self.num_kv_heads = config.num_key_value_heads 226 | 227 | assert self.num_heads % self.num_kv_heads == 0 228 | self.num_queries_per_kv = self.num_heads // self.num_kv_heads 229 | 230 | self.hidden_size = config.hidden_size 231 | self.head_dim = config.head_dim 232 | 233 | self.q_size = self.num_heads * self.head_dim 234 | self.kv_size = self.num_kv_heads * self.head_dim 235 | 236 | if config.query_pre_attn_scalar is not None: 237 | self.scaling = config.query_pre_attn_scalar**-0.5 238 | else: 239 | self.scaling = self.head_dim**-0.5 240 | 241 | self.qkv_proj = Linear( 242 | self.hidden_size, 243 | (self.num_heads + 2 * self.num_kv_heads) * self.head_dim, 244 | quant=config.quant) 245 | self.o_proj = Linear( 246 | self.num_heads * self.head_dim, self.hidden_size, quant=config.quant 247 | ) 248 | self.query_norm = ( 249 | RMSNorm(self.head_dim, eps=config.rms_norm_eps) 250 | if config.use_qk_norm 251 | else None 252 | ) 253 | self.key_norm = ( 254 | RMSNorm(self.head_dim, eps=config.rms_norm_eps) 255 | if config.use_qk_norm 256 | else None 257 | ) 258 | 259 | self.attn_type = attn_type 260 | self.sliding_window_size = config.sliding_window_size 261 | self.attn_logit_softcapping = config.attn_logit_softcapping 262 | 263 | def forward( 264 | self, 265 | hidden_states: torch.Tensor, 266 | freqs_cis: torch.Tensor, 267 | kv_write_indices: torch.Tensor, 268 | kv_cache: Tuple[torch.Tensor, torch.Tensor], 269 | mask: torch.Tensor, 270 | local_mask: torch.Tensor = None, 271 | ) -> torch.Tensor: 272 | hidden_states_shape = hidden_states.shape 273 | assert len(hidden_states_shape) == 3 274 | 275 | batch_size, input_len, _ = hidden_states_shape 276 | 277 | qkv = self.qkv_proj(hidden_states) 278 | xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size], 279 | dim=-1) 280 | 281 | xq = xq.view(batch_size, -1, self.num_heads, self.head_dim) 282 | xk = xk.view(batch_size, -1, self.num_kv_heads, self.head_dim) 283 | xv = xv.view(batch_size, -1, self.num_kv_heads, self.head_dim) 284 | 285 | if self.query_norm is not None and self.key_norm is not None: 286 | xq = self.query_norm(xq) 287 | xk = self.key_norm(xk) 288 | 289 | # Positional embedding. 290 | xq = apply_rotary_emb(xq, freqs_cis=freqs_cis) 291 | xk = apply_rotary_emb(xk, freqs_cis=freqs_cis) 292 | 293 | # Write new kv cache. 294 | # [batch_size, input_len, n_local_kv_heads, head_dim] 295 | k_cache, v_cache = kv_cache 296 | k_cache.index_copy_(1, kv_write_indices, xk) 297 | v_cache.index_copy_(1, kv_write_indices, xv) 298 | 299 | key = k_cache 300 | value = v_cache 301 | if self.num_kv_heads != self.num_heads: 302 | # [batch_size, max_seq_len, n_local_heads, head_dim] 303 | key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=2) 304 | value = torch.repeat_interleave(value, 305 | self.num_queries_per_kv, 306 | dim=2) 307 | 308 | # [batch_size, n_local_heads, input_len, head_dim] 309 | q = xq.transpose(1, 2) 310 | # [batch_size, n_local_heads, max_seq_len, head_dim] 311 | k = key.transpose(1, 2) 312 | v = value.transpose(1, 2) 313 | 314 | # [batch_size, n_local_heads, input_len, max_seq_len] 315 | q.mul_(self.scaling) 316 | scores = torch.matmul(q, k.transpose(2, 3)) 317 | if ( 318 | self.attn_type == gemma_config.AttentionType.LOCAL_SLIDING 319 | and self.sliding_window_size is not None 320 | and local_mask is not None 321 | ): 322 | mask = local_mask 323 | 324 | if self.attn_logit_softcapping is not None: 325 | scores = scores / self.attn_logit_softcapping 326 | scores = torch.tanh(scores) 327 | scores = scores * self.attn_logit_softcapping 328 | 329 | scores = scores + mask 330 | scores = F.softmax(scores.float(), dim=-1).type_as(q) 331 | 332 | # [batch_size, n_local_heads, input_len, head_dim] 333 | output = torch.matmul(scores, v) 334 | 335 | # [batch_size, input_len, hidden_dim] 336 | output = (output.transpose(1, 2).contiguous().view( 337 | batch_size, input_len, -1)) 338 | output = self.o_proj(output) 339 | return output 340 | 341 | 342 | class GemmaDecoderLayer(nn.Module): 343 | 344 | def __init__( 345 | self, 346 | config: gemma_config.GemmaConfig, 347 | ): 348 | super().__init__() 349 | self.attn_type = gemma_config.AttentionType.GLOBAL 350 | self.self_attn = GemmaAttention( 351 | config=config, 352 | attn_type=self.attn_type) 353 | self.mlp = GemmaMLP( 354 | hidden_size=config.hidden_size, 355 | intermediate_size=config.intermediate_size, 356 | quant=config.quant, 357 | ) 358 | self.input_layernorm = RMSNorm(config.hidden_size, 359 | eps=config.rms_norm_eps) 360 | self.post_attention_layernorm = RMSNorm(config.hidden_size, 361 | eps=config.rms_norm_eps) 362 | 363 | # TODO(imayank): Decouple Gemma versions into separate files. 364 | def forward( 365 | self, 366 | hidden_states: torch.Tensor, 367 | freqs_cis: torch.Tensor, 368 | kv_write_indices: torch.Tensor, 369 | kv_cache: Tuple[torch.Tensor, torch.Tensor], 370 | mask: torch.Tensor, 371 | local_mask: torch.Tensor, 372 | ) -> torch.Tensor: 373 | # Self Attention 374 | residual = hidden_states 375 | hidden_states = self.input_layernorm(hidden_states) 376 | hidden_states = self.self_attn( 377 | hidden_states=hidden_states, 378 | freqs_cis=freqs_cis, 379 | kv_write_indices=kv_write_indices, 380 | kv_cache=kv_cache, 381 | mask=mask, 382 | ) 383 | hidden_states = residual + hidden_states 384 | 385 | # MLP 386 | residual = hidden_states 387 | hidden_states = self.post_attention_layernorm(hidden_states) 388 | hidden_states = self.mlp(hidden_states) 389 | hidden_states = residual + hidden_states 390 | 391 | return hidden_states 392 | 393 | 394 | class Gemma2DecoderLayer(nn.Module): 395 | def __init__( 396 | self, 397 | config: gemma_config.GemmaConfig, 398 | attn_type: gemma_config.AttentionType, 399 | ): 400 | super().__init__() 401 | self.attn_type = attn_type 402 | self.self_attn = GemmaAttention( 403 | config=config, 404 | attn_type=self.attn_type, 405 | ) 406 | self.mlp = GemmaMLP( 407 | hidden_size=config.hidden_size, 408 | intermediate_size=config.intermediate_size, 409 | quant=config.quant, 410 | ) 411 | self.input_layernorm = RMSNorm(config.hidden_size, 412 | eps=config.rms_norm_eps) 413 | self.post_attention_layernorm = RMSNorm(config.hidden_size, 414 | eps=config.rms_norm_eps) 415 | self.pre_feedforward_layernorm = ( 416 | RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 417 | if config.use_pre_ffw_norm 418 | else None 419 | ) 420 | self.post_feedforward_layernorm = ( 421 | RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 422 | if config.use_post_ffw_norm 423 | else None 424 | ) 425 | 426 | def forward( 427 | self, 428 | hidden_states: torch.Tensor, 429 | freqs_cis: torch.Tensor, 430 | kv_write_indices: torch.Tensor, 431 | kv_cache: Tuple[torch.Tensor, torch.Tensor], 432 | mask: torch.Tensor, 433 | local_mask: torch.Tensor, 434 | ) -> torch.Tensor: 435 | # Self Attention 436 | residual = hidden_states 437 | hidden_states = self.input_layernorm(hidden_states) 438 | hidden_states = self.self_attn( 439 | hidden_states=hidden_states, 440 | freqs_cis=freqs_cis, 441 | kv_write_indices=kv_write_indices, 442 | kv_cache=kv_cache, 443 | mask=mask, 444 | local_mask=local_mask, 445 | ) 446 | hidden_states = self.post_attention_layernorm(hidden_states) 447 | hidden_states = residual + hidden_states 448 | 449 | # MLP 450 | residual = hidden_states 451 | if self.pre_feedforward_layernorm is not None: 452 | hidden_states = self.pre_feedforward_layernorm(hidden_states) 453 | hidden_states = self.mlp(hidden_states) 454 | if self.post_feedforward_layernorm is not None: 455 | hidden_states = self.post_feedforward_layernorm(hidden_states) 456 | hidden_states = residual + hidden_states 457 | 458 | return hidden_states 459 | 460 | 461 | class GemmaModel(nn.Module): 462 | 463 | def __init__(self, config: gemma_config.GemmaConfig): 464 | super().__init__() 465 | self.config = config 466 | self.vocab_size = config.vocab_size 467 | 468 | self.layers = nn.ModuleList() 469 | for i in range(config.num_hidden_layers): 470 | if config.architecture == gemma_config.Architecture.GEMMA_1: 471 | self.layers.append(GemmaDecoderLayer(config)) 472 | elif config.architecture in ( 473 | gemma_config.Architecture.GEMMA_2, 474 | gemma_config.Architecture.GEMMA_3, 475 | ): 476 | attn_type = ( 477 | config.attn_types[i % len(config.attn_types)] 478 | if config.attn_types is not None 479 | else gemma_config.AttentionType.GLOBAL 480 | ) 481 | self.layers.append(Gemma2DecoderLayer(config, attn_type)) 482 | else: 483 | raise ValueError(f'Unknown architecture: {config.architecture}') 484 | self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 485 | 486 | def forward( 487 | self, 488 | hidden_states: torch.Tensor, 489 | freqs_cis: Mapping[gemma_config.AttentionType, torch.Tensor], 490 | kv_write_indices: torch.Tensor, 491 | kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], 492 | mask: torch.Tensor, 493 | local_mask: torch.Tensor, 494 | ) -> torch.Tensor: 495 | for i in range(len(self.layers)): 496 | layer = self.layers[i] 497 | hidden_states = layer( 498 | hidden_states=hidden_states, 499 | freqs_cis=freqs_cis.get(layer.attn_type), 500 | kv_write_indices=kv_write_indices, 501 | kv_cache=kv_caches[i], 502 | mask=mask, 503 | local_mask=local_mask, 504 | ) 505 | hidden_states = self.norm(hidden_states) 506 | return hidden_states 507 | 508 | 509 | class GemmaForCausalLM(nn.Module): 510 | 511 | def __init__( 512 | self, 513 | config: gemma_config.GemmaConfig, 514 | ): 515 | super().__init__() 516 | self.config = config 517 | assert config.hidden_size % config.num_attention_heads == 0 518 | 519 | max_seq_len = config.max_position_embeddings 520 | head_dim = config.head_dim 521 | vocab_size = config.vocab_size 522 | 523 | self.tokenizer = tokenizer.Tokenizer(config.tokenizer) 524 | self.embedder = Embedding(vocab_size, config.hidden_size, config.quant) 525 | self.model = GemmaModel(config) 526 | self.sampler = Sampler(vocab_size, config) 527 | 528 | # Pre-compute rotary embedding table. 529 | if config.architecture == gemma_config.Architecture.GEMMA_3: 530 | if config.rope_wave_length is None: 531 | raise ValueError('rope_wave_length must be provided for Gemma3.') 532 | 533 | rope_lengths = config.rope_wave_length 534 | defaults = { 535 | gemma_config.AttentionType.LOCAL_SLIDING: 10_000, 536 | gemma_config.AttentionType.GLOBAL: 10_000, 537 | } 538 | 539 | for attn_type, name in [ 540 | (gemma_config.AttentionType.LOCAL_SLIDING, 'local_freqs_cis'), 541 | (gemma_config.AttentionType.GLOBAL, 'global_freqs_cis'), 542 | ]: 543 | theta = rope_lengths.get( 544 | attn_type, defaults[attn_type] 545 | ) 546 | self._register_freqs_cis(name, head_dim, max_seq_len, theta=theta) 547 | 548 | else: 549 | self._register_freqs_cis('freqs_cis', head_dim, max_seq_len) 550 | 551 | def _register_freqs_cis( 552 | self, name: str, head_dim: int, max_seq_len: int, theta: int = 10_000 553 | ): 554 | self.register_buffer( 555 | name, precompute_freqs_cis(head_dim, max_seq_len * 2, theta=theta) 556 | ) 557 | 558 | @torch.no_grad() 559 | def forward( 560 | self, 561 | input_token_ids: torch.Tensor, 562 | input_positions: torch.Tensor, 563 | kv_write_indices: torch.Tensor, 564 | kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], 565 | mask: torch.Tensor, 566 | output_positions: torch.Tensor, 567 | temperatures: Union[torch.Tensor, None], 568 | top_ps: torch.Tensor, 569 | top_ks: torch.Tensor, 570 | local_mask: torch.Tensor | None = None, 571 | **kwargs, 572 | ) -> Tuple[torch.Tensor, torch.Tensor]: 573 | freqs_cis = {} 574 | 575 | if self.config.architecture == gemma_config.Architecture.GEMMA_3: 576 | freqs_cis[gemma_config.AttentionType.LOCAL_SLIDING] = ( 577 | self.local_freqs_cis.index_select(0, input_positions) 578 | ) 579 | freqs_cis[gemma_config.AttentionType.GLOBAL] = ( 580 | self.global_freqs_cis.index_select(0, input_positions) 581 | ) 582 | else: 583 | freqs_cis[gemma_config.AttentionType.LOCAL_SLIDING] = ( 584 | self.freqs_cis.index_select(0, input_positions) 585 | ) 586 | freqs_cis[gemma_config.AttentionType.GLOBAL] = ( 587 | self.freqs_cis.index_select(0, input_positions) 588 | ) 589 | 590 | kv_write_indices = input_positions 591 | 592 | # [batch_size, input_len, hidden_size] 593 | hidden_states = self.embedder(input_token_ids) 594 | # Gemma normalizes the embedding by sqrt(hidden_size). 595 | # Gemma2 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 596 | # See https://github.com/huggingface/transformers/pull/29402 597 | normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype, device=hidden_states.device) 598 | hidden_states = hidden_states * normalizer 599 | 600 | hidden_states = self.model( 601 | hidden_states=hidden_states, 602 | freqs_cis=freqs_cis, 603 | kv_write_indices=kv_write_indices, 604 | kv_caches=kv_caches, 605 | mask=mask, 606 | local_mask=local_mask, 607 | ) 608 | embedder_weight = self.embedder.weight 609 | if self.config.quant: 610 | embedder_weight = ( 611 | embedder_weight * self.embedder.weight_scaler.unsqueeze(-1)) 612 | next_tokens, logits = self.sampler( 613 | embedding=embedder_weight, 614 | hidden_states=hidden_states, 615 | output_positions=output_positions, 616 | temperatures=temperatures, 617 | top_ps=top_ps, 618 | top_ks=top_ks, 619 | ) 620 | return next_tokens, logits 621 | 622 | def generate( 623 | self, 624 | prompts: Union[str, Sequence[str]], 625 | device: Any, 626 | output_len: int = 100, 627 | temperature: Union[float, None] = 1.0, 628 | top_p: float = 0.95, 629 | top_k: int = 64, 630 | ) -> Union[str, Sequence[str]]: 631 | """Generates responses for given prompts using Gemma model.""" 632 | # If a single prompt is provided, treat it as a batch of 1. 633 | is_str_prompt = isinstance(prompts, str) 634 | if is_str_prompt: 635 | prompts = [prompts] 636 | 637 | batch_size = len(prompts) 638 | prompt_tokens = [self.tokenizer.encode(prompt) for prompt in prompts] 639 | min_prompt_len = min(len(p) for p in prompt_tokens) 640 | max_prompt_len = max(len(p) for p in prompt_tokens) 641 | max_seq_len = max_prompt_len + output_len 642 | assert max_seq_len <= self.config.max_position_embeddings 643 | 644 | # build KV caches 645 | kv_caches = [] 646 | for _ in range(self.config.num_hidden_layers): 647 | size = (batch_size, max_seq_len, self.config.num_key_value_heads, 648 | self.config.head_dim) 649 | dtype = self.config.get_dtype() 650 | k_cache = torch.zeros(size=size, dtype=dtype, device=device) 651 | v_cache = torch.zeros(size=size, dtype=dtype, device=device) 652 | kv_caches.append((k_cache, v_cache)) 653 | 654 | # prepare inputs 655 | token_ids_tensor = torch.full((batch_size, max_seq_len), 656 | self.tokenizer.pad_id, dtype=torch.int64) 657 | input_token_ids_tensor = torch.full((batch_size, min_prompt_len), 658 | self.tokenizer.pad_id, 659 | dtype=torch.int64) 660 | for i, p in enumerate(prompt_tokens): 661 | token_ids_tensor[i, :len(p)] = torch.tensor(p) 662 | input_token_ids_tensor[i, :min_prompt_len] = torch.tensor( 663 | p[:min_prompt_len]) 664 | token_ids_tensor = token_ids_tensor.to(device) 665 | input_token_ids_tensor = input_token_ids_tensor.to(device) 666 | prompt_mask_tensor = token_ids_tensor != self.tokenizer.pad_id 667 | input_positions_tensor = torch.arange(0, min_prompt_len, 668 | dtype=torch.int64).to(device) 669 | mask_tensor = torch.full((1, 1, max_seq_len, max_seq_len), 670 | -2.3819763e38).to(torch.float) 671 | mask_tensor = torch.triu(mask_tensor, diagonal=1).to(device) 672 | local_mask_tensor = mask_tensor + torch.tril( 673 | torch.full((1, 1, max_seq_len, max_seq_len), -2.3819763e38, device=device), 674 | diagonal=-self.config.sliding_window_size, 675 | ) if self.config.sliding_window_size else None 676 | curr_mask_tensor = mask_tensor.index_select(2, input_positions_tensor) 677 | curr_local_mask_tensor = local_mask_tensor.index_select( 678 | 2, input_positions_tensor 679 | ) if local_mask_tensor is not None else None 680 | output_positions_tensor = torch.LongTensor([min_prompt_len - 1]).to(device) 681 | temperatures_tensor = None if not temperature else torch.FloatTensor( 682 | [temperature] * batch_size).to(device) 683 | top_ps_tensor = torch.FloatTensor([top_p] * batch_size).to(device) 684 | top_ks_tensor = torch.LongTensor([top_k] * batch_size).to(device) 685 | output_index = torch.tensor(min_prompt_len, dtype=torch.int64).to( 686 | device) 687 | 688 | # Prefill up to min_prompt_len tokens, then treat other prefill as 689 | # decode and ignore output. 690 | for i in range(max_seq_len - min_prompt_len): 691 | next_token_ids, _ = self( 692 | input_token_ids=input_token_ids_tensor, 693 | input_positions=input_positions_tensor, 694 | kv_write_indices=None, 695 | kv_caches=kv_caches, 696 | mask=curr_mask_tensor, 697 | output_positions=output_positions_tensor, 698 | temperatures=temperatures_tensor, 699 | top_ps=top_ps_tensor, 700 | top_ks=top_ks_tensor, 701 | local_mask=curr_local_mask_tensor, 702 | ) 703 | 704 | curr_prompt_mask = prompt_mask_tensor.index_select( 705 | 1, output_index).squeeze(dim=1) 706 | curr_token_ids = token_ids_tensor.index_select( 707 | 1, output_index).squeeze(dim=1) 708 | output_token_ids = torch.where(curr_prompt_mask, curr_token_ids, 709 | next_token_ids).unsqueeze(dim=1) 710 | token_ids_tensor.index_copy_(1, output_index, output_token_ids) 711 | 712 | input_token_ids_tensor = output_token_ids 713 | input_positions_tensor = output_index.unsqueeze(dim=-1) 714 | curr_mask_tensor = mask_tensor.index_select(2, 715 | input_positions_tensor) 716 | curr_local_mask_tensor = local_mask_tensor.index_select( 717 | 2, input_positions_tensor 718 | ) if local_mask_tensor is not None else None 719 | output_positions_tensor = torch.tensor(0, dtype=torch.int64).to( 720 | device) 721 | output_index = output_index + 1 722 | 723 | # Detokenization. 724 | token_ids = token_ids_tensor.tolist() 725 | results = [] 726 | for i, tokens in enumerate(token_ids): 727 | trimmed_output = tokens[len(prompt_tokens[i]):len(prompt_tokens[i]) 728 | + output_len] 729 | if self.tokenizer.eos_id in trimmed_output: 730 | eos_index = trimmed_output.index(self.tokenizer.eos_id) 731 | trimmed_output = trimmed_output[:eos_index] 732 | results.append(self.tokenizer.decode(trimmed_output)) 733 | 734 | # If a string was provided as input, return a string as output. 735 | return results[0] if is_str_prompt else results 736 | 737 | def load_weights(self, model_path: str): 738 | if os.path.isfile(model_path): 739 | self.load_state_dict( 740 | torch.load( 741 | model_path, mmap=True, weights_only=True, 742 | )['model_state_dict'], 743 | strict=False, 744 | ) 745 | else: 746 | index_path = os.path.join(model_path, 'pytorch_model.bin.index.json') 747 | with open(index_path, "r", encoding="utf-8") as f: 748 | index = json.load(f) 749 | shard_files = list(set(index["weight_map"].values())) 750 | for shard_file in shard_files: 751 | shard_path = os.path.join(model_path, shard_file) 752 | state_dict = torch.load(shard_path, map_location="cpu", weights_only=True) 753 | self.load_state_dict(state_dict, strict=False) 754 | del state_dict # Save memory. 755 | gc.collect() 756 | --------------------------------------------------------------------------------